py2ls 0.2.4.10.3__py3-none-any.whl → 0.2.4.10.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/ml2ls.py CHANGED
@@ -4,12 +4,13 @@ from sklearn.ensemble import (
4
4
  AdaBoostClassifier,
5
5
  BaggingClassifier,
6
6
  )
7
- from sklearn.svm import SVC,SVR
7
+ from sklearn.svm import SVC, SVR
8
8
  from sklearn.calibration import CalibratedClassifierCV
9
9
  from sklearn.model_selection import GridSearchCV, StratifiedKFold
10
10
  from sklearn.linear_model import (
11
11
  LassoCV,
12
- LogisticRegression,LinearRegression,
12
+ LogisticRegression,
13
+ LinearRegression,
13
14
  Lasso,
14
15
  Ridge,
15
16
  RidgeClassifierCV,
@@ -47,7 +48,7 @@ from . import plot
47
48
  import matplotlib.pyplot as plt
48
49
  import seaborn as sns
49
50
 
50
- plt.style.use("paper")
51
+ plt.style.use(str(get_cwd()) + "/data/styles/stylelib/paper.mplstyle")
51
52
  import logging
52
53
  import warnings
53
54
 
@@ -334,13 +335,13 @@ def features_naive_bayes(x_train: pd.DataFrame, y_train: pd.Series) -> list:
334
335
  probabilities = nb.predict_proba(x_train)
335
336
  # Limit the number of features safely, choosing the lesser of half the features or all columns
336
337
  n_features = min(x_train.shape[1] // 2, len(x_train.columns))
337
-
338
+
338
339
  # Sort probabilities, then map to valid column indices
339
340
  sorted_indices = np.argsort(probabilities.max(axis=1))[:n_features]
340
-
341
+
341
342
  # Ensure indices are within the column bounds of x_train
342
343
  valid_indices = sorted_indices[sorted_indices < len(x_train.columns)]
343
-
344
+
344
345
  return x_train.columns[valid_indices]
345
346
 
346
347
 
@@ -575,15 +576,28 @@ def get_features(
575
576
  bagging_params: Optional[Dict] = None,
576
577
  knn_params: Optional[Dict] = None,
577
578
  cls: list = [
578
- "lasso","ridge","Elastic Net(Enet)","gradient Boosting","Random forest (rf)","XGBoost (xgb)","Support Vector Machine(svm)",
579
- "naive bayes","Linear Discriminant Analysis (lda)","adaboost","DecisionTree","KNeighbors","Bagging"],
579
+ "lasso",
580
+ "ridge",
581
+ "Elastic Net(Enet)",
582
+ "gradient Boosting",
583
+ "Random forest (rf)",
584
+ "XGBoost (xgb)",
585
+ "Support Vector Machine(svm)",
586
+ "naive bayes",
587
+ "Linear Discriminant Analysis (lda)",
588
+ "adaboost",
589
+ "DecisionTree",
590
+ "KNeighbors",
591
+ "Bagging",
592
+ ],
580
593
  metrics: Optional[List[str]] = None,
581
594
  cv_folds: int = 5,
582
595
  strict: bool = False,
583
596
  n_shared: int = 2, # 只要有两个方法有重合,就纳入common genes
584
597
  use_selected_features: bool = True,
585
598
  plot_: bool = True,
586
- dir_save:str="./") -> dict:
599
+ dir_save: str = "./",
600
+ ) -> dict:
587
601
  """
588
602
  Master function to perform feature selection and validate models.
589
603
  """
@@ -598,14 +612,14 @@ def get_features(
598
612
 
599
613
  # fill na
600
614
  if fill_missing:
601
- ips.df_fillna(data=X,method='knn',inplace=True,axis=0)
602
- if isinstance(y, str) and y in X.columns:
603
- y_col_name=y
604
- y=X[y]
605
- y=ips.df_encoder(pd.DataFrame(y),method='dummy')
606
- X = X.drop(y_col_name,axis=1)
615
+ ips.df_fillna(data=X, method="knn", inplace=True, axis=0)
616
+ if isinstance(y, str) and y in X.columns:
617
+ y_col_name = y
618
+ y = X[y]
619
+ y = ips.df_encoder(pd.DataFrame(y), method="dummy")
620
+ X = X.drop(y_col_name, axis=1)
607
621
  else:
608
- y=ips.df_encoder(pd.DataFrame(y),method='dummy').values.ravel()
622
+ y = ips.df_encoder(pd.DataFrame(y), method="dummy").values.ravel()
609
623
  y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
610
624
  y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
611
625
 
@@ -817,7 +831,7 @@ def get_features(
817
831
  top_knn_features,
818
832
  strict=strict,
819
833
  n_shared=n_shared,
820
- verbose=False
834
+ verbose=False,
821
835
  )
822
836
 
823
837
  # Use selected features or all features for model validation
@@ -899,13 +913,14 @@ def get_features(
899
913
  results = {
900
914
  "selected_features": features_df,
901
915
  "cv_train_scores": cv_train_results_df,
902
- "cv_test_scores": rank_models(cv_test_results_df,plot_=plot_),
916
+ "cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
903
917
  "common_features": list(common_features),
904
918
  }
905
- if all([plot_,dir_save]):
919
+ if all([plot_, dir_save]):
906
920
  from datetime import datetime
921
+
907
922
  now_ = datetime.now().strftime("%y%m%d_%H%M%S")
908
- ips.figsave(dir_save+f"features{now_}.pdf")
923
+ ips.figsave(dir_save + f"features{now_}.pdf")
909
924
  else:
910
925
  results = {
911
926
  "selected_features": pd.DataFrame(),
@@ -931,7 +946,7 @@ def validate_features(
931
946
  metrics: Optional[list] = None,
932
947
  random_state: int = 1,
933
948
  smote: bool = False,
934
- n_jobs:int = -1,
949
+ n_jobs: int = -1,
935
950
  plot_: bool = True,
936
951
  class_weight: str = "balanced",
937
952
  ) -> dict:
@@ -952,8 +967,11 @@ def validate_features(
952
967
 
953
968
  """
954
969
  from tqdm import tqdm
970
+
955
971
  # Ensure common features are selected
956
- common_features = ips.shared(common_features, x_train.columns, x_true.columns, strict=True,verbose=False)
972
+ common_features = ips.shared(
973
+ common_features, x_train.columns, x_true.columns, strict=True, verbose=False
974
+ )
957
975
 
958
976
  # Filter the training and validation datasets for the common features
959
977
  x_train_selected = x_train[common_features]
@@ -1007,8 +1025,7 @@ def validate_features(
1007
1025
  l1_ratio=0.5,
1008
1026
  random_state=random_state,
1009
1027
  ),
1010
- "XGBoost": xgb.XGBClassifier(eval_metric="logloss"
1011
- ),
1028
+ "XGBoost": xgb.XGBClassifier(eval_metric="logloss"),
1012
1029
  "Naive Bayes": GaussianNB(),
1013
1030
  "LDA": LinearDiscriminantAnalysis(),
1014
1031
  }
@@ -1078,11 +1095,11 @@ def validate_features(
1078
1095
 
1079
1096
  # Validate each classifier with GridSearchCV
1080
1097
  for name, clf in tqdm(
1081
- models.items(),
1082
- desc="for metric in metrics",
1083
- colour="green",
1084
- bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
1085
- ):
1098
+ models.items(),
1099
+ desc="for metric in metrics",
1100
+ colour="green",
1101
+ bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
1102
+ ):
1086
1103
  print(f"\nValidating {name} on the validation dataset:")
1087
1104
 
1088
1105
  # Check if `predict_proba` method exists; if not, use CalibratedClassifierCV
@@ -1162,7 +1179,7 @@ def validate_features(
1162
1179
  if y_pred_proba is not None:
1163
1180
  # fpr, tpr, roc_auc = dict(), dict(), dict()
1164
1181
  fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
1165
- lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba,verbose=False)
1182
+ lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
1166
1183
  roc_auc = auc(fpr, tpr)
1167
1184
  roc_info = {
1168
1185
  "fpr": fpr.tolist(),
@@ -1197,6 +1214,7 @@ def validate_features(
1197
1214
  # Validate models using the validation dataset (X_val, y_val)
1198
1215
  # validation_results = validate_features(X, y, X_val, y_val, common_features)
1199
1216
 
1217
+
1200
1218
  # # If you want to access validation scores
1201
1219
  # print(validation_results)
1202
1220
  def plot_validate_features(res_val):
@@ -1204,47 +1222,75 @@ def plot_validate_features(res_val):
1204
1222
  plot the results of 'validate_features()'
1205
1223
  """
1206
1224
  colors = plot.get_color(len(ips.flatten(res_val["pr_curve"].index)))
1207
- if res_val.shape[0]>5:
1208
- alpha=0
1209
- figsize=[8,10]
1210
- subplot_layout=[1,2]
1211
- ncols=2
1212
- bbox_to_anchor=[1.5,0.6]
1225
+ if res_val.shape[0] > 5:
1226
+ alpha = 0
1227
+ figsize = [8, 10]
1228
+ subplot_layout = [1, 2]
1229
+ ncols = 2
1230
+ bbox_to_anchor = [1.5, 0.6]
1213
1231
  else:
1214
- alpha=0.03
1215
- figsize=[10,6]
1216
- subplot_layout=[1,1]
1217
- ncols=1
1218
- bbox_to_anchor=[1,1]
1232
+ alpha = 0.03
1233
+ figsize = [10, 6]
1234
+ subplot_layout = [1, 1]
1235
+ ncols = 1
1236
+ bbox_to_anchor = [1, 1]
1219
1237
  nexttile = plot.subplot(figsize=figsize)
1220
- ax = nexttile(subplot_layout[0],subplot_layout[1])
1238
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1221
1239
  for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1222
1240
  fpr = res_val["roc_curve"][model_name]["fpr"]
1223
1241
  tpr = res_val["roc_curve"][model_name]["tpr"]
1224
1242
  (lower_ci, upper_ci) = res_val["roc_curve"][model_name]["ci95"]
1225
1243
  mean_auc = res_val["roc_curve"][model_name]["auc"]
1226
1244
  plot_roc_curve(
1227
- fpr,tpr,mean_auc,lower_ci,upper_ci,model_name=model_name,
1228
- lw=1.5,color=colors[i],alpha=alpha,ax=ax)
1229
- plot.figsets(sp=2,legend=dict(loc="upper right", ncols=ncols, fontsize=8, bbox_to_anchor=[1.5,0.6],markerscale=0.8))
1245
+ fpr,
1246
+ tpr,
1247
+ mean_auc,
1248
+ lower_ci,
1249
+ upper_ci,
1250
+ model_name=model_name,
1251
+ lw=1.5,
1252
+ color=colors[i],
1253
+ alpha=alpha,
1254
+ ax=ax,
1255
+ )
1256
+ plot.figsets(
1257
+ sp=2,
1258
+ legend=dict(
1259
+ loc="upper right",
1260
+ ncols=ncols,
1261
+ fontsize=8,
1262
+ bbox_to_anchor=[1.5, 0.6],
1263
+ markerscale=0.8,
1264
+ ),
1265
+ )
1230
1266
  # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1231
1267
 
1232
- ax = nexttile(subplot_layout[0],subplot_layout[1])
1268
+ ax = nexttile(subplot_layout[0], subplot_layout[1])
1233
1269
  for i, model_name in enumerate(ips.flatten(res_val["pr_curve"].index)):
1234
1270
  plot_pr_curve(
1235
1271
  recall=res_val["pr_curve"][model_name]["recall"],
1236
1272
  precision=res_val["pr_curve"][model_name]["precision"],
1237
1273
  avg_precision=res_val["pr_curve"][model_name]["avg_precision"],
1238
1274
  model_name=model_name,
1239
- color=colors[i],lw=1.5,alpha=alpha,ax=ax)
1240
- plot.figsets(sp=2,legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5,0.5]))
1275
+ color=colors[i],
1276
+ lw=1.5,
1277
+ alpha=alpha,
1278
+ ax=ax,
1279
+ )
1280
+ plot.figsets(
1281
+ sp=2,
1282
+ legend=dict(loc="upper right", ncols=1, fontsize=8, bbox_to_anchor=[1.5, 0.5]),
1283
+ )
1241
1284
  # plot.split_legend(ax,n=2, loc=["upper left", "lower left"],bbox=[[1,0.5],[1,0.5]],ncols=2,labelcolor="k",fontsize=8)
1242
-
1243
- def plot_validate_features_single(res_val,figsize=None):
1285
+
1286
+
1287
+ def plot_validate_features_single(res_val, figsize=None):
1244
1288
  if figsize is None:
1245
1289
  nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3)
1246
1290
  else:
1247
- nexttile = plot.subplot(len(ips.flatten(res_val["pr_curve"].index)), 3,figsize=figsize)
1291
+ nexttile = plot.subplot(
1292
+ len(ips.flatten(res_val["pr_curve"].index)), 3, figsize=figsize
1293
+ )
1248
1294
  for model_name in ips.flatten(res_val["pr_curve"].index):
1249
1295
  fpr = res_val["roc_curve"][model_name]["fpr"]
1250
1296
  tpr = res_val["roc_curve"][model_name]["tpr"]
@@ -1268,7 +1314,10 @@ def plot_validate_features_single(res_val,figsize=None):
1268
1314
  plot_cm(res_val["confusion_matrix"][model_name], ax=nexttile(), normalize=False)
1269
1315
  plot.figsets(title=model_name, sp=2)
1270
1316
 
1271
- def cal_auc_ci(y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,verbose=True):
1317
+
1318
+ def cal_auc_ci(
1319
+ y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1, verbose=True
1320
+ ):
1272
1321
  y_true = np.asarray(y_true)
1273
1322
  y_pred = np.asarray(y_pred)
1274
1323
  bootstrapped_scores = []
@@ -1298,10 +1347,10 @@ def cal_auc_ci(y_true, y_pred, n_bootstraps=1000, ci=0.95, random_state=1,verbos
1298
1347
  confidence_upper = sorted_scores[int(ci * len(sorted_scores))]
1299
1348
  if verbose:
1300
1349
  print(
1301
- "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1302
- confidence_lower, confidence_upper
1350
+ "Confidence interval for the score: [{:0.3f} - {:0.3}]".format(
1351
+ confidence_lower, confidence_upper
1352
+ )
1303
1353
  )
1304
- )
1305
1354
  return confidence_lower, confidence_upper
1306
1355
 
1307
1356
 
@@ -1339,7 +1388,7 @@ def plot_roc_curve(
1339
1388
  # Plot ROC curve and the diagonal reference line
1340
1389
  ax.fill_between(fpr, tpr, alpha=alpha, color=color)
1341
1390
  ax.plot([0, 1], [0, 1], color=diagonal_color, clip_on=False, linestyle="--")
1342
- ax.plot(fpr, tpr, color=color, lw=lw, label=label,clip_on=False, **kwargs)
1391
+ ax.plot(fpr, tpr, color=color, lw=lw, label=label, clip_on=False, **kwargs)
1343
1392
  # Setting plot limits, labels, and title
1344
1393
  ax.set_xlim([-0.01, 1.0])
1345
1394
  ax.set_ylim([0.0, 1.0])
@@ -1536,12 +1585,11 @@ def plot_cm(
1536
1585
  color=color,
1537
1586
  fontsize=fontsize,
1538
1587
  )
1539
-
1540
- plot.figsets(ax=ax,
1541
- boxloc="none"
1542
- )
1588
+
1589
+ plot.figsets(ax=ax, boxloc="none")
1543
1590
  return ax
1544
1591
 
1592
+
1545
1593
  def rank_models(
1546
1594
  cv_test_scores,
1547
1595
  rm_outlier=False,
@@ -1644,7 +1692,7 @@ def rank_models(
1644
1692
  if rm_outlier:
1645
1693
  cv_test_scores_ = ips.df_outlier(cv_test_scores)
1646
1694
  else:
1647
- cv_test_scores_=cv_test_scores
1695
+ cv_test_scores_ = cv_test_scores
1648
1696
 
1649
1697
  # Normalize the scores of metrics if normalize is True
1650
1698
  scaler = MinMaxScaler()
@@ -1673,7 +1721,7 @@ def rank_models(
1673
1721
  )
1674
1722
  plt.title("Classifier Performance")
1675
1723
  plt.tight_layout()
1676
- return plt
1724
+ return plt
1677
1725
 
1678
1726
  nexttile = plot.subplot(2, 2, figsize=[10, 7])
1679
1727
  generate_bar_plot(nexttile(), top_models.dropna())
@@ -1703,10 +1751,11 @@ def rank_models(
1703
1751
 
1704
1752
  # figsave("classifier_performance.pdf")
1705
1753
 
1754
+
1706
1755
  def predict(
1707
1756
  x_train: pd.DataFrame,
1708
1757
  y_train: pd.Series,
1709
- x_true: pd.DataFrame=None,
1758
+ x_true: pd.DataFrame = None,
1710
1759
  y_true: Optional[pd.Series] = None,
1711
1760
  common_features: set = None,
1712
1761
  purpose: str = "classification", # 'classification' or 'regression'
@@ -1714,117 +1763,156 @@ def predict(
1714
1763
  metrics: Optional[List[str]] = None,
1715
1764
  random_state: int = 1,
1716
1765
  smote: bool = False,
1717
- n_jobs:int = -1,
1766
+ n_jobs: int = -1,
1718
1767
  plot_: bool = True,
1719
- dir_save:str="./",
1720
- test_size:float=0.2,# specific only when x_true is None
1721
- cv_folds:int=5,# more cv_folds 得更加稳定,auc可能更低
1722
- cv_level:str="l",#"s":'low',"m":'medium',"l":"high"
1768
+ dir_save: str = "./",
1769
+ test_size: float = 0.2, # specific only when x_true is None
1770
+ cv_folds: int = 5, # more cv_folds 得更加稳定,auc可能更低
1771
+ cv_level: str = "l", # "s":'low',"m":'medium',"l":"high"
1723
1772
  class_weight: str = "balanced",
1724
- verbose:bool=False,
1773
+ verbose: bool = False,
1725
1774
  ) -> pd.DataFrame:
1726
- """
1727
- 第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
1728
- Usage:
1729
- (1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
1730
- predict 函数会根据 test_size 参数,将 x_train 和 y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
1731
- (2). predict(x_train, y_train, x_true,...)使用 x_train 和 y_train 训练并对 x_true 进行预测
1732
- 由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true,
1733
- 因此无法与真实值进行对比。
1734
- (3). predict(x_train, y_train, x_true, y_true,...)使用 x_train 和 y_train 训练,并验证 x_true 与真实标签 y_true.
1735
- predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
1736
- 计算验证指标,完成对 x_true 的真正验证。
1737
- trains and validates a variety of machine learning models for both classification and regression tasks.
1738
- It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
1739
- feature scaling, and handling of class imbalance through SMOTE.
1740
-
1741
- Parameters:
1742
- - x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
1743
- - y_train (pd.Series):Target variable for the training dataset.
1744
- - x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
1745
- - y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
1746
- - common_features (set, optional):Specifies a subset of features common across training and test data.
1747
- - purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
1748
- metrics and models are applied.
1749
- - cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
1750
- - metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
1751
- - random_state (int, default = 1):Random seed to ensure reproducibility.
1752
- - smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
1753
- - n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
1754
- - plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
1755
- - test_size (float, default = 0.2):Test data proportion if x_true is not provided.
1756
- - cv_folds (int, default = 5):Number of cross-validation folds.
1757
- - cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
1758
- - class_weight (str, default = "balanced"):Balances class weights in classification tasks.
1759
- - verbose (bool, default = False):If True, prints detailed output during model training.
1760
- - dir_save (str, default = "./"):Directory path to save plot outputs and results.
1761
-
1762
- Key Steps in the Function:
1763
- Model Initialization: Depending on purpose, initializes either classification or regression models.
1764
- Feature Selection: Ensures training and test sets have matching feature columns.
1765
- SMOTE Application: Balances classes if smote is enabled and the task is classification.
1766
- Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
1767
- Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
1775
+ """
1776
+ 第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
1777
+ Usage:
1778
+ (1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
1779
+ predict 函数会根据 test_size 参数,将 x_train 和 y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
1780
+ (2). predict(x_train, y_train, x_true,...)使用 x_train 和 y_train 训练并对 x_true 进行预测
1781
+ 由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true,
1782
+ 因此无法与真实值进行对比。
1783
+ (3). predict(x_train, y_train, x_true, y_true,...)使用 x_train 和 y_train 训练,并验证 x_true 与真实标签 y_true.
1784
+ predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
1785
+ 计算验证指标,完成对 x_true 的真正验证。
1786
+ trains and validates a variety of machine learning models for both classification and regression tasks.
1787
+ It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
1788
+ feature scaling, and handling of class imbalance through SMOTE.
1789
+
1790
+ Parameters:
1791
+ - x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
1792
+ - y_train (pd.Series):Target variable for the training dataset.
1793
+ - x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
1794
+ - y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
1795
+ - common_features (set, optional):Specifies a subset of features common across training and test data.
1796
+ - purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
1797
+ metrics and models are applied.
1798
+ - cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
1799
+ - metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
1800
+ - random_state (int, default = 1):Random seed to ensure reproducibility.
1801
+ - smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
1802
+ - n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
1803
+ - plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
1804
+ - test_size (float, default = 0.2):Test data proportion if x_true is not provided.
1805
+ - cv_folds (int, default = 5):Number of cross-validation folds.
1806
+ - cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
1807
+ - class_weight (str, default = "balanced"):Balances class weights in classification tasks.
1808
+ - verbose (bool, default = False):If True, prints detailed output during model training.
1809
+ - dir_save (str, default = "./"):Directory path to save plot outputs and results.
1810
+
1811
+ Key Steps in the Function:
1812
+ Model Initialization: Depending on purpose, initializes either classification or regression models.
1813
+ Feature Selection: Ensures training and test sets have matching feature columns.
1814
+ SMOTE Application: Balances classes if smote is enabled and the task is classification.
1815
+ Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
1816
+ Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
1768
1817
  """
1769
1818
  from tqdm import tqdm
1770
- from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, ExtraTreesClassifier, ExtraTreesRegressor, BaggingClassifier, BaggingRegressor, AdaBoostClassifier, AdaBoostRegressor
1819
+ from sklearn.ensemble import (
1820
+ RandomForestClassifier,
1821
+ RandomForestRegressor,
1822
+ ExtraTreesClassifier,
1823
+ ExtraTreesRegressor,
1824
+ BaggingClassifier,
1825
+ BaggingRegressor,
1826
+ AdaBoostClassifier,
1827
+ AdaBoostRegressor,
1828
+ )
1771
1829
  from sklearn.svm import SVC, SVR
1772
1830
  from sklearn.tree import DecisionTreeRegressor
1773
- from sklearn.linear_model import LogisticRegression, ElasticNet, ElasticNetCV, LinearRegression, Lasso,RidgeClassifierCV, Perceptron, SGDClassifier
1774
- from sklearn.neighbors import KNeighborsClassifier,KNeighborsRegressor
1775
- from sklearn.naive_bayes import GaussianNB,BernoulliNB
1831
+ from sklearn.linear_model import (
1832
+ LogisticRegression,
1833
+ ElasticNet,
1834
+ ElasticNetCV,
1835
+ LinearRegression,
1836
+ Lasso,
1837
+ RidgeClassifierCV,
1838
+ Perceptron,
1839
+ SGDClassifier,
1840
+ )
1841
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
1842
+ from sklearn.naive_bayes import GaussianNB, BernoulliNB
1776
1843
  from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
1777
1844
  import xgboost as xgb
1778
1845
  import lightgbm as lgb
1779
1846
  import catboost as cb
1780
1847
  from sklearn.neural_network import MLPClassifier, MLPRegressor
1781
1848
  from sklearn.model_selection import GridSearchCV, StratifiedKFold, KFold
1782
- from sklearn.discriminant_analysis import LinearDiscriminantAnalysis,QuadraticDiscriminantAnalysis
1849
+ from sklearn.discriminant_analysis import (
1850
+ LinearDiscriminantAnalysis,
1851
+ QuadraticDiscriminantAnalysis,
1852
+ )
1783
1853
  from sklearn.preprocessing import PolynomialFeatures
1784
1854
 
1785
-
1786
1855
  # 拼写检查
1787
- purpose=ips.strcmp(purpose,['classification','regression'])[0]
1856
+ purpose = ips.strcmp(purpose, ["classification", "regression"])[0]
1788
1857
  print(f"{purpose} processing...")
1789
1858
  # Default models or regressors if not provided
1790
1859
  if purpose == "classification":
1791
1860
  model_ = {
1792
- "Random Forest": RandomForestClassifier(random_state=random_state, class_weight=class_weight),
1793
-
1794
- # SVC (Support Vector Classification)
1795
- "SVM": SVC(kernel="rbf",probability=True,class_weight=class_weight,random_state=random_state),
1796
-
1861
+ "Random Forest": RandomForestClassifier(
1862
+ random_state=random_state, class_weight=class_weight
1863
+ ),
1864
+ # SVC (Support Vector Classification)
1865
+ "SVM": SVC(
1866
+ kernel="rbf",
1867
+ probability=True,
1868
+ class_weight=class_weight,
1869
+ random_state=random_state,
1870
+ ),
1797
1871
  # fit the best model without enforcing sparsity, which means it does not directly perform feature selection.
1798
- "Logistic Regression": LogisticRegression(class_weight=class_weight, random_state=random_state),
1799
-
1872
+ "Logistic Regression": LogisticRegression(
1873
+ class_weight=class_weight, random_state=random_state
1874
+ ),
1800
1875
  # Logistic Regression with L1 Regularization (Lasso)
1801
- "Lasso Logistic Regression": LogisticRegression(penalty="l1", solver="saga", random_state=random_state),
1876
+ "Lasso Logistic Regression": LogisticRegression(
1877
+ penalty="l1", solver="saga", random_state=random_state
1878
+ ),
1802
1879
  "Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
1803
- "XGBoost": xgb.XGBClassifier(eval_metric="logloss",random_state=random_state,),
1880
+ "XGBoost": xgb.XGBClassifier(
1881
+ eval_metric="logloss",
1882
+ random_state=random_state,
1883
+ ),
1804
1884
  "KNN": KNeighborsClassifier(n_neighbors=5),
1805
1885
  "Naive Bayes": GaussianNB(),
1806
1886
  "Linear Discriminant Analysis": LinearDiscriminantAnalysis(),
1807
- "AdaBoost": AdaBoostClassifier(algorithm='SAMME', random_state=random_state),
1887
+ "AdaBoost": AdaBoostClassifier(
1888
+ algorithm="SAMME", random_state=random_state
1889
+ ),
1808
1890
  # "LightGBM": lgb.LGBMClassifier(random_state=random_state, class_weight=class_weight),
1809
1891
  "CatBoost": cb.CatBoostClassifier(verbose=0, random_state=random_state),
1810
- "Extra Trees": ExtraTreesClassifier(random_state=random_state, class_weight=class_weight),
1892
+ "Extra Trees": ExtraTreesClassifier(
1893
+ random_state=random_state, class_weight=class_weight
1894
+ ),
1811
1895
  "Bagging": BaggingClassifier(random_state=random_state),
1812
1896
  "Neural Network": MLPClassifier(max_iter=500, random_state=random_state),
1813
1897
  "DecisionTree": DecisionTreeClassifier(),
1814
1898
  "Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis(),
1815
- "Ridge": RidgeClassifierCV(class_weight=class_weight, store_cv_results=True),
1899
+ "Ridge": RidgeClassifierCV(
1900
+ class_weight=class_weight, store_cv_results=True
1901
+ ),
1816
1902
  "Perceptron": Perceptron(random_state=random_state),
1817
1903
  "Bernoulli Naive Bayes": BernoulliNB(),
1818
- "SGDClassifier": SGDClassifier(random_state=random_state),
1904
+ "SGDClassifier": SGDClassifier(random_state=random_state),
1819
1905
  }
1820
1906
  elif purpose == "regression":
1821
1907
  model_ = {
1822
1908
  "Random Forest": RandomForestRegressor(random_state=random_state),
1823
- "SVM": SVR(),# SVR (Support Vector Regression)
1909
+ "SVM": SVR(), # SVR (Support Vector Regression)
1824
1910
  # "Lasso": Lasso(random_state=random_state), # 它和LassoCV相同(必须要提供alpha参数),
1825
- "LassoCV": LassoCV(cv=cv_folds, random_state=random_state),#LassoCV自动找出最适alpha,优于Lasso
1911
+ "LassoCV": LassoCV(
1912
+ cv=cv_folds, random_state=random_state
1913
+ ), # LassoCV自动找出最适alpha,优于Lasso
1826
1914
  "Gradient Boosting": GradientBoostingRegressor(random_state=random_state),
1827
- "XGBoost": xgb.XGBRegressor(eval_metric="rmse",random_state=random_state),
1915
+ "XGBoost": xgb.XGBRegressor(eval_metric="rmse", random_state=random_state),
1828
1916
  "Linear Regression": LinearRegression(),
1829
1917
  "Lasso": Lasso(random_state=random_state),
1830
1918
  "AdaBoost": AdaBoostRegressor(random_state=random_state),
@@ -1834,71 +1922,76 @@ def predict(
1834
1922
  "Bagging": BaggingRegressor(random_state=random_state),
1835
1923
  "Neural Network": MLPRegressor(max_iter=500, random_state=random_state),
1836
1924
  "ElasticNet": ElasticNet(random_state=random_state),
1837
- "Ridge": Ridge(),
1838
- "KNN":KNeighborsRegressor()
1925
+ "Ridge": Ridge(),
1926
+ "KNN": KNeighborsRegressor(),
1839
1927
  }
1840
- # indicate cls:
1841
- if ips.run_once_within(30):# 10 min
1928
+ # indicate cls:
1929
+ if ips.run_once_within(30): # 10 min
1842
1930
  print(f"supported models: {list(model_.keys())}")
1843
1931
  if cls is None:
1844
- models=model_
1932
+ models = model_
1845
1933
  else:
1846
1934
  if not isinstance(cls, list):
1847
- cls=[cls]
1848
- models={}
1849
- for cls_ in cls:
1935
+ cls = [cls]
1936
+ models = {}
1937
+ for cls_ in cls:
1850
1938
  cls_ = ips.strcmp(cls_, list(model_.keys()))[0]
1851
1939
  models[cls_] = model_[cls_]
1852
- if 'LightGBM' in models:
1853
- x_train=ips.df_special_characters_cleaner(x_train)
1854
- x_true=ips.df_special_characters_cleaner(x_true) if x_true is not None else None
1855
-
1856
- if isinstance(y_train, str) and y_train in x_train.columns:
1857
- y_train_col_name=y_train
1858
- y_train=x_train[y_train]
1859
- y_train=ips.df_encoder(pd.DataFrame(y_train),method='dummy')
1860
- x_train = x_train.drop(y_train_col_name,axis=1)
1940
+ if "LightGBM" in models:
1941
+ x_train = ips.df_special_characters_cleaner(x_train)
1942
+ x_true = (
1943
+ ips.df_special_characters_cleaner(x_true) if x_true is not None else None
1944
+ )
1945
+
1946
+ if isinstance(y_train, str) and y_train in x_train.columns:
1947
+ y_train_col_name = y_train
1948
+ y_train = x_train[y_train]
1949
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
1950
+ x_train = x_train.drop(y_train_col_name, axis=1)
1861
1951
  else:
1862
- y_train=ips.df_encoder(pd.DataFrame(y_train),method='dummy').values.ravel()
1952
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
1863
1953
 
1864
1954
  if x_true is None:
1865
1955
  x_train, x_true, y_train, y_true = train_test_split(
1866
- x_train,
1867
- y_train,
1868
- test_size=test_size,
1869
- random_state=random_state,
1870
- stratify=y_train if purpose == "classification" else None
1956
+ x_train,
1957
+ y_train,
1958
+ test_size=test_size,
1959
+ random_state=random_state,
1960
+ stratify=y_train if purpose == "classification" else None,
1871
1961
  )
1872
- if isinstance(y_train, str) and y_train in x_train.columns:
1873
- y_train_col_name=y_train
1874
- y_train=x_train[y_train]
1875
- y_train=ips.df_encoder(pd.DataFrame(y_train),method='dummy')
1876
- x_train = x_train.drop(y_train_col_name,axis=1)
1962
+ if isinstance(y_train, str) and y_train in x_train.columns:
1963
+ y_train_col_name = y_train
1964
+ y_train = x_train[y_train]
1965
+ y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy")
1966
+ x_train = x_train.drop(y_train_col_name, axis=1)
1877
1967
  else:
1878
- y_train=ips.df_encoder(pd.DataFrame(y_train),method='dummy').values.ravel()
1968
+ y_train = ips.df_encoder(
1969
+ pd.DataFrame(y_train), method="dummy"
1970
+ ).values.ravel()
1879
1971
  if y_true is not None:
1880
- if isinstance(y_true, str) and y_true in x_true.columns:
1881
- y_true_col_name=y_true
1882
- y_true=x_true[y_true]
1883
- y_true=ips.df_encoder(pd.DataFrame(y_true),method='dummy')
1884
- x_true = x_true.drop(y_true_col_name,axis=1)
1972
+ if isinstance(y_true, str) and y_true in x_true.columns:
1973
+ y_true_col_name = y_true
1974
+ y_true = x_true[y_true]
1975
+ y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy")
1976
+ x_true = x_true.drop(y_true_col_name, axis=1)
1885
1977
  else:
1886
- y_true=ips.df_encoder(pd.DataFrame(y_true),method='dummy').values.ravel()
1978
+ y_true = ips.df_encoder(pd.DataFrame(y_true), method="dummy").values.ravel()
1887
1979
 
1888
1980
  # to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
1889
1981
 
1890
1982
  # y_train=y_train.values.ravel() if y_train is not None else None
1891
1983
  # y_true=y_true.values.ravel() if y_true is not None else None
1892
- y_train = y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
1984
+ y_train = (
1985
+ y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
1986
+ )
1893
1987
  y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
1894
1988
 
1895
-
1896
1989
  # Ensure common features are selected
1897
1990
  if common_features is not None:
1898
1991
  x_train, x_true = x_train[common_features], x_true[common_features]
1899
1992
  else:
1900
- share_col_names = ips.shared(x_train.columns, x_true.columns,verbose=verbose)
1901
- x_train, x_true =x_train[share_col_names], x_true[share_col_names]
1993
+ share_col_names = ips.shared(x_train.columns, x_true.columns, verbose=verbose)
1994
+ x_train, x_true = x_train[share_col_names], x_true[share_col_names]
1902
1995
 
1903
1996
  x_train, x_true = ips.df_scaler(x_train), ips.df_scaler(x_true)
1904
1997
  x_train, x_true = ips.df_encoder(x_train, method="dummy"), ips.df_encoder(
@@ -1917,26 +2010,30 @@ def predict(
1917
2010
  x_train, y_train = smote_sampler.fit_resample(x_train, y_train)
1918
2011
 
1919
2012
  # Hyperparameter grids for tuning
1920
- if cv_level in ["low",'simple','s','l']:
2013
+ if cv_level in ["low", "simple", "s", "l"]:
1921
2014
  param_grids = {
1922
- "Random Forest": {
1923
- "n_estimators": [100], # One basic option
1924
- "max_depth": [None, 10],
1925
- "min_samples_split": [2],
1926
- "min_samples_leaf": [1],
1927
- "class_weight": [None],
1928
- } if purpose == "classification" else {
1929
- "n_estimators": [100], # One basic option
1930
- "max_depth": [None, 10],
1931
- "min_samples_split": [2],
1932
- "min_samples_leaf": [1],
1933
- "max_features": [None],
1934
- "bootstrap": [True], # Only one option for simplicity
1935
- },
2015
+ "Random Forest": (
2016
+ {
2017
+ "n_estimators": [100], # One basic option
2018
+ "max_depth": [None, 10],
2019
+ "min_samples_split": [2],
2020
+ "min_samples_leaf": [1],
2021
+ "class_weight": [None],
2022
+ }
2023
+ if purpose == "classification"
2024
+ else {
2025
+ "n_estimators": [100], # One basic option
2026
+ "max_depth": [None, 10],
2027
+ "min_samples_split": [2],
2028
+ "min_samples_leaf": [1],
2029
+ "max_features": [None],
2030
+ "bootstrap": [True], # Only one option for simplicity
2031
+ }
2032
+ ),
1936
2033
  "SVM": {
1937
2034
  "C": [1],
1938
- "gamma": ['scale'],
1939
- "kernel": ['rbf'],
2035
+ "gamma": ["scale"],
2036
+ "kernel": ["rbf"],
1940
2037
  },
1941
2038
  "Lasso": {
1942
2039
  "alpha": [0.1],
@@ -1946,8 +2043,8 @@ def predict(
1946
2043
  },
1947
2044
  "Logistic Regression": {
1948
2045
  "C": [1],
1949
- "solver": ['lbfgs'],
1950
- "penalty": ['l2'],
2046
+ "solver": ["lbfgs"],
2047
+ "penalty": ["l2"],
1951
2048
  "max_iter": [500],
1952
2049
  },
1953
2050
  "Gradient Boosting": {
@@ -1964,25 +2061,29 @@ def predict(
1964
2061
  "subsample": [0.8],
1965
2062
  "colsample_bytree": [0.8],
1966
2063
  },
1967
- "KNN": {
1968
- "n_neighbors": [3],
1969
- "weights": ['uniform'],
1970
- "algorithm": ['auto'],
1971
- "p": [2],
1972
- } if purpose == 'classification' else {
1973
- 'n_neighbors': [3],
1974
- 'weights': ['uniform'],
1975
- 'metric': ['euclidean'],
1976
- 'leaf_size': [30],
1977
- 'p': [2],
1978
- },
2064
+ "KNN": (
2065
+ {
2066
+ "n_neighbors": [3],
2067
+ "weights": ["uniform"],
2068
+ "algorithm": ["auto"],
2069
+ "p": [2],
2070
+ }
2071
+ if purpose == "classification"
2072
+ else {
2073
+ "n_neighbors": [3],
2074
+ "weights": ["uniform"],
2075
+ "metric": ["euclidean"],
2076
+ "leaf_size": [30],
2077
+ "p": [2],
2078
+ }
2079
+ ),
1979
2080
  "Naive Bayes": {
1980
2081
  "var_smoothing": [1e-9],
1981
2082
  },
1982
2083
  "SVR": {
1983
2084
  "C": [1],
1984
- "gamma": ['scale'],
1985
- "kernel": ['rbf'],
2085
+ "gamma": ["scale"],
2086
+ "kernel": ["rbf"],
1986
2087
  },
1987
2088
  "Linear Regression": {
1988
2089
  "fit_intercept": [True],
@@ -2003,9 +2104,9 @@ def predict(
2003
2104
  "n_estimators": [100],
2004
2105
  "num_leaves": [31],
2005
2106
  "max_depth": [10],
2006
- 'min_data_in_leaf': [20],
2007
- 'min_gain_to_split': [0.01],
2008
- 'scale_pos_weight': [10],
2107
+ "min_data_in_leaf": [20],
2108
+ "min_gain_to_split": [0.01],
2109
+ "scale_pos_weight": [10],
2009
2110
  },
2010
2111
  "Bagging": {
2011
2112
  "n_estimators": [50],
@@ -2033,132 +2134,168 @@ def predict(
2033
2134
  "shrinkage": [None],
2034
2135
  },
2035
2136
  "Quadratic Discriminant Analysis": {
2036
- 'reg_param': [0.0],
2037
- 'priors': [None],
2038
- 'tol': [1e-4],
2039
- },
2040
- "Ridge": {'class_weight': [None, 'balanced']} if purpose == "classification" else {
2041
- 'alpha': [0.1, 1, 10],
2137
+ "reg_param": [0.0],
2138
+ "priors": [None],
2139
+ "tol": [1e-4],
2042
2140
  },
2141
+ "Ridge": (
2142
+ {"class_weight": [None, "balanced"]}
2143
+ if purpose == "classification"
2144
+ else {
2145
+ "alpha": [0.1, 1, 10],
2146
+ }
2147
+ ),
2043
2148
  "Perceptron": {
2044
- 'alpha': [1e-3],
2045
- 'penalty': ['l2'],
2046
- 'max_iter': [1000],
2047
- 'eta0': [1.0],
2149
+ "alpha": [1e-3],
2150
+ "penalty": ["l2"],
2151
+ "max_iter": [1000],
2152
+ "eta0": [1.0],
2048
2153
  },
2049
2154
  "Bernoulli Naive Bayes": {
2050
- 'alpha': [0.1, 1, 10],
2051
- 'binarize': [0.0],
2052
- 'fit_prior': [True],
2155
+ "alpha": [0.1, 1, 10],
2156
+ "binarize": [0.0],
2157
+ "fit_prior": [True],
2053
2158
  },
2054
2159
  "SGDClassifier": {
2055
- 'eta0': [0.01],
2056
- 'loss': ['hinge'],
2057
- 'penalty': ['l2'],
2058
- 'alpha': [1e-3],
2059
- 'max_iter': [1000],
2060
- 'tol': [1e-3],
2061
- 'random_state': [random_state],
2062
- 'learning_rate': ['constant'],
2160
+ "eta0": [0.01],
2161
+ "loss": ["hinge"],
2162
+ "penalty": ["l2"],
2163
+ "alpha": [1e-3],
2164
+ "max_iter": [1000],
2165
+ "tol": [1e-3],
2166
+ "random_state": [random_state],
2167
+ "learning_rate": ["constant"],
2063
2168
  },
2064
2169
  }
2065
- elif cv_level in ['high','advanced','h']:
2066
- param_grids = {
2067
- "Random Forest": {
2068
- "n_estimators": [100, 200, 500, 700, 1000],
2069
- "max_depth": [None, 3, 5, 10, 15, 20, 30],
2070
- "min_samples_split": [2, 5, 10, 20],
2071
- "min_samples_leaf": [1, 2, 4],
2072
- "class_weight": [None, "balanced"] if purpose == "classification" else {},
2073
- } if purpose == "classification" else {
2074
- "n_estimators": [100, 200, 500, 700, 1000],
2075
- "max_depth": [None, 3, 5, 10, 15, 20, 30],
2076
- "min_samples_split": [2, 5, 10, 20],
2077
- "min_samples_leaf": [1, 2, 4],
2078
- "max_features": ['auto', 'sqrt', 'log2'], # Number of features to consider when looking for the best split
2079
- "bootstrap": [True, False], # Whether bootstrap samples are used when building trees
2080
- },
2170
+ elif cv_level in ["high", "advanced", "h"]:
2171
+ param_grids = {
2172
+ "Random Forest": (
2173
+ {
2174
+ "n_estimators": [100, 200, 500, 700, 1000],
2175
+ "max_depth": [None, 3, 5, 10, 15, 20, 30],
2176
+ "min_samples_split": [2, 5, 10, 20],
2177
+ "min_samples_leaf": [1, 2, 4],
2178
+ "class_weight": (
2179
+ [None, "balanced"] if purpose == "classification" else {}
2180
+ ),
2181
+ }
2182
+ if purpose == "classification"
2183
+ else {
2184
+ "n_estimators": [100, 200, 500, 700, 1000],
2185
+ "max_depth": [None, 3, 5, 10, 15, 20, 30],
2186
+ "min_samples_split": [2, 5, 10, 20],
2187
+ "min_samples_leaf": [1, 2, 4],
2188
+ "max_features": [
2189
+ "auto",
2190
+ "sqrt",
2191
+ "log2",
2192
+ ], # Number of features to consider when looking for the best split
2193
+ "bootstrap": [
2194
+ True,
2195
+ False,
2196
+ ], # Whether bootstrap samples are used when building trees
2197
+ }
2198
+ ),
2081
2199
  "SVM": {
2082
- "C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
2083
- "gamma": ["scale", "auto", 0.001, 0.01, 0.1],
2084
- "kernel": ["linear", "rbf", "poly"],
2085
- },
2200
+ "C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
2201
+ "gamma": ["scale", "auto", 0.001, 0.01, 0.1],
2202
+ "kernel": ["linear", "rbf", "poly"],
2203
+ },
2086
2204
  "Logistic Regression": {
2087
- "C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
2088
- "solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
2089
- "penalty": ["l1", "l2", "elasticnet"],
2090
- "max_iter": [100, 200, 300, 500],
2091
- },
2092
- "Lasso":{
2093
- "alpha": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2094
- "max_iter": [500, 1000, 2000, 5000],
2095
- "tol": [1e-4, 1e-5, 1e-6],
2096
- "selection": ["cyclic", "random"]
2097
- },
2098
- "LassoCV":{
2099
- "alphas": [[0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
2100
- "max_iter": [500, 1000, 2000, 5000],
2101
- "cv": [3, 5, 10],
2102
- "tol": [1e-4, 1e-5, 1e-6]
2103
- },
2205
+ "C": [0.001, 0.01, 0.1, 1, 10, 100, 1000],
2206
+ "solver": ["liblinear", "saga", "newton-cg", "lbfgs"],
2207
+ "penalty": ["l1", "l2", "elasticnet"],
2208
+ "max_iter": [100, 200, 300, 500],
2209
+ },
2210
+ "Lasso": {
2211
+ "alpha": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2212
+ "max_iter": [500, 1000, 2000, 5000],
2213
+ "tol": [1e-4, 1e-5, 1e-6],
2214
+ "selection": ["cyclic", "random"],
2215
+ },
2216
+ "LassoCV": {
2217
+ "alphas": [[0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
2218
+ "max_iter": [500, 1000, 2000, 5000],
2219
+ "cv": [3, 5, 10],
2220
+ "tol": [1e-4, 1e-5, 1e-6],
2221
+ },
2104
2222
  "Gradient Boosting": {
2105
- "n_estimators": [100, 200, 300, 400, 500, 700, 1000],
2106
- "learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.5],
2107
- "max_depth": [3, 5, 7, 9, 15],
2108
- "min_samples_split": [2, 5, 10, 20],
2109
- "subsample": [0.8, 1.0],
2110
- },
2223
+ "n_estimators": [100, 200, 300, 400, 500, 700, 1000],
2224
+ "learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.5],
2225
+ "max_depth": [3, 5, 7, 9, 15],
2226
+ "min_samples_split": [2, 5, 10, 20],
2227
+ "subsample": [0.8, 1.0],
2228
+ },
2111
2229
  "XGBoost": {
2112
- "n_estimators": [100, 200, 500, 700],
2113
- "max_depth": [3, 5, 7, 10],
2114
- "learning_rate": [0.01, 0.1, 0.2, 0.3],
2115
- "subsample": [0.8, 1.0],
2116
- "colsample_bytree": [0.8, 0.9, 1.0],
2117
- },
2118
- "KNN": {
2119
- "n_neighbors": [1, 3, 5, 10, 15, 20],
2120
- "weights": ["uniform", "distance"],
2121
- "algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
2122
- "p": [1, 2], # 1 for Manhattan, 2 for Euclidean distance
2123
- } if purpose=='classification' else {
2124
- 'n_neighbors': [3, 5, 7, 9, 11], # Number of neighbors
2125
- 'weights': ['uniform', 'distance'], # Weight function used in prediction
2126
- 'metric': ['euclidean', 'manhattan', 'minkowski'], # Distance metric
2127
- 'leaf_size': [20, 30, 40, 50], # Leaf size for KDTree or BallTree algorithms
2128
- 'p': [1, 2] # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
2129
- },
2230
+ "n_estimators": [100, 200, 500, 700],
2231
+ "max_depth": [3, 5, 7, 10],
2232
+ "learning_rate": [0.01, 0.1, 0.2, 0.3],
2233
+ "subsample": [0.8, 1.0],
2234
+ "colsample_bytree": [0.8, 0.9, 1.0],
2235
+ },
2236
+ "KNN": (
2237
+ {
2238
+ "n_neighbors": [1, 3, 5, 10, 15, 20],
2239
+ "weights": ["uniform", "distance"],
2240
+ "algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
2241
+ "p": [1, 2], # 1 for Manhattan, 2 for Euclidean distance
2242
+ }
2243
+ if purpose == "classification"
2244
+ else {
2245
+ "n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
2246
+ "weights": [
2247
+ "uniform",
2248
+ "distance",
2249
+ ], # Weight function used in prediction
2250
+ "metric": [
2251
+ "euclidean",
2252
+ "manhattan",
2253
+ "minkowski",
2254
+ ], # Distance metric
2255
+ "leaf_size": [
2256
+ 20,
2257
+ 30,
2258
+ 40,
2259
+ 50,
2260
+ ], # Leaf size for KDTree or BallTree algorithms
2261
+ "p": [
2262
+ 1,
2263
+ 2,
2264
+ ], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
2265
+ }
2266
+ ),
2130
2267
  "Naive Bayes": {
2131
2268
  "var_smoothing": [1e-10, 1e-9, 1e-8, 1e-7],
2132
- },
2269
+ },
2133
2270
  "AdaBoost": {
2134
- "n_estimators": [50, 100, 200, 300, 500],
2135
- "learning_rate": [0.001, 0.01, 0.1, 0.5, 1.0],
2136
- },
2271
+ "n_estimators": [50, 100, 200, 300, 500],
2272
+ "learning_rate": [0.001, 0.01, 0.1, 0.5, 1.0],
2273
+ },
2137
2274
  "SVR": {
2138
- "C": [0.01, 0.1, 1, 10, 100, 1000],
2139
- "gamma": [0.001, 0.01, 0.1, "scale", "auto"],
2140
- "kernel": ["linear", "rbf", "poly"],
2141
- },
2275
+ "C": [0.01, 0.1, 1, 10, 100, 1000],
2276
+ "gamma": [0.001, 0.01, 0.1, "scale", "auto"],
2277
+ "kernel": ["linear", "rbf", "poly"],
2278
+ },
2142
2279
  "Linear Regression": {
2143
2280
  "fit_intercept": [True, False],
2144
- },
2145
- "Lasso":{
2281
+ },
2282
+ "Lasso": {
2146
2283
  "alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2147
- "max_iter": [1000, 2000] # Higher iteration limit for fine-tuning
2148
- },
2284
+ "max_iter": [1000, 2000], # Higher iteration limit for fine-tuning
2285
+ },
2149
2286
  "Extra Trees": {
2150
2287
  "n_estimators": [100, 200, 500, 700, 1000],
2151
2288
  "max_depth": [None, 5, 10, 15, 20, 30],
2152
2289
  "min_samples_split": [2, 5, 10, 20],
2153
- "min_samples_leaf": [1, 2, 4]
2154
- },
2290
+ "min_samples_leaf": [1, 2, 4],
2291
+ },
2155
2292
  "CatBoost": {
2156
2293
  "iterations": [100, 200, 500],
2157
2294
  "learning_rate": [0.001, 0.01, 0.1, 0.2],
2158
2295
  "depth": [3, 5, 7, 10],
2159
2296
  "l2_leaf_reg": [1, 3, 5, 7, 10],
2160
2297
  "border_count": [32, 64, 128],
2161
- },
2298
+ },
2162
2299
  "LightGBM": {
2163
2300
  "n_estimators": [100, 200, 500, 700, 1000],
2164
2301
  "learning_rate": [0.001, 0.01, 0.1, 0.2],
@@ -2167,66 +2304,97 @@ def predict(
2167
2304
  "min_child_samples": [5, 10, 20],
2168
2305
  "subsample": [0.8, 1.0],
2169
2306
  "colsample_bytree": [0.8, 0.9, 1.0],
2170
- },
2307
+ },
2171
2308
  "Neural Network": {
2172
2309
  "hidden_layer_sizes": [(50,), (100,), (100, 50), (200, 100)],
2173
2310
  "activation": ["relu", "tanh", "logistic"],
2174
2311
  "solver": ["adam", "sgd", "lbfgs"],
2175
2312
  "alpha": [0.0001, 0.001, 0.01],
2176
2313
  "learning_rate": ["constant", "adaptive"],
2177
- },
2314
+ },
2178
2315
  "Decision Tree": {
2179
2316
  "max_depth": [None, 5, 10, 20, 30],
2180
2317
  "min_samples_split": [2, 5, 10, 20],
2181
2318
  "min_samples_leaf": [1, 2, 5, 10],
2182
2319
  "criterion": ["gini", "entropy"],
2183
2320
  "splitter": ["best", "random"],
2184
- },
2321
+ },
2185
2322
  "Linear Discriminant Analysis": {
2186
2323
  "solver": ["svd", "lsqr", "eigen"],
2187
- "shrinkage": [None, "auto", 0.1, 0.5, 1.0], # shrinkage levels for 'lsqr' and 'eigen'
2188
- },
2189
- 'Ridge': {'class_weight': [None, 'balanced']} if purpose == "classification" else {
2190
- 'alpha': [0.1, 1, 10, 100, 1000],
2191
- 'solver': ['auto', 'svd', 'cholesky', 'lsqr', 'lbfgs'],
2192
- 'fit_intercept': [True, False], # Whether to calculate the intercept
2193
- 'normalize': [True, False] # If True, the regressors X will be normalized
2324
+ "shrinkage": [
2325
+ None,
2326
+ "auto",
2327
+ 0.1,
2328
+ 0.5,
2329
+ 1.0,
2330
+ ], # shrinkage levels for 'lsqr' and 'eigen'
2331
+ },
2332
+ "Ridge": (
2333
+ {"class_weight": [None, "balanced"]}
2334
+ if purpose == "classification"
2335
+ else {
2336
+ "alpha": [0.1, 1, 10, 100, 1000],
2337
+ "solver": ["auto", "svd", "cholesky", "lsqr", "lbfgs"],
2338
+ "fit_intercept": [
2339
+ True,
2340
+ False,
2341
+ ], # Whether to calculate the intercept
2342
+ "normalize": [
2343
+ True,
2344
+ False,
2345
+ ], # If True, the regressors X will be normalized
2194
2346
  }
2195
- }
2196
- else: # median level
2197
- param_grids = {
2198
- "Random Forest": {
2199
- "n_estimators": [100, 200, 500],
2200
- "max_depth": [None, 10, 20, 30],
2201
- "min_samples_split": [2, 5, 10],
2202
- "min_samples_leaf": [1, 2, 4],
2203
- "class_weight": [None, "balanced"]
2204
- } if purpose == "classification" else {
2205
- "n_estimators": [100, 200, 500],
2206
- "max_depth": [None, 10, 20, 30],
2207
- "min_samples_split": [2, 5, 10],
2208
- "min_samples_leaf": [1, 2, 4],
2209
- "max_features": ['auto', 'sqrt', 'log2'], # Number of features to consider when looking for the best split
2210
- "bootstrap": [True, False], # Whether bootstrap samples are used when building trees
2211
- },
2347
+ ),
2348
+ }
2349
+ else: # median level
2350
+ param_grids = {
2351
+ "Random Forest": (
2352
+ {
2353
+ "n_estimators": [100, 200, 500],
2354
+ "max_depth": [None, 10, 20, 30],
2355
+ "min_samples_split": [2, 5, 10],
2356
+ "min_samples_leaf": [1, 2, 4],
2357
+ "class_weight": [None, "balanced"],
2358
+ }
2359
+ if purpose == "classification"
2360
+ else {
2361
+ "n_estimators": [100, 200, 500],
2362
+ "max_depth": [None, 10, 20, 30],
2363
+ "min_samples_split": [2, 5, 10],
2364
+ "min_samples_leaf": [1, 2, 4],
2365
+ "max_features": [
2366
+ "auto",
2367
+ "sqrt",
2368
+ "log2",
2369
+ ], # Number of features to consider when looking for the best split
2370
+ "bootstrap": [
2371
+ True,
2372
+ False,
2373
+ ], # Whether bootstrap samples are used when building trees
2374
+ }
2375
+ ),
2212
2376
  "SVM": {
2213
2377
  "C": [0.1, 1, 10, 100], # Regularization strength
2214
- "gamma": ['scale', 'auto'], # Common gamma values
2215
- "kernel": ['rbf', 'linear', 'poly'],
2378
+ "gamma": ["scale", "auto"], # Common gamma values
2379
+ "kernel": ["rbf", "linear", "poly"],
2216
2380
  },
2217
2381
  "Logistic Regression": {
2218
2382
  "C": [0.1, 1, 10, 100], # Regularization strength
2219
- "solver": ['lbfgs', 'liblinear', 'saga'], # Common solvers
2220
- "penalty": ['l2'], # L2 penalty is most common
2221
- "max_iter": [500, 1000, 2000], # Increased max_iter for better convergence
2383
+ "solver": ["lbfgs", "liblinear", "saga"], # Common solvers
2384
+ "penalty": ["l2"], # L2 penalty is most common
2385
+ "max_iter": [
2386
+ 500,
2387
+ 1000,
2388
+ 2000,
2389
+ ], # Increased max_iter for better convergence
2222
2390
  },
2223
- "Lasso":{
2391
+ "Lasso": {
2224
2392
  "alpha": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0],
2225
- "max_iter": [500, 1000, 2000]
2393
+ "max_iter": [500, 1000, 2000],
2226
2394
  },
2227
- "LassoCV":{
2395
+ "LassoCV": {
2228
2396
  "alphas": [[0.001, 0.01, 0.1, 1.0, 10.0, 100.0]],
2229
- "max_iter": [500, 1000, 2000]
2397
+ "max_iter": [500, 1000, 2000],
2230
2398
  },
2231
2399
  "Gradient Boosting": {
2232
2400
  "n_estimators": [100, 200, 500],
@@ -2242,25 +2410,44 @@ def predict(
2242
2410
  "subsample": [0.8, 1.0],
2243
2411
  "colsample_bytree": [0.8, 1.0],
2244
2412
  },
2245
- "KNN": {
2246
- "n_neighbors": [3, 5, 7, 10],
2247
- "weights": ['uniform', 'distance'],
2248
- "algorithm": ['auto', 'ball_tree', 'kd_tree', 'brute'],
2249
- "p": [1, 2],
2250
- } if purpose=='classification' else {
2251
- 'n_neighbors': [3, 5, 7, 9, 11], # Number of neighbors
2252
- 'weights': ['uniform', 'distance'], # Weight function used in prediction
2253
- 'metric': ['euclidean', 'manhattan', 'minkowski'], # Distance metric
2254
- 'leaf_size': [20, 30, 40, 50], # Leaf size for KDTree or BallTree algorithms
2255
- 'p': [1, 2] # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
2256
- },
2413
+ "KNN": (
2414
+ {
2415
+ "n_neighbors": [3, 5, 7, 10],
2416
+ "weights": ["uniform", "distance"],
2417
+ "algorithm": ["auto", "ball_tree", "kd_tree", "brute"],
2418
+ "p": [1, 2],
2419
+ }
2420
+ if purpose == "classification"
2421
+ else {
2422
+ "n_neighbors": [3, 5, 7, 9, 11], # Number of neighbors
2423
+ "weights": [
2424
+ "uniform",
2425
+ "distance",
2426
+ ], # Weight function used in prediction
2427
+ "metric": [
2428
+ "euclidean",
2429
+ "manhattan",
2430
+ "minkowski",
2431
+ ], # Distance metric
2432
+ "leaf_size": [
2433
+ 20,
2434
+ 30,
2435
+ 40,
2436
+ 50,
2437
+ ], # Leaf size for KDTree or BallTree algorithms
2438
+ "p": [
2439
+ 1,
2440
+ 2,
2441
+ ], # Power parameter for the Minkowski metric (1 = Manhattan, 2 = Euclidean)
2442
+ }
2443
+ ),
2257
2444
  "Naive Bayes": {
2258
2445
  "var_smoothing": [1e-9, 1e-8, 1e-7],
2259
2446
  },
2260
2447
  "SVR": {
2261
2448
  "C": [0.1, 1, 10, 100],
2262
- "gamma": ['scale', 'auto'],
2263
- "kernel": ['rbf', 'linear'],
2449
+ "gamma": ["scale", "auto"],
2450
+ "kernel": ["rbf", "linear"],
2264
2451
  },
2265
2452
  "Linear Regression": {
2266
2453
  "fit_intercept": [True, False],
@@ -2286,10 +2473,10 @@ def predict(
2286
2473
  "learning_rate": [0.01, 0.1],
2287
2474
  "num_leaves": [31, 50, 100],
2288
2475
  "max_depth": [-1, 10, 20],
2289
- 'min_data_in_leaf': [20], # Minimum samples in each leaf
2290
- 'min_gain_to_split': [0.01], # Minimum gain to allow a split
2291
- 'scale_pos_weight': [10], # Address class imbalance
2292
- },
2476
+ "min_data_in_leaf": [20], # Minimum samples in each leaf
2477
+ "min_gain_to_split": [0.01], # Minimum gain to allow a split
2478
+ "scale_pos_weight": [10], # Address class imbalance
2479
+ },
2293
2480
  "Bagging": {
2294
2481
  "n_estimators": [10, 50, 100],
2295
2482
  "max_samples": [0.5, 0.7, 1.0],
@@ -2314,41 +2501,73 @@ def predict(
2314
2501
  "Linear Discriminant Analysis": {
2315
2502
  "solver": ["svd", "lsqr", "eigen"],
2316
2503
  "shrinkage": [None, "auto"],
2317
- }, "Quadratic Discriminant Analysis":{
2318
- 'reg_param': [0.0, 0.1, 0.5, 1.0], # Regularization parameter
2319
- 'priors': [None, [0.5, 0.5], [0.3, 0.7]], # Class priors
2320
- 'tol': [1e-4, 1e-3, 1e-2] # Tolerance value for the convergence of the algorithm
2321
- },
2322
- "Perceptron":{
2323
- 'alpha': [1e-4, 1e-3, 1e-2], # Regularization parameter
2324
- 'penalty': ['l2', 'l1', 'elasticnet'], # Regularization penalty
2325
- 'max_iter': [1000, 2000], # Maximum number of iterations
2326
- 'eta0': [1.0, 0.1], # Learning rate for gradient descent
2327
- 'tol': [1e-3, 1e-4, 1e-5], # Tolerance for stopping criteria
2328
- 'random_state': [random_state] # Random state for reproducibility
2329
- },
2330
- "Bernoulli Naive Bayes":{
2331
- 'alpha': [0.1, 1.0, 10.0], # Additive (Laplace) smoothing parameter
2332
- 'binarize': [0.0, 0.5, 1.0], # Threshold for binarizing the input features
2333
- 'fit_prior': [True, False] # Whether to learn class prior probabilities
2334
- },
2335
- "SGDClassifier":{
2336
- 'eta0': [0.01, 0.1, 1.0],
2337
- 'loss': ['hinge', 'log', 'modified_huber', 'squared_hinge', 'perceptron'], # Loss function
2338
- 'penalty': ['l2', 'l1', 'elasticnet'], # Regularization penalty
2339
- 'alpha': [1e-4, 1e-3, 1e-2], # Regularization strength
2340
- 'l1_ratio': [0.15, 0.5, 0.85], # L1 ratio for elasticnet penalty
2341
- 'max_iter': [1000, 2000], # Maximum number of iterations
2342
- 'tol': [1e-3, 1e-4], # Tolerance for stopping criteria
2343
- 'random_state': [random_state], # Random state for reproducibility
2344
- 'learning_rate': ['constant', 'optimal', 'invscaling', 'adaptive'], # Learning rate schedule
2345
- },
2346
- 'Ridge': {'class_weight': [None, 'balanced']} if purpose == "classification" else {
2347
- 'alpha': [0.1, 1, 10, 100],
2348
- 'solver': ['auto', 'svd', 'cholesky', 'lsqr'] # Solver for optimization
2349
- }
2504
+ },
2505
+ "Quadratic Discriminant Analysis": {
2506
+ "reg_param": [0.0, 0.1, 0.5, 1.0], # Regularization parameter
2507
+ "priors": [None, [0.5, 0.5], [0.3, 0.7]], # Class priors
2508
+ "tol": [
2509
+ 1e-4,
2510
+ 1e-3,
2511
+ 1e-2,
2512
+ ], # Tolerance value for the convergence of the algorithm
2513
+ },
2514
+ "Perceptron": {
2515
+ "alpha": [1e-4, 1e-3, 1e-2], # Regularization parameter
2516
+ "penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
2517
+ "max_iter": [1000, 2000], # Maximum number of iterations
2518
+ "eta0": [1.0, 0.1], # Learning rate for gradient descent
2519
+ "tol": [1e-3, 1e-4, 1e-5], # Tolerance for stopping criteria
2520
+ "random_state": [random_state], # Random state for reproducibility
2521
+ },
2522
+ "Bernoulli Naive Bayes": {
2523
+ "alpha": [0.1, 1.0, 10.0], # Additive (Laplace) smoothing parameter
2524
+ "binarize": [
2525
+ 0.0,
2526
+ 0.5,
2527
+ 1.0,
2528
+ ], # Threshold for binarizing the input features
2529
+ "fit_prior": [
2530
+ True,
2531
+ False,
2532
+ ], # Whether to learn class prior probabilities
2533
+ },
2534
+ "SGDClassifier": {
2535
+ "eta0": [0.01, 0.1, 1.0],
2536
+ "loss": [
2537
+ "hinge",
2538
+ "log",
2539
+ "modified_huber",
2540
+ "squared_hinge",
2541
+ "perceptron",
2542
+ ], # Loss function
2543
+ "penalty": ["l2", "l1", "elasticnet"], # Regularization penalty
2544
+ "alpha": [1e-4, 1e-3, 1e-2], # Regularization strength
2545
+ "l1_ratio": [0.15, 0.5, 0.85], # L1 ratio for elasticnet penalty
2546
+ "max_iter": [1000, 2000], # Maximum number of iterations
2547
+ "tol": [1e-3, 1e-4], # Tolerance for stopping criteria
2548
+ "random_state": [random_state], # Random state for reproducibility
2549
+ "learning_rate": [
2550
+ "constant",
2551
+ "optimal",
2552
+ "invscaling",
2553
+ "adaptive",
2554
+ ], # Learning rate schedule
2555
+ },
2556
+ "Ridge": (
2557
+ {"class_weight": [None, "balanced"]}
2558
+ if purpose == "classification"
2559
+ else {
2560
+ "alpha": [0.1, 1, 10, 100],
2561
+ "solver": [
2562
+ "auto",
2563
+ "svd",
2564
+ "cholesky",
2565
+ "lsqr",
2566
+ ], # Solver for optimization
2567
+ }
2568
+ ),
2350
2569
  }
2351
-
2570
+
2352
2571
  results = {}
2353
2572
  # Use StratifiedKFold for classification and KFold for regression
2354
2573
  cv = (
@@ -2359,11 +2578,11 @@ def predict(
2359
2578
 
2360
2579
  # Train and validate each model
2361
2580
  for name, clf in tqdm(
2362
- models.items(),
2363
- desc="models",
2364
- colour="green",
2365
- bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
2366
- ):
2581
+ models.items(),
2582
+ desc="models",
2583
+ colour="green",
2584
+ bar_format="{l_bar}{bar} {n_fmt}/{total_fmt}",
2585
+ ):
2367
2586
  if verbose:
2368
2587
  print(f"\nTraining and validating {name}:")
2369
2588
 
@@ -2381,7 +2600,7 @@ def predict(
2381
2600
  gs.fit(x_train, y_train)
2382
2601
  best_clf = gs.best_estimator_
2383
2602
  # make sure x_train and x_test has the same name
2384
- x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2603
+ x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
2385
2604
  y_pred = best_clf.predict(x_true)
2386
2605
 
2387
2606
  # y_pred_proba
@@ -2396,18 +2615,23 @@ def predict(
2396
2615
  )
2397
2616
  else:
2398
2617
  y_pred_proba = None # No probability output for certain models
2399
-
2400
-
2618
+
2401
2619
  validation_scores = {}
2402
2620
  if y_true is not None:
2403
- validation_scores = cal_metrics(y_true, y_pred, y_pred_proba=y_pred_proba, purpose=purpose, average="weighted")
2621
+ validation_scores = cal_metrics(
2622
+ y_true,
2623
+ y_pred,
2624
+ y_pred_proba=y_pred_proba,
2625
+ purpose=purpose,
2626
+ average="weighted",
2627
+ )
2404
2628
 
2405
2629
  # Calculate ROC curve
2406
2630
  # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
2407
2631
  if y_pred_proba is not None:
2408
2632
  # fpr, tpr, roc_auc = dict(), dict(), dict()
2409
2633
  fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
2410
- lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba,verbose=False)
2634
+ lower_ci, upper_ci = cal_auc_ci(y_true, y_pred_proba, verbose=False)
2411
2635
  roc_auc = auc(fpr, tpr)
2412
2636
  roc_info = {
2413
2637
  "fpr": fpr.tolist(),
@@ -2425,11 +2649,14 @@ def predict(
2425
2649
  }
2426
2650
  else:
2427
2651
  roc_info, pr_info = None, None
2428
- if purpose=="classification":
2652
+ if purpose == "classification":
2429
2653
  results[name] = {
2430
- "best_clf": gs.best_estimator_,
2654
+ "best_clf": gs.best_estimator_,
2431
2655
  "best_params": gs.best_params_,
2432
- "auc_indiv":[gs.cv_results_[f'split{i}_test_score'][gs.best_index_] for i in range(cv_folds)],
2656
+ "auc_indiv": [
2657
+ gs.cv_results_[f"split{i}_test_score"][gs.best_index_]
2658
+ for i in range(cv_folds)
2659
+ ],
2433
2660
  "scores": validation_scores,
2434
2661
  "roc_curve": roc_info,
2435
2662
  "pr_curve": pr_info,
@@ -2439,11 +2666,11 @@ def predict(
2439
2666
  y_pred_proba.tolist() if y_pred_proba is not None else None
2440
2667
  ),
2441
2668
  }
2442
- else: # "regression"
2669
+ else: # "regression"
2443
2670
  results[name] = {
2444
- "best_clf": gs.best_estimator_,
2671
+ "best_clf": gs.best_estimator_,
2445
2672
  "best_params": gs.best_params_,
2446
- "scores": validation_scores, # e.g., neg_MSE, R², etc.
2673
+ "scores": validation_scores, # e.g., neg_MSE, R², etc.
2447
2674
  "predictions": y_pred.tolist(),
2448
2675
  "predictions_proba": (
2449
2676
  y_pred_proba.tolist() if y_pred_proba is not None else None
@@ -2452,9 +2679,9 @@ def predict(
2452
2679
 
2453
2680
  else:
2454
2681
  results[name] = {
2455
- "best_clf": gs.best_estimator_,
2682
+ "best_clf": gs.best_estimator_,
2456
2683
  "best_params": gs.best_params_,
2457
- "scores": validation_scores,
2684
+ "scores": validation_scores,
2458
2685
  "predictions": y_pred.tolist(),
2459
2686
  "predictions_proba": (
2460
2687
  y_pred_proba.tolist() if y_pred_proba is not None else None
@@ -2465,76 +2692,80 @@ def predict(
2465
2692
  df_results = pd.DataFrame.from_dict(results, orient="index")
2466
2693
 
2467
2694
  # sort
2468
- if y_true is not None and purpose=="classification":
2695
+ if y_true is not None and purpose == "classification":
2469
2696
  df_scores = pd.DataFrame(
2470
- df_results["scores"].tolist(), index=df_results["scores"].index
2471
- ).sort_values(by="roc_auc", ascending=False)
2472
- df_results=df_results.loc[df_scores.index]
2697
+ df_results["scores"].tolist(), index=df_results["scores"].index
2698
+ ).sort_values(by="roc_auc", ascending=False)
2699
+ df_results = df_results.loc[df_scores.index]
2473
2700
 
2474
2701
  if plot_:
2475
2702
  from datetime import datetime
2703
+
2476
2704
  now_ = datetime.now().strftime("%y%m%d_%H%M%S")
2477
- nexttile=plot.subplot(figsize=[12, 10])
2478
- plot.heatmap(df_scores, kind="direct",ax=nexttile())
2705
+ nexttile = plot.subplot(figsize=[12, 10])
2706
+ plot.heatmap(df_scores, kind="direct", ax=nexttile())
2479
2707
  plot.figsets(xangle=30)
2480
2708
  if dir_save:
2481
- ips.figsave(dir_save+f"scores_sorted_heatmap{now_}.pdf")
2482
- if df_scores.shape[0]>1:# draw cluster
2483
- plot.heatmap(df_scores, kind="direct",cluster=True)
2709
+ ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
2710
+ if df_scores.shape[0] > 1: # draw cluster
2711
+ plot.heatmap(df_scores, kind="direct", cluster=True)
2484
2712
  plot.figsets(xangle=30)
2485
2713
  if dir_save:
2486
- ips.figsave(dir_save+f"scores_clus{now_}.pdf")
2487
- if all([plot_, y_true is not None, purpose=='classification']):
2714
+ ips.figsave(dir_save + f"scores_clus{now_}.pdf")
2715
+ if all([plot_, y_true is not None, purpose == "classification"]):
2488
2716
  try:
2489
- if len(models)>3:
2717
+ if len(models) > 3:
2490
2718
  plot_validate_features(df_results)
2491
2719
  else:
2492
- plot_validate_features_single(df_results,figsize=(12,4*len(models)))
2720
+ plot_validate_features_single(df_results, figsize=(12, 4 * len(models)))
2493
2721
  if dir_save:
2494
- ips.figsave(dir_save+f"validate_features{now_}.pdf")
2722
+ ips.figsave(dir_save + f"validate_features{now_}.pdf")
2495
2723
  except Exception as e:
2496
2724
  print(f"Error: 在画图的过程中出现了问题:{e}")
2497
2725
  return df_results
2498
2726
 
2499
2727
 
2500
- def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"):
2728
+ def cal_metrics(
2729
+ y_true, y_pred, y_pred_proba=None, purpose="regression", average="weighted"
2730
+ ):
2501
2731
  """
2502
2732
  Calculate regression or classification metrics based on the purpose.
2503
-
2733
+
2504
2734
  Parameters:
2505
2735
  - y_true: Array of true values.
2506
2736
  - y_pred: Array of predicted labels for classification or predicted values for regression.
2507
2737
  - y_pred_proba: Array of predicted probabilities for classification (optional).
2508
2738
  - purpose: str, "regression" or "classification".
2509
2739
  - average: str, averaging method for multi-class classification ("binary", "micro", "macro", "weighted", etc.).
2510
-
2740
+
2511
2741
  Returns:
2512
2742
  - validation_scores: dict of computed metrics.
2513
2743
  """
2514
2744
  from sklearn.metrics import (
2515
- mean_squared_error,
2516
- mean_absolute_error,
2517
- mean_absolute_percentage_error,
2518
- explained_variance_score,
2519
- r2_score,
2520
- mean_squared_log_error,
2521
- accuracy_score,
2522
- precision_score,
2523
- recall_score,
2524
- f1_score,
2525
- roc_auc_score,
2526
- matthews_corrcoef,
2527
- confusion_matrix,
2528
- balanced_accuracy_score,
2529
- average_precision_score,
2530
- precision_recall_curve
2531
- )
2745
+ mean_squared_error,
2746
+ mean_absolute_error,
2747
+ mean_absolute_percentage_error,
2748
+ explained_variance_score,
2749
+ r2_score,
2750
+ mean_squared_log_error,
2751
+ accuracy_score,
2752
+ precision_score,
2753
+ recall_score,
2754
+ f1_score,
2755
+ roc_auc_score,
2756
+ matthews_corrcoef,
2757
+ confusion_matrix,
2758
+ balanced_accuracy_score,
2759
+ average_precision_score,
2760
+ precision_recall_curve,
2761
+ )
2762
+
2532
2763
  validation_scores = {}
2533
2764
 
2534
2765
  if purpose == "regression":
2535
2766
  y_true = np.asarray(y_true)
2536
2767
  y_true = y_true.ravel()
2537
- y_pred = np.asarray(y_pred)
2768
+ y_pred = np.asarray(y_pred)
2538
2769
  y_pred = y_pred.ravel()
2539
2770
  # Regression metrics
2540
2771
  validation_scores = {
@@ -2544,7 +2775,7 @@ def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average
2544
2775
  "r2": r2_score(y_true, y_pred),
2545
2776
  "mape": mean_absolute_percentage_error(y_true, y_pred),
2546
2777
  "explained_variance": explained_variance_score(y_true, y_pred),
2547
- "mbd": np.mean(y_pred - y_true) # Mean Bias Deviation
2778
+ "mbd": np.mean(y_pred - y_true), # Mean Bias Deviation
2548
2779
  }
2549
2780
  # Check if MSLE can be calculated
2550
2781
  if np.all(y_true >= 0) and np.all(y_pred >= 0): # Ensure no negative values
@@ -2560,21 +2791,24 @@ def cal_metrics(y_true, y_pred, y_pred_proba=None, purpose="regression", average
2560
2791
  "recall": recall_score(y_true, y_pred, average=average),
2561
2792
  "f1": f1_score(y_true, y_pred, average=average),
2562
2793
  "mcc": matthews_corrcoef(y_true, y_pred),
2563
- "specificity": None,
2564
- "balanced_accuracy": balanced_accuracy_score(y_true, y_pred)
2794
+ "specificity": None,
2795
+ "balanced_accuracy": balanced_accuracy_score(y_true, y_pred),
2565
2796
  }
2566
2797
 
2567
2798
  # Confusion matrix to calculate specificity
2568
2799
  tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
2569
- validation_scores["specificity"] = tn / (tn + fp) if (tn + fp) > 0 else 0 # Specificity calculation
2800
+ validation_scores["specificity"] = (
2801
+ tn / (tn + fp) if (tn + fp) > 0 else 0
2802
+ ) # Specificity calculation
2570
2803
 
2571
- if y_pred_proba is not None:
2804
+ if y_pred_proba is not None:
2572
2805
  # Calculate ROC-AUC
2573
- validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
2806
+ validation_scores["roc_auc"] = roc_auc_score(y_true, y_pred_proba)
2574
2807
  # PR-AUC (Precision-Recall AUC) calculation
2575
2808
  validation_scores["pr_auc"] = average_precision_score(y_true, y_pred_proba)
2576
2809
  else:
2577
- raise ValueError("Invalid purpose specified. Choose 'regression' or 'classification'.")
2810
+ raise ValueError(
2811
+ "Invalid purpose specified. Choose 'regression' or 'classification'."
2812
+ )
2578
2813
 
2579
2814
  return validation_scores
2580
-