py2ls 0.2.4.15__py3-none-any.whl → 0.2.4.16__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- py2ls/.git/index +0 -0
- py2ls/ips.py +722 -12
- py2ls/ml2ls copy.py +2906 -0
- py2ls/ml2ls.py +345 -12
- py2ls/plot.py +409 -24
- {py2ls-0.2.4.15.dist-info → py2ls-0.2.4.16.dist-info}/METADATA +1 -1
- {py2ls-0.2.4.15.dist-info → py2ls-0.2.4.16.dist-info}/RECORD +8 -7
- {py2ls-0.2.4.15.dist-info → py2ls-0.2.4.16.dist-info}/WHEEL +0 -0
py2ls/ml2ls copy.py
ADDED
@@ -0,0 +1,2906 @@
|
|
1
|
+
from sklearn.ensemble import (
|
2
|
+
RandomForestClassifier,
|
3
|
+
GradientBoostingClassifier,
|
4
|
+
AdaBoostClassifier,
|
5
|
+
BaggingClassifier,
|
6
|
+
)
|
7
|
+
from sklearn.svm import SVC, SVR
|
8
|
+
from sklearn.calibration import CalibratedClassifierCV
|
9
|
+
from sklearn.model_selection import GridSearchCV, StratifiedKFold
|
10
|
+
from sklearn.linear_model import (
|
11
|
+
LassoCV,
|
12
|
+
LogisticRegression,
|
13
|
+
LinearRegression,
|
14
|
+
Lasso,
|
15
|
+
Ridge,
|
16
|
+
RidgeClassifierCV,
|
17
|
+
ElasticNet,
|
18
|
+
)
|
19
|
+
from sklearn.feature_selection import RFE
|
20
|
+
from sklearn.naive_bayes import GaussianNB
|
21
|
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
22
|
+
import xgboost as xgb # Make sure you have xgboost installed
|
23
|
+
|
24
|
+
from sklearn.model_selection import train_test_split, cross_val_score
|
25
|
+
from sklearn.metrics import (
|
26
|
+
accuracy_score,
|
27
|
+
precision_score,
|
28
|
+
recall_score,
|
29
|
+
f1_score,
|
30
|
+
roc_auc_score,
|
31
|
+
confusion_matrix,
|
32
|
+
matthews_corrcoef,
|
33
|
+
roc_curve,
|
34
|
+
auc,
|
35
|
+
balanced_accuracy_score,
|
36
|
+
precision_recall_curve,
|
37
|
+
average_precision_score,
|
38
|
+
)
|
39
|
+
from imblearn.over_sampling import SMOTE
|
40
|
+
from sklearn.pipeline import Pipeline
|
41
|
+
from collections import defaultdict
|
42
|
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
43
|
+
from typing import Dict, Any, Optional, List, Union
|
44
|
+
import numpy as np
|
45
|
+
import pandas as pd
|
46
|
+
from . import ips
|
47
|
+
from . import plot
|
48
|
+
import matplotlib.pyplot as plt
|
49
|
+
import seaborn as sns
|
50
|
+
|
51
|
+
plt.style.use(str(ips.get_cwd()) + "/data/styles/stylelib/paper.mplstyle")
|
52
|
+
import logging
|
53
|
+
import warnings
|
54
|
+
|
55
|
+
logging.basicConfig(
|
56
|
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
57
|
+
)
|
58
|
+
logger = logging.getLogger()
|
59
|
+
|
60
|
+
# Ignore specific warnings (UserWarning in this case)
|
61
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
62
|
+
from sklearn.tree import DecisionTreeClassifier
|
63
|
+
from sklearn.neighbors import KNeighborsClassifier
|
64
|
+
|
65
|
+
|
66
|
+
def features_knn(
|
67
|
+
x_train: pd.DataFrame, y_train: pd.Series, knn_params: dict
|
68
|
+
) -> pd.DataFrame:
|
69
|
+
"""
|
70
|
+
A distance-based classifier that assigns labels based on the majority label of nearest neighbors.
|
71
|
+
when to use:
|
72
|
+
Effective for small to medium datasets with a low number of features.
|
73
|
+
It does not directly provide feature importances but can be assessed through feature permutation or similar methods.
|
74
|
+
Recommended Use: Effective for datasets with low feature dimensionality and well-separated clusters.
|
75
|
+
|
76
|
+
Fits KNeighborsClassifier and approximates feature influence using permutation importance.
|
77
|
+
"""
|
78
|
+
knn = KNeighborsClassifier(**knn_params)
|
79
|
+
knn.fit(x_train, y_train)
|
80
|
+
importances = permutation_importance(
|
81
|
+
knn, x_train, y_train, n_repeats=30, random_state=1, scoring="accuracy"
|
82
|
+
)
|
83
|
+
return pd.DataFrame(
|
84
|
+
{"feature": x_train.columns, "importance": importances.importances_mean}
|
85
|
+
).sort_values(by="importance", ascending=False)
|
86
|
+
|
87
|
+
|
88
|
+
#! 1. Linear and Regularized Regression Methods
|
89
|
+
# 1.1 Lasso
|
90
|
+
def features_lasso(
|
91
|
+
x_train: pd.DataFrame, y_train: pd.Series, lasso_params: dict
|
92
|
+
) -> np.ndarray:
|
93
|
+
"""
|
94
|
+
Lasso (Least Absolute Shrinkage and Selection Operator):
|
95
|
+
A regularized linear regression method that uses L1 penalty to shrink coefficients, effectively
|
96
|
+
performing feature selection by zeroing out less important ones.
|
97
|
+
"""
|
98
|
+
lasso = LassoCV(**lasso_params)
|
99
|
+
lasso.fit(x_train, y_train)
|
100
|
+
# Get non-zero coefficients and their corresponding features
|
101
|
+
coefficients = lasso.coef_
|
102
|
+
importance_df = pd.DataFrame(
|
103
|
+
{"feature": x_train.columns, "importance": np.abs(coefficients)}
|
104
|
+
)
|
105
|
+
return importance_df[importance_df["importance"] > 0].sort_values(
|
106
|
+
by="importance", ascending=False
|
107
|
+
)
|
108
|
+
|
109
|
+
|
110
|
+
# 1.2 Ridge regression
|
111
|
+
def features_ridge(
|
112
|
+
x_train: pd.DataFrame, y_train: pd.Series, ridge_params: dict
|
113
|
+
) -> np.ndarray:
|
114
|
+
"""
|
115
|
+
Ridge Regression: A linear regression technique that applies L2 regularization, reducing coefficient
|
116
|
+
magnitudes to avoid overfitting, especially with multicollinearity among features.
|
117
|
+
"""
|
118
|
+
from sklearn.linear_model import RidgeCV
|
119
|
+
|
120
|
+
ridge = RidgeCV(**ridge_params)
|
121
|
+
ridge.fit(x_train, y_train)
|
122
|
+
|
123
|
+
# Get the coefficients
|
124
|
+
coefficients = ridge.coef_
|
125
|
+
|
126
|
+
# Create a DataFrame to hold feature importance
|
127
|
+
importance_df = pd.DataFrame(
|
128
|
+
{"feature": x_train.columns, "importance": np.abs(coefficients)}
|
129
|
+
)
|
130
|
+
return importance_df[importance_df["importance"] > 0].sort_values(
|
131
|
+
by="importance", ascending=False
|
132
|
+
)
|
133
|
+
|
134
|
+
|
135
|
+
# 1.3 Elastic Net(Enet)
|
136
|
+
def features_enet(
|
137
|
+
x_train: pd.DataFrame, y_train: pd.Series, enet_params: dict
|
138
|
+
) -> np.ndarray:
|
139
|
+
"""
|
140
|
+
Elastic Net (Enet): Combines L1 and L2 penalties (lasso and ridge) in a linear model, beneficial
|
141
|
+
when features are highly correlated or for datasets with more features than samples.
|
142
|
+
"""
|
143
|
+
from sklearn.linear_model import ElasticNetCV
|
144
|
+
|
145
|
+
enet = ElasticNetCV(**enet_params)
|
146
|
+
enet.fit(x_train, y_train)
|
147
|
+
# Get the coefficients
|
148
|
+
coefficients = enet.coef_
|
149
|
+
# Create a DataFrame to hold feature importance
|
150
|
+
importance_df = pd.DataFrame(
|
151
|
+
{"feature": x_train.columns, "importance": np.abs(coefficients)}
|
152
|
+
)
|
153
|
+
return importance_df[importance_df["importance"] > 0].sort_values(
|
154
|
+
by="importance", ascending=False
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
# 1.4 Partial Least Squares Regression for Generalized Linear Models (plsRglm): Combines regression and
|
159
|
+
# feature reduction, useful for high-dimensional data with correlated features, such as genomics.
|
160
|
+
|
161
|
+
#! 2.Generalized Linear Models and Extensions
|
162
|
+
# 2.1
|
163
|
+
|
164
|
+
|
165
|
+
#!3.Tree-Based and Ensemble Methods
|
166
|
+
# 3.1 Random Forest(RF)
|
167
|
+
def features_rf(
|
168
|
+
x_train: pd.DataFrame, y_train: pd.Series, rf_params: dict
|
169
|
+
) -> np.ndarray:
|
170
|
+
"""
|
171
|
+
An ensemble of decision trees that combines predictions from multiple trees for classification or
|
172
|
+
regression, effective with high-dimensional, complex datasets.
|
173
|
+
when to use:
|
174
|
+
Handles high-dimensional data well.
|
175
|
+
Robust to overfitting due to averaging of multiple trees.
|
176
|
+
Provides feature importance, which can help in understanding the influence of different genes.
|
177
|
+
Fit Random Forest and return sorted feature importances.
|
178
|
+
Recommended Use: Great for classification problems, especially when you have many features (genes).
|
179
|
+
"""
|
180
|
+
rf = RandomForestClassifier(**rf_params)
|
181
|
+
rf.fit(x_train, y_train)
|
182
|
+
return pd.DataFrame(
|
183
|
+
{"feature": x_train.columns, "importance": rf.featuress_}
|
184
|
+
).sort_values(by="importance", ascending=False)
|
185
|
+
|
186
|
+
|
187
|
+
# 3.2 Gradient Boosting Trees
|
188
|
+
def features_gradient_boosting(
|
189
|
+
x_train: pd.DataFrame, y_train: pd.Series, gb_params: dict
|
190
|
+
) -> pd.DataFrame:
|
191
|
+
"""
|
192
|
+
An ensemble of decision trees that combines predictions from multiple trees for classification or regression, effective with
|
193
|
+
high-dimensional, complex datasets.
|
194
|
+
Gradient Boosting
|
195
|
+
Strengths:
|
196
|
+
High predictive accuracy and works well for both classification and regression.
|
197
|
+
Can handle a mixture of numerical and categorical features.
|
198
|
+
Recommended Use:
|
199
|
+
Effective for complex relationships and when you need a powerful predictive model.
|
200
|
+
Fit Gradient Boosting classifier and return sorted feature importances.
|
201
|
+
Recommended Use: Effective for complex datasets with many features (genes).
|
202
|
+
"""
|
203
|
+
gb = GradientBoostingClassifier(**gb_params)
|
204
|
+
gb.fit(x_train, y_train)
|
205
|
+
return pd.DataFrame(
|
206
|
+
{"feature": x_train.columns, "importance": gb.feature_importances_}
|
207
|
+
).sort_values(by="importance", ascending=False)
|
208
|
+
|
209
|
+
|
210
|
+
# 3.3 XGBoost
|
211
|
+
def features_xgb(
|
212
|
+
x_train: pd.DataFrame, y_train: pd.Series, xgb_params: dict
|
213
|
+
) -> pd.DataFrame:
|
214
|
+
"""
|
215
|
+
XGBoost: An advanced gradient boosting technique, faster and more efficient than GBM, with excellent predictive performance on structured data.
|
216
|
+
"""
|
217
|
+
import xgboost as xgb
|
218
|
+
|
219
|
+
xgb_model = xgb.XGBClassifier(**xgb_params)
|
220
|
+
xgb_model.fit(x_train, y_train)
|
221
|
+
return pd.DataFrame(
|
222
|
+
{"feature": x_train.columns, "importance": xgb_model.feature_importances_}
|
223
|
+
).sort_values(by="importance", ascending=False)
|
224
|
+
|
225
|
+
|
226
|
+
# 3.4.decision tree
|
227
|
+
def features_decision_tree(
|
228
|
+
x_train: pd.DataFrame, y_train: pd.Series, dt_params: dict
|
229
|
+
) -> pd.DataFrame:
|
230
|
+
"""
|
231
|
+
A single decision tree classifier effective for identifying key decision boundaries in data.
|
232
|
+
when to use:
|
233
|
+
Good for capturing non-linear patterns.
|
234
|
+
Provides feature importance scores for each feature, though it may overfit on small datasets.
|
235
|
+
Efficient for low to medium-sized datasets, where interpretability of decisions is key.
|
236
|
+
Recommended Use: Useful for interpretable feature importance analysis in smaller or balanced datasets.
|
237
|
+
|
238
|
+
Fits DecisionTreeClassifier and returns sorted feature importances.
|
239
|
+
"""
|
240
|
+
dt = DecisionTreeClassifier(**dt_params)
|
241
|
+
dt.fit(x_train, y_train)
|
242
|
+
return pd.DataFrame(
|
243
|
+
{"feature": x_train.columns, "importance": dt.feature_importances_}
|
244
|
+
).sort_values(by="importance", ascending=False)
|
245
|
+
|
246
|
+
|
247
|
+
# 3.5 bagging
|
248
|
+
def features_bagging(
|
249
|
+
x_train: pd.DataFrame, y_train: pd.Series, bagging_params: dict
|
250
|
+
) -> pd.DataFrame:
|
251
|
+
"""
|
252
|
+
A bagging ensemble of models, often used with weak learners like decision trees, to reduce variance.
|
253
|
+
when to use:
|
254
|
+
Helps reduce overfitting, especially on high-variance models.
|
255
|
+
Effective when the dataset has numerous features and may benefit from ensemble stability.
|
256
|
+
Recommended Use: Beneficial for high-dimensional or noisy datasets needing ensemble stability.
|
257
|
+
|
258
|
+
Fits BaggingClassifier and returns averaged feature importances from underlying estimators if available.
|
259
|
+
"""
|
260
|
+
bagging = BaggingClassifier(**bagging_params)
|
261
|
+
bagging.fit(x_train, y_train)
|
262
|
+
|
263
|
+
# Calculate feature importance by averaging importances across estimators, if feature_importances_ is available.
|
264
|
+
if hasattr(bagging.estimators_[0], "feature_importances_"):
|
265
|
+
importances = np.mean(
|
266
|
+
[estimator.feature_importances_ for estimator in bagging.estimators_],
|
267
|
+
axis=0,
|
268
|
+
)
|
269
|
+
return pd.DataFrame(
|
270
|
+
{"feature": x_train.columns, "importance": importances}
|
271
|
+
).sort_values(by="importance", ascending=False)
|
272
|
+
else:
|
273
|
+
# If the base estimator does not support feature importances, fallback to permutation importance.
|
274
|
+
importances = permutation_importance(
|
275
|
+
bagging, x_train, y_train, n_repeats=30, random_state=1, scoring="accuracy"
|
276
|
+
)
|
277
|
+
return pd.DataFrame(
|
278
|
+
{"feature": x_train.columns, "importance": importances.importances_mean}
|
279
|
+
).sort_values(by="importance", ascending=False)
|
280
|
+
|
281
|
+
|
282
|
+
#! 4.Support Vector Machines
|
283
|
+
def features_svm(
|
284
|
+
x_train: pd.DataFrame, y_train: pd.Series, rfe_params: dict
|
285
|
+
) -> np.ndarray:
|
286
|
+
"""
|
287
|
+
Suitable for classification tasks where the number of features is much larger than the number of samples.
|
288
|
+
1. Effective in high-dimensional spaces and with clear margin of separation.
|
289
|
+
2. Works well for both linear and non-linear classification (using kernel functions).
|
290
|
+
Select features using RFE with SVM.When combined with SVM, RFE selects features that are most critical for the decision boundary,
|
291
|
+
helping reduce the dataset to a more manageable size without losing much predictive power.
|
292
|
+
SVM (Support Vector Machines),supports various kernels (linear, rbf, poly, and sigmoid), is good at handling high-dimensional
|
293
|
+
data and finding an optimal decision boundary between classes, especially when using the right kernel.
|
294
|
+
kernel: ["linear", "rbf", "poly", "sigmoid"]
|
295
|
+
'linear': simplest kernel that attempts to separate data by drawing a straight line (or hyperplane) between classes. It is effective
|
296
|
+
when the data is linearly separable, meaning the classes can be well divided by a straight boundary.
|
297
|
+
Advantages:
|
298
|
+
- Computationally efficient for large datasets.
|
299
|
+
- Works well when the number of features is high, which is common in genomic data where you may have thousands of genes
|
300
|
+
as features.
|
301
|
+
'rbf': a nonlinear kernel that maps the input data into a higher-dimensional space to find a decision boundary. It works well for
|
302
|
+
data that is not linearly separable in its original space.
|
303
|
+
Advantages:
|
304
|
+
- Handles nonlinear relationships between features and classes
|
305
|
+
- Often better than a linear kernel when there is no clear linear decision boundary in the data.
|
306
|
+
'poly': Polynomial Kernel: computes similarity between data points based on polynomial functions of the input features. It can model
|
307
|
+
interactions between features to a certain degree, depending on the polynomial degree chosen.
|
308
|
+
Advantages:
|
309
|
+
- Allows modeling of feature interactions.
|
310
|
+
- Can fit more complex relationships compared to linear models.
|
311
|
+
'sigmoid': similar to the activation function in neural networks, and it works well when the data follows an S-shaped decision boundary.
|
312
|
+
Advantages:
|
313
|
+
- Can approximate the behavior of neural networks.
|
314
|
+
- Use case: It’s not as widely used as the RBF or linear kernel but can be explored when there is some evidence of non-linear
|
315
|
+
S-shaped relationships.
|
316
|
+
"""
|
317
|
+
# SVM (Support Vector Machines)
|
318
|
+
svc = SVC(kernel=rfe_params["kernel"]) # ["linear", "rbf", "poly", "sigmoid"]
|
319
|
+
# RFE(Recursive Feature Elimination)
|
320
|
+
selector = RFE(svc, n_features_to_select=rfe_params["n_features_to_select"])
|
321
|
+
selector.fit(x_train, y_train)
|
322
|
+
return x_train.columns[selector.support_]
|
323
|
+
|
324
|
+
|
325
|
+
#! 5.Bayesian and Probabilistic Methods
|
326
|
+
def features_naive_bayes(x_train: pd.DataFrame, y_train: pd.Series) -> list:
|
327
|
+
"""
|
328
|
+
Naive Bayes: A probabilistic classifier based on Bayes' theorem, assuming independence between features, simple and fast, especially
|
329
|
+
effective for text classification and other high-dimensional data.
|
330
|
+
"""
|
331
|
+
from sklearn.naive_bayes import GaussianNB
|
332
|
+
|
333
|
+
nb = GaussianNB()
|
334
|
+
nb.fit(x_train, y_train)
|
335
|
+
probabilities = nb.predict_proba(x_train)
|
336
|
+
# Limit the number of features safely, choosing the lesser of half the features or all columns
|
337
|
+
n_features = min(x_train.shape[1] // 2, len(x_train.columns))
|
338
|
+
|
339
|
+
# Sort probabilities, then map to valid column indices
|
340
|
+
sorted_indices = np.argsort(probabilities.max(axis=1))[:n_features]
|
341
|
+
|
342
|
+
# Ensure indices are within the column bounds of x_train
|
343
|
+
valid_indices = sorted_indices[sorted_indices < len(x_train.columns)]
|
344
|
+
|
345
|
+
return x_train.columns[valid_indices]
|
346
|
+
|
347
|
+
|
348
|
+
#! 6.Linear Discriminant Analysis (LDA)
|
349
|
+
def features_lda(x_train: pd.DataFrame, y_train: pd.Series) -> list:
|
350
|
+
"""
|
351
|
+
Linear Discriminant Analysis (LDA): Projects data onto a lower-dimensional space to maximize class separability, often used as a dimensionality
|
352
|
+
reduction technique before classification on high-dimensional data.
|
353
|
+
"""
|
354
|
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
355
|
+
|
356
|
+
lda = LinearDiscriminantAnalysis()
|
357
|
+
lda.fit(x_train, y_train)
|
358
|
+
coef = lda.coef_.flatten()
|
359
|
+
# Create a DataFrame to hold feature importance
|
360
|
+
importance_df = pd.DataFrame(
|
361
|
+
{"feature": x_train.columns, "importance": np.abs(coef)}
|
362
|
+
)
|
363
|
+
|
364
|
+
return importance_df[importance_df["importance"] > 0].sort_values(
|
365
|
+
by="importance", ascending=False
|
366
|
+
)
|
367
|
+
|
368
|
+
|
369
|
+
def features_adaboost(
|
370
|
+
x_train: pd.DataFrame, y_train: pd.Series, adaboost_params: dict
|
371
|
+
) -> pd.DataFrame:
|
372
|
+
"""
|
373
|
+
AdaBoost
|
374
|
+
Strengths:
|
375
|
+
Combines multiple weak learners to create a strong classifier.
|
376
|
+
Focuses on examples that are hard to classify, improving overall performance.
|
377
|
+
Recommended Use:
|
378
|
+
Can be effective for boosting weak models in a genomics context.
|
379
|
+
Fit AdaBoost classifier and return sorted feature importances.
|
380
|
+
Recommended Use: Great for classification problems with a large number of features (genes).
|
381
|
+
"""
|
382
|
+
ada = AdaBoostClassifier(**adaboost_params)
|
383
|
+
ada.fit(x_train, y_train)
|
384
|
+
return pd.DataFrame(
|
385
|
+
{"feature": x_train.columns, "importance": ada.feature_importances_}
|
386
|
+
).sort_values(by="importance", ascending=False)
|
387
|
+
|
388
|
+
|
389
|
+
import torch
|
390
|
+
import torch.nn as nn
|
391
|
+
import torch.optim as optim
|
392
|
+
from torch.utils.data import DataLoader, TensorDataset
|
393
|
+
from skorch import NeuralNetClassifier # sklearn compatible
|
394
|
+
|
395
|
+
|
396
|
+
class DNNClassifier(nn.Module):
|
397
|
+
def __init__(self, input_dim, hidden_dim=128, output_dim=2, dropout_rate=0.5):
|
398
|
+
super(DNNClassifier, self).__init__()
|
399
|
+
|
400
|
+
self.hidden_layer1 = nn.Sequential(
|
401
|
+
nn.Linear(input_dim, hidden_dim),
|
402
|
+
nn.ReLU(),
|
403
|
+
nn.Dropout(dropout_rate),
|
404
|
+
nn.Linear(hidden_dim, hidden_dim),
|
405
|
+
nn.ReLU(),
|
406
|
+
)
|
407
|
+
|
408
|
+
self.hidden_layer2 = nn.Sequential(
|
409
|
+
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate)
|
410
|
+
)
|
411
|
+
|
412
|
+
# Adding a residual connection between hidden layers
|
413
|
+
self.residual = nn.Linear(input_dim, hidden_dim)
|
414
|
+
|
415
|
+
self.output_layer = nn.Sequential(
|
416
|
+
nn.Linear(hidden_dim, output_dim), nn.Softmax(dim=1)
|
417
|
+
)
|
418
|
+
|
419
|
+
def forward(self, x):
|
420
|
+
residual = self.residual(x)
|
421
|
+
x = self.hidden_layer1(x)
|
422
|
+
x = x + residual # Residual connection
|
423
|
+
x = self.hidden_layer2(x)
|
424
|
+
x = self.output_layer(x)
|
425
|
+
return x
|
426
|
+
|
427
|
+
|
428
|
+
def validate_classifier(
|
429
|
+
clf,
|
430
|
+
x_train: pd.DataFrame,
|
431
|
+
y_train: pd.Series,
|
432
|
+
x_test: pd.DataFrame,
|
433
|
+
y_test: pd.Series,
|
434
|
+
metrics: list = ["accuracy", "precision", "recall", "f1", "roc_auc"],
|
435
|
+
cv_folds: int = 5,
|
436
|
+
) -> dict:
|
437
|
+
"""
|
438
|
+
Perform cross-validation for a given classifier and return average scores for specified metrics on training data.
|
439
|
+
Then fit the best model on the full training data and evaluate it on the test set.
|
440
|
+
|
441
|
+
Parameters:
|
442
|
+
- clf: The classifier to be validated.
|
443
|
+
- x_train: Training features.
|
444
|
+
- y_train: Training labels.
|
445
|
+
- x_test: Test features.
|
446
|
+
- y_test: Test labels.
|
447
|
+
- metrics: List of metrics to evaluate (e.g., ['accuracy', 'roc_auc']).
|
448
|
+
- cv_folds: Number of cross-validation folds.
|
449
|
+
|
450
|
+
Returns:
|
451
|
+
- results: Dictionary containing average cv_train_scores and cv_test_scores.
|
452
|
+
"""
|
453
|
+
cv_train_scores = {metric: [] for metric in metrics}
|
454
|
+
skf = StratifiedKFold(n_splits=cv_folds)
|
455
|
+
# Perform cross-validation
|
456
|
+
for metric in metrics:
|
457
|
+
try:
|
458
|
+
if metric == "roc_auc" and len(set(y_train)) == 2:
|
459
|
+
scores = cross_val_score(
|
460
|
+
clf, x_train, y_train, cv=skf, scoring="roc_auc"
|
461
|
+
)
|
462
|
+
cv_train_scores[metric] = (
|
463
|
+
np.nanmean(scores) if not np.isnan(scores).all() else float("nan")
|
464
|
+
)
|
465
|
+
else:
|
466
|
+
score = cross_val_score(clf, x_train, y_train, cv=skf, scoring=metric)
|
467
|
+
cv_train_scores[metric] = score.mean()
|
468
|
+
except Exception as e:
|
469
|
+
cv_train_scores[metric] = float("nan")
|
470
|
+
clf.fit(x_train, y_train)
|
471
|
+
|
472
|
+
# Evaluate on the test set
|
473
|
+
cv_test_scores = {}
|
474
|
+
for metric in metrics:
|
475
|
+
if metric == "roc_auc" and len(set(y_test)) == 2:
|
476
|
+
try:
|
477
|
+
y_prob = clf.predict_proba(x_test)[:, 1]
|
478
|
+
cv_test_scores[metric] = roc_auc_score(y_test, y_prob)
|
479
|
+
except AttributeError:
|
480
|
+
cv_test_scores[metric] = float("nan")
|
481
|
+
else:
|
482
|
+
score_func = globals().get(
|
483
|
+
f"{metric}_score"
|
484
|
+
) # Fetching the appropriate scoring function
|
485
|
+
if score_func:
|
486
|
+
try:
|
487
|
+
y_pred = clf.predict(x_test)
|
488
|
+
cv_test_scores[metric] = score_func(y_test, y_pred)
|
489
|
+
except Exception as e:
|
490
|
+
cv_test_scores[metric] = float("nan")
|
491
|
+
|
492
|
+
# Combine results
|
493
|
+
results = {"cv_train_scores": cv_train_scores, "cv_test_scores": cv_test_scores}
|
494
|
+
return results
|
495
|
+
|
496
|
+
|
497
|
+
def get_models(
|
498
|
+
random_state=1,
|
499
|
+
cls=[
|
500
|
+
"lasso",
|
501
|
+
"ridge",
|
502
|
+
"Elastic Net(Enet)",
|
503
|
+
"gradient Boosting",
|
504
|
+
"Random forest (rf)",
|
505
|
+
"XGBoost (xgb)",
|
506
|
+
"Support Vector Machine(svm)",
|
507
|
+
"naive bayes",
|
508
|
+
"Linear Discriminant Analysis (lda)",
|
509
|
+
"adaboost",
|
510
|
+
"DecisionTree",
|
511
|
+
"KNeighbors",
|
512
|
+
"Bagging",
|
513
|
+
],
|
514
|
+
):
|
515
|
+
from sklearn.ensemble import (
|
516
|
+
RandomForestClassifier,
|
517
|
+
GradientBoostingClassifier,
|
518
|
+
AdaBoostClassifier,
|
519
|
+
BaggingClassifier,
|
520
|
+
)
|
521
|
+
from sklearn.svm import SVC
|
522
|
+
from sklearn.linear_model import (
|
523
|
+
LogisticRegression,
|
524
|
+
Lasso,
|
525
|
+
RidgeClassifierCV,
|
526
|
+
ElasticNet,
|
527
|
+
)
|
528
|
+
from sklearn.naive_bayes import GaussianNB
|
529
|
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
530
|
+
import xgboost as xgb
|
531
|
+
from sklearn.tree import DecisionTreeClassifier
|
532
|
+
from sklearn.neighbors import KNeighborsClassifier
|
533
|
+
|
534
|
+
res_cls = {}
|
535
|
+
model_all = {
|
536
|
+
"Lasso": LogisticRegression(
|
537
|
+
penalty="l1", solver="saga", random_state=random_state
|
538
|
+
),
|
539
|
+
"Ridge": RidgeClassifierCV(),
|
540
|
+
"Elastic Net (Enet)": ElasticNet(random_state=random_state),
|
541
|
+
"Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
|
542
|
+
"Random Forest (RF)": RandomForestClassifier(random_state=random_state),
|
543
|
+
"XGBoost (XGB)": xgb.XGBClassifier(random_state=random_state),
|
544
|
+
"Support Vector Machine (SVM)": SVC(kernel="rbf", probability=True),
|
545
|
+
"Naive Bayes": GaussianNB(),
|
546
|
+
"Linear Discriminant Analysis (LDA)": LinearDiscriminantAnalysis(),
|
547
|
+
"AdaBoost": AdaBoostClassifier(random_state=random_state, algorithm="SAMME"),
|
548
|
+
"DecisionTree": DecisionTreeClassifier(),
|
549
|
+
"KNeighbors": KNeighborsClassifier(n_neighbors=5),
|
550
|
+
"Bagging": BaggingClassifier(),
|
551
|
+
}
|
552
|
+
print("Using default models:")
|
553
|
+
for cls_name in cls:
|
554
|
+
cls_name = ips.strcmp(cls_name, list(model_all.keys()))[0]
|
555
|
+
res_cls[cls_name] = model_all[cls_name]
|
556
|
+
print(f"- {cls_name}")
|
557
|
+
return res_cls
|
558
|
+
|
559
|
+
|
560
|
+
def get_features(
|
561
|
+
X: Union[pd.DataFrame, np.ndarray], # n_samples X n_features
|
562
|
+
y: Union[pd.Series, np.ndarray, list], # n_samples X n_features
|
563
|
+
test_size: float = 0.2,
|
564
|
+
random_state: int = 1,
|
565
|
+
n_features: int = 10,
|
566
|
+
fill_missing=True,
|
567
|
+
rf_params: Optional[Dict] = None,
|
568
|
+
rfe_params: Optional[Dict] = None,
|
569
|
+
lasso_params: Optional[Dict] = None,
|
570
|
+
ridge_params: Optional[Dict] = None,
|
571
|
+
enet_params: Optional[Dict] = None,
|
572
|
+
gb_params: Optional[Dict] = None,
|
573
|
+
adaboost_params: Optional[Dict] = None,
|
574
|
+
xgb_params: Optional[Dict] = None,
|
575
|
+
dt_params: Optional[Dict] = None,
|
576
|
+
bagging_params: Optional[Dict] = None,
|
577
|
+
knn_params: Optional[Dict] = None,
|
578
|
+
cls: list = [
|
579
|
+
"lasso",
|
580
|
+
"ridge",
|
581
|
+
"Elastic Net(Enet)",
|
582
|
+
"gradient Boosting",
|
583
|
+
"Random forest (rf)",
|
584
|
+
"XGBoost (xgb)",
|
585
|
+
"Support Vector Machine(svm)",
|
586
|
+
"naive bayes",
|
587
|
+
"Linear Discriminant Analysis (lda)",
|
588
|
+
"adaboost",
|
589
|
+
"DecisionTree",
|
590
|
+
"KNeighbors",
|
591
|
+
"Bagging",
|
592
|
+
],
|
593
|
+
metrics: Optional[List[str]] = None,
|
594
|
+
cv_folds: int = 5,
|
595
|
+
strict: bool = False,
|
596
|
+
n_shared: int = 2, # 只要有两个方法有重合,就纳入common genes
|
597
|
+
use_selected_features: bool = True,
|
598
|
+
plot_: bool = True,
|
599
|
+
dir_save: str = "./",
|
600
|
+
) -> dict:
|
601
|
+
"""
|
602
|
+
Master function to perform feature selection and validate models.
|
603
|
+
"""
|
604
|
+
from sklearn.compose import ColumnTransformer
|
605
|
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
606
|
+
|
607
|
+
# Ensure X and y are DataFrames/Series for consistency
|
608
|
+
if isinstance(X, np.ndarray):
|
609
|
+
X = pd.DataFrame(X)
|
610
|
+
if isinstance(y, (np.ndarray, list)):
|
611
|
+
y = pd.Series(y)
|
612
|
+
|
613
|
+
# fill na
|
614
|
+
if fill_missing:
|
615
|
+
ips.df_fillna(data=X, method="knn", inplace=True, axis=0)
|
616
|
+
if isinstance(y, str) and y in X.columns:
|
617
|
+
y_col_name = y
|
618
|
+
y = X[y]
|
619
|
+
y = ips.df_encoder(pd.DataFrame(y), method="dummy")
|
620
|
+
X = X.drop(y_col_name, axis=1)
|
621
|
+
else:
|
622
|
+
y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
|
623
|
+
y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
|
624
|
+
y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
|
625
|
+
|
626
|
+
if X.shape[0] != len(y):
|
627
|
+
raise ValueError("X and y must have the same number of samples (rows).")
|
628
|
+
|
629
|
+
# #! # Check for non-numeric columns in X and apply one-hot encoding if needed
|
630
|
+
# Check if any column in X is non-numeric
|
631
|
+
if any(not np.issubdtype(dtype, np.number) for dtype in X.dtypes):
|
632
|
+
X = pd.get_dummies(X, drop_first=True)
|
633
|
+
print(X.shape)
|
634
|
+
|
635
|
+
# #!alternative: # Identify categorical and numerical columns
|
636
|
+
# categorical_cols = X.select_dtypes(include=["object", "category"]).columns
|
637
|
+
# numerical_cols = X.select_dtypes(include=["number"]).columns
|
638
|
+
|
639
|
+
# # Define preprocessing pipeline
|
640
|
+
# preprocessor = ColumnTransformer(
|
641
|
+
# transformers=[
|
642
|
+
# ("num", StandardScaler(), numerical_cols),
|
643
|
+
# ("cat", OneHotEncoder(drop="first", handle_unknown="ignore"), categorical_cols),
|
644
|
+
# ]
|
645
|
+
# )
|
646
|
+
# # Preprocess the data
|
647
|
+
# X = preprocessor.fit_transform(X)
|
648
|
+
|
649
|
+
# Split data into training and test sets
|
650
|
+
x_train, x_test, y_train, y_test = train_test_split(
|
651
|
+
X, y, test_size=test_size, random_state=random_state
|
652
|
+
)
|
653
|
+
# Standardize features
|
654
|
+
scaler = StandardScaler()
|
655
|
+
x_train_scaled = scaler.fit_transform(x_train)
|
656
|
+
x_test_scaled = scaler.transform(x_test)
|
657
|
+
|
658
|
+
# Convert back to DataFrame for consistency
|
659
|
+
x_train = pd.DataFrame(x_train_scaled, columns=x_train.columns)
|
660
|
+
x_test = pd.DataFrame(x_test_scaled, columns=x_test.columns)
|
661
|
+
|
662
|
+
rf_defaults = {"n_estimators": 100, "random_state": random_state}
|
663
|
+
rfe_defaults = {"kernel": "linear", "n_features_to_select": n_features}
|
664
|
+
lasso_defaults = {"alphas": np.logspace(-4, 4, 100), "cv": 10}
|
665
|
+
ridge_defaults = {"alphas": np.logspace(-4, 4, 100), "cv": 10}
|
666
|
+
enet_defaults = {"alphas": np.logspace(-4, 4, 100), "cv": 10}
|
667
|
+
xgb_defaults = {
|
668
|
+
"n_estimators": 100,
|
669
|
+
"use_label_encoder": False,
|
670
|
+
"eval_metric": "logloss",
|
671
|
+
"random_state": random_state,
|
672
|
+
}
|
673
|
+
gb_defaults = {"n_estimators": 100, "random_state": random_state}
|
674
|
+
adaboost_defaults = {"n_estimators": 50, "random_state": random_state}
|
675
|
+
dt_defaults = {"max_depth": None, "random_state": random_state}
|
676
|
+
bagging_defaults = {"n_estimators": 50, "random_state": random_state}
|
677
|
+
knn_defaults = {"n_neighbors": 5}
|
678
|
+
rf_params, rfe_params = rf_params or rf_defaults, rfe_params or rfe_defaults
|
679
|
+
lasso_params, ridge_params = (
|
680
|
+
lasso_params or lasso_defaults,
|
681
|
+
ridge_params or ridge_defaults,
|
682
|
+
)
|
683
|
+
enet_params, xgb_params = enet_params or enet_defaults, xgb_params or xgb_defaults
|
684
|
+
gb_params, adaboost_params = (
|
685
|
+
gb_params or gb_defaults,
|
686
|
+
adaboost_params or adaboost_defaults,
|
687
|
+
)
|
688
|
+
dt_params = dt_params or dt_defaults
|
689
|
+
bagging_params = bagging_params or bagging_defaults
|
690
|
+
knn_params = knn_params or knn_defaults
|
691
|
+
|
692
|
+
cls_ = [
|
693
|
+
"lasso",
|
694
|
+
"ridge",
|
695
|
+
"Elastic Net(Enet)",
|
696
|
+
"Gradient Boosting",
|
697
|
+
"Random Forest (rf)",
|
698
|
+
"XGBoost (xgb)",
|
699
|
+
"Support Vector Machine(svm)",
|
700
|
+
"Naive Bayes",
|
701
|
+
"Linear Discriminant Analysis (lda)",
|
702
|
+
"adaboost",
|
703
|
+
]
|
704
|
+
cls = [ips.strcmp(i, cls_)[0] for i in cls]
|
705
|
+
|
706
|
+
# Lasso Feature Selection
|
707
|
+
lasso_importances = (
|
708
|
+
features_lasso(x_train, y_train, lasso_params)
|
709
|
+
if "lasso" in cls
|
710
|
+
else pd.DataFrame()
|
711
|
+
)
|
712
|
+
lasso_selected_features = (
|
713
|
+
lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
|
714
|
+
)
|
715
|
+
# Ridge
|
716
|
+
ridge_importances = (
|
717
|
+
features_ridge(x_train, y_train, ridge_params)
|
718
|
+
if "ridge" in cls
|
719
|
+
else pd.DataFrame()
|
720
|
+
)
|
721
|
+
selected_ridge_features = (
|
722
|
+
ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
|
723
|
+
)
|
724
|
+
# Elastic Net
|
725
|
+
enet_importances = (
|
726
|
+
features_enet(x_train, y_train, enet_params)
|
727
|
+
if "Enet" in cls
|
728
|
+
else pd.DataFrame()
|
729
|
+
)
|
730
|
+
selected_enet_features = (
|
731
|
+
enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
|
732
|
+
)
|
733
|
+
# Random Forest Feature Importance
|
734
|
+
rf_importances = (
|
735
|
+
features_rf(x_train, y_train, rf_params)
|
736
|
+
if "Random Forest" in cls
|
737
|
+
else pd.DataFrame()
|
738
|
+
)
|
739
|
+
top_rf_features = (
|
740
|
+
rf_importances.head(n_features)["feature"].values
|
741
|
+
if "Random Forest" in cls
|
742
|
+
else []
|
743
|
+
)
|
744
|
+
# Gradient Boosting Feature Importance
|
745
|
+
gb_importances = (
|
746
|
+
features_gradient_boosting(x_train, y_train, gb_params)
|
747
|
+
if "Gradient Boosting" in cls
|
748
|
+
else pd.DataFrame()
|
749
|
+
)
|
750
|
+
top_gb_features = (
|
751
|
+
gb_importances.head(n_features)["feature"].values
|
752
|
+
if "Gradient Boosting" in cls
|
753
|
+
else []
|
754
|
+
)
|
755
|
+
# xgb
|
756
|
+
xgb_importances = (
|
757
|
+
features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
|
758
|
+
)
|
759
|
+
top_xgb_features = (
|
760
|
+
xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
|
761
|
+
)
|
762
|
+
|
763
|
+
# SVM with RFE
|
764
|
+
selected_svm_features = (
|
765
|
+
features_svm(x_train, y_train, rfe_params) if "svm" in cls else []
|
766
|
+
)
|
767
|
+
# Naive Bayes
|
768
|
+
selected_naive_bayes_features = (
|
769
|
+
features_naive_bayes(x_train, y_train) if "Naive Bayes" in cls else []
|
770
|
+
)
|
771
|
+
# lda: linear discriminant analysis
|
772
|
+
lda_importances = features_lda(x_train, y_train) if "lda" in cls else pd.DataFrame()
|
773
|
+
selected_lda_features = (
|
774
|
+
lda_importances.head(n_features)["feature"].values if "lda" in cls else []
|
775
|
+
)
|
776
|
+
# AdaBoost Feature Importance
|
777
|
+
adaboost_importances = (
|
778
|
+
features_adaboost(x_train, y_train, adaboost_params)
|
779
|
+
if "AdaBoost" in cls
|
780
|
+
else pd.DataFrame()
|
781
|
+
)
|
782
|
+
top_adaboost_features = (
|
783
|
+
adaboost_importances.head(n_features)["feature"].values
|
784
|
+
if "AdaBoost" in cls
|
785
|
+
else []
|
786
|
+
)
|
787
|
+
# Decision Tree Feature Importance
|
788
|
+
dt_importances = (
|
789
|
+
features_decision_tree(x_train, y_train, dt_params)
|
790
|
+
if "Decision Tree" in cls
|
791
|
+
else pd.DataFrame()
|
792
|
+
)
|
793
|
+
top_dt_features = (
|
794
|
+
dt_importances.head(n_features)["feature"].values
|
795
|
+
if "Decision Tree" in cls
|
796
|
+
else []
|
797
|
+
)
|
798
|
+
# Bagging Feature Importance
|
799
|
+
bagging_importances = (
|
800
|
+
features_bagging(x_train, y_train, bagging_params)
|
801
|
+
if "Bagging" in cls
|
802
|
+
else pd.DataFrame()
|
803
|
+
)
|
804
|
+
top_bagging_features = (
|
805
|
+
bagging_importances.head(n_features)["feature"].values
|
806
|
+
if "Bagging" in cls
|
807
|
+
else []
|
808
|
+
)
|
809
|
+
# KNN Feature Importance via Permutation
|
810
|
+
knn_importances = (
|
811
|
+
features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
|
812
|
+
)
|
813
|
+
top_knn_features = (
|
814
|
+
knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
|
815
|
+
)
|
816
|
+
|
817
|
+
#! Find common features
|
818
|
+
common_features = ips.shared(
|
819
|
+
lasso_selected_features,
|
820
|
+
selected_ridge_features,
|
821
|
+
selected_enet_features,
|
822
|
+
top_rf_features,
|
823
|
+
top_gb_features,
|
824
|
+
top_xgb_features,
|
825
|
+
selected_svm_features,
|
826
|
+
selected_naive_bayes_features,
|
827
|
+
selected_lda_features,
|
828
|
+
top_adaboost_features,
|
829
|
+
top_dt_features,
|
830
|
+
top_bagging_features,
|
831
|
+
top_knn_features,
|
832
|
+
strict=strict,
|
833
|
+
n_shared=n_shared,
|
834
|
+
verbose=False,
|
835
|
+
)
|
836
|
+
|
837
|
+
# Use selected features or all features for model validation
|
838
|
+
x_train_selected = (
|
839
|
+
x_train[list(common_features)] if use_selected_features else x_train
|
840
|
+
)
|
841
|
+
x_test_selected = x_test[list(common_features)] if use_selected_features else x_test
|
842
|
+
|
843
|
+
if metrics is None:
|
844
|
+
metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
|
845
|
+
|
846
|
+
# Prepare results DataFrame for selected features
|
847
|
+
features_df = pd.DataFrame(
|
848
|
+
{
|
849
|
+
"type": ["Lasso"] * len(lasso_selected_features)
|
850
|
+
+ ["Ridge"] * len(selected_ridge_features)
|
851
|
+
+ ["Random Forest"] * len(top_rf_features)
|
852
|
+
+ ["Gradient Boosting"] * len(top_gb_features)
|
853
|
+
+ ["Enet"] * len(selected_enet_features)
|
854
|
+
+ ["xgb"] * len(top_xgb_features)
|
855
|
+
+ ["SVM"] * len(selected_svm_features)
|
856
|
+
+ ["Naive Bayes"] * len(selected_naive_bayes_features)
|
857
|
+
+ ["Linear Discriminant Analysis"] * len(selected_lda_features)
|
858
|
+
+ ["AdaBoost"] * len(top_adaboost_features)
|
859
|
+
+ ["Decision Tree"] * len(top_dt_features)
|
860
|
+
+ ["Bagging"] * len(top_bagging_features)
|
861
|
+
+ ["KNN"] * len(top_knn_features),
|
862
|
+
"feature": np.concatenate(
|
863
|
+
[
|
864
|
+
lasso_selected_features,
|
865
|
+
selected_ridge_features,
|
866
|
+
top_rf_features,
|
867
|
+
top_gb_features,
|
868
|
+
selected_enet_features,
|
869
|
+
top_xgb_features,
|
870
|
+
selected_svm_features,
|
871
|
+
selected_naive_bayes_features,
|
872
|
+
selected_lda_features,
|
873
|
+
top_adaboost_features,
|
874
|
+
top_dt_features,
|
875
|
+
top_bagging_features,
|
876
|
+
top_knn_features,
|
877
|
+
]
|
878
|
+
),
|
879
|
+
}
|
880
|
+
)
|
881
|
+
|
882
|
+
#! Validate trained each classifier
|
883
|
+
models = get_models(random_state=random_state, cls=cls)
|
884
|
+
cv_train_results, cv_test_results = [], []
|
885
|
+
for name, clf in models.items():
|
886
|
+
if not x_train_selected.empty:
|
887
|
+
cv_scores = validate_classifier(
|
888
|
+
clf,
|
889
|
+
x_train_selected,
|
890
|
+
y_train,
|
891
|
+
x_test_selected,
|
892
|
+
y_test,
|
893
|
+
metrics=metrics,
|
894
|
+
cv_folds=cv_folds,
|
895
|
+
)
|
896
|
+
|
897
|
+
cv_train_score_df = pd.DataFrame(cv_scores["cv_train_scores"], index=[name])
|
898
|
+
cv_test_score_df = pd.DataFrame(cv_scores["cv_test_scores"], index=[name])
|
899
|
+
cv_train_results.append(cv_train_score_df)
|
900
|
+
cv_test_results.append(cv_test_score_df)
|
901
|
+
if all([cv_train_results, cv_test_results]):
|
902
|
+
cv_train_results_df = (
|
903
|
+
pd.concat(cv_train_results)
|
904
|
+
.reset_index()
|
905
|
+
.rename(columns={"index": "Classifier"})
|
906
|
+
)
|
907
|
+
cv_test_results_df = (
|
908
|
+
pd.concat(cv_test_results)
|
909
|
+
.reset_index()
|
910
|
+
.rename(columns={"index": "Classifier"})
|
911
|
+
)
|
912
|
+
#! Store results in the main results dictionary
|
913
|
+
results = {
|
914
|
+
"selected_features": features_df,
|
915
|
+
"cv_train_scores": cv_train_results_df,
|
916
|
+
"cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
|
917
|
+
"common_features": list(common_features),
|
918
|
+
}
|
919
|
+
if all([plot_, dir_save]):
|
920
|
+
from datetime import datetime
|
921
|
+
|
922
|
+
now_ = datetime.now().strftime("%y%m%d_%H%M%S")
|
923
|
+
ips.figsave(dir_save + f"features{now_}.pdf")
|
924
|
+
else:
|
925
|
+
results = {
|
926
|
+
"selected_features": pd.DataFrame(),
|
927
|
+
"cv_train_scores": pd.DataFrame(),
|
928
|
+
"cv_test_scores": pd.DataFrame(),
|
929
|
+
"common_features": [],
|
930
|
+
}
|
931
|
+
print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
|
932
|
+
return results
|
933
|
+
|
934
|
+
|
935
|
+
#! # usage:
|
936
|
+
# # Get features and common features
|
937
|
+
# results = get_features(X, y)
|
938
|
+
# common_features = results["common_features"]
|
939
|
+
def validate_features(
|
940
|
+
x_train: pd.DataFrame,
|
941
|
+
y_train: pd.Series,
|
942
|
+
x_true: pd.DataFrame,
|
943
|
+
y_true: pd.Series,
|
944
|
+
common_features: set = None,
|
945
|
+
models: Optional[Dict[str, Any]] = None,
|
946
|
+
metrics: Optional[list] = None,
|
947
|
+
random_state: int = 1,
|
948
|
+
smote: bool = False,
|
949
|
+
n_jobs: int = -1,
|
950
|
+
plot_: bool = True,
|
951
|
+
class_weight: str = "balanced",
|
952
|
+
) -> dict:
|
953
|
+
"""
|
954
|
+
Validate models using selected features on the validation dataset.
|
955
|
+
|
956
|
+
Parameters:
|
957
|
+
- x_train (pd.DataFrame): Training feature dataset.
|
958
|
+
- y_train (pd.Series): Training target variable.
|
959
|
+
- x_true (pd.DataFrame): Validation feature dataset.
|
960
|
+
- y_true (pd.Series): Validation target variable.
|
961
|
+
- common_features (set): Set of common features to use for validation.
|
962
|
+
- models (dict, optional): Dictionary of models to validate.
|
963
|
+
- metrics (list, optional): List of metrics to compute.
|
964
|
+
- random_state (int): Random state for reproducibility.
|
965
|
+
- plot_ (bool): Option to plot metrics (to be implemented if needed).
|
966
|
+
- class_weight (str or dict): Class weights to handle imbalance.
|
967
|
+
|
968
|
+
"""
|
969
|
+
from tqdm import tqdm
|
970
|
+
|
971
|
+
# Ensure common features are selected
|
972
|
+
common_features = ips.shared(
|
973
|
+
common_features, x_train.columns, x_true.columns, strict=True, verbose=False
|
974
|
+
)
|
975
|
+
|
976
|
+
# Filter the training and validation datasets for the common features
|
977
|
+
x_train_selected = x_train[common_features]
|
978
|
+
x_true_selected = x_true[common_features]
|
979
|
+
|
980
|
+
if not x_true_selected.index.equals(y_true.index):
|
981
|
+
raise ValueError(
|
982
|
+
"Index mismatch between validation features and target. Ensure data alignment."
|
983
|
+
)
|
984
|
+
|
985
|
+
y_true = y_true.loc[x_true_selected.index]
|
986
|
+
|
987
|
+
# Handle class imbalance using SMOTE
|
988
|
+
if smote:
|
989
|
+
if (
|
990
|
+
y_train.value_counts(normalize=True).max() < 0.8
|
991
|
+
): # Threshold to decide if data is imbalanced
|
992
|
+
smote = SMOTE(random_state=random_state)
|
993
|
+
x_train_resampled, y_train_resampled = smote.fit_resample(
|
994
|
+
x_train_selected, y_train
|
995
|
+
)
|
996
|
+
else:
|
997
|
+
# skip SMOTE
|
998
|
+
x_train_resampled, y_train_resampled = x_train_selected, y_train
|
999
|
+
else:
|
1000
|
+
x_train_resampled, y_train_resampled = x_train_selected, y_train
|
1001
|
+
|
1002
|
+
# Default models if not provided
|
1003
|
+
if models is None:
|
1004
|
+
models = {
|
1005
|
+
"Random Forest": RandomForestClassifier(
|
1006
|
+
class_weight=class_weight, random_state=random_state
|
1007
|
+
),
|
1008
|
+
"SVM": SVC(probability=True, class_weight=class_weight),
|
1009
|
+
"Logistic Regression": LogisticRegression(
|
1010
|
+
class_weight=class_weight, random_state=random_state
|
1011
|
+
),
|
1012
|
+
"Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
|
1013
|
+
"AdaBoost": AdaBoostClassifier(
|
1014
|
+
random_state=random_state, algorithm="SAMME"
|
1015
|
+
),
|
1016
|
+
"Lasso": LogisticRegression(
|
1017
|
+
penalty="l1", solver="saga", random_state=random_state
|
1018
|
+
),
|
1019
|
+
"Ridge": LogisticRegression(
|
1020
|
+
penalty="l2", solver="saga", random_state=random_state
|
1021
|
+
),
|
1022
|
+
"Elastic Net": LogisticRegression(
|
1023
|
+
penalty="elasticnet",
|
1024
|
+
solver="saga",
|
1025
|
+
l1_ratio=0.5,
|
1026
|
+
random_state=random_state,
|
1027
|
+
),
|
1028
|
+
"XGBoost": xgb.XGBClassifier(eval_metric="logloss"),
|
1029
|
+
"Naive Bayes": GaussianNB(),
|
1030
|
+
"LDA": LinearDiscriminantAnalysis(),
|
1031
|
+
}
|
1032
|
+
|
1033
|
+
# Hyperparameter grids for tuning
|
1034
|
+
param_grids = {
|
1035
|
+
"Random Forest": {
|
1036
|
+
"n_estimators": [100, 200, 300, 400, 500],
|
1037
|
+
"max_depth": [None, 3, 5, 10, 20],
|
1038
|
+
"min_samples_split": [2, 5, 10],
|
1039
|
+
"min_samples_leaf": [1, 2, 4],
|
1040
|
+
"class_weight": [None, "balanced"],
|
1041
|
+
},
|
1042
|
+
"SVM": {
|
1043
|
+
"C": [0.01, 0.1, 1, 10, 100, 1000],
|
1044
|
+
"gamma": [0.001, 0.01, 0.1, "scale", "auto"],
|
1045
|
+
"kernel": ["linear", "rbf", "poly"],
|
1046
|
+
},
|
1047
|
+
"Logistic Regression": {
|
1048
|
+
"C": [0.01, 0.1, 1, 10, 100],
|
1049
|
+
"solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
|
1050
|
+
"penalty": ["l1", "l2"],
|
1051
|
+
"max_iter": [100, 200, 300],
|
1052
|
+
},
|
1053
|
+
"Gradient Boosting": {
|
1054
|
+
"n_estimators": [100, 200, 300, 400, 500],
|
1055
|
+
"learning_rate": np.logspace(-3, 0, 4),
|
1056
|
+
"max_depth": [3, 5, 7, 9],
|
1057
|
+
"min_samples_split": [2, 5, 10],
|
1058
|
+
},
|
1059
|
+
"AdaBoost": {
|
1060
|
+
"n_estimators": [50, 100, 200, 300, 500],
|
1061
|
+
"learning_rate": np.logspace(-3, 0, 4),
|
1062
|
+
},
|
1063
|
+
"Lasso": {"C": np.logspace(-3, 1, 10), "max_iter": [100, 200, 300]},
|
1064
|
+
"Ridge": {"C": np.logspace(-3, 1, 10), "max_iter": [100, 200, 300]},
|
1065
|
+
"Elastic Net": {
|
1066
|
+
"C": np.logspace(-3, 1, 10),
|
1067
|
+
"l1_ratio": [0.1, 0.5, 0.9],
|
1068
|
+
"max_iter": [100, 200, 300],
|
1069
|
+
},
|
1070
|
+
"XGBoost": {
|
1071
|
+
"n_estimators": [100, 200],
|
1072
|
+
"max_depth": [3, 5, 7],
|
1073
|
+
"learning_rate": [0.01, 0.1, 0.2],
|
1074
|
+
"subsample": [0.8, 1.0],
|
1075
|
+
"colsample_bytree": [0.8, 1.0],
|
1076
|
+
},
|
1077
|
+
"Naive Bayes": {},
|
1078
|
+
"LDA": {"solver": ["svd", "lsqr", "eigen"]},
|
1079
|
+
}
|
1080
|
+
# Default metrics if not provided
|
1081
|
+
if metrics is None:
|
1082
|
+
metrics = [
|
1083
|
+
"accuracy",
|
1084
|
+
"precision",
|
1085
|
+
"recall",
|
1086
|
+
"f1",
|
1087
|
+
"roc_auc",
|
1088
|
+
"mcc",
|
1089
|
+
"specificity",
|
1090
|
+
"balanced_accuracy",
|
1091
|
+
"pr_auc",
|
1092
|
+
]
|
1093
|
+
|
1094
|
+
results = {}
|
1095
|
+
|
1096
|
+
# Validate each classifier with GridSearchCV
|
1097
|
+
for name, clf in tqdm(
|
1098
|
+
models.items(),
|
1099
|
+
desc="for metric in metrics",
|
1100
|
+
colour="green",
|
1101
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
|
1102
|
+
):
|
1103
|
+
print(f"\nValidating {name} on the validation dataset:")
|
1104
|
+
|
1105
|
+
# Check if `predict_proba` method exists; if not, use CalibratedClassifierCV
|
1106
|
+
# 没有predict_proba的分类器,使用 CalibratedClassifierCV 可以获得校准的概率估计。此外,为了使代码更灵活,我们可以在创建分类器
|
1107
|
+
# 时检查 predict_proba 方法是否存在,如果不存在且用户希望计算 roc_auc 或 pr_auc,则启用 CalibratedClassifierCV
|
1108
|
+
if not hasattr(clf, "predict_proba"):
|
1109
|
+
print(
|
1110
|
+
f"Using CalibratedClassifierCV for {name} due to lack of probability estimates."
|
1111
|
+
)
|
1112
|
+
calibrated_clf = CalibratedClassifierCV(clf, method="sigmoid", cv="prefit")
|
1113
|
+
else:
|
1114
|
+
calibrated_clf = clf
|
1115
|
+
# Stratified K-Fold for cross-validation
|
1116
|
+
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
|
1117
|
+
|
1118
|
+
# Create GridSearchCV object
|
1119
|
+
gs = GridSearchCV(
|
1120
|
+
estimator=calibrated_clf,
|
1121
|
+
param_grid=param_grids[name],
|
1122
|
+
scoring="roc_auc", # Optimize for ROC AUC
|
1123
|
+
cv=skf, # Stratified K-Folds cross-validation
|
1124
|
+
n_jobs=n_jobs,
|
1125
|
+
verbose=1,
|
1126
|
+
)
|
1127
|
+
|
1128
|
+
# Fit the model using GridSearchCV
|
1129
|
+
gs.fit(x_train_resampled, y_train_resampled)
|
1130
|
+
# Best estimator from grid search
|
1131
|
+
best_clf = gs.best_estimator_
|
1132
|
+
# Make predictions on the validation set
|
1133
|
+
y_pred = best_clf.predict(x_true_selected)
|
1134
|
+
# Calculate probabilities for ROC AUC if possible
|
1135
|
+
if hasattr(best_clf, "predict_proba"):
|
1136
|
+
y_pred_proba = best_clf.predict_proba(x_true_selected)[:, 1]
|
1137
|
+
elif hasattr(best_clf, "decision_function"):
|
1138
|
+
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
1139
|
+
y_pred_proba = best_clf.decision_function(x_true_selected)
|
1140
|
+
# Ensure y_pred_proba is within 0 and 1 bounds
|
1141
|
+
y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
|
1142
|
+
y_pred_proba.max() - y_pred_proba.min()
|
1143
|
+
)
|
1144
|
+
else:
|
1145
|
+
y_pred_proba = None # No probability output for certain models
|
1146
|
+
|
1147
|
+
# Calculate metrics
|
1148
|
+
validation_scores = {}
|
1149
|
+
for metric in metrics:
|
1150
|
+
if metric == "accuracy":
|
1151
|
+
validation_scores[metric] = accuracy_score(y_true, y_pred)
|
1152
|
+
elif metric == "precision":
|
1153
|
+
validation_scores[metric] = precision_score(
|
1154
|
+
y_true, y_pred, average="weighted"
|
1155
|
+
)
|
1156
|
+
elif metric == "recall":
|
1157
|
+
validation_scores[metric] = recall_score(
|
1158
|
+
y_true, y_pred, average="weighted"
|
1159
|
+
)
|
1160
|
+
elif metric == "f1":
|
1161
|
+
validation_scores[metric] = f1_score(y_true, y_pred, average="weighted")
|
1162
|
+
elif metric == "roc_auc" and y_pred_proba is not None:
|
1163
|
+
validation_scores[metric] = roc_auc_score(y_true, y_pred_proba)
|
1164
|
+
elif metric == "mcc":
|
1165
|
+
validation_scores[metric] = matthews_corrcoef(y_true, y_pred)
|
1166
|
+
elif metric == "specificity":
|
1167
|
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
1168
|
+
validation_scores[metric] = tn / (tn + fp) # Specificity calculation
|
1169
|
+
elif metric == "balanced_accuracy":
|
1170
|
+
validation_scores[metric] = balanced_accuracy_score(y_true, y_pred)
|
1171
|
+
elif metric == "pr_auc" and y_pred_proba is not None:
|
1172
|
+
precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
|
1173
|
+
validation_scores[metric] = average_precision_score(
|
1174
|
+
y_true, y_pred_proba
|
1175
|
+
)
|
1176
|
+
|
1177
|
+
# Calculate ROC curve
|
1178
|
+
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
1179
|
+
if y_pred_proba is not None:
|
1180
|
+
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
1181
|
+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
1182
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
|
1183
|
+
roc_auc = auc(fpr, tpr)
|
1184
|
+
roc_info = {
|
1185
|
+
"fpr": fpr.tolist(),
|
1186
|
+
"tpr": tpr.tolist(),
|
1187
|
+
"auc": roc_auc,
|
1188
|
+
"ci95": (lower_ci, upper_ci),
|
1189
|
+
}
|
1190
|
+
# precision-recall curve
|
1191
|
+
precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
|
1192
|
+
avg_precision_ = average_precision_score(y_true, y_pred_proba)
|
1193
|
+
pr_info = {
|
1194
|
+
"precision": precision_,
|
1195
|
+
"recall": recall_,
|
1196
|
+
"avg_precision": avg_precision_,
|
1197
|
+
}
|
1198
|
+
else:
|
1199
|
+
roc_info, pr_info = None, None
|
1200
|
+
results[name] = {
|
1201
|
+
"best_params": gs.best_params_,
|
1202
|
+
"scores": validation_scores,
|
1203
|
+
"roc_curve": roc_info,
|
1204
|
+
"pr_curve": pr_info,
|
1205
|
+
"confusion_matrix": confusion_matrix(y_true, y_pred),
|
1206
|
+
}
|
1207
|
+
|
1208
|
+
df_results = pd.DataFrame.from_dict(results, orient="index")
|
1209
|
+
|
1210
|
+
return df_results
|
1211
|
+
|
1212
|
+
|
1213
|
+
#! usage validate_features()
|
1214
|
+
# Validate models using the validation dataset (X_val, y_val)
|
1215
|
+
# validation_results = validate_features(X, y, X_val, y_val, common_features)
|
1216
|
+
|
1217
|
+
|
1218
|
+
# # If you want to access validation scores
|
1219
|
+
# print(validation_results)
|
1220
|
+
def plot_validate_features(res_val):
|
1221
|
+
"""
|
1222
|
+
plot the results of 'validate_features()'
|
1223
|
+
"""
|
1224
|
+
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1225
|
+
if res_val.shape[0] > 5:
|
1226
|
+
alpha = 0
|
1227
|
+
figsize = [8, 10]
|
1228
|
+
subplot_layout = [1, 2]
|
1229
|
+
ncols = 2
|
1230
|
+
bbox_to_anchor = [1.5, 0.6]
|
1231
|
+
else:
|
1232
|
+
alpha = 0.03
|
1233
|
+
figsize = [10, 6]
|
1234
|
+
subplot_layout = [1, 1]
|
1235
|
+
ncols = 1
|
1236
|
+
bbox_to_anchor = [1, 1]
|
1237
|
+
nexttile = plot.subplot(figsize=figsize)
|
1238
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1239
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1240
|
+
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1241
|
+
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1242
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1243
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1244
|
+
plot_roc_curve(
|
1245
|
+
fpr,
|
1246
|
+
tpr,
|
1247
|
+
mean_auc,
|
1248
|
+
lower_ci,
|
1249
|
+
upper_ci,
|
1250
|
+
model_name=model_name,
|
1251
|
+
lw=1.5,
|
1252
|
+
color=colors[i],
|
1253
|
+
alpha=alpha,
|
1254
|
+
ax=ax,
|
1255
|
+
)
|
1256
|
+
plot.figsets(
|
1257
|
+
sp=2,
|
1258
|
+
legend=dict(
|
1259
|
+
loc="upper right",
|
1260
|
+
ncols=ncols,
|
1261
|
+
fontsize=8,
|
1262
|
+
bbox_to_anchor=[1.5, 0.6],
|
1263
|
+
markerscale=0.8,
|
1264
|
+
),
|
1265
|
+
)
|
1266
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1267
|
+
|
1268
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1269
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1270
|
+
plot_pr_curve(
|
1271
|
+
recall=res_val["pr_curve"][model_name]["recall"],
|
1272
|
+
precision=res_val["pr_curve"][model_name]["precision"],
|
1273
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1274
|
+
model_name=model_name,
|
1275
|
+
color=colors[i],
|
1276
|
+
lw=1.5,
|
1277
|
+
alpha=alpha,
|
1278
|
+
ax=ax,
|
1279
|
+
)
|
1280
|
+
plot.figsets(
|
1281
|
+
sp=2,
|
1282
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1283
|
+
)
|
1284
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1285
|
+
|
1286
|
+
|
1287
|
+
def plot_validate_features_single(res_val, figsize=None):
|
1288
|
+
if figsize is None:
|
1289
|
+
nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
|
1290
|
+
else:
|
1291
|
+
nexttile = plot.subplot(
|
1292
|
+
len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
|
1293
|
+
)
|
1294
|
+
for model_name in ips.flatten(res_val["pr_curve"].index):
|
1295
|
+
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1296
|
+
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1297
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1298
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1299
|
+
|
1300
|
+
# Plotting
|
1301
|
+
plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
|
1302
|
+
model_name=model_name, ax=nexttile())
|
1303
|
+
plot.figsets(title=model_name, sp=2)
|
1304
|
+
|
1305
|
+
plot_pr_binary(
|
1306
|
+
recall=res_val["pr_curve"][model_name]["recall"],
|
1307
|
+
precision=res_val["pr_curve"][model_name]["precision"],
|
1308
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1309
|
+
model_name=model_name,
|
1310
|
+
ax=nexttile(),
|
1311
|
+
)
|
1312
|
+
plot.figsets(title=model_name, sp=2)
|
1313
|
+
|
1314
|
+
# plot cm
|
1315
|
+
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1316
|
+
plot.figsets(title=model_name, sp=2)
|
1317
|
+
|
1318
|
+
|
1319
|
+
def cal_auc_ci(
|
1320
|
+
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
|
1321
|
+
):
|
1322
|
+
y_true = np.asarray(y_true)
|
1323
|
+
y_pred = np.asarray(y_pred)
|
1324
|
+
bootstrapped_scores = []
|
1325
|
+
if verbose:
|
1326
|
+
print("auroc score:", roc_auc_score(y_true, y_pred))
|
1327
|
+
rng = np.random.RandomState(random_state)
|
1328
|
+
for i in range(n_bootstraps):
|
1329
|
+
# bootstrap by sampling with replacement on the prediction indices
|
1330
|
+
indices = rng.randint(0, len(y_pred), len(y_pred))
|
1331
|
+
if len(np.unique(y_true[indices])) < 2:
|
1332
|
+
# We need at least one positive and one negative sample for ROC AUC
|
1333
|
+
# to be defined: reject the sample
|
1334
|
+
continue
|
1335
|
+
if isinstance(y_true, np.ndarray):
|
1336
|
+
score = roc_auc_score(y_true[indices], y_pred[indices])
|
1337
|
+
else:
|
1338
|
+
score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
|
1339
|
+
bootstrapped_scores.append(score)
|
1340
|
+
# print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
|
1341
|
+
sorted_scores = np.array(bootstrapped_scores)
|
1342
|
+
sorted_scores.sort()
|
1343
|
+
|
1344
|
+
# Computing the lower and upper bound of the 90% confidence interval
|
1345
|
+
# You can change the bounds percentiles to 0.025 and 0.975 to get
|
1346
|
+
# a 95% confidence interval instead.
|
1347
|
+
confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
|
1348
|
+
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1349
|
+
if verbose:
|
1350
|
+
print(
|
1351
|
+
"Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
|
1352
|
+
confidence_lower, confidence_upper
|
1353
|
+
)
|
1354
|
+
)
|
1355
|
+
return confidence_lower, confidence_upper
|
1356
|
+
|
1357
|
+
|
1358
|
+
def plot_roc_curve(
|
1359
|
+
fpr=None,
|
1360
|
+
tpr=None,
|
1361
|
+
mean_auc=None,
|
1362
|
+
lower_ci=None,
|
1363
|
+
upper_ci=None,
|
1364
|
+
model_name=None,
|
1365
|
+
color="#FF8F00",
|
1366
|
+
lw=2,
|
1367
|
+
alpha=0.1,
|
1368
|
+
ci_display=True,
|
1369
|
+
title="ROC Curve",
|
1370
|
+
xlabel="1−Specificity",
|
1371
|
+
ylabel="Sensitivity",
|
1372
|
+
legend_loc="lower right",
|
1373
|
+
diagonal_color="0.5",
|
1374
|
+
figsize=(5, 5),
|
1375
|
+
ax=None,
|
1376
|
+
**kwargs,
|
1377
|
+
):
|
1378
|
+
if ax is None:
|
1379
|
+
fig, ax = plt.subplots(figsize=figsize)
|
1380
|
+
if mean_auc is not None:
|
1381
|
+
model_name = "ROC curve" if model_name is None else model_name
|
1382
|
+
if ci_display:
|
1383
|
+
label = f"{model_name} (AUC = {mean_auc:.3f})\n95% CI: {lower_ci:.3f} - {upper_ci:.3f}"
|
1384
|
+
else:
|
1385
|
+
label = f"{model_name} (AUC = {mean_auc:.3f})"
|
1386
|
+
else:
|
1387
|
+
label = None
|
1388
|
+
|
1389
|
+
# Plot ROC curve and the diagonal reference line
|
1390
|
+
ax.fill_between(fpr, tpr, alpha=alpha, color=color)
|
1391
|
+
ax.plot([0, 1], [0, 1], color=diagonal_color, clip_on=False, linestyle="--")
|
1392
|
+
ax.plot(fpr, tpr, color=color, lw=lw, label=label, clip_on=False, **kwargs)
|
1393
|
+
# Setting plot limits, labels, and title
|
1394
|
+
ax.set_xlim([-0.01, 1.0])
|
1395
|
+
ax.set_ylim([0.0, 1.0])
|
1396
|
+
ax.set_xlabel(xlabel)
|
1397
|
+
ax.set_ylabel(ylabel)
|
1398
|
+
ax.set_title(title)
|
1399
|
+
ax.legend(loc=legend_loc)
|
1400
|
+
return ax
|
1401
|
+
|
1402
|
+
|
1403
|
+
# * usage: ml2ls.plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci)
|
1404
|
+
# for model_name in flatten(validation_results["roc_curve"].keys())[2:]:
|
1405
|
+
# fpr = validation_results["roc_curve"][model_name]["fpr"]
|
1406
|
+
# tpr = validation_results["roc_curve"][model_name]["tpr"]
|
1407
|
+
# (lower_ci, upper_ci) = validation_results["roc_curve"][model_name]["ci95"]
|
1408
|
+
# mean_auc = validation_results["roc_curve"][model_name]["auc"]
|
1409
|
+
|
1410
|
+
# # Plotting
|
1411
|
+
# ml2ls.plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci)
|
1412
|
+
# figsets(title=model_name)
|
1413
|
+
|
1414
|
+
def plot_pr_curve(
|
1415
|
+
recall=None,
|
1416
|
+
precision=None,
|
1417
|
+
avg_precision=None,
|
1418
|
+
model_name=None,
|
1419
|
+
lw=2,
|
1420
|
+
figsize=[5, 5],
|
1421
|
+
title="Precision-Recall Curve",
|
1422
|
+
xlabel="Recall",
|
1423
|
+
ylabel="Precision",
|
1424
|
+
alpha=0.1,
|
1425
|
+
color="#FF8F00",
|
1426
|
+
legend_loc="lower left",
|
1427
|
+
ax=None,
|
1428
|
+
**kwargs,
|
1429
|
+
):
|
1430
|
+
if ax is None:
|
1431
|
+
fig, ax = plt.subplots(figsize=figsize)
|
1432
|
+
model_name = "PR curve" if model_name is None else model_name
|
1433
|
+
# Plot Precision-Recall curve
|
1434
|
+
ax.plot(
|
1435
|
+
recall,
|
1436
|
+
precision,
|
1437
|
+
lw=lw,
|
1438
|
+
color=color,
|
1439
|
+
label=(f"{model_name} (AP={avg_precision:.2f})"),
|
1440
|
+
clip_on=False,
|
1441
|
+
**kwargs,
|
1442
|
+
)
|
1443
|
+
# Fill area under the curve
|
1444
|
+
ax.fill_between(recall, precision, alpha=alpha, color=color)
|
1445
|
+
|
1446
|
+
# Customize axes
|
1447
|
+
ax.set_title(title)
|
1448
|
+
ax.set_xlabel(xlabel)
|
1449
|
+
ax.set_ylabel(ylabel)
|
1450
|
+
ax.set_xlim([-0.01, 1.0])
|
1451
|
+
ax.set_ylim([0.0, 1.0])
|
1452
|
+
ax.grid(False)
|
1453
|
+
ax.legend(loc=legend_loc)
|
1454
|
+
return ax
|
1455
|
+
|
1456
|
+
# * usage: ml2ls.plot_pr_curve()
|
1457
|
+
# for md_name in flatten(validation_results["pr_curve"].keys()):
|
1458
|
+
# ml2ls.plot_pr_curve(
|
1459
|
+
# recall=validation_results["pr_curve"][md_name]["recall"],
|
1460
|
+
# precision=validation_results["pr_curve"][md_name]["precision"],
|
1461
|
+
# avg_precision=validation_results["pr_curve"][md_name]["avg_precision"],
|
1462
|
+
# model_name=md_name,
|
1463
|
+
# lw=2,
|
1464
|
+
# alpha=0.1,
|
1465
|
+
# color="r",
|
1466
|
+
# )
|
1467
|
+
|
1468
|
+
def plot_pr_binary(
|
1469
|
+
recall=None,
|
1470
|
+
precision=None,
|
1471
|
+
avg_precision=None,
|
1472
|
+
model_name=None,
|
1473
|
+
lw=2,
|
1474
|
+
figsize=[5, 5],
|
1475
|
+
title="Precision-Recall Curve",
|
1476
|
+
xlabel="Recall",
|
1477
|
+
ylabel="Precision",
|
1478
|
+
alpha=0.1,
|
1479
|
+
color="#FF8F00",
|
1480
|
+
legend_loc="lower left",
|
1481
|
+
ax=None,
|
1482
|
+
show_avg_precision=False,
|
1483
|
+
**kwargs,
|
1484
|
+
):
|
1485
|
+
from scipy.interpolate import interp1d
|
1486
|
+
if ax is None:
|
1487
|
+
fig, ax = plt.subplots(figsize=figsize)
|
1488
|
+
model_name = "Binary PR Curve" if model_name is None else model_name
|
1489
|
+
|
1490
|
+
#* use sklearn bulitin function 'PrecisionRecallDisplay'?
|
1491
|
+
# from sklearn.metrics import PrecisionRecallDisplay
|
1492
|
+
# disp = PrecisionRecallDisplay(precision=precision,
|
1493
|
+
# recall=recall,
|
1494
|
+
# average_precision=avg_precision,**kwargs)
|
1495
|
+
# disp.plot(ax=ax, name=model_name, color=color)
|
1496
|
+
|
1497
|
+
# Plot Precision-Recall curve
|
1498
|
+
ax.plot(
|
1499
|
+
recall,
|
1500
|
+
precision,
|
1501
|
+
lw=lw,
|
1502
|
+
color=color,
|
1503
|
+
label=(f"{model_name} (AP={avg_precision:.2f})"),
|
1504
|
+
clip_on=False,
|
1505
|
+
**kwargs,
|
1506
|
+
)
|
1507
|
+
|
1508
|
+
# Fill area under the curve
|
1509
|
+
ax.fill_between(recall, precision, alpha=alpha, color=color)
|
1510
|
+
# Add F1 score iso-contours
|
1511
|
+
f_scores = np.linspace(0.2, 0.8, num=4)
|
1512
|
+
# for f_score in f_scores:
|
1513
|
+
# x = np.linspace(0.01, 1)
|
1514
|
+
# y = f_score * x / (2 * x - f_score)
|
1515
|
+
# plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=1)
|
1516
|
+
# plt.annotate(f"$f_1={f_score:0.1f}$", xy=(0.8, y[45] + 0.02))
|
1517
|
+
|
1518
|
+
pr_boundary = interp1d(recall, precision, kind="linear", fill_value="extrapolate")
|
1519
|
+
for f_score in f_scores:
|
1520
|
+
x_vals = np.linspace(0.01, 1, 10000)
|
1521
|
+
y_vals = f_score * x_vals / (2 * x_vals - f_score)
|
1522
|
+
y_vals_clipped = np.minimum(y_vals, pr_boundary(x_vals))
|
1523
|
+
y_vals_clipped = np.clip(y_vals_clipped, 1e-3, None) # Prevent going to zero
|
1524
|
+
valid = y_vals_clipped < pr_boundary(x_vals)
|
1525
|
+
valid_ = y_vals_clipped > 1e-3
|
1526
|
+
valid = valid&valid_
|
1527
|
+
x_vals = x_vals[valid]
|
1528
|
+
y_vals_clipped = y_vals_clipped[valid]
|
1529
|
+
if len(x_vals) > 0: # Ensure annotation is placed only if line segment exists
|
1530
|
+
ax.plot(x_vals, y_vals_clipped, color="gray", alpha=1)
|
1531
|
+
plt.annotate(f"$f_1={f_score:0.1f}$", xy=(0.8, y_vals_clipped[-int(len(y_vals_clipped)*0.35)] + 0.02))
|
1532
|
+
|
1533
|
+
|
1534
|
+
# # Plot the average precision line
|
1535
|
+
if show_avg_precision:
|
1536
|
+
plt.axhline(
|
1537
|
+
y=avg_precision,
|
1538
|
+
color="red",
|
1539
|
+
ls="--",
|
1540
|
+
lw=lw,
|
1541
|
+
label=f"Avg. precision={avg_precision:.2f}",
|
1542
|
+
)
|
1543
|
+
# Customize axes
|
1544
|
+
ax.set_title(title)
|
1545
|
+
ax.set_xlabel(xlabel)
|
1546
|
+
ax.set_ylabel(ylabel)
|
1547
|
+
ax.set_xlim([-0.01, 1.0])
|
1548
|
+
ax.set_ylim([0.0, 1.0])
|
1549
|
+
ax.grid(False)
|
1550
|
+
ax.legend(loc=legend_loc)
|
1551
|
+
return ax
|
1552
|
+
|
1553
|
+
def plot_cm(
|
1554
|
+
cm,
|
1555
|
+
labels_name=None,
|
1556
|
+
thresh=0.8,
|
1557
|
+
axis_labels=None,
|
1558
|
+
cmap="Reds",
|
1559
|
+
normalize=True,
|
1560
|
+
xlabel="Predicted Label",
|
1561
|
+
ylabel="Actual Label",
|
1562
|
+
fontsize=12,
|
1563
|
+
figsize=[5, 5],
|
1564
|
+
ax=None,
|
1565
|
+
):
|
1566
|
+
if ax is None:
|
1567
|
+
fig, ax = plt.subplots(figsize=figsize)
|
1568
|
+
|
1569
|
+
cm_normalized = np.round(
|
1570
|
+
cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100, 2
|
1571
|
+
)
|
1572
|
+
cm_value = cm_normalized if normalize else cm.astype("int")
|
1573
|
+
# Plot the heatmap
|
1574
|
+
cax = ax.imshow(cm_normalized, interpolation="nearest", cmap=cmap)
|
1575
|
+
plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)
|
1576
|
+
cax.set_clim(0, 100)
|
1577
|
+
|
1578
|
+
# Define tick labels based on provided labels
|
1579
|
+
num_local = np.arange(len(labels_name)) if labels_name is not None else range(2)
|
1580
|
+
if axis_labels is None:
|
1581
|
+
axis_labels = labels_name if labels_name is not None else ["No", "Yes"]
|
1582
|
+
ax.set_xticks(num_local)
|
1583
|
+
ax.set_xticklabels(axis_labels)
|
1584
|
+
ax.set_yticks(num_local)
|
1585
|
+
ax.set_yticklabels(axis_labels)
|
1586
|
+
ax.set_ylabel(ylabel)
|
1587
|
+
ax.set_xlabel(xlabel)
|
1588
|
+
|
1589
|
+
# Add TN, FP, FN, TP annotations specifically for binary classification (2x2 matrix)
|
1590
|
+
if labels_name is None or len(labels_name) == 2:
|
1591
|
+
# True Negative (TN), False Positive (FP), False Negative (FN), and True Positive (TP)
|
1592
|
+
# Predicted
|
1593
|
+
# 0 | 1
|
1594
|
+
# ----------------
|
1595
|
+
# 0 | TN | FP
|
1596
|
+
# Actual ----------------
|
1597
|
+
# 1 | FN | TP
|
1598
|
+
tn_label = "TN"
|
1599
|
+
fp_label = "FP"
|
1600
|
+
fn_label = "FN"
|
1601
|
+
tp_label = "TP"
|
1602
|
+
|
1603
|
+
# Adjust positions slightly for TN, FP, FN, TP labels
|
1604
|
+
ax.text(
|
1605
|
+
0,
|
1606
|
+
0,
|
1607
|
+
(
|
1608
|
+
f"{tn_label}:{cm_normalized[0, 0]:.2f}%"
|
1609
|
+
if normalize
|
1610
|
+
else f"{tn_label}:{cm_value[0, 0]}"
|
1611
|
+
),
|
1612
|
+
ha="center",
|
1613
|
+
va="center",
|
1614
|
+
color="white" if cm_normalized[0, 0] > thresh * 100 else "black",
|
1615
|
+
fontsize=fontsize,
|
1616
|
+
)
|
1617
|
+
ax.text(
|
1618
|
+
1,
|
1619
|
+
0,
|
1620
|
+
(
|
1621
|
+
f"{fp_label}:{cm_normalized[0, 1]:.2f}%"
|
1622
|
+
if normalize
|
1623
|
+
else f"{fp_label}:{cm_value[0, 1]}"
|
1624
|
+
),
|
1625
|
+
ha="center",
|
1626
|
+
va="center",
|
1627
|
+
color="white" if cm_normalized[0, 1] > thresh * 100 else "black",
|
1628
|
+
fontsize=fontsize,
|
1629
|
+
)
|
1630
|
+
ax.text(
|
1631
|
+
0,
|
1632
|
+
1,
|
1633
|
+
(
|
1634
|
+
f"{fn_label}:{cm_normalized[1, 0]:.2f}%"
|
1635
|
+
if normalize
|
1636
|
+
else f"{fn_label}:{cm_value[1, 0]}"
|
1637
|
+
),
|
1638
|
+
ha="center",
|
1639
|
+
va="center",
|
1640
|
+
color="white" if cm_normalized[1, 0] > thresh * 100 else "black",
|
1641
|
+
fontsize=fontsize,
|
1642
|
+
)
|
1643
|
+
ax.text(
|
1644
|
+
1,
|
1645
|
+
1,
|
1646
|
+
(
|
1647
|
+
f"{tp_label}:{cm_normalized[1, 1]:.2f}%"
|
1648
|
+
if normalize
|
1649
|
+
else f"{tp_label}:{cm_value[1, 1]}"
|
1650
|
+
),
|
1651
|
+
ha="center",
|
1652
|
+
va="center",
|
1653
|
+
color="white" if cm_normalized[1, 1] > thresh * 100 else "black",
|
1654
|
+
fontsize=fontsize,
|
1655
|
+
)
|
1656
|
+
else:
|
1657
|
+
# Annotate cells with normalized percentage values
|
1658
|
+
for i in range(len(labels_name)):
|
1659
|
+
for j in range(len(labels_name)):
|
1660
|
+
val = cm_normalized[i, j]
|
1661
|
+
color = "white" if val > thresh * 100 else "black"
|
1662
|
+
ax.text(
|
1663
|
+
j,
|
1664
|
+
i,
|
1665
|
+
f"{val:.2f}%",
|
1666
|
+
ha="center",
|
1667
|
+
va="center",
|
1668
|
+
color=color,
|
1669
|
+
fontsize=fontsize,
|
1670
|
+
)
|
1671
|
+
|
1672
|
+
plot.figsets(ax=ax, boxloc="none")
|
1673
|
+
return ax
|
1674
|
+
|
1675
|
+
|
1676
|
+
def rank_models(
|
1677
|
+
cv_test_scores,
|
1678
|
+
rm_outlier=False,
|
1679
|
+
metric_weights=None,
|
1680
|
+
plot_=True,
|
1681
|
+
):
|
1682
|
+
"""
|
1683
|
+
Selects the best model based on a multi-metric scoring approach, with outlier handling, optional visualization,
|
1684
|
+
and additional performance metrics.
|
1685
|
+
|
1686
|
+
Parameters:
|
1687
|
+
- cv_test_scores (pd.DataFrame): DataFrame with cross-validation results across multiple metrics.
|
1688
|
+
Assumes columns are 'Classifier', 'accuracy', 'precision', 'recall', 'f1', 'roc_auc'.
|
1689
|
+
- metric_weights (dict): Dictionary specifying weights for each metric (e.g., {'accuracy': 0.2, 'precision': 0.3, ...}).
|
1690
|
+
If None, default weights are applied equally across available metrics.
|
1691
|
+
a. equal_weights(standard approch): 所有的metrics同等重要
|
1692
|
+
e.g., {"accuracy": 0.2, "precision": 0.2, "recall": 0.2, "f1": 0.2, "roc_auc": 0.2}
|
1693
|
+
b. accuracy_focosed: classification correctness (e.g., in balanced datasets), accuracy might be weighted more heavily.
|
1694
|
+
e.g., {"accuracy": 0.4, "precision": 0.2, "recall": 0.2, "f1": 0.1, "roc_auc": 0.1}
|
1695
|
+
c. Precision and Recall Emphasis: In cases where false positives and false negatives are particularly important (such as
|
1696
|
+
in medical applications or fraud detection), precision and recall may be weighted more heavily.
|
1697
|
+
e.g., {"accuracy": 0.2, "precision": 0.3, "recall": 0.3, "f1": 0.1, "roc_auc": 0.1}
|
1698
|
+
d. F1-Focused: When balance between precision and recall is crucial (e.g., in imbalanced datasets)
|
1699
|
+
e.g., {"accuracy": 0.2, "precision": 0.2, "recall": 0.2, "f1": 0.3, "roc_auc": 0.1}
|
1700
|
+
e. ROC-AUC Emphasis: In some cases, ROC AUC may be prioritized, particularly in classification tasks where class imbalance
|
1701
|
+
is present, as ROC AUC accounts for the model's performance across all classification thresholds.
|
1702
|
+
e.g., {"accuracy": 0.1, "precision": 0.2, "recall": 0.2, "f1": 0.3, "roc_auc": 0.3}
|
1703
|
+
|
1704
|
+
- normalize (bool): Whether to normalize scores of each metric to range [0, 1].
|
1705
|
+
- visualize (bool): If True, generates visualizations (e.g., bar plot, radar chart).
|
1706
|
+
- outlier_threshold (float): The threshold to detect outliers using the IQR method. Default is 1.5.
|
1707
|
+
- cv_folds (int): The number of cross-validation folds used.
|
1708
|
+
|
1709
|
+
Returns:
|
1710
|
+
- best_model (str): Name of the best model based on the combined metric scores.
|
1711
|
+
- scored_df (pd.DataFrame): DataFrame with an added 'combined_score' column used for model selection.
|
1712
|
+
- visualizations (dict): A dictionary containing visualizations if `visualize=True`.
|
1713
|
+
"""
|
1714
|
+
from sklearn.preprocessing import MinMaxScaler
|
1715
|
+
import seaborn as sns
|
1716
|
+
import matplotlib.pyplot as plt
|
1717
|
+
from py2ls import plot
|
1718
|
+
|
1719
|
+
# Check for missing metrics and set default weights if not provided
|
1720
|
+
available_metrics = cv_test_scores.columns[1:] # Exclude 'Classifier' column
|
1721
|
+
if metric_weights is None:
|
1722
|
+
metric_weights = {
|
1723
|
+
metric: 1 / len(available_metrics) for metric in available_metrics
|
1724
|
+
} # Equal weight if not specified
|
1725
|
+
elif metric_weights == "a":
|
1726
|
+
metric_weights = {
|
1727
|
+
"accuracy": 0.2,
|
1728
|
+
"precision": 0.2,
|
1729
|
+
"recall": 0.2,
|
1730
|
+
"f1": 0.2,
|
1731
|
+
"roc_auc": 0.2,
|
1732
|
+
}
|
1733
|
+
elif metric_weights == "b":
|
1734
|
+
metric_weights = {
|
1735
|
+
"accuracy": 0.4,
|
1736
|
+
"precision": 0.2,
|
1737
|
+
"recall": 0.2,
|
1738
|
+
"f1": 0.1,
|
1739
|
+
"roc_auc": 0.1,
|
1740
|
+
}
|
1741
|
+
elif metric_weights == "c":
|
1742
|
+
metric_weights = {
|
1743
|
+
"accuracy": 0.2,
|
1744
|
+
"precision": 0.3,
|
1745
|
+
"recall": 0.3,
|
1746
|
+
"f1": 0.1,
|
1747
|
+
"roc_auc": 0.1,
|
1748
|
+
}
|
1749
|
+
elif metric_weights == "d":
|
1750
|
+
metric_weights = {
|
1751
|
+
"accuracy": 0.2,
|
1752
|
+
"precision": 0.2,
|
1753
|
+
"recall": 0.2,
|
1754
|
+
"f1": 0.3,
|
1755
|
+
"roc_auc": 0.1,
|
1756
|
+
}
|
1757
|
+
elif metric_weights == "e":
|
1758
|
+
metric_weights = {
|
1759
|
+
"accuracy": 0.1,
|
1760
|
+
"precision": 0.2,
|
1761
|
+
"recall": 0.2,
|
1762
|
+
"f1": 0.3,
|
1763
|
+
"roc_auc": 0.3,
|
1764
|
+
}
|
1765
|
+
else:
|
1766
|
+
metric_weights = {
|
1767
|
+
metric: 1 / len(available_metrics) for metric in available_metrics
|
1768
|
+
}
|
1769
|
+
|
1770
|
+
# Normalize weights if they don’t sum to 1
|
1771
|
+
total_weight = sum(metric_weights.values())
|
1772
|
+
metric_weights = {
|
1773
|
+
metric: weight / total_weight for metric, weight in metric_weights.items()
|
1774
|
+
}
|
1775
|
+
if rm_outlier:
|
1776
|
+
cv_test_scores_ = ips.df_outlier(cv_test_scores)
|
1777
|
+
else:
|
1778
|
+
cv_test_scores_ = cv_test_scores
|
1779
|
+
|
1780
|
+
# Normalize the scores of metrics if normalize is True
|
1781
|
+
scaler = MinMaxScaler()
|
1782
|
+
normalized_scores = pd.DataFrame(
|
1783
|
+
scaler.fit_transform(cv_test_scores_[available_metrics]),
|
1784
|
+
columns=available_metrics,
|
1785
|
+
)
|
1786
|
+
cv_test_scores_ = pd.concat(
|
1787
|
+
[cv_test_scores_[["Classifier"]], normalized_scores], axis=1
|
1788
|
+
)
|
1789
|
+
|
1790
|
+
# Calculate weighted scores for each model
|
1791
|
+
cv_test_scores_["combined_score"] = sum(
|
1792
|
+
cv_test_scores_[metric] * weight for metric, weight in metric_weights.items()
|
1793
|
+
)
|
1794
|
+
top_models = cv_test_scores_.sort_values(by="combined_score", ascending=False)
|
1795
|
+
cv_test_scores = cv_test_scores.loc[top_models.index]
|
1796
|
+
top_models.reset_index(drop=True, inplace=True)
|
1797
|
+
cv_test_scores.reset_index(drop=True, inplace=True)
|
1798
|
+
|
1799
|
+
if plot_:
|
1800
|
+
|
1801
|
+
def generate_bar_plot(ax, cv_test_scores):
|
1802
|
+
ax = plot.plotxy(
|
1803
|
+
y="Classifier", x="combined_score", data=cv_test_scores, kind="bar"
|
1804
|
+
)
|
1805
|
+
plt.title("Classifier Performance")
|
1806
|
+
plt.tight_layout()
|
1807
|
+
return plt
|
1808
|
+
|
1809
|
+
nexttile = plot.subplot(2, 2, figsize=[10, 7])
|
1810
|
+
generate_bar_plot(nexttile(), top_models.dropna())
|
1811
|
+
plot.radar(
|
1812
|
+
ax=nexttile(projection="polar"),
|
1813
|
+
data=cv_test_scores.set_index("Classifier"),
|
1814
|
+
ylim=[0.5, 1],
|
1815
|
+
color=plot.get_color(10),
|
1816
|
+
alpha=0.05,
|
1817
|
+
circular=1,
|
1818
|
+
)
|
1819
|
+
return cv_test_scores
|
1820
|
+
|
1821
|
+
|
1822
|
+
# # Example Usage:
|
1823
|
+
# metric_weights = {
|
1824
|
+
# "accuracy": 0.2,
|
1825
|
+
# "precision": 0.3,
|
1826
|
+
# "recall": 0.2,
|
1827
|
+
# "f1": 0.2,
|
1828
|
+
# "roc_auc": 0.1,
|
1829
|
+
# }
|
1830
|
+
# cv_test_scores = res["cv_test_scores"].copy()
|
1831
|
+
# best_model = rank_models(
|
1832
|
+
# cv_test_scores, metric_weights=metric_weights, normalize=True, plot_=True
|
1833
|
+
# )
|
1834
|
+
|
1835
|
+
# figsave("classifier_performance.pdf")
|
1836
|
+
|
1837
|
+
|
1838
|
+
def predict(
|
1839
|
+
x_train: pd.DataFrame,
|
1840
|
+
y_train: pd.Series,
|
1841
|
+
x_true: pd.DataFrame = None,
|
1842
|
+
y_true: Optional[pd.Series] = None,
|
1843
|
+
common_features: set = None,
|
1844
|
+
purpose: str = "classification", # 'classification' or 'regression'
|
1845
|
+
cls: Optional[Dict[str, Any]] = None,
|
1846
|
+
metrics: Optional[List[str]] = None,
|
1847
|
+
random_state: int = 1,
|
1848
|
+
smote: bool = False,
|
1849
|
+
n_jobs: int = -1,
|
1850
|
+
plot_: bool = True,
|
1851
|
+
dir_save: str = "./",
|
1852
|
+
test_size: float = 0.2, # specific only when x_true is None
|
1853
|
+
cv_folds: int = 5, # more cv_folds 得更加稳定,auc可能更低
|
1854
|
+
cv_level: str = "l", # "s":'low',"m":'medium',"l":"high"
|
1855
|
+
class_weight: str = "balanced",
|
1856
|
+
verbose: bool = False,
|
1857
|
+
) -> pd.DataFrame:
|
1858
|
+
"""
|
1859
|
+
第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
|
1860
|
+
Usage:
|
1861
|
+
(1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
|
1862
|
+
predict 函数会根据 test_size 参数,将 x_train 和 y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
|
1863
|
+
(2). predict(x_train, y_train, x_true,...)使用 x_train 和 y_train 训练并对 x_true 进行预测
|
1864
|
+
由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true,
|
1865
|
+
因此无法与真实值进行对比。
|
1866
|
+
(3). predict(x_train, y_train, x_true, y_true,...)使用 x_train 和 y_train 训练,并验证 x_true 与真实标签 y_true.
|
1867
|
+
predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
|
1868
|
+
计算验证指标,完成对 x_true 的真正验证。
|
1869
|
+
trains and validates a variety of machine learning models for both classification and regression tasks.
|
1870
|
+
It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
|
1871
|
+
feature scaling, and handling of class imbalance through SMOTE.
|
1872
|
+
|
1873
|
+
Parameters:
|
1874
|
+
- x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
|
1875
|
+
- y_train (pd.Series):Target variable for the training dataset.
|
1876
|
+
- x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
|
1877
|
+
- y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
|
1878
|
+
- common_features (set, optional):Specifies a subset of features common across training and test data.
|
1879
|
+
- purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
|
1880
|
+
metrics and models are applied.
|
1881
|
+
- cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
|
1882
|
+
- metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
|
1883
|
+
- random_state (int, default = 1):Random seed to ensure reproducibility.
|
1884
|
+
- smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
|
1885
|
+
- n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
|
1886
|
+
- plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
|
1887
|
+
- test_size (float, default = 0.2):Test data proportion if x_true is not provided.
|
1888
|
+
- cv_folds (int, default = 5):Number of cross-validation folds.
|
1889
|
+
- cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
|
1890
|
+
- class_weight (str, default = "balanced"):Balances class weights in classification tasks.
|
1891
|
+
- verbose (bool, default = False):If True, prints detailed output during model training.
|
1892
|
+
- dir_save (str, default = "./"):Directory path to save plot outputs and results.
|
1893
|
+
|
1894
|
+
Key Steps in the Function:
|
1895
|
+
Model Initialization: Depending on purpose, initializes either classification or regression models.
|
1896
|
+
Feature Selection: Ensures training and test sets have matching feature columns.
|
1897
|
+
SMOTE Application: Balances classes if smote is enabled and the task is classification.
|
1898
|
+
Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
|
1899
|
+
Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
|
1900
|
+
"""
|
1901
|
+
from tqdm import tqdm
|
1902
|
+
from sklearn.ensemble import (
|
1903
|
+
RandomForestClassifier,
|
1904
|
+
RandomForestRegressor,
|
1905
|
+
ExtraTreesClassifier,
|
1906
|
+
ExtraTreesRegressor,
|
1907
|
+
BaggingClassifier,
|
1908
|
+
BaggingRegressor,
|
1909
|
+
AdaBoostClassifier,
|
1910
|
+
AdaBoostRegressor,
|
1911
|
+
)
|
1912
|
+
from sklearn.svm import SVC, SVR
|
1913
|
+
from sklearn.tree import DecisionTreeRegressor
|
1914
|
+
from sklearn.linear_model import (
|
1915
|
+
LogisticRegression,
|
1916
|
+
ElasticNet,
|
1917
|
+
ElasticNetCV,
|
1918
|
+
LinearRegression,
|
1919
|
+
Lasso,
|
1920
|
+
RidgeClassifierCV,
|
1921
|
+
Perceptron,
|
1922
|
+
SGDClassifier,
|
1923
|
+
)
|
1924
|
+
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
|
1925
|
+
from sklearn.naive_bayes import GaussianNB, BernoulliNB
|
1926
|
+
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
|
1927
|
+
import xgboost as xgb
|
1928
|
+
import lightgbm as lgb
|
1929
|
+
import catboost as cb
|
1930
|
+
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
1931
|
+
from sklearn.model_selection import GridSearchCV, StratifiedKFold, KFold
|
1932
|
+
from sklearn.discriminant_analysis import (
|
1933
|
+
LinearDiscriminantAnalysis,
|
1934
|
+
QuadraticDiscriminantAnalysis,
|
1935
|
+
)
|
1936
|
+
from sklearn.preprocessing import PolynomialFeatures
|
1937
|
+
|
1938
|
+
# 拼写检查
|
1939
|
+
purpose = ips.strcmp(purpose, ["classification", "regression"])[0]
|
1940
|
+
print(f"{purpose} processing...")
|
1941
|
+
# Default models or regressors if not provided
|
1942
|
+
if purpose == "classification":
|
1943
|
+
model_ = {
|
1944
|
+
"Random Forest": RandomForestClassifier(
|
1945
|
+
random_state=random_state, class_weight=class_weight
|
1946
|
+
),
|
1947
|
+
# SVC (Support Vector Classification)
|
1948
|
+
"SVM": SVC(
|
1949
|
+
kernel="rbf",
|
1950
|
+
probability=True,
|
1951
|
+
class_weight=class_weight,
|
1952
|
+
random_state=random_state,
|
1953
|
+
),
|
1954
|
+
# fit the best model without enforcing sparsity, which means it does not directly perform feature selection.
|
1955
|
+
"Logistic Regression": LogisticRegression(
|
1956
|
+
class_weight=class_weight, random_state=random_state
|
1957
|
+
),
|
1958
|
+
# Logistic Regression with L1 Regularization (Lasso)
|
1959
|
+
"Lasso Logistic Regression": LogisticRegression(
|
1960
|
+
penalty="l1", solver="saga", random_state=random_state
|
1961
|
+
),
|
1962
|
+
"Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
|
1963
|
+
"XGBoost": xgb.XGBClassifier(
|
1964
|
+
eval_metric="logloss",
|
1965
|
+
random_state=random_state,
|
1966
|
+
),
|
1967
|
+
"KNN": KNeighborsClassifier(n_neighbors=5),
|
1968
|
+
"Naive Bayes": GaussianNB(),
|
1969
|
+
"Linear Discriminant Analysis": LinearDiscriminantAnalysis(),
|
1970
|
+
"AdaBoost": AdaBoostClassifier(
|
1971
|
+
algorithm="SAMME", random_state=random_state
|
1972
|
+
),
|
1973
|
+
# "LightGBM": lgb.LGBMClassifier(random_state=random_state, class_weight=class_weight),
|
1974
|
+
"CatBoost": cb.CatBoostClassifier(verbose=0, random_state=random_state),
|
1975
|
+
"Extra Trees": ExtraTreesClassifier(
|
1976
|
+
random_state=random_state, class_weight=class_weight
|
1977
|
+
),
|
1978
|
+
"Bagging": BaggingClassifier(random_state=random_state),
|
1979
|
+
"Neural Network": MLPClassifier(max_iter=500, random_state=random_state),
|
1980
|
+
"DecisionTree": DecisionTreeClassifier(),
|
1981
|
+
"Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis(),
|
1982
|
+
"Ridge": RidgeClassifierCV(
|
1983
|
+
class_weight=class_weight, store_cv_results=True
|
1984
|
+
),
|
1985
|
+
"Perceptron": Perceptron(random_state=random_state),
|
1986
|
+
"Bernoulli Naive Bayes": BernoulliNB(),
|
1987
|
+
"SGDClassifier": SGDClassifier(random_state=random_state),
|
1988
|
+
}
|
1989
|
+
elif purpose == "regression":
|
1990
|
+
model_ = {
|
1991
|
+
"Random Forest": RandomForestRegressor(random_state=random_state),
|
1992
|
+
"SVM": SVR(), # SVR (Support Vector Regression)
|
1993
|
+
# "Lasso": Lasso(random_state=random_state), # 它和LassoCV相同(必须要提供alpha参数),
|
1994
|
+
"LassoCV": LassoCV(
|
1995
|
+
cv=cv_folds, random_state=random_state
|
1996
|
+
), # LassoCV自动找出最适alpha,优于Lasso
|
1997
|
+
"Gradient Boosting": GradientBoostingRegressor(random_state=random_state),
|
1998
|
+
"XGBoost": xgb.XGBRegressor(eval_metric="rmse", random_state=random_state),
|
1999
|
+
"Linear Regression": LinearRegression(),
|
2000
|
+
"Lasso": Lasso(random_state=random_state),
|
2001
|
+
"AdaBoost": AdaBoostRegressor(random_state=random_state),
|
2002
|
+
# "LightGBM": lgb.LGBMRegressor(random_state=random_state),
|
2003
|
+
"CatBoost": cb.CatBoostRegressor(verbose=0, random_state=random_state),
|
2004
|
+
"Extra Trees": ExtraTreesRegressor(random_state=random_state),
|
2005
|
+
"Bagging": BaggingRegressor(random_state=random_state),
|
2006
|
+
"Neural Network": MLPRegressor(max_iter=500, random_state=random_state),
|
2007
|
+
"ElasticNet": ElasticNet(random_state=random_state),
|
2008
|
+
"Ridge": Ridge(),
|
2009
|
+
"KNN": KNeighborsRegressor(),
|
2010
|
+
}
|
2011
|
+
# indicate cls:
|
2012
|
+
if ips.run_once_within(30): # 10 min
|
2013
|
+
print(f"supported models: {list(model_.keys())}")
|
2014
|
+
if cls is None:
|
2015
|
+
models = model_
|
2016
|
+
else:
|
2017
|
+
if not isinstance(cls, list):
|
2018
|
+
cls = [cls]
|
2019
|
+
models = {}
|
2020
|
+
for cls_ in cls:
|
2021
|
+
cls_ = ips.strcmp(cls_, list(model_.keys()))[0]
|
2022
|
+
models[cls_] = model_[cls_]
|
2023
|
+
if "LightGBM" in models:
|
2024
|
+
x_train = ips.df_special_characters_cleaner(x_train)
|
2025
|
+
x_true = (
|
2026
|
+
ips.df_special_characters_cleaner(x_true) if x_true is not None else None
|
2027
|
+
)
|
2028
|
+
|
2029
|
+
if isinstance(y_train, str) and y_train in x_train.columns:
|
2030
|
+
y_train_col_name = y_train
|
2031
|
+
y_train = x_train[y_train]
|
2032
|
+
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
2033
|
+
x_train = x_train.drop(y_train_col_name, axis=1)
|
2034
|
+
# else:
|
2035
|
+
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
2036
|
+
y_train=pd.DataFrame(y_train)
|
2037
|
+
y_train_=ips.df_encoder(y_train, method="dummy")
|
2038
|
+
is_binary = False if y_train_.shape[1] >1 else True
|
2039
|
+
print(is_binary)
|
2040
|
+
if is_binary:
|
2041
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
2042
|
+
if x_true is None:
|
2043
|
+
x_train, x_true, y_train, y_true = train_test_split(
|
2044
|
+
x_train,
|
2045
|
+
y_train,
|
2046
|
+
test_size=test_size,
|
2047
|
+
random_state=random_state,
|
2048
|
+
stratify=y_train if purpose == "classification" else None,
|
2049
|
+
)
|
2050
|
+
if isinstance(y_train, str) and y_train in x_train.columns:
|
2051
|
+
y_train_col_name = y_train
|
2052
|
+
y_train = x_train[y_train]
|
2053
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
2054
|
+
x_train = x_train.drop(y_train_col_name, axis=1)
|
2055
|
+
else:
|
2056
|
+
y_train = ips.df_encoder(
|
2057
|
+
pd.DataFrame(y_train), method="dummy"
|
2058
|
+
).values.ravel()
|
2059
|
+
|
2060
|
+
if y_true is not None:
|
2061
|
+
if isinstance(y_true, str) and y_true in x_true.columns:
|
2062
|
+
y_true_col_name = y_true
|
2063
|
+
y_true = x_true[y_true]
|
2064
|
+
# y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
|
2065
|
+
y_true = pd.DataFrame(y_true)
|
2066
|
+
x_true = x_true.drop(y_true_col_name, axis=1)
|
2067
|
+
# else:
|
2068
|
+
# y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
|
2069
|
+
|
2070
|
+
# to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
|
2071
|
+
|
2072
|
+
# y_train=y_train.values.ravel() if y_train is not None else None
|
2073
|
+
# y_true=y_true.values.ravel() if y_true is not None else None
|
2074
|
+
y_train = (
|
2075
|
+
y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
|
2076
|
+
)
|
2077
|
+
print(len(y_train),len(y_true))
|
2078
|
+
y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
|
2079
|
+
print(len(y_train),len(y_true))
|
2080
|
+
# Ensure common features are selected
|
2081
|
+
if common_features is not None:
|
2082
|
+
x_train, x_true = x_train[common_features], x_true[common_features]
|
2083
|
+
else:
|
2084
|
+
share_col_names = ips.shared(x_train.columns, x_true.columns, verbose=verbose)
|
2085
|
+
x_train, x_true = x_train[share_col_names], x_true[share_col_names]
|
2086
|
+
|
2087
|
+
x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
|
2088
|
+
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(x_true, method="dummy")
|
2089
|
+
# Handle class imbalance using SMOTE (only for classification)
|
2090
|
+
if (
|
2091
|
+
smote
|
2092
|
+
and purpose == "classification"
|
2093
|
+
and y_train.value_counts(normalize=True).max() < 0.8
|
2094
|
+
):
|
2095
|
+
from imblearn.over_sampling import SMOTE
|
2096
|
+
|
2097
|
+
smote_sampler = SMOTE(random_state=random_state)
|
2098
|
+
x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
|
2099
|
+
|
2100
|
+
# Hyperparameter grids for tuning
|
2101
|
+
if cv_level in ["low", "simple", "s", "l"]:
|
2102
|
+
param_grids = {
|
2103
|
+
"Random Forest": (
|
2104
|
+
{
|
2105
|
+
"n_estimators": [100], # One basic option
|
2106
|
+
"max_depth": [None, 10],
|
2107
|
+
"min_samples_split": [2],
|
2108
|
+
"min_samples_leaf": [1],
|
2109
|
+
"class_weight": [None],
|
2110
|
+
}
|
2111
|
+
if purpose == "classification"
|
2112
|
+
else {
|
2113
|
+
"n_estimators": [100], # One basic option
|
2114
|
+
"max_depth": [None, 10],
|
2115
|
+
"min_samples_split": [2],
|
2116
|
+
"min_samples_leaf": [1],
|
2117
|
+
"max_features": [None],
|
2118
|
+
"bootstrap": [True], # Only one option for simplicity
|
2119
|
+
}
|
2120
|
+
),
|
2121
|
+
"SVM": {
|
2122
|
+
"C": [1],
|
2123
|
+
"gamma": ["scale"],
|
2124
|
+
"kernel": ["rbf"],
|
2125
|
+
},
|
2126
|
+
"Lasso": {
|
2127
|
+
"alpha": [0.1],
|
2128
|
+
},
|
2129
|
+
"LassoCV": {
|
2130
|
+
"alphas": [[0.1]],
|
2131
|
+
},
|
2132
|
+
"Logistic Regression": {
|
2133
|
+
"C": [1],
|
2134
|
+
"solver": ["lbfgs"],
|
2135
|
+
"penalty": ["l2"],
|
2136
|
+
"max_iter": [500],
|
2137
|
+
},
|
2138
|
+
"Gradient Boosting": {
|
2139
|
+
"n_estimators": [100],
|
2140
|
+
"learning_rate": [0.1],
|
2141
|
+
"max_depth": [3],
|
2142
|
+
"min_samples_split": [2],
|
2143
|
+
"subsample": [0.8],
|
2144
|
+
},
|
2145
|
+
"XGBoost": {
|
2146
|
+
"n_estimators": [100],
|
2147
|
+
"max_depth": [3],
|
2148
|
+
"learning_rate": [0.1],
|
2149
|
+
"subsample": [0.8],
|
2150
|
+
"colsample_bytree": [0.8],
|
2151
|
+
},
|
2152
|
+
"KNN": (
|
2153
|
+
{
|
2154
|
+
"n_neighbors": [3],
|
2155
|
+
"weights": ["uniform"],
|
2156
|
+
"algorithm": ["auto"],
|
2157
|
+
"p": [2],
|
2158
|
+
}
|
2159
|
+
if purpose == "classification"
|
2160
|
+
else {
|
2161
|
+
"n_neighbors": [3],
|
2162
|
+
"weights": ["uniform"],
|
2163
|
+
"metric": ["euclidean"],
|
2164
|
+
"leaf_size": [30],
|
2165
|
+
"p": [2],
|
2166
|
+
}
|
2167
|
+
),
|
2168
|
+
"Naive Bayes": {
|
2169
|
+
"var_smoothing": [1e-9],
|
2170
|
+
},
|
2171
|
+
"SVR": {
|
2172
|
+
"C": [1],
|
2173
|
+
"gamma": ["scale"],
|
2174
|
+
"kernel": ["rbf"],
|
2175
|
+
},
|
2176
|
+
"Linear Regression": {
|
2177
|
+
"fit_intercept": [True],
|
2178
|
+
},
|
2179
|
+
"Extra Trees": {
|
2180
|
+
"n_estimators": [100],
|
2181
|
+
"max_depth": [None, 10],
|
2182
|
+
"min_samples_split": [2],
|
2183
|
+
"min_samples_leaf": [1],
|
2184
|
+
},
|
2185
|
+
"CatBoost": {
|
2186
|
+
"iterations": [100],
|
2187
|
+
"learning_rate": [0.1],
|
2188
|
+
"depth": [3],
|
2189
|
+
"l2_leaf_reg": [1],
|
2190
|
+
},
|
2191
|
+
"LightGBM": {
|
2192
|
+
"n_estimators": [100],
|
2193
|
+
"num_leaves": [31],
|
2194
|
+
"max_depth": [10],
|
2195
|
+
"min_data_in_leaf": [20],
|
2196
|
+
"min_gain_to_split": [0.01],
|
2197
|
+
"scale_pos_weight": [10],
|
2198
|
+
},
|
2199
|
+
"Bagging": {
|
2200
|
+
"n_estimators": [50],
|
2201
|
+
"max_samples": [0.7],
|
2202
|
+
"max_features": [0.7],
|
2203
|
+
},
|
2204
|
+
"Neural Network": {
|
2205
|
+
"hidden_layer_sizes": [(50,)],
|
2206
|
+
"activation": ["relu"],
|
2207
|
+
"solver": ["adam"],
|
2208
|
+
"alpha": [0.0001],
|
2209
|
+
},
|
2210
|
+
"Decision Tree": {
|
2211
|
+
"max_depth": [None, 10],
|
2212
|
+
"min_samples_split": [2],
|
2213
|
+
"min_samples_leaf": [1],
|
2214
|
+
"criterion": ["gini"],
|
2215
|
+
},
|
2216
|
+
"AdaBoost": {
|
2217
|
+
"n_estimators": [50],
|
2218
|
+
"learning_rate": [0.5],
|
2219
|
+
},
|
2220
|
+
"Linear Discriminant Analysis": {
|
2221
|
+
"solver": ["svd"],
|
2222
|
+
"shrinkage": [None],
|
2223
|
+
},
|
2224
|
+
"Quadratic Discriminant Analysis": {
|
2225
|
+
"reg_param": [0.0],
|
2226
|
+
"priors": [None],
|
2227
|
+
"tol": [1e-4],
|
2228
|
+
},
|
2229
|
+
"Ridge": (
|
2230
|
+
{"class_weight": [None, "balanced"]}
|
2231
|
+
if purpose == "classification"
|
2232
|
+
else {
|
2233
|
+
"alpha": [0.1, 1, 10],
|
2234
|
+
}
|
2235
|
+
),
|
2236
|
+
"Perceptron": {
|
2237
|
+
"alpha": [1e-3],
|
2238
|
+
"penalty": ["l2"],
|
2239
|
+
"max_iter": [1000],
|
2240
|
+
"eta0": [1.0],
|
2241
|
+
},
|
2242
|
+
"Bernoulli Naive Bayes": {
|
2243
|
+
"alpha": [0.1, 1, 10],
|
2244
|
+
"binarize": [0.0],
|
2245
|
+
"fit_prior": [True],
|
2246
|
+
},
|
2247
|
+
"SGDClassifier": {
|
2248
|
+
"eta0": [0.01],
|
2249
|
+
"loss": ["hinge"],
|
2250
|
+
"penalty": ["l2"],
|
2251
|
+
"alpha": [1e-3],
|
2252
|
+
"max_iter": [1000],
|
2253
|
+
"tol": [1e-3],
|
2254
|
+
"random_state": [random_state],
|
2255
|
+
"learning_rate": ["constant"],
|
2256
|
+
},
|
2257
|
+
}
|
2258
|
+
elif cv_level in ["high", "advanced", "h"]:
|
2259
|
+
param_grids = {
|
2260
|
+
"Random Forest": (
|
2261
|
+
{
|
2262
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2263
|
+
"max_depth": [None, 3, 5, 10, 15, 20, 30],
|
2264
|
+
"min_samples_split": [2, 5, 10, 20],
|
2265
|
+
"min_samples_leaf": [1, 2, 4],
|
2266
|
+
"class_weight": (
|
2267
|
+
[None, "balanced"] if purpose == "classification" else {}
|
2268
|
+
),
|
2269
|
+
}
|
2270
|
+
if purpose == "classification"
|
2271
|
+
else {
|
2272
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2273
|
+
"max_depth": [None, 3, 5, 10, 15, 20, 30],
|
2274
|
+
"min_samples_split": [2, 5, 10, 20],
|
2275
|
+
"min_samples_leaf": [1, 2, 4],
|
2276
|
+
"max_features": [
|
2277
|
+
"auto",
|
2278
|
+
"sqrt",
|
2279
|
+
"log2",
|
2280
|
+
], # Number of features to consider when looking for the best split
|
2281
|
+
"bootstrap": [
|
2282
|
+
True,
|
2283
|
+
False,
|
2284
|
+
], # Whether bootstrap samples are used when building trees
|
2285
|
+
}
|
2286
|
+
),
|
2287
|
+
"SVM": {
|
2288
|
+
"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
|
2289
|
+
"gamma": ["scale", "auto", 0.001, 0.01, 0.1],
|
2290
|
+
"kernel": ["linear", "rbf", "poly"],
|
2291
|
+
},
|
2292
|
+
"Logistic Regression": {
|
2293
|
+
"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
|
2294
|
+
"solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
|
2295
|
+
"penalty": ["l1", "l2", "elasticnet"],
|
2296
|
+
"max_iter": [100, 200, 300, 500],
|
2297
|
+
},
|
2298
|
+
"Lasso": {
|
2299
|
+
"alpha": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2300
|
+
"max_iter": [500, 1000, 2000, 5000],
|
2301
|
+
"tol": [1e-4, 1e-5, 1e-6],
|
2302
|
+
"selection": ["cyclic", "random"],
|
2303
|
+
},
|
2304
|
+
"LassoCV": {
|
2305
|
+
"alphas": [[0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
|
2306
|
+
"max_iter": [500, 1000, 2000, 5000],
|
2307
|
+
"cv": [3, 5, 10],
|
2308
|
+
"tol": [1e-4, 1e-5, 1e-6],
|
2309
|
+
},
|
2310
|
+
"Gradient Boosting": {
|
2311
|
+
"n_estimators": [100, 200, 300, 400, 500, 700, 1000],
|
2312
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.5],
|
2313
|
+
"max_depth": [3, 5, 7, 9, 15],
|
2314
|
+
"min_samples_split": [2, 5, 10, 20],
|
2315
|
+
"subsample": [0.8, 1.0],
|
2316
|
+
},
|
2317
|
+
"XGBoost": {
|
2318
|
+
"n_estimators": [100, 200, 500, 700],
|
2319
|
+
"max_depth": [3, 5, 7, 10],
|
2320
|
+
"learning_rate": [0.01, 0.1, 0.2, 0.3],
|
2321
|
+
"subsample": [0.8, 1.0],
|
2322
|
+
"colsample_bytree": [0.8, 0.9, 1.0],
|
2323
|
+
},
|
2324
|
+
"KNN": (
|
2325
|
+
{
|
2326
|
+
"n_neighbors": [1, 3, 5, 10, 15, 20],
|
2327
|
+
"weights": ["uniform", "distance"],
|
2328
|
+
"algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
|
2329
|
+
"p": [1, 2], # 1 for Manhattan, 2 for Euclidean distance
|
2330
|
+
}
|
2331
|
+
if purpose == "classification"
|
2332
|
+
else {
|
2333
|
+
"n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
|
2334
|
+
"weights": [
|
2335
|
+
"uniform",
|
2336
|
+
"distance",
|
2337
|
+
], # Weight function used in prediction
|
2338
|
+
"metric": [
|
2339
|
+
"euclidean",
|
2340
|
+
"manhattan",
|
2341
|
+
"minkowski",
|
2342
|
+
], # Distance metric
|
2343
|
+
"leaf_size": [
|
2344
|
+
20,
|
2345
|
+
30,
|
2346
|
+
40,
|
2347
|
+
50,
|
2348
|
+
], # Leaf size for KDTree or BallTree algorithms
|
2349
|
+
"p": [
|
2350
|
+
1,
|
2351
|
+
2,
|
2352
|
+
], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
|
2353
|
+
}
|
2354
|
+
),
|
2355
|
+
"Naive Bayes": {
|
2356
|
+
"var_smoothing": [1e-10, 1e-9, 1e-8, 1e-7],
|
2357
|
+
},
|
2358
|
+
"AdaBoost": {
|
2359
|
+
"n_estimators": [50, 100, 200, 300, 500],
|
2360
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.5, 1.0],
|
2361
|
+
},
|
2362
|
+
"SVR": {
|
2363
|
+
"C": [0.01, 0.1, 1, 10, 100, 1000],
|
2364
|
+
"gamma": [0.001, 0.01, 0.1, "scale", "auto"],
|
2365
|
+
"kernel": ["linear", "rbf", "poly"],
|
2366
|
+
},
|
2367
|
+
"Linear Regression": {
|
2368
|
+
"fit_intercept": [True, False],
|
2369
|
+
},
|
2370
|
+
"Lasso": {
|
2371
|
+
"alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2372
|
+
"max_iter": [1000, 2000], # Higher iteration limit for fine-tuning
|
2373
|
+
},
|
2374
|
+
"Extra Trees": {
|
2375
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2376
|
+
"max_depth": [None, 5, 10, 15, 20, 30],
|
2377
|
+
"min_samples_split": [2, 5, 10, 20],
|
2378
|
+
"min_samples_leaf": [1, 2, 4],
|
2379
|
+
},
|
2380
|
+
"CatBoost": {
|
2381
|
+
"iterations": [100, 200, 500],
|
2382
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.2],
|
2383
|
+
"depth": [3, 5, 7, 10],
|
2384
|
+
"l2_leaf_reg": [1, 3, 5, 7, 10],
|
2385
|
+
"border_count": [32, 64, 128],
|
2386
|
+
},
|
2387
|
+
"LightGBM": {
|
2388
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2389
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.2],
|
2390
|
+
"num_leaves": [31, 50, 100, 200],
|
2391
|
+
"max_depth": [-1, 5, 10, 20, 30],
|
2392
|
+
"min_child_samples": [5, 10, 20],
|
2393
|
+
"subsample": [0.8, 1.0],
|
2394
|
+
"colsample_bytree": [0.8, 0.9, 1.0],
|
2395
|
+
},
|
2396
|
+
"Neural Network": {
|
2397
|
+
"hidden_layer_sizes": [(50,), (100,), (100, 50), (200, 100)],
|
2398
|
+
"activation": ["relu", "tanh", "logistic"],
|
2399
|
+
"solver": ["adam", "sgd", "lbfgs"],
|
2400
|
+
"alpha": [0.0001, 0.001, 0.01],
|
2401
|
+
"learning_rate": ["constant", "adaptive"],
|
2402
|
+
},
|
2403
|
+
"Decision Tree": {
|
2404
|
+
"max_depth": [None, 5, 10, 20, 30],
|
2405
|
+
"min_samples_split": [2, 5, 10, 20],
|
2406
|
+
"min_samples_leaf": [1, 2, 5, 10],
|
2407
|
+
"criterion": ["gini", "entropy"],
|
2408
|
+
"splitter": ["best", "random"],
|
2409
|
+
},
|
2410
|
+
"Linear Discriminant Analysis": {
|
2411
|
+
"solver": ["svd", "lsqr", "eigen"],
|
2412
|
+
"shrinkage": [
|
2413
|
+
None,
|
2414
|
+
"auto",
|
2415
|
+
0.1,
|
2416
|
+
0.5,
|
2417
|
+
1.0,
|
2418
|
+
], # shrinkage levels for 'lsqr' and 'eigen'
|
2419
|
+
},
|
2420
|
+
"Ridge": (
|
2421
|
+
{"class_weight": [None, "balanced"]}
|
2422
|
+
if purpose == "classification"
|
2423
|
+
else {
|
2424
|
+
"alpha": [0.1, 1, 10, 100, 1000],
|
2425
|
+
"solver": ["auto", "svd", "cholesky", "lsqr", "lbfgs"],
|
2426
|
+
"fit_intercept": [
|
2427
|
+
True,
|
2428
|
+
False,
|
2429
|
+
], # Whether to calculate the intercept
|
2430
|
+
"normalize": [
|
2431
|
+
True,
|
2432
|
+
False,
|
2433
|
+
], # If True, the regressors X will be normalized
|
2434
|
+
}
|
2435
|
+
),
|
2436
|
+
}
|
2437
|
+
else: # median level
|
2438
|
+
param_grids = {
|
2439
|
+
"Random Forest": (
|
2440
|
+
{
|
2441
|
+
"n_estimators": [100, 200, 500],
|
2442
|
+
"max_depth": [None, 10, 20, 30],
|
2443
|
+
"min_samples_split": [2, 5, 10],
|
2444
|
+
"min_samples_leaf": [1, 2, 4],
|
2445
|
+
"class_weight": [None, "balanced"],
|
2446
|
+
}
|
2447
|
+
if purpose == "classification"
|
2448
|
+
else {
|
2449
|
+
"n_estimators": [100, 200, 500],
|
2450
|
+
"max_depth": [None, 10, 20, 30],
|
2451
|
+
"min_samples_split": [2, 5, 10],
|
2452
|
+
"min_samples_leaf": [1, 2, 4],
|
2453
|
+
"max_features": [
|
2454
|
+
"auto",
|
2455
|
+
"sqrt",
|
2456
|
+
"log2",
|
2457
|
+
], # Number of features to consider when looking for the best split
|
2458
|
+
"bootstrap": [
|
2459
|
+
True,
|
2460
|
+
False,
|
2461
|
+
], # Whether bootstrap samples are used when building trees
|
2462
|
+
}
|
2463
|
+
),
|
2464
|
+
"SVM": {
|
2465
|
+
"C": [0.1, 1, 10, 100], # Regularization strength
|
2466
|
+
"gamma": ["scale", "auto"], # Common gamma values
|
2467
|
+
"kernel": ["rbf", "linear", "poly"],
|
2468
|
+
},
|
2469
|
+
"Logistic Regression": {
|
2470
|
+
"C": [0.1, 1, 10, 100], # Regularization strength
|
2471
|
+
"solver": ["lbfgs", "liblinear", "saga"], # Common solvers
|
2472
|
+
"penalty": ["l2"], # L2 penalty is most common
|
2473
|
+
"max_iter": [
|
2474
|
+
500,
|
2475
|
+
1000,
|
2476
|
+
2000,
|
2477
|
+
], # Increased max_iter for better convergence
|
2478
|
+
},
|
2479
|
+
"Lasso": {
|
2480
|
+
"alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2481
|
+
"max_iter": [500, 1000, 2000],
|
2482
|
+
},
|
2483
|
+
"LassoCV": {
|
2484
|
+
"alphas": [[0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
|
2485
|
+
"max_iter": [500, 1000, 2000],
|
2486
|
+
},
|
2487
|
+
"Gradient Boosting": {
|
2488
|
+
"n_estimators": [100, 200, 500],
|
2489
|
+
"learning_rate": [0.01, 0.1, 0.2],
|
2490
|
+
"max_depth": [3, 5, 7],
|
2491
|
+
"min_samples_split": [2, 5, 10],
|
2492
|
+
"subsample": [0.8, 1.0],
|
2493
|
+
},
|
2494
|
+
"XGBoost": {
|
2495
|
+
"n_estimators": [100, 200, 500],
|
2496
|
+
"max_depth": [3, 5, 7],
|
2497
|
+
"learning_rate": [0.01, 0.1, 0.2],
|
2498
|
+
"subsample": [0.8, 1.0],
|
2499
|
+
"colsample_bytree": [0.8, 1.0],
|
2500
|
+
},
|
2501
|
+
"KNN": (
|
2502
|
+
{
|
2503
|
+
"n_neighbors": [3, 5, 7, 10],
|
2504
|
+
"weights": ["uniform", "distance"],
|
2505
|
+
"algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
|
2506
|
+
"p": [1, 2],
|
2507
|
+
}
|
2508
|
+
if purpose == "classification"
|
2509
|
+
else {
|
2510
|
+
"n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
|
2511
|
+
"weights": [
|
2512
|
+
"uniform",
|
2513
|
+
"distance",
|
2514
|
+
], # Weight function used in prediction
|
2515
|
+
"metric": [
|
2516
|
+
"euclidean",
|
2517
|
+
"manhattan",
|
2518
|
+
"minkowski",
|
2519
|
+
], # Distance metric
|
2520
|
+
"leaf_size": [
|
2521
|
+
20,
|
2522
|
+
30,
|
2523
|
+
40,
|
2524
|
+
50,
|
2525
|
+
], # Leaf size for KDTree or BallTree algorithms
|
2526
|
+
"p": [
|
2527
|
+
1,
|
2528
|
+
2,
|
2529
|
+
], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
|
2530
|
+
}
|
2531
|
+
),
|
2532
|
+
"Naive Bayes": {
|
2533
|
+
"var_smoothing": [1e-9, 1e-8, 1e-7],
|
2534
|
+
},
|
2535
|
+
"SVR": {
|
2536
|
+
"C": [0.1, 1, 10, 100],
|
2537
|
+
"gamma": ["scale", "auto"],
|
2538
|
+
"kernel": ["rbf", "linear"],
|
2539
|
+
},
|
2540
|
+
"Linear Regression": {
|
2541
|
+
"fit_intercept": [True, False],
|
2542
|
+
},
|
2543
|
+
"Lasso": {
|
2544
|
+
"alpha": [0.1, 1.0, 10.0],
|
2545
|
+
"max_iter": [1000, 2000], # Sufficient iterations for convergence
|
2546
|
+
},
|
2547
|
+
"Extra Trees": {
|
2548
|
+
"n_estimators": [100, 200, 500],
|
2549
|
+
"max_depth": [None, 10, 20, 30],
|
2550
|
+
"min_samples_split": [2, 5, 10],
|
2551
|
+
"min_samples_leaf": [1, 2, 4],
|
2552
|
+
},
|
2553
|
+
"CatBoost": {
|
2554
|
+
"iterations": [100, 200],
|
2555
|
+
"learning_rate": [0.01, 0.1],
|
2556
|
+
"depth": [3, 6, 10],
|
2557
|
+
"l2_leaf_reg": [1, 3, 5, 7],
|
2558
|
+
},
|
2559
|
+
"LightGBM": {
|
2560
|
+
"n_estimators": [100, 200, 500],
|
2561
|
+
"learning_rate": [0.01, 0.1],
|
2562
|
+
"num_leaves": [31, 50, 100],
|
2563
|
+
"max_depth": [-1, 10, 20],
|
2564
|
+
"min_data_in_leaf": [20], # Minimum samples in each leaf
|
2565
|
+
"min_gain_to_split": [0.01], # Minimum gain to allow a split
|
2566
|
+
"scale_pos_weight": [10], # Address class imbalance
|
2567
|
+
},
|
2568
|
+
"Bagging": {
|
2569
|
+
"n_estimators": [10, 50, 100],
|
2570
|
+
"max_samples": [0.5, 0.7, 1.0],
|
2571
|
+
"max_features": [0.5, 0.7, 1.0],
|
2572
|
+
},
|
2573
|
+
"Neural Network": {
|
2574
|
+
"hidden_layer_sizes": [(50,), (100,), (100, 50)],
|
2575
|
+
"activation": ["relu", "tanh"],
|
2576
|
+
"solver": ["adam", "sgd"],
|
2577
|
+
"alpha": [0.0001, 0.001],
|
2578
|
+
},
|
2579
|
+
"Decision Tree": {
|
2580
|
+
"max_depth": [None, 10, 20],
|
2581
|
+
"min_samples_split": [2, 10],
|
2582
|
+
"min_samples_leaf": [1, 4],
|
2583
|
+
"criterion": ["gini", "entropy"],
|
2584
|
+
},
|
2585
|
+
"AdaBoost": {
|
2586
|
+
"n_estimators": [50, 100],
|
2587
|
+
"learning_rate": [0.5, 1.0],
|
2588
|
+
},
|
2589
|
+
"Linear Discriminant Analysis": {
|
2590
|
+
"solver": ["svd", "lsqr", "eigen"],
|
2591
|
+
"shrinkage": [None, "auto"],
|
2592
|
+
},
|
2593
|
+
"Quadratic Discriminant Analysis": {
|
2594
|
+
"reg_param": [0.0, 0.1, 0.5, 1.0], # Regularization parameter
|
2595
|
+
"priors": [None, [0.5, 0.5], [0.3, 0.7]], # Class priors
|
2596
|
+
"tol": [
|
2597
|
+
1e-4,
|
2598
|
+
1e-3,
|
2599
|
+
1e-2,
|
2600
|
+
], # Tolerance value for the convergence of the algorithm
|
2601
|
+
},
|
2602
|
+
"Perceptron": {
|
2603
|
+
"alpha": [1e-4, 1e-3, 1e-2], # Regularization parameter
|
2604
|
+
"penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
|
2605
|
+
"max_iter": [1000, 2000], # Maximum number of iterations
|
2606
|
+
"eta0": [1.0, 0.1], # Learning rate for gradient descent
|
2607
|
+
"tol": [1e-3, 1e-4, 1e-5], # Tolerance for stopping criteria
|
2608
|
+
"random_state": [random_state], # Random state for reproducibility
|
2609
|
+
},
|
2610
|
+
"Bernoulli Naive Bayes": {
|
2611
|
+
"alpha": [0.1, 1.0, 10.0], # Additive (Laplace) smoothing parameter
|
2612
|
+
"binarize": [
|
2613
|
+
0.0,
|
2614
|
+
0.5,
|
2615
|
+
1.0,
|
2616
|
+
], # Threshold for binarizing the input features
|
2617
|
+
"fit_prior": [
|
2618
|
+
True,
|
2619
|
+
False,
|
2620
|
+
], # Whether to learn class prior probabilities
|
2621
|
+
},
|
2622
|
+
"SGDClassifier": {
|
2623
|
+
"eta0": [0.01, 0.1, 1.0],
|
2624
|
+
"loss": [
|
2625
|
+
"hinge",
|
2626
|
+
"log",
|
2627
|
+
"modified_huber",
|
2628
|
+
"squared_hinge",
|
2629
|
+
"perceptron",
|
2630
|
+
], # Loss function
|
2631
|
+
"penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
|
2632
|
+
"alpha": [1e-4, 1e-3, 1e-2], # Regularization strength
|
2633
|
+
"l1_ratio": [0.15, 0.5, 0.85], # L1 ratio for elasticnet penalty
|
2634
|
+
"max_iter": [1000, 2000], # Maximum number of iterations
|
2635
|
+
"tol": [1e-3, 1e-4], # Tolerance for stopping criteria
|
2636
|
+
"random_state": [random_state], # Random state for reproducibility
|
2637
|
+
"learning_rate": [
|
2638
|
+
"constant",
|
2639
|
+
"optimal",
|
2640
|
+
"invscaling",
|
2641
|
+
"adaptive",
|
2642
|
+
], # Learning rate schedule
|
2643
|
+
},
|
2644
|
+
"Ridge": (
|
2645
|
+
{"class_weight": [None, "balanced"]}
|
2646
|
+
if purpose == "classification"
|
2647
|
+
else {
|
2648
|
+
"alpha": [0.1, 1, 10, 100],
|
2649
|
+
"solver": [
|
2650
|
+
"auto",
|
2651
|
+
"svd",
|
2652
|
+
"cholesky",
|
2653
|
+
"lsqr",
|
2654
|
+
], # Solver for optimization
|
2655
|
+
}
|
2656
|
+
),
|
2657
|
+
}
|
2658
|
+
|
2659
|
+
results = {}
|
2660
|
+
# Use StratifiedKFold for classification and KFold for regression
|
2661
|
+
cv = (
|
2662
|
+
StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
2663
|
+
if purpose == "classification"
|
2664
|
+
else KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
2665
|
+
)
|
2666
|
+
|
2667
|
+
# Train and validate each model
|
2668
|
+
for name, clf in tqdm(
|
2669
|
+
models.items(),
|
2670
|
+
desc="models",
|
2671
|
+
colour="green",
|
2672
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
|
2673
|
+
):
|
2674
|
+
if verbose:
|
2675
|
+
print(f"\nTraining and validating {name}:")
|
2676
|
+
|
2677
|
+
# Grid search with KFold or StratifiedKFold
|
2678
|
+
gs = GridSearchCV(
|
2679
|
+
clf,
|
2680
|
+
param_grid=param_grids.get(name, {}),
|
2681
|
+
scoring=(
|
2682
|
+
"roc_auc" if purpose == "classification" else "neg_mean_squared_error"
|
2683
|
+
),
|
2684
|
+
cv=cv,
|
2685
|
+
n_jobs=n_jobs,
|
2686
|
+
verbose=verbose,
|
2687
|
+
)
|
2688
|
+
gs.fit(x_train, y_train)
|
2689
|
+
best_clf = gs.best_estimator_
|
2690
|
+
# make sure x_train and x_test has the same name
|
2691
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2692
|
+
y_pred = best_clf.predict(x_true)
|
2693
|
+
|
2694
|
+
# y_pred_proba
|
2695
|
+
if hasattr(best_clf, "predict_proba"):
|
2696
|
+
y_pred_proba = best_clf.predict_proba(x_true)[:, 1]
|
2697
|
+
elif hasattr(best_clf, "decision_function"):
|
2698
|
+
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
2699
|
+
y_pred_proba = best_clf.decision_function(x_true)
|
2700
|
+
# Ensure y_pred_proba is within 0 and 1 bounds
|
2701
|
+
y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
|
2702
|
+
y_pred_proba.max() - y_pred_proba.min()
|
2703
|
+
)
|
2704
|
+
else:
|
2705
|
+
y_pred_proba = None # No probability output for certain models
|
2706
|
+
|
2707
|
+
validation_scores = {}
|
2708
|
+
if y_true is not None:
|
2709
|
+
validation_scores = cal_metrics(
|
2710
|
+
y_true,
|
2711
|
+
y_pred,
|
2712
|
+
y_pred_proba=y_pred_proba,
|
2713
|
+
is_binary=is_binary,
|
2714
|
+
purpose=purpose,
|
2715
|
+
average="weighted",
|
2716
|
+
)
|
2717
|
+
|
2718
|
+
# Calculate ROC curve
|
2719
|
+
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
2720
|
+
if y_pred_proba is not None:
|
2721
|
+
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
2722
|
+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
2723
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
|
2724
|
+
roc_auc = auc(fpr, tpr)
|
2725
|
+
roc_info = {
|
2726
|
+
"fpr": fpr.tolist(),
|
2727
|
+
"tpr": tpr.tolist(),
|
2728
|
+
"auc": roc_auc,
|
2729
|
+
"ci95": (lower_ci, upper_ci),
|
2730
|
+
}
|
2731
|
+
# precision-recall curve
|
2732
|
+
precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
|
2733
|
+
avg_precision_ = average_precision_score(y_true, y_pred_proba)
|
2734
|
+
pr_info = {
|
2735
|
+
"precision": precision_,
|
2736
|
+
"recall": recall_,
|
2737
|
+
"avg_precision": avg_precision_,
|
2738
|
+
}
|
2739
|
+
else:
|
2740
|
+
roc_info, pr_info = None, None
|
2741
|
+
if purpose == "classification":
|
2742
|
+
results[name] = {
|
2743
|
+
"best_clf": gs.best_estimator_,
|
2744
|
+
"best_params": gs.best_params_,
|
2745
|
+
"auc_indiv": [
|
2746
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
2747
|
+
for i in range(cv_folds)
|
2748
|
+
],
|
2749
|
+
"scores": validation_scores,
|
2750
|
+
"roc_curve": roc_info,
|
2751
|
+
"pr_curve": pr_info,
|
2752
|
+
"confusion_matrix": confusion_matrix(y_true, y_pred),
|
2753
|
+
"predictions": y_pred.tolist(),
|
2754
|
+
"predictions_proba": (
|
2755
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
2756
|
+
),
|
2757
|
+
}
|
2758
|
+
else: # "regression"
|
2759
|
+
results[name] = {
|
2760
|
+
"best_clf": gs.best_estimator_,
|
2761
|
+
"best_params": gs.best_params_,
|
2762
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
2763
|
+
"predictions": y_pred.tolist(),
|
2764
|
+
"predictions_proba": (
|
2765
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
2766
|
+
),
|
2767
|
+
}
|
2768
|
+
|
2769
|
+
else:
|
2770
|
+
results[name] = {
|
2771
|
+
"best_clf": gs.best_estimator_,
|
2772
|
+
"best_params": gs.best_params_,
|
2773
|
+
"scores": validation_scores,
|
2774
|
+
"predictions": y_pred.tolist(),
|
2775
|
+
"predictions_proba": (
|
2776
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
2777
|
+
),
|
2778
|
+
}
|
2779
|
+
|
2780
|
+
# Convert results to DataFrame
|
2781
|
+
df_results = pd.DataFrame.from_dict(results, orient="index")
|
2782
|
+
|
2783
|
+
# sort
|
2784
|
+
if y_true is not None and purpose == "classification":
|
2785
|
+
df_scores = pd.DataFrame(
|
2786
|
+
df_results["scores"].tolist(), index=df_results["scores"].index
|
2787
|
+
).sort_values(by="roc_auc", ascending=False)
|
2788
|
+
df_results = df_results.loc[df_scores.index]
|
2789
|
+
|
2790
|
+
if plot_:
|
2791
|
+
from datetime import datetime
|
2792
|
+
|
2793
|
+
now_ = datetime.now().strftime("%y%m%d_%H%M%S")
|
2794
|
+
nexttile = plot.subplot(figsize=[12, 10])
|
2795
|
+
plot.heatmap(df_scores, kind="direct", ax=nexttile())
|
2796
|
+
plot.figsets(xangle=30)
|
2797
|
+
if dir_save:
|
2798
|
+
ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
|
2799
|
+
if df_scores.shape[0] > 1: # draw cluster
|
2800
|
+
plot.heatmap(df_scores, kind="direct", cluster=True)
|
2801
|
+
plot.figsets(xangle=30)
|
2802
|
+
if dir_save:
|
2803
|
+
ips.figsave(dir_save + f"scores_clus{now_}.pdf")
|
2804
|
+
if all([plot_, y_true is not None, purpose == "classification"]):
|
2805
|
+
try:
|
2806
|
+
if len(models) > 3:
|
2807
|
+
plot_validate_features(df_results)
|
2808
|
+
else:
|
2809
|
+
plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
|
2810
|
+
if dir_save:
|
2811
|
+
ips.figsave(dir_save + f"validate_features{now_}.pdf")
|
2812
|
+
except Exception as e:
|
2813
|
+
print(f"Error: 在画图的过程中出现了问题:{e}")
|
2814
|
+
return df_results
|
2815
|
+
|
2816
|
+
|
2817
|
+
def cal_metrics(
|
2818
|
+
y_true, y_pred, y_pred_proba=None, is_binary=True,purpose="regression", average="weighted"
|
2819
|
+
):
|
2820
|
+
"""
|
2821
|
+
Calculate regression or classification metrics based on the purpose.
|
2822
|
+
|
2823
|
+
Parameters:
|
2824
|
+
- y_true: Array of true values.
|
2825
|
+
- y_pred: Array of predicted labels for classification or predicted values for regression.
|
2826
|
+
- y_pred_proba: Array of predicted probabilities for classification (optional).
|
2827
|
+
- purpose: str, "regression" or "classification".
|
2828
|
+
- average: str, averaging method for multi-class classification ("binary", "micro", "macro", "weighted", etc.).
|
2829
|
+
|
2830
|
+
Returns:
|
2831
|
+
- validation_scores: dict of computed metrics.
|
2832
|
+
"""
|
2833
|
+
from sklearn.metrics import (
|
2834
|
+
mean_squared_error,
|
2835
|
+
mean_absolute_error,
|
2836
|
+
mean_absolute_percentage_error,
|
2837
|
+
explained_variance_score,
|
2838
|
+
r2_score,
|
2839
|
+
mean_squared_log_error,
|
2840
|
+
accuracy_score,
|
2841
|
+
precision_score,
|
2842
|
+
recall_score,
|
2843
|
+
f1_score,
|
2844
|
+
roc_auc_score,
|
2845
|
+
matthews_corrcoef,
|
2846
|
+
confusion_matrix,
|
2847
|
+
balanced_accuracy_score,
|
2848
|
+
average_precision_score,
|
2849
|
+
precision_recall_curve,
|
2850
|
+
)
|
2851
|
+
|
2852
|
+
validation_scores = {}
|
2853
|
+
|
2854
|
+
if purpose == "regression":
|
2855
|
+
y_true = np.asarray(y_true)
|
2856
|
+
y_true = y_true.ravel()
|
2857
|
+
y_pred = np.asarray(y_pred)
|
2858
|
+
y_pred = y_pred.ravel()
|
2859
|
+
# Regression metrics
|
2860
|
+
validation_scores = {
|
2861
|
+
"mse": mean_squared_error(y_true, y_pred),
|
2862
|
+
"rmse": np.sqrt(mean_squared_error(y_true, y_pred)),
|
2863
|
+
"mae": mean_absolute_error(y_true, y_pred),
|
2864
|
+
"r2": r2_score(y_true, y_pred),
|
2865
|
+
"mape": mean_absolute_percentage_error(y_true, y_pred),
|
2866
|
+
"explained_variance": explained_variance_score(y_true, y_pred),
|
2867
|
+
"mbd": np.mean(y_pred - y_true), # Mean Bias Deviation
|
2868
|
+
}
|
2869
|
+
# Check if MSLE can be calculated
|
2870
|
+
if np.all(y_true >= 0) and np.all(y_pred >= 0): # Ensure no negative values
|
2871
|
+
validation_scores["msle"] = mean_squared_log_error(y_true, y_pred)
|
2872
|
+
else:
|
2873
|
+
validation_scores["msle"] = "Cannot be calculated due to negative values"
|
2874
|
+
|
2875
|
+
elif purpose == "classification":
|
2876
|
+
# Classification metrics
|
2877
|
+
validation_scores = {
|
2878
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
2879
|
+
"precision": precision_score(y_true, y_pred, average=average),
|
2880
|
+
"recall": recall_score(y_true, y_pred, average=average),
|
2881
|
+
"f1": f1_score(y_true, y_pred, average=average),
|
2882
|
+
"mcc": matthews_corrcoef(y_true, y_pred),
|
2883
|
+
"specificity": None,
|
2884
|
+
"balanced_accuracy": balanced_accuracy_score(y_true, y_pred),
|
2885
|
+
}
|
2886
|
+
|
2887
|
+
# Confusion matrix to calculate specificity
|
2888
|
+
if is_binary:
|
2889
|
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
2890
|
+
else:
|
2891
|
+
cm=onfusion_matrix(y_true, y_pred)
|
2892
|
+
validation_scores["specificity"] = (
|
2893
|
+
tn / (tn + fp) if (tn + fp) > 0 else 0
|
2894
|
+
) # Specificity calculation
|
2895
|
+
|
2896
|
+
if y_pred_proba is not None:
|
2897
|
+
# Calculate ROC-AUC
|
2898
|
+
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2899
|
+
# PR-AUC (Precision-Recall AUC) calculation
|
2900
|
+
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
2901
|
+
else:
|
2902
|
+
raise ValueError(
|
2903
|
+
"Invalid purpose specified. Choose 'regression' or 'classification'."
|
2904
|
+
)
|
2905
|
+
|
2906
|
+
return validation_scores
|