ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1020 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import timedelta
4
+ import gc
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
8
+
9
+ import joblib
10
+ import numpy as np
11
+ import optuna
12
+ import pandas as pd
13
+ import torch
14
+ try: # pragma: no cover
15
+ import torch.distributed as dist # type: ignore
16
+ except Exception: # pragma: no cover
17
+ dist = None # type: ignore
18
+ from sklearn.model_selection import (
19
+ GroupKFold,
20
+ GroupShuffleSplit,
21
+ KFold,
22
+ ShuffleSplit,
23
+ TimeSeriesSplit,
24
+ )
25
+ from sklearn.preprocessing import StandardScaler
26
+
27
+ from ..config_preprocess import BayesOptConfig, OutputManager
28
+ from ..utils import DistributedUtils, EPS, ensure_parent_dir
29
+
30
+ class _OrderSplitter:
31
+ def __init__(self, splitter, order: np.ndarray) -> None:
32
+ self._splitter = splitter
33
+ self._order = np.asarray(order)
34
+
35
+ def split(self, X, y=None, groups=None):
36
+ order = self._order
37
+ X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
38
+ for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
39
+ yield order[tr_idx], order[val_idx]
40
+
41
+ # =============================================================================
42
+ # Trainer system
43
+ # =============================================================================
44
+
45
+
46
+ class TrainerBase:
47
+ def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
48
+ self.ctx = context
49
+ self.label = label
50
+ self.model_name_prefix = model_name_prefix
51
+ self.model = None
52
+ self.best_params: Optional[Dict[str, Any]] = None
53
+ self.best_trial = None
54
+ self.study_name: Optional[str] = None
55
+ self.enable_distributed_optuna: bool = False
56
+ self._distributed_forced_params: Optional[Dict[str, Any]] = None
57
+
58
+ def _dist_barrier(self, reason: str) -> None:
59
+ """DDP barrier wrapper used by distributed Optuna.
60
+
61
+ To debug "trial finished but next trial never starts" hangs, set these
62
+ environment variables (either in shell or config.json `env`):
63
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1` to print barrier enter/exit per-rank
64
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300` to fail fast instead of waiting forever
65
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL` and `NCCL_DEBUG=INFO` for PyTorch/NCCL logs
66
+ """
67
+ if dist is None:
68
+ return
69
+ try:
70
+ if not getattr(dist, "is_available", lambda: False)():
71
+ return
72
+ if not dist.is_initialized():
73
+ return
74
+ except Exception:
75
+ return
76
+
77
+ timeout_seconds = int(os.environ.get("BAYESOPT_DDP_BARRIER_TIMEOUT", "1800"))
78
+ debug_barrier = os.environ.get("BAYESOPT_DDP_BARRIER_DEBUG", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
79
+ rank = None
80
+ world = None
81
+ if debug_barrier:
82
+ try:
83
+ rank = dist.get_rank()
84
+ world = dist.get_world_size()
85
+ print(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
86
+ except Exception:
87
+ debug_barrier = False
88
+ try:
89
+ timeout = timedelta(seconds=timeout_seconds)
90
+ backend = None
91
+ try:
92
+ backend = dist.get_backend()
93
+ except Exception:
94
+ backend = None
95
+
96
+ # `monitored_barrier` is only implemented for GLOO; using it under NCCL
97
+ # will raise and can itself trigger a secondary hang. Prefer an async
98
+ # barrier with timeout for NCCL.
99
+ monitored = getattr(dist, "monitored_barrier", None)
100
+ if backend == "gloo" and callable(monitored):
101
+ monitored(timeout=timeout)
102
+ else:
103
+ work = None
104
+ try:
105
+ work = dist.barrier(async_op=True)
106
+ except TypeError:
107
+ work = None
108
+ if work is not None:
109
+ wait = getattr(work, "wait", None)
110
+ if callable(wait):
111
+ try:
112
+ wait(timeout=timeout)
113
+ except TypeError:
114
+ wait()
115
+ else:
116
+ dist.barrier()
117
+ else:
118
+ dist.barrier()
119
+ if debug_barrier:
120
+ print(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
121
+ except Exception as exc:
122
+ print(
123
+ f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
124
+ flush=True,
125
+ )
126
+ raise
127
+
128
+ @property
129
+ def config(self) -> BayesOptConfig:
130
+ return self.ctx.config
131
+
132
+ @property
133
+ def output(self) -> OutputManager:
134
+ return self.ctx.output_manager
135
+
136
+ def _get_model_filename(self) -> str:
137
+ ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
138
+ return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
139
+
140
+ def _resolve_optuna_storage_url(self) -> Optional[str]:
141
+ storage = getattr(self.config, "optuna_storage", None)
142
+ if not storage:
143
+ return None
144
+ storage_str = str(storage).strip()
145
+ if not storage_str:
146
+ return None
147
+ if "://" in storage_str or storage_str == ":memory:":
148
+ return storage_str
149
+ path = Path(storage_str)
150
+ path = path.resolve()
151
+ ensure_parent_dir(str(path))
152
+ return f"sqlite:///{path.as_posix()}"
153
+
154
+ def _resolve_optuna_study_name(self) -> str:
155
+ prefix = getattr(self.config, "optuna_study_prefix",
156
+ None) or "bayesopt"
157
+ raw = f"{prefix}_{self.ctx.model_nme}_{self.model_name_prefix}"
158
+ safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
159
+ return safe.lower()
160
+
161
+ def tune(self, max_evals: int, objective_fn=None) -> None:
162
+ # Generic Optuna tuning loop.
163
+ if objective_fn is None:
164
+ # If subclass doesn't provide objective_fn, default to cross_val.
165
+ objective_fn = self.cross_val
166
+
167
+ if self._should_use_distributed_optuna():
168
+ self._distributed_tune(max_evals, objective_fn)
169
+ return
170
+
171
+ total_trials = max(1, int(max_evals))
172
+ progress_counter = {"count": 0}
173
+
174
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
175
+ should_log = DistributedUtils.is_main_process()
176
+ if should_log:
177
+ current_idx = progress_counter["count"] + 1
178
+ print(
179
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
180
+ f"(trial_id={trial.number})."
181
+ )
182
+ try:
183
+ result = objective_fn(trial)
184
+ except RuntimeError as exc:
185
+ if "out of memory" in str(exc).lower():
186
+ print(
187
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
188
+ )
189
+ self._clean_gpu()
190
+ raise optuna.TrialPruned() from exc
191
+ raise
192
+ finally:
193
+ self._clean_gpu()
194
+ if should_log:
195
+ progress_counter["count"] = progress_counter["count"] + 1
196
+ trial_state = getattr(trial, "state", None)
197
+ state_repr = getattr(trial_state, "name", "OK")
198
+ print(
199
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
200
+ f"(status={state_repr})."
201
+ )
202
+ return result
203
+
204
+ storage_url = self._resolve_optuna_storage_url()
205
+ study_name = self._resolve_optuna_study_name()
206
+ study_kwargs: Dict[str, Any] = {
207
+ "direction": "minimize",
208
+ "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
209
+ }
210
+ if storage_url:
211
+ study_kwargs.update(
212
+ storage=storage_url,
213
+ study_name=study_name,
214
+ load_if_exists=True,
215
+ )
216
+
217
+ study = optuna.create_study(**study_kwargs)
218
+ self.study_name = getattr(study, "study_name", None)
219
+
220
+ def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
221
+ # Persist best_params after each trial to allow safe resume.
222
+ try:
223
+ best = getattr(check_study, "best_trial", None)
224
+ if best is None:
225
+ return
226
+ best_params = getattr(best, "params", None)
227
+ if not best_params:
228
+ return
229
+ params_path = self.output.result_path(
230
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
231
+ )
232
+ pd.DataFrame(best_params, index=[0]).to_csv(
233
+ params_path, index=False)
234
+ except Exception:
235
+ return
236
+
237
+ completed_states = (
238
+ optuna.trial.TrialState.COMPLETE,
239
+ optuna.trial.TrialState.PRUNED,
240
+ optuna.trial.TrialState.FAIL,
241
+ )
242
+ completed = len(study.get_trials(states=completed_states))
243
+ progress_counter["count"] = completed
244
+ remaining = max(0, total_trials - completed)
245
+ if remaining > 0:
246
+ study.optimize(
247
+ objective_wrapper,
248
+ n_trials=remaining,
249
+ callbacks=[checkpoint_callback],
250
+ )
251
+ self.best_params = study.best_params
252
+ self.best_trial = study.best_trial
253
+
254
+ # Save best params to CSV for reproducibility.
255
+ params_path = self.output.result_path(
256
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
257
+ )
258
+ pd.DataFrame(self.best_params, index=[0]).to_csv(
259
+ params_path, index=False)
260
+
261
+ def train(self) -> None:
262
+ raise NotImplementedError
263
+
264
+ def save(self) -> None:
265
+ if self.model is None:
266
+ print(f"[save] Warning: No model to save for {self.label}")
267
+ return
268
+
269
+ path = self.output.model_path(self._get_model_filename())
270
+ if self.label in ['Xgboost', 'GLM']:
271
+ joblib.dump(self.model, path)
272
+ else:
273
+ # PyTorch models can save state_dict or the full object.
274
+ # Legacy behavior: ResNetTrainer saves state_dict; FTTrainer saves full object.
275
+ if hasattr(self.model, 'resnet'): # ResNetSklearn model
276
+ torch.save(self.model.resnet.state_dict(), path)
277
+ else: # FTTransformerSklearn or other PyTorch model
278
+ torch.save(self.model, path)
279
+
280
+ def load(self) -> None:
281
+ path = self.output.model_path(self._get_model_filename())
282
+ if not os.path.exists(path):
283
+ print(f"[load] Warning: Model file not found: {path}")
284
+ return
285
+
286
+ if self.label in ['Xgboost', 'GLM']:
287
+ self.model = joblib.load(path)
288
+ else:
289
+ # PyTorch loading depends on the model structure.
290
+ if self.label == 'ResNet' or self.label == 'ResNetClassifier':
291
+ # ResNet requires reconstructing the skeleton; handled by subclass.
292
+ pass
293
+ else:
294
+ # FT-Transformer serializes the whole object; load then move to device.
295
+ loaded = torch.load(path, map_location='cpu')
296
+ self._move_to_device(loaded)
297
+ self.model = loaded
298
+
299
+ def _move_to_device(self, model_obj):
300
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
301
+ if hasattr(model_obj, 'device'):
302
+ model_obj.device = device
303
+ if hasattr(model_obj, 'to'):
304
+ model_obj.to(device)
305
+ # Move nested submodules (ft/resnet/gnn) to the same device.
306
+ if hasattr(model_obj, 'ft'):
307
+ model_obj.ft.to(device)
308
+ if hasattr(model_obj, 'resnet'):
309
+ model_obj.resnet.to(device)
310
+ if hasattr(model_obj, 'gnn'):
311
+ model_obj.gnn.to(device)
312
+
313
+ def _should_use_distributed_optuna(self) -> bool:
314
+ if not self.enable_distributed_optuna:
315
+ return False
316
+ rank_env = os.environ.get("RANK")
317
+ world_env = os.environ.get("WORLD_SIZE")
318
+ local_env = os.environ.get("LOCAL_RANK")
319
+ if rank_env is None or world_env is None or local_env is None:
320
+ return False
321
+ try:
322
+ world_size = int(world_env)
323
+ except Exception:
324
+ return False
325
+ return world_size > 1
326
+
327
+ def _distributed_is_main(self) -> bool:
328
+ return DistributedUtils.is_main_process()
329
+
330
+ def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
331
+ if not self._should_use_distributed_optuna() or not self._distributed_is_main():
332
+ return
333
+ if dist is None:
334
+ return
335
+ DistributedUtils.setup_ddp()
336
+ if not dist.is_initialized():
337
+ return
338
+ message = [payload]
339
+ dist.broadcast_object_list(message, src=0)
340
+
341
+ def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
342
+ if not self._should_use_distributed_optuna():
343
+ return
344
+ if not self._distributed_is_main():
345
+ return
346
+ if dist is None:
347
+ return
348
+ self._distributed_send_command({"type": "RUN", "params": params})
349
+ if not dist.is_initialized():
350
+ return
351
+ # STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
352
+ self._dist_barrier("prepare_trial")
353
+
354
+ def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
355
+ if dist is None:
356
+ print(
357
+ f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
358
+ flush=True,
359
+ )
360
+ return
361
+ DistributedUtils.setup_ddp()
362
+ if not dist.is_initialized():
363
+ print(
364
+ f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
365
+ flush=True,
366
+ )
367
+ return
368
+ while True:
369
+ message = [None]
370
+ dist.broadcast_object_list(message, src=0)
371
+ payload = message[0]
372
+ if not isinstance(payload, dict):
373
+ continue
374
+ cmd = payload.get("type")
375
+ if cmd == "STOP":
376
+ best_params = payload.get("best_params")
377
+ if best_params is not None:
378
+ self.best_params = best_params
379
+ break
380
+ if cmd == "RUN":
381
+ params = payload.get("params") or {}
382
+ self._distributed_forced_params = params
383
+ # STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
384
+ self._dist_barrier("worker_start")
385
+ try:
386
+ objective_fn(None)
387
+ except optuna.TrialPruned:
388
+ pass
389
+ except Exception as exc:
390
+ print(
391
+ f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
392
+ finally:
393
+ self._clean_gpu()
394
+ # STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
395
+ self._dist_barrier("worker_end")
396
+
397
+ def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
398
+ if dist is None:
399
+ print(
400
+ f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
401
+ flush=True,
402
+ )
403
+ prev = self.enable_distributed_optuna
404
+ self.enable_distributed_optuna = False
405
+ try:
406
+ self.tune(max_evals, objective_fn)
407
+ finally:
408
+ self.enable_distributed_optuna = prev
409
+ return
410
+ DistributedUtils.setup_ddp()
411
+ if not dist.is_initialized():
412
+ rank_env = os.environ.get("RANK", "0")
413
+ if str(rank_env) != "0":
414
+ print(
415
+ f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
416
+ flush=True,
417
+ )
418
+ return
419
+ print(
420
+ f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
421
+ flush=True,
422
+ )
423
+ prev = self.enable_distributed_optuna
424
+ self.enable_distributed_optuna = False
425
+ try:
426
+ self.tune(max_evals, objective_fn)
427
+ finally:
428
+ self.enable_distributed_optuna = prev
429
+ return
430
+ if not self._distributed_is_main():
431
+ self._distributed_worker_loop(objective_fn)
432
+ return
433
+
434
+ total_trials = max(1, int(max_evals))
435
+ progress_counter = {"count": 0}
436
+
437
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
438
+ should_log = True
439
+ if should_log:
440
+ current_idx = progress_counter["count"] + 1
441
+ print(
442
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
443
+ f"(trial_id={trial.number})."
444
+ )
445
+ try:
446
+ result = objective_fn(trial)
447
+ except RuntimeError as exc:
448
+ if "out of memory" in str(exc).lower():
449
+ print(
450
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
451
+ )
452
+ self._clean_gpu()
453
+ raise optuna.TrialPruned() from exc
454
+ raise
455
+ finally:
456
+ self._clean_gpu()
457
+ if should_log:
458
+ progress_counter["count"] = progress_counter["count"] + 1
459
+ trial_state = getattr(trial, "state", None)
460
+ state_repr = getattr(trial_state, "name", "OK")
461
+ print(
462
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
463
+ f"(status={state_repr})."
464
+ )
465
+ # STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
466
+ self._dist_barrier("trial_end")
467
+ return result
468
+
469
+ storage_url = self._resolve_optuna_storage_url()
470
+ study_name = self._resolve_optuna_study_name()
471
+ study_kwargs: Dict[str, Any] = {
472
+ "direction": "minimize",
473
+ "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
474
+ }
475
+ if storage_url:
476
+ study_kwargs.update(
477
+ storage=storage_url,
478
+ study_name=study_name,
479
+ load_if_exists=True,
480
+ )
481
+ study = optuna.create_study(**study_kwargs)
482
+ self.study_name = getattr(study, "study_name", None)
483
+
484
+ def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
485
+ try:
486
+ best = getattr(check_study, "best_trial", None)
487
+ if best is None:
488
+ return
489
+ best_params = getattr(best, "params", None)
490
+ if not best_params:
491
+ return
492
+ params_path = self.output.result_path(
493
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
494
+ )
495
+ pd.DataFrame(best_params, index=[0]).to_csv(
496
+ params_path, index=False)
497
+ except Exception:
498
+ return
499
+
500
+ completed_states = (
501
+ optuna.trial.TrialState.COMPLETE,
502
+ optuna.trial.TrialState.PRUNED,
503
+ optuna.trial.TrialState.FAIL,
504
+ )
505
+ completed = len(study.get_trials(states=completed_states))
506
+ progress_counter["count"] = completed
507
+ remaining = max(0, total_trials - completed)
508
+ try:
509
+ if remaining > 0:
510
+ study.optimize(
511
+ objective_wrapper,
512
+ n_trials=remaining,
513
+ callbacks=[checkpoint_callback],
514
+ )
515
+ self.best_params = study.best_params
516
+ self.best_trial = study.best_trial
517
+ params_path = self.output.result_path(
518
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
519
+ )
520
+ pd.DataFrame(self.best_params, index=[0]).to_csv(
521
+ params_path, index=False)
522
+ finally:
523
+ self._distributed_send_command(
524
+ {"type": "STOP", "best_params": self.best_params})
525
+
526
+ def _clean_gpu(self):
527
+ gc.collect()
528
+ if torch.cuda.is_available():
529
+ device = None
530
+ try:
531
+ device = getattr(self, "device", None)
532
+ except Exception:
533
+ device = None
534
+ if isinstance(device, torch.device):
535
+ try:
536
+ torch.cuda.set_device(device)
537
+ except Exception:
538
+ pass
539
+ torch.cuda.empty_cache()
540
+ do_ipc_collect = os.environ.get("BAYESOPT_CUDA_IPC_COLLECT", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
541
+ do_sync = os.environ.get("BAYESOPT_CUDA_SYNC", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
542
+ if do_ipc_collect:
543
+ torch.cuda.ipc_collect()
544
+ if do_sync:
545
+ torch.cuda.synchronize()
546
+
547
+ def _standardize_fold(self,
548
+ X_train: pd.DataFrame,
549
+ X_val: pd.DataFrame,
550
+ columns: Optional[List[str]] = None
551
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
552
+ """Fit StandardScaler on the training fold and transform train/val features.
553
+
554
+ Args:
555
+ X_train: training features.
556
+ X_val: validation features.
557
+ columns: columns to scale (default: all).
558
+
559
+ Returns:
560
+ Scaled train/val features and the fitted scaler.
561
+ """
562
+ scaler = StandardScaler()
563
+ cols = list(columns) if columns else list(X_train.columns)
564
+ X_train_scaled = X_train.copy(deep=True)
565
+ X_val_scaled = X_val.copy(deep=True)
566
+ if cols:
567
+ scaler.fit(X_train_scaled[cols])
568
+ X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
569
+ X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
570
+ return X_train_scaled, X_val_scaled, scaler
571
+
572
+ def _resolve_train_val_indices(
573
+ self,
574
+ X_all: pd.DataFrame,
575
+ *,
576
+ allow_default: bool = False,
577
+ ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
578
+ val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
579
+ if not (0.0 < val_ratio < 1.0):
580
+ if not allow_default:
581
+ return None
582
+ val_ratio = 0.25
583
+ if len(X_all) < 10:
584
+ return None
585
+
586
+ strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
587
+ if strategy in {"time", "timeseries", "temporal"}:
588
+ time_col = getattr(self.ctx.config, "cv_time_col", None)
589
+ if not time_col:
590
+ raise ValueError("cv_time_col is required for time cv_strategy.")
591
+ if time_col not in self.ctx.train_data.columns:
592
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
593
+ ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
594
+ order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
595
+ index_set = set(X_all.index)
596
+ order_index = [idx for idx in order_index if idx in index_set]
597
+ order = X_all.index.get_indexer(order_index)
598
+ order = order[order >= 0]
599
+ cutoff = int(len(order) * (1.0 - val_ratio))
600
+ if cutoff <= 0 or cutoff >= len(order):
601
+ raise ValueError(
602
+ f"prop_test={val_ratio} leaves no data for train/val split.")
603
+ return order[:cutoff], order[cutoff:]
604
+
605
+ if strategy in {"group", "grouped"}:
606
+ group_col = getattr(self.ctx.config, "cv_group_col", None)
607
+ if not group_col:
608
+ raise ValueError("cv_group_col is required for group cv_strategy.")
609
+ if group_col not in self.ctx.train_data.columns:
610
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
611
+ groups = self.ctx.train_data.reindex(X_all.index)[group_col]
612
+ splitter = GroupShuffleSplit(
613
+ n_splits=1,
614
+ test_size=val_ratio,
615
+ random_state=self.ctx.rand_seed,
616
+ )
617
+ train_idx, val_idx = next(splitter.split(X_all, groups=groups))
618
+ return train_idx, val_idx
619
+
620
+ splitter = ShuffleSplit(
621
+ n_splits=1,
622
+ test_size=val_ratio,
623
+ random_state=self.ctx.rand_seed,
624
+ )
625
+ train_idx, val_idx = next(splitter.split(X_all))
626
+ return train_idx, val_idx
627
+
628
+ def _resolve_time_sample_indices(
629
+ self,
630
+ X_all: pd.DataFrame,
631
+ sample_limit: int,
632
+ ) -> Optional[pd.Index]:
633
+ if sample_limit <= 0:
634
+ return None
635
+ strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
636
+ if strategy not in {"time", "timeseries", "temporal"}:
637
+ return None
638
+ time_col = getattr(self.ctx.config, "cv_time_col", None)
639
+ if not time_col:
640
+ raise ValueError("cv_time_col is required for time cv_strategy.")
641
+ if time_col not in self.ctx.train_data.columns:
642
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
643
+ ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
644
+ order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
645
+ index_set = set(X_all.index)
646
+ order_index = [idx for idx in order_index if idx in index_set]
647
+ if not order_index:
648
+ return None
649
+ if len(order_index) > sample_limit:
650
+ order_index = order_index[-sample_limit:]
651
+ return pd.Index(order_index)
652
+
653
+ def _resolve_ensemble_splits(
654
+ self,
655
+ X_all: pd.DataFrame,
656
+ *,
657
+ k: int,
658
+ ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
659
+ k = max(2, int(k))
660
+ n_samples = len(X_all)
661
+ if n_samples < 2:
662
+ return None, 0
663
+
664
+ strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
665
+ if strategy in {"group", "grouped"}:
666
+ group_col = getattr(self.ctx.config, "cv_group_col", None)
667
+ if not group_col:
668
+ raise ValueError("cv_group_col is required for group cv_strategy.")
669
+ if group_col not in self.ctx.train_data.columns:
670
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
671
+ groups = self.ctx.train_data.reindex(X_all.index)[group_col]
672
+ n_groups = int(groups.nunique(dropna=False))
673
+ if n_groups < 2:
674
+ return None, 0
675
+ if k > n_groups:
676
+ k = n_groups
677
+ if k < 2:
678
+ return None, 0
679
+ splitter = GroupKFold(n_splits=k)
680
+ return splitter.split(X_all, y=None, groups=groups), k
681
+
682
+ if strategy in {"time", "timeseries", "temporal"}:
683
+ time_col = getattr(self.ctx.config, "cv_time_col", None)
684
+ if not time_col:
685
+ raise ValueError("cv_time_col is required for time cv_strategy.")
686
+ if time_col not in self.ctx.train_data.columns:
687
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
688
+ ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
689
+ order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
690
+ index_set = set(X_all.index)
691
+ order_index = [idx for idx in order_index if idx in index_set]
692
+ order = X_all.index.get_indexer(order_index)
693
+ order = order[order >= 0]
694
+ if len(order) < 2:
695
+ return None, 0
696
+ if len(order) <= k:
697
+ k = max(2, len(order) - 1)
698
+ if k < 2:
699
+ return None, 0
700
+ splitter = TimeSeriesSplit(n_splits=k)
701
+ return _OrderSplitter(splitter, order).split(X_all), k
702
+
703
+ if n_samples < k:
704
+ k = n_samples
705
+ if k < 2:
706
+ return None, 0
707
+ splitter = KFold(
708
+ n_splits=k,
709
+ shuffle=True,
710
+ random_state=self.ctx.rand_seed,
711
+ )
712
+ return splitter.split(X_all), k
713
+
714
+ def cross_val_generic(
715
+ self,
716
+ trial: optuna.trial.Trial,
717
+ hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
718
+ data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
719
+ model_builder: Callable[[Dict[str, Any]], Any],
720
+ metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
721
+ sample_limit: Optional[int] = None,
722
+ preprocess_fn: Optional[Callable[[
723
+ pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
724
+ fit_predict_fn: Optional[
725
+ Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
726
+ pd.DataFrame, pd.Series, Optional[pd.Series],
727
+ optuna.trial.Trial], np.ndarray]
728
+ ] = None,
729
+ cleanup_fn: Optional[Callable[[Any], None]] = None,
730
+ splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
731
+ """Generic holdout/CV helper to reuse tuning workflows.
732
+
733
+ Args:
734
+ trial: current Optuna trial.
735
+ hyperparameter_space: sampler dict keyed by parameter name.
736
+ data_provider: callback returning (X, y, sample_weight).
737
+ model_builder: callback to build a model per fold.
738
+ metric_fn: loss/score function taking y_true, y_pred, weight.
739
+ sample_limit: optional sample cap; random sample if exceeded.
740
+ preprocess_fn: optional per-fold preprocessing (X_train, X_val).
741
+ fit_predict_fn: optional custom fit/predict logic for validation.
742
+ cleanup_fn: optional cleanup callback per fold.
743
+ splitter: optional (train_idx, val_idx) iterator; defaults to cv_strategy config.
744
+
745
+ Returns:
746
+ Mean validation metric across folds.
747
+ """
748
+ params: Optional[Dict[str, Any]] = None
749
+ if self._distributed_forced_params is not None:
750
+ params = self._distributed_forced_params
751
+ self._distributed_forced_params = None
752
+ else:
753
+ if trial is None:
754
+ raise RuntimeError(
755
+ "Missing Optuna trial for parameter sampling.")
756
+ params = {name: sampler(trial)
757
+ for name, sampler in hyperparameter_space.items()}
758
+ if self._should_use_distributed_optuna():
759
+ self._distributed_prepare_trial(params)
760
+ X_all, y_all, w_all = data_provider()
761
+ cfg_limit = getattr(self.ctx.config, "bo_sample_limit", None)
762
+ if cfg_limit is not None:
763
+ cfg_limit = int(cfg_limit)
764
+ if cfg_limit > 0:
765
+ sample_limit = cfg_limit if sample_limit is None else min(sample_limit, cfg_limit)
766
+ if sample_limit is not None and len(X_all) > sample_limit:
767
+ sampled_idx = self._resolve_time_sample_indices(X_all, int(sample_limit))
768
+ if sampled_idx is None:
769
+ sampled_idx = X_all.sample(
770
+ n=sample_limit,
771
+ random_state=self.ctx.rand_seed
772
+ ).index
773
+ X_all = X_all.loc[sampled_idx]
774
+ y_all = y_all.loc[sampled_idx]
775
+ w_all = w_all.loc[sampled_idx] if w_all is not None else None
776
+
777
+ if splitter is None:
778
+ strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
779
+ val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
780
+ if not (0.0 < val_ratio < 1.0):
781
+ val_ratio = 0.25
782
+ cv_splits = getattr(self.ctx.config, "cv_splits", None)
783
+ if cv_splits is None:
784
+ cv_splits = max(2, int(round(1 / val_ratio)))
785
+ cv_splits = max(2, int(cv_splits))
786
+
787
+ if strategy in {"group", "grouped"}:
788
+ group_col = getattr(self.ctx.config, "cv_group_col", None)
789
+ if not group_col:
790
+ raise ValueError("cv_group_col is required for group cv_strategy.")
791
+ if group_col not in self.ctx.train_data.columns:
792
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
793
+ groups = self.ctx.train_data.reindex(X_all.index)[group_col]
794
+ split_iter = GroupKFold(n_splits=cv_splits).split(X_all, y_all, groups=groups)
795
+ elif strategy in {"time", "timeseries", "temporal"}:
796
+ time_col = getattr(self.ctx.config, "cv_time_col", None)
797
+ if not time_col:
798
+ raise ValueError("cv_time_col is required for time cv_strategy.")
799
+ if time_col not in self.ctx.train_data.columns:
800
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
801
+ ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
802
+ order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
803
+ index_set = set(X_all.index)
804
+ order_index = [idx for idx in order_index if idx in index_set]
805
+ order = X_all.index.get_indexer(order_index)
806
+ order = order[order >= 0]
807
+ if len(order) <= cv_splits:
808
+ cv_splits = max(2, len(order) - 1)
809
+ if cv_splits < 2:
810
+ raise ValueError("Not enough samples for time-series CV.")
811
+ split_iter = _OrderSplitter(TimeSeriesSplit(n_splits=cv_splits), order).split(X_all)
812
+ else:
813
+ split_iter = ShuffleSplit(
814
+ n_splits=cv_splits,
815
+ test_size=val_ratio,
816
+ random_state=self.ctx.rand_seed
817
+ ).split(X_all)
818
+ else:
819
+ if hasattr(splitter, "split"):
820
+ split_iter = splitter.split(X_all, y_all, groups=None)
821
+ else:
822
+ split_iter = splitter
823
+
824
+ losses: List[float] = []
825
+ for train_idx, val_idx in split_iter:
826
+ X_train = X_all.iloc[train_idx]
827
+ y_train = y_all.iloc[train_idx]
828
+ X_val = X_all.iloc[val_idx]
829
+ y_val = y_all.iloc[val_idx]
830
+ w_train = w_all.iloc[train_idx] if w_all is not None else None
831
+ w_val = w_all.iloc[val_idx] if w_all is not None else None
832
+
833
+ if preprocess_fn:
834
+ X_train, X_val = preprocess_fn(X_train, X_val)
835
+
836
+ model = model_builder(params)
837
+ try:
838
+ if fit_predict_fn:
839
+ y_pred = fit_predict_fn(
840
+ model, X_train, y_train, w_train,
841
+ X_val, y_val, w_val, trial
842
+ )
843
+ else:
844
+ fit_kwargs = {}
845
+ if w_train is not None:
846
+ fit_kwargs["sample_weight"] = w_train
847
+ model.fit(X_train, y_train, **fit_kwargs)
848
+ y_pred = model.predict(X_val)
849
+ losses.append(metric_fn(y_val, y_pred, w_val))
850
+ finally:
851
+ if cleanup_fn:
852
+ cleanup_fn(model)
853
+ self._clean_gpu()
854
+
855
+ return float(np.mean(losses))
856
+
857
+ # Prediction + caching logic.
858
+ def _predict_and_cache(self,
859
+ model,
860
+ pred_prefix: str,
861
+ use_oht: bool = False,
862
+ design_fn=None,
863
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
864
+ predict_kwargs_test: Optional[Dict[str, Any]] = None,
865
+ predict_fn: Optional[Callable[..., Any]] = None) -> None:
866
+ if design_fn:
867
+ X_train = design_fn(train=True)
868
+ X_test = design_fn(train=False)
869
+ elif use_oht:
870
+ X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
871
+ X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
872
+ else:
873
+ X_train = self.ctx.train_data[self.ctx.factor_nmes]
874
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
875
+
876
+ predictor = predict_fn or model.predict
877
+ preds_train = predictor(X_train, **(predict_kwargs_train or {}))
878
+ preds_test = predictor(X_test, **(predict_kwargs_test or {}))
879
+ preds_train = np.asarray(preds_train)
880
+ preds_test = np.asarray(preds_test)
881
+
882
+ if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
883
+ col_name = f'pred_{pred_prefix}'
884
+ self.ctx.train_data[col_name] = preds_train.reshape(-1)
885
+ self.ctx.test_data[col_name] = preds_test.reshape(-1)
886
+ self.ctx.train_data[f'w_{col_name}'] = (
887
+ self.ctx.train_data[col_name] *
888
+ self.ctx.train_data[self.ctx.weight_nme]
889
+ )
890
+ self.ctx.test_data[f'w_{col_name}'] = (
891
+ self.ctx.test_data[col_name] *
892
+ self.ctx.test_data[self.ctx.weight_nme]
893
+ )
894
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
895
+ return
896
+
897
+ # Vector outputs (e.g., embeddings) are expanded into pred_<prefix>_0.. columns.
898
+ if preds_train.ndim != 2:
899
+ raise ValueError(
900
+ f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
901
+ if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
902
+ raise ValueError(
903
+ f"Train/test prediction dims mismatch for '{pred_prefix}': "
904
+ f"{preds_train.shape} vs {preds_test.shape}")
905
+ for j in range(preds_train.shape[1]):
906
+ col_name = f'pred_{pred_prefix}_{j}'
907
+ self.ctx.train_data[col_name] = preds_train[:, j]
908
+ self.ctx.test_data[col_name] = preds_test[:, j]
909
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
910
+
911
+ def _cache_predictions(self,
912
+ pred_prefix: str,
913
+ preds_train,
914
+ preds_test) -> None:
915
+ preds_train = np.asarray(preds_train)
916
+ preds_test = np.asarray(preds_test)
917
+ if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
918
+ if preds_test.ndim > 1:
919
+ preds_test = preds_test.reshape(-1)
920
+ col_name = f'pred_{pred_prefix}'
921
+ self.ctx.train_data[col_name] = preds_train.reshape(-1)
922
+ self.ctx.test_data[col_name] = preds_test.reshape(-1)
923
+ self.ctx.train_data[f'w_{col_name}'] = (
924
+ self.ctx.train_data[col_name] *
925
+ self.ctx.train_data[self.ctx.weight_nme]
926
+ )
927
+ self.ctx.test_data[f'w_{col_name}'] = (
928
+ self.ctx.test_data[col_name] *
929
+ self.ctx.test_data[self.ctx.weight_nme]
930
+ )
931
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
932
+ return
933
+
934
+ if preds_train.ndim != 2:
935
+ raise ValueError(
936
+ f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
937
+ if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
938
+ raise ValueError(
939
+ f"Train/test prediction dims mismatch for '{pred_prefix}': "
940
+ f"{preds_train.shape} vs {preds_test.shape}")
941
+ for j in range(preds_train.shape[1]):
942
+ col_name = f'pred_{pred_prefix}_{j}'
943
+ self.ctx.train_data[col_name] = preds_train[:, j]
944
+ self.ctx.test_data[col_name] = preds_test[:, j]
945
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
946
+
947
+ def _maybe_cache_predictions(self, pred_prefix: str, preds_train, preds_test) -> None:
948
+ cfg = getattr(self.ctx, "config", None)
949
+ if cfg is None or not bool(getattr(cfg, "cache_predictions", False)):
950
+ return
951
+ fmt = str(getattr(cfg, "prediction_cache_format", "parquet") or "parquet").lower()
952
+ cache_dir = getattr(cfg, "prediction_cache_dir", None)
953
+ if cache_dir:
954
+ target_dir = Path(str(cache_dir))
955
+ if not target_dir.is_absolute():
956
+ target_dir = Path(self.output.result_dir) / target_dir
957
+ else:
958
+ target_dir = Path(self.output.result_dir) / "predictions"
959
+ target_dir.mkdir(parents=True, exist_ok=True)
960
+
961
+ def _build_frame(preds, split_label: str) -> pd.DataFrame:
962
+ arr = np.asarray(preds)
963
+ if arr.ndim <= 1:
964
+ return pd.DataFrame({f"pred_{pred_prefix}": arr.reshape(-1)})
965
+ cols = [f"pred_{pred_prefix}_{i}" for i in range(arr.shape[1])]
966
+ return pd.DataFrame(arr, columns=cols)
967
+
968
+ for split_label, preds in [("train", preds_train), ("test", preds_test)]:
969
+ frame = _build_frame(preds, split_label)
970
+ filename = f"{self.ctx.model_nme}_{pred_prefix}_{split_label}.{ 'csv' if fmt == 'csv' else 'parquet' }"
971
+ path = target_dir / filename
972
+ try:
973
+ if fmt == "csv":
974
+ frame.to_csv(path, index=False)
975
+ else:
976
+ frame.to_parquet(path, index=False)
977
+ except Exception:
978
+ pass
979
+
980
+ def _resolve_best_epoch(self,
981
+ history: Optional[Dict[str, List[float]]],
982
+ default_epochs: int) -> int:
983
+ if not history:
984
+ return max(1, int(default_epochs))
985
+ vals = history.get("val") or []
986
+ if not vals:
987
+ return max(1, int(default_epochs))
988
+ best_idx = int(np.nanargmin(vals))
989
+ return max(1, best_idx + 1)
990
+
991
+ def _fit_predict_cache(self,
992
+ model,
993
+ X_train,
994
+ y_train,
995
+ sample_weight,
996
+ pred_prefix: str,
997
+ use_oht: bool = False,
998
+ design_fn=None,
999
+ fit_kwargs: Optional[Dict[str, Any]] = None,
1000
+ sample_weight_arg: Optional[str] = 'sample_weight',
1001
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
1002
+ predict_kwargs_test: Optional[Dict[str, Any]] = None,
1003
+ predict_fn: Optional[Callable[..., Any]] = None,
1004
+ record_label: bool = True) -> None:
1005
+ fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
1006
+ if sample_weight is not None and sample_weight_arg:
1007
+ fit_kwargs.setdefault(sample_weight_arg, sample_weight)
1008
+ model.fit(X_train, y_train, **fit_kwargs)
1009
+ if record_label:
1010
+ self.ctx.model_label.append(self.label)
1011
+ self._predict_and_cache(
1012
+ model,
1013
+ pred_prefix,
1014
+ use_oht=use_oht,
1015
+ design_fn=design_fn,
1016
+ predict_kwargs_train=predict_kwargs_train,
1017
+ predict_kwargs_test=predict_kwargs_test,
1018
+ predict_fn=predict_fn)
1019
+
1020
+