gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,137 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ config = {
6
+ "batch_size": 64,
7
+ "weights_units": [64, 32],
8
+ "regressor_units": [64, 32],
9
+ "dropout": 0.3,
10
+ }
11
+
12
+
13
+ class SimpleSNPModel(nn.Module):
14
+ def __init__(self, num_snps):
15
+ super().__init__()
16
+ try:
17
+ self.config = config
18
+ if not isinstance(num_snps, int) or num_snps <= 0:
19
+ raise ValueError(f"num_snps must be positive integer, got {num_snps}")
20
+
21
+ self.attention = self._build_attention_module(num_snps)
22
+ self.regressor = self._build_regressor_module(num_snps)
23
+ except Exception as e:
24
+ raise ValueError(f"Model initialization failed: {str(e)}")
25
+
26
+ def _build_attention_module(self, num_snps):
27
+ """Build attention module with error checking"""
28
+ try:
29
+ layers = []
30
+ prev_size = num_snps
31
+ for i, h_size in enumerate(self.config['weights_units']):
32
+ if not isinstance(h_size, int) or h_size <= 0:
33
+ raise ValueError(f"Invalid hidden size {h_size} in attention layer {i}")
34
+ layers.append(nn.Linear(prev_size, h_size))
35
+ if i < len(self.config['weights_units']) - 1:
36
+ layers.append(nn.GELU())
37
+ prev_size = h_size
38
+ layers.append(nn.Linear(prev_size, num_snps))
39
+ layers.append(nn.Sigmoid())
40
+ return nn.Sequential(*layers)
41
+ except Exception as e:
42
+ raise ValueError(f"Attention module construction failed: {str(e)}")
43
+
44
+ def _build_regressor_module(self, num_snps):
45
+ """Build regressor module with error checking"""
46
+ try:
47
+ layers = []
48
+ prev_size = num_snps
49
+ for i, h_size in enumerate(self.config['regressor_units']):
50
+ if not isinstance(h_size, int) or h_size <= 0:
51
+ raise ValueError(f"Invalid hidden size {h_size} in regressor layer {i}")
52
+ layers.append(nn.Linear(prev_size, h_size))
53
+ if i < len(self.config['regressor_units']) - 1:
54
+ layers.append(nn.LayerNorm(h_size))
55
+ layers.append(nn.GELU())
56
+ layers.append(nn.Dropout(self.config['dropout']))
57
+ prev_size = h_size
58
+ layers.append(nn.Linear(prev_size, 1))
59
+ return nn.Sequential(*layers)
60
+ except Exception as e:
61
+ raise ValueError(f"Regressor module construction failed: {str(e)}")
62
+
63
+ def forward(self, x):
64
+ """Forward pass with dimension checking"""
65
+ try:
66
+ if x.dim() != 2:
67
+ raise ValueError(f"Input must be 2D tensor, got {x.dim()}D")
68
+
69
+ pre_sigmoid_weights = self.attention[:-1](x)
70
+ att_weights = self.attention(x)
71
+ weighted = x * att_weights + x # Residual connection
72
+ return self.regressor(weighted).squeeze(), pre_sigmoid_weights
73
+ except Exception as e:
74
+ raise RuntimeError(f"Forward pass failed: {str(e)}")
75
+
76
+
77
+ def train_model(self, train_loader, valid_loader, num_epochs, learning_rate, weight_decay, patience, device):
78
+ optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
79
+ criterion = nn.MSELoss()
80
+ self.to(device)
81
+
82
+ best_loss = float('inf')
83
+ best_state = None
84
+ trigger_times = 0
85
+
86
+ for epoch in range(num_epochs):
87
+ self.train()
88
+ train_loss = 0.0
89
+ for inputs, labels in train_loader:
90
+ inputs, labels = inputs.to(device), labels.to(device)
91
+ optimizer.zero_grad()
92
+ outputs, _ = self(inputs)
93
+ labels = labels
94
+ loss = criterion(outputs, labels)
95
+ loss.backward()
96
+ optimizer.step()
97
+ train_loss += loss.item() * inputs.size(0)
98
+
99
+ self.eval()
100
+ valid_loss = 0.0
101
+ with torch.no_grad():
102
+ for inputs, labels in valid_loader:
103
+ inputs, labels = inputs.to(device), labels.to(device)
104
+ outputs,_ = self(inputs)
105
+ labels = labels
106
+ loss = criterion(outputs, labels)
107
+ valid_loss += loss.item() * inputs.size(0)
108
+
109
+ train_loss /= len(train_loader.dataset)
110
+ valid_loss /= len(valid_loader.dataset)
111
+
112
+ # ---------- Early stopping ----------
113
+ if valid_loss < best_loss:
114
+ best_loss = valid_loss
115
+ best_state = self.state_dict()
116
+ trigger_times = 0
117
+ else:
118
+ trigger_times += 1
119
+ if trigger_times >= patience:
120
+ print(f"Early stopping at epoch {epoch+1}")
121
+ break
122
+
123
+ if best_state is not None:
124
+ self.load_state_dict(best_state)
125
+ return best_loss
126
+
127
+ def predict(self, test_loader):
128
+ self.eval()
129
+ y_pred = []
130
+ with torch.no_grad():
131
+ for inputs, _ in test_loader:
132
+ outputs,_ = self(inputs)
133
+ y_pred.append(outputs.cpu().numpy())
134
+ y_pred = np.concatenate(y_pred, axis=0)
135
+ y_pred = np.squeeze(y_pred)
136
+ return y_pred
137
+
@@ -0,0 +1,313 @@
1
+ import argparse
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch.nn as nn
7
+ import time, psutil, os
8
+ import torch.optim as optim
9
+ from torch.optim import Adam
10
+ from torch.nn import MSELoss
11
+ from lightning.pytorch import LightningModule
12
+ from sklearn.model_selection import KFold
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+ from sklearn.preprocessing import StandardScaler
15
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
16
+ from . import Cropformer_Hyperparameters
17
+ import pynvml
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser(description="Argument parser")
21
+ parser.add_argument('--methods', type=str, default='Cropformer/', help='Random seed')
22
+ parser.add_argument('--species', type=str, default='Chickpea/GSTP012/', help='Dataset name')
23
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
24
+ parser.add_argument('--data_dir', type=str, default='../../data/', help='Path to data directory')
25
+ parser.add_argument('--result_dir', type=str, default='result/', help='Path to result directory')
26
+
27
+ parser.add_argument('--lr', type=float, default=0.01,help='Learning rate')
28
+ parser.add_argument('--num_head', type=int, default=1, help='Number of attention heads')
29
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout probability')
30
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
31
+ parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimension')
32
+ parser.add_argument('--kernel_size', type=int, default=3, help='Kernel size')
33
+ parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+ def load_data(args):
38
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
39
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
40
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
41
+
42
+ nsample = xData.shape[0]
43
+ nsnp = xData.shape[1]
44
+ print("Number of samples: ", nsample)
45
+ print("Number of SNPs: ", nsnp)
46
+ return xData, yData, nsample, nsnp, names
47
+
48
+ class LayerNorm(nn.Module):
49
+ def __init__(self, hidden_size, eps=1e-12):
50
+ super(LayerNorm, self).__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, x):
56
+ u = x.mean(-1, keepdim=True)
57
+ s = (x - u).pow(2).mean(-1, keepdim=True)
58
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
59
+ return self.weight * x + self.bias
60
+
61
+ class SelfAttention(LightningModule):
62
+ def __init__(self, num_attention_heads, input_size, hidden_size, output_dim=1, kernel_size=3,
63
+ hidden_dropout_prob=0.5, attention_probs_dropout_prob=0.5, learning_rate=0.001):
64
+ super(SelfAttention, self).__init__()
65
+ self.num_attention_heads = num_attention_heads
66
+ self.attention_head_size = int(hidden_size / num_attention_heads)
67
+ self.all_head_size = hidden_size
68
+
69
+ self.query = torch.nn.Linear(input_size, self.all_head_size)
70
+ self.key = torch.nn.Linear(input_size, self.all_head_size)
71
+ self.value = torch.nn.Linear(input_size, self.all_head_size)
72
+
73
+ self.attn_dropout = torch.nn.Dropout(attention_probs_dropout_prob)
74
+ self.out_dropout = torch.nn.Dropout(hidden_dropout_prob)
75
+ self.dense = torch.nn.Linear(hidden_size, input_size)
76
+ self.LayerNorm = torch.nn.LayerNorm(input_size, eps=1e-12)
77
+ self.relu = torch.nn.ReLU()
78
+ self.out = torch.nn.Linear(input_size, output_dim)
79
+ self.cnn = torch.nn.Conv1d(1, 1, kernel_size, stride=1, padding=1)
80
+
81
+ self.learning_rate = learning_rate
82
+ self.loss_fn = MSELoss()
83
+
84
+ def forward(self, input_tensor):
85
+ input_tensor = input_tensor.to(self.device)
86
+ self.cnn = self.cnn.to(self.device)
87
+
88
+ cnn_hidden = self.cnn(input_tensor.view(input_tensor.size(0), 1, -1))
89
+ input_tensor = cnn_hidden
90
+ mixed_query_layer = self.query(input_tensor)
91
+ mixed_key_layer = self.key(input_tensor)
92
+ mixed_value_layer = self.value(input_tensor)
93
+
94
+ query_layer = mixed_query_layer
95
+ key_layer = mixed_key_layer
96
+ value_layer = mixed_value_layer
97
+
98
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
99
+ attention_scores = attention_scores / np.sqrt(self.attention_head_size)
100
+ attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
101
+ attention_probs = self.attn_dropout(attention_probs)
102
+
103
+ context_layer = torch.matmul(attention_probs, value_layer)
104
+ hidden_states = self.dense(context_layer)
105
+ hidden_states = self.out_dropout(hidden_states)
106
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
107
+ output = self.out(self.relu(hidden_states.view(hidden_states.size(0), -1)))
108
+ return output
109
+
110
+ def training_step(self, batch, batch_idx):
111
+ x, y = batch
112
+ y_pred = self(x)
113
+ loss = self.loss_fn(y_pred, y)
114
+ return loss
115
+
116
+ def validation_step(self, batch, batch_idx):
117
+ x, y = batch
118
+ y_pred = self(x)
119
+ val_loss = self.loss_fn(y_pred, y)
120
+ return val_loss
121
+
122
+ def configure_optimizers(self):
123
+ return Adam(self.parameters(), lr=self.learning_rate)
124
+
125
+ class EarlyStopping:
126
+ def __init__(self, patience=10, delta=0):
127
+ self.patience = patience
128
+ self.delta = delta
129
+ self.best_score = None
130
+ self.counter = 0
131
+ self.early_stop = False
132
+
133
+ def __call__(self, score):
134
+ if self.best_score is None:
135
+ self.best_score = score
136
+ elif score < self.best_score + self.delta:
137
+ self.counter += 1
138
+ if self.counter >= self.patience:
139
+ self.early_stop = True
140
+ else:
141
+ self.best_score = score
142
+ self.counter = 0
143
+
144
+ def get_gpu_mem_by_pid(pid):
145
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
146
+ for p in procs:
147
+ if p.pid == pid:
148
+ return p.usedGpuMemory / 1024**2
149
+ return 0.0
150
+
151
+ def run_nested_cv_with_early_stopping(args, data, label, outer_cv, learning_rate, batch_size, hidden_dim,
152
+ output_dim, kernel_size, patience, DEVICE):
153
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
154
+ os.makedirs(result_dir, exist_ok=True)
155
+ best_corr_coefs = []
156
+ best_maes = []
157
+ best_r2s = []
158
+ best_mses = []
159
+
160
+ time_star = time.time()
161
+ process = psutil.Process(os.getpid())
162
+ for fold, (train_idx, test_idx) in enumerate(outer_cv.split(data)):
163
+ fold_start_time = time.time()
164
+ if torch.cuda.is_available():
165
+ torch.cuda.reset_peak_memory_stats()
166
+ process = psutil.Process(os.getpid())
167
+
168
+ x_train, x_test = data[train_idx], data[test_idx]
169
+ y_train, y_test = label[train_idx], label[test_idx]
170
+
171
+ num_attention_heads = args.num_head
172
+ attention_probs_dropout_prob = args.dropout
173
+ hidden_dropout_prob = 0.5
174
+
175
+ model = SelfAttention(num_attention_heads, x_train.shape[1], hidden_dim, output_dim,
176
+ hidden_dropout_prob=hidden_dropout_prob, kernel_size=kernel_size,
177
+ attention_probs_dropout_prob=attention_probs_dropout_prob).to(DEVICE)
178
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
179
+ loss_function = torch.nn.MSELoss()
180
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10)
181
+
182
+ scaler = StandardScaler()
183
+ x_train = scaler.fit_transform(x_train)
184
+ x_test = scaler.transform(x_test)
185
+
186
+ x_train_tensor = torch.from_numpy(x_train).float().to(DEVICE)
187
+ y_train_tensor = torch.from_numpy(y_train).float().to(DEVICE)
188
+ x_test_tensor = torch.from_numpy(x_test).float().to(DEVICE)
189
+ y_test_tensor = torch.from_numpy(y_test).float().to(DEVICE)
190
+
191
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
192
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
193
+
194
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
195
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
196
+
197
+ early_stopping = EarlyStopping(patience=patience)
198
+ best_corr_coef = -float('inf')
199
+ best_mae = float('inf')
200
+ best_mse = float('inf')
201
+ best_r2 = -float('inf')
202
+ for epoch in range(100):
203
+ model.train()
204
+ for x_batch, y_batch in train_loader:
205
+ optimizer.zero_grad()
206
+ y_pred = model(x_batch)
207
+ loss = loss_function(y_pred, y_batch.reshape(-1, 1))
208
+ loss.backward()
209
+ optimizer.step()
210
+
211
+ model.eval()
212
+ y_test_preds, y_test_trues = [], []
213
+
214
+ with torch.no_grad():
215
+ for x_batch, y_batch in test_loader:
216
+ y_test_pred = model(x_batch)
217
+ y_test_preds.extend(y_test_pred.cpu().numpy().reshape(-1).tolist())
218
+ y_test_trues.extend(y_batch.cpu().numpy().reshape(-1).tolist())
219
+
220
+ corr_coef = np.corrcoef(y_test_preds, y_test_trues)[0, 1]
221
+ mae = mean_absolute_error(np.array(y_test_trues), np.array(y_test_preds))
222
+ mse = mean_squared_error(np.array(y_test_trues), np.array(y_test_preds))
223
+ r2 = r2_score(np.array(y_test_trues), np.array(y_test_preds))
224
+ scheduler.step(corr_coef)
225
+
226
+
227
+ if corr_coef > best_corr_coef:
228
+ best_mae = mae
229
+ best_corr_coef = corr_coef
230
+ best_mse = mse
231
+ best_r2 = r2
232
+
233
+ early_stopping(corr_coef)
234
+ if early_stopping.early_stop:
235
+ print(f"Early stopping at epoch {epoch + 1}")
236
+ break
237
+
238
+ best_corr_coefs.append(best_corr_coef)
239
+ best_maes.append(best_mae)
240
+ best_mses.append(best_mse)
241
+ best_r2s.append(best_r2)
242
+
243
+ fold_time = time.time() - fold_start_time
244
+ fold_gpu_mem = get_gpu_mem_by_pid(os.getpid())
245
+ fold_cpu_mem = process.memory_info().rss / 1024**2
246
+ print(f'Fold {fold + 1}: Corr={best_corr_coef:.4f}, MAE={best_mae:.4f}, MSE={best_mse:.4f}, R2={best_r2:.4f}, Time={fold_time:.2f}s, '
247
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
248
+
249
+ results_df = pd.DataFrame({'Y_test': y_test_trues, 'Y_pred': y_test_preds})
250
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
251
+
252
+ print("==== Final Results ====")
253
+ print(f"Corr: {np.mean(best_corr_coefs):.4f} ± {np.std(best_corr_coefs):.4f}")
254
+ print(f"MAE: {np.mean(best_maes):.4f} ± {np.std(best_maes):.4f}")
255
+ print(f"MSE: {np.mean(best_mses):.4f} ± {np.std(best_mses):.4f}")
256
+ print(f"R2 : {np.mean(best_r2s):.4f} ± {np.std(best_r2s):.4f}")
257
+ print(f"Time: {time.time() - time_star:.2f}s")
258
+
259
+
260
+ def set_seed(seed=42):
261
+ random.seed(seed)
262
+ np.random.seed(seed)
263
+ torch.manual_seed(seed)
264
+ torch.cuda.manual_seed_all(seed)
265
+ torch.backends.cudnn.deterministic = True
266
+ torch.backends.cudnn.benchmark = False
267
+
268
+ def Cropformer_reg():
269
+ set_seed(42)
270
+ pynvml.nvmlInit()
271
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
272
+ device = torch.device("cuda:0")
273
+ args = parse_args()
274
+ all_species =['Cotton/']
275
+ for i in range(len(all_species)):
276
+ args.species = all_species[i]
277
+ X, Y, nsamples, nsnp, names = load_data(args)
278
+ for j in range(len(names)):
279
+ args.phe = names[j]
280
+ print("starting run " + args.methods + args.species + args.phe)
281
+ label = Y[:, j]
282
+ label = np.nan_to_num(label, nan=np.nanmean(label))
283
+ best_params = Cropformer_Hyperparameters.Hyperparameter(X, label)
284
+ args.lr = best_params['learning_rate']
285
+ args.num_head = best_params['heads']
286
+ args.dropout = best_params['dropout']
287
+ args.batch_size = best_params['batch_size']
288
+
289
+ outer_cv = KFold(n_splits=10, shuffle=True, random_state=42)
290
+
291
+ start_time = time.time()
292
+ torch.cuda.reset_peak_memory_stats()
293
+ process = psutil.Process(os.getpid())
294
+
295
+ run_nested_cv_with_early_stopping(args,
296
+ data=X,
297
+ label=label,
298
+ outer_cv=outer_cv,
299
+ learning_rate= args.lr,
300
+ batch_size=args.batch_size,
301
+ hidden_dim=args.hidden_dim,
302
+ output_dim=1,
303
+ kernel_size=3,
304
+ patience=args.patience,
305
+ DEVICE='cuda:0' if torch.cuda.is_available() else 'cpu')
306
+
307
+ elapsed_time = time.time() - start_time
308
+ print(f"running time: {elapsed_time:.2f} s")
309
+ print("successfully")
310
+
311
+
312
+ if __name__ == '__main__':
313
+ Cropformer_reg()