ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,962 +1,965 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import asdict
4
- from datetime import datetime
5
- import os
6
- from pathlib import Path
7
- from typing import Any, Dict, List, Optional
8
- import numpy as np
9
- import pandas as pd
10
- import torch
11
- from sklearn.model_selection import GroupKFold, ShuffleSplit, TimeSeriesSplit
12
- from sklearn.preprocessing import StandardScaler
13
-
14
- from .config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
15
- from .model_explain_mixin import BayesOptExplainMixin
16
- from .model_plotting_mixin import BayesOptPlottingMixin
17
- from .models import GraphNeuralNetSklearn
18
- from .trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
19
- from .utils import EPS, infer_factor_and_cate_list, set_global_seed
20
- from .utils.losses import (
21
- infer_loss_name_from_model_name,
22
- normalize_loss_name,
23
- resolve_tweedie_power,
24
- resolve_xgb_objective,
25
- )
26
-
27
-
28
- class _CVSplitter:
29
- """Wrapper to carry optional groups or time order for CV splits."""
30
-
31
- def __init__(
32
- self,
33
- splitter,
34
- *,
35
- groups: Optional[pd.Series] = None,
36
- order: Optional[np.ndarray] = None,
37
- ) -> None:
38
- self._splitter = splitter
39
- self._groups = groups
40
- self._order = order
41
-
42
- def split(self, X, y=None, groups=None):
43
- if self._order is not None:
44
- order = np.asarray(self._order)
45
- X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
46
- for tr_idx, val_idx in self._splitter.split(X_ord, y=y):
47
- yield order[tr_idx], order[val_idx]
48
- return
49
- use_groups = groups if groups is not None else self._groups
50
- for tr_idx, val_idx in self._splitter.split(X, y=y, groups=use_groups):
51
- yield tr_idx, val_idx
52
-
53
- # BayesOpt orchestration and SHAP utilities
54
- # =============================================================================
55
- class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
56
- def __init__(self, train_data, test_data,
57
- config: Optional[BayesOptConfig] = None,
58
- # Backward compatibility: individual parameters (DEPRECATED)
59
- model_nme=None, resp_nme=None, weight_nme=None,
60
- factor_nmes: Optional[List[str]] = None, task_type='regression',
61
- binary_resp_nme=None,
62
- cate_list=None, prop_test=0.25, rand_seed=None,
63
- epochs=100, use_gpu=True,
64
- use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
65
- use_gnn_data_parallel: bool = False,
66
- use_resn_ddp: bool = False, use_ft_ddp: bool = False,
67
- use_gnn_ddp: bool = False,
68
- output_dir: Optional[str] = None,
69
- gnn_use_approx_knn: bool = True,
70
- gnn_approx_knn_threshold: int = 50000,
71
- gnn_graph_cache: Optional[str] = None,
72
- gnn_max_gpu_knn_nodes: Optional[int] = 200000,
73
- gnn_knn_gpu_mem_ratio: float = 0.9,
74
- gnn_knn_gpu_mem_overhead: float = 2.0,
75
- ft_role: str = "model",
76
- ft_feature_prefix: str = "ft_emb",
77
- ft_num_numeric_tokens: Optional[int] = None,
78
- infer_categorical_max_unique: int = 50,
79
- infer_categorical_max_ratio: float = 0.05,
80
- reuse_best_params: bool = False,
81
- xgb_max_depth_max: int = 25,
82
- xgb_n_estimators_max: int = 500,
83
- resn_weight_decay: Optional[float] = None,
84
- final_ensemble: bool = False,
85
- final_ensemble_k: int = 3,
86
- final_refit: bool = True,
87
- optuna_storage: Optional[str] = None,
88
- optuna_study_prefix: Optional[str] = None,
89
- best_params_files: Optional[Dict[str, str]] = None,
90
- cv_strategy: Optional[str] = None,
91
- cv_splits: Optional[int] = None,
92
- cv_group_col: Optional[str] = None,
93
- cv_time_col: Optional[str] = None,
94
- cv_time_ascending: bool = True,
95
- ft_oof_folds: Optional[int] = None,
96
- ft_oof_strategy: Optional[str] = None,
97
- ft_oof_shuffle: bool = True,
98
- save_preprocess: bool = False,
99
- preprocess_artifact_path: Optional[str] = None,
100
- plot_path_style: Optional[str] = None,
101
- bo_sample_limit: Optional[int] = None,
102
- cache_predictions: bool = False,
103
- prediction_cache_dir: Optional[str] = None,
104
- prediction_cache_format: Optional[str] = None,
105
- region_province_col: Optional[str] = None,
106
- region_city_col: Optional[str] = None,
107
- region_effect_alpha: Optional[float] = None,
108
- geo_feature_nmes: Optional[List[str]] = None,
109
- geo_token_hidden_dim: Optional[int] = None,
110
- geo_token_layers: Optional[int] = None,
111
- geo_token_dropout: Optional[float] = None,
112
- geo_token_k_neighbors: Optional[int] = None,
113
- geo_token_learning_rate: Optional[float] = None,
114
- geo_token_epochs: Optional[int] = None):
115
- """Orchestrate BayesOpt training across multiple trainers.
116
-
117
- Args:
118
- train_data: Training DataFrame.
119
- test_data: Test DataFrame.
120
- config: BayesOptConfig instance with all configuration (RECOMMENDED).
121
- If provided, all other parameters are ignored.
122
-
123
- # DEPRECATED: Individual parameters (use config instead)
124
- model_nme: Model name prefix used in outputs.
125
- resp_nme: Target column name.
126
- weight_nme: Sample weight column name.
127
- factor_nmes: Feature column list.
128
- task_type: "regression" or "classification".
129
- binary_resp_nme: Optional binary target for lift curves.
130
- cate_list: Categorical feature list.
131
- prop_test: Validation split ratio in CV.
132
- rand_seed: Random seed.
133
- epochs: NN training epochs.
134
- use_gpu: Prefer GPU when available.
135
- use_resn_data_parallel: Enable DataParallel for ResNet.
136
- use_ft_data_parallel: Enable DataParallel for FTTransformer.
137
- use_gnn_data_parallel: Enable DataParallel for GNN.
138
- use_resn_ddp: Enable DDP for ResNet.
139
- use_ft_ddp: Enable DDP for FTTransformer.
140
- use_gnn_ddp: Enable DDP for GNN.
141
- output_dir: Output root for models/results/plots.
142
- gnn_use_approx_knn: Use approximate kNN when available.
143
- gnn_approx_knn_threshold: Row threshold to switch to approximate kNN.
144
- gnn_graph_cache: Optional adjacency cache path.
145
- gnn_max_gpu_knn_nodes: Force CPU kNN above this node count to avoid OOM.
146
- gnn_knn_gpu_mem_ratio: Fraction of free GPU memory for kNN.
147
- gnn_knn_gpu_mem_overhead: Temporary memory multiplier for GPU kNN.
148
- ft_num_numeric_tokens: Number of numeric tokens for FT (None = auto).
149
- final_ensemble: Enable k-fold model averaging at the final stage.
150
- final_ensemble_k: Number of folds for averaging.
151
- final_refit: Refit on full data using best stopping point.
152
-
153
- Examples:
154
- # New style (recommended):
155
- config = BayesOptConfig(
156
- model_nme="my_model",
157
- resp_nme="target",
158
- weight_nme="weight",
159
- factor_nmes=["feat1", "feat2"]
160
- )
161
- model = BayesOptModel(train_df, test_df, config=config)
162
-
163
- # Old style (deprecated, for backward compatibility):
164
- model = BayesOptModel(
165
- train_df, test_df,
166
- model_nme="my_model",
167
- resp_nme="target",
168
- weight_nme="weight",
169
- factor_nmes=["feat1", "feat2"]
170
- )
171
- """
172
- # Detect which API is being used
173
- if config is not None:
174
- # New API: config object provided
175
- if isinstance(config, BayesOptConfig):
176
- cfg = config
177
- else:
178
- raise TypeError(
179
- f"config must be a BayesOptConfig instance, got {type(config).__name__}"
180
- )
181
- else:
182
- # Old API: individual parameters (backward compatibility)
183
- # Show deprecation warning
184
- import warnings
185
- warnings.warn(
186
- "Passing individual parameters to BayesOptModel.__init__ is deprecated. "
187
- "Use the 'config' parameter with a BayesOptConfig instance instead:\n"
188
- " config = BayesOptConfig(model_nme=..., resp_nme=..., ...)\n"
189
- " model = BayesOptModel(train_data, test_data, config=config)\n"
190
- "Individual parameters will be removed in v0.4.0.",
191
- DeprecationWarning,
192
- stacklevel=2
193
- )
194
-
195
- # Validate required parameters
196
- if model_nme is None:
197
- raise ValueError("model_nme is required when not using config parameter")
198
- if resp_nme is None:
199
- raise ValueError("resp_nme is required when not using config parameter")
200
- if weight_nme is None:
201
- raise ValueError("weight_nme is required when not using config parameter")
202
-
203
- # Infer categorical features if needed
204
- inferred_factors, inferred_cats = infer_factor_and_cate_list(
205
- train_df=train_data,
206
- test_df=test_data,
207
- resp_nme=resp_nme,
208
- weight_nme=weight_nme,
209
- binary_resp_nme=binary_resp_nme,
210
- factor_nmes=factor_nmes,
211
- cate_list=cate_list,
212
- infer_categorical_max_unique=int(infer_categorical_max_unique),
213
- infer_categorical_max_ratio=float(infer_categorical_max_ratio),
214
- )
215
-
216
- # Construct config from individual parameters
217
- cfg = BayesOptConfig(
218
- model_nme=model_nme,
219
- task_type=task_type,
220
- resp_nme=resp_nme,
221
- weight_nme=weight_nme,
222
- factor_nmes=list(inferred_factors),
223
- binary_resp_nme=binary_resp_nme,
224
- cate_list=list(inferred_cats) if inferred_cats else None,
225
- prop_test=prop_test,
226
- rand_seed=rand_seed,
227
- epochs=epochs,
228
- use_gpu=use_gpu,
229
- xgb_max_depth_max=int(xgb_max_depth_max),
230
- xgb_n_estimators_max=int(xgb_n_estimators_max),
231
- use_resn_data_parallel=use_resn_data_parallel,
232
- use_ft_data_parallel=use_ft_data_parallel,
233
- use_resn_ddp=use_resn_ddp,
234
- use_gnn_data_parallel=use_gnn_data_parallel,
235
- use_ft_ddp=use_ft_ddp,
236
- use_gnn_ddp=use_gnn_ddp,
237
- gnn_use_approx_knn=gnn_use_approx_knn,
238
- gnn_approx_knn_threshold=gnn_approx_knn_threshold,
239
- gnn_graph_cache=gnn_graph_cache,
240
- gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
241
- gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
242
- gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
243
- output_dir=output_dir,
244
- optuna_storage=optuna_storage,
245
- optuna_study_prefix=optuna_study_prefix,
246
- best_params_files=best_params_files,
247
- ft_role=str(ft_role or "model"),
248
- ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
249
- ft_num_numeric_tokens=ft_num_numeric_tokens,
250
- reuse_best_params=bool(reuse_best_params),
251
- resn_weight_decay=float(resn_weight_decay)
252
- if resn_weight_decay is not None
253
- else 1e-4,
254
- final_ensemble=bool(final_ensemble),
255
- final_ensemble_k=int(final_ensemble_k),
256
- final_refit=bool(final_refit),
257
- cv_strategy=str(cv_strategy or "random"),
258
- cv_splits=cv_splits,
259
- cv_group_col=cv_group_col,
260
- cv_time_col=cv_time_col,
261
- cv_time_ascending=bool(cv_time_ascending),
262
- ft_oof_folds=ft_oof_folds,
263
- ft_oof_strategy=ft_oof_strategy,
264
- ft_oof_shuffle=bool(ft_oof_shuffle),
265
- save_preprocess=bool(save_preprocess),
266
- preprocess_artifact_path=preprocess_artifact_path,
267
- plot_path_style=str(plot_path_style or "nested"),
268
- bo_sample_limit=bo_sample_limit,
269
- cache_predictions=bool(cache_predictions),
270
- prediction_cache_dir=prediction_cache_dir,
271
- prediction_cache_format=str(prediction_cache_format or "parquet"),
272
- region_province_col=region_province_col,
273
- region_city_col=region_city_col,
274
- region_effect_alpha=float(region_effect_alpha)
275
- if region_effect_alpha is not None
276
- else 50.0,
277
- geo_feature_nmes=list(geo_feature_nmes)
278
- if geo_feature_nmes is not None
279
- else None,
280
- geo_token_hidden_dim=int(geo_token_hidden_dim)
281
- if geo_token_hidden_dim is not None
282
- else 32,
283
- geo_token_layers=int(geo_token_layers)
284
- if geo_token_layers is not None
285
- else 2,
286
- geo_token_dropout=float(geo_token_dropout)
287
- if geo_token_dropout is not None
288
- else 0.1,
289
- geo_token_k_neighbors=int(geo_token_k_neighbors)
290
- if geo_token_k_neighbors is not None
291
- else 10,
292
- geo_token_learning_rate=float(geo_token_learning_rate)
293
- if geo_token_learning_rate is not None
294
- else 1e-3,
295
- geo_token_epochs=int(geo_token_epochs)
296
- if geo_token_epochs is not None
297
- else 50,
298
- )
299
- self.config = cfg
300
- self.model_nme = cfg.model_nme
301
- self.task_type = cfg.task_type
302
- normalized_loss = normalize_loss_name(getattr(cfg, "loss_name", None), self.task_type)
303
- if self.task_type == "classification":
304
- self.loss_name = "logloss" if normalized_loss == "auto" else normalized_loss
305
- else:
306
- if normalized_loss == "auto":
307
- self.loss_name = infer_loss_name_from_model_name(self.model_nme)
308
- else:
309
- self.loss_name = normalized_loss
310
- self.resp_nme = cfg.resp_nme
311
- self.weight_nme = cfg.weight_nme
312
- self.factor_nmes = cfg.factor_nmes
313
- self.binary_resp_nme = cfg.binary_resp_nme
314
- self.cate_list = list(cfg.cate_list or [])
315
- self.prop_test = cfg.prop_test
316
- self.epochs = cfg.epochs
317
- self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
318
- 1, 10000)
319
- set_global_seed(int(self.rand_seed))
320
- self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
321
- self.output_manager = OutputManager(
322
- cfg.output_dir or os.getcwd(), self.model_nme)
323
-
324
- preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
325
- self.train_data = preprocessor.train_data
326
- self.test_data = preprocessor.test_data
327
- self.train_oht_data = preprocessor.train_oht_data
328
- self.test_oht_data = preprocessor.test_oht_data
329
- self.train_oht_scl_data = preprocessor.train_oht_scl_data
330
- self.test_oht_scl_data = preprocessor.test_oht_scl_data
331
- self.var_nmes = preprocessor.var_nmes
332
- self.num_features = preprocessor.num_features
333
- self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
334
- self.numeric_scalers = preprocessor.numeric_scalers
335
- if getattr(self.config, "save_preprocess", False):
336
- artifact_path = getattr(self.config, "preprocess_artifact_path", None)
337
- if artifact_path:
338
- target = Path(str(artifact_path))
339
- if not target.is_absolute():
340
- target = Path(self.output_manager.result_dir) / target
341
- else:
342
- target = Path(self.output_manager.result_path(
343
- f"{self.model_nme}_preprocess.json"
344
- ))
345
- preprocessor.save_artifacts(target)
346
- self.geo_token_cols: List[str] = []
347
- self.train_geo_tokens: Optional[pd.DataFrame] = None
348
- self.test_geo_tokens: Optional[pd.DataFrame] = None
349
- self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
350
- self._add_region_effect()
351
-
352
- self.cv = self._build_cv_splitter()
353
- if self.task_type == 'classification':
354
- self.obj = 'binary:logistic'
355
- else: # regression task
356
- self.obj = resolve_xgb_objective(self.loss_name)
357
- self.fit_params = {
358
- 'sample_weight': self.train_data[self.weight_nme].values
359
- }
360
- self.model_label: List[str] = []
361
- self.optuna_storage = cfg.optuna_storage
362
- self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
363
-
364
- # Keep trainers in a dict for unified access and easy extension.
365
- self.trainers: Dict[str, TrainerBase] = {
366
- 'glm': GLMTrainer(self),
367
- 'xgb': XGBTrainer(self),
368
- 'resn': ResNetTrainer(self),
369
- 'ft': FTTrainer(self),
370
- 'gnn': GNNTrainer(self),
371
- }
372
- self._prepare_geo_tokens()
373
- self.xgb_best = None
374
- self.resn_best = None
375
- self.gnn_best = None
376
- self.glm_best = None
377
- self.ft_best = None
378
- self.best_xgb_params = None
379
- self.best_resn_params = None
380
- self.best_gnn_params = None
381
- self.best_ft_params = None
382
- self.best_xgb_trial = None
383
- self.best_resn_trial = None
384
- self.best_gnn_trial = None
385
- self.best_ft_trial = None
386
- self.best_glm_params = None
387
- self.best_glm_trial = None
388
- self.xgb_load = None
389
- self.resn_load = None
390
- self.gnn_load = None
391
- self.ft_load = None
392
- self.version_manager = VersionManager(self.output_manager)
393
-
394
- def _build_cv_splitter(self) -> _CVSplitter:
395
- strategy = str(getattr(self.config, "cv_strategy", "random") or "random").strip().lower()
396
- val_ratio = float(self.prop_test) if self.prop_test is not None else 0.25
397
- if not (0.0 < val_ratio < 1.0):
398
- val_ratio = 0.25
399
- cv_splits = getattr(self.config, "cv_splits", None)
400
- if cv_splits is None:
401
- cv_splits = max(2, int(round(1 / val_ratio)))
402
- cv_splits = max(2, int(cv_splits))
403
-
404
- if strategy in {"group", "grouped"}:
405
- group_col = getattr(self.config, "cv_group_col", None)
406
- if not group_col:
407
- raise ValueError("cv_group_col is required for group cv_strategy.")
408
- if group_col not in self.train_data.columns:
409
- raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
410
- groups = self.train_data[group_col]
411
- splitter = GroupKFold(n_splits=cv_splits)
412
- return _CVSplitter(splitter, groups=groups)
413
-
414
- if strategy in {"time", "timeseries", "temporal"}:
415
- time_col = getattr(self.config, "cv_time_col", None)
416
- if not time_col:
417
- raise ValueError("cv_time_col is required for time cv_strategy.")
418
- if time_col not in self.train_data.columns:
419
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
420
- ascending = bool(getattr(self.config, "cv_time_ascending", True))
421
- order_index = self.train_data[time_col].sort_values(ascending=ascending).index
422
- order = self.train_data.index.get_indexer(order_index)
423
- splitter = TimeSeriesSplit(n_splits=cv_splits)
424
- return _CVSplitter(splitter, order=order)
425
-
426
- splitter = ShuffleSplit(
427
- n_splits=cv_splits,
428
- test_size=val_ratio,
429
- random_state=self.rand_seed,
430
- )
431
- return _CVSplitter(splitter)
432
-
433
- def default_tweedie_power(self, obj: Optional[str] = None) -> Optional[float]:
434
- if self.task_type == 'classification':
435
- return None
436
- loss_name = getattr(self, "loss_name", None)
437
- if loss_name:
438
- resolved = resolve_tweedie_power(str(loss_name), default=1.5)
439
- if resolved is not None:
440
- return resolved
441
- objective = obj or getattr(self, "obj", None)
442
- if objective == 'count:poisson':
443
- return 1.0
444
- if objective == 'reg:gamma':
445
- return 2.0
446
- return 1.5
447
-
448
- def _build_geo_tokens(self, params_override: Optional[Dict[str, Any]] = None):
449
- """Internal builder; allows trial overrides and returns None on failure."""
450
- geo_cols = list(self.config.geo_feature_nmes or [])
451
- if not geo_cols:
452
- return None
453
-
454
- available = [c for c in geo_cols if c in self.train_data.columns]
455
- if not available:
456
- return None
457
-
458
- # Preprocess text/numeric: fill numeric with median, label-encode text, map unknowns.
459
- proc_train = {}
460
- proc_test = {}
461
- for col in available:
462
- s_train = self.train_data[col]
463
- s_test = self.test_data[col]
464
- if pd.api.types.is_numeric_dtype(s_train):
465
- tr = pd.to_numeric(s_train, errors="coerce")
466
- te = pd.to_numeric(s_test, errors="coerce")
467
- med = np.nanmedian(tr)
468
- proc_train[col] = np.nan_to_num(tr, nan=med).astype(np.float32)
469
- proc_test[col] = np.nan_to_num(te, nan=med).astype(np.float32)
470
- else:
471
- cats = pd.Categorical(s_train.astype(str))
472
- tr_codes = cats.codes.astype(np.float32, copy=True)
473
- tr_codes[tr_codes < 0] = len(cats.categories)
474
- te_cats = pd.Categorical(
475
- s_test.astype(str), categories=cats.categories)
476
- te_codes = te_cats.codes.astype(np.float32, copy=True)
477
- te_codes[te_codes < 0] = len(cats.categories)
478
- proc_train[col] = tr_codes
479
- proc_test[col] = te_codes
480
-
481
- train_geo_raw = pd.DataFrame(proc_train, index=self.train_data.index)
482
- test_geo_raw = pd.DataFrame(proc_test, index=self.test_data.index)
483
-
484
- scaler = StandardScaler()
485
- train_geo = pd.DataFrame(
486
- scaler.fit_transform(train_geo_raw),
487
- columns=available,
488
- index=self.train_data.index
489
- )
490
- test_geo = pd.DataFrame(
491
- scaler.transform(test_geo_raw),
492
- columns=available,
493
- index=self.test_data.index
494
- )
495
-
496
- tw_power = self.default_tweedie_power()
497
-
498
- cfg = params_override or {}
499
- try:
500
- geo_gnn = GraphNeuralNetSklearn(
501
- model_nme=f"{self.model_nme}_geo",
502
- input_dim=len(available),
503
- hidden_dim=cfg.get("geo_token_hidden_dim",
504
- self.config.geo_token_hidden_dim),
505
- num_layers=cfg.get("geo_token_layers",
506
- self.config.geo_token_layers),
507
- k_neighbors=cfg.get("geo_token_k_neighbors",
508
- self.config.geo_token_k_neighbors),
509
- dropout=cfg.get("geo_token_dropout",
510
- self.config.geo_token_dropout),
511
- learning_rate=cfg.get(
512
- "geo_token_learning_rate", self.config.geo_token_learning_rate),
513
- epochs=int(cfg.get("geo_token_epochs",
514
- self.config.geo_token_epochs)),
515
- patience=5,
516
- task_type=self.task_type,
517
- tweedie_power=tw_power,
518
- loss_name=self.loss_name,
519
- use_data_parallel=False,
520
- use_ddp=False,
521
- use_approx_knn=self.config.gnn_use_approx_knn,
522
- approx_knn_threshold=self.config.gnn_approx_knn_threshold,
523
- graph_cache_path=None,
524
- max_gpu_knn_nodes=self.config.gnn_max_gpu_knn_nodes,
525
- knn_gpu_mem_ratio=self.config.gnn_knn_gpu_mem_ratio,
526
- knn_gpu_mem_overhead=self.config.gnn_knn_gpu_mem_overhead
527
- )
528
- geo_gnn.fit(
529
- train_geo,
530
- self.train_data[self.resp_nme],
531
- self.train_data[self.weight_nme]
532
- )
533
- train_embed = geo_gnn.encode(train_geo)
534
- test_embed = geo_gnn.encode(test_geo)
535
- cols = [f"geo_token_{i}" for i in range(train_embed.shape[1])]
536
- train_tokens = pd.DataFrame(
537
- train_embed, index=self.train_data.index, columns=cols)
538
- test_tokens = pd.DataFrame(
539
- test_embed, index=self.test_data.index, columns=cols)
540
- return train_tokens, test_tokens, cols, geo_gnn
541
- except Exception as exc:
542
- print(f"[GeoToken] Generation failed: {exc}")
543
- return None
544
-
545
- def _prepare_geo_tokens(self) -> None:
546
- """Build and persist geo tokens with default config values."""
547
- gnn_trainer = self.trainers.get("gnn")
548
- if gnn_trainer is not None and hasattr(gnn_trainer, "prepare_geo_tokens"):
549
- try:
550
- gnn_trainer.prepare_geo_tokens(force=False) # type: ignore[attr-defined]
551
- return
552
- except Exception as exc:
553
- print(f"[GeoToken] GNNTrainer generation failed: {exc}")
554
-
555
- result = self._build_geo_tokens()
556
- if result is None:
557
- return
558
- train_tokens, test_tokens, cols, geo_gnn = result
559
- self.train_geo_tokens = train_tokens
560
- self.test_geo_tokens = test_tokens
561
- self.geo_token_cols = cols
562
- self.geo_gnn_model = geo_gnn
563
- print(f"[GeoToken] Generated {len(cols)}-dim geo tokens; injecting into FT.")
564
-
565
- def _add_region_effect(self) -> None:
566
- """Partial pooling over province/city to create a smoothed region_effect feature."""
567
- prov_col = self.config.region_province_col
568
- city_col = self.config.region_city_col
569
- if not prov_col or not city_col:
570
- return
571
- for col in [prov_col, city_col]:
572
- if col not in self.train_data.columns:
573
- print(f"[RegionEffect] Missing column {col}; skipped.")
574
- return
575
-
576
- def safe_mean(df: pd.DataFrame) -> float:
577
- w = df[self.weight_nme]
578
- y = df[self.resp_nme]
579
- denom = max(float(w.sum()), EPS)
580
- return float((y * w).sum() / denom)
581
-
582
- global_mean = safe_mean(self.train_data)
583
- alpha = max(float(self.config.region_effect_alpha), 0.0)
584
-
585
- w_all = self.train_data[self.weight_nme]
586
- y_all = self.train_data[self.resp_nme]
587
- yw_all = y_all * w_all
588
-
589
- prov_sumw = w_all.groupby(self.train_data[prov_col]).sum()
590
- prov_sumyw = yw_all.groupby(self.train_data[prov_col]).sum()
591
- prov_mean = (prov_sumyw / prov_sumw.clip(lower=EPS)).astype(float)
592
- prov_mean = prov_mean.fillna(global_mean)
593
-
594
- city_sumw = self.train_data.groupby([prov_col, city_col])[
595
- self.weight_nme].sum()
596
- city_sumyw = yw_all.groupby(
597
- [self.train_data[prov_col], self.train_data[city_col]]).sum()
598
- city_df = pd.DataFrame({
599
- "sum_w": city_sumw,
600
- "sum_yw": city_sumyw,
601
- })
602
- city_df["prior"] = city_df.index.get_level_values(0).map(
603
- prov_mean).fillna(global_mean)
604
- city_df["effect"] = (
605
- city_df["sum_yw"] + alpha * city_df["prior"]
606
- ) / (city_df["sum_w"] + alpha).clip(lower=EPS)
607
- city_effect = city_df["effect"]
608
-
609
- def lookup_effect(df: pd.DataFrame) -> pd.Series:
610
- idx = pd.MultiIndex.from_frame(df[[prov_col, city_col]])
611
- effects = city_effect.reindex(idx).to_numpy(dtype=np.float64)
612
- prov_fallback = df[prov_col].map(
613
- prov_mean).fillna(global_mean).to_numpy(dtype=np.float64)
614
- effects = np.where(np.isfinite(effects), effects, prov_fallback)
615
- effects = np.where(np.isfinite(effects), effects, global_mean)
616
- return pd.Series(effects, index=df.index, dtype=np.float32)
617
-
618
- re_train = lookup_effect(self.train_data)
619
- re_test = lookup_effect(self.test_data)
620
-
621
- col_name = "region_effect"
622
- self.train_data[col_name] = re_train
623
- self.test_data[col_name] = re_test
624
-
625
- # Sync into one-hot and scaled variants.
626
- for df in [self.train_oht_data, self.test_oht_data]:
627
- if df is not None:
628
- df[col_name] = re_train if df is self.train_oht_data else re_test
629
-
630
- # Standardize region_effect and propagate.
631
- scaler = StandardScaler()
632
- re_train_s = scaler.fit_transform(
633
- re_train.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
634
- re_test_s = scaler.transform(
635
- re_test.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
636
- for df in [self.train_oht_scl_data, self.test_oht_scl_data]:
637
- if df is not None:
638
- df[col_name] = re_train_s if df is self.train_oht_scl_data else re_test_s
639
-
640
- # Update feature lists.
641
- if col_name not in self.factor_nmes:
642
- self.factor_nmes.append(col_name)
643
- if col_name not in self.num_features:
644
- self.num_features.append(col_name)
645
- if self.train_oht_scl_data is not None:
646
- excluded = {self.weight_nme, self.resp_nme}
647
- self.var_nmes = [
648
- col for col in self.train_oht_scl_data.columns if col not in excluded
649
- ]
650
-
651
- def _require_trainer(self, model_key: str) -> "TrainerBase":
652
- trainer = self.trainers.get(model_key)
653
- if trainer is None:
654
- raise KeyError(f"Unknown model key: {model_key}")
655
- return trainer
656
-
657
- def _pred_vector_columns(self, pred_prefix: str) -> List[str]:
658
- """Return vector feature columns like pred_<prefix>_0.. sorted by suffix."""
659
- col_prefix = f"pred_{pred_prefix}_"
660
- cols = [c for c in self.train_data.columns if c.startswith(col_prefix)]
661
-
662
- def sort_key(name: str):
663
- tail = name.rsplit("_", 1)[-1]
664
- try:
665
- return (0, int(tail))
666
- except Exception:
667
- return (1, tail)
668
-
669
- cols.sort(key=sort_key)
670
- return cols
671
-
672
- def _inject_pred_features(self, pred_prefix: str) -> List[str]:
673
- """Inject pred_<prefix> or pred_<prefix>_i columns into features and return names."""
674
- cols = self._pred_vector_columns(pred_prefix)
675
- if cols:
676
- self.add_numeric_features_from_columns(cols)
677
- return cols
678
- scalar_col = f"pred_{pred_prefix}"
679
- if scalar_col in self.train_data.columns:
680
- self.add_numeric_feature_from_column(scalar_col)
681
- return [scalar_col]
682
- return []
683
-
684
- def _maybe_load_best_params(self, model_key: str, trainer: "TrainerBase") -> None:
685
- # 1) If best_params_files is specified, load and skip tuning.
686
- best_params_files = getattr(self.config, "best_params_files", None) or {}
687
- best_params_file = best_params_files.get(model_key)
688
- if best_params_file and not trainer.best_params:
689
- trainer.best_params = IOUtils.load_params_file(best_params_file)
690
- trainer.best_trial = None
691
- print(
692
- f"[Optuna][{trainer.label}] Loaded best_params from {best_params_file}; skip tuning."
693
- )
694
-
695
- # 2) If reuse_best_params is enabled, prefer version snapshots; else load legacy CSV.
696
- reuse_params = bool(getattr(self.config, "reuse_best_params", False))
697
- if reuse_params and not trainer.best_params:
698
- payload = self.version_manager.load_latest(f"{model_key}_best")
699
- best_params = None if payload is None else payload.get("best_params")
700
- if best_params:
701
- trainer.best_params = best_params
702
- trainer.best_trial = None
703
- trainer.study_name = payload.get(
704
- "study_name") if isinstance(payload, dict) else None
705
- print(
706
- f"[Optuna][{trainer.label}] Reusing best_params from versions snapshot.")
707
- return
708
-
709
- params_path = self.output_manager.result_path(
710
- f'{self.model_nme}_bestparams_{trainer.label.lower()}.csv'
711
- )
712
- if os.path.exists(params_path):
713
- try:
714
- trainer.best_params = IOUtils.load_params_file(params_path)
715
- trainer.best_trial = None
716
- print(
717
- f"[Optuna][{trainer.label}] Reusing best_params from {params_path}.")
718
- except ValueError:
719
- # Legacy compatibility: ignore empty files and continue tuning.
720
- pass
721
-
722
- # Generic optimization entry point.
723
- def optimize_model(self, model_key: str, max_evals: int = 100):
724
- if model_key not in self.trainers:
725
- print(f"Warning: Unknown model key: {model_key}")
726
- return
727
-
728
- trainer = self._require_trainer(model_key)
729
- self._maybe_load_best_params(model_key, trainer)
730
-
731
- should_tune = not trainer.best_params
732
- if should_tune:
733
- if model_key == "ft" and str(self.config.ft_role) == "unsupervised_embedding":
734
- if hasattr(trainer, "cross_val_unsupervised"):
735
- trainer.tune(
736
- max_evals,
737
- objective_fn=getattr(trainer, "cross_val_unsupervised")
738
- )
739
- else:
740
- raise RuntimeError(
741
- "FT trainer does not support unsupervised Optuna objective.")
742
- else:
743
- trainer.tune(max_evals)
744
-
745
- if model_key == "ft" and str(self.config.ft_role) != "model":
746
- prefix = str(self.config.ft_feature_prefix or "ft_emb")
747
- role = str(self.config.ft_role)
748
- if role == "embedding":
749
- trainer.train_as_feature(
750
- pred_prefix=prefix, feature_mode="embedding")
751
- elif role == "unsupervised_embedding":
752
- trainer.pretrain_unsupervised_as_feature(
753
- pred_prefix=prefix,
754
- params=trainer.best_params
755
- )
756
- else:
757
- raise ValueError(
758
- f"Unsupported ft_role='{role}', expected 'model'/'embedding'/'unsupervised_embedding'.")
759
-
760
- # Inject generated prediction/embedding columns as features (scalar or vector).
761
- self._inject_pred_features(prefix)
762
- # Do not add FT as a standalone model label; downstream models handle evaluation.
763
- else:
764
- trainer.train()
765
-
766
- if bool(getattr(self.config, "final_ensemble", False)):
767
- k = int(getattr(self.config, "final_ensemble_k", 3) or 3)
768
- if k > 1:
769
- if model_key == "ft" and str(self.config.ft_role) != "model":
770
- pass
771
- elif hasattr(trainer, "ensemble_predict"):
772
- trainer.ensemble_predict(k)
773
- else:
774
- print(
775
- f"[Ensemble] Trainer '{model_key}' does not support ensemble prediction.",
776
- flush=True,
777
- )
778
-
779
- # Update context fields for backward compatibility.
780
- setattr(self, f"{model_key}_best", trainer.model)
781
- setattr(self, f"best_{model_key}_params", trainer.best_params)
782
- setattr(self, f"best_{model_key}_trial", trainer.best_trial)
783
- # Save a snapshot for traceability.
784
- study_name = getattr(trainer, "study_name", None)
785
- if study_name is None and trainer.best_trial is not None:
786
- study_obj = getattr(trainer.best_trial, "study", None)
787
- study_name = getattr(study_obj, "study_name", None)
788
- snapshot = {
789
- "model_key": model_key,
790
- "timestamp": datetime.now().isoformat(),
791
- "best_params": trainer.best_params,
792
- "study_name": study_name,
793
- "config": asdict(self.config),
794
- }
795
- self.version_manager.save(f"{model_key}_best", snapshot)
796
-
797
- def add_numeric_feature_from_column(self, col_name: str) -> None:
798
- """Add an existing column as a feature and sync one-hot/scaled tables."""
799
- if col_name not in self.train_data.columns or col_name not in self.test_data.columns:
800
- raise KeyError(
801
- f"Column '{col_name}' must exist in both train_data and test_data.")
802
-
803
- if col_name not in self.factor_nmes:
804
- self.factor_nmes.append(col_name)
805
- if col_name not in self.config.factor_nmes:
806
- self.config.factor_nmes.append(col_name)
807
-
808
- if col_name not in self.cate_list and col_name not in self.num_features:
809
- self.num_features.append(col_name)
810
-
811
- if self.train_oht_data is not None and self.test_oht_data is not None:
812
- self.train_oht_data[col_name] = self.train_data[col_name].values
813
- self.test_oht_data[col_name] = self.test_data[col_name].values
814
- if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
815
- scaler = StandardScaler()
816
- tr = self.train_data[col_name].to_numpy(
817
- dtype=np.float32, copy=False).reshape(-1, 1)
818
- te = self.test_data[col_name].to_numpy(
819
- dtype=np.float32, copy=False).reshape(-1, 1)
820
- self.train_oht_scl_data[col_name] = scaler.fit_transform(
821
- tr).reshape(-1)
822
- self.test_oht_scl_data[col_name] = scaler.transform(te).reshape(-1)
823
-
824
- if col_name not in self.var_nmes:
825
- self.var_nmes.append(col_name)
826
-
827
- def add_numeric_features_from_columns(self, col_names: List[str]) -> None:
828
- if not col_names:
829
- return
830
-
831
- missing = [
832
- col for col in col_names
833
- if col not in self.train_data.columns or col not in self.test_data.columns
834
- ]
835
- if missing:
836
- raise KeyError(
837
- f"Column(s) {missing} must exist in both train_data and test_data."
838
- )
839
-
840
- for col_name in col_names:
841
- if col_name not in self.factor_nmes:
842
- self.factor_nmes.append(col_name)
843
- if col_name not in self.config.factor_nmes:
844
- self.config.factor_nmes.append(col_name)
845
- if col_name not in self.cate_list and col_name not in self.num_features:
846
- self.num_features.append(col_name)
847
- if col_name not in self.var_nmes:
848
- self.var_nmes.append(col_name)
849
-
850
- if self.train_oht_data is not None and self.test_oht_data is not None:
851
- self.train_oht_data[col_names] = self.train_data[col_names].to_numpy(copy=False)
852
- self.test_oht_data[col_names] = self.test_data[col_names].to_numpy(copy=False)
853
-
854
- if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
855
- scaler = StandardScaler()
856
- tr = self.train_data[col_names].to_numpy(dtype=np.float32, copy=False)
857
- te = self.test_data[col_names].to_numpy(dtype=np.float32, copy=False)
858
- self.train_oht_scl_data[col_names] = scaler.fit_transform(tr)
859
- self.test_oht_scl_data[col_names] = scaler.transform(te)
860
-
861
- def prepare_ft_as_feature(self, max_evals: int = 50, pred_prefix: str = "ft_feat") -> str:
862
- """Train FT as a feature generator and return the downstream column name."""
863
- ft_trainer = self._require_trainer("ft")
864
- ft_trainer.tune(max_evals=max_evals)
865
- if hasattr(ft_trainer, "train_as_feature"):
866
- ft_trainer.train_as_feature(pred_prefix=pred_prefix)
867
- else:
868
- ft_trainer.train()
869
- feature_col = f"pred_{pred_prefix}"
870
- self.add_numeric_feature_from_column(feature_col)
871
- return feature_col
872
-
873
- def prepare_ft_embedding_as_features(self, max_evals: int = 50, pred_prefix: str = "ft_emb") -> List[str]:
874
- """Train FT and inject pooled embeddings as vector features pred_<prefix>_0.. ."""
875
- ft_trainer = self._require_trainer("ft")
876
- ft_trainer.tune(max_evals=max_evals)
877
- if hasattr(ft_trainer, "train_as_feature"):
878
- ft_trainer.train_as_feature(
879
- pred_prefix=pred_prefix, feature_mode="embedding")
880
- else:
881
- raise RuntimeError(
882
- "FT trainer does not support embedding feature mode.")
883
- cols = self._pred_vector_columns(pred_prefix)
884
- if not cols:
885
- raise RuntimeError(
886
- f"No embedding columns were generated for prefix '{pred_prefix}'.")
887
- self.add_numeric_features_from_columns(cols)
888
- return cols
889
-
890
- def prepare_ft_unsupervised_embedding_as_features(self,
891
- pred_prefix: str = "ft_uemb",
892
- params: Optional[Dict[str,
893
- Any]] = None,
894
- mask_prob_num: float = 0.15,
895
- mask_prob_cat: float = 0.15,
896
- num_loss_weight: float = 1.0,
897
- cat_loss_weight: float = 1.0) -> List[str]:
898
- """Export embeddings after FT self-supervised masked reconstruction pretraining."""
899
- ft_trainer = self._require_trainer("ft")
900
- if not hasattr(ft_trainer, "pretrain_unsupervised_as_feature"):
901
- raise RuntimeError(
902
- "FT trainer does not support unsupervised pretraining.")
903
- ft_trainer.pretrain_unsupervised_as_feature(
904
- pred_prefix=pred_prefix,
905
- params=params,
906
- mask_prob_num=mask_prob_num,
907
- mask_prob_cat=mask_prob_cat,
908
- num_loss_weight=num_loss_weight,
909
- cat_loss_weight=cat_loss_weight
910
- )
911
- cols = self._pred_vector_columns(pred_prefix)
912
- if not cols:
913
- raise RuntimeError(
914
- f"No embedding columns were generated for prefix '{pred_prefix}'.")
915
- self.add_numeric_features_from_columns(cols)
916
- return cols
917
-
918
- # GLM Bayesian optimization wrapper.
919
- def bayesopt_glm(self, max_evals=50):
920
- self.optimize_model('glm', max_evals)
921
-
922
- # XGBoost Bayesian optimization wrapper.
923
- def bayesopt_xgb(self, max_evals=100):
924
- self.optimize_model('xgb', max_evals)
925
-
926
- # ResNet Bayesian optimization wrapper.
927
- def bayesopt_resnet(self, max_evals=100):
928
- self.optimize_model('resn', max_evals)
929
-
930
- # GNN Bayesian optimization wrapper.
931
- def bayesopt_gnn(self, max_evals=50):
932
- self.optimize_model('gnn', max_evals)
933
-
934
- # FT-Transformer Bayesian optimization wrapper.
935
- def bayesopt_ft(self, max_evals=50):
936
- self.optimize_model('ft', max_evals)
937
-
938
- def save_model(self, model_name=None):
939
- keys = [model_name] if model_name else self.trainers.keys()
940
- for key in keys:
941
- if key in self.trainers:
942
- self.trainers[key].save()
943
- else:
944
- if model_name: # Only warn when the user specifies a model name.
945
- print(f"[save_model] Warning: Unknown model key {key}")
946
-
947
- def load_model(self, model_name=None):
948
- keys = [model_name] if model_name else self.trainers.keys()
949
- for key in keys:
950
- if key in self.trainers:
951
- self.trainers[key].load()
952
- # Sync context fields.
953
- trainer = self.trainers[key]
954
- if trainer.model is not None:
955
- setattr(self, f"{key}_best", trainer.model)
956
- # For legacy compatibility, also update xxx_load.
957
- # Old versions only tracked xgb_load/resn_load/ft_load (not glm_load/gnn_load).
958
- if key in ['xgb', 'resn', 'ft', 'gnn']:
959
- setattr(self, f"{key}_load", trainer.model)
960
- else:
961
- if model_name:
962
- print(f"[load_model] Warning: Unknown model key {key}")
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict
4
+ from datetime import datetime
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from sklearn.model_selection import GroupKFold, ShuffleSplit, TimeSeriesSplit
12
+ from sklearn.preprocessing import StandardScaler
13
+
14
+ from ins_pricing.modelling.bayesopt.config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
15
+ from ins_pricing.modelling.bayesopt.model_explain_mixin import BayesOptExplainMixin
16
+ from ins_pricing.modelling.bayesopt.model_plotting_mixin import BayesOptPlottingMixin
17
+ from ins_pricing.modelling.bayesopt.models import GraphNeuralNetSklearn
18
+ from ins_pricing.modelling.bayesopt.trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
19
+ from ins_pricing.utils import EPS, infer_factor_and_cate_list, set_global_seed
20
+ from ins_pricing.utils.io import IOUtils
21
+ from ins_pricing.utils.losses import (
22
+ infer_loss_name_from_model_name,
23
+ normalize_loss_name,
24
+ resolve_tweedie_power,
25
+ resolve_xgb_objective,
26
+ )
27
+
28
+
29
+ class _CVSplitter:
30
+ """Wrapper to carry optional groups or time order for CV splits."""
31
+
32
+ def __init__(
33
+ self,
34
+ splitter,
35
+ *,
36
+ groups: Optional[pd.Series] = None,
37
+ order: Optional[np.ndarray] = None,
38
+ ) -> None:
39
+ self._splitter = splitter
40
+ self._groups = groups
41
+ self._order = order
42
+
43
+ def split(self, X, y=None, groups=None):
44
+ if self._order is not None:
45
+ order = np.asarray(self._order)
46
+ X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
47
+ for tr_idx, val_idx in self._splitter.split(X_ord, y=y):
48
+ yield order[tr_idx], order[val_idx]
49
+ return
50
+ use_groups = groups if groups is not None else self._groups
51
+ for tr_idx, val_idx in self._splitter.split(X, y=y, groups=use_groups):
52
+ yield tr_idx, val_idx
53
+
54
+ # BayesOpt orchestration and SHAP utilities
55
+ # =============================================================================
56
+ class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
57
+ def __init__(self, train_data, test_data,
58
+ config: Optional[BayesOptConfig] = None,
59
+ # Backward compatibility: individual parameters (DEPRECATED)
60
+ model_nme=None, resp_nme=None, weight_nme=None,
61
+ factor_nmes: Optional[List[str]] = None, task_type='regression',
62
+ binary_resp_nme=None,
63
+ cate_list=None, prop_test=0.25, rand_seed=None,
64
+ epochs=100, use_gpu=True,
65
+ use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
66
+ use_gnn_data_parallel: bool = False,
67
+ use_resn_ddp: bool = False, use_ft_ddp: bool = False,
68
+ use_gnn_ddp: bool = False,
69
+ output_dir: Optional[str] = None,
70
+ gnn_use_approx_knn: bool = True,
71
+ gnn_approx_knn_threshold: int = 50000,
72
+ gnn_graph_cache: Optional[str] = None,
73
+ gnn_max_gpu_knn_nodes: Optional[int] = 200000,
74
+ gnn_knn_gpu_mem_ratio: float = 0.9,
75
+ gnn_knn_gpu_mem_overhead: float = 2.0,
76
+ ft_role: str = "model",
77
+ ft_feature_prefix: str = "ft_emb",
78
+ ft_num_numeric_tokens: Optional[int] = None,
79
+ infer_categorical_max_unique: int = 50,
80
+ infer_categorical_max_ratio: float = 0.05,
81
+ reuse_best_params: bool = False,
82
+ xgb_max_depth_max: int = 25,
83
+ xgb_n_estimators_max: int = 500,
84
+ resn_weight_decay: Optional[float] = None,
85
+ final_ensemble: bool = False,
86
+ final_ensemble_k: int = 3,
87
+ final_refit: bool = True,
88
+ optuna_storage: Optional[str] = None,
89
+ optuna_study_prefix: Optional[str] = None,
90
+ best_params_files: Optional[Dict[str, str]] = None,
91
+ cv_strategy: Optional[str] = None,
92
+ cv_splits: Optional[int] = None,
93
+ cv_group_col: Optional[str] = None,
94
+ cv_time_col: Optional[str] = None,
95
+ cv_time_ascending: bool = True,
96
+ ft_oof_folds: Optional[int] = None,
97
+ ft_oof_strategy: Optional[str] = None,
98
+ ft_oof_shuffle: bool = True,
99
+ save_preprocess: bool = False,
100
+ preprocess_artifact_path: Optional[str] = None,
101
+ plot_path_style: Optional[str] = None,
102
+ bo_sample_limit: Optional[int] = None,
103
+ cache_predictions: bool = False,
104
+ prediction_cache_dir: Optional[str] = None,
105
+ prediction_cache_format: Optional[str] = None,
106
+ region_province_col: Optional[str] = None,
107
+ region_city_col: Optional[str] = None,
108
+ region_effect_alpha: Optional[float] = None,
109
+ geo_feature_nmes: Optional[List[str]] = None,
110
+ geo_token_hidden_dim: Optional[int] = None,
111
+ geo_token_layers: Optional[int] = None,
112
+ geo_token_dropout: Optional[float] = None,
113
+ geo_token_k_neighbors: Optional[int] = None,
114
+ geo_token_learning_rate: Optional[float] = None,
115
+ geo_token_epochs: Optional[int] = None):
116
+ """Orchestrate BayesOpt training across multiple trainers.
117
+
118
+ Args:
119
+ train_data: Training DataFrame.
120
+ test_data: Test DataFrame.
121
+ config: BayesOptConfig instance with all configuration (RECOMMENDED).
122
+ If provided, all other parameters are ignored.
123
+
124
+ # DEPRECATED: Individual parameters (use config instead)
125
+ model_nme: Model name prefix used in outputs.
126
+ resp_nme: Target column name.
127
+ weight_nme: Sample weight column name.
128
+ factor_nmes: Feature column list.
129
+ task_type: "regression" or "classification".
130
+ binary_resp_nme: Optional binary target for lift curves.
131
+ cate_list: Categorical feature list.
132
+ prop_test: Validation split ratio in CV.
133
+ rand_seed: Random seed.
134
+ epochs: NN training epochs.
135
+ use_gpu: Prefer GPU when available.
136
+ use_resn_data_parallel: Enable DataParallel for ResNet.
137
+ use_ft_data_parallel: Enable DataParallel for FTTransformer.
138
+ use_gnn_data_parallel: Enable DataParallel for GNN.
139
+ use_resn_ddp: Enable DDP for ResNet.
140
+ use_ft_ddp: Enable DDP for FTTransformer.
141
+ use_gnn_ddp: Enable DDP for GNN.
142
+ output_dir: Output root for models/results/plots.
143
+ gnn_use_approx_knn: Use approximate kNN when available.
144
+ gnn_approx_knn_threshold: Row threshold to switch to approximate kNN.
145
+ gnn_graph_cache: Optional adjacency cache path.
146
+ gnn_max_gpu_knn_nodes: Force CPU kNN above this node count to avoid OOM.
147
+ gnn_knn_gpu_mem_ratio: Fraction of free GPU memory for kNN.
148
+ gnn_knn_gpu_mem_overhead: Temporary memory multiplier for GPU kNN.
149
+ ft_num_numeric_tokens: Number of numeric tokens for FT (None = auto).
150
+ final_ensemble: Enable k-fold model averaging at the final stage.
151
+ final_ensemble_k: Number of folds for averaging.
152
+ final_refit: Refit on full data using best stopping point.
153
+
154
+ Examples:
155
+ # New style (recommended):
156
+ config = BayesOptConfig(
157
+ model_nme="my_model",
158
+ resp_nme="target",
159
+ weight_nme="weight",
160
+ factor_nmes=["feat1", "feat2"]
161
+ )
162
+ model = BayesOptModel(train_df, test_df, config=config)
163
+
164
+ # Old style (deprecated, for backward compatibility):
165
+ model = BayesOptModel(
166
+ train_df, test_df,
167
+ model_nme="my_model",
168
+ resp_nme="target",
169
+ weight_nme="weight",
170
+ factor_nmes=["feat1", "feat2"]
171
+ )
172
+ """
173
+ # Detect which API is being used
174
+ if config is not None:
175
+ # New API: config object provided
176
+ if isinstance(config, BayesOptConfig):
177
+ cfg = config
178
+ else:
179
+ raise TypeError(
180
+ f"config must be a BayesOptConfig instance, got {type(config).__name__}"
181
+ )
182
+ else:
183
+ # Old API: individual parameters (backward compatibility)
184
+ # Show deprecation warning
185
+ import warnings
186
+ warnings.warn(
187
+ "Passing individual parameters to BayesOptModel.__init__ is deprecated. "
188
+ "Use the 'config' parameter with a BayesOptConfig instance instead:\n"
189
+ " config = BayesOptConfig(model_nme=..., resp_nme=..., ...)\n"
190
+ " model = BayesOptModel(train_data, test_data, config=config)\n"
191
+ "Individual parameters will be removed in v0.4.0.",
192
+ DeprecationWarning,
193
+ stacklevel=2
194
+ )
195
+
196
+ # Validate required parameters
197
+ if model_nme is None:
198
+ raise ValueError("model_nme is required when not using config parameter")
199
+ if resp_nme is None:
200
+ raise ValueError("resp_nme is required when not using config parameter")
201
+ if weight_nme is None:
202
+ raise ValueError("weight_nme is required when not using config parameter")
203
+
204
+ # Infer categorical features if needed
205
+ # Only use user-specified categorical list for one-hot; do not auto-infer.
206
+ user_cate_list = [] if cate_list is None else list(cate_list)
207
+ inferred_factors, inferred_cats = infer_factor_and_cate_list(
208
+ train_df=train_data,
209
+ test_df=test_data,
210
+ resp_nme=resp_nme,
211
+ weight_nme=weight_nme,
212
+ binary_resp_nme=binary_resp_nme,
213
+ factor_nmes=factor_nmes,
214
+ cate_list=user_cate_list,
215
+ infer_categorical_max_unique=int(infer_categorical_max_unique),
216
+ infer_categorical_max_ratio=float(infer_categorical_max_ratio),
217
+ )
218
+
219
+ # Construct config from individual parameters
220
+ cfg = BayesOptConfig(
221
+ model_nme=model_nme,
222
+ task_type=task_type,
223
+ resp_nme=resp_nme,
224
+ weight_nme=weight_nme,
225
+ factor_nmes=list(inferred_factors),
226
+ binary_resp_nme=binary_resp_nme,
227
+ cate_list=list(inferred_cats) if inferred_cats else None,
228
+ prop_test=prop_test,
229
+ rand_seed=rand_seed,
230
+ epochs=epochs,
231
+ use_gpu=use_gpu,
232
+ xgb_max_depth_max=int(xgb_max_depth_max),
233
+ xgb_n_estimators_max=int(xgb_n_estimators_max),
234
+ use_resn_data_parallel=use_resn_data_parallel,
235
+ use_ft_data_parallel=use_ft_data_parallel,
236
+ use_resn_ddp=use_resn_ddp,
237
+ use_gnn_data_parallel=use_gnn_data_parallel,
238
+ use_ft_ddp=use_ft_ddp,
239
+ use_gnn_ddp=use_gnn_ddp,
240
+ gnn_use_approx_knn=gnn_use_approx_knn,
241
+ gnn_approx_knn_threshold=gnn_approx_knn_threshold,
242
+ gnn_graph_cache=gnn_graph_cache,
243
+ gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
244
+ gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
245
+ gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
246
+ output_dir=output_dir,
247
+ optuna_storage=optuna_storage,
248
+ optuna_study_prefix=optuna_study_prefix,
249
+ best_params_files=best_params_files,
250
+ ft_role=str(ft_role or "model"),
251
+ ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
252
+ ft_num_numeric_tokens=ft_num_numeric_tokens,
253
+ reuse_best_params=bool(reuse_best_params),
254
+ resn_weight_decay=float(resn_weight_decay)
255
+ if resn_weight_decay is not None
256
+ else 1e-4,
257
+ final_ensemble=bool(final_ensemble),
258
+ final_ensemble_k=int(final_ensemble_k),
259
+ final_refit=bool(final_refit),
260
+ cv_strategy=str(cv_strategy or "random"),
261
+ cv_splits=cv_splits,
262
+ cv_group_col=cv_group_col,
263
+ cv_time_col=cv_time_col,
264
+ cv_time_ascending=bool(cv_time_ascending),
265
+ ft_oof_folds=ft_oof_folds,
266
+ ft_oof_strategy=ft_oof_strategy,
267
+ ft_oof_shuffle=bool(ft_oof_shuffle),
268
+ save_preprocess=bool(save_preprocess),
269
+ preprocess_artifact_path=preprocess_artifact_path,
270
+ plot_path_style=str(plot_path_style or "nested"),
271
+ bo_sample_limit=bo_sample_limit,
272
+ cache_predictions=bool(cache_predictions),
273
+ prediction_cache_dir=prediction_cache_dir,
274
+ prediction_cache_format=str(prediction_cache_format or "parquet"),
275
+ region_province_col=region_province_col,
276
+ region_city_col=region_city_col,
277
+ region_effect_alpha=float(region_effect_alpha)
278
+ if region_effect_alpha is not None
279
+ else 50.0,
280
+ geo_feature_nmes=list(geo_feature_nmes)
281
+ if geo_feature_nmes is not None
282
+ else None,
283
+ geo_token_hidden_dim=int(geo_token_hidden_dim)
284
+ if geo_token_hidden_dim is not None
285
+ else 32,
286
+ geo_token_layers=int(geo_token_layers)
287
+ if geo_token_layers is not None
288
+ else 2,
289
+ geo_token_dropout=float(geo_token_dropout)
290
+ if geo_token_dropout is not None
291
+ else 0.1,
292
+ geo_token_k_neighbors=int(geo_token_k_neighbors)
293
+ if geo_token_k_neighbors is not None
294
+ else 10,
295
+ geo_token_learning_rate=float(geo_token_learning_rate)
296
+ if geo_token_learning_rate is not None
297
+ else 1e-3,
298
+ geo_token_epochs=int(geo_token_epochs)
299
+ if geo_token_epochs is not None
300
+ else 50,
301
+ )
302
+ self.config = cfg
303
+ self.model_nme = cfg.model_nme
304
+ self.task_type = cfg.task_type
305
+ normalized_loss = normalize_loss_name(getattr(cfg, "loss_name", None), self.task_type)
306
+ if self.task_type == "classification":
307
+ self.loss_name = "logloss" if normalized_loss == "auto" else normalized_loss
308
+ else:
309
+ if normalized_loss == "auto":
310
+ self.loss_name = infer_loss_name_from_model_name(self.model_nme)
311
+ else:
312
+ self.loss_name = normalized_loss
313
+ self.resp_nme = cfg.resp_nme
314
+ self.weight_nme = cfg.weight_nme
315
+ self.factor_nmes = cfg.factor_nmes
316
+ self.binary_resp_nme = cfg.binary_resp_nme
317
+ self.cate_list = list(cfg.cate_list or [])
318
+ self.prop_test = cfg.prop_test
319
+ self.epochs = cfg.epochs
320
+ self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
321
+ 1, 10000)
322
+ set_global_seed(int(self.rand_seed))
323
+ self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
324
+ self.output_manager = OutputManager(
325
+ cfg.output_dir or os.getcwd(), self.model_nme)
326
+
327
+ preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
328
+ self.train_data = preprocessor.train_data
329
+ self.test_data = preprocessor.test_data
330
+ self.train_oht_data = preprocessor.train_oht_data
331
+ self.test_oht_data = preprocessor.test_oht_data
332
+ self.train_oht_scl_data = preprocessor.train_oht_scl_data
333
+ self.test_oht_scl_data = preprocessor.test_oht_scl_data
334
+ self.var_nmes = preprocessor.var_nmes
335
+ self.num_features = preprocessor.num_features
336
+ self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
337
+ self.numeric_scalers = preprocessor.numeric_scalers
338
+ if getattr(self.config, "save_preprocess", False):
339
+ artifact_path = getattr(self.config, "preprocess_artifact_path", None)
340
+ if artifact_path:
341
+ target = Path(str(artifact_path))
342
+ if not target.is_absolute():
343
+ target = Path(self.output_manager.result_dir) / target
344
+ else:
345
+ target = Path(self.output_manager.result_path(
346
+ f"{self.model_nme}_preprocess.json"
347
+ ))
348
+ preprocessor.save_artifacts(target)
349
+ self.geo_token_cols: List[str] = []
350
+ self.train_geo_tokens: Optional[pd.DataFrame] = None
351
+ self.test_geo_tokens: Optional[pd.DataFrame] = None
352
+ self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
353
+ self._add_region_effect()
354
+
355
+ self.cv = self._build_cv_splitter()
356
+ if self.task_type == 'classification':
357
+ self.obj = 'binary:logistic'
358
+ else: # regression task
359
+ self.obj = resolve_xgb_objective(self.loss_name)
360
+ self.fit_params = {
361
+ 'sample_weight': self.train_data[self.weight_nme].values
362
+ }
363
+ self.model_label: List[str] = []
364
+ self.optuna_storage = cfg.optuna_storage
365
+ self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
366
+
367
+ # Keep trainers in a dict for unified access and easy extension.
368
+ self.trainers: Dict[str, TrainerBase] = {
369
+ 'glm': GLMTrainer(self),
370
+ 'xgb': XGBTrainer(self),
371
+ 'resn': ResNetTrainer(self),
372
+ 'ft': FTTrainer(self),
373
+ 'gnn': GNNTrainer(self),
374
+ }
375
+ self._prepare_geo_tokens()
376
+ self.xgb_best = None
377
+ self.resn_best = None
378
+ self.gnn_best = None
379
+ self.glm_best = None
380
+ self.ft_best = None
381
+ self.best_xgb_params = None
382
+ self.best_resn_params = None
383
+ self.best_gnn_params = None
384
+ self.best_ft_params = None
385
+ self.best_xgb_trial = None
386
+ self.best_resn_trial = None
387
+ self.best_gnn_trial = None
388
+ self.best_ft_trial = None
389
+ self.best_glm_params = None
390
+ self.best_glm_trial = None
391
+ self.xgb_load = None
392
+ self.resn_load = None
393
+ self.gnn_load = None
394
+ self.ft_load = None
395
+ self.version_manager = VersionManager(self.output_manager)
396
+
397
+ def _build_cv_splitter(self) -> _CVSplitter:
398
+ strategy = str(getattr(self.config, "cv_strategy", "random") or "random").strip().lower()
399
+ val_ratio = float(self.prop_test) if self.prop_test is not None else 0.25
400
+ if not (0.0 < val_ratio < 1.0):
401
+ val_ratio = 0.25
402
+ cv_splits = getattr(self.config, "cv_splits", None)
403
+ if cv_splits is None:
404
+ cv_splits = max(2, int(round(1 / val_ratio)))
405
+ cv_splits = max(2, int(cv_splits))
406
+
407
+ if strategy in {"group", "grouped"}:
408
+ group_col = getattr(self.config, "cv_group_col", None)
409
+ if not group_col:
410
+ raise ValueError("cv_group_col is required for group cv_strategy.")
411
+ if group_col not in self.train_data.columns:
412
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
413
+ groups = self.train_data[group_col]
414
+ splitter = GroupKFold(n_splits=cv_splits)
415
+ return _CVSplitter(splitter, groups=groups)
416
+
417
+ if strategy in {"time", "timeseries", "temporal"}:
418
+ time_col = getattr(self.config, "cv_time_col", None)
419
+ if not time_col:
420
+ raise ValueError("cv_time_col is required for time cv_strategy.")
421
+ if time_col not in self.train_data.columns:
422
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
423
+ ascending = bool(getattr(self.config, "cv_time_ascending", True))
424
+ order_index = self.train_data[time_col].sort_values(ascending=ascending).index
425
+ order = self.train_data.index.get_indexer(order_index)
426
+ splitter = TimeSeriesSplit(n_splits=cv_splits)
427
+ return _CVSplitter(splitter, order=order)
428
+
429
+ splitter = ShuffleSplit(
430
+ n_splits=cv_splits,
431
+ test_size=val_ratio,
432
+ random_state=self.rand_seed,
433
+ )
434
+ return _CVSplitter(splitter)
435
+
436
+ def default_tweedie_power(self, obj: Optional[str] = None) -> Optional[float]:
437
+ if self.task_type == 'classification':
438
+ return None
439
+ loss_name = getattr(self, "loss_name", None)
440
+ if loss_name:
441
+ resolved = resolve_tweedie_power(str(loss_name), default=1.5)
442
+ if resolved is not None:
443
+ return resolved
444
+ objective = obj or getattr(self, "obj", None)
445
+ if objective == 'count:poisson':
446
+ return 1.0
447
+ if objective == 'reg:gamma':
448
+ return 2.0
449
+ return 1.5
450
+
451
+ def _build_geo_tokens(self, params_override: Optional[Dict[str, Any]] = None):
452
+ """Internal builder; allows trial overrides and returns None on failure."""
453
+ geo_cols = list(self.config.geo_feature_nmes or [])
454
+ if not geo_cols:
455
+ return None
456
+
457
+ available = [c for c in geo_cols if c in self.train_data.columns]
458
+ if not available:
459
+ return None
460
+
461
+ # Preprocess text/numeric: fill numeric with median, label-encode text, map unknowns.
462
+ proc_train = {}
463
+ proc_test = {}
464
+ for col in available:
465
+ s_train = self.train_data[col]
466
+ s_test = self.test_data[col]
467
+ if pd.api.types.is_numeric_dtype(s_train):
468
+ tr = pd.to_numeric(s_train, errors="coerce")
469
+ te = pd.to_numeric(s_test, errors="coerce")
470
+ med = np.nanmedian(tr)
471
+ proc_train[col] = np.nan_to_num(tr, nan=med).astype(np.float32)
472
+ proc_test[col] = np.nan_to_num(te, nan=med).astype(np.float32)
473
+ else:
474
+ cats = pd.Categorical(s_train.astype(str))
475
+ tr_codes = cats.codes.astype(np.float32, copy=True)
476
+ tr_codes[tr_codes < 0] = len(cats.categories)
477
+ te_cats = pd.Categorical(
478
+ s_test.astype(str), categories=cats.categories)
479
+ te_codes = te_cats.codes.astype(np.float32, copy=True)
480
+ te_codes[te_codes < 0] = len(cats.categories)
481
+ proc_train[col] = tr_codes
482
+ proc_test[col] = te_codes
483
+
484
+ train_geo_raw = pd.DataFrame(proc_train, index=self.train_data.index)
485
+ test_geo_raw = pd.DataFrame(proc_test, index=self.test_data.index)
486
+
487
+ scaler = StandardScaler()
488
+ train_geo = pd.DataFrame(
489
+ scaler.fit_transform(train_geo_raw),
490
+ columns=available,
491
+ index=self.train_data.index
492
+ )
493
+ test_geo = pd.DataFrame(
494
+ scaler.transform(test_geo_raw),
495
+ columns=available,
496
+ index=self.test_data.index
497
+ )
498
+
499
+ tw_power = self.default_tweedie_power()
500
+
501
+ cfg = params_override or {}
502
+ try:
503
+ geo_gnn = GraphNeuralNetSklearn(
504
+ model_nme=f"{self.model_nme}_geo",
505
+ input_dim=len(available),
506
+ hidden_dim=cfg.get("geo_token_hidden_dim",
507
+ self.config.geo_token_hidden_dim),
508
+ num_layers=cfg.get("geo_token_layers",
509
+ self.config.geo_token_layers),
510
+ k_neighbors=cfg.get("geo_token_k_neighbors",
511
+ self.config.geo_token_k_neighbors),
512
+ dropout=cfg.get("geo_token_dropout",
513
+ self.config.geo_token_dropout),
514
+ learning_rate=cfg.get(
515
+ "geo_token_learning_rate", self.config.geo_token_learning_rate),
516
+ epochs=int(cfg.get("geo_token_epochs",
517
+ self.config.geo_token_epochs)),
518
+ patience=5,
519
+ task_type=self.task_type,
520
+ tweedie_power=tw_power,
521
+ loss_name=self.loss_name,
522
+ use_data_parallel=False,
523
+ use_ddp=False,
524
+ use_approx_knn=self.config.gnn_use_approx_knn,
525
+ approx_knn_threshold=self.config.gnn_approx_knn_threshold,
526
+ graph_cache_path=None,
527
+ max_gpu_knn_nodes=self.config.gnn_max_gpu_knn_nodes,
528
+ knn_gpu_mem_ratio=self.config.gnn_knn_gpu_mem_ratio,
529
+ knn_gpu_mem_overhead=self.config.gnn_knn_gpu_mem_overhead
530
+ )
531
+ geo_gnn.fit(
532
+ train_geo,
533
+ self.train_data[self.resp_nme],
534
+ self.train_data[self.weight_nme]
535
+ )
536
+ train_embed = geo_gnn.encode(train_geo)
537
+ test_embed = geo_gnn.encode(test_geo)
538
+ cols = [f"geo_token_{i}" for i in range(train_embed.shape[1])]
539
+ train_tokens = pd.DataFrame(
540
+ train_embed, index=self.train_data.index, columns=cols)
541
+ test_tokens = pd.DataFrame(
542
+ test_embed, index=self.test_data.index, columns=cols)
543
+ return train_tokens, test_tokens, cols, geo_gnn
544
+ except Exception as exc:
545
+ print(f"[GeoToken] Generation failed: {exc}")
546
+ return None
547
+
548
+ def _prepare_geo_tokens(self) -> None:
549
+ """Build and persist geo tokens with default config values."""
550
+ gnn_trainer = self.trainers.get("gnn")
551
+ if gnn_trainer is not None and hasattr(gnn_trainer, "prepare_geo_tokens"):
552
+ try:
553
+ gnn_trainer.prepare_geo_tokens(force=False) # type: ignore[attr-defined]
554
+ return
555
+ except Exception as exc:
556
+ print(f"[GeoToken] GNNTrainer generation failed: {exc}")
557
+
558
+ result = self._build_geo_tokens()
559
+ if result is None:
560
+ return
561
+ train_tokens, test_tokens, cols, geo_gnn = result
562
+ self.train_geo_tokens = train_tokens
563
+ self.test_geo_tokens = test_tokens
564
+ self.geo_token_cols = cols
565
+ self.geo_gnn_model = geo_gnn
566
+ print(f"[GeoToken] Generated {len(cols)}-dim geo tokens; injecting into FT.")
567
+
568
+ def _add_region_effect(self) -> None:
569
+ """Partial pooling over province/city to create a smoothed region_effect feature."""
570
+ prov_col = self.config.region_province_col
571
+ city_col = self.config.region_city_col
572
+ if not prov_col or not city_col:
573
+ return
574
+ for col in [prov_col, city_col]:
575
+ if col not in self.train_data.columns:
576
+ print(f"[RegionEffect] Missing column {col}; skipped.")
577
+ return
578
+
579
+ def safe_mean(df: pd.DataFrame) -> float:
580
+ w = df[self.weight_nme]
581
+ y = df[self.resp_nme]
582
+ denom = max(float(w.sum()), EPS)
583
+ return float((y * w).sum() / denom)
584
+
585
+ global_mean = safe_mean(self.train_data)
586
+ alpha = max(float(self.config.region_effect_alpha), 0.0)
587
+
588
+ w_all = self.train_data[self.weight_nme]
589
+ y_all = self.train_data[self.resp_nme]
590
+ yw_all = y_all * w_all
591
+
592
+ prov_sumw = w_all.groupby(self.train_data[prov_col]).sum()
593
+ prov_sumyw = yw_all.groupby(self.train_data[prov_col]).sum()
594
+ prov_mean = (prov_sumyw / prov_sumw.clip(lower=EPS)).astype(float)
595
+ prov_mean = prov_mean.fillna(global_mean)
596
+
597
+ city_sumw = self.train_data.groupby([prov_col, city_col])[
598
+ self.weight_nme].sum()
599
+ city_sumyw = yw_all.groupby(
600
+ [self.train_data[prov_col], self.train_data[city_col]]).sum()
601
+ city_df = pd.DataFrame({
602
+ "sum_w": city_sumw,
603
+ "sum_yw": city_sumyw,
604
+ })
605
+ city_df["prior"] = city_df.index.get_level_values(0).map(
606
+ prov_mean).fillna(global_mean)
607
+ city_df["effect"] = (
608
+ city_df["sum_yw"] + alpha * city_df["prior"]
609
+ ) / (city_df["sum_w"] + alpha).clip(lower=EPS)
610
+ city_effect = city_df["effect"]
611
+
612
+ def lookup_effect(df: pd.DataFrame) -> pd.Series:
613
+ idx = pd.MultiIndex.from_frame(df[[prov_col, city_col]])
614
+ effects = city_effect.reindex(idx).to_numpy(dtype=np.float64)
615
+ prov_fallback = df[prov_col].map(
616
+ prov_mean).fillna(global_mean).to_numpy(dtype=np.float64)
617
+ effects = np.where(np.isfinite(effects), effects, prov_fallback)
618
+ effects = np.where(np.isfinite(effects), effects, global_mean)
619
+ return pd.Series(effects, index=df.index, dtype=np.float32)
620
+
621
+ re_train = lookup_effect(self.train_data)
622
+ re_test = lookup_effect(self.test_data)
623
+
624
+ col_name = "region_effect"
625
+ self.train_data[col_name] = re_train
626
+ self.test_data[col_name] = re_test
627
+
628
+ # Sync into one-hot and scaled variants.
629
+ for df in [self.train_oht_data, self.test_oht_data]:
630
+ if df is not None:
631
+ df[col_name] = re_train if df is self.train_oht_data else re_test
632
+
633
+ # Standardize region_effect and propagate.
634
+ scaler = StandardScaler()
635
+ re_train_s = scaler.fit_transform(
636
+ re_train.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
637
+ re_test_s = scaler.transform(
638
+ re_test.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
639
+ for df in [self.train_oht_scl_data, self.test_oht_scl_data]:
640
+ if df is not None:
641
+ df[col_name] = re_train_s if df is self.train_oht_scl_data else re_test_s
642
+
643
+ # Update feature lists.
644
+ if col_name not in self.factor_nmes:
645
+ self.factor_nmes.append(col_name)
646
+ if col_name not in self.num_features:
647
+ self.num_features.append(col_name)
648
+ if self.train_oht_scl_data is not None:
649
+ excluded = {self.weight_nme, self.resp_nme}
650
+ self.var_nmes = [
651
+ col for col in self.train_oht_scl_data.columns if col not in excluded
652
+ ]
653
+
654
+ def _require_trainer(self, model_key: str) -> "TrainerBase":
655
+ trainer = self.trainers.get(model_key)
656
+ if trainer is None:
657
+ raise KeyError(f"Unknown model key: {model_key}")
658
+ return trainer
659
+
660
+ def _pred_vector_columns(self, pred_prefix: str) -> List[str]:
661
+ """Return vector feature columns like pred_<prefix>_0.. sorted by suffix."""
662
+ col_prefix = f"pred_{pred_prefix}_"
663
+ cols = [c for c in self.train_data.columns if c.startswith(col_prefix)]
664
+
665
+ def sort_key(name: str):
666
+ tail = name.rsplit("_", 1)[-1]
667
+ try:
668
+ return (0, int(tail))
669
+ except Exception:
670
+ return (1, tail)
671
+
672
+ cols.sort(key=sort_key)
673
+ return cols
674
+
675
+ def _inject_pred_features(self, pred_prefix: str) -> List[str]:
676
+ """Inject pred_<prefix> or pred_<prefix>_i columns into features and return names."""
677
+ cols = self._pred_vector_columns(pred_prefix)
678
+ if cols:
679
+ self.add_numeric_features_from_columns(cols)
680
+ return cols
681
+ scalar_col = f"pred_{pred_prefix}"
682
+ if scalar_col in self.train_data.columns:
683
+ self.add_numeric_feature_from_column(scalar_col)
684
+ return [scalar_col]
685
+ return []
686
+
687
+ def _maybe_load_best_params(self, model_key: str, trainer: "TrainerBase") -> None:
688
+ # 1) If best_params_files is specified, load and skip tuning.
689
+ best_params_files = getattr(self.config, "best_params_files", None) or {}
690
+ best_params_file = best_params_files.get(model_key)
691
+ if best_params_file and not trainer.best_params:
692
+ trainer.best_params = IOUtils.load_params_file(best_params_file)
693
+ trainer.best_trial = None
694
+ print(
695
+ f"[Optuna][{trainer.label}] Loaded best_params from {best_params_file}; skip tuning."
696
+ )
697
+
698
+ # 2) If reuse_best_params is enabled, prefer version snapshots; else load legacy CSV.
699
+ reuse_params = bool(getattr(self.config, "reuse_best_params", False))
700
+ if reuse_params and not trainer.best_params:
701
+ payload = self.version_manager.load_latest(f"{model_key}_best")
702
+ best_params = None if payload is None else payload.get("best_params")
703
+ if best_params:
704
+ trainer.best_params = best_params
705
+ trainer.best_trial = None
706
+ trainer.study_name = payload.get(
707
+ "study_name") if isinstance(payload, dict) else None
708
+ print(
709
+ f"[Optuna][{trainer.label}] Reusing best_params from versions snapshot.")
710
+ return
711
+
712
+ params_path = self.output_manager.result_path(
713
+ f'{self.model_nme}_bestparams_{trainer.label.lower()}.csv'
714
+ )
715
+ if os.path.exists(params_path):
716
+ try:
717
+ trainer.best_params = IOUtils.load_params_file(params_path)
718
+ trainer.best_trial = None
719
+ print(
720
+ f"[Optuna][{trainer.label}] Reusing best_params from {params_path}.")
721
+ except ValueError:
722
+ # Legacy compatibility: ignore empty files and continue tuning.
723
+ pass
724
+
725
+ # Generic optimization entry point.
726
+ def optimize_model(self, model_key: str, max_evals: int = 100):
727
+ if model_key not in self.trainers:
728
+ print(f"Warning: Unknown model key: {model_key}")
729
+ return
730
+
731
+ trainer = self._require_trainer(model_key)
732
+ self._maybe_load_best_params(model_key, trainer)
733
+
734
+ should_tune = not trainer.best_params
735
+ if should_tune:
736
+ if model_key == "ft" and str(self.config.ft_role) == "unsupervised_embedding":
737
+ if hasattr(trainer, "cross_val_unsupervised"):
738
+ trainer.tune(
739
+ max_evals,
740
+ objective_fn=getattr(trainer, "cross_val_unsupervised")
741
+ )
742
+ else:
743
+ raise RuntimeError(
744
+ "FT trainer does not support unsupervised Optuna objective.")
745
+ else:
746
+ trainer.tune(max_evals)
747
+
748
+ if model_key == "ft" and str(self.config.ft_role) != "model":
749
+ prefix = str(self.config.ft_feature_prefix or "ft_emb")
750
+ role = str(self.config.ft_role)
751
+ if role == "embedding":
752
+ trainer.train_as_feature(
753
+ pred_prefix=prefix, feature_mode="embedding")
754
+ elif role == "unsupervised_embedding":
755
+ trainer.pretrain_unsupervised_as_feature(
756
+ pred_prefix=prefix,
757
+ params=trainer.best_params
758
+ )
759
+ else:
760
+ raise ValueError(
761
+ f"Unsupported ft_role='{role}', expected 'model'/'embedding'/'unsupervised_embedding'.")
762
+
763
+ # Inject generated prediction/embedding columns as features (scalar or vector).
764
+ self._inject_pred_features(prefix)
765
+ # Do not add FT as a standalone model label; downstream models handle evaluation.
766
+ else:
767
+ trainer.train()
768
+
769
+ if bool(getattr(self.config, "final_ensemble", False)):
770
+ k = int(getattr(self.config, "final_ensemble_k", 3) or 3)
771
+ if k > 1:
772
+ if model_key == "ft" and str(self.config.ft_role) != "model":
773
+ pass
774
+ elif hasattr(trainer, "ensemble_predict"):
775
+ trainer.ensemble_predict(k)
776
+ else:
777
+ print(
778
+ f"[Ensemble] Trainer '{model_key}' does not support ensemble prediction.",
779
+ flush=True,
780
+ )
781
+
782
+ # Update context fields for backward compatibility.
783
+ setattr(self, f"{model_key}_best", trainer.model)
784
+ setattr(self, f"best_{model_key}_params", trainer.best_params)
785
+ setattr(self, f"best_{model_key}_trial", trainer.best_trial)
786
+ # Save a snapshot for traceability.
787
+ study_name = getattr(trainer, "study_name", None)
788
+ if study_name is None and trainer.best_trial is not None:
789
+ study_obj = getattr(trainer.best_trial, "study", None)
790
+ study_name = getattr(study_obj, "study_name", None)
791
+ snapshot = {
792
+ "model_key": model_key,
793
+ "timestamp": datetime.now().isoformat(),
794
+ "best_params": trainer.best_params,
795
+ "study_name": study_name,
796
+ "config": asdict(self.config),
797
+ }
798
+ self.version_manager.save(f"{model_key}_best", snapshot)
799
+
800
+ def add_numeric_feature_from_column(self, col_name: str) -> None:
801
+ """Add an existing column as a feature and sync one-hot/scaled tables."""
802
+ if col_name not in self.train_data.columns or col_name not in self.test_data.columns:
803
+ raise KeyError(
804
+ f"Column '{col_name}' must exist in both train_data and test_data.")
805
+
806
+ if col_name not in self.factor_nmes:
807
+ self.factor_nmes.append(col_name)
808
+ if col_name not in self.config.factor_nmes:
809
+ self.config.factor_nmes.append(col_name)
810
+
811
+ if col_name not in self.cate_list and col_name not in self.num_features:
812
+ self.num_features.append(col_name)
813
+
814
+ if self.train_oht_data is not None and self.test_oht_data is not None:
815
+ self.train_oht_data[col_name] = self.train_data[col_name].values
816
+ self.test_oht_data[col_name] = self.test_data[col_name].values
817
+ if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
818
+ scaler = StandardScaler()
819
+ tr = self.train_data[col_name].to_numpy(
820
+ dtype=np.float32, copy=False).reshape(-1, 1)
821
+ te = self.test_data[col_name].to_numpy(
822
+ dtype=np.float32, copy=False).reshape(-1, 1)
823
+ self.train_oht_scl_data[col_name] = scaler.fit_transform(
824
+ tr).reshape(-1)
825
+ self.test_oht_scl_data[col_name] = scaler.transform(te).reshape(-1)
826
+
827
+ if col_name not in self.var_nmes:
828
+ self.var_nmes.append(col_name)
829
+
830
+ def add_numeric_features_from_columns(self, col_names: List[str]) -> None:
831
+ if not col_names:
832
+ return
833
+
834
+ missing = [
835
+ col for col in col_names
836
+ if col not in self.train_data.columns or col not in self.test_data.columns
837
+ ]
838
+ if missing:
839
+ raise KeyError(
840
+ f"Column(s) {missing} must exist in both train_data and test_data."
841
+ )
842
+
843
+ for col_name in col_names:
844
+ if col_name not in self.factor_nmes:
845
+ self.factor_nmes.append(col_name)
846
+ if col_name not in self.config.factor_nmes:
847
+ self.config.factor_nmes.append(col_name)
848
+ if col_name not in self.cate_list and col_name not in self.num_features:
849
+ self.num_features.append(col_name)
850
+ if col_name not in self.var_nmes:
851
+ self.var_nmes.append(col_name)
852
+
853
+ if self.train_oht_data is not None and self.test_oht_data is not None:
854
+ self.train_oht_data[col_names] = self.train_data[col_names].to_numpy(copy=False)
855
+ self.test_oht_data[col_names] = self.test_data[col_names].to_numpy(copy=False)
856
+
857
+ if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
858
+ scaler = StandardScaler()
859
+ tr = self.train_data[col_names].to_numpy(dtype=np.float32, copy=False)
860
+ te = self.test_data[col_names].to_numpy(dtype=np.float32, copy=False)
861
+ self.train_oht_scl_data[col_names] = scaler.fit_transform(tr)
862
+ self.test_oht_scl_data[col_names] = scaler.transform(te)
863
+
864
+ def prepare_ft_as_feature(self, max_evals: int = 50, pred_prefix: str = "ft_feat") -> str:
865
+ """Train FT as a feature generator and return the downstream column name."""
866
+ ft_trainer = self._require_trainer("ft")
867
+ ft_trainer.tune(max_evals=max_evals)
868
+ if hasattr(ft_trainer, "train_as_feature"):
869
+ ft_trainer.train_as_feature(pred_prefix=pred_prefix)
870
+ else:
871
+ ft_trainer.train()
872
+ feature_col = f"pred_{pred_prefix}"
873
+ self.add_numeric_feature_from_column(feature_col)
874
+ return feature_col
875
+
876
+ def prepare_ft_embedding_as_features(self, max_evals: int = 50, pred_prefix: str = "ft_emb") -> List[str]:
877
+ """Train FT and inject pooled embeddings as vector features pred_<prefix>_0.. ."""
878
+ ft_trainer = self._require_trainer("ft")
879
+ ft_trainer.tune(max_evals=max_evals)
880
+ if hasattr(ft_trainer, "train_as_feature"):
881
+ ft_trainer.train_as_feature(
882
+ pred_prefix=pred_prefix, feature_mode="embedding")
883
+ else:
884
+ raise RuntimeError(
885
+ "FT trainer does not support embedding feature mode.")
886
+ cols = self._pred_vector_columns(pred_prefix)
887
+ if not cols:
888
+ raise RuntimeError(
889
+ f"No embedding columns were generated for prefix '{pred_prefix}'.")
890
+ self.add_numeric_features_from_columns(cols)
891
+ return cols
892
+
893
+ def prepare_ft_unsupervised_embedding_as_features(self,
894
+ pred_prefix: str = "ft_uemb",
895
+ params: Optional[Dict[str,
896
+ Any]] = None,
897
+ mask_prob_num: float = 0.15,
898
+ mask_prob_cat: float = 0.15,
899
+ num_loss_weight: float = 1.0,
900
+ cat_loss_weight: float = 1.0) -> List[str]:
901
+ """Export embeddings after FT self-supervised masked reconstruction pretraining."""
902
+ ft_trainer = self._require_trainer("ft")
903
+ if not hasattr(ft_trainer, "pretrain_unsupervised_as_feature"):
904
+ raise RuntimeError(
905
+ "FT trainer does not support unsupervised pretraining.")
906
+ ft_trainer.pretrain_unsupervised_as_feature(
907
+ pred_prefix=pred_prefix,
908
+ params=params,
909
+ mask_prob_num=mask_prob_num,
910
+ mask_prob_cat=mask_prob_cat,
911
+ num_loss_weight=num_loss_weight,
912
+ cat_loss_weight=cat_loss_weight
913
+ )
914
+ cols = self._pred_vector_columns(pred_prefix)
915
+ if not cols:
916
+ raise RuntimeError(
917
+ f"No embedding columns were generated for prefix '{pred_prefix}'.")
918
+ self.add_numeric_features_from_columns(cols)
919
+ return cols
920
+
921
+ # GLM Bayesian optimization wrapper.
922
+ def bayesopt_glm(self, max_evals=50):
923
+ self.optimize_model('glm', max_evals)
924
+
925
+ # XGBoost Bayesian optimization wrapper.
926
+ def bayesopt_xgb(self, max_evals=100):
927
+ self.optimize_model('xgb', max_evals)
928
+
929
+ # ResNet Bayesian optimization wrapper.
930
+ def bayesopt_resnet(self, max_evals=100):
931
+ self.optimize_model('resn', max_evals)
932
+
933
+ # GNN Bayesian optimization wrapper.
934
+ def bayesopt_gnn(self, max_evals=50):
935
+ self.optimize_model('gnn', max_evals)
936
+
937
+ # FT-Transformer Bayesian optimization wrapper.
938
+ def bayesopt_ft(self, max_evals=50):
939
+ self.optimize_model('ft', max_evals)
940
+
941
+ def save_model(self, model_name=None):
942
+ keys = [model_name] if model_name else self.trainers.keys()
943
+ for key in keys:
944
+ if key in self.trainers:
945
+ self.trainers[key].save()
946
+ else:
947
+ if model_name: # Only warn when the user specifies a model name.
948
+ print(f"[save_model] Warning: Unknown model key {key}")
949
+
950
+ def load_model(self, model_name=None):
951
+ keys = [model_name] if model_name else self.trainers.keys()
952
+ for key in keys:
953
+ if key in self.trainers:
954
+ self.trainers[key].load()
955
+ # Sync context fields.
956
+ trainer = self.trainers[key]
957
+ if trainer.model is not None:
958
+ setattr(self, f"{key}_best", trainer.model)
959
+ # For legacy compatibility, also update xxx_load.
960
+ # Old versions only tracked xgb_load/resn_load/ft_load (not glm_load/gnn_load).
961
+ if key in ['xgb', 'resn', 'ft', 'gnn']:
962
+ setattr(self, f"{key}_load", trainer.model)
963
+ else:
964
+ if model_name:
965
+ print(f"[load_model] Warning: Unknown model key {key}")