py2ls 0.2.4.14__py3-none-any.whl → 0.2.4.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.git/index +0 -0
- py2ls/ips.py +722 -12
- py2ls/ml2ls copy.py +2906 -0
- py2ls/ml2ls.py +898 -243
- py2ls/plot.py +409 -24
- py2ls/translator.py +2 -0
- {py2ls-0.2.4.14.dist-info → py2ls-0.2.4.16.dist-info}/METADATA +1 -1
- {py2ls-0.2.4.14.dist-info → py2ls-0.2.4.16.dist-info}/RECORD +9 -8
- {py2ls-0.2.4.14.dist-info → py2ls-0.2.4.16.dist-info}/WHEEL +0 -0
py2ls/ml2ls.py
CHANGED
@@ -506,7 +506,7 @@ def get_models(
|
|
506
506
|
"Support Vector Machine(svm)",
|
507
507
|
"naive bayes",
|
508
508
|
"Linear Discriminant Analysis (lda)",
|
509
|
-
"
|
509
|
+
"AdaBoost",
|
510
510
|
"DecisionTree",
|
511
511
|
"KNeighbors",
|
512
512
|
"Bagging",
|
@@ -585,7 +585,7 @@ def get_features(
|
|
585
585
|
"Support Vector Machine(svm)",
|
586
586
|
"naive bayes",
|
587
587
|
"Linear Discriminant Analysis (lda)",
|
588
|
-
"
|
588
|
+
"AdaBoost",
|
589
589
|
"DecisionTree",
|
590
590
|
"KNeighbors",
|
591
591
|
"Bagging",
|
@@ -616,10 +616,10 @@ def get_features(
|
|
616
616
|
if isinstance(y, str) and y in X.columns:
|
617
617
|
y_col_name = y
|
618
618
|
y = X[y]
|
619
|
-
y = ips.df_encoder(pd.DataFrame(y), method="
|
619
|
+
y = ips.df_encoder(pd.DataFrame(y), method="label")
|
620
620
|
X = X.drop(y_col_name, axis=1)
|
621
621
|
else:
|
622
|
-
y = ips.df_encoder(pd.DataFrame(y), method="
|
622
|
+
y = ips.df_encoder(pd.DataFrame(y), method="label").values.ravel()
|
623
623
|
y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
|
624
624
|
y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
|
625
625
|
|
@@ -699,9 +699,11 @@ def get_features(
|
|
699
699
|
"Support Vector Machine(svm)",
|
700
700
|
"Naive Bayes",
|
701
701
|
"Linear Discriminant Analysis (lda)",
|
702
|
-
"
|
702
|
+
"AdaBoost",
|
703
703
|
]
|
704
704
|
cls = [ips.strcmp(i, cls_)[0] for i in cls]
|
705
|
+
|
706
|
+
feature_importances = {}
|
705
707
|
|
706
708
|
# Lasso Feature Selection
|
707
709
|
lasso_importances = (
|
@@ -712,6 +714,7 @@ def get_features(
|
|
712
714
|
lasso_selected_features = (
|
713
715
|
lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
|
714
716
|
)
|
717
|
+
feature_importances['lasso']=lasso_importances.head(n_features)
|
715
718
|
# Ridge
|
716
719
|
ridge_importances = (
|
717
720
|
features_ridge(x_train, y_train, ridge_params)
|
@@ -721,6 +724,7 @@ def get_features(
|
|
721
724
|
selected_ridge_features = (
|
722
725
|
ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
|
723
726
|
)
|
727
|
+
feature_importances['ridge']=ridge_importances.head(n_features)
|
724
728
|
# Elastic Net
|
725
729
|
enet_importances = (
|
726
730
|
features_enet(x_train, y_train, enet_params)
|
@@ -730,6 +734,7 @@ def get_features(
|
|
730
734
|
selected_enet_features = (
|
731
735
|
enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
|
732
736
|
)
|
737
|
+
feature_importances['Enet']=enet_importances.head(n_features)
|
733
738
|
# Random Forest Feature Importance
|
734
739
|
rf_importances = (
|
735
740
|
features_rf(x_train, y_train, rf_params)
|
@@ -741,6 +746,7 @@ def get_features(
|
|
741
746
|
if "Random Forest" in cls
|
742
747
|
else []
|
743
748
|
)
|
749
|
+
feature_importances['Random Forest']=rf_importances.head(n_features)
|
744
750
|
# Gradient Boosting Feature Importance
|
745
751
|
gb_importances = (
|
746
752
|
features_gradient_boosting(x_train, y_train, gb_params)
|
@@ -752,6 +758,7 @@ def get_features(
|
|
752
758
|
if "Gradient Boosting" in cls
|
753
759
|
else []
|
754
760
|
)
|
761
|
+
feature_importances['Gradient Boosting']=gb_importances.head(n_features)
|
755
762
|
# xgb
|
756
763
|
xgb_importances = (
|
757
764
|
features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
|
@@ -759,6 +766,7 @@ def get_features(
|
|
759
766
|
top_xgb_features = (
|
760
767
|
xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
|
761
768
|
)
|
769
|
+
feature_importances['xgb']=xgb_importances.head(n_features)
|
762
770
|
|
763
771
|
# SVM with RFE
|
764
772
|
selected_svm_features = (
|
@@ -773,6 +781,7 @@ def get_features(
|
|
773
781
|
selected_lda_features = (
|
774
782
|
lda_importances.head(n_features)["feature"].values if "lda" in cls else []
|
775
783
|
)
|
784
|
+
feature_importances['lda']=lda_importances.head(n_features)
|
776
785
|
# AdaBoost Feature Importance
|
777
786
|
adaboost_importances = (
|
778
787
|
features_adaboost(x_train, y_train, adaboost_params)
|
@@ -784,6 +793,7 @@ def get_features(
|
|
784
793
|
if "AdaBoost" in cls
|
785
794
|
else []
|
786
795
|
)
|
796
|
+
feature_importances['AdaBoost']=adaboost_importances.head(n_features)
|
787
797
|
# Decision Tree Feature Importance
|
788
798
|
dt_importances = (
|
789
799
|
features_decision_tree(x_train, y_train, dt_params)
|
@@ -794,7 +804,8 @@ def get_features(
|
|
794
804
|
dt_importances.head(n_features)["feature"].values
|
795
805
|
if "Decision Tree" in cls
|
796
806
|
else []
|
797
|
-
)
|
807
|
+
)
|
808
|
+
feature_importances['Decision Tree']=dt_importances.head(n_features)
|
798
809
|
# Bagging Feature Importance
|
799
810
|
bagging_importances = (
|
800
811
|
features_bagging(x_train, y_train, bagging_params)
|
@@ -806,6 +817,7 @@ def get_features(
|
|
806
817
|
if "Bagging" in cls
|
807
818
|
else []
|
808
819
|
)
|
820
|
+
feature_importances['Bagging']=bagging_importances.head(n_features)
|
809
821
|
# KNN Feature Importance via Permutation
|
810
822
|
knn_importances = (
|
811
823
|
features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
|
@@ -813,6 +825,7 @@ def get_features(
|
|
813
825
|
top_knn_features = (
|
814
826
|
knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
|
815
827
|
)
|
828
|
+
feature_importances['KNN']=knn_importances.head(n_features)
|
816
829
|
|
817
830
|
#! Find common features
|
818
831
|
common_features = ips.shared(
|
@@ -915,6 +928,7 @@ def get_features(
|
|
915
928
|
"cv_train_scores": cv_train_results_df,
|
916
929
|
"cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
|
917
930
|
"common_features": list(common_features),
|
931
|
+
"feature_importances":feature_importances
|
918
932
|
}
|
919
933
|
if all([plot_, dir_save]):
|
920
934
|
from datetime import datetime
|
@@ -927,6 +941,7 @@ def get_features(
|
|
927
941
|
"cv_train_scores": pd.DataFrame(),
|
928
942
|
"cv_test_scores": pd.DataFrame(),
|
929
943
|
"common_features": [],
|
944
|
+
"feature_importances":{}
|
930
945
|
}
|
931
946
|
print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
|
932
947
|
return results
|
@@ -1217,142 +1232,335 @@ def validate_features(
|
|
1217
1232
|
|
1218
1233
|
# # If you want to access validation scores
|
1219
1234
|
# print(validation_results)
|
1220
|
-
def plot_validate_features(res_val):
|
1235
|
+
def plot_validate_features(res_val,is_binary=True,figsize=None):
|
1221
1236
|
"""
|
1222
1237
|
plot the results of 'validate_features()'
|
1223
1238
|
"""
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1239
|
+
if is_binary:
|
1240
|
+
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1241
|
+
if res_val.shape[0] > 5:
|
1242
|
+
alpha = 0
|
1243
|
+
figsize = [8, 10] if figsize is None else figsize
|
1244
|
+
subplot_layout = [1, 2]
|
1245
|
+
ncols = 2
|
1246
|
+
bbox_to_anchor = [1.5, 0.6]
|
1247
|
+
else:
|
1248
|
+
alpha = 0.03
|
1249
|
+
figsize = [10, 6] if figsize is None else figsize
|
1250
|
+
subplot_layout = [1, 1]
|
1251
|
+
ncols = 1
|
1252
|
+
bbox_to_anchor = [1, 1]
|
1253
|
+
nexttile = plot.subplot(figsize=figsize)
|
1254
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1255
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1256
|
+
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1257
|
+
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1258
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1259
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1260
|
+
plot_roc_curve(
|
1261
|
+
fpr,
|
1262
|
+
tpr,
|
1263
|
+
mean_auc,
|
1264
|
+
lower_ci,
|
1265
|
+
upper_ci,
|
1266
|
+
model_name=model_name,
|
1267
|
+
lw=1.5,
|
1268
|
+
color=colors[i],
|
1269
|
+
alpha=alpha,
|
1270
|
+
ax=ax,
|
1271
|
+
)
|
1272
|
+
plot.figsets(
|
1273
|
+
sp=2,
|
1274
|
+
legend=dict(
|
1275
|
+
loc="upper right",
|
1276
|
+
ncols=ncols,
|
1277
|
+
fontsize=8,
|
1278
|
+
bbox_to_anchor=[1.5, 0.6],
|
1279
|
+
markerscale=0.8,
|
1280
|
+
),
|
1255
1281
|
)
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1274
|
-
model_name=model_name,
|
1275
|
-
color=colors[i],
|
1276
|
-
lw=1.5,
|
1277
|
-
alpha=alpha,
|
1278
|
-
ax=ax,
|
1282
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1283
|
+
|
1284
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1285
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1286
|
+
plot_pr_curve(
|
1287
|
+
recall=res_val["pr_curve"][model_name]["recall"],
|
1288
|
+
precision=res_val["pr_curve"][model_name]["precision"],
|
1289
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1290
|
+
model_name=model_name,
|
1291
|
+
color=colors[i],
|
1292
|
+
lw=1.5,
|
1293
|
+
alpha=alpha,
|
1294
|
+
ax=ax,
|
1295
|
+
)
|
1296
|
+
plot.figsets(
|
1297
|
+
sp=2,
|
1298
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1279
1299
|
)
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1300
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1301
|
+
else:
|
1302
|
+
colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
|
1303
|
+
modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
|
1304
|
+
classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
|
1305
|
+
if res_val.shape[0] > 5:
|
1306
|
+
alpha = 0
|
1307
|
+
figsize = [8, 8*2*(len(classes))] if figsize is None else figsize
|
1308
|
+
subplot_layout = [1, 2]
|
1309
|
+
ncols = 2
|
1310
|
+
bbox_to_anchor = [1.5, 0.6]
|
1311
|
+
else:
|
1312
|
+
alpha = 0.03
|
1313
|
+
figsize = [10, 6*(len(classes))] if figsize is None else figsize
|
1314
|
+
subplot_layout = [1, 1]
|
1315
|
+
ncols = 1
|
1316
|
+
bbox_to_anchor = [1, 1]
|
1317
|
+
nexttile = plot.subplot(2*(len(classes)),2,figsize=figsize)
|
1318
|
+
for iclass, class_ in enumerate(classes):
|
1319
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1320
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1321
|
+
fpr = res_val["roc_curve"][model_name]["fpr"][class_]
|
1322
|
+
tpr = res_val["roc_curve"][model_name]["tpr"][class_]
|
1323
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
|
1324
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
|
1325
|
+
plot_roc_curve(
|
1326
|
+
fpr,
|
1327
|
+
tpr,
|
1328
|
+
mean_auc,
|
1329
|
+
lower_ci,
|
1330
|
+
upper_ci,
|
1331
|
+
model_name=model_name,
|
1332
|
+
lw=1.5,
|
1333
|
+
color=colors[i],
|
1334
|
+
alpha=alpha,
|
1335
|
+
ax=ax,
|
1336
|
+
)
|
1337
|
+
plot.figsets(
|
1338
|
+
sp=2,
|
1339
|
+
title=class_,
|
1340
|
+
legend=dict(
|
1341
|
+
loc="upper right",
|
1342
|
+
ncols=ncols,
|
1343
|
+
fontsize=8,
|
1344
|
+
bbox_to_anchor=[1.5, 0.6],
|
1345
|
+
markerscale=0.8,
|
1346
|
+
),
|
1347
|
+
)
|
1348
|
+
# plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
|
1349
|
+
|
1350
|
+
ax = nexttile(subplot_layout[0], subplot_layout[1])
|
1351
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1352
|
+
plot_pr_curve(
|
1353
|
+
recall=res_val["pr_curve"][model_name]["recall"][iclass],
|
1354
|
+
precision=res_val["pr_curve"][model_name]["precision"][iclass],
|
1355
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
|
1356
|
+
model_name=model_name,
|
1357
|
+
color=colors[i],
|
1358
|
+
lw=1.5,
|
1359
|
+
alpha=alpha,
|
1360
|
+
ax=ax,
|
1361
|
+
)
|
1362
|
+
plot.figsets(
|
1363
|
+
sp=2,
|
1364
|
+
title=class_,
|
1365
|
+
legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
|
1366
|
+
)
|
1285
1367
|
|
1368
|
+
def plot_validate_features_single(res_val, figsize=None,is_binary=True):
|
1369
|
+
if is_binary:
|
1370
|
+
if figsize is None:
|
1371
|
+
nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3,figsize=[13,4*len(ips.flatten(res_val["pr_curve"].index))])
|
1372
|
+
else:
|
1373
|
+
nexttile = plot.subplot(
|
1374
|
+
len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
|
1375
|
+
)
|
1376
|
+
for model_name in ips.flatten(res_val["pr_curve"].index):
|
1377
|
+
fpr = res_val["roc_curve"][model_name]["fpr"]
|
1378
|
+
tpr = res_val["roc_curve"][model_name]["tpr"]
|
1379
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
|
1380
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"]
|
1381
|
+
|
1382
|
+
# Plotting
|
1383
|
+
plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
|
1384
|
+
model_name=model_name, ax=nexttile())
|
1385
|
+
plot.figsets(title=model_name, sp=2)
|
1386
|
+
|
1387
|
+
plot_pr_binary(
|
1388
|
+
recall=res_val["pr_curve"][model_name]["recall"],
|
1389
|
+
precision=res_val["pr_curve"][model_name]["precision"],
|
1390
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
|
1391
|
+
model_name=model_name,
|
1392
|
+
ax=nexttile(),
|
1393
|
+
)
|
1394
|
+
plot.figsets(title=model_name, sp=2)
|
1286
1395
|
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1396
|
+
# plot cm
|
1397
|
+
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1398
|
+
plot.figsets(title=model_name, sp=2)
|
1290
1399
|
else:
|
1291
|
-
|
1292
|
-
|
1293
|
-
)
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1400
|
+
|
1401
|
+
modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
|
1402
|
+
classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
|
1403
|
+
if figsize is None:
|
1404
|
+
nexttile = plot.subplot(len(modname_tmp), 3,figsize=[15,len(modname_tmp)*5])
|
1405
|
+
else:
|
1406
|
+
nexttile = plot.subplot(len(modname_tmp), 3, figsize=figsize)
|
1407
|
+
colors = plot.get_color(len(classes))
|
1408
|
+
for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
|
1409
|
+
ax = nexttile()
|
1410
|
+
for iclass, class_ in enumerate(classes):
|
1411
|
+
fpr = res_val["roc_curve"][model_name]["fpr"][class_]
|
1412
|
+
tpr = res_val["roc_curve"][model_name]["tpr"][class_]
|
1413
|
+
(lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
|
1414
|
+
mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
|
1415
|
+
plot_roc_curve(
|
1416
|
+
fpr,
|
1417
|
+
tpr,
|
1418
|
+
mean_auc,
|
1419
|
+
lower_ci,
|
1420
|
+
upper_ci,
|
1421
|
+
model_name=class_,
|
1422
|
+
lw=1.5,
|
1423
|
+
color=colors[iclass],
|
1424
|
+
alpha=0.03,
|
1425
|
+
ax=ax,
|
1426
|
+
)
|
1427
|
+
plot.figsets(
|
1428
|
+
sp=2,
|
1429
|
+
title=model_name,
|
1430
|
+
legend=dict(
|
1431
|
+
loc="best",
|
1432
|
+
fontsize=8,
|
1433
|
+
),
|
1434
|
+
)
|
1435
|
+
|
1436
|
+
ax = nexttile()
|
1437
|
+
for iclass, class_ in enumerate(classes):
|
1438
|
+
plot_pr_curve(
|
1439
|
+
recall=res_val["pr_curve"][model_name]["recall"][iclass],
|
1440
|
+
precision=res_val["pr_curve"][model_name]["precision"][iclass],
|
1441
|
+
avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
|
1442
|
+
model_name=class_,
|
1443
|
+
color=colors[iclass],
|
1444
|
+
lw=1.5,
|
1445
|
+
alpha=0.03,
|
1446
|
+
ax=ax,
|
1447
|
+
)
|
1448
|
+
plot.figsets(
|
1449
|
+
sp=2,
|
1450
|
+
title=class_,
|
1451
|
+
legend=dict(loc="best", fontsize=8),
|
1452
|
+
)
|
1453
|
+
|
1454
|
+
plot_cm(res_val["confusion_matrix"][model_name],labels_name=classes, ax=nexttile(), normalize=False)
|
1455
|
+
plot.figsets(title=model_name, sp=2)
|
1313
1456
|
|
1314
|
-
# plot cm
|
1315
|
-
plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
|
1316
|
-
plot.figsets(title=model_name, sp=2)
|
1317
1457
|
|
1458
|
+
def cal_precision_recall(
|
1459
|
+
y_true, y_pred_proba, is_binary=True):
|
1460
|
+
if is_binary:
|
1461
|
+
precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
|
1462
|
+
avg_precision_ = average_precision_score(y_true, y_pred_proba)
|
1463
|
+
return precision_, recall_,avg_precision_
|
1464
|
+
else:
|
1465
|
+
n_classes = y_pred_proba.shape[1] # Number of classes
|
1466
|
+
precision_ = []
|
1467
|
+
recall_ = []
|
1468
|
+
|
1469
|
+
# One-vs-rest approach for multi-class precision-recall curve
|
1470
|
+
for class_idx in range(n_classes):
|
1471
|
+
precision, recall, _ = precision_recall_curve(
|
1472
|
+
(y_true == class_idx).astype(int), # Binarize true labels for the current class
|
1473
|
+
y_pred_proba[:, class_idx], # Probabilities for the current class
|
1474
|
+
)
|
1318
1475
|
|
1476
|
+
precision_.append(precision)
|
1477
|
+
recall_.append(recall)
|
1478
|
+
# Optionally, you can compute average precision for each class
|
1479
|
+
avg_precision_ = []
|
1480
|
+
for class_idx in range(n_classes):
|
1481
|
+
avg_precision = average_precision_score(
|
1482
|
+
(y_true == class_idx).astype(int), # Binarize true labels for the current class
|
1483
|
+
y_pred_proba[:, class_idx], # Probabilities for the current class
|
1484
|
+
)
|
1485
|
+
avg_precision_.append(avg_precision)
|
1486
|
+
return precision_, recall_,avg_precision_
|
1487
|
+
|
1319
1488
|
def cal_auc_ci(
|
1320
|
-
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
|
1489
|
+
y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,is_binary=True, verbose=True
|
1321
1490
|
):
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1491
|
+
if is_binary:
|
1492
|
+
y_true = np.asarray(y_true)
|
1493
|
+
y_pred = np.asarray(y_pred)
|
1494
|
+
bootstrapped_scores = []
|
1495
|
+
if verbose:
|
1496
|
+
print("auroc score:", roc_auc_score(y_true, y_pred))
|
1497
|
+
rng = np.random.RandomState(random_state)
|
1498
|
+
for i in range(n_bootstraps):
|
1499
|
+
# bootstrap by sampling with replacement on the prediction indices
|
1500
|
+
indices = rng.randint(0, len(y_pred), len(y_pred))
|
1501
|
+
if len(np.unique(y_true[indices])) < 2:
|
1502
|
+
# We need at least one positive and one negative sample for ROC AUC
|
1503
|
+
# to be defined: reject the sample
|
1504
|
+
continue
|
1505
|
+
if isinstance(y_true, np.ndarray):
|
1506
|
+
score = roc_auc_score(y_true[indices], y_pred[indices])
|
1507
|
+
else:
|
1508
|
+
score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
|
1509
|
+
bootstrapped_scores.append(score)
|
1510
|
+
# print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
|
1511
|
+
sorted_scores = np.array(bootstrapped_scores)
|
1512
|
+
sorted_scores.sort()
|
1513
|
+
|
1514
|
+
# Computing the lower and upper bound of the 90% confidence interval
|
1515
|
+
# You can change the bounds percentiles to 0.025 and 0.975 to get
|
1516
|
+
# a 95% confidence interval instead.
|
1517
|
+
confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
|
1518
|
+
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1519
|
+
if verbose:
|
1520
|
+
print(
|
1521
|
+
"Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
|
1522
|
+
confidence_lower, confidence_upper
|
1523
|
+
)
|
1353
1524
|
)
|
1354
|
-
|
1355
|
-
|
1525
|
+
return confidence_lower, confidence_upper
|
1526
|
+
else:
|
1527
|
+
from sklearn.preprocessing import label_binarize
|
1528
|
+
# Multi-class classification case
|
1529
|
+
y_true = np.asarray(y_true)
|
1530
|
+
y_pred = np.asarray(y_pred)
|
1531
|
+
|
1532
|
+
# Binarize the multi-class labels for OvR computation
|
1533
|
+
y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) # One-vs-Rest transformation
|
1534
|
+
n_classes = y_true_bin.shape[1] # Number of classes
|
1535
|
+
|
1536
|
+
bootstrapped_scores = np.zeros((n_classes, n_bootstraps)) # Store scores for each class
|
1537
|
+
|
1538
|
+
if verbose:
|
1539
|
+
print("AUROC scores for each class:")
|
1540
|
+
for i in range(n_classes):
|
1541
|
+
print(f"Class {i}: {roc_auc_score(y_true_bin[:, i], y_pred[:, i])}")
|
1542
|
+
|
1543
|
+
rng = np.random.RandomState(random_state)
|
1544
|
+
for i in range(n_bootstraps):
|
1545
|
+
indices = rng.randint(0, len(y_pred), len(y_pred))
|
1546
|
+
for class_idx in range(n_classes):
|
1547
|
+
if len(np.unique(y_true_bin[indices, class_idx])) < 2:
|
1548
|
+
continue # Reject if the class doesn't have both positive and negative samples
|
1549
|
+
score = roc_auc_score(y_true_bin[indices, class_idx], y_pred[indices, class_idx])
|
1550
|
+
bootstrapped_scores[class_idx, i] = score
|
1551
|
+
|
1552
|
+
# Calculating the confidence intervals for each class
|
1553
|
+
confidence_intervals = []
|
1554
|
+
for class_idx in range(n_classes):
|
1555
|
+
sorted_scores = np.sort(bootstrapped_scores[class_idx])
|
1556
|
+
confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
|
1557
|
+
confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
|
1558
|
+
confidence_intervals.append((confidence_lower, confidence_upper))
|
1559
|
+
|
1560
|
+
if verbose:
|
1561
|
+
print(f"Class {class_idx} - Confidence interval: [{confidence_lower:.3f} - {confidence_upper:.3f}]")
|
1562
|
+
|
1563
|
+
return confidence_intervals
|
1356
1564
|
|
1357
1565
|
|
1358
1566
|
def plot_roc_curve(
|
@@ -1517,7 +1725,7 @@ def plot_pr_binary(
|
|
1517
1725
|
|
1518
1726
|
pr_boundary = interp1d(recall, precision, kind="linear", fill_value="extrapolate")
|
1519
1727
|
for f_score in f_scores:
|
1520
|
-
x_vals = np.linspace(0.01, 1,
|
1728
|
+
x_vals = np.linspace(0.01, 1, 20000)
|
1521
1729
|
y_vals = f_score * x_vals / (2 * x_vals - f_score)
|
1522
1730
|
y_vals_clipped = np.minimum(y_vals, pr_boundary(x_vals))
|
1523
1731
|
y_vals_clipped = np.clip(y_vals_clipped, 1e-3, None) # Prevent going to zero
|
@@ -1553,7 +1761,7 @@ def plot_pr_binary(
|
|
1553
1761
|
def plot_cm(
|
1554
1762
|
cm,
|
1555
1763
|
labels_name=None,
|
1556
|
-
thresh=0.8,
|
1764
|
+
thresh=0.8, # for set color
|
1557
1765
|
axis_labels=None,
|
1558
1766
|
cmap="Reds",
|
1559
1767
|
normalize=True,
|
@@ -2029,11 +2237,21 @@ def predict(
|
|
2029
2237
|
if isinstance(y_train, str) and y_train in x_train.columns:
|
2030
2238
|
y_train_col_name = y_train
|
2031
2239
|
y_train = x_train[y_train]
|
2032
|
-
y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
2240
|
+
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
|
2033
2241
|
x_train = x_train.drop(y_train_col_name, axis=1)
|
2242
|
+
# else:
|
2243
|
+
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
2244
|
+
y_train=pd.DataFrame(y_train)
|
2245
|
+
if y_train.select_dtypes(include=np.number).empty:
|
2246
|
+
y_train_=ips.df_encoder(y_train, method="dummy",drop=None)
|
2247
|
+
is_binary = False if y_train_.shape[1] >2 else True
|
2034
2248
|
else:
|
2035
|
-
|
2249
|
+
y_train_=ips.flatten(y_train.values)
|
2250
|
+
is_binary = False if len(y_train_)>2 else True
|
2036
2251
|
|
2252
|
+
if is_binary:
|
2253
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="label")
|
2254
|
+
print('is_binary:',is_binary)
|
2037
2255
|
if x_true is None:
|
2038
2256
|
x_train, x_true, y_train, y_true = train_test_split(
|
2039
2257
|
x_train,
|
@@ -2042,23 +2260,27 @@ def predict(
|
|
2042
2260
|
random_state=random_state,
|
2043
2261
|
stratify=y_train if purpose == "classification" else None,
|
2044
2262
|
)
|
2263
|
+
|
2045
2264
|
if isinstance(y_train, str) and y_train in x_train.columns:
|
2046
2265
|
y_train_col_name = y_train
|
2047
2266
|
y_train = x_train[y_train]
|
2048
|
-
y_train = ips.df_encoder(pd.DataFrame(y_train), method="
|
2267
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="label") if is_binary else y_train
|
2049
2268
|
x_train = x_train.drop(y_train_col_name, axis=1)
|
2050
|
-
|
2269
|
+
if is_binary:
|
2051
2270
|
y_train = ips.df_encoder(
|
2052
|
-
pd.DataFrame(y_train), method="
|
2053
|
-
).values.ravel()
|
2271
|
+
pd.DataFrame(y_train), method="label"
|
2272
|
+
).values.ravel()
|
2273
|
+
|
2054
2274
|
if y_true is not None:
|
2055
2275
|
if isinstance(y_true, str) and y_true in x_true.columns:
|
2056
2276
|
y_true_col_name = y_true
|
2057
2277
|
y_true = x_true[y_true]
|
2058
|
-
y_true = ips.df_encoder(pd.DataFrame(y_true), method="
|
2278
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="label") if is_binary else y_true
|
2279
|
+
y_true = pd.DataFrame(y_true)
|
2059
2280
|
x_true = x_true.drop(y_true_col_name, axis=1)
|
2060
|
-
|
2061
|
-
y_true = ips.df_encoder(pd.DataFrame(y_true), method="
|
2281
|
+
if is_binary:
|
2282
|
+
y_true = ips.df_encoder(pd.DataFrame(y_true), method="label").values.ravel()
|
2283
|
+
y_true = pd.DataFrame(y_true)
|
2062
2284
|
|
2063
2285
|
# to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
|
2064
2286
|
|
@@ -2068,7 +2290,6 @@ def predict(
|
|
2068
2290
|
y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
|
2069
2291
|
)
|
2070
2292
|
y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
|
2071
|
-
|
2072
2293
|
# Ensure common features are selected
|
2073
2294
|
if common_features is not None:
|
2074
2295
|
x_train, x_true = x_train[common_features], x_true[common_features]
|
@@ -2077,10 +2298,7 @@ def predict(
|
|
2077
2298
|
x_train, x_true = x_train[share_col_names], x_true[share_col_names]
|
2078
2299
|
|
2079
2300
|
x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
|
2080
|
-
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
|
2081
|
-
x_true, method="dummy"
|
2082
|
-
)
|
2083
|
-
|
2301
|
+
x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(x_true, method="dummy")
|
2084
2302
|
# Handle class imbalance using SMOTE (only for classification)
|
2085
2303
|
if (
|
2086
2304
|
smote
|
@@ -2091,7 +2309,13 @@ def predict(
|
|
2091
2309
|
|
2092
2310
|
smote_sampler = SMOTE(random_state=random_state)
|
2093
2311
|
x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
|
2094
|
-
|
2312
|
+
if not is_binary:
|
2313
|
+
if isinstance(y_train, np.ndarray):
|
2314
|
+
y_train = ips.df_encoder(data=pd.DataFrame(y_train),method='label')
|
2315
|
+
y_train=np.asarray(y_train)
|
2316
|
+
if isinstance(y_train, np.ndarray):
|
2317
|
+
y_true = ips.df_encoder(data=pd.DataFrame(y_true),method='label')
|
2318
|
+
y_true=np.asarray(y_true)
|
2095
2319
|
# Hyperparameter grids for tuning
|
2096
2320
|
if cv_level in ["low", "simple", "s", "l"]:
|
2097
2321
|
param_grids = {
|
@@ -2670,95 +2894,181 @@ def predict(
|
|
2670
2894
|
print(f"\nTraining and validating {name}:")
|
2671
2895
|
|
2672
2896
|
# Grid search with KFold or StratifiedKFold
|
2673
|
-
|
2674
|
-
|
2675
|
-
|
2676
|
-
|
2677
|
-
|
2678
|
-
|
2679
|
-
|
2680
|
-
|
2681
|
-
|
2682
|
-
|
2683
|
-
gs.fit(x_train, y_train)
|
2684
|
-
best_clf = gs.best_estimator_
|
2685
|
-
# make sure x_train and x_test has the same name
|
2686
|
-
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2687
|
-
y_pred = best_clf.predict(x_true)
|
2688
|
-
|
2689
|
-
# y_pred_proba
|
2690
|
-
if hasattr(best_clf, "predict_proba"):
|
2691
|
-
y_pred_proba = best_clf.predict_proba(x_true)[:, 1]
|
2692
|
-
elif hasattr(best_clf, "decision_function"):
|
2693
|
-
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
2694
|
-
y_pred_proba = best_clf.decision_function(x_true)
|
2695
|
-
# Ensure y_pred_proba is within 0 and 1 bounds
|
2696
|
-
y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
|
2697
|
-
y_pred_proba.max() - y_pred_proba.min()
|
2897
|
+
if is_binary:
|
2898
|
+
gs = GridSearchCV(
|
2899
|
+
clf,
|
2900
|
+
param_grid=param_grids.get(name, {}),
|
2901
|
+
scoring=(
|
2902
|
+
"roc_auc" if purpose == "classification" else "neg_mean_squared_error"
|
2903
|
+
),
|
2904
|
+
cv=cv,
|
2905
|
+
n_jobs=n_jobs,
|
2906
|
+
verbose=verbose,
|
2698
2907
|
)
|
2908
|
+
|
2909
|
+
gs.fit(x_train, y_train)
|
2910
|
+
best_clf = gs.best_estimator_
|
2911
|
+
# make sure x_train and x_test has the same name
|
2912
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2913
|
+
y_pred = best_clf.predict(x_true)
|
2914
|
+
if hasattr(best_clf, "predict_proba"):
|
2915
|
+
y_pred_proba = best_clf.predict_proba(x_true)
|
2916
|
+
print("Shape of predicted probabilities:", y_pred_proba.shape)
|
2917
|
+
if y_pred_proba.shape[1] == 1:
|
2918
|
+
y_pred_proba = np.hstack([1 - y_pred_proba, y_pred_proba]) # Add missing class probabilities
|
2919
|
+
y_pred_proba = y_pred_proba[:, 1]
|
2920
|
+
elif hasattr(best_clf, "decision_function"):
|
2921
|
+
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
2922
|
+
y_pred_proba = best_clf.decision_function(x_true)
|
2923
|
+
# Ensure y_pred_proba is within 0 and 1 bounds
|
2924
|
+
y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
|
2925
|
+
y_pred_proba.max() - y_pred_proba.min()
|
2926
|
+
)
|
2927
|
+
else:
|
2928
|
+
y_pred_proba = None # No probability output for certain models
|
2699
2929
|
else:
|
2700
|
-
|
2930
|
+
gs = GridSearchCV(
|
2931
|
+
clf,
|
2932
|
+
param_grid=param_grids.get(name, {}),
|
2933
|
+
scoring=(
|
2934
|
+
"roc_auc_ovr" if purpose == "classification" else "neg_mean_squared_error"
|
2935
|
+
),
|
2936
|
+
cv=cv,
|
2937
|
+
n_jobs=n_jobs,
|
2938
|
+
verbose=verbose,
|
2939
|
+
)
|
2701
2940
|
|
2941
|
+
# Fit GridSearchCV
|
2942
|
+
gs.fit(x_train, y_train)
|
2943
|
+
best_clf = gs.best_estimator_
|
2944
|
+
|
2945
|
+
# Ensure x_true aligns with x_train columns
|
2946
|
+
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2947
|
+
y_pred = best_clf.predict(x_true)
|
2948
|
+
|
2949
|
+
# Handle prediction probabilities for multiclass
|
2950
|
+
if hasattr(best_clf, "predict_proba"):
|
2951
|
+
y_pred_proba = best_clf.predict_proba(x_true)
|
2952
|
+
elif hasattr(best_clf, "decision_function"):
|
2953
|
+
y_pred_proba = best_clf.decision_function(x_true)
|
2954
|
+
|
2955
|
+
# Normalize for multiclass if necessary
|
2956
|
+
if y_pred_proba.ndim == 2:
|
2957
|
+
y_pred_proba = (y_pred_proba - y_pred_proba.min(axis=1, keepdims=True)) / \
|
2958
|
+
(y_pred_proba.max(axis=1, keepdims=True) - y_pred_proba.min(axis=1, keepdims=True))
|
2959
|
+
else:
|
2960
|
+
y_pred_proba = None # No probability output for certain models
|
2961
|
+
|
2702
2962
|
validation_scores = {}
|
2703
|
-
|
2963
|
+
|
2964
|
+
if y_true is not None and y_pred_proba is not None:
|
2704
2965
|
validation_scores = cal_metrics(
|
2705
2966
|
y_true,
|
2706
2967
|
y_pred,
|
2707
2968
|
y_pred_proba=y_pred_proba,
|
2969
|
+
is_binary=is_binary,
|
2708
2970
|
purpose=purpose,
|
2709
2971
|
average="weighted",
|
2710
2972
|
)
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
2724
|
-
|
2725
|
-
|
2726
|
-
|
2727
|
-
|
2728
|
-
|
2729
|
-
|
2730
|
-
|
2731
|
-
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2736
|
-
|
2737
|
-
|
2738
|
-
|
2739
|
-
|
2740
|
-
|
2741
|
-
|
2742
|
-
|
2743
|
-
|
2744
|
-
|
2745
|
-
|
2746
|
-
|
2747
|
-
|
2748
|
-
|
2749
|
-
|
2750
|
-
|
2751
|
-
|
2752
|
-
|
2753
|
-
|
2754
|
-
|
2755
|
-
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2973
|
+
if is_binary:
|
2974
|
+
# Calculate ROC curve
|
2975
|
+
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
|
2976
|
+
if y_pred_proba is not None:
|
2977
|
+
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
2978
|
+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
2979
|
+
lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
|
2980
|
+
roc_auc = auc(fpr, tpr)
|
2981
|
+
roc_info = {
|
2982
|
+
"fpr": fpr.tolist(),
|
2983
|
+
"tpr": tpr.tolist(),
|
2984
|
+
"auc": roc_auc,
|
2985
|
+
"ci95": (lower_ci, upper_ci),
|
2986
|
+
}
|
2987
|
+
# precision-recall curve
|
2988
|
+
precision_, recall_, _ = cal_precision_recall(y_true, y_pred_proba)
|
2989
|
+
avg_precision_ = average_precision_score(y_true, y_pred_proba)
|
2990
|
+
pr_info = {
|
2991
|
+
"precision": precision_,
|
2992
|
+
"recall": recall_,
|
2993
|
+
"avg_precision": avg_precision_,
|
2994
|
+
}
|
2995
|
+
else:
|
2996
|
+
roc_info, pr_info = None, None
|
2997
|
+
if purpose == "classification":
|
2998
|
+
results[name] = {
|
2999
|
+
"best_clf": gs.best_estimator_,
|
3000
|
+
"best_params": gs.best_params_,
|
3001
|
+
"auc_indiv": [
|
3002
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
3003
|
+
for i in range(cv_folds)
|
3004
|
+
],
|
3005
|
+
"scores": validation_scores,
|
3006
|
+
"roc_curve": roc_info,
|
3007
|
+
"pr_curve": pr_info,
|
3008
|
+
"confusion_matrix": confusion_matrix(y_true, y_pred),
|
3009
|
+
"predictions": y_pred.tolist(),
|
3010
|
+
"predictions_proba": (
|
3011
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3012
|
+
),
|
3013
|
+
}
|
3014
|
+
else: # "regression"
|
3015
|
+
results[name] = {
|
3016
|
+
"best_clf": gs.best_estimator_,
|
3017
|
+
"best_params": gs.best_params_,
|
3018
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
3019
|
+
"predictions": y_pred.tolist(),
|
3020
|
+
"predictions_proba": (
|
3021
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3022
|
+
),
|
3023
|
+
}
|
3024
|
+
else: # multi-classes
|
3025
|
+
if y_pred_proba is not None:
|
3026
|
+
# fpr, tpr, roc_auc = dict(), dict(), dict()
|
3027
|
+
# fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
3028
|
+
confidence_intervals = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
|
3029
|
+
roc_info = {
|
3030
|
+
"fpr": validation_scores["fpr"],
|
3031
|
+
"tpr": validation_scores["tpr"],
|
3032
|
+
"auc": validation_scores["roc_auc_by_class"],
|
3033
|
+
"ci95": confidence_intervals,
|
3034
|
+
}
|
3035
|
+
# precision-recall curve
|
3036
|
+
precision_, recall_, avg_precision_ = cal_precision_recall(y_true, y_pred_proba,is_binary=is_binary)
|
3037
|
+
pr_info = {
|
3038
|
+
"precision": precision_,
|
3039
|
+
"recall": recall_,
|
3040
|
+
"avg_precision": avg_precision_,
|
3041
|
+
}
|
3042
|
+
else:
|
3043
|
+
roc_info, pr_info = None, None
|
3044
|
+
|
3045
|
+
if purpose == "classification":
|
3046
|
+
results[name] = {
|
3047
|
+
"best_clf": gs.best_estimator_,
|
3048
|
+
"best_params": gs.best_params_,
|
3049
|
+
"auc_indiv": [
|
3050
|
+
gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
|
3051
|
+
for i in range(cv_folds)
|
3052
|
+
],
|
3053
|
+
"scores": validation_scores,
|
3054
|
+
"roc_curve": roc_info,
|
3055
|
+
"pr_curve": pr_info,
|
3056
|
+
"confusion_matrix": confusion_matrix(y_true, y_pred),
|
3057
|
+
"predictions": y_pred.tolist(),
|
3058
|
+
"predictions_proba": (
|
3059
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3060
|
+
),
|
3061
|
+
}
|
3062
|
+
else: # "regression"
|
3063
|
+
results[name] = {
|
3064
|
+
"best_clf": gs.best_estimator_,
|
3065
|
+
"best_params": gs.best_params_,
|
3066
|
+
"scores": validation_scores, # e.g., neg_MSE, R², etc.
|
3067
|
+
"predictions": y_pred.tolist(),
|
3068
|
+
"predictions_proba": (
|
3069
|
+
y_pred_proba.tolist() if y_pred_proba is not None else None
|
3070
|
+
),
|
3071
|
+
}
|
2762
3072
|
|
2763
3073
|
else:
|
2764
3074
|
results[name] = {
|
@@ -2773,7 +3083,6 @@ def predict(
|
|
2773
3083
|
|
2774
3084
|
# Convert results to DataFrame
|
2775
3085
|
df_results = pd.DataFrame.from_dict(results, orient="index")
|
2776
|
-
|
2777
3086
|
# sort
|
2778
3087
|
if y_true is not None and purpose == "classification":
|
2779
3088
|
df_scores = pd.DataFrame(
|
@@ -2790,26 +3099,29 @@ def predict(
|
|
2790
3099
|
plot.figsets(xangle=30)
|
2791
3100
|
if dir_save:
|
2792
3101
|
ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
|
3102
|
+
|
3103
|
+
df_scores=df_scores.select_dtypes(include=np.number)
|
3104
|
+
|
2793
3105
|
if df_scores.shape[0] > 1: # draw cluster
|
2794
3106
|
plot.heatmap(df_scores, kind="direct", cluster=True)
|
2795
3107
|
plot.figsets(xangle=30)
|
2796
3108
|
if dir_save:
|
2797
3109
|
ips.figsave(dir_save + f"scores_clus{now_}.pdf")
|
2798
3110
|
if all([plot_, y_true is not None, purpose == "classification"]):
|
2799
|
-
try:
|
2800
|
-
|
2801
|
-
|
2802
|
-
|
2803
|
-
|
2804
|
-
|
2805
|
-
|
2806
|
-
except Exception as e:
|
2807
|
-
|
3111
|
+
# try:
|
3112
|
+
if len(models) > 3:
|
3113
|
+
plot_validate_features(df_results,is_binary=is_binary)
|
3114
|
+
else:
|
3115
|
+
plot_validate_features_single(df_results, is_binary=is_binary)
|
3116
|
+
if dir_save:
|
3117
|
+
ips.figsave(dir_save + f"validate_features{now_}.pdf")
|
3118
|
+
# except Exception as e:
|
3119
|
+
# print(f"Error: 在画图的过程中出现了问题:{e}")
|
2808
3120
|
return df_results
|
2809
3121
|
|
2810
3122
|
|
2811
3123
|
def cal_metrics(
|
2812
|
-
y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
|
3124
|
+
y_true, y_pred, y_pred_proba=None, is_binary=True,purpose="regression", average="weighted"
|
2813
3125
|
):
|
2814
3126
|
"""
|
2815
3127
|
Calculate regression or classification metrics based on the purpose.
|
@@ -2879,19 +3191,362 @@ def cal_metrics(
|
|
2879
3191
|
}
|
2880
3192
|
|
2881
3193
|
# Confusion matrix to calculate specificity
|
2882
|
-
|
2883
|
-
|
2884
|
-
|
2885
|
-
|
3194
|
+
if is_binary:
|
3195
|
+
cm = confusion_matrix(y_true, y_pred)
|
3196
|
+
if cm.size == 4:
|
3197
|
+
tn, fp, fn, tp = cm.ravel()
|
3198
|
+
else:
|
3199
|
+
# Handle single-class predictions
|
3200
|
+
tn, fp, fn, tp = 0, 0, 0, 0
|
3201
|
+
print("Warning: Only one class found in y_pred or y_true.")
|
3202
|
+
|
3203
|
+
# Specificity calculation
|
3204
|
+
validation_scores["specificity"] = (
|
3205
|
+
tn / (tn + fp) if (tn + fp) > 0 else 0
|
3206
|
+
)
|
3207
|
+
if y_pred_proba is not None:
|
3208
|
+
# Calculate ROC-AUC
|
3209
|
+
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
3210
|
+
# PR-AUC (Precision-Recall AUC) calculation
|
3211
|
+
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
3212
|
+
|
3213
|
+
else: # multi-class
|
3214
|
+
from sklearn.preprocessing import label_binarize
|
3215
|
+
#* Multi-class ROC calculation
|
3216
|
+
y_pred_proba = np.asarray(y_pred_proba)
|
3217
|
+
classes = np.unique(y_true)
|
3218
|
+
y_true_bin = label_binarize(y_true, classes=classes)
|
3219
|
+
if isinstance(y_true, np.ndarray):
|
3220
|
+
y_true = ips.df_encoder(data=pd.DataFrame(y_true), method='dum',prefix='Label')
|
3221
|
+
# Initialize dictionaries to store FPR, TPR, and AUC for each class
|
3222
|
+
fpr = dict()
|
3223
|
+
tpr = dict()
|
3224
|
+
roc_auc = dict()
|
3225
|
+
for i, class_label in enumerate(classes):
|
3226
|
+
fpr[class_label], tpr[class_label], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
|
3227
|
+
roc_auc[class_label] = auc(fpr[class_label], tpr[class_label])
|
3228
|
+
|
3229
|
+
# Store the mean ROC AUC
|
3230
|
+
try:
|
3231
|
+
validation_scores["roc_auc"] = roc_auc_score(
|
3232
|
+
y_true, y_pred_proba, multi_class="ovr", average=average
|
3233
|
+
)
|
3234
|
+
except Exception as e:
|
3235
|
+
y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True)
|
3236
|
+
validation_scores["roc_auc"] = roc_auc_score(
|
3237
|
+
y_true, y_pred_proba, multi_class="ovr", average=average
|
3238
|
+
)
|
3239
|
+
|
3240
|
+
validation_scores["roc_auc_by_class"] = roc_auc # Individual class AUCs
|
3241
|
+
validation_scores["fpr"] = fpr
|
3242
|
+
validation_scores["tpr"] = tpr
|
2886
3243
|
|
2887
|
-
if y_pred_proba is not None:
|
2888
|
-
# Calculate ROC-AUC
|
2889
|
-
validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
|
2890
|
-
# PR-AUC (Precision-Recall AUC) calculation
|
2891
|
-
validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
|
2892
3244
|
else:
|
2893
3245
|
raise ValueError(
|
2894
3246
|
"Invalid purpose specified. Choose 'regression' or 'classification'."
|
2895
3247
|
)
|
2896
3248
|
|
2897
3249
|
return validation_scores
|
3250
|
+
|
3251
|
+
def plot_trees(
|
3252
|
+
X, y, cls, max_trees=500, test_size=0.2, random_state=42, early_stopping_rounds=None
|
3253
|
+
):
|
3254
|
+
"""
|
3255
|
+
# # Example usage:
|
3256
|
+
# X = np.random.rand(100, 10) # Example data with 100 samples and 10 features
|
3257
|
+
# y = np.random.randint(0, 2, 100) # Example binary target
|
3258
|
+
# # Using the function with different classifiers
|
3259
|
+
# # Random Forest example
|
3260
|
+
# plot_trees(X, y, RandomForestClassifier(), max_trees=100)
|
3261
|
+
# # Gradient Boosting with early stopping example
|
3262
|
+
# plot_trees(X, y, GradientBoostingClassifier(), max_trees=100, early_stopping_rounds=10)
|
3263
|
+
# # Extra Trees example
|
3264
|
+
# plot_trees(X, y, ExtraTreesClassifier(), max_trees=100)
|
3265
|
+
Master function to plot error rates (OOB, training, and testing) for different tree-based ensemble classifiers.
|
3266
|
+
|
3267
|
+
Parameters:
|
3268
|
+
- X (array-like): Feature matrix.
|
3269
|
+
- y (array-like): Target labels.
|
3270
|
+
- cls (object): Tree-based ensemble classifier instance (e.g., RandomForestClassifier()).
|
3271
|
+
- max_trees (int): Maximum number of trees to evaluate. Default is 500.
|
3272
|
+
- test_size (float): Proportion of data to use as test set for testing error. Default is 0.2.
|
3273
|
+
- random_state (int): Random state for reproducibility. Default is 42.
|
3274
|
+
- early_stopping_rounds (int): For boosting models only, stops training if validation error doesn't improve after specified rounds.
|
3275
|
+
|
3276
|
+
Returns:
|
3277
|
+
- None
|
3278
|
+
"""
|
3279
|
+
from sklearn.model_selection import train_test_split
|
3280
|
+
from sklearn.metrics import accuracy_score
|
3281
|
+
from sklearn.ensemble import (
|
3282
|
+
RandomForestClassifier,
|
3283
|
+
BaggingClassifier,
|
3284
|
+
ExtraTreesClassifier,
|
3285
|
+
)
|
3286
|
+
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
|
3287
|
+
# Split data for training and testing error calculation
|
3288
|
+
x_train, x_test, y_train, y_test = train_test_split(
|
3289
|
+
X, y, test_size=test_size, random_state=random_state
|
3290
|
+
)
|
3291
|
+
|
3292
|
+
# Initialize lists to store error rates
|
3293
|
+
oob_error_rate = []
|
3294
|
+
train_error_rate = []
|
3295
|
+
test_error_rate = []
|
3296
|
+
validation_error = None
|
3297
|
+
|
3298
|
+
# Configure classifier based on type
|
3299
|
+
oob_enabled = False # Default to no OOB error unless explicitly set
|
3300
|
+
|
3301
|
+
if isinstance(cls, (RandomForestClassifier, ExtraTreesClassifier)):
|
3302
|
+
# Enable OOB if cls supports it and is using bootstrapping
|
3303
|
+
cls.set_params(warm_start=True, n_estimators=1)
|
3304
|
+
if hasattr(cls, "oob_score"):
|
3305
|
+
cls.set_params(bootstrap=True, oob_score=True)
|
3306
|
+
oob_enabled = True
|
3307
|
+
elif isinstance(cls, BaggingClassifier):
|
3308
|
+
cls.set_params(warm_start=True, bootstrap=True, oob_score=True, n_estimators=1)
|
3309
|
+
oob_enabled = True
|
3310
|
+
elif isinstance(cls, (AdaBoostClassifier, GradientBoostingClassifier)):
|
3311
|
+
cls.set_params(n_estimators=1)
|
3312
|
+
oob_enabled = False
|
3313
|
+
if early_stopping_rounds:
|
3314
|
+
validation_error = []
|
3315
|
+
|
3316
|
+
# Train and evaluate with an increasing number of trees
|
3317
|
+
for i in range(1, max_trees + 1):
|
3318
|
+
cls.set_params(n_estimators=i)
|
3319
|
+
cls.fit(x_train, y_train)
|
3320
|
+
|
3321
|
+
# Calculate OOB error (for models that support it)
|
3322
|
+
if oob_enabled and hasattr(cls, "oob_score_") and cls.oob_score:
|
3323
|
+
oob_error = 1 - cls.oob_score_
|
3324
|
+
oob_error_rate.append(oob_error)
|
3325
|
+
|
3326
|
+
# Calculate training error
|
3327
|
+
train_error = 1 - accuracy_score(y_train, cls.predict(x_train))
|
3328
|
+
train_error_rate.append(train_error)
|
3329
|
+
|
3330
|
+
# Calculate testing error
|
3331
|
+
test_error = 1 - accuracy_score(y_test, cls.predict(x_test))
|
3332
|
+
test_error_rate.append(test_error)
|
3333
|
+
|
3334
|
+
# For boosting models, use validation error with early stopping
|
3335
|
+
if early_stopping_rounds and isinstance(
|
3336
|
+
cls, (AdaBoostClassifier, GradientBoostingClassifier)
|
3337
|
+
):
|
3338
|
+
val_error = 1 - accuracy_score(y_test, cls.predict(x_test))
|
3339
|
+
validation_error.append(val_error)
|
3340
|
+
if len(validation_error) > early_stopping_rounds:
|
3341
|
+
# Stop if validation error has not improved in early_stopping_rounds
|
3342
|
+
if validation_error[-early_stopping_rounds:] == sorted(
|
3343
|
+
validation_error[-early_stopping_rounds:]
|
3344
|
+
):
|
3345
|
+
print(f"Early stopping at tree {i} due to lack of improvement in validation error.")
|
3346
|
+
break
|
3347
|
+
|
3348
|
+
# Plot results
|
3349
|
+
plt.figure(figsize=(10, 6))
|
3350
|
+
if oob_error_rate:
|
3351
|
+
plt.plot(
|
3352
|
+
range(1, len(oob_error_rate) + 1),
|
3353
|
+
oob_error_rate,
|
3354
|
+
color="black",
|
3355
|
+
label="OOB Error Rate",
|
3356
|
+
linewidth=2,
|
3357
|
+
)
|
3358
|
+
if train_error_rate:
|
3359
|
+
plt.plot(
|
3360
|
+
range(1, len(train_error_rate) + 1),
|
3361
|
+
train_error_rate,
|
3362
|
+
linestyle="dotted",
|
3363
|
+
color="green",
|
3364
|
+
label="Training Error Rate",
|
3365
|
+
)
|
3366
|
+
if test_error_rate:
|
3367
|
+
plt.plot(
|
3368
|
+
range(1, len(test_error_rate) + 1),
|
3369
|
+
test_error_rate,
|
3370
|
+
linestyle="dashed",
|
3371
|
+
color="red",
|
3372
|
+
label="Testing Error Rate",
|
3373
|
+
)
|
3374
|
+
if validation_error:
|
3375
|
+
plt.plot(
|
3376
|
+
range(1, len(validation_error) + 1),
|
3377
|
+
validation_error,
|
3378
|
+
linestyle="solid",
|
3379
|
+
color="blue",
|
3380
|
+
label="Validation Error (Boosting)",
|
3381
|
+
)
|
3382
|
+
|
3383
|
+
# Customize plot
|
3384
|
+
plt.xlabel("Number of Trees")
|
3385
|
+
plt.ylabel("Error Rate")
|
3386
|
+
plt.title(f"Error Rate Analysis for {cls.__class__.__name__}")
|
3387
|
+
plt.legend(loc="upper right")
|
3388
|
+
plt.grid(True)
|
3389
|
+
plt.show()
|
3390
|
+
|
3391
|
+
def img_datasets_preprocessing(
|
3392
|
+
data: pd.DataFrame,
|
3393
|
+
x_col: str,
|
3394
|
+
y_col: str=None,
|
3395
|
+
target_size: tuple = (224, 224),
|
3396
|
+
batch_size: int = 128,
|
3397
|
+
class_mode: str = "raw",
|
3398
|
+
shuffle: bool = False,
|
3399
|
+
augment: bool = False,
|
3400
|
+
scaler: str = 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
|
3401
|
+
grayscale: bool = False,
|
3402
|
+
encoder: str = "label", # Options: 'label', 'onehot', 'binary'
|
3403
|
+
label_encoder=None,
|
3404
|
+
kws_augmentation: dict = None,
|
3405
|
+
verbose: bool = True,
|
3406
|
+
drop_missing: bool = True,
|
3407
|
+
output="df", # "iterator":data_iterator,'df':return DataFrame
|
3408
|
+
):
|
3409
|
+
"""
|
3410
|
+
Enhanced preprocessing function for loading and preparing image data from a DataFrame.
|
3411
|
+
|
3412
|
+
Parameters:
|
3413
|
+
- df (pd.DataFrame): Input DataFrame with image paths and labels.
|
3414
|
+
- x_col (str): Column in `df` containing image file paths.
|
3415
|
+
- y_col (str): Column in `df` containing image labels.
|
3416
|
+
- target_size (tuple): Desired image size in (height, width).
|
3417
|
+
- batch_size (int): Number of images per batch.
|
3418
|
+
- class_mode (str): Mode of label ('raw', 'categorical', 'binary').
|
3419
|
+
- shuffle (bool): Shuffle the images in the DataFrame.
|
3420
|
+
- augment (bool): Apply data augmentation.
|
3421
|
+
- scaler (str): 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
|
3422
|
+
- grayscale (bool): Convert images to grayscale.
|
3423
|
+
- normalize (bool): Normalize image data to [0, 1] range.
|
3424
|
+
- encoder (str): Label encoder method ('label', 'onehot', 'binary').
|
3425
|
+
- label_encoder: Optional pre-defined label encoder.
|
3426
|
+
- kws_augmentation (dict): Parameters for data augmentation.
|
3427
|
+
- verbose (bool): Print status messages.
|
3428
|
+
- drop_missing (bool): Drop rows with missing or invalid image paths.
|
3429
|
+
|
3430
|
+
Returns:
|
3431
|
+
- pd.DataFrame: DataFrame with flattened image pixels and 'Label' column.
|
3432
|
+
"""
|
3433
|
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
3434
|
+
from tensorflow.keras.utils import to_categorical
|
3435
|
+
from sklearn.preprocessing import LabelEncoder
|
3436
|
+
from PIL import Image
|
3437
|
+
import os
|
3438
|
+
|
3439
|
+
# Validate input DataFrame for required columns
|
3440
|
+
if y_col:
|
3441
|
+
assert (
|
3442
|
+
x_col in data.columns and y_col in data.columns
|
3443
|
+
), "Missing required columns in DataFrame."
|
3444
|
+
if y_col is None:
|
3445
|
+
class_mode=None
|
3446
|
+
# 输出格式
|
3447
|
+
output = ips.strcmp(output,[
|
3448
|
+
"generator","tf","iterator","transform","transformer","dataframe",
|
3449
|
+
"df","pd","pandas"])[0]
|
3450
|
+
|
3451
|
+
# Handle missing file paths
|
3452
|
+
if drop_missing:
|
3453
|
+
data = data[
|
3454
|
+
data[x_col].apply(lambda path: os.path.exists(path) and os.path.isfile(path))
|
3455
|
+
]
|
3456
|
+
|
3457
|
+
# Encoding labels if necessary
|
3458
|
+
if encoder and y_col is not None:
|
3459
|
+
if encoder == "binary":
|
3460
|
+
data[y_col] = (data[y_col] == data[y_col].unique()[0]).astype(int)
|
3461
|
+
elif encoder == "onehot":
|
3462
|
+
if not label_encoder:
|
3463
|
+
label_encoder = LabelEncoder()
|
3464
|
+
data[y_col] = label_encoder.fit_transform(data[y_col])
|
3465
|
+
data[y_col] = to_categorical(data[y_col])
|
3466
|
+
elif encoder == "label":
|
3467
|
+
if not label_encoder:
|
3468
|
+
label_encoder = LabelEncoder()
|
3469
|
+
data[y_col] = label_encoder.fit_transform(data[y_col])
|
3470
|
+
|
3471
|
+
# Set up data augmentation
|
3472
|
+
if augment:
|
3473
|
+
aug_params = {
|
3474
|
+
"rotation_range": 20,
|
3475
|
+
"width_shift_range": 0.2,
|
3476
|
+
"height_shift_range": 0.2,
|
3477
|
+
"shear_range": 0.2,
|
3478
|
+
"zoom_range": 0.2,
|
3479
|
+
"horizontal_flip": True,
|
3480
|
+
"fill_mode": "nearest",
|
3481
|
+
}
|
3482
|
+
if kws_augmentation:
|
3483
|
+
aug_params.update(kws_augmentation)
|
3484
|
+
dat = ImageDataGenerator(rescale=scaler, **aug_params)
|
3485
|
+
dat = ImageDataGenerator(
|
3486
|
+
rescale=1.0 / 255 if scaler == 'normalize' else None, **aug_params)
|
3487
|
+
|
3488
|
+
else:
|
3489
|
+
dat = ImageDataGenerator(
|
3490
|
+
rescale=1.0 / 255 if scaler == 'normalize' else None)
|
3491
|
+
|
3492
|
+
# Create DataFrameIterator
|
3493
|
+
data_iterator = dat.flow_from_dataframe(
|
3494
|
+
dataframe=data,
|
3495
|
+
x_col=x_col,
|
3496
|
+
y_col=y_col,
|
3497
|
+
target_size=target_size,
|
3498
|
+
color_mode="grayscale" if grayscale else "rgb",
|
3499
|
+
batch_size=batch_size,
|
3500
|
+
class_mode=class_mode,
|
3501
|
+
shuffle=shuffle,
|
3502
|
+
)
|
3503
|
+
print(f"target_size:{target_size}")
|
3504
|
+
if output.lower() in ["generator", "tf", "iterator", "transform", "transformer"]:
|
3505
|
+
return data_iterator
|
3506
|
+
elif output.lower() in ["dataframe", "df", "pd", "pandas"]:
|
3507
|
+
# Initialize list to collect processed data
|
3508
|
+
data_list = []
|
3509
|
+
total_batches = data_iterator.n // batch_size
|
3510
|
+
|
3511
|
+
# Load, resize, and process images in batches
|
3512
|
+
for i, (batch_images, batch_labels) in enumerate(data_iterator):
|
3513
|
+
for img, label in zip(batch_images, batch_labels):
|
3514
|
+
if scaler == ['normalize','raw']:
|
3515
|
+
# Already rescaled by 1.0/255 in ImageDataGenerator
|
3516
|
+
pass
|
3517
|
+
elif scaler == 'standardize':
|
3518
|
+
# Standardize by subtracting mean and dividing by std
|
3519
|
+
img = (img - np.mean(img)) / np.std(img)
|
3520
|
+
elif scaler == 'clahe':
|
3521
|
+
# Apply CLAHE to the image
|
3522
|
+
img = apply_clahe(img)
|
3523
|
+
flat_img = img.flatten()
|
3524
|
+
data_list.append(np.append(flat_img, label))
|
3525
|
+
|
3526
|
+
# Stop when all images have been processed
|
3527
|
+
if i >= total_batches:
|
3528
|
+
break
|
3529
|
+
|
3530
|
+
# Define column names for flattened image data
|
3531
|
+
pixel_count = target_size[0] * target_size[1] * (1 if grayscale else 3)
|
3532
|
+
column_names = [f"pixel_{i}" for i in range(pixel_count)] + ["Label"]
|
3533
|
+
|
3534
|
+
# Create DataFrame from flattened data
|
3535
|
+
df_img = pd.DataFrame(data_list, columns=column_names)
|
3536
|
+
|
3537
|
+
if verbose:
|
3538
|
+
print("Processed images:", len(df_img))
|
3539
|
+
print("Final DataFrame shape:", df_img.shape)
|
3540
|
+
display(df_img.head())
|
3541
|
+
|
3542
|
+
return df_img
|
3543
|
+
# Function to apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
|
3544
|
+
def apply_clahe(img):
|
3545
|
+
import cv2
|
3546
|
+
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) # Convert to LAB color space
|
3547
|
+
l, a, b = cv2.split(lab) # Split into channels
|
3548
|
+
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
3549
|
+
cl = clahe.apply(l) # Apply CLAHE to the L channel
|
3550
|
+
limg = cv2.merge((cl, a, b)) # Merge back the channels
|
3551
|
+
img_clahe = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) # Convert back to RGB
|
3552
|
+
return img_clahe
|