gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,258 @@
1
+ import time
2
+ import argparse
3
+ import pynvml
4
+ import psutil
5
+ import os
6
+ import torch
7
+ import random
8
+ import swanlab
9
+ import numpy as np
10
+ import pandas as pd
11
+ from torch.utils.data import DataLoader,TensorDataset
12
+ from sklearn.model_selection import KFold, train_test_split
13
+ import sys
14
+ sys.path.append("..")
15
+ from .utils.models_locally_connected import LCLModel
16
+ from .utils.common import DataDimensions
17
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
18
+ from scipy.stats import pearsonr
19
+ from .EIR_Hyperparameters import Hyperparameter
20
+ os.environ['CUDA_VISIBLE_DEVICE'] = '0'
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Argument parser")
24
+ parser.add_argument('--methods', type=str, default='EIR/', help='Random seed')
25
+ parser.add_argument('--species', type=str, default='', help='Species name')
26
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
27
+ parser.add_argument('--data_dir', type=str, default='../../data/')
28
+ parser.add_argument('--result_dir', type=str, default='result/')
29
+
30
+ parser.add_argument('--epochs', type=int, default=1000, help='Number of training rounds')
31
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
32
+ parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
33
+ parser.add_argument('--patience', type=int, default=50, help='Patience for early stopping')
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def load_data(args):
39
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
40
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
41
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
42
+
43
+ nsample = xData.shape[0]
44
+ nsnp = xData.shape[1]
45
+ print("Number of samples: ", nsample)
46
+ print("Number of SNPs: ", nsnp)
47
+ return xData, yData, nsample, nsnp, names
48
+
49
+ def get_gpu_mem_by_pid(pid):
50
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
51
+ for p in procs:
52
+ if p.pid == pid:
53
+ return p.usedGpuMemory / 1024**2
54
+ return 0.0
55
+
56
+ def init():
57
+ seed = 42
58
+ os.environ['PYTHONHASHSEED'] = str(seed)
59
+ random.seed(seed)
60
+ np.random.seed(seed)
61
+ torch.manual_seed(seed)
62
+ torch.cuda.manual_seed(seed)
63
+
64
+ def set_seed(seed=42):
65
+ random.seed(seed)
66
+ np.random.seed(seed)
67
+ torch.manual_seed(seed)
68
+ torch.cuda.manual_seed_all(seed)
69
+ torch.backends.cudnn.deterministic = True
70
+ torch.backends.cudnn.benchmark = False
71
+
72
+ def one_hot_encode_ATCG(char):
73
+ c = char
74
+ if char==0:
75
+ return [1,0,0,0]
76
+ elif char==1:
77
+ return [0,1,0,0]
78
+ elif char==2:
79
+ return [0,0,1,0]
80
+ else:
81
+ return [0,0,0,1]
82
+
83
+ def one_hot_seq(df:pd.DataFrame, nsnp:int):
84
+ one_hot_df = df.applymap(lambda x:one_hot_encode_ATCG(x))
85
+ one_hot_df = one_hot_df.values.tolist()
86
+ one_hot_df = np.array(one_hot_df)
87
+ one_hot_df = np.reshape(one_hot_df, (one_hot_df.shape[0],-1))
88
+ tensor_data = torch.Tensor(one_hot_df)
89
+ return tensor_data
90
+
91
+
92
+ def train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs, patience, device):
93
+
94
+ model.to(device)
95
+ best_loss = float('inf')
96
+ best_state = None
97
+ trigger_times = 0
98
+
99
+ for epoch in range(num_epochs):
100
+ model.train()
101
+ train_loss = 0.0
102
+ for inputs, labels in train_loader:
103
+ inputs, labels = inputs.to(device), labels.to(device)
104
+ optimizer.zero_grad()
105
+ outputs = model(inputs)
106
+ labels = labels.unsqueeze(1)
107
+ loss = criterion(outputs, labels)
108
+ loss.backward()
109
+ optimizer.step()
110
+ train_loss += loss.item() * inputs.size(0)
111
+
112
+ model.eval()
113
+ valid_loss = 0.0
114
+ with torch.no_grad():
115
+ for inputs, labels in valid_loader:
116
+ inputs, labels = inputs.to(device), labels.to(device)
117
+ outputs = model(inputs)
118
+ labels = labels.unsqueeze(1)
119
+ loss = criterion(outputs, labels)
120
+ valid_loss += loss.item() * inputs.size(0)
121
+
122
+ train_loss /= len(train_loader.dataset)
123
+ valid_loss /= len(valid_loader.dataset)
124
+
125
+ # ---------- Early stopping ----------
126
+ if valid_loss < best_loss:
127
+ best_loss = valid_loss
128
+ best_state = model.state_dict()
129
+ trigger_times = 0
130
+ else:
131
+ trigger_times += 1
132
+ if trigger_times >= patience:
133
+ print(f"Early stopping at epoch {epoch+1}")
134
+ break
135
+
136
+ if best_state is not None:
137
+ model.load_state_dict(best_state)
138
+ return best_loss
139
+
140
+ def predict(model, test_loader, device):
141
+ model.eval()
142
+ y_pred = []
143
+ with torch.no_grad():
144
+ for inputs, _ in test_loader:
145
+ inputs = inputs.to(device)
146
+ outputs = model(inputs)
147
+ y_pred.append(outputs.cpu().numpy())
148
+ y_pred = np.concatenate(y_pred, axis=0)
149
+ y_pred = np.squeeze(y_pred)
150
+ return y_pred
151
+
152
+
153
+ def run_nested_cv(args, data, label, nsnp, device):
154
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
155
+ os.makedirs(result_dir, exist_ok=True)
156
+ print("Starting 10-fold cross-validation...")
157
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
158
+ data = data.reshape(data.shape[0], 1, nsnp, -1)
159
+
160
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
161
+ time_star = time.time()
162
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
163
+ print(f"Running fold {fold}...")
164
+ process = psutil.Process(os.getpid())
165
+ fold_start_time = time.time()
166
+
167
+ X_train, X_test = data[train_index], data[test_index]
168
+ y_train, y_test = label[train_index], label[test_index]
169
+
170
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
171
+
172
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
173
+ y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
174
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
175
+ y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
176
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
177
+ y_test_tensor = torch.from_numpy(y_test).float().to(device)
178
+
179
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
180
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
181
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
182
+
183
+ train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
184
+ valid_loader = DataLoader(valid_data, args.batch_size, shuffle=False)
185
+ test_loader = DataLoader(test_data, args.batch_size, shuffle=False)
186
+
187
+ model = LCLModel(DataDimensions(channels=1, height=nsnp, width=1)).to(device)
188
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
189
+ loss_fn = torch.nn.MSELoss()
190
+
191
+ train_model(model, train_loader, valid_loader, optimizer, loss_fn, args.epochs, args.patience, device)
192
+ y_pred = predict(model, test_loader, device)
193
+
194
+ mse = mean_squared_error(y_test, y_pred)
195
+ r2 = r2_score(y_test, y_pred)
196
+ mae = mean_absolute_error(y_test, y_pred)
197
+ pcc, _ = pearsonr(y_test, y_pred)
198
+
199
+ all_mse.append(mse)
200
+ all_r2.append(r2)
201
+ all_mae.append(mae)
202
+ all_pcc.append(pcc)
203
+
204
+ fold_time = time.time() - fold_start_time
205
+ fold_gpu_mem = get_gpu_mem_by_pid(os.getpid())
206
+ fold_cpu_mem = process.memory_info().rss / 1024**2
207
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
208
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
209
+
210
+ torch.cuda.empty_cache()
211
+ torch.cuda.reset_peak_memory_stats()
212
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
213
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
214
+
215
+ print("\n===== Cross-validation summary =====")
216
+ print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
217
+ print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
218
+ print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
219
+ print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
220
+ print(f"Time: {time.time() - time_star:.2f}s")
221
+
222
+
223
+ def EIR_reg():
224
+ start = time.time()
225
+ set_seed(42)
226
+
227
+ pynvml.nvmlInit()
228
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
229
+ args = parse_args()
230
+ all_species =['Cotton/']
231
+
232
+ for i in range(len(all_species)):
233
+ args.species = all_species[i]
234
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
235
+ args.device = device
236
+ X, Y, nsamples, nsnp, names = load_data(args)
237
+ for j in range(len(names)):
238
+ args.phe = names[j]
239
+ print("starting run " + args.methods + args.species + args.phe)
240
+ label = Y[:, j]
241
+ label = np.nan_to_num(label, nan=np.nanmean(label))
242
+ best_params = Hyperparameter(X, label, nsnp)
243
+ args.learning_rate = best_params['learning_rate']
244
+ args.batch_size = best_params['batch_size']
245
+ args.patience = best_params['patience']
246
+ start_time = time.time()
247
+ torch.cuda.reset_peak_memory_stats()
248
+ process = psutil.Process(os.getpid())
249
+
250
+ run_nested_cv(args, data=X, label=label, nsnp = nsnp, device = args.device)
251
+
252
+ elapsed_time = time.time() - start_time
253
+ print(f"running time: {elapsed_time:.2f} s")
254
+ print("successfully")
255
+
256
+
257
+ if __name__ == '__main__':
258
+ EIR_reg()
@@ -0,0 +1,178 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import KFold, train_test_split
9
+ from sklearn.preprocessing import StandardScaler
10
+ from scipy.stats import pearsonr
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+ from optuna.exceptions import TrialPruned
14
+ from .utils.models_locally_connected import LCLModel
15
+ from .utils.common import DataDimensions
16
+ os.environ['CUDA_VISIBLE_DEVICE'] = '0'
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ def train_model(model, train_loader, valid_loader, optimizer,criterion, num_epochs, patience, device):
21
+ model.to(device)
22
+ best_loss = float('inf')
23
+ best_state = None
24
+ trigger_times = 0
25
+
26
+ for epoch in range(num_epochs):
27
+ model.train()
28
+ train_loss = 0.0
29
+ for inputs, labels in train_loader:
30
+ inputs, labels = inputs.to(device), labels.to(device)
31
+ optimizer.zero_grad()
32
+ outputs = model(inputs)
33
+ labels = labels.unsqueeze(1)
34
+ loss = criterion(outputs, labels)
35
+ loss.backward()
36
+ optimizer.step()
37
+ train_loss += loss.item() * inputs.size(0)
38
+ # ---------- 验证 ----------
39
+ model.eval()
40
+ valid_loss = 0.0
41
+ with torch.no_grad():
42
+ for inputs, labels in valid_loader:
43
+ inputs, labels = inputs.to(device), labels.to(device)
44
+ outputs = model(inputs)
45
+ labels = labels.unsqueeze(1)
46
+ loss = criterion(outputs, labels)
47
+ valid_loss += loss.item() * inputs.size(0)
48
+
49
+ train_loss /= len(train_loader.dataset)
50
+ valid_loss /= len(valid_loader.dataset)
51
+
52
+ # ---------- Early stopping ----------
53
+ if valid_loss < best_loss:
54
+ best_loss = valid_loss
55
+ best_state = model.state_dict()
56
+ trigger_times = 0
57
+ else:
58
+ trigger_times += 1
59
+ if trigger_times >= patience:
60
+ print(f"Early stopping at epoch {epoch+1}")
61
+ break
62
+
63
+ # 恢复最佳参数
64
+ if best_state is not None:
65
+ model.load_state_dict(best_state)
66
+ return best_loss
67
+
68
+ def predict(model, test_loader, device):
69
+ model.eval()
70
+ y_pred = []
71
+ with torch.no_grad():
72
+ for inputs, _ in test_loader:
73
+ outputs = model(inputs)
74
+ y_pred.append(outputs.cpu().numpy())
75
+ y_pred = np.concatenate(y_pred, axis=0)
76
+ y_pred = np.squeeze(y_pred)
77
+ return y_pred
78
+
79
+ def run_nested_cv_with_early_stopping(data, label, nsnp, learning_rate, patience, batch_size,epochs=1000):
80
+ device = torch.device("cuda:0")
81
+ print("Starting 10-fold cross-validation...")
82
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
83
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
84
+
85
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
86
+ print(f"Running fold {fold}...")
87
+ process = psutil.Process(os.getpid())
88
+ fold_start_time = time.time()
89
+
90
+ X_train, X_test = data[train_index], data[test_index]
91
+ y_train, y_test = label[train_index], label[test_index]
92
+
93
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
94
+
95
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
96
+ y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
97
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
98
+ y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
99
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
100
+ y_test_tensor = torch.from_numpy(y_test).float().to(device)
101
+
102
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
103
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
104
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
105
+
106
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
107
+ valid_loader = DataLoader(valid_data, batch_size, shuffle=False)
108
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
109
+
110
+ # 初始化模型
111
+ model = LCLModel(DataDimensions(channels=1, height=nsnp, width=1)).to(device)
112
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
113
+ loss_fn = torch.nn.MSELoss()
114
+
115
+ train_model(model, train_loader, valid_loader, optimizer,loss_fn, epochs, patience, device)
116
+ y_pred = predict(model, test_loader, device)
117
+
118
+ # 计算评估指标
119
+ mse = mean_squared_error(y_test, y_pred)
120
+ r2 = r2_score(y_test, y_pred)
121
+ mae = mean_absolute_error(y_test, y_pred)
122
+ pcc, _ = pearsonr(y_test, y_pred)
123
+
124
+ if np.isnan(pcc):
125
+ print(f"Fold {fold} resulted in NaN PCC, pruning the trial...")
126
+ raise TrialPruned()
127
+
128
+ # 将结果添加到列表中
129
+ all_mse.append(mse)
130
+ all_r2.append(r2)
131
+ all_mae.append(mae)
132
+ all_pcc.append(pcc)
133
+
134
+ # ====== 每折结束时统计 ======
135
+ fold_time = time.time() - fold_start_time
136
+ #fold_gpu_mem = #torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
137
+ fold_cpu_mem = process.memory_info().rss / 1024**2
138
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
139
+ f'CPU={fold_cpu_mem:.2f}MB')
140
+
141
+ return np.mean(all_pcc) if all_pcc else 0.0
142
+
143
+ def set_seed(seed=42):
144
+ random.seed(seed)
145
+ np.random.seed(seed)
146
+ torch.manual_seed(seed)
147
+ if torch.cuda.is_available():
148
+ torch.cuda.manual_seed_all(seed)
149
+ torch.backends.cudnn.deterministic = True
150
+ torch.backends.cudnn.benchmark = False
151
+
152
+ def Hyperparameter(data, label, nsnp):
153
+ set_seed(42)
154
+
155
+ def objective(trial):
156
+ learning_rate = trial.suggest_loguniform("learning_rate", 1e-4,0.1)
157
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
158
+ patience = trial.suggest_int("patience", 10, 100, step=10)
159
+ try:
160
+ corr_score = run_nested_cv_with_early_stopping(
161
+ data=data,
162
+ label=label,
163
+ nsnp=nsnp,
164
+ learning_rate=learning_rate,
165
+ patience=patience,
166
+ batch_size=batch_size
167
+ )
168
+
169
+ except TrialPruned:
170
+ return float("-inf")
171
+ return corr_score
172
+
173
+ study = optuna.create_study(direction="maximize")
174
+ study.optimize(objective, n_trials=20)
175
+
176
+ print("最佳参数:", study.best_params)
177
+ print("successfully")
178
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .EIR import EIR_reg
2
+
3
+ EIR = EIR_reg
4
+
5
+ __all__ = ["EIR","EIR_reg"]
File without changes
@@ -0,0 +1,97 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Literal, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from eir.models.input.array.array_models import al_pre_normalization
8
+ from eir.models.input.array.models_locally_connected import LCLModel, LCLModelConfig
9
+ from eir.models.layers.projection_layers import get_projection_layer
10
+ from eir.models.output.array.output_array_models_cnn import (
11
+ CNNUpscaleModel,
12
+ CNNUpscaleModelConfig,
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ from eir.setup.input_setup_modules.common import DataDimensions
17
+
18
+ al_array_model_types = Literal["lcl", "cnn"]
19
+ al_output_array_model_classes = Type[LCLModel] | Type[CNNUpscaleModel]
20
+ al_output_array_models = LCLModel | CNNUpscaleModel
21
+ al_output_array_model_config_classes = (
22
+ Type["LCLOutputModelConfig"] | Type[CNNUpscaleModelConfig]
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class LCLOutputModelConfig(LCLModelConfig):
28
+ cutoff: int | Literal["auto"] = "auto"
29
+
30
+
31
+ @dataclass
32
+ class ArrayOutputModuleConfig:
33
+
34
+ """
35
+ :param model_type:
36
+ Which type of image model to use.
37
+
38
+ :param model_init_config:
39
+ Configuration used to initialise model.
40
+ """
41
+
42
+ model_type: al_array_model_types
43
+ model_init_config: LCLOutputModelConfig
44
+ pre_normalization: al_pre_normalization = None
45
+
46
+
47
+ class ArrayOutputWrapperModule(nn.Module):
48
+ def __init__(
49
+ self,
50
+ feature_extractor: al_output_array_models,
51
+ output_name: str,
52
+ target_data_dimensions: "DataDimensions",
53
+ ):
54
+ super().__init__()
55
+ self.feature_extractor = feature_extractor
56
+ self.output_name = output_name
57
+ self.data_dimensions = target_data_dimensions
58
+
59
+ self.target_width = self.data_dimensions.num_elements()
60
+ self.target_shape = self.data_dimensions.full_shape()
61
+
62
+ diff_tolerance = get_diff_tolerance(num_target_elements=self.target_width)
63
+
64
+ self.projection_head = get_projection_layer(
65
+ input_dimension=self.feature_extractor.num_out_features,
66
+ target_dimension=self.target_width,
67
+ projection_layer_type="lcl_residual",
68
+ lcl_diff_tolerance=diff_tolerance,
69
+ )
70
+
71
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
72
+ out = self.feature_extractor(x)
73
+
74
+ out = out.reshape(out.shape[0], -1)
75
+ out = self.projection_head(out)
76
+
77
+ out = out[:, : self.target_width]
78
+
79
+ out = out.reshape(-1, *self.target_shape)
80
+
81
+ return {self.output_name: out}
82
+
83
+
84
+ def get_diff_tolerance(num_target_elements: int) -> int:
85
+ return int(0.001 * num_target_elements)
86
+
87
+
88
+ def get_array_output_module(
89
+ feature_extractor: al_output_array_models,
90
+ output_name: str,
91
+ target_data_dimensions: "DataDimensions",
92
+ ) -> ArrayOutputWrapperModule:
93
+ return ArrayOutputWrapperModule(
94
+ feature_extractor=feature_extractor,
95
+ output_name=output_name,
96
+ target_data_dimensions=target_data_dimensions,
97
+ )
@@ -0,0 +1,65 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ # from eir.data_load.data_source_modules.deeplake_ops import (
8
+ # get_deeplake_input_source_iterable,
9
+ # is_deeplake_dataset,
10
+ # load_deeplake_dataset,
11
+ # )
12
+ # from eir.data_load.label_setup import get_file_path_iterator
13
+
14
+
15
+ @dataclass
16
+ class DataDimensions:
17
+ channels: int
18
+ height: int
19
+ width: int
20
+ extra_dims: tuple[int, ...] = tuple()
21
+
22
+ def num_elements(self) -> int:
23
+ base = self.channels * self.height * self.width
24
+ return int(base * np.prod(self.extra_dims))
25
+
26
+ def full_shape(self) -> Tuple[int, ...]:
27
+ return (self.channels, self.height, self.width) + self.extra_dims
28
+
29
+
30
+ # def get_data_dimension_from_data_source(
31
+ # data_source: Path,
32
+ # deeplake_inner_key: Optional[str] = None,
33
+ # ) -> DataDimensions:
34
+ # """
35
+ # TODO: Make more dynamic / robust. Also weird to say "width" for a 1D vector.
36
+ # """
37
+ #
38
+ # if is_deeplake_dataset(data_source=str(data_source)):
39
+ # assert deeplake_inner_key is not None, data_source
40
+ # deeplake_ds = load_deeplake_dataset(data_source=str(data_source))
41
+ # deeplake_iter = get_deeplake_input_source_iterable(
42
+ # deeplake_dataset=deeplake_ds, inner_key=deeplake_inner_key
43
+ # )
44
+ # shape = next(deeplake_iter).shape
45
+ # else:
46
+ # iterator = get_file_path_iterator(data_source=data_source)
47
+ # path = next(iterator)
48
+ # shape = np.load(file=path).shape
49
+ #
50
+ # extra_dims: tuple[int, ...] = tuple()
51
+ # if len(shape) == 1:
52
+ # channels, height, width = 1, 1, shape[0]
53
+ # elif len(shape) == 2:
54
+ # channels, height, width = 1, shape[0], shape[1]
55
+ # elif len(shape) == 3:
56
+ # channels, height, width = shape
57
+ # else:
58
+ # channels, height, width = shape[0], shape[1], shape[2]
59
+ # extra_dims = shape[3:]
60
+ #
61
+ # return DataDimensions(
62
+ # channels=channels, height=height, width=width, extra_dims=extra_dims
63
+ # )
64
+
65
+