ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +9 -6
- ins_pricing/__init__.py +3 -11
- ins_pricing/cli/BayesOpt_entry.py +24 -0
- ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
- ins_pricing/cli/Explain_Run.py +25 -0
- ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
- ins_pricing/cli/Pricing_Run.py +25 -0
- ins_pricing/cli/__init__.py +1 -0
- ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
- ins_pricing/cli/utils/__init__.py +1 -0
- ins_pricing/cli/utils/cli_common.py +320 -0
- ins_pricing/cli/utils/cli_config.py +375 -0
- ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
- {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
- ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
- ins_pricing/docs/modelling/README.md +34 -0
- ins_pricing/modelling/__init__.py +57 -6
- ins_pricing/modelling/core/__init__.py +1 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
- ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
- ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
- ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
- ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
- ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
- ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
- ins_pricing/modelling/core/evaluation.py +115 -0
- ins_pricing/production/__init__.py +4 -0
- ins_pricing/production/preprocess.py +71 -0
- ins_pricing/setup.py +10 -5
- {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
- ins_pricing-0.2.0.dist-info/RECORD +125 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
- ins_pricing/modelling/BayesOpt_entry.py +0 -633
- ins_pricing/modelling/Explain_Run.py +0 -36
- ins_pricing/modelling/Pricing_Run.py +0 -36
- ins_pricing/modelling/README.md +0 -33
- ins_pricing/modelling/bayesopt/models.py +0 -2196
- ins_pricing/modelling/bayesopt/trainers.py +0 -2446
- ins_pricing/modelling/cli_common.py +0 -136
- ins_pricing/modelling/tests/test_plotting.py +0 -63
- ins_pricing/modelling/watchdog_run.py +0 -211
- ins_pricing-0.1.11.dist-info/RECORD +0 -169
- ins_pricing_gemini/__init__.py +0 -23
- ins_pricing_gemini/governance/__init__.py +0 -20
- ins_pricing_gemini/governance/approval.py +0 -93
- ins_pricing_gemini/governance/audit.py +0 -37
- ins_pricing_gemini/governance/registry.py +0 -99
- ins_pricing_gemini/governance/release.py +0 -159
- ins_pricing_gemini/modelling/Explain_Run.py +0 -36
- ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
- ins_pricing_gemini/modelling/__init__.py +0 -151
- ins_pricing_gemini/modelling/cli_common.py +0 -141
- ins_pricing_gemini/modelling/config.py +0 -249
- ins_pricing_gemini/modelling/config_preprocess.py +0 -254
- ins_pricing_gemini/modelling/core.py +0 -741
- ins_pricing_gemini/modelling/data_container.py +0 -42
- ins_pricing_gemini/modelling/explain/__init__.py +0 -55
- ins_pricing_gemini/modelling/explain/gradients.py +0 -334
- ins_pricing_gemini/modelling/explain/metrics.py +0 -176
- ins_pricing_gemini/modelling/explain/permutation.py +0 -155
- ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
- ins_pricing_gemini/modelling/features.py +0 -215
- ins_pricing_gemini/modelling/model_manager.py +0 -148
- ins_pricing_gemini/modelling/model_plotting.py +0 -463
- ins_pricing_gemini/modelling/models.py +0 -2203
- ins_pricing_gemini/modelling/notebook_utils.py +0 -294
- ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
- ins_pricing_gemini/modelling/plotting/common.py +0 -63
- ins_pricing_gemini/modelling/plotting/curves.py +0 -572
- ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
- ins_pricing_gemini/modelling/plotting/geo.py +0 -362
- ins_pricing_gemini/modelling/plotting/importance.py +0 -121
- ins_pricing_gemini/modelling/run_logging.py +0 -133
- ins_pricing_gemini/modelling/tests/conftest.py +0 -8
- ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
- ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
- ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
- ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
- ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
- ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
- ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
- ins_pricing_gemini/modelling/trainers.py +0 -2447
- ins_pricing_gemini/modelling/utils.py +0 -1020
- ins_pricing_gemini/pricing/__init__.py +0 -27
- ins_pricing_gemini/pricing/calibration.py +0 -39
- ins_pricing_gemini/pricing/data_quality.py +0 -117
- ins_pricing_gemini/pricing/exposure.py +0 -85
- ins_pricing_gemini/pricing/factors.py +0 -91
- ins_pricing_gemini/pricing/monitoring.py +0 -99
- ins_pricing_gemini/pricing/rate_table.py +0 -78
- ins_pricing_gemini/production/__init__.py +0 -21
- ins_pricing_gemini/production/drift.py +0 -30
- ins_pricing_gemini/production/monitoring.py +0 -143
- ins_pricing_gemini/production/scoring.py +0 -40
- ins_pricing_gemini/reporting/__init__.py +0 -11
- ins_pricing_gemini/reporting/report_builder.py +0 -72
- ins_pricing_gemini/reporting/scheduler.py +0 -45
- ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
- ins_pricing_gemini/scripts/Explain_entry.py +0 -545
- ins_pricing_gemini/scripts/__init__.py +0 -1
- ins_pricing_gemini/scripts/train.py +0 -568
- ins_pricing_gemini/setup.py +0 -55
- ins_pricing_gemini/smoke_test.py +0 -28
- /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
- /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
- /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1020 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import timedelta
|
|
4
|
+
import gc
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import joblib
|
|
10
|
+
import numpy as np
|
|
11
|
+
import optuna
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
try: # pragma: no cover
|
|
15
|
+
import torch.distributed as dist # type: ignore
|
|
16
|
+
except Exception: # pragma: no cover
|
|
17
|
+
dist = None # type: ignore
|
|
18
|
+
from sklearn.model_selection import (
|
|
19
|
+
GroupKFold,
|
|
20
|
+
GroupShuffleSplit,
|
|
21
|
+
KFold,
|
|
22
|
+
ShuffleSplit,
|
|
23
|
+
TimeSeriesSplit,
|
|
24
|
+
)
|
|
25
|
+
from sklearn.preprocessing import StandardScaler
|
|
26
|
+
|
|
27
|
+
from ..config_preprocess import BayesOptConfig, OutputManager
|
|
28
|
+
from ..utils import DistributedUtils, EPS, ensure_parent_dir
|
|
29
|
+
|
|
30
|
+
class _OrderSplitter:
|
|
31
|
+
def __init__(self, splitter, order: np.ndarray) -> None:
|
|
32
|
+
self._splitter = splitter
|
|
33
|
+
self._order = np.asarray(order)
|
|
34
|
+
|
|
35
|
+
def split(self, X, y=None, groups=None):
|
|
36
|
+
order = self._order
|
|
37
|
+
X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
|
|
38
|
+
for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
|
|
39
|
+
yield order[tr_idx], order[val_idx]
|
|
40
|
+
|
|
41
|
+
# =============================================================================
|
|
42
|
+
# Trainer system
|
|
43
|
+
# =============================================================================
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TrainerBase:
|
|
47
|
+
def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
|
|
48
|
+
self.ctx = context
|
|
49
|
+
self.label = label
|
|
50
|
+
self.model_name_prefix = model_name_prefix
|
|
51
|
+
self.model = None
|
|
52
|
+
self.best_params: Optional[Dict[str, Any]] = None
|
|
53
|
+
self.best_trial = None
|
|
54
|
+
self.study_name: Optional[str] = None
|
|
55
|
+
self.enable_distributed_optuna: bool = False
|
|
56
|
+
self._distributed_forced_params: Optional[Dict[str, Any]] = None
|
|
57
|
+
|
|
58
|
+
def _dist_barrier(self, reason: str) -> None:
|
|
59
|
+
"""DDP barrier wrapper used by distributed Optuna.
|
|
60
|
+
|
|
61
|
+
To debug "trial finished but next trial never starts" hangs, set these
|
|
62
|
+
environment variables (either in shell or config.json `env`):
|
|
63
|
+
- `BAYESOPT_DDP_BARRIER_DEBUG=1` to print barrier enter/exit per-rank
|
|
64
|
+
- `BAYESOPT_DDP_BARRIER_TIMEOUT=300` to fail fast instead of waiting forever
|
|
65
|
+
- `TORCH_DISTRIBUTED_DEBUG=DETAIL` and `NCCL_DEBUG=INFO` for PyTorch/NCCL logs
|
|
66
|
+
"""
|
|
67
|
+
if dist is None:
|
|
68
|
+
return
|
|
69
|
+
try:
|
|
70
|
+
if not getattr(dist, "is_available", lambda: False)():
|
|
71
|
+
return
|
|
72
|
+
if not dist.is_initialized():
|
|
73
|
+
return
|
|
74
|
+
except Exception:
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
timeout_seconds = int(os.environ.get("BAYESOPT_DDP_BARRIER_TIMEOUT", "1800"))
|
|
78
|
+
debug_barrier = os.environ.get("BAYESOPT_DDP_BARRIER_DEBUG", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
|
|
79
|
+
rank = None
|
|
80
|
+
world = None
|
|
81
|
+
if debug_barrier:
|
|
82
|
+
try:
|
|
83
|
+
rank = dist.get_rank()
|
|
84
|
+
world = dist.get_world_size()
|
|
85
|
+
print(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
|
|
86
|
+
except Exception:
|
|
87
|
+
debug_barrier = False
|
|
88
|
+
try:
|
|
89
|
+
timeout = timedelta(seconds=timeout_seconds)
|
|
90
|
+
backend = None
|
|
91
|
+
try:
|
|
92
|
+
backend = dist.get_backend()
|
|
93
|
+
except Exception:
|
|
94
|
+
backend = None
|
|
95
|
+
|
|
96
|
+
# `monitored_barrier` is only implemented for GLOO; using it under NCCL
|
|
97
|
+
# will raise and can itself trigger a secondary hang. Prefer an async
|
|
98
|
+
# barrier with timeout for NCCL.
|
|
99
|
+
monitored = getattr(dist, "monitored_barrier", None)
|
|
100
|
+
if backend == "gloo" and callable(monitored):
|
|
101
|
+
monitored(timeout=timeout)
|
|
102
|
+
else:
|
|
103
|
+
work = None
|
|
104
|
+
try:
|
|
105
|
+
work = dist.barrier(async_op=True)
|
|
106
|
+
except TypeError:
|
|
107
|
+
work = None
|
|
108
|
+
if work is not None:
|
|
109
|
+
wait = getattr(work, "wait", None)
|
|
110
|
+
if callable(wait):
|
|
111
|
+
try:
|
|
112
|
+
wait(timeout=timeout)
|
|
113
|
+
except TypeError:
|
|
114
|
+
wait()
|
|
115
|
+
else:
|
|
116
|
+
dist.barrier()
|
|
117
|
+
else:
|
|
118
|
+
dist.barrier()
|
|
119
|
+
if debug_barrier:
|
|
120
|
+
print(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
|
|
121
|
+
except Exception as exc:
|
|
122
|
+
print(
|
|
123
|
+
f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
|
|
124
|
+
flush=True,
|
|
125
|
+
)
|
|
126
|
+
raise
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def config(self) -> BayesOptConfig:
|
|
130
|
+
return self.ctx.config
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def output(self) -> OutputManager:
|
|
134
|
+
return self.ctx.output_manager
|
|
135
|
+
|
|
136
|
+
def _get_model_filename(self) -> str:
|
|
137
|
+
ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
|
|
138
|
+
return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
|
|
139
|
+
|
|
140
|
+
def _resolve_optuna_storage_url(self) -> Optional[str]:
|
|
141
|
+
storage = getattr(self.config, "optuna_storage", None)
|
|
142
|
+
if not storage:
|
|
143
|
+
return None
|
|
144
|
+
storage_str = str(storage).strip()
|
|
145
|
+
if not storage_str:
|
|
146
|
+
return None
|
|
147
|
+
if "://" in storage_str or storage_str == ":memory:":
|
|
148
|
+
return storage_str
|
|
149
|
+
path = Path(storage_str)
|
|
150
|
+
path = path.resolve()
|
|
151
|
+
ensure_parent_dir(str(path))
|
|
152
|
+
return f"sqlite:///{path.as_posix()}"
|
|
153
|
+
|
|
154
|
+
def _resolve_optuna_study_name(self) -> str:
|
|
155
|
+
prefix = getattr(self.config, "optuna_study_prefix",
|
|
156
|
+
None) or "bayesopt"
|
|
157
|
+
raw = f"{prefix}_{self.ctx.model_nme}_{self.model_name_prefix}"
|
|
158
|
+
safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
|
|
159
|
+
return safe.lower()
|
|
160
|
+
|
|
161
|
+
def tune(self, max_evals: int, objective_fn=None) -> None:
|
|
162
|
+
# Generic Optuna tuning loop.
|
|
163
|
+
if objective_fn is None:
|
|
164
|
+
# If subclass doesn't provide objective_fn, default to cross_val.
|
|
165
|
+
objective_fn = self.cross_val
|
|
166
|
+
|
|
167
|
+
if self._should_use_distributed_optuna():
|
|
168
|
+
self._distributed_tune(max_evals, objective_fn)
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
total_trials = max(1, int(max_evals))
|
|
172
|
+
progress_counter = {"count": 0}
|
|
173
|
+
|
|
174
|
+
def objective_wrapper(trial: optuna.trial.Trial) -> float:
|
|
175
|
+
should_log = DistributedUtils.is_main_process()
|
|
176
|
+
if should_log:
|
|
177
|
+
current_idx = progress_counter["count"] + 1
|
|
178
|
+
print(
|
|
179
|
+
f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
|
|
180
|
+
f"(trial_id={trial.number})."
|
|
181
|
+
)
|
|
182
|
+
try:
|
|
183
|
+
result = objective_fn(trial)
|
|
184
|
+
except RuntimeError as exc:
|
|
185
|
+
if "out of memory" in str(exc).lower():
|
|
186
|
+
print(
|
|
187
|
+
f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
|
|
188
|
+
)
|
|
189
|
+
self._clean_gpu()
|
|
190
|
+
raise optuna.TrialPruned() from exc
|
|
191
|
+
raise
|
|
192
|
+
finally:
|
|
193
|
+
self._clean_gpu()
|
|
194
|
+
if should_log:
|
|
195
|
+
progress_counter["count"] = progress_counter["count"] + 1
|
|
196
|
+
trial_state = getattr(trial, "state", None)
|
|
197
|
+
state_repr = getattr(trial_state, "name", "OK")
|
|
198
|
+
print(
|
|
199
|
+
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
200
|
+
f"(status={state_repr})."
|
|
201
|
+
)
|
|
202
|
+
return result
|
|
203
|
+
|
|
204
|
+
storage_url = self._resolve_optuna_storage_url()
|
|
205
|
+
study_name = self._resolve_optuna_study_name()
|
|
206
|
+
study_kwargs: Dict[str, Any] = {
|
|
207
|
+
"direction": "minimize",
|
|
208
|
+
"sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
|
|
209
|
+
}
|
|
210
|
+
if storage_url:
|
|
211
|
+
study_kwargs.update(
|
|
212
|
+
storage=storage_url,
|
|
213
|
+
study_name=study_name,
|
|
214
|
+
load_if_exists=True,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
study = optuna.create_study(**study_kwargs)
|
|
218
|
+
self.study_name = getattr(study, "study_name", None)
|
|
219
|
+
|
|
220
|
+
def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
|
|
221
|
+
# Persist best_params after each trial to allow safe resume.
|
|
222
|
+
try:
|
|
223
|
+
best = getattr(check_study, "best_trial", None)
|
|
224
|
+
if best is None:
|
|
225
|
+
return
|
|
226
|
+
best_params = getattr(best, "params", None)
|
|
227
|
+
if not best_params:
|
|
228
|
+
return
|
|
229
|
+
params_path = self.output.result_path(
|
|
230
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
231
|
+
)
|
|
232
|
+
pd.DataFrame(best_params, index=[0]).to_csv(
|
|
233
|
+
params_path, index=False)
|
|
234
|
+
except Exception:
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
completed_states = (
|
|
238
|
+
optuna.trial.TrialState.COMPLETE,
|
|
239
|
+
optuna.trial.TrialState.PRUNED,
|
|
240
|
+
optuna.trial.TrialState.FAIL,
|
|
241
|
+
)
|
|
242
|
+
completed = len(study.get_trials(states=completed_states))
|
|
243
|
+
progress_counter["count"] = completed
|
|
244
|
+
remaining = max(0, total_trials - completed)
|
|
245
|
+
if remaining > 0:
|
|
246
|
+
study.optimize(
|
|
247
|
+
objective_wrapper,
|
|
248
|
+
n_trials=remaining,
|
|
249
|
+
callbacks=[checkpoint_callback],
|
|
250
|
+
)
|
|
251
|
+
self.best_params = study.best_params
|
|
252
|
+
self.best_trial = study.best_trial
|
|
253
|
+
|
|
254
|
+
# Save best params to CSV for reproducibility.
|
|
255
|
+
params_path = self.output.result_path(
|
|
256
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
257
|
+
)
|
|
258
|
+
pd.DataFrame(self.best_params, index=[0]).to_csv(
|
|
259
|
+
params_path, index=False)
|
|
260
|
+
|
|
261
|
+
def train(self) -> None:
|
|
262
|
+
raise NotImplementedError
|
|
263
|
+
|
|
264
|
+
def save(self) -> None:
|
|
265
|
+
if self.model is None:
|
|
266
|
+
print(f"[save] Warning: No model to save for {self.label}")
|
|
267
|
+
return
|
|
268
|
+
|
|
269
|
+
path = self.output.model_path(self._get_model_filename())
|
|
270
|
+
if self.label in ['Xgboost', 'GLM']:
|
|
271
|
+
joblib.dump(self.model, path)
|
|
272
|
+
else:
|
|
273
|
+
# PyTorch models can save state_dict or the full object.
|
|
274
|
+
# Legacy behavior: ResNetTrainer saves state_dict; FTTrainer saves full object.
|
|
275
|
+
if hasattr(self.model, 'resnet'): # ResNetSklearn model
|
|
276
|
+
torch.save(self.model.resnet.state_dict(), path)
|
|
277
|
+
else: # FTTransformerSklearn or other PyTorch model
|
|
278
|
+
torch.save(self.model, path)
|
|
279
|
+
|
|
280
|
+
def load(self) -> None:
|
|
281
|
+
path = self.output.model_path(self._get_model_filename())
|
|
282
|
+
if not os.path.exists(path):
|
|
283
|
+
print(f"[load] Warning: Model file not found: {path}")
|
|
284
|
+
return
|
|
285
|
+
|
|
286
|
+
if self.label in ['Xgboost', 'GLM']:
|
|
287
|
+
self.model = joblib.load(path)
|
|
288
|
+
else:
|
|
289
|
+
# PyTorch loading depends on the model structure.
|
|
290
|
+
if self.label == 'ResNet' or self.label == 'ResNetClassifier':
|
|
291
|
+
# ResNet requires reconstructing the skeleton; handled by subclass.
|
|
292
|
+
pass
|
|
293
|
+
else:
|
|
294
|
+
# FT-Transformer serializes the whole object; load then move to device.
|
|
295
|
+
loaded = torch.load(path, map_location='cpu')
|
|
296
|
+
self._move_to_device(loaded)
|
|
297
|
+
self.model = loaded
|
|
298
|
+
|
|
299
|
+
def _move_to_device(self, model_obj):
|
|
300
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
301
|
+
if hasattr(model_obj, 'device'):
|
|
302
|
+
model_obj.device = device
|
|
303
|
+
if hasattr(model_obj, 'to'):
|
|
304
|
+
model_obj.to(device)
|
|
305
|
+
# Move nested submodules (ft/resnet/gnn) to the same device.
|
|
306
|
+
if hasattr(model_obj, 'ft'):
|
|
307
|
+
model_obj.ft.to(device)
|
|
308
|
+
if hasattr(model_obj, 'resnet'):
|
|
309
|
+
model_obj.resnet.to(device)
|
|
310
|
+
if hasattr(model_obj, 'gnn'):
|
|
311
|
+
model_obj.gnn.to(device)
|
|
312
|
+
|
|
313
|
+
def _should_use_distributed_optuna(self) -> bool:
|
|
314
|
+
if not self.enable_distributed_optuna:
|
|
315
|
+
return False
|
|
316
|
+
rank_env = os.environ.get("RANK")
|
|
317
|
+
world_env = os.environ.get("WORLD_SIZE")
|
|
318
|
+
local_env = os.environ.get("LOCAL_RANK")
|
|
319
|
+
if rank_env is None or world_env is None or local_env is None:
|
|
320
|
+
return False
|
|
321
|
+
try:
|
|
322
|
+
world_size = int(world_env)
|
|
323
|
+
except Exception:
|
|
324
|
+
return False
|
|
325
|
+
return world_size > 1
|
|
326
|
+
|
|
327
|
+
def _distributed_is_main(self) -> bool:
|
|
328
|
+
return DistributedUtils.is_main_process()
|
|
329
|
+
|
|
330
|
+
def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
|
|
331
|
+
if not self._should_use_distributed_optuna() or not self._distributed_is_main():
|
|
332
|
+
return
|
|
333
|
+
if dist is None:
|
|
334
|
+
return
|
|
335
|
+
DistributedUtils.setup_ddp()
|
|
336
|
+
if not dist.is_initialized():
|
|
337
|
+
return
|
|
338
|
+
message = [payload]
|
|
339
|
+
dist.broadcast_object_list(message, src=0)
|
|
340
|
+
|
|
341
|
+
def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
|
|
342
|
+
if not self._should_use_distributed_optuna():
|
|
343
|
+
return
|
|
344
|
+
if not self._distributed_is_main():
|
|
345
|
+
return
|
|
346
|
+
if dist is None:
|
|
347
|
+
return
|
|
348
|
+
self._distributed_send_command({"type": "RUN", "params": params})
|
|
349
|
+
if not dist.is_initialized():
|
|
350
|
+
return
|
|
351
|
+
# STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
|
|
352
|
+
self._dist_barrier("prepare_trial")
|
|
353
|
+
|
|
354
|
+
def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
|
|
355
|
+
if dist is None:
|
|
356
|
+
print(
|
|
357
|
+
f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
|
|
358
|
+
flush=True,
|
|
359
|
+
)
|
|
360
|
+
return
|
|
361
|
+
DistributedUtils.setup_ddp()
|
|
362
|
+
if not dist.is_initialized():
|
|
363
|
+
print(
|
|
364
|
+
f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
|
|
365
|
+
flush=True,
|
|
366
|
+
)
|
|
367
|
+
return
|
|
368
|
+
while True:
|
|
369
|
+
message = [None]
|
|
370
|
+
dist.broadcast_object_list(message, src=0)
|
|
371
|
+
payload = message[0]
|
|
372
|
+
if not isinstance(payload, dict):
|
|
373
|
+
continue
|
|
374
|
+
cmd = payload.get("type")
|
|
375
|
+
if cmd == "STOP":
|
|
376
|
+
best_params = payload.get("best_params")
|
|
377
|
+
if best_params is not None:
|
|
378
|
+
self.best_params = best_params
|
|
379
|
+
break
|
|
380
|
+
if cmd == "RUN":
|
|
381
|
+
params = payload.get("params") or {}
|
|
382
|
+
self._distributed_forced_params = params
|
|
383
|
+
# STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
|
|
384
|
+
self._dist_barrier("worker_start")
|
|
385
|
+
try:
|
|
386
|
+
objective_fn(None)
|
|
387
|
+
except optuna.TrialPruned:
|
|
388
|
+
pass
|
|
389
|
+
except Exception as exc:
|
|
390
|
+
print(
|
|
391
|
+
f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
|
|
392
|
+
finally:
|
|
393
|
+
self._clean_gpu()
|
|
394
|
+
# STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
|
|
395
|
+
self._dist_barrier("worker_end")
|
|
396
|
+
|
|
397
|
+
def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
|
|
398
|
+
if dist is None:
|
|
399
|
+
print(
|
|
400
|
+
f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
|
|
401
|
+
flush=True,
|
|
402
|
+
)
|
|
403
|
+
prev = self.enable_distributed_optuna
|
|
404
|
+
self.enable_distributed_optuna = False
|
|
405
|
+
try:
|
|
406
|
+
self.tune(max_evals, objective_fn)
|
|
407
|
+
finally:
|
|
408
|
+
self.enable_distributed_optuna = prev
|
|
409
|
+
return
|
|
410
|
+
DistributedUtils.setup_ddp()
|
|
411
|
+
if not dist.is_initialized():
|
|
412
|
+
rank_env = os.environ.get("RANK", "0")
|
|
413
|
+
if str(rank_env) != "0":
|
|
414
|
+
print(
|
|
415
|
+
f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
|
|
416
|
+
flush=True,
|
|
417
|
+
)
|
|
418
|
+
return
|
|
419
|
+
print(
|
|
420
|
+
f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
|
|
421
|
+
flush=True,
|
|
422
|
+
)
|
|
423
|
+
prev = self.enable_distributed_optuna
|
|
424
|
+
self.enable_distributed_optuna = False
|
|
425
|
+
try:
|
|
426
|
+
self.tune(max_evals, objective_fn)
|
|
427
|
+
finally:
|
|
428
|
+
self.enable_distributed_optuna = prev
|
|
429
|
+
return
|
|
430
|
+
if not self._distributed_is_main():
|
|
431
|
+
self._distributed_worker_loop(objective_fn)
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
total_trials = max(1, int(max_evals))
|
|
435
|
+
progress_counter = {"count": 0}
|
|
436
|
+
|
|
437
|
+
def objective_wrapper(trial: optuna.trial.Trial) -> float:
|
|
438
|
+
should_log = True
|
|
439
|
+
if should_log:
|
|
440
|
+
current_idx = progress_counter["count"] + 1
|
|
441
|
+
print(
|
|
442
|
+
f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
|
|
443
|
+
f"(trial_id={trial.number})."
|
|
444
|
+
)
|
|
445
|
+
try:
|
|
446
|
+
result = objective_fn(trial)
|
|
447
|
+
except RuntimeError as exc:
|
|
448
|
+
if "out of memory" in str(exc).lower():
|
|
449
|
+
print(
|
|
450
|
+
f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
|
|
451
|
+
)
|
|
452
|
+
self._clean_gpu()
|
|
453
|
+
raise optuna.TrialPruned() from exc
|
|
454
|
+
raise
|
|
455
|
+
finally:
|
|
456
|
+
self._clean_gpu()
|
|
457
|
+
if should_log:
|
|
458
|
+
progress_counter["count"] = progress_counter["count"] + 1
|
|
459
|
+
trial_state = getattr(trial, "state", None)
|
|
460
|
+
state_repr = getattr(trial_state, "name", "OK")
|
|
461
|
+
print(
|
|
462
|
+
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
463
|
+
f"(status={state_repr})."
|
|
464
|
+
)
|
|
465
|
+
# STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
|
|
466
|
+
self._dist_barrier("trial_end")
|
|
467
|
+
return result
|
|
468
|
+
|
|
469
|
+
storage_url = self._resolve_optuna_storage_url()
|
|
470
|
+
study_name = self._resolve_optuna_study_name()
|
|
471
|
+
study_kwargs: Dict[str, Any] = {
|
|
472
|
+
"direction": "minimize",
|
|
473
|
+
"sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
|
|
474
|
+
}
|
|
475
|
+
if storage_url:
|
|
476
|
+
study_kwargs.update(
|
|
477
|
+
storage=storage_url,
|
|
478
|
+
study_name=study_name,
|
|
479
|
+
load_if_exists=True,
|
|
480
|
+
)
|
|
481
|
+
study = optuna.create_study(**study_kwargs)
|
|
482
|
+
self.study_name = getattr(study, "study_name", None)
|
|
483
|
+
|
|
484
|
+
def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
|
|
485
|
+
try:
|
|
486
|
+
best = getattr(check_study, "best_trial", None)
|
|
487
|
+
if best is None:
|
|
488
|
+
return
|
|
489
|
+
best_params = getattr(best, "params", None)
|
|
490
|
+
if not best_params:
|
|
491
|
+
return
|
|
492
|
+
params_path = self.output.result_path(
|
|
493
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
494
|
+
)
|
|
495
|
+
pd.DataFrame(best_params, index=[0]).to_csv(
|
|
496
|
+
params_path, index=False)
|
|
497
|
+
except Exception:
|
|
498
|
+
return
|
|
499
|
+
|
|
500
|
+
completed_states = (
|
|
501
|
+
optuna.trial.TrialState.COMPLETE,
|
|
502
|
+
optuna.trial.TrialState.PRUNED,
|
|
503
|
+
optuna.trial.TrialState.FAIL,
|
|
504
|
+
)
|
|
505
|
+
completed = len(study.get_trials(states=completed_states))
|
|
506
|
+
progress_counter["count"] = completed
|
|
507
|
+
remaining = max(0, total_trials - completed)
|
|
508
|
+
try:
|
|
509
|
+
if remaining > 0:
|
|
510
|
+
study.optimize(
|
|
511
|
+
objective_wrapper,
|
|
512
|
+
n_trials=remaining,
|
|
513
|
+
callbacks=[checkpoint_callback],
|
|
514
|
+
)
|
|
515
|
+
self.best_params = study.best_params
|
|
516
|
+
self.best_trial = study.best_trial
|
|
517
|
+
params_path = self.output.result_path(
|
|
518
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
519
|
+
)
|
|
520
|
+
pd.DataFrame(self.best_params, index=[0]).to_csv(
|
|
521
|
+
params_path, index=False)
|
|
522
|
+
finally:
|
|
523
|
+
self._distributed_send_command(
|
|
524
|
+
{"type": "STOP", "best_params": self.best_params})
|
|
525
|
+
|
|
526
|
+
def _clean_gpu(self):
|
|
527
|
+
gc.collect()
|
|
528
|
+
if torch.cuda.is_available():
|
|
529
|
+
device = None
|
|
530
|
+
try:
|
|
531
|
+
device = getattr(self, "device", None)
|
|
532
|
+
except Exception:
|
|
533
|
+
device = None
|
|
534
|
+
if isinstance(device, torch.device):
|
|
535
|
+
try:
|
|
536
|
+
torch.cuda.set_device(device)
|
|
537
|
+
except Exception:
|
|
538
|
+
pass
|
|
539
|
+
torch.cuda.empty_cache()
|
|
540
|
+
do_ipc_collect = os.environ.get("BAYESOPT_CUDA_IPC_COLLECT", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
|
|
541
|
+
do_sync = os.environ.get("BAYESOPT_CUDA_SYNC", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
|
|
542
|
+
if do_ipc_collect:
|
|
543
|
+
torch.cuda.ipc_collect()
|
|
544
|
+
if do_sync:
|
|
545
|
+
torch.cuda.synchronize()
|
|
546
|
+
|
|
547
|
+
def _standardize_fold(self,
|
|
548
|
+
X_train: pd.DataFrame,
|
|
549
|
+
X_val: pd.DataFrame,
|
|
550
|
+
columns: Optional[List[str]] = None
|
|
551
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
|
|
552
|
+
"""Fit StandardScaler on the training fold and transform train/val features.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
X_train: training features.
|
|
556
|
+
X_val: validation features.
|
|
557
|
+
columns: columns to scale (default: all).
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
Scaled train/val features and the fitted scaler.
|
|
561
|
+
"""
|
|
562
|
+
scaler = StandardScaler()
|
|
563
|
+
cols = list(columns) if columns else list(X_train.columns)
|
|
564
|
+
X_train_scaled = X_train.copy(deep=True)
|
|
565
|
+
X_val_scaled = X_val.copy(deep=True)
|
|
566
|
+
if cols:
|
|
567
|
+
scaler.fit(X_train_scaled[cols])
|
|
568
|
+
X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
|
|
569
|
+
X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
|
|
570
|
+
return X_train_scaled, X_val_scaled, scaler
|
|
571
|
+
|
|
572
|
+
def _resolve_train_val_indices(
|
|
573
|
+
self,
|
|
574
|
+
X_all: pd.DataFrame,
|
|
575
|
+
*,
|
|
576
|
+
allow_default: bool = False,
|
|
577
|
+
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
578
|
+
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
579
|
+
if not (0.0 < val_ratio < 1.0):
|
|
580
|
+
if not allow_default:
|
|
581
|
+
return None
|
|
582
|
+
val_ratio = 0.25
|
|
583
|
+
if len(X_all) < 10:
|
|
584
|
+
return None
|
|
585
|
+
|
|
586
|
+
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
587
|
+
if strategy in {"time", "timeseries", "temporal"}:
|
|
588
|
+
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
589
|
+
if not time_col:
|
|
590
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
591
|
+
if time_col not in self.ctx.train_data.columns:
|
|
592
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
593
|
+
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
594
|
+
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
595
|
+
index_set = set(X_all.index)
|
|
596
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
597
|
+
order = X_all.index.get_indexer(order_index)
|
|
598
|
+
order = order[order >= 0]
|
|
599
|
+
cutoff = int(len(order) * (1.0 - val_ratio))
|
|
600
|
+
if cutoff <= 0 or cutoff >= len(order):
|
|
601
|
+
raise ValueError(
|
|
602
|
+
f"prop_test={val_ratio} leaves no data for train/val split.")
|
|
603
|
+
return order[:cutoff], order[cutoff:]
|
|
604
|
+
|
|
605
|
+
if strategy in {"group", "grouped"}:
|
|
606
|
+
group_col = getattr(self.ctx.config, "cv_group_col", None)
|
|
607
|
+
if not group_col:
|
|
608
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
609
|
+
if group_col not in self.ctx.train_data.columns:
|
|
610
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
611
|
+
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
612
|
+
splitter = GroupShuffleSplit(
|
|
613
|
+
n_splits=1,
|
|
614
|
+
test_size=val_ratio,
|
|
615
|
+
random_state=self.ctx.rand_seed,
|
|
616
|
+
)
|
|
617
|
+
train_idx, val_idx = next(splitter.split(X_all, groups=groups))
|
|
618
|
+
return train_idx, val_idx
|
|
619
|
+
|
|
620
|
+
splitter = ShuffleSplit(
|
|
621
|
+
n_splits=1,
|
|
622
|
+
test_size=val_ratio,
|
|
623
|
+
random_state=self.ctx.rand_seed,
|
|
624
|
+
)
|
|
625
|
+
train_idx, val_idx = next(splitter.split(X_all))
|
|
626
|
+
return train_idx, val_idx
|
|
627
|
+
|
|
628
|
+
def _resolve_time_sample_indices(
|
|
629
|
+
self,
|
|
630
|
+
X_all: pd.DataFrame,
|
|
631
|
+
sample_limit: int,
|
|
632
|
+
) -> Optional[pd.Index]:
|
|
633
|
+
if sample_limit <= 0:
|
|
634
|
+
return None
|
|
635
|
+
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
636
|
+
if strategy not in {"time", "timeseries", "temporal"}:
|
|
637
|
+
return None
|
|
638
|
+
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
639
|
+
if not time_col:
|
|
640
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
641
|
+
if time_col not in self.ctx.train_data.columns:
|
|
642
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
643
|
+
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
644
|
+
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
645
|
+
index_set = set(X_all.index)
|
|
646
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
647
|
+
if not order_index:
|
|
648
|
+
return None
|
|
649
|
+
if len(order_index) > sample_limit:
|
|
650
|
+
order_index = order_index[-sample_limit:]
|
|
651
|
+
return pd.Index(order_index)
|
|
652
|
+
|
|
653
|
+
def _resolve_ensemble_splits(
|
|
654
|
+
self,
|
|
655
|
+
X_all: pd.DataFrame,
|
|
656
|
+
*,
|
|
657
|
+
k: int,
|
|
658
|
+
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
659
|
+
k = max(2, int(k))
|
|
660
|
+
n_samples = len(X_all)
|
|
661
|
+
if n_samples < 2:
|
|
662
|
+
return None, 0
|
|
663
|
+
|
|
664
|
+
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
665
|
+
if strategy in {"group", "grouped"}:
|
|
666
|
+
group_col = getattr(self.ctx.config, "cv_group_col", None)
|
|
667
|
+
if not group_col:
|
|
668
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
669
|
+
if group_col not in self.ctx.train_data.columns:
|
|
670
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
671
|
+
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
672
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
673
|
+
if n_groups < 2:
|
|
674
|
+
return None, 0
|
|
675
|
+
if k > n_groups:
|
|
676
|
+
k = n_groups
|
|
677
|
+
if k < 2:
|
|
678
|
+
return None, 0
|
|
679
|
+
splitter = GroupKFold(n_splits=k)
|
|
680
|
+
return splitter.split(X_all, y=None, groups=groups), k
|
|
681
|
+
|
|
682
|
+
if strategy in {"time", "timeseries", "temporal"}:
|
|
683
|
+
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
684
|
+
if not time_col:
|
|
685
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
686
|
+
if time_col not in self.ctx.train_data.columns:
|
|
687
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
688
|
+
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
689
|
+
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
690
|
+
index_set = set(X_all.index)
|
|
691
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
692
|
+
order = X_all.index.get_indexer(order_index)
|
|
693
|
+
order = order[order >= 0]
|
|
694
|
+
if len(order) < 2:
|
|
695
|
+
return None, 0
|
|
696
|
+
if len(order) <= k:
|
|
697
|
+
k = max(2, len(order) - 1)
|
|
698
|
+
if k < 2:
|
|
699
|
+
return None, 0
|
|
700
|
+
splitter = TimeSeriesSplit(n_splits=k)
|
|
701
|
+
return _OrderSplitter(splitter, order).split(X_all), k
|
|
702
|
+
|
|
703
|
+
if n_samples < k:
|
|
704
|
+
k = n_samples
|
|
705
|
+
if k < 2:
|
|
706
|
+
return None, 0
|
|
707
|
+
splitter = KFold(
|
|
708
|
+
n_splits=k,
|
|
709
|
+
shuffle=True,
|
|
710
|
+
random_state=self.ctx.rand_seed,
|
|
711
|
+
)
|
|
712
|
+
return splitter.split(X_all), k
|
|
713
|
+
|
|
714
|
+
def cross_val_generic(
|
|
715
|
+
self,
|
|
716
|
+
trial: optuna.trial.Trial,
|
|
717
|
+
hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
|
|
718
|
+
data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
|
|
719
|
+
model_builder: Callable[[Dict[str, Any]], Any],
|
|
720
|
+
metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
|
|
721
|
+
sample_limit: Optional[int] = None,
|
|
722
|
+
preprocess_fn: Optional[Callable[[
|
|
723
|
+
pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
|
|
724
|
+
fit_predict_fn: Optional[
|
|
725
|
+
Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
|
|
726
|
+
pd.DataFrame, pd.Series, Optional[pd.Series],
|
|
727
|
+
optuna.trial.Trial], np.ndarray]
|
|
728
|
+
] = None,
|
|
729
|
+
cleanup_fn: Optional[Callable[[Any], None]] = None,
|
|
730
|
+
splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
|
|
731
|
+
"""Generic holdout/CV helper to reuse tuning workflows.
|
|
732
|
+
|
|
733
|
+
Args:
|
|
734
|
+
trial: current Optuna trial.
|
|
735
|
+
hyperparameter_space: sampler dict keyed by parameter name.
|
|
736
|
+
data_provider: callback returning (X, y, sample_weight).
|
|
737
|
+
model_builder: callback to build a model per fold.
|
|
738
|
+
metric_fn: loss/score function taking y_true, y_pred, weight.
|
|
739
|
+
sample_limit: optional sample cap; random sample if exceeded.
|
|
740
|
+
preprocess_fn: optional per-fold preprocessing (X_train, X_val).
|
|
741
|
+
fit_predict_fn: optional custom fit/predict logic for validation.
|
|
742
|
+
cleanup_fn: optional cleanup callback per fold.
|
|
743
|
+
splitter: optional (train_idx, val_idx) iterator; defaults to cv_strategy config.
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
Mean validation metric across folds.
|
|
747
|
+
"""
|
|
748
|
+
params: Optional[Dict[str, Any]] = None
|
|
749
|
+
if self._distributed_forced_params is not None:
|
|
750
|
+
params = self._distributed_forced_params
|
|
751
|
+
self._distributed_forced_params = None
|
|
752
|
+
else:
|
|
753
|
+
if trial is None:
|
|
754
|
+
raise RuntimeError(
|
|
755
|
+
"Missing Optuna trial for parameter sampling.")
|
|
756
|
+
params = {name: sampler(trial)
|
|
757
|
+
for name, sampler in hyperparameter_space.items()}
|
|
758
|
+
if self._should_use_distributed_optuna():
|
|
759
|
+
self._distributed_prepare_trial(params)
|
|
760
|
+
X_all, y_all, w_all = data_provider()
|
|
761
|
+
cfg_limit = getattr(self.ctx.config, "bo_sample_limit", None)
|
|
762
|
+
if cfg_limit is not None:
|
|
763
|
+
cfg_limit = int(cfg_limit)
|
|
764
|
+
if cfg_limit > 0:
|
|
765
|
+
sample_limit = cfg_limit if sample_limit is None else min(sample_limit, cfg_limit)
|
|
766
|
+
if sample_limit is not None and len(X_all) > sample_limit:
|
|
767
|
+
sampled_idx = self._resolve_time_sample_indices(X_all, int(sample_limit))
|
|
768
|
+
if sampled_idx is None:
|
|
769
|
+
sampled_idx = X_all.sample(
|
|
770
|
+
n=sample_limit,
|
|
771
|
+
random_state=self.ctx.rand_seed
|
|
772
|
+
).index
|
|
773
|
+
X_all = X_all.loc[sampled_idx]
|
|
774
|
+
y_all = y_all.loc[sampled_idx]
|
|
775
|
+
w_all = w_all.loc[sampled_idx] if w_all is not None else None
|
|
776
|
+
|
|
777
|
+
if splitter is None:
|
|
778
|
+
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
779
|
+
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
780
|
+
if not (0.0 < val_ratio < 1.0):
|
|
781
|
+
val_ratio = 0.25
|
|
782
|
+
cv_splits = getattr(self.ctx.config, "cv_splits", None)
|
|
783
|
+
if cv_splits is None:
|
|
784
|
+
cv_splits = max(2, int(round(1 / val_ratio)))
|
|
785
|
+
cv_splits = max(2, int(cv_splits))
|
|
786
|
+
|
|
787
|
+
if strategy in {"group", "grouped"}:
|
|
788
|
+
group_col = getattr(self.ctx.config, "cv_group_col", None)
|
|
789
|
+
if not group_col:
|
|
790
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
791
|
+
if group_col not in self.ctx.train_data.columns:
|
|
792
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
793
|
+
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
794
|
+
split_iter = GroupKFold(n_splits=cv_splits).split(X_all, y_all, groups=groups)
|
|
795
|
+
elif strategy in {"time", "timeseries", "temporal"}:
|
|
796
|
+
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
797
|
+
if not time_col:
|
|
798
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
799
|
+
if time_col not in self.ctx.train_data.columns:
|
|
800
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
801
|
+
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
802
|
+
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
803
|
+
index_set = set(X_all.index)
|
|
804
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
805
|
+
order = X_all.index.get_indexer(order_index)
|
|
806
|
+
order = order[order >= 0]
|
|
807
|
+
if len(order) <= cv_splits:
|
|
808
|
+
cv_splits = max(2, len(order) - 1)
|
|
809
|
+
if cv_splits < 2:
|
|
810
|
+
raise ValueError("Not enough samples for time-series CV.")
|
|
811
|
+
split_iter = _OrderSplitter(TimeSeriesSplit(n_splits=cv_splits), order).split(X_all)
|
|
812
|
+
else:
|
|
813
|
+
split_iter = ShuffleSplit(
|
|
814
|
+
n_splits=cv_splits,
|
|
815
|
+
test_size=val_ratio,
|
|
816
|
+
random_state=self.ctx.rand_seed
|
|
817
|
+
).split(X_all)
|
|
818
|
+
else:
|
|
819
|
+
if hasattr(splitter, "split"):
|
|
820
|
+
split_iter = splitter.split(X_all, y_all, groups=None)
|
|
821
|
+
else:
|
|
822
|
+
split_iter = splitter
|
|
823
|
+
|
|
824
|
+
losses: List[float] = []
|
|
825
|
+
for train_idx, val_idx in split_iter:
|
|
826
|
+
X_train = X_all.iloc[train_idx]
|
|
827
|
+
y_train = y_all.iloc[train_idx]
|
|
828
|
+
X_val = X_all.iloc[val_idx]
|
|
829
|
+
y_val = y_all.iloc[val_idx]
|
|
830
|
+
w_train = w_all.iloc[train_idx] if w_all is not None else None
|
|
831
|
+
w_val = w_all.iloc[val_idx] if w_all is not None else None
|
|
832
|
+
|
|
833
|
+
if preprocess_fn:
|
|
834
|
+
X_train, X_val = preprocess_fn(X_train, X_val)
|
|
835
|
+
|
|
836
|
+
model = model_builder(params)
|
|
837
|
+
try:
|
|
838
|
+
if fit_predict_fn:
|
|
839
|
+
y_pred = fit_predict_fn(
|
|
840
|
+
model, X_train, y_train, w_train,
|
|
841
|
+
X_val, y_val, w_val, trial
|
|
842
|
+
)
|
|
843
|
+
else:
|
|
844
|
+
fit_kwargs = {}
|
|
845
|
+
if w_train is not None:
|
|
846
|
+
fit_kwargs["sample_weight"] = w_train
|
|
847
|
+
model.fit(X_train, y_train, **fit_kwargs)
|
|
848
|
+
y_pred = model.predict(X_val)
|
|
849
|
+
losses.append(metric_fn(y_val, y_pred, w_val))
|
|
850
|
+
finally:
|
|
851
|
+
if cleanup_fn:
|
|
852
|
+
cleanup_fn(model)
|
|
853
|
+
self._clean_gpu()
|
|
854
|
+
|
|
855
|
+
return float(np.mean(losses))
|
|
856
|
+
|
|
857
|
+
# Prediction + caching logic.
|
|
858
|
+
def _predict_and_cache(self,
|
|
859
|
+
model,
|
|
860
|
+
pred_prefix: str,
|
|
861
|
+
use_oht: bool = False,
|
|
862
|
+
design_fn=None,
|
|
863
|
+
predict_kwargs_train: Optional[Dict[str, Any]] = None,
|
|
864
|
+
predict_kwargs_test: Optional[Dict[str, Any]] = None,
|
|
865
|
+
predict_fn: Optional[Callable[..., Any]] = None) -> None:
|
|
866
|
+
if design_fn:
|
|
867
|
+
X_train = design_fn(train=True)
|
|
868
|
+
X_test = design_fn(train=False)
|
|
869
|
+
elif use_oht:
|
|
870
|
+
X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
|
|
871
|
+
X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
|
|
872
|
+
else:
|
|
873
|
+
X_train = self.ctx.train_data[self.ctx.factor_nmes]
|
|
874
|
+
X_test = self.ctx.test_data[self.ctx.factor_nmes]
|
|
875
|
+
|
|
876
|
+
predictor = predict_fn or model.predict
|
|
877
|
+
preds_train = predictor(X_train, **(predict_kwargs_train or {}))
|
|
878
|
+
preds_test = predictor(X_test, **(predict_kwargs_test or {}))
|
|
879
|
+
preds_train = np.asarray(preds_train)
|
|
880
|
+
preds_test = np.asarray(preds_test)
|
|
881
|
+
|
|
882
|
+
if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
|
|
883
|
+
col_name = f'pred_{pred_prefix}'
|
|
884
|
+
self.ctx.train_data[col_name] = preds_train.reshape(-1)
|
|
885
|
+
self.ctx.test_data[col_name] = preds_test.reshape(-1)
|
|
886
|
+
self.ctx.train_data[f'w_{col_name}'] = (
|
|
887
|
+
self.ctx.train_data[col_name] *
|
|
888
|
+
self.ctx.train_data[self.ctx.weight_nme]
|
|
889
|
+
)
|
|
890
|
+
self.ctx.test_data[f'w_{col_name}'] = (
|
|
891
|
+
self.ctx.test_data[col_name] *
|
|
892
|
+
self.ctx.test_data[self.ctx.weight_nme]
|
|
893
|
+
)
|
|
894
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
895
|
+
return
|
|
896
|
+
|
|
897
|
+
# Vector outputs (e.g., embeddings) are expanded into pred_<prefix>_0.. columns.
|
|
898
|
+
if preds_train.ndim != 2:
|
|
899
|
+
raise ValueError(
|
|
900
|
+
f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
|
|
901
|
+
if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
|
|
902
|
+
raise ValueError(
|
|
903
|
+
f"Train/test prediction dims mismatch for '{pred_prefix}': "
|
|
904
|
+
f"{preds_train.shape} vs {preds_test.shape}")
|
|
905
|
+
for j in range(preds_train.shape[1]):
|
|
906
|
+
col_name = f'pred_{pred_prefix}_{j}'
|
|
907
|
+
self.ctx.train_data[col_name] = preds_train[:, j]
|
|
908
|
+
self.ctx.test_data[col_name] = preds_test[:, j]
|
|
909
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
910
|
+
|
|
911
|
+
def _cache_predictions(self,
|
|
912
|
+
pred_prefix: str,
|
|
913
|
+
preds_train,
|
|
914
|
+
preds_test) -> None:
|
|
915
|
+
preds_train = np.asarray(preds_train)
|
|
916
|
+
preds_test = np.asarray(preds_test)
|
|
917
|
+
if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
|
|
918
|
+
if preds_test.ndim > 1:
|
|
919
|
+
preds_test = preds_test.reshape(-1)
|
|
920
|
+
col_name = f'pred_{pred_prefix}'
|
|
921
|
+
self.ctx.train_data[col_name] = preds_train.reshape(-1)
|
|
922
|
+
self.ctx.test_data[col_name] = preds_test.reshape(-1)
|
|
923
|
+
self.ctx.train_data[f'w_{col_name}'] = (
|
|
924
|
+
self.ctx.train_data[col_name] *
|
|
925
|
+
self.ctx.train_data[self.ctx.weight_nme]
|
|
926
|
+
)
|
|
927
|
+
self.ctx.test_data[f'w_{col_name}'] = (
|
|
928
|
+
self.ctx.test_data[col_name] *
|
|
929
|
+
self.ctx.test_data[self.ctx.weight_nme]
|
|
930
|
+
)
|
|
931
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
932
|
+
return
|
|
933
|
+
|
|
934
|
+
if preds_train.ndim != 2:
|
|
935
|
+
raise ValueError(
|
|
936
|
+
f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
|
|
937
|
+
if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
|
|
938
|
+
raise ValueError(
|
|
939
|
+
f"Train/test prediction dims mismatch for '{pred_prefix}': "
|
|
940
|
+
f"{preds_train.shape} vs {preds_test.shape}")
|
|
941
|
+
for j in range(preds_train.shape[1]):
|
|
942
|
+
col_name = f'pred_{pred_prefix}_{j}'
|
|
943
|
+
self.ctx.train_data[col_name] = preds_train[:, j]
|
|
944
|
+
self.ctx.test_data[col_name] = preds_test[:, j]
|
|
945
|
+
self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
|
|
946
|
+
|
|
947
|
+
def _maybe_cache_predictions(self, pred_prefix: str, preds_train, preds_test) -> None:
|
|
948
|
+
cfg = getattr(self.ctx, "config", None)
|
|
949
|
+
if cfg is None or not bool(getattr(cfg, "cache_predictions", False)):
|
|
950
|
+
return
|
|
951
|
+
fmt = str(getattr(cfg, "prediction_cache_format", "parquet") or "parquet").lower()
|
|
952
|
+
cache_dir = getattr(cfg, "prediction_cache_dir", None)
|
|
953
|
+
if cache_dir:
|
|
954
|
+
target_dir = Path(str(cache_dir))
|
|
955
|
+
if not target_dir.is_absolute():
|
|
956
|
+
target_dir = Path(self.output.result_dir) / target_dir
|
|
957
|
+
else:
|
|
958
|
+
target_dir = Path(self.output.result_dir) / "predictions"
|
|
959
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
960
|
+
|
|
961
|
+
def _build_frame(preds, split_label: str) -> pd.DataFrame:
|
|
962
|
+
arr = np.asarray(preds)
|
|
963
|
+
if arr.ndim <= 1:
|
|
964
|
+
return pd.DataFrame({f"pred_{pred_prefix}": arr.reshape(-1)})
|
|
965
|
+
cols = [f"pred_{pred_prefix}_{i}" for i in range(arr.shape[1])]
|
|
966
|
+
return pd.DataFrame(arr, columns=cols)
|
|
967
|
+
|
|
968
|
+
for split_label, preds in [("train", preds_train), ("test", preds_test)]:
|
|
969
|
+
frame = _build_frame(preds, split_label)
|
|
970
|
+
filename = f"{self.ctx.model_nme}_{pred_prefix}_{split_label}.{ 'csv' if fmt == 'csv' else 'parquet' }"
|
|
971
|
+
path = target_dir / filename
|
|
972
|
+
try:
|
|
973
|
+
if fmt == "csv":
|
|
974
|
+
frame.to_csv(path, index=False)
|
|
975
|
+
else:
|
|
976
|
+
frame.to_parquet(path, index=False)
|
|
977
|
+
except Exception:
|
|
978
|
+
pass
|
|
979
|
+
|
|
980
|
+
def _resolve_best_epoch(self,
|
|
981
|
+
history: Optional[Dict[str, List[float]]],
|
|
982
|
+
default_epochs: int) -> int:
|
|
983
|
+
if not history:
|
|
984
|
+
return max(1, int(default_epochs))
|
|
985
|
+
vals = history.get("val") or []
|
|
986
|
+
if not vals:
|
|
987
|
+
return max(1, int(default_epochs))
|
|
988
|
+
best_idx = int(np.nanargmin(vals))
|
|
989
|
+
return max(1, best_idx + 1)
|
|
990
|
+
|
|
991
|
+
def _fit_predict_cache(self,
|
|
992
|
+
model,
|
|
993
|
+
X_train,
|
|
994
|
+
y_train,
|
|
995
|
+
sample_weight,
|
|
996
|
+
pred_prefix: str,
|
|
997
|
+
use_oht: bool = False,
|
|
998
|
+
design_fn=None,
|
|
999
|
+
fit_kwargs: Optional[Dict[str, Any]] = None,
|
|
1000
|
+
sample_weight_arg: Optional[str] = 'sample_weight',
|
|
1001
|
+
predict_kwargs_train: Optional[Dict[str, Any]] = None,
|
|
1002
|
+
predict_kwargs_test: Optional[Dict[str, Any]] = None,
|
|
1003
|
+
predict_fn: Optional[Callable[..., Any]] = None,
|
|
1004
|
+
record_label: bool = True) -> None:
|
|
1005
|
+
fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
|
|
1006
|
+
if sample_weight is not None and sample_weight_arg:
|
|
1007
|
+
fit_kwargs.setdefault(sample_weight_arg, sample_weight)
|
|
1008
|
+
model.fit(X_train, y_train, **fit_kwargs)
|
|
1009
|
+
if record_label:
|
|
1010
|
+
self.ctx.model_label.append(self.label)
|
|
1011
|
+
self._predict_and_cache(
|
|
1012
|
+
model,
|
|
1013
|
+
pred_prefix,
|
|
1014
|
+
use_oht=use_oht,
|
|
1015
|
+
design_fn=design_fn,
|
|
1016
|
+
predict_kwargs_train=predict_kwargs_train,
|
|
1017
|
+
predict_kwargs_test=predict_kwargs_test,
|
|
1018
|
+
predict_fn=predict_fn)
|
|
1019
|
+
|
|
1020
|
+
|