gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,171 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ class DeepCCR(nn.Module):
7
+ def __init__(self,
8
+ input_channels=1,
9
+ input_seq_len=162,
10
+ lstm_hidden_size=64,
11
+ fc1_hidden_dim=128):
12
+ super(DeepCCR, self).__init__()
13
+
14
+ # ==================== Conv1 ====================
15
+ self.conv1_kernel = 500
16
+ self.conv1_stride = 100
17
+ self.conv1_out_ch = 150
18
+
19
+ if input_seq_len < self.conv1_kernel:
20
+ pad_left = (self.conv1_kernel - input_seq_len) // 2
21
+ pad_right = self.conv1_kernel - input_seq_len - pad_left
22
+ self.conv1_padding = (pad_left, pad_right)
23
+ conv1_input_len = self.conv1_kernel
24
+ else:
25
+ self.conv1_padding = (0, 0)
26
+ conv1_input_len = input_seq_len
27
+
28
+ self.conv1 = nn.Conv1d(
29
+ in_channels=input_channels,
30
+ out_channels=self.conv1_out_ch,
31
+ kernel_size=self.conv1_kernel,
32
+ stride=self.conv1_stride,
33
+ padding=0
34
+ )
35
+ self.relu1 = nn.ReLU()
36
+
37
+ conv1_seq_len = (conv1_input_len - self.conv1_kernel) // self.conv1_stride + 1
38
+
39
+ self.max_pool1 = nn.MaxPool1d(kernel_size=2, stride=2, ceil_mode=True)
40
+ pool1_seq_len = (conv1_seq_len + 1) // 2
41
+
42
+ # ==================== BiLSTM ====================
43
+ self.bilstm = nn.LSTM(input_size=150, hidden_size=lstm_hidden_size, batch_first=True, bidirectional=True)
44
+
45
+ # ==================== Conv2 ====================
46
+ self.conv2_kernel = 30
47
+ self.conv2_stride = 5
48
+ self.conv2_out_ch = 150
49
+
50
+ if pool1_seq_len < self.conv2_kernel:
51
+ pad_left2 = (self.conv2_kernel - pool1_seq_len) // 2
52
+ pad_right2 = self.conv2_kernel - pool1_seq_len - pad_left2
53
+ self.conv2_padding = (pad_left2, pad_right2)
54
+ conv2_input_len = self.conv2_kernel
55
+ else:
56
+ self.conv2_padding = (0, 0)
57
+ conv2_input_len = pool1_seq_len
58
+
59
+ self.conv2 = nn.Conv1d(
60
+ in_channels=lstm_hidden_size*2,
61
+ out_channels=self.conv2_out_ch,
62
+ kernel_size=self.conv2_kernel,
63
+ stride=self.conv2_stride,
64
+ padding=0
65
+ )
66
+ self.relu2 = nn.ReLU()
67
+
68
+ conv2_seq_len = (conv2_input_len - self.conv2_kernel) // self.conv2_stride + 1
69
+ self.max_pool2 = nn.MaxPool1d(kernel_size=2, stride=2, ceil_mode=True)
70
+ pool2_seq_len = (conv2_seq_len + 1) // 2
71
+
72
+ # ==================== FC ====================
73
+ flatten_dim = self.conv2_out_ch * pool2_seq_len
74
+
75
+ self.fc1 = nn.Linear(flatten_dim, fc1_hidden_dim)
76
+ self.relu3 = nn.ReLU()
77
+ self.fc2 = nn.Linear(fc1_hidden_dim, 1)
78
+
79
+ def forward(self, x):
80
+ """
81
+ x: [batch, channels, seq_len]
82
+ """
83
+
84
+ # -------- Conv1 --------
85
+ if self.conv1_padding != (0, 0):
86
+ x = F.pad(x, self.conv1_padding)
87
+
88
+ x = self.relu1(self.conv1(x))
89
+ x = self.max_pool1(x)
90
+
91
+ # -------- BiLSTM --------
92
+ x = x.permute(0, 2, 1) # [B, T, C]
93
+ x, _ = self.bilstm(x)
94
+ x = x.permute(0, 2, 1) # [B, C, T]
95
+
96
+ # -------- Conv2 --------
97
+ if self.conv2_padding != (0, 0):
98
+ x = F.pad(x, self.conv2_padding)
99
+
100
+ x = self.relu2(self.conv2(x))
101
+ x = self.max_pool2(x)
102
+
103
+ # -------- FC --------
104
+ x = torch.flatten(x, start_dim=1)
105
+ x = self.relu3(self.fc1(x))
106
+ x = self.fc2(x)
107
+ return x
108
+
109
+ def train_model(self, train_loader, valid_loader, num_epochs, learning_rate, patience, device):
110
+ optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=1e-4)
111
+ criterion = nn.MSELoss()
112
+ self.to(device)
113
+
114
+ best_loss = float('inf')
115
+ best_state = None
116
+ trigger_times = 0
117
+
118
+ for epoch in range(num_epochs):
119
+ self.train()
120
+ train_loss = 0.0
121
+ for inputs, labels in train_loader:
122
+ inputs = inputs.to(device)
123
+ labels = labels.to(device).unsqueeze(1)
124
+ optimizer.zero_grad()
125
+ outputs = self(inputs)
126
+ loss = criterion(outputs, labels)
127
+ loss.backward()
128
+ optimizer.step()
129
+ train_loss += loss.item() * inputs.size(0)
130
+ train_loss /= len(train_loader.dataset)
131
+
132
+ self.eval()
133
+ valid_loss = 0.0
134
+ with torch.no_grad():
135
+ for inputs, labels in valid_loader:
136
+ inputs = inputs.to(device)
137
+ labels = labels.to(device).unsqueeze(1)
138
+ outputs = self(inputs)
139
+ loss = criterion(outputs, labels)
140
+ valid_loss += loss.item() * inputs.size(0)
141
+ valid_loss /= len(valid_loader.dataset)
142
+
143
+ # Early stopping
144
+ if valid_loss < best_loss:
145
+ best_loss = valid_loss
146
+ best_state = {k: v.cpu().clone() for k, v in self.state_dict().items()}
147
+ trigger_times = 0
148
+ else:
149
+ trigger_times += 1
150
+ if trigger_times >= patience:
151
+ print(f"Early stopping at epoch {epoch+1}")
152
+ break
153
+
154
+ if best_state is not None:
155
+ cur_device = next(self.parameters()).device
156
+ best_state = {k: v.to(cur_device) for k, v in best_state.items()}
157
+ self.load_state_dict(best_state)
158
+ return best_loss
159
+
160
+ def predict(self, test_loader, device):
161
+ self.eval()
162
+ self.to(device)
163
+ y_pred = []
164
+ with torch.no_grad():
165
+ for inputs, _ in test_loader:
166
+ inputs = inputs.to(device)
167
+ outputs = self(inputs)
168
+ y_pred.append(outputs.cpu().numpy())
169
+ y_pred = np.concatenate(y_pred, axis=0)
170
+ y_pred = np.squeeze(y_pred)
171
+ return y_pred
@@ -0,0 +1,165 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import argparse
5
+ import random
6
+ import torch
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.model_selection import KFold, train_test_split
10
+ from .base_deepgs import DeepGS
11
+ from scipy.stats import pearsonr
12
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+ from . import DeepGS_Hyperparameters
15
+ import pynvml
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Argument parser")
19
+ parser.add_argument('--methods', type=str, default='DeepGS/', help='Random seed')
20
+ parser.add_argument('--species', type=str, default='', help='Species name')
21
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
22
+ parser.add_argument('--data_dir', type=str, default='../../data/')
23
+ parser.add_argument('--result_dir', type=str, default='result/')
24
+
25
+ parser.add_argument('--num_round', type=int, default=6000, help='Number of training rounds')
26
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
27
+ parser.add_argument('--weight_decay', type=float, default=0.00001, help='Weight decay')
28
+ parser.add_argument('--momentum', type=float, default=0.5, help='Momentum')
29
+ parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
30
+ parser.add_argument('--patience', type=int, default=50, help='Patience for early stopping')
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+ def load_data(args):
35
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
36
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
37
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
38
+
39
+ nsample = xData.shape[0]
40
+ nsnp = xData.shape[1]
41
+ print("Number of samples: ", nsample)
42
+ print("Number of SNPs: ", nsnp)
43
+ return xData, yData, nsample, nsnp, names
44
+
45
+ def set_seed(seed=42):
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed_all(seed)
50
+ torch.backends.cudnn.deterministic = True
51
+ torch.backends.cudnn.benchmark = False
52
+
53
+ def get_gpu_mem_by_pid(pid):
54
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
55
+ for p in procs:
56
+ if p.pid == pid:
57
+ return p.usedGpuMemory / 1024**2
58
+ return 0.0
59
+
60
+
61
+ def run_nested_cv(args, data, label, nsnp, device):
62
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
63
+ os.makedirs(result_dir, exist_ok=True)
64
+ print("Starting 10-fold cross-validation...")
65
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
66
+
67
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
68
+ time_star = time.time()
69
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
70
+ print(f"Running fold {fold}...")
71
+ process = psutil.Process(os.getpid())
72
+ fold_start_time = time.time()
73
+
74
+ X_train, X_test = data[train_index], data[test_index]
75
+ y_train, y_test = label[train_index], label[test_index]
76
+
77
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
78
+
79
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
80
+ y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
81
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
82
+ y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
83
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
84
+ y_test_tensor = torch.from_numpy(y_test).float().to(device)
85
+
86
+ x_train_tensor = x_train_tensor.unsqueeze(1)
87
+ x_valid_tensor = x_valid_tensor.unsqueeze(1)
88
+ x_test_tensor = x_test_tensor.unsqueeze(1)
89
+
90
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
91
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
92
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
93
+
94
+ train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
95
+ valid_loader = DataLoader(valid_data, args.batch_size, shuffle=False)
96
+ test_loader = DataLoader(test_data, args.batch_size, shuffle=False)
97
+
98
+ model = DeepGS(nsnp)
99
+ model.train_model(train_loader, valid_loader, args.num_round, args.learning_rate, args.momentum, args.weight_decay, args.patience, device)
100
+ y_pred = model.predict(test_loader)
101
+
102
+ mse = mean_squared_error(y_test, y_pred)
103
+ r2 = r2_score(y_test, y_pred)
104
+ mae = mean_absolute_error(y_test, y_pred)
105
+ pcc, _ = pearsonr(y_test, y_pred)
106
+
107
+ all_mse.append(mse)
108
+ all_r2.append(r2)
109
+ all_mae.append(mae)
110
+ all_pcc.append(pcc)
111
+
112
+ fold_time = time.time() - fold_start_time
113
+ fold_gpu_mem = get_gpu_mem_by_pid(os.getpid())
114
+ fold_cpu_mem = process.memory_info().rss / 1024**2
115
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
116
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
117
+
118
+ torch.cuda.empty_cache()
119
+ torch.cuda.reset_peak_memory_stats()
120
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
121
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
122
+
123
+ print("\n===== Cross-validation summary =====")
124
+ print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
125
+ print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
126
+ print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
127
+ print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
128
+ print(f"Time: {time.time() - time_star:.2f}s")
129
+
130
+
131
+ def DeepGS_reg():
132
+ set_seed(42)
133
+ pynvml.nvmlInit()
134
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
135
+ args = parse_args()
136
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
137
+ args.device = device
138
+ all_species =['Cotton/']
139
+ for i in range(len(all_species)):
140
+ args.species = all_species[i]
141
+ X, Y, nsamples, nsnp, names = load_data(args)
142
+ for j in range(len(names)):
143
+ args.phe = names[j]
144
+ print("starting run " + args.methods + args.species + args.phe)
145
+ label = Y[:, i]
146
+ label = np.nan_to_num(label, nan=np.nanmean(label))
147
+ best_params = DeepGS_Hyperparameters.Hyperparameter(X, label, nsnp)
148
+ args.learning_rate = best_params['learning_rate']
149
+ args.batch_size = best_params['batch_size']
150
+ args.momentum = best_params['momentum']
151
+ args.weight_decay = best_params['weight_decay']
152
+ args.patience =best_params['patience']
153
+ start_time = time.time()
154
+ torch.cuda.reset_peak_memory_stats()
155
+ process = psutil.Process(os.getpid())
156
+
157
+ run_nested_cv(args, data=X, label=label, nsnp = nsnp, device = args.device)
158
+
159
+ elapsed_time = time.time() - start_time
160
+ print(f"running time: {elapsed_time:.2f} s")
161
+ print("successfully")
162
+
163
+
164
+ if __name__ == "__main__":
165
+ DeepGS_reg()
@@ -0,0 +1,114 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import KFold, train_test_split
9
+ from .base_deepgs import DeepGS
10
+ from scipy.stats import pearsonr
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+ from optuna.exceptions import TrialPruned
14
+
15
+ def run_nested_cv_with_early_stopping(data, label, nsnp, learning_rate, momentum, weight_decay, patience, batch_size, num_round=6000):
16
+ device = torch.device("cuda:0")
17
+ print("Starting 10-fold cross-validation...")
18
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
19
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
20
+
21
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
22
+ print(f"Running fold {fold}...")
23
+ process = psutil.Process(os.getpid())
24
+ fold_start_time = time.time()
25
+
26
+ X_train, X_test = data[train_index], data[test_index]
27
+ y_train, y_test = label[train_index], label[test_index]
28
+
29
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
30
+
31
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
32
+ y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
33
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
34
+ y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
35
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
36
+ y_test_tensor = torch.from_numpy(y_test).float().to(device)
37
+
38
+ x_train_tensor = x_train_tensor.unsqueeze(1)
39
+ x_valid_tensor = x_valid_tensor.unsqueeze(1)
40
+ x_test_tensor = x_test_tensor.unsqueeze(1)
41
+
42
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
43
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
44
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
45
+
46
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
47
+ valid_loader = DataLoader(valid_data, batch_size, shuffle=False)
48
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
49
+
50
+ model = DeepGS(nsnp)
51
+ model.train_model(train_loader, valid_loader, num_round, learning_rate, momentum, weight_decay, patience, device)
52
+ y_pred = model.predict(test_loader)
53
+
54
+ mse = mean_squared_error(y_test, y_pred)
55
+ r2 = r2_score(y_test, y_pred)
56
+ mae = mean_absolute_error(y_test, y_pred)
57
+ pcc, _ = pearsonr(y_test, y_pred)
58
+
59
+ if np.isnan(pcc):
60
+ print(f"Fold {fold} resulted in NaN PCC, pruning the trial...")
61
+ raise TrialPruned()
62
+
63
+ all_mse.append(mse)
64
+ all_r2.append(r2)
65
+ all_mae.append(mae)
66
+ all_pcc.append(pcc)
67
+
68
+ fold_time = time.time() - fold_start_time
69
+ fold_cpu_mem = process.memory_info().rss / 1024**2
70
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
71
+ f'CPU={fold_cpu_mem:.2f}MB')
72
+
73
+ return np.mean(all_pcc) if all_pcc else 0.0
74
+
75
+ def set_seed(seed=42):
76
+ random.seed(seed)
77
+ np.random.seed(seed)
78
+ torch.manual_seed(seed)
79
+ if torch.cuda.is_available():
80
+ torch.cuda.manual_seed_all(seed)
81
+ torch.backends.cudnn.deterministic = True
82
+ torch.backends.cudnn.benchmark = False
83
+
84
+ def Hyperparameter(data, label, nsnp):
85
+ set_seed(42)
86
+
87
+ def objective(trial):
88
+ learning_rate = trial.suggest_loguniform("learning_rate", 1e-4,0.1)
89
+ momentum = trial.suggest_float("momentum", 0.1, 0.9, step=0.1)
90
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
91
+ weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2, 1e-1])
92
+ patience = trial.suggest_int("patience", 10, 100, step=10)
93
+ try:
94
+ corr_score = run_nested_cv_with_early_stopping(
95
+ data=data,
96
+ label=label,
97
+ nsnp=nsnp,
98
+ learning_rate=learning_rate,
99
+ momentum=momentum,
100
+ weight_decay=weight_decay,
101
+ patience=patience,
102
+ batch_size=batch_size
103
+ )
104
+
105
+ except TrialPruned:
106
+ return float("-inf")
107
+ return corr_score
108
+
109
+ study = optuna.create_study(direction="maximize")
110
+ study.optimize(objective, n_trials=20)
111
+
112
+ print("best params:", study.best_params)
113
+ print("successfully")
114
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .DeepGS import DeepGS_reg
2
+
3
+ DeepGS = DeepGS_reg
4
+
5
+ __all__ = ["DeepGS","DeepGS_reg"]
@@ -0,0 +1,98 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import numpy as np
6
+
7
+
8
+ class DeepGS(nn.Module):
9
+ def __init__(self, input_size):
10
+ super().__init__()
11
+ self.input = nn.Identity()
12
+ self.conv1 = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=18, stride=1)
13
+ self.act1 = nn.ReLU()
14
+ self.pool = nn.MaxPool1d(kernel_size=4, stride=4)
15
+ self.drop1 = nn.Dropout1d(p = 0.2)
16
+ with torch.no_grad():
17
+ dummy = torch.zeros(1, 1, input_size)
18
+ dummy_out = self.pool(self.conv1(dummy))
19
+ conv_out_dim = dummy_out.view(1, -1).size(1)
20
+ self.fc1 = nn.Linear(in_features=conv_out_dim, out_features=32)
21
+ self.drop2 = nn.Dropout1d(p = 0.1)
22
+ self.fc2 = nn.Linear(in_features=32, out_features=1)
23
+ self.act2 = nn.Sigmoid()
24
+
25
+ def forward(self, x):
26
+ x = self.input(x)
27
+ x = self.conv1(x)
28
+ x = self.act1(x)
29
+ x = self.pool(x)
30
+ x = torch.flatten(x, 1).unsqueeze(1)
31
+ x = self.drop1(x)
32
+ x = self.fc1(x)
33
+ x = self.act2(x)
34
+ x = self.drop2(x)
35
+ x = self.fc2(x)
36
+ x = x.view(x.size(0), -1)
37
+ return x
38
+
39
+ def train_model(self, train_loader, valid_loader, num_epochs, learning_rate, momentum, weight_decay, patience, device):
40
+ optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
41
+ criterion = nn.L1Loss()
42
+ self.to(device)
43
+
44
+ best_loss = float('inf')
45
+ best_state = None
46
+ trigger_times = 0
47
+
48
+ for epoch in range(num_epochs):
49
+ self.train()
50
+ train_loss = 0.0
51
+ for inputs, labels in train_loader:
52
+ inputs, labels = inputs.to(device), labels.to(device)
53
+ optimizer.zero_grad()
54
+ outputs = self(inputs)
55
+ labels = labels.unsqueeze(1)
56
+ loss = criterion(outputs, labels)
57
+ loss.backward()
58
+ optimizer.step()
59
+ train_loss += loss.item() * inputs.size(0)
60
+
61
+ self.eval()
62
+ valid_loss = 0.0
63
+ with torch.no_grad():
64
+ for inputs, labels in valid_loader:
65
+ inputs, labels = inputs.to(device), labels.to(device)
66
+ outputs = self(inputs)
67
+ labels = labels.unsqueeze(1)
68
+ loss = criterion(outputs, labels)
69
+ valid_loss += loss.item() * inputs.size(0)
70
+
71
+ train_loss /= len(train_loader.dataset)
72
+ valid_loss /= len(valid_loader.dataset)
73
+
74
+ # ---------- Early stopping ----------
75
+ if valid_loss < best_loss:
76
+ best_loss = valid_loss
77
+ best_state = self.state_dict()
78
+ trigger_times = 0
79
+ else:
80
+ trigger_times += 1
81
+ if trigger_times >= patience:
82
+ print(f"Early stopping at epoch {epoch+1}")
83
+ break
84
+
85
+ if best_state is not None:
86
+ self.load_state_dict(best_state)
87
+ return best_loss
88
+
89
+ def predict(self, test_loader):
90
+ self.eval()
91
+ y_pred = []
92
+ with torch.no_grad():
93
+ for inputs, _ in test_loader:
94
+ outputs = self(inputs)
95
+ y_pred.append(outputs.cpu().numpy())
96
+ y_pred = np.concatenate(y_pred, axis=0)
97
+ y_pred = np.squeeze(y_pred)
98
+ return y_pred