gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,642 @@
1
+ from copy import copy
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Callable,
7
+ List,
8
+ Literal,
9
+ Optional,
10
+ Protocol,
11
+ Sequence,
12
+ Union,
13
+ )
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ # from eir.models.input.sequence.transformer_models import PositionalEmbedding
19
+ # from eir.models.layers.lcl_layers import LCL, LCLResidualBlock
20
+ # from eir.utils.logging import get_logger
21
+
22
+ from .lcl_layers import LCL, LCLResidualBlock
23
+ from .transformer_models import PositionalEmbedding
24
+
25
+ from torchvision.ops import StochasticDepth
26
+ from aislib.pytorch_modules import Swish
27
+
28
+
29
+ # logger = get_logger(__name__)
30
+
31
+
32
+ class FlattenFunc(Protocol):
33
+ def __call__(
34
+ self,
35
+ x: torch.Tensor,
36
+ ) -> torch.Tensor:
37
+ ...
38
+
39
+
40
+ @dataclass
41
+ class SimpleLCLModelConfig:
42
+ """
43
+ :param fc_repr_dim:
44
+ Controls the number of output sets in the first and only split layer. Analogous
45
+ to channels in CNNs.
46
+ :param num_lcl_chunks:
47
+ Controls the number of splits applied to the input. E.g. with a input with of
48
+ 800, using ``num_lcl_chunks=100`` will result in a kernel width of 8,
49
+ meaning 8 elements in the flattened input. If using a SNP inputs with a one-hot
50
+ encoding of 4 possible values, this will result in 8/2 = 2 SNPs per locally
51
+ connected area.
52
+ :param l1:
53
+ L1 regularization applied to the first and only locally connected layer.
54
+ """
55
+
56
+ fc_repr_dim: int = 12
57
+ num_lcl_chunks: int = 64
58
+ l1: float = 0.00
59
+
60
+
61
+ # class SimpleLCLModel(nn.Module):
62
+ # def __init__(
63
+ # self,
64
+ # model_config: SimpleLCLModelConfig,
65
+ # data_dimensions: "DataDimensions",
66
+ # flatten_fn: FlattenFunc,
67
+ # ):
68
+ # super().__init__()
69
+ #
70
+ # self.model_config = model_config
71
+ # self.data_dimensions = data_dimensions
72
+ # self.flatten_fn = flatten_fn
73
+ #
74
+ # num_chunks = self.model_config.num_lcl_chunks
75
+ # self.fc_0 = LCL(
76
+ # in_features=self.fc_1_in_features,
77
+ # out_feature_sets=self.model_config.fc_repr_dim,
78
+ # num_chunks=num_chunks,
79
+ # bias=True,
80
+ # )
81
+ #
82
+ # self._init_weights()
83
+ #
84
+ # @property
85
+ # def fc_1_in_features(self) -> int:
86
+ # return self.data_dimensions.num_elements()
87
+ #
88
+ # @property
89
+ # def l1_penalized_weights(self) -> torch.Tensor:
90
+ # return self.fc_0.weight
91
+ #
92
+ # @property
93
+ # def num_out_features(self) -> int:
94
+ # return self.fc_0.out_features
95
+ #
96
+ # def _init_weights(self):
97
+ # pass
98
+ #
99
+ # def forward(self, input: torch.Tensor) -> torch.Tensor:
100
+ # out = self.flatten_fn(x=input)
101
+ #
102
+ # out = self.fc_0(out)
103
+ #
104
+ # return out
105
+
106
+
107
+ @dataclass
108
+ class LCLModelConfig:
109
+ """
110
+ Note that when using the automatic network setup, kernel widths will get expanded
111
+ to ensure that the feature representations become smaller as they are propagated
112
+ through the network.
113
+
114
+ :param patch_size:
115
+ Controls the size of the patches used in the first layer. If set to ``None``,
116
+ the input is flattened according to the torch ``flatten`` function. Note that
117
+ when using this parameter, we generally want the kernel width to be set to
118
+ the multiplication of the patch size.
119
+
120
+ :param layers:
121
+ Controls the number of layers in the model. If set to ``None``, the model will
122
+ automatically set up the number of layers according to the ``cutoff`` parameter
123
+ value.
124
+
125
+ :param kernel_width:
126
+ With of the locally connected kernels. Note that in the context of genomic
127
+ inputs this refers to the flattened input,
128
+ meaning that if we have a one-hot encoding of 4 values (e.g. SNPs), 12
129
+ refers to 12/4 = 3 SNPs per locally connected window. Can be set to ``None`` if
130
+ the ``num_lcl_chunks`` parameter is set, which means that the kernel width
131
+ will be set automatically according to
132
+
133
+ :param first_kernel_expansion:
134
+ Factor to extend the first kernel. This value can both be positive or negative.
135
+ For example in the case of ``kernel_width=12``, setting
136
+ ``first_kernel_expansion=2`` means that the first kernel will have a width of
137
+ 24, whereas other kernels will have a width of 12. When using a negative value,
138
+ divides the first kernel by the value instead of multiplying.
139
+
140
+ :param channel_exp_base:
141
+ Which power of 2 to use in order to set the number of channels/weight sets in
142
+ the network. For example, setting ``channel_exp_base=3`` means that 2**3=8
143
+ weight sets will be used.
144
+
145
+ :param first_channel_expansion:
146
+ Whether to expand / shrink the number of channels in the first layer as compared
147
+ to other layers in the network. Works analogously to the
148
+ ``first_kernel_expansion`` parameter.
149
+
150
+ :param num_lcl_chunks:
151
+ Controls the number of splits applied to the input. E.g. with a input width of
152
+ 800, using ``num_lcl_chunks=100`` will result in a kernel width of 8,
153
+ meaning 8 elements in the flattened input. If using a SNP inputs with a one-hot
154
+ encoding of 4 possible values, this will result in 8/2 = 2 SNPs per locally
155
+ connected area.
156
+
157
+ :param rb_do:
158
+ Dropout in the residual blocks.
159
+
160
+ :param stochastic_depth_p:
161
+ Probability of dropping input.
162
+
163
+ :param l1:
164
+ L1 regularization applied to the first layer in the network.
165
+
166
+ :param cutoff:
167
+ Feature dimension cutoff where the automatic network setup stops adding layers.
168
+ The 'auto' option is only supported when using the model for array *outputs*,
169
+ and will set the cutoff to roughly the number of output features.
170
+
171
+ :param direction:
172
+ Whether to use a "down" or "up" network. "Down" means that the feature
173
+ representation will get smaller as it is propagated through the network, whereas
174
+ "up" means that the feature representation will get larger.
175
+
176
+ :param attention_inclusion_cutoff:
177
+ Cutoff to start including attention blocks in the network. If set to ``None``,
178
+ no attention blocks will be included. The cutoff here refers to the "length"
179
+ dimension of the input after reshaping according to the output_feature_sets
180
+ in the preceding layer. For example, if we 1024 output features, and we have
181
+ 4 output feature sets, the length dimension will be 1024/4 = 256. With an
182
+ attention cutoff >= 256, the attention block will be included.
183
+ """
184
+
185
+ patch_size: Optional[tuple[int, int, int]] = None
186
+
187
+ layers: Union[None, List[int]] = None
188
+
189
+ #kernel_width: int | Literal["patch"] = 16
190
+
191
+ kernel_width: Union[int, Literal["patch"]] = 16
192
+ first_kernel_expansion: int = -2
193
+
194
+ channel_exp_base: int = 2
195
+ first_channel_expansion: int = 1
196
+
197
+ num_lcl_chunks: Union[None, int] = None
198
+
199
+ rb_do: float = 0.10
200
+ stochastic_depth_p: float = 0.00
201
+ l1: float = 0.00
202
+
203
+ cutoff: Union[int , Literal["auto"]] = 1024
204
+ direction: Literal["down", "up"] = "down"
205
+ attention_inclusion_cutoff: Optional[int] = None
206
+
207
+
208
+ class LCLModel(nn.Module):
209
+ def __init__(
210
+ self,
211
+ # model_config: Union[LCLModelConfig, "LCLOutputModelConfig"],
212
+ # data_dimensions: "DataDimensions",
213
+ data_dimensions,
214
+ # flatten_fn: FlattenFunc,
215
+ # dynamic_cutoff: Optional[int] = None,
216
+ ):
217
+ super().__init__()
218
+
219
+ # self.model_config = model_config
220
+ self.data_dimensions = data_dimensions
221
+ # self.flatten_fn = flatten_fn
222
+
223
+ # kernel_width = parse_kernel_width(
224
+ # # kernel_width=self.model_config.kernel_width,
225
+ # # patch_size=self.model_config.patch_size,
226
+ # kernel_width=16,
227
+ # patch_size=None
228
+ #
229
+ # )
230
+ kernel_width=16
231
+
232
+ fc_0_kernel_size = calc_value_after_expansion(
233
+ # base=kernel_width,
234
+ # expansion=self.model_config.first_kernel_expansion,
235
+ base=16,
236
+ expansion=-2,
237
+
238
+ )
239
+ fc_0_channel_exponent = calc_value_after_expansion(
240
+ # base=self.model_config.channel_exp_base,
241
+ # expansion=self.model_config.first_channel_expansion,
242
+ base=2,
243
+ expansion=1
244
+ )
245
+ self.fc_0 = LCL(
246
+ in_features=self.fc_1_in_features, # 1*9033*4 = 36132
247
+ out_feature_sets=2**fc_0_channel_exponent, #4
248
+ kernel_size=fc_0_kernel_size, #8
249
+ bias=True,
250
+ )
251
+
252
+ # cutoff = dynamic_cutoff or self.model_config.cutoff
253
+ cutoff = 1024
254
+ assert isinstance(cutoff, int)
255
+
256
+ lcl_parameter_spec = LCParameterSpec(
257
+ # in_features=self.fc_0.out_features,
258
+ # kernel_width=kernel_width,
259
+ # channel_exp_base=self.model_config.channel_exp_base,
260
+ # dropout_p=self.model_config.rb_do,
261
+ # cutoff=cutoff,
262
+ # stochastic_depth_p=self.model_config.stochastic_depth_p,
263
+ # attention_inclusion_cutoff=self.model_config.attention_inclusion_cutoff,
264
+ # direction=self.model_config.direction,
265
+ in_features=self.fc_0.out_features,
266
+ kernel_width=kernel_width,
267
+ channel_exp_base=2,
268
+ dropout_p=0.10,
269
+ cutoff=cutoff,
270
+ stochastic_depth_p=0.00,
271
+ attention_inclusion_cutoff=None,
272
+ direction="down",
273
+ )
274
+ self.lcl_blocks = _get_lcl_blocks(
275
+ # lcl_spec=lcl_parameter_spec,
276
+ # block_layer_spec=self.model_config.layers,
277
+ lcl_spec=lcl_parameter_spec,
278
+ block_layer_spec=None,
279
+ )
280
+
281
+ self._init_weights()
282
+
283
+ # MLP residual
284
+ self.full_preactivation = False
285
+ in_features = self.lcl_blocks[-1].out_features
286
+ # Cattle 84 Chicken136 Chickpea 104 Cotton 8 Loblolly_Pine 152 Maize 148 Millet 656
287
+ # Mouse 320 Pig 1016 Rapeseed 312 Rice 56 Soybean 104 Wheat 536
288
+ out_features = 1
289
+ if in_features == out_features:
290
+ self.downsample_identity = lambda x: x
291
+ else:
292
+ self.downsample_identity = nn.Linear(
293
+ in_features=in_features, out_features=out_features, bias=True
294
+ )
295
+ self.norm_1 = nn.LayerNorm(normalized_shape=in_features)
296
+ self.fc_1 = nn.Linear(
297
+ in_features=in_features, out_features=out_features, bias=True
298
+ )
299
+
300
+ self.act_1 = Swish()
301
+ dropout_p = 0.0
302
+ self.do = nn.Dropout(p=dropout_p)
303
+ self.fc_2 = nn.Linear(
304
+ in_features=out_features, out_features=out_features, bias=True
305
+ )
306
+ self.stochastic_depth_p = 0.0
307
+ self.stochastic_depth = StochasticDepth(p=self.stochastic_depth_p, mode="batch")
308
+
309
+
310
+ @property
311
+ def fc_1_in_features(self) -> int:
312
+ return self.data_dimensions.num_elements() #
313
+
314
+ @property
315
+ def l1_penalized_weights(self) -> torch.Tensor:
316
+ return self.fc_0.weight
317
+
318
+ @property
319
+ def num_out_features(self) -> int:
320
+ return self.lcl_blocks[-1].out_features
321
+
322
+ def _init_weights(self):
323
+ pass
324
+
325
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
326
+ # out = self.flatten_fn(x=input) # original
327
+
328
+ # by ww
329
+ out = input # by ww
330
+ a = input.shape[0]
331
+ out = out.reshape(input.shape[0],1, -1)
332
+
333
+ out = self.fc_0(out)
334
+ out = self.lcl_blocks(out) # (4, 72)
335
+
336
+
337
+ # MLP Residual
338
+ out = self.norm_1(out)
339
+
340
+ identity = out if self.full_preactivation else out
341
+ identity = self.downsample_identity(identity)
342
+
343
+ out = self.fc_1(out)
344
+
345
+ out = self.act_1(out)
346
+ out = self.do(out)
347
+ out = self.fc_2(out)
348
+
349
+ out = self.stochastic_depth(out)
350
+
351
+ return out + identity
352
+
353
+
354
+ # return out
355
+
356
+
357
+ def parse_kernel_width(
358
+ kernel_width: Union[int, Literal["patch"]],
359
+ patch_size: Optional[tuple[int, int, int]],
360
+ ) -> int:
361
+ if kernel_width == "patch":
362
+ if patch_size is None:
363
+ raise ValueError(
364
+ "kernel_width set to 'patch', but no patch_size was specified."
365
+ )
366
+ kernel_width = patch_size[0] * patch_size[1] * patch_size[2]
367
+ return kernel_width
368
+
369
+
370
+ def flatten_h_w_fortran(x: torch.Tensor) -> torch.Tensor:
371
+ """
372
+ This is needed when e.g. flattening one-hot inputs that are ordered in a columns
373
+ wise fashion (meaning that each column is a one-hot feature),
374
+ and we want to make sure the first part of the flattened tensor is the first column,
375
+ i.e. first one-hot element.
376
+ """
377
+ column_order_flattened = x.transpose(2, 3).flatten(1)
378
+ return column_order_flattened
379
+
380
+
381
+ def calc_value_after_expansion(base: int, expansion: int, min_value: int = 0) -> int:
382
+ if expansion > 0:
383
+ return base * expansion
384
+ elif expansion < 0:
385
+ abs_expansion = abs(expansion)
386
+ return max(min_value, base // abs_expansion)
387
+ return base
388
+
389
+
390
+ @dataclass
391
+ class LCParameterSpec:
392
+ in_features: int
393
+ kernel_width: int
394
+ channel_exp_base: int
395
+ dropout_p: float
396
+ stochastic_depth_p: float
397
+ cutoff: int
398
+ attention_inclusion_cutoff: Optional[int] = None
399
+ direction: Literal["down", "up"] = "down"
400
+
401
+
402
+ def _get_lcl_blocks(
403
+ lcl_spec: LCParameterSpec,
404
+ block_layer_spec: Optional[Sequence[int]],
405
+ ) -> nn.Sequential:
406
+ factory = _get_lcl_block_factory(block_layer_spec=block_layer_spec)
407
+
408
+ blocks = factory(lcl_spec)
409
+
410
+ return blocks
411
+
412
+
413
+ def _get_lcl_block_factory(
414
+ block_layer_spec: Optional[Sequence[int]],
415
+ ) -> Callable[[LCParameterSpec], nn.Sequential]:
416
+ if not block_layer_spec:
417
+ return generate_lcl_residual_blocks_auto
418
+
419
+ auto_factory = partial(
420
+ _generate_lcl_blocks_from_spec, block_layer_spec=block_layer_spec
421
+ )
422
+
423
+ return auto_factory
424
+
425
+
426
+ def _generate_lcl_blocks_from_spec(
427
+ lcl_parameter_spec: LCParameterSpec,
428
+ block_layer_spec: List[int],
429
+ ) -> nn.Sequential:
430
+ s = lcl_parameter_spec
431
+ block_layer_spec_copy = copy(block_layer_spec)
432
+
433
+ first_block = LCLResidualBlock(
434
+ in_features=s.in_features,
435
+ kernel_size=s.kernel_width,
436
+ out_feature_sets=2**s.channel_exp_base,
437
+ dropout_p=s.dropout_p,
438
+ full_preactivation=True,
439
+ )
440
+
441
+ block_modules = [first_block]
442
+ block_layer_spec_copy[0] -= 1
443
+
444
+ for cur_layer_index, block_dim in enumerate(block_layer_spec_copy):
445
+ for block in range(block_dim):
446
+ cur_out_feature_sets = 2 ** (s.channel_exp_base + cur_layer_index)
447
+ cur_kernel_width = s.kernel_width
448
+
449
+ cur_out_feature_sets, cur_kernel_width = _adjust_auto_params(
450
+ cur_out_feature_sets=cur_out_feature_sets,
451
+ cur_kernel_width=cur_kernel_width,
452
+ direction=s.direction,
453
+ )
454
+
455
+ cur_size = block_modules[-1].out_features
456
+
457
+ cur_block = LCLResidualBlock(
458
+ in_features=cur_size,
459
+ kernel_size=cur_kernel_width,
460
+ out_feature_sets=cur_out_feature_sets,
461
+ dropout_p=s.dropout_p,
462
+ stochastic_depth_p=s.stochastic_depth_p,
463
+ )
464
+
465
+ block_modules.append(cur_block)
466
+
467
+ return nn.Sequential(*block_modules)
468
+
469
+
470
+ def _adjust_auto_params(
471
+ cur_out_feature_sets: int, cur_kernel_width: int, direction: Literal["down", "up"]
472
+ ) -> tuple[int, int]:
473
+ """
474
+ Down: increase kernel width until it is larger than the number of output feature
475
+ sets.
476
+ Up: increase number of output feature sets until it is larger than the kernel width.
477
+
478
+ """
479
+ if direction == "down":
480
+ while cur_out_feature_sets >= cur_kernel_width:
481
+ cur_kernel_width *= 2
482
+ elif direction == "up":
483
+ while cur_out_feature_sets <= cur_kernel_width:
484
+ cur_out_feature_sets *= 2
485
+ else:
486
+ raise ValueError(f"Unknown direction: {direction}")
487
+
488
+ return cur_out_feature_sets, cur_kernel_width
489
+
490
+
491
+ def generate_lcl_residual_blocks_auto(lcl_parameter_spec: LCParameterSpec):
492
+ """
493
+ TODO: Create some over-engineered abstraction for this and
494
+ ``_generate_lcl_blocks_from_spec`` if feeling bored.
495
+ """
496
+
497
+ s = lcl_parameter_spec
498
+
499
+ first_block = LCLResidualBlock(
500
+ in_features=s.in_features,
501
+ kernel_size=s.kernel_width,
502
+ out_feature_sets=2**s.channel_exp_base,
503
+ dropout_p=s.dropout_p,
504
+ full_preactivation=True,
505
+ )
506
+
507
+ block_modules: list[Union[LCLResidualBlock, LCLAttentionBlock]]
508
+ block_modules = [first_block]
509
+
510
+ if _do_add_attention(
511
+ attention_inclusion_cutoff=s.attention_inclusion_cutoff,
512
+ in_features=first_block.out_features,
513
+ embedding_dim=first_block.out_feature_sets,
514
+ ):
515
+ cur_attention_block = LCLAttentionBlock(
516
+ embedding_dim=first_block.out_feature_sets,
517
+ in_features=first_block.out_features,
518
+ )
519
+ block_modules.append(cur_attention_block)
520
+
521
+ while True:
522
+ cur_no_blocks = len(block_modules)
523
+ cur_index = cur_no_blocks // 2
524
+
525
+ cur_out_feature_sets = 2 ** (s.channel_exp_base + cur_index)
526
+ cur_kernel_width = s.kernel_width
527
+ cur_out_feature_sets, cur_kernel_width = _adjust_auto_params(
528
+ cur_out_feature_sets=cur_out_feature_sets,
529
+ cur_kernel_width=cur_kernel_width,
530
+ direction=s.direction,
531
+ )
532
+
533
+ cur_size = block_modules[-1].out_features
534
+
535
+ if _should_break_auto(
536
+ cur_size=cur_size,
537
+ cutoff=s.cutoff,
538
+ direction=s.direction,
539
+ ):
540
+ break
541
+
542
+ cur_block = LCLResidualBlock(
543
+ in_features=cur_size,
544
+ kernel_size=cur_kernel_width,
545
+ out_feature_sets=cur_out_feature_sets,
546
+ dropout_p=s.dropout_p,
547
+ stochastic_depth_p=s.stochastic_depth_p,
548
+ )
549
+
550
+ block_modules.append(cur_block)
551
+
552
+ if _do_add_attention(
553
+ attention_inclusion_cutoff=s.attention_inclusion_cutoff,
554
+ in_features=cur_block.out_features,
555
+ embedding_dim=cur_block.out_feature_sets,
556
+ ):
557
+ cur_attention_block = LCLAttentionBlock(
558
+ embedding_dim=cur_block.out_feature_sets,
559
+ in_features=cur_block.out_features,
560
+ )
561
+ block_modules.append(cur_attention_block)
562
+
563
+ # logger.debug(
564
+ # "No SplitLinear residual blocks specified in CL arguments. Created %d "
565
+ # "blocks with final output dimension of %d.",
566
+ # len(block_modules),
567
+ # cur_size,
568
+ # )
569
+ return nn.Sequential(*block_modules)
570
+
571
+
572
+ def _should_break_auto(
573
+ cur_size: int, cutoff: int, direction: Literal["up", "down"]
574
+ ) -> bool:
575
+ if direction == "down":
576
+ return cur_size <= cutoff
577
+ elif direction == "up":
578
+ return cur_size >= cutoff
579
+ else:
580
+ raise ValueError(f"Unknown direction: {direction}")
581
+
582
+
583
+ class LCLAttentionBlock(nn.Module):
584
+ def __init__(
585
+ self,
586
+ embedding_dim: int,
587
+ in_features: int,
588
+ num_heads: Union[int, Literal["auto"]] = "auto",
589
+ dropout_p: float = 0.0,
590
+ num_layers: int = 2,
591
+ dim_feedforward_factor: int = 4,
592
+ ):
593
+ super().__init__()
594
+
595
+ self.embedding_dim = embedding_dim
596
+ self.dim_feedforward_factor = dim_feedforward_factor
597
+ self.in_features = in_features
598
+ self.dropout_p = dropout_p
599
+ self.num_layers = num_layers
600
+ self.out_features = in_features
601
+
602
+ if num_heads == "auto":
603
+ self.num_heads = embedding_dim
604
+
605
+ encoder_layer = nn.TransformerEncoderLayer(
606
+ d_model=self.embedding_dim,
607
+ nhead=self.num_heads,
608
+ dim_feedforward=self.embedding_dim * self.dim_feedforward_factor,
609
+ activation="gelu",
610
+ norm_first=True,
611
+ batch_first=True,
612
+ )
613
+
614
+ self.encoder = nn.TransformerEncoder(
615
+ encoder_layer=encoder_layer,
616
+ num_layers=self.num_layers,
617
+ )
618
+
619
+ self.pos_emb = PositionalEmbedding(
620
+ embedding_dim=self.embedding_dim,
621
+ max_length=self.in_features // self.embedding_dim,
622
+ dropout=self.dropout_p,
623
+ zero_init=True,
624
+ )
625
+
626
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
627
+ out = x.reshape(x.shape[0], -1, self.embedding_dim)
628
+ out = self.pos_emb(out)
629
+ out = self.encoder(out)
630
+ out = torch.flatten(input=out, start_dim=1)
631
+
632
+ return out
633
+
634
+
635
+ def _do_add_attention(
636
+ in_features: int, embedding_dim: int, attention_inclusion_cutoff: Optional[int]
637
+ ) -> bool:
638
+ if attention_inclusion_cutoff is None:
639
+ return False
640
+
641
+ attention_sequence_length = in_features // embedding_dim
642
+ return attention_sequence_length <= attention_inclusion_cutoff