py2ls 0.2.4.10.3__py3-none-any.whl → 0.2.4.10.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/ml2ls.py
CHANGED
@@ -4,12 +4,13 @@ from sklearn.ensemble import (
|
|
4
4
|
AdaBoostClassifier,
|
5
5
|
BaggingClassifier,
|
6
6
|
)
|
7
|
-
from sklearn.svm import SVC,SVR
|
7
|
+
from sklearn.svm import SVC, SVR
|
8
8
|
from sklearn.calibration import CalibratedClassifierCV
|
9
9
|
from sklearn.model_selection import GridSearchCV, StratifiedKFold
|
10
10
|
from sklearn.linear_model import (
|
11
11
|
LassoCV,
|
12
|
-
LogisticRegression,
|
12
|
+
LogisticRegression,
|
13
|
+
LinearRegression,
|
13
14
|
Lasso,
|
14
15
|
Ridge,
|
15
16
|
RidgeClassifierCV,
|
@@ -47,7 +48,7 @@ from . import plot
|
|
47
48
|
import matplotlib.pyplot as plt
|
48
49
|
import seaborn as sns
|
49
50
|
|
50
|
-
plt.style.use("paper")
|
51
|
+
plt.style.use(str(get_cwd()) + "/data/styles/stylelib/paper.mplstyle")
|
51
52
|
import logging
|
52
53
|
import warnings
|
53
54
|
|
@@ -334,13 +335,13 @@ def features_naive_bayes(x_train: pd.DataFrame, y_train: pd.Series) -> list:
|
|
334
335
|
probabilities = nb.predict_proba(x_train)
|
335
336
|
# Limit the number of features safely, choosing the lesser of half the features or all columns
|
336
337
|
n_features = min(x_train.shape[1] // 2, len(x_train.columns))
|
337
|
-
|
338
|
+
|
338
339
|
# Sort probabilities, then map to valid column indices
|
339
340
|
sorted_indices = np.argsort(probabilities.max(axis=1))[:n_features]
|
340
|
-
|
341
|
+
|
341
342
|
# Ensure indices are within the column bounds of x_train
|
342
343
|
valid_indices = sorted_indices[sorted_indices < len(x_train.columns)]
|
343
|
-
|
344
|
+
|
344
345
|
return x_train.columns[valid_indices]
|
345
346
|
|
346
347
|
|
@@ -575,15 +576,28 @@ def get_features(
|
|
575
576
|
bagging_params: Optional[Dict] = None,
|
576
577
|
knn_params: Optional[Dict] = None,
|
577
578
|
cls: list = [
|
578
|
-
"lasso",
|
579
|
-
"
|
579
|
+
"lasso",
|
580
|
+
"ridge",
|
581
|
+
"Elastic Net(Enet)",
|
582
|
+
"gradient Boosting",
|
583
|
+
"Random forest (rf)",
|
584
|
+
"XGBoost (xgb)",
|
585
|
+
"Support Vector Machine(svm)",
|
586
|
+
"naive bayes",
|
587
|
+
"Linear Discriminant Analysis (lda)",
|
588
|
+
"adaboost",
|
589
|
+
"DecisionTree",
|
590
|
+
"KNeighbors",
|
591
|
+
"Bagging",
|
592
|
+
],
|
580
593
|
metrics: Optional[List[str]] = None,
|
581
594
|
cv_folds: int = 5,
|
582
595
|
strict: bool = False,
|
583
596
|
n_shared: int = 2, # 只要有两个方法有重合,就纳入common genes
|
584
597
|
use_selected_features: bool = True,
|
585
598
|
plot_: bool = True,
|
586
|
-
dir_save:str="./"
|
599
|
+
dir_save: str = "./",
|
600
|
+
) -> dict:
|
587
601
|
"""
|
588
602
|
Master function to perform feature selection and validate models.
|
589
603
|
"""
|
@@ -598,14 +612,14 @@ def get_features(
|
|
598
612
|
|
599
613
|
# fill na
|
600
614
|
if fill_missing:
|
601
|
-
ips.df_fillna(data=X,method=
|
602
|
-
if isinstance(y, str) and y in X.columns:
|
603
|
-
y_col_name=y
|
604
|
-
y=X[y]
|
605
|
-
y=ips.df_encoder(pd.DataFrame(y),method=
|
606
|
-
X = X.drop(y_col_name,axis=1)
|
615
|
+
ips.df_fillna(data=X, method="knn", inplace=True, axis=0)
|
616
|
+
if isinstance(y, str) and y in X.columns:
|
617
|
+
y_col_name = y
|
618
|
+
y = X[y]
|
619
|
+
y = ips.df_encoder(pd.DataFrame(y), method="dummy")
|
620
|
+
X = X.drop(y_col_name, axis=1)
|
607
621
|
else:
|
608
|
-
y=ips.df_encoder(pd.DataFrame(y),method=
|
622
|
+
y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
|
609
623
|
y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
|
610
624
|
y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
|
611
625
|
|
@@ -817,7 +831,7 @@ def get_features(
|
|
817
831
|
top_knn_features,
|
818
832
|
strict=strict,
|
819
833
|
n_shared=n_shared,
|
820
|
-
verbose=False
|
834
|
+
verbose=False,
|
821
835
|
)
|
822
836
|
|
823
837
|
# Use selected features or all features for model validation
|
@@ -899,13 +913,14 @@ def get_features(
|
|
899
913
|
results = {
|
900
914
|
"selected_features": features_df,
|
901
915
|
"cv_train_scores": cv_train_results_df,
|
902
|
-
"cv_test_scores": rank_models(cv_test_results_df,plot_=plot_),
|
916
|
+
"cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
|
903
917
|
"common_features": list(common_features),
|
904
918
|
}
|
905
|
-
if all([plot_,dir_save]):
|
919
|
+
if all([plot_, dir_save]):
|
906
920
|
from datetime import datetime
|
921
|
+
|
907
922
|
now_ = datetime.now().strftime("%y%m%d_%H%M%S")
|
908
|
-
ips.figsave(dir_save+f"features{now_}.pdf")
|
923
|
+
ips.figsave(dir_save + f"features{now_}.pdf")
|
909
924
|
else:
|
910
925
|
results = {
|
911
926
|
"selected_features": pd.DataFrame(),
|
@@ -931,7 +946,7 @@ def validate_features(
|
|
931
946
|
metrics: Optional[list] = None,
|
932
947
|
random_state: int = 1,
|
933
948
|
smote: bool = False,
|
934
|
-
n_jobs:int = -1,
|
949
|
+
n_jobs: int = -1,
|
935
950
|
plot_: bool = True,
|
936
951
|
class_weight: str = "balanced",
|
937
952
|
) -> dict:
|
@@ -952,8 +967,11 @@ def validate_features(
|
|
952
967
|
|
953
968
|
"""
|
954
969
|
from tqdm import tqdm
|
970
|
+
|
955
971
|
# Ensure common features are selected
|
956
|
-
common_features = ips.shared(
|
972
|
+
common_features = ips.shared(
|
973
|
+
common_features, x_train.columns, x_true.columns, strict=True, verbose=False
|
974
|
+
)
|
957
975
|
|
958
976
|
# Filter the training and validation datasets for the common features
|
959
977
|
x_train_selected = x_train[common_features]
|
@@ -1007,8 +1025,7 @@ def validate_features(
|
|
1007
1025
|
l1_ratio=0.5,
|
1008
1026
|
random_state=random_state,
|
1009
1027
|
),
|
1010
|
-
"XGBoost": xgb.XGBClassifier(eval_metric="logloss"
|
1011
|
-
),
|
1028
|
+
"XGBoost": xgb.XGBClassifier(eval_metric="logloss"),
|
1012
1029
|
"Naive Bayes": GaussianNB(),
|
1013
1030
|
"LDA": LinearDiscriminantAnalysis(),
|
1014
1031
|
}
|
@@ -1078,11 +1095,11 @@ def validate_features(
|
|
1078
1095
|
|
1079
1096
|
# Validate each classifier with GridSearchCV
|
1080
1097
|
for name, clf in tqdm(
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1098
|
+
models.items(),
|
1099
|
+
desc="for metric in metrics",
|
1100
|
+
colour="green",
|
1101
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
|
1102
|
+
):
|
1086
1103
|
print(f"\nValidating {name} on the validation dataset:")
|
1087
1104
|
|
1088
1105
|
# Check if `predict_proba` method exists; if not, use CalibratedClassifierCV
|
@@ -1162,7 +1179,7 @@ def validate_features(
|
|
1162
1179
|
if y_pred_proba is not None:
|
1163
1180
|
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
1164
1181
|
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
1165
|
-
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba,verbose=False)
|
1182
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
|
1166
1183
|
roc_auc = auc(fpr, tpr)
|
1167
1184
|
roc_info = {
|
1168
1185
|
"fpr": fpr.tolist(),
|
@@ -1197,6 +1214,7 @@ def validate_features(
|
|
1197
1214
|
# Validate models using the validation dataset (X_val, y_val)
|
1198
1215
|
# validation_results = validate_features(X, y, X_val, y_val, common_features)
|
1199
1216
|
|
1217
|
+
|
1200
1218
|
# # If you want to access validation scores
|
1201
1219
|
# print(validation_results)
|
1202
1220
|
def plot_validate_features(res_val):
|
@@ -1204,47 +1222,75 @@ def plot_validate_features(res_val):
|
|
1204
1222
|
plot the results of 'validate_features()'
|
1205
1223
|
"""
|
1206
1224
|
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1207
|
-
if res_val.shape[0]>5:
|
1208
|
-
alpha=0
|
1209
|
-
figsize=[8,10]
|
1210
|
-
subplot_layout=[1,2]
|
1211
|
-
ncols=2
|
1212
|
-
bbox_to_anchor=[1.5,0.6]
|
1225
|
+
if res_val.shape[0] > 5:
|
1226
|
+
alpha = 0
|
1227
|
+
figsize = [8, 10]
|
1228
|
+
subplot_layout = [1, 2]
|
1229
|
+
ncols = 2
|
1230
|
+
bbox_to_anchor = [1.5, 0.6]
|
1213
1231
|
else:
|
1214
|
-
alpha=0.03
|
1215
|
-
figsize=[10,6]
|
1216
|
-
subplot_layout=[1,1]
|
1217
|
-
ncols=1
|
1218
|
-
bbox_to_anchor=[1,1]
|
1232
|
+
alpha = 0.03
|
1233
|
+
figsize = [10, 6]
|
1234
|
+
subplot_layout = [1, 1]
|
1235
|
+
ncols = 1
|
1236
|
+
bbox_to_anchor = [1, 1]
|
1219
1237
|
nexttile = plot.subplot(figsize=figsize)
|
1220
|
-
ax = nexttile(subplot_layout[0],subplot_layout[1])
|
1238
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1221
1239
|
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1222
1240
|
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1223
1241
|
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1224
1242
|
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1225
1243
|
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1226
1244
|
plot_roc_curve(
|
1227
|
-
fpr,
|
1228
|
-
|
1229
|
-
|
1245
|
+
fpr,
|
1246
|
+
tpr,
|
1247
|
+
mean_auc,
|
1248
|
+
lower_ci,
|
1249
|
+
upper_ci,
|
1250
|
+
model_name=model_name,
|
1251
|
+
lw=1.5,
|
1252
|
+
color=colors[i],
|
1253
|
+
alpha=alpha,
|
1254
|
+
ax=ax,
|
1255
|
+
)
|
1256
|
+
plot.figsets(
|
1257
|
+
sp=2,
|
1258
|
+
legend=dict(
|
1259
|
+
loc="upper right",
|
1260
|
+
ncols=ncols,
|
1261
|
+
fontsize=8,
|
1262
|
+
bbox_to_anchor=[1.5, 0.6],
|
1263
|
+
markerscale=0.8,
|
1264
|
+
),
|
1265
|
+
)
|
1230
1266
|
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1231
1267
|
|
1232
|
-
ax = nexttile(subplot_layout[0],subplot_layout[1])
|
1268
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1233
1269
|
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1234
1270
|
plot_pr_curve(
|
1235
1271
|
recall=res_val["pr_curve"][model_name]["recall"],
|
1236
1272
|
precision=res_val["pr_curve"][model_name]["precision"],
|
1237
1273
|
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1238
1274
|
model_name=model_name,
|
1239
|
-
color=colors[i],
|
1240
|
-
|
1275
|
+
color=colors[i],
|
1276
|
+
lw=1.5,
|
1277
|
+
alpha=alpha,
|
1278
|
+
ax=ax,
|
1279
|
+
)
|
1280
|
+
plot.figsets(
|
1281
|
+
sp=2,
|
1282
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1283
|
+
)
|
1241
1284
|
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1242
|
-
|
1243
|
-
|
1285
|
+
|
1286
|
+
|
1287
|
+
def plot_validate_features_single(res_val, figsize=None):
|
1244
1288
|
if figsize is None:
|
1245
1289
|
nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
|
1246
1290
|
else:
|
1247
|
-
nexttile = plot.subplot(
|
1291
|
+
nexttile = plot.subplot(
|
1292
|
+
len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
|
1293
|
+
)
|
1248
1294
|
for model_name in ips.flatten(res_val["pr_curve"].index):
|
1249
1295
|
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1250
1296
|
tpr = res_val["roc_curve"][model_name]["tpr"]
|
@@ -1268,7 +1314,10 @@ def plot_validate_features_single(res_val,figsize=None):
|
|
1268
1314
|
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1269
1315
|
plot.figsets(title=model_name, sp=2)
|
1270
1316
|
|
1271
|
-
|
1317
|
+
|
1318
|
+
def cal_auc_ci(
|
1319
|
+
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
|
1320
|
+
):
|
1272
1321
|
y_true = np.asarray(y_true)
|
1273
1322
|
y_pred = np.asarray(y_pred)
|
1274
1323
|
bootstrapped_scores = []
|
@@ -1298,10 +1347,10 @@ def cal_auc_ci(y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,verbos
|
|
1298
1347
|
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1299
1348
|
if verbose:
|
1300
1349
|
print(
|
1301
|
-
|
1302
|
-
|
1350
|
+
"Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
|
1351
|
+
confidence_lower, confidence_upper
|
1352
|
+
)
|
1303
1353
|
)
|
1304
|
-
)
|
1305
1354
|
return confidence_lower, confidence_upper
|
1306
1355
|
|
1307
1356
|
|
@@ -1339,7 +1388,7 @@ def plot_roc_curve(
|
|
1339
1388
|
# Plot ROC curve and the diagonal reference line
|
1340
1389
|
ax.fill_between(fpr, tpr, alpha=alpha, color=color)
|
1341
1390
|
ax.plot([0, 1], [0, 1], color=diagonal_color, clip_on=False, linestyle="--")
|
1342
|
-
ax.plot(fpr, tpr, color=color, lw=lw, label=label,clip_on=False, **kwargs)
|
1391
|
+
ax.plot(fpr, tpr, color=color, lw=lw, label=label, clip_on=False, **kwargs)
|
1343
1392
|
# Setting plot limits, labels, and title
|
1344
1393
|
ax.set_xlim([-0.01, 1.0])
|
1345
1394
|
ax.set_ylim([0.0, 1.0])
|
@@ -1536,12 +1585,11 @@ def plot_cm(
|
|
1536
1585
|
color=color,
|
1537
1586
|
fontsize=fontsize,
|
1538
1587
|
)
|
1539
|
-
|
1540
|
-
plot.figsets(ax=ax,
|
1541
|
-
boxloc="none"
|
1542
|
-
)
|
1588
|
+
|
1589
|
+
plot.figsets(ax=ax, boxloc="none")
|
1543
1590
|
return ax
|
1544
1591
|
|
1592
|
+
|
1545
1593
|
def rank_models(
|
1546
1594
|
cv_test_scores,
|
1547
1595
|
rm_outlier=False,
|
@@ -1644,7 +1692,7 @@ def rank_models(
|
|
1644
1692
|
if rm_outlier:
|
1645
1693
|
cv_test_scores_ = ips.df_outlier(cv_test_scores)
|
1646
1694
|
else:
|
1647
|
-
cv_test_scores_=cv_test_scores
|
1695
|
+
cv_test_scores_ = cv_test_scores
|
1648
1696
|
|
1649
1697
|
# Normalize the scores of metrics if normalize is True
|
1650
1698
|
scaler = MinMaxScaler()
|
@@ -1673,7 +1721,7 @@ def rank_models(
|
|
1673
1721
|
)
|
1674
1722
|
plt.title("Classifier Performance")
|
1675
1723
|
plt.tight_layout()
|
1676
|
-
return plt
|
1724
|
+
return plt
|
1677
1725
|
|
1678
1726
|
nexttile = plot.subplot(2, 2, figsize=[10, 7])
|
1679
1727
|
generate_bar_plot(nexttile(), top_models.dropna())
|
@@ -1703,10 +1751,11 @@ def rank_models(
|
|
1703
1751
|
|
1704
1752
|
# figsave("classifier_performance.pdf")
|
1705
1753
|
|
1754
|
+
|
1706
1755
|
def predict(
|
1707
1756
|
x_train: pd.DataFrame,
|
1708
1757
|
y_train: pd.Series,
|
1709
|
-
x_true: pd.DataFrame=None,
|
1758
|
+
x_true: pd.DataFrame = None,
|
1710
1759
|
y_true: Optional[pd.Series] = None,
|
1711
1760
|
common_features: set = None,
|
1712
1761
|
purpose: str = "classification", # 'classification' or 'regression'
|
@@ -1714,117 +1763,156 @@ def predict(
|
|
1714
1763
|
metrics: Optional[List[str]] = None,
|
1715
1764
|
random_state: int = 1,
|
1716
1765
|
smote: bool = False,
|
1717
|
-
n_jobs:int
|
1766
|
+
n_jobs: int = -1,
|
1718
1767
|
plot_: bool = True,
|
1719
|
-
dir_save:str="./",
|
1720
|
-
test_size:float=0.2
|
1721
|
-
cv_folds:int=5
|
1722
|
-
cv_level:str="l"
|
1768
|
+
dir_save: str = "./",
|
1769
|
+
test_size: float = 0.2, # specific only when x_true is None
|
1770
|
+
cv_folds: int = 5, # more cv_folds 得更加稳定,auc可能更低
|
1771
|
+
cv_level: str = "l", # "s":'low',"m":'medium',"l":"high"
|
1723
1772
|
class_weight: str = "balanced",
|
1724
|
-
verbose:bool=False,
|
1773
|
+
verbose: bool = False,
|
1725
1774
|
) -> pd.DataFrame:
|
1726
|
-
"""
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
1744
|
-
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
|
1749
|
-
|
1750
|
-
|
1751
|
-
|
1752
|
-
|
1753
|
-
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1775
|
+
"""
|
1776
|
+
第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
|
1777
|
+
Usage:
|
1778
|
+
(1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
|
1779
|
+
predict 函数会根据 test_size 参数,将 x_train 和 y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
|
1780
|
+
(2). predict(x_train, y_train, x_true,...)使用 x_train 和 y_train 训练并对 x_true 进行预测
|
1781
|
+
由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true,
|
1782
|
+
因此无法与真实值进行对比。
|
1783
|
+
(3). predict(x_train, y_train, x_true, y_true,...)使用 x_train 和 y_train 训练,并验证 x_true 与真实标签 y_true.
|
1784
|
+
predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
|
1785
|
+
计算验证指标,完成对 x_true 的真正验证。
|
1786
|
+
trains and validates a variety of machine learning models for both classification and regression tasks.
|
1787
|
+
It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
|
1788
|
+
feature scaling, and handling of class imbalance through SMOTE.
|
1789
|
+
|
1790
|
+
Parameters:
|
1791
|
+
- x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
|
1792
|
+
- y_train (pd.Series):Target variable for the training dataset.
|
1793
|
+
- x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
|
1794
|
+
- y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
|
1795
|
+
- common_features (set, optional):Specifies a subset of features common across training and test data.
|
1796
|
+
- purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
|
1797
|
+
metrics and models are applied.
|
1798
|
+
- cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
|
1799
|
+
- metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
|
1800
|
+
- random_state (int, default = 1):Random seed to ensure reproducibility.
|
1801
|
+
- smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
|
1802
|
+
- n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
|
1803
|
+
- plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
|
1804
|
+
- test_size (float, default = 0.2):Test data proportion if x_true is not provided.
|
1805
|
+
- cv_folds (int, default = 5):Number of cross-validation folds.
|
1806
|
+
- cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
|
1807
|
+
- class_weight (str, default = "balanced"):Balances class weights in classification tasks.
|
1808
|
+
- verbose (bool, default = False):If True, prints detailed output during model training.
|
1809
|
+
- dir_save (str, default = "./"):Directory path to save plot outputs and results.
|
1810
|
+
|
1811
|
+
Key Steps in the Function:
|
1812
|
+
Model Initialization: Depending on purpose, initializes either classification or regression models.
|
1813
|
+
Feature Selection: Ensures training and test sets have matching feature columns.
|
1814
|
+
SMOTE Application: Balances classes if smote is enabled and the task is classification.
|
1815
|
+
Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
|
1816
|
+
Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
|
1768
1817
|
"""
|
1769
1818
|
from tqdm import tqdm
|
1770
|
-
from sklearn.ensemble import
|
1819
|
+
from sklearn.ensemble import (
|
1820
|
+
RandomForestClassifier,
|
1821
|
+
RandomForestRegressor,
|
1822
|
+
ExtraTreesClassifier,
|
1823
|
+
ExtraTreesRegressor,
|
1824
|
+
BaggingClassifier,
|
1825
|
+
BaggingRegressor,
|
1826
|
+
AdaBoostClassifier,
|
1827
|
+
AdaBoostRegressor,
|
1828
|
+
)
|
1771
1829
|
from sklearn.svm import SVC, SVR
|
1772
1830
|
from sklearn.tree import DecisionTreeRegressor
|
1773
|
-
from sklearn.linear_model import
|
1774
|
-
|
1775
|
-
|
1831
|
+
from sklearn.linear_model import (
|
1832
|
+
LogisticRegression,
|
1833
|
+
ElasticNet,
|
1834
|
+
ElasticNetCV,
|
1835
|
+
LinearRegression,
|
1836
|
+
Lasso,
|
1837
|
+
RidgeClassifierCV,
|
1838
|
+
Perceptron,
|
1839
|
+
SGDClassifier,
|
1840
|
+
)
|
1841
|
+
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
|
1842
|
+
from sklearn.naive_bayes import GaussianNB, BernoulliNB
|
1776
1843
|
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
|
1777
1844
|
import xgboost as xgb
|
1778
1845
|
import lightgbm as lgb
|
1779
1846
|
import catboost as cb
|
1780
1847
|
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
1781
1848
|
from sklearn.model_selection import GridSearchCV, StratifiedKFold, KFold
|
1782
|
-
from sklearn.discriminant_analysis import
|
1849
|
+
from sklearn.discriminant_analysis import (
|
1850
|
+
LinearDiscriminantAnalysis,
|
1851
|
+
QuadraticDiscriminantAnalysis,
|
1852
|
+
)
|
1783
1853
|
from sklearn.preprocessing import PolynomialFeatures
|
1784
1854
|
|
1785
|
-
|
1786
1855
|
# 拼写检查
|
1787
|
-
purpose=ips.strcmp(purpose,[
|
1856
|
+
purpose = ips.strcmp(purpose, ["classification", "regression"])[0]
|
1788
1857
|
print(f"{purpose} processing...")
|
1789
1858
|
# Default models or regressors if not provided
|
1790
1859
|
if purpose == "classification":
|
1791
1860
|
model_ = {
|
1792
|
-
"Random Forest": RandomForestClassifier(
|
1793
|
-
|
1794
|
-
|
1795
|
-
|
1796
|
-
|
1861
|
+
"Random Forest": RandomForestClassifier(
|
1862
|
+
random_state=random_state, class_weight=class_weight
|
1863
|
+
),
|
1864
|
+
# SVC (Support Vector Classification)
|
1865
|
+
"SVM": SVC(
|
1866
|
+
kernel="rbf",
|
1867
|
+
probability=True,
|
1868
|
+
class_weight=class_weight,
|
1869
|
+
random_state=random_state,
|
1870
|
+
),
|
1797
1871
|
# fit the best model without enforcing sparsity, which means it does not directly perform feature selection.
|
1798
|
-
"Logistic Regression": LogisticRegression(
|
1799
|
-
|
1872
|
+
"Logistic Regression": LogisticRegression(
|
1873
|
+
class_weight=class_weight, random_state=random_state
|
1874
|
+
),
|
1800
1875
|
# Logistic Regression with L1 Regularization (Lasso)
|
1801
|
-
"Lasso Logistic Regression": LogisticRegression(
|
1876
|
+
"Lasso Logistic Regression": LogisticRegression(
|
1877
|
+
penalty="l1", solver="saga", random_state=random_state
|
1878
|
+
),
|
1802
1879
|
"Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
|
1803
|
-
"XGBoost": xgb.XGBClassifier(
|
1880
|
+
"XGBoost": xgb.XGBClassifier(
|
1881
|
+
eval_metric="logloss",
|
1882
|
+
random_state=random_state,
|
1883
|
+
),
|
1804
1884
|
"KNN": KNeighborsClassifier(n_neighbors=5),
|
1805
1885
|
"Naive Bayes": GaussianNB(),
|
1806
1886
|
"Linear Discriminant Analysis": LinearDiscriminantAnalysis(),
|
1807
|
-
"AdaBoost":
|
1887
|
+
"AdaBoost": AdaBoostClassifier(
|
1888
|
+
algorithm="SAMME", random_state=random_state
|
1889
|
+
),
|
1808
1890
|
# "LightGBM": lgb.LGBMClassifier(random_state=random_state, class_weight=class_weight),
|
1809
1891
|
"CatBoost": cb.CatBoostClassifier(verbose=0, random_state=random_state),
|
1810
|
-
"Extra Trees": ExtraTreesClassifier(
|
1892
|
+
"Extra Trees": ExtraTreesClassifier(
|
1893
|
+
random_state=random_state, class_weight=class_weight
|
1894
|
+
),
|
1811
1895
|
"Bagging": BaggingClassifier(random_state=random_state),
|
1812
1896
|
"Neural Network": MLPClassifier(max_iter=500, random_state=random_state),
|
1813
1897
|
"DecisionTree": DecisionTreeClassifier(),
|
1814
1898
|
"Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis(),
|
1815
|
-
"Ridge": RidgeClassifierCV(
|
1899
|
+
"Ridge": RidgeClassifierCV(
|
1900
|
+
class_weight=class_weight, store_cv_results=True
|
1901
|
+
),
|
1816
1902
|
"Perceptron": Perceptron(random_state=random_state),
|
1817
1903
|
"Bernoulli Naive Bayes": BernoulliNB(),
|
1818
|
-
"SGDClassifier": SGDClassifier(random_state=random_state),
|
1904
|
+
"SGDClassifier": SGDClassifier(random_state=random_state),
|
1819
1905
|
}
|
1820
1906
|
elif purpose == "regression":
|
1821
1907
|
model_ = {
|
1822
1908
|
"Random Forest": RandomForestRegressor(random_state=random_state),
|
1823
|
-
"SVM": SVR()
|
1909
|
+
"SVM": SVR(), # SVR (Support Vector Regression)
|
1824
1910
|
# "Lasso": Lasso(random_state=random_state), # 它和LassoCV相同(必须要提供alpha参数),
|
1825
|
-
"LassoCV": LassoCV(
|
1911
|
+
"LassoCV": LassoCV(
|
1912
|
+
cv=cv_folds, random_state=random_state
|
1913
|
+
), # LassoCV自动找出最适alpha,优于Lasso
|
1826
1914
|
"Gradient Boosting": GradientBoostingRegressor(random_state=random_state),
|
1827
|
-
"XGBoost": xgb.XGBRegressor(eval_metric="rmse",random_state=random_state),
|
1915
|
+
"XGBoost": xgb.XGBRegressor(eval_metric="rmse", random_state=random_state),
|
1828
1916
|
"Linear Regression": LinearRegression(),
|
1829
1917
|
"Lasso": Lasso(random_state=random_state),
|
1830
1918
|
"AdaBoost": AdaBoostRegressor(random_state=random_state),
|
@@ -1834,71 +1922,76 @@ def predict(
|
|
1834
1922
|
"Bagging": BaggingRegressor(random_state=random_state),
|
1835
1923
|
"Neural Network": MLPRegressor(max_iter=500, random_state=random_state),
|
1836
1924
|
"ElasticNet": ElasticNet(random_state=random_state),
|
1837
|
-
"Ridge": Ridge(),
|
1838
|
-
"KNN":KNeighborsRegressor()
|
1925
|
+
"Ridge": Ridge(),
|
1926
|
+
"KNN": KNeighborsRegressor(),
|
1839
1927
|
}
|
1840
|
-
# indicate cls:
|
1841
|
-
if ips.run_once_within(30)
|
1928
|
+
# indicate cls:
|
1929
|
+
if ips.run_once_within(30): # 10 min
|
1842
1930
|
print(f"supported models: {list(model_.keys())}")
|
1843
1931
|
if cls is None:
|
1844
|
-
models=model_
|
1932
|
+
models = model_
|
1845
1933
|
else:
|
1846
1934
|
if not isinstance(cls, list):
|
1847
|
-
cls=[cls]
|
1848
|
-
models={}
|
1849
|
-
for cls_ in cls:
|
1935
|
+
cls = [cls]
|
1936
|
+
models = {}
|
1937
|
+
for cls_ in cls:
|
1850
1938
|
cls_ = ips.strcmp(cls_, list(model_.keys()))[0]
|
1851
1939
|
models[cls_] = model_[cls_]
|
1852
|
-
if
|
1853
|
-
x_train=ips.df_special_characters_cleaner(x_train)
|
1854
|
-
x_true=
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1940
|
+
if "LightGBM" in models:
|
1941
|
+
x_train = ips.df_special_characters_cleaner(x_train)
|
1942
|
+
x_true = (
|
1943
|
+
ips.df_special_characters_cleaner(x_true) if x_true is not None else None
|
1944
|
+
)
|
1945
|
+
|
1946
|
+
if isinstance(y_train, str) and y_train in x_train.columns:
|
1947
|
+
y_train_col_name = y_train
|
1948
|
+
y_train = x_train[y_train]
|
1949
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
1950
|
+
x_train = x_train.drop(y_train_col_name, axis=1)
|
1861
1951
|
else:
|
1862
|
-
y_train=ips.df_encoder(pd.DataFrame(y_train),method=
|
1952
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
1863
1953
|
|
1864
1954
|
if x_true is None:
|
1865
1955
|
x_train, x_true, y_train, y_true = train_test_split(
|
1866
|
-
x_train,
|
1867
|
-
y_train,
|
1868
|
-
test_size=test_size,
|
1869
|
-
random_state=random_state,
|
1870
|
-
stratify=y_train if purpose == "classification" else None
|
1956
|
+
x_train,
|
1957
|
+
y_train,
|
1958
|
+
test_size=test_size,
|
1959
|
+
random_state=random_state,
|
1960
|
+
stratify=y_train if purpose == "classification" else None,
|
1871
1961
|
)
|
1872
|
-
if isinstance(y_train, str) and y_train in x_train.columns:
|
1873
|
-
y_train_col_name=y_train
|
1874
|
-
y_train=x_train[y_train]
|
1875
|
-
y_train=ips.df_encoder(pd.DataFrame(y_train),method=
|
1876
|
-
x_train = x_train.drop(y_train_col_name,axis=1)
|
1962
|
+
if isinstance(y_train, str) and y_train in x_train.columns:
|
1963
|
+
y_train_col_name = y_train
|
1964
|
+
y_train = x_train[y_train]
|
1965
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
1966
|
+
x_train = x_train.drop(y_train_col_name, axis=1)
|
1877
1967
|
else:
|
1878
|
-
y_train=ips.df_encoder(
|
1968
|
+
y_train = ips.df_encoder(
|
1969
|
+
pd.DataFrame(y_train), method="dummy"
|
1970
|
+
).values.ravel()
|
1879
1971
|
if y_true is not None:
|
1880
|
-
if isinstance(y_true, str) and y_true in x_true.columns:
|
1881
|
-
y_true_col_name=y_true
|
1882
|
-
y_true=x_true[y_true]
|
1883
|
-
y_true=ips.df_encoder(pd.DataFrame(y_true),method=
|
1884
|
-
x_true = x_true.drop(y_true_col_name,axis=1)
|
1972
|
+
if isinstance(y_true, str) and y_true in x_true.columns:
|
1973
|
+
y_true_col_name = y_true
|
1974
|
+
y_true = x_true[y_true]
|
1975
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
|
1976
|
+
x_true = x_true.drop(y_true_col_name, axis=1)
|
1885
1977
|
else:
|
1886
|
-
y_true=ips.df_encoder(pd.DataFrame(y_true),method=
|
1978
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
|
1887
1979
|
|
1888
1980
|
# to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
|
1889
1981
|
|
1890
1982
|
# y_train=y_train.values.ravel() if y_train is not None else None
|
1891
1983
|
# y_true=y_true.values.ravel() if y_true is not None else None
|
1892
|
-
y_train =
|
1984
|
+
y_train = (
|
1985
|
+
y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
|
1986
|
+
)
|
1893
1987
|
y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
|
1894
1988
|
|
1895
|
-
|
1896
1989
|
# Ensure common features are selected
|
1897
1990
|
if common_features is not None:
|
1898
1991
|
x_train, x_true = x_train[common_features], x_true[common_features]
|
1899
1992
|
else:
|
1900
|
-
share_col_names = ips.shared(x_train.columns, x_true.columns,verbose=verbose)
|
1901
|
-
x_train, x_true =x_train[share_col_names], x_true[share_col_names]
|
1993
|
+
share_col_names = ips.shared(x_train.columns, x_true.columns, verbose=verbose)
|
1994
|
+
x_train, x_true = x_train[share_col_names], x_true[share_col_names]
|
1902
1995
|
|
1903
1996
|
x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
|
1904
1997
|
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
|
@@ -1917,26 +2010,30 @@ def predict(
|
|
1917
2010
|
x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
|
1918
2011
|
|
1919
2012
|
# Hyperparameter grids for tuning
|
1920
|
-
if cv_level in ["low",
|
2013
|
+
if cv_level in ["low", "simple", "s", "l"]:
|
1921
2014
|
param_grids = {
|
1922
|
-
"Random Forest":
|
1923
|
-
|
1924
|
-
|
1925
|
-
|
1926
|
-
|
1927
|
-
|
1928
|
-
|
1929
|
-
|
1930
|
-
"
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
2015
|
+
"Random Forest": (
|
2016
|
+
{
|
2017
|
+
"n_estimators": [100], # One basic option
|
2018
|
+
"max_depth": [None, 10],
|
2019
|
+
"min_samples_split": [2],
|
2020
|
+
"min_samples_leaf": [1],
|
2021
|
+
"class_weight": [None],
|
2022
|
+
}
|
2023
|
+
if purpose == "classification"
|
2024
|
+
else {
|
2025
|
+
"n_estimators": [100], # One basic option
|
2026
|
+
"max_depth": [None, 10],
|
2027
|
+
"min_samples_split": [2],
|
2028
|
+
"min_samples_leaf": [1],
|
2029
|
+
"max_features": [None],
|
2030
|
+
"bootstrap": [True], # Only one option for simplicity
|
2031
|
+
}
|
2032
|
+
),
|
1936
2033
|
"SVM": {
|
1937
2034
|
"C": [1],
|
1938
|
-
"gamma": [
|
1939
|
-
"kernel": [
|
2035
|
+
"gamma": ["scale"],
|
2036
|
+
"kernel": ["rbf"],
|
1940
2037
|
},
|
1941
2038
|
"Lasso": {
|
1942
2039
|
"alpha": [0.1],
|
@@ -1946,8 +2043,8 @@ def predict(
|
|
1946
2043
|
},
|
1947
2044
|
"Logistic Regression": {
|
1948
2045
|
"C": [1],
|
1949
|
-
"solver": [
|
1950
|
-
"penalty": [
|
2046
|
+
"solver": ["lbfgs"],
|
2047
|
+
"penalty": ["l2"],
|
1951
2048
|
"max_iter": [500],
|
1952
2049
|
},
|
1953
2050
|
"Gradient Boosting": {
|
@@ -1964,25 +2061,29 @@ def predict(
|
|
1964
2061
|
"subsample": [0.8],
|
1965
2062
|
"colsample_bytree": [0.8],
|
1966
2063
|
},
|
1967
|
-
"KNN":
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
2064
|
+
"KNN": (
|
2065
|
+
{
|
2066
|
+
"n_neighbors": [3],
|
2067
|
+
"weights": ["uniform"],
|
2068
|
+
"algorithm": ["auto"],
|
2069
|
+
"p": [2],
|
2070
|
+
}
|
2071
|
+
if purpose == "classification"
|
2072
|
+
else {
|
2073
|
+
"n_neighbors": [3],
|
2074
|
+
"weights": ["uniform"],
|
2075
|
+
"metric": ["euclidean"],
|
2076
|
+
"leaf_size": [30],
|
2077
|
+
"p": [2],
|
2078
|
+
}
|
2079
|
+
),
|
1979
2080
|
"Naive Bayes": {
|
1980
2081
|
"var_smoothing": [1e-9],
|
1981
2082
|
},
|
1982
2083
|
"SVR": {
|
1983
2084
|
"C": [1],
|
1984
|
-
"gamma": [
|
1985
|
-
"kernel": [
|
2085
|
+
"gamma": ["scale"],
|
2086
|
+
"kernel": ["rbf"],
|
1986
2087
|
},
|
1987
2088
|
"Linear Regression": {
|
1988
2089
|
"fit_intercept": [True],
|
@@ -2003,9 +2104,9 @@ def predict(
|
|
2003
2104
|
"n_estimators": [100],
|
2004
2105
|
"num_leaves": [31],
|
2005
2106
|
"max_depth": [10],
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2107
|
+
"min_data_in_leaf": [20],
|
2108
|
+
"min_gain_to_split": [0.01],
|
2109
|
+
"scale_pos_weight": [10],
|
2009
2110
|
},
|
2010
2111
|
"Bagging": {
|
2011
2112
|
"n_estimators": [50],
|
@@ -2033,132 +2134,168 @@ def predict(
|
|
2033
2134
|
"shrinkage": [None],
|
2034
2135
|
},
|
2035
2136
|
"Quadratic Discriminant Analysis": {
|
2036
|
-
|
2037
|
-
|
2038
|
-
|
2039
|
-
},
|
2040
|
-
"Ridge": {'class_weight': [None, 'balanced']} if purpose == "classification" else {
|
2041
|
-
'alpha': [0.1, 1, 10],
|
2137
|
+
"reg_param": [0.0],
|
2138
|
+
"priors": [None],
|
2139
|
+
"tol": [1e-4],
|
2042
2140
|
},
|
2141
|
+
"Ridge": (
|
2142
|
+
{"class_weight": [None, "balanced"]}
|
2143
|
+
if purpose == "classification"
|
2144
|
+
else {
|
2145
|
+
"alpha": [0.1, 1, 10],
|
2146
|
+
}
|
2147
|
+
),
|
2043
2148
|
"Perceptron": {
|
2044
|
-
|
2045
|
-
|
2046
|
-
|
2047
|
-
|
2149
|
+
"alpha": [1e-3],
|
2150
|
+
"penalty": ["l2"],
|
2151
|
+
"max_iter": [1000],
|
2152
|
+
"eta0": [1.0],
|
2048
2153
|
},
|
2049
2154
|
"Bernoulli Naive Bayes": {
|
2050
|
-
|
2051
|
-
|
2052
|
-
|
2155
|
+
"alpha": [0.1, 1, 10],
|
2156
|
+
"binarize": [0.0],
|
2157
|
+
"fit_prior": [True],
|
2053
2158
|
},
|
2054
2159
|
"SGDClassifier": {
|
2055
|
-
|
2056
|
-
|
2057
|
-
|
2058
|
-
|
2059
|
-
|
2060
|
-
|
2061
|
-
|
2062
|
-
|
2160
|
+
"eta0": [0.01],
|
2161
|
+
"loss": ["hinge"],
|
2162
|
+
"penalty": ["l2"],
|
2163
|
+
"alpha": [1e-3],
|
2164
|
+
"max_iter": [1000],
|
2165
|
+
"tol": [1e-3],
|
2166
|
+
"random_state": [random_state],
|
2167
|
+
"learning_rate": ["constant"],
|
2063
2168
|
},
|
2064
2169
|
}
|
2065
|
-
elif cv_level in [
|
2066
|
-
param_grids = {
|
2067
|
-
"Random Forest":
|
2068
|
-
|
2069
|
-
|
2070
|
-
|
2071
|
-
|
2072
|
-
|
2073
|
-
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2078
|
-
|
2079
|
-
|
2080
|
-
|
2170
|
+
elif cv_level in ["high", "advanced", "h"]:
|
2171
|
+
param_grids = {
|
2172
|
+
"Random Forest": (
|
2173
|
+
{
|
2174
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2175
|
+
"max_depth": [None, 3, 5, 10, 15, 20, 30],
|
2176
|
+
"min_samples_split": [2, 5, 10, 20],
|
2177
|
+
"min_samples_leaf": [1, 2, 4],
|
2178
|
+
"class_weight": (
|
2179
|
+
[None, "balanced"] if purpose == "classification" else {}
|
2180
|
+
),
|
2181
|
+
}
|
2182
|
+
if purpose == "classification"
|
2183
|
+
else {
|
2184
|
+
"n_estimators": [100, 200, 500, 700, 1000],
|
2185
|
+
"max_depth": [None, 3, 5, 10, 15, 20, 30],
|
2186
|
+
"min_samples_split": [2, 5, 10, 20],
|
2187
|
+
"min_samples_leaf": [1, 2, 4],
|
2188
|
+
"max_features": [
|
2189
|
+
"auto",
|
2190
|
+
"sqrt",
|
2191
|
+
"log2",
|
2192
|
+
], # Number of features to consider when looking for the best split
|
2193
|
+
"bootstrap": [
|
2194
|
+
True,
|
2195
|
+
False,
|
2196
|
+
], # Whether bootstrap samples are used when building trees
|
2197
|
+
}
|
2198
|
+
),
|
2081
2199
|
"SVM": {
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2200
|
+
"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
|
2201
|
+
"gamma": ["scale", "auto", 0.001, 0.01, 0.1],
|
2202
|
+
"kernel": ["linear", "rbf", "poly"],
|
2203
|
+
},
|
2086
2204
|
"Logistic Regression": {
|
2087
|
-
|
2088
|
-
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2092
|
-
"Lasso":{
|
2093
|
-
|
2094
|
-
|
2095
|
-
|
2096
|
-
|
2097
|
-
|
2098
|
-
"LassoCV":{
|
2099
|
-
|
2100
|
-
|
2101
|
-
|
2102
|
-
|
2103
|
-
|
2205
|
+
"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
|
2206
|
+
"solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
|
2207
|
+
"penalty": ["l1", "l2", "elasticnet"],
|
2208
|
+
"max_iter": [100, 200, 300, 500],
|
2209
|
+
},
|
2210
|
+
"Lasso": {
|
2211
|
+
"alpha": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2212
|
+
"max_iter": [500, 1000, 2000, 5000],
|
2213
|
+
"tol": [1e-4, 1e-5, 1e-6],
|
2214
|
+
"selection": ["cyclic", "random"],
|
2215
|
+
},
|
2216
|
+
"LassoCV": {
|
2217
|
+
"alphas": [[0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
|
2218
|
+
"max_iter": [500, 1000, 2000, 5000],
|
2219
|
+
"cv": [3, 5, 10],
|
2220
|
+
"tol": [1e-4, 1e-5, 1e-6],
|
2221
|
+
},
|
2104
2222
|
"Gradient Boosting": {
|
2105
|
-
|
2106
|
-
|
2107
|
-
|
2108
|
-
|
2109
|
-
|
2110
|
-
|
2223
|
+
"n_estimators": [100, 200, 300, 400, 500, 700, 1000],
|
2224
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.5],
|
2225
|
+
"max_depth": [3, 5, 7, 9, 15],
|
2226
|
+
"min_samples_split": [2, 5, 10, 20],
|
2227
|
+
"subsample": [0.8, 1.0],
|
2228
|
+
},
|
2111
2229
|
"XGBoost": {
|
2112
|
-
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2116
|
-
|
2117
|
-
|
2118
|
-
"KNN":
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2123
|
-
|
2124
|
-
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
|
2129
|
-
|
2230
|
+
"n_estimators": [100, 200, 500, 700],
|
2231
|
+
"max_depth": [3, 5, 7, 10],
|
2232
|
+
"learning_rate": [0.01, 0.1, 0.2, 0.3],
|
2233
|
+
"subsample": [0.8, 1.0],
|
2234
|
+
"colsample_bytree": [0.8, 0.9, 1.0],
|
2235
|
+
},
|
2236
|
+
"KNN": (
|
2237
|
+
{
|
2238
|
+
"n_neighbors": [1, 3, 5, 10, 15, 20],
|
2239
|
+
"weights": ["uniform", "distance"],
|
2240
|
+
"algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
|
2241
|
+
"p": [1, 2], # 1 for Manhattan, 2 for Euclidean distance
|
2242
|
+
}
|
2243
|
+
if purpose == "classification"
|
2244
|
+
else {
|
2245
|
+
"n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
|
2246
|
+
"weights": [
|
2247
|
+
"uniform",
|
2248
|
+
"distance",
|
2249
|
+
], # Weight function used in prediction
|
2250
|
+
"metric": [
|
2251
|
+
"euclidean",
|
2252
|
+
"manhattan",
|
2253
|
+
"minkowski",
|
2254
|
+
], # Distance metric
|
2255
|
+
"leaf_size": [
|
2256
|
+
20,
|
2257
|
+
30,
|
2258
|
+
40,
|
2259
|
+
50,
|
2260
|
+
], # Leaf size for KDTree or BallTree algorithms
|
2261
|
+
"p": [
|
2262
|
+
1,
|
2263
|
+
2,
|
2264
|
+
], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
|
2265
|
+
}
|
2266
|
+
),
|
2130
2267
|
"Naive Bayes": {
|
2131
2268
|
"var_smoothing": [1e-10, 1e-9, 1e-8, 1e-7],
|
2132
|
-
|
2269
|
+
},
|
2133
2270
|
"AdaBoost": {
|
2134
|
-
|
2135
|
-
|
2136
|
-
|
2271
|
+
"n_estimators": [50, 100, 200, 300, 500],
|
2272
|
+
"learning_rate": [0.001, 0.01, 0.1, 0.5, 1.0],
|
2273
|
+
},
|
2137
2274
|
"SVR": {
|
2138
|
-
|
2139
|
-
|
2140
|
-
|
2141
|
-
|
2275
|
+
"C": [0.01, 0.1, 1, 10, 100, 1000],
|
2276
|
+
"gamma": [0.001, 0.01, 0.1, "scale", "auto"],
|
2277
|
+
"kernel": ["linear", "rbf", "poly"],
|
2278
|
+
},
|
2142
2279
|
"Linear Regression": {
|
2143
2280
|
"fit_intercept": [True, False],
|
2144
|
-
|
2145
|
-
"Lasso":{
|
2281
|
+
},
|
2282
|
+
"Lasso": {
|
2146
2283
|
"alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2147
|
-
"max_iter": [1000, 2000] # Higher iteration limit for fine-tuning
|
2148
|
-
|
2284
|
+
"max_iter": [1000, 2000], # Higher iteration limit for fine-tuning
|
2285
|
+
},
|
2149
2286
|
"Extra Trees": {
|
2150
2287
|
"n_estimators": [100, 200, 500, 700, 1000],
|
2151
2288
|
"max_depth": [None, 5, 10, 15, 20, 30],
|
2152
2289
|
"min_samples_split": [2, 5, 10, 20],
|
2153
|
-
"min_samples_leaf": [1, 2, 4]
|
2154
|
-
|
2290
|
+
"min_samples_leaf": [1, 2, 4],
|
2291
|
+
},
|
2155
2292
|
"CatBoost": {
|
2156
2293
|
"iterations": [100, 200, 500],
|
2157
2294
|
"learning_rate": [0.001, 0.01, 0.1, 0.2],
|
2158
2295
|
"depth": [3, 5, 7, 10],
|
2159
2296
|
"l2_leaf_reg": [1, 3, 5, 7, 10],
|
2160
2297
|
"border_count": [32, 64, 128],
|
2161
|
-
|
2298
|
+
},
|
2162
2299
|
"LightGBM": {
|
2163
2300
|
"n_estimators": [100, 200, 500, 700, 1000],
|
2164
2301
|
"learning_rate": [0.001, 0.01, 0.1, 0.2],
|
@@ -2167,66 +2304,97 @@ def predict(
|
|
2167
2304
|
"min_child_samples": [5, 10, 20],
|
2168
2305
|
"subsample": [0.8, 1.0],
|
2169
2306
|
"colsample_bytree": [0.8, 0.9, 1.0],
|
2170
|
-
|
2307
|
+
},
|
2171
2308
|
"Neural Network": {
|
2172
2309
|
"hidden_layer_sizes": [(50,), (100,), (100, 50), (200, 100)],
|
2173
2310
|
"activation": ["relu", "tanh", "logistic"],
|
2174
2311
|
"solver": ["adam", "sgd", "lbfgs"],
|
2175
2312
|
"alpha": [0.0001, 0.001, 0.01],
|
2176
2313
|
"learning_rate": ["constant", "adaptive"],
|
2177
|
-
|
2314
|
+
},
|
2178
2315
|
"Decision Tree": {
|
2179
2316
|
"max_depth": [None, 5, 10, 20, 30],
|
2180
2317
|
"min_samples_split": [2, 5, 10, 20],
|
2181
2318
|
"min_samples_leaf": [1, 2, 5, 10],
|
2182
2319
|
"criterion": ["gini", "entropy"],
|
2183
2320
|
"splitter": ["best", "random"],
|
2184
|
-
|
2321
|
+
},
|
2185
2322
|
"Linear Discriminant Analysis": {
|
2186
2323
|
"solver": ["svd", "lsqr", "eigen"],
|
2187
|
-
"shrinkage": [
|
2188
|
-
|
2189
|
-
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2324
|
+
"shrinkage": [
|
2325
|
+
None,
|
2326
|
+
"auto",
|
2327
|
+
0.1,
|
2328
|
+
0.5,
|
2329
|
+
1.0,
|
2330
|
+
], # shrinkage levels for 'lsqr' and 'eigen'
|
2331
|
+
},
|
2332
|
+
"Ridge": (
|
2333
|
+
{"class_weight": [None, "balanced"]}
|
2334
|
+
if purpose == "classification"
|
2335
|
+
else {
|
2336
|
+
"alpha": [0.1, 1, 10, 100, 1000],
|
2337
|
+
"solver": ["auto", "svd", "cholesky", "lsqr", "lbfgs"],
|
2338
|
+
"fit_intercept": [
|
2339
|
+
True,
|
2340
|
+
False,
|
2341
|
+
], # Whether to calculate the intercept
|
2342
|
+
"normalize": [
|
2343
|
+
True,
|
2344
|
+
False,
|
2345
|
+
], # If True, the regressors X will be normalized
|
2194
2346
|
}
|
2195
|
-
|
2196
|
-
|
2197
|
-
|
2198
|
-
|
2199
|
-
|
2200
|
-
|
2201
|
-
|
2202
|
-
|
2203
|
-
|
2204
|
-
|
2205
|
-
|
2206
|
-
|
2207
|
-
|
2208
|
-
|
2209
|
-
|
2210
|
-
|
2211
|
-
|
2347
|
+
),
|
2348
|
+
}
|
2349
|
+
else: # median level
|
2350
|
+
param_grids = {
|
2351
|
+
"Random Forest": (
|
2352
|
+
{
|
2353
|
+
"n_estimators": [100, 200, 500],
|
2354
|
+
"max_depth": [None, 10, 20, 30],
|
2355
|
+
"min_samples_split": [2, 5, 10],
|
2356
|
+
"min_samples_leaf": [1, 2, 4],
|
2357
|
+
"class_weight": [None, "balanced"],
|
2358
|
+
}
|
2359
|
+
if purpose == "classification"
|
2360
|
+
else {
|
2361
|
+
"n_estimators": [100, 200, 500],
|
2362
|
+
"max_depth": [None, 10, 20, 30],
|
2363
|
+
"min_samples_split": [2, 5, 10],
|
2364
|
+
"min_samples_leaf": [1, 2, 4],
|
2365
|
+
"max_features": [
|
2366
|
+
"auto",
|
2367
|
+
"sqrt",
|
2368
|
+
"log2",
|
2369
|
+
], # Number of features to consider when looking for the best split
|
2370
|
+
"bootstrap": [
|
2371
|
+
True,
|
2372
|
+
False,
|
2373
|
+
], # Whether bootstrap samples are used when building trees
|
2374
|
+
}
|
2375
|
+
),
|
2212
2376
|
"SVM": {
|
2213
2377
|
"C": [0.1, 1, 10, 100], # Regularization strength
|
2214
|
-
"gamma": [
|
2215
|
-
"kernel": [
|
2378
|
+
"gamma": ["scale", "auto"], # Common gamma values
|
2379
|
+
"kernel": ["rbf", "linear", "poly"],
|
2216
2380
|
},
|
2217
2381
|
"Logistic Regression": {
|
2218
2382
|
"C": [0.1, 1, 10, 100], # Regularization strength
|
2219
|
-
"solver": [
|
2220
|
-
"penalty": [
|
2221
|
-
"max_iter": [
|
2383
|
+
"solver": ["lbfgs", "liblinear", "saga"], # Common solvers
|
2384
|
+
"penalty": ["l2"], # L2 penalty is most common
|
2385
|
+
"max_iter": [
|
2386
|
+
500,
|
2387
|
+
1000,
|
2388
|
+
2000,
|
2389
|
+
], # Increased max_iter for better convergence
|
2222
2390
|
},
|
2223
|
-
"Lasso":{
|
2391
|
+
"Lasso": {
|
2224
2392
|
"alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
|
2225
|
-
"max_iter": [500, 1000, 2000]
|
2393
|
+
"max_iter": [500, 1000, 2000],
|
2226
2394
|
},
|
2227
|
-
"LassoCV":{
|
2395
|
+
"LassoCV": {
|
2228
2396
|
"alphas": [[0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
|
2229
|
-
"max_iter": [500, 1000, 2000]
|
2397
|
+
"max_iter": [500, 1000, 2000],
|
2230
2398
|
},
|
2231
2399
|
"Gradient Boosting": {
|
2232
2400
|
"n_estimators": [100, 200, 500],
|
@@ -2242,25 +2410,44 @@ def predict(
|
|
2242
2410
|
"subsample": [0.8, 1.0],
|
2243
2411
|
"colsample_bytree": [0.8, 1.0],
|
2244
2412
|
},
|
2245
|
-
"KNN":
|
2246
|
-
|
2247
|
-
|
2248
|
-
|
2249
|
-
|
2250
|
-
|
2251
|
-
|
2252
|
-
|
2253
|
-
|
2254
|
-
|
2255
|
-
|
2256
|
-
|
2413
|
+
"KNN": (
|
2414
|
+
{
|
2415
|
+
"n_neighbors": [3, 5, 7, 10],
|
2416
|
+
"weights": ["uniform", "distance"],
|
2417
|
+
"algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
|
2418
|
+
"p": [1, 2],
|
2419
|
+
}
|
2420
|
+
if purpose == "classification"
|
2421
|
+
else {
|
2422
|
+
"n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
|
2423
|
+
"weights": [
|
2424
|
+
"uniform",
|
2425
|
+
"distance",
|
2426
|
+
], # Weight function used in prediction
|
2427
|
+
"metric": [
|
2428
|
+
"euclidean",
|
2429
|
+
"manhattan",
|
2430
|
+
"minkowski",
|
2431
|
+
], # Distance metric
|
2432
|
+
"leaf_size": [
|
2433
|
+
20,
|
2434
|
+
30,
|
2435
|
+
40,
|
2436
|
+
50,
|
2437
|
+
], # Leaf size for KDTree or BallTree algorithms
|
2438
|
+
"p": [
|
2439
|
+
1,
|
2440
|
+
2,
|
2441
|
+
], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
|
2442
|
+
}
|
2443
|
+
),
|
2257
2444
|
"Naive Bayes": {
|
2258
2445
|
"var_smoothing": [1e-9, 1e-8, 1e-7],
|
2259
2446
|
},
|
2260
2447
|
"SVR": {
|
2261
2448
|
"C": [0.1, 1, 10, 100],
|
2262
|
-
"gamma": [
|
2263
|
-
"kernel": [
|
2449
|
+
"gamma": ["scale", "auto"],
|
2450
|
+
"kernel": ["rbf", "linear"],
|
2264
2451
|
},
|
2265
2452
|
"Linear Regression": {
|
2266
2453
|
"fit_intercept": [True, False],
|
@@ -2286,10 +2473,10 @@ def predict(
|
|
2286
2473
|
"learning_rate": [0.01, 0.1],
|
2287
2474
|
"num_leaves": [31, 50, 100],
|
2288
2475
|
"max_depth": [-1, 10, 20],
|
2289
|
-
|
2290
|
-
|
2291
|
-
|
2292
|
-
},
|
2476
|
+
"min_data_in_leaf": [20], # Minimum samples in each leaf
|
2477
|
+
"min_gain_to_split": [0.01], # Minimum gain to allow a split
|
2478
|
+
"scale_pos_weight": [10], # Address class imbalance
|
2479
|
+
},
|
2293
2480
|
"Bagging": {
|
2294
2481
|
"n_estimators": [10, 50, 100],
|
2295
2482
|
"max_samples": [0.5, 0.7, 1.0],
|
@@ -2314,41 +2501,73 @@ def predict(
|
|
2314
2501
|
"Linear Discriminant Analysis": {
|
2315
2502
|
"solver": ["svd", "lsqr", "eigen"],
|
2316
2503
|
"shrinkage": [None, "auto"],
|
2317
|
-
},
|
2318
|
-
|
2319
|
-
|
2320
|
-
|
2321
|
-
|
2322
|
-
|
2323
|
-
|
2324
|
-
|
2325
|
-
|
2326
|
-
|
2327
|
-
|
2328
|
-
|
2329
|
-
|
2330
|
-
|
2331
|
-
|
2332
|
-
|
2333
|
-
|
2334
|
-
|
2335
|
-
"
|
2336
|
-
|
2337
|
-
|
2338
|
-
|
2339
|
-
|
2340
|
-
|
2341
|
-
|
2342
|
-
|
2343
|
-
|
2344
|
-
|
2345
|
-
|
2346
|
-
|
2347
|
-
|
2348
|
-
|
2349
|
-
|
2504
|
+
},
|
2505
|
+
"Quadratic Discriminant Analysis": {
|
2506
|
+
"reg_param": [0.0, 0.1, 0.5, 1.0], # Regularization parameter
|
2507
|
+
"priors": [None, [0.5, 0.5], [0.3, 0.7]], # Class priors
|
2508
|
+
"tol": [
|
2509
|
+
1e-4,
|
2510
|
+
1e-3,
|
2511
|
+
1e-2,
|
2512
|
+
], # Tolerance value for the convergence of the algorithm
|
2513
|
+
},
|
2514
|
+
"Perceptron": {
|
2515
|
+
"alpha": [1e-4, 1e-3, 1e-2], # Regularization parameter
|
2516
|
+
"penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
|
2517
|
+
"max_iter": [1000, 2000], # Maximum number of iterations
|
2518
|
+
"eta0": [1.0, 0.1], # Learning rate for gradient descent
|
2519
|
+
"tol": [1e-3, 1e-4, 1e-5], # Tolerance for stopping criteria
|
2520
|
+
"random_state": [random_state], # Random state for reproducibility
|
2521
|
+
},
|
2522
|
+
"Bernoulli Naive Bayes": {
|
2523
|
+
"alpha": [0.1, 1.0, 10.0], # Additive (Laplace) smoothing parameter
|
2524
|
+
"binarize": [
|
2525
|
+
0.0,
|
2526
|
+
0.5,
|
2527
|
+
1.0,
|
2528
|
+
], # Threshold for binarizing the input features
|
2529
|
+
"fit_prior": [
|
2530
|
+
True,
|
2531
|
+
False,
|
2532
|
+
], # Whether to learn class prior probabilities
|
2533
|
+
},
|
2534
|
+
"SGDClassifier": {
|
2535
|
+
"eta0": [0.01, 0.1, 1.0],
|
2536
|
+
"loss": [
|
2537
|
+
"hinge",
|
2538
|
+
"log",
|
2539
|
+
"modified_huber",
|
2540
|
+
"squared_hinge",
|
2541
|
+
"perceptron",
|
2542
|
+
], # Loss function
|
2543
|
+
"penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
|
2544
|
+
"alpha": [1e-4, 1e-3, 1e-2], # Regularization strength
|
2545
|
+
"l1_ratio": [0.15, 0.5, 0.85], # L1 ratio for elasticnet penalty
|
2546
|
+
"max_iter": [1000, 2000], # Maximum number of iterations
|
2547
|
+
"tol": [1e-3, 1e-4], # Tolerance for stopping criteria
|
2548
|
+
"random_state": [random_state], # Random state for reproducibility
|
2549
|
+
"learning_rate": [
|
2550
|
+
"constant",
|
2551
|
+
"optimal",
|
2552
|
+
"invscaling",
|
2553
|
+
"adaptive",
|
2554
|
+
], # Learning rate schedule
|
2555
|
+
},
|
2556
|
+
"Ridge": (
|
2557
|
+
{"class_weight": [None, "balanced"]}
|
2558
|
+
if purpose == "classification"
|
2559
|
+
else {
|
2560
|
+
"alpha": [0.1, 1, 10, 100],
|
2561
|
+
"solver": [
|
2562
|
+
"auto",
|
2563
|
+
"svd",
|
2564
|
+
"cholesky",
|
2565
|
+
"lsqr",
|
2566
|
+
], # Solver for optimization
|
2567
|
+
}
|
2568
|
+
),
|
2350
2569
|
}
|
2351
|
-
|
2570
|
+
|
2352
2571
|
results = {}
|
2353
2572
|
# Use StratifiedKFold for classification and KFold for regression
|
2354
2573
|
cv = (
|
@@ -2359,11 +2578,11 @@ def predict(
|
|
2359
2578
|
|
2360
2579
|
# Train and validate each model
|
2361
2580
|
for name, clf in tqdm(
|
2362
|
-
|
2363
|
-
|
2364
|
-
|
2365
|
-
|
2366
|
-
|
2581
|
+
models.items(),
|
2582
|
+
desc="models",
|
2583
|
+
colour="green",
|
2584
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
|
2585
|
+
):
|
2367
2586
|
if verbose:
|
2368
2587
|
print(f"\nTraining and validating {name}:")
|
2369
2588
|
|
@@ -2381,7 +2600,7 @@ def predict(
|
|
2381
2600
|
gs.fit(x_train, y_train)
|
2382
2601
|
best_clf = gs.best_estimator_
|
2383
2602
|
# make sure x_train and x_test has the same name
|
2384
|
-
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2603
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2385
2604
|
y_pred = best_clf.predict(x_true)
|
2386
2605
|
|
2387
2606
|
# y_pred_proba
|
@@ -2396,18 +2615,23 @@ def predict(
|
|
2396
2615
|
)
|
2397
2616
|
else:
|
2398
2617
|
y_pred_proba = None # No probability output for certain models
|
2399
|
-
|
2400
|
-
|
2618
|
+
|
2401
2619
|
validation_scores = {}
|
2402
2620
|
if y_true is not None:
|
2403
|
-
validation_scores = cal_metrics(
|
2621
|
+
validation_scores = cal_metrics(
|
2622
|
+
y_true,
|
2623
|
+
y_pred,
|
2624
|
+
y_pred_proba=y_pred_proba,
|
2625
|
+
purpose=purpose,
|
2626
|
+
average="weighted",
|
2627
|
+
)
|
2404
2628
|
|
2405
2629
|
# Calculate ROC curve
|
2406
2630
|
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
2407
2631
|
if y_pred_proba is not None:
|
2408
2632
|
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
2409
2633
|
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
2410
|
-
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba,verbose=False)
|
2634
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
|
2411
2635
|
roc_auc = auc(fpr, tpr)
|
2412
2636
|
roc_info = {
|
2413
2637
|
"fpr": fpr.tolist(),
|
@@ -2425,11 +2649,14 @@ def predict(
|
|
2425
2649
|
}
|
2426
2650
|
else:
|
2427
2651
|
roc_info, pr_info = None, None
|
2428
|
-
if purpose=="classification":
|
2652
|
+
if purpose == "classification":
|
2429
2653
|
results[name] = {
|
2430
|
-
"best_clf": gs.best_estimator_,
|
2654
|
+
"best_clf": gs.best_estimator_,
|
2431
2655
|
"best_params": gs.best_params_,
|
2432
|
-
"auc_indiv":[
|
2656
|
+
"auc_indiv": [
|
2657
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
2658
|
+
for i in range(cv_folds)
|
2659
|
+
],
|
2433
2660
|
"scores": validation_scores,
|
2434
2661
|
"roc_curve": roc_info,
|
2435
2662
|
"pr_curve": pr_info,
|
@@ -2439,11 +2666,11 @@ def predict(
|
|
2439
2666
|
y_pred_proba.tolist() if y_pred_proba is not None else None
|
2440
2667
|
),
|
2441
2668
|
}
|
2442
|
-
else:
|
2669
|
+
else: # "regression"
|
2443
2670
|
results[name] = {
|
2444
|
-
"best_clf": gs.best_estimator_,
|
2671
|
+
"best_clf": gs.best_estimator_,
|
2445
2672
|
"best_params": gs.best_params_,
|
2446
|
-
"scores": validation_scores,
|
2673
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
2447
2674
|
"predictions": y_pred.tolist(),
|
2448
2675
|
"predictions_proba": (
|
2449
2676
|
y_pred_proba.tolist() if y_pred_proba is not None else None
|
@@ -2452,9 +2679,9 @@ def predict(
|
|
2452
2679
|
|
2453
2680
|
else:
|
2454
2681
|
results[name] = {
|
2455
|
-
"best_clf": gs.best_estimator_,
|
2682
|
+
"best_clf": gs.best_estimator_,
|
2456
2683
|
"best_params": gs.best_params_,
|
2457
|
-
"scores": validation_scores,
|
2684
|
+
"scores": validation_scores,
|
2458
2685
|
"predictions": y_pred.tolist(),
|
2459
2686
|
"predictions_proba": (
|
2460
2687
|
y_pred_proba.tolist() if y_pred_proba is not None else None
|
@@ -2465,76 +2692,80 @@ def predict(
|
|
2465
2692
|
df_results = pd.DataFrame.from_dict(results, orient="index")
|
2466
2693
|
|
2467
2694
|
# sort
|
2468
|
-
if y_true is not None and purpose=="classification":
|
2695
|
+
if y_true is not None and purpose == "classification":
|
2469
2696
|
df_scores = pd.DataFrame(
|
2470
|
-
|
2471
|
-
|
2472
|
-
df_results=df_results.loc[df_scores.index]
|
2697
|
+
df_results["scores"].tolist(), index=df_results["scores"].index
|
2698
|
+
).sort_values(by="roc_auc", ascending=False)
|
2699
|
+
df_results = df_results.loc[df_scores.index]
|
2473
2700
|
|
2474
2701
|
if plot_:
|
2475
2702
|
from datetime import datetime
|
2703
|
+
|
2476
2704
|
now_ = datetime.now().strftime("%y%m%d_%H%M%S")
|
2477
|
-
nexttile=plot.subplot(figsize=[12, 10])
|
2478
|
-
plot.heatmap(df_scores, kind="direct",ax=nexttile())
|
2705
|
+
nexttile = plot.subplot(figsize=[12, 10])
|
2706
|
+
plot.heatmap(df_scores, kind="direct", ax=nexttile())
|
2479
2707
|
plot.figsets(xangle=30)
|
2480
2708
|
if dir_save:
|
2481
|
-
ips.figsave(dir_save+f"scores_sorted_heatmap{now_}.pdf")
|
2482
|
-
if df_scores.shape[0]>1
|
2483
|
-
plot.heatmap(df_scores, kind="direct",cluster=True)
|
2709
|
+
ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
|
2710
|
+
if df_scores.shape[0] > 1: # draw cluster
|
2711
|
+
plot.heatmap(df_scores, kind="direct", cluster=True)
|
2484
2712
|
plot.figsets(xangle=30)
|
2485
2713
|
if dir_save:
|
2486
|
-
ips.figsave(dir_save+f"scores_clus{now_}.pdf")
|
2487
|
-
if all([plot_, y_true is not None, purpose==
|
2714
|
+
ips.figsave(dir_save + f"scores_clus{now_}.pdf")
|
2715
|
+
if all([plot_, y_true is not None, purpose == "classification"]):
|
2488
2716
|
try:
|
2489
|
-
if len(models)>3:
|
2717
|
+
if len(models) > 3:
|
2490
2718
|
plot_validate_features(df_results)
|
2491
2719
|
else:
|
2492
|
-
plot_validate_features_single(df_results,figsize=(12,4*len(models)))
|
2720
|
+
plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
|
2493
2721
|
if dir_save:
|
2494
|
-
ips.figsave(dir_save+f"validate_features{now_}.pdf")
|
2722
|
+
ips.figsave(dir_save + f"validate_features{now_}.pdf")
|
2495
2723
|
except Exception as e:
|
2496
2724
|
print(f"Error: 在画图的过程中出现了问题:{e}")
|
2497
2725
|
return df_results
|
2498
2726
|
|
2499
2727
|
|
2500
|
-
def cal_metrics(
|
2728
|
+
def cal_metrics(
|
2729
|
+
y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
|
2730
|
+
):
|
2501
2731
|
"""
|
2502
2732
|
Calculate regression or classification metrics based on the purpose.
|
2503
|
-
|
2733
|
+
|
2504
2734
|
Parameters:
|
2505
2735
|
- y_true: Array of true values.
|
2506
2736
|
- y_pred: Array of predicted labels for classification or predicted values for regression.
|
2507
2737
|
- y_pred_proba: Array of predicted probabilities for classification (optional).
|
2508
2738
|
- purpose: str, "regression" or "classification".
|
2509
2739
|
- average: str, averaging method for multi-class classification ("binary", "micro", "macro", "weighted", etc.).
|
2510
|
-
|
2740
|
+
|
2511
2741
|
Returns:
|
2512
2742
|
- validation_scores: dict of computed metrics.
|
2513
2743
|
"""
|
2514
2744
|
from sklearn.metrics import (
|
2515
|
-
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2530
|
-
|
2531
|
-
|
2745
|
+
mean_squared_error,
|
2746
|
+
mean_absolute_error,
|
2747
|
+
mean_absolute_percentage_error,
|
2748
|
+
explained_variance_score,
|
2749
|
+
r2_score,
|
2750
|
+
mean_squared_log_error,
|
2751
|
+
accuracy_score,
|
2752
|
+
precision_score,
|
2753
|
+
recall_score,
|
2754
|
+
f1_score,
|
2755
|
+
roc_auc_score,
|
2756
|
+
matthews_corrcoef,
|
2757
|
+
confusion_matrix,
|
2758
|
+
balanced_accuracy_score,
|
2759
|
+
average_precision_score,
|
2760
|
+
precision_recall_curve,
|
2761
|
+
)
|
2762
|
+
|
2532
2763
|
validation_scores = {}
|
2533
2764
|
|
2534
2765
|
if purpose == "regression":
|
2535
2766
|
y_true = np.asarray(y_true)
|
2536
2767
|
y_true = y_true.ravel()
|
2537
|
-
y_pred = np.asarray(y_pred)
|
2768
|
+
y_pred = np.asarray(y_pred)
|
2538
2769
|
y_pred = y_pred.ravel()
|
2539
2770
|
# Regression metrics
|
2540
2771
|
validation_scores = {
|
@@ -2544,7 +2775,7 @@ def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average
|
|
2544
2775
|
"r2": r2_score(y_true, y_pred),
|
2545
2776
|
"mape": mean_absolute_percentage_error(y_true, y_pred),
|
2546
2777
|
"explained_variance": explained_variance_score(y_true, y_pred),
|
2547
|
-
"mbd": np.mean(y_pred - y_true) # Mean Bias Deviation
|
2778
|
+
"mbd": np.mean(y_pred - y_true), # Mean Bias Deviation
|
2548
2779
|
}
|
2549
2780
|
# Check if MSLE can be calculated
|
2550
2781
|
if np.all(y_true >= 0) and np.all(y_pred >= 0): # Ensure no negative values
|
@@ -2560,21 +2791,24 @@ def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average
|
|
2560
2791
|
"recall": recall_score(y_true, y_pred, average=average),
|
2561
2792
|
"f1": f1_score(y_true, y_pred, average=average),
|
2562
2793
|
"mcc": matthews_corrcoef(y_true, y_pred),
|
2563
|
-
"specificity": None,
|
2564
|
-
"balanced_accuracy": balanced_accuracy_score(y_true, y_pred)
|
2794
|
+
"specificity": None,
|
2795
|
+
"balanced_accuracy": balanced_accuracy_score(y_true, y_pred),
|
2565
2796
|
}
|
2566
2797
|
|
2567
2798
|
# Confusion matrix to calculate specificity
|
2568
2799
|
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
2569
|
-
validation_scores["specificity"] =
|
2800
|
+
validation_scores["specificity"] = (
|
2801
|
+
tn / (tn + fp) if (tn + fp) > 0 else 0
|
2802
|
+
) # Specificity calculation
|
2570
2803
|
|
2571
|
-
if y_pred_proba is not None:
|
2804
|
+
if y_pred_proba is not None:
|
2572
2805
|
# Calculate ROC-AUC
|
2573
|
-
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2806
|
+
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2574
2807
|
# PR-AUC (Precision-Recall AUC) calculation
|
2575
2808
|
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
2576
2809
|
else:
|
2577
|
-
raise ValueError(
|
2810
|
+
raise ValueError(
|
2811
|
+
"Invalid purpose specified. Choose 'regression' or 'classification'."
|
2812
|
+
)
|
2578
2813
|
|
2579
2814
|
return validation_scores
|
2580
|
-
|