gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,276 @@
1
+ # by ww
2
+ import time
3
+ import argparse
4
+ import pynvml
5
+ import psutil
6
+ import os
7
+ import torch
8
+ import random
9
+ import swanlab
10
+ import numpy as np
11
+ import pandas as pd
12
+ from torch.utils.data import DataLoader,TensorDataset
13
+ from sklearn.model_selection import StratifiedKFold, train_test_split
14
+ from sklearn.preprocessing import LabelEncoder
15
+ import sys
16
+ sys.path.append("..")
17
+ from .utils.models_locally_connected import LCLModel
18
+ from .utils.common import DataDimensions
19
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
20
+ from .EIR_he_class import Hyperparameter
21
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Argument parser")
25
+ # Add arguments
26
+ parser.add_argument('--methods', type=str, default='EIR/', help='Random seed')
27
+ parser.add_argument('--species', type=str, default='', help='Species name')
28
+ parser.add_argument('--phe', type=str, default='', help='Dataset name')
29
+ parser.add_argument('--data_dir', type=str, default='../../data/')
30
+ parser.add_argument('--result_dir', type=str, default='result/')
31
+
32
+ parser.add_argument('--epochs', type=int, default=1000, help='Number of training rounds')
33
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
34
+ parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
35
+ parser.add_argument('--patience', type=int, default=50, help='Patience for early stopping')
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def load_data(args):
41
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
42
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
43
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
44
+
45
+ nsample = xData.shape[0]
46
+ nsnp = xData.shape[1]
47
+ print("Number of samples: ", nsample)
48
+ print("Number of SNPs: ", nsnp)
49
+ return xData, yData, nsample, nsnp, names
50
+
51
+ def get_gpu_mem_by_pid(pid, handle=None):
52
+ if handle is None:
53
+ return 0.0
54
+ try:
55
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
56
+ for p in procs:
57
+ if p.pid == pid:
58
+ return p.usedGpuMemory / 1024**2
59
+ return 0.0
60
+ except Exception:
61
+ return 0.0
62
+
63
+ def init():
64
+ seed = 42
65
+ os.environ['PYTHONHASHSEED'] = str(seed)
66
+ random.seed(seed)
67
+ np.random.seed(seed)
68
+ torch.manual_seed(seed)
69
+ torch.cuda.manual_seed(seed)
70
+
71
+ def set_seed(seed=42):
72
+ random.seed(seed)
73
+ np.random.seed(seed)
74
+ torch.manual_seed(seed)
75
+ torch.cuda.manual_seed_all(seed)
76
+ torch.backends.cudnn.deterministic = True
77
+ torch.backends.cudnn.benchmark = False
78
+
79
+ def one_hot_encode_ATCG(char):
80
+ c = char
81
+ if char==0:
82
+ return [1,0,0,0]
83
+ elif char==1:
84
+ return [0,1,0,0]
85
+ elif char==2:
86
+ return [0,0,1,0]
87
+ else:
88
+ return [0,0,0,1]
89
+
90
+ def one_hot_seq(df:pd.DataFrame, nsnp:int):
91
+ one_hot_df = df.applymap(lambda x:one_hot_encode_ATCG(x))
92
+ one_hot_df = one_hot_df.values.tolist()
93
+ one_hot_df = np.array(one_hot_df)
94
+ one_hot_df = np.reshape(one_hot_df, (one_hot_df.shape[0],-1))
95
+ tensor_data = torch.Tensor(one_hot_df)
96
+ return tensor_data
97
+
98
+
99
+ def train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs, patience, device):
100
+
101
+ model.to(device)
102
+ best_loss = float('inf')
103
+ best_state = None
104
+ trigger_times = 0
105
+
106
+ for epoch in range(num_epochs):
107
+ model.train()
108
+ train_loss = 0.0
109
+ for inputs, labels in train_loader:
110
+ inputs, labels = inputs.to(device), labels.to(device)
111
+ optimizer.zero_grad()
112
+ outputs = model(inputs)
113
+ loss = criterion(outputs, labels)
114
+ loss.backward()
115
+ optimizer.step()
116
+ train_loss += loss.item() * inputs.size(0)
117
+ #scheduler.step()
118
+ # ---------- 验证 ----------
119
+ model.eval()
120
+ valid_loss = 0.0
121
+ with torch.no_grad():
122
+ for inputs, labels in valid_loader:
123
+ inputs, labels = inputs.to(device), labels.to(device)
124
+ outputs = model(inputs)
125
+ loss = criterion(outputs, labels)
126
+ valid_loss += loss.item() * inputs.size(0)
127
+
128
+ train_loss /= len(train_loader.dataset)
129
+ valid_loss /= len(valid_loader.dataset)
130
+
131
+ # ---------- Early stopping ----------
132
+ if valid_loss < best_loss:
133
+ best_loss = valid_loss
134
+ best_state = model.state_dict()
135
+ trigger_times = 0
136
+ else:
137
+ trigger_times += 1
138
+ if trigger_times >= patience:
139
+ print(f"Early stopping at epoch {epoch+1}")
140
+ break
141
+ if best_state is not None:
142
+ model.load_state_dict(best_state)
143
+ return best_loss
144
+
145
+ def predict(model, test_loader, device):
146
+ model.eval()
147
+ y_pred = []
148
+ with torch.no_grad():
149
+ for inputs, _ in test_loader:
150
+ inputs = inputs.to(device)
151
+ outputs = model(inputs) # (batch_size, num_classes)
152
+ preds = torch.argmax(outputs, dim=1)
153
+ y_pred.append(preds.cpu().numpy())
154
+ y_pred = np.concatenate(y_pred, axis=0)
155
+ return y_pred
156
+
157
+
158
+ def run_nested_cv(args, data, label, nsnp, num_classes, device, gpu_handle=None):
159
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
160
+ os.makedirs(result_dir, exist_ok=True)
161
+ print("Starting 10-fold cross-validation...")
162
+ kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
163
+ data = data.reshape(data.shape[0], 1, nsnp, -1)
164
+
165
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
166
+ time_star = time.time()
167
+ for fold, (train_index, test_index) in enumerate(kf.split(data, label)):
168
+ print(f"Running fold {fold}...")
169
+ process = psutil.Process(os.getpid())
170
+ fold_start_time = time.time()
171
+
172
+ X_train, X_test = data[train_index], data[test_index]
173
+ y_train, y_test = label[train_index], label[test_index]
174
+
175
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
176
+ X_train, y_train, test_size=0.1, stratify=y_train, random_state=42
177
+ )
178
+
179
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
180
+ y_train_tensor = torch.from_numpy(y_train_sub).long().to(device)
181
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
182
+ y_valid_tensor = torch.from_numpy(y_valid).long().to(device)
183
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
184
+ y_test_tensor = torch.from_numpy(y_test).long().to(device)
185
+
186
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
187
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
188
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
189
+
190
+ train_loader = DataLoader(train_data, args.batch_size, shuffle=True)
191
+ valid_loader = DataLoader(valid_data, args.batch_size, shuffle=False)
192
+ test_loader = DataLoader(test_data, args.batch_size, shuffle=False)
193
+
194
+ model = LCLModel(DataDimensions(channels=1, height=nsnp, width=1)).to(device)
195
+ in_features = model.fc_2.in_features
196
+ model.fc_2 = torch.nn.Linear(in_features, num_classes).to(device)
197
+ if isinstance(model.downsample_identity, torch.nn.Linear):
198
+ identity_in_features = model.downsample_identity.in_features
199
+ model.downsample_identity = torch.nn.Linear(identity_in_features, num_classes).to(device)
200
+ else:
201
+ identity_in_features = model.lcl_blocks[-1].out_features
202
+ model.downsample_identity = torch.nn.Linear(identity_in_features, num_classes).to(device)
203
+
204
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
205
+ loss_fn = torch.nn.CrossEntropyLoss()
206
+
207
+ train_model(model, train_loader, valid_loader, optimizer, loss_fn, args.epochs, args.patience, device)
208
+ y_pred = predict(model, test_loader, device)
209
+
210
+ acc = accuracy_score(y_test, y_pred)
211
+ prec, rec, f1, _ = precision_recall_fscore_support(
212
+ y_test, y_pred, average="macro", zero_division=0
213
+ )
214
+
215
+ all_acc.append(acc)
216
+ all_prec.append(prec)
217
+ all_rec.append(rec)
218
+ all_f1.append(f1)
219
+
220
+ fold_time = time.time() - fold_start_time
221
+ fold_gpu_mem = get_gpu_mem_by_pid(os.getpid(), gpu_handle)
222
+ fold_cpu_mem = process.memory_info().rss / 1024**2
223
+ print(f'Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, Time={fold_time:.2f}s, '
224
+ f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
225
+
226
+ if torch.cuda.is_available():
227
+ torch.cuda.empty_cache()
228
+ torch.cuda.reset_peak_memory_stats()
229
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_pred})
230
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
231
+
232
+ print("\n===== Cross-validation summary =====")
233
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
234
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
235
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
236
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
237
+ print(f"Time: {time.time() - time_star:.2f}s")
238
+
239
+
240
+ def EIR_class():
241
+ start = time.time()
242
+ set_seed(42)
243
+ pynvml.nvmlInit()
244
+ gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
245
+
246
+ args = parse_args()
247
+ all_species = ["Human/Sim/"]
248
+ for i in range(len(all_species)):
249
+ args.species = all_species[i]
250
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
251
+ args.device = device
252
+ X, Y, nsamples, nsnp, names = load_data(args)
253
+
254
+ print("starting run " + args.methods + args.species)
255
+ label_raw = np.nan_to_num(Y[:, 0])
256
+ le = LabelEncoder()
257
+ label = le.fit_transform(label_raw)
258
+ num_classes = len(le.classes_)
259
+
260
+ best_params = Hyperparameter(X, label, nsnp, num_classes)
261
+ args.learning_rate = best_params['learning_rate']
262
+ args.batch_size = best_params['batch_size']
263
+ args.patience = best_params['patience']
264
+ start_time = time.time()
265
+ if torch.cuda.is_available():
266
+ torch.cuda.reset_peak_memory_stats()
267
+ process = psutil.Process(os.getpid())
268
+ run_nested_cv(args, data=X, label=label, nsnp=nsnp, num_classes=num_classes, device=args.device, gpu_handle=gpu_handle)
269
+
270
+ elapsed_time = time.time() - start_time
271
+ print(f"Running time: {elapsed_time:.2f}s")
272
+ print("successfully")
273
+
274
+
275
+ if __name__ == '__main__':
276
+ EIR_class()
@@ -0,0 +1,184 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import StratifiedKFold, train_test_split
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
10
+ from torch.utils.data import DataLoader, TensorDataset
11
+ from optuna.exceptions import TrialPruned
12
+ from .utils.models_locally_connected import LCLModel
13
+ from .utils.common import DataDimensions
14
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
15
+
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+
18
+ def train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs, patience, device):
19
+ model.to(device)
20
+ best_loss = float('inf')
21
+ best_state = None
22
+ trigger_times = 0
23
+
24
+ for epoch in range(num_epochs):
25
+ model.train()
26
+ train_loss = 0.0
27
+ for inputs, labels in train_loader:
28
+ inputs, labels = inputs.to(device), labels.to(device)
29
+ optimizer.zero_grad()
30
+ outputs = model(inputs)
31
+ loss = criterion(outputs, labels)
32
+ loss.backward()
33
+ optimizer.step()
34
+ train_loss += loss.item() * inputs.size(0)
35
+
36
+ model.eval()
37
+ valid_loss = 0.0
38
+ with torch.no_grad():
39
+ for inputs, labels in valid_loader:
40
+ inputs, labels = inputs.to(device), labels.to(device)
41
+ outputs = model(inputs)
42
+ loss = criterion(outputs, labels)
43
+ valid_loss += loss.item() * inputs.size(0)
44
+
45
+ train_loss /= len(train_loader.dataset)
46
+ valid_loss /= len(valid_loader.dataset)
47
+
48
+ # ---------- Early stopping ----------
49
+ if valid_loss < best_loss:
50
+ best_loss = valid_loss
51
+ best_state = model.state_dict()
52
+ trigger_times = 0
53
+ else:
54
+ trigger_times += 1
55
+ if trigger_times >= patience:
56
+ print(f"Early stopping at epoch {epoch+1}")
57
+ break
58
+
59
+ if best_state is not None:
60
+ model.load_state_dict(best_state)
61
+ return best_loss
62
+
63
+ def predict(model, test_loader, device):
64
+ model.eval()
65
+ y_pred = []
66
+ with torch.no_grad():
67
+ for inputs, _ in test_loader:
68
+ inputs = inputs.to(device)
69
+ outputs = model(inputs) # (batch_size, num_classes)
70
+ preds = torch.argmax(outputs, dim=1)
71
+ y_pred.append(preds.cpu().numpy())
72
+ y_pred = np.concatenate(y_pred, axis=0)
73
+ return y_pred
74
+
75
+ def run_nested_cv_with_early_stopping(data, label, nsnp, num_classes, learning_rate, patience, batch_size, epochs=1000):
76
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77
+ print("Starting 10-fold cross-validation...")
78
+ kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
79
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
80
+
81
+ for fold, (train_index, test_index) in enumerate(kf.split(data, label)):
82
+ print(f"Running fold {fold}...")
83
+ process = psutil.Process(os.getpid())
84
+ fold_start_time = time.time()
85
+
86
+ X_train, X_test = data[train_index], data[test_index]
87
+ y_train, y_test = label[train_index], label[test_index]
88
+
89
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
90
+ X_train, y_train, test_size=0.1, stratify=y_train, random_state=42
91
+ )
92
+
93
+ x_train_tensor = torch.from_numpy(X_train_sub).float().to(device)
94
+ y_train_tensor = torch.from_numpy(y_train_sub).long().to(device)
95
+ x_valid_tensor = torch.from_numpy(X_valid).float().to(device)
96
+ y_valid_tensor = torch.from_numpy(y_valid).long().to(device)
97
+ x_test_tensor = torch.from_numpy(X_test).float().to(device)
98
+ y_test_tensor = torch.from_numpy(y_test).long().to(device)
99
+
100
+ train_data = TensorDataset(x_train_tensor, y_train_tensor)
101
+ valid_data = TensorDataset(x_valid_tensor, y_valid_tensor)
102
+ test_data = TensorDataset(x_test_tensor, y_test_tensor)
103
+
104
+ train_loader = DataLoader(train_data, batch_size, shuffle=True)
105
+ valid_loader = DataLoader(valid_data, batch_size, shuffle=False)
106
+ test_loader = DataLoader(test_data, batch_size, shuffle=False)
107
+
108
+ model = LCLModel(DataDimensions(channels=1, height=nsnp, width=1)).to(device)
109
+ in_features = model.fc_2.in_features
110
+ model.fc_2 = torch.nn.Linear(in_features, num_classes).to(device)
111
+ if isinstance(model.downsample_identity, torch.nn.Linear):
112
+ identity_in_features = model.downsample_identity.in_features
113
+ model.downsample_identity = torch.nn.Linear(identity_in_features, num_classes).to(device)
114
+ else:
115
+ identity_in_features = model.lcl_blocks[-1].out_features
116
+ model.downsample_identity = torch.nn.Linear(identity_in_features, num_classes).to(device)
117
+
118
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
119
+ loss_fn = torch.nn.CrossEntropyLoss()
120
+
121
+ train_model(model, train_loader, valid_loader, optimizer, loss_fn, epochs, patience, device)
122
+ y_pred = predict(model, test_loader, device)
123
+ acc = accuracy_score(y_test, y_pred)
124
+ prec, rec, f1, _ = precision_recall_fscore_support(
125
+ y_test, y_pred, average="macro", zero_division=0
126
+ )
127
+
128
+ if np.isnan(f1) or f1 <= 0:
129
+ print(f"Fold {fold} resulted in NaN or zero F1, pruning the trial...")
130
+ raise TrialPruned()
131
+ all_acc.append(acc)
132
+ all_prec.append(prec)
133
+ all_rec.append(rec)
134
+ all_f1.append(f1)
135
+
136
+ fold_time = time.time() - fold_start_time
137
+ fold_cpu_mem = process.memory_info().rss / 1024**2
138
+ print(f'Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, '
139
+ f'Time={fold_time:.2f}s, CPU={fold_cpu_mem:.2f}MB')
140
+
141
+ print("\n===== Cross-validation summary =====")
142
+ print(f"Average ACC: {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
143
+ print(f"Average PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
144
+ print(f"Average REC: {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
145
+ print(f"Average F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
146
+
147
+ return float(np.mean(all_f1)) if all_f1 else 0.0
148
+
149
+ def set_seed(seed=42):
150
+ random.seed(seed)
151
+ np.random.seed(seed)
152
+ torch.manual_seed(seed)
153
+ if torch.cuda.is_available():
154
+ torch.cuda.manual_seed_all(seed)
155
+ torch.backends.cudnn.deterministic = True
156
+ torch.backends.cudnn.benchmark = False
157
+
158
+ def Hyperparameter(data, label, nsnp, num_classes):
159
+ set_seed(42)
160
+
161
+ def objective(trial):
162
+ learning_rate = trial.suggest_float("learning_rate", 1e-4, 0.1, log=True)
163
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
164
+ patience = trial.suggest_int("patience", 10, 100, step=10)
165
+ try:
166
+ f1_score = run_nested_cv_with_early_stopping(
167
+ data=data,
168
+ label=label,
169
+ nsnp=nsnp,
170
+ num_classes=num_classes,
171
+ learning_rate=learning_rate,
172
+ patience=patience,
173
+ batch_size=batch_size
174
+ )
175
+ except TrialPruned:
176
+ return float("-inf")
177
+ return f1_score
178
+
179
+ study = optuna.create_study(direction="maximize")
180
+ study.optimize(objective, n_trials=20)
181
+
182
+ print("Best hyperparameters:", study.best_params)
183
+ print("successfully")
184
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .EIR_class import EIR_class
2
+
3
+ EIR = EIR_class
4
+
5
+ __all__ = ["EIR","EIR_class"]
File without changes
@@ -0,0 +1,97 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Literal, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from eir.models.input.array.array_models import al_pre_normalization
8
+ from eir.models.input.array.models_locally_connected import LCLModel, LCLModelConfig
9
+ from eir.models.layers.projection_layers import get_projection_layer
10
+ from eir.models.output.array.output_array_models_cnn import (
11
+ CNNUpscaleModel,
12
+ CNNUpscaleModelConfig,
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ from eir.setup.input_setup_modules.common import DataDimensions
17
+
18
+ al_array_model_types = Literal["lcl", "cnn"]
19
+ al_output_array_model_classes = Type[LCLModel] | Type[CNNUpscaleModel]
20
+ al_output_array_models = LCLModel | CNNUpscaleModel
21
+ al_output_array_model_config_classes = (
22
+ Type["LCLOutputModelConfig"] | Type[CNNUpscaleModelConfig]
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class LCLOutputModelConfig(LCLModelConfig):
28
+ cutoff: int | Literal["auto"] = "auto"
29
+
30
+
31
+ @dataclass
32
+ class ArrayOutputModuleConfig:
33
+
34
+ """
35
+ :param model_type:
36
+ Which type of image model to use.
37
+
38
+ :param model_init_config:
39
+ Configuration used to initialise model.
40
+ """
41
+
42
+ model_type: al_array_model_types
43
+ model_init_config: LCLOutputModelConfig
44
+ pre_normalization: al_pre_normalization = None
45
+
46
+
47
+ class ArrayOutputWrapperModule(nn.Module):
48
+ def __init__(
49
+ self,
50
+ feature_extractor: al_output_array_models,
51
+ output_name: str,
52
+ target_data_dimensions: "DataDimensions",
53
+ ):
54
+ super().__init__()
55
+ self.feature_extractor = feature_extractor
56
+ self.output_name = output_name
57
+ self.data_dimensions = target_data_dimensions
58
+
59
+ self.target_width = self.data_dimensions.num_elements()
60
+ self.target_shape = self.data_dimensions.full_shape()
61
+
62
+ diff_tolerance = get_diff_tolerance(num_target_elements=self.target_width)
63
+
64
+ self.projection_head = get_projection_layer(
65
+ input_dimension=self.feature_extractor.num_out_features,
66
+ target_dimension=self.target_width,
67
+ projection_layer_type="lcl_residual",
68
+ lcl_diff_tolerance=diff_tolerance,
69
+ )
70
+
71
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
72
+ out = self.feature_extractor(x)
73
+
74
+ out = out.reshape(out.shape[0], -1)
75
+ out = self.projection_head(out)
76
+
77
+ out = out[:, : self.target_width]
78
+
79
+ out = out.reshape(-1, *self.target_shape)
80
+
81
+ return {self.output_name: out}
82
+
83
+
84
+ def get_diff_tolerance(num_target_elements: int) -> int:
85
+ return int(0.001 * num_target_elements)
86
+
87
+
88
+ def get_array_output_module(
89
+ feature_extractor: al_output_array_models,
90
+ output_name: str,
91
+ target_data_dimensions: "DataDimensions",
92
+ ) -> ArrayOutputWrapperModule:
93
+ return ArrayOutputWrapperModule(
94
+ feature_extractor=feature_extractor,
95
+ output_name=output_name,
96
+ target_data_dimensions=target_data_dimensions,
97
+ )
@@ -0,0 +1,65 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ # from eir.data_load.data_source_modules.deeplake_ops import (
8
+ # get_deeplake_input_source_iterable,
9
+ # is_deeplake_dataset,
10
+ # load_deeplake_dataset,
11
+ # )
12
+ # from eir.data_load.label_setup import get_file_path_iterator
13
+
14
+
15
+ @dataclass
16
+ class DataDimensions:
17
+ channels: int
18
+ height: int
19
+ width: int
20
+ extra_dims: tuple[int, ...] = tuple()
21
+
22
+ def num_elements(self) -> int:
23
+ base = self.channels * self.height * self.width
24
+ return int(base * np.prod(self.extra_dims))
25
+
26
+ def full_shape(self) -> Tuple[int, ...]:
27
+ return (self.channels, self.height, self.width) + self.extra_dims
28
+
29
+
30
+ # def get_data_dimension_from_data_source(
31
+ # data_source: Path,
32
+ # deeplake_inner_key: Optional[str] = None,
33
+ # ) -> DataDimensions:
34
+ # """
35
+ # TODO: Make more dynamic / robust. Also weird to say "width" for a 1D vector.
36
+ # """
37
+ #
38
+ # if is_deeplake_dataset(data_source=str(data_source)):
39
+ # assert deeplake_inner_key is not None, data_source
40
+ # deeplake_ds = load_deeplake_dataset(data_source=str(data_source))
41
+ # deeplake_iter = get_deeplake_input_source_iterable(
42
+ # deeplake_dataset=deeplake_ds, inner_key=deeplake_inner_key
43
+ # )
44
+ # shape = next(deeplake_iter).shape
45
+ # else:
46
+ # iterator = get_file_path_iterator(data_source=data_source)
47
+ # path = next(iterator)
48
+ # shape = np.load(file=path).shape
49
+ #
50
+ # extra_dims: tuple[int, ...] = tuple()
51
+ # if len(shape) == 1:
52
+ # channels, height, width = 1, 1, shape[0]
53
+ # elif len(shape) == 2:
54
+ # channels, height, width = 1, shape[0], shape[1]
55
+ # elif len(shape) == 3:
56
+ # channels, height, width = shape
57
+ # else:
58
+ # channels, height, width = shape[0], shape[1], shape[2]
59
+ # extra_dims = shape[3:]
60
+ #
61
+ # return DataDimensions(
62
+ # channels=channels, height=height, width=width, extra_dims=extra_dims
63
+ # )
64
+
65
+