ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,172 +1,184 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import optuna
8
- import torch
9
- from sklearn.metrics import log_loss
10
-
11
- from .trainer_base import TrainerBase
12
- from ..models import GraphNeuralNetSklearn
13
- from ..utils import EPS
14
- from ..utils.losses import regression_loss
15
- from ins_pricing.utils import get_logger
16
- from ins_pricing.utils.torch_compat import torch_load
17
-
18
- _logger = get_logger("ins_pricing.trainer.gnn")
19
-
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ from sklearn.metrics import log_loss
10
+
11
+ from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
12
+ from ins_pricing.modelling.bayesopt.models import GraphNeuralNetSklearn
13
+ from ins_pricing.utils import EPS, get_logger, log_print
14
+ from ins_pricing.utils.losses import regression_loss
15
+ from ins_pricing.utils.torch_compat import torch_load
16
+
17
+ _logger = get_logger("ins_pricing.trainer.gnn")
18
+
19
+
20
+ def _log(*args, **kwargs) -> None:
21
+ log_print(_logger, *args, **kwargs)
22
+
20
23
  class GNNTrainer(TrainerBase):
21
24
  def __init__(self, context: "BayesOptModel") -> None:
22
25
  super().__init__(context, 'GNN', 'GNN')
23
26
  self.model: Optional[GraphNeuralNetSklearn] = None
24
27
  self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
25
28
 
26
- def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
27
- params = params or {}
28
- base_tw_power = self.ctx.default_tweedie_power()
29
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
30
- tw_power = params.get("tw_power")
31
- if self.ctx.task_type == "regression":
32
- if loss_name == "tweedie":
33
- tw_power = base_tw_power if tw_power is None else float(tw_power)
34
- elif loss_name in ("poisson", "gamma"):
35
- tw_power = base_tw_power
36
- else:
37
- tw_power = None
38
- model = GraphNeuralNetSklearn(
39
- model_nme=f"{self.ctx.model_nme}_gnn",
40
- input_dim=len(self.ctx.var_nmes),
41
- hidden_dim=int(params.get("hidden_dim", 64)),
42
- num_layers=int(params.get("num_layers", 2)),
43
- k_neighbors=int(params.get("k_neighbors", 10)),
44
- dropout=float(params.get("dropout", 0.1)),
45
- learning_rate=float(params.get("learning_rate", 1e-3)),
46
- epochs=int(params.get("epochs", self.ctx.epochs)),
47
- patience=int(params.get("patience", 5)),
48
- task_type=self.ctx.task_type,
49
- tweedie_power=tw_power,
50
- weight_decay=float(params.get("weight_decay", 0.0)),
51
- use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
52
- use_ddp=bool(self.ctx.config.use_gnn_ddp),
53
- use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
54
- approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
55
- graph_cache_path=self.ctx.config.gnn_graph_cache,
56
- max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
57
- knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
58
- knn_gpu_mem_overhead=float(
59
- self.ctx.config.gnn_knn_gpu_mem_overhead),
60
- loss_name=loss_name,
61
- )
62
- return self._apply_dataloader_overrides(model)
63
-
64
- def cross_val(self, trial: optuna.trial.Trial) -> float:
65
- base_tw_power = self.ctx.default_tweedie_power()
66
- loss_name = getattr(self.ctx, "loss_name", "tweedie")
67
- metric_ctx: Dict[str, Any] = {}
68
-
69
- def data_provider():
70
- data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
71
- assert data is not None, "Preprocessed training data is missing."
72
- return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
73
-
74
- def model_builder(params: Dict[str, Any]):
75
- if loss_name == "tweedie":
76
- tw_power = params.get("tw_power", base_tw_power)
77
- elif loss_name in ("poisson", "gamma"):
78
- tw_power = base_tw_power
79
- else:
80
- tw_power = None
81
- metric_ctx["tw_power"] = tw_power
82
- if tw_power is None:
83
- params = dict(params)
84
- params.pop("tw_power", None)
85
- return self._build_model(params)
86
-
87
- def preprocess_fn(X_train, X_val):
88
- X_train_s, X_val_s, _ = self._standardize_fold(
89
- X_train, X_val, self.ctx.num_features)
90
- return X_train_s, X_val_s
91
-
92
- def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
93
- model.fit(
94
- X_train,
95
- y_train,
96
- w_train=w_train,
97
- X_val=X_val,
98
- y_val=y_val,
99
- w_val=w_val,
100
- trial=trial_obj,
101
- )
102
- return model.predict(X_val)
103
-
104
- def metric_fn(y_true, y_pred, weight):
105
- if self.ctx.task_type == 'classification':
106
- y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
107
- return log_loss(y_true, y_pred_clipped, sample_weight=weight)
108
- return regression_loss(
109
- y_true,
110
- y_pred,
111
- weight,
112
- loss_name=loss_name,
113
- tweedie_power=metric_ctx.get("tw_power", base_tw_power),
114
- )
115
-
116
- # Keep GNN BO lightweight: sample during CV, use full data for final training.
117
- X_cap = data_provider()[0]
118
- sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
119
-
120
- param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
121
- "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
122
- "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
123
- "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
124
- "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
125
- "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
126
- "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
127
- }
128
- if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
129
- param_space["tw_power"] = lambda t: t.suggest_float(
130
- 'tw_power', 1.0, 2.0)
131
-
132
- return self.cross_val_generic(
133
- trial=trial,
134
- hyperparameter_space=param_space,
135
- data_provider=data_provider,
136
- model_builder=model_builder,
137
- metric_fn=metric_fn,
138
- sample_limit=sample_limit,
139
- preprocess_fn=preprocess_fn,
140
- fit_predict_fn=fit_predict,
141
- cleanup_fn=lambda m: getattr(
142
- getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
143
- )
144
-
145
- def train(self) -> None:
146
- if not self.best_params:
147
- raise RuntimeError("Run tune() first to obtain best GNN parameters.")
148
-
149
- data = self.ctx.train_oht_scl_data
150
- assert data is not None, "Preprocessed training data is missing."
151
- X_all = data[self.ctx.var_nmes]
152
- y_all = data[self.ctx.resp_nme]
153
- w_all = data[self.ctx.weight_nme]
154
-
155
- use_refit = bool(getattr(self.ctx.config, "final_refit", True))
156
- refit_epochs = None
157
-
158
- split = self._resolve_train_val_indices(X_all)
159
- if split is not None:
160
- train_idx, val_idx = split
161
- X_train = X_all.iloc[train_idx]
162
- y_train = y_all.iloc[train_idx]
163
- w_train = w_all.iloc[train_idx]
164
- X_val = X_all.iloc[val_idx]
165
- y_val = y_all.iloc[val_idx]
166
- w_val = w_all.iloc[val_idx]
167
-
168
- if use_refit:
169
- tmp_model = self._build_model(self.best_params)
29
+ def _maybe_cleanup_gpu(self, model: Optional[GraphNeuralNetSklearn]) -> None:
30
+ if not bool(getattr(self.ctx.config, "gnn_cleanup_per_fold", False)):
31
+ return
32
+ if model is not None:
33
+ getattr(getattr(model, "gnn", None), "to",
34
+ lambda *_args, **_kwargs: None)("cpu")
35
+ synchronize = bool(getattr(self.ctx.config, "gnn_cleanup_synchronize", False))
36
+ self._clean_gpu(synchronize=synchronize)
37
+
38
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
39
+ params = params or {}
40
+ base_tw_power = self.ctx.default_tweedie_power()
41
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
42
+ tw_power = params.get("tw_power")
43
+ if self.ctx.task_type == "regression":
44
+ if loss_name == "tweedie":
45
+ tw_power = base_tw_power if tw_power is None else float(tw_power)
46
+ elif loss_name in ("poisson", "gamma"):
47
+ tw_power = base_tw_power
48
+ else:
49
+ tw_power = None
50
+ model = GraphNeuralNetSklearn(
51
+ model_nme=f"{self.ctx.model_nme}_gnn",
52
+ input_dim=len(self.ctx.var_nmes),
53
+ hidden_dim=int(params.get("hidden_dim", 64)),
54
+ num_layers=int(params.get("num_layers", 2)),
55
+ k_neighbors=int(params.get("k_neighbors", 10)),
56
+ dropout=float(params.get("dropout", 0.1)),
57
+ learning_rate=float(params.get("learning_rate", 1e-3)),
58
+ epochs=int(params.get("epochs", self.ctx.epochs)),
59
+ patience=int(params.get("patience", 5)),
60
+ task_type=self.ctx.task_type,
61
+ tweedie_power=tw_power,
62
+ weight_decay=float(params.get("weight_decay", 0.0)),
63
+ use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
64
+ use_ddp=bool(self.ctx.config.use_gnn_ddp),
65
+ use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
66
+ approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
67
+ graph_cache_path=self.ctx.config.gnn_graph_cache,
68
+ max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
69
+ knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
70
+ knn_gpu_mem_overhead=float(
71
+ self.ctx.config.gnn_knn_gpu_mem_overhead),
72
+ loss_name=loss_name,
73
+ )
74
+ return self._apply_dataloader_overrides(model)
75
+
76
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
77
+ base_tw_power = self.ctx.default_tweedie_power()
78
+ loss_name = getattr(self.ctx, "loss_name", "tweedie")
79
+ metric_ctx: Dict[str, Any] = {}
80
+
81
+ def data_provider():
82
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
83
+ assert data is not None, "Preprocessed training data is missing."
84
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
85
+
86
+ def model_builder(params: Dict[str, Any]):
87
+ if loss_name == "tweedie":
88
+ tw_power = params.get("tw_power", base_tw_power)
89
+ elif loss_name in ("poisson", "gamma"):
90
+ tw_power = base_tw_power
91
+ else:
92
+ tw_power = None
93
+ metric_ctx["tw_power"] = tw_power
94
+ if tw_power is None:
95
+ params = dict(params)
96
+ params.pop("tw_power", None)
97
+ return self._build_model(params)
98
+
99
+ def preprocess_fn(X_train, X_val):
100
+ X_train_s, X_val_s, _ = self._standardize_fold(
101
+ X_train, X_val, self.ctx.num_features)
102
+ return X_train_s, X_val_s
103
+
104
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
105
+ model.fit(
106
+ X_train,
107
+ y_train,
108
+ w_train=w_train,
109
+ X_val=X_val,
110
+ y_val=y_val,
111
+ w_val=w_val,
112
+ trial=trial_obj,
113
+ )
114
+ return model.predict(X_val)
115
+
116
+ def metric_fn(y_true, y_pred, weight):
117
+ if self.ctx.task_type == 'classification':
118
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
119
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
120
+ return regression_loss(
121
+ y_true,
122
+ y_pred,
123
+ weight,
124
+ loss_name=loss_name,
125
+ tweedie_power=metric_ctx.get("tw_power", base_tw_power),
126
+ )
127
+
128
+ # Keep GNN BO lightweight: sample during CV, use full data for final training.
129
+ X_cap = data_provider()[0]
130
+ sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
131
+
132
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
133
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
134
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
135
+ "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
136
+ "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
137
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
138
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
139
+ }
140
+ if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
141
+ param_space["tw_power"] = lambda t: t.suggest_float(
142
+ 'tw_power', 1.0, 2.0)
143
+
144
+ return self.cross_val_generic(
145
+ trial=trial,
146
+ hyperparameter_space=param_space,
147
+ data_provider=data_provider,
148
+ model_builder=model_builder,
149
+ metric_fn=metric_fn,
150
+ sample_limit=sample_limit,
151
+ preprocess_fn=preprocess_fn,
152
+ fit_predict_fn=fit_predict,
153
+ cleanup_fn=lambda m: getattr(
154
+ getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
155
+ )
156
+
157
+ def train(self) -> None:
158
+ if not self.best_params:
159
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
160
+
161
+ data = self.ctx.train_oht_scl_data
162
+ assert data is not None, "Preprocessed training data is missing."
163
+ X_all = data[self.ctx.var_nmes]
164
+ y_all = data[self.ctx.resp_nme]
165
+ w_all = data[self.ctx.weight_nme]
166
+
167
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
168
+ refit_epochs = None
169
+
170
+ split = self._resolve_train_val_indices(X_all)
171
+ if split is not None:
172
+ train_idx, val_idx = split
173
+ X_train = X_all.iloc[train_idx]
174
+ y_train = y_all.iloc[train_idx]
175
+ w_train = w_all.iloc[train_idx]
176
+ X_val = X_all.iloc[val_idx]
177
+ y_val = y_all.iloc[val_idx]
178
+ w_val = w_all.iloc[val_idx]
179
+
180
+ if use_refit:
181
+ tmp_model = self._build_model(self.best_params)
170
182
  tmp_model.fit(
171
183
  X_train,
172
184
  y_train,
@@ -177,168 +189,164 @@ class GNNTrainer(TrainerBase):
177
189
  trial=None,
178
190
  )
179
191
  refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
180
- getattr(getattr(tmp_model, "gnn", None), "to",
181
- lambda *_args, **_kwargs: None)("cpu")
182
- self._clean_gpu()
183
- else:
184
- self.model = self._build_model(self.best_params)
185
- self.model.fit(
186
- X_train,
187
- y_train,
188
- w_train=w_train,
189
- X_val=X_val,
190
- y_val=y_val,
191
- w_val=w_val,
192
- trial=None,
193
- )
194
- else:
195
- use_refit = False
196
-
197
- if use_refit:
198
- self.model = self._build_model(self.best_params)
199
- if refit_epochs is not None:
200
- self.model.epochs = int(refit_epochs)
201
- self.model.fit(
202
- X_all,
203
- y_all,
204
- w_train=w_all,
205
- X_val=None,
206
- y_val=None,
207
- w_val=None,
208
- trial=None,
209
- )
210
- elif self.model is None:
211
- self.model = self._build_model(self.best_params)
212
- self.model.fit(
213
- X_all,
214
- y_all,
215
- w_train=w_all,
216
- X_val=None,
217
- y_val=None,
218
- w_val=None,
219
- trial=None,
220
- )
221
- self.ctx.model_label.append(self.label)
222
- self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
223
- self.ctx.gnn_best = self.model
224
-
225
- # If geo_feature_nmes is set, refresh geo tokens for FT input.
226
- if self.ctx.config.geo_feature_nmes:
227
- self.prepare_geo_tokens(force=True)
228
-
229
- def ensemble_predict(self, k: int) -> None:
230
- if not self.best_params:
231
- raise RuntimeError("Run tune() first to obtain best GNN parameters.")
232
- data = self.ctx.train_oht_scl_data
233
- test_data = self.ctx.test_oht_scl_data
234
- if data is None or test_data is None:
235
- raise RuntimeError("Missing standardized data for GNN ensemble.")
236
- X_all = data[self.ctx.var_nmes]
237
- y_all = data[self.ctx.resp_nme]
238
- w_all = data[self.ctx.weight_nme]
239
- X_test = test_data[self.ctx.var_nmes]
240
-
241
- k = max(2, int(k))
242
- n_samples = len(X_all)
243
- split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
244
- if split_iter is None:
245
- print(
246
- f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
247
- flush=True,
248
- )
249
- return
250
- preds_train_sum = np.zeros(n_samples, dtype=np.float64)
251
- preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
252
-
253
- split_count = 0
254
- for train_idx, val_idx in split_iter:
255
- model = self._build_model(self.best_params)
256
- model.fit(
257
- X_all.iloc[train_idx],
258
- y_all.iloc[train_idx],
259
- w_train=w_all.iloc[train_idx],
260
- X_val=X_all.iloc[val_idx],
261
- y_val=y_all.iloc[val_idx],
262
- w_val=w_all.iloc[val_idx],
263
- trial=None,
264
- )
265
- pred_train = model.predict(X_all)
266
- pred_test = model.predict(X_test)
192
+ self._maybe_cleanup_gpu(tmp_model)
193
+ else:
194
+ self.model = self._build_model(self.best_params)
195
+ self.model.fit(
196
+ X_train,
197
+ y_train,
198
+ w_train=w_train,
199
+ X_val=X_val,
200
+ y_val=y_val,
201
+ w_val=w_val,
202
+ trial=None,
203
+ )
204
+ else:
205
+ use_refit = False
206
+
207
+ if use_refit:
208
+ self.model = self._build_model(self.best_params)
209
+ if refit_epochs is not None:
210
+ self.model.epochs = int(refit_epochs)
211
+ self.model.fit(
212
+ X_all,
213
+ y_all,
214
+ w_train=w_all,
215
+ X_val=None,
216
+ y_val=None,
217
+ w_val=None,
218
+ trial=None,
219
+ )
220
+ elif self.model is None:
221
+ self.model = self._build_model(self.best_params)
222
+ self.model.fit(
223
+ X_all,
224
+ y_all,
225
+ w_train=w_all,
226
+ X_val=None,
227
+ y_val=None,
228
+ w_val=None,
229
+ trial=None,
230
+ )
231
+ self.ctx.model_label.append(self.label)
232
+ self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
233
+ self.ctx.gnn_best = self.model
234
+
235
+ # If geo_feature_nmes is set, refresh geo tokens for FT input.
236
+ if self.ctx.config.geo_feature_nmes:
237
+ self.prepare_geo_tokens(force=True)
238
+
239
+ def ensemble_predict(self, k: int) -> None:
240
+ if not self.best_params:
241
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
242
+ data = self.ctx.train_oht_scl_data
243
+ test_data = self.ctx.test_oht_scl_data
244
+ if data is None or test_data is None:
245
+ raise RuntimeError("Missing standardized data for GNN ensemble.")
246
+ X_all = data[self.ctx.var_nmes]
247
+ y_all = data[self.ctx.resp_nme]
248
+ w_all = data[self.ctx.weight_nme]
249
+ X_test = test_data[self.ctx.var_nmes]
250
+
251
+ k = max(2, int(k))
252
+ n_samples = len(X_all)
253
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
254
+ if split_iter is None:
255
+ _log(
256
+ f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
257
+ flush=True,
258
+ )
259
+ return
260
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
261
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
262
+
263
+ split_count = 0
264
+ for train_idx, val_idx in split_iter:
265
+ model = self._build_model(self.best_params)
266
+ model.fit(
267
+ X_all.iloc[train_idx],
268
+ y_all.iloc[train_idx],
269
+ w_train=w_all.iloc[train_idx],
270
+ X_val=X_all.iloc[val_idx],
271
+ y_val=y_all.iloc[val_idx],
272
+ w_val=w_all.iloc[val_idx],
273
+ trial=None,
274
+ )
275
+ pred_train = model.predict(X_all)
276
+ pred_test = model.predict(X_test)
267
277
  preds_train_sum += np.asarray(pred_train, dtype=np.float64)
268
278
  preds_test_sum += np.asarray(pred_test, dtype=np.float64)
269
- getattr(getattr(model, "gnn", None), "to",
270
- lambda *_args, **_kwargs: None)("cpu")
271
- self._clean_gpu()
279
+ self._maybe_cleanup_gpu(model)
272
280
  split_count += 1
273
-
274
- if split_count < 1:
275
- print(
276
- f"[GNN Ensemble] no CV splits generated; skip ensemble.",
277
- flush=True,
278
- )
279
- return
280
- preds_train = preds_train_sum / float(split_count)
281
- preds_test = preds_test_sum / float(split_count)
282
- self._cache_predictions("gnn", preds_train, preds_test)
283
-
284
- def prepare_geo_tokens(self, force: bool = False) -> None:
285
- """Train/update the GNN encoder for geo tokens and inject them into FT input."""
286
- geo_cols = list(self.ctx.config.geo_feature_nmes or [])
287
- if not geo_cols:
288
- return
289
- if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
290
- return
291
-
292
- result = self.ctx._build_geo_tokens()
293
- if result is None:
294
- return
295
- train_tokens, test_tokens, cols, geo_gnn = result
296
- self.ctx.train_geo_tokens = train_tokens
297
- self.ctx.test_geo_tokens = test_tokens
298
- self.ctx.geo_token_cols = cols
299
- self.ctx.geo_gnn_model = geo_gnn
300
- print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
301
-
302
- def save(self) -> None:
303
- if self.model is None:
304
- print(f"[save] Warning: No model to save for {self.label}")
305
- return
306
- path = self.output.model_path(self._get_model_filename())
307
- base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
308
- if base_gnn is not None:
309
- base_gnn = base_gnn.to("cpu")
310
- state = None if base_gnn is None else base_gnn.state_dict()
311
- payload = {
312
- "best_params": self.best_params,
313
- "state_dict": state,
314
- "preprocess_artifacts": self._export_preprocess_artifacts(),
315
- }
316
- torch.save(payload, path)
317
-
318
- def load(self) -> None:
319
- path = self.output.model_path(self._get_model_filename())
320
- if not os.path.exists(path):
321
- print(f"[load] Warning: Model file not found: {path}")
322
- return
323
- payload = torch_load(path, map_location='cpu', weights_only=False)
324
- if not isinstance(payload, dict):
325
- raise ValueError(f"Invalid GNN checkpoint: {path}")
326
- params = payload.get("best_params") or {}
327
- state_dict = payload.get("state_dict")
328
- model = self._build_model(params)
329
- if params:
330
- model.set_params(dict(params))
331
- base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
332
- if base_gnn is not None and state_dict is not None:
333
- # Use strict=True for better error detection, but handle missing keys gracefully
334
- try:
335
- base_gnn.load_state_dict(state_dict, strict=True)
336
- except RuntimeError as e:
337
- if "Missing key" in str(e) or "Unexpected key" in str(e):
338
- print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
339
- base_gnn.load_state_dict(state_dict, strict=False)
340
- else:
341
- raise
342
- self.model = model
343
- self.best_params = dict(params) if isinstance(params, dict) else None
344
- self.ctx.gnn_best = self.model
281
+
282
+ if split_count < 1:
283
+ _log(
284
+ f"[GNN Ensemble] no CV splits generated; skip ensemble.",
285
+ flush=True,
286
+ )
287
+ return
288
+ preds_train = preds_train_sum / float(split_count)
289
+ preds_test = preds_test_sum / float(split_count)
290
+ self._cache_predictions("gnn", preds_train, preds_test)
291
+
292
+ def prepare_geo_tokens(self, force: bool = False) -> None:
293
+ """Train/update the GNN encoder for geo tokens and inject them into FT input."""
294
+ geo_cols = list(self.ctx.config.geo_feature_nmes or [])
295
+ if not geo_cols:
296
+ return
297
+ if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
298
+ return
299
+
300
+ result = self.ctx._build_geo_tokens()
301
+ if result is None:
302
+ return
303
+ train_tokens, test_tokens, cols, geo_gnn = result
304
+ self.ctx.train_geo_tokens = train_tokens
305
+ self.ctx.test_geo_tokens = test_tokens
306
+ self.ctx.geo_token_cols = cols
307
+ self.ctx.geo_gnn_model = geo_gnn
308
+ _log(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
309
+
310
+ def save(self) -> None:
311
+ if self.model is None:
312
+ _log(f"[save] Warning: No model to save for {self.label}")
313
+ return
314
+ path = self.output.model_path(self._get_model_filename())
315
+ base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
316
+ if base_gnn is not None:
317
+ base_gnn = base_gnn.to("cpu")
318
+ state = None if base_gnn is None else base_gnn.state_dict()
319
+ payload = {
320
+ "best_params": self.best_params,
321
+ "state_dict": state,
322
+ "preprocess_artifacts": self._export_preprocess_artifacts(),
323
+ }
324
+ torch.save(payload, path)
325
+
326
+ def load(self) -> None:
327
+ path = self.output.model_path(self._get_model_filename())
328
+ if not os.path.exists(path):
329
+ _log(f"[load] Warning: Model file not found: {path}")
330
+ return
331
+ payload = torch_load(path, map_location='cpu', weights_only=False)
332
+ if not isinstance(payload, dict):
333
+ raise ValueError(f"Invalid GNN checkpoint: {path}")
334
+ params = payload.get("best_params") or {}
335
+ state_dict = payload.get("state_dict")
336
+ model = self._build_model(params)
337
+ if params:
338
+ model.set_params(dict(params))
339
+ base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
340
+ if base_gnn is not None and state_dict is not None:
341
+ # Use strict=True for better error detection, but handle missing keys gracefully
342
+ try:
343
+ base_gnn.load_state_dict(state_dict, strict=True)
344
+ except RuntimeError as e:
345
+ if "Missing key" in str(e) or "Unexpected key" in str(e):
346
+ _log(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
347
+ base_gnn.load_state_dict(state_dict, strict=False)
348
+ else:
349
+ raise
350
+ self.model = model
351
+ self.best_params = dict(params) if isinstance(params, dict) else None
352
+ self.ctx.gnn_best = self.model