gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
|
@@ -0,0 +1,642 @@
|
|
|
1
|
+
from copy import copy
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import (
|
|
5
|
+
TYPE_CHECKING,
|
|
6
|
+
Callable,
|
|
7
|
+
List,
|
|
8
|
+
Literal,
|
|
9
|
+
Optional,
|
|
10
|
+
Protocol,
|
|
11
|
+
Sequence,
|
|
12
|
+
Union,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
# from eir.models.input.sequence.transformer_models import PositionalEmbedding
|
|
19
|
+
# from eir.models.layers.lcl_layers import LCL, LCLResidualBlock
|
|
20
|
+
# from eir.utils.logging import get_logger
|
|
21
|
+
|
|
22
|
+
from .lcl_layers import LCL, LCLResidualBlock
|
|
23
|
+
from .transformer_models import PositionalEmbedding
|
|
24
|
+
|
|
25
|
+
from torchvision.ops import StochasticDepth
|
|
26
|
+
from aislib.pytorch_modules import Swish
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FlattenFunc(Protocol):
|
|
33
|
+
def __call__(
|
|
34
|
+
self,
|
|
35
|
+
x: torch.Tensor,
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class SimpleLCLModelConfig:
|
|
42
|
+
"""
|
|
43
|
+
:param fc_repr_dim:
|
|
44
|
+
Controls the number of output sets in the first and only split layer. Analogous
|
|
45
|
+
to channels in CNNs.
|
|
46
|
+
:param num_lcl_chunks:
|
|
47
|
+
Controls the number of splits applied to the input. E.g. with a input with of
|
|
48
|
+
800, using ``num_lcl_chunks=100`` will result in a kernel width of 8,
|
|
49
|
+
meaning 8 elements in the flattened input. If using a SNP inputs with a one-hot
|
|
50
|
+
encoding of 4 possible values, this will result in 8/2 = 2 SNPs per locally
|
|
51
|
+
connected area.
|
|
52
|
+
:param l1:
|
|
53
|
+
L1 regularization applied to the first and only locally connected layer.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
fc_repr_dim: int = 12
|
|
57
|
+
num_lcl_chunks: int = 64
|
|
58
|
+
l1: float = 0.00
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# class SimpleLCLModel(nn.Module):
|
|
62
|
+
# def __init__(
|
|
63
|
+
# self,
|
|
64
|
+
# model_config: SimpleLCLModelConfig,
|
|
65
|
+
# data_dimensions: "DataDimensions",
|
|
66
|
+
# flatten_fn: FlattenFunc,
|
|
67
|
+
# ):
|
|
68
|
+
# super().__init__()
|
|
69
|
+
#
|
|
70
|
+
# self.model_config = model_config
|
|
71
|
+
# self.data_dimensions = data_dimensions
|
|
72
|
+
# self.flatten_fn = flatten_fn
|
|
73
|
+
#
|
|
74
|
+
# num_chunks = self.model_config.num_lcl_chunks
|
|
75
|
+
# self.fc_0 = LCL(
|
|
76
|
+
# in_features=self.fc_1_in_features,
|
|
77
|
+
# out_feature_sets=self.model_config.fc_repr_dim,
|
|
78
|
+
# num_chunks=num_chunks,
|
|
79
|
+
# bias=True,
|
|
80
|
+
# )
|
|
81
|
+
#
|
|
82
|
+
# self._init_weights()
|
|
83
|
+
#
|
|
84
|
+
# @property
|
|
85
|
+
# def fc_1_in_features(self) -> int:
|
|
86
|
+
# return self.data_dimensions.num_elements()
|
|
87
|
+
#
|
|
88
|
+
# @property
|
|
89
|
+
# def l1_penalized_weights(self) -> torch.Tensor:
|
|
90
|
+
# return self.fc_0.weight
|
|
91
|
+
#
|
|
92
|
+
# @property
|
|
93
|
+
# def num_out_features(self) -> int:
|
|
94
|
+
# return self.fc_0.out_features
|
|
95
|
+
#
|
|
96
|
+
# def _init_weights(self):
|
|
97
|
+
# pass
|
|
98
|
+
#
|
|
99
|
+
# def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
# out = self.flatten_fn(x=input)
|
|
101
|
+
#
|
|
102
|
+
# out = self.fc_0(out)
|
|
103
|
+
#
|
|
104
|
+
# return out
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class LCLModelConfig:
|
|
109
|
+
"""
|
|
110
|
+
Note that when using the automatic network setup, kernel widths will get expanded
|
|
111
|
+
to ensure that the feature representations become smaller as they are propagated
|
|
112
|
+
through the network.
|
|
113
|
+
|
|
114
|
+
:param patch_size:
|
|
115
|
+
Controls the size of the patches used in the first layer. If set to ``None``,
|
|
116
|
+
the input is flattened according to the torch ``flatten`` function. Note that
|
|
117
|
+
when using this parameter, we generally want the kernel width to be set to
|
|
118
|
+
the multiplication of the patch size.
|
|
119
|
+
|
|
120
|
+
:param layers:
|
|
121
|
+
Controls the number of layers in the model. If set to ``None``, the model will
|
|
122
|
+
automatically set up the number of layers according to the ``cutoff`` parameter
|
|
123
|
+
value.
|
|
124
|
+
|
|
125
|
+
:param kernel_width:
|
|
126
|
+
With of the locally connected kernels. Note that in the context of genomic
|
|
127
|
+
inputs this refers to the flattened input,
|
|
128
|
+
meaning that if we have a one-hot encoding of 4 values (e.g. SNPs), 12
|
|
129
|
+
refers to 12/4 = 3 SNPs per locally connected window. Can be set to ``None`` if
|
|
130
|
+
the ``num_lcl_chunks`` parameter is set, which means that the kernel width
|
|
131
|
+
will be set automatically according to
|
|
132
|
+
|
|
133
|
+
:param first_kernel_expansion:
|
|
134
|
+
Factor to extend the first kernel. This value can both be positive or negative.
|
|
135
|
+
For example in the case of ``kernel_width=12``, setting
|
|
136
|
+
``first_kernel_expansion=2`` means that the first kernel will have a width of
|
|
137
|
+
24, whereas other kernels will have a width of 12. When using a negative value,
|
|
138
|
+
divides the first kernel by the value instead of multiplying.
|
|
139
|
+
|
|
140
|
+
:param channel_exp_base:
|
|
141
|
+
Which power of 2 to use in order to set the number of channels/weight sets in
|
|
142
|
+
the network. For example, setting ``channel_exp_base=3`` means that 2**3=8
|
|
143
|
+
weight sets will be used.
|
|
144
|
+
|
|
145
|
+
:param first_channel_expansion:
|
|
146
|
+
Whether to expand / shrink the number of channels in the first layer as compared
|
|
147
|
+
to other layers in the network. Works analogously to the
|
|
148
|
+
``first_kernel_expansion`` parameter.
|
|
149
|
+
|
|
150
|
+
:param num_lcl_chunks:
|
|
151
|
+
Controls the number of splits applied to the input. E.g. with a input width of
|
|
152
|
+
800, using ``num_lcl_chunks=100`` will result in a kernel width of 8,
|
|
153
|
+
meaning 8 elements in the flattened input. If using a SNP inputs with a one-hot
|
|
154
|
+
encoding of 4 possible values, this will result in 8/2 = 2 SNPs per locally
|
|
155
|
+
connected area.
|
|
156
|
+
|
|
157
|
+
:param rb_do:
|
|
158
|
+
Dropout in the residual blocks.
|
|
159
|
+
|
|
160
|
+
:param stochastic_depth_p:
|
|
161
|
+
Probability of dropping input.
|
|
162
|
+
|
|
163
|
+
:param l1:
|
|
164
|
+
L1 regularization applied to the first layer in the network.
|
|
165
|
+
|
|
166
|
+
:param cutoff:
|
|
167
|
+
Feature dimension cutoff where the automatic network setup stops adding layers.
|
|
168
|
+
The 'auto' option is only supported when using the model for array *outputs*,
|
|
169
|
+
and will set the cutoff to roughly the number of output features.
|
|
170
|
+
|
|
171
|
+
:param direction:
|
|
172
|
+
Whether to use a "down" or "up" network. "Down" means that the feature
|
|
173
|
+
representation will get smaller as it is propagated through the network, whereas
|
|
174
|
+
"up" means that the feature representation will get larger.
|
|
175
|
+
|
|
176
|
+
:param attention_inclusion_cutoff:
|
|
177
|
+
Cutoff to start including attention blocks in the network. If set to ``None``,
|
|
178
|
+
no attention blocks will be included. The cutoff here refers to the "length"
|
|
179
|
+
dimension of the input after reshaping according to the output_feature_sets
|
|
180
|
+
in the preceding layer. For example, if we 1024 output features, and we have
|
|
181
|
+
4 output feature sets, the length dimension will be 1024/4 = 256. With an
|
|
182
|
+
attention cutoff >= 256, the attention block will be included.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
patch_size: Optional[tuple[int, int, int]] = None
|
|
186
|
+
|
|
187
|
+
layers: Union[None, List[int]] = None
|
|
188
|
+
|
|
189
|
+
#kernel_width: int | Literal["patch"] = 16
|
|
190
|
+
|
|
191
|
+
kernel_width: Union[int, Literal["patch"]] = 16
|
|
192
|
+
first_kernel_expansion: int = -2
|
|
193
|
+
|
|
194
|
+
channel_exp_base: int = 2
|
|
195
|
+
first_channel_expansion: int = 1
|
|
196
|
+
|
|
197
|
+
num_lcl_chunks: Union[None, int] = None
|
|
198
|
+
|
|
199
|
+
rb_do: float = 0.10
|
|
200
|
+
stochastic_depth_p: float = 0.00
|
|
201
|
+
l1: float = 0.00
|
|
202
|
+
|
|
203
|
+
cutoff: Union[int , Literal["auto"]] = 1024
|
|
204
|
+
direction: Literal["down", "up"] = "down"
|
|
205
|
+
attention_inclusion_cutoff: Optional[int] = None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class LCLModel(nn.Module):
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
# model_config: Union[LCLModelConfig, "LCLOutputModelConfig"],
|
|
212
|
+
# data_dimensions: "DataDimensions",
|
|
213
|
+
data_dimensions,
|
|
214
|
+
# flatten_fn: FlattenFunc,
|
|
215
|
+
# dynamic_cutoff: Optional[int] = None,
|
|
216
|
+
):
|
|
217
|
+
super().__init__()
|
|
218
|
+
|
|
219
|
+
# self.model_config = model_config
|
|
220
|
+
self.data_dimensions = data_dimensions
|
|
221
|
+
# self.flatten_fn = flatten_fn
|
|
222
|
+
|
|
223
|
+
# kernel_width = parse_kernel_width(
|
|
224
|
+
# # kernel_width=self.model_config.kernel_width,
|
|
225
|
+
# # patch_size=self.model_config.patch_size,
|
|
226
|
+
# kernel_width=16,
|
|
227
|
+
# patch_size=None
|
|
228
|
+
#
|
|
229
|
+
# )
|
|
230
|
+
kernel_width=16
|
|
231
|
+
|
|
232
|
+
fc_0_kernel_size = calc_value_after_expansion(
|
|
233
|
+
# base=kernel_width,
|
|
234
|
+
# expansion=self.model_config.first_kernel_expansion,
|
|
235
|
+
base=16,
|
|
236
|
+
expansion=-2,
|
|
237
|
+
|
|
238
|
+
)
|
|
239
|
+
fc_0_channel_exponent = calc_value_after_expansion(
|
|
240
|
+
# base=self.model_config.channel_exp_base,
|
|
241
|
+
# expansion=self.model_config.first_channel_expansion,
|
|
242
|
+
base=2,
|
|
243
|
+
expansion=1
|
|
244
|
+
)
|
|
245
|
+
self.fc_0 = LCL(
|
|
246
|
+
in_features=self.fc_1_in_features, # 1*9033*4 = 36132
|
|
247
|
+
out_feature_sets=2**fc_0_channel_exponent, #4
|
|
248
|
+
kernel_size=fc_0_kernel_size, #8
|
|
249
|
+
bias=True,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# cutoff = dynamic_cutoff or self.model_config.cutoff
|
|
253
|
+
cutoff = 1024
|
|
254
|
+
assert isinstance(cutoff, int)
|
|
255
|
+
|
|
256
|
+
lcl_parameter_spec = LCParameterSpec(
|
|
257
|
+
# in_features=self.fc_0.out_features,
|
|
258
|
+
# kernel_width=kernel_width,
|
|
259
|
+
# channel_exp_base=self.model_config.channel_exp_base,
|
|
260
|
+
# dropout_p=self.model_config.rb_do,
|
|
261
|
+
# cutoff=cutoff,
|
|
262
|
+
# stochastic_depth_p=self.model_config.stochastic_depth_p,
|
|
263
|
+
# attention_inclusion_cutoff=self.model_config.attention_inclusion_cutoff,
|
|
264
|
+
# direction=self.model_config.direction,
|
|
265
|
+
in_features=self.fc_0.out_features,
|
|
266
|
+
kernel_width=kernel_width,
|
|
267
|
+
channel_exp_base=2,
|
|
268
|
+
dropout_p=0.10,
|
|
269
|
+
cutoff=cutoff,
|
|
270
|
+
stochastic_depth_p=0.00,
|
|
271
|
+
attention_inclusion_cutoff=None,
|
|
272
|
+
direction="down",
|
|
273
|
+
)
|
|
274
|
+
self.lcl_blocks = _get_lcl_blocks(
|
|
275
|
+
# lcl_spec=lcl_parameter_spec,
|
|
276
|
+
# block_layer_spec=self.model_config.layers,
|
|
277
|
+
lcl_spec=lcl_parameter_spec,
|
|
278
|
+
block_layer_spec=None,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
self._init_weights()
|
|
282
|
+
|
|
283
|
+
# MLP residual
|
|
284
|
+
self.full_preactivation = False
|
|
285
|
+
in_features = self.lcl_blocks[-1].out_features
|
|
286
|
+
# Cattle 84 Chicken136 Chickpea 104 Cotton 8 Loblolly_Pine 152 Maize 148 Millet 656
|
|
287
|
+
# Mouse 320 Pig 1016 Rapeseed 312 Rice 56 Soybean 104 Wheat 536
|
|
288
|
+
out_features = 1
|
|
289
|
+
if in_features == out_features:
|
|
290
|
+
self.downsample_identity = lambda x: x
|
|
291
|
+
else:
|
|
292
|
+
self.downsample_identity = nn.Linear(
|
|
293
|
+
in_features=in_features, out_features=out_features, bias=True
|
|
294
|
+
)
|
|
295
|
+
self.norm_1 = nn.LayerNorm(normalized_shape=in_features)
|
|
296
|
+
self.fc_1 = nn.Linear(
|
|
297
|
+
in_features=in_features, out_features=out_features, bias=True
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
self.act_1 = Swish()
|
|
301
|
+
dropout_p = 0.0
|
|
302
|
+
self.do = nn.Dropout(p=dropout_p)
|
|
303
|
+
self.fc_2 = nn.Linear(
|
|
304
|
+
in_features=out_features, out_features=out_features, bias=True
|
|
305
|
+
)
|
|
306
|
+
self.stochastic_depth_p = 0.0
|
|
307
|
+
self.stochastic_depth = StochasticDepth(p=self.stochastic_depth_p, mode="batch")
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def fc_1_in_features(self) -> int:
|
|
312
|
+
return self.data_dimensions.num_elements() #
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def l1_penalized_weights(self) -> torch.Tensor:
|
|
316
|
+
return self.fc_0.weight
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def num_out_features(self) -> int:
|
|
320
|
+
return self.lcl_blocks[-1].out_features
|
|
321
|
+
|
|
322
|
+
def _init_weights(self):
|
|
323
|
+
pass
|
|
324
|
+
|
|
325
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
326
|
+
# out = self.flatten_fn(x=input) # original
|
|
327
|
+
|
|
328
|
+
# by ww
|
|
329
|
+
out = input # by ww
|
|
330
|
+
a = input.shape[0]
|
|
331
|
+
out = out.reshape(input.shape[0],1, -1)
|
|
332
|
+
|
|
333
|
+
out = self.fc_0(out)
|
|
334
|
+
out = self.lcl_blocks(out) # (4, 72)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# MLP Residual
|
|
338
|
+
out = self.norm_1(out)
|
|
339
|
+
|
|
340
|
+
identity = out if self.full_preactivation else out
|
|
341
|
+
identity = self.downsample_identity(identity)
|
|
342
|
+
|
|
343
|
+
out = self.fc_1(out)
|
|
344
|
+
|
|
345
|
+
out = self.act_1(out)
|
|
346
|
+
out = self.do(out)
|
|
347
|
+
out = self.fc_2(out)
|
|
348
|
+
|
|
349
|
+
out = self.stochastic_depth(out)
|
|
350
|
+
|
|
351
|
+
return out + identity
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
# return out
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def parse_kernel_width(
|
|
358
|
+
kernel_width: Union[int, Literal["patch"]],
|
|
359
|
+
patch_size: Optional[tuple[int, int, int]],
|
|
360
|
+
) -> int:
|
|
361
|
+
if kernel_width == "patch":
|
|
362
|
+
if patch_size is None:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
"kernel_width set to 'patch', but no patch_size was specified."
|
|
365
|
+
)
|
|
366
|
+
kernel_width = patch_size[0] * patch_size[1] * patch_size[2]
|
|
367
|
+
return kernel_width
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def flatten_h_w_fortran(x: torch.Tensor) -> torch.Tensor:
|
|
371
|
+
"""
|
|
372
|
+
This is needed when e.g. flattening one-hot inputs that are ordered in a columns
|
|
373
|
+
wise fashion (meaning that each column is a one-hot feature),
|
|
374
|
+
and we want to make sure the first part of the flattened tensor is the first column,
|
|
375
|
+
i.e. first one-hot element.
|
|
376
|
+
"""
|
|
377
|
+
column_order_flattened = x.transpose(2, 3).flatten(1)
|
|
378
|
+
return column_order_flattened
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def calc_value_after_expansion(base: int, expansion: int, min_value: int = 0) -> int:
|
|
382
|
+
if expansion > 0:
|
|
383
|
+
return base * expansion
|
|
384
|
+
elif expansion < 0:
|
|
385
|
+
abs_expansion = abs(expansion)
|
|
386
|
+
return max(min_value, base // abs_expansion)
|
|
387
|
+
return base
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@dataclass
|
|
391
|
+
class LCParameterSpec:
|
|
392
|
+
in_features: int
|
|
393
|
+
kernel_width: int
|
|
394
|
+
channel_exp_base: int
|
|
395
|
+
dropout_p: float
|
|
396
|
+
stochastic_depth_p: float
|
|
397
|
+
cutoff: int
|
|
398
|
+
attention_inclusion_cutoff: Optional[int] = None
|
|
399
|
+
direction: Literal["down", "up"] = "down"
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _get_lcl_blocks(
|
|
403
|
+
lcl_spec: LCParameterSpec,
|
|
404
|
+
block_layer_spec: Optional[Sequence[int]],
|
|
405
|
+
) -> nn.Sequential:
|
|
406
|
+
factory = _get_lcl_block_factory(block_layer_spec=block_layer_spec)
|
|
407
|
+
|
|
408
|
+
blocks = factory(lcl_spec)
|
|
409
|
+
|
|
410
|
+
return blocks
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def _get_lcl_block_factory(
|
|
414
|
+
block_layer_spec: Optional[Sequence[int]],
|
|
415
|
+
) -> Callable[[LCParameterSpec], nn.Sequential]:
|
|
416
|
+
if not block_layer_spec:
|
|
417
|
+
return generate_lcl_residual_blocks_auto
|
|
418
|
+
|
|
419
|
+
auto_factory = partial(
|
|
420
|
+
_generate_lcl_blocks_from_spec, block_layer_spec=block_layer_spec
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
return auto_factory
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _generate_lcl_blocks_from_spec(
|
|
427
|
+
lcl_parameter_spec: LCParameterSpec,
|
|
428
|
+
block_layer_spec: List[int],
|
|
429
|
+
) -> nn.Sequential:
|
|
430
|
+
s = lcl_parameter_spec
|
|
431
|
+
block_layer_spec_copy = copy(block_layer_spec)
|
|
432
|
+
|
|
433
|
+
first_block = LCLResidualBlock(
|
|
434
|
+
in_features=s.in_features,
|
|
435
|
+
kernel_size=s.kernel_width,
|
|
436
|
+
out_feature_sets=2**s.channel_exp_base,
|
|
437
|
+
dropout_p=s.dropout_p,
|
|
438
|
+
full_preactivation=True,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
block_modules = [first_block]
|
|
442
|
+
block_layer_spec_copy[0] -= 1
|
|
443
|
+
|
|
444
|
+
for cur_layer_index, block_dim in enumerate(block_layer_spec_copy):
|
|
445
|
+
for block in range(block_dim):
|
|
446
|
+
cur_out_feature_sets = 2 ** (s.channel_exp_base + cur_layer_index)
|
|
447
|
+
cur_kernel_width = s.kernel_width
|
|
448
|
+
|
|
449
|
+
cur_out_feature_sets, cur_kernel_width = _adjust_auto_params(
|
|
450
|
+
cur_out_feature_sets=cur_out_feature_sets,
|
|
451
|
+
cur_kernel_width=cur_kernel_width,
|
|
452
|
+
direction=s.direction,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
cur_size = block_modules[-1].out_features
|
|
456
|
+
|
|
457
|
+
cur_block = LCLResidualBlock(
|
|
458
|
+
in_features=cur_size,
|
|
459
|
+
kernel_size=cur_kernel_width,
|
|
460
|
+
out_feature_sets=cur_out_feature_sets,
|
|
461
|
+
dropout_p=s.dropout_p,
|
|
462
|
+
stochastic_depth_p=s.stochastic_depth_p,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
block_modules.append(cur_block)
|
|
466
|
+
|
|
467
|
+
return nn.Sequential(*block_modules)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _adjust_auto_params(
|
|
471
|
+
cur_out_feature_sets: int, cur_kernel_width: int, direction: Literal["down", "up"]
|
|
472
|
+
) -> tuple[int, int]:
|
|
473
|
+
"""
|
|
474
|
+
Down: increase kernel width until it is larger than the number of output feature
|
|
475
|
+
sets.
|
|
476
|
+
Up: increase number of output feature sets until it is larger than the kernel width.
|
|
477
|
+
|
|
478
|
+
"""
|
|
479
|
+
if direction == "down":
|
|
480
|
+
while cur_out_feature_sets >= cur_kernel_width:
|
|
481
|
+
cur_kernel_width *= 2
|
|
482
|
+
elif direction == "up":
|
|
483
|
+
while cur_out_feature_sets <= cur_kernel_width:
|
|
484
|
+
cur_out_feature_sets *= 2
|
|
485
|
+
else:
|
|
486
|
+
raise ValueError(f"Unknown direction: {direction}")
|
|
487
|
+
|
|
488
|
+
return cur_out_feature_sets, cur_kernel_width
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def generate_lcl_residual_blocks_auto(lcl_parameter_spec: LCParameterSpec):
|
|
492
|
+
"""
|
|
493
|
+
TODO: Create some over-engineered abstraction for this and
|
|
494
|
+
``_generate_lcl_blocks_from_spec`` if feeling bored.
|
|
495
|
+
"""
|
|
496
|
+
|
|
497
|
+
s = lcl_parameter_spec
|
|
498
|
+
|
|
499
|
+
first_block = LCLResidualBlock(
|
|
500
|
+
in_features=s.in_features,
|
|
501
|
+
kernel_size=s.kernel_width,
|
|
502
|
+
out_feature_sets=2**s.channel_exp_base,
|
|
503
|
+
dropout_p=s.dropout_p,
|
|
504
|
+
full_preactivation=True,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
block_modules: list[Union[LCLResidualBlock, LCLAttentionBlock]]
|
|
508
|
+
block_modules = [first_block]
|
|
509
|
+
|
|
510
|
+
if _do_add_attention(
|
|
511
|
+
attention_inclusion_cutoff=s.attention_inclusion_cutoff,
|
|
512
|
+
in_features=first_block.out_features,
|
|
513
|
+
embedding_dim=first_block.out_feature_sets,
|
|
514
|
+
):
|
|
515
|
+
cur_attention_block = LCLAttentionBlock(
|
|
516
|
+
embedding_dim=first_block.out_feature_sets,
|
|
517
|
+
in_features=first_block.out_features,
|
|
518
|
+
)
|
|
519
|
+
block_modules.append(cur_attention_block)
|
|
520
|
+
|
|
521
|
+
while True:
|
|
522
|
+
cur_no_blocks = len(block_modules)
|
|
523
|
+
cur_index = cur_no_blocks // 2
|
|
524
|
+
|
|
525
|
+
cur_out_feature_sets = 2 ** (s.channel_exp_base + cur_index)
|
|
526
|
+
cur_kernel_width = s.kernel_width
|
|
527
|
+
cur_out_feature_sets, cur_kernel_width = _adjust_auto_params(
|
|
528
|
+
cur_out_feature_sets=cur_out_feature_sets,
|
|
529
|
+
cur_kernel_width=cur_kernel_width,
|
|
530
|
+
direction=s.direction,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
cur_size = block_modules[-1].out_features
|
|
534
|
+
|
|
535
|
+
if _should_break_auto(
|
|
536
|
+
cur_size=cur_size,
|
|
537
|
+
cutoff=s.cutoff,
|
|
538
|
+
direction=s.direction,
|
|
539
|
+
):
|
|
540
|
+
break
|
|
541
|
+
|
|
542
|
+
cur_block = LCLResidualBlock(
|
|
543
|
+
in_features=cur_size,
|
|
544
|
+
kernel_size=cur_kernel_width,
|
|
545
|
+
out_feature_sets=cur_out_feature_sets,
|
|
546
|
+
dropout_p=s.dropout_p,
|
|
547
|
+
stochastic_depth_p=s.stochastic_depth_p,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
block_modules.append(cur_block)
|
|
551
|
+
|
|
552
|
+
if _do_add_attention(
|
|
553
|
+
attention_inclusion_cutoff=s.attention_inclusion_cutoff,
|
|
554
|
+
in_features=cur_block.out_features,
|
|
555
|
+
embedding_dim=cur_block.out_feature_sets,
|
|
556
|
+
):
|
|
557
|
+
cur_attention_block = LCLAttentionBlock(
|
|
558
|
+
embedding_dim=cur_block.out_feature_sets,
|
|
559
|
+
in_features=cur_block.out_features,
|
|
560
|
+
)
|
|
561
|
+
block_modules.append(cur_attention_block)
|
|
562
|
+
|
|
563
|
+
# logger.debug(
|
|
564
|
+
# "No SplitLinear residual blocks specified in CL arguments. Created %d "
|
|
565
|
+
# "blocks with final output dimension of %d.",
|
|
566
|
+
# len(block_modules),
|
|
567
|
+
# cur_size,
|
|
568
|
+
# )
|
|
569
|
+
return nn.Sequential(*block_modules)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def _should_break_auto(
|
|
573
|
+
cur_size: int, cutoff: int, direction: Literal["up", "down"]
|
|
574
|
+
) -> bool:
|
|
575
|
+
if direction == "down":
|
|
576
|
+
return cur_size <= cutoff
|
|
577
|
+
elif direction == "up":
|
|
578
|
+
return cur_size >= cutoff
|
|
579
|
+
else:
|
|
580
|
+
raise ValueError(f"Unknown direction: {direction}")
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
class LCLAttentionBlock(nn.Module):
|
|
584
|
+
def __init__(
|
|
585
|
+
self,
|
|
586
|
+
embedding_dim: int,
|
|
587
|
+
in_features: int,
|
|
588
|
+
num_heads: Union[int, Literal["auto"]] = "auto",
|
|
589
|
+
dropout_p: float = 0.0,
|
|
590
|
+
num_layers: int = 2,
|
|
591
|
+
dim_feedforward_factor: int = 4,
|
|
592
|
+
):
|
|
593
|
+
super().__init__()
|
|
594
|
+
|
|
595
|
+
self.embedding_dim = embedding_dim
|
|
596
|
+
self.dim_feedforward_factor = dim_feedforward_factor
|
|
597
|
+
self.in_features = in_features
|
|
598
|
+
self.dropout_p = dropout_p
|
|
599
|
+
self.num_layers = num_layers
|
|
600
|
+
self.out_features = in_features
|
|
601
|
+
|
|
602
|
+
if num_heads == "auto":
|
|
603
|
+
self.num_heads = embedding_dim
|
|
604
|
+
|
|
605
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
606
|
+
d_model=self.embedding_dim,
|
|
607
|
+
nhead=self.num_heads,
|
|
608
|
+
dim_feedforward=self.embedding_dim * self.dim_feedforward_factor,
|
|
609
|
+
activation="gelu",
|
|
610
|
+
norm_first=True,
|
|
611
|
+
batch_first=True,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
self.encoder = nn.TransformerEncoder(
|
|
615
|
+
encoder_layer=encoder_layer,
|
|
616
|
+
num_layers=self.num_layers,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
self.pos_emb = PositionalEmbedding(
|
|
620
|
+
embedding_dim=self.embedding_dim,
|
|
621
|
+
max_length=self.in_features // self.embedding_dim,
|
|
622
|
+
dropout=self.dropout_p,
|
|
623
|
+
zero_init=True,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
627
|
+
out = x.reshape(x.shape[0], -1, self.embedding_dim)
|
|
628
|
+
out = self.pos_emb(out)
|
|
629
|
+
out = self.encoder(out)
|
|
630
|
+
out = torch.flatten(input=out, start_dim=1)
|
|
631
|
+
|
|
632
|
+
return out
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def _do_add_attention(
|
|
636
|
+
in_features: int, embedding_dim: int, attention_inclusion_cutoff: Optional[int]
|
|
637
|
+
) -> bool:
|
|
638
|
+
if attention_inclusion_cutoff is None:
|
|
639
|
+
return False
|
|
640
|
+
|
|
641
|
+
attention_sequence_length = in_features // embedding_dim
|
|
642
|
+
return attention_sequence_length <= attention_inclusion_cutoff
|