py2ls 0.2.4.14__py3-none-any.whl → 0.2.4.16__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- py2ls/.git/index +0 -0
- py2ls/ips.py +722 -12
- py2ls/ml2ls copy.py +2906 -0
- py2ls/ml2ls.py +898 -243
- py2ls/plot.py +409 -24
- py2ls/translator.py +2 -0
- {py2ls-0.2.4.14.dist-info → py2ls-0.2.4.16.dist-info}/METADATA +1 -1
- {py2ls-0.2.4.14.dist-info → py2ls-0.2.4.16.dist-info}/RECORD +9 -8
- {py2ls-0.2.4.14.dist-info → py2ls-0.2.4.16.dist-info}/WHEEL +0 -0
py2ls/ml2ls.py
CHANGED
@@ -506,7 +506,7 @@ def get_models(
|
|
506
506
|
"Support Vector Machine(svm)",
|
507
507
|
"naive bayes",
|
508
508
|
"Linear Discriminant Analysis (lda)",
|
509
|
-
"
|
509
|
+
"AdaBoost",
|
510
510
|
"DecisionTree",
|
511
511
|
"KNeighbors",
|
512
512
|
"Bagging",
|
@@ -585,7 +585,7 @@ def get_features(
|
|
585
585
|
"Support Vector Machine(svm)",
|
586
586
|
"naive bayes",
|
587
587
|
"Linear Discriminant Analysis (lda)",
|
588
|
-
"
|
588
|
+
"AdaBoost",
|
589
589
|
"DecisionTree",
|
590
590
|
"KNeighbors",
|
591
591
|
"Bagging",
|
@@ -616,10 +616,10 @@ def get_features(
|
|
616
616
|
if isinstance(y, str) and y in X.columns:
|
617
617
|
y_col_name = y
|
618
618
|
y = X[y]
|
619
|
-
y = ips.df_encoder(pd.DataFrame(y), method="
|
619
|
+
y = ips.df_encoder(pd.DataFrame(y), method="label")
|
620
620
|
X = X.drop(y_col_name, axis=1)
|
621
621
|
else:
|
622
|
-
y = ips.df_encoder(pd.DataFrame(y), method="
|
622
|
+
y = ips.df_encoder(pd.DataFrame(y), method="label").values.ravel()
|
623
623
|
y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
|
624
624
|
y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
|
625
625
|
|
@@ -699,9 +699,11 @@ def get_features(
|
|
699
699
|
"Support Vector Machine(svm)",
|
700
700
|
"Naive Bayes",
|
701
701
|
"Linear Discriminant Analysis (lda)",
|
702
|
-
"
|
702
|
+
"AdaBoost",
|
703
703
|
]
|
704
704
|
cls = [ips.strcmp(i, cls_)[0] for i in cls]
|
705
|
+
|
706
|
+
feature_importances = {}
|
705
707
|
|
706
708
|
# Lasso Feature Selection
|
707
709
|
lasso_importances = (
|
@@ -712,6 +714,7 @@ def get_features(
|
|
712
714
|
lasso_selected_features = (
|
713
715
|
lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
|
714
716
|
)
|
717
|
+
feature_importances['lasso']=lasso_importances.head(n_features)
|
715
718
|
# Ridge
|
716
719
|
ridge_importances = (
|
717
720
|
features_ridge(x_train, y_train, ridge_params)
|
@@ -721,6 +724,7 @@ def get_features(
|
|
721
724
|
selected_ridge_features = (
|
722
725
|
ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
|
723
726
|
)
|
727
|
+
feature_importances['ridge']=ridge_importances.head(n_features)
|
724
728
|
# Elastic Net
|
725
729
|
enet_importances = (
|
726
730
|
features_enet(x_train, y_train, enet_params)
|
@@ -730,6 +734,7 @@ def get_features(
|
|
730
734
|
selected_enet_features = (
|
731
735
|
enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
|
732
736
|
)
|
737
|
+
feature_importances['Enet']=enet_importances.head(n_features)
|
733
738
|
# Random Forest Feature Importance
|
734
739
|
rf_importances = (
|
735
740
|
features_rf(x_train, y_train, rf_params)
|
@@ -741,6 +746,7 @@ def get_features(
|
|
741
746
|
if "Random Forest" in cls
|
742
747
|
else []
|
743
748
|
)
|
749
|
+
feature_importances['Random Forest']=rf_importances.head(n_features)
|
744
750
|
# Gradient Boosting Feature Importance
|
745
751
|
gb_importances = (
|
746
752
|
features_gradient_boosting(x_train, y_train, gb_params)
|
@@ -752,6 +758,7 @@ def get_features(
|
|
752
758
|
if "Gradient Boosting" in cls
|
753
759
|
else []
|
754
760
|
)
|
761
|
+
feature_importances['Gradient Boosting']=gb_importances.head(n_features)
|
755
762
|
# xgb
|
756
763
|
xgb_importances = (
|
757
764
|
features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
|
@@ -759,6 +766,7 @@ def get_features(
|
|
759
766
|
top_xgb_features = (
|
760
767
|
xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
|
761
768
|
)
|
769
|
+
feature_importances['xgb']=xgb_importances.head(n_features)
|
762
770
|
|
763
771
|
# SVM with RFE
|
764
772
|
selected_svm_features = (
|
@@ -773,6 +781,7 @@ def get_features(
|
|
773
781
|
selected_lda_features = (
|
774
782
|
lda_importances.head(n_features)["feature"].values if "lda" in cls else []
|
775
783
|
)
|
784
|
+
feature_importances['lda']=lda_importances.head(n_features)
|
776
785
|
# AdaBoost Feature Importance
|
777
786
|
adaboost_importances = (
|
778
787
|
features_adaboost(x_train, y_train, adaboost_params)
|
@@ -784,6 +793,7 @@ def get_features(
|
|
784
793
|
if "AdaBoost" in cls
|
785
794
|
else []
|
786
795
|
)
|
796
|
+
feature_importances['AdaBoost']=adaboost_importances.head(n_features)
|
787
797
|
# Decision Tree Feature Importance
|
788
798
|
dt_importances = (
|
789
799
|
features_decision_tree(x_train, y_train, dt_params)
|
@@ -794,7 +804,8 @@ def get_features(
|
|
794
804
|
dt_importances.head(n_features)["feature"].values
|
795
805
|
if "Decision Tree" in cls
|
796
806
|
else []
|
797
|
-
)
|
807
|
+
)
|
808
|
+
feature_importances['Decision Tree']=dt_importances.head(n_features)
|
798
809
|
# Bagging Feature Importance
|
799
810
|
bagging_importances = (
|
800
811
|
features_bagging(x_train, y_train, bagging_params)
|
@@ -806,6 +817,7 @@ def get_features(
|
|
806
817
|
if "Bagging" in cls
|
807
818
|
else []
|
808
819
|
)
|
820
|
+
feature_importances['Bagging']=bagging_importances.head(n_features)
|
809
821
|
# KNN Feature Importance via Permutation
|
810
822
|
knn_importances = (
|
811
823
|
features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
|
@@ -813,6 +825,7 @@ def get_features(
|
|
813
825
|
top_knn_features = (
|
814
826
|
knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
|
815
827
|
)
|
828
|
+
feature_importances['KNN']=knn_importances.head(n_features)
|
816
829
|
|
817
830
|
#! Find common features
|
818
831
|
common_features = ips.shared(
|
@@ -915,6 +928,7 @@ def get_features(
|
|
915
928
|
"cv_train_scores": cv_train_results_df,
|
916
929
|
"cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
|
917
930
|
"common_features": list(common_features),
|
931
|
+
"feature_importances":feature_importances
|
918
932
|
}
|
919
933
|
if all([plot_, dir_save]):
|
920
934
|
from datetime import datetime
|
@@ -927,6 +941,7 @@ def get_features(
|
|
927
941
|
"cv_train_scores": pd.DataFrame(),
|
928
942
|
"cv_test_scores": pd.DataFrame(),
|
929
943
|
"common_features": [],
|
944
|
+
"feature_importances":{}
|
930
945
|
}
|
931
946
|
print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
|
932
947
|
return results
|
@@ -1217,142 +1232,335 @@ def validate_features(
|
|
1217
1232
|
|
1218
1233
|
# # If you want to access validation scores
|
1219
1234
|
# print(validation_results)
|
1220
|
-
def plot_validate_features(res_val):
|
1235
|
+
def plot_validate_features(res_val,is_binary=True,figsize=None):
|
1221
1236
|
"""
|
1222
1237
|
plot the results of 'validate_features()'
|
1223
1238
|
"""
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1239
|
+
if is_binary:
|
1240
|
+
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1241
|
+
if res_val.shape[0] > 5:
|
1242
|
+
alpha = 0
|
1243
|
+
figsize = [8, 10] if figsize is None else figsize
|
1244
|
+
subplot_layout = [1, 2]
|
1245
|
+
ncols = 2
|
1246
|
+
bbox_to_anchor = [1.5, 0.6]
|
1247
|
+
else:
|
1248
|
+
alpha = 0.03
|
1249
|
+
figsize = [10, 6] if figsize is None else figsize
|
1250
|
+
subplot_layout = [1, 1]
|
1251
|
+
ncols = 1
|
1252
|
+
bbox_to_anchor = [1, 1]
|
1253
|
+
nexttile = plot.subplot(figsize=figsize)
|
1254
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1255
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1256
|
+
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1257
|
+
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1258
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1259
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1260
|
+
plot_roc_curve(
|
1261
|
+
fpr,
|
1262
|
+
tpr,
|
1263
|
+
mean_auc,
|
1264
|
+
lower_ci,
|
1265
|
+
upper_ci,
|
1266
|
+
model_name=model_name,
|
1267
|
+
lw=1.5,
|
1268
|
+
color=colors[i],
|
1269
|
+
alpha=alpha,
|
1270
|
+
ax=ax,
|
1271
|
+
)
|
1272
|
+
plot.figsets(
|
1273
|
+
sp=2,
|
1274
|
+
legend=dict(
|
1275
|
+
loc="upper right",
|
1276
|
+
ncols=ncols,
|
1277
|
+
fontsize=8,
|
1278
|
+
bbox_to_anchor=[1.5, 0.6],
|
1279
|
+
markerscale=0.8,
|
1280
|
+
),
|
1255
1281
|
)
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1274
|
-
model_name=model_name,
|
1275
|
-
color=colors[i],
|
1276
|
-
lw=1.5,
|
1277
|
-
alpha=alpha,
|
1278
|
-
ax=ax,
|
1282
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1283
|
+
|
1284
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1285
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1286
|
+
plot_pr_curve(
|
1287
|
+
recall=res_val["pr_curve"][model_name]["recall"],
|
1288
|
+
precision=res_val["pr_curve"][model_name]["precision"],
|
1289
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1290
|
+
model_name=model_name,
|
1291
|
+
color=colors[i],
|
1292
|
+
lw=1.5,
|
1293
|
+
alpha=alpha,
|
1294
|
+
ax=ax,
|
1295
|
+
)
|
1296
|
+
plot.figsets(
|
1297
|
+
sp=2,
|
1298
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1279
1299
|
)
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1300
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1301
|
+
else:
|
1302
|
+
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1303
|
+
modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
|
1304
|
+
classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
|
1305
|
+
if res_val.shape[0] > 5:
|
1306
|
+
alpha = 0
|
1307
|
+
figsize = [8, 8*2*(len(classes))] if figsize is None else figsize
|
1308
|
+
subplot_layout = [1, 2]
|
1309
|
+
ncols = 2
|
1310
|
+
bbox_to_anchor = [1.5, 0.6]
|
1311
|
+
else:
|
1312
|
+
alpha = 0.03
|
1313
|
+
figsize = [10, 6*(len(classes))] if figsize is None else figsize
|
1314
|
+
subplot_layout = [1, 1]
|
1315
|
+
ncols = 1
|
1316
|
+
bbox_to_anchor = [1, 1]
|
1317
|
+
nexttile = plot.subplot(2*(len(classes)),2,figsize=figsize)
|
1318
|
+
for iclass, class_ in enumerate(classes):
|
1319
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1320
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1321
|
+
fpr = res_val["roc_curve"][model_name]["fpr"][class_]
|
1322
|
+
tpr = res_val["roc_curve"][model_name]["tpr"][class_]
|
1323
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
|
1324
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
|
1325
|
+
plot_roc_curve(
|
1326
|
+
fpr,
|
1327
|
+
tpr,
|
1328
|
+
mean_auc,
|
1329
|
+
lower_ci,
|
1330
|
+
upper_ci,
|
1331
|
+
model_name=model_name,
|
1332
|
+
lw=1.5,
|
1333
|
+
color=colors[i],
|
1334
|
+
alpha=alpha,
|
1335
|
+
ax=ax,
|
1336
|
+
)
|
1337
|
+
plot.figsets(
|
1338
|
+
sp=2,
|
1339
|
+
title=class_,
|
1340
|
+
legend=dict(
|
1341
|
+
loc="upper right",
|
1342
|
+
ncols=ncols,
|
1343
|
+
fontsize=8,
|
1344
|
+
bbox_to_anchor=[1.5, 0.6],
|
1345
|
+
markerscale=0.8,
|
1346
|
+
),
|
1347
|
+
)
|
1348
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1349
|
+
|
1350
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1351
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1352
|
+
plot_pr_curve(
|
1353
|
+
recall=res_val["pr_curve"][model_name]["recall"][iclass],
|
1354
|
+
precision=res_val["pr_curve"][model_name]["precision"][iclass],
|
1355
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
|
1356
|
+
model_name=model_name,
|
1357
|
+
color=colors[i],
|
1358
|
+
lw=1.5,
|
1359
|
+
alpha=alpha,
|
1360
|
+
ax=ax,
|
1361
|
+
)
|
1362
|
+
plot.figsets(
|
1363
|
+
sp=2,
|
1364
|
+
title=class_,
|
1365
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1366
|
+
)
|
1285
1367
|
|
1368
|
+
def plot_validate_features_single(res_val, figsize=None,is_binary=True):
|
1369
|
+
if is_binary:
|
1370
|
+
if figsize is None:
|
1371
|
+
nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3,figsize=[13,4*len(ips.flatten(res_val["pr_curve"].index))])
|
1372
|
+
else:
|
1373
|
+
nexttile = plot.subplot(
|
1374
|
+
len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
|
1375
|
+
)
|
1376
|
+
for model_name in ips.flatten(res_val["pr_curve"].index):
|
1377
|
+
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1378
|
+
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1379
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1380
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1381
|
+
|
1382
|
+
# Plotting
|
1383
|
+
plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
|
1384
|
+
model_name=model_name, ax=nexttile())
|
1385
|
+
plot.figsets(title=model_name, sp=2)
|
1386
|
+
|
1387
|
+
plot_pr_binary(
|
1388
|
+
recall=res_val["pr_curve"][model_name]["recall"],
|
1389
|
+
precision=res_val["pr_curve"][model_name]["precision"],
|
1390
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1391
|
+
model_name=model_name,
|
1392
|
+
ax=nexttile(),
|
1393
|
+
)
|
1394
|
+
plot.figsets(title=model_name, sp=2)
|
1286
1395
|
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1396
|
+
# plot cm
|
1397
|
+
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1398
|
+
plot.figsets(title=model_name, sp=2)
|
1290
1399
|
else:
|
1291
|
-
|
1292
|
-
|
1293
|
-
)
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1400
|
+
|
1401
|
+
modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
|
1402
|
+
classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
|
1403
|
+
if figsize is None:
|
1404
|
+
nexttile = plot.subplot(len(modname_tmp), 3,figsize=[15,len(modname_tmp)*5])
|
1405
|
+
else:
|
1406
|
+
nexttile = plot.subplot(len(modname_tmp), 3, figsize=figsize)
|
1407
|
+
colors = plot.get_color(len(classes))
|
1408
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1409
|
+
ax = nexttile()
|
1410
|
+
for iclass, class_ in enumerate(classes):
|
1411
|
+
fpr = res_val["roc_curve"][model_name]["fpr"][class_]
|
1412
|
+
tpr = res_val["roc_curve"][model_name]["tpr"][class_]
|
1413
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
|
1414
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
|
1415
|
+
plot_roc_curve(
|
1416
|
+
fpr,
|
1417
|
+
tpr,
|
1418
|
+
mean_auc,
|
1419
|
+
lower_ci,
|
1420
|
+
upper_ci,
|
1421
|
+
model_name=class_,
|
1422
|
+
lw=1.5,
|
1423
|
+
color=colors[iclass],
|
1424
|
+
alpha=0.03,
|
1425
|
+
ax=ax,
|
1426
|
+
)
|
1427
|
+
plot.figsets(
|
1428
|
+
sp=2,
|
1429
|
+
title=model_name,
|
1430
|
+
legend=dict(
|
1431
|
+
loc="best",
|
1432
|
+
fontsize=8,
|
1433
|
+
),
|
1434
|
+
)
|
1435
|
+
|
1436
|
+
ax = nexttile()
|
1437
|
+
for iclass, class_ in enumerate(classes):
|
1438
|
+
plot_pr_curve(
|
1439
|
+
recall=res_val["pr_curve"][model_name]["recall"][iclass],
|
1440
|
+
precision=res_val["pr_curve"][model_name]["precision"][iclass],
|
1441
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
|
1442
|
+
model_name=class_,
|
1443
|
+
color=colors[iclass],
|
1444
|
+
lw=1.5,
|
1445
|
+
alpha=0.03,
|
1446
|
+
ax=ax,
|
1447
|
+
)
|
1448
|
+
plot.figsets(
|
1449
|
+
sp=2,
|
1450
|
+
title=class_,
|
1451
|
+
legend=dict(loc="best", fontsize=8),
|
1452
|
+
)
|
1453
|
+
|
1454
|
+
plot_cm(res_val["confusion_matrix"][model_name],labels_name=classes, ax=nexttile(), normalize=False)
|
1455
|
+
plot.figsets(title=model_name, sp=2)
|
1313
1456
|
|
1314
|
-
# plot cm
|
1315
|
-
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1316
|
-
plot.figsets(title=model_name, sp=2)
|
1317
1457
|
|
1458
|
+
def cal_precision_recall(
|
1459
|
+
y_true, y_pred_proba, is_binary=True):
|
1460
|
+
if is_binary:
|
1461
|
+
precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
|
1462
|
+
avg_precision_ = average_precision_score(y_true, y_pred_proba)
|
1463
|
+
return precision_, recall_,avg_precision_
|
1464
|
+
else:
|
1465
|
+
n_classes = y_pred_proba.shape[1] # Number of classes
|
1466
|
+
precision_ = []
|
1467
|
+
recall_ = []
|
1468
|
+
|
1469
|
+
# One-vs-rest approach for multi-class precision-recall curve
|
1470
|
+
for class_idx in range(n_classes):
|
1471
|
+
precision, recall, _ = precision_recall_curve(
|
1472
|
+
(y_true == class_idx).astype(int), # Binarize true labels for the current class
|
1473
|
+
y_pred_proba[:, class_idx], # Probabilities for the current class
|
1474
|
+
)
|
1318
1475
|
|
1476
|
+
precision_.append(precision)
|
1477
|
+
recall_.append(recall)
|
1478
|
+
# Optionally, you can compute average precision for each class
|
1479
|
+
avg_precision_ = []
|
1480
|
+
for class_idx in range(n_classes):
|
1481
|
+
avg_precision = average_precision_score(
|
1482
|
+
(y_true == class_idx).astype(int), # Binarize true labels for the current class
|
1483
|
+
y_pred_proba[:, class_idx], # Probabilities for the current class
|
1484
|
+
)
|
1485
|
+
avg_precision_.append(avg_precision)
|
1486
|
+
return precision_, recall_,avg_precision_
|
1487
|
+
|
1319
1488
|
def cal_auc_ci(
|
1320
|
-
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
|
1489
|
+
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,is_binary=True, verbose=True
|
1321
1490
|
):
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1491
|
+
if is_binary:
|
1492
|
+
y_true = np.asarray(y_true)
|
1493
|
+
y_pred = np.asarray(y_pred)
|
1494
|
+
bootstrapped_scores = []
|
1495
|
+
if verbose:
|
1496
|
+
print("auroc score:", roc_auc_score(y_true, y_pred))
|
1497
|
+
rng = np.random.RandomState(random_state)
|
1498
|
+
for i in range(n_bootstraps):
|
1499
|
+
# bootstrap by sampling with replacement on the prediction indices
|
1500
|
+
indices = rng.randint(0, len(y_pred), len(y_pred))
|
1501
|
+
if len(np.unique(y_true[indices])) < 2:
|
1502
|
+
# We need at least one positive and one negative sample for ROC AUC
|
1503
|
+
# to be defined: reject the sample
|
1504
|
+
continue
|
1505
|
+
if isinstance(y_true, np.ndarray):
|
1506
|
+
score = roc_auc_score(y_true[indices], y_pred[indices])
|
1507
|
+
else:
|
1508
|
+
score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
|
1509
|
+
bootstrapped_scores.append(score)
|
1510
|
+
# print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
|
1511
|
+
sorted_scores = np.array(bootstrapped_scores)
|
1512
|
+
sorted_scores.sort()
|
1513
|
+
|
1514
|
+
# Computing the lower and upper bound of the 90% confidence interval
|
1515
|
+
# You can change the bounds percentiles to 0.025 and 0.975 to get
|
1516
|
+
# a 95% confidence interval instead.
|
1517
|
+
confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
|
1518
|
+
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1519
|
+
if verbose:
|
1520
|
+
print(
|
1521
|
+
"Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
|
1522
|
+
confidence_lower, confidence_upper
|
1523
|
+
)
|
1353
1524
|
)
|
1354
|
-
|
1355
|
-
|
1525
|
+
return confidence_lower, confidence_upper
|
1526
|
+
else:
|
1527
|
+
from sklearn.preprocessing import label_binarize
|
1528
|
+
# Multi-class classification case
|
1529
|
+
y_true = np.asarray(y_true)
|
1530
|
+
y_pred = np.asarray(y_pred)
|
1531
|
+
|
1532
|
+
# Binarize the multi-class labels for OvR computation
|
1533
|
+
y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) # One-vs-Rest transformation
|
1534
|
+
n_classes = y_true_bin.shape[1] # Number of classes
|
1535
|
+
|
1536
|
+
bootstrapped_scores = np.zeros((n_classes, n_bootstraps)) # Store scores for each class
|
1537
|
+
|
1538
|
+
if verbose:
|
1539
|
+
print("AUROC scores for each class:")
|
1540
|
+
for i in range(n_classes):
|
1541
|
+
print(f"Class {i}: {roc_auc_score(y_true_bin[:, i], y_pred[:, i])}")
|
1542
|
+
|
1543
|
+
rng = np.random.RandomState(random_state)
|
1544
|
+
for i in range(n_bootstraps):
|
1545
|
+
indices = rng.randint(0, len(y_pred), len(y_pred))
|
1546
|
+
for class_idx in range(n_classes):
|
1547
|
+
if len(np.unique(y_true_bin[indices, class_idx])) < 2:
|
1548
|
+
continue # Reject if the class doesn't have both positive and negative samples
|
1549
|
+
score = roc_auc_score(y_true_bin[indices, class_idx], y_pred[indices, class_idx])
|
1550
|
+
bootstrapped_scores[class_idx, i] = score
|
1551
|
+
|
1552
|
+
# Calculating the confidence intervals for each class
|
1553
|
+
confidence_intervals = []
|
1554
|
+
for class_idx in range(n_classes):
|
1555
|
+
sorted_scores = np.sort(bootstrapped_scores[class_idx])
|
1556
|
+
confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
|
1557
|
+
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1558
|
+
confidence_intervals.append((confidence_lower, confidence_upper))
|
1559
|
+
|
1560
|
+
if verbose:
|
1561
|
+
print(f"Class {class_idx} - Confidence interval: [{confidence_lower:.3f} - {confidence_upper:.3f}]")
|
1562
|
+
|
1563
|
+
return confidence_intervals
|
1356
1564
|
|
1357
1565
|
|
1358
1566
|
def plot_roc_curve(
|
@@ -1517,7 +1725,7 @@ def plot_pr_binary(
|
|
1517
1725
|
|
1518
1726
|
pr_boundary = interp1d(recall, precision, kind="linear", fill_value="extrapolate")
|
1519
1727
|
for f_score in f_scores:
|
1520
|
-
x_vals = np.linspace(0.01, 1,
|
1728
|
+
x_vals = np.linspace(0.01, 1, 20000)
|
1521
1729
|
y_vals = f_score * x_vals / (2 * x_vals - f_score)
|
1522
1730
|
y_vals_clipped = np.minimum(y_vals, pr_boundary(x_vals))
|
1523
1731
|
y_vals_clipped = np.clip(y_vals_clipped, 1e-3, None) # Prevent going to zero
|
@@ -1553,7 +1761,7 @@ def plot_pr_binary(
|
|
1553
1761
|
def plot_cm(
|
1554
1762
|
cm,
|
1555
1763
|
labels_name=None,
|
1556
|
-
thresh=0.8,
|
1764
|
+
thresh=0.8, # for set color
|
1557
1765
|
axis_labels=None,
|
1558
1766
|
cmap="Reds",
|
1559
1767
|
normalize=True,
|
@@ -2029,11 +2237,21 @@ def predict(
|
|
2029
2237
|
if isinstance(y_train, str) and y_train in x_train.columns:
|
2030
2238
|
y_train_col_name = y_train
|
2031
2239
|
y_train = x_train[y_train]
|
2032
|
-
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
2240
|
+
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
2033
2241
|
x_train = x_train.drop(y_train_col_name, axis=1)
|
2242
|
+
# else:
|
2243
|
+
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
2244
|
+
y_train=pd.DataFrame(y_train)
|
2245
|
+
if y_train.select_dtypes(include=np.number).empty:
|
2246
|
+
y_train_=ips.df_encoder(y_train, method="dummy",drop=None)
|
2247
|
+
is_binary = False if y_train_.shape[1] >2 else True
|
2034
2248
|
else:
|
2035
|
-
|
2249
|
+
y_train_=ips.flatten(y_train.values)
|
2250
|
+
is_binary = False if len(y_train_)>2 else True
|
2036
2251
|
|
2252
|
+
if is_binary:
|
2253
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="label")
|
2254
|
+
print('is_binary:',is_binary)
|
2037
2255
|
if x_true is None:
|
2038
2256
|
x_train, x_true, y_train, y_true = train_test_split(
|
2039
2257
|
x_train,
|
@@ -2042,23 +2260,27 @@ def predict(
|
|
2042
2260
|
random_state=random_state,
|
2043
2261
|
stratify=y_train if purpose == "classification" else None,
|
2044
2262
|
)
|
2263
|
+
|
2045
2264
|
if isinstance(y_train, str) and y_train in x_train.columns:
|
2046
2265
|
y_train_col_name = y_train
|
2047
2266
|
y_train = x_train[y_train]
|
2048
|
-
y_train = ips.df_encoder(pd.DataFrame(y_train), method="
|
2267
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="label") if is_binary else y_train
|
2049
2268
|
x_train = x_train.drop(y_train_col_name, axis=1)
|
2050
|
-
|
2269
|
+
if is_binary:
|
2051
2270
|
y_train = ips.df_encoder(
|
2052
|
-
pd.DataFrame(y_train), method="
|
2053
|
-
).values.ravel()
|
2271
|
+
pd.DataFrame(y_train), method="label"
|
2272
|
+
).values.ravel()
|
2273
|
+
|
2054
2274
|
if y_true is not None:
|
2055
2275
|
if isinstance(y_true, str) and y_true in x_true.columns:
|
2056
2276
|
y_true_col_name = y_true
|
2057
2277
|
y_true = x_true[y_true]
|
2058
|
-
y_true = ips.df_encoder(pd.DataFrame(y_true), method="
|
2278
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="label") if is_binary else y_true
|
2279
|
+
y_true = pd.DataFrame(y_true)
|
2059
2280
|
x_true = x_true.drop(y_true_col_name, axis=1)
|
2060
|
-
|
2061
|
-
y_true = ips.df_encoder(pd.DataFrame(y_true), method="
|
2281
|
+
if is_binary:
|
2282
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="label").values.ravel()
|
2283
|
+
y_true = pd.DataFrame(y_true)
|
2062
2284
|
|
2063
2285
|
# to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
|
2064
2286
|
|
@@ -2068,7 +2290,6 @@ def predict(
|
|
2068
2290
|
y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
|
2069
2291
|
)
|
2070
2292
|
y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
|
2071
|
-
|
2072
2293
|
# Ensure common features are selected
|
2073
2294
|
if common_features is not None:
|
2074
2295
|
x_train, x_true = x_train[common_features], x_true[common_features]
|
@@ -2077,10 +2298,7 @@ def predict(
|
|
2077
2298
|
x_train, x_true = x_train[share_col_names], x_true[share_col_names]
|
2078
2299
|
|
2079
2300
|
x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
|
2080
|
-
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
|
2081
|
-
x_true, method="dummy"
|
2082
|
-
)
|
2083
|
-
|
2301
|
+
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(x_true, method="dummy")
|
2084
2302
|
# Handle class imbalance using SMOTE (only for classification)
|
2085
2303
|
if (
|
2086
2304
|
smote
|
@@ -2091,7 +2309,13 @@ def predict(
|
|
2091
2309
|
|
2092
2310
|
smote_sampler = SMOTE(random_state=random_state)
|
2093
2311
|
x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
|
2094
|
-
|
2312
|
+
if not is_binary:
|
2313
|
+
if isinstance(y_train, np.ndarray):
|
2314
|
+
y_train = ips.df_encoder(data=pd.DataFrame(y_train),method='label')
|
2315
|
+
y_train=np.asarray(y_train)
|
2316
|
+
if isinstance(y_train, np.ndarray):
|
2317
|
+
y_true = ips.df_encoder(data=pd.DataFrame(y_true),method='label')
|
2318
|
+
y_true=np.asarray(y_true)
|
2095
2319
|
# Hyperparameter grids for tuning
|
2096
2320
|
if cv_level in ["low", "simple", "s", "l"]:
|
2097
2321
|
param_grids = {
|
@@ -2670,95 +2894,181 @@ def predict(
|
|
2670
2894
|
print(f"\nTraining and validating {name}:")
|
2671
2895
|
|
2672
2896
|
# Grid search with KFold or StratifiedKFold
|
2673
|
-
|
2674
|
-
|
2675
|
-
|
2676
|
-
|
2677
|
-
|
2678
|
-
|
2679
|
-
|
2680
|
-
|
2681
|
-
|
2682
|
-
|
2683
|
-
gs.fit(x_train, y_train)
|
2684
|
-
best_clf = gs.best_estimator_
|
2685
|
-
# make sure x_train and x_test has the same name
|
2686
|
-
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2687
|
-
y_pred = best_clf.predict(x_true)
|
2688
|
-
|
2689
|
-
# y_pred_proba
|
2690
|
-
if hasattr(best_clf, "predict_proba"):
|
2691
|
-
y_pred_proba = best_clf.predict_proba(x_true)[:, 1]
|
2692
|
-
elif hasattr(best_clf, "decision_function"):
|
2693
|
-
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
2694
|
-
y_pred_proba = best_clf.decision_function(x_true)
|
2695
|
-
# Ensure y_pred_proba is within 0 and 1 bounds
|
2696
|
-
y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
|
2697
|
-
y_pred_proba.max() - y_pred_proba.min()
|
2897
|
+
if is_binary:
|
2898
|
+
gs = GridSearchCV(
|
2899
|
+
clf,
|
2900
|
+
param_grid=param_grids.get(name, {}),
|
2901
|
+
scoring=(
|
2902
|
+
"roc_auc" if purpose == "classification" else "neg_mean_squared_error"
|
2903
|
+
),
|
2904
|
+
cv=cv,
|
2905
|
+
n_jobs=n_jobs,
|
2906
|
+
verbose=verbose,
|
2698
2907
|
)
|
2908
|
+
|
2909
|
+
gs.fit(x_train, y_train)
|
2910
|
+
best_clf = gs.best_estimator_
|
2911
|
+
# make sure x_train and x_test has the same name
|
2912
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2913
|
+
y_pred = best_clf.predict(x_true)
|
2914
|
+
if hasattr(best_clf, "predict_proba"):
|
2915
|
+
y_pred_proba = best_clf.predict_proba(x_true)
|
2916
|
+
print("Shape of predicted probabilities:", y_pred_proba.shape)
|
2917
|
+
if y_pred_proba.shape[1] == 1:
|
2918
|
+
y_pred_proba = np.hstack([1 - y_pred_proba, y_pred_proba]) # Add missing class probabilities
|
2919
|
+
y_pred_proba = y_pred_proba[:, 1]
|
2920
|
+
elif hasattr(best_clf, "decision_function"):
|
2921
|
+
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
2922
|
+
y_pred_proba = best_clf.decision_function(x_true)
|
2923
|
+
# Ensure y_pred_proba is within 0 and 1 bounds
|
2924
|
+
y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
|
2925
|
+
y_pred_proba.max() - y_pred_proba.min()
|
2926
|
+
)
|
2927
|
+
else:
|
2928
|
+
y_pred_proba = None # No probability output for certain models
|
2699
2929
|
else:
|
2700
|
-
|
2930
|
+
gs = GridSearchCV(
|
2931
|
+
clf,
|
2932
|
+
param_grid=param_grids.get(name, {}),
|
2933
|
+
scoring=(
|
2934
|
+
"roc_auc_ovr" if purpose == "classification" else "neg_mean_squared_error"
|
2935
|
+
),
|
2936
|
+
cv=cv,
|
2937
|
+
n_jobs=n_jobs,
|
2938
|
+
verbose=verbose,
|
2939
|
+
)
|
2701
2940
|
|
2941
|
+
# Fit GridSearchCV
|
2942
|
+
gs.fit(x_train, y_train)
|
2943
|
+
best_clf = gs.best_estimator_
|
2944
|
+
|
2945
|
+
# Ensure x_true aligns with x_train columns
|
2946
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2947
|
+
y_pred = best_clf.predict(x_true)
|
2948
|
+
|
2949
|
+
# Handle prediction probabilities for multiclass
|
2950
|
+
if hasattr(best_clf, "predict_proba"):
|
2951
|
+
y_pred_proba = best_clf.predict_proba(x_true)
|
2952
|
+
elif hasattr(best_clf, "decision_function"):
|
2953
|
+
y_pred_proba = best_clf.decision_function(x_true)
|
2954
|
+
|
2955
|
+
# Normalize for multiclass if necessary
|
2956
|
+
if y_pred_proba.ndim == 2:
|
2957
|
+
y_pred_proba = (y_pred_proba - y_pred_proba.min(axis=1, keepdims=True)) / \
|
2958
|
+
(y_pred_proba.max(axis=1, keepdims=True) - y_pred_proba.min(axis=1, keepdims=True))
|
2959
|
+
else:
|
2960
|
+
y_pred_proba = None # No probability output for certain models
|
2961
|
+
|
2702
2962
|
validation_scores = {}
|
2703
|
-
|
2963
|
+
|
2964
|
+
if y_true is not None and y_pred_proba is not None:
|
2704
2965
|
validation_scores = cal_metrics(
|
2705
2966
|
y_true,
|
2706
2967
|
y_pred,
|
2707
2968
|
y_pred_proba=y_pred_proba,
|
2969
|
+
is_binary=is_binary,
|
2708
2970
|
purpose=purpose,
|
2709
2971
|
average="weighted",
|
2710
2972
|
)
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
2724
|
-
|
2725
|
-
|
2726
|
-
|
2727
|
-
|
2728
|
-
|
2729
|
-
|
2730
|
-
|
2731
|
-
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2736
|
-
|
2737
|
-
|
2738
|
-
|
2739
|
-
|
2740
|
-
|
2741
|
-
|
2742
|
-
|
2743
|
-
|
2744
|
-
|
2745
|
-
|
2746
|
-
|
2747
|
-
|
2748
|
-
|
2749
|
-
|
2750
|
-
|
2751
|
-
|
2752
|
-
|
2753
|
-
|
2754
|
-
|
2755
|
-
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2973
|
+
if is_binary:
|
2974
|
+
# Calculate ROC curve
|
2975
|
+
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
2976
|
+
if y_pred_proba is not None:
|
2977
|
+
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
2978
|
+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
2979
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
|
2980
|
+
roc_auc = auc(fpr, tpr)
|
2981
|
+
roc_info = {
|
2982
|
+
"fpr": fpr.tolist(),
|
2983
|
+
"tpr": tpr.tolist(),
|
2984
|
+
"auc": roc_auc,
|
2985
|
+
"ci95": (lower_ci, upper_ci),
|
2986
|
+
}
|
2987
|
+
# precision-recall curve
|
2988
|
+
precision_, recall_, _ = cal_precision_recall(y_true, y_pred_proba)
|
2989
|
+
avg_precision_ = average_precision_score(y_true, y_pred_proba)
|
2990
|
+
pr_info = {
|
2991
|
+
"precision": precision_,
|
2992
|
+
"recall": recall_,
|
2993
|
+
"avg_precision": avg_precision_,
|
2994
|
+
}
|
2995
|
+
else:
|
2996
|
+
roc_info, pr_info = None, None
|
2997
|
+
if purpose == "classification":
|
2998
|
+
results[name] = {
|
2999
|
+
"best_clf": gs.best_estimator_,
|
3000
|
+
"best_params": gs.best_params_,
|
3001
|
+
"auc_indiv": [
|
3002
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
3003
|
+
for i in range(cv_folds)
|
3004
|
+
],
|
3005
|
+
"scores": validation_scores,
|
3006
|
+
"roc_curve": roc_info,
|
3007
|
+
"pr_curve": pr_info,
|
3008
|
+
"confusion_matrix": confusion_matrix(y_true, y_pred),
|
3009
|
+
"predictions": y_pred.tolist(),
|
3010
|
+
"predictions_proba": (
|
3011
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3012
|
+
),
|
3013
|
+
}
|
3014
|
+
else: # "regression"
|
3015
|
+
results[name] = {
|
3016
|
+
"best_clf": gs.best_estimator_,
|
3017
|
+
"best_params": gs.best_params_,
|
3018
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
3019
|
+
"predictions": y_pred.tolist(),
|
3020
|
+
"predictions_proba": (
|
3021
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3022
|
+
),
|
3023
|
+
}
|
3024
|
+
else: # multi-classes
|
3025
|
+
if y_pred_proba is not None:
|
3026
|
+
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
3027
|
+
# fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
3028
|
+
confidence_intervals = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
|
3029
|
+
roc_info = {
|
3030
|
+
"fpr": validation_scores["fpr"],
|
3031
|
+
"tpr": validation_scores["tpr"],
|
3032
|
+
"auc": validation_scores["roc_auc_by_class"],
|
3033
|
+
"ci95": confidence_intervals,
|
3034
|
+
}
|
3035
|
+
# precision-recall curve
|
3036
|
+
precision_, recall_, avg_precision_ = cal_precision_recall(y_true, y_pred_proba,is_binary=is_binary)
|
3037
|
+
pr_info = {
|
3038
|
+
"precision": precision_,
|
3039
|
+
"recall": recall_,
|
3040
|
+
"avg_precision": avg_precision_,
|
3041
|
+
}
|
3042
|
+
else:
|
3043
|
+
roc_info, pr_info = None, None
|
3044
|
+
|
3045
|
+
if purpose == "classification":
|
3046
|
+
results[name] = {
|
3047
|
+
"best_clf": gs.best_estimator_,
|
3048
|
+
"best_params": gs.best_params_,
|
3049
|
+
"auc_indiv": [
|
3050
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
3051
|
+
for i in range(cv_folds)
|
3052
|
+
],
|
3053
|
+
"scores": validation_scores,
|
3054
|
+
"roc_curve": roc_info,
|
3055
|
+
"pr_curve": pr_info,
|
3056
|
+
"confusion_matrix": confusion_matrix(y_true, y_pred),
|
3057
|
+
"predictions": y_pred.tolist(),
|
3058
|
+
"predictions_proba": (
|
3059
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3060
|
+
),
|
3061
|
+
}
|
3062
|
+
else: # "regression"
|
3063
|
+
results[name] = {
|
3064
|
+
"best_clf": gs.best_estimator_,
|
3065
|
+
"best_params": gs.best_params_,
|
3066
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
3067
|
+
"predictions": y_pred.tolist(),
|
3068
|
+
"predictions_proba": (
|
3069
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3070
|
+
),
|
3071
|
+
}
|
2762
3072
|
|
2763
3073
|
else:
|
2764
3074
|
results[name] = {
|
@@ -2773,7 +3083,6 @@ def predict(
|
|
2773
3083
|
|
2774
3084
|
# Convert results to DataFrame
|
2775
3085
|
df_results = pd.DataFrame.from_dict(results, orient="index")
|
2776
|
-
|
2777
3086
|
# sort
|
2778
3087
|
if y_true is not None and purpose == "classification":
|
2779
3088
|
df_scores = pd.DataFrame(
|
@@ -2790,26 +3099,29 @@ def predict(
|
|
2790
3099
|
plot.figsets(xangle=30)
|
2791
3100
|
if dir_save:
|
2792
3101
|
ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
|
3102
|
+
|
3103
|
+
df_scores=df_scores.select_dtypes(include=np.number)
|
3104
|
+
|
2793
3105
|
if df_scores.shape[0] > 1: # draw cluster
|
2794
3106
|
plot.heatmap(df_scores, kind="direct", cluster=True)
|
2795
3107
|
plot.figsets(xangle=30)
|
2796
3108
|
if dir_save:
|
2797
3109
|
ips.figsave(dir_save + f"scores_clus{now_}.pdf")
|
2798
3110
|
if all([plot_, y_true is not None, purpose == "classification"]):
|
2799
|
-
try:
|
2800
|
-
|
2801
|
-
|
2802
|
-
|
2803
|
-
|
2804
|
-
|
2805
|
-
|
2806
|
-
except Exception as e:
|
2807
|
-
|
3111
|
+
# try:
|
3112
|
+
if len(models) > 3:
|
3113
|
+
plot_validate_features(df_results,is_binary=is_binary)
|
3114
|
+
else:
|
3115
|
+
plot_validate_features_single(df_results, is_binary=is_binary)
|
3116
|
+
if dir_save:
|
3117
|
+
ips.figsave(dir_save + f"validate_features{now_}.pdf")
|
3118
|
+
# except Exception as e:
|
3119
|
+
# print(f"Error: 在画图的过程中出现了问题:{e}")
|
2808
3120
|
return df_results
|
2809
3121
|
|
2810
3122
|
|
2811
3123
|
def cal_metrics(
|
2812
|
-
y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
|
3124
|
+
y_true, y_pred, y_pred_proba=None, is_binary=True,purpose="regression", average="weighted"
|
2813
3125
|
):
|
2814
3126
|
"""
|
2815
3127
|
Calculate regression or classification metrics based on the purpose.
|
@@ -2879,19 +3191,362 @@ def cal_metrics(
|
|
2879
3191
|
}
|
2880
3192
|
|
2881
3193
|
# Confusion matrix to calculate specificity
|
2882
|
-
|
2883
|
-
|
2884
|
-
|
2885
|
-
|
3194
|
+
if is_binary:
|
3195
|
+
cm = confusion_matrix(y_true, y_pred)
|
3196
|
+
if cm.size == 4:
|
3197
|
+
tn, fp, fn, tp = cm.ravel()
|
3198
|
+
else:
|
3199
|
+
# Handle single-class predictions
|
3200
|
+
tn, fp, fn, tp = 0, 0, 0, 0
|
3201
|
+
print("Warning: Only one class found in y_pred or y_true.")
|
3202
|
+
|
3203
|
+
# Specificity calculation
|
3204
|
+
validation_scores["specificity"] = (
|
3205
|
+
tn / (tn + fp) if (tn + fp) > 0 else 0
|
3206
|
+
)
|
3207
|
+
if y_pred_proba is not None:
|
3208
|
+
# Calculate ROC-AUC
|
3209
|
+
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
3210
|
+
# PR-AUC (Precision-Recall AUC) calculation
|
3211
|
+
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
3212
|
+
|
3213
|
+
else: # multi-class
|
3214
|
+
from sklearn.preprocessing import label_binarize
|
3215
|
+
#* Multi-class ROC calculation
|
3216
|
+
y_pred_proba = np.asarray(y_pred_proba)
|
3217
|
+
classes = np.unique(y_true)
|
3218
|
+
y_true_bin = label_binarize(y_true, classes=classes)
|
3219
|
+
if isinstance(y_true, np.ndarray):
|
3220
|
+
y_true = ips.df_encoder(data=pd.DataFrame(y_true), method='dum',prefix='Label')
|
3221
|
+
# Initialize dictionaries to store FPR, TPR, and AUC for each class
|
3222
|
+
fpr = dict()
|
3223
|
+
tpr = dict()
|
3224
|
+
roc_auc = dict()
|
3225
|
+
for i, class_label in enumerate(classes):
|
3226
|
+
fpr[class_label], tpr[class_label], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
|
3227
|
+
roc_auc[class_label] = auc(fpr[class_label], tpr[class_label])
|
3228
|
+
|
3229
|
+
# Store the mean ROC AUC
|
3230
|
+
try:
|
3231
|
+
validation_scores["roc_auc"] = roc_auc_score(
|
3232
|
+
y_true, y_pred_proba, multi_class="ovr", average=average
|
3233
|
+
)
|
3234
|
+
except Exception as e:
|
3235
|
+
y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True)
|
3236
|
+
validation_scores["roc_auc"] = roc_auc_score(
|
3237
|
+
y_true, y_pred_proba, multi_class="ovr", average=average
|
3238
|
+
)
|
3239
|
+
|
3240
|
+
validation_scores["roc_auc_by_class"] = roc_auc # Individual class AUCs
|
3241
|
+
validation_scores["fpr"] = fpr
|
3242
|
+
validation_scores["tpr"] = tpr
|
2886
3243
|
|
2887
|
-
if y_pred_proba is not None:
|
2888
|
-
# Calculate ROC-AUC
|
2889
|
-
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2890
|
-
# PR-AUC (Precision-Recall AUC) calculation
|
2891
|
-
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
2892
3244
|
else:
|
2893
3245
|
raise ValueError(
|
2894
3246
|
"Invalid purpose specified. Choose 'regression' or 'classification'."
|
2895
3247
|
)
|
2896
3248
|
|
2897
3249
|
return validation_scores
|
3250
|
+
|
3251
|
+
def plot_trees(
|
3252
|
+
X, y, cls, max_trees=500, test_size=0.2, random_state=42, early_stopping_rounds=None
|
3253
|
+
):
|
3254
|
+
"""
|
3255
|
+
# # Example usage:
|
3256
|
+
# X = np.random.rand(100, 10) # Example data with 100 samples and 10 features
|
3257
|
+
# y = np.random.randint(0, 2, 100) # Example binary target
|
3258
|
+
# # Using the function with different classifiers
|
3259
|
+
# # Random Forest example
|
3260
|
+
# plot_trees(X, y, RandomForestClassifier(), max_trees=100)
|
3261
|
+
# # Gradient Boosting with early stopping example
|
3262
|
+
# plot_trees(X, y, GradientBoostingClassifier(), max_trees=100, early_stopping_rounds=10)
|
3263
|
+
# # Extra Trees example
|
3264
|
+
# plot_trees(X, y, ExtraTreesClassifier(), max_trees=100)
|
3265
|
+
Master function to plot error rates (OOB, training, and testing) for different tree-based ensemble classifiers.
|
3266
|
+
|
3267
|
+
Parameters:
|
3268
|
+
- X (array-like): Feature matrix.
|
3269
|
+
- y (array-like): Target labels.
|
3270
|
+
- cls (object): Tree-based ensemble classifier instance (e.g., RandomForestClassifier()).
|
3271
|
+
- max_trees (int): Maximum number of trees to evaluate. Default is 500.
|
3272
|
+
- test_size (float): Proportion of data to use as test set for testing error. Default is 0.2.
|
3273
|
+
- random_state (int): Random state for reproducibility. Default is 42.
|
3274
|
+
- early_stopping_rounds (int): For boosting models only, stops training if validation error doesn't improve after specified rounds.
|
3275
|
+
|
3276
|
+
Returns:
|
3277
|
+
- None
|
3278
|
+
"""
|
3279
|
+
from sklearn.model_selection import train_test_split
|
3280
|
+
from sklearn.metrics import accuracy_score
|
3281
|
+
from sklearn.ensemble import (
|
3282
|
+
RandomForestClassifier,
|
3283
|
+
BaggingClassifier,
|
3284
|
+
ExtraTreesClassifier,
|
3285
|
+
)
|
3286
|
+
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
|
3287
|
+
# Split data for training and testing error calculation
|
3288
|
+
x_train, x_test, y_train, y_test = train_test_split(
|
3289
|
+
X, y, test_size=test_size, random_state=random_state
|
3290
|
+
)
|
3291
|
+
|
3292
|
+
# Initialize lists to store error rates
|
3293
|
+
oob_error_rate = []
|
3294
|
+
train_error_rate = []
|
3295
|
+
test_error_rate = []
|
3296
|
+
validation_error = None
|
3297
|
+
|
3298
|
+
# Configure classifier based on type
|
3299
|
+
oob_enabled = False # Default to no OOB error unless explicitly set
|
3300
|
+
|
3301
|
+
if isinstance(cls, (RandomForestClassifier, ExtraTreesClassifier)):
|
3302
|
+
# Enable OOB if cls supports it and is using bootstrapping
|
3303
|
+
cls.set_params(warm_start=True, n_estimators=1)
|
3304
|
+
if hasattr(cls, "oob_score"):
|
3305
|
+
cls.set_params(bootstrap=True, oob_score=True)
|
3306
|
+
oob_enabled = True
|
3307
|
+
elif isinstance(cls, BaggingClassifier):
|
3308
|
+
cls.set_params(warm_start=True, bootstrap=True, oob_score=True, n_estimators=1)
|
3309
|
+
oob_enabled = True
|
3310
|
+
elif isinstance(cls, (AdaBoostClassifier, GradientBoostingClassifier)):
|
3311
|
+
cls.set_params(n_estimators=1)
|
3312
|
+
oob_enabled = False
|
3313
|
+
if early_stopping_rounds:
|
3314
|
+
validation_error = []
|
3315
|
+
|
3316
|
+
# Train and evaluate with an increasing number of trees
|
3317
|
+
for i in range(1, max_trees + 1):
|
3318
|
+
cls.set_params(n_estimators=i)
|
3319
|
+
cls.fit(x_train, y_train)
|
3320
|
+
|
3321
|
+
# Calculate OOB error (for models that support it)
|
3322
|
+
if oob_enabled and hasattr(cls, "oob_score_") and cls.oob_score:
|
3323
|
+
oob_error = 1 - cls.oob_score_
|
3324
|
+
oob_error_rate.append(oob_error)
|
3325
|
+
|
3326
|
+
# Calculate training error
|
3327
|
+
train_error = 1 - accuracy_score(y_train, cls.predict(x_train))
|
3328
|
+
train_error_rate.append(train_error)
|
3329
|
+
|
3330
|
+
# Calculate testing error
|
3331
|
+
test_error = 1 - accuracy_score(y_test, cls.predict(x_test))
|
3332
|
+
test_error_rate.append(test_error)
|
3333
|
+
|
3334
|
+
# For boosting models, use validation error with early stopping
|
3335
|
+
if early_stopping_rounds and isinstance(
|
3336
|
+
cls, (AdaBoostClassifier, GradientBoostingClassifier)
|
3337
|
+
):
|
3338
|
+
val_error = 1 - accuracy_score(y_test, cls.predict(x_test))
|
3339
|
+
validation_error.append(val_error)
|
3340
|
+
if len(validation_error) > early_stopping_rounds:
|
3341
|
+
# Stop if validation error has not improved in early_stopping_rounds
|
3342
|
+
if validation_error[-early_stopping_rounds:] == sorted(
|
3343
|
+
validation_error[-early_stopping_rounds:]
|
3344
|
+
):
|
3345
|
+
print(f"Early stopping at tree {i} due to lack of improvement in validation error.")
|
3346
|
+
break
|
3347
|
+
|
3348
|
+
# Plot results
|
3349
|
+
plt.figure(figsize=(10, 6))
|
3350
|
+
if oob_error_rate:
|
3351
|
+
plt.plot(
|
3352
|
+
range(1, len(oob_error_rate) + 1),
|
3353
|
+
oob_error_rate,
|
3354
|
+
color="black",
|
3355
|
+
label="OOB Error Rate",
|
3356
|
+
linewidth=2,
|
3357
|
+
)
|
3358
|
+
if train_error_rate:
|
3359
|
+
plt.plot(
|
3360
|
+
range(1, len(train_error_rate) + 1),
|
3361
|
+
train_error_rate,
|
3362
|
+
linestyle="dotted",
|
3363
|
+
color="green",
|
3364
|
+
label="Training Error Rate",
|
3365
|
+
)
|
3366
|
+
if test_error_rate:
|
3367
|
+
plt.plot(
|
3368
|
+
range(1, len(test_error_rate) + 1),
|
3369
|
+
test_error_rate,
|
3370
|
+
linestyle="dashed",
|
3371
|
+
color="red",
|
3372
|
+
label="Testing Error Rate",
|
3373
|
+
)
|
3374
|
+
if validation_error:
|
3375
|
+
plt.plot(
|
3376
|
+
range(1, len(validation_error) + 1),
|
3377
|
+
validation_error,
|
3378
|
+
linestyle="solid",
|
3379
|
+
color="blue",
|
3380
|
+
label="Validation Error (Boosting)",
|
3381
|
+
)
|
3382
|
+
|
3383
|
+
# Customize plot
|
3384
|
+
plt.xlabel("Number of Trees")
|
3385
|
+
plt.ylabel("Error Rate")
|
3386
|
+
plt.title(f"Error Rate Analysis for {cls.__class__.__name__}")
|
3387
|
+
plt.legend(loc="upper right")
|
3388
|
+
plt.grid(True)
|
3389
|
+
plt.show()
|
3390
|
+
|
3391
|
+
def img_datasets_preprocessing(
|
3392
|
+
data: pd.DataFrame,
|
3393
|
+
x_col: str,
|
3394
|
+
y_col: str=None,
|
3395
|
+
target_size: tuple = (224, 224),
|
3396
|
+
batch_size: int = 128,
|
3397
|
+
class_mode: str = "raw",
|
3398
|
+
shuffle: bool = False,
|
3399
|
+
augment: bool = False,
|
3400
|
+
scaler: str = 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
|
3401
|
+
grayscale: bool = False,
|
3402
|
+
encoder: str = "label", # Options: 'label', 'onehot', 'binary'
|
3403
|
+
label_encoder=None,
|
3404
|
+
kws_augmentation: dict = None,
|
3405
|
+
verbose: bool = True,
|
3406
|
+
drop_missing: bool = True,
|
3407
|
+
output="df", # "iterator":data_iterator,'df':return DataFrame
|
3408
|
+
):
|
3409
|
+
"""
|
3410
|
+
Enhanced preprocessing function for loading and preparing image data from a DataFrame.
|
3411
|
+
|
3412
|
+
Parameters:
|
3413
|
+
- df (pd.DataFrame): Input DataFrame with image paths and labels.
|
3414
|
+
- x_col (str): Column in `df` containing image file paths.
|
3415
|
+
- y_col (str): Column in `df` containing image labels.
|
3416
|
+
- target_size (tuple): Desired image size in (height, width).
|
3417
|
+
- batch_size (int): Number of images per batch.
|
3418
|
+
- class_mode (str): Mode of label ('raw', 'categorical', 'binary').
|
3419
|
+
- shuffle (bool): Shuffle the images in the DataFrame.
|
3420
|
+
- augment (bool): Apply data augmentation.
|
3421
|
+
- scaler (str): 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
|
3422
|
+
- grayscale (bool): Convert images to grayscale.
|
3423
|
+
- normalize (bool): Normalize image data to [0, 1] range.
|
3424
|
+
- encoder (str): Label encoder method ('label', 'onehot', 'binary').
|
3425
|
+
- label_encoder: Optional pre-defined label encoder.
|
3426
|
+
- kws_augmentation (dict): Parameters for data augmentation.
|
3427
|
+
- verbose (bool): Print status messages.
|
3428
|
+
- drop_missing (bool): Drop rows with missing or invalid image paths.
|
3429
|
+
|
3430
|
+
Returns:
|
3431
|
+
- pd.DataFrame: DataFrame with flattened image pixels and 'Label' column.
|
3432
|
+
"""
|
3433
|
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
3434
|
+
from tensorflow.keras.utils import to_categorical
|
3435
|
+
from sklearn.preprocessing import LabelEncoder
|
3436
|
+
from PIL import Image
|
3437
|
+
import os
|
3438
|
+
|
3439
|
+
# Validate input DataFrame for required columns
|
3440
|
+
if y_col:
|
3441
|
+
assert (
|
3442
|
+
x_col in data.columns and y_col in data.columns
|
3443
|
+
), "Missing required columns in DataFrame."
|
3444
|
+
if y_col is None:
|
3445
|
+
class_mode=None
|
3446
|
+
# 输出格式
|
3447
|
+
output = ips.strcmp(output,[
|
3448
|
+
"generator","tf","iterator","transform","transformer","dataframe",
|
3449
|
+
"df","pd","pandas"])[0]
|
3450
|
+
|
3451
|
+
# Handle missing file paths
|
3452
|
+
if drop_missing:
|
3453
|
+
data = data[
|
3454
|
+
data[x_col].apply(lambda path: os.path.exists(path) and os.path.isfile(path))
|
3455
|
+
]
|
3456
|
+
|
3457
|
+
# Encoding labels if necessary
|
3458
|
+
if encoder and y_col is not None:
|
3459
|
+
if encoder == "binary":
|
3460
|
+
data[y_col] = (data[y_col] == data[y_col].unique()[0]).astype(int)
|
3461
|
+
elif encoder == "onehot":
|
3462
|
+
if not label_encoder:
|
3463
|
+
label_encoder = LabelEncoder()
|
3464
|
+
data[y_col] = label_encoder.fit_transform(data[y_col])
|
3465
|
+
data[y_col] = to_categorical(data[y_col])
|
3466
|
+
elif encoder == "label":
|
3467
|
+
if not label_encoder:
|
3468
|
+
label_encoder = LabelEncoder()
|
3469
|
+
data[y_col] = label_encoder.fit_transform(data[y_col])
|
3470
|
+
|
3471
|
+
# Set up data augmentation
|
3472
|
+
if augment:
|
3473
|
+
aug_params = {
|
3474
|
+
"rotation_range": 20,
|
3475
|
+
"width_shift_range": 0.2,
|
3476
|
+
"height_shift_range": 0.2,
|
3477
|
+
"shear_range": 0.2,
|
3478
|
+
"zoom_range": 0.2,
|
3479
|
+
"horizontal_flip": True,
|
3480
|
+
"fill_mode": "nearest",
|
3481
|
+
}
|
3482
|
+
if kws_augmentation:
|
3483
|
+
aug_params.update(kws_augmentation)
|
3484
|
+
dat = ImageDataGenerator(rescale=scaler, **aug_params)
|
3485
|
+
dat = ImageDataGenerator(
|
3486
|
+
rescale=1.0 / 255 if scaler == 'normalize' else None, **aug_params)
|
3487
|
+
|
3488
|
+
else:
|
3489
|
+
dat = ImageDataGenerator(
|
3490
|
+
rescale=1.0 / 255 if scaler == 'normalize' else None)
|
3491
|
+
|
3492
|
+
# Create DataFrameIterator
|
3493
|
+
data_iterator = dat.flow_from_dataframe(
|
3494
|
+
dataframe=data,
|
3495
|
+
x_col=x_col,
|
3496
|
+
y_col=y_col,
|
3497
|
+
target_size=target_size,
|
3498
|
+
color_mode="grayscale" if grayscale else "rgb",
|
3499
|
+
batch_size=batch_size,
|
3500
|
+
class_mode=class_mode,
|
3501
|
+
shuffle=shuffle,
|
3502
|
+
)
|
3503
|
+
print(f"target_size:{target_size}")
|
3504
|
+
if output.lower() in ["generator", "tf", "iterator", "transform", "transformer"]:
|
3505
|
+
return data_iterator
|
3506
|
+
elif output.lower() in ["dataframe", "df", "pd", "pandas"]:
|
3507
|
+
# Initialize list to collect processed data
|
3508
|
+
data_list = []
|
3509
|
+
total_batches = data_iterator.n // batch_size
|
3510
|
+
|
3511
|
+
# Load, resize, and process images in batches
|
3512
|
+
for i, (batch_images, batch_labels) in enumerate(data_iterator):
|
3513
|
+
for img, label in zip(batch_images, batch_labels):
|
3514
|
+
if scaler == ['normalize','raw']:
|
3515
|
+
# Already rescaled by 1.0/255 in ImageDataGenerator
|
3516
|
+
pass
|
3517
|
+
elif scaler == 'standardize':
|
3518
|
+
# Standardize by subtracting mean and dividing by std
|
3519
|
+
img = (img - np.mean(img)) / np.std(img)
|
3520
|
+
elif scaler == 'clahe':
|
3521
|
+
# Apply CLAHE to the image
|
3522
|
+
img = apply_clahe(img)
|
3523
|
+
flat_img = img.flatten()
|
3524
|
+
data_list.append(np.append(flat_img, label))
|
3525
|
+
|
3526
|
+
# Stop when all images have been processed
|
3527
|
+
if i >= total_batches:
|
3528
|
+
break
|
3529
|
+
|
3530
|
+
# Define column names for flattened image data
|
3531
|
+
pixel_count = target_size[0] * target_size[1] * (1 if grayscale else 3)
|
3532
|
+
column_names = [f"pixel_{i}" for i in range(pixel_count)] + ["Label"]
|
3533
|
+
|
3534
|
+
# Create DataFrame from flattened data
|
3535
|
+
df_img = pd.DataFrame(data_list, columns=column_names)
|
3536
|
+
|
3537
|
+
if verbose:
|
3538
|
+
print("Processed images:", len(df_img))
|
3539
|
+
print("Final DataFrame shape:", df_img.shape)
|
3540
|
+
display(df_img.head())
|
3541
|
+
|
3542
|
+
return df_img
|
3543
|
+
# Function to apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
|
3544
|
+
def apply_clahe(img):
|
3545
|
+
import cv2
|
3546
|
+
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) # Convert to LAB color space
|
3547
|
+
l, a, b = cv2.split(lab) # Split into channels
|
3548
|
+
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
3549
|
+
cl = clahe.apply(l) # Apply CLAHE to the L channel
|
3550
|
+
limg = cv2.merge((cl, a, b)) # Merge back the channels
|
3551
|
+
img_clahe = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) # Convert back to RGB
|
3552
|
+
return img_clahe
|