ins-pricing 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,809 +1,830 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict, List, Optional, Tuple
4
-
5
- import numpy as np
6
- import optuna
7
- import pandas as pd
8
- from sklearn.metrics import log_loss
9
- from sklearn.model_selection import GroupKFold, TimeSeriesSplit
10
-
11
- from .trainer_base import TrainerBase
12
- from ..models import FTTransformerSklearn
13
- from ..utils.losses import regression_loss
14
-
15
- class FTTrainer(TrainerBase):
16
- def __init__(self, context: "BayesOptModel") -> None:
17
- if context.task_type == 'classification':
18
- super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
19
- else:
20
- super().__init__(context, 'FTTransformer', 'FTTransformer')
21
- self.model: Optional[FTTransformerSklearn] = None
22
- self.enable_distributed_optuna = bool(context.config.use_ft_ddp)
23
- self._cv_geo_warned = False
24
-
25
- def _resolve_numeric_tokens(self) -> int:
26
- requested = getattr(self.ctx.config, "ft_num_numeric_tokens", None)
27
- return FTTransformerSklearn.resolve_numeric_token_count(
28
- self.ctx.num_features,
29
- self.ctx.cate_list,
30
- requested,
31
- )
32
-
33
- def _resolve_adaptive_heads(self,
34
- d_model: int,
35
- requested_heads: Optional[int] = None) -> Tuple[int, bool]:
36
- d_model = int(d_model)
37
- if d_model <= 0:
38
- raise ValueError(f"Invalid d_model={d_model}, expected > 0.")
39
-
40
- default_heads = max(2, d_model // 16)
41
- base_heads = default_heads if requested_heads is None else int(
42
- requested_heads)
43
- base_heads = max(1, min(base_heads, d_model))
44
-
45
- if d_model % base_heads == 0:
46
- return base_heads, False
47
-
48
- for candidate in range(min(d_model, base_heads), 0, -1):
49
- if d_model % candidate == 0:
50
- return candidate, True
51
- return 1, True
52
-
53
- def _build_geo_tokens_for_split(self,
54
- X_train: pd.DataFrame,
55
- X_val: pd.DataFrame,
56
- geo_params: Optional[Dict[str, Any]] = None):
57
- if not self.ctx.config.geo_feature_nmes:
58
- return None
59
- orig_train = self.ctx.train_data
60
- orig_test = self.ctx.test_data
61
- try:
62
- self.ctx.train_data = orig_train.loc[X_train.index].copy()
63
- self.ctx.test_data = orig_train.loc[X_val.index].copy()
64
- return self.ctx._build_geo_tokens(geo_params)
65
- finally:
66
- self.ctx.train_data = orig_train
67
- self.ctx.test_data = orig_test
68
-
69
- def cross_val_unsupervised(self, trial: Optional[optuna.trial.Trial]) -> float:
70
- """Optuna objective A: minimize validation loss for masked reconstruction."""
71
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
72
- param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
73
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-3, log=True),
74
- "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
75
- "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
76
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
77
- "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
78
- "mask_prob_num": lambda t: t.suggest_float('mask_prob_num', 0.05, 0.4),
79
- "mask_prob_cat": lambda t: t.suggest_float('mask_prob_cat', 0.05, 0.4),
80
- "num_loss_weight": lambda t: t.suggest_float('num_loss_weight', 0.25, 4.0, log=True),
81
- "cat_loss_weight": lambda t: t.suggest_float('cat_loss_weight', 0.25, 4.0, log=True),
82
- }
83
-
84
- params: Optional[Dict[str, Any]] = None
85
- if self._distributed_forced_params is not None:
86
- params = self._distributed_forced_params
87
- self._distributed_forced_params = None
88
- else:
89
- if trial is None:
90
- raise RuntimeError(
91
- "Missing Optuna trial for parameter sampling.")
92
- params = {name: sampler(trial)
93
- for name, sampler in param_space.items()}
94
- if self._should_use_distributed_optuna():
95
- self._distributed_prepare_trial(params)
96
-
97
- X_all = self.ctx.train_data[self.ctx.factor_nmes]
98
- max_rows_for_ft_bo = min(1_000_000, int(len(X_all) / 2))
99
- if max_rows_for_ft_bo > 0 and len(X_all) > max_rows_for_ft_bo:
100
- sampled_idx = self._resolve_time_sample_indices(X_all, max_rows_for_ft_bo)
101
- if sampled_idx is None:
102
- X_all = X_all.sample(
103
- n=max_rows_for_ft_bo,
104
- random_state=self.ctx.rand_seed,
105
- )
106
- else:
107
- X_all = X_all.loc[sampled_idx]
108
-
109
- split = self._resolve_train_val_indices(X_all, allow_default=True)
110
- if split is None:
111
- raise ValueError("Unable to build train/val split for FT unsupervised CV.")
112
- train_idx, val_idx = split
113
- X_train = X_all.iloc[train_idx]
114
- X_val = X_all.iloc[val_idx]
115
- geo_train = geo_val = None
116
- if self.ctx.config.geo_feature_nmes:
117
- built = self._build_geo_tokens_for_split(X_train, X_val, params)
118
- if built is not None:
119
- geo_train, geo_val, _, _ = built
120
- elif not self._cv_geo_warned:
121
- print(
122
- "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
123
- flush=True,
124
- )
125
- self._cv_geo_warned = True
126
-
127
- d_model = int(params["d_model"])
128
- n_layers = int(params["n_layers"])
129
- num_numeric_tokens = self._resolve_numeric_tokens()
130
- token_count = num_numeric_tokens + len(self.ctx.cate_list)
131
- if geo_train is not None:
132
- token_count += 1
133
- approx_units = d_model * n_layers * max(1, token_count)
134
- if approx_units > 12_000_000:
135
- raise optuna.TrialPruned(
136
- f"config exceeds safe memory budget (approx_units={approx_units})")
137
-
138
- adaptive_heads, _ = self._resolve_adaptive_heads(
139
- d_model=d_model,
140
- requested_heads=params.get("n_heads")
141
- )
142
-
143
- mask_prob_num = float(params.get("mask_prob_num", 0.15))
144
- mask_prob_cat = float(params.get("mask_prob_cat", 0.15))
145
- num_loss_weight = float(params.get("num_loss_weight", 1.0))
146
- cat_loss_weight = float(params.get("cat_loss_weight", 1.0))
147
-
148
- model_params = dict(params)
149
- model_params["n_heads"] = adaptive_heads
150
- for k in ("mask_prob_num", "mask_prob_cat", "num_loss_weight", "cat_loss_weight"):
151
- model_params.pop(k, None)
152
-
153
- model = FTTransformerSklearn(
154
- model_nme=self.ctx.model_nme,
155
- num_cols=self.ctx.num_features,
156
- cat_cols=self.ctx.cate_list,
157
- task_type=self.ctx.task_type,
158
- epochs=self.ctx.epochs,
159
- patience=5,
160
- weight_decay=float(params.get("weight_decay", 0.0)),
161
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
162
- use_ddp=self.ctx.config.use_ft_ddp,
163
- num_numeric_tokens=num_numeric_tokens,
164
- loss_name=loss_name,
165
- )
166
- model = self._apply_dataloader_overrides(model)
167
- model.set_params(model_params)
168
- try:
169
- return float(model.fit_unsupervised(
170
- X_train,
171
- X_val=X_val,
172
- trial=trial,
173
- geo_train=geo_train,
174
- geo_val=geo_val,
175
- mask_prob_num=mask_prob_num,
176
- mask_prob_cat=mask_prob_cat,
177
- num_loss_weight=num_loss_weight,
178
- cat_loss_weight=cat_loss_weight
179
- ))
180
- finally:
181
- getattr(getattr(model, "ft", None), "to",
182
- lambda *_args, **_kwargs: None)("cpu")
183
- self._clean_gpu()
184
-
185
- def cross_val(self, trial: optuna.trial.Trial) -> float:
186
- # FT-Transformer CV also focuses on memory control:
187
- # - Shrink search space to avoid oversized models.
188
- # - Release GPU memory after each fold so the next trial can run.
189
- # Slightly shrink hyperparameter space to avoid oversized models.
190
- param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
191
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-4, log=True),
192
- # "d_model": lambda t: t.suggest_int('d_model', 8, 64, step=8),
193
- "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
194
- "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
195
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.2),
196
- "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
197
- }
198
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
199
- if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
200
- param_space["tw_power"] = lambda t: t.suggest_float(
201
- 'tw_power', 1.0, 2.0)
202
- geo_enabled = bool(
203
- self.ctx.geo_token_cols or self.ctx.config.geo_feature_nmes)
204
- if geo_enabled:
205
- # Only tune GNN-related hyperparams when geo tokens are enabled.
206
- param_space.update({
207
- "geo_token_hidden_dim": lambda t: t.suggest_int('geo_token_hidden_dim', 16, 128, step=16),
208
- "geo_token_layers": lambda t: t.suggest_int('geo_token_layers', 1, 4),
209
- "geo_token_k_neighbors": lambda t: t.suggest_int('geo_token_k_neighbors', 5, 20),
210
- "geo_token_dropout": lambda t: t.suggest_float('geo_token_dropout', 0.0, 0.3),
211
- "geo_token_learning_rate": lambda t: t.suggest_float('geo_token_learning_rate', 1e-4, 5e-3, log=True),
212
- })
213
-
214
- metric_ctx: Dict[str, Any] = {}
215
-
216
- def data_provider():
217
- data = self.ctx.train_data
218
- return data[self.ctx.factor_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
219
-
220
- def model_builder(params):
221
- d_model = int(params["d_model"])
222
- n_layers = int(params["n_layers"])
223
- num_numeric_tokens = self._resolve_numeric_tokens()
224
- token_count = num_numeric_tokens + len(self.ctx.cate_list)
225
- if geo_enabled:
226
- token_count += 1
227
- approx_units = d_model * n_layers * max(1, token_count)
228
- if approx_units > 12_000_000:
229
- print(
230
- f"[FTTrainer] Trial pruned early: d_model={d_model}, n_layers={n_layers} -> approx_units={approx_units}")
231
- raise optuna.TrialPruned(
232
- "config exceeds safe memory budget; prune before training")
233
- geo_params_local = {k: v for k, v in params.items()
234
- if k.startswith("geo_token_")}
235
-
236
- tw_power = params.get("tw_power")
237
- if self.ctx.task_type == 'regression':
238
- base_tw = self.ctx.default_tweedie_power()
239
- if loss_name == "tweedie":
240
- tw_power = base_tw if tw_power is None else tw_power
241
- elif loss_name in ("poisson", "gamma"):
242
- tw_power = base_tw
243
- else:
244
- tw_power = None
245
- metric_ctx["tw_power"] = tw_power
246
-
247
- adaptive_heads, _ = self._resolve_adaptive_heads(
248
- d_model=d_model,
249
- requested_heads=params.get("n_heads")
250
- )
251
-
252
- model = FTTransformerSklearn(
253
- model_nme=self.ctx.model_nme,
254
- num_cols=self.ctx.num_features,
255
- cat_cols=self.ctx.cate_list,
256
- d_model=d_model,
257
- n_heads=adaptive_heads,
258
- n_layers=n_layers,
259
- dropout=params["dropout"],
260
- task_type=self.ctx.task_type,
261
- epochs=self.ctx.epochs,
262
- tweedie_power=tw_power,
263
- learning_rate=params["learning_rate"],
264
- patience=5,
265
- weight_decay=float(params.get("weight_decay", 0.0)),
266
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
267
- use_ddp=self.ctx.config.use_ft_ddp,
268
- num_numeric_tokens=num_numeric_tokens,
269
- loss_name=loss_name,
270
- )
271
- model = self._apply_dataloader_overrides(model)
272
- model.set_params({"_geo_params": geo_params_local} if geo_enabled else {})
273
- return model
274
-
275
- def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
276
- geo_train = geo_val = None
277
- if geo_enabled:
278
- geo_params = getattr(model, "_geo_params", {})
279
- built = self._build_geo_tokens_for_split(
280
- X_train, X_val, geo_params)
281
- if built is not None:
282
- geo_train, geo_val, _, _ = built
283
- elif not self._cv_geo_warned:
284
- print(
285
- "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
286
- flush=True,
287
- )
288
- self._cv_geo_warned = True
289
- model.fit(
290
- X_train, y_train, w_train,
291
- X_val, y_val, w_val,
292
- trial=trial_obj,
293
- geo_train=geo_train,
294
- geo_val=geo_val
295
- )
296
- return model.predict(X_val, geo_tokens=geo_val)
297
-
298
- def metric_fn(y_true, y_pred, weight):
299
- if self.ctx.task_type == 'regression':
300
- return regression_loss(
301
- y_true,
302
- y_pred,
303
- weight,
304
- loss_name=loss_name,
305
- tweedie_power=metric_ctx.get("tw_power", 1.5),
306
- )
307
- return log_loss(y_true, y_pred, sample_weight=weight)
308
-
309
- data_for_cap = data_provider()[0]
310
- max_rows_for_ft_bo = min(1000000, int(len(data_for_cap)/2))
311
-
312
- return self.cross_val_generic(
313
- trial=trial,
314
- hyperparameter_space=param_space,
315
- data_provider=data_provider,
316
- model_builder=model_builder,
317
- metric_fn=metric_fn,
318
- sample_limit=max_rows_for_ft_bo if len(
319
- data_for_cap) > max_rows_for_ft_bo > 0 else None,
320
- fit_predict_fn=fit_predict,
321
- cleanup_fn=lambda m: getattr(
322
- getattr(m, "ft", None), "to", lambda *_args, **_kwargs: None)("cpu")
323
- )
324
-
325
- def train(self) -> None:
326
- if not self.best_params:
327
- raise RuntimeError("Run tune() first to obtain best FT-Transformer parameters.")
328
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
329
- resolved_params = dict(self.best_params)
330
- d_model_value = resolved_params.get("d_model", 64)
331
- adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
332
- d_model=d_model_value,
333
- requested_heads=resolved_params.get("n_heads")
334
- )
335
- if heads_adjusted:
336
- print(f"[FTTrainer] Auto-adjusted n_heads from "
337
- f"{resolved_params.get('n_heads')} to {adaptive_heads} "
338
- f"(d_model={d_model_value}).")
339
- resolved_params["n_heads"] = adaptive_heads
340
-
341
- use_refit = bool(getattr(self.ctx.config, "final_refit", True))
342
- refit_epochs = None
343
- X_all = self.ctx.train_data[self.ctx.factor_nmes]
344
- y_all = self.ctx.train_data[self.ctx.resp_nme]
345
- w_all = self.ctx.train_data[self.ctx.weight_nme]
346
- split = self._resolve_train_val_indices(X_all)
347
- if use_refit and split is not None:
348
- train_idx, val_idx = split
349
- tmp_model = FTTransformerSklearn(
350
- model_nme=self.ctx.model_nme,
351
- num_cols=self.ctx.num_features,
352
- cat_cols=self.ctx.cate_list,
353
- task_type=self.ctx.task_type,
354
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
355
- use_ddp=self.ctx.config.use_ft_ddp,
356
- num_numeric_tokens=self._resolve_numeric_tokens(),
357
- weight_decay=float(resolved_params.get("weight_decay", 0.0)),
358
- loss_name=loss_name,
359
- )
360
- tmp_model = self._apply_dataloader_overrides(tmp_model)
361
- tmp_model.set_params(resolved_params)
362
- geo_train_full = self.ctx.train_geo_tokens
363
- geo_train = None if geo_train_full is None else geo_train_full.iloc[train_idx]
364
- geo_val = None if geo_train_full is None else geo_train_full.iloc[val_idx]
365
- tmp_model.fit(
366
- X_all.iloc[train_idx],
367
- y_all.iloc[train_idx],
368
- w_all.iloc[train_idx],
369
- X_all.iloc[val_idx],
370
- y_all.iloc[val_idx],
371
- w_all.iloc[val_idx],
372
- trial=None,
373
- geo_train=geo_train,
374
- geo_val=geo_val,
375
- )
376
- refit_epochs = self._resolve_best_epoch(
377
- getattr(tmp_model, "training_history", None),
378
- default_epochs=int(self.ctx.epochs),
379
- )
380
- getattr(getattr(tmp_model, "ft", None), "to",
381
- lambda *_args, **_kwargs: None)("cpu")
382
- self._clean_gpu()
383
-
384
- self.model = FTTransformerSklearn(
385
- model_nme=self.ctx.model_nme,
386
- num_cols=self.ctx.num_features,
387
- cat_cols=self.ctx.cate_list,
388
- task_type=self.ctx.task_type,
389
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
390
- use_ddp=self.ctx.config.use_ft_ddp,
391
- num_numeric_tokens=self._resolve_numeric_tokens(),
392
- weight_decay=float(resolved_params.get("weight_decay", 0.0)),
393
- loss_name=loss_name,
394
- )
395
- self.model = self._apply_dataloader_overrides(self.model)
396
- if refit_epochs is not None:
397
- self.model.epochs = int(refit_epochs)
398
- self.model.set_params(resolved_params)
399
- self.best_params = resolved_params
400
- loss_plot_path = self.output.plot_path(
401
- f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
402
- self.model.loss_curve_path = loss_plot_path
403
- geo_train = self.ctx.train_geo_tokens
404
- geo_test = self.ctx.test_geo_tokens
405
- fit_kwargs = {}
406
- predict_kwargs_train = None
407
- predict_kwargs_test = None
408
- if geo_train is not None and geo_test is not None:
409
- fit_kwargs["geo_train"] = geo_train
410
- predict_kwargs_train = {"geo_tokens": geo_train}
411
- predict_kwargs_test = {"geo_tokens": geo_test}
412
- self._fit_predict_cache(
413
- self.model,
414
- self.ctx.train_data[self.ctx.factor_nmes],
415
- self.ctx.train_data[self.ctx.resp_nme],
416
- sample_weight=self.ctx.train_data[self.ctx.weight_nme],
417
- pred_prefix='ft',
418
- sample_weight_arg='w_train',
419
- fit_kwargs=fit_kwargs,
420
- predict_kwargs_train=predict_kwargs_train,
421
- predict_kwargs_test=predict_kwargs_test
422
- )
423
- self.ctx.ft_best = self.model
424
-
425
- def ensemble_predict(self, k: int) -> None:
426
- if not self.best_params:
427
- raise RuntimeError("Run tune() first to obtain best FT-Transformer parameters.")
428
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
429
- k = max(2, int(k))
430
- X_all = self.ctx.train_data[self.ctx.factor_nmes]
431
- y_all = self.ctx.train_data[self.ctx.resp_nme]
432
- w_all = self.ctx.train_data[self.ctx.weight_nme]
433
- X_test = self.ctx.test_data[self.ctx.factor_nmes]
434
- n_samples = len(X_all)
435
- geo_train_full = self.ctx.train_geo_tokens
436
- geo_test_full = self.ctx.test_geo_tokens
437
-
438
- resolved_params = dict(self.best_params)
439
- default_d_model = getattr(self.model, "d_model", 64)
440
- adaptive_heads, _ = self._resolve_adaptive_heads(
441
- d_model=resolved_params.get("d_model", default_d_model),
442
- requested_heads=resolved_params.get("n_heads")
443
- )
444
- resolved_params["n_heads"] = adaptive_heads
445
-
446
- split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
447
- if split_iter is None:
448
- print(
449
- f"[FT Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
450
- flush=True,
451
- )
452
- return
453
- preds_train_sum = np.zeros(n_samples, dtype=np.float64)
454
- preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
455
-
456
- split_count = 0
457
- for train_idx, val_idx in split_iter:
458
- model = FTTransformerSklearn(
459
- model_nme=self.ctx.model_nme,
460
- num_cols=self.ctx.num_features,
461
- cat_cols=self.ctx.cate_list,
462
- task_type=self.ctx.task_type,
463
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
464
- use_ddp=self.ctx.config.use_ft_ddp,
465
- num_numeric_tokens=self._resolve_numeric_tokens(),
466
- weight_decay=float(resolved_params.get("weight_decay", 0.0)),
467
- loss_name=loss_name,
468
- )
469
- model = self._apply_dataloader_overrides(model)
470
- model.set_params(resolved_params)
471
-
472
- geo_train = geo_val = None
473
- if geo_train_full is not None:
474
- geo_train = geo_train_full.iloc[train_idx]
475
- geo_val = geo_train_full.iloc[val_idx]
476
-
477
- model.fit(
478
- X_all.iloc[train_idx],
479
- y_all.iloc[train_idx],
480
- w_all.iloc[train_idx],
481
- X_all.iloc[val_idx],
482
- y_all.iloc[val_idx],
483
- w_all.iloc[val_idx],
484
- trial=None,
485
- geo_train=geo_train,
486
- geo_val=geo_val,
487
- )
488
-
489
- pred_train = model.predict(X_all, geo_tokens=geo_train_full)
490
- pred_test = model.predict(X_test, geo_tokens=geo_test_full)
491
- preds_train_sum += np.asarray(pred_train, dtype=np.float64)
492
- preds_test_sum += np.asarray(pred_test, dtype=np.float64)
493
- getattr(getattr(model, "ft", None), "to",
494
- lambda *_args, **_kwargs: None)("cpu")
495
- self._clean_gpu()
496
- split_count += 1
497
-
498
- if split_count < 1:
499
- print(
500
- f"[FT Ensemble] no CV splits generated; skip ensemble.",
501
- flush=True,
502
- )
503
- return
504
- preds_train = preds_train_sum / float(split_count)
505
- preds_test = preds_test_sum / float(split_count)
506
- self._cache_predictions("ft", preds_train, preds_test)
507
-
508
- def _resolve_oof_splitter(self, n_samples: int):
509
- cfg = self.ctx.config
510
- raw_strategy = str(getattr(cfg, "ft_oof_strategy", "auto") or "auto").strip().lower()
511
- base_strategy = str(getattr(cfg, "cv_strategy", "random") or "random").strip().lower()
512
- if raw_strategy == "auto":
513
- strategy = base_strategy
514
- else:
515
- strategy = raw_strategy
516
-
517
- oof_folds = getattr(cfg, "ft_oof_folds", None)
518
- if oof_folds is None:
519
- if strategy in {"random", "group", "grouped"}:
520
- val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test else 0.25
521
- if not (0.0 < val_ratio < 1.0):
522
- val_ratio = 0.25
523
- oof_folds = max(2, int(round(1 / val_ratio)))
524
- else:
525
- oof_folds = 0
526
- oof_folds = int(oof_folds)
527
-
528
- if oof_folds < 2 or n_samples < oof_folds:
529
- return None, None, 0
530
-
531
- if strategy in {"group", "grouped"}:
532
- group_col = getattr(cfg, "cv_group_col", None)
533
- if not group_col:
534
- raise ValueError("cv_group_col is required for FT OOF group strategy.")
535
- if group_col not in self.ctx.train_data.columns:
536
- raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
537
- groups = self.ctx.train_data[group_col]
538
- splitter = GroupKFold(n_splits=oof_folds)
539
- return splitter, groups, oof_folds
540
-
541
- if strategy in {"time", "timeseries", "temporal"}:
542
- time_col = getattr(cfg, "cv_time_col", None)
543
- if not time_col:
544
- raise ValueError("cv_time_col is required for FT OOF time strategy.")
545
- if time_col not in self.ctx.train_data.columns:
546
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
547
- ascending = bool(getattr(cfg, "cv_time_ascending", True))
548
- order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
549
- order = self.ctx.train_data.index.get_indexer(order_index)
550
- if n_samples <= oof_folds:
551
- return None, None, 0
552
- splitter = TimeSeriesSplit(n_splits=oof_folds)
553
- return _OrderSplitter(splitter, order), None, oof_folds
554
-
555
- shuffle = bool(getattr(cfg, "ft_oof_shuffle", True))
556
- splitter = KFold(
557
- n_splits=oof_folds,
558
- shuffle=shuffle,
559
- random_state=self.ctx.rand_seed if shuffle else None,
560
- )
561
- return splitter, None, oof_folds
562
-
563
- def _build_ft_feature_model(self, resolved_params: Dict[str, Any]) -> FTTransformerSklearn:
564
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
565
- model = FTTransformerSklearn(
566
- model_nme=self.ctx.model_nme,
567
- num_cols=self.ctx.num_features,
568
- cat_cols=self.ctx.cate_list,
569
- task_type=self.ctx.task_type,
570
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
571
- use_ddp=self.ctx.config.use_ft_ddp,
572
- num_numeric_tokens=self._resolve_numeric_tokens(),
573
- loss_name=loss_name,
574
- )
575
- model = self._apply_dataloader_overrides(model)
576
- adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
577
- d_model=resolved_params.get("d_model", model.d_model),
578
- requested_heads=resolved_params.get("n_heads"),
579
- )
580
- if heads_adjusted:
581
- print(
582
- f"[FTTrainer] Auto-adjusted n_heads from "
583
- f"{resolved_params.get('n_heads')} to {adaptive_heads} "
584
- f"(d_model={resolved_params.get('d_model', model.d_model)})."
585
- )
586
- resolved_params["n_heads"] = adaptive_heads
587
- if resolved_params:
588
- model.set_params(resolved_params)
589
- return model
590
-
591
- def _oof_predict_train(
592
- self,
593
- resolved_params: Dict[str, Any],
594
- *,
595
- feature_mode: str,
596
- geo_train_full: Optional[pd.DataFrame],
597
- ) -> Optional[np.ndarray]:
598
- X_all = self.ctx.train_data[self.ctx.factor_nmes]
599
- y_all = self.ctx.train_data[self.ctx.resp_nme]
600
- w_all = self.ctx.train_data[self.ctx.weight_nme]
601
- splitter, groups, oof_folds = self._resolve_oof_splitter(len(X_all))
602
- if splitter is None:
603
- return None
604
-
605
- preds_train = None
606
- for fold_idx, (train_idx, val_idx) in enumerate(splitter.split(X_all, y_all, groups=groups), start=1):
607
- X_train = X_all.iloc[train_idx]
608
- y_train = y_all.iloc[train_idx]
609
- w_train = w_all.iloc[train_idx]
610
- X_val = X_all.iloc[val_idx]
611
- y_val = y_all.iloc[val_idx]
612
- w_val = w_all.iloc[val_idx]
613
-
614
- geo_train = geo_val = None
615
- if geo_train_full is not None:
616
- geo_train = geo_train_full.iloc[train_idx]
617
- geo_val = geo_train_full.iloc[val_idx]
618
-
619
- model = self._build_ft_feature_model(dict(resolved_params))
620
- model.fit(
621
- X_train,
622
- y_train,
623
- w_train=w_train,
624
- X_val=X_val,
625
- y_val=y_val,
626
- w_val=w_val,
627
- trial=None,
628
- geo_train=geo_train,
629
- geo_val=geo_val,
630
- )
631
-
632
- predict_kwargs = {}
633
- if geo_val is not None:
634
- predict_kwargs["geo_tokens"] = geo_val
635
- if feature_mode == "embedding":
636
- predict_kwargs["return_embedding"] = True
637
- fold_pred = model.predict(X_val, **predict_kwargs)
638
- fold_pred = np.asarray(fold_pred)
639
- if preds_train is None:
640
- preds_train = np.empty((len(X_all),) + fold_pred.shape[1:], dtype=fold_pred.dtype)
641
- preds_train[val_idx] = fold_pred
642
-
643
- getattr(getattr(model, "ft", None), "to", lambda *_a, **_k: None)("cpu")
644
- self._clean_gpu()
645
-
646
- if preds_train is None:
647
- return None
648
- if oof_folds < 2:
649
- return None
650
- return preds_train
651
-
652
- def train_as_feature(self, pred_prefix: str = "ft_feat", feature_mode: str = "prediction") -> None:
653
- """Train FT-Transformer only to generate features (not recorded as final model)."""
654
- if not self.best_params:
655
- raise RuntimeError("Run tune() first to obtain best FT-Transformer parameters.")
656
- resolved_params = dict(self.best_params)
657
- if feature_mode not in ("prediction", "embedding"):
658
- raise ValueError(
659
- f"Unsupported feature_mode='{feature_mode}', expected 'prediction' or 'embedding'.")
660
-
661
- geo_train = self.ctx.train_geo_tokens
662
- geo_test = self.ctx.test_geo_tokens
663
- fit_kwargs = {}
664
- predict_kwargs_train = None
665
- predict_kwargs_test = None
666
- if geo_train is not None and geo_test is not None:
667
- fit_kwargs["geo_train"] = geo_train
668
- predict_kwargs_train = {"geo_tokens": geo_train}
669
- predict_kwargs_test = {"geo_tokens": geo_test}
670
-
671
- if feature_mode == "embedding":
672
- predict_kwargs_train = dict(predict_kwargs_train or {})
673
- predict_kwargs_test = dict(predict_kwargs_test or {})
674
- predict_kwargs_train["return_embedding"] = True
675
- predict_kwargs_test["return_embedding"] = True
676
-
677
- oof_preds = self._oof_predict_train(
678
- resolved_params,
679
- feature_mode=feature_mode,
680
- geo_train_full=geo_train,
681
- )
682
- if oof_preds is not None:
683
- self.model = self._build_ft_feature_model(resolved_params)
684
- self.best_params = resolved_params
685
- self.model.fit(
686
- self.ctx.train_data[self.ctx.factor_nmes],
687
- self.ctx.train_data[self.ctx.resp_nme],
688
- w_train=self.ctx.train_data[self.ctx.weight_nme],
689
- X_val=None,
690
- y_val=None,
691
- w_val=None,
692
- trial=None,
693
- geo_train=geo_train,
694
- geo_val=None,
695
- )
696
- predict_kwargs = dict(predict_kwargs_test or {})
697
- preds_test = self.model.predict(
698
- self.ctx.test_data[self.ctx.factor_nmes],
699
- **predict_kwargs,
700
- )
701
- self._cache_predictions(pred_prefix, oof_preds, preds_test)
702
- return
703
-
704
- self.model = self._build_ft_feature_model(resolved_params)
705
- self.best_params = resolved_params
706
- self._fit_predict_cache(
707
- self.model,
708
- self.ctx.train_data[self.ctx.factor_nmes],
709
- self.ctx.train_data[self.ctx.resp_nme],
710
- sample_weight=self.ctx.train_data[self.ctx.weight_nme],
711
- pred_prefix=pred_prefix,
712
- sample_weight_arg='w_train',
713
- fit_kwargs=fit_kwargs,
714
- predict_kwargs_train=predict_kwargs_train,
715
- predict_kwargs_test=predict_kwargs_test,
716
- record_label=False,
717
- )
718
-
719
- def pretrain_unsupervised_as_feature(self,
720
- pred_prefix: str = "ft_uemb",
721
- params: Optional[Dict[str,
722
- Any]] = None,
723
- mask_prob_num: float = 0.15,
724
- mask_prob_cat: float = 0.15,
725
- num_loss_weight: float = 1.0,
726
- cat_loss_weight: float = 1.0) -> None:
727
- """Self-supervised pretraining (masked reconstruction) and cache embeddings."""
728
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
729
- self.model = FTTransformerSklearn(
730
- model_nme=self.ctx.model_nme,
731
- num_cols=self.ctx.num_features,
732
- cat_cols=self.ctx.cate_list,
733
- task_type=self.ctx.task_type,
734
- use_data_parallel=self.ctx.config.use_ft_data_parallel,
735
- use_ddp=self.ctx.config.use_ft_ddp,
736
- num_numeric_tokens=self._resolve_numeric_tokens(),
737
- loss_name=loss_name,
738
- )
739
- self.model = self._apply_dataloader_overrides(self.model)
740
- resolved_params = dict(params or {})
741
- # Reuse supervised tuning structure params unless explicitly overridden.
742
- if not resolved_params and self.best_params:
743
- resolved_params = dict(self.best_params)
744
-
745
- # If params include masked reconstruction fields, they take precedence.
746
- mask_prob_num = float(resolved_params.pop(
747
- "mask_prob_num", mask_prob_num))
748
- mask_prob_cat = float(resolved_params.pop(
749
- "mask_prob_cat", mask_prob_cat))
750
- num_loss_weight = float(resolved_params.pop(
751
- "num_loss_weight", num_loss_weight))
752
- cat_loss_weight = float(resolved_params.pop(
753
- "cat_loss_weight", cat_loss_weight))
754
-
755
- adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
756
- d_model=resolved_params.get("d_model", self.model.d_model),
757
- requested_heads=resolved_params.get("n_heads")
758
- )
759
- if heads_adjusted:
760
- print(f"[FTTrainer] Auto-adjusted n_heads from "
761
- f"{resolved_params.get('n_heads')} to {adaptive_heads} "
762
- f"(d_model={resolved_params.get('d_model', self.model.d_model)}).")
763
- resolved_params["n_heads"] = adaptive_heads
764
- if resolved_params:
765
- self.model.set_params(resolved_params)
766
-
767
- loss_plot_path = self.output.plot_path(
768
- f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_FTTransformerUnsupervised.png')
769
- self.model.loss_curve_path = loss_plot_path
770
-
771
- # Build a simple holdout split for pretraining early stopping.
772
- X_all = self.ctx.train_data[self.ctx.factor_nmes]
773
- split = self._resolve_train_val_indices(X_all, allow_default=True)
774
- if split is None:
775
- raise ValueError("Unable to build train/val split for FT unsupervised training.")
776
- train_idx, val_idx = split
777
- X_tr = X_all.iloc[train_idx]
778
- X_val = X_all.iloc[val_idx]
779
-
780
- geo_all = self.ctx.train_geo_tokens
781
- geo_tr = geo_val = None
782
- if geo_all is not None:
783
- geo_tr = geo_all.loc[X_tr.index]
784
- geo_val = geo_all.loc[X_val.index]
785
-
786
- self.model.fit_unsupervised(
787
- X_tr,
788
- X_val=X_val,
789
- geo_train=geo_tr,
790
- geo_val=geo_val,
791
- mask_prob_num=mask_prob_num,
792
- mask_prob_cat=mask_prob_cat,
793
- num_loss_weight=num_loss_weight,
794
- cat_loss_weight=cat_loss_weight
795
- )
796
-
797
- geo_train_full = self.ctx.train_geo_tokens
798
- geo_test_full = self.ctx.test_geo_tokens
799
- predict_kwargs_train = {"return_embedding": True}
800
- predict_kwargs_test = {"return_embedding": True}
801
- if geo_train_full is not None and geo_test_full is not None:
802
- predict_kwargs_train["geo_tokens"] = geo_train_full
803
- predict_kwargs_test["geo_tokens"] = geo_test_full
804
-
805
- self._predict_and_cache(
806
- self.model,
807
- pred_prefix=pred_prefix,
808
- predict_kwargs_train=predict_kwargs_train,
809
-
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import optuna
7
+ import pandas as pd
8
+ from sklearn.metrics import log_loss
9
+ from sklearn.model_selection import GroupKFold, TimeSeriesSplit
10
+
11
+ from .trainer_base import TrainerBase
12
+ from ..models import FTTransformerSklearn
13
+ from ..utils.losses import regression_loss
14
+
15
+
16
+ class FTTrainer(TrainerBase):
17
+ def __init__(self, context: "BayesOptModel") -> None:
18
+ if context.task_type == 'classification':
19
+ super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
20
+ else:
21
+ super().__init__(context, 'FTTransformer', 'FTTransformer')
22
+ self.model: Optional[FTTransformerSklearn] = None
23
+ self.enable_distributed_optuna = bool(context.config.use_ft_ddp)
24
+ self._cv_geo_warned = False
25
+
26
+ def _resolve_numeric_tokens(self) -> int:
27
+ requested = getattr(self.ctx.config, "ft_num_numeric_tokens", None)
28
+ return FTTransformerSklearn.resolve_numeric_token_count(
29
+ self.ctx.num_features,
30
+ self.ctx.cate_list,
31
+ requested,
32
+ )
33
+
34
+ def _resolve_adaptive_heads(self,
35
+ d_model: int,
36
+ requested_heads: Optional[int] = None) -> Tuple[int, bool]:
37
+ d_model = int(d_model)
38
+ if d_model <= 0:
39
+ raise ValueError(f"Invalid d_model={d_model}, expected > 0.")
40
+
41
+ default_heads = max(2, d_model // 16)
42
+ base_heads = default_heads if requested_heads is None else int(
43
+ requested_heads)
44
+ base_heads = max(1, min(base_heads, d_model))
45
+
46
+ if d_model % base_heads == 0:
47
+ return base_heads, False
48
+
49
+ for candidate in range(min(d_model, base_heads), 0, -1):
50
+ if d_model % candidate == 0:
51
+ return candidate, True
52
+ return 1, True
53
+
54
+ def _build_geo_tokens_for_split(self,
55
+ X_train: pd.DataFrame,
56
+ X_val: pd.DataFrame,
57
+ geo_params: Optional[Dict[str, Any]] = None):
58
+ if not self.ctx.config.geo_feature_nmes:
59
+ return None
60
+ orig_train = self.ctx.train_data
61
+ orig_test = self.ctx.test_data
62
+ try:
63
+ self.ctx.train_data = orig_train.loc[X_train.index].copy()
64
+ self.ctx.test_data = orig_train.loc[X_val.index].copy()
65
+ return self.ctx._build_geo_tokens(geo_params)
66
+ finally:
67
+ self.ctx.train_data = orig_train
68
+ self.ctx.test_data = orig_test
69
+
70
+ def cross_val_unsupervised(self, trial: Optional[optuna.trial.Trial]) -> float:
71
+ """Optuna objective A: minimize validation loss for masked reconstruction."""
72
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
73
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
74
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-3, log=True),
75
+ "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
76
+ "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
77
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
78
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
79
+ "mask_prob_num": lambda t: t.suggest_float('mask_prob_num', 0.05, 0.4),
80
+ "mask_prob_cat": lambda t: t.suggest_float('mask_prob_cat', 0.05, 0.4),
81
+ "num_loss_weight": lambda t: t.suggest_float('num_loss_weight', 0.25, 4.0, log=True),
82
+ "cat_loss_weight": lambda t: t.suggest_float('cat_loss_weight', 0.25, 4.0, log=True),
83
+ }
84
+
85
+ params: Optional[Dict[str, Any]] = None
86
+ if self._distributed_forced_params is not None:
87
+ params = self._distributed_forced_params
88
+ self._distributed_forced_params = None
89
+ else:
90
+ if trial is None:
91
+ raise RuntimeError(
92
+ "Missing Optuna trial for parameter sampling.")
93
+ params = {name: sampler(trial)
94
+ for name, sampler in param_space.items()}
95
+ if self._should_use_distributed_optuna():
96
+ self._distributed_prepare_trial(params)
97
+
98
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
99
+ max_rows_for_ft_bo = min(1_000_000, int(len(X_all) / 2))
100
+ if max_rows_for_ft_bo > 0 and len(X_all) > max_rows_for_ft_bo:
101
+ sampled_idx = self._resolve_time_sample_indices(
102
+ X_all, max_rows_for_ft_bo)
103
+ if sampled_idx is None:
104
+ X_all = X_all.sample(
105
+ n=max_rows_for_ft_bo,
106
+ random_state=self.ctx.rand_seed,
107
+ )
108
+ else:
109
+ X_all = X_all.loc[sampled_idx]
110
+
111
+ split = self._resolve_train_val_indices(X_all, allow_default=True)
112
+ if split is None:
113
+ raise ValueError(
114
+ "Unable to build train/val split for FT unsupervised CV.")
115
+ train_idx, val_idx = split
116
+ X_train = X_all.iloc[train_idx]
117
+ X_val = X_all.iloc[val_idx]
118
+ geo_train = geo_val = None
119
+ if self.ctx.config.geo_feature_nmes:
120
+ built = self._build_geo_tokens_for_split(X_train, X_val, params)
121
+ if built is not None:
122
+ geo_train, geo_val, _, _ = built
123
+ elif not self._cv_geo_warned:
124
+ print(
125
+ "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
126
+ flush=True,
127
+ )
128
+ self._cv_geo_warned = True
129
+
130
+ d_model = int(params["d_model"])
131
+ n_layers = int(params["n_layers"])
132
+ num_numeric_tokens = self._resolve_numeric_tokens()
133
+ token_count = num_numeric_tokens + len(self.ctx.cate_list)
134
+ if geo_train is not None:
135
+ token_count += 1
136
+ approx_units = d_model * n_layers * max(1, token_count)
137
+ if approx_units > 12_000_000:
138
+ raise optuna.TrialPruned(
139
+ f"config exceeds safe memory budget (approx_units={approx_units})")
140
+
141
+ adaptive_heads, _ = self._resolve_adaptive_heads(
142
+ d_model=d_model,
143
+ requested_heads=params.get("n_heads")
144
+ )
145
+
146
+ mask_prob_num = float(params.get("mask_prob_num", 0.15))
147
+ mask_prob_cat = float(params.get("mask_prob_cat", 0.15))
148
+ num_loss_weight = float(params.get("num_loss_weight", 1.0))
149
+ cat_loss_weight = float(params.get("cat_loss_weight", 1.0))
150
+
151
+ model_params = dict(params)
152
+ model_params["n_heads"] = adaptive_heads
153
+ for k in ("mask_prob_num", "mask_prob_cat", "num_loss_weight", "cat_loss_weight"):
154
+ model_params.pop(k, None)
155
+
156
+ model = FTTransformerSklearn(
157
+ model_nme=self.ctx.model_nme,
158
+ num_cols=self.ctx.num_features,
159
+ cat_cols=self.ctx.cate_list,
160
+ task_type=self.ctx.task_type,
161
+ epochs=self.ctx.epochs,
162
+ patience=5,
163
+ weight_decay=float(params.get("weight_decay", 0.0)),
164
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
165
+ use_ddp=self.ctx.config.use_ft_ddp,
166
+ num_numeric_tokens=num_numeric_tokens,
167
+ loss_name=loss_name,
168
+ )
169
+ model = self._apply_dataloader_overrides(model)
170
+ model.set_params(model_params)
171
+ try:
172
+ return float(model.fit_unsupervised(
173
+ X_train,
174
+ X_val=X_val,
175
+ trial=trial,
176
+ geo_train=geo_train,
177
+ geo_val=geo_val,
178
+ mask_prob_num=mask_prob_num,
179
+ mask_prob_cat=mask_prob_cat,
180
+ num_loss_weight=num_loss_weight,
181
+ cat_loss_weight=cat_loss_weight
182
+ ))
183
+ finally:
184
+ getattr(getattr(model, "ft", None), "to",
185
+ lambda *_args, **_kwargs: None)("cpu")
186
+ self._clean_gpu()
187
+
188
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
189
+ # FT-Transformer CV also focuses on memory control:
190
+ # - Shrink search space to avoid oversized models.
191
+ # - Release GPU memory after each fold so the next trial can run.
192
+ # Slightly shrink hyperparameter space to avoid oversized models.
193
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
194
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-4, log=True),
195
+ # "d_model": lambda t: t.suggest_int('d_model', 8, 64, step=8),
196
+ "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
197
+ "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
198
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.2),
199
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
200
+ }
201
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
202
+ if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
203
+ param_space["tw_power"] = lambda t: t.suggest_float(
204
+ 'tw_power', 1.0, 2.0)
205
+ geo_enabled = bool(
206
+ self.ctx.geo_token_cols or self.ctx.config.geo_feature_nmes)
207
+ if geo_enabled:
208
+ # Only tune GNN-related hyperparams when geo tokens are enabled.
209
+ param_space.update({
210
+ "geo_token_hidden_dim": lambda t: t.suggest_int('geo_token_hidden_dim', 16, 128, step=16),
211
+ "geo_token_layers": lambda t: t.suggest_int('geo_token_layers', 1, 4),
212
+ "geo_token_k_neighbors": lambda t: t.suggest_int('geo_token_k_neighbors', 5, 20),
213
+ "geo_token_dropout": lambda t: t.suggest_float('geo_token_dropout', 0.0, 0.3),
214
+ "geo_token_learning_rate": lambda t: t.suggest_float('geo_token_learning_rate', 1e-4, 5e-3, log=True),
215
+ })
216
+
217
+ metric_ctx: Dict[str, Any] = {}
218
+
219
+ def data_provider():
220
+ data = self.ctx.train_data
221
+ return data[self.ctx.factor_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
222
+
223
+ def model_builder(params):
224
+ d_model = int(params["d_model"])
225
+ n_layers = int(params["n_layers"])
226
+ num_numeric_tokens = self._resolve_numeric_tokens()
227
+ token_count = num_numeric_tokens + len(self.ctx.cate_list)
228
+ if geo_enabled:
229
+ token_count += 1
230
+ approx_units = d_model * n_layers * max(1, token_count)
231
+ if approx_units > 12_000_000:
232
+ print(
233
+ f"[FTTrainer] Trial pruned early: d_model={d_model}, n_layers={n_layers} -> approx_units={approx_units}")
234
+ raise optuna.TrialPruned(
235
+ "config exceeds safe memory budget; prune before training")
236
+ geo_params_local = {k: v for k, v in params.items()
237
+ if k.startswith("geo_token_")}
238
+
239
+ tw_power = params.get("tw_power")
240
+ if self.ctx.task_type == 'regression':
241
+ base_tw = self.ctx.default_tweedie_power()
242
+ if loss_name == "tweedie":
243
+ tw_power = base_tw if tw_power is None else tw_power
244
+ elif loss_name in ("poisson", "gamma"):
245
+ tw_power = base_tw
246
+ else:
247
+ tw_power = None
248
+ metric_ctx["tw_power"] = tw_power
249
+
250
+ adaptive_heads, _ = self._resolve_adaptive_heads(
251
+ d_model=d_model,
252
+ requested_heads=params.get("n_heads")
253
+ )
254
+
255
+ model = FTTransformerSklearn(
256
+ model_nme=self.ctx.model_nme,
257
+ num_cols=self.ctx.num_features,
258
+ cat_cols=self.ctx.cate_list,
259
+ d_model=d_model,
260
+ n_heads=adaptive_heads,
261
+ n_layers=n_layers,
262
+ dropout=params["dropout"],
263
+ task_type=self.ctx.task_type,
264
+ epochs=self.ctx.epochs,
265
+ tweedie_power=tw_power,
266
+ learning_rate=params["learning_rate"],
267
+ patience=5,
268
+ weight_decay=float(params.get("weight_decay", 0.0)),
269
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
270
+ use_ddp=self.ctx.config.use_ft_ddp,
271
+ num_numeric_tokens=num_numeric_tokens,
272
+ loss_name=loss_name,
273
+ )
274
+ model = self._apply_dataloader_overrides(model)
275
+ model.set_params({"_geo_params": geo_params_local}
276
+ if geo_enabled else {})
277
+ return model
278
+
279
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
280
+ geo_train = geo_val = None
281
+ if geo_enabled:
282
+ geo_params = getattr(model, "_geo_params", {})
283
+ built = self._build_geo_tokens_for_split(
284
+ X_train, X_val, geo_params)
285
+ if built is not None:
286
+ geo_train, geo_val, _, _ = built
287
+ elif not self._cv_geo_warned:
288
+ print(
289
+ "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
290
+ flush=True,
291
+ )
292
+ self._cv_geo_warned = True
293
+ model.fit(
294
+ X_train, y_train, w_train,
295
+ X_val, y_val, w_val,
296
+ trial=trial_obj,
297
+ geo_train=geo_train,
298
+ geo_val=geo_val
299
+ )
300
+ return model.predict(X_val, geo_tokens=geo_val)
301
+
302
+ def metric_fn(y_true, y_pred, weight):
303
+ if self.ctx.task_type == 'regression':
304
+ return regression_loss(
305
+ y_true,
306
+ y_pred,
307
+ weight,
308
+ loss_name=loss_name,
309
+ tweedie_power=metric_ctx.get("tw_power", 1.5),
310
+ )
311
+ return log_loss(y_true, y_pred, sample_weight=weight)
312
+
313
+ data_for_cap = data_provider()[0]
314
+ max_rows_for_ft_bo = min(1000000, int(len(data_for_cap)/2))
315
+
316
+ return self.cross_val_generic(
317
+ trial=trial,
318
+ hyperparameter_space=param_space,
319
+ data_provider=data_provider,
320
+ model_builder=model_builder,
321
+ metric_fn=metric_fn,
322
+ sample_limit=max_rows_for_ft_bo if len(
323
+ data_for_cap) > max_rows_for_ft_bo > 0 else None,
324
+ fit_predict_fn=fit_predict,
325
+ cleanup_fn=lambda m: getattr(
326
+ getattr(m, "ft", None), "to", lambda *_args, **_kwargs: None)("cpu")
327
+ )
328
+
329
+ def train(self) -> None:
330
+ if not self.best_params:
331
+ raise RuntimeError(
332
+ "Run tune() first to obtain best FT-Transformer parameters.")
333
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
334
+ resolved_params = dict(self.best_params)
335
+ d_model_value = resolved_params.get("d_model", 64)
336
+ adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
337
+ d_model=d_model_value,
338
+ requested_heads=resolved_params.get("n_heads")
339
+ )
340
+ if heads_adjusted:
341
+ print(f"[FTTrainer] Auto-adjusted n_heads from "
342
+ f"{resolved_params.get('n_heads')} to {adaptive_heads} "
343
+ f"(d_model={d_model_value}).")
344
+ resolved_params["n_heads"] = adaptive_heads
345
+
346
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
347
+ refit_epochs = None
348
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
349
+ y_all = self.ctx.train_data[self.ctx.resp_nme]
350
+ w_all = self.ctx.train_data[self.ctx.weight_nme]
351
+ split = self._resolve_train_val_indices(X_all)
352
+ if use_refit and split is not None:
353
+ train_idx, val_idx = split
354
+ tmp_model = FTTransformerSklearn(
355
+ model_nme=self.ctx.model_nme,
356
+ num_cols=self.ctx.num_features,
357
+ cat_cols=self.ctx.cate_list,
358
+ task_type=self.ctx.task_type,
359
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
360
+ use_ddp=self.ctx.config.use_ft_ddp,
361
+ num_numeric_tokens=self._resolve_numeric_tokens(),
362
+ weight_decay=float(resolved_params.get("weight_decay", 0.0)),
363
+ loss_name=loss_name,
364
+ )
365
+ tmp_model = self._apply_dataloader_overrides(tmp_model)
366
+ tmp_model.set_params(resolved_params)
367
+ geo_train_full = self.ctx.train_geo_tokens
368
+ geo_train = None if geo_train_full is None else geo_train_full.iloc[train_idx]
369
+ geo_val = None if geo_train_full is None else geo_train_full.iloc[val_idx]
370
+ tmp_model.fit(
371
+ X_all.iloc[train_idx],
372
+ y_all.iloc[train_idx],
373
+ w_all.iloc[train_idx],
374
+ X_all.iloc[val_idx],
375
+ y_all.iloc[val_idx],
376
+ w_all.iloc[val_idx],
377
+ trial=None,
378
+ geo_train=geo_train,
379
+ geo_val=geo_val,
380
+ )
381
+ refit_epochs = self._resolve_best_epoch(
382
+ getattr(tmp_model, "training_history", None),
383
+ default_epochs=int(self.ctx.epochs),
384
+ )
385
+ getattr(getattr(tmp_model, "ft", None), "to",
386
+ lambda *_args, **_kwargs: None)("cpu")
387
+ self._clean_gpu()
388
+
389
+ self.model = FTTransformerSklearn(
390
+ model_nme=self.ctx.model_nme,
391
+ num_cols=self.ctx.num_features,
392
+ cat_cols=self.ctx.cate_list,
393
+ task_type=self.ctx.task_type,
394
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
395
+ use_ddp=self.ctx.config.use_ft_ddp,
396
+ num_numeric_tokens=self._resolve_numeric_tokens(),
397
+ weight_decay=float(resolved_params.get("weight_decay", 0.0)),
398
+ loss_name=loss_name,
399
+ )
400
+ self.model = self._apply_dataloader_overrides(self.model)
401
+ if refit_epochs is not None:
402
+ self.model.epochs = int(refit_epochs)
403
+ self.model.set_params(resolved_params)
404
+ self.best_params = resolved_params
405
+ loss_plot_path = self.output.plot_path(
406
+ f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
407
+ self.model.loss_curve_path = loss_plot_path
408
+ geo_train = self.ctx.train_geo_tokens
409
+ geo_test = self.ctx.test_geo_tokens
410
+ fit_kwargs = {}
411
+ predict_kwargs_train = None
412
+ predict_kwargs_test = None
413
+ if geo_train is not None and geo_test is not None:
414
+ fit_kwargs["geo_train"] = geo_train
415
+ predict_kwargs_train = {"geo_tokens": geo_train}
416
+ predict_kwargs_test = {"geo_tokens": geo_test}
417
+ self._fit_predict_cache(
418
+ self.model,
419
+ self.ctx.train_data[self.ctx.factor_nmes],
420
+ self.ctx.train_data[self.ctx.resp_nme],
421
+ sample_weight=self.ctx.train_data[self.ctx.weight_nme],
422
+ pred_prefix='ft',
423
+ sample_weight_arg='w_train',
424
+ fit_kwargs=fit_kwargs,
425
+ predict_kwargs_train=predict_kwargs_train,
426
+ predict_kwargs_test=predict_kwargs_test
427
+ )
428
+ self.ctx.ft_best = self.model
429
+
430
+ def ensemble_predict(self, k: int) -> None:
431
+ if not self.best_params:
432
+ raise RuntimeError(
433
+ "Run tune() first to obtain best FT-Transformer parameters.")
434
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
435
+ k = max(2, int(k))
436
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
437
+ y_all = self.ctx.train_data[self.ctx.resp_nme]
438
+ w_all = self.ctx.train_data[self.ctx.weight_nme]
439
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
440
+ n_samples = len(X_all)
441
+ geo_train_full = self.ctx.train_geo_tokens
442
+ geo_test_full = self.ctx.test_geo_tokens
443
+
444
+ resolved_params = dict(self.best_params)
445
+ default_d_model = getattr(self.model, "d_model", 64)
446
+ adaptive_heads, _ = self._resolve_adaptive_heads(
447
+ d_model=resolved_params.get("d_model", default_d_model),
448
+ requested_heads=resolved_params.get("n_heads")
449
+ )
450
+ resolved_params["n_heads"] = adaptive_heads
451
+
452
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
453
+ if split_iter is None:
454
+ print(
455
+ f"[FT Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
456
+ flush=True,
457
+ )
458
+ return
459
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
460
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
461
+
462
+ split_count = 0
463
+ for train_idx, val_idx in split_iter:
464
+ model = FTTransformerSklearn(
465
+ model_nme=self.ctx.model_nme,
466
+ num_cols=self.ctx.num_features,
467
+ cat_cols=self.ctx.cate_list,
468
+ task_type=self.ctx.task_type,
469
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
470
+ use_ddp=self.ctx.config.use_ft_ddp,
471
+ num_numeric_tokens=self._resolve_numeric_tokens(),
472
+ weight_decay=float(resolved_params.get("weight_decay", 0.0)),
473
+ loss_name=loss_name,
474
+ )
475
+ model = self._apply_dataloader_overrides(model)
476
+ model.set_params(resolved_params)
477
+
478
+ geo_train = geo_val = None
479
+ if geo_train_full is not None:
480
+ geo_train = geo_train_full.iloc[train_idx]
481
+ geo_val = geo_train_full.iloc[val_idx]
482
+
483
+ model.fit(
484
+ X_all.iloc[train_idx],
485
+ y_all.iloc[train_idx],
486
+ w_all.iloc[train_idx],
487
+ X_all.iloc[val_idx],
488
+ y_all.iloc[val_idx],
489
+ w_all.iloc[val_idx],
490
+ trial=None,
491
+ geo_train=geo_train,
492
+ geo_val=geo_val,
493
+ )
494
+
495
+ pred_train = model.predict(X_all, geo_tokens=geo_train_full)
496
+ pred_test = model.predict(X_test, geo_tokens=geo_test_full)
497
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
498
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
499
+ getattr(getattr(model, "ft", None), "to",
500
+ lambda *_args, **_kwargs: None)("cpu")
501
+ self._clean_gpu()
502
+ split_count += 1
503
+
504
+ if split_count < 1:
505
+ print(
506
+ f"[FT Ensemble] no CV splits generated; skip ensemble.",
507
+ flush=True,
508
+ )
509
+ return
510
+ preds_train = preds_train_sum / float(split_count)
511
+ preds_test = preds_test_sum / float(split_count)
512
+ self._cache_predictions("ft", preds_train, preds_test)
513
+
514
+ def _resolve_oof_splitter(self, n_samples: int):
515
+ cfg = self.ctx.config
516
+ raw_strategy = str(getattr(cfg, "ft_oof_strategy",
517
+ "auto") or "auto").strip().lower()
518
+ base_strategy = str(
519
+ getattr(cfg, "cv_strategy", "random") or "random").strip().lower()
520
+ if raw_strategy == "auto":
521
+ strategy = base_strategy
522
+ else:
523
+ strategy = raw_strategy
524
+
525
+ oof_folds = getattr(cfg, "ft_oof_folds", None)
526
+ if oof_folds is None:
527
+ if strategy in {"random", "group", "grouped"}:
528
+ val_ratio = float(
529
+ self.ctx.prop_test) if self.ctx.prop_test else 0.25
530
+ if not (0.0 < val_ratio < 1.0):
531
+ val_ratio = 0.25
532
+ oof_folds = max(2, int(round(1 / val_ratio)))
533
+ else:
534
+ oof_folds = 0
535
+ oof_folds = int(oof_folds)
536
+
537
+ if oof_folds < 2 or n_samples < oof_folds:
538
+ return None, None, 0
539
+
540
+ if strategy in {"group", "grouped"}:
541
+ group_col = getattr(cfg, "cv_group_col", None)
542
+ if not group_col:
543
+ raise ValueError(
544
+ "cv_group_col is required for FT OOF group strategy.")
545
+ if group_col not in self.ctx.train_data.columns:
546
+ raise KeyError(
547
+ f"cv_group_col '{group_col}' not in train_data.")
548
+ groups = self.ctx.train_data[group_col]
549
+ splitter = GroupKFold(n_splits=oof_folds)
550
+ return splitter, groups, oof_folds
551
+
552
+ if strategy in {"time", "timeseries", "temporal"}:
553
+ time_col = getattr(cfg, "cv_time_col", None)
554
+ if not time_col:
555
+ raise ValueError(
556
+ "cv_time_col is required for FT OOF time strategy.")
557
+ if time_col not in self.ctx.train_data.columns:
558
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
559
+ ascending = bool(getattr(cfg, "cv_time_ascending", True))
560
+ order_index = self.ctx.train_data[time_col].sort_values(
561
+ ascending=ascending).index
562
+ order = self.ctx.train_data.index.get_indexer(order_index)
563
+ if n_samples <= oof_folds:
564
+ return None, None, 0
565
+ splitter = TimeSeriesSplit(n_splits=oof_folds)
566
+ return _OrderSplitter(splitter, order), None, oof_folds
567
+
568
+ shuffle = bool(getattr(cfg, "ft_oof_shuffle", True))
569
+ splitter = KFold(
570
+ n_splits=oof_folds,
571
+ shuffle=shuffle,
572
+ random_state=self.ctx.rand_seed if shuffle else None,
573
+ )
574
+ return splitter, None, oof_folds
575
+
576
+ def _build_ft_feature_model(self, resolved_params: Dict[str, Any]) -> FTTransformerSklearn:
577
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
578
+ model = FTTransformerSklearn(
579
+ model_nme=self.ctx.model_nme,
580
+ num_cols=self.ctx.num_features,
581
+ cat_cols=self.ctx.cate_list,
582
+ task_type=self.ctx.task_type,
583
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
584
+ use_ddp=self.ctx.config.use_ft_ddp,
585
+ num_numeric_tokens=self._resolve_numeric_tokens(),
586
+ loss_name=loss_name,
587
+ )
588
+ model = self._apply_dataloader_overrides(model)
589
+ adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
590
+ d_model=resolved_params.get("d_model", model.d_model),
591
+ requested_heads=resolved_params.get("n_heads"),
592
+ )
593
+ if heads_adjusted:
594
+ print(
595
+ f"[FTTrainer] Auto-adjusted n_heads from "
596
+ f"{resolved_params.get('n_heads')} to {adaptive_heads} "
597
+ f"(d_model={resolved_params.get('d_model', model.d_model)})."
598
+ )
599
+ resolved_params["n_heads"] = adaptive_heads
600
+ if resolved_params:
601
+ model.set_params(resolved_params)
602
+ return model
603
+
604
+ def _oof_predict_train(
605
+ self,
606
+ resolved_params: Dict[str, Any],
607
+ *,
608
+ feature_mode: str,
609
+ geo_train_full: Optional[pd.DataFrame],
610
+ ) -> Optional[np.ndarray]:
611
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
612
+ y_all = self.ctx.train_data[self.ctx.resp_nme]
613
+ w_all = self.ctx.train_data[self.ctx.weight_nme]
614
+ splitter, groups, oof_folds = self._resolve_oof_splitter(len(X_all))
615
+ if splitter is None:
616
+ return None
617
+
618
+ preds_train = None
619
+ for fold_idx, (train_idx, val_idx) in enumerate(splitter.split(X_all, y_all, groups=groups), start=1):
620
+ X_train = X_all.iloc[train_idx]
621
+ y_train = y_all.iloc[train_idx]
622
+ w_train = w_all.iloc[train_idx]
623
+ X_val = X_all.iloc[val_idx]
624
+ y_val = y_all.iloc[val_idx]
625
+ w_val = w_all.iloc[val_idx]
626
+
627
+ geo_train = geo_val = None
628
+ if geo_train_full is not None:
629
+ geo_train = geo_train_full.iloc[train_idx]
630
+ geo_val = geo_train_full.iloc[val_idx]
631
+
632
+ model = self._build_ft_feature_model(dict(resolved_params))
633
+ model.fit(
634
+ X_train,
635
+ y_train,
636
+ w_train=w_train,
637
+ X_val=X_val,
638
+ y_val=y_val,
639
+ w_val=w_val,
640
+ trial=None,
641
+ geo_train=geo_train,
642
+ geo_val=geo_val,
643
+ )
644
+
645
+ predict_kwargs = {}
646
+ if geo_val is not None:
647
+ predict_kwargs["geo_tokens"] = geo_val
648
+ if feature_mode == "embedding":
649
+ predict_kwargs["return_embedding"] = True
650
+ fold_pred = model.predict(X_val, **predict_kwargs)
651
+ fold_pred = np.asarray(fold_pred)
652
+ if preds_train is None:
653
+ preds_train = np.empty(
654
+ (len(X_all),) + fold_pred.shape[1:], dtype=fold_pred.dtype)
655
+ preds_train[val_idx] = fold_pred
656
+
657
+ getattr(getattr(model, "ft", None), "to",
658
+ lambda *_a, **_k: None)("cpu")
659
+ self._clean_gpu()
660
+
661
+ if preds_train is None:
662
+ return None
663
+ if oof_folds < 2:
664
+ return None
665
+ return preds_train
666
+
667
+ def train_as_feature(self, pred_prefix: str = "ft_feat", feature_mode: str = "prediction") -> None:
668
+ """Train FT-Transformer only to generate features (not recorded as final model)."""
669
+ if not self.best_params:
670
+ raise RuntimeError(
671
+ "Run tune() first to obtain best FT-Transformer parameters.")
672
+ resolved_params = dict(self.best_params)
673
+ if feature_mode not in ("prediction", "embedding"):
674
+ raise ValueError(
675
+ f"Unsupported feature_mode='{feature_mode}', expected 'prediction' or 'embedding'.")
676
+
677
+ geo_train = self.ctx.train_geo_tokens
678
+ geo_test = self.ctx.test_geo_tokens
679
+ fit_kwargs = {}
680
+ predict_kwargs_train = None
681
+ predict_kwargs_test = None
682
+ if geo_train is not None and geo_test is not None:
683
+ fit_kwargs["geo_train"] = geo_train
684
+ predict_kwargs_train = {"geo_tokens": geo_train}
685
+ predict_kwargs_test = {"geo_tokens": geo_test}
686
+
687
+ if feature_mode == "embedding":
688
+ predict_kwargs_train = dict(predict_kwargs_train or {})
689
+ predict_kwargs_test = dict(predict_kwargs_test or {})
690
+ predict_kwargs_train["return_embedding"] = True
691
+ predict_kwargs_test["return_embedding"] = True
692
+
693
+ oof_preds = self._oof_predict_train(
694
+ resolved_params,
695
+ feature_mode=feature_mode,
696
+ geo_train_full=geo_train,
697
+ )
698
+ if oof_preds is not None:
699
+ self.model = self._build_ft_feature_model(resolved_params)
700
+ self.best_params = resolved_params
701
+ self.model.fit(
702
+ self.ctx.train_data[self.ctx.factor_nmes],
703
+ self.ctx.train_data[self.ctx.resp_nme],
704
+ w_train=self.ctx.train_data[self.ctx.weight_nme],
705
+ X_val=None,
706
+ y_val=None,
707
+ w_val=None,
708
+ trial=None,
709
+ geo_train=geo_train,
710
+ geo_val=None,
711
+ )
712
+ predict_kwargs = dict(predict_kwargs_test or {})
713
+ preds_test = self.model.predict(
714
+ self.ctx.test_data[self.ctx.factor_nmes],
715
+ **predict_kwargs,
716
+ )
717
+ self._cache_predictions(pred_prefix, oof_preds, preds_test)
718
+ return
719
+
720
+ self.model = self._build_ft_feature_model(resolved_params)
721
+ self.best_params = resolved_params
722
+ self._fit_predict_cache(
723
+ self.model,
724
+ self.ctx.train_data[self.ctx.factor_nmes],
725
+ self.ctx.train_data[self.ctx.resp_nme],
726
+ sample_weight=self.ctx.train_data[self.ctx.weight_nme],
727
+ pred_prefix=pred_prefix,
728
+ sample_weight_arg='w_train',
729
+ fit_kwargs=fit_kwargs,
730
+ predict_kwargs_train=predict_kwargs_train,
731
+ predict_kwargs_test=predict_kwargs_test,
732
+ record_label=False,
733
+ )
734
+
735
+ def pretrain_unsupervised_as_feature(self,
736
+ pred_prefix: str = "ft_uemb",
737
+ params: Optional[Dict[str,
738
+ Any]] = None,
739
+ mask_prob_num: float = 0.15,
740
+ mask_prob_cat: float = 0.15,
741
+ num_loss_weight: float = 1.0,
742
+ cat_loss_weight: float = 1.0) -> None:
743
+ """Self-supervised pretraining (masked reconstruction) and cache embeddings."""
744
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
745
+ self.model = FTTransformerSklearn(
746
+ model_nme=self.ctx.model_nme,
747
+ num_cols=self.ctx.num_features,
748
+ cat_cols=self.ctx.cate_list,
749
+ task_type=self.ctx.task_type,
750
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
751
+ use_ddp=self.ctx.config.use_ft_ddp,
752
+ num_numeric_tokens=self._resolve_numeric_tokens(),
753
+ loss_name=loss_name,
754
+ )
755
+ self.model = self._apply_dataloader_overrides(self.model)
756
+ resolved_params = dict(params or {})
757
+ # Reuse supervised tuning structure params unless explicitly overridden.
758
+ if not resolved_params and self.best_params:
759
+ resolved_params = dict(self.best_params)
760
+
761
+ # If params include masked reconstruction fields, they take precedence.
762
+ mask_prob_num = float(resolved_params.pop(
763
+ "mask_prob_num", mask_prob_num))
764
+ mask_prob_cat = float(resolved_params.pop(
765
+ "mask_prob_cat", mask_prob_cat))
766
+ num_loss_weight = float(resolved_params.pop(
767
+ "num_loss_weight", num_loss_weight))
768
+ cat_loss_weight = float(resolved_params.pop(
769
+ "cat_loss_weight", cat_loss_weight))
770
+
771
+ adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
772
+ d_model=resolved_params.get("d_model", self.model.d_model),
773
+ requested_heads=resolved_params.get("n_heads")
774
+ )
775
+ if heads_adjusted:
776
+ print(f"[FTTrainer] Auto-adjusted n_heads from "
777
+ f"{resolved_params.get('n_heads')} to {adaptive_heads} "
778
+ f"(d_model={resolved_params.get('d_model', self.model.d_model)}).")
779
+ resolved_params["n_heads"] = adaptive_heads
780
+ if resolved_params:
781
+ self.model.set_params(resolved_params)
782
+
783
+ loss_plot_path = self.output.plot_path(
784
+ f'{self.ctx.model_nme}/loss/loss_{self.ctx.model_nme}_FTTransformerUnsupervised.png')
785
+ self.model.loss_curve_path = loss_plot_path
786
+
787
+ # Build a simple holdout split for pretraining early stopping.
788
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
789
+ split = self._resolve_train_val_indices(X_all, allow_default=True)
790
+ if split is None:
791
+ raise ValueError(
792
+ "Unable to build train/val split for FT unsupervised training.")
793
+ train_idx, val_idx = split
794
+ X_tr = X_all.iloc[train_idx]
795
+ X_val = X_all.iloc[val_idx]
796
+
797
+ geo_all = self.ctx.train_geo_tokens
798
+ geo_tr = geo_val = None
799
+ if geo_all is not None:
800
+ geo_tr = geo_all.loc[X_tr.index]
801
+ geo_val = geo_all.loc[X_val.index]
802
+
803
+ self.model.fit_unsupervised(
804
+ X_tr,
805
+ X_val=X_val,
806
+ geo_train=geo_tr,
807
+ geo_val=geo_val,
808
+ mask_prob_num=mask_prob_num,
809
+ mask_prob_cat=mask_prob_cat,
810
+ num_loss_weight=num_loss_weight,
811
+ cat_loss_weight=cat_loss_weight
812
+ )
813
+
814
+ geo_train_full = self.ctx.train_geo_tokens
815
+ geo_test_full = self.ctx.test_geo_tokens
816
+ predict_kwargs_train = {"return_embedding": True}
817
+ predict_kwargs_test = {"return_embedding": True}
818
+ if geo_train_full is not None and geo_test_full is not None:
819
+ predict_kwargs_train["geo_tokens"] = geo_train_full
820
+ predict_kwargs_test["geo_tokens"] = geo_test_full
821
+
822
+ self._predict_and_cache(
823
+ self.model,
824
+ pred_prefix=pred_prefix,
825
+ predict_kwargs_train=predict_kwargs_train,
826
+ predict_kwargs_test=predict_kwargs_test
827
+ )
828
+
829
+
830
+ # =============================================================================