ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,808 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from contextlib import nullcontext
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import optuna
9
+ import pandas as pd
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.nn.utils import clip_grad_norm_
17
+
18
+ from ..utils import DistributedUtils, EPS, TorchTrainerMixin
19
+ from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
20
+
21
+
22
+ # Scikit-Learn style wrapper for FTTransformer.
23
+
24
+
25
+ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
26
+
27
+ # sklearn-style wrapper:
28
+ # - num_cols: numeric feature column names
29
+ # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
30
+
31
+ @staticmethod
32
+ def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
33
+ num_cols_count = len(num_cols or [])
34
+ if num_cols_count == 0:
35
+ return 0
36
+ if requested is not None:
37
+ count = int(requested)
38
+ if count <= 0:
39
+ raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
40
+ return count
41
+ return max(1, num_cols_count)
42
+
43
+ def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
44
+ n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
45
+ task_type: str = 'regression',
46
+ tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
47
+ weight_decay: float = 0.0,
48
+ use_data_parallel: bool = True,
49
+ use_ddp: bool = False,
50
+ num_numeric_tokens: Optional[int] = None
51
+ ):
52
+ super().__init__()
53
+
54
+ self.use_ddp = use_ddp
55
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
56
+ False, 0, 0, 1)
57
+ if self.use_ddp:
58
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
59
+
60
+ self.model_nme = model_nme
61
+ self.num_cols = list(num_cols)
62
+ self.cat_cols = list(cat_cols)
63
+ self.num_numeric_tokens = self.resolve_numeric_token_count(
64
+ self.num_cols,
65
+ self.cat_cols,
66
+ num_numeric_tokens,
67
+ )
68
+ self.d_model = d_model
69
+ self.n_heads = n_heads
70
+ self.n_layers = n_layers
71
+ self.dropout = dropout
72
+ self.batch_num = batch_num
73
+ self.epochs = epochs
74
+ self.learning_rate = learning_rate
75
+ self.weight_decay = weight_decay
76
+ self.task_type = task_type
77
+ self.patience = patience
78
+ if self.task_type == 'classification':
79
+ self.tw_power = None # No Tweedie power for classification.
80
+ elif 'f' in self.model_nme:
81
+ self.tw_power = 1.0
82
+ elif 's' in self.model_nme:
83
+ self.tw_power = 2.0
84
+ else:
85
+ self.tw_power = tweedie_power
86
+
87
+ if self.is_ddp_enabled:
88
+ self.device = torch.device(f"cuda:{self.local_rank}")
89
+ elif torch.cuda.is_available():
90
+ self.device = torch.device("cuda")
91
+ elif torch.backends.mps.is_available():
92
+ self.device = torch.device("mps")
93
+ else:
94
+ self.device = torch.device("cpu")
95
+ self.cat_cardinalities = None
96
+ self.cat_categories = {}
97
+ self.cat_maps: Dict[str, Dict[Any, int]] = {}
98
+ self.cat_str_maps: Dict[str, Dict[str, int]] = {}
99
+ self._num_mean = None
100
+ self._num_std = None
101
+ self.ft = None
102
+ self.use_data_parallel = bool(use_data_parallel)
103
+ self.num_geo = 0
104
+ self._geo_params: Dict[str, Any] = {}
105
+ self.loss_curve_path: Optional[str] = None
106
+ self.training_history: Dict[str, List[float]] = {
107
+ "train": [], "val": []}
108
+
109
+ def _build_model(self, X_train):
110
+ num_numeric = len(self.num_cols)
111
+ cat_cardinalities = []
112
+
113
+ if num_numeric > 0:
114
+ num_arr = X_train[self.num_cols].to_numpy(
115
+ dtype=np.float32, copy=False)
116
+ num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
117
+ mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
118
+ std = num_arr.std(axis=0).astype(np.float32, copy=False)
119
+ std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
120
+ self._num_mean = mean
121
+ self._num_std = std
122
+ else:
123
+ self._num_mean = None
124
+ self._num_std = None
125
+
126
+ self.cat_maps = {}
127
+ self.cat_str_maps = {}
128
+ for col in self.cat_cols:
129
+ cats = X_train[col].astype('category')
130
+ categories = cats.cat.categories
131
+ self.cat_categories[col] = categories # Store full category list from training.
132
+ self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
133
+ if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
134
+ self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
135
+
136
+ card = len(categories) + 1 # Reserve one extra class for unknown/missing.
137
+ cat_cardinalities.append(card)
138
+
139
+ self.cat_cardinalities = cat_cardinalities
140
+
141
+ core = FTTransformerCore(
142
+ num_numeric=num_numeric,
143
+ cat_cardinalities=cat_cardinalities,
144
+ d_model=self.d_model,
145
+ n_heads=self.n_heads,
146
+ n_layers=self.n_layers,
147
+ dropout=self.dropout,
148
+ task_type=self.task_type,
149
+ num_geo=self.num_geo,
150
+ num_numeric_tokens=self.num_numeric_tokens
151
+ )
152
+ use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
153
+ if self.is_ddp_enabled:
154
+ core = core.to(self.device)
155
+ core = DDP(core, device_ids=[
156
+ self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
157
+ self.use_data_parallel = False
158
+ elif use_dp:
159
+ if self.use_ddp and not self.is_ddp_enabled:
160
+ print(
161
+ ">>> DDP requested but not initialized; falling back to DataParallel.")
162
+ core = nn.DataParallel(core, device_ids=list(
163
+ range(torch.cuda.device_count())))
164
+ self.device = torch.device("cuda")
165
+ self.use_data_parallel = True
166
+ else:
167
+ self.use_data_parallel = False
168
+ self.ft = core.to(self.device)
169
+
170
+ def _encode_cats(self, X):
171
+ # Input DataFrame must include all categorical feature columns.
172
+ # Return int64 array with shape (N, num_categorical_features).
173
+
174
+ if not self.cat_cols:
175
+ return np.zeros((len(X), 0), dtype='int64')
176
+
177
+ n_rows = len(X)
178
+ n_cols = len(self.cat_cols)
179
+ X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
180
+ for idx, col in enumerate(self.cat_cols):
181
+ categories = self.cat_categories[col]
182
+ mapping = self.cat_maps.get(col)
183
+ if mapping is None:
184
+ mapping = {cat: i for i, cat in enumerate(categories)}
185
+ self.cat_maps[col] = mapping
186
+ unknown_idx = len(categories)
187
+ series = X[col]
188
+ codes = series.map(mapping)
189
+ unmapped = series.notna() & codes.isna()
190
+ if unmapped.any():
191
+ try:
192
+ series_cast = series.astype(categories.dtype)
193
+ except Exception:
194
+ series_cast = None
195
+ if series_cast is not None:
196
+ codes = series_cast.map(mapping)
197
+ unmapped = series_cast.notna() & codes.isna()
198
+ if unmapped.any():
199
+ str_map = self.cat_str_maps.get(col)
200
+ if str_map is None:
201
+ str_map = {str(cat): i for i, cat in enumerate(categories)}
202
+ self.cat_str_maps[col] = str_map
203
+ codes = series.astype(str).map(str_map)
204
+ if pd.api.types.is_categorical_dtype(codes):
205
+ codes = codes.astype("float")
206
+ codes = codes.fillna(unknown_idx).astype(
207
+ "int64", copy=False).to_numpy()
208
+ X_cat_np[:, idx] = codes
209
+ return X_cat_np
210
+
211
+ def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
212
+ return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
213
+
214
+ def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
215
+ return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
216
+
217
+ @staticmethod
218
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
219
+ if arr is None:
220
+ return
221
+ if isinstance(arr, pd.DataFrame):
222
+ if arr.shape[1] != 1:
223
+ raise ValueError(f"{name} must be 1d (single column).")
224
+ length = len(arr)
225
+ else:
226
+ arr_np = np.asarray(arr)
227
+ if arr_np.ndim == 0:
228
+ raise ValueError(f"{name} must be 1d.")
229
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
230
+ raise ValueError(f"{name} must be 1d or Nx1.")
231
+ length = arr_np.shape[0]
232
+ if length != n_rows:
233
+ raise ValueError(
234
+ f"{name} length {length} does not match X length {n_rows}."
235
+ )
236
+
237
+ def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
238
+ if X is None:
239
+ if allow_none:
240
+ return None, None, None, None, None, False
241
+ raise ValueError("Input features X must not be None.")
242
+ if not isinstance(X, pd.DataFrame):
243
+ raise ValueError("X must be a pandas DataFrame.")
244
+ missing_cols = [
245
+ col for col in (self.num_cols + self.cat_cols) if col not in X.columns
246
+ ]
247
+ if missing_cols:
248
+ raise ValueError(f"X is missing required columns: {missing_cols}")
249
+ n_rows = len(X)
250
+ if y is not None:
251
+ self._validate_vector(y, "y", n_rows)
252
+ if w is not None:
253
+ self._validate_vector(w, "w", n_rows)
254
+
255
+ num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
256
+ if not num_np.flags["OWNDATA"]:
257
+ num_np = num_np.copy()
258
+ num_np = np.nan_to_num(num_np, nan=0.0,
259
+ posinf=0.0, neginf=0.0, copy=False)
260
+ if self._num_mean is not None and self._num_std is not None and num_np.size:
261
+ num_np = (num_np - self._num_mean) / self._num_std
262
+ X_num = torch.as_tensor(num_np)
263
+ if self.cat_cols:
264
+ X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
265
+ else:
266
+ X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
267
+
268
+ if geo_tokens is not None:
269
+ geo_np = np.asarray(geo_tokens, dtype=np.float32)
270
+ if geo_np.shape[0] != n_rows:
271
+ raise ValueError(
272
+ "geo_tokens length does not match X rows.")
273
+ if geo_np.ndim == 1:
274
+ geo_np = geo_np.reshape(-1, 1)
275
+ elif self.num_geo > 0:
276
+ raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
277
+ else:
278
+ geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
279
+ X_geo = torch.as_tensor(geo_np)
280
+
281
+ y_tensor = torch.as_tensor(
282
+ y.to_numpy(dtype=np.float32, copy=False) if hasattr(
283
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
284
+ ).view(-1, 1) if y is not None else None
285
+ if y_tensor is None:
286
+ w_tensor = None
287
+ elif w is not None:
288
+ w_tensor = torch.as_tensor(
289
+ w.to_numpy(dtype=np.float32, copy=False) if hasattr(
290
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
291
+ ).view(-1, 1)
292
+ else:
293
+ w_tensor = torch.ones_like(y_tensor)
294
+ return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
295
+
296
+ def fit(self, X_train, y_train, w_train=None,
297
+ X_val=None, y_val=None, w_val=None, trial=None,
298
+ geo_train=None, geo_val=None):
299
+
300
+ # Build the underlying model on first fit.
301
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
302
+ if self.ft is None:
303
+ self._build_model(X_train)
304
+
305
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
306
+ X_train, y_train, w_train, geo_train=geo_train)
307
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
308
+ X_val, y_val, w_val, geo_val=geo_val)
309
+
310
+ # --- Build DataLoader ---
311
+ dataset = TabularDataset(
312
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
313
+ )
314
+
315
+ dataloader, accum_steps = self._build_dataloader(
316
+ dataset,
317
+ N=X_num_train.shape[0],
318
+ base_bs_gpu=(2048, 1024, 512),
319
+ base_bs_cpu=(256, 128),
320
+ min_bs=64,
321
+ target_effective_cuda=2048,
322
+ target_effective_cpu=1024
323
+ )
324
+
325
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
326
+ self.dataloader_sampler = dataloader.sampler
327
+ else:
328
+ self.dataloader_sampler = None
329
+
330
+ optimizer = torch.optim.Adam(
331
+ self.ft.parameters(),
332
+ lr=self.learning_rate,
333
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
334
+ )
335
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
336
+
337
+ X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
338
+ val_dataloader = None
339
+ if has_val:
340
+ val_dataset = TabularDataset(
341
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
342
+ )
343
+ val_dataloader = self._build_val_dataloader(
344
+ val_dataset, dataloader, accum_steps)
345
+
346
+ is_data_parallel = isinstance(self.ft, nn.DataParallel)
347
+
348
+ def forward_fn(batch):
349
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
350
+
351
+ if not is_data_parallel:
352
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
353
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
354
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
355
+ y_b = y_b.to(self.device, non_blocking=True)
356
+ w_b = w_b.to(self.device, non_blocking=True)
357
+
358
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
359
+ return y_pred, y_b, w_b
360
+
361
+ def val_forward_fn():
362
+ total_loss = 0.0
363
+ total_weight = 0.0
364
+ for batch in val_dataloader:
365
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
366
+ if not is_data_parallel:
367
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
368
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
369
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
370
+ y_b = y_b.to(self.device, non_blocking=True)
371
+ w_b = w_b.to(self.device, non_blocking=True)
372
+
373
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
374
+
375
+ # Manually compute validation loss.
376
+ losses = self._compute_losses(
377
+ y_pred, y_b, apply_softplus=False)
378
+
379
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
380
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
381
+
382
+ total_loss += batch_weighted_loss_sum.item()
383
+ total_weight += batch_weight_sum.item()
384
+
385
+ return total_loss / max(total_weight, EPS)
386
+
387
+ clip_fn = None
388
+ if self.device.type == 'cuda':
389
+ def clip_fn(): return (scaler.unscale_(optimizer),
390
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
391
+
392
+ best_state, history = self._train_model(
393
+ self.ft,
394
+ dataloader,
395
+ accum_steps,
396
+ optimizer,
397
+ scaler,
398
+ forward_fn,
399
+ val_forward_fn if has_val else None,
400
+ apply_softplus=False,
401
+ clip_fn=clip_fn,
402
+ trial=trial,
403
+ loss_curve_path=getattr(self, "loss_curve_path", None)
404
+ )
405
+
406
+ if has_val and best_state is not None:
407
+ self.ft.load_state_dict(best_state)
408
+ self.training_history = history
409
+
410
+ def fit_unsupervised(self,
411
+ X_train,
412
+ X_val=None,
413
+ trial: Optional[optuna.trial.Trial] = None,
414
+ geo_train=None,
415
+ geo_val=None,
416
+ mask_prob_num: float = 0.15,
417
+ mask_prob_cat: float = 0.15,
418
+ num_loss_weight: float = 1.0,
419
+ cat_loss_weight: float = 1.0) -> float:
420
+ """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
421
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
422
+ if self.ft is None:
423
+ self._build_model(X_train)
424
+
425
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
426
+ X_train, None, None, geo_tokens=geo_train, allow_none=True)
427
+ has_val = X_val is not None
428
+ if has_val:
429
+ X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
430
+ X_val, None, None, geo_tokens=geo_val, allow_none=True)
431
+ else:
432
+ X_num_val = X_cat_val = X_geo_val = None
433
+
434
+ N = int(X_num.shape[0])
435
+ num_dim = int(X_num.shape[1])
436
+ cat_dim = int(X_cat.shape[1])
437
+ device_type = self._device_type()
438
+
439
+ gen = torch.Generator()
440
+ gen.manual_seed(13 + int(getattr(self, "rank", 0)))
441
+
442
+ base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
443
+ cardinals = getattr(base_model, "cat_cardinalities", None) or []
444
+ unknown_idx = torch.tensor(
445
+ [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
446
+
447
+ means = None
448
+ if num_dim > 0:
449
+ # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
450
+ means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
451
+
452
+ def _mask_inputs(X_num_in: torch.Tensor,
453
+ X_cat_in: torch.Tensor,
454
+ generator: torch.Generator):
455
+ n_rows = int(X_num_in.shape[0])
456
+ num_mask_local = None
457
+ cat_mask_local = None
458
+ X_num_masked_local = X_num_in
459
+ X_cat_masked_local = X_cat_in
460
+ if num_dim > 0:
461
+ num_mask_local = (torch.rand(
462
+ (n_rows, num_dim), generator=generator) < float(mask_prob_num))
463
+ X_num_masked_local = X_num_in.clone()
464
+ if num_mask_local.any():
465
+ X_num_masked_local[num_mask_local] = means.expand_as(
466
+ X_num_masked_local)[num_mask_local]
467
+ if cat_dim > 0:
468
+ cat_mask_local = (torch.rand(
469
+ (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
470
+ X_cat_masked_local = X_cat_in.clone()
471
+ if cat_mask_local.any():
472
+ X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
473
+ X_cat_masked_local)[cat_mask_local]
474
+ return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
475
+
476
+ X_num_true = X_num if num_dim > 0 else None
477
+ X_cat_true = X_cat if cat_dim > 0 else None
478
+ X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
479
+ X_num, X_cat, gen)
480
+
481
+ dataset = MaskedTabularDataset(
482
+ X_num_masked, X_cat_masked, X_geo,
483
+ X_num_true, num_mask,
484
+ X_cat_true, cat_mask
485
+ )
486
+ dataloader, accum_steps = self._build_dataloader(
487
+ dataset,
488
+ N=N,
489
+ base_bs_gpu=(2048, 1024, 512),
490
+ base_bs_cpu=(256, 128),
491
+ min_bs=64,
492
+ target_effective_cuda=2048,
493
+ target_effective_cpu=1024
494
+ )
495
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
496
+ self.dataloader_sampler = dataloader.sampler
497
+ else:
498
+ self.dataloader_sampler = None
499
+
500
+ optimizer = torch.optim.Adam(
501
+ self.ft.parameters(),
502
+ lr=self.learning_rate,
503
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
504
+ )
505
+ scaler = GradScaler(enabled=(device_type == 'cuda'))
506
+
507
+ def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
508
+ loss = torch.zeros((), device=device, dtype=torch.float32)
509
+
510
+ if num_pred is not None and num_true_b is not None and num_mask_b is not None:
511
+ num_mask_b = num_mask_b.to(dtype=torch.bool)
512
+ if num_mask_b.any():
513
+ diff = num_pred - num_true_b
514
+ mse = diff * diff
515
+ loss = loss + float(num_loss_weight) * \
516
+ mse[num_mask_b].mean()
517
+
518
+ if cat_logits and cat_true_b is not None and cat_mask_b is not None:
519
+ cat_mask_b = cat_mask_b.to(dtype=torch.bool)
520
+ cat_losses: List[torch.Tensor] = []
521
+ for j, logits in enumerate(cat_logits):
522
+ mask_j = cat_mask_b[:, j]
523
+ if not mask_j.any():
524
+ continue
525
+ targets = cat_true_b[:, j]
526
+ cat_losses.append(
527
+ F.cross_entropy(logits, targets, reduction='none')[
528
+ mask_j].mean()
529
+ )
530
+ if cat_losses:
531
+ loss = loss + float(cat_loss_weight) * \
532
+ torch.stack(cat_losses).mean()
533
+ return loss
534
+
535
+ train_history: List[float] = []
536
+ val_history: List[float] = []
537
+ best_loss = float("inf")
538
+ best_state = None
539
+ patience_counter = 0
540
+ is_ddp_model = isinstance(self.ft, DDP)
541
+
542
+ clip_fn = None
543
+ if self.device.type == 'cuda':
544
+ def clip_fn(): return (scaler.unscale_(optimizer),
545
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
546
+
547
+ for epoch in range(1, int(self.epochs) + 1):
548
+ if self.dataloader_sampler is not None:
549
+ self.dataloader_sampler.set_epoch(epoch)
550
+
551
+ self.ft.train()
552
+ optimizer.zero_grad()
553
+ epoch_loss_sum = 0.0
554
+ epoch_count = 0.0
555
+
556
+ for step, batch in enumerate(dataloader):
557
+ is_update_step = ((step + 1) % accum_steps == 0) or \
558
+ ((step + 1) == len(dataloader))
559
+ sync_cm = self.ft.no_sync if (
560
+ is_ddp_model and not is_update_step) else nullcontext
561
+ with sync_cm():
562
+ with autocast(enabled=(device_type == 'cuda')):
563
+ X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
564
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
565
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
566
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
567
+ num_true_b = None if num_true_b is None else num_true_b.to(
568
+ self.device, non_blocking=True)
569
+ num_mask_b = None if num_mask_b is None else num_mask_b.to(
570
+ self.device, non_blocking=True)
571
+ cat_true_b = None if cat_true_b is None else cat_true_b.to(
572
+ self.device, non_blocking=True)
573
+ cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
574
+ self.device, non_blocking=True)
575
+
576
+ num_pred, cat_logits = self.ft(
577
+ X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
578
+ batch_loss = _batch_recon_loss(
579
+ num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device=X_num_b.device)
580
+ local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
581
+ global_bad = local_bad
582
+ if dist.is_initialized():
583
+ bad = torch.tensor(
584
+ [local_bad],
585
+ device=batch_loss.device,
586
+ dtype=torch.int32,
587
+ )
588
+ dist.all_reduce(bad, op=dist.ReduceOp.MAX)
589
+ global_bad = int(bad.item())
590
+
591
+ if global_bad:
592
+ msg = (
593
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
594
+ f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
595
+ )
596
+ should_log = (not dist.is_initialized()
597
+ or DistributedUtils.is_main_process())
598
+ if should_log:
599
+ print(msg, flush=True)
600
+ print(
601
+ f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
602
+ f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
603
+ f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
604
+ flush=True,
605
+ )
606
+ if X_geo_b is not None:
607
+ print(
608
+ f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
609
+ f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
610
+ f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
611
+ flush=True,
612
+ )
613
+ if trial is not None:
614
+ raise optuna.TrialPruned(msg)
615
+ raise RuntimeError(msg)
616
+ loss_for_backward = batch_loss / float(accum_steps)
617
+ scaler.scale(loss_for_backward).backward()
618
+
619
+ if is_update_step:
620
+ if clip_fn is not None:
621
+ clip_fn()
622
+ scaler.step(optimizer)
623
+ scaler.update()
624
+ optimizer.zero_grad()
625
+
626
+ epoch_loss_sum += float(batch_loss.detach().item()) * \
627
+ float(X_num_b.shape[0])
628
+ epoch_count += float(X_num_b.shape[0])
629
+
630
+ train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
631
+
632
+ if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
633
+ should_compute_val = (not dist.is_initialized()
634
+ or DistributedUtils.is_main_process())
635
+ loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
636
+ "cpu")
637
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
638
+
639
+ if should_compute_val:
640
+ self.ft.eval()
641
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
642
+ val_bs = min(
643
+ int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
644
+ total_val = 0.0
645
+ total_n = 0.0
646
+ for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
647
+ end = min(
648
+ int(X_num_val.shape[0]), start + max(1, val_bs))
649
+ X_num_v_true_cpu = X_num_val[start:end]
650
+ X_cat_v_true_cpu = X_cat_val[start:end]
651
+ X_geo_v = X_geo_val[start:end].to(
652
+ self.device, non_blocking=True)
653
+ gen_val = torch.Generator()
654
+ gen_val.manual_seed(10_000 + epoch + start)
655
+ X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
656
+ X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
657
+ X_num_v_true = X_num_v_true_cpu.to(
658
+ self.device, non_blocking=True)
659
+ X_cat_v_true = X_cat_v_true_cpu.to(
660
+ self.device, non_blocking=True)
661
+ X_num_v = X_num_v_cpu.to(
662
+ self.device, non_blocking=True)
663
+ X_cat_v = X_cat_v_cpu.to(
664
+ self.device, non_blocking=True)
665
+ val_num_mask = None if val_num_mask is None else val_num_mask.to(
666
+ self.device, non_blocking=True)
667
+ val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
668
+ self.device, non_blocking=True)
669
+ num_pred_v, cat_logits_v = self.ft(
670
+ X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
671
+ loss_v = _batch_recon_loss(
672
+ num_pred_v, cat_logits_v,
673
+ X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
674
+ X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
675
+ device=X_num_v.device
676
+ )
677
+ if not torch.isfinite(loss_v):
678
+ total_val = float("inf")
679
+ total_n = 1.0
680
+ break
681
+ total_val += float(loss_v.detach().item()
682
+ ) * float(end - start)
683
+ total_n += float(end - start)
684
+ val_loss_tensor[0] = total_val / max(total_n, 1.0)
685
+
686
+ if dist.is_initialized():
687
+ dist.broadcast(val_loss_tensor, src=0)
688
+ val_loss_value = float(val_loss_tensor.item())
689
+ prune_now = False
690
+ prune_msg = None
691
+ if not np.isfinite(val_loss_value):
692
+ prune_now = True
693
+ prune_msg = (
694
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
695
+ f"(epoch={epoch}, val_loss={val_loss_value})"
696
+ )
697
+ val_history.append(val_loss_value)
698
+
699
+ if val_loss_value < best_loss:
700
+ best_loss = val_loss_value
701
+ best_state = {
702
+ k: (v.clone() if isinstance(
703
+ v, torch.Tensor) else copy.deepcopy(v))
704
+ for k, v in self.ft.state_dict().items()
705
+ }
706
+ patience_counter = 0
707
+ else:
708
+ patience_counter += 1
709
+ if best_state is not None and patience_counter >= int(self.patience):
710
+ break
711
+
712
+ if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
713
+ trial.report(val_loss_value, epoch)
714
+ if trial.should_prune():
715
+ prune_now = True
716
+
717
+ if dist.is_initialized():
718
+ flag = torch.tensor(
719
+ [1 if prune_now else 0],
720
+ device=loss_tensor_device,
721
+ dtype=torch.int32,
722
+ )
723
+ dist.broadcast(flag, src=0)
724
+ prune_now = bool(flag.item())
725
+
726
+ if prune_now:
727
+ if prune_msg:
728
+ raise optuna.TrialPruned(prune_msg)
729
+ raise optuna.TrialPruned()
730
+
731
+ self.training_history = {"train": train_history, "val": val_history}
732
+ self._plot_loss_curve(self.training_history, getattr(
733
+ self, "loss_curve_path", None))
734
+ if has_val and best_state is not None:
735
+ self.ft.load_state_dict(best_state)
736
+ return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
737
+
738
+ def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
739
+ # X_test must include all numeric/categorical columns; geo_tokens is optional.
740
+
741
+ self.ft.eval()
742
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
743
+ X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
744
+
745
+ num_rows = X_num.shape[0]
746
+ if num_rows == 0:
747
+ return np.empty(0, dtype=np.float32)
748
+
749
+ device = self.device if isinstance(
750
+ self.device, torch.device) else torch.device(self.device)
751
+
752
+ def resolve_batch_size(n_rows: int) -> int:
753
+ if batch_size is not None:
754
+ return max(1, min(int(batch_size), n_rows))
755
+ # Estimate a safe batch size based on model size to avoid attention OOM.
756
+ token_cnt = self.num_numeric_tokens + len(self.cat_cols)
757
+ if self.num_geo > 0:
758
+ token_cnt += 1
759
+ approx_units = max(1, token_cnt * max(1, self.d_model))
760
+ if device.type == 'cuda':
761
+ if approx_units >= 8192:
762
+ base = 512
763
+ elif approx_units >= 4096:
764
+ base = 1024
765
+ else:
766
+ base = 2048
767
+ else:
768
+ base = 512
769
+ return max(1, min(base, n_rows))
770
+
771
+ eff_batch = resolve_batch_size(num_rows)
772
+ preds: List[torch.Tensor] = []
773
+
774
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
775
+ with inference_cm():
776
+ for start in range(0, num_rows, eff_batch):
777
+ end = min(num_rows, start + eff_batch)
778
+ X_num_b = X_num[start:end].to(device, non_blocking=True)
779
+ X_cat_b = X_cat[start:end].to(device, non_blocking=True)
780
+ X_geo_b = X_geo[start:end].to(device, non_blocking=True)
781
+ pred_chunk = self.ft(
782
+ X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
783
+ preds.append(pred_chunk.cpu())
784
+
785
+ y_pred = torch.cat(preds, dim=0).numpy()
786
+
787
+ if return_embedding:
788
+ return y_pred
789
+
790
+ if self.task_type == 'classification':
791
+ # Convert logits to probabilities.
792
+ y_pred = 1 / (1 + np.exp(-y_pred))
793
+ else:
794
+ # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
795
+ y_pred = np.clip(y_pred, 1e-6, None)
796
+ return y_pred.ravel()
797
+
798
+ def set_params(self, params: dict):
799
+
800
+ # Keep sklearn-style behavior.
801
+ # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
802
+
803
+ for key, value in params.items():
804
+ if hasattr(self, key):
805
+ setattr(self, key, value)
806
+ else:
807
+ raise ValueError(f"Parameter {key} not found in model.")
808
+ return self