gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,209 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional
4
+ import numpy as np
5
+
6
+
7
+ class ModelHyperparams:
8
+ def __init__(self,
9
+ left_tower_filters_list: Optional[List[int]] = None,
10
+ left_tower_kernel_size_list: Optional[List[int]] = None,
11
+ right_tower_filters_list: Optional[List[int]] = None,
12
+ right_tower_kernel_size_list: Optional[List[int]] = None,
13
+ central_tower_filters_list: Optional[List[int]] = None,
14
+ central_tower_kernel_size_list: Optional[List[int]] = None,
15
+ dnn_size_list: Optional[List[int]] = None,
16
+ activation: str = "linear",
17
+ dropout_rate: float = 0.75):
18
+ self.left_tower_filters_list = left_tower_filters_list or [4, 4]
19
+ self.left_tower_kernel_size_list = left_tower_kernel_size_list or [3, 5]
20
+ self.right_tower_filters_list = right_tower_filters_list or [4]
21
+ self.right_tower_kernel_size_list = right_tower_kernel_size_list or [3]
22
+ self.central_tower_filters_list = central_tower_filters_list or [4]
23
+ self.central_tower_kernel_size_list = central_tower_kernel_size_list or [3]
24
+ self.dnn_size_list = dnn_size_list or [32]
25
+ self.activation = activation
26
+ self.dropout_rate = dropout_rate
27
+
28
+ def get_activation(name: str):
29
+ if name.lower() == "relu":
30
+ return nn.ReLU()
31
+ elif name.lower() == "linear":
32
+ return nn.Identity()
33
+ else:
34
+ raise ValueError(f"Unsupported activation: {name}")
35
+
36
+
37
+ class G2PDeep(nn.Module):
38
+ def __init__(self, nsnp: int, num_classes: int, hyperparams: ModelHyperparams):
39
+ super().__init__()
40
+ self.nsnp = nsnp
41
+ self.num_classes = num_classes
42
+ hp = hyperparams
43
+
44
+ # --- Left Tower ---
45
+ self.left_convs = nn.ModuleList()
46
+ in_ch = 4
47
+ for filt, k in zip(hp.left_tower_filters_list, hp.left_tower_kernel_size_list):
48
+ self.left_convs.append(nn.Conv1d(in_ch, filt, k, padding="same"))
49
+ in_ch = filt
50
+
51
+ # --- Right Tower ---
52
+ self.right_convs = nn.ModuleList()
53
+ in_ch = 4
54
+ for filt, k in zip(hp.right_tower_filters_list, hp.right_tower_kernel_size_list):
55
+ self.right_convs.append(nn.Conv1d(in_ch, filt, k, padding="same"))
56
+ in_ch = filt
57
+
58
+ # --- Channel alignment ---
59
+ left_out_ch = hp.left_tower_filters_list[-1]
60
+ right_out_ch = hp.right_tower_filters_list[-1]
61
+ self.merged_ch = max(left_out_ch, right_out_ch)
62
+
63
+ self.left_proj = nn.Conv1d(left_out_ch, self.merged_ch, 1) \
64
+ if left_out_ch != self.merged_ch else nn.Identity()
65
+ self.right_proj = nn.Conv1d(right_out_ch, self.merged_ch, 1) \
66
+ if right_out_ch != self.merged_ch else nn.Identity()
67
+
68
+ # --- Central Tower ---
69
+ self.central_convs = nn.ModuleList()
70
+ in_ch = self.merged_ch
71
+ for filt, k in zip(hp.central_tower_filters_list, hp.central_tower_kernel_size_list):
72
+ self.central_convs.append(nn.Conv1d(in_ch, filt, k, padding="same"))
73
+ in_ch = filt
74
+
75
+ # --DNN ---
76
+ self.dropout = nn.Dropout(p=hp.dropout_rate)
77
+ final_conv_ch = hp.central_tower_filters_list[-1]
78
+ flattened_dim = final_conv_ch * nsnp
79
+
80
+ dnn_layers = []
81
+ prev = flattened_dim
82
+ for out_sz in hp.dnn_size_list:
83
+ dnn_layers.append(nn.Linear(prev, out_sz))
84
+ dnn_layers.append(get_activation(hp.activation))
85
+ dnn_layers.append(nn.Dropout(hp.dropout_rate))
86
+ prev = out_sz
87
+
88
+ dnn_layers.append(nn.Linear(prev, num_classes))
89
+ self.dnn = nn.Sequential(*dnn_layers)
90
+
91
+ self.activation = get_activation(hp.activation)
92
+
93
+ def forward(self, x):
94
+ # (B, Seq, 4) -> (B, 4, Seq)
95
+ if x.shape[-1] != 4:
96
+ raise ValueError(f"Expected input with 4 channels, got {x.shape}")
97
+
98
+ x = x.transpose(1, 2)
99
+
100
+ # Left tower
101
+ left = x
102
+ for conv in self.left_convs:
103
+ left = self.activation(conv(left))
104
+
105
+ # Right tower
106
+ right = x
107
+ for conv in self.right_convs:
108
+ right = self.activation(conv(right))
109
+
110
+ merged = self.left_proj(left) + self.right_proj(right)
111
+
112
+ # Central tower
113
+ for conv in self.central_convs:
114
+ merged = self.activation(conv(merged))
115
+
116
+ x_flat = torch.flatten(merged, 1)
117
+ x_flat = self.dropout(x_flat)
118
+ return self.dnn(x_flat) # (B, num_classes) logits
119
+
120
+ def train_model(self, train_loader, valid_loader, num_epochs, learning_rate, patience, device):
121
+ optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=1e-4)
122
+ criterion = nn.CrossEntropyLoss()
123
+ self.to(device)
124
+
125
+ use_amp = device.type == 'cuda'
126
+ scaler = torch.amp.GradScaler('cuda') if use_amp else None
127
+
128
+ best_loss = float('inf')
129
+ best_state = None
130
+ trigger_times = 0
131
+
132
+ for epoch in range(num_epochs):
133
+ self.train()
134
+ train_loss = 0.0
135
+ for inputs, labels in train_loader:
136
+ inputs = inputs.to(device, non_blocking=True)
137
+ labels = labels.to(device, non_blocking=True) # (B,) long tensor
138
+
139
+ optimizer.zero_grad()
140
+
141
+ if use_amp:
142
+ with torch.amp.autocast('cuda'):
143
+ outputs = self(inputs) # (B, num_classes)
144
+ loss = criterion(outputs, labels)
145
+ scaler.scale(loss).backward()
146
+ scaler.step(optimizer)
147
+ scaler.update()
148
+ else:
149
+ outputs = self(inputs)
150
+ loss = criterion(outputs, labels)
151
+ loss.backward()
152
+ optimizer.step()
153
+
154
+ train_loss += loss.item() * inputs.size(0)
155
+ train_loss /= len(train_loader.dataset)
156
+
157
+ # 验证
158
+ self.eval()
159
+ valid_loss = 0.0
160
+ with torch.no_grad():
161
+ for inputs, labels in valid_loader:
162
+ inputs = inputs.to(device, non_blocking=True)
163
+ labels = labels.to(device, non_blocking=True)
164
+
165
+ if use_amp:
166
+ with torch.amp.autocast('cuda'):
167
+ outputs = self(inputs)
168
+ loss = criterion(outputs, labels)
169
+ else:
170
+ outputs = self(inputs)
171
+ loss = criterion(outputs, labels)
172
+
173
+ valid_loss += loss.item() * inputs.size(0)
174
+ valid_loss /= len(valid_loader.dataset)
175
+
176
+ # Early stopping
177
+ if valid_loss < best_loss:
178
+ best_loss = valid_loss
179
+ best_state = {k: v.cpu().clone() for k, v in self.state_dict().items()}
180
+ trigger_times = 0
181
+ else:
182
+ trigger_times += 1
183
+ if trigger_times >= patience:
184
+ print(f"Early stopping at epoch {epoch+1}")
185
+ break
186
+
187
+ if best_state is not None:
188
+ cur_device = next(self.parameters()).device
189
+ best_state = {k: v.to(cur_device) for k, v in best_state.items()}
190
+ self.load_state_dict(best_state)
191
+ return best_loss
192
+
193
+ def predict(self, test_loader, device):
194
+ self.eval()
195
+ self.to(device)
196
+ y_pred_list = []
197
+ use_amp = device.type == 'cuda'
198
+ with torch.no_grad():
199
+ for inputs, _ in test_loader:
200
+ inputs = inputs.to(device, non_blocking=True)
201
+ if use_amp:
202
+ with torch.amp.autocast('cuda'):
203
+ outputs = self(inputs) # (B, num_classes)
204
+ else:
205
+ outputs = self(inputs)
206
+ preds = torch.argmax(outputs, dim=1) # (B,)
207
+ y_pred_list.append(preds.cpu())
208
+ y_pred = torch.cat(y_pred_list, dim=0).numpy()
209
+ return y_pred
@@ -0,0 +1,183 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import swanlab
5
+ import argparse
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.model_selection import StratifiedKFold
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
13
+
14
+ import rpy2.robjects as ro
15
+ from rpy2.robjects import numpy2ri
16
+ numpy2ri.activate()
17
+ ro.r('library(rrBLUP)')
18
+
19
+
20
+ def gblup_classification(X_train, y_train_bin, X_test):
21
+
22
+ # Pass data to R
23
+ ro.globalenv['X_train'] = X_train
24
+ ro.globalenv['y_train_bin'] = y_train_bin
25
+ ro.globalenv['X_test'] = X_test
26
+
27
+ r_code = """
28
+ library(rrBLUP)
29
+
30
+ n_train <- nrow(X_train)
31
+ m <- ncol(X_train)
32
+
33
+ # Step1: allele frequencies
34
+ p <- colMeans(X_train) / 2
35
+ p <- pmax(pmin(p, 0.99), 0.01)
36
+
37
+ # Step2: VanRaden standardized genotype
38
+ Z_train <- sweep(X_train, 2, 2*p, "-") / sqrt(2*p*(1-p))
39
+ Z_train[is.na(Z_train)] <- 0
40
+
41
+ Z_test <- sweep(X_test, 2, 2*p, "-") / sqrt(2*p*(1-p))
42
+ Z_test[is.na(Z_test)] <- 0
43
+
44
+ # Step3: Genomic relationship matrix (VanRaden method 2)
45
+ denom <- sum(2*p*(1-p))
46
+ G <- Z_train %*% t(Z_train) / denom
47
+ G <- G + diag(1e-6, n_train) # stability
48
+
49
+ # Step4: REML GBLUP
50
+ fit <- mixed.solve(y = y_train_bin, K = G, SE = FALSE)
51
+
52
+ # Extract variance components and fixed effect
53
+ Vu <- fit$Vu
54
+ Ve <- fit$Ve
55
+ mu <- as.numeric(fit$beta) # intercept
56
+ h2 <- Vu / (Vu + Ve)
57
+
58
+ # Step5: GBLUP prediction for test set
59
+ y_centered <- y_train_bin - mu
60
+ A <- G + (Ve / Vu) * diag(n_train) # G + λ I
61
+
62
+ G_test_train <- Z_test %*% t(Z_train) / denom
63
+ u_test <- G_test_train %*% solve(A, y_centered) # strictly correct formula
64
+
65
+ y_pred_score <- mu + u_test
66
+ y_pred_score
67
+ """
68
+
69
+ y_pred_score = np.array(ro.r(r_code)).flatten()
70
+ return y_pred_score
71
+
72
+
73
+ def parse_args():
74
+ parser = argparse.ArgumentParser(description="Argument parser")
75
+ parser.add_argument('--methods', type=str, default='GBLUP/', help='Method name')
76
+ parser.add_argument('--species', type=str, default='')
77
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
78
+ parser.add_argument('--data_dir', type=str, default='../../data/')
79
+ parser.add_argument('--result_dir', type=str, default='result/')
80
+ return parser.parse_args()
81
+
82
+
83
+ def load_data(args):
84
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
85
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
86
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
87
+ nsample = xData.shape[0]
88
+ nsnp = xData.shape[1]
89
+ print("Number of samples: ", nsample)
90
+ print("Number of SNPs: ", nsnp)
91
+ return xData, yData, nsample, nsnp, names
92
+
93
+
94
+ def set_seed(seed=42):
95
+ random.seed(seed)
96
+ np.random.seed(seed)
97
+ torch.manual_seed(seed)
98
+ torch.cuda.manual_seed_all(seed)
99
+ torch.backends.cudnn.deterministic = True
100
+ torch.backends.cudnn.benchmark = False
101
+
102
+
103
+ def run_nested_cv(args, data, label, process):
104
+ result_dir = os.path.join(args.result_dir, args.methods + args.species)
105
+ os.makedirs(result_dir, exist_ok=True)
106
+ print("Starting 10-fold cross-validation (GBLUP Classification with R)...")
107
+
108
+ skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
109
+ le = LabelEncoder()
110
+ label_all = le.fit_transform(label)
111
+ np.save(os.path.join(result_dir, 'label_mapping.npy'), le.classes_)
112
+
113
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
114
+
115
+ for fold, (train_idx, test_idx) in enumerate(skf.split(data, label_all)):
116
+ fold_start = time.time()
117
+ print(f"===== Fold {fold} =====")
118
+ X_train, X_test = data[train_idx], data[test_idx]
119
+ Y_train, Y_test = label_all[train_idx], label_all[test_idx]
120
+
121
+ if torch.cuda.is_available():
122
+ torch.cuda.reset_peak_memory_stats()
123
+
124
+ classes = np.unique(Y_train)
125
+ scores = np.zeros((len(classes), X_test.shape[0]))
126
+ for idx, cls in enumerate(classes):
127
+ y_train_bin = (Y_train == cls).astype(float)
128
+ scores[idx, :] = gblup_classification(X_train, y_train_bin, X_test)
129
+
130
+ Y_pred = np.argmax(scores, axis=0)
131
+
132
+ acc = accuracy_score(Y_test, Y_pred)
133
+ prec, rec, f1, _ = precision_recall_fscore_support(Y_test, Y_pred, average='macro', zero_division=0)
134
+ all_acc.append(acc)
135
+ all_prec.append(prec)
136
+ all_rec.append(rec)
137
+ all_f1.append(f1)
138
+
139
+ fold_time = time.time() - fold_start
140
+ fold_gpu_mem = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
141
+ fold_cpu_mem = process.memory_info().rss / 1024**2
142
+ print(f'Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, '
143
+ f'Time={fold_time:.2f}s, GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
144
+
145
+ Y_test_orig = le.inverse_transform(Y_test)
146
+ Y_pred_orig = le.inverse_transform(Y_pred)
147
+ results_df = pd.DataFrame({'Y_test': Y_test_orig, 'Y_pred': Y_pred_orig})
148
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
149
+
150
+ print("\n===== Cross-validation summary =====")
151
+ print(f"Average ACC: {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
152
+ print(f"Average PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
153
+ print(f"Average REC: {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
154
+ print(f"Average F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
155
+
156
+
157
+ def GBLUP_class():
158
+ set_seed(42)
159
+ torch.cuda.empty_cache()
160
+ args = parse_args()
161
+ process = psutil.Process(os.getpid())
162
+
163
+ all_species = ["Human/Sim/"]
164
+ for sp in all_species:
165
+ args.species = sp
166
+ X, Y, nsamples, nsnp, names = load_data(args)
167
+ print("Starting run " + args.methods + args.species)
168
+ label = Y[:, 0]
169
+ s = pd.Series(label)
170
+ fill_val = s.mode().iloc[0] if not s.dropna().empty else 0
171
+ label = np.nan_to_num(label, nan=fill_val)
172
+
173
+ start_time = time.time()
174
+ torch.cuda.reset_peak_memory_stats()
175
+ run_nested_cv(args, data=X, label=label, process=process)
176
+
177
+ elapsed_time = time.time() - start_time
178
+ print(f"Total running time: {elapsed_time:.2f} s")
179
+ print("Successfully finished!")
180
+
181
+
182
+ if __name__ == "__main__":
183
+ GBLUP_class()
@@ -0,0 +1,5 @@
1
+ from .GBLUP_class import GBLUP_class
2
+
3
+ GBLUP = GBLUP_class
4
+
5
+ __all__ = ["GBLUP","GBLUP_class"]
@@ -0,0 +1,169 @@
1
+ import os
2
+ import torch
3
+ import swanlab
4
+ import argparse
5
+ import psutil
6
+ import time
7
+ import random
8
+ import numpy as np
9
+ import pandas as pd
10
+ import pynvml
11
+ from . import GEFormer_he_class
12
+
13
+ from .gMLP_class import GEFormer
14
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
15
+ from sklearn.preprocessing import LabelEncoder
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+ from sklearn.model_selection import StratifiedKFold, train_test_split
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--methods', type=str, default='GEFormer/')
22
+ parser.add_argument('--species', type=str, default='')
23
+ parser.add_argument('--phe', type=str, default='')
24
+ parser.add_argument('--data_dir', type=str, default='../../data/')
25
+ parser.add_argument('--result_dir', type=str, default='result/')
26
+
27
+ parser.add_argument('--epoch', type=int, default=1000)
28
+ parser.add_argument('--batch_size', type=int, default=64)
29
+ parser.add_argument('--learning_rate', type=float, default=0.01)
30
+ parser.add_argument('--patience', type=int, default=10)
31
+ return parser.parse_args()
32
+
33
+ def set_seed(seed=42):
34
+ random.seed(seed)
35
+ np.random.seed(seed)
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+
42
+ def load_data(args):
43
+ x = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
44
+ y = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
45
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
46
+ return x, y, x.shape[0], x.shape[1], names
47
+
48
+
49
+ def get_gpu_mem_by_pid(pid):
50
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
51
+ for p in procs:
52
+ if p.pid == pid:
53
+ return p.usedGpuMemory / 1024**2
54
+ return 0.0
55
+
56
+ def run_nested_cv(args, data, label, nsnp, device):
57
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
58
+ os.makedirs(result_dir, exist_ok=True)
59
+
60
+ kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
61
+
62
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
63
+ total_start = time.time()
64
+
65
+ num_classes = len(np.unique(label))
66
+
67
+ for fold, (train_idx, test_idx) in enumerate(kf.split(data, label)):
68
+ fold_start = time.time()
69
+ process = psutil.Process(os.getpid())
70
+
71
+ X_train, X_test = data[train_idx], data[test_idx]
72
+ y_train, y_test = label[train_idx], label[test_idx]
73
+
74
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
75
+ X_train, y_train, test_size=0.1, stratify=y_train, random_state=42
76
+ )
77
+
78
+ x_train = torch.from_numpy(X_train_sub).float().to(device)
79
+ y_train = torch.from_numpy(y_train_sub).long().to(device)
80
+ x_valid = torch.from_numpy(X_valid).float().to(device)
81
+ y_valid = torch.from_numpy(y_valid).long().to(device)
82
+ x_test = torch.from_numpy(X_test).float().to(device)
83
+ y_test_tensor = torch.from_numpy(y_test).long().to(device)
84
+
85
+ train_loader = DataLoader(TensorDataset(x_train, y_train), args.batch_size, shuffle=True)
86
+ valid_loader = DataLoader(TensorDataset(x_valid, y_valid), args.batch_size, shuffle=False)
87
+ test_loader = DataLoader(TensorDataset(x_test, y_test_tensor), args.batch_size, shuffle=False)
88
+
89
+ model = GEFormer(nsnp=nsnp, num_classes=num_classes).to(device)
90
+ model.train_model(
91
+ train_loader, valid_loader,
92
+ args.epoch, args.learning_rate, args.patience, device
93
+ )
94
+
95
+ logits = model.predict(test_loader)
96
+ y_pred = np.argmax(logits, axis=1)
97
+
98
+ acc = accuracy_score(y_test, y_pred)
99
+ prec, rec, f1, _ = precision_recall_fscore_support(
100
+ y_test, y_pred, average='macro', zero_division=0
101
+ )
102
+
103
+ all_acc.append(acc)
104
+ all_prec.append(prec)
105
+ all_rec.append(rec)
106
+ all_f1.append(f1)
107
+
108
+ fold_time = time.time() - fold_start
109
+ gpu_mem = get_gpu_mem_by_pid(os.getpid())
110
+ cpu_mem = process.memory_info().rss / 1024**2
111
+
112
+ if torch.cuda.is_available():
113
+ torch.cuda.empty_cache()
114
+ torch.cuda.reset_peak_memory_stats()
115
+
116
+ print(
117
+ f"Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, "
118
+ f"F1={f1:.4f}, Time={fold_time:.2f}s, "
119
+ f"GPU={gpu_mem:.2f}MB, CPU={cpu_mem:.2f}MB"
120
+ )
121
+
122
+ pd.DataFrame({"Y_test": y_test, "Y_pred": y_pred}).to_csv(
123
+ os.path.join(result_dir, f"fold{fold}.csv"), index=False
124
+ )
125
+
126
+ total_time = time.time() - total_start
127
+ print("\n===== CV Summary =====")
128
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
129
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
130
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
131
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
132
+ print(f"Time: {total_time:.2f}s")
133
+
134
+
135
+ def GEFormer_class():
136
+ set_seed(42)
137
+ pynvml.nvmlInit()
138
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
139
+
140
+ args = parse_args()
141
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
+
143
+ all_species = ["Human/Sim/"]
144
+
145
+ for species in all_species:
146
+ args.species = species
147
+ X, Y, _, nsnp, _ = load_data(args)
148
+ label_raw = np.nan_to_num(Y[:, 0])
149
+ le = LabelEncoder()
150
+ label = le.fit_transform(label_raw)
151
+ num_classes = len(le.classes_)
152
+
153
+ best_params = GEFormer_he_class.Hyperparameter(X, label, nsnp)
154
+ args.learning_rate = best_params['learning_rate']
155
+ args.batch_size = best_params['batch_size']
156
+ args.patience = best_params['patience']
157
+ start_time = time.time()
158
+ if torch.cuda.is_available():
159
+ torch.cuda.reset_peak_memory_stats()
160
+ process = psutil.Process(os.getpid())
161
+ run_nested_cv(args, X, label, nsnp, device)
162
+
163
+ elapsed_time = time.time() - start_time
164
+ print(f"Running time: {elapsed_time:.2f}s")
165
+ print("Successfully finished:", args.species, args.phe)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ GEFormer_class()