autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/pipeline_presets.py +130 -0
  4. autogluon/tabular/configs/presets_configs.py +51 -26
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  7. autogluon/tabular/models/__init__.py +6 -1
  8. autogluon/tabular/models/_utils/rapids_utils.py +1 -1
  9. autogluon/tabular/models/automm/automm_model.py +2 -0
  10. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  11. autogluon/tabular/models/catboost/callbacks.py +3 -2
  12. autogluon/tabular/models/catboost/catboost_model.py +15 -9
  13. autogluon/tabular/models/catboost/catboost_utils.py +17 -3
  14. autogluon/tabular/models/ebm/__init__.py +0 -0
  15. autogluon/tabular/models/ebm/ebm_model.py +259 -0
  16. autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
  17. autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  18. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  19. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
  20. autogluon/tabular/models/knn/knn_model.py +7 -3
  21. autogluon/tabular/models/lgb/lgb_model.py +60 -21
  22. autogluon/tabular/models/lr/lr_model.py +6 -1
  23. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  24. autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
  25. autogluon/tabular/models/mitra/__init__.py +0 -0
  26. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  27. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  28. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  29. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  30. autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
  31. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  32. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  33. autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
  34. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  35. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  36. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
  37. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
  38. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  39. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  40. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
  41. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
  42. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  43. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  44. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  45. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  46. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  47. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  48. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  49. autogluon/tabular/models/mitra/mitra_model.py +380 -0
  50. autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
  51. autogluon/tabular/models/realmlp/__init__.py +0 -0
  52. autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
  53. autogluon/tabular/models/rf/rf_model.py +11 -6
  54. autogluon/tabular/models/tabicl/__init__.py +0 -0
  55. autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
  56. autogluon/tabular/models/tabm/__init__.py +0 -0
  57. autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
  58. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
  59. autogluon/tabular/models/tabm/tabm_model.py +356 -0
  60. autogluon/tabular/models/tabm/tabm_reference.py +631 -0
  61. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
  62. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  63. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  64. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  65. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  66. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  67. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  68. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  69. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  70. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
  71. autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
  72. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
  73. autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
  74. autogluon/tabular/predictor/predictor.py +147 -84
  75. autogluon/tabular/registry/_ag_model_registry.py +12 -2
  76. autogluon/tabular/testing/fit_helper.py +57 -27
  77. autogluon/tabular/testing/generate_datasets.py +7 -0
  78. autogluon/tabular/trainer/abstract_trainer.py +3 -1
  79. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  80. autogluon/tabular/version.py +1 -1
  81. autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
  82. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
  83. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
  84. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
  85. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  86. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  87. autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
  88. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
  89. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
  90. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
  91. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
  92. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
@@ -0,0 +1,631 @@
1
+ # License: https://github.com/yandex-research/tabm/blob/main/LICENSE
2
+
3
+ # NOTE
4
+ # The minimum required versions of the dependencies are specified in README.md.
5
+
6
+ from __future__ import annotations
7
+
8
+ import itertools
9
+ from typing import Any, Literal, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+ from . import rtdl_num_embeddings
16
+ from .rtdl_num_embeddings import _Periodic
17
+
18
+
19
+ # ======================================================================================
20
+ # Initialization
21
+ # ======================================================================================
22
+ def init_rsqrt_uniform_(x: Tensor, d: int) -> Tensor:
23
+ assert d > 0
24
+ d_rsqrt = d**-0.5
25
+ return nn.init.uniform_(x, -d_rsqrt, d_rsqrt)
26
+
27
+
28
+ @torch.inference_mode()
29
+ def init_random_signs_(x: Tensor) -> Tensor:
30
+ return x.bernoulli_(0.5).mul_(2).add_(-1)
31
+
32
+
33
+ # ======================================================================================
34
+ # Modules
35
+ # ======================================================================================
36
+ class NLinear(nn.Module):
37
+ """N linear layers applied in parallel to N disjoint parts of the input.
38
+
39
+ **Shape**
40
+
41
+ - Input: ``(B, N, in_features)``
42
+ - Output: ``(B, N, out_features)``
43
+
44
+ The i-th linear layer is applied to the i-th matrix of the shape (B, in_features).
45
+
46
+ Technically, this is a simplified version of delu.nn.NLinear:
47
+ https://yura52.github.io/delu/stable/api/generated/delu.nn.NLinear.html.
48
+ The difference is that this layer supports only 3D inputs
49
+ with exactly one batch dimension. By contrast, delu.nn.NLinear supports
50
+ any number of batch dimensions.
51
+ """
52
+
53
+ def __init__(
54
+ self, n: int, in_features: int, out_features: int, bias: bool = True
55
+ ) -> None:
56
+ super().__init__()
57
+ self.weight = nn.Parameter(torch.empty(n, in_features, out_features))
58
+ self.bias = nn.Parameter(torch.empty(n, out_features)) if bias else None
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ d = self.weight.shape[-2]
63
+ init_rsqrt_uniform_(self.weight, d)
64
+ if self.bias is not None:
65
+ init_rsqrt_uniform_(self.bias, d)
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ assert x.ndim == 3
69
+ assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]
70
+
71
+ x = x.transpose(0, 1)
72
+ x = x @ self.weight
73
+ x = x.transpose(0, 1)
74
+ if self.bias is not None:
75
+ x = x + self.bias
76
+ return x
77
+
78
+
79
+ class OneHotEncoding0d(nn.Module):
80
+ # Input: (*, n_cat_features=len(cardinalities))
81
+ # Output: (*, sum(cardinalities))
82
+
83
+ def __init__(self, cardinalities: list[int]) -> None:
84
+ super().__init__()
85
+ self._cardinalities = cardinalities
86
+
87
+ def forward(self, x: Tensor) -> Tensor:
88
+ assert x.ndim >= 1
89
+ assert x.shape[-1] == len(self._cardinalities)
90
+
91
+ return torch.cat(
92
+ [
93
+ # NOTE
94
+ # This is a quick hack to support out-of-vocabulary categories.
95
+ #
96
+ # Recall that lib.data.transform_cat encodes categorical features
97
+ # as follows:
98
+ # - In-vocabulary values receive indices from `range(cardinality)`.
99
+ # - All out-of-vocabulary values (i.e. new categories in validation
100
+ # and test data that are not presented in the training data)
101
+ # receive the index `cardinality`.
102
+ #
103
+ # As such, the line below will produce the standard one-hot encoding for
104
+ # known categories, and the all-zeros encoding for unknown categories.
105
+ # This may not be the best approach to deal with unknown values,
106
+ # but should be enough for our purposes.
107
+ nn.functional.one_hot(x[..., i], cardinality + 1)[..., :-1]
108
+ for i, cardinality in enumerate(self._cardinalities)
109
+ ],
110
+ -1,
111
+ )
112
+
113
+
114
+ class ScaleEnsemble(nn.Module):
115
+ def __init__(
116
+ self,
117
+ k: int,
118
+ d: int,
119
+ *,
120
+ init: Literal['ones', 'normal', 'random-signs'],
121
+ ) -> None:
122
+ super().__init__()
123
+ self.weight = nn.Parameter(torch.empty(k, d))
124
+ self._weight_init = init
125
+ self.reset_parameters()
126
+
127
+ def reset_parameters(self) -> None:
128
+ if self._weight_init == 'ones':
129
+ nn.init.ones_(self.weight)
130
+ elif self._weight_init == 'normal':
131
+ nn.init.normal_(self.weight)
132
+ elif self._weight_init == 'random-signs':
133
+ init_random_signs_(self.weight)
134
+ else:
135
+ raise ValueError(f'Unknown weight_init: {self._weight_init}')
136
+
137
+ def forward(self, x: Tensor) -> Tensor:
138
+ assert x.ndim >= 2
139
+ return x * self.weight
140
+
141
+
142
+ class LinearEfficientEnsemble(nn.Module):
143
+ """
144
+ This layer is a more configurable version of the "BatchEnsemble" layer
145
+ from the paper
146
+ "BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning"
147
+ (link: https://arxiv.org/abs/2002.06715).
148
+
149
+ First, this layer allows to select only some of the "ensembled" parts:
150
+ - the input scaling (r_i in the BatchEnsemble paper)
151
+ - the output scaling (s_i in the BatchEnsemble paper)
152
+ - the output bias (not mentioned in the BatchEnsemble paper,
153
+ but is presented in public implementations)
154
+
155
+ Second, the initialization of the scaling weights is configurable
156
+ through the `scaling_init` argument.
157
+
158
+ NOTE
159
+ The term "adapter" is used in the TabM paper only to tell the story.
160
+ The original BatchEnsemble paper does NOT use this term. So this class also
161
+ avoids the term "adapter".
162
+ """
163
+
164
+ r: Union[None, Tensor]
165
+ s: Union[None, Tensor]
166
+ bias: Union[None, Tensor]
167
+
168
+ def __init__(
169
+ self,
170
+ in_features: int,
171
+ out_features: int,
172
+ bias: bool = True,
173
+ *,
174
+ k: int,
175
+ ensemble_scaling_in: bool,
176
+ ensemble_scaling_out: bool,
177
+ ensemble_bias: bool,
178
+ scaling_init: Literal['ones', 'random-signs'],
179
+ ):
180
+ assert k > 0
181
+ if ensemble_bias:
182
+ assert bias
183
+ super().__init__()
184
+
185
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
186
+ self.register_parameter(
187
+ 'r',
188
+ (
189
+ nn.Parameter(torch.empty(k, in_features))
190
+ if ensemble_scaling_in
191
+ else None
192
+ ), # type: ignore[code]
193
+ )
194
+ self.register_parameter(
195
+ 's',
196
+ (
197
+ nn.Parameter(torch.empty(k, out_features))
198
+ if ensemble_scaling_out
199
+ else None
200
+ ), # type: ignore[code]
201
+ )
202
+ self.register_parameter(
203
+ 'bias',
204
+ (
205
+ nn.Parameter(torch.empty(out_features)) # type: ignore[code]
206
+ if bias and not ensemble_bias
207
+ else nn.Parameter(torch.empty(k, out_features))
208
+ if ensemble_bias
209
+ else None
210
+ ),
211
+ )
212
+
213
+ self.in_features = in_features
214
+ self.out_features = out_features
215
+ self.k = k
216
+ self.scaling_init = scaling_init
217
+
218
+ self.reset_parameters()
219
+
220
+ def reset_parameters(self):
221
+ init_rsqrt_uniform_(self.weight, self.in_features)
222
+ scaling_init_fn = {'ones': nn.init.ones_, 'random-signs': init_random_signs_}[
223
+ self.scaling_init
224
+ ]
225
+ if self.r is not None:
226
+ scaling_init_fn(self.r)
227
+ if self.s is not None:
228
+ scaling_init_fn(self.s)
229
+ if self.bias is not None:
230
+ bias_init = torch.empty(
231
+ # NOTE: the shape of bias_init is (out_features,) not (k, out_features).
232
+ # It means that all biases have the same initialization.
233
+ # This is similar to having one shared bias plus
234
+ # k zero-initialized non-shared biases.
235
+ self.out_features,
236
+ dtype=self.weight.dtype,
237
+ device=self.weight.device,
238
+ )
239
+ bias_init = init_rsqrt_uniform_(bias_init, self.in_features)
240
+ with torch.inference_mode():
241
+ self.bias.copy_(bias_init)
242
+
243
+ def forward(self, x: Tensor) -> Tensor:
244
+ # x.shape == (B, K, D)
245
+ assert x.ndim == 3
246
+
247
+ # >>> The equation (5) from the BatchEnsemble paper (arXiv v2).
248
+ if self.r is not None:
249
+ x = x * self.r
250
+ x = x @ self.weight.T
251
+ if self.s is not None:
252
+ x = x * self.s
253
+ # <<<
254
+
255
+ if self.bias is not None:
256
+ x = x + self.bias
257
+ return x
258
+
259
+
260
+ class MLP(nn.Module):
261
+ def __init__(
262
+ self,
263
+ *,
264
+ d_in: Union[None, int] = None,
265
+ d_out: Union[None, int] = None,
266
+ n_blocks: int,
267
+ d_block: int,
268
+ dropout: float,
269
+ activation: str = 'ReLU',
270
+ ) -> None:
271
+ super().__init__()
272
+
273
+ d_first = d_block if d_in is None else d_in
274
+ self.blocks = nn.ModuleList(
275
+ [
276
+ nn.Sequential(
277
+ nn.Linear(d_first if i == 0 else d_block, d_block),
278
+ getattr(nn, activation)(),
279
+ nn.Dropout(dropout),
280
+ )
281
+ for i in range(n_blocks)
282
+ ]
283
+ )
284
+ self.output = None if d_out is None else nn.Linear(d_block, d_out)
285
+
286
+ def forward(self, x: Tensor) -> Tensor:
287
+ for block in self.blocks:
288
+ x = block(x)
289
+ if self.output is not None:
290
+ x = self.output(x)
291
+ return x
292
+
293
+
294
+ def make_efficient_ensemble(module: nn.Module, EnsembleLayer, **kwargs) -> None:
295
+ """Replace linear layers with efficient ensembles of linear layers.
296
+
297
+ NOTE
298
+ In the paper, there are no experiments with networks with normalization layers.
299
+ Perhaps, their trainable weights (the affine transformations) also need
300
+ "ensemblification" as in the paper about "FiLM-Ensemble".
301
+ Additional experiments are required to make conclusions.
302
+ """
303
+ for name, submodule in list(module.named_children()):
304
+ if isinstance(submodule, nn.Linear):
305
+ module.add_module(
306
+ name,
307
+ EnsembleLayer(
308
+ in_features=submodule.in_features,
309
+ out_features=submodule.out_features,
310
+ bias=submodule.bias is not None,
311
+ **kwargs,
312
+ ),
313
+ )
314
+ else:
315
+ make_efficient_ensemble(submodule, EnsembleLayer, **kwargs)
316
+
317
+
318
+ def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
319
+ if isinstance(backbone, MLP):
320
+ return backbone.blocks[0][0] # type: ignore[code]
321
+ else:
322
+ raise RuntimeError(f'Unsupported backbone: {backbone}')
323
+
324
+
325
+ @torch.inference_mode()
326
+ def _init_first_adapter(
327
+ weight: Tensor,
328
+ distribution: Literal['normal', 'random-signs'],
329
+ init_sections: list[int],
330
+ ) -> None:
331
+ """Initialize the first adapter.
332
+
333
+ NOTE
334
+ The `init_sections` argument is a historical artifact that accidentally leaked
335
+ from irrelevant experiments to the final models. Perhaps, the code related
336
+ to `init_sections` can be simply removed, but this was not tested.
337
+ """
338
+ assert weight.ndim == 2
339
+ assert weight.shape[1] == sum(init_sections)
340
+
341
+ if distribution == 'normal':
342
+ init_fn_ = nn.init.normal_
343
+ elif distribution == 'random-signs':
344
+ init_fn_ = init_random_signs_
345
+ else:
346
+ raise ValueError(f'Unknown distribution: {distribution}')
347
+
348
+ section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()]
349
+ for i in range(len(init_sections)):
350
+ # NOTE
351
+ # As noted above, this section-based initialization is an arbitrary historical
352
+ # artifact. Consider the first adapter of one ensemble member.
353
+ # This adapter vector is implicitly split into "sections",
354
+ # where one section corresponds to one feature. The code below ensures that
355
+ # the adapter weights in one section are initialized with the same random value
356
+ # from the given distribution.
357
+ w = torch.empty((len(weight), 1), dtype=weight.dtype, device=weight.device)
358
+ init_fn_(w)
359
+ weight[:, section_bounds[i] : section_bounds[i + 1]] = w
360
+
361
+
362
+ _CUSTOM_MODULES = {
363
+ # https://docs.python.org/3/library/stdtypes.html#definition.__name__
364
+ CustomModule.__name__: CustomModule
365
+ for CustomModule in [
366
+ rtdl_num_embeddings.LinearEmbeddings,
367
+ rtdl_num_embeddings.LinearReLUEmbeddings,
368
+ rtdl_num_embeddings.PeriodicEmbeddings,
369
+ rtdl_num_embeddings.PiecewiseLinearEmbeddings,
370
+ MLP,
371
+ ]
372
+ }
373
+
374
+
375
+ def make_module(type: str, *args, **kwargs) -> nn.Module:
376
+ Module = getattr(nn, type, None)
377
+ if Module is None:
378
+ Module = _CUSTOM_MODULES[type]
379
+ return Module(*args, **kwargs)
380
+
381
+
382
+ # ======================================================================================
383
+ # Optimization
384
+ # ======================================================================================
385
+ def default_zero_weight_decay_condition(
386
+ module_name: str, module: nn.Module, parameter_name: str, parameter: nn.Parameter
387
+ ):
388
+ del module_name, parameter
389
+ return parameter_name.endswith('bias') or isinstance(
390
+ module,
391
+ (
392
+ nn.BatchNorm1d,
393
+ nn.LayerNorm,
394
+ nn.InstanceNorm1d,
395
+ rtdl_num_embeddings.LinearEmbeddings,
396
+ rtdl_num_embeddings.LinearReLUEmbeddings,
397
+ _Periodic,
398
+ ),
399
+ )
400
+
401
+
402
+ def make_parameter_groups(
403
+ module: nn.Module,
404
+ zero_weight_decay_condition=default_zero_weight_decay_condition,
405
+ custom_groups: Union[None, list[dict[str, Any]]] = None,
406
+ ) -> list[dict[str, Any]]:
407
+ if custom_groups is None:
408
+ custom_groups = []
409
+ custom_params = frozenset(
410
+ itertools.chain.from_iterable(group['params'] for group in custom_groups)
411
+ )
412
+ assert len(custom_params) == sum(
413
+ len(group['params']) for group in custom_groups
414
+ ), 'Parameters in custom_groups must not intersect'
415
+ zero_wd_params = frozenset(
416
+ p
417
+ for mn, m in module.named_modules()
418
+ for pn, p in m.named_parameters()
419
+ if p not in custom_params and zero_weight_decay_condition(mn, m, pn, p)
420
+ )
421
+ default_group = {
422
+ 'params': [
423
+ p
424
+ for p in module.parameters()
425
+ if p not in custom_params and p not in zero_wd_params
426
+ ]
427
+ }
428
+ return [
429
+ default_group,
430
+ {'params': list(zero_wd_params), 'weight_decay': 0.0},
431
+ *custom_groups,
432
+ ]
433
+
434
+
435
+ # ======================================================================================
436
+ # The model
437
+ # ======================================================================================
438
+ class Model(nn.Module):
439
+ """MLP & TabM."""
440
+
441
+ def __init__(
442
+ self,
443
+ *,
444
+ n_num_features: int,
445
+ cat_cardinalities: list[int],
446
+ n_classes: Union[None, int],
447
+ backbone: dict,
448
+ bins: Union[None, list[Tensor]], # For piecewise-linear encoding/embeddings.
449
+ num_embeddings: Union[None, dict] = None,
450
+ arch_type: Literal[
451
+ # Plain feed-forward network without any kind of ensembling.
452
+ 'plain',
453
+ #
454
+ # TabM
455
+ 'tabm',
456
+ #
457
+ # TabM-mini
458
+ 'tabm-mini',
459
+ #
460
+ # TabM-packed
461
+ 'tabm-packed',
462
+ #
463
+ # TabM. The first adapter is initialized from the normal distribution.
464
+ # This variant was not used in the paper, but it may be useful in practice.
465
+ 'tabm-normal',
466
+ #
467
+ # TabM-mini. The adapter is initialized from the normal distribution.
468
+ # This variant was not used in the paper.
469
+ 'tabm-mini-normal',
470
+ ],
471
+ k: Union[None, int] = None,
472
+ share_training_batches: bool = True,
473
+ ) -> None:
474
+ # >>> Validate arguments.
475
+ assert n_num_features >= 0
476
+ assert n_num_features or cat_cardinalities
477
+ if arch_type == 'plain':
478
+ assert k is None
479
+ assert (
480
+ share_training_batches
481
+ ), 'If `arch_type` is set to "plain", then `simple` must remain True'
482
+ else:
483
+ assert k is not None
484
+ assert k > 0
485
+
486
+ super().__init__()
487
+
488
+ # >>> Continuous (numerical) features
489
+ first_adapter_sections = [] # See the comment in `_init_first_adapter`.
490
+
491
+ if n_num_features == 0:
492
+ assert bins is None
493
+ self.num_module = None
494
+ d_num = 0
495
+
496
+ elif num_embeddings is None:
497
+ assert bins is None
498
+ self.num_module = None
499
+ d_num = n_num_features
500
+ first_adapter_sections.extend(1 for _ in range(n_num_features))
501
+
502
+ else:
503
+ if bins is None:
504
+ self.num_module = make_module(
505
+ **num_embeddings, n_features=n_num_features
506
+ )
507
+ else:
508
+ assert num_embeddings['type'].startswith('PiecewiseLinearEmbeddings')
509
+ self.num_module = make_module(**num_embeddings, bins=bins)
510
+ d_num = n_num_features * num_embeddings['d_embedding']
511
+ first_adapter_sections.extend(
512
+ num_embeddings['d_embedding'] for _ in range(n_num_features)
513
+ )
514
+
515
+ # >>> Categorical features
516
+ self.cat_module = (
517
+ OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
518
+ )
519
+ first_adapter_sections.extend(cat_cardinalities)
520
+ d_cat = sum(cat_cardinalities)
521
+
522
+ # >>> Backbone
523
+ d_flat = d_num + d_cat
524
+ self.minimal_ensemble_adapter = None
525
+ # Any backbone can be here but we provide only MLP
526
+ self.backbone = make_module(d_in=d_flat, **backbone)
527
+
528
+ if arch_type != 'plain':
529
+ assert k is not None
530
+ first_adapter_init = (
531
+ None
532
+ if arch_type == 'tabm-packed'
533
+ else 'normal'
534
+ if arch_type in ('tabm-mini-normal', 'tabm-normal')
535
+ # For other arch_types, the initialization depends
536
+ # on the presence of num_embeddings.
537
+ else 'random-signs'
538
+ if num_embeddings is None
539
+ else 'normal'
540
+ )
541
+
542
+ if arch_type in ('tabm', 'tabm-normal'):
543
+ # Like BatchEnsemble, but all multiplicative adapters,
544
+ # except for the very first one, are initialized with ones.
545
+ assert first_adapter_init is not None
546
+ make_efficient_ensemble(
547
+ self.backbone,
548
+ LinearEfficientEnsemble,
549
+ k=k,
550
+ ensemble_scaling_in=True,
551
+ ensemble_scaling_out=True,
552
+ ensemble_bias=True,
553
+ scaling_init='ones',
554
+ )
555
+ _init_first_adapter(
556
+ _get_first_ensemble_layer(self.backbone).r, # type: ignore[code]
557
+ first_adapter_init,
558
+ first_adapter_sections,
559
+ )
560
+
561
+ elif arch_type in ('tabm-mini', 'tabm-mini-normal'):
562
+ # MiniEnsemble
563
+ assert first_adapter_init is not None
564
+ self.minimal_ensemble_adapter = ScaleEnsemble(
565
+ k,
566
+ d_flat,
567
+ init='random-signs' if num_embeddings is None else 'normal',
568
+ )
569
+ _init_first_adapter(
570
+ self.minimal_ensemble_adapter.weight, # type: ignore[code]
571
+ first_adapter_init,
572
+ first_adapter_sections,
573
+ )
574
+
575
+ elif arch_type == 'tabm-packed':
576
+ # Packed ensemble.
577
+ # In terms of the Packed Ensembles paper by Laurent et al.,
578
+ # TabM-packed is PackedEnsemble(alpha=k, M=k, gamma=1).
579
+ assert first_adapter_init is None
580
+ make_efficient_ensemble(self.backbone, NLinear, n=k)
581
+
582
+ else:
583
+ raise ValueError(f'Unknown arch_type: {arch_type}')
584
+
585
+ # >>> Output
586
+ d_block = backbone['d_block']
587
+ d_out = 1 if n_classes is None else n_classes
588
+ self.output = (
589
+ nn.Linear(d_block, d_out)
590
+ if arch_type == 'plain'
591
+ else NLinear(k, d_block, d_out) # type: ignore[code]
592
+ )
593
+
594
+ # >>>
595
+ self.arch_type = arch_type
596
+ self.k = k
597
+ self.share_training_batches = share_training_batches
598
+
599
+ def forward(
600
+ self, x_num: Union[None, Tensor] = None, x_cat: Union[None, Tensor] = None
601
+ ) -> Tensor:
602
+ x = []
603
+ if x_num is not None:
604
+ x.append(x_num if self.num_module is None else self.num_module(x_num))
605
+ if x_cat is None:
606
+ assert self.cat_module is None
607
+ else:
608
+ assert self.cat_module is not None
609
+ x.append(self.cat_module(x_cat).float())
610
+ x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
611
+
612
+ if self.k is not None:
613
+ if self.share_training_batches or not self.training:
614
+ # (B, D) -> (B, K, D)
615
+ x = x[:, None].expand(-1, self.k, -1)
616
+ else:
617
+ # (B * K, D) -> (B, K, D)
618
+ x = x.reshape(len(x) // self.k, self.k, *x.shape[1:])
619
+ if self.minimal_ensemble_adapter is not None:
620
+ x = self.minimal_ensemble_adapter(x)
621
+ else:
622
+ assert self.minimal_ensemble_adapter is None
623
+
624
+ x = self.backbone(x)
625
+ x = self.output(x)
626
+ if self.k is None:
627
+ # Adjust the output shape for plain networks to make them compatible
628
+ # with the rest of the script (loss, metrics, predictions, ...).
629
+ # (B, D_OUT) -> (B, 1, D_OUT)
630
+ x = x[:, None]
631
+ return x
@@ -26,6 +26,8 @@ class TabPFNMixModel(AbstractModel):
26
26
 
27
27
  TabPFNMix is based off of the TabPFN and TabForestPFN models.
28
28
 
29
+ We recommend using Mitra instead, as it is an improved version of TabPFNMix.
30
+
29
31
  It is a tabular transformer model pre-trained on purely synthetic data.
30
32
 
31
33
  It currently has several limitations:
@@ -34,10 +36,13 @@ class TabPFNMixModel(AbstractModel):
34
36
  3. Does not support GPU
35
37
 
36
38
  For more information, refer to the `./_internals/README.md` file.
39
+
40
+ .. versionadded:: 1.2.0
37
41
  """
38
42
  ag_key = "TABPFNMIX"
39
43
  ag_name = "TabPFNMix"
40
44
  ag_priority = 45
45
+ seed_name = "random_state"
41
46
 
42
47
  weights_file_name = "model.pt"
43
48
 
@@ -119,6 +124,7 @@ class TabPFNMixModel(AbstractModel):
119
124
  raise AssertionError(f"Max allowed classes for the model is {max_classes}, " f"but found {self.num_classes} classes.")
120
125
 
121
126
  params = self._get_model_params()
127
+ random_state = params.pop(self.seed_name, self.default_random_seed)
122
128
  sample_rows = ag_params.get("sample_rows", None)
123
129
  sample_rows_val = ag_params.get("sample_rows_val", None)
124
130
  max_rows = ag_params.get("max_rows", None)
@@ -129,11 +135,11 @@ class TabPFNMixModel(AbstractModel):
129
135
 
130
136
  # TODO: Make sample_rows generic
131
137
  if sample_rows is not None and isinstance(sample_rows, int) and len(X) > sample_rows:
132
- X, y = self._subsample_data(X=X, y=y, num_rows=sample_rows)
138
+ X, y = self._subsample_data(X=X, y=y, num_rows=sample_rows, random_state=random_state)
133
139
 
134
140
  # TODO: Make sample_rows generic
135
141
  if X_val is not None and y_val is not None and sample_rows_val is not None and isinstance(sample_rows_val, int) and len(X_val) > sample_rows_val:
136
- X_val, y_val = self._subsample_data(X=X_val, y=y_val, num_rows=sample_rows_val)
142
+ X_val, y_val = self._subsample_data(X=X_val, y=y_val, num_rows=sample_rows_val, random_state=random_state)
137
143
 
138
144
  from ._internal.core.enums import Task
139
145
  if self.problem_type in [REGRESSION, QUANTILE]:
@@ -174,7 +180,7 @@ class TabPFNMixModel(AbstractModel):
174
180
  elif weights_path is not None:
175
181
  logger.log(15, f'\tLoading pre-trained weights from file... (weights_path="{weights_path}")')
176
182
 
177
- cfg = ConfigRun(hyperparams=params, task=task, device=device)
183
+ cfg = ConfigRun(hyperparams=params, task=task, device=device, seed=random_state)
178
184
 
179
185
  if cfg.hyperparams["max_epochs"] == 0 and cfg.hyperparams["n_ensembles"] != 1:
180
186
  logger.log(
@@ -238,7 +244,7 @@ class TabPFNMixModel(AbstractModel):
238
244
  return self
239
245
 
240
246
  # TODO: Make this generic by creating a generic `preprocess_train` and putting this logic prior to `_preprocess`.
241
- def _subsample_data(self, X: pd.DataFrame, y: pd.Series, num_rows: int, random_state=0) -> (pd.DataFrame, pd.Series):
247
+ def _subsample_data(self, X: pd.DataFrame, y: pd.Series, num_rows: int, random_state: int | None = 0) -> (pd.DataFrame, pd.Series):
242
248
  num_rows_to_drop = len(X) - num_rows
243
249
  X, _, y, _ = generate_train_test_split(
244
250
  X=X,
@@ -311,11 +317,11 @@ class TabPFNMixModel(AbstractModel):
311
317
 
312
318
  def _get_maximum_resources(self) -> dict[str, int | float]:
313
319
  # torch model trains slower when utilizing virtual cores and this issue scale up when the number of cpu cores increases
314
- return {"num_cpus": ResourceManager.get_cpu_count_psutil(logical=False)}
320
+ return {"num_cpus": ResourceManager.get_cpu_count(only_physical_cores=True)}
315
321
 
316
322
  def _get_default_resources(self) -> tuple[int, float]:
317
- # logical=False is faster in training
318
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
323
+ # only_physical_cores=True is faster in training
324
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
319
325
  num_gpus = 0
320
326
  return num_cpus, num_gpus
321
327