ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,913 +1,921 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- from contextlib import nullcontext
5
- from typing import Any, Dict, List, Optional
6
-
7
- import numpy as np
8
- import optuna
9
- import pandas as pd
10
- import torch
11
- import torch.distributed as dist
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torch.cuda.amp import autocast, GradScaler
15
- from torch.nn.parallel import DistributedDataParallel as DDP
16
- from torch.nn.utils import clip_grad_norm_
17
-
18
- from ..utils import DistributedUtils, EPS, TorchTrainerMixin
19
- from ..utils.losses import (
20
- infer_loss_name_from_model_name,
21
- normalize_loss_name,
22
- resolve_tweedie_power,
23
- )
24
- from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
25
-
26
-
27
- # --- Helper functions for reconstruction loss computation ---
28
-
29
-
30
- def _compute_numeric_reconstruction_loss(
31
- num_pred: Optional[torch.Tensor],
32
- num_true: Optional[torch.Tensor],
33
- num_mask: Optional[torch.Tensor],
34
- loss_weight: float,
35
- device: torch.device,
36
- ) -> torch.Tensor:
37
- """Compute MSE loss for numeric feature reconstruction.
38
-
39
- Args:
40
- num_pred: Predicted numeric values (N, num_features)
41
- num_true: Ground truth numeric values (N, num_features)
42
- num_mask: Boolean mask indicating which values were masked (N, num_features)
43
- loss_weight: Weight to apply to the loss
44
- device: Target device for computation
45
-
46
- Returns:
47
- Weighted MSE loss for masked numeric features
48
- """
49
- if num_pred is None or num_true is None or num_mask is None:
50
- return torch.zeros((), device=device, dtype=torch.float32)
51
-
52
- num_mask = num_mask.to(dtype=torch.bool)
53
- if not num_mask.any():
54
- return torch.zeros((), device=device, dtype=torch.float32)
55
-
56
- diff = num_pred - num_true
57
- mse = diff * diff
58
- return float(loss_weight) * mse[num_mask].mean()
59
-
60
-
61
- def _compute_categorical_reconstruction_loss(
62
- cat_logits: Optional[List[torch.Tensor]],
63
- cat_true: Optional[torch.Tensor],
64
- cat_mask: Optional[torch.Tensor],
65
- loss_weight: float,
66
- device: torch.device,
67
- ) -> torch.Tensor:
68
- """Compute cross-entropy loss for categorical feature reconstruction.
69
-
70
- Args:
71
- cat_logits: List of logits for each categorical feature
72
- cat_true: Ground truth categorical indices (N, num_cat_features)
73
- cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
74
- loss_weight: Weight to apply to the loss
75
- device: Target device for computation
76
-
77
- Returns:
78
- Weighted cross-entropy loss for masked categorical features
79
- """
80
- if not cat_logits or cat_true is None or cat_mask is None:
81
- return torch.zeros((), device=device, dtype=torch.float32)
82
-
83
- cat_mask = cat_mask.to(dtype=torch.bool)
84
- cat_losses: List[torch.Tensor] = []
85
-
86
- for j, logits in enumerate(cat_logits):
87
- mask_j = cat_mask[:, j]
88
- if not mask_j.any():
89
- continue
90
- targets = cat_true[:, j]
91
- cat_losses.append(
92
- F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
93
- )
94
-
95
- if not cat_losses:
96
- return torch.zeros((), device=device, dtype=torch.float32)
97
-
98
- return float(loss_weight) * torch.stack(cat_losses).mean()
99
-
100
-
101
- def _compute_reconstruction_loss(
102
- num_pred: Optional[torch.Tensor],
103
- cat_logits: Optional[List[torch.Tensor]],
104
- num_true: Optional[torch.Tensor],
105
- num_mask: Optional[torch.Tensor],
106
- cat_true: Optional[torch.Tensor],
107
- cat_mask: Optional[torch.Tensor],
108
- num_loss_weight: float,
109
- cat_loss_weight: float,
110
- device: torch.device,
111
- ) -> torch.Tensor:
112
- """Compute combined reconstruction loss for masked tabular data.
113
-
114
- This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
115
-
116
- Args:
117
- num_pred: Predicted numeric values
118
- cat_logits: List of logits for categorical features
119
- num_true: Ground truth numeric values
120
- num_mask: Mask for numeric features
121
- cat_true: Ground truth categorical indices
122
- cat_mask: Mask for categorical features
123
- num_loss_weight: Weight for numeric loss
124
- cat_loss_weight: Weight for categorical loss
125
- device: Target device for computation
126
-
127
- Returns:
128
- Combined weighted reconstruction loss
129
- """
130
- num_loss = _compute_numeric_reconstruction_loss(
131
- num_pred, num_true, num_mask, num_loss_weight, device
132
- )
133
- cat_loss = _compute_categorical_reconstruction_loss(
134
- cat_logits, cat_true, cat_mask, cat_loss_weight, device
135
- )
136
- return num_loss + cat_loss
137
-
138
-
139
- # Scikit-Learn style wrapper for FTTransformer.
140
-
141
-
142
- class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
143
-
144
- # sklearn-style wrapper:
145
- # - num_cols: numeric feature column names
146
- # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
147
-
148
- @staticmethod
149
- def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
150
- num_cols_count = len(num_cols or [])
151
- if num_cols_count == 0:
152
- return 0
153
- if requested is not None:
154
- count = int(requested)
155
- if count <= 0:
156
- raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
157
- return count
158
- return max(1, num_cols_count)
159
-
160
- def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
161
- n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
162
- task_type: str = 'regression',
163
- tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
164
- weight_decay: float = 0.0,
165
- use_data_parallel: bool = True,
166
- use_ddp: bool = False,
167
- num_numeric_tokens: Optional[int] = None,
168
- loss_name: Optional[str] = None
169
- ):
170
- super().__init__()
171
-
172
- self.use_ddp = use_ddp
173
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
174
- False, 0, 0, 1)
175
- if self.use_ddp:
176
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
177
-
178
- self.model_nme = model_nme
179
- self.num_cols = list(num_cols)
180
- self.cat_cols = list(cat_cols)
181
- self.num_numeric_tokens = self.resolve_numeric_token_count(
182
- self.num_cols,
183
- self.cat_cols,
184
- num_numeric_tokens,
185
- )
186
- self.d_model = d_model
187
- self.n_heads = n_heads
188
- self.n_layers = n_layers
189
- self.dropout = dropout
190
- self.batch_num = batch_num
191
- self.epochs = epochs
192
- self.learning_rate = learning_rate
193
- self.weight_decay = weight_decay
194
- self.task_type = task_type
195
- self.patience = patience
196
- resolved_loss = normalize_loss_name(loss_name, self.task_type)
197
- if self.task_type == 'classification':
198
- self.loss_name = "logloss"
199
- self.tw_power = None # No Tweedie power for classification.
200
- else:
201
- if resolved_loss == "auto":
202
- resolved_loss = infer_loss_name_from_model_name(self.model_nme)
203
- self.loss_name = resolved_loss
204
- if self.loss_name == "tweedie":
205
- self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
206
- else:
207
- self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
208
-
209
- if self.is_ddp_enabled:
210
- self.device = torch.device(f"cuda:{self.local_rank}")
211
- elif torch.cuda.is_available():
212
- self.device = torch.device("cuda")
213
- elif torch.backends.mps.is_available():
214
- self.device = torch.device("mps")
215
- else:
216
- self.device = torch.device("cpu")
217
- self.cat_cardinalities = None
218
- self.cat_categories = {}
219
- self.cat_maps: Dict[str, Dict[Any, int]] = {}
220
- self.cat_str_maps: Dict[str, Dict[str, int]] = {}
221
- self._num_mean = None
222
- self._num_std = None
223
- self.ft = None
224
- self.use_data_parallel = bool(use_data_parallel)
225
- self.num_geo = 0
226
- self._geo_params: Dict[str, Any] = {}
227
- self.loss_curve_path: Optional[str] = None
228
- self.training_history: Dict[str, List[float]] = {
229
- "train": [], "val": []}
230
-
231
- def _build_model(self, X_train):
232
- num_numeric = len(self.num_cols)
233
- cat_cardinalities = []
234
-
235
- if num_numeric > 0:
236
- num_arr = X_train[self.num_cols].to_numpy(
237
- dtype=np.float32, copy=False)
238
- num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
239
- mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
240
- std = num_arr.std(axis=0).astype(np.float32, copy=False)
241
- std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
242
- self._num_mean = mean
243
- self._num_std = std
244
- else:
245
- self._num_mean = None
246
- self._num_std = None
247
-
248
- self.cat_maps = {}
249
- self.cat_str_maps = {}
250
- for col in self.cat_cols:
251
- cats = X_train[col].astype('category')
252
- categories = cats.cat.categories
253
- self.cat_categories[col] = categories # Store full category list from training.
254
- self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
255
- if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
256
- self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
257
-
258
- card = len(categories) + 1 # Reserve one extra class for unknown/missing.
259
- cat_cardinalities.append(card)
260
-
261
- self.cat_cardinalities = cat_cardinalities
262
-
263
- core = FTTransformerCore(
264
- num_numeric=num_numeric,
265
- cat_cardinalities=cat_cardinalities,
266
- d_model=self.d_model,
267
- n_heads=self.n_heads,
268
- n_layers=self.n_layers,
269
- dropout=self.dropout,
270
- task_type=self.task_type,
271
- num_geo=self.num_geo,
272
- num_numeric_tokens=self.num_numeric_tokens
273
- )
274
- use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
275
- if self.is_ddp_enabled:
276
- core = core.to(self.device)
277
- core = DDP(core, device_ids=[
278
- self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
279
- self.use_data_parallel = False
280
- elif use_dp:
281
- if self.use_ddp and not self.is_ddp_enabled:
282
- print(
283
- ">>> DDP requested but not initialized; falling back to DataParallel.")
284
- core = nn.DataParallel(core, device_ids=list(
285
- range(torch.cuda.device_count())))
286
- self.device = torch.device("cuda")
287
- self.use_data_parallel = True
288
- else:
289
- self.use_data_parallel = False
290
- self.ft = core.to(self.device)
291
-
292
- def _encode_cats(self, X):
293
- # Input DataFrame must include all categorical feature columns.
294
- # Return int64 array with shape (N, num_categorical_features).
295
-
296
- if not self.cat_cols:
297
- return np.zeros((len(X), 0), dtype='int64')
298
-
299
- n_rows = len(X)
300
- n_cols = len(self.cat_cols)
301
- X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
302
- for idx, col in enumerate(self.cat_cols):
303
- categories = self.cat_categories[col]
304
- mapping = self.cat_maps.get(col)
305
- if mapping is None:
306
- mapping = {cat: i for i, cat in enumerate(categories)}
307
- self.cat_maps[col] = mapping
308
- unknown_idx = len(categories)
309
- series = X[col]
310
- codes = series.map(mapping)
311
- unmapped = series.notna() & codes.isna()
312
- if unmapped.any():
313
- try:
314
- series_cast = series.astype(categories.dtype)
315
- except Exception:
316
- series_cast = None
317
- if series_cast is not None:
318
- codes = series_cast.map(mapping)
319
- unmapped = series_cast.notna() & codes.isna()
320
- if unmapped.any():
321
- str_map = self.cat_str_maps.get(col)
322
- if str_map is None:
323
- str_map = {str(cat): i for i, cat in enumerate(categories)}
324
- self.cat_str_maps[col] = str_map
325
- codes = series.astype(str).map(str_map)
326
- if pd.api.types.is_categorical_dtype(codes):
327
- codes = codes.astype("float")
328
- codes = codes.fillna(unknown_idx).astype(
329
- "int64", copy=False).to_numpy()
330
- X_cat_np[:, idx] = codes
331
- return X_cat_np
332
-
333
- def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
334
- return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
335
-
336
- def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
337
- return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
338
-
339
- @staticmethod
340
- def _validate_vector(arr, name: str, n_rows: int) -> None:
341
- if arr is None:
342
- return
343
- if isinstance(arr, pd.DataFrame):
344
- if arr.shape[1] != 1:
345
- raise ValueError(f"{name} must be 1d (single column).")
346
- length = len(arr)
347
- else:
348
- arr_np = np.asarray(arr)
349
- if arr_np.ndim == 0:
350
- raise ValueError(f"{name} must be 1d.")
351
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
352
- raise ValueError(f"{name} must be 1d or Nx1.")
353
- length = arr_np.shape[0]
354
- if length != n_rows:
355
- raise ValueError(
356
- f"{name} length {length} does not match X length {n_rows}."
357
- )
358
-
359
- def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
360
- if X is None:
361
- if allow_none:
362
- return None, None, None, None, None, False
363
- raise ValueError("Input features X must not be None.")
364
- if not isinstance(X, pd.DataFrame):
365
- raise ValueError("X must be a pandas DataFrame.")
366
- missing_cols = [
367
- col for col in (self.num_cols + self.cat_cols) if col not in X.columns
368
- ]
369
- if missing_cols:
370
- raise ValueError(f"X is missing required columns: {missing_cols}")
371
- n_rows = len(X)
372
- if y is not None:
373
- self._validate_vector(y, "y", n_rows)
374
- if w is not None:
375
- self._validate_vector(w, "w", n_rows)
376
-
377
- num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
378
- if not num_np.flags["OWNDATA"]:
379
- num_np = num_np.copy()
380
- num_np = np.nan_to_num(num_np, nan=0.0,
381
- posinf=0.0, neginf=0.0, copy=False)
382
- if self._num_mean is not None and self._num_std is not None and num_np.size:
383
- num_np = (num_np - self._num_mean) / self._num_std
384
- X_num = torch.as_tensor(num_np)
385
- if self.cat_cols:
386
- X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
387
- else:
388
- X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
389
-
390
- if geo_tokens is not None:
391
- geo_np = np.asarray(geo_tokens, dtype=np.float32)
392
- if geo_np.shape[0] != n_rows:
393
- raise ValueError(
394
- "geo_tokens length does not match X rows.")
395
- if geo_np.ndim == 1:
396
- geo_np = geo_np.reshape(-1, 1)
397
- elif self.num_geo > 0:
398
- raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
399
- else:
400
- geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
401
- X_geo = torch.as_tensor(geo_np)
402
-
403
- y_tensor = torch.as_tensor(
404
- y.to_numpy(dtype=np.float32, copy=False) if hasattr(
405
- y, "to_numpy") else np.asarray(y, dtype=np.float32)
406
- ).view(-1, 1) if y is not None else None
407
- if y_tensor is None:
408
- w_tensor = None
409
- elif w is not None:
410
- w_tensor = torch.as_tensor(
411
- w.to_numpy(dtype=np.float32, copy=False) if hasattr(
412
- w, "to_numpy") else np.asarray(w, dtype=np.float32)
413
- ).view(-1, 1)
414
- else:
415
- w_tensor = torch.ones_like(y_tensor)
416
- return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
417
-
418
- def fit(self, X_train, y_train, w_train=None,
419
- X_val=None, y_val=None, w_val=None, trial=None,
420
- geo_train=None, geo_val=None):
421
-
422
- # Build the underlying model on first fit.
423
- self.num_geo = geo_train.shape[1] if geo_train is not None else 0
424
- if self.ft is None:
425
- self._build_model(X_train)
426
-
427
- X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
428
- X_train, y_train, w_train, geo_train=geo_train)
429
- X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
430
- X_val, y_val, w_val, geo_val=geo_val)
431
-
432
- # --- Build DataLoader ---
433
- dataset = TabularDataset(
434
- X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
435
- )
436
-
437
- dataloader, accum_steps = self._build_dataloader(
438
- dataset,
439
- N=X_num_train.shape[0],
440
- base_bs_gpu=(2048, 1024, 512),
441
- base_bs_cpu=(256, 128),
442
- min_bs=64,
443
- target_effective_cuda=2048,
444
- target_effective_cpu=1024
445
- )
446
-
447
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
448
- self.dataloader_sampler = dataloader.sampler
449
- else:
450
- self.dataloader_sampler = None
451
-
452
- optimizer = torch.optim.Adam(
453
- self.ft.parameters(),
454
- lr=self.learning_rate,
455
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
456
- )
457
- scaler = GradScaler(enabled=(self.device.type == 'cuda'))
458
-
459
- X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
460
- val_dataloader = None
461
- if has_val:
462
- val_dataset = TabularDataset(
463
- X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
464
- )
465
- val_dataloader = self._build_val_dataloader(
466
- val_dataset, dataloader, accum_steps)
467
-
468
- # Check for both DataParallel and DDP wrappers
469
- is_data_parallel = isinstance(self.ft, (nn.DataParallel, DDP))
470
-
471
- def forward_fn(batch):
472
- X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
473
-
474
- # For DataParallel, inputs are automatically scattered; for DDP, move to local device
475
- if not isinstance(self.ft, nn.DataParallel):
476
- X_num_b = X_num_b.to(self.device, non_blocking=True)
477
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
478
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
479
- y_b = y_b.to(self.device, non_blocking=True)
480
- w_b = w_b.to(self.device, non_blocking=True)
481
-
482
- y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
483
- return y_pred, y_b, w_b
484
-
485
- def val_forward_fn():
486
- total_loss = 0.0
487
- total_weight = 0.0
488
- for batch in val_dataloader:
489
- X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
490
- if not isinstance(self.ft, nn.DataParallel):
491
- X_num_b = X_num_b.to(self.device, non_blocking=True)
492
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
493
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
494
- y_b = y_b.to(self.device, non_blocking=True)
495
- w_b = w_b.to(self.device, non_blocking=True)
496
-
497
- y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
498
-
499
- # Manually compute validation loss.
500
- losses = self._compute_losses(
501
- y_pred, y_b, apply_softplus=False)
502
-
503
- batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
504
- batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
505
-
506
- total_loss += batch_weighted_loss_sum.item()
507
- total_weight += batch_weight_sum.item()
508
-
509
- return total_loss / max(total_weight, EPS)
510
-
511
- clip_fn = None
512
- if self.device.type == 'cuda':
513
- def clip_fn(): return (scaler.unscale_(optimizer),
514
- clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
515
-
516
- best_state, history = self._train_model(
517
- self.ft,
518
- dataloader,
519
- accum_steps,
520
- optimizer,
521
- scaler,
522
- forward_fn,
523
- val_forward_fn if has_val else None,
524
- apply_softplus=False,
525
- clip_fn=clip_fn,
526
- trial=trial,
527
- loss_curve_path=getattr(self, "loss_curve_path", None)
528
- )
529
-
530
- if has_val and best_state is not None:
531
- # Load state into unwrapped module to match how it was saved
532
- base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
533
- base_module.load_state_dict(best_state)
534
- self.training_history = history
535
-
536
- def fit_unsupervised(self,
537
- X_train,
538
- X_val=None,
539
- trial: Optional[optuna.trial.Trial] = None,
540
- geo_train=None,
541
- geo_val=None,
542
- mask_prob_num: float = 0.15,
543
- mask_prob_cat: float = 0.15,
544
- num_loss_weight: float = 1.0,
545
- cat_loss_weight: float = 1.0) -> float:
546
- """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
547
- self.num_geo = geo_train.shape[1] if geo_train is not None else 0
548
- if self.ft is None:
549
- self._build_model(X_train)
550
-
551
- X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
552
- X_train, None, None, geo_tokens=geo_train, allow_none=True)
553
- has_val = X_val is not None
554
- if has_val:
555
- X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
556
- X_val, None, None, geo_tokens=geo_val, allow_none=True)
557
- else:
558
- X_num_val = X_cat_val = X_geo_val = None
559
-
560
- N = int(X_num.shape[0])
561
- num_dim = int(X_num.shape[1])
562
- cat_dim = int(X_cat.shape[1])
563
- device_type = self._device_type()
564
-
565
- gen = torch.Generator()
566
- gen.manual_seed(13 + int(getattr(self, "rank", 0)))
567
-
568
- base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
569
- cardinals = getattr(base_model, "cat_cardinalities", None) or []
570
- unknown_idx = torch.tensor(
571
- [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
572
-
573
- means = None
574
- if num_dim > 0:
575
- # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
576
- means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
577
-
578
- def _mask_inputs(X_num_in: torch.Tensor,
579
- X_cat_in: torch.Tensor,
580
- generator: torch.Generator):
581
- n_rows = int(X_num_in.shape[0])
582
- num_mask_local = None
583
- cat_mask_local = None
584
- X_num_masked_local = X_num_in
585
- X_cat_masked_local = X_cat_in
586
- if num_dim > 0:
587
- num_mask_local = (torch.rand(
588
- (n_rows, num_dim), generator=generator) < float(mask_prob_num))
589
- X_num_masked_local = X_num_in.clone()
590
- if num_mask_local.any():
591
- X_num_masked_local[num_mask_local] = means.expand_as(
592
- X_num_masked_local)[num_mask_local]
593
- if cat_dim > 0:
594
- cat_mask_local = (torch.rand(
595
- (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
596
- X_cat_masked_local = X_cat_in.clone()
597
- if cat_mask_local.any():
598
- X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
599
- X_cat_masked_local)[cat_mask_local]
600
- return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
601
-
602
- X_num_true = X_num if num_dim > 0 else None
603
- X_cat_true = X_cat if cat_dim > 0 else None
604
- X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
605
- X_num, X_cat, gen)
606
-
607
- dataset = MaskedTabularDataset(
608
- X_num_masked, X_cat_masked, X_geo,
609
- X_num_true, num_mask,
610
- X_cat_true, cat_mask
611
- )
612
- dataloader, accum_steps = self._build_dataloader(
613
- dataset,
614
- N=N,
615
- base_bs_gpu=(2048, 1024, 512),
616
- base_bs_cpu=(256, 128),
617
- min_bs=64,
618
- target_effective_cuda=2048,
619
- target_effective_cpu=1024
620
- )
621
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
622
- self.dataloader_sampler = dataloader.sampler
623
- else:
624
- self.dataloader_sampler = None
625
-
626
- optimizer = torch.optim.Adam(
627
- self.ft.parameters(),
628
- lr=self.learning_rate,
629
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
630
- )
631
- scaler = GradScaler(enabled=(device_type == 'cuda'))
632
-
633
- train_history: List[float] = []
634
- val_history: List[float] = []
635
- best_loss = float("inf")
636
- best_state = None
637
- patience_counter = 0
638
- is_ddp_model = isinstance(self.ft, DDP)
639
- use_collectives = dist.is_initialized() and is_ddp_model
640
-
641
- clip_fn = None
642
- if self.device.type == 'cuda':
643
- def clip_fn(): return (scaler.unscale_(optimizer),
644
- clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
645
-
646
- for epoch in range(1, int(self.epochs) + 1):
647
- if self.dataloader_sampler is not None:
648
- self.dataloader_sampler.set_epoch(epoch)
649
-
650
- self.ft.train()
651
- optimizer.zero_grad()
652
- epoch_loss_sum = 0.0
653
- epoch_count = 0.0
654
-
655
- for step, batch in enumerate(dataloader):
656
- is_update_step = ((step + 1) % accum_steps == 0) or \
657
- ((step + 1) == len(dataloader))
658
- sync_cm = self.ft.no_sync if (
659
- is_ddp_model and not is_update_step) else nullcontext
660
- with sync_cm():
661
- with autocast(enabled=(device_type == 'cuda')):
662
- X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
663
- X_num_b = X_num_b.to(self.device, non_blocking=True)
664
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
665
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
666
- num_true_b = None if num_true_b is None else num_true_b.to(
667
- self.device, non_blocking=True)
668
- num_mask_b = None if num_mask_b is None else num_mask_b.to(
669
- self.device, non_blocking=True)
670
- cat_true_b = None if cat_true_b is None else cat_true_b.to(
671
- self.device, non_blocking=True)
672
- cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
673
- self.device, non_blocking=True)
674
-
675
- num_pred, cat_logits = self.ft(
676
- X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
677
- batch_loss = _compute_reconstruction_loss(
678
- num_pred, cat_logits, num_true_b, num_mask_b,
679
- cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
680
- device=X_num_b.device)
681
- local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
682
- global_bad = local_bad
683
- if use_collectives:
684
- bad = torch.tensor(
685
- [local_bad],
686
- device=batch_loss.device,
687
- dtype=torch.int32,
688
- )
689
- dist.all_reduce(bad, op=dist.ReduceOp.MAX)
690
- global_bad = int(bad.item())
691
-
692
- if global_bad:
693
- msg = (
694
- f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
695
- f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
696
- )
697
- should_log = (not dist.is_initialized()
698
- or DistributedUtils.is_main_process())
699
- if should_log:
700
- print(msg, flush=True)
701
- print(
702
- f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
703
- f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
704
- f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
705
- flush=True,
706
- )
707
- if X_geo_b is not None:
708
- print(
709
- f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
710
- f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
711
- f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
712
- flush=True,
713
- )
714
- if trial is not None:
715
- raise optuna.TrialPruned(msg)
716
- raise RuntimeError(msg)
717
- loss_for_backward = batch_loss / float(accum_steps)
718
- scaler.scale(loss_for_backward).backward()
719
-
720
- if is_update_step:
721
- if clip_fn is not None:
722
- clip_fn()
723
- scaler.step(optimizer)
724
- scaler.update()
725
- optimizer.zero_grad()
726
-
727
- epoch_loss_sum += float(batch_loss.detach().item()) * \
728
- float(X_num_b.shape[0])
729
- epoch_count += float(X_num_b.shape[0])
730
-
731
- train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
732
-
733
- if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
734
- should_compute_val = (not dist.is_initialized()
735
- or DistributedUtils.is_main_process())
736
- loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
737
- "cpu")
738
- val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
739
-
740
- if should_compute_val:
741
- self.ft.eval()
742
- with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
743
- val_bs = min(
744
- int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
745
- total_val = 0.0
746
- total_n = 0.0
747
- for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
748
- end = min(
749
- int(X_num_val.shape[0]), start + max(1, val_bs))
750
- X_num_v_true_cpu = X_num_val[start:end]
751
- X_cat_v_true_cpu = X_cat_val[start:end]
752
- X_geo_v = X_geo_val[start:end].to(
753
- self.device, non_blocking=True)
754
- gen_val = torch.Generator()
755
- gen_val.manual_seed(10_000 + epoch + start)
756
- X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
757
- X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
758
- X_num_v_true = X_num_v_true_cpu.to(
759
- self.device, non_blocking=True)
760
- X_cat_v_true = X_cat_v_true_cpu.to(
761
- self.device, non_blocking=True)
762
- X_num_v = X_num_v_cpu.to(
763
- self.device, non_blocking=True)
764
- X_cat_v = X_cat_v_cpu.to(
765
- self.device, non_blocking=True)
766
- val_num_mask = None if val_num_mask is None else val_num_mask.to(
767
- self.device, non_blocking=True)
768
- val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
769
- self.device, non_blocking=True)
770
- num_pred_v, cat_logits_v = self.ft(
771
- X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
772
- loss_v = _compute_reconstruction_loss(
773
- num_pred_v, cat_logits_v,
774
- X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
775
- X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
776
- num_loss_weight, cat_loss_weight,
777
- device=X_num_v.device
778
- )
779
- if not torch.isfinite(loss_v):
780
- total_val = float("inf")
781
- total_n = 1.0
782
- break
783
- total_val += float(loss_v.detach().item()
784
- ) * float(end - start)
785
- total_n += float(end - start)
786
- val_loss_tensor[0] = total_val / max(total_n, 1.0)
787
-
788
- if use_collectives:
789
- dist.broadcast(val_loss_tensor, src=0)
790
- val_loss_value = float(val_loss_tensor.item())
791
- prune_now = False
792
- prune_msg = None
793
- if not np.isfinite(val_loss_value):
794
- prune_now = True
795
- prune_msg = (
796
- f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
797
- f"(epoch={epoch}, val_loss={val_loss_value})"
798
- )
799
- val_history.append(val_loss_value)
800
-
801
- if val_loss_value < best_loss:
802
- best_loss = val_loss_value
803
- # Efficiently clone state_dict - only clone tensor data, not DDP metadata
804
- base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
805
- best_state = {
806
- k: v.detach().clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)
807
- for k, v in base_module.state_dict().items()
808
- }
809
- patience_counter = 0
810
- else:
811
- patience_counter += 1
812
- if best_state is not None and patience_counter >= int(self.patience):
813
- break
814
-
815
- if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
816
- trial.report(val_loss_value, epoch)
817
- if trial.should_prune():
818
- prune_now = True
819
-
820
- if use_collectives:
821
- flag = torch.tensor(
822
- [1 if prune_now else 0],
823
- device=loss_tensor_device,
824
- dtype=torch.int32,
825
- )
826
- dist.broadcast(flag, src=0)
827
- prune_now = bool(flag.item())
828
-
829
- if prune_now:
830
- if prune_msg:
831
- raise optuna.TrialPruned(prune_msg)
832
- raise optuna.TrialPruned()
833
-
834
- self.training_history = {"train": train_history, "val": val_history}
835
- self._plot_loss_curve(self.training_history, getattr(
836
- self, "loss_curve_path", None))
837
- if has_val and best_state is not None:
838
- # Load state into unwrapped module to match how it was saved
839
- base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
840
- base_module.load_state_dict(best_state)
841
- return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
842
-
843
- def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
844
- # X_test must include all numeric/categorical columns; geo_tokens is optional.
845
-
846
- self.ft.eval()
847
- X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
848
- X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
849
-
850
- num_rows = X_num.shape[0]
851
- if num_rows == 0:
852
- return np.empty(0, dtype=np.float32)
853
-
854
- device = self.device if isinstance(
855
- self.device, torch.device) else torch.device(self.device)
856
-
857
- def resolve_batch_size(n_rows: int) -> int:
858
- if batch_size is not None:
859
- return max(1, min(int(batch_size), n_rows))
860
- # Estimate a safe batch size based on model size to avoid attention OOM.
861
- token_cnt = self.num_numeric_tokens + len(self.cat_cols)
862
- if self.num_geo > 0:
863
- token_cnt += 1
864
- approx_units = max(1, token_cnt * max(1, self.d_model))
865
- if device.type == 'cuda':
866
- if approx_units >= 8192:
867
- base = 512
868
- elif approx_units >= 4096:
869
- base = 1024
870
- else:
871
- base = 2048
872
- else:
873
- base = 512
874
- return max(1, min(base, n_rows))
875
-
876
- eff_batch = resolve_batch_size(num_rows)
877
- preds: List[torch.Tensor] = []
878
-
879
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
880
- with inference_cm():
881
- for start in range(0, num_rows, eff_batch):
882
- end = min(num_rows, start + eff_batch)
883
- X_num_b = X_num[start:end].to(device, non_blocking=True)
884
- X_cat_b = X_cat[start:end].to(device, non_blocking=True)
885
- X_geo_b = X_geo[start:end].to(device, non_blocking=True)
886
- pred_chunk = self.ft(
887
- X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
888
- preds.append(pred_chunk.cpu())
889
-
890
- y_pred = torch.cat(preds, dim=0).numpy()
891
-
892
- if return_embedding:
893
- return y_pred
894
-
895
- if self.task_type == 'classification':
896
- # Convert logits to probabilities.
897
- y_pred = 1 / (1 + np.exp(-y_pred))
898
- else:
899
- # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
900
- y_pred = np.clip(y_pred, 1e-6, None)
901
- return y_pred.ravel()
902
-
903
- def set_params(self, params: dict):
904
-
905
- # Keep sklearn-style behavior.
906
- # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
907
-
908
- for key, value in params.items():
909
- if hasattr(self, key):
910
- setattr(self, key, value)
911
- else:
912
- raise ValueError(f"Parameter {key} not found in model.")
913
- return self
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from contextlib import nullcontext
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import optuna
9
+ import pandas as pd
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.nn.utils import clip_grad_norm_
17
+
18
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
19
+ from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
20
+ from ins_pricing.utils import EPS, get_logger, log_print
21
+ from ins_pricing.utils.losses import (
22
+ infer_loss_name_from_model_name,
23
+ normalize_loss_name,
24
+ resolve_tweedie_power,
25
+ )
26
+ from ins_pricing.modelling.bayesopt.models.model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
27
+
28
+ _logger = get_logger("ins_pricing.modelling.bayesopt.models.model_ft_trainer")
29
+
30
+
31
+ def _log(*args, **kwargs) -> None:
32
+ log_print(_logger, *args, **kwargs)
33
+
34
+
35
+ # --- Helper functions for reconstruction loss computation ---
36
+
37
+
38
+ def _compute_numeric_reconstruction_loss(
39
+ num_pred: Optional[torch.Tensor],
40
+ num_true: Optional[torch.Tensor],
41
+ num_mask: Optional[torch.Tensor],
42
+ loss_weight: float,
43
+ device: torch.device,
44
+ ) -> torch.Tensor:
45
+ """Compute MSE loss for numeric feature reconstruction.
46
+
47
+ Args:
48
+ num_pred: Predicted numeric values (N, num_features)
49
+ num_true: Ground truth numeric values (N, num_features)
50
+ num_mask: Boolean mask indicating which values were masked (N, num_features)
51
+ loss_weight: Weight to apply to the loss
52
+ device: Target device for computation
53
+
54
+ Returns:
55
+ Weighted MSE loss for masked numeric features
56
+ """
57
+ if num_pred is None or num_true is None or num_mask is None:
58
+ return torch.zeros((), device=device, dtype=torch.float32)
59
+
60
+ num_mask = num_mask.to(dtype=torch.bool)
61
+ if not num_mask.any():
62
+ return torch.zeros((), device=device, dtype=torch.float32)
63
+
64
+ diff = num_pred - num_true
65
+ mse = diff * diff
66
+ return float(loss_weight) * mse[num_mask].mean()
67
+
68
+
69
+ def _compute_categorical_reconstruction_loss(
70
+ cat_logits: Optional[List[torch.Tensor]],
71
+ cat_true: Optional[torch.Tensor],
72
+ cat_mask: Optional[torch.Tensor],
73
+ loss_weight: float,
74
+ device: torch.device,
75
+ ) -> torch.Tensor:
76
+ """Compute cross-entropy loss for categorical feature reconstruction.
77
+
78
+ Args:
79
+ cat_logits: List of logits for each categorical feature
80
+ cat_true: Ground truth categorical indices (N, num_cat_features)
81
+ cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
82
+ loss_weight: Weight to apply to the loss
83
+ device: Target device for computation
84
+
85
+ Returns:
86
+ Weighted cross-entropy loss for masked categorical features
87
+ """
88
+ if not cat_logits or cat_true is None or cat_mask is None:
89
+ return torch.zeros((), device=device, dtype=torch.float32)
90
+
91
+ cat_mask = cat_mask.to(dtype=torch.bool)
92
+ cat_losses: List[torch.Tensor] = []
93
+
94
+ for j, logits in enumerate(cat_logits):
95
+ mask_j = cat_mask[:, j]
96
+ if not mask_j.any():
97
+ continue
98
+ targets = cat_true[:, j]
99
+ cat_losses.append(
100
+ F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
101
+ )
102
+
103
+ if not cat_losses:
104
+ return torch.zeros((), device=device, dtype=torch.float32)
105
+
106
+ return float(loss_weight) * torch.stack(cat_losses).mean()
107
+
108
+
109
+ def _compute_reconstruction_loss(
110
+ num_pred: Optional[torch.Tensor],
111
+ cat_logits: Optional[List[torch.Tensor]],
112
+ num_true: Optional[torch.Tensor],
113
+ num_mask: Optional[torch.Tensor],
114
+ cat_true: Optional[torch.Tensor],
115
+ cat_mask: Optional[torch.Tensor],
116
+ num_loss_weight: float,
117
+ cat_loss_weight: float,
118
+ device: torch.device,
119
+ ) -> torch.Tensor:
120
+ """Compute combined reconstruction loss for masked tabular data.
121
+
122
+ This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
123
+
124
+ Args:
125
+ num_pred: Predicted numeric values
126
+ cat_logits: List of logits for categorical features
127
+ num_true: Ground truth numeric values
128
+ num_mask: Mask for numeric features
129
+ cat_true: Ground truth categorical indices
130
+ cat_mask: Mask for categorical features
131
+ num_loss_weight: Weight for numeric loss
132
+ cat_loss_weight: Weight for categorical loss
133
+ device: Target device for computation
134
+
135
+ Returns:
136
+ Combined weighted reconstruction loss
137
+ """
138
+ num_loss = _compute_numeric_reconstruction_loss(
139
+ num_pred, num_true, num_mask, num_loss_weight, device
140
+ )
141
+ cat_loss = _compute_categorical_reconstruction_loss(
142
+ cat_logits, cat_true, cat_mask, cat_loss_weight, device
143
+ )
144
+ return num_loss + cat_loss
145
+
146
+
147
+ # Scikit-Learn style wrapper for FTTransformer.
148
+
149
+
150
+ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
151
+
152
+ # sklearn-style wrapper:
153
+ # - num_cols: numeric feature column names
154
+ # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
155
+
156
+ @staticmethod
157
+ def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
158
+ num_cols_count = len(num_cols or [])
159
+ if num_cols_count == 0:
160
+ return 0
161
+ if requested is not None:
162
+ count = int(requested)
163
+ if count <= 0:
164
+ raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
165
+ return count
166
+ return max(1, num_cols_count)
167
+
168
+ def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
169
+ n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
170
+ task_type: str = 'regression',
171
+ tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
172
+ weight_decay: float = 0.0,
173
+ use_data_parallel: bool = True,
174
+ use_ddp: bool = False,
175
+ num_numeric_tokens: Optional[int] = None,
176
+ loss_name: Optional[str] = None
177
+ ):
178
+ super().__init__()
179
+
180
+ self.use_ddp = use_ddp
181
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
182
+ False, 0, 0, 1)
183
+ if self.use_ddp:
184
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
185
+
186
+ self.model_nme = model_nme
187
+ self.num_cols = list(num_cols)
188
+ self.cat_cols = list(cat_cols)
189
+ self.num_numeric_tokens = self.resolve_numeric_token_count(
190
+ self.num_cols,
191
+ self.cat_cols,
192
+ num_numeric_tokens,
193
+ )
194
+ self.d_model = d_model
195
+ self.n_heads = n_heads
196
+ self.n_layers = n_layers
197
+ self.dropout = dropout
198
+ self.batch_num = batch_num
199
+ self.epochs = epochs
200
+ self.learning_rate = learning_rate
201
+ self.weight_decay = weight_decay
202
+ self.task_type = task_type
203
+ self.patience = patience
204
+ resolved_loss = normalize_loss_name(loss_name, self.task_type)
205
+ if self.task_type == 'classification':
206
+ self.loss_name = "logloss"
207
+ self.tw_power = None # No Tweedie power for classification.
208
+ else:
209
+ if resolved_loss == "auto":
210
+ resolved_loss = infer_loss_name_from_model_name(self.model_nme)
211
+ self.loss_name = resolved_loss
212
+ if self.loss_name == "tweedie":
213
+ self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
214
+ else:
215
+ self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
216
+
217
+ if self.is_ddp_enabled:
218
+ self.device = torch.device(f"cuda:{self.local_rank}")
219
+ elif torch.cuda.is_available():
220
+ self.device = torch.device("cuda")
221
+ elif torch.backends.mps.is_available():
222
+ self.device = torch.device("mps")
223
+ else:
224
+ self.device = torch.device("cpu")
225
+ self.cat_cardinalities = None
226
+ self.cat_categories = {}
227
+ self.cat_maps: Dict[str, Dict[Any, int]] = {}
228
+ self.cat_str_maps: Dict[str, Dict[str, int]] = {}
229
+ self._num_mean = None
230
+ self._num_std = None
231
+ self.ft = None
232
+ self.use_data_parallel = bool(use_data_parallel)
233
+ self.num_geo = 0
234
+ self._geo_params: Dict[str, Any] = {}
235
+ self.loss_curve_path: Optional[str] = None
236
+ self.training_history: Dict[str, List[float]] = {
237
+ "train": [], "val": []}
238
+
239
+ def _build_model(self, X_train):
240
+ num_numeric = len(self.num_cols)
241
+ cat_cardinalities = []
242
+
243
+ if num_numeric > 0:
244
+ num_arr = X_train[self.num_cols].to_numpy(
245
+ dtype=np.float32, copy=False)
246
+ num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
247
+ mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
248
+ std = num_arr.std(axis=0).astype(np.float32, copy=False)
249
+ std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
250
+ self._num_mean = mean
251
+ self._num_std = std
252
+ else:
253
+ self._num_mean = None
254
+ self._num_std = None
255
+
256
+ self.cat_maps = {}
257
+ self.cat_str_maps = {}
258
+ for col in self.cat_cols:
259
+ cats = X_train[col].astype('category')
260
+ categories = cats.cat.categories
261
+ self.cat_categories[col] = categories # Store full category list from training.
262
+ self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
263
+ if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
264
+ self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
265
+
266
+ card = len(categories) + 1 # Reserve one extra class for unknown/missing.
267
+ cat_cardinalities.append(card)
268
+
269
+ self.cat_cardinalities = cat_cardinalities
270
+
271
+ core = FTTransformerCore(
272
+ num_numeric=num_numeric,
273
+ cat_cardinalities=cat_cardinalities,
274
+ d_model=self.d_model,
275
+ n_heads=self.n_heads,
276
+ n_layers=self.n_layers,
277
+ dropout=self.dropout,
278
+ task_type=self.task_type,
279
+ num_geo=self.num_geo,
280
+ num_numeric_tokens=self.num_numeric_tokens
281
+ )
282
+ use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
283
+ if self.is_ddp_enabled:
284
+ core = core.to(self.device)
285
+ core = DDP(core, device_ids=[
286
+ self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
287
+ self.use_data_parallel = False
288
+ elif use_dp:
289
+ if self.use_ddp and not self.is_ddp_enabled:
290
+ _log(
291
+ ">>> DDP requested but not initialized; falling back to DataParallel.")
292
+ core = nn.DataParallel(core, device_ids=list(
293
+ range(torch.cuda.device_count())))
294
+ self.device = torch.device("cuda")
295
+ self.use_data_parallel = True
296
+ else:
297
+ self.use_data_parallel = False
298
+ self.ft = core.to(self.device)
299
+
300
+ def _encode_cats(self, X):
301
+ # Input DataFrame must include all categorical feature columns.
302
+ # Return int64 array with shape (N, num_categorical_features).
303
+
304
+ if not self.cat_cols:
305
+ return np.zeros((len(X), 0), dtype='int64')
306
+
307
+ n_rows = len(X)
308
+ n_cols = len(self.cat_cols)
309
+ X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
310
+ for idx, col in enumerate(self.cat_cols):
311
+ categories = self.cat_categories[col]
312
+ mapping = self.cat_maps.get(col)
313
+ if mapping is None:
314
+ mapping = {cat: i for i, cat in enumerate(categories)}
315
+ self.cat_maps[col] = mapping
316
+ unknown_idx = len(categories)
317
+ series = X[col]
318
+ codes = series.map(mapping)
319
+ unmapped = series.notna() & codes.isna()
320
+ if unmapped.any():
321
+ try:
322
+ series_cast = series.astype(categories.dtype)
323
+ except Exception:
324
+ series_cast = None
325
+ if series_cast is not None:
326
+ codes = series_cast.map(mapping)
327
+ unmapped = series_cast.notna() & codes.isna()
328
+ if unmapped.any():
329
+ str_map = self.cat_str_maps.get(col)
330
+ if str_map is None:
331
+ str_map = {str(cat): i for i, cat in enumerate(categories)}
332
+ self.cat_str_maps[col] = str_map
333
+ codes = series.astype(str).map(str_map)
334
+ if pd.api.types.is_categorical_dtype(codes):
335
+ codes = codes.astype("float")
336
+ codes = codes.fillna(unknown_idx).astype(
337
+ "int64", copy=False).to_numpy()
338
+ X_cat_np[:, idx] = codes
339
+ return X_cat_np
340
+
341
+ def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
342
+ return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
343
+
344
+ def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
345
+ return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
346
+
347
+ @staticmethod
348
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
349
+ if arr is None:
350
+ return
351
+ if isinstance(arr, pd.DataFrame):
352
+ if arr.shape[1] != 1:
353
+ raise ValueError(f"{name} must be 1d (single column).")
354
+ length = len(arr)
355
+ else:
356
+ arr_np = np.asarray(arr)
357
+ if arr_np.ndim == 0:
358
+ raise ValueError(f"{name} must be 1d.")
359
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
360
+ raise ValueError(f"{name} must be 1d or Nx1.")
361
+ length = arr_np.shape[0]
362
+ if length != n_rows:
363
+ raise ValueError(
364
+ f"{name} length {length} does not match X length {n_rows}."
365
+ )
366
+
367
+ def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
368
+ if X is None:
369
+ if allow_none:
370
+ return None, None, None, None, None, False
371
+ raise ValueError("Input features X must not be None.")
372
+ if not isinstance(X, pd.DataFrame):
373
+ raise ValueError("X must be a pandas DataFrame.")
374
+ missing_cols = [
375
+ col for col in (self.num_cols + self.cat_cols) if col not in X.columns
376
+ ]
377
+ if missing_cols:
378
+ raise ValueError(f"X is missing required columns: {missing_cols}")
379
+ n_rows = len(X)
380
+ if y is not None:
381
+ self._validate_vector(y, "y", n_rows)
382
+ if w is not None:
383
+ self._validate_vector(w, "w", n_rows)
384
+
385
+ num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
386
+ if not num_np.flags["OWNDATA"]:
387
+ num_np = num_np.copy()
388
+ num_np = np.nan_to_num(num_np, nan=0.0,
389
+ posinf=0.0, neginf=0.0, copy=False)
390
+ if self._num_mean is not None and self._num_std is not None and num_np.size:
391
+ num_np = (num_np - self._num_mean) / self._num_std
392
+ X_num = torch.as_tensor(num_np)
393
+ if self.cat_cols:
394
+ X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
395
+ else:
396
+ X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
397
+
398
+ if geo_tokens is not None:
399
+ geo_np = np.asarray(geo_tokens, dtype=np.float32)
400
+ if geo_np.shape[0] != n_rows:
401
+ raise ValueError(
402
+ "geo_tokens length does not match X rows.")
403
+ if geo_np.ndim == 1:
404
+ geo_np = geo_np.reshape(-1, 1)
405
+ elif self.num_geo > 0:
406
+ raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
407
+ else:
408
+ geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
409
+ X_geo = torch.as_tensor(geo_np)
410
+
411
+ y_tensor = torch.as_tensor(
412
+ y.to_numpy(dtype=np.float32, copy=False) if hasattr(
413
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
414
+ ).view(-1, 1) if y is not None else None
415
+ if y_tensor is None:
416
+ w_tensor = None
417
+ elif w is not None:
418
+ w_tensor = torch.as_tensor(
419
+ w.to_numpy(dtype=np.float32, copy=False) if hasattr(
420
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
421
+ ).view(-1, 1)
422
+ else:
423
+ w_tensor = torch.ones_like(y_tensor)
424
+ return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
425
+
426
+ def fit(self, X_train, y_train, w_train=None,
427
+ X_val=None, y_val=None, w_val=None, trial=None,
428
+ geo_train=None, geo_val=None):
429
+
430
+ # Build the underlying model on first fit.
431
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
432
+ if self.ft is None:
433
+ self._build_model(X_train)
434
+
435
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
436
+ X_train, y_train, w_train, geo_train=geo_train)
437
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
438
+ X_val, y_val, w_val, geo_val=geo_val)
439
+
440
+ # --- Build DataLoader ---
441
+ dataset = TabularDataset(
442
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
443
+ )
444
+
445
+ dataloader, accum_steps = self._build_dataloader(
446
+ dataset,
447
+ N=X_num_train.shape[0],
448
+ base_bs_gpu=(2048, 1024, 512),
449
+ base_bs_cpu=(256, 128),
450
+ min_bs=64,
451
+ target_effective_cuda=2048,
452
+ target_effective_cpu=1024
453
+ )
454
+
455
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
456
+ self.dataloader_sampler = dataloader.sampler
457
+ else:
458
+ self.dataloader_sampler = None
459
+
460
+ optimizer = torch.optim.Adam(
461
+ self.ft.parameters(),
462
+ lr=self.learning_rate,
463
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
464
+ )
465
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
466
+
467
+ X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
468
+ val_dataloader = None
469
+ if has_val:
470
+ val_dataset = TabularDataset(
471
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
472
+ )
473
+ val_dataloader = self._build_val_dataloader(
474
+ val_dataset, dataloader, accum_steps)
475
+
476
+ # Check for both DataParallel and DDP wrappers
477
+ is_data_parallel = isinstance(self.ft, (nn.DataParallel, DDP))
478
+
479
+ def forward_fn(batch):
480
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
481
+
482
+ # For DataParallel, inputs are automatically scattered; for DDP, move to local device
483
+ if not isinstance(self.ft, nn.DataParallel):
484
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
485
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
486
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
487
+ y_b = y_b.to(self.device, non_blocking=True)
488
+ w_b = w_b.to(self.device, non_blocking=True)
489
+
490
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
491
+ return y_pred, y_b, w_b
492
+
493
+ def val_forward_fn():
494
+ total_loss = 0.0
495
+ total_weight = 0.0
496
+ for batch in val_dataloader:
497
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
498
+ if not isinstance(self.ft, nn.DataParallel):
499
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
500
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
501
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
502
+ y_b = y_b.to(self.device, non_blocking=True)
503
+ w_b = w_b.to(self.device, non_blocking=True)
504
+
505
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
506
+
507
+ # Manually compute validation loss.
508
+ losses = self._compute_losses(
509
+ y_pred, y_b, apply_softplus=False)
510
+
511
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
512
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
513
+
514
+ total_loss += batch_weighted_loss_sum.item()
515
+ total_weight += batch_weight_sum.item()
516
+
517
+ return total_loss / max(total_weight, EPS)
518
+
519
+ clip_fn = None
520
+ if self.device.type == 'cuda':
521
+ def clip_fn(): return (scaler.unscale_(optimizer),
522
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
523
+
524
+ best_state, history = self._train_model(
525
+ self.ft,
526
+ dataloader,
527
+ accum_steps,
528
+ optimizer,
529
+ scaler,
530
+ forward_fn,
531
+ val_forward_fn if has_val else None,
532
+ apply_softplus=False,
533
+ clip_fn=clip_fn,
534
+ trial=trial,
535
+ loss_curve_path=getattr(self, "loss_curve_path", None)
536
+ )
537
+
538
+ if has_val and best_state is not None:
539
+ # Load state into unwrapped module to match how it was saved
540
+ base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
541
+ base_module.load_state_dict(best_state)
542
+ self.training_history = history
543
+
544
+ def fit_unsupervised(self,
545
+ X_train,
546
+ X_val=None,
547
+ trial: Optional[optuna.trial.Trial] = None,
548
+ geo_train=None,
549
+ geo_val=None,
550
+ mask_prob_num: float = 0.15,
551
+ mask_prob_cat: float = 0.15,
552
+ num_loss_weight: float = 1.0,
553
+ cat_loss_weight: float = 1.0) -> float:
554
+ """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
555
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
556
+ if self.ft is None:
557
+ self._build_model(X_train)
558
+
559
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
560
+ X_train, None, None, geo_tokens=geo_train, allow_none=True)
561
+ has_val = X_val is not None
562
+ if has_val:
563
+ X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
564
+ X_val, None, None, geo_tokens=geo_val, allow_none=True)
565
+ else:
566
+ X_num_val = X_cat_val = X_geo_val = None
567
+
568
+ N = int(X_num.shape[0])
569
+ num_dim = int(X_num.shape[1])
570
+ cat_dim = int(X_cat.shape[1])
571
+ device_type = self._device_type()
572
+
573
+ gen = torch.Generator()
574
+ gen.manual_seed(13 + int(getattr(self, "rank", 0)))
575
+
576
+ base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
577
+ cardinals = getattr(base_model, "cat_cardinalities", None) or []
578
+ unknown_idx = torch.tensor(
579
+ [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
580
+
581
+ means = None
582
+ if num_dim > 0:
583
+ # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
584
+ means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
585
+
586
+ def _mask_inputs(X_num_in: torch.Tensor,
587
+ X_cat_in: torch.Tensor,
588
+ generator: torch.Generator):
589
+ n_rows = int(X_num_in.shape[0])
590
+ num_mask_local = None
591
+ cat_mask_local = None
592
+ X_num_masked_local = X_num_in
593
+ X_cat_masked_local = X_cat_in
594
+ if num_dim > 0:
595
+ num_mask_local = (torch.rand(
596
+ (n_rows, num_dim), generator=generator) < float(mask_prob_num))
597
+ X_num_masked_local = X_num_in.clone()
598
+ if num_mask_local.any():
599
+ X_num_masked_local[num_mask_local] = means.expand_as(
600
+ X_num_masked_local)[num_mask_local]
601
+ if cat_dim > 0:
602
+ cat_mask_local = (torch.rand(
603
+ (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
604
+ X_cat_masked_local = X_cat_in.clone()
605
+ if cat_mask_local.any():
606
+ X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
607
+ X_cat_masked_local)[cat_mask_local]
608
+ return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
609
+
610
+ X_num_true = X_num if num_dim > 0 else None
611
+ X_cat_true = X_cat if cat_dim > 0 else None
612
+ X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
613
+ X_num, X_cat, gen)
614
+
615
+ dataset = MaskedTabularDataset(
616
+ X_num_masked, X_cat_masked, X_geo,
617
+ X_num_true, num_mask,
618
+ X_cat_true, cat_mask
619
+ )
620
+ dataloader, accum_steps = self._build_dataloader(
621
+ dataset,
622
+ N=N,
623
+ base_bs_gpu=(2048, 1024, 512),
624
+ base_bs_cpu=(256, 128),
625
+ min_bs=64,
626
+ target_effective_cuda=2048,
627
+ target_effective_cpu=1024
628
+ )
629
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
630
+ self.dataloader_sampler = dataloader.sampler
631
+ else:
632
+ self.dataloader_sampler = None
633
+
634
+ optimizer = torch.optim.Adam(
635
+ self.ft.parameters(),
636
+ lr=self.learning_rate,
637
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
638
+ )
639
+ scaler = GradScaler(enabled=(device_type == 'cuda'))
640
+
641
+ train_history: List[float] = []
642
+ val_history: List[float] = []
643
+ best_loss = float("inf")
644
+ best_state = None
645
+ patience_counter = 0
646
+ is_ddp_model = isinstance(self.ft, DDP)
647
+ use_collectives = dist.is_initialized() and is_ddp_model
648
+
649
+ clip_fn = None
650
+ if self.device.type == 'cuda':
651
+ def clip_fn(): return (scaler.unscale_(optimizer),
652
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
653
+
654
+ for epoch in range(1, int(self.epochs) + 1):
655
+ if self.dataloader_sampler is not None:
656
+ self.dataloader_sampler.set_epoch(epoch)
657
+
658
+ self.ft.train()
659
+ optimizer.zero_grad()
660
+ epoch_loss_sum = 0.0
661
+ epoch_count = 0.0
662
+
663
+ for step, batch in enumerate(dataloader):
664
+ is_update_step = ((step + 1) % accum_steps == 0) or \
665
+ ((step + 1) == len(dataloader))
666
+ sync_cm = self.ft.no_sync if (
667
+ is_ddp_model and not is_update_step) else nullcontext
668
+ with sync_cm():
669
+ with autocast(enabled=(device_type == 'cuda')):
670
+ X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
671
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
672
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
673
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
674
+ num_true_b = None if num_true_b is None else num_true_b.to(
675
+ self.device, non_blocking=True)
676
+ num_mask_b = None if num_mask_b is None else num_mask_b.to(
677
+ self.device, non_blocking=True)
678
+ cat_true_b = None if cat_true_b is None else cat_true_b.to(
679
+ self.device, non_blocking=True)
680
+ cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
681
+ self.device, non_blocking=True)
682
+
683
+ num_pred, cat_logits = self.ft(
684
+ X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
685
+ batch_loss = _compute_reconstruction_loss(
686
+ num_pred, cat_logits, num_true_b, num_mask_b,
687
+ cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
688
+ device=X_num_b.device)
689
+ local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
690
+ global_bad = local_bad
691
+ if use_collectives:
692
+ bad = torch.tensor(
693
+ [local_bad],
694
+ device=batch_loss.device,
695
+ dtype=torch.int32,
696
+ )
697
+ dist.all_reduce(bad, op=dist.ReduceOp.MAX)
698
+ global_bad = int(bad.item())
699
+
700
+ if global_bad:
701
+ msg = (
702
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
703
+ f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
704
+ )
705
+ should_log = (not dist.is_initialized()
706
+ or DistributedUtils.is_main_process())
707
+ if should_log:
708
+ _log(msg, flush=True)
709
+ _log(
710
+ f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
711
+ f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
712
+ f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
713
+ flush=True,
714
+ )
715
+ if X_geo_b is not None:
716
+ _log(
717
+ f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
718
+ f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
719
+ f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
720
+ flush=True,
721
+ )
722
+ if trial is not None:
723
+ raise optuna.TrialPruned(msg)
724
+ raise RuntimeError(msg)
725
+ loss_for_backward = batch_loss / float(accum_steps)
726
+ scaler.scale(loss_for_backward).backward()
727
+
728
+ if is_update_step:
729
+ if clip_fn is not None:
730
+ clip_fn()
731
+ scaler.step(optimizer)
732
+ scaler.update()
733
+ optimizer.zero_grad()
734
+
735
+ epoch_loss_sum += float(batch_loss.detach().item()) * \
736
+ float(X_num_b.shape[0])
737
+ epoch_count += float(X_num_b.shape[0])
738
+
739
+ train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
740
+
741
+ if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
742
+ should_compute_val = (not dist.is_initialized()
743
+ or DistributedUtils.is_main_process())
744
+ loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
745
+ "cpu")
746
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
747
+
748
+ if should_compute_val:
749
+ self.ft.eval()
750
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
751
+ val_bs = min(
752
+ int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
753
+ total_val = 0.0
754
+ total_n = 0.0
755
+ for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
756
+ end = min(
757
+ int(X_num_val.shape[0]), start + max(1, val_bs))
758
+ X_num_v_true_cpu = X_num_val[start:end]
759
+ X_cat_v_true_cpu = X_cat_val[start:end]
760
+ X_geo_v = X_geo_val[start:end].to(
761
+ self.device, non_blocking=True)
762
+ gen_val = torch.Generator()
763
+ gen_val.manual_seed(10_000 + epoch + start)
764
+ X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
765
+ X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
766
+ X_num_v_true = X_num_v_true_cpu.to(
767
+ self.device, non_blocking=True)
768
+ X_cat_v_true = X_cat_v_true_cpu.to(
769
+ self.device, non_blocking=True)
770
+ X_num_v = X_num_v_cpu.to(
771
+ self.device, non_blocking=True)
772
+ X_cat_v = X_cat_v_cpu.to(
773
+ self.device, non_blocking=True)
774
+ val_num_mask = None if val_num_mask is None else val_num_mask.to(
775
+ self.device, non_blocking=True)
776
+ val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
777
+ self.device, non_blocking=True)
778
+ num_pred_v, cat_logits_v = self.ft(
779
+ X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
780
+ loss_v = _compute_reconstruction_loss(
781
+ num_pred_v, cat_logits_v,
782
+ X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
783
+ X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
784
+ num_loss_weight, cat_loss_weight,
785
+ device=X_num_v.device
786
+ )
787
+ if not torch.isfinite(loss_v):
788
+ total_val = float("inf")
789
+ total_n = 1.0
790
+ break
791
+ total_val += float(loss_v.detach().item()
792
+ ) * float(end - start)
793
+ total_n += float(end - start)
794
+ val_loss_tensor[0] = total_val / max(total_n, 1.0)
795
+
796
+ if use_collectives:
797
+ dist.broadcast(val_loss_tensor, src=0)
798
+ val_loss_value = float(val_loss_tensor.item())
799
+ prune_now = False
800
+ prune_msg = None
801
+ if not np.isfinite(val_loss_value):
802
+ prune_now = True
803
+ prune_msg = (
804
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
805
+ f"(epoch={epoch}, val_loss={val_loss_value})"
806
+ )
807
+ val_history.append(val_loss_value)
808
+
809
+ if val_loss_value < best_loss:
810
+ best_loss = val_loss_value
811
+ # Efficiently clone state_dict - only clone tensor data, not DDP metadata
812
+ base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
813
+ best_state = {
814
+ k: v.detach().clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)
815
+ for k, v in base_module.state_dict().items()
816
+ }
817
+ patience_counter = 0
818
+ else:
819
+ patience_counter += 1
820
+ if best_state is not None and patience_counter >= int(self.patience):
821
+ break
822
+
823
+ if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
824
+ trial.report(val_loss_value, epoch)
825
+ if trial.should_prune():
826
+ prune_now = True
827
+
828
+ if use_collectives:
829
+ flag = torch.tensor(
830
+ [1 if prune_now else 0],
831
+ device=loss_tensor_device,
832
+ dtype=torch.int32,
833
+ )
834
+ dist.broadcast(flag, src=0)
835
+ prune_now = bool(flag.item())
836
+
837
+ if prune_now:
838
+ if prune_msg:
839
+ raise optuna.TrialPruned(prune_msg)
840
+ raise optuna.TrialPruned()
841
+
842
+ self.training_history = {"train": train_history, "val": val_history}
843
+ self._plot_loss_curve(self.training_history, getattr(
844
+ self, "loss_curve_path", None))
845
+ if has_val and best_state is not None:
846
+ # Load state into unwrapped module to match how it was saved
847
+ base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
848
+ base_module.load_state_dict(best_state)
849
+ return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
850
+
851
+ def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
852
+ # X_test must include all numeric/categorical columns; geo_tokens is optional.
853
+
854
+ self.ft.eval()
855
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
856
+ X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
857
+
858
+ num_rows = X_num.shape[0]
859
+ if num_rows == 0:
860
+ return np.empty(0, dtype=np.float32)
861
+
862
+ device = self.device if isinstance(
863
+ self.device, torch.device) else torch.device(self.device)
864
+
865
+ def resolve_batch_size(n_rows: int) -> int:
866
+ if batch_size is not None:
867
+ return max(1, min(int(batch_size), n_rows))
868
+ # Estimate a safe batch size based on model size to avoid attention OOM.
869
+ token_cnt = self.num_numeric_tokens + len(self.cat_cols)
870
+ if self.num_geo > 0:
871
+ token_cnt += 1
872
+ approx_units = max(1, token_cnt * max(1, self.d_model))
873
+ if device.type == 'cuda':
874
+ if approx_units >= 8192:
875
+ base = 512
876
+ elif approx_units >= 4096:
877
+ base = 1024
878
+ else:
879
+ base = 2048
880
+ else:
881
+ base = 512
882
+ return max(1, min(base, n_rows))
883
+
884
+ eff_batch = resolve_batch_size(num_rows)
885
+ preds: List[torch.Tensor] = []
886
+
887
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
888
+ with inference_cm():
889
+ for start in range(0, num_rows, eff_batch):
890
+ end = min(num_rows, start + eff_batch)
891
+ X_num_b = X_num[start:end].to(device, non_blocking=True)
892
+ X_cat_b = X_cat[start:end].to(device, non_blocking=True)
893
+ X_geo_b = X_geo[start:end].to(device, non_blocking=True)
894
+ pred_chunk = self.ft(
895
+ X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
896
+ preds.append(pred_chunk.cpu())
897
+
898
+ y_pred = torch.cat(preds, dim=0).numpy()
899
+
900
+ if return_embedding:
901
+ return y_pred
902
+
903
+ if self.task_type == 'classification':
904
+ # Convert logits to probabilities.
905
+ y_pred = 1 / (1 + np.exp(-y_pred))
906
+ else:
907
+ # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
908
+ y_pred = np.clip(y_pred, 1e-6, None)
909
+ return y_pred.ravel()
910
+
911
+ def set_params(self, params: dict):
912
+
913
+ # Keep sklearn-style behavior.
914
+ # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
915
+
916
+ for key, value in params.items():
917
+ if hasattr(self, key):
918
+ setattr(self, key, value)
919
+ else:
920
+ raise ValueError(f"Parameter {key} not found in model.")
921
+ return self