autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/METADATA +27 -27
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260117-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/zip-safe +0 -0
@@ -3,15 +3,15 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
- __version__ = '0.0.12'
6
+ __version__ = "0.0.12"
7
7
 
8
8
  __all__ = [
9
- 'LinearEmbeddings',
10
- 'LinearReLUEmbeddings',
11
- 'PeriodicEmbeddings',
12
- 'PiecewiseLinearEmbeddings',
13
- 'PiecewiseLinearEncoding',
14
- 'compute_bins',
9
+ "LinearEmbeddings",
10
+ "LinearReLUEmbeddings",
11
+ "PeriodicEmbeddings",
12
+ "PiecewiseLinearEmbeddings",
13
+ "PiecewiseLinearEncoding",
14
+ "compute_bins",
15
15
  ]
16
16
 
17
17
 
@@ -37,13 +37,10 @@ except ImportError:
37
37
 
38
38
  def _check_input_shape(x: Tensor, expected_n_features: int) -> None:
39
39
  if x.ndim < 1:
40
- raise ValueError(
41
- f'The input must have at least one dimension, however: {x.ndim=}'
42
- )
40
+ raise ValueError(f"The input must have at least one dimension, however: {x.ndim=}")
43
41
  if x.shape[-1] != expected_n_features:
44
42
  raise ValueError(
45
- 'The last dimension of the input was expected to be'
46
- f' {expected_n_features}, however, {x.shape[-1]=}'
43
+ f"The last dimension of the input was expected to be {expected_n_features}, however, {x.shape[-1]=}"
47
44
  )
48
45
 
49
46
 
@@ -75,9 +72,9 @@ class LinearEmbeddings(nn.Module):
75
72
  d_embedding: the embedding size.
76
73
  """
77
74
  if n_features <= 0:
78
- raise ValueError(f'n_features must be positive, however: {n_features=}')
75
+ raise ValueError(f"n_features must be positive, however: {n_features=}")
79
76
  if d_embedding <= 0:
80
- raise ValueError(f'd_embedding must be positive, however: {d_embedding=}')
77
+ raise ValueError(f"d_embedding must be positive, however: {d_embedding=}")
81
78
 
82
79
  super().__init__()
83
80
  self.weight = Parameter(torch.empty(n_features, d_embedding))
@@ -153,7 +150,7 @@ class _Periodic(nn.Module):
153
150
 
154
151
  def __init__(self, n_features: int, k: int, sigma: float) -> None:
155
152
  if sigma <= 0.0:
156
- raise ValueError(f'sigma must be positive, however: {sigma=}')
153
+ raise ValueError(f"sigma must be positive, however: {sigma=}")
157
154
 
158
155
  super().__init__()
159
156
  self._sigma = sigma
@@ -185,9 +182,7 @@ class _NLinear(nn.Module):
185
182
  each feature embedding is transformed by its own dedicated linear layer.
186
183
  """
187
184
 
188
- def __init__(
189
- self, n: int, in_features: int, out_features: int, bias: bool = True
190
- ) -> None:
185
+ def __init__(self, n: int, in_features: int, out_features: int, bias: bool = True) -> None:
191
186
  super().__init__()
192
187
  self.weight = Parameter(torch.empty(n, in_features, out_features))
193
188
  self.bias = Parameter(torch.empty(n, out_features)) if bias else None
@@ -204,8 +199,8 @@ class _NLinear(nn.Module):
204
199
  """Do the forward pass."""
205
200
  if x.ndim != 3:
206
201
  raise ValueError(
207
- '_NLinear supports only inputs with exactly one batch dimension,'
208
- ' so `x` must have a shape like (BATCH_SIZE, N_FEATURES, D_EMBEDDING).'
202
+ "_NLinear supports only inputs with exactly one batch dimension,"
203
+ " so `x` must have a shape like (BATCH_SIZE, N_FEATURES, D_EMBEDDING)."
209
204
  )
210
205
  assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]
211
206
 
@@ -286,7 +281,7 @@ class PeriodicEmbeddings(nn.Module):
286
281
  # The lite variation was introduced in a different paper
287
282
  # (about the TabR model).
288
283
  if not activation:
289
- raise ValueError('lite=True is allowed only when activation=True')
284
+ raise ValueError("lite=True is allowed only when activation=True")
290
285
  self.linear = nn.Linear(2 * n_frequencies, d_embedding)
291
286
  else:
292
287
  self.linear = _NLinear(n_features, 2 * n_frequencies, d_embedding)
@@ -296,9 +291,7 @@ class PeriodicEmbeddings(nn.Module):
296
291
  """Get the output shape without the batch dimensions."""
297
292
  n_features = self.periodic.weight.shape[0]
298
293
  d_embedding = (
299
- self.linear.weight.shape[0]
300
- if isinstance(self.linear, nn.Linear)
301
- else self.linear.weight.shape[-1]
294
+ self.linear.weight.shape[0] if isinstance(self.linear, nn.Linear) else self.linear.weight.shape[-1]
302
295
  )
303
296
  return torch.Size((n_features, d_embedding))
304
297
 
@@ -313,32 +306,23 @@ class PeriodicEmbeddings(nn.Module):
313
306
 
314
307
  def _check_bins(bins: list[Tensor]) -> None:
315
308
  if not bins:
316
- raise ValueError('The list of bins must not be empty')
309
+ raise ValueError("The list of bins must not be empty")
317
310
  for i, feature_bins in enumerate(bins):
318
311
  if not isinstance(feature_bins, Tensor):
319
- raise ValueError(
320
- 'bins must be a list of PyTorch tensors. '
321
- f'However, for {i=}: {type(bins[i])=}'
322
- )
312
+ raise ValueError(f"bins must be a list of PyTorch tensors. However, for {i=}: {type(bins[i])=}")
323
313
  if feature_bins.ndim != 1:
324
314
  raise ValueError(
325
- 'Each item of the bin list must have exactly one dimension.'
326
- f' However, for {i=}: {bins[i].ndim=}'
315
+ f"Each item of the bin list must have exactly one dimension. However, for {i=}: {bins[i].ndim=}"
327
316
  )
328
317
  if len(feature_bins) < 2:
329
- raise ValueError(
330
- 'All features must have at least two bin edges.'
331
- f' However, for {i=}: {len(bins[i])=}'
332
- )
318
+ raise ValueError(f"All features must have at least two bin edges. However, for {i=}: {len(bins[i])=}")
333
319
  if not feature_bins.isfinite().all():
334
320
  raise ValueError(
335
- 'Bin edges must not contain nan/inf/-inf.'
336
- f' However, this is not true for the {i}-th feature'
321
+ f"Bin edges must not contain nan/inf/-inf. However, this is not true for the {i}-th feature"
337
322
  )
338
323
  if (feature_bins[:-1] >= feature_bins[1:]).any():
339
324
  raise ValueError(
340
- 'Bin edges must be sorted.'
341
- f' However, the for the {i}-th feature, the bin edges are not sorted'
325
+ f"Bin edges must be sorted. However, the for the {i}-th feature, the bin edges are not sorted"
342
326
  )
343
327
  # Commented out due to spaming warnings.
344
328
  # if len(feature_bins) == 2:
@@ -399,53 +383,49 @@ def compute_bins(
399
383
  - the minimum possible number of bin edges is ``1``.
400
384
  """ # noqa: E501
401
385
  if not isinstance(X, Tensor):
402
- raise ValueError(f'X must be a PyTorch tensor, however: {type(X)=}')
386
+ raise ValueError(f"X must be a PyTorch tensor, however: {type(X)=}")
403
387
  if X.ndim != 2:
404
- raise ValueError(f'X must have exactly two dimensions, however: {X.ndim=}')
388
+ raise ValueError(f"X must have exactly two dimensions, however: {X.ndim=}")
405
389
  if X.shape[0] < 2:
406
- raise ValueError(f'X must have at least two rows, however: {X.shape[0]=}')
390
+ raise ValueError(f"X must have at least two rows, however: {X.shape[0]=}")
407
391
  if X.shape[1] < 1:
408
- raise ValueError(f'X must have at least one column, however: {X.shape[1]=}')
392
+ raise ValueError(f"X must have at least one column, however: {X.shape[1]=}")
409
393
  if not X.isfinite().all():
410
- raise ValueError('X must not contain nan/inf/-inf.')
394
+ raise ValueError("X must not contain nan/inf/-inf.")
411
395
  if (X == X[0]).all(dim=0).any():
412
396
  raise ValueError(
413
- 'All columns of X must have at least two distinct values.'
414
- ' However, X contains columns with just one distinct value.'
397
+ "All columns of X must have at least two distinct values."
398
+ " However, X contains columns with just one distinct value."
415
399
  )
416
400
  if n_bins <= 1 or n_bins >= len(X):
417
- raise ValueError(
418
- 'n_bins must be more than 1, but less than len(X), however:'
419
- f' {n_bins=}, {len(X)=}'
420
- )
401
+ raise ValueError(f"n_bins must be more than 1, but less than len(X), however: {n_bins=}, {len(X)=}")
421
402
 
422
403
  if tree_kwargs is None:
423
404
  if y is not None or regression is not None or verbose:
424
405
  raise ValueError(
425
- 'If tree_kwargs is None, then y must be None, regression must be None'
426
- ' and verbose must be False'
406
+ "If tree_kwargs is None, then y must be None, regression must be None and verbose must be False"
427
407
  )
428
408
 
429
409
  _upper = 2**24 # 16_777_216
430
410
  if len(X) > _upper:
431
411
  warnings.warn(
432
- f'Computing quantile-based bins for more than {_upper} million objects'
433
- ' may not be possible due to the limitation of PyTorch'
434
- ' (for details, see https://github.com/pytorch/pytorch/issues/64947;'
435
- ' if that issue is successfully resolved, this warning may be irrelevant).' # noqa
436
- ' As a workaround, subsample the data, i.e. instead of'
437
- '\ncompute_bins(X, ...)'
438
- '\ndo'
439
- '\ncompute_bins(X[torch.randperm(len(X), device=X.device)[:16_777_216]], ...)' # noqa
440
- '\nOn CUDA, the computation can still fail with OOM even after'
441
- ' subsampling. If this is the case, try passing features by groups:'
442
- '\nbins = sum('
443
- '\n compute_bins(X[:, idx], ...)'
444
- '\n for idx in torch.arange(len(X), device=X.device).split(group_size),' # noqa
445
- '\n start=[]'
446
- '\n)'
447
- '\nAnother option is to perform the computation on CPU:'
448
- '\ncompute_bins(X.cpu(), ...)'
412
+ f"Computing quantile-based bins for more than {_upper} million objects"
413
+ " may not be possible due to the limitation of PyTorch"
414
+ " (for details, see https://github.com/pytorch/pytorch/issues/64947;"
415
+ " if that issue is successfully resolved, this warning may be irrelevant)." # noqa
416
+ " As a workaround, subsample the data, i.e. instead of"
417
+ "\ncompute_bins(X, ...)"
418
+ "\ndo"
419
+ "\ncompute_bins(X[torch.randperm(len(X), device=X.device)[:16_777_216]], ...)" # noqa
420
+ "\nOn CUDA, the computation can still fail with OOM even after"
421
+ " subsampling. If this is the case, try passing features by groups:"
422
+ "\nbins = sum("
423
+ "\n compute_bins(X[:, idx], ...)"
424
+ "\n for idx in torch.arange(len(X), device=X.device).split(group_size)," # noqa
425
+ "\n start=[]"
426
+ "\n)"
427
+ "\nAnother option is to perform the computation on CPU:"
428
+ "\ncompute_bins(X.cpu(), ...)"
449
429
  )
450
430
  del _upper
451
431
 
@@ -458,53 +438,38 @@ def compute_bins(
458
438
  # https://github.com/yandex-research/tabular-dl-num-embeddings/blob/c1d9eb63c0685b51d7e1bc081cdce6ffdb8886a8/bin/train4.py#L612C30-L612C30
459
439
  # (explanation: limiting the number of quantiles by the number of distinct
460
440
  # values is NOT the same as removing identical quantiles after computing them).
461
- bins = [
462
- q.unique()
463
- for q in torch.quantile(
464
- X, torch.linspace(0.0, 1.0, n_bins + 1).to(X), dim=0
465
- ).T
466
- ]
441
+ bins = [q.unique() for q in torch.quantile(X, torch.linspace(0.0, 1.0, n_bins + 1).to(X), dim=0).T]
467
442
  _check_bins(bins)
468
443
  return bins
469
444
 
470
445
  else:
471
446
  if sklearn_tree is None:
472
- raise RuntimeError(
473
- 'The scikit-learn package is missing.'
474
- ' See README.md for installation instructions'
475
- )
447
+ raise RuntimeError("The scikit-learn package is missing. See README.md for installation instructions")
476
448
  if y is None or regression is None:
477
- raise ValueError(
478
- 'If tree_kwargs is not None, then y and regression must not be None'
479
- )
449
+ raise ValueError("If tree_kwargs is not None, then y and regression must not be None")
480
450
  if y.ndim != 1:
481
- raise ValueError(f'y must have exactly one dimension, however: {y.ndim=}')
451
+ raise ValueError(f"y must have exactly one dimension, however: {y.ndim=}")
482
452
  if len(y) != len(X):
483
- raise ValueError(
484
- f'len(y) must be equal to len(X), however: {len(y)=}, {len(X)=}'
485
- )
453
+ raise ValueError(f"len(y) must be equal to len(X), however: {len(y)=}, {len(X)=}")
486
454
  if y is None or regression is None:
455
+ raise ValueError("If tree_kwargs is not None, then y and regression must not be None")
456
+ if "max_leaf_nodes" in tree_kwargs:
487
457
  raise ValueError(
488
- 'If tree_kwargs is not None, then y and regression must not be None'
489
- )
490
- if 'max_leaf_nodes' in tree_kwargs:
491
- raise ValueError(
492
- 'tree_kwargs must not contain the key "max_leaf_nodes"'
493
- ' (it will be set to n_bins automatically).'
458
+ 'tree_kwargs must not contain the key "max_leaf_nodes" (it will be set to n_bins automatically).'
494
459
  )
495
460
 
496
461
  if verbose:
497
462
  if tqdm is None:
498
- raise ImportError('If verbose is True, tqdm must be installed')
463
+ raise ImportError("If verbose is True, tqdm must be installed")
499
464
  tqdm_ = tqdm
500
465
  else:
501
466
  tqdm_ = lambda x: x # noqa: E731
502
467
 
503
- if X.device.type != 'cpu' or y.device.type != 'cpu':
468
+ if X.device.type != "cpu" or y.device.type != "cpu":
504
469
  warnings.warn(
505
- 'Computing tree-based bins involves the conversion of the input PyTorch'
506
- ' tensors to NumPy arrays. The provided PyTorch tensors are not'
507
- ' located on CPU, so the conversion has some overhead.',
470
+ "Computing tree-based bins involves the conversion of the input PyTorch"
471
+ " tensors to NumPy arrays. The provided PyTorch tensors are not"
472
+ " located on CPU, so the conversion has some overhead.",
508
473
  UserWarning,
509
474
  )
510
475
  X_numpy = X.cpu().numpy()
@@ -513,11 +478,9 @@ def compute_bins(
513
478
  for column in tqdm_(X_numpy.T):
514
479
  feature_bin_edges = [float(column.min()), float(column.max())]
515
480
  tree = (
516
- (
517
- sklearn_tree.DecisionTreeRegressor
518
- if regression
519
- else sklearn_tree.DecisionTreeClassifier
520
- )(max_leaf_nodes=n_bins, **tree_kwargs)
481
+ (sklearn_tree.DecisionTreeRegressor if regression else sklearn_tree.DecisionTreeClassifier)(
482
+ max_leaf_nodes=n_bins, **tree_kwargs
483
+ )
521
484
  .fit(column.reshape(-1, 1), y_numpy)
522
485
  .tree_
523
486
  )
@@ -605,16 +568,14 @@ class _PiecewiseLinearEncodingImpl(nn.Module):
605
568
  n_bins = [len(x) - 1 for x in bins]
606
569
  max_n_bins = max(n_bins)
607
570
 
608
- self.register_buffer('weight', torch.zeros(n_features, max_n_bins))
609
- self.register_buffer('bias', torch.zeros(n_features, max_n_bins))
571
+ self.register_buffer("weight", torch.zeros(n_features, max_n_bins))
572
+ self.register_buffer("bias", torch.zeros(n_features, max_n_bins))
610
573
 
611
574
  single_bin_mask = torch.tensor(n_bins) == 1
612
- self.register_buffer(
613
- 'single_bin_mask', single_bin_mask if single_bin_mask.any() else None
614
- )
575
+ self.register_buffer("single_bin_mask", single_bin_mask if single_bin_mask.any() else None)
615
576
 
616
577
  self.register_buffer(
617
- 'mask',
578
+ "mask",
618
579
  # The mask is needed if features have different number of bins.
619
580
  None
620
581
  if all(len(x) == len(bins[0]) for x in bins)
@@ -713,9 +674,7 @@ class PiecewiseLinearEncoding(nn.Module):
713
674
  def get_output_shape(self) -> torch.Size:
714
675
  """Get the output shape without the batch dimensions."""
715
676
  total_n_bins = (
716
- self.impl.weight.shape.numel()
717
- if self.impl.mask is None
718
- else int(self.impl.mask.long().sum().cpu().item())
677
+ self.impl.weight.shape.numel() if self.impl.mask is None else int(self.impl.mask.long().sum().cpu().item())
719
678
  )
720
679
  return torch.Size((total_n_bins,))
721
680
 
@@ -740,7 +699,7 @@ class PiecewiseLinearEmbeddings(nn.Module):
740
699
  d_embedding: int,
741
700
  *,
742
701
  activation: bool,
743
- version: Literal[None, 'A', 'B'] = None,
702
+ version: Literal[None, "A", "B"] = None,
744
703
  ) -> None:
745
704
  """
746
705
  Args:
@@ -751,28 +710,24 @@ class PiecewiseLinearEmbeddings(nn.Module):
751
710
  parametrization and initialization. See README for details.
752
711
  """
753
712
  if d_embedding <= 0:
754
- raise ValueError(
755
- f'd_embedding must be a positive integer, however: {d_embedding=}'
756
- )
713
+ raise ValueError(f"d_embedding must be a positive integer, however: {d_embedding=}")
757
714
  _check_bins(bins)
758
715
  if version is None:
759
716
  warnings.warn(
760
717
  'The `version` argument is not provided, so version="A" will be used'
761
- ' for backward compatibility.'
762
- ' See README for recommendations regarding `version`.'
763
- ' In future, omitting this argument will result in an exception.'
718
+ " for backward compatibility."
719
+ " See README for recommendations regarding `version`."
720
+ " In future, omitting this argument will result in an exception."
764
721
  )
765
- version = 'A'
722
+ version = "A"
766
723
 
767
724
  super().__init__()
768
725
  n_features = len(bins)
769
726
  # NOTE[DIFF]
770
727
  # version="B" was introduced in a different paper (about the TabM model).
771
- is_version_B = version == 'B'
728
+ is_version_B = version == "B"
772
729
 
773
- self.linear0 = (
774
- LinearEmbeddings(n_features, d_embedding) if is_version_B else None
775
- )
730
+ self.linear0 = LinearEmbeddings(n_features, d_embedding) if is_version_B else None
776
731
  self.impl = _PiecewiseLinearEncodingImpl(bins)
777
732
  self.linear = _NLinear(
778
733
  len(bins),
@@ -797,9 +752,7 @@ class PiecewiseLinearEmbeddings(nn.Module):
797
752
  def forward(self, x: Tensor) -> Tensor:
798
753
  """Do the forward pass."""
799
754
  if x.ndim != 2:
800
- raise ValueError(
801
- 'For now, only inputs with exactly one batch dimension are supported.'
802
- )
755
+ raise ValueError("For now, only inputs with exactly one batch dimension are supported.")
803
756
 
804
757
  x_linear = None if self.linear0 is None else self.linear0(x)
805
758
 
@@ -807,4 +760,4 @@ class PiecewiseLinearEmbeddings(nn.Module):
807
760
  x_ple = self.linear(x_ple)
808
761
  if self.activation is not None:
809
762
  x_ple = self.activation(x_ple)
810
- return x_ple if x_linear is None else x_linear + x_ple
763
+ return x_ple if x_linear is None else x_linear + x_ple
@@ -36,6 +36,7 @@ class TabMModel(AbstractTorchModel):
36
36
 
37
37
  .. versionadded:: 1.4.0
38
38
  """
39
+
39
40
  ag_key = "TABM"
40
41
  ag_name = "TabM"
41
42
  ag_priority = 85
@@ -239,9 +240,12 @@ class TabMModel(AbstractTorchModel):
239
240
 
240
241
  # not completely sure
241
242
  n_params_num_emb = n_numerical * (num_emb_n_bins + 1) * d_embedding
242
- n_params_mlp = (n_numerical + sum(cat_sizes)) * d_embedding * (d_block + tabm_k) \
243
- + (n_blocks - 1) * d_block ** 2 \
244
- + n_blocks * d_block + d_block * (1 + max(1, n_classes))
243
+ n_params_mlp = (
244
+ (n_numerical + sum(cat_sizes)) * d_embedding * (d_block + tabm_k)
245
+ + (n_blocks - 1) * d_block**2
246
+ + n_blocks * d_block
247
+ + d_block * (1 + max(1, n_classes))
248
+ )
245
249
  # 4 bytes per float, up to 5 copies of parameters (1 standard, 1 .grad, 2 adam, 1 best_epoch)
246
250
  mem_params = 4 * 5 * (n_params_num_emb + n_params_mlp)
247
251
 
@@ -259,7 +263,7 @@ class TabMModel(AbstractTorchModel):
259
263
  mem_ds = n_samples * (4 * n_numerical + 8 * len(cat_sizes))
260
264
 
261
265
  # some safety constants and offsets (the 5 is probably excessive)
262
- mem_total = 5 * mem_ds + 1.2 * mem_forward_backward + 1.2 * mem_params + 0.3 * (1024 ** 3)
266
+ mem_total = 5 * mem_ds + 1.2 * mem_forward_backward + 1.2 * mem_params + 0.3 * (1024**3)
263
267
 
264
268
  return mem_total
265
269