autogluon.tabular 1.3.2b20250709__py3-none-any.whl → 1.3.2b20250710__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +3 -0
- autogluon/tabular/models/catboost/callbacks.py +3 -2
- autogluon/tabular/models/catboost/catboost_model.py +2 -2
- autogluon/tabular/models/catboost/catboost_utils.py +7 -3
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +3 -3
- autogluon/tabular/models/lgb/lgb_model.py +2 -2
- autogluon/tabular/models/realmlp/__init__.py +0 -0
- autogluon/tabular/models/realmlp/realmlp_model.py +347 -0
- autogluon/tabular/models/rf/rf_model.py +2 -1
- autogluon/tabular/models/tabicl/__init__.py +0 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +174 -0
- autogluon/tabular/models/tabm/__init__.py +0 -0
- autogluon/tabular/models/tabm/_tabm_internal.py +544 -0
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +807 -0
- autogluon/tabular/models/tabm/tabm_model.py +275 -0
- autogluon/tabular/models/tabm/tabm_reference.py +627 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +3 -3
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -3
- autogluon/tabular/models/xgboost/xgboost_model.py +2 -2
- autogluon/tabular/predictor/predictor.py +5 -3
- autogluon/tabular/registry/_ag_model_registry.py +6 -0
- autogluon/tabular/testing/fit_helper.py +27 -25
- autogluon/tabular/testing/generate_datasets.py +7 -0
- autogluon/tabular/trainer/abstract_trainer.py +1 -1
- autogluon/tabular/trainer/model_presets/presets.py +10 -1
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/METADATA +21 -13
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/RECORD +35 -26
- /autogluon.tabular-1.3.2b20250709-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250710-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250709.dist-info → autogluon.tabular-1.3.2b20250710.dist-info}/zip-safe +0 -0
@@ -0,0 +1,544 @@
|
|
1
|
+
"""Partially adapted from pytabkit's TabM implementation."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import math
|
7
|
+
import random
|
8
|
+
import time
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import scipy
|
14
|
+
import torch
|
15
|
+
from autogluon.core.metrics import compute_metric
|
16
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
17
|
+
from sklearn.impute import SimpleImputer
|
18
|
+
from sklearn.pipeline import Pipeline
|
19
|
+
from sklearn.preprocessing import OrdinalEncoder, QuantileTransformer
|
20
|
+
from sklearn.utils.validation import check_is_fitted
|
21
|
+
|
22
|
+
from . import rtdl_num_embeddings, tabm_reference
|
23
|
+
from .tabm_reference import make_parameter_groups
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from autogluon.core.metrics import Scorer
|
27
|
+
|
28
|
+
TaskType = Literal["regression", "binclass", "multiclass"]
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
def get_tabm_auto_batch_size(n_train: int) -> int:
|
34
|
+
# by Yury Gorishniy, inferred from the choices in the TabM paper.
|
35
|
+
if n_train < 2_800:
|
36
|
+
return 32
|
37
|
+
if n_train < 4_500:
|
38
|
+
return 64
|
39
|
+
if n_train < 6_400:
|
40
|
+
return 128
|
41
|
+
if n_train < 32_000:
|
42
|
+
return 256
|
43
|
+
if n_train < 108_000:
|
44
|
+
return 512
|
45
|
+
return 1024
|
46
|
+
|
47
|
+
|
48
|
+
class RTDLQuantileTransformer(BaseEstimator, TransformerMixin):
|
49
|
+
# adapted from pytabkit
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
noise=1e-5,
|
53
|
+
random_state=None,
|
54
|
+
n_quantiles=1000,
|
55
|
+
subsample=1_000_000_000,
|
56
|
+
output_distribution="normal",
|
57
|
+
):
|
58
|
+
self.noise = noise
|
59
|
+
self.random_state = random_state
|
60
|
+
self.n_quantiles = n_quantiles
|
61
|
+
self.subsample = subsample
|
62
|
+
self.output_distribution = output_distribution
|
63
|
+
|
64
|
+
def fit(self, X, y=None):
|
65
|
+
# Calculate the number of quantiles based on data size
|
66
|
+
n_quantiles = max(min(X.shape[0] // 30, self.n_quantiles), 10)
|
67
|
+
|
68
|
+
# Initialize QuantileTransformer
|
69
|
+
normalizer = QuantileTransformer(
|
70
|
+
output_distribution=self.output_distribution,
|
71
|
+
n_quantiles=n_quantiles,
|
72
|
+
subsample=self.subsample,
|
73
|
+
random_state=self.random_state,
|
74
|
+
)
|
75
|
+
|
76
|
+
# Add noise if required
|
77
|
+
X_modified = self._add_noise(X) if self.noise > 0 else X
|
78
|
+
|
79
|
+
# Fit the normalizer
|
80
|
+
normalizer.fit(X_modified)
|
81
|
+
# show that it's fitted
|
82
|
+
self.normalizer_ = normalizer
|
83
|
+
|
84
|
+
return self
|
85
|
+
|
86
|
+
def transform(self, X, y=None):
|
87
|
+
check_is_fitted(self)
|
88
|
+
return self.normalizer_.transform(X)
|
89
|
+
|
90
|
+
def _add_noise(self, X):
|
91
|
+
return X + np.random.default_rng(self.random_state).normal(0.0, 1e-5, X.shape).astype(X.dtype)
|
92
|
+
|
93
|
+
|
94
|
+
class TabMOrdinalEncoder(BaseEstimator, TransformerMixin):
|
95
|
+
# encodes missing and unknown values to a value one larger than the known values
|
96
|
+
def __init__(self):
|
97
|
+
# No fitted attributes here — only parameters
|
98
|
+
pass
|
99
|
+
|
100
|
+
def fit(self, X, y=None):
|
101
|
+
X = pd.DataFrame(X)
|
102
|
+
|
103
|
+
# Fit internal OrdinalEncoder with NaNs preserved for now
|
104
|
+
self.encoder_ = OrdinalEncoder(
|
105
|
+
handle_unknown="use_encoded_value",
|
106
|
+
unknown_value=np.nan,
|
107
|
+
encoded_missing_value=np.nan,
|
108
|
+
)
|
109
|
+
self.encoder_.fit(X)
|
110
|
+
|
111
|
+
# Cardinalities = number of known categories per column
|
112
|
+
self.cardinalities_ = [len(cats) for cats in self.encoder_.categories_]
|
113
|
+
|
114
|
+
return self
|
115
|
+
|
116
|
+
def transform(self, X):
|
117
|
+
check_is_fitted(self, ["encoder_", "cardinalities_"])
|
118
|
+
|
119
|
+
X = pd.DataFrame(X)
|
120
|
+
X_enc = self.encoder_.transform(X)
|
121
|
+
|
122
|
+
# Replace np.nan (unknown or missing) with cardinality value
|
123
|
+
for col_idx, cardinality in enumerate(self.cardinalities_):
|
124
|
+
mask = np.isnan(X_enc[:, col_idx])
|
125
|
+
X_enc[mask, col_idx] = cardinality
|
126
|
+
|
127
|
+
return X_enc.astype(int)
|
128
|
+
|
129
|
+
def get_cardinalities(self):
|
130
|
+
check_is_fitted(self, ["cardinalities_"])
|
131
|
+
return self.cardinalities_
|
132
|
+
|
133
|
+
|
134
|
+
class TabMImplementation:
|
135
|
+
def __init__(self, early_stopping_metric: Scorer, **config):
|
136
|
+
self.config = config
|
137
|
+
self.early_stopping_metric = early_stopping_metric
|
138
|
+
|
139
|
+
self.ord_enc_ = None
|
140
|
+
self.num_prep_ = None
|
141
|
+
self.cat_col_names_ = None
|
142
|
+
self.n_classes_ = None
|
143
|
+
self.task_type_ = None
|
144
|
+
self.device_ = None
|
145
|
+
self.has_num_cols = None
|
146
|
+
|
147
|
+
def fit(
|
148
|
+
self,
|
149
|
+
X_train: pd.DataFrame,
|
150
|
+
y_train: pd.Series,
|
151
|
+
X_val: pd.DataFrame,
|
152
|
+
y_val: pd.Series,
|
153
|
+
cat_col_names: list[Any],
|
154
|
+
time_to_fit_in_seconds: float | None = None,
|
155
|
+
):
|
156
|
+
start_time = time.time()
|
157
|
+
|
158
|
+
if X_val is None or len(X_val) == 0:
|
159
|
+
raise ValueError("Training without validation set is currently not implemented")
|
160
|
+
seed: int | None = self.config.get("random_state", None)
|
161
|
+
if seed is not None:
|
162
|
+
torch.manual_seed(seed)
|
163
|
+
np.random.seed(seed)
|
164
|
+
random.seed(seed)
|
165
|
+
if "n_threads" in self.config:
|
166
|
+
torch.set_num_threads(self.config["n_threads"])
|
167
|
+
|
168
|
+
# -- Meta parameters
|
169
|
+
problem_type = self.config["problem_type"]
|
170
|
+
task_type: TaskType = "binclass" if problem_type == "binary" else problem_type
|
171
|
+
n_train = len(X_train)
|
172
|
+
n_classes = None
|
173
|
+
device = self.config["device"]
|
174
|
+
device = torch.device(device)
|
175
|
+
self.task_type_ = task_type
|
176
|
+
self.device_ = device
|
177
|
+
self.cat_col_names_ = cat_col_names
|
178
|
+
|
179
|
+
# -- Hyperparameters
|
180
|
+
arch_type = self.config.get("arch_type", "tabm-mini")
|
181
|
+
num_emb_type = self.config.get("num_emb_type", "pwl")
|
182
|
+
n_epochs = self.config.get("n_epochs", 1_000_000_000)
|
183
|
+
patience = self.config.get("patience", 16)
|
184
|
+
batch_size = self.config.get("batch_size", "auto")
|
185
|
+
compile_model = self.config.get("compile_model", False)
|
186
|
+
lr = self.config.get("lr", 2e-3)
|
187
|
+
d_embedding = self.config.get("d_embedding", 16)
|
188
|
+
d_block = self.config.get("d_block", 512)
|
189
|
+
dropout = self.config.get("dropout", 0.1)
|
190
|
+
tabm_k = self.config.get("tabm_k", 32)
|
191
|
+
allow_amp = self.config.get("allow_amp", False)
|
192
|
+
n_blocks = self.config.get("n_blocks", "auto")
|
193
|
+
num_emb_n_bins = self.config.get("num_emb_n_bins", 48)
|
194
|
+
eval_batch_size = self.config.get("eval_batch_size", 1024)
|
195
|
+
share_training_batches = self.config.get("share_training_batches", False)
|
196
|
+
weight_decay = self.config.get("weight_decay", 3e-4)
|
197
|
+
# this is the search space default but not the example default (which is 'none')
|
198
|
+
gradient_clipping_norm = self.config.get("gradient_clipping_norm", 1.0)
|
199
|
+
|
200
|
+
# -- Verify HPs
|
201
|
+
num_emb_n_bins = min(num_emb_n_bins, n_train - 1)
|
202
|
+
if n_train <= 2:
|
203
|
+
num_emb_type = "none" # there is no valid number of bins for piecewise linear embeddings
|
204
|
+
if batch_size == "auto":
|
205
|
+
batch_size = get_tabm_auto_batch_size(n_train=n_train)
|
206
|
+
|
207
|
+
# -- Preprocessing
|
208
|
+
ds_parts = dict()
|
209
|
+
self.ord_enc_ = (
|
210
|
+
TabMOrdinalEncoder()
|
211
|
+
) # Unique ordinal encoder -> replaces nan and missing values with the cardinality
|
212
|
+
self.ord_enc_.fit(X_train[self.cat_col_names_])
|
213
|
+
# TODO: fix transformer to be able to work with empty input data like the sklearn default
|
214
|
+
self.num_prep_ = Pipeline(steps=[
|
215
|
+
("qt", RTDLQuantileTransformer(random_state=self.config.get("random_state", None))),
|
216
|
+
("imp", SimpleImputer(add_indicator=True)),
|
217
|
+
])
|
218
|
+
self.has_num_cols = bool(set(X_train.columns) - set(cat_col_names))
|
219
|
+
for part, X, y in [("train", X_train, y_train), ("val", X_val, y_val)]:
|
220
|
+
tensors = dict()
|
221
|
+
|
222
|
+
tensors["x_cat"] = torch.as_tensor(self.ord_enc_.transform(X[cat_col_names]), dtype=torch.long)
|
223
|
+
|
224
|
+
if self.has_num_cols:
|
225
|
+
x_cont_np = X.drop(columns=cat_col_names).to_numpy(dtype=np.float32)
|
226
|
+
if part == "train":
|
227
|
+
self.num_prep_.fit(x_cont_np)
|
228
|
+
tensors["x_cont"] = torch.as_tensor(self.num_prep_.transform(x_cont_np))
|
229
|
+
else:
|
230
|
+
tensors["x_cont"] = torch.empty((len(X), 0), dtype=torch.float32)
|
231
|
+
|
232
|
+
if task_type == "regression":
|
233
|
+
tensors["y"] = torch.as_tensor(y.to_numpy(np.float32))
|
234
|
+
if part == "train":
|
235
|
+
n_classes = 0
|
236
|
+
else:
|
237
|
+
tensors["y"] = torch.as_tensor(y.to_numpy(np.int32), dtype=torch.long)
|
238
|
+
if part == "train":
|
239
|
+
n_classes = tensors["y"].max().item() + 1
|
240
|
+
|
241
|
+
ds_parts[part] = tensors
|
242
|
+
|
243
|
+
part_names = ["train", "val"]
|
244
|
+
cat_cardinalities = self.ord_enc_.get_cardinalities()
|
245
|
+
self.n_classes_ = n_classes
|
246
|
+
|
247
|
+
# filter out numerical columns with only a single value
|
248
|
+
# -> AG also does this already but preprocessing might create constant columns again
|
249
|
+
x_cont_train = ds_parts["train"]["x_cont"]
|
250
|
+
self.num_col_mask_ = ~torch.all(x_cont_train == x_cont_train[0:1, :], dim=0)
|
251
|
+
for part in part_names:
|
252
|
+
ds_parts[part]["x_cont"] = ds_parts[part]["x_cont"][:, self.num_col_mask_]
|
253
|
+
# tensor infos are not correct anymore, but might not be used either
|
254
|
+
for part in part_names:
|
255
|
+
for tens_name in ds_parts[part]:
|
256
|
+
ds_parts[part][tens_name] = ds_parts[part][tens_name].to(device)
|
257
|
+
|
258
|
+
# update
|
259
|
+
n_cont_features = ds_parts["train"]["x_cont"].shape[1]
|
260
|
+
|
261
|
+
Y_train = ds_parts["train"]["y"].clone()
|
262
|
+
if task_type == "regression":
|
263
|
+
self.y_mean_ = ds_parts["train"]["y"].mean().item()
|
264
|
+
self.y_std_ = ds_parts["train"]["y"].std(correction=0).item()
|
265
|
+
|
266
|
+
Y_train = (Y_train - self.y_mean_) / (self.y_std_ + 1e-30)
|
267
|
+
|
268
|
+
# the | operator joins dicts (like update() but not in-place)
|
269
|
+
data = {
|
270
|
+
part: dict(x_cont=ds_parts[part]["x_cont"], y=ds_parts[part]["y"])
|
271
|
+
| (dict(x_cat=ds_parts[part]["x_cat"]) if ds_parts[part]["x_cat"].shape[1] > 0 else dict())
|
272
|
+
for part in part_names
|
273
|
+
}
|
274
|
+
|
275
|
+
# adapted from https://github.com/yandex-research/tabm/blob/main/example.ipynb
|
276
|
+
|
277
|
+
# Automatic mixed precision (AMP)
|
278
|
+
# torch.float16 is implemented for completeness,
|
279
|
+
# but it was not tested in the project,
|
280
|
+
# so torch.bfloat16 is used by default.
|
281
|
+
amp_dtype = (
|
282
|
+
torch.bfloat16
|
283
|
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
284
|
+
else torch.float16
|
285
|
+
if torch.cuda.is_available()
|
286
|
+
else None
|
287
|
+
)
|
288
|
+
# Changing False to True will result in faster training on compatible hardware.
|
289
|
+
amp_enabled = allow_amp and amp_dtype is not None
|
290
|
+
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None # type: ignore
|
291
|
+
|
292
|
+
# fmt: off
|
293
|
+
logger.log(15, f"Device: {device.type.upper()}"
|
294
|
+
f"\nAMP: {amp_enabled} (dtype: {amp_dtype})"
|
295
|
+
f"\ntorch.compile: {compile_model}",
|
296
|
+
)
|
297
|
+
# fmt: on
|
298
|
+
|
299
|
+
bins = (
|
300
|
+
None
|
301
|
+
if num_emb_type != "pwl" or n_cont_features == 0
|
302
|
+
else rtdl_num_embeddings.compute_bins(data["train"]["x_cont"], n_bins=num_emb_n_bins)
|
303
|
+
)
|
304
|
+
|
305
|
+
model = tabm_reference.Model(
|
306
|
+
n_num_features=n_cont_features,
|
307
|
+
cat_cardinalities=cat_cardinalities,
|
308
|
+
n_classes=n_classes if n_classes > 0 else None,
|
309
|
+
backbone={
|
310
|
+
"type": "MLP",
|
311
|
+
"n_blocks": n_blocks if n_blocks != "auto" else (3 if bins is None else 2),
|
312
|
+
"d_block": d_block,
|
313
|
+
"dropout": dropout,
|
314
|
+
},
|
315
|
+
bins=bins,
|
316
|
+
num_embeddings=(
|
317
|
+
None
|
318
|
+
if bins is None
|
319
|
+
else {
|
320
|
+
"type": "PiecewiseLinearEmbeddings",
|
321
|
+
"d_embedding": d_embedding,
|
322
|
+
"activation": False,
|
323
|
+
"version": "B",
|
324
|
+
}
|
325
|
+
),
|
326
|
+
arch_type=arch_type,
|
327
|
+
k=tabm_k,
|
328
|
+
share_training_batches=share_training_batches,
|
329
|
+
).to(device)
|
330
|
+
optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay)
|
331
|
+
|
332
|
+
if compile_model:
|
333
|
+
# NOTE
|
334
|
+
# `torch.compile` is intentionally called without the `mode` argument
|
335
|
+
# (mode="reduce-overhead" caused issues during training with torch==2.0.1).
|
336
|
+
model = torch.compile(model)
|
337
|
+
evaluation_mode = torch.no_grad
|
338
|
+
else:
|
339
|
+
evaluation_mode = torch.inference_mode
|
340
|
+
|
341
|
+
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype) # type: ignore[code]
|
342
|
+
def apply_model(part: str, idx: torch.Tensor) -> torch.Tensor:
|
343
|
+
return (
|
344
|
+
model(
|
345
|
+
data[part]["x_cont"][idx],
|
346
|
+
data[part]["x_cat"][idx] if "x_cat" in data[part] else None,
|
347
|
+
)
|
348
|
+
.squeeze(-1) # Remove the last dimension for regression tasks.
|
349
|
+
.float()
|
350
|
+
)
|
351
|
+
|
352
|
+
# TODO: use BCELoss for binary classification
|
353
|
+
base_loss_fn = torch.nn.functional.mse_loss if task_type == "regression" else torch.nn.functional.cross_entropy
|
354
|
+
|
355
|
+
def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
|
356
|
+
# TabM produces k predictions per object. Each of them must be trained separately.
|
357
|
+
# (regression) y_pred.shape == (batch_size, k)
|
358
|
+
# (classification) y_pred.shape == (batch_size, k, n_classes)
|
359
|
+
k = y_pred.shape[1]
|
360
|
+
return base_loss_fn(
|
361
|
+
y_pred.flatten(0, 1),
|
362
|
+
y_true.repeat_interleave(k) if model.share_training_batches else y_true,
|
363
|
+
)
|
364
|
+
|
365
|
+
@evaluation_mode()
|
366
|
+
def evaluate(part: str) -> float:
|
367
|
+
model.eval()
|
368
|
+
|
369
|
+
# When using torch.compile, you may need to reduce the evaluation batch size.
|
370
|
+
y_pred: np.ndarray = (
|
371
|
+
torch.cat(
|
372
|
+
[
|
373
|
+
apply_model(part, idx)
|
374
|
+
for idx in torch.arange(len(data[part]["y"]), device=device).split(
|
375
|
+
eval_batch_size,
|
376
|
+
)
|
377
|
+
],
|
378
|
+
)
|
379
|
+
.cpu()
|
380
|
+
.numpy()
|
381
|
+
)
|
382
|
+
if task_type == "regression":
|
383
|
+
# Transform the predictions back to the original label space.
|
384
|
+
y_pred = y_pred * self.y_std_ + self.y_mean_
|
385
|
+
|
386
|
+
# Compute the mean of the k predictions.
|
387
|
+
average_logits = self.config.get("average_logits", False)
|
388
|
+
if average_logits:
|
389
|
+
y_pred = y_pred.mean(1)
|
390
|
+
if task_type != "regression":
|
391
|
+
# For classification, the mean must be computed in the probability space.
|
392
|
+
y_pred = scipy.special.softmax(y_pred, axis=-1)
|
393
|
+
if not average_logits:
|
394
|
+
y_pred = y_pred.mean(1)
|
395
|
+
|
396
|
+
return compute_metric(
|
397
|
+
y=data[part]["y"].cpu().numpy(),
|
398
|
+
metric=self.early_stopping_metric,
|
399
|
+
y_pred=y_pred if task_type == "regression" else y_pred.argmax(1),
|
400
|
+
y_pred_proba=y_pred[:, 1] if task_type == "binclass" else y_pred,
|
401
|
+
silent=True,
|
402
|
+
)
|
403
|
+
|
404
|
+
math.ceil(n_train / batch_size)
|
405
|
+
best = {
|
406
|
+
"val": -math.inf,
|
407
|
+
# 'test': -math.inf,
|
408
|
+
"epoch": -1,
|
409
|
+
}
|
410
|
+
best_params = [p.clone() for p in model.parameters()]
|
411
|
+
# Early stopping: the training stops when
|
412
|
+
# there are more than `patience` consecutive bad updates.
|
413
|
+
remaining_patience = patience
|
414
|
+
|
415
|
+
try:
|
416
|
+
if self.config.get("verbosity", 0) >= 1:
|
417
|
+
from tqdm.std import tqdm
|
418
|
+
else:
|
419
|
+
tqdm = lambda arr, desc: arr
|
420
|
+
except ImportError:
|
421
|
+
tqdm = lambda arr, desc: arr
|
422
|
+
|
423
|
+
logger.log(15, "-" * 88 + "\n")
|
424
|
+
for epoch in range(n_epochs):
|
425
|
+
# check time limit
|
426
|
+
if epoch > 0 and time_to_fit_in_seconds is not None:
|
427
|
+
pred_time_after_next_epoch = (epoch + 1) / epoch * (time.time() - start_time)
|
428
|
+
if pred_time_after_next_epoch >= time_to_fit_in_seconds:
|
429
|
+
break
|
430
|
+
|
431
|
+
batches = (
|
432
|
+
torch.randperm(n_train, device=device).split(batch_size)
|
433
|
+
if model.share_training_batches
|
434
|
+
else [
|
435
|
+
x.transpose(0, 1).flatten()
|
436
|
+
for x in torch.rand((model.k, n_train), device=device).argsort(dim=1).split(batch_size, dim=1)
|
437
|
+
]
|
438
|
+
)
|
439
|
+
|
440
|
+
for batch_idx in tqdm(batches, desc=f"Epoch {epoch}"):
|
441
|
+
model.train()
|
442
|
+
optimizer.zero_grad()
|
443
|
+
loss = loss_fn(apply_model("train", batch_idx), Y_train[batch_idx])
|
444
|
+
|
445
|
+
# added from https://github.com/yandex-research/tabm/blob/main/bin/model.py
|
446
|
+
if gradient_clipping_norm is not None and gradient_clipping_norm != "none":
|
447
|
+
if grad_scaler is not None:
|
448
|
+
grad_scaler.unscale_(optimizer)
|
449
|
+
torch.nn.utils.clip_grad.clip_grad_norm_(
|
450
|
+
model.parameters(),
|
451
|
+
gradient_clipping_norm,
|
452
|
+
)
|
453
|
+
|
454
|
+
if grad_scaler is None:
|
455
|
+
loss.backward()
|
456
|
+
optimizer.step()
|
457
|
+
else:
|
458
|
+
grad_scaler.scale(loss).backward() # type: ignore
|
459
|
+
grad_scaler.step(optimizer) # Ignores grad scaler might skip steps; should not break anything
|
460
|
+
grad_scaler.update()
|
461
|
+
|
462
|
+
val_score = evaluate("val")
|
463
|
+
logger.log(15, f"(val) {val_score:.4f}")
|
464
|
+
|
465
|
+
if val_score > best["val"]:
|
466
|
+
logger.log(15, "🌸 New best epoch! 🌸")
|
467
|
+
# best = {'val': val_score, 'test': test_score, 'epoch': epoch}
|
468
|
+
best = {"val": val_score, "epoch": epoch}
|
469
|
+
remaining_patience = patience
|
470
|
+
with torch.no_grad():
|
471
|
+
for bp, p in zip(best_params, model.parameters(), strict=False):
|
472
|
+
bp.copy_(p)
|
473
|
+
else:
|
474
|
+
remaining_patience -= 1
|
475
|
+
|
476
|
+
if remaining_patience < 0:
|
477
|
+
break
|
478
|
+
|
479
|
+
logger.log(15, "\n\nResult:")
|
480
|
+
logger.log(15, str(best))
|
481
|
+
|
482
|
+
logger.log(15, "Restoring best model")
|
483
|
+
with torch.no_grad():
|
484
|
+
for bp, p in zip(best_params, model.parameters(), strict=False):
|
485
|
+
p.copy_(bp)
|
486
|
+
|
487
|
+
self.model_ = model
|
488
|
+
|
489
|
+
def predict_raw(self, X: pd.DataFrame) -> torch.Tensor:
|
490
|
+
self.model_.eval()
|
491
|
+
|
492
|
+
tensors = dict()
|
493
|
+
tensors["x_cat"] = torch.as_tensor(self.ord_enc_.transform(X[self.cat_col_names_]), dtype=torch.long).to(
|
494
|
+
self.device_,
|
495
|
+
)
|
496
|
+
tensors["x_cont"] = torch.as_tensor(
|
497
|
+
self.num_prep_.transform(X.drop(columns=X[self.cat_col_names_]).to_numpy(dtype=np.float32))
|
498
|
+
if self.has_num_cols
|
499
|
+
else np.empty((len(X), 0), dtype=np.float32),
|
500
|
+
).to(self.device_)
|
501
|
+
|
502
|
+
tensors["x_cont"] = tensors["x_cont"][:, self.num_col_mask_]
|
503
|
+
|
504
|
+
eval_batch_size = self.config.get("eval_batch_size", 1024)
|
505
|
+
with torch.no_grad():
|
506
|
+
y_pred: torch.Tensor = torch.cat(
|
507
|
+
[
|
508
|
+
self.model_(
|
509
|
+
tensors["x_cont"][idx],
|
510
|
+
tensors["x_cat"][idx] if tensors["x_cat"].numel() != 0 else None,
|
511
|
+
)
|
512
|
+
.squeeze(-1) # Remove the last dimension for regression tasks.
|
513
|
+
.float()
|
514
|
+
for idx in torch.arange(tensors["x_cont"].shape[0], device=self.device_).split(
|
515
|
+
eval_batch_size,
|
516
|
+
)
|
517
|
+
],
|
518
|
+
)
|
519
|
+
if self.task_type_ == "regression":
|
520
|
+
# Transform the predictions back to the original label space.
|
521
|
+
y_pred = y_pred * self.y_std_ + self.y_mean_
|
522
|
+
y_pred = y_pred.mean(1)
|
523
|
+
# y_pred = y_pred.unsqueeze(-1) # add extra "features" dimension
|
524
|
+
else:
|
525
|
+
average_logits = self.config.get("average_logits", False)
|
526
|
+
if average_logits:
|
527
|
+
y_pred = y_pred.mean(1)
|
528
|
+
else:
|
529
|
+
# For classification, the mean must be computed in the probability space.
|
530
|
+
y_pred = torch.log(torch.softmax(y_pred, dim=-1).mean(1) + 1e-30)
|
531
|
+
|
532
|
+
return y_pred.cpu()
|
533
|
+
|
534
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
535
|
+
y_pred = self.predict_raw(X)
|
536
|
+
if self.task_type_ == "regression":
|
537
|
+
return y_pred.numpy()
|
538
|
+
return y_pred.argmax(dim=-1).numpy()
|
539
|
+
|
540
|
+
def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
|
541
|
+
probas = torch.softmax(self.predict_raw(X), dim=-1).numpy()
|
542
|
+
if probas.shape[1] == 2:
|
543
|
+
probas = probas[:, 1]
|
544
|
+
return probas
|