gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,235 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from aislib.misc_utils import get_logger
6
+ from aislib.pytorch_modules import Swish
7
+ from torch import nn
8
+ from torch.nn import Parameter
9
+ from torch.nn import functional as F
10
+ from torchvision.ops import StochasticDepth
11
+
12
+ logger = get_logger(name=__name__)
13
+
14
+
15
+ class LCL(nn.Module):
16
+ __constants__ = ["bias", "in_features", "out_features"]
17
+
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ out_feature_sets: int,
22
+ num_chunks: int = 10,
23
+ kernel_size: Optional[int] = None, #8
24
+ bias: Optional[bool] = True,
25
+ ):
26
+ super().__init__()
27
+
28
+ self.in_features = in_features # 9033*4=36132
29
+ self.out_feature_sets = out_feature_sets # 4
30
+ self.num_chunks = num_chunks # 4517
31
+
32
+ if kernel_size:
33
+ self.kernel_size = kernel_size #8
34
+ self.num_chunks = int(math.ceil(in_features / kernel_size))
35
+ logger.debug(
36
+ "%s: Setting num chunks to %d as kernel size of %d was passed in.",
37
+ self.__class__,
38
+ self.num_chunks,
39
+ kernel_size,
40
+ )
41
+ else:
42
+ self.kernel_size = int(math.ceil(self.in_features / self.num_chunks))
43
+ logger.debug(
44
+ "%s :Setting kernel size to %d as number of "
45
+ "chunks of %d was passed in.",
46
+ self.__class__,
47
+ self.kernel_size,
48
+ self.num_chunks,
49
+ )
50
+
51
+ assert self.kernel_size is not None
52
+
53
+ self.out_features = self.out_feature_sets * self.num_chunks
54
+
55
+ self.padding = _find_lcl_padding_needed(
56
+ input_size=self.in_features,
57
+ kernel_size=self.kernel_size,
58
+ num_chunks=self.num_chunks,
59
+ )
60
+
61
+ self.weight = Parameter(
62
+ torch.Tensor(self.out_feature_sets, self.num_chunks, self.kernel_size),
63
+ requires_grad=True,
64
+ )
65
+
66
+
67
+ if bias:
68
+ self.bias = Parameter(torch.Tensor(self.out_features), requires_grad=True)
69
+ else:
70
+ self.register_parameter("bias", None)
71
+ self.reset_parameters()
72
+
73
+ def reset_parameters(self):
74
+ """
75
+ NOTE: This default init actually works quite well, as compared to initializing
76
+ for each chunk (meaning higher weights at init). In that case, the model takes
77
+ longer to get to a good performance as it spends a while driving the weights
78
+ down.
79
+ """
80
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
81
+ if self.bias is not None:
82
+ nn.init.zeros_(self.bias)
83
+
84
+ def extra_repr(self):
85
+ return (
86
+ "in_features={}, num_chunks={}, kernel_size={}, "
87
+ "out_feature_sets={}, out_features={}, bias={}".format(
88
+ self.in_features,
89
+ self.num_chunks,
90
+ self.kernel_size,
91
+ self.out_feature_sets,
92
+ self.out_features,
93
+ self.bias is not None,
94
+ )
95
+ )
96
+
97
+ def forward(self, input: torch.Tensor):
98
+ input_padded = F.pad(input=input, pad=[0, self.padding, 0, 0])
99
+
100
+ input_reshaped = input_padded.reshape(
101
+ input.shape[0], 1, self.num_chunks, self.kernel_size
102
+ )
103
+
104
+
105
+ out = calc_lcl_forward(input=input_reshaped, weight=self.weight, bias=self.bias)
106
+ return out
107
+
108
+
109
+ def _find_lcl_padding_needed(input_size: int, kernel_size: int, num_chunks: int):
110
+ return num_chunks * kernel_size - input_size
111
+
112
+
113
+ def calc_lcl_forward(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
114
+ """
115
+ n: num samples
116
+ c: num chunks (height)
117
+ s: kernel size (width)
118
+ o: output sets
119
+ """
120
+
121
+ summed = torch.einsum("nhw, ohw -> noh", input.squeeze(1), weight)
122
+ flattened = summed.flatten(start_dim=1)
123
+
124
+ final = flattened
125
+ if bias is not None:
126
+ final = flattened + bias
127
+
128
+ return final
129
+
130
+
131
+ class LCLResidualBlock(nn.Module):
132
+ def __init__(
133
+ self,
134
+ in_features: int,
135
+ out_feature_sets: int,
136
+ kernel_size: int,
137
+ dropout_p: float = 0.0,
138
+ stochastic_depth_p: float = 0.0,
139
+ full_preactivation: bool = False,
140
+ reduce_both: bool = True,
141
+ ):
142
+ super().__init__()
143
+
144
+ self.in_features = in_features
145
+ self.kernel_size = kernel_size
146
+ self.out_feature_sets = out_feature_sets
147
+
148
+ self.dropout_p = dropout_p
149
+ self.full_preactivation = full_preactivation
150
+ self.reduce_both = reduce_both
151
+ self.stochastic_depth_p = stochastic_depth_p
152
+
153
+ self.norm_1 = nn.LayerNorm(normalized_shape=in_features)
154
+ self.fc_1 = LCL(
155
+ in_features=self.in_features,
156
+ out_feature_sets=self.out_feature_sets,
157
+ bias=True,
158
+ kernel_size=self.kernel_size,
159
+ )
160
+
161
+ self.act_1 = Swish()
162
+ self.do = nn.Dropout(p=dropout_p)
163
+
164
+ fc_2_kwargs = _get_lcl_2_kwargs(
165
+ in_features=self.fc_1.out_features,
166
+ out_feature_sets=self.out_feature_sets,
167
+ bias=True,
168
+ kernel_size=self.kernel_size,
169
+ reduce_both=self.reduce_both,
170
+ )
171
+ self.fc_2 = LCL(**fc_2_kwargs)
172
+
173
+ self.out_features = self.fc_2.out_features
174
+
175
+ if in_features == self.out_features:
176
+ self.downsample_identity = lambda x: x
177
+ else:
178
+ self.downsample_identity = LCL(
179
+ in_features=self.in_features,
180
+ out_feature_sets=1,
181
+ bias=True,
182
+ num_chunks=self.fc_2.out_features,
183
+ )
184
+
185
+ self.stochastic_depth = StochasticDepth(p=stochastic_depth_p, mode="batch")
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ out = self.norm_1(x)
189
+
190
+ identity = out if self.full_preactivation else x
191
+ identity = self.downsample_identity(identity)
192
+
193
+ out = self.fc_1(out)
194
+
195
+ out = self.act_1(out)
196
+ out = self.do(out)
197
+ out = self.fc_2(out)
198
+
199
+ out = self.stochastic_depth(out)
200
+
201
+ return out + identity
202
+
203
+
204
+ def _get_lcl_2_kwargs(
205
+ in_features: int,
206
+ out_feature_sets: int,
207
+ bias: bool,
208
+ reduce_both: bool,
209
+ kernel_size: int,
210
+ ):
211
+ common_kwargs = {
212
+ "in_features": in_features,
213
+ "out_feature_sets": out_feature_sets,
214
+ "bias": bias,
215
+ }
216
+
217
+ if reduce_both:
218
+ common_kwargs["kernel_size"] = kernel_size
219
+ else:
220
+ num_chunks = _calculate_num_chunks_for_equal_lcl_out_features(
221
+ in_features=in_features, out_feature_sets=out_feature_sets
222
+ )
223
+ common_kwargs["num_chunks"] = num_chunks
224
+
225
+ return common_kwargs
226
+
227
+
228
+ def _calculate_num_chunks_for_equal_lcl_out_features(
229
+ in_features: int,
230
+ out_feature_sets: int,
231
+ ) -> int:
232
+ """
233
+ Ensure total out features are equal to in features.
234
+ """
235
+ return in_features // out_feature_sets
@@ -0,0 +1,59 @@
1
+ import copy
2
+ import logging
3
+
4
+ from termcolor import colored
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TQDMLoggingHandler(logging.Handler):
9
+ def __init__(self, level=logging.NOTSET):
10
+ super().__init__(level)
11
+
12
+ def emit(self, record):
13
+ try:
14
+ msg = self.format(record)
15
+ tqdm.write(msg)
16
+ self.flush()
17
+ except (KeyboardInterrupt, SystemExit):
18
+ raise
19
+ except Exception:
20
+ self.handleError(record)
21
+
22
+
23
+ class ColoredFormatter(logging.Formatter):
24
+ def format(self, record):
25
+ record_copy = copy.copy(record)
26
+
27
+ if record_copy.levelno == logging.DEBUG:
28
+ record_copy.levelname = colored(record_copy.levelname, "blue")
29
+ elif record_copy.levelno == logging.INFO:
30
+ record_copy.levelname = colored(record_copy.levelname, "green")
31
+ elif record_copy.levelno == logging.WARNING:
32
+ record_copy.levelname = colored(record_copy.levelname, "yellow")
33
+ elif record_copy.levelno == logging.ERROR:
34
+ record_copy.levelname = colored(record_copy.levelname, "red")
35
+
36
+ return super().format(record_copy)
37
+
38
+
39
+ def get_logger(name: str, tqdm_compatible: bool = False) -> logging.Logger:
40
+ """
41
+ Creates a logger with a debug level and a custom format.
42
+
43
+ tqdm_compatible: Overwrite default stream.write in favor of tqdm.write
44
+ to avoid breaking progress bar.
45
+ """
46
+ logger_ = logging.getLogger(name)
47
+ logger_.setLevel(logging.DEBUG)
48
+
49
+ handler: logging.Handler | TQDMLoggingHandler
50
+ handler = TQDMLoggingHandler() if tqdm_compatible else logging.StreamHandler()
51
+
52
+ formatter = ColoredFormatter(
53
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
54
+ datefmt="%H:%M:%S",
55
+ )
56
+ handler.setFormatter(formatter)
57
+
58
+ logger_.addHandler(handler)
59
+ return logger_
@@ -0,0 +1,92 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ from aislib.pytorch_modules import Swish
6
+ from torch import nn
7
+ from torchvision.ops import StochasticDepth
8
+
9
+
10
+ @dataclass
11
+ class ResidualMLPConfig:
12
+ """
13
+ :param layers:
14
+ Number of residual MLP layers to use in for each output predictor after fusing.
15
+
16
+ :param fc_task_dim:
17
+ Number of hidden nodes in each MLP residual block.
18
+
19
+ :param rb_do:
20
+ Dropout in each MLP residual block.
21
+
22
+ :param fc_do:
23
+ Dropout before final layer.
24
+
25
+ :param stochastic_depth_p:
26
+ Probability of dropping input.
27
+
28
+ """
29
+
30
+ layers: List[int] = field(default_factory=lambda: [2])
31
+
32
+ fc_task_dim: int = 256
33
+
34
+ rb_do: float = 0.10
35
+ fc_do: float = 0.10
36
+
37
+ stochastic_depth_p: float = 0.10
38
+
39
+
40
+ class MLPResidualBlock(nn.Module):
41
+ def __init__(
42
+ self,
43
+ in_features: int,
44
+ out_features: int,
45
+ dropout_p: float = 0.0,
46
+ full_preactivation: bool = False,
47
+ stochastic_depth_p: float = 0.0,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.in_features = in_features
52
+ self.out_features = out_features
53
+ self.dropout_p = dropout_p
54
+ self.full_preactivation = full_preactivation
55
+ self.stochastic_depth_p = stochastic_depth_p
56
+
57
+ self.norm_1 = nn.LayerNorm(normalized_shape=in_features)
58
+
59
+ self.fc_1 = nn.Linear(
60
+ in_features=in_features, out_features=out_features, bias=True
61
+ )
62
+
63
+ self.act_1 = Swish()
64
+ self.do = nn.Dropout(p=dropout_p)
65
+ self.fc_2 = nn.Linear(
66
+ in_features=out_features, out_features=out_features, bias=True
67
+ )
68
+
69
+ if in_features == out_features:
70
+ self.downsample_identity = lambda x: x
71
+ else:
72
+ self.downsample_identity = nn.Linear(
73
+ in_features=in_features, out_features=out_features, bias=True
74
+ )
75
+
76
+ self.stochastic_depth = StochasticDepth(p=self.stochastic_depth_p, mode="batch")
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ out = self.norm_1(x)
80
+
81
+ identity = out if self.full_preactivation else x
82
+ identity = self.downsample_identity(identity)
83
+
84
+ out = self.fc_1(out)
85
+
86
+ out = self.act_1(out)
87
+ out = self.do(out)
88
+ out = self.fc_2(out)
89
+
90
+ out = self.stochastic_depth(out)
91
+
92
+ return out + identity