ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,195 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import optuna
7
+ import pandas as pd
8
+ import statsmodels.api as sm
9
+ from sklearn.metrics import log_loss, mean_tweedie_deviance
10
+
11
+ from .trainer_base import TrainerBase
12
+ from ..utils import EPS
13
+
14
+ class GLMTrainer(TrainerBase):
15
+ def __init__(self, context: "BayesOptModel") -> None:
16
+ super().__init__(context, 'GLM', 'GLM')
17
+ self.model = None
18
+
19
+ def _select_family(self, tweedie_power: Optional[float] = None):
20
+ if self.ctx.task_type == 'classification':
21
+ return sm.families.Binomial()
22
+ if self.ctx.obj == 'count:poisson':
23
+ return sm.families.Poisson()
24
+ if self.ctx.obj == 'reg:gamma':
25
+ return sm.families.Gamma()
26
+ power = tweedie_power if tweedie_power is not None else 1.5
27
+ return sm.families.Tweedie(var_power=power, link=sm.families.links.log())
28
+
29
+ def _prepare_design(self, data: pd.DataFrame) -> pd.DataFrame:
30
+ # Add intercept to the statsmodels design matrix.
31
+ X = data[self.ctx.var_nmes]
32
+ return sm.add_constant(X, has_constant='add')
33
+
34
+ def _metric_power(self, family, tweedie_power: Optional[float]) -> float:
35
+ if isinstance(family, sm.families.Poisson):
36
+ return 1.0
37
+ if isinstance(family, sm.families.Gamma):
38
+ return 2.0
39
+ if isinstance(family, sm.families.Tweedie):
40
+ return tweedie_power if tweedie_power is not None else getattr(family, 'var_power', 1.5)
41
+ return 1.5
42
+
43
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
44
+ param_space = {
45
+ "alpha": lambda t: t.suggest_float('alpha', 1e-6, 1e2, log=True),
46
+ "l1_ratio": lambda t: t.suggest_float('l1_ratio', 0.0, 1.0)
47
+ }
48
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
49
+ param_space["tweedie_power"] = lambda t: t.suggest_float(
50
+ 'tweedie_power', 1.0, 2.0)
51
+
52
+ def data_provider():
53
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
54
+ assert data is not None, "Preprocessed training data is missing."
55
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
56
+
57
+ def preprocess_fn(X_train, X_val):
58
+ X_train_s, X_val_s, _ = self._standardize_fold(
59
+ X_train, X_val, self.ctx.num_features)
60
+ return self._prepare_design(X_train_s), self._prepare_design(X_val_s)
61
+
62
+ metric_ctx: Dict[str, Any] = {}
63
+
64
+ def model_builder(params):
65
+ family = self._select_family(params.get("tweedie_power"))
66
+ metric_ctx["family"] = family
67
+ metric_ctx["tweedie_power"] = params.get("tweedie_power")
68
+ return {
69
+ "family": family,
70
+ "alpha": params["alpha"],
71
+ "l1_ratio": params["l1_ratio"],
72
+ "tweedie_power": params.get("tweedie_power")
73
+ }
74
+
75
+ def fit_predict(model_cfg, X_train, y_train, w_train, X_val, y_val, w_val, _trial):
76
+ glm = sm.GLM(y_train, X_train,
77
+ family=model_cfg["family"],
78
+ freq_weights=w_train)
79
+ result = glm.fit_regularized(
80
+ alpha=model_cfg["alpha"],
81
+ L1_wt=model_cfg["l1_ratio"],
82
+ maxiter=200
83
+ )
84
+ return result.predict(X_val)
85
+
86
+ def metric_fn(y_true, y_pred, weight):
87
+ if self.ctx.task_type == 'classification':
88
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
89
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
90
+ y_pred_safe = np.maximum(y_pred, EPS)
91
+ return mean_tweedie_deviance(
92
+ y_true,
93
+ y_pred_safe,
94
+ sample_weight=weight,
95
+ power=self._metric_power(
96
+ metric_ctx.get("family"), metric_ctx.get("tweedie_power"))
97
+ )
98
+
99
+ return self.cross_val_generic(
100
+ trial=trial,
101
+ hyperparameter_space=param_space,
102
+ data_provider=data_provider,
103
+ model_builder=model_builder,
104
+ metric_fn=metric_fn,
105
+ preprocess_fn=preprocess_fn,
106
+ fit_predict_fn=fit_predict
107
+ )
108
+
109
+ def train(self) -> None:
110
+ if not self.best_params:
111
+ raise RuntimeError("Run tune() first to obtain best GLM parameters.")
112
+ tweedie_power = self.best_params.get('tweedie_power')
113
+ family = self._select_family(tweedie_power)
114
+
115
+ X_train = self._prepare_design(self.ctx.train_oht_scl_data)
116
+ y_train = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
117
+ w_train = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
118
+
119
+ glm = sm.GLM(y_train, X_train, family=family,
120
+ freq_weights=w_train)
121
+ self.model = glm.fit_regularized(
122
+ alpha=self.best_params['alpha'],
123
+ L1_wt=self.best_params['l1_ratio'],
124
+ maxiter=300
125
+ )
126
+
127
+ self.ctx.glm_best = self.model
128
+ self.ctx.model_label += [self.label]
129
+ self._predict_and_cache(
130
+ self.model,
131
+ 'glm',
132
+ design_fn=lambda train: self._prepare_design(
133
+ self.ctx.train_oht_scl_data if train else self.ctx.test_oht_scl_data
134
+ )
135
+ )
136
+
137
+ def ensemble_predict(self, k: int) -> None:
138
+ if not self.best_params:
139
+ raise RuntimeError("Run tune() first to obtain best GLM parameters.")
140
+ k = max(2, int(k))
141
+ data = self.ctx.train_oht_scl_data
142
+ if data is None:
143
+ raise RuntimeError("Missing standardized data for GLM ensemble.")
144
+ X_all = data[self.ctx.var_nmes]
145
+ y_all = data[self.ctx.resp_nme]
146
+ w_all = data[self.ctx.weight_nme]
147
+ X_test = self.ctx.test_oht_scl_data
148
+ if X_test is None:
149
+ raise RuntimeError("Missing standardized test data for GLM ensemble.")
150
+
151
+ n_samples = len(X_all)
152
+ X_all_design = self._prepare_design(data)
153
+ X_test_design = self._prepare_design(X_test)
154
+ tweedie_power = self.best_params.get('tweedie_power')
155
+ family = self._select_family(tweedie_power)
156
+
157
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
158
+ if split_iter is None:
159
+ print(
160
+ f"[GLM Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
161
+ flush=True,
162
+ )
163
+ return
164
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
165
+ preds_test_sum = np.zeros(len(X_test_design), dtype=np.float64)
166
+
167
+ split_count = 0
168
+ for train_idx, _val_idx in split_iter:
169
+ X_train = X_all_design.iloc[train_idx]
170
+ y_train = y_all.iloc[train_idx]
171
+ w_train = w_all.iloc[train_idx]
172
+
173
+ glm = sm.GLM(y_train, X_train, family=family, freq_weights=w_train)
174
+ result = glm.fit_regularized(
175
+ alpha=self.best_params['alpha'],
176
+ L1_wt=self.best_params['l1_ratio'],
177
+ maxiter=300
178
+ )
179
+ pred_train = result.predict(X_all_design)
180
+ pred_test = result.predict(X_test_design)
181
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
182
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
183
+ split_count += 1
184
+
185
+ if split_count < 1:
186
+ print(
187
+ f"[GLM Ensemble] no CV splits generated; skip ensemble.",
188
+ flush=True,
189
+ )
190
+ return
191
+ preds_train = preds_train_sum / float(split_count)
192
+ preds_test = preds_test_sum / float(split_count)
193
+ self._cache_predictions("glm", preds_train, preds_test)
194
+
195
+
@@ -0,0 +1,312 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ from sklearn.metrics import log_loss, mean_tweedie_deviance
10
+
11
+ from .trainer_base import TrainerBase
12
+ from ..models import GraphNeuralNetSklearn
13
+ from ..utils import EPS
14
+
15
+ class GNNTrainer(TrainerBase):
16
+ def __init__(self, context: "BayesOptModel") -> None:
17
+ super().__init__(context, 'GNN', 'GNN')
18
+ self.model: Optional[GraphNeuralNetSklearn] = None
19
+ self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
20
+
21
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
22
+ params = params or {}
23
+ base_tw_power = self.ctx.default_tweedie_power()
24
+ model = GraphNeuralNetSklearn(
25
+ model_nme=f"{self.ctx.model_nme}_gnn",
26
+ input_dim=len(self.ctx.var_nmes),
27
+ hidden_dim=int(params.get("hidden_dim", 64)),
28
+ num_layers=int(params.get("num_layers", 2)),
29
+ k_neighbors=int(params.get("k_neighbors", 10)),
30
+ dropout=float(params.get("dropout", 0.1)),
31
+ learning_rate=float(params.get("learning_rate", 1e-3)),
32
+ epochs=int(params.get("epochs", self.ctx.epochs)),
33
+ patience=int(params.get("patience", 5)),
34
+ task_type=self.ctx.task_type,
35
+ tweedie_power=float(params.get("tw_power", base_tw_power or 1.5)),
36
+ weight_decay=float(params.get("weight_decay", 0.0)),
37
+ use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
38
+ use_ddp=bool(self.ctx.config.use_gnn_ddp),
39
+ use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
40
+ approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
41
+ graph_cache_path=self.ctx.config.gnn_graph_cache,
42
+ max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
43
+ knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
44
+ knn_gpu_mem_overhead=float(
45
+ self.ctx.config.gnn_knn_gpu_mem_overhead),
46
+ )
47
+ return model
48
+
49
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
50
+ base_tw_power = self.ctx.default_tweedie_power()
51
+ metric_ctx: Dict[str, Any] = {}
52
+
53
+ def data_provider():
54
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
55
+ assert data is not None, "Preprocessed training data is missing."
56
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
57
+
58
+ def model_builder(params: Dict[str, Any]):
59
+ tw_power = params.get("tw_power", base_tw_power)
60
+ metric_ctx["tw_power"] = tw_power
61
+ return self._build_model(params)
62
+
63
+ def preprocess_fn(X_train, X_val):
64
+ X_train_s, X_val_s, _ = self._standardize_fold(
65
+ X_train, X_val, self.ctx.num_features)
66
+ return X_train_s, X_val_s
67
+
68
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
69
+ model.fit(
70
+ X_train,
71
+ y_train,
72
+ w_train=w_train,
73
+ X_val=X_val,
74
+ y_val=y_val,
75
+ w_val=w_val,
76
+ trial=trial_obj,
77
+ )
78
+ return model.predict(X_val)
79
+
80
+ def metric_fn(y_true, y_pred, weight):
81
+ if self.ctx.task_type == 'classification':
82
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
83
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
84
+ y_pred_safe = np.maximum(y_pred, EPS)
85
+ power = metric_ctx.get("tw_power", base_tw_power or 1.5)
86
+ return mean_tweedie_deviance(
87
+ y_true,
88
+ y_pred_safe,
89
+ sample_weight=weight,
90
+ power=power,
91
+ )
92
+
93
+ # Keep GNN BO lightweight: sample during CV, use full data for final training.
94
+ X_cap = data_provider()[0]
95
+ sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
96
+
97
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
98
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
99
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
100
+ "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
101
+ "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
102
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
103
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
104
+ }
105
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
106
+ param_space["tw_power"] = lambda t: t.suggest_float(
107
+ 'tw_power', 1.0, 2.0)
108
+
109
+ return self.cross_val_generic(
110
+ trial=trial,
111
+ hyperparameter_space=param_space,
112
+ data_provider=data_provider,
113
+ model_builder=model_builder,
114
+ metric_fn=metric_fn,
115
+ sample_limit=sample_limit,
116
+ preprocess_fn=preprocess_fn,
117
+ fit_predict_fn=fit_predict,
118
+ cleanup_fn=lambda m: getattr(
119
+ getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
120
+ )
121
+
122
+ def train(self) -> None:
123
+ if not self.best_params:
124
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
125
+
126
+ data = self.ctx.train_oht_scl_data
127
+ assert data is not None, "Preprocessed training data is missing."
128
+ X_all = data[self.ctx.var_nmes]
129
+ y_all = data[self.ctx.resp_nme]
130
+ w_all = data[self.ctx.weight_nme]
131
+
132
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
133
+ refit_epochs = None
134
+
135
+ split = self._resolve_train_val_indices(X_all)
136
+ if split is not None:
137
+ train_idx, val_idx = split
138
+ X_train = X_all.iloc[train_idx]
139
+ y_train = y_all.iloc[train_idx]
140
+ w_train = w_all.iloc[train_idx]
141
+ X_val = X_all.iloc[val_idx]
142
+ y_val = y_all.iloc[val_idx]
143
+ w_val = w_all.iloc[val_idx]
144
+
145
+ if use_refit:
146
+ tmp_model = self._build_model(self.best_params)
147
+ tmp_model.fit(
148
+ X_train,
149
+ y_train,
150
+ w_train=w_train,
151
+ X_val=X_val,
152
+ y_val=y_val,
153
+ w_val=w_val,
154
+ trial=None,
155
+ )
156
+ refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
157
+ getattr(getattr(tmp_model, "gnn", None), "to",
158
+ lambda *_args, **_kwargs: None)("cpu")
159
+ self._clean_gpu()
160
+ else:
161
+ self.model = self._build_model(self.best_params)
162
+ self.model.fit(
163
+ X_train,
164
+ y_train,
165
+ w_train=w_train,
166
+ X_val=X_val,
167
+ y_val=y_val,
168
+ w_val=w_val,
169
+ trial=None,
170
+ )
171
+ else:
172
+ use_refit = False
173
+
174
+ if use_refit:
175
+ self.model = self._build_model(self.best_params)
176
+ if refit_epochs is not None:
177
+ self.model.epochs = int(refit_epochs)
178
+ self.model.fit(
179
+ X_all,
180
+ y_all,
181
+ w_train=w_all,
182
+ X_val=None,
183
+ y_val=None,
184
+ w_val=None,
185
+ trial=None,
186
+ )
187
+ elif self.model is None:
188
+ self.model = self._build_model(self.best_params)
189
+ self.model.fit(
190
+ X_all,
191
+ y_all,
192
+ w_train=w_all,
193
+ X_val=None,
194
+ y_val=None,
195
+ w_val=None,
196
+ trial=None,
197
+ )
198
+ self.ctx.model_label.append(self.label)
199
+ self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
200
+ self.ctx.gnn_best = self.model
201
+
202
+ # If geo_feature_nmes is set, refresh geo tokens for FT input.
203
+ if self.ctx.config.geo_feature_nmes:
204
+ self.prepare_geo_tokens(force=True)
205
+
206
+ def ensemble_predict(self, k: int) -> None:
207
+ if not self.best_params:
208
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
209
+ data = self.ctx.train_oht_scl_data
210
+ test_data = self.ctx.test_oht_scl_data
211
+ if data is None or test_data is None:
212
+ raise RuntimeError("Missing standardized data for GNN ensemble.")
213
+ X_all = data[self.ctx.var_nmes]
214
+ y_all = data[self.ctx.resp_nme]
215
+ w_all = data[self.ctx.weight_nme]
216
+ X_test = test_data[self.ctx.var_nmes]
217
+
218
+ k = max(2, int(k))
219
+ n_samples = len(X_all)
220
+ split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
221
+ if split_iter is None:
222
+ print(
223
+ f"[GNN Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
224
+ flush=True,
225
+ )
226
+ return
227
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
228
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
229
+
230
+ split_count = 0
231
+ for train_idx, val_idx in split_iter:
232
+ model = self._build_model(self.best_params)
233
+ model.fit(
234
+ X_all.iloc[train_idx],
235
+ y_all.iloc[train_idx],
236
+ w_train=w_all.iloc[train_idx],
237
+ X_val=X_all.iloc[val_idx],
238
+ y_val=y_all.iloc[val_idx],
239
+ w_val=w_all.iloc[val_idx],
240
+ trial=None,
241
+ )
242
+ pred_train = model.predict(X_all)
243
+ pred_test = model.predict(X_test)
244
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
245
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
246
+ getattr(getattr(model, "gnn", None), "to",
247
+ lambda *_args, **_kwargs: None)("cpu")
248
+ self._clean_gpu()
249
+ split_count += 1
250
+
251
+ if split_count < 1:
252
+ print(
253
+ f"[GNN Ensemble] no CV splits generated; skip ensemble.",
254
+ flush=True,
255
+ )
256
+ return
257
+ preds_train = preds_train_sum / float(split_count)
258
+ preds_test = preds_test_sum / float(split_count)
259
+ self._cache_predictions("gnn", preds_train, preds_test)
260
+
261
+ def prepare_geo_tokens(self, force: bool = False) -> None:
262
+ """Train/update the GNN encoder for geo tokens and inject them into FT input."""
263
+ geo_cols = list(self.ctx.config.geo_feature_nmes or [])
264
+ if not geo_cols:
265
+ return
266
+ if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
267
+ return
268
+
269
+ result = self.ctx._build_geo_tokens()
270
+ if result is None:
271
+ return
272
+ train_tokens, test_tokens, cols, geo_gnn = result
273
+ self.ctx.train_geo_tokens = train_tokens
274
+ self.ctx.test_geo_tokens = test_tokens
275
+ self.ctx.geo_token_cols = cols
276
+ self.ctx.geo_gnn_model = geo_gnn
277
+ print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
278
+
279
+ def save(self) -> None:
280
+ if self.model is None:
281
+ print(f"[save] Warning: No model to save for {self.label}")
282
+ return
283
+ path = self.output.model_path(self._get_model_filename())
284
+ base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
285
+ state = None if base_gnn is None else base_gnn.state_dict()
286
+ payload = {
287
+ "best_params": self.best_params,
288
+ "state_dict": state,
289
+ }
290
+ torch.save(payload, path)
291
+
292
+ def load(self) -> None:
293
+ path = self.output.model_path(self._get_model_filename())
294
+ if not os.path.exists(path):
295
+ print(f"[load] Warning: Model file not found: {path}")
296
+ return
297
+ payload = torch.load(path, map_location='cpu')
298
+ if not isinstance(payload, dict):
299
+ raise ValueError(f"Invalid GNN checkpoint: {path}")
300
+ params = payload.get("best_params") or {}
301
+ state_dict = payload.get("state_dict")
302
+ model = self._build_model(params)
303
+ if params:
304
+ model.set_params(dict(params))
305
+ base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
306
+ if base_gnn is not None and state_dict is not None:
307
+ base_gnn.load_state_dict(state_dict, strict=False)
308
+ self.model = model
309
+ self.best_params = dict(params) if isinstance(params, dict) else None
310
+ self.ctx.gnn_best = self.model
311
+
312
+