gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,182 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import swanlab
5
+ import argparse
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.model_selection import KFold
11
+ from scipy.stats import pearsonr
12
+ from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
13
+
14
+ # rpy2 导入
15
+ import rpy2.robjects as ro
16
+ from rpy2.robjects import numpy2ri
17
+ numpy2ri.activate()
18
+
19
+ # 为 BLUP 求逆
20
+ ro.r('library(MASS)')
21
+
22
+ def gblup_r_vanraden_reml(X_train, y_train, X_test):
23
+
24
+ # Pass data to R
25
+ ro.globalenv['X_train'] = X_train
26
+ ro.globalenv['y_train'] = y_train
27
+ ro.globalenv['X_test'] = X_test
28
+
29
+ r_code = """
30
+ library(rrBLUP)
31
+
32
+ n_train <- nrow(X_train)
33
+ m <- ncol(X_train)
34
+
35
+ # Step1: allele frequencies
36
+ p <- colMeans(X_train) / 2
37
+ p <- pmax(pmin(p, 0.99), 0.01)
38
+
39
+ # Step2: VanRaden standardized genotype
40
+ Z_train <- sweep(X_train, 2, 2*p, "-") / sqrt(2*p*(1-p))
41
+ Z_train[is.na(Z_train)] <- 0
42
+
43
+ Z_test <- sweep(X_test, 2, 2*p, "-") / sqrt(2*p*(1-p))
44
+ Z_test[is.na(Z_test)] <- 0
45
+
46
+ # Step3: Genomic relationship matrix (VanRaden method 2)
47
+ denom <- sum(2*p*(1-p))
48
+ G <- Z_train %*% t(Z_train) / denom
49
+ G <- G + diag(1e-6, n_train) # stability
50
+
51
+ # Step4: REML GBLUP
52
+ fit <- mixed.solve(y = y_train, K = G, SE = FALSE)
53
+
54
+ # Extract variance components and fixed effect
55
+ Vu <- fit$Vu
56
+ Ve <- fit$Ve
57
+ mu <- as.numeric(fit$beta) # <-- 转成标量,避免非兼容数组
58
+ h2 <- Vu / (Vu + Ve)
59
+
60
+ # Step5: GBLUP prediction for test set
61
+ y_centered <- y_train - mu
62
+ A <- G + (Ve / Vu) * diag(n_train) # G + λ I
63
+
64
+ G_test_train <- Z_test %*% t(Z_train) / denom
65
+ u_test <- G_test_train %*% solve(A, y_centered) # strictly correct formula
66
+
67
+ y_pred <- mu + u_test
68
+ y_pred
69
+ """
70
+
71
+ y_pred = np.array(ro.r(r_code)).flatten()
72
+ return y_pred
73
+
74
+
75
+ def parse_args():
76
+ parser = argparse.ArgumentParser(description="Argument parser")
77
+ parser.add_argument('--methods', type=str, default='GBLUP_R/', help='Method name')
78
+ parser.add_argument('--species', type=str, default='')
79
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
80
+ parser.add_argument('--data_dir', type=str, default='../../data/', help='Path to data directory')
81
+ parser.add_argument('--result_dir', type=str, default='result/', help='Path to result directory')
82
+ args = parser.parse_args()
83
+ return args
84
+
85
+
86
+ def load_data(args):
87
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
88
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
89
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
90
+
91
+ nsample = xData.shape[0]
92
+ nsnp = xData.shape[1]
93
+ print("Number of samples: ", nsample)
94
+ print("Number of SNPs: ", nsnp)
95
+ return xData, yData, nsample, nsnp, names
96
+
97
+
98
+ def set_seed(seed=42):
99
+ random.seed(seed)
100
+ np.random.seed(seed)
101
+ torch.manual_seed(torch.tensor(seed))
102
+ torch.cuda.manual_seed_all(seed)
103
+ torch.backends.cudnn.deterministic = True
104
+ torch.backends.cudnn.benchmark = False
105
+
106
+
107
+ def run_nested_cv(args, data, label, process):
108
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
109
+ os.makedirs(result_dir, exist_ok=True)
110
+ print("Starting 10-fold cross-validation with GBLUP (R VanRaden)...")
111
+
112
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
113
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
114
+ time_star = time.time()
115
+
116
+ for fold, (train_idx, test_idx) in enumerate(kf.split(label)):
117
+ print(f"===== Fold {fold} =====")
118
+ fold_start_time = time.time()
119
+ X_train, X_test = data[train_idx], data[test_idx]
120
+ y_train, y_test = label[train_idx], label[test_idx]
121
+
122
+ # === run strict GBLUP via R ===
123
+ y_pred = gblup_r_vanraden_reml(X_train, y_train, X_test)
124
+
125
+ # 评价指标
126
+ pcc = pearsonr(y_test, y_pred)[0]
127
+ mse = mean_squared_error(y_test, y_pred)
128
+ r2 = r2_score(y_test, y_pred)
129
+ mae = mean_absolute_error(y_test, y_pred)
130
+
131
+ all_mse.append(mse)
132
+ all_r2.append(r2)
133
+ all_mae.append(mae)
134
+ all_pcc.append(pcc)
135
+
136
+ fold_time = time.time() - fold_start_time
137
+ fold_gpu_mem = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
138
+ fold_cpu_mem = process.memory_info().rss / 1024**2
139
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
140
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
141
+
142
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
143
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
144
+
145
+ print("\n===== Cross-validation summary =====")
146
+ print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
147
+ print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
148
+ print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
149
+ print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
150
+ print(f"Time: {time.time() - time_star:.2f}s")
151
+
152
+
153
+
154
+ def GBLUP_reg():
155
+ set_seed(42)
156
+ torch.cuda.empty_cache()
157
+ device = torch.device("cuda:0")
158
+ args = parse_args()
159
+ process = psutil.Process(os.getpid())
160
+ all_species =['Cotton/']
161
+
162
+ for sp in all_species:
163
+ args.species = sp
164
+ X, Y, nsamples, nsnp, names = load_data(args)
165
+ for i, phe in enumerate(names):
166
+ args.phe = phe
167
+ print("starting run " + args.methods + args.species + args.phe)
168
+ label = Y[:, i]
169
+ label = np.nan_to_num(label, nan=np.nanmean(label))
170
+ start_time = time.time()
171
+ torch.cuda.reset_peak_memory_stats()
172
+
173
+
174
+ run_nested_cv(args, data=X, label=label, process=process)
175
+
176
+ elapsed_time = time.time() - start_time
177
+ print(f"运行时间: {elapsed_time:.2f} 秒")
178
+ print("successfully")
179
+
180
+
181
+ if __name__ == "__main__":
182
+ GBLUP_reg()
@@ -0,0 +1,5 @@
1
+ from .GBLUP_R import GBLUP_reg
2
+
3
+ GBLUP = GBLUP_reg
4
+
5
+ __all__ = ["GBLUP","GBLUP_reg"]
@@ -0,0 +1,164 @@
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import psutil
5
+ import time
6
+ import random
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pynvml
10
+ from . import GEFormer_Hyperparameters
11
+ from .gMLP import GEFormer
12
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+ from sklearn.model_selection import KFold, train_test_split
15
+ from scipy.stats import pearsonr
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Argument parser")
19
+ parser.add_argument('--methods', type=str, default='GEFormer/', help='Random seed')
20
+ parser.add_argument('--species', type=str, default='Pig/')
21
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
22
+ parser.add_argument('--data_dir', type=str, default='../../data/')
23
+ parser.add_argument('--result_dir', type=str, default='result/')
24
+
25
+ parser.add_argument('--epoch', type=int, default=1000, help='Number of training rounds')
26
+ parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
27
+ parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
28
+ parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping')
29
+ parser.add_argument('--dropout1', type=float, default=0.5, help='Dropout rate for layer 1')
30
+ parser.add_argument('--dropout2', type=float, default=0.5, help='Dropout rate for layer 2')
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+ def load_data(args):
35
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
36
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
37
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
38
+
39
+ nsample = xData.shape[0]
40
+ nsnp = xData.shape[1]
41
+ print("Number of samples: ", nsample)
42
+ print("Number of SNPs: ", nsnp)
43
+ return xData, yData, nsample, nsnp, names
44
+
45
+ def set_seed(seed=42):
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed_all(seed)
50
+ torch.backends.cudnn.deterministic = True
51
+ torch.backends.cudnn.benchmark = False
52
+
53
+ def get_gpu_mem_by_pid(pid):
54
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
55
+ for p in procs:
56
+ if p.pid == pid:
57
+ return p.usedGpuMemory / 1024**2
58
+ return 0.0
59
+
60
+ def run_nested_cv(args, data, label, nsnp, device):
61
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
62
+ os.makedirs(result_dir, exist_ok=True)
63
+ print("Starting 10-fold cross-validation...")
64
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
65
+
66
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
67
+ time_star = time.time()
68
+ for fold, (train_idx, test_idx) in enumerate(kf.split(data)):
69
+ print(f"Running fold {fold}...")
70
+ process = psutil.Process(os.getpid())
71
+ fold_start_time = time.time()
72
+
73
+ x_train, x_test = data[train_idx], data[test_idx]
74
+ y_train, y_test = label[train_idx], label[test_idx]
75
+
76
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(x_train, y_train, test_size=0.1, random_state=42)
77
+
78
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
79
+ y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
80
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
81
+ y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
82
+ x_test_tensor = torch.from_numpy(x_test).float().to(device)
83
+ y_test_tensor = torch.from_numpy(y_test).float().to(device)
84
+
85
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
86
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
87
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
88
+
89
+ train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
90
+ valid_loader = DataLoader(valid_data, args.batch_size, shuffle=False)
91
+ test_loader = DataLoader(test_data, args.batch_size, shuffle=False)
92
+
93
+ model = GEFormer(nsnp=nsnp).to(device)
94
+ model.train_model(train_loader, valid_loader,args.epoch, args.lr, args.patience, device)
95
+ y_pred = model.predict(test_loader)
96
+
97
+ mse = mean_squared_error(y_test, y_pred)
98
+ r2 = r2_score(y_test, y_pred)
99
+ mae = mean_absolute_error(y_test, y_pred)
100
+ pcc, _ = pearsonr(y_test, y_pred)
101
+
102
+ mse = mean_squared_error(y_test, y_pred)
103
+ r2 = r2_score(y_test, y_pred)
104
+ mae = mean_absolute_error(y_test, y_pred)
105
+ pcc, _ = pearsonr(y_test, y_pred)
106
+
107
+ all_mse.append(mse)
108
+ all_r2.append(r2)
109
+ all_mae.append(mae)
110
+ all_pcc.append(pcc)
111
+
112
+ fold_time = time.time() - fold_start_time
113
+ fold_gpu_mem = get_gpu_mem_by_pid(os.getpid())
114
+ fold_cpu_mem = process.memory_info().rss / 1024**2
115
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
116
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
117
+
118
+ torch.cuda.empty_cache()
119
+ torch.cuda.reset_peak_memory_stats()
120
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
121
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
122
+
123
+ print("\n===== Cross-validation summary =====")
124
+ print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
125
+ print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
126
+ print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
127
+ print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
128
+ print(f"Time: {time.time() - time_star:.2f}s")
129
+
130
+
131
+ def GEFormer_reg():
132
+ set_seed(42)
133
+ torch.cuda.empty_cache()
134
+ pynvml.nvmlInit()
135
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
136
+ args = parse_args()
137
+ all_species =['Cotton/']
138
+
139
+ for i in range(len(all_species)):
140
+ args.species = all_species[i]
141
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
+ args.device = device
143
+ X, Y, nsamples, nsnp, names = load_data(args)
144
+ for j in range(len(names)):
145
+ args.phe = names[j]
146
+ print("starting run " + args.methods + args.species + args.phe)
147
+ label = Y[:, j]
148
+ label = np.nan_to_num(label, nan=np.nanmean(label))
149
+ best_params = GEFormer_Hyperparameters.Hyperparameter(X, label, nsnp)
150
+ args.learning_rate = best_params['learning_rate']
151
+ args.batch_size = best_params['batch_size']
152
+ args.patience = best_params['patience']
153
+ start_time = time.time()
154
+ torch.cuda.reset_peak_memory_stats()
155
+ process = psutil.Process(os.getpid())
156
+ run_nested_cv(args, data=X, label=label, nsnp = nsnp, device = args.device)
157
+
158
+ elapsed_time = time.time() - start_time
159
+ print(f"running time: {elapsed_time:.2f} s")
160
+ print("successfully")
161
+
162
+
163
+ if __name__ == '__main__':
164
+ GEFormer_reg()
@@ -0,0 +1,106 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import KFold, train_test_split
9
+ from .gMLP import GEFormer
10
+ from scipy.stats import pearsonr
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+ from optuna.exceptions import TrialPruned
14
+
15
+ def run_nested_cv_with_early_stopping(data, label, nsnp, learning_rate, patience, batch_size, epoch=1000):
16
+ device = torch.device("cuda:0")
17
+ print("Starting 10-fold cross-validation...")
18
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
19
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
20
+
21
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
22
+ print(f"Running fold {fold}...")
23
+ process = psutil.Process(os.getpid())
24
+ fold_start_time = time.time()
25
+
26
+ X_train, X_test = data[train_index], data[test_index]
27
+ y_train, y_test = label[train_index], label[test_index]
28
+
29
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
30
+
31
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
32
+ y_train_tensor = torch.from_numpy(y_train_sub).float().to(device)
33
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
34
+ y_valid_tensor = torch.from_numpy(y_valid).float().to(device)
35
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
36
+ y_test_tensor = torch.from_numpy(y_test).float().to(device)
37
+
38
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
39
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
40
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
41
+
42
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
43
+ valid_loader = DataLoader(valid_data, batch_size, shuffle=False)
44
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
45
+
46
+ model = GEFormer(nsnp=nsnp)
47
+ model.train_model(train_loader, valid_loader, epoch, learning_rate, patience, device)
48
+ y_pred = model.predict(test_loader)
49
+
50
+ mse = mean_squared_error(y_test, y_pred)
51
+ r2 = r2_score(y_test, y_pred)
52
+ mae = mean_absolute_error(y_test, y_pred)
53
+ pcc, _ = pearsonr(y_test, y_pred)
54
+
55
+ if np.isnan(pcc):
56
+ print(f"Fold {fold} resulted in NaN PCC, pruning the trial...")
57
+ raise TrialPruned()
58
+
59
+ all_mse.append(mse)
60
+ all_r2.append(r2)
61
+ all_mae.append(mae)
62
+ all_pcc.append(pcc)
63
+
64
+ fold_time = time.time() - fold_start_time
65
+ fold_cpu_mem = process.memory_info().rss / 1024**2
66
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
67
+ f'CPU={fold_cpu_mem:.2f}MB')
68
+
69
+ return np.mean(all_pcc) if all_pcc else 0.0
70
+
71
+ def set_seed(seed=42):
72
+ random.seed(seed)
73
+ np.random.seed(seed)
74
+ torch.manual_seed(seed)
75
+ if torch.cuda.is_available():
76
+ torch.cuda.manual_seed_all(seed)
77
+ torch.backends.cudnn.deterministic = True
78
+ torch.backends.cudnn.benchmark = False
79
+
80
+ def Hyperparameter(data, label, nsnp):
81
+ set_seed(42)
82
+
83
+ def objective(trial):
84
+ learning_rate = trial.suggest_loguniform("learning_rate", 1e-4,0.1)
85
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
86
+ patience = trial.suggest_int("patience", 1, 10)
87
+ try:
88
+ corr_score = run_nested_cv_with_early_stopping(
89
+ data=data,
90
+ label=label,
91
+ nsnp=nsnp,
92
+ learning_rate=learning_rate,
93
+ patience=patience,
94
+ batch_size=batch_size
95
+ )
96
+
97
+ except TrialPruned:
98
+ return float("-inf")
99
+ return corr_score
100
+
101
+ study = optuna.create_study(direction="maximize")
102
+ study.optimize(objective, n_trials=20)
103
+
104
+ print("best params:", study.best_params)
105
+ print("successfully")
106
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .GEFormer import GEFormer_reg
2
+
3
+ GEFormer = GEFormer_reg
4
+
5
+ __all__ = ["GEFormer","GEFormer_reg"]