ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +9 -6
- ins_pricing/__init__.py +3 -11
- ins_pricing/cli/BayesOpt_entry.py +24 -0
- ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
- ins_pricing/cli/Explain_Run.py +25 -0
- ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
- ins_pricing/cli/Pricing_Run.py +25 -0
- ins_pricing/cli/__init__.py +1 -0
- ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
- ins_pricing/cli/utils/__init__.py +1 -0
- ins_pricing/cli/utils/cli_common.py +320 -0
- ins_pricing/cli/utils/cli_config.py +375 -0
- ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
- {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
- ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
- ins_pricing/docs/modelling/README.md +34 -0
- ins_pricing/modelling/__init__.py +57 -6
- ins_pricing/modelling/core/__init__.py +1 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
- ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
- ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
- ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
- ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
- ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
- ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
- ins_pricing/modelling/core/evaluation.py +115 -0
- ins_pricing/production/__init__.py +4 -0
- ins_pricing/production/preprocess.py +71 -0
- ins_pricing/setup.py +10 -5
- {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
- ins_pricing-0.2.0.dist-info/RECORD +125 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
- ins_pricing/modelling/BayesOpt_entry.py +0 -633
- ins_pricing/modelling/Explain_Run.py +0 -36
- ins_pricing/modelling/Pricing_Run.py +0 -36
- ins_pricing/modelling/README.md +0 -33
- ins_pricing/modelling/bayesopt/models.py +0 -2196
- ins_pricing/modelling/bayesopt/trainers.py +0 -2446
- ins_pricing/modelling/cli_common.py +0 -136
- ins_pricing/modelling/tests/test_plotting.py +0 -63
- ins_pricing/modelling/watchdog_run.py +0 -211
- ins_pricing-0.1.11.dist-info/RECORD +0 -169
- ins_pricing_gemini/__init__.py +0 -23
- ins_pricing_gemini/governance/__init__.py +0 -20
- ins_pricing_gemini/governance/approval.py +0 -93
- ins_pricing_gemini/governance/audit.py +0 -37
- ins_pricing_gemini/governance/registry.py +0 -99
- ins_pricing_gemini/governance/release.py +0 -159
- ins_pricing_gemini/modelling/Explain_Run.py +0 -36
- ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
- ins_pricing_gemini/modelling/__init__.py +0 -151
- ins_pricing_gemini/modelling/cli_common.py +0 -141
- ins_pricing_gemini/modelling/config.py +0 -249
- ins_pricing_gemini/modelling/config_preprocess.py +0 -254
- ins_pricing_gemini/modelling/core.py +0 -741
- ins_pricing_gemini/modelling/data_container.py +0 -42
- ins_pricing_gemini/modelling/explain/__init__.py +0 -55
- ins_pricing_gemini/modelling/explain/gradients.py +0 -334
- ins_pricing_gemini/modelling/explain/metrics.py +0 -176
- ins_pricing_gemini/modelling/explain/permutation.py +0 -155
- ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
- ins_pricing_gemini/modelling/features.py +0 -215
- ins_pricing_gemini/modelling/model_manager.py +0 -148
- ins_pricing_gemini/modelling/model_plotting.py +0 -463
- ins_pricing_gemini/modelling/models.py +0 -2203
- ins_pricing_gemini/modelling/notebook_utils.py +0 -294
- ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
- ins_pricing_gemini/modelling/plotting/common.py +0 -63
- ins_pricing_gemini/modelling/plotting/curves.py +0 -572
- ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
- ins_pricing_gemini/modelling/plotting/geo.py +0 -362
- ins_pricing_gemini/modelling/plotting/importance.py +0 -121
- ins_pricing_gemini/modelling/run_logging.py +0 -133
- ins_pricing_gemini/modelling/tests/conftest.py +0 -8
- ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
- ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
- ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
- ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
- ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
- ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
- ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
- ins_pricing_gemini/modelling/trainers.py +0 -2447
- ins_pricing_gemini/modelling/utils.py +0 -1020
- ins_pricing_gemini/pricing/__init__.py +0 -27
- ins_pricing_gemini/pricing/calibration.py +0 -39
- ins_pricing_gemini/pricing/data_quality.py +0 -117
- ins_pricing_gemini/pricing/exposure.py +0 -85
- ins_pricing_gemini/pricing/factors.py +0 -91
- ins_pricing_gemini/pricing/monitoring.py +0 -99
- ins_pricing_gemini/pricing/rate_table.py +0 -78
- ins_pricing_gemini/production/__init__.py +0 -21
- ins_pricing_gemini/production/drift.py +0 -30
- ins_pricing_gemini/production/monitoring.py +0 -143
- ins_pricing_gemini/production/scoring.py +0 -40
- ins_pricing_gemini/reporting/__init__.py +0 -11
- ins_pricing_gemini/reporting/report_builder.py +0 -72
- ins_pricing_gemini/reporting/scheduler.py +0 -45
- ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
- ins_pricing_gemini/scripts/Explain_entry.py +0 -545
- ins_pricing_gemini/scripts/__init__.py +0 -1
- ins_pricing_gemini/scripts/train.py +0 -568
- ins_pricing_gemini/setup.py +0 -55
- ins_pricing_gemini/smoke_test.py +0 -28
- /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
- /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
- /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,675 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import torch
|
|
12
|
+
import torch.distributed as dist
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from sklearn.neighbors import NearestNeighbors
|
|
15
|
+
from torch.cuda.amp import autocast, GradScaler
|
|
16
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
17
|
+
from torch.nn.utils import clip_grad_norm_
|
|
18
|
+
|
|
19
|
+
from ..utils import DistributedUtils, EPS, IOUtils, TorchTrainerMixin
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from torch_geometric.nn import knn_graph
|
|
23
|
+
from torch_geometric.utils import add_self_loops, to_undirected
|
|
24
|
+
_PYG_AVAILABLE = True
|
|
25
|
+
except Exception:
|
|
26
|
+
knn_graph = None # type: ignore
|
|
27
|
+
add_self_loops = None # type: ignore
|
|
28
|
+
to_undirected = None # type: ignore
|
|
29
|
+
_PYG_AVAILABLE = False
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import pynndescent
|
|
33
|
+
_PYNN_AVAILABLE = True
|
|
34
|
+
except Exception:
|
|
35
|
+
pynndescent = None # type: ignore
|
|
36
|
+
_PYNN_AVAILABLE = False
|
|
37
|
+
|
|
38
|
+
_GNN_MPS_WARNED = False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# =============================================================================
|
|
42
|
+
# Simplified GNN implementation.
|
|
43
|
+
# =============================================================================
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SimpleGraphLayer(nn.Module):
|
|
47
|
+
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.linear = nn.Linear(in_dim, out_dim)
|
|
50
|
+
self.activation = nn.ReLU(inplace=True)
|
|
51
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
# Message passing with normalized sparse adjacency: A_hat * X * W.
|
|
55
|
+
h = torch.sparse.mm(adj, x)
|
|
56
|
+
h = self.linear(h)
|
|
57
|
+
h = self.activation(h)
|
|
58
|
+
return self.dropout(h)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SimpleGNN(nn.Module):
|
|
62
|
+
def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
|
|
63
|
+
dropout: float = 0.1, task_type: str = 'regression'):
|
|
64
|
+
super().__init__()
|
|
65
|
+
layers = []
|
|
66
|
+
dim_in = input_dim
|
|
67
|
+
for _ in range(max(1, num_layers)):
|
|
68
|
+
layers.append(SimpleGraphLayer(
|
|
69
|
+
dim_in, hidden_dim, dropout=dropout))
|
|
70
|
+
dim_in = hidden_dim
|
|
71
|
+
self.layers = nn.ModuleList(layers)
|
|
72
|
+
self.output = nn.Linear(hidden_dim, 1)
|
|
73
|
+
if task_type == 'classification':
|
|
74
|
+
self.output_act = nn.Identity()
|
|
75
|
+
else:
|
|
76
|
+
self.output_act = nn.Softplus()
|
|
77
|
+
self.task_type = task_type
|
|
78
|
+
# Keep adjacency as a buffer for DataParallel copies.
|
|
79
|
+
self.register_buffer("adj_buffer", torch.empty(0))
|
|
80
|
+
|
|
81
|
+
def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
82
|
+
adj_used = adj if adj is not None else getattr(
|
|
83
|
+
self, "adj_buffer", None)
|
|
84
|
+
if adj_used is None or adj_used.numel() == 0:
|
|
85
|
+
raise RuntimeError("Adjacency is not set for GNN forward.")
|
|
86
|
+
h = x
|
|
87
|
+
for layer in self.layers:
|
|
88
|
+
h = layer(h, adj_used)
|
|
89
|
+
h = torch.sparse.mm(adj_used, h)
|
|
90
|
+
out = self.output(h)
|
|
91
|
+
return self.output_act(out)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
95
|
+
def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
|
|
96
|
+
num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
|
|
97
|
+
learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
|
|
98
|
+
task_type: str = 'regression', tweedie_power: float = 1.5,
|
|
99
|
+
weight_decay: float = 0.0,
|
|
100
|
+
use_data_parallel: bool = False, use_ddp: bool = False,
|
|
101
|
+
use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
|
|
102
|
+
graph_cache_path: Optional[str] = None,
|
|
103
|
+
max_gpu_knn_nodes: Optional[int] = None,
|
|
104
|
+
knn_gpu_mem_ratio: float = 0.9,
|
|
105
|
+
knn_gpu_mem_overhead: float = 2.0,
|
|
106
|
+
knn_cpu_jobs: Optional[int] = -1) -> None:
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.model_nme = model_nme
|
|
109
|
+
self.input_dim = input_dim
|
|
110
|
+
self.hidden_dim = hidden_dim
|
|
111
|
+
self.num_layers = num_layers
|
|
112
|
+
self.k_neighbors = max(1, k_neighbors)
|
|
113
|
+
self.dropout = dropout
|
|
114
|
+
self.learning_rate = learning_rate
|
|
115
|
+
self.weight_decay = weight_decay
|
|
116
|
+
self.epochs = epochs
|
|
117
|
+
self.patience = patience
|
|
118
|
+
self.task_type = task_type
|
|
119
|
+
self.use_approx_knn = use_approx_knn
|
|
120
|
+
self.approx_knn_threshold = approx_knn_threshold
|
|
121
|
+
self.graph_cache_path = Path(
|
|
122
|
+
graph_cache_path) if graph_cache_path else None
|
|
123
|
+
self.max_gpu_knn_nodes = max_gpu_knn_nodes
|
|
124
|
+
self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
|
|
125
|
+
self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
|
|
126
|
+
self.knn_cpu_jobs = knn_cpu_jobs
|
|
127
|
+
self._knn_warning_emitted = False
|
|
128
|
+
self._adj_cache_meta: Optional[Dict[str, Any]] = None
|
|
129
|
+
self._adj_cache_key: Optional[Tuple[Any, ...]] = None
|
|
130
|
+
self._adj_cache_tensor: Optional[torch.Tensor] = None
|
|
131
|
+
|
|
132
|
+
if self.task_type == 'classification':
|
|
133
|
+
self.tw_power = None
|
|
134
|
+
elif 'f' in self.model_nme:
|
|
135
|
+
self.tw_power = 1.0
|
|
136
|
+
elif 's' in self.model_nme:
|
|
137
|
+
self.tw_power = 2.0
|
|
138
|
+
else:
|
|
139
|
+
self.tw_power = tweedie_power
|
|
140
|
+
|
|
141
|
+
self.ddp_enabled = False
|
|
142
|
+
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
143
|
+
self.data_parallel_enabled = False
|
|
144
|
+
self._ddp_disabled = False
|
|
145
|
+
|
|
146
|
+
if use_ddp:
|
|
147
|
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
148
|
+
if world_size > 1:
|
|
149
|
+
print(
|
|
150
|
+
"[GNN] DDP training is not supported; falling back to single process.",
|
|
151
|
+
flush=True,
|
|
152
|
+
)
|
|
153
|
+
self._ddp_disabled = True
|
|
154
|
+
use_ddp = False
|
|
155
|
+
|
|
156
|
+
# DDP only works with CUDA; fall back to single process if init fails.
|
|
157
|
+
if use_ddp and torch.cuda.is_available():
|
|
158
|
+
ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
|
|
159
|
+
if ddp_ok:
|
|
160
|
+
self.ddp_enabled = True
|
|
161
|
+
self.local_rank = local_rank
|
|
162
|
+
self.device = torch.device(f'cuda:{local_rank}')
|
|
163
|
+
else:
|
|
164
|
+
self.device = torch.device('cuda')
|
|
165
|
+
elif torch.cuda.is_available():
|
|
166
|
+
if self._ddp_disabled:
|
|
167
|
+
self.device = torch.device(f'cuda:{self.local_rank}')
|
|
168
|
+
else:
|
|
169
|
+
self.device = torch.device('cuda')
|
|
170
|
+
elif torch.backends.mps.is_available():
|
|
171
|
+
self.device = torch.device('cpu')
|
|
172
|
+
global _GNN_MPS_WARNED
|
|
173
|
+
if not _GNN_MPS_WARNED:
|
|
174
|
+
print(
|
|
175
|
+
"[GNN] MPS backend does not support sparse ops; falling back to CPU.",
|
|
176
|
+
flush=True,
|
|
177
|
+
)
|
|
178
|
+
_GNN_MPS_WARNED = True
|
|
179
|
+
else:
|
|
180
|
+
self.device = torch.device('cpu')
|
|
181
|
+
self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
|
|
182
|
+
|
|
183
|
+
self.gnn = SimpleGNN(
|
|
184
|
+
input_dim=self.input_dim,
|
|
185
|
+
hidden_dim=self.hidden_dim,
|
|
186
|
+
num_layers=self.num_layers,
|
|
187
|
+
dropout=self.dropout,
|
|
188
|
+
task_type=self.task_type
|
|
189
|
+
).to(self.device)
|
|
190
|
+
|
|
191
|
+
# DataParallel copies the full graph to each GPU and splits features; good for medium graphs.
|
|
192
|
+
if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
|
|
193
|
+
self.data_parallel_enabled = True
|
|
194
|
+
self.gnn = nn.DataParallel(
|
|
195
|
+
self.gnn, device_ids=list(range(torch.cuda.device_count())))
|
|
196
|
+
self.device = torch.device('cuda')
|
|
197
|
+
|
|
198
|
+
if self.ddp_enabled:
|
|
199
|
+
self.gnn = DDP(
|
|
200
|
+
self.gnn,
|
|
201
|
+
device_ids=[self.local_rank],
|
|
202
|
+
output_device=self.local_rank,
|
|
203
|
+
find_unused_parameters=False
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def _validate_vector(arr, name: str, n_rows: int) -> None:
|
|
208
|
+
if arr is None:
|
|
209
|
+
return
|
|
210
|
+
if isinstance(arr, pd.DataFrame):
|
|
211
|
+
if arr.shape[1] != 1:
|
|
212
|
+
raise ValueError(f"{name} must be 1d (single column).")
|
|
213
|
+
length = len(arr)
|
|
214
|
+
else:
|
|
215
|
+
arr_np = np.asarray(arr)
|
|
216
|
+
if arr_np.ndim == 0:
|
|
217
|
+
raise ValueError(f"{name} must be 1d.")
|
|
218
|
+
if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
|
|
219
|
+
raise ValueError(f"{name} must be 1d or Nx1.")
|
|
220
|
+
length = arr_np.shape[0]
|
|
221
|
+
if length != n_rows:
|
|
222
|
+
raise ValueError(
|
|
223
|
+
f"{name} length {length} does not match X length {n_rows}."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def _unwrap_gnn(self) -> nn.Module:
|
|
227
|
+
if isinstance(self.gnn, (DDP, nn.DataParallel)):
|
|
228
|
+
return self.gnn.module
|
|
229
|
+
return self.gnn
|
|
230
|
+
|
|
231
|
+
def _set_adj_buffer(self, adj: torch.Tensor) -> None:
|
|
232
|
+
base = self._unwrap_gnn()
|
|
233
|
+
if hasattr(base, "adj_buffer"):
|
|
234
|
+
base.adj_buffer = adj
|
|
235
|
+
else:
|
|
236
|
+
base.register_buffer("adj_buffer", adj)
|
|
237
|
+
|
|
238
|
+
def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
|
|
239
|
+
row_hash = pd.util.hash_pandas_object(X_df, index=False).values
|
|
240
|
+
idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
|
|
241
|
+
col_sig = ",".join(map(str, X_df.columns))
|
|
242
|
+
hasher = hashlib.sha256()
|
|
243
|
+
hasher.update(row_hash.tobytes())
|
|
244
|
+
hasher.update(idx_hash.tobytes())
|
|
245
|
+
hasher.update(col_sig.encode("utf-8", errors="ignore"))
|
|
246
|
+
knn_config = {
|
|
247
|
+
"k_neighbors": int(self.k_neighbors),
|
|
248
|
+
"use_approx_knn": bool(self.use_approx_knn),
|
|
249
|
+
"approx_knn_threshold": int(self.approx_knn_threshold),
|
|
250
|
+
"use_pyg_knn": bool(self.use_pyg_knn),
|
|
251
|
+
"pynndescent_available": bool(_PYNN_AVAILABLE),
|
|
252
|
+
"max_gpu_knn_nodes": (
|
|
253
|
+
None if self.max_gpu_knn_nodes is None else int(self.max_gpu_knn_nodes)
|
|
254
|
+
),
|
|
255
|
+
"knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
|
|
256
|
+
"knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
|
|
257
|
+
}
|
|
258
|
+
return {
|
|
259
|
+
"n_samples": int(X_df.shape[0]),
|
|
260
|
+
"n_features": int(X_df.shape[1]),
|
|
261
|
+
"hash": hasher.hexdigest(),
|
|
262
|
+
"knn_config": knn_config,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
|
|
266
|
+
return (
|
|
267
|
+
id(X_df),
|
|
268
|
+
id(getattr(X_df, "_mgr", None)),
|
|
269
|
+
id(X_df.index),
|
|
270
|
+
X_df.shape,
|
|
271
|
+
tuple(map(str, X_df.columns)),
|
|
272
|
+
X_df.attrs.get("graph_cache_key"),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def invalidate_graph_cache(self) -> None:
|
|
276
|
+
self._adj_cache_meta = None
|
|
277
|
+
self._adj_cache_key = None
|
|
278
|
+
self._adj_cache_tensor = None
|
|
279
|
+
|
|
280
|
+
def _load_cached_adj(self,
|
|
281
|
+
X_df: pd.DataFrame,
|
|
282
|
+
meta_expected: Optional[Dict[str, Any]] = None) -> Optional[torch.Tensor]:
|
|
283
|
+
if self.graph_cache_path and self.graph_cache_path.exists():
|
|
284
|
+
if meta_expected is None:
|
|
285
|
+
meta_expected = self._graph_cache_meta(X_df)
|
|
286
|
+
try:
|
|
287
|
+
payload = torch.load(self.graph_cache_path,
|
|
288
|
+
map_location=self.device)
|
|
289
|
+
except Exception as exc:
|
|
290
|
+
print(
|
|
291
|
+
f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
|
|
292
|
+
return None
|
|
293
|
+
if isinstance(payload, dict) and "adj" in payload:
|
|
294
|
+
meta_cached = payload.get("meta")
|
|
295
|
+
if meta_cached == meta_expected:
|
|
296
|
+
return payload["adj"].to(self.device)
|
|
297
|
+
print(
|
|
298
|
+
f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
|
|
299
|
+
return None
|
|
300
|
+
if isinstance(payload, torch.Tensor):
|
|
301
|
+
print(
|
|
302
|
+
f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
|
|
303
|
+
return None
|
|
304
|
+
print(
|
|
305
|
+
f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
|
|
306
|
+
return None
|
|
307
|
+
|
|
308
|
+
def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
|
|
309
|
+
n_samples = X_np.shape[0]
|
|
310
|
+
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
311
|
+
n_neighbors = min(k + 1, n_samples)
|
|
312
|
+
use_approx = (self.use_approx_knn or n_samples >=
|
|
313
|
+
self.approx_knn_threshold) and _PYNN_AVAILABLE
|
|
314
|
+
indices = None
|
|
315
|
+
if use_approx:
|
|
316
|
+
try:
|
|
317
|
+
nn_index = pynndescent.NNDescent(
|
|
318
|
+
X_np,
|
|
319
|
+
n_neighbors=n_neighbors,
|
|
320
|
+
random_state=0
|
|
321
|
+
)
|
|
322
|
+
indices, _ = nn_index.neighbor_graph
|
|
323
|
+
except Exception as exc:
|
|
324
|
+
print(
|
|
325
|
+
f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
|
|
326
|
+
use_approx = False
|
|
327
|
+
|
|
328
|
+
if indices is None:
|
|
329
|
+
nbrs = NearestNeighbors(
|
|
330
|
+
n_neighbors=n_neighbors,
|
|
331
|
+
algorithm="auto",
|
|
332
|
+
n_jobs=self.knn_cpu_jobs,
|
|
333
|
+
)
|
|
334
|
+
nbrs.fit(X_np)
|
|
335
|
+
_, indices = nbrs.kneighbors(X_np)
|
|
336
|
+
|
|
337
|
+
indices = np.asarray(indices)
|
|
338
|
+
rows = np.repeat(np.arange(n_samples), n_neighbors).astype(
|
|
339
|
+
np.int64, copy=False)
|
|
340
|
+
cols = indices.reshape(-1).astype(np.int64, copy=False)
|
|
341
|
+
mask = rows != cols
|
|
342
|
+
rows = rows[mask]
|
|
343
|
+
cols = cols[mask]
|
|
344
|
+
rows_base = rows
|
|
345
|
+
cols_base = cols
|
|
346
|
+
self_loops = np.arange(n_samples, dtype=np.int64)
|
|
347
|
+
rows = np.concatenate([rows_base, cols_base, self_loops])
|
|
348
|
+
cols = np.concatenate([cols_base, rows_base, self_loops])
|
|
349
|
+
|
|
350
|
+
edge_index_np = np.stack([rows, cols], axis=0)
|
|
351
|
+
edge_index = torch.as_tensor(edge_index_np, device=self.device)
|
|
352
|
+
return edge_index
|
|
353
|
+
|
|
354
|
+
def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
|
|
355
|
+
if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
|
|
356
|
+
# Defensive: check use_pyg_knn before calling.
|
|
357
|
+
raise RuntimeError(
|
|
358
|
+
"GPU graph builder requested but PyG is unavailable.")
|
|
359
|
+
|
|
360
|
+
n_samples = X_tensor.size(0)
|
|
361
|
+
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
362
|
+
|
|
363
|
+
# knn_graph runs on GPU to avoid CPU graph construction bottlenecks.
|
|
364
|
+
edge_index = knn_graph(
|
|
365
|
+
X_tensor,
|
|
366
|
+
k=k,
|
|
367
|
+
loop=False
|
|
368
|
+
)
|
|
369
|
+
edge_index = to_undirected(edge_index, num_nodes=n_samples)
|
|
370
|
+
edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
|
|
371
|
+
return edge_index
|
|
372
|
+
|
|
373
|
+
def _log_knn_fallback(self, reason: str) -> None:
|
|
374
|
+
if self._knn_warning_emitted:
|
|
375
|
+
return
|
|
376
|
+
if (not self.ddp_enabled) or self.local_rank == 0:
|
|
377
|
+
print(f"[GNN] Falling back to CPU kNN builder: {reason}")
|
|
378
|
+
self._knn_warning_emitted = True
|
|
379
|
+
|
|
380
|
+
def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
|
|
381
|
+
if not self.use_pyg_knn:
|
|
382
|
+
return False
|
|
383
|
+
|
|
384
|
+
reason = None
|
|
385
|
+
if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
|
|
386
|
+
reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
|
|
387
|
+
elif self.device.type == 'cuda' and torch.cuda.is_available():
|
|
388
|
+
try:
|
|
389
|
+
device_index = self.device.index
|
|
390
|
+
if device_index is None:
|
|
391
|
+
device_index = torch.cuda.current_device()
|
|
392
|
+
free_mem, total_mem = torch.cuda.mem_get_info(device_index)
|
|
393
|
+
feature_bytes = X_tensor.element_size() * X_tensor.nelement()
|
|
394
|
+
required = int(feature_bytes * self.knn_gpu_mem_overhead)
|
|
395
|
+
budget = int(free_mem * self.knn_gpu_mem_ratio)
|
|
396
|
+
if required > budget:
|
|
397
|
+
required_gb = required / (1024 ** 3)
|
|
398
|
+
budget_gb = budget / (1024 ** 3)
|
|
399
|
+
reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
|
|
400
|
+
f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
|
|
401
|
+
except Exception:
|
|
402
|
+
# On older versions or some environments, mem_get_info may be unavailable; default to trying GPU.
|
|
403
|
+
reason = None
|
|
404
|
+
|
|
405
|
+
if reason:
|
|
406
|
+
self._log_knn_fallback(reason)
|
|
407
|
+
return False
|
|
408
|
+
return True
|
|
409
|
+
|
|
410
|
+
def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
411
|
+
values = torch.ones(edge_index.shape[1], device=self.device)
|
|
412
|
+
adj = torch.sparse_coo_tensor(
|
|
413
|
+
edge_index.to(self.device), values, (num_nodes, num_nodes))
|
|
414
|
+
adj = adj.coalesce()
|
|
415
|
+
|
|
416
|
+
deg = torch.sparse.sum(adj, dim=1).to_dense()
|
|
417
|
+
deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
|
|
418
|
+
row, col = adj.indices()
|
|
419
|
+
norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
|
|
420
|
+
adj_norm = torch.sparse_coo_tensor(
|
|
421
|
+
adj.indices(), norm_values, size=adj.shape)
|
|
422
|
+
return adj_norm
|
|
423
|
+
|
|
424
|
+
def _tensorize_split(self, X, y, w, allow_none: bool = False):
|
|
425
|
+
if X is None and allow_none:
|
|
426
|
+
return None, None, None
|
|
427
|
+
if not isinstance(X, pd.DataFrame):
|
|
428
|
+
raise ValueError("X must be a pandas DataFrame for GNN.")
|
|
429
|
+
n_rows = len(X)
|
|
430
|
+
if y is not None:
|
|
431
|
+
self._validate_vector(y, "y", n_rows)
|
|
432
|
+
if w is not None:
|
|
433
|
+
self._validate_vector(w, "w", n_rows)
|
|
434
|
+
X_np = X.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
435
|
+
X, "to_numpy") else np.asarray(X, dtype=np.float32)
|
|
436
|
+
X_tensor = torch.as_tensor(
|
|
437
|
+
X_np, dtype=torch.float32, device=self.device)
|
|
438
|
+
if y is None:
|
|
439
|
+
y_tensor = None
|
|
440
|
+
else:
|
|
441
|
+
y_np = y.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
442
|
+
y, "to_numpy") else np.asarray(y, dtype=np.float32)
|
|
443
|
+
y_tensor = torch.as_tensor(
|
|
444
|
+
y_np, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
445
|
+
if w is None:
|
|
446
|
+
w_tensor = torch.ones(
|
|
447
|
+
(len(X), 1), dtype=torch.float32, device=self.device)
|
|
448
|
+
else:
|
|
449
|
+
w_np = w.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
450
|
+
w, "to_numpy") else np.asarray(w, dtype=np.float32)
|
|
451
|
+
w_tensor = torch.as_tensor(
|
|
452
|
+
w_np, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
453
|
+
return X_tensor, y_tensor, w_tensor
|
|
454
|
+
|
|
455
|
+
def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
456
|
+
if not isinstance(X_df, pd.DataFrame):
|
|
457
|
+
raise ValueError("X must be a pandas DataFrame for graph building.")
|
|
458
|
+
meta_expected = None
|
|
459
|
+
cache_key = None
|
|
460
|
+
if self.graph_cache_path:
|
|
461
|
+
meta_expected = self._graph_cache_meta(X_df)
|
|
462
|
+
if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
|
|
463
|
+
cached = self._adj_cache_tensor
|
|
464
|
+
if cached.device != self.device:
|
|
465
|
+
cached = cached.to(self.device)
|
|
466
|
+
self._adj_cache_tensor = cached
|
|
467
|
+
return cached
|
|
468
|
+
else:
|
|
469
|
+
cache_key = self._graph_cache_key(X_df)
|
|
470
|
+
if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
|
|
471
|
+
cached = self._adj_cache_tensor
|
|
472
|
+
if cached.device != self.device:
|
|
473
|
+
cached = cached.to(self.device)
|
|
474
|
+
self._adj_cache_tensor = cached
|
|
475
|
+
return cached
|
|
476
|
+
X_np = None
|
|
477
|
+
if X_tensor is None:
|
|
478
|
+
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
479
|
+
X_tensor = torch.as_tensor(
|
|
480
|
+
X_np, dtype=torch.float32, device=self.device)
|
|
481
|
+
if self.graph_cache_path:
|
|
482
|
+
cached = self._load_cached_adj(X_df, meta_expected=meta_expected)
|
|
483
|
+
if cached is not None:
|
|
484
|
+
self._adj_cache_meta = meta_expected
|
|
485
|
+
self._adj_cache_key = None
|
|
486
|
+
self._adj_cache_tensor = cached
|
|
487
|
+
return cached
|
|
488
|
+
use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
|
|
489
|
+
if use_gpu_knn:
|
|
490
|
+
edge_index = self._build_edge_index_gpu(X_tensor)
|
|
491
|
+
else:
|
|
492
|
+
if X_np is None:
|
|
493
|
+
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
494
|
+
edge_index = self._build_edge_index_cpu(X_np)
|
|
495
|
+
adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
|
|
496
|
+
if self.graph_cache_path:
|
|
497
|
+
try:
|
|
498
|
+
IOUtils.ensure_parent_dir(str(self.graph_cache_path))
|
|
499
|
+
torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
|
|
500
|
+
except Exception as exc:
|
|
501
|
+
print(
|
|
502
|
+
f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
|
|
503
|
+
self._adj_cache_meta = meta_expected
|
|
504
|
+
self._adj_cache_key = None
|
|
505
|
+
else:
|
|
506
|
+
self._adj_cache_meta = None
|
|
507
|
+
self._adj_cache_key = cache_key
|
|
508
|
+
self._adj_cache_tensor = adj_norm
|
|
509
|
+
return adj_norm
|
|
510
|
+
|
|
511
|
+
def fit(self, X_train, y_train, w_train=None,
|
|
512
|
+
X_val=None, y_val=None, w_val=None,
|
|
513
|
+
trial: Optional[optuna.trial.Trial] = None):
|
|
514
|
+
|
|
515
|
+
X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
|
|
516
|
+
X_train, y_train, w_train, allow_none=False)
|
|
517
|
+
has_val = X_val is not None and y_val is not None
|
|
518
|
+
if has_val:
|
|
519
|
+
X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
|
|
520
|
+
X_val, y_val, w_val, allow_none=False)
|
|
521
|
+
else:
|
|
522
|
+
X_val_tensor = y_val_tensor = w_val_tensor = None
|
|
523
|
+
|
|
524
|
+
adj_train = self._build_graph_from_df(X_train, X_train_tensor)
|
|
525
|
+
adj_val = self._build_graph_from_df(
|
|
526
|
+
X_val, X_val_tensor) if has_val else None
|
|
527
|
+
# DataParallel needs adjacency cached on the model to avoid scatter.
|
|
528
|
+
self._set_adj_buffer(adj_train)
|
|
529
|
+
|
|
530
|
+
base_gnn = self._unwrap_gnn()
|
|
531
|
+
optimizer = torch.optim.Adam(
|
|
532
|
+
base_gnn.parameters(),
|
|
533
|
+
lr=self.learning_rate,
|
|
534
|
+
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
535
|
+
)
|
|
536
|
+
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
537
|
+
|
|
538
|
+
best_loss = float('inf')
|
|
539
|
+
best_state = None
|
|
540
|
+
patience_counter = 0
|
|
541
|
+
best_epoch = None
|
|
542
|
+
|
|
543
|
+
for epoch in range(1, self.epochs + 1):
|
|
544
|
+
epoch_start_ts = time.time()
|
|
545
|
+
self.gnn.train()
|
|
546
|
+
optimizer.zero_grad()
|
|
547
|
+
with autocast(enabled=(self.device.type == 'cuda')):
|
|
548
|
+
if self.data_parallel_enabled:
|
|
549
|
+
y_pred = self.gnn(X_train_tensor)
|
|
550
|
+
else:
|
|
551
|
+
y_pred = self.gnn(X_train_tensor, adj_train)
|
|
552
|
+
loss = self._compute_weighted_loss(
|
|
553
|
+
y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
|
|
554
|
+
scaler.scale(loss).backward()
|
|
555
|
+
scaler.unscale_(optimizer)
|
|
556
|
+
clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
|
|
557
|
+
scaler.step(optimizer)
|
|
558
|
+
scaler.update()
|
|
559
|
+
|
|
560
|
+
val_loss = None
|
|
561
|
+
if has_val:
|
|
562
|
+
self.gnn.eval()
|
|
563
|
+
if self.data_parallel_enabled and adj_val is not None:
|
|
564
|
+
self._set_adj_buffer(adj_val)
|
|
565
|
+
with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
|
|
566
|
+
if self.data_parallel_enabled:
|
|
567
|
+
y_val_pred = self.gnn(X_val_tensor)
|
|
568
|
+
else:
|
|
569
|
+
y_val_pred = self.gnn(X_val_tensor, adj_val)
|
|
570
|
+
val_loss = self._compute_weighted_loss(
|
|
571
|
+
y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
|
|
572
|
+
if self.data_parallel_enabled:
|
|
573
|
+
# Restore training adjacency.
|
|
574
|
+
self._set_adj_buffer(adj_train)
|
|
575
|
+
|
|
576
|
+
is_best = val_loss is not None and val_loss < best_loss
|
|
577
|
+
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
578
|
+
val_loss, best_loss, best_state, patience_counter, base_gnn,
|
|
579
|
+
ignore_keys=["adj_buffer"])
|
|
580
|
+
if is_best:
|
|
581
|
+
best_epoch = epoch
|
|
582
|
+
|
|
583
|
+
prune_now = False
|
|
584
|
+
if trial is not None:
|
|
585
|
+
trial.report(val_loss, epoch)
|
|
586
|
+
if trial.should_prune():
|
|
587
|
+
prune_now = True
|
|
588
|
+
|
|
589
|
+
if dist.is_initialized():
|
|
590
|
+
flag = torch.tensor(
|
|
591
|
+
[1 if prune_now else 0],
|
|
592
|
+
device=self.device,
|
|
593
|
+
dtype=torch.int32,
|
|
594
|
+
)
|
|
595
|
+
dist.broadcast(flag, src=0)
|
|
596
|
+
prune_now = bool(flag.item())
|
|
597
|
+
|
|
598
|
+
if prune_now:
|
|
599
|
+
raise optuna.TrialPruned()
|
|
600
|
+
if stop_training:
|
|
601
|
+
break
|
|
602
|
+
|
|
603
|
+
should_log = (not dist.is_initialized()
|
|
604
|
+
or DistributedUtils.is_main_process())
|
|
605
|
+
if should_log:
|
|
606
|
+
elapsed = int(time.time() - epoch_start_ts)
|
|
607
|
+
if val_loss is None:
|
|
608
|
+
print(
|
|
609
|
+
f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
|
|
610
|
+
flush=True,
|
|
611
|
+
)
|
|
612
|
+
else:
|
|
613
|
+
print(
|
|
614
|
+
f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
|
|
615
|
+
f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
|
|
616
|
+
flush=True,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
if best_state is not None:
|
|
620
|
+
base_gnn.load_state_dict(best_state, strict=False)
|
|
621
|
+
self.best_epoch = int(best_epoch or self.epochs)
|
|
622
|
+
|
|
623
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
624
|
+
self.gnn.eval()
|
|
625
|
+
X_tensor, _, _ = self._tensorize_split(
|
|
626
|
+
X, None, None, allow_none=False)
|
|
627
|
+
adj = self._build_graph_from_df(X, X_tensor)
|
|
628
|
+
if self.data_parallel_enabled:
|
|
629
|
+
self._set_adj_buffer(adj)
|
|
630
|
+
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
631
|
+
with inference_cm():
|
|
632
|
+
if self.data_parallel_enabled:
|
|
633
|
+
y_pred = self.gnn(X_tensor).cpu().numpy()
|
|
634
|
+
else:
|
|
635
|
+
y_pred = self.gnn(X_tensor, adj).cpu().numpy()
|
|
636
|
+
if self.task_type == 'classification':
|
|
637
|
+
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
638
|
+
else:
|
|
639
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
640
|
+
return y_pred.ravel()
|
|
641
|
+
|
|
642
|
+
def encode(self, X: pd.DataFrame) -> np.ndarray:
|
|
643
|
+
"""Return per-sample node embeddings (hidden representations)."""
|
|
644
|
+
base = self._unwrap_gnn()
|
|
645
|
+
base.eval()
|
|
646
|
+
X_tensor, _, _ = self._tensorize_split(X, None, None, allow_none=False)
|
|
647
|
+
adj = self._build_graph_from_df(X, X_tensor)
|
|
648
|
+
if self.data_parallel_enabled:
|
|
649
|
+
self._set_adj_buffer(adj)
|
|
650
|
+
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
651
|
+
with inference_cm():
|
|
652
|
+
h = X_tensor
|
|
653
|
+
layers = getattr(base, "layers", None)
|
|
654
|
+
if layers is None:
|
|
655
|
+
raise RuntimeError("GNN base module does not expose layers.")
|
|
656
|
+
for layer in layers:
|
|
657
|
+
h = layer(h, adj)
|
|
658
|
+
h = torch.sparse.mm(adj, h)
|
|
659
|
+
return h.detach().cpu().numpy()
|
|
660
|
+
|
|
661
|
+
def set_params(self, params: Dict[str, Any]):
|
|
662
|
+
for key, value in params.items():
|
|
663
|
+
if hasattr(self, key):
|
|
664
|
+
setattr(self, key, value)
|
|
665
|
+
else:
|
|
666
|
+
raise ValueError(f"Parameter {key} not found in GNN model.")
|
|
667
|
+
# Rebuild the backbone after structural parameter changes.
|
|
668
|
+
self.gnn = SimpleGNN(
|
|
669
|
+
input_dim=self.input_dim,
|
|
670
|
+
hidden_dim=self.hidden_dim,
|
|
671
|
+
num_layers=self.num_layers,
|
|
672
|
+
dropout=self.dropout,
|
|
673
|
+
task_type=self.task_type
|
|
674
|
+
).to(self.device)
|
|
675
|
+
return self
|