gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,83 @@
1
+ import gc
2
+ import random
3
+ import time
4
+ import numpy as np
5
+ import optuna
6
+ from sklearn.model_selection import KFold
7
+ from sklearn.linear_model import ElasticNet
8
+ from scipy.stats import pearsonr
9
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
10
+ from optuna.exceptions import TrialPruned
11
+
12
+ def run_nested_cv_with_early_stopping(data, label, outer_cv, alpha, l1_ratio):
13
+ best_corr_coefs = []
14
+ best_maes = []
15
+ best_r2s = []
16
+ best_mses = []
17
+ time_star = time.time()
18
+
19
+ for fold, (train_idx, test_idx) in enumerate(outer_cv.split(data)):
20
+ x_train = data[train_idx]
21
+ x_test = data[test_idx]
22
+ y_train = label[train_idx]
23
+ y_test = label[test_idx]
24
+
25
+ model = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, max_iter=1000, random_state=42)
26
+ model.fit(x_train, y_train)
27
+ y_test_preds = model.predict(x_test)
28
+
29
+ pcc, _ = pearsonr(y_test, y_test_preds)
30
+ mse = mean_squared_error(y_test, y_test_preds)
31
+ r2 = r2_score(y_test, y_test_preds)
32
+ mae = mean_absolute_error(y_test, y_test_preds)
33
+
34
+ best_corr_coefs.append(pcc)
35
+ best_maes.append(mae)
36
+ best_r2s.append(r2)
37
+ best_mses.append(mse)
38
+
39
+ print(f'Fold {fold + 1}: MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Corr={pcc:.4f}')
40
+ del model, y_test_preds, x_train, x_test, y_train, y_test
41
+
42
+ print("==== Final Results ====")
43
+ print(f"MAE: {np.mean(best_maes):.4f} ± {np.std(best_maes):.4f}")
44
+ print(f"MSE: {np.mean(best_mses):.4f} ± {np.std(best_mses):.4f}")
45
+ print(f"R2 : {np.mean(best_r2s):.4f} ± {np.std(best_r2s):.4f}")
46
+ print(f"Corr: {np.mean(best_corr_coefs):.4f} ± {np.std(best_corr_coefs):.4f}")
47
+
48
+ print(f"Time: {time.time() - time_star:.2f}s")
49
+ gc.collect()
50
+
51
+ return np.mean(best_corr_coefs)
52
+
53
+ def set_seed(seed=42):
54
+ random.seed(seed)
55
+ np.random.seed(seed)
56
+
57
+ def Hyperparameter(data, label):
58
+ set_seed(42)
59
+
60
+ def objective(trial):
61
+ alpha = trial.suggest_float("alpha", 1e-4, 1.0, log=True)
62
+ l1_ratio = trial.suggest_categorical("l1_ratio", [0.1, 0.3, 0.5, 0.7, 0.9])
63
+
64
+ outer_cv = KFold(n_splits=10, shuffle=True, random_state=42)
65
+
66
+ try:
67
+ corr_score = run_nested_cv_with_early_stopping(
68
+ data=data,
69
+ label=label,
70
+ outer_cv=outer_cv,
71
+ alpha=alpha,
72
+ l1_ratio=l1_ratio
73
+ )
74
+ except TrialPruned:
75
+ return float("-inf")
76
+ return corr_score
77
+
78
+ study = optuna.create_study(direction="maximize")
79
+ study.optimize(objective, n_trials=20)
80
+
81
+ print("best params:", study.best_params)
82
+ print("successfully")
83
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .ElasticNet import ElasticNet_reg
2
+
3
+ ElasticNet = ElasticNet_reg
4
+
5
+ __all__ = ["ElasticNet","ElasticNet_reg"]
@@ -0,0 +1,107 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import KFold, train_test_split
9
+ from .base_G2PDeep import G2PDeep, ModelHyperparams
10
+ from scipy.stats import pearsonr
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+ from optuna.exceptions import TrialPruned
14
+
15
+ def run_nested_cv_with_early_stopping(data, label, nsnp, learning_rate, batch_size, patience, epoch=1000):
16
+ device = torch.device("cuda:0")
17
+ print("Starting 10-fold cross-validation...")
18
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
19
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
20
+
21
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
22
+ print(f"Running fold {fold}...")
23
+ process = psutil.Process(os.getpid())
24
+ fold_start_time = time.time()
25
+
26
+ X_train, X_test = data[train_index], data[test_index]
27
+ y_train, y_test = label[train_index], label[test_index]
28
+
29
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
30
+
31
+ x_train_tensor = torch.from_numpy(X_train_sub).float()
32
+ y_train_tensor = torch.from_numpy(y_train_sub).float()
33
+ x_valid_tensor = torch.from_numpy(X_valid).float()
34
+ y_valid_tensor = torch.from_numpy(y_valid).float()
35
+ x_test_tensor = torch.from_numpy(X_test).float()
36
+ y_test_tensor = torch.from_numpy(y_test).float()
37
+
38
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
39
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
40
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
41
+
42
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
43
+ valid_loader = DataLoader(valid_data, batch_size, shuffle=False)
44
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
45
+ hp = ModelHyperparams()
46
+ model = G2PDeep(nsnp=nsnp, hyperparams = hp)
47
+ model.train_model(train_loader, valid_loader, epoch, learning_rate, patience, device)
48
+ y_pred = model.predict(test_loader, device)
49
+
50
+ mse = mean_squared_error(y_test, y_pred)
51
+ r2 = r2_score(y_test, y_pred)
52
+ mae = mean_absolute_error(y_test, y_pred)
53
+ pcc, _ = pearsonr(y_test, y_pred)
54
+
55
+ if np.isnan(pcc):
56
+ print(f"Fold {fold} resulted in NaN PCC, pruning the trial...")
57
+ raise TrialPruned()
58
+
59
+ all_mse.append(mse)
60
+ all_r2.append(r2)
61
+ all_mae.append(mae)
62
+ all_pcc.append(pcc)
63
+
64
+ fold_time = time.time() - fold_start_time
65
+ fold_gpu_mem = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
66
+ fold_cpu_mem = process.memory_info().rss / 1024**2
67
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
68
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
69
+
70
+ return np.mean(all_pcc) if all_pcc else 0.0
71
+
72
+ def set_seed(seed=42):
73
+ random.seed(seed)
74
+ np.random.seed(seed)
75
+ torch.manual_seed(seed)
76
+ if torch.cuda.is_available():
77
+ torch.cuda.manual_seed_all(seed)
78
+ torch.backends.cudnn.deterministic = True
79
+ torch.backends.cudnn.benchmark = False
80
+
81
+ def Hyperparameter(data, label, nsnp):
82
+ set_seed(42)
83
+
84
+ def objective(trial):
85
+ lr = trial.suggest_float("learning_rate", 1e-4, 0.1)
86
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
87
+ patience = trial.suggest_int("patience", 1, 10)
88
+ try:
89
+ corr_score = run_nested_cv_with_early_stopping(
90
+ data=data,
91
+ label=label,
92
+ nsnp=nsnp,
93
+ learning_rate=lr,
94
+ batch_size=batch_size,
95
+ patience=patience
96
+ )
97
+
98
+ except TrialPruned:
99
+ return float("-inf")
100
+ return corr_score
101
+
102
+ study = optuna.create_study(direction="maximize")
103
+ study.optimize(objective, n_trials=20)
104
+
105
+ print("best params:", study.best_params)
106
+ print("successfully")
107
+ return study.best_params
@@ -0,0 +1,166 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import swanlab
5
+ import argparse
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.model_selection import KFold, train_test_split
11
+ from .base_G2PDeep import G2PDeep, ModelHyperparams
12
+ from scipy.stats import pearsonr
13
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
14
+ from torch.utils.data import DataLoader, TensorDataset
15
+ from . import G2PDeep_Hyperparameters
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Argument parser")
19
+ parser.add_argument('--methods', type=str, default='G2PDeep/', help='Random seed')
20
+ parser.add_argument('--species', type=str, default='')
21
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
22
+ parser.add_argument('--data_dir', type=str, default='../../data/')
23
+ parser.add_argument('--result_dir', type=str, default='result/')
24
+
25
+ parser.add_argument('--epoch', type=int, default=1000, help='Number of training rounds')
26
+ parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
27
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
28
+ parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping')
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+ def process_snp_data(data: np.array) -> np.array:
33
+ nb_classes = 4
34
+ onehot_x = np.empty(
35
+ shape=(data.shape[0], data.shape[1], nb_classes),
36
+ dtype=np.float32
37
+ )
38
+
39
+ for i in range(data.shape[0]):
40
+ _data = pd.to_numeric(data[i], errors='coerce')
41
+ _targets = np.array(_data).reshape(-1).astype(np.int64)
42
+ onehot_x[i] = np.eye(nb_classes)[_targets]
43
+
44
+ return onehot_x
45
+
46
+
47
+ def load_data(args):
48
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
49
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
50
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
51
+
52
+ xData[xData == -9] = 0
53
+ xData = process_snp_data(xData)
54
+ nsample = xData.shape[0]
55
+ nsnp = xData.shape[1]
56
+ print("Number of samples: ", nsample)
57
+ print("Number of SNPs: ", nsnp)
58
+ return xData, yData, nsample, nsnp, names
59
+
60
+
61
+ def set_seed(seed=42):
62
+ random.seed(seed)
63
+ np.random.seed(seed)
64
+ torch.manual_seed(seed)
65
+ torch.cuda.manual_seed_all(seed)
66
+ torch.backends.cudnn.deterministic = True
67
+ torch.backends.cudnn.benchmark = False
68
+
69
+ def run_nested_cv(args, data, label, nsnp, device):
70
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
71
+ os.makedirs(result_dir, exist_ok=True)
72
+ print("Starting 10-fold cross-validation...")
73
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
74
+
75
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
76
+ time_star = time.time()
77
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
78
+ print(f"Running fold {fold}...")
79
+ process = psutil.Process(os.getpid())
80
+ fold_start_time = time.time()
81
+
82
+ X_train, X_test = data[train_index], data[test_index]
83
+ y_train, y_test = label[train_index], label[test_index]
84
+
85
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
86
+
87
+ x_train_tensor = torch.from_numpy(X_train_sub).float()
88
+ y_train_tensor = torch.from_numpy(y_train_sub).float()
89
+ x_valid_tensor = torch.from_numpy(X_valid).float()
90
+ y_valid_tensor = torch.from_numpy(y_valid).float()
91
+ x_test_tensor = torch.from_numpy(X_test).float()
92
+ y_test_tensor = torch.from_numpy(y_test).float()
93
+
94
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
95
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
96
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
97
+
98
+ train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
99
+ valid_loader = DataLoader(valid_data, args.batch_size, shuffle=False)
100
+ test_loader = DataLoader(test_data, args.batch_size, shuffle=False)
101
+ hp = ModelHyperparams()
102
+ model = G2PDeep(nsnp=nsnp, hyperparams=hp).to(device)
103
+ model.train_model(train_loader, valid_loader, args.epoch, args.lr, args.patience, device)
104
+ y_pred = model.predict(test_loader, device)
105
+
106
+ mse = mean_squared_error(y_test, y_pred)
107
+ r2 = r2_score(y_test, y_pred)
108
+ mae = mean_absolute_error(y_test, y_pred)
109
+ pcc, _ = pearsonr(y_test, y_pred)
110
+
111
+ all_mse.append(mse)
112
+ all_r2.append(r2)
113
+ all_mae.append(mae)
114
+ all_pcc.append(pcc)
115
+
116
+ fold_time = time.time() - fold_start_time
117
+ fold_gpu_mem = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
118
+ fold_cpu_mem = process.memory_info().rss / 1024**2
119
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, Time={fold_time:.2f}s, '
120
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
121
+
122
+ torch.cuda.empty_cache()
123
+ torch.cuda.reset_peak_memory_stats()
124
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
125
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
126
+
127
+ print("\n===== Cross-validation summary =====")
128
+ print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
129
+ print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
130
+ print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
131
+ print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
132
+ print(f"Time: {time.time() - time_star:.2f}s")
133
+
134
+
135
+ def G2PDeep_reg():
136
+ set_seed(42)
137
+ torch.cuda.empty_cache()
138
+ args = parse_args()
139
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
140
+ all_species =['Cotton/']
141
+
142
+ for i in range(len(all_species)):
143
+ args.species = all_species[i]
144
+ args.device = device
145
+ X, Y, nsamples, nsnp, names = load_data(args)
146
+ for j in range(len(names)):
147
+ args.phe = names[j]
148
+ print("starting run " + args.methods + args.species + args.phe)
149
+ label = Y[:, j]
150
+ label = np.nan_to_num(label, nan=np.nanmean(label))
151
+ best_params = G2PDeep_Hyperparameters.Hyperparameter(X, label, nsnp)
152
+ args.lr = best_params['learning_rate']
153
+ args.patience = best_params['patience']
154
+ args.batch_size = best_params['batch_size']
155
+ start_time = time.time()
156
+ torch.cuda.reset_peak_memory_stats()
157
+ process = psutil.Process(os.getpid())
158
+ run_nested_cv(args, data=X, label=label, nsnp = nsnp, device = args.device)
159
+
160
+ elapsed_time = time.time() - start_time
161
+ print(f"running time: {elapsed_time:.2f} s")
162
+ print("successfully")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ G2PDeep_reg()
@@ -0,0 +1,5 @@
1
+ from .G2Pdeep import G2PDeep_reg
2
+
3
+ G2PDeep = G2PDeep_reg
4
+
5
+ __all__ = ["G2PDeep","G2PDeep_reg"]
@@ -0,0 +1,209 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional
4
+ import numpy as np
5
+
6
+
7
+ class ModelHyperparams:
8
+ def __init__(self,
9
+ left_tower_filters_list: Optional[List[int]] = None,
10
+ left_tower_kernel_size_list: Optional[List[int]] = None,
11
+ right_tower_filters_list: Optional[List[int]] = None,
12
+ right_tower_kernel_size_list: Optional[List[int]] = None,
13
+ central_tower_filters_list: Optional[List[int]] = None,
14
+ central_tower_kernel_size_list: Optional[List[int]] = None,
15
+ dnn_size_list: Optional[List[int]] = None,
16
+ activation: str = "linear",
17
+ dropout_rate: float = 0.75): # ⬅ 改小
18
+ self.left_tower_filters_list = left_tower_filters_list or [4, 4]
19
+ self.left_tower_kernel_size_list = left_tower_kernel_size_list or [3, 5]
20
+ self.right_tower_filters_list = right_tower_filters_list or [4]
21
+ self.right_tower_kernel_size_list = right_tower_kernel_size_list or [3]
22
+ self.central_tower_filters_list = central_tower_filters_list or [4]
23
+ self.central_tower_kernel_size_list = central_tower_kernel_size_list or [3]
24
+ self.dnn_size_list = dnn_size_list or [1]
25
+ self.activation = activation
26
+ self.dropout_rate = dropout_rate
27
+
28
+ def get_activation(name: str):
29
+ if name.lower() == "relu":
30
+ return nn.ReLU()
31
+ elif name.lower() == "linear":
32
+ return nn.Identity()
33
+ else:
34
+ raise ValueError(f"Unsupported activation: {name}")
35
+
36
+
37
+ class G2PDeep(nn.Module):
38
+ def __init__(self, nsnp: int, hyperparams: ModelHyperparams):
39
+ super().__init__()
40
+ self.nsnp = nsnp
41
+ hp = hyperparams
42
+
43
+ # --- Left Tower ---
44
+ self.left_convs = nn.ModuleList()
45
+ in_ch = 4
46
+ for filt, k in zip(hp.left_tower_filters_list, hp.left_tower_kernel_size_list):
47
+ self.left_convs.append(nn.Conv1d(in_ch, filt, k, padding="same"))
48
+ in_ch = filt
49
+
50
+ # --- Right Tower ---
51
+ self.right_convs = nn.ModuleList()
52
+ in_ch = 4
53
+ for filt, k in zip(hp.right_tower_filters_list, hp.right_tower_kernel_size_list):
54
+ self.right_convs.append(nn.Conv1d(in_ch, filt, k, padding="same"))
55
+ in_ch = filt
56
+
57
+ # --- Channel alignment ---
58
+ left_out_ch = hp.left_tower_filters_list[-1]
59
+ right_out_ch = hp.right_tower_filters_list[-1]
60
+ self.merged_ch = max(left_out_ch, right_out_ch)
61
+
62
+ self.left_proj = nn.Conv1d(left_out_ch, self.merged_ch, 1) \
63
+ if left_out_ch != self.merged_ch else nn.Identity()
64
+ self.right_proj = nn.Conv1d(right_out_ch, self.merged_ch, 1) \
65
+ if right_out_ch != self.merged_ch else nn.Identity()
66
+
67
+ # --- Central Tower ---
68
+ self.central_convs = nn.ModuleList()
69
+ in_ch = self.merged_ch
70
+ for filt, k in zip(hp.central_tower_filters_list, hp.central_tower_kernel_size_list):
71
+ self.central_convs.append(nn.Conv1d(in_ch, filt, k, padding="same"))
72
+ in_ch = filt
73
+
74
+ # --DNN ---
75
+ self.dropout = nn.Dropout(p=hp.dropout_rate)
76
+ final_conv_ch = hp.central_tower_filters_list[-1]
77
+ flattened_dim = final_conv_ch * nsnp
78
+
79
+ dnn_layers = []
80
+ prev = flattened_dim
81
+ for out_sz in hp.dnn_size_list[:-1]:
82
+ dnn_layers.append(nn.Linear(prev, out_sz))
83
+ dnn_layers.append(get_activation(hp.activation))
84
+ dnn_layers.append(nn.Dropout(hp.dropout_rate))
85
+ prev = out_sz
86
+ dnn_layers.append(nn.Linear(prev, hp.dnn_size_list[-1]))
87
+ self.dnn = nn.Sequential(*dnn_layers)
88
+
89
+ self.activation = get_activation(hp.activation)
90
+
91
+ def forward(self, x):
92
+ # (B, Seq, 4) -> (B, 4, Seq)
93
+ if x.shape[-1] != 4:
94
+ raise ValueError(f"Expected input with 4 channels, got {x.shape}")
95
+
96
+ x = x.transpose(1, 2)
97
+
98
+ # Left tower
99
+ left = x
100
+ for conv in self.left_convs:
101
+ left = self.activation(conv(left))
102
+
103
+ # Right tower
104
+ right = x
105
+ for conv in self.right_convs:
106
+ right = self.activation(conv(right))
107
+
108
+ merged = self.left_proj(left) + self.right_proj(right)
109
+
110
+ # Central tower
111
+ for conv in self.central_convs:
112
+ merged = self.activation(conv(merged))
113
+
114
+ x_flat = torch.flatten(merged, 1)
115
+ x_flat = self.dropout(x_flat)
116
+ return self.dnn(x_flat)
117
+
118
+ def train_model(self, train_loader, valid_loader, num_epochs, learning_rate, patience, device):
119
+ optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=1e-4)
120
+ criterion = nn.MSELoss()
121
+ self.to(device)
122
+
123
+ # 启用混合精度训练
124
+ use_amp = device.type == 'cuda'
125
+ scaler = torch.amp.GradScaler('cuda') if use_amp else None
126
+
127
+ best_loss = float('inf')
128
+ best_state = None
129
+ trigger_times = 0
130
+
131
+ for epoch in range(num_epochs):
132
+ # 训练
133
+ self.train()
134
+ train_loss = 0.0
135
+ for inputs, labels in train_loader:
136
+ inputs = inputs.to(device, non_blocking=True)
137
+ labels = labels.to(device, non_blocking=True).unsqueeze(1)
138
+
139
+ optimizer.zero_grad()
140
+
141
+ if use_amp:
142
+ with torch.amp.autocast('cuda'):
143
+ outputs = self(inputs)
144
+ loss = criterion(outputs, labels)
145
+ scaler.scale(loss).backward()
146
+ scaler.step(optimizer)
147
+ scaler.update()
148
+ else:
149
+ outputs = self(inputs)
150
+ loss = criterion(outputs, labels)
151
+ loss.backward()
152
+ optimizer.step()
153
+
154
+ train_loss += loss.item() * inputs.size(0)
155
+ train_loss /= len(train_loader.dataset)
156
+
157
+ # 验证
158
+ self.eval()
159
+ valid_loss = 0.0
160
+ with torch.no_grad():
161
+ for inputs, labels in valid_loader:
162
+ inputs = inputs.to(device, non_blocking=True)
163
+ labels = labels.to(device, non_blocking=True).unsqueeze(1)
164
+
165
+ if use_amp:
166
+ with torch.amp.autocast('cuda'):
167
+ outputs = self(inputs)
168
+ loss = criterion(outputs, labels)
169
+ else:
170
+ outputs = self(inputs)
171
+ loss = criterion(outputs, labels)
172
+
173
+ valid_loss += loss.item() * inputs.size(0)
174
+ valid_loss /= len(valid_loader.dataset)
175
+
176
+ # Early stopping
177
+ if valid_loss < best_loss:
178
+ best_loss = valid_loss
179
+ best_state = {k: v.cpu().clone() for k, v in self.state_dict().items()}
180
+ trigger_times = 0
181
+ else:
182
+ trigger_times += 1
183
+ if trigger_times >= patience:
184
+ print(f"Early stopping at epoch {epoch+1}")
185
+ break
186
+
187
+ if best_state is not None:
188
+ cur_device = next(self.parameters()).device
189
+ best_state = {k: v.to(cur_device) for k, v in best_state.items()}
190
+ self.load_state_dict(best_state)
191
+ return best_loss
192
+
193
+ def predict(self, test_loader, device):
194
+ self.eval()
195
+ self.to(device)
196
+ y_pred_list = []
197
+ use_amp = device.type == 'cuda'
198
+ with torch.no_grad():
199
+ for inputs, _ in test_loader:
200
+ inputs = inputs.to(device, non_blocking=True)
201
+ if use_amp:
202
+ with torch.amp.autocast('cuda'):
203
+ outputs = self(inputs)
204
+ else:
205
+ outputs = self(inputs)
206
+ y_pred_list.append(outputs.cpu())
207
+ y_pred = torch.cat(y_pred_list, dim=0).numpy() # 一次性转换
208
+ y_pred = np.squeeze(y_pred)
209
+ return y_pred