ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +9 -6
- ins_pricing/__init__.py +3 -11
- ins_pricing/cli/BayesOpt_entry.py +24 -0
- ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
- ins_pricing/cli/Explain_Run.py +25 -0
- ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
- ins_pricing/cli/Pricing_Run.py +25 -0
- ins_pricing/cli/__init__.py +1 -0
- ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
- ins_pricing/cli/utils/__init__.py +1 -0
- ins_pricing/cli/utils/cli_common.py +320 -0
- ins_pricing/cli/utils/cli_config.py +375 -0
- ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
- {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
- ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
- ins_pricing/docs/modelling/README.md +34 -0
- ins_pricing/modelling/__init__.py +57 -6
- ins_pricing/modelling/core/__init__.py +1 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
- ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
- ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
- ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
- ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
- ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
- ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
- ins_pricing/modelling/core/evaluation.py +115 -0
- ins_pricing/production/__init__.py +4 -0
- ins_pricing/production/preprocess.py +71 -0
- ins_pricing/setup.py +10 -5
- {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
- ins_pricing-0.2.0.dist-info/RECORD +125 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
- ins_pricing/modelling/BayesOpt_entry.py +0 -633
- ins_pricing/modelling/Explain_Run.py +0 -36
- ins_pricing/modelling/Pricing_Run.py +0 -36
- ins_pricing/modelling/README.md +0 -33
- ins_pricing/modelling/bayesopt/models.py +0 -2196
- ins_pricing/modelling/bayesopt/trainers.py +0 -2446
- ins_pricing/modelling/cli_common.py +0 -136
- ins_pricing/modelling/tests/test_plotting.py +0 -63
- ins_pricing/modelling/watchdog_run.py +0 -211
- ins_pricing-0.1.11.dist-info/RECORD +0 -169
- ins_pricing_gemini/__init__.py +0 -23
- ins_pricing_gemini/governance/__init__.py +0 -20
- ins_pricing_gemini/governance/approval.py +0 -93
- ins_pricing_gemini/governance/audit.py +0 -37
- ins_pricing_gemini/governance/registry.py +0 -99
- ins_pricing_gemini/governance/release.py +0 -159
- ins_pricing_gemini/modelling/Explain_Run.py +0 -36
- ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
- ins_pricing_gemini/modelling/__init__.py +0 -151
- ins_pricing_gemini/modelling/cli_common.py +0 -141
- ins_pricing_gemini/modelling/config.py +0 -249
- ins_pricing_gemini/modelling/config_preprocess.py +0 -254
- ins_pricing_gemini/modelling/core.py +0 -741
- ins_pricing_gemini/modelling/data_container.py +0 -42
- ins_pricing_gemini/modelling/explain/__init__.py +0 -55
- ins_pricing_gemini/modelling/explain/gradients.py +0 -334
- ins_pricing_gemini/modelling/explain/metrics.py +0 -176
- ins_pricing_gemini/modelling/explain/permutation.py +0 -155
- ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
- ins_pricing_gemini/modelling/features.py +0 -215
- ins_pricing_gemini/modelling/model_manager.py +0 -148
- ins_pricing_gemini/modelling/model_plotting.py +0 -463
- ins_pricing_gemini/modelling/models.py +0 -2203
- ins_pricing_gemini/modelling/notebook_utils.py +0 -294
- ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
- ins_pricing_gemini/modelling/plotting/common.py +0 -63
- ins_pricing_gemini/modelling/plotting/curves.py +0 -572
- ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
- ins_pricing_gemini/modelling/plotting/geo.py +0 -362
- ins_pricing_gemini/modelling/plotting/importance.py +0 -121
- ins_pricing_gemini/modelling/run_logging.py +0 -133
- ins_pricing_gemini/modelling/tests/conftest.py +0 -8
- ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
- ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
- ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
- ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
- ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
- ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
- ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
- ins_pricing_gemini/modelling/trainers.py +0 -2447
- ins_pricing_gemini/modelling/utils.py +0 -1020
- ins_pricing_gemini/pricing/__init__.py +0 -27
- ins_pricing_gemini/pricing/calibration.py +0 -39
- ins_pricing_gemini/pricing/data_quality.py +0 -117
- ins_pricing_gemini/pricing/exposure.py +0 -85
- ins_pricing_gemini/pricing/factors.py +0 -91
- ins_pricing_gemini/pricing/monitoring.py +0 -99
- ins_pricing_gemini/pricing/rate_table.py +0 -78
- ins_pricing_gemini/production/__init__.py +0 -21
- ins_pricing_gemini/production/drift.py +0 -30
- ins_pricing_gemini/production/monitoring.py +0 -143
- ins_pricing_gemini/production/scoring.py +0 -40
- ins_pricing_gemini/reporting/__init__.py +0 -11
- ins_pricing_gemini/reporting/report_builder.py +0 -72
- ins_pricing_gemini/reporting/scheduler.py +0 -45
- ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
- ins_pricing_gemini/scripts/Explain_entry.py +0 -545
- ins_pricing_gemini/scripts/__init__.py +0 -1
- ins_pricing_gemini/scripts/train.py +0 -568
- ins_pricing_gemini/setup.py +0 -55
- ins_pricing_gemini/smoke_test.py +0 -28
- /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
- /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
- /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import optuna
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import statsmodels.api as sm
|
|
9
|
+
from sklearn.metrics import log_loss, mean_tweedie_deviance
|
|
10
|
+
|
|
11
|
+
from .trainer_base import TrainerBase
|
|
12
|
+
from ..utils import EPS
|
|
13
|
+
|
|
14
|
+
class GLMTrainer(TrainerBase):
|
|
15
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
16
|
+
super().__init__(context, 'GLM', 'GLM')
|
|
17
|
+
self.model = None
|
|
18
|
+
|
|
19
|
+
def _select_family(self, tweedie_power: Optional[float] = None):
|
|
20
|
+
if self.ctx.task_type == 'classification':
|
|
21
|
+
return sm.families.Binomial()
|
|
22
|
+
if self.ctx.obj == 'count:poisson':
|
|
23
|
+
return sm.families.Poisson()
|
|
24
|
+
if self.ctx.obj == 'reg:gamma':
|
|
25
|
+
return sm.families.Gamma()
|
|
26
|
+
power = tweedie_power if tweedie_power is not None else 1.5
|
|
27
|
+
return sm.families.Tweedie(var_power=power, link=sm.families.links.log())
|
|
28
|
+
|
|
29
|
+
def _prepare_design(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
30
|
+
# Add intercept to the statsmodels design matrix.
|
|
31
|
+
X = data[self.ctx.var_nmes]
|
|
32
|
+
return sm.add_constant(X, has_constant='add')
|
|
33
|
+
|
|
34
|
+
def _metric_power(self, family, tweedie_power: Optional[float]) -> float:
|
|
35
|
+
if isinstance(family, sm.families.Poisson):
|
|
36
|
+
return 1.0
|
|
37
|
+
if isinstance(family, sm.families.Gamma):
|
|
38
|
+
return 2.0
|
|
39
|
+
if isinstance(family, sm.families.Tweedie):
|
|
40
|
+
return tweedie_power if tweedie_power is not None else getattr(family, 'var_power', 1.5)
|
|
41
|
+
return 1.5
|
|
42
|
+
|
|
43
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
44
|
+
param_space = {
|
|
45
|
+
"alpha": lambda t: t.suggest_float('alpha', 1e-6, 1e2, log=True),
|
|
46
|
+
"l1_ratio": lambda t: t.suggest_float('l1_ratio', 0.0, 1.0)
|
|
47
|
+
}
|
|
48
|
+
if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
|
|
49
|
+
param_space["tweedie_power"] = lambda t: t.suggest_float(
|
|
50
|
+
'tweedie_power', 1.0, 2.0)
|
|
51
|
+
|
|
52
|
+
def data_provider():
|
|
53
|
+
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
54
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
55
|
+
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
56
|
+
|
|
57
|
+
def preprocess_fn(X_train, X_val):
|
|
58
|
+
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
59
|
+
X_train, X_val, self.ctx.num_features)
|
|
60
|
+
return self._prepare_design(X_train_s), self._prepare_design(X_val_s)
|
|
61
|
+
|
|
62
|
+
metric_ctx: Dict[str, Any] = {}
|
|
63
|
+
|
|
64
|
+
def model_builder(params):
|
|
65
|
+
family = self._select_family(params.get("tweedie_power"))
|
|
66
|
+
metric_ctx["family"] = family
|
|
67
|
+
metric_ctx["tweedie_power"] = params.get("tweedie_power")
|
|
68
|
+
return {
|
|
69
|
+
"family": family,
|
|
70
|
+
"alpha": params["alpha"],
|
|
71
|
+
"l1_ratio": params["l1_ratio"],
|
|
72
|
+
"tweedie_power": params.get("tweedie_power")
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def fit_predict(model_cfg, X_train, y_train, w_train, X_val, y_val, w_val, _trial):
|
|
76
|
+
glm = sm.GLM(y_train, X_train,
|
|
77
|
+
family=model_cfg["family"],
|
|
78
|
+
freq_weights=w_train)
|
|
79
|
+
result = glm.fit_regularized(
|
|
80
|
+
alpha=model_cfg["alpha"],
|
|
81
|
+
L1_wt=model_cfg["l1_ratio"],
|
|
82
|
+
maxiter=200
|
|
83
|
+
)
|
|
84
|
+
return result.predict(X_val)
|
|
85
|
+
|
|
86
|
+
def metric_fn(y_true, y_pred, weight):
|
|
87
|
+
if self.ctx.task_type == 'classification':
|
|
88
|
+
y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
|
|
89
|
+
return log_loss(y_true, y_pred_clipped, sample_weight=weight)
|
|
90
|
+
y_pred_safe = np.maximum(y_pred, EPS)
|
|
91
|
+
return mean_tweedie_deviance(
|
|
92
|
+
y_true,
|
|
93
|
+
y_pred_safe,
|
|
94
|
+
sample_weight=weight,
|
|
95
|
+
power=self._metric_power(
|
|
96
|
+
metric_ctx.get("family"), metric_ctx.get("tweedie_power"))
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return self.cross_val_generic(
|
|
100
|
+
trial=trial,
|
|
101
|
+
hyperparameter_space=param_space,
|
|
102
|
+
data_provider=data_provider,
|
|
103
|
+
model_builder=model_builder,
|
|
104
|
+
metric_fn=metric_fn,
|
|
105
|
+
preprocess_fn=preprocess_fn,
|
|
106
|
+
fit_predict_fn=fit_predict
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def train(self) -> None:
|
|
110
|
+
if not self.best_params:
|
|
111
|
+
raise RuntimeError("Run tune() first to obtain best GLM parameters.")
|
|
112
|
+
tweedie_power = self.best_params.get('tweedie_power')
|
|
113
|
+
family = self._select_family(tweedie_power)
|
|
114
|
+
|
|
115
|
+
X_train = self._prepare_design(self.ctx.train_oht_scl_data)
|
|
116
|
+
y_train = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
|
|
117
|
+
w_train = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
|
|
118
|
+
|
|
119
|
+
glm = sm.GLM(y_train, X_train, family=family,
|
|
120
|
+
freq_weights=w_train)
|
|
121
|
+
self.model = glm.fit_regularized(
|
|
122
|
+
alpha=self.best_params['alpha'],
|
|
123
|
+
L1_wt=self.best_params['l1_ratio'],
|
|
124
|
+
maxiter=300
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self.ctx.glm_best = self.model
|
|
128
|
+
self.ctx.model_label += [self.label]
|
|
129
|
+
self._predict_and_cache(
|
|
130
|
+
self.model,
|
|
131
|
+
'glm',
|
|
132
|
+
design_fn=lambda train: self._prepare_design(
|
|
133
|
+
self.ctx.train_oht_scl_data if train else self.ctx.test_oht_scl_data
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def ensemble_predict(self, k: int) -> None:
|
|
138
|
+
if not self.best_params:
|
|
139
|
+
raise RuntimeError("Run tune() first to obtain best GLM parameters.")
|
|
140
|
+
k = max(2, int(k))
|
|
141
|
+
data = self.ctx.train_oht_scl_data
|
|
142
|
+
if data is None:
|
|
143
|
+
raise RuntimeError("Missing standardized data for GLM ensemble.")
|
|
144
|
+
X_all = data[self.ctx.var_nmes]
|
|
145
|
+
y_all = data[self.ctx.resp_nme]
|
|
146
|
+
w_all = data[self.ctx.weight_nme]
|
|
147
|
+
X_test = self.ctx.test_oht_scl_data
|
|
148
|
+
if X_test is None:
|
|
149
|
+
raise RuntimeError("Missing standardized test data for GLM ensemble.")
|
|
150
|
+
|
|
151
|
+
n_samples = len(X_all)
|
|
152
|
+
X_all_design = self._prepare_design(data)
|
|
153
|
+
X_test_design = self._prepare_design(X_test)
|
|
154
|
+
tweedie_power = self.best_params.get('tweedie_power')
|
|
155
|
+
family = self._select_family(tweedie_power)
|
|
156
|
+
|
|
157
|
+
split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
|
|
158
|
+
if split_iter is None:
|
|
159
|
+
print(
|
|
160
|
+
f"[GLM Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
|
|
161
|
+
flush=True,
|
|
162
|
+
)
|
|
163
|
+
return
|
|
164
|
+
preds_train_sum = np.zeros(n_samples, dtype=np.float64)
|
|
165
|
+
preds_test_sum = np.zeros(len(X_test_design), dtype=np.float64)
|
|
166
|
+
|
|
167
|
+
split_count = 0
|
|
168
|
+
for train_idx, _val_idx in split_iter:
|
|
169
|
+
X_train = X_all_design.iloc[train_idx]
|
|
170
|
+
y_train = y_all.iloc[train_idx]
|
|
171
|
+
w_train = w_all.iloc[train_idx]
|
|
172
|
+
|
|
173
|
+
glm = sm.GLM(y_train, X_train, family=family, freq_weights=w_train)
|
|
174
|
+
result = glm.fit_regularized(
|
|
175
|
+
alpha=self.best_params['alpha'],
|
|
176
|
+
L1_wt=self.best_params['l1_ratio'],
|
|
177
|
+
maxiter=300
|
|
178
|
+
)
|
|
179
|
+
pred_train = result.predict(X_all_design)
|
|
180
|
+
pred_test = result.predict(X_test_design)
|
|
181
|
+
preds_train_sum += np.asarray(pred_train, dtype=np.float64)
|
|
182
|
+
preds_test_sum += np.asarray(pred_test, dtype=np.float64)
|
|
183
|
+
split_count += 1
|
|
184
|
+
|
|
185
|
+
if split_count < 1:
|
|
186
|
+
print(
|
|
187
|
+
f"[GLM Ensemble] no CV splits generated; skip ensemble.",
|
|
188
|
+
flush=True,
|
|
189
|
+
)
|
|
190
|
+
return
|
|
191
|
+
preds_train = preds_train_sum / float(split_count)
|
|
192
|
+
preds_test = preds_test_sum / float(split_count)
|
|
193
|
+
self._cache_predictions("glm", preds_train, preds_test)
|
|
194
|
+
|
|
195
|
+
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
import torch
|
|
9
|
+
from sklearn.metrics import log_loss, mean_tweedie_deviance
|
|
10
|
+
|
|
11
|
+
from .trainer_base import TrainerBase
|
|
12
|
+
from ..models import GraphNeuralNetSklearn
|
|
13
|
+
from ..utils import EPS
|
|
14
|
+
|
|
15
|
+
class GNNTrainer(TrainerBase):
|
|
16
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
17
|
+
super().__init__(context, 'GNN', 'GNN')
|
|
18
|
+
self.model: Optional[GraphNeuralNetSklearn] = None
|
|
19
|
+
self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
|
|
20
|
+
|
|
21
|
+
def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
|
|
22
|
+
params = params or {}
|
|
23
|
+
base_tw_power = self.ctx.default_tweedie_power()
|
|
24
|
+
model = GraphNeuralNetSklearn(
|
|
25
|
+
model_nme=f"{self.ctx.model_nme}_gnn",
|
|
26
|
+
input_dim=len(self.ctx.var_nmes),
|
|
27
|
+
hidden_dim=int(params.get("hidden_dim", 64)),
|
|
28
|
+
num_layers=int(params.get("num_layers", 2)),
|
|
29
|
+
k_neighbors=int(params.get("k_neighbors", 10)),
|
|
30
|
+
dropout=float(params.get("dropout", 0.1)),
|
|
31
|
+
learning_rate=float(params.get("learning_rate", 1e-3)),
|
|
32
|
+
epochs=int(params.get("epochs", self.ctx.epochs)),
|
|
33
|
+
patience=int(params.get("patience", 5)),
|
|
34
|
+
task_type=self.ctx.task_type,
|
|
35
|
+
tweedie_power=float(params.get("tw_power", base_tw_power or 1.5)),
|
|
36
|
+
weight_decay=float(params.get("weight_decay", 0.0)),
|
|
37
|
+
use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
|
|
38
|
+
use_ddp=bool(self.ctx.config.use_gnn_ddp),
|
|
39
|
+
use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
|
|
40
|
+
approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
|
|
41
|
+
graph_cache_path=self.ctx.config.gnn_graph_cache,
|
|
42
|
+
max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
|
|
43
|
+
knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
|
|
44
|
+
knn_gpu_mem_overhead=float(
|
|
45
|
+
self.ctx.config.gnn_knn_gpu_mem_overhead),
|
|
46
|
+
)
|
|
47
|
+
return model
|
|
48
|
+
|
|
49
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
50
|
+
base_tw_power = self.ctx.default_tweedie_power()
|
|
51
|
+
metric_ctx: Dict[str, Any] = {}
|
|
52
|
+
|
|
53
|
+
def data_provider():
|
|
54
|
+
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
55
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
56
|
+
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
57
|
+
|
|
58
|
+
def model_builder(params: Dict[str, Any]):
|
|
59
|
+
tw_power = params.get("tw_power", base_tw_power)
|
|
60
|
+
metric_ctx["tw_power"] = tw_power
|
|
61
|
+
return self._build_model(params)
|
|
62
|
+
|
|
63
|
+
def preprocess_fn(X_train, X_val):
|
|
64
|
+
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
65
|
+
X_train, X_val, self.ctx.num_features)
|
|
66
|
+
return X_train_s, X_val_s
|
|
67
|
+
|
|
68
|
+
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
69
|
+
model.fit(
|
|
70
|
+
X_train,
|
|
71
|
+
y_train,
|
|
72
|
+
w_train=w_train,
|
|
73
|
+
X_val=X_val,
|
|
74
|
+
y_val=y_val,
|
|
75
|
+
w_val=w_val,
|
|
76
|
+
trial=trial_obj,
|
|
77
|
+
)
|
|
78
|
+
return model.predict(X_val)
|
|
79
|
+
|
|
80
|
+
def metric_fn(y_true, y_pred, weight):
|
|
81
|
+
if self.ctx.task_type == 'classification':
|
|
82
|
+
y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
|
|
83
|
+
return log_loss(y_true, y_pred_clipped, sample_weight=weight)
|
|
84
|
+
y_pred_safe = np.maximum(y_pred, EPS)
|
|
85
|
+
power = metric_ctx.get("tw_power", base_tw_power or 1.5)
|
|
86
|
+
return mean_tweedie_deviance(
|
|
87
|
+
y_true,
|
|
88
|
+
y_pred_safe,
|
|
89
|
+
sample_weight=weight,
|
|
90
|
+
power=power,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Keep GNN BO lightweight: sample during CV, use full data for final training.
|
|
94
|
+
X_cap = data_provider()[0]
|
|
95
|
+
sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
|
|
96
|
+
|
|
97
|
+
param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
|
|
98
|
+
"learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
|
|
99
|
+
"hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
|
|
100
|
+
"num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
|
|
101
|
+
"k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
|
|
102
|
+
"dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
|
|
103
|
+
"weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
|
|
104
|
+
}
|
|
105
|
+
if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
|
|
106
|
+
param_space["tw_power"] = lambda t: t.suggest_float(
|
|
107
|
+
'tw_power', 1.0, 2.0)
|
|
108
|
+
|
|
109
|
+
return self.cross_val_generic(
|
|
110
|
+
trial=trial,
|
|
111
|
+
hyperparameter_space=param_space,
|
|
112
|
+
data_provider=data_provider,
|
|
113
|
+
model_builder=model_builder,
|
|
114
|
+
metric_fn=metric_fn,
|
|
115
|
+
sample_limit=sample_limit,
|
|
116
|
+
preprocess_fn=preprocess_fn,
|
|
117
|
+
fit_predict_fn=fit_predict,
|
|
118
|
+
cleanup_fn=lambda m: getattr(
|
|
119
|
+
getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def train(self) -> None:
|
|
123
|
+
if not self.best_params:
|
|
124
|
+
raise RuntimeError("Run tune() first to obtain best GNN parameters.")
|
|
125
|
+
|
|
126
|
+
data = self.ctx.train_oht_scl_data
|
|
127
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
128
|
+
X_all = data[self.ctx.var_nmes]
|
|
129
|
+
y_all = data[self.ctx.resp_nme]
|
|
130
|
+
w_all = data[self.ctx.weight_nme]
|
|
131
|
+
|
|
132
|
+
use_refit = bool(getattr(self.ctx.config, "final_refit", True))
|
|
133
|
+
refit_epochs = None
|
|
134
|
+
|
|
135
|
+
split = self._resolve_train_val_indices(X_all)
|
|
136
|
+
if split is not None:
|
|
137
|
+
train_idx, val_idx = split
|
|
138
|
+
X_train = X_all.iloc[train_idx]
|
|
139
|
+
y_train = y_all.iloc[train_idx]
|
|
140
|
+
w_train = w_all.iloc[train_idx]
|
|
141
|
+
X_val = X_all.iloc[val_idx]
|
|
142
|
+
y_val = y_all.iloc[val_idx]
|
|
143
|
+
w_val = w_all.iloc[val_idx]
|
|
144
|
+
|
|
145
|
+
if use_refit:
|
|
146
|
+
tmp_model = self._build_model(self.best_params)
|
|
147
|
+
tmp_model.fit(
|
|
148
|
+
X_train,
|
|
149
|
+
y_train,
|
|
150
|
+
w_train=w_train,
|
|
151
|
+
X_val=X_val,
|
|
152
|
+
y_val=y_val,
|
|
153
|
+
w_val=w_val,
|
|
154
|
+
trial=None,
|
|
155
|
+
)
|
|
156
|
+
refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
|
|
157
|
+
getattr(getattr(tmp_model, "gnn", None), "to",
|
|
158
|
+
lambda *_args, **_kwargs: None)("cpu")
|
|
159
|
+
self._clean_gpu()
|
|
160
|
+
else:
|
|
161
|
+
self.model = self._build_model(self.best_params)
|
|
162
|
+
self.model.fit(
|
|
163
|
+
X_train,
|
|
164
|
+
y_train,
|
|
165
|
+
w_train=w_train,
|
|
166
|
+
X_val=X_val,
|
|
167
|
+
y_val=y_val,
|
|
168
|
+
w_val=w_val,
|
|
169
|
+
trial=None,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
use_refit = False
|
|
173
|
+
|
|
174
|
+
if use_refit:
|
|
175
|
+
self.model = self._build_model(self.best_params)
|
|
176
|
+
if refit_epochs is not None:
|
|
177
|
+
self.model.epochs = int(refit_epochs)
|
|
178
|
+
self.model.fit(
|
|
179
|
+
X_all,
|
|
180
|
+
y_all,
|
|
181
|
+
w_train=w_all,
|
|
182
|
+
X_val=None,
|
|
183
|
+
y_val=None,
|
|
184
|
+
w_val=None,
|
|
185
|
+
trial=None,
|
|
186
|
+
)
|
|
187
|
+
elif self.model is None:
|
|
188
|
+
self.model = self._build_model(self.best_params)
|
|
189
|
+
self.model.fit(
|
|
190
|
+
X_all,
|
|
191
|
+
y_all,
|
|
192
|
+
w_train=w_all,
|
|
193
|
+
X_val=None,
|
|
194
|
+
y_val=None,
|
|
195
|
+
w_val=None,
|
|
196
|
+
trial=None,
|
|
197
|
+
)
|
|
198
|
+
self.ctx.model_label.append(self.label)
|
|
199
|
+
self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
|
|
200
|
+
self.ctx.gnn_best = self.model
|
|
201
|
+
|
|
202
|
+
# If geo_feature_nmes is set, refresh geo tokens for FT input.
|
|
203
|
+
if self.ctx.config.geo_feature_nmes:
|
|
204
|
+
self.prepare_geo_tokens(force=True)
|
|
205
|
+
|
|
206
|
+
def ensemble_predict(self, k: int) -> None:
|
|
207
|
+
if not self.best_params:
|
|
208
|
+
raise RuntimeError("Run tune() first to obtain best GNN parameters.")
|
|
209
|
+
data = self.ctx.train_oht_scl_data
|
|
210
|
+
test_data = self.ctx.test_oht_scl_data
|
|
211
|
+
if data is None or test_data is None:
|
|
212
|
+
raise RuntimeError("Missing standardized data for GNN ensemble.")
|
|
213
|
+
X_all = data[self.ctx.var_nmes]
|
|
214
|
+
y_all = data[self.ctx.resp_nme]
|
|
215
|
+
w_all = data[self.ctx.weight_nme]
|
|
216
|
+
X_test = test_data[self.ctx.var_nmes]
|
|
217
|
+
|
|
218
|
+
k = max(2, int(k))
|
|
219
|
+
n_samples = len(X_all)
|
|
220
|
+
split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
|
|
221
|
+
if split_iter is None:
|
|
222
|
+
print(
|
|
223
|
+
f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
|
|
224
|
+
flush=True,
|
|
225
|
+
)
|
|
226
|
+
return
|
|
227
|
+
preds_train_sum = np.zeros(n_samples, dtype=np.float64)
|
|
228
|
+
preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
|
|
229
|
+
|
|
230
|
+
split_count = 0
|
|
231
|
+
for train_idx, val_idx in split_iter:
|
|
232
|
+
model = self._build_model(self.best_params)
|
|
233
|
+
model.fit(
|
|
234
|
+
X_all.iloc[train_idx],
|
|
235
|
+
y_all.iloc[train_idx],
|
|
236
|
+
w_train=w_all.iloc[train_idx],
|
|
237
|
+
X_val=X_all.iloc[val_idx],
|
|
238
|
+
y_val=y_all.iloc[val_idx],
|
|
239
|
+
w_val=w_all.iloc[val_idx],
|
|
240
|
+
trial=None,
|
|
241
|
+
)
|
|
242
|
+
pred_train = model.predict(X_all)
|
|
243
|
+
pred_test = model.predict(X_test)
|
|
244
|
+
preds_train_sum += np.asarray(pred_train, dtype=np.float64)
|
|
245
|
+
preds_test_sum += np.asarray(pred_test, dtype=np.float64)
|
|
246
|
+
getattr(getattr(model, "gnn", None), "to",
|
|
247
|
+
lambda *_args, **_kwargs: None)("cpu")
|
|
248
|
+
self._clean_gpu()
|
|
249
|
+
split_count += 1
|
|
250
|
+
|
|
251
|
+
if split_count < 1:
|
|
252
|
+
print(
|
|
253
|
+
f"[GNN Ensemble] no CV splits generated; skip ensemble.",
|
|
254
|
+
flush=True,
|
|
255
|
+
)
|
|
256
|
+
return
|
|
257
|
+
preds_train = preds_train_sum / float(split_count)
|
|
258
|
+
preds_test = preds_test_sum / float(split_count)
|
|
259
|
+
self._cache_predictions("gnn", preds_train, preds_test)
|
|
260
|
+
|
|
261
|
+
def prepare_geo_tokens(self, force: bool = False) -> None:
|
|
262
|
+
"""Train/update the GNN encoder for geo tokens and inject them into FT input."""
|
|
263
|
+
geo_cols = list(self.ctx.config.geo_feature_nmes or [])
|
|
264
|
+
if not geo_cols:
|
|
265
|
+
return
|
|
266
|
+
if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
|
|
267
|
+
return
|
|
268
|
+
|
|
269
|
+
result = self.ctx._build_geo_tokens()
|
|
270
|
+
if result is None:
|
|
271
|
+
return
|
|
272
|
+
train_tokens, test_tokens, cols, geo_gnn = result
|
|
273
|
+
self.ctx.train_geo_tokens = train_tokens
|
|
274
|
+
self.ctx.test_geo_tokens = test_tokens
|
|
275
|
+
self.ctx.geo_token_cols = cols
|
|
276
|
+
self.ctx.geo_gnn_model = geo_gnn
|
|
277
|
+
print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
|
|
278
|
+
|
|
279
|
+
def save(self) -> None:
|
|
280
|
+
if self.model is None:
|
|
281
|
+
print(f"[save] Warning: No model to save for {self.label}")
|
|
282
|
+
return
|
|
283
|
+
path = self.output.model_path(self._get_model_filename())
|
|
284
|
+
base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
|
|
285
|
+
state = None if base_gnn is None else base_gnn.state_dict()
|
|
286
|
+
payload = {
|
|
287
|
+
"best_params": self.best_params,
|
|
288
|
+
"state_dict": state,
|
|
289
|
+
}
|
|
290
|
+
torch.save(payload, path)
|
|
291
|
+
|
|
292
|
+
def load(self) -> None:
|
|
293
|
+
path = self.output.model_path(self._get_model_filename())
|
|
294
|
+
if not os.path.exists(path):
|
|
295
|
+
print(f"[load] Warning: Model file not found: {path}")
|
|
296
|
+
return
|
|
297
|
+
payload = torch.load(path, map_location='cpu')
|
|
298
|
+
if not isinstance(payload, dict):
|
|
299
|
+
raise ValueError(f"Invalid GNN checkpoint: {path}")
|
|
300
|
+
params = payload.get("best_params") or {}
|
|
301
|
+
state_dict = payload.get("state_dict")
|
|
302
|
+
model = self._build_model(params)
|
|
303
|
+
if params:
|
|
304
|
+
model.set_params(dict(params))
|
|
305
|
+
base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
|
|
306
|
+
if base_gnn is not None and state_dict is not None:
|
|
307
|
+
base_gnn.load_state_dict(state_dict, strict=False)
|
|
308
|
+
self.model = model
|
|
309
|
+
self.best_params = dict(params) if isinstance(params, dict) else None
|
|
310
|
+
self.ctx.gnn_best = self.model
|
|
311
|
+
|
|
312
|
+
|