autogluon.tabular 1.3.2b20250709__py3-none-any.whl → 1.3.2b20250711__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. autogluon/tabular/models/__init__.py +3 -0
  2. autogluon/tabular/models/catboost/callbacks.py +3 -2
  3. autogluon/tabular/models/catboost/catboost_model.py +2 -2
  4. autogluon/tabular/models/catboost/catboost_utils.py +7 -3
  5. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +3 -3
  6. autogluon/tabular/models/lgb/lgb_model.py +2 -2
  7. autogluon/tabular/models/realmlp/__init__.py +0 -0
  8. autogluon/tabular/models/realmlp/realmlp_model.py +347 -0
  9. autogluon/tabular/models/rf/rf_model.py +2 -1
  10. autogluon/tabular/models/tabicl/__init__.py +0 -0
  11. autogluon/tabular/models/tabicl/tabicl_model.py +174 -0
  12. autogluon/tabular/models/tabm/__init__.py +0 -0
  13. autogluon/tabular/models/tabm/_tabm_internal.py +544 -0
  14. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +807 -0
  15. autogluon/tabular/models/tabm/tabm_model.py +275 -0
  16. autogluon/tabular/models/tabm/tabm_reference.py +627 -0
  17. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +3 -3
  18. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -3
  19. autogluon/tabular/models/xgboost/xgboost_model.py +2 -2
  20. autogluon/tabular/predictor/predictor.py +5 -3
  21. autogluon/tabular/registry/_ag_model_registry.py +6 -0
  22. autogluon/tabular/testing/fit_helper.py +27 -25
  23. autogluon/tabular/testing/generate_datasets.py +7 -0
  24. autogluon/tabular/trainer/abstract_trainer.py +1 -1
  25. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  26. autogluon/tabular/version.py +1 -1
  27. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/METADATA +21 -13
  28. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/RECORD +35 -26
  29. /autogluon.tabular-1.3.2b20250709-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth +0 -0
  30. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/LICENSE +0 -0
  31. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/NOTICE +0 -0
  32. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/WHEEL +0 -0
  33. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/namespace_packages.txt +0 -0
  34. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/top_level.txt +0 -0
  35. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250711.dist-info}/zip-safe +0 -0
@@ -0,0 +1,627 @@
1
+ # License: https://github.com/yandex-research/tabm/blob/main/LICENSE
2
+
3
+ # NOTE
4
+ # The minimum required versions of the dependencies are specified in README.md.
5
+
6
+ import itertools
7
+ from typing import Any, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+ from . import rtdl_num_embeddings
14
+ from .rtdl_num_embeddings import _Periodic
15
+
16
+
17
+ # ======================================================================================
18
+ # Initialization
19
+ # ======================================================================================
20
+ def init_rsqrt_uniform_(x: Tensor, d: int) -> Tensor:
21
+ assert d > 0
22
+ d_rsqrt = d**-0.5
23
+ return nn.init.uniform_(x, -d_rsqrt, d_rsqrt)
24
+
25
+
26
+ @torch.inference_mode()
27
+ def init_random_signs_(x: Tensor) -> Tensor:
28
+ return x.bernoulli_(0.5).mul_(2).add_(-1)
29
+
30
+
31
+ # ======================================================================================
32
+ # Modules
33
+ # ======================================================================================
34
+ class NLinear(nn.Module):
35
+ """N linear layers applied in parallel to N disjoint parts of the input.
36
+
37
+ **Shape**
38
+
39
+ - Input: ``(B, N, in_features)``
40
+ - Output: ``(B, N, out_features)``
41
+
42
+ The i-th linear layer is applied to the i-th matrix of the shape (B, in_features).
43
+
44
+ Technically, this is a simplified version of delu.nn.NLinear:
45
+ https://yura52.github.io/delu/stable/api/generated/delu.nn.NLinear.html.
46
+ The difference is that this layer supports only 3D inputs
47
+ with exactly one batch dimension. By contrast, delu.nn.NLinear supports
48
+ any number of batch dimensions.
49
+ """
50
+
51
+ def __init__(
52
+ self, n: int, in_features: int, out_features: int, bias: bool = True
53
+ ) -> None:
54
+ super().__init__()
55
+ self.weight = nn.Parameter(torch.empty(n, in_features, out_features))
56
+ self.bias = nn.Parameter(torch.empty(n, out_features)) if bias else None
57
+ self.reset_parameters()
58
+
59
+ def reset_parameters(self):
60
+ d = self.weight.shape[-2]
61
+ init_rsqrt_uniform_(self.weight, d)
62
+ if self.bias is not None:
63
+ init_rsqrt_uniform_(self.bias, d)
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ assert x.ndim == 3
67
+ assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]
68
+
69
+ x = x.transpose(0, 1)
70
+ x = x @ self.weight
71
+ x = x.transpose(0, 1)
72
+ if self.bias is not None:
73
+ x = x + self.bias
74
+ return x
75
+
76
+
77
+ class OneHotEncoding0d(nn.Module):
78
+ # Input: (*, n_cat_features=len(cardinalities))
79
+ # Output: (*, sum(cardinalities))
80
+
81
+ def __init__(self, cardinalities: list[int]) -> None:
82
+ super().__init__()
83
+ self._cardinalities = cardinalities
84
+
85
+ def forward(self, x: Tensor) -> Tensor:
86
+ assert x.ndim >= 1
87
+ assert x.shape[-1] == len(self._cardinalities)
88
+
89
+ return torch.cat(
90
+ [
91
+ # NOTE
92
+ # This is a quick hack to support out-of-vocabulary categories.
93
+ #
94
+ # Recall that lib.data.transform_cat encodes categorical features
95
+ # as follows:
96
+ # - In-vocabulary values receive indices from `range(cardinality)`.
97
+ # - All out-of-vocabulary values (i.e. new categories in validation
98
+ # and test data that are not presented in the training data)
99
+ # receive the index `cardinality`.
100
+ #
101
+ # As such, the line below will produce the standard one-hot encoding for
102
+ # known categories, and the all-zeros encoding for unknown categories.
103
+ # This may not be the best approach to deal with unknown values,
104
+ # but should be enough for our purposes.
105
+ nn.functional.one_hot(x[..., i], cardinality + 1)[..., :-1]
106
+ for i, cardinality in enumerate(self._cardinalities)
107
+ ],
108
+ -1,
109
+ )
110
+
111
+
112
+ class ScaleEnsemble(nn.Module):
113
+ def __init__(
114
+ self,
115
+ k: int,
116
+ d: int,
117
+ *,
118
+ init: Literal['ones', 'normal', 'random-signs'],
119
+ ) -> None:
120
+ super().__init__()
121
+ self.weight = nn.Parameter(torch.empty(k, d))
122
+ self._weight_init = init
123
+ self.reset_parameters()
124
+
125
+ def reset_parameters(self) -> None:
126
+ if self._weight_init == 'ones':
127
+ nn.init.ones_(self.weight)
128
+ elif self._weight_init == 'normal':
129
+ nn.init.normal_(self.weight)
130
+ elif self._weight_init == 'random-signs':
131
+ init_random_signs_(self.weight)
132
+ else:
133
+ raise ValueError(f'Unknown weight_init: {self._weight_init}')
134
+
135
+ def forward(self, x: Tensor) -> Tensor:
136
+ assert x.ndim >= 2
137
+ return x * self.weight
138
+
139
+
140
+ class LinearEfficientEnsemble(nn.Module):
141
+ """
142
+ This layer is a more configurable version of the "BatchEnsemble" layer
143
+ from the paper
144
+ "BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning"
145
+ (link: https://arxiv.org/abs/2002.06715).
146
+
147
+ First, this layer allows to select only some of the "ensembled" parts:
148
+ - the input scaling (r_i in the BatchEnsemble paper)
149
+ - the output scaling (s_i in the BatchEnsemble paper)
150
+ - the output bias (not mentioned in the BatchEnsemble paper,
151
+ but is presented in public implementations)
152
+
153
+ Second, the initialization of the scaling weights is configurable
154
+ through the `scaling_init` argument.
155
+
156
+ NOTE
157
+ The term "adapter" is used in the TabM paper only to tell the story.
158
+ The original BatchEnsemble paper does NOT use this term. So this class also
159
+ avoids the term "adapter".
160
+ """
161
+
162
+ r: None | Tensor
163
+ s: None | Tensor
164
+ bias: None | Tensor
165
+
166
+ def __init__(
167
+ self,
168
+ in_features: int,
169
+ out_features: int,
170
+ bias: bool = True,
171
+ *,
172
+ k: int,
173
+ ensemble_scaling_in: bool,
174
+ ensemble_scaling_out: bool,
175
+ ensemble_bias: bool,
176
+ scaling_init: Literal['ones', 'random-signs'],
177
+ ):
178
+ assert k > 0
179
+ if ensemble_bias:
180
+ assert bias
181
+ super().__init__()
182
+
183
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
184
+ self.register_parameter(
185
+ 'r',
186
+ (
187
+ nn.Parameter(torch.empty(k, in_features))
188
+ if ensemble_scaling_in
189
+ else None
190
+ ), # type: ignore[code]
191
+ )
192
+ self.register_parameter(
193
+ 's',
194
+ (
195
+ nn.Parameter(torch.empty(k, out_features))
196
+ if ensemble_scaling_out
197
+ else None
198
+ ), # type: ignore[code]
199
+ )
200
+ self.register_parameter(
201
+ 'bias',
202
+ (
203
+ nn.Parameter(torch.empty(out_features)) # type: ignore[code]
204
+ if bias and not ensemble_bias
205
+ else nn.Parameter(torch.empty(k, out_features))
206
+ if ensemble_bias
207
+ else None
208
+ ),
209
+ )
210
+
211
+ self.in_features = in_features
212
+ self.out_features = out_features
213
+ self.k = k
214
+ self.scaling_init = scaling_init
215
+
216
+ self.reset_parameters()
217
+
218
+ def reset_parameters(self):
219
+ init_rsqrt_uniform_(self.weight, self.in_features)
220
+ scaling_init_fn = {'ones': nn.init.ones_, 'random-signs': init_random_signs_}[
221
+ self.scaling_init
222
+ ]
223
+ if self.r is not None:
224
+ scaling_init_fn(self.r)
225
+ if self.s is not None:
226
+ scaling_init_fn(self.s)
227
+ if self.bias is not None:
228
+ bias_init = torch.empty(
229
+ # NOTE: the shape of bias_init is (out_features,) not (k, out_features).
230
+ # It means that all biases have the same initialization.
231
+ # This is similar to having one shared bias plus
232
+ # k zero-initialized non-shared biases.
233
+ self.out_features,
234
+ dtype=self.weight.dtype,
235
+ device=self.weight.device,
236
+ )
237
+ bias_init = init_rsqrt_uniform_(bias_init, self.in_features)
238
+ with torch.inference_mode():
239
+ self.bias.copy_(bias_init)
240
+
241
+ def forward(self, x: Tensor) -> Tensor:
242
+ # x.shape == (B, K, D)
243
+ assert x.ndim == 3
244
+
245
+ # >>> The equation (5) from the BatchEnsemble paper (arXiv v2).
246
+ if self.r is not None:
247
+ x = x * self.r
248
+ x = x @ self.weight.T
249
+ if self.s is not None:
250
+ x = x * self.s
251
+ # <<<
252
+
253
+ if self.bias is not None:
254
+ x = x + self.bias
255
+ return x
256
+
257
+
258
+ class MLP(nn.Module):
259
+ def __init__(
260
+ self,
261
+ *,
262
+ d_in: None | int = None,
263
+ d_out: None | int = None,
264
+ n_blocks: int,
265
+ d_block: int,
266
+ dropout: float,
267
+ activation: str = 'ReLU',
268
+ ) -> None:
269
+ super().__init__()
270
+
271
+ d_first = d_block if d_in is None else d_in
272
+ self.blocks = nn.ModuleList(
273
+ [
274
+ nn.Sequential(
275
+ nn.Linear(d_first if i == 0 else d_block, d_block),
276
+ getattr(nn, activation)(),
277
+ nn.Dropout(dropout),
278
+ )
279
+ for i in range(n_blocks)
280
+ ]
281
+ )
282
+ self.output = None if d_out is None else nn.Linear(d_block, d_out)
283
+
284
+ def forward(self, x: Tensor) -> Tensor:
285
+ for block in self.blocks:
286
+ x = block(x)
287
+ if self.output is not None:
288
+ x = self.output(x)
289
+ return x
290
+
291
+
292
+ def make_efficient_ensemble(module: nn.Module, EnsembleLayer, **kwargs) -> None:
293
+ """Replace linear layers with efficient ensembles of linear layers.
294
+
295
+ NOTE
296
+ In the paper, there are no experiments with networks with normalization layers.
297
+ Perhaps, their trainable weights (the affine transformations) also need
298
+ "ensemblification" as in the paper about "FiLM-Ensemble".
299
+ Additional experiments are required to make conclusions.
300
+ """
301
+ for name, submodule in list(module.named_children()):
302
+ if isinstance(submodule, nn.Linear):
303
+ module.add_module(
304
+ name,
305
+ EnsembleLayer(
306
+ in_features=submodule.in_features,
307
+ out_features=submodule.out_features,
308
+ bias=submodule.bias is not None,
309
+ **kwargs,
310
+ ),
311
+ )
312
+ else:
313
+ make_efficient_ensemble(submodule, EnsembleLayer, **kwargs)
314
+
315
+
316
+ def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
317
+ if isinstance(backbone, MLP):
318
+ return backbone.blocks[0][0] # type: ignore[code]
319
+ else:
320
+ raise RuntimeError(f'Unsupported backbone: {backbone}')
321
+
322
+
323
+ @torch.inference_mode()
324
+ def _init_first_adapter(
325
+ weight: Tensor,
326
+ distribution: Literal['normal', 'random-signs'],
327
+ init_sections: list[int],
328
+ ) -> None:
329
+ """Initialize the first adapter.
330
+
331
+ NOTE
332
+ The `init_sections` argument is a historical artifact that accidentally leaked
333
+ from irrelevant experiments to the final models. Perhaps, the code related
334
+ to `init_sections` can be simply removed, but this was not tested.
335
+ """
336
+ assert weight.ndim == 2
337
+ assert weight.shape[1] == sum(init_sections)
338
+
339
+ if distribution == 'normal':
340
+ init_fn_ = nn.init.normal_
341
+ elif distribution == 'random-signs':
342
+ init_fn_ = init_random_signs_
343
+ else:
344
+ raise ValueError(f'Unknown distribution: {distribution}')
345
+
346
+ section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()]
347
+ for i in range(len(init_sections)):
348
+ # NOTE
349
+ # As noted above, this section-based initialization is an arbitrary historical
350
+ # artifact. Consider the first adapter of one ensemble member.
351
+ # This adapter vector is implicitly split into "sections",
352
+ # where one section corresponds to one feature. The code below ensures that
353
+ # the adapter weights in one section are initialized with the same random value
354
+ # from the given distribution.
355
+ w = torch.empty((len(weight), 1), dtype=weight.dtype, device=weight.device)
356
+ init_fn_(w)
357
+ weight[:, section_bounds[i] : section_bounds[i + 1]] = w
358
+
359
+
360
+ _CUSTOM_MODULES = {
361
+ # https://docs.python.org/3/library/stdtypes.html#definition.__name__
362
+ CustomModule.__name__: CustomModule
363
+ for CustomModule in [
364
+ rtdl_num_embeddings.LinearEmbeddings,
365
+ rtdl_num_embeddings.LinearReLUEmbeddings,
366
+ rtdl_num_embeddings.PeriodicEmbeddings,
367
+ rtdl_num_embeddings.PiecewiseLinearEmbeddings,
368
+ MLP,
369
+ ]
370
+ }
371
+
372
+
373
+ def make_module(type: str, *args, **kwargs) -> nn.Module:
374
+ Module = getattr(nn, type, None)
375
+ if Module is None:
376
+ Module = _CUSTOM_MODULES[type]
377
+ return Module(*args, **kwargs)
378
+
379
+
380
+ # ======================================================================================
381
+ # Optimization
382
+ # ======================================================================================
383
+ def default_zero_weight_decay_condition(
384
+ module_name: str, module: nn.Module, parameter_name: str, parameter: nn.Parameter
385
+ ):
386
+ del module_name, parameter
387
+ return parameter_name.endswith('bias') or isinstance(
388
+ module,
389
+ nn.BatchNorm1d
390
+ | nn.LayerNorm
391
+ | nn.InstanceNorm1d
392
+ | rtdl_num_embeddings.LinearEmbeddings
393
+ | rtdl_num_embeddings.LinearReLUEmbeddings
394
+ | _Periodic,
395
+ )
396
+
397
+
398
+ def make_parameter_groups(
399
+ module: nn.Module,
400
+ zero_weight_decay_condition=default_zero_weight_decay_condition,
401
+ custom_groups: None | list[dict[str, Any]] = None,
402
+ ) -> list[dict[str, Any]]:
403
+ if custom_groups is None:
404
+ custom_groups = []
405
+ custom_params = frozenset(
406
+ itertools.chain.from_iterable(group['params'] for group in custom_groups)
407
+ )
408
+ assert len(custom_params) == sum(
409
+ len(group['params']) for group in custom_groups
410
+ ), 'Parameters in custom_groups must not intersect'
411
+ zero_wd_params = frozenset(
412
+ p
413
+ for mn, m in module.named_modules()
414
+ for pn, p in m.named_parameters()
415
+ if p not in custom_params and zero_weight_decay_condition(mn, m, pn, p)
416
+ )
417
+ default_group = {
418
+ 'params': [
419
+ p
420
+ for p in module.parameters()
421
+ if p not in custom_params and p not in zero_wd_params
422
+ ]
423
+ }
424
+ return [
425
+ default_group,
426
+ {'params': list(zero_wd_params), 'weight_decay': 0.0},
427
+ *custom_groups,
428
+ ]
429
+
430
+
431
+ # ======================================================================================
432
+ # The model
433
+ # ======================================================================================
434
+ class Model(nn.Module):
435
+ """MLP & TabM."""
436
+
437
+ def __init__(
438
+ self,
439
+ *,
440
+ n_num_features: int,
441
+ cat_cardinalities: list[int],
442
+ n_classes: None | int,
443
+ backbone: dict,
444
+ bins: None | list[Tensor], # For piecewise-linear encoding/embeddings.
445
+ num_embeddings: None | dict = None,
446
+ arch_type: Literal[
447
+ # Plain feed-forward network without any kind of ensembling.
448
+ 'plain',
449
+ #
450
+ # TabM
451
+ 'tabm',
452
+ #
453
+ # TabM-mini
454
+ 'tabm-mini',
455
+ #
456
+ # TabM-packed
457
+ 'tabm-packed',
458
+ #
459
+ # TabM. The first adapter is initialized from the normal distribution.
460
+ # This variant was not used in the paper, but it may be useful in practice.
461
+ 'tabm-normal',
462
+ #
463
+ # TabM-mini. The adapter is initialized from the normal distribution.
464
+ # This variant was not used in the paper.
465
+ 'tabm-mini-normal',
466
+ ],
467
+ k: None | int = None,
468
+ share_training_batches: bool = True,
469
+ ) -> None:
470
+ # >>> Validate arguments.
471
+ assert n_num_features >= 0
472
+ assert n_num_features or cat_cardinalities
473
+ if arch_type == 'plain':
474
+ assert k is None
475
+ assert (
476
+ share_training_batches
477
+ ), 'If `arch_type` is set to "plain", then `simple` must remain True'
478
+ else:
479
+ assert k is not None
480
+ assert k > 0
481
+
482
+ super().__init__()
483
+
484
+ # >>> Continuous (numerical) features
485
+ first_adapter_sections = [] # See the comment in `_init_first_adapter`.
486
+
487
+ if n_num_features == 0:
488
+ assert bins is None
489
+ self.num_module = None
490
+ d_num = 0
491
+
492
+ elif num_embeddings is None:
493
+ assert bins is None
494
+ self.num_module = None
495
+ d_num = n_num_features
496
+ first_adapter_sections.extend(1 for _ in range(n_num_features))
497
+
498
+ else:
499
+ if bins is None:
500
+ self.num_module = make_module(
501
+ **num_embeddings, n_features=n_num_features
502
+ )
503
+ else:
504
+ assert num_embeddings['type'].startswith('PiecewiseLinearEmbeddings')
505
+ self.num_module = make_module(**num_embeddings, bins=bins)
506
+ d_num = n_num_features * num_embeddings['d_embedding']
507
+ first_adapter_sections.extend(
508
+ num_embeddings['d_embedding'] for _ in range(n_num_features)
509
+ )
510
+
511
+ # >>> Categorical features
512
+ self.cat_module = (
513
+ OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
514
+ )
515
+ first_adapter_sections.extend(cat_cardinalities)
516
+ d_cat = sum(cat_cardinalities)
517
+
518
+ # >>> Backbone
519
+ d_flat = d_num + d_cat
520
+ self.minimal_ensemble_adapter = None
521
+ # Any backbone can be here but we provide only MLP
522
+ self.backbone = make_module(d_in=d_flat, **backbone)
523
+
524
+ if arch_type != 'plain':
525
+ assert k is not None
526
+ first_adapter_init = (
527
+ None
528
+ if arch_type == 'tabm-packed'
529
+ else 'normal'
530
+ if arch_type in ('tabm-mini-normal', 'tabm-normal')
531
+ # For other arch_types, the initialization depends
532
+ # on the presence of num_embeddings.
533
+ else 'random-signs'
534
+ if num_embeddings is None
535
+ else 'normal'
536
+ )
537
+
538
+ if arch_type in ('tabm', 'tabm-normal'):
539
+ # Like BatchEnsemble, but all multiplicative adapters,
540
+ # except for the very first one, are initialized with ones.
541
+ assert first_adapter_init is not None
542
+ make_efficient_ensemble(
543
+ self.backbone,
544
+ LinearEfficientEnsemble,
545
+ k=k,
546
+ ensemble_scaling_in=True,
547
+ ensemble_scaling_out=True,
548
+ ensemble_bias=True,
549
+ scaling_init='ones',
550
+ )
551
+ _init_first_adapter(
552
+ _get_first_ensemble_layer(self.backbone).r, # type: ignore[code]
553
+ first_adapter_init,
554
+ first_adapter_sections,
555
+ )
556
+
557
+ elif arch_type in ('tabm-mini', 'tabm-mini-normal'):
558
+ # MiniEnsemble
559
+ assert first_adapter_init is not None
560
+ self.minimal_ensemble_adapter = ScaleEnsemble(
561
+ k,
562
+ d_flat,
563
+ init='random-signs' if num_embeddings is None else 'normal',
564
+ )
565
+ _init_first_adapter(
566
+ self.minimal_ensemble_adapter.weight, # type: ignore[code]
567
+ first_adapter_init,
568
+ first_adapter_sections,
569
+ )
570
+
571
+ elif arch_type == 'tabm-packed':
572
+ # Packed ensemble.
573
+ # In terms of the Packed Ensembles paper by Laurent et al.,
574
+ # TabM-packed is PackedEnsemble(alpha=k, M=k, gamma=1).
575
+ assert first_adapter_init is None
576
+ make_efficient_ensemble(self.backbone, NLinear, n=k)
577
+
578
+ else:
579
+ raise ValueError(f'Unknown arch_type: {arch_type}')
580
+
581
+ # >>> Output
582
+ d_block = backbone['d_block']
583
+ d_out = 1 if n_classes is None else n_classes
584
+ self.output = (
585
+ nn.Linear(d_block, d_out)
586
+ if arch_type == 'plain'
587
+ else NLinear(k, d_block, d_out) # type: ignore[code]
588
+ )
589
+
590
+ # >>>
591
+ self.arch_type = arch_type
592
+ self.k = k
593
+ self.share_training_batches = share_training_batches
594
+
595
+ def forward(
596
+ self, x_num: None | Tensor = None, x_cat: None | Tensor = None
597
+ ) -> Tensor:
598
+ x = []
599
+ if x_num is not None:
600
+ x.append(x_num if self.num_module is None else self.num_module(x_num))
601
+ if x_cat is None:
602
+ assert self.cat_module is None
603
+ else:
604
+ assert self.cat_module is not None
605
+ x.append(self.cat_module(x_cat).float())
606
+ x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
607
+
608
+ if self.k is not None:
609
+ if self.share_training_batches or not self.training:
610
+ # (B, D) -> (B, K, D)
611
+ x = x[:, None].expand(-1, self.k, -1)
612
+ else:
613
+ # (B * K, D) -> (B, K, D)
614
+ x = x.reshape(len(x) // self.k, self.k, *x.shape[1:])
615
+ if self.minimal_ensemble_adapter is not None:
616
+ x = self.minimal_ensemble_adapter(x)
617
+ else:
618
+ assert self.minimal_ensemble_adapter is None
619
+
620
+ x = self.backbone(x)
621
+ x = self.output(x)
622
+ if self.k is None:
623
+ # Adjust the output shape for plain networks to make them compatible
624
+ # with the rest of the script (loss, metrics, predictions, ...).
625
+ # (B, D_OUT) -> (B, 1, D_OUT)
626
+ x = x[:, None]
627
+ return x
@@ -311,11 +311,11 @@ class TabPFNMixModel(AbstractModel):
311
311
 
312
312
  def _get_maximum_resources(self) -> dict[str, int | float]:
313
313
  # torch model trains slower when utilizing virtual cores and this issue scale up when the number of cpu cores increases
314
- return {"num_cpus": ResourceManager.get_cpu_count_psutil(logical=False)}
314
+ return {"num_cpus": ResourceManager.get_cpu_count(only_physical_cores=True)}
315
315
 
316
316
  def _get_default_resources(self) -> tuple[int, float]:
317
- # logical=False is faster in training
318
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
317
+ # only_physical_cores=True is faster in training
318
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
319
319
  num_gpus = 0
320
320
  return num_cpus, num_gpus
321
321
 
@@ -814,11 +814,11 @@ class TabularNeuralNetTorchModel(AbstractNeuralNetworkModel):
814
814
 
815
815
  def _get_maximum_resources(self) -> Dict[str, Union[int, float]]:
816
816
  # torch model trains slower when utilizing virtual cores and this issue scale up when the number of cpu cores increases
817
- return {"num_cpus": ResourceManager.get_cpu_count_psutil(logical=False)}
817
+ return {"num_cpus": ResourceManager.get_cpu_count(only_physical_cores=True)}
818
818
 
819
819
  def _get_default_resources(self):
820
- # logical=False is faster in training
821
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
820
+ # only_physical_cores=True is faster in training
821
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
822
822
  num_gpus = 0
823
823
  return num_cpus, num_gpus
824
824
 
@@ -310,8 +310,8 @@ class XGBoostModel(AbstractModel):
310
310
 
311
311
  @disable_if_lite_mode(ret=(1, 0))
312
312
  def _get_default_resources(self):
313
- # logical=False is faster in training
314
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
313
+ # only_physical_cores=True is faster in training
314
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
315
315
  num_gpus = 0
316
316
  return num_cpus, num_gpus
317
317