autogluon.tabular 1.3.2b20250709__py3-none-any.whl → 1.3.2b20250710__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. autogluon/tabular/models/__init__.py +3 -0
  2. autogluon/tabular/models/catboost/callbacks.py +3 -2
  3. autogluon/tabular/models/catboost/catboost_model.py +2 -2
  4. autogluon/tabular/models/catboost/catboost_utils.py +7 -3
  5. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +3 -3
  6. autogluon/tabular/models/lgb/lgb_model.py +2 -2
  7. autogluon/tabular/models/realmlp/__init__.py +0 -0
  8. autogluon/tabular/models/realmlp/realmlp_model.py +347 -0
  9. autogluon/tabular/models/rf/rf_model.py +2 -1
  10. autogluon/tabular/models/tabicl/__init__.py +0 -0
  11. autogluon/tabular/models/tabicl/tabicl_model.py +174 -0
  12. autogluon/tabular/models/tabm/__init__.py +0 -0
  13. autogluon/tabular/models/tabm/_tabm_internal.py +544 -0
  14. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +807 -0
  15. autogluon/tabular/models/tabm/tabm_model.py +275 -0
  16. autogluon/tabular/models/tabm/tabm_reference.py +627 -0
  17. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +3 -3
  18. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -3
  19. autogluon/tabular/models/xgboost/xgboost_model.py +2 -2
  20. autogluon/tabular/predictor/predictor.py +5 -3
  21. autogluon/tabular/registry/_ag_model_registry.py +6 -0
  22. autogluon/tabular/testing/fit_helper.py +27 -25
  23. autogluon/tabular/testing/generate_datasets.py +7 -0
  24. autogluon/tabular/trainer/abstract_trainer.py +1 -1
  25. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  26. autogluon/tabular/version.py +1 -1
  27. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/METADATA +21 -13
  28. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/RECORD +35 -26
  29. /autogluon.tabular-1.3.2b20250709-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250710-py3.9-nspkg.pth +0 -0
  30. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/LICENSE +0 -0
  31. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/NOTICE +0 -0
  32. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/WHEEL +0 -0
  33. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/namespace_packages.txt +0 -0
  34. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/top_level.txt +0 -0
  35. {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/zip-safe +0 -0
@@ -0,0 +1,544 @@
1
+ """Partially adapted from pytabkit's TabM implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import math
7
+ import random
8
+ import time
9
+ from typing import TYPE_CHECKING, Any, Literal
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import scipy
14
+ import torch
15
+ from autogluon.core.metrics import compute_metric
16
+ from sklearn.base import BaseEstimator, TransformerMixin
17
+ from sklearn.impute import SimpleImputer
18
+ from sklearn.pipeline import Pipeline
19
+ from sklearn.preprocessing import OrdinalEncoder, QuantileTransformer
20
+ from sklearn.utils.validation import check_is_fitted
21
+
22
+ from . import rtdl_num_embeddings, tabm_reference
23
+ from .tabm_reference import make_parameter_groups
24
+
25
+ if TYPE_CHECKING:
26
+ from autogluon.core.metrics import Scorer
27
+
28
+ TaskType = Literal["regression", "binclass", "multiclass"]
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def get_tabm_auto_batch_size(n_train: int) -> int:
34
+ # by Yury Gorishniy, inferred from the choices in the TabM paper.
35
+ if n_train < 2_800:
36
+ return 32
37
+ if n_train < 4_500:
38
+ return 64
39
+ if n_train < 6_400:
40
+ return 128
41
+ if n_train < 32_000:
42
+ return 256
43
+ if n_train < 108_000:
44
+ return 512
45
+ return 1024
46
+
47
+
48
+ class RTDLQuantileTransformer(BaseEstimator, TransformerMixin):
49
+ # adapted from pytabkit
50
+ def __init__(
51
+ self,
52
+ noise=1e-5,
53
+ random_state=None,
54
+ n_quantiles=1000,
55
+ subsample=1_000_000_000,
56
+ output_distribution="normal",
57
+ ):
58
+ self.noise = noise
59
+ self.random_state = random_state
60
+ self.n_quantiles = n_quantiles
61
+ self.subsample = subsample
62
+ self.output_distribution = output_distribution
63
+
64
+ def fit(self, X, y=None):
65
+ # Calculate the number of quantiles based on data size
66
+ n_quantiles = max(min(X.shape[0] // 30, self.n_quantiles), 10)
67
+
68
+ # Initialize QuantileTransformer
69
+ normalizer = QuantileTransformer(
70
+ output_distribution=self.output_distribution,
71
+ n_quantiles=n_quantiles,
72
+ subsample=self.subsample,
73
+ random_state=self.random_state,
74
+ )
75
+
76
+ # Add noise if required
77
+ X_modified = self._add_noise(X) if self.noise > 0 else X
78
+
79
+ # Fit the normalizer
80
+ normalizer.fit(X_modified)
81
+ # show that it's fitted
82
+ self.normalizer_ = normalizer
83
+
84
+ return self
85
+
86
+ def transform(self, X, y=None):
87
+ check_is_fitted(self)
88
+ return self.normalizer_.transform(X)
89
+
90
+ def _add_noise(self, X):
91
+ return X + np.random.default_rng(self.random_state).normal(0.0, 1e-5, X.shape).astype(X.dtype)
92
+
93
+
94
+ class TabMOrdinalEncoder(BaseEstimator, TransformerMixin):
95
+ # encodes missing and unknown values to a value one larger than the known values
96
+ def __init__(self):
97
+ # No fitted attributes here — only parameters
98
+ pass
99
+
100
+ def fit(self, X, y=None):
101
+ X = pd.DataFrame(X)
102
+
103
+ # Fit internal OrdinalEncoder with NaNs preserved for now
104
+ self.encoder_ = OrdinalEncoder(
105
+ handle_unknown="use_encoded_value",
106
+ unknown_value=np.nan,
107
+ encoded_missing_value=np.nan,
108
+ )
109
+ self.encoder_.fit(X)
110
+
111
+ # Cardinalities = number of known categories per column
112
+ self.cardinalities_ = [len(cats) for cats in self.encoder_.categories_]
113
+
114
+ return self
115
+
116
+ def transform(self, X):
117
+ check_is_fitted(self, ["encoder_", "cardinalities_"])
118
+
119
+ X = pd.DataFrame(X)
120
+ X_enc = self.encoder_.transform(X)
121
+
122
+ # Replace np.nan (unknown or missing) with cardinality value
123
+ for col_idx, cardinality in enumerate(self.cardinalities_):
124
+ mask = np.isnan(X_enc[:, col_idx])
125
+ X_enc[mask, col_idx] = cardinality
126
+
127
+ return X_enc.astype(int)
128
+
129
+ def get_cardinalities(self):
130
+ check_is_fitted(self, ["cardinalities_"])
131
+ return self.cardinalities_
132
+
133
+
134
+ class TabMImplementation:
135
+ def __init__(self, early_stopping_metric: Scorer, **config):
136
+ self.config = config
137
+ self.early_stopping_metric = early_stopping_metric
138
+
139
+ self.ord_enc_ = None
140
+ self.num_prep_ = None
141
+ self.cat_col_names_ = None
142
+ self.n_classes_ = None
143
+ self.task_type_ = None
144
+ self.device_ = None
145
+ self.has_num_cols = None
146
+
147
+ def fit(
148
+ self,
149
+ X_train: pd.DataFrame,
150
+ y_train: pd.Series,
151
+ X_val: pd.DataFrame,
152
+ y_val: pd.Series,
153
+ cat_col_names: list[Any],
154
+ time_to_fit_in_seconds: float | None = None,
155
+ ):
156
+ start_time = time.time()
157
+
158
+ if X_val is None or len(X_val) == 0:
159
+ raise ValueError("Training without validation set is currently not implemented")
160
+ seed: int | None = self.config.get("random_state", None)
161
+ if seed is not None:
162
+ torch.manual_seed(seed)
163
+ np.random.seed(seed)
164
+ random.seed(seed)
165
+ if "n_threads" in self.config:
166
+ torch.set_num_threads(self.config["n_threads"])
167
+
168
+ # -- Meta parameters
169
+ problem_type = self.config["problem_type"]
170
+ task_type: TaskType = "binclass" if problem_type == "binary" else problem_type
171
+ n_train = len(X_train)
172
+ n_classes = None
173
+ device = self.config["device"]
174
+ device = torch.device(device)
175
+ self.task_type_ = task_type
176
+ self.device_ = device
177
+ self.cat_col_names_ = cat_col_names
178
+
179
+ # -- Hyperparameters
180
+ arch_type = self.config.get("arch_type", "tabm-mini")
181
+ num_emb_type = self.config.get("num_emb_type", "pwl")
182
+ n_epochs = self.config.get("n_epochs", 1_000_000_000)
183
+ patience = self.config.get("patience", 16)
184
+ batch_size = self.config.get("batch_size", "auto")
185
+ compile_model = self.config.get("compile_model", False)
186
+ lr = self.config.get("lr", 2e-3)
187
+ d_embedding = self.config.get("d_embedding", 16)
188
+ d_block = self.config.get("d_block", 512)
189
+ dropout = self.config.get("dropout", 0.1)
190
+ tabm_k = self.config.get("tabm_k", 32)
191
+ allow_amp = self.config.get("allow_amp", False)
192
+ n_blocks = self.config.get("n_blocks", "auto")
193
+ num_emb_n_bins = self.config.get("num_emb_n_bins", 48)
194
+ eval_batch_size = self.config.get("eval_batch_size", 1024)
195
+ share_training_batches = self.config.get("share_training_batches", False)
196
+ weight_decay = self.config.get("weight_decay", 3e-4)
197
+ # this is the search space default but not the example default (which is 'none')
198
+ gradient_clipping_norm = self.config.get("gradient_clipping_norm", 1.0)
199
+
200
+ # -- Verify HPs
201
+ num_emb_n_bins = min(num_emb_n_bins, n_train - 1)
202
+ if n_train <= 2:
203
+ num_emb_type = "none" # there is no valid number of bins for piecewise linear embeddings
204
+ if batch_size == "auto":
205
+ batch_size = get_tabm_auto_batch_size(n_train=n_train)
206
+
207
+ # -- Preprocessing
208
+ ds_parts = dict()
209
+ self.ord_enc_ = (
210
+ TabMOrdinalEncoder()
211
+ ) # Unique ordinal encoder -> replaces nan and missing values with the cardinality
212
+ self.ord_enc_.fit(X_train[self.cat_col_names_])
213
+ # TODO: fix transformer to be able to work with empty input data like the sklearn default
214
+ self.num_prep_ = Pipeline(steps=[
215
+ ("qt", RTDLQuantileTransformer(random_state=self.config.get("random_state", None))),
216
+ ("imp", SimpleImputer(add_indicator=True)),
217
+ ])
218
+ self.has_num_cols = bool(set(X_train.columns) - set(cat_col_names))
219
+ for part, X, y in [("train", X_train, y_train), ("val", X_val, y_val)]:
220
+ tensors = dict()
221
+
222
+ tensors["x_cat"] = torch.as_tensor(self.ord_enc_.transform(X[cat_col_names]), dtype=torch.long)
223
+
224
+ if self.has_num_cols:
225
+ x_cont_np = X.drop(columns=cat_col_names).to_numpy(dtype=np.float32)
226
+ if part == "train":
227
+ self.num_prep_.fit(x_cont_np)
228
+ tensors["x_cont"] = torch.as_tensor(self.num_prep_.transform(x_cont_np))
229
+ else:
230
+ tensors["x_cont"] = torch.empty((len(X), 0), dtype=torch.float32)
231
+
232
+ if task_type == "regression":
233
+ tensors["y"] = torch.as_tensor(y.to_numpy(np.float32))
234
+ if part == "train":
235
+ n_classes = 0
236
+ else:
237
+ tensors["y"] = torch.as_tensor(y.to_numpy(np.int32), dtype=torch.long)
238
+ if part == "train":
239
+ n_classes = tensors["y"].max().item() + 1
240
+
241
+ ds_parts[part] = tensors
242
+
243
+ part_names = ["train", "val"]
244
+ cat_cardinalities = self.ord_enc_.get_cardinalities()
245
+ self.n_classes_ = n_classes
246
+
247
+ # filter out numerical columns with only a single value
248
+ # -> AG also does this already but preprocessing might create constant columns again
249
+ x_cont_train = ds_parts["train"]["x_cont"]
250
+ self.num_col_mask_ = ~torch.all(x_cont_train == x_cont_train[0:1, :], dim=0)
251
+ for part in part_names:
252
+ ds_parts[part]["x_cont"] = ds_parts[part]["x_cont"][:, self.num_col_mask_]
253
+ # tensor infos are not correct anymore, but might not be used either
254
+ for part in part_names:
255
+ for tens_name in ds_parts[part]:
256
+ ds_parts[part][tens_name] = ds_parts[part][tens_name].to(device)
257
+
258
+ # update
259
+ n_cont_features = ds_parts["train"]["x_cont"].shape[1]
260
+
261
+ Y_train = ds_parts["train"]["y"].clone()
262
+ if task_type == "regression":
263
+ self.y_mean_ = ds_parts["train"]["y"].mean().item()
264
+ self.y_std_ = ds_parts["train"]["y"].std(correction=0).item()
265
+
266
+ Y_train = (Y_train - self.y_mean_) / (self.y_std_ + 1e-30)
267
+
268
+ # the | operator joins dicts (like update() but not in-place)
269
+ data = {
270
+ part: dict(x_cont=ds_parts[part]["x_cont"], y=ds_parts[part]["y"])
271
+ | (dict(x_cat=ds_parts[part]["x_cat"]) if ds_parts[part]["x_cat"].shape[1] > 0 else dict())
272
+ for part in part_names
273
+ }
274
+
275
+ # adapted from https://github.com/yandex-research/tabm/blob/main/example.ipynb
276
+
277
+ # Automatic mixed precision (AMP)
278
+ # torch.float16 is implemented for completeness,
279
+ # but it was not tested in the project,
280
+ # so torch.bfloat16 is used by default.
281
+ amp_dtype = (
282
+ torch.bfloat16
283
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
284
+ else torch.float16
285
+ if torch.cuda.is_available()
286
+ else None
287
+ )
288
+ # Changing False to True will result in faster training on compatible hardware.
289
+ amp_enabled = allow_amp and amp_dtype is not None
290
+ grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None # type: ignore
291
+
292
+ # fmt: off
293
+ logger.log(15, f"Device: {device.type.upper()}"
294
+ f"\nAMP: {amp_enabled} (dtype: {amp_dtype})"
295
+ f"\ntorch.compile: {compile_model}",
296
+ )
297
+ # fmt: on
298
+
299
+ bins = (
300
+ None
301
+ if num_emb_type != "pwl" or n_cont_features == 0
302
+ else rtdl_num_embeddings.compute_bins(data["train"]["x_cont"], n_bins=num_emb_n_bins)
303
+ )
304
+
305
+ model = tabm_reference.Model(
306
+ n_num_features=n_cont_features,
307
+ cat_cardinalities=cat_cardinalities,
308
+ n_classes=n_classes if n_classes > 0 else None,
309
+ backbone={
310
+ "type": "MLP",
311
+ "n_blocks": n_blocks if n_blocks != "auto" else (3 if bins is None else 2),
312
+ "d_block": d_block,
313
+ "dropout": dropout,
314
+ },
315
+ bins=bins,
316
+ num_embeddings=(
317
+ None
318
+ if bins is None
319
+ else {
320
+ "type": "PiecewiseLinearEmbeddings",
321
+ "d_embedding": d_embedding,
322
+ "activation": False,
323
+ "version": "B",
324
+ }
325
+ ),
326
+ arch_type=arch_type,
327
+ k=tabm_k,
328
+ share_training_batches=share_training_batches,
329
+ ).to(device)
330
+ optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay)
331
+
332
+ if compile_model:
333
+ # NOTE
334
+ # `torch.compile` is intentionally called without the `mode` argument
335
+ # (mode="reduce-overhead" caused issues during training with torch==2.0.1).
336
+ model = torch.compile(model)
337
+ evaluation_mode = torch.no_grad
338
+ else:
339
+ evaluation_mode = torch.inference_mode
340
+
341
+ @torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype) # type: ignore[code]
342
+ def apply_model(part: str, idx: torch.Tensor) -> torch.Tensor:
343
+ return (
344
+ model(
345
+ data[part]["x_cont"][idx],
346
+ data[part]["x_cat"][idx] if "x_cat" in data[part] else None,
347
+ )
348
+ .squeeze(-1) # Remove the last dimension for regression tasks.
349
+ .float()
350
+ )
351
+
352
+ # TODO: use BCELoss for binary classification
353
+ base_loss_fn = torch.nn.functional.mse_loss if task_type == "regression" else torch.nn.functional.cross_entropy
354
+
355
+ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
356
+ # TabM produces k predictions per object. Each of them must be trained separately.
357
+ # (regression) y_pred.shape == (batch_size, k)
358
+ # (classification) y_pred.shape == (batch_size, k, n_classes)
359
+ k = y_pred.shape[1]
360
+ return base_loss_fn(
361
+ y_pred.flatten(0, 1),
362
+ y_true.repeat_interleave(k) if model.share_training_batches else y_true,
363
+ )
364
+
365
+ @evaluation_mode()
366
+ def evaluate(part: str) -> float:
367
+ model.eval()
368
+
369
+ # When using torch.compile, you may need to reduce the evaluation batch size.
370
+ y_pred: np.ndarray = (
371
+ torch.cat(
372
+ [
373
+ apply_model(part, idx)
374
+ for idx in torch.arange(len(data[part]["y"]), device=device).split(
375
+ eval_batch_size,
376
+ )
377
+ ],
378
+ )
379
+ .cpu()
380
+ .numpy()
381
+ )
382
+ if task_type == "regression":
383
+ # Transform the predictions back to the original label space.
384
+ y_pred = y_pred * self.y_std_ + self.y_mean_
385
+
386
+ # Compute the mean of the k predictions.
387
+ average_logits = self.config.get("average_logits", False)
388
+ if average_logits:
389
+ y_pred = y_pred.mean(1)
390
+ if task_type != "regression":
391
+ # For classification, the mean must be computed in the probability space.
392
+ y_pred = scipy.special.softmax(y_pred, axis=-1)
393
+ if not average_logits:
394
+ y_pred = y_pred.mean(1)
395
+
396
+ return compute_metric(
397
+ y=data[part]["y"].cpu().numpy(),
398
+ metric=self.early_stopping_metric,
399
+ y_pred=y_pred if task_type == "regression" else y_pred.argmax(1),
400
+ y_pred_proba=y_pred[:, 1] if task_type == "binclass" else y_pred,
401
+ silent=True,
402
+ )
403
+
404
+ math.ceil(n_train / batch_size)
405
+ best = {
406
+ "val": -math.inf,
407
+ # 'test': -math.inf,
408
+ "epoch": -1,
409
+ }
410
+ best_params = [p.clone() for p in model.parameters()]
411
+ # Early stopping: the training stops when
412
+ # there are more than `patience` consecutive bad updates.
413
+ remaining_patience = patience
414
+
415
+ try:
416
+ if self.config.get("verbosity", 0) >= 1:
417
+ from tqdm.std import tqdm
418
+ else:
419
+ tqdm = lambda arr, desc: arr
420
+ except ImportError:
421
+ tqdm = lambda arr, desc: arr
422
+
423
+ logger.log(15, "-" * 88 + "\n")
424
+ for epoch in range(n_epochs):
425
+ # check time limit
426
+ if epoch > 0 and time_to_fit_in_seconds is not None:
427
+ pred_time_after_next_epoch = (epoch + 1) / epoch * (time.time() - start_time)
428
+ if pred_time_after_next_epoch >= time_to_fit_in_seconds:
429
+ break
430
+
431
+ batches = (
432
+ torch.randperm(n_train, device=device).split(batch_size)
433
+ if model.share_training_batches
434
+ else [
435
+ x.transpose(0, 1).flatten()
436
+ for x in torch.rand((model.k, n_train), device=device).argsort(dim=1).split(batch_size, dim=1)
437
+ ]
438
+ )
439
+
440
+ for batch_idx in tqdm(batches, desc=f"Epoch {epoch}"):
441
+ model.train()
442
+ optimizer.zero_grad()
443
+ loss = loss_fn(apply_model("train", batch_idx), Y_train[batch_idx])
444
+
445
+ # added from https://github.com/yandex-research/tabm/blob/main/bin/model.py
446
+ if gradient_clipping_norm is not None and gradient_clipping_norm != "none":
447
+ if grad_scaler is not None:
448
+ grad_scaler.unscale_(optimizer)
449
+ torch.nn.utils.clip_grad.clip_grad_norm_(
450
+ model.parameters(),
451
+ gradient_clipping_norm,
452
+ )
453
+
454
+ if grad_scaler is None:
455
+ loss.backward()
456
+ optimizer.step()
457
+ else:
458
+ grad_scaler.scale(loss).backward() # type: ignore
459
+ grad_scaler.step(optimizer) # Ignores grad scaler might skip steps; should not break anything
460
+ grad_scaler.update()
461
+
462
+ val_score = evaluate("val")
463
+ logger.log(15, f"(val) {val_score:.4f}")
464
+
465
+ if val_score > best["val"]:
466
+ logger.log(15, "🌸 New best epoch! 🌸")
467
+ # best = {'val': val_score, 'test': test_score, 'epoch': epoch}
468
+ best = {"val": val_score, "epoch": epoch}
469
+ remaining_patience = patience
470
+ with torch.no_grad():
471
+ for bp, p in zip(best_params, model.parameters(), strict=False):
472
+ bp.copy_(p)
473
+ else:
474
+ remaining_patience -= 1
475
+
476
+ if remaining_patience < 0:
477
+ break
478
+
479
+ logger.log(15, "\n\nResult:")
480
+ logger.log(15, str(best))
481
+
482
+ logger.log(15, "Restoring best model")
483
+ with torch.no_grad():
484
+ for bp, p in zip(best_params, model.parameters(), strict=False):
485
+ p.copy_(bp)
486
+
487
+ self.model_ = model
488
+
489
+ def predict_raw(self, X: pd.DataFrame) -> torch.Tensor:
490
+ self.model_.eval()
491
+
492
+ tensors = dict()
493
+ tensors["x_cat"] = torch.as_tensor(self.ord_enc_.transform(X[self.cat_col_names_]), dtype=torch.long).to(
494
+ self.device_,
495
+ )
496
+ tensors["x_cont"] = torch.as_tensor(
497
+ self.num_prep_.transform(X.drop(columns=X[self.cat_col_names_]).to_numpy(dtype=np.float32))
498
+ if self.has_num_cols
499
+ else np.empty((len(X), 0), dtype=np.float32),
500
+ ).to(self.device_)
501
+
502
+ tensors["x_cont"] = tensors["x_cont"][:, self.num_col_mask_]
503
+
504
+ eval_batch_size = self.config.get("eval_batch_size", 1024)
505
+ with torch.no_grad():
506
+ y_pred: torch.Tensor = torch.cat(
507
+ [
508
+ self.model_(
509
+ tensors["x_cont"][idx],
510
+ tensors["x_cat"][idx] if tensors["x_cat"].numel() != 0 else None,
511
+ )
512
+ .squeeze(-1) # Remove the last dimension for regression tasks.
513
+ .float()
514
+ for idx in torch.arange(tensors["x_cont"].shape[0], device=self.device_).split(
515
+ eval_batch_size,
516
+ )
517
+ ],
518
+ )
519
+ if self.task_type_ == "regression":
520
+ # Transform the predictions back to the original label space.
521
+ y_pred = y_pred * self.y_std_ + self.y_mean_
522
+ y_pred = y_pred.mean(1)
523
+ # y_pred = y_pred.unsqueeze(-1) # add extra "features" dimension
524
+ else:
525
+ average_logits = self.config.get("average_logits", False)
526
+ if average_logits:
527
+ y_pred = y_pred.mean(1)
528
+ else:
529
+ # For classification, the mean must be computed in the probability space.
530
+ y_pred = torch.log(torch.softmax(y_pred, dim=-1).mean(1) + 1e-30)
531
+
532
+ return y_pred.cpu()
533
+
534
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
535
+ y_pred = self.predict_raw(X)
536
+ if self.task_type_ == "regression":
537
+ return y_pred.numpy()
538
+ return y_pred.argmax(dim=-1).numpy()
539
+
540
+ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
541
+ probas = torch.softmax(self.predict_raw(X), dim=-1).numpy()
542
+ if probas.shape[1] == 2:
543
+ probas = probas[:, 1]
544
+ return probas