gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,137 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import optuna
8
+ from sklearn.model_selection import StratifiedKFold, train_test_split
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
10
+ from torch.utils.data import DataLoader, TensorDataset
11
+ from optuna.exceptions import TrialPruned
12
+ from .gMLP_class import GEFormer
13
+
14
+ def run_nested_cv_with_early_stopping(
15
+ data, label, nsnp,
16
+ learning_rate, patience, batch_size,
17
+ epoch=1000
18
+ ):
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ print("Starting 10-fold cross-validation...")
21
+
22
+ kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
23
+
24
+ all_acc, all_prec, all_rec, all_f1 = [], [], [], []
25
+
26
+ num_classes = len(np.unique(label))
27
+
28
+ for fold, (train_index, test_index) in enumerate(kf.split(data, label)):
29
+ print(f"Running fold {fold}...")
30
+ process = psutil.Process(os.getpid())
31
+ fold_start_time = time.time()
32
+
33
+ X_train, X_test = data[train_index], data[test_index]
34
+ y_train, y_test = label[train_index], label[test_index]
35
+
36
+ X_train_sub, X_valid, y_train_sub, y_valid = train_test_split(
37
+ X_train, y_train, test_size=0.1, stratify=y_train, random_state=42
38
+ )
39
+
40
+ x_train = torch.from_numpy(X_train_sub).float().to(device)
41
+ y_train = torch.from_numpy(y_train_sub).long().to(device)
42
+ x_valid = torch.from_numpy(X_valid).float().to(device)
43
+ y_valid = torch.from_numpy(y_valid).long().to(device)
44
+ x_test = torch.from_numpy(X_test).float().to(device)
45
+ y_test_tensor = torch.from_numpy(y_test).long().to(device)
46
+
47
+ train_loader = DataLoader(
48
+ TensorDataset(x_train, y_train), batch_size, shuffle=True
49
+ )
50
+ valid_loader = DataLoader(
51
+ TensorDataset(x_valid, y_valid), batch_size, shuffle=False
52
+ )
53
+ test_loader = DataLoader(
54
+ TensorDataset(x_test, y_test_tensor), batch_size, shuffle=False
55
+ )
56
+
57
+ model = GEFormer(nsnp=nsnp, num_classes=num_classes)
58
+ model.train_model(
59
+ train_loader, valid_loader,
60
+ epoch, learning_rate, patience, device
61
+ )
62
+
63
+ logits = model.predict(test_loader)
64
+ y_pred = np.argmax(logits, axis=1)
65
+
66
+ acc = accuracy_score(y_test, y_pred)
67
+ prec, rec, f1, _ = precision_recall_fscore_support(
68
+ y_test, y_pred, average="macro", zero_division=0
69
+ )
70
+
71
+ if np.isnan(f1) or f1 <= 0:
72
+ print(f"Fold {fold} resulted in NaN or zero F1, pruning trial")
73
+ raise TrialPruned()
74
+
75
+ all_acc.append(acc)
76
+ all_prec.append(prec)
77
+ all_rec.append(rec)
78
+ all_f1.append(f1)
79
+
80
+ fold_time = time.time() - fold_start_time
81
+ fold_cpu_mem = process.memory_info().rss / 1024**2
82
+
83
+ print(
84
+ f"Fold {fold}: "
85
+ f"ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, "
86
+ f"Time={fold_time:.2f}s, CPU={fold_cpu_mem:.2f}MB"
87
+ )
88
+
89
+ print("\n===== CV Summary =====")
90
+ print(f"ACC : {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
91
+ print(f"PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
92
+ print(f"REC : {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
93
+ print(f"F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
94
+
95
+ return np.mean(all_f1) if all_f1 else 0.0
96
+
97
+ def set_seed(seed=42):
98
+ random.seed(seed)
99
+ np.random.seed(seed)
100
+ torch.manual_seed(seed)
101
+ if torch.cuda.is_available():
102
+ torch.cuda.manual_seed_all(seed)
103
+ torch.backends.cudnn.deterministic = True
104
+ torch.backends.cudnn.benchmark = False
105
+
106
+ def Hyperparameter(data, label, nsnp):
107
+ set_seed(42)
108
+
109
+ def objective(trial):
110
+ learning_rate = trial.suggest_float(
111
+ "learning_rate", 1e-4, 0.1, log=True
112
+ )
113
+ batch_size = trial.suggest_categorical(
114
+ "batch_size", [32, 64, 128]
115
+ )
116
+ patience = trial.suggest_int("patience", 1, 10)
117
+
118
+ try:
119
+ f1_score = run_nested_cv_with_early_stopping(
120
+ data=data,
121
+ label=label,
122
+ nsnp=nsnp,
123
+ learning_rate=learning_rate,
124
+ patience=patience,
125
+ batch_size=batch_size
126
+ )
127
+ except TrialPruned:
128
+ return float("-inf")
129
+
130
+ return f1_score
131
+
132
+ study = optuna.create_study(direction="maximize")
133
+ study.optimize(objective, n_trials=20)
134
+
135
+ print("best params:", study.best_params)
136
+ print("successfully")
137
+ return study.best_params
@@ -0,0 +1,5 @@
1
+ from .GEFormer_class import GEFormer_class
2
+
3
+ GEFormer = GEFormer_class
4
+
5
+ __all__ = ["GEFormer","GEFormer_class"]
@@ -0,0 +1,357 @@
1
+ from random import randrange
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+
7
+ from einops import rearrange, repeat
8
+ from einops.layers.torch import Rearrange, Reduce
9
+
10
+ # functions
11
+
12
+ def exists(val):
13
+ return val is not None
14
+
15
+ def pair(val):
16
+ return (val, val) if not isinstance(val, tuple) else val
17
+
18
+ def dropout_layers(layers, prob_survival):
19
+ if prob_survival == 1:
20
+ return layers
21
+
22
+ num_layers = len(layers)
23
+ to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival
24
+
25
+ # make sure at least one layer makes it
26
+ if all(to_drop):
27
+ rand_index = randrange(num_layers)
28
+ to_drop[rand_index] = False
29
+
30
+ layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
31
+ return layers
32
+
33
+
34
+ # helper classes
35
+
36
+ class Residual(nn.Module):
37
+ def __init__(self, fn):
38
+ super().__init__()
39
+ self.fn = fn
40
+
41
+ def forward(self, x):
42
+ return self.fn(x) + x
43
+
44
+ class PreNorm(nn.Module):
45
+ def __init__(self, dim, fn):
46
+ super().__init__()
47
+ self.fn = fn
48
+ self.norm = nn.LayerNorm(dim)
49
+
50
+ def forward(self, x, **kwargs):
51
+ x = self.norm(x)
52
+ return self.fn(x, **kwargs)
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(self, dim_in, dim_out, dim_inner, causal = False):
56
+ super().__init__()
57
+ self.scale = dim_inner ** -0.5
58
+ self.causal = causal
59
+
60
+ self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
61
+ self.to_out = nn.Linear(dim_inner, dim_out)
62
+
63
+ def forward(self, x):
64
+ device = x.device
65
+ q, k, v = self.to_qkv(x).chunk(3, dim = -1)
66
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
67
+
68
+ if self.causal:
69
+ mask = torch.ones(sim.shape[-2:], device = device).triu(1).bool()
70
+ sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)
71
+
72
+ attn = sim.softmax(dim = -1)
73
+ out = einsum('b i j, b j d -> b i d', attn, v)
74
+ return self.to_out(out)
75
+
76
+ class SpatialGatingUnit(nn.Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ dim_seq,
81
+ causal = False,
82
+ act = nn.Identity(),
83
+ heads = 1,
84
+ init_eps = 1e-3,
85
+ circulant_matrix = False
86
+ ):
87
+ super().__init__()
88
+ dim_out = dim // 2
89
+ self.heads = heads
90
+ self.causal = causal
91
+ self.norm = nn.LayerNorm(dim_out)
92
+
93
+ self.act = act
94
+
95
+ # parameters
96
+
97
+ if circulant_matrix:
98
+ self.circulant_pos_x = nn.Parameter(torch.ones(heads, dim_seq))
99
+ self.circulant_pos_y = nn.Parameter(torch.ones(heads, dim_seq))
100
+
101
+ self.circulant_matrix = circulant_matrix
102
+ shape = (heads, dim_seq,) if circulant_matrix else (heads, dim_seq, dim_seq)
103
+ weight = torch.zeros(shape)
104
+
105
+ self.weight = nn.Parameter(weight)
106
+ init_eps /= dim_seq
107
+ nn.init.uniform_(self.weight, -init_eps, init_eps)
108
+
109
+ self.bias = nn.Parameter(torch.ones(heads, dim_seq))
110
+
111
+ def forward(self, x, gate_res = None):
112
+ device, n, h = x.device, x.shape[1], self.heads
113
+
114
+ res, gate = x.chunk(2, dim = -1)
115
+ gate = self.norm(gate)
116
+
117
+ weight, bias = self.weight, self.bias
118
+
119
+ if self.circulant_matrix:
120
+ # build the circulant matrix
121
+
122
+ dim_seq = weight.shape[-1]
123
+ weight = F.pad(weight, (0, dim_seq), value = 0)
124
+ weight = repeat(weight, '... n -> ... (r n)', r = dim_seq)
125
+ weight = weight[:, :-dim_seq].reshape(h, dim_seq, 2 * dim_seq - 1)
126
+ weight = weight[:, :, (dim_seq - 1):]
127
+
128
+ # give circulant matrix absolute position awareness
129
+
130
+ pos_x, pos_y = self.circulant_pos_x, self.circulant_pos_y
131
+ weight = weight * rearrange(pos_x, 'h i -> h i ()') * rearrange(pos_y, 'h j -> h () j')
132
+
133
+ if self.causal:
134
+ weight, bias = weight[:, :n, :n], bias[:, :n]
135
+ mask = torch.ones(weight.shape[-2:], device = device).triu_(1).bool()
136
+ mask = rearrange(mask, 'i j -> () i j')
137
+ weight = weight.masked_fill(mask, 0.)
138
+
139
+ gate = rearrange(gate, 'b n (h d) -> b h n d', h = h)
140
+
141
+ gate = einsum('b h n d, h m n -> b h m d', gate, weight)
142
+ gate = gate + rearrange(bias, 'h n -> () h n ()')
143
+
144
+ gate = rearrange(gate, 'b h n d -> b n (h d)')
145
+
146
+ if exists(gate_res):
147
+ gate = gate + gate_res
148
+
149
+ return self.act(gate) * res
150
+
151
+ class gMLPBlock(nn.Module):
152
+ def __init__(
153
+ self,
154
+ *,
155
+ dim,
156
+ dim_ff,
157
+ seq_len,
158
+ heads = 1,
159
+ attn_dim = None,
160
+ causal = False,
161
+ act = nn.Identity(),
162
+ circulant_matrix = False
163
+ ):
164
+ super().__init__()
165
+ self.proj_in = nn.Sequential(
166
+ nn.Linear(dim, dim_ff),
167
+ nn.GELU()
168
+ )
169
+
170
+ self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None
171
+
172
+ self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, circulant_matrix = circulant_matrix)
173
+ self.proj_out = nn.Linear(dim_ff // 2, dim)
174
+
175
+ def forward(self, x):
176
+ gate_res = self.attn(x) if exists(self.attn) else None
177
+ x = self.proj_in(x)
178
+ x = self.sgu(x, gate_res=gate_res)
179
+ x = self.proj_out(x)
180
+ return x
181
+
182
+
183
+ # main classes
184
+ class gMLPVision(nn.Module):
185
+ def __init__(
186
+ self,
187
+ *,
188
+ image_size,
189
+ patch_size,
190
+ num_classes,
191
+ dim,
192
+ depth,
193
+ snp_len,
194
+ heads = 1,
195
+ ff_mult = 4,
196
+ channels = 1,
197
+ attn_dim = None,
198
+ prob_survival = 1.
199
+ ):
200
+ super().__init__()
201
+ assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
202
+
203
+ image_height, image_width = pair(image_size)
204
+ patch_height, patch_width = pair(patch_size)
205
+ #assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
206
+ #num_patches = (image_height[0] // patch_height[0]) * (image_width[1] // patch_width[1])
207
+ num_patches = 200
208
+ dim_ff = dim * ff_mult
209
+
210
+ self.to_patch_embed = nn.Sequential(
211
+ Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_height, p2 = patch_width),
212
+ nn.Linear(1*snp_len*1, dim)
213
+ )
214
+
215
+
216
+ self.prob_survival = prob_survival
217
+
218
+ self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = num_patches, attn_dim = attn_dim))) for i in range(depth)])
219
+
220
+ self.to_logits = nn.Sequential(
221
+ nn.LayerNorm(dim),
222
+ Reduce('b n d -> b d', 'mean'),
223
+ nn.Linear(dim, num_classes)
224
+ )
225
+
226
+ def forward(self, x):
227
+ x = self.to_patch_embed(x)
228
+ layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
229
+ x = nn.Sequential(*layers)(x)
230
+ return self.to_logits(x)
231
+
232
+ class EarlyStopping:
233
+ def __init__(self, patience=10, delta=0):
234
+ self.patience = patience
235
+ self.delta = delta
236
+ self.best_score = None
237
+ self.counter = 0
238
+ self.early_stop = False
239
+
240
+ def __call__(self, score):
241
+ if self.best_score is None:
242
+ self.best_score = score
243
+ elif score < self.best_score + self.delta:
244
+ self.counter += 1
245
+ if self.counter >= self.patience:
246
+ self.early_stop = True
247
+ else:
248
+ self.best_score = score
249
+ self.counter = 0
250
+
251
+ def exists(val):
252
+ return val is not None
253
+
254
+ class GEFormer(nn.Module):
255
+ def __init__(self, nsnp, num_classes):
256
+ super(GEFormer, self).__init__()
257
+
258
+ self.gmlp = gMLPVision(
259
+ image_size=(nsnp, 1),
260
+ patch_size=(nsnp, 1),
261
+ num_classes=126,
262
+ dim=126,
263
+ depth=1,
264
+ snp_len=nsnp
265
+ )
266
+
267
+ self.MLP = nn.Sequential(
268
+ nn.Linear(126, 128),
269
+ nn.LeakyReLU(),
270
+ nn.Dropout(0.4),
271
+ nn.Linear(128, 64),
272
+ nn.LeakyReLU(),
273
+ nn.Dropout(0.4),
274
+ nn.Linear(64, num_classes)
275
+ )
276
+
277
+ self.numsnp = nsnp
278
+
279
+ def forward(self, x):
280
+ x = x.view(x.size(0), 1, self.numsnp, 1)
281
+ x = self.gmlp(x)
282
+ logits = self.MLP(x)
283
+ return logits
284
+
285
+
286
+ def train_model(
287
+ self, train_loader, valid_loader,
288
+ num_epochs, learning_rate, patience, device
289
+ ):
290
+ optimizer = torch.optim.Adam(
291
+ self.parameters(), lr=learning_rate, weight_decay=1e-4
292
+ )
293
+
294
+ criterion = nn.CrossEntropyLoss()
295
+ self.to(device)
296
+
297
+ best_loss = float('inf')
298
+ best_state = None
299
+ trigger_times = 0
300
+
301
+ for epoch in range(num_epochs):
302
+ self.train()
303
+ train_loss = 0.0
304
+
305
+ for inputs, labels in train_loader:
306
+ inputs = inputs.to(device)
307
+ labels = labels.to(device).long()
308
+
309
+ optimizer.zero_grad()
310
+ logits = self(inputs)
311
+ loss = criterion(logits, labels)
312
+ loss.backward()
313
+ optimizer.step()
314
+
315
+ train_loss += loss.item() * inputs.size(0)
316
+
317
+ self.eval()
318
+ valid_loss = 0.0
319
+ with torch.no_grad():
320
+ for inputs, labels in valid_loader:
321
+ inputs = inputs.to(device)
322
+ labels = labels.to(device).long()
323
+ logits = self(inputs)
324
+ loss = criterion(logits, labels)
325
+ valid_loss += loss.item() * inputs.size(0)
326
+
327
+ train_loss /= len(train_loader.dataset)
328
+ valid_loss /= len(valid_loader.dataset)
329
+
330
+ if valid_loss < best_loss:
331
+ best_loss = valid_loss
332
+ best_state = self.state_dict()
333
+ trigger_times = 0
334
+ else:
335
+ trigger_times += 1
336
+ if trigger_times >= patience:
337
+ print(f"Early stopping at epoch {epoch+1}")
338
+ break
339
+
340
+ if best_state is not None:
341
+ self.load_state_dict(best_state)
342
+
343
+ return best_loss
344
+
345
+ def predict(self, test_loader):
346
+ self.eval()
347
+ logits_all = []
348
+
349
+ with torch.no_grad():
350
+ for inputs, _ in test_loader:
351
+ inputs = inputs.to(next(self.parameters()).device)
352
+ logits = self(inputs)
353
+ logits_all.append(logits.cpu().numpy())
354
+
355
+ return np.concatenate(logits_all, axis=0) # logits
356
+
357
+