py2ls 0.2.4.8__py3-none-any.whl → 0.2.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/ml2ls.py CHANGED
@@ -582,7 +582,8 @@ def get_features(
582
582
  strict: bool = False,
583
583
  n_shared: int = 2, # 只要有两个方法有重合,就纳入common genes
584
584
  use_selected_features: bool = True,
585
- ) -> dict:
585
+ plot_: bool = True,
586
+ dir_save:str="./") -> dict:
586
587
  """
587
588
  Master function to perform feature selection and validate models.
588
589
  """
@@ -598,11 +599,15 @@ def get_features(
598
599
  # fill na
599
600
  if fill_missing:
600
601
  ips.df_fillna(data=X,method='knn',inplace=True,axis=0)
601
-
602
- # rm missing values
603
- X.dropna(inplace=True)
604
- y.dropna(inplace=True)
602
+ if isinstance(y, str) and y in X.columns:
603
+ y_col_name=y
604
+ y=X[y]
605
+ y=ips.df_encoder(pd.DataFrame(y),method='dummy')
606
+ X = X.drop(y_col_name,axis=1)
607
+ else:
608
+ y=ips.df_encoder(pd.DataFrame(y),method='dummy').values.ravel()
605
609
  y = y.loc[X.index] # Align y with X after dropping rows with missing values in X
610
+ y = y.ravel() if isinstance(y, np.ndarray) else y.values.ravel()
606
611
 
607
612
  if X.shape[0] != len(y):
608
613
  raise ValueError("X and y must have the same number of samples (rows).")
@@ -894,9 +899,13 @@ def get_features(
894
899
  results = {
895
900
  "selected_features": features_df,
896
901
  "cv_train_scores": cv_train_results_df,
897
- "cv_test_scores": rank_models(cv_test_results_df),
902
+ "cv_test_scores": rank_models(cv_test_results_df,plot_=plot_),
898
903
  "common_features": list(common_features),
899
904
  }
905
+ if all([plot_,dir_save]):
906
+ from datetime import datetime
907
+ now_ = datetime.now().strftime("%y%m%d_%H%M%S")
908
+ ips.figsave(dir_save+f"features{now_}.pdf")
900
909
  else:
901
910
  results = {
902
911
  "selected_features": pd.DataFrame(),
@@ -1707,36 +1716,55 @@ def predict(
1707
1716
  smote: bool = False,
1708
1717
  n_jobs:int = -1,
1709
1718
  plot_: bool = True,
1719
+ dir_save:str="./",
1710
1720
  test_size:float=0.2,# specific only when x_true is None
1711
1721
  cv_folds:int=5,# more cv_folds 得更加稳定,auc可能更低
1712
1722
  cv_level:str="l",#"s":'low',"m":'medium',"l":"high"
1713
1723
  class_weight: str = "balanced",
1714
1724
  verbose:bool=False,
1715
- dir_save:str="./"
1716
1725
  ) -> pd.DataFrame:
1717
- """
1718
- 1. 对x_train进行split_train_test,并对其进行validate
1719
- predict(x_train, y_train)
1720
- 2. 利用x_train, y_train的数据,对x_true的数据进行predict
1721
- predict(x_train, y_train, x_true)
1722
- 3. 利用x_train, y_train的数据,validate x_true和y_true
1723
- predict(x_train, y_train, x_true, y_true)
1724
-
1725
- Advanced master predictor function with grid search for hyperparameter tuning.
1726
-
1727
- Parameters:
1728
- - x_train, y_train: Training dataset.
1729
- - x_true, y_true: Dataset for validation or prediction (y_true=None for prediction).
1730
- - common_features (set): Common features to use for validation.
1731
- - purpose (str): Task type - 'classification' or 'regression'.
1732
- - models (dict): Dictionary of models and parameters.
1733
- - metrics (list): Metrics to compute.
1734
- - random_state (int): Seed for reproducibility.
1735
- - smote (bool): Use SMOTE for class imbalance (classification only).
1736
- - class_weight (str): Class weights to handle imbalance.
1737
-
1738
- Returns:
1739
- - df_results (pd.DataFrame): DataFrame with performance metrics and hyperparameters.
1726
+ """
1727
+ 第一种情况是内部拆分,第二种是直接预测,第三种是外部验证。
1728
+ Usage:
1729
+ (1). predict(x_train, y_train,...) 对 x_train 进行拆分训练/测试集,并在测试集上进行验证.
1730
+ predict 函数会根据 test_size 参数,将 x_train y_train 拆分出内部测试集。然后模型会在拆分出的训练集上进行训练,并在测试集上验证效果。
1731
+ (2). predict(x_train, y_train, x_true,...)使用 x_train y_train 训练并对 x_true 进行预测
1732
+ 由于传入了 x_true,函数会跳过 x_train 的拆分,直接使用全部的 x_train 和 y_train 进行训练。然后对 x_true 进行预测,但由于没有提供 y_true
1733
+ 因此无法与真实值进行对比。
1734
+ (3). predict(x_train, y_train, x_true, y_true,...)使用 x_train y_train 训练,并验证 x_true 与真实标签 y_true.
1735
+ predict 函数会在 x_train 和 y_train 上进行训练,并将 x_true 作为测试集。由于提供了 y_true,函数可以将预测结果与 y_true 进行对比,从而
1736
+ 计算验证指标,完成对 x_true 的真正验证。
1737
+ trains and validates a variety of machine learning models for both classification and regression tasks.
1738
+ It supports hyperparameter tuning with grid search and includes additional features like cross-validation,
1739
+ feature scaling, and handling of class imbalance through SMOTE.
1740
+
1741
+ Parameters:
1742
+ - x_train (pd.DataFrame):Training feature data, structured with each row as an observation and each column as a feature.
1743
+ - y_train (pd.Series):Target variable for the training dataset.
1744
+ - x_true (pd.DataFrame, optional):Test feature data. If not provided, the function splits x_train based on test_size.
1745
+ - y_true (pd.Series, optional):Test target values. If not provided, y_train is split into training and testing sets.
1746
+ - common_features (set, optional):Specifies a subset of features common across training and test data.
1747
+ - purpose (str, default = "classification"):Defines whether the task is "classification" or "regression". Determines which
1748
+ metrics and models are applied.
1749
+ - cls (dict, optional):Dictionary to specify custom classifiers/regressors. Defaults to a set of common models if not provided.
1750
+ - metrics (list, optional):List of evaluation metrics (like accuracy, F1 score) used for model evaluation.
1751
+ - random_state (int, default = 1):Random seed to ensure reproducibility.
1752
+ - smote (bool, default = False):Applies Synthetic Minority Oversampling Technique (SMOTE) to address class imbalance if enabled.
1753
+ - n_jobs (int, default = -1):Number of parallel jobs for computation. Set to -1 to use all available cores.
1754
+ - plot_ (bool, default = True):If True, generates plots of the model evaluation metrics.
1755
+ - test_size (float, default = 0.2):Test data proportion if x_true is not provided.
1756
+ - cv_folds (int, default = 5):Number of cross-validation folds.
1757
+ - cv_level (str, default = "l"):Sets the detail level of cross-validation. "s" for low, "m" for medium, and "l" for high.
1758
+ - class_weight (str, default = "balanced"):Balances class weights in classification tasks.
1759
+ - verbose (bool, default = False):If True, prints detailed output during model training.
1760
+ - dir_save (str, default = "./"):Directory path to save plot outputs and results.
1761
+
1762
+ Key Steps in the Function:
1763
+ Model Initialization: Depending on purpose, initializes either classification or regression models.
1764
+ Feature Selection: Ensures training and test sets have matching feature columns.
1765
+ SMOTE Application: Balances classes if smote is enabled and the task is classification.
1766
+ Cross-Validation and Hyperparameter Tuning: Utilizes GridSearchCV for model tuning based on cv_level.
1767
+ Evaluation and Plotting: Outputs evaluation metrics like AUC, confusion matrices, and optional plotting of performance metrics.
1740
1768
  """
1741
1769
  from tqdm import tqdm
1742
1770
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, ExtraTreesClassifier, ExtraTreesRegressor, BaggingClassifier, BaggingRegressor, AdaBoostClassifier, AdaBoostRegressor
@@ -1858,8 +1886,12 @@ def predict(
1858
1886
  y_true=ips.df_encoder(pd.DataFrame(y_true),method='dummy').values.ravel()
1859
1887
 
1860
1888
  # to convert the 2D to 1D: 2D column-vector format (like [[1], [0], [1], ...]) instead of a 1D array ([1, 0, 1, ...]
1861
- y_train=y_train.values.ravel() if y_train is not None else None
1862
- y_true=y_true.values.ravel() if y_true is not None else None
1889
+
1890
+ # y_train=y_train.values.ravel() if y_train is not None else None
1891
+ # y_true=y_true.values.ravel() if y_true is not None else None
1892
+ y_train = y_train.ravel() if isinstance(y_train, np.ndarray) else y_train.values.ravel()
1893
+ y_true = y_true.ravel() if isinstance(y_true, np.ndarray) else y_true.values.ravel()
1894
+
1863
1895
 
1864
1896
  # Ensure common features are selected
1865
1897
  if common_features is not None:
@@ -2440,15 +2472,18 @@ def predict(
2440
2472
  df_results=df_results.loc[df_scores.index]
2441
2473
 
2442
2474
  if plot_:
2475
+ from datetime import datetime
2476
+ now_ = datetime.now().strftime("%y%m%d_%H%M%S")
2443
2477
  nexttile=plot.subplot(figsize=[12, 10])
2444
2478
  plot.heatmap(df_scores, kind="direct",ax=nexttile())
2445
2479
  plot.figsets(xangle=30)
2446
2480
  if dir_save:
2447
- ips.figsave(dir_save+"scores_sorted_heatmap.pdf")
2448
- plot.heatmap(df_scores, kind="direct",cluster=True)
2449
- plot.figsets(xangle=30)
2450
- if dir_save:
2451
- ips.figsave(dir_save+"scores_clus.pdf")
2481
+ ips.figsave(dir_save+f"scores_sorted_heatmap{now_}.pdf")
2482
+ if df_scores.shape[0]>1:# draw cluster
2483
+ plot.heatmap(df_scores, kind="direct",cluster=True)
2484
+ plot.figsets(xangle=30)
2485
+ if dir_save:
2486
+ ips.figsave(dir_save+f"scores_clus{now_}.pdf")
2452
2487
  if all([plot_, y_true is not None, purpose=='classification']):
2453
2488
  try:
2454
2489
  if len(models)>3:
@@ -2456,7 +2491,7 @@ def predict(
2456
2491
  else:
2457
2492
  plot_validate_features_single(df_results,figsize=(12,4*len(models)))
2458
2493
  if dir_save:
2459
- ips.figsave(dir_save+"validate_features.pdf")
2494
+ ips.figsave(dir_save+f"validate_features{now_}.pdf")
2460
2495
  except Exception as e:
2461
2496
  print(f"Error: 在画图的过程中出现了问题:{e}")
2462
2497
  return df_results
py2ls/plot.py CHANGED
@@ -3020,7 +3020,7 @@ def plotxy(
3020
3020
  sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
3021
3021
 
3022
3022
  valid_kinds = list(default_settings.keys())
3023
- print(valid_kinds)
3023
+ # print(valid_kinds)
3024
3024
  if kind is not None:
3025
3025
  if isinstance(kind, str):
3026
3026
  kind = [kind]
@@ -3032,13 +3032,7 @@ def plotxy(
3032
3032
  if kind is not None:
3033
3033
  for k in kind:
3034
3034
  if k in valid_kinds:
3035
- print(f"{k}:\n\t{default_settings[k]}")
3036
- print(
3037
- sns_info[sns_info["Functions"].str.contains(k)]
3038
- .iloc[:, -1]
3039
- .tolist()[0]
3040
- )
3041
- print()
3035
+ print(f"{k}:\n\t{default_settings[k]}")
3042
3036
  usage_str = """plotxy(data=ranked_genes,
3043
3037
  x="log2(fold_change)",
3044
3038
  y="-log10(p-value)",
@@ -3102,7 +3096,6 @@ def plotxy(
3102
3096
  hue = kwargs.pop("hue", None)
3103
3097
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3104
3098
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3105
-
3106
3099
  palette = kws_scatter.pop("palette",get_color(data[hue].nunique()) if hue is not None else None)
3107
3100
  s = kws_scatter.pop("s", 10)
3108
3101
  alpha = kws_scatter.pop("alpha", 0.7)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: py2ls
3
- Version: 0.2.4.8
3
+ Version: 0.2.4.9
4
4
  Summary: py(thon)2(too)ls
5
5
  Author: Jianfeng
6
6
  Author-email: Jianfeng.Liu0413@gmail.com
@@ -56,7 +56,7 @@ Requires-Dist: coverage (>=7.6.0)
56
56
  Requires-Dist: coveralls (>=4.0.1)
57
57
  Requires-Dist: crashtest (>=0.4.1)
58
58
  Requires-Dist: cycler (>=0.12.1)
59
- Requires-Dist: dask[dataframe] (>=2023.6,<2024.0)
59
+ Requires-Dist: dask (>=2024.7.1)
60
60
  Requires-Dist: debugpy (>=1.8.2)
61
61
  Requires-Dist: decorator (>=5.1.1)
62
62
  Requires-Dist: defusedxml (>=0.7.1)
@@ -17,7 +17,7 @@ py2ls/.git/hooks/pre-receive.sample,sha256=pMPSuce7P9jRRBwxvU7nGlldZrRPz0ndsxAlI
17
17
  py2ls/.git/hooks/prepare-commit-msg.sample,sha256=6d3KpBif3dJe2X_Ix4nsp7bKFjkLI5KuMnbwyOGqRhk,1492
18
18
  py2ls/.git/hooks/push-to-checkout.sample,sha256=pT0HQXmLKHxt16-mSu5HPzBeZdP0lGO7nXQI7DsSv18,2783
19
19
  py2ls/.git/hooks/update.sample,sha256=jV8vqD4QPPCLV-qmdSHfkZT0XL28s32lKtWGCXoU0QY,3650
20
- py2ls/.git/index,sha256=w4n_eE4nWnGzWPUMesYlqgV_hldK1TphJVZrRyij5Vo,4232
20
+ py2ls/.git/index,sha256=CwYqVYvTvnn8cGWn-ctzQIfhmbgRfxtYXrAkmCA1AuU,4232
21
21
  py2ls/.git/info/exclude,sha256=ZnH-g7egfIky7okWTR8nk7IxgFjri5jcXAbuClo7DsE,240
22
22
  py2ls/.git/logs/HEAD,sha256=8ID7WuAe_TlO9g-ARxhIJYdgdL3u3m7-1qrOanaIUlA,3535
23
23
  py2ls/.git/logs/refs/heads/main,sha256=8ID7WuAe_TlO9g-ARxhIJYdgdL3u3m7-1qrOanaIUlA,3535
@@ -214,17 +214,17 @@ py2ls/export_requirements.py,sha256=x2WgUF0jYKz9GfA1MVKN-MdsM-oQ8yUeC6Ua8oCymio,
214
214
  py2ls/fetch_update.py,sha256=9LXj661GpCEFII2wx_99aINYctDiHni6DOruDs_fdt8,4752
215
215
  py2ls/freqanalysis.py,sha256=F4218VSPbgL5tnngh6xNCYuNnfR-F_QjECUUxrPYZss,32594
216
216
  py2ls/ich2ls.py,sha256=3E9R8oVpyYZXH5PiIQgT3CN5NxLe4Dwtm2LwaeacE6I,21381
217
- py2ls/ips.py,sha256=OJgNO3F-S7m5QjrwRjFOxG-sIZruRvK52NfnTm9yhTU,260110
218
- py2ls/ml2ls.py,sha256=SODP4ebQnbpdhX1VeUXTkHIKSxz37c0Brxis87vPv4U,102625
217
+ py2ls/ips.py,sha256=2Ds3kra7LtxVu5L1vNrpKjGFhg2mdnS5qcqSqHDNkkQ,265181
218
+ py2ls/ml2ls.py,sha256=EN-ufKgFs6NWPJVyh3mu9VmRyRK4vgi6rzufDd7B2pA,106633
219
219
  py2ls/mol.py,sha256=AZnHzarIk_MjueKdChqn1V6e4tUle3X1NnHSFA6n3Nw,10645
220
220
  py2ls/netfinder.py,sha256=RJFr80tGEJiuwEx99IBOhI5-ZuXnPdWnGUYpF7XCEwI,56426
221
221
  py2ls/ocr.py,sha256=5lhUbJufIKRSOL6wAWVLEo8TqMYSjoI_Q-IO-_4u3DE,31419
222
- py2ls/plot.py,sha256=IBIlcOYmXLrsgq_7338JlowZVPns8Hr3dHnvozwINl4,167825
222
+ py2ls/plot.py,sha256=LeQpTLvRHMDrQtU8yaeXEOgDdVm7KWLcAuRia6wWMYQ,167604
223
223
  py2ls/setuptools-70.1.0-py3-none-any.whl,sha256=2bi3cUVal8ip86s0SOvgspteEF8SKLukECi-EWmFomc,882588
224
224
  py2ls/sleep_events_detectors.py,sha256=bQA3HJqv5qnYKJJEIhCyhlDtkXQfIzqksnD0YRXso68,52145
225
225
  py2ls/stats.py,sha256=DMoJd8Z5YV9T1wB-4P52F5K5scfVK55DT8UP4Twcebo,38627
226
226
  py2ls/translator.py,sha256=zBeq4pYZeroqw3DT-5g7uHfVqKd-EQptT6LJ-Adi8JY,34244
227
227
  py2ls/wb_detector.py,sha256=7y6TmBUj9exCZeIgBAJ_9hwuhkDh1x_-yg4dvNY1_GQ,6284
228
- py2ls-0.2.4.8.dist-info/METADATA,sha256=Fr6xazl4OK1paux9REIe5pEKB-xDp4soHe7PLsvIHmA,20055
229
- py2ls-0.2.4.8.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
230
- py2ls-0.2.4.8.dist-info/RECORD,,
228
+ py2ls-0.2.4.9.dist-info/METADATA,sha256=4HaavKedVGS05_RLEBRr7E_A9XJotqR0oXRC0u-qR4k,20038
229
+ py2ls-0.2.4.9.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
230
+ py2ls-0.2.4.9.dist-info/RECORD,,