gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import joblib
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pathlib
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ParamFreeBaseModel(abc.ABC):
|
|
8
|
+
"""
|
|
9
|
+
BaseModel parent class for all models that do not have hyperparameters, e.g. BLUP.
|
|
10
|
+
|
|
11
|
+
Every model must be based on :obj:`~easypheno.model.param_free_base_model.ParamFreeBaseModel` directly or ParamFreeBaseModel's child classes.
|
|
12
|
+
|
|
13
|
+
Please add ``super().__init__(PARAMS)`` to the constructor in case you override it in a child class
|
|
14
|
+
|
|
15
|
+
**Attributes**
|
|
16
|
+
|
|
17
|
+
*Class attributes*
|
|
18
|
+
|
|
19
|
+
- standard_encoding (*str*): the standard encoding for this model
|
|
20
|
+
- possible_encodings (*List<str>*): a list of all encodings that are possible according to the model definition
|
|
21
|
+
|
|
22
|
+
*Instance attributes*
|
|
23
|
+
|
|
24
|
+
- task (*str*): ML task ('regression' or 'classification') depending on target variable
|
|
25
|
+
- encoding (*str*): the encoding to use (standard encoding or user-defined)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
:param task: ML task (regression or classification) depending on target variable
|
|
29
|
+
:param encoding: the encoding to use (standard encoding or user-defined)
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# Class attributes #
|
|
34
|
+
@property
|
|
35
|
+
@classmethod
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
def standard_encoding(cls):
|
|
38
|
+
"""the standard encoding for this model"""
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
@classmethod
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def possible_encodings(cls):
|
|
45
|
+
"""a list of all encodings that are possible according to the model definition"""
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
# Constructor super class #
|
|
49
|
+
def __init__(self, task: str, encoding: str = None):
|
|
50
|
+
self.task = task
|
|
51
|
+
self.encoding = self.standard_encoding if encoding is None else encoding
|
|
52
|
+
|
|
53
|
+
# Methods required by each child class #
|
|
54
|
+
|
|
55
|
+
@abc.abstractmethod
|
|
56
|
+
def fit(self, X: np.array, y: np.array) -> np.array:
|
|
57
|
+
"""
|
|
58
|
+
Method that fits the model based on features X and targets y
|
|
59
|
+
|
|
60
|
+
:param X: feature matrix for retraining
|
|
61
|
+
:param y: target vector
|
|
62
|
+
|
|
63
|
+
:return: numpy array with values predicted for X
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
@abc.abstractmethod
|
|
67
|
+
def predict(self, X_in: np.array) -> np.array:
|
|
68
|
+
"""
|
|
69
|
+
Method that predicts target values based on the input X_in
|
|
70
|
+
|
|
71
|
+
:param X_in: feature matrix as input
|
|
72
|
+
|
|
73
|
+
:return: numpy array with the predicted values
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def save_model(self, path: pathlib.Path, filename: str):
|
|
77
|
+
"""
|
|
78
|
+
Persist the whole model object on a hard drive
|
|
79
|
+
(can be loaded with :obj:`~easypheno.model._model_functions.load_model`)
|
|
80
|
+
|
|
81
|
+
:param path: path where the model will be saved
|
|
82
|
+
:param filename: filename of the model
|
|
83
|
+
"""
|
|
84
|
+
joblib.dump(self, path.joinpath(filename), compress=3)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from . import _bayesfromR
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BayesC(_bayesfromR.Bayes_R):
|
|
5
|
+
"""
|
|
6
|
+
Implementation of a class for Bayes C.
|
|
7
|
+
|
|
8
|
+
*Attributes*
|
|
9
|
+
|
|
10
|
+
*Inherited attributes*
|
|
11
|
+
|
|
12
|
+
See :obj:`~easypheno.model._bayesfromR.Bayes_R` for more information on the attributes.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, task: str, encoding: str = None):
|
|
16
|
+
super().__init__(task=task, model_name='BayesC', encoding=encoding)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import psutil
|
|
4
|
+
import swanlab
|
|
5
|
+
import argparse
|
|
6
|
+
import random
|
|
7
|
+
import torch
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
|
11
|
+
from sklearn.preprocessing import LabelEncoder
|
|
12
|
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
13
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
14
|
+
from .base_CropARNet_class import SimpleSNPModel
|
|
15
|
+
from . import CropARNet_he_class
|
|
16
|
+
import pynvml
|
|
17
|
+
|
|
18
|
+
def parse_args():
|
|
19
|
+
parser = argparse.ArgumentParser(description="Argument parser")
|
|
20
|
+
parser.add_argument('--methods', type=str, default='CropARNet/')
|
|
21
|
+
parser.add_argument('--species', type=str, default='')
|
|
22
|
+
parser.add_argument('--phe', type=str, default='')
|
|
23
|
+
parser.add_argument('--data_dir', type=str, default='../../data/')
|
|
24
|
+
parser.add_argument('--result_dir', type=str, default='result/')
|
|
25
|
+
|
|
26
|
+
parser.add_argument('--epochs', type=int, default=500)
|
|
27
|
+
parser.add_argument('--batch_size', type=int, default=32)
|
|
28
|
+
parser.add_argument('--weight_decay', type=float, default=1e-5)
|
|
29
|
+
parser.add_argument('--momentum', type=float, default=0.5)
|
|
30
|
+
parser.add_argument('--learning_rate', type=float, default=0.01)
|
|
31
|
+
parser.add_argument('--patience', type=int, default=50)
|
|
32
|
+
return parser.parse_args()
|
|
33
|
+
|
|
34
|
+
def load_data(args):
|
|
35
|
+
xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
|
|
36
|
+
yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
|
|
37
|
+
names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
|
|
38
|
+
|
|
39
|
+
print("Number of samples:", xData.shape[0])
|
|
40
|
+
print("Number of SNPs:", xData.shape[1])
|
|
41
|
+
return xData, yData, xData.shape[0], xData.shape[1], names
|
|
42
|
+
|
|
43
|
+
def set_seed(seed=42):
|
|
44
|
+
random.seed(seed)
|
|
45
|
+
np.random.seed(seed)
|
|
46
|
+
torch.manual_seed(seed)
|
|
47
|
+
torch.cuda.manual_seed_all(seed)
|
|
48
|
+
torch.backends.cudnn.deterministic = True
|
|
49
|
+
torch.backends.cudnn.benchmark = False
|
|
50
|
+
|
|
51
|
+
def get_gpu_mem_by_pid(pid, handle):
|
|
52
|
+
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
|
53
|
+
for p in procs:
|
|
54
|
+
if p.pid == pid:
|
|
55
|
+
return p.usedGpuMemory / 1024**2
|
|
56
|
+
return 0.0
|
|
57
|
+
|
|
58
|
+
def run_nested_cv(args, data, label, nsnp, num_classes, device, handle=None, le=None):
|
|
59
|
+
result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
|
|
60
|
+
os.makedirs(result_dir, exist_ok=True)
|
|
61
|
+
|
|
62
|
+
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
|
|
63
|
+
|
|
64
|
+
all_acc, all_prec, all_rec, all_f1 = [], [], [], []
|
|
65
|
+
time_star = time.time()
|
|
66
|
+
|
|
67
|
+
for fold, (train_index, test_index) in enumerate(kf.split(data, label)):
|
|
68
|
+
print(f"\nRunning fold {fold}...")
|
|
69
|
+
process = psutil.Process(os.getpid())
|
|
70
|
+
fold_start = time.time()
|
|
71
|
+
|
|
72
|
+
X_train, X_test = data[train_index], data[test_index]
|
|
73
|
+
y_train, y_test = label[train_index], label[test_index]
|
|
74
|
+
|
|
75
|
+
X_tr, X_val, y_tr, y_val = train_test_split(
|
|
76
|
+
X_train, y_train, test_size=0.1, random_state=42, stratify=y_train
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
x_train = torch.from_numpy(X_tr).float().to(device)
|
|
80
|
+
y_train_t = torch.from_numpy(y_tr).long().to(device)
|
|
81
|
+
x_valid = torch.from_numpy(X_val).float().to(device)
|
|
82
|
+
y_valid_t = torch.from_numpy(y_val).long().to(device)
|
|
83
|
+
x_test = torch.from_numpy(X_test).float().to(device)
|
|
84
|
+
y_test_t = torch.from_numpy(y_test).long().to(device)
|
|
85
|
+
|
|
86
|
+
train_loader = DataLoader(TensorDataset(x_train, y_train_t), args.batch_size, shuffle=True)
|
|
87
|
+
valid_loader = DataLoader(TensorDataset(x_valid, y_valid_t), args.batch_size, shuffle=False)
|
|
88
|
+
test_loader = DataLoader(TensorDataset(x_test, y_test_t), args.batch_size, shuffle=False)
|
|
89
|
+
|
|
90
|
+
model = SimpleSNPModel(num_snps = nsnp, num_classes=num_classes)
|
|
91
|
+
model.train_model(
|
|
92
|
+
train_loader,
|
|
93
|
+
valid_loader,
|
|
94
|
+
args.epochs,
|
|
95
|
+
args.learning_rate,
|
|
96
|
+
args.weight_decay,
|
|
97
|
+
args.patience,
|
|
98
|
+
device
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
y_pred = model.predict(test_loader)
|
|
102
|
+
|
|
103
|
+
acc = accuracy_score(y_test, y_pred)
|
|
104
|
+
prec, rec, f1, _ = precision_recall_fscore_support(
|
|
105
|
+
y_test, y_pred, average="macro", zero_division=0
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
all_acc.append(acc)
|
|
109
|
+
all_prec.append(prec)
|
|
110
|
+
all_rec.append(rec)
|
|
111
|
+
all_f1.append(f1)
|
|
112
|
+
|
|
113
|
+
fold_time = time.time() - fold_start
|
|
114
|
+
gpu_mem = get_gpu_mem_by_pid(os.getpid(), handle) if handle else 0.0
|
|
115
|
+
cpu_mem = process.memory_info().rss / 1024**2
|
|
116
|
+
|
|
117
|
+
print(
|
|
118
|
+
f"Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, "
|
|
119
|
+
f"REC={rec:.4f}, F1={f1:.4f}, "
|
|
120
|
+
f"Time={fold_time:.2f}s, GPU={gpu_mem:.2f}MB, CPU={cpu_mem:.2f}MB"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
pd.DataFrame({
|
|
124
|
+
"y_true": le.inverse_transform(y_test),
|
|
125
|
+
"y_pred": le.inverse_transform(y_pred)
|
|
126
|
+
}).to_csv(
|
|
127
|
+
os.path.join(result_dir, f"fold{fold}.csv"), index=False
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
torch.cuda.empty_cache()
|
|
131
|
+
|
|
132
|
+
print("\n===== Cross-validation summary =====")
|
|
133
|
+
print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
|
|
134
|
+
print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
|
|
135
|
+
print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
|
|
136
|
+
print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
|
|
137
|
+
print(f"Time: {time.time() - time_star:.2f}s")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def CropARNet_class():
|
|
141
|
+
set_seed(42)
|
|
142
|
+
pynvml.nvmlInit()
|
|
143
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
144
|
+
|
|
145
|
+
args = parse_args()
|
|
146
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
147
|
+
all_species = ["Human/Sim/"]
|
|
148
|
+
|
|
149
|
+
for sp in all_species:
|
|
150
|
+
args.species = sp
|
|
151
|
+
X, Y, _, nsnp, names = load_data(args)
|
|
152
|
+
|
|
153
|
+
print("Starting:", args.methods + args.species)
|
|
154
|
+
if Y.ndim == 1:
|
|
155
|
+
label = Y
|
|
156
|
+
else:
|
|
157
|
+
label = Y[:, 0]
|
|
158
|
+
|
|
159
|
+
label = np.nan_to_num(label, nan=np.nanmean(label))
|
|
160
|
+
|
|
161
|
+
le = LabelEncoder()
|
|
162
|
+
label = le.fit_transform(label)
|
|
163
|
+
num_classes = len(np.unique(label))
|
|
164
|
+
|
|
165
|
+
result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
|
|
166
|
+
os.makedirs(result_dir, exist_ok=True)
|
|
167
|
+
np.save(os.path.join(result_dir, 'label_mapping.npy'), le.classes_)
|
|
168
|
+
|
|
169
|
+
best_params = CropARNet_he_class.Hyperparameter(X, label, nsnp)
|
|
170
|
+
args.learning_rate = best_params["learning_rate"]
|
|
171
|
+
args.batch_size = best_params["batch_size"]
|
|
172
|
+
args.weight_decay = best_params["weight_decay"]
|
|
173
|
+
args.patience = best_params["patience"]
|
|
174
|
+
start_time = time.time()
|
|
175
|
+
torch.cuda.reset_peak_memory_stats()
|
|
176
|
+
process = psutil.Process(os.getpid())
|
|
177
|
+
|
|
178
|
+
run_nested_cv(args, X, label, nsnp, num_classes, device, handle, le)
|
|
179
|
+
elapsed_time = time.time() - start_time
|
|
180
|
+
print(f"运行时间: {elapsed_time:.2f} 秒")
|
|
181
|
+
print("successfully")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
CropARNet_class()
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import psutil
|
|
4
|
+
import random
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
from sklearn.model_selection import KFold, train_test_split
|
|
9
|
+
from sklearn.preprocessing import LabelEncoder
|
|
10
|
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
11
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
12
|
+
from optuna.exceptions import TrialPruned
|
|
13
|
+
from .base_CropARNet_class import SimpleSNPModel
|
|
14
|
+
|
|
15
|
+
def run_nested_cv_with_early_stopping(
|
|
16
|
+
data,
|
|
17
|
+
label,
|
|
18
|
+
nsnp,
|
|
19
|
+
num_classes,
|
|
20
|
+
learning_rate,
|
|
21
|
+
weight_decay,
|
|
22
|
+
patience,
|
|
23
|
+
batch_size,
|
|
24
|
+
num_round=500
|
|
25
|
+
):
|
|
26
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
27
|
+
print("Starting 10-fold cross-validation...")
|
|
28
|
+
|
|
29
|
+
kf = KFold(n_splits=10, shuffle=True, random_state=42)
|
|
30
|
+
all_acc, all_prec, all_rec, all_f1 = [], [], [], []
|
|
31
|
+
|
|
32
|
+
for fold, (train_index, test_index) in enumerate(kf.split(data)):
|
|
33
|
+
print(f"Running fold {fold}...")
|
|
34
|
+
process = psutil.Process(os.getpid())
|
|
35
|
+
fold_start_time = time.time()
|
|
36
|
+
|
|
37
|
+
X_train, X_test = data[train_index], data[test_index]
|
|
38
|
+
y_train, y_test = label[train_index], label[test_index]
|
|
39
|
+
|
|
40
|
+
X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
|
|
41
|
+
X_train, y_train, test_size=0.1, random_state=42
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
|
|
45
|
+
y_train_tensor = torch.from_numpy(y_train_sub).long().to(device)
|
|
46
|
+
|
|
47
|
+
x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
|
|
48
|
+
y_valid_tensor = torch.from_numpy(y_valid).long().to(device)
|
|
49
|
+
|
|
50
|
+
x_test_tensor = torch.from_numpy(X_test).float().to(device)
|
|
51
|
+
y_test_tensor = torch.from_numpy(y_test).long().to(device)
|
|
52
|
+
|
|
53
|
+
train_loader = DataLoader(
|
|
54
|
+
TensorDataset(x_train_tensor, y_train_tensor),
|
|
55
|
+
batch_size=batch_size,
|
|
56
|
+
shuffle=True
|
|
57
|
+
)
|
|
58
|
+
valid_loader = DataLoader(
|
|
59
|
+
TensorDataset(x_valid_tensor, y_valid_tensor),
|
|
60
|
+
batch_size=batch_size,
|
|
61
|
+
shuffle=False
|
|
62
|
+
)
|
|
63
|
+
test_loader = DataLoader(
|
|
64
|
+
TensorDataset(x_test_tensor, y_test_tensor),
|
|
65
|
+
batch_size=batch_size,
|
|
66
|
+
shuffle=False
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
model = SimpleSNPModel(num_snps = nsnp, num_classes=num_classes)
|
|
70
|
+
|
|
71
|
+
model.train_model(
|
|
72
|
+
train_loader,
|
|
73
|
+
valid_loader,
|
|
74
|
+
num_round,
|
|
75
|
+
learning_rate,
|
|
76
|
+
weight_decay,
|
|
77
|
+
patience,
|
|
78
|
+
device
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
y_pred = model.predict(test_loader)
|
|
82
|
+
acc = accuracy_score(y_test, y_pred)
|
|
83
|
+
prec, rec, f1, _ = precision_recall_fscore_support(
|
|
84
|
+
y_test,
|
|
85
|
+
y_pred,
|
|
86
|
+
average="macro",
|
|
87
|
+
zero_division=0
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if np.isnan(f1):
|
|
91
|
+
print(f"Fold {fold} resulted in NaN F1, pruning trial...")
|
|
92
|
+
raise TrialPruned()
|
|
93
|
+
|
|
94
|
+
all_acc.append(acc)
|
|
95
|
+
all_prec.append(prec)
|
|
96
|
+
all_rec.append(rec)
|
|
97
|
+
all_f1.append(f1)
|
|
98
|
+
|
|
99
|
+
fold_time = time.time() - fold_start_time
|
|
100
|
+
fold_cpu_mem = process.memory_info().rss / 1024**2
|
|
101
|
+
|
|
102
|
+
print(
|
|
103
|
+
f"Fold {fold}: ACC={acc:.4f}, "
|
|
104
|
+
f"PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, "
|
|
105
|
+
f"Time={fold_time:.2f}s, CPU={fold_cpu_mem:.2f}MB"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return np.mean(all_f1) if all_f1 else 0.0
|
|
109
|
+
|
|
110
|
+
def set_seed(seed=42):
|
|
111
|
+
random.seed(seed)
|
|
112
|
+
np.random.seed(seed)
|
|
113
|
+
torch.manual_seed(seed)
|
|
114
|
+
if torch.cuda.is_available():
|
|
115
|
+
torch.cuda.manual_seed_all(seed)
|
|
116
|
+
torch.backends.cudnn.deterministic = True
|
|
117
|
+
torch.backends.cudnn.benchmark = False
|
|
118
|
+
|
|
119
|
+
def Hyperparameter(data, label, nsnp):
|
|
120
|
+
set_seed(42)
|
|
121
|
+
le = LabelEncoder()
|
|
122
|
+
label = le.fit_transform(label)
|
|
123
|
+
num_classes = len(np.unique(label))
|
|
124
|
+
|
|
125
|
+
def objective(trial):
|
|
126
|
+
learning_rate = trial.suggest_float("learning_rate", 1e-4, 0.1, log=True)
|
|
127
|
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
|
|
128
|
+
weight_decay = trial.suggest_categorical(
|
|
129
|
+
"weight_decay", [1e-4, 1e-3, 1e-2, 1e-1]
|
|
130
|
+
)
|
|
131
|
+
patience = trial.suggest_int("patience", 5, 30)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
f1_score = run_nested_cv_with_early_stopping(
|
|
135
|
+
data=data,
|
|
136
|
+
label=label,
|
|
137
|
+
nsnp=nsnp,
|
|
138
|
+
num_classes=num_classes,
|
|
139
|
+
learning_rate=learning_rate,
|
|
140
|
+
weight_decay=weight_decay,
|
|
141
|
+
patience=patience,
|
|
142
|
+
batch_size=batch_size
|
|
143
|
+
)
|
|
144
|
+
except TrialPruned:
|
|
145
|
+
return float("-inf")
|
|
146
|
+
|
|
147
|
+
return f1_score
|
|
148
|
+
|
|
149
|
+
study = optuna.create_study(direction="maximize")
|
|
150
|
+
study.optimize(objective, n_trials=20)
|
|
151
|
+
|
|
152
|
+
print("Best hyperparameters:", study.best_params)
|
|
153
|
+
print("successfully")
|
|
154
|
+
return study.best_params
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
config = {
|
|
6
|
+
"batch_size": 64,
|
|
7
|
+
"weights_units": [64, 32],
|
|
8
|
+
"classifier_units": [64, 32],
|
|
9
|
+
"dropout": 0.3,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SimpleSNPModel(nn.Module):
|
|
14
|
+
"""
|
|
15
|
+
Classification version of SimpleSNPModel
|
|
16
|
+
(Attention + Residual + MLP)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, num_snps: int, num_classes: int):
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
if not isinstance(num_snps, int) or num_snps <= 0:
|
|
23
|
+
raise ValueError(f"num_snps must be positive integer, got {num_snps}")
|
|
24
|
+
if not isinstance(num_classes, int) or num_classes <= 1:
|
|
25
|
+
raise ValueError(f"num_classes must be >=2, got {num_classes}")
|
|
26
|
+
|
|
27
|
+
self.config = config
|
|
28
|
+
self.num_classes = num_classes
|
|
29
|
+
|
|
30
|
+
self.attention = self._build_attention_module(num_snps)
|
|
31
|
+
self.classifier = self._build_classifier_module(num_snps, num_classes)
|
|
32
|
+
|
|
33
|
+
# ==================================================
|
|
34
|
+
# Attention module
|
|
35
|
+
# ==================================================
|
|
36
|
+
def _build_attention_module(self, num_snps):
|
|
37
|
+
layers = []
|
|
38
|
+
prev_size = num_snps
|
|
39
|
+
|
|
40
|
+
for i, h_size in enumerate(self.config["weights_units"]):
|
|
41
|
+
if h_size <= 0:
|
|
42
|
+
raise ValueError(f"Invalid hidden size {h_size}")
|
|
43
|
+
layers.append(nn.Linear(prev_size, h_size))
|
|
44
|
+
layers.append(nn.GELU())
|
|
45
|
+
prev_size = h_size
|
|
46
|
+
|
|
47
|
+
layers.append(nn.Linear(prev_size, num_snps))
|
|
48
|
+
layers.append(nn.Sigmoid())
|
|
49
|
+
return nn.Sequential(*layers)
|
|
50
|
+
|
|
51
|
+
# ==================================================
|
|
52
|
+
# Classifier module
|
|
53
|
+
# ==================================================
|
|
54
|
+
def _build_classifier_module(self, num_snps, num_classes):
|
|
55
|
+
layers = []
|
|
56
|
+
prev_size = num_snps
|
|
57
|
+
|
|
58
|
+
for i, h_size in enumerate(self.config["classifier_units"]):
|
|
59
|
+
if h_size <= 0:
|
|
60
|
+
raise ValueError(f"Invalid hidden size {h_size}")
|
|
61
|
+
layers.append(nn.Linear(prev_size, h_size))
|
|
62
|
+
layers.append(nn.LayerNorm(h_size))
|
|
63
|
+
layers.append(nn.GELU())
|
|
64
|
+
layers.append(nn.Dropout(self.config["dropout"]))
|
|
65
|
+
prev_size = h_size
|
|
66
|
+
|
|
67
|
+
layers.append(nn.Linear(prev_size, num_classes))
|
|
68
|
+
return nn.Sequential(*layers)
|
|
69
|
+
|
|
70
|
+
# ==================================================
|
|
71
|
+
# Forward
|
|
72
|
+
# ==================================================
|
|
73
|
+
def forward(self, x):
|
|
74
|
+
if x.dim() != 2:
|
|
75
|
+
raise ValueError(f"Input must be 2D tensor, got {x.dim()}D")
|
|
76
|
+
|
|
77
|
+
# Attention
|
|
78
|
+
pre_sigmoid_weights = self.attention[:-1](x)
|
|
79
|
+
att_weights = self.attention(x)
|
|
80
|
+
|
|
81
|
+
# Residual weighted SNPs
|
|
82
|
+
weighted = x * att_weights + x
|
|
83
|
+
|
|
84
|
+
logits = self.classifier(weighted)
|
|
85
|
+
return logits, pre_sigmoid_weights
|
|
86
|
+
|
|
87
|
+
# ==================================================
|
|
88
|
+
# Training (classification)
|
|
89
|
+
# ==================================================
|
|
90
|
+
def train_model(
|
|
91
|
+
self,
|
|
92
|
+
train_loader,
|
|
93
|
+
valid_loader,
|
|
94
|
+
num_epochs,
|
|
95
|
+
learning_rate,
|
|
96
|
+
weight_decay,
|
|
97
|
+
patience,
|
|
98
|
+
device
|
|
99
|
+
):
|
|
100
|
+
self.to(device)
|
|
101
|
+
|
|
102
|
+
optimizer = torch.optim.AdamW(
|
|
103
|
+
self.parameters(),
|
|
104
|
+
lr=learning_rate,
|
|
105
|
+
weight_decay=weight_decay
|
|
106
|
+
)
|
|
107
|
+
criterion = nn.CrossEntropyLoss()
|
|
108
|
+
|
|
109
|
+
best_loss = float("inf")
|
|
110
|
+
best_state = None
|
|
111
|
+
trigger_times = 0
|
|
112
|
+
|
|
113
|
+
for epoch in range(num_epochs):
|
|
114
|
+
# -------- Train --------
|
|
115
|
+
self.train()
|
|
116
|
+
train_loss = 0.0
|
|
117
|
+
|
|
118
|
+
for inputs, labels in train_loader:
|
|
119
|
+
inputs = inputs.to(device)
|
|
120
|
+
labels = labels.to(device).long()
|
|
121
|
+
|
|
122
|
+
optimizer.zero_grad()
|
|
123
|
+
outputs, _ = self(inputs)
|
|
124
|
+
loss = criterion(outputs, labels)
|
|
125
|
+
loss.backward()
|
|
126
|
+
optimizer.step()
|
|
127
|
+
|
|
128
|
+
train_loss += loss.item() * inputs.size(0)
|
|
129
|
+
|
|
130
|
+
train_loss /= len(train_loader.dataset)
|
|
131
|
+
|
|
132
|
+
# -------- Validation --------
|
|
133
|
+
self.eval()
|
|
134
|
+
valid_loss = 0.0
|
|
135
|
+
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
for inputs, labels in valid_loader:
|
|
138
|
+
inputs = inputs.to(device)
|
|
139
|
+
labels = labels.to(device).long()
|
|
140
|
+
|
|
141
|
+
outputs, _ = self(inputs)
|
|
142
|
+
loss = criterion(outputs, labels)
|
|
143
|
+
valid_loss += loss.item() * inputs.size(0)
|
|
144
|
+
|
|
145
|
+
valid_loss /= len(valid_loader.dataset)
|
|
146
|
+
|
|
147
|
+
# -------- Early stopping --------
|
|
148
|
+
if valid_loss < best_loss:
|
|
149
|
+
best_loss = valid_loss
|
|
150
|
+
best_state = self.state_dict()
|
|
151
|
+
trigger_times = 0
|
|
152
|
+
else:
|
|
153
|
+
trigger_times += 1
|
|
154
|
+
if trigger_times >= patience:
|
|
155
|
+
print(f"Early stopping at epoch {epoch + 1}")
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
if best_state is not None:
|
|
159
|
+
self.load_state_dict(best_state)
|
|
160
|
+
|
|
161
|
+
return best_loss
|
|
162
|
+
|
|
163
|
+
# ==================================================
|
|
164
|
+
# Prediction (classification)
|
|
165
|
+
# ==================================================
|
|
166
|
+
def predict(self, test_loader):
|
|
167
|
+
self.eval()
|
|
168
|
+
device = next(self.parameters()).device
|
|
169
|
+
y_pred = []
|
|
170
|
+
|
|
171
|
+
with torch.no_grad():
|
|
172
|
+
for inputs, _ in test_loader:
|
|
173
|
+
inputs = inputs.to(device)
|
|
174
|
+
outputs, _ = self(inputs) # logits
|
|
175
|
+
preds = torch.argmax(outputs, dim=1)
|
|
176
|
+
y_pred.append(preds.cpu().numpy())
|
|
177
|
+
|
|
178
|
+
return np.concatenate(y_pred, axis=0)
|