autogluon.tabular 1.3.2b20250709__py3-none-any.whl → 1.3.2b20250710__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +3 -0
- autogluon/tabular/models/catboost/callbacks.py +3 -2
- autogluon/tabular/models/catboost/catboost_model.py +2 -2
- autogluon/tabular/models/catboost/catboost_utils.py +7 -3
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +3 -3
- autogluon/tabular/models/lgb/lgb_model.py +2 -2
- autogluon/tabular/models/realmlp/__init__.py +0 -0
- autogluon/tabular/models/realmlp/realmlp_model.py +347 -0
- autogluon/tabular/models/rf/rf_model.py +2 -1
- autogluon/tabular/models/tabicl/__init__.py +0 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +174 -0
- autogluon/tabular/models/tabm/__init__.py +0 -0
- autogluon/tabular/models/tabm/_tabm_internal.py +544 -0
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +807 -0
- autogluon/tabular/models/tabm/tabm_model.py +275 -0
- autogluon/tabular/models/tabm/tabm_reference.py +627 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +3 -3
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -3
- autogluon/tabular/models/xgboost/xgboost_model.py +2 -2
- autogluon/tabular/predictor/predictor.py +5 -3
- autogluon/tabular/registry/_ag_model_registry.py +6 -0
- autogluon/tabular/testing/fit_helper.py +27 -25
- autogluon/tabular/testing/generate_datasets.py +7 -0
- autogluon/tabular/trainer/abstract_trainer.py +1 -1
- autogluon/tabular/trainer/model_presets/presets.py +10 -1
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/METADATA +21 -13
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/RECORD +35 -26
- /autogluon.tabular-1.3.2b20250709-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250710-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/zip-safe +0 -0
@@ -0,0 +1,627 @@
|
|
1
|
+
# License: https://github.com/yandex-research/tabm/blob/main/LICENSE
|
2
|
+
|
3
|
+
# NOTE
|
4
|
+
# The minimum required versions of the dependencies are specified in README.md.
|
5
|
+
|
6
|
+
import itertools
|
7
|
+
from typing import Any, Literal
|
8
|
+
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
from torch import Tensor
|
12
|
+
|
13
|
+
from . import rtdl_num_embeddings
|
14
|
+
from .rtdl_num_embeddings import _Periodic
|
15
|
+
|
16
|
+
|
17
|
+
# ======================================================================================
|
18
|
+
# Initialization
|
19
|
+
# ======================================================================================
|
20
|
+
def init_rsqrt_uniform_(x: Tensor, d: int) -> Tensor:
|
21
|
+
assert d > 0
|
22
|
+
d_rsqrt = d**-0.5
|
23
|
+
return nn.init.uniform_(x, -d_rsqrt, d_rsqrt)
|
24
|
+
|
25
|
+
|
26
|
+
@torch.inference_mode()
|
27
|
+
def init_random_signs_(x: Tensor) -> Tensor:
|
28
|
+
return x.bernoulli_(0.5).mul_(2).add_(-1)
|
29
|
+
|
30
|
+
|
31
|
+
# ======================================================================================
|
32
|
+
# Modules
|
33
|
+
# ======================================================================================
|
34
|
+
class NLinear(nn.Module):
|
35
|
+
"""N linear layers applied in parallel to N disjoint parts of the input.
|
36
|
+
|
37
|
+
**Shape**
|
38
|
+
|
39
|
+
- Input: ``(B, N, in_features)``
|
40
|
+
- Output: ``(B, N, out_features)``
|
41
|
+
|
42
|
+
The i-th linear layer is applied to the i-th matrix of the shape (B, in_features).
|
43
|
+
|
44
|
+
Technically, this is a simplified version of delu.nn.NLinear:
|
45
|
+
https://yura52.github.io/delu/stable/api/generated/delu.nn.NLinear.html.
|
46
|
+
The difference is that this layer supports only 3D inputs
|
47
|
+
with exactly one batch dimension. By contrast, delu.nn.NLinear supports
|
48
|
+
any number of batch dimensions.
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self, n: int, in_features: int, out_features: int, bias: bool = True
|
53
|
+
) -> None:
|
54
|
+
super().__init__()
|
55
|
+
self.weight = nn.Parameter(torch.empty(n, in_features, out_features))
|
56
|
+
self.bias = nn.Parameter(torch.empty(n, out_features)) if bias else None
|
57
|
+
self.reset_parameters()
|
58
|
+
|
59
|
+
def reset_parameters(self):
|
60
|
+
d = self.weight.shape[-2]
|
61
|
+
init_rsqrt_uniform_(self.weight, d)
|
62
|
+
if self.bias is not None:
|
63
|
+
init_rsqrt_uniform_(self.bias, d)
|
64
|
+
|
65
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
66
|
+
assert x.ndim == 3
|
67
|
+
assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]
|
68
|
+
|
69
|
+
x = x.transpose(0, 1)
|
70
|
+
x = x @ self.weight
|
71
|
+
x = x.transpose(0, 1)
|
72
|
+
if self.bias is not None:
|
73
|
+
x = x + self.bias
|
74
|
+
return x
|
75
|
+
|
76
|
+
|
77
|
+
class OneHotEncoding0d(nn.Module):
|
78
|
+
# Input: (*, n_cat_features=len(cardinalities))
|
79
|
+
# Output: (*, sum(cardinalities))
|
80
|
+
|
81
|
+
def __init__(self, cardinalities: list[int]) -> None:
|
82
|
+
super().__init__()
|
83
|
+
self._cardinalities = cardinalities
|
84
|
+
|
85
|
+
def forward(self, x: Tensor) -> Tensor:
|
86
|
+
assert x.ndim >= 1
|
87
|
+
assert x.shape[-1] == len(self._cardinalities)
|
88
|
+
|
89
|
+
return torch.cat(
|
90
|
+
[
|
91
|
+
# NOTE
|
92
|
+
# This is a quick hack to support out-of-vocabulary categories.
|
93
|
+
#
|
94
|
+
# Recall that lib.data.transform_cat encodes categorical features
|
95
|
+
# as follows:
|
96
|
+
# - In-vocabulary values receive indices from `range(cardinality)`.
|
97
|
+
# - All out-of-vocabulary values (i.e. new categories in validation
|
98
|
+
# and test data that are not presented in the training data)
|
99
|
+
# receive the index `cardinality`.
|
100
|
+
#
|
101
|
+
# As such, the line below will produce the standard one-hot encoding for
|
102
|
+
# known categories, and the all-zeros encoding for unknown categories.
|
103
|
+
# This may not be the best approach to deal with unknown values,
|
104
|
+
# but should be enough for our purposes.
|
105
|
+
nn.functional.one_hot(x[..., i], cardinality + 1)[..., :-1]
|
106
|
+
for i, cardinality in enumerate(self._cardinalities)
|
107
|
+
],
|
108
|
+
-1,
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class ScaleEnsemble(nn.Module):
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
k: int,
|
116
|
+
d: int,
|
117
|
+
*,
|
118
|
+
init: Literal['ones', 'normal', 'random-signs'],
|
119
|
+
) -> None:
|
120
|
+
super().__init__()
|
121
|
+
self.weight = nn.Parameter(torch.empty(k, d))
|
122
|
+
self._weight_init = init
|
123
|
+
self.reset_parameters()
|
124
|
+
|
125
|
+
def reset_parameters(self) -> None:
|
126
|
+
if self._weight_init == 'ones':
|
127
|
+
nn.init.ones_(self.weight)
|
128
|
+
elif self._weight_init == 'normal':
|
129
|
+
nn.init.normal_(self.weight)
|
130
|
+
elif self._weight_init == 'random-signs':
|
131
|
+
init_random_signs_(self.weight)
|
132
|
+
else:
|
133
|
+
raise ValueError(f'Unknown weight_init: {self._weight_init}')
|
134
|
+
|
135
|
+
def forward(self, x: Tensor) -> Tensor:
|
136
|
+
assert x.ndim >= 2
|
137
|
+
return x * self.weight
|
138
|
+
|
139
|
+
|
140
|
+
class LinearEfficientEnsemble(nn.Module):
|
141
|
+
"""
|
142
|
+
This layer is a more configurable version of the "BatchEnsemble" layer
|
143
|
+
from the paper
|
144
|
+
"BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning"
|
145
|
+
(link: https://arxiv.org/abs/2002.06715).
|
146
|
+
|
147
|
+
First, this layer allows to select only some of the "ensembled" parts:
|
148
|
+
- the input scaling (r_i in the BatchEnsemble paper)
|
149
|
+
- the output scaling (s_i in the BatchEnsemble paper)
|
150
|
+
- the output bias (not mentioned in the BatchEnsemble paper,
|
151
|
+
but is presented in public implementations)
|
152
|
+
|
153
|
+
Second, the initialization of the scaling weights is configurable
|
154
|
+
through the `scaling_init` argument.
|
155
|
+
|
156
|
+
NOTE
|
157
|
+
The term "adapter" is used in the TabM paper only to tell the story.
|
158
|
+
The original BatchEnsemble paper does NOT use this term. So this class also
|
159
|
+
avoids the term "adapter".
|
160
|
+
"""
|
161
|
+
|
162
|
+
r: None | Tensor
|
163
|
+
s: None | Tensor
|
164
|
+
bias: None | Tensor
|
165
|
+
|
166
|
+
def __init__(
|
167
|
+
self,
|
168
|
+
in_features: int,
|
169
|
+
out_features: int,
|
170
|
+
bias: bool = True,
|
171
|
+
*,
|
172
|
+
k: int,
|
173
|
+
ensemble_scaling_in: bool,
|
174
|
+
ensemble_scaling_out: bool,
|
175
|
+
ensemble_bias: bool,
|
176
|
+
scaling_init: Literal['ones', 'random-signs'],
|
177
|
+
):
|
178
|
+
assert k > 0
|
179
|
+
if ensemble_bias:
|
180
|
+
assert bias
|
181
|
+
super().__init__()
|
182
|
+
|
183
|
+
self.weight = nn.Parameter(torch.empty(out_features, in_features))
|
184
|
+
self.register_parameter(
|
185
|
+
'r',
|
186
|
+
(
|
187
|
+
nn.Parameter(torch.empty(k, in_features))
|
188
|
+
if ensemble_scaling_in
|
189
|
+
else None
|
190
|
+
), # type: ignore[code]
|
191
|
+
)
|
192
|
+
self.register_parameter(
|
193
|
+
's',
|
194
|
+
(
|
195
|
+
nn.Parameter(torch.empty(k, out_features))
|
196
|
+
if ensemble_scaling_out
|
197
|
+
else None
|
198
|
+
), # type: ignore[code]
|
199
|
+
)
|
200
|
+
self.register_parameter(
|
201
|
+
'bias',
|
202
|
+
(
|
203
|
+
nn.Parameter(torch.empty(out_features)) # type: ignore[code]
|
204
|
+
if bias and not ensemble_bias
|
205
|
+
else nn.Parameter(torch.empty(k, out_features))
|
206
|
+
if ensemble_bias
|
207
|
+
else None
|
208
|
+
),
|
209
|
+
)
|
210
|
+
|
211
|
+
self.in_features = in_features
|
212
|
+
self.out_features = out_features
|
213
|
+
self.k = k
|
214
|
+
self.scaling_init = scaling_init
|
215
|
+
|
216
|
+
self.reset_parameters()
|
217
|
+
|
218
|
+
def reset_parameters(self):
|
219
|
+
init_rsqrt_uniform_(self.weight, self.in_features)
|
220
|
+
scaling_init_fn = {'ones': nn.init.ones_, 'random-signs': init_random_signs_}[
|
221
|
+
self.scaling_init
|
222
|
+
]
|
223
|
+
if self.r is not None:
|
224
|
+
scaling_init_fn(self.r)
|
225
|
+
if self.s is not None:
|
226
|
+
scaling_init_fn(self.s)
|
227
|
+
if self.bias is not None:
|
228
|
+
bias_init = torch.empty(
|
229
|
+
# NOTE: the shape of bias_init is (out_features,) not (k, out_features).
|
230
|
+
# It means that all biases have the same initialization.
|
231
|
+
# This is similar to having one shared bias plus
|
232
|
+
# k zero-initialized non-shared biases.
|
233
|
+
self.out_features,
|
234
|
+
dtype=self.weight.dtype,
|
235
|
+
device=self.weight.device,
|
236
|
+
)
|
237
|
+
bias_init = init_rsqrt_uniform_(bias_init, self.in_features)
|
238
|
+
with torch.inference_mode():
|
239
|
+
self.bias.copy_(bias_init)
|
240
|
+
|
241
|
+
def forward(self, x: Tensor) -> Tensor:
|
242
|
+
# x.shape == (B, K, D)
|
243
|
+
assert x.ndim == 3
|
244
|
+
|
245
|
+
# >>> The equation (5) from the BatchEnsemble paper (arXiv v2).
|
246
|
+
if self.r is not None:
|
247
|
+
x = x * self.r
|
248
|
+
x = x @ self.weight.T
|
249
|
+
if self.s is not None:
|
250
|
+
x = x * self.s
|
251
|
+
# <<<
|
252
|
+
|
253
|
+
if self.bias is not None:
|
254
|
+
x = x + self.bias
|
255
|
+
return x
|
256
|
+
|
257
|
+
|
258
|
+
class MLP(nn.Module):
|
259
|
+
def __init__(
|
260
|
+
self,
|
261
|
+
*,
|
262
|
+
d_in: None | int = None,
|
263
|
+
d_out: None | int = None,
|
264
|
+
n_blocks: int,
|
265
|
+
d_block: int,
|
266
|
+
dropout: float,
|
267
|
+
activation: str = 'ReLU',
|
268
|
+
) -> None:
|
269
|
+
super().__init__()
|
270
|
+
|
271
|
+
d_first = d_block if d_in is None else d_in
|
272
|
+
self.blocks = nn.ModuleList(
|
273
|
+
[
|
274
|
+
nn.Sequential(
|
275
|
+
nn.Linear(d_first if i == 0 else d_block, d_block),
|
276
|
+
getattr(nn, activation)(),
|
277
|
+
nn.Dropout(dropout),
|
278
|
+
)
|
279
|
+
for i in range(n_blocks)
|
280
|
+
]
|
281
|
+
)
|
282
|
+
self.output = None if d_out is None else nn.Linear(d_block, d_out)
|
283
|
+
|
284
|
+
def forward(self, x: Tensor) -> Tensor:
|
285
|
+
for block in self.blocks:
|
286
|
+
x = block(x)
|
287
|
+
if self.output is not None:
|
288
|
+
x = self.output(x)
|
289
|
+
return x
|
290
|
+
|
291
|
+
|
292
|
+
def make_efficient_ensemble(module: nn.Module, EnsembleLayer, **kwargs) -> None:
|
293
|
+
"""Replace linear layers with efficient ensembles of linear layers.
|
294
|
+
|
295
|
+
NOTE
|
296
|
+
In the paper, there are no experiments with networks with normalization layers.
|
297
|
+
Perhaps, their trainable weights (the affine transformations) also need
|
298
|
+
"ensemblification" as in the paper about "FiLM-Ensemble".
|
299
|
+
Additional experiments are required to make conclusions.
|
300
|
+
"""
|
301
|
+
for name, submodule in list(module.named_children()):
|
302
|
+
if isinstance(submodule, nn.Linear):
|
303
|
+
module.add_module(
|
304
|
+
name,
|
305
|
+
EnsembleLayer(
|
306
|
+
in_features=submodule.in_features,
|
307
|
+
out_features=submodule.out_features,
|
308
|
+
bias=submodule.bias is not None,
|
309
|
+
**kwargs,
|
310
|
+
),
|
311
|
+
)
|
312
|
+
else:
|
313
|
+
make_efficient_ensemble(submodule, EnsembleLayer, **kwargs)
|
314
|
+
|
315
|
+
|
316
|
+
def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
|
317
|
+
if isinstance(backbone, MLP):
|
318
|
+
return backbone.blocks[0][0] # type: ignore[code]
|
319
|
+
else:
|
320
|
+
raise RuntimeError(f'Unsupported backbone: {backbone}')
|
321
|
+
|
322
|
+
|
323
|
+
@torch.inference_mode()
|
324
|
+
def _init_first_adapter(
|
325
|
+
weight: Tensor,
|
326
|
+
distribution: Literal['normal', 'random-signs'],
|
327
|
+
init_sections: list[int],
|
328
|
+
) -> None:
|
329
|
+
"""Initialize the first adapter.
|
330
|
+
|
331
|
+
NOTE
|
332
|
+
The `init_sections` argument is a historical artifact that accidentally leaked
|
333
|
+
from irrelevant experiments to the final models. Perhaps, the code related
|
334
|
+
to `init_sections` can be simply removed, but this was not tested.
|
335
|
+
"""
|
336
|
+
assert weight.ndim == 2
|
337
|
+
assert weight.shape[1] == sum(init_sections)
|
338
|
+
|
339
|
+
if distribution == 'normal':
|
340
|
+
init_fn_ = nn.init.normal_
|
341
|
+
elif distribution == 'random-signs':
|
342
|
+
init_fn_ = init_random_signs_
|
343
|
+
else:
|
344
|
+
raise ValueError(f'Unknown distribution: {distribution}')
|
345
|
+
|
346
|
+
section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()]
|
347
|
+
for i in range(len(init_sections)):
|
348
|
+
# NOTE
|
349
|
+
# As noted above, this section-based initialization is an arbitrary historical
|
350
|
+
# artifact. Consider the first adapter of one ensemble member.
|
351
|
+
# This adapter vector is implicitly split into "sections",
|
352
|
+
# where one section corresponds to one feature. The code below ensures that
|
353
|
+
# the adapter weights in one section are initialized with the same random value
|
354
|
+
# from the given distribution.
|
355
|
+
w = torch.empty((len(weight), 1), dtype=weight.dtype, device=weight.device)
|
356
|
+
init_fn_(w)
|
357
|
+
weight[:, section_bounds[i] : section_bounds[i + 1]] = w
|
358
|
+
|
359
|
+
|
360
|
+
_CUSTOM_MODULES = {
|
361
|
+
# https://docs.python.org/3/library/stdtypes.html#definition.__name__
|
362
|
+
CustomModule.__name__: CustomModule
|
363
|
+
for CustomModule in [
|
364
|
+
rtdl_num_embeddings.LinearEmbeddings,
|
365
|
+
rtdl_num_embeddings.LinearReLUEmbeddings,
|
366
|
+
rtdl_num_embeddings.PeriodicEmbeddings,
|
367
|
+
rtdl_num_embeddings.PiecewiseLinearEmbeddings,
|
368
|
+
MLP,
|
369
|
+
]
|
370
|
+
}
|
371
|
+
|
372
|
+
|
373
|
+
def make_module(type: str, *args, **kwargs) -> nn.Module:
|
374
|
+
Module = getattr(nn, type, None)
|
375
|
+
if Module is None:
|
376
|
+
Module = _CUSTOM_MODULES[type]
|
377
|
+
return Module(*args, **kwargs)
|
378
|
+
|
379
|
+
|
380
|
+
# ======================================================================================
|
381
|
+
# Optimization
|
382
|
+
# ======================================================================================
|
383
|
+
def default_zero_weight_decay_condition(
|
384
|
+
module_name: str, module: nn.Module, parameter_name: str, parameter: nn.Parameter
|
385
|
+
):
|
386
|
+
del module_name, parameter
|
387
|
+
return parameter_name.endswith('bias') or isinstance(
|
388
|
+
module,
|
389
|
+
nn.BatchNorm1d
|
390
|
+
| nn.LayerNorm
|
391
|
+
| nn.InstanceNorm1d
|
392
|
+
| rtdl_num_embeddings.LinearEmbeddings
|
393
|
+
| rtdl_num_embeddings.LinearReLUEmbeddings
|
394
|
+
| _Periodic,
|
395
|
+
)
|
396
|
+
|
397
|
+
|
398
|
+
def make_parameter_groups(
|
399
|
+
module: nn.Module,
|
400
|
+
zero_weight_decay_condition=default_zero_weight_decay_condition,
|
401
|
+
custom_groups: None | list[dict[str, Any]] = None,
|
402
|
+
) -> list[dict[str, Any]]:
|
403
|
+
if custom_groups is None:
|
404
|
+
custom_groups = []
|
405
|
+
custom_params = frozenset(
|
406
|
+
itertools.chain.from_iterable(group['params'] for group in custom_groups)
|
407
|
+
)
|
408
|
+
assert len(custom_params) == sum(
|
409
|
+
len(group['params']) for group in custom_groups
|
410
|
+
), 'Parameters in custom_groups must not intersect'
|
411
|
+
zero_wd_params = frozenset(
|
412
|
+
p
|
413
|
+
for mn, m in module.named_modules()
|
414
|
+
for pn, p in m.named_parameters()
|
415
|
+
if p not in custom_params and zero_weight_decay_condition(mn, m, pn, p)
|
416
|
+
)
|
417
|
+
default_group = {
|
418
|
+
'params': [
|
419
|
+
p
|
420
|
+
for p in module.parameters()
|
421
|
+
if p not in custom_params and p not in zero_wd_params
|
422
|
+
]
|
423
|
+
}
|
424
|
+
return [
|
425
|
+
default_group,
|
426
|
+
{'params': list(zero_wd_params), 'weight_decay': 0.0},
|
427
|
+
*custom_groups,
|
428
|
+
]
|
429
|
+
|
430
|
+
|
431
|
+
# ======================================================================================
|
432
|
+
# The model
|
433
|
+
# ======================================================================================
|
434
|
+
class Model(nn.Module):
|
435
|
+
"""MLP & TabM."""
|
436
|
+
|
437
|
+
def __init__(
|
438
|
+
self,
|
439
|
+
*,
|
440
|
+
n_num_features: int,
|
441
|
+
cat_cardinalities: list[int],
|
442
|
+
n_classes: None | int,
|
443
|
+
backbone: dict,
|
444
|
+
bins: None | list[Tensor], # For piecewise-linear encoding/embeddings.
|
445
|
+
num_embeddings: None | dict = None,
|
446
|
+
arch_type: Literal[
|
447
|
+
# Plain feed-forward network without any kind of ensembling.
|
448
|
+
'plain',
|
449
|
+
#
|
450
|
+
# TabM
|
451
|
+
'tabm',
|
452
|
+
#
|
453
|
+
# TabM-mini
|
454
|
+
'tabm-mini',
|
455
|
+
#
|
456
|
+
# TabM-packed
|
457
|
+
'tabm-packed',
|
458
|
+
#
|
459
|
+
# TabM. The first adapter is initialized from the normal distribution.
|
460
|
+
# This variant was not used in the paper, but it may be useful in practice.
|
461
|
+
'tabm-normal',
|
462
|
+
#
|
463
|
+
# TabM-mini. The adapter is initialized from the normal distribution.
|
464
|
+
# This variant was not used in the paper.
|
465
|
+
'tabm-mini-normal',
|
466
|
+
],
|
467
|
+
k: None | int = None,
|
468
|
+
share_training_batches: bool = True,
|
469
|
+
) -> None:
|
470
|
+
# >>> Validate arguments.
|
471
|
+
assert n_num_features >= 0
|
472
|
+
assert n_num_features or cat_cardinalities
|
473
|
+
if arch_type == 'plain':
|
474
|
+
assert k is None
|
475
|
+
assert (
|
476
|
+
share_training_batches
|
477
|
+
), 'If `arch_type` is set to "plain", then `simple` must remain True'
|
478
|
+
else:
|
479
|
+
assert k is not None
|
480
|
+
assert k > 0
|
481
|
+
|
482
|
+
super().__init__()
|
483
|
+
|
484
|
+
# >>> Continuous (numerical) features
|
485
|
+
first_adapter_sections = [] # See the comment in `_init_first_adapter`.
|
486
|
+
|
487
|
+
if n_num_features == 0:
|
488
|
+
assert bins is None
|
489
|
+
self.num_module = None
|
490
|
+
d_num = 0
|
491
|
+
|
492
|
+
elif num_embeddings is None:
|
493
|
+
assert bins is None
|
494
|
+
self.num_module = None
|
495
|
+
d_num = n_num_features
|
496
|
+
first_adapter_sections.extend(1 for _ in range(n_num_features))
|
497
|
+
|
498
|
+
else:
|
499
|
+
if bins is None:
|
500
|
+
self.num_module = make_module(
|
501
|
+
**num_embeddings, n_features=n_num_features
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
assert num_embeddings['type'].startswith('PiecewiseLinearEmbeddings')
|
505
|
+
self.num_module = make_module(**num_embeddings, bins=bins)
|
506
|
+
d_num = n_num_features * num_embeddings['d_embedding']
|
507
|
+
first_adapter_sections.extend(
|
508
|
+
num_embeddings['d_embedding'] for _ in range(n_num_features)
|
509
|
+
)
|
510
|
+
|
511
|
+
# >>> Categorical features
|
512
|
+
self.cat_module = (
|
513
|
+
OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
|
514
|
+
)
|
515
|
+
first_adapter_sections.extend(cat_cardinalities)
|
516
|
+
d_cat = sum(cat_cardinalities)
|
517
|
+
|
518
|
+
# >>> Backbone
|
519
|
+
d_flat = d_num + d_cat
|
520
|
+
self.minimal_ensemble_adapter = None
|
521
|
+
# Any backbone can be here but we provide only MLP
|
522
|
+
self.backbone = make_module(d_in=d_flat, **backbone)
|
523
|
+
|
524
|
+
if arch_type != 'plain':
|
525
|
+
assert k is not None
|
526
|
+
first_adapter_init = (
|
527
|
+
None
|
528
|
+
if arch_type == 'tabm-packed'
|
529
|
+
else 'normal'
|
530
|
+
if arch_type in ('tabm-mini-normal', 'tabm-normal')
|
531
|
+
# For other arch_types, the initialization depends
|
532
|
+
# on the presence of num_embeddings.
|
533
|
+
else 'random-signs'
|
534
|
+
if num_embeddings is None
|
535
|
+
else 'normal'
|
536
|
+
)
|
537
|
+
|
538
|
+
if arch_type in ('tabm', 'tabm-normal'):
|
539
|
+
# Like BatchEnsemble, but all multiplicative adapters,
|
540
|
+
# except for the very first one, are initialized with ones.
|
541
|
+
assert first_adapter_init is not None
|
542
|
+
make_efficient_ensemble(
|
543
|
+
self.backbone,
|
544
|
+
LinearEfficientEnsemble,
|
545
|
+
k=k,
|
546
|
+
ensemble_scaling_in=True,
|
547
|
+
ensemble_scaling_out=True,
|
548
|
+
ensemble_bias=True,
|
549
|
+
scaling_init='ones',
|
550
|
+
)
|
551
|
+
_init_first_adapter(
|
552
|
+
_get_first_ensemble_layer(self.backbone).r, # type: ignore[code]
|
553
|
+
first_adapter_init,
|
554
|
+
first_adapter_sections,
|
555
|
+
)
|
556
|
+
|
557
|
+
elif arch_type in ('tabm-mini', 'tabm-mini-normal'):
|
558
|
+
# MiniEnsemble
|
559
|
+
assert first_adapter_init is not None
|
560
|
+
self.minimal_ensemble_adapter = ScaleEnsemble(
|
561
|
+
k,
|
562
|
+
d_flat,
|
563
|
+
init='random-signs' if num_embeddings is None else 'normal',
|
564
|
+
)
|
565
|
+
_init_first_adapter(
|
566
|
+
self.minimal_ensemble_adapter.weight, # type: ignore[code]
|
567
|
+
first_adapter_init,
|
568
|
+
first_adapter_sections,
|
569
|
+
)
|
570
|
+
|
571
|
+
elif arch_type == 'tabm-packed':
|
572
|
+
# Packed ensemble.
|
573
|
+
# In terms of the Packed Ensembles paper by Laurent et al.,
|
574
|
+
# TabM-packed is PackedEnsemble(alpha=k, M=k, gamma=1).
|
575
|
+
assert first_adapter_init is None
|
576
|
+
make_efficient_ensemble(self.backbone, NLinear, n=k)
|
577
|
+
|
578
|
+
else:
|
579
|
+
raise ValueError(f'Unknown arch_type: {arch_type}')
|
580
|
+
|
581
|
+
# >>> Output
|
582
|
+
d_block = backbone['d_block']
|
583
|
+
d_out = 1 if n_classes is None else n_classes
|
584
|
+
self.output = (
|
585
|
+
nn.Linear(d_block, d_out)
|
586
|
+
if arch_type == 'plain'
|
587
|
+
else NLinear(k, d_block, d_out) # type: ignore[code]
|
588
|
+
)
|
589
|
+
|
590
|
+
# >>>
|
591
|
+
self.arch_type = arch_type
|
592
|
+
self.k = k
|
593
|
+
self.share_training_batches = share_training_batches
|
594
|
+
|
595
|
+
def forward(
|
596
|
+
self, x_num: None | Tensor = None, x_cat: None | Tensor = None
|
597
|
+
) -> Tensor:
|
598
|
+
x = []
|
599
|
+
if x_num is not None:
|
600
|
+
x.append(x_num if self.num_module is None else self.num_module(x_num))
|
601
|
+
if x_cat is None:
|
602
|
+
assert self.cat_module is None
|
603
|
+
else:
|
604
|
+
assert self.cat_module is not None
|
605
|
+
x.append(self.cat_module(x_cat).float())
|
606
|
+
x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
|
607
|
+
|
608
|
+
if self.k is not None:
|
609
|
+
if self.share_training_batches or not self.training:
|
610
|
+
# (B, D) -> (B, K, D)
|
611
|
+
x = x[:, None].expand(-1, self.k, -1)
|
612
|
+
else:
|
613
|
+
# (B * K, D) -> (B, K, D)
|
614
|
+
x = x.reshape(len(x) // self.k, self.k, *x.shape[1:])
|
615
|
+
if self.minimal_ensemble_adapter is not None:
|
616
|
+
x = self.minimal_ensemble_adapter(x)
|
617
|
+
else:
|
618
|
+
assert self.minimal_ensemble_adapter is None
|
619
|
+
|
620
|
+
x = self.backbone(x)
|
621
|
+
x = self.output(x)
|
622
|
+
if self.k is None:
|
623
|
+
# Adjust the output shape for plain networks to make them compatible
|
624
|
+
# with the rest of the script (loss, metrics, predictions, ...).
|
625
|
+
# (B, D_OUT) -> (B, 1, D_OUT)
|
626
|
+
x = x[:, None]
|
627
|
+
return x
|
@@ -311,11 +311,11 @@ class TabPFNMixModel(AbstractModel):
|
|
311
311
|
|
312
312
|
def _get_maximum_resources(self) -> dict[str, int | float]:
|
313
313
|
# torch model trains slower when utilizing virtual cores and this issue scale up when the number of cpu cores increases
|
314
|
-
return {"num_cpus": ResourceManager.
|
314
|
+
return {"num_cpus": ResourceManager.get_cpu_count(only_physical_cores=True)}
|
315
315
|
|
316
316
|
def _get_default_resources(self) -> tuple[int, float]:
|
317
|
-
#
|
318
|
-
num_cpus = ResourceManager.
|
317
|
+
# only_physical_cores=True is faster in training
|
318
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
319
319
|
num_gpus = 0
|
320
320
|
return num_cpus, num_gpus
|
321
321
|
|
@@ -814,11 +814,11 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
|
|
814
814
|
|
815
815
|
def _get_maximum_resources(self) -> Dict[str, Union[int, float]]:
|
816
816
|
# torch model trains slower when utilizing virtual cores and this issue scale up when the number of cpu cores increases
|
817
|
-
return {"num_cpus": ResourceManager.
|
817
|
+
return {"num_cpus": ResourceManager.get_cpu_count(only_physical_cores=True)}
|
818
818
|
|
819
819
|
def _get_default_resources(self):
|
820
|
-
#
|
821
|
-
num_cpus = ResourceManager.
|
820
|
+
# only_physical_cores=True is faster in training
|
821
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
822
822
|
num_gpus = 0
|
823
823
|
return num_cpus, num_gpus
|
824
824
|
|
@@ -310,8 +310,8 @@ class XGBoostModel(AbstractModel):
|
|
310
310
|
|
311
311
|
@disable_if_lite_mode(ret=(1, 0))
|
312
312
|
def _get_default_resources(self):
|
313
|
-
#
|
314
|
-
num_cpus = ResourceManager.
|
313
|
+
# only_physical_cores=True is faster in training
|
314
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
315
315
|
num_gpus = 0
|
316
316
|
return num_cpus, num_gpus
|
317
317
|
|