ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +39 -105
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +11 -9
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/__init__.py +10 -10
  15. ins_pricing/frontend/example_workflows.py +1 -1
  16. ins_pricing/governance/__init__.py +20 -20
  17. ins_pricing/governance/release.py +159 -159
  18. ins_pricing/modelling/__init__.py +147 -92
  19. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
  20. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  21. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
  22. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  29. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  36. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  37. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  38. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  39. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  40. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  42. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  43. ins_pricing/modelling/explain/__init__.py +55 -55
  44. ins_pricing/modelling/explain/metrics.py +27 -174
  45. ins_pricing/modelling/explain/permutation.py +237 -237
  46. ins_pricing/modelling/plotting/__init__.py +40 -36
  47. ins_pricing/modelling/plotting/compat.py +228 -0
  48. ins_pricing/modelling/plotting/curves.py +572 -572
  49. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  50. ins_pricing/modelling/plotting/geo.py +362 -362
  51. ins_pricing/modelling/plotting/importance.py +121 -121
  52. ins_pricing/pricing/__init__.py +27 -27
  53. ins_pricing/production/__init__.py +35 -25
  54. ins_pricing/production/{predict.py → inference.py} +140 -57
  55. ins_pricing/production/monitoring.py +8 -21
  56. ins_pricing/reporting/__init__.py +11 -11
  57. ins_pricing/setup.py +1 -1
  58. ins_pricing/tests/production/test_inference.py +90 -0
  59. ins_pricing/utils/__init__.py +116 -83
  60. ins_pricing/utils/device.py +255 -255
  61. ins_pricing/utils/features.py +53 -0
  62. ins_pricing/utils/io.py +72 -0
  63. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  64. ins_pricing/utils/metrics.py +158 -24
  65. ins_pricing/utils/numerics.py +76 -0
  66. ins_pricing/utils/paths.py +9 -1
  67. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
  68. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  69. ins_pricing/modelling/core/BayesOpt.py +0 -146
  70. ins_pricing/modelling/core/__init__.py +0 -1
  71. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  72. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  73. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  74. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  75. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  76. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  77. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  78. ins_pricing/tests/production/test_predict.py +0 -233
  79. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  80. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  81. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  82. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  83. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  84. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,344 +1,344 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import optuna
8
- import torch
9
- from sklearn.metrics import log_loss
10
-
11
- from .trainer_base import TrainerBase
12
- from ..models import GraphNeuralNetSklearn
13
- from ..utils import EPS
14
- from ..utils.losses import regression_loss
15
- from ins_pricing.utils import get_logger
16
- from ins_pricing.utils.torch_compat import torch_load
17
-
18
- _logger = get_logger("ins_pricing.trainer.gnn")
19
-
20
- class GNNTrainer(TrainerBase):
21
- def __init__(self, context: "BayesOptModel") -> None:
22
- super().__init__(context, 'GNN', 'GNN')
23
- self.model: Optional[GraphNeuralNetSklearn] = None
24
- self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
25
-
26
- def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
27
- params = params or {}
28
- base_tw_power = self.ctx.default_tweedie_power()
29
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
30
- tw_power = params.get("tw_power")
31
- if self.ctx.task_type == "regression":
32
- if loss_name == "tweedie":
33
- tw_power = base_tw_power if tw_power is None else float(tw_power)
34
- elif loss_name in ("poisson", "gamma"):
35
- tw_power = base_tw_power
36
- else:
37
- tw_power = None
38
- model = GraphNeuralNetSklearn(
39
- model_nme=f"{self.ctx.model_nme}_gnn",
40
- input_dim=len(self.ctx.var_nmes),
41
- hidden_dim=int(params.get("hidden_dim", 64)),
42
- num_layers=int(params.get("num_layers", 2)),
43
- k_neighbors=int(params.get("k_neighbors", 10)),
44
- dropout=float(params.get("dropout", 0.1)),
45
- learning_rate=float(params.get("learning_rate", 1e-3)),
46
- epochs=int(params.get("epochs", self.ctx.epochs)),
47
- patience=int(params.get("patience", 5)),
48
- task_type=self.ctx.task_type,
49
- tweedie_power=tw_power,
50
- weight_decay=float(params.get("weight_decay", 0.0)),
51
- use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
52
- use_ddp=bool(self.ctx.config.use_gnn_ddp),
53
- use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
54
- approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
55
- graph_cache_path=self.ctx.config.gnn_graph_cache,
56
- max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
57
- knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
58
- knn_gpu_mem_overhead=float(
59
- self.ctx.config.gnn_knn_gpu_mem_overhead),
60
- loss_name=loss_name,
61
- )
62
- return self._apply_dataloader_overrides(model)
63
-
64
- def cross_val(self, trial: optuna.trial.Trial) -> float:
65
- base_tw_power = self.ctx.default_tweedie_power()
66
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
67
- metric_ctx: Dict[str, Any] = {}
68
-
69
- def data_provider():
70
- data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
71
- assert data is not None, "Preprocessed training data is missing."
72
- return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
73
-
74
- def model_builder(params: Dict[str, Any]):
75
- if loss_name == "tweedie":
76
- tw_power = params.get("tw_power", base_tw_power)
77
- elif loss_name in ("poisson", "gamma"):
78
- tw_power = base_tw_power
79
- else:
80
- tw_power = None
81
- metric_ctx["tw_power"] = tw_power
82
- if tw_power is None:
83
- params = dict(params)
84
- params.pop("tw_power", None)
85
- return self._build_model(params)
86
-
87
- def preprocess_fn(X_train, X_val):
88
- X_train_s, X_val_s, _ = self._standardize_fold(
89
- X_train, X_val, self.ctx.num_features)
90
- return X_train_s, X_val_s
91
-
92
- def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
93
- model.fit(
94
- X_train,
95
- y_train,
96
- w_train=w_train,
97
- X_val=X_val,
98
- y_val=y_val,
99
- w_val=w_val,
100
- trial=trial_obj,
101
- )
102
- return model.predict(X_val)
103
-
104
- def metric_fn(y_true, y_pred, weight):
105
- if self.ctx.task_type == 'classification':
106
- y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
107
- return log_loss(y_true, y_pred_clipped, sample_weight=weight)
108
- return regression_loss(
109
- y_true,
110
- y_pred,
111
- weight,
112
- loss_name=loss_name,
113
- tweedie_power=metric_ctx.get("tw_power", base_tw_power),
114
- )
115
-
116
- # Keep GNN BO lightweight: sample during CV, use full data for final training.
117
- X_cap = data_provider()[0]
118
- sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
119
-
120
- param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
121
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
122
- "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
123
- "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
124
- "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
125
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
126
- "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
127
- }
128
- if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
129
- param_space["tw_power"] = lambda t: t.suggest_float(
130
- 'tw_power', 1.0, 2.0)
131
-
132
- return self.cross_val_generic(
133
- trial=trial,
134
- hyperparameter_space=param_space,
135
- data_provider=data_provider,
136
- model_builder=model_builder,
137
- metric_fn=metric_fn,
138
- sample_limit=sample_limit,
139
- preprocess_fn=preprocess_fn,
140
- fit_predict_fn=fit_predict,
141
- cleanup_fn=lambda m: getattr(
142
- getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
143
- )
144
-
145
- def train(self) -> None:
146
- if not self.best_params:
147
- raise RuntimeError("Run tune() first to obtain best GNN parameters.")
148
-
149
- data = self.ctx.train_oht_scl_data
150
- assert data is not None, "Preprocessed training data is missing."
151
- X_all = data[self.ctx.var_nmes]
152
- y_all = data[self.ctx.resp_nme]
153
- w_all = data[self.ctx.weight_nme]
154
-
155
- use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
- refit_epochs = None
157
-
158
- split = self._resolve_train_val_indices(X_all)
159
- if split is not None:
160
- train_idx, val_idx = split
161
- X_train = X_all.iloc[train_idx]
162
- y_train = y_all.iloc[train_idx]
163
- w_train = w_all.iloc[train_idx]
164
- X_val = X_all.iloc[val_idx]
165
- y_val = y_all.iloc[val_idx]
166
- w_val = w_all.iloc[val_idx]
167
-
168
- if use_refit:
169
- tmp_model = self._build_model(self.best_params)
170
- tmp_model.fit(
171
- X_train,
172
- y_train,
173
- w_train=w_train,
174
- X_val=X_val,
175
- y_val=y_val,
176
- w_val=w_val,
177
- trial=None,
178
- )
179
- refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
180
- getattr(getattr(tmp_model, "gnn", None), "to",
181
- lambda *_args, **_kwargs: None)("cpu")
182
- self._clean_gpu()
183
- else:
184
- self.model = self._build_model(self.best_params)
185
- self.model.fit(
186
- X_train,
187
- y_train,
188
- w_train=w_train,
189
- X_val=X_val,
190
- y_val=y_val,
191
- w_val=w_val,
192
- trial=None,
193
- )
194
- else:
195
- use_refit = False
196
-
197
- if use_refit:
198
- self.model = self._build_model(self.best_params)
199
- if refit_epochs is not None:
200
- self.model.epochs = int(refit_epochs)
201
- self.model.fit(
202
- X_all,
203
- y_all,
204
- w_train=w_all,
205
- X_val=None,
206
- y_val=None,
207
- w_val=None,
208
- trial=None,
209
- )
210
- elif self.model is None:
211
- self.model = self._build_model(self.best_params)
212
- self.model.fit(
213
- X_all,
214
- y_all,
215
- w_train=w_all,
216
- X_val=None,
217
- y_val=None,
218
- w_val=None,
219
- trial=None,
220
- )
221
- self.ctx.model_label.append(self.label)
222
- self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
223
- self.ctx.gnn_best = self.model
224
-
225
- # If geo_feature_nmes is set, refresh geo tokens for FT input.
226
- if self.ctx.config.geo_feature_nmes:
227
- self.prepare_geo_tokens(force=True)
228
-
229
- def ensemble_predict(self, k: int) -> None:
230
- if not self.best_params:
231
- raise RuntimeError("Run tune() first to obtain best GNN parameters.")
232
- data = self.ctx.train_oht_scl_data
233
- test_data = self.ctx.test_oht_scl_data
234
- if data is None or test_data is None:
235
- raise RuntimeError("Missing standardized data for GNN ensemble.")
236
- X_all = data[self.ctx.var_nmes]
237
- y_all = data[self.ctx.resp_nme]
238
- w_all = data[self.ctx.weight_nme]
239
- X_test = test_data[self.ctx.var_nmes]
240
-
241
- k = max(2, int(k))
242
- n_samples = len(X_all)
243
- split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
244
- if split_iter is None:
245
- print(
246
- f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
247
- flush=True,
248
- )
249
- return
250
- preds_train_sum = np.zeros(n_samples, dtype=np.float64)
251
- preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
252
-
253
- split_count = 0
254
- for train_idx, val_idx in split_iter:
255
- model = self._build_model(self.best_params)
256
- model.fit(
257
- X_all.iloc[train_idx],
258
- y_all.iloc[train_idx],
259
- w_train=w_all.iloc[train_idx],
260
- X_val=X_all.iloc[val_idx],
261
- y_val=y_all.iloc[val_idx],
262
- w_val=w_all.iloc[val_idx],
263
- trial=None,
264
- )
265
- pred_train = model.predict(X_all)
266
- pred_test = model.predict(X_test)
267
- preds_train_sum += np.asarray(pred_train, dtype=np.float64)
268
- preds_test_sum += np.asarray(pred_test, dtype=np.float64)
269
- getattr(getattr(model, "gnn", None), "to",
270
- lambda *_args, **_kwargs: None)("cpu")
271
- self._clean_gpu()
272
- split_count += 1
273
-
274
- if split_count < 1:
275
- print(
276
- f"[GNN Ensemble] no CV splits generated; skip ensemble.",
277
- flush=True,
278
- )
279
- return
280
- preds_train = preds_train_sum / float(split_count)
281
- preds_test = preds_test_sum / float(split_count)
282
- self._cache_predictions("gnn", preds_train, preds_test)
283
-
284
- def prepare_geo_tokens(self, force: bool = False) -> None:
285
- """Train/update the GNN encoder for geo tokens and inject them into FT input."""
286
- geo_cols = list(self.ctx.config.geo_feature_nmes or [])
287
- if not geo_cols:
288
- return
289
- if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
290
- return
291
-
292
- result = self.ctx._build_geo_tokens()
293
- if result is None:
294
- return
295
- train_tokens, test_tokens, cols, geo_gnn = result
296
- self.ctx.train_geo_tokens = train_tokens
297
- self.ctx.test_geo_tokens = test_tokens
298
- self.ctx.geo_token_cols = cols
299
- self.ctx.geo_gnn_model = geo_gnn
300
- print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
301
-
302
- def save(self) -> None:
303
- if self.model is None:
304
- print(f"[save] Warning: No model to save for {self.label}")
305
- return
306
- path = self.output.model_path(self._get_model_filename())
307
- base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
308
- if base_gnn is not None:
309
- base_gnn = base_gnn.to("cpu")
310
- state = None if base_gnn is None else base_gnn.state_dict()
311
- payload = {
312
- "best_params": self.best_params,
313
- "state_dict": state,
314
- "preprocess_artifacts": self._export_preprocess_artifacts(),
315
- }
316
- torch.save(payload, path)
317
-
318
- def load(self) -> None:
319
- path = self.output.model_path(self._get_model_filename())
320
- if not os.path.exists(path):
321
- print(f"[load] Warning: Model file not found: {path}")
322
- return
323
- payload = torch_load(path, map_location='cpu', weights_only=False)
324
- if not isinstance(payload, dict):
325
- raise ValueError(f"Invalid GNN checkpoint: {path}")
326
- params = payload.get("best_params") or {}
327
- state_dict = payload.get("state_dict")
328
- model = self._build_model(params)
329
- if params:
330
- model.set_params(dict(params))
331
- base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
332
- if base_gnn is not None and state_dict is not None:
333
- # Use strict=True for better error detection, but handle missing keys gracefully
334
- try:
335
- base_gnn.load_state_dict(state_dict, strict=True)
336
- except RuntimeError as e:
337
- if "Missing key" in str(e) or "Unexpected key" in str(e):
338
- print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
339
- base_gnn.load_state_dict(state_dict, strict=False)
340
- else:
341
- raise
342
- self.model = model
343
- self.best_params = dict(params) if isinstance(params, dict) else None
344
- self.ctx.gnn_best = self.model
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ from sklearn.metrics import log_loss
10
+
11
+ from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
12
+ from ins_pricing.modelling.bayesopt.models import GraphNeuralNetSklearn
13
+ from ins_pricing.utils import EPS
14
+ from ins_pricing.utils.losses import regression_loss
15
+ from ins_pricing.utils import get_logger
16
+ from ins_pricing.utils.torch_compat import torch_load
17
+
18
+ _logger = get_logger("ins_pricing.trainer.gnn")
19
+
20
+ class GNNTrainer(TrainerBase):
21
+ def __init__(self, context: "BayesOptModel") -> None:
22
+ super().__init__(context, 'GNN', 'GNN')
23
+ self.model: Optional[GraphNeuralNetSklearn] = None
24
+ self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
25
+
26
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
27
+ params = params or {}
28
+ base_tw_power = self.ctx.default_tweedie_power()
29
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
30
+ tw_power = params.get("tw_power")
31
+ if self.ctx.task_type == "regression":
32
+ if loss_name == "tweedie":
33
+ tw_power = base_tw_power if tw_power is None else float(tw_power)
34
+ elif loss_name in ("poisson", "gamma"):
35
+ tw_power = base_tw_power
36
+ else:
37
+ tw_power = None
38
+ model = GraphNeuralNetSklearn(
39
+ model_nme=f"{self.ctx.model_nme}_gnn",
40
+ input_dim=len(self.ctx.var_nmes),
41
+ hidden_dim=int(params.get("hidden_dim", 64)),
42
+ num_layers=int(params.get("num_layers", 2)),
43
+ k_neighbors=int(params.get("k_neighbors", 10)),
44
+ dropout=float(params.get("dropout", 0.1)),
45
+ learning_rate=float(params.get("learning_rate", 1e-3)),
46
+ epochs=int(params.get("epochs", self.ctx.epochs)),
47
+ patience=int(params.get("patience", 5)),
48
+ task_type=self.ctx.task_type,
49
+ tweedie_power=tw_power,
50
+ weight_decay=float(params.get("weight_decay", 0.0)),
51
+ use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
52
+ use_ddp=bool(self.ctx.config.use_gnn_ddp),
53
+ use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
54
+ approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
55
+ graph_cache_path=self.ctx.config.gnn_graph_cache,
56
+ max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
57
+ knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
58
+ knn_gpu_mem_overhead=float(
59
+ self.ctx.config.gnn_knn_gpu_mem_overhead),
60
+ loss_name=loss_name,
61
+ )
62
+ return self._apply_dataloader_overrides(model)
63
+
64
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
65
+ base_tw_power = self.ctx.default_tweedie_power()
66
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
67
+ metric_ctx: Dict[str, Any] = {}
68
+
69
+ def data_provider():
70
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
71
+ assert data is not None, "Preprocessed training data is missing."
72
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
73
+
74
+ def model_builder(params: Dict[str, Any]):
75
+ if loss_name == "tweedie":
76
+ tw_power = params.get("tw_power", base_tw_power)
77
+ elif loss_name in ("poisson", "gamma"):
78
+ tw_power = base_tw_power
79
+ else:
80
+ tw_power = None
81
+ metric_ctx["tw_power"] = tw_power
82
+ if tw_power is None:
83
+ params = dict(params)
84
+ params.pop("tw_power", None)
85
+ return self._build_model(params)
86
+
87
+ def preprocess_fn(X_train, X_val):
88
+ X_train_s, X_val_s, _ = self._standardize_fold(
89
+ X_train, X_val, self.ctx.num_features)
90
+ return X_train_s, X_val_s
91
+
92
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
93
+ model.fit(
94
+ X_train,
95
+ y_train,
96
+ w_train=w_train,
97
+ X_val=X_val,
98
+ y_val=y_val,
99
+ w_val=w_val,
100
+ trial=trial_obj,
101
+ )
102
+ return model.predict(X_val)
103
+
104
+ def metric_fn(y_true, y_pred, weight):
105
+ if self.ctx.task_type == 'classification':
106
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
107
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
108
+ return regression_loss(
109
+ y_true,
110
+ y_pred,
111
+ weight,
112
+ loss_name=loss_name,
113
+ tweedie_power=metric_ctx.get("tw_power", base_tw_power),
114
+ )
115
+
116
+ # Keep GNN BO lightweight: sample during CV, use full data for final training.
117
+ X_cap = data_provider()[0]
118
+ sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
119
+
120
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
121
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
122
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
123
+ "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
124
+ "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
125
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
126
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
127
+ }
128
+ if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
129
+ param_space["tw_power"] = lambda t: t.suggest_float(
130
+ 'tw_power', 1.0, 2.0)
131
+
132
+ return self.cross_val_generic(
133
+ trial=trial,
134
+ hyperparameter_space=param_space,
135
+ data_provider=data_provider,
136
+ model_builder=model_builder,
137
+ metric_fn=metric_fn,
138
+ sample_limit=sample_limit,
139
+ preprocess_fn=preprocess_fn,
140
+ fit_predict_fn=fit_predict,
141
+ cleanup_fn=lambda m: getattr(
142
+ getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
143
+ )
144
+
145
+ def train(self) -> None:
146
+ if not self.best_params:
147
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
148
+
149
+ data = self.ctx.train_oht_scl_data
150
+ assert data is not None, "Preprocessed training data is missing."
151
+ X_all = data[self.ctx.var_nmes]
152
+ y_all = data[self.ctx.resp_nme]
153
+ w_all = data[self.ctx.weight_nme]
154
+
155
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
+ refit_epochs = None
157
+
158
+ split = self._resolve_train_val_indices(X_all)
159
+ if split is not None:
160
+ train_idx, val_idx = split
161
+ X_train = X_all.iloc[train_idx]
162
+ y_train = y_all.iloc[train_idx]
163
+ w_train = w_all.iloc[train_idx]
164
+ X_val = X_all.iloc[val_idx]
165
+ y_val = y_all.iloc[val_idx]
166
+ w_val = w_all.iloc[val_idx]
167
+
168
+ if use_refit:
169
+ tmp_model = self._build_model(self.best_params)
170
+ tmp_model.fit(
171
+ X_train,
172
+ y_train,
173
+ w_train=w_train,
174
+ X_val=X_val,
175
+ y_val=y_val,
176
+ w_val=w_val,
177
+ trial=None,
178
+ )
179
+ refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
180
+ getattr(getattr(tmp_model, "gnn", None), "to",
181
+ lambda *_args, **_kwargs: None)("cpu")
182
+ self._clean_gpu()
183
+ else:
184
+ self.model = self._build_model(self.best_params)
185
+ self.model.fit(
186
+ X_train,
187
+ y_train,
188
+ w_train=w_train,
189
+ X_val=X_val,
190
+ y_val=y_val,
191
+ w_val=w_val,
192
+ trial=None,
193
+ )
194
+ else:
195
+ use_refit = False
196
+
197
+ if use_refit:
198
+ self.model = self._build_model(self.best_params)
199
+ if refit_epochs is not None:
200
+ self.model.epochs = int(refit_epochs)
201
+ self.model.fit(
202
+ X_all,
203
+ y_all,
204
+ w_train=w_all,
205
+ X_val=None,
206
+ y_val=None,
207
+ w_val=None,
208
+ trial=None,
209
+ )
210
+ elif self.model is None:
211
+ self.model = self._build_model(self.best_params)
212
+ self.model.fit(
213
+ X_all,
214
+ y_all,
215
+ w_train=w_all,
216
+ X_val=None,
217
+ y_val=None,
218
+ w_val=None,
219
+ trial=None,
220
+ )
221
+ self.ctx.model_label.append(self.label)
222
+ self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
223
+ self.ctx.gnn_best = self.model
224
+
225
+ # If geo_feature_nmes is set, refresh geo tokens for FT input.
226
+ if self.ctx.config.geo_feature_nmes:
227
+ self.prepare_geo_tokens(force=True)
228
+
229
+ def ensemble_predict(self, k: int) -> None:
230
+ if not self.best_params:
231
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
232
+ data = self.ctx.train_oht_scl_data
233
+ test_data = self.ctx.test_oht_scl_data
234
+ if data is None or test_data is None:
235
+ raise RuntimeError("Missing standardized data for GNN ensemble.")
236
+ X_all = data[self.ctx.var_nmes]
237
+ y_all = data[self.ctx.resp_nme]
238
+ w_all = data[self.ctx.weight_nme]
239
+ X_test = test_data[self.ctx.var_nmes]
240
+
241
+ k = max(2, int(k))
242
+ n_samples = len(X_all)
243
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
244
+ if split_iter is None:
245
+ print(
246
+ f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
247
+ flush=True,
248
+ )
249
+ return
250
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
251
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
252
+
253
+ split_count = 0
254
+ for train_idx, val_idx in split_iter:
255
+ model = self._build_model(self.best_params)
256
+ model.fit(
257
+ X_all.iloc[train_idx],
258
+ y_all.iloc[train_idx],
259
+ w_train=w_all.iloc[train_idx],
260
+ X_val=X_all.iloc[val_idx],
261
+ y_val=y_all.iloc[val_idx],
262
+ w_val=w_all.iloc[val_idx],
263
+ trial=None,
264
+ )
265
+ pred_train = model.predict(X_all)
266
+ pred_test = model.predict(X_test)
267
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
268
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
269
+ getattr(getattr(model, "gnn", None), "to",
270
+ lambda *_args, **_kwargs: None)("cpu")
271
+ self._clean_gpu()
272
+ split_count += 1
273
+
274
+ if split_count < 1:
275
+ print(
276
+ f"[GNN Ensemble] no CV splits generated; skip ensemble.",
277
+ flush=True,
278
+ )
279
+ return
280
+ preds_train = preds_train_sum / float(split_count)
281
+ preds_test = preds_test_sum / float(split_count)
282
+ self._cache_predictions("gnn", preds_train, preds_test)
283
+
284
+ def prepare_geo_tokens(self, force: bool = False) -> None:
285
+ """Train/update the GNN encoder for geo tokens and inject them into FT input."""
286
+ geo_cols = list(self.ctx.config.geo_feature_nmes or [])
287
+ if not geo_cols:
288
+ return
289
+ if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
290
+ return
291
+
292
+ result = self.ctx._build_geo_tokens()
293
+ if result is None:
294
+ return
295
+ train_tokens, test_tokens, cols, geo_gnn = result
296
+ self.ctx.train_geo_tokens = train_tokens
297
+ self.ctx.test_geo_tokens = test_tokens
298
+ self.ctx.geo_token_cols = cols
299
+ self.ctx.geo_gnn_model = geo_gnn
300
+ print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
301
+
302
+ def save(self) -> None:
303
+ if self.model is None:
304
+ print(f"[save] Warning: No model to save for {self.label}")
305
+ return
306
+ path = self.output.model_path(self._get_model_filename())
307
+ base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
308
+ if base_gnn is not None:
309
+ base_gnn = base_gnn.to("cpu")
310
+ state = None if base_gnn is None else base_gnn.state_dict()
311
+ payload = {
312
+ "best_params": self.best_params,
313
+ "state_dict": state,
314
+ "preprocess_artifacts": self._export_preprocess_artifacts(),
315
+ }
316
+ torch.save(payload, path)
317
+
318
+ def load(self) -> None:
319
+ path = self.output.model_path(self._get_model_filename())
320
+ if not os.path.exists(path):
321
+ print(f"[load] Warning: Model file not found: {path}")
322
+ return
323
+ payload = torch_load(path, map_location='cpu', weights_only=False)
324
+ if not isinstance(payload, dict):
325
+ raise ValueError(f"Invalid GNN checkpoint: {path}")
326
+ params = payload.get("best_params") or {}
327
+ state_dict = payload.get("state_dict")
328
+ model = self._build_model(params)
329
+ if params:
330
+ model.set_params(dict(params))
331
+ base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
332
+ if base_gnn is not None and state_dict is not None:
333
+ # Use strict=True for better error detection, but handle missing keys gracefully
334
+ try:
335
+ base_gnn.load_state_dict(state_dict, strict=True)
336
+ except RuntimeError as e:
337
+ if "Missing key" in str(e) or "Unexpected key" in str(e):
338
+ print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
339
+ base_gnn.load_state_dict(state_dict, strict=False)
340
+ else:
341
+ raise
342
+ self.model = model
343
+ self.best_params = dict(params) if isinstance(params, dict) else None
344
+ self.ctx.gnn_best = self.model