gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from aislib.misc_utils import get_logger
|
|
6
|
+
from aislib.pytorch_modules import Swish
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.nn import Parameter
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
from torchvision.ops import StochasticDepth
|
|
11
|
+
|
|
12
|
+
logger = get_logger(name=__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LCL(nn.Module):
|
|
16
|
+
__constants__ = ["bias", "in_features", "out_features"]
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
in_features: int,
|
|
21
|
+
out_feature_sets: int,
|
|
22
|
+
num_chunks: int = 10,
|
|
23
|
+
kernel_size: Optional[int] = None, #8
|
|
24
|
+
bias: Optional[bool] = True,
|
|
25
|
+
):
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
self.in_features = in_features # 9033*4=36132
|
|
29
|
+
self.out_feature_sets = out_feature_sets # 4
|
|
30
|
+
self.num_chunks = num_chunks # 4517
|
|
31
|
+
|
|
32
|
+
if kernel_size:
|
|
33
|
+
self.kernel_size = kernel_size #8
|
|
34
|
+
self.num_chunks = int(math.ceil(in_features / kernel_size))
|
|
35
|
+
logger.debug(
|
|
36
|
+
"%s: Setting num chunks to %d as kernel size of %d was passed in.",
|
|
37
|
+
self.__class__,
|
|
38
|
+
self.num_chunks,
|
|
39
|
+
kernel_size,
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
self.kernel_size = int(math.ceil(self.in_features / self.num_chunks))
|
|
43
|
+
logger.debug(
|
|
44
|
+
"%s :Setting kernel size to %d as number of "
|
|
45
|
+
"chunks of %d was passed in.",
|
|
46
|
+
self.__class__,
|
|
47
|
+
self.kernel_size,
|
|
48
|
+
self.num_chunks,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
assert self.kernel_size is not None
|
|
52
|
+
|
|
53
|
+
self.out_features = self.out_feature_sets * self.num_chunks
|
|
54
|
+
|
|
55
|
+
self.padding = _find_lcl_padding_needed(
|
|
56
|
+
input_size=self.in_features,
|
|
57
|
+
kernel_size=self.kernel_size,
|
|
58
|
+
num_chunks=self.num_chunks,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.weight = Parameter(
|
|
62
|
+
torch.Tensor(self.out_feature_sets, self.num_chunks, self.kernel_size),
|
|
63
|
+
requires_grad=True,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if bias:
|
|
68
|
+
self.bias = Parameter(torch.Tensor(self.out_features), requires_grad=True)
|
|
69
|
+
else:
|
|
70
|
+
self.register_parameter("bias", None)
|
|
71
|
+
self.reset_parameters()
|
|
72
|
+
|
|
73
|
+
def reset_parameters(self):
|
|
74
|
+
"""
|
|
75
|
+
NOTE: This default init actually works quite well, as compared to initializing
|
|
76
|
+
for each chunk (meaning higher weights at init). In that case, the model takes
|
|
77
|
+
longer to get to a good performance as it spends a while driving the weights
|
|
78
|
+
down.
|
|
79
|
+
"""
|
|
80
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
81
|
+
if self.bias is not None:
|
|
82
|
+
nn.init.zeros_(self.bias)
|
|
83
|
+
|
|
84
|
+
def extra_repr(self):
|
|
85
|
+
return (
|
|
86
|
+
"in_features={}, num_chunks={}, kernel_size={}, "
|
|
87
|
+
"out_feature_sets={}, out_features={}, bias={}".format(
|
|
88
|
+
self.in_features,
|
|
89
|
+
self.num_chunks,
|
|
90
|
+
self.kernel_size,
|
|
91
|
+
self.out_feature_sets,
|
|
92
|
+
self.out_features,
|
|
93
|
+
self.bias is not None,
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def forward(self, input: torch.Tensor):
|
|
98
|
+
input_padded = F.pad(input=input, pad=[0, self.padding, 0, 0])
|
|
99
|
+
|
|
100
|
+
input_reshaped = input_padded.reshape(
|
|
101
|
+
input.shape[0], 1, self.num_chunks, self.kernel_size
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
out = calc_lcl_forward(input=input_reshaped, weight=self.weight, bias=self.bias)
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _find_lcl_padding_needed(input_size: int, kernel_size: int, num_chunks: int):
|
|
110
|
+
return num_chunks * kernel_size - input_size
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def calc_lcl_forward(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
|
|
114
|
+
"""
|
|
115
|
+
n: num samples
|
|
116
|
+
c: num chunks (height)
|
|
117
|
+
s: kernel size (width)
|
|
118
|
+
o: output sets
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
summed = torch.einsum("nhw, ohw -> noh", input.squeeze(1), weight)
|
|
122
|
+
flattened = summed.flatten(start_dim=1)
|
|
123
|
+
|
|
124
|
+
final = flattened
|
|
125
|
+
if bias is not None:
|
|
126
|
+
final = flattened + bias
|
|
127
|
+
|
|
128
|
+
return final
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class LCLResidualBlock(nn.Module):
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
in_features: int,
|
|
135
|
+
out_feature_sets: int,
|
|
136
|
+
kernel_size: int,
|
|
137
|
+
dropout_p: float = 0.0,
|
|
138
|
+
stochastic_depth_p: float = 0.0,
|
|
139
|
+
full_preactivation: bool = False,
|
|
140
|
+
reduce_both: bool = True,
|
|
141
|
+
):
|
|
142
|
+
super().__init__()
|
|
143
|
+
|
|
144
|
+
self.in_features = in_features
|
|
145
|
+
self.kernel_size = kernel_size
|
|
146
|
+
self.out_feature_sets = out_feature_sets
|
|
147
|
+
|
|
148
|
+
self.dropout_p = dropout_p
|
|
149
|
+
self.full_preactivation = full_preactivation
|
|
150
|
+
self.reduce_both = reduce_both
|
|
151
|
+
self.stochastic_depth_p = stochastic_depth_p
|
|
152
|
+
|
|
153
|
+
self.norm_1 = nn.LayerNorm(normalized_shape=in_features)
|
|
154
|
+
self.fc_1 = LCL(
|
|
155
|
+
in_features=self.in_features,
|
|
156
|
+
out_feature_sets=self.out_feature_sets,
|
|
157
|
+
bias=True,
|
|
158
|
+
kernel_size=self.kernel_size,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self.act_1 = Swish()
|
|
162
|
+
self.do = nn.Dropout(p=dropout_p)
|
|
163
|
+
|
|
164
|
+
fc_2_kwargs = _get_lcl_2_kwargs(
|
|
165
|
+
in_features=self.fc_1.out_features,
|
|
166
|
+
out_feature_sets=self.out_feature_sets,
|
|
167
|
+
bias=True,
|
|
168
|
+
kernel_size=self.kernel_size,
|
|
169
|
+
reduce_both=self.reduce_both,
|
|
170
|
+
)
|
|
171
|
+
self.fc_2 = LCL(**fc_2_kwargs)
|
|
172
|
+
|
|
173
|
+
self.out_features = self.fc_2.out_features
|
|
174
|
+
|
|
175
|
+
if in_features == self.out_features:
|
|
176
|
+
self.downsample_identity = lambda x: x
|
|
177
|
+
else:
|
|
178
|
+
self.downsample_identity = LCL(
|
|
179
|
+
in_features=self.in_features,
|
|
180
|
+
out_feature_sets=1,
|
|
181
|
+
bias=True,
|
|
182
|
+
num_chunks=self.fc_2.out_features,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
self.stochastic_depth = StochasticDepth(p=stochastic_depth_p, mode="batch")
|
|
186
|
+
|
|
187
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
188
|
+
out = self.norm_1(x)
|
|
189
|
+
|
|
190
|
+
identity = out if self.full_preactivation else x
|
|
191
|
+
identity = self.downsample_identity(identity)
|
|
192
|
+
|
|
193
|
+
out = self.fc_1(out)
|
|
194
|
+
|
|
195
|
+
out = self.act_1(out)
|
|
196
|
+
out = self.do(out)
|
|
197
|
+
out = self.fc_2(out)
|
|
198
|
+
|
|
199
|
+
out = self.stochastic_depth(out)
|
|
200
|
+
|
|
201
|
+
return out + identity
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _get_lcl_2_kwargs(
|
|
205
|
+
in_features: int,
|
|
206
|
+
out_feature_sets: int,
|
|
207
|
+
bias: bool,
|
|
208
|
+
reduce_both: bool,
|
|
209
|
+
kernel_size: int,
|
|
210
|
+
):
|
|
211
|
+
common_kwargs = {
|
|
212
|
+
"in_features": in_features,
|
|
213
|
+
"out_feature_sets": out_feature_sets,
|
|
214
|
+
"bias": bias,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
if reduce_both:
|
|
218
|
+
common_kwargs["kernel_size"] = kernel_size
|
|
219
|
+
else:
|
|
220
|
+
num_chunks = _calculate_num_chunks_for_equal_lcl_out_features(
|
|
221
|
+
in_features=in_features, out_feature_sets=out_feature_sets
|
|
222
|
+
)
|
|
223
|
+
common_kwargs["num_chunks"] = num_chunks
|
|
224
|
+
|
|
225
|
+
return common_kwargs
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _calculate_num_chunks_for_equal_lcl_out_features(
|
|
229
|
+
in_features: int,
|
|
230
|
+
out_feature_sets: int,
|
|
231
|
+
) -> int:
|
|
232
|
+
"""
|
|
233
|
+
Ensure total out features are equal to in features.
|
|
234
|
+
"""
|
|
235
|
+
return in_features // out_feature_sets
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from termcolor import colored
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TQDMLoggingHandler(logging.Handler):
|
|
9
|
+
def __init__(self, level=logging.NOTSET):
|
|
10
|
+
super().__init__(level)
|
|
11
|
+
|
|
12
|
+
def emit(self, record):
|
|
13
|
+
try:
|
|
14
|
+
msg = self.format(record)
|
|
15
|
+
tqdm.write(msg)
|
|
16
|
+
self.flush()
|
|
17
|
+
except (KeyboardInterrupt, SystemExit):
|
|
18
|
+
raise
|
|
19
|
+
except Exception:
|
|
20
|
+
self.handleError(record)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ColoredFormatter(logging.Formatter):
|
|
24
|
+
def format(self, record):
|
|
25
|
+
record_copy = copy.copy(record)
|
|
26
|
+
|
|
27
|
+
if record_copy.levelno == logging.DEBUG:
|
|
28
|
+
record_copy.levelname = colored(record_copy.levelname, "blue")
|
|
29
|
+
elif record_copy.levelno == logging.INFO:
|
|
30
|
+
record_copy.levelname = colored(record_copy.levelname, "green")
|
|
31
|
+
elif record_copy.levelno == logging.WARNING:
|
|
32
|
+
record_copy.levelname = colored(record_copy.levelname, "yellow")
|
|
33
|
+
elif record_copy.levelno == logging.ERROR:
|
|
34
|
+
record_copy.levelname = colored(record_copy.levelname, "red")
|
|
35
|
+
|
|
36
|
+
return super().format(record_copy)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_logger(name: str, tqdm_compatible: bool = False) -> logging.Logger:
|
|
40
|
+
"""
|
|
41
|
+
Creates a logger with a debug level and a custom format.
|
|
42
|
+
|
|
43
|
+
tqdm_compatible: Overwrite default stream.write in favor of tqdm.write
|
|
44
|
+
to avoid breaking progress bar.
|
|
45
|
+
"""
|
|
46
|
+
logger_ = logging.getLogger(name)
|
|
47
|
+
logger_.setLevel(logging.DEBUG)
|
|
48
|
+
|
|
49
|
+
handler: logging.Handler | TQDMLoggingHandler
|
|
50
|
+
handler = TQDMLoggingHandler() if tqdm_compatible else logging.StreamHandler()
|
|
51
|
+
|
|
52
|
+
formatter = ColoredFormatter(
|
|
53
|
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
54
|
+
datefmt="%H:%M:%S",
|
|
55
|
+
)
|
|
56
|
+
handler.setFormatter(formatter)
|
|
57
|
+
|
|
58
|
+
logger_.addHandler(handler)
|
|
59
|
+
return logger_
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from aislib.pytorch_modules import Swish
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torchvision.ops import StochasticDepth
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ResidualMLPConfig:
|
|
12
|
+
"""
|
|
13
|
+
:param layers:
|
|
14
|
+
Number of residual MLP layers to use in for each output predictor after fusing.
|
|
15
|
+
|
|
16
|
+
:param fc_task_dim:
|
|
17
|
+
Number of hidden nodes in each MLP residual block.
|
|
18
|
+
|
|
19
|
+
:param rb_do:
|
|
20
|
+
Dropout in each MLP residual block.
|
|
21
|
+
|
|
22
|
+
:param fc_do:
|
|
23
|
+
Dropout before final layer.
|
|
24
|
+
|
|
25
|
+
:param stochastic_depth_p:
|
|
26
|
+
Probability of dropping input.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
layers: List[int] = field(default_factory=lambda: [2])
|
|
31
|
+
|
|
32
|
+
fc_task_dim: int = 256
|
|
33
|
+
|
|
34
|
+
rb_do: float = 0.10
|
|
35
|
+
fc_do: float = 0.10
|
|
36
|
+
|
|
37
|
+
stochastic_depth_p: float = 0.10
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MLPResidualBlock(nn.Module):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
in_features: int,
|
|
44
|
+
out_features: int,
|
|
45
|
+
dropout_p: float = 0.0,
|
|
46
|
+
full_preactivation: bool = False,
|
|
47
|
+
stochastic_depth_p: float = 0.0,
|
|
48
|
+
):
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self.in_features = in_features
|
|
52
|
+
self.out_features = out_features
|
|
53
|
+
self.dropout_p = dropout_p
|
|
54
|
+
self.full_preactivation = full_preactivation
|
|
55
|
+
self.stochastic_depth_p = stochastic_depth_p
|
|
56
|
+
|
|
57
|
+
self.norm_1 = nn.LayerNorm(normalized_shape=in_features)
|
|
58
|
+
|
|
59
|
+
self.fc_1 = nn.Linear(
|
|
60
|
+
in_features=in_features, out_features=out_features, bias=True
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self.act_1 = Swish()
|
|
64
|
+
self.do = nn.Dropout(p=dropout_p)
|
|
65
|
+
self.fc_2 = nn.Linear(
|
|
66
|
+
in_features=out_features, out_features=out_features, bias=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if in_features == out_features:
|
|
70
|
+
self.downsample_identity = lambda x: x
|
|
71
|
+
else:
|
|
72
|
+
self.downsample_identity = nn.Linear(
|
|
73
|
+
in_features=in_features, out_features=out_features, bias=True
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.stochastic_depth = StochasticDepth(p=self.stochastic_depth_p, mode="batch")
|
|
77
|
+
|
|
78
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
79
|
+
out = self.norm_1(x)
|
|
80
|
+
|
|
81
|
+
identity = out if self.full_preactivation else x
|
|
82
|
+
identity = self.downsample_identity(identity)
|
|
83
|
+
|
|
84
|
+
out = self.fc_1(out)
|
|
85
|
+
|
|
86
|
+
out = self.act_1(out)
|
|
87
|
+
out = self.do(out)
|
|
88
|
+
out = self.fc_2(out)
|
|
89
|
+
|
|
90
|
+
out = self.stochastic_depth(out)
|
|
91
|
+
|
|
92
|
+
return out + identity
|