py2ls 0.2.4.10.3__py3-none-any.whl → 0.2.4.10.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
py2ls/ml2ls.py
CHANGED
@@ -4,12 +4,13 @@ from sklearn.ensemble import (
|
|
4
4
|
AdaBoostClassifier,
|
5
5
|
BaggingClassifier,
|
6
6
|
)
|
7
|
-
from sklearn.svm import SVC,SVR
|
7
|
+
from sklearn.svm import SVC, SVR
|
8
8
|
from sklearn.calibration import CalibratedClassifierCV
|
9
9
|
from sklearn.model_selection import GridSearchCV, StratifiedKFold
|
10
10
|
from sklearn.linear_model import (
|
11
11
|
LassoCV,
|
12
|
-
LogisticRegression,
|
12
|
+
LogisticRegression,
|
13
|
+
LinearRegression,
|
13
14
|
Lasso,
|
14
15
|
Ridge,
|
15
16
|
RidgeClassifierCV,
|
@@ -47,7 +48,7 @@ from . import plot
|
|
47
48
|
import matplotlib.pyplot as plt
|
48
49
|
import seaborn as sns
|
49
50
|
|
50
|
-
plt.style.use("paper")
|
51
|
+
plt.style.use(str(ips.get_cwd()) + "/data/styles/stylelib/paper.mplstyle")
|
51
52
|
import logging
|
52
53
|
import warnings
|
53
54
|
|
@@ -334,13 +335,13 @@ def features_naive_bayes(x_train: pd.DataFrame, y_train: pd.Series) -> list:
|
|
334
335
|
probabilities = nb.predict_proba(x_train)
|
335
336
|
# Limit the number of features safely, choosing the lesser of half the features or all columns
|
336
337
|
n_features = min(x_train.shape[1] // 2, len(x_train.columns))
|
337
|
-
|
338
|
+
|
338
339
|
# Sort probabilities, then map to valid column indices
|
339
340
|
sorted_indices = np.argsort(probabilities.max(axis=1))[:n_features]
|
340
|
-
|
341
|
+
|
341
342
|
# Ensure indices are within the column bounds of x_train
|
342
343
|
valid_indices = sorted_indices[sorted_indices < len(x_train.columns)]
|
343
|
-
|
344
|
+
|
344
345
|
return x_train.columns[valid_indices]
|
345
346
|
|
346
347
|
|
@@ -575,15 +576,28 @@ def get_features(
|
|
575
576
|
bagging_params: Optional[Dict] = None,
|
576
577
|
knn_params: Optional[Dict] = None,
|
577
578
|
cls: list = [
|
578
|
-
"lasso",
|
579
|
-
"
|
579
|
+
"lasso",
|
580
|
+
"ridge",
|
581
|
+
"Elastic Net(Enet)",
|
582
|
+
"gradient Boosting",
|
583
|
+
"Random forest (rf)",
|
584
|
+
"XGBoost (xgb)",
|
585
|
+
"Support Vector Machine(svm)",
|
586
|
+
"naive bayes",
|
587
|
+
"Linear Discriminant Analysis (lda)",
|
588
|
+
"adaboost",
|
589
|
+
"DecisionTree",
|
590
|
+
"KNeighbors",
|
591
|
+
"Bagging",
|
592
|
+
],
|
580
593
|
metrics: Optional[List[str]] = None,
|
581
594
|
cv_folds: int = 5,
|
582
595
|
strict: bool = False,
|
583
596
|
n_shared: int = 2, # 只要有两个方法有重合,就纳入common genes
|
584
597
|
use_selected_features: bool = True,
|
585
598
|
plot_: bool = True,
|
586
|
-
dir_save:str="./"
|
599
|
+
dir_save: str = "./",
|
600
|
+
) -> dict:
|
587
601
|
"""
|
588
602
|
Master function to perform feature selection and validate models.
|
589
603
|
"""
|
@@ -598,14 +612,14 @@ def get_features(
|
|
598
612
|
|
599
613
|
# fill na
|
600
614
|
if fill_missing:
|
601
|
-
ips.df_fillna(data=X,method=
|
602
|
-
if isinstance(y, str) and y in X.columns:
|
603
|
-
y_col_name=y
|
604
|
-
y=X[y]
|
605
|
-
y=ips.df_encoder(pd.DataFrame(y),method=
|
606
|
-
X = X.drop(y_col_name,axis=1)
|
615
|
+
ips.df_fillna(data=X, method="knn", inplace=True, axis=0)
|
616
|
+
if isinstance(y, str) and y in X.columns:
|
617
|
+
y_col_name = y
|
618
|
+
y = X[y]
|
619
|
+
y = ips.df_encoder(pd.DataFrame(y), method="dummy")
|
620
|
+
X = X.drop(y_col_name, axis=1)
|
607
621
|
else:
|
608
|
-
y=ips.df_encoder(pd.DataFrame(y),method=
|
622
|
+
y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
|
609
623
|
y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
|
610
624
|
y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
|
611
625
|
|
@@ -817,7 +831,7 @@ def get_features(
|
|
817
831
|
top_knn_features,
|
818
832
|
strict=strict,
|
819
833
|
n_shared=n_shared,
|
820
|
-
verbose=False
|
834
|
+
verbose=False,
|
821
835
|
)
|
822
836
|
|
823
837
|
# Use selected features or all features for model validation
|
@@ -899,13 +913,14 @@ def get_features(
|
|
899
913
|
results = {
|
900
914
|
"selected_features": features_df,
|
901
915
|
"cv_train_scores": cv_train_results_df,
|
902
|
-
"cv_test_scores": rank_models(cv_test_results_df,plot_=plot_),
|
916
|
+
"cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
|
903
917
|
"common_features": list(common_features),
|
904
918
|
}
|
905
|
-
if all([plot_,dir_save]):
|
919
|
+
if all([plot_, dir_save]):
|
906
920
|
from datetime import datetime
|
921
|
+
|
907
922
|
now_ = datetime.now().strftime("%y%m%d_%H%M%S")
|
908
|
-
ips.figsave(dir_save+f"features{now_}.pdf")
|
923
|
+
ips.figsave(dir_save + f"features{now_}.pdf")
|
909
924
|
else:
|
910
925
|
results = {
|
911
926
|
"selected_features": pd.DataFrame(),
|
@@ -931,7 +946,7 @@ def validate_features(
|
|
931
946
|
metrics: Optional[list] = None,
|
932
947
|
random_state: int = 1,
|
933
948
|
smote: bool = False,
|
934
|
-
n_jobs:int = -1,
|
949
|
+
n_jobs: int = -1,
|
935
950
|
plot_: bool = True,
|
936
951
|
class_weight: str = "balanced",
|
937
952
|
) -> dict:
|
@@ -952,8 +967,11 @@ def validate_features(
|
|
952
967
|
|
953
968
|
"""
|
954
969
|
from tqdm import tqdm
|
970
|
+
|
955
971
|
# Ensure common features are selected
|
956
|
-
common_features = ips.shared(
|
972
|
+
common_features = ips.shared(
|
973
|
+
common_features, x_train.columns, x_true.columns, strict=True, verbose=False
|
974
|
+
)
|
957
975
|
|
958
976
|
# Filter the training and validation datasets for the common features
|
959
977
|
x_train_selected = x_train[common_features]
|
@@ -1007,8 +1025,7 @@ def validate_features(
|
|
1007
1025
|
l1_ratio=0.5,
|
1008
1026
|
random_state=random_state,
|
1009
1027
|
),
|
1010
|
-
"XGBoost": xgb.XGBClassifier(eval_metric="logloss"
|
1011
|
-
),
|
1028
|
+
"XGBoost": xgb.XGBClassifier(eval_metric="logloss"),
|
1012
1029
|
"Naive Bayes": GaussianNB(),
|
1013
1030
|
"LDA": LinearDiscriminantAnalysis(),
|
1014
1031
|
}
|
@@ -1078,11 +1095,11 @@ def validate_features(
|
|
1078
1095
|
|
1079
1096
|
# Validate each classifier with GridSearchCV
|
1080
1097
|
for name, clf in tqdm(
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1098
|
+
models.items(),
|
1099
|
+
desc="for metric in metrics",
|
1100
|
+
colour="green",
|
1101
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
|
1102
|
+
):
|
1086
1103
|
print(f"\nValidating {name} on the validation dataset:")
|
1087
1104
|
|
1088
1105
|
# Check if `predict_proba` method exists; if not, use CalibratedClassifierCV
|
@@ -1162,7 +1179,7 @@ def validate_features(
|
|
1162
1179
|
if y_pred_proba is not None:
|
1163
1180
|
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
1164
1181
|
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
1165
|
-
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba,verbose=False)
|
1182
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
|
1166
1183
|
roc_auc = auc(fpr, tpr)
|
1167
1184
|
roc_info = {
|
1168
1185
|
"fpr": fpr.tolist(),
|
@@ -1197,6 +1214,7 @@ def validate_features(
|
|
1197
1214
|
# Validate models using the validation dataset (X_val, y_val)
|
1198
1215
|
# validation_results = validate_features(X, y, X_val, y_val, common_features)
|
1199
1216
|
|
1217
|
+
|
1200
1218
|
# # If you want to access validation scores
|
1201
1219
|
# print(validation_results)
|
1202
1220
|
def plot_validate_features(res_val):
|
@@ -1204,47 +1222,75 @@ def plot_validate_features(res_val):
|
|
1204
1222
|
plot the results of 'validate_features()'
|
1205
1223
|
"""
|
1206
1224
|
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1207
|
-
if res_val.shape[0]>5:
|
1208
|
-
alpha=0
|
1209
|
-
figsize=[8,10]
|
1210
|
-
subplot_layout=[1,2]
|
1211
|
-
ncols=2
|
1212
|
-
bbox_to_anchor=[1.5,0.6]
|
1225
|
+
if res_val.shape[0] > 5:
|
1226
|
+
alpha = 0
|
1227
|
+
figsize = [8, 10]
|
1228
|
+
subplot_layout = [1, 2]
|
1229
|
+
ncols = 2
|
1230
|
+
bbox_to_anchor = [1.5, 0.6]
|
1213
1231
|
else:
|
1214
|
-
alpha=0.03
|
1215
|
-
figsize=[10,6]
|
1216
|
-
subplot_layout=[1,1]
|
1217
|
-
ncols=1
|
1218
|
-
bbox_to_anchor=[1,1]
|
1232
|
+
alpha = 0.03
|
1233
|
+
figsize = [10, 6]
|
1234
|
+
subplot_layout = [1, 1]
|
1235
|
+
ncols = 1
|
1236
|
+
bbox_to_anchor = [1, 1]
|
1219
1237
|
nexttile = plot.subplot(figsize=figsize)
|
1220
|
-
ax = nexttile(subplot_layout[0],subplot_layout[1])
|
1238
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1221
1239
|
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1222
1240
|
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1223
1241
|
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1224
1242
|
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1225
1243
|
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1226
1244
|
plot_roc_curve(
|
1227
|
-
fpr,
|
1228
|
-
|
1229
|
-
|
1245
|
+
fpr,
|
1246
|
+
tpr,
|
1247
|
+
mean_auc,
|
1248
|
+
lower_ci,
|
1249
|
+
upper_ci,
|
1250
|
+
model_name=model_name,
|
1251
|
+
lw=1.5,
|
1252
|
+
color=colors[i],
|
1253
|
+
alpha=alpha,
|
1254
|
+
ax=ax,
|
1255
|
+
)
|
1256
|
+
plot.figsets(
|
1257
|
+
sp=2,
|
1258
|
+
legend=dict(
|
1259
|
+
loc="upper right",
|
1260
|
+
ncols=ncols,
|
1261
|
+
fontsize=8,
|
1262
|
+
bbox_to_anchor=[1.5, 0.6],
|
1263
|
+
markerscale=0.8,
|
1264
|
+
),
|
1265
|
+
)
|
1230
1266
|
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1231
1267
|
|
1232
|
-
ax = nexttile(subplot_layout[0],subplot_layout[1])
|
1268
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1233
1269
|
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1234
1270
|
plot_pr_curve(
|
1235
1271
|
recall=res_val["pr_curve"][model_name]["recall"],
|
1236
1272
|
precision=res_val["pr_curve"][model_name]["precision"],
|
1237
1273
|
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1238
1274
|
model_name=model_name,
|
1239
|
-
color=colors[i],
|
1240
|
-
|
1275
|
+
color=colors[i],
|
1276
|
+
lw=1.5,
|
1277
|
+
alpha=alpha,
|
1278
|
+
ax=ax,
|
1279
|
+
)
|
1280
|
+
plot.figsets(
|
1281
|
+
sp=2,
|
1282
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1283
|
+
)
|
1241
1284
|
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1242
|
-
|
1243
|
-
|
1285
|
+
|
1286
|
+
|
1287
|
+
def plot_validate_features_single(res_val, figsize=None):
|
1244
1288
|
if figsize is None:
|
1245
1289
|
nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
|
1246
1290
|
else:
|
1247
|
-
nexttile = plot.subplot(
|
1291
|
+
nexttile = plot.subplot(
|
1292
|
+
len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
|
1293
|
+
)
|
1248
1294
|
for model_name in ips.flatten(res_val["pr_curve"].index):
|
1249
1295
|
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1250
1296
|
tpr = res_val["roc_curve"][model_name]["tpr"]
|
@@ -1268,7 +1314,10 @@ def plot_validate_features_single(res_val,figsize=None):
|
|
1268
1314
|
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1269
1315
|
plot.figsets(title=model_name, sp=2)
|
1270
1316
|
|
1271
|
-
|
1317
|
+
|
1318
|
+
def cal_auc_ci(
|
1319
|
+
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
|
1320
|
+
):
|
1272
1321
|
y_true = np.asarray(y_true)
|
1273
1322
|
y_pred = np.asarray(y_pred)
|
1274
1323
|
bootstrapped_scores = []
|
@@ -1298,10 +1347,10 @@ def cal_auc_ci(y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,verbos
|
|
1298
1347
|
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1299
1348
|
if verbose:
|
1300
1349
|
print(
|
1301
|
-
|
1302
|
-
|
1350
|
+
"Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
|
1351
|
+
confidence_lower, confidence_upper
|
1352
|
+
)
|
1303
1353
|
)
|
1304
|
-
)
|
1305
1354
|
return confidence_lower, confidence_upper
|
1306
1355
|
|
1307
1356
|
|
@@ -1339,7 +1388,7 @@ def plot_roc_curve(
|
|
1339
1388
|
# Plot ROC curve and the diagonal reference line
|
1340
1389
|
ax.fill_between(fpr, tpr, alpha=alpha, color=color)
|
1341
1390
|
ax.plot([0, 1], [0, 1], color=diagonal_color, clip_on=False, linestyle="--")
|
1342
|
-
ax.plot(fpr, tpr, color=color, lw=lw, label=label,clip_on=False, **kwargs)
|
1391
|
+
ax.plot(fpr, tpr, color=color, lw=lw, label=label, clip_on=False, **kwargs)
|
1343
1392
|
# Setting plot limits, labels, and title
|
1344
1393
|
ax.set_xlim([-0.01, 1.0])
|
1345
1394
|
ax.set_ylim([0.0, 1.0])
|
@@ -1536,12 +1585,11 @@ def plot_cm(
|
|
1536
1585
|
color=color,
|
1537
1586
|
fontsize=fontsize,
|
1538
1587
|
)
|
1539
|
-
|
1540
|
-
plot.figsets(ax=ax,
|
1541
|
-
boxloc="none"
|
1542
|
-
)
|
1588
|
+
|
1589
|
+
plot.figsets(ax=ax, boxloc="none")
|
1543
1590
|
return ax
|
1544
1591
|
|
1592
|
+
|
1545
1593
|
def rank_models(
|
1546
1594
|
cv_test_scores,
|
1547
1595
|
rm_outlier=False,
|
@@ -1644,7 +1692,7 @@ def rank_models(
|
|
1644
1692
|
if rm_outlier:
|
1645
1693
|
cv_test_scores_ = ips.df_outlier(cv_test_scores)
|
1646
1694
|
else:
|
1647
|
-
cv_test_scores_=cv_test_scores
|
1695
|
+
cv_test_scores_ = cv_test_scores
|
1648
1696
|
|
1649
1697
|
# Normalize the scores of metrics if normalize is True
|
1650
1698
|
scaler = MinMaxScaler()
|
@@ -1673,7 +1721,7 @@ def rank_models(
|
|
1673
1721
|
)
|
1674
1722
|
plt.title("Classifier Performance")
|
1675
1723
|
plt.tight_layout()
|
1676
|
-
return plt
|
1724
|
+
return plt
|
1677
1725
|
|
1678
1726
|
nexttile = plot.subplot(2, 2, figsize=[10, 7])
|
1679
1727
|
generate_bar_plot(nexttile(), top_models.dropna())
|
@@ -1703,10 +1751,11 @@ def rank_models(
|
|
1703
1751
|
|
1704
1752
|
# figsave("classifier_performance.pdf")
|
1705
1753
|
|
1754
|
+
|
1706
1755
|
def predict(
|
1707
1756
|
x_train: pd.DataFrame,
|
1708
1757
|
y_train: pd.Series,
|
1709
|
-
x_true: pd.DataFrame=None,
|
1758
|
+
x_true: pd.DataFrame = None,
|
1710
1759
|
y_true: Optional[pd.Series] = None,
|
1711
1760
|
common_features: set = None,
|
1712
1761
|
purpose: str = "classification", # 'classification' or 'regression'
|
@@ -1714,117 +1763,156 @@ def predict(
|
|
1714
1763
|
metrics: Optional[List[str]] = None,
|
1715
1764
|
random_state: int = 1,
|
1716
1765
|
smote: bool = False,
|
1717
|
-
n_jobs:int
|
1766
|
+
n_jobs: int = -1,
|
1718
1767
|
plot_: bool = True,
|
1719
|
-
dir_save:str="./",
|
1720
|
-
test_size:float=0.2
|
1721
|
-
cv_folds:int=5
|
1722
|
-
cv_level:str="l"
|
1768
|
+
dir_save: str = "./",
|
1769
|
+
test_size: float = 0.2, # specific only when x_true is None
|
1770
|
+
cv_folds: int = 5, # more cv_folds 得更加稳定,auc可能更低
|
1771
|
+
cv_level: str = "l", # "s":'low',"m":'medium',"l":"high"
|
1723
1772
|
class_weight: str = "balanced",
|
1724
|
-
verbose:bool=False,
|
1773
|
+
verbose: bool = False,
|
1725
1774
|
) -> pd.DataFrame:
|
1726
|
-
"""
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
1744
|
-
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
|
1749
|
-
|
1750
|
-
|
1751
|
-
|
1752
|
-
|
1753
|
-
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1775
|
+
"""
|
1776
|
+
第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
|
1777
|
+
Usage:
|
1778
|
+
(1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
|
1779
|
+
predict 函数会根据 test_size 参数,将 x_train 和 y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
|
1780
|
+
(2). predict(x_train, y_train, x_true,...)使用 x_train 和 y_train 训练并对 x_true 进行预测
|
1781
|
+
由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true,
|
1782
|
+
因此无法与真实值进行对比。
|
1783
|
+
(3). predict(x_train, y_train, x_true, y_true,...)使用 x_train 和 y_train 训练,并验证 x_true 与真实标签 y_true.
|
1784
|
+
predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
|
1785
|
+
计算验证指标,完成对 x_true 的真正验证。
|
1786
|
+
trains and validates a variety of machine learning models for both classification and regression tasks.
|
1787
|
+
It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
|
1788
|
+
feature scaling, and handling of class imbalance through SMOTE.
|
1789
|
+
|
1790
|
+
Parameters:
|
1791
|
+
- x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
|
1792
|
+
- y_train (pd.Series):Target variable for the training dataset.
|
1793
|
+
- x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
|
1794
|
+
- y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
|
1795
|
+
- common_features (set, optional):Specifies a subset of features common across training and test data.
|
1796
|
+
- purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
|
1797
|
+
metrics and models are applied.
|
1798
|
+
- cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
|
1799
|
+
- metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
|
1800
|
+
- random_state (int, default = 1):Random seed to ensure reproducibility.
|
1801
|
+
- smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
|
1802
|
+
- n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
|
1803
|
+
- plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
|
1804
|
+
- test_size (float, default = 0.2):Test data proportion if x_true is not provided.
|
1805
|
+
- cv_folds (int, default = 5):Number of cross-validation folds.
|
1806
|
+
- cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
|
1807
|
+
- class_weight (str, default = "balanced"):Balances class weights in classification tasks.
|
1808
|
+
- verbose (bool, default = False):If True, prints detailed output during model training.
|
1809
|
+
- dir_save (str, default = "./"):Directory path to save plot outputs and results.
|
1810
|
+
|
1811
|
+
Key Steps in the Function:
|
1812
|
+
Model Initialization: Depending on purpose, initializes either classification or regression models.
|
1813
|
+
Feature Selection: Ensures training and test sets have matching feature columns.
|
1814
|
+
SMOTE Application: Balances classes if smote is enabled and the task is classification.
|
1815
|
+
Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
|
1816
|
+
Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
|
1768
1817
|
"""
|
1769
1818
|
from tqdm import tqdm
|
1770
|
-
from sklearn.ensemble import
|
1819
|
+
from sklearn.ensemble import (
|
1820
|
+
RandomForestClassifier,
|
1821
|
+
RandomForestRegressor,
|
1822
|
+
ExtraTreesClassifier,
|
1823
|
+
ExtraTreesRegressor,
|
1824
|
+
BaggingClassifier,
|
1825
|
+
BaggingRegressor,
|
1826
|
+
AdaBoostClassifier,
|
1827
|
+
AdaBoostRegressor,
|
1828
|
+
)
|
1771
1829
|
from sklearn.svm import SVC, SVR
|
1772
1830
|
from sklearn.tree import DecisionTreeRegressor
|
1773
|
-
from sklearn.linear_model import
|
1774
|
-
|
1775
|
-
|
1831
|
+
from sklearn.linear_model import (
|
1832
|
+
LogisticRegression,
|
1833
|
+
ElasticNet,
|
1834
|
+
ElasticNetCV,
|
1835
|
+
LinearRegression,
|
1836
|
+
Lasso,
|
1837
|
+
RidgeClassifierCV,
|
1838
|
+
Perceptron,
|
1839
|
+
SGDClassifier,
|
1840
|
+
)
|
1841
|
+
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
|
1842
|
+
from sklearn.naive_bayes import GaussianNB, BernoulliNB
|
1776
1843
|
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
|
1777
1844
|
import xgboost as xgb
|
1778
1845
|
import lightgbm as lgb
|
1779
1846
|
import catboost as cb
|
1780
1847
|
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
1781
1848
|
from sklearn.model_selection import GridSearchCV, StratifiedKFold, KFold
|
1782
|
-
from sklearn.discriminant_analysis import
|
1849
|
+
from sklearn.discriminant_analysis import (
|
1850
|
+
LinearDiscriminantAnalysis,
|
1851
|
+
QuadraticDiscriminantAnalysis,
|
1852
|
+
)
|
1783
1853
|
from sklearn.preprocessing import PolynomialFeatures
|
1784
1854
|
|
1785
|
-
|
1786
1855
|
# 拼写检查
|
1787
|
-
purpose=ips.strcmp(purpose,[
|
1856
|
+
purpose = ips.strcmp(purpose, ["classification", "regression"])[0]
|
1788
1857
|
print(f"{purpose} processing...")
|
1789
1858
|
# Default models or regressors if not provided
|
1790
1859
|
if purpose == "classification":
|
1791
1860
|
model_ = {
|
1792
|
-
"Random Forest": RandomForestClassifier(
|
1793
|
-
|
1794
|
-
|
1795
|
-
|
1796
|
-
|
1861
|
+
"Random Forest": RandomForestClassifier(
|
1862
|
+
random_state=random_state, class_weight=class_weight
|
1863
|
+
),
|
1864
|
+
# SVC (Support Vector Classification)
|
1865
|
+
"SVM": SVC(
|
1866
|
+
kernel="rbf",
|
1867
|
+
probability=True,
|
1868
|
+
class_weight=class_weight,
|
1869
|
+
random_state=random_state,
|
1870
|
+
),
|
1797
1871
|
# fit the best model without enforcing sparsity, which means it does not directly perform feature selection.
|
1798
|
-
"Logistic Regression": LogisticRegression(
|
1799
|
-
|
1872
|
+
"Logistic Regression": LogisticRegression(
|
1873
|
+
class_weight=class_weight, random_state=random_state
|
1874
|
+
),
|
1800
1875
|
# Logistic Regression with L1 Regularization (Lasso)
|
1801
|
-
"Lasso Logistic Regression": LogisticRegression(
|
1876
|
+
"Lasso Logistic Regression": LogisticRegression(
|
1877
|
+
penalty="l1", solver="saga", random_state=random_state
|
1878
|
+
),
|
1802
1879
|
"Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
|
1803
|
-
"XGBoost": xgb.XGBClassifier(
|
1880
|
+
"XGBoost": xgb.XGBClassifier(
|
1881
|
+
eval_metric="logloss",
|
1882
|
+
random_state=random_state,
|
1883
|
+
),
|
1804
1884
|
"KNN": KNeighborsClassifier(n_neighbors=5),
|
1805
1885
|
"Naive Bayes": GaussianNB(),
|
1806
1886
|
"Linear Discriminant Analysis": LinearDiscriminantAnalysis(),
|
1807
|
-
"AdaBoost":
|
1887
|
+
"AdaBoost": AdaBoostClassifier(
|
1888
|
+
algorithm="SAMME", random_state=random_state
|
1889
|
+
),
|
1808
1890
|
# "LightGBM": lgb.LGBMClassifier(random_state=random_state, class_weight=class_weight),
|
1809
1891
|
"CatBoost": cb.CatBoostClassifier(verbose=0, random_state=random_state),
|
1810
|
-
"Extra Trees": ExtraTreesClassifier(
|
1892
|
+
"Extra Trees": ExtraTreesClassifier(
|
1893
|
+
random_state=random_state, class_weight=class_weight
|
1894
|
+
),
|
1811
1895
|
"Bagging": BaggingClassifier(random_state=random_state),
|
1812
1896
|
"Neural Network": MLPClassifier(max_iter=500, random_state=random_state),
|
1813
1897
|
"DecisionTree": DecisionTreeClassifier(),
|
1814
1898
|
"Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis(),
|
1815
|
-
"Ridge": RidgeClassifierCV(
|
1899
|
+
"Ridge": RidgeClassifierCV(
|
1900
|
+
class_weight=class_weight, store_cv_results=True
|
1901
|
+
),
|
1816
1902
|
"Perceptron": Perceptron(random_state=random_state),
|
1817
1903
|
"Bernoulli Naive Bayes": BernoulliNB(),
|
1818
|
-
"SGDClassifier": SGDClassifier(random_state=random_state),
|
1904
|
+
"SGDClassifier": SGDClassifier(random_state=random_state),
|
1819
1905
|
}
|
1820
1906
|
elif purpose == "regression":
|
1821
1907
|
model_ = {
|
1822
1908
|
"Random Forest": RandomForestRegressor(random_state=random_state),
|
1823
|
-
"SVM": SVR()
|
1909
|
+
"SVM": SVR(), # SVR (Support Vector Regression)
|
1824
1910
|
# "Lasso": Lasso(random_state=random_state), # 它和LassoCV相同(必须要提供alpha参数),
|
1825
|
-
"LassoCV": LassoCV(
|
1911
|
+
"LassoCV": LassoCV(
|
1912
|
+
cv=cv_folds, random_state=random_state
|
1913
|
+
), # LassoCV自动找出最适alpha,优于Lasso
|
1826
1914
|
"Gradient Boosting": GradientBoostingRegressor(random_state=random_state),
|
1827
|
-
"XGBoost": xgb.XGBRegressor(eval_metric="rmse",random_state=random_state),
|
1915
|
+
"XGBoost": xgb.XGBRegressor(eval_metric="rmse", random_state=random_state),
|
1828
1916
|
"Linear Regression": LinearRegression(),
|
1829
1917
|
"Lasso": Lasso(random_state=random_state),
|
1830
1918
|
"AdaBoost": AdaBoostRegressor(random_state=random_state),
|
@@ -1834,71 +1922,76 @@ def predict(
|
|
1834
1922
|
"Bagging": BaggingRegressor(random_state=random_state),
|
1835
1923
|
"Neural Network": MLPRegressor(max_iter=500, random_state=random_state),
|
1836
1924
|
"ElasticNet": ElasticNet(random_state=random_state),
|
1837
|
-
"Ridge": Ridge(),
|
1838
|
-
"KNN":KNeighborsRegressor()
|
1925
|
+
"Ridge": Ridge(),
|
1926
|
+
"KNN": KNeighborsRegressor(),
|
1839
1927
|
}
|
1840
|
-
# indicate cls:
|
1841
|
-
if ips.run_once_within(30)
|
1928
|
+
# indicate cls:
|
1929
|
+
if ips.run_once_within(30): # 10 min
|
1842
1930
|
print(f"supported models: {list(model_.keys())}")
|
1843
1931
|
if cls is None:
|
1844
|
-
models=model_
|
1932
|
+
models = model_
|
1845
1933
|
else:
|
1846
1934
|
if not isinstance(cls, list):
|
1847
|
-
cls=[cls]
|
1848
|
-
models={}
|
1849
|
-
for cls_ in cls:
|
1935
|
+
cls = [cls]
|
1936
|
+
models = {}
|
1937
|
+
for cls_ in cls:
|
1850
1938
|
cls_ = ips.strcmp(cls_, list(model_.keys()))[0]
|
1851
1939
|
models[cls_] = model_[cls_]
|
1852
|
-
if
|
1853
|
-
x_train=ips.df_special_characters_cleaner(x_train)
|
1854
|
-
x_true=
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1940
|
+
if "LightGBM" in models:
|
1941
|
+
x_train = ips.df_special_characters_cleaner(x_train)
|
1942
|
+
x_true = (
|
1943
|
+
ips.df_special_characters_cleaner(x_true) if x_true is not None else None
|
1944
|
+
)
|
1945
|
+
|
1946
|
+
if isinstance(y_train, str) and y_train in x_train.columns:
|
1947
|
+
y_train_col_name = y_train
|
1948
|
+
y_train = x_train[y_train]
|
1949
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
1950
|
+
x_train = x_train.drop(y_train_col_name, axis=1)
|
1861
1951
|
else:
|
1862
|
-
y_train=ips.df_encoder(pd.DataFrame(y_train),method=
|
1952
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
1863
1953
|
|
1864
1954
|
if x_true is None:
|
1865
1955
|
x_train, x_true, y_train, y_true = train_test_split(
|
1866
|
-
x_train,
|
1867
|
-
y_train,
|
1868
|
-
test_size=test_size,
|
1869
|
-
random_state=random_state,
|
1870
|
-
stratify=y_train if purpose == "classification" else None
|
1956
|
+
x_train,
|
1957
|
+
y_train,
|
1958
|
+
test_size=test_size,
|
1959
|
+
random_state=random_state,
|
1960
|
+
stratify=y_train if purpose == "classification" else None,
|
1871
1961
|
)
|
1872
|
-
if isinstance(y_train, str) and y_train in x_train.columns:
|
1873
|
-
y_train_col_name=y_train
|
1874
|
-
y_train=x_train[y_train]
|
1875
|
-
y_train=ips.df_encoder(pd.DataFrame(y_train),method=
|
1876
|
-
x_train = x_train.drop(y_train_col_name,axis=1)
|
1962
|
+
if isinstance(y_train, str) and y_train in x_train.columns:
|
1963
|
+
y_train_col_name = y_train
|
1964
|
+
y_train = x_train[y_train]
|
1965
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
1966
|
+
x_train = x_train.drop(y_train_col_name, axis=1)
|
1877
1967
|
else:
|
1878
|
-
y_train=ips.df_encoder(
|
1968
|
+
y_train = ips.df_encoder(
|
1969
|
+
pd.DataFrame(y_train), method="dummy"
|
1970
|
+
).values.ravel()
|
1879
1971
|
if y_true is not None:
|
1880
|
-
if isinstance(y_true, str) and y_true in x_true.columns:
|
1881
|
-
y_true_col_name=y_true
|
1882
|
-
y_true=x_true[y_true]
|
1883
|
-
y_true=ips.df_encoder(pd.DataFrame(y_true),method=
|
1884
|
-
x_true = x_true.drop(y_true_col_name,axis=1)
|
1972
|
+
if isinstance(y_true, str) and y_true in x_true.columns:
|
1973
|
+
y_true_col_name = y_true
|
1974
|
+
y_true = x_true[y_true]
|
1975
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
|
1976
|
+
x_true = x_true.drop(y_true_col_name, axis=1)
|
1885
1977
|
else:
|
1886
|
-
y_true=ips.df_encoder(pd.DataFrame(y_true),method=
|
1978
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
|
1887
1979
|
|
1888
1980
|
# to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
|
1889
1981
|
|
1890
1982
|
# y_train=y_train.values.ravel() if y_train is not None else None
|
1891
1983
|
# y_true=y_true.values.ravel() if y_true is not None else None
|
1892
|
-
y_train =
|
1984
|
+
y_train = (
|
1985
|
+
y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
|
1986
|
+
)
|
1893
1987
|
y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
|
1894
1988
|
|
1895
|
-
|
1896
1989
|
# Ensure common features are selected
|
1897
1990
|
if common_features is not None:
|
1898
1991
|
x_train, x_true = x_train[common_features], x_true[common_features]
|
1899
1992
|
else:
|
1900
|
-
share_col_names = ips.shared(x_train.columns, x_true.columns,verbose=verbose)
|
1901
|
-
x_train, x_true =x_train[share_col_names], x_true[share_col_names]
|
1993
|
+
share_col_names = ips.shared(x_train.columns, x_true.columns, verbose=verbose)
|
1994
|
+
x_train, x_true = x_train[share_col_names], x_true[share_col_names]
|
1902
1995
|
|
1903
1996
|
x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
|
1904
1997
|
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
|
@@ -1917,26 +2010,30 @@ def predict(
|
|
1917
2010
|
x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
|
1918
2011
|
|
1919
2012
|
# Hyperparameter grids for tuning
|
1920
|
-
if cv_level in ["low",
|
2013
|
+
if cv_level in ["low", "simple", "s", "l"]:
|
1921
2014
|
param_grids = {
|
1922
|
-
"Random Forest":
|
1923
|
-
|
1924
|
-
|
1925
|
-
|
1926
|
-
|
1927
|
-
|
1928
|
-
|
1929
|
-
|
1930
|
-
"
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
2015
|
+
"Random Forest": (
|
2016
|
+
{
|
2017
|
+
"n_estimators": [100], # One basic option
|
2018
|
+
"max_depth": [None, 10],
|
2019
|
+
"min_samples_split": [2],
|
2020
|
+
"min_samples_leaf": [1],
|
2021
|
+
"class_weight": [None],
|
2022
|
+
}
|
2023
|
+
if purpose == "classification"
|
2024
|
+
else {
|
2025
|
+
"n_estimators": [100], # One basic option
|
2026
|
+
"max_depth": [None, 10],
|
2027
|
+
"min_samples_split": [2],
|
2028
|
+
"min_samples_leaf": [1],
|
2029
|
+
"max_features": [None],
|
2030
|
+
"bootstrap": [True], # Only one option for simplicity
|
2031
|
+
}
|
2032
|
+
),
|
1936
2033
|
"SVM": {
|
1937
2034
|
"C": [1],
|
1938
|
-
"gamma": [
|
1939
|
-
"kernel": [
|
2035
|
+
"gamma": ["scale"],
|
2036
|
+
"kernel": ["rbf"],
|
1940
2037
|
},
|
1941
2038
|
"Lasso": {
|
1942
2039
|
"alpha": [0.1],
|
@@ -1946,8 +2043,8 @@ def predict(
|
|
1946
2043
|
},
|
1947
2044
|
"Logistic Regression": {
|
1948
2045
|
"C": [1],
|
1949
|
-
"solver": [
|
1950
|
-
"penalty": [
|
2046
|
+
"solver": ["lbfgs"],
|
2047
|
+
"penalty": ["l2"],
|
1951
2048
|
"max_iter": [500],
|
1952
2049
|
},
|
1953
2050
|
"Gradient Boosting": {
|
@@ -1964,25 +2061,29 @@ def predict(
|
|
1964
2061
|
"subsample": [0.8],
|
1965
2062
|
"colsample_bytree": [0.8],
|
1966
2063
|
},
|
1967
|
-
"KNN":
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
2064
|
+
"KNN": (
|
2065
|
+
{
|
2066
|
+
"n_neighbors": [3],
|
2067
|
+
"weights": ["uniform"],
|
2068
|
+
"algorithm": ["auto"],
|
2069
|
+
"p": [2],
|
2070
|
+
}
|
2071
|
+
if purpose == "classification"
|
2072
|
+
else {
|
2073
|
+
"n_neighbors": [3],
|
2074
|
+
"weights": ["uniform"],
|
2075
|
+
"metric": ["euclidean"],
|
2076
|
+
"leaf_size": [30],
|
2077
|
+
"p": [2],
|
2078
|
+
}
|
2079
|
+
),
|
1979
2080
|
"Naive Bayes": {
|
1980
2081
|
"var_smoothing": [1e-9],
|
1981
2082
|
},
|
1982
2083
|
"SVR": {
|
1983
2084
|
"C": [1],
|
1984
|
-
"gamma": [
|
1985
|
-
"kernel": [
|
2085
|
+
"gamma": ["scale"],
|
2086
|
+
"kernel": ["rbf"],
|
1986
2087
|
},
|
1987
2088
|
"Linear Regression": {
|
1988
2089
|
"fit_intercept": [True],
|
@@ -2003,9 +2104,9 @@ def predict(
|
|
2003
2104
|
"n_estimators": [100],
|
2004
2105
|
"num_leaves": [31],
|
2005
2106
|
"max_depth": [10],
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2107
|
+
"min_data_in_leaf": [20],
|
2108
|
+
"min_gain_to_split": [0.01],
|
2109
|
+
"scale_pos_weight": [10],
|
2009
2110
|
},
|
2010
2111
|
"Bagging": {
|
2011
2112
|
"n_estimators": [50],
|
@@ -2033,132 +2134,168 @@ def predict(
|
|
2033
2134
|
"shrinkage": [None],
|
2034
2135
|
},
|
2035
2136
|
"Quadratic Discriminant Analysis": {
|
2036
|
-
|
2037
|
-
|
2038
|
-
|
2039
|
-
},
|
2040
|
-
"Ridge": {'class_weight': [None, 'balanced']} if purpose == "classification" else {
|
2041
|
-
'alpha': [0.1, 1, 10],
|
2137
|
+
"reg_param": [0.0],
|
2138
|
+
"priors": [None],
|
2139
|
+
"tol": [1e-4],
|
2042
2140
|
},
|
2141
|
+
"Ridge": (
|
2142
|
+
{"class_weight": [None, "balanced"]}
|
2143
|
+
if purpose == "classification"
|
2144
|
+
else {
|
2145
|
+
"alpha": [0.1, 1, 10],
|
2146
|
+
}
|
2147
|
+
),
|
2043
2148
|
"Perceptron": {
|
2044
|
-
|
2045
|
-
|
2046
|
-
|
2047
|
-
|
2149
|
+
"alpha": [1e-3],
|
2150
|
+
"penalty": ["l2"],
|
2151
|
+
"max_iter": [1000],
|
2152
|
+
"eta0": [1.0],
|
2048
2153
|
},
|
2049
2154
|
"Bernoulli Naive Bayes": {
|
2050
|
-
|
2051
|
-
|
2052
|
-
|
2155
|
+
"alpha": [0.1, 1, 10],
|
2156
|
+
"binarize": [0.0],
|
2157
|
+
"fit_prior": [True],
|
2053
2158
|
},
|
2054
2159
|
"SGDClassifier": {
|
2055
|
-
|
2056
|
-
|
2057
|
-
|
2058
|
-
|
2059
|
-
|
2060
|
-
|
2061
|
-
|
2062
|
-
|
2160
|
+
"eta0": [0.01],
|
2161
|
+
"loss": ["hinge"],
|
2162
|
+
"penalty": ["l2"],
|
2163
|
+
"alpha": [1e-3],
|
2164
|
+
"max_iter": [1000],
|
2165
|
+
"tol": [1e-3],
|
2166
|
+
"random_state": [random_state],
|
2167
|
+
"learning_rate": ["constant"],
|
2063
2168
|
},
|
2064
2169
|
}
|
2065
|
-
elif cv_level in [
|
2066
|
-
param_grids = {
|
2067
|
-
"Random Forest":
|
2068
|
-
|
2069
|
-
|
2070
|
-
|
2071
|
-
|
2072
|
-
|
2073
|
-
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2078
|
-
|
2079
|
-
|
2080
|
-
|
2170
|
+
elif cv_level in ["high", "advanced", "h"]:
|
2171
|
+
param_grids = {
|
2172
|
+
"Random Forest": (
|
2173
|
+
{
|
2174
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2175
|
+
"max_depth": [None, 3, 5, 10, 15, 20, 30],
|
2176
|
+
"min_samples_split": [2, 5, 10, 20],
|
2177
|
+
"min_samples_leaf": [1, 2, 4],
|
2178
|
+
"class_weight": (
|
2179
|
+
[None, "balanced"] if purpose == "classification" else {}
|
2180
|
+
),
|
2181
|
+
}
|
2182
|
+
if purpose == "classification"
|
2183
|
+
else {
|
2184
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2185
|
+
"max_depth": [None, 3, 5, 10, 15, 20, 30],
|
2186
|
+
"min_samples_split": [2, 5, 10, 20],
|
2187
|
+
"min_samples_leaf": [1, 2, 4],
|
2188
|
+
"max_features": [
|
2189
|
+
"auto",
|
2190
|
+
"sqrt",
|
2191
|
+
"log2",
|
2192
|
+
], # Number of features to consider when looking for the best split
|
2193
|
+
"bootstrap": [
|
2194
|
+
True,
|
2195
|
+
False,
|
2196
|
+
], # Whether bootstrap samples are used when building trees
|
2197
|
+
}
|
2198
|
+
),
|
2081
2199
|
"SVM": {
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2200
|
+
"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
|
2201
|
+
"gamma": ["scale", "auto", 0.001, 0.01, 0.1],
|
2202
|
+
"kernel": ["linear", "rbf", "poly"],
|
2203
|
+
},
|
2086
2204
|
"Logistic Regression": {
|
2087
|
-
|
2088
|
-
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2092
|
-
"Lasso":{
|
2093
|
-
|
2094
|
-
|
2095
|
-
|
2096
|
-
|
2097
|
-
|
2098
|
-
"LassoCV":{
|
2099
|
-
|
2100
|
-
|
2101
|
-
|
2102
|
-
|
2103
|
-
|
2205
|
+
"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
|
2206
|
+
"solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
|
2207
|
+
"penalty": ["l1", "l2", "elasticnet"],
|
2208
|
+
"max_iter": [100, 200, 300, 500],
|
2209
|
+
},
|
2210
|
+
"Lasso": {
|
2211
|
+
"alpha": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2212
|
+
"max_iter": [500, 1000, 2000, 5000],
|
2213
|
+
"tol": [1e-4, 1e-5, 1e-6],
|
2214
|
+
"selection": ["cyclic", "random"],
|
2215
|
+
},
|
2216
|
+
"LassoCV": {
|
2217
|
+
"alphas": [[0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
|
2218
|
+
"max_iter": [500, 1000, 2000, 5000],
|
2219
|
+
"cv": [3, 5, 10],
|
2220
|
+
"tol": [1e-4, 1e-5, 1e-6],
|
2221
|
+
},
|
2104
2222
|
"Gradient Boosting": {
|
2105
|
-
|
2106
|
-
|
2107
|
-
|
2108
|
-
|
2109
|
-
|
2110
|
-
|
2223
|
+
"n_estimators": [100, 200, 300, 400, 500, 700, 1000],
|
2224
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.5],
|
2225
|
+
"max_depth": [3, 5, 7, 9, 15],
|
2226
|
+
"min_samples_split": [2, 5, 10, 20],
|
2227
|
+
"subsample": [0.8, 1.0],
|
2228
|
+
},
|
2111
2229
|
"XGBoost": {
|
2112
|
-
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2116
|
-
|
2117
|
-
|
2118
|
-
"KNN":
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2123
|
-
|
2124
|
-
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
|
2129
|
-
|
2230
|
+
"n_estimators": [100, 200, 500, 700],
|
2231
|
+
"max_depth": [3, 5, 7, 10],
|
2232
|
+
"learning_rate": [0.01, 0.1, 0.2, 0.3],
|
2233
|
+
"subsample": [0.8, 1.0],
|
2234
|
+
"colsample_bytree": [0.8, 0.9, 1.0],
|
2235
|
+
},
|
2236
|
+
"KNN": (
|
2237
|
+
{
|
2238
|
+
"n_neighbors": [1, 3, 5, 10, 15, 20],
|
2239
|
+
"weights": ["uniform", "distance"],
|
2240
|
+
"algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
|
2241
|
+
"p": [1, 2], # 1 for Manhattan, 2 for Euclidean distance
|
2242
|
+
}
|
2243
|
+
if purpose == "classification"
|
2244
|
+
else {
|
2245
|
+
"n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
|
2246
|
+
"weights": [
|
2247
|
+
"uniform",
|
2248
|
+
"distance",
|
2249
|
+
], # Weight function used in prediction
|
2250
|
+
"metric": [
|
2251
|
+
"euclidean",
|
2252
|
+
"manhattan",
|
2253
|
+
"minkowski",
|
2254
|
+
], # Distance metric
|
2255
|
+
"leaf_size": [
|
2256
|
+
20,
|
2257
|
+
30,
|
2258
|
+
40,
|
2259
|
+
50,
|
2260
|
+
], # Leaf size for KDTree or BallTree algorithms
|
2261
|
+
"p": [
|
2262
|
+
1,
|
2263
|
+
2,
|
2264
|
+
], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
|
2265
|
+
}
|
2266
|
+
),
|
2130
2267
|
"Naive Bayes": {
|
2131
2268
|
"var_smoothing": [1e-10, 1e-9, 1e-8, 1e-7],
|
2132
|
-
|
2269
|
+
},
|
2133
2270
|
"AdaBoost": {
|
2134
|
-
|
2135
|
-
|
2136
|
-
|
2271
|
+
"n_estimators": [50, 100, 200, 300, 500],
|
2272
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.5, 1.0],
|
2273
|
+
},
|
2137
2274
|
"SVR": {
|
2138
|
-
|
2139
|
-
|
2140
|
-
|
2141
|
-
|
2275
|
+
"C": [0.01, 0.1, 1, 10, 100, 1000],
|
2276
|
+
"gamma": [0.001, 0.01, 0.1, "scale", "auto"],
|
2277
|
+
"kernel": ["linear", "rbf", "poly"],
|
2278
|
+
},
|
2142
2279
|
"Linear Regression": {
|
2143
2280
|
"fit_intercept": [True, False],
|
2144
|
-
|
2145
|
-
"Lasso":{
|
2281
|
+
},
|
2282
|
+
"Lasso": {
|
2146
2283
|
"alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2147
|
-
"max_iter": [1000, 2000] # Higher iteration limit for fine-tuning
|
2148
|
-
|
2284
|
+
"max_iter": [1000, 2000], # Higher iteration limit for fine-tuning
|
2285
|
+
},
|
2149
2286
|
"Extra Trees": {
|
2150
2287
|
"n_estimators": [100, 200, 500, 700, 1000],
|
2151
2288
|
"max_depth": [None, 5, 10, 15, 20, 30],
|
2152
2289
|
"min_samples_split": [2, 5, 10, 20],
|
2153
|
-
"min_samples_leaf": [1, 2, 4]
|
2154
|
-
|
2290
|
+
"min_samples_leaf": [1, 2, 4],
|
2291
|
+
},
|
2155
2292
|
"CatBoost": {
|
2156
2293
|
"iterations": [100, 200, 500],
|
2157
2294
|
"learning_rate": [0.001, 0.01, 0.1, 0.2],
|
2158
2295
|
"depth": [3, 5, 7, 10],
|
2159
2296
|
"l2_leaf_reg": [1, 3, 5, 7, 10],
|
2160
2297
|
"border_count": [32, 64, 128],
|
2161
|
-
|
2298
|
+
},
|
2162
2299
|
"LightGBM": {
|
2163
2300
|
"n_estimators": [100, 200, 500, 700, 1000],
|
2164
2301
|
"learning_rate": [0.001, 0.01, 0.1, 0.2],
|
@@ -2167,66 +2304,97 @@ def predict(
|
|
2167
2304
|
"min_child_samples": [5, 10, 20],
|
2168
2305
|
"subsample": [0.8, 1.0],
|
2169
2306
|
"colsample_bytree": [0.8, 0.9, 1.0],
|
2170
|
-
|
2307
|
+
},
|
2171
2308
|
"Neural Network": {
|
2172
2309
|
"hidden_layer_sizes": [(50,), (100,), (100, 50), (200, 100)],
|
2173
2310
|
"activation": ["relu", "tanh", "logistic"],
|
2174
2311
|
"solver": ["adam", "sgd", "lbfgs"],
|
2175
2312
|
"alpha": [0.0001, 0.001, 0.01],
|
2176
2313
|
"learning_rate": ["constant", "adaptive"],
|
2177
|
-
|
2314
|
+
},
|
2178
2315
|
"Decision Tree": {
|
2179
2316
|
"max_depth": [None, 5, 10, 20, 30],
|
2180
2317
|
"min_samples_split": [2, 5, 10, 20],
|
2181
2318
|
"min_samples_leaf": [1, 2, 5, 10],
|
2182
2319
|
"criterion": ["gini", "entropy"],
|
2183
2320
|
"splitter": ["best", "random"],
|
2184
|
-
|
2321
|
+
},
|
2185
2322
|
"Linear Discriminant Analysis": {
|
2186
2323
|
"solver": ["svd", "lsqr", "eigen"],
|
2187
|
-
"shrinkage": [
|
2188
|
-
|
2189
|
-
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2324
|
+
"shrinkage": [
|
2325
|
+
None,
|
2326
|
+
"auto",
|
2327
|
+
0.1,
|
2328
|
+
0.5,
|
2329
|
+
1.0,
|
2330
|
+
], # shrinkage levels for 'lsqr' and 'eigen'
|
2331
|
+
},
|
2332
|
+
"Ridge": (
|
2333
|
+
{"class_weight": [None, "balanced"]}
|
2334
|
+
if purpose == "classification"
|
2335
|
+
else {
|
2336
|
+
"alpha": [0.1, 1, 10, 100, 1000],
|
2337
|
+
"solver": ["auto", "svd", "cholesky", "lsqr", "lbfgs"],
|
2338
|
+
"fit_intercept": [
|
2339
|
+
True,
|
2340
|
+
False,
|
2341
|
+
], # Whether to calculate the intercept
|
2342
|
+
"normalize": [
|
2343
|
+
True,
|
2344
|
+
False,
|
2345
|
+
], # If True, the regressors X will be normalized
|
2194
2346
|
}
|
2195
|
-
|
2196
|
-
|
2197
|
-
|
2198
|
-
|
2199
|
-
|
2200
|
-
|
2201
|
-
|
2202
|
-
|
2203
|
-
|
2204
|
-
|
2205
|
-
|
2206
|
-
|
2207
|
-
|
2208
|
-
|
2209
|
-
|
2210
|
-
|
2211
|
-
|
2347
|
+
),
|
2348
|
+
}
|
2349
|
+
else: # median level
|
2350
|
+
param_grids = {
|
2351
|
+
"Random Forest": (
|
2352
|
+
{
|
2353
|
+
"n_estimators": [100, 200, 500],
|
2354
|
+
"max_depth": [None, 10, 20, 30],
|
2355
|
+
"min_samples_split": [2, 5, 10],
|
2356
|
+
"min_samples_leaf": [1, 2, 4],
|
2357
|
+
"class_weight": [None, "balanced"],
|
2358
|
+
}
|
2359
|
+
if purpose == "classification"
|
2360
|
+
else {
|
2361
|
+
"n_estimators": [100, 200, 500],
|
2362
|
+
"max_depth": [None, 10, 20, 30],
|
2363
|
+
"min_samples_split": [2, 5, 10],
|
2364
|
+
"min_samples_leaf": [1, 2, 4],
|
2365
|
+
"max_features": [
|
2366
|
+
"auto",
|
2367
|
+
"sqrt",
|
2368
|
+
"log2",
|
2369
|
+
], # Number of features to consider when looking for the best split
|
2370
|
+
"bootstrap": [
|
2371
|
+
True,
|
2372
|
+
False,
|
2373
|
+
], # Whether bootstrap samples are used when building trees
|
2374
|
+
}
|
2375
|
+
),
|
2212
2376
|
"SVM": {
|
2213
2377
|
"C": [0.1, 1, 10, 100], # Regularization strength
|
2214
|
-
"gamma": [
|
2215
|
-
"kernel": [
|
2378
|
+
"gamma": ["scale", "auto"], # Common gamma values
|
2379
|
+
"kernel": ["rbf", "linear", "poly"],
|
2216
2380
|
},
|
2217
2381
|
"Logistic Regression": {
|
2218
2382
|
"C": [0.1, 1, 10, 100], # Regularization strength
|
2219
|
-
"solver": [
|
2220
|
-
"penalty": [
|
2221
|
-
"max_iter": [
|
2383
|
+
"solver": ["lbfgs", "liblinear", "saga"], # Common solvers
|
2384
|
+
"penalty": ["l2"], # L2 penalty is most common
|
2385
|
+
"max_iter": [
|
2386
|
+
500,
|
2387
|
+
1000,
|
2388
|
+
2000,
|
2389
|
+
], # Increased max_iter for better convergence
|
2222
2390
|
},
|
2223
|
-
"Lasso":{
|
2391
|
+
"Lasso": {
|
2224
2392
|
"alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2225
|
-
"max_iter": [500, 1000, 2000]
|
2393
|
+
"max_iter": [500, 1000, 2000],
|
2226
2394
|
},
|
2227
|
-
"LassoCV":{
|
2395
|
+
"LassoCV": {
|
2228
2396
|
"alphas": [[0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
|
2229
|
-
"max_iter": [500, 1000, 2000]
|
2397
|
+
"max_iter": [500, 1000, 2000],
|
2230
2398
|
},
|
2231
2399
|
"Gradient Boosting": {
|
2232
2400
|
"n_estimators": [100, 200, 500],
|
@@ -2242,25 +2410,44 @@ def predict(
|
|
2242
2410
|
"subsample": [0.8, 1.0],
|
2243
2411
|
"colsample_bytree": [0.8, 1.0],
|
2244
2412
|
},
|
2245
|
-
"KNN":
|
2246
|
-
|
2247
|
-
|
2248
|
-
|
2249
|
-
|
2250
|
-
|
2251
|
-
|
2252
|
-
|
2253
|
-
|
2254
|
-
|
2255
|
-
|
2256
|
-
|
2413
|
+
"KNN": (
|
2414
|
+
{
|
2415
|
+
"n_neighbors": [3, 5, 7, 10],
|
2416
|
+
"weights": ["uniform", "distance"],
|
2417
|
+
"algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
|
2418
|
+
"p": [1, 2],
|
2419
|
+
}
|
2420
|
+
if purpose == "classification"
|
2421
|
+
else {
|
2422
|
+
"n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
|
2423
|
+
"weights": [
|
2424
|
+
"uniform",
|
2425
|
+
"distance",
|
2426
|
+
], # Weight function used in prediction
|
2427
|
+
"metric": [
|
2428
|
+
"euclidean",
|
2429
|
+
"manhattan",
|
2430
|
+
"minkowski",
|
2431
|
+
], # Distance metric
|
2432
|
+
"leaf_size": [
|
2433
|
+
20,
|
2434
|
+
30,
|
2435
|
+
40,
|
2436
|
+
50,
|
2437
|
+
], # Leaf size for KDTree or BallTree algorithms
|
2438
|
+
"p": [
|
2439
|
+
1,
|
2440
|
+
2,
|
2441
|
+
], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
|
2442
|
+
}
|
2443
|
+
),
|
2257
2444
|
"Naive Bayes": {
|
2258
2445
|
"var_smoothing": [1e-9, 1e-8, 1e-7],
|
2259
2446
|
},
|
2260
2447
|
"SVR": {
|
2261
2448
|
"C": [0.1, 1, 10, 100],
|
2262
|
-
"gamma": [
|
2263
|
-
"kernel": [
|
2449
|
+
"gamma": ["scale", "auto"],
|
2450
|
+
"kernel": ["rbf", "linear"],
|
2264
2451
|
},
|
2265
2452
|
"Linear Regression": {
|
2266
2453
|
"fit_intercept": [True, False],
|
@@ -2286,10 +2473,10 @@ def predict(
|
|
2286
2473
|
"learning_rate": [0.01, 0.1],
|
2287
2474
|
"num_leaves": [31, 50, 100],
|
2288
2475
|
"max_depth": [-1, 10, 20],
|
2289
|
-
|
2290
|
-
|
2291
|
-
|
2292
|
-
},
|
2476
|
+
"min_data_in_leaf": [20], # Minimum samples in each leaf
|
2477
|
+
"min_gain_to_split": [0.01], # Minimum gain to allow a split
|
2478
|
+
"scale_pos_weight": [10], # Address class imbalance
|
2479
|
+
},
|
2293
2480
|
"Bagging": {
|
2294
2481
|
"n_estimators": [10, 50, 100],
|
2295
2482
|
"max_samples": [0.5, 0.7, 1.0],
|
@@ -2314,41 +2501,73 @@ def predict(
|
|
2314
2501
|
"Linear Discriminant Analysis": {
|
2315
2502
|
"solver": ["svd", "lsqr", "eigen"],
|
2316
2503
|
"shrinkage": [None, "auto"],
|
2317
|
-
},
|
2318
|
-
|
2319
|
-
|
2320
|
-
|
2321
|
-
|
2322
|
-
|
2323
|
-
|
2324
|
-
|
2325
|
-
|
2326
|
-
|
2327
|
-
|
2328
|
-
|
2329
|
-
|
2330
|
-
|
2331
|
-
|
2332
|
-
|
2333
|
-
|
2334
|
-
|
2335
|
-
"
|
2336
|
-
|
2337
|
-
|
2338
|
-
|
2339
|
-
|
2340
|
-
|
2341
|
-
|
2342
|
-
|
2343
|
-
|
2344
|
-
|
2345
|
-
|
2346
|
-
|
2347
|
-
|
2348
|
-
|
2349
|
-
|
2504
|
+
},
|
2505
|
+
"Quadratic Discriminant Analysis": {
|
2506
|
+
"reg_param": [0.0, 0.1, 0.5, 1.0], # Regularization parameter
|
2507
|
+
"priors": [None, [0.5, 0.5], [0.3, 0.7]], # Class priors
|
2508
|
+
"tol": [
|
2509
|
+
1e-4,
|
2510
|
+
1e-3,
|
2511
|
+
1e-2,
|
2512
|
+
], # Tolerance value for the convergence of the algorithm
|
2513
|
+
},
|
2514
|
+
"Perceptron": {
|
2515
|
+
"alpha": [1e-4, 1e-3, 1e-2], # Regularization parameter
|
2516
|
+
"penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
|
2517
|
+
"max_iter": [1000, 2000], # Maximum number of iterations
|
2518
|
+
"eta0": [1.0, 0.1], # Learning rate for gradient descent
|
2519
|
+
"tol": [1e-3, 1e-4, 1e-5], # Tolerance for stopping criteria
|
2520
|
+
"random_state": [random_state], # Random state for reproducibility
|
2521
|
+
},
|
2522
|
+
"Bernoulli Naive Bayes": {
|
2523
|
+
"alpha": [0.1, 1.0, 10.0], # Additive (Laplace) smoothing parameter
|
2524
|
+
"binarize": [
|
2525
|
+
0.0,
|
2526
|
+
0.5,
|
2527
|
+
1.0,
|
2528
|
+
], # Threshold for binarizing the input features
|
2529
|
+
"fit_prior": [
|
2530
|
+
True,
|
2531
|
+
False,
|
2532
|
+
], # Whether to learn class prior probabilities
|
2533
|
+
},
|
2534
|
+
"SGDClassifier": {
|
2535
|
+
"eta0": [0.01, 0.1, 1.0],
|
2536
|
+
"loss": [
|
2537
|
+
"hinge",
|
2538
|
+
"log",
|
2539
|
+
"modified_huber",
|
2540
|
+
"squared_hinge",
|
2541
|
+
"perceptron",
|
2542
|
+
], # Loss function
|
2543
|
+
"penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
|
2544
|
+
"alpha": [1e-4, 1e-3, 1e-2], # Regularization strength
|
2545
|
+
"l1_ratio": [0.15, 0.5, 0.85], # L1 ratio for elasticnet penalty
|
2546
|
+
"max_iter": [1000, 2000], # Maximum number of iterations
|
2547
|
+
"tol": [1e-3, 1e-4], # Tolerance for stopping criteria
|
2548
|
+
"random_state": [random_state], # Random state for reproducibility
|
2549
|
+
"learning_rate": [
|
2550
|
+
"constant",
|
2551
|
+
"optimal",
|
2552
|
+
"invscaling",
|
2553
|
+
"adaptive",
|
2554
|
+
], # Learning rate schedule
|
2555
|
+
},
|
2556
|
+
"Ridge": (
|
2557
|
+
{"class_weight": [None, "balanced"]}
|
2558
|
+
if purpose == "classification"
|
2559
|
+
else {
|
2560
|
+
"alpha": [0.1, 1, 10, 100],
|
2561
|
+
"solver": [
|
2562
|
+
"auto",
|
2563
|
+
"svd",
|
2564
|
+
"cholesky",
|
2565
|
+
"lsqr",
|
2566
|
+
], # Solver for optimization
|
2567
|
+
}
|
2568
|
+
),
|
2350
2569
|
}
|
2351
|
-
|
2570
|
+
|
2352
2571
|
results = {}
|
2353
2572
|
# Use StratifiedKFold for classification and KFold for regression
|
2354
2573
|
cv = (
|
@@ -2359,11 +2578,11 @@ def predict(
|
|
2359
2578
|
|
2360
2579
|
# Train and validate each model
|
2361
2580
|
for name, clf in tqdm(
|
2362
|
-
|
2363
|
-
|
2364
|
-
|
2365
|
-
|
2366
|
-
|
2581
|
+
models.items(),
|
2582
|
+
desc="models",
|
2583
|
+
colour="green",
|
2584
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
|
2585
|
+
):
|
2367
2586
|
if verbose:
|
2368
2587
|
print(f"\nTraining and validating {name}:")
|
2369
2588
|
|
@@ -2381,7 +2600,7 @@ def predict(
|
|
2381
2600
|
gs.fit(x_train, y_train)
|
2382
2601
|
best_clf = gs.best_estimator_
|
2383
2602
|
# make sure x_train and x_test has the same name
|
2384
|
-
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2603
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2385
2604
|
y_pred = best_clf.predict(x_true)
|
2386
2605
|
|
2387
2606
|
# y_pred_proba
|
@@ -2396,18 +2615,23 @@ def predict(
|
|
2396
2615
|
)
|
2397
2616
|
else:
|
2398
2617
|
y_pred_proba = None # No probability output for certain models
|
2399
|
-
|
2400
|
-
|
2618
|
+
|
2401
2619
|
validation_scores = {}
|
2402
2620
|
if y_true is not None:
|
2403
|
-
validation_scores = cal_metrics(
|
2621
|
+
validation_scores = cal_metrics(
|
2622
|
+
y_true,
|
2623
|
+
y_pred,
|
2624
|
+
y_pred_proba=y_pred_proba,
|
2625
|
+
purpose=purpose,
|
2626
|
+
average="weighted",
|
2627
|
+
)
|
2404
2628
|
|
2405
2629
|
# Calculate ROC curve
|
2406
2630
|
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
2407
2631
|
if y_pred_proba is not None:
|
2408
2632
|
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
2409
2633
|
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
2410
|
-
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba,verbose=False)
|
2634
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
|
2411
2635
|
roc_auc = auc(fpr, tpr)
|
2412
2636
|
roc_info = {
|
2413
2637
|
"fpr": fpr.tolist(),
|
@@ -2425,11 +2649,14 @@ def predict(
|
|
2425
2649
|
}
|
2426
2650
|
else:
|
2427
2651
|
roc_info, pr_info = None, None
|
2428
|
-
if purpose=="classification":
|
2652
|
+
if purpose == "classification":
|
2429
2653
|
results[name] = {
|
2430
|
-
"best_clf": gs.best_estimator_,
|
2654
|
+
"best_clf": gs.best_estimator_,
|
2431
2655
|
"best_params": gs.best_params_,
|
2432
|
-
"auc_indiv":[
|
2656
|
+
"auc_indiv": [
|
2657
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
2658
|
+
for i in range(cv_folds)
|
2659
|
+
],
|
2433
2660
|
"scores": validation_scores,
|
2434
2661
|
"roc_curve": roc_info,
|
2435
2662
|
"pr_curve": pr_info,
|
@@ -2439,11 +2666,11 @@ def predict(
|
|
2439
2666
|
y_pred_proba.tolist() if y_pred_proba is not None else None
|
2440
2667
|
),
|
2441
2668
|
}
|
2442
|
-
else:
|
2669
|
+
else: # "regression"
|
2443
2670
|
results[name] = {
|
2444
|
-
"best_clf": gs.best_estimator_,
|
2671
|
+
"best_clf": gs.best_estimator_,
|
2445
2672
|
"best_params": gs.best_params_,
|
2446
|
-
"scores": validation_scores,
|
2673
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
2447
2674
|
"predictions": y_pred.tolist(),
|
2448
2675
|
"predictions_proba": (
|
2449
2676
|
y_pred_proba.tolist() if y_pred_proba is not None else None
|
@@ -2452,9 +2679,9 @@ def predict(
|
|
2452
2679
|
|
2453
2680
|
else:
|
2454
2681
|
results[name] = {
|
2455
|
-
"best_clf": gs.best_estimator_,
|
2682
|
+
"best_clf": gs.best_estimator_,
|
2456
2683
|
"best_params": gs.best_params_,
|
2457
|
-
"scores": validation_scores,
|
2684
|
+
"scores": validation_scores,
|
2458
2685
|
"predictions": y_pred.tolist(),
|
2459
2686
|
"predictions_proba": (
|
2460
2687
|
y_pred_proba.tolist() if y_pred_proba is not None else None
|
@@ -2465,76 +2692,80 @@ def predict(
|
|
2465
2692
|
df_results = pd.DataFrame.from_dict(results, orient="index")
|
2466
2693
|
|
2467
2694
|
# sort
|
2468
|
-
if y_true is not None and purpose=="classification":
|
2695
|
+
if y_true is not None and purpose == "classification":
|
2469
2696
|
df_scores = pd.DataFrame(
|
2470
|
-
|
2471
|
-
|
2472
|
-
df_results=df_results.loc[df_scores.index]
|
2697
|
+
df_results["scores"].tolist(), index=df_results["scores"].index
|
2698
|
+
).sort_values(by="roc_auc", ascending=False)
|
2699
|
+
df_results = df_results.loc[df_scores.index]
|
2473
2700
|
|
2474
2701
|
if plot_:
|
2475
2702
|
from datetime import datetime
|
2703
|
+
|
2476
2704
|
now_ = datetime.now().strftime("%y%m%d_%H%M%S")
|
2477
|
-
nexttile=plot.subplot(figsize=[12, 10])
|
2478
|
-
plot.heatmap(df_scores, kind="direct",ax=nexttile())
|
2705
|
+
nexttile = plot.subplot(figsize=[12, 10])
|
2706
|
+
plot.heatmap(df_scores, kind="direct", ax=nexttile())
|
2479
2707
|
plot.figsets(xangle=30)
|
2480
2708
|
if dir_save:
|
2481
|
-
ips.figsave(dir_save+f"scores_sorted_heatmap{now_}.pdf")
|
2482
|
-
if df_scores.shape[0]>1
|
2483
|
-
plot.heatmap(df_scores, kind="direct",cluster=True)
|
2709
|
+
ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
|
2710
|
+
if df_scores.shape[0] > 1: # draw cluster
|
2711
|
+
plot.heatmap(df_scores, kind="direct", cluster=True)
|
2484
2712
|
plot.figsets(xangle=30)
|
2485
2713
|
if dir_save:
|
2486
|
-
ips.figsave(dir_save+f"scores_clus{now_}.pdf")
|
2487
|
-
if all([plot_, y_true is not None, purpose==
|
2714
|
+
ips.figsave(dir_save + f"scores_clus{now_}.pdf")
|
2715
|
+
if all([plot_, y_true is not None, purpose == "classification"]):
|
2488
2716
|
try:
|
2489
|
-
if len(models)>3:
|
2717
|
+
if len(models) > 3:
|
2490
2718
|
plot_validate_features(df_results)
|
2491
2719
|
else:
|
2492
|
-
plot_validate_features_single(df_results,figsize=(12,4*len(models)))
|
2720
|
+
plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
|
2493
2721
|
if dir_save:
|
2494
|
-
ips.figsave(dir_save+f"validate_features{now_}.pdf")
|
2722
|
+
ips.figsave(dir_save + f"validate_features{now_}.pdf")
|
2495
2723
|
except Exception as e:
|
2496
2724
|
print(f"Error: 在画图的过程中出现了问题:{e}")
|
2497
2725
|
return df_results
|
2498
2726
|
|
2499
2727
|
|
2500
|
-
def cal_metrics(
|
2728
|
+
def cal_metrics(
|
2729
|
+
y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
|
2730
|
+
):
|
2501
2731
|
"""
|
2502
2732
|
Calculate regression or classification metrics based on the purpose.
|
2503
|
-
|
2733
|
+
|
2504
2734
|
Parameters:
|
2505
2735
|
- y_true: Array of true values.
|
2506
2736
|
- y_pred: Array of predicted labels for classification or predicted values for regression.
|
2507
2737
|
- y_pred_proba: Array of predicted probabilities for classification (optional).
|
2508
2738
|
- purpose: str, "regression" or "classification".
|
2509
2739
|
- average: str, averaging method for multi-class classification ("binary", "micro", "macro", "weighted", etc.).
|
2510
|
-
|
2740
|
+
|
2511
2741
|
Returns:
|
2512
2742
|
- validation_scores: dict of computed metrics.
|
2513
2743
|
"""
|
2514
2744
|
from sklearn.metrics import (
|
2515
|
-
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2530
|
-
|
2531
|
-
|
2745
|
+
mean_squared_error,
|
2746
|
+
mean_absolute_error,
|
2747
|
+
mean_absolute_percentage_error,
|
2748
|
+
explained_variance_score,
|
2749
|
+
r2_score,
|
2750
|
+
mean_squared_log_error,
|
2751
|
+
accuracy_score,
|
2752
|
+
precision_score,
|
2753
|
+
recall_score,
|
2754
|
+
f1_score,
|
2755
|
+
roc_auc_score,
|
2756
|
+
matthews_corrcoef,
|
2757
|
+
confusion_matrix,
|
2758
|
+
balanced_accuracy_score,
|
2759
|
+
average_precision_score,
|
2760
|
+
precision_recall_curve,
|
2761
|
+
)
|
2762
|
+
|
2532
2763
|
validation_scores = {}
|
2533
2764
|
|
2534
2765
|
if purpose == "regression":
|
2535
2766
|
y_true = np.asarray(y_true)
|
2536
2767
|
y_true = y_true.ravel()
|
2537
|
-
y_pred = np.asarray(y_pred)
|
2768
|
+
y_pred = np.asarray(y_pred)
|
2538
2769
|
y_pred = y_pred.ravel()
|
2539
2770
|
# Regression metrics
|
2540
2771
|
validation_scores = {
|
@@ -2544,7 +2775,7 @@ def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average
|
|
2544
2775
|
"r2": r2_score(y_true, y_pred),
|
2545
2776
|
"mape": mean_absolute_percentage_error(y_true, y_pred),
|
2546
2777
|
"explained_variance": explained_variance_score(y_true, y_pred),
|
2547
|
-
"mbd": np.mean(y_pred - y_true) # Mean Bias Deviation
|
2778
|
+
"mbd": np.mean(y_pred - y_true), # Mean Bias Deviation
|
2548
2779
|
}
|
2549
2780
|
# Check if MSLE can be calculated
|
2550
2781
|
if np.all(y_true >= 0) and np.all(y_pred >= 0): # Ensure no negative values
|
@@ -2560,21 +2791,24 @@ def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average
|
|
2560
2791
|
"recall": recall_score(y_true, y_pred, average=average),
|
2561
2792
|
"f1": f1_score(y_true, y_pred, average=average),
|
2562
2793
|
"mcc": matthews_corrcoef(y_true, y_pred),
|
2563
|
-
"specificity": None,
|
2564
|
-
"balanced_accuracy": balanced_accuracy_score(y_true, y_pred)
|
2794
|
+
"specificity": None,
|
2795
|
+
"balanced_accuracy": balanced_accuracy_score(y_true, y_pred),
|
2565
2796
|
}
|
2566
2797
|
|
2567
2798
|
# Confusion matrix to calculate specificity
|
2568
2799
|
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
2569
|
-
validation_scores["specificity"] =
|
2800
|
+
validation_scores["specificity"] = (
|
2801
|
+
tn / (tn + fp) if (tn + fp) > 0 else 0
|
2802
|
+
) # Specificity calculation
|
2570
2803
|
|
2571
|
-
if y_pred_proba is not None:
|
2804
|
+
if y_pred_proba is not None:
|
2572
2805
|
# Calculate ROC-AUC
|
2573
|
-
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2806
|
+
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2574
2807
|
# PR-AUC (Precision-Recall AUC) calculation
|
2575
2808
|
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
2576
2809
|
else:
|
2577
|
-
raise ValueError(
|
2810
|
+
raise ValueError(
|
2811
|
+
"Invalid purpose specified. Choose 'regression' or 'classification'."
|
2812
|
+
)
|
2578
2813
|
|
2579
2814
|
return validation_scores
|
2580
|
-
|