ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +52 -50
- ins_pricing/cli/BayesOpt_incremental.py +39 -105
- ins_pricing/cli/Explain_Run.py +31 -23
- ins_pricing/cli/Explain_entry.py +532 -579
- ins_pricing/cli/Pricing_Run.py +31 -23
- ins_pricing/cli/bayesopt_entry_runner.py +11 -9
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +375 -375
- ins_pricing/cli/utils/import_resolver.py +382 -365
- ins_pricing/cli/utils/notebook_utils.py +340 -340
- ins_pricing/cli/watchdog_run.py +209 -201
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +116 -83
- ins_pricing/utils/device.py +255 -255
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
- ins_pricing-0.5.0.dist-info/RECORD +131 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,344 +1,344 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import optuna
|
|
8
|
-
import torch
|
|
9
|
-
from sklearn.metrics import log_loss
|
|
10
|
-
|
|
11
|
-
from .trainer_base import TrainerBase
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
15
|
-
from ins_pricing.utils import get_logger
|
|
16
|
-
from ins_pricing.utils.torch_compat import torch_load
|
|
17
|
-
|
|
18
|
-
_logger = get_logger("ins_pricing.trainer.gnn")
|
|
19
|
-
|
|
20
|
-
class GNNTrainer(TrainerBase):
|
|
21
|
-
def __init__(self, context: "BayesOptModel") -> None:
|
|
22
|
-
super().__init__(context, 'GNN', 'GNN')
|
|
23
|
-
self.model: Optional[GraphNeuralNetSklearn] = None
|
|
24
|
-
self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
|
|
25
|
-
|
|
26
|
-
def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
|
|
27
|
-
params = params or {}
|
|
28
|
-
base_tw_power = self.ctx.default_tweedie_power()
|
|
29
|
-
loss_name = getattr(self.ctx, "loss_name", "tweedie")
|
|
30
|
-
tw_power = params.get("tw_power")
|
|
31
|
-
if self.ctx.task_type == "regression":
|
|
32
|
-
if loss_name == "tweedie":
|
|
33
|
-
tw_power = base_tw_power if tw_power is None else float(tw_power)
|
|
34
|
-
elif loss_name in ("poisson", "gamma"):
|
|
35
|
-
tw_power = base_tw_power
|
|
36
|
-
else:
|
|
37
|
-
tw_power = None
|
|
38
|
-
model = GraphNeuralNetSklearn(
|
|
39
|
-
model_nme=f"{self.ctx.model_nme}_gnn",
|
|
40
|
-
input_dim=len(self.ctx.var_nmes),
|
|
41
|
-
hidden_dim=int(params.get("hidden_dim", 64)),
|
|
42
|
-
num_layers=int(params.get("num_layers", 2)),
|
|
43
|
-
k_neighbors=int(params.get("k_neighbors", 10)),
|
|
44
|
-
dropout=float(params.get("dropout", 0.1)),
|
|
45
|
-
learning_rate=float(params.get("learning_rate", 1e-3)),
|
|
46
|
-
epochs=int(params.get("epochs", self.ctx.epochs)),
|
|
47
|
-
patience=int(params.get("patience", 5)),
|
|
48
|
-
task_type=self.ctx.task_type,
|
|
49
|
-
tweedie_power=tw_power,
|
|
50
|
-
weight_decay=float(params.get("weight_decay", 0.0)),
|
|
51
|
-
use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
|
|
52
|
-
use_ddp=bool(self.ctx.config.use_gnn_ddp),
|
|
53
|
-
use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
|
|
54
|
-
approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
|
|
55
|
-
graph_cache_path=self.ctx.config.gnn_graph_cache,
|
|
56
|
-
max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
|
|
57
|
-
knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
|
|
58
|
-
knn_gpu_mem_overhead=float(
|
|
59
|
-
self.ctx.config.gnn_knn_gpu_mem_overhead),
|
|
60
|
-
loss_name=loss_name,
|
|
61
|
-
)
|
|
62
|
-
return self._apply_dataloader_overrides(model)
|
|
63
|
-
|
|
64
|
-
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
65
|
-
base_tw_power = self.ctx.default_tweedie_power()
|
|
66
|
-
loss_name = getattr(self.ctx, "loss_name", "tweedie")
|
|
67
|
-
metric_ctx: Dict[str, Any] = {}
|
|
68
|
-
|
|
69
|
-
def data_provider():
|
|
70
|
-
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
71
|
-
assert data is not None, "Preprocessed training data is missing."
|
|
72
|
-
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
73
|
-
|
|
74
|
-
def model_builder(params: Dict[str, Any]):
|
|
75
|
-
if loss_name == "tweedie":
|
|
76
|
-
tw_power = params.get("tw_power", base_tw_power)
|
|
77
|
-
elif loss_name in ("poisson", "gamma"):
|
|
78
|
-
tw_power = base_tw_power
|
|
79
|
-
else:
|
|
80
|
-
tw_power = None
|
|
81
|
-
metric_ctx["tw_power"] = tw_power
|
|
82
|
-
if tw_power is None:
|
|
83
|
-
params = dict(params)
|
|
84
|
-
params.pop("tw_power", None)
|
|
85
|
-
return self._build_model(params)
|
|
86
|
-
|
|
87
|
-
def preprocess_fn(X_train, X_val):
|
|
88
|
-
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
89
|
-
X_train, X_val, self.ctx.num_features)
|
|
90
|
-
return X_train_s, X_val_s
|
|
91
|
-
|
|
92
|
-
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
93
|
-
model.fit(
|
|
94
|
-
X_train,
|
|
95
|
-
y_train,
|
|
96
|
-
w_train=w_train,
|
|
97
|
-
X_val=X_val,
|
|
98
|
-
y_val=y_val,
|
|
99
|
-
w_val=w_val,
|
|
100
|
-
trial=trial_obj,
|
|
101
|
-
)
|
|
102
|
-
return model.predict(X_val)
|
|
103
|
-
|
|
104
|
-
def metric_fn(y_true, y_pred, weight):
|
|
105
|
-
if self.ctx.task_type == 'classification':
|
|
106
|
-
y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
|
|
107
|
-
return log_loss(y_true, y_pred_clipped, sample_weight=weight)
|
|
108
|
-
return regression_loss(
|
|
109
|
-
y_true,
|
|
110
|
-
y_pred,
|
|
111
|
-
weight,
|
|
112
|
-
loss_name=loss_name,
|
|
113
|
-
tweedie_power=metric_ctx.get("tw_power", base_tw_power),
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
# Keep GNN BO lightweight: sample during CV, use full data for final training.
|
|
117
|
-
X_cap = data_provider()[0]
|
|
118
|
-
sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
|
|
119
|
-
|
|
120
|
-
param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
|
|
121
|
-
"learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
|
|
122
|
-
"hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
|
|
123
|
-
"num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
|
|
124
|
-
"k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
|
|
125
|
-
"dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
|
|
126
|
-
"weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
|
|
127
|
-
}
|
|
128
|
-
if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
|
|
129
|
-
param_space["tw_power"] = lambda t: t.suggest_float(
|
|
130
|
-
'tw_power', 1.0, 2.0)
|
|
131
|
-
|
|
132
|
-
return self.cross_val_generic(
|
|
133
|
-
trial=trial,
|
|
134
|
-
hyperparameter_space=param_space,
|
|
135
|
-
data_provider=data_provider,
|
|
136
|
-
model_builder=model_builder,
|
|
137
|
-
metric_fn=metric_fn,
|
|
138
|
-
sample_limit=sample_limit,
|
|
139
|
-
preprocess_fn=preprocess_fn,
|
|
140
|
-
fit_predict_fn=fit_predict,
|
|
141
|
-
cleanup_fn=lambda m: getattr(
|
|
142
|
-
getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
def train(self) -> None:
|
|
146
|
-
if not self.best_params:
|
|
147
|
-
raise RuntimeError("Run tune() first to obtain best GNN parameters.")
|
|
148
|
-
|
|
149
|
-
data = self.ctx.train_oht_scl_data
|
|
150
|
-
assert data is not None, "Preprocessed training data is missing."
|
|
151
|
-
X_all = data[self.ctx.var_nmes]
|
|
152
|
-
y_all = data[self.ctx.resp_nme]
|
|
153
|
-
w_all = data[self.ctx.weight_nme]
|
|
154
|
-
|
|
155
|
-
use_refit = bool(getattr(self.ctx.config, "final_refit", True))
|
|
156
|
-
refit_epochs = None
|
|
157
|
-
|
|
158
|
-
split = self._resolve_train_val_indices(X_all)
|
|
159
|
-
if split is not None:
|
|
160
|
-
train_idx, val_idx = split
|
|
161
|
-
X_train = X_all.iloc[train_idx]
|
|
162
|
-
y_train = y_all.iloc[train_idx]
|
|
163
|
-
w_train = w_all.iloc[train_idx]
|
|
164
|
-
X_val = X_all.iloc[val_idx]
|
|
165
|
-
y_val = y_all.iloc[val_idx]
|
|
166
|
-
w_val = w_all.iloc[val_idx]
|
|
167
|
-
|
|
168
|
-
if use_refit:
|
|
169
|
-
tmp_model = self._build_model(self.best_params)
|
|
170
|
-
tmp_model.fit(
|
|
171
|
-
X_train,
|
|
172
|
-
y_train,
|
|
173
|
-
w_train=w_train,
|
|
174
|
-
X_val=X_val,
|
|
175
|
-
y_val=y_val,
|
|
176
|
-
w_val=w_val,
|
|
177
|
-
trial=None,
|
|
178
|
-
)
|
|
179
|
-
refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
|
|
180
|
-
getattr(getattr(tmp_model, "gnn", None), "to",
|
|
181
|
-
lambda *_args, **_kwargs: None)("cpu")
|
|
182
|
-
self._clean_gpu()
|
|
183
|
-
else:
|
|
184
|
-
self.model = self._build_model(self.best_params)
|
|
185
|
-
self.model.fit(
|
|
186
|
-
X_train,
|
|
187
|
-
y_train,
|
|
188
|
-
w_train=w_train,
|
|
189
|
-
X_val=X_val,
|
|
190
|
-
y_val=y_val,
|
|
191
|
-
w_val=w_val,
|
|
192
|
-
trial=None,
|
|
193
|
-
)
|
|
194
|
-
else:
|
|
195
|
-
use_refit = False
|
|
196
|
-
|
|
197
|
-
if use_refit:
|
|
198
|
-
self.model = self._build_model(self.best_params)
|
|
199
|
-
if refit_epochs is not None:
|
|
200
|
-
self.model.epochs = int(refit_epochs)
|
|
201
|
-
self.model.fit(
|
|
202
|
-
X_all,
|
|
203
|
-
y_all,
|
|
204
|
-
w_train=w_all,
|
|
205
|
-
X_val=None,
|
|
206
|
-
y_val=None,
|
|
207
|
-
w_val=None,
|
|
208
|
-
trial=None,
|
|
209
|
-
)
|
|
210
|
-
elif self.model is None:
|
|
211
|
-
self.model = self._build_model(self.best_params)
|
|
212
|
-
self.model.fit(
|
|
213
|
-
X_all,
|
|
214
|
-
y_all,
|
|
215
|
-
w_train=w_all,
|
|
216
|
-
X_val=None,
|
|
217
|
-
y_val=None,
|
|
218
|
-
w_val=None,
|
|
219
|
-
trial=None,
|
|
220
|
-
)
|
|
221
|
-
self.ctx.model_label.append(self.label)
|
|
222
|
-
self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
|
|
223
|
-
self.ctx.gnn_best = self.model
|
|
224
|
-
|
|
225
|
-
# If geo_feature_nmes is set, refresh geo tokens for FT input.
|
|
226
|
-
if self.ctx.config.geo_feature_nmes:
|
|
227
|
-
self.prepare_geo_tokens(force=True)
|
|
228
|
-
|
|
229
|
-
def ensemble_predict(self, k: int) -> None:
|
|
230
|
-
if not self.best_params:
|
|
231
|
-
raise RuntimeError("Run tune() first to obtain best GNN parameters.")
|
|
232
|
-
data = self.ctx.train_oht_scl_data
|
|
233
|
-
test_data = self.ctx.test_oht_scl_data
|
|
234
|
-
if data is None or test_data is None:
|
|
235
|
-
raise RuntimeError("Missing standardized data for GNN ensemble.")
|
|
236
|
-
X_all = data[self.ctx.var_nmes]
|
|
237
|
-
y_all = data[self.ctx.resp_nme]
|
|
238
|
-
w_all = data[self.ctx.weight_nme]
|
|
239
|
-
X_test = test_data[self.ctx.var_nmes]
|
|
240
|
-
|
|
241
|
-
k = max(2, int(k))
|
|
242
|
-
n_samples = len(X_all)
|
|
243
|
-
split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
|
|
244
|
-
if split_iter is None:
|
|
245
|
-
print(
|
|
246
|
-
f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
|
|
247
|
-
flush=True,
|
|
248
|
-
)
|
|
249
|
-
return
|
|
250
|
-
preds_train_sum = np.zeros(n_samples, dtype=np.float64)
|
|
251
|
-
preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
|
|
252
|
-
|
|
253
|
-
split_count = 0
|
|
254
|
-
for train_idx, val_idx in split_iter:
|
|
255
|
-
model = self._build_model(self.best_params)
|
|
256
|
-
model.fit(
|
|
257
|
-
X_all.iloc[train_idx],
|
|
258
|
-
y_all.iloc[train_idx],
|
|
259
|
-
w_train=w_all.iloc[train_idx],
|
|
260
|
-
X_val=X_all.iloc[val_idx],
|
|
261
|
-
y_val=y_all.iloc[val_idx],
|
|
262
|
-
w_val=w_all.iloc[val_idx],
|
|
263
|
-
trial=None,
|
|
264
|
-
)
|
|
265
|
-
pred_train = model.predict(X_all)
|
|
266
|
-
pred_test = model.predict(X_test)
|
|
267
|
-
preds_train_sum += np.asarray(pred_train, dtype=np.float64)
|
|
268
|
-
preds_test_sum += np.asarray(pred_test, dtype=np.float64)
|
|
269
|
-
getattr(getattr(model, "gnn", None), "to",
|
|
270
|
-
lambda *_args, **_kwargs: None)("cpu")
|
|
271
|
-
self._clean_gpu()
|
|
272
|
-
split_count += 1
|
|
273
|
-
|
|
274
|
-
if split_count < 1:
|
|
275
|
-
print(
|
|
276
|
-
f"[GNN Ensemble] no CV splits generated; skip ensemble.",
|
|
277
|
-
flush=True,
|
|
278
|
-
)
|
|
279
|
-
return
|
|
280
|
-
preds_train = preds_train_sum / float(split_count)
|
|
281
|
-
preds_test = preds_test_sum / float(split_count)
|
|
282
|
-
self._cache_predictions("gnn", preds_train, preds_test)
|
|
283
|
-
|
|
284
|
-
def prepare_geo_tokens(self, force: bool = False) -> None:
|
|
285
|
-
"""Train/update the GNN encoder for geo tokens and inject them into FT input."""
|
|
286
|
-
geo_cols = list(self.ctx.config.geo_feature_nmes or [])
|
|
287
|
-
if not geo_cols:
|
|
288
|
-
return
|
|
289
|
-
if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
|
|
290
|
-
return
|
|
291
|
-
|
|
292
|
-
result = self.ctx._build_geo_tokens()
|
|
293
|
-
if result is None:
|
|
294
|
-
return
|
|
295
|
-
train_tokens, test_tokens, cols, geo_gnn = result
|
|
296
|
-
self.ctx.train_geo_tokens = train_tokens
|
|
297
|
-
self.ctx.test_geo_tokens = test_tokens
|
|
298
|
-
self.ctx.geo_token_cols = cols
|
|
299
|
-
self.ctx.geo_gnn_model = geo_gnn
|
|
300
|
-
print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
|
|
301
|
-
|
|
302
|
-
def save(self) -> None:
|
|
303
|
-
if self.model is None:
|
|
304
|
-
print(f"[save] Warning: No model to save for {self.label}")
|
|
305
|
-
return
|
|
306
|
-
path = self.output.model_path(self._get_model_filename())
|
|
307
|
-
base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
|
|
308
|
-
if base_gnn is not None:
|
|
309
|
-
base_gnn = base_gnn.to("cpu")
|
|
310
|
-
state = None if base_gnn is None else base_gnn.state_dict()
|
|
311
|
-
payload = {
|
|
312
|
-
"best_params": self.best_params,
|
|
313
|
-
"state_dict": state,
|
|
314
|
-
"preprocess_artifacts": self._export_preprocess_artifacts(),
|
|
315
|
-
}
|
|
316
|
-
torch.save(payload, path)
|
|
317
|
-
|
|
318
|
-
def load(self) -> None:
|
|
319
|
-
path = self.output.model_path(self._get_model_filename())
|
|
320
|
-
if not os.path.exists(path):
|
|
321
|
-
print(f"[load] Warning: Model file not found: {path}")
|
|
322
|
-
return
|
|
323
|
-
payload = torch_load(path, map_location='cpu', weights_only=False)
|
|
324
|
-
if not isinstance(payload, dict):
|
|
325
|
-
raise ValueError(f"Invalid GNN checkpoint: {path}")
|
|
326
|
-
params = payload.get("best_params") or {}
|
|
327
|
-
state_dict = payload.get("state_dict")
|
|
328
|
-
model = self._build_model(params)
|
|
329
|
-
if params:
|
|
330
|
-
model.set_params(dict(params))
|
|
331
|
-
base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
|
|
332
|
-
if base_gnn is not None and state_dict is not None:
|
|
333
|
-
# Use strict=True for better error detection, but handle missing keys gracefully
|
|
334
|
-
try:
|
|
335
|
-
base_gnn.load_state_dict(state_dict, strict=True)
|
|
336
|
-
except RuntimeError as e:
|
|
337
|
-
if "Missing key" in str(e) or "Unexpected key" in str(e):
|
|
338
|
-
print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
|
|
339
|
-
base_gnn.load_state_dict(state_dict, strict=False)
|
|
340
|
-
else:
|
|
341
|
-
raise
|
|
342
|
-
self.model = model
|
|
343
|
-
self.best_params = dict(params) if isinstance(params, dict) else None
|
|
344
|
-
self.ctx.gnn_best = self.model
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
import torch
|
|
9
|
+
from sklearn.metrics import log_loss
|
|
10
|
+
|
|
11
|
+
from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
|
|
12
|
+
from ins_pricing.modelling.bayesopt.models import GraphNeuralNetSklearn
|
|
13
|
+
from ins_pricing.utils import EPS
|
|
14
|
+
from ins_pricing.utils.losses import regression_loss
|
|
15
|
+
from ins_pricing.utils import get_logger
|
|
16
|
+
from ins_pricing.utils.torch_compat import torch_load
|
|
17
|
+
|
|
18
|
+
_logger = get_logger("ins_pricing.trainer.gnn")
|
|
19
|
+
|
|
20
|
+
class GNNTrainer(TrainerBase):
|
|
21
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
22
|
+
super().__init__(context, 'GNN', 'GNN')
|
|
23
|
+
self.model: Optional[GraphNeuralNetSklearn] = None
|
|
24
|
+
self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
|
|
25
|
+
|
|
26
|
+
def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
|
|
27
|
+
params = params or {}
|
|
28
|
+
base_tw_power = self.ctx.default_tweedie_power()
|
|
29
|
+
loss_name = getattr(self.ctx, "loss_name", "tweedie")
|
|
30
|
+
tw_power = params.get("tw_power")
|
|
31
|
+
if self.ctx.task_type == "regression":
|
|
32
|
+
if loss_name == "tweedie":
|
|
33
|
+
tw_power = base_tw_power if tw_power is None else float(tw_power)
|
|
34
|
+
elif loss_name in ("poisson", "gamma"):
|
|
35
|
+
tw_power = base_tw_power
|
|
36
|
+
else:
|
|
37
|
+
tw_power = None
|
|
38
|
+
model = GraphNeuralNetSklearn(
|
|
39
|
+
model_nme=f"{self.ctx.model_nme}_gnn",
|
|
40
|
+
input_dim=len(self.ctx.var_nmes),
|
|
41
|
+
hidden_dim=int(params.get("hidden_dim", 64)),
|
|
42
|
+
num_layers=int(params.get("num_layers", 2)),
|
|
43
|
+
k_neighbors=int(params.get("k_neighbors", 10)),
|
|
44
|
+
dropout=float(params.get("dropout", 0.1)),
|
|
45
|
+
learning_rate=float(params.get("learning_rate", 1e-3)),
|
|
46
|
+
epochs=int(params.get("epochs", self.ctx.epochs)),
|
|
47
|
+
patience=int(params.get("patience", 5)),
|
|
48
|
+
task_type=self.ctx.task_type,
|
|
49
|
+
tweedie_power=tw_power,
|
|
50
|
+
weight_decay=float(params.get("weight_decay", 0.0)),
|
|
51
|
+
use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
|
|
52
|
+
use_ddp=bool(self.ctx.config.use_gnn_ddp),
|
|
53
|
+
use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
|
|
54
|
+
approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
|
|
55
|
+
graph_cache_path=self.ctx.config.gnn_graph_cache,
|
|
56
|
+
max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
|
|
57
|
+
knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
|
|
58
|
+
knn_gpu_mem_overhead=float(
|
|
59
|
+
self.ctx.config.gnn_knn_gpu_mem_overhead),
|
|
60
|
+
loss_name=loss_name,
|
|
61
|
+
)
|
|
62
|
+
return self._apply_dataloader_overrides(model)
|
|
63
|
+
|
|
64
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
65
|
+
base_tw_power = self.ctx.default_tweedie_power()
|
|
66
|
+
loss_name = getattr(self.ctx, "loss_name", "tweedie")
|
|
67
|
+
metric_ctx: Dict[str, Any] = {}
|
|
68
|
+
|
|
69
|
+
def data_provider():
|
|
70
|
+
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
71
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
72
|
+
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
73
|
+
|
|
74
|
+
def model_builder(params: Dict[str, Any]):
|
|
75
|
+
if loss_name == "tweedie":
|
|
76
|
+
tw_power = params.get("tw_power", base_tw_power)
|
|
77
|
+
elif loss_name in ("poisson", "gamma"):
|
|
78
|
+
tw_power = base_tw_power
|
|
79
|
+
else:
|
|
80
|
+
tw_power = None
|
|
81
|
+
metric_ctx["tw_power"] = tw_power
|
|
82
|
+
if tw_power is None:
|
|
83
|
+
params = dict(params)
|
|
84
|
+
params.pop("tw_power", None)
|
|
85
|
+
return self._build_model(params)
|
|
86
|
+
|
|
87
|
+
def preprocess_fn(X_train, X_val):
|
|
88
|
+
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
89
|
+
X_train, X_val, self.ctx.num_features)
|
|
90
|
+
return X_train_s, X_val_s
|
|
91
|
+
|
|
92
|
+
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
93
|
+
model.fit(
|
|
94
|
+
X_train,
|
|
95
|
+
y_train,
|
|
96
|
+
w_train=w_train,
|
|
97
|
+
X_val=X_val,
|
|
98
|
+
y_val=y_val,
|
|
99
|
+
w_val=w_val,
|
|
100
|
+
trial=trial_obj,
|
|
101
|
+
)
|
|
102
|
+
return model.predict(X_val)
|
|
103
|
+
|
|
104
|
+
def metric_fn(y_true, y_pred, weight):
|
|
105
|
+
if self.ctx.task_type == 'classification':
|
|
106
|
+
y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
|
|
107
|
+
return log_loss(y_true, y_pred_clipped, sample_weight=weight)
|
|
108
|
+
return regression_loss(
|
|
109
|
+
y_true,
|
|
110
|
+
y_pred,
|
|
111
|
+
weight,
|
|
112
|
+
loss_name=loss_name,
|
|
113
|
+
tweedie_power=metric_ctx.get("tw_power", base_tw_power),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Keep GNN BO lightweight: sample during CV, use full data for final training.
|
|
117
|
+
X_cap = data_provider()[0]
|
|
118
|
+
sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
|
|
119
|
+
|
|
120
|
+
param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
|
|
121
|
+
"learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
|
|
122
|
+
"hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
|
|
123
|
+
"num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
|
|
124
|
+
"k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
|
|
125
|
+
"dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
|
|
126
|
+
"weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
|
|
127
|
+
}
|
|
128
|
+
if self.ctx.task_type == 'regression' and loss_name == 'tweedie':
|
|
129
|
+
param_space["tw_power"] = lambda t: t.suggest_float(
|
|
130
|
+
'tw_power', 1.0, 2.0)
|
|
131
|
+
|
|
132
|
+
return self.cross_val_generic(
|
|
133
|
+
trial=trial,
|
|
134
|
+
hyperparameter_space=param_space,
|
|
135
|
+
data_provider=data_provider,
|
|
136
|
+
model_builder=model_builder,
|
|
137
|
+
metric_fn=metric_fn,
|
|
138
|
+
sample_limit=sample_limit,
|
|
139
|
+
preprocess_fn=preprocess_fn,
|
|
140
|
+
fit_predict_fn=fit_predict,
|
|
141
|
+
cleanup_fn=lambda m: getattr(
|
|
142
|
+
getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def train(self) -> None:
|
|
146
|
+
if not self.best_params:
|
|
147
|
+
raise RuntimeError("Run tune() first to obtain best GNN parameters.")
|
|
148
|
+
|
|
149
|
+
data = self.ctx.train_oht_scl_data
|
|
150
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
151
|
+
X_all = data[self.ctx.var_nmes]
|
|
152
|
+
y_all = data[self.ctx.resp_nme]
|
|
153
|
+
w_all = data[self.ctx.weight_nme]
|
|
154
|
+
|
|
155
|
+
use_refit = bool(getattr(self.ctx.config, "final_refit", True))
|
|
156
|
+
refit_epochs = None
|
|
157
|
+
|
|
158
|
+
split = self._resolve_train_val_indices(X_all)
|
|
159
|
+
if split is not None:
|
|
160
|
+
train_idx, val_idx = split
|
|
161
|
+
X_train = X_all.iloc[train_idx]
|
|
162
|
+
y_train = y_all.iloc[train_idx]
|
|
163
|
+
w_train = w_all.iloc[train_idx]
|
|
164
|
+
X_val = X_all.iloc[val_idx]
|
|
165
|
+
y_val = y_all.iloc[val_idx]
|
|
166
|
+
w_val = w_all.iloc[val_idx]
|
|
167
|
+
|
|
168
|
+
if use_refit:
|
|
169
|
+
tmp_model = self._build_model(self.best_params)
|
|
170
|
+
tmp_model.fit(
|
|
171
|
+
X_train,
|
|
172
|
+
y_train,
|
|
173
|
+
w_train=w_train,
|
|
174
|
+
X_val=X_val,
|
|
175
|
+
y_val=y_val,
|
|
176
|
+
w_val=w_val,
|
|
177
|
+
trial=None,
|
|
178
|
+
)
|
|
179
|
+
refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
|
|
180
|
+
getattr(getattr(tmp_model, "gnn", None), "to",
|
|
181
|
+
lambda *_args, **_kwargs: None)("cpu")
|
|
182
|
+
self._clean_gpu()
|
|
183
|
+
else:
|
|
184
|
+
self.model = self._build_model(self.best_params)
|
|
185
|
+
self.model.fit(
|
|
186
|
+
X_train,
|
|
187
|
+
y_train,
|
|
188
|
+
w_train=w_train,
|
|
189
|
+
X_val=X_val,
|
|
190
|
+
y_val=y_val,
|
|
191
|
+
w_val=w_val,
|
|
192
|
+
trial=None,
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
use_refit = False
|
|
196
|
+
|
|
197
|
+
if use_refit:
|
|
198
|
+
self.model = self._build_model(self.best_params)
|
|
199
|
+
if refit_epochs is not None:
|
|
200
|
+
self.model.epochs = int(refit_epochs)
|
|
201
|
+
self.model.fit(
|
|
202
|
+
X_all,
|
|
203
|
+
y_all,
|
|
204
|
+
w_train=w_all,
|
|
205
|
+
X_val=None,
|
|
206
|
+
y_val=None,
|
|
207
|
+
w_val=None,
|
|
208
|
+
trial=None,
|
|
209
|
+
)
|
|
210
|
+
elif self.model is None:
|
|
211
|
+
self.model = self._build_model(self.best_params)
|
|
212
|
+
self.model.fit(
|
|
213
|
+
X_all,
|
|
214
|
+
y_all,
|
|
215
|
+
w_train=w_all,
|
|
216
|
+
X_val=None,
|
|
217
|
+
y_val=None,
|
|
218
|
+
w_val=None,
|
|
219
|
+
trial=None,
|
|
220
|
+
)
|
|
221
|
+
self.ctx.model_label.append(self.label)
|
|
222
|
+
self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
|
|
223
|
+
self.ctx.gnn_best = self.model
|
|
224
|
+
|
|
225
|
+
# If geo_feature_nmes is set, refresh geo tokens for FT input.
|
|
226
|
+
if self.ctx.config.geo_feature_nmes:
|
|
227
|
+
self.prepare_geo_tokens(force=True)
|
|
228
|
+
|
|
229
|
+
def ensemble_predict(self, k: int) -> None:
|
|
230
|
+
if not self.best_params:
|
|
231
|
+
raise RuntimeError("Run tune() first to obtain best GNN parameters.")
|
|
232
|
+
data = self.ctx.train_oht_scl_data
|
|
233
|
+
test_data = self.ctx.test_oht_scl_data
|
|
234
|
+
if data is None or test_data is None:
|
|
235
|
+
raise RuntimeError("Missing standardized data for GNN ensemble.")
|
|
236
|
+
X_all = data[self.ctx.var_nmes]
|
|
237
|
+
y_all = data[self.ctx.resp_nme]
|
|
238
|
+
w_all = data[self.ctx.weight_nme]
|
|
239
|
+
X_test = test_data[self.ctx.var_nmes]
|
|
240
|
+
|
|
241
|
+
k = max(2, int(k))
|
|
242
|
+
n_samples = len(X_all)
|
|
243
|
+
split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
|
|
244
|
+
if split_iter is None:
|
|
245
|
+
print(
|
|
246
|
+
f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
|
|
247
|
+
flush=True,
|
|
248
|
+
)
|
|
249
|
+
return
|
|
250
|
+
preds_train_sum = np.zeros(n_samples, dtype=np.float64)
|
|
251
|
+
preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
|
|
252
|
+
|
|
253
|
+
split_count = 0
|
|
254
|
+
for train_idx, val_idx in split_iter:
|
|
255
|
+
model = self._build_model(self.best_params)
|
|
256
|
+
model.fit(
|
|
257
|
+
X_all.iloc[train_idx],
|
|
258
|
+
y_all.iloc[train_idx],
|
|
259
|
+
w_train=w_all.iloc[train_idx],
|
|
260
|
+
X_val=X_all.iloc[val_idx],
|
|
261
|
+
y_val=y_all.iloc[val_idx],
|
|
262
|
+
w_val=w_all.iloc[val_idx],
|
|
263
|
+
trial=None,
|
|
264
|
+
)
|
|
265
|
+
pred_train = model.predict(X_all)
|
|
266
|
+
pred_test = model.predict(X_test)
|
|
267
|
+
preds_train_sum += np.asarray(pred_train, dtype=np.float64)
|
|
268
|
+
preds_test_sum += np.asarray(pred_test, dtype=np.float64)
|
|
269
|
+
getattr(getattr(model, "gnn", None), "to",
|
|
270
|
+
lambda *_args, **_kwargs: None)("cpu")
|
|
271
|
+
self._clean_gpu()
|
|
272
|
+
split_count += 1
|
|
273
|
+
|
|
274
|
+
if split_count < 1:
|
|
275
|
+
print(
|
|
276
|
+
f"[GNN Ensemble] no CV splits generated; skip ensemble.",
|
|
277
|
+
flush=True,
|
|
278
|
+
)
|
|
279
|
+
return
|
|
280
|
+
preds_train = preds_train_sum / float(split_count)
|
|
281
|
+
preds_test = preds_test_sum / float(split_count)
|
|
282
|
+
self._cache_predictions("gnn", preds_train, preds_test)
|
|
283
|
+
|
|
284
|
+
def prepare_geo_tokens(self, force: bool = False) -> None:
|
|
285
|
+
"""Train/update the GNN encoder for geo tokens and inject them into FT input."""
|
|
286
|
+
geo_cols = list(self.ctx.config.geo_feature_nmes or [])
|
|
287
|
+
if not geo_cols:
|
|
288
|
+
return
|
|
289
|
+
if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
result = self.ctx._build_geo_tokens()
|
|
293
|
+
if result is None:
|
|
294
|
+
return
|
|
295
|
+
train_tokens, test_tokens, cols, geo_gnn = result
|
|
296
|
+
self.ctx.train_geo_tokens = train_tokens
|
|
297
|
+
self.ctx.test_geo_tokens = test_tokens
|
|
298
|
+
self.ctx.geo_token_cols = cols
|
|
299
|
+
self.ctx.geo_gnn_model = geo_gnn
|
|
300
|
+
print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
|
|
301
|
+
|
|
302
|
+
def save(self) -> None:
|
|
303
|
+
if self.model is None:
|
|
304
|
+
print(f"[save] Warning: No model to save for {self.label}")
|
|
305
|
+
return
|
|
306
|
+
path = self.output.model_path(self._get_model_filename())
|
|
307
|
+
base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
|
|
308
|
+
if base_gnn is not None:
|
|
309
|
+
base_gnn = base_gnn.to("cpu")
|
|
310
|
+
state = None if base_gnn is None else base_gnn.state_dict()
|
|
311
|
+
payload = {
|
|
312
|
+
"best_params": self.best_params,
|
|
313
|
+
"state_dict": state,
|
|
314
|
+
"preprocess_artifacts": self._export_preprocess_artifacts(),
|
|
315
|
+
}
|
|
316
|
+
torch.save(payload, path)
|
|
317
|
+
|
|
318
|
+
def load(self) -> None:
|
|
319
|
+
path = self.output.model_path(self._get_model_filename())
|
|
320
|
+
if not os.path.exists(path):
|
|
321
|
+
print(f"[load] Warning: Model file not found: {path}")
|
|
322
|
+
return
|
|
323
|
+
payload = torch_load(path, map_location='cpu', weights_only=False)
|
|
324
|
+
if not isinstance(payload, dict):
|
|
325
|
+
raise ValueError(f"Invalid GNN checkpoint: {path}")
|
|
326
|
+
params = payload.get("best_params") or {}
|
|
327
|
+
state_dict = payload.get("state_dict")
|
|
328
|
+
model = self._build_model(params)
|
|
329
|
+
if params:
|
|
330
|
+
model.set_params(dict(params))
|
|
331
|
+
base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
|
|
332
|
+
if base_gnn is not None and state_dict is not None:
|
|
333
|
+
# Use strict=True for better error detection, but handle missing keys gracefully
|
|
334
|
+
try:
|
|
335
|
+
base_gnn.load_state_dict(state_dict, strict=True)
|
|
336
|
+
except RuntimeError as e:
|
|
337
|
+
if "Missing key" in str(e) or "Unexpected key" in str(e):
|
|
338
|
+
print(f"[GNN load] Warning: State dict mismatch, loading with strict=False: {e}")
|
|
339
|
+
base_gnn.load_state_dict(state_dict, strict=False)
|
|
340
|
+
else:
|
|
341
|
+
raise
|
|
342
|
+
self.model = model
|
|
343
|
+
self.best_params = dict(params) if isinstance(params, dict) else None
|
|
344
|
+
self.ctx.gnn_best = self.model
|