autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/METADATA +27 -27
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260117-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/zip-safe +0 -0
@@ -50,9 +50,7 @@ class NLinear(nn.Module):
50
50
  any number of batch dimensions.
51
51
  """
52
52
 
53
- def __init__(
54
- self, n: int, in_features: int, out_features: int, bias: bool = True
55
- ) -> None:
53
+ def __init__(self, n: int, in_features: int, out_features: int, bias: bool = True) -> None:
56
54
  super().__init__()
57
55
  self.weight = nn.Parameter(torch.empty(n, in_features, out_features))
58
56
  self.bias = nn.Parameter(torch.empty(n, out_features)) if bias else None
@@ -117,7 +115,7 @@ class ScaleEnsemble(nn.Module):
117
115
  k: int,
118
116
  d: int,
119
117
  *,
120
- init: Literal['ones', 'normal', 'random-signs'],
118
+ init: Literal["ones", "normal", "random-signs"],
121
119
  ) -> None:
122
120
  super().__init__()
123
121
  self.weight = nn.Parameter(torch.empty(k, d))
@@ -125,14 +123,14 @@ class ScaleEnsemble(nn.Module):
125
123
  self.reset_parameters()
126
124
 
127
125
  def reset_parameters(self) -> None:
128
- if self._weight_init == 'ones':
126
+ if self._weight_init == "ones":
129
127
  nn.init.ones_(self.weight)
130
- elif self._weight_init == 'normal':
128
+ elif self._weight_init == "normal":
131
129
  nn.init.normal_(self.weight)
132
- elif self._weight_init == 'random-signs':
130
+ elif self._weight_init == "random-signs":
133
131
  init_random_signs_(self.weight)
134
132
  else:
135
- raise ValueError(f'Unknown weight_init: {self._weight_init}')
133
+ raise ValueError(f"Unknown weight_init: {self._weight_init}")
136
134
 
137
135
  def forward(self, x: Tensor) -> Tensor:
138
136
  assert x.ndim >= 2
@@ -175,7 +173,7 @@ class LinearEfficientEnsemble(nn.Module):
175
173
  ensemble_scaling_in: bool,
176
174
  ensemble_scaling_out: bool,
177
175
  ensemble_bias: bool,
178
- scaling_init: Literal['ones', 'random-signs'],
176
+ scaling_init: Literal["ones", "random-signs"],
179
177
  ):
180
178
  assert k > 0
181
179
  if ensemble_bias:
@@ -184,23 +182,15 @@ class LinearEfficientEnsemble(nn.Module):
184
182
 
185
183
  self.weight = nn.Parameter(torch.empty(out_features, in_features))
186
184
  self.register_parameter(
187
- 'r',
188
- (
189
- nn.Parameter(torch.empty(k, in_features))
190
- if ensemble_scaling_in
191
- else None
192
- ), # type: ignore[code]
185
+ "r",
186
+ (nn.Parameter(torch.empty(k, in_features)) if ensemble_scaling_in else None), # type: ignore[code]
193
187
  )
194
188
  self.register_parameter(
195
- 's',
196
- (
197
- nn.Parameter(torch.empty(k, out_features))
198
- if ensemble_scaling_out
199
- else None
200
- ), # type: ignore[code]
189
+ "s",
190
+ (nn.Parameter(torch.empty(k, out_features)) if ensemble_scaling_out else None), # type: ignore[code]
201
191
  )
202
192
  self.register_parameter(
203
- 'bias',
193
+ "bias",
204
194
  (
205
195
  nn.Parameter(torch.empty(out_features)) # type: ignore[code]
206
196
  if bias and not ensemble_bias
@@ -219,9 +209,7 @@ class LinearEfficientEnsemble(nn.Module):
219
209
 
220
210
  def reset_parameters(self):
221
211
  init_rsqrt_uniform_(self.weight, self.in_features)
222
- scaling_init_fn = {'ones': nn.init.ones_, 'random-signs': init_random_signs_}[
223
- self.scaling_init
224
- ]
212
+ scaling_init_fn = {"ones": nn.init.ones_, "random-signs": init_random_signs_}[self.scaling_init]
225
213
  if self.r is not None:
226
214
  scaling_init_fn(self.r)
227
215
  if self.s is not None:
@@ -266,7 +254,7 @@ class MLP(nn.Module):
266
254
  n_blocks: int,
267
255
  d_block: int,
268
256
  dropout: float,
269
- activation: str = 'ReLU',
257
+ activation: str = "ReLU",
270
258
  ) -> None:
271
259
  super().__init__()
272
260
 
@@ -319,13 +307,13 @@ def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
319
307
  if isinstance(backbone, MLP):
320
308
  return backbone.blocks[0][0] # type: ignore[code]
321
309
  else:
322
- raise RuntimeError(f'Unsupported backbone: {backbone}')
310
+ raise RuntimeError(f"Unsupported backbone: {backbone}")
323
311
 
324
312
 
325
313
  @torch.inference_mode()
326
314
  def _init_first_adapter(
327
315
  weight: Tensor,
328
- distribution: Literal['normal', 'random-signs'],
316
+ distribution: Literal["normal", "random-signs"],
329
317
  init_sections: list[int],
330
318
  ) -> None:
331
319
  """Initialize the first adapter.
@@ -338,12 +326,12 @@ def _init_first_adapter(
338
326
  assert weight.ndim == 2
339
327
  assert weight.shape[1] == sum(init_sections)
340
328
 
341
- if distribution == 'normal':
329
+ if distribution == "normal":
342
330
  init_fn_ = nn.init.normal_
343
- elif distribution == 'random-signs':
331
+ elif distribution == "random-signs":
344
332
  init_fn_ = init_random_signs_
345
333
  else:
346
- raise ValueError(f'Unknown distribution: {distribution}')
334
+ raise ValueError(f"Unknown distribution: {distribution}")
347
335
 
348
336
  section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()]
349
337
  for i in range(len(init_sections)):
@@ -386,7 +374,7 @@ def default_zero_weight_decay_condition(
386
374
  module_name: str, module: nn.Module, parameter_name: str, parameter: nn.Parameter
387
375
  ):
388
376
  del module_name, parameter
389
- return parameter_name.endswith('bias') or isinstance(
377
+ return parameter_name.endswith("bias") or isinstance(
390
378
  module,
391
379
  (
392
380
  nn.BatchNorm1d,
@@ -406,28 +394,20 @@ def make_parameter_groups(
406
394
  ) -> list[dict[str, Any]]:
407
395
  if custom_groups is None:
408
396
  custom_groups = []
409
- custom_params = frozenset(
410
- itertools.chain.from_iterable(group['params'] for group in custom_groups)
397
+ custom_params = frozenset(itertools.chain.from_iterable(group["params"] for group in custom_groups))
398
+ assert len(custom_params) == sum(len(group["params"]) for group in custom_groups), (
399
+ "Parameters in custom_groups must not intersect"
411
400
  )
412
- assert len(custom_params) == sum(
413
- len(group['params']) for group in custom_groups
414
- ), 'Parameters in custom_groups must not intersect'
415
401
  zero_wd_params = frozenset(
416
402
  p
417
403
  for mn, m in module.named_modules()
418
404
  for pn, p in m.named_parameters()
419
405
  if p not in custom_params and zero_weight_decay_condition(mn, m, pn, p)
420
406
  )
421
- default_group = {
422
- 'params': [
423
- p
424
- for p in module.parameters()
425
- if p not in custom_params and p not in zero_wd_params
426
- ]
427
- }
407
+ default_group = {"params": [p for p in module.parameters() if p not in custom_params and p not in zero_wd_params]}
428
408
  return [
429
409
  default_group,
430
- {'params': list(zero_wd_params), 'weight_decay': 0.0},
410
+ {"params": list(zero_wd_params), "weight_decay": 0.0},
431
411
  *custom_groups,
432
412
  ]
433
413
 
@@ -449,24 +429,24 @@ class Model(nn.Module):
449
429
  num_embeddings: Union[None, dict] = None,
450
430
  arch_type: Literal[
451
431
  # Plain feed-forward network without any kind of ensembling.
452
- 'plain',
432
+ "plain",
453
433
  #
454
434
  # TabM
455
- 'tabm',
435
+ "tabm",
456
436
  #
457
437
  # TabM-mini
458
- 'tabm-mini',
438
+ "tabm-mini",
459
439
  #
460
440
  # TabM-packed
461
- 'tabm-packed',
441
+ "tabm-packed",
462
442
  #
463
443
  # TabM. The first adapter is initialized from the normal distribution.
464
444
  # This variant was not used in the paper, but it may be useful in practice.
465
- 'tabm-normal',
445
+ "tabm-normal",
466
446
  #
467
447
  # TabM-mini. The adapter is initialized from the normal distribution.
468
448
  # This variant was not used in the paper.
469
- 'tabm-mini-normal',
449
+ "tabm-mini-normal",
470
450
  ],
471
451
  k: Union[None, int] = None,
472
452
  share_training_batches: bool = True,
@@ -474,11 +454,9 @@ class Model(nn.Module):
474
454
  # >>> Validate arguments.
475
455
  assert n_num_features >= 0
476
456
  assert n_num_features or cat_cardinalities
477
- if arch_type == 'plain':
457
+ if arch_type == "plain":
478
458
  assert k is None
479
- assert (
480
- share_training_batches
481
- ), 'If `arch_type` is set to "plain", then `simple` must remain True'
459
+ assert share_training_batches, 'If `arch_type` is set to "plain", then `simple` must remain True'
482
460
  else:
483
461
  assert k is not None
484
462
  assert k > 0
@@ -501,21 +479,15 @@ class Model(nn.Module):
501
479
 
502
480
  else:
503
481
  if bins is None:
504
- self.num_module = make_module(
505
- **num_embeddings, n_features=n_num_features
506
- )
482
+ self.num_module = make_module(**num_embeddings, n_features=n_num_features)
507
483
  else:
508
- assert num_embeddings['type'].startswith('PiecewiseLinearEmbeddings')
484
+ assert num_embeddings["type"].startswith("PiecewiseLinearEmbeddings")
509
485
  self.num_module = make_module(**num_embeddings, bins=bins)
510
- d_num = n_num_features * num_embeddings['d_embedding']
511
- first_adapter_sections.extend(
512
- num_embeddings['d_embedding'] for _ in range(n_num_features)
513
- )
486
+ d_num = n_num_features * num_embeddings["d_embedding"]
487
+ first_adapter_sections.extend(num_embeddings["d_embedding"] for _ in range(n_num_features))
514
488
 
515
489
  # >>> Categorical features
516
- self.cat_module = (
517
- OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
518
- )
490
+ self.cat_module = OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
519
491
  first_adapter_sections.extend(cat_cardinalities)
520
492
  d_cat = sum(cat_cardinalities)
521
493
 
@@ -525,21 +497,21 @@ class Model(nn.Module):
525
497
  # Any backbone can be here but we provide only MLP
526
498
  self.backbone = make_module(d_in=d_flat, **backbone)
527
499
 
528
- if arch_type != 'plain':
500
+ if arch_type != "plain":
529
501
  assert k is not None
530
502
  first_adapter_init = (
531
503
  None
532
- if arch_type == 'tabm-packed'
533
- else 'normal'
534
- if arch_type in ('tabm-mini-normal', 'tabm-normal')
504
+ if arch_type == "tabm-packed"
505
+ else "normal"
506
+ if arch_type in ("tabm-mini-normal", "tabm-normal")
535
507
  # For other arch_types, the initialization depends
536
508
  # on the presence of num_embeddings.
537
- else 'random-signs'
509
+ else "random-signs"
538
510
  if num_embeddings is None
539
- else 'normal'
511
+ else "normal"
540
512
  )
541
513
 
542
- if arch_type in ('tabm', 'tabm-normal'):
514
+ if arch_type in ("tabm", "tabm-normal"):
543
515
  # Like BatchEnsemble, but all multiplicative adapters,
544
516
  # except for the very first one, are initialized with ones.
545
517
  assert first_adapter_init is not None
@@ -550,7 +522,7 @@ class Model(nn.Module):
550
522
  ensemble_scaling_in=True,
551
523
  ensemble_scaling_out=True,
552
524
  ensemble_bias=True,
553
- scaling_init='ones',
525
+ scaling_init="ones",
554
526
  )
555
527
  _init_first_adapter(
556
528
  _get_first_ensemble_layer(self.backbone).r, # type: ignore[code]
@@ -558,13 +530,13 @@ class Model(nn.Module):
558
530
  first_adapter_sections,
559
531
  )
560
532
 
561
- elif arch_type in ('tabm-mini', 'tabm-mini-normal'):
533
+ elif arch_type in ("tabm-mini", "tabm-mini-normal"):
562
534
  # MiniEnsemble
563
535
  assert first_adapter_init is not None
564
536
  self.minimal_ensemble_adapter = ScaleEnsemble(
565
537
  k,
566
538
  d_flat,
567
- init='random-signs' if num_embeddings is None else 'normal',
539
+ init="random-signs" if num_embeddings is None else "normal",
568
540
  )
569
541
  _init_first_adapter(
570
542
  self.minimal_ensemble_adapter.weight, # type: ignore[code]
@@ -572,7 +544,7 @@ class Model(nn.Module):
572
544
  first_adapter_sections,
573
545
  )
574
546
 
575
- elif arch_type == 'tabm-packed':
547
+ elif arch_type == "tabm-packed":
576
548
  # Packed ensemble.
577
549
  # In terms of the Packed Ensembles paper by Laurent et al.,
578
550
  # TabM-packed is PackedEnsemble(alpha=k, M=k, gamma=1).
@@ -580,15 +552,13 @@ class Model(nn.Module):
580
552
  make_efficient_ensemble(self.backbone, NLinear, n=k)
581
553
 
582
554
  else:
583
- raise ValueError(f'Unknown arch_type: {arch_type}')
555
+ raise ValueError(f"Unknown arch_type: {arch_type}")
584
556
 
585
557
  # >>> Output
586
- d_block = backbone['d_block']
558
+ d_block = backbone["d_block"]
587
559
  d_out = 1 if n_classes is None else n_classes
588
560
  self.output = (
589
- nn.Linear(d_block, d_out)
590
- if arch_type == 'plain'
591
- else NLinear(k, d_block, d_out) # type: ignore[code]
561
+ nn.Linear(d_block, d_out) if arch_type == "plain" else NLinear(k, d_block, d_out) # type: ignore[code]
592
562
  )
593
563
 
594
564
  # >>>
@@ -596,9 +566,7 @@ class Model(nn.Module):
596
566
  self.k = k
597
567
  self.share_training_batches = share_training_batches
598
568
 
599
- def forward(
600
- self, x_num: Union[None, Tensor] = None, x_cat: Union[None, Tensor] = None
601
- ) -> Tensor:
569
+ def forward(self, x_num: Union[None, Tensor] = None, x_cat: Union[None, Tensor] = None) -> Tensor:
602
570
  x = []
603
571
  if x_num is not None:
604
572
  x.append(x_num if self.num_module is None else self.num_module(x_num))
@@ -9,19 +9,15 @@ import numpy as np
9
9
  import torch
10
10
 
11
11
 
12
- class EarlyStopping():
13
-
12
+ class EarlyStopping:
14
13
  def __init__(self, patience=10, delta=0.0001):
15
-
16
14
  self.patience = patience
17
15
  self.counter = 0
18
16
  self.best_score = None
19
17
  self.early_stop = False
20
18
  self.delta = delta
21
19
 
22
-
23
20
  def __call__(self, val_loss):
24
-
25
21
  score = -val_loss
26
22
 
27
23
  if self.best_score is None:
@@ -56,7 +52,7 @@ class Checkpoint:
56
52
  self.buffer = io.BytesIO()
57
53
  self.best_model = None
58
54
  self.best_epoch = None
59
-
55
+
60
56
  def reset(self):
61
57
  self.curr_best_loss = np.inf
62
58
  self.best_model = None
@@ -70,7 +66,7 @@ class Checkpoint:
70
66
  self.best_epoch = epoch
71
67
  if self.save_best:
72
68
  self.save()
73
-
69
+
74
70
  def save(self):
75
71
  if self.in_memory:
76
72
  self.buffer = io.BytesIO()
@@ -87,15 +83,12 @@ class Checkpoint:
87
83
  return torch.load(self.path) # nosec B614
88
84
 
89
85
 
90
-
91
-
92
- class EpochStatistics():
93
-
86
+ class EpochStatistics:
94
87
  def __init__(self) -> None:
95
88
  self.n = 0
96
89
  self.loss = 0
97
90
  self.score = 0
98
-
91
+
99
92
  def update(self, loss, score, n):
100
93
  self.n += n
101
94
  self.loss += loss * n
@@ -103,11 +96,9 @@ class EpochStatistics():
103
96
 
104
97
  def get(self):
105
98
  return self.loss / self.n, self.score / self.n
106
-
107
-
108
99
 
109
- class TrackOutput():
110
100
 
101
+ class TrackOutput:
111
102
  def __init__(self) -> None:
112
103
  self.y_true: list[np.ndarray] = []
113
104
  self.y_pred: list[np.ndarray] = []
@@ -117,4 +108,4 @@ class TrackOutput():
117
108
  self.y_pred.append(y_pred)
118
109
 
119
110
  def get(self):
120
- return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
111
+ return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
@@ -3,44 +3,36 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
 
6
- class CollatorWithPadding():
7
-
8
- def __init__(
9
- self,
10
- pad_to_n_support_samples: Optional[int]
11
- ) -> None:
12
-
6
+ class CollatorWithPadding:
7
+ def __init__(self, pad_to_n_support_samples: Optional[int]) -> None:
13
8
  self.pad_to_n_support_samples = pad_to_n_support_samples
14
9
 
15
-
16
10
  def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
17
-
18
11
  if self.pad_to_n_support_samples is not None:
19
- assert all(dataset['x_support'].shape[0] <= self.pad_to_n_support_samples for dataset in batch)
12
+ assert all(dataset["x_support"].shape[0] <= self.pad_to_n_support_samples for dataset in batch)
20
13
  self.n_support_samples = self.pad_to_n_support_samples
21
14
  else:
22
- self.n_support_samples = max(dataset['x_support'].shape[0] for dataset in batch)
15
+ self.n_support_samples = max(dataset["x_support"].shape[0] for dataset in batch)
23
16
 
24
- max_query_samples = max(dataset['x_query'].shape[0] for dataset in batch)
17
+ max_query_samples = max(dataset["x_query"].shape[0] for dataset in batch)
25
18
 
26
- n_support_features = batch[0]['x_support'].shape[1]
27
- n_query_features = batch[0]['x_query'].shape[1]
28
- y_dtype = batch[0]['y_support'].dtype
19
+ n_support_features = batch[0]["x_support"].shape[1]
20
+ n_query_features = batch[0]["x_query"].shape[1]
21
+ y_dtype = batch[0]["y_support"].dtype
29
22
 
30
23
  batch_size = len(batch)
31
24
 
32
25
  tensor_dict = {
33
- 'x_support': torch.zeros((batch_size, self.n_support_samples, n_support_features), dtype=torch.float32),
34
- 'y_support': torch.zeros((batch_size, self.n_support_samples), dtype=y_dtype),
35
- 'x_query': torch.zeros((batch_size, max_query_samples, n_query_features), dtype=torch.float32),
36
- 'y_query': torch.zeros((batch_size, max_query_samples), dtype=y_dtype)
26
+ "x_support": torch.zeros((batch_size, self.n_support_samples, n_support_features), dtype=torch.float32),
27
+ "y_support": torch.zeros((batch_size, self.n_support_samples), dtype=y_dtype),
28
+ "x_query": torch.zeros((batch_size, max_query_samples, n_query_features), dtype=torch.float32),
29
+ "y_query": torch.zeros((batch_size, max_query_samples), dtype=y_dtype),
37
30
  }
38
31
 
39
32
  for i, dataset in enumerate(batch):
40
- tensor_dict['x_support'][i, :dataset['x_support'].shape[0], :] = dataset['x_support']
41
- tensor_dict['y_support'][i, :dataset['y_support'].shape[0]] = dataset['y_support']
42
- tensor_dict['x_query'][i, :dataset['x_query'].shape[0], :] = dataset['x_query']
43
- tensor_dict['y_query'][i, :dataset['y_query'].shape[0]] = dataset['y_query']
33
+ tensor_dict["x_support"][i, : dataset["x_support"].shape[0], :] = dataset["x_support"]
34
+ tensor_dict["y_support"][i, : dataset["y_support"].shape[0]] = dataset["y_support"]
35
+ tensor_dict["x_query"][i, : dataset["x_query"].shape[0], :] = dataset["x_query"]
36
+ tensor_dict["y_query"][i, : dataset["y_query"].shape[0]] = dataset["y_query"]
44
37
 
45
38
  return tensor_dict
46
-
@@ -5,7 +5,9 @@ from sklearn.model_selection import StratifiedKFold, train_test_split
5
5
  from .enums import Task
6
6
 
7
7
 
8
- def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, random_state: Generator = None) -> tuple[np.ndarray, ...]:
8
+ def make_dataset_split(
9
+ x: np.ndarray, y: np.ndarray, task: Task, random_state: Generator = None
10
+ ) -> tuple[np.ndarray, ...]:
9
11
  # Splits the dataset into train and validation sets with ratio 80/20
10
12
 
11
13
  if task == Task.CLASSIFICATION and np.min(np.bincount(y)) >= 5:
@@ -13,10 +15,9 @@ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, random_state: G
13
15
  return make_stratified_dataset_split(x, y, rng=random_state)
14
16
  else:
15
17
  return make_standard_dataset_split(x, y, rng=random_state)
16
-
17
18
 
18
- def make_stratified_dataset_split(x, y, rng: Generator = None):
19
19
 
20
+ def make_stratified_dataset_split(x, y, rng: Generator = None):
20
21
  # Stratify doesn't shuffle the data, so we shuffle it first
21
22
  permutation = rng.permutation(len(y))
22
23
  x, y = x[permutation], y[permutation]
@@ -30,7 +31,4 @@ def make_stratified_dataset_split(x, y, rng: Generator = None):
30
31
 
31
32
 
32
33
  def make_standard_dataset_split(x, y, rng: Generator = None):
33
-
34
- return train_test_split(
35
- x, y, test_size=0.2, random_state=rng.integers(low=0, high=1000000)
36
- )
34
+ return train_test_split(x, y, test_size=0.2, random_state=rng.integers(low=0, high=1000000))
@@ -1,5 +1,3 @@
1
-
2
-
3
1
  class Task:
4
2
  CLASSIFICATION = "classification"
5
3
  REGRESSION = "regression"
@@ -4,7 +4,6 @@ from .enums import Task
4
4
 
5
5
 
6
6
  def get_loss(task: Task):
7
-
8
7
  if task == Task.REGRESSION:
9
8
  return torch.nn.MSELoss()
10
9
  else:
@@ -3,30 +3,19 @@ from torch.optim import SGD, Adam, AdamW
3
3
 
4
4
 
5
5
  def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
6
-
7
6
  optimizer: torch.optim.Optimizer
8
7
 
9
- if hyperparams['optimizer'] == "adam":
8
+ if hyperparams["optimizer"] == "adam":
10
9
  optimizer = Adam(
11
- model.parameters(),
12
- lr=hyperparams['lr'],
13
- betas=(0.9, 0.999),
14
- weight_decay=hyperparams['weight_decay']
10
+ model.parameters(), lr=hyperparams["lr"], betas=(0.9, 0.999), weight_decay=hyperparams["weight_decay"]
15
11
  )
16
- elif hyperparams['optimizer'] == "adamw":
12
+ elif hyperparams["optimizer"] == "adamw":
17
13
  optimizer = AdamW(
18
- model.parameters(),
19
- lr=hyperparams['lr'],
20
- betas=(0.9, 0.999),
21
- weight_decay=hyperparams['weight_decay']
22
- )
23
- elif hyperparams['optimizer'] == "sgd":
24
- optimizer = SGD(
25
- model.parameters(),
26
- lr=hyperparams['lr'],
27
- weight_decay=hyperparams['weight_decay']
14
+ model.parameters(), lr=hyperparams["lr"], betas=(0.9, 0.999), weight_decay=hyperparams["weight_decay"]
28
15
  )
16
+ elif hyperparams["optimizer"] == "sgd":
17
+ optimizer = SGD(model.parameters(), lr=hyperparams["lr"], weight_decay=hyperparams["weight_decay"])
29
18
  else:
30
19
  raise ValueError("Optimizer not recognized")
31
-
20
+
32
21
  return optimizer
@@ -3,20 +3,9 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
3
3
 
4
4
 
5
5
  def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer):
6
-
7
- if hyperparams['lr_scheduler']:
8
- scheduler = ReduceLROnPlateau(
9
- optimizer,
10
- patience=hyperparams['lr_scheduler_patience'],
11
- min_lr=0,
12
- factor=0.2
13
- )
6
+ if hyperparams["lr_scheduler"]:
7
+ scheduler = ReduceLROnPlateau(optimizer, patience=hyperparams["lr_scheduler_patience"], min_lr=0, factor=0.2)
14
8
  else:
15
- scheduler = ReduceLROnPlateau(
16
- optimizer,
17
- patience=10000000,
18
- min_lr=0,
19
- factor=0.2
20
- )
9
+ scheduler = ReduceLROnPlateau(optimizer, patience=10000000, min_lr=0, factor=0.2)
21
10
 
22
11
  return scheduler