ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,2203 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import hashlib
5
+ import math
6
+ import os
7
+ import time
8
+ from contextlib import nullcontext
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import optuna
14
+ import pandas as pd
15
+ import torch
16
+ import torch.distributed as dist
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from sklearn.neighbors import NearestNeighbors
20
+ from torch.cuda.amp import autocast, GradScaler
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from torch.nn.utils import clip_grad_norm_
23
+ from torch.utils.data import Dataset, TensorDataset
24
+
25
+ from .utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
26
+
27
+ try:
28
+ from torch_geometric.nn import knn_graph
29
+ from torch_geometric.utils import add_self_loops, to_undirected
30
+ _PYG_AVAILABLE = True
31
+ except Exception:
32
+ knn_graph = None # type: ignore
33
+ add_self_loops = None # type: ignore
34
+ to_undirected = None # type: ignore
35
+ _PYG_AVAILABLE = False
36
+
37
+ try:
38
+ import pynndescent
39
+ _PYNN_AVAILABLE = True
40
+ except Exception:
41
+ pynndescent = None # type: ignore
42
+ _PYNN_AVAILABLE = False
43
+
44
+ _GNN_MPS_WARNED = False
45
+
46
+ # =============================================================================
47
+ # ResNet model and sklearn-style wrapper
48
+ # =============================================================================
49
+
50
+ # ResNet model definition
51
+ # Residual block: two linear layers + ReLU + residual connection
52
+ # ResBlock inherits nn.Module
53
+ class ResBlock(nn.Module):
54
+ def __init__(self, dim: int, dropout: float = 0.1,
55
+ use_layernorm: bool = False, residual_scale: float = 0.1,
56
+ stochastic_depth: float = 0.0
57
+ ):
58
+ super().__init__()
59
+ self.use_layernorm = use_layernorm
60
+
61
+ if use_layernorm:
62
+ Norm = nn.LayerNorm # Normalize the last dimension
63
+ else:
64
+ def Norm(d): return nn.BatchNorm1d(d) # Keep a switch to try BN
65
+
66
+ self.norm1 = Norm(dim)
67
+ self.fc1 = nn.Linear(dim, dim, bias=True)
68
+ self.act = nn.ReLU(inplace=True)
69
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
70
+ # Enable post-second-layer norm if needed: self.norm2 = Norm(dim)
71
+ self.fc2 = nn.Linear(dim, dim, bias=True)
72
+
73
+ # Residual scaling to stabilize early training
74
+ self.res_scale = nn.Parameter(
75
+ torch.tensor(residual_scale, dtype=torch.float32)
76
+ )
77
+ self.stochastic_depth = max(0.0, float(stochastic_depth))
78
+
79
+ def _drop_path(self, x: torch.Tensor) -> torch.Tensor:
80
+ if self.stochastic_depth <= 0.0 or not self.training:
81
+ return x
82
+ keep_prob = 1.0 - self.stochastic_depth
83
+ if keep_prob <= 0.0:
84
+ return torch.zeros_like(x)
85
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
86
+ random_tensor = keep_prob + torch.rand(
87
+ shape, dtype=x.dtype, device=x.device)
88
+ binary_tensor = torch.floor(random_tensor)
89
+ return x * binary_tensor / keep_prob
90
+
91
+ def forward(self, x):
92
+ # Pre-activation structure
93
+ out = self.norm1(x)
94
+ out = self.fc1(out)
95
+ out = self.act(out)
96
+ out = self.dropout(out)
97
+ # If a second norm is enabled: out = self.norm2(out)
98
+ out = self.fc2(out)
99
+ # Apply residual scaling then add
100
+ out = self.res_scale * out
101
+ out = self._drop_path(out)
102
+ return x + out
103
+
104
+ # ResNetSequential defines the full network
105
+
106
+
107
+ class ResNetSequential(nn.Module):
108
+ # Input shape: (batch, input_dim)
109
+ # Network: FC + norm + ReLU, stack residual blocks, output Softplus
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
112
+ use_layernorm: bool = True, dropout: float = 0.1,
113
+ residual_scale: float = 0.1, stochastic_depth: float = 0.0,
114
+ task_type: str = 'regression'):
115
+ super(ResNetSequential, self).__init__()
116
+
117
+ self.net = nn.Sequential()
118
+ self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
119
+
120
+ # Optional explicit normalization after the first layer:
121
+ # For LayerNorm:
122
+ # self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
123
+ # Or BatchNorm:
124
+ # self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
125
+
126
+ # If desired, insert ReLU before residual blocks:
127
+ # self.net.add_module('relu1', nn.ReLU(inplace=True))
128
+
129
+ # Residual blocks
130
+ drop_path_rate = max(0.0, float(stochastic_depth))
131
+ for i in range(block_num):
132
+ if block_num > 1:
133
+ block_drop = drop_path_rate * (i / (block_num - 1))
134
+ else:
135
+ block_drop = drop_path_rate
136
+ self.net.add_module(
137
+ f'ResBlk_{i+1}',
138
+ ResBlock(
139
+ hidden_dim,
140
+ dropout=dropout,
141
+ use_layernorm=use_layernorm,
142
+ residual_scale=residual_scale,
143
+ stochastic_depth=block_drop)
144
+ )
145
+
146
+ self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
147
+
148
+ if task_type == 'classification':
149
+ self.net.add_module('softplus', nn.Identity())
150
+ else:
151
+ self.net.add_module('softplus', nn.Softplus())
152
+
153
+ def forward(self, x):
154
+ if self.training and not hasattr(self, '_printed_device'):
155
+ print(f">>> ResNetSequential executing on device: {x.device}")
156
+ self._printed_device = True
157
+ return self.net(x)
158
+
159
+ # Define the ResNet sklearn-style wrapper.
160
+
161
+
162
+ class ResNetSklearn(TorchTrainerMixin, nn.Module):
163
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
164
+ block_num: int = 2, batch_num: int = 100, epochs: int = 100,
165
+ task_type: str = 'regression',
166
+ tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
167
+ use_layernorm: bool = True, dropout: float = 0.1,
168
+ residual_scale: float = 0.1,
169
+ stochastic_depth: float = 0.0,
170
+ weight_decay: float = 1e-4,
171
+ use_data_parallel: bool = True,
172
+ use_ddp: bool = False):
173
+ super(ResNetSklearn, self).__init__()
174
+
175
+ self.use_ddp = use_ddp
176
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
177
+ False, 0, 0, 1)
178
+
179
+ if self.use_ddp:
180
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
181
+
182
+ self.input_dim = input_dim
183
+ self.hidden_dim = hidden_dim
184
+ self.block_num = block_num
185
+ self.batch_num = batch_num
186
+ self.epochs = epochs
187
+ self.task_type = task_type
188
+ self.model_nme = model_nme
189
+ self.learning_rate = learning_rate
190
+ self.weight_decay = weight_decay
191
+ self.patience = patience
192
+ self.use_layernorm = use_layernorm
193
+ self.dropout = dropout
194
+ self.residual_scale = residual_scale
195
+ self.stochastic_depth = max(0.0, float(stochastic_depth))
196
+ self.loss_curve_path: Optional[str] = None
197
+ self.training_history: Dict[str, List[float]] = {
198
+ "train": [], "val": []}
199
+ self.use_data_parallel = bool(use_data_parallel)
200
+
201
+ # Device selection: cuda > mps > cpu
202
+ if self.is_ddp_enabled:
203
+ if torch.cuda.is_available():
204
+ self.device = torch.device(f'cuda:{self.local_rank}')
205
+ else:
206
+ self.device = torch.device('cpu')
207
+ elif torch.cuda.is_available():
208
+ self.device = torch.device('cuda')
209
+ elif torch.backends.mps.is_available():
210
+ self.device = torch.device('mps')
211
+ else:
212
+ self.device = torch.device('cpu')
213
+
214
+ # Tweedie power (unused for classification)
215
+ if self.task_type == 'classification':
216
+ self.tw_power = None
217
+ elif 'f' in self.model_nme:
218
+ self.tw_power = 1
219
+ elif 's' in self.model_nme:
220
+ self.tw_power = 2
221
+ else:
222
+ self.tw_power = tweedie_power
223
+
224
+ # Build network (construct on CPU first)
225
+ core = ResNetSequential(
226
+ self.input_dim,
227
+ self.hidden_dim,
228
+ self.block_num,
229
+ use_layernorm=self.use_layernorm,
230
+ dropout=self.dropout,
231
+ residual_scale=self.residual_scale,
232
+ stochastic_depth=self.stochastic_depth,
233
+ task_type=self.task_type
234
+ )
235
+
236
+ # ===== Multi-GPU: DataParallel vs DistributedDataParallel =====
237
+ if self.is_ddp_enabled:
238
+ core = core.to(self.device)
239
+ if self.device.type == 'cuda':
240
+ core = DDP(core, device_ids=[self.local_rank], output_device=self.local_rank)
241
+ else:
242
+ # CPU/Gloo DDP
243
+ core = DDP(core)
244
+ self.use_data_parallel = False
245
+ elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
246
+ if self.use_ddp and not self.is_ddp_enabled:
247
+ print(
248
+ ">>> DDP requested but not initialized; falling back to DataParallel.")
249
+ core = nn.DataParallel(core, device_ids=list(
250
+ range(torch.cuda.device_count())))
251
+ # DataParallel scatters inputs, but the primary device remains cuda:0.
252
+ self.device = torch.device('cuda')
253
+ self.use_data_parallel = True
254
+ else:
255
+ self.use_data_parallel = False
256
+
257
+ self.resnet = core.to(self.device)
258
+
259
+ # ================ Internal helpers ================
260
+ @staticmethod
261
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
262
+ if arr is None:
263
+ return
264
+ if isinstance(arr, pd.DataFrame):
265
+ if arr.shape[1] != 1:
266
+ raise ValueError(f"{name} must be 1d (single column).")
267
+ length = len(arr)
268
+ else:
269
+ arr_np = np.asarray(arr)
270
+ if arr_np.ndim == 0:
271
+ raise ValueError(f"{name} must be 1d.")
272
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
273
+ raise ValueError(f"{name} must be 1d or Nx1.")
274
+ length = arr_np.shape[0]
275
+ if length != n_rows:
276
+ raise ValueError(
277
+ f"{name} length {length} does not match X length {n_rows}."
278
+ )
279
+
280
+ def _validate_inputs(self, X, y, w, label: str) -> None:
281
+ if X is None:
282
+ raise ValueError(f"{label} X cannot be None.")
283
+ n_rows = len(X)
284
+ if y is None:
285
+ raise ValueError(f"{label} y cannot be None.")
286
+ self._validate_vector(y, f"{label} y", n_rows)
287
+ self._validate_vector(w, f"{label} w", n_rows)
288
+
289
+ def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
290
+ self._validate_inputs(X_train, y_train, w_train, "train")
291
+ if X_val is not None or y_val is not None or w_val is not None:
292
+ if X_val is None or y_val is None:
293
+ raise ValueError("validation X and y must both be provided.")
294
+ self._validate_inputs(X_val, y_val, w_val, "val")
295
+
296
+ def _to_numpy(arr):
297
+ if hasattr(arr, "to_numpy"):
298
+ return arr.to_numpy(dtype=np.float32, copy=False)
299
+ return np.asarray(arr, dtype=np.float32)
300
+
301
+ X_tensor = torch.as_tensor(_to_numpy(X_train))
302
+ y_tensor = torch.as_tensor(_to_numpy(y_train)).view(-1, 1)
303
+ w_tensor = (
304
+ torch.as_tensor(_to_numpy(w_train)).view(-1, 1)
305
+ if w_train is not None else torch.ones_like(y_tensor)
306
+ )
307
+
308
+ has_val = X_val is not None and y_val is not None
309
+ if has_val:
310
+ X_val_tensor = torch.as_tensor(_to_numpy(X_val))
311
+ y_val_tensor = torch.as_tensor(_to_numpy(y_val)).view(-1, 1)
312
+ w_val_tensor = (
313
+ torch.as_tensor(_to_numpy(w_val)).view(-1, 1)
314
+ if w_val is not None else torch.ones_like(y_val_tensor)
315
+ )
316
+ else:
317
+ X_val_tensor = y_val_tensor = w_val_tensor = None
318
+ return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
319
+
320
+ def forward(self, x):
321
+ # Handle SHAP NumPy input.
322
+ if isinstance(x, np.ndarray):
323
+ x_tensor = torch.as_tensor(x, dtype=torch.float32)
324
+ else:
325
+ x_tensor = x
326
+
327
+ x_tensor = x_tensor.to(self.device)
328
+ y_pred = self.resnet(x_tensor)
329
+ return y_pred
330
+
331
+ # ---------------- Training ----------------
332
+
333
+ def fit(self, X_train, y_train, w_train=None,
334
+ X_val=None, y_val=None, w_val=None, trial=None):
335
+
336
+ X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
337
+ self._build_train_val_tensors(
338
+ X_train, y_train, w_train, X_val, y_val, w_val)
339
+
340
+ dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
341
+ dataloader, accum_steps = self._build_dataloader(
342
+ dataset,
343
+ N=X_tensor.shape[0],
344
+ base_bs_gpu=(2048, 1024, 512),
345
+ base_bs_cpu=(256, 128),
346
+ min_bs=64,
347
+ target_effective_cuda=2048,
348
+ target_effective_cpu=1024
349
+ )
350
+
351
+ # Set sampler epoch at the start of each epoch to keep shuffling deterministic.
352
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
353
+ self.dataloader_sampler = dataloader.sampler
354
+ else:
355
+ self.dataloader_sampler = None
356
+
357
+ # === 4. Optimizer and AMP ===
358
+ self.optimizer = torch.optim.Adam(
359
+ self.resnet.parameters(),
360
+ lr=self.learning_rate,
361
+ weight_decay=float(self.weight_decay),
362
+ )
363
+ self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
364
+
365
+ X_val_dev = y_val_dev = w_val_dev = None
366
+ val_dataloader = None
367
+ if has_val:
368
+ # Build validation DataLoader.
369
+ val_dataset = TensorDataset(
370
+ X_val_tensor, y_val_tensor, w_val_tensor)
371
+ # No backward pass in validation; batch size can be larger for throughput.
372
+ val_dataloader = self._build_val_dataloader(
373
+ val_dataset, dataloader, accum_steps)
374
+ # Validation usually does not need a DDP sampler because we validate on the main process
375
+ # or aggregate results. For simplicity, keep validation on a single GPU or the main process.
376
+
377
+ is_data_parallel = isinstance(self.resnet, nn.DataParallel)
378
+
379
+ def forward_fn(batch):
380
+ X_batch, y_batch, w_batch = batch
381
+
382
+ if not is_data_parallel:
383
+ X_batch = X_batch.to(self.device, non_blocking=True)
384
+ # Keep targets and weights on the main device for loss computation.
385
+ y_batch = y_batch.to(self.device, non_blocking=True)
386
+ w_batch = w_batch.to(self.device, non_blocking=True)
387
+
388
+ y_pred = self.resnet(X_batch)
389
+ return y_pred, y_batch, w_batch
390
+
391
+ def val_forward_fn():
392
+ total_loss = 0.0
393
+ total_weight = 0.0
394
+ for batch in val_dataloader:
395
+ X_b, y_b, w_b = batch
396
+ if not is_data_parallel:
397
+ X_b = X_b.to(self.device, non_blocking=True)
398
+ y_b = y_b.to(self.device, non_blocking=True)
399
+ w_b = w_b.to(self.device, non_blocking=True)
400
+
401
+ y_pred = self.resnet(X_b)
402
+
403
+ # Manually compute weighted loss for accurate aggregation.
404
+ losses = self._compute_losses(
405
+ y_pred, y_b, apply_softplus=False)
406
+
407
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
408
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
409
+
410
+ total_loss += batch_weighted_loss_sum.item()
411
+ total_weight += batch_weight_sum.item()
412
+
413
+ return total_loss / max(total_weight, EPS)
414
+
415
+ clip_fn = None
416
+ if self.device.type == 'cuda':
417
+ def clip_fn(): return (self.scaler.unscale_(self.optimizer),
418
+ clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
419
+
420
+ # Under DDP, only the main process prints logs and saves models.
421
+ if self.is_ddp_enabled and not DistributedUtils.is_main_process():
422
+ # Non-main processes skip validation callback logging (handled inside _train_model).
423
+ pass
424
+
425
+ best_state, history = self._train_model(
426
+ self.resnet,
427
+ dataloader,
428
+ accum_steps,
429
+ self.optimizer,
430
+ self.scaler,
431
+ forward_fn,
432
+ val_forward_fn if has_val else None,
433
+ apply_softplus=False,
434
+ clip_fn=clip_fn,
435
+ trial=trial,
436
+ loss_curve_path=getattr(self, "loss_curve_path", None)
437
+ )
438
+
439
+ if has_val and best_state is not None:
440
+ self.resnet.load_state_dict(best_state)
441
+ self.training_history = history
442
+
443
+ # ---------------- Prediction ----------------
444
+
445
+ def predict(self, X_test):
446
+ self.resnet.eval()
447
+ if isinstance(X_test, pd.DataFrame):
448
+ X_np = X_test.to_numpy(dtype=np.float32, copy=False)
449
+ else:
450
+ X_np = np.asarray(X_test, dtype=np.float32)
451
+
452
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
453
+ with inference_cm():
454
+ y_pred = self(X_np).cpu().numpy()
455
+
456
+ if self.task_type == 'classification':
457
+ y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid converts logits to probabilities.
458
+ else:
459
+ y_pred = np.clip(y_pred, 1e-6, None)
460
+ return y_pred.flatten()
461
+
462
+ # ---------------- Set Params ----------------
463
+
464
+ def set_params(self, params):
465
+ for key, value in params.items():
466
+ if hasattr(self, key):
467
+ setattr(self, key, value)
468
+ else:
469
+ raise ValueError(f"Parameter {key} not found in model.")
470
+ return self
471
+
472
+
473
+ # =============================================================================
474
+ # FT-Transformer model and sklearn-style wrapper.
475
+ # =============================================================================
476
+ # Define FT-Transformer model structure.
477
+
478
+
479
+ class FeatureTokenizer(nn.Module):
480
+ """Map numeric/categorical/geo tokens into transformer input tokens."""
481
+
482
+ def __init__(
483
+ self,
484
+ num_numeric: int,
485
+ cat_cardinalities,
486
+ d_model: int,
487
+ num_geo: int = 0,
488
+ num_numeric_tokens: int = 1,
489
+ ):
490
+ super().__init__()
491
+
492
+ self.num_numeric = num_numeric
493
+ self.num_geo = num_geo
494
+ self.has_geo = num_geo > 0
495
+
496
+ if num_numeric > 0:
497
+ if int(num_numeric_tokens) <= 0:
498
+ raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
499
+ self.num_numeric_tokens = int(num_numeric_tokens)
500
+ self.has_numeric = True
501
+ self.num_linear = nn.Linear(num_numeric, d_model * self.num_numeric_tokens)
502
+ else:
503
+ self.num_numeric_tokens = 0
504
+ self.has_numeric = False
505
+
506
+ self.embeddings = nn.ModuleList([
507
+ nn.Embedding(card, d_model) for card in cat_cardinalities
508
+ ])
509
+
510
+ if self.has_geo:
511
+ # Map geo tokens with a linear layer to avoid one-hot on raw strings; upstream is encoded/normalized.
512
+ self.geo_linear = nn.Linear(num_geo, d_model)
513
+
514
+ def forward(self, X_num, X_cat, X_geo=None):
515
+ tokens = []
516
+
517
+ if self.has_numeric:
518
+ batch_size = X_num.shape[0]
519
+ num_token = self.num_linear(X_num)
520
+ num_token = num_token.view(batch_size, self.num_numeric_tokens, -1)
521
+ tokens.append(num_token)
522
+
523
+ for i, emb in enumerate(self.embeddings):
524
+ tok = emb(X_cat[:, i])
525
+ tokens.append(tok.unsqueeze(1))
526
+
527
+ if self.has_geo:
528
+ if X_geo is None:
529
+ raise RuntimeError("Geo tokens are enabled but X_geo was not provided.")
530
+ geo_token = self.geo_linear(X_geo)
531
+ tokens.append(geo_token.unsqueeze(1))
532
+
533
+ x = torch.cat(tokens, dim=1)
534
+ return x
535
+
536
+ # Encoder layer with residual scaling.
537
+
538
+
539
+ class ScaledTransformerEncoderLayer(nn.Module):
540
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
541
+ dropout: float = 0.1, residual_scale_attn: float = 1.0,
542
+ residual_scale_ffn: float = 1.0, norm_first: bool = True,
543
+ ):
544
+ super().__init__()
545
+ self.self_attn = nn.MultiheadAttention(
546
+ embed_dim=d_model,
547
+ num_heads=nhead,
548
+ dropout=dropout,
549
+ batch_first=True
550
+ )
551
+
552
+ # Feed-forward network.
553
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
554
+ self.dropout = nn.Dropout(dropout)
555
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
556
+
557
+ # Normalization and dropout.
558
+ self.norm1 = nn.LayerNorm(d_model)
559
+ self.norm2 = nn.LayerNorm(d_model)
560
+ self.dropout1 = nn.Dropout(dropout)
561
+ self.dropout2 = nn.Dropout(dropout)
562
+
563
+ self.activation = nn.GELU()
564
+ # If you prefer ReLU, set: self.activation = nn.ReLU()
565
+ self.norm_first = norm_first
566
+
567
+ # Residual scaling coefficients.
568
+ self.res_scale_attn = residual_scale_attn
569
+ self.res_scale_ffn = residual_scale_ffn
570
+
571
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
572
+ # Input tensor shape: (batch, seq_len, d_model).
573
+ x = src
574
+
575
+ if self.norm_first:
576
+ # Pre-norm before attention.
577
+ x = x + self._sa_block(self.norm1(x), src_mask,
578
+ src_key_padding_mask)
579
+ x = x + self._ff_block(self.norm2(x))
580
+ else:
581
+ # Post-norm (usually disabled).
582
+ x = self.norm1(
583
+ x + self._sa_block(x, src_mask, src_key_padding_mask))
584
+ x = self.norm2(x + self._ff_block(x))
585
+
586
+ return x
587
+
588
+ def _sa_block(self, x, attn_mask, key_padding_mask):
589
+ # Self-attention with residual scaling.
590
+ attn_out, _ = self.self_attn(
591
+ x, x, x,
592
+ attn_mask=attn_mask,
593
+ key_padding_mask=key_padding_mask,
594
+ need_weights=False
595
+ )
596
+ return self.res_scale_attn * self.dropout1(attn_out)
597
+
598
+ def _ff_block(self, x):
599
+ # Feed-forward block with residual scaling.
600
+ x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
601
+ return self.res_scale_ffn * self.dropout2(x2)
602
+
603
+ # FT-Transformer core model.
604
+
605
+
606
+ class FTTransformerCore(nn.Module):
607
+ # Minimal FT-Transformer built from:
608
+ # 1) FeatureTokenizer: convert numeric/categorical features to tokens;
609
+ # 2) TransformerEncoder: model feature interactions;
610
+ # 3) Pooling + MLP + Softplus: positive outputs for Tweedie/Gamma tasks.
611
+
612
+ def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
613
+ n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
614
+ task_type: str = 'regression', num_geo: int = 0,
615
+ num_numeric_tokens: int = 1
616
+ ):
617
+ super().__init__()
618
+
619
+ self.num_numeric = int(num_numeric)
620
+ self.cat_cardinalities = list(cat_cardinalities or [])
621
+
622
+ self.tokenizer = FeatureTokenizer(
623
+ num_numeric=num_numeric,
624
+ cat_cardinalities=cat_cardinalities,
625
+ d_model=d_model,
626
+ num_geo=num_geo,
627
+ num_numeric_tokens=num_numeric_tokens
628
+ )
629
+ scale = 1.0 / math.sqrt(n_layers) # Recommended default.
630
+ encoder_layer = ScaledTransformerEncoderLayer(
631
+ d_model=d_model,
632
+ nhead=n_heads,
633
+ dim_feedforward=d_model * 4,
634
+ dropout=dropout,
635
+ residual_scale_attn=scale,
636
+ residual_scale_ffn=scale,
637
+ norm_first=True,
638
+ )
639
+ self.encoder = nn.TransformerEncoder(
640
+ encoder_layer,
641
+ num_layers=n_layers
642
+ )
643
+ self.n_layers = n_layers
644
+
645
+ layers = [
646
+ # If you need a deeper head, enable the sample layers below:
647
+ # nn.LayerNorm(d_model), # Extra normalization
648
+ # nn.Linear(d_model, d_model), # Extra fully connected layer
649
+ # nn.GELU(), # Activation
650
+ nn.Linear(d_model, 1),
651
+ ]
652
+
653
+ if task_type == 'classification':
654
+ # Classification outputs logits for BCEWithLogitsLoss.
655
+ layers.append(nn.Identity())
656
+ else:
657
+ # Regression keeps positive outputs for Tweedie/Gamma.
658
+ layers.append(nn.Softplus())
659
+
660
+ self.head = nn.Sequential(*layers)
661
+
662
+ # ---- Self-supervised reconstruction head (masked modeling) ----
663
+ self.num_recon_head = nn.Linear(
664
+ d_model, self.num_numeric) if self.num_numeric > 0 else None
665
+ self.cat_recon_heads = nn.ModuleList([
666
+ nn.Linear(d_model, int(card)) for card in self.cat_cardinalities
667
+ ])
668
+
669
+ def forward(
670
+ self,
671
+ X_num,
672
+ X_cat,
673
+ X_geo=None,
674
+ return_embedding: bool = False,
675
+ return_reconstruction: bool = False):
676
+
677
+ # Inputs:
678
+ # X_num -> float32 tensor with shape (batch, num_numeric_features)
679
+ # X_cat -> long tensor with shape (batch, num_categorical_features)
680
+ # X_geo -> float32 tensor with shape (batch, geo_token_dim)
681
+
682
+ if self.training and not hasattr(self, '_printed_device'):
683
+ print(f">>> FTTransformerCore executing on device: {X_num.device}")
684
+ self._printed_device = True
685
+
686
+ # => tensor shape (batch, token_num, d_model)
687
+ tokens = self.tokenizer(X_num, X_cat, X_geo)
688
+ # => tensor shape (batch, token_num, d_model)
689
+ x = self.encoder(tokens)
690
+
691
+ # Mean-pool tokens, then send to the head.
692
+ x = x.mean(dim=1) # => tensor shape (batch, d_model)
693
+
694
+ if return_reconstruction:
695
+ num_pred, cat_logits = self.reconstruct(x)
696
+ cat_logits_out = tuple(
697
+ cat_logits) if cat_logits is not None else tuple()
698
+ if return_embedding:
699
+ return x, num_pred, cat_logits_out
700
+ return num_pred, cat_logits_out
701
+
702
+ if return_embedding:
703
+ return x
704
+
705
+ # => tensor shape (batch, 1); Softplus keeps it positive.
706
+ out = self.head(x)
707
+ return out
708
+
709
+ def reconstruct(self, embedding: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[torch.Tensor]]:
710
+ """Reconstruct numeric/categorical inputs from pooled embedding (batch, d_model)."""
711
+ num_pred = self.num_recon_head(
712
+ embedding) if self.num_recon_head is not None else None
713
+ cat_logits = [head(embedding) for head in self.cat_recon_heads]
714
+ return num_pred, cat_logits
715
+
716
+ # TabularDataset.
717
+
718
+
719
+ class TabularDataset(Dataset):
720
+ def __init__(self, X_num, X_cat, X_geo, y, w):
721
+
722
+ # Input tensors:
723
+ # X_num: torch.float32, shape=(N, num_numeric_features)
724
+ # X_cat: torch.long, shape=(N, num_categorical_features)
725
+ # X_geo: torch.float32, shape=(N, geo_token_dim), can be empty
726
+ # y: torch.float32, shape=(N, 1)
727
+ # w: torch.float32, shape=(N, 1)
728
+
729
+ self.X_num = X_num
730
+ self.X_cat = X_cat
731
+ self.X_geo = X_geo
732
+ self.y = y
733
+ self.w = w
734
+
735
+ def __len__(self):
736
+ return self.y.shape[0]
737
+
738
+ def __getitem__(self, idx):
739
+ return (
740
+ self.X_num[idx],
741
+ self.X_cat[idx],
742
+ self.X_geo[idx],
743
+ self.y[idx],
744
+ self.w[idx],
745
+ )
746
+
747
+
748
+ class MaskedTabularDataset(Dataset):
749
+ def __init__(self,
750
+ X_num_masked: torch.Tensor,
751
+ X_cat_masked: torch.Tensor,
752
+ X_geo: torch.Tensor,
753
+ X_num_true: Optional[torch.Tensor],
754
+ num_mask: Optional[torch.Tensor],
755
+ X_cat_true: Optional[torch.Tensor],
756
+ cat_mask: Optional[torch.Tensor]):
757
+ self.X_num_masked = X_num_masked
758
+ self.X_cat_masked = X_cat_masked
759
+ self.X_geo = X_geo
760
+ self.X_num_true = X_num_true
761
+ self.num_mask = num_mask
762
+ self.X_cat_true = X_cat_true
763
+ self.cat_mask = cat_mask
764
+
765
+ def __len__(self):
766
+ return self.X_num_masked.shape[0]
767
+
768
+ def __getitem__(self, idx):
769
+ return (
770
+ self.X_num_masked[idx],
771
+ self.X_cat_masked[idx],
772
+ self.X_geo[idx],
773
+ None if self.X_num_true is None else self.X_num_true[idx],
774
+ None if self.num_mask is None else self.num_mask[idx],
775
+ None if self.X_cat_true is None else self.X_cat_true[idx],
776
+ None if self.cat_mask is None else self.cat_mask[idx],
777
+ )
778
+
779
+ # Scikit-Learn style wrapper for FTTransformer.
780
+
781
+
782
+ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
783
+
784
+ # sklearn-style wrapper:
785
+ # - num_cols: numeric feature column names
786
+ # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
787
+
788
+ @staticmethod
789
+ def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
790
+ num_cols_count = len(num_cols or [])
791
+ if num_cols_count == 0:
792
+ return 0
793
+ if requested is not None:
794
+ count = int(requested)
795
+ if count <= 0:
796
+ raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
797
+ return count
798
+ return max(1, num_cols_count)
799
+
800
+ def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
801
+ n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
802
+ task_type: str = 'regression',
803
+ tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
804
+ weight_decay: float = 0.0,
805
+ use_data_parallel: bool = True,
806
+ use_ddp: bool = False,
807
+ num_numeric_tokens: Optional[int] = None
808
+ ):
809
+ super().__init__()
810
+
811
+ self.use_ddp = use_ddp
812
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
813
+ False, 0, 0, 1)
814
+ if self.use_ddp:
815
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
816
+
817
+ self.model_nme = model_nme
818
+ self.num_cols = list(num_cols)
819
+ self.cat_cols = list(cat_cols)
820
+ self.num_numeric_tokens = self.resolve_numeric_token_count(
821
+ self.num_cols,
822
+ self.cat_cols,
823
+ num_numeric_tokens,
824
+ )
825
+ self.d_model = d_model
826
+ self.n_heads = n_heads
827
+ self.n_layers = n_layers
828
+ self.dropout = dropout
829
+ self.batch_num = batch_num
830
+ self.epochs = epochs
831
+ self.learning_rate = learning_rate
832
+ self.weight_decay = weight_decay
833
+ self.task_type = task_type
834
+ self.patience = patience
835
+ if self.task_type == 'classification':
836
+ self.tw_power = None # No Tweedie power for classification.
837
+ elif 'f' in self.model_nme:
838
+ self.tw_power = 1.0
839
+ elif 's' in self.model_nme:
840
+ self.tw_power = 2.0
841
+ else:
842
+ self.tw_power = tweedie_power
843
+
844
+ if self.is_ddp_enabled:
845
+ # Allow CPU DDP (e.g. gloo) if CUDA is not available
846
+ if torch.cuda.is_available():
847
+ self.device = torch.device(f"cuda:{self.local_rank}")
848
+ else:
849
+ self.device = torch.device("cpu")
850
+ self.cat_cardinalities = None
851
+ self.cat_categories = {}
852
+ self.cat_maps: Dict[str, Dict[Any, int]] = {}
853
+ self.cat_str_maps: Dict[str, Dict[str, int]] = {}
854
+ self._num_mean = None
855
+ self._num_std = None
856
+ self.ft = None
857
+ self.use_data_parallel = bool(use_data_parallel)
858
+ self.num_geo = 0
859
+ self._geo_params: Dict[str, Any] = {}
860
+ self.loss_curve_path: Optional[str] = None
861
+ self.training_history: Dict[str, List[float]] = {
862
+ "train": [], "val": []}
863
+
864
+ def _build_model(self, X_train):
865
+ num_numeric = len(self.num_cols)
866
+ cat_cardinalities = []
867
+
868
+ if num_numeric > 0:
869
+ num_arr = X_train[self.num_cols].to_numpy(
870
+ dtype=np.float32, copy=False)
871
+ num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
872
+ mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
873
+ std = num_arr.std(axis=0).astype(np.float32, copy=False)
874
+ std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
875
+ self._num_mean = mean
876
+ self._num_std = std
877
+ else:
878
+ self._num_mean = None
879
+ self._num_std = None
880
+
881
+ self.cat_maps = {}
882
+ self.cat_str_maps = {}
883
+ for col in self.cat_cols:
884
+ cats = X_train[col].astype('category')
885
+ categories = cats.cat.categories
886
+ self.cat_categories[col] = categories # Store full category list from training.
887
+ self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
888
+ if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
889
+ self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
890
+
891
+ card = len(categories) + 1 # Reserve one extra class for unknown/missing.
892
+ cat_cardinalities.append(card)
893
+
894
+ self.cat_cardinalities = cat_cardinalities
895
+
896
+ core = FTTransformerCore(
897
+ num_numeric=num_numeric,
898
+ cat_cardinalities=cat_cardinalities,
899
+ d_model=self.d_model,
900
+ n_heads=self.n_heads,
901
+ n_layers=self.n_layers,
902
+ dropout=self.dropout,
903
+ task_type=self.task_type,
904
+ num_geo=self.num_geo,
905
+ num_numeric_tokens=self.num_numeric_tokens
906
+ )
907
+ use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
908
+ if self.is_ddp_enabled:
909
+ core = core.to(self.device)
910
+ if self.device.type == 'cuda':
911
+ core = DDP(core, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
912
+ else:
913
+ # CPU/Gloo DDP
914
+ core = DDP(core, find_unused_parameters=True)
915
+ self.use_data_parallel = False
916
+ elif use_dp:
917
+ if self.use_ddp and not self.is_ddp_enabled:
918
+ print(
919
+ ">>> DDP requested but not initialized; falling back to DataParallel.")
920
+ core = nn.DataParallel(core, device_ids=list(
921
+ range(torch.cuda.device_count())))
922
+ self.device = torch.device("cuda")
923
+ self.use_data_parallel = True
924
+ else:
925
+ self.use_data_parallel = False
926
+ self.ft = core.to(self.device)
927
+
928
+ def _encode_cats(self, X):
929
+ # Input DataFrame must include all categorical feature columns.
930
+ # Return int64 array with shape (N, num_categorical_features).
931
+
932
+ if not self.cat_cols:
933
+ return np.zeros((len(X), 0), dtype='int64')
934
+
935
+ n_rows = len(X)
936
+ n_cols = len(self.cat_cols)
937
+ X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
938
+ for idx, col in enumerate(self.cat_cols):
939
+ categories = self.cat_categories[col]
940
+ mapping = self.cat_maps.get(col)
941
+ if mapping is None:
942
+ mapping = {cat: i for i, cat in enumerate(categories)}
943
+ self.cat_maps[col] = mapping
944
+ unknown_idx = len(categories)
945
+ series = X[col]
946
+ codes = series.map(mapping)
947
+ unmapped = series.notna() & codes.isna()
948
+ if unmapped.any():
949
+ try:
950
+ series_cast = series.astype(categories.dtype)
951
+ except Exception:
952
+ series_cast = None
953
+ if series_cast is not None:
954
+ codes = series_cast.map(mapping)
955
+ unmapped = series_cast.notna() & codes.isna()
956
+ if unmapped.any():
957
+ str_map = self.cat_str_maps.get(col)
958
+ if str_map is None:
959
+ str_map = {str(cat): i for i, cat in enumerate(categories)}
960
+ self.cat_str_maps[col] = str_map
961
+ codes = series.astype(str).map(str_map)
962
+ if pd.api.types.is_categorical_dtype(codes):
963
+ codes = codes.astype("float")
964
+ codes = codes.fillna(unknown_idx).astype(
965
+ "int64", copy=False).to_numpy()
966
+ X_cat_np[:, idx] = codes
967
+ return X_cat_np
968
+
969
+ def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
970
+ return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
971
+
972
+ def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
973
+ return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
974
+
975
+ @staticmethod
976
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
977
+ if arr is None:
978
+ return
979
+ if isinstance(arr, pd.DataFrame):
980
+ if arr.shape[1] != 1:
981
+ raise ValueError(f"{name} must be 1d (single column).")
982
+ length = len(arr)
983
+ else:
984
+ arr_np = np.asarray(arr)
985
+ if arr_np.ndim == 0:
986
+ raise ValueError(f"{name} must be 1d.")
987
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
988
+ raise ValueError(f"{name} must be 1d or Nx1.")
989
+ length = arr_np.shape[0]
990
+ if length != n_rows:
991
+ raise ValueError(
992
+ f"{name} length {length} does not match X length {n_rows}."
993
+ )
994
+
995
+ def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
996
+ if X is None:
997
+ if allow_none:
998
+ return None, None, None, None, None, False
999
+ raise ValueError("Input features X must not be None.")
1000
+ if not isinstance(X, pd.DataFrame):
1001
+ raise ValueError("X must be a pandas DataFrame.")
1002
+ missing_cols = [
1003
+ col for col in (self.num_cols + self.cat_cols) if col not in X.columns
1004
+ ]
1005
+ if missing_cols:
1006
+ raise ValueError(f"X is missing required columns: {missing_cols}")
1007
+ n_rows = len(X)
1008
+ if y is not None:
1009
+ self._validate_vector(y, "y", n_rows)
1010
+ if w is not None:
1011
+ self._validate_vector(w, "w", n_rows)
1012
+
1013
+ num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
1014
+ if not num_np.flags["OWNDATA"]:
1015
+ num_np = num_np.copy()
1016
+ num_np = np.nan_to_num(num_np, nan=0.0,
1017
+ posinf=0.0, neginf=0.0, copy=False)
1018
+ if self._num_mean is not None and self._num_std is not None and num_np.size:
1019
+ num_np = (num_np - self._num_mean) / self._num_std
1020
+ X_num = torch.as_tensor(num_np)
1021
+ if self.cat_cols:
1022
+ X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
1023
+ else:
1024
+ X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
1025
+
1026
+ if geo_tokens is not None:
1027
+ geo_np = np.asarray(geo_tokens, dtype=np.float32)
1028
+ if geo_np.shape[0] != n_rows:
1029
+ raise ValueError(
1030
+ "geo_tokens length does not match X rows.")
1031
+ if geo_np.ndim == 1:
1032
+ geo_np = geo_np.reshape(-1, 1)
1033
+ elif self.num_geo > 0:
1034
+ raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
1035
+ else:
1036
+ geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
1037
+ X_geo = torch.as_tensor(geo_np)
1038
+
1039
+ y_tensor = torch.as_tensor(
1040
+ y.to_numpy(dtype=np.float32, copy=False) if hasattr(
1041
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
1042
+ ).view(-1, 1) if y is not None else None
1043
+ if y_tensor is None:
1044
+ w_tensor = None
1045
+ elif w is not None:
1046
+ w_tensor = torch.as_tensor(
1047
+ w.to_numpy(dtype=np.float32, copy=False) if hasattr(
1048
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
1049
+ ).view(-1, 1)
1050
+ else:
1051
+ w_tensor = torch.ones_like(y_tensor)
1052
+ return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
1053
+
1054
+ def fit(self, X_train, y_train, w_train=None,
1055
+ X_val=None, y_val=None, w_val=None, trial=None,
1056
+ geo_train=None, geo_val=None):
1057
+
1058
+ # Build the underlying model on first fit.
1059
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
1060
+ if self.ft is None:
1061
+ self._build_model(X_train)
1062
+
1063
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
1064
+ X_train, y_train, w_train, geo_train=geo_train)
1065
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
1066
+ X_val, y_val, w_val, geo_val=geo_val)
1067
+
1068
+ # --- Build DataLoader ---
1069
+ dataset = TabularDataset(
1070
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
1071
+ )
1072
+
1073
+ dataloader, accum_steps = self._build_dataloader(
1074
+ dataset,
1075
+ N=X_num_train.shape[0],
1076
+ base_bs_gpu=(2048, 1024, 512),
1077
+ base_bs_cpu=(256, 128),
1078
+ min_bs=64,
1079
+ target_effective_cuda=2048,
1080
+ target_effective_cpu=1024
1081
+ )
1082
+
1083
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
1084
+ self.dataloader_sampler = dataloader.sampler
1085
+ else:
1086
+ self.dataloader_sampler = None
1087
+
1088
+ optimizer = torch.optim.Adam(
1089
+ self.ft.parameters(),
1090
+ lr=self.learning_rate,
1091
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
1092
+ )
1093
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
1094
+
1095
+ X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
1096
+ val_dataloader = None
1097
+ if has_val:
1098
+ val_dataset = TabularDataset(
1099
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
1100
+ )
1101
+ val_dataloader = self._build_val_dataloader(
1102
+ val_dataset, dataloader, accum_steps)
1103
+
1104
+ is_data_parallel = isinstance(self.ft, nn.DataParallel)
1105
+
1106
+ def forward_fn(batch):
1107
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
1108
+
1109
+ if not is_data_parallel:
1110
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1111
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1112
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1113
+ y_b = y_b.to(self.device, non_blocking=True)
1114
+ w_b = w_b.to(self.device, non_blocking=True)
1115
+
1116
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
1117
+ return y_pred, y_b, w_b
1118
+
1119
+ def val_forward_fn():
1120
+ total_loss = 0.0
1121
+ total_weight = 0.0
1122
+ for batch in val_dataloader:
1123
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
1124
+ if not is_data_parallel:
1125
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1126
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1127
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1128
+ y_b = y_b.to(self.device, non_blocking=True)
1129
+ w_b = w_b.to(self.device, non_blocking=True)
1130
+
1131
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
1132
+
1133
+ # Manually compute validation loss.
1134
+ losses = self._compute_losses(
1135
+ y_pred, y_b, apply_softplus=False)
1136
+
1137
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
1138
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
1139
+
1140
+ total_loss += batch_weighted_loss_sum.item()
1141
+ total_weight += batch_weight_sum.item()
1142
+
1143
+ return total_loss / max(total_weight, EPS)
1144
+
1145
+ clip_fn = None
1146
+ if self.device.type == 'cuda':
1147
+ def clip_fn(): return (scaler.unscale_(optimizer),
1148
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
1149
+
1150
+ best_state, history = self._train_model(
1151
+ self.ft,
1152
+ dataloader,
1153
+ accum_steps,
1154
+ optimizer,
1155
+ scaler,
1156
+ forward_fn,
1157
+ val_forward_fn if has_val else None,
1158
+ apply_softplus=False,
1159
+ clip_fn=clip_fn,
1160
+ trial=trial,
1161
+ loss_curve_path=getattr(self, "loss_curve_path", None)
1162
+ )
1163
+
1164
+ if has_val and best_state is not None:
1165
+ self.ft.load_state_dict(best_state)
1166
+ self.training_history = history
1167
+
1168
+ def fit_unsupervised(self,
1169
+ X_train,
1170
+ X_val=None,
1171
+ trial: Optional[optuna.trial.Trial] = None,
1172
+ geo_train=None,
1173
+ geo_val=None,
1174
+ mask_prob_num: float = 0.15,
1175
+ mask_prob_cat: float = 0.15,
1176
+ num_loss_weight: float = 1.0,
1177
+ cat_loss_weight: float = 1.0) -> float:
1178
+ """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
1179
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
1180
+ if self.ft is None:
1181
+ self._build_model(X_train)
1182
+
1183
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
1184
+ X_train, None, None, geo_tokens=geo_train, allow_none=True)
1185
+ has_val = X_val is not None
1186
+ if has_val:
1187
+ X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
1188
+ X_val, None, None, geo_tokens=geo_val, allow_none=True)
1189
+ else:
1190
+ X_num_val = X_cat_val = X_geo_val = None
1191
+
1192
+ N = int(X_num.shape[0])
1193
+ num_dim = int(X_num.shape[1])
1194
+ cat_dim = int(X_cat.shape[1])
1195
+ device_type = self._device_type()
1196
+
1197
+ gen = torch.Generator()
1198
+ gen.manual_seed(13 + int(getattr(self, "rank", 0)))
1199
+
1200
+ base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
1201
+ cardinals = getattr(base_model, "cat_cardinalities", None) or []
1202
+ unknown_idx = torch.tensor(
1203
+ [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
1204
+
1205
+ means = None
1206
+ if num_dim > 0:
1207
+ # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
1208
+ means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
1209
+
1210
+ def _mask_inputs(X_num_in: torch.Tensor,
1211
+ X_cat_in: torch.Tensor,
1212
+ generator: torch.Generator):
1213
+ n_rows = int(X_num_in.shape[0])
1214
+ num_mask_local = None
1215
+ cat_mask_local = None
1216
+ X_num_masked_local = X_num_in
1217
+ X_cat_masked_local = X_cat_in
1218
+ if num_dim > 0:
1219
+ num_mask_local = (torch.rand(
1220
+ (n_rows, num_dim), generator=generator) < float(mask_prob_num))
1221
+ X_num_masked_local = X_num_in.clone()
1222
+ if num_mask_local.any():
1223
+ X_num_masked_local[num_mask_local] = means.expand_as(
1224
+ X_num_masked_local)[num_mask_local]
1225
+ if cat_dim > 0:
1226
+ cat_mask_local = (torch.rand(
1227
+ (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
1228
+ X_cat_masked_local = X_cat_in.clone()
1229
+ if cat_mask_local.any():
1230
+ X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
1231
+ X_cat_masked_local)[cat_mask_local]
1232
+ return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
1233
+
1234
+ X_num_true = X_num if num_dim > 0 else None
1235
+ X_cat_true = X_cat if cat_dim > 0 else None
1236
+ X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
1237
+ X_num, X_cat, gen)
1238
+
1239
+ dataset = MaskedTabularDataset(
1240
+ X_num_masked, X_cat_masked, X_geo,
1241
+ X_num_true, num_mask,
1242
+ X_cat_true, cat_mask
1243
+ )
1244
+ dataloader, accum_steps = self._build_dataloader(
1245
+ dataset,
1246
+ N=N,
1247
+ base_bs_gpu=(2048, 1024, 512),
1248
+ base_bs_cpu=(256, 128),
1249
+ min_bs=64,
1250
+ target_effective_cuda=2048,
1251
+ target_effective_cpu=1024
1252
+ )
1253
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
1254
+ self.dataloader_sampler = dataloader.sampler
1255
+ else:
1256
+ self.dataloader_sampler = None
1257
+
1258
+ optimizer = torch.optim.Adam(
1259
+ self.ft.parameters(),
1260
+ lr=self.learning_rate,
1261
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
1262
+ )
1263
+ scaler = GradScaler(enabled=(device_type == 'cuda'))
1264
+
1265
+ def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
1266
+ loss = torch.zeros((), device=device, dtype=torch.float32)
1267
+
1268
+ if num_pred is not None and num_true_b is not None and num_mask_b is not None:
1269
+ num_mask_b = num_mask_b.to(dtype=torch.bool)
1270
+ if num_mask_b.any():
1271
+ diff = num_pred - num_true_b
1272
+ mse = diff * diff
1273
+ loss = loss + float(num_loss_weight) * \
1274
+ mse[num_mask_b].mean()
1275
+
1276
+ if cat_logits and cat_true_b is not None and cat_mask_b is not None:
1277
+ cat_mask_b = cat_mask_b.to(dtype=torch.bool)
1278
+ cat_losses: List[torch.Tensor] = []
1279
+ for j, logits in enumerate(cat_logits):
1280
+ mask_j = cat_mask_b[:, j]
1281
+ if not mask_j.any():
1282
+ continue
1283
+ targets = cat_true_b[:, j]
1284
+ cat_losses.append(
1285
+ F.cross_entropy(logits, targets, reduction='none')[
1286
+ mask_j].mean()
1287
+ )
1288
+ if cat_losses:
1289
+ loss = loss + float(cat_loss_weight) * \
1290
+ torch.stack(cat_losses).mean()
1291
+ return loss
1292
+
1293
+ train_history: List[float] = []
1294
+ val_history: List[float] = []
1295
+ best_loss = float("inf")
1296
+ best_state = None
1297
+ patience_counter = 0
1298
+ is_ddp_model = isinstance(self.ft, DDP)
1299
+
1300
+ clip_fn = None
1301
+ if self.device.type == 'cuda':
1302
+ def clip_fn(): return (scaler.unscale_(optimizer),
1303
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
1304
+
1305
+ for epoch in range(1, int(self.epochs) + 1):
1306
+ if self.dataloader_sampler is not None:
1307
+ self.dataloader_sampler.set_epoch(epoch)
1308
+
1309
+ self.ft.train()
1310
+ optimizer.zero_grad()
1311
+ epoch_loss_sum = 0.0
1312
+ epoch_count = 0.0
1313
+
1314
+ for step, batch in enumerate(dataloader):
1315
+ is_update_step = ((step + 1) % accum_steps == 0) or \
1316
+ ((step + 1) == len(dataloader))
1317
+ sync_cm = self.ft.no_sync if (
1318
+ is_ddp_model and not is_update_step) else nullcontext
1319
+ with sync_cm():
1320
+ with autocast(enabled=(device_type == 'cuda')):
1321
+ X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
1322
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1323
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1324
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1325
+ num_true_b = None if num_true_b is None else num_true_b.to(
1326
+ self.device, non_blocking=True)
1327
+ num_mask_b = None if num_mask_b is None else num_mask_b.to(
1328
+ self.device, non_blocking=True)
1329
+ cat_true_b = None if cat_true_b is None else cat_true_b.to(
1330
+ self.device, non_blocking=True)
1331
+ cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
1332
+ self.device, non_blocking=True)
1333
+
1334
+ num_pred, cat_logits = self.ft(
1335
+ X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
1336
+ batch_loss = _batch_recon_loss(
1337
+ num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device=X_num_b.device)
1338
+ local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
1339
+ global_bad = local_bad
1340
+ if dist.is_initialized():
1341
+ bad = torch.tensor(
1342
+ [local_bad],
1343
+ device=batch_loss.device,
1344
+ dtype=torch.int32,
1345
+ )
1346
+ dist.all_reduce(bad, op=dist.ReduceOp.MAX)
1347
+ global_bad = int(bad.item())
1348
+
1349
+ if global_bad:
1350
+ msg = (
1351
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
1352
+ f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
1353
+ )
1354
+ should_log = (not dist.is_initialized()
1355
+ or DistributedUtils.is_main_process())
1356
+ if should_log:
1357
+ print(msg, flush=True)
1358
+ print(
1359
+ f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
1360
+ f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
1361
+ f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
1362
+ flush=True,
1363
+ )
1364
+ if X_geo_b is not None:
1365
+ print(
1366
+ f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
1367
+ f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
1368
+ f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
1369
+ flush=True,
1370
+ )
1371
+ if trial is not None:
1372
+ raise optuna.TrialPruned(msg)
1373
+ raise RuntimeError(msg)
1374
+ loss_for_backward = batch_loss / float(accum_steps)
1375
+ scaler.scale(loss_for_backward).backward()
1376
+
1377
+ if is_update_step:
1378
+ if clip_fn is not None:
1379
+ clip_fn()
1380
+ scaler.step(optimizer)
1381
+ scaler.update()
1382
+ optimizer.zero_grad()
1383
+
1384
+ epoch_loss_sum += float(batch_loss.detach().item()) * \
1385
+ float(X_num_b.shape[0])
1386
+ epoch_count += float(X_num_b.shape[0])
1387
+
1388
+ train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
1389
+
1390
+ if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
1391
+ should_compute_val = (not dist.is_initialized()
1392
+ or DistributedUtils.is_main_process())
1393
+ loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
1394
+ "cpu")
1395
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
1396
+
1397
+ if should_compute_val:
1398
+ self.ft.eval()
1399
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
1400
+ val_bs = min(
1401
+ int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
1402
+ total_val = 0.0
1403
+ total_n = 0.0
1404
+ for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
1405
+ end = min(
1406
+ int(X_num_val.shape[0]), start + max(1, val_bs))
1407
+ X_num_v_true_cpu = X_num_val[start:end]
1408
+ X_cat_v_true_cpu = X_cat_val[start:end]
1409
+ X_geo_v = X_geo_val[start:end].to(
1410
+ self.device, non_blocking=True)
1411
+ gen_val = torch.Generator()
1412
+ gen_val.manual_seed(10_000 + epoch + start)
1413
+ X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
1414
+ X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
1415
+ X_num_v_true = X_num_v_true_cpu.to(
1416
+ self.device, non_blocking=True)
1417
+ X_cat_v_true = X_cat_v_true_cpu.to(
1418
+ self.device, non_blocking=True)
1419
+ X_num_v = X_num_v_cpu.to(
1420
+ self.device, non_blocking=True)
1421
+ X_cat_v = X_cat_v_cpu.to(
1422
+ self.device, non_blocking=True)
1423
+ val_num_mask = None if val_num_mask is None else val_num_mask.to(
1424
+ self.device, non_blocking=True)
1425
+ val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
1426
+ self.device, non_blocking=True)
1427
+ num_pred_v, cat_logits_v = self.ft(
1428
+ X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
1429
+ loss_v = _batch_recon_loss(
1430
+ num_pred_v, cat_logits_v,
1431
+ X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
1432
+ X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
1433
+ device=X_num_v.device
1434
+ )
1435
+ if not torch.isfinite(loss_v):
1436
+ total_val = float("inf")
1437
+ total_n = 1.0
1438
+ break
1439
+ total_val += float(loss_v.detach().item()
1440
+ ) * float(end - start)
1441
+ total_n += float(end - start)
1442
+ val_loss_tensor[0] = total_val / max(total_n, 1.0)
1443
+
1444
+ if dist.is_initialized():
1445
+ dist.broadcast(val_loss_tensor, src=0)
1446
+ val_loss_value = float(val_loss_tensor.item())
1447
+ prune_now = False
1448
+ prune_msg = None
1449
+ if not np.isfinite(val_loss_value):
1450
+ prune_now = True
1451
+ prune_msg = (
1452
+ f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
1453
+ f"(epoch={epoch}, val_loss={val_loss_value})"
1454
+ )
1455
+ val_history.append(val_loss_value)
1456
+
1457
+ if val_loss_value < best_loss:
1458
+ best_loss = val_loss_value
1459
+ best_state = {
1460
+ k: (v.clone() if isinstance(
1461
+ v, torch.Tensor) else copy.deepcopy(v))
1462
+ for k, v in self.ft.state_dict().items()
1463
+ }
1464
+ patience_counter = 0
1465
+ else:
1466
+ patience_counter += 1
1467
+ if best_state is not None and patience_counter >= int(self.patience):
1468
+ break
1469
+
1470
+ if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
1471
+ trial.report(val_loss_value, epoch)
1472
+ if trial.should_prune():
1473
+ prune_now = True
1474
+
1475
+ if dist.is_initialized():
1476
+ flag = torch.tensor(
1477
+ [1 if prune_now else 0],
1478
+ device=loss_tensor_device,
1479
+ dtype=torch.int32,
1480
+ )
1481
+ dist.broadcast(flag, src=0)
1482
+ prune_now = bool(flag.item())
1483
+
1484
+ if prune_now:
1485
+ if prune_msg:
1486
+ raise optuna.TrialPruned(prune_msg)
1487
+ raise optuna.TrialPruned()
1488
+
1489
+ self.training_history = {"train": train_history, "val": val_history}
1490
+ self._plot_loss_curve(self.training_history, getattr(
1491
+ self, "loss_curve_path", None))
1492
+ if has_val and best_state is not None:
1493
+ self.ft.load_state_dict(best_state)
1494
+ return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
1495
+
1496
+ def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
1497
+ # X_test must include all numeric/categorical columns; geo_tokens is optional.
1498
+
1499
+ self.ft.eval()
1500
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
1501
+ X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
1502
+
1503
+ num_rows = X_num.shape[0]
1504
+ if num_rows == 0:
1505
+ return np.empty(0, dtype=np.float32)
1506
+
1507
+ device = self.device if isinstance(
1508
+ self.device, torch.device) else torch.device(self.device)
1509
+
1510
+ def resolve_batch_size(n_rows: int) -> int:
1511
+ if batch_size is not None:
1512
+ return max(1, min(int(batch_size), n_rows))
1513
+ # Estimate a safe batch size based on model size to avoid attention OOM.
1514
+ token_cnt = self.num_numeric_tokens + len(self.cat_cols)
1515
+ if self.num_geo > 0:
1516
+ token_cnt += 1
1517
+ approx_units = max(1, token_cnt * max(1, self.d_model))
1518
+ if device.type == 'cuda':
1519
+ if approx_units >= 8192:
1520
+ base = 512
1521
+ elif approx_units >= 4096:
1522
+ base = 1024
1523
+ else:
1524
+ base = 2048
1525
+ else:
1526
+ base = 512
1527
+ return max(1, min(base, n_rows))
1528
+
1529
+ eff_batch = resolve_batch_size(num_rows)
1530
+ preds: List[torch.Tensor] = []
1531
+
1532
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
1533
+ with inference_cm():
1534
+ for start in range(0, num_rows, eff_batch):
1535
+ end = min(num_rows, start + eff_batch)
1536
+ X_num_b = X_num[start:end].to(device, non_blocking=True)
1537
+ X_cat_b = X_cat[start:end].to(device, non_blocking=True)
1538
+ X_geo_b = X_geo[start:end].to(device, non_blocking=True)
1539
+ pred_chunk = self.ft(
1540
+ X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
1541
+ preds.append(pred_chunk.cpu())
1542
+
1543
+ y_pred = torch.cat(preds, dim=0).numpy()
1544
+
1545
+ if return_embedding:
1546
+ return y_pred
1547
+
1548
+ if self.task_type == 'classification':
1549
+ # Convert logits to probabilities.
1550
+ y_pred = 1 / (1 + np.exp(-y_pred))
1551
+ else:
1552
+ # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
1553
+ y_pred = np.clip(y_pred, 1e-6, None)
1554
+ return y_pred.ravel()
1555
+
1556
+ def set_params(self, params: dict):
1557
+
1558
+ # Keep sklearn-style behavior.
1559
+ # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
1560
+
1561
+ for key, value in params.items():
1562
+ if hasattr(self, key):
1563
+ setattr(self, key, value)
1564
+ else:
1565
+ raise ValueError(f"Parameter {key} not found in model.")
1566
+ return self
1567
+
1568
+
1569
+ # =============================================================================
1570
+ # Simplified GNN implementation.
1571
+ # =============================================================================
1572
+
1573
+
1574
+ class SimpleGraphLayer(nn.Module):
1575
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
1576
+ super().__init__()
1577
+ self.linear = nn.Linear(in_dim, out_dim)
1578
+ self.activation = nn.ReLU(inplace=True)
1579
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1580
+
1581
+ def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
1582
+ # Message passing with normalized sparse adjacency: A_hat * X * W.
1583
+ h = torch.sparse.mm(adj, x)
1584
+ h = self.linear(h)
1585
+ h = self.activation(h)
1586
+ return self.dropout(h)
1587
+
1588
+
1589
+ class SimpleGNN(nn.Module):
1590
+ def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
1591
+ dropout: float = 0.1, task_type: str = 'regression'):
1592
+ super().__init__()
1593
+ layers = []
1594
+ dim_in = input_dim
1595
+ for _ in range(max(1, num_layers)):
1596
+ layers.append(SimpleGraphLayer(
1597
+ dim_in, hidden_dim, dropout=dropout))
1598
+ dim_in = hidden_dim
1599
+ self.layers = nn.ModuleList(layers)
1600
+ self.output = nn.Linear(hidden_dim, 1)
1601
+ if task_type == 'classification':
1602
+ self.output_act = nn.Identity()
1603
+ else:
1604
+ self.output_act = nn.Softplus()
1605
+ self.task_type = task_type
1606
+ # Keep adjacency as a buffer for DataParallel copies.
1607
+ self.register_buffer("adj_buffer", torch.empty(0))
1608
+
1609
+ def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
1610
+ adj_used = adj if adj is not None else getattr(
1611
+ self, "adj_buffer", None)
1612
+ if adj_used is None or adj_used.numel() == 0:
1613
+ raise RuntimeError("Adjacency is not set for GNN forward.")
1614
+ h = x
1615
+ for layer in self.layers:
1616
+ h = layer(h, adj_used)
1617
+ h = torch.sparse.mm(adj_used, h)
1618
+ out = self.output(h)
1619
+ return self.output_act(out)
1620
+
1621
+
1622
+ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
1623
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
1624
+ num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
1625
+ learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
1626
+ task_type: str = 'regression', tweedie_power: float = 1.5,
1627
+ weight_decay: float = 0.0,
1628
+ use_data_parallel: bool = False, use_ddp: bool = False,
1629
+ use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
1630
+ graph_cache_path: Optional[str] = None,
1631
+ max_gpu_knn_nodes: Optional[int] = None,
1632
+ knn_gpu_mem_ratio: float = 0.9,
1633
+ knn_gpu_mem_overhead: float = 2.0,
1634
+ knn_cpu_jobs: Optional[int] = -1) -> None:
1635
+ super().__init__()
1636
+ self.model_nme = model_nme
1637
+ self.input_dim = input_dim
1638
+ self.hidden_dim = hidden_dim
1639
+ self.num_layers = num_layers
1640
+ self.k_neighbors = max(1, k_neighbors)
1641
+ self.dropout = dropout
1642
+ self.learning_rate = learning_rate
1643
+ self.weight_decay = weight_decay
1644
+ self.epochs = epochs
1645
+ self.patience = patience
1646
+ self.task_type = task_type
1647
+ self.use_approx_knn = use_approx_knn
1648
+ self.approx_knn_threshold = approx_knn_threshold
1649
+ self.graph_cache_path = Path(
1650
+ graph_cache_path) if graph_cache_path else None
1651
+ self.max_gpu_knn_nodes = max_gpu_knn_nodes
1652
+ self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
1653
+ self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
1654
+ self.knn_cpu_jobs = knn_cpu_jobs
1655
+ self._knn_warning_emitted = False
1656
+ self._adj_cache_meta: Optional[Dict[str, Any]] = None
1657
+ self._adj_cache_key: Optional[Tuple[Any, ...]] = None
1658
+ self._adj_cache_tensor: Optional[torch.Tensor] = None
1659
+
1660
+ if self.task_type == 'classification':
1661
+ self.tw_power = None
1662
+ elif 'f' in self.model_nme:
1663
+ self.tw_power = 1.0
1664
+ elif 's' in self.model_nme:
1665
+ self.tw_power = 2.0
1666
+ else:
1667
+ self.tw_power = tweedie_power
1668
+
1669
+ self.ddp_enabled = False
1670
+ self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
1671
+ self.data_parallel_enabled = False
1672
+ self._ddp_disabled = False
1673
+
1674
+ if use_ddp:
1675
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
1676
+ if world_size > 1:
1677
+ print(
1678
+ "[GNN] DDP training is not supported; falling back to single process.",
1679
+ flush=True,
1680
+ )
1681
+ self._ddp_disabled = True
1682
+ use_ddp = False
1683
+
1684
+ # DDP only works with CUDA; fall back to single process if init fails.
1685
+ if use_ddp and torch.cuda.is_available():
1686
+ ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
1687
+ if ddp_ok:
1688
+ self.ddp_enabled = True
1689
+ self.local_rank = local_rank
1690
+ self.device = torch.device(f'cuda:{local_rank}')
1691
+ else:
1692
+ self.device = torch.device('cuda')
1693
+ elif torch.cuda.is_available():
1694
+ if self._ddp_disabled:
1695
+ self.device = torch.device(f'cuda:{self.local_rank}')
1696
+ else:
1697
+ self.device = torch.device('cuda')
1698
+ elif torch.backends.mps.is_available():
1699
+ self.device = torch.device('cpu')
1700
+ global _GNN_MPS_WARNED
1701
+ if not _GNN_MPS_WARNED:
1702
+ print(
1703
+ "[GNN] MPS backend does not support sparse ops; falling back to CPU.",
1704
+ flush=True,
1705
+ )
1706
+ _GNN_MPS_WARNED = True
1707
+ else:
1708
+ self.device = torch.device('cpu')
1709
+ self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
1710
+
1711
+ self.gnn = SimpleGNN(
1712
+ input_dim=self.input_dim,
1713
+ hidden_dim=self.hidden_dim,
1714
+ num_layers=self.num_layers,
1715
+ dropout=self.dropout,
1716
+ task_type=self.task_type
1717
+ ).to(self.device)
1718
+
1719
+ # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
1720
+ if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
1721
+ self.data_parallel_enabled = True
1722
+ self.gnn = nn.DataParallel(
1723
+ self.gnn, device_ids=list(range(torch.cuda.device_count())))
1724
+ self.device = torch.device('cuda')
1725
+
1726
+ if self.ddp_enabled:
1727
+ self.gnn = DDP(
1728
+ self.gnn,
1729
+ device_ids=[self.local_rank],
1730
+ output_device=self.local_rank,
1731
+ find_unused_parameters=False
1732
+ )
1733
+
1734
+ @staticmethod
1735
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
1736
+ if arr is None:
1737
+ return
1738
+ if isinstance(arr, pd.DataFrame):
1739
+ if arr.shape[1] != 1:
1740
+ raise ValueError(f"{name} must be 1d (single column).")
1741
+ length = len(arr)
1742
+ else:
1743
+ arr_np = np.asarray(arr)
1744
+ if arr_np.ndim == 0:
1745
+ raise ValueError(f"{name} must be 1d.")
1746
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
1747
+ raise ValueError(f"{name} must be 1d or Nx1.")
1748
+ length = arr_np.shape[0]
1749
+ if length != n_rows:
1750
+ raise ValueError(
1751
+ f"{name} length {length} does not match X length {n_rows}."
1752
+ )
1753
+
1754
+ def _unwrap_gnn(self) -> nn.Module:
1755
+ if isinstance(self.gnn, (DDP, nn.DataParallel)):
1756
+ return self.gnn.module
1757
+ return self.gnn
1758
+
1759
+ def _set_adj_buffer(self, adj: torch.Tensor) -> None:
1760
+ base = self._unwrap_gnn()
1761
+ if hasattr(base, "adj_buffer"):
1762
+ base.adj_buffer = adj
1763
+ else:
1764
+ base.register_buffer("adj_buffer", adj)
1765
+
1766
+ def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
1767
+ row_hash = pd.util.hash_pandas_object(X_df, index=False).values
1768
+ idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
1769
+ col_sig = ",".join(map(str, X_df.columns))
1770
+ hasher = hashlib.sha256()
1771
+ hasher.update(row_hash.tobytes())
1772
+ hasher.update(idx_hash.tobytes())
1773
+ hasher.update(col_sig.encode("utf-8", errors="ignore"))
1774
+ knn_config = {
1775
+ "k_neighbors": int(self.k_neighbors),
1776
+ "use_approx_knn": bool(self.use_approx_knn),
1777
+ "approx_knn_threshold": int(self.approx_knn_threshold),
1778
+ "use_pyg_knn": bool(self.use_pyg_knn),
1779
+ "pynndescent_available": bool(_PYNN_AVAILABLE),
1780
+ "max_gpu_knn_nodes": (
1781
+ None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
1782
+ ),
1783
+ "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
1784
+ "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
1785
+ }
1786
+ return {
1787
+ "n_samples": int(X_df.shape[0]),
1788
+ "n_features": int(X_df.shape[1]),
1789
+ "hash": hasher.hexdigest(),
1790
+ "knn_config": knn_config,
1791
+ }
1792
+
1793
+ def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
1794
+ return (
1795
+ id(X_df),
1796
+ id(getattr(X_df, "_mgr", None)),
1797
+ id(X_df.index),
1798
+ X_df.shape,
1799
+ tuple(map(str, X_df.columns)),
1800
+ X_df.attrs.get("graph_cache_key"),
1801
+ )
1802
+
1803
+ def invalidate_graph_cache(self) -> None:
1804
+ self._adj_cache_meta = None
1805
+ self._adj_cache_key = None
1806
+ self._adj_cache_tensor = None
1807
+
1808
+ def _load_cached_adj(self,
1809
+ X_df: pd.DataFrame,
1810
+ meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
1811
+ if self.graph_cache_path and self.graph_cache_path.exists():
1812
+ if meta_expected is None:
1813
+ meta_expected = self._graph_cache_meta(X_df)
1814
+ try:
1815
+ payload = torch.load(self.graph_cache_path,
1816
+ map_location=self.device)
1817
+ except Exception as exc:
1818
+ print(
1819
+ f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
1820
+ return None
1821
+ if isinstance(payload, dict) and "adj" in payload:
1822
+ meta_cached = payload.get("meta")
1823
+ if meta_cached == meta_expected:
1824
+ return payload["adj"].to(self.device)
1825
+ print(
1826
+ f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
1827
+ return None
1828
+ if isinstance(payload, torch.Tensor):
1829
+ print(
1830
+ f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
1831
+ return None
1832
+ print(
1833
+ f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
1834
+ return None
1835
+
1836
+ def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
1837
+ n_samples = X_np.shape[0]
1838
+ k = min(self.k_neighbors, max(1, n_samples - 1))
1839
+ n_neighbors = min(k + 1, n_samples)
1840
+ use_approx = (self.use_approx_knn or n_samples >=
1841
+ self.approx_knn_threshold) and _PYNN_AVAILABLE
1842
+ indices = None
1843
+ if use_approx:
1844
+ try:
1845
+ nn_index = pynndescent.NNDescent(
1846
+ X_np,
1847
+ n_neighbors=n_neighbors,
1848
+ random_state=0
1849
+ )
1850
+ indices, _ = nn_index.neighbor_graph
1851
+ except Exception as exc:
1852
+ print(
1853
+ f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
1854
+ use_approx = False
1855
+
1856
+ if indices is None:
1857
+ nbrs = NearestNeighbors(
1858
+ n_neighbors=n_neighbors,
1859
+ algorithm="auto",
1860
+ n_jobs=self.knn_cpu_jobs,
1861
+ )
1862
+ nbrs.fit(X_np)
1863
+ _, indices = nbrs.kneighbors(X_np)
1864
+
1865
+ indices = np.asarray(indices)
1866
+ rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
1867
+ np.int64, copy=False)
1868
+ cols = indices.reshape(-1).astype(np.int64, copy=False)
1869
+ mask = rows != cols
1870
+ rows = rows[mask]
1871
+ cols = cols[mask]
1872
+ rows_base = rows
1873
+ cols_base = cols
1874
+ self_loops = np.arange(n_samples, dtype=np.int64)
1875
+ rows = np.concatenate([rows_base, cols_base, self_loops])
1876
+ cols = np.concatenate([cols_base, rows_base, self_loops])
1877
+
1878
+ edge_index_np = np.stack([rows, cols], axis=0)
1879
+ edge_index = torch.as_tensor(edge_index_np, device=self.device)
1880
+ return edge_index
1881
+
1882
+ def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
1883
+ if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
1884
+ # Defensive: check use_pyg_knn before calling.
1885
+ raise RuntimeError(
1886
+ "GPU graph builder requested but PyG is unavailable.")
1887
+
1888
+ n_samples = X_tensor.size(0)
1889
+ k = min(self.k_neighbors, max(1, n_samples - 1))
1890
+
1891
+ # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
1892
+ edge_index = knn_graph(
1893
+ X_tensor,
1894
+ k=k,
1895
+ loop=False
1896
+ )
1897
+ edge_index = to_undirected(edge_index, num_nodes=n_samples)
1898
+ edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
1899
+ return edge_index
1900
+
1901
+ def _log_knn_fallback(self, reason: str) -> None:
1902
+ if self._knn_warning_emitted:
1903
+ return
1904
+ if (not self.ddp_enabled) or self.local_rank == 0:
1905
+ print(f"[GNN] Falling back to CPU kNN builder: {reason}")
1906
+ self._knn_warning_emitted = True
1907
+
1908
+ def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
1909
+ if not self.use_pyg_knn:
1910
+ return False
1911
+
1912
+ reason = None
1913
+ if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
1914
+ reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
1915
+ elif self.device.type == 'cuda' and torch.cuda.is_available():
1916
+ try:
1917
+ device_index = self.device.index
1918
+ if device_index is None:
1919
+ device_index = torch.cuda.current_device()
1920
+ free_mem, total_mem = torch.cuda.mem_get_info(device_index)
1921
+ feature_bytes = X_tensor.element_size() * X_tensor.nelement()
1922
+ required = int(feature_bytes * self.knn_gpu_mem_overhead)
1923
+ budget = int(free_mem * self.knn_gpu_mem_ratio)
1924
+ if required > budget:
1925
+ required_gb = required / (1024 ** 3)
1926
+ budget_gb = budget / (1024 ** 3)
1927
+ reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
1928
+ f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
1929
+ except Exception:
1930
+ # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
1931
+ reason = None
1932
+
1933
+ if reason:
1934
+ self._log_knn_fallback(reason)
1935
+ return False
1936
+ return True
1937
+
1938
+ def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
1939
+ values = torch.ones(edge_index.shape[1], device=self.device)
1940
+ adj = torch.sparse_coo_tensor(
1941
+ edge_index.to(self.device), values, (num_nodes, num_nodes))
1942
+ adj = adj.coalesce()
1943
+
1944
+ deg = torch.sparse.sum(adj, dim=1).to_dense()
1945
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
1946
+ row, col = adj.indices()
1947
+ norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
1948
+ adj_norm = torch.sparse_coo_tensor(
1949
+ adj.indices(), norm_values, size=adj.shape)
1950
+ return adj_norm
1951
+
1952
+ def _tensorize_split(self, X, y, w, allow_none: bool = False):
1953
+ if X is None and allow_none:
1954
+ return None, None, None
1955
+ if not isinstance(X, pd.DataFrame):
1956
+ raise ValueError("X must be a pandas DataFrame for GNN.")
1957
+ n_rows = len(X)
1958
+ if y is not None:
1959
+ self._validate_vector(y, "y", n_rows)
1960
+ if w is not None:
1961
+ self._validate_vector(w, "w", n_rows)
1962
+ X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
1963
+ X, "to_numpy") else np.asarray(X, dtype=np.float32)
1964
+ X_tensor = torch.as_tensor(
1965
+ X_np, dtype=torch.float32, device=self.device)
1966
+ if y is None:
1967
+ y_tensor = None
1968
+ else:
1969
+ y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
1970
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
1971
+ y_tensor = torch.as_tensor(
1972
+ y_np, dtype=torch.float32, device=self.device).view(-1, 1)
1973
+ if w is None:
1974
+ w_tensor = torch.ones(
1975
+ (len(X), 1), dtype=torch.float32, device=self.device)
1976
+ else:
1977
+ w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
1978
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
1979
+ w_tensor = torch.as_tensor(
1980
+ w_np, dtype=torch.float32, device=self.device).view(-1, 1)
1981
+ return X_tensor, y_tensor, w_tensor
1982
+
1983
+ def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
1984
+ if not isinstance(X_df, pd.DataFrame):
1985
+ raise ValueError("X must be a pandas DataFrame for graph building.")
1986
+ meta_expected = None
1987
+ cache_key = None
1988
+ if self.graph_cache_path:
1989
+ meta_expected = self._graph_cache_meta(X_df)
1990
+ if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
1991
+ cached = self._adj_cache_tensor
1992
+ if cached.device != self.device:
1993
+ cached = cached.to(self.device)
1994
+ self._adj_cache_tensor = cached
1995
+ return cached
1996
+ else:
1997
+ cache_key = self._graph_cache_key(X_df)
1998
+ if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
1999
+ cached = self._adj_cache_tensor
2000
+ if cached.device != self.device:
2001
+ cached = cached.to(self.device)
2002
+ self._adj_cache_tensor = cached
2003
+ return cached
2004
+ X_np = None
2005
+ if X_tensor is None:
2006
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
2007
+ X_tensor = torch.as_tensor(
2008
+ X_np, dtype=torch.float32, device=self.device)
2009
+ if self.graph_cache_path:
2010
+ cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
2011
+ if cached is not None:
2012
+ self._adj_cache_meta = meta_expected
2013
+ self._adj_cache_key = None
2014
+ self._adj_cache_tensor = cached
2015
+ return cached
2016
+ use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
2017
+ if use_gpu_knn:
2018
+ edge_index = self._build_edge_index_gpu(X_tensor)
2019
+ else:
2020
+ if X_np is None:
2021
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
2022
+ edge_index = self._build_edge_index_cpu(X_np)
2023
+ adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
2024
+ if self.graph_cache_path:
2025
+ try:
2026
+ IOUtils.ensure_parent_dir(str(self.graph_cache_path))
2027
+ torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
2028
+ except Exception as exc:
2029
+ print(
2030
+ f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
2031
+ self._adj_cache_meta = meta_expected
2032
+ self._adj_cache_key = None
2033
+ else:
2034
+ self._adj_cache_meta = None
2035
+ self._adj_cache_key = cache_key
2036
+ self._adj_cache_tensor = adj_norm
2037
+ return adj_norm
2038
+
2039
+ def fit(self, X_train, y_train, w_train=None,
2040
+ X_val=None, y_val=None, w_val=None,
2041
+ trial: Optional[optuna.trial.Trial] = None):
2042
+
2043
+ X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
2044
+ X_train, y_train, w_train, allow_none=False)
2045
+ has_val = X_val is not None and y_val is not None
2046
+ if has_val:
2047
+ X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
2048
+ X_val, y_val, w_val, allow_none=False)
2049
+ else:
2050
+ X_val_tensor = y_val_tensor = w_val_tensor = None
2051
+
2052
+ adj_train = self._build_graph_from_df(X_train, X_train_tensor)
2053
+ adj_val = self._build_graph_from_df(
2054
+ X_val, X_val_tensor) if has_val else None
2055
+ # DataParallel needs adjacency cached on the model to avoid scatter.
2056
+ self._set_adj_buffer(adj_train)
2057
+
2058
+ base_gnn = self._unwrap_gnn()
2059
+ optimizer = torch.optim.Adam(
2060
+ base_gnn.parameters(),
2061
+ lr=self.learning_rate,
2062
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
2063
+ )
2064
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
2065
+
2066
+ best_loss = float('inf')
2067
+ best_state = None
2068
+ patience_counter = 0
2069
+ best_epoch = None
2070
+
2071
+ for epoch in range(1, self.epochs + 1):
2072
+ epoch_start_ts = time.time()
2073
+ self.gnn.train()
2074
+ optimizer.zero_grad()
2075
+ with autocast(enabled=(self.device.type == 'cuda')):
2076
+ if self.data_parallel_enabled:
2077
+ y_pred = self.gnn(X_train_tensor)
2078
+ else:
2079
+ y_pred = self.gnn(X_train_tensor, adj_train)
2080
+ loss = self._compute_weighted_loss(
2081
+ y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
2082
+ scaler.scale(loss).backward()
2083
+ scaler.unscale_(optimizer)
2084
+ clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
2085
+ scaler.step(optimizer)
2086
+ scaler.update()
2087
+
2088
+ val_loss = None
2089
+ if has_val:
2090
+ self.gnn.eval()
2091
+ if self.data_parallel_enabled and adj_val is not None:
2092
+ self._set_adj_buffer(adj_val)
2093
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
2094
+ if self.data_parallel_enabled:
2095
+ y_val_pred = self.gnn(X_val_tensor)
2096
+ else:
2097
+ y_val_pred = self.gnn(X_val_tensor, adj_val)
2098
+ val_loss = self._compute_weighted_loss(
2099
+ y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
2100
+ if self.data_parallel_enabled:
2101
+ # Restore training adjacency.
2102
+ self._set_adj_buffer(adj_train)
2103
+
2104
+ is_best = val_loss is not None and val_loss < best_loss
2105
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
2106
+ val_loss, best_loss, best_state, patience_counter, base_gnn,
2107
+ ignore_keys=["adj_buffer"])
2108
+ if is_best:
2109
+ best_epoch = epoch
2110
+
2111
+ prune_now = False
2112
+ if trial is not None:
2113
+ trial.report(val_loss, epoch)
2114
+ if trial.should_prune():
2115
+ prune_now = True
2116
+
2117
+ if dist.is_initialized():
2118
+ flag = torch.tensor(
2119
+ [1 if prune_now else 0],
2120
+ device=self.device,
2121
+ dtype=torch.int32,
2122
+ )
2123
+ dist.broadcast(flag, src=0)
2124
+ prune_now = bool(flag.item())
2125
+
2126
+ if prune_now:
2127
+ raise optuna.TrialPruned()
2128
+ if stop_training:
2129
+ break
2130
+
2131
+ should_log = (not dist.is_initialized()
2132
+ or DistributedUtils.is_main_process())
2133
+ if should_log:
2134
+ elapsed = int(time.time() - epoch_start_ts)
2135
+ if val_loss is None:
2136
+ print(
2137
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
2138
+ flush=True,
2139
+ )
2140
+ else:
2141
+ print(
2142
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
2143
+ f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
2144
+ flush=True,
2145
+ )
2146
+
2147
+ if best_state is not None:
2148
+ base_gnn.load_state_dict(best_state, strict=False)
2149
+ self.best_epoch = int(best_epoch or self.epochs)
2150
+
2151
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
2152
+ self.gnn.eval()
2153
+ X_tensor, _, _ = self._tensorize_split(
2154
+ X, None, None, allow_none=False)
2155
+ adj = self._build_graph_from_df(X, X_tensor)
2156
+ if self.data_parallel_enabled:
2157
+ self._set_adj_buffer(adj)
2158
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
2159
+ with inference_cm():
2160
+ if self.data_parallel_enabled:
2161
+ y_pred = self.gnn(X_tensor).cpu().numpy()
2162
+ else:
2163
+ y_pred = self.gnn(X_tensor, adj).cpu().numpy()
2164
+ if self.task_type == 'classification':
2165
+ y_pred = 1 / (1 + np.exp(-y_pred))
2166
+ else:
2167
+ y_pred = np.clip(y_pred, 1e-6, None)
2168
+ return y_pred.ravel()
2169
+
2170
+ def encode(self, X: pd.DataFrame) -> np.ndarray:
2171
+ """Return per-sample node embeddings (hidden representations)."""
2172
+ base = self._unwrap_gnn()
2173
+ base.eval()
2174
+ X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
2175
+ adj = self._build_graph_from_df(X, X_tensor)
2176
+ if self.data_parallel_enabled:
2177
+ self._set_adj_buffer(adj)
2178
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
2179
+ with inference_cm():
2180
+ h = X_tensor
2181
+ layers = getattr(base, "layers", None)
2182
+ if layers is None:
2183
+ raise RuntimeError("GNN base module does not expose layers.")
2184
+ for layer in layers:
2185
+ h = layer(h, adj)
2186
+ h = torch.sparse.mm(adj, h)
2187
+ return h.detach().cpu().numpy()
2188
+
2189
+ def set_params(self, params: Dict[str, Any]):
2190
+ for key, value in params.items():
2191
+ if hasattr(self, key):
2192
+ setattr(self, key, value)
2193
+ else:
2194
+ raise ValueError(f"Parameter {key} not found in GNN model.")
2195
+ # Rebuild the backbone after structural parameter changes.
2196
+ self.gnn = SimpleGNN(
2197
+ input_dim=self.input_dim,
2198
+ hidden_dim=self.hidden_dim,
2199
+ num_layers=self.num_layers,
2200
+ dropout=self.dropout,
2201
+ task_type=self.task_type
2202
+ ).to(self.device)
2203
+ return self