gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,209 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class DeepCCR(nn.Module):
8
+ def __init__(
9
+ self,
10
+ input_channels=1,
11
+ input_seq_len=162,
12
+ lstm_hidden_size=64,
13
+ fc1_hidden_dim=128,
14
+ num_classes=2,
15
+ ):
16
+ super(DeepCCR, self).__init__()
17
+
18
+ # ==================== Conv1 ====================
19
+ self.conv1_kernel = 500
20
+ self.conv1_stride = 100
21
+ self.conv1_out_ch = 150
22
+
23
+ if input_seq_len < self.conv1_kernel:
24
+ pad_left = (self.conv1_kernel - input_seq_len) // 2
25
+ pad_right = self.conv1_kernel - input_seq_len - pad_left
26
+ self.conv1_padding = (pad_left, pad_right)
27
+ conv1_input_len = self.conv1_kernel
28
+ else:
29
+ self.conv1_padding = (0, 0)
30
+ conv1_input_len = input_seq_len
31
+
32
+ self.conv1 = nn.Conv1d(
33
+ in_channels=input_channels,
34
+ out_channels=self.conv1_out_ch,
35
+ kernel_size=self.conv1_kernel,
36
+ stride=self.conv1_stride,
37
+ padding=0
38
+ )
39
+ self.relu1 = nn.ReLU()
40
+
41
+ conv1_seq_len = (conv1_input_len - self.conv1_kernel) // self.conv1_stride + 1
42
+ self.max_pool1 = nn.MaxPool1d(kernel_size=2, stride=2, ceil_mode=True)
43
+ pool1_seq_len = (conv1_seq_len + 1) // 2
44
+
45
+ # ==================== BiLSTM ====================
46
+ self.bilstm = nn.LSTM(
47
+ input_size=150,
48
+ hidden_size=lstm_hidden_size,
49
+ batch_first=True,
50
+ bidirectional=True
51
+ )
52
+
53
+ # ==================== Conv2 ====================
54
+ self.conv2_kernel = 30
55
+ self.conv2_stride = 5
56
+ self.conv2_out_ch = 150
57
+
58
+ if pool1_seq_len < self.conv2_kernel:
59
+ pad_left2 = (self.conv2_kernel - pool1_seq_len) // 2
60
+ pad_right2 = self.conv2_kernel - pool1_seq_len - pad_left2
61
+ self.conv2_padding = (pad_left2, pad_right2)
62
+ conv2_input_len = self.conv2_kernel
63
+ else:
64
+ self.conv2_padding = (0, 0)
65
+ conv2_input_len = pool1_seq_len
66
+
67
+ self.conv2 = nn.Conv1d(
68
+ in_channels=lstm_hidden_size * 2,
69
+ out_channels=self.conv2_out_ch,
70
+ kernel_size=self.conv2_kernel,
71
+ stride=self.conv2_stride,
72
+ padding=0
73
+ )
74
+ self.relu2 = nn.ReLU()
75
+
76
+ conv2_seq_len = (conv2_input_len - self.conv2_kernel) // self.conv2_stride + 1
77
+ self.max_pool2 = nn.MaxPool1d(kernel_size=2, stride=2, ceil_mode=True)
78
+ pool2_seq_len = (conv2_seq_len + 1) // 2
79
+
80
+ # ==================== FC ====================
81
+ flatten_dim = self.conv2_out_ch * pool2_seq_len
82
+
83
+ self.fc1 = nn.Linear(flatten_dim, fc1_hidden_dim)
84
+ self.relu3 = nn.ReLU()
85
+ self.fc2 = nn.Linear(fc1_hidden_dim, num_classes)
86
+
87
+ # ==================================================
88
+ # Forward
89
+ # ==================================================
90
+ def forward(self, x):
91
+ """
92
+ x: [batch, channels, seq_len]
93
+ return: logits [batch, num_classes]
94
+ """
95
+
96
+ if self.conv1_padding != (0, 0):
97
+ x = F.pad(x, self.conv1_padding)
98
+
99
+ x = self.relu1(self.conv1(x))
100
+ x = self.max_pool1(x)
101
+
102
+ x = x.permute(0, 2, 1) # [B, T, C]
103
+ x, _ = self.bilstm(x)
104
+ x = x.permute(0, 2, 1) # [B, C, T]
105
+
106
+ if self.conv2_padding != (0, 0):
107
+ x = F.pad(x, self.conv2_padding)
108
+
109
+ x = self.relu2(self.conv2(x))
110
+ x = self.max_pool2(x)
111
+
112
+ x = torch.flatten(x, start_dim=1)
113
+ x = self.relu3(self.fc1(x))
114
+ logits = self.fc2(x)
115
+
116
+ return logits
117
+
118
+ # ==================================================
119
+ # Training (Classification)
120
+ # ==================================================
121
+ def train_model(
122
+ self,
123
+ train_loader,
124
+ valid_loader,
125
+ num_epochs,
126
+ learning_rate,
127
+ patience,
128
+ device
129
+ ):
130
+ self.to(device)
131
+
132
+ optimizer = torch.optim.AdamW(
133
+ self.parameters(),
134
+ lr=learning_rate,
135
+ weight_decay=1e-4
136
+ )
137
+ criterion = nn.CrossEntropyLoss()
138
+
139
+ best_loss = float("inf")
140
+ best_state = None
141
+ trigger_times = 0
142
+
143
+ for epoch in range(num_epochs):
144
+ # -------- Train --------
145
+ self.train()
146
+ train_loss = 0.0
147
+
148
+ for inputs, labels in train_loader:
149
+ inputs = inputs.to(device)
150
+ labels = labels.to(device).long()
151
+
152
+ optimizer.zero_grad()
153
+ outputs = self(inputs) # logits
154
+ loss = criterion(outputs, labels)
155
+ loss.backward()
156
+ optimizer.step()
157
+
158
+ train_loss += loss.item() * inputs.size(0)
159
+
160
+ train_loss /= len(train_loader.dataset)
161
+
162
+ # -------- Validation --------
163
+ self.eval()
164
+ valid_loss = 0.0
165
+ with torch.no_grad():
166
+ for inputs, labels in valid_loader:
167
+ inputs = inputs.to(device)
168
+ labels = labels.to(device).long()
169
+ outputs = self(inputs)
170
+ loss = criterion(outputs, labels)
171
+ valid_loss += loss.item() * inputs.size(0)
172
+
173
+ valid_loss /= len(valid_loader.dataset)
174
+
175
+ # -------- Early stopping --------
176
+ if valid_loss < best_loss:
177
+ best_loss = valid_loss
178
+ best_state = {k: v.cpu().clone() for k, v in self.state_dict().items()}
179
+ trigger_times = 0
180
+ else:
181
+ trigger_times += 1
182
+ if trigger_times >= patience:
183
+ print(f"Early stopping at epoch {epoch + 1}")
184
+ break
185
+
186
+ if best_state is not None:
187
+ cur_device = next(self.parameters()).device
188
+ best_state = {k: v.to(cur_device) for k, v in best_state.items()}
189
+ self.load_state_dict(best_state)
190
+
191
+ return best_loss
192
+
193
+ # ==================================================
194
+ # Predict (Classification)
195
+ # ==================================================
196
+ def predict(self, test_loader, device):
197
+ self.eval()
198
+ self.to(device)
199
+
200
+ y_pred = []
201
+ with torch.no_grad():
202
+ for inputs, _ in test_loader:
203
+ inputs = inputs.to(device)
204
+ outputs = self(inputs) # logits
205
+ preds = torch.argmax(outputs, dim=1) # class index
206
+ y_pred.append(preds.cpu().numpy())
207
+
208
+ y_pred = np.concatenate(y_pred, axis=0)
209
+ return y_pred
@@ -0,0 +1,184 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import swanlab
5
+ import argparse
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.model_selection import StratifiedKFold, train_test_split
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+ from .base_deepgs_class import DeepGS
15
+ from . import DeepGS_he_class
16
+ import pynvml
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--methods', type=str, default='DeepGS/')
21
+ parser.add_argument('--species', type=str, default='Wheat/')
22
+ parser.add_argument('--phe', type=str, default='')
23
+ parser.add_argument('--data_dir', type=str, default='../../data/')
24
+ parser.add_argument('--result_dir', type=str, default='result/')
25
+
26
+ parser.add_argument('--num_round', type=int, default=6000)
27
+ parser.add_argument('--batch_size', type=int, default=32)
28
+ parser.add_argument('--weight_decay', type=float, default=1e-5)
29
+ parser.add_argument('--momentum', type=float, default=0.5)
30
+ parser.add_argument('--learning_rate', type=float, default=0.01)
31
+ parser.add_argument('--patience', type=int, default=50)
32
+ return parser.parse_args()
33
+
34
+ def load_data(args):
35
+ X = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
36
+ Y = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
37
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
38
+ print("Samples:", X.shape[0], "SNPs:", X.shape[1])
39
+ return X, Y, X.shape[0], X.shape[1], names
40
+
41
+ def set_seed(seed=42):
42
+ random.seed(seed)
43
+ np.random.seed(seed)
44
+ torch.manual_seed(seed)
45
+ torch.cuda.manual_seed_all(seed)
46
+ torch.backends.cudnn.deterministic = True
47
+ torch.backends.cudnn.benchmark = False
48
+
49
+ def get_gpu_mem_by_pid(pid, handle=None):
50
+ if handle is None:
51
+ return 0.0
52
+ try:
53
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
54
+ for p in procs:
55
+ if p.pid == pid:
56
+ return p.usedGpuMemory / 1024**2
57
+ return 0.0
58
+ except Exception:
59
+ return 0.0
60
+
61
+ def run_nested_cv(args, data, label, nsnp, num_classes, device, gpu_handle=None):
62
+
63
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
64
+ os.makedirs(result_dir, exist_ok=True)
65
+
66
+ kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
67
+
68
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
69
+ cv_start_time = time.time()
70
+
71
+ for fold, (train_idx, test_idx) in enumerate(kf.split(data, label)):
72
+ fold_start = time.time()
73
+ process = psutil.Process(os.getpid())
74
+ print(f"\n===== Fold {fold} =====")
75
+
76
+ X_train, X_test = data[train_idx], data[test_idx]
77
+ y_train, y_test = label[train_idx], label[test_idx]
78
+
79
+ X_tr, X_val, y_tr, y_val = train_test_split(
80
+ X_train, y_train,
81
+ test_size=0.1,
82
+ stratify=y_train,
83
+ random_state=42
84
+ )
85
+
86
+ x_tr = torch.from_numpy(X_tr).float().unsqueeze(1).to(device)
87
+ x_val = torch.from_numpy(X_val).float().unsqueeze(1).to(device)
88
+ x_te = torch.from_numpy(X_test).float().unsqueeze(1).to(device)
89
+
90
+ y_tr = torch.from_numpy(y_tr).long().to(device)
91
+ y_val = torch.from_numpy(y_val).long().to(device)
92
+ y_te = torch.from_numpy(y_test).long().to(device)
93
+
94
+ train_loader = DataLoader(TensorDataset(x_tr, y_tr), args.batch_size, shuffle=True)
95
+ valid_loader = DataLoader(TensorDataset(x_val, y_val), args.batch_size, shuffle=False)
96
+ test_loader = DataLoader(TensorDataset(x_te, y_te), args.batch_size, shuffle=False)
97
+
98
+ model = DeepGS(nsnp, num_classes=num_classes)
99
+
100
+ model.train_model(
101
+ train_loader,
102
+ valid_loader,
103
+ args.num_round,
104
+ args.learning_rate,
105
+ args.momentum,
106
+ args.weight_decay,
107
+ args.patience,
108
+ device
109
+ )
110
+
111
+ y_pred = model.predict(test_loader)
112
+
113
+ acc = accuracy_score(y_test, y_pred)
114
+ prec, rec, f1, _ = precision_recall_fscore_support(
115
+ y_test, y_pred,
116
+ average="macro",
117
+ zero_division=0
118
+ )
119
+
120
+ all_acc.append(acc)
121
+ all_prec.append(prec)
122
+ all_rec.append(rec)
123
+ all_f1.append(f1)
124
+
125
+ fold_time = time.time() - fold_start
126
+ gpu_mem = get_gpu_mem_by_pid(os.getpid(), gpu_handle)
127
+ cpu_mem = process.memory_info().rss / 1024**2
128
+
129
+ print(
130
+ f"Fold {fold}: "
131
+ f"ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, "
132
+ f"Time={fold_time:.2f}s, GPU={gpu_mem:.2f}MB, CPU={cpu_mem:.2f}MB"
133
+ )
134
+
135
+ pd.DataFrame({
136
+ "Y_test": y_test,
137
+ "Y_pred": y_pred
138
+ }).to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
139
+
140
+ torch.cuda.empty_cache()
141
+
142
+
143
+ def DeepGS_class():
144
+ set_seed(42)
145
+ gpu_handle = None
146
+ try:
147
+ if torch.cuda.is_available():
148
+ pynvml.nvmlInit()
149
+ gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
150
+ except Exception as e:
151
+ print(f"Warning: GPU monitoring initialization failed: {e}")
152
+ gpu_handle = None
153
+
154
+ args = parse_args()
155
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
156
+
157
+ all_species = ["Human/Sim/"]
158
+
159
+ for species in all_species:
160
+ args.species = species
161
+ X, Y, nsamples, nsnp, names = load_data(args)
162
+
163
+ print("Starting:", args.methods + args.species)
164
+ label_raw = np.nan_to_num(Y[:, 0])
165
+ le = LabelEncoder()
166
+ label = le.fit_transform(label_raw)
167
+ num_classes = len(le.classes_)
168
+
169
+ best_params = DeepGS_he_class.Hyperparameter(X, label, nsnp)
170
+ args.learning_rate = best_params['learning_rate']
171
+ args.batch_size = best_params['batch_size']
172
+ args.momentum = best_params['momentum']
173
+ args.weight_decay = best_params['weight_decay']
174
+ args.patience = best_params['patience']
175
+
176
+ start_time = time.time()
177
+ run_nested_cv(args, X, label, nsnp, num_classes, device, gpu_handle)
178
+ elapsed_time = time.time() - start_time
179
+ print(f"Total running time: {elapsed_time:.2f}s")
180
+ print("Successfully finished:", args.species, args.phe)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ DeepGS_class()
@@ -0,0 +1,150 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import KFold, train_test_split
9
+ from sklearn.preprocessing import LabelEncoder
10
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+ from optuna.exceptions import TrialPruned
13
+ from .base_deepgs_class import DeepGS
14
+
15
+ def run_nested_cv_classification( data, label, nsnp, learning_rate, momentum, weight_decay,
16
+ patience, batch_size, num_round=1000):
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ print("Starting 10-fold cross-validation...")
19
+
20
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
21
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
22
+
23
+ num_classes = len(np.unique(label))
24
+
25
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
26
+ print(f"\n===== Fold {fold} =====")
27
+ fold_start_time = time.time()
28
+ process = psutil.Process(os.getpid())
29
+
30
+ X_train, X_test = data[train_index], data[test_index]
31
+ y_train, y_test = label[train_index], label[test_index]
32
+
33
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
34
+ X_train, y_train, test_size=0.1, random_state=42
35
+ )
36
+
37
+ # tensor
38
+ x_train = torch.from_numpy(X_train_sub).float().unsqueeze(1).to(device)
39
+ y_train = torch.from_numpy(y_train_sub).long().to(device)
40
+ x_valid = torch.from_numpy(X_valid).float().unsqueeze(1).to(device)
41
+ y_valid = torch.from_numpy(y_valid).long().to(device)
42
+ x_test = torch.from_numpy(X_test).float().unsqueeze(1).to(device)
43
+ y_test = torch.from_numpy(y_test).long().to(device)
44
+
45
+ train_loader = DataLoader(
46
+ TensorDataset(x_train, y_train),
47
+ batch_size=batch_size, shuffle=True
48
+ )
49
+ valid_loader = DataLoader(
50
+ TensorDataset(x_valid, y_valid),
51
+ batch_size=batch_size, shuffle=False
52
+ )
53
+ test_loader = DataLoader(
54
+ TensorDataset(x_test, y_test),
55
+ batch_size=batch_size, shuffle=False
56
+ )
57
+
58
+ model = DeepGS(nsnp, num_classes=num_classes)
59
+ model.loss_fn = torch.nn.CrossEntropyLoss()
60
+
61
+ model.train_model(
62
+ train_loader, valid_loader,
63
+ num_round, learning_rate,
64
+ momentum, weight_decay,
65
+ patience, device
66
+ )
67
+
68
+ y_pred = model.predict(test_loader)
69
+
70
+ if y_pred.ndim == 2:
71
+ y_pred_class = np.argmax(y_pred, axis=1)
72
+ else:
73
+ y_pred_class = y_pred
74
+
75
+ acc = accuracy_score(y_test.cpu().numpy(), y_pred_class)
76
+ prec, rec, f1, _ = precision_recall_fscore_support(
77
+ y_test.cpu().numpy(),
78
+ y_pred_class,
79
+ average="macro",
80
+ zero_division=0
81
+ )
82
+
83
+ all_acc.append(acc)
84
+ all_prec.append(prec)
85
+ all_rec.append(rec)
86
+ all_f1.append(f1)
87
+
88
+ fold_time = time.time() - fold_start_time
89
+ cpu_mem = process.memory_info().rss / 1024**2
90
+
91
+ print(
92
+ f"Fold {fold}: "
93
+ f"ACC={acc:.4f}, "
94
+ f"PREC={prec:.4f}, "
95
+ f"REC={rec:.4f}, "
96
+ f"F1={f1:.4f}, "
97
+ f"Time={fold_time:.2f}s, "
98
+ f"CPU={cpu_mem:.2f}MB"
99
+ )
100
+
101
+ print("\n===== Final Results =====")
102
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
103
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
104
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
105
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
106
+
107
+ return np.mean(all_f1)
108
+
109
+ def set_seed(seed=42):
110
+ random.seed(seed)
111
+ np.random.seed(seed)
112
+ torch.manual_seed(seed)
113
+ if torch.cuda.is_available():
114
+ torch.cuda.manual_seed_all(seed)
115
+ torch.backends.cudnn.deterministic = True
116
+ torch.backends.cudnn.benchmark = False
117
+
118
+ def Hyperparameter(data, label, nsnp):
119
+ set_seed(42)
120
+
121
+ def objective(trial):
122
+ learning_rate = trial.suggest_float("learning_rate", 1e-4, 0.1, log=True)
123
+ momentum = trial.suggest_float("momentum", 0.1, 0.9, step=0.1)
124
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
125
+ weight_decay = trial.suggest_categorical(
126
+ "weight_decay", [1e-4, 1e-3, 1e-2, 1e-1]
127
+ )
128
+ patience = trial.suggest_int("patience", 10, 100, step=10)
129
+
130
+ try:
131
+ f1 = run_nested_cv_classification(
132
+ data=data,
133
+ label=label,
134
+ nsnp=nsnp,
135
+ learning_rate=learning_rate,
136
+ momentum=momentum,
137
+ weight_decay=weight_decay,
138
+ patience=patience,
139
+ batch_size=batch_size
140
+ )
141
+ except Exception as e:
142
+ raise TrialPruned()
143
+
144
+ return f1
145
+
146
+ study = optuna.create_study(direction="maximize")
147
+ study.optimize(objective, n_trials=20)
148
+
149
+ print("Best params:", study.best_params)
150
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .DeepGS_class import DeepGS_class
2
+
3
+ DeepGS = DeepGS_class
4
+
5
+ __all__ = ["DeepGS","DeepGS_class"]
@@ -0,0 +1,153 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class DeepGS(nn.Module):
7
+ """
8
+ DeepGS for multi-class classification
9
+ Fully compatible with:
10
+ - Optuna hyperparameter optimization
11
+ - 10-fold cross-validation
12
+ """
13
+
14
+ def __init__(self, input_size: int, num_classes: int):
15
+ super().__init__()
16
+
17
+ # ========= Feature extractor =========
18
+ self.conv1 = nn.Conv1d(
19
+ in_channels=1,
20
+ out_channels=8,
21
+ kernel_size=18,
22
+ stride=1
23
+ )
24
+ self.act1 = nn.ReLU()
25
+ self.pool1 = nn.MaxPool1d(kernel_size=4, stride=4)
26
+ self.drop1 = nn.Dropout(p=0.2)
27
+
28
+ # ========= Dynamically infer FC input =========
29
+ with torch.no_grad():
30
+ dummy = torch.zeros(1, 1, input_size)
31
+ dummy = self.pool1(self.act1(self.conv1(dummy)))
32
+ conv_out_dim = dummy.view(1, -1).size(1)
33
+
34
+ # ========= Classifier =========
35
+ self.fc1 = nn.Linear(conv_out_dim, 32)
36
+ self.act2 = nn.ReLU()
37
+ self.drop2 = nn.Dropout(p=0.1)
38
+
39
+ self.fc2 = nn.Linear(32, num_classes)
40
+
41
+ # ==================================================
42
+ # Forward
43
+ # ==================================================
44
+ def forward(self, x):
45
+ x = self.conv1(x)
46
+ x = self.act1(x)
47
+ x = self.pool1(x)
48
+ x = self.drop1(x)
49
+
50
+ x = torch.flatten(x, 1)
51
+
52
+ x = self.fc1(x)
53
+ x = self.act2(x)
54
+ x = self.drop2(x)
55
+
56
+ x = self.fc2(x) # logits
57
+ return x
58
+
59
+ # ==================================================
60
+ # Training (classification)
61
+ # ==================================================
62
+ def train_model(
63
+ self,
64
+ train_loader,
65
+ valid_loader,
66
+ num_epochs: int,
67
+ learning_rate: float,
68
+ momentum: float,
69
+ weight_decay: float,
70
+ patience: int,
71
+ device: torch.device
72
+ ):
73
+ self.to(device)
74
+
75
+ optimizer = torch.optim.SGD(
76
+ self.parameters(),
77
+ lr=learning_rate,
78
+ momentum=momentum,
79
+ weight_decay=weight_decay
80
+ )
81
+
82
+ criterion = nn.CrossEntropyLoss()
83
+
84
+ best_loss = float("inf")
85
+ best_state = None
86
+ trigger_times = 0
87
+
88
+ for epoch in range(num_epochs):
89
+ # -------- Train --------
90
+ self.train()
91
+ train_loss = 0.0
92
+
93
+ for inputs, labels in train_loader:
94
+ inputs = inputs.to(device)
95
+ labels = labels.to(device)
96
+
97
+ optimizer.zero_grad()
98
+ outputs = self(inputs)
99
+ loss = criterion(outputs, labels)
100
+ loss.backward()
101
+ optimizer.step()
102
+
103
+ train_loss += loss.item() * inputs.size(0)
104
+
105
+ train_loss /= len(train_loader.dataset)
106
+
107
+ # -------- Validation --------
108
+ self.eval()
109
+ valid_loss = 0.0
110
+
111
+ with torch.no_grad():
112
+ for inputs, labels in valid_loader:
113
+ inputs = inputs.to(device)
114
+ labels = labels.to(device)
115
+
116
+ outputs = self(inputs)
117
+ loss = criterion(outputs, labels)
118
+ valid_loss += loss.item() * inputs.size(0)
119
+
120
+ valid_loss /= len(valid_loader.dataset)
121
+
122
+ # -------- Early stopping --------
123
+ if valid_loss < best_loss:
124
+ best_loss = valid_loss
125
+ best_state = self.state_dict()
126
+ trigger_times = 0
127
+ else:
128
+ trigger_times += 1
129
+ if trigger_times >= patience:
130
+ break
131
+
132
+ if best_state is not None:
133
+ self.load_state_dict(best_state)
134
+
135
+ return best_loss
136
+
137
+ # ==================================================
138
+ # Prediction (classification)
139
+ # ==================================================
140
+ def predict(self, test_loader):
141
+ self.eval()
142
+ device = next(self.parameters()).device
143
+
144
+ y_pred = []
145
+
146
+ with torch.no_grad():
147
+ for inputs, _ in test_loader:
148
+ inputs = inputs.to(device)
149
+ outputs = self(inputs) # (N, C)
150
+ preds = torch.argmax(outputs, dim=1)
151
+ y_pred.append(preds.cpu().numpy())
152
+
153
+ return np.concatenate(y_pred, axis=0)