gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,308 @@
1
+ import argparse
2
+ import random
3
+ import torch
4
+ import swanlab
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch.nn as nn
8
+ import time, psutil, os
9
+ import torch.optim as optim
10
+ from torch.optim import Adam
11
+ from lightning.pytorch import LightningModule
12
+ from sklearn.model_selection import KFold
13
+ from sklearn.preprocessing import LabelEncoder
14
+ from torch.utils.data import DataLoader, TensorDataset
15
+ from sklearn.preprocessing import StandardScaler
16
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
17
+ from . import cropformer_he_class
18
+ import pynvml
19
+
20
+ # =========================
21
+ # Argument parser
22
+ # =========================
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Argument parser")
25
+ parser.add_argument('--methods', type=str, default='Cropformer/', help='Method name')
26
+ parser.add_argument('--species', type=str, default='', help='Dataset name')
27
+ parser.add_argument('--phe', type=str, default='', help='Phenotype name')
28
+ parser.add_argument('--data_dir', type=str, default='../../data/', help='Data directory')
29
+ parser.add_argument('--result_dir', type=str, default='result/', help='Result directory')
30
+
31
+ parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
32
+ parser.add_argument('--num_head', type=int, default=1, help='Number of attention heads')
33
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout probability')
34
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
35
+ parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimension')
36
+ parser.add_argument('--kernel_size', type=int, default=3, help='Kernel size')
37
+ parser.add_argument('--patience', type=int, default=5, help='Early stopping patience')
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+ def load_data(args):
42
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
43
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
44
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
45
+ print(f"Number of samples: {xData.shape[0]}, Number of SNPs: {xData.shape[1]}")
46
+ return xData, yData, xData.shape[0], xData.shape[1], names
47
+
48
+ class LayerNorm(nn.Module):
49
+ def __init__(self, hidden_size, eps=1e-12):
50
+ super(LayerNorm, self).__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, x):
56
+ u = x.mean(-1, keepdim=True)
57
+ s = (x - u).pow(2).mean(-1, keepdim=True)
58
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
59
+ return self.weight * x + self.bias
60
+
61
+ class SelfAttention(LightningModule):
62
+ def __init__(self, num_attention_heads, input_size, hidden_size, output_dim=2, kernel_size=3,
63
+ hidden_dropout_prob=0.5, attention_probs_dropout_prob=0.5, learning_rate=0.001):
64
+ super(SelfAttention, self).__init__()
65
+ self.num_attention_heads = num_attention_heads
66
+ self.attention_head_size = int(hidden_size / num_attention_heads)
67
+ self.all_head_size = hidden_size
68
+
69
+ self.query = nn.Linear(input_size, self.all_head_size)
70
+ self.key = nn.Linear(input_size, self.all_head_size)
71
+ self.value = nn.Linear(input_size, self.all_head_size)
72
+
73
+ self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
74
+ self.out_dropout = nn.Dropout(hidden_dropout_prob)
75
+ self.dense = nn.Linear(hidden_size, input_size)
76
+ self.LayerNorm = nn.LayerNorm(input_size, eps=1e-12)
77
+ self.relu = nn.ReLU()
78
+ self.out = nn.Linear(input_size, output_dim)
79
+ self.cnn = nn.Conv1d(1, 1, kernel_size, stride=1, padding=1)
80
+
81
+ self.learning_rate = learning_rate
82
+ self.loss_fn = nn.CrossEntropyLoss()
83
+
84
+ def forward(self, input_tensor):
85
+ input_tensor = input_tensor.to(self.device)
86
+ self.cnn = self.cnn.to(self.device)
87
+
88
+ cnn_hidden = self.cnn(input_tensor.view(input_tensor.size(0), 1, -1))
89
+ input_tensor = cnn_hidden
90
+ mixed_query_layer = self.query(input_tensor)
91
+ mixed_key_layer = self.key(input_tensor)
92
+ mixed_value_layer = self.value(input_tensor)
93
+
94
+ attention_scores = torch.matmul(mixed_query_layer, mixed_key_layer.transpose(-1, -2))
95
+ attention_scores = attention_scores / np.sqrt(self.attention_head_size)
96
+ attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
97
+ attention_probs = self.attn_dropout(attention_probs)
98
+
99
+ context_layer = torch.matmul(attention_probs, mixed_value_layer)
100
+ hidden_states = self.dense(context_layer)
101
+ hidden_states = self.out_dropout(hidden_states)
102
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
103
+ output = self.out(self.relu(hidden_states.view(hidden_states.size(0), -1)))
104
+ return output
105
+
106
+ def training_step(self, batch, batch_idx):
107
+ x, y = batch
108
+ y = y.long()
109
+ y_pred = self(x)
110
+ loss = self.loss_fn(y_pred, y)
111
+ return loss
112
+
113
+ def validation_step(self, batch, batch_idx):
114
+ x, y = batch
115
+ y = y.long()
116
+ y_pred = self(x)
117
+ val_loss = self.loss_fn(y_pred, y)
118
+ return val_loss
119
+
120
+ def configure_optimizers(self):
121
+ return Adam(self.parameters(), lr=self.learning_rate)
122
+
123
+ class EarlyStopping:
124
+ def __init__(self, patience=10, delta=0):
125
+ self.patience = patience
126
+ self.delta = delta
127
+ self.best_score = None
128
+ self.counter = 0
129
+ self.early_stop = False
130
+
131
+ def __call__(self, score):
132
+ if self.best_score is None:
133
+ self.best_score = score
134
+ elif score < self.best_score + self.delta:
135
+ self.counter += 1
136
+ if self.counter >= self.patience:
137
+ self.early_stop = True
138
+ else:
139
+ self.best_score = score
140
+ self.counter = 0
141
+
142
+ def get_gpu_mem_by_pid(pid):
143
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
144
+ for p in procs:
145
+ if p.pid == pid:
146
+ return p.usedGpuMemory / 1024**2
147
+ return 0.0
148
+
149
+ def run_nested_cv_with_early_stopping(args, data, label, outer_cv, learning_rate, batch_size, hidden_dim,
150
+ output_dim, kernel_size, patience, DEVICE):
151
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
152
+ os.makedirs(result_dir, exist_ok=True)
153
+
154
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
155
+ time_star = time.time()
156
+ process = psutil.Process(os.getpid())
157
+
158
+ for fold, (train_idx, test_idx) in enumerate(outer_cv.split(data)):
159
+ fold_start_time = time.time()
160
+ if torch.cuda.is_available():
161
+ torch.cuda.reset_peak_memory_stats()
162
+
163
+ x_train, x_test = data[train_idx], data[test_idx]
164
+ y_train, y_test = label[train_idx], label[test_idx]
165
+
166
+ num_attention_heads = args.num_head
167
+ attention_probs_dropout_prob = args.dropout
168
+ hidden_dropout_prob = 0.5
169
+
170
+ model = SelfAttention(num_attention_heads, x_train.shape[1], hidden_dim, output_dim,
171
+ hidden_dropout_prob=hidden_dropout_prob, kernel_size=kernel_size,
172
+ attention_probs_dropout_prob=attention_probs_dropout_prob).to(DEVICE)
173
+
174
+ optimizer = Adam(model.parameters(), lr=learning_rate)
175
+
176
+ scaler = StandardScaler()
177
+ x_train = scaler.fit_transform(x_train)
178
+ x_test = scaler.transform(x_test)
179
+
180
+ x_train_tensor = torch.from_numpy(x_train).float().to(DEVICE)
181
+ y_train_tensor = torch.from_numpy(y_train).long().to(DEVICE)
182
+ x_test_tensor = torch.from_numpy(x_test).float().to(DEVICE)
183
+ y_test_tensor = torch.from_numpy(y_test).long().to(DEVICE)
184
+
185
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
186
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
187
+
188
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
189
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
190
+
191
+ early_stopping = EarlyStopping(patience=patience)
192
+ best_f1 = -float('inf')
193
+
194
+ for epoch in range(100):
195
+ model.train()
196
+ for x_batch, y_batch in train_loader:
197
+ optimizer.zero_grad()
198
+ y_pred = model(x_batch)
199
+ loss = model.loss_fn(y_pred, y_batch)
200
+ loss.backward()
201
+ optimizer.step()
202
+
203
+ model.eval()
204
+ y_test_preds, y_test_trues = [], []
205
+ with torch.no_grad():
206
+ for x_batch, y_batch in test_loader:
207
+ y_test_pred = model(x_batch)
208
+ preds = torch.argmax(y_test_pred, dim=1)
209
+ y_test_preds.extend(preds.cpu().numpy())
210
+ y_test_trues.extend(y_batch.cpu().numpy())
211
+
212
+ acc = accuracy_score(y_test_trues, y_test_preds)
213
+ prec, rec, f1, _ = precision_recall_fscore_support(
214
+ y_test_trues, y_test_preds, average="macro", zero_division=0
215
+ )
216
+
217
+ if f1 > best_f1:
218
+ best_acc, best_prec, best_rec, best_f1 = acc, prec, rec, f1
219
+
220
+ early_stopping(f1)
221
+ if early_stopping.early_stop:
222
+ print(f"Early stopping at epoch {epoch + 1}")
223
+ break
224
+
225
+ all_acc.append(best_acc)
226
+ all_prec.append(best_prec)
227
+ all_rec.append(best_rec)
228
+ all_f1.append(best_f1)
229
+
230
+ fold_time = time.time() - fold_start_time
231
+ fold_gpu_mem = get_gpu_mem_by_pid(os.getpid())
232
+ fold_cpu_mem = process.memory_info().rss / 1024**2
233
+
234
+ print(f'Fold {fold + 1}: ACC={best_acc:.4f}, PREC={best_prec:.4f}, REC={best_rec:.4f}, F1={best_f1:.4f}, '
235
+ f'Time={fold_time:.2f}s, GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
236
+
237
+ results_df = pd.DataFrame({'Y_test': y_test_trues, 'Y_pred': y_test_preds})
238
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
239
+
240
+ print("==== Final Results ====")
241
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
242
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
243
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
244
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
245
+ print(f"Time: {time.time() - time_star:.2f}s")
246
+
247
+
248
+ # =========================
249
+ # Set seed
250
+ # =========================
251
+ def set_seed(seed=42):
252
+ random.seed(seed)
253
+ np.random.seed(seed)
254
+ torch.manual_seed(seed)
255
+ torch.cuda.manual_seed_all(seed)
256
+ torch.backends.cudnn.deterministic = True
257
+ torch.backends.cudnn.benchmark = False
258
+
259
+
260
+ def Cropformer_class():
261
+ set_seed(42)
262
+ pynvml.nvmlInit()
263
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
264
+ device = torch.device("cuda:0")
265
+ args = parse_args()
266
+ all_species = ["Human/Sim/"]
267
+
268
+ for i in range(len(all_species)):
269
+ args.species = all_species[i]
270
+ X, Y, nsamples, nsnp, names = load_data(args)
271
+ print("Starting run " + args.methods + args.species)
272
+ label = Y[:, 0]
273
+ le = LabelEncoder()
274
+ label = le.fit_transform(label)
275
+
276
+ best_params = cropformer_he_class.Hyperparameter(X, label)
277
+ args.lr = best_params['learning_rate']
278
+ args.num_head = best_params['heads']
279
+ args.dropout = best_params['dropout']
280
+ args.batch_size = best_params['batch_size']
281
+
282
+ outer_cv = KFold(n_splits=10, shuffle=True, random_state=42)
283
+
284
+ start_time = time.time()
285
+ torch.cuda.reset_peak_memory_stats()
286
+ process = psutil.Process(os.getpid())
287
+
288
+ run_nested_cv_with_early_stopping(
289
+ args,
290
+ data=X,
291
+ label=label,
292
+ outer_cv=outer_cv,
293
+ learning_rate=args.lr,
294
+ batch_size=args.batch_size,
295
+ hidden_dim=args.hidden_dim,
296
+ output_dim=len(np.unique(label)),
297
+ kernel_size=3,
298
+ patience=args.patience,
299
+ DEVICE='cuda:0' if torch.cuda.is_available() else 'cpu'
300
+ )
301
+
302
+ elapsed_time = time.time() - start_time
303
+ print(f"Total runtime: {elapsed_time:.2f} seconds")
304
+ print("Successfully finished!")
305
+
306
+
307
+ if __name__ == '__main__':
308
+ Cropformer_class()
@@ -0,0 +1,5 @@
1
+ from .Cropformer_class import Cropformer_class
2
+
3
+ Cropformer = Cropformer_class
4
+
5
+ __all__ = ["Cropformer","Cropformer_class"]
@@ -0,0 +1,221 @@
1
+ import time
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import random
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+ from sklearn.preprocessing import StandardScaler
9
+ from lightning.pytorch import LightningModule
10
+ import optuna
11
+ from sklearn.model_selection import KFold
12
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
13
+
14
+ class LayerNorm(nn.Module):
15
+ def __init__(self, hidden_size, eps=1e-12):
16
+ super(LayerNorm, self).__init__()
17
+ self.weight = nn.Parameter(torch.ones(hidden_size))
18
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
19
+ self.variance_epsilon = eps
20
+
21
+ def forward(self, x):
22
+ u = x.mean(-1, keepdim=True)
23
+ s = (x - u).pow(2).mean(-1, keepdim=True)
24
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
25
+ return self.weight * x + self.bias
26
+
27
+ class SelfAttention(LightningModule):
28
+ def __init__(self, num_attention_heads, input_size, hidden_size, output_dim=2, kernel_size=3,
29
+ hidden_dropout_prob=0.5, attention_probs_dropout_prob=0.5, learning_rate=0.001):
30
+ super(SelfAttention, self).__init__()
31
+ self.num_attention_heads = num_attention_heads
32
+ self.attention_head_size = int(hidden_size / num_attention_heads)
33
+ self.all_head_size = hidden_size
34
+
35
+ self.query = nn.Linear(input_size, self.all_head_size)
36
+ self.key = nn.Linear(input_size, self.all_head_size)
37
+ self.value = nn.Linear(input_size, self.all_head_size)
38
+
39
+ self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
40
+ self.out_dropout = nn.Dropout(hidden_dropout_prob)
41
+ self.dense = nn.Linear(hidden_size, input_size)
42
+ self.LayerNorm = nn.LayerNorm(input_size, eps=1e-12)
43
+ self.relu = nn.ReLU()
44
+ self.out = nn.Linear(input_size, output_dim)
45
+ self.cnn = nn.Conv1d(1, 1, kernel_size, stride=1, padding=1)
46
+
47
+ self.learning_rate = learning_rate
48
+ self.loss_fn = nn.CrossEntropyLoss()
49
+
50
+ def forward(self, input_tensor):
51
+ input_tensor = input_tensor.to(self.device)
52
+ self.cnn = self.cnn.to(self.device)
53
+
54
+ cnn_hidden = self.cnn(input_tensor.view(input_tensor.size(0), 1, -1))
55
+ input_tensor = cnn_hidden
56
+ mixed_query_layer = self.query(input_tensor)
57
+ mixed_key_layer = self.key(input_tensor)
58
+ mixed_value_layer = self.value(input_tensor)
59
+
60
+ attention_scores = torch.matmul(mixed_query_layer, mixed_key_layer.transpose(-1, -2))
61
+ attention_scores = attention_scores / np.sqrt(self.attention_head_size)
62
+ attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
63
+ attention_probs = self.attn_dropout(attention_probs)
64
+
65
+ context_layer = torch.matmul(attention_probs, mixed_value_layer)
66
+ hidden_states = self.dense(context_layer)
67
+ hidden_states = self.out_dropout(hidden_states)
68
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
69
+ output = self.out(self.relu(hidden_states.view(hidden_states.size(0), -1)))
70
+ return output
71
+
72
+ def training_step(self, batch, batch_idx):
73
+ x, y = batch
74
+ y = y.long()
75
+ y_pred = self(x)
76
+ loss = self.loss_fn(y_pred, y)
77
+ return loss
78
+
79
+ def validation_step(self, batch, batch_idx):
80
+ x, y = batch
81
+ y = y.long()
82
+ y_pred = self(x)
83
+ val_loss = self.loss_fn(y_pred, y)
84
+ return val_loss
85
+
86
+ def configure_optimizers(self):
87
+ return optim.Adam(self.parameters(), lr=self.learning_rate)
88
+
89
+ class EarlyStopping:
90
+ def __init__(self, patience=10, delta=0):
91
+ self.patience = patience
92
+ self.delta = delta
93
+ self.best_score = None
94
+ self.counter = 0
95
+ self.early_stop = False
96
+
97
+ def __call__(self, score):
98
+ if self.best_score is None:
99
+ self.best_score = score
100
+ elif score < self.best_score + self.delta:
101
+ self.counter += 1
102
+ if self.counter >= self.patience:
103
+ self.early_stop = True
104
+ else:
105
+ self.best_score = score
106
+ self.counter = 0
107
+
108
+ def run_nested_cv_with_early_stopping(data, label, outer_cv, learning_rate, num_heads, dropout_prob, batch_size, hidden_dim,
109
+ output_dim, kernel_size, patience, DEVICE):
110
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
111
+ time_star = time.time()
112
+
113
+ for fold, (train_idx, test_idx) in enumerate(outer_cv.split(data)):
114
+ x_train, x_test = data[train_idx], data[test_idx]
115
+ y_train, y_test = label[train_idx], label[test_idx]
116
+
117
+ model = SelfAttention(num_heads, x_train.shape[1], hidden_dim, output_dim,
118
+ hidden_dropout_prob=0.5, kernel_size=kernel_size,
119
+ attention_probs_dropout_prob=dropout_prob).to(DEVICE)
120
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
121
+
122
+ scaler = StandardScaler()
123
+ x_train = scaler.fit_transform(x_train)
124
+ x_test = scaler.transform(x_test)
125
+
126
+ x_train_tensor = torch.from_numpy(x_train).float().to(DEVICE)
127
+ y_train_tensor = torch.from_numpy(y_train).long().to(DEVICE)
128
+ x_test_tensor = torch.from_numpy(x_test).float().to(DEVICE)
129
+ y_test_tensor = torch.from_numpy(y_test).long().to(DEVICE)
130
+
131
+ train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=batch_size, shuffle=True)
132
+ test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=batch_size, shuffle=False)
133
+
134
+ early_stopping = EarlyStopping(patience=patience)
135
+ best_f1 = -float('inf')
136
+
137
+ for epoch in range(100):
138
+ model.train()
139
+ for x_batch, y_batch in train_loader:
140
+ optimizer.zero_grad()
141
+ y_pred = model(x_batch)
142
+ loss = model.loss_fn(y_pred, y_batch)
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ model.eval()
147
+ y_test_preds, y_test_trues = [], []
148
+ with torch.no_grad():
149
+ for x_batch, y_batch in test_loader:
150
+ y_pred = model(x_batch)
151
+ preds = torch.argmax(y_pred, dim=1)
152
+ y_test_preds.extend(preds.cpu().numpy())
153
+ y_test_trues.extend(y_batch.cpu().numpy())
154
+
155
+ acc = accuracy_score(y_test_trues, y_test_preds)
156
+ prec, rec, f1, _ = precision_recall_fscore_support(y_test_trues, y_test_preds, average="macro", zero_division=0)
157
+
158
+ if f1 > best_f1:
159
+ best_acc, best_prec, best_rec, best_f1 = acc, prec, rec, f1
160
+
161
+ early_stopping(f1)
162
+ if early_stopping.early_stop:
163
+ print(f"Early stopping at epoch {epoch+1}")
164
+ break
165
+
166
+ all_acc.append(best_acc)
167
+ all_prec.append(best_prec)
168
+ all_rec.append(best_rec)
169
+ all_f1.append(best_f1)
170
+ print(f'Fold {fold+1}: ACC={best_acc:.4f}, PREC={best_prec:.4f}, REC={best_rec:.4f}, F1={best_f1:.4f}')
171
+
172
+ print("==== Final Results ====")
173
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
174
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
175
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
176
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
177
+ print(f"Time: {time.time() - time_star:.2f}s")
178
+
179
+ return all_f1
180
+
181
+ def set_seed(seed=42):
182
+ random.seed(seed)
183
+ np.random.seed(seed)
184
+ torch.manual_seed(seed)
185
+ torch.cuda.manual_seed_all(seed)
186
+ torch.backends.cudnn.deterministic = True
187
+ torch.backends.cudnn.benchmark = False
188
+
189
+ def Hyperparameter(X, label):
190
+ set_seed(42)
191
+ torch.cuda.empty_cache()
192
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
193
+
194
+ def objective(trial):
195
+ lr = trial.suggest_loguniform("learning_rate", 1e-4, 1e-1)
196
+ heads = trial.suggest_int("heads", 1, 8)
197
+ dropout = trial.suggest_float("dropout", 0.1, 0.9)
198
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
199
+
200
+ outer_cv = KFold(n_splits=10, shuffle=True, random_state=42)
201
+ f1_scores = run_nested_cv_with_early_stopping(
202
+ data=X,
203
+ label=label.astype(int),
204
+ outer_cv=outer_cv,
205
+ learning_rate=lr,
206
+ num_heads=heads,
207
+ dropout_prob=dropout,
208
+ batch_size=batch_size,
209
+ hidden_dim=64,
210
+ output_dim=len(np.unique(label)),
211
+ kernel_size=3,
212
+ patience=5,
213
+ DEVICE=device
214
+ )
215
+ return np.mean(f1_scores)
216
+
217
+ study = optuna.create_study(direction="maximize")
218
+ study.optimize(objective, n_trials=20)
219
+
220
+ print("Best hyperparameters:", study.best_params)
221
+ return study.best_params