gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import argparse
|
|
3
|
+
import pynvml
|
|
4
|
+
import psutil
|
|
5
|
+
import os
|
|
6
|
+
import torch
|
|
7
|
+
import random
|
|
8
|
+
import swanlab
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from torch.utils.data import DataLoader,TensorDataset
|
|
12
|
+
from sklearn.model_selection import KFold, train_test_split
|
|
13
|
+
import sys
|
|
14
|
+
sys.path.append("..")
|
|
15
|
+
from .utils.models_locally_connected import LCLModel
|
|
16
|
+
from .utils.common import DataDimensions
|
|
17
|
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
|
18
|
+
from scipy.stats import pearsonr
|
|
19
|
+
from .EIR_Hyperparameters import Hyperparameter
|
|
20
|
+
os.environ['CUDA_VISIBLE_DEVICE'] = '0'
|
|
21
|
+
|
|
22
|
+
def parse_args():
|
|
23
|
+
parser = argparse.ArgumentParser(description="Argument parser")
|
|
24
|
+
parser.add_argument('--methods', type=str, default='EIR/', help='Random seed')
|
|
25
|
+
parser.add_argument('--species', type=str, default='', help='Species name')
|
|
26
|
+
parser.add_argument('--phe', type=str, default='', help='Dataset name')
|
|
27
|
+
parser.add_argument('--data_dir', type=str, default='../../data/')
|
|
28
|
+
parser.add_argument('--result_dir', type=str, default='result/')
|
|
29
|
+
|
|
30
|
+
parser.add_argument('--epochs', type=int, default=1000, help='Number of training rounds')
|
|
31
|
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
|
32
|
+
parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
|
33
|
+
parser.add_argument('--patience', type=int, default=50, help='Patience for early stopping')
|
|
34
|
+
args = parser.parse_args()
|
|
35
|
+
return args
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def load_data(args):
|
|
39
|
+
xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
|
|
40
|
+
yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
|
|
41
|
+
names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
|
|
42
|
+
|
|
43
|
+
nsample = xData.shape[0]
|
|
44
|
+
nsnp = xData.shape[1]
|
|
45
|
+
print("Number of samples: ", nsample)
|
|
46
|
+
print("Number of SNPs: ", nsnp)
|
|
47
|
+
return xData, yData, nsample, nsnp, names
|
|
48
|
+
|
|
49
|
+
def get_gpu_mem_by_pid(pid):
|
|
50
|
+
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
|
51
|
+
for p in procs:
|
|
52
|
+
if p.pid == pid:
|
|
53
|
+
return p.usedGpuMemory / 1024**2
|
|
54
|
+
return 0.0
|
|
55
|
+
|
|
56
|
+
def init():
|
|
57
|
+
seed = 42
|
|
58
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
59
|
+
random.seed(seed)
|
|
60
|
+
np.random.seed(seed)
|
|
61
|
+
torch.manual_seed(seed)
|
|
62
|
+
torch.cuda.manual_seed(seed)
|
|
63
|
+
|
|
64
|
+
def set_seed(seed=42):
|
|
65
|
+
random.seed(seed)
|
|
66
|
+
np.random.seed(seed)
|
|
67
|
+
torch.manual_seed(seed)
|
|
68
|
+
torch.cuda.manual_seed_all(seed)
|
|
69
|
+
torch.backends.cudnn.deterministic = True
|
|
70
|
+
torch.backends.cudnn.benchmark = False
|
|
71
|
+
|
|
72
|
+
def one_hot_encode_ATCG(char):
|
|
73
|
+
c = char
|
|
74
|
+
if char==0:
|
|
75
|
+
return [1,0,0,0]
|
|
76
|
+
elif char==1:
|
|
77
|
+
return [0,1,0,0]
|
|
78
|
+
elif char==2:
|
|
79
|
+
return [0,0,1,0]
|
|
80
|
+
else:
|
|
81
|
+
return [0,0,0,1]
|
|
82
|
+
|
|
83
|
+
def one_hot_seq(df:pd.DataFrame, nsnp:int):
|
|
84
|
+
one_hot_df = df.applymap(lambda x:one_hot_encode_ATCG(x))
|
|
85
|
+
one_hot_df = one_hot_df.values.tolist()
|
|
86
|
+
one_hot_df = np.array(one_hot_df)
|
|
87
|
+
one_hot_df = np.reshape(one_hot_df, (one_hot_df.shape[0],-1))
|
|
88
|
+
tensor_data = torch.Tensor(one_hot_df)
|
|
89
|
+
return tensor_data
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs, patience, device):
|
|
93
|
+
|
|
94
|
+
model.to(device)
|
|
95
|
+
best_loss = float('inf')
|
|
96
|
+
best_state = None
|
|
97
|
+
trigger_times = 0
|
|
98
|
+
|
|
99
|
+
for epoch in range(num_epochs):
|
|
100
|
+
model.train()
|
|
101
|
+
train_loss = 0.0
|
|
102
|
+
for inputs, labels in train_loader:
|
|
103
|
+
inputs, labels = inputs.to(device), labels.to(device)
|
|
104
|
+
optimizer.zero_grad()
|
|
105
|
+
outputs = model(inputs)
|
|
106
|
+
labels = labels.unsqueeze(1)
|
|
107
|
+
loss = criterion(outputs, labels)
|
|
108
|
+
loss.backward()
|
|
109
|
+
optimizer.step()
|
|
110
|
+
train_loss += loss.item() * inputs.size(0)
|
|
111
|
+
|
|
112
|
+
model.eval()
|
|
113
|
+
valid_loss = 0.0
|
|
114
|
+
with torch.no_grad():
|
|
115
|
+
for inputs, labels in valid_loader:
|
|
116
|
+
inputs, labels = inputs.to(device), labels.to(device)
|
|
117
|
+
outputs = model(inputs)
|
|
118
|
+
labels = labels.unsqueeze(1)
|
|
119
|
+
loss = criterion(outputs, labels)
|
|
120
|
+
valid_loss += loss.item() * inputs.size(0)
|
|
121
|
+
|
|
122
|
+
train_loss /= len(train_loader.dataset)
|
|
123
|
+
valid_loss /= len(valid_loader.dataset)
|
|
124
|
+
|
|
125
|
+
# ---------- Early stopping ----------
|
|
126
|
+
if valid_loss < best_loss:
|
|
127
|
+
best_loss = valid_loss
|
|
128
|
+
best_state = model.state_dict()
|
|
129
|
+
trigger_times = 0
|
|
130
|
+
else:
|
|
131
|
+
trigger_times += 1
|
|
132
|
+
if trigger_times >= patience:
|
|
133
|
+
print(f"Early stopping at epoch {epoch+1}")
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
if best_state is not None:
|
|
137
|
+
model.load_state_dict(best_state)
|
|
138
|
+
return best_loss
|
|
139
|
+
|
|
140
|
+
def predict(model, test_loader, device):
|
|
141
|
+
model.eval()
|
|
142
|
+
y_pred = []
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
for inputs, _ in test_loader:
|
|
145
|
+
inputs = inputs.to(device)
|
|
146
|
+
outputs = model(inputs)
|
|
147
|
+
y_pred.append(outputs.cpu().numpy())
|
|
148
|
+
y_pred = np.concatenate(y_pred, axis=0)
|
|
149
|
+
y_pred = np.squeeze(y_pred)
|
|
150
|
+
return y_pred
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def run_nested_cv(args, data, label, nsnp, device):
|
|
154
|
+
result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
|
|
155
|
+
os.makedirs(result_dir, exist_ok=True)
|
|
156
|
+
print("Starting 10-fold cross-validation...")
|
|
157
|
+
kf = KFold(n_splits=10, shuffle=True, random_state=42)
|
|
158
|
+
data = data.reshape(data.shape[0], 1, nsnp, -1)
|
|
159
|
+
|
|
160
|
+
all_mse, all_mae, all_r2, all_pcc = [], [], [], []
|
|
161
|
+
time_star = time.time()
|
|
162
|
+
for fold, (train_index, test_index) in enumerate(kf.split(data)):
|
|
163
|
+
print(f"Running fold {fold}...")
|
|
164
|
+
process = psutil.Process(os.getpid())
|
|
165
|
+
fold_start_time = time.time()
|
|
166
|
+
|
|
167
|
+
X_train, X_test = data[train_index], data[test_index]
|
|
168
|
+
y_train, y_test = label[train_index], label[test_index]
|
|
169
|
+
|
|
170
|
+
X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
|
|
171
|
+
|
|
172
|
+
x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
|
|
173
|
+
y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
|
|
174
|
+
x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
|
|
175
|
+
y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
|
|
176
|
+
x_test_tensor = torch.from_numpy(X_test).float().to(device)
|
|
177
|
+
y_test_tensor = torch.from_numpy(y_test).float().to(device)
|
|
178
|
+
|
|
179
|
+
train_data = TensorDataset(x_train_tensor, y_train_tensor)
|
|
180
|
+
valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
|
|
181
|
+
test_data = TensorDataset(x_test_tensor, y_test_tensor)
|
|
182
|
+
|
|
183
|
+
train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
|
|
184
|
+
valid_loader = DataLoader(valid_data, args.batch_size, shuffle=False)
|
|
185
|
+
test_loader = DataLoader(test_data, args.batch_size, shuffle=False)
|
|
186
|
+
|
|
187
|
+
model = LCLModel(DataDimensions(channels=1, height=nsnp, width=1)).to(device)
|
|
188
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
|
|
189
|
+
loss_fn = torch.nn.MSELoss()
|
|
190
|
+
|
|
191
|
+
train_model(model, train_loader, valid_loader, optimizer, loss_fn, args.epochs, args.patience, device)
|
|
192
|
+
y_pred = predict(model, test_loader, device)
|
|
193
|
+
|
|
194
|
+
mse = mean_squared_error(y_test, y_pred)
|
|
195
|
+
r2 = r2_score(y_test, y_pred)
|
|
196
|
+
mae = mean_absolute_error(y_test, y_pred)
|
|
197
|
+
pcc, _ = pearsonr(y_test, y_pred)
|
|
198
|
+
|
|
199
|
+
all_mse.append(mse)
|
|
200
|
+
all_r2.append(r2)
|
|
201
|
+
all_mae.append(mae)
|
|
202
|
+
all_pcc.append(pcc)
|
|
203
|
+
|
|
204
|
+
fold_time = time.time() - fold_start_time
|
|
205
|
+
fold_gpu_mem = get_gpu_mem_by_pid(os.getpid())
|
|
206
|
+
fold_cpu_mem = process.memory_info().rss / 1024**2
|
|
207
|
+
print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
|
|
208
|
+
f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
|
|
209
|
+
|
|
210
|
+
torch.cuda.empty_cache()
|
|
211
|
+
torch.cuda.reset_peak_memory_stats()
|
|
212
|
+
results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
|
|
213
|
+
results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
|
|
214
|
+
|
|
215
|
+
print("\n===== Cross-validation summary =====")
|
|
216
|
+
print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
|
|
217
|
+
print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
|
|
218
|
+
print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
|
|
219
|
+
print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
|
|
220
|
+
print(f"Time: {time.time() - time_star:.2f}s")
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def EIR_reg():
|
|
224
|
+
start = time.time()
|
|
225
|
+
set_seed(42)
|
|
226
|
+
|
|
227
|
+
pynvml.nvmlInit()
|
|
228
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
229
|
+
args = parse_args()
|
|
230
|
+
all_species =['Cotton/']
|
|
231
|
+
|
|
232
|
+
for i in range(len(all_species)):
|
|
233
|
+
args.species = all_species[i]
|
|
234
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
235
|
+
args.device = device
|
|
236
|
+
X, Y, nsamples, nsnp, names = load_data(args)
|
|
237
|
+
for j in range(len(names)):
|
|
238
|
+
args.phe = names[j]
|
|
239
|
+
print("starting run " + args.methods + args.species + args.phe)
|
|
240
|
+
label = Y[:, j]
|
|
241
|
+
label = np.nan_to_num(label, nan=np.nanmean(label))
|
|
242
|
+
best_params = Hyperparameter(X, label, nsnp)
|
|
243
|
+
args.learning_rate = best_params['learning_rate']
|
|
244
|
+
args.batch_size = best_params['batch_size']
|
|
245
|
+
args.patience = best_params['patience']
|
|
246
|
+
start_time = time.time()
|
|
247
|
+
torch.cuda.reset_peak_memory_stats()
|
|
248
|
+
process = psutil.Process(os.getpid())
|
|
249
|
+
|
|
250
|
+
run_nested_cv(args, data=X, label=label, nsnp = nsnp, device = args.device)
|
|
251
|
+
|
|
252
|
+
elapsed_time = time.time() - start_time
|
|
253
|
+
print(f"running time: {elapsed_time:.2f} s")
|
|
254
|
+
print("successfully")
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
if __name__ == '__main__':
|
|
258
|
+
EIR_reg()
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import psutil
|
|
4
|
+
import random
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
from sklearn.model_selection import KFold, train_test_split
|
|
9
|
+
from sklearn.preprocessing import StandardScaler
|
|
10
|
+
from scipy.stats import pearsonr
|
|
11
|
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
|
12
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
13
|
+
from optuna.exceptions import TrialPruned
|
|
14
|
+
from .utils.models_locally_connected import LCLModel
|
|
15
|
+
from .utils.common import DataDimensions
|
|
16
|
+
os.environ['CUDA_VISIBLE_DEVICE'] = '0'
|
|
17
|
+
|
|
18
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
19
|
+
|
|
20
|
+
def train_model(model, train_loader, valid_loader, optimizer,criterion, num_epochs, patience, device):
|
|
21
|
+
model.to(device)
|
|
22
|
+
best_loss = float('inf')
|
|
23
|
+
best_state = None
|
|
24
|
+
trigger_times = 0
|
|
25
|
+
|
|
26
|
+
for epoch in range(num_epochs):
|
|
27
|
+
model.train()
|
|
28
|
+
train_loss = 0.0
|
|
29
|
+
for inputs, labels in train_loader:
|
|
30
|
+
inputs, labels = inputs.to(device), labels.to(device)
|
|
31
|
+
optimizer.zero_grad()
|
|
32
|
+
outputs = model(inputs)
|
|
33
|
+
labels = labels.unsqueeze(1)
|
|
34
|
+
loss = criterion(outputs, labels)
|
|
35
|
+
loss.backward()
|
|
36
|
+
optimizer.step()
|
|
37
|
+
train_loss += loss.item() * inputs.size(0)
|
|
38
|
+
# ---------- 验证 ----------
|
|
39
|
+
model.eval()
|
|
40
|
+
valid_loss = 0.0
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
for inputs, labels in valid_loader:
|
|
43
|
+
inputs, labels = inputs.to(device), labels.to(device)
|
|
44
|
+
outputs = model(inputs)
|
|
45
|
+
labels = labels.unsqueeze(1)
|
|
46
|
+
loss = criterion(outputs, labels)
|
|
47
|
+
valid_loss += loss.item() * inputs.size(0)
|
|
48
|
+
|
|
49
|
+
train_loss /= len(train_loader.dataset)
|
|
50
|
+
valid_loss /= len(valid_loader.dataset)
|
|
51
|
+
|
|
52
|
+
# ---------- Early stopping ----------
|
|
53
|
+
if valid_loss < best_loss:
|
|
54
|
+
best_loss = valid_loss
|
|
55
|
+
best_state = model.state_dict()
|
|
56
|
+
trigger_times = 0
|
|
57
|
+
else:
|
|
58
|
+
trigger_times += 1
|
|
59
|
+
if trigger_times >= patience:
|
|
60
|
+
print(f"Early stopping at epoch {epoch+1}")
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
# 恢复最佳参数
|
|
64
|
+
if best_state is not None:
|
|
65
|
+
model.load_state_dict(best_state)
|
|
66
|
+
return best_loss
|
|
67
|
+
|
|
68
|
+
def predict(model, test_loader, device):
|
|
69
|
+
model.eval()
|
|
70
|
+
y_pred = []
|
|
71
|
+
with torch.no_grad():
|
|
72
|
+
for inputs, _ in test_loader:
|
|
73
|
+
outputs = model(inputs)
|
|
74
|
+
y_pred.append(outputs.cpu().numpy())
|
|
75
|
+
y_pred = np.concatenate(y_pred, axis=0)
|
|
76
|
+
y_pred = np.squeeze(y_pred)
|
|
77
|
+
return y_pred
|
|
78
|
+
|
|
79
|
+
def run_nested_cv_with_early_stopping(data, label, nsnp, learning_rate, patience, batch_size,epochs=1000):
|
|
80
|
+
device = torch.device("cuda:0")
|
|
81
|
+
print("Starting 10-fold cross-validation...")
|
|
82
|
+
kf = KFold(n_splits=10, shuffle=True, random_state=42)
|
|
83
|
+
all_mse, all_mae, all_r2, all_pcc = [], [], [], []
|
|
84
|
+
|
|
85
|
+
for fold, (train_index, test_index) in enumerate(kf.split(data)):
|
|
86
|
+
print(f"Running fold {fold}...")
|
|
87
|
+
process = psutil.Process(os.getpid())
|
|
88
|
+
fold_start_time = time.time()
|
|
89
|
+
|
|
90
|
+
X_train, X_test = data[train_index], data[test_index]
|
|
91
|
+
y_train, y_test = label[train_index], label[test_index]
|
|
92
|
+
|
|
93
|
+
X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
|
|
94
|
+
|
|
95
|
+
x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
|
|
96
|
+
y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
|
|
97
|
+
x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
|
|
98
|
+
y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
|
|
99
|
+
x_test_tensor = torch.from_numpy(X_test).float().to(device)
|
|
100
|
+
y_test_tensor = torch.from_numpy(y_test).float().to(device)
|
|
101
|
+
|
|
102
|
+
train_data = TensorDataset(x_train_tensor, y_train_tensor)
|
|
103
|
+
valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
|
|
104
|
+
test_data = TensorDataset(x_test_tensor, y_test_tensor)
|
|
105
|
+
|
|
106
|
+
train_loader = DataLoader(train_data, batch_size, shuffle=True)
|
|
107
|
+
valid_loader = DataLoader(valid_data, batch_size, shuffle=False)
|
|
108
|
+
test_loader = DataLoader(test_data, batch_size, shuffle=False)
|
|
109
|
+
|
|
110
|
+
# 初始化模型
|
|
111
|
+
model = LCLModel(DataDimensions(channels=1, height=nsnp, width=1)).to(device)
|
|
112
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
|
|
113
|
+
loss_fn = torch.nn.MSELoss()
|
|
114
|
+
|
|
115
|
+
train_model(model, train_loader, valid_loader, optimizer,loss_fn, epochs, patience, device)
|
|
116
|
+
y_pred = predict(model, test_loader, device)
|
|
117
|
+
|
|
118
|
+
# 计算评估指标
|
|
119
|
+
mse = mean_squared_error(y_test, y_pred)
|
|
120
|
+
r2 = r2_score(y_test, y_pred)
|
|
121
|
+
mae = mean_absolute_error(y_test, y_pred)
|
|
122
|
+
pcc, _ = pearsonr(y_test, y_pred)
|
|
123
|
+
|
|
124
|
+
if np.isnan(pcc):
|
|
125
|
+
print(f"Fold {fold} resulted in NaN PCC, pruning the trial...")
|
|
126
|
+
raise TrialPruned()
|
|
127
|
+
|
|
128
|
+
# 将结果添加到列表中
|
|
129
|
+
all_mse.append(mse)
|
|
130
|
+
all_r2.append(r2)
|
|
131
|
+
all_mae.append(mae)
|
|
132
|
+
all_pcc.append(pcc)
|
|
133
|
+
|
|
134
|
+
# ====== 每折结束时统计 ======
|
|
135
|
+
fold_time = time.time() - fold_start_time
|
|
136
|
+
#fold_gpu_mem = #torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
|
|
137
|
+
fold_cpu_mem = process.memory_info().rss / 1024**2
|
|
138
|
+
print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
|
|
139
|
+
f'CPU={fold_cpu_mem:.2f}MB')
|
|
140
|
+
|
|
141
|
+
return np.mean(all_pcc) if all_pcc else 0.0
|
|
142
|
+
|
|
143
|
+
def set_seed(seed=42):
|
|
144
|
+
random.seed(seed)
|
|
145
|
+
np.random.seed(seed)
|
|
146
|
+
torch.manual_seed(seed)
|
|
147
|
+
if torch.cuda.is_available():
|
|
148
|
+
torch.cuda.manual_seed_all(seed)
|
|
149
|
+
torch.backends.cudnn.deterministic = True
|
|
150
|
+
torch.backends.cudnn.benchmark = False
|
|
151
|
+
|
|
152
|
+
def Hyperparameter(data, label, nsnp):
|
|
153
|
+
set_seed(42)
|
|
154
|
+
|
|
155
|
+
def objective(trial):
|
|
156
|
+
learning_rate = trial.suggest_loguniform("learning_rate", 1e-4,0.1)
|
|
157
|
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
|
|
158
|
+
patience = trial.suggest_int("patience", 10, 100, step=10)
|
|
159
|
+
try:
|
|
160
|
+
corr_score = run_nested_cv_with_early_stopping(
|
|
161
|
+
data=data,
|
|
162
|
+
label=label,
|
|
163
|
+
nsnp=nsnp,
|
|
164
|
+
learning_rate=learning_rate,
|
|
165
|
+
patience=patience,
|
|
166
|
+
batch_size=batch_size
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
except TrialPruned:
|
|
170
|
+
return float("-inf")
|
|
171
|
+
return corr_score
|
|
172
|
+
|
|
173
|
+
study = optuna.create_study(direction="maximize")
|
|
174
|
+
study.optimize(objective, n_trials=20)
|
|
175
|
+
|
|
176
|
+
print("最佳参数:", study.best_params)
|
|
177
|
+
print("successfully")
|
|
178
|
+
return study.best_params
|
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Literal, Type
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from eir.models.input.array.array_models import al_pre_normalization
|
|
8
|
+
from eir.models.input.array.models_locally_connected import LCLModel, LCLModelConfig
|
|
9
|
+
from eir.models.layers.projection_layers import get_projection_layer
|
|
10
|
+
from eir.models.output.array.output_array_models_cnn import (
|
|
11
|
+
CNNUpscaleModel,
|
|
12
|
+
CNNUpscaleModelConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from eir.setup.input_setup_modules.common import DataDimensions
|
|
17
|
+
|
|
18
|
+
al_array_model_types = Literal["lcl", "cnn"]
|
|
19
|
+
al_output_array_model_classes = Type[LCLModel] | Type[CNNUpscaleModel]
|
|
20
|
+
al_output_array_models = LCLModel | CNNUpscaleModel
|
|
21
|
+
al_output_array_model_config_classes = (
|
|
22
|
+
Type["LCLOutputModelConfig"] | Type[CNNUpscaleModelConfig]
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class LCLOutputModelConfig(LCLModelConfig):
|
|
28
|
+
cutoff: int | Literal["auto"] = "auto"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class ArrayOutputModuleConfig:
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
:param model_type:
|
|
36
|
+
Which type of image model to use.
|
|
37
|
+
|
|
38
|
+
:param model_init_config:
|
|
39
|
+
Configuration used to initialise model.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
model_type: al_array_model_types
|
|
43
|
+
model_init_config: LCLOutputModelConfig
|
|
44
|
+
pre_normalization: al_pre_normalization = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ArrayOutputWrapperModule(nn.Module):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
feature_extractor: al_output_array_models,
|
|
51
|
+
output_name: str,
|
|
52
|
+
target_data_dimensions: "DataDimensions",
|
|
53
|
+
):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.feature_extractor = feature_extractor
|
|
56
|
+
self.output_name = output_name
|
|
57
|
+
self.data_dimensions = target_data_dimensions
|
|
58
|
+
|
|
59
|
+
self.target_width = self.data_dimensions.num_elements()
|
|
60
|
+
self.target_shape = self.data_dimensions.full_shape()
|
|
61
|
+
|
|
62
|
+
diff_tolerance = get_diff_tolerance(num_target_elements=self.target_width)
|
|
63
|
+
|
|
64
|
+
self.projection_head = get_projection_layer(
|
|
65
|
+
input_dimension=self.feature_extractor.num_out_features,
|
|
66
|
+
target_dimension=self.target_width,
|
|
67
|
+
projection_layer_type="lcl_residual",
|
|
68
|
+
lcl_diff_tolerance=diff_tolerance,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
72
|
+
out = self.feature_extractor(x)
|
|
73
|
+
|
|
74
|
+
out = out.reshape(out.shape[0], -1)
|
|
75
|
+
out = self.projection_head(out)
|
|
76
|
+
|
|
77
|
+
out = out[:, : self.target_width]
|
|
78
|
+
|
|
79
|
+
out = out.reshape(-1, *self.target_shape)
|
|
80
|
+
|
|
81
|
+
return {self.output_name: out}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_diff_tolerance(num_target_elements: int) -> int:
|
|
85
|
+
return int(0.001 * num_target_elements)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_array_output_module(
|
|
89
|
+
feature_extractor: al_output_array_models,
|
|
90
|
+
output_name: str,
|
|
91
|
+
target_data_dimensions: "DataDimensions",
|
|
92
|
+
) -> ArrayOutputWrapperModule:
|
|
93
|
+
return ArrayOutputWrapperModule(
|
|
94
|
+
feature_extractor=feature_extractor,
|
|
95
|
+
output_name=output_name,
|
|
96
|
+
target_data_dimensions=target_data_dimensions,
|
|
97
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
# from eir.data_load.data_source_modules.deeplake_ops import (
|
|
8
|
+
# get_deeplake_input_source_iterable,
|
|
9
|
+
# is_deeplake_dataset,
|
|
10
|
+
# load_deeplake_dataset,
|
|
11
|
+
# )
|
|
12
|
+
# from eir.data_load.label_setup import get_file_path_iterator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DataDimensions:
|
|
17
|
+
channels: int
|
|
18
|
+
height: int
|
|
19
|
+
width: int
|
|
20
|
+
extra_dims: tuple[int, ...] = tuple()
|
|
21
|
+
|
|
22
|
+
def num_elements(self) -> int:
|
|
23
|
+
base = self.channels * self.height * self.width
|
|
24
|
+
return int(base * np.prod(self.extra_dims))
|
|
25
|
+
|
|
26
|
+
def full_shape(self) -> Tuple[int, ...]:
|
|
27
|
+
return (self.channels, self.height, self.width) + self.extra_dims
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# def get_data_dimension_from_data_source(
|
|
31
|
+
# data_source: Path,
|
|
32
|
+
# deeplake_inner_key: Optional[str] = None,
|
|
33
|
+
# ) -> DataDimensions:
|
|
34
|
+
# """
|
|
35
|
+
# TODO: Make more dynamic / robust. Also weird to say "width" for a 1D vector.
|
|
36
|
+
# """
|
|
37
|
+
#
|
|
38
|
+
# if is_deeplake_dataset(data_source=str(data_source)):
|
|
39
|
+
# assert deeplake_inner_key is not None, data_source
|
|
40
|
+
# deeplake_ds = load_deeplake_dataset(data_source=str(data_source))
|
|
41
|
+
# deeplake_iter = get_deeplake_input_source_iterable(
|
|
42
|
+
# deeplake_dataset=deeplake_ds, inner_key=deeplake_inner_key
|
|
43
|
+
# )
|
|
44
|
+
# shape = next(deeplake_iter).shape
|
|
45
|
+
# else:
|
|
46
|
+
# iterator = get_file_path_iterator(data_source=data_source)
|
|
47
|
+
# path = next(iterator)
|
|
48
|
+
# shape = np.load(file=path).shape
|
|
49
|
+
#
|
|
50
|
+
# extra_dims: tuple[int, ...] = tuple()
|
|
51
|
+
# if len(shape) == 1:
|
|
52
|
+
# channels, height, width = 1, 1, shape[0]
|
|
53
|
+
# elif len(shape) == 2:
|
|
54
|
+
# channels, height, width = 1, shape[0], shape[1]
|
|
55
|
+
# elif len(shape) == 3:
|
|
56
|
+
# channels, height, width = shape
|
|
57
|
+
# else:
|
|
58
|
+
# channels, height, width = shape[0], shape[1], shape[2]
|
|
59
|
+
# extra_dims = shape[3:]
|
|
60
|
+
#
|
|
61
|
+
# return DataDimensions(
|
|
62
|
+
# channels=channels, height=height, width=width, extra_dims=extra_dims
|
|
63
|
+
# )
|
|
64
|
+
|
|
65
|
+
|