py2ls 0.2.4.14__py3-none-any.whl → 0.2.4.16__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/ml2ls.py CHANGED
@@ -506,7 +506,7 @@ def get_models(
506
506
  "Support Vector Machine(svm)",
507
507
  "naive bayes",
508
508
  "Linear Discriminant Analysis (lda)",
509
- "adaboost",
509
+ "AdaBoost",
510
510
  "DecisionTree",
511
511
  "KNeighbors",
512
512
  "Bagging",
@@ -585,7 +585,7 @@ def get_features(
585
585
  "Support Vector Machine(svm)",
586
586
  "naive bayes",
587
587
  "Linear Discriminant Analysis (lda)",
588
- "adaboost",
588
+ "AdaBoost",
589
589
  "DecisionTree",
590
590
  "KNeighbors",
591
591
  "Bagging",
@@ -616,10 +616,10 @@ def get_features(
616
616
  if isinstance(y, str) and y in X.columns:
617
617
  y_col_name = y
618
618
  y = X[y]
619
- y = ips.df_encoder(pd.DataFrame(y), method="dummy")
619
+ y = ips.df_encoder(pd.DataFrame(y), method="label")
620
620
  X = X.drop(y_col_name, axis=1)
621
621
  else:
622
- y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
622
+ y = ips.df_encoder(pd.DataFrame(y), method="label").values.ravel()
623
623
  y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
624
624
  y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
625
625
 
@@ -699,9 +699,11 @@ def get_features(
699
699
  "Support Vector Machine(svm)",
700
700
  "Naive Bayes",
701
701
  "Linear Discriminant Analysis (lda)",
702
- "adaboost",
702
+ "AdaBoost",
703
703
  ]
704
704
  cls = [ips.strcmp(i, cls_)[0] for i in cls]
705
+
706
+ feature_importances = {}
705
707
 
706
708
  # Lasso Feature Selection
707
709
  lasso_importances = (
@@ -712,6 +714,7 @@ def get_features(
712
714
  lasso_selected_features = (
713
715
  lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
714
716
  )
717
+ feature_importances['lasso']=lasso_importances.head(n_features)
715
718
  # Ridge
716
719
  ridge_importances = (
717
720
  features_ridge(x_train, y_train, ridge_params)
@@ -721,6 +724,7 @@ def get_features(
721
724
  selected_ridge_features = (
722
725
  ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
723
726
  )
727
+ feature_importances['ridge']=ridge_importances.head(n_features)
724
728
  # Elastic Net
725
729
  enet_importances = (
726
730
  features_enet(x_train, y_train, enet_params)
@@ -730,6 +734,7 @@ def get_features(
730
734
  selected_enet_features = (
731
735
  enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
732
736
  )
737
+ feature_importances['Enet']=enet_importances.head(n_features)
733
738
  # Random Forest Feature Importance
734
739
  rf_importances = (
735
740
  features_rf(x_train, y_train, rf_params)
@@ -741,6 +746,7 @@ def get_features(
741
746
  if "Random Forest" in cls
742
747
  else []
743
748
  )
749
+ feature_importances['Random Forest']=rf_importances.head(n_features)
744
750
  # Gradient Boosting Feature Importance
745
751
  gb_importances = (
746
752
  features_gradient_boosting(x_train, y_train, gb_params)
@@ -752,6 +758,7 @@ def get_features(
752
758
  if "Gradient Boosting" in cls
753
759
  else []
754
760
  )
761
+ feature_importances['Gradient Boosting']=gb_importances.head(n_features)
755
762
  # xgb
756
763
  xgb_importances = (
757
764
  features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
@@ -759,6 +766,7 @@ def get_features(
759
766
  top_xgb_features = (
760
767
  xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
761
768
  )
769
+ feature_importances['xgb']=xgb_importances.head(n_features)
762
770
 
763
771
  # SVM with RFE
764
772
  selected_svm_features = (
@@ -773,6 +781,7 @@ def get_features(
773
781
  selected_lda_features = (
774
782
  lda_importances.head(n_features)["feature"].values if "lda" in cls else []
775
783
  )
784
+ feature_importances['lda']=lda_importances.head(n_features)
776
785
  # AdaBoost Feature Importance
777
786
  adaboost_importances = (
778
787
  features_adaboost(x_train, y_train, adaboost_params)
@@ -784,6 +793,7 @@ def get_features(
784
793
  if "AdaBoost" in cls
785
794
  else []
786
795
  )
796
+ feature_importances['AdaBoost']=adaboost_importances.head(n_features)
787
797
  # Decision Tree Feature Importance
788
798
  dt_importances = (
789
799
  features_decision_tree(x_train, y_train, dt_params)
@@ -794,7 +804,8 @@ def get_features(
794
804
  dt_importances.head(n_features)["feature"].values
795
805
  if "Decision Tree" in cls
796
806
  else []
797
- )
807
+ )
808
+ feature_importances['Decision Tree']=dt_importances.head(n_features)
798
809
  # Bagging Feature Importance
799
810
  bagging_importances = (
800
811
  features_bagging(x_train, y_train, bagging_params)
@@ -806,6 +817,7 @@ def get_features(
806
817
  if "Bagging" in cls
807
818
  else []
808
819
  )
820
+ feature_importances['Bagging']=bagging_importances.head(n_features)
809
821
  # KNN Feature Importance via Permutation
810
822
  knn_importances = (
811
823
  features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
@@ -813,6 +825,7 @@ def get_features(
813
825
  top_knn_features = (
814
826
  knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
815
827
  )
828
+ feature_importances['KNN']=knn_importances.head(n_features)
816
829
 
817
830
  #! Find common features
818
831
  common_features = ips.shared(
@@ -915,6 +928,7 @@ def get_features(
915
928
  "cv_train_scores": cv_train_results_df,
916
929
  "cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
917
930
  "common_features": list(common_features),
931
+ "feature_importances":feature_importances
918
932
  }
919
933
  if all([plot_, dir_save]):
920
934
  from datetime import datetime
@@ -927,6 +941,7 @@ def get_features(
927
941
  "cv_train_scores": pd.DataFrame(),
928
942
  "cv_test_scores": pd.DataFrame(),
929
943
  "common_features": [],
944
+ "feature_importances":{}
930
945
  }
931
946
  print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
932
947
  return results
@@ -1217,142 +1232,335 @@ def validate_features(
1217
1232
 
1218
1233
  # # If you want to access validation scores
1219
1234
  # print(validation_results)
1220
- def plot_validate_features(res_val):
1235
+ def plot_validate_features(res_val,is_binary=True,figsize=None):
1221
1236
  """
1222
1237
  plot the results of 'validate_features()'
1223
1238
  """
1224
- colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1225
- if res_val.shape[0] > 5:
1226
- alpha = 0
1227
- figsize = [8, 10]
1228
- subplot_layout = [1, 2]
1229
- ncols = 2
1230
- bbox_to_anchor = [1.5, 0.6]
1231
- else:
1232
- alpha = 0.03
1233
- figsize = [10, 6]
1234
- subplot_layout = [1, 1]
1235
- ncols = 1
1236
- bbox_to_anchor = [1, 1]
1237
- nexttile = plot.subplot(figsize=figsize)
1238
- ax = nexttile(subplot_layout[0], subplot_layout[1])
1239
- for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1240
- fpr = res_val["roc_curve"][model_name]["fpr"]
1241
- tpr = res_val["roc_curve"][model_name]["tpr"]
1242
- (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1243
- mean_auc = res_val["roc_curve"][model_name]["auc"]
1244
- plot_roc_curve(
1245
- fpr,
1246
- tpr,
1247
- mean_auc,
1248
- lower_ci,
1249
- upper_ci,
1250
- model_name=model_name,
1251
- lw=1.5,
1252
- color=colors[i],
1253
- alpha=alpha,
1254
- ax=ax,
1239
+ if is_binary:
1240
+ colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1241
+ if res_val.shape[0] > 5:
1242
+ alpha = 0
1243
+ figsize = [8, 10] if figsize is None else figsize
1244
+ subplot_layout = [1, 2]
1245
+ ncols = 2
1246
+ bbox_to_anchor = [1.5, 0.6]
1247
+ else:
1248
+ alpha = 0.03
1249
+ figsize = [10, 6] if figsize is None else figsize
1250
+ subplot_layout = [1, 1]
1251
+ ncols = 1
1252
+ bbox_to_anchor = [1, 1]
1253
+ nexttile = plot.subplot(figsize=figsize)
1254
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1255
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1256
+ fpr = res_val["roc_curve"][model_name]["fpr"]
1257
+ tpr = res_val["roc_curve"][model_name]["tpr"]
1258
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1259
+ mean_auc = res_val["roc_curve"][model_name]["auc"]
1260
+ plot_roc_curve(
1261
+ fpr,
1262
+ tpr,
1263
+ mean_auc,
1264
+ lower_ci,
1265
+ upper_ci,
1266
+ model_name=model_name,
1267
+ lw=1.5,
1268
+ color=colors[i],
1269
+ alpha=alpha,
1270
+ ax=ax,
1271
+ )
1272
+ plot.figsets(
1273
+ sp=2,
1274
+ legend=dict(
1275
+ loc="upper right",
1276
+ ncols=ncols,
1277
+ fontsize=8,
1278
+ bbox_to_anchor=[1.5, 0.6],
1279
+ markerscale=0.8,
1280
+ ),
1255
1281
  )
1256
- plot.figsets(
1257
- sp=2,
1258
- legend=dict(
1259
- loc="upper right",
1260
- ncols=ncols,
1261
- fontsize=8,
1262
- bbox_to_anchor=[1.5, 0.6],
1263
- markerscale=0.8,
1264
- ),
1265
- )
1266
- # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1267
-
1268
- ax = nexttile(subplot_layout[0], subplot_layout[1])
1269
- for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1270
- plot_pr_curve(
1271
- recall=res_val["pr_curve"][model_name]["recall"],
1272
- precision=res_val["pr_curve"][model_name]["precision"],
1273
- avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1274
- model_name=model_name,
1275
- color=colors[i],
1276
- lw=1.5,
1277
- alpha=alpha,
1278
- ax=ax,
1282
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1283
+
1284
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1285
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1286
+ plot_pr_curve(
1287
+ recall=res_val["pr_curve"][model_name]["recall"],
1288
+ precision=res_val["pr_curve"][model_name]["precision"],
1289
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1290
+ model_name=model_name,
1291
+ color=colors[i],
1292
+ lw=1.5,
1293
+ alpha=alpha,
1294
+ ax=ax,
1295
+ )
1296
+ plot.figsets(
1297
+ sp=2,
1298
+ legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1279
1299
  )
1280
- plot.figsets(
1281
- sp=2,
1282
- legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1283
- )
1284
- # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1300
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1301
+ else:
1302
+ colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1303
+ modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
1304
+ classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
1305
+ if res_val.shape[0] > 5:
1306
+ alpha = 0
1307
+ figsize = [8, 8*2*(len(classes))] if figsize is None else figsize
1308
+ subplot_layout = [1, 2]
1309
+ ncols = 2
1310
+ bbox_to_anchor = [1.5, 0.6]
1311
+ else:
1312
+ alpha = 0.03
1313
+ figsize = [10, 6*(len(classes))] if figsize is None else figsize
1314
+ subplot_layout = [1, 1]
1315
+ ncols = 1
1316
+ bbox_to_anchor = [1, 1]
1317
+ nexttile = plot.subplot(2*(len(classes)),2,figsize=figsize)
1318
+ for iclass, class_ in enumerate(classes):
1319
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1320
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1321
+ fpr = res_val["roc_curve"][model_name]["fpr"][class_]
1322
+ tpr = res_val["roc_curve"][model_name]["tpr"][class_]
1323
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
1324
+ mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
1325
+ plot_roc_curve(
1326
+ fpr,
1327
+ tpr,
1328
+ mean_auc,
1329
+ lower_ci,
1330
+ upper_ci,
1331
+ model_name=model_name,
1332
+ lw=1.5,
1333
+ color=colors[i],
1334
+ alpha=alpha,
1335
+ ax=ax,
1336
+ )
1337
+ plot.figsets(
1338
+ sp=2,
1339
+ title=class_,
1340
+ legend=dict(
1341
+ loc="upper right",
1342
+ ncols=ncols,
1343
+ fontsize=8,
1344
+ bbox_to_anchor=[1.5, 0.6],
1345
+ markerscale=0.8,
1346
+ ),
1347
+ )
1348
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1349
+
1350
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1351
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1352
+ plot_pr_curve(
1353
+ recall=res_val["pr_curve"][model_name]["recall"][iclass],
1354
+ precision=res_val["pr_curve"][model_name]["precision"][iclass],
1355
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
1356
+ model_name=model_name,
1357
+ color=colors[i],
1358
+ lw=1.5,
1359
+ alpha=alpha,
1360
+ ax=ax,
1361
+ )
1362
+ plot.figsets(
1363
+ sp=2,
1364
+ title=class_,
1365
+ legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1366
+ )
1285
1367
 
1368
+ def plot_validate_features_single(res_val, figsize=None,is_binary=True):
1369
+ if is_binary:
1370
+ if figsize is None:
1371
+ nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3,figsize=[13,4*len(ips.flatten(res_val["pr_curve"].index))])
1372
+ else:
1373
+ nexttile = plot.subplot(
1374
+ len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
1375
+ )
1376
+ for model_name in ips.flatten(res_val["pr_curve"].index):
1377
+ fpr = res_val["roc_curve"][model_name]["fpr"]
1378
+ tpr = res_val["roc_curve"][model_name]["tpr"]
1379
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1380
+ mean_auc = res_val["roc_curve"][model_name]["auc"]
1381
+
1382
+ # Plotting
1383
+ plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
1384
+ model_name=model_name, ax=nexttile())
1385
+ plot.figsets(title=model_name, sp=2)
1386
+
1387
+ plot_pr_binary(
1388
+ recall=res_val["pr_curve"][model_name]["recall"],
1389
+ precision=res_val["pr_curve"][model_name]["precision"],
1390
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1391
+ model_name=model_name,
1392
+ ax=nexttile(),
1393
+ )
1394
+ plot.figsets(title=model_name, sp=2)
1286
1395
 
1287
- def plot_validate_features_single(res_val, figsize=None):
1288
- if figsize is None:
1289
- nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
1396
+ # plot cm
1397
+ plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
1398
+ plot.figsets(title=model_name, sp=2)
1290
1399
  else:
1291
- nexttile = plot.subplot(
1292
- len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
1293
- )
1294
- for model_name in ips.flatten(res_val["pr_curve"].index):
1295
- fpr = res_val["roc_curve"][model_name]["fpr"]
1296
- tpr = res_val["roc_curve"][model_name]["tpr"]
1297
- (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1298
- mean_auc = res_val["roc_curve"][model_name]["auc"]
1299
-
1300
- # Plotting
1301
- plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
1302
- model_name=model_name, ax=nexttile())
1303
- plot.figsets(title=model_name, sp=2)
1304
-
1305
- plot_pr_binary(
1306
- recall=res_val["pr_curve"][model_name]["recall"],
1307
- precision=res_val["pr_curve"][model_name]["precision"],
1308
- avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1309
- model_name=model_name,
1310
- ax=nexttile(),
1311
- )
1312
- plot.figsets(title=model_name, sp=2)
1400
+
1401
+ modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
1402
+ classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
1403
+ if figsize is None:
1404
+ nexttile = plot.subplot(len(modname_tmp), 3,figsize=[15,len(modname_tmp)*5])
1405
+ else:
1406
+ nexttile = plot.subplot(len(modname_tmp), 3, figsize=figsize)
1407
+ colors = plot.get_color(len(classes))
1408
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1409
+ ax = nexttile()
1410
+ for iclass, class_ in enumerate(classes):
1411
+ fpr = res_val["roc_curve"][model_name]["fpr"][class_]
1412
+ tpr = res_val["roc_curve"][model_name]["tpr"][class_]
1413
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
1414
+ mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
1415
+ plot_roc_curve(
1416
+ fpr,
1417
+ tpr,
1418
+ mean_auc,
1419
+ lower_ci,
1420
+ upper_ci,
1421
+ model_name=class_,
1422
+ lw=1.5,
1423
+ color=colors[iclass],
1424
+ alpha=0.03,
1425
+ ax=ax,
1426
+ )
1427
+ plot.figsets(
1428
+ sp=2,
1429
+ title=model_name,
1430
+ legend=dict(
1431
+ loc="best",
1432
+ fontsize=8,
1433
+ ),
1434
+ )
1435
+
1436
+ ax = nexttile()
1437
+ for iclass, class_ in enumerate(classes):
1438
+ plot_pr_curve(
1439
+ recall=res_val["pr_curve"][model_name]["recall"][iclass],
1440
+ precision=res_val["pr_curve"][model_name]["precision"][iclass],
1441
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
1442
+ model_name=class_,
1443
+ color=colors[iclass],
1444
+ lw=1.5,
1445
+ alpha=0.03,
1446
+ ax=ax,
1447
+ )
1448
+ plot.figsets(
1449
+ sp=2,
1450
+ title=class_,
1451
+ legend=dict(loc="best", fontsize=8),
1452
+ )
1453
+
1454
+ plot_cm(res_val["confusion_matrix"][model_name],labels_name=classes, ax=nexttile(), normalize=False)
1455
+ plot.figsets(title=model_name, sp=2)
1313
1456
 
1314
- # plot cm
1315
- plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
1316
- plot.figsets(title=model_name, sp=2)
1317
1457
 
1458
+ def cal_precision_recall(
1459
+ y_true, y_pred_proba, is_binary=True):
1460
+ if is_binary:
1461
+ precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
1462
+ avg_precision_ = average_precision_score(y_true, y_pred_proba)
1463
+ return precision_, recall_,avg_precision_
1464
+ else:
1465
+ n_classes = y_pred_proba.shape[1] # Number of classes
1466
+ precision_ = []
1467
+ recall_ = []
1468
+
1469
+ # One-vs-rest approach for multi-class precision-recall curve
1470
+ for class_idx in range(n_classes):
1471
+ precision, recall, _ = precision_recall_curve(
1472
+ (y_true == class_idx).astype(int), # Binarize true labels for the current class
1473
+ y_pred_proba[:, class_idx], # Probabilities for the current class
1474
+ )
1318
1475
 
1476
+ precision_.append(precision)
1477
+ recall_.append(recall)
1478
+ # Optionally, you can compute average precision for each class
1479
+ avg_precision_ = []
1480
+ for class_idx in range(n_classes):
1481
+ avg_precision = average_precision_score(
1482
+ (y_true == class_idx).astype(int), # Binarize true labels for the current class
1483
+ y_pred_proba[:, class_idx], # Probabilities for the current class
1484
+ )
1485
+ avg_precision_.append(avg_precision)
1486
+ return precision_, recall_,avg_precision_
1487
+
1319
1488
  def cal_auc_ci(
1320
- y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
1489
+ y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,is_binary=True, verbose=True
1321
1490
  ):
1322
- y_true = np.asarray(y_true)
1323
- y_pred = np.asarray(y_pred)
1324
- bootstrapped_scores = []
1325
- if verbose:
1326
- print("auroc score:", roc_auc_score(y_true, y_pred))
1327
- rng = np.random.RandomState(random_state)
1328
- for i in range(n_bootstraps):
1329
- # bootstrap by sampling with replacement on the prediction indices
1330
- indices = rng.randint(0, len(y_pred), len(y_pred))
1331
- if len(np.unique(y_true[indices])) < 2:
1332
- # We need at least one positive and one negative sample for ROC AUC
1333
- # to be defined: reject the sample
1334
- continue
1335
- if isinstance(y_true, np.ndarray):
1336
- score = roc_auc_score(y_true[indices], y_pred[indices])
1337
- else:
1338
- score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
1339
- bootstrapped_scores.append(score)
1340
- # print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
1341
- sorted_scores = np.array(bootstrapped_scores)
1342
- sorted_scores.sort()
1343
-
1344
- # Computing the lower and upper bound of the 90% confidence interval
1345
- # You can change the bounds percentiles to 0.025 and 0.975 to get
1346
- # a 95% confidence interval instead.
1347
- confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1348
- confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1349
- if verbose:
1350
- print(
1351
- "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1352
- confidence_lower, confidence_upper
1491
+ if is_binary:
1492
+ y_true = np.asarray(y_true)
1493
+ y_pred = np.asarray(y_pred)
1494
+ bootstrapped_scores = []
1495
+ if verbose:
1496
+ print("auroc score:", roc_auc_score(y_true, y_pred))
1497
+ rng = np.random.RandomState(random_state)
1498
+ for i in range(n_bootstraps):
1499
+ # bootstrap by sampling with replacement on the prediction indices
1500
+ indices = rng.randint(0, len(y_pred), len(y_pred))
1501
+ if len(np.unique(y_true[indices])) < 2:
1502
+ # We need at least one positive and one negative sample for ROC AUC
1503
+ # to be defined: reject the sample
1504
+ continue
1505
+ if isinstance(y_true, np.ndarray):
1506
+ score = roc_auc_score(y_true[indices], y_pred[indices])
1507
+ else:
1508
+ score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
1509
+ bootstrapped_scores.append(score)
1510
+ # print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
1511
+ sorted_scores = np.array(bootstrapped_scores)
1512
+ sorted_scores.sort()
1513
+
1514
+ # Computing the lower and upper bound of the 90% confidence interval
1515
+ # You can change the bounds percentiles to 0.025 and 0.975 to get
1516
+ # a 95% confidence interval instead.
1517
+ confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1518
+ confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1519
+ if verbose:
1520
+ print(
1521
+ "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1522
+ confidence_lower, confidence_upper
1523
+ )
1353
1524
  )
1354
- )
1355
- return confidence_lower, confidence_upper
1525
+ return confidence_lower, confidence_upper
1526
+ else:
1527
+ from sklearn.preprocessing import label_binarize
1528
+ # Multi-class classification case
1529
+ y_true = np.asarray(y_true)
1530
+ y_pred = np.asarray(y_pred)
1531
+
1532
+ # Binarize the multi-class labels for OvR computation
1533
+ y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) # One-vs-Rest transformation
1534
+ n_classes = y_true_bin.shape[1] # Number of classes
1535
+
1536
+ bootstrapped_scores = np.zeros((n_classes, n_bootstraps)) # Store scores for each class
1537
+
1538
+ if verbose:
1539
+ print("AUROC scores for each class:")
1540
+ for i in range(n_classes):
1541
+ print(f"Class {i}: {roc_auc_score(y_true_bin[:, i], y_pred[:, i])}")
1542
+
1543
+ rng = np.random.RandomState(random_state)
1544
+ for i in range(n_bootstraps):
1545
+ indices = rng.randint(0, len(y_pred), len(y_pred))
1546
+ for class_idx in range(n_classes):
1547
+ if len(np.unique(y_true_bin[indices, class_idx])) < 2:
1548
+ continue # Reject if the class doesn't have both positive and negative samples
1549
+ score = roc_auc_score(y_true_bin[indices, class_idx], y_pred[indices, class_idx])
1550
+ bootstrapped_scores[class_idx, i] = score
1551
+
1552
+ # Calculating the confidence intervals for each class
1553
+ confidence_intervals = []
1554
+ for class_idx in range(n_classes):
1555
+ sorted_scores = np.sort(bootstrapped_scores[class_idx])
1556
+ confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1557
+ confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1558
+ confidence_intervals.append((confidence_lower, confidence_upper))
1559
+
1560
+ if verbose:
1561
+ print(f"Class {class_idx} - Confidence interval: [{confidence_lower:.3f} - {confidence_upper:.3f}]")
1562
+
1563
+ return confidence_intervals
1356
1564
 
1357
1565
 
1358
1566
  def plot_roc_curve(
@@ -1517,7 +1725,7 @@ def plot_pr_binary(
1517
1725
 
1518
1726
  pr_boundary = interp1d(recall, precision, kind="linear", fill_value="extrapolate")
1519
1727
  for f_score in f_scores:
1520
- x_vals = np.linspace(0.01, 1, 10000)
1728
+ x_vals = np.linspace(0.01, 1, 20000)
1521
1729
  y_vals = f_score * x_vals / (2 * x_vals - f_score)
1522
1730
  y_vals_clipped = np.minimum(y_vals, pr_boundary(x_vals))
1523
1731
  y_vals_clipped = np.clip(y_vals_clipped, 1e-3, None) # Prevent going to zero
@@ -1553,7 +1761,7 @@ def plot_pr_binary(
1553
1761
  def plot_cm(
1554
1762
  cm,
1555
1763
  labels_name=None,
1556
- thresh=0.8,
1764
+ thresh=0.8, # for set color
1557
1765
  axis_labels=None,
1558
1766
  cmap="Reds",
1559
1767
  normalize=True,
@@ -2029,11 +2237,21 @@ def predict(
2029
2237
  if isinstance(y_train, str) and y_train in x_train.columns:
2030
2238
  y_train_col_name = y_train
2031
2239
  y_train = x_train[y_train]
2032
- y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2240
+ # y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2033
2241
  x_train = x_train.drop(y_train_col_name, axis=1)
2242
+ # else:
2243
+ # y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
2244
+ y_train=pd.DataFrame(y_train)
2245
+ if y_train.select_dtypes(include=np.number).empty:
2246
+ y_train_=ips.df_encoder(y_train, method="dummy",drop=None)
2247
+ is_binary = False if y_train_.shape[1] >2 else True
2034
2248
  else:
2035
- y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
2249
+ y_train_=ips.flatten(y_train.values)
2250
+ is_binary = False if len(y_train_)>2 else True
2036
2251
 
2252
+ if is_binary:
2253
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="label")
2254
+ print('is_binary:',is_binary)
2037
2255
  if x_true is None:
2038
2256
  x_train, x_true, y_train, y_true = train_test_split(
2039
2257
  x_train,
@@ -2042,23 +2260,27 @@ def predict(
2042
2260
  random_state=random_state,
2043
2261
  stratify=y_train if purpose == "classification" else None,
2044
2262
  )
2263
+
2045
2264
  if isinstance(y_train, str) and y_train in x_train.columns:
2046
2265
  y_train_col_name = y_train
2047
2266
  y_train = x_train[y_train]
2048
- y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2267
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="label") if is_binary else y_train
2049
2268
  x_train = x_train.drop(y_train_col_name, axis=1)
2050
- else:
2269
+ if is_binary:
2051
2270
  y_train = ips.df_encoder(
2052
- pd.DataFrame(y_train), method="dummy"
2053
- ).values.ravel()
2271
+ pd.DataFrame(y_train), method="label"
2272
+ ).values.ravel()
2273
+
2054
2274
  if y_true is not None:
2055
2275
  if isinstance(y_true, str) and y_true in x_true.columns:
2056
2276
  y_true_col_name = y_true
2057
2277
  y_true = x_true[y_true]
2058
- y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
2278
+ y_true = ips.df_encoder(pd.DataFrame(y_true), method="label") if is_binary else y_true
2279
+ y_true = pd.DataFrame(y_true)
2059
2280
  x_true = x_true.drop(y_true_col_name, axis=1)
2060
- else:
2061
- y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
2281
+ if is_binary:
2282
+ y_true = ips.df_encoder(pd.DataFrame(y_true), method="label").values.ravel()
2283
+ y_true = pd.DataFrame(y_true)
2062
2284
 
2063
2285
  # to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
2064
2286
 
@@ -2068,7 +2290,6 @@ def predict(
2068
2290
  y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
2069
2291
  )
2070
2292
  y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
2071
-
2072
2293
  # Ensure common features are selected
2073
2294
  if common_features is not None:
2074
2295
  x_train, x_true = x_train[common_features], x_true[common_features]
@@ -2077,10 +2298,7 @@ def predict(
2077
2298
  x_train, x_true = x_train[share_col_names], x_true[share_col_names]
2078
2299
 
2079
2300
  x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
2080
- x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
2081
- x_true, method="dummy"
2082
- )
2083
-
2301
+ x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(x_true, method="dummy")
2084
2302
  # Handle class imbalance using SMOTE (only for classification)
2085
2303
  if (
2086
2304
  smote
@@ -2091,7 +2309,13 @@ def predict(
2091
2309
 
2092
2310
  smote_sampler = SMOTE(random_state=random_state)
2093
2311
  x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
2094
-
2312
+ if not is_binary:
2313
+ if isinstance(y_train, np.ndarray):
2314
+ y_train = ips.df_encoder(data=pd.DataFrame(y_train),method='label')
2315
+ y_train=np.asarray(y_train)
2316
+ if isinstance(y_train, np.ndarray):
2317
+ y_true = ips.df_encoder(data=pd.DataFrame(y_true),method='label')
2318
+ y_true=np.asarray(y_true)
2095
2319
  # Hyperparameter grids for tuning
2096
2320
  if cv_level in ["low", "simple", "s", "l"]:
2097
2321
  param_grids = {
@@ -2670,95 +2894,181 @@ def predict(
2670
2894
  print(f"\nTraining and validating {name}:")
2671
2895
 
2672
2896
  # Grid search with KFold or StratifiedKFold
2673
- gs = GridSearchCV(
2674
- clf,
2675
- param_grid=param_grids.get(name, {}),
2676
- scoring=(
2677
- "roc_auc" if purpose == "classification" else "neg_mean_squared_error"
2678
- ),
2679
- cv=cv,
2680
- n_jobs=n_jobs,
2681
- verbose=verbose,
2682
- )
2683
- gs.fit(x_train, y_train)
2684
- best_clf = gs.best_estimator_
2685
- # make sure x_train and x_test has the same name
2686
- x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2687
- y_pred = best_clf.predict(x_true)
2688
-
2689
- # y_pred_proba
2690
- if hasattr(best_clf, "predict_proba"):
2691
- y_pred_proba = best_clf.predict_proba(x_true)[:, 1]
2692
- elif hasattr(best_clf, "decision_function"):
2693
- # If predict_proba is not available, use decision_function (e.g., for SVM)
2694
- y_pred_proba = best_clf.decision_function(x_true)
2695
- # Ensure y_pred_proba is within 0 and 1 bounds
2696
- y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
2697
- y_pred_proba.max() - y_pred_proba.min()
2897
+ if is_binary:
2898
+ gs = GridSearchCV(
2899
+ clf,
2900
+ param_grid=param_grids.get(name, {}),
2901
+ scoring=(
2902
+ "roc_auc" if purpose == "classification" else "neg_mean_squared_error"
2903
+ ),
2904
+ cv=cv,
2905
+ n_jobs=n_jobs,
2906
+ verbose=verbose,
2698
2907
  )
2908
+
2909
+ gs.fit(x_train, y_train)
2910
+ best_clf = gs.best_estimator_
2911
+ # make sure x_train and x_test has the same name
2912
+ x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2913
+ y_pred = best_clf.predict(x_true)
2914
+ if hasattr(best_clf, "predict_proba"):
2915
+ y_pred_proba = best_clf.predict_proba(x_true)
2916
+ print("Shape of predicted probabilities:", y_pred_proba.shape)
2917
+ if y_pred_proba.shape[1] == 1:
2918
+ y_pred_proba = np.hstack([1 - y_pred_proba, y_pred_proba]) # Add missing class probabilities
2919
+ y_pred_proba = y_pred_proba[:, 1]
2920
+ elif hasattr(best_clf, "decision_function"):
2921
+ # If predict_proba is not available, use decision_function (e.g., for SVM)
2922
+ y_pred_proba = best_clf.decision_function(x_true)
2923
+ # Ensure y_pred_proba is within 0 and 1 bounds
2924
+ y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
2925
+ y_pred_proba.max() - y_pred_proba.min()
2926
+ )
2927
+ else:
2928
+ y_pred_proba = None # No probability output for certain models
2699
2929
  else:
2700
- y_pred_proba = None # No probability output for certain models
2930
+ gs = GridSearchCV(
2931
+ clf,
2932
+ param_grid=param_grids.get(name, {}),
2933
+ scoring=(
2934
+ "roc_auc_ovr" if purpose == "classification" else "neg_mean_squared_error"
2935
+ ),
2936
+ cv=cv,
2937
+ n_jobs=n_jobs,
2938
+ verbose=verbose,
2939
+ )
2701
2940
 
2941
+ # Fit GridSearchCV
2942
+ gs.fit(x_train, y_train)
2943
+ best_clf = gs.best_estimator_
2944
+
2945
+ # Ensure x_true aligns with x_train columns
2946
+ x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2947
+ y_pred = best_clf.predict(x_true)
2948
+
2949
+ # Handle prediction probabilities for multiclass
2950
+ if hasattr(best_clf, "predict_proba"):
2951
+ y_pred_proba = best_clf.predict_proba(x_true)
2952
+ elif hasattr(best_clf, "decision_function"):
2953
+ y_pred_proba = best_clf.decision_function(x_true)
2954
+
2955
+ # Normalize for multiclass if necessary
2956
+ if y_pred_proba.ndim == 2:
2957
+ y_pred_proba = (y_pred_proba - y_pred_proba.min(axis=1, keepdims=True)) / \
2958
+ (y_pred_proba.max(axis=1, keepdims=True) - y_pred_proba.min(axis=1, keepdims=True))
2959
+ else:
2960
+ y_pred_proba = None # No probability output for certain models
2961
+
2702
2962
  validation_scores = {}
2703
- if y_true is not None:
2963
+
2964
+ if y_true is not None and y_pred_proba is not None:
2704
2965
  validation_scores = cal_metrics(
2705
2966
  y_true,
2706
2967
  y_pred,
2707
2968
  y_pred_proba=y_pred_proba,
2969
+ is_binary=is_binary,
2708
2970
  purpose=purpose,
2709
2971
  average="weighted",
2710
2972
  )
2711
-
2712
- # Calculate ROC curve
2713
- # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
2714
- if y_pred_proba is not None:
2715
- # fpr, tpr, roc_auc = dict(), dict(), dict()
2716
- fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
2717
- lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
2718
- roc_auc = auc(fpr, tpr)
2719
- roc_info = {
2720
- "fpr": fpr.tolist(),
2721
- "tpr": tpr.tolist(),
2722
- "auc": roc_auc,
2723
- "ci95": (lower_ci, upper_ci),
2724
- }
2725
- # precision-recall curve
2726
- precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
2727
- avg_precision_ = average_precision_score(y_true, y_pred_proba)
2728
- pr_info = {
2729
- "precision": precision_,
2730
- "recall": recall_,
2731
- "avg_precision": avg_precision_,
2732
- }
2733
- else:
2734
- roc_info, pr_info = None, None
2735
- if purpose == "classification":
2736
- results[name] = {
2737
- "best_clf": gs.best_estimator_,
2738
- "best_params": gs.best_params_,
2739
- "auc_indiv": [
2740
- gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
2741
- for i in range(cv_folds)
2742
- ],
2743
- "scores": validation_scores,
2744
- "roc_curve": roc_info,
2745
- "pr_curve": pr_info,
2746
- "confusion_matrix": confusion_matrix(y_true, y_pred),
2747
- "predictions": y_pred.tolist(),
2748
- "predictions_proba": (
2749
- y_pred_proba.tolist() if y_pred_proba is not None else None
2750
- ),
2751
- }
2752
- else: # "regression"
2753
- results[name] = {
2754
- "best_clf": gs.best_estimator_,
2755
- "best_params": gs.best_params_,
2756
- "scores": validation_scores, # e.g., neg_MSE, R², etc.
2757
- "predictions": y_pred.tolist(),
2758
- "predictions_proba": (
2759
- y_pred_proba.tolist() if y_pred_proba is not None else None
2760
- ),
2761
- }
2973
+ if is_binary:
2974
+ # Calculate ROC curve
2975
+ # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
2976
+ if y_pred_proba is not None:
2977
+ # fpr, tpr, roc_auc = dict(), dict(), dict()
2978
+ fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
2979
+ lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
2980
+ roc_auc = auc(fpr, tpr)
2981
+ roc_info = {
2982
+ "fpr": fpr.tolist(),
2983
+ "tpr": tpr.tolist(),
2984
+ "auc": roc_auc,
2985
+ "ci95": (lower_ci, upper_ci),
2986
+ }
2987
+ # precision-recall curve
2988
+ precision_, recall_, _ = cal_precision_recall(y_true, y_pred_proba)
2989
+ avg_precision_ = average_precision_score(y_true, y_pred_proba)
2990
+ pr_info = {
2991
+ "precision": precision_,
2992
+ "recall": recall_,
2993
+ "avg_precision": avg_precision_,
2994
+ }
2995
+ else:
2996
+ roc_info, pr_info = None, None
2997
+ if purpose == "classification":
2998
+ results[name] = {
2999
+ "best_clf": gs.best_estimator_,
3000
+ "best_params": gs.best_params_,
3001
+ "auc_indiv": [
3002
+ gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
3003
+ for i in range(cv_folds)
3004
+ ],
3005
+ "scores": validation_scores,
3006
+ "roc_curve": roc_info,
3007
+ "pr_curve": pr_info,
3008
+ "confusion_matrix": confusion_matrix(y_true, y_pred),
3009
+ "predictions": y_pred.tolist(),
3010
+ "predictions_proba": (
3011
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3012
+ ),
3013
+ }
3014
+ else: # "regression"
3015
+ results[name] = {
3016
+ "best_clf": gs.best_estimator_,
3017
+ "best_params": gs.best_params_,
3018
+ "scores": validation_scores, # e.g., neg_MSE, R², etc.
3019
+ "predictions": y_pred.tolist(),
3020
+ "predictions_proba": (
3021
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3022
+ ),
3023
+ }
3024
+ else: # multi-classes
3025
+ if y_pred_proba is not None:
3026
+ # fpr, tpr, roc_auc = dict(), dict(), dict()
3027
+ # fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
3028
+ confidence_intervals = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
3029
+ roc_info = {
3030
+ "fpr": validation_scores["fpr"],
3031
+ "tpr": validation_scores["tpr"],
3032
+ "auc": validation_scores["roc_auc_by_class"],
3033
+ "ci95": confidence_intervals,
3034
+ }
3035
+ # precision-recall curve
3036
+ precision_, recall_, avg_precision_ = cal_precision_recall(y_true, y_pred_proba,is_binary=is_binary)
3037
+ pr_info = {
3038
+ "precision": precision_,
3039
+ "recall": recall_,
3040
+ "avg_precision": avg_precision_,
3041
+ }
3042
+ else:
3043
+ roc_info, pr_info = None, None
3044
+
3045
+ if purpose == "classification":
3046
+ results[name] = {
3047
+ "best_clf": gs.best_estimator_,
3048
+ "best_params": gs.best_params_,
3049
+ "auc_indiv": [
3050
+ gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
3051
+ for i in range(cv_folds)
3052
+ ],
3053
+ "scores": validation_scores,
3054
+ "roc_curve": roc_info,
3055
+ "pr_curve": pr_info,
3056
+ "confusion_matrix": confusion_matrix(y_true, y_pred),
3057
+ "predictions": y_pred.tolist(),
3058
+ "predictions_proba": (
3059
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3060
+ ),
3061
+ }
3062
+ else: # "regression"
3063
+ results[name] = {
3064
+ "best_clf": gs.best_estimator_,
3065
+ "best_params": gs.best_params_,
3066
+ "scores": validation_scores, # e.g., neg_MSE, R², etc.
3067
+ "predictions": y_pred.tolist(),
3068
+ "predictions_proba": (
3069
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3070
+ ),
3071
+ }
2762
3072
 
2763
3073
  else:
2764
3074
  results[name] = {
@@ -2773,7 +3083,6 @@ def predict(
2773
3083
 
2774
3084
  # Convert results to DataFrame
2775
3085
  df_results = pd.DataFrame.from_dict(results, orient="index")
2776
-
2777
3086
  # sort
2778
3087
  if y_true is not None and purpose == "classification":
2779
3088
  df_scores = pd.DataFrame(
@@ -2790,26 +3099,29 @@ def predict(
2790
3099
  plot.figsets(xangle=30)
2791
3100
  if dir_save:
2792
3101
  ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
3102
+
3103
+ df_scores=df_scores.select_dtypes(include=np.number)
3104
+
2793
3105
  if df_scores.shape[0] > 1: # draw cluster
2794
3106
  plot.heatmap(df_scores, kind="direct", cluster=True)
2795
3107
  plot.figsets(xangle=30)
2796
3108
  if dir_save:
2797
3109
  ips.figsave(dir_save + f"scores_clus{now_}.pdf")
2798
3110
  if all([plot_, y_true is not None, purpose == "classification"]):
2799
- try:
2800
- if len(models) > 3:
2801
- plot_validate_features(df_results)
2802
- else:
2803
- plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
2804
- if dir_save:
2805
- ips.figsave(dir_save + f"validate_features{now_}.pdf")
2806
- except Exception as e:
2807
- print(f"Error: 在画图的过程中出现了问题:{e}")
3111
+ # try:
3112
+ if len(models) > 3:
3113
+ plot_validate_features(df_results,is_binary=is_binary)
3114
+ else:
3115
+ plot_validate_features_single(df_results, is_binary=is_binary)
3116
+ if dir_save:
3117
+ ips.figsave(dir_save + f"validate_features{now_}.pdf")
3118
+ # except Exception as e:
3119
+ # print(f"Error: 在画图的过程中出现了问题:{e}")
2808
3120
  return df_results
2809
3121
 
2810
3122
 
2811
3123
  def cal_metrics(
2812
- y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
3124
+ y_true, y_pred, y_pred_proba=None, is_binary=True,purpose="regression", average="weighted"
2813
3125
  ):
2814
3126
  """
2815
3127
  Calculate regression or classification metrics based on the purpose.
@@ -2879,19 +3191,362 @@ def cal_metrics(
2879
3191
  }
2880
3192
 
2881
3193
  # Confusion matrix to calculate specificity
2882
- tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
2883
- validation_scores["specificity"] = (
2884
- tn / (tn + fp) if (tn + fp) > 0 else 0
2885
- ) # Specificity calculation
3194
+ if is_binary:
3195
+ cm = confusion_matrix(y_true, y_pred)
3196
+ if cm.size == 4:
3197
+ tn, fp, fn, tp = cm.ravel()
3198
+ else:
3199
+ # Handle single-class predictions
3200
+ tn, fp, fn, tp = 0, 0, 0, 0
3201
+ print("Warning: Only one class found in y_pred or y_true.")
3202
+
3203
+ # Specificity calculation
3204
+ validation_scores["specificity"] = (
3205
+ tn / (tn + fp) if (tn + fp) > 0 else 0
3206
+ )
3207
+ if y_pred_proba is not None:
3208
+ # Calculate ROC-AUC
3209
+ validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
3210
+ # PR-AUC (Precision-Recall AUC) calculation
3211
+ validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
3212
+
3213
+ else: # multi-class
3214
+ from sklearn.preprocessing import label_binarize
3215
+ #* Multi-class ROC calculation
3216
+ y_pred_proba = np.asarray(y_pred_proba)
3217
+ classes = np.unique(y_true)
3218
+ y_true_bin = label_binarize(y_true, classes=classes)
3219
+ if isinstance(y_true, np.ndarray):
3220
+ y_true = ips.df_encoder(data=pd.DataFrame(y_true), method='dum',prefix='Label')
3221
+ # Initialize dictionaries to store FPR, TPR, and AUC for each class
3222
+ fpr = dict()
3223
+ tpr = dict()
3224
+ roc_auc = dict()
3225
+ for i, class_label in enumerate(classes):
3226
+ fpr[class_label], tpr[class_label], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
3227
+ roc_auc[class_label] = auc(fpr[class_label], tpr[class_label])
3228
+
3229
+ # Store the mean ROC AUC
3230
+ try:
3231
+ validation_scores["roc_auc"] = roc_auc_score(
3232
+ y_true, y_pred_proba, multi_class="ovr", average=average
3233
+ )
3234
+ except Exception as e:
3235
+ y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True)
3236
+ validation_scores["roc_auc"] = roc_auc_score(
3237
+ y_true, y_pred_proba, multi_class="ovr", average=average
3238
+ )
3239
+
3240
+ validation_scores["roc_auc_by_class"] = roc_auc # Individual class AUCs
3241
+ validation_scores["fpr"] = fpr
3242
+ validation_scores["tpr"] = tpr
2886
3243
 
2887
- if y_pred_proba is not None:
2888
- # Calculate ROC-AUC
2889
- validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
2890
- # PR-AUC (Precision-Recall AUC) calculation
2891
- validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
2892
3244
  else:
2893
3245
  raise ValueError(
2894
3246
  "Invalid purpose specified. Choose 'regression' or 'classification'."
2895
3247
  )
2896
3248
 
2897
3249
  return validation_scores
3250
+
3251
+ def plot_trees(
3252
+ X, y, cls, max_trees=500, test_size=0.2, random_state=42, early_stopping_rounds=None
3253
+ ):
3254
+ """
3255
+ # # Example usage:
3256
+ # X = np.random.rand(100, 10) # Example data with 100 samples and 10 features
3257
+ # y = np.random.randint(0, 2, 100) # Example binary target
3258
+ # # Using the function with different classifiers
3259
+ # # Random Forest example
3260
+ # plot_trees(X, y, RandomForestClassifier(), max_trees=100)
3261
+ # # Gradient Boosting with early stopping example
3262
+ # plot_trees(X, y, GradientBoostingClassifier(), max_trees=100, early_stopping_rounds=10)
3263
+ # # Extra Trees example
3264
+ # plot_trees(X, y, ExtraTreesClassifier(), max_trees=100)
3265
+ Master function to plot error rates (OOB, training, and testing) for different tree-based ensemble classifiers.
3266
+
3267
+ Parameters:
3268
+ - X (array-like): Feature matrix.
3269
+ - y (array-like): Target labels.
3270
+ - cls (object): Tree-based ensemble classifier instance (e.g., RandomForestClassifier()).
3271
+ - max_trees (int): Maximum number of trees to evaluate. Default is 500.
3272
+ - test_size (float): Proportion of data to use as test set for testing error. Default is 0.2.
3273
+ - random_state (int): Random state for reproducibility. Default is 42.
3274
+ - early_stopping_rounds (int): For boosting models only, stops training if validation error doesn't improve after specified rounds.
3275
+
3276
+ Returns:
3277
+ - None
3278
+ """
3279
+ from sklearn.model_selection import train_test_split
3280
+ from sklearn.metrics import accuracy_score
3281
+ from sklearn.ensemble import (
3282
+ RandomForestClassifier,
3283
+ BaggingClassifier,
3284
+ ExtraTreesClassifier,
3285
+ )
3286
+ from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
3287
+ # Split data for training and testing error calculation
3288
+ x_train, x_test, y_train, y_test = train_test_split(
3289
+ X, y, test_size=test_size, random_state=random_state
3290
+ )
3291
+
3292
+ # Initialize lists to store error rates
3293
+ oob_error_rate = []
3294
+ train_error_rate = []
3295
+ test_error_rate = []
3296
+ validation_error = None
3297
+
3298
+ # Configure classifier based on type
3299
+ oob_enabled = False # Default to no OOB error unless explicitly set
3300
+
3301
+ if isinstance(cls, (RandomForestClassifier, ExtraTreesClassifier)):
3302
+ # Enable OOB if cls supports it and is using bootstrapping
3303
+ cls.set_params(warm_start=True, n_estimators=1)
3304
+ if hasattr(cls, "oob_score"):
3305
+ cls.set_params(bootstrap=True, oob_score=True)
3306
+ oob_enabled = True
3307
+ elif isinstance(cls, BaggingClassifier):
3308
+ cls.set_params(warm_start=True, bootstrap=True, oob_score=True, n_estimators=1)
3309
+ oob_enabled = True
3310
+ elif isinstance(cls, (AdaBoostClassifier, GradientBoostingClassifier)):
3311
+ cls.set_params(n_estimators=1)
3312
+ oob_enabled = False
3313
+ if early_stopping_rounds:
3314
+ validation_error = []
3315
+
3316
+ # Train and evaluate with an increasing number of trees
3317
+ for i in range(1, max_trees + 1):
3318
+ cls.set_params(n_estimators=i)
3319
+ cls.fit(x_train, y_train)
3320
+
3321
+ # Calculate OOB error (for models that support it)
3322
+ if oob_enabled and hasattr(cls, "oob_score_") and cls.oob_score:
3323
+ oob_error = 1 - cls.oob_score_
3324
+ oob_error_rate.append(oob_error)
3325
+
3326
+ # Calculate training error
3327
+ train_error = 1 - accuracy_score(y_train, cls.predict(x_train))
3328
+ train_error_rate.append(train_error)
3329
+
3330
+ # Calculate testing error
3331
+ test_error = 1 - accuracy_score(y_test, cls.predict(x_test))
3332
+ test_error_rate.append(test_error)
3333
+
3334
+ # For boosting models, use validation error with early stopping
3335
+ if early_stopping_rounds and isinstance(
3336
+ cls, (AdaBoostClassifier, GradientBoostingClassifier)
3337
+ ):
3338
+ val_error = 1 - accuracy_score(y_test, cls.predict(x_test))
3339
+ validation_error.append(val_error)
3340
+ if len(validation_error) > early_stopping_rounds:
3341
+ # Stop if validation error has not improved in early_stopping_rounds
3342
+ if validation_error[-early_stopping_rounds:] == sorted(
3343
+ validation_error[-early_stopping_rounds:]
3344
+ ):
3345
+ print(f"Early stopping at tree {i} due to lack of improvement in validation error.")
3346
+ break
3347
+
3348
+ # Plot results
3349
+ plt.figure(figsize=(10, 6))
3350
+ if oob_error_rate:
3351
+ plt.plot(
3352
+ range(1, len(oob_error_rate) + 1),
3353
+ oob_error_rate,
3354
+ color="black",
3355
+ label="OOB Error Rate",
3356
+ linewidth=2,
3357
+ )
3358
+ if train_error_rate:
3359
+ plt.plot(
3360
+ range(1, len(train_error_rate) + 1),
3361
+ train_error_rate,
3362
+ linestyle="dotted",
3363
+ color="green",
3364
+ label="Training Error Rate",
3365
+ )
3366
+ if test_error_rate:
3367
+ plt.plot(
3368
+ range(1, len(test_error_rate) + 1),
3369
+ test_error_rate,
3370
+ linestyle="dashed",
3371
+ color="red",
3372
+ label="Testing Error Rate",
3373
+ )
3374
+ if validation_error:
3375
+ plt.plot(
3376
+ range(1, len(validation_error) + 1),
3377
+ validation_error,
3378
+ linestyle="solid",
3379
+ color="blue",
3380
+ label="Validation Error (Boosting)",
3381
+ )
3382
+
3383
+ # Customize plot
3384
+ plt.xlabel("Number of Trees")
3385
+ plt.ylabel("Error Rate")
3386
+ plt.title(f"Error Rate Analysis for {cls.__class__.__name__}")
3387
+ plt.legend(loc="upper right")
3388
+ plt.grid(True)
3389
+ plt.show()
3390
+
3391
+ def img_datasets_preprocessing(
3392
+ data: pd.DataFrame,
3393
+ x_col: str,
3394
+ y_col: str=None,
3395
+ target_size: tuple = (224, 224),
3396
+ batch_size: int = 128,
3397
+ class_mode: str = "raw",
3398
+ shuffle: bool = False,
3399
+ augment: bool = False,
3400
+ scaler: str = 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
3401
+ grayscale: bool = False,
3402
+ encoder: str = "label", # Options: 'label', 'onehot', 'binary'
3403
+ label_encoder=None,
3404
+ kws_augmentation: dict = None,
3405
+ verbose: bool = True,
3406
+ drop_missing: bool = True,
3407
+ output="df", # "iterator":data_iterator,'df':return DataFrame
3408
+ ):
3409
+ """
3410
+ Enhanced preprocessing function for loading and preparing image data from a DataFrame.
3411
+
3412
+ Parameters:
3413
+ - df (pd.DataFrame): Input DataFrame with image paths and labels.
3414
+ - x_col (str): Column in `df` containing image file paths.
3415
+ - y_col (str): Column in `df` containing image labels.
3416
+ - target_size (tuple): Desired image size in (height, width).
3417
+ - batch_size (int): Number of images per batch.
3418
+ - class_mode (str): Mode of label ('raw', 'categorical', 'binary').
3419
+ - shuffle (bool): Shuffle the images in the DataFrame.
3420
+ - augment (bool): Apply data augmentation.
3421
+ - scaler (str): 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
3422
+ - grayscale (bool): Convert images to grayscale.
3423
+ - normalize (bool): Normalize image data to [0, 1] range.
3424
+ - encoder (str): Label encoder method ('label', 'onehot', 'binary').
3425
+ - label_encoder: Optional pre-defined label encoder.
3426
+ - kws_augmentation (dict): Parameters for data augmentation.
3427
+ - verbose (bool): Print status messages.
3428
+ - drop_missing (bool): Drop rows with missing or invalid image paths.
3429
+
3430
+ Returns:
3431
+ - pd.DataFrame: DataFrame with flattened image pixels and 'Label' column.
3432
+ """
3433
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
3434
+ from tensorflow.keras.utils import to_categorical
3435
+ from sklearn.preprocessing import LabelEncoder
3436
+ from PIL import Image
3437
+ import os
3438
+
3439
+ # Validate input DataFrame for required columns
3440
+ if y_col:
3441
+ assert (
3442
+ x_col in data.columns and y_col in data.columns
3443
+ ), "Missing required columns in DataFrame."
3444
+ if y_col is None:
3445
+ class_mode=None
3446
+ # 输出格式
3447
+ output = ips.strcmp(output,[
3448
+ "generator","tf","iterator","transform","transformer","dataframe",
3449
+ "df","pd","pandas"])[0]
3450
+
3451
+ # Handle missing file paths
3452
+ if drop_missing:
3453
+ data = data[
3454
+ data[x_col].apply(lambda path: os.path.exists(path) and os.path.isfile(path))
3455
+ ]
3456
+
3457
+ # Encoding labels if necessary
3458
+ if encoder and y_col is not None:
3459
+ if encoder == "binary":
3460
+ data[y_col] = (data[y_col] == data[y_col].unique()[0]).astype(int)
3461
+ elif encoder == "onehot":
3462
+ if not label_encoder:
3463
+ label_encoder = LabelEncoder()
3464
+ data[y_col] = label_encoder.fit_transform(data[y_col])
3465
+ data[y_col] = to_categorical(data[y_col])
3466
+ elif encoder == "label":
3467
+ if not label_encoder:
3468
+ label_encoder = LabelEncoder()
3469
+ data[y_col] = label_encoder.fit_transform(data[y_col])
3470
+
3471
+ # Set up data augmentation
3472
+ if augment:
3473
+ aug_params = {
3474
+ "rotation_range": 20,
3475
+ "width_shift_range": 0.2,
3476
+ "height_shift_range": 0.2,
3477
+ "shear_range": 0.2,
3478
+ "zoom_range": 0.2,
3479
+ "horizontal_flip": True,
3480
+ "fill_mode": "nearest",
3481
+ }
3482
+ if kws_augmentation:
3483
+ aug_params.update(kws_augmentation)
3484
+ dat = ImageDataGenerator(rescale=scaler, **aug_params)
3485
+ dat = ImageDataGenerator(
3486
+ rescale=1.0 / 255 if scaler == 'normalize' else None, **aug_params)
3487
+
3488
+ else:
3489
+ dat = ImageDataGenerator(
3490
+ rescale=1.0 / 255 if scaler == 'normalize' else None)
3491
+
3492
+ # Create DataFrameIterator
3493
+ data_iterator = dat.flow_from_dataframe(
3494
+ dataframe=data,
3495
+ x_col=x_col,
3496
+ y_col=y_col,
3497
+ target_size=target_size,
3498
+ color_mode="grayscale" if grayscale else "rgb",
3499
+ batch_size=batch_size,
3500
+ class_mode=class_mode,
3501
+ shuffle=shuffle,
3502
+ )
3503
+ print(f"target_size:{target_size}")
3504
+ if output.lower() in ["generator", "tf", "iterator", "transform", "transformer"]:
3505
+ return data_iterator
3506
+ elif output.lower() in ["dataframe", "df", "pd", "pandas"]:
3507
+ # Initialize list to collect processed data
3508
+ data_list = []
3509
+ total_batches = data_iterator.n // batch_size
3510
+
3511
+ # Load, resize, and process images in batches
3512
+ for i, (batch_images, batch_labels) in enumerate(data_iterator):
3513
+ for img, label in zip(batch_images, batch_labels):
3514
+ if scaler == ['normalize','raw']:
3515
+ # Already rescaled by 1.0/255 in ImageDataGenerator
3516
+ pass
3517
+ elif scaler == 'standardize':
3518
+ # Standardize by subtracting mean and dividing by std
3519
+ img = (img - np.mean(img)) / np.std(img)
3520
+ elif scaler == 'clahe':
3521
+ # Apply CLAHE to the image
3522
+ img = apply_clahe(img)
3523
+ flat_img = img.flatten()
3524
+ data_list.append(np.append(flat_img, label))
3525
+
3526
+ # Stop when all images have been processed
3527
+ if i >= total_batches:
3528
+ break
3529
+
3530
+ # Define column names for flattened image data
3531
+ pixel_count = target_size[0] * target_size[1] * (1 if grayscale else 3)
3532
+ column_names = [f"pixel_{i}" for i in range(pixel_count)] + ["Label"]
3533
+
3534
+ # Create DataFrame from flattened data
3535
+ df_img = pd.DataFrame(data_list, columns=column_names)
3536
+
3537
+ if verbose:
3538
+ print("Processed images:", len(df_img))
3539
+ print("Final DataFrame shape:", df_img.shape)
3540
+ display(df_img.head())
3541
+
3542
+ return df_img
3543
+ # Function to apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
3544
+ def apply_clahe(img):
3545
+ import cv2
3546
+ lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) # Convert to LAB color space
3547
+ l, a, b = cv2.split(lab) # Split into channels
3548
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
3549
+ cl = clahe.apply(l) # Apply CLAHE to the L channel
3550
+ limg = cv2.merge((cl, a, b)) # Merge back the channels
3551
+ img_clahe = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) # Convert back to RGB
3552
+ return img_clahe