autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/pipeline_presets.py +130 -0
  4. autogluon/tabular/configs/presets_configs.py +51 -26
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  7. autogluon/tabular/models/__init__.py +6 -1
  8. autogluon/tabular/models/_utils/rapids_utils.py +1 -1
  9. autogluon/tabular/models/automm/automm_model.py +2 -0
  10. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  11. autogluon/tabular/models/catboost/callbacks.py +3 -2
  12. autogluon/tabular/models/catboost/catboost_model.py +15 -9
  13. autogluon/tabular/models/catboost/catboost_utils.py +17 -3
  14. autogluon/tabular/models/ebm/__init__.py +0 -0
  15. autogluon/tabular/models/ebm/ebm_model.py +259 -0
  16. autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
  17. autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  18. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  19. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
  20. autogluon/tabular/models/knn/knn_model.py +7 -3
  21. autogluon/tabular/models/lgb/lgb_model.py +60 -21
  22. autogluon/tabular/models/lr/lr_model.py +6 -1
  23. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  24. autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
  25. autogluon/tabular/models/mitra/__init__.py +0 -0
  26. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  27. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  28. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  29. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  30. autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
  31. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  32. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  33. autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
  34. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  35. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  36. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
  37. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
  38. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  39. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  40. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
  41. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
  42. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  43. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  44. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  45. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  46. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  47. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  48. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  49. autogluon/tabular/models/mitra/mitra_model.py +380 -0
  50. autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
  51. autogluon/tabular/models/realmlp/__init__.py +0 -0
  52. autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
  53. autogluon/tabular/models/rf/rf_model.py +11 -6
  54. autogluon/tabular/models/tabicl/__init__.py +0 -0
  55. autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
  56. autogluon/tabular/models/tabm/__init__.py +0 -0
  57. autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
  58. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
  59. autogluon/tabular/models/tabm/tabm_model.py +356 -0
  60. autogluon/tabular/models/tabm/tabm_reference.py +631 -0
  61. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
  62. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  63. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  64. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  65. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  66. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  67. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  68. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  69. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  70. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
  71. autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
  72. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
  73. autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
  74. autogluon/tabular/predictor/predictor.py +147 -84
  75. autogluon/tabular/registry/_ag_model_registry.py +12 -2
  76. autogluon/tabular/testing/fit_helper.py +57 -27
  77. autogluon/tabular/testing/generate_datasets.py +7 -0
  78. autogluon/tabular/trainer/abstract_trainer.py +3 -1
  79. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  80. autogluon/tabular/version.py +1 -1
  81. autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
  82. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
  83. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
  84. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
  85. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  86. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  87. autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
  88. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
  89. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
  90. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
  91. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
  92. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
@@ -0,0 +1,545 @@
1
+ """Partially adapted from pytabkit's TabM implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import math
7
+ import random
8
+ import time
9
+ from typing import TYPE_CHECKING, Any, Literal
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import scipy
14
+ import torch
15
+ from sklearn.base import BaseEstimator, TransformerMixin
16
+ from sklearn.impute import SimpleImputer
17
+ from sklearn.pipeline import Pipeline
18
+ from sklearn.preprocessing import OrdinalEncoder, QuantileTransformer
19
+ from sklearn.utils.validation import check_is_fitted
20
+
21
+ from autogluon.core.metrics import compute_metric
22
+
23
+ from . import rtdl_num_embeddings, tabm_reference
24
+ from .tabm_reference import make_parameter_groups
25
+
26
+ if TYPE_CHECKING:
27
+ from autogluon.core.metrics import Scorer
28
+
29
+ TaskType = Literal["regression", "binclass", "multiclass"]
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def get_tabm_auto_batch_size(n_train: int) -> int:
35
+ # by Yury Gorishniy, inferred from the choices in the TabM paper.
36
+ if n_train < 2_800:
37
+ return 32
38
+ if n_train < 4_500:
39
+ return 64
40
+ if n_train < 6_400:
41
+ return 128
42
+ if n_train < 32_000:
43
+ return 256
44
+ if n_train < 108_000:
45
+ return 512
46
+ return 1024
47
+
48
+
49
+ class RTDLQuantileTransformer(BaseEstimator, TransformerMixin):
50
+ # adapted from pytabkit
51
+ def __init__(
52
+ self,
53
+ noise=1e-5,
54
+ random_state=None,
55
+ n_quantiles=1000,
56
+ subsample=1_000_000_000,
57
+ output_distribution="normal",
58
+ ):
59
+ self.noise = noise
60
+ self.random_state = random_state
61
+ self.n_quantiles = n_quantiles
62
+ self.subsample = subsample
63
+ self.output_distribution = output_distribution
64
+
65
+ def fit(self, X, y=None):
66
+ # Calculate the number of quantiles based on data size
67
+ n_quantiles = max(min(X.shape[0] // 30, self.n_quantiles), 10)
68
+
69
+ # Initialize QuantileTransformer
70
+ normalizer = QuantileTransformer(
71
+ output_distribution=self.output_distribution,
72
+ n_quantiles=n_quantiles,
73
+ subsample=self.subsample,
74
+ random_state=self.random_state,
75
+ )
76
+
77
+ # Add noise if required
78
+ X_modified = self._add_noise(X) if self.noise > 0 else X
79
+
80
+ # Fit the normalizer
81
+ normalizer.fit(X_modified)
82
+ # show that it's fitted
83
+ self.normalizer_ = normalizer
84
+
85
+ return self
86
+
87
+ def transform(self, X, y=None):
88
+ check_is_fitted(self)
89
+ return self.normalizer_.transform(X)
90
+
91
+ def _add_noise(self, X):
92
+ return X + np.random.default_rng(self.random_state).normal(0.0, 1e-5, X.shape).astype(X.dtype)
93
+
94
+
95
+ class TabMOrdinalEncoder(BaseEstimator, TransformerMixin):
96
+ # encodes missing and unknown values to a value one larger than the known values
97
+ def __init__(self):
98
+ # No fitted attributes here — only parameters
99
+ pass
100
+
101
+ def fit(self, X, y=None):
102
+ X = pd.DataFrame(X)
103
+
104
+ # Fit internal OrdinalEncoder with NaNs preserved for now
105
+ self.encoder_ = OrdinalEncoder(
106
+ handle_unknown="use_encoded_value",
107
+ unknown_value=np.nan,
108
+ encoded_missing_value=np.nan,
109
+ )
110
+ self.encoder_.fit(X)
111
+
112
+ # Cardinalities = number of known categories per column
113
+ self.cardinalities_ = [len(cats) for cats in self.encoder_.categories_]
114
+
115
+ return self
116
+
117
+ def transform(self, X):
118
+ check_is_fitted(self, ["encoder_", "cardinalities_"])
119
+
120
+ X = pd.DataFrame(X)
121
+ X_enc = self.encoder_.transform(X)
122
+
123
+ # Replace np.nan (unknown or missing) with cardinality value
124
+ for col_idx, cardinality in enumerate(self.cardinalities_):
125
+ mask = np.isnan(X_enc[:, col_idx])
126
+ X_enc[mask, col_idx] = cardinality
127
+
128
+ return X_enc.astype(int)
129
+
130
+ def get_cardinalities(self):
131
+ check_is_fitted(self, ["cardinalities_"])
132
+ return self.cardinalities_
133
+
134
+
135
+ class TabMImplementation:
136
+ def __init__(self, early_stopping_metric: Scorer, **config):
137
+ self.config = config
138
+ self.early_stopping_metric = early_stopping_metric
139
+
140
+ self.ord_enc_ = None
141
+ self.num_prep_ = None
142
+ self.cat_col_names_ = None
143
+ self.n_classes_ = None
144
+ self.task_type_ = None
145
+ self.device_ = None
146
+ self.has_num_cols = None
147
+
148
+ def fit(
149
+ self,
150
+ X_train: pd.DataFrame,
151
+ y_train: pd.Series,
152
+ X_val: pd.DataFrame,
153
+ y_val: pd.Series,
154
+ cat_col_names: list[Any],
155
+ time_to_fit_in_seconds: float | None = None,
156
+ ):
157
+ start_time = time.time()
158
+
159
+ if X_val is None or len(X_val) == 0:
160
+ raise ValueError("Training without validation set is currently not implemented")
161
+ seed: int | None = self.config.get("random_state", None)
162
+ if seed is not None:
163
+ torch.manual_seed(seed)
164
+ np.random.seed(seed)
165
+ random.seed(seed)
166
+ if "n_threads" in self.config:
167
+ torch.set_num_threads(self.config["n_threads"])
168
+
169
+ # -- Meta parameters
170
+ problem_type = self.config["problem_type"]
171
+ task_type: TaskType = "binclass" if problem_type == "binary" else problem_type
172
+ n_train = len(X_train)
173
+ n_classes = None
174
+ device = self.config["device"]
175
+ device = torch.device(device)
176
+ self.task_type_ = task_type
177
+ self.device_ = device
178
+ self.cat_col_names_ = cat_col_names
179
+
180
+ # -- Hyperparameters
181
+ arch_type = self.config.get("arch_type", "tabm-mini")
182
+ num_emb_type = self.config.get("num_emb_type", "pwl")
183
+ n_epochs = self.config.get("n_epochs", 1_000_000_000)
184
+ patience = self.config.get("patience", 16)
185
+ batch_size = self.config.get("batch_size", "auto")
186
+ compile_model = self.config.get("compile_model", False)
187
+ lr = self.config.get("lr", 2e-3)
188
+ d_embedding = self.config.get("d_embedding", 16)
189
+ d_block = self.config.get("d_block", 512)
190
+ dropout = self.config.get("dropout", 0.1)
191
+ tabm_k = self.config.get("tabm_k", 32)
192
+ allow_amp = self.config.get("allow_amp", False)
193
+ n_blocks = self.config.get("n_blocks", "auto")
194
+ num_emb_n_bins = self.config.get("num_emb_n_bins", 48)
195
+ eval_batch_size = self.config.get("eval_batch_size", 1024)
196
+ share_training_batches = self.config.get("share_training_batches", False)
197
+ weight_decay = self.config.get("weight_decay", 3e-4)
198
+ # this is the search space default but not the example default (which is 'none')
199
+ gradient_clipping_norm = self.config.get("gradient_clipping_norm", 1.0)
200
+
201
+ # -- Verify HPs
202
+ num_emb_n_bins = min(num_emb_n_bins, n_train - 1)
203
+ if n_train <= 2:
204
+ num_emb_type = "none" # there is no valid number of bins for piecewise linear embeddings
205
+ if batch_size == "auto":
206
+ batch_size = get_tabm_auto_batch_size(n_train=n_train)
207
+
208
+ # -- Preprocessing
209
+ ds_parts = dict()
210
+ self.ord_enc_ = (
211
+ TabMOrdinalEncoder()
212
+ ) # Unique ordinal encoder -> replaces nan and missing values with the cardinality
213
+ self.ord_enc_.fit(X_train[self.cat_col_names_])
214
+ # TODO: fix transformer to be able to work with empty input data like the sklearn default
215
+ self.num_prep_ = Pipeline(steps=[
216
+ ("qt", RTDLQuantileTransformer(random_state=self.config.get("random_state", None))),
217
+ ("imp", SimpleImputer(add_indicator=True)),
218
+ ])
219
+ self.has_num_cols = bool(set(X_train.columns) - set(cat_col_names))
220
+ for part, X, y in [("train", X_train, y_train), ("val", X_val, y_val)]:
221
+ tensors = dict()
222
+
223
+ tensors["x_cat"] = torch.as_tensor(self.ord_enc_.transform(X[cat_col_names]), dtype=torch.long)
224
+
225
+ if self.has_num_cols:
226
+ x_cont_np = X.drop(columns=cat_col_names).to_numpy(dtype=np.float32)
227
+ if part == "train":
228
+ self.num_prep_.fit(x_cont_np)
229
+ tensors["x_cont"] = torch.as_tensor(self.num_prep_.transform(x_cont_np))
230
+ else:
231
+ tensors["x_cont"] = torch.empty((len(X), 0), dtype=torch.float32)
232
+
233
+ if task_type == "regression":
234
+ tensors["y"] = torch.as_tensor(y.to_numpy(np.float32))
235
+ if part == "train":
236
+ n_classes = 0
237
+ else:
238
+ tensors["y"] = torch.as_tensor(y.to_numpy(np.int32), dtype=torch.long)
239
+ if part == "train":
240
+ n_classes = tensors["y"].max().item() + 1
241
+
242
+ ds_parts[part] = tensors
243
+
244
+ part_names = ["train", "val"]
245
+ cat_cardinalities = self.ord_enc_.get_cardinalities()
246
+ self.n_classes_ = n_classes
247
+
248
+ # filter out numerical columns with only a single value
249
+ # -> AG also does this already but preprocessing might create constant columns again
250
+ x_cont_train = ds_parts["train"]["x_cont"]
251
+ self.num_col_mask_ = ~torch.all(x_cont_train == x_cont_train[0:1, :], dim=0)
252
+ for part in part_names:
253
+ ds_parts[part]["x_cont"] = ds_parts[part]["x_cont"][:, self.num_col_mask_]
254
+ # tensor infos are not correct anymore, but might not be used either
255
+ for part in part_names:
256
+ for tens_name in ds_parts[part]:
257
+ ds_parts[part][tens_name] = ds_parts[part][tens_name].to(device)
258
+
259
+ # update
260
+ n_cont_features = ds_parts["train"]["x_cont"].shape[1]
261
+
262
+ Y_train = ds_parts["train"]["y"].clone()
263
+ if task_type == "regression":
264
+ self.y_mean_ = ds_parts["train"]["y"].mean().item()
265
+ self.y_std_ = ds_parts["train"]["y"].std(correction=0).item()
266
+
267
+ Y_train = (Y_train - self.y_mean_) / (self.y_std_ + 1e-30)
268
+
269
+ # the | operator joins dicts (like update() but not in-place)
270
+ data = {
271
+ part: dict(x_cont=ds_parts[part]["x_cont"], y=ds_parts[part]["y"])
272
+ | (dict(x_cat=ds_parts[part]["x_cat"]) if ds_parts[part]["x_cat"].shape[1] > 0 else dict())
273
+ for part in part_names
274
+ }
275
+
276
+ # adapted from https://github.com/yandex-research/tabm/blob/main/example.ipynb
277
+
278
+ # Automatic mixed precision (AMP)
279
+ # torch.float16 is implemented for completeness,
280
+ # but it was not tested in the project,
281
+ # so torch.bfloat16 is used by default.
282
+ amp_dtype = (
283
+ torch.bfloat16
284
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
285
+ else torch.float16
286
+ if torch.cuda.is_available()
287
+ else None
288
+ )
289
+ # Changing False to True will result in faster training on compatible hardware.
290
+ amp_enabled = allow_amp and amp_dtype is not None
291
+ grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None # type: ignore
292
+
293
+ # fmt: off
294
+ logger.log(15, f"Device: {device.type.upper()}"
295
+ f"\nAMP: {amp_enabled} (dtype: {amp_dtype})"
296
+ f"\ntorch.compile: {compile_model}",
297
+ )
298
+ # fmt: on
299
+
300
+ bins = (
301
+ None
302
+ if num_emb_type != "pwl" or n_cont_features == 0
303
+ else rtdl_num_embeddings.compute_bins(data["train"]["x_cont"], n_bins=num_emb_n_bins)
304
+ )
305
+
306
+ model = tabm_reference.Model(
307
+ n_num_features=n_cont_features,
308
+ cat_cardinalities=cat_cardinalities,
309
+ n_classes=n_classes if n_classes > 0 else None,
310
+ backbone={
311
+ "type": "MLP",
312
+ "n_blocks": n_blocks if n_blocks != "auto" else (3 if bins is None else 2),
313
+ "d_block": d_block,
314
+ "dropout": dropout,
315
+ },
316
+ bins=bins,
317
+ num_embeddings=(
318
+ None
319
+ if bins is None
320
+ else {
321
+ "type": "PiecewiseLinearEmbeddings",
322
+ "d_embedding": d_embedding,
323
+ "activation": False,
324
+ "version": "B",
325
+ }
326
+ ),
327
+ arch_type=arch_type,
328
+ k=tabm_k,
329
+ share_training_batches=share_training_batches,
330
+ ).to(device)
331
+ optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay)
332
+
333
+ if compile_model:
334
+ # NOTE
335
+ # `torch.compile` is intentionally called without the `mode` argument
336
+ # (mode="reduce-overhead" caused issues during training with torch==2.0.1).
337
+ model = torch.compile(model)
338
+ evaluation_mode = torch.no_grad
339
+ else:
340
+ evaluation_mode = torch.inference_mode
341
+
342
+ @torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype) # type: ignore[code]
343
+ def apply_model(part: str, idx: torch.Tensor) -> torch.Tensor:
344
+ return (
345
+ model(
346
+ data[part]["x_cont"][idx],
347
+ data[part]["x_cat"][idx] if "x_cat" in data[part] else None,
348
+ )
349
+ .squeeze(-1) # Remove the last dimension for regression tasks.
350
+ .float()
351
+ )
352
+
353
+ # TODO: use BCELoss for binary classification
354
+ base_loss_fn = torch.nn.functional.mse_loss if task_type == "regression" else torch.nn.functional.cross_entropy
355
+
356
+ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
357
+ # TabM produces k predictions per object. Each of them must be trained separately.
358
+ # (regression) y_pred.shape == (batch_size, k)
359
+ # (classification) y_pred.shape == (batch_size, k, n_classes)
360
+ k = y_pred.shape[1]
361
+ return base_loss_fn(
362
+ y_pred.flatten(0, 1),
363
+ y_true.repeat_interleave(k) if model.share_training_batches else y_true,
364
+ )
365
+
366
+ @evaluation_mode()
367
+ def evaluate(part: str) -> float:
368
+ model.eval()
369
+
370
+ # When using torch.compile, you may need to reduce the evaluation batch size.
371
+ y_pred: np.ndarray = (
372
+ torch.cat(
373
+ [
374
+ apply_model(part, idx)
375
+ for idx in torch.arange(len(data[part]["y"]), device=device).split(
376
+ eval_batch_size,
377
+ )
378
+ ],
379
+ )
380
+ .cpu()
381
+ .numpy()
382
+ )
383
+ if task_type == "regression":
384
+ # Transform the predictions back to the original label space.
385
+ y_pred = y_pred * self.y_std_ + self.y_mean_
386
+
387
+ # Compute the mean of the k predictions.
388
+ average_logits = self.config.get("average_logits", False)
389
+ if average_logits:
390
+ y_pred = y_pred.mean(1)
391
+ if task_type != "regression":
392
+ # For classification, the mean must be computed in the probability space.
393
+ y_pred = scipy.special.softmax(y_pred, axis=-1)
394
+ if not average_logits:
395
+ y_pred = y_pred.mean(1)
396
+
397
+ return compute_metric(
398
+ y=data[part]["y"].cpu().numpy(),
399
+ metric=self.early_stopping_metric,
400
+ y_pred=y_pred if task_type == "regression" else y_pred.argmax(1),
401
+ y_pred_proba=y_pred[:, 1] if task_type == "binclass" else y_pred,
402
+ silent=True,
403
+ )
404
+
405
+ math.ceil(n_train / batch_size)
406
+ best = {
407
+ "val": -math.inf,
408
+ # 'test': -math.inf,
409
+ "epoch": -1,
410
+ }
411
+ best_params = [p.clone() for p in model.parameters()]
412
+ # Early stopping: the training stops when
413
+ # there are more than `patience` consecutive bad updates.
414
+ remaining_patience = patience
415
+
416
+ try:
417
+ if self.config.get("verbosity", 0) >= 1:
418
+ from tqdm.std import tqdm
419
+ else:
420
+ tqdm = lambda arr, desc: arr
421
+ except ImportError:
422
+ tqdm = lambda arr, desc: arr
423
+
424
+ logger.log(15, "-" * 88 + "\n")
425
+ for epoch in range(n_epochs):
426
+ # check time limit
427
+ if epoch > 0 and time_to_fit_in_seconds is not None:
428
+ pred_time_after_next_epoch = (epoch + 1) / epoch * (time.time() - start_time)
429
+ if pred_time_after_next_epoch >= time_to_fit_in_seconds:
430
+ break
431
+
432
+ batches = (
433
+ torch.randperm(n_train, device=device).split(batch_size)
434
+ if model.share_training_batches
435
+ else [
436
+ x.transpose(0, 1).flatten()
437
+ for x in torch.rand((model.k, n_train), device=device).argsort(dim=1).split(batch_size, dim=1)
438
+ ]
439
+ )
440
+
441
+ for batch_idx in tqdm(batches, desc=f"Epoch {epoch}"):
442
+ model.train()
443
+ optimizer.zero_grad()
444
+ loss = loss_fn(apply_model("train", batch_idx), Y_train[batch_idx])
445
+
446
+ # added from https://github.com/yandex-research/tabm/blob/main/bin/model.py
447
+ if gradient_clipping_norm is not None and gradient_clipping_norm != "none":
448
+ if grad_scaler is not None:
449
+ grad_scaler.unscale_(optimizer)
450
+ torch.nn.utils.clip_grad.clip_grad_norm_(
451
+ model.parameters(),
452
+ gradient_clipping_norm,
453
+ )
454
+
455
+ if grad_scaler is None:
456
+ loss.backward()
457
+ optimizer.step()
458
+ else:
459
+ grad_scaler.scale(loss).backward() # type: ignore
460
+ grad_scaler.step(optimizer) # Ignores grad scaler might skip steps; should not break anything
461
+ grad_scaler.update()
462
+
463
+ val_score = evaluate("val")
464
+ logger.log(15, f"(val) {val_score:.4f}")
465
+
466
+ if val_score > best["val"]:
467
+ logger.log(15, "🌸 New best epoch! 🌸")
468
+ # best = {'val': val_score, 'test': test_score, 'epoch': epoch}
469
+ best = {"val": val_score, "epoch": epoch}
470
+ remaining_patience = patience
471
+ with torch.no_grad():
472
+ for bp, p in zip(best_params, model.parameters()):
473
+ bp.copy_(p)
474
+ else:
475
+ remaining_patience -= 1
476
+
477
+ if remaining_patience < 0:
478
+ break
479
+
480
+ logger.log(15, "\n\nResult:")
481
+ logger.log(15, str(best))
482
+
483
+ logger.log(15, "Restoring best model")
484
+ with torch.no_grad():
485
+ for bp, p in zip(best_params, model.parameters()):
486
+ p.copy_(bp)
487
+
488
+ self.model_ = model
489
+
490
+ def predict_raw(self, X: pd.DataFrame) -> torch.Tensor:
491
+ self.model_.eval()
492
+
493
+ tensors = dict()
494
+ tensors["x_cat"] = torch.as_tensor(self.ord_enc_.transform(X[self.cat_col_names_]), dtype=torch.long).to(
495
+ self.device_,
496
+ )
497
+ tensors["x_cont"] = torch.as_tensor(
498
+ self.num_prep_.transform(X.drop(columns=X[self.cat_col_names_]).to_numpy(dtype=np.float32))
499
+ if self.has_num_cols
500
+ else np.empty((len(X), 0), dtype=np.float32),
501
+ ).to(self.device_)
502
+
503
+ tensors["x_cont"] = tensors["x_cont"][:, self.num_col_mask_]
504
+
505
+ eval_batch_size = self.config.get("eval_batch_size", 1024)
506
+ with torch.no_grad():
507
+ y_pred: torch.Tensor = torch.cat(
508
+ [
509
+ self.model_(
510
+ tensors["x_cont"][idx],
511
+ tensors["x_cat"][idx] if tensors["x_cat"].numel() != 0 else None,
512
+ )
513
+ .squeeze(-1) # Remove the last dimension for regression tasks.
514
+ .float()
515
+ for idx in torch.arange(tensors["x_cont"].shape[0], device=self.device_).split(
516
+ eval_batch_size,
517
+ )
518
+ ],
519
+ )
520
+ if self.task_type_ == "regression":
521
+ # Transform the predictions back to the original label space.
522
+ y_pred = y_pred * self.y_std_ + self.y_mean_
523
+ y_pred = y_pred.mean(1)
524
+ # y_pred = y_pred.unsqueeze(-1) # add extra "features" dimension
525
+ else:
526
+ average_logits = self.config.get("average_logits", False)
527
+ if average_logits:
528
+ y_pred = y_pred.mean(1)
529
+ else:
530
+ # For classification, the mean must be computed in the probability space.
531
+ y_pred = torch.log(torch.softmax(y_pred, dim=-1).mean(1) + 1e-30)
532
+
533
+ return y_pred.cpu()
534
+
535
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
536
+ y_pred = self.predict_raw(X)
537
+ if self.task_type_ == "regression":
538
+ return y_pred.numpy()
539
+ return y_pred.argmax(dim=-1).numpy()
540
+
541
+ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
542
+ probas = torch.softmax(self.predict_raw(X), dim=-1).numpy()
543
+ if probas.shape[1] == 2:
544
+ probas = probas[:, 1]
545
+ return probas