gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,84 @@
1
+ import abc
2
+ import joblib
3
+ import numpy as np
4
+ import pathlib
5
+
6
+
7
+ class ParamFreeBaseModel(abc.ABC):
8
+ """
9
+ BaseModel parent class for all models that do not have hyperparameters, e.g. BLUP.
10
+
11
+ Every model must be based on :obj:`~easypheno.model.param_free_base_model.ParamFreeBaseModel` directly or ParamFreeBaseModel's child classes.
12
+
13
+ Please add ``super().__init__(PARAMS)`` to the constructor in case you override it in a child class
14
+
15
+ **Attributes**
16
+
17
+ *Class attributes*
18
+
19
+ - standard_encoding (*str*): the standard encoding for this model
20
+ - possible_encodings (*List<str>*): a list of all encodings that are possible according to the model definition
21
+
22
+ *Instance attributes*
23
+
24
+ - task (*str*): ML task ('regression' or 'classification') depending on target variable
25
+ - encoding (*str*): the encoding to use (standard encoding or user-defined)
26
+
27
+
28
+ :param task: ML task (regression or classification) depending on target variable
29
+ :param encoding: the encoding to use (standard encoding or user-defined)
30
+
31
+ """
32
+
33
+ # Class attributes #
34
+ @property
35
+ @classmethod
36
+ @abc.abstractmethod
37
+ def standard_encoding(cls):
38
+ """the standard encoding for this model"""
39
+ raise NotImplementedError
40
+
41
+ @property
42
+ @classmethod
43
+ @abc.abstractmethod
44
+ def possible_encodings(cls):
45
+ """a list of all encodings that are possible according to the model definition"""
46
+ raise NotImplementedError
47
+
48
+ # Constructor super class #
49
+ def __init__(self, task: str, encoding: str = None):
50
+ self.task = task
51
+ self.encoding = self.standard_encoding if encoding is None else encoding
52
+
53
+ # Methods required by each child class #
54
+
55
+ @abc.abstractmethod
56
+ def fit(self, X: np.array, y: np.array) -> np.array:
57
+ """
58
+ Method that fits the model based on features X and targets y
59
+
60
+ :param X: feature matrix for retraining
61
+ :param y: target vector
62
+
63
+ :return: numpy array with values predicted for X
64
+ """
65
+
66
+ @abc.abstractmethod
67
+ def predict(self, X_in: np.array) -> np.array:
68
+ """
69
+ Method that predicts target values based on the input X_in
70
+
71
+ :param X_in: feature matrix as input
72
+
73
+ :return: numpy array with the predicted values
74
+ """
75
+
76
+ def save_model(self, path: pathlib.Path, filename: str):
77
+ """
78
+ Persist the whole model object on a hard drive
79
+ (can be loaded with :obj:`~easypheno.model._model_functions.load_model`)
80
+
81
+ :param path: path where the model will be saved
82
+ :param filename: filename of the model
83
+ """
84
+ joblib.dump(self, path.joinpath(filename), compress=3)
@@ -0,0 +1,16 @@
1
+ from . import _bayesfromR
2
+
3
+
4
+ class BayesC(_bayesfromR.Bayes_R):
5
+ """
6
+ Implementation of a class for Bayes C.
7
+
8
+ *Attributes*
9
+
10
+ *Inherited attributes*
11
+
12
+ See :obj:`~easypheno.model._bayesfromR.Bayes_R` for more information on the attributes.
13
+ """
14
+
15
+ def __init__(self, task: str, encoding: str = None):
16
+ super().__init__(task=task, model_name='BayesC', encoding=encoding)
@@ -0,0 +1,186 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import swanlab
5
+ import argparse
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.model_selection import StratifiedKFold, train_test_split
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+ from .base_CropARNet_class import SimpleSNPModel
15
+ from . import CropARNet_he_class
16
+ import pynvml
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description="Argument parser")
20
+ parser.add_argument('--methods', type=str, default='CropARNet/')
21
+ parser.add_argument('--species', type=str, default='')
22
+ parser.add_argument('--phe', type=str, default='')
23
+ parser.add_argument('--data_dir', type=str, default='../../data/')
24
+ parser.add_argument('--result_dir', type=str, default='result/')
25
+
26
+ parser.add_argument('--epochs', type=int, default=500)
27
+ parser.add_argument('--batch_size', type=int, default=32)
28
+ parser.add_argument('--weight_decay', type=float, default=1e-5)
29
+ parser.add_argument('--momentum', type=float, default=0.5)
30
+ parser.add_argument('--learning_rate', type=float, default=0.01)
31
+ parser.add_argument('--patience', type=int, default=50)
32
+ return parser.parse_args()
33
+
34
+ def load_data(args):
35
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
36
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
37
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
38
+
39
+ print("Number of samples:", xData.shape[0])
40
+ print("Number of SNPs:", xData.shape[1])
41
+ return xData, yData, xData.shape[0], xData.shape[1], names
42
+
43
+ def set_seed(seed=42):
44
+ random.seed(seed)
45
+ np.random.seed(seed)
46
+ torch.manual_seed(seed)
47
+ torch.cuda.manual_seed_all(seed)
48
+ torch.backends.cudnn.deterministic = True
49
+ torch.backends.cudnn.benchmark = False
50
+
51
+ def get_gpu_mem_by_pid(pid, handle):
52
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
53
+ for p in procs:
54
+ if p.pid == pid:
55
+ return p.usedGpuMemory / 1024**2
56
+ return 0.0
57
+
58
+ def run_nested_cv(args, data, label, nsnp, num_classes, device, handle=None, le=None):
59
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
60
+ os.makedirs(result_dir, exist_ok=True)
61
+
62
+ kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
63
+
64
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
65
+ time_star = time.time()
66
+
67
+ for fold, (train_index, test_index) in enumerate(kf.split(data, label)):
68
+ print(f"\nRunning fold {fold}...")
69
+ process = psutil.Process(os.getpid())
70
+ fold_start = time.time()
71
+
72
+ X_train, X_test = data[train_index], data[test_index]
73
+ y_train, y_test = label[train_index], label[test_index]
74
+
75
+ X_tr, X_val, y_tr, y_val = train_test_split(
76
+ X_train, y_train, test_size=0.1, random_state=42, stratify=y_train
77
+ )
78
+
79
+ x_train = torch.from_numpy(X_tr).float().to(device)
80
+ y_train_t = torch.from_numpy(y_tr).long().to(device)
81
+ x_valid = torch.from_numpy(X_val).float().to(device)
82
+ y_valid_t = torch.from_numpy(y_val).long().to(device)
83
+ x_test = torch.from_numpy(X_test).float().to(device)
84
+ y_test_t = torch.from_numpy(y_test).long().to(device)
85
+
86
+ train_loader = DataLoader(TensorDataset(x_train, y_train_t), args.batch_size, shuffle=True)
87
+ valid_loader = DataLoader(TensorDataset(x_valid, y_valid_t), args.batch_size, shuffle=False)
88
+ test_loader = DataLoader(TensorDataset(x_test, y_test_t), args.batch_size, shuffle=False)
89
+
90
+ model = SimpleSNPModel(num_snps = nsnp, num_classes=num_classes)
91
+ model.train_model(
92
+ train_loader,
93
+ valid_loader,
94
+ args.epochs,
95
+ args.learning_rate,
96
+ args.weight_decay,
97
+ args.patience,
98
+ device
99
+ )
100
+
101
+ y_pred = model.predict(test_loader)
102
+
103
+ acc = accuracy_score(y_test, y_pred)
104
+ prec, rec, f1, _ = precision_recall_fscore_support(
105
+ y_test, y_pred, average="macro", zero_division=0
106
+ )
107
+
108
+ all_acc.append(acc)
109
+ all_prec.append(prec)
110
+ all_rec.append(rec)
111
+ all_f1.append(f1)
112
+
113
+ fold_time = time.time() - fold_start
114
+ gpu_mem = get_gpu_mem_by_pid(os.getpid(), handle) if handle else 0.0
115
+ cpu_mem = process.memory_info().rss / 1024**2
116
+
117
+ print(
118
+ f"Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, "
119
+ f"REC={rec:.4f}, F1={f1:.4f}, "
120
+ f"Time={fold_time:.2f}s, GPU={gpu_mem:.2f}MB, CPU={cpu_mem:.2f}MB"
121
+ )
122
+
123
+ pd.DataFrame({
124
+ "y_true": le.inverse_transform(y_test),
125
+ "y_pred": le.inverse_transform(y_pred)
126
+ }).to_csv(
127
+ os.path.join(result_dir, f"fold{fold}.csv"), index=False
128
+ )
129
+
130
+ torch.cuda.empty_cache()
131
+
132
+ print("\n===== Cross-validation summary =====")
133
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
134
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
135
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
136
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
137
+ print(f"Time: {time.time() - time_star:.2f}s")
138
+
139
+
140
+ def CropARNet_class():
141
+ set_seed(42)
142
+ pynvml.nvmlInit()
143
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
144
+
145
+ args = parse_args()
146
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
147
+ all_species = ["Human/Sim/"]
148
+
149
+ for sp in all_species:
150
+ args.species = sp
151
+ X, Y, _, nsnp, names = load_data(args)
152
+
153
+ print("Starting:", args.methods + args.species)
154
+ if Y.ndim == 1:
155
+ label = Y
156
+ else:
157
+ label = Y[:, 0]
158
+
159
+ label = np.nan_to_num(label, nan=np.nanmean(label))
160
+
161
+ le = LabelEncoder()
162
+ label = le.fit_transform(label)
163
+ num_classes = len(np.unique(label))
164
+
165
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
166
+ os.makedirs(result_dir, exist_ok=True)
167
+ np.save(os.path.join(result_dir, 'label_mapping.npy'), le.classes_)
168
+
169
+ best_params = CropARNet_he_class.Hyperparameter(X, label, nsnp)
170
+ args.learning_rate = best_params["learning_rate"]
171
+ args.batch_size = best_params["batch_size"]
172
+ args.weight_decay = best_params["weight_decay"]
173
+ args.patience = best_params["patience"]
174
+ start_time = time.time()
175
+ torch.cuda.reset_peak_memory_stats()
176
+ process = psutil.Process(os.getpid())
177
+
178
+ run_nested_cv(args, X, label, nsnp, num_classes, device, handle, le)
179
+ elapsed_time = time.time() - start_time
180
+ print(f"运行时间: {elapsed_time:.2f} 秒")
181
+ print("successfully")
182
+
183
+
184
+
185
+ if __name__ == "__main__":
186
+ CropARNet_class()
@@ -0,0 +1,154 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import KFold, train_test_split
9
+ from sklearn.preprocessing import LabelEncoder
10
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+ from optuna.exceptions import TrialPruned
13
+ from .base_CropARNet_class import SimpleSNPModel
14
+
15
+ def run_nested_cv_with_early_stopping(
16
+ data,
17
+ label,
18
+ nsnp,
19
+ num_classes,
20
+ learning_rate,
21
+ weight_decay,
22
+ patience,
23
+ batch_size,
24
+ num_round=500
25
+ ):
26
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
+ print("Starting 10-fold cross-validation...")
28
+
29
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
30
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
31
+
32
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
33
+ print(f"Running fold {fold}...")
34
+ process = psutil.Process(os.getpid())
35
+ fold_start_time = time.time()
36
+
37
+ X_train, X_test = data[train_index], data[test_index]
38
+ y_train, y_test = label[train_index], label[test_index]
39
+
40
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
41
+ X_train, y_train, test_size=0.1, random_state=42
42
+ )
43
+
44
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
45
+ y_train_tensor = torch.from_numpy(y_train_sub).long().to(device)
46
+
47
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
48
+ y_valid_tensor = torch.from_numpy(y_valid).long().to(device)
49
+
50
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
51
+ y_test_tensor = torch.from_numpy(y_test).long().to(device)
52
+
53
+ train_loader = DataLoader(
54
+ TensorDataset(x_train_tensor, y_train_tensor),
55
+ batch_size=batch_size,
56
+ shuffle=True
57
+ )
58
+ valid_loader = DataLoader(
59
+ TensorDataset(x_valid_tensor, y_valid_tensor),
60
+ batch_size=batch_size,
61
+ shuffle=False
62
+ )
63
+ test_loader = DataLoader(
64
+ TensorDataset(x_test_tensor, y_test_tensor),
65
+ batch_size=batch_size,
66
+ shuffle=False
67
+ )
68
+
69
+ model = SimpleSNPModel(num_snps = nsnp, num_classes=num_classes)
70
+
71
+ model.train_model(
72
+ train_loader,
73
+ valid_loader,
74
+ num_round,
75
+ learning_rate,
76
+ weight_decay,
77
+ patience,
78
+ device
79
+ )
80
+
81
+ y_pred = model.predict(test_loader)
82
+ acc = accuracy_score(y_test, y_pred)
83
+ prec, rec, f1, _ = precision_recall_fscore_support(
84
+ y_test,
85
+ y_pred,
86
+ average="macro",
87
+ zero_division=0
88
+ )
89
+
90
+ if np.isnan(f1):
91
+ print(f"Fold {fold} resulted in NaN F1, pruning trial...")
92
+ raise TrialPruned()
93
+
94
+ all_acc.append(acc)
95
+ all_prec.append(prec)
96
+ all_rec.append(rec)
97
+ all_f1.append(f1)
98
+
99
+ fold_time = time.time() - fold_start_time
100
+ fold_cpu_mem = process.memory_info().rss / 1024**2
101
+
102
+ print(
103
+ f"Fold {fold}: ACC={acc:.4f}, "
104
+ f"PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, "
105
+ f"Time={fold_time:.2f}s, CPU={fold_cpu_mem:.2f}MB"
106
+ )
107
+
108
+ return np.mean(all_f1) if all_f1 else 0.0
109
+
110
+ def set_seed(seed=42):
111
+ random.seed(seed)
112
+ np.random.seed(seed)
113
+ torch.manual_seed(seed)
114
+ if torch.cuda.is_available():
115
+ torch.cuda.manual_seed_all(seed)
116
+ torch.backends.cudnn.deterministic = True
117
+ torch.backends.cudnn.benchmark = False
118
+
119
+ def Hyperparameter(data, label, nsnp):
120
+ set_seed(42)
121
+ le = LabelEncoder()
122
+ label = le.fit_transform(label)
123
+ num_classes = len(np.unique(label))
124
+
125
+ def objective(trial):
126
+ learning_rate = trial.suggest_float("learning_rate", 1e-4, 0.1, log=True)
127
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
128
+ weight_decay = trial.suggest_categorical(
129
+ "weight_decay", [1e-4, 1e-3, 1e-2, 1e-1]
130
+ )
131
+ patience = trial.suggest_int("patience", 5, 30)
132
+
133
+ try:
134
+ f1_score = run_nested_cv_with_early_stopping(
135
+ data=data,
136
+ label=label,
137
+ nsnp=nsnp,
138
+ num_classes=num_classes,
139
+ learning_rate=learning_rate,
140
+ weight_decay=weight_decay,
141
+ patience=patience,
142
+ batch_size=batch_size
143
+ )
144
+ except TrialPruned:
145
+ return float("-inf")
146
+
147
+ return f1_score
148
+
149
+ study = optuna.create_study(direction="maximize")
150
+ study.optimize(objective, n_trials=20)
151
+
152
+ print("Best hyperparameters:", study.best_params)
153
+ print("successfully")
154
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .CropARNet_class import CropARNet_class
2
+
3
+ CropARNet = CropARNet_class
4
+
5
+ __all__ = ["CropARNet","CropARNet_class"]
@@ -0,0 +1,178 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ config = {
6
+ "batch_size": 64,
7
+ "weights_units": [64, 32],
8
+ "classifier_units": [64, 32],
9
+ "dropout": 0.3,
10
+ }
11
+
12
+
13
+ class SimpleSNPModel(nn.Module):
14
+ """
15
+ Classification version of SimpleSNPModel
16
+ (Attention + Residual + MLP)
17
+ """
18
+
19
+ def __init__(self, num_snps: int, num_classes: int):
20
+ super().__init__()
21
+
22
+ if not isinstance(num_snps, int) or num_snps <= 0:
23
+ raise ValueError(f"num_snps must be positive integer, got {num_snps}")
24
+ if not isinstance(num_classes, int) or num_classes <= 1:
25
+ raise ValueError(f"num_classes must be >=2, got {num_classes}")
26
+
27
+ self.config = config
28
+ self.num_classes = num_classes
29
+
30
+ self.attention = self._build_attention_module(num_snps)
31
+ self.classifier = self._build_classifier_module(num_snps, num_classes)
32
+
33
+ # ==================================================
34
+ # Attention module
35
+ # ==================================================
36
+ def _build_attention_module(self, num_snps):
37
+ layers = []
38
+ prev_size = num_snps
39
+
40
+ for i, h_size in enumerate(self.config["weights_units"]):
41
+ if h_size <= 0:
42
+ raise ValueError(f"Invalid hidden size {h_size}")
43
+ layers.append(nn.Linear(prev_size, h_size))
44
+ layers.append(nn.GELU())
45
+ prev_size = h_size
46
+
47
+ layers.append(nn.Linear(prev_size, num_snps))
48
+ layers.append(nn.Sigmoid())
49
+ return nn.Sequential(*layers)
50
+
51
+ # ==================================================
52
+ # Classifier module
53
+ # ==================================================
54
+ def _build_classifier_module(self, num_snps, num_classes):
55
+ layers = []
56
+ prev_size = num_snps
57
+
58
+ for i, h_size in enumerate(self.config["classifier_units"]):
59
+ if h_size <= 0:
60
+ raise ValueError(f"Invalid hidden size {h_size}")
61
+ layers.append(nn.Linear(prev_size, h_size))
62
+ layers.append(nn.LayerNorm(h_size))
63
+ layers.append(nn.GELU())
64
+ layers.append(nn.Dropout(self.config["dropout"]))
65
+ prev_size = h_size
66
+
67
+ layers.append(nn.Linear(prev_size, num_classes))
68
+ return nn.Sequential(*layers)
69
+
70
+ # ==================================================
71
+ # Forward
72
+ # ==================================================
73
+ def forward(self, x):
74
+ if x.dim() != 2:
75
+ raise ValueError(f"Input must be 2D tensor, got {x.dim()}D")
76
+
77
+ # Attention
78
+ pre_sigmoid_weights = self.attention[:-1](x)
79
+ att_weights = self.attention(x)
80
+
81
+ # Residual weighted SNPs
82
+ weighted = x * att_weights + x
83
+
84
+ logits = self.classifier(weighted)
85
+ return logits, pre_sigmoid_weights
86
+
87
+ # ==================================================
88
+ # Training (classification)
89
+ # ==================================================
90
+ def train_model(
91
+ self,
92
+ train_loader,
93
+ valid_loader,
94
+ num_epochs,
95
+ learning_rate,
96
+ weight_decay,
97
+ patience,
98
+ device
99
+ ):
100
+ self.to(device)
101
+
102
+ optimizer = torch.optim.AdamW(
103
+ self.parameters(),
104
+ lr=learning_rate,
105
+ weight_decay=weight_decay
106
+ )
107
+ criterion = nn.CrossEntropyLoss()
108
+
109
+ best_loss = float("inf")
110
+ best_state = None
111
+ trigger_times = 0
112
+
113
+ for epoch in range(num_epochs):
114
+ # -------- Train --------
115
+ self.train()
116
+ train_loss = 0.0
117
+
118
+ for inputs, labels in train_loader:
119
+ inputs = inputs.to(device)
120
+ labels = labels.to(device).long()
121
+
122
+ optimizer.zero_grad()
123
+ outputs, _ = self(inputs)
124
+ loss = criterion(outputs, labels)
125
+ loss.backward()
126
+ optimizer.step()
127
+
128
+ train_loss += loss.item() * inputs.size(0)
129
+
130
+ train_loss /= len(train_loader.dataset)
131
+
132
+ # -------- Validation --------
133
+ self.eval()
134
+ valid_loss = 0.0
135
+
136
+ with torch.no_grad():
137
+ for inputs, labels in valid_loader:
138
+ inputs = inputs.to(device)
139
+ labels = labels.to(device).long()
140
+
141
+ outputs, _ = self(inputs)
142
+ loss = criterion(outputs, labels)
143
+ valid_loss += loss.item() * inputs.size(0)
144
+
145
+ valid_loss /= len(valid_loader.dataset)
146
+
147
+ # -------- Early stopping --------
148
+ if valid_loss < best_loss:
149
+ best_loss = valid_loss
150
+ best_state = self.state_dict()
151
+ trigger_times = 0
152
+ else:
153
+ trigger_times += 1
154
+ if trigger_times >= patience:
155
+ print(f"Early stopping at epoch {epoch + 1}")
156
+ break
157
+
158
+ if best_state is not None:
159
+ self.load_state_dict(best_state)
160
+
161
+ return best_loss
162
+
163
+ # ==================================================
164
+ # Prediction (classification)
165
+ # ==================================================
166
+ def predict(self, test_loader):
167
+ self.eval()
168
+ device = next(self.parameters()).device
169
+ y_pred = []
170
+
171
+ with torch.no_grad():
172
+ for inputs, _ in test_loader:
173
+ inputs = inputs.to(device)
174
+ outputs, _ = self(inputs) # logits
175
+ preds = torch.argmax(outputs, dim=1)
176
+ y_pred.append(preds.cpu().numpy())
177
+
178
+ return np.concatenate(y_pred, axis=0)