ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,785 +1,794 @@
1
- from __future__ import annotations
2
-
3
- import hashlib
4
- import os
5
- import time
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Tuple
8
-
9
- import numpy as np
10
- import pandas as pd
11
- import torch
12
- import torch.distributed as dist
13
- import torch.nn as nn
14
- from sklearn.neighbors import NearestNeighbors
15
- from torch.cuda.amp import autocast, GradScaler
16
- from torch.nn.parallel import DistributedDataParallel as DDP
17
- from torch.nn.utils import clip_grad_norm_
18
-
19
- from ..utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
20
- from ..utils.losses import (
21
- infer_loss_name_from_model_name,
22
- normalize_loss_name,
23
- resolve_tweedie_power,
24
- )
25
-
26
- try:
27
- from torch_geometric.nn import knn_graph
28
- from torch_geometric.utils import add_self_loops, to_undirected
29
- _PYG_AVAILABLE = True
30
- except Exception:
31
- knn_graph = None # type: ignore
32
- add_self_loops = None # type: ignore
33
- to_undirected = None # type: ignore
34
- _PYG_AVAILABLE = False
35
-
36
- try:
37
- import pynndescent
38
- _PYNN_AVAILABLE = True
39
- except Exception:
40
- pynndescent = None # type: ignore
41
- _PYNN_AVAILABLE = False
42
-
43
- _GNN_MPS_WARNED = False
44
-
45
-
46
- # =============================================================================
47
- # Simplified GNN implementation.
48
- # =============================================================================
49
-
50
- def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
51
- """Matrix multiply that supports sparse or dense adjacency."""
52
- if adj.is_sparse:
53
- return torch.sparse.mm(adj, x)
54
- return adj.matmul(x)
55
-
56
-
57
- class SimpleGraphLayer(nn.Module):
58
- def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
59
- super().__init__()
60
- self.linear = nn.Linear(in_dim, out_dim)
61
- self.activation = nn.ReLU(inplace=True)
62
- self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
63
-
64
- def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
65
- # Message passing with normalized sparse adjacency: A_hat * X * W.
66
- h = _adj_mm(adj, x)
67
- h = self.linear(h)
68
- h = self.activation(h)
69
- return self.dropout(h)
70
-
71
-
72
- class SimpleGNN(nn.Module):
73
- def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
74
- dropout: float = 0.1, task_type: str = 'regression'):
75
- super().__init__()
76
- layers = []
77
- dim_in = input_dim
78
- for _ in range(max(1, num_layers)):
79
- layers.append(SimpleGraphLayer(
80
- dim_in, hidden_dim, dropout=dropout))
81
- dim_in = hidden_dim
82
- self.layers = nn.ModuleList(layers)
83
- self.output = nn.Linear(hidden_dim, 1)
84
- if task_type == 'classification':
85
- self.output_act = nn.Identity()
86
- else:
87
- self.output_act = nn.Softplus()
88
- self.task_type = task_type
89
- # Keep adjacency as a buffer for DataParallel copies.
90
- self.register_buffer("adj_buffer", torch.empty(0))
91
-
92
- def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
93
- adj_used = adj if adj is not None else getattr(
94
- self, "adj_buffer", None)
95
- if adj_used is None or adj_used.numel() == 0:
96
- raise RuntimeError("Adjacency is not set for GNN forward.")
97
- h = x
98
- for layer in self.layers:
99
- h = layer(h, adj_used)
100
- h = _adj_mm(adj_used, h)
101
- out = self.output(h)
102
- return self.output_act(out)
103
-
104
-
105
- class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
106
- def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
107
- num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
108
- learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
109
- task_type: str = 'regression', tweedie_power: float = 1.5,
110
- weight_decay: float = 0.0,
111
- use_data_parallel: bool = False, use_ddp: bool = False,
112
- use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
113
- graph_cache_path: Optional[str] = None,
114
- max_gpu_knn_nodes: Optional[int] = None,
115
- knn_gpu_mem_ratio: float = 0.9,
116
- knn_gpu_mem_overhead: float = 2.0,
117
- knn_cpu_jobs: Optional[int] = -1,
118
- loss_name: Optional[str] = None) -> None:
119
- super().__init__()
120
- self.model_nme = model_nme
121
- self.input_dim = input_dim
122
- self.hidden_dim = hidden_dim
123
- self.num_layers = num_layers
124
- self.k_neighbors = max(1, k_neighbors)
125
- self.dropout = dropout
126
- self.learning_rate = learning_rate
127
- self.weight_decay = weight_decay
128
- self.epochs = epochs
129
- self.patience = patience
130
- self.task_type = task_type
131
- self.use_approx_knn = use_approx_knn
132
- self.approx_knn_threshold = approx_knn_threshold
133
- self.graph_cache_path = Path(
134
- graph_cache_path) if graph_cache_path else None
135
- self.max_gpu_knn_nodes = max_gpu_knn_nodes
136
- self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
137
- self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
138
- self.knn_cpu_jobs = knn_cpu_jobs
139
- self.mps_dense_max_nodes = int(
140
- os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
141
- )
142
- self._knn_warning_emitted = False
143
- self._mps_fallback_triggered = False
144
- self._adj_cache_meta: Optional[Dict[str, Any]] = None
145
- self._adj_cache_key: Optional[Tuple[Any, ...]] = None
146
- self._adj_cache_tensor: Optional[torch.Tensor] = None
147
-
148
- resolved_loss = normalize_loss_name(loss_name, self.task_type)
149
- if self.task_type == 'classification':
150
- self.loss_name = "logloss"
151
- self.tw_power = None
152
- else:
153
- if resolved_loss == "auto":
154
- resolved_loss = infer_loss_name_from_model_name(self.model_nme)
155
- self.loss_name = resolved_loss
156
- if self.loss_name == "tweedie":
157
- self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
158
- else:
159
- self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
160
-
161
- self.ddp_enabled = False
162
- self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
163
- self.data_parallel_enabled = False
164
- self._ddp_disabled = False
165
-
166
- if use_ddp:
167
- world_size = int(os.environ.get("WORLD_SIZE", "1"))
168
- if world_size > 1:
169
- print(
170
- "[GNN] DDP training is not supported; falling back to single process.",
171
- flush=True,
172
- )
173
- self._ddp_disabled = True
174
- use_ddp = False
175
-
176
- # DDP only works with CUDA; fall back to single process if init fails.
177
- if use_ddp and torch.cuda.is_available():
178
- ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
179
- if ddp_ok:
180
- self.ddp_enabled = True
181
- self.local_rank = local_rank
182
- self.device = torch.device(f'cuda:{local_rank}')
183
- else:
184
- self.device = torch.device('cuda')
185
- elif torch.cuda.is_available():
186
- if self._ddp_disabled:
187
- self.device = torch.device(f'cuda:{self.local_rank}')
188
- else:
189
- self.device = torch.device('cuda')
190
- elif torch.backends.mps.is_available():
191
- self.device = torch.device('mps')
192
- global _GNN_MPS_WARNED
193
- if not _GNN_MPS_WARNED:
194
- print(
195
- "[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
196
- flush=True,
197
- )
198
- _GNN_MPS_WARNED = True
199
- else:
200
- self.device = torch.device('cpu')
201
- self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
202
-
203
- self.gnn = SimpleGNN(
204
- input_dim=self.input_dim,
205
- hidden_dim=self.hidden_dim,
206
- num_layers=self.num_layers,
207
- dropout=self.dropout,
208
- task_type=self.task_type
209
- ).to(self.device)
210
-
211
- # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
212
- if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
213
- self.data_parallel_enabled = True
214
- self.gnn = nn.DataParallel(
215
- self.gnn, device_ids=list(range(torch.cuda.device_count())))
216
- self.device = torch.device('cuda')
217
-
218
- if self.ddp_enabled:
219
- self.gnn = DDP(
220
- self.gnn,
221
- device_ids=[self.local_rank],
222
- output_device=self.local_rank,
223
- find_unused_parameters=False
224
- )
225
-
226
- @staticmethod
227
- def _validate_vector(arr, name: str, n_rows: int) -> None:
228
- if arr is None:
229
- return
230
- if isinstance(arr, pd.DataFrame):
231
- if arr.shape[1] != 1:
232
- raise ValueError(f"{name} must be 1d (single column).")
233
- length = len(arr)
234
- else:
235
- arr_np = np.asarray(arr)
236
- if arr_np.ndim == 0:
237
- raise ValueError(f"{name} must be 1d.")
238
- if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
239
- raise ValueError(f"{name} must be 1d or Nx1.")
240
- length = arr_np.shape[0]
241
- if length != n_rows:
242
- raise ValueError(
243
- f"{name} length {length} does not match X length {n_rows}."
244
- )
245
-
246
- def _unwrap_gnn(self) -> nn.Module:
247
- if isinstance(self.gnn, (DDP, nn.DataParallel)):
248
- return self.gnn.module
249
- return self.gnn
250
-
251
- def _set_adj_buffer(self, adj: torch.Tensor) -> None:
252
- base = self._unwrap_gnn()
253
- if hasattr(base, "adj_buffer"):
254
- base.adj_buffer = adj
255
- else:
256
- base.register_buffer("adj_buffer", adj)
257
-
258
- @staticmethod
259
- def _is_mps_unsupported_error(exc: BaseException) -> bool:
260
- msg = str(exc).lower()
261
- if "mps" not in msg:
262
- return False
263
- if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
264
- return True
265
- return "sparse" in msg
266
-
267
- def _fallback_to_cpu(self, reason: str) -> None:
268
- if self.device.type != "mps" or self._mps_fallback_triggered:
269
- return
270
- self._mps_fallback_triggered = True
271
- print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
272
- self.device = torch.device("cpu")
273
- self.use_pyg_knn = False
274
- self.data_parallel_enabled = False
275
- self.ddp_enabled = False
276
- base = self._unwrap_gnn()
277
- try:
278
- base = base.to(self.device)
279
- except Exception:
280
- pass
281
- self.gnn = base
282
- self.invalidate_graph_cache()
283
-
284
- def _run_with_mps_fallback(self, fn, *args, **kwargs):
285
- try:
286
- return fn(*args, **kwargs)
287
- except (RuntimeError, NotImplementedError) as exc:
288
- if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
289
- self._fallback_to_cpu(str(exc))
290
- return fn(*args, **kwargs)
291
- raise
292
-
293
- def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
294
- row_hash = pd.util.hash_pandas_object(X_df, index=False).values
295
- idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
296
- col_sig = ",".join(map(str, X_df.columns))
297
- hasher = hashlib.sha256()
298
- hasher.update(row_hash.tobytes())
299
- hasher.update(idx_hash.tobytes())
300
- hasher.update(col_sig.encode("utf-8", errors="ignore"))
301
- knn_config = {
302
- "k_neighbors": int(self.k_neighbors),
303
- "use_approx_knn": bool(self.use_approx_knn),
304
- "approx_knn_threshold": int(self.approx_knn_threshold),
305
- "use_pyg_knn": bool(self.use_pyg_knn),
306
- "pynndescent_available": bool(_PYNN_AVAILABLE),
307
- "max_gpu_knn_nodes": (
308
- None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
309
- ),
310
- "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
311
- "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
312
- }
313
- adj_format = "dense" if self.device.type == "mps" else "sparse"
314
- return {
315
- "n_samples": int(X_df.shape[0]),
316
- "n_features": int(X_df.shape[1]),
317
- "hash": hasher.hexdigest(),
318
- "knn_config": knn_config,
319
- "adj_format": adj_format,
320
- "device_type": self.device.type,
321
- }
322
-
323
- def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
324
- return (
325
- id(X_df),
326
- id(getattr(X_df, "_mgr", None)),
327
- id(X_df.index),
328
- X_df.shape,
329
- tuple(map(str, X_df.columns)),
330
- X_df.attrs.get("graph_cache_key"),
331
- )
332
-
333
- def invalidate_graph_cache(self) -> None:
334
- self._adj_cache_meta = None
335
- self._adj_cache_key = None
336
- self._adj_cache_tensor = None
337
-
338
- def _load_cached_adj(self,
339
- X_df: pd.DataFrame,
340
- meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
341
- if self.graph_cache_path and self.graph_cache_path.exists():
342
- if meta_expected is None:
343
- meta_expected = self._graph_cache_meta(X_df)
344
- try:
345
- payload = torch.load(self.graph_cache_path, map_location="cpu")
346
- except Exception as exc:
347
- print(
348
- f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
349
- return None
350
- if isinstance(payload, dict) and "adj" in payload:
351
- meta_cached = payload.get("meta")
352
- if meta_cached == meta_expected:
353
- adj = payload["adj"]
354
- if self.device.type == "mps" and getattr(adj, "is_sparse", False):
355
- print(
356
- f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
357
- )
358
- return None
359
- return adj.to(self.device)
360
- print(
361
- f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
362
- return None
363
- if isinstance(payload, torch.Tensor):
364
- print(
365
- f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
366
- return None
367
- print(
368
- f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
369
- return None
370
-
371
- def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
372
- n_samples = X_np.shape[0]
373
- k = min(self.k_neighbors, max(1, n_samples - 1))
374
- n_neighbors = min(k + 1, n_samples)
375
- use_approx = (self.use_approx_knn or n_samples >=
376
- self.approx_knn_threshold) and _PYNN_AVAILABLE
377
- indices = None
378
- if use_approx:
379
- try:
380
- nn_index = pynndescent.NNDescent(
381
- X_np,
382
- n_neighbors=n_neighbors,
383
- random_state=0
384
- )
385
- indices, _ = nn_index.neighbor_graph
386
- except Exception as exc:
387
- print(
388
- f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
389
- use_approx = False
390
-
391
- if indices is None:
392
- nbrs = NearestNeighbors(
393
- n_neighbors=n_neighbors,
394
- algorithm="auto",
395
- n_jobs=self.knn_cpu_jobs,
396
- )
397
- nbrs.fit(X_np)
398
- _, indices = nbrs.kneighbors(X_np)
399
-
400
- indices = np.asarray(indices)
401
- rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
402
- np.int64, copy=False)
403
- cols = indices.reshape(-1).astype(np.int64, copy=False)
404
- mask = rows != cols
405
- rows = rows[mask]
406
- cols = cols[mask]
407
- rows_base = rows
408
- cols_base = cols
409
- self_loops = np.arange(n_samples, dtype=np.int64)
410
- rows = np.concatenate([rows_base, cols_base, self_loops])
411
- cols = np.concatenate([cols_base, rows_base, self_loops])
412
-
413
- edge_index_np = np.stack([rows, cols], axis=0)
414
- edge_index = torch.as_tensor(edge_index_np, device=self.device)
415
- return edge_index
416
-
417
- def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
418
- if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
419
- # Defensive: check use_pyg_knn before calling.
420
- raise RuntimeError(
421
- "GPU graph builder requested but PyG is unavailable.")
422
-
423
- n_samples = X_tensor.size(0)
424
- k = min(self.k_neighbors, max(1, n_samples - 1))
425
-
426
- # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
427
- edge_index = knn_graph(
428
- X_tensor,
429
- k=k,
430
- loop=False
431
- )
432
- edge_index = to_undirected(edge_index, num_nodes=n_samples)
433
- edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
434
- return edge_index
435
-
436
- def _log_knn_fallback(self, reason: str) -> None:
437
- if self._knn_warning_emitted:
438
- return
439
- if (not self.ddp_enabled) or self.local_rank == 0:
440
- print(f"[GNN] Falling back to CPU kNN builder: {reason}")
441
- self._knn_warning_emitted = True
442
-
443
- def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
444
- if not self.use_pyg_knn:
445
- return False
446
-
447
- reason = None
448
- if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
449
- reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
450
- elif self.device.type == 'cuda' and torch.cuda.is_available():
451
- try:
452
- device_index = self.device.index
453
- if device_index is None:
454
- device_index = torch.cuda.current_device()
455
- free_mem, total_mem = torch.cuda.mem_get_info(device_index)
456
- feature_bytes = X_tensor.element_size() * X_tensor.nelement()
457
- required = int(feature_bytes * self.knn_gpu_mem_overhead)
458
- budget = int(free_mem * self.knn_gpu_mem_ratio)
459
- if required > budget:
460
- required_gb = required / (1024 ** 3)
461
- budget_gb = budget / (1024 ** 3)
462
- reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
463
- f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
464
- except Exception:
465
- # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
466
- reason = None
467
-
468
- if reason:
469
- self._log_knn_fallback(reason)
470
- return False
471
- return True
472
-
473
- def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
474
- if self.device.type == "mps":
475
- return self._normalized_adj_dense(edge_index, num_nodes)
476
- return self._normalized_adj_sparse(edge_index, num_nodes)
477
-
478
- def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
479
- values = torch.ones(edge_index.shape[1], device=self.device)
480
- adj = torch.sparse_coo_tensor(
481
- edge_index.to(self.device), values, (num_nodes, num_nodes))
482
- adj = adj.coalesce()
483
-
484
- deg = torch.sparse.sum(adj, dim=1).to_dense()
485
- deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
486
- row, col = adj.indices()
487
- norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
488
- adj_norm = torch.sparse_coo_tensor(
489
- adj.indices(), norm_values, size=adj.shape)
490
- return adj_norm
491
-
492
- def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
493
- if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
494
- raise RuntimeError(
495
- f"MPS dense adjacency not supported for {num_nodes} nodes; "
496
- f"max={self.mps_dense_max_nodes}. Falling back to CPU."
497
- )
498
- edge_index = edge_index.to(self.device)
499
- adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
500
- adj[edge_index[0], edge_index[1]] = 1.0
501
- deg = adj.sum(dim=1)
502
- deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
503
- adj = adj * deg_inv_sqrt.view(-1, 1)
504
- adj = adj * deg_inv_sqrt.view(1, -1)
505
- return adj
506
-
507
- def _tensorize_split(self, X, y, w, allow_none: bool = False):
508
- if X is None and allow_none:
509
- return None, None, None
510
- if not isinstance(X, pd.DataFrame):
511
- raise ValueError("X must be a pandas DataFrame for GNN.")
512
- n_rows = len(X)
513
- if y is not None:
514
- self._validate_vector(y, "y", n_rows)
515
- if w is not None:
516
- self._validate_vector(w, "w", n_rows)
517
- X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
518
- X, "to_numpy") else np.asarray(X, dtype=np.float32)
519
- X_tensor = torch.as_tensor(
520
- X_np, dtype=torch.float32, device=self.device)
521
- if y is None:
522
- y_tensor = None
523
- else:
524
- y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
525
- y, "to_numpy") else np.asarray(y, dtype=np.float32)
526
- y_tensor = torch.as_tensor(
527
- y_np, dtype=torch.float32, device=self.device).view(-1, 1)
528
- if w is None:
529
- w_tensor = torch.ones(
530
- (len(X), 1), dtype=torch.float32, device=self.device)
531
- else:
532
- w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
533
- w, "to_numpy") else np.asarray(w, dtype=np.float32)
534
- w_tensor = torch.as_tensor(
535
- w_np, dtype=torch.float32, device=self.device).view(-1, 1)
536
- return X_tensor, y_tensor, w_tensor
537
-
538
- def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
539
- if not isinstance(X_df, pd.DataFrame):
540
- raise ValueError("X must be a pandas DataFrame for graph building.")
541
- meta_expected = None
542
- cache_key = None
543
- if self.graph_cache_path:
544
- meta_expected = self._graph_cache_meta(X_df)
545
- if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
546
- cached = self._adj_cache_tensor
547
- if cached.device != self.device:
548
- if self.device.type == "mps" and getattr(cached, "is_sparse", False):
549
- self._adj_cache_tensor = None
550
- else:
551
- cached = cached.to(self.device)
552
- self._adj_cache_tensor = cached
553
- if self._adj_cache_tensor is not None:
554
- return self._adj_cache_tensor
555
- else:
556
- cache_key = self._graph_cache_key(X_df)
557
- if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
558
- cached = self._adj_cache_tensor
559
- if cached.device != self.device:
560
- if self.device.type == "mps" and getattr(cached, "is_sparse", False):
561
- self._adj_cache_tensor = None
562
- else:
563
- cached = cached.to(self.device)
564
- self._adj_cache_tensor = cached
565
- if self._adj_cache_tensor is not None:
566
- return self._adj_cache_tensor
567
- X_np = None
568
- if X_tensor is None:
569
- X_np = X_df.to_numpy(dtype=np.float32, copy=False)
570
- X_tensor = torch.as_tensor(
571
- X_np, dtype=torch.float32, device=self.device)
572
- if self.graph_cache_path:
573
- cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
574
- if cached is not None:
575
- self._adj_cache_meta = meta_expected
576
- self._adj_cache_key = None
577
- self._adj_cache_tensor = cached
578
- return cached
579
- use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
580
- if use_gpu_knn:
581
- edge_index = self._build_edge_index_gpu(X_tensor)
582
- else:
583
- if X_np is None:
584
- X_np = X_df.to_numpy(dtype=np.float32, copy=False)
585
- edge_index = self._build_edge_index_cpu(X_np)
586
- adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
587
- if self.graph_cache_path:
588
- try:
589
- IOUtils.ensure_parent_dir(str(self.graph_cache_path))
590
- torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
591
- except Exception as exc:
592
- print(
593
- f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
594
- self._adj_cache_meta = meta_expected
595
- self._adj_cache_key = None
596
- else:
597
- self._adj_cache_meta = None
598
- self._adj_cache_key = cache_key
599
- self._adj_cache_tensor = adj_norm
600
- return adj_norm
601
-
602
- def fit(self, X_train, y_train, w_train=None,
603
- X_val=None, y_val=None, w_val=None,
604
- trial: Optional[optuna.trial.Trial] = None):
605
- return self._run_with_mps_fallback(
606
- self._fit_impl,
607
- X_train,
608
- y_train,
609
- w_train,
610
- X_val,
611
- y_val,
612
- w_val,
613
- trial,
614
- )
615
-
616
- def _fit_impl(self, X_train, y_train, w_train=None,
617
- X_val=None, y_val=None, w_val=None,
618
- trial: Optional[optuna.trial.Trial] = None):
619
- X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
620
- X_train, y_train, w_train, allow_none=False)
621
- has_val = X_val is not None and y_val is not None
622
- if has_val:
623
- X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
624
- X_val, y_val, w_val, allow_none=False)
625
- else:
626
- X_val_tensor = y_val_tensor = w_val_tensor = None
627
-
628
- adj_train = self._build_graph_from_df(X_train, X_train_tensor)
629
- adj_val = self._build_graph_from_df(
630
- X_val, X_val_tensor) if has_val else None
631
- # DataParallel needs adjacency cached on the model to avoid scatter.
632
- self._set_adj_buffer(adj_train)
633
-
634
- base_gnn = self._unwrap_gnn()
635
- optimizer = torch.optim.Adam(
636
- base_gnn.parameters(),
637
- lr=self.learning_rate,
638
- weight_decay=float(getattr(self, "weight_decay", 0.0)),
639
- )
640
- scaler = GradScaler(enabled=(self.device.type == 'cuda'))
641
-
642
- best_loss = float('inf')
643
- best_state = None
644
- patience_counter = 0
645
- best_epoch = None
646
-
647
- for epoch in range(1, self.epochs + 1):
648
- epoch_start_ts = time.time()
649
- self.gnn.train()
650
- optimizer.zero_grad()
651
- with autocast(enabled=(self.device.type == 'cuda')):
652
- if self.data_parallel_enabled:
653
- y_pred = self.gnn(X_train_tensor)
654
- else:
655
- y_pred = self.gnn(X_train_tensor, adj_train)
656
- loss = self._compute_weighted_loss(
657
- y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
658
- scaler.scale(loss).backward()
659
- scaler.unscale_(optimizer)
660
- clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
661
- scaler.step(optimizer)
662
- scaler.update()
663
-
664
- val_loss = None
665
- if has_val:
666
- self.gnn.eval()
667
- if self.data_parallel_enabled and adj_val is not None:
668
- self._set_adj_buffer(adj_val)
669
- with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
670
- if self.data_parallel_enabled:
671
- y_val_pred = self.gnn(X_val_tensor)
672
- else:
673
- y_val_pred = self.gnn(X_val_tensor, adj_val)
674
- val_loss = self._compute_weighted_loss(
675
- y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
676
- if self.data_parallel_enabled:
677
- # Restore training adjacency.
678
- self._set_adj_buffer(adj_train)
679
-
680
- is_best = val_loss is not None and val_loss < best_loss
681
- best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
682
- val_loss, best_loss, best_state, patience_counter, base_gnn,
683
- ignore_keys=["adj_buffer"])
684
- if is_best:
685
- best_epoch = epoch
686
-
687
- prune_now = False
688
- if trial is not None:
689
- trial.report(val_loss, epoch)
690
- if trial.should_prune():
691
- prune_now = True
692
-
693
- if dist.is_initialized():
694
- flag = torch.tensor(
695
- [1 if prune_now else 0],
696
- device=self.device,
697
- dtype=torch.int32,
698
- )
699
- dist.broadcast(flag, src=0)
700
- prune_now = bool(flag.item())
701
-
702
- if prune_now:
703
- raise optuna.TrialPruned()
704
- if stop_training:
705
- break
706
-
707
- should_log = (not dist.is_initialized()
708
- or DistributedUtils.is_main_process())
709
- if should_log:
710
- elapsed = int(time.time() - epoch_start_ts)
711
- if val_loss is None:
712
- print(
713
- f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
714
- flush=True,
715
- )
716
- else:
717
- print(
718
- f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
719
- f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
720
- flush=True,
721
- )
722
-
723
- if best_state is not None:
724
- base_gnn.load_state_dict(best_state, strict=False)
725
- self.best_epoch = int(best_epoch or self.epochs)
726
-
727
- def predict(self, X: pd.DataFrame) -> np.ndarray:
728
- return self._run_with_mps_fallback(self._predict_impl, X)
729
-
730
- def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
731
- self.gnn.eval()
732
- X_tensor, _, _ = self._tensorize_split(
733
- X, None, None, allow_none=False)
734
- adj = self._build_graph_from_df(X, X_tensor)
735
- if self.data_parallel_enabled:
736
- self._set_adj_buffer(adj)
737
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
738
- with inference_cm():
739
- if self.data_parallel_enabled:
740
- y_pred = self.gnn(X_tensor).cpu().numpy()
741
- else:
742
- y_pred = self.gnn(X_tensor, adj).cpu().numpy()
743
- if self.task_type == 'classification':
744
- y_pred = 1 / (1 + np.exp(-y_pred))
745
- else:
746
- y_pred = np.clip(y_pred, 1e-6, None)
747
- return y_pred.ravel()
748
-
749
- def encode(self, X: pd.DataFrame) -> np.ndarray:
750
- return self._run_with_mps_fallback(self._encode_impl, X)
751
-
752
- def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
753
- """Return per-sample node embeddings (hidden representations)."""
754
- base = self._unwrap_gnn()
755
- base.eval()
756
- X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
757
- adj = self._build_graph_from_df(X, X_tensor)
758
- if self.data_parallel_enabled:
759
- self._set_adj_buffer(adj)
760
- inference_cm = getattr(torch, "inference_mode", torch.no_grad)
761
- with inference_cm():
762
- h = X_tensor
763
- layers = getattr(base, "layers", None)
764
- if layers is None:
765
- raise RuntimeError("GNN base module does not expose layers.")
766
- for layer in layers:
767
- h = layer(h, adj)
768
- h = _adj_mm(adj, h)
769
- return h.detach().cpu().numpy()
770
-
771
- def set_params(self, params: Dict[str, Any]):
772
- for key, value in params.items():
773
- if hasattr(self, key):
774
- setattr(self, key, value)
775
- else:
776
- raise ValueError(f"Parameter {key} not found in GNN model.")
777
- # Rebuild the backbone after structural parameter changes.
778
- self.gnn = SimpleGNN(
779
- input_dim=self.input_dim,
780
- hidden_dim=self.hidden_dim,
781
- num_layers=self.num_layers,
782
- dropout=self.dropout,
783
- task_type=self.task_type
784
- ).to(self.device)
785
- return self
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn as nn
14
+ from sklearn.neighbors import NearestNeighbors
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.nn.utils import clip_grad_norm_
18
+
19
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
20
+ from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
21
+ from ins_pricing.utils import EPS, get_logger, log_print
22
+ from ins_pricing.utils.io import IOUtils
23
+ from ins_pricing.utils.losses import (
24
+ infer_loss_name_from_model_name,
25
+ normalize_loss_name,
26
+ resolve_tweedie_power,
27
+ )
28
+
29
+ try:
30
+ from torch_geometric.nn import knn_graph
31
+ from torch_geometric.utils import add_self_loops, to_undirected
32
+ _PYG_AVAILABLE = True
33
+ except Exception:
34
+ knn_graph = None # type: ignore
35
+ add_self_loops = None # type: ignore
36
+ to_undirected = None # type: ignore
37
+ _PYG_AVAILABLE = False
38
+
39
+ try:
40
+ import pynndescent
41
+ _PYNN_AVAILABLE = True
42
+ except Exception:
43
+ pynndescent = None # type: ignore
44
+ _PYNN_AVAILABLE = False
45
+
46
+ _GNN_MPS_WARNED = False
47
+
48
+ _logger = get_logger("ins_pricing.modelling.bayesopt.models.model_gnn")
49
+
50
+
51
+ def _log(*args, **kwargs) -> None:
52
+ log_print(_logger, *args, **kwargs)
53
+
54
+
55
+ # =============================================================================
56
+ # Simplified GNN implementation.
57
+ # =============================================================================
58
+
59
+ def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
60
+ """Matrix multiply that supports sparse or dense adjacency."""
61
+ if adj.is_sparse:
62
+ return torch.sparse.mm(adj, x)
63
+ return adj.matmul(x)
64
+
65
+
66
+ class SimpleGraphLayer(nn.Module):
67
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
68
+ super().__init__()
69
+ self.linear = nn.Linear(in_dim, out_dim)
70
+ self.activation = nn.ReLU(inplace=True)
71
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
72
+
73
+ def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
74
+ # Message passing with normalized sparse adjacency: A_hat * X * W.
75
+ h = _adj_mm(adj, x)
76
+ h = self.linear(h)
77
+ h = self.activation(h)
78
+ return self.dropout(h)
79
+
80
+
81
+ class SimpleGNN(nn.Module):
82
+ def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
83
+ dropout: float = 0.1, task_type: str = 'regression'):
84
+ super().__init__()
85
+ layers = []
86
+ dim_in = input_dim
87
+ for _ in range(max(1, num_layers)):
88
+ layers.append(SimpleGraphLayer(
89
+ dim_in, hidden_dim, dropout=dropout))
90
+ dim_in = hidden_dim
91
+ self.layers = nn.ModuleList(layers)
92
+ self.output = nn.Linear(hidden_dim, 1)
93
+ if task_type == 'classification':
94
+ self.output_act = nn.Identity()
95
+ else:
96
+ self.output_act = nn.Softplus()
97
+ self.task_type = task_type
98
+ # Keep adjacency as a buffer for DataParallel copies.
99
+ self.register_buffer("adj_buffer", torch.empty(0))
100
+
101
+ def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
102
+ adj_used = adj if adj is not None else getattr(
103
+ self, "adj_buffer", None)
104
+ if adj_used is None or adj_used.numel() == 0:
105
+ raise RuntimeError("Adjacency is not set for GNN forward.")
106
+ h = x
107
+ for layer in self.layers:
108
+ h = layer(h, adj_used)
109
+ h = _adj_mm(adj_used, h)
110
+ out = self.output(h)
111
+ return self.output_act(out)
112
+
113
+
114
+ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
115
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
116
+ num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
117
+ learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
118
+ task_type: str = 'regression', tweedie_power: float = 1.5,
119
+ weight_decay: float = 0.0,
120
+ use_data_parallel: bool = False, use_ddp: bool = False,
121
+ use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
122
+ graph_cache_path: Optional[str] = None,
123
+ max_gpu_knn_nodes: Optional[int] = None,
124
+ knn_gpu_mem_ratio: float = 0.9,
125
+ knn_gpu_mem_overhead: float = 2.0,
126
+ knn_cpu_jobs: Optional[int] = -1,
127
+ loss_name: Optional[str] = None) -> None:
128
+ super().__init__()
129
+ self.model_nme = model_nme
130
+ self.input_dim = input_dim
131
+ self.hidden_dim = hidden_dim
132
+ self.num_layers = num_layers
133
+ self.k_neighbors = max(1, k_neighbors)
134
+ self.dropout = dropout
135
+ self.learning_rate = learning_rate
136
+ self.weight_decay = weight_decay
137
+ self.epochs = epochs
138
+ self.patience = patience
139
+ self.task_type = task_type
140
+ self.use_approx_knn = use_approx_knn
141
+ self.approx_knn_threshold = approx_knn_threshold
142
+ self.graph_cache_path = Path(
143
+ graph_cache_path) if graph_cache_path else None
144
+ self.max_gpu_knn_nodes = max_gpu_knn_nodes
145
+ self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
146
+ self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
147
+ self.knn_cpu_jobs = knn_cpu_jobs
148
+ self.mps_dense_max_nodes = int(
149
+ os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
150
+ )
151
+ self._knn_warning_emitted = False
152
+ self._mps_fallback_triggered = False
153
+ self._adj_cache_meta: Optional[Dict[str, Any]] = None
154
+ self._adj_cache_key: Optional[Tuple[Any, ...]] = None
155
+ self._adj_cache_tensor: Optional[torch.Tensor] = None
156
+
157
+ resolved_loss = normalize_loss_name(loss_name, self.task_type)
158
+ if self.task_type == 'classification':
159
+ self.loss_name = "logloss"
160
+ self.tw_power = None
161
+ else:
162
+ if resolved_loss == "auto":
163
+ resolved_loss = infer_loss_name_from_model_name(self.model_nme)
164
+ self.loss_name = resolved_loss
165
+ if self.loss_name == "tweedie":
166
+ self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
167
+ else:
168
+ self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
169
+
170
+ self.ddp_enabled = False
171
+ self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
172
+ self.data_parallel_enabled = False
173
+ self._ddp_disabled = False
174
+
175
+ if use_ddp:
176
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
177
+ if world_size > 1:
178
+ _log(
179
+ "[GNN] DDP training is not supported; falling back to single process.",
180
+ flush=True,
181
+ )
182
+ self._ddp_disabled = True
183
+ use_ddp = False
184
+
185
+ # DDP only works with CUDA; fall back to single process if init fails.
186
+ if use_ddp and torch.cuda.is_available():
187
+ ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
188
+ if ddp_ok:
189
+ self.ddp_enabled = True
190
+ self.local_rank = local_rank
191
+ self.device = torch.device(f'cuda:{local_rank}')
192
+ else:
193
+ self.device = torch.device('cuda')
194
+ elif torch.cuda.is_available():
195
+ if self._ddp_disabled:
196
+ self.device = torch.device(f'cuda:{self.local_rank}')
197
+ else:
198
+ self.device = torch.device('cuda')
199
+ elif torch.backends.mps.is_available():
200
+ self.device = torch.device('mps')
201
+ global _GNN_MPS_WARNED
202
+ if not _GNN_MPS_WARNED:
203
+ _log(
204
+ "[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
205
+ flush=True,
206
+ )
207
+ _GNN_MPS_WARNED = True
208
+ else:
209
+ self.device = torch.device('cpu')
210
+ self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
211
+
212
+ self.gnn = SimpleGNN(
213
+ input_dim=self.input_dim,
214
+ hidden_dim=self.hidden_dim,
215
+ num_layers=self.num_layers,
216
+ dropout=self.dropout,
217
+ task_type=self.task_type
218
+ ).to(self.device)
219
+
220
+ # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
221
+ if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
222
+ self.data_parallel_enabled = True
223
+ self.gnn = nn.DataParallel(
224
+ self.gnn, device_ids=list(range(torch.cuda.device_count())))
225
+ self.device = torch.device('cuda')
226
+
227
+ if self.ddp_enabled:
228
+ self.gnn = DDP(
229
+ self.gnn,
230
+ device_ids=[self.local_rank],
231
+ output_device=self.local_rank,
232
+ find_unused_parameters=False
233
+ )
234
+
235
+ @staticmethod
236
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
237
+ if arr is None:
238
+ return
239
+ if isinstance(arr, pd.DataFrame):
240
+ if arr.shape[1] != 1:
241
+ raise ValueError(f"{name} must be 1d (single column).")
242
+ length = len(arr)
243
+ else:
244
+ arr_np = np.asarray(arr)
245
+ if arr_np.ndim == 0:
246
+ raise ValueError(f"{name} must be 1d.")
247
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
248
+ raise ValueError(f"{name} must be 1d or Nx1.")
249
+ length = arr_np.shape[0]
250
+ if length != n_rows:
251
+ raise ValueError(
252
+ f"{name} length {length} does not match X length {n_rows}."
253
+ )
254
+
255
+ def _unwrap_gnn(self) -> nn.Module:
256
+ if isinstance(self.gnn, (DDP, nn.DataParallel)):
257
+ return self.gnn.module
258
+ return self.gnn
259
+
260
+ def _set_adj_buffer(self, adj: torch.Tensor) -> None:
261
+ base = self._unwrap_gnn()
262
+ if hasattr(base, "adj_buffer"):
263
+ base.adj_buffer = adj
264
+ else:
265
+ base.register_buffer("adj_buffer", adj)
266
+
267
+ @staticmethod
268
+ def _is_mps_unsupported_error(exc: BaseException) -> bool:
269
+ msg = str(exc).lower()
270
+ if "mps" not in msg:
271
+ return False
272
+ if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
273
+ return True
274
+ return "sparse" in msg
275
+
276
+ def _fallback_to_cpu(self, reason: str) -> None:
277
+ if self.device.type != "mps" or self._mps_fallback_triggered:
278
+ return
279
+ self._mps_fallback_triggered = True
280
+ _log(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
281
+ self.device = torch.device("cpu")
282
+ self.use_pyg_knn = False
283
+ self.data_parallel_enabled = False
284
+ self.ddp_enabled = False
285
+ base = self._unwrap_gnn()
286
+ try:
287
+ base = base.to(self.device)
288
+ except Exception:
289
+ pass
290
+ self.gnn = base
291
+ self.invalidate_graph_cache()
292
+
293
+ def _run_with_mps_fallback(self, fn, *args, **kwargs):
294
+ try:
295
+ return fn(*args, **kwargs)
296
+ except (RuntimeError, NotImplementedError) as exc:
297
+ if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
298
+ self._fallback_to_cpu(str(exc))
299
+ return fn(*args, **kwargs)
300
+ raise
301
+
302
+ def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
303
+ row_hash = pd.util.hash_pandas_object(X_df, index=False).values
304
+ idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
305
+ col_sig = ",".join(map(str, X_df.columns))
306
+ hasher = hashlib.sha256()
307
+ hasher.update(row_hash.tobytes())
308
+ hasher.update(idx_hash.tobytes())
309
+ hasher.update(col_sig.encode("utf-8", errors="ignore"))
310
+ knn_config = {
311
+ "k_neighbors": int(self.k_neighbors),
312
+ "use_approx_knn": bool(self.use_approx_knn),
313
+ "approx_knn_threshold": int(self.approx_knn_threshold),
314
+ "use_pyg_knn": bool(self.use_pyg_knn),
315
+ "pynndescent_available": bool(_PYNN_AVAILABLE),
316
+ "max_gpu_knn_nodes": (
317
+ None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
318
+ ),
319
+ "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
320
+ "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
321
+ }
322
+ adj_format = "dense" if self.device.type == "mps" else "sparse"
323
+ return {
324
+ "n_samples": int(X_df.shape[0]),
325
+ "n_features": int(X_df.shape[1]),
326
+ "hash": hasher.hexdigest(),
327
+ "knn_config": knn_config,
328
+ "adj_format": adj_format,
329
+ "device_type": self.device.type,
330
+ }
331
+
332
+ def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
333
+ return (
334
+ id(X_df),
335
+ id(getattr(X_df, "_mgr", None)),
336
+ id(X_df.index),
337
+ X_df.shape,
338
+ tuple(map(str, X_df.columns)),
339
+ X_df.attrs.get("graph_cache_key"),
340
+ )
341
+
342
+ def invalidate_graph_cache(self) -> None:
343
+ self._adj_cache_meta = None
344
+ self._adj_cache_key = None
345
+ self._adj_cache_tensor = None
346
+
347
+ def _load_cached_adj(self,
348
+ X_df: pd.DataFrame,
349
+ meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
350
+ if self.graph_cache_path and self.graph_cache_path.exists():
351
+ if meta_expected is None:
352
+ meta_expected = self._graph_cache_meta(X_df)
353
+ try:
354
+ payload = torch.load(self.graph_cache_path, map_location="cpu")
355
+ except Exception as exc:
356
+ _log(
357
+ f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
358
+ return None
359
+ if isinstance(payload, dict) and "adj" in payload:
360
+ meta_cached = payload.get("meta")
361
+ if meta_cached == meta_expected:
362
+ adj = payload["adj"]
363
+ if self.device.type == "mps" and getattr(adj, "is_sparse", False):
364
+ _log(
365
+ f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
366
+ )
367
+ return None
368
+ return adj.to(self.device)
369
+ _log(
370
+ f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
371
+ return None
372
+ if isinstance(payload, torch.Tensor):
373
+ _log(
374
+ f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
375
+ return None
376
+ _log(
377
+ f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
378
+ return None
379
+
380
+ def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
381
+ n_samples = X_np.shape[0]
382
+ k = min(self.k_neighbors, max(1, n_samples - 1))
383
+ n_neighbors = min(k + 1, n_samples)
384
+ use_approx = (self.use_approx_knn or n_samples >=
385
+ self.approx_knn_threshold) and _PYNN_AVAILABLE
386
+ indices = None
387
+ if use_approx:
388
+ try:
389
+ nn_index = pynndescent.NNDescent(
390
+ X_np,
391
+ n_neighbors=n_neighbors,
392
+ random_state=0
393
+ )
394
+ indices, _ = nn_index.neighbor_graph
395
+ except Exception as exc:
396
+ _log(
397
+ f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
398
+ use_approx = False
399
+
400
+ if indices is None:
401
+ nbrs = NearestNeighbors(
402
+ n_neighbors=n_neighbors,
403
+ algorithm="auto",
404
+ n_jobs=self.knn_cpu_jobs,
405
+ )
406
+ nbrs.fit(X_np)
407
+ _, indices = nbrs.kneighbors(X_np)
408
+
409
+ indices = np.asarray(indices)
410
+ rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
411
+ np.int64, copy=False)
412
+ cols = indices.reshape(-1).astype(np.int64, copy=False)
413
+ mask = rows != cols
414
+ rows = rows[mask]
415
+ cols = cols[mask]
416
+ rows_base = rows
417
+ cols_base = cols
418
+ self_loops = np.arange(n_samples, dtype=np.int64)
419
+ rows = np.concatenate([rows_base, cols_base, self_loops])
420
+ cols = np.concatenate([cols_base, rows_base, self_loops])
421
+
422
+ edge_index_np = np.stack([rows, cols], axis=0)
423
+ edge_index = torch.as_tensor(edge_index_np, device=self.device)
424
+ return edge_index
425
+
426
+ def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
427
+ if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
428
+ # Defensive: check use_pyg_knn before calling.
429
+ raise RuntimeError(
430
+ "GPU graph builder requested but PyG is unavailable.")
431
+
432
+ n_samples = X_tensor.size(0)
433
+ k = min(self.k_neighbors, max(1, n_samples - 1))
434
+
435
+ # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
436
+ edge_index = knn_graph(
437
+ X_tensor,
438
+ k=k,
439
+ loop=False
440
+ )
441
+ edge_index = to_undirected(edge_index, num_nodes=n_samples)
442
+ edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
443
+ return edge_index
444
+
445
+ def _log_knn_fallback(self, reason: str) -> None:
446
+ if self._knn_warning_emitted:
447
+ return
448
+ if (not self.ddp_enabled) or self.local_rank == 0:
449
+ _log(f"[GNN] Falling back to CPU kNN builder: {reason}")
450
+ self._knn_warning_emitted = True
451
+
452
+ def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
453
+ if not self.use_pyg_knn:
454
+ return False
455
+
456
+ reason = None
457
+ if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
458
+ reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
459
+ elif self.device.type == 'cuda' and torch.cuda.is_available():
460
+ try:
461
+ device_index = self.device.index
462
+ if device_index is None:
463
+ device_index = torch.cuda.current_device()
464
+ free_mem, total_mem = torch.cuda.mem_get_info(device_index)
465
+ feature_bytes = X_tensor.element_size() * X_tensor.nelement()
466
+ required = int(feature_bytes * self.knn_gpu_mem_overhead)
467
+ budget = int(free_mem * self.knn_gpu_mem_ratio)
468
+ if required > budget:
469
+ required_gb = required / (1024 ** 3)
470
+ budget_gb = budget / (1024 ** 3)
471
+ reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
472
+ f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
473
+ except Exception:
474
+ # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
475
+ reason = None
476
+
477
+ if reason:
478
+ self._log_knn_fallback(reason)
479
+ return False
480
+ return True
481
+
482
+ def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
483
+ if self.device.type == "mps":
484
+ return self._normalized_adj_dense(edge_index, num_nodes)
485
+ return self._normalized_adj_sparse(edge_index, num_nodes)
486
+
487
+ def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
488
+ values = torch.ones(edge_index.shape[1], device=self.device)
489
+ adj = torch.sparse_coo_tensor(
490
+ edge_index.to(self.device), values, (num_nodes, num_nodes))
491
+ adj = adj.coalesce()
492
+
493
+ deg = torch.sparse.sum(adj, dim=1).to_dense()
494
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
495
+ row, col = adj.indices()
496
+ norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
497
+ adj_norm = torch.sparse_coo_tensor(
498
+ adj.indices(), norm_values, size=adj.shape)
499
+ return adj_norm
500
+
501
+ def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
502
+ if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
503
+ raise RuntimeError(
504
+ f"MPS dense adjacency not supported for {num_nodes} nodes; "
505
+ f"max={self.mps_dense_max_nodes}. Falling back to CPU."
506
+ )
507
+ edge_index = edge_index.to(self.device)
508
+ adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
509
+ adj[edge_index[0], edge_index[1]] = 1.0
510
+ deg = adj.sum(dim=1)
511
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
512
+ adj = adj * deg_inv_sqrt.view(-1, 1)
513
+ adj = adj * deg_inv_sqrt.view(1, -1)
514
+ return adj
515
+
516
+ def _tensorize_split(self, X, y, w, allow_none: bool = False):
517
+ if X is None and allow_none:
518
+ return None, None, None
519
+ if not isinstance(X, pd.DataFrame):
520
+ raise ValueError("X must be a pandas DataFrame for GNN.")
521
+ n_rows = len(X)
522
+ if y is not None:
523
+ self._validate_vector(y, "y", n_rows)
524
+ if w is not None:
525
+ self._validate_vector(w, "w", n_rows)
526
+ X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
527
+ X, "to_numpy") else np.asarray(X, dtype=np.float32)
528
+ X_tensor = torch.as_tensor(
529
+ X_np, dtype=torch.float32, device=self.device)
530
+ if y is None:
531
+ y_tensor = None
532
+ else:
533
+ y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
534
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
535
+ y_tensor = torch.as_tensor(
536
+ y_np, dtype=torch.float32, device=self.device).view(-1, 1)
537
+ if w is None:
538
+ w_tensor = torch.ones(
539
+ (len(X), 1), dtype=torch.float32, device=self.device)
540
+ else:
541
+ w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
542
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
543
+ w_tensor = torch.as_tensor(
544
+ w_np, dtype=torch.float32, device=self.device).view(-1, 1)
545
+ return X_tensor, y_tensor, w_tensor
546
+
547
+ def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
548
+ if not isinstance(X_df, pd.DataFrame):
549
+ raise ValueError("X must be a pandas DataFrame for graph building.")
550
+ meta_expected = None
551
+ cache_key = None
552
+ if self.graph_cache_path:
553
+ meta_expected = self._graph_cache_meta(X_df)
554
+ if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
555
+ cached = self._adj_cache_tensor
556
+ if cached.device != self.device:
557
+ if self.device.type == "mps" and getattr(cached, "is_sparse", False):
558
+ self._adj_cache_tensor = None
559
+ else:
560
+ cached = cached.to(self.device)
561
+ self._adj_cache_tensor = cached
562
+ if self._adj_cache_tensor is not None:
563
+ return self._adj_cache_tensor
564
+ else:
565
+ cache_key = self._graph_cache_key(X_df)
566
+ if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
567
+ cached = self._adj_cache_tensor
568
+ if cached.device != self.device:
569
+ if self.device.type == "mps" and getattr(cached, "is_sparse", False):
570
+ self._adj_cache_tensor = None
571
+ else:
572
+ cached = cached.to(self.device)
573
+ self._adj_cache_tensor = cached
574
+ if self._adj_cache_tensor is not None:
575
+ return self._adj_cache_tensor
576
+ X_np = None
577
+ if X_tensor is None:
578
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
579
+ X_tensor = torch.as_tensor(
580
+ X_np, dtype=torch.float32, device=self.device)
581
+ if self.graph_cache_path:
582
+ cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
583
+ if cached is not None:
584
+ self._adj_cache_meta = meta_expected
585
+ self._adj_cache_key = None
586
+ self._adj_cache_tensor = cached
587
+ return cached
588
+ use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
589
+ if use_gpu_knn:
590
+ edge_index = self._build_edge_index_gpu(X_tensor)
591
+ else:
592
+ if X_np is None:
593
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
594
+ edge_index = self._build_edge_index_cpu(X_np)
595
+ adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
596
+ if self.graph_cache_path:
597
+ try:
598
+ IOUtils.ensure_parent_dir(str(self.graph_cache_path))
599
+ torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
600
+ except Exception as exc:
601
+ _log(
602
+ f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
603
+ self._adj_cache_meta = meta_expected
604
+ self._adj_cache_key = None
605
+ else:
606
+ self._adj_cache_meta = None
607
+ self._adj_cache_key = cache_key
608
+ self._adj_cache_tensor = adj_norm
609
+ return adj_norm
610
+
611
+ def fit(self, X_train, y_train, w_train=None,
612
+ X_val=None, y_val=None, w_val=None,
613
+ trial: Optional[optuna.trial.Trial] = None):
614
+ return self._run_with_mps_fallback(
615
+ self._fit_impl,
616
+ X_train,
617
+ y_train,
618
+ w_train,
619
+ X_val,
620
+ y_val,
621
+ w_val,
622
+ trial,
623
+ )
624
+
625
+ def _fit_impl(self, X_train, y_train, w_train=None,
626
+ X_val=None, y_val=None, w_val=None,
627
+ trial: Optional[optuna.trial.Trial] = None):
628
+ X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
629
+ X_train, y_train, w_train, allow_none=False)
630
+ has_val = X_val is not None and y_val is not None
631
+ if has_val:
632
+ X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
633
+ X_val, y_val, w_val, allow_none=False)
634
+ else:
635
+ X_val_tensor = y_val_tensor = w_val_tensor = None
636
+
637
+ adj_train = self._build_graph_from_df(X_train, X_train_tensor)
638
+ adj_val = self._build_graph_from_df(
639
+ X_val, X_val_tensor) if has_val else None
640
+ # DataParallel needs adjacency cached on the model to avoid scatter.
641
+ self._set_adj_buffer(adj_train)
642
+
643
+ base_gnn = self._unwrap_gnn()
644
+ optimizer = torch.optim.Adam(
645
+ base_gnn.parameters(),
646
+ lr=self.learning_rate,
647
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
648
+ )
649
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
650
+
651
+ best_loss = float('inf')
652
+ best_state = None
653
+ patience_counter = 0
654
+ best_epoch = None
655
+
656
+ for epoch in range(1, self.epochs + 1):
657
+ epoch_start_ts = time.time()
658
+ self.gnn.train()
659
+ optimizer.zero_grad()
660
+ with autocast(enabled=(self.device.type == 'cuda')):
661
+ if self.data_parallel_enabled:
662
+ y_pred = self.gnn(X_train_tensor)
663
+ else:
664
+ y_pred = self.gnn(X_train_tensor, adj_train)
665
+ loss = self._compute_weighted_loss(
666
+ y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
667
+ scaler.scale(loss).backward()
668
+ scaler.unscale_(optimizer)
669
+ clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
670
+ scaler.step(optimizer)
671
+ scaler.update()
672
+
673
+ val_loss = None
674
+ if has_val:
675
+ self.gnn.eval()
676
+ if self.data_parallel_enabled and adj_val is not None:
677
+ self._set_adj_buffer(adj_val)
678
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
679
+ if self.data_parallel_enabled:
680
+ y_val_pred = self.gnn(X_val_tensor)
681
+ else:
682
+ y_val_pred = self.gnn(X_val_tensor, adj_val)
683
+ val_loss = self._compute_weighted_loss(
684
+ y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
685
+ if self.data_parallel_enabled:
686
+ # Restore training adjacency.
687
+ self._set_adj_buffer(adj_train)
688
+
689
+ is_best = val_loss is not None and val_loss < best_loss
690
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
691
+ val_loss, best_loss, best_state, patience_counter, base_gnn,
692
+ ignore_keys=["adj_buffer"])
693
+ if is_best:
694
+ best_epoch = epoch
695
+
696
+ prune_now = False
697
+ if trial is not None:
698
+ trial.report(val_loss, epoch)
699
+ if trial.should_prune():
700
+ prune_now = True
701
+
702
+ if dist.is_initialized():
703
+ flag = torch.tensor(
704
+ [1 if prune_now else 0],
705
+ device=self.device,
706
+ dtype=torch.int32,
707
+ )
708
+ dist.broadcast(flag, src=0)
709
+ prune_now = bool(flag.item())
710
+
711
+ if prune_now:
712
+ raise optuna.TrialPruned()
713
+ if stop_training:
714
+ break
715
+
716
+ should_log = (not dist.is_initialized()
717
+ or DistributedUtils.is_main_process())
718
+ if should_log:
719
+ elapsed = int(time.time() - epoch_start_ts)
720
+ if val_loss is None:
721
+ _log(
722
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
723
+ flush=True,
724
+ )
725
+ else:
726
+ _log(
727
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
728
+ f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
729
+ flush=True,
730
+ )
731
+
732
+ if best_state is not None:
733
+ base_gnn.load_state_dict(best_state, strict=False)
734
+ self.best_epoch = int(best_epoch or self.epochs)
735
+
736
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
737
+ return self._run_with_mps_fallback(self._predict_impl, X)
738
+
739
+ def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
740
+ self.gnn.eval()
741
+ X_tensor, _, _ = self._tensorize_split(
742
+ X, None, None, allow_none=False)
743
+ adj = self._build_graph_from_df(X, X_tensor)
744
+ if self.data_parallel_enabled:
745
+ self._set_adj_buffer(adj)
746
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
747
+ with inference_cm():
748
+ if self.data_parallel_enabled:
749
+ y_pred = self.gnn(X_tensor).cpu().numpy()
750
+ else:
751
+ y_pred = self.gnn(X_tensor, adj).cpu().numpy()
752
+ if self.task_type == 'classification':
753
+ y_pred = 1 / (1 + np.exp(-y_pred))
754
+ else:
755
+ y_pred = np.clip(y_pred, 1e-6, None)
756
+ return y_pred.ravel()
757
+
758
+ def encode(self, X: pd.DataFrame) -> np.ndarray:
759
+ return self._run_with_mps_fallback(self._encode_impl, X)
760
+
761
+ def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
762
+ """Return per-sample node embeddings (hidden representations)."""
763
+ base = self._unwrap_gnn()
764
+ base.eval()
765
+ X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
766
+ adj = self._build_graph_from_df(X, X_tensor)
767
+ if self.data_parallel_enabled:
768
+ self._set_adj_buffer(adj)
769
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
770
+ with inference_cm():
771
+ h = X_tensor
772
+ layers = getattr(base, "layers", None)
773
+ if layers is None:
774
+ raise RuntimeError("GNN base module does not expose layers.")
775
+ for layer in layers:
776
+ h = layer(h, adj)
777
+ h = _adj_mm(adj, h)
778
+ return h.detach().cpu().numpy()
779
+
780
+ def set_params(self, params: Dict[str, Any]):
781
+ for key, value in params.items():
782
+ if hasattr(self, key):
783
+ setattr(self, key, value)
784
+ else:
785
+ raise ValueError(f"Parameter {key} not found in GNN model.")
786
+ # Rebuild the backbone after structural parameter changes.
787
+ self.gnn = SimpleGNN(
788
+ input_dim=self.input_dim,
789
+ hidden_dim=self.hidden_dim,
790
+ num_layers=self.num_layers,
791
+ dropout=self.dropout,
792
+ task_type=self.task_type
793
+ ).to(self.device)
794
+ return self