gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,546 @@
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ from typing import Callable, Dict, Literal, Optional, Sequence, Tuple, Type, Union
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
10
+ from torch.nn.functional import pad
11
+ from transformers import PretrainedConfig, PreTrainedModel
12
+
13
+ from .lcl_layers import _find_lcl_padding_needed
14
+ from .logging import get_logger
15
+
16
+ logger = get_logger(name=__name__, tqdm_compatible=True)
17
+
18
+
19
+ @dataclass
20
+ class SequenceModelConfig:
21
+
22
+ """
23
+ :param model_init_config:
24
+ Configuration / arguments used to initialise model.
25
+
26
+ :param model_type:
27
+ Which type of image model to use.
28
+
29
+ :param embedding_dim:
30
+ Which dimension to use for the embeddings. If ``None``, will automatically set
31
+ this value based on the number of tokens and attention heads.
32
+
33
+ :param position:
34
+ Whether to encode the token position or use learnable position embeddings.
35
+
36
+ :param position_dropout:
37
+ Dropout for the positional encoding / embedding.
38
+
39
+ :param window_size:
40
+ If set to more than 0, will apply a sliding window of feature
41
+ extraction over the input, meaning the model (e.g. transformer) will only
42
+ see a part of the input at a time. Can be Useful to avoid the O(n²)
43
+ complexity of transformers, as it becomes O(window_size² * n_windows) instead.
44
+
45
+ :param pool:
46
+ Whether and how to pool (max / avg) the final feature maps before being
47
+ passed to the final fusion module / predictor. Meaning we pool over the
48
+ sequence (i.e. time) dimension, so the resulting dimensions is embedding_dim
49
+ instead of sequence_length * embedding_dim. If using windowed / conv
50
+ transformers, this becomes embedding_dim * number_of_chunks.
51
+
52
+ :param pretrained_model:
53
+ Specify whether the model type is assumed to be pretrained and from the
54
+ Pytorch Image Models repository.
55
+
56
+ :param freeze_pretrained_model:
57
+ Whether to freeze the pretrained model weights.
58
+
59
+ """
60
+
61
+ model_init_config: Union["BasicTransformerFeatureExtractorModelConfig", Dict]
62
+
63
+ #model_type: Literal["sequence-default"] | str = "sequence-default"
64
+ model_type: Union[Literal["sequence-default"], str] = "sequence-default"
65
+
66
+ embedding_dim: int = 64
67
+
68
+ position: Literal["encode", "embed"] = "encode"
69
+ position_dropout: float = 0.10
70
+ window_size: int = 0
71
+ pool: Union[Literal["avg"], Literal["max"], None] = None
72
+
73
+ pretrained_model: bool = False
74
+ freeze_pretrained_model: bool = False
75
+
76
+
77
+ class TransformerWrapperModel(nn.Module):
78
+ def __init__(
79
+ self,
80
+ feature_extractor: Union["TransformerFeatureExtractor", nn.Module],
81
+ model_config: SequenceModelConfig,
82
+ embedding_dim: int,
83
+ num_tokens: int,
84
+ max_length: int,
85
+ external_feature_extractor: bool,
86
+ device: str,
87
+ embeddings: Optional[nn.Embedding] = None,
88
+ pre_computed_num_out_features: Optional[int] = None,
89
+ ) -> None:
90
+ super().__init__()
91
+ self.model_config = model_config
92
+ self.embedding_dim = embedding_dim
93
+ self.num_tokens = num_tokens
94
+ self.max_length = max_length
95
+ self.external_feature_extractor = external_feature_extractor
96
+ self.pre_computed_num_out_features = pre_computed_num_out_features
97
+
98
+ pos_repr_class = get_positional_representation_class(
99
+ position_model_config=self.model_config.position
100
+ )
101
+ self.pos_representation = pos_repr_class(
102
+ embedding_dim=self.embedding_dim,
103
+ dropout=self.model_config.position_dropout,
104
+ max_length=self.max_length,
105
+ )
106
+
107
+ if embeddings:
108
+ self.embedding = embeddings
109
+ else:
110
+ self.embedding = nn.Embedding(
111
+ num_embeddings=self.num_tokens, embedding_dim=self.embedding_dim
112
+ )
113
+
114
+ self.feature_extractor = feature_extractor
115
+
116
+ if not embeddings:
117
+ self.init_embedding_weights()
118
+
119
+ (
120
+ self.dynamic_extras,
121
+ self.extract_features,
122
+ ) = _get_transformer_wrapper_feature_extractor(
123
+ feature_extractor=self.feature_extractor,
124
+ window_size=self.model_config.window_size,
125
+ max_length=self.max_length,
126
+ embedding_dim=self.embedding_dim,
127
+ pool=self.model_config.pool,
128
+ device=device,
129
+ external_feature_extractor=self.external_feature_extractor,
130
+ )
131
+
132
+ @property
133
+ def num_out_features(self) -> int:
134
+ if self.pre_computed_num_out_features:
135
+ return self.pre_computed_num_out_features
136
+
137
+ padding = self.dynamic_extras.get("padding", 0)
138
+ length_with_padding = self.max_length + padding
139
+
140
+ if self.model_config.pool in ("avg", "max"):
141
+ num_chunks = 1
142
+ if self.model_config.window_size:
143
+ num_chunks = length_with_padding // self.model_config.window_size
144
+
145
+ return self.embedding_dim * num_chunks
146
+
147
+ return length_with_padding * self.embedding_dim
148
+
149
+ def script_submodules_for_tracing(self):
150
+ self.embedding = torch.jit.script(self.embedding)
151
+
152
+ def init_embedding_weights(self) -> None:
153
+ init_range = 0.1
154
+ self.embedding.weight.data.uniform_(-init_range, init_range)
155
+
156
+ def forward(self, input: Tensor) -> Tensor:
157
+ out = input * math.sqrt(self.embedding_dim)
158
+ out = self.pos_representation(out)
159
+ out = self.extract_features(out)
160
+
161
+ return out
162
+
163
+
164
+ def get_embedding_dim_for_sequence_model(
165
+ embedding_dim: Union[None, int], num_tokens: int, num_heads: int
166
+ ) -> int:
167
+ if embedding_dim is None:
168
+ auto_emb_dim = math.ceil((int(num_tokens**0.25) / num_heads)) * num_heads
169
+ logger.info(
170
+ "Setting up automatic embedding dimension of %d based on %d "
171
+ "tokens and %d attention heads.",
172
+ auto_emb_dim,
173
+ num_tokens,
174
+ num_heads,
175
+ )
176
+ embedding_dim = auto_emb_dim
177
+
178
+ return embedding_dim
179
+
180
+
181
+ def _get_transformer_wrapper_feature_extractor(
182
+ feature_extractor: Union["TransformerFeatureExtractor", nn.Module],
183
+ external_feature_extractor: bool,
184
+ window_size: int,
185
+ embedding_dim: int,
186
+ max_length: int,
187
+ device: str,
188
+ pool: Union[Literal["avg"], Literal["max"], None] = None,
189
+ ) -> Tuple[Dict[str, int], Callable[[torch.Tensor], torch.Tensor]]:
190
+ dynamic_extras = {"padding": 0}
191
+
192
+ feature_extractor_forward = _get_feature_extractor_forward(
193
+ is_hf_model=external_feature_extractor,
194
+ feature_extractor=feature_extractor,
195
+ input_length=window_size if window_size else max_length,
196
+ embedding_size=embedding_dim,
197
+ device=device,
198
+ pool=pool,
199
+ )
200
+ if not window_size:
201
+ extractor = partial(
202
+ feature_extractor_forward, feature_extractor=feature_extractor
203
+ )
204
+ else:
205
+ num_chunks = int(math.ceil(max_length / window_size))
206
+ logger.debug(
207
+ "Setting num chunks to %d as window size of %d and maximum sequence length "
208
+ "of %d were passed in.",
209
+ num_chunks,
210
+ window_size,
211
+ max_length,
212
+ )
213
+
214
+ padding = _find_lcl_padding_needed(
215
+ input_size=max_length,
216
+ kernel_size=window_size,
217
+ num_chunks=num_chunks,
218
+ )
219
+ dynamic_extras["padding"] = padding
220
+
221
+ extractor = partial(
222
+ _conv_transformer_forward,
223
+ feature_extractor=feature_extractor,
224
+ feature_extractor_forward_callable=feature_extractor_forward,
225
+ max_length=max_length,
226
+ window_size=window_size,
227
+ padding=padding,
228
+ )
229
+
230
+ return dynamic_extras, extractor
231
+
232
+
233
+ def _get_feature_extractor_forward(
234
+ is_hf_model: bool,
235
+ feature_extractor: Union[nn.Module, PreTrainedModel],
236
+ input_length: int,
237
+ embedding_size: int,
238
+ device: str,
239
+ pool: Union[Literal["avg"], Literal["max"], None] = None,
240
+ ) -> Callable[
241
+ [torch.Tensor, Union["TransformerFeatureExtractor", nn.Module]], torch.Tensor
242
+ ]:
243
+ if is_hf_model:
244
+ return get_hf_transformer_forward(
245
+ feature_extractor_=feature_extractor,
246
+ input_length=input_length,
247
+ embedding_dim=embedding_size,
248
+ device=device,
249
+ pool=pool,
250
+ )
251
+ return _get_simple_transformer_forward(pool=pool)
252
+
253
+
254
+ def _get_simple_transformer_forward(
255
+ pool: Union[Literal["avg"], Literal["max"], None] = None,
256
+ ):
257
+ pooling_func = _get_sequence_pooling_func(pool=pool)
258
+
259
+ def _simple_transformer_forward(
260
+ input: torch.Tensor,
261
+ feature_extractor: "TransformerFeatureExtractor",
262
+ ) -> torch.Tensor:
263
+ tensor_out = feature_extractor(input)
264
+ tensor_pooled = pooling_func(input=tensor_out)
265
+ final_out = tensor_pooled.flatten(1)
266
+
267
+ return final_out
268
+
269
+ return _simple_transformer_forward
270
+
271
+
272
+ def get_hf_transformer_forward(
273
+ feature_extractor_: PreTrainedModel,
274
+ input_length: int,
275
+ embedding_dim: int,
276
+ device: str,
277
+ pool: Union[Literal["avg"], Literal["max"], None] = None,
278
+ ):
279
+ forward_argnames = inspect.getfullargspec(feature_extractor_.forward)[0]
280
+
281
+ bound_kwargs = _build_transformer_forward_kwargs(
282
+ forward_argnames=forward_argnames,
283
+ config=feature_extractor_.config,
284
+ input_length=input_length,
285
+ embedding_dim=embedding_dim,
286
+ device=device,
287
+ )
288
+
289
+ pooling_func = _get_sequence_pooling_func(pool=pool)
290
+
291
+ def _hf_transformer_forward(
292
+ input: torch.Tensor,
293
+ feature_extractor: nn.Module,
294
+ key: str = "last_hidden_state",
295
+ ) -> torch.Tensor:
296
+ hf_transformer_out = feature_extractor(inputs_embeds=input, **bound_kwargs)
297
+ tensor_out = getattr(hf_transformer_out, key)
298
+ tensor_pooled = pooling_func(input=tensor_out)
299
+ final_out = tensor_pooled.flatten(1)
300
+ return final_out
301
+
302
+ return _hf_transformer_forward
303
+
304
+
305
+ def _get_sequence_pooling_func(
306
+ pool: Union[Literal["avg"], Literal["max"], None]
307
+ ) -> Callable:
308
+ def _identity(input: torch.Tensor) -> torch.Tensor:
309
+ return input
310
+
311
+ def _max(input: torch.Tensor) -> torch.Tensor:
312
+ return input.max(dim=1)[0]
313
+
314
+ def _avg(input: torch.Tensor) -> torch.Tensor:
315
+ return input.mean(dim=1)
316
+
317
+ if pool is None:
318
+ return _identity
319
+
320
+ elif pool == "max":
321
+ return _max
322
+
323
+ elif pool == "avg":
324
+ return _avg
325
+
326
+ raise ValueError()
327
+
328
+
329
+ def _build_transformer_forward_kwargs(
330
+ forward_argnames: Sequence[str],
331
+ config: PretrainedConfig,
332
+ input_length: int,
333
+ embedding_dim: int,
334
+ device: str,
335
+ ) -> Dict:
336
+ """
337
+ TODO: Deprecate.
338
+ """
339
+ kwargs = {}
340
+
341
+ if "decoder_inputs_embeds" in forward_argnames:
342
+ kwargs["decoder_inputs_embeds"] = torch.randn(1, input_length, embedding_dim)
343
+
344
+ return kwargs
345
+
346
+
347
+ def _conv_transformer_forward(
348
+ input: torch.Tensor,
349
+ feature_extractor: "TransformerFeatureExtractor",
350
+ feature_extractor_forward_callable: Callable[
351
+ [torch.Tensor, Callable], torch.Tensor
352
+ ],
353
+ max_length: int,
354
+ window_size: int,
355
+ padding: int,
356
+ ) -> torch.Tensor:
357
+ out = pad(input=input, pad=[0, 0, padding, 0])
358
+ total_length = max_length + padding
359
+
360
+ aggregated_out = None
361
+ for lower_index in range(0, total_length, window_size):
362
+ upper_index = lower_index + window_size
363
+
364
+ cur_input = out[:, lower_index:upper_index, :]
365
+ cur_out = feature_extractor_forward_callable(cur_input, feature_extractor)
366
+
367
+ if aggregated_out is None:
368
+ aggregated_out = cur_out
369
+ else:
370
+ aggregated_out = torch.cat((aggregated_out, cur_out), dim=1)
371
+
372
+ assert aggregated_out is not None
373
+ return aggregated_out
374
+
375
+
376
+ @dataclass
377
+ class BasicTransformerFeatureExtractorModelConfig:
378
+ """
379
+ :param num_heads:
380
+ The number of heads in the multi-head attention models
381
+
382
+ :param num_layers:
383
+ The number of encoder blocks in the transformer model.
384
+
385
+ :param dim_feedforward:
386
+ The dimension of the feedforward network model
387
+
388
+ :param dropout:
389
+ Dropout value to use in the encoder layers.
390
+ """
391
+
392
+ num_heads: int = 8
393
+ num_layers: int = 2
394
+ dim_feedforward: Union[int, Literal["auto"]] = "auto"
395
+ dropout: float = 0.10
396
+
397
+
398
+ class TransformerFeatureExtractor(nn.Module):
399
+ def __init__(
400
+ self,
401
+ model_config: BasicTransformerFeatureExtractorModelConfig,
402
+ embedding_dim: int,
403
+ num_tokens: int,
404
+ max_length: int,
405
+ ) -> None:
406
+ super().__init__()
407
+
408
+ self.model_config = model_config
409
+ self.embedding_dim = embedding_dim
410
+ self.num_tokens = num_tokens
411
+ self.max_length = max_length
412
+
413
+ dim_feed_forward = parse_dim_feedforward(
414
+ dim_feedforward=self.model_config.dim_feedforward,
415
+ embedding_dim=self.embedding_dim,
416
+ )
417
+
418
+ encoder_layer_base = TransformerEncoderLayer(
419
+ d_model=self.embedding_dim,
420
+ nhead=self.model_config.num_heads,
421
+ dim_feedforward=dim_feed_forward,
422
+ dropout=self.model_config.dropout,
423
+ activation="gelu",
424
+ batch_first=True,
425
+ norm_first=True,
426
+ )
427
+ self.transformer_encoder = TransformerEncoder(
428
+ encoder_layer=encoder_layer_base, num_layers=self.model_config.num_layers
429
+ )
430
+
431
+ @property
432
+ def num_out_features(self) -> int:
433
+ return self.max_length * self.embedding_dim
434
+
435
+ def forward(self, input: Tensor) -> Tensor:
436
+ out = self.transformer_encoder(input)
437
+ return out
438
+
439
+
440
+ def parse_dim_feedforward(
441
+ dim_feedforward: Union[int, Literal["auto"]], embedding_dim: int
442
+ ) -> int:
443
+ if dim_feedforward == "auto":
444
+ dim_feedforward = embedding_dim * 4
445
+ logger.info(
446
+ "Setting dim_feedfoward to %d based on embedding_dim=%d and 'auto' "
447
+ "option.",
448
+ dim_feedforward,
449
+ embedding_dim,
450
+ )
451
+
452
+ return dim_feedforward
453
+
454
+
455
+ class SequenceOutputTransformerFeatureExtractor(TransformerFeatureExtractor):
456
+ def __init__(
457
+ self,
458
+ model_config: BasicTransformerFeatureExtractorModelConfig,
459
+ embedding_dim: int,
460
+ num_tokens: int,
461
+ max_length: int,
462
+ ):
463
+ super().__init__(
464
+ model_config=model_config,
465
+ embedding_dim=embedding_dim,
466
+ num_tokens=num_tokens,
467
+ max_length=max_length,
468
+ )
469
+
470
+ mask = torch.triu(
471
+ torch.ones(self.max_length, self.max_length) * float("-inf"), diagonal=1
472
+ )
473
+ self.register_buffer("mask", mask)
474
+
475
+ def forward(self, input: Tensor) -> Tensor:
476
+ out = self.transformer_encoder(input, mask=self.mask)
477
+ return out
478
+
479
+
480
+ def get_positional_representation_class(
481
+ position_model_config: Literal["encode", "embed"]
482
+ ) -> Union[Type["PositionalEncoding"], Type["PositionalEmbedding"]]:
483
+ if position_model_config == "encode":
484
+ return PositionalEncoding
485
+ elif position_model_config == "embed":
486
+ return PositionalEmbedding
487
+ raise ValueError(
488
+ "Unknown value for positional representation. "
489
+ "Expected 'encode' or 'embed' but got '%s'.",
490
+ position_model_config,
491
+ )
492
+
493
+
494
+ class PositionalEncoding(nn.Module):
495
+ def __init__(
496
+ self,
497
+ embedding_dim: int,
498
+ max_length: int,
499
+ dropout: float = 0.1,
500
+ ) -> None:
501
+ super().__init__()
502
+ self.dropout = nn.Dropout(p=dropout)
503
+ self.max_length = max_length
504
+
505
+ position = torch.arange(max_length).unsqueeze(1)
506
+ div_term = torch.exp(
507
+ torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim)
508
+ )
509
+ pe: torch.Tensor = torch.zeros(1, max_length, embedding_dim)
510
+ pe[0, :, 0::2] = torch.sin(position * div_term)
511
+ pe[0, :, 1::2] = torch.cos(position * div_term)
512
+ self.register_buffer("pe", pe)
513
+
514
+ self.pe: Tensor
515
+
516
+ def forward(self, x: Tensor) -> Tensor:
517
+ x = x + self.pe[:, : self.max_length, :]
518
+ return self.dropout(x)
519
+
520
+
521
+ class PositionalEmbedding(nn.Module):
522
+ def __init__(
523
+ self,
524
+ embedding_dim: int,
525
+ max_length: int,
526
+ dropout: float = 0.1,
527
+ zero_init: bool = False,
528
+ ) -> None:
529
+ super().__init__()
530
+ self.dropout = nn.Dropout(p=dropout)
531
+ self.max_length = max_length
532
+
533
+ if zero_init:
534
+ self.embedding = torch.nn.Parameter(
535
+ data=torch.zeros(1, max_length, embedding_dim),
536
+ requires_grad=True,
537
+ )
538
+ else:
539
+ self.embedding = torch.nn.Parameter(
540
+ data=torch.randn(1, max_length, embedding_dim),
541
+ requires_grad=True,
542
+ )
543
+
544
+ def forward(self, x: Tensor) -> Tensor:
545
+ x = x + self.embedding[:, : self.max_length, :]
546
+ return self.dropout(x)
@@ -0,0 +1,123 @@
1
+ import os
2
+ import time
3
+ import psutil
4
+ import argparse
5
+ import random
6
+ import pandas as pd
7
+ import numpy as np
8
+ from sklearn.model_selection import KFold
9
+ from sklearn.linear_model import ElasticNet
10
+ from scipy.stats import pearsonr
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12
+ from . import ElasticNet_he
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="Argument parser")
16
+ parser.add_argument('--methods', type=str, default='ElasticNet/', help='Method name')
17
+ parser.add_argument('--species', type=str, default='', help='Dataset name')
18
+ parser.add_argument('--phe', type=str, default='', help='Phenotype name')
19
+ parser.add_argument('--data_dir', type=str, default='../../data/')
20
+ parser.add_argument('--result_dir', type=str, default='result/')
21
+
22
+ parser.add_argument('--alpha', type=float, default=0.5, help='Regularization strength')
23
+ parser.add_argument('--l1_ratio', type=float, default=0.5, help='L1 ratio (0=Ridge, 1=Lasso)')
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+ def load_data(args):
28
+ xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
29
+ yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
30
+ names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
31
+
32
+ nsample = xData.shape[0]
33
+ nsnp = xData.shape[1]
34
+ print("Number of samples: ", nsample)
35
+ print("Number of SNPs: ", nsnp)
36
+ return xData, yData, nsample, nsnp, names
37
+
38
+ def set_seed(seed=42):
39
+ random.seed(seed)
40
+ np.random.seed(seed)
41
+
42
+
43
+ def run_nested_cv(args, data, label):
44
+ result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
45
+ os.makedirs(result_dir, exist_ok=True)
46
+ print("Starting 10-fold cross-validation with ElasticNet (sklearn)...")
47
+
48
+ kf = KFold(n_splits=10, shuffle=True, random_state=42)
49
+
50
+ all_mse, all_mae, all_r2, all_pcc = [], [], [], []
51
+ time_star = time.time()
52
+
53
+ for fold, (train_index, test_index) in enumerate(kf.split(data)):
54
+ print(f"Running fold {fold}...")
55
+ process = psutil.Process(os.getpid())
56
+ fold_start_time = time.time()
57
+
58
+ x_train = data[train_index]
59
+ x_test = data[test_index]
60
+ y_train = label[train_index]
61
+ y_test = label[test_index]
62
+
63
+ model = ElasticNet(alpha=args.alpha, l1_ratio=args.l1_ratio, max_iter=1000, random_state=42)
64
+ model.fit(x_train, y_train)
65
+ y_test_preds = model.predict(x_test)
66
+
67
+ pcc, _ = pearsonr(y_test, y_test_preds)
68
+ mse = mean_squared_error(y_test, y_test_preds)
69
+ r2 = r2_score(y_test, y_test_preds)
70
+ mae = mean_absolute_error(y_test, y_test_preds)
71
+
72
+ all_mse.append(mse)
73
+ all_r2.append(r2)
74
+ all_mae.append(mae)
75
+ all_pcc.append(pcc)
76
+
77
+ fold_time = time.time() - fold_start_time
78
+ fold_cpu_mem = process.memory_info().rss / 1024**2
79
+
80
+ print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, '
81
+ f'Time={fold_time:.2f}s, CPU={fold_cpu_mem:.2f}MB')
82
+
83
+ results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_test_preds})
84
+ results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
85
+
86
+ del model, y_test_preds, x_train, x_test, y_train, y_test
87
+
88
+ print("\n===== Cross-validation summary =====")
89
+ print(f"Using sklearn ElasticNet")
90
+ print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
91
+ print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
92
+ print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
93
+ print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
94
+ print(f"Total Time: {time.time() - time_star:.2f}s")
95
+
96
+
97
+ def ElasticNet_reg():
98
+ set_seed(42)
99
+ args = parse_args()
100
+ all_species =['Cotton/']
101
+
102
+ for i in range(len(all_species)):
103
+ args.species = all_species[i]
104
+ X, Y, nsamples, nsnp, names = load_data(args)
105
+ for j in range(len(names)):
106
+ args.phe = names[j]
107
+ print("starting run " + args.methods + args.species + args.phe)
108
+ label = Y[:, j]
109
+ label = np.nan_to_num(label, nan=np.nanmean(label))
110
+
111
+ best_params = ElasticNet_he.Hyperparameter(X, label)
112
+ args.alpha = best_params['alpha']
113
+ args.l1_ratio = best_params['l1_ratio']
114
+
115
+ start_time = time.time()
116
+ run_nested_cv(args, data=X, label=label)
117
+ elapsed_time = time.time() - start_time
118
+ print(f"running time: {elapsed_time:.2f} s")
119
+ print("successfully")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ ElasticNet_reg()