ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,344 +1,344 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import optuna
8
- import torch
9
- from sklearn.metrics import log_loss
10
-
11
- from .trainer_base import TrainerBase
12
- from ..models import GraphNeuralNetSklearn
13
- from ..utils import EPS
14
- from ..utils.losses import regression_loss
15
- from ins_pricing.utils import get_logger
16
- from ins_pricing.utils.torch_compat import torch_load
17
-
18
- _logger = get_logger("ins_pricing.trainer.gnn")
19
-
20
- class GNNTrainer(TrainerBase):
21
- def __init__(self, context: "BayesOptModel") -> None:
22
- super().__init__(context, 'GNN', 'GNN')
23
- self.model: Optional[GraphNeuralNetSklearn] = None
24
- self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
25
-
26
- def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
27
- params = params or {}
28
- base_tw_power = self.ctx.default_tweedie_power()
29
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
30
- tw_power = params.get("tw_power")
31
- if self.ctx.task_type == "regression":
32
- if loss_name == "tweedie":
33
- tw_power = base_tw_power if tw_power is None else float(tw_power)
34
- elif loss_name in ("poisson", "gamma"):
35
- tw_power = base_tw_power
36
- else:
37
- tw_power = None
38
- model = GraphNeuralNetSklearn(
39
- model_nme=f"{self.ctx.model_nme}_gnn",
40
- input_dim=len(self.ctx.var_nmes),
41
- hidden_dim=int(params.get("hidden_dim", 64)),
42
- num_layers=int(params.get("num_layers", 2)),
43
- k_neighbors=int(params.get("k_neighbors", 10)),
44
- dropout=float(params.get("dropout", 0.1)),
45
- learning_rate=float(params.get("learning_rate", 1e-3)),
46
- epochs=int(params.get("epochs", self.ctx.epochs)),
47
- patience=int(params.get("patience", 5)),
48
- task_type=self.ctx.task_type,
49
- tweedie_power=tw_power,
50
- weight_decay=float(params.get("weight_decay", 0.0)),
51
- use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
52
- use_ddp=bool(self.ctx.config.use_gnn_ddp),
53
- use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
54
- approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
55
- graph_cache_path=self.ctx.config.gnn_graph_cache,
56
- max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
57
- knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
58
- knn_gpu_mem_overhead=float(
59
- self.ctx.config.gnn_knn_gpu_mem_overhead),
60
- loss_name=loss_name,
61
- )
62
- return self._apply_dataloader_overrides(model)
63
-
64
- def cross_val(self, trial: optuna.trial.Trial) -> float:
65
- base_tw_power = self.ctx.default_tweedie_power()
66
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
67
- metric_ctx: Dict[str, Any] = {}
68
-
69
- def data_provider():
70
- data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
71
- assert data is not None, "Preprocessed training data is missing."
72
- return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
73
-
74
- def model_builder(params: Dict[str, Any]):
75
- if loss_name == "tweedie":
76
- tw_power = params.get("tw_power", base_tw_power)
77
- elif loss_name in ("poisson", "gamma"):
78
- tw_power = base_tw_power
79
- else:
80
- tw_power = None
81
- metric_ctx["tw_power"] = tw_power
82
- if tw_power is None:
83
- params = dict(params)
84
- params.pop("tw_power", None)
85
- return self._build_model(params)
86
-
87
- def preprocess_fn(X_train, X_val):
88
- X_train_s, X_val_s, _ = self._standardize_fold(
89
- X_train, X_val, self.ctx.num_features)
90
- return X_train_s, X_val_s
91
-
92
- def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
93
- model.fit(
94
- X_train,
95
- y_train,
96
- w_train=w_train,
97
- X_val=X_val,
98
- y_val=y_val,
99
- w_val=w_val,
100
- trial=trial_obj,
101
- )
102
- return model.predict(X_val)
103
-
104
- def metric_fn(y_true, y_pred, weight):
105
- if self.ctx.task_type == 'classification':
106
- y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
107
- return log_loss(y_true, y_pred_clipped, sample_weight=weight)
108
- return regression_loss(
109
- y_true,
110
- y_pred,
111
- weight,
112
- loss_name=loss_name,
113
- tweedie_power=metric_ctx.get("tw_power", base_tw_power),
114
- )
115
-
116
- # Keep GNN BO lightweight: sample during CV, use full data for final training.
117
- X_cap = data_provider()[0]
118
- sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
119
-
120
- param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
121
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
122
- "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
123
- "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
124
- "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
125
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
126
- "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
127
- }
128
- if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
129
- param_space["tw_power"] = lambda t: t.suggest_float(
130
- 'tw_power', 1.0, 2.0)
131
-
132
- return self.cross_val_generic(
133
- trial=trial,
134
- hyperparameter_space=param_space,
135
- data_provider=data_provider,
136
- model_builder=model_builder,
137
- metric_fn=metric_fn,
138
- sample_limit=sample_limit,
139
- preprocess_fn=preprocess_fn,
140
- fit_predict_fn=fit_predict,
141
- cleanup_fn=lambda m: getattr(
142
- getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
143
- )
144
-
145
- def train(self) -> None:
146
- if not self.best_params:
147
- raise RuntimeError("Run tune() first to obtain best GNN parameters.")
148
-
149
- data = self.ctx.train_oht_scl_data
150
- assert data is not None, "Preprocessed training data is missing."
151
- X_all = data[self.ctx.var_nmes]
152
- y_all = data[self.ctx.resp_nme]
153
- w_all = data[self.ctx.weight_nme]
154
-
155
- use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
- refit_epochs = None
157
-
158
- split = self._resolve_train_val_indices(X_all)
159
- if split is not None:
160
- train_idx, val_idx = split
161
- X_train = X_all.iloc[train_idx]
162
- y_train = y_all.iloc[train_idx]
163
- w_train = w_all.iloc[train_idx]
164
- X_val = X_all.iloc[val_idx]
165
- y_val = y_all.iloc[val_idx]
166
- w_val = w_all.iloc[val_idx]
167
-
168
- if use_refit:
169
- tmp_model = self._build_model(self.best_params)
170
- tmp_model.fit(
171
- X_train,
172
- y_train,
173
- w_train=w_train,
174
- X_val=X_val,
175
- y_val=y_val,
176
- w_val=w_val,
177
- trial=None,
178
- )
179
- refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
180
- getattr(getattr(tmp_model, "gnn", None), "to",
181
- lambda *_args, **_kwargs: None)("cpu")
182
- self._clean_gpu()
183
- else:
184
- self.model = self._build_model(self.best_params)
185
- self.model.fit(
186
- X_train,
187
- y_train,
188
- w_train=w_train,
189
- X_val=X_val,
190
- y_val=y_val,
191
- w_val=w_val,
192
- trial=None,
193
- )
194
- else:
195
- use_refit = False
196
-
197
- if use_refit:
198
- self.model = self._build_model(self.best_params)
199
- if refit_epochs is not None:
200
- self.model.epochs = int(refit_epochs)
201
- self.model.fit(
202
- X_all,
203
- y_all,
204
- w_train=w_all,
205
- X_val=None,
206
- y_val=None,
207
- w_val=None,
208
- trial=None,
209
- )
210
- elif self.model is None:
211
- self.model = self._build_model(self.best_params)
212
- self.model.fit(
213
- X_all,
214
- y_all,
215
- w_train=w_all,
216
- X_val=None,
217
- y_val=None,
218
- w_val=None,
219
- trial=None,
220
- )
221
- self.ctx.model_label.append(self.label)
222
- self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
223
- self.ctx.gnn_best = self.model
224
-
225
- # If geo_feature_nmes is set, refresh geo tokens for FT input.
226
- if self.ctx.config.geo_feature_nmes:
227
- self.prepare_geo_tokens(force=True)
228
-
229
- def ensemble_predict(self, k: int) -> None:
230
- if not self.best_params:
231
- raise RuntimeError("Run tune() first to obtain best GNN parameters.")
232
- data = self.ctx.train_oht_scl_data
233
- test_data = self.ctx.test_oht_scl_data
234
- if data is None or test_data is None:
235
- raise RuntimeError("Missing standardized data for GNN ensemble.")
236
- X_all = data[self.ctx.var_nmes]
237
- y_all = data[self.ctx.resp_nme]
238
- w_all = data[self.ctx.weight_nme]
239
- X_test = test_data[self.ctx.var_nmes]
240
-
241
- k = max(2, int(k))
242
- n_samples = len(X_all)
243
- split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
244
- if split_iter is None:
245
- print(
246
- f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
247
- flush=True,
248
- )
249
- return
250
- preds_train_sum = np.zeros(n_samples, dtype=np.float64)
251
- preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
252
-
253
- split_count = 0
254
- for train_idx, val_idx in split_iter:
255
- model = self._build_model(self.best_params)
256
- model.fit(
257
- X_all.iloc[train_idx],
258
- y_all.iloc[train_idx],
259
- w_train=w_all.iloc[train_idx],
260
- X_val=X_all.iloc[val_idx],
261
- y_val=y_all.iloc[val_idx],
262
- w_val=w_all.iloc[val_idx],
263
- trial=None,
264
- )
265
- pred_train = model.predict(X_all)
266
- pred_test = model.predict(X_test)
267
- preds_train_sum += np.asarray(pred_train, dtype=np.float64)
268
- preds_test_sum += np.asarray(pred_test, dtype=np.float64)
269
- getattr(getattr(model, "gnn", None), "to",
270
- lambda *_args, **_kwargs: None)("cpu")
271
- self._clean_gpu()
272
- split_count += 1
273
-
274
- if split_count < 1:
275
- print(
276
- f"[GNN Ensemble] no CV splits generated; skip ensemble.",
277
- flush=True,
278
- )
279
- return
280
- preds_train = preds_train_sum / float(split_count)
281
- preds_test = preds_test_sum / float(split_count)
282
- self._cache_predictions("gnn", preds_train, preds_test)
283
-
284
- def prepare_geo_tokens(self, force: bool = False) -> None:
285
- """Train/update the GNN encoder for geo tokens and inject them into FT input."""
286
- geo_cols = list(self.ctx.config.geo_feature_nmes or [])
287
- if not geo_cols:
288
- return
289
- if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
290
- return
291
-
292
- result = self.ctx._build_geo_tokens()
293
- if result is None:
294
- return
295
- train_tokens, test_tokens, cols, geo_gnn = result
296
- self.ctx.train_geo_tokens = train_tokens
297
- self.ctx.test_geo_tokens = test_tokens
298
- self.ctx.geo_token_cols = cols
299
- self.ctx.geo_gnn_model = geo_gnn
300
- print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
301
-
302
- def save(self) -> None:
303
- if self.model is None:
304
- print(f"[save] Warning: No model to save for {self.label}")
305
- return
306
- path = self.output.model_path(self._get_model_filename())
307
- base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
308
- if base_gnn is not None:
309
- base_gnn = base_gnn.to("cpu")
310
- state = None if base_gnn is None else base_gnn.state_dict()
311
- payload = {
312
- "best_params": self.best_params,
313
- "state_dict": state,
314
- "preprocess_artifacts": self._export_preprocess_artifacts(),
315
- }
316
- torch.save(payload, path)
317
-
318
- def load(self) -> None:
319
- path = self.output.model_path(self._get_model_filename())
320
- if not os.path.exists(path):
321
- print(f"[load] Warning: Model file not found: {path}")
322
- return
323
- payload = torch_load(path, map_location='cpu', weights_only=False)
324
- if not isinstance(payload, dict):
325
- raise ValueError(f"Invalid GNN checkpoint: {path}")
326
- params = payload.get("best_params") or {}
327
- state_dict = payload.get("state_dict")
328
- model = self._build_model(params)
329
- if params:
330
- model.set_params(dict(params))
331
- base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
332
- if base_gnn is not None and state_dict is not None:
333
- # Use strict=True for better error detection, but handle missing keys gracefully
334
- try:
335
- base_gnn.load_state_dict(state_dict, strict=True)
336
- except RuntimeError as e:
337
- if "Missing key" in str(e) or "Unexpected key" in str(e):
338
- print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
339
- base_gnn.load_state_dict(state_dict, strict=False)
340
- else:
341
- raise
342
- self.model = model
343
- self.best_params = dict(params) if isinstance(params, dict) else None
344
- self.ctx.gnn_best = self.model
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ from sklearn.metrics import log_loss
10
+
11
+ from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
12
+ from ins_pricing.modelling.bayesopt.models import GraphNeuralNetSklearn
13
+ from ins_pricing.utils import EPS
14
+ from ins_pricing.utils.losses import regression_loss
15
+ from ins_pricing.utils import get_logger
16
+ from ins_pricing.utils.torch_compat import torch_load
17
+
18
+ _logger = get_logger("ins_pricing.trainer.gnn")
19
+
20
+ class GNNTrainer(TrainerBase):
21
+ def __init__(self, context: "BayesOptModel") -> None:
22
+ super().__init__(context, 'GNN', 'GNN')
23
+ self.model: Optional[GraphNeuralNetSklearn] = None
24
+ self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
25
+
26
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
27
+ params = params or {}
28
+ base_tw_power = self.ctx.default_tweedie_power()
29
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
30
+ tw_power = params.get("tw_power")
31
+ if self.ctx.task_type == "regression":
32
+ if loss_name == "tweedie":
33
+ tw_power = base_tw_power if tw_power is None else float(tw_power)
34
+ elif loss_name in ("poisson", "gamma"):
35
+ tw_power = base_tw_power
36
+ else:
37
+ tw_power = None
38
+ model = GraphNeuralNetSklearn(
39
+ model_nme=f"{self.ctx.model_nme}_gnn",
40
+ input_dim=len(self.ctx.var_nmes),
41
+ hidden_dim=int(params.get("hidden_dim", 64)),
42
+ num_layers=int(params.get("num_layers", 2)),
43
+ k_neighbors=int(params.get("k_neighbors", 10)),
44
+ dropout=float(params.get("dropout", 0.1)),
45
+ learning_rate=float(params.get("learning_rate", 1e-3)),
46
+ epochs=int(params.get("epochs", self.ctx.epochs)),
47
+ patience=int(params.get("patience", 5)),
48
+ task_type=self.ctx.task_type,
49
+ tweedie_power=tw_power,
50
+ weight_decay=float(params.get("weight_decay", 0.0)),
51
+ use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
52
+ use_ddp=bool(self.ctx.config.use_gnn_ddp),
53
+ use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
54
+ approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
55
+ graph_cache_path=self.ctx.config.gnn_graph_cache,
56
+ max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
57
+ knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
58
+ knn_gpu_mem_overhead=float(
59
+ self.ctx.config.gnn_knn_gpu_mem_overhead),
60
+ loss_name=loss_name,
61
+ )
62
+ return self._apply_dataloader_overrides(model)
63
+
64
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
65
+ base_tw_power = self.ctx.default_tweedie_power()
66
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
67
+ metric_ctx: Dict[str, Any] = {}
68
+
69
+ def data_provider():
70
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
71
+ assert data is not None, "Preprocessed training data is missing."
72
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
73
+
74
+ def model_builder(params: Dict[str, Any]):
75
+ if loss_name == "tweedie":
76
+ tw_power = params.get("tw_power", base_tw_power)
77
+ elif loss_name in ("poisson", "gamma"):
78
+ tw_power = base_tw_power
79
+ else:
80
+ tw_power = None
81
+ metric_ctx["tw_power"] = tw_power
82
+ if tw_power is None:
83
+ params = dict(params)
84
+ params.pop("tw_power", None)
85
+ return self._build_model(params)
86
+
87
+ def preprocess_fn(X_train, X_val):
88
+ X_train_s, X_val_s, _ = self._standardize_fold(
89
+ X_train, X_val, self.ctx.num_features)
90
+ return X_train_s, X_val_s
91
+
92
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
93
+ model.fit(
94
+ X_train,
95
+ y_train,
96
+ w_train=w_train,
97
+ X_val=X_val,
98
+ y_val=y_val,
99
+ w_val=w_val,
100
+ trial=trial_obj,
101
+ )
102
+ return model.predict(X_val)
103
+
104
+ def metric_fn(y_true, y_pred, weight):
105
+ if self.ctx.task_type == 'classification':
106
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
107
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
108
+ return regression_loss(
109
+ y_true,
110
+ y_pred,
111
+ weight,
112
+ loss_name=loss_name,
113
+ tweedie_power=metric_ctx.get("tw_power", base_tw_power),
114
+ )
115
+
116
+ # Keep GNN BO lightweight: sample during CV, use full data for final training.
117
+ X_cap = data_provider()[0]
118
+ sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
119
+
120
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
121
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
122
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
123
+ "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
124
+ "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
125
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
126
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
127
+ }
128
+ if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
129
+ param_space["tw_power"] = lambda t: t.suggest_float(
130
+ 'tw_power', 1.0, 2.0)
131
+
132
+ return self.cross_val_generic(
133
+ trial=trial,
134
+ hyperparameter_space=param_space,
135
+ data_provider=data_provider,
136
+ model_builder=model_builder,
137
+ metric_fn=metric_fn,
138
+ sample_limit=sample_limit,
139
+ preprocess_fn=preprocess_fn,
140
+ fit_predict_fn=fit_predict,
141
+ cleanup_fn=lambda m: getattr(
142
+ getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
143
+ )
144
+
145
+ def train(self) -> None:
146
+ if not self.best_params:
147
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
148
+
149
+ data = self.ctx.train_oht_scl_data
150
+ assert data is not None, "Preprocessed training data is missing."
151
+ X_all = data[self.ctx.var_nmes]
152
+ y_all = data[self.ctx.resp_nme]
153
+ w_all = data[self.ctx.weight_nme]
154
+
155
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
+ refit_epochs = None
157
+
158
+ split = self._resolve_train_val_indices(X_all)
159
+ if split is not None:
160
+ train_idx, val_idx = split
161
+ X_train = X_all.iloc[train_idx]
162
+ y_train = y_all.iloc[train_idx]
163
+ w_train = w_all.iloc[train_idx]
164
+ X_val = X_all.iloc[val_idx]
165
+ y_val = y_all.iloc[val_idx]
166
+ w_val = w_all.iloc[val_idx]
167
+
168
+ if use_refit:
169
+ tmp_model = self._build_model(self.best_params)
170
+ tmp_model.fit(
171
+ X_train,
172
+ y_train,
173
+ w_train=w_train,
174
+ X_val=X_val,
175
+ y_val=y_val,
176
+ w_val=w_val,
177
+ trial=None,
178
+ )
179
+ refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
180
+ getattr(getattr(tmp_model, "gnn", None), "to",
181
+ lambda *_args, **_kwargs: None)("cpu")
182
+ self._clean_gpu()
183
+ else:
184
+ self.model = self._build_model(self.best_params)
185
+ self.model.fit(
186
+ X_train,
187
+ y_train,
188
+ w_train=w_train,
189
+ X_val=X_val,
190
+ y_val=y_val,
191
+ w_val=w_val,
192
+ trial=None,
193
+ )
194
+ else:
195
+ use_refit = False
196
+
197
+ if use_refit:
198
+ self.model = self._build_model(self.best_params)
199
+ if refit_epochs is not None:
200
+ self.model.epochs = int(refit_epochs)
201
+ self.model.fit(
202
+ X_all,
203
+ y_all,
204
+ w_train=w_all,
205
+ X_val=None,
206
+ y_val=None,
207
+ w_val=None,
208
+ trial=None,
209
+ )
210
+ elif self.model is None:
211
+ self.model = self._build_model(self.best_params)
212
+ self.model.fit(
213
+ X_all,
214
+ y_all,
215
+ w_train=w_all,
216
+ X_val=None,
217
+ y_val=None,
218
+ w_val=None,
219
+ trial=None,
220
+ )
221
+ self.ctx.model_label.append(self.label)
222
+ self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
223
+ self.ctx.gnn_best = self.model
224
+
225
+ # If geo_feature_nmes is set, refresh geo tokens for FT input.
226
+ if self.ctx.config.geo_feature_nmes:
227
+ self.prepare_geo_tokens(force=True)
228
+
229
+ def ensemble_predict(self, k: int) -> None:
230
+ if not self.best_params:
231
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
232
+ data = self.ctx.train_oht_scl_data
233
+ test_data = self.ctx.test_oht_scl_data
234
+ if data is None or test_data is None:
235
+ raise RuntimeError("Missing standardized data for GNN ensemble.")
236
+ X_all = data[self.ctx.var_nmes]
237
+ y_all = data[self.ctx.resp_nme]
238
+ w_all = data[self.ctx.weight_nme]
239
+ X_test = test_data[self.ctx.var_nmes]
240
+
241
+ k = max(2, int(k))
242
+ n_samples = len(X_all)
243
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
244
+ if split_iter is None:
245
+ print(
246
+ f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
247
+ flush=True,
248
+ )
249
+ return
250
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
251
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
252
+
253
+ split_count = 0
254
+ for train_idx, val_idx in split_iter:
255
+ model = self._build_model(self.best_params)
256
+ model.fit(
257
+ X_all.iloc[train_idx],
258
+ y_all.iloc[train_idx],
259
+ w_train=w_all.iloc[train_idx],
260
+ X_val=X_all.iloc[val_idx],
261
+ y_val=y_all.iloc[val_idx],
262
+ w_val=w_all.iloc[val_idx],
263
+ trial=None,
264
+ )
265
+ pred_train = model.predict(X_all)
266
+ pred_test = model.predict(X_test)
267
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
268
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
269
+ getattr(getattr(model, "gnn", None), "to",
270
+ lambda *_args, **_kwargs: None)("cpu")
271
+ self._clean_gpu()
272
+ split_count += 1
273
+
274
+ if split_count < 1:
275
+ print(
276
+ f"[GNN Ensemble] no CV splits generated; skip ensemble.",
277
+ flush=True,
278
+ )
279
+ return
280
+ preds_train = preds_train_sum / float(split_count)
281
+ preds_test = preds_test_sum / float(split_count)
282
+ self._cache_predictions("gnn", preds_train, preds_test)
283
+
284
+ def prepare_geo_tokens(self, force: bool = False) -> None:
285
+ """Train/update the GNN encoder for geo tokens and inject them into FT input."""
286
+ geo_cols = list(self.ctx.config.geo_feature_nmes or [])
287
+ if not geo_cols:
288
+ return
289
+ if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
290
+ return
291
+
292
+ result = self.ctx._build_geo_tokens()
293
+ if result is None:
294
+ return
295
+ train_tokens, test_tokens, cols, geo_gnn = result
296
+ self.ctx.train_geo_tokens = train_tokens
297
+ self.ctx.test_geo_tokens = test_tokens
298
+ self.ctx.geo_token_cols = cols
299
+ self.ctx.geo_gnn_model = geo_gnn
300
+ print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
301
+
302
+ def save(self) -> None:
303
+ if self.model is None:
304
+ print(f"[save] Warning: No model to save for {self.label}")
305
+ return
306
+ path = self.output.model_path(self._get_model_filename())
307
+ base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
308
+ if base_gnn is not None:
309
+ base_gnn = base_gnn.to("cpu")
310
+ state = None if base_gnn is None else base_gnn.state_dict()
311
+ payload = {
312
+ "best_params": self.best_params,
313
+ "state_dict": state,
314
+ "preprocess_artifacts": self._export_preprocess_artifacts(),
315
+ }
316
+ torch.save(payload, path)
317
+
318
+ def load(self) -> None:
319
+ path = self.output.model_path(self._get_model_filename())
320
+ if not os.path.exists(path):
321
+ print(f"[load] Warning: Model file not found: {path}")
322
+ return
323
+ payload = torch_load(path, map_location='cpu', weights_only=False)
324
+ if not isinstance(payload, dict):
325
+ raise ValueError(f"Invalid GNN checkpoint: {path}")
326
+ params = payload.get("best_params") or {}
327
+ state_dict = payload.get("state_dict")
328
+ model = self._build_model(params)
329
+ if params:
330
+ model.set_params(dict(params))
331
+ base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
332
+ if base_gnn is not None and state_dict is not None:
333
+ # Use strict=True for better error detection, but handle missing keys gracefully
334
+ try:
335
+ base_gnn.load_state_dict(state_dict, strict=True)
336
+ except RuntimeError as e:
337
+ if "Missing key" in str(e) or "Unexpected key" in str(e):
338
+ print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
339
+ base_gnn.load_state_dict(state_dict, strict=False)
340
+ else:
341
+ raise
342
+ self.model = model
343
+ self.best_params = dict(params) if isinstance(params, dict) else None
344
+ self.ctx.gnn_best = self.model