ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,675 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn as nn
14
+ from sklearn.neighbors import NearestNeighbors
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.nn.utils import clip_grad_norm_
18
+
19
+ from ..utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
20
+
21
+ try:
22
+ from torch_geometric.nn import knn_graph
23
+ from torch_geometric.utils import add_self_loops, to_undirected
24
+ _PYG_AVAILABLE = True
25
+ except Exception:
26
+ knn_graph = None # type: ignore
27
+ add_self_loops = None # type: ignore
28
+ to_undirected = None # type: ignore
29
+ _PYG_AVAILABLE = False
30
+
31
+ try:
32
+ import pynndescent
33
+ _PYNN_AVAILABLE = True
34
+ except Exception:
35
+ pynndescent = None # type: ignore
36
+ _PYNN_AVAILABLE = False
37
+
38
+ _GNN_MPS_WARNED = False
39
+
40
+
41
+ # =============================================================================
42
+ # Simplified GNN implementation.
43
+ # =============================================================================
44
+
45
+
46
+ class SimpleGraphLayer(nn.Module):
47
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
48
+ super().__init__()
49
+ self.linear = nn.Linear(in_dim, out_dim)
50
+ self.activation = nn.ReLU(inplace=True)
51
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
52
+
53
+ def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
54
+ # Message passing with normalized sparse adjacency: A_hat * X * W.
55
+ h = torch.sparse.mm(adj, x)
56
+ h = self.linear(h)
57
+ h = self.activation(h)
58
+ return self.dropout(h)
59
+
60
+
61
+ class SimpleGNN(nn.Module):
62
+ def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
63
+ dropout: float = 0.1, task_type: str = 'regression'):
64
+ super().__init__()
65
+ layers = []
66
+ dim_in = input_dim
67
+ for _ in range(max(1, num_layers)):
68
+ layers.append(SimpleGraphLayer(
69
+ dim_in, hidden_dim, dropout=dropout))
70
+ dim_in = hidden_dim
71
+ self.layers = nn.ModuleList(layers)
72
+ self.output = nn.Linear(hidden_dim, 1)
73
+ if task_type == 'classification':
74
+ self.output_act = nn.Identity()
75
+ else:
76
+ self.output_act = nn.Softplus()
77
+ self.task_type = task_type
78
+ # Keep adjacency as a buffer for DataParallel copies.
79
+ self.register_buffer("adj_buffer", torch.empty(0))
80
+
81
+ def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
82
+ adj_used = adj if adj is not None else getattr(
83
+ self, "adj_buffer", None)
84
+ if adj_used is None or adj_used.numel() == 0:
85
+ raise RuntimeError("Adjacency is not set for GNN forward.")
86
+ h = x
87
+ for layer in self.layers:
88
+ h = layer(h, adj_used)
89
+ h = torch.sparse.mm(adj_used, h)
90
+ out = self.output(h)
91
+ return self.output_act(out)
92
+
93
+
94
+ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
95
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
96
+ num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
97
+ learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
98
+ task_type: str = 'regression', tweedie_power: float = 1.5,
99
+ weight_decay: float = 0.0,
100
+ use_data_parallel: bool = False, use_ddp: bool = False,
101
+ use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
102
+ graph_cache_path: Optional[str] = None,
103
+ max_gpu_knn_nodes: Optional[int] = None,
104
+ knn_gpu_mem_ratio: float = 0.9,
105
+ knn_gpu_mem_overhead: float = 2.0,
106
+ knn_cpu_jobs: Optional[int] = -1) -> None:
107
+ super().__init__()
108
+ self.model_nme = model_nme
109
+ self.input_dim = input_dim
110
+ self.hidden_dim = hidden_dim
111
+ self.num_layers = num_layers
112
+ self.k_neighbors = max(1, k_neighbors)
113
+ self.dropout = dropout
114
+ self.learning_rate = learning_rate
115
+ self.weight_decay = weight_decay
116
+ self.epochs = epochs
117
+ self.patience = patience
118
+ self.task_type = task_type
119
+ self.use_approx_knn = use_approx_knn
120
+ self.approx_knn_threshold = approx_knn_threshold
121
+ self.graph_cache_path = Path(
122
+ graph_cache_path) if graph_cache_path else None
123
+ self.max_gpu_knn_nodes = max_gpu_knn_nodes
124
+ self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
125
+ self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
126
+ self.knn_cpu_jobs = knn_cpu_jobs
127
+ self._knn_warning_emitted = False
128
+ self._adj_cache_meta: Optional[Dict[str, Any]] = None
129
+ self._adj_cache_key: Optional[Tuple[Any, ...]] = None
130
+ self._adj_cache_tensor: Optional[torch.Tensor] = None
131
+
132
+ if self.task_type == 'classification':
133
+ self.tw_power = None
134
+ elif 'f' in self.model_nme:
135
+ self.tw_power = 1.0
136
+ elif 's' in self.model_nme:
137
+ self.tw_power = 2.0
138
+ else:
139
+ self.tw_power = tweedie_power
140
+
141
+ self.ddp_enabled = False
142
+ self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
143
+ self.data_parallel_enabled = False
144
+ self._ddp_disabled = False
145
+
146
+ if use_ddp:
147
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
148
+ if world_size > 1:
149
+ print(
150
+ "[GNN] DDP training is not supported; falling back to single process.",
151
+ flush=True,
152
+ )
153
+ self._ddp_disabled = True
154
+ use_ddp = False
155
+
156
+ # DDP only works with CUDA; fall back to single process if init fails.
157
+ if use_ddp and torch.cuda.is_available():
158
+ ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
159
+ if ddp_ok:
160
+ self.ddp_enabled = True
161
+ self.local_rank = local_rank
162
+ self.device = torch.device(f'cuda:{local_rank}')
163
+ else:
164
+ self.device = torch.device('cuda')
165
+ elif torch.cuda.is_available():
166
+ if self._ddp_disabled:
167
+ self.device = torch.device(f'cuda:{self.local_rank}')
168
+ else:
169
+ self.device = torch.device('cuda')
170
+ elif torch.backends.mps.is_available():
171
+ self.device = torch.device('cpu')
172
+ global _GNN_MPS_WARNED
173
+ if not _GNN_MPS_WARNED:
174
+ print(
175
+ "[GNN] MPS backend does not support sparse ops; falling back to CPU.",
176
+ flush=True,
177
+ )
178
+ _GNN_MPS_WARNED = True
179
+ else:
180
+ self.device = torch.device('cpu')
181
+ self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
182
+
183
+ self.gnn = SimpleGNN(
184
+ input_dim=self.input_dim,
185
+ hidden_dim=self.hidden_dim,
186
+ num_layers=self.num_layers,
187
+ dropout=self.dropout,
188
+ task_type=self.task_type
189
+ ).to(self.device)
190
+
191
+ # DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
192
+ if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
193
+ self.data_parallel_enabled = True
194
+ self.gnn = nn.DataParallel(
195
+ self.gnn, device_ids=list(range(torch.cuda.device_count())))
196
+ self.device = torch.device('cuda')
197
+
198
+ if self.ddp_enabled:
199
+ self.gnn = DDP(
200
+ self.gnn,
201
+ device_ids=[self.local_rank],
202
+ output_device=self.local_rank,
203
+ find_unused_parameters=False
204
+ )
205
+
206
+ @staticmethod
207
+ def _validate_vector(arr, name: str, n_rows: int) -> None:
208
+ if arr is None:
209
+ return
210
+ if isinstance(arr, pd.DataFrame):
211
+ if arr.shape[1] != 1:
212
+ raise ValueError(f"{name} must be 1d (single column).")
213
+ length = len(arr)
214
+ else:
215
+ arr_np = np.asarray(arr)
216
+ if arr_np.ndim == 0:
217
+ raise ValueError(f"{name} must be 1d.")
218
+ if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
219
+ raise ValueError(f"{name} must be 1d or Nx1.")
220
+ length = arr_np.shape[0]
221
+ if length != n_rows:
222
+ raise ValueError(
223
+ f"{name} length {length} does not match X length {n_rows}."
224
+ )
225
+
226
+ def _unwrap_gnn(self) -> nn.Module:
227
+ if isinstance(self.gnn, (DDP, nn.DataParallel)):
228
+ return self.gnn.module
229
+ return self.gnn
230
+
231
+ def _set_adj_buffer(self, adj: torch.Tensor) -> None:
232
+ base = self._unwrap_gnn()
233
+ if hasattr(base, "adj_buffer"):
234
+ base.adj_buffer = adj
235
+ else:
236
+ base.register_buffer("adj_buffer", adj)
237
+
238
+ def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
239
+ row_hash = pd.util.hash_pandas_object(X_df, index=False).values
240
+ idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
241
+ col_sig = ",".join(map(str, X_df.columns))
242
+ hasher = hashlib.sha256()
243
+ hasher.update(row_hash.tobytes())
244
+ hasher.update(idx_hash.tobytes())
245
+ hasher.update(col_sig.encode("utf-8", errors="ignore"))
246
+ knn_config = {
247
+ "k_neighbors": int(self.k_neighbors),
248
+ "use_approx_knn": bool(self.use_approx_knn),
249
+ "approx_knn_threshold": int(self.approx_knn_threshold),
250
+ "use_pyg_knn": bool(self.use_pyg_knn),
251
+ "pynndescent_available": bool(_PYNN_AVAILABLE),
252
+ "max_gpu_knn_nodes": (
253
+ None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
254
+ ),
255
+ "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
256
+ "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
257
+ }
258
+ return {
259
+ "n_samples": int(X_df.shape[0]),
260
+ "n_features": int(X_df.shape[1]),
261
+ "hash": hasher.hexdigest(),
262
+ "knn_config": knn_config,
263
+ }
264
+
265
+ def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
266
+ return (
267
+ id(X_df),
268
+ id(getattr(X_df, "_mgr", None)),
269
+ id(X_df.index),
270
+ X_df.shape,
271
+ tuple(map(str, X_df.columns)),
272
+ X_df.attrs.get("graph_cache_key"),
273
+ )
274
+
275
+ def invalidate_graph_cache(self) -> None:
276
+ self._adj_cache_meta = None
277
+ self._adj_cache_key = None
278
+ self._adj_cache_tensor = None
279
+
280
+ def _load_cached_adj(self,
281
+ X_df: pd.DataFrame,
282
+ meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
283
+ if self.graph_cache_path and self.graph_cache_path.exists():
284
+ if meta_expected is None:
285
+ meta_expected = self._graph_cache_meta(X_df)
286
+ try:
287
+ payload = torch.load(self.graph_cache_path,
288
+ map_location=self.device)
289
+ except Exception as exc:
290
+ print(
291
+ f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
292
+ return None
293
+ if isinstance(payload, dict) and "adj" in payload:
294
+ meta_cached = payload.get("meta")
295
+ if meta_cached == meta_expected:
296
+ return payload["adj"].to(self.device)
297
+ print(
298
+ f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
299
+ return None
300
+ if isinstance(payload, torch.Tensor):
301
+ print(
302
+ f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
303
+ return None
304
+ print(
305
+ f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
306
+ return None
307
+
308
+ def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
309
+ n_samples = X_np.shape[0]
310
+ k = min(self.k_neighbors, max(1, n_samples - 1))
311
+ n_neighbors = min(k + 1, n_samples)
312
+ use_approx = (self.use_approx_knn or n_samples >=
313
+ self.approx_knn_threshold) and _PYNN_AVAILABLE
314
+ indices = None
315
+ if use_approx:
316
+ try:
317
+ nn_index = pynndescent.NNDescent(
318
+ X_np,
319
+ n_neighbors=n_neighbors,
320
+ random_state=0
321
+ )
322
+ indices, _ = nn_index.neighbor_graph
323
+ except Exception as exc:
324
+ print(
325
+ f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
326
+ use_approx = False
327
+
328
+ if indices is None:
329
+ nbrs = NearestNeighbors(
330
+ n_neighbors=n_neighbors,
331
+ algorithm="auto",
332
+ n_jobs=self.knn_cpu_jobs,
333
+ )
334
+ nbrs.fit(X_np)
335
+ _, indices = nbrs.kneighbors(X_np)
336
+
337
+ indices = np.asarray(indices)
338
+ rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
339
+ np.int64, copy=False)
340
+ cols = indices.reshape(-1).astype(np.int64, copy=False)
341
+ mask = rows != cols
342
+ rows = rows[mask]
343
+ cols = cols[mask]
344
+ rows_base = rows
345
+ cols_base = cols
346
+ self_loops = np.arange(n_samples, dtype=np.int64)
347
+ rows = np.concatenate([rows_base, cols_base, self_loops])
348
+ cols = np.concatenate([cols_base, rows_base, self_loops])
349
+
350
+ edge_index_np = np.stack([rows, cols], axis=0)
351
+ edge_index = torch.as_tensor(edge_index_np, device=self.device)
352
+ return edge_index
353
+
354
+ def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
355
+ if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
356
+ # Defensive: check use_pyg_knn before calling.
357
+ raise RuntimeError(
358
+ "GPU graph builder requested but PyG is unavailable.")
359
+
360
+ n_samples = X_tensor.size(0)
361
+ k = min(self.k_neighbors, max(1, n_samples - 1))
362
+
363
+ # knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
364
+ edge_index = knn_graph(
365
+ X_tensor,
366
+ k=k,
367
+ loop=False
368
+ )
369
+ edge_index = to_undirected(edge_index, num_nodes=n_samples)
370
+ edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
371
+ return edge_index
372
+
373
+ def _log_knn_fallback(self, reason: str) -> None:
374
+ if self._knn_warning_emitted:
375
+ return
376
+ if (not self.ddp_enabled) or self.local_rank == 0:
377
+ print(f"[GNN] Falling back to CPU kNN builder: {reason}")
378
+ self._knn_warning_emitted = True
379
+
380
+ def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
381
+ if not self.use_pyg_knn:
382
+ return False
383
+
384
+ reason = None
385
+ if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
386
+ reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
387
+ elif self.device.type == 'cuda' and torch.cuda.is_available():
388
+ try:
389
+ device_index = self.device.index
390
+ if device_index is None:
391
+ device_index = torch.cuda.current_device()
392
+ free_mem, total_mem = torch.cuda.mem_get_info(device_index)
393
+ feature_bytes = X_tensor.element_size() * X_tensor.nelement()
394
+ required = int(feature_bytes * self.knn_gpu_mem_overhead)
395
+ budget = int(free_mem * self.knn_gpu_mem_ratio)
396
+ if required > budget:
397
+ required_gb = required / (1024 ** 3)
398
+ budget_gb = budget / (1024 ** 3)
399
+ reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
400
+ f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
401
+ except Exception:
402
+ # On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
403
+ reason = None
404
+
405
+ if reason:
406
+ self._log_knn_fallback(reason)
407
+ return False
408
+ return True
409
+
410
+ def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
411
+ values = torch.ones(edge_index.shape[1], device=self.device)
412
+ adj = torch.sparse_coo_tensor(
413
+ edge_index.to(self.device), values, (num_nodes, num_nodes))
414
+ adj = adj.coalesce()
415
+
416
+ deg = torch.sparse.sum(adj, dim=1).to_dense()
417
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
418
+ row, col = adj.indices()
419
+ norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
420
+ adj_norm = torch.sparse_coo_tensor(
421
+ adj.indices(), norm_values, size=adj.shape)
422
+ return adj_norm
423
+
424
+ def _tensorize_split(self, X, y, w, allow_none: bool = False):
425
+ if X is None and allow_none:
426
+ return None, None, None
427
+ if not isinstance(X, pd.DataFrame):
428
+ raise ValueError("X must be a pandas DataFrame for GNN.")
429
+ n_rows = len(X)
430
+ if y is not None:
431
+ self._validate_vector(y, "y", n_rows)
432
+ if w is not None:
433
+ self._validate_vector(w, "w", n_rows)
434
+ X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
435
+ X, "to_numpy") else np.asarray(X, dtype=np.float32)
436
+ X_tensor = torch.as_tensor(
437
+ X_np, dtype=torch.float32, device=self.device)
438
+ if y is None:
439
+ y_tensor = None
440
+ else:
441
+ y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
442
+ y, "to_numpy") else np.asarray(y, dtype=np.float32)
443
+ y_tensor = torch.as_tensor(
444
+ y_np, dtype=torch.float32, device=self.device).view(-1, 1)
445
+ if w is None:
446
+ w_tensor = torch.ones(
447
+ (len(X), 1), dtype=torch.float32, device=self.device)
448
+ else:
449
+ w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
450
+ w, "to_numpy") else np.asarray(w, dtype=np.float32)
451
+ w_tensor = torch.as_tensor(
452
+ w_np, dtype=torch.float32, device=self.device).view(-1, 1)
453
+ return X_tensor, y_tensor, w_tensor
454
+
455
+ def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
456
+ if not isinstance(X_df, pd.DataFrame):
457
+ raise ValueError("X must be a pandas DataFrame for graph building.")
458
+ meta_expected = None
459
+ cache_key = None
460
+ if self.graph_cache_path:
461
+ meta_expected = self._graph_cache_meta(X_df)
462
+ if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
463
+ cached = self._adj_cache_tensor
464
+ if cached.device != self.device:
465
+ cached = cached.to(self.device)
466
+ self._adj_cache_tensor = cached
467
+ return cached
468
+ else:
469
+ cache_key = self._graph_cache_key(X_df)
470
+ if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
471
+ cached = self._adj_cache_tensor
472
+ if cached.device != self.device:
473
+ cached = cached.to(self.device)
474
+ self._adj_cache_tensor = cached
475
+ return cached
476
+ X_np = None
477
+ if X_tensor is None:
478
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
479
+ X_tensor = torch.as_tensor(
480
+ X_np, dtype=torch.float32, device=self.device)
481
+ if self.graph_cache_path:
482
+ cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
483
+ if cached is not None:
484
+ self._adj_cache_meta = meta_expected
485
+ self._adj_cache_key = None
486
+ self._adj_cache_tensor = cached
487
+ return cached
488
+ use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
489
+ if use_gpu_knn:
490
+ edge_index = self._build_edge_index_gpu(X_tensor)
491
+ else:
492
+ if X_np is None:
493
+ X_np = X_df.to_numpy(dtype=np.float32, copy=False)
494
+ edge_index = self._build_edge_index_cpu(X_np)
495
+ adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
496
+ if self.graph_cache_path:
497
+ try:
498
+ IOUtils.ensure_parent_dir(str(self.graph_cache_path))
499
+ torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
500
+ except Exception as exc:
501
+ print(
502
+ f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
503
+ self._adj_cache_meta = meta_expected
504
+ self._adj_cache_key = None
505
+ else:
506
+ self._adj_cache_meta = None
507
+ self._adj_cache_key = cache_key
508
+ self._adj_cache_tensor = adj_norm
509
+ return adj_norm
510
+
511
+ def fit(self, X_train, y_train, w_train=None,
512
+ X_val=None, y_val=None, w_val=None,
513
+ trial: Optional[optuna.trial.Trial] = None):
514
+
515
+ X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
516
+ X_train, y_train, w_train, allow_none=False)
517
+ has_val = X_val is not None and y_val is not None
518
+ if has_val:
519
+ X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
520
+ X_val, y_val, w_val, allow_none=False)
521
+ else:
522
+ X_val_tensor = y_val_tensor = w_val_tensor = None
523
+
524
+ adj_train = self._build_graph_from_df(X_train, X_train_tensor)
525
+ adj_val = self._build_graph_from_df(
526
+ X_val, X_val_tensor) if has_val else None
527
+ # DataParallel needs adjacency cached on the model to avoid scatter.
528
+ self._set_adj_buffer(adj_train)
529
+
530
+ base_gnn = self._unwrap_gnn()
531
+ optimizer = torch.optim.Adam(
532
+ base_gnn.parameters(),
533
+ lr=self.learning_rate,
534
+ weight_decay=float(getattr(self, "weight_decay", 0.0)),
535
+ )
536
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
537
+
538
+ best_loss = float('inf')
539
+ best_state = None
540
+ patience_counter = 0
541
+ best_epoch = None
542
+
543
+ for epoch in range(1, self.epochs + 1):
544
+ epoch_start_ts = time.time()
545
+ self.gnn.train()
546
+ optimizer.zero_grad()
547
+ with autocast(enabled=(self.device.type == 'cuda')):
548
+ if self.data_parallel_enabled:
549
+ y_pred = self.gnn(X_train_tensor)
550
+ else:
551
+ y_pred = self.gnn(X_train_tensor, adj_train)
552
+ loss = self._compute_weighted_loss(
553
+ y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
554
+ scaler.scale(loss).backward()
555
+ scaler.unscale_(optimizer)
556
+ clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
557
+ scaler.step(optimizer)
558
+ scaler.update()
559
+
560
+ val_loss = None
561
+ if has_val:
562
+ self.gnn.eval()
563
+ if self.data_parallel_enabled and adj_val is not None:
564
+ self._set_adj_buffer(adj_val)
565
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
566
+ if self.data_parallel_enabled:
567
+ y_val_pred = self.gnn(X_val_tensor)
568
+ else:
569
+ y_val_pred = self.gnn(X_val_tensor, adj_val)
570
+ val_loss = self._compute_weighted_loss(
571
+ y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
572
+ if self.data_parallel_enabled:
573
+ # Restore training adjacency.
574
+ self._set_adj_buffer(adj_train)
575
+
576
+ is_best = val_loss is not None and val_loss < best_loss
577
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
578
+ val_loss, best_loss, best_state, patience_counter, base_gnn,
579
+ ignore_keys=["adj_buffer"])
580
+ if is_best:
581
+ best_epoch = epoch
582
+
583
+ prune_now = False
584
+ if trial is not None:
585
+ trial.report(val_loss, epoch)
586
+ if trial.should_prune():
587
+ prune_now = True
588
+
589
+ if dist.is_initialized():
590
+ flag = torch.tensor(
591
+ [1 if prune_now else 0],
592
+ device=self.device,
593
+ dtype=torch.int32,
594
+ )
595
+ dist.broadcast(flag, src=0)
596
+ prune_now = bool(flag.item())
597
+
598
+ if prune_now:
599
+ raise optuna.TrialPruned()
600
+ if stop_training:
601
+ break
602
+
603
+ should_log = (not dist.is_initialized()
604
+ or DistributedUtils.is_main_process())
605
+ if should_log:
606
+ elapsed = int(time.time() - epoch_start_ts)
607
+ if val_loss is None:
608
+ print(
609
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
610
+ flush=True,
611
+ )
612
+ else:
613
+ print(
614
+ f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
615
+ f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
616
+ flush=True,
617
+ )
618
+
619
+ if best_state is not None:
620
+ base_gnn.load_state_dict(best_state, strict=False)
621
+ self.best_epoch = int(best_epoch or self.epochs)
622
+
623
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
624
+ self.gnn.eval()
625
+ X_tensor, _, _ = self._tensorize_split(
626
+ X, None, None, allow_none=False)
627
+ adj = self._build_graph_from_df(X, X_tensor)
628
+ if self.data_parallel_enabled:
629
+ self._set_adj_buffer(adj)
630
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
631
+ with inference_cm():
632
+ if self.data_parallel_enabled:
633
+ y_pred = self.gnn(X_tensor).cpu().numpy()
634
+ else:
635
+ y_pred = self.gnn(X_tensor, adj).cpu().numpy()
636
+ if self.task_type == 'classification':
637
+ y_pred = 1 / (1 + np.exp(-y_pred))
638
+ else:
639
+ y_pred = np.clip(y_pred, 1e-6, None)
640
+ return y_pred.ravel()
641
+
642
+ def encode(self, X: pd.DataFrame) -> np.ndarray:
643
+ """Return per-sample node embeddings (hidden representations)."""
644
+ base = self._unwrap_gnn()
645
+ base.eval()
646
+ X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
647
+ adj = self._build_graph_from_df(X, X_tensor)
648
+ if self.data_parallel_enabled:
649
+ self._set_adj_buffer(adj)
650
+ inference_cm = getattr(torch, "inference_mode", torch.no_grad)
651
+ with inference_cm():
652
+ h = X_tensor
653
+ layers = getattr(base, "layers", None)
654
+ if layers is None:
655
+ raise RuntimeError("GNN base module does not expose layers.")
656
+ for layer in layers:
657
+ h = layer(h, adj)
658
+ h = torch.sparse.mm(adj, h)
659
+ return h.detach().cpu().numpy()
660
+
661
+ def set_params(self, params: Dict[str, Any]):
662
+ for key, value in params.items():
663
+ if hasattr(self, key):
664
+ setattr(self, key, value)
665
+ else:
666
+ raise ValueError(f"Parameter {key} not found in GNN model.")
667
+ # Rebuild the backbone after structural parameter changes.
668
+ self.gnn = SimpleGNN(
669
+ input_dim=self.input_dim,
670
+ hidden_dim=self.hidden_dim,
671
+ num_layers=self.num_layers,
672
+ dropout=self.dropout,
673
+ task_type=self.task_type
674
+ ).to(self.device)
675
+ return self