ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +9 -6
- ins_pricing/__init__.py +3 -11
- ins_pricing/cli/BayesOpt_entry.py +24 -0
- ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
- ins_pricing/cli/Explain_Run.py +25 -0
- ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
- ins_pricing/cli/Pricing_Run.py +25 -0
- ins_pricing/cli/__init__.py +1 -0
- ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
- ins_pricing/cli/utils/__init__.py +1 -0
- ins_pricing/cli/utils/cli_common.py +320 -0
- ins_pricing/cli/utils/cli_config.py +375 -0
- ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
- {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
- ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
- ins_pricing/docs/modelling/README.md +34 -0
- ins_pricing/modelling/__init__.py +57 -6
- ins_pricing/modelling/core/__init__.py +1 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
- ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
- ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
- ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
- ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
- ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
- ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
- ins_pricing/modelling/core/evaluation.py +115 -0
- ins_pricing/production/__init__.py +4 -0
- ins_pricing/production/preprocess.py +71 -0
- ins_pricing/setup.py +10 -5
- {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
- ins_pricing-0.2.0.dist-info/RECORD +125 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
- ins_pricing/modelling/BayesOpt_entry.py +0 -633
- ins_pricing/modelling/Explain_Run.py +0 -36
- ins_pricing/modelling/Pricing_Run.py +0 -36
- ins_pricing/modelling/README.md +0 -33
- ins_pricing/modelling/bayesopt/models.py +0 -2196
- ins_pricing/modelling/bayesopt/trainers.py +0 -2446
- ins_pricing/modelling/cli_common.py +0 -136
- ins_pricing/modelling/tests/test_plotting.py +0 -63
- ins_pricing/modelling/watchdog_run.py +0 -211
- ins_pricing-0.1.11.dist-info/RECORD +0 -169
- ins_pricing_gemini/__init__.py +0 -23
- ins_pricing_gemini/governance/__init__.py +0 -20
- ins_pricing_gemini/governance/approval.py +0 -93
- ins_pricing_gemini/governance/audit.py +0 -37
- ins_pricing_gemini/governance/registry.py +0 -99
- ins_pricing_gemini/governance/release.py +0 -159
- ins_pricing_gemini/modelling/Explain_Run.py +0 -36
- ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
- ins_pricing_gemini/modelling/__init__.py +0 -151
- ins_pricing_gemini/modelling/cli_common.py +0 -141
- ins_pricing_gemini/modelling/config.py +0 -249
- ins_pricing_gemini/modelling/config_preprocess.py +0 -254
- ins_pricing_gemini/modelling/core.py +0 -741
- ins_pricing_gemini/modelling/data_container.py +0 -42
- ins_pricing_gemini/modelling/explain/__init__.py +0 -55
- ins_pricing_gemini/modelling/explain/gradients.py +0 -334
- ins_pricing_gemini/modelling/explain/metrics.py +0 -176
- ins_pricing_gemini/modelling/explain/permutation.py +0 -155
- ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
- ins_pricing_gemini/modelling/features.py +0 -215
- ins_pricing_gemini/modelling/model_manager.py +0 -148
- ins_pricing_gemini/modelling/model_plotting.py +0 -463
- ins_pricing_gemini/modelling/models.py +0 -2203
- ins_pricing_gemini/modelling/notebook_utils.py +0 -294
- ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
- ins_pricing_gemini/modelling/plotting/common.py +0 -63
- ins_pricing_gemini/modelling/plotting/curves.py +0 -572
- ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
- ins_pricing_gemini/modelling/plotting/geo.py +0 -362
- ins_pricing_gemini/modelling/plotting/importance.py +0 -121
- ins_pricing_gemini/modelling/run_logging.py +0 -133
- ins_pricing_gemini/modelling/tests/conftest.py +0 -8
- ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
- ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
- ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
- ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
- ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
- ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
- ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
- ins_pricing_gemini/modelling/trainers.py +0 -2447
- ins_pricing_gemini/modelling/utils.py +0 -1020
- ins_pricing_gemini/pricing/__init__.py +0 -27
- ins_pricing_gemini/pricing/calibration.py +0 -39
- ins_pricing_gemini/pricing/data_quality.py +0 -117
- ins_pricing_gemini/pricing/exposure.py +0 -85
- ins_pricing_gemini/pricing/factors.py +0 -91
- ins_pricing_gemini/pricing/monitoring.py +0 -99
- ins_pricing_gemini/pricing/rate_table.py +0 -78
- ins_pricing_gemini/production/__init__.py +0 -21
- ins_pricing_gemini/production/drift.py +0 -30
- ins_pricing_gemini/production/monitoring.py +0 -143
- ins_pricing_gemini/production/scoring.py +0 -40
- ins_pricing_gemini/reporting/__init__.py +0 -11
- ins_pricing_gemini/reporting/report_builder.py +0 -72
- ins_pricing_gemini/reporting/scheduler.py +0 -45
- ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
- ins_pricing_gemini/scripts/Explain_entry.py +0 -545
- ins_pricing_gemini/scripts/__init__.py +0 -1
- ins_pricing_gemini/scripts/train.py +0 -568
- ins_pricing_gemini/setup.py +0 -55
- ins_pricing_gemini/smoke_test.py +0 -28
- /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
- /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
- /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -1,2203 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import copy
|
|
4
|
-
import hashlib
|
|
5
|
-
import math
|
|
6
|
-
import os
|
|
7
|
-
import time
|
|
8
|
-
from contextlib import nullcontext
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
11
|
-
|
|
12
|
-
import numpy as np
|
|
13
|
-
import optuna
|
|
14
|
-
import pandas as pd
|
|
15
|
-
import torch
|
|
16
|
-
import torch.distributed as dist
|
|
17
|
-
import torch.nn as nn
|
|
18
|
-
import torch.nn.functional as F
|
|
19
|
-
from sklearn.neighbors import NearestNeighbors
|
|
20
|
-
from torch.cuda.amp import autocast, GradScaler
|
|
21
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
22
|
-
from torch.nn.utils import clip_grad_norm_
|
|
23
|
-
from torch.utils.data import Dataset, TensorDataset
|
|
24
|
-
|
|
25
|
-
from .utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
|
|
26
|
-
|
|
27
|
-
try:
|
|
28
|
-
from torch_geometric.nn import knn_graph
|
|
29
|
-
from torch_geometric.utils import add_self_loops, to_undirected
|
|
30
|
-
_PYG_AVAILABLE = True
|
|
31
|
-
except Exception:
|
|
32
|
-
knn_graph = None # type: ignore
|
|
33
|
-
add_self_loops = None # type: ignore
|
|
34
|
-
to_undirected = None # type: ignore
|
|
35
|
-
_PYG_AVAILABLE = False
|
|
36
|
-
|
|
37
|
-
try:
|
|
38
|
-
import pynndescent
|
|
39
|
-
_PYNN_AVAILABLE = True
|
|
40
|
-
except Exception:
|
|
41
|
-
pynndescent = None # type: ignore
|
|
42
|
-
_PYNN_AVAILABLE = False
|
|
43
|
-
|
|
44
|
-
_GNN_MPS_WARNED = False
|
|
45
|
-
|
|
46
|
-
# =============================================================================
|
|
47
|
-
# ResNet model and sklearn-style wrapper
|
|
48
|
-
# =============================================================================
|
|
49
|
-
|
|
50
|
-
# ResNet model definition
|
|
51
|
-
# Residual block: two linear layers + ReLU + residual connection
|
|
52
|
-
# ResBlock inherits nn.Module
|
|
53
|
-
class ResBlock(nn.Module):
|
|
54
|
-
def __init__(self, dim: int, dropout: float = 0.1,
|
|
55
|
-
use_layernorm: bool = False, residual_scale: float = 0.1,
|
|
56
|
-
stochastic_depth: float = 0.0
|
|
57
|
-
):
|
|
58
|
-
super().__init__()
|
|
59
|
-
self.use_layernorm = use_layernorm
|
|
60
|
-
|
|
61
|
-
if use_layernorm:
|
|
62
|
-
Norm = nn.LayerNorm # Normalize the last dimension
|
|
63
|
-
else:
|
|
64
|
-
def Norm(d): return nn.BatchNorm1d(d) # Keep a switch to try BN
|
|
65
|
-
|
|
66
|
-
self.norm1 = Norm(dim)
|
|
67
|
-
self.fc1 = nn.Linear(dim, dim, bias=True)
|
|
68
|
-
self.act = nn.ReLU(inplace=True)
|
|
69
|
-
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
|
70
|
-
# Enable post-second-layer norm if needed: self.norm2 = Norm(dim)
|
|
71
|
-
self.fc2 = nn.Linear(dim, dim, bias=True)
|
|
72
|
-
|
|
73
|
-
# Residual scaling to stabilize early training
|
|
74
|
-
self.res_scale = nn.Parameter(
|
|
75
|
-
torch.tensor(residual_scale, dtype=torch.float32)
|
|
76
|
-
)
|
|
77
|
-
self.stochastic_depth = max(0.0, float(stochastic_depth))
|
|
78
|
-
|
|
79
|
-
def _drop_path(self, x: torch.Tensor) -> torch.Tensor:
|
|
80
|
-
if self.stochastic_depth <= 0.0 or not self.training:
|
|
81
|
-
return x
|
|
82
|
-
keep_prob = 1.0 - self.stochastic_depth
|
|
83
|
-
if keep_prob <= 0.0:
|
|
84
|
-
return torch.zeros_like(x)
|
|
85
|
-
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
86
|
-
random_tensor = keep_prob + torch.rand(
|
|
87
|
-
shape, dtype=x.dtype, device=x.device)
|
|
88
|
-
binary_tensor = torch.floor(random_tensor)
|
|
89
|
-
return x * binary_tensor / keep_prob
|
|
90
|
-
|
|
91
|
-
def forward(self, x):
|
|
92
|
-
# Pre-activation structure
|
|
93
|
-
out = self.norm1(x)
|
|
94
|
-
out = self.fc1(out)
|
|
95
|
-
out = self.act(out)
|
|
96
|
-
out = self.dropout(out)
|
|
97
|
-
# If a second norm is enabled: out = self.norm2(out)
|
|
98
|
-
out = self.fc2(out)
|
|
99
|
-
# Apply residual scaling then add
|
|
100
|
-
out = self.res_scale * out
|
|
101
|
-
out = self._drop_path(out)
|
|
102
|
-
return x + out
|
|
103
|
-
|
|
104
|
-
# ResNetSequential defines the full network
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class ResNetSequential(nn.Module):
|
|
108
|
-
# Input shape: (batch, input_dim)
|
|
109
|
-
# Network: FC + norm + ReLU, stack residual blocks, output Softplus
|
|
110
|
-
|
|
111
|
-
def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
|
|
112
|
-
use_layernorm: bool = True, dropout: float = 0.1,
|
|
113
|
-
residual_scale: float = 0.1, stochastic_depth: float = 0.0,
|
|
114
|
-
task_type: str = 'regression'):
|
|
115
|
-
super(ResNetSequential, self).__init__()
|
|
116
|
-
|
|
117
|
-
self.net = nn.Sequential()
|
|
118
|
-
self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
|
|
119
|
-
|
|
120
|
-
# Optional explicit normalization after the first layer:
|
|
121
|
-
# For LayerNorm:
|
|
122
|
-
# self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
|
|
123
|
-
# Or BatchNorm:
|
|
124
|
-
# self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
|
|
125
|
-
|
|
126
|
-
# If desired, insert ReLU before residual blocks:
|
|
127
|
-
# self.net.add_module('relu1', nn.ReLU(inplace=True))
|
|
128
|
-
|
|
129
|
-
# Residual blocks
|
|
130
|
-
drop_path_rate = max(0.0, float(stochastic_depth))
|
|
131
|
-
for i in range(block_num):
|
|
132
|
-
if block_num > 1:
|
|
133
|
-
block_drop = drop_path_rate * (i / (block_num - 1))
|
|
134
|
-
else:
|
|
135
|
-
block_drop = drop_path_rate
|
|
136
|
-
self.net.add_module(
|
|
137
|
-
f'ResBlk_{i+1}',
|
|
138
|
-
ResBlock(
|
|
139
|
-
hidden_dim,
|
|
140
|
-
dropout=dropout,
|
|
141
|
-
use_layernorm=use_layernorm,
|
|
142
|
-
residual_scale=residual_scale,
|
|
143
|
-
stochastic_depth=block_drop)
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
|
|
147
|
-
|
|
148
|
-
if task_type == 'classification':
|
|
149
|
-
self.net.add_module('softplus', nn.Identity())
|
|
150
|
-
else:
|
|
151
|
-
self.net.add_module('softplus', nn.Softplus())
|
|
152
|
-
|
|
153
|
-
def forward(self, x):
|
|
154
|
-
if self.training and not hasattr(self, '_printed_device'):
|
|
155
|
-
print(f">>> ResNetSequential executing on device: {x.device}")
|
|
156
|
-
self._printed_device = True
|
|
157
|
-
return self.net(x)
|
|
158
|
-
|
|
159
|
-
# Define the ResNet sklearn-style wrapper.
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
class ResNetSklearn(TorchTrainerMixin, nn.Module):
|
|
163
|
-
def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
|
|
164
|
-
block_num: int = 2, batch_num: int = 100, epochs: int = 100,
|
|
165
|
-
task_type: str = 'regression',
|
|
166
|
-
tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
|
|
167
|
-
use_layernorm: bool = True, dropout: float = 0.1,
|
|
168
|
-
residual_scale: float = 0.1,
|
|
169
|
-
stochastic_depth: float = 0.0,
|
|
170
|
-
weight_decay: float = 1e-4,
|
|
171
|
-
use_data_parallel: bool = True,
|
|
172
|
-
use_ddp: bool = False):
|
|
173
|
-
super(ResNetSklearn, self).__init__()
|
|
174
|
-
|
|
175
|
-
self.use_ddp = use_ddp
|
|
176
|
-
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
|
|
177
|
-
False, 0, 0, 1)
|
|
178
|
-
|
|
179
|
-
if self.use_ddp:
|
|
180
|
-
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
|
|
181
|
-
|
|
182
|
-
self.input_dim = input_dim
|
|
183
|
-
self.hidden_dim = hidden_dim
|
|
184
|
-
self.block_num = block_num
|
|
185
|
-
self.batch_num = batch_num
|
|
186
|
-
self.epochs = epochs
|
|
187
|
-
self.task_type = task_type
|
|
188
|
-
self.model_nme = model_nme
|
|
189
|
-
self.learning_rate = learning_rate
|
|
190
|
-
self.weight_decay = weight_decay
|
|
191
|
-
self.patience = patience
|
|
192
|
-
self.use_layernorm = use_layernorm
|
|
193
|
-
self.dropout = dropout
|
|
194
|
-
self.residual_scale = residual_scale
|
|
195
|
-
self.stochastic_depth = max(0.0, float(stochastic_depth))
|
|
196
|
-
self.loss_curve_path: Optional[str] = None
|
|
197
|
-
self.training_history: Dict[str, List[float]] = {
|
|
198
|
-
"train": [], "val": []}
|
|
199
|
-
self.use_data_parallel = bool(use_data_parallel)
|
|
200
|
-
|
|
201
|
-
# Device selection: cuda > mps > cpu
|
|
202
|
-
if self.is_ddp_enabled:
|
|
203
|
-
if torch.cuda.is_available():
|
|
204
|
-
self.device = torch.device(f'cuda:{self.local_rank}')
|
|
205
|
-
else:
|
|
206
|
-
self.device = torch.device('cpu')
|
|
207
|
-
elif torch.cuda.is_available():
|
|
208
|
-
self.device = torch.device('cuda')
|
|
209
|
-
elif torch.backends.mps.is_available():
|
|
210
|
-
self.device = torch.device('mps')
|
|
211
|
-
else:
|
|
212
|
-
self.device = torch.device('cpu')
|
|
213
|
-
|
|
214
|
-
# Tweedie power (unused for classification)
|
|
215
|
-
if self.task_type == 'classification':
|
|
216
|
-
self.tw_power = None
|
|
217
|
-
elif 'f' in self.model_nme:
|
|
218
|
-
self.tw_power = 1
|
|
219
|
-
elif 's' in self.model_nme:
|
|
220
|
-
self.tw_power = 2
|
|
221
|
-
else:
|
|
222
|
-
self.tw_power = tweedie_power
|
|
223
|
-
|
|
224
|
-
# Build network (construct on CPU first)
|
|
225
|
-
core = ResNetSequential(
|
|
226
|
-
self.input_dim,
|
|
227
|
-
self.hidden_dim,
|
|
228
|
-
self.block_num,
|
|
229
|
-
use_layernorm=self.use_layernorm,
|
|
230
|
-
dropout=self.dropout,
|
|
231
|
-
residual_scale=self.residual_scale,
|
|
232
|
-
stochastic_depth=self.stochastic_depth,
|
|
233
|
-
task_type=self.task_type
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
# ===== Multi-GPU: DataParallel vs DistributedDataParallel =====
|
|
237
|
-
if self.is_ddp_enabled:
|
|
238
|
-
core = core.to(self.device)
|
|
239
|
-
if self.device.type == 'cuda':
|
|
240
|
-
core = DDP(core, device_ids=[self.local_rank], output_device=self.local_rank)
|
|
241
|
-
else:
|
|
242
|
-
# CPU/Gloo DDP
|
|
243
|
-
core = DDP(core)
|
|
244
|
-
self.use_data_parallel = False
|
|
245
|
-
elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
|
|
246
|
-
if self.use_ddp and not self.is_ddp_enabled:
|
|
247
|
-
print(
|
|
248
|
-
">>> DDP requested but not initialized; falling back to DataParallel.")
|
|
249
|
-
core = nn.DataParallel(core, device_ids=list(
|
|
250
|
-
range(torch.cuda.device_count())))
|
|
251
|
-
# DataParallel scatters inputs, but the primary device remains cuda:0.
|
|
252
|
-
self.device = torch.device('cuda')
|
|
253
|
-
self.use_data_parallel = True
|
|
254
|
-
else:
|
|
255
|
-
self.use_data_parallel = False
|
|
256
|
-
|
|
257
|
-
self.resnet = core.to(self.device)
|
|
258
|
-
|
|
259
|
-
# ================ Internal helpers ================
|
|
260
|
-
@staticmethod
|
|
261
|
-
def _validate_vector(arr, name: str, n_rows: int) -> None:
|
|
262
|
-
if arr is None:
|
|
263
|
-
return
|
|
264
|
-
if isinstance(arr, pd.DataFrame):
|
|
265
|
-
if arr.shape[1] != 1:
|
|
266
|
-
raise ValueError(f"{name} must be 1d (single column).")
|
|
267
|
-
length = len(arr)
|
|
268
|
-
else:
|
|
269
|
-
arr_np = np.asarray(arr)
|
|
270
|
-
if arr_np.ndim == 0:
|
|
271
|
-
raise ValueError(f"{name} must be 1d.")
|
|
272
|
-
if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
|
|
273
|
-
raise ValueError(f"{name} must be 1d or Nx1.")
|
|
274
|
-
length = arr_np.shape[0]
|
|
275
|
-
if length != n_rows:
|
|
276
|
-
raise ValueError(
|
|
277
|
-
f"{name} length {length} does not match X length {n_rows}."
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
def _validate_inputs(self, X, y, w, label: str) -> None:
|
|
281
|
-
if X is None:
|
|
282
|
-
raise ValueError(f"{label} X cannot be None.")
|
|
283
|
-
n_rows = len(X)
|
|
284
|
-
if y is None:
|
|
285
|
-
raise ValueError(f"{label} y cannot be None.")
|
|
286
|
-
self._validate_vector(y, f"{label} y", n_rows)
|
|
287
|
-
self._validate_vector(w, f"{label} w", n_rows)
|
|
288
|
-
|
|
289
|
-
def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
|
|
290
|
-
self._validate_inputs(X_train, y_train, w_train, "train")
|
|
291
|
-
if X_val is not None or y_val is not None or w_val is not None:
|
|
292
|
-
if X_val is None or y_val is None:
|
|
293
|
-
raise ValueError("validation X and y must both be provided.")
|
|
294
|
-
self._validate_inputs(X_val, y_val, w_val, "val")
|
|
295
|
-
|
|
296
|
-
def _to_numpy(arr):
|
|
297
|
-
if hasattr(arr, "to_numpy"):
|
|
298
|
-
return arr.to_numpy(dtype=np.float32, copy=False)
|
|
299
|
-
return np.asarray(arr, dtype=np.float32)
|
|
300
|
-
|
|
301
|
-
X_tensor = torch.as_tensor(_to_numpy(X_train))
|
|
302
|
-
y_tensor = torch.as_tensor(_to_numpy(y_train)).view(-1, 1)
|
|
303
|
-
w_tensor = (
|
|
304
|
-
torch.as_tensor(_to_numpy(w_train)).view(-1, 1)
|
|
305
|
-
if w_train is not None else torch.ones_like(y_tensor)
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
has_val = X_val is not None and y_val is not None
|
|
309
|
-
if has_val:
|
|
310
|
-
X_val_tensor = torch.as_tensor(_to_numpy(X_val))
|
|
311
|
-
y_val_tensor = torch.as_tensor(_to_numpy(y_val)).view(-1, 1)
|
|
312
|
-
w_val_tensor = (
|
|
313
|
-
torch.as_tensor(_to_numpy(w_val)).view(-1, 1)
|
|
314
|
-
if w_val is not None else torch.ones_like(y_val_tensor)
|
|
315
|
-
)
|
|
316
|
-
else:
|
|
317
|
-
X_val_tensor = y_val_tensor = w_val_tensor = None
|
|
318
|
-
return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
|
|
319
|
-
|
|
320
|
-
def forward(self, x):
|
|
321
|
-
# Handle SHAP NumPy input.
|
|
322
|
-
if isinstance(x, np.ndarray):
|
|
323
|
-
x_tensor = torch.as_tensor(x, dtype=torch.float32)
|
|
324
|
-
else:
|
|
325
|
-
x_tensor = x
|
|
326
|
-
|
|
327
|
-
x_tensor = x_tensor.to(self.device)
|
|
328
|
-
y_pred = self.resnet(x_tensor)
|
|
329
|
-
return y_pred
|
|
330
|
-
|
|
331
|
-
# ---------------- Training ----------------
|
|
332
|
-
|
|
333
|
-
def fit(self, X_train, y_train, w_train=None,
|
|
334
|
-
X_val=None, y_val=None, w_val=None, trial=None):
|
|
335
|
-
|
|
336
|
-
X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
|
|
337
|
-
self._build_train_val_tensors(
|
|
338
|
-
X_train, y_train, w_train, X_val, y_val, w_val)
|
|
339
|
-
|
|
340
|
-
dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
|
|
341
|
-
dataloader, accum_steps = self._build_dataloader(
|
|
342
|
-
dataset,
|
|
343
|
-
N=X_tensor.shape[0],
|
|
344
|
-
base_bs_gpu=(2048, 1024, 512),
|
|
345
|
-
base_bs_cpu=(256, 128),
|
|
346
|
-
min_bs=64,
|
|
347
|
-
target_effective_cuda=2048,
|
|
348
|
-
target_effective_cpu=1024
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
# Set sampler epoch at the start of each epoch to keep shuffling deterministic.
|
|
352
|
-
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
353
|
-
self.dataloader_sampler = dataloader.sampler
|
|
354
|
-
else:
|
|
355
|
-
self.dataloader_sampler = None
|
|
356
|
-
|
|
357
|
-
# === 4. Optimizer and AMP ===
|
|
358
|
-
self.optimizer = torch.optim.Adam(
|
|
359
|
-
self.resnet.parameters(),
|
|
360
|
-
lr=self.learning_rate,
|
|
361
|
-
weight_decay=float(self.weight_decay),
|
|
362
|
-
)
|
|
363
|
-
self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
364
|
-
|
|
365
|
-
X_val_dev = y_val_dev = w_val_dev = None
|
|
366
|
-
val_dataloader = None
|
|
367
|
-
if has_val:
|
|
368
|
-
# Build validation DataLoader.
|
|
369
|
-
val_dataset = TensorDataset(
|
|
370
|
-
X_val_tensor, y_val_tensor, w_val_tensor)
|
|
371
|
-
# No backward pass in validation; batch size can be larger for throughput.
|
|
372
|
-
val_dataloader = self._build_val_dataloader(
|
|
373
|
-
val_dataset, dataloader, accum_steps)
|
|
374
|
-
# Validation usually does not need a DDP sampler because we validate on the main process
|
|
375
|
-
# or aggregate results. For simplicity, keep validation on a single GPU or the main process.
|
|
376
|
-
|
|
377
|
-
is_data_parallel = isinstance(self.resnet, nn.DataParallel)
|
|
378
|
-
|
|
379
|
-
def forward_fn(batch):
|
|
380
|
-
X_batch, y_batch, w_batch = batch
|
|
381
|
-
|
|
382
|
-
if not is_data_parallel:
|
|
383
|
-
X_batch = X_batch.to(self.device, non_blocking=True)
|
|
384
|
-
# Keep targets and weights on the main device for loss computation.
|
|
385
|
-
y_batch = y_batch.to(self.device, non_blocking=True)
|
|
386
|
-
w_batch = w_batch.to(self.device, non_blocking=True)
|
|
387
|
-
|
|
388
|
-
y_pred = self.resnet(X_batch)
|
|
389
|
-
return y_pred, y_batch, w_batch
|
|
390
|
-
|
|
391
|
-
def val_forward_fn():
|
|
392
|
-
total_loss = 0.0
|
|
393
|
-
total_weight = 0.0
|
|
394
|
-
for batch in val_dataloader:
|
|
395
|
-
X_b, y_b, w_b = batch
|
|
396
|
-
if not is_data_parallel:
|
|
397
|
-
X_b = X_b.to(self.device, non_blocking=True)
|
|
398
|
-
y_b = y_b.to(self.device, non_blocking=True)
|
|
399
|
-
w_b = w_b.to(self.device, non_blocking=True)
|
|
400
|
-
|
|
401
|
-
y_pred = self.resnet(X_b)
|
|
402
|
-
|
|
403
|
-
# Manually compute weighted loss for accurate aggregation.
|
|
404
|
-
losses = self._compute_losses(
|
|
405
|
-
y_pred, y_b, apply_softplus=False)
|
|
406
|
-
|
|
407
|
-
batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
|
|
408
|
-
batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
|
|
409
|
-
|
|
410
|
-
total_loss += batch_weighted_loss_sum.item()
|
|
411
|
-
total_weight += batch_weight_sum.item()
|
|
412
|
-
|
|
413
|
-
return total_loss / max(total_weight, EPS)
|
|
414
|
-
|
|
415
|
-
clip_fn = None
|
|
416
|
-
if self.device.type == 'cuda':
|
|
417
|
-
def clip_fn(): return (self.scaler.unscale_(self.optimizer),
|
|
418
|
-
clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
|
|
419
|
-
|
|
420
|
-
# Under DDP, only the main process prints logs and saves models.
|
|
421
|
-
if self.is_ddp_enabled and not DistributedUtils.is_main_process():
|
|
422
|
-
# Non-main processes skip validation callback logging (handled inside _train_model).
|
|
423
|
-
pass
|
|
424
|
-
|
|
425
|
-
best_state, history = self._train_model(
|
|
426
|
-
self.resnet,
|
|
427
|
-
dataloader,
|
|
428
|
-
accum_steps,
|
|
429
|
-
self.optimizer,
|
|
430
|
-
self.scaler,
|
|
431
|
-
forward_fn,
|
|
432
|
-
val_forward_fn if has_val else None,
|
|
433
|
-
apply_softplus=False,
|
|
434
|
-
clip_fn=clip_fn,
|
|
435
|
-
trial=trial,
|
|
436
|
-
loss_curve_path=getattr(self, "loss_curve_path", None)
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
if has_val and best_state is not None:
|
|
440
|
-
self.resnet.load_state_dict(best_state)
|
|
441
|
-
self.training_history = history
|
|
442
|
-
|
|
443
|
-
# ---------------- Prediction ----------------
|
|
444
|
-
|
|
445
|
-
def predict(self, X_test):
|
|
446
|
-
self.resnet.eval()
|
|
447
|
-
if isinstance(X_test, pd.DataFrame):
|
|
448
|
-
X_np = X_test.to_numpy(dtype=np.float32, copy=False)
|
|
449
|
-
else:
|
|
450
|
-
X_np = np.asarray(X_test, dtype=np.float32)
|
|
451
|
-
|
|
452
|
-
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
453
|
-
with inference_cm():
|
|
454
|
-
y_pred = self(X_np).cpu().numpy()
|
|
455
|
-
|
|
456
|
-
if self.task_type == 'classification':
|
|
457
|
-
y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid converts logits to probabilities.
|
|
458
|
-
else:
|
|
459
|
-
y_pred = np.clip(y_pred, 1e-6, None)
|
|
460
|
-
return y_pred.flatten()
|
|
461
|
-
|
|
462
|
-
# ---------------- Set Params ----------------
|
|
463
|
-
|
|
464
|
-
def set_params(self, params):
|
|
465
|
-
for key, value in params.items():
|
|
466
|
-
if hasattr(self, key):
|
|
467
|
-
setattr(self, key, value)
|
|
468
|
-
else:
|
|
469
|
-
raise ValueError(f"Parameter {key} not found in model.")
|
|
470
|
-
return self
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
# =============================================================================
|
|
474
|
-
# FT-Transformer model and sklearn-style wrapper.
|
|
475
|
-
# =============================================================================
|
|
476
|
-
# Define FT-Transformer model structure.
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
class FeatureTokenizer(nn.Module):
|
|
480
|
-
"""Map numeric/categorical/geo tokens into transformer input tokens."""
|
|
481
|
-
|
|
482
|
-
def __init__(
|
|
483
|
-
self,
|
|
484
|
-
num_numeric: int,
|
|
485
|
-
cat_cardinalities,
|
|
486
|
-
d_model: int,
|
|
487
|
-
num_geo: int = 0,
|
|
488
|
-
num_numeric_tokens: int = 1,
|
|
489
|
-
):
|
|
490
|
-
super().__init__()
|
|
491
|
-
|
|
492
|
-
self.num_numeric = num_numeric
|
|
493
|
-
self.num_geo = num_geo
|
|
494
|
-
self.has_geo = num_geo > 0
|
|
495
|
-
|
|
496
|
-
if num_numeric > 0:
|
|
497
|
-
if int(num_numeric_tokens) <= 0:
|
|
498
|
-
raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
|
|
499
|
-
self.num_numeric_tokens = int(num_numeric_tokens)
|
|
500
|
-
self.has_numeric = True
|
|
501
|
-
self.num_linear = nn.Linear(num_numeric, d_model * self.num_numeric_tokens)
|
|
502
|
-
else:
|
|
503
|
-
self.num_numeric_tokens = 0
|
|
504
|
-
self.has_numeric = False
|
|
505
|
-
|
|
506
|
-
self.embeddings = nn.ModuleList([
|
|
507
|
-
nn.Embedding(card, d_model) for card in cat_cardinalities
|
|
508
|
-
])
|
|
509
|
-
|
|
510
|
-
if self.has_geo:
|
|
511
|
-
# Map geo tokens with a linear layer to avoid one-hot on raw strings; upstream is encoded/normalized.
|
|
512
|
-
self.geo_linear = nn.Linear(num_geo, d_model)
|
|
513
|
-
|
|
514
|
-
def forward(self, X_num, X_cat, X_geo=None):
|
|
515
|
-
tokens = []
|
|
516
|
-
|
|
517
|
-
if self.has_numeric:
|
|
518
|
-
batch_size = X_num.shape[0]
|
|
519
|
-
num_token = self.num_linear(X_num)
|
|
520
|
-
num_token = num_token.view(batch_size, self.num_numeric_tokens, -1)
|
|
521
|
-
tokens.append(num_token)
|
|
522
|
-
|
|
523
|
-
for i, emb in enumerate(self.embeddings):
|
|
524
|
-
tok = emb(X_cat[:, i])
|
|
525
|
-
tokens.append(tok.unsqueeze(1))
|
|
526
|
-
|
|
527
|
-
if self.has_geo:
|
|
528
|
-
if X_geo is None:
|
|
529
|
-
raise RuntimeError("Geo tokens are enabled but X_geo was not provided.")
|
|
530
|
-
geo_token = self.geo_linear(X_geo)
|
|
531
|
-
tokens.append(geo_token.unsqueeze(1))
|
|
532
|
-
|
|
533
|
-
x = torch.cat(tokens, dim=1)
|
|
534
|
-
return x
|
|
535
|
-
|
|
536
|
-
# Encoder layer with residual scaling.
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
class ScaledTransformerEncoderLayer(nn.Module):
|
|
540
|
-
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
|
|
541
|
-
dropout: float = 0.1, residual_scale_attn: float = 1.0,
|
|
542
|
-
residual_scale_ffn: float = 1.0, norm_first: bool = True,
|
|
543
|
-
):
|
|
544
|
-
super().__init__()
|
|
545
|
-
self.self_attn = nn.MultiheadAttention(
|
|
546
|
-
embed_dim=d_model,
|
|
547
|
-
num_heads=nhead,
|
|
548
|
-
dropout=dropout,
|
|
549
|
-
batch_first=True
|
|
550
|
-
)
|
|
551
|
-
|
|
552
|
-
# Feed-forward network.
|
|
553
|
-
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
554
|
-
self.dropout = nn.Dropout(dropout)
|
|
555
|
-
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
556
|
-
|
|
557
|
-
# Normalization and dropout.
|
|
558
|
-
self.norm1 = nn.LayerNorm(d_model)
|
|
559
|
-
self.norm2 = nn.LayerNorm(d_model)
|
|
560
|
-
self.dropout1 = nn.Dropout(dropout)
|
|
561
|
-
self.dropout2 = nn.Dropout(dropout)
|
|
562
|
-
|
|
563
|
-
self.activation = nn.GELU()
|
|
564
|
-
# If you prefer ReLU, set: self.activation = nn.ReLU()
|
|
565
|
-
self.norm_first = norm_first
|
|
566
|
-
|
|
567
|
-
# Residual scaling coefficients.
|
|
568
|
-
self.res_scale_attn = residual_scale_attn
|
|
569
|
-
self.res_scale_ffn = residual_scale_ffn
|
|
570
|
-
|
|
571
|
-
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
|
572
|
-
# Input tensor shape: (batch, seq_len, d_model).
|
|
573
|
-
x = src
|
|
574
|
-
|
|
575
|
-
if self.norm_first:
|
|
576
|
-
# Pre-norm before attention.
|
|
577
|
-
x = x + self._sa_block(self.norm1(x), src_mask,
|
|
578
|
-
src_key_padding_mask)
|
|
579
|
-
x = x + self._ff_block(self.norm2(x))
|
|
580
|
-
else:
|
|
581
|
-
# Post-norm (usually disabled).
|
|
582
|
-
x = self.norm1(
|
|
583
|
-
x + self._sa_block(x, src_mask, src_key_padding_mask))
|
|
584
|
-
x = self.norm2(x + self._ff_block(x))
|
|
585
|
-
|
|
586
|
-
return x
|
|
587
|
-
|
|
588
|
-
def _sa_block(self, x, attn_mask, key_padding_mask):
|
|
589
|
-
# Self-attention with residual scaling.
|
|
590
|
-
attn_out, _ = self.self_attn(
|
|
591
|
-
x, x, x,
|
|
592
|
-
attn_mask=attn_mask,
|
|
593
|
-
key_padding_mask=key_padding_mask,
|
|
594
|
-
need_weights=False
|
|
595
|
-
)
|
|
596
|
-
return self.res_scale_attn * self.dropout1(attn_out)
|
|
597
|
-
|
|
598
|
-
def _ff_block(self, x):
|
|
599
|
-
# Feed-forward block with residual scaling.
|
|
600
|
-
x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
601
|
-
return self.res_scale_ffn * self.dropout2(x2)
|
|
602
|
-
|
|
603
|
-
# FT-Transformer core model.
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
class FTTransformerCore(nn.Module):
|
|
607
|
-
# Minimal FT-Transformer built from:
|
|
608
|
-
# 1) FeatureTokenizer: convert numeric/categorical features to tokens;
|
|
609
|
-
# 2) TransformerEncoder: model feature interactions;
|
|
610
|
-
# 3) Pooling + MLP + Softplus: positive outputs for Tweedie/Gamma tasks.
|
|
611
|
-
|
|
612
|
-
def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
|
|
613
|
-
n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
|
|
614
|
-
task_type: str = 'regression', num_geo: int = 0,
|
|
615
|
-
num_numeric_tokens: int = 1
|
|
616
|
-
):
|
|
617
|
-
super().__init__()
|
|
618
|
-
|
|
619
|
-
self.num_numeric = int(num_numeric)
|
|
620
|
-
self.cat_cardinalities = list(cat_cardinalities or [])
|
|
621
|
-
|
|
622
|
-
self.tokenizer = FeatureTokenizer(
|
|
623
|
-
num_numeric=num_numeric,
|
|
624
|
-
cat_cardinalities=cat_cardinalities,
|
|
625
|
-
d_model=d_model,
|
|
626
|
-
num_geo=num_geo,
|
|
627
|
-
num_numeric_tokens=num_numeric_tokens
|
|
628
|
-
)
|
|
629
|
-
scale = 1.0 / math.sqrt(n_layers) # Recommended default.
|
|
630
|
-
encoder_layer = ScaledTransformerEncoderLayer(
|
|
631
|
-
d_model=d_model,
|
|
632
|
-
nhead=n_heads,
|
|
633
|
-
dim_feedforward=d_model * 4,
|
|
634
|
-
dropout=dropout,
|
|
635
|
-
residual_scale_attn=scale,
|
|
636
|
-
residual_scale_ffn=scale,
|
|
637
|
-
norm_first=True,
|
|
638
|
-
)
|
|
639
|
-
self.encoder = nn.TransformerEncoder(
|
|
640
|
-
encoder_layer,
|
|
641
|
-
num_layers=n_layers
|
|
642
|
-
)
|
|
643
|
-
self.n_layers = n_layers
|
|
644
|
-
|
|
645
|
-
layers = [
|
|
646
|
-
# If you need a deeper head, enable the sample layers below:
|
|
647
|
-
# nn.LayerNorm(d_model), # Extra normalization
|
|
648
|
-
# nn.Linear(d_model, d_model), # Extra fully connected layer
|
|
649
|
-
# nn.GELU(), # Activation
|
|
650
|
-
nn.Linear(d_model, 1),
|
|
651
|
-
]
|
|
652
|
-
|
|
653
|
-
if task_type == 'classification':
|
|
654
|
-
# Classification outputs logits for BCEWithLogitsLoss.
|
|
655
|
-
layers.append(nn.Identity())
|
|
656
|
-
else:
|
|
657
|
-
# Regression keeps positive outputs for Tweedie/Gamma.
|
|
658
|
-
layers.append(nn.Softplus())
|
|
659
|
-
|
|
660
|
-
self.head = nn.Sequential(*layers)
|
|
661
|
-
|
|
662
|
-
# ---- Self-supervised reconstruction head (masked modeling) ----
|
|
663
|
-
self.num_recon_head = nn.Linear(
|
|
664
|
-
d_model, self.num_numeric) if self.num_numeric > 0 else None
|
|
665
|
-
self.cat_recon_heads = nn.ModuleList([
|
|
666
|
-
nn.Linear(d_model, int(card)) for card in self.cat_cardinalities
|
|
667
|
-
])
|
|
668
|
-
|
|
669
|
-
def forward(
|
|
670
|
-
self,
|
|
671
|
-
X_num,
|
|
672
|
-
X_cat,
|
|
673
|
-
X_geo=None,
|
|
674
|
-
return_embedding: bool = False,
|
|
675
|
-
return_reconstruction: bool = False):
|
|
676
|
-
|
|
677
|
-
# Inputs:
|
|
678
|
-
# X_num -> float32 tensor with shape (batch, num_numeric_features)
|
|
679
|
-
# X_cat -> long tensor with shape (batch, num_categorical_features)
|
|
680
|
-
# X_geo -> float32 tensor with shape (batch, geo_token_dim)
|
|
681
|
-
|
|
682
|
-
if self.training and not hasattr(self, '_printed_device'):
|
|
683
|
-
print(f">>> FTTransformerCore executing on device: {X_num.device}")
|
|
684
|
-
self._printed_device = True
|
|
685
|
-
|
|
686
|
-
# => tensor shape (batch, token_num, d_model)
|
|
687
|
-
tokens = self.tokenizer(X_num, X_cat, X_geo)
|
|
688
|
-
# => tensor shape (batch, token_num, d_model)
|
|
689
|
-
x = self.encoder(tokens)
|
|
690
|
-
|
|
691
|
-
# Mean-pool tokens, then send to the head.
|
|
692
|
-
x = x.mean(dim=1) # => tensor shape (batch, d_model)
|
|
693
|
-
|
|
694
|
-
if return_reconstruction:
|
|
695
|
-
num_pred, cat_logits = self.reconstruct(x)
|
|
696
|
-
cat_logits_out = tuple(
|
|
697
|
-
cat_logits) if cat_logits is not None else tuple()
|
|
698
|
-
if return_embedding:
|
|
699
|
-
return x, num_pred, cat_logits_out
|
|
700
|
-
return num_pred, cat_logits_out
|
|
701
|
-
|
|
702
|
-
if return_embedding:
|
|
703
|
-
return x
|
|
704
|
-
|
|
705
|
-
# => tensor shape (batch, 1); Softplus keeps it positive.
|
|
706
|
-
out = self.head(x)
|
|
707
|
-
return out
|
|
708
|
-
|
|
709
|
-
def reconstruct(self, embedding: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[torch.Tensor]]:
|
|
710
|
-
"""Reconstruct numeric/categorical inputs from pooled embedding (batch, d_model)."""
|
|
711
|
-
num_pred = self.num_recon_head(
|
|
712
|
-
embedding) if self.num_recon_head is not None else None
|
|
713
|
-
cat_logits = [head(embedding) for head in self.cat_recon_heads]
|
|
714
|
-
return num_pred, cat_logits
|
|
715
|
-
|
|
716
|
-
# TabularDataset.
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
class TabularDataset(Dataset):
|
|
720
|
-
def __init__(self, X_num, X_cat, X_geo, y, w):
|
|
721
|
-
|
|
722
|
-
# Input tensors:
|
|
723
|
-
# X_num: torch.float32, shape=(N, num_numeric_features)
|
|
724
|
-
# X_cat: torch.long, shape=(N, num_categorical_features)
|
|
725
|
-
# X_geo: torch.float32, shape=(N, geo_token_dim), can be empty
|
|
726
|
-
# y: torch.float32, shape=(N, 1)
|
|
727
|
-
# w: torch.float32, shape=(N, 1)
|
|
728
|
-
|
|
729
|
-
self.X_num = X_num
|
|
730
|
-
self.X_cat = X_cat
|
|
731
|
-
self.X_geo = X_geo
|
|
732
|
-
self.y = y
|
|
733
|
-
self.w = w
|
|
734
|
-
|
|
735
|
-
def __len__(self):
|
|
736
|
-
return self.y.shape[0]
|
|
737
|
-
|
|
738
|
-
def __getitem__(self, idx):
|
|
739
|
-
return (
|
|
740
|
-
self.X_num[idx],
|
|
741
|
-
self.X_cat[idx],
|
|
742
|
-
self.X_geo[idx],
|
|
743
|
-
self.y[idx],
|
|
744
|
-
self.w[idx],
|
|
745
|
-
)
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
class MaskedTabularDataset(Dataset):
|
|
749
|
-
def __init__(self,
|
|
750
|
-
X_num_masked: torch.Tensor,
|
|
751
|
-
X_cat_masked: torch.Tensor,
|
|
752
|
-
X_geo: torch.Tensor,
|
|
753
|
-
X_num_true: Optional[torch.Tensor],
|
|
754
|
-
num_mask: Optional[torch.Tensor],
|
|
755
|
-
X_cat_true: Optional[torch.Tensor],
|
|
756
|
-
cat_mask: Optional[torch.Tensor]):
|
|
757
|
-
self.X_num_masked = X_num_masked
|
|
758
|
-
self.X_cat_masked = X_cat_masked
|
|
759
|
-
self.X_geo = X_geo
|
|
760
|
-
self.X_num_true = X_num_true
|
|
761
|
-
self.num_mask = num_mask
|
|
762
|
-
self.X_cat_true = X_cat_true
|
|
763
|
-
self.cat_mask = cat_mask
|
|
764
|
-
|
|
765
|
-
def __len__(self):
|
|
766
|
-
return self.X_num_masked.shape[0]
|
|
767
|
-
|
|
768
|
-
def __getitem__(self, idx):
|
|
769
|
-
return (
|
|
770
|
-
self.X_num_masked[idx],
|
|
771
|
-
self.X_cat_masked[idx],
|
|
772
|
-
self.X_geo[idx],
|
|
773
|
-
None if self.X_num_true is None else self.X_num_true[idx],
|
|
774
|
-
None if self.num_mask is None else self.num_mask[idx],
|
|
775
|
-
None if self.X_cat_true is None else self.X_cat_true[idx],
|
|
776
|
-
None if self.cat_mask is None else self.cat_mask[idx],
|
|
777
|
-
)
|
|
778
|
-
|
|
779
|
-
# Scikit-Learn style wrapper for FTTransformer.
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
783
|
-
|
|
784
|
-
# sklearn-style wrapper:
|
|
785
|
-
# - num_cols: numeric feature column names
|
|
786
|
-
# - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
|
|
787
|
-
|
|
788
|
-
@staticmethod
|
|
789
|
-
def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
|
|
790
|
-
num_cols_count = len(num_cols or [])
|
|
791
|
-
if num_cols_count == 0:
|
|
792
|
-
return 0
|
|
793
|
-
if requested is not None:
|
|
794
|
-
count = int(requested)
|
|
795
|
-
if count <= 0:
|
|
796
|
-
raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
|
|
797
|
-
return count
|
|
798
|
-
return max(1, num_cols_count)
|
|
799
|
-
|
|
800
|
-
def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
|
|
801
|
-
n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
|
|
802
|
-
task_type: str = 'regression',
|
|
803
|
-
tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
|
|
804
|
-
weight_decay: float = 0.0,
|
|
805
|
-
use_data_parallel: bool = True,
|
|
806
|
-
use_ddp: bool = False,
|
|
807
|
-
num_numeric_tokens: Optional[int] = None
|
|
808
|
-
):
|
|
809
|
-
super().__init__()
|
|
810
|
-
|
|
811
|
-
self.use_ddp = use_ddp
|
|
812
|
-
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
|
|
813
|
-
False, 0, 0, 1)
|
|
814
|
-
if self.use_ddp:
|
|
815
|
-
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
|
|
816
|
-
|
|
817
|
-
self.model_nme = model_nme
|
|
818
|
-
self.num_cols = list(num_cols)
|
|
819
|
-
self.cat_cols = list(cat_cols)
|
|
820
|
-
self.num_numeric_tokens = self.resolve_numeric_token_count(
|
|
821
|
-
self.num_cols,
|
|
822
|
-
self.cat_cols,
|
|
823
|
-
num_numeric_tokens,
|
|
824
|
-
)
|
|
825
|
-
self.d_model = d_model
|
|
826
|
-
self.n_heads = n_heads
|
|
827
|
-
self.n_layers = n_layers
|
|
828
|
-
self.dropout = dropout
|
|
829
|
-
self.batch_num = batch_num
|
|
830
|
-
self.epochs = epochs
|
|
831
|
-
self.learning_rate = learning_rate
|
|
832
|
-
self.weight_decay = weight_decay
|
|
833
|
-
self.task_type = task_type
|
|
834
|
-
self.patience = patience
|
|
835
|
-
if self.task_type == 'classification':
|
|
836
|
-
self.tw_power = None # No Tweedie power for classification.
|
|
837
|
-
elif 'f' in self.model_nme:
|
|
838
|
-
self.tw_power = 1.0
|
|
839
|
-
elif 's' in self.model_nme:
|
|
840
|
-
self.tw_power = 2.0
|
|
841
|
-
else:
|
|
842
|
-
self.tw_power = tweedie_power
|
|
843
|
-
|
|
844
|
-
if self.is_ddp_enabled:
|
|
845
|
-
# Allow CPU DDP (e.g. gloo) if CUDA is not available
|
|
846
|
-
if torch.cuda.is_available():
|
|
847
|
-
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
848
|
-
else:
|
|
849
|
-
self.device = torch.device("cpu")
|
|
850
|
-
self.cat_cardinalities = None
|
|
851
|
-
self.cat_categories = {}
|
|
852
|
-
self.cat_maps: Dict[str, Dict[Any, int]] = {}
|
|
853
|
-
self.cat_str_maps: Dict[str, Dict[str, int]] = {}
|
|
854
|
-
self._num_mean = None
|
|
855
|
-
self._num_std = None
|
|
856
|
-
self.ft = None
|
|
857
|
-
self.use_data_parallel = bool(use_data_parallel)
|
|
858
|
-
self.num_geo = 0
|
|
859
|
-
self._geo_params: Dict[str, Any] = {}
|
|
860
|
-
self.loss_curve_path: Optional[str] = None
|
|
861
|
-
self.training_history: Dict[str, List[float]] = {
|
|
862
|
-
"train": [], "val": []}
|
|
863
|
-
|
|
864
|
-
def _build_model(self, X_train):
|
|
865
|
-
num_numeric = len(self.num_cols)
|
|
866
|
-
cat_cardinalities = []
|
|
867
|
-
|
|
868
|
-
if num_numeric > 0:
|
|
869
|
-
num_arr = X_train[self.num_cols].to_numpy(
|
|
870
|
-
dtype=np.float32, copy=False)
|
|
871
|
-
num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
|
|
872
|
-
mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
|
|
873
|
-
std = num_arr.std(axis=0).astype(np.float32, copy=False)
|
|
874
|
-
std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
|
|
875
|
-
self._num_mean = mean
|
|
876
|
-
self._num_std = std
|
|
877
|
-
else:
|
|
878
|
-
self._num_mean = None
|
|
879
|
-
self._num_std = None
|
|
880
|
-
|
|
881
|
-
self.cat_maps = {}
|
|
882
|
-
self.cat_str_maps = {}
|
|
883
|
-
for col in self.cat_cols:
|
|
884
|
-
cats = X_train[col].astype('category')
|
|
885
|
-
categories = cats.cat.categories
|
|
886
|
-
self.cat_categories[col] = categories # Store full category list from training.
|
|
887
|
-
self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
|
|
888
|
-
if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
|
|
889
|
-
self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
|
|
890
|
-
|
|
891
|
-
card = len(categories) + 1 # Reserve one extra class for unknown/missing.
|
|
892
|
-
cat_cardinalities.append(card)
|
|
893
|
-
|
|
894
|
-
self.cat_cardinalities = cat_cardinalities
|
|
895
|
-
|
|
896
|
-
core = FTTransformerCore(
|
|
897
|
-
num_numeric=num_numeric,
|
|
898
|
-
cat_cardinalities=cat_cardinalities,
|
|
899
|
-
d_model=self.d_model,
|
|
900
|
-
n_heads=self.n_heads,
|
|
901
|
-
n_layers=self.n_layers,
|
|
902
|
-
dropout=self.dropout,
|
|
903
|
-
task_type=self.task_type,
|
|
904
|
-
num_geo=self.num_geo,
|
|
905
|
-
num_numeric_tokens=self.num_numeric_tokens
|
|
906
|
-
)
|
|
907
|
-
use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
|
|
908
|
-
if self.is_ddp_enabled:
|
|
909
|
-
core = core.to(self.device)
|
|
910
|
-
if self.device.type == 'cuda':
|
|
911
|
-
core = DDP(core, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
|
|
912
|
-
else:
|
|
913
|
-
# CPU/Gloo DDP
|
|
914
|
-
core = DDP(core, find_unused_parameters=True)
|
|
915
|
-
self.use_data_parallel = False
|
|
916
|
-
elif use_dp:
|
|
917
|
-
if self.use_ddp and not self.is_ddp_enabled:
|
|
918
|
-
print(
|
|
919
|
-
">>> DDP requested but not initialized; falling back to DataParallel.")
|
|
920
|
-
core = nn.DataParallel(core, device_ids=list(
|
|
921
|
-
range(torch.cuda.device_count())))
|
|
922
|
-
self.device = torch.device("cuda")
|
|
923
|
-
self.use_data_parallel = True
|
|
924
|
-
else:
|
|
925
|
-
self.use_data_parallel = False
|
|
926
|
-
self.ft = core.to(self.device)
|
|
927
|
-
|
|
928
|
-
def _encode_cats(self, X):
|
|
929
|
-
# Input DataFrame must include all categorical feature columns.
|
|
930
|
-
# Return int64 array with shape (N, num_categorical_features).
|
|
931
|
-
|
|
932
|
-
if not self.cat_cols:
|
|
933
|
-
return np.zeros((len(X), 0), dtype='int64')
|
|
934
|
-
|
|
935
|
-
n_rows = len(X)
|
|
936
|
-
n_cols = len(self.cat_cols)
|
|
937
|
-
X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
|
|
938
|
-
for idx, col in enumerate(self.cat_cols):
|
|
939
|
-
categories = self.cat_categories[col]
|
|
940
|
-
mapping = self.cat_maps.get(col)
|
|
941
|
-
if mapping is None:
|
|
942
|
-
mapping = {cat: i for i, cat in enumerate(categories)}
|
|
943
|
-
self.cat_maps[col] = mapping
|
|
944
|
-
unknown_idx = len(categories)
|
|
945
|
-
series = X[col]
|
|
946
|
-
codes = series.map(mapping)
|
|
947
|
-
unmapped = series.notna() & codes.isna()
|
|
948
|
-
if unmapped.any():
|
|
949
|
-
try:
|
|
950
|
-
series_cast = series.astype(categories.dtype)
|
|
951
|
-
except Exception:
|
|
952
|
-
series_cast = None
|
|
953
|
-
if series_cast is not None:
|
|
954
|
-
codes = series_cast.map(mapping)
|
|
955
|
-
unmapped = series_cast.notna() & codes.isna()
|
|
956
|
-
if unmapped.any():
|
|
957
|
-
str_map = self.cat_str_maps.get(col)
|
|
958
|
-
if str_map is None:
|
|
959
|
-
str_map = {str(cat): i for i, cat in enumerate(categories)}
|
|
960
|
-
self.cat_str_maps[col] = str_map
|
|
961
|
-
codes = series.astype(str).map(str_map)
|
|
962
|
-
if pd.api.types.is_categorical_dtype(codes):
|
|
963
|
-
codes = codes.astype("float")
|
|
964
|
-
codes = codes.fillna(unknown_idx).astype(
|
|
965
|
-
"int64", copy=False).to_numpy()
|
|
966
|
-
X_cat_np[:, idx] = codes
|
|
967
|
-
return X_cat_np
|
|
968
|
-
|
|
969
|
-
def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
|
|
970
|
-
return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
|
|
971
|
-
|
|
972
|
-
def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
|
|
973
|
-
return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
|
|
974
|
-
|
|
975
|
-
@staticmethod
|
|
976
|
-
def _validate_vector(arr, name: str, n_rows: int) -> None:
|
|
977
|
-
if arr is None:
|
|
978
|
-
return
|
|
979
|
-
if isinstance(arr, pd.DataFrame):
|
|
980
|
-
if arr.shape[1] != 1:
|
|
981
|
-
raise ValueError(f"{name} must be 1d (single column).")
|
|
982
|
-
length = len(arr)
|
|
983
|
-
else:
|
|
984
|
-
arr_np = np.asarray(arr)
|
|
985
|
-
if arr_np.ndim == 0:
|
|
986
|
-
raise ValueError(f"{name} must be 1d.")
|
|
987
|
-
if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
|
|
988
|
-
raise ValueError(f"{name} must be 1d or Nx1.")
|
|
989
|
-
length = arr_np.shape[0]
|
|
990
|
-
if length != n_rows:
|
|
991
|
-
raise ValueError(
|
|
992
|
-
f"{name} length {length} does not match X length {n_rows}."
|
|
993
|
-
)
|
|
994
|
-
|
|
995
|
-
def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
|
|
996
|
-
if X is None:
|
|
997
|
-
if allow_none:
|
|
998
|
-
return None, None, None, None, None, False
|
|
999
|
-
raise ValueError("Input features X must not be None.")
|
|
1000
|
-
if not isinstance(X, pd.DataFrame):
|
|
1001
|
-
raise ValueError("X must be a pandas DataFrame.")
|
|
1002
|
-
missing_cols = [
|
|
1003
|
-
col for col in (self.num_cols + self.cat_cols) if col not in X.columns
|
|
1004
|
-
]
|
|
1005
|
-
if missing_cols:
|
|
1006
|
-
raise ValueError(f"X is missing required columns: {missing_cols}")
|
|
1007
|
-
n_rows = len(X)
|
|
1008
|
-
if y is not None:
|
|
1009
|
-
self._validate_vector(y, "y", n_rows)
|
|
1010
|
-
if w is not None:
|
|
1011
|
-
self._validate_vector(w, "w", n_rows)
|
|
1012
|
-
|
|
1013
|
-
num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
|
|
1014
|
-
if not num_np.flags["OWNDATA"]:
|
|
1015
|
-
num_np = num_np.copy()
|
|
1016
|
-
num_np = np.nan_to_num(num_np, nan=0.0,
|
|
1017
|
-
posinf=0.0, neginf=0.0, copy=False)
|
|
1018
|
-
if self._num_mean is not None and self._num_std is not None and num_np.size:
|
|
1019
|
-
num_np = (num_np - self._num_mean) / self._num_std
|
|
1020
|
-
X_num = torch.as_tensor(num_np)
|
|
1021
|
-
if self.cat_cols:
|
|
1022
|
-
X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
|
|
1023
|
-
else:
|
|
1024
|
-
X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
|
|
1025
|
-
|
|
1026
|
-
if geo_tokens is not None:
|
|
1027
|
-
geo_np = np.asarray(geo_tokens, dtype=np.float32)
|
|
1028
|
-
if geo_np.shape[0] != n_rows:
|
|
1029
|
-
raise ValueError(
|
|
1030
|
-
"geo_tokens length does not match X rows.")
|
|
1031
|
-
if geo_np.ndim == 1:
|
|
1032
|
-
geo_np = geo_np.reshape(-1, 1)
|
|
1033
|
-
elif self.num_geo > 0:
|
|
1034
|
-
raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
|
|
1035
|
-
else:
|
|
1036
|
-
geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
|
|
1037
|
-
X_geo = torch.as_tensor(geo_np)
|
|
1038
|
-
|
|
1039
|
-
y_tensor = torch.as_tensor(
|
|
1040
|
-
y.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
1041
|
-
y, "to_numpy") else np.asarray(y, dtype=np.float32)
|
|
1042
|
-
).view(-1, 1) if y is not None else None
|
|
1043
|
-
if y_tensor is None:
|
|
1044
|
-
w_tensor = None
|
|
1045
|
-
elif w is not None:
|
|
1046
|
-
w_tensor = torch.as_tensor(
|
|
1047
|
-
w.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
1048
|
-
w, "to_numpy") else np.asarray(w, dtype=np.float32)
|
|
1049
|
-
).view(-1, 1)
|
|
1050
|
-
else:
|
|
1051
|
-
w_tensor = torch.ones_like(y_tensor)
|
|
1052
|
-
return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
|
|
1053
|
-
|
|
1054
|
-
def fit(self, X_train, y_train, w_train=None,
|
|
1055
|
-
X_val=None, y_val=None, w_val=None, trial=None,
|
|
1056
|
-
geo_train=None, geo_val=None):
|
|
1057
|
-
|
|
1058
|
-
# Build the underlying model on first fit.
|
|
1059
|
-
self.num_geo = geo_train.shape[1] if geo_train is not None else 0
|
|
1060
|
-
if self.ft is None:
|
|
1061
|
-
self._build_model(X_train)
|
|
1062
|
-
|
|
1063
|
-
X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
|
|
1064
|
-
X_train, y_train, w_train, geo_train=geo_train)
|
|
1065
|
-
X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
|
|
1066
|
-
X_val, y_val, w_val, geo_val=geo_val)
|
|
1067
|
-
|
|
1068
|
-
# --- Build DataLoader ---
|
|
1069
|
-
dataset = TabularDataset(
|
|
1070
|
-
X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
|
|
1071
|
-
)
|
|
1072
|
-
|
|
1073
|
-
dataloader, accum_steps = self._build_dataloader(
|
|
1074
|
-
dataset,
|
|
1075
|
-
N=X_num_train.shape[0],
|
|
1076
|
-
base_bs_gpu=(2048, 1024, 512),
|
|
1077
|
-
base_bs_cpu=(256, 128),
|
|
1078
|
-
min_bs=64,
|
|
1079
|
-
target_effective_cuda=2048,
|
|
1080
|
-
target_effective_cpu=1024
|
|
1081
|
-
)
|
|
1082
|
-
|
|
1083
|
-
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
1084
|
-
self.dataloader_sampler = dataloader.sampler
|
|
1085
|
-
else:
|
|
1086
|
-
self.dataloader_sampler = None
|
|
1087
|
-
|
|
1088
|
-
optimizer = torch.optim.Adam(
|
|
1089
|
-
self.ft.parameters(),
|
|
1090
|
-
lr=self.learning_rate,
|
|
1091
|
-
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
1092
|
-
)
|
|
1093
|
-
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
1094
|
-
|
|
1095
|
-
X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
|
|
1096
|
-
val_dataloader = None
|
|
1097
|
-
if has_val:
|
|
1098
|
-
val_dataset = TabularDataset(
|
|
1099
|
-
X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
|
|
1100
|
-
)
|
|
1101
|
-
val_dataloader = self._build_val_dataloader(
|
|
1102
|
-
val_dataset, dataloader, accum_steps)
|
|
1103
|
-
|
|
1104
|
-
is_data_parallel = isinstance(self.ft, nn.DataParallel)
|
|
1105
|
-
|
|
1106
|
-
def forward_fn(batch):
|
|
1107
|
-
X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
|
|
1108
|
-
|
|
1109
|
-
if not is_data_parallel:
|
|
1110
|
-
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
1111
|
-
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
1112
|
-
X_geo_b = X_geo_b.to(self.device, non_blocking=True)
|
|
1113
|
-
y_b = y_b.to(self.device, non_blocking=True)
|
|
1114
|
-
w_b = w_b.to(self.device, non_blocking=True)
|
|
1115
|
-
|
|
1116
|
-
y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
|
|
1117
|
-
return y_pred, y_b, w_b
|
|
1118
|
-
|
|
1119
|
-
def val_forward_fn():
|
|
1120
|
-
total_loss = 0.0
|
|
1121
|
-
total_weight = 0.0
|
|
1122
|
-
for batch in val_dataloader:
|
|
1123
|
-
X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
|
|
1124
|
-
if not is_data_parallel:
|
|
1125
|
-
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
1126
|
-
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
1127
|
-
X_geo_b = X_geo_b.to(self.device, non_blocking=True)
|
|
1128
|
-
y_b = y_b.to(self.device, non_blocking=True)
|
|
1129
|
-
w_b = w_b.to(self.device, non_blocking=True)
|
|
1130
|
-
|
|
1131
|
-
y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
|
|
1132
|
-
|
|
1133
|
-
# Manually compute validation loss.
|
|
1134
|
-
losses = self._compute_losses(
|
|
1135
|
-
y_pred, y_b, apply_softplus=False)
|
|
1136
|
-
|
|
1137
|
-
batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
|
|
1138
|
-
batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
|
|
1139
|
-
|
|
1140
|
-
total_loss += batch_weighted_loss_sum.item()
|
|
1141
|
-
total_weight += batch_weight_sum.item()
|
|
1142
|
-
|
|
1143
|
-
return total_loss / max(total_weight, EPS)
|
|
1144
|
-
|
|
1145
|
-
clip_fn = None
|
|
1146
|
-
if self.device.type == 'cuda':
|
|
1147
|
-
def clip_fn(): return (scaler.unscale_(optimizer),
|
|
1148
|
-
clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
|
|
1149
|
-
|
|
1150
|
-
best_state, history = self._train_model(
|
|
1151
|
-
self.ft,
|
|
1152
|
-
dataloader,
|
|
1153
|
-
accum_steps,
|
|
1154
|
-
optimizer,
|
|
1155
|
-
scaler,
|
|
1156
|
-
forward_fn,
|
|
1157
|
-
val_forward_fn if has_val else None,
|
|
1158
|
-
apply_softplus=False,
|
|
1159
|
-
clip_fn=clip_fn,
|
|
1160
|
-
trial=trial,
|
|
1161
|
-
loss_curve_path=getattr(self, "loss_curve_path", None)
|
|
1162
|
-
)
|
|
1163
|
-
|
|
1164
|
-
if has_val and best_state is not None:
|
|
1165
|
-
self.ft.load_state_dict(best_state)
|
|
1166
|
-
self.training_history = history
|
|
1167
|
-
|
|
1168
|
-
def fit_unsupervised(self,
|
|
1169
|
-
X_train,
|
|
1170
|
-
X_val=None,
|
|
1171
|
-
trial: Optional[optuna.trial.Trial] = None,
|
|
1172
|
-
geo_train=None,
|
|
1173
|
-
geo_val=None,
|
|
1174
|
-
mask_prob_num: float = 0.15,
|
|
1175
|
-
mask_prob_cat: float = 0.15,
|
|
1176
|
-
num_loss_weight: float = 1.0,
|
|
1177
|
-
cat_loss_weight: float = 1.0) -> float:
|
|
1178
|
-
"""Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
|
|
1179
|
-
self.num_geo = geo_train.shape[1] if geo_train is not None else 0
|
|
1180
|
-
if self.ft is None:
|
|
1181
|
-
self._build_model(X_train)
|
|
1182
|
-
|
|
1183
|
-
X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
|
|
1184
|
-
X_train, None, None, geo_tokens=geo_train, allow_none=True)
|
|
1185
|
-
has_val = X_val is not None
|
|
1186
|
-
if has_val:
|
|
1187
|
-
X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
|
|
1188
|
-
X_val, None, None, geo_tokens=geo_val, allow_none=True)
|
|
1189
|
-
else:
|
|
1190
|
-
X_num_val = X_cat_val = X_geo_val = None
|
|
1191
|
-
|
|
1192
|
-
N = int(X_num.shape[0])
|
|
1193
|
-
num_dim = int(X_num.shape[1])
|
|
1194
|
-
cat_dim = int(X_cat.shape[1])
|
|
1195
|
-
device_type = self._device_type()
|
|
1196
|
-
|
|
1197
|
-
gen = torch.Generator()
|
|
1198
|
-
gen.manual_seed(13 + int(getattr(self, "rank", 0)))
|
|
1199
|
-
|
|
1200
|
-
base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
|
|
1201
|
-
cardinals = getattr(base_model, "cat_cardinalities", None) or []
|
|
1202
|
-
unknown_idx = torch.tensor(
|
|
1203
|
-
[int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
|
|
1204
|
-
|
|
1205
|
-
means = None
|
|
1206
|
-
if num_dim > 0:
|
|
1207
|
-
# Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
|
|
1208
|
-
means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
|
|
1209
|
-
|
|
1210
|
-
def _mask_inputs(X_num_in: torch.Tensor,
|
|
1211
|
-
X_cat_in: torch.Tensor,
|
|
1212
|
-
generator: torch.Generator):
|
|
1213
|
-
n_rows = int(X_num_in.shape[0])
|
|
1214
|
-
num_mask_local = None
|
|
1215
|
-
cat_mask_local = None
|
|
1216
|
-
X_num_masked_local = X_num_in
|
|
1217
|
-
X_cat_masked_local = X_cat_in
|
|
1218
|
-
if num_dim > 0:
|
|
1219
|
-
num_mask_local = (torch.rand(
|
|
1220
|
-
(n_rows, num_dim), generator=generator) < float(mask_prob_num))
|
|
1221
|
-
X_num_masked_local = X_num_in.clone()
|
|
1222
|
-
if num_mask_local.any():
|
|
1223
|
-
X_num_masked_local[num_mask_local] = means.expand_as(
|
|
1224
|
-
X_num_masked_local)[num_mask_local]
|
|
1225
|
-
if cat_dim > 0:
|
|
1226
|
-
cat_mask_local = (torch.rand(
|
|
1227
|
-
(n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
|
|
1228
|
-
X_cat_masked_local = X_cat_in.clone()
|
|
1229
|
-
if cat_mask_local.any():
|
|
1230
|
-
X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
|
|
1231
|
-
X_cat_masked_local)[cat_mask_local]
|
|
1232
|
-
return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
|
|
1233
|
-
|
|
1234
|
-
X_num_true = X_num if num_dim > 0 else None
|
|
1235
|
-
X_cat_true = X_cat if cat_dim > 0 else None
|
|
1236
|
-
X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
|
|
1237
|
-
X_num, X_cat, gen)
|
|
1238
|
-
|
|
1239
|
-
dataset = MaskedTabularDataset(
|
|
1240
|
-
X_num_masked, X_cat_masked, X_geo,
|
|
1241
|
-
X_num_true, num_mask,
|
|
1242
|
-
X_cat_true, cat_mask
|
|
1243
|
-
)
|
|
1244
|
-
dataloader, accum_steps = self._build_dataloader(
|
|
1245
|
-
dataset,
|
|
1246
|
-
N=N,
|
|
1247
|
-
base_bs_gpu=(2048, 1024, 512),
|
|
1248
|
-
base_bs_cpu=(256, 128),
|
|
1249
|
-
min_bs=64,
|
|
1250
|
-
target_effective_cuda=2048,
|
|
1251
|
-
target_effective_cpu=1024
|
|
1252
|
-
)
|
|
1253
|
-
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
1254
|
-
self.dataloader_sampler = dataloader.sampler
|
|
1255
|
-
else:
|
|
1256
|
-
self.dataloader_sampler = None
|
|
1257
|
-
|
|
1258
|
-
optimizer = torch.optim.Adam(
|
|
1259
|
-
self.ft.parameters(),
|
|
1260
|
-
lr=self.learning_rate,
|
|
1261
|
-
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
1262
|
-
)
|
|
1263
|
-
scaler = GradScaler(enabled=(device_type == 'cuda'))
|
|
1264
|
-
|
|
1265
|
-
def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
|
|
1266
|
-
loss = torch.zeros((), device=device, dtype=torch.float32)
|
|
1267
|
-
|
|
1268
|
-
if num_pred is not None and num_true_b is not None and num_mask_b is not None:
|
|
1269
|
-
num_mask_b = num_mask_b.to(dtype=torch.bool)
|
|
1270
|
-
if num_mask_b.any():
|
|
1271
|
-
diff = num_pred - num_true_b
|
|
1272
|
-
mse = diff * diff
|
|
1273
|
-
loss = loss + float(num_loss_weight) * \
|
|
1274
|
-
mse[num_mask_b].mean()
|
|
1275
|
-
|
|
1276
|
-
if cat_logits and cat_true_b is not None and cat_mask_b is not None:
|
|
1277
|
-
cat_mask_b = cat_mask_b.to(dtype=torch.bool)
|
|
1278
|
-
cat_losses: List[torch.Tensor] = []
|
|
1279
|
-
for j, logits in enumerate(cat_logits):
|
|
1280
|
-
mask_j = cat_mask_b[:, j]
|
|
1281
|
-
if not mask_j.any():
|
|
1282
|
-
continue
|
|
1283
|
-
targets = cat_true_b[:, j]
|
|
1284
|
-
cat_losses.append(
|
|
1285
|
-
F.cross_entropy(logits, targets, reduction='none')[
|
|
1286
|
-
mask_j].mean()
|
|
1287
|
-
)
|
|
1288
|
-
if cat_losses:
|
|
1289
|
-
loss = loss + float(cat_loss_weight) * \
|
|
1290
|
-
torch.stack(cat_losses).mean()
|
|
1291
|
-
return loss
|
|
1292
|
-
|
|
1293
|
-
train_history: List[float] = []
|
|
1294
|
-
val_history: List[float] = []
|
|
1295
|
-
best_loss = float("inf")
|
|
1296
|
-
best_state = None
|
|
1297
|
-
patience_counter = 0
|
|
1298
|
-
is_ddp_model = isinstance(self.ft, DDP)
|
|
1299
|
-
|
|
1300
|
-
clip_fn = None
|
|
1301
|
-
if self.device.type == 'cuda':
|
|
1302
|
-
def clip_fn(): return (scaler.unscale_(optimizer),
|
|
1303
|
-
clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
|
|
1304
|
-
|
|
1305
|
-
for epoch in range(1, int(self.epochs) + 1):
|
|
1306
|
-
if self.dataloader_sampler is not None:
|
|
1307
|
-
self.dataloader_sampler.set_epoch(epoch)
|
|
1308
|
-
|
|
1309
|
-
self.ft.train()
|
|
1310
|
-
optimizer.zero_grad()
|
|
1311
|
-
epoch_loss_sum = 0.0
|
|
1312
|
-
epoch_count = 0.0
|
|
1313
|
-
|
|
1314
|
-
for step, batch in enumerate(dataloader):
|
|
1315
|
-
is_update_step = ((step + 1) % accum_steps == 0) or \
|
|
1316
|
-
((step + 1) == len(dataloader))
|
|
1317
|
-
sync_cm = self.ft.no_sync if (
|
|
1318
|
-
is_ddp_model and not is_update_step) else nullcontext
|
|
1319
|
-
with sync_cm():
|
|
1320
|
-
with autocast(enabled=(device_type == 'cuda')):
|
|
1321
|
-
X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
|
|
1322
|
-
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
1323
|
-
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
1324
|
-
X_geo_b = X_geo_b.to(self.device, non_blocking=True)
|
|
1325
|
-
num_true_b = None if num_true_b is None else num_true_b.to(
|
|
1326
|
-
self.device, non_blocking=True)
|
|
1327
|
-
num_mask_b = None if num_mask_b is None else num_mask_b.to(
|
|
1328
|
-
self.device, non_blocking=True)
|
|
1329
|
-
cat_true_b = None if cat_true_b is None else cat_true_b.to(
|
|
1330
|
-
self.device, non_blocking=True)
|
|
1331
|
-
cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
|
|
1332
|
-
self.device, non_blocking=True)
|
|
1333
|
-
|
|
1334
|
-
num_pred, cat_logits = self.ft(
|
|
1335
|
-
X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
|
|
1336
|
-
batch_loss = _batch_recon_loss(
|
|
1337
|
-
num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device=X_num_b.device)
|
|
1338
|
-
local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
|
|
1339
|
-
global_bad = local_bad
|
|
1340
|
-
if dist.is_initialized():
|
|
1341
|
-
bad = torch.tensor(
|
|
1342
|
-
[local_bad],
|
|
1343
|
-
device=batch_loss.device,
|
|
1344
|
-
dtype=torch.int32,
|
|
1345
|
-
)
|
|
1346
|
-
dist.all_reduce(bad, op=dist.ReduceOp.MAX)
|
|
1347
|
-
global_bad = int(bad.item())
|
|
1348
|
-
|
|
1349
|
-
if global_bad:
|
|
1350
|
-
msg = (
|
|
1351
|
-
f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
|
|
1352
|
-
f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
|
|
1353
|
-
)
|
|
1354
|
-
should_log = (not dist.is_initialized()
|
|
1355
|
-
or DistributedUtils.is_main_process())
|
|
1356
|
-
if should_log:
|
|
1357
|
-
print(msg, flush=True)
|
|
1358
|
-
print(
|
|
1359
|
-
f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
|
|
1360
|
-
f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
|
|
1361
|
-
f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
|
|
1362
|
-
flush=True,
|
|
1363
|
-
)
|
|
1364
|
-
if X_geo_b is not None:
|
|
1365
|
-
print(
|
|
1366
|
-
f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
|
|
1367
|
-
f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
|
|
1368
|
-
f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
|
|
1369
|
-
flush=True,
|
|
1370
|
-
)
|
|
1371
|
-
if trial is not None:
|
|
1372
|
-
raise optuna.TrialPruned(msg)
|
|
1373
|
-
raise RuntimeError(msg)
|
|
1374
|
-
loss_for_backward = batch_loss / float(accum_steps)
|
|
1375
|
-
scaler.scale(loss_for_backward).backward()
|
|
1376
|
-
|
|
1377
|
-
if is_update_step:
|
|
1378
|
-
if clip_fn is not None:
|
|
1379
|
-
clip_fn()
|
|
1380
|
-
scaler.step(optimizer)
|
|
1381
|
-
scaler.update()
|
|
1382
|
-
optimizer.zero_grad()
|
|
1383
|
-
|
|
1384
|
-
epoch_loss_sum += float(batch_loss.detach().item()) * \
|
|
1385
|
-
float(X_num_b.shape[0])
|
|
1386
|
-
epoch_count += float(X_num_b.shape[0])
|
|
1387
|
-
|
|
1388
|
-
train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
|
|
1389
|
-
|
|
1390
|
-
if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
|
|
1391
|
-
should_compute_val = (not dist.is_initialized()
|
|
1392
|
-
or DistributedUtils.is_main_process())
|
|
1393
|
-
loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
|
|
1394
|
-
"cpu")
|
|
1395
|
-
val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
|
|
1396
|
-
|
|
1397
|
-
if should_compute_val:
|
|
1398
|
-
self.ft.eval()
|
|
1399
|
-
with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
|
|
1400
|
-
val_bs = min(
|
|
1401
|
-
int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
|
|
1402
|
-
total_val = 0.0
|
|
1403
|
-
total_n = 0.0
|
|
1404
|
-
for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
|
|
1405
|
-
end = min(
|
|
1406
|
-
int(X_num_val.shape[0]), start + max(1, val_bs))
|
|
1407
|
-
X_num_v_true_cpu = X_num_val[start:end]
|
|
1408
|
-
X_cat_v_true_cpu = X_cat_val[start:end]
|
|
1409
|
-
X_geo_v = X_geo_val[start:end].to(
|
|
1410
|
-
self.device, non_blocking=True)
|
|
1411
|
-
gen_val = torch.Generator()
|
|
1412
|
-
gen_val.manual_seed(10_000 + epoch + start)
|
|
1413
|
-
X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
|
|
1414
|
-
X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
|
|
1415
|
-
X_num_v_true = X_num_v_true_cpu.to(
|
|
1416
|
-
self.device, non_blocking=True)
|
|
1417
|
-
X_cat_v_true = X_cat_v_true_cpu.to(
|
|
1418
|
-
self.device, non_blocking=True)
|
|
1419
|
-
X_num_v = X_num_v_cpu.to(
|
|
1420
|
-
self.device, non_blocking=True)
|
|
1421
|
-
X_cat_v = X_cat_v_cpu.to(
|
|
1422
|
-
self.device, non_blocking=True)
|
|
1423
|
-
val_num_mask = None if val_num_mask is None else val_num_mask.to(
|
|
1424
|
-
self.device, non_blocking=True)
|
|
1425
|
-
val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
|
|
1426
|
-
self.device, non_blocking=True)
|
|
1427
|
-
num_pred_v, cat_logits_v = self.ft(
|
|
1428
|
-
X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
|
|
1429
|
-
loss_v = _batch_recon_loss(
|
|
1430
|
-
num_pred_v, cat_logits_v,
|
|
1431
|
-
X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
|
|
1432
|
-
X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
|
|
1433
|
-
device=X_num_v.device
|
|
1434
|
-
)
|
|
1435
|
-
if not torch.isfinite(loss_v):
|
|
1436
|
-
total_val = float("inf")
|
|
1437
|
-
total_n = 1.0
|
|
1438
|
-
break
|
|
1439
|
-
total_val += float(loss_v.detach().item()
|
|
1440
|
-
) * float(end - start)
|
|
1441
|
-
total_n += float(end - start)
|
|
1442
|
-
val_loss_tensor[0] = total_val / max(total_n, 1.0)
|
|
1443
|
-
|
|
1444
|
-
if dist.is_initialized():
|
|
1445
|
-
dist.broadcast(val_loss_tensor, src=0)
|
|
1446
|
-
val_loss_value = float(val_loss_tensor.item())
|
|
1447
|
-
prune_now = False
|
|
1448
|
-
prune_msg = None
|
|
1449
|
-
if not np.isfinite(val_loss_value):
|
|
1450
|
-
prune_now = True
|
|
1451
|
-
prune_msg = (
|
|
1452
|
-
f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
|
|
1453
|
-
f"(epoch={epoch}, val_loss={val_loss_value})"
|
|
1454
|
-
)
|
|
1455
|
-
val_history.append(val_loss_value)
|
|
1456
|
-
|
|
1457
|
-
if val_loss_value < best_loss:
|
|
1458
|
-
best_loss = val_loss_value
|
|
1459
|
-
best_state = {
|
|
1460
|
-
k: (v.clone() if isinstance(
|
|
1461
|
-
v, torch.Tensor) else copy.deepcopy(v))
|
|
1462
|
-
for k, v in self.ft.state_dict().items()
|
|
1463
|
-
}
|
|
1464
|
-
patience_counter = 0
|
|
1465
|
-
else:
|
|
1466
|
-
patience_counter += 1
|
|
1467
|
-
if best_state is not None and patience_counter >= int(self.patience):
|
|
1468
|
-
break
|
|
1469
|
-
|
|
1470
|
-
if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
|
|
1471
|
-
trial.report(val_loss_value, epoch)
|
|
1472
|
-
if trial.should_prune():
|
|
1473
|
-
prune_now = True
|
|
1474
|
-
|
|
1475
|
-
if dist.is_initialized():
|
|
1476
|
-
flag = torch.tensor(
|
|
1477
|
-
[1 if prune_now else 0],
|
|
1478
|
-
device=loss_tensor_device,
|
|
1479
|
-
dtype=torch.int32,
|
|
1480
|
-
)
|
|
1481
|
-
dist.broadcast(flag, src=0)
|
|
1482
|
-
prune_now = bool(flag.item())
|
|
1483
|
-
|
|
1484
|
-
if prune_now:
|
|
1485
|
-
if prune_msg:
|
|
1486
|
-
raise optuna.TrialPruned(prune_msg)
|
|
1487
|
-
raise optuna.TrialPruned()
|
|
1488
|
-
|
|
1489
|
-
self.training_history = {"train": train_history, "val": val_history}
|
|
1490
|
-
self._plot_loss_curve(self.training_history, getattr(
|
|
1491
|
-
self, "loss_curve_path", None))
|
|
1492
|
-
if has_val and best_state is not None:
|
|
1493
|
-
self.ft.load_state_dict(best_state)
|
|
1494
|
-
return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
|
|
1495
|
-
|
|
1496
|
-
def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
|
|
1497
|
-
# X_test must include all numeric/categorical columns; geo_tokens is optional.
|
|
1498
|
-
|
|
1499
|
-
self.ft.eval()
|
|
1500
|
-
X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
|
|
1501
|
-
X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
|
|
1502
|
-
|
|
1503
|
-
num_rows = X_num.shape[0]
|
|
1504
|
-
if num_rows == 0:
|
|
1505
|
-
return np.empty(0, dtype=np.float32)
|
|
1506
|
-
|
|
1507
|
-
device = self.device if isinstance(
|
|
1508
|
-
self.device, torch.device) else torch.device(self.device)
|
|
1509
|
-
|
|
1510
|
-
def resolve_batch_size(n_rows: int) -> int:
|
|
1511
|
-
if batch_size is not None:
|
|
1512
|
-
return max(1, min(int(batch_size), n_rows))
|
|
1513
|
-
# Estimate a safe batch size based on model size to avoid attention OOM.
|
|
1514
|
-
token_cnt = self.num_numeric_tokens + len(self.cat_cols)
|
|
1515
|
-
if self.num_geo > 0:
|
|
1516
|
-
token_cnt += 1
|
|
1517
|
-
approx_units = max(1, token_cnt * max(1, self.d_model))
|
|
1518
|
-
if device.type == 'cuda':
|
|
1519
|
-
if approx_units >= 8192:
|
|
1520
|
-
base = 512
|
|
1521
|
-
elif approx_units >= 4096:
|
|
1522
|
-
base = 1024
|
|
1523
|
-
else:
|
|
1524
|
-
base = 2048
|
|
1525
|
-
else:
|
|
1526
|
-
base = 512
|
|
1527
|
-
return max(1, min(base, n_rows))
|
|
1528
|
-
|
|
1529
|
-
eff_batch = resolve_batch_size(num_rows)
|
|
1530
|
-
preds: List[torch.Tensor] = []
|
|
1531
|
-
|
|
1532
|
-
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
1533
|
-
with inference_cm():
|
|
1534
|
-
for start in range(0, num_rows, eff_batch):
|
|
1535
|
-
end = min(num_rows, start + eff_batch)
|
|
1536
|
-
X_num_b = X_num[start:end].to(device, non_blocking=True)
|
|
1537
|
-
X_cat_b = X_cat[start:end].to(device, non_blocking=True)
|
|
1538
|
-
X_geo_b = X_geo[start:end].to(device, non_blocking=True)
|
|
1539
|
-
pred_chunk = self.ft(
|
|
1540
|
-
X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
|
|
1541
|
-
preds.append(pred_chunk.cpu())
|
|
1542
|
-
|
|
1543
|
-
y_pred = torch.cat(preds, dim=0).numpy()
|
|
1544
|
-
|
|
1545
|
-
if return_embedding:
|
|
1546
|
-
return y_pred
|
|
1547
|
-
|
|
1548
|
-
if self.task_type == 'classification':
|
|
1549
|
-
# Convert logits to probabilities.
|
|
1550
|
-
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
1551
|
-
else:
|
|
1552
|
-
# Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
|
|
1553
|
-
y_pred = np.clip(y_pred, 1e-6, None)
|
|
1554
|
-
return y_pred.ravel()
|
|
1555
|
-
|
|
1556
|
-
def set_params(self, params: dict):
|
|
1557
|
-
|
|
1558
|
-
# Keep sklearn-style behavior.
|
|
1559
|
-
# Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
|
|
1560
|
-
|
|
1561
|
-
for key, value in params.items():
|
|
1562
|
-
if hasattr(self, key):
|
|
1563
|
-
setattr(self, key, value)
|
|
1564
|
-
else:
|
|
1565
|
-
raise ValueError(f"Parameter {key} not found in model.")
|
|
1566
|
-
return self
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
# =============================================================================
|
|
1570
|
-
# Simplified GNN implementation.
|
|
1571
|
-
# =============================================================================
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
class SimpleGraphLayer(nn.Module):
|
|
1575
|
-
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
1576
|
-
super().__init__()
|
|
1577
|
-
self.linear = nn.Linear(in_dim, out_dim)
|
|
1578
|
-
self.activation = nn.ReLU(inplace=True)
|
|
1579
|
-
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
1580
|
-
|
|
1581
|
-
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
|
|
1582
|
-
# Message passing with normalized sparse adjacency: A_hat * X * W.
|
|
1583
|
-
h = torch.sparse.mm(adj, x)
|
|
1584
|
-
h = self.linear(h)
|
|
1585
|
-
h = self.activation(h)
|
|
1586
|
-
return self.dropout(h)
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
class SimpleGNN(nn.Module):
|
|
1590
|
-
def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
|
|
1591
|
-
dropout: float = 0.1, task_type: str = 'regression'):
|
|
1592
|
-
super().__init__()
|
|
1593
|
-
layers = []
|
|
1594
|
-
dim_in = input_dim
|
|
1595
|
-
for _ in range(max(1, num_layers)):
|
|
1596
|
-
layers.append(SimpleGraphLayer(
|
|
1597
|
-
dim_in, hidden_dim, dropout=dropout))
|
|
1598
|
-
dim_in = hidden_dim
|
|
1599
|
-
self.layers = nn.ModuleList(layers)
|
|
1600
|
-
self.output = nn.Linear(hidden_dim, 1)
|
|
1601
|
-
if task_type == 'classification':
|
|
1602
|
-
self.output_act = nn.Identity()
|
|
1603
|
-
else:
|
|
1604
|
-
self.output_act = nn.Softplus()
|
|
1605
|
-
self.task_type = task_type
|
|
1606
|
-
# Keep adjacency as a buffer for DataParallel copies.
|
|
1607
|
-
self.register_buffer("adj_buffer", torch.empty(0))
|
|
1608
|
-
|
|
1609
|
-
def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1610
|
-
adj_used = adj if adj is not None else getattr(
|
|
1611
|
-
self, "adj_buffer", None)
|
|
1612
|
-
if adj_used is None or adj_used.numel() == 0:
|
|
1613
|
-
raise RuntimeError("Adjacency is not set for GNN forward.")
|
|
1614
|
-
h = x
|
|
1615
|
-
for layer in self.layers:
|
|
1616
|
-
h = layer(h, adj_used)
|
|
1617
|
-
h = torch.sparse.mm(adj_used, h)
|
|
1618
|
-
out = self.output(h)
|
|
1619
|
-
return self.output_act(out)
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
1623
|
-
def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
|
|
1624
|
-
num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
|
|
1625
|
-
learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
|
|
1626
|
-
task_type: str = 'regression', tweedie_power: float = 1.5,
|
|
1627
|
-
weight_decay: float = 0.0,
|
|
1628
|
-
use_data_parallel: bool = False, use_ddp: bool = False,
|
|
1629
|
-
use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
|
|
1630
|
-
graph_cache_path: Optional[str] = None,
|
|
1631
|
-
max_gpu_knn_nodes: Optional[int] = None,
|
|
1632
|
-
knn_gpu_mem_ratio: float = 0.9,
|
|
1633
|
-
knn_gpu_mem_overhead: float = 2.0,
|
|
1634
|
-
knn_cpu_jobs: Optional[int] = -1) -> None:
|
|
1635
|
-
super().__init__()
|
|
1636
|
-
self.model_nme = model_nme
|
|
1637
|
-
self.input_dim = input_dim
|
|
1638
|
-
self.hidden_dim = hidden_dim
|
|
1639
|
-
self.num_layers = num_layers
|
|
1640
|
-
self.k_neighbors = max(1, k_neighbors)
|
|
1641
|
-
self.dropout = dropout
|
|
1642
|
-
self.learning_rate = learning_rate
|
|
1643
|
-
self.weight_decay = weight_decay
|
|
1644
|
-
self.epochs = epochs
|
|
1645
|
-
self.patience = patience
|
|
1646
|
-
self.task_type = task_type
|
|
1647
|
-
self.use_approx_knn = use_approx_knn
|
|
1648
|
-
self.approx_knn_threshold = approx_knn_threshold
|
|
1649
|
-
self.graph_cache_path = Path(
|
|
1650
|
-
graph_cache_path) if graph_cache_path else None
|
|
1651
|
-
self.max_gpu_knn_nodes = max_gpu_knn_nodes
|
|
1652
|
-
self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
|
|
1653
|
-
self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
|
|
1654
|
-
self.knn_cpu_jobs = knn_cpu_jobs
|
|
1655
|
-
self._knn_warning_emitted = False
|
|
1656
|
-
self._adj_cache_meta: Optional[Dict[str, Any]] = None
|
|
1657
|
-
self._adj_cache_key: Optional[Tuple[Any, ...]] = None
|
|
1658
|
-
self._adj_cache_tensor: Optional[torch.Tensor] = None
|
|
1659
|
-
|
|
1660
|
-
if self.task_type == 'classification':
|
|
1661
|
-
self.tw_power = None
|
|
1662
|
-
elif 'f' in self.model_nme:
|
|
1663
|
-
self.tw_power = 1.0
|
|
1664
|
-
elif 's' in self.model_nme:
|
|
1665
|
-
self.tw_power = 2.0
|
|
1666
|
-
else:
|
|
1667
|
-
self.tw_power = tweedie_power
|
|
1668
|
-
|
|
1669
|
-
self.ddp_enabled = False
|
|
1670
|
-
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
1671
|
-
self.data_parallel_enabled = False
|
|
1672
|
-
self._ddp_disabled = False
|
|
1673
|
-
|
|
1674
|
-
if use_ddp:
|
|
1675
|
-
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
1676
|
-
if world_size > 1:
|
|
1677
|
-
print(
|
|
1678
|
-
"[GNN] DDP training is not supported; falling back to single process.",
|
|
1679
|
-
flush=True,
|
|
1680
|
-
)
|
|
1681
|
-
self._ddp_disabled = True
|
|
1682
|
-
use_ddp = False
|
|
1683
|
-
|
|
1684
|
-
# DDP only works with CUDA; fall back to single process if init fails.
|
|
1685
|
-
if use_ddp and torch.cuda.is_available():
|
|
1686
|
-
ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
|
|
1687
|
-
if ddp_ok:
|
|
1688
|
-
self.ddp_enabled = True
|
|
1689
|
-
self.local_rank = local_rank
|
|
1690
|
-
self.device = torch.device(f'cuda:{local_rank}')
|
|
1691
|
-
else:
|
|
1692
|
-
self.device = torch.device('cuda')
|
|
1693
|
-
elif torch.cuda.is_available():
|
|
1694
|
-
if self._ddp_disabled:
|
|
1695
|
-
self.device = torch.device(f'cuda:{self.local_rank}')
|
|
1696
|
-
else:
|
|
1697
|
-
self.device = torch.device('cuda')
|
|
1698
|
-
elif torch.backends.mps.is_available():
|
|
1699
|
-
self.device = torch.device('cpu')
|
|
1700
|
-
global _GNN_MPS_WARNED
|
|
1701
|
-
if not _GNN_MPS_WARNED:
|
|
1702
|
-
print(
|
|
1703
|
-
"[GNN] MPS backend does not support sparse ops; falling back to CPU.",
|
|
1704
|
-
flush=True,
|
|
1705
|
-
)
|
|
1706
|
-
_GNN_MPS_WARNED = True
|
|
1707
|
-
else:
|
|
1708
|
-
self.device = torch.device('cpu')
|
|
1709
|
-
self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
|
|
1710
|
-
|
|
1711
|
-
self.gnn = SimpleGNN(
|
|
1712
|
-
input_dim=self.input_dim,
|
|
1713
|
-
hidden_dim=self.hidden_dim,
|
|
1714
|
-
num_layers=self.num_layers,
|
|
1715
|
-
dropout=self.dropout,
|
|
1716
|
-
task_type=self.task_type
|
|
1717
|
-
).to(self.device)
|
|
1718
|
-
|
|
1719
|
-
# DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
|
|
1720
|
-
if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
|
|
1721
|
-
self.data_parallel_enabled = True
|
|
1722
|
-
self.gnn = nn.DataParallel(
|
|
1723
|
-
self.gnn, device_ids=list(range(torch.cuda.device_count())))
|
|
1724
|
-
self.device = torch.device('cuda')
|
|
1725
|
-
|
|
1726
|
-
if self.ddp_enabled:
|
|
1727
|
-
self.gnn = DDP(
|
|
1728
|
-
self.gnn,
|
|
1729
|
-
device_ids=[self.local_rank],
|
|
1730
|
-
output_device=self.local_rank,
|
|
1731
|
-
find_unused_parameters=False
|
|
1732
|
-
)
|
|
1733
|
-
|
|
1734
|
-
@staticmethod
|
|
1735
|
-
def _validate_vector(arr, name: str, n_rows: int) -> None:
|
|
1736
|
-
if arr is None:
|
|
1737
|
-
return
|
|
1738
|
-
if isinstance(arr, pd.DataFrame):
|
|
1739
|
-
if arr.shape[1] != 1:
|
|
1740
|
-
raise ValueError(f"{name} must be 1d (single column).")
|
|
1741
|
-
length = len(arr)
|
|
1742
|
-
else:
|
|
1743
|
-
arr_np = np.asarray(arr)
|
|
1744
|
-
if arr_np.ndim == 0:
|
|
1745
|
-
raise ValueError(f"{name} must be 1d.")
|
|
1746
|
-
if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
|
|
1747
|
-
raise ValueError(f"{name} must be 1d or Nx1.")
|
|
1748
|
-
length = arr_np.shape[0]
|
|
1749
|
-
if length != n_rows:
|
|
1750
|
-
raise ValueError(
|
|
1751
|
-
f"{name} length {length} does not match X length {n_rows}."
|
|
1752
|
-
)
|
|
1753
|
-
|
|
1754
|
-
def _unwrap_gnn(self) -> nn.Module:
|
|
1755
|
-
if isinstance(self.gnn, (DDP, nn.DataParallel)):
|
|
1756
|
-
return self.gnn.module
|
|
1757
|
-
return self.gnn
|
|
1758
|
-
|
|
1759
|
-
def _set_adj_buffer(self, adj: torch.Tensor) -> None:
|
|
1760
|
-
base = self._unwrap_gnn()
|
|
1761
|
-
if hasattr(base, "adj_buffer"):
|
|
1762
|
-
base.adj_buffer = adj
|
|
1763
|
-
else:
|
|
1764
|
-
base.register_buffer("adj_buffer", adj)
|
|
1765
|
-
|
|
1766
|
-
def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
|
|
1767
|
-
row_hash = pd.util.hash_pandas_object(X_df, index=False).values
|
|
1768
|
-
idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
|
|
1769
|
-
col_sig = ",".join(map(str, X_df.columns))
|
|
1770
|
-
hasher = hashlib.sha256()
|
|
1771
|
-
hasher.update(row_hash.tobytes())
|
|
1772
|
-
hasher.update(idx_hash.tobytes())
|
|
1773
|
-
hasher.update(col_sig.encode("utf-8", errors="ignore"))
|
|
1774
|
-
knn_config = {
|
|
1775
|
-
"k_neighbors": int(self.k_neighbors),
|
|
1776
|
-
"use_approx_knn": bool(self.use_approx_knn),
|
|
1777
|
-
"approx_knn_threshold": int(self.approx_knn_threshold),
|
|
1778
|
-
"use_pyg_knn": bool(self.use_pyg_knn),
|
|
1779
|
-
"pynndescent_available": bool(_PYNN_AVAILABLE),
|
|
1780
|
-
"max_gpu_knn_nodes": (
|
|
1781
|
-
None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
|
|
1782
|
-
),
|
|
1783
|
-
"knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
|
|
1784
|
-
"knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
|
|
1785
|
-
}
|
|
1786
|
-
return {
|
|
1787
|
-
"n_samples": int(X_df.shape[0]),
|
|
1788
|
-
"n_features": int(X_df.shape[1]),
|
|
1789
|
-
"hash": hasher.hexdigest(),
|
|
1790
|
-
"knn_config": knn_config,
|
|
1791
|
-
}
|
|
1792
|
-
|
|
1793
|
-
def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
|
|
1794
|
-
return (
|
|
1795
|
-
id(X_df),
|
|
1796
|
-
id(getattr(X_df, "_mgr", None)),
|
|
1797
|
-
id(X_df.index),
|
|
1798
|
-
X_df.shape,
|
|
1799
|
-
tuple(map(str, X_df.columns)),
|
|
1800
|
-
X_df.attrs.get("graph_cache_key"),
|
|
1801
|
-
)
|
|
1802
|
-
|
|
1803
|
-
def invalidate_graph_cache(self) -> None:
|
|
1804
|
-
self._adj_cache_meta = None
|
|
1805
|
-
self._adj_cache_key = None
|
|
1806
|
-
self._adj_cache_tensor = None
|
|
1807
|
-
|
|
1808
|
-
def _load_cached_adj(self,
|
|
1809
|
-
X_df: pd.DataFrame,
|
|
1810
|
-
meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
|
|
1811
|
-
if self.graph_cache_path and self.graph_cache_path.exists():
|
|
1812
|
-
if meta_expected is None:
|
|
1813
|
-
meta_expected = self._graph_cache_meta(X_df)
|
|
1814
|
-
try:
|
|
1815
|
-
payload = torch.load(self.graph_cache_path,
|
|
1816
|
-
map_location=self.device)
|
|
1817
|
-
except Exception as exc:
|
|
1818
|
-
print(
|
|
1819
|
-
f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
|
|
1820
|
-
return None
|
|
1821
|
-
if isinstance(payload, dict) and "adj" in payload:
|
|
1822
|
-
meta_cached = payload.get("meta")
|
|
1823
|
-
if meta_cached == meta_expected:
|
|
1824
|
-
return payload["adj"].to(self.device)
|
|
1825
|
-
print(
|
|
1826
|
-
f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
|
|
1827
|
-
return None
|
|
1828
|
-
if isinstance(payload, torch.Tensor):
|
|
1829
|
-
print(
|
|
1830
|
-
f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
|
|
1831
|
-
return None
|
|
1832
|
-
print(
|
|
1833
|
-
f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
|
|
1834
|
-
return None
|
|
1835
|
-
|
|
1836
|
-
def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
|
|
1837
|
-
n_samples = X_np.shape[0]
|
|
1838
|
-
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
1839
|
-
n_neighbors = min(k + 1, n_samples)
|
|
1840
|
-
use_approx = (self.use_approx_knn or n_samples >=
|
|
1841
|
-
self.approx_knn_threshold) and _PYNN_AVAILABLE
|
|
1842
|
-
indices = None
|
|
1843
|
-
if use_approx:
|
|
1844
|
-
try:
|
|
1845
|
-
nn_index = pynndescent.NNDescent(
|
|
1846
|
-
X_np,
|
|
1847
|
-
n_neighbors=n_neighbors,
|
|
1848
|
-
random_state=0
|
|
1849
|
-
)
|
|
1850
|
-
indices, _ = nn_index.neighbor_graph
|
|
1851
|
-
except Exception as exc:
|
|
1852
|
-
print(
|
|
1853
|
-
f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
|
|
1854
|
-
use_approx = False
|
|
1855
|
-
|
|
1856
|
-
if indices is None:
|
|
1857
|
-
nbrs = NearestNeighbors(
|
|
1858
|
-
n_neighbors=n_neighbors,
|
|
1859
|
-
algorithm="auto",
|
|
1860
|
-
n_jobs=self.knn_cpu_jobs,
|
|
1861
|
-
)
|
|
1862
|
-
nbrs.fit(X_np)
|
|
1863
|
-
_, indices = nbrs.kneighbors(X_np)
|
|
1864
|
-
|
|
1865
|
-
indices = np.asarray(indices)
|
|
1866
|
-
rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
|
|
1867
|
-
np.int64, copy=False)
|
|
1868
|
-
cols = indices.reshape(-1).astype(np.int64, copy=False)
|
|
1869
|
-
mask = rows != cols
|
|
1870
|
-
rows = rows[mask]
|
|
1871
|
-
cols = cols[mask]
|
|
1872
|
-
rows_base = rows
|
|
1873
|
-
cols_base = cols
|
|
1874
|
-
self_loops = np.arange(n_samples, dtype=np.int64)
|
|
1875
|
-
rows = np.concatenate([rows_base, cols_base, self_loops])
|
|
1876
|
-
cols = np.concatenate([cols_base, rows_base, self_loops])
|
|
1877
|
-
|
|
1878
|
-
edge_index_np = np.stack([rows, cols], axis=0)
|
|
1879
|
-
edge_index = torch.as_tensor(edge_index_np, device=self.device)
|
|
1880
|
-
return edge_index
|
|
1881
|
-
|
|
1882
|
-
def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
|
|
1883
|
-
if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
|
|
1884
|
-
# Defensive: check use_pyg_knn before calling.
|
|
1885
|
-
raise RuntimeError(
|
|
1886
|
-
"GPU graph builder requested but PyG is unavailable.")
|
|
1887
|
-
|
|
1888
|
-
n_samples = X_tensor.size(0)
|
|
1889
|
-
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
1890
|
-
|
|
1891
|
-
# knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
|
|
1892
|
-
edge_index = knn_graph(
|
|
1893
|
-
X_tensor,
|
|
1894
|
-
k=k,
|
|
1895
|
-
loop=False
|
|
1896
|
-
)
|
|
1897
|
-
edge_index = to_undirected(edge_index, num_nodes=n_samples)
|
|
1898
|
-
edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
|
|
1899
|
-
return edge_index
|
|
1900
|
-
|
|
1901
|
-
def _log_knn_fallback(self, reason: str) -> None:
|
|
1902
|
-
if self._knn_warning_emitted:
|
|
1903
|
-
return
|
|
1904
|
-
if (not self.ddp_enabled) or self.local_rank == 0:
|
|
1905
|
-
print(f"[GNN] Falling back to CPU kNN builder: {reason}")
|
|
1906
|
-
self._knn_warning_emitted = True
|
|
1907
|
-
|
|
1908
|
-
def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
|
|
1909
|
-
if not self.use_pyg_knn:
|
|
1910
|
-
return False
|
|
1911
|
-
|
|
1912
|
-
reason = None
|
|
1913
|
-
if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
|
|
1914
|
-
reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
|
|
1915
|
-
elif self.device.type == 'cuda' and torch.cuda.is_available():
|
|
1916
|
-
try:
|
|
1917
|
-
device_index = self.device.index
|
|
1918
|
-
if device_index is None:
|
|
1919
|
-
device_index = torch.cuda.current_device()
|
|
1920
|
-
free_mem, total_mem = torch.cuda.mem_get_info(device_index)
|
|
1921
|
-
feature_bytes = X_tensor.element_size() * X_tensor.nelement()
|
|
1922
|
-
required = int(feature_bytes * self.knn_gpu_mem_overhead)
|
|
1923
|
-
budget = int(free_mem * self.knn_gpu_mem_ratio)
|
|
1924
|
-
if required > budget:
|
|
1925
|
-
required_gb = required / (1024 ** 3)
|
|
1926
|
-
budget_gb = budget / (1024 ** 3)
|
|
1927
|
-
reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
|
|
1928
|
-
f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
|
|
1929
|
-
except Exception:
|
|
1930
|
-
# On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
|
|
1931
|
-
reason = None
|
|
1932
|
-
|
|
1933
|
-
if reason:
|
|
1934
|
-
self._log_knn_fallback(reason)
|
|
1935
|
-
return False
|
|
1936
|
-
return True
|
|
1937
|
-
|
|
1938
|
-
def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
1939
|
-
values = torch.ones(edge_index.shape[1], device=self.device)
|
|
1940
|
-
adj = torch.sparse_coo_tensor(
|
|
1941
|
-
edge_index.to(self.device), values, (num_nodes, num_nodes))
|
|
1942
|
-
adj = adj.coalesce()
|
|
1943
|
-
|
|
1944
|
-
deg = torch.sparse.sum(adj, dim=1).to_dense()
|
|
1945
|
-
deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
|
|
1946
|
-
row, col = adj.indices()
|
|
1947
|
-
norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
|
|
1948
|
-
adj_norm = torch.sparse_coo_tensor(
|
|
1949
|
-
adj.indices(), norm_values, size=adj.shape)
|
|
1950
|
-
return adj_norm
|
|
1951
|
-
|
|
1952
|
-
def _tensorize_split(self, X, y, w, allow_none: bool = False):
|
|
1953
|
-
if X is None and allow_none:
|
|
1954
|
-
return None, None, None
|
|
1955
|
-
if not isinstance(X, pd.DataFrame):
|
|
1956
|
-
raise ValueError("X must be a pandas DataFrame for GNN.")
|
|
1957
|
-
n_rows = len(X)
|
|
1958
|
-
if y is not None:
|
|
1959
|
-
self._validate_vector(y, "y", n_rows)
|
|
1960
|
-
if w is not None:
|
|
1961
|
-
self._validate_vector(w, "w", n_rows)
|
|
1962
|
-
X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
1963
|
-
X, "to_numpy") else np.asarray(X, dtype=np.float32)
|
|
1964
|
-
X_tensor = torch.as_tensor(
|
|
1965
|
-
X_np, dtype=torch.float32, device=self.device)
|
|
1966
|
-
if y is None:
|
|
1967
|
-
y_tensor = None
|
|
1968
|
-
else:
|
|
1969
|
-
y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
1970
|
-
y, "to_numpy") else np.asarray(y, dtype=np.float32)
|
|
1971
|
-
y_tensor = torch.as_tensor(
|
|
1972
|
-
y_np, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
1973
|
-
if w is None:
|
|
1974
|
-
w_tensor = torch.ones(
|
|
1975
|
-
(len(X), 1), dtype=torch.float32, device=self.device)
|
|
1976
|
-
else:
|
|
1977
|
-
w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
1978
|
-
w, "to_numpy") else np.asarray(w, dtype=np.float32)
|
|
1979
|
-
w_tensor = torch.as_tensor(
|
|
1980
|
-
w_np, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
1981
|
-
return X_tensor, y_tensor, w_tensor
|
|
1982
|
-
|
|
1983
|
-
def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1984
|
-
if not isinstance(X_df, pd.DataFrame):
|
|
1985
|
-
raise ValueError("X must be a pandas DataFrame for graph building.")
|
|
1986
|
-
meta_expected = None
|
|
1987
|
-
cache_key = None
|
|
1988
|
-
if self.graph_cache_path:
|
|
1989
|
-
meta_expected = self._graph_cache_meta(X_df)
|
|
1990
|
-
if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
|
|
1991
|
-
cached = self._adj_cache_tensor
|
|
1992
|
-
if cached.device != self.device:
|
|
1993
|
-
cached = cached.to(self.device)
|
|
1994
|
-
self._adj_cache_tensor = cached
|
|
1995
|
-
return cached
|
|
1996
|
-
else:
|
|
1997
|
-
cache_key = self._graph_cache_key(X_df)
|
|
1998
|
-
if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
|
|
1999
|
-
cached = self._adj_cache_tensor
|
|
2000
|
-
if cached.device != self.device:
|
|
2001
|
-
cached = cached.to(self.device)
|
|
2002
|
-
self._adj_cache_tensor = cached
|
|
2003
|
-
return cached
|
|
2004
|
-
X_np = None
|
|
2005
|
-
if X_tensor is None:
|
|
2006
|
-
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
2007
|
-
X_tensor = torch.as_tensor(
|
|
2008
|
-
X_np, dtype=torch.float32, device=self.device)
|
|
2009
|
-
if self.graph_cache_path:
|
|
2010
|
-
cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
|
|
2011
|
-
if cached is not None:
|
|
2012
|
-
self._adj_cache_meta = meta_expected
|
|
2013
|
-
self._adj_cache_key = None
|
|
2014
|
-
self._adj_cache_tensor = cached
|
|
2015
|
-
return cached
|
|
2016
|
-
use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
|
|
2017
|
-
if use_gpu_knn:
|
|
2018
|
-
edge_index = self._build_edge_index_gpu(X_tensor)
|
|
2019
|
-
else:
|
|
2020
|
-
if X_np is None:
|
|
2021
|
-
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
2022
|
-
edge_index = self._build_edge_index_cpu(X_np)
|
|
2023
|
-
adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
|
|
2024
|
-
if self.graph_cache_path:
|
|
2025
|
-
try:
|
|
2026
|
-
IOUtils.ensure_parent_dir(str(self.graph_cache_path))
|
|
2027
|
-
torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
|
|
2028
|
-
except Exception as exc:
|
|
2029
|
-
print(
|
|
2030
|
-
f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
|
|
2031
|
-
self._adj_cache_meta = meta_expected
|
|
2032
|
-
self._adj_cache_key = None
|
|
2033
|
-
else:
|
|
2034
|
-
self._adj_cache_meta = None
|
|
2035
|
-
self._adj_cache_key = cache_key
|
|
2036
|
-
self._adj_cache_tensor = adj_norm
|
|
2037
|
-
return adj_norm
|
|
2038
|
-
|
|
2039
|
-
def fit(self, X_train, y_train, w_train=None,
|
|
2040
|
-
X_val=None, y_val=None, w_val=None,
|
|
2041
|
-
trial: Optional[optuna.trial.Trial] = None):
|
|
2042
|
-
|
|
2043
|
-
X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
|
|
2044
|
-
X_train, y_train, w_train, allow_none=False)
|
|
2045
|
-
has_val = X_val is not None and y_val is not None
|
|
2046
|
-
if has_val:
|
|
2047
|
-
X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
|
|
2048
|
-
X_val, y_val, w_val, allow_none=False)
|
|
2049
|
-
else:
|
|
2050
|
-
X_val_tensor = y_val_tensor = w_val_tensor = None
|
|
2051
|
-
|
|
2052
|
-
adj_train = self._build_graph_from_df(X_train, X_train_tensor)
|
|
2053
|
-
adj_val = self._build_graph_from_df(
|
|
2054
|
-
X_val, X_val_tensor) if has_val else None
|
|
2055
|
-
# DataParallel needs adjacency cached on the model to avoid scatter.
|
|
2056
|
-
self._set_adj_buffer(adj_train)
|
|
2057
|
-
|
|
2058
|
-
base_gnn = self._unwrap_gnn()
|
|
2059
|
-
optimizer = torch.optim.Adam(
|
|
2060
|
-
base_gnn.parameters(),
|
|
2061
|
-
lr=self.learning_rate,
|
|
2062
|
-
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
2063
|
-
)
|
|
2064
|
-
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
2065
|
-
|
|
2066
|
-
best_loss = float('inf')
|
|
2067
|
-
best_state = None
|
|
2068
|
-
patience_counter = 0
|
|
2069
|
-
best_epoch = None
|
|
2070
|
-
|
|
2071
|
-
for epoch in range(1, self.epochs + 1):
|
|
2072
|
-
epoch_start_ts = time.time()
|
|
2073
|
-
self.gnn.train()
|
|
2074
|
-
optimizer.zero_grad()
|
|
2075
|
-
with autocast(enabled=(self.device.type == 'cuda')):
|
|
2076
|
-
if self.data_parallel_enabled:
|
|
2077
|
-
y_pred = self.gnn(X_train_tensor)
|
|
2078
|
-
else:
|
|
2079
|
-
y_pred = self.gnn(X_train_tensor, adj_train)
|
|
2080
|
-
loss = self._compute_weighted_loss(
|
|
2081
|
-
y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
|
|
2082
|
-
scaler.scale(loss).backward()
|
|
2083
|
-
scaler.unscale_(optimizer)
|
|
2084
|
-
clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
|
|
2085
|
-
scaler.step(optimizer)
|
|
2086
|
-
scaler.update()
|
|
2087
|
-
|
|
2088
|
-
val_loss = None
|
|
2089
|
-
if has_val:
|
|
2090
|
-
self.gnn.eval()
|
|
2091
|
-
if self.data_parallel_enabled and adj_val is not None:
|
|
2092
|
-
self._set_adj_buffer(adj_val)
|
|
2093
|
-
with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
|
|
2094
|
-
if self.data_parallel_enabled:
|
|
2095
|
-
y_val_pred = self.gnn(X_val_tensor)
|
|
2096
|
-
else:
|
|
2097
|
-
y_val_pred = self.gnn(X_val_tensor, adj_val)
|
|
2098
|
-
val_loss = self._compute_weighted_loss(
|
|
2099
|
-
y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
|
|
2100
|
-
if self.data_parallel_enabled:
|
|
2101
|
-
# Restore training adjacency.
|
|
2102
|
-
self._set_adj_buffer(adj_train)
|
|
2103
|
-
|
|
2104
|
-
is_best = val_loss is not None and val_loss < best_loss
|
|
2105
|
-
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
2106
|
-
val_loss, best_loss, best_state, patience_counter, base_gnn,
|
|
2107
|
-
ignore_keys=["adj_buffer"])
|
|
2108
|
-
if is_best:
|
|
2109
|
-
best_epoch = epoch
|
|
2110
|
-
|
|
2111
|
-
prune_now = False
|
|
2112
|
-
if trial is not None:
|
|
2113
|
-
trial.report(val_loss, epoch)
|
|
2114
|
-
if trial.should_prune():
|
|
2115
|
-
prune_now = True
|
|
2116
|
-
|
|
2117
|
-
if dist.is_initialized():
|
|
2118
|
-
flag = torch.tensor(
|
|
2119
|
-
[1 if prune_now else 0],
|
|
2120
|
-
device=self.device,
|
|
2121
|
-
dtype=torch.int32,
|
|
2122
|
-
)
|
|
2123
|
-
dist.broadcast(flag, src=0)
|
|
2124
|
-
prune_now = bool(flag.item())
|
|
2125
|
-
|
|
2126
|
-
if prune_now:
|
|
2127
|
-
raise optuna.TrialPruned()
|
|
2128
|
-
if stop_training:
|
|
2129
|
-
break
|
|
2130
|
-
|
|
2131
|
-
should_log = (not dist.is_initialized()
|
|
2132
|
-
or DistributedUtils.is_main_process())
|
|
2133
|
-
if should_log:
|
|
2134
|
-
elapsed = int(time.time() - epoch_start_ts)
|
|
2135
|
-
if val_loss is None:
|
|
2136
|
-
print(
|
|
2137
|
-
f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
|
|
2138
|
-
flush=True,
|
|
2139
|
-
)
|
|
2140
|
-
else:
|
|
2141
|
-
print(
|
|
2142
|
-
f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
|
|
2143
|
-
f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
|
|
2144
|
-
flush=True,
|
|
2145
|
-
)
|
|
2146
|
-
|
|
2147
|
-
if best_state is not None:
|
|
2148
|
-
base_gnn.load_state_dict(best_state, strict=False)
|
|
2149
|
-
self.best_epoch = int(best_epoch or self.epochs)
|
|
2150
|
-
|
|
2151
|
-
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
2152
|
-
self.gnn.eval()
|
|
2153
|
-
X_tensor, _, _ = self._tensorize_split(
|
|
2154
|
-
X, None, None, allow_none=False)
|
|
2155
|
-
adj = self._build_graph_from_df(X, X_tensor)
|
|
2156
|
-
if self.data_parallel_enabled:
|
|
2157
|
-
self._set_adj_buffer(adj)
|
|
2158
|
-
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
2159
|
-
with inference_cm():
|
|
2160
|
-
if self.data_parallel_enabled:
|
|
2161
|
-
y_pred = self.gnn(X_tensor).cpu().numpy()
|
|
2162
|
-
else:
|
|
2163
|
-
y_pred = self.gnn(X_tensor, adj).cpu().numpy()
|
|
2164
|
-
if self.task_type == 'classification':
|
|
2165
|
-
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
2166
|
-
else:
|
|
2167
|
-
y_pred = np.clip(y_pred, 1e-6, None)
|
|
2168
|
-
return y_pred.ravel()
|
|
2169
|
-
|
|
2170
|
-
def encode(self, X: pd.DataFrame) -> np.ndarray:
|
|
2171
|
-
"""Return per-sample node embeddings (hidden representations)."""
|
|
2172
|
-
base = self._unwrap_gnn()
|
|
2173
|
-
base.eval()
|
|
2174
|
-
X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
|
|
2175
|
-
adj = self._build_graph_from_df(X, X_tensor)
|
|
2176
|
-
if self.data_parallel_enabled:
|
|
2177
|
-
self._set_adj_buffer(adj)
|
|
2178
|
-
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
2179
|
-
with inference_cm():
|
|
2180
|
-
h = X_tensor
|
|
2181
|
-
layers = getattr(base, "layers", None)
|
|
2182
|
-
if layers is None:
|
|
2183
|
-
raise RuntimeError("GNN base module does not expose layers.")
|
|
2184
|
-
for layer in layers:
|
|
2185
|
-
h = layer(h, adj)
|
|
2186
|
-
h = torch.sparse.mm(adj, h)
|
|
2187
|
-
return h.detach().cpu().numpy()
|
|
2188
|
-
|
|
2189
|
-
def set_params(self, params: Dict[str, Any]):
|
|
2190
|
-
for key, value in params.items():
|
|
2191
|
-
if hasattr(self, key):
|
|
2192
|
-
setattr(self, key, value)
|
|
2193
|
-
else:
|
|
2194
|
-
raise ValueError(f"Parameter {key} not found in GNN model.")
|
|
2195
|
-
# Rebuild the backbone after structural parameter changes.
|
|
2196
|
-
self.gnn = SimpleGNN(
|
|
2197
|
-
input_dim=self.input_dim,
|
|
2198
|
-
hidden_dim=self.hidden_dim,
|
|
2199
|
-
num_layers=self.num_layers,
|
|
2200
|
-
dropout=self.dropout,
|
|
2201
|
-
task_type=self.task_type
|
|
2202
|
-
).to(self.device)
|
|
2203
|
-
return self
|