ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +39 -105
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +11 -9
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/__init__.py +10 -10
  15. ins_pricing/frontend/example_workflows.py +1 -1
  16. ins_pricing/governance/__init__.py +20 -20
  17. ins_pricing/governance/release.py +159 -159
  18. ins_pricing/modelling/__init__.py +147 -92
  19. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
  20. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  21. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
  22. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  29. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  36. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  37. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  38. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  39. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  40. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  42. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  43. ins_pricing/modelling/explain/__init__.py +55 -55
  44. ins_pricing/modelling/explain/metrics.py +27 -174
  45. ins_pricing/modelling/explain/permutation.py +237 -237
  46. ins_pricing/modelling/plotting/__init__.py +40 -36
  47. ins_pricing/modelling/plotting/compat.py +228 -0
  48. ins_pricing/modelling/plotting/curves.py +572 -572
  49. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  50. ins_pricing/modelling/plotting/geo.py +362 -362
  51. ins_pricing/modelling/plotting/importance.py +121 -121
  52. ins_pricing/pricing/__init__.py +27 -27
  53. ins_pricing/production/__init__.py +35 -25
  54. ins_pricing/production/{predict.py → inference.py} +140 -57
  55. ins_pricing/production/monitoring.py +8 -21
  56. ins_pricing/reporting/__init__.py +11 -11
  57. ins_pricing/setup.py +1 -1
  58. ins_pricing/tests/production/test_inference.py +90 -0
  59. ins_pricing/utils/__init__.py +116 -83
  60. ins_pricing/utils/device.py +255 -255
  61. ins_pricing/utils/features.py +53 -0
  62. ins_pricing/utils/io.py +72 -0
  63. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  64. ins_pricing/utils/metrics.py +158 -24
  65. ins_pricing/utils/numerics.py +76 -0
  66. ins_pricing/utils/paths.py +9 -1
  67. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
  68. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  69. ins_pricing/modelling/core/BayesOpt.py +0 -146
  70. ins_pricing/modelling/core/__init__.py +0 -1
  71. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  72. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  73. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  74. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  75. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  76. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  77. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  78. ins_pricing/tests/production/test_predict.py +0 -233
  79. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  80. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  81. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  82. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  83. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  84. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,913 +1,915 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- from contextlib import nullcontext
5
- from typing import Any, Dict, List, Optional
6
-
7
- import numpy as np
8
- import optuna
9
- import pandas as pd
10
- import torch
11
- import torch.distributed as dist
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torch.cuda.amp import autocast, GradScaler
15
- from torch.nn.parallel import DistributedDataParallel as DDP
16
- from torch.nn.utils import clip_grad_norm_
17
-
18
- from ..utils import DistributedUtils, EPS, TorchTrainerMixin
19
- from ..utils.losses import (
20
- infer_loss_name_from_model_name,
21
- normalize_loss_name,
22
- resolve_tweedie_power,
23
- )
24
- from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
25
-
26
-
27
- # --- Helper functions for reconstruction loss computation ---
28
-
29
-
30
- def _compute_numeric_reconstruction_loss(
31
- num_pred: Optional[torch.Tensor],
32
- num_true: Optional[torch.Tensor],
33
- num_mask: Optional[torch.Tensor],
34
- loss_weight: float,
35
- device: torch.device,
36
- ) -> torch.Tensor:
37
- """Compute MSE loss for numeric feature reconstruction.
38
-
39
- Args:
40
- num_pred: Predicted numeric values (N, num_features)
41
- num_true: Ground truth numeric values (N, num_features)
42
- num_mask: Boolean mask indicating which values were masked (N, num_features)
43
- loss_weight: Weight to apply to the loss
44
- device: Target device for computation
45
-
46
- Returns:
47
- Weighted MSE loss for masked numeric features
48
- """
49
- if num_pred is None or num_true is None or num_mask is None:
50
- return torch.zeros((), device=device, dtype=torch.float32)
51
-
52
- num_mask = num_mask.to(dtype=torch.bool)
53
- if not num_mask.any():
54
- return torch.zeros((), device=device, dtype=torch.float32)
55
-
56
- diff = num_pred - num_true
57
- mse = diff * diff
58
- return float(loss_weight) * mse[num_mask].mean()
59
-
60
-
61
- def _compute_categorical_reconstruction_loss(
62
- cat_logits: Optional[List[torch.Tensor]],
63
- cat_true: Optional[torch.Tensor],
64
- cat_mask: Optional[torch.Tensor],
65
- loss_weight: float,
66
- device: torch.device,
67
- ) -> torch.Tensor:
68
- """Compute cross-entropy loss for categorical feature reconstruction.
69
-
70
- Args:
71
- cat_logits: List of logits for each categorical feature
72
- cat_true: Ground truth categorical indices (N, num_cat_features)
73
- cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
74
- loss_weight: Weight to apply to the loss
75
- device: Target device for computation
76
-
77
- Returns:
78
- Weighted cross-entropy loss for masked categorical features
79
- """
80
- if not cat_logits or cat_true is None or cat_mask is None:
81
- return torch.zeros((), device=device, dtype=torch.float32)
82
-
83
- cat_mask = cat_mask.to(dtype=torch.bool)
84
- cat_losses: List[torch.Tensor] = []
85
-
86
- for j, logits in enumerate(cat_logits):
87
- mask_j = cat_mask[:, j]
88
- if not mask_j.any():
89
- continue
90
- targets = cat_true[:, j]
91
- cat_losses.append(
92
- F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
93
- )
94
-
95
- if not cat_losses:
96
- return torch.zeros((), device=device, dtype=torch.float32)
97
-
98
- return float(loss_weight) * torch.stack(cat_losses).mean()
99
-
100
-
101
- def _compute_reconstruction_loss(
102
- num_pred: Optional[torch.Tensor],
103
- cat_logits: Optional[List[torch.Tensor]],
104
- num_true: Optional[torch.Tensor],
105
- num_mask: Optional[torch.Tensor],
106
- cat_true: Optional[torch.Tensor],
107
- cat_mask: Optional[torch.Tensor],
108
- num_loss_weight: float,
109
- cat_loss_weight: float,
110
- device: torch.device,
111
- ) -> torch.Tensor:
112
- """Compute combined reconstruction loss for masked tabular data.
113
-
114
- This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
115
-
116
- Args:
117
- num_pred: Predicted numeric values
118
- cat_logits: List of logits for categorical features
119
- num_true: Ground truth numeric values
120
- num_mask: Mask for numeric features
121
- cat_true: Ground truth categorical indices
122
- cat_mask: Mask for categorical features
123
- num_loss_weight: Weight for numeric loss
124
- cat_loss_weight: Weight for categorical loss
125
- device: Target device for computation
126
-
127
- Returns:
128
- Combined weighted reconstruction loss
129
- """
130
- num_loss = _compute_numeric_reconstruction_loss(
131
- num_pred, num_true, num_mask, num_loss_weight, device
132
- )
133
- cat_loss = _compute_categorical_reconstruction_loss(
134
- cat_logits, cat_true, cat_mask, cat_loss_weight, device
135
- )
136
- return num_loss + cat_loss
137
-
138
-
139
- # Scikit-Learn style wrapper for FTTransformer.
140
-
141
-
142
- class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
143
-
144
- # sklearn-style wrapper:
145
- # - num_cols: numeric feature column names
146
- # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
147
-
148
- @staticmethod
149
- def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
150
- num_cols_count = len(num_cols or [])
151
- if num_cols_count == 0:
152
- return 0
153
- if requested is not None:
154
- count = int(requested)
155
- if count <= 0:
156
- raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
157
- return count
158
- return max(1, num_cols_count)
159
-
160
- def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
161
- n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
162
- task_type: str = 'regression',
163
- tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
164
- weight_decay: float = 0.0,
165
- use_data_parallel: bool = True,
166
- use_ddp: bool = False,
167
- num_numeric_tokens: Optional[int] = None,
168
- loss_name: Optional[str] = None
169
- ):
170
- super().__init__()
171
-
172
- self.use_ddp = use_ddp
173
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
174
- False, 0, 0, 1)
175
- if self.use_ddp:
176
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
177
-
178
- self.model_nme = model_nme
179
- self.num_cols = list(num_cols)
180
- self.cat_cols = list(cat_cols)
181
- self.num_numeric_tokens = self.resolve_numeric_token_count(
182
- self.num_cols,
183
- self.cat_cols,
184
- num_numeric_tokens,
185
- )
186
- self.d_model = d_model
187
- self.n_heads = n_heads
188
- self.n_layers = n_layers
189
- self.dropout = dropout
190
- self.batch_num = batch_num
191
- self.epochs = epochs
192
- self.learning_rate = learning_rate
193
- self.weight_decay = weight_decay
194
- self.task_type = task_type
195
- self.patience = patience
196
- resolved_loss = normalize_loss_name(loss_name, self.task_type)
197
- if self.task_type == 'classification':
198
- self.loss_name = "logloss"
199
- self.tw_power = None # No Tweedie power for classification.
200
- else:
201
- if resolved_loss == "auto":
202
- resolved_loss = infer_loss_name_from_model_name(self.model_nme)
203
- self.loss_name = resolved_loss
204
- if self.loss_name == "tweedie":
205
- self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
206
- else:
207
- self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
208
-
209
- if self.is_ddp_enabled:
210
- self.device = torch.device(f"cuda:{self.local_rank}")
211
- elif torch.cuda.is_available():
212
- self.device = torch.device("cuda")
213
- elif torch.backends.mps.is_available():
214
- self.device = torch.device("mps")
215
- else:
216
- self.device = torch.device("cpu")
217
- self.cat_cardinalities = None
218
- self.cat_categories = {}
219
- self.cat_maps: Dict[str, Dict[Any, int]] = {}
220
- self.cat_str_maps: Dict[str, Dict[str, int]] = {}
221
- self._num_mean = None
222
- self._num_std = None
223
- self.ft = None
224
- self.use_data_parallel = bool(use_data_parallel)
225
- self.num_geo = 0
226
- self._geo_params: Dict[str, Any] = {}
227
- self.loss_curve_path: Optional[str] = None
228
- self.training_history: Dict[str, List[float]] = {
229
- "train": [], "val": []}
230
-
231
- def _build_model(self, X_train):
232
- num_numeric = len(self.num_cols)
233
- cat_cardinalities = []
234
-
235
- if num_numeric > 0:
236
- num_arr = X_train[self.num_cols].to_numpy(
237
- dtype=np.float32, copy=False)
238
- num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
239
- mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
240
- std = num_arr.std(axis=0).astype(np.float32, copy=False)
241
- std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
242
- self._num_mean = mean
243
- self._num_std = std
244
- else:
245
- self._num_mean = None
246
- self._num_std = None
247
-
248
- self.cat_maps = {}
249
- self.cat_str_maps = {}
250
- for col in self.cat_cols:
251
- cats = X_train[col].astype('category')
252
- categories = cats.cat.categories
253
- self.cat_categories[col] = categories # Store full category list from training.
254
- self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
255
- if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
256
- self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
257
-
258
- card = len(categories) + 1 # Reserve one extra class for unknown/missing.
259
- cat_cardinalities.append(card)
260
-
261
- self.cat_cardinalities = cat_cardinalities
262
-
263
- core = FTTransformerCore(
264
- num_numeric=num_numeric,
265
- cat_cardinalities=cat_cardinalities,
266
- d_model=self.d_model,
267
- n_heads=self.n_heads,
268
- n_layers=self.n_layers,
269
- dropout=self.dropout,
270
- task_type=self.task_type,
271
- num_geo=self.num_geo,
272
- num_numeric_tokens=self.num_numeric_tokens
273
- )
274
- use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
275
- if self.is_ddp_enabled:
276
- core = core.to(self.device)
277
- core = DDP(core, device_ids=[
278
- self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
279
- self.use_data_parallel = False
280
- elif use_dp:
281
- if self.use_ddp and not self.is_ddp_enabled:
282
- print(
283
- ">>> DDP requested but not initialized; falling back to DataParallel.")
284
- core = nn.DataParallel(core, device_ids=list(
285
- range(torch.cuda.device_count())))
286
- self.device = torch.device("cuda")
287
- self.use_data_parallel = True
288
- else:
289
- self.use_data_parallel = False
290
- self.ft = core.to(self.device)
291
-
292
- def _encode_cats(self, X):
293
- # Input DataFrame must include all categorical feature columns.
294
- # Return int64 array with shape (N, num_categorical_features).
295
-
296
- if not self.cat_cols:
297
- return np.zeros((len(X), 0), dtype='int64')
298
-
299
- n_rows = len(X)
300
- n_cols = len(self.cat_cols)
301
- X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
302
- for idx, col in enumerate(self.cat_cols):
303
- categories = self.cat_categories[col]
304
- mapping = self.cat_maps.get(col)
305
- if mapping is None:
306
- mapping = {cat: i for i, cat in enumerate(categories)}
307
- self.cat_maps[col] = mapping
308
- unknown_idx = len(categories)
309
- series = X[col]
310
- codes = series.map(mapping)
311
- unmapped = series.notna() & codes.isna()
312
- if unmapped.any():
313
- try:
314
- series_cast = series.astype(categories.dtype)
315
- except Exception:
316
- series_cast = None
317
- if series_cast is not None:
318
- codes = series_cast.map(mapping)
319
- unmapped = series_cast.notna() & codes.isna()
320
- if unmapped.any():
321
- str_map = self.cat_str_maps.get(col)
322
- if str_map is None:
323
- str_map = {str(cat): i for i, cat in enumerate(categories)}
324
- self.cat_str_maps[col] = str_map
325
- codes = series.astype(str).map(str_map)
326
- if pd.api.types.is_categorical_dtype(codes):
327
- codes = codes.astype("float")
328
- codes = codes.fillna(unknown_idx).astype(
329
- "int64", copy=False).to_numpy()
330
- X_cat_np[:, idx] = codes
331
- return X_cat_np
332
-
333
- def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
334
- return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
335
-
336
- def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
337
- return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
338
-
339
- @staticmethod
340
- def _validate_vector(arr, name: str, n_rows: int) -> None:
341
- if arr is None:
342
- return
343
- if isinstance(arr, pd.DataFrame):
344
- if arr.shape[1] != 1:
345
- raise ValueError(f"{name} must be 1d (single column).")
346
- length = len(arr)
347
- else:
348
- arr_np = np.asarray(arr)
349
- if arr_np.ndim == 0:
350
- raise ValueError(f"{name} must be 1d.")
351
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
352
- raise ValueError(f"{name} must be 1d or Nx1.")
353
- length = arr_np.shape[0]
354
- if length != n_rows:
355
- raise ValueError(
356
- f"{name} length {length} does not match X length {n_rows}."
357
- )
358
-
359
- def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
360
- if X is None:
361
- if allow_none:
362
- return None, None, None, None, None, False
363
- raise ValueError("Input features X must not be None.")
364
- if not isinstance(X, pd.DataFrame):
365
- raise ValueError("X must be a pandas DataFrame.")
366
- missing_cols = [
367
- col for col in (self.num_cols + self.cat_cols) if col not in X.columns
368
- ]
369
- if missing_cols:
370
- raise ValueError(f"X is missing required columns: {missing_cols}")
371
- n_rows = len(X)
372
- if y is not None:
373
- self._validate_vector(y, "y", n_rows)
374
- if w is not None:
375
- self._validate_vector(w, "w", n_rows)
376
-
377
- num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
378
- if not num_np.flags["OWNDATA"]:
379
- num_np = num_np.copy()
380
- num_np = np.nan_to_num(num_np, nan=0.0,
381
- posinf=0.0, neginf=0.0, copy=False)
382
- if self._num_mean is not None and self._num_std is not None and num_np.size:
383
- num_np = (num_np - self._num_mean) / self._num_std
384
- X_num = torch.as_tensor(num_np)
385
- if self.cat_cols:
386
- X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
387
- else:
388
- X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
389
-
390
- if geo_tokens is not None:
391
- geo_np = np.asarray(geo_tokens, dtype=np.float32)
392
- if geo_np.shape[0] != n_rows:
393
- raise ValueError(
394
- "geo_tokens length does not match X rows.")
395
- if geo_np.ndim == 1:
396
- geo_np = geo_np.reshape(-1, 1)
397
- elif self.num_geo > 0:
398
- raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
399
- else:
400
- geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
401
- X_geo = torch.as_tensor(geo_np)
402
-
403
- y_tensor = torch.as_tensor(
404
- y.to_numpy(dtype=np.float32, copy=False) if hasattr(
405
- y, "to_numpy") else np.asarray(y, dtype=np.float32)
406
- ).view(-1, 1) if y is not None else None
407
- if y_tensor is None:
408
- w_tensor = None
409
- elif w is not None:
410
- w_tensor = torch.as_tensor(
411
- w.to_numpy(dtype=np.float32, copy=False) if hasattr(
412
- w, "to_numpy") else np.asarray(w, dtype=np.float32)
413
- ).view(-1, 1)
414
- else:
415
- w_tensor = torch.ones_like(y_tensor)
416
- return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
417
-
418
- def fit(self, X_train, y_train, w_train=None,
419
- X_val=None, y_val=None, w_val=None, trial=None,
420
- geo_train=None, geo_val=None):
421
-
422
- # Build the underlying model on first fit.
423
- self.num_geo = geo_train.shape[1] if geo_train is not None else 0
424
- if self.ft is None:
425
- self._build_model(X_train)
426
-
427
- X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
428
- X_train, y_train, w_train, geo_train=geo_train)
429
- X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
430
- X_val, y_val, w_val, geo_val=geo_val)
431
-
432
- # --- Build DataLoader ---
433
- dataset = TabularDataset(
434
- X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
435
- )
436
-
437
- dataloader, accum_steps = self._build_dataloader(
438
- dataset,
439
- N=X_num_train.shape[0],
440
- base_bs_gpu=(2048, 1024, 512),
441
- base_bs_cpu=(256, 128),
442
- min_bs=64,
443
- target_effective_cuda=2048,
444
- target_effective_cpu=1024
445
- )
446
-
447
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
448
- self.dataloader_sampler = dataloader.sampler
449
- else:
450
- self.dataloader_sampler = None
451
-
452
- optimizer = torch.optim.Adam(
453
- self.ft.parameters(),
454
- lr=self.learning_rate,
455
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
456
- )
457
- scaler = GradScaler(enabled=(self.device.type == 'cuda'))
458
-
459
- X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
460
- val_dataloader = None
461
- if has_val:
462
- val_dataset = TabularDataset(
463
- X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
464
- )
465
- val_dataloader = self._build_val_dataloader(
466
- val_dataset, dataloader, accum_steps)
467
-
468
- # Check for both DataParallel and DDP wrappers
469
- is_data_parallel = isinstance(self.ft, (nn.DataParallel, DDP))
470
-
471
- def forward_fn(batch):
472
- X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
473
-
474
- # For DataParallel, inputs are automatically scattered; for DDP, move to local device
475
- if not isinstance(self.ft, nn.DataParallel):
476
- X_num_b = X_num_b.to(self.device, non_blocking=True)
477
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
478
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
479
- y_b = y_b.to(self.device, non_blocking=True)
480
- w_b = w_b.to(self.device, non_blocking=True)
481
-
482
- y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
483
- return y_pred, y_b, w_b
484
-
485
- def val_forward_fn():
486
- total_loss = 0.0
487
- total_weight = 0.0
488
- for batch in val_dataloader:
489
- X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
490
- if not isinstance(self.ft, nn.DataParallel):
491
- X_num_b = X_num_b.to(self.device, non_blocking=True)
492
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
493
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
494
- y_b = y_b.to(self.device, non_blocking=True)
495
- w_b = w_b.to(self.device, non_blocking=True)
496
-
497
- y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
498
-
499
- # Manually compute validation loss.
500
- losses = self._compute_losses(
501
- y_pred, y_b, apply_softplus=False)
502
-
503
- batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
504
- batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
505
-
506
- total_loss += batch_weighted_loss_sum.item()
507
- total_weight += batch_weight_sum.item()
508
-
509
- return total_loss / max(total_weight, EPS)
510
-
511
- clip_fn = None
512
- if self.device.type == 'cuda':
513
- def clip_fn(): return (scaler.unscale_(optimizer),
514
- clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
515
-
516
- best_state, history = self._train_model(
517
- self.ft,
518
- dataloader,
519
- accum_steps,
520
- optimizer,
521
- scaler,
522
- forward_fn,
523
- val_forward_fn if has_val else None,
524
- apply_softplus=False,
525
- clip_fn=clip_fn,
526
- trial=trial,
527
- loss_curve_path=getattr(self, "loss_curve_path", None)
528
- )
529
-
530
- if has_val and best_state is not None:
531
- # Load state into unwrapped module to match how it was saved
532
- base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
533
- base_module.load_state_dict(best_state)
534
- self.training_history = history
535
-
536
- def fit_unsupervised(self,
537
- X_train,
538
- X_val=None,
539
- trial: Optional[optuna.trial.Trial] = None,
540
- geo_train=None,
541
- geo_val=None,
542
- mask_prob_num: float = 0.15,
543
- mask_prob_cat: float = 0.15,
544
- num_loss_weight: float = 1.0,
545
- cat_loss_weight: float = 1.0) -> float:
546
- """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
547
- self.num_geo = geo_train.shape[1] if geo_train is not None else 0
548
- if self.ft is None:
549
- self._build_model(X_train)
550
-
551
- X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
552
- X_train, None, None, geo_tokens=geo_train, allow_none=True)
553
- has_val = X_val is not None
554
- if has_val:
555
- X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
556
- X_val, None, None, geo_tokens=geo_val, allow_none=True)
557
- else:
558
- X_num_val = X_cat_val = X_geo_val = None
559
-
560
- N = int(X_num.shape[0])
561
- num_dim = int(X_num.shape[1])
562
- cat_dim = int(X_cat.shape[1])
563
- device_type = self._device_type()
564
-
565
- gen = torch.Generator()
566
- gen.manual_seed(13 + int(getattr(self, "rank", 0)))
567
-
568
- base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
569
- cardinals = getattr(base_model, "cat_cardinalities", None) or []
570
- unknown_idx = torch.tensor(
571
- [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
572
-
573
- means = None
574
- if num_dim > 0:
575
- # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
576
- means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
577
-
578
- def _mask_inputs(X_num_in: torch.Tensor,
579
- X_cat_in: torch.Tensor,
580
- generator: torch.Generator):
581
- n_rows = int(X_num_in.shape[0])
582
- num_mask_local = None
583
- cat_mask_local = None
584
- X_num_masked_local = X_num_in
585
- X_cat_masked_local = X_cat_in
586
- if num_dim > 0:
587
- num_mask_local = (torch.rand(
588
- (n_rows, num_dim), generator=generator) < float(mask_prob_num))
589
- X_num_masked_local = X_num_in.clone()
590
- if num_mask_local.any():
591
- X_num_masked_local[num_mask_local] = means.expand_as(
592
- X_num_masked_local)[num_mask_local]
593
- if cat_dim > 0:
594
- cat_mask_local = (torch.rand(
595
- (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
596
- X_cat_masked_local = X_cat_in.clone()
597
- if cat_mask_local.any():
598
- X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
599
- X_cat_masked_local)[cat_mask_local]
600
- return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
601
-
602
- X_num_true = X_num if num_dim > 0 else None
603
- X_cat_true = X_cat if cat_dim > 0 else None
604
- X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
605
- X_num, X_cat, gen)
606
-
607
- dataset = MaskedTabularDataset(
608
- X_num_masked, X_cat_masked, X_geo,
609
- X_num_true, num_mask,
610
- X_cat_true, cat_mask
611
- )
612
- dataloader, accum_steps = self._build_dataloader(
613
- dataset,
614
- N=N,
615
- base_bs_gpu=(2048, 1024, 512),
616
- base_bs_cpu=(256, 128),
617
- min_bs=64,
618
- target_effective_cuda=2048,
619
- target_effective_cpu=1024
620
- )
621
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
622
- self.dataloader_sampler = dataloader.sampler
623
- else:
624
- self.dataloader_sampler = None
625
-
626
- optimizer = torch.optim.Adam(
627
- self.ft.parameters(),
628
- lr=self.learning_rate,
629
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
630
- )
631
- scaler = GradScaler(enabled=(device_type == 'cuda'))
632
-
633
- train_history: List[float] = []
634
- val_history: List[float] = []
635
- best_loss = float("inf")
636
- best_state = None
637
- patience_counter = 0
638
- is_ddp_model = isinstance(self.ft, DDP)
639
- use_collectives = dist.is_initialized() and is_ddp_model
640
-
641
- clip_fn = None
642
- if self.device.type == 'cuda':
643
- def clip_fn(): return (scaler.unscale_(optimizer),
644
- clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
645
-
646
- for epoch in range(1, int(self.epochs) + 1):
647
- if self.dataloader_sampler is not None:
648
- self.dataloader_sampler.set_epoch(epoch)
649
-
650
- self.ft.train()
651
- optimizer.zero_grad()
652
- epoch_loss_sum = 0.0
653
- epoch_count = 0.0
654
-
655
- for step, batch in enumerate(dataloader):
656
- is_update_step = ((step + 1) % accum_steps == 0) or \
657
- ((step + 1) == len(dataloader))
658
- sync_cm = self.ft.no_sync if (
659
- is_ddp_model and not is_update_step) else nullcontext
660
- with sync_cm():
661
- with autocast(enabled=(device_type == 'cuda')):
662
- X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
663
- X_num_b = X_num_b.to(self.device, non_blocking=True)
664
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
665
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
666
- num_true_b = None if num_true_b is None else num_true_b.to(
667
- self.device, non_blocking=True)
668
- num_mask_b = None if num_mask_b is None else num_mask_b.to(
669
- self.device, non_blocking=True)
670
- cat_true_b = None if cat_true_b is None else cat_true_b.to(
671
- self.device, non_blocking=True)
672
- cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
673
- self.device, non_blocking=True)
674
-
675
- num_pred, cat_logits = self.ft(
676
- X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
677
- batch_loss = _compute_reconstruction_loss(
678
- num_pred, cat_logits, num_true_b, num_mask_b,
679
- cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
680
- device=X_num_b.device)
681
- local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
682
- global_bad = local_bad
683
- if use_collectives:
684
- bad = torch.tensor(
685
- [local_bad],
686
- device=batch_loss.device,
687
- dtype=torch.int32,
688
- )
689
- dist.all_reduce(bad, op=dist.ReduceOp.MAX)
690
- global_bad = int(bad.item())
691
-
692
- if global_bad:
693
- msg = (
694
- f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
695
- f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
696
- )
697
- should_log = (not dist.is_initialized()
698
- or DistributedUtils.is_main_process())
699
- if should_log:
700
- print(msg, flush=True)
701
- print(
702
- f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
703
- f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
704
- f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
705
- flush=True,
706
- )
707
- if X_geo_b is not None:
708
- print(
709
- f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
710
- f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
711
- f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
712
- flush=True,
713
- )
714
- if trial is not None:
715
- raise optuna.TrialPruned(msg)
716
- raise RuntimeError(msg)
717
- loss_for_backward = batch_loss / float(accum_steps)
718
- scaler.scale(loss_for_backward).backward()
719
-
720
- if is_update_step:
721
- if clip_fn is not None:
722
- clip_fn()
723
- scaler.step(optimizer)
724
- scaler.update()
725
- optimizer.zero_grad()
726
-
727
- epoch_loss_sum += float(batch_loss.detach().item()) * \
728
- float(X_num_b.shape[0])
729
- epoch_count += float(X_num_b.shape[0])
730
-
731
- train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
732
-
733
- if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
734
- should_compute_val = (not dist.is_initialized()
735
- or DistributedUtils.is_main_process())
736
- loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
737
- "cpu")
738
- val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
739
-
740
- if should_compute_val:
741
- self.ft.eval()
742
- with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
743
- val_bs = min(
744
- int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
745
- total_val = 0.0
746
- total_n = 0.0
747
- for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
748
- end = min(
749
- int(X_num_val.shape[0]), start + max(1, val_bs))
750
- X_num_v_true_cpu = X_num_val[start:end]
751
- X_cat_v_true_cpu = X_cat_val[start:end]
752
- X_geo_v = X_geo_val[start:end].to(
753
- self.device, non_blocking=True)
754
- gen_val = torch.Generator()
755
- gen_val.manual_seed(10_000 + epoch + start)
756
- X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
757
- X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
758
- X_num_v_true = X_num_v_true_cpu.to(
759
- self.device, non_blocking=True)
760
- X_cat_v_true = X_cat_v_true_cpu.to(
761
- self.device, non_blocking=True)
762
- X_num_v = X_num_v_cpu.to(
763
- self.device, non_blocking=True)
764
- X_cat_v = X_cat_v_cpu.to(
765
- self.device, non_blocking=True)
766
- val_num_mask = None if val_num_mask is None else val_num_mask.to(
767
- self.device, non_blocking=True)
768
- val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
769
- self.device, non_blocking=True)
770
- num_pred_v, cat_logits_v = self.ft(
771
- X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
772
- loss_v = _compute_reconstruction_loss(
773
- num_pred_v, cat_logits_v,
774
- X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
775
- X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
776
- num_loss_weight, cat_loss_weight,
777
- device=X_num_v.device
778
- )
779
- if not torch.isfinite(loss_v):
780
- total_val = float("inf")
781
- total_n = 1.0
782
- break
783
- total_val += float(loss_v.detach().item()
784
- ) * float(end - start)
785
- total_n += float(end - start)
786
- val_loss_tensor[0] = total_val / max(total_n, 1.0)
787
-
788
- if use_collectives:
789
- dist.broadcast(val_loss_tensor, src=0)
790
- val_loss_value = float(val_loss_tensor.item())
791
- prune_now = False
792
- prune_msg = None
793
- if not np.isfinite(val_loss_value):
794
- prune_now = True
795
- prune_msg = (
796
- f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
797
- f"(epoch={epoch}, val_loss={val_loss_value})"
798
- )
799
- val_history.append(val_loss_value)
800
-
801
- if val_loss_value < best_loss:
802
- best_loss = val_loss_value
803
- # Efficiently clone state_dict - only clone tensor data, not DDP metadata
804
- base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
805
- best_state = {
806
- k: v.detach().clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)
807
- for k, v in base_module.state_dict().items()
808
- }
809
- patience_counter = 0
810
- else:
811
- patience_counter += 1
812
- if best_state is not None and patience_counter >= int(self.patience):
813
- break
814
-
815
- if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
816
- trial.report(val_loss_value, epoch)
817
- if trial.should_prune():
818
- prune_now = True
819
-
820
- if use_collectives:
821
- flag = torch.tensor(
822
- [1 if prune_now else 0],
823
- device=loss_tensor_device,
824
- dtype=torch.int32,
825
- )
826
- dist.broadcast(flag, src=0)
827
- prune_now = bool(flag.item())
828
-
829
- if prune_now:
830
- if prune_msg:
831
- raise optuna.TrialPruned(prune_msg)
832
- raise optuna.TrialPruned()
833
-
834
- self.training_history = {"train": train_history, "val": val_history}
835
- self._plot_loss_curve(self.training_history, getattr(
836
- self, "loss_curve_path", None))
837
- if has_val and best_state is not None:
838
- # Load state into unwrapped module to match how it was saved
839
- base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
840
- base_module.load_state_dict(best_state)
841
- return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
842
-
843
- def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
844
- # X_test must include all numeric/categorical columns; geo_tokens is optional.
845
-
846
- self.ft.eval()
847
- X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
848
- X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
849
-
850
- num_rows = X_num.shape[0]
851
- if num_rows == 0:
852
- return np.empty(0, dtype=np.float32)
853
-
854
- device = self.device if isinstance(
855
- self.device, torch.device) else torch.device(self.device)
856
-
857
- def resolve_batch_size(n_rows: int) -> int:
858
- if batch_size is not None:
859
- return max(1, min(int(batch_size), n_rows))
860
- # Estimate a safe batch size based on model size to avoid attention OOM.
861
- token_cnt = self.num_numeric_tokens + len(self.cat_cols)
862
- if self.num_geo > 0:
863
- token_cnt += 1
864
- approx_units = max(1, token_cnt * max(1, self.d_model))
865
- if device.type == 'cuda':
866
- if approx_units >= 8192:
867
- base = 512
868
- elif approx_units >= 4096:
869
- base = 1024
870
- else:
871
- base = 2048
872
- else:
873
- base = 512
874
- return max(1, min(base, n_rows))
875
-
876
- eff_batch = resolve_batch_size(num_rows)
877
- preds: List[torch.Tensor] = []
878
-
879
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
880
- with inference_cm():
881
- for start in range(0, num_rows, eff_batch):
882
- end = min(num_rows, start + eff_batch)
883
- X_num_b = X_num[start:end].to(device, non_blocking=True)
884
- X_cat_b = X_cat[start:end].to(device, non_blocking=True)
885
- X_geo_b = X_geo[start:end].to(device, non_blocking=True)
886
- pred_chunk = self.ft(
887
- X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
888
- preds.append(pred_chunk.cpu())
889
-
890
- y_pred = torch.cat(preds, dim=0).numpy()
891
-
892
- if return_embedding:
893
- return y_pred
894
-
895
- if self.task_type == 'classification':
896
- # Convert logits to probabilities.
897
- y_pred = 1 / (1 + np.exp(-y_pred))
898
- else:
899
- # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
900
- y_pred = np.clip(y_pred, 1e-6, None)
901
- return y_pred.ravel()
902
-
903
- def set_params(self, params: dict):
904
-
905
- # Keep sklearn-style behavior.
906
- # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
907
-
908
- for key, value in params.items():
909
- if hasattr(self, key):
910
- setattr(self, key, value)
911
- else:
912
- raise ValueError(f"Parameter {key} not found in model.")
913
- return self
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from contextlib import nullcontext
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import optuna
9
+ import pandas as pd
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.nn.utils import clip_grad_norm_
17
+
18
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
19
+ from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
20
+ from ins_pricing.utils import EPS
21
+ from ins_pricing.utils.losses import (
22
+ infer_loss_name_from_model_name,
23
+ normalize_loss_name,
24
+ resolve_tweedie_power,
25
+ )
26
+ from ins_pricing.modelling.bayesopt.models.model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
27
+
28
+
29
+ # --- Helper functions for reconstruction loss computation ---
30
+
31
+
32
+ def _compute_numeric_reconstruction_loss(
33
+ num_pred: Optional[torch.Tensor],
34
+ num_true: Optional[torch.Tensor],
35
+ num_mask: Optional[torch.Tensor],
36
+ loss_weight: float,
37
+ device: torch.device,
38
+ ) -> torch.Tensor:
39
+ """Compute MSE loss for numeric feature reconstruction.
40
+
41
+ Args:
42
+ num_pred: Predicted numeric values (N, num_features)
43
+ num_true: Ground truth numeric values (N, num_features)
44
+ num_mask: Boolean mask indicating which values were masked (N, num_features)
45
+ loss_weight: Weight to apply to the loss
46
+ device: Target device for computation
47
+
48
+ Returns:
49
+ Weighted MSE loss for masked numeric features
50
+ """
51
+ if num_pred is None or num_true is None or num_mask is None:
52
+ return torch.zeros((), device=device, dtype=torch.float32)
53
+
54
+ num_mask = num_mask.to(dtype=torch.bool)
55
+ if not num_mask.any():
56
+ return torch.zeros((), device=device, dtype=torch.float32)
57
+
58
+ diff = num_pred - num_true
59
+ mse = diff * diff
60
+ return float(loss_weight) * mse[num_mask].mean()
61
+
62
+
63
+ def _compute_categorical_reconstruction_loss(
64
+ cat_logits: Optional[List[torch.Tensor]],
65
+ cat_true: Optional[torch.Tensor],
66
+ cat_mask: Optional[torch.Tensor],
67
+ loss_weight: float,
68
+ device: torch.device,
69
+ ) -> torch.Tensor:
70
+ """Compute cross-entropy loss for categorical feature reconstruction.
71
+
72
+ Args:
73
+ cat_logits: List of logits for each categorical feature
74
+ cat_true: Ground truth categorical indices (N, num_cat_features)
75
+ cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
76
+ loss_weight: Weight to apply to the loss
77
+ device: Target device for computation
78
+
79
+ Returns:
80
+ Weighted cross-entropy loss for masked categorical features
81
+ """
82
+ if not cat_logits or cat_true is None or cat_mask is None:
83
+ return torch.zeros((), device=device, dtype=torch.float32)
84
+
85
+ cat_mask = cat_mask.to(dtype=torch.bool)
86
+ cat_losses: List[torch.Tensor] = []
87
+
88
+ for j, logits in enumerate(cat_logits):
89
+ mask_j = cat_mask[:, j]
90
+ if not mask_j.any():
91
+ continue
92
+ targets = cat_true[:, j]
93
+ cat_losses.append(
94
+ F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
95
+ )
96
+
97
+ if not cat_losses:
98
+ return torch.zeros((), device=device, dtype=torch.float32)
99
+
100
+ return float(loss_weight) * torch.stack(cat_losses).mean()
101
+
102
+
103
+ def _compute_reconstruction_loss(
104
+ num_pred: Optional[torch.Tensor],
105
+ cat_logits: Optional[List[torch.Tensor]],
106
+ num_true: Optional[torch.Tensor],
107
+ num_mask: Optional[torch.Tensor],
108
+ cat_true: Optional[torch.Tensor],
109
+ cat_mask: Optional[torch.Tensor],
110
+ num_loss_weight: float,
111
+ cat_loss_weight: float,
112
+ device: torch.device,
113
+ ) -> torch.Tensor:
114
+ """Compute combined reconstruction loss for masked tabular data.
115
+
116
+ This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
117
+
118
+ Args:
119
+ num_pred: Predicted numeric values
120
+ cat_logits: List of logits for categorical features
121
+ num_true: Ground truth numeric values
122
+ num_mask: Mask for numeric features
123
+ cat_true: Ground truth categorical indices
124
+ cat_mask: Mask for categorical features
125
+ num_loss_weight: Weight for numeric loss
126
+ cat_loss_weight: Weight for categorical loss
127
+ device: Target device for computation
128
+
129
+ Returns:
130
+ Combined weighted reconstruction loss
131
+ """
132
+ num_loss = _compute_numeric_reconstruction_loss(
133
+ num_pred, num_true, num_mask, num_loss_weight, device
134
+ )
135
+ cat_loss = _compute_categorical_reconstruction_loss(
136
+ cat_logits, cat_true, cat_mask, cat_loss_weight, device
137
+ )
138
+ return num_loss + cat_loss
139
+
140
+
141
+ # Scikit-Learn style wrapper for FTTransformer.
142
+
143
+
144
+ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
145
+
146
+ # sklearn-style wrapper:
147
+ # - num_cols: numeric feature column names
148
+ # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
149
+
150
+ @staticmethod
151
+ def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
152
+ num_cols_count = len(num_cols or [])
153
+ if num_cols_count == 0:
154
+ return 0
155
+ if requested is not None:
156
+ count = int(requested)
157
+ if count <= 0:
158
+ raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
159
+ return count
160
+ return max(1, num_cols_count)
161
+
162
+ def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
163
+ n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
164
+ task_type: str = 'regression',
165
+ tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
166
+ weight_decay: float = 0.0,
167
+ use_data_parallel: bool = True,
168
+ use_ddp: bool = False,
169
+ num_numeric_tokens: Optional[int] = None,
170
+ loss_name: Optional[str] = None
171
+ ):
172
+ super().__init__()
173
+
174
+ self.use_ddp = use_ddp
175
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
176
+ False, 0, 0, 1)
177
+ if self.use_ddp:
178
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
179
+
180
+ self.model_nme = model_nme
181
+ self.num_cols = list(num_cols)
182
+ self.cat_cols = list(cat_cols)
183
+ self.num_numeric_tokens = self.resolve_numeric_token_count(
184
+ self.num_cols,
185
+ self.cat_cols,
186
+ num_numeric_tokens,
187
+ )
188
+ self.d_model = d_model
189
+ self.n_heads = n_heads
190
+ self.n_layers = n_layers
191
+ self.dropout = dropout
192
+ self.batch_num = batch_num
193
+ self.epochs = epochs
194
+ self.learning_rate = learning_rate
195
+ self.weight_decay = weight_decay
196
+ self.task_type = task_type
197
+ self.patience = patience
198
+ resolved_loss = normalize_loss_name(loss_name, self.task_type)
199
+ if self.task_type == 'classification':
200
+ self.loss_name = "logloss"
201
+ self.tw_power = None # No Tweedie power for classification.
202
+ else:
203
+ if resolved_loss == "auto":
204
+ resolved_loss = infer_loss_name_from_model_name(self.model_nme)
205
+ self.loss_name = resolved_loss
206
+ if self.loss_name == "tweedie":
207
+ self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
208
+ else:
209
+ self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
210
+
211
+ if self.is_ddp_enabled:
212
+ self.device = torch.device(f"cuda:{self.local_rank}")
213
+ elif torch.cuda.is_available():
214
+ self.device = torch.device("cuda")
215
+ elif torch.backends.mps.is_available():
216
+ self.device = torch.device("mps")
217
+ else:
218
+ self.device = torch.device("cpu")
219
+ self.cat_cardinalities = None
220
+ self.cat_categories = {}
221
+ self.cat_maps: Dict[str, Dict[Any, int]] = {}
222
+ self.cat_str_maps: Dict[str, Dict[str, int]] = {}
223
+ self._num_mean = None
224
+ self._num_std = None
225
+ self.ft = None
226
+ self.use_data_parallel = bool(use_data_parallel)
227
+ self.num_geo = 0
228
+ self._geo_params: Dict[str, Any] = {}
229
+ self.loss_curve_path: Optional[str] = None
230
+ self.training_history: Dict[str, List[float]] = {
231
+ "train": [], "val": []}
232
+
233
+ def _build_model(self, X_train):
234
+ num_numeric = len(self.num_cols)
235
+ cat_cardinalities = []
236
+
237
+ if num_numeric > 0:
238
+ num_arr = X_train[self.num_cols].to_numpy(
239
+ dtype=np.float32, copy=False)
240
+ num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
241
+ mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
242
+ std = num_arr.std(axis=0).astype(np.float32, copy=False)
243
+ std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
244
+ self._num_mean = mean
245
+ self._num_std = std
246
+ else:
247
+ self._num_mean = None
248
+ self._num_std = None
249
+
250
+ self.cat_maps = {}
251
+ self.cat_str_maps = {}
252
+ for col in self.cat_cols:
253
+ cats = X_train[col].astype('category')
254
+ categories = cats.cat.categories
255
+ self.cat_categories[col] = categories # Store full category list from training.
256
+ self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
257
+ if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
258
+ self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
259
+
260
+ card = len(categories) + 1 # Reserve one extra class for unknown/missing.
261
+ cat_cardinalities.append(card)
262
+
263
+ self.cat_cardinalities = cat_cardinalities
264
+
265
+ core = FTTransformerCore(
266
+ num_numeric=num_numeric,
267
+ cat_cardinalities=cat_cardinalities,
268
+ d_model=self.d_model,
269
+ n_heads=self.n_heads,
270
+ n_layers=self.n_layers,
271
+ dropout=self.dropout,
272
+ task_type=self.task_type,
273
+ num_geo=self.num_geo,
274
+ num_numeric_tokens=self.num_numeric_tokens
275
+ )
276
+ use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
277
+ if self.is_ddp_enabled:
278
+ core = core.to(self.device)
279
+ core = DDP(core, device_ids=[
280
+ self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
281
+ self.use_data_parallel = False
282
+ elif use_dp:
283
+ if self.use_ddp and not self.is_ddp_enabled:
284
+ print(
285
+ ">>> DDP requested but not initialized; falling back to DataParallel.")
286
+ core = nn.DataParallel(core, device_ids=list(
287
+ range(torch.cuda.device_count())))
288
+ self.device = torch.device("cuda")
289
+ self.use_data_parallel = True
290
+ else:
291
+ self.use_data_parallel = False
292
+ self.ft = core.to(self.device)
293
+
294
+ def _encode_cats(self, X):
295
+ # Input DataFrame must include all categorical feature columns.
296
+ # Return int64 array with shape (N, num_categorical_features).
297
+
298
+ if not self.cat_cols:
299
+ return np.zeros((len(X), 0), dtype='int64')
300
+
301
+ n_rows = len(X)
302
+ n_cols = len(self.cat_cols)
303
+ X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
304
+ for idx, col in enumerate(self.cat_cols):
305
+ categories = self.cat_categories[col]
306
+ mapping = self.cat_maps.get(col)
307
+ if mapping is None:
308
+ mapping = {cat: i for i, cat in enumerate(categories)}
309
+ self.cat_maps[col] = mapping
310
+ unknown_idx = len(categories)
311
+ series = X[col]
312
+ codes = series.map(mapping)
313
+ unmapped = series.notna() & codes.isna()
314
+ if unmapped.any():
315
+ try:
316
+ series_cast = series.astype(categories.dtype)
317
+ except Exception:
318
+ series_cast = None
319
+ if series_cast is not None:
320
+ codes = series_cast.map(mapping)
321
+ unmapped = series_cast.notna() & codes.isna()
322
+ if unmapped.any():
323
+ str_map = self.cat_str_maps.get(col)
324
+ if str_map is None:
325
+ str_map = {str(cat): i for i, cat in enumerate(categories)}
326
+ self.cat_str_maps[col] = str_map
327
+ codes = series.astype(str).map(str_map)
328
+ if pd.api.types.is_categorical_dtype(codes):
329
+ codes = codes.astype("float")
330
+ codes = codes.fillna(unknown_idx).astype(
331
+ "int64", copy=False).to_numpy()
332
+ X_cat_np[:, idx] = codes
333
+ return X_cat_np
334
+
335
+ def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
336
+ return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
337
+
338
+ def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
339
+ return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
340
+
341
+ @staticmethod
342
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
343
+ if arr is None:
344
+ return
345
+ if isinstance(arr, pd.DataFrame):
346
+ if arr.shape[1] != 1:
347
+ raise ValueError(f"{name} must be 1d (single column).")
348
+ length = len(arr)
349
+ else:
350
+ arr_np = np.asarray(arr)
351
+ if arr_np.ndim == 0:
352
+ raise ValueError(f"{name} must be 1d.")
353
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
354
+ raise ValueError(f"{name} must be 1d or Nx1.")
355
+ length = arr_np.shape[0]
356
+ if length != n_rows:
357
+ raise ValueError(
358
+ f"{name} length {length} does not match X length {n_rows}."
359
+ )
360
+
361
+ def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
362
+ if X is None:
363
+ if allow_none:
364
+ return None, None, None, None, None, False
365
+ raise ValueError("Input features X must not be None.")
366
+ if not isinstance(X, pd.DataFrame):
367
+ raise ValueError("X must be a pandas DataFrame.")
368
+ missing_cols = [
369
+ col for col in (self.num_cols + self.cat_cols) if col not in X.columns
370
+ ]
371
+ if missing_cols:
372
+ raise ValueError(f"X is missing required columns: {missing_cols}")
373
+ n_rows = len(X)
374
+ if y is not None:
375
+ self._validate_vector(y, "y", n_rows)
376
+ if w is not None:
377
+ self._validate_vector(w, "w", n_rows)
378
+
379
+ num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
380
+ if not num_np.flags["OWNDATA"]:
381
+ num_np = num_np.copy()
382
+ num_np = np.nan_to_num(num_np, nan=0.0,
383
+ posinf=0.0, neginf=0.0, copy=False)
384
+ if self._num_mean is not None and self._num_std is not None and num_np.size:
385
+ num_np = (num_np - self._num_mean) / self._num_std
386
+ X_num = torch.as_tensor(num_np)
387
+ if self.cat_cols:
388
+ X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
389
+ else:
390
+ X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
391
+
392
+ if geo_tokens is not None:
393
+ geo_np = np.asarray(geo_tokens, dtype=np.float32)
394
+ if geo_np.shape[0] != n_rows:
395
+ raise ValueError(
396
+ "geo_tokens length does not match X rows.")
397
+ if geo_np.ndim == 1:
398
+ geo_np = geo_np.reshape(-1, 1)
399
+ elif self.num_geo > 0:
400
+ raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
401
+ else:
402
+ geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
403
+ X_geo = torch.as_tensor(geo_np)
404
+
405
+ y_tensor = torch.as_tensor(
406
+ y.to_numpy(dtype=np.float32, copy=False) if hasattr(
407
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
408
+ ).view(-1, 1) if y is not None else None
409
+ if y_tensor is None:
410
+ w_tensor = None
411
+ elif w is not None:
412
+ w_tensor = torch.as_tensor(
413
+ w.to_numpy(dtype=np.float32, copy=False) if hasattr(
414
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
415
+ ).view(-1, 1)
416
+ else:
417
+ w_tensor = torch.ones_like(y_tensor)
418
+ return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
419
+
420
+ def fit(self, X_train, y_train, w_train=None,
421
+ X_val=None, y_val=None, w_val=None, trial=None,
422
+ geo_train=None, geo_val=None):
423
+
424
+ # Build the underlying model on first fit.
425
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
426
+ if self.ft is None:
427
+ self._build_model(X_train)
428
+
429
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
430
+ X_train, y_train, w_train, geo_train=geo_train)
431
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
432
+ X_val, y_val, w_val, geo_val=geo_val)
433
+
434
+ # --- Build DataLoader ---
435
+ dataset = TabularDataset(
436
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
437
+ )
438
+
439
+ dataloader, accum_steps = self._build_dataloader(
440
+ dataset,
441
+ N=X_num_train.shape[0],
442
+ base_bs_gpu=(2048, 1024, 512),
443
+ base_bs_cpu=(256, 128),
444
+ min_bs=64,
445
+ target_effective_cuda=2048,
446
+ target_effective_cpu=1024
447
+ )
448
+
449
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
450
+ self.dataloader_sampler = dataloader.sampler
451
+ else:
452
+ self.dataloader_sampler = None
453
+
454
+ optimizer = torch.optim.Adam(
455
+ self.ft.parameters(),
456
+ lr=self.learning_rate,
457
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
458
+ )
459
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
460
+
461
+ X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
462
+ val_dataloader = None
463
+ if has_val:
464
+ val_dataset = TabularDataset(
465
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
466
+ )
467
+ val_dataloader = self._build_val_dataloader(
468
+ val_dataset, dataloader, accum_steps)
469
+
470
+ # Check for both DataParallel and DDP wrappers
471
+ is_data_parallel = isinstance(self.ft, (nn.DataParallel, DDP))
472
+
473
+ def forward_fn(batch):
474
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
475
+
476
+ # For DataParallel, inputs are automatically scattered; for DDP, move to local device
477
+ if not isinstance(self.ft, nn.DataParallel):
478
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
479
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
480
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
481
+ y_b = y_b.to(self.device, non_blocking=True)
482
+ w_b = w_b.to(self.device, non_blocking=True)
483
+
484
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
485
+ return y_pred, y_b, w_b
486
+
487
+ def val_forward_fn():
488
+ total_loss = 0.0
489
+ total_weight = 0.0
490
+ for batch in val_dataloader:
491
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
492
+ if not isinstance(self.ft, nn.DataParallel):
493
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
494
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
495
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
496
+ y_b = y_b.to(self.device, non_blocking=True)
497
+ w_b = w_b.to(self.device, non_blocking=True)
498
+
499
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
500
+
501
+ # Manually compute validation loss.
502
+ losses = self._compute_losses(
503
+ y_pred, y_b, apply_softplus=False)
504
+
505
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
506
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
507
+
508
+ total_loss += batch_weighted_loss_sum.item()
509
+ total_weight += batch_weight_sum.item()
510
+
511
+ return total_loss / max(total_weight, EPS)
512
+
513
+ clip_fn = None
514
+ if self.device.type == 'cuda':
515
+ def clip_fn(): return (scaler.unscale_(optimizer),
516
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
517
+
518
+ best_state, history = self._train_model(
519
+ self.ft,
520
+ dataloader,
521
+ accum_steps,
522
+ optimizer,
523
+ scaler,
524
+ forward_fn,
525
+ val_forward_fn if has_val else None,
526
+ apply_softplus=False,
527
+ clip_fn=clip_fn,
528
+ trial=trial,
529
+ loss_curve_path=getattr(self, "loss_curve_path", None)
530
+ )
531
+
532
+ if has_val and best_state is not None:
533
+ # Load state into unwrapped module to match how it was saved
534
+ base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
535
+ base_module.load_state_dict(best_state)
536
+ self.training_history = history
537
+
538
+ def fit_unsupervised(self,
539
+ X_train,
540
+ X_val=None,
541
+ trial: Optional[optuna.trial.Trial] = None,
542
+ geo_train=None,
543
+ geo_val=None,
544
+ mask_prob_num: float = 0.15,
545
+ mask_prob_cat: float = 0.15,
546
+ num_loss_weight: float = 1.0,
547
+ cat_loss_weight: float = 1.0) -> float:
548
+ """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
549
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
550
+ if self.ft is None:
551
+ self._build_model(X_train)
552
+
553
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
554
+ X_train, None, None, geo_tokens=geo_train, allow_none=True)
555
+ has_val = X_val is not None
556
+ if has_val:
557
+ X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
558
+ X_val, None, None, geo_tokens=geo_val, allow_none=True)
559
+ else:
560
+ X_num_val = X_cat_val = X_geo_val = None
561
+
562
+ N = int(X_num.shape[0])
563
+ num_dim = int(X_num.shape[1])
564
+ cat_dim = int(X_cat.shape[1])
565
+ device_type = self._device_type()
566
+
567
+ gen = torch.Generator()
568
+ gen.manual_seed(13 + int(getattr(self, "rank", 0)))
569
+
570
+ base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
571
+ cardinals = getattr(base_model, "cat_cardinalities", None) or []
572
+ unknown_idx = torch.tensor(
573
+ [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
574
+
575
+ means = None
576
+ if num_dim > 0:
577
+ # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
578
+ means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
579
+
580
+ def _mask_inputs(X_num_in: torch.Tensor,
581
+ X_cat_in: torch.Tensor,
582
+ generator: torch.Generator):
583
+ n_rows = int(X_num_in.shape[0])
584
+ num_mask_local = None
585
+ cat_mask_local = None
586
+ X_num_masked_local = X_num_in
587
+ X_cat_masked_local = X_cat_in
588
+ if num_dim > 0:
589
+ num_mask_local = (torch.rand(
590
+ (n_rows, num_dim), generator=generator) < float(mask_prob_num))
591
+ X_num_masked_local = X_num_in.clone()
592
+ if num_mask_local.any():
593
+ X_num_masked_local[num_mask_local] = means.expand_as(
594
+ X_num_masked_local)[num_mask_local]
595
+ if cat_dim > 0:
596
+ cat_mask_local = (torch.rand(
597
+ (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
598
+ X_cat_masked_local = X_cat_in.clone()
599
+ if cat_mask_local.any():
600
+ X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
601
+ X_cat_masked_local)[cat_mask_local]
602
+ return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
603
+
604
+ X_num_true = X_num if num_dim > 0 else None
605
+ X_cat_true = X_cat if cat_dim > 0 else None
606
+ X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
607
+ X_num, X_cat, gen)
608
+
609
+ dataset = MaskedTabularDataset(
610
+ X_num_masked, X_cat_masked, X_geo,
611
+ X_num_true, num_mask,
612
+ X_cat_true, cat_mask
613
+ )
614
+ dataloader, accum_steps = self._build_dataloader(
615
+ dataset,
616
+ N=N,
617
+ base_bs_gpu=(2048, 1024, 512),
618
+ base_bs_cpu=(256, 128),
619
+ min_bs=64,
620
+ target_effective_cuda=2048,
621
+ target_effective_cpu=1024
622
+ )
623
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
624
+ self.dataloader_sampler = dataloader.sampler
625
+ else:
626
+ self.dataloader_sampler = None
627
+
628
+ optimizer = torch.optim.Adam(
629
+ self.ft.parameters(),
630
+ lr=self.learning_rate,
631
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
632
+ )
633
+ scaler = GradScaler(enabled=(device_type == 'cuda'))
634
+
635
+ train_history: List[float] = []
636
+ val_history: List[float] = []
637
+ best_loss = float("inf")
638
+ best_state = None
639
+ patience_counter = 0
640
+ is_ddp_model = isinstance(self.ft, DDP)
641
+ use_collectives = dist.is_initialized() and is_ddp_model
642
+
643
+ clip_fn = None
644
+ if self.device.type == 'cuda':
645
+ def clip_fn(): return (scaler.unscale_(optimizer),
646
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
647
+
648
+ for epoch in range(1, int(self.epochs) + 1):
649
+ if self.dataloader_sampler is not None:
650
+ self.dataloader_sampler.set_epoch(epoch)
651
+
652
+ self.ft.train()
653
+ optimizer.zero_grad()
654
+ epoch_loss_sum = 0.0
655
+ epoch_count = 0.0
656
+
657
+ for step, batch in enumerate(dataloader):
658
+ is_update_step = ((step + 1) % accum_steps == 0) or \
659
+ ((step + 1) == len(dataloader))
660
+ sync_cm = self.ft.no_sync if (
661
+ is_ddp_model and not is_update_step) else nullcontext
662
+ with sync_cm():
663
+ with autocast(enabled=(device_type == 'cuda')):
664
+ X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
665
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
666
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
667
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
668
+ num_true_b = None if num_true_b is None else num_true_b.to(
669
+ self.device, non_blocking=True)
670
+ num_mask_b = None if num_mask_b is None else num_mask_b.to(
671
+ self.device, non_blocking=True)
672
+ cat_true_b = None if cat_true_b is None else cat_true_b.to(
673
+ self.device, non_blocking=True)
674
+ cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
675
+ self.device, non_blocking=True)
676
+
677
+ num_pred, cat_logits = self.ft(
678
+ X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
679
+ batch_loss = _compute_reconstruction_loss(
680
+ num_pred, cat_logits, num_true_b, num_mask_b,
681
+ cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
682
+ device=X_num_b.device)
683
+ local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
684
+ global_bad = local_bad
685
+ if use_collectives:
686
+ bad = torch.tensor(
687
+ [local_bad],
688
+ device=batch_loss.device,
689
+ dtype=torch.int32,
690
+ )
691
+ dist.all_reduce(bad, op=dist.ReduceOp.MAX)
692
+ global_bad = int(bad.item())
693
+
694
+ if global_bad:
695
+ msg = (
696
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
697
+ f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
698
+ )
699
+ should_log = (not dist.is_initialized()
700
+ or DistributedUtils.is_main_process())
701
+ if should_log:
702
+ print(msg, flush=True)
703
+ print(
704
+ f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
705
+ f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
706
+ f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
707
+ flush=True,
708
+ )
709
+ if X_geo_b is not None:
710
+ print(
711
+ f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
712
+ f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
713
+ f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
714
+ flush=True,
715
+ )
716
+ if trial is not None:
717
+ raise optuna.TrialPruned(msg)
718
+ raise RuntimeError(msg)
719
+ loss_for_backward = batch_loss / float(accum_steps)
720
+ scaler.scale(loss_for_backward).backward()
721
+
722
+ if is_update_step:
723
+ if clip_fn is not None:
724
+ clip_fn()
725
+ scaler.step(optimizer)
726
+ scaler.update()
727
+ optimizer.zero_grad()
728
+
729
+ epoch_loss_sum += float(batch_loss.detach().item()) * \
730
+ float(X_num_b.shape[0])
731
+ epoch_count += float(X_num_b.shape[0])
732
+
733
+ train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
734
+
735
+ if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
736
+ should_compute_val = (not dist.is_initialized()
737
+ or DistributedUtils.is_main_process())
738
+ loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
739
+ "cpu")
740
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
741
+
742
+ if should_compute_val:
743
+ self.ft.eval()
744
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
745
+ val_bs = min(
746
+ int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
747
+ total_val = 0.0
748
+ total_n = 0.0
749
+ for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
750
+ end = min(
751
+ int(X_num_val.shape[0]), start + max(1, val_bs))
752
+ X_num_v_true_cpu = X_num_val[start:end]
753
+ X_cat_v_true_cpu = X_cat_val[start:end]
754
+ X_geo_v = X_geo_val[start:end].to(
755
+ self.device, non_blocking=True)
756
+ gen_val = torch.Generator()
757
+ gen_val.manual_seed(10_000 + epoch + start)
758
+ X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
759
+ X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
760
+ X_num_v_true = X_num_v_true_cpu.to(
761
+ self.device, non_blocking=True)
762
+ X_cat_v_true = X_cat_v_true_cpu.to(
763
+ self.device, non_blocking=True)
764
+ X_num_v = X_num_v_cpu.to(
765
+ self.device, non_blocking=True)
766
+ X_cat_v = X_cat_v_cpu.to(
767
+ self.device, non_blocking=True)
768
+ val_num_mask = None if val_num_mask is None else val_num_mask.to(
769
+ self.device, non_blocking=True)
770
+ val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
771
+ self.device, non_blocking=True)
772
+ num_pred_v, cat_logits_v = self.ft(
773
+ X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
774
+ loss_v = _compute_reconstruction_loss(
775
+ num_pred_v, cat_logits_v,
776
+ X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
777
+ X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
778
+ num_loss_weight, cat_loss_weight,
779
+ device=X_num_v.device
780
+ )
781
+ if not torch.isfinite(loss_v):
782
+ total_val = float("inf")
783
+ total_n = 1.0
784
+ break
785
+ total_val += float(loss_v.detach().item()
786
+ ) * float(end - start)
787
+ total_n += float(end - start)
788
+ val_loss_tensor[0] = total_val / max(total_n, 1.0)
789
+
790
+ if use_collectives:
791
+ dist.broadcast(val_loss_tensor, src=0)
792
+ val_loss_value = float(val_loss_tensor.item())
793
+ prune_now = False
794
+ prune_msg = None
795
+ if not np.isfinite(val_loss_value):
796
+ prune_now = True
797
+ prune_msg = (
798
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
799
+ f"(epoch={epoch}, val_loss={val_loss_value})"
800
+ )
801
+ val_history.append(val_loss_value)
802
+
803
+ if val_loss_value < best_loss:
804
+ best_loss = val_loss_value
805
+ # Efficiently clone state_dict - only clone tensor data, not DDP metadata
806
+ base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
807
+ best_state = {
808
+ k: v.detach().clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)
809
+ for k, v in base_module.state_dict().items()
810
+ }
811
+ patience_counter = 0
812
+ else:
813
+ patience_counter += 1
814
+ if best_state is not None and patience_counter >= int(self.patience):
815
+ break
816
+
817
+ if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
818
+ trial.report(val_loss_value, epoch)
819
+ if trial.should_prune():
820
+ prune_now = True
821
+
822
+ if use_collectives:
823
+ flag = torch.tensor(
824
+ [1 if prune_now else 0],
825
+ device=loss_tensor_device,
826
+ dtype=torch.int32,
827
+ )
828
+ dist.broadcast(flag, src=0)
829
+ prune_now = bool(flag.item())
830
+
831
+ if prune_now:
832
+ if prune_msg:
833
+ raise optuna.TrialPruned(prune_msg)
834
+ raise optuna.TrialPruned()
835
+
836
+ self.training_history = {"train": train_history, "val": val_history}
837
+ self._plot_loss_curve(self.training_history, getattr(
838
+ self, "loss_curve_path", None))
839
+ if has_val and best_state is not None:
840
+ # Load state into unwrapped module to match how it was saved
841
+ base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
842
+ base_module.load_state_dict(best_state)
843
+ return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
844
+
845
+ def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
846
+ # X_test must include all numeric/categorical columns; geo_tokens is optional.
847
+
848
+ self.ft.eval()
849
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
850
+ X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
851
+
852
+ num_rows = X_num.shape[0]
853
+ if num_rows == 0:
854
+ return np.empty(0, dtype=np.float32)
855
+
856
+ device = self.device if isinstance(
857
+ self.device, torch.device) else torch.device(self.device)
858
+
859
+ def resolve_batch_size(n_rows: int) -> int:
860
+ if batch_size is not None:
861
+ return max(1, min(int(batch_size), n_rows))
862
+ # Estimate a safe batch size based on model size to avoid attention OOM.
863
+ token_cnt = self.num_numeric_tokens + len(self.cat_cols)
864
+ if self.num_geo > 0:
865
+ token_cnt += 1
866
+ approx_units = max(1, token_cnt * max(1, self.d_model))
867
+ if device.type == 'cuda':
868
+ if approx_units >= 8192:
869
+ base = 512
870
+ elif approx_units >= 4096:
871
+ base = 1024
872
+ else:
873
+ base = 2048
874
+ else:
875
+ base = 512
876
+ return max(1, min(base, n_rows))
877
+
878
+ eff_batch = resolve_batch_size(num_rows)
879
+ preds: List[torch.Tensor] = []
880
+
881
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
882
+ with inference_cm():
883
+ for start in range(0, num_rows, eff_batch):
884
+ end = min(num_rows, start + eff_batch)
885
+ X_num_b = X_num[start:end].to(device, non_blocking=True)
886
+ X_cat_b = X_cat[start:end].to(device, non_blocking=True)
887
+ X_geo_b = X_geo[start:end].to(device, non_blocking=True)
888
+ pred_chunk = self.ft(
889
+ X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
890
+ preds.append(pred_chunk.cpu())
891
+
892
+ y_pred = torch.cat(preds, dim=0).numpy()
893
+
894
+ if return_embedding:
895
+ return y_pred
896
+
897
+ if self.task_type == 'classification':
898
+ # Convert logits to probabilities.
899
+ y_pred = 1 / (1 + np.exp(-y_pred))
900
+ else:
901
+ # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
902
+ y_pred = np.clip(y_pred, 1e-6, None)
903
+ return y_pred.ravel()
904
+
905
+ def set_params(self, params: dict):
906
+
907
+ # Keep sklearn-style behavior.
908
+ # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
909
+
910
+ for key, value in params.items():
911
+ if hasattr(self, key):
912
+ setattr(self, key, value)
913
+ else:
914
+ raise ValueError(f"Parameter {key} not found in model.")
915
+ return self