ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,785 +1,788 @@
1
- from __future__ import annotations
2
-
3
- import hashlib
4
- import os
5
- import time
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Tuple
8
-
9
- import numpy as np
10
- import pandas as pd
11
- import torch
12
- import torch.distributed as dist
13
- import torch.nn as nn
14
- from sklearn.neighbors import NearestNeighbors
15
- from torch.cuda.amp import autocast, GradScaler
16
- from torch.nn.parallel import DistributedDataParallel as DDP
17
- from torch.nn.utils import clip_grad_norm_
18
-
19
- from ..utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
20
- from ..utils.losses import (
21
- infer_loss_name_from_model_name,
22
- normalize_loss_name,
23
- resolve_tweedie_power,
24
- )
25
-
26
- try:
27
- from torch_geometric.nn import knn_graph
28
- from torch_geometric.utils import add_self_loops, to_undirected
29
- _PYG_AVAILABLE = True
30
- except Exception:
31
- knn_graph = None # type: ignore
32
- add_self_loops = None # type: ignore
33
- to_undirected = None # type: ignore
34
- _PYG_AVAILABLE = False
35
-
36
- try:
37
- import pynndescent
38
- _PYNN_AVAILABLE = True
39
- except Exception:
40
- pynndescent = None # type: ignore
41
- _PYNN_AVAILABLE = False
42
-
43
- _GNN_MPS_WARNED = False
44
-
45
-
46
- # =============================================================================
47
- # Simplified GNN implementation.
48
- # =============================================================================
49
-
50
- def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
51
- """Matrix multiply that supports sparse or dense adjacency."""
52
- if adj.is_sparse:
53
- return torch.sparse.mm(adj, x)
54
- return adj.matmul(x)
55
-
56
-
57
- class SimpleGraphLayer(nn.Module):
58
- def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
59
- super().__init__()
60
- self.linear = nn.Linear(in_dim, out_dim)
61
- self.activation = nn.ReLU(inplace=True)
62
- self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
63
-
64
- def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
65
- # Message passing with normalized sparse adjacency: A_hat * X * W.
66
- h = _adj_mm(adj, x)
67
- h = self.linear(h)
68
- h = self.activation(h)
69
- return self.dropout(h)
70
-
71
-
72
- class SimpleGNN(nn.Module):
73
- def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
74
- dropout: float = 0.1, task_type: str = 'regression'):
75
- super().__init__()
76
- layers = []
77
- dim_in = input_dim
78
- for _ in range(max(1, num_layers)):
79
- layers.append(SimpleGraphLayer(
80
- dim_in, hidden_dim, dropout=dropout))
81
- dim_in = hidden_dim
82
- self.layers = nn.ModuleList(layers)
83
- self.output = nn.Linear(hidden_dim, 1)
84
- if task_type == 'classification':
85
- self.output_act = nn.Identity()
86
- else:
87
- self.output_act = nn.Softplus()
88
- self.task_type = task_type
89
- # Keep adjacency as a buffer for DataParallel copies.
90
- self.register_buffer("adj_buffer", torch.empty(0))
91
-
92
- def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
93
- adj_used = adj if adj is not None else getattr(
94
- self, "adj_buffer", None)
95
- if adj_used is None or adj_used.numel() == 0:
96
- raise RuntimeError("Adjacency is not set for GNN forward.")
97
- h = x
98
- for layer in self.layers:
99
- h = layer(h, adj_used)
100
- h = _adj_mm(adj_used, h)
101
- out = self.output(h)
102
- return self.output_act(out)
103
-
104
-
105
- class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
106
- def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
107
- num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
108
- learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
109
- task_type: str = 'regression', tweedie_power: float = 1.5,
110
- weight_decay: float = 0.0,
111
- use_data_parallel: bool = False, use_ddp: bool = False,
112
- use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
113
- graph_cache_path: Optional[str] = None,
114
- max_gpu_knn_nodes: Optional[int] = None,
115
- knn_gpu_mem_ratio: float = 0.9,
116
- knn_gpu_mem_overhead: float = 2.0,
117
- knn_cpu_jobs: Optional[int] = -1,
118
- loss_name: Optional[str] = None) -> None:
119
- super().__init__()
120
- self.model_nme = model_nme
121
- self.input_dim = input_dim
122
- self.hidden_dim = hidden_dim
123
- self.num_layers = num_layers
124
- self.k_neighbors = max(1, k_neighbors)
125
- self.dropout = dropout
126
- self.learning_rate = learning_rate
127
- self.weight_decay = weight_decay
128
- self.epochs = epochs
129
- self.patience = patience
130
- self.task_type = task_type
131
- self.use_approx_knn = use_approx_knn
132
- self.approx_knn_threshold = approx_knn_threshold
133
- self.graph_cache_path = Path(
134
- graph_cache_path) if graph_cache_path else None
135
- self.max_gpu_knn_nodes = max_gpu_knn_nodes
136
- self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
137
- self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
138
- self.knn_cpu_jobs = knn_cpu_jobs
139
- self.mps_dense_max_nodes = int(
140
- os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
141
- )
142
- self._knn_warning_emitted = False
143
- self._mps_fallback_triggered = False
144
- self._adj_cache_meta: Optional[Dict[str, Any]] = None
145
- self._adj_cache_key: Optional[Tuple[Any, ...]] = None
146
- self._adj_cache_tensor: Optional[torch.Tensor] = None
147
-
148
- resolved_loss = normalize_loss_name(loss_name, self.task_type)
149
- if self.task_type == 'classification':
150
- self.loss_name = "logloss"
151
- self.tw_power = None
152
- else:
153
- if resolved_loss == "auto":
154
- resolved_loss = infer_loss_name_from_model_name(self.model_nme)
155
- self.loss_name = resolved_loss
156
- if self.loss_name == "tweedie":
157
- self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
158
- else:
159
- self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
160
-
161
- self.ddp_enabled = False
162
- self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
163
- self.data_parallel_enabled = False
164
- self._ddp_disabled = False
165
-
166
- if use_ddp:
167
- world_size = int(os.environ.get("WORLD_SIZE", "1"))
168
- if world_size > 1:
169
- print(
170
- "[GNN] DDP training is not supported; falling back to single process.",
171
- flush=True,
172
- )
173
- self._ddp_disabled = True
174
- use_ddp = False
175
-
176
- # DDP only works with CUDA; fall back to single process if init fails.
177
- if use_ddp and torch.cuda.is_available():
178
- ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
179
- if ddp_ok:
180
- self.ddp_enabled = True
181
- self.local_rank = local_rank
182
- self.device = torch.device(f'cuda:{local_rank}')
183
- else:
184
- self.device = torch.device('cuda')
185
- elif torch.cuda.is_available():
186
- if self._ddp_disabled:
187
- self.device = torch.device(f'cuda:{self.local_rank}')
188
- else:
189
- self.device = torch.device('cuda')
190
- elif torch.backends.mps.is_available():
191
- self.device = torch.device('mps')
192
- global _GNN_MPS_WARNED
193
- if not _GNN_MPS_WARNED:
194
- print(
195
- "[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
196
- flush=True,
197
- )
198
- _GNN_MPS_WARNED = True
199
- else:
200
- self.device = torch.device('cpu')
201
- self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
202
-
203
- self.gnn = SimpleGNN(
204
- input_dim=self.input_dim,
205
- hidden_dim=self.hidden_dim,
206
- num_layers=self.num_layers,
207
- dropout=self.dropout,
208
- task_type=self.task_type
209
- ).to(self.device)
210
-
211
- # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
212
- if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
213
- self.data_parallel_enabled = True
214
- self.gnn = nn.DataParallel(
215
- self.gnn, device_ids=list(range(torch.cuda.device_count())))
216
- self.device = torch.device('cuda')
217
-
218
- if self.ddp_enabled:
219
- self.gnn = DDP(
220
- self.gnn,
221
- device_ids=[self.local_rank],
222
- output_device=self.local_rank,
223
- find_unused_parameters=False
224
- )
225
-
226
- @staticmethod
227
- def _validate_vector(arr, name: str, n_rows: int) -> None:
228
- if arr is None:
229
- return
230
- if isinstance(arr, pd.DataFrame):
231
- if arr.shape[1] != 1:
232
- raise ValueError(f"{name} must be 1d (single column).")
233
- length = len(arr)
234
- else:
235
- arr_np = np.asarray(arr)
236
- if arr_np.ndim == 0:
237
- raise ValueError(f"{name} must be 1d.")
238
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
239
- raise ValueError(f"{name} must be 1d or Nx1.")
240
- length = arr_np.shape[0]
241
- if length != n_rows:
242
- raise ValueError(
243
- f"{name} length {length} does not match X length {n_rows}."
244
- )
245
-
246
- def _unwrap_gnn(self) -> nn.Module:
247
- if isinstance(self.gnn, (DDP, nn.DataParallel)):
248
- return self.gnn.module
249
- return self.gnn
250
-
251
- def _set_adj_buffer(self, adj: torch.Tensor) -> None:
252
- base = self._unwrap_gnn()
253
- if hasattr(base, "adj_buffer"):
254
- base.adj_buffer = adj
255
- else:
256
- base.register_buffer("adj_buffer", adj)
257
-
258
- @staticmethod
259
- def _is_mps_unsupported_error(exc: BaseException) -> bool:
260
- msg = str(exc).lower()
261
- if "mps" not in msg:
262
- return False
263
- if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
264
- return True
265
- return "sparse" in msg
266
-
267
- def _fallback_to_cpu(self, reason: str) -> None:
268
- if self.device.type != "mps" or self._mps_fallback_triggered:
269
- return
270
- self._mps_fallback_triggered = True
271
- print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
272
- self.device = torch.device("cpu")
273
- self.use_pyg_knn = False
274
- self.data_parallel_enabled = False
275
- self.ddp_enabled = False
276
- base = self._unwrap_gnn()
277
- try:
278
- base = base.to(self.device)
279
- except Exception:
280
- pass
281
- self.gnn = base
282
- self.invalidate_graph_cache()
283
-
284
- def _run_with_mps_fallback(self, fn, *args, **kwargs):
285
- try:
286
- return fn(*args, **kwargs)
287
- except (RuntimeError, NotImplementedError) as exc:
288
- if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
289
- self._fallback_to_cpu(str(exc))
290
- return fn(*args, **kwargs)
291
- raise
292
-
293
- def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
294
- row_hash = pd.util.hash_pandas_object(X_df, index=False).values
295
- idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
296
- col_sig = ",".join(map(str, X_df.columns))
297
- hasher = hashlib.sha256()
298
- hasher.update(row_hash.tobytes())
299
- hasher.update(idx_hash.tobytes())
300
- hasher.update(col_sig.encode("utf-8", errors="ignore"))
301
- knn_config = {
302
- "k_neighbors": int(self.k_neighbors),
303
- "use_approx_knn": bool(self.use_approx_knn),
304
- "approx_knn_threshold": int(self.approx_knn_threshold),
305
- "use_pyg_knn": bool(self.use_pyg_knn),
306
- "pynndescent_available": bool(_PYNN_AVAILABLE),
307
- "max_gpu_knn_nodes": (
308
- None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
309
- ),
310
- "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
311
- "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
312
- }
313
- adj_format = "dense" if self.device.type == "mps" else "sparse"
314
- return {
315
- "n_samples": int(X_df.shape[0]),
316
- "n_features": int(X_df.shape[1]),
317
- "hash": hasher.hexdigest(),
318
- "knn_config": knn_config,
319
- "adj_format": adj_format,
320
- "device_type": self.device.type,
321
- }
322
-
323
- def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
324
- return (
325
- id(X_df),
326
- id(getattr(X_df, "_mgr", None)),
327
- id(X_df.index),
328
- X_df.shape,
329
- tuple(map(str, X_df.columns)),
330
- X_df.attrs.get("graph_cache_key"),
331
- )
332
-
333
- def invalidate_graph_cache(self) -> None:
334
- self._adj_cache_meta = None
335
- self._adj_cache_key = None
336
- self._adj_cache_tensor = None
337
-
338
- def _load_cached_adj(self,
339
- X_df: pd.DataFrame,
340
- meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
341
- if self.graph_cache_path and self.graph_cache_path.exists():
342
- if meta_expected is None:
343
- meta_expected = self._graph_cache_meta(X_df)
344
- try:
345
- payload = torch.load(self.graph_cache_path, map_location="cpu")
346
- except Exception as exc:
347
- print(
348
- f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
349
- return None
350
- if isinstance(payload, dict) and "adj" in payload:
351
- meta_cached = payload.get("meta")
352
- if meta_cached == meta_expected:
353
- adj = payload["adj"]
354
- if self.device.type == "mps" and getattr(adj, "is_sparse", False):
355
- print(
356
- f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
357
- )
358
- return None
359
- return adj.to(self.device)
360
- print(
361
- f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
362
- return None
363
- if isinstance(payload, torch.Tensor):
364
- print(
365
- f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
366
- return None
367
- print(
368
- f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
369
- return None
370
-
371
- def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
372
- n_samples = X_np.shape[0]
373
- k = min(self.k_neighbors, max(1, n_samples - 1))
374
- n_neighbors = min(k + 1, n_samples)
375
- use_approx = (self.use_approx_knn or n_samples >=
376
- self.approx_knn_threshold) and _PYNN_AVAILABLE
377
- indices = None
378
- if use_approx:
379
- try:
380
- nn_index = pynndescent.NNDescent(
381
- X_np,
382
- n_neighbors=n_neighbors,
383
- random_state=0
384
- )
385
- indices, _ = nn_index.neighbor_graph
386
- except Exception as exc:
387
- print(
388
- f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
389
- use_approx = False
390
-
391
- if indices is None:
392
- nbrs = NearestNeighbors(
393
- n_neighbors=n_neighbors,
394
- algorithm="auto",
395
- n_jobs=self.knn_cpu_jobs,
396
- )
397
- nbrs.fit(X_np)
398
- _, indices = nbrs.kneighbors(X_np)
399
-
400
- indices = np.asarray(indices)
401
- rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
402
- np.int64, copy=False)
403
- cols = indices.reshape(-1).astype(np.int64, copy=False)
404
- mask = rows != cols
405
- rows = rows[mask]
406
- cols = cols[mask]
407
- rows_base = rows
408
- cols_base = cols
409
- self_loops = np.arange(n_samples, dtype=np.int64)
410
- rows = np.concatenate([rows_base, cols_base, self_loops])
411
- cols = np.concatenate([cols_base, rows_base, self_loops])
412
-
413
- edge_index_np = np.stack([rows, cols], axis=0)
414
- edge_index = torch.as_tensor(edge_index_np, device=self.device)
415
- return edge_index
416
-
417
- def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
418
- if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
419
- # Defensive: check use_pyg_knn before calling.
420
- raise RuntimeError(
421
- "GPU graph builder requested but PyG is unavailable.")
422
-
423
- n_samples = X_tensor.size(0)
424
- k = min(self.k_neighbors, max(1, n_samples - 1))
425
-
426
- # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
427
- edge_index = knn_graph(
428
- X_tensor,
429
- k=k,
430
- loop=False
431
- )
432
- edge_index = to_undirected(edge_index, num_nodes=n_samples)
433
- edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
434
- return edge_index
435
-
436
- def _log_knn_fallback(self, reason: str) -> None:
437
- if self._knn_warning_emitted:
438
- return
439
- if (not self.ddp_enabled) or self.local_rank == 0:
440
- print(f"[GNN] Falling back to CPU kNN builder: {reason}")
441
- self._knn_warning_emitted = True
442
-
443
- def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
444
- if not self.use_pyg_knn:
445
- return False
446
-
447
- reason = None
448
- if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
449
- reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
450
- elif self.device.type == 'cuda' and torch.cuda.is_available():
451
- try:
452
- device_index = self.device.index
453
- if device_index is None:
454
- device_index = torch.cuda.current_device()
455
- free_mem, total_mem = torch.cuda.mem_get_info(device_index)
456
- feature_bytes = X_tensor.element_size() * X_tensor.nelement()
457
- required = int(feature_bytes * self.knn_gpu_mem_overhead)
458
- budget = int(free_mem * self.knn_gpu_mem_ratio)
459
- if required > budget:
460
- required_gb = required / (1024 ** 3)
461
- budget_gb = budget / (1024 ** 3)
462
- reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
463
- f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
464
- except Exception:
465
- # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
466
- reason = None
467
-
468
- if reason:
469
- self._log_knn_fallback(reason)
470
- return False
471
- return True
472
-
473
- def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
474
- if self.device.type == "mps":
475
- return self._normalized_adj_dense(edge_index, num_nodes)
476
- return self._normalized_adj_sparse(edge_index, num_nodes)
477
-
478
- def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
479
- values = torch.ones(edge_index.shape[1], device=self.device)
480
- adj = torch.sparse_coo_tensor(
481
- edge_index.to(self.device), values, (num_nodes, num_nodes))
482
- adj = adj.coalesce()
483
-
484
- deg = torch.sparse.sum(adj, dim=1).to_dense()
485
- deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
486
- row, col = adj.indices()
487
- norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
488
- adj_norm = torch.sparse_coo_tensor(
489
- adj.indices(), norm_values, size=adj.shape)
490
- return adj_norm
491
-
492
- def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
493
- if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
494
- raise RuntimeError(
495
- f"MPS dense adjacency not supported for {num_nodes} nodes; "
496
- f"max={self.mps_dense_max_nodes}. Falling back to CPU."
497
- )
498
- edge_index = edge_index.to(self.device)
499
- adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
500
- adj[edge_index[0], edge_index[1]] = 1.0
501
- deg = adj.sum(dim=1)
502
- deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
503
- adj = adj * deg_inv_sqrt.view(-1, 1)
504
- adj = adj * deg_inv_sqrt.view(1, -1)
505
- return adj
506
-
507
- def _tensorize_split(self, X, y, w, allow_none: bool = False):
508
- if X is None and allow_none:
509
- return None, None, None
510
- if not isinstance(X, pd.DataFrame):
511
- raise ValueError("X must be a pandas DataFrame for GNN.")
512
- n_rows = len(X)
513
- if y is not None:
514
- self._validate_vector(y, "y", n_rows)
515
- if w is not None:
516
- self._validate_vector(w, "w", n_rows)
517
- X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
518
- X, "to_numpy") else np.asarray(X, dtype=np.float32)
519
- X_tensor = torch.as_tensor(
520
- X_np, dtype=torch.float32, device=self.device)
521
- if y is None:
522
- y_tensor = None
523
- else:
524
- y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
525
- y, "to_numpy") else np.asarray(y, dtype=np.float32)
526
- y_tensor = torch.as_tensor(
527
- y_np, dtype=torch.float32, device=self.device).view(-1, 1)
528
- if w is None:
529
- w_tensor = torch.ones(
530
- (len(X), 1), dtype=torch.float32, device=self.device)
531
- else:
532
- w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
533
- w, "to_numpy") else np.asarray(w, dtype=np.float32)
534
- w_tensor = torch.as_tensor(
535
- w_np, dtype=torch.float32, device=self.device).view(-1, 1)
536
- return X_tensor, y_tensor, w_tensor
537
-
538
- def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
539
- if not isinstance(X_df, pd.DataFrame):
540
- raise ValueError("X must be a pandas DataFrame for graph building.")
541
- meta_expected = None
542
- cache_key = None
543
- if self.graph_cache_path:
544
- meta_expected = self._graph_cache_meta(X_df)
545
- if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
546
- cached = self._adj_cache_tensor
547
- if cached.device != self.device:
548
- if self.device.type == "mps" and getattr(cached, "is_sparse", False):
549
- self._adj_cache_tensor = None
550
- else:
551
- cached = cached.to(self.device)
552
- self._adj_cache_tensor = cached
553
- if self._adj_cache_tensor is not None:
554
- return self._adj_cache_tensor
555
- else:
556
- cache_key = self._graph_cache_key(X_df)
557
- if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
558
- cached = self._adj_cache_tensor
559
- if cached.device != self.device:
560
- if self.device.type == "mps" and getattr(cached, "is_sparse", False):
561
- self._adj_cache_tensor = None
562
- else:
563
- cached = cached.to(self.device)
564
- self._adj_cache_tensor = cached
565
- if self._adj_cache_tensor is not None:
566
- return self._adj_cache_tensor
567
- X_np = None
568
- if X_tensor is None:
569
- X_np = X_df.to_numpy(dtype=np.float32, copy=False)
570
- X_tensor = torch.as_tensor(
571
- X_np, dtype=torch.float32, device=self.device)
572
- if self.graph_cache_path:
573
- cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
574
- if cached is not None:
575
- self._adj_cache_meta = meta_expected
576
- self._adj_cache_key = None
577
- self._adj_cache_tensor = cached
578
- return cached
579
- use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
580
- if use_gpu_knn:
581
- edge_index = self._build_edge_index_gpu(X_tensor)
582
- else:
583
- if X_np is None:
584
- X_np = X_df.to_numpy(dtype=np.float32, copy=False)
585
- edge_index = self._build_edge_index_cpu(X_np)
586
- adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
587
- if self.graph_cache_path:
588
- try:
589
- IOUtils.ensure_parent_dir(str(self.graph_cache_path))
590
- torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
591
- except Exception as exc:
592
- print(
593
- f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
594
- self._adj_cache_meta = meta_expected
595
- self._adj_cache_key = None
596
- else:
597
- self._adj_cache_meta = None
598
- self._adj_cache_key = cache_key
599
- self._adj_cache_tensor = adj_norm
600
- return adj_norm
601
-
602
- def fit(self, X_train, y_train, w_train=None,
603
- X_val=None, y_val=None, w_val=None,
604
- trial: Optional[optuna.trial.Trial] = None):
605
- return self._run_with_mps_fallback(
606
- self._fit_impl,
607
- X_train,
608
- y_train,
609
- w_train,
610
- X_val,
611
- y_val,
612
- w_val,
613
- trial,
614
- )
615
-
616
- def _fit_impl(self, X_train, y_train, w_train=None,
617
- X_val=None, y_val=None, w_val=None,
618
- trial: Optional[optuna.trial.Trial] = None):
619
- X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
620
- X_train, y_train, w_train, allow_none=False)
621
- has_val = X_val is not None and y_val is not None
622
- if has_val:
623
- X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
624
- X_val, y_val, w_val, allow_none=False)
625
- else:
626
- X_val_tensor = y_val_tensor = w_val_tensor = None
627
-
628
- adj_train = self._build_graph_from_df(X_train, X_train_tensor)
629
- adj_val = self._build_graph_from_df(
630
- X_val, X_val_tensor) if has_val else None
631
- # DataParallel needs adjacency cached on the model to avoid scatter.
632
- self._set_adj_buffer(adj_train)
633
-
634
- base_gnn = self._unwrap_gnn()
635
- optimizer = torch.optim.Adam(
636
- base_gnn.parameters(),
637
- lr=self.learning_rate,
638
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
639
- )
640
- scaler = GradScaler(enabled=(self.device.type == 'cuda'))
641
-
642
- best_loss = float('inf')
643
- best_state = None
644
- patience_counter = 0
645
- best_epoch = None
646
-
647
- for epoch in range(1, self.epochs + 1):
648
- epoch_start_ts = time.time()
649
- self.gnn.train()
650
- optimizer.zero_grad()
651
- with autocast(enabled=(self.device.type == 'cuda')):
652
- if self.data_parallel_enabled:
653
- y_pred = self.gnn(X_train_tensor)
654
- else:
655
- y_pred = self.gnn(X_train_tensor, adj_train)
656
- loss = self._compute_weighted_loss(
657
- y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
658
- scaler.scale(loss).backward()
659
- scaler.unscale_(optimizer)
660
- clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
661
- scaler.step(optimizer)
662
- scaler.update()
663
-
664
- val_loss = None
665
- if has_val:
666
- self.gnn.eval()
667
- if self.data_parallel_enabled and adj_val is not None:
668
- self._set_adj_buffer(adj_val)
669
- with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
670
- if self.data_parallel_enabled:
671
- y_val_pred = self.gnn(X_val_tensor)
672
- else:
673
- y_val_pred = self.gnn(X_val_tensor, adj_val)
674
- val_loss = self._compute_weighted_loss(
675
- y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
676
- if self.data_parallel_enabled:
677
- # Restore training adjacency.
678
- self._set_adj_buffer(adj_train)
679
-
680
- is_best = val_loss is not None and val_loss < best_loss
681
- best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
682
- val_loss, best_loss, best_state, patience_counter, base_gnn,
683
- ignore_keys=["adj_buffer"])
684
- if is_best:
685
- best_epoch = epoch
686
-
687
- prune_now = False
688
- if trial is not None:
689
- trial.report(val_loss, epoch)
690
- if trial.should_prune():
691
- prune_now = True
692
-
693
- if dist.is_initialized():
694
- flag = torch.tensor(
695
- [1 if prune_now else 0],
696
- device=self.device,
697
- dtype=torch.int32,
698
- )
699
- dist.broadcast(flag, src=0)
700
- prune_now = bool(flag.item())
701
-
702
- if prune_now:
703
- raise optuna.TrialPruned()
704
- if stop_training:
705
- break
706
-
707
- should_log = (not dist.is_initialized()
708
- or DistributedUtils.is_main_process())
709
- if should_log:
710
- elapsed = int(time.time() - epoch_start_ts)
711
- if val_loss is None:
712
- print(
713
- f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
714
- flush=True,
715
- )
716
- else:
717
- print(
718
- f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
719
- f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
720
- flush=True,
721
- )
722
-
723
- if best_state is not None:
724
- base_gnn.load_state_dict(best_state, strict=False)
725
- self.best_epoch = int(best_epoch or self.epochs)
726
-
727
- def predict(self, X: pd.DataFrame) -> np.ndarray:
728
- return self._run_with_mps_fallback(self._predict_impl, X)
729
-
730
- def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
731
- self.gnn.eval()
732
- X_tensor, _, _ = self._tensorize_split(
733
- X, None, None, allow_none=False)
734
- adj = self._build_graph_from_df(X, X_tensor)
735
- if self.data_parallel_enabled:
736
- self._set_adj_buffer(adj)
737
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
738
- with inference_cm():
739
- if self.data_parallel_enabled:
740
- y_pred = self.gnn(X_tensor).cpu().numpy()
741
- else:
742
- y_pred = self.gnn(X_tensor, adj).cpu().numpy()
743
- if self.task_type == 'classification':
744
- y_pred = 1 / (1 + np.exp(-y_pred))
745
- else:
746
- y_pred = np.clip(y_pred, 1e-6, None)
747
- return y_pred.ravel()
748
-
749
- def encode(self, X: pd.DataFrame) -> np.ndarray:
750
- return self._run_with_mps_fallback(self._encode_impl, X)
751
-
752
- def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
753
- """Return per-sample node embeddings (hidden representations)."""
754
- base = self._unwrap_gnn()
755
- base.eval()
756
- X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
757
- adj = self._build_graph_from_df(X, X_tensor)
758
- if self.data_parallel_enabled:
759
- self._set_adj_buffer(adj)
760
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
761
- with inference_cm():
762
- h = X_tensor
763
- layers = getattr(base, "layers", None)
764
- if layers is None:
765
- raise RuntimeError("GNN base module does not expose layers.")
766
- for layer in layers:
767
- h = layer(h, adj)
768
- h = _adj_mm(adj, h)
769
- return h.detach().cpu().numpy()
770
-
771
- def set_params(self, params: Dict[str, Any]):
772
- for key, value in params.items():
773
- if hasattr(self, key):
774
- setattr(self, key, value)
775
- else:
776
- raise ValueError(f"Parameter {key} not found in GNN model.")
777
- # Rebuild the backbone after structural parameter changes.
778
- self.gnn = SimpleGNN(
779
- input_dim=self.input_dim,
780
- hidden_dim=self.hidden_dim,
781
- num_layers=self.num_layers,
782
- dropout=self.dropout,
783
- task_type=self.task_type
784
- ).to(self.device)
785
- return self
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn as nn
14
+ from sklearn.neighbors import NearestNeighbors
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.nn.utils import clip_grad_norm_
18
+
19
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
20
+ from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
21
+ from ins_pricing.utils import EPS
22
+ from ins_pricing.utils.io import IOUtils
23
+ from ins_pricing.utils.losses import (
24
+ infer_loss_name_from_model_name,
25
+ normalize_loss_name,
26
+ resolve_tweedie_power,
27
+ )
28
+
29
+ try:
30
+ from torch_geometric.nn import knn_graph
31
+ from torch_geometric.utils import add_self_loops, to_undirected
32
+ _PYG_AVAILABLE = True
33
+ except Exception:
34
+ knn_graph = None # type: ignore
35
+ add_self_loops = None # type: ignore
36
+ to_undirected = None # type: ignore
37
+ _PYG_AVAILABLE = False
38
+
39
+ try:
40
+ import pynndescent
41
+ _PYNN_AVAILABLE = True
42
+ except Exception:
43
+ pynndescent = None # type: ignore
44
+ _PYNN_AVAILABLE = False
45
+
46
+ _GNN_MPS_WARNED = False
47
+
48
+
49
+ # =============================================================================
50
+ # Simplified GNN implementation.
51
+ # =============================================================================
52
+
53
+ def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
54
+ """Matrix multiply that supports sparse or dense adjacency."""
55
+ if adj.is_sparse:
56
+ return torch.sparse.mm(adj, x)
57
+ return adj.matmul(x)
58
+
59
+
60
+ class SimpleGraphLayer(nn.Module):
61
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
62
+ super().__init__()
63
+ self.linear = nn.Linear(in_dim, out_dim)
64
+ self.activation = nn.ReLU(inplace=True)
65
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
66
+
67
+ def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
68
+ # Message passing with normalized sparse adjacency: A_hat * X * W.
69
+ h = _adj_mm(adj, x)
70
+ h = self.linear(h)
71
+ h = self.activation(h)
72
+ return self.dropout(h)
73
+
74
+
75
+ class SimpleGNN(nn.Module):
76
+ def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
77
+ dropout: float = 0.1, task_type: str = 'regression'):
78
+ super().__init__()
79
+ layers = []
80
+ dim_in = input_dim
81
+ for _ in range(max(1, num_layers)):
82
+ layers.append(SimpleGraphLayer(
83
+ dim_in, hidden_dim, dropout=dropout))
84
+ dim_in = hidden_dim
85
+ self.layers = nn.ModuleList(layers)
86
+ self.output = nn.Linear(hidden_dim, 1)
87
+ if task_type == 'classification':
88
+ self.output_act = nn.Identity()
89
+ else:
90
+ self.output_act = nn.Softplus()
91
+ self.task_type = task_type
92
+ # Keep adjacency as a buffer for DataParallel copies.
93
+ self.register_buffer("adj_buffer", torch.empty(0))
94
+
95
+ def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
96
+ adj_used = adj if adj is not None else getattr(
97
+ self, "adj_buffer", None)
98
+ if adj_used is None or adj_used.numel() == 0:
99
+ raise RuntimeError("Adjacency is not set for GNN forward.")
100
+ h = x
101
+ for layer in self.layers:
102
+ h = layer(h, adj_used)
103
+ h = _adj_mm(adj_used, h)
104
+ out = self.output(h)
105
+ return self.output_act(out)
106
+
107
+
108
+ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
109
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
110
+ num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
111
+ learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
112
+ task_type: str = 'regression', tweedie_power: float = 1.5,
113
+ weight_decay: float = 0.0,
114
+ use_data_parallel: bool = False, use_ddp: bool = False,
115
+ use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
116
+ graph_cache_path: Optional[str] = None,
117
+ max_gpu_knn_nodes: Optional[int] = None,
118
+ knn_gpu_mem_ratio: float = 0.9,
119
+ knn_gpu_mem_overhead: float = 2.0,
120
+ knn_cpu_jobs: Optional[int] = -1,
121
+ loss_name: Optional[str] = None) -> None:
122
+ super().__init__()
123
+ self.model_nme = model_nme
124
+ self.input_dim = input_dim
125
+ self.hidden_dim = hidden_dim
126
+ self.num_layers = num_layers
127
+ self.k_neighbors = max(1, k_neighbors)
128
+ self.dropout = dropout
129
+ self.learning_rate = learning_rate
130
+ self.weight_decay = weight_decay
131
+ self.epochs = epochs
132
+ self.patience = patience
133
+ self.task_type = task_type
134
+ self.use_approx_knn = use_approx_knn
135
+ self.approx_knn_threshold = approx_knn_threshold
136
+ self.graph_cache_path = Path(
137
+ graph_cache_path) if graph_cache_path else None
138
+ self.max_gpu_knn_nodes = max_gpu_knn_nodes
139
+ self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
140
+ self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
141
+ self.knn_cpu_jobs = knn_cpu_jobs
142
+ self.mps_dense_max_nodes = int(
143
+ os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
144
+ )
145
+ self._knn_warning_emitted = False
146
+ self._mps_fallback_triggered = False
147
+ self._adj_cache_meta: Optional[Dict[str, Any]] = None
148
+ self._adj_cache_key: Optional[Tuple[Any, ...]] = None
149
+ self._adj_cache_tensor: Optional[torch.Tensor] = None
150
+
151
+ resolved_loss = normalize_loss_name(loss_name, self.task_type)
152
+ if self.task_type == 'classification':
153
+ self.loss_name = "logloss"
154
+ self.tw_power = None
155
+ else:
156
+ if resolved_loss == "auto":
157
+ resolved_loss = infer_loss_name_from_model_name(self.model_nme)
158
+ self.loss_name = resolved_loss
159
+ if self.loss_name == "tweedie":
160
+ self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
161
+ else:
162
+ self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
163
+
164
+ self.ddp_enabled = False
165
+ self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
166
+ self.data_parallel_enabled = False
167
+ self._ddp_disabled = False
168
+
169
+ if use_ddp:
170
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
171
+ if world_size > 1:
172
+ print(
173
+ "[GNN] DDP training is not supported; falling back to single process.",
174
+ flush=True,
175
+ )
176
+ self._ddp_disabled = True
177
+ use_ddp = False
178
+
179
+ # DDP only works with CUDA; fall back to single process if init fails.
180
+ if use_ddp and torch.cuda.is_available():
181
+ ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
182
+ if ddp_ok:
183
+ self.ddp_enabled = True
184
+ self.local_rank = local_rank
185
+ self.device = torch.device(f'cuda:{local_rank}')
186
+ else:
187
+ self.device = torch.device('cuda')
188
+ elif torch.cuda.is_available():
189
+ if self._ddp_disabled:
190
+ self.device = torch.device(f'cuda:{self.local_rank}')
191
+ else:
192
+ self.device = torch.device('cuda')
193
+ elif torch.backends.mps.is_available():
194
+ self.device = torch.device('mps')
195
+ global _GNN_MPS_WARNED
196
+ if not _GNN_MPS_WARNED:
197
+ print(
198
+ "[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
199
+ flush=True,
200
+ )
201
+ _GNN_MPS_WARNED = True
202
+ else:
203
+ self.device = torch.device('cpu')
204
+ self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
205
+
206
+ self.gnn = SimpleGNN(
207
+ input_dim=self.input_dim,
208
+ hidden_dim=self.hidden_dim,
209
+ num_layers=self.num_layers,
210
+ dropout=self.dropout,
211
+ task_type=self.task_type
212
+ ).to(self.device)
213
+
214
+ # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
215
+ if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
216
+ self.data_parallel_enabled = True
217
+ self.gnn = nn.DataParallel(
218
+ self.gnn, device_ids=list(range(torch.cuda.device_count())))
219
+ self.device = torch.device('cuda')
220
+
221
+ if self.ddp_enabled:
222
+ self.gnn = DDP(
223
+ self.gnn,
224
+ device_ids=[self.local_rank],
225
+ output_device=self.local_rank,
226
+ find_unused_parameters=False
227
+ )
228
+
229
+ @staticmethod
230
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
231
+ if arr is None:
232
+ return
233
+ if isinstance(arr, pd.DataFrame):
234
+ if arr.shape[1] != 1:
235
+ raise ValueError(f"{name} must be 1d (single column).")
236
+ length = len(arr)
237
+ else:
238
+ arr_np = np.asarray(arr)
239
+ if arr_np.ndim == 0:
240
+ raise ValueError(f"{name} must be 1d.")
241
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
242
+ raise ValueError(f"{name} must be 1d or Nx1.")
243
+ length = arr_np.shape[0]
244
+ if length != n_rows:
245
+ raise ValueError(
246
+ f"{name} length {length} does not match X length {n_rows}."
247
+ )
248
+
249
+ def _unwrap_gnn(self) -> nn.Module:
250
+ if isinstance(self.gnn, (DDP, nn.DataParallel)):
251
+ return self.gnn.module
252
+ return self.gnn
253
+
254
+ def _set_adj_buffer(self, adj: torch.Tensor) -> None:
255
+ base = self._unwrap_gnn()
256
+ if hasattr(base, "adj_buffer"):
257
+ base.adj_buffer = adj
258
+ else:
259
+ base.register_buffer("adj_buffer", adj)
260
+
261
+ @staticmethod
262
+ def _is_mps_unsupported_error(exc: BaseException) -> bool:
263
+ msg = str(exc).lower()
264
+ if "mps" not in msg:
265
+ return False
266
+ if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
267
+ return True
268
+ return "sparse" in msg
269
+
270
+ def _fallback_to_cpu(self, reason: str) -> None:
271
+ if self.device.type != "mps" or self._mps_fallback_triggered:
272
+ return
273
+ self._mps_fallback_triggered = True
274
+ print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
275
+ self.device = torch.device("cpu")
276
+ self.use_pyg_knn = False
277
+ self.data_parallel_enabled = False
278
+ self.ddp_enabled = False
279
+ base = self._unwrap_gnn()
280
+ try:
281
+ base = base.to(self.device)
282
+ except Exception:
283
+ pass
284
+ self.gnn = base
285
+ self.invalidate_graph_cache()
286
+
287
+ def _run_with_mps_fallback(self, fn, *args, **kwargs):
288
+ try:
289
+ return fn(*args, **kwargs)
290
+ except (RuntimeError, NotImplementedError) as exc:
291
+ if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
292
+ self._fallback_to_cpu(str(exc))
293
+ return fn(*args, **kwargs)
294
+ raise
295
+
296
+ def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
297
+ row_hash = pd.util.hash_pandas_object(X_df, index=False).values
298
+ idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
299
+ col_sig = ",".join(map(str, X_df.columns))
300
+ hasher = hashlib.sha256()
301
+ hasher.update(row_hash.tobytes())
302
+ hasher.update(idx_hash.tobytes())
303
+ hasher.update(col_sig.encode("utf-8", errors="ignore"))
304
+ knn_config = {
305
+ "k_neighbors": int(self.k_neighbors),
306
+ "use_approx_knn": bool(self.use_approx_knn),
307
+ "approx_knn_threshold": int(self.approx_knn_threshold),
308
+ "use_pyg_knn": bool(self.use_pyg_knn),
309
+ "pynndescent_available": bool(_PYNN_AVAILABLE),
310
+ "max_gpu_knn_nodes": (
311
+ None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
312
+ ),
313
+ "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
314
+ "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
315
+ }
316
+ adj_format = "dense" if self.device.type == "mps" else "sparse"
317
+ return {
318
+ "n_samples": int(X_df.shape[0]),
319
+ "n_features": int(X_df.shape[1]),
320
+ "hash": hasher.hexdigest(),
321
+ "knn_config": knn_config,
322
+ "adj_format": adj_format,
323
+ "device_type": self.device.type,
324
+ }
325
+
326
+ def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
327
+ return (
328
+ id(X_df),
329
+ id(getattr(X_df, "_mgr", None)),
330
+ id(X_df.index),
331
+ X_df.shape,
332
+ tuple(map(str, X_df.columns)),
333
+ X_df.attrs.get("graph_cache_key"),
334
+ )
335
+
336
+ def invalidate_graph_cache(self) -> None:
337
+ self._adj_cache_meta = None
338
+ self._adj_cache_key = None
339
+ self._adj_cache_tensor = None
340
+
341
+ def _load_cached_adj(self,
342
+ X_df: pd.DataFrame,
343
+ meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
344
+ if self.graph_cache_path and self.graph_cache_path.exists():
345
+ if meta_expected is None:
346
+ meta_expected = self._graph_cache_meta(X_df)
347
+ try:
348
+ payload = torch.load(self.graph_cache_path, map_location="cpu")
349
+ except Exception as exc:
350
+ print(
351
+ f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
352
+ return None
353
+ if isinstance(payload, dict) and "adj" in payload:
354
+ meta_cached = payload.get("meta")
355
+ if meta_cached == meta_expected:
356
+ adj = payload["adj"]
357
+ if self.device.type == "mps" and getattr(adj, "is_sparse", False):
358
+ print(
359
+ f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
360
+ )
361
+ return None
362
+ return adj.to(self.device)
363
+ print(
364
+ f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
365
+ return None
366
+ if isinstance(payload, torch.Tensor):
367
+ print(
368
+ f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
369
+ return None
370
+ print(
371
+ f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
372
+ return None
373
+
374
+ def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
375
+ n_samples = X_np.shape[0]
376
+ k = min(self.k_neighbors, max(1, n_samples - 1))
377
+ n_neighbors = min(k + 1, n_samples)
378
+ use_approx = (self.use_approx_knn or n_samples >=
379
+ self.approx_knn_threshold) and _PYNN_AVAILABLE
380
+ indices = None
381
+ if use_approx:
382
+ try:
383
+ nn_index = pynndescent.NNDescent(
384
+ X_np,
385
+ n_neighbors=n_neighbors,
386
+ random_state=0
387
+ )
388
+ indices, _ = nn_index.neighbor_graph
389
+ except Exception as exc:
390
+ print(
391
+ f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
392
+ use_approx = False
393
+
394
+ if indices is None:
395
+ nbrs = NearestNeighbors(
396
+ n_neighbors=n_neighbors,
397
+ algorithm="auto",
398
+ n_jobs=self.knn_cpu_jobs,
399
+ )
400
+ nbrs.fit(X_np)
401
+ _, indices = nbrs.kneighbors(X_np)
402
+
403
+ indices = np.asarray(indices)
404
+ rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
405
+ np.int64, copy=False)
406
+ cols = indices.reshape(-1).astype(np.int64, copy=False)
407
+ mask = rows != cols
408
+ rows = rows[mask]
409
+ cols = cols[mask]
410
+ rows_base = rows
411
+ cols_base = cols
412
+ self_loops = np.arange(n_samples, dtype=np.int64)
413
+ rows = np.concatenate([rows_base, cols_base, self_loops])
414
+ cols = np.concatenate([cols_base, rows_base, self_loops])
415
+
416
+ edge_index_np = np.stack([rows, cols], axis=0)
417
+ edge_index = torch.as_tensor(edge_index_np, device=self.device)
418
+ return edge_index
419
+
420
+ def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
421
+ if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
422
+ # Defensive: check use_pyg_knn before calling.
423
+ raise RuntimeError(
424
+ "GPU graph builder requested but PyG is unavailable.")
425
+
426
+ n_samples = X_tensor.size(0)
427
+ k = min(self.k_neighbors, max(1, n_samples - 1))
428
+
429
+ # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
430
+ edge_index = knn_graph(
431
+ X_tensor,
432
+ k=k,
433
+ loop=False
434
+ )
435
+ edge_index = to_undirected(edge_index, num_nodes=n_samples)
436
+ edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
437
+ return edge_index
438
+
439
+ def _log_knn_fallback(self, reason: str) -> None:
440
+ if self._knn_warning_emitted:
441
+ return
442
+ if (not self.ddp_enabled) or self.local_rank == 0:
443
+ print(f"[GNN] Falling back to CPU kNN builder: {reason}")
444
+ self._knn_warning_emitted = True
445
+
446
+ def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
447
+ if not self.use_pyg_knn:
448
+ return False
449
+
450
+ reason = None
451
+ if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
452
+ reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
453
+ elif self.device.type == 'cuda' and torch.cuda.is_available():
454
+ try:
455
+ device_index = self.device.index
456
+ if device_index is None:
457
+ device_index = torch.cuda.current_device()
458
+ free_mem, total_mem = torch.cuda.mem_get_info(device_index)
459
+ feature_bytes = X_tensor.element_size() * X_tensor.nelement()
460
+ required = int(feature_bytes * self.knn_gpu_mem_overhead)
461
+ budget = int(free_mem * self.knn_gpu_mem_ratio)
462
+ if required > budget:
463
+ required_gb = required / (1024 ** 3)
464
+ budget_gb = budget / (1024 ** 3)
465
+ reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
466
+ f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
467
+ except Exception:
468
+ # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
469
+ reason = None
470
+
471
+ if reason:
472
+ self._log_knn_fallback(reason)
473
+ return False
474
+ return True
475
+
476
+ def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
477
+ if self.device.type == "mps":
478
+ return self._normalized_adj_dense(edge_index, num_nodes)
479
+ return self._normalized_adj_sparse(edge_index, num_nodes)
480
+
481
+ def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
482
+ values = torch.ones(edge_index.shape[1], device=self.device)
483
+ adj = torch.sparse_coo_tensor(
484
+ edge_index.to(self.device), values, (num_nodes, num_nodes))
485
+ adj = adj.coalesce()
486
+
487
+ deg = torch.sparse.sum(adj, dim=1).to_dense()
488
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
489
+ row, col = adj.indices()
490
+ norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
491
+ adj_norm = torch.sparse_coo_tensor(
492
+ adj.indices(), norm_values, size=adj.shape)
493
+ return adj_norm
494
+
495
+ def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
496
+ if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
497
+ raise RuntimeError(
498
+ f"MPS dense adjacency not supported for {num_nodes} nodes; "
499
+ f"max={self.mps_dense_max_nodes}. Falling back to CPU."
500
+ )
501
+ edge_index = edge_index.to(self.device)
502
+ adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
503
+ adj[edge_index[0], edge_index[1]] = 1.0
504
+ deg = adj.sum(dim=1)
505
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
506
+ adj = adj * deg_inv_sqrt.view(-1, 1)
507
+ adj = adj * deg_inv_sqrt.view(1, -1)
508
+ return adj
509
+
510
+ def _tensorize_split(self, X, y, w, allow_none: bool = False):
511
+ if X is None and allow_none:
512
+ return None, None, None
513
+ if not isinstance(X, pd.DataFrame):
514
+ raise ValueError("X must be a pandas DataFrame for GNN.")
515
+ n_rows = len(X)
516
+ if y is not None:
517
+ self._validate_vector(y, "y", n_rows)
518
+ if w is not None:
519
+ self._validate_vector(w, "w", n_rows)
520
+ X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
521
+ X, "to_numpy") else np.asarray(X, dtype=np.float32)
522
+ X_tensor = torch.as_tensor(
523
+ X_np, dtype=torch.float32, device=self.device)
524
+ if y is None:
525
+ y_tensor = None
526
+ else:
527
+ y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
528
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
529
+ y_tensor = torch.as_tensor(
530
+ y_np, dtype=torch.float32, device=self.device).view(-1, 1)
531
+ if w is None:
532
+ w_tensor = torch.ones(
533
+ (len(X), 1), dtype=torch.float32, device=self.device)
534
+ else:
535
+ w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
536
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
537
+ w_tensor = torch.as_tensor(
538
+ w_np, dtype=torch.float32, device=self.device).view(-1, 1)
539
+ return X_tensor, y_tensor, w_tensor
540
+
541
+ def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
542
+ if not isinstance(X_df, pd.DataFrame):
543
+ raise ValueError("X must be a pandas DataFrame for graph building.")
544
+ meta_expected = None
545
+ cache_key = None
546
+ if self.graph_cache_path:
547
+ meta_expected = self._graph_cache_meta(X_df)
548
+ if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
549
+ cached = self._adj_cache_tensor
550
+ if cached.device != self.device:
551
+ if self.device.type == "mps" and getattr(cached, "is_sparse", False):
552
+ self._adj_cache_tensor = None
553
+ else:
554
+ cached = cached.to(self.device)
555
+ self._adj_cache_tensor = cached
556
+ if self._adj_cache_tensor is not None:
557
+ return self._adj_cache_tensor
558
+ else:
559
+ cache_key = self._graph_cache_key(X_df)
560
+ if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
561
+ cached = self._adj_cache_tensor
562
+ if cached.device != self.device:
563
+ if self.device.type == "mps" and getattr(cached, "is_sparse", False):
564
+ self._adj_cache_tensor = None
565
+ else:
566
+ cached = cached.to(self.device)
567
+ self._adj_cache_tensor = cached
568
+ if self._adj_cache_tensor is not None:
569
+ return self._adj_cache_tensor
570
+ X_np = None
571
+ if X_tensor is None:
572
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
573
+ X_tensor = torch.as_tensor(
574
+ X_np, dtype=torch.float32, device=self.device)
575
+ if self.graph_cache_path:
576
+ cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
577
+ if cached is not None:
578
+ self._adj_cache_meta = meta_expected
579
+ self._adj_cache_key = None
580
+ self._adj_cache_tensor = cached
581
+ return cached
582
+ use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
583
+ if use_gpu_knn:
584
+ edge_index = self._build_edge_index_gpu(X_tensor)
585
+ else:
586
+ if X_np is None:
587
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
588
+ edge_index = self._build_edge_index_cpu(X_np)
589
+ adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
590
+ if self.graph_cache_path:
591
+ try:
592
+ IOUtils.ensure_parent_dir(str(self.graph_cache_path))
593
+ torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
594
+ except Exception as exc:
595
+ print(
596
+ f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
597
+ self._adj_cache_meta = meta_expected
598
+ self._adj_cache_key = None
599
+ else:
600
+ self._adj_cache_meta = None
601
+ self._adj_cache_key = cache_key
602
+ self._adj_cache_tensor = adj_norm
603
+ return adj_norm
604
+
605
+ def fit(self, X_train, y_train, w_train=None,
606
+ X_val=None, y_val=None, w_val=None,
607
+ trial: Optional[optuna.trial.Trial] = None):
608
+ return self._run_with_mps_fallback(
609
+ self._fit_impl,
610
+ X_train,
611
+ y_train,
612
+ w_train,
613
+ X_val,
614
+ y_val,
615
+ w_val,
616
+ trial,
617
+ )
618
+
619
+ def _fit_impl(self, X_train, y_train, w_train=None,
620
+ X_val=None, y_val=None, w_val=None,
621
+ trial: Optional[optuna.trial.Trial] = None):
622
+ X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
623
+ X_train, y_train, w_train, allow_none=False)
624
+ has_val = X_val is not None and y_val is not None
625
+ if has_val:
626
+ X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
627
+ X_val, y_val, w_val, allow_none=False)
628
+ else:
629
+ X_val_tensor = y_val_tensor = w_val_tensor = None
630
+
631
+ adj_train = self._build_graph_from_df(X_train, X_train_tensor)
632
+ adj_val = self._build_graph_from_df(
633
+ X_val, X_val_tensor) if has_val else None
634
+ # DataParallel needs adjacency cached on the model to avoid scatter.
635
+ self._set_adj_buffer(adj_train)
636
+
637
+ base_gnn = self._unwrap_gnn()
638
+ optimizer = torch.optim.Adam(
639
+ base_gnn.parameters(),
640
+ lr=self.learning_rate,
641
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
642
+ )
643
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
644
+
645
+ best_loss = float('inf')
646
+ best_state = None
647
+ patience_counter = 0
648
+ best_epoch = None
649
+
650
+ for epoch in range(1, self.epochs + 1):
651
+ epoch_start_ts = time.time()
652
+ self.gnn.train()
653
+ optimizer.zero_grad()
654
+ with autocast(enabled=(self.device.type == 'cuda')):
655
+ if self.data_parallel_enabled:
656
+ y_pred = self.gnn(X_train_tensor)
657
+ else:
658
+ y_pred = self.gnn(X_train_tensor, adj_train)
659
+ loss = self._compute_weighted_loss(
660
+ y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
661
+ scaler.scale(loss).backward()
662
+ scaler.unscale_(optimizer)
663
+ clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
664
+ scaler.step(optimizer)
665
+ scaler.update()
666
+
667
+ val_loss = None
668
+ if has_val:
669
+ self.gnn.eval()
670
+ if self.data_parallel_enabled and adj_val is not None:
671
+ self._set_adj_buffer(adj_val)
672
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
673
+ if self.data_parallel_enabled:
674
+ y_val_pred = self.gnn(X_val_tensor)
675
+ else:
676
+ y_val_pred = self.gnn(X_val_tensor, adj_val)
677
+ val_loss = self._compute_weighted_loss(
678
+ y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
679
+ if self.data_parallel_enabled:
680
+ # Restore training adjacency.
681
+ self._set_adj_buffer(adj_train)
682
+
683
+ is_best = val_loss is not None and val_loss < best_loss
684
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
685
+ val_loss, best_loss, best_state, patience_counter, base_gnn,
686
+ ignore_keys=["adj_buffer"])
687
+ if is_best:
688
+ best_epoch = epoch
689
+
690
+ prune_now = False
691
+ if trial is not None:
692
+ trial.report(val_loss, epoch)
693
+ if trial.should_prune():
694
+ prune_now = True
695
+
696
+ if dist.is_initialized():
697
+ flag = torch.tensor(
698
+ [1 if prune_now else 0],
699
+ device=self.device,
700
+ dtype=torch.int32,
701
+ )
702
+ dist.broadcast(flag, src=0)
703
+ prune_now = bool(flag.item())
704
+
705
+ if prune_now:
706
+ raise optuna.TrialPruned()
707
+ if stop_training:
708
+ break
709
+
710
+ should_log = (not dist.is_initialized()
711
+ or DistributedUtils.is_main_process())
712
+ if should_log:
713
+ elapsed = int(time.time() - epoch_start_ts)
714
+ if val_loss is None:
715
+ print(
716
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
717
+ flush=True,
718
+ )
719
+ else:
720
+ print(
721
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
722
+ f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
723
+ flush=True,
724
+ )
725
+
726
+ if best_state is not None:
727
+ base_gnn.load_state_dict(best_state, strict=False)
728
+ self.best_epoch = int(best_epoch or self.epochs)
729
+
730
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
731
+ return self._run_with_mps_fallback(self._predict_impl, X)
732
+
733
+ def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
734
+ self.gnn.eval()
735
+ X_tensor, _, _ = self._tensorize_split(
736
+ X, None, None, allow_none=False)
737
+ adj = self._build_graph_from_df(X, X_tensor)
738
+ if self.data_parallel_enabled:
739
+ self._set_adj_buffer(adj)
740
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
741
+ with inference_cm():
742
+ if self.data_parallel_enabled:
743
+ y_pred = self.gnn(X_tensor).cpu().numpy()
744
+ else:
745
+ y_pred = self.gnn(X_tensor, adj).cpu().numpy()
746
+ if self.task_type == 'classification':
747
+ y_pred = 1 / (1 + np.exp(-y_pred))
748
+ else:
749
+ y_pred = np.clip(y_pred, 1e-6, None)
750
+ return y_pred.ravel()
751
+
752
+ def encode(self, X: pd.DataFrame) -> np.ndarray:
753
+ return self._run_with_mps_fallback(self._encode_impl, X)
754
+
755
+ def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
756
+ """Return per-sample node embeddings (hidden representations)."""
757
+ base = self._unwrap_gnn()
758
+ base.eval()
759
+ X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
760
+ adj = self._build_graph_from_df(X, X_tensor)
761
+ if self.data_parallel_enabled:
762
+ self._set_adj_buffer(adj)
763
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
764
+ with inference_cm():
765
+ h = X_tensor
766
+ layers = getattr(base, "layers", None)
767
+ if layers is None:
768
+ raise RuntimeError("GNN base module does not expose layers.")
769
+ for layer in layers:
770
+ h = layer(h, adj)
771
+ h = _adj_mm(adj, h)
772
+ return h.detach().cpu().numpy()
773
+
774
+ def set_params(self, params: Dict[str, Any]):
775
+ for key, value in params.items():
776
+ if hasattr(self, key):
777
+ setattr(self, key, value)
778
+ else:
779
+ raise ValueError(f"Parameter {key} not found in GNN model.")
780
+ # Rebuild the backbone after structural parameter changes.
781
+ self.gnn = SimpleGNN(
782
+ input_dim=self.input_dim,
783
+ hidden_dim=self.hidden_dim,
784
+ num_layers=self.num_layers,
785
+ dropout=self.dropout,
786
+ task_type=self.task_type
787
+ ).to(self.device)
788
+ return self