ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,2446 @@
1
+ # =============================================================================
2
+ from __future__ import annotations
3
+
4
+ from datetime import timedelta
5
+ import gc
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
9
+
10
+ import joblib
11
+ import numpy as np
12
+ import optuna
13
+ import pandas as pd
14
+ import torch
15
+ try: # pragma: no cover
16
+ import torch.distributed as dist # type: ignore
17
+ except Exception: # pragma: no cover
18
+ dist = None # type: ignore
19
+ import xgboost as xgb
20
+ from sklearn.metrics import log_loss, mean_tweedie_deviance
21
+ from sklearn.model_selection import KFold, ShuffleSplit
22
+ from sklearn.preprocessing import StandardScaler
23
+
24
+ import statsmodels.api as sm
25
+
26
+ from .config_preprocess import BayesOptConfig, OutputManager
27
+ from .models import FTTransformerSklearn, GraphNeuralNetSklearn, ResNetSklearn
28
+ from .utils import DistributedUtils, EPS, ensure_parent_dir
29
+
30
+ _XGB_CUDA_CHECKED = False
31
+ _XGB_HAS_CUDA = False
32
+
33
+
34
+ def _xgb_cuda_available() -> bool:
35
+ # Best-effort check for XGBoost CUDA build; cached to avoid repeated checks.
36
+ global _XGB_CUDA_CHECKED, _XGB_HAS_CUDA
37
+ if _XGB_CUDA_CHECKED:
38
+ return _XGB_HAS_CUDA
39
+ _XGB_CUDA_CHECKED = True
40
+ if not torch.cuda.is_available():
41
+ _XGB_HAS_CUDA = False
42
+ return False
43
+ try:
44
+ build_info = getattr(xgb, "build_info", None)
45
+ if callable(build_info):
46
+ info = build_info()
47
+ for key in ("USE_CUDA", "use_cuda", "cuda"):
48
+ if key in info:
49
+ val = info[key]
50
+ if isinstance(val, str):
51
+ _XGB_HAS_CUDA = val.strip().upper() in (
52
+ "ON", "YES", "TRUE", "1")
53
+ else:
54
+ _XGB_HAS_CUDA = bool(val)
55
+ return _XGB_HAS_CUDA
56
+ except Exception:
57
+ pass
58
+ try:
59
+ has_cuda = getattr(getattr(xgb, "core", None), "_has_cuda_support", None)
60
+ if callable(has_cuda):
61
+ _XGB_HAS_CUDA = bool(has_cuda())
62
+ return _XGB_HAS_CUDA
63
+ except Exception:
64
+ pass
65
+ _XGB_HAS_CUDA = False
66
+ return False
67
+
68
+ # =============================================================================
69
+ # Trainer system
70
+ # =============================================================================
71
+
72
+
73
+ class TrainerBase:
74
+ def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
75
+ self.ctx = context
76
+ self.label = label
77
+ self.model_name_prefix = model_name_prefix
78
+ self.model = None
79
+ self.best_params: Optional[Dict[str, Any]] = None
80
+ self.best_trial = None
81
+ self.study_name: Optional[str] = None
82
+ self.enable_distributed_optuna: bool = False
83
+ self._distributed_forced_params: Optional[Dict[str, Any]] = None
84
+
85
+ def _dist_barrier(self, reason: str) -> None:
86
+ """DDP barrier wrapper used by distributed Optuna.
87
+
88
+ To debug "trial finished but next trial never starts" hangs, set these
89
+ environment variables (either in shell or config.json `env`):
90
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1` to print barrier enter/exit per-rank
91
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300` to fail fast instead of waiting forever
92
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL` and `NCCL_DEBUG=INFO` for PyTorch/NCCL logs
93
+ """
94
+ if dist is None:
95
+ return
96
+ try:
97
+ if not getattr(dist, "is_available", lambda: False)():
98
+ return
99
+ if not dist.is_initialized():
100
+ return
101
+ except Exception:
102
+ return
103
+
104
+ timeout_seconds = int(os.environ.get("BAYESOPT_DDP_BARRIER_TIMEOUT", "1800"))
105
+ debug_barrier = os.environ.get("BAYESOPT_DDP_BARRIER_DEBUG", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
106
+ rank = None
107
+ world = None
108
+ if debug_barrier:
109
+ try:
110
+ rank = dist.get_rank()
111
+ world = dist.get_world_size()
112
+ print(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
113
+ except Exception:
114
+ debug_barrier = False
115
+ try:
116
+ timeout = timedelta(seconds=timeout_seconds)
117
+ backend = None
118
+ try:
119
+ backend = dist.get_backend()
120
+ except Exception:
121
+ backend = None
122
+
123
+ # `monitored_barrier` is only implemented for GLOO; using it under NCCL
124
+ # will raise and can itself trigger a secondary hang. Prefer an async
125
+ # barrier with timeout for NCCL.
126
+ monitored = getattr(dist, "monitored_barrier", None)
127
+ if backend == "gloo" and callable(monitored):
128
+ monitored(timeout=timeout)
129
+ else:
130
+ work = None
131
+ try:
132
+ work = dist.barrier(async_op=True)
133
+ except TypeError:
134
+ work = None
135
+ if work is not None:
136
+ wait = getattr(work, "wait", None)
137
+ if callable(wait):
138
+ try:
139
+ wait(timeout=timeout)
140
+ except TypeError:
141
+ wait()
142
+ else:
143
+ dist.barrier()
144
+ else:
145
+ dist.barrier()
146
+ if debug_barrier:
147
+ print(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
148
+ except Exception as exc:
149
+ print(
150
+ f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
151
+ flush=True,
152
+ )
153
+ raise
154
+
155
+ @property
156
+ def config(self) -> BayesOptConfig:
157
+ return self.ctx.config
158
+
159
+ @property
160
+ def output(self) -> OutputManager:
161
+ return self.ctx.output_manager
162
+
163
+ def _get_model_filename(self) -> str:
164
+ ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
165
+ return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
166
+
167
+ def _resolve_optuna_storage_url(self) -> Optional[str]:
168
+ storage = getattr(self.config, "optuna_storage", None)
169
+ if not storage:
170
+ return None
171
+ storage_str = str(storage).strip()
172
+ if not storage_str:
173
+ return None
174
+ if "://" in storage_str or storage_str == ":memory:":
175
+ return storage_str
176
+ path = Path(storage_str)
177
+ path = path.resolve()
178
+ ensure_parent_dir(str(path))
179
+ return f"sqlite:///{path.as_posix()}"
180
+
181
+ def _resolve_optuna_study_name(self) -> str:
182
+ prefix = getattr(self.config, "optuna_study_prefix",
183
+ None) or "bayesopt"
184
+ raw = f"{prefix}_{self.ctx.model_nme}_{self.model_name_prefix}"
185
+ safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
186
+ return safe.lower()
187
+
188
+ def tune(self, max_evals: int, objective_fn=None) -> None:
189
+ # Generic Optuna tuning loop.
190
+ if objective_fn is None:
191
+ # If subclass doesn't provide objective_fn, default to cross_val.
192
+ objective_fn = self.cross_val
193
+
194
+ if self._should_use_distributed_optuna():
195
+ self._distributed_tune(max_evals, objective_fn)
196
+ return
197
+
198
+ total_trials = max(1, int(max_evals))
199
+ progress_counter = {"count": 0}
200
+
201
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
202
+ should_log = DistributedUtils.is_main_process()
203
+ if should_log:
204
+ current_idx = progress_counter["count"] + 1
205
+ print(
206
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
207
+ f"(trial_id={trial.number})."
208
+ )
209
+ try:
210
+ result = objective_fn(trial)
211
+ except RuntimeError as exc:
212
+ if "out of memory" in str(exc).lower():
213
+ print(
214
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
215
+ )
216
+ self._clean_gpu()
217
+ raise optuna.TrialPruned() from exc
218
+ raise
219
+ finally:
220
+ self._clean_gpu()
221
+ if should_log:
222
+ progress_counter["count"] = progress_counter["count"] + 1
223
+ trial_state = getattr(trial, "state", None)
224
+ state_repr = getattr(trial_state, "name", "OK")
225
+ print(
226
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
227
+ f"(status={state_repr})."
228
+ )
229
+ return result
230
+
231
+ storage_url = self._resolve_optuna_storage_url()
232
+ study_name = self._resolve_optuna_study_name()
233
+ study_kwargs: Dict[str, Any] = {
234
+ "direction": "minimize",
235
+ "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
236
+ }
237
+ if storage_url:
238
+ study_kwargs.update(
239
+ storage=storage_url,
240
+ study_name=study_name,
241
+ load_if_exists=True,
242
+ )
243
+
244
+ study = optuna.create_study(**study_kwargs)
245
+ self.study_name = getattr(study, "study_name", None)
246
+
247
+ def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
248
+ # Persist best_params after each trial to allow safe resume.
249
+ try:
250
+ best = getattr(check_study, "best_trial", None)
251
+ if best is None:
252
+ return
253
+ best_params = getattr(best, "params", None)
254
+ if not best_params:
255
+ return
256
+ params_path = self.output.result_path(
257
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
258
+ )
259
+ pd.DataFrame(best_params, index=[0]).to_csv(
260
+ params_path, index=False)
261
+ except Exception:
262
+ return
263
+
264
+ completed_states = (
265
+ optuna.trial.TrialState.COMPLETE,
266
+ optuna.trial.TrialState.PRUNED,
267
+ optuna.trial.TrialState.FAIL,
268
+ )
269
+ completed = len(study.get_trials(states=completed_states))
270
+ progress_counter["count"] = completed
271
+ remaining = max(0, total_trials - completed)
272
+ if remaining > 0:
273
+ study.optimize(
274
+ objective_wrapper,
275
+ n_trials=remaining,
276
+ callbacks=[checkpoint_callback],
277
+ )
278
+ self.best_params = study.best_params
279
+ self.best_trial = study.best_trial
280
+
281
+ # Save best params to CSV for reproducibility.
282
+ params_path = self.output.result_path(
283
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
284
+ )
285
+ pd.DataFrame(self.best_params, index=[0]).to_csv(
286
+ params_path, index=False)
287
+
288
+ def train(self) -> None:
289
+ raise NotImplementedError
290
+
291
+ def save(self) -> None:
292
+ if self.model is None:
293
+ print(f"[save] Warning: No model to save for {self.label}")
294
+ return
295
+
296
+ path = self.output.model_path(self._get_model_filename())
297
+ if self.label in ['Xgboost', 'GLM']:
298
+ joblib.dump(self.model, path)
299
+ else:
300
+ # PyTorch models can save state_dict or the full object.
301
+ # Legacy behavior: ResNetTrainer saves state_dict; FTTrainer saves full object.
302
+ if hasattr(self.model, 'resnet'): # ResNetSklearn model
303
+ torch.save(self.model.resnet.state_dict(), path)
304
+ else: # FTTransformerSklearn or other PyTorch model
305
+ torch.save(self.model, path)
306
+
307
+ def load(self) -> None:
308
+ path = self.output.model_path(self._get_model_filename())
309
+ if not os.path.exists(path):
310
+ print(f"[load] Warning: Model file not found: {path}")
311
+ return
312
+
313
+ if self.label in ['Xgboost', 'GLM']:
314
+ self.model = joblib.load(path)
315
+ else:
316
+ # PyTorch loading depends on the model structure.
317
+ if self.label == 'ResNet' or self.label == 'ResNetClassifier':
318
+ # ResNet requires reconstructing the skeleton; handled by subclass.
319
+ pass
320
+ else:
321
+ # FT-Transformer serializes the whole object; load then move to device.
322
+ loaded = torch.load(path, map_location='cpu')
323
+ self._move_to_device(loaded)
324
+ self.model = loaded
325
+
326
+ def _move_to_device(self, model_obj):
327
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
328
+ if hasattr(model_obj, 'device'):
329
+ model_obj.device = device
330
+ if hasattr(model_obj, 'to'):
331
+ model_obj.to(device)
332
+ # Move nested submodules (ft/resnet/gnn) to the same device.
333
+ if hasattr(model_obj, 'ft'):
334
+ model_obj.ft.to(device)
335
+ if hasattr(model_obj, 'resnet'):
336
+ model_obj.resnet.to(device)
337
+ if hasattr(model_obj, 'gnn'):
338
+ model_obj.gnn.to(device)
339
+
340
+ def _should_use_distributed_optuna(self) -> bool:
341
+ if not self.enable_distributed_optuna:
342
+ return False
343
+ rank_env = os.environ.get("RANK")
344
+ world_env = os.environ.get("WORLD_SIZE")
345
+ local_env = os.environ.get("LOCAL_RANK")
346
+ if rank_env is None or world_env is None or local_env is None:
347
+ return False
348
+ try:
349
+ world_size = int(world_env)
350
+ except Exception:
351
+ return False
352
+ return world_size > 1
353
+
354
+ def _distributed_is_main(self) -> bool:
355
+ return DistributedUtils.is_main_process()
356
+
357
+ def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
358
+ if not self._should_use_distributed_optuna() or not self._distributed_is_main():
359
+ return
360
+ if dist is None:
361
+ return
362
+ DistributedUtils.setup_ddp()
363
+ if not dist.is_initialized():
364
+ return
365
+ message = [payload]
366
+ dist.broadcast_object_list(message, src=0)
367
+
368
+ def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
369
+ if not self._should_use_distributed_optuna():
370
+ return
371
+ if not self._distributed_is_main():
372
+ return
373
+ if dist is None:
374
+ return
375
+ self._distributed_send_command({"type": "RUN", "params": params})
376
+ if not dist.is_initialized():
377
+ return
378
+ # STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
379
+ self._dist_barrier("prepare_trial")
380
+
381
+ def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
382
+ if dist is None:
383
+ print(
384
+ f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
385
+ flush=True,
386
+ )
387
+ return
388
+ DistributedUtils.setup_ddp()
389
+ if not dist.is_initialized():
390
+ print(
391
+ f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
392
+ flush=True,
393
+ )
394
+ return
395
+ while True:
396
+ message = [None]
397
+ dist.broadcast_object_list(message, src=0)
398
+ payload = message[0]
399
+ if not isinstance(payload, dict):
400
+ continue
401
+ cmd = payload.get("type")
402
+ if cmd == "STOP":
403
+ best_params = payload.get("best_params")
404
+ if best_params is not None:
405
+ self.best_params = best_params
406
+ break
407
+ if cmd == "RUN":
408
+ params = payload.get("params") or {}
409
+ self._distributed_forced_params = params
410
+ # STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
411
+ self._dist_barrier("worker_start")
412
+ try:
413
+ objective_fn(None)
414
+ except optuna.TrialPruned:
415
+ pass
416
+ except Exception as exc:
417
+ print(
418
+ f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
419
+ finally:
420
+ self._clean_gpu()
421
+ # STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
422
+ self._dist_barrier("worker_end")
423
+
424
+ def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
425
+ if dist is None:
426
+ print(
427
+ f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
428
+ flush=True,
429
+ )
430
+ prev = self.enable_distributed_optuna
431
+ self.enable_distributed_optuna = False
432
+ try:
433
+ self.tune(max_evals, objective_fn)
434
+ finally:
435
+ self.enable_distributed_optuna = prev
436
+ return
437
+ DistributedUtils.setup_ddp()
438
+ if not dist.is_initialized():
439
+ rank_env = os.environ.get("RANK", "0")
440
+ if str(rank_env) != "0":
441
+ print(
442
+ f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
443
+ flush=True,
444
+ )
445
+ return
446
+ print(
447
+ f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
448
+ flush=True,
449
+ )
450
+ prev = self.enable_distributed_optuna
451
+ self.enable_distributed_optuna = False
452
+ try:
453
+ self.tune(max_evals, objective_fn)
454
+ finally:
455
+ self.enable_distributed_optuna = prev
456
+ return
457
+ if not self._distributed_is_main():
458
+ self._distributed_worker_loop(objective_fn)
459
+ return
460
+
461
+ total_trials = max(1, int(max_evals))
462
+ progress_counter = {"count": 0}
463
+
464
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
465
+ should_log = True
466
+ if should_log:
467
+ current_idx = progress_counter["count"] + 1
468
+ print(
469
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
470
+ f"(trial_id={trial.number})."
471
+ )
472
+ try:
473
+ result = objective_fn(trial)
474
+ except RuntimeError as exc:
475
+ if "out of memory" in str(exc).lower():
476
+ print(
477
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
478
+ )
479
+ self._clean_gpu()
480
+ raise optuna.TrialPruned() from exc
481
+ raise
482
+ finally:
483
+ self._clean_gpu()
484
+ if should_log:
485
+ progress_counter["count"] = progress_counter["count"] + 1
486
+ trial_state = getattr(trial, "state", None)
487
+ state_repr = getattr(trial_state, "name", "OK")
488
+ print(
489
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
490
+ f"(status={state_repr})."
491
+ )
492
+ # STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
493
+ self._dist_barrier("trial_end")
494
+ return result
495
+
496
+ storage_url = self._resolve_optuna_storage_url()
497
+ study_name = self._resolve_optuna_study_name()
498
+ study_kwargs: Dict[str, Any] = {
499
+ "direction": "minimize",
500
+ "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
501
+ }
502
+ if storage_url:
503
+ study_kwargs.update(
504
+ storage=storage_url,
505
+ study_name=study_name,
506
+ load_if_exists=True,
507
+ )
508
+ study = optuna.create_study(**study_kwargs)
509
+ self.study_name = getattr(study, "study_name", None)
510
+
511
+ def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
512
+ try:
513
+ best = getattr(check_study, "best_trial", None)
514
+ if best is None:
515
+ return
516
+ best_params = getattr(best, "params", None)
517
+ if not best_params:
518
+ return
519
+ params_path = self.output.result_path(
520
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
521
+ )
522
+ pd.DataFrame(best_params, index=[0]).to_csv(
523
+ params_path, index=False)
524
+ except Exception:
525
+ return
526
+
527
+ completed_states = (
528
+ optuna.trial.TrialState.COMPLETE,
529
+ optuna.trial.TrialState.PRUNED,
530
+ optuna.trial.TrialState.FAIL,
531
+ )
532
+ completed = len(study.get_trials(states=completed_states))
533
+ progress_counter["count"] = completed
534
+ remaining = max(0, total_trials - completed)
535
+ try:
536
+ if remaining > 0:
537
+ study.optimize(
538
+ objective_wrapper,
539
+ n_trials=remaining,
540
+ callbacks=[checkpoint_callback],
541
+ )
542
+ self.best_params = study.best_params
543
+ self.best_trial = study.best_trial
544
+ params_path = self.output.result_path(
545
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
546
+ )
547
+ pd.DataFrame(self.best_params, index=[0]).to_csv(
548
+ params_path, index=False)
549
+ finally:
550
+ self._distributed_send_command(
551
+ {"type": "STOP", "best_params": self.best_params})
552
+
553
+ def _clean_gpu(self):
554
+ gc.collect()
555
+ if torch.cuda.is_available():
556
+ device = None
557
+ try:
558
+ device = getattr(self, "device", None)
559
+ except Exception:
560
+ device = None
561
+ if isinstance(device, torch.device):
562
+ try:
563
+ torch.cuda.set_device(device)
564
+ except Exception:
565
+ pass
566
+ torch.cuda.empty_cache()
567
+ do_ipc_collect = os.environ.get("BAYESOPT_CUDA_IPC_COLLECT", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
568
+ do_sync = os.environ.get("BAYESOPT_CUDA_SYNC", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
569
+ if do_ipc_collect:
570
+ torch.cuda.ipc_collect()
571
+ if do_sync:
572
+ torch.cuda.synchronize()
573
+
574
+ def _standardize_fold(self,
575
+ X_train: pd.DataFrame,
576
+ X_val: pd.DataFrame,
577
+ columns: Optional[List[str]] = None
578
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
579
+ """Fit StandardScaler on the training fold and transform train/val features.
580
+
581
+ Args:
582
+ X_train: training features.
583
+ X_val: validation features.
584
+ columns: columns to scale (default: all).
585
+
586
+ Returns:
587
+ Scaled train/val features and the fitted scaler.
588
+ """
589
+ scaler = StandardScaler()
590
+ cols = list(columns) if columns else list(X_train.columns)
591
+ X_train_scaled = X_train.copy(deep=True)
592
+ X_val_scaled = X_val.copy(deep=True)
593
+ if cols:
594
+ scaler.fit(X_train_scaled[cols])
595
+ X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
596
+ X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
597
+ return X_train_scaled, X_val_scaled, scaler
598
+
599
+ def cross_val_generic(
600
+ self,
601
+ trial: optuna.trial.Trial,
602
+ hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
603
+ data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
604
+ model_builder: Callable[[Dict[str, Any]], Any],
605
+ metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
606
+ sample_limit: Optional[int] = None,
607
+ preprocess_fn: Optional[Callable[[
608
+ pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
609
+ fit_predict_fn: Optional[
610
+ Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
611
+ pd.DataFrame, pd.Series, Optional[pd.Series],
612
+ optuna.trial.Trial], np.ndarray]
613
+ ] = None,
614
+ cleanup_fn: Optional[Callable[[Any], None]] = None,
615
+ splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
616
+ """Generic holdout/CV helper to reuse tuning workflows.
617
+
618
+ Args:
619
+ trial: current Optuna trial.
620
+ hyperparameter_space: sampler dict keyed by parameter name.
621
+ data_provider: callback returning (X, y, sample_weight).
622
+ model_builder: callback to build a model per fold.
623
+ metric_fn: loss/score function taking y_true, y_pred, weight.
624
+ sample_limit: optional sample cap; random sample if exceeded.
625
+ preprocess_fn: optional per-fold preprocessing (X_train, X_val).
626
+ fit_predict_fn: optional custom fit/predict logic for validation.
627
+ cleanup_fn: optional cleanup callback per fold.
628
+ splitter: optional (train_idx, val_idx) iterator; defaults to ShuffleSplit.
629
+
630
+ Returns:
631
+ Mean validation metric across folds.
632
+ """
633
+ params: Optional[Dict[str, Any]] = None
634
+ if self._distributed_forced_params is not None:
635
+ params = self._distributed_forced_params
636
+ self._distributed_forced_params = None
637
+ else:
638
+ if trial is None:
639
+ raise RuntimeError(
640
+ "Missing Optuna trial for parameter sampling.")
641
+ params = {name: sampler(trial)
642
+ for name, sampler in hyperparameter_space.items()}
643
+ if self._should_use_distributed_optuna():
644
+ self._distributed_prepare_trial(params)
645
+ X_all, y_all, w_all = data_provider()
646
+ if sample_limit is not None and len(X_all) > sample_limit:
647
+ sampled_idx = X_all.sample(
648
+ n=sample_limit,
649
+ random_state=self.ctx.rand_seed
650
+ ).index
651
+ X_all = X_all.loc[sampled_idx]
652
+ y_all = y_all.loc[sampled_idx]
653
+ w_all = w_all.loc[sampled_idx] if w_all is not None else None
654
+
655
+ split_iter = splitter or ShuffleSplit(
656
+ n_splits=int(1 / self.ctx.prop_test),
657
+ test_size=self.ctx.prop_test,
658
+ random_state=self.ctx.rand_seed
659
+ ).split(X_all)
660
+
661
+ losses: List[float] = []
662
+ for train_idx, val_idx in split_iter:
663
+ X_train = X_all.iloc[train_idx]
664
+ y_train = y_all.iloc[train_idx]
665
+ X_val = X_all.iloc[val_idx]
666
+ y_val = y_all.iloc[val_idx]
667
+ w_train = w_all.iloc[train_idx] if w_all is not None else None
668
+ w_val = w_all.iloc[val_idx] if w_all is not None else None
669
+
670
+ if preprocess_fn:
671
+ X_train, X_val = preprocess_fn(X_train, X_val)
672
+
673
+ model = model_builder(params)
674
+ try:
675
+ if fit_predict_fn:
676
+ y_pred = fit_predict_fn(
677
+ model, X_train, y_train, w_train,
678
+ X_val, y_val, w_val, trial
679
+ )
680
+ else:
681
+ fit_kwargs = {}
682
+ if w_train is not None:
683
+ fit_kwargs["sample_weight"] = w_train
684
+ model.fit(X_train, y_train, **fit_kwargs)
685
+ y_pred = model.predict(X_val)
686
+ losses.append(metric_fn(y_val, y_pred, w_val))
687
+ finally:
688
+ if cleanup_fn:
689
+ cleanup_fn(model)
690
+ self._clean_gpu()
691
+
692
+ return float(np.mean(losses))
693
+
694
+ # Prediction + caching logic.
695
+ def _predict_and_cache(self,
696
+ model,
697
+ pred_prefix: str,
698
+ use_oht: bool = False,
699
+ design_fn=None,
700
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
701
+ predict_kwargs_test: Optional[Dict[str, Any]] = None,
702
+ predict_fn: Optional[Callable[..., Any]] = None) -> None:
703
+ if design_fn:
704
+ X_train = design_fn(train=True)
705
+ X_test = design_fn(train=False)
706
+ elif use_oht:
707
+ X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
708
+ X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
709
+ else:
710
+ X_train = self.ctx.train_data[self.ctx.factor_nmes]
711
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
712
+
713
+ predictor = predict_fn or model.predict
714
+ preds_train = predictor(X_train, **(predict_kwargs_train or {}))
715
+ preds_test = predictor(X_test, **(predict_kwargs_test or {}))
716
+ preds_train = np.asarray(preds_train)
717
+ preds_test = np.asarray(preds_test)
718
+
719
+ if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
720
+ col_name = f'pred_{pred_prefix}'
721
+ self.ctx.train_data[col_name] = preds_train.reshape(-1)
722
+ self.ctx.test_data[col_name] = preds_test.reshape(-1)
723
+ self.ctx.train_data[f'w_{col_name}'] = (
724
+ self.ctx.train_data[col_name] *
725
+ self.ctx.train_data[self.ctx.weight_nme]
726
+ )
727
+ self.ctx.test_data[f'w_{col_name}'] = (
728
+ self.ctx.test_data[col_name] *
729
+ self.ctx.test_data[self.ctx.weight_nme]
730
+ )
731
+ return
732
+
733
+ # Vector outputs (e.g., embeddings) are expanded into pred_<prefix>_0.. columns.
734
+ if preds_train.ndim != 2:
735
+ raise ValueError(
736
+ f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
737
+ if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
738
+ raise ValueError(
739
+ f"Train/test prediction dims mismatch for '{pred_prefix}': "
740
+ f"{preds_train.shape} vs {preds_test.shape}")
741
+ for j in range(preds_train.shape[1]):
742
+ col_name = f'pred_{pred_prefix}_{j}'
743
+ self.ctx.train_data[col_name] = preds_train[:, j]
744
+ self.ctx.test_data[col_name] = preds_test[:, j]
745
+
746
+ def _cache_predictions(self,
747
+ pred_prefix: str,
748
+ preds_train,
749
+ preds_test) -> None:
750
+ preds_train = np.asarray(preds_train)
751
+ preds_test = np.asarray(preds_test)
752
+ if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
753
+ if preds_test.ndim > 1:
754
+ preds_test = preds_test.reshape(-1)
755
+ col_name = f'pred_{pred_prefix}'
756
+ self.ctx.train_data[col_name] = preds_train.reshape(-1)
757
+ self.ctx.test_data[col_name] = preds_test.reshape(-1)
758
+ self.ctx.train_data[f'w_{col_name}'] = (
759
+ self.ctx.train_data[col_name] *
760
+ self.ctx.train_data[self.ctx.weight_nme]
761
+ )
762
+ self.ctx.test_data[f'w_{col_name}'] = (
763
+ self.ctx.test_data[col_name] *
764
+ self.ctx.test_data[self.ctx.weight_nme]
765
+ )
766
+ return
767
+
768
+ if preds_train.ndim != 2:
769
+ raise ValueError(
770
+ f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
771
+ if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
772
+ raise ValueError(
773
+ f"Train/test prediction dims mismatch for '{pred_prefix}': "
774
+ f"{preds_train.shape} vs {preds_test.shape}")
775
+ for j in range(preds_train.shape[1]):
776
+ col_name = f'pred_{pred_prefix}_{j}'
777
+ self.ctx.train_data[col_name] = preds_train[:, j]
778
+ self.ctx.test_data[col_name] = preds_test[:, j]
779
+
780
+ def _resolve_best_epoch(self,
781
+ history: Optional[Dict[str, List[float]]],
782
+ default_epochs: int) -> int:
783
+ if not history:
784
+ return max(1, int(default_epochs))
785
+ vals = history.get("val") or []
786
+ if not vals:
787
+ return max(1, int(default_epochs))
788
+ best_idx = int(np.nanargmin(vals))
789
+ return max(1, best_idx + 1)
790
+
791
+ def _fit_predict_cache(self,
792
+ model,
793
+ X_train,
794
+ y_train,
795
+ sample_weight,
796
+ pred_prefix: str,
797
+ use_oht: bool = False,
798
+ design_fn=None,
799
+ fit_kwargs: Optional[Dict[str, Any]] = None,
800
+ sample_weight_arg: Optional[str] = 'sample_weight',
801
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
802
+ predict_kwargs_test: Optional[Dict[str, Any]] = None,
803
+ predict_fn: Optional[Callable[..., Any]] = None,
804
+ record_label: bool = True) -> None:
805
+ fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
806
+ if sample_weight is not None and sample_weight_arg:
807
+ fit_kwargs.setdefault(sample_weight_arg, sample_weight)
808
+ model.fit(X_train, y_train, **fit_kwargs)
809
+ if record_label:
810
+ self.ctx.model_label.append(self.label)
811
+ self._predict_and_cache(
812
+ model,
813
+ pred_prefix,
814
+ use_oht=use_oht,
815
+ design_fn=design_fn,
816
+ predict_kwargs_train=predict_kwargs_train,
817
+ predict_kwargs_test=predict_kwargs_test,
818
+ predict_fn=predict_fn)
819
+
820
+
821
+ class GNNTrainer(TrainerBase):
822
+ def __init__(self, context: "BayesOptModel") -> None:
823
+ super().__init__(context, 'GNN', 'GNN')
824
+ self.model: Optional[GraphNeuralNetSklearn] = None
825
+ self.enable_distributed_optuna = bool(context.config.use_gnn_ddp)
826
+
827
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> GraphNeuralNetSklearn:
828
+ params = params or {}
829
+ base_tw_power = self.ctx.default_tweedie_power()
830
+ model = GraphNeuralNetSklearn(
831
+ model_nme=f"{self.ctx.model_nme}_gnn",
832
+ input_dim=len(self.ctx.var_nmes),
833
+ hidden_dim=int(params.get("hidden_dim", 64)),
834
+ num_layers=int(params.get("num_layers", 2)),
835
+ k_neighbors=int(params.get("k_neighbors", 10)),
836
+ dropout=float(params.get("dropout", 0.1)),
837
+ learning_rate=float(params.get("learning_rate", 1e-3)),
838
+ epochs=int(params.get("epochs", self.ctx.epochs)),
839
+ patience=int(params.get("patience", 5)),
840
+ task_type=self.ctx.task_type,
841
+ tweedie_power=float(params.get("tw_power", base_tw_power or 1.5)),
842
+ weight_decay=float(params.get("weight_decay", 0.0)),
843
+ use_data_parallel=bool(self.ctx.config.use_gnn_data_parallel),
844
+ use_ddp=bool(self.ctx.config.use_gnn_ddp),
845
+ use_approx_knn=bool(self.ctx.config.gnn_use_approx_knn),
846
+ approx_knn_threshold=int(self.ctx.config.gnn_approx_knn_threshold),
847
+ graph_cache_path=self.ctx.config.gnn_graph_cache,
848
+ max_gpu_knn_nodes=self.ctx.config.gnn_max_gpu_knn_nodes,
849
+ knn_gpu_mem_ratio=float(self.ctx.config.gnn_knn_gpu_mem_ratio),
850
+ knn_gpu_mem_overhead=float(
851
+ self.ctx.config.gnn_knn_gpu_mem_overhead),
852
+ )
853
+ return model
854
+
855
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
856
+ base_tw_power = self.ctx.default_tweedie_power()
857
+ metric_ctx: Dict[str, Any] = {}
858
+
859
+ def data_provider():
860
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
861
+ assert data is not None, "Preprocessed training data is missing."
862
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
863
+
864
+ def model_builder(params: Dict[str, Any]):
865
+ tw_power = params.get("tw_power", base_tw_power)
866
+ metric_ctx["tw_power"] = tw_power
867
+ return self._build_model(params)
868
+
869
+ def preprocess_fn(X_train, X_val):
870
+ X_train_s, X_val_s, _ = self._standardize_fold(
871
+ X_train, X_val, self.ctx.num_features)
872
+ return X_train_s, X_val_s
873
+
874
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
875
+ model.fit(
876
+ X_train,
877
+ y_train,
878
+ w_train=w_train,
879
+ X_val=X_val,
880
+ y_val=y_val,
881
+ w_val=w_val,
882
+ trial=trial_obj,
883
+ )
884
+ return model.predict(X_val)
885
+
886
+ def metric_fn(y_true, y_pred, weight):
887
+ if self.ctx.task_type == 'classification':
888
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
889
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
890
+ y_pred_safe = np.maximum(y_pred, EPS)
891
+ power = metric_ctx.get("tw_power", base_tw_power or 1.5)
892
+ return mean_tweedie_deviance(
893
+ y_true,
894
+ y_pred_safe,
895
+ sample_weight=weight,
896
+ power=power,
897
+ )
898
+
899
+ # Keep GNN BO lightweight: sample during CV, use full data for final training.
900
+ X_cap = data_provider()[0]
901
+ sample_limit = min(200000, len(X_cap)) if len(X_cap) > 200000 else None
902
+
903
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
904
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-4, 5e-3, log=True),
905
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
906
+ "num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
907
+ "k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 30),
908
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
909
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
910
+ }
911
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
912
+ param_space["tw_power"] = lambda t: t.suggest_float(
913
+ 'tw_power', 1.0, 2.0)
914
+
915
+ return self.cross_val_generic(
916
+ trial=trial,
917
+ hyperparameter_space=param_space,
918
+ data_provider=data_provider,
919
+ model_builder=model_builder,
920
+ metric_fn=metric_fn,
921
+ sample_limit=sample_limit,
922
+ preprocess_fn=preprocess_fn,
923
+ fit_predict_fn=fit_predict,
924
+ cleanup_fn=lambda m: getattr(
925
+ getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
926
+ )
927
+
928
+ def train(self) -> None:
929
+ if not self.best_params:
930
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
931
+
932
+ data = self.ctx.train_oht_scl_data
933
+ assert data is not None, "Preprocessed training data is missing."
934
+ X_all = data[self.ctx.var_nmes]
935
+ y_all = data[self.ctx.resp_nme]
936
+ w_all = data[self.ctx.weight_nme]
937
+
938
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
939
+ refit_epochs = None
940
+
941
+ if 0.0 < float(self.ctx.prop_test) < 1.0 and len(X_all) >= 10:
942
+ splitter = ShuffleSplit(
943
+ n_splits=1,
944
+ test_size=self.ctx.prop_test,
945
+ random_state=self.ctx.rand_seed,
946
+ )
947
+ train_idx, val_idx = next(splitter.split(X_all))
948
+ X_train = X_all.iloc[train_idx]
949
+ y_train = y_all.iloc[train_idx]
950
+ w_train = w_all.iloc[train_idx]
951
+ X_val = X_all.iloc[val_idx]
952
+ y_val = y_all.iloc[val_idx]
953
+ w_val = w_all.iloc[val_idx]
954
+
955
+ if use_refit:
956
+ tmp_model = self._build_model(self.best_params)
957
+ tmp_model.fit(
958
+ X_train,
959
+ y_train,
960
+ w_train=w_train,
961
+ X_val=X_val,
962
+ y_val=y_val,
963
+ w_val=w_val,
964
+ trial=None,
965
+ )
966
+ refit_epochs = int(getattr(tmp_model, "best_epoch", None) or self.ctx.epochs)
967
+ getattr(getattr(tmp_model, "gnn", None), "to",
968
+ lambda *_args, **_kwargs: None)("cpu")
969
+ self._clean_gpu()
970
+ else:
971
+ self.model = self._build_model(self.best_params)
972
+ self.model.fit(
973
+ X_train,
974
+ y_train,
975
+ w_train=w_train,
976
+ X_val=X_val,
977
+ y_val=y_val,
978
+ w_val=w_val,
979
+ trial=None,
980
+ )
981
+ else:
982
+ use_refit = False
983
+
984
+ if use_refit:
985
+ self.model = self._build_model(self.best_params)
986
+ if refit_epochs is not None:
987
+ self.model.epochs = int(refit_epochs)
988
+ self.model.fit(
989
+ X_all,
990
+ y_all,
991
+ w_train=w_all,
992
+ X_val=None,
993
+ y_val=None,
994
+ w_val=None,
995
+ trial=None,
996
+ )
997
+ elif self.model is None:
998
+ self.model = self._build_model(self.best_params)
999
+ self.model.fit(
1000
+ X_all,
1001
+ y_all,
1002
+ w_train=w_all,
1003
+ X_val=None,
1004
+ y_val=None,
1005
+ w_val=None,
1006
+ trial=None,
1007
+ )
1008
+ self.ctx.model_label.append(self.label)
1009
+ self._predict_and_cache(self.model, pred_prefix='gnn', use_oht=True)
1010
+ self.ctx.gnn_best = self.model
1011
+
1012
+ # If geo_feature_nmes is set, refresh geo tokens for FT input.
1013
+ if self.ctx.config.geo_feature_nmes:
1014
+ self.prepare_geo_tokens(force=True)
1015
+
1016
+ def ensemble_predict(self, k: int) -> None:
1017
+ if not self.best_params:
1018
+ raise RuntimeError("Run tune() first to obtain best GNN parameters.")
1019
+ data = self.ctx.train_oht_scl_data
1020
+ test_data = self.ctx.test_oht_scl_data
1021
+ if data is None or test_data is None:
1022
+ raise RuntimeError("Missing standardized data for GNN ensemble.")
1023
+ X_all = data[self.ctx.var_nmes]
1024
+ y_all = data[self.ctx.resp_nme]
1025
+ w_all = data[self.ctx.weight_nme]
1026
+ X_test = test_data[self.ctx.var_nmes]
1027
+
1028
+ k = max(2, int(k))
1029
+ n_samples = len(X_all)
1030
+ if n_samples < k:
1031
+ print(
1032
+ f"[GNN Ensemble] n_samples={n_samples} < k={k}; skip ensemble.",
1033
+ flush=True,
1034
+ )
1035
+ return
1036
+
1037
+ splitter = KFold(
1038
+ n_splits=k,
1039
+ shuffle=True,
1040
+ random_state=self.ctx.rand_seed,
1041
+ )
1042
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
1043
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
1044
+
1045
+ for train_idx, val_idx in splitter.split(X_all):
1046
+ model = self._build_model(self.best_params)
1047
+ model.fit(
1048
+ X_all.iloc[train_idx],
1049
+ y_all.iloc[train_idx],
1050
+ w_train=w_all.iloc[train_idx],
1051
+ X_val=X_all.iloc[val_idx],
1052
+ y_val=y_all.iloc[val_idx],
1053
+ w_val=w_all.iloc[val_idx],
1054
+ trial=None,
1055
+ )
1056
+ pred_train = model.predict(X_all)
1057
+ pred_test = model.predict(X_test)
1058
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
1059
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
1060
+ getattr(getattr(model, "gnn", None), "to",
1061
+ lambda *_args, **_kwargs: None)("cpu")
1062
+ self._clean_gpu()
1063
+
1064
+ preds_train = preds_train_sum / float(k)
1065
+ preds_test = preds_test_sum / float(k)
1066
+ self._cache_predictions("gnn", preds_train, preds_test)
1067
+
1068
+ def prepare_geo_tokens(self, force: bool = False) -> None:
1069
+ """Train/update the GNN encoder for geo tokens and inject them into FT input."""
1070
+ geo_cols = list(self.ctx.config.geo_feature_nmes or [])
1071
+ if not geo_cols:
1072
+ return
1073
+ if (not force) and self.ctx.train_geo_tokens is not None and self.ctx.test_geo_tokens is not None:
1074
+ return
1075
+
1076
+ result = self.ctx._build_geo_tokens()
1077
+ if result is None:
1078
+ return
1079
+ train_tokens, test_tokens, cols, geo_gnn = result
1080
+ self.ctx.train_geo_tokens = train_tokens
1081
+ self.ctx.test_geo_tokens = test_tokens
1082
+ self.ctx.geo_token_cols = cols
1083
+ self.ctx.geo_gnn_model = geo_gnn
1084
+ print(f"[GeoToken][GNNTrainer] Generated {len(cols)} dims and injected into FT.", flush=True)
1085
+
1086
+ def save(self) -> None:
1087
+ if self.model is None:
1088
+ print(f"[save] Warning: No model to save for {self.label}")
1089
+ return
1090
+ path = self.output.model_path(self._get_model_filename())
1091
+ base_gnn = getattr(self.model, "_unwrap_gnn", lambda: None)()
1092
+ state = None if base_gnn is None else base_gnn.state_dict()
1093
+ payload = {
1094
+ "best_params": self.best_params,
1095
+ "state_dict": state,
1096
+ }
1097
+ torch.save(payload, path)
1098
+
1099
+ def load(self) -> None:
1100
+ path = self.output.model_path(self._get_model_filename())
1101
+ if not os.path.exists(path):
1102
+ print(f"[load] Warning: Model file not found: {path}")
1103
+ return
1104
+ payload = torch.load(path, map_location='cpu')
1105
+ if not isinstance(payload, dict):
1106
+ raise ValueError(f"Invalid GNN checkpoint: {path}")
1107
+ params = payload.get("best_params") or {}
1108
+ state_dict = payload.get("state_dict")
1109
+ model = self._build_model(params)
1110
+ if params:
1111
+ model.set_params(dict(params))
1112
+ base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
1113
+ if base_gnn is not None and state_dict is not None:
1114
+ base_gnn.load_state_dict(state_dict, strict=False)
1115
+ self.model = model
1116
+ self.best_params = dict(params) if isinstance(params, dict) else None
1117
+ self.ctx.gnn_best = self.model
1118
+
1119
+
1120
+ class XGBTrainer(TrainerBase):
1121
+ def __init__(self, context: "BayesOptModel") -> None:
1122
+ super().__init__(context, 'Xgboost', 'Xgboost')
1123
+ self.model: Optional[xgb.XGBModel] = None
1124
+ self._xgb_use_gpu = False
1125
+ self._xgb_gpu_warned = False
1126
+
1127
+ def _build_estimator(self) -> xgb.XGBModel:
1128
+ use_gpu = bool(self.ctx.use_gpu and _xgb_cuda_available())
1129
+ self._xgb_use_gpu = use_gpu
1130
+ params = dict(
1131
+ objective=self.ctx.obj,
1132
+ random_state=self.ctx.rand_seed,
1133
+ subsample=0.9,
1134
+ tree_method='gpu_hist' if use_gpu else 'hist',
1135
+ enable_categorical=True,
1136
+ predictor='gpu_predictor' if use_gpu else 'cpu_predictor'
1137
+ )
1138
+ if self.ctx.use_gpu and not use_gpu and not self._xgb_gpu_warned:
1139
+ print(
1140
+ "[XGBoost] CUDA requested but not available; falling back to CPU.",
1141
+ flush=True,
1142
+ )
1143
+ self._xgb_gpu_warned = True
1144
+ if use_gpu:
1145
+ params['gpu_id'] = 0
1146
+ print(f">>> XGBoost using GPU ID: 0 (Single GPU Mode)")
1147
+ if self.ctx.task_type == 'classification':
1148
+ params.setdefault("eval_metric", "logloss")
1149
+ return xgb.XGBClassifier(**params)
1150
+ return xgb.XGBRegressor(**params)
1151
+
1152
+ def _resolve_early_stopping_rounds(self, n_estimators: int) -> int:
1153
+ n_estimators = max(1, int(n_estimators))
1154
+ base = max(5, n_estimators // 10)
1155
+ return min(50, base)
1156
+
1157
+ def _build_fit_kwargs(self,
1158
+ w_train,
1159
+ X_val=None,
1160
+ y_val=None,
1161
+ w_val=None,
1162
+ n_estimators: Optional[int] = None) -> Dict[str, Any]:
1163
+ fit_kwargs = dict(self.ctx.fit_params or {})
1164
+ fit_kwargs.pop("sample_weight", None)
1165
+ fit_kwargs["sample_weight"] = w_train
1166
+
1167
+ if "eval_set" not in fit_kwargs and X_val is not None and y_val is not None:
1168
+ fit_kwargs["eval_set"] = [(X_val, y_val)]
1169
+ if w_val is not None:
1170
+ fit_kwargs["sample_weight_eval_set"] = [w_val]
1171
+
1172
+ if "eval_metric" not in fit_kwargs:
1173
+ fit_kwargs["eval_metric"] = "logloss" if self.ctx.task_type == 'classification' else "rmse"
1174
+
1175
+ if "early_stopping_rounds" not in fit_kwargs and "eval_set" in fit_kwargs:
1176
+ rounds = self._resolve_early_stopping_rounds(n_estimators or 100)
1177
+ fit_kwargs["early_stopping_rounds"] = rounds
1178
+
1179
+ fit_kwargs.setdefault("verbose", False)
1180
+ return fit_kwargs
1181
+
1182
+ def ensemble_predict(self, k: int) -> None:
1183
+ if not self.best_params:
1184
+ raise RuntimeError("Run tune() first to obtain best XGB parameters.")
1185
+ k = max(2, int(k))
1186
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
1187
+ y_all = self.ctx.train_data[self.ctx.resp_nme].values
1188
+ w_all = self.ctx.train_data[self.ctx.weight_nme].values
1189
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
1190
+ n_samples = len(X_all)
1191
+ if n_samples < k:
1192
+ print(
1193
+ f"[XGB Ensemble] n_samples={n_samples} < k={k}; skip ensemble.",
1194
+ flush=True,
1195
+ )
1196
+ return
1197
+
1198
+ splitter = KFold(
1199
+ n_splits=k,
1200
+ shuffle=True,
1201
+ random_state=self.ctx.rand_seed,
1202
+ )
1203
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
1204
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
1205
+
1206
+ for train_idx, val_idx in splitter.split(X_all):
1207
+ X_train = X_all.iloc[train_idx]
1208
+ y_train = y_all[train_idx]
1209
+ w_train = w_all[train_idx]
1210
+ X_val = X_all.iloc[val_idx]
1211
+ y_val = y_all[val_idx]
1212
+ w_val = w_all[val_idx]
1213
+
1214
+ clf = self._build_estimator()
1215
+ clf.set_params(**self.best_params)
1216
+ fit_kwargs = self._build_fit_kwargs(
1217
+ w_train=w_train,
1218
+ X_val=X_val,
1219
+ y_val=y_val,
1220
+ w_val=w_val,
1221
+ n_estimators=self.best_params.get("n_estimators", 100),
1222
+ )
1223
+ clf.fit(X_train, y_train, **fit_kwargs)
1224
+
1225
+ if self.ctx.task_type == 'classification':
1226
+ pred_train = clf.predict_proba(X_all)[:, 1]
1227
+ pred_test = clf.predict_proba(X_test)[:, 1]
1228
+ else:
1229
+ pred_train = clf.predict(X_all)
1230
+ pred_test = clf.predict(X_test)
1231
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
1232
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
1233
+ self._clean_gpu()
1234
+
1235
+ preds_train = preds_train_sum / float(k)
1236
+ preds_test = preds_test_sum / float(k)
1237
+ self._cache_predictions("xgb", preds_train, preds_test)
1238
+
1239
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1240
+ learning_rate = trial.suggest_float(
1241
+ 'learning_rate', 1e-5, 1e-1, log=True)
1242
+ gamma = trial.suggest_float('gamma', 0, 10000)
1243
+ max_depth_max = max(
1244
+ 3, int(getattr(self.config, "xgb_max_depth_max", 25)))
1245
+ n_estimators_max = max(
1246
+ 10, int(getattr(self.config, "xgb_n_estimators_max", 500)))
1247
+ max_depth = trial.suggest_int('max_depth', 3, max_depth_max)
1248
+ n_estimators = trial.suggest_int(
1249
+ 'n_estimators', 10, n_estimators_max, step=10)
1250
+ min_child_weight = trial.suggest_int(
1251
+ 'min_child_weight', 100, 10000, step=100)
1252
+ reg_alpha = trial.suggest_float('reg_alpha', 1e-10, 1, log=True)
1253
+ reg_lambda = trial.suggest_float('reg_lambda', 1e-10, 1, log=True)
1254
+ if trial is not None:
1255
+ print(
1256
+ f"[Optuna][Xgboost] trial_id={trial.number} max_depth={max_depth} "
1257
+ f"n_estimators={n_estimators}",
1258
+ flush=True,
1259
+ )
1260
+ if max_depth >= 20 and n_estimators >= 300:
1261
+ raise optuna.TrialPruned(
1262
+ "XGB config is likely too slow (max_depth>=20 & n_estimators>=300)")
1263
+ clf = self._build_estimator()
1264
+ params = {
1265
+ 'learning_rate': learning_rate,
1266
+ 'gamma': gamma,
1267
+ 'max_depth': max_depth,
1268
+ 'n_estimators': n_estimators,
1269
+ 'min_child_weight': min_child_weight,
1270
+ 'reg_alpha': reg_alpha,
1271
+ 'reg_lambda': reg_lambda
1272
+ }
1273
+ tweedie_variance_power = None
1274
+ if self.ctx.task_type != 'classification':
1275
+ if self.ctx.obj == 'reg:tweedie':
1276
+ tweedie_variance_power = trial.suggest_float(
1277
+ 'tweedie_variance_power', 1, 2)
1278
+ params['tweedie_variance_power'] = tweedie_variance_power
1279
+ elif self.ctx.obj == 'count:poisson':
1280
+ tweedie_variance_power = 1
1281
+ elif self.ctx.obj == 'reg:gamma':
1282
+ tweedie_variance_power = 2
1283
+ else:
1284
+ tweedie_variance_power = 1.5
1285
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
1286
+ y_all = self.ctx.train_data[self.ctx.resp_nme].values
1287
+ w_all = self.ctx.train_data[self.ctx.weight_nme].values
1288
+
1289
+ losses: List[float] = []
1290
+ for train_idx, val_idx in self.ctx.cv.split(X_all):
1291
+ X_train = X_all.iloc[train_idx]
1292
+ y_train = y_all[train_idx]
1293
+ w_train = w_all[train_idx]
1294
+ X_val = X_all.iloc[val_idx]
1295
+ y_val = y_all[val_idx]
1296
+ w_val = w_all[val_idx]
1297
+
1298
+ clf = self._build_estimator()
1299
+ clf.set_params(**params)
1300
+ fit_kwargs = self._build_fit_kwargs(
1301
+ w_train=w_train,
1302
+ X_val=X_val,
1303
+ y_val=y_val,
1304
+ w_val=w_val,
1305
+ n_estimators=n_estimators,
1306
+ )
1307
+ clf.fit(X_train, y_train, **fit_kwargs)
1308
+
1309
+ if self.ctx.task_type == 'classification':
1310
+ y_pred = clf.predict_proba(X_val)[:, 1]
1311
+ y_pred = np.clip(y_pred, EPS, 1 - EPS)
1312
+ loss = log_loss(y_val, y_pred, sample_weight=w_val)
1313
+ else:
1314
+ y_pred = clf.predict(X_val)
1315
+ y_pred_safe = np.maximum(y_pred, EPS)
1316
+ loss = mean_tweedie_deviance(
1317
+ y_val,
1318
+ y_pred_safe,
1319
+ sample_weight=w_val,
1320
+ power=tweedie_variance_power,
1321
+ )
1322
+ losses.append(float(loss))
1323
+ self._clean_gpu()
1324
+
1325
+ return float(np.mean(losses))
1326
+
1327
+ def train(self) -> None:
1328
+ if not self.best_params:
1329
+ raise RuntimeError("Run tune() first to obtain best XGB parameters.")
1330
+ self.model = self._build_estimator()
1331
+ self.model.set_params(**self.best_params)
1332
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
1333
+ predict_fn = None
1334
+ if self.ctx.task_type == 'classification':
1335
+ def _predict_proba(X, **_kwargs):
1336
+ return self.model.predict_proba(X)[:, 1]
1337
+ predict_fn = _predict_proba
1338
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
1339
+ y_all = self.ctx.train_data[self.ctx.resp_nme].values
1340
+ w_all = self.ctx.train_data[self.ctx.weight_nme].values
1341
+
1342
+ use_split = 0.0 < float(self.ctx.prop_test) < 1.0 and len(X_all) >= 10
1343
+ if use_split:
1344
+ splitter = ShuffleSplit(
1345
+ n_splits=1,
1346
+ test_size=self.ctx.prop_test,
1347
+ random_state=self.ctx.rand_seed,
1348
+ )
1349
+ train_idx, val_idx = next(splitter.split(X_all))
1350
+ X_train = X_all.iloc[train_idx]
1351
+ y_train = y_all[train_idx]
1352
+ w_train = w_all[train_idx]
1353
+ X_val = X_all.iloc[val_idx]
1354
+ y_val = y_all[val_idx]
1355
+ w_val = w_all[val_idx]
1356
+ fit_kwargs = self._build_fit_kwargs(
1357
+ w_train=w_train,
1358
+ X_val=X_val,
1359
+ y_val=y_val,
1360
+ w_val=w_val,
1361
+ n_estimators=self.best_params.get("n_estimators", 100),
1362
+ )
1363
+ self.model.fit(X_train, y_train, **fit_kwargs)
1364
+ best_iter = getattr(self.model, "best_iteration", None)
1365
+ if use_refit and best_iter is not None:
1366
+ refit_model = self._build_estimator()
1367
+ refit_params = dict(self.best_params)
1368
+ refit_params["n_estimators"] = int(best_iter) + 1
1369
+ refit_model.set_params(**refit_params)
1370
+ refit_kwargs = dict(self.ctx.fit_params or {})
1371
+ refit_kwargs.setdefault("sample_weight", w_all)
1372
+ refit_kwargs.pop("eval_set", None)
1373
+ refit_kwargs.pop("sample_weight_eval_set", None)
1374
+ refit_kwargs.pop("early_stopping_rounds", None)
1375
+ refit_kwargs.pop("eval_metric", None)
1376
+ refit_kwargs.setdefault("verbose", False)
1377
+ refit_model.fit(X_all, y_all, **refit_kwargs)
1378
+ self.model = refit_model
1379
+ else:
1380
+ fit_kwargs = dict(self.ctx.fit_params or {})
1381
+ fit_kwargs.setdefault("sample_weight", w_all)
1382
+ self.model.fit(X_all, y_all, **fit_kwargs)
1383
+
1384
+ self.ctx.model_label.append(self.label)
1385
+ self._predict_and_cache(
1386
+ self.model,
1387
+ pred_prefix='xgb',
1388
+ predict_fn=predict_fn
1389
+ )
1390
+ self.ctx.xgb_best = self.model
1391
+
1392
+
1393
+ class GLMTrainer(TrainerBase):
1394
+ def __init__(self, context: "BayesOptModel") -> None:
1395
+ super().__init__(context, 'GLM', 'GLM')
1396
+ self.model = None
1397
+
1398
+ def _select_family(self, tweedie_power: Optional[float] = None):
1399
+ if self.ctx.task_type == 'classification':
1400
+ return sm.families.Binomial()
1401
+ if self.ctx.obj == 'count:poisson':
1402
+ return sm.families.Poisson()
1403
+ if self.ctx.obj == 'reg:gamma':
1404
+ return sm.families.Gamma()
1405
+ power = tweedie_power if tweedie_power is not None else 1.5
1406
+ return sm.families.Tweedie(var_power=power, link=sm.families.links.log())
1407
+
1408
+ def _prepare_design(self, data: pd.DataFrame) -> pd.DataFrame:
1409
+ # Add intercept to the statsmodels design matrix.
1410
+ X = data[self.ctx.var_nmes]
1411
+ return sm.add_constant(X, has_constant='add')
1412
+
1413
+ def _metric_power(self, family, tweedie_power: Optional[float]) -> float:
1414
+ if isinstance(family, sm.families.Poisson):
1415
+ return 1.0
1416
+ if isinstance(family, sm.families.Gamma):
1417
+ return 2.0
1418
+ if isinstance(family, sm.families.Tweedie):
1419
+ return tweedie_power if tweedie_power is not None else getattr(family, 'var_power', 1.5)
1420
+ return 1.5
1421
+
1422
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1423
+ param_space = {
1424
+ "alpha": lambda t: t.suggest_float('alpha', 1e-6, 1e2, log=True),
1425
+ "l1_ratio": lambda t: t.suggest_float('l1_ratio', 0.0, 1.0)
1426
+ }
1427
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
1428
+ param_space["tweedie_power"] = lambda t: t.suggest_float(
1429
+ 'tweedie_power', 1.0, 2.0)
1430
+
1431
+ def data_provider():
1432
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
1433
+ assert data is not None, "Preprocessed training data is missing."
1434
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
1435
+
1436
+ def preprocess_fn(X_train, X_val):
1437
+ X_train_s, X_val_s, _ = self._standardize_fold(
1438
+ X_train, X_val, self.ctx.num_features)
1439
+ return self._prepare_design(X_train_s), self._prepare_design(X_val_s)
1440
+
1441
+ metric_ctx: Dict[str, Any] = {}
1442
+
1443
+ def model_builder(params):
1444
+ family = self._select_family(params.get("tweedie_power"))
1445
+ metric_ctx["family"] = family
1446
+ metric_ctx["tweedie_power"] = params.get("tweedie_power")
1447
+ return {
1448
+ "family": family,
1449
+ "alpha": params["alpha"],
1450
+ "l1_ratio": params["l1_ratio"],
1451
+ "tweedie_power": params.get("tweedie_power")
1452
+ }
1453
+
1454
+ def fit_predict(model_cfg, X_train, y_train, w_train, X_val, y_val, w_val, _trial):
1455
+ glm = sm.GLM(y_train, X_train,
1456
+ family=model_cfg["family"],
1457
+ freq_weights=w_train)
1458
+ result = glm.fit_regularized(
1459
+ alpha=model_cfg["alpha"],
1460
+ L1_wt=model_cfg["l1_ratio"],
1461
+ maxiter=200
1462
+ )
1463
+ return result.predict(X_val)
1464
+
1465
+ def metric_fn(y_true, y_pred, weight):
1466
+ if self.ctx.task_type == 'classification':
1467
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
1468
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
1469
+ y_pred_safe = np.maximum(y_pred, EPS)
1470
+ return mean_tweedie_deviance(
1471
+ y_true,
1472
+ y_pred_safe,
1473
+ sample_weight=weight,
1474
+ power=self._metric_power(
1475
+ metric_ctx.get("family"), metric_ctx.get("tweedie_power"))
1476
+ )
1477
+
1478
+ return self.cross_val_generic(
1479
+ trial=trial,
1480
+ hyperparameter_space=param_space,
1481
+ data_provider=data_provider,
1482
+ model_builder=model_builder,
1483
+ metric_fn=metric_fn,
1484
+ preprocess_fn=preprocess_fn,
1485
+ fit_predict_fn=fit_predict,
1486
+ splitter=self.ctx.cv.split(self.ctx.train_oht_data[self.ctx.var_nmes]
1487
+ if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data[self.ctx.var_nmes])
1488
+ )
1489
+
1490
+ def train(self) -> None:
1491
+ if not self.best_params:
1492
+ raise RuntimeError("Run tune() first to obtain best GLM parameters.")
1493
+ tweedie_power = self.best_params.get('tweedie_power')
1494
+ family = self._select_family(tweedie_power)
1495
+
1496
+ X_train = self._prepare_design(self.ctx.train_oht_scl_data)
1497
+ y_train = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
1498
+ w_train = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
1499
+
1500
+ glm = sm.GLM(y_train, X_train, family=family,
1501
+ freq_weights=w_train)
1502
+ self.model = glm.fit_regularized(
1503
+ alpha=self.best_params['alpha'],
1504
+ L1_wt=self.best_params['l1_ratio'],
1505
+ maxiter=300
1506
+ )
1507
+
1508
+ self.ctx.glm_best = self.model
1509
+ self.ctx.model_label += [self.label]
1510
+ self._predict_and_cache(
1511
+ self.model,
1512
+ 'glm',
1513
+ design_fn=lambda train: self._prepare_design(
1514
+ self.ctx.train_oht_scl_data if train else self.ctx.test_oht_scl_data
1515
+ )
1516
+ )
1517
+
1518
+ def ensemble_predict(self, k: int) -> None:
1519
+ if not self.best_params:
1520
+ raise RuntimeError("Run tune() first to obtain best GLM parameters.")
1521
+ k = max(2, int(k))
1522
+ data = self.ctx.train_oht_scl_data
1523
+ if data is None:
1524
+ raise RuntimeError("Missing standardized data for GLM ensemble.")
1525
+ X_all = data[self.ctx.var_nmes]
1526
+ y_all = data[self.ctx.resp_nme]
1527
+ w_all = data[self.ctx.weight_nme]
1528
+ X_test = self.ctx.test_oht_scl_data
1529
+ if X_test is None:
1530
+ raise RuntimeError("Missing standardized test data for GLM ensemble.")
1531
+
1532
+ n_samples = len(X_all)
1533
+ if n_samples < k:
1534
+ print(
1535
+ f"[GLM Ensemble] n_samples={n_samples} < k={k}; skip ensemble.",
1536
+ flush=True,
1537
+ )
1538
+ return
1539
+
1540
+ X_all_design = self._prepare_design(data)
1541
+ X_test_design = self._prepare_design(X_test)
1542
+ tweedie_power = self.best_params.get('tweedie_power')
1543
+ family = self._select_family(tweedie_power)
1544
+
1545
+ splitter = KFold(
1546
+ n_splits=k,
1547
+ shuffle=True,
1548
+ random_state=self.ctx.rand_seed,
1549
+ )
1550
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
1551
+ preds_test_sum = np.zeros(len(X_test_design), dtype=np.float64)
1552
+
1553
+ for train_idx, _val_idx in splitter.split(X_all):
1554
+ X_train = X_all_design.iloc[train_idx]
1555
+ y_train = y_all.iloc[train_idx]
1556
+ w_train = w_all.iloc[train_idx]
1557
+
1558
+ glm = sm.GLM(y_train, X_train, family=family, freq_weights=w_train)
1559
+ result = glm.fit_regularized(
1560
+ alpha=self.best_params['alpha'],
1561
+ L1_wt=self.best_params['l1_ratio'],
1562
+ maxiter=300
1563
+ )
1564
+ pred_train = result.predict(X_all_design)
1565
+ pred_test = result.predict(X_test_design)
1566
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
1567
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
1568
+
1569
+ preds_train = preds_train_sum / float(k)
1570
+ preds_test = preds_test_sum / float(k)
1571
+ self._cache_predictions("glm", preds_train, preds_test)
1572
+
1573
+
1574
+ class ResNetTrainer(TrainerBase):
1575
+ def __init__(self, context: "BayesOptModel") -> None:
1576
+ if context.task_type == 'classification':
1577
+ super().__init__(context, 'ResNetClassifier', 'ResNet')
1578
+ else:
1579
+ super().__init__(context, 'ResNet', 'ResNet')
1580
+ self.model: Optional[ResNetSklearn] = None
1581
+ self.enable_distributed_optuna = bool(context.config.use_resn_ddp)
1582
+
1583
+ def _resolve_input_dim(self) -> int:
1584
+ data = getattr(self.ctx, "train_oht_scl_data", None)
1585
+ if data is not None and getattr(self.ctx, "var_nmes", None):
1586
+ return int(data[self.ctx.var_nmes].shape[1])
1587
+ return int(len(self.ctx.var_nmes or []))
1588
+
1589
+ def _build_model(self, params: Optional[Dict[str, Any]] = None) -> ResNetSklearn:
1590
+ params = params or {}
1591
+ power = params.get("tw_power", self.ctx.default_tweedie_power())
1592
+ if power is not None:
1593
+ power = float(power)
1594
+ resn_weight_decay = float(
1595
+ params.get(
1596
+ "weight_decay",
1597
+ getattr(self.ctx.config, "resn_weight_decay", 1e-4),
1598
+ )
1599
+ )
1600
+ return ResNetSklearn(
1601
+ model_nme=self.ctx.model_nme,
1602
+ input_dim=self._resolve_input_dim(),
1603
+ hidden_dim=int(params.get("hidden_dim", 64)),
1604
+ block_num=int(params.get("block_num", 2)),
1605
+ task_type=self.ctx.task_type,
1606
+ epochs=self.ctx.epochs,
1607
+ tweedie_power=power,
1608
+ learning_rate=float(params.get("learning_rate", 0.01)),
1609
+ patience=int(params.get("patience", 10)),
1610
+ use_layernorm=True,
1611
+ dropout=float(params.get("dropout", 0.1)),
1612
+ residual_scale=float(params.get("residual_scale", 0.1)),
1613
+ stochastic_depth=float(params.get("stochastic_depth", 0.0)),
1614
+ weight_decay=resn_weight_decay,
1615
+ use_data_parallel=self.ctx.config.use_resn_data_parallel,
1616
+ use_ddp=self.ctx.config.use_resn_ddp
1617
+ )
1618
+
1619
+ # ========= Cross-validation (for BayesOpt) =========
1620
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1621
+ # ResNet CV focuses on memory control:
1622
+ # - Create a ResNetSklearn per fold and release it immediately after.
1623
+ # - Move model to CPU, delete, and call gc/empty_cache after each fold.
1624
+ # - Optionally sample part of training data during BayesOpt to reduce memory.
1625
+
1626
+ base_tw_power = self.ctx.default_tweedie_power()
1627
+
1628
+ def data_provider():
1629
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
1630
+ assert data is not None, "Preprocessed training data is missing."
1631
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
1632
+
1633
+ metric_ctx: Dict[str, Any] = {}
1634
+
1635
+ def model_builder(params):
1636
+ power = params.get("tw_power", base_tw_power)
1637
+ metric_ctx["tw_power"] = power
1638
+ params_local = dict(params)
1639
+ params_local["tw_power"] = power
1640
+ return self._build_model(params_local)
1641
+
1642
+ def preprocess_fn(X_train, X_val):
1643
+ X_train_s, X_val_s, _ = self._standardize_fold(
1644
+ X_train, X_val, self.ctx.num_features)
1645
+ return X_train_s, X_val_s
1646
+
1647
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
1648
+ model.fit(
1649
+ X_train, y_train, w_train,
1650
+ X_val, y_val, w_val,
1651
+ trial=trial_obj
1652
+ )
1653
+ return model.predict(X_val)
1654
+
1655
+ def metric_fn(y_true, y_pred, weight):
1656
+ if self.ctx.task_type == 'regression':
1657
+ return mean_tweedie_deviance(
1658
+ y_true,
1659
+ y_pred,
1660
+ sample_weight=weight,
1661
+ power=metric_ctx.get("tw_power", base_tw_power)
1662
+ )
1663
+ return log_loss(y_true, y_pred, sample_weight=weight)
1664
+
1665
+ sample_cap = data_provider()[0]
1666
+ max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
1667
+
1668
+ return self.cross_val_generic(
1669
+ trial=trial,
1670
+ hyperparameter_space={
1671
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
1672
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
1673
+ "block_num": lambda t: t.suggest_int('block_num', 2, 10),
1674
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3, step=0.05),
1675
+ "residual_scale": lambda t: t.suggest_float('residual_scale', 0.05, 0.3, step=0.05),
1676
+ "patience": lambda t: t.suggest_int('patience', 3, 12),
1677
+ "stochastic_depth": lambda t: t.suggest_float('stochastic_depth', 0.0, 0.2, step=0.05),
1678
+ **({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie' else {})
1679
+ },
1680
+ data_provider=data_provider,
1681
+ model_builder=model_builder,
1682
+ metric_fn=metric_fn,
1683
+ sample_limit=max_rows_for_resnet_bo if len(
1684
+ sample_cap) > max_rows_for_resnet_bo > 0 else None,
1685
+ preprocess_fn=preprocess_fn,
1686
+ fit_predict_fn=fit_predict,
1687
+ cleanup_fn=lambda m: getattr(
1688
+ getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
1689
+ )
1690
+
1691
+ # ========= Train final ResNet with best hyperparameters =========
1692
+ def train(self) -> None:
1693
+ if not self.best_params:
1694
+ raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
1695
+
1696
+ params = dict(self.best_params)
1697
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
1698
+ data = self.ctx.train_oht_scl_data
1699
+ if data is None:
1700
+ raise RuntimeError("Missing standardized data for ResNet training.")
1701
+ X_all = data[self.ctx.var_nmes]
1702
+ y_all = data[self.ctx.resp_nme]
1703
+ w_all = data[self.ctx.weight_nme]
1704
+
1705
+ refit_epochs = None
1706
+ if use_refit and 0.0 < float(self.ctx.prop_test) < 1.0 and len(X_all) >= 10:
1707
+ splitter = ShuffleSplit(
1708
+ n_splits=1,
1709
+ test_size=self.ctx.prop_test,
1710
+ random_state=self.ctx.rand_seed,
1711
+ )
1712
+ train_idx, val_idx = next(splitter.split(X_all))
1713
+ tmp_model = self._build_model(params)
1714
+ tmp_model.fit(
1715
+ X_all.iloc[train_idx],
1716
+ y_all.iloc[train_idx],
1717
+ w_all.iloc[train_idx],
1718
+ X_all.iloc[val_idx],
1719
+ y_all.iloc[val_idx],
1720
+ w_all.iloc[val_idx],
1721
+ trial=None,
1722
+ )
1723
+ refit_epochs = self._resolve_best_epoch(
1724
+ getattr(tmp_model, "training_history", None),
1725
+ default_epochs=int(self.ctx.epochs),
1726
+ )
1727
+ getattr(getattr(tmp_model, "resnet", None), "to",
1728
+ lambda *_args, **_kwargs: None)("cpu")
1729
+ self._clean_gpu()
1730
+
1731
+ self.model = self._build_model(params)
1732
+ if refit_epochs is not None:
1733
+ self.model.epochs = int(refit_epochs)
1734
+ self.best_params = params
1735
+ loss_plot_path = self.output.plot_path(
1736
+ f'loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
1737
+ self.model.loss_curve_path = loss_plot_path
1738
+
1739
+ self._fit_predict_cache(
1740
+ self.model,
1741
+ X_all,
1742
+ y_all,
1743
+ sample_weight=w_all,
1744
+ pred_prefix='resn',
1745
+ use_oht=True,
1746
+ sample_weight_arg='w_train'
1747
+ )
1748
+
1749
+ # Convenience wrapper for external callers.
1750
+ self.ctx.resn_best = self.model
1751
+
1752
+ def ensemble_predict(self, k: int) -> None:
1753
+ if not self.best_params:
1754
+ raise RuntimeError("Run tune() first to obtain best ResNet parameters.")
1755
+ data = self.ctx.train_oht_scl_data
1756
+ test_data = self.ctx.test_oht_scl_data
1757
+ if data is None or test_data is None:
1758
+ raise RuntimeError("Missing standardized data for ResNet ensemble.")
1759
+ X_all = data[self.ctx.var_nmes]
1760
+ y_all = data[self.ctx.resp_nme]
1761
+ w_all = data[self.ctx.weight_nme]
1762
+ X_test = test_data[self.ctx.var_nmes]
1763
+
1764
+ k = max(2, int(k))
1765
+ n_samples = len(X_all)
1766
+ if n_samples < k:
1767
+ print(
1768
+ f"[ResNet Ensemble] n_samples={n_samples} < k={k}; skip ensemble.",
1769
+ flush=True,
1770
+ )
1771
+ return
1772
+
1773
+ splitter = KFold(
1774
+ n_splits=k,
1775
+ shuffle=True,
1776
+ random_state=self.ctx.rand_seed,
1777
+ )
1778
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
1779
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
1780
+
1781
+ for train_idx, val_idx in splitter.split(X_all):
1782
+ model = self._build_model(self.best_params)
1783
+ model.fit(
1784
+ X_all.iloc[train_idx],
1785
+ y_all.iloc[train_idx],
1786
+ w_all.iloc[train_idx],
1787
+ X_all.iloc[val_idx],
1788
+ y_all.iloc[val_idx],
1789
+ w_all.iloc[val_idx],
1790
+ trial=None,
1791
+ )
1792
+ pred_train = model.predict(X_all)
1793
+ pred_test = model.predict(X_test)
1794
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
1795
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
1796
+ getattr(getattr(model, "resnet", None), "to",
1797
+ lambda *_args, **_kwargs: None)("cpu")
1798
+ self._clean_gpu()
1799
+
1800
+ preds_train = preds_train_sum / float(k)
1801
+ preds_test = preds_test_sum / float(k)
1802
+ self._cache_predictions("resn", preds_train, preds_test)
1803
+
1804
+ # ========= Save / Load =========
1805
+ # ResNet is saved as state_dict and needs a custom load path.
1806
+ # Save logic is implemented in TrainerBase (checks .resnet attribute).
1807
+
1808
+ def load(self) -> None:
1809
+ # Load ResNet weights to the current device to match context.
1810
+ path = self.output.model_path(self._get_model_filename())
1811
+ if os.path.exists(path):
1812
+ resn_loaded = self._build_model(self.best_params)
1813
+ state_dict = torch.load(path, map_location='cpu')
1814
+ resn_loaded.resnet.load_state_dict(state_dict)
1815
+
1816
+ self._move_to_device(resn_loaded)
1817
+ self.model = resn_loaded
1818
+ self.ctx.resn_best = self.model
1819
+ else:
1820
+ print(f"[ResNetTrainer.load] Model file not found: {path}")
1821
+
1822
+
1823
+ class FTTrainer(TrainerBase):
1824
+ def __init__(self, context: "BayesOptModel") -> None:
1825
+ if context.task_type == 'classification':
1826
+ super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
1827
+ else:
1828
+ super().__init__(context, 'FTTransformer', 'FTTransformer')
1829
+ self.model: Optional[FTTransformerSklearn] = None
1830
+ self.enable_distributed_optuna = bool(context.config.use_ft_ddp)
1831
+ self._cv_geo_warned = False
1832
+
1833
+ def _resolve_numeric_tokens(self) -> int:
1834
+ requested = getattr(self.ctx.config, "ft_num_numeric_tokens", None)
1835
+ return FTTransformerSklearn.resolve_numeric_token_count(
1836
+ self.ctx.num_features,
1837
+ self.ctx.cate_list,
1838
+ requested,
1839
+ )
1840
+
1841
+ def _resolve_adaptive_heads(self,
1842
+ d_model: int,
1843
+ requested_heads: Optional[int] = None) -> Tuple[int, bool]:
1844
+ d_model = int(d_model)
1845
+ if d_model <= 0:
1846
+ raise ValueError(f"Invalid d_model={d_model}, expected > 0.")
1847
+
1848
+ default_heads = max(2, d_model // 16)
1849
+ base_heads = default_heads if requested_heads is None else int(
1850
+ requested_heads)
1851
+ base_heads = max(1, min(base_heads, d_model))
1852
+
1853
+ if d_model % base_heads == 0:
1854
+ return base_heads, False
1855
+
1856
+ for candidate in range(min(d_model, base_heads), 0, -1):
1857
+ if d_model % candidate == 0:
1858
+ return candidate, True
1859
+ return 1, True
1860
+
1861
+ def _build_geo_tokens_for_split(self,
1862
+ X_train: pd.DataFrame,
1863
+ X_val: pd.DataFrame,
1864
+ geo_params: Optional[Dict[str, Any]] = None):
1865
+ if not self.ctx.config.geo_feature_nmes:
1866
+ return None
1867
+ orig_train = self.ctx.train_data
1868
+ orig_test = self.ctx.test_data
1869
+ try:
1870
+ self.ctx.train_data = orig_train.loc[X_train.index].copy()
1871
+ self.ctx.test_data = orig_train.loc[X_val.index].copy()
1872
+ return self.ctx._build_geo_tokens(geo_params)
1873
+ finally:
1874
+ self.ctx.train_data = orig_train
1875
+ self.ctx.test_data = orig_test
1876
+
1877
+ def cross_val_unsupervised(self, trial: Optional[optuna.trial.Trial]) -> float:
1878
+ """Optuna objective A: minimize validation loss for masked reconstruction."""
1879
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
1880
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-3, log=True),
1881
+ "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
1882
+ "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
1883
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
1884
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
1885
+ "mask_prob_num": lambda t: t.suggest_float('mask_prob_num', 0.05, 0.4),
1886
+ "mask_prob_cat": lambda t: t.suggest_float('mask_prob_cat', 0.05, 0.4),
1887
+ "num_loss_weight": lambda t: t.suggest_float('num_loss_weight', 0.25, 4.0, log=True),
1888
+ "cat_loss_weight": lambda t: t.suggest_float('cat_loss_weight', 0.25, 4.0, log=True),
1889
+ }
1890
+
1891
+ params: Optional[Dict[str, Any]] = None
1892
+ if self._distributed_forced_params is not None:
1893
+ params = self._distributed_forced_params
1894
+ self._distributed_forced_params = None
1895
+ else:
1896
+ if trial is None:
1897
+ raise RuntimeError(
1898
+ "Missing Optuna trial for parameter sampling.")
1899
+ params = {name: sampler(trial)
1900
+ for name, sampler in param_space.items()}
1901
+ if self._should_use_distributed_optuna():
1902
+ self._distributed_prepare_trial(params)
1903
+
1904
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
1905
+ max_rows_for_ft_bo = min(1_000_000, int(len(X_all) / 2))
1906
+ if max_rows_for_ft_bo > 0 and len(X_all) > max_rows_for_ft_bo:
1907
+ X_all = X_all.sample(n=max_rows_for_ft_bo,
1908
+ random_state=self.ctx.rand_seed)
1909
+
1910
+ splitter = ShuffleSplit(
1911
+ n_splits=1,
1912
+ test_size=self.ctx.prop_test,
1913
+ random_state=self.ctx.rand_seed
1914
+ )
1915
+ train_idx, val_idx = next(splitter.split(X_all))
1916
+ X_train = X_all.iloc[train_idx]
1917
+ X_val = X_all.iloc[val_idx]
1918
+ geo_train = geo_val = None
1919
+ if self.ctx.config.geo_feature_nmes:
1920
+ built = self._build_geo_tokens_for_split(X_train, X_val, params)
1921
+ if built is not None:
1922
+ geo_train, geo_val, _, _ = built
1923
+ elif not self._cv_geo_warned:
1924
+ print(
1925
+ "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
1926
+ flush=True,
1927
+ )
1928
+ self._cv_geo_warned = True
1929
+
1930
+ d_model = int(params["d_model"])
1931
+ n_layers = int(params["n_layers"])
1932
+ num_numeric_tokens = self._resolve_numeric_tokens()
1933
+ token_count = num_numeric_tokens + len(self.ctx.cate_list)
1934
+ if geo_train is not None:
1935
+ token_count += 1
1936
+ approx_units = d_model * n_layers * max(1, token_count)
1937
+ if approx_units > 12_000_000:
1938
+ raise optuna.TrialPruned(
1939
+ f"config exceeds safe memory budget (approx_units={approx_units})")
1940
+
1941
+ adaptive_heads, _ = self._resolve_adaptive_heads(
1942
+ d_model=d_model,
1943
+ requested_heads=params.get("n_heads")
1944
+ )
1945
+
1946
+ mask_prob_num = float(params.get("mask_prob_num", 0.15))
1947
+ mask_prob_cat = float(params.get("mask_prob_cat", 0.15))
1948
+ num_loss_weight = float(params.get("num_loss_weight", 1.0))
1949
+ cat_loss_weight = float(params.get("cat_loss_weight", 1.0))
1950
+
1951
+ model_params = dict(params)
1952
+ model_params["n_heads"] = adaptive_heads
1953
+ for k in ("mask_prob_num", "mask_prob_cat", "num_loss_weight", "cat_loss_weight"):
1954
+ model_params.pop(k, None)
1955
+
1956
+ model = FTTransformerSklearn(
1957
+ model_nme=self.ctx.model_nme,
1958
+ num_cols=self.ctx.num_features,
1959
+ cat_cols=self.ctx.cate_list,
1960
+ task_type=self.ctx.task_type,
1961
+ epochs=self.ctx.epochs,
1962
+ patience=5,
1963
+ weight_decay=float(params.get("weight_decay", 0.0)),
1964
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
1965
+ use_ddp=self.ctx.config.use_ft_ddp,
1966
+ num_numeric_tokens=num_numeric_tokens,
1967
+ )
1968
+ model.set_params(model_params)
1969
+ try:
1970
+ return float(model.fit_unsupervised(
1971
+ X_train,
1972
+ X_val=X_val,
1973
+ trial=trial,
1974
+ geo_train=geo_train,
1975
+ geo_val=geo_val,
1976
+ mask_prob_num=mask_prob_num,
1977
+ mask_prob_cat=mask_prob_cat,
1978
+ num_loss_weight=num_loss_weight,
1979
+ cat_loss_weight=cat_loss_weight
1980
+ ))
1981
+ finally:
1982
+ getattr(getattr(model, "ft", None), "to",
1983
+ lambda *_args, **_kwargs: None)("cpu")
1984
+ self._clean_gpu()
1985
+
1986
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1987
+ # FT-Transformer CV also focuses on memory control:
1988
+ # - Shrink search space to avoid oversized models.
1989
+ # - Release GPU memory after each fold so the next trial can run.
1990
+ # Slightly shrink hyperparameter space to avoid oversized models.
1991
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
1992
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-4, log=True),
1993
+ # "d_model": lambda t: t.suggest_int('d_model', 8, 64, step=8),
1994
+ "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
1995
+ "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
1996
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.2),
1997
+ "weight_decay": lambda t: t.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
1998
+ }
1999
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
2000
+ param_space["tw_power"] = lambda t: t.suggest_float(
2001
+ 'tw_power', 1.0, 2.0)
2002
+ geo_enabled = bool(
2003
+ self.ctx.geo_token_cols or self.ctx.config.geo_feature_nmes)
2004
+ if geo_enabled:
2005
+ # Only tune GNN-related hyperparams when geo tokens are enabled.
2006
+ param_space.update({
2007
+ "geo_token_hidden_dim": lambda t: t.suggest_int('geo_token_hidden_dim', 16, 128, step=16),
2008
+ "geo_token_layers": lambda t: t.suggest_int('geo_token_layers', 1, 4),
2009
+ "geo_token_k_neighbors": lambda t: t.suggest_int('geo_token_k_neighbors', 5, 20),
2010
+ "geo_token_dropout": lambda t: t.suggest_float('geo_token_dropout', 0.0, 0.3),
2011
+ "geo_token_learning_rate": lambda t: t.suggest_float('geo_token_learning_rate', 1e-4, 5e-3, log=True),
2012
+ })
2013
+
2014
+ metric_ctx: Dict[str, Any] = {}
2015
+
2016
+ def data_provider():
2017
+ data = self.ctx.train_data
2018
+ return data[self.ctx.factor_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
2019
+
2020
+ def model_builder(params):
2021
+ d_model = int(params["d_model"])
2022
+ n_layers = int(params["n_layers"])
2023
+ num_numeric_tokens = self._resolve_numeric_tokens()
2024
+ token_count = num_numeric_tokens + len(self.ctx.cate_list)
2025
+ if geo_enabled:
2026
+ token_count += 1
2027
+ approx_units = d_model * n_layers * max(1, token_count)
2028
+ if approx_units > 12_000_000:
2029
+ print(
2030
+ f"[FTTrainer] Trial pruned early: d_model={d_model}, n_layers={n_layers} -> approx_units={approx_units}")
2031
+ raise optuna.TrialPruned(
2032
+ "config exceeds safe memory budget; prune before training")
2033
+ geo_params_local = {k: v for k, v in params.items()
2034
+ if k.startswith("geo_token_")}
2035
+
2036
+ tw_power = params.get("tw_power")
2037
+ if self.ctx.task_type == 'regression':
2038
+ base_tw = self.ctx.default_tweedie_power()
2039
+ if self.ctx.obj in ('count:poisson', 'reg:gamma'):
2040
+ tw_power = base_tw
2041
+ elif tw_power is None:
2042
+ tw_power = base_tw
2043
+ metric_ctx["tw_power"] = tw_power
2044
+
2045
+ adaptive_heads, _ = self._resolve_adaptive_heads(
2046
+ d_model=d_model,
2047
+ requested_heads=params.get("n_heads")
2048
+ )
2049
+
2050
+ return FTTransformerSklearn(
2051
+ model_nme=self.ctx.model_nme,
2052
+ num_cols=self.ctx.num_features,
2053
+ cat_cols=self.ctx.cate_list,
2054
+ d_model=d_model,
2055
+ n_heads=adaptive_heads,
2056
+ n_layers=n_layers,
2057
+ dropout=params["dropout"],
2058
+ task_type=self.ctx.task_type,
2059
+ epochs=self.ctx.epochs,
2060
+ tweedie_power=tw_power,
2061
+ learning_rate=params["learning_rate"],
2062
+ patience=5,
2063
+ weight_decay=float(params.get("weight_decay", 0.0)),
2064
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
2065
+ use_ddp=self.ctx.config.use_ft_ddp,
2066
+ num_numeric_tokens=num_numeric_tokens,
2067
+ ).set_params({"_geo_params": geo_params_local} if geo_enabled else {})
2068
+
2069
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
2070
+ geo_train = geo_val = None
2071
+ if geo_enabled:
2072
+ geo_params = getattr(model, "_geo_params", {})
2073
+ built = self._build_geo_tokens_for_split(
2074
+ X_train, X_val, geo_params)
2075
+ if built is not None:
2076
+ geo_train, geo_val, _, _ = built
2077
+ elif not self._cv_geo_warned:
2078
+ print(
2079
+ "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
2080
+ flush=True,
2081
+ )
2082
+ self._cv_geo_warned = True
2083
+ model.fit(
2084
+ X_train, y_train, w_train,
2085
+ X_val, y_val, w_val,
2086
+ trial=trial_obj,
2087
+ geo_train=geo_train,
2088
+ geo_val=geo_val
2089
+ )
2090
+ return model.predict(X_val, geo_tokens=geo_val)
2091
+
2092
+ def metric_fn(y_true, y_pred, weight):
2093
+ if self.ctx.task_type == 'regression':
2094
+ return mean_tweedie_deviance(
2095
+ y_true,
2096
+ y_pred,
2097
+ sample_weight=weight,
2098
+ power=metric_ctx.get("tw_power", 1.5)
2099
+ )
2100
+ return log_loss(y_true, y_pred, sample_weight=weight)
2101
+
2102
+ data_for_cap = data_provider()[0]
2103
+ max_rows_for_ft_bo = min(1000000, int(len(data_for_cap)/2))
2104
+
2105
+ return self.cross_val_generic(
2106
+ trial=trial,
2107
+ hyperparameter_space=param_space,
2108
+ data_provider=data_provider,
2109
+ model_builder=model_builder,
2110
+ metric_fn=metric_fn,
2111
+ sample_limit=max_rows_for_ft_bo if len(
2112
+ data_for_cap) > max_rows_for_ft_bo > 0 else None,
2113
+ fit_predict_fn=fit_predict,
2114
+ cleanup_fn=lambda m: getattr(
2115
+ getattr(m, "ft", None), "to", lambda *_args, **_kwargs: None)("cpu")
2116
+ )
2117
+
2118
+ def train(self) -> None:
2119
+ if not self.best_params:
2120
+ raise RuntimeError("Run tune() first to obtain best FT-Transformer parameters.")
2121
+ resolved_params = dict(self.best_params)
2122
+ d_model_value = resolved_params.get("d_model", 64)
2123
+ adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
2124
+ d_model=d_model_value,
2125
+ requested_heads=resolved_params.get("n_heads")
2126
+ )
2127
+ if heads_adjusted:
2128
+ print(f"[FTTrainer] Auto-adjusted n_heads from "
2129
+ f"{resolved_params.get('n_heads')} to {adaptive_heads} "
2130
+ f"(d_model={d_model_value}).")
2131
+ resolved_params["n_heads"] = adaptive_heads
2132
+
2133
+ use_refit = bool(getattr(self.ctx.config, "final_refit", True))
2134
+ refit_epochs = None
2135
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
2136
+ y_all = self.ctx.train_data[self.ctx.resp_nme]
2137
+ w_all = self.ctx.train_data[self.ctx.weight_nme]
2138
+ if use_refit and 0.0 < float(self.ctx.prop_test) < 1.0 and len(X_all) >= 10:
2139
+ splitter = ShuffleSplit(
2140
+ n_splits=1,
2141
+ test_size=self.ctx.prop_test,
2142
+ random_state=self.ctx.rand_seed,
2143
+ )
2144
+ train_idx, val_idx = next(splitter.split(X_all))
2145
+ tmp_model = FTTransformerSklearn(
2146
+ model_nme=self.ctx.model_nme,
2147
+ num_cols=self.ctx.num_features,
2148
+ cat_cols=self.ctx.cate_list,
2149
+ task_type=self.ctx.task_type,
2150
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
2151
+ use_ddp=self.ctx.config.use_ft_ddp,
2152
+ num_numeric_tokens=self._resolve_numeric_tokens(),
2153
+ weight_decay=float(resolved_params.get("weight_decay", 0.0)),
2154
+ )
2155
+ tmp_model.set_params(resolved_params)
2156
+ geo_train_full = self.ctx.train_geo_tokens
2157
+ geo_train = None if geo_train_full is None else geo_train_full.iloc[train_idx]
2158
+ geo_val = None if geo_train_full is None else geo_train_full.iloc[val_idx]
2159
+ tmp_model.fit(
2160
+ X_all.iloc[train_idx],
2161
+ y_all.iloc[train_idx],
2162
+ w_all.iloc[train_idx],
2163
+ X_all.iloc[val_idx],
2164
+ y_all.iloc[val_idx],
2165
+ w_all.iloc[val_idx],
2166
+ trial=None,
2167
+ geo_train=geo_train,
2168
+ geo_val=geo_val,
2169
+ )
2170
+ refit_epochs = self._resolve_best_epoch(
2171
+ getattr(tmp_model, "training_history", None),
2172
+ default_epochs=int(self.ctx.epochs),
2173
+ )
2174
+ getattr(getattr(tmp_model, "ft", None), "to",
2175
+ lambda *_args, **_kwargs: None)("cpu")
2176
+ self._clean_gpu()
2177
+
2178
+ self.model = FTTransformerSklearn(
2179
+ model_nme=self.ctx.model_nme,
2180
+ num_cols=self.ctx.num_features,
2181
+ cat_cols=self.ctx.cate_list,
2182
+ task_type=self.ctx.task_type,
2183
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
2184
+ use_ddp=self.ctx.config.use_ft_ddp,
2185
+ num_numeric_tokens=self._resolve_numeric_tokens(),
2186
+ weight_decay=float(resolved_params.get("weight_decay", 0.0)),
2187
+ )
2188
+ if refit_epochs is not None:
2189
+ self.model.epochs = int(refit_epochs)
2190
+ self.model.set_params(resolved_params)
2191
+ self.best_params = resolved_params
2192
+ loss_plot_path = self.output.plot_path(
2193
+ f'loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
2194
+ self.model.loss_curve_path = loss_plot_path
2195
+ geo_train = self.ctx.train_geo_tokens
2196
+ geo_test = self.ctx.test_geo_tokens
2197
+ fit_kwargs = {}
2198
+ predict_kwargs_train = None
2199
+ predict_kwargs_test = None
2200
+ if geo_train is not None and geo_test is not None:
2201
+ fit_kwargs["geo_train"] = geo_train
2202
+ predict_kwargs_train = {"geo_tokens": geo_train}
2203
+ predict_kwargs_test = {"geo_tokens": geo_test}
2204
+ self._fit_predict_cache(
2205
+ self.model,
2206
+ self.ctx.train_data[self.ctx.factor_nmes],
2207
+ self.ctx.train_data[self.ctx.resp_nme],
2208
+ sample_weight=self.ctx.train_data[self.ctx.weight_nme],
2209
+ pred_prefix='ft',
2210
+ sample_weight_arg='w_train',
2211
+ fit_kwargs=fit_kwargs,
2212
+ predict_kwargs_train=predict_kwargs_train,
2213
+ predict_kwargs_test=predict_kwargs_test
2214
+ )
2215
+ self.ctx.ft_best = self.model
2216
+
2217
+ def ensemble_predict(self, k: int) -> None:
2218
+ if not self.best_params:
2219
+ raise RuntimeError("Run tune() first to obtain best FT-Transformer parameters.")
2220
+ k = max(2, int(k))
2221
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
2222
+ y_all = self.ctx.train_data[self.ctx.resp_nme]
2223
+ w_all = self.ctx.train_data[self.ctx.weight_nme]
2224
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
2225
+ n_samples = len(X_all)
2226
+ if n_samples < k:
2227
+ print(
2228
+ f"[FT Ensemble] n_samples={n_samples} < k={k}; skip ensemble.",
2229
+ flush=True,
2230
+ )
2231
+ return
2232
+
2233
+ geo_train_full = self.ctx.train_geo_tokens
2234
+ geo_test_full = self.ctx.test_geo_tokens
2235
+
2236
+ resolved_params = dict(self.best_params)
2237
+ default_d_model = getattr(self.model, "d_model", 64)
2238
+ adaptive_heads, _ = self._resolve_adaptive_heads(
2239
+ d_model=resolved_params.get("d_model", default_d_model),
2240
+ requested_heads=resolved_params.get("n_heads")
2241
+ )
2242
+ resolved_params["n_heads"] = adaptive_heads
2243
+
2244
+ splitter = KFold(
2245
+ n_splits=k,
2246
+ shuffle=True,
2247
+ random_state=self.ctx.rand_seed,
2248
+ )
2249
+ preds_train_sum = np.zeros(n_samples, dtype=np.float64)
2250
+ preds_test_sum = np.zeros(len(X_test), dtype=np.float64)
2251
+
2252
+ for train_idx, val_idx in splitter.split(X_all):
2253
+ model = FTTransformerSklearn(
2254
+ model_nme=self.ctx.model_nme,
2255
+ num_cols=self.ctx.num_features,
2256
+ cat_cols=self.ctx.cate_list,
2257
+ task_type=self.ctx.task_type,
2258
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
2259
+ use_ddp=self.ctx.config.use_ft_ddp,
2260
+ num_numeric_tokens=self._resolve_numeric_tokens(),
2261
+ weight_decay=float(resolved_params.get("weight_decay", 0.0)),
2262
+ )
2263
+ model.set_params(resolved_params)
2264
+
2265
+ geo_train = geo_val = None
2266
+ if geo_train_full is not None:
2267
+ geo_train = geo_train_full.iloc[train_idx]
2268
+ geo_val = geo_train_full.iloc[val_idx]
2269
+
2270
+ model.fit(
2271
+ X_all.iloc[train_idx],
2272
+ y_all.iloc[train_idx],
2273
+ w_all.iloc[train_idx],
2274
+ X_all.iloc[val_idx],
2275
+ y_all.iloc[val_idx],
2276
+ w_all.iloc[val_idx],
2277
+ trial=None,
2278
+ geo_train=geo_train,
2279
+ geo_val=geo_val,
2280
+ )
2281
+
2282
+ pred_train = model.predict(X_all, geo_tokens=geo_train_full)
2283
+ pred_test = model.predict(X_test, geo_tokens=geo_test_full)
2284
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
2285
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
2286
+ getattr(getattr(model, "ft", None), "to",
2287
+ lambda *_args, **_kwargs: None)("cpu")
2288
+ self._clean_gpu()
2289
+
2290
+ preds_train = preds_train_sum / float(k)
2291
+ preds_test = preds_test_sum / float(k)
2292
+ self._cache_predictions("ft", preds_train, preds_test)
2293
+
2294
+ def train_as_feature(self, pred_prefix: str = "ft_feat", feature_mode: str = "prediction") -> None:
2295
+ """Train FT-Transformer only to generate features (not recorded as final model)."""
2296
+ if not self.best_params:
2297
+ raise RuntimeError("Run tune() first to obtain best FT-Transformer parameters.")
2298
+ self.model = FTTransformerSklearn(
2299
+ model_nme=self.ctx.model_nme,
2300
+ num_cols=self.ctx.num_features,
2301
+ cat_cols=self.ctx.cate_list,
2302
+ task_type=self.ctx.task_type,
2303
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
2304
+ use_ddp=self.ctx.config.use_ft_ddp,
2305
+ num_numeric_tokens=self._resolve_numeric_tokens(),
2306
+ )
2307
+ resolved_params = dict(self.best_params)
2308
+ adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
2309
+ d_model=resolved_params.get("d_model", self.model.d_model),
2310
+ requested_heads=resolved_params.get("n_heads")
2311
+ )
2312
+ if heads_adjusted:
2313
+ print(f"[FTTrainer] Auto-adjusted n_heads from "
2314
+ f"{resolved_params.get('n_heads')} to {adaptive_heads} "
2315
+ f"(d_model={resolved_params.get('d_model', self.model.d_model)}).")
2316
+ resolved_params["n_heads"] = adaptive_heads
2317
+ self.model.set_params(resolved_params)
2318
+ self.best_params = resolved_params
2319
+
2320
+ geo_train = self.ctx.train_geo_tokens
2321
+ geo_test = self.ctx.test_geo_tokens
2322
+ fit_kwargs = {}
2323
+ predict_kwargs_train = None
2324
+ predict_kwargs_test = None
2325
+ if geo_train is not None and geo_test is not None:
2326
+ fit_kwargs["geo_train"] = geo_train
2327
+ predict_kwargs_train = {"geo_tokens": geo_train}
2328
+ predict_kwargs_test = {"geo_tokens": geo_test}
2329
+
2330
+ if feature_mode not in ("prediction", "embedding"):
2331
+ raise ValueError(
2332
+ f"Unsupported feature_mode='{feature_mode}', expected 'prediction' or 'embedding'.")
2333
+ if feature_mode == "embedding":
2334
+ predict_kwargs_train = dict(predict_kwargs_train or {})
2335
+ predict_kwargs_test = dict(predict_kwargs_test or {})
2336
+ predict_kwargs_train["return_embedding"] = True
2337
+ predict_kwargs_test["return_embedding"] = True
2338
+
2339
+ self._fit_predict_cache(
2340
+ self.model,
2341
+ self.ctx.train_data[self.ctx.factor_nmes],
2342
+ self.ctx.train_data[self.ctx.resp_nme],
2343
+ sample_weight=self.ctx.train_data[self.ctx.weight_nme],
2344
+ pred_prefix=pred_prefix,
2345
+ sample_weight_arg='w_train',
2346
+ fit_kwargs=fit_kwargs,
2347
+ predict_kwargs_train=predict_kwargs_train,
2348
+ predict_kwargs_test=predict_kwargs_test,
2349
+ record_label=False
2350
+ )
2351
+
2352
+ def pretrain_unsupervised_as_feature(self,
2353
+ pred_prefix: str = "ft_uemb",
2354
+ params: Optional[Dict[str,
2355
+ Any]] = None,
2356
+ mask_prob_num: float = 0.15,
2357
+ mask_prob_cat: float = 0.15,
2358
+ num_loss_weight: float = 1.0,
2359
+ cat_loss_weight: float = 1.0) -> None:
2360
+ """Self-supervised pretraining (masked reconstruction) and cache embeddings."""
2361
+ self.model = FTTransformerSklearn(
2362
+ model_nme=self.ctx.model_nme,
2363
+ num_cols=self.ctx.num_features,
2364
+ cat_cols=self.ctx.cate_list,
2365
+ task_type=self.ctx.task_type,
2366
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
2367
+ use_ddp=self.ctx.config.use_ft_ddp,
2368
+ num_numeric_tokens=self._resolve_numeric_tokens(),
2369
+ )
2370
+ resolved_params = dict(params or {})
2371
+ # Reuse supervised tuning structure params unless explicitly overridden.
2372
+ if not resolved_params and self.best_params:
2373
+ resolved_params = dict(self.best_params)
2374
+
2375
+ # If params include masked reconstruction fields, they take precedence.
2376
+ mask_prob_num = float(resolved_params.pop(
2377
+ "mask_prob_num", mask_prob_num))
2378
+ mask_prob_cat = float(resolved_params.pop(
2379
+ "mask_prob_cat", mask_prob_cat))
2380
+ num_loss_weight = float(resolved_params.pop(
2381
+ "num_loss_weight", num_loss_weight))
2382
+ cat_loss_weight = float(resolved_params.pop(
2383
+ "cat_loss_weight", cat_loss_weight))
2384
+
2385
+ adaptive_heads, heads_adjusted = self._resolve_adaptive_heads(
2386
+ d_model=resolved_params.get("d_model", self.model.d_model),
2387
+ requested_heads=resolved_params.get("n_heads")
2388
+ )
2389
+ if heads_adjusted:
2390
+ print(f"[FTTrainer] Auto-adjusted n_heads from "
2391
+ f"{resolved_params.get('n_heads')} to {adaptive_heads} "
2392
+ f"(d_model={resolved_params.get('d_model', self.model.d_model)}).")
2393
+ resolved_params["n_heads"] = adaptive_heads
2394
+ if resolved_params:
2395
+ self.model.set_params(resolved_params)
2396
+
2397
+ loss_plot_path = self.output.plot_path(
2398
+ f'loss_{self.ctx.model_nme}_FTTransformerUnsupervised.png')
2399
+ self.model.loss_curve_path = loss_plot_path
2400
+
2401
+ # Build a simple holdout split for pretraining early stopping.
2402
+ X_all = self.ctx.train_data[self.ctx.factor_nmes]
2403
+ idx = np.arange(len(X_all))
2404
+ splitter = ShuffleSplit(
2405
+ n_splits=1,
2406
+ test_size=self.ctx.prop_test,
2407
+ random_state=self.ctx.rand_seed
2408
+ )
2409
+ train_idx, val_idx = next(splitter.split(idx))
2410
+ X_tr = X_all.iloc[train_idx]
2411
+ X_val = X_all.iloc[val_idx]
2412
+
2413
+ geo_all = self.ctx.train_geo_tokens
2414
+ geo_tr = geo_val = None
2415
+ if geo_all is not None:
2416
+ geo_tr = geo_all.loc[X_tr.index]
2417
+ geo_val = geo_all.loc[X_val.index]
2418
+
2419
+ self.model.fit_unsupervised(
2420
+ X_tr,
2421
+ X_val=X_val,
2422
+ geo_train=geo_tr,
2423
+ geo_val=geo_val,
2424
+ mask_prob_num=mask_prob_num,
2425
+ mask_prob_cat=mask_prob_cat,
2426
+ num_loss_weight=num_loss_weight,
2427
+ cat_loss_weight=cat_loss_weight
2428
+ )
2429
+
2430
+ geo_train_full = self.ctx.train_geo_tokens
2431
+ geo_test_full = self.ctx.test_geo_tokens
2432
+ predict_kwargs_train = {"return_embedding": True}
2433
+ predict_kwargs_test = {"return_embedding": True}
2434
+ if geo_train_full is not None and geo_test_full is not None:
2435
+ predict_kwargs_train["geo_tokens"] = geo_train_full
2436
+ predict_kwargs_test["geo_tokens"] = geo_test_full
2437
+
2438
+ self._predict_and_cache(
2439
+ self.model,
2440
+ pred_prefix=pred_prefix,
2441
+ predict_kwargs_train=predict_kwargs_train,
2442
+ predict_kwargs_test=predict_kwargs_test
2443
+ )
2444
+
2445
+
2446
+ # =============================================================================