ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -1,2196 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- import hashlib
5
- import math
6
- import os
7
- import time
8
- from contextlib import nullcontext
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Tuple
11
-
12
- import numpy as np
13
- import optuna
14
- import pandas as pd
15
- import torch
16
- import torch.distributed as dist
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
- from sklearn.neighbors import NearestNeighbors
20
- from torch.cuda.amp import autocast, GradScaler
21
- from torch.nn.parallel import DistributedDataParallel as DDP
22
- from torch.nn.utils import clip_grad_norm_
23
- from torch.utils.data import Dataset, TensorDataset
24
-
25
- from .utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
26
-
27
- try:
28
- from torch_geometric.nn import knn_graph
29
- from torch_geometric.utils import add_self_loops, to_undirected
30
- _PYG_AVAILABLE = True
31
- except Exception:
32
- knn_graph = None # type: ignore
33
- add_self_loops = None # type: ignore
34
- to_undirected = None # type: ignore
35
- _PYG_AVAILABLE = False
36
-
37
- try:
38
- import pynndescent
39
- _PYNN_AVAILABLE = True
40
- except Exception:
41
- pynndescent = None # type: ignore
42
- _PYNN_AVAILABLE = False
43
-
44
- _GNN_MPS_WARNED = False
45
-
46
- # =============================================================================
47
- # ResNet model and sklearn-style wrapper
48
- # =============================================================================
49
-
50
- # ResNet model definition
51
- # Residual block: two linear layers + ReLU + residual connection
52
- # ResBlock inherits nn.Module
53
- class ResBlock(nn.Module):
54
- def __init__(self, dim: int, dropout: float = 0.1,
55
- use_layernorm: bool = False, residual_scale: float = 0.1,
56
- stochastic_depth: float = 0.0
57
- ):
58
- super().__init__()
59
- self.use_layernorm = use_layernorm
60
-
61
- if use_layernorm:
62
- Norm = nn.LayerNorm # Normalize the last dimension
63
- else:
64
- def Norm(d): return nn.BatchNorm1d(d) # Keep a switch to try BN
65
-
66
- self.norm1 = Norm(dim)
67
- self.fc1 = nn.Linear(dim, dim, bias=True)
68
- self.act = nn.ReLU(inplace=True)
69
- self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
70
- # Enable post-second-layer norm if needed: self.norm2 = Norm(dim)
71
- self.fc2 = nn.Linear(dim, dim, bias=True)
72
-
73
- # Residual scaling to stabilize early training
74
- self.res_scale = nn.Parameter(
75
- torch.tensor(residual_scale, dtype=torch.float32)
76
- )
77
- self.stochastic_depth = max(0.0, float(stochastic_depth))
78
-
79
- def _drop_path(self, x: torch.Tensor) -> torch.Tensor:
80
- if self.stochastic_depth <= 0.0 or not self.training:
81
- return x
82
- keep_prob = 1.0 - self.stochastic_depth
83
- if keep_prob <= 0.0:
84
- return torch.zeros_like(x)
85
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
86
- random_tensor = keep_prob + torch.rand(
87
- shape, dtype=x.dtype, device=x.device)
88
- binary_tensor = torch.floor(random_tensor)
89
- return x * binary_tensor / keep_prob
90
-
91
- def forward(self, x):
92
- # Pre-activation structure
93
- out = self.norm1(x)
94
- out = self.fc1(out)
95
- out = self.act(out)
96
- out = self.dropout(out)
97
- # If a second norm is enabled: out = self.norm2(out)
98
- out = self.fc2(out)
99
- # Apply residual scaling then add
100
- out = self.res_scale * out
101
- out = self._drop_path(out)
102
- return x + out
103
-
104
- # ResNetSequential defines the full network
105
-
106
-
107
- class ResNetSequential(nn.Module):
108
- # Input shape: (batch, input_dim)
109
- # Network: FC + norm + ReLU, stack residual blocks, output Softplus
110
-
111
- def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
112
- use_layernorm: bool = True, dropout: float = 0.1,
113
- residual_scale: float = 0.1, stochastic_depth: float = 0.0,
114
- task_type: str = 'regression'):
115
- super(ResNetSequential, self).__init__()
116
-
117
- self.net = nn.Sequential()
118
- self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
119
-
120
- # Optional explicit normalization after the first layer:
121
- # For LayerNorm:
122
- # self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
123
- # Or BatchNorm:
124
- # self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
125
-
126
- # If desired, insert ReLU before residual blocks:
127
- # self.net.add_module('relu1', nn.ReLU(inplace=True))
128
-
129
- # Residual blocks
130
- drop_path_rate = max(0.0, float(stochastic_depth))
131
- for i in range(block_num):
132
- if block_num > 1:
133
- block_drop = drop_path_rate * (i / (block_num - 1))
134
- else:
135
- block_drop = drop_path_rate
136
- self.net.add_module(
137
- f'ResBlk_{i+1}',
138
- ResBlock(
139
- hidden_dim,
140
- dropout=dropout,
141
- use_layernorm=use_layernorm,
142
- residual_scale=residual_scale,
143
- stochastic_depth=block_drop)
144
- )
145
-
146
- self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
147
-
148
- if task_type == 'classification':
149
- self.net.add_module('softplus', nn.Identity())
150
- else:
151
- self.net.add_module('softplus', nn.Softplus())
152
-
153
- def forward(self, x):
154
- if self.training and not hasattr(self, '_printed_device'):
155
- print(f">>> ResNetSequential executing on device: {x.device}")
156
- self._printed_device = True
157
- return self.net(x)
158
-
159
- # Define the ResNet sklearn-style wrapper.
160
-
161
-
162
- class ResNetSklearn(TorchTrainerMixin, nn.Module):
163
- def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
164
- block_num: int = 2, batch_num: int = 100, epochs: int = 100,
165
- task_type: str = 'regression',
166
- tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
167
- use_layernorm: bool = True, dropout: float = 0.1,
168
- residual_scale: float = 0.1,
169
- stochastic_depth: float = 0.0,
170
- weight_decay: float = 1e-4,
171
- use_data_parallel: bool = True,
172
- use_ddp: bool = False):
173
- super(ResNetSklearn, self).__init__()
174
-
175
- self.use_ddp = use_ddp
176
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
177
- False, 0, 0, 1)
178
-
179
- if self.use_ddp:
180
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
181
-
182
- self.input_dim = input_dim
183
- self.hidden_dim = hidden_dim
184
- self.block_num = block_num
185
- self.batch_num = batch_num
186
- self.epochs = epochs
187
- self.task_type = task_type
188
- self.model_nme = model_nme
189
- self.learning_rate = learning_rate
190
- self.weight_decay = weight_decay
191
- self.patience = patience
192
- self.use_layernorm = use_layernorm
193
- self.dropout = dropout
194
- self.residual_scale = residual_scale
195
- self.stochastic_depth = max(0.0, float(stochastic_depth))
196
- self.loss_curve_path: Optional[str] = None
197
- self.training_history: Dict[str, List[float]] = {
198
- "train": [], "val": []}
199
- self.use_data_parallel = bool(use_data_parallel)
200
-
201
- # Device selection: cuda > mps > cpu
202
- if self.is_ddp_enabled:
203
- self.device = torch.device(f'cuda:{self.local_rank}')
204
- elif torch.cuda.is_available():
205
- self.device = torch.device('cuda')
206
- elif torch.backends.mps.is_available():
207
- self.device = torch.device('mps')
208
- else:
209
- self.device = torch.device('cpu')
210
-
211
- # Tweedie power (unused for classification)
212
- if self.task_type == 'classification':
213
- self.tw_power = None
214
- elif 'f' in self.model_nme:
215
- self.tw_power = 1
216
- elif 's' in self.model_nme:
217
- self.tw_power = 2
218
- else:
219
- self.tw_power = tweedie_power
220
-
221
- # Build network (construct on CPU first)
222
- core = ResNetSequential(
223
- self.input_dim,
224
- self.hidden_dim,
225
- self.block_num,
226
- use_layernorm=self.use_layernorm,
227
- dropout=self.dropout,
228
- residual_scale=self.residual_scale,
229
- stochastic_depth=self.stochastic_depth,
230
- task_type=self.task_type
231
- )
232
-
233
- # ===== Multi-GPU: DataParallel vs DistributedDataParallel =====
234
- if self.is_ddp_enabled:
235
- core = core.to(self.device)
236
- core = DDP(core, device_ids=[
237
- self.local_rank], output_device=self.local_rank)
238
- self.use_data_parallel = False
239
- elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
240
- if self.use_ddp and not self.is_ddp_enabled:
241
- print(
242
- ">>> DDP requested but not initialized; falling back to DataParallel.")
243
- core = nn.DataParallel(core, device_ids=list(
244
- range(torch.cuda.device_count())))
245
- # DataParallel scatters inputs, but the primary device remains cuda:0.
246
- self.device = torch.device('cuda')
247
- self.use_data_parallel = True
248
- else:
249
- self.use_data_parallel = False
250
-
251
- self.resnet = core.to(self.device)
252
-
253
- # ================ Internal helpers ================
254
- @staticmethod
255
- def _validate_vector(arr, name: str, n_rows: int) -> None:
256
- if arr is None:
257
- return
258
- if isinstance(arr, pd.DataFrame):
259
- if arr.shape[1] != 1:
260
- raise ValueError(f"{name} must be 1d (single column).")
261
- length = len(arr)
262
- else:
263
- arr_np = np.asarray(arr)
264
- if arr_np.ndim == 0:
265
- raise ValueError(f"{name} must be 1d.")
266
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
267
- raise ValueError(f"{name} must be 1d or Nx1.")
268
- length = arr_np.shape[0]
269
- if length != n_rows:
270
- raise ValueError(
271
- f"{name} length {length} does not match X length {n_rows}."
272
- )
273
-
274
- def _validate_inputs(self, X, y, w, label: str) -> None:
275
- if X is None:
276
- raise ValueError(f"{label} X cannot be None.")
277
- n_rows = len(X)
278
- if y is None:
279
- raise ValueError(f"{label} y cannot be None.")
280
- self._validate_vector(y, f"{label} y", n_rows)
281
- self._validate_vector(w, f"{label} w", n_rows)
282
-
283
- def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
284
- self._validate_inputs(X_train, y_train, w_train, "train")
285
- if X_val is not None or y_val is not None or w_val is not None:
286
- if X_val is None or y_val is None:
287
- raise ValueError("validation X and y must both be provided.")
288
- self._validate_inputs(X_val, y_val, w_val, "val")
289
-
290
- def _to_numpy(arr):
291
- if hasattr(arr, "to_numpy"):
292
- return arr.to_numpy(dtype=np.float32, copy=False)
293
- return np.asarray(arr, dtype=np.float32)
294
-
295
- X_tensor = torch.as_tensor(_to_numpy(X_train))
296
- y_tensor = torch.as_tensor(_to_numpy(y_train)).view(-1, 1)
297
- w_tensor = (
298
- torch.as_tensor(_to_numpy(w_train)).view(-1, 1)
299
- if w_train is not None else torch.ones_like(y_tensor)
300
- )
301
-
302
- has_val = X_val is not None and y_val is not None
303
- if has_val:
304
- X_val_tensor = torch.as_tensor(_to_numpy(X_val))
305
- y_val_tensor = torch.as_tensor(_to_numpy(y_val)).view(-1, 1)
306
- w_val_tensor = (
307
- torch.as_tensor(_to_numpy(w_val)).view(-1, 1)
308
- if w_val is not None else torch.ones_like(y_val_tensor)
309
- )
310
- else:
311
- X_val_tensor = y_val_tensor = w_val_tensor = None
312
- return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
313
-
314
- def forward(self, x):
315
- # Handle SHAP NumPy input.
316
- if isinstance(x, np.ndarray):
317
- x_tensor = torch.as_tensor(x, dtype=torch.float32)
318
- else:
319
- x_tensor = x
320
-
321
- x_tensor = x_tensor.to(self.device)
322
- y_pred = self.resnet(x_tensor)
323
- return y_pred
324
-
325
- # ---------------- Training ----------------
326
-
327
- def fit(self, X_train, y_train, w_train=None,
328
- X_val=None, y_val=None, w_val=None, trial=None):
329
-
330
- X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
331
- self._build_train_val_tensors(
332
- X_train, y_train, w_train, X_val, y_val, w_val)
333
-
334
- dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
335
- dataloader, accum_steps = self._build_dataloader(
336
- dataset,
337
- N=X_tensor.shape[0],
338
- base_bs_gpu=(2048, 1024, 512),
339
- base_bs_cpu=(256, 128),
340
- min_bs=64,
341
- target_effective_cuda=2048,
342
- target_effective_cpu=1024
343
- )
344
-
345
- # Set sampler epoch at the start of each epoch to keep shuffling deterministic.
346
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
347
- self.dataloader_sampler = dataloader.sampler
348
- else:
349
- self.dataloader_sampler = None
350
-
351
- # === 4. Optimizer and AMP ===
352
- self.optimizer = torch.optim.Adam(
353
- self.resnet.parameters(),
354
- lr=self.learning_rate,
355
- weight_decay=float(self.weight_decay),
356
- )
357
- self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
358
-
359
- X_val_dev = y_val_dev = w_val_dev = None
360
- val_dataloader = None
361
- if has_val:
362
- # Build validation DataLoader.
363
- val_dataset = TensorDataset(
364
- X_val_tensor, y_val_tensor, w_val_tensor)
365
- # No backward pass in validation; batch size can be larger for throughput.
366
- val_dataloader = self._build_val_dataloader(
367
- val_dataset, dataloader, accum_steps)
368
- # Validation usually does not need a DDP sampler because we validate on the main process
369
- # or aggregate results. For simplicity, keep validation on a single GPU or the main process.
370
-
371
- is_data_parallel = isinstance(self.resnet, nn.DataParallel)
372
-
373
- def forward_fn(batch):
374
- X_batch, y_batch, w_batch = batch
375
-
376
- if not is_data_parallel:
377
- X_batch = X_batch.to(self.device, non_blocking=True)
378
- # Keep targets and weights on the main device for loss computation.
379
- y_batch = y_batch.to(self.device, non_blocking=True)
380
- w_batch = w_batch.to(self.device, non_blocking=True)
381
-
382
- y_pred = self.resnet(X_batch)
383
- return y_pred, y_batch, w_batch
384
-
385
- def val_forward_fn():
386
- total_loss = 0.0
387
- total_weight = 0.0
388
- for batch in val_dataloader:
389
- X_b, y_b, w_b = batch
390
- if not is_data_parallel:
391
- X_b = X_b.to(self.device, non_blocking=True)
392
- y_b = y_b.to(self.device, non_blocking=True)
393
- w_b = w_b.to(self.device, non_blocking=True)
394
-
395
- y_pred = self.resnet(X_b)
396
-
397
- # Manually compute weighted loss for accurate aggregation.
398
- losses = self._compute_losses(
399
- y_pred, y_b, apply_softplus=False)
400
-
401
- batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
402
- batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
403
-
404
- total_loss += batch_weighted_loss_sum.item()
405
- total_weight += batch_weight_sum.item()
406
-
407
- return total_loss / max(total_weight, EPS)
408
-
409
- clip_fn = None
410
- if self.device.type == 'cuda':
411
- def clip_fn(): return (self.scaler.unscale_(self.optimizer),
412
- clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
413
-
414
- # Under DDP, only the main process prints logs and saves models.
415
- if self.is_ddp_enabled and not DistributedUtils.is_main_process():
416
- # Non-main processes skip validation callback logging (handled inside _train_model).
417
- pass
418
-
419
- best_state, history = self._train_model(
420
- self.resnet,
421
- dataloader,
422
- accum_steps,
423
- self.optimizer,
424
- self.scaler,
425
- forward_fn,
426
- val_forward_fn if has_val else None,
427
- apply_softplus=False,
428
- clip_fn=clip_fn,
429
- trial=trial,
430
- loss_curve_path=getattr(self, "loss_curve_path", None)
431
- )
432
-
433
- if has_val and best_state is not None:
434
- self.resnet.load_state_dict(best_state)
435
- self.training_history = history
436
-
437
- # ---------------- Prediction ----------------
438
-
439
- def predict(self, X_test):
440
- self.resnet.eval()
441
- if isinstance(X_test, pd.DataFrame):
442
- X_np = X_test.to_numpy(dtype=np.float32, copy=False)
443
- else:
444
- X_np = np.asarray(X_test, dtype=np.float32)
445
-
446
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
447
- with inference_cm():
448
- y_pred = self(X_np).cpu().numpy()
449
-
450
- if self.task_type == 'classification':
451
- y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid converts logits to probabilities.
452
- else:
453
- y_pred = np.clip(y_pred, 1e-6, None)
454
- return y_pred.flatten()
455
-
456
- # ---------------- Set Params ----------------
457
-
458
- def set_params(self, params):
459
- for key, value in params.items():
460
- if hasattr(self, key):
461
- setattr(self, key, value)
462
- else:
463
- raise ValueError(f"Parameter {key} not found in model.")
464
- return self
465
-
466
-
467
- # =============================================================================
468
- # FT-Transformer model and sklearn-style wrapper.
469
- # =============================================================================
470
- # Define FT-Transformer model structure.
471
-
472
-
473
- class FeatureTokenizer(nn.Module):
474
- """Map numeric/categorical/geo tokens into transformer input tokens."""
475
-
476
- def __init__(
477
- self,
478
- num_numeric: int,
479
- cat_cardinalities,
480
- d_model: int,
481
- num_geo: int = 0,
482
- num_numeric_tokens: int = 1,
483
- ):
484
- super().__init__()
485
-
486
- self.num_numeric = num_numeric
487
- self.num_geo = num_geo
488
- self.has_geo = num_geo > 0
489
-
490
- if num_numeric > 0:
491
- if int(num_numeric_tokens) <= 0:
492
- raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
493
- self.num_numeric_tokens = int(num_numeric_tokens)
494
- self.has_numeric = True
495
- self.num_linear = nn.Linear(num_numeric, d_model * self.num_numeric_tokens)
496
- else:
497
- self.num_numeric_tokens = 0
498
- self.has_numeric = False
499
-
500
- self.embeddings = nn.ModuleList([
501
- nn.Embedding(card, d_model) for card in cat_cardinalities
502
- ])
503
-
504
- if self.has_geo:
505
- # Map geo tokens with a linear layer to avoid one-hot on raw strings; upstream is encoded/normalized.
506
- self.geo_linear = nn.Linear(num_geo, d_model)
507
-
508
- def forward(self, X_num, X_cat, X_geo=None):
509
- tokens = []
510
-
511
- if self.has_numeric:
512
- batch_size = X_num.shape[0]
513
- num_token = self.num_linear(X_num)
514
- num_token = num_token.view(batch_size, self.num_numeric_tokens, -1)
515
- tokens.append(num_token)
516
-
517
- for i, emb in enumerate(self.embeddings):
518
- tok = emb(X_cat[:, i])
519
- tokens.append(tok.unsqueeze(1))
520
-
521
- if self.has_geo:
522
- if X_geo is None:
523
- raise RuntimeError("Geo tokens are enabled but X_geo was not provided.")
524
- geo_token = self.geo_linear(X_geo)
525
- tokens.append(geo_token.unsqueeze(1))
526
-
527
- x = torch.cat(tokens, dim=1)
528
- return x
529
-
530
- # Encoder layer with residual scaling.
531
-
532
-
533
- class ScaledTransformerEncoderLayer(nn.Module):
534
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
535
- dropout: float = 0.1, residual_scale_attn: float = 1.0,
536
- residual_scale_ffn: float = 1.0, norm_first: bool = True,
537
- ):
538
- super().__init__()
539
- self.self_attn = nn.MultiheadAttention(
540
- embed_dim=d_model,
541
- num_heads=nhead,
542
- dropout=dropout,
543
- batch_first=True
544
- )
545
-
546
- # Feed-forward network.
547
- self.linear1 = nn.Linear(d_model, dim_feedforward)
548
- self.dropout = nn.Dropout(dropout)
549
- self.linear2 = nn.Linear(dim_feedforward, d_model)
550
-
551
- # Normalization and dropout.
552
- self.norm1 = nn.LayerNorm(d_model)
553
- self.norm2 = nn.LayerNorm(d_model)
554
- self.dropout1 = nn.Dropout(dropout)
555
- self.dropout2 = nn.Dropout(dropout)
556
-
557
- self.activation = nn.GELU()
558
- # If you prefer ReLU, set: self.activation = nn.ReLU()
559
- self.norm_first = norm_first
560
-
561
- # Residual scaling coefficients.
562
- self.res_scale_attn = residual_scale_attn
563
- self.res_scale_ffn = residual_scale_ffn
564
-
565
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
566
- # Input tensor shape: (batch, seq_len, d_model).
567
- x = src
568
-
569
- if self.norm_first:
570
- # Pre-norm before attention.
571
- x = x + self._sa_block(self.norm1(x), src_mask,
572
- src_key_padding_mask)
573
- x = x + self._ff_block(self.norm2(x))
574
- else:
575
- # Post-norm (usually disabled).
576
- x = self.norm1(
577
- x + self._sa_block(x, src_mask, src_key_padding_mask))
578
- x = self.norm2(x + self._ff_block(x))
579
-
580
- return x
581
-
582
- def _sa_block(self, x, attn_mask, key_padding_mask):
583
- # Self-attention with residual scaling.
584
- attn_out, _ = self.self_attn(
585
- x, x, x,
586
- attn_mask=attn_mask,
587
- key_padding_mask=key_padding_mask,
588
- need_weights=False
589
- )
590
- return self.res_scale_attn * self.dropout1(attn_out)
591
-
592
- def _ff_block(self, x):
593
- # Feed-forward block with residual scaling.
594
- x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
595
- return self.res_scale_ffn * self.dropout2(x2)
596
-
597
- # FT-Transformer core model.
598
-
599
-
600
- class FTTransformerCore(nn.Module):
601
- # Minimal FT-Transformer built from:
602
- # 1) FeatureTokenizer: convert numeric/categorical features to tokens;
603
- # 2) TransformerEncoder: model feature interactions;
604
- # 3) Pooling + MLP + Softplus: positive outputs for Tweedie/Gamma tasks.
605
-
606
- def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
607
- n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
608
- task_type: str = 'regression', num_geo: int = 0,
609
- num_numeric_tokens: int = 1
610
- ):
611
- super().__init__()
612
-
613
- self.num_numeric = int(num_numeric)
614
- self.cat_cardinalities = list(cat_cardinalities or [])
615
-
616
- self.tokenizer = FeatureTokenizer(
617
- num_numeric=num_numeric,
618
- cat_cardinalities=cat_cardinalities,
619
- d_model=d_model,
620
- num_geo=num_geo,
621
- num_numeric_tokens=num_numeric_tokens
622
- )
623
- scale = 1.0 / math.sqrt(n_layers) # Recommended default.
624
- encoder_layer = ScaledTransformerEncoderLayer(
625
- d_model=d_model,
626
- nhead=n_heads,
627
- dim_feedforward=d_model * 4,
628
- dropout=dropout,
629
- residual_scale_attn=scale,
630
- residual_scale_ffn=scale,
631
- norm_first=True,
632
- )
633
- self.encoder = nn.TransformerEncoder(
634
- encoder_layer,
635
- num_layers=n_layers
636
- )
637
- self.n_layers = n_layers
638
-
639
- layers = [
640
- # If you need a deeper head, enable the sample layers below:
641
- # nn.LayerNorm(d_model), # Extra normalization
642
- # nn.Linear(d_model, d_model), # Extra fully connected layer
643
- # nn.GELU(), # Activation
644
- nn.Linear(d_model, 1),
645
- ]
646
-
647
- if task_type == 'classification':
648
- # Classification outputs logits for BCEWithLogitsLoss.
649
- layers.append(nn.Identity())
650
- else:
651
- # Regression keeps positive outputs for Tweedie/Gamma.
652
- layers.append(nn.Softplus())
653
-
654
- self.head = nn.Sequential(*layers)
655
-
656
- # ---- Self-supervised reconstruction head (masked modeling) ----
657
- self.num_recon_head = nn.Linear(
658
- d_model, self.num_numeric) if self.num_numeric > 0 else None
659
- self.cat_recon_heads = nn.ModuleList([
660
- nn.Linear(d_model, int(card)) for card in self.cat_cardinalities
661
- ])
662
-
663
- def forward(
664
- self,
665
- X_num,
666
- X_cat,
667
- X_geo=None,
668
- return_embedding: bool = False,
669
- return_reconstruction: bool = False):
670
-
671
- # Inputs:
672
- # X_num -> float32 tensor with shape (batch, num_numeric_features)
673
- # X_cat -> long tensor with shape (batch, num_categorical_features)
674
- # X_geo -> float32 tensor with shape (batch, geo_token_dim)
675
-
676
- if self.training and not hasattr(self, '_printed_device'):
677
- print(f">>> FTTransformerCore executing on device: {X_num.device}")
678
- self._printed_device = True
679
-
680
- # => tensor shape (batch, token_num, d_model)
681
- tokens = self.tokenizer(X_num, X_cat, X_geo)
682
- # => tensor shape (batch, token_num, d_model)
683
- x = self.encoder(tokens)
684
-
685
- # Mean-pool tokens, then send to the head.
686
- x = x.mean(dim=1) # => tensor shape (batch, d_model)
687
-
688
- if return_reconstruction:
689
- num_pred, cat_logits = self.reconstruct(x)
690
- cat_logits_out = tuple(
691
- cat_logits) if cat_logits is not None else tuple()
692
- if return_embedding:
693
- return x, num_pred, cat_logits_out
694
- return num_pred, cat_logits_out
695
-
696
- if return_embedding:
697
- return x
698
-
699
- # => tensor shape (batch, 1); Softplus keeps it positive.
700
- out = self.head(x)
701
- return out
702
-
703
- def reconstruct(self, embedding: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[torch.Tensor]]:
704
- """Reconstruct numeric/categorical inputs from pooled embedding (batch, d_model)."""
705
- num_pred = self.num_recon_head(
706
- embedding) if self.num_recon_head is not None else None
707
- cat_logits = [head(embedding) for head in self.cat_recon_heads]
708
- return num_pred, cat_logits
709
-
710
- # TabularDataset.
711
-
712
-
713
- class TabularDataset(Dataset):
714
- def __init__(self, X_num, X_cat, X_geo, y, w):
715
-
716
- # Input tensors:
717
- # X_num: torch.float32, shape=(N, num_numeric_features)
718
- # X_cat: torch.long, shape=(N, num_categorical_features)
719
- # X_geo: torch.float32, shape=(N, geo_token_dim), can be empty
720
- # y: torch.float32, shape=(N, 1)
721
- # w: torch.float32, shape=(N, 1)
722
-
723
- self.X_num = X_num
724
- self.X_cat = X_cat
725
- self.X_geo = X_geo
726
- self.y = y
727
- self.w = w
728
-
729
- def __len__(self):
730
- return self.y.shape[0]
731
-
732
- def __getitem__(self, idx):
733
- return (
734
- self.X_num[idx],
735
- self.X_cat[idx],
736
- self.X_geo[idx],
737
- self.y[idx],
738
- self.w[idx],
739
- )
740
-
741
-
742
- class MaskedTabularDataset(Dataset):
743
- def __init__(self,
744
- X_num_masked: torch.Tensor,
745
- X_cat_masked: torch.Tensor,
746
- X_geo: torch.Tensor,
747
- X_num_true: Optional[torch.Tensor],
748
- num_mask: Optional[torch.Tensor],
749
- X_cat_true: Optional[torch.Tensor],
750
- cat_mask: Optional[torch.Tensor]):
751
- self.X_num_masked = X_num_masked
752
- self.X_cat_masked = X_cat_masked
753
- self.X_geo = X_geo
754
- self.X_num_true = X_num_true
755
- self.num_mask = num_mask
756
- self.X_cat_true = X_cat_true
757
- self.cat_mask = cat_mask
758
-
759
- def __len__(self):
760
- return self.X_num_masked.shape[0]
761
-
762
- def __getitem__(self, idx):
763
- return (
764
- self.X_num_masked[idx],
765
- self.X_cat_masked[idx],
766
- self.X_geo[idx],
767
- None if self.X_num_true is None else self.X_num_true[idx],
768
- None if self.num_mask is None else self.num_mask[idx],
769
- None if self.X_cat_true is None else self.X_cat_true[idx],
770
- None if self.cat_mask is None else self.cat_mask[idx],
771
- )
772
-
773
- # Scikit-Learn style wrapper for FTTransformer.
774
-
775
-
776
- class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
777
-
778
- # sklearn-style wrapper:
779
- # - num_cols: numeric feature column names
780
- # - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
781
-
782
- @staticmethod
783
- def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
784
- num_cols_count = len(num_cols or [])
785
- if num_cols_count == 0:
786
- return 0
787
- if requested is not None:
788
- count = int(requested)
789
- if count <= 0:
790
- raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
791
- return count
792
- return max(1, num_cols_count)
793
-
794
- def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
795
- n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
796
- task_type: str = 'regression',
797
- tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
798
- weight_decay: float = 0.0,
799
- use_data_parallel: bool = True,
800
- use_ddp: bool = False,
801
- num_numeric_tokens: Optional[int] = None
802
- ):
803
- super().__init__()
804
-
805
- self.use_ddp = use_ddp
806
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
807
- False, 0, 0, 1)
808
- if self.use_ddp:
809
- self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
810
-
811
- self.model_nme = model_nme
812
- self.num_cols = list(num_cols)
813
- self.cat_cols = list(cat_cols)
814
- self.num_numeric_tokens = self.resolve_numeric_token_count(
815
- self.num_cols,
816
- self.cat_cols,
817
- num_numeric_tokens,
818
- )
819
- self.d_model = d_model
820
- self.n_heads = n_heads
821
- self.n_layers = n_layers
822
- self.dropout = dropout
823
- self.batch_num = batch_num
824
- self.epochs = epochs
825
- self.learning_rate = learning_rate
826
- self.weight_decay = weight_decay
827
- self.task_type = task_type
828
- self.patience = patience
829
- if self.task_type == 'classification':
830
- self.tw_power = None # No Tweedie power for classification.
831
- elif 'f' in self.model_nme:
832
- self.tw_power = 1.0
833
- elif 's' in self.model_nme:
834
- self.tw_power = 2.0
835
- else:
836
- self.tw_power = tweedie_power
837
-
838
- if self.is_ddp_enabled:
839
- self.device = torch.device(f"cuda:{self.local_rank}")
840
- elif torch.cuda.is_available():
841
- self.device = torch.device("cuda")
842
- elif torch.backends.mps.is_available():
843
- self.device = torch.device("mps")
844
- else:
845
- self.device = torch.device("cpu")
846
- self.cat_cardinalities = None
847
- self.cat_categories = {}
848
- self.cat_maps: Dict[str, Dict[Any, int]] = {}
849
- self.cat_str_maps: Dict[str, Dict[str, int]] = {}
850
- self._num_mean = None
851
- self._num_std = None
852
- self.ft = None
853
- self.use_data_parallel = bool(use_data_parallel)
854
- self.num_geo = 0
855
- self._geo_params: Dict[str, Any] = {}
856
- self.loss_curve_path: Optional[str] = None
857
- self.training_history: Dict[str, List[float]] = {
858
- "train": [], "val": []}
859
-
860
- def _build_model(self, X_train):
861
- num_numeric = len(self.num_cols)
862
- cat_cardinalities = []
863
-
864
- if num_numeric > 0:
865
- num_arr = X_train[self.num_cols].to_numpy(
866
- dtype=np.float32, copy=False)
867
- num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
868
- mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
869
- std = num_arr.std(axis=0).astype(np.float32, copy=False)
870
- std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
871
- self._num_mean = mean
872
- self._num_std = std
873
- else:
874
- self._num_mean = None
875
- self._num_std = None
876
-
877
- self.cat_maps = {}
878
- self.cat_str_maps = {}
879
- for col in self.cat_cols:
880
- cats = X_train[col].astype('category')
881
- categories = cats.cat.categories
882
- self.cat_categories[col] = categories # Store full category list from training.
883
- self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
884
- if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
885
- self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
886
-
887
- card = len(categories) + 1 # Reserve one extra class for unknown/missing.
888
- cat_cardinalities.append(card)
889
-
890
- self.cat_cardinalities = cat_cardinalities
891
-
892
- core = FTTransformerCore(
893
- num_numeric=num_numeric,
894
- cat_cardinalities=cat_cardinalities,
895
- d_model=self.d_model,
896
- n_heads=self.n_heads,
897
- n_layers=self.n_layers,
898
- dropout=self.dropout,
899
- task_type=self.task_type,
900
- num_geo=self.num_geo,
901
- num_numeric_tokens=self.num_numeric_tokens
902
- )
903
- use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
904
- if self.is_ddp_enabled:
905
- core = core.to(self.device)
906
- core = DDP(core, device_ids=[
907
- self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
908
- self.use_data_parallel = False
909
- elif use_dp:
910
- if self.use_ddp and not self.is_ddp_enabled:
911
- print(
912
- ">>> DDP requested but not initialized; falling back to DataParallel.")
913
- core = nn.DataParallel(core, device_ids=list(
914
- range(torch.cuda.device_count())))
915
- self.device = torch.device("cuda")
916
- self.use_data_parallel = True
917
- else:
918
- self.use_data_parallel = False
919
- self.ft = core.to(self.device)
920
-
921
- def _encode_cats(self, X):
922
- # Input DataFrame must include all categorical feature columns.
923
- # Return int64 array with shape (N, num_categorical_features).
924
-
925
- if not self.cat_cols:
926
- return np.zeros((len(X), 0), dtype='int64')
927
-
928
- n_rows = len(X)
929
- n_cols = len(self.cat_cols)
930
- X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
931
- for idx, col in enumerate(self.cat_cols):
932
- categories = self.cat_categories[col]
933
- mapping = self.cat_maps.get(col)
934
- if mapping is None:
935
- mapping = {cat: i for i, cat in enumerate(categories)}
936
- self.cat_maps[col] = mapping
937
- unknown_idx = len(categories)
938
- series = X[col]
939
- codes = series.map(mapping)
940
- unmapped = series.notna() & codes.isna()
941
- if unmapped.any():
942
- try:
943
- series_cast = series.astype(categories.dtype)
944
- except Exception:
945
- series_cast = None
946
- if series_cast is not None:
947
- codes = series_cast.map(mapping)
948
- unmapped = series_cast.notna() & codes.isna()
949
- if unmapped.any():
950
- str_map = self.cat_str_maps.get(col)
951
- if str_map is None:
952
- str_map = {str(cat): i for i, cat in enumerate(categories)}
953
- self.cat_str_maps[col] = str_map
954
- codes = series.astype(str).map(str_map)
955
- if pd.api.types.is_categorical_dtype(codes):
956
- codes = codes.astype("float")
957
- codes = codes.fillna(unknown_idx).astype(
958
- "int64", copy=False).to_numpy()
959
- X_cat_np[:, idx] = codes
960
- return X_cat_np
961
-
962
- def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
963
- return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
964
-
965
- def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
966
- return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
967
-
968
- @staticmethod
969
- def _validate_vector(arr, name: str, n_rows: int) -> None:
970
- if arr is None:
971
- return
972
- if isinstance(arr, pd.DataFrame):
973
- if arr.shape[1] != 1:
974
- raise ValueError(f"{name} must be 1d (single column).")
975
- length = len(arr)
976
- else:
977
- arr_np = np.asarray(arr)
978
- if arr_np.ndim == 0:
979
- raise ValueError(f"{name} must be 1d.")
980
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
981
- raise ValueError(f"{name} must be 1d or Nx1.")
982
- length = arr_np.shape[0]
983
- if length != n_rows:
984
- raise ValueError(
985
- f"{name} length {length} does not match X length {n_rows}."
986
- )
987
-
988
- def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
989
- if X is None:
990
- if allow_none:
991
- return None, None, None, None, None, False
992
- raise ValueError("Input features X must not be None.")
993
- if not isinstance(X, pd.DataFrame):
994
- raise ValueError("X must be a pandas DataFrame.")
995
- missing_cols = [
996
- col for col in (self.num_cols + self.cat_cols) if col not in X.columns
997
- ]
998
- if missing_cols:
999
- raise ValueError(f"X is missing required columns: {missing_cols}")
1000
- n_rows = len(X)
1001
- if y is not None:
1002
- self._validate_vector(y, "y", n_rows)
1003
- if w is not None:
1004
- self._validate_vector(w, "w", n_rows)
1005
-
1006
- num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
1007
- if not num_np.flags["OWNDATA"]:
1008
- num_np = num_np.copy()
1009
- num_np = np.nan_to_num(num_np, nan=0.0,
1010
- posinf=0.0, neginf=0.0, copy=False)
1011
- if self._num_mean is not None and self._num_std is not None and num_np.size:
1012
- num_np = (num_np - self._num_mean) / self._num_std
1013
- X_num = torch.as_tensor(num_np)
1014
- if self.cat_cols:
1015
- X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
1016
- else:
1017
- X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
1018
-
1019
- if geo_tokens is not None:
1020
- geo_np = np.asarray(geo_tokens, dtype=np.float32)
1021
- if geo_np.shape[0] != n_rows:
1022
- raise ValueError(
1023
- "geo_tokens length does not match X rows.")
1024
- if geo_np.ndim == 1:
1025
- geo_np = geo_np.reshape(-1, 1)
1026
- elif self.num_geo > 0:
1027
- raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
1028
- else:
1029
- geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
1030
- X_geo = torch.as_tensor(geo_np)
1031
-
1032
- y_tensor = torch.as_tensor(
1033
- y.to_numpy(dtype=np.float32, copy=False) if hasattr(
1034
- y, "to_numpy") else np.asarray(y, dtype=np.float32)
1035
- ).view(-1, 1) if y is not None else None
1036
- if y_tensor is None:
1037
- w_tensor = None
1038
- elif w is not None:
1039
- w_tensor = torch.as_tensor(
1040
- w.to_numpy(dtype=np.float32, copy=False) if hasattr(
1041
- w, "to_numpy") else np.asarray(w, dtype=np.float32)
1042
- ).view(-1, 1)
1043
- else:
1044
- w_tensor = torch.ones_like(y_tensor)
1045
- return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
1046
-
1047
- def fit(self, X_train, y_train, w_train=None,
1048
- X_val=None, y_val=None, w_val=None, trial=None,
1049
- geo_train=None, geo_val=None):
1050
-
1051
- # Build the underlying model on first fit.
1052
- self.num_geo = geo_train.shape[1] if geo_train is not None else 0
1053
- if self.ft is None:
1054
- self._build_model(X_train)
1055
-
1056
- X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
1057
- X_train, y_train, w_train, geo_train=geo_train)
1058
- X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
1059
- X_val, y_val, w_val, geo_val=geo_val)
1060
-
1061
- # --- Build DataLoader ---
1062
- dataset = TabularDataset(
1063
- X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
1064
- )
1065
-
1066
- dataloader, accum_steps = self._build_dataloader(
1067
- dataset,
1068
- N=X_num_train.shape[0],
1069
- base_bs_gpu=(2048, 1024, 512),
1070
- base_bs_cpu=(256, 128),
1071
- min_bs=64,
1072
- target_effective_cuda=2048,
1073
- target_effective_cpu=1024
1074
- )
1075
-
1076
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
1077
- self.dataloader_sampler = dataloader.sampler
1078
- else:
1079
- self.dataloader_sampler = None
1080
-
1081
- optimizer = torch.optim.Adam(
1082
- self.ft.parameters(),
1083
- lr=self.learning_rate,
1084
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
1085
- )
1086
- scaler = GradScaler(enabled=(self.device.type == 'cuda'))
1087
-
1088
- X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
1089
- val_dataloader = None
1090
- if has_val:
1091
- val_dataset = TabularDataset(
1092
- X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
1093
- )
1094
- val_dataloader = self._build_val_dataloader(
1095
- val_dataset, dataloader, accum_steps)
1096
-
1097
- is_data_parallel = isinstance(self.ft, nn.DataParallel)
1098
-
1099
- def forward_fn(batch):
1100
- X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
1101
-
1102
- if not is_data_parallel:
1103
- X_num_b = X_num_b.to(self.device, non_blocking=True)
1104
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1105
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1106
- y_b = y_b.to(self.device, non_blocking=True)
1107
- w_b = w_b.to(self.device, non_blocking=True)
1108
-
1109
- y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
1110
- return y_pred, y_b, w_b
1111
-
1112
- def val_forward_fn():
1113
- total_loss = 0.0
1114
- total_weight = 0.0
1115
- for batch in val_dataloader:
1116
- X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
1117
- if not is_data_parallel:
1118
- X_num_b = X_num_b.to(self.device, non_blocking=True)
1119
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1120
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1121
- y_b = y_b.to(self.device, non_blocking=True)
1122
- w_b = w_b.to(self.device, non_blocking=True)
1123
-
1124
- y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
1125
-
1126
- # Manually compute validation loss.
1127
- losses = self._compute_losses(
1128
- y_pred, y_b, apply_softplus=False)
1129
-
1130
- batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
1131
- batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
1132
-
1133
- total_loss += batch_weighted_loss_sum.item()
1134
- total_weight += batch_weight_sum.item()
1135
-
1136
- return total_loss / max(total_weight, EPS)
1137
-
1138
- clip_fn = None
1139
- if self.device.type == 'cuda':
1140
- def clip_fn(): return (scaler.unscale_(optimizer),
1141
- clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
1142
-
1143
- best_state, history = self._train_model(
1144
- self.ft,
1145
- dataloader,
1146
- accum_steps,
1147
- optimizer,
1148
- scaler,
1149
- forward_fn,
1150
- val_forward_fn if has_val else None,
1151
- apply_softplus=False,
1152
- clip_fn=clip_fn,
1153
- trial=trial,
1154
- loss_curve_path=getattr(self, "loss_curve_path", None)
1155
- )
1156
-
1157
- if has_val and best_state is not None:
1158
- self.ft.load_state_dict(best_state)
1159
- self.training_history = history
1160
-
1161
- def fit_unsupervised(self,
1162
- X_train,
1163
- X_val=None,
1164
- trial: Optional[optuna.trial.Trial] = None,
1165
- geo_train=None,
1166
- geo_val=None,
1167
- mask_prob_num: float = 0.15,
1168
- mask_prob_cat: float = 0.15,
1169
- num_loss_weight: float = 1.0,
1170
- cat_loss_weight: float = 1.0) -> float:
1171
- """Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
1172
- self.num_geo = geo_train.shape[1] if geo_train is not None else 0
1173
- if self.ft is None:
1174
- self._build_model(X_train)
1175
-
1176
- X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
1177
- X_train, None, None, geo_tokens=geo_train, allow_none=True)
1178
- has_val = X_val is not None
1179
- if has_val:
1180
- X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
1181
- X_val, None, None, geo_tokens=geo_val, allow_none=True)
1182
- else:
1183
- X_num_val = X_cat_val = X_geo_val = None
1184
-
1185
- N = int(X_num.shape[0])
1186
- num_dim = int(X_num.shape[1])
1187
- cat_dim = int(X_cat.shape[1])
1188
- device_type = self._device_type()
1189
-
1190
- gen = torch.Generator()
1191
- gen.manual_seed(13 + int(getattr(self, "rank", 0)))
1192
-
1193
- base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
1194
- cardinals = getattr(base_model, "cat_cardinalities", None) or []
1195
- unknown_idx = torch.tensor(
1196
- [int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
1197
-
1198
- means = None
1199
- if num_dim > 0:
1200
- # Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
1201
- means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
1202
-
1203
- def _mask_inputs(X_num_in: torch.Tensor,
1204
- X_cat_in: torch.Tensor,
1205
- generator: torch.Generator):
1206
- n_rows = int(X_num_in.shape[0])
1207
- num_mask_local = None
1208
- cat_mask_local = None
1209
- X_num_masked_local = X_num_in
1210
- X_cat_masked_local = X_cat_in
1211
- if num_dim > 0:
1212
- num_mask_local = (torch.rand(
1213
- (n_rows, num_dim), generator=generator) < float(mask_prob_num))
1214
- X_num_masked_local = X_num_in.clone()
1215
- if num_mask_local.any():
1216
- X_num_masked_local[num_mask_local] = means.expand_as(
1217
- X_num_masked_local)[num_mask_local]
1218
- if cat_dim > 0:
1219
- cat_mask_local = (torch.rand(
1220
- (n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
1221
- X_cat_masked_local = X_cat_in.clone()
1222
- if cat_mask_local.any():
1223
- X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
1224
- X_cat_masked_local)[cat_mask_local]
1225
- return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
1226
-
1227
- X_num_true = X_num if num_dim > 0 else None
1228
- X_cat_true = X_cat if cat_dim > 0 else None
1229
- X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
1230
- X_num, X_cat, gen)
1231
-
1232
- dataset = MaskedTabularDataset(
1233
- X_num_masked, X_cat_masked, X_geo,
1234
- X_num_true, num_mask,
1235
- X_cat_true, cat_mask
1236
- )
1237
- dataloader, accum_steps = self._build_dataloader(
1238
- dataset,
1239
- N=N,
1240
- base_bs_gpu=(2048, 1024, 512),
1241
- base_bs_cpu=(256, 128),
1242
- min_bs=64,
1243
- target_effective_cuda=2048,
1244
- target_effective_cpu=1024
1245
- )
1246
- if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
1247
- self.dataloader_sampler = dataloader.sampler
1248
- else:
1249
- self.dataloader_sampler = None
1250
-
1251
- optimizer = torch.optim.Adam(
1252
- self.ft.parameters(),
1253
- lr=self.learning_rate,
1254
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
1255
- )
1256
- scaler = GradScaler(enabled=(device_type == 'cuda'))
1257
-
1258
- def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
1259
- loss = torch.zeros((), device=device, dtype=torch.float32)
1260
-
1261
- if num_pred is not None and num_true_b is not None and num_mask_b is not None:
1262
- num_mask_b = num_mask_b.to(dtype=torch.bool)
1263
- if num_mask_b.any():
1264
- diff = num_pred - num_true_b
1265
- mse = diff * diff
1266
- loss = loss + float(num_loss_weight) * \
1267
- mse[num_mask_b].mean()
1268
-
1269
- if cat_logits and cat_true_b is not None and cat_mask_b is not None:
1270
- cat_mask_b = cat_mask_b.to(dtype=torch.bool)
1271
- cat_losses: List[torch.Tensor] = []
1272
- for j, logits in enumerate(cat_logits):
1273
- mask_j = cat_mask_b[:, j]
1274
- if not mask_j.any():
1275
- continue
1276
- targets = cat_true_b[:, j]
1277
- cat_losses.append(
1278
- F.cross_entropy(logits, targets, reduction='none')[
1279
- mask_j].mean()
1280
- )
1281
- if cat_losses:
1282
- loss = loss + float(cat_loss_weight) * \
1283
- torch.stack(cat_losses).mean()
1284
- return loss
1285
-
1286
- train_history: List[float] = []
1287
- val_history: List[float] = []
1288
- best_loss = float("inf")
1289
- best_state = None
1290
- patience_counter = 0
1291
- is_ddp_model = isinstance(self.ft, DDP)
1292
-
1293
- clip_fn = None
1294
- if self.device.type == 'cuda':
1295
- def clip_fn(): return (scaler.unscale_(optimizer),
1296
- clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
1297
-
1298
- for epoch in range(1, int(self.epochs) + 1):
1299
- if self.dataloader_sampler is not None:
1300
- self.dataloader_sampler.set_epoch(epoch)
1301
-
1302
- self.ft.train()
1303
- optimizer.zero_grad()
1304
- epoch_loss_sum = 0.0
1305
- epoch_count = 0.0
1306
-
1307
- for step, batch in enumerate(dataloader):
1308
- is_update_step = ((step + 1) % accum_steps == 0) or \
1309
- ((step + 1) == len(dataloader))
1310
- sync_cm = self.ft.no_sync if (
1311
- is_ddp_model and not is_update_step) else nullcontext
1312
- with sync_cm():
1313
- with autocast(enabled=(device_type == 'cuda')):
1314
- X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
1315
- X_num_b = X_num_b.to(self.device, non_blocking=True)
1316
- X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1317
- X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1318
- num_true_b = None if num_true_b is None else num_true_b.to(
1319
- self.device, non_blocking=True)
1320
- num_mask_b = None if num_mask_b is None else num_mask_b.to(
1321
- self.device, non_blocking=True)
1322
- cat_true_b = None if cat_true_b is None else cat_true_b.to(
1323
- self.device, non_blocking=True)
1324
- cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
1325
- self.device, non_blocking=True)
1326
-
1327
- num_pred, cat_logits = self.ft(
1328
- X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
1329
- batch_loss = _batch_recon_loss(
1330
- num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device=X_num_b.device)
1331
- local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
1332
- global_bad = local_bad
1333
- if dist.is_initialized():
1334
- bad = torch.tensor(
1335
- [local_bad],
1336
- device=batch_loss.device,
1337
- dtype=torch.int32,
1338
- )
1339
- dist.all_reduce(bad, op=dist.ReduceOp.MAX)
1340
- global_bad = int(bad.item())
1341
-
1342
- if global_bad:
1343
- msg = (
1344
- f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
1345
- f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
1346
- )
1347
- should_log = (not dist.is_initialized()
1348
- or DistributedUtils.is_main_process())
1349
- if should_log:
1350
- print(msg, flush=True)
1351
- print(
1352
- f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
1353
- f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
1354
- f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
1355
- flush=True,
1356
- )
1357
- if X_geo_b is not None:
1358
- print(
1359
- f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
1360
- f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
1361
- f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
1362
- flush=True,
1363
- )
1364
- if trial is not None:
1365
- raise optuna.TrialPruned(msg)
1366
- raise RuntimeError(msg)
1367
- loss_for_backward = batch_loss / float(accum_steps)
1368
- scaler.scale(loss_for_backward).backward()
1369
-
1370
- if is_update_step:
1371
- if clip_fn is not None:
1372
- clip_fn()
1373
- scaler.step(optimizer)
1374
- scaler.update()
1375
- optimizer.zero_grad()
1376
-
1377
- epoch_loss_sum += float(batch_loss.detach().item()) * \
1378
- float(X_num_b.shape[0])
1379
- epoch_count += float(X_num_b.shape[0])
1380
-
1381
- train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
1382
-
1383
- if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
1384
- should_compute_val = (not dist.is_initialized()
1385
- or DistributedUtils.is_main_process())
1386
- loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
1387
- "cpu")
1388
- val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
1389
-
1390
- if should_compute_val:
1391
- self.ft.eval()
1392
- with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
1393
- val_bs = min(
1394
- int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
1395
- total_val = 0.0
1396
- total_n = 0.0
1397
- for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
1398
- end = min(
1399
- int(X_num_val.shape[0]), start + max(1, val_bs))
1400
- X_num_v_true_cpu = X_num_val[start:end]
1401
- X_cat_v_true_cpu = X_cat_val[start:end]
1402
- X_geo_v = X_geo_val[start:end].to(
1403
- self.device, non_blocking=True)
1404
- gen_val = torch.Generator()
1405
- gen_val.manual_seed(10_000 + epoch + start)
1406
- X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
1407
- X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
1408
- X_num_v_true = X_num_v_true_cpu.to(
1409
- self.device, non_blocking=True)
1410
- X_cat_v_true = X_cat_v_true_cpu.to(
1411
- self.device, non_blocking=True)
1412
- X_num_v = X_num_v_cpu.to(
1413
- self.device, non_blocking=True)
1414
- X_cat_v = X_cat_v_cpu.to(
1415
- self.device, non_blocking=True)
1416
- val_num_mask = None if val_num_mask is None else val_num_mask.to(
1417
- self.device, non_blocking=True)
1418
- val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
1419
- self.device, non_blocking=True)
1420
- num_pred_v, cat_logits_v = self.ft(
1421
- X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
1422
- loss_v = _batch_recon_loss(
1423
- num_pred_v, cat_logits_v,
1424
- X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
1425
- X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
1426
- device=X_num_v.device
1427
- )
1428
- if not torch.isfinite(loss_v):
1429
- total_val = float("inf")
1430
- total_n = 1.0
1431
- break
1432
- total_val += float(loss_v.detach().item()
1433
- ) * float(end - start)
1434
- total_n += float(end - start)
1435
- val_loss_tensor[0] = total_val / max(total_n, 1.0)
1436
-
1437
- if dist.is_initialized():
1438
- dist.broadcast(val_loss_tensor, src=0)
1439
- val_loss_value = float(val_loss_tensor.item())
1440
- prune_now = False
1441
- prune_msg = None
1442
- if not np.isfinite(val_loss_value):
1443
- prune_now = True
1444
- prune_msg = (
1445
- f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
1446
- f"(epoch={epoch}, val_loss={val_loss_value})"
1447
- )
1448
- val_history.append(val_loss_value)
1449
-
1450
- if val_loss_value < best_loss:
1451
- best_loss = val_loss_value
1452
- best_state = {
1453
- k: (v.clone() if isinstance(
1454
- v, torch.Tensor) else copy.deepcopy(v))
1455
- for k, v in self.ft.state_dict().items()
1456
- }
1457
- patience_counter = 0
1458
- else:
1459
- patience_counter += 1
1460
- if best_state is not None and patience_counter >= int(self.patience):
1461
- break
1462
-
1463
- if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
1464
- trial.report(val_loss_value, epoch)
1465
- if trial.should_prune():
1466
- prune_now = True
1467
-
1468
- if dist.is_initialized():
1469
- flag = torch.tensor(
1470
- [1 if prune_now else 0],
1471
- device=loss_tensor_device,
1472
- dtype=torch.int32,
1473
- )
1474
- dist.broadcast(flag, src=0)
1475
- prune_now = bool(flag.item())
1476
-
1477
- if prune_now:
1478
- if prune_msg:
1479
- raise optuna.TrialPruned(prune_msg)
1480
- raise optuna.TrialPruned()
1481
-
1482
- self.training_history = {"train": train_history, "val": val_history}
1483
- self._plot_loss_curve(self.training_history, getattr(
1484
- self, "loss_curve_path", None))
1485
- if has_val and best_state is not None:
1486
- self.ft.load_state_dict(best_state)
1487
- return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
1488
-
1489
- def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
1490
- # X_test must include all numeric/categorical columns; geo_tokens is optional.
1491
-
1492
- self.ft.eval()
1493
- X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
1494
- X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
1495
-
1496
- num_rows = X_num.shape[0]
1497
- if num_rows == 0:
1498
- return np.empty(0, dtype=np.float32)
1499
-
1500
- device = self.device if isinstance(
1501
- self.device, torch.device) else torch.device(self.device)
1502
-
1503
- def resolve_batch_size(n_rows: int) -> int:
1504
- if batch_size is not None:
1505
- return max(1, min(int(batch_size), n_rows))
1506
- # Estimate a safe batch size based on model size to avoid attention OOM.
1507
- token_cnt = self.num_numeric_tokens + len(self.cat_cols)
1508
- if self.num_geo > 0:
1509
- token_cnt += 1
1510
- approx_units = max(1, token_cnt * max(1, self.d_model))
1511
- if device.type == 'cuda':
1512
- if approx_units >= 8192:
1513
- base = 512
1514
- elif approx_units >= 4096:
1515
- base = 1024
1516
- else:
1517
- base = 2048
1518
- else:
1519
- base = 512
1520
- return max(1, min(base, n_rows))
1521
-
1522
- eff_batch = resolve_batch_size(num_rows)
1523
- preds: List[torch.Tensor] = []
1524
-
1525
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
1526
- with inference_cm():
1527
- for start in range(0, num_rows, eff_batch):
1528
- end = min(num_rows, start + eff_batch)
1529
- X_num_b = X_num[start:end].to(device, non_blocking=True)
1530
- X_cat_b = X_cat[start:end].to(device, non_blocking=True)
1531
- X_geo_b = X_geo[start:end].to(device, non_blocking=True)
1532
- pred_chunk = self.ft(
1533
- X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
1534
- preds.append(pred_chunk.cpu())
1535
-
1536
- y_pred = torch.cat(preds, dim=0).numpy()
1537
-
1538
- if return_embedding:
1539
- return y_pred
1540
-
1541
- if self.task_type == 'classification':
1542
- # Convert logits to probabilities.
1543
- y_pred = 1 / (1 + np.exp(-y_pred))
1544
- else:
1545
- # Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
1546
- y_pred = np.clip(y_pred, 1e-6, None)
1547
- return y_pred.ravel()
1548
-
1549
- def set_params(self, params: dict):
1550
-
1551
- # Keep sklearn-style behavior.
1552
- # Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
1553
-
1554
- for key, value in params.items():
1555
- if hasattr(self, key):
1556
- setattr(self, key, value)
1557
- else:
1558
- raise ValueError(f"Parameter {key} not found in model.")
1559
- return self
1560
-
1561
-
1562
- # =============================================================================
1563
- # Simplified GNN implementation.
1564
- # =============================================================================
1565
-
1566
-
1567
- class SimpleGraphLayer(nn.Module):
1568
- def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
1569
- super().__init__()
1570
- self.linear = nn.Linear(in_dim, out_dim)
1571
- self.activation = nn.ReLU(inplace=True)
1572
- self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1573
-
1574
- def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
1575
- # Message passing with normalized sparse adjacency: A_hat * X * W.
1576
- h = torch.sparse.mm(adj, x)
1577
- h = self.linear(h)
1578
- h = self.activation(h)
1579
- return self.dropout(h)
1580
-
1581
-
1582
- class SimpleGNN(nn.Module):
1583
- def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
1584
- dropout: float = 0.1, task_type: str = 'regression'):
1585
- super().__init__()
1586
- layers = []
1587
- dim_in = input_dim
1588
- for _ in range(max(1, num_layers)):
1589
- layers.append(SimpleGraphLayer(
1590
- dim_in, hidden_dim, dropout=dropout))
1591
- dim_in = hidden_dim
1592
- self.layers = nn.ModuleList(layers)
1593
- self.output = nn.Linear(hidden_dim, 1)
1594
- if task_type == 'classification':
1595
- self.output_act = nn.Identity()
1596
- else:
1597
- self.output_act = nn.Softplus()
1598
- self.task_type = task_type
1599
- # Keep adjacency as a buffer for DataParallel copies.
1600
- self.register_buffer("adj_buffer", torch.empty(0))
1601
-
1602
- def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
1603
- adj_used = adj if adj is not None else getattr(
1604
- self, "adj_buffer", None)
1605
- if adj_used is None or adj_used.numel() == 0:
1606
- raise RuntimeError("Adjacency is not set for GNN forward.")
1607
- h = x
1608
- for layer in self.layers:
1609
- h = layer(h, adj_used)
1610
- h = torch.sparse.mm(adj_used, h)
1611
- out = self.output(h)
1612
- return self.output_act(out)
1613
-
1614
-
1615
- class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
1616
- def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
1617
- num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
1618
- learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
1619
- task_type: str = 'regression', tweedie_power: float = 1.5,
1620
- weight_decay: float = 0.0,
1621
- use_data_parallel: bool = False, use_ddp: bool = False,
1622
- use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
1623
- graph_cache_path: Optional[str] = None,
1624
- max_gpu_knn_nodes: Optional[int] = None,
1625
- knn_gpu_mem_ratio: float = 0.9,
1626
- knn_gpu_mem_overhead: float = 2.0,
1627
- knn_cpu_jobs: Optional[int] = -1) -> None:
1628
- super().__init__()
1629
- self.model_nme = model_nme
1630
- self.input_dim = input_dim
1631
- self.hidden_dim = hidden_dim
1632
- self.num_layers = num_layers
1633
- self.k_neighbors = max(1, k_neighbors)
1634
- self.dropout = dropout
1635
- self.learning_rate = learning_rate
1636
- self.weight_decay = weight_decay
1637
- self.epochs = epochs
1638
- self.patience = patience
1639
- self.task_type = task_type
1640
- self.use_approx_knn = use_approx_knn
1641
- self.approx_knn_threshold = approx_knn_threshold
1642
- self.graph_cache_path = Path(
1643
- graph_cache_path) if graph_cache_path else None
1644
- self.max_gpu_knn_nodes = max_gpu_knn_nodes
1645
- self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
1646
- self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
1647
- self.knn_cpu_jobs = knn_cpu_jobs
1648
- self._knn_warning_emitted = False
1649
- self._adj_cache_meta: Optional[Dict[str, Any]] = None
1650
- self._adj_cache_key: Optional[Tuple[Any, ...]] = None
1651
- self._adj_cache_tensor: Optional[torch.Tensor] = None
1652
-
1653
- if self.task_type == 'classification':
1654
- self.tw_power = None
1655
- elif 'f' in self.model_nme:
1656
- self.tw_power = 1.0
1657
- elif 's' in self.model_nme:
1658
- self.tw_power = 2.0
1659
- else:
1660
- self.tw_power = tweedie_power
1661
-
1662
- self.ddp_enabled = False
1663
- self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
1664
- self.data_parallel_enabled = False
1665
- self._ddp_disabled = False
1666
-
1667
- if use_ddp:
1668
- world_size = int(os.environ.get("WORLD_SIZE", "1"))
1669
- if world_size > 1:
1670
- print(
1671
- "[GNN] DDP training is not supported; falling back to single process.",
1672
- flush=True,
1673
- )
1674
- self._ddp_disabled = True
1675
- use_ddp = False
1676
-
1677
- # DDP only works with CUDA; fall back to single process if init fails.
1678
- if use_ddp and torch.cuda.is_available():
1679
- ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
1680
- if ddp_ok:
1681
- self.ddp_enabled = True
1682
- self.local_rank = local_rank
1683
- self.device = torch.device(f'cuda:{local_rank}')
1684
- else:
1685
- self.device = torch.device('cuda')
1686
- elif torch.cuda.is_available():
1687
- if self._ddp_disabled:
1688
- self.device = torch.device(f'cuda:{self.local_rank}')
1689
- else:
1690
- self.device = torch.device('cuda')
1691
- elif torch.backends.mps.is_available():
1692
- self.device = torch.device('cpu')
1693
- global _GNN_MPS_WARNED
1694
- if not _GNN_MPS_WARNED:
1695
- print(
1696
- "[GNN] MPS backend does not support sparse ops; falling back to CPU.",
1697
- flush=True,
1698
- )
1699
- _GNN_MPS_WARNED = True
1700
- else:
1701
- self.device = torch.device('cpu')
1702
- self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
1703
-
1704
- self.gnn = SimpleGNN(
1705
- input_dim=self.input_dim,
1706
- hidden_dim=self.hidden_dim,
1707
- num_layers=self.num_layers,
1708
- dropout=self.dropout,
1709
- task_type=self.task_type
1710
- ).to(self.device)
1711
-
1712
- # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
1713
- if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
1714
- self.data_parallel_enabled = True
1715
- self.gnn = nn.DataParallel(
1716
- self.gnn, device_ids=list(range(torch.cuda.device_count())))
1717
- self.device = torch.device('cuda')
1718
-
1719
- if self.ddp_enabled:
1720
- self.gnn = DDP(
1721
- self.gnn,
1722
- device_ids=[self.local_rank],
1723
- output_device=self.local_rank,
1724
- find_unused_parameters=False
1725
- )
1726
-
1727
- @staticmethod
1728
- def _validate_vector(arr, name: str, n_rows: int) -> None:
1729
- if arr is None:
1730
- return
1731
- if isinstance(arr, pd.DataFrame):
1732
- if arr.shape[1] != 1:
1733
- raise ValueError(f"{name} must be 1d (single column).")
1734
- length = len(arr)
1735
- else:
1736
- arr_np = np.asarray(arr)
1737
- if arr_np.ndim == 0:
1738
- raise ValueError(f"{name} must be 1d.")
1739
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
1740
- raise ValueError(f"{name} must be 1d or Nx1.")
1741
- length = arr_np.shape[0]
1742
- if length != n_rows:
1743
- raise ValueError(
1744
- f"{name} length {length} does not match X length {n_rows}."
1745
- )
1746
-
1747
- def _unwrap_gnn(self) -> nn.Module:
1748
- if isinstance(self.gnn, (DDP, nn.DataParallel)):
1749
- return self.gnn.module
1750
- return self.gnn
1751
-
1752
- def _set_adj_buffer(self, adj: torch.Tensor) -> None:
1753
- base = self._unwrap_gnn()
1754
- if hasattr(base, "adj_buffer"):
1755
- base.adj_buffer = adj
1756
- else:
1757
- base.register_buffer("adj_buffer", adj)
1758
-
1759
- def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
1760
- row_hash = pd.util.hash_pandas_object(X_df, index=False).values
1761
- idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
1762
- col_sig = ",".join(map(str, X_df.columns))
1763
- hasher = hashlib.sha256()
1764
- hasher.update(row_hash.tobytes())
1765
- hasher.update(idx_hash.tobytes())
1766
- hasher.update(col_sig.encode("utf-8", errors="ignore"))
1767
- knn_config = {
1768
- "k_neighbors": int(self.k_neighbors),
1769
- "use_approx_knn": bool(self.use_approx_knn),
1770
- "approx_knn_threshold": int(self.approx_knn_threshold),
1771
- "use_pyg_knn": bool(self.use_pyg_knn),
1772
- "pynndescent_available": bool(_PYNN_AVAILABLE),
1773
- "max_gpu_knn_nodes": (
1774
- None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
1775
- ),
1776
- "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
1777
- "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
1778
- }
1779
- return {
1780
- "n_samples": int(X_df.shape[0]),
1781
- "n_features": int(X_df.shape[1]),
1782
- "hash": hasher.hexdigest(),
1783
- "knn_config": knn_config,
1784
- }
1785
-
1786
- def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
1787
- return (
1788
- id(X_df),
1789
- id(getattr(X_df, "_mgr", None)),
1790
- id(X_df.index),
1791
- X_df.shape,
1792
- tuple(map(str, X_df.columns)),
1793
- X_df.attrs.get("graph_cache_key"),
1794
- )
1795
-
1796
- def invalidate_graph_cache(self) -> None:
1797
- self._adj_cache_meta = None
1798
- self._adj_cache_key = None
1799
- self._adj_cache_tensor = None
1800
-
1801
- def _load_cached_adj(self,
1802
- X_df: pd.DataFrame,
1803
- meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
1804
- if self.graph_cache_path and self.graph_cache_path.exists():
1805
- if meta_expected is None:
1806
- meta_expected = self._graph_cache_meta(X_df)
1807
- try:
1808
- payload = torch.load(self.graph_cache_path,
1809
- map_location=self.device)
1810
- except Exception as exc:
1811
- print(
1812
- f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
1813
- return None
1814
- if isinstance(payload, dict) and "adj" in payload:
1815
- meta_cached = payload.get("meta")
1816
- if meta_cached == meta_expected:
1817
- return payload["adj"].to(self.device)
1818
- print(
1819
- f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
1820
- return None
1821
- if isinstance(payload, torch.Tensor):
1822
- print(
1823
- f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
1824
- return None
1825
- print(
1826
- f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
1827
- return None
1828
-
1829
- def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
1830
- n_samples = X_np.shape[0]
1831
- k = min(self.k_neighbors, max(1, n_samples - 1))
1832
- n_neighbors = min(k + 1, n_samples)
1833
- use_approx = (self.use_approx_knn or n_samples >=
1834
- self.approx_knn_threshold) and _PYNN_AVAILABLE
1835
- indices = None
1836
- if use_approx:
1837
- try:
1838
- nn_index = pynndescent.NNDescent(
1839
- X_np,
1840
- n_neighbors=n_neighbors,
1841
- random_state=0
1842
- )
1843
- indices, _ = nn_index.neighbor_graph
1844
- except Exception as exc:
1845
- print(
1846
- f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
1847
- use_approx = False
1848
-
1849
- if indices is None:
1850
- nbrs = NearestNeighbors(
1851
- n_neighbors=n_neighbors,
1852
- algorithm="auto",
1853
- n_jobs=self.knn_cpu_jobs,
1854
- )
1855
- nbrs.fit(X_np)
1856
- _, indices = nbrs.kneighbors(X_np)
1857
-
1858
- indices = np.asarray(indices)
1859
- rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
1860
- np.int64, copy=False)
1861
- cols = indices.reshape(-1).astype(np.int64, copy=False)
1862
- mask = rows != cols
1863
- rows = rows[mask]
1864
- cols = cols[mask]
1865
- rows_base = rows
1866
- cols_base = cols
1867
- self_loops = np.arange(n_samples, dtype=np.int64)
1868
- rows = np.concatenate([rows_base, cols_base, self_loops])
1869
- cols = np.concatenate([cols_base, rows_base, self_loops])
1870
-
1871
- edge_index_np = np.stack([rows, cols], axis=0)
1872
- edge_index = torch.as_tensor(edge_index_np, device=self.device)
1873
- return edge_index
1874
-
1875
- def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
1876
- if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
1877
- # Defensive: check use_pyg_knn before calling.
1878
- raise RuntimeError(
1879
- "GPU graph builder requested but PyG is unavailable.")
1880
-
1881
- n_samples = X_tensor.size(0)
1882
- k = min(self.k_neighbors, max(1, n_samples - 1))
1883
-
1884
- # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
1885
- edge_index = knn_graph(
1886
- X_tensor,
1887
- k=k,
1888
- loop=False
1889
- )
1890
- edge_index = to_undirected(edge_index, num_nodes=n_samples)
1891
- edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
1892
- return edge_index
1893
-
1894
- def _log_knn_fallback(self, reason: str) -> None:
1895
- if self._knn_warning_emitted:
1896
- return
1897
- if (not self.ddp_enabled) or self.local_rank == 0:
1898
- print(f"[GNN] Falling back to CPU kNN builder: {reason}")
1899
- self._knn_warning_emitted = True
1900
-
1901
- def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
1902
- if not self.use_pyg_knn:
1903
- return False
1904
-
1905
- reason = None
1906
- if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
1907
- reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
1908
- elif self.device.type == 'cuda' and torch.cuda.is_available():
1909
- try:
1910
- device_index = self.device.index
1911
- if device_index is None:
1912
- device_index = torch.cuda.current_device()
1913
- free_mem, total_mem = torch.cuda.mem_get_info(device_index)
1914
- feature_bytes = X_tensor.element_size() * X_tensor.nelement()
1915
- required = int(feature_bytes * self.knn_gpu_mem_overhead)
1916
- budget = int(free_mem * self.knn_gpu_mem_ratio)
1917
- if required > budget:
1918
- required_gb = required / (1024 ** 3)
1919
- budget_gb = budget / (1024 ** 3)
1920
- reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
1921
- f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
1922
- except Exception:
1923
- # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
1924
- reason = None
1925
-
1926
- if reason:
1927
- self._log_knn_fallback(reason)
1928
- return False
1929
- return True
1930
-
1931
- def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
1932
- values = torch.ones(edge_index.shape[1], device=self.device)
1933
- adj = torch.sparse_coo_tensor(
1934
- edge_index.to(self.device), values, (num_nodes, num_nodes))
1935
- adj = adj.coalesce()
1936
-
1937
- deg = torch.sparse.sum(adj, dim=1).to_dense()
1938
- deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
1939
- row, col = adj.indices()
1940
- norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
1941
- adj_norm = torch.sparse_coo_tensor(
1942
- adj.indices(), norm_values, size=adj.shape)
1943
- return adj_norm
1944
-
1945
- def _tensorize_split(self, X, y, w, allow_none: bool = False):
1946
- if X is None and allow_none:
1947
- return None, None, None
1948
- if not isinstance(X, pd.DataFrame):
1949
- raise ValueError("X must be a pandas DataFrame for GNN.")
1950
- n_rows = len(X)
1951
- if y is not None:
1952
- self._validate_vector(y, "y", n_rows)
1953
- if w is not None:
1954
- self._validate_vector(w, "w", n_rows)
1955
- X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
1956
- X, "to_numpy") else np.asarray(X, dtype=np.float32)
1957
- X_tensor = torch.as_tensor(
1958
- X_np, dtype=torch.float32, device=self.device)
1959
- if y is None:
1960
- y_tensor = None
1961
- else:
1962
- y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
1963
- y, "to_numpy") else np.asarray(y, dtype=np.float32)
1964
- y_tensor = torch.as_tensor(
1965
- y_np, dtype=torch.float32, device=self.device).view(-1, 1)
1966
- if w is None:
1967
- w_tensor = torch.ones(
1968
- (len(X), 1), dtype=torch.float32, device=self.device)
1969
- else:
1970
- w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
1971
- w, "to_numpy") else np.asarray(w, dtype=np.float32)
1972
- w_tensor = torch.as_tensor(
1973
- w_np, dtype=torch.float32, device=self.device).view(-1, 1)
1974
- return X_tensor, y_tensor, w_tensor
1975
-
1976
- def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
1977
- if not isinstance(X_df, pd.DataFrame):
1978
- raise ValueError("X must be a pandas DataFrame for graph building.")
1979
- meta_expected = None
1980
- cache_key = None
1981
- if self.graph_cache_path:
1982
- meta_expected = self._graph_cache_meta(X_df)
1983
- if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
1984
- cached = self._adj_cache_tensor
1985
- if cached.device != self.device:
1986
- cached = cached.to(self.device)
1987
- self._adj_cache_tensor = cached
1988
- return cached
1989
- else:
1990
- cache_key = self._graph_cache_key(X_df)
1991
- if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
1992
- cached = self._adj_cache_tensor
1993
- if cached.device != self.device:
1994
- cached = cached.to(self.device)
1995
- self._adj_cache_tensor = cached
1996
- return cached
1997
- X_np = None
1998
- if X_tensor is None:
1999
- X_np = X_df.to_numpy(dtype=np.float32, copy=False)
2000
- X_tensor = torch.as_tensor(
2001
- X_np, dtype=torch.float32, device=self.device)
2002
- if self.graph_cache_path:
2003
- cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
2004
- if cached is not None:
2005
- self._adj_cache_meta = meta_expected
2006
- self._adj_cache_key = None
2007
- self._adj_cache_tensor = cached
2008
- return cached
2009
- use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
2010
- if use_gpu_knn:
2011
- edge_index = self._build_edge_index_gpu(X_tensor)
2012
- else:
2013
- if X_np is None:
2014
- X_np = X_df.to_numpy(dtype=np.float32, copy=False)
2015
- edge_index = self._build_edge_index_cpu(X_np)
2016
- adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
2017
- if self.graph_cache_path:
2018
- try:
2019
- IOUtils.ensure_parent_dir(str(self.graph_cache_path))
2020
- torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
2021
- except Exception as exc:
2022
- print(
2023
- f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
2024
- self._adj_cache_meta = meta_expected
2025
- self._adj_cache_key = None
2026
- else:
2027
- self._adj_cache_meta = None
2028
- self._adj_cache_key = cache_key
2029
- self._adj_cache_tensor = adj_norm
2030
- return adj_norm
2031
-
2032
- def fit(self, X_train, y_train, w_train=None,
2033
- X_val=None, y_val=None, w_val=None,
2034
- trial: Optional[optuna.trial.Trial] = None):
2035
-
2036
- X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
2037
- X_train, y_train, w_train, allow_none=False)
2038
- has_val = X_val is not None and y_val is not None
2039
- if has_val:
2040
- X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
2041
- X_val, y_val, w_val, allow_none=False)
2042
- else:
2043
- X_val_tensor = y_val_tensor = w_val_tensor = None
2044
-
2045
- adj_train = self._build_graph_from_df(X_train, X_train_tensor)
2046
- adj_val = self._build_graph_from_df(
2047
- X_val, X_val_tensor) if has_val else None
2048
- # DataParallel needs adjacency cached on the model to avoid scatter.
2049
- self._set_adj_buffer(adj_train)
2050
-
2051
- base_gnn = self._unwrap_gnn()
2052
- optimizer = torch.optim.Adam(
2053
- base_gnn.parameters(),
2054
- lr=self.learning_rate,
2055
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
2056
- )
2057
- scaler = GradScaler(enabled=(self.device.type == 'cuda'))
2058
-
2059
- best_loss = float('inf')
2060
- best_state = None
2061
- patience_counter = 0
2062
- best_epoch = None
2063
-
2064
- for epoch in range(1, self.epochs + 1):
2065
- epoch_start_ts = time.time()
2066
- self.gnn.train()
2067
- optimizer.zero_grad()
2068
- with autocast(enabled=(self.device.type == 'cuda')):
2069
- if self.data_parallel_enabled:
2070
- y_pred = self.gnn(X_train_tensor)
2071
- else:
2072
- y_pred = self.gnn(X_train_tensor, adj_train)
2073
- loss = self._compute_weighted_loss(
2074
- y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
2075
- scaler.scale(loss).backward()
2076
- scaler.unscale_(optimizer)
2077
- clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
2078
- scaler.step(optimizer)
2079
- scaler.update()
2080
-
2081
- val_loss = None
2082
- if has_val:
2083
- self.gnn.eval()
2084
- if self.data_parallel_enabled and adj_val is not None:
2085
- self._set_adj_buffer(adj_val)
2086
- with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
2087
- if self.data_parallel_enabled:
2088
- y_val_pred = self.gnn(X_val_tensor)
2089
- else:
2090
- y_val_pred = self.gnn(X_val_tensor, adj_val)
2091
- val_loss = self._compute_weighted_loss(
2092
- y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
2093
- if self.data_parallel_enabled:
2094
- # Restore training adjacency.
2095
- self._set_adj_buffer(adj_train)
2096
-
2097
- is_best = val_loss is not None and val_loss < best_loss
2098
- best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
2099
- val_loss, best_loss, best_state, patience_counter, base_gnn,
2100
- ignore_keys=["adj_buffer"])
2101
- if is_best:
2102
- best_epoch = epoch
2103
-
2104
- prune_now = False
2105
- if trial is not None:
2106
- trial.report(val_loss, epoch)
2107
- if trial.should_prune():
2108
- prune_now = True
2109
-
2110
- if dist.is_initialized():
2111
- flag = torch.tensor(
2112
- [1 if prune_now else 0],
2113
- device=self.device,
2114
- dtype=torch.int32,
2115
- )
2116
- dist.broadcast(flag, src=0)
2117
- prune_now = bool(flag.item())
2118
-
2119
- if prune_now:
2120
- raise optuna.TrialPruned()
2121
- if stop_training:
2122
- break
2123
-
2124
- should_log = (not dist.is_initialized()
2125
- or DistributedUtils.is_main_process())
2126
- if should_log:
2127
- elapsed = int(time.time() - epoch_start_ts)
2128
- if val_loss is None:
2129
- print(
2130
- f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
2131
- flush=True,
2132
- )
2133
- else:
2134
- print(
2135
- f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
2136
- f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
2137
- flush=True,
2138
- )
2139
-
2140
- if best_state is not None:
2141
- base_gnn.load_state_dict(best_state, strict=False)
2142
- self.best_epoch = int(best_epoch or self.epochs)
2143
-
2144
- def predict(self, X: pd.DataFrame) -> np.ndarray:
2145
- self.gnn.eval()
2146
- X_tensor, _, _ = self._tensorize_split(
2147
- X, None, None, allow_none=False)
2148
- adj = self._build_graph_from_df(X, X_tensor)
2149
- if self.data_parallel_enabled:
2150
- self._set_adj_buffer(adj)
2151
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
2152
- with inference_cm():
2153
- if self.data_parallel_enabled:
2154
- y_pred = self.gnn(X_tensor).cpu().numpy()
2155
- else:
2156
- y_pred = self.gnn(X_tensor, adj).cpu().numpy()
2157
- if self.task_type == 'classification':
2158
- y_pred = 1 / (1 + np.exp(-y_pred))
2159
- else:
2160
- y_pred = np.clip(y_pred, 1e-6, None)
2161
- return y_pred.ravel()
2162
-
2163
- def encode(self, X: pd.DataFrame) -> np.ndarray:
2164
- """Return per-sample node embeddings (hidden representations)."""
2165
- base = self._unwrap_gnn()
2166
- base.eval()
2167
- X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
2168
- adj = self._build_graph_from_df(X, X_tensor)
2169
- if self.data_parallel_enabled:
2170
- self._set_adj_buffer(adj)
2171
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
2172
- with inference_cm():
2173
- h = X_tensor
2174
- layers = getattr(base, "layers", None)
2175
- if layers is None:
2176
- raise RuntimeError("GNN base module does not expose layers.")
2177
- for layer in layers:
2178
- h = layer(h, adj)
2179
- h = torch.sparse.mm(adj, h)
2180
- return h.detach().cpu().numpy()
2181
-
2182
- def set_params(self, params: Dict[str, Any]):
2183
- for key, value in params.items():
2184
- if hasattr(self, key):
2185
- setattr(self, key, value)
2186
- else:
2187
- raise ValueError(f"Parameter {key} not found in GNN model.")
2188
- # Rebuild the backbone after structural parameter changes.
2189
- self.gnn = SimpleGNN(
2190
- input_dim=self.input_dim,
2191
- hidden_dim=self.hidden_dim,
2192
- num_layers=self.num_layers,
2193
- dropout=self.dropout,
2194
- task_type=self.task_type
2195
- ).to(self.device)
2196
- return self