ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,438 +1,442 @@
1
- from __future__ import annotations
2
-
3
- from datetime import timedelta
4
- import gc
5
- import os
6
- from pathlib import Path
7
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
8
-
9
- import joblib
10
- import numpy as np
11
- import optuna
12
- import pandas as pd
13
- import torch
14
- try: # pragma: no cover
15
- import torch.distributed as dist # type: ignore
16
- except Exception: # pragma: no cover
17
- dist = None # type: ignore
18
- from sklearn.model_selection import (
19
- GroupKFold,
20
- GroupShuffleSplit,
21
- KFold,
22
- ShuffleSplit,
23
- TimeSeriesSplit,
24
- )
25
- from sklearn.preprocessing import StandardScaler
26
-
27
- from ..config_preprocess import BayesOptConfig, OutputManager
28
- from ..utils import DistributedUtils, EPS, ensure_parent_dir
29
- from ins_pricing.utils import get_logger, GPUMemoryManager, DeviceManager
30
- from ins_pricing.utils.torch_compat import torch_load
31
-
32
- # Module-level logger
33
- _logger = get_logger("ins_pricing.trainer")
34
-
35
- class _OrderSplitter:
36
- def __init__(self, splitter, order: np.ndarray) -> None:
37
- self._splitter = splitter
38
- self._order = np.asarray(order)
39
-
40
- def split(self, X, y=None, groups=None):
41
- order = self._order
42
- X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
43
- for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
44
- yield order[tr_idx], order[val_idx]
45
-
46
-
47
- # =============================================================================
48
- # CV Strategy Resolution Helper
49
- # =============================================================================
50
-
51
-
52
- class CVStrategyResolver:
53
- """Helper class to resolve cross-validation splitting strategies.
54
-
55
- This encapsulates the logic for determining how to split data based on the
56
- configured strategy (random, time, group). It provides methods to:
57
- - Get time-ordered indices for a dataset
58
- - Get group values for a dataset
59
- - Create appropriate sklearn splitters
60
- """
61
-
62
- TIME_STRATEGIES = {"time", "timeseries", "temporal"}
63
- GROUP_STRATEGIES = {"group", "grouped"}
64
-
65
- def __init__(self, config, train_data: pd.DataFrame, rand_seed: Optional[int] = None):
66
- """Initialize the resolver.
67
-
68
- Args:
69
- config: BayesOptConfig with cv_strategy, cv_time_col, cv_group_col, etc.
70
- train_data: The training DataFrame (needed for column access)
71
- rand_seed: Random seed for reproducible splits
72
- """
73
- self.config = config
74
- self.train_data = train_data
75
- self.rand_seed = rand_seed
76
- self._strategy = self._normalize_strategy()
77
-
78
- def _normalize_strategy(self) -> str:
79
- """Normalize the strategy string to lowercase."""
80
- raw = str(getattr(self.config, "cv_strategy", "random") or "random")
81
- return raw.strip().lower()
82
-
83
- @property
84
- def strategy(self) -> str:
85
- """Return the normalized CV strategy."""
86
- return self._strategy
87
-
88
- def is_time_strategy(self) -> bool:
89
- """Check if using a time-based splitting strategy."""
90
- return self._strategy in self.TIME_STRATEGIES
91
-
92
- def is_group_strategy(self) -> bool:
93
- """Check if using a group-based splitting strategy."""
94
- return self._strategy in self.GROUP_STRATEGIES
95
-
96
- def get_time_col(self) -> str:
97
- """Get and validate the time column.
98
-
99
- Raises:
100
- ValueError: If time column is not configured
101
- KeyError: If time column not found in train_data
102
- """
103
- time_col = getattr(self.config, "cv_time_col", None)
104
- if not time_col:
105
- raise ValueError("cv_time_col is required for time cv_strategy.")
106
- if time_col not in self.train_data.columns:
107
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
108
- return time_col
109
-
110
- def get_time_ascending(self) -> bool:
111
- """Get the time ordering preference."""
112
- return bool(getattr(self.config, "cv_time_ascending", True))
113
-
114
- def get_group_col(self) -> str:
115
- """Get and validate the group column.
116
-
117
- Raises:
118
- ValueError: If group column is not configured
119
- KeyError: If group column not found in train_data
120
- """
121
- group_col = getattr(self.config, "cv_group_col", None)
122
- if not group_col:
123
- raise ValueError("cv_group_col is required for group cv_strategy.")
124
- if group_col not in self.train_data.columns:
125
- raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
126
- return group_col
127
-
128
- def get_time_ordered_indices(self, X_all: pd.DataFrame) -> np.ndarray:
129
- """Get indices ordered by time for the given dataset.
130
-
131
- Args:
132
- X_all: DataFrame to get indices for (must have index compatible with train_data)
133
-
134
- Returns:
135
- Array of positional indices into X_all, ordered by time
136
- """
137
- time_col = self.get_time_col()
138
- ascending = self.get_time_ascending()
139
- order_index = self.train_data[time_col].sort_values(ascending=ascending).index
140
- index_set = set(X_all.index)
141
- order_index = [idx for idx in order_index if idx in index_set]
142
- order = X_all.index.get_indexer(order_index)
143
- return order[order >= 0]
144
-
145
- def get_groups(self, X_all: pd.DataFrame) -> pd.Series:
146
- """Get group labels for the given dataset.
147
-
148
- Args:
149
- X_all: DataFrame to get groups for
150
-
151
- Returns:
152
- Series of group labels aligned with X_all
153
- """
154
- group_col = self.get_group_col()
155
- return self.train_data.reindex(X_all.index)[group_col]
156
-
157
- def create_train_val_splitter(
158
- self,
159
- X_all: pd.DataFrame,
160
- val_ratio: float,
161
- ) -> Tuple[Optional[Tuple[np.ndarray, np.ndarray]], Optional[pd.Series]]:
162
- """Create a single train/val split based on strategy.
163
-
164
- Args:
165
- X_all: DataFrame to split
166
- val_ratio: Fraction of data for validation
167
-
168
- Returns:
169
- Tuple of ((train_idx, val_idx), groups) where groups is None for non-group strategies
170
- """
171
- if self.is_time_strategy():
172
- order = self.get_time_ordered_indices(X_all)
173
- cutoff = int(len(order) * (1.0 - val_ratio))
174
- if cutoff <= 0 or cutoff >= len(order):
175
- raise ValueError(f"val_ratio={val_ratio} leaves no data for train/val split.")
176
- return (order[:cutoff], order[cutoff:]), None
177
-
178
- if self.is_group_strategy():
179
- groups = self.get_groups(X_all)
180
- splitter = GroupShuffleSplit(
181
- n_splits=1, test_size=val_ratio, random_state=self.rand_seed
182
- )
183
- train_idx, val_idx = next(splitter.split(X_all, groups=groups))
184
- return (train_idx, val_idx), groups
185
-
186
- # Random strategy
187
- splitter = ShuffleSplit(
188
- n_splits=1, test_size=val_ratio, random_state=self.rand_seed
189
- )
190
- train_idx, val_idx = next(splitter.split(X_all))
191
- return (train_idx, val_idx), None
192
-
193
- def create_cv_splitter(
194
- self,
195
- X_all: pd.DataFrame,
196
- y_all: Optional[pd.Series],
197
- n_splits: int,
198
- val_ratio: float,
199
- ) -> Tuple[Iterable[Tuple[np.ndarray, np.ndarray]], int]:
200
- """Create a cross-validation splitter based on strategy.
201
-
202
- Args:
203
- X_all: DataFrame to split
204
- y_all: Target series (used by some splitters)
205
- n_splits: Number of CV folds
206
- val_ratio: Validation ratio (for ShuffleSplit)
207
-
208
- Returns:
209
- Tuple of (split_iterator, actual_n_splits)
210
- """
211
- n_splits = max(2, int(n_splits))
212
-
213
- if self.is_group_strategy():
214
- groups = self.get_groups(X_all)
215
- n_groups = int(groups.nunique(dropna=False))
216
- if n_groups < 2:
217
- return iter([]), 0
218
- n_splits = min(n_splits, n_groups)
219
- if n_splits < 2:
220
- return iter([]), 0
221
- splitter = GroupKFold(n_splits=n_splits)
222
- return splitter.split(X_all, y_all, groups=groups), n_splits
223
-
224
- if self.is_time_strategy():
225
- order = self.get_time_ordered_indices(X_all)
226
- if len(order) < 2:
227
- return iter([]), 0
228
- n_splits = min(n_splits, max(2, len(order) - 1))
229
- if n_splits < 2:
230
- return iter([]), 0
231
- splitter = TimeSeriesSplit(n_splits=n_splits)
232
- return _OrderSplitter(splitter, order).split(X_all), n_splits
233
-
234
- # Random strategy
235
- if len(X_all) < n_splits:
236
- n_splits = len(X_all)
237
- if n_splits < 2:
238
- return iter([]), 0
239
- splitter = ShuffleSplit(
240
- n_splits=n_splits, test_size=val_ratio, random_state=self.rand_seed
241
- )
242
- return splitter.split(X_all), n_splits
243
-
244
- def create_kfold_splitter(
245
- self,
246
- X_all: pd.DataFrame,
247
- k: int,
248
- ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
249
- """Create a K-fold splitter for ensemble training.
250
-
251
- Args:
252
- X_all: DataFrame to split
253
- k: Number of folds
254
-
255
- Returns:
256
- Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
257
- """
258
- k = max(2, int(k))
259
- n_samples = len(X_all)
260
- if n_samples < 2:
261
- return None, 0
262
-
263
- if self.is_group_strategy():
264
- groups = self.get_groups(X_all)
265
- n_groups = int(groups.nunique(dropna=False))
266
- if n_groups < 2:
267
- return None, 0
268
- k = min(k, n_groups)
269
- if k < 2:
270
- return None, 0
271
- splitter = GroupKFold(n_splits=k)
272
- return splitter.split(X_all, y=None, groups=groups), k
273
-
274
- if self.is_time_strategy():
275
- order = self.get_time_ordered_indices(X_all)
276
- if len(order) < 2:
277
- return None, 0
278
- k = min(k, max(2, len(order) - 1))
279
- if k < 2:
280
- return None, 0
281
- splitter = TimeSeriesSplit(n_splits=k)
282
- return _OrderSplitter(splitter, order).split(X_all), k
283
-
284
- # Random strategy with KFold
285
- k = min(k, n_samples)
286
- if k < 2:
287
- return None, 0
288
- splitter = KFold(n_splits=k, shuffle=True, random_state=self.rand_seed)
289
- return splitter.split(X_all), k
290
-
291
-
292
- # =============================================================================
293
- # Trainer system
294
- # =============================================================================
295
-
296
-
297
- class TrainerBase:
298
- def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
299
- self.ctx = context
300
- self.label = label
301
- self.model_name_prefix = model_name_prefix
302
- self.model = None
303
- self.best_params: Optional[Dict[str, Any]] = None
304
- self.best_trial = None
305
- self.study_name: Optional[str] = None
306
- self.enable_distributed_optuna: bool = False
307
- self._distributed_forced_params: Optional[Dict[str, Any]] = None
308
-
309
- def _apply_dataloader_overrides(self, model: Any) -> Any:
310
- """Apply dataloader-related overrides from config to a model."""
311
- cfg = getattr(self.ctx, "config", None)
312
- if cfg is None:
313
- return model
314
- workers = getattr(cfg, "dataloader_workers", None)
315
- if workers is not None:
316
- model.dataloader_workers = int(workers)
317
- profile = getattr(cfg, "resource_profile", None)
318
- if profile:
319
- model.resource_profile = str(profile)
320
- return model
321
-
322
- def _export_preprocess_artifacts(self) -> Dict[str, Any]:
323
- dummy_columns: List[str] = []
324
- if getattr(self.ctx, "train_oht_data", None) is not None:
325
- dummy_columns = list(self.ctx.train_oht_data.columns)
326
- return {
327
- "factor_nmes": list(getattr(self.ctx, "factor_nmes", []) or []),
328
- "cate_list": list(getattr(self.ctx, "cate_list", []) or []),
329
- "num_features": list(getattr(self.ctx, "num_features", []) or []),
330
- "var_nmes": list(getattr(self.ctx, "var_nmes", []) or []),
331
- "cat_categories": dict(getattr(self.ctx, "cat_categories_for_shap", {}) or {}),
332
- "dummy_columns": dummy_columns,
333
- "numeric_scalers": dict(getattr(self.ctx, "numeric_scalers", {}) or {}),
334
- "weight_nme": str(getattr(self.ctx, "weight_nme", "")),
335
- "resp_nme": str(getattr(self.ctx, "resp_nme", "")),
336
- "binary_resp_nme": getattr(self.ctx, "binary_resp_nme", None),
337
- "drop_first": True,
338
- }
339
-
340
- def _dist_barrier(self, reason: str) -> None:
341
- """DDP barrier wrapper used by distributed Optuna.
342
-
343
- To debug "trial finished but next trial never starts" hangs, set these
344
- environment variables (either in shell or config.json `env`):
345
- - `BAYESOPT_DDP_BARRIER_DEBUG=1` to print barrier enter/exit per-rank
346
- - `BAYESOPT_DDP_BARRIER_TIMEOUT=300` to fail fast instead of waiting forever
347
- - `TORCH_DISTRIBUTED_DEBUG=DETAIL` and `NCCL_DEBUG=INFO` for PyTorch/NCCL logs
348
- """
349
- if dist is None:
350
- return
351
- try:
352
- if not getattr(dist, "is_available", lambda: False)():
353
- return
354
- if not dist.is_initialized():
355
- return
356
- except Exception:
357
- return
358
-
359
- timeout_seconds = int(os.environ.get("BAYESOPT_DDP_BARRIER_TIMEOUT", "1800"))
360
- debug_barrier = os.environ.get("BAYESOPT_DDP_BARRIER_DEBUG", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
361
- rank = None
362
- world = None
363
- if debug_barrier:
364
- try:
365
- rank = dist.get_rank()
366
- world = dist.get_world_size()
367
- print(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
368
- except Exception:
369
- debug_barrier = False
370
- try:
371
- timeout = timedelta(seconds=timeout_seconds)
372
- backend = None
373
- try:
374
- backend = dist.get_backend()
375
- except Exception:
376
- backend = None
377
-
378
- # `monitored_barrier` is only implemented for GLOO; using it under NCCL
379
- # will raise and can itself trigger a secondary hang. Prefer an async
380
- # barrier with timeout for NCCL.
381
- monitored = getattr(dist, "monitored_barrier", None)
382
- if backend == "gloo" and callable(monitored):
383
- monitored(timeout=timeout)
384
- else:
385
- work = None
386
- try:
387
- work = dist.barrier(async_op=True)
388
- except TypeError:
389
- work = None
390
- if work is not None:
391
- wait = getattr(work, "wait", None)
392
- if callable(wait):
393
- try:
394
- wait(timeout=timeout)
395
- except TypeError:
396
- wait()
397
- else:
398
- dist.barrier()
399
- else:
400
- dist.barrier()
401
- if debug_barrier:
402
- print(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
403
- except Exception as exc:
404
- print(
405
- f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
406
- flush=True,
407
- )
408
- raise
409
-
410
- @property
411
- def config(self) -> BayesOptConfig:
412
- return self.ctx.config
413
-
414
- @property
415
- def output(self) -> OutputManager:
416
- return self.ctx.output_manager
417
-
418
- def _get_model_filename(self) -> str:
419
- ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
420
- return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
421
-
422
- def _resolve_optuna_storage_url(self) -> Optional[str]:
423
- storage = getattr(self.config, "optuna_storage", None)
424
- if not storage:
425
- return None
426
- storage_str = str(storage).strip()
427
- if not storage_str:
428
- return None
429
- if "://" in storage_str or storage_str == ":memory:":
430
- return storage_str
431
- path = Path(storage_str)
432
- path = path.resolve()
433
- ensure_parent_dir(str(path))
434
- return f"sqlite:///{path.as_posix()}"
435
-
1
+ from __future__ import annotations
2
+
3
+ from datetime import timedelta
4
+ import gc
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
8
+
9
+ import joblib
10
+ import numpy as np
11
+ import optuna
12
+ import pandas as pd
13
+ import torch
14
+ try: # pragma: no cover
15
+ import torch.distributed as dist # type: ignore
16
+ except Exception: # pragma: no cover
17
+ dist = None # type: ignore
18
+ from sklearn.model_selection import (
19
+ GroupKFold,
20
+ GroupShuffleSplit,
21
+ KFold,
22
+ ShuffleSplit,
23
+ TimeSeriesSplit,
24
+ )
25
+ from sklearn.preprocessing import StandardScaler
26
+
27
+ from ins_pricing.modelling.bayesopt.config_preprocess import BayesOptConfig, OutputManager
28
+ from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
29
+ from ins_pricing.utils import EPS, ensure_parent_dir, get_logger, GPUMemoryManager, DeviceManager, log_print
30
+ from ins_pricing.utils.torch_compat import torch_load
31
+
32
+ # Module-level logger
33
+ _logger = get_logger("ins_pricing.trainer")
34
+
35
+
36
+ def _log(*args, **kwargs) -> None:
37
+ log_print(_logger, *args, **kwargs)
38
+
39
+ class _OrderSplitter:
40
+ def __init__(self, splitter, order: np.ndarray) -> None:
41
+ self._splitter = splitter
42
+ self._order = np.asarray(order)
43
+
44
+ def split(self, X, y=None, groups=None):
45
+ order = self._order
46
+ X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
47
+ for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
48
+ yield order[tr_idx], order[val_idx]
49
+
50
+
51
+ # =============================================================================
52
+ # CV Strategy Resolution Helper
53
+ # =============================================================================
54
+
55
+
56
+ class CVStrategyResolver:
57
+ """Helper class to resolve cross-validation splitting strategies.
58
+
59
+ This encapsulates the logic for determining how to split data based on the
60
+ configured strategy (random, time, group). It provides methods to:
61
+ - Get time-ordered indices for a dataset
62
+ - Get group values for a dataset
63
+ - Create appropriate sklearn splitters
64
+ """
65
+
66
+ TIME_STRATEGIES = {"time", "timeseries", "temporal"}
67
+ GROUP_STRATEGIES = {"group", "grouped"}
68
+
69
+ def __init__(self, config, train_data: pd.DataFrame, rand_seed: Optional[int] = None):
70
+ """Initialize the resolver.
71
+
72
+ Args:
73
+ config: BayesOptConfig with cv_strategy, cv_time_col, cv_group_col, etc.
74
+ train_data: The training DataFrame (needed for column access)
75
+ rand_seed: Random seed for reproducible splits
76
+ """
77
+ self.config = config
78
+ self.train_data = train_data
79
+ self.rand_seed = rand_seed
80
+ self._strategy = self._normalize_strategy()
81
+
82
+ def _normalize_strategy(self) -> str:
83
+ """Normalize the strategy string to lowercase."""
84
+ raw = str(getattr(self.config, "cv_strategy", "random") or "random")
85
+ return raw.strip().lower()
86
+
87
+ @property
88
+ def strategy(self) -> str:
89
+ """Return the normalized CV strategy."""
90
+ return self._strategy
91
+
92
+ def is_time_strategy(self) -> bool:
93
+ """Check if using a time-based splitting strategy."""
94
+ return self._strategy in self.TIME_STRATEGIES
95
+
96
+ def is_group_strategy(self) -> bool:
97
+ """Check if using a group-based splitting strategy."""
98
+ return self._strategy in self.GROUP_STRATEGIES
99
+
100
+ def get_time_col(self) -> str:
101
+ """Get and validate the time column.
102
+
103
+ Raises:
104
+ ValueError: If time column is not configured
105
+ KeyError: If time column not found in train_data
106
+ """
107
+ time_col = getattr(self.config, "cv_time_col", None)
108
+ if not time_col:
109
+ raise ValueError("cv_time_col is required for time cv_strategy.")
110
+ if time_col not in self.train_data.columns:
111
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
112
+ return time_col
113
+
114
+ def get_time_ascending(self) -> bool:
115
+ """Get the time ordering preference."""
116
+ return bool(getattr(self.config, "cv_time_ascending", True))
117
+
118
+ def get_group_col(self) -> str:
119
+ """Get and validate the group column.
120
+
121
+ Raises:
122
+ ValueError: If group column is not configured
123
+ KeyError: If group column not found in train_data
124
+ """
125
+ group_col = getattr(self.config, "cv_group_col", None)
126
+ if not group_col:
127
+ raise ValueError("cv_group_col is required for group cv_strategy.")
128
+ if group_col not in self.train_data.columns:
129
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
130
+ return group_col
131
+
132
+ def get_time_ordered_indices(self, X_all: pd.DataFrame) -> np.ndarray:
133
+ """Get indices ordered by time for the given dataset.
134
+
135
+ Args:
136
+ X_all: DataFrame to get indices for (must have index compatible with train_data)
137
+
138
+ Returns:
139
+ Array of positional indices into X_all, ordered by time
140
+ """
141
+ time_col = self.get_time_col()
142
+ ascending = self.get_time_ascending()
143
+ order_index = self.train_data[time_col].sort_values(ascending=ascending).index
144
+ index_set = set(X_all.index)
145
+ order_index = [idx for idx in order_index if idx in index_set]
146
+ order = X_all.index.get_indexer(order_index)
147
+ return order[order >= 0]
148
+
149
+ def get_groups(self, X_all: pd.DataFrame) -> pd.Series:
150
+ """Get group labels for the given dataset.
151
+
152
+ Args:
153
+ X_all: DataFrame to get groups for
154
+
155
+ Returns:
156
+ Series of group labels aligned with X_all
157
+ """
158
+ group_col = self.get_group_col()
159
+ return self.train_data.reindex(X_all.index)[group_col]
160
+
161
+ def create_train_val_splitter(
162
+ self,
163
+ X_all: pd.DataFrame,
164
+ val_ratio: float,
165
+ ) -> Tuple[Optional[Tuple[np.ndarray, np.ndarray]], Optional[pd.Series]]:
166
+ """Create a single train/val split based on strategy.
167
+
168
+ Args:
169
+ X_all: DataFrame to split
170
+ val_ratio: Fraction of data for validation
171
+
172
+ Returns:
173
+ Tuple of ((train_idx, val_idx), groups) where groups is None for non-group strategies
174
+ """
175
+ if self.is_time_strategy():
176
+ order = self.get_time_ordered_indices(X_all)
177
+ cutoff = int(len(order) * (1.0 - val_ratio))
178
+ if cutoff <= 0 or cutoff >= len(order):
179
+ raise ValueError(f"val_ratio={val_ratio} leaves no data for train/val split.")
180
+ return (order[:cutoff], order[cutoff:]), None
181
+
182
+ if self.is_group_strategy():
183
+ groups = self.get_groups(X_all)
184
+ splitter = GroupShuffleSplit(
185
+ n_splits=1, test_size=val_ratio, random_state=self.rand_seed
186
+ )
187
+ train_idx, val_idx = next(splitter.split(X_all, groups=groups))
188
+ return (train_idx, val_idx), groups
189
+
190
+ # Random strategy
191
+ splitter = ShuffleSplit(
192
+ n_splits=1, test_size=val_ratio, random_state=self.rand_seed
193
+ )
194
+ train_idx, val_idx = next(splitter.split(X_all))
195
+ return (train_idx, val_idx), None
196
+
197
+ def create_cv_splitter(
198
+ self,
199
+ X_all: pd.DataFrame,
200
+ y_all: Optional[pd.Series],
201
+ n_splits: int,
202
+ val_ratio: float,
203
+ ) -> Tuple[Iterable[Tuple[np.ndarray, np.ndarray]], int]:
204
+ """Create a cross-validation splitter based on strategy.
205
+
206
+ Args:
207
+ X_all: DataFrame to split
208
+ y_all: Target series (used by some splitters)
209
+ n_splits: Number of CV folds
210
+ val_ratio: Validation ratio (for ShuffleSplit)
211
+
212
+ Returns:
213
+ Tuple of (split_iterator, actual_n_splits)
214
+ """
215
+ n_splits = max(2, int(n_splits))
216
+
217
+ if self.is_group_strategy():
218
+ groups = self.get_groups(X_all)
219
+ n_groups = int(groups.nunique(dropna=False))
220
+ if n_groups < 2:
221
+ return iter([]), 0
222
+ n_splits = min(n_splits, n_groups)
223
+ if n_splits < 2:
224
+ return iter([]), 0
225
+ splitter = GroupKFold(n_splits=n_splits)
226
+ return splitter.split(X_all, y_all, groups=groups), n_splits
227
+
228
+ if self.is_time_strategy():
229
+ order = self.get_time_ordered_indices(X_all)
230
+ if len(order) < 2:
231
+ return iter([]), 0
232
+ n_splits = min(n_splits, max(2, len(order) - 1))
233
+ if n_splits < 2:
234
+ return iter([]), 0
235
+ splitter = TimeSeriesSplit(n_splits=n_splits)
236
+ return _OrderSplitter(splitter, order).split(X_all), n_splits
237
+
238
+ # Random strategy
239
+ if len(X_all) < n_splits:
240
+ n_splits = len(X_all)
241
+ if n_splits < 2:
242
+ return iter([]), 0
243
+ splitter = ShuffleSplit(
244
+ n_splits=n_splits, test_size=val_ratio, random_state=self.rand_seed
245
+ )
246
+ return splitter.split(X_all), n_splits
247
+
248
+ def create_kfold_splitter(
249
+ self,
250
+ X_all: pd.DataFrame,
251
+ k: int,
252
+ ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
253
+ """Create a K-fold splitter for ensemble training.
254
+
255
+ Args:
256
+ X_all: DataFrame to split
257
+ k: Number of folds
258
+
259
+ Returns:
260
+ Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
261
+ """
262
+ k = max(2, int(k))
263
+ n_samples = len(X_all)
264
+ if n_samples < 2:
265
+ return None, 0
266
+
267
+ if self.is_group_strategy():
268
+ groups = self.get_groups(X_all)
269
+ n_groups = int(groups.nunique(dropna=False))
270
+ if n_groups < 2:
271
+ return None, 0
272
+ k = min(k, n_groups)
273
+ if k < 2:
274
+ return None, 0
275
+ splitter = GroupKFold(n_splits=k)
276
+ return splitter.split(X_all, y=None, groups=groups), k
277
+
278
+ if self.is_time_strategy():
279
+ order = self.get_time_ordered_indices(X_all)
280
+ if len(order) < 2:
281
+ return None, 0
282
+ k = min(k, max(2, len(order) - 1))
283
+ if k < 2:
284
+ return None, 0
285
+ splitter = TimeSeriesSplit(n_splits=k)
286
+ return _OrderSplitter(splitter, order).split(X_all), k
287
+
288
+ # Random strategy with KFold
289
+ k = min(k, n_samples)
290
+ if k < 2:
291
+ return None, 0
292
+ splitter = KFold(n_splits=k, shuffle=True, random_state=self.rand_seed)
293
+ return splitter.split(X_all), k
294
+
295
+
296
+ # =============================================================================
297
+ # Trainer system
298
+ # =============================================================================
299
+
300
+
301
+ class TrainerBase:
302
+ def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
303
+ self.ctx = context
304
+ self.label = label
305
+ self.model_name_prefix = model_name_prefix
306
+ self.model = None
307
+ self.best_params: Optional[Dict[str, Any]] = None
308
+ self.best_trial = None
309
+ self.study_name: Optional[str] = None
310
+ self.enable_distributed_optuna: bool = False
311
+ self._distributed_forced_params: Optional[Dict[str, Any]] = None
312
+
313
+ def _apply_dataloader_overrides(self, model: Any) -> Any:
314
+ """Apply dataloader-related overrides from config to a model."""
315
+ cfg = getattr(self.ctx, "config", None)
316
+ if cfg is None:
317
+ return model
318
+ workers = getattr(cfg, "dataloader_workers", None)
319
+ if workers is not None:
320
+ model.dataloader_workers = int(workers)
321
+ profile = getattr(cfg, "resource_profile", None)
322
+ if profile:
323
+ model.resource_profile = str(profile)
324
+ return model
325
+
326
+ def _export_preprocess_artifacts(self) -> Dict[str, Any]:
327
+ dummy_columns: List[str] = []
328
+ if getattr(self.ctx, "train_oht_data", None) is not None:
329
+ dummy_columns = list(self.ctx.train_oht_data.columns)
330
+ return {
331
+ "factor_nmes": list(getattr(self.ctx, "factor_nmes", []) or []),
332
+ "cate_list": list(getattr(self.ctx, "cate_list", []) or []),
333
+ "num_features": list(getattr(self.ctx, "num_features", []) or []),
334
+ "var_nmes": list(getattr(self.ctx, "var_nmes", []) or []),
335
+ "cat_categories": dict(getattr(self.ctx, "cat_categories_for_shap", {}) or {}),
336
+ "dummy_columns": dummy_columns,
337
+ "numeric_scalers": dict(getattr(self.ctx, "numeric_scalers", {}) or {}),
338
+ "weight_nme": str(getattr(self.ctx, "weight_nme", "")),
339
+ "resp_nme": str(getattr(self.ctx, "resp_nme", "")),
340
+ "binary_resp_nme": getattr(self.ctx, "binary_resp_nme", None),
341
+ "drop_first": True,
342
+ }
343
+
344
+ def _dist_barrier(self, reason: str) -> None:
345
+ """DDP barrier wrapper used by distributed Optuna.
346
+
347
+ To debug "trial finished but next trial never starts" hangs, set these
348
+ environment variables (either in shell or config.json `env`):
349
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1` to print barrier enter/exit per-rank
350
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300` to fail fast instead of waiting forever
351
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL` and `NCCL_DEBUG=INFO` for PyTorch/NCCL logs
352
+ """
353
+ if dist is None:
354
+ return
355
+ try:
356
+ if not getattr(dist, "is_available", lambda: False)():
357
+ return
358
+ if not dist.is_initialized():
359
+ return
360
+ except Exception:
361
+ return
362
+
363
+ timeout_seconds = int(os.environ.get("BAYESOPT_DDP_BARRIER_TIMEOUT", "1800"))
364
+ debug_barrier = os.environ.get("BAYESOPT_DDP_BARRIER_DEBUG", "").strip() in {"1", "true", "TRUE", "yes", "YES"}
365
+ rank = None
366
+ world = None
367
+ if debug_barrier:
368
+ try:
369
+ rank = dist.get_rank()
370
+ world = dist.get_world_size()
371
+ _log(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
372
+ except Exception:
373
+ debug_barrier = False
374
+ try:
375
+ timeout = timedelta(seconds=timeout_seconds)
376
+ backend = None
377
+ try:
378
+ backend = dist.get_backend()
379
+ except Exception:
380
+ backend = None
381
+
382
+ # `monitored_barrier` is only implemented for GLOO; using it under NCCL
383
+ # will raise and can itself trigger a secondary hang. Prefer an async
384
+ # barrier with timeout for NCCL.
385
+ monitored = getattr(dist, "monitored_barrier", None)
386
+ if backend == "gloo" and callable(monitored):
387
+ monitored(timeout=timeout)
388
+ else:
389
+ work = None
390
+ try:
391
+ work = dist.barrier(async_op=True)
392
+ except TypeError:
393
+ work = None
394
+ if work is not None:
395
+ wait = getattr(work, "wait", None)
396
+ if callable(wait):
397
+ try:
398
+ wait(timeout=timeout)
399
+ except TypeError:
400
+ wait()
401
+ else:
402
+ dist.barrier()
403
+ else:
404
+ dist.barrier()
405
+ if debug_barrier:
406
+ _log(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
407
+ except Exception as exc:
408
+ _log(
409
+ f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
410
+ flush=True,
411
+ )
412
+ raise
413
+
414
+ @property
415
+ def config(self) -> BayesOptConfig:
416
+ return self.ctx.config
417
+
418
+ @property
419
+ def output(self) -> OutputManager:
420
+ return self.ctx.output_manager
421
+
422
+ def _get_model_filename(self) -> str:
423
+ ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
424
+ return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
425
+
426
+ def _resolve_optuna_storage_url(self) -> Optional[str]:
427
+ storage = getattr(self.config, "optuna_storage", None)
428
+ if not storage:
429
+ return None
430
+ storage_str = str(storage).strip()
431
+ if not storage_str:
432
+ return None
433
+ if "://" in storage_str or storage_str == ":memory:":
434
+ return storage_str
435
+ path = Path(storage_str)
436
+ path = path.resolve()
437
+ ensure_parent_dir(str(path))
438
+ return f"sqlite:///{path.as_posix()}"
439
+
436
440
  def _resolve_optuna_study_name(self) -> str:
437
441
  prefix = getattr(self.config, "optuna_study_prefix",
438
442
  None) or "bayesopt"
@@ -440,869 +444,877 @@ class TrainerBase:
440
444
  safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
441
445
  return safe.lower()
442
446
 
443
- def tune(self, max_evals: int, objective_fn=None) -> None:
444
- # Generic Optuna tuning loop.
445
- if objective_fn is None:
446
- # If subclass doesn't provide objective_fn, default to cross_val.
447
- objective_fn = self.cross_val
448
-
449
- if self._should_use_distributed_optuna():
450
- self._distributed_tune(max_evals, objective_fn)
451
- return
452
-
453
- total_trials = max(1, int(max_evals))
454
- progress_counter = {"count": 0}
455
-
456
- def objective_wrapper(trial: optuna.trial.Trial) -> float:
457
- should_log = DistributedUtils.is_main_process()
458
- if should_log:
459
- current_idx = progress_counter["count"] + 1
460
- print(
461
- f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
462
- f"(trial_id={trial.number})."
463
- )
464
- try:
465
- result = objective_fn(trial)
447
+ def _optuna_cleanup_sync(self) -> bool:
448
+ return bool(getattr(self.config, "optuna_cleanup_synchronize", False))
449
+
450
+ def tune(self, max_evals: int, objective_fn=None) -> None:
451
+ # Generic Optuna tuning loop.
452
+ if objective_fn is None:
453
+ # If subclass doesn't provide objective_fn, default to cross_val.
454
+ objective_fn = self.cross_val
455
+
456
+ if self._should_use_distributed_optuna():
457
+ self._distributed_tune(max_evals, objective_fn)
458
+ return
459
+
460
+ total_trials = max(1, int(max_evals))
461
+ progress_counter = {"count": 0}
462
+
463
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
464
+ should_log = DistributedUtils.is_main_process()
465
+ if should_log:
466
+ current_idx = progress_counter["count"] + 1
467
+ _log(
468
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
469
+ f"(trial_id={trial.number})."
470
+ )
471
+ try:
472
+ result = objective_fn(trial)
466
473
  except RuntimeError as exc:
467
474
  if "out of memory" in str(exc).lower():
468
- print(
475
+ _log(
469
476
  f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
470
477
  )
471
- self._clean_gpu()
478
+ self._clean_gpu(synchronize=True)
472
479
  raise optuna.TrialPruned() from exc
473
480
  raise
474
481
  finally:
475
- self._clean_gpu()
476
- if should_log:
477
- progress_counter["count"] = progress_counter["count"] + 1
478
- trial_state = getattr(trial, "state", None)
479
- state_repr = getattr(trial_state, "name", "OK")
480
- print(
481
- f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
482
- f"(status={state_repr})."
483
- )
484
- return result
485
-
486
- storage_url = self._resolve_optuna_storage_url()
487
- study_name = self._resolve_optuna_study_name()
488
- study_kwargs: Dict[str, Any] = {
489
- "direction": "minimize",
490
- "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
491
- }
492
- if storage_url:
493
- study_kwargs.update(
494
- storage=storage_url,
495
- study_name=study_name,
496
- load_if_exists=True,
497
- )
498
-
499
- study = optuna.create_study(**study_kwargs)
500
- self.study_name = getattr(study, "study_name", None)
501
-
502
- def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
503
- # Persist best_params after each trial to allow safe resume.
504
- try:
505
- best = getattr(check_study, "best_trial", None)
506
- if best is None:
507
- return
508
- best_params = getattr(best, "params", None)
509
- if not best_params:
510
- return
511
- params_path = self.output.result_path(
512
- f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
513
- )
514
- pd.DataFrame(best_params, index=[0]).to_csv(
515
- params_path, index=False)
516
- except Exception:
517
- return
518
-
519
- completed_states = (
520
- optuna.trial.TrialState.COMPLETE,
521
- optuna.trial.TrialState.PRUNED,
522
- optuna.trial.TrialState.FAIL,
523
- )
524
- completed = len(study.get_trials(states=completed_states))
525
- progress_counter["count"] = completed
526
- remaining = max(0, total_trials - completed)
527
- if remaining > 0:
528
- study.optimize(
529
- objective_wrapper,
530
- n_trials=remaining,
531
- callbacks=[checkpoint_callback],
532
- )
533
- self.best_params = study.best_params
534
- self.best_trial = study.best_trial
535
-
536
- # Save best params to CSV for reproducibility.
537
- params_path = self.output.result_path(
538
- f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
539
- )
540
- pd.DataFrame(self.best_params, index=[0]).to_csv(
541
- params_path, index=False)
542
-
543
- def train(self) -> None:
544
- raise NotImplementedError
545
-
546
- def _unwrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
547
- """Unwrap DDP or DataParallel wrapper to get the base module."""
548
- from torch.nn.parallel import DistributedDataParallel as DDP
549
- if isinstance(module, (DDP, torch.nn.DataParallel)):
550
- return module.module
551
- return module
552
-
553
- def save(self) -> None:
554
- if self.model is None:
555
- print(f"[save] Warning: No model to save for {self.label}")
556
- return
557
-
558
- path = self.output.model_path(self._get_model_filename())
559
- if self.label in ['Xgboost', 'GLM']:
560
- payload = {
561
- "model": self.model,
562
- "preprocess_artifacts": self._export_preprocess_artifacts(),
563
- }
564
- joblib.dump(payload, path)
565
- else:
566
- # PyTorch models: save state_dict without DDP/DataParallel wrappers
567
- # to ensure cross-platform compatibility.
568
- payload = {
569
- "preprocess_artifacts": self._export_preprocess_artifacts(),
570
- }
571
- if hasattr(self.model, 'resnet'): # ResNetSklearn model
572
- # Unwrap DDP/DataParallel and move to CPU
573
- resnet = self._unwrap_module(self.model.resnet)
574
- resnet_cpu = resnet.to("cpu")
575
- payload["state_dict"] = resnet_cpu.state_dict()
576
- payload["best_params"] = dict(self.best_params or {})
577
- elif hasattr(self.model, 'ft'): # FTTransformerSklearn model
578
- # Unwrap DDP/DataParallel and save state_dict (not full model object)
579
- # to avoid serialization issues with DDP wrappers
580
- ft = self._unwrap_module(self.model.ft)
581
- ft_cpu = ft.to("cpu")
582
- payload["state_dict"] = ft_cpu.state_dict()
583
- payload["best_params"] = dict(self.best_params or {})
584
- # Save model configuration for reconstruction
585
- payload["model_config"] = {
586
- "model_nme": getattr(self.model, "model_nme", ""),
587
- "num_cols": list(getattr(self.model, "num_cols", [])),
588
- "cat_cols": list(getattr(self.model, "cat_cols", [])),
589
- "d_model": getattr(self.model, "d_model", 64),
590
- "n_heads": getattr(self.model, "n_heads", 8),
591
- "n_layers": getattr(self.model, "n_layers", 4),
592
- "dropout": getattr(self.model, "dropout", 0.1),
593
- "task_type": getattr(self.model, "task_type", "regression"),
594
- "loss_name": getattr(self.model, "loss_name", None),
595
- "tw_power": getattr(self.model, "tw_power", 1.5),
596
- "num_geo": getattr(self.model, "num_geo", 0),
597
- "num_numeric_tokens": getattr(self.model, "num_numeric_tokens", None),
598
- "cat_cardinalities": getattr(self.model, "cat_cardinalities", None),
599
- "cat_categories": {k: list(v) for k, v in getattr(self.model, "cat_categories", {}).items()},
600
- "_num_mean": getattr(self.model, "_num_mean", None),
601
- "_num_std": getattr(self.model, "_num_std", None),
602
- }
603
- # Convert numpy arrays to lists for JSON serialization
604
- if payload["model_config"]["_num_mean"] is not None:
605
- payload["model_config"]["_num_mean"] = payload["model_config"]["_num_mean"].tolist() if hasattr(payload["model_config"]["_num_mean"], "tolist") else payload["model_config"]["_num_mean"]
606
- if payload["model_config"]["_num_std"] is not None:
607
- payload["model_config"]["_num_std"] = payload["model_config"]["_num_std"].tolist() if hasattr(payload["model_config"]["_num_std"], "tolist") else payload["model_config"]["_num_std"]
608
- else:
609
- # Generic PyTorch model fallback
610
- if hasattr(self.model, 'to'):
611
- self.model.to("cpu")
612
- payload["model"] = self.model
613
- torch.save(payload, path)
614
-
615
- def load(self) -> None:
616
- path = self.output.model_path(self._get_model_filename())
617
- if not os.path.exists(path):
618
- print(f"[load] Warning: Model file not found: {path}")
619
- return
620
-
621
- if self.label in ['Xgboost', 'GLM']:
622
- loaded = joblib.load(path)
623
- if isinstance(loaded, dict) and "model" in loaded:
624
- self.model = loaded.get("model")
625
- else:
626
- self.model = loaded
627
- else:
628
- # PyTorch loading depends on the model structure.
629
- if self.label == 'ResNet' or self.label == 'ResNetClassifier':
630
- # ResNet requires reconstructing the skeleton; handled by subclass.
631
- pass
632
- else:
633
- # FT-Transformer: load state_dict and reconstruct model
634
- loaded = torch_load(path, map_location='cpu', weights_only=False)
635
- if isinstance(loaded, dict):
636
- if "state_dict" in loaded and "model_config" in loaded:
637
- # New format: state_dict + model_config
638
- state_dict = loaded.get("state_dict")
639
- model_config = loaded.get("model_config", {})
640
- self.best_params = loaded.get("best_params", {})
641
-
642
- # Import FTTransformerSklearn for reconstruction
643
- from ..models import FTTransformerSklearn
644
-
645
- # Reconstruct model from config
646
- model = FTTransformerSklearn(
647
- model_nme=model_config.get("model_nme", ""),
648
- num_cols=model_config.get("num_cols", []),
649
- cat_cols=model_config.get("cat_cols", []),
650
- d_model=model_config.get("d_model", 64),
651
- n_heads=model_config.get("n_heads", 8),
652
- n_layers=model_config.get("n_layers", 4),
653
- dropout=model_config.get("dropout", 0.1),
654
- task_type=model_config.get("task_type", "regression"),
655
- loss_name=model_config.get("loss_name", None),
656
- tweedie_power=model_config.get("tw_power", 1.5),
657
- num_numeric_tokens=model_config.get("num_numeric_tokens"),
658
- use_data_parallel=False,
659
- use_ddp=False,
660
- )
661
- # Restore internal state
662
- model.num_geo = model_config.get("num_geo", 0)
663
- model.cat_cardinalities = model_config.get("cat_cardinalities")
664
- model.cat_categories = {k: pd.Index(v) for k, v in model_config.get("cat_categories", {}).items()}
665
- if model_config.get("_num_mean") is not None:
666
- model._num_mean = np.array(model_config["_num_mean"], dtype=np.float32)
667
- if model_config.get("_num_std") is not None:
668
- model._num_std = np.array(model_config["_num_std"], dtype=np.float32)
669
-
670
- # Build the model architecture and load weights
671
- # We need a dummy dataframe to initialize the model
672
- if model.cat_cardinalities is not None:
673
- from ..models.model_ft_components import FTTransformerCore
674
- core = FTTransformerCore(
675
- num_numeric=len(model.num_cols),
676
- cat_cardinalities=model.cat_cardinalities,
677
- d_model=model.d_model,
678
- n_heads=model.n_heads,
679
- n_layers=model.n_layers,
680
- dropout=model.dropout,
681
- task_type=model.task_type,
682
- num_geo=model.num_geo,
683
- num_numeric_tokens=model.num_numeric_tokens,
684
- )
685
- model.ft = core
686
- model.ft.load_state_dict(state_dict)
687
-
688
- self._move_to_device(model)
689
- self.model = model
690
- elif "model" in loaded:
691
- # Legacy format: full model object
692
- loaded_model = loaded.get("model")
693
- if loaded_model is not None:
694
- self._move_to_device(loaded_model)
695
- self.model = loaded_model
696
- else:
697
- # Unknown format
698
- print(f"[load] Warning: Unknown model format in {path}")
699
- else:
700
- # Very old format: direct model object
701
- if loaded is not None:
702
- self._move_to_device(loaded)
703
- self.model = loaded
704
-
705
- def _move_to_device(self, model_obj):
706
- """Move model to the best available device using shared DeviceManager."""
707
- DeviceManager.move_to_device(model_obj)
708
-
709
- def _should_use_distributed_optuna(self) -> bool:
710
- if not self.enable_distributed_optuna:
711
- return False
712
- rank_env = os.environ.get("RANK")
713
- world_env = os.environ.get("WORLD_SIZE")
714
- local_env = os.environ.get("LOCAL_RANK")
715
- if rank_env is None or world_env is None or local_env is None:
716
- return False
717
- try:
718
- world_size = int(world_env)
719
- except Exception:
720
- return False
721
- return world_size > 1
722
-
723
- def _distributed_is_main(self) -> bool:
724
- return DistributedUtils.is_main_process()
725
-
726
- def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
727
- if not self._should_use_distributed_optuna() or not self._distributed_is_main():
728
- return
729
- if dist is None:
730
- return
731
- DistributedUtils.setup_ddp()
732
- if not dist.is_initialized():
733
- return
734
- message = [payload]
735
- dist.broadcast_object_list(message, src=0)
736
-
737
- def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
738
- if not self._should_use_distributed_optuna():
739
- return
740
- if not self._distributed_is_main():
741
- return
742
- if dist is None:
743
- return
744
- self._distributed_send_command({"type": "RUN", "params": params})
745
- if not dist.is_initialized():
746
- return
747
- # STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
748
- self._dist_barrier("prepare_trial")
749
-
750
- def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
751
- if dist is None:
752
- print(
753
- f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
754
- flush=True,
755
- )
756
- return
757
- DistributedUtils.setup_ddp()
758
- if not dist.is_initialized():
759
- print(
760
- f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
761
- flush=True,
762
- )
763
- return
764
- while True:
765
- message = [None]
766
- dist.broadcast_object_list(message, src=0)
767
- payload = message[0]
768
- if not isinstance(payload, dict):
769
- continue
770
- cmd = payload.get("type")
771
- if cmd == "STOP":
772
- best_params = payload.get("best_params")
773
- if best_params is not None:
774
- self.best_params = best_params
775
- break
776
- if cmd == "RUN":
777
- params = payload.get("params") or {}
778
- self._distributed_forced_params = params
779
- # STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
780
- self._dist_barrier("worker_start")
781
- try:
782
- objective_fn(None)
783
- except optuna.TrialPruned:
784
- pass
785
- except Exception as exc:
786
- print(
787
- f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
482
+ self._clean_gpu(synchronize=self._optuna_cleanup_sync())
483
+ if should_log:
484
+ progress_counter["count"] = progress_counter["count"] + 1
485
+ trial_state = getattr(trial, "state", None)
486
+ state_repr = getattr(trial_state, "name", "OK")
487
+ _log(
488
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
489
+ f"(status={state_repr})."
490
+ )
491
+ return result
492
+
493
+ storage_url = self._resolve_optuna_storage_url()
494
+ study_name = self._resolve_optuna_study_name()
495
+ study_kwargs: Dict[str, Any] = {
496
+ "direction": "minimize",
497
+ "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
498
+ }
499
+ if storage_url:
500
+ study_kwargs.update(
501
+ storage=storage_url,
502
+ study_name=study_name,
503
+ load_if_exists=True,
504
+ )
505
+
506
+ study = optuna.create_study(**study_kwargs)
507
+ self.study_name = getattr(study, "study_name", None)
508
+
509
+ def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
510
+ # Persist best_params after each trial to allow safe resume.
511
+ try:
512
+ best = getattr(check_study, "best_trial", None)
513
+ if best is None:
514
+ return
515
+ best_params = getattr(best, "params", None)
516
+ if not best_params:
517
+ return
518
+ params_path = self.output.result_path(
519
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
520
+ )
521
+ pd.DataFrame(best_params, index=[0]).to_csv(
522
+ params_path, index=False)
523
+ except Exception:
524
+ return
525
+
526
+ completed_states = (
527
+ optuna.trial.TrialState.COMPLETE,
528
+ optuna.trial.TrialState.PRUNED,
529
+ optuna.trial.TrialState.FAIL,
530
+ )
531
+ completed = len(study.get_trials(states=completed_states))
532
+ progress_counter["count"] = completed
533
+ remaining = max(0, total_trials - completed)
534
+ if remaining > 0:
535
+ study.optimize(
536
+ objective_wrapper,
537
+ n_trials=remaining,
538
+ callbacks=[checkpoint_callback],
539
+ )
540
+ self.best_params = study.best_params
541
+ self.best_trial = study.best_trial
542
+
543
+ # Save best params to CSV for reproducibility.
544
+ params_path = self.output.result_path(
545
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
546
+ )
547
+ pd.DataFrame(self.best_params, index=[0]).to_csv(
548
+ params_path, index=False)
549
+
550
+ def train(self) -> None:
551
+ raise NotImplementedError
552
+
553
+ def _unwrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
554
+ """Unwrap DDP or DataParallel wrapper to get the base module."""
555
+ from torch.nn.parallel import DistributedDataParallel as DDP
556
+ if isinstance(module, (DDP, torch.nn.DataParallel)):
557
+ return module.module
558
+ return module
559
+
560
+ def save(self) -> None:
561
+ if self.model is None:
562
+ _log(f"[save] Warning: No model to save for {self.label}")
563
+ return
564
+
565
+ path = self.output.model_path(self._get_model_filename())
566
+ if self.label in ['Xgboost', 'GLM']:
567
+ payload = {
568
+ "model": self.model,
569
+ "preprocess_artifacts": self._export_preprocess_artifacts(),
570
+ }
571
+ joblib.dump(payload, path)
572
+ else:
573
+ # PyTorch models: save state_dict without DDP/DataParallel wrappers
574
+ # to ensure cross-platform compatibility.
575
+ payload = {
576
+ "preprocess_artifacts": self._export_preprocess_artifacts(),
577
+ }
578
+ if hasattr(self.model, 'resnet'): # ResNetSklearn model
579
+ # Unwrap DDP/DataParallel and move to CPU
580
+ resnet = self._unwrap_module(self.model.resnet)
581
+ resnet_cpu = resnet.to("cpu")
582
+ payload["state_dict"] = resnet_cpu.state_dict()
583
+ payload["best_params"] = dict(self.best_params or {})
584
+ elif hasattr(self.model, 'ft'): # FTTransformerSklearn model
585
+ # Unwrap DDP/DataParallel and save state_dict (not full model object)
586
+ # to avoid serialization issues with DDP wrappers
587
+ ft = self._unwrap_module(self.model.ft)
588
+ ft_cpu = ft.to("cpu")
589
+ payload["state_dict"] = ft_cpu.state_dict()
590
+ payload["best_params"] = dict(self.best_params or {})
591
+ # Save model configuration for reconstruction
592
+ payload["model_config"] = {
593
+ "model_nme": getattr(self.model, "model_nme", ""),
594
+ "num_cols": list(getattr(self.model, "num_cols", [])),
595
+ "cat_cols": list(getattr(self.model, "cat_cols", [])),
596
+ "d_model": getattr(self.model, "d_model", 64),
597
+ "n_heads": getattr(self.model, "n_heads", 8),
598
+ "n_layers": getattr(self.model, "n_layers", 4),
599
+ "dropout": getattr(self.model, "dropout", 0.1),
600
+ "task_type": getattr(self.model, "task_type", "regression"),
601
+ "loss_name": getattr(self.model, "loss_name", None),
602
+ "tw_power": getattr(self.model, "tw_power", 1.5),
603
+ "num_geo": getattr(self.model, "num_geo", 0),
604
+ "num_numeric_tokens": getattr(self.model, "num_numeric_tokens", None),
605
+ "cat_cardinalities": getattr(self.model, "cat_cardinalities", None),
606
+ "cat_categories": {k: list(v) for k, v in getattr(self.model, "cat_categories", {}).items()},
607
+ "_num_mean": getattr(self.model, "_num_mean", None),
608
+ "_num_std": getattr(self.model, "_num_std", None),
609
+ }
610
+ # Convert numpy arrays to lists for JSON serialization
611
+ if payload["model_config"]["_num_mean"] is not None:
612
+ payload["model_config"]["_num_mean"] = payload["model_config"]["_num_mean"].tolist() if hasattr(payload["model_config"]["_num_mean"], "tolist") else payload["model_config"]["_num_mean"]
613
+ if payload["model_config"]["_num_std"] is not None:
614
+ payload["model_config"]["_num_std"] = payload["model_config"]["_num_std"].tolist() if hasattr(payload["model_config"]["_num_std"], "tolist") else payload["model_config"]["_num_std"]
615
+ else:
616
+ # Generic PyTorch model fallback
617
+ if hasattr(self.model, 'to'):
618
+ self.model.to("cpu")
619
+ payload["model"] = self.model
620
+ torch.save(payload, path)
621
+
622
+ def load(self) -> None:
623
+ path = self.output.model_path(self._get_model_filename())
624
+ if not os.path.exists(path):
625
+ _log(f"[load] Warning: Model file not found: {path}")
626
+ return
627
+
628
+ if self.label in ['Xgboost', 'GLM']:
629
+ loaded = joblib.load(path)
630
+ if isinstance(loaded, dict) and "model" in loaded:
631
+ self.model = loaded.get("model")
632
+ else:
633
+ self.model = loaded
634
+ else:
635
+ # PyTorch loading depends on the model structure.
636
+ if self.label == 'ResNet' or self.label == 'ResNetClassifier':
637
+ # ResNet requires reconstructing the skeleton; handled by subclass.
638
+ pass
639
+ else:
640
+ # FT-Transformer: load state_dict and reconstruct model
641
+ loaded = torch_load(path, map_location='cpu', weights_only=False)
642
+ if isinstance(loaded, dict):
643
+ if "state_dict" in loaded and "model_config" in loaded:
644
+ # New format: state_dict + model_config
645
+ state_dict = loaded.get("state_dict")
646
+ model_config = loaded.get("model_config", {})
647
+ self.best_params = loaded.get("best_params", {})
648
+
649
+ # Import FTTransformerSklearn for reconstruction
650
+ from ins_pricing.modelling.bayesopt.models import FTTransformerSklearn
651
+
652
+ # Reconstruct model from config
653
+ model = FTTransformerSklearn(
654
+ model_nme=model_config.get("model_nme", ""),
655
+ num_cols=model_config.get("num_cols", []),
656
+ cat_cols=model_config.get("cat_cols", []),
657
+ d_model=model_config.get("d_model", 64),
658
+ n_heads=model_config.get("n_heads", 8),
659
+ n_layers=model_config.get("n_layers", 4),
660
+ dropout=model_config.get("dropout", 0.1),
661
+ task_type=model_config.get("task_type", "regression"),
662
+ loss_name=model_config.get("loss_name", None),
663
+ tweedie_power=model_config.get("tw_power", 1.5),
664
+ num_numeric_tokens=model_config.get("num_numeric_tokens"),
665
+ use_data_parallel=False,
666
+ use_ddp=False,
667
+ )
668
+ # Restore internal state
669
+ model.num_geo = model_config.get("num_geo", 0)
670
+ model.cat_cardinalities = model_config.get("cat_cardinalities")
671
+ model.cat_categories = {k: pd.Index(v) for k, v in model_config.get("cat_categories", {}).items()}
672
+ if model_config.get("_num_mean") is not None:
673
+ model._num_mean = np.array(model_config["_num_mean"], dtype=np.float32)
674
+ if model_config.get("_num_std") is not None:
675
+ model._num_std = np.array(model_config["_num_std"], dtype=np.float32)
676
+
677
+ # Build the model architecture and load weights
678
+ # We need a dummy dataframe to initialize the model
679
+ if model.cat_cardinalities is not None:
680
+ from ins_pricing.modelling.bayesopt.models.model_ft_components import FTTransformerCore
681
+ core = FTTransformerCore(
682
+ num_numeric=len(model.num_cols),
683
+ cat_cardinalities=model.cat_cardinalities,
684
+ d_model=model.d_model,
685
+ n_heads=model.n_heads,
686
+ n_layers=model.n_layers,
687
+ dropout=model.dropout,
688
+ task_type=model.task_type,
689
+ num_geo=model.num_geo,
690
+ num_numeric_tokens=model.num_numeric_tokens,
691
+ )
692
+ model.ft = core
693
+ model.ft.load_state_dict(state_dict)
694
+
695
+ self._move_to_device(model)
696
+ self.model = model
697
+ elif "model" in loaded:
698
+ # Legacy format: full model object
699
+ loaded_model = loaded.get("model")
700
+ if loaded_model is not None:
701
+ self._move_to_device(loaded_model)
702
+ self.model = loaded_model
703
+ else:
704
+ # Unknown format
705
+ _log(f"[load] Warning: Unknown model format in {path}")
706
+ else:
707
+ # Very old format: direct model object
708
+ if loaded is not None:
709
+ self._move_to_device(loaded)
710
+ self.model = loaded
711
+
712
+ def _move_to_device(self, model_obj):
713
+ """Move model to the best available device using shared DeviceManager."""
714
+ DeviceManager.move_to_device(model_obj)
715
+
716
+ def _should_use_distributed_optuna(self) -> bool:
717
+ if not self.enable_distributed_optuna:
718
+ return False
719
+ rank_env = os.environ.get("RANK")
720
+ world_env = os.environ.get("WORLD_SIZE")
721
+ local_env = os.environ.get("LOCAL_RANK")
722
+ if rank_env is None or world_env is None or local_env is None:
723
+ return False
724
+ try:
725
+ world_size = int(world_env)
726
+ except Exception:
727
+ return False
728
+ return world_size > 1
729
+
730
+ def _distributed_is_main(self) -> bool:
731
+ return DistributedUtils.is_main_process()
732
+
733
+ def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
734
+ if not self._should_use_distributed_optuna() or not self._distributed_is_main():
735
+ return
736
+ if dist is None:
737
+ return
738
+ DistributedUtils.setup_ddp()
739
+ if not dist.is_initialized():
740
+ return
741
+ message = [payload]
742
+ dist.broadcast_object_list(message, src=0)
743
+
744
+ def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
745
+ if not self._should_use_distributed_optuna():
746
+ return
747
+ if not self._distributed_is_main():
748
+ return
749
+ if dist is None:
750
+ return
751
+ self._distributed_send_command({"type": "RUN", "params": params})
752
+ if not dist.is_initialized():
753
+ return
754
+ # STEP 2 (DDP/Optuna): make sure all ranks start the trial together.
755
+ self._dist_barrier("prepare_trial")
756
+
757
+ def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
758
+ if dist is None:
759
+ _log(
760
+ f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
761
+ flush=True,
762
+ )
763
+ return
764
+ DistributedUtils.setup_ddp()
765
+ if not dist.is_initialized():
766
+ _log(
767
+ f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
768
+ flush=True,
769
+ )
770
+ return
771
+ while True:
772
+ message = [None]
773
+ dist.broadcast_object_list(message, src=0)
774
+ payload = message[0]
775
+ if not isinstance(payload, dict):
776
+ continue
777
+ cmd = payload.get("type")
778
+ if cmd == "STOP":
779
+ best_params = payload.get("best_params")
780
+ if best_params is not None:
781
+ self.best_params = best_params
782
+ break
783
+ if cmd == "RUN":
784
+ params = payload.get("params") or {}
785
+ self._distributed_forced_params = params
786
+ # STEP 2 (DDP/Optuna): align worker with rank0 before running objective_fn.
787
+ self._dist_barrier("worker_start")
788
+ try:
789
+ objective_fn(None)
790
+ except optuna.TrialPruned:
791
+ pass
792
+ except Exception as exc:
793
+ _log(
794
+ f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
788
795
  finally:
789
- self._clean_gpu()
796
+ self._clean_gpu(synchronize=self._optuna_cleanup_sync())
790
797
  # STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
791
798
  self._dist_barrier("worker_end")
792
-
793
- def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
794
- if dist is None:
795
- print(
796
- f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
797
- flush=True,
798
- )
799
- prev = self.enable_distributed_optuna
800
- self.enable_distributed_optuna = False
801
- try:
802
- self.tune(max_evals, objective_fn)
803
- finally:
804
- self.enable_distributed_optuna = prev
805
- return
806
- DistributedUtils.setup_ddp()
807
- if not dist.is_initialized():
808
- rank_env = os.environ.get("RANK", "0")
809
- if str(rank_env) != "0":
810
- print(
811
- f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
812
- flush=True,
813
- )
814
- return
815
- print(
816
- f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
817
- flush=True,
818
- )
819
- prev = self.enable_distributed_optuna
820
- self.enable_distributed_optuna = False
821
- try:
822
- self.tune(max_evals, objective_fn)
823
- finally:
824
- self.enable_distributed_optuna = prev
825
- return
826
- if not self._distributed_is_main():
827
- self._distributed_worker_loop(objective_fn)
828
- return
829
-
830
- total_trials = max(1, int(max_evals))
831
- progress_counter = {"count": 0}
832
-
833
- def objective_wrapper(trial: optuna.trial.Trial) -> float:
834
- should_log = True
835
- if should_log:
836
- current_idx = progress_counter["count"] + 1
837
- print(
838
- f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
839
- f"(trial_id={trial.number})."
840
- )
841
- try:
842
- result = objective_fn(trial)
799
+
800
+ def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
801
+ if dist is None:
802
+ _log(
803
+ f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
804
+ flush=True,
805
+ )
806
+ prev = self.enable_distributed_optuna
807
+ self.enable_distributed_optuna = False
808
+ try:
809
+ self.tune(max_evals, objective_fn)
810
+ finally:
811
+ self.enable_distributed_optuna = prev
812
+ return
813
+ DistributedUtils.setup_ddp()
814
+ if not dist.is_initialized():
815
+ rank_env = os.environ.get("RANK", "0")
816
+ if str(rank_env) != "0":
817
+ _log(
818
+ f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
819
+ flush=True,
820
+ )
821
+ return
822
+ _log(
823
+ f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
824
+ flush=True,
825
+ )
826
+ prev = self.enable_distributed_optuna
827
+ self.enable_distributed_optuna = False
828
+ try:
829
+ self.tune(max_evals, objective_fn)
830
+ finally:
831
+ self.enable_distributed_optuna = prev
832
+ return
833
+ if not self._distributed_is_main():
834
+ self._distributed_worker_loop(objective_fn)
835
+ return
836
+
837
+ total_trials = max(1, int(max_evals))
838
+ progress_counter = {"count": 0}
839
+
840
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
841
+ should_log = True
842
+ if should_log:
843
+ current_idx = progress_counter["count"] + 1
844
+ _log(
845
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
846
+ f"(trial_id={trial.number})."
847
+ )
848
+ try:
849
+ result = objective_fn(trial)
843
850
  except RuntimeError as exc:
844
851
  if "out of memory" in str(exc).lower():
845
- print(
852
+ _log(
846
853
  f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
847
854
  )
848
- self._clean_gpu()
855
+ self._clean_gpu(synchronize=True)
849
856
  raise optuna.TrialPruned() from exc
850
857
  raise
851
858
  finally:
852
- self._clean_gpu()
853
- if should_log:
854
- progress_counter["count"] = progress_counter["count"] + 1
855
- trial_state = getattr(trial, "state", None)
856
- state_repr = getattr(trial_state, "name", "OK")
857
- print(
858
- f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
859
- f"(status={state_repr})."
860
- )
861
- # STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
862
- self._dist_barrier("trial_end")
863
- return result
864
-
865
- storage_url = self._resolve_optuna_storage_url()
866
- study_name = self._resolve_optuna_study_name()
867
- study_kwargs: Dict[str, Any] = {
868
- "direction": "minimize",
869
- "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
870
- }
871
- if storage_url:
872
- study_kwargs.update(
873
- storage=storage_url,
874
- study_name=study_name,
875
- load_if_exists=True,
876
- )
877
- study = optuna.create_study(**study_kwargs)
878
- self.study_name = getattr(study, "study_name", None)
879
-
880
- def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
881
- try:
882
- best = getattr(check_study, "best_trial", None)
883
- if best is None:
884
- return
885
- best_params = getattr(best, "params", None)
886
- if not best_params:
887
- return
888
- params_path = self.output.result_path(
889
- f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
890
- )
891
- pd.DataFrame(best_params, index=[0]).to_csv(
892
- params_path, index=False)
893
- except Exception:
894
- return
895
-
896
- completed_states = (
897
- optuna.trial.TrialState.COMPLETE,
898
- optuna.trial.TrialState.PRUNED,
899
- optuna.trial.TrialState.FAIL,
900
- )
901
- completed = len(study.get_trials(states=completed_states))
902
- progress_counter["count"] = completed
903
- remaining = max(0, total_trials - completed)
904
- try:
905
- if remaining > 0:
906
- study.optimize(
907
- objective_wrapper,
908
- n_trials=remaining,
909
- callbacks=[checkpoint_callback],
910
- )
911
- self.best_params = study.best_params
912
- self.best_trial = study.best_trial
913
- params_path = self.output.result_path(
914
- f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
915
- )
916
- pd.DataFrame(self.best_params, index=[0]).to_csv(
917
- params_path, index=False)
918
- finally:
919
- self._distributed_send_command(
920
- {"type": "STOP", "best_params": self.best_params})
921
-
922
- def _clean_gpu(self):
923
- """Clean up GPU memory using shared GPUMemoryManager."""
924
- GPUMemoryManager.clean()
925
-
926
- def _standardize_fold(self,
927
- X_train: pd.DataFrame,
928
- X_val: pd.DataFrame,
929
- columns: Optional[List[str]] = None
930
- ) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
931
- """Fit StandardScaler on the training fold and transform train/val features.
932
-
933
- Args:
934
- X_train: training features.
935
- X_val: validation features.
936
- columns: columns to scale (default: all).
937
-
938
- Returns:
939
- Scaled train/val features and the fitted scaler.
940
- """
941
- scaler = StandardScaler()
942
- cols = list(columns) if columns else list(X_train.columns)
943
- X_train_scaled = X_train.copy(deep=True)
944
- X_val_scaled = X_val.copy(deep=True)
945
- if cols:
946
- scaler.fit(X_train_scaled[cols])
947
- X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
948
- X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
949
- return X_train_scaled, X_val_scaled, scaler
950
-
951
- def _resolve_train_val_indices(
952
- self,
953
- X_all: pd.DataFrame,
954
- *,
955
- allow_default: bool = False,
956
- ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
957
- """Resolve train/validation split indices based on configured CV strategy.
958
-
959
- Args:
960
- X_all: DataFrame to split
961
- allow_default: If True, use default val_ratio when config is invalid
962
-
963
- Returns:
964
- Tuple of (train_indices, val_indices) or None if not enough data
965
- """
966
- val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
967
- if not (0.0 < val_ratio < 1.0):
968
- if not allow_default:
969
- return None
970
- val_ratio = 0.25
971
- if len(X_all) < 10:
972
- return None
973
-
974
- resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
975
- (train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
976
- return train_idx, val_idx
977
-
978
- def _resolve_time_sample_indices(
979
- self,
980
- X_all: pd.DataFrame,
981
- sample_limit: int,
982
- ) -> Optional[pd.Index]:
983
- """Get the most recent indices for time-based sampling.
984
-
985
- For time-based CV strategies, returns the last `sample_limit` indices
986
- ordered by time. For other strategies, returns None.
987
-
988
- Args:
989
- X_all: DataFrame to sample from
990
- sample_limit: Maximum number of samples to return
991
-
992
- Returns:
993
- Index of sampled rows, or None if not using time-based strategy
994
- """
995
- if sample_limit <= 0:
996
- return None
997
-
998
- resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
999
- if not resolver.is_time_strategy():
1000
- return None
1001
-
1002
- order = resolver.get_time_ordered_indices(X_all)
1003
- if len(order) == 0:
1004
- return None
1005
-
1006
- # Get the last sample_limit indices (most recent in time)
1007
- if len(order) > sample_limit:
1008
- order = order[-sample_limit:]
1009
-
1010
- return X_all.index[order]
1011
-
1012
- def _resolve_ensemble_splits(
859
+ self._clean_gpu(synchronize=self._optuna_cleanup_sync())
860
+ if should_log:
861
+ progress_counter["count"] = progress_counter["count"] + 1
862
+ trial_state = getattr(trial, "state", None)
863
+ state_repr = getattr(trial_state, "name", "OK")
864
+ _log(
865
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
866
+ f"(status={state_repr})."
867
+ )
868
+ # STEP 2 (DDP/Optuna): a trial-end sync point; debug with BAYESOPT_DDP_BARRIER_DEBUG=1.
869
+ self._dist_barrier("trial_end")
870
+ return result
871
+
872
+ storage_url = self._resolve_optuna_storage_url()
873
+ study_name = self._resolve_optuna_study_name()
874
+ study_kwargs: Dict[str, Any] = {
875
+ "direction": "minimize",
876
+ "sampler": optuna.samplers.TPESampler(seed=self.ctx.rand_seed),
877
+ }
878
+ if storage_url:
879
+ study_kwargs.update(
880
+ storage=storage_url,
881
+ study_name=study_name,
882
+ load_if_exists=True,
883
+ )
884
+ study = optuna.create_study(**study_kwargs)
885
+ self.study_name = getattr(study, "study_name", None)
886
+
887
+ def checkpoint_callback(check_study: optuna.study.Study, _trial) -> None:
888
+ try:
889
+ best = getattr(check_study, "best_trial", None)
890
+ if best is None:
891
+ return
892
+ best_params = getattr(best, "params", None)
893
+ if not best_params:
894
+ return
895
+ params_path = self.output.result_path(
896
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
897
+ )
898
+ pd.DataFrame(best_params, index=[0]).to_csv(
899
+ params_path, index=False)
900
+ except Exception:
901
+ return
902
+
903
+ completed_states = (
904
+ optuna.trial.TrialState.COMPLETE,
905
+ optuna.trial.TrialState.PRUNED,
906
+ optuna.trial.TrialState.FAIL,
907
+ )
908
+ completed = len(study.get_trials(states=completed_states))
909
+ progress_counter["count"] = completed
910
+ remaining = max(0, total_trials - completed)
911
+ try:
912
+ if remaining > 0:
913
+ study.optimize(
914
+ objective_wrapper,
915
+ n_trials=remaining,
916
+ callbacks=[checkpoint_callback],
917
+ )
918
+ self.best_params = study.best_params
919
+ self.best_trial = study.best_trial
920
+ params_path = self.output.result_path(
921
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
922
+ )
923
+ pd.DataFrame(self.best_params, index=[0]).to_csv(
924
+ params_path, index=False)
925
+ finally:
926
+ self._distributed_send_command(
927
+ {"type": "STOP", "best_params": self.best_params})
928
+
929
+ def _clean_gpu(
1013
930
  self,
1014
- X_all: pd.DataFrame,
1015
931
  *,
1016
- k: int,
1017
- ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
1018
- """Resolve K-fold splits for ensemble training based on configured CV strategy.
1019
-
1020
- Args:
1021
- X_all: DataFrame to split
1022
- k: Number of folds requested
1023
-
1024
- Returns:
1025
- Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
1026
- """
1027
- resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1028
- return resolver.create_kfold_splitter(X_all, k)
1029
-
1030
- def cross_val_generic(
1031
- self,
1032
- trial: optuna.trial.Trial,
1033
- hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
1034
- data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
1035
- model_builder: Callable[[Dict[str, Any]], Any],
1036
- metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
1037
- sample_limit: Optional[int] = None,
1038
- preprocess_fn: Optional[Callable[[
1039
- pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
1040
- fit_predict_fn: Optional[
1041
- Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
1042
- pd.DataFrame, pd.Series, Optional[pd.Series],
1043
- optuna.trial.Trial], np.ndarray]
1044
- ] = None,
1045
- cleanup_fn: Optional[Callable[[Any], None]] = None,
1046
- splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
1047
- """Generic holdout/CV helper to reuse tuning workflows.
1048
-
1049
- Args:
1050
- trial: current Optuna trial.
1051
- hyperparameter_space: sampler dict keyed by parameter name.
1052
- data_provider: callback returning (X, y, sample_weight).
1053
- model_builder: callback to build a model per fold.
1054
- metric_fn: loss/score function taking y_true, y_pred, weight.
1055
- sample_limit: optional sample cap; random sample if exceeded.
1056
- preprocess_fn: optional per-fold preprocessing (X_train, X_val).
1057
- fit_predict_fn: optional custom fit/predict logic for validation.
1058
- cleanup_fn: optional cleanup callback per fold.
1059
- splitter: optional (train_idx, val_idx) iterator; defaults to cv_strategy config.
1060
-
1061
- Returns:
1062
- Mean validation metric across folds.
1063
- """
1064
- params: Optional[Dict[str, Any]] = None
1065
- if self._distributed_forced_params is not None:
1066
- params = self._distributed_forced_params
1067
- self._distributed_forced_params = None
1068
- else:
1069
- if trial is None:
1070
- raise RuntimeError(
1071
- "Missing Optuna trial for parameter sampling.")
1072
- params = {name: sampler(trial)
1073
- for name, sampler in hyperparameter_space.items()}
1074
- if self._should_use_distributed_optuna():
1075
- self._distributed_prepare_trial(params)
1076
- X_all, y_all, w_all = data_provider()
1077
- cfg_limit = getattr(self.ctx.config, "bo_sample_limit", None)
1078
- if cfg_limit is not None:
1079
- cfg_limit = int(cfg_limit)
1080
- if cfg_limit > 0:
1081
- sample_limit = cfg_limit if sample_limit is None else min(sample_limit, cfg_limit)
1082
- if sample_limit is not None and len(X_all) > sample_limit:
1083
- sampled_idx = self._resolve_time_sample_indices(X_all, int(sample_limit))
1084
- if sampled_idx is None:
1085
- sampled_idx = X_all.sample(
1086
- n=sample_limit,
1087
- random_state=self.ctx.rand_seed
1088
- ).index
1089
- X_all = X_all.loc[sampled_idx]
1090
- y_all = y_all.loc[sampled_idx]
1091
- w_all = w_all.loc[sampled_idx] if w_all is not None else None
1092
-
1093
- if splitter is None:
1094
- val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
1095
- if not (0.0 < val_ratio < 1.0):
1096
- val_ratio = 0.25
1097
- cv_splits = getattr(self.ctx.config, "cv_splits", None)
1098
- if cv_splits is None:
1099
- cv_splits = max(2, int(round(1 / val_ratio)))
1100
- cv_splits = max(2, int(cv_splits))
1101
-
1102
- resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1103
- split_iter, actual_splits = resolver.create_cv_splitter(X_all, y_all, cv_splits, val_ratio)
1104
- if actual_splits < 2:
1105
- raise ValueError("Not enough samples for cross-validation.")
1106
- else:
1107
- if hasattr(splitter, "split"):
1108
- split_iter = splitter.split(X_all, y_all, groups=None)
1109
- else:
1110
- split_iter = splitter
1111
-
1112
- losses: List[float] = []
1113
- for fold_idx, (train_idx, val_idx) in enumerate(split_iter):
1114
- X_train = X_all.iloc[train_idx]
1115
- y_train = y_all.iloc[train_idx]
1116
- X_val = X_all.iloc[val_idx]
1117
- y_val = y_all.iloc[val_idx]
1118
- w_train = w_all.iloc[train_idx] if w_all is not None else None
1119
- w_val = w_all.iloc[val_idx] if w_all is not None else None
1120
-
1121
- if preprocess_fn:
1122
- X_train, X_val = preprocess_fn(X_train, X_val)
1123
-
1124
- model = model_builder(params)
1125
- try:
1126
- if fit_predict_fn:
1127
- # Avoid duplicate Optuna step reports across folds.
1128
- trial_for_fold = trial if fold_idx == 0 else None
1129
- y_pred = fit_predict_fn(
1130
- model, X_train, y_train, w_train,
1131
- X_val, y_val, w_val, trial_for_fold
1132
- )
1133
- else:
1134
- fit_kwargs = {}
1135
- if w_train is not None:
1136
- fit_kwargs["sample_weight"] = w_train
1137
- model.fit(X_train, y_train, **fit_kwargs)
1138
- y_pred = model.predict(X_val)
1139
- losses.append(metric_fn(y_val, y_pred, w_val))
1140
- finally:
1141
- if cleanup_fn:
1142
- cleanup_fn(model)
1143
- self._clean_gpu()
1144
-
1145
- return float(np.mean(losses))
1146
-
1147
- # Prediction + caching logic.
1148
- def _predict_and_cache(self,
1149
- model,
1150
- pred_prefix: str,
1151
- use_oht: bool = False,
1152
- design_fn=None,
1153
- predict_kwargs_train: Optional[Dict[str, Any]] = None,
1154
- predict_kwargs_test: Optional[Dict[str, Any]] = None,
1155
- predict_fn: Optional[Callable[..., Any]] = None) -> None:
1156
- if design_fn:
1157
- X_train = design_fn(train=True)
1158
- X_test = design_fn(train=False)
1159
- elif use_oht:
1160
- X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
1161
- X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
1162
- else:
1163
- X_train = self.ctx.train_data[self.ctx.factor_nmes]
1164
- X_test = self.ctx.test_data[self.ctx.factor_nmes]
1165
-
1166
- predictor = predict_fn or model.predict
1167
- preds_train = predictor(X_train, **(predict_kwargs_train or {}))
1168
- preds_test = predictor(X_test, **(predict_kwargs_test or {}))
1169
- preds_train = np.asarray(preds_train)
1170
- preds_test = np.asarray(preds_test)
1171
-
1172
- if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
1173
- col_name = f'pred_{pred_prefix}'
1174
- self.ctx.train_data[col_name] = preds_train.reshape(-1)
1175
- self.ctx.test_data[col_name] = preds_test.reshape(-1)
1176
- self.ctx.train_data[f'w_{col_name}'] = (
1177
- self.ctx.train_data[col_name] *
1178
- self.ctx.train_data[self.ctx.weight_nme]
1179
- )
1180
- self.ctx.test_data[f'w_{col_name}'] = (
1181
- self.ctx.test_data[col_name] *
1182
- self.ctx.test_data[self.ctx.weight_nme]
1183
- )
1184
- self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1185
- return
1186
-
1187
- # Vector outputs (e.g., embeddings) are expanded into pred_<prefix>_0.. columns.
1188
- if preds_train.ndim != 2:
1189
- raise ValueError(
1190
- f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
1191
- if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
1192
- raise ValueError(
1193
- f"Train/test prediction dims mismatch for '{pred_prefix}': "
1194
- f"{preds_train.shape} vs {preds_test.shape}")
1195
- for j in range(preds_train.shape[1]):
1196
- col_name = f'pred_{pred_prefix}_{j}'
1197
- self.ctx.train_data[col_name] = preds_train[:, j]
1198
- self.ctx.test_data[col_name] = preds_test[:, j]
1199
- self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1200
-
1201
- def _cache_predictions(self,
1202
- pred_prefix: str,
1203
- preds_train,
1204
- preds_test) -> None:
1205
- preds_train = np.asarray(preds_train)
1206
- preds_test = np.asarray(preds_test)
1207
- if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
1208
- if preds_test.ndim > 1:
1209
- preds_test = preds_test.reshape(-1)
1210
- col_name = f'pred_{pred_prefix}'
1211
- self.ctx.train_data[col_name] = preds_train.reshape(-1)
1212
- self.ctx.test_data[col_name] = preds_test.reshape(-1)
1213
- self.ctx.train_data[f'w_{col_name}'] = (
1214
- self.ctx.train_data[col_name] *
1215
- self.ctx.train_data[self.ctx.weight_nme]
1216
- )
1217
- self.ctx.test_data[f'w_{col_name}'] = (
1218
- self.ctx.test_data[col_name] *
1219
- self.ctx.test_data[self.ctx.weight_nme]
1220
- )
1221
- self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1222
- return
1223
-
1224
- if preds_train.ndim != 2:
1225
- raise ValueError(
1226
- f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
1227
- if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
1228
- raise ValueError(
1229
- f"Train/test prediction dims mismatch for '{pred_prefix}': "
1230
- f"{preds_train.shape} vs {preds_test.shape}")
1231
- for j in range(preds_train.shape[1]):
1232
- col_name = f'pred_{pred_prefix}_{j}'
1233
- self.ctx.train_data[col_name] = preds_train[:, j]
1234
- self.ctx.test_data[col_name] = preds_test[:, j]
1235
- self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1236
-
1237
- def _maybe_cache_predictions(self, pred_prefix: str, preds_train, preds_test) -> None:
1238
- cfg = getattr(self.ctx, "config", None)
1239
- if cfg is None or not bool(getattr(cfg, "cache_predictions", False)):
1240
- return
1241
- fmt = str(getattr(cfg, "prediction_cache_format", "parquet") or "parquet").lower()
1242
- cache_dir = getattr(cfg, "prediction_cache_dir", None)
1243
- if cache_dir:
1244
- target_dir = Path(str(cache_dir))
1245
- if not target_dir.is_absolute():
1246
- target_dir = Path(self.output.result_dir) / target_dir
1247
- else:
1248
- target_dir = Path(self.output.result_dir) / "predictions"
1249
- target_dir.mkdir(parents=True, exist_ok=True)
1250
-
1251
- def _build_frame(preds, split_label: str) -> pd.DataFrame:
1252
- arr = np.asarray(preds)
1253
- if arr.ndim <= 1:
1254
- return pd.DataFrame({f"pred_{pred_prefix}": arr.reshape(-1)})
1255
- cols = [f"pred_{pred_prefix}_{i}" for i in range(arr.shape[1])]
1256
- return pd.DataFrame(arr, columns=cols)
1257
-
1258
- for split_label, preds in [("train", preds_train), ("test", preds_test)]:
1259
- frame = _build_frame(preds, split_label)
1260
- filename = f"{self.ctx.model_nme}_{pred_prefix}_{split_label}.{ 'csv' if fmt == 'csv' else 'parquet' }"
1261
- path = target_dir / filename
1262
- try:
1263
- if fmt == "csv":
1264
- frame.to_csv(path, index=False)
1265
- else:
1266
- frame.to_parquet(path, index=False)
1267
- except Exception:
1268
- pass
1269
-
1270
- def _resolve_best_epoch(self,
1271
- history: Optional[Dict[str, List[float]]],
1272
- default_epochs: int) -> int:
1273
- if not history:
1274
- return max(1, int(default_epochs))
1275
- vals = history.get("val") or []
1276
- if not vals:
1277
- return max(1, int(default_epochs))
1278
- best_idx = int(np.nanargmin(vals))
1279
- return max(1, best_idx + 1)
1280
-
1281
- def _fit_predict_cache(self,
1282
- model,
1283
- X_train,
1284
- y_train,
1285
- sample_weight,
1286
- pred_prefix: str,
1287
- use_oht: bool = False,
1288
- design_fn=None,
1289
- fit_kwargs: Optional[Dict[str, Any]] = None,
1290
- sample_weight_arg: Optional[str] = 'sample_weight',
1291
- predict_kwargs_train: Optional[Dict[str, Any]] = None,
1292
- predict_kwargs_test: Optional[Dict[str, Any]] = None,
1293
- predict_fn: Optional[Callable[..., Any]] = None,
1294
- record_label: bool = True) -> None:
1295
- fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
1296
- if sample_weight is not None and sample_weight_arg:
1297
- fit_kwargs.setdefault(sample_weight_arg, sample_weight)
1298
- model.fit(X_train, y_train, **fit_kwargs)
1299
- if record_label:
1300
- self.ctx.model_label.append(self.label)
1301
- self._predict_and_cache(
1302
- model,
1303
- pred_prefix,
1304
- use_oht=use_oht,
1305
- design_fn=design_fn,
1306
- predict_kwargs_train=predict_kwargs_train,
1307
- predict_kwargs_test=predict_kwargs_test,
1308
- predict_fn=predict_fn)
932
+ synchronize: bool = True,
933
+ empty_cache: bool = True,
934
+ ) -> None:
935
+ """Clean up GPU memory using shared GPUMemoryManager."""
936
+ GPUMemoryManager.clean(synchronize=synchronize, empty_cache=empty_cache)
937
+
938
+ def _standardize_fold(self,
939
+ X_train: pd.DataFrame,
940
+ X_val: pd.DataFrame,
941
+ columns: Optional[List[str]] = None
942
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
943
+ """Fit StandardScaler on the training fold and transform train/val features.
944
+
945
+ Args:
946
+ X_train: training features.
947
+ X_val: validation features.
948
+ columns: columns to scale (default: all).
949
+
950
+ Returns:
951
+ Scaled train/val features and the fitted scaler.
952
+ """
953
+ scaler = StandardScaler()
954
+ cols = list(columns) if columns else list(X_train.columns)
955
+ X_train_scaled = X_train.copy(deep=True)
956
+ X_val_scaled = X_val.copy(deep=True)
957
+ if cols:
958
+ scaler.fit(X_train_scaled[cols])
959
+ X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
960
+ X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
961
+ return X_train_scaled, X_val_scaled, scaler
962
+
963
+ def _resolve_train_val_indices(
964
+ self,
965
+ X_all: pd.DataFrame,
966
+ *,
967
+ allow_default: bool = False,
968
+ ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
969
+ """Resolve train/validation split indices based on configured CV strategy.
970
+
971
+ Args:
972
+ X_all: DataFrame to split
973
+ allow_default: If True, use default val_ratio when config is invalid
974
+
975
+ Returns:
976
+ Tuple of (train_indices, val_indices) or None if not enough data
977
+ """
978
+ val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
979
+ if not (0.0 < val_ratio < 1.0):
980
+ if not allow_default:
981
+ return None
982
+ val_ratio = 0.25
983
+ if len(X_all) < 10:
984
+ return None
985
+
986
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
987
+ (train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
988
+ return train_idx, val_idx
989
+
990
+ def _resolve_time_sample_indices(
991
+ self,
992
+ X_all: pd.DataFrame,
993
+ sample_limit: int,
994
+ ) -> Optional[pd.Index]:
995
+ """Get the most recent indices for time-based sampling.
996
+
997
+ For time-based CV strategies, returns the last `sample_limit` indices
998
+ ordered by time. For other strategies, returns None.
999
+
1000
+ Args:
1001
+ X_all: DataFrame to sample from
1002
+ sample_limit: Maximum number of samples to return
1003
+
1004
+ Returns:
1005
+ Index of sampled rows, or None if not using time-based strategy
1006
+ """
1007
+ if sample_limit <= 0:
1008
+ return None
1009
+
1010
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1011
+ if not resolver.is_time_strategy():
1012
+ return None
1013
+
1014
+ order = resolver.get_time_ordered_indices(X_all)
1015
+ if len(order) == 0:
1016
+ return None
1017
+
1018
+ # Get the last sample_limit indices (most recent in time)
1019
+ if len(order) > sample_limit:
1020
+ order = order[-sample_limit:]
1021
+
1022
+ return X_all.index[order]
1023
+
1024
+ def _resolve_ensemble_splits(
1025
+ self,
1026
+ X_all: pd.DataFrame,
1027
+ *,
1028
+ k: int,
1029
+ ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
1030
+ """Resolve K-fold splits for ensemble training based on configured CV strategy.
1031
+
1032
+ Args:
1033
+ X_all: DataFrame to split
1034
+ k: Number of folds requested
1035
+
1036
+ Returns:
1037
+ Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
1038
+ """
1039
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1040
+ return resolver.create_kfold_splitter(X_all, k)
1041
+
1042
+ def cross_val_generic(
1043
+ self,
1044
+ trial: optuna.trial.Trial,
1045
+ hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
1046
+ data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
1047
+ model_builder: Callable[[Dict[str, Any]], Any],
1048
+ metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
1049
+ sample_limit: Optional[int] = None,
1050
+ preprocess_fn: Optional[Callable[[
1051
+ pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
1052
+ fit_predict_fn: Optional[
1053
+ Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
1054
+ pd.DataFrame, pd.Series, Optional[pd.Series],
1055
+ optuna.trial.Trial], np.ndarray]
1056
+ ] = None,
1057
+ cleanup_fn: Optional[Callable[[Any], None]] = None,
1058
+ splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
1059
+ """Generic holdout/CV helper to reuse tuning workflows.
1060
+
1061
+ Args:
1062
+ trial: current Optuna trial.
1063
+ hyperparameter_space: sampler dict keyed by parameter name.
1064
+ data_provider: callback returning (X, y, sample_weight).
1065
+ model_builder: callback to build a model per fold.
1066
+ metric_fn: loss/score function taking y_true, y_pred, weight.
1067
+ sample_limit: optional sample cap; random sample if exceeded.
1068
+ preprocess_fn: optional per-fold preprocessing (X_train, X_val).
1069
+ fit_predict_fn: optional custom fit/predict logic for validation.
1070
+ cleanup_fn: optional cleanup callback per fold.
1071
+ splitter: optional (train_idx, val_idx) iterator; defaults to cv_strategy config.
1072
+
1073
+ Returns:
1074
+ Mean validation metric across folds.
1075
+ """
1076
+ params: Optional[Dict[str, Any]] = None
1077
+ if self._distributed_forced_params is not None:
1078
+ params = self._distributed_forced_params
1079
+ self._distributed_forced_params = None
1080
+ else:
1081
+ if trial is None:
1082
+ raise RuntimeError(
1083
+ "Missing Optuna trial for parameter sampling.")
1084
+ params = {name: sampler(trial)
1085
+ for name, sampler in hyperparameter_space.items()}
1086
+ if self._should_use_distributed_optuna():
1087
+ self._distributed_prepare_trial(params)
1088
+ X_all, y_all, w_all = data_provider()
1089
+ cfg_limit = getattr(self.ctx.config, "bo_sample_limit", None)
1090
+ if cfg_limit is not None:
1091
+ cfg_limit = int(cfg_limit)
1092
+ if cfg_limit > 0:
1093
+ sample_limit = cfg_limit if sample_limit is None else min(sample_limit, cfg_limit)
1094
+ if sample_limit is not None and len(X_all) > sample_limit:
1095
+ sampled_idx = self._resolve_time_sample_indices(X_all, int(sample_limit))
1096
+ if sampled_idx is None:
1097
+ sampled_idx = X_all.sample(
1098
+ n=sample_limit,
1099
+ random_state=self.ctx.rand_seed
1100
+ ).index
1101
+ X_all = X_all.loc[sampled_idx]
1102
+ y_all = y_all.loc[sampled_idx]
1103
+ w_all = w_all.loc[sampled_idx] if w_all is not None else None
1104
+
1105
+ if splitter is None:
1106
+ val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
1107
+ if not (0.0 < val_ratio < 1.0):
1108
+ val_ratio = 0.25
1109
+ cv_splits = getattr(self.ctx.config, "cv_splits", None)
1110
+ if cv_splits is None:
1111
+ cv_splits = max(2, int(round(1 / val_ratio)))
1112
+ cv_splits = max(2, int(cv_splits))
1113
+
1114
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1115
+ split_iter, actual_splits = resolver.create_cv_splitter(X_all, y_all, cv_splits, val_ratio)
1116
+ if actual_splits < 2:
1117
+ raise ValueError("Not enough samples for cross-validation.")
1118
+ else:
1119
+ if hasattr(splitter, "split"):
1120
+ split_iter = splitter.split(X_all, y_all, groups=None)
1121
+ else:
1122
+ split_iter = splitter
1123
+
1124
+ losses: List[float] = []
1125
+ for fold_idx, (train_idx, val_idx) in enumerate(split_iter):
1126
+ X_train = X_all.iloc[train_idx]
1127
+ y_train = y_all.iloc[train_idx]
1128
+ X_val = X_all.iloc[val_idx]
1129
+ y_val = y_all.iloc[val_idx]
1130
+ w_train = w_all.iloc[train_idx] if w_all is not None else None
1131
+ w_val = w_all.iloc[val_idx] if w_all is not None else None
1132
+
1133
+ if preprocess_fn:
1134
+ X_train, X_val = preprocess_fn(X_train, X_val)
1135
+
1136
+ model = model_builder(params)
1137
+ try:
1138
+ if fit_predict_fn:
1139
+ # Avoid duplicate Optuna step reports across folds.
1140
+ trial_for_fold = trial if fold_idx == 0 else None
1141
+ y_pred = fit_predict_fn(
1142
+ model, X_train, y_train, w_train,
1143
+ X_val, y_val, w_val, trial_for_fold
1144
+ )
1145
+ else:
1146
+ fit_kwargs = {}
1147
+ if w_train is not None:
1148
+ fit_kwargs["sample_weight"] = w_train
1149
+ model.fit(X_train, y_train, **fit_kwargs)
1150
+ y_pred = model.predict(X_val)
1151
+ losses.append(metric_fn(y_val, y_pred, w_val))
1152
+ finally:
1153
+ if cleanup_fn:
1154
+ cleanup_fn(model)
1155
+ self._clean_gpu()
1156
+
1157
+ return float(np.mean(losses))
1158
+
1159
+ # Prediction + caching logic.
1160
+ def _predict_and_cache(self,
1161
+ model,
1162
+ pred_prefix: str,
1163
+ use_oht: bool = False,
1164
+ design_fn=None,
1165
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
1166
+ predict_kwargs_test: Optional[Dict[str, Any]] = None,
1167
+ predict_fn: Optional[Callable[..., Any]] = None) -> None:
1168
+ if design_fn:
1169
+ X_train = design_fn(train=True)
1170
+ X_test = design_fn(train=False)
1171
+ elif use_oht:
1172
+ X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
1173
+ X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
1174
+ else:
1175
+ X_train = self.ctx.train_data[self.ctx.factor_nmes]
1176
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
1177
+
1178
+ predictor = predict_fn or model.predict
1179
+ preds_train = predictor(X_train, **(predict_kwargs_train or {}))
1180
+ preds_test = predictor(X_test, **(predict_kwargs_test or {}))
1181
+ preds_train = np.asarray(preds_train)
1182
+ preds_test = np.asarray(preds_test)
1183
+
1184
+ if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
1185
+ col_name = f'pred_{pred_prefix}'
1186
+ self.ctx.train_data[col_name] = preds_train.reshape(-1)
1187
+ self.ctx.test_data[col_name] = preds_test.reshape(-1)
1188
+ self.ctx.train_data[f'w_{col_name}'] = (
1189
+ self.ctx.train_data[col_name] *
1190
+ self.ctx.train_data[self.ctx.weight_nme]
1191
+ )
1192
+ self.ctx.test_data[f'w_{col_name}'] = (
1193
+ self.ctx.test_data[col_name] *
1194
+ self.ctx.test_data[self.ctx.weight_nme]
1195
+ )
1196
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1197
+ return
1198
+
1199
+ # Vector outputs (e.g., embeddings) are expanded into pred_<prefix>_0.. columns.
1200
+ if preds_train.ndim != 2:
1201
+ raise ValueError(
1202
+ f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
1203
+ if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
1204
+ raise ValueError(
1205
+ f"Train/test prediction dims mismatch for '{pred_prefix}': "
1206
+ f"{preds_train.shape} vs {preds_test.shape}")
1207
+ for j in range(preds_train.shape[1]):
1208
+ col_name = f'pred_{pred_prefix}_{j}'
1209
+ self.ctx.train_data[col_name] = preds_train[:, j]
1210
+ self.ctx.test_data[col_name] = preds_test[:, j]
1211
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1212
+
1213
+ def _cache_predictions(self,
1214
+ pred_prefix: str,
1215
+ preds_train,
1216
+ preds_test) -> None:
1217
+ preds_train = np.asarray(preds_train)
1218
+ preds_test = np.asarray(preds_test)
1219
+ if preds_train.ndim <= 1 or (preds_train.ndim == 2 and preds_train.shape[1] == 1):
1220
+ if preds_test.ndim > 1:
1221
+ preds_test = preds_test.reshape(-1)
1222
+ col_name = f'pred_{pred_prefix}'
1223
+ self.ctx.train_data[col_name] = preds_train.reshape(-1)
1224
+ self.ctx.test_data[col_name] = preds_test.reshape(-1)
1225
+ self.ctx.train_data[f'w_{col_name}'] = (
1226
+ self.ctx.train_data[col_name] *
1227
+ self.ctx.train_data[self.ctx.weight_nme]
1228
+ )
1229
+ self.ctx.test_data[f'w_{col_name}'] = (
1230
+ self.ctx.test_data[col_name] *
1231
+ self.ctx.test_data[self.ctx.weight_nme]
1232
+ )
1233
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1234
+ return
1235
+
1236
+ if preds_train.ndim != 2:
1237
+ raise ValueError(
1238
+ f"Unexpected prediction shape for '{pred_prefix}': {preds_train.shape}")
1239
+ if preds_test.ndim != 2 or preds_test.shape[1] != preds_train.shape[1]:
1240
+ raise ValueError(
1241
+ f"Train/test prediction dims mismatch for '{pred_prefix}': "
1242
+ f"{preds_train.shape} vs {preds_test.shape}")
1243
+ for j in range(preds_train.shape[1]):
1244
+ col_name = f'pred_{pred_prefix}_{j}'
1245
+ self.ctx.train_data[col_name] = preds_train[:, j]
1246
+ self.ctx.test_data[col_name] = preds_test[:, j]
1247
+ self._maybe_cache_predictions(pred_prefix, preds_train, preds_test)
1248
+
1249
+ def _maybe_cache_predictions(self, pred_prefix: str, preds_train, preds_test) -> None:
1250
+ cfg = getattr(self.ctx, "config", None)
1251
+ if cfg is None or not bool(getattr(cfg, "cache_predictions", False)):
1252
+ return
1253
+ fmt = str(getattr(cfg, "prediction_cache_format", "parquet") or "parquet").lower()
1254
+ cache_dir = getattr(cfg, "prediction_cache_dir", None)
1255
+ if cache_dir:
1256
+ target_dir = Path(str(cache_dir))
1257
+ if not target_dir.is_absolute():
1258
+ target_dir = Path(self.output.result_dir) / target_dir
1259
+ else:
1260
+ target_dir = Path(self.output.result_dir) / "predictions"
1261
+ target_dir.mkdir(parents=True, exist_ok=True)
1262
+
1263
+ def _build_frame(preds, split_label: str) -> pd.DataFrame:
1264
+ arr = np.asarray(preds)
1265
+ if arr.ndim <= 1:
1266
+ return pd.DataFrame({f"pred_{pred_prefix}": arr.reshape(-1)})
1267
+ cols = [f"pred_{pred_prefix}_{i}" for i in range(arr.shape[1])]
1268
+ return pd.DataFrame(arr, columns=cols)
1269
+
1270
+ for split_label, preds in [("train", preds_train), ("test", preds_test)]:
1271
+ frame = _build_frame(preds, split_label)
1272
+ filename = f"{self.ctx.model_nme}_{pred_prefix}_{split_label}.{ 'csv' if fmt == 'csv' else 'parquet' }"
1273
+ path = target_dir / filename
1274
+ try:
1275
+ if fmt == "csv":
1276
+ frame.to_csv(path, index=False)
1277
+ else:
1278
+ frame.to_parquet(path, index=False)
1279
+ except Exception:
1280
+ pass
1281
+
1282
+ def _resolve_best_epoch(self,
1283
+ history: Optional[Dict[str, List[float]]],
1284
+ default_epochs: int) -> int:
1285
+ if not history:
1286
+ return max(1, int(default_epochs))
1287
+ vals = history.get("val") or []
1288
+ if not vals:
1289
+ return max(1, int(default_epochs))
1290
+ best_idx = int(np.nanargmin(vals))
1291
+ return max(1, best_idx + 1)
1292
+
1293
+ def _fit_predict_cache(self,
1294
+ model,
1295
+ X_train,
1296
+ y_train,
1297
+ sample_weight,
1298
+ pred_prefix: str,
1299
+ use_oht: bool = False,
1300
+ design_fn=None,
1301
+ fit_kwargs: Optional[Dict[str, Any]] = None,
1302
+ sample_weight_arg: Optional[str] = 'sample_weight',
1303
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
1304
+ predict_kwargs_test: Optional[Dict[str, Any]] = None,
1305
+ predict_fn: Optional[Callable[..., Any]] = None,
1306
+ record_label: bool = True) -> None:
1307
+ fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
1308
+ if sample_weight is not None and sample_weight_arg:
1309
+ fit_kwargs.setdefault(sample_weight_arg, sample_weight)
1310
+ model.fit(X_train, y_train, **fit_kwargs)
1311
+ if record_label:
1312
+ self.ctx.model_label.append(self.label)
1313
+ self._predict_and_cache(
1314
+ model,
1315
+ pred_prefix,
1316
+ use_oht=use_oht,
1317
+ design_fn=design_fn,
1318
+ predict_kwargs_train=predict_kwargs_train,
1319
+ predict_kwargs_test=predict_kwargs_test,
1320
+ predict_fn=predict_fn)