py2ls 0.2.4.14__py3-none-any.whl → 0.2.4.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/ml2ls.py CHANGED
@@ -506,7 +506,7 @@ def get_models(
506
506
  "Support Vector Machine(svm)",
507
507
  "naive bayes",
508
508
  "Linear Discriminant Analysis (lda)",
509
- "adaboost",
509
+ "AdaBoost",
510
510
  "DecisionTree",
511
511
  "KNeighbors",
512
512
  "Bagging",
@@ -585,7 +585,7 @@ def get_features(
585
585
  "Support Vector Machine(svm)",
586
586
  "naive bayes",
587
587
  "Linear Discriminant Analysis (lda)",
588
- "adaboost",
588
+ "AdaBoost",
589
589
  "DecisionTree",
590
590
  "KNeighbors",
591
591
  "Bagging",
@@ -616,10 +616,10 @@ def get_features(
616
616
  if isinstance(y, str) and y in X.columns:
617
617
  y_col_name = y
618
618
  y = X[y]
619
- y = ips.df_encoder(pd.DataFrame(y), method="dummy")
619
+ y = ips.df_encoder(pd.DataFrame(y), method="label")
620
620
  X = X.drop(y_col_name, axis=1)
621
621
  else:
622
- y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
622
+ y = ips.df_encoder(pd.DataFrame(y), method="label").values.ravel()
623
623
  y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
624
624
  y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
625
625
 
@@ -699,9 +699,11 @@ def get_features(
699
699
  "Support Vector Machine(svm)",
700
700
  "Naive Bayes",
701
701
  "Linear Discriminant Analysis (lda)",
702
- "adaboost",
702
+ "AdaBoost",
703
703
  ]
704
704
  cls = [ips.strcmp(i, cls_)[0] for i in cls]
705
+
706
+ feature_importances = {}
705
707
 
706
708
  # Lasso Feature Selection
707
709
  lasso_importances = (
@@ -712,6 +714,7 @@ def get_features(
712
714
  lasso_selected_features = (
713
715
  lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
714
716
  )
717
+ feature_importances['lasso']=lasso_importances.head(n_features)
715
718
  # Ridge
716
719
  ridge_importances = (
717
720
  features_ridge(x_train, y_train, ridge_params)
@@ -721,6 +724,7 @@ def get_features(
721
724
  selected_ridge_features = (
722
725
  ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
723
726
  )
727
+ feature_importances['ridge']=ridge_importances.head(n_features)
724
728
  # Elastic Net
725
729
  enet_importances = (
726
730
  features_enet(x_train, y_train, enet_params)
@@ -730,6 +734,7 @@ def get_features(
730
734
  selected_enet_features = (
731
735
  enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
732
736
  )
737
+ feature_importances['Enet']=enet_importances.head(n_features)
733
738
  # Random Forest Feature Importance
734
739
  rf_importances = (
735
740
  features_rf(x_train, y_train, rf_params)
@@ -741,6 +746,7 @@ def get_features(
741
746
  if "Random Forest" in cls
742
747
  else []
743
748
  )
749
+ feature_importances['Random Forest']=rf_importances.head(n_features)
744
750
  # Gradient Boosting Feature Importance
745
751
  gb_importances = (
746
752
  features_gradient_boosting(x_train, y_train, gb_params)
@@ -752,6 +758,7 @@ def get_features(
752
758
  if "Gradient Boosting" in cls
753
759
  else []
754
760
  )
761
+ feature_importances['Gradient Boosting']=gb_importances.head(n_features)
755
762
  # xgb
756
763
  xgb_importances = (
757
764
  features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
@@ -759,6 +766,7 @@ def get_features(
759
766
  top_xgb_features = (
760
767
  xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
761
768
  )
769
+ feature_importances['xgb']=xgb_importances.head(n_features)
762
770
 
763
771
  # SVM with RFE
764
772
  selected_svm_features = (
@@ -773,6 +781,7 @@ def get_features(
773
781
  selected_lda_features = (
774
782
  lda_importances.head(n_features)["feature"].values if "lda" in cls else []
775
783
  )
784
+ feature_importances['lda']=lda_importances.head(n_features)
776
785
  # AdaBoost Feature Importance
777
786
  adaboost_importances = (
778
787
  features_adaboost(x_train, y_train, adaboost_params)
@@ -784,6 +793,7 @@ def get_features(
784
793
  if "AdaBoost" in cls
785
794
  else []
786
795
  )
796
+ feature_importances['AdaBoost']=adaboost_importances.head(n_features)
787
797
  # Decision Tree Feature Importance
788
798
  dt_importances = (
789
799
  features_decision_tree(x_train, y_train, dt_params)
@@ -794,7 +804,8 @@ def get_features(
794
804
  dt_importances.head(n_features)["feature"].values
795
805
  if "Decision Tree" in cls
796
806
  else []
797
- )
807
+ )
808
+ feature_importances['Decision Tree']=dt_importances.head(n_features)
798
809
  # Bagging Feature Importance
799
810
  bagging_importances = (
800
811
  features_bagging(x_train, y_train, bagging_params)
@@ -806,6 +817,7 @@ def get_features(
806
817
  if "Bagging" in cls
807
818
  else []
808
819
  )
820
+ feature_importances['Bagging']=bagging_importances.head(n_features)
809
821
  # KNN Feature Importance via Permutation
810
822
  knn_importances = (
811
823
  features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
@@ -813,6 +825,7 @@ def get_features(
813
825
  top_knn_features = (
814
826
  knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
815
827
  )
828
+ feature_importances['KNN']=knn_importances.head(n_features)
816
829
 
817
830
  #! Find common features
818
831
  common_features = ips.shared(
@@ -915,6 +928,7 @@ def get_features(
915
928
  "cv_train_scores": cv_train_results_df,
916
929
  "cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
917
930
  "common_features": list(common_features),
931
+ "feature_importances":feature_importances
918
932
  }
919
933
  if all([plot_, dir_save]):
920
934
  from datetime import datetime
@@ -927,6 +941,7 @@ def get_features(
927
941
  "cv_train_scores": pd.DataFrame(),
928
942
  "cv_test_scores": pd.DataFrame(),
929
943
  "common_features": [],
944
+ "feature_importances":{}
930
945
  }
931
946
  print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
932
947
  return results
@@ -1217,142 +1232,335 @@ def validate_features(
1217
1232
 
1218
1233
  # # If you want to access validation scores
1219
1234
  # print(validation_results)
1220
- def plot_validate_features(res_val):
1235
+ def plot_validate_features(res_val,is_binary=True,figsize=None):
1221
1236
  """
1222
1237
  plot the results of 'validate_features()'
1223
1238
  """
1224
- colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1225
- if res_val.shape[0] > 5:
1226
- alpha = 0
1227
- figsize = [8, 10]
1228
- subplot_layout = [1, 2]
1229
- ncols = 2
1230
- bbox_to_anchor = [1.5, 0.6]
1231
- else:
1232
- alpha = 0.03
1233
- figsize = [10, 6]
1234
- subplot_layout = [1, 1]
1235
- ncols = 1
1236
- bbox_to_anchor = [1, 1]
1237
- nexttile = plot.subplot(figsize=figsize)
1238
- ax = nexttile(subplot_layout[0], subplot_layout[1])
1239
- for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1240
- fpr = res_val["roc_curve"][model_name]["fpr"]
1241
- tpr = res_val["roc_curve"][model_name]["tpr"]
1242
- (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1243
- mean_auc = res_val["roc_curve"][model_name]["auc"]
1244
- plot_roc_curve(
1245
- fpr,
1246
- tpr,
1247
- mean_auc,
1248
- lower_ci,
1249
- upper_ci,
1250
- model_name=model_name,
1251
- lw=1.5,
1252
- color=colors[i],
1253
- alpha=alpha,
1254
- ax=ax,
1239
+ if is_binary:
1240
+ colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1241
+ if res_val.shape[0] > 5:
1242
+ alpha = 0
1243
+ figsize = [8, 10] if figsize is None else figsize
1244
+ subplot_layout = [1, 2]
1245
+ ncols = 2
1246
+ bbox_to_anchor = [1.5, 0.6]
1247
+ else:
1248
+ alpha = 0.03
1249
+ figsize = [10, 6] if figsize is None else figsize
1250
+ subplot_layout = [1, 1]
1251
+ ncols = 1
1252
+ bbox_to_anchor = [1, 1]
1253
+ nexttile = plot.subplot(figsize=figsize)
1254
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1255
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1256
+ fpr = res_val["roc_curve"][model_name]["fpr"]
1257
+ tpr = res_val["roc_curve"][model_name]["tpr"]
1258
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1259
+ mean_auc = res_val["roc_curve"][model_name]["auc"]
1260
+ plot_roc_curve(
1261
+ fpr,
1262
+ tpr,
1263
+ mean_auc,
1264
+ lower_ci,
1265
+ upper_ci,
1266
+ model_name=model_name,
1267
+ lw=1.5,
1268
+ color=colors[i],
1269
+ alpha=alpha,
1270
+ ax=ax,
1271
+ )
1272
+ plot.figsets(
1273
+ sp=2,
1274
+ legend=dict(
1275
+ loc="upper right",
1276
+ ncols=ncols,
1277
+ fontsize=8,
1278
+ bbox_to_anchor=[1.5, 0.6],
1279
+ markerscale=0.8,
1280
+ ),
1255
1281
  )
1256
- plot.figsets(
1257
- sp=2,
1258
- legend=dict(
1259
- loc="upper right",
1260
- ncols=ncols,
1261
- fontsize=8,
1262
- bbox_to_anchor=[1.5, 0.6],
1263
- markerscale=0.8,
1264
- ),
1265
- )
1266
- # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1267
-
1268
- ax = nexttile(subplot_layout[0], subplot_layout[1])
1269
- for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1270
- plot_pr_curve(
1271
- recall=res_val["pr_curve"][model_name]["recall"],
1272
- precision=res_val["pr_curve"][model_name]["precision"],
1273
- avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1274
- model_name=model_name,
1275
- color=colors[i],
1276
- lw=1.5,
1277
- alpha=alpha,
1278
- ax=ax,
1282
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1283
+
1284
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1285
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1286
+ plot_pr_curve(
1287
+ recall=res_val["pr_curve"][model_name]["recall"],
1288
+ precision=res_val["pr_curve"][model_name]["precision"],
1289
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1290
+ model_name=model_name,
1291
+ color=colors[i],
1292
+ lw=1.5,
1293
+ alpha=alpha,
1294
+ ax=ax,
1295
+ )
1296
+ plot.figsets(
1297
+ sp=2,
1298
+ legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1279
1299
  )
1280
- plot.figsets(
1281
- sp=2,
1282
- legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1283
- )
1284
- # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1300
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1301
+ else:
1302
+ colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1303
+ modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
1304
+ classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
1305
+ if res_val.shape[0] > 5:
1306
+ alpha = 0
1307
+ figsize = [8, 8*2*(len(classes))] if figsize is None else figsize
1308
+ subplot_layout = [1, 2]
1309
+ ncols = 2
1310
+ bbox_to_anchor = [1.5, 0.6]
1311
+ else:
1312
+ alpha = 0.03
1313
+ figsize = [10, 6*(len(classes))] if figsize is None else figsize
1314
+ subplot_layout = [1, 1]
1315
+ ncols = 1
1316
+ bbox_to_anchor = [1, 1]
1317
+ nexttile = plot.subplot(2*(len(classes)),2,figsize=figsize)
1318
+ for iclass, class_ in enumerate(classes):
1319
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1320
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1321
+ fpr = res_val["roc_curve"][model_name]["fpr"][class_]
1322
+ tpr = res_val["roc_curve"][model_name]["tpr"][class_]
1323
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
1324
+ mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
1325
+ plot_roc_curve(
1326
+ fpr,
1327
+ tpr,
1328
+ mean_auc,
1329
+ lower_ci,
1330
+ upper_ci,
1331
+ model_name=model_name,
1332
+ lw=1.5,
1333
+ color=colors[i],
1334
+ alpha=alpha,
1335
+ ax=ax,
1336
+ )
1337
+ plot.figsets(
1338
+ sp=2,
1339
+ title=class_,
1340
+ legend=dict(
1341
+ loc="upper right",
1342
+ ncols=ncols,
1343
+ fontsize=8,
1344
+ bbox_to_anchor=[1.5, 0.6],
1345
+ markerscale=0.8,
1346
+ ),
1347
+ )
1348
+ # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1349
+
1350
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1351
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1352
+ plot_pr_curve(
1353
+ recall=res_val["pr_curve"][model_name]["recall"][iclass],
1354
+ precision=res_val["pr_curve"][model_name]["precision"][iclass],
1355
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
1356
+ model_name=model_name,
1357
+ color=colors[i],
1358
+ lw=1.5,
1359
+ alpha=alpha,
1360
+ ax=ax,
1361
+ )
1362
+ plot.figsets(
1363
+ sp=2,
1364
+ title=class_,
1365
+ legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1366
+ )
1285
1367
 
1368
+ def plot_validate_features_single(res_val, figsize=None,is_binary=True):
1369
+ if is_binary:
1370
+ if figsize is None:
1371
+ nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3,figsize=[13,4*len(ips.flatten(res_val["pr_curve"].index))])
1372
+ else:
1373
+ nexttile = plot.subplot(
1374
+ len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
1375
+ )
1376
+ for model_name in ips.flatten(res_val["pr_curve"].index):
1377
+ fpr = res_val["roc_curve"][model_name]["fpr"]
1378
+ tpr = res_val["roc_curve"][model_name]["tpr"]
1379
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1380
+ mean_auc = res_val["roc_curve"][model_name]["auc"]
1381
+
1382
+ # Plotting
1383
+ plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
1384
+ model_name=model_name, ax=nexttile())
1385
+ plot.figsets(title=model_name, sp=2)
1386
+
1387
+ plot_pr_binary(
1388
+ recall=res_val["pr_curve"][model_name]["recall"],
1389
+ precision=res_val["pr_curve"][model_name]["precision"],
1390
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1391
+ model_name=model_name,
1392
+ ax=nexttile(),
1393
+ )
1394
+ plot.figsets(title=model_name, sp=2)
1286
1395
 
1287
- def plot_validate_features_single(res_val, figsize=None):
1288
- if figsize is None:
1289
- nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
1396
+ # plot cm
1397
+ plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
1398
+ plot.figsets(title=model_name, sp=2)
1290
1399
  else:
1291
- nexttile = plot.subplot(
1292
- len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
1293
- )
1294
- for model_name in ips.flatten(res_val["pr_curve"].index):
1295
- fpr = res_val["roc_curve"][model_name]["fpr"]
1296
- tpr = res_val["roc_curve"][model_name]["tpr"]
1297
- (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1298
- mean_auc = res_val["roc_curve"][model_name]["auc"]
1299
-
1300
- # Plotting
1301
- plot_roc_curve(fpr, tpr, mean_auc, lower_ci, upper_ci,
1302
- model_name=model_name, ax=nexttile())
1303
- plot.figsets(title=model_name, sp=2)
1304
-
1305
- plot_pr_binary(
1306
- recall=res_val["pr_curve"][model_name]["recall"],
1307
- precision=res_val["pr_curve"][model_name]["precision"],
1308
- avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1309
- model_name=model_name,
1310
- ax=nexttile(),
1311
- )
1312
- plot.figsets(title=model_name, sp=2)
1400
+
1401
+ modname_tmp=ips.flatten(res_val["roc_curve"].index)[0]
1402
+ classes=list(res_val["roc_curve"][modname_tmp]['fpr'].keys())
1403
+ if figsize is None:
1404
+ nexttile = plot.subplot(len(modname_tmp), 3,figsize=[15,len(modname_tmp)*5])
1405
+ else:
1406
+ nexttile = plot.subplot(len(modname_tmp), 3, figsize=figsize)
1407
+ colors = plot.get_color(len(classes))
1408
+ for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1409
+ ax = nexttile()
1410
+ for iclass, class_ in enumerate(classes):
1411
+ fpr = res_val["roc_curve"][model_name]["fpr"][class_]
1412
+ tpr = res_val["roc_curve"][model_name]["tpr"][class_]
1413
+ (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"][iclass]
1414
+ mean_auc = res_val["roc_curve"][model_name]["auc"][class_]
1415
+ plot_roc_curve(
1416
+ fpr,
1417
+ tpr,
1418
+ mean_auc,
1419
+ lower_ci,
1420
+ upper_ci,
1421
+ model_name=class_,
1422
+ lw=1.5,
1423
+ color=colors[iclass],
1424
+ alpha=0.03,
1425
+ ax=ax,
1426
+ )
1427
+ plot.figsets(
1428
+ sp=2,
1429
+ title=model_name,
1430
+ legend=dict(
1431
+ loc="best",
1432
+ fontsize=8,
1433
+ ),
1434
+ )
1435
+
1436
+ ax = nexttile()
1437
+ for iclass, class_ in enumerate(classes):
1438
+ plot_pr_curve(
1439
+ recall=res_val["pr_curve"][model_name]["recall"][iclass],
1440
+ precision=res_val["pr_curve"][model_name]["precision"][iclass],
1441
+ avg_precision=res_val["pr_curve"][model_name]["avg_precision"][iclass],
1442
+ model_name=class_,
1443
+ color=colors[iclass],
1444
+ lw=1.5,
1445
+ alpha=0.03,
1446
+ ax=ax,
1447
+ )
1448
+ plot.figsets(
1449
+ sp=2,
1450
+ title=class_,
1451
+ legend=dict(loc="best", fontsize=8),
1452
+ )
1453
+
1454
+ plot_cm(res_val["confusion_matrix"][model_name],labels_name=classes, ax=nexttile(), normalize=False)
1455
+ plot.figsets(title=model_name, sp=2)
1313
1456
 
1314
- # plot cm
1315
- plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
1316
- plot.figsets(title=model_name, sp=2)
1317
1457
 
1458
+ def cal_precision_recall(
1459
+ y_true, y_pred_proba, is_binary=True):
1460
+ if is_binary:
1461
+ precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
1462
+ avg_precision_ = average_precision_score(y_true, y_pred_proba)
1463
+ return precision_, recall_,avg_precision_
1464
+ else:
1465
+ n_classes = y_pred_proba.shape[1] # Number of classes
1466
+ precision_ = []
1467
+ recall_ = []
1468
+
1469
+ # One-vs-rest approach for multi-class precision-recall curve
1470
+ for class_idx in range(n_classes):
1471
+ precision, recall, _ = precision_recall_curve(
1472
+ (y_true == class_idx).astype(int), # Binarize true labels for the current class
1473
+ y_pred_proba[:, class_idx], # Probabilities for the current class
1474
+ )
1318
1475
 
1476
+ precision_.append(precision)
1477
+ recall_.append(recall)
1478
+ # Optionally, you can compute average precision for each class
1479
+ avg_precision_ = []
1480
+ for class_idx in range(n_classes):
1481
+ avg_precision = average_precision_score(
1482
+ (y_true == class_idx).astype(int), # Binarize true labels for the current class
1483
+ y_pred_proba[:, class_idx], # Probabilities for the current class
1484
+ )
1485
+ avg_precision_.append(avg_precision)
1486
+ return precision_, recall_,avg_precision_
1487
+
1319
1488
  def cal_auc_ci(
1320
- y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
1489
+ y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,is_binary=True, verbose=True
1321
1490
  ):
1322
- y_true = np.asarray(y_true)
1323
- y_pred = np.asarray(y_pred)
1324
- bootstrapped_scores = []
1325
- if verbose:
1326
- print("auroc score:", roc_auc_score(y_true, y_pred))
1327
- rng = np.random.RandomState(random_state)
1328
- for i in range(n_bootstraps):
1329
- # bootstrap by sampling with replacement on the prediction indices
1330
- indices = rng.randint(0, len(y_pred), len(y_pred))
1331
- if len(np.unique(y_true[indices])) < 2:
1332
- # We need at least one positive and one negative sample for ROC AUC
1333
- # to be defined: reject the sample
1334
- continue
1335
- if isinstance(y_true, np.ndarray):
1336
- score = roc_auc_score(y_true[indices], y_pred[indices])
1337
- else:
1338
- score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
1339
- bootstrapped_scores.append(score)
1340
- # print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
1341
- sorted_scores = np.array(bootstrapped_scores)
1342
- sorted_scores.sort()
1343
-
1344
- # Computing the lower and upper bound of the 90% confidence interval
1345
- # You can change the bounds percentiles to 0.025 and 0.975 to get
1346
- # a 95% confidence interval instead.
1347
- confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1348
- confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1349
- if verbose:
1350
- print(
1351
- "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1352
- confidence_lower, confidence_upper
1491
+ if is_binary:
1492
+ y_true = np.asarray(y_true)
1493
+ y_pred = np.asarray(y_pred)
1494
+ bootstrapped_scores = []
1495
+ if verbose:
1496
+ print("auroc score:", roc_auc_score(y_true, y_pred))
1497
+ rng = np.random.RandomState(random_state)
1498
+ for i in range(n_bootstraps):
1499
+ # bootstrap by sampling with replacement on the prediction indices
1500
+ indices = rng.randint(0, len(y_pred), len(y_pred))
1501
+ if len(np.unique(y_true[indices])) < 2:
1502
+ # We need at least one positive and one negative sample for ROC AUC
1503
+ # to be defined: reject the sample
1504
+ continue
1505
+ if isinstance(y_true, np.ndarray):
1506
+ score = roc_auc_score(y_true[indices], y_pred[indices])
1507
+ else:
1508
+ score = roc_auc_score(y_true.iloc[indices], y_pred.iloc[indices])
1509
+ bootstrapped_scores.append(score)
1510
+ # print("Bootstrap #{} ROC area: {:0.3f}".format(i + 1, score))
1511
+ sorted_scores = np.array(bootstrapped_scores)
1512
+ sorted_scores.sort()
1513
+
1514
+ # Computing the lower and upper bound of the 90% confidence interval
1515
+ # You can change the bounds percentiles to 0.025 and 0.975 to get
1516
+ # a 95% confidence interval instead.
1517
+ confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1518
+ confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1519
+ if verbose:
1520
+ print(
1521
+ "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1522
+ confidence_lower, confidence_upper
1523
+ )
1353
1524
  )
1354
- )
1355
- return confidence_lower, confidence_upper
1525
+ return confidence_lower, confidence_upper
1526
+ else:
1527
+ from sklearn.preprocessing import label_binarize
1528
+ # Multi-class classification case
1529
+ y_true = np.asarray(y_true)
1530
+ y_pred = np.asarray(y_pred)
1531
+
1532
+ # Binarize the multi-class labels for OvR computation
1533
+ y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) # One-vs-Rest transformation
1534
+ n_classes = y_true_bin.shape[1] # Number of classes
1535
+
1536
+ bootstrapped_scores = np.zeros((n_classes, n_bootstraps)) # Store scores for each class
1537
+
1538
+ if verbose:
1539
+ print("AUROC scores for each class:")
1540
+ for i in range(n_classes):
1541
+ print(f"Class {i}: {roc_auc_score(y_true_bin[:, i], y_pred[:, i])}")
1542
+
1543
+ rng = np.random.RandomState(random_state)
1544
+ for i in range(n_bootstraps):
1545
+ indices = rng.randint(0, len(y_pred), len(y_pred))
1546
+ for class_idx in range(n_classes):
1547
+ if len(np.unique(y_true_bin[indices, class_idx])) < 2:
1548
+ continue # Reject if the class doesn't have both positive and negative samples
1549
+ score = roc_auc_score(y_true_bin[indices, class_idx], y_pred[indices, class_idx])
1550
+ bootstrapped_scores[class_idx, i] = score
1551
+
1552
+ # Calculating the confidence intervals for each class
1553
+ confidence_intervals = []
1554
+ for class_idx in range(n_classes):
1555
+ sorted_scores = np.sort(bootstrapped_scores[class_idx])
1556
+ confidence_lower = sorted_scores[int((1 - ci) * len(sorted_scores))]
1557
+ confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1558
+ confidence_intervals.append((confidence_lower, confidence_upper))
1559
+
1560
+ if verbose:
1561
+ print(f"Class {class_idx} - Confidence interval: [{confidence_lower:.3f} - {confidence_upper:.3f}]")
1562
+
1563
+ return confidence_intervals
1356
1564
 
1357
1565
 
1358
1566
  def plot_roc_curve(
@@ -1517,7 +1725,7 @@ def plot_pr_binary(
1517
1725
 
1518
1726
  pr_boundary = interp1d(recall, precision, kind="linear", fill_value="extrapolate")
1519
1727
  for f_score in f_scores:
1520
- x_vals = np.linspace(0.01, 1, 10000)
1728
+ x_vals = np.linspace(0.01, 1, 20000)
1521
1729
  y_vals = f_score * x_vals / (2 * x_vals - f_score)
1522
1730
  y_vals_clipped = np.minimum(y_vals, pr_boundary(x_vals))
1523
1731
  y_vals_clipped = np.clip(y_vals_clipped, 1e-3, None) # Prevent going to zero
@@ -1553,7 +1761,7 @@ def plot_pr_binary(
1553
1761
  def plot_cm(
1554
1762
  cm,
1555
1763
  labels_name=None,
1556
- thresh=0.8,
1764
+ thresh=0.8, # for set color
1557
1765
  axis_labels=None,
1558
1766
  cmap="Reds",
1559
1767
  normalize=True,
@@ -2029,11 +2237,21 @@ def predict(
2029
2237
  if isinstance(y_train, str) and y_train in x_train.columns:
2030
2238
  y_train_col_name = y_train
2031
2239
  y_train = x_train[y_train]
2032
- y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2240
+ # y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2033
2241
  x_train = x_train.drop(y_train_col_name, axis=1)
2242
+ # else:
2243
+ # y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
2244
+ y_train=pd.DataFrame(y_train)
2245
+ if y_train.select_dtypes(include=np.number).empty:
2246
+ y_train_=ips.df_encoder(y_train, method="dummy",drop=None)
2247
+ is_binary = False if y_train_.shape[1] >2 else True
2034
2248
  else:
2035
- y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
2249
+ y_train_=ips.flatten(y_train.values)
2250
+ is_binary = False if len(y_train_)>2 else True
2036
2251
 
2252
+ if is_binary:
2253
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="label")
2254
+ print('is_binary:',is_binary)
2037
2255
  if x_true is None:
2038
2256
  x_train, x_true, y_train, y_true = train_test_split(
2039
2257
  x_train,
@@ -2042,23 +2260,27 @@ def predict(
2042
2260
  random_state=random_state,
2043
2261
  stratify=y_train if purpose == "classification" else None,
2044
2262
  )
2263
+
2045
2264
  if isinstance(y_train, str) and y_train in x_train.columns:
2046
2265
  y_train_col_name = y_train
2047
2266
  y_train = x_train[y_train]
2048
- y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
2267
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="label") if is_binary else y_train
2049
2268
  x_train = x_train.drop(y_train_col_name, axis=1)
2050
- else:
2269
+ if is_binary:
2051
2270
  y_train = ips.df_encoder(
2052
- pd.DataFrame(y_train), method="dummy"
2053
- ).values.ravel()
2271
+ pd.DataFrame(y_train), method="label"
2272
+ ).values.ravel()
2273
+
2054
2274
  if y_true is not None:
2055
2275
  if isinstance(y_true, str) and y_true in x_true.columns:
2056
2276
  y_true_col_name = y_true
2057
2277
  y_true = x_true[y_true]
2058
- y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
2278
+ y_true = ips.df_encoder(pd.DataFrame(y_true), method="label") if is_binary else y_true
2279
+ y_true = pd.DataFrame(y_true)
2059
2280
  x_true = x_true.drop(y_true_col_name, axis=1)
2060
- else:
2061
- y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
2281
+ if is_binary:
2282
+ y_true = ips.df_encoder(pd.DataFrame(y_true), method="label").values.ravel()
2283
+ y_true = pd.DataFrame(y_true)
2062
2284
 
2063
2285
  # to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
2064
2286
 
@@ -2068,7 +2290,6 @@ def predict(
2068
2290
  y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
2069
2291
  )
2070
2292
  y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
2071
-
2072
2293
  # Ensure common features are selected
2073
2294
  if common_features is not None:
2074
2295
  x_train, x_true = x_train[common_features], x_true[common_features]
@@ -2077,10 +2298,7 @@ def predict(
2077
2298
  x_train, x_true = x_train[share_col_names], x_true[share_col_names]
2078
2299
 
2079
2300
  x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
2080
- x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
2081
- x_true, method="dummy"
2082
- )
2083
-
2301
+ x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(x_true, method="dummy")
2084
2302
  # Handle class imbalance using SMOTE (only for classification)
2085
2303
  if (
2086
2304
  smote
@@ -2091,7 +2309,13 @@ def predict(
2091
2309
 
2092
2310
  smote_sampler = SMOTE(random_state=random_state)
2093
2311
  x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
2094
-
2312
+ if not is_binary:
2313
+ if isinstance(y_train, np.ndarray):
2314
+ y_train = ips.df_encoder(data=pd.DataFrame(y_train),method='label')
2315
+ y_train=np.asarray(y_train)
2316
+ if isinstance(y_train, np.ndarray):
2317
+ y_true = ips.df_encoder(data=pd.DataFrame(y_true),method='label')
2318
+ y_true=np.asarray(y_true)
2095
2319
  # Hyperparameter grids for tuning
2096
2320
  if cv_level in ["low", "simple", "s", "l"]:
2097
2321
  param_grids = {
@@ -2670,95 +2894,181 @@ def predict(
2670
2894
  print(f"\nTraining and validating {name}:")
2671
2895
 
2672
2896
  # Grid search with KFold or StratifiedKFold
2673
- gs = GridSearchCV(
2674
- clf,
2675
- param_grid=param_grids.get(name, {}),
2676
- scoring=(
2677
- "roc_auc" if purpose == "classification" else "neg_mean_squared_error"
2678
- ),
2679
- cv=cv,
2680
- n_jobs=n_jobs,
2681
- verbose=verbose,
2682
- )
2683
- gs.fit(x_train, y_train)
2684
- best_clf = gs.best_estimator_
2685
- # make sure x_train and x_test has the same name
2686
- x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2687
- y_pred = best_clf.predict(x_true)
2688
-
2689
- # y_pred_proba
2690
- if hasattr(best_clf, "predict_proba"):
2691
- y_pred_proba = best_clf.predict_proba(x_true)[:, 1]
2692
- elif hasattr(best_clf, "decision_function"):
2693
- # If predict_proba is not available, use decision_function (e.g., for SVM)
2694
- y_pred_proba = best_clf.decision_function(x_true)
2695
- # Ensure y_pred_proba is within 0 and 1 bounds
2696
- y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
2697
- y_pred_proba.max() - y_pred_proba.min()
2897
+ if is_binary:
2898
+ gs = GridSearchCV(
2899
+ clf,
2900
+ param_grid=param_grids.get(name, {}),
2901
+ scoring=(
2902
+ "roc_auc" if purpose == "classification" else "neg_mean_squared_error"
2903
+ ),
2904
+ cv=cv,
2905
+ n_jobs=n_jobs,
2906
+ verbose=verbose,
2698
2907
  )
2908
+
2909
+ gs.fit(x_train, y_train)
2910
+ best_clf = gs.best_estimator_
2911
+ # make sure x_train and x_test has the same name
2912
+ x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2913
+ y_pred = best_clf.predict(x_true)
2914
+ if hasattr(best_clf, "predict_proba"):
2915
+ y_pred_proba = best_clf.predict_proba(x_true)
2916
+ print("Shape of predicted probabilities:", y_pred_proba.shape)
2917
+ if y_pred_proba.shape[1] == 1:
2918
+ y_pred_proba = np.hstack([1 - y_pred_proba, y_pred_proba]) # Add missing class probabilities
2919
+ y_pred_proba = y_pred_proba[:, 1]
2920
+ elif hasattr(best_clf, "decision_function"):
2921
+ # If predict_proba is not available, use decision_function (e.g., for SVM)
2922
+ y_pred_proba = best_clf.decision_function(x_true)
2923
+ # Ensure y_pred_proba is within 0 and 1 bounds
2924
+ y_pred_proba = (y_pred_proba - y_pred_proba.min()) / (
2925
+ y_pred_proba.max() - y_pred_proba.min()
2926
+ )
2927
+ else:
2928
+ y_pred_proba = None # No probability output for certain models
2699
2929
  else:
2700
- y_pred_proba = None # No probability output for certain models
2930
+ gs = GridSearchCV(
2931
+ clf,
2932
+ param_grid=param_grids.get(name, {}),
2933
+ scoring=(
2934
+ "roc_auc_ovr" if purpose == "classification" else "neg_mean_squared_error"
2935
+ ),
2936
+ cv=cv,
2937
+ n_jobs=n_jobs,
2938
+ verbose=verbose,
2939
+ )
2701
2940
 
2941
+ # Fit GridSearchCV
2942
+ gs.fit(x_train, y_train)
2943
+ best_clf = gs.best_estimator_
2944
+
2945
+ # Ensure x_true aligns with x_train columns
2946
+ x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2947
+ y_pred = best_clf.predict(x_true)
2948
+
2949
+ # Handle prediction probabilities for multiclass
2950
+ if hasattr(best_clf, "predict_proba"):
2951
+ y_pred_proba = best_clf.predict_proba(x_true)
2952
+ elif hasattr(best_clf, "decision_function"):
2953
+ y_pred_proba = best_clf.decision_function(x_true)
2954
+
2955
+ # Normalize for multiclass if necessary
2956
+ if y_pred_proba.ndim == 2:
2957
+ y_pred_proba = (y_pred_proba - y_pred_proba.min(axis=1, keepdims=True)) / \
2958
+ (y_pred_proba.max(axis=1, keepdims=True) - y_pred_proba.min(axis=1, keepdims=True))
2959
+ else:
2960
+ y_pred_proba = None # No probability output for certain models
2961
+
2702
2962
  validation_scores = {}
2703
- if y_true is not None:
2963
+
2964
+ if y_true is not None and y_pred_proba is not None:
2704
2965
  validation_scores = cal_metrics(
2705
2966
  y_true,
2706
2967
  y_pred,
2707
2968
  y_pred_proba=y_pred_proba,
2969
+ is_binary=is_binary,
2708
2970
  purpose=purpose,
2709
2971
  average="weighted",
2710
2972
  )
2711
-
2712
- # Calculate ROC curve
2713
- # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
2714
- if y_pred_proba is not None:
2715
- # fpr, tpr, roc_auc = dict(), dict(), dict()
2716
- fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
2717
- lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
2718
- roc_auc = auc(fpr, tpr)
2719
- roc_info = {
2720
- "fpr": fpr.tolist(),
2721
- "tpr": tpr.tolist(),
2722
- "auc": roc_auc,
2723
- "ci95": (lower_ci, upper_ci),
2724
- }
2725
- # precision-recall curve
2726
- precision_, recall_, _ = precision_recall_curve(y_true, y_pred_proba)
2727
- avg_precision_ = average_precision_score(y_true, y_pred_proba)
2728
- pr_info = {
2729
- "precision": precision_,
2730
- "recall": recall_,
2731
- "avg_precision": avg_precision_,
2732
- }
2733
- else:
2734
- roc_info, pr_info = None, None
2735
- if purpose == "classification":
2736
- results[name] = {
2737
- "best_clf": gs.best_estimator_,
2738
- "best_params": gs.best_params_,
2739
- "auc_indiv": [
2740
- gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
2741
- for i in range(cv_folds)
2742
- ],
2743
- "scores": validation_scores,
2744
- "roc_curve": roc_info,
2745
- "pr_curve": pr_info,
2746
- "confusion_matrix": confusion_matrix(y_true, y_pred),
2747
- "predictions": y_pred.tolist(),
2748
- "predictions_proba": (
2749
- y_pred_proba.tolist() if y_pred_proba is not None else None
2750
- ),
2751
- }
2752
- else: # "regression"
2753
- results[name] = {
2754
- "best_clf": gs.best_estimator_,
2755
- "best_params": gs.best_params_,
2756
- "scores": validation_scores, # e.g., neg_MSE, R², etc.
2757
- "predictions": y_pred.tolist(),
2758
- "predictions_proba": (
2759
- y_pred_proba.tolist() if y_pred_proba is not None else None
2760
- ),
2761
- }
2973
+ if is_binary:
2974
+ # Calculate ROC curve
2975
+ # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
2976
+ if y_pred_proba is not None:
2977
+ # fpr, tpr, roc_auc = dict(), dict(), dict()
2978
+ fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
2979
+ lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
2980
+ roc_auc = auc(fpr, tpr)
2981
+ roc_info = {
2982
+ "fpr": fpr.tolist(),
2983
+ "tpr": tpr.tolist(),
2984
+ "auc": roc_auc,
2985
+ "ci95": (lower_ci, upper_ci),
2986
+ }
2987
+ # precision-recall curve
2988
+ precision_, recall_, _ = cal_precision_recall(y_true, y_pred_proba)
2989
+ avg_precision_ = average_precision_score(y_true, y_pred_proba)
2990
+ pr_info = {
2991
+ "precision": precision_,
2992
+ "recall": recall_,
2993
+ "avg_precision": avg_precision_,
2994
+ }
2995
+ else:
2996
+ roc_info, pr_info = None, None
2997
+ if purpose == "classification":
2998
+ results[name] = {
2999
+ "best_clf": gs.best_estimator_,
3000
+ "best_params": gs.best_params_,
3001
+ "auc_indiv": [
3002
+ gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
3003
+ for i in range(cv_folds)
3004
+ ],
3005
+ "scores": validation_scores,
3006
+ "roc_curve": roc_info,
3007
+ "pr_curve": pr_info,
3008
+ "confusion_matrix": confusion_matrix(y_true, y_pred),
3009
+ "predictions": y_pred.tolist(),
3010
+ "predictions_proba": (
3011
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3012
+ ),
3013
+ }
3014
+ else: # "regression"
3015
+ results[name] = {
3016
+ "best_clf": gs.best_estimator_,
3017
+ "best_params": gs.best_params_,
3018
+ "scores": validation_scores, # e.g., neg_MSE, R², etc.
3019
+ "predictions": y_pred.tolist(),
3020
+ "predictions_proba": (
3021
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3022
+ ),
3023
+ }
3024
+ else: # multi-classes
3025
+ if y_pred_proba is not None:
3026
+ # fpr, tpr, roc_auc = dict(), dict(), dict()
3027
+ # fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
3028
+ confidence_intervals = cal_auc_ci(y_true, y_pred_proba, verbose=False,is_binary=is_binary)
3029
+ roc_info = {
3030
+ "fpr": validation_scores["fpr"],
3031
+ "tpr": validation_scores["tpr"],
3032
+ "auc": validation_scores["roc_auc_by_class"],
3033
+ "ci95": confidence_intervals,
3034
+ }
3035
+ # precision-recall curve
3036
+ precision_, recall_, avg_precision_ = cal_precision_recall(y_true, y_pred_proba,is_binary=is_binary)
3037
+ pr_info = {
3038
+ "precision": precision_,
3039
+ "recall": recall_,
3040
+ "avg_precision": avg_precision_,
3041
+ }
3042
+ else:
3043
+ roc_info, pr_info = None, None
3044
+
3045
+ if purpose == "classification":
3046
+ results[name] = {
3047
+ "best_clf": gs.best_estimator_,
3048
+ "best_params": gs.best_params_,
3049
+ "auc_indiv": [
3050
+ gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
3051
+ for i in range(cv_folds)
3052
+ ],
3053
+ "scores": validation_scores,
3054
+ "roc_curve": roc_info,
3055
+ "pr_curve": pr_info,
3056
+ "confusion_matrix": confusion_matrix(y_true, y_pred),
3057
+ "predictions": y_pred.tolist(),
3058
+ "predictions_proba": (
3059
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3060
+ ),
3061
+ }
3062
+ else: # "regression"
3063
+ results[name] = {
3064
+ "best_clf": gs.best_estimator_,
3065
+ "best_params": gs.best_params_,
3066
+ "scores": validation_scores, # e.g., neg_MSE, R², etc.
3067
+ "predictions": y_pred.tolist(),
3068
+ "predictions_proba": (
3069
+ y_pred_proba.tolist() if y_pred_proba is not None else None
3070
+ ),
3071
+ }
2762
3072
 
2763
3073
  else:
2764
3074
  results[name] = {
@@ -2773,7 +3083,6 @@ def predict(
2773
3083
 
2774
3084
  # Convert results to DataFrame
2775
3085
  df_results = pd.DataFrame.from_dict(results, orient="index")
2776
-
2777
3086
  # sort
2778
3087
  if y_true is not None and purpose == "classification":
2779
3088
  df_scores = pd.DataFrame(
@@ -2790,26 +3099,29 @@ def predict(
2790
3099
  plot.figsets(xangle=30)
2791
3100
  if dir_save:
2792
3101
  ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
3102
+
3103
+ df_scores=df_scores.select_dtypes(include=np.number)
3104
+
2793
3105
  if df_scores.shape[0] > 1: # draw cluster
2794
3106
  plot.heatmap(df_scores, kind="direct", cluster=True)
2795
3107
  plot.figsets(xangle=30)
2796
3108
  if dir_save:
2797
3109
  ips.figsave(dir_save + f"scores_clus{now_}.pdf")
2798
3110
  if all([plot_, y_true is not None, purpose == "classification"]):
2799
- try:
2800
- if len(models) > 3:
2801
- plot_validate_features(df_results)
2802
- else:
2803
- plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
2804
- if dir_save:
2805
- ips.figsave(dir_save + f"validate_features{now_}.pdf")
2806
- except Exception as e:
2807
- print(f"Error: 在画图的过程中出现了问题:{e}")
3111
+ # try:
3112
+ if len(models) > 3:
3113
+ plot_validate_features(df_results,is_binary=is_binary)
3114
+ else:
3115
+ plot_validate_features_single(df_results, is_binary=is_binary)
3116
+ if dir_save:
3117
+ ips.figsave(dir_save + f"validate_features{now_}.pdf")
3118
+ # except Exception as e:
3119
+ # print(f"Error: 在画图的过程中出现了问题:{e}")
2808
3120
  return df_results
2809
3121
 
2810
3122
 
2811
3123
  def cal_metrics(
2812
- y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
3124
+ y_true, y_pred, y_pred_proba=None, is_binary=True,purpose="regression", average="weighted"
2813
3125
  ):
2814
3126
  """
2815
3127
  Calculate regression or classification metrics based on the purpose.
@@ -2879,19 +3191,362 @@ def cal_metrics(
2879
3191
  }
2880
3192
 
2881
3193
  # Confusion matrix to calculate specificity
2882
- tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
2883
- validation_scores["specificity"] = (
2884
- tn / (tn + fp) if (tn + fp) > 0 else 0
2885
- ) # Specificity calculation
3194
+ if is_binary:
3195
+ cm = confusion_matrix(y_true, y_pred)
3196
+ if cm.size == 4:
3197
+ tn, fp, fn, tp = cm.ravel()
3198
+ else:
3199
+ # Handle single-class predictions
3200
+ tn, fp, fn, tp = 0, 0, 0, 0
3201
+ print("Warning: Only one class found in y_pred or y_true.")
3202
+
3203
+ # Specificity calculation
3204
+ validation_scores["specificity"] = (
3205
+ tn / (tn + fp) if (tn + fp) > 0 else 0
3206
+ )
3207
+ if y_pred_proba is not None:
3208
+ # Calculate ROC-AUC
3209
+ validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
3210
+ # PR-AUC (Precision-Recall AUC) calculation
3211
+ validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
3212
+
3213
+ else: # multi-class
3214
+ from sklearn.preprocessing import label_binarize
3215
+ #* Multi-class ROC calculation
3216
+ y_pred_proba = np.asarray(y_pred_proba)
3217
+ classes = np.unique(y_true)
3218
+ y_true_bin = label_binarize(y_true, classes=classes)
3219
+ if isinstance(y_true, np.ndarray):
3220
+ y_true = ips.df_encoder(data=pd.DataFrame(y_true), method='dum',prefix='Label')
3221
+ # Initialize dictionaries to store FPR, TPR, and AUC for each class
3222
+ fpr = dict()
3223
+ tpr = dict()
3224
+ roc_auc = dict()
3225
+ for i, class_label in enumerate(classes):
3226
+ fpr[class_label], tpr[class_label], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
3227
+ roc_auc[class_label] = auc(fpr[class_label], tpr[class_label])
3228
+
3229
+ # Store the mean ROC AUC
3230
+ try:
3231
+ validation_scores["roc_auc"] = roc_auc_score(
3232
+ y_true, y_pred_proba, multi_class="ovr", average=average
3233
+ )
3234
+ except Exception as e:
3235
+ y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True)
3236
+ validation_scores["roc_auc"] = roc_auc_score(
3237
+ y_true, y_pred_proba, multi_class="ovr", average=average
3238
+ )
3239
+
3240
+ validation_scores["roc_auc_by_class"] = roc_auc # Individual class AUCs
3241
+ validation_scores["fpr"] = fpr
3242
+ validation_scores["tpr"] = tpr
2886
3243
 
2887
- if y_pred_proba is not None:
2888
- # Calculate ROC-AUC
2889
- validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
2890
- # PR-AUC (Precision-Recall AUC) calculation
2891
- validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
2892
3244
  else:
2893
3245
  raise ValueError(
2894
3246
  "Invalid purpose specified. Choose 'regression' or 'classification'."
2895
3247
  )
2896
3248
 
2897
3249
  return validation_scores
3250
+
3251
+ def plot_trees(
3252
+ X, y, cls, max_trees=500, test_size=0.2, random_state=42, early_stopping_rounds=None
3253
+ ):
3254
+ """
3255
+ # # Example usage:
3256
+ # X = np.random.rand(100, 10) # Example data with 100 samples and 10 features
3257
+ # y = np.random.randint(0, 2, 100) # Example binary target
3258
+ # # Using the function with different classifiers
3259
+ # # Random Forest example
3260
+ # plot_trees(X, y, RandomForestClassifier(), max_trees=100)
3261
+ # # Gradient Boosting with early stopping example
3262
+ # plot_trees(X, y, GradientBoostingClassifier(), max_trees=100, early_stopping_rounds=10)
3263
+ # # Extra Trees example
3264
+ # plot_trees(X, y, ExtraTreesClassifier(), max_trees=100)
3265
+ Master function to plot error rates (OOB, training, and testing) for different tree-based ensemble classifiers.
3266
+
3267
+ Parameters:
3268
+ - X (array-like): Feature matrix.
3269
+ - y (array-like): Target labels.
3270
+ - cls (object): Tree-based ensemble classifier instance (e.g., RandomForestClassifier()).
3271
+ - max_trees (int): Maximum number of trees to evaluate. Default is 500.
3272
+ - test_size (float): Proportion of data to use as test set for testing error. Default is 0.2.
3273
+ - random_state (int): Random state for reproducibility. Default is 42.
3274
+ - early_stopping_rounds (int): For boosting models only, stops training if validation error doesn't improve after specified rounds.
3275
+
3276
+ Returns:
3277
+ - None
3278
+ """
3279
+ from sklearn.model_selection import train_test_split
3280
+ from sklearn.metrics import accuracy_score
3281
+ from sklearn.ensemble import (
3282
+ RandomForestClassifier,
3283
+ BaggingClassifier,
3284
+ ExtraTreesClassifier,
3285
+ )
3286
+ from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
3287
+ # Split data for training and testing error calculation
3288
+ x_train, x_test, y_train, y_test = train_test_split(
3289
+ X, y, test_size=test_size, random_state=random_state
3290
+ )
3291
+
3292
+ # Initialize lists to store error rates
3293
+ oob_error_rate = []
3294
+ train_error_rate = []
3295
+ test_error_rate = []
3296
+ validation_error = None
3297
+
3298
+ # Configure classifier based on type
3299
+ oob_enabled = False # Default to no OOB error unless explicitly set
3300
+
3301
+ if isinstance(cls, (RandomForestClassifier, ExtraTreesClassifier)):
3302
+ # Enable OOB if cls supports it and is using bootstrapping
3303
+ cls.set_params(warm_start=True, n_estimators=1)
3304
+ if hasattr(cls, "oob_score"):
3305
+ cls.set_params(bootstrap=True, oob_score=True)
3306
+ oob_enabled = True
3307
+ elif isinstance(cls, BaggingClassifier):
3308
+ cls.set_params(warm_start=True, bootstrap=True, oob_score=True, n_estimators=1)
3309
+ oob_enabled = True
3310
+ elif isinstance(cls, (AdaBoostClassifier, GradientBoostingClassifier)):
3311
+ cls.set_params(n_estimators=1)
3312
+ oob_enabled = False
3313
+ if early_stopping_rounds:
3314
+ validation_error = []
3315
+
3316
+ # Train and evaluate with an increasing number of trees
3317
+ for i in range(1, max_trees + 1):
3318
+ cls.set_params(n_estimators=i)
3319
+ cls.fit(x_train, y_train)
3320
+
3321
+ # Calculate OOB error (for models that support it)
3322
+ if oob_enabled and hasattr(cls, "oob_score_") and cls.oob_score:
3323
+ oob_error = 1 - cls.oob_score_
3324
+ oob_error_rate.append(oob_error)
3325
+
3326
+ # Calculate training error
3327
+ train_error = 1 - accuracy_score(y_train, cls.predict(x_train))
3328
+ train_error_rate.append(train_error)
3329
+
3330
+ # Calculate testing error
3331
+ test_error = 1 - accuracy_score(y_test, cls.predict(x_test))
3332
+ test_error_rate.append(test_error)
3333
+
3334
+ # For boosting models, use validation error with early stopping
3335
+ if early_stopping_rounds and isinstance(
3336
+ cls, (AdaBoostClassifier, GradientBoostingClassifier)
3337
+ ):
3338
+ val_error = 1 - accuracy_score(y_test, cls.predict(x_test))
3339
+ validation_error.append(val_error)
3340
+ if len(validation_error) > early_stopping_rounds:
3341
+ # Stop if validation error has not improved in early_stopping_rounds
3342
+ if validation_error[-early_stopping_rounds:] == sorted(
3343
+ validation_error[-early_stopping_rounds:]
3344
+ ):
3345
+ print(f"Early stopping at tree {i} due to lack of improvement in validation error.")
3346
+ break
3347
+
3348
+ # Plot results
3349
+ plt.figure(figsize=(10, 6))
3350
+ if oob_error_rate:
3351
+ plt.plot(
3352
+ range(1, len(oob_error_rate) + 1),
3353
+ oob_error_rate,
3354
+ color="black",
3355
+ label="OOB Error Rate",
3356
+ linewidth=2,
3357
+ )
3358
+ if train_error_rate:
3359
+ plt.plot(
3360
+ range(1, len(train_error_rate) + 1),
3361
+ train_error_rate,
3362
+ linestyle="dotted",
3363
+ color="green",
3364
+ label="Training Error Rate",
3365
+ )
3366
+ if test_error_rate:
3367
+ plt.plot(
3368
+ range(1, len(test_error_rate) + 1),
3369
+ test_error_rate,
3370
+ linestyle="dashed",
3371
+ color="red",
3372
+ label="Testing Error Rate",
3373
+ )
3374
+ if validation_error:
3375
+ plt.plot(
3376
+ range(1, len(validation_error) + 1),
3377
+ validation_error,
3378
+ linestyle="solid",
3379
+ color="blue",
3380
+ label="Validation Error (Boosting)",
3381
+ )
3382
+
3383
+ # Customize plot
3384
+ plt.xlabel("Number of Trees")
3385
+ plt.ylabel("Error Rate")
3386
+ plt.title(f"Error Rate Analysis for {cls.__class__.__name__}")
3387
+ plt.legend(loc="upper right")
3388
+ plt.grid(True)
3389
+ plt.show()
3390
+
3391
+ def img_datasets_preprocessing(
3392
+ data: pd.DataFrame,
3393
+ x_col: str,
3394
+ y_col: str=None,
3395
+ target_size: tuple = (224, 224),
3396
+ batch_size: int = 128,
3397
+ class_mode: str = "raw",
3398
+ shuffle: bool = False,
3399
+ augment: bool = False,
3400
+ scaler: str = 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
3401
+ grayscale: bool = False,
3402
+ encoder: str = "label", # Options: 'label', 'onehot', 'binary'
3403
+ label_encoder=None,
3404
+ kws_augmentation: dict = None,
3405
+ verbose: bool = True,
3406
+ drop_missing: bool = True,
3407
+ output="df", # "iterator":data_iterator,'df':return DataFrame
3408
+ ):
3409
+ """
3410
+ Enhanced preprocessing function for loading and preparing image data from a DataFrame.
3411
+
3412
+ Parameters:
3413
+ - df (pd.DataFrame): Input DataFrame with image paths and labels.
3414
+ - x_col (str): Column in `df` containing image file paths.
3415
+ - y_col (str): Column in `df` containing image labels.
3416
+ - target_size (tuple): Desired image size in (height, width).
3417
+ - batch_size (int): Number of images per batch.
3418
+ - class_mode (str): Mode of label ('raw', 'categorical', 'binary').
3419
+ - shuffle (bool): Shuffle the images in the DataFrame.
3420
+ - augment (bool): Apply data augmentation.
3421
+ - scaler (str): 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
3422
+ - grayscale (bool): Convert images to grayscale.
3423
+ - normalize (bool): Normalize image data to [0, 1] range.
3424
+ - encoder (str): Label encoder method ('label', 'onehot', 'binary').
3425
+ - label_encoder: Optional pre-defined label encoder.
3426
+ - kws_augmentation (dict): Parameters for data augmentation.
3427
+ - verbose (bool): Print status messages.
3428
+ - drop_missing (bool): Drop rows with missing or invalid image paths.
3429
+
3430
+ Returns:
3431
+ - pd.DataFrame: DataFrame with flattened image pixels and 'Label' column.
3432
+ """
3433
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
3434
+ from tensorflow.keras.utils import to_categorical
3435
+ from sklearn.preprocessing import LabelEncoder
3436
+ from PIL import Image
3437
+ import os
3438
+
3439
+ # Validate input DataFrame for required columns
3440
+ if y_col:
3441
+ assert (
3442
+ x_col in data.columns and y_col in data.columns
3443
+ ), "Missing required columns in DataFrame."
3444
+ if y_col is None:
3445
+ class_mode=None
3446
+ # 输出格式
3447
+ output = ips.strcmp(output,[
3448
+ "generator","tf","iterator","transform","transformer","dataframe",
3449
+ "df","pd","pandas"])[0]
3450
+
3451
+ # Handle missing file paths
3452
+ if drop_missing:
3453
+ data = data[
3454
+ data[x_col].apply(lambda path: os.path.exists(path) and os.path.isfile(path))
3455
+ ]
3456
+
3457
+ # Encoding labels if necessary
3458
+ if encoder and y_col is not None:
3459
+ if encoder == "binary":
3460
+ data[y_col] = (data[y_col] == data[y_col].unique()[0]).astype(int)
3461
+ elif encoder == "onehot":
3462
+ if not label_encoder:
3463
+ label_encoder = LabelEncoder()
3464
+ data[y_col] = label_encoder.fit_transform(data[y_col])
3465
+ data[y_col] = to_categorical(data[y_col])
3466
+ elif encoder == "label":
3467
+ if not label_encoder:
3468
+ label_encoder = LabelEncoder()
3469
+ data[y_col] = label_encoder.fit_transform(data[y_col])
3470
+
3471
+ # Set up data augmentation
3472
+ if augment:
3473
+ aug_params = {
3474
+ "rotation_range": 20,
3475
+ "width_shift_range": 0.2,
3476
+ "height_shift_range": 0.2,
3477
+ "shear_range": 0.2,
3478
+ "zoom_range": 0.2,
3479
+ "horizontal_flip": True,
3480
+ "fill_mode": "nearest",
3481
+ }
3482
+ if kws_augmentation:
3483
+ aug_params.update(kws_augmentation)
3484
+ dat = ImageDataGenerator(rescale=scaler, **aug_params)
3485
+ dat = ImageDataGenerator(
3486
+ rescale=1.0 / 255 if scaler == 'normalize' else None, **aug_params)
3487
+
3488
+ else:
3489
+ dat = ImageDataGenerator(
3490
+ rescale=1.0 / 255 if scaler == 'normalize' else None)
3491
+
3492
+ # Create DataFrameIterator
3493
+ data_iterator = dat.flow_from_dataframe(
3494
+ dataframe=data,
3495
+ x_col=x_col,
3496
+ y_col=y_col,
3497
+ target_size=target_size,
3498
+ color_mode="grayscale" if grayscale else "rgb",
3499
+ batch_size=batch_size,
3500
+ class_mode=class_mode,
3501
+ shuffle=shuffle,
3502
+ )
3503
+ print(f"target_size:{target_size}")
3504
+ if output.lower() in ["generator", "tf", "iterator", "transform", "transformer"]:
3505
+ return data_iterator
3506
+ elif output.lower() in ["dataframe", "df", "pd", "pandas"]:
3507
+ # Initialize list to collect processed data
3508
+ data_list = []
3509
+ total_batches = data_iterator.n // batch_size
3510
+
3511
+ # Load, resize, and process images in batches
3512
+ for i, (batch_images, batch_labels) in enumerate(data_iterator):
3513
+ for img, label in zip(batch_images, batch_labels):
3514
+ if scaler == ['normalize','raw']:
3515
+ # Already rescaled by 1.0/255 in ImageDataGenerator
3516
+ pass
3517
+ elif scaler == 'standardize':
3518
+ # Standardize by subtracting mean and dividing by std
3519
+ img = (img - np.mean(img)) / np.std(img)
3520
+ elif scaler == 'clahe':
3521
+ # Apply CLAHE to the image
3522
+ img = apply_clahe(img)
3523
+ flat_img = img.flatten()
3524
+ data_list.append(np.append(flat_img, label))
3525
+
3526
+ # Stop when all images have been processed
3527
+ if i >= total_batches:
3528
+ break
3529
+
3530
+ # Define column names for flattened image data
3531
+ pixel_count = target_size[0] * target_size[1] * (1 if grayscale else 3)
3532
+ column_names = [f"pixel_{i}" for i in range(pixel_count)] + ["Label"]
3533
+
3534
+ # Create DataFrame from flattened data
3535
+ df_img = pd.DataFrame(data_list, columns=column_names)
3536
+
3537
+ if verbose:
3538
+ print("Processed images:", len(df_img))
3539
+ print("Final DataFrame shape:", df_img.shape)
3540
+ display(df_img.head())
3541
+
3542
+ return df_img
3543
+ # Function to apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
3544
+ def apply_clahe(img):
3545
+ import cv2
3546
+ lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) # Convert to LAB color space
3547
+ l, a, b = cv2.split(lab) # Split into channels
3548
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
3549
+ cl = clahe.apply(l) # Apply CLAHE to the L channel
3550
+ limg = cv2.merge((cl, a, b)) # Merge back the channels
3551
+ img_clahe = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) # Convert back to RGB
3552
+ return img_clahe