ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,1442 +1,1444 @@
1
- """
2
- CLI entry point generated from BayesOpt_AutoPricing.ipynb so the workflow can
3
- run non‑interactively (e.g., via torchrun).
4
-
5
- Example:
6
- python -m torch.distributed.run --standalone --nproc_per_node=2 \\
7
- ins_pricing/cli/BayesOpt_entry.py \\
8
- --config-json ins_pricing/examples/modelling/config_template.json \\
9
- --model-keys ft --max-evals 50 --use-ft-ddp
10
- """
11
-
12
- from __future__ import annotations
13
-
1
+ """
2
+ CLI entry point generated from BayesOpt_AutoPricing.ipynb so the workflow can
3
+ run non‑interactively (e.g., via torchrun).
4
+
5
+ Example:
6
+ python -m torch.distributed.run --standalone --nproc_per_node=2 \\
7
+ ins_pricing/cli/BayesOpt_entry.py \\
8
+ --config-json examples/config_template.json \\
9
+ --model-keys ft --max-evals 50 --use-ft-ddp
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
14
  from pathlib import Path
15
+ import importlib.util
15
16
  import sys
16
17
 
17
18
  if __package__ in {None, ""}:
18
- repo_root = Path(__file__).resolve().parents[2]
19
- if str(repo_root) not in sys.path:
20
- sys.path.insert(0, str(repo_root))
21
-
22
- import argparse
23
- import hashlib
24
- import json
25
- import os
26
- from datetime import datetime
27
- from typing import Any, Dict, List, Optional
28
-
29
- import numpy as np
30
- import pandas as pd
31
-
32
- # Use unified import resolver to eliminate nested try/except chains
33
- from .utils.import_resolver import resolve_imports, setup_sys_path
34
- from .utils.evaluation_context import (
35
- EvaluationContext,
36
- TrainingContext,
37
- ModelIdentity,
38
- DataFingerprint,
39
- CalibrationConfig,
40
- ThresholdConfig,
41
- BootstrapConfig,
42
- ReportConfig,
43
- RegistryConfig,
44
- )
45
-
46
- # Resolve all imports from a single location
47
- setup_sys_path()
48
- _imports = resolve_imports()
49
-
50
- ropt = _imports.bayesopt
51
- PLOT_MODEL_LABELS = _imports.PLOT_MODEL_LABELS
52
- PYTORCH_TRAINERS = _imports.PYTORCH_TRAINERS
53
- build_model_names = _imports.build_model_names
54
- dedupe_preserve_order = _imports.dedupe_preserve_order
55
- load_dataset = _imports.load_dataset
56
- parse_model_pairs = _imports.parse_model_pairs
57
- resolve_data_path = _imports.resolve_data_path
58
- resolve_path = _imports.resolve_path
59
- fingerprint_file = _imports.fingerprint_file
60
- coerce_dataset_types = _imports.coerce_dataset_types
61
- split_train_test = _imports.split_train_test
62
-
63
- add_config_json_arg = _imports.add_config_json_arg
64
- add_output_dir_arg = _imports.add_output_dir_arg
65
- resolve_and_load_config = _imports.resolve_and_load_config
66
- resolve_data_config = _imports.resolve_data_config
67
- resolve_report_config = _imports.resolve_report_config
68
- resolve_split_config = _imports.resolve_split_config
69
- resolve_runtime_config = _imports.resolve_runtime_config
70
- resolve_output_dirs = _imports.resolve_output_dirs
71
-
72
- bootstrap_ci = _imports.bootstrap_ci
73
- calibrate_predictions = _imports.calibrate_predictions
74
- eval_metrics_report = _imports.metrics_report
75
- select_threshold = _imports.select_threshold
76
-
77
- ModelArtifact = _imports.ModelArtifact
78
- ModelRegistry = _imports.ModelRegistry
79
- drift_psi_report = _imports.drift_psi_report
80
- group_metrics = _imports.group_metrics
81
- ReportPayload = _imports.ReportPayload
82
- write_report = _imports.write_report
83
-
84
- configure_run_logging = _imports.configure_run_logging
85
- plot_loss_curve_common = _imports.plot_loss_curve
86
-
87
- import matplotlib
88
-
89
- if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
90
- matplotlib.use("Agg")
91
- import matplotlib.pyplot as plt
92
-
93
-
94
- def _parse_args() -> argparse.Namespace:
95
- parser = argparse.ArgumentParser(
96
- description="Batch trainer generated from BayesOpt_AutoPricing notebook."
97
- )
98
- add_config_json_arg(
99
- parser,
100
- help_text="Path to the JSON config describing datasets and feature columns.",
101
- )
102
- parser.add_argument(
103
- "--model-keys",
104
- nargs="+",
105
- default=["ft"],
106
- choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
107
- help="Space-separated list of trainers to run (e.g., --model-keys glm xgb). Include 'all' to run every trainer.",
108
- )
109
- parser.add_argument(
110
- "--stack-model-keys",
111
- nargs="+",
112
- default=None,
113
- choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
114
- help=(
115
- "Only used when ft_role != 'model' (FT runs as feature generator). "
116
- "When provided (or when config defines stack_model_keys), these trainers run after FT features "
117
- "are generated. Use 'all' to run every non-FT trainer."
118
- ),
119
- )
120
- parser.add_argument(
121
- "--max-evals",
122
- type=int,
123
- default=50,
124
- help="Optuna trial count per dataset.",
125
- )
126
- parser.add_argument(
127
- "--use-resn-ddp",
128
- action="store_true",
129
- help="Force ResNet trainer to use DistributedDataParallel.",
130
- )
131
- parser.add_argument(
132
- "--use-ft-ddp",
133
- action="store_true",
134
- help="Force FT-Transformer trainer to use DistributedDataParallel.",
135
- )
136
- parser.add_argument(
137
- "--use-resn-dp",
138
- action="store_true",
139
- help="Enable ResNet DataParallel fall-back regardless of config.",
140
- )
141
- parser.add_argument(
142
- "--use-ft-dp",
143
- action="store_true",
144
- help="Enable FT-Transformer DataParallel fall-back regardless of config.",
145
- )
146
- parser.add_argument(
147
- "--use-gnn-dp",
148
- action="store_true",
149
- help="Enable GNN DataParallel fall-back regardless of config.",
150
- )
151
- parser.add_argument(
152
- "--use-gnn-ddp",
153
- action="store_true",
154
- help="Force GNN trainer to use DistributedDataParallel.",
155
- )
156
- parser.add_argument(
157
- "--gnn-no-ann",
158
- action="store_true",
159
- help="Disable approximate k-NN for GNN graph construction and use exact search.",
160
- )
161
- parser.add_argument(
162
- "--gnn-ann-threshold",
163
- type=int,
164
- default=None,
165
- help="Row threshold above which approximate k-NN is preferred (overrides config).",
166
- )
167
- parser.add_argument(
168
- "--gnn-graph-cache",
169
- default=None,
170
- help="Optional path to persist/load cached adjacency matrix for GNN.",
171
- )
172
- parser.add_argument(
173
- "--gnn-max-gpu-nodes",
174
- type=int,
175
- default=None,
176
- help="Overrides the maximum node count allowed for GPU k-NN graph construction.",
177
- )
178
- parser.add_argument(
179
- "--gnn-gpu-mem-ratio",
180
- type=float,
181
- default=None,
182
- help="Overrides the fraction of free GPU memory the k-NN builder may consume.",
183
- )
184
- parser.add_argument(
185
- "--gnn-gpu-mem-overhead",
186
- type=float,
187
- default=None,
188
- help="Overrides the temporary GPU memory overhead multiplier for k-NN estimation.",
189
- )
190
- add_output_dir_arg(
191
- parser,
192
- help_text="Override output root for models/results/plots.",
193
- )
194
- parser.add_argument(
195
- "--plot-curves",
196
- action="store_true",
197
- help="Enable lift/diagnostic plots after training (config file may also request plotting).",
198
- )
199
- parser.add_argument(
200
- "--ft-as-feature",
201
- action="store_true",
202
- help="Alias for --ft-role embedding (keep tuning, export embeddings; skip FT plots/SHAP).",
203
- )
204
- parser.add_argument(
205
- "--ft-role",
206
- default=None,
207
- choices=["model", "embedding", "unsupervised_embedding"],
208
- help="How to use FT: model (default), embedding (export pooling embeddings), or unsupervised_embedding.",
209
- )
210
- parser.add_argument(
211
- "--ft-feature-prefix",
212
- default="ft_feat",
213
- help="Prefix used for generated FT features (columns: pred_<prefix>_0.. or pred_<prefix>).",
214
- )
215
- parser.add_argument(
216
- "--reuse-best-params",
217
- action="store_true",
218
- help="Skip Optuna and reuse best_params saved in Results/versions or bestparams CSV when available.",
219
- )
220
- return parser.parse_args()
221
-
222
-
223
- def _plot_curves_for_model(model: ropt.BayesOptModel, trained_keys: List[str], cfg: Dict) -> None:
224
- plot_cfg = cfg.get("plot", {})
225
- legacy_lift_flags = {
226
- "glm": cfg.get("plot_lift_glm", False),
227
- "xgb": cfg.get("plot_lift_xgb", False),
228
- "resn": cfg.get("plot_lift_resn", False),
229
- "ft": cfg.get("plot_lift_ft", False),
230
- }
231
- plot_enabled = plot_cfg.get("enable", any(legacy_lift_flags.values()))
232
- if not plot_enabled:
233
- return
234
-
235
- n_bins = int(plot_cfg.get("n_bins", 10))
236
- oneway_enabled = plot_cfg.get("oneway", True)
237
-
238
- available_models = dedupe_preserve_order(
239
- [m for m in trained_keys if m in PLOT_MODEL_LABELS]
240
- )
241
-
242
- lift_models = plot_cfg.get("lift_models")
243
- if lift_models is None:
244
- lift_models = [
245
- m for m, enabled in legacy_lift_flags.items() if enabled]
246
- if not lift_models:
247
- lift_models = available_models
248
- lift_models = dedupe_preserve_order(
249
- [m for m in lift_models if m in available_models]
250
- )
251
-
252
- if oneway_enabled:
253
- oneway_pred = bool(plot_cfg.get("oneway_pred", False))
254
- oneway_pred_models = plot_cfg.get("oneway_pred_models")
255
- pred_plotted = False
256
- if oneway_pred:
257
- if oneway_pred_models is None:
258
- oneway_pred_models = lift_models or available_models
259
- oneway_pred_models = dedupe_preserve_order(
260
- [m for m in oneway_pred_models if m in available_models]
261
- )
262
- for model_key in oneway_pred_models:
263
- label, pred_nme = PLOT_MODEL_LABELS[model_key]
264
- if pred_nme not in model.train_data.columns:
265
- print(
266
- f"[Oneway] Missing prediction column '{pred_nme}'; skip.",
267
- flush=True,
268
- )
269
- continue
270
- model.plot_oneway(
271
- n_bins=n_bins,
272
- pred_col=pred_nme,
273
- pred_label=label,
274
- plot_subdir="oneway/post",
275
- )
276
- pred_plotted = True
277
- if not oneway_pred or not pred_plotted:
278
- model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/post")
279
-
280
- if not available_models:
281
- return
282
-
283
- for model_key in lift_models:
284
- label, pred_nme = PLOT_MODEL_LABELS[model_key]
285
- model.plot_lift(model_label=label, pred_nme=pred_nme, n_bins=n_bins)
286
-
287
- if not plot_cfg.get("double_lift", True) or len(available_models) < 2:
288
- return
289
-
290
- raw_pairs = plot_cfg.get("double_lift_pairs")
291
- if raw_pairs:
292
- pairs = [
293
- (a, b)
294
- for a, b in parse_model_pairs(raw_pairs)
295
- if a in available_models and b in available_models and a != b
296
- ]
297
- else:
298
- pairs = [(a, b) for i, a in enumerate(available_models)
299
- for b in available_models[i + 1:]]
300
-
301
- for first, second in pairs:
302
- model.plot_dlift([first, second], n_bins=n_bins)
303
-
304
-
305
- def _plot_loss_curve_for_trainer(model_name: str, trainer) -> None:
306
- model_obj = getattr(trainer, "model", None)
307
- history = None
308
- if model_obj is not None:
309
- history = getattr(model_obj, "training_history", None)
310
- if not history:
311
- history = getattr(trainer, "training_history", None)
312
- if not history:
313
- return
314
- train_hist = list(history.get("train") or [])
315
- val_hist = list(history.get("val") or [])
316
- if not train_hist and not val_hist:
317
- return
318
- try:
319
- plot_dir = trainer.output.plot_path(
320
- f"{model_name}/loss/loss_{model_name}_{trainer.model_name_prefix}.png"
321
- )
322
- except Exception:
323
- default_dir = Path("plot") / model_name / "loss"
324
- default_dir.mkdir(parents=True, exist_ok=True)
325
- plot_dir = str(
326
- default_dir / f"loss_{model_name}_{trainer.model_name_prefix}.png")
327
- if plot_loss_curve_common is not None:
328
- plot_loss_curve_common(
329
- history=history,
330
- title=f"{trainer.model_name_prefix} Loss Curve ({model_name})",
331
- save_path=plot_dir,
332
- show=False,
333
- )
334
- else:
335
- epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
336
- fig, ax = plt.subplots(figsize=(8, 4))
337
- if train_hist:
338
- ax.plot(range(1, len(train_hist) + 1),
339
- train_hist, label="Train Loss", color="tab:blue")
340
- if val_hist:
341
- ax.plot(range(1, len(val_hist) + 1),
342
- val_hist, label="Validation Loss", color="tab:orange")
343
- ax.set_xlabel("Epoch")
344
- ax.set_ylabel("Weighted Loss")
345
- ax.set_title(
346
- f"{trainer.model_name_prefix} Loss Curve ({model_name})")
347
- ax.grid(True, linestyle="--", alpha=0.3)
348
- ax.legend()
349
- plt.tight_layout()
350
- plt.savefig(plot_dir, dpi=300)
351
- plt.close(fig)
352
- print(
353
- f"[Plot] Saved loss curve for {model_name}/{trainer.label} -> {plot_dir}")
354
-
355
-
356
- def _sample_arrays(
357
- y_true: np.ndarray,
358
- y_pred: np.ndarray,
359
- *,
360
- max_rows: Optional[int],
361
- seed: Optional[int],
362
- ) -> tuple[np.ndarray, np.ndarray]:
363
- if max_rows is None or max_rows <= 0:
364
- return y_true, y_pred
365
- n = len(y_true)
366
- if n <= max_rows:
367
- return y_true, y_pred
368
- rng = np.random.default_rng(seed)
369
- idx = rng.choice(n, size=int(max_rows), replace=False)
370
- return y_true[idx], y_pred[idx]
371
-
372
-
373
- def _compute_psi_report(
374
- model: ropt.BayesOptModel,
375
- *,
376
- features: Optional[List[str]],
377
- bins: int,
378
- strategy: str,
379
- ) -> Optional[pd.DataFrame]:
380
- if drift_psi_report is None:
381
- return None
382
- psi_features = features or list(getattr(model, "factor_nmes", []))
383
- psi_features = [
384
- f for f in psi_features if f in model.train_data.columns and f in model.test_data.columns]
385
- if not psi_features:
386
- return None
387
- try:
388
- return drift_psi_report(
389
- model.train_data[psi_features],
390
- model.test_data[psi_features],
391
- features=psi_features,
392
- bins=int(bins),
393
- strategy=str(strategy),
394
- )
395
- except Exception as exc:
396
- print(f"[Report] PSI computation failed: {exc}")
397
- return None
398
-
399
-
400
- # --- Refactored helper functions for _evaluate_and_report ---
401
-
402
-
403
- def _apply_calibration(
404
- y_true_train: np.ndarray,
405
- y_pred_train: np.ndarray,
406
- y_pred_test: np.ndarray,
407
- calibration_cfg: Dict[str, Any],
408
- model_name: str,
409
- model_key: str,
410
- ) -> tuple[np.ndarray, np.ndarray, Optional[Dict[str, Any]]]:
411
- """Apply calibration to predictions for classification tasks.
412
-
413
- Returns:
414
- Tuple of (calibrated_train_preds, calibrated_test_preds, calibration_info)
415
- """
416
- cal_cfg = dict(calibration_cfg or {})
417
- cal_enabled = bool(cal_cfg.get("enable", False) or cal_cfg.get("method"))
418
-
419
- if not cal_enabled or calibrate_predictions is None:
420
- return y_pred_train, y_pred_test, None
421
-
422
- method = cal_cfg.get("method", "sigmoid")
423
- max_rows = cal_cfg.get("max_rows")
424
- seed = cal_cfg.get("seed")
425
- y_cal, p_cal = _sample_arrays(
426
- y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
427
-
428
- try:
429
- calibrator = calibrate_predictions(y_cal, p_cal, method=method)
430
- calibrated_train = calibrator.predict(y_pred_train)
431
- calibrated_test = calibrator.predict(y_pred_test)
432
- calibration_info = {"method": calibrator.method, "max_rows": max_rows}
433
- return calibrated_train, calibrated_test, calibration_info
434
- except Exception as exc:
435
- print(f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
436
- return y_pred_train, y_pred_test, None
437
-
438
-
439
- def _select_classification_threshold(
440
- y_true_train: np.ndarray,
441
- y_pred_train_eval: np.ndarray,
442
- threshold_cfg: Dict[str, Any],
443
- ) -> tuple[float, Optional[Dict[str, Any]]]:
444
- """Select threshold for classification predictions.
445
-
446
- Returns:
447
- Tuple of (threshold_value, threshold_info)
448
- """
449
- thr_cfg = dict(threshold_cfg or {})
450
- thr_enabled = bool(
451
- thr_cfg.get("enable", False)
452
- or thr_cfg.get("metric")
453
- or thr_cfg.get("value") is not None
454
- )
455
-
456
- if thr_cfg.get("value") is not None:
457
- threshold_value = float(thr_cfg["value"])
458
- return threshold_value, {"threshold": threshold_value, "source": "fixed"}
459
-
460
- if thr_enabled and select_threshold is not None:
461
- max_rows = thr_cfg.get("max_rows")
462
- seed = thr_cfg.get("seed")
463
- y_thr, p_thr = _sample_arrays(
464
- y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
465
- threshold_info = select_threshold(
466
- y_thr,
467
- p_thr,
468
- metric=thr_cfg.get("metric", "f1"),
469
- min_positive_rate=thr_cfg.get("min_positive_rate"),
470
- grid=thr_cfg.get("grid", 99),
471
- )
472
- return float(threshold_info.get("threshold", 0.5)), threshold_info
473
-
474
- return 0.5, None
475
-
476
-
477
- def _compute_classification_metrics(
478
- y_true_test: np.ndarray,
479
- y_pred_test_eval: np.ndarray,
480
- threshold_value: float,
481
- ) -> Dict[str, Any]:
482
- """Compute metrics for classification task."""
483
- metrics = eval_metrics_report(
484
- y_true_test,
485
- y_pred_test_eval,
486
- task_type="classification",
487
- threshold=threshold_value,
488
- )
489
- precision = float(metrics.get("precision", 0.0))
490
- recall = float(metrics.get("recall", 0.0))
491
- f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
492
- metrics["f1"] = float(f1)
493
- metrics["threshold"] = float(threshold_value)
494
- return metrics
495
-
496
-
497
- def _compute_bootstrap_ci(
498
- y_true_test: np.ndarray,
499
- y_pred_test_eval: np.ndarray,
500
- weight_test: Optional[np.ndarray],
501
- metrics: Dict[str, Any],
502
- bootstrap_cfg: Dict[str, Any],
503
- task_type: str,
504
- ) -> Dict[str, Dict[str, float]]:
505
- """Compute bootstrap confidence intervals for metrics."""
506
- if not bootstrap_cfg or not bool(bootstrap_cfg.get("enable", False)) or bootstrap_ci is None:
507
- return {}
508
-
509
- metric_names = bootstrap_cfg.get("metrics")
510
- if not metric_names:
511
- metric_names = [name for name in metrics.keys() if name != "threshold"]
512
- n_samples = int(bootstrap_cfg.get("n_samples", 200))
513
- ci = float(bootstrap_cfg.get("ci", 0.95))
514
- seed = bootstrap_cfg.get("seed")
515
-
516
- def _metric_fn(y_true, y_pred, weight=None):
517
- vals = eval_metrics_report(
518
- y_true,
519
- y_pred,
520
- task_type=task_type,
521
- weight=weight,
522
- threshold=metrics.get("threshold", 0.5),
523
- )
524
- if task_type == "classification":
525
- prec = float(vals.get("precision", 0.0))
526
- rec = float(vals.get("recall", 0.0))
527
- vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
528
- return vals
529
-
530
- bootstrap_results: Dict[str, Dict[str, float]] = {}
531
- for name in metric_names:
532
- if name not in metrics:
533
- continue
534
- ci_result = bootstrap_ci(
535
- lambda y_t, y_p, w=None: float(_metric_fn(y_t, y_p, w).get(name, 0.0)),
536
- y_true_test,
537
- y_pred_test_eval,
538
- weight=weight_test,
539
- n_samples=n_samples,
540
- ci=ci,
541
- seed=seed,
542
- )
543
- bootstrap_results[str(name)] = ci_result
544
-
545
- return bootstrap_results
546
-
547
-
548
- def _compute_validation_table(
549
- model: ropt.BayesOptModel,
550
- pred_col: str,
551
- report_group_cols: Optional[List[str]],
552
- weight_col: Optional[str],
553
- model_name: str,
554
- model_key: str,
555
- ) -> Optional[pd.DataFrame]:
556
- """Compute grouped validation metrics table."""
557
- if not report_group_cols or group_metrics is None:
558
- return None
559
-
560
- available_groups = [
561
- col for col in report_group_cols if col in model.test_data.columns
562
- ]
563
- if not available_groups:
564
- return None
565
-
566
- try:
567
- validation_table = group_metrics(
568
- model.test_data,
569
- actual_col=model.resp_nme,
570
- pred_col=pred_col,
571
- group_cols=available_groups,
572
- weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
573
- )
574
- counts = (
575
- model.test_data.groupby(available_groups, dropna=False)
576
- .size()
577
- .reset_index(name="count")
578
- )
579
- return validation_table.merge(counts, on=available_groups, how="left")
580
- except Exception as exc:
581
- print(f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
582
- return None
583
-
584
-
585
- def _compute_risk_trend(
586
- model: ropt.BayesOptModel,
587
- pred_col: str,
588
- report_time_col: Optional[str],
589
- report_time_freq: str,
590
- report_time_ascending: bool,
591
- weight_col: Optional[str],
592
- model_name: str,
593
- model_key: str,
594
- ) -> Optional[pd.DataFrame]:
595
- """Compute time-series risk trend metrics."""
596
- if not report_time_col or group_metrics is None:
597
- return None
598
-
599
- if report_time_col not in model.test_data.columns:
600
- return None
601
-
602
- try:
603
- time_df = model.test_data.copy()
604
- time_series = pd.to_datetime(time_df[report_time_col], errors="coerce")
605
- time_df = time_df.loc[time_series.notna()].copy()
606
-
607
- if time_df.empty:
608
- return None
609
-
610
- time_df["_time_bucket"] = (
611
- pd.to_datetime(time_df[report_time_col], errors="coerce")
612
- .dt.to_period(report_time_freq)
613
- .dt.to_timestamp()
614
- )
615
- risk_trend = group_metrics(
616
- time_df,
617
- actual_col=model.resp_nme,
618
- pred_col=pred_col,
619
- group_cols=["_time_bucket"],
620
- weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
621
- )
622
- counts = (
623
- time_df.groupby("_time_bucket", dropna=False)
624
- .size()
625
- .reset_index(name="count")
626
- )
627
- risk_trend = risk_trend.merge(counts, on="_time_bucket", how="left")
628
- risk_trend = risk_trend.sort_values(
629
- "_time_bucket", ascending=bool(report_time_ascending)
630
- ).reset_index(drop=True)
631
- return risk_trend.rename(columns={"_time_bucket": report_time_col})
632
- except Exception as exc:
633
- print(f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
634
- return None
635
-
636
-
637
- def _write_metrics_json(
638
- report_root: Path,
639
- model_name: str,
640
- model_key: str,
641
- version: str,
642
- metrics: Dict[str, Any],
643
- threshold_info: Optional[Dict[str, Any]],
644
- calibration_info: Optional[Dict[str, Any]],
645
- bootstrap_results: Dict[str, Dict[str, float]],
646
- data_path: Path,
647
- data_fingerprint: Dict[str, Any],
648
- config_sha: str,
649
- pred_col: str,
650
- task_type: str,
651
- ) -> Path:
652
- """Write metrics to JSON file and return the path."""
653
- metrics_payload = {
654
- "model_name": model_name,
655
- "model_key": model_key,
656
- "model_version": version,
657
- "metrics": metrics,
658
- "threshold": threshold_info,
659
- "calibration": calibration_info,
660
- "bootstrap": bootstrap_results,
661
- "data_path": str(data_path),
662
- "data_fingerprint": data_fingerprint,
663
- "config_sha256": config_sha,
664
- "pred_col": pred_col,
665
- "task_type": task_type,
666
- }
667
- metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
668
- metrics_path.write_text(
669
- json.dumps(metrics_payload, indent=2, ensure_ascii=True),
670
- encoding="utf-8",
671
- )
672
- return metrics_path
673
-
674
-
675
- def _write_model_report(
676
- report_root: Path,
677
- model_name: str,
678
- model_key: str,
679
- version: str,
680
- metrics: Dict[str, Any],
681
- risk_trend: Optional[pd.DataFrame],
682
- psi_report_df: Optional[pd.DataFrame],
683
- validation_table: Optional[pd.DataFrame],
684
- calibration_info: Optional[Dict[str, Any]],
685
- threshold_info: Optional[Dict[str, Any]],
686
- bootstrap_results: Dict[str, Dict[str, float]],
687
- config_sha: str,
688
- data_fingerprint: Dict[str, Any],
689
- ) -> Optional[Path]:
690
- """Write model report and return the path."""
691
- if ReportPayload is None or write_report is None:
692
- return None
693
-
694
- notes_lines = [
695
- f"- Config SHA256: {config_sha}",
696
- f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
697
- ]
698
- if calibration_info:
699
- notes_lines.append(f"- Calibration: {calibration_info.get('method')}")
700
- if threshold_info:
701
- notes_lines.append(f"- Threshold selection: {threshold_info}")
702
- if bootstrap_results:
703
- notes_lines.append("- Bootstrap: see metrics JSON for CI")
704
-
705
- payload = ReportPayload(
706
- model_name=f"{model_name}/{model_key}",
707
- model_version=version,
708
- metrics={k: float(v) for k, v in metrics.items()},
709
- risk_trend=risk_trend,
710
- drift_report=psi_report_df,
711
- validation_table=validation_table,
712
- extra_notes="\n".join(notes_lines),
713
- )
714
- return write_report(
715
- payload,
716
- report_root / f"{model_name}_{model_key}_report.md",
717
- )
718
-
719
-
720
- def _register_model_to_registry(
721
- model: ropt.BayesOptModel,
722
- model_name: str,
723
- model_key: str,
724
- version: str,
725
- metrics: Dict[str, Any],
726
- task_type: str,
727
- data_path: Path,
728
- data_fingerprint: Dict[str, Any],
729
- config_sha: str,
730
- registry_path: Optional[str],
731
- registry_tags: Dict[str, Any],
732
- registry_status: str,
733
- report_path: Optional[Path],
734
- metrics_path: Path,
735
- cfg: Dict[str, Any],
736
- ) -> None:
737
- """Register model artifacts to the model registry."""
738
- if ModelRegistry is None or ModelArtifact is None:
739
- return
740
-
741
- registry = ModelRegistry(
742
- registry_path
743
- if registry_path
744
- else Path(model.output_manager.result_dir) / "model_registry.json"
745
- )
746
-
747
- tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
748
- tags.update({
749
- "model_key": str(model_key),
750
- "task_type": str(task_type),
751
- "data_path": str(data_path),
752
- "data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
753
- "data_size": str(data_fingerprint.get("size", "")),
754
- "data_mtime": str(data_fingerprint.get("mtime", "")),
755
- "config_sha256": str(config_sha),
756
- })
757
-
758
- artifacts = _collect_model_artifacts(
759
- model, model_name, model_key, report_path, metrics_path, cfg
760
- )
761
-
762
- registry.register(
763
- name=str(model_name),
764
- version=version,
765
- metrics={k: float(v) for k, v in metrics.items()},
766
- tags=tags,
767
- artifacts=artifacts,
768
- status=str(registry_status or "candidate"),
769
- notes=f"model_key={model_key}",
770
- )
771
-
772
-
773
- def _collect_model_artifacts(
774
- model: ropt.BayesOptModel,
775
- model_name: str,
776
- model_key: str,
777
- report_path: Optional[Path],
778
- metrics_path: Path,
779
- cfg: Dict[str, Any],
780
- ) -> List:
781
- """Collect all model artifacts for registry."""
782
- artifacts = []
783
-
784
- # Trained model artifact
785
- trainer = model.trainers.get(model_key)
786
- if trainer is not None:
787
- try:
788
- model_path = trainer.output.model_path(trainer._get_model_filename())
789
- if os.path.exists(model_path):
790
- artifacts.append(ModelArtifact(path=model_path, description="trained model"))
791
- except Exception:
792
- pass
793
-
794
- # Report artifact
795
- if report_path is not None:
796
- artifacts.append(ModelArtifact(path=str(report_path), description="model report"))
797
-
798
- # Metrics JSON artifact
799
- if metrics_path.exists():
800
- artifacts.append(ModelArtifact(path=str(metrics_path), description="metrics json"))
801
-
802
- # Preprocess artifacts
803
- if bool(cfg.get("save_preprocess", False)):
804
- artifact_path = cfg.get("preprocess_artifact_path")
805
- if artifact_path:
806
- preprocess_path = Path(str(artifact_path))
807
- if not preprocess_path.is_absolute():
808
- preprocess_path = Path(model.output_manager.result_dir) / preprocess_path
809
- else:
810
- preprocess_path = Path(model.output_manager.result_path(
811
- f"{model.model_nme}_preprocess.json"
812
- ))
813
- if preprocess_path.exists():
814
- artifacts.append(
815
- ModelArtifact(path=str(preprocess_path), description="preprocess artifacts")
816
- )
817
-
818
- # Prediction cache artifacts
819
- if bool(cfg.get("cache_predictions", False)):
820
- cache_dir = cfg.get("prediction_cache_dir")
821
- if cache_dir:
822
- pred_root = Path(str(cache_dir))
823
- if not pred_root.is_absolute():
824
- pred_root = Path(model.output_manager.result_dir) / pred_root
825
- else:
826
- pred_root = Path(model.output_manager.result_dir) / "predictions"
827
- ext = "csv" if str(cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
828
- for split_label in ("train", "test"):
829
- pred_path = pred_root / f"{model_name}_{model_key}_{split_label}.{ext}"
830
- if pred_path.exists():
831
- artifacts.append(
832
- ModelArtifact(path=str(pred_path), description=f"predictions {split_label}")
833
- )
834
-
835
- return artifacts
836
-
837
-
838
- def _evaluate_and_report(
839
- model: ropt.BayesOptModel,
840
- *,
841
- model_name: str,
842
- model_key: str,
843
- cfg: Dict[str, Any],
844
- data_path: Path,
845
- data_fingerprint: Dict[str, Any],
846
- report_output_dir: Optional[str],
847
- report_group_cols: Optional[List[str]],
848
- report_time_col: Optional[str],
849
- report_time_freq: str,
850
- report_time_ascending: bool,
851
- psi_report_df: Optional[pd.DataFrame],
852
- calibration_cfg: Dict[str, Any],
853
- threshold_cfg: Dict[str, Any],
854
- bootstrap_cfg: Dict[str, Any],
855
- register_model: bool,
856
- registry_path: Optional[str],
857
- registry_tags: Dict[str, Any],
858
- registry_status: str,
859
- run_id: str,
860
- config_sha: str,
861
- ) -> None:
862
- """Evaluate model predictions and generate reports.
863
-
864
- This function orchestrates the evaluation pipeline:
865
- 1. Extract predictions and ground truth
866
- 2. Apply calibration (for classification)
867
- 3. Select threshold (for classification)
868
- 4. Compute metrics
869
- 5. Compute bootstrap confidence intervals
870
- 6. Generate validation tables and risk trends
871
- 7. Write reports and register model
872
- """
873
- if eval_metrics_report is None:
874
- print("[Report] Skip evaluation: metrics module unavailable.")
875
- return
876
-
877
- pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
878
- if pred_col not in model.test_data.columns:
879
- print(f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
880
- return
881
-
882
- # Extract predictions and weights
883
- weight_col = getattr(model, "weight_nme", None)
884
- y_true_train = model.train_data[model.resp_nme].to_numpy(dtype=float, copy=False)
885
- y_true_test = model.test_data[model.resp_nme].to_numpy(dtype=float, copy=False)
886
- y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
887
- y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
888
- weight_test = (
889
- model.test_data[weight_col].to_numpy(dtype=float, copy=False)
890
- if weight_col and weight_col in model.test_data.columns
891
- else None
892
- )
893
-
894
- task_type = str(cfg.get("task_type", getattr(model, "task_type", "regression")))
895
-
896
- # Process based on task type
897
- if task_type == "classification":
898
- y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
899
- y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
900
-
901
- y_pred_train_eval, y_pred_test_eval, calibration_info = _apply_calibration(
902
- y_true_train, y_pred_train, y_pred_test, calibration_cfg, model_name, model_key
903
- )
904
- threshold_value, threshold_info = _select_classification_threshold(
905
- y_true_train, y_pred_train_eval, threshold_cfg
906
- )
907
- metrics = _compute_classification_metrics(y_true_test, y_pred_test_eval, threshold_value)
908
- else:
909
- y_pred_test_eval = y_pred_test
910
- calibration_info = None
911
- threshold_info = None
912
- metrics = eval_metrics_report(
913
- y_true_test, y_pred_test_eval, task_type=task_type, weight=weight_test
914
- )
915
-
916
- # Compute bootstrap confidence intervals
917
- bootstrap_results = _compute_bootstrap_ci(
918
- y_true_test, y_pred_test_eval, weight_test, metrics, bootstrap_cfg, task_type
919
- )
920
-
921
- # Compute validation table and risk trend
922
- validation_table = _compute_validation_table(
923
- model, pred_col, report_group_cols, weight_col, model_name, model_key
924
- )
925
- risk_trend = _compute_risk_trend(
926
- model, pred_col, report_time_col, report_time_freq,
927
- report_time_ascending, weight_col, model_name, model_key
928
- )
929
-
930
- # Setup output directory
931
- report_root = (
932
- Path(report_output_dir)
933
- if report_output_dir
934
- else Path(model.output_manager.result_dir) / "reports"
935
- )
936
- report_root.mkdir(parents=True, exist_ok=True)
937
- version = f"{model_key}_{run_id}"
938
-
939
- # Write metrics JSON
940
- metrics_path = _write_metrics_json(
941
- report_root, model_name, model_key, version, metrics,
942
- threshold_info, calibration_info, bootstrap_results,
943
- data_path, data_fingerprint, config_sha, pred_col, task_type
944
- )
945
-
946
- # Write model report
947
- report_path = _write_model_report(
948
- report_root, model_name, model_key, version, metrics,
949
- risk_trend, psi_report_df, validation_table,
950
- calibration_info, threshold_info, bootstrap_results,
951
- config_sha, data_fingerprint
952
- )
953
-
954
- # Register model
955
- if register_model:
956
- _register_model_to_registry(
957
- model, model_name, model_key, version, metrics, task_type,
958
- data_path, data_fingerprint, config_sha, registry_path,
959
- registry_tags, registry_status, report_path, metrics_path, cfg
960
- )
961
-
962
-
963
- def _evaluate_with_context(
964
- model: ropt.BayesOptModel,
965
- ctx: EvaluationContext,
966
- ) -> None:
967
- """Evaluate model predictions using context object.
968
-
969
- This is a cleaner interface that uses the EvaluationContext dataclass
970
- instead of 19+ individual parameters.
971
- """
972
- _evaluate_and_report(
973
- model,
974
- model_name=ctx.identity.model_name,
975
- model_key=ctx.identity.model_key,
976
- cfg=ctx.cfg,
977
- data_path=ctx.data_path,
978
- data_fingerprint=ctx.data_fingerprint.to_dict(),
979
- report_output_dir=ctx.report.output_dir,
980
- report_group_cols=ctx.report.group_cols,
981
- report_time_col=ctx.report.time_col,
982
- report_time_freq=ctx.report.time_freq,
983
- report_time_ascending=ctx.report.time_ascending,
984
- psi_report_df=ctx.psi_report_df,
985
- calibration_cfg={
986
- "enable": ctx.calibration.enable,
987
- "method": ctx.calibration.method,
988
- "max_rows": ctx.calibration.max_rows,
989
- "seed": ctx.calibration.seed,
990
- },
991
- threshold_cfg={
992
- "enable": ctx.threshold.enable,
993
- "metric": ctx.threshold.metric,
994
- "value": ctx.threshold.value,
995
- "min_positive_rate": ctx.threshold.min_positive_rate,
996
- "grid": ctx.threshold.grid,
997
- "max_rows": ctx.threshold.max_rows,
998
- "seed": ctx.threshold.seed,
999
- },
1000
- bootstrap_cfg={
1001
- "enable": ctx.bootstrap.enable,
1002
- "metrics": ctx.bootstrap.metrics,
1003
- "n_samples": ctx.bootstrap.n_samples,
1004
- "ci": ctx.bootstrap.ci,
1005
- "seed": ctx.bootstrap.seed,
1006
- },
1007
- register_model=ctx.registry.register,
1008
- registry_path=ctx.registry.path,
1009
- registry_tags=ctx.registry.tags,
1010
- registry_status=ctx.registry.status,
1011
- run_id=ctx.run_id,
1012
- config_sha=ctx.config_sha,
1013
- )
1014
-
1015
-
1016
- def _create_ddp_barrier(dist_ctx: TrainingContext):
1017
- """Create a DDP barrier function for distributed training synchronization."""
1018
- def _ddp_barrier(reason: str) -> None:
1019
- if not dist_ctx.is_distributed:
1020
- return
1021
- torch_mod = getattr(ropt, "torch", None)
1022
- dist_mod = getattr(torch_mod, "distributed", None)
1023
- if dist_mod is None:
1024
- return
1025
- try:
1026
- if not getattr(dist_mod, "is_available", lambda: False)():
1027
- return
1028
- if not dist_mod.is_initialized():
1029
- ddp_ok, _, _, _ = ropt.DistributedUtils.setup_ddp()
1030
- if not ddp_ok or not dist_mod.is_initialized():
1031
- return
1032
- dist_mod.barrier()
1033
- except Exception as exc:
1034
- print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
1035
- raise
1036
- return _ddp_barrier
1037
-
1038
-
1039
- def train_from_config(args: argparse.Namespace) -> None:
1040
- script_dir = Path(__file__).resolve().parents[1]
1041
- config_path, cfg = resolve_and_load_config(
1042
- args.config_json,
1043
- script_dir,
1044
- required_keys=["data_dir", "model_list",
1045
- "model_categories", "target", "weight"],
1046
- )
1047
- plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
1048
- config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
1049
- run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
1050
-
1051
- # Use TrainingContext for distributed training state
1052
- dist_ctx = TrainingContext.from_env()
1053
- dist_world_size = dist_ctx.world_size
1054
- dist_rank = dist_ctx.rank
1055
- dist_active = dist_ctx.is_distributed
1056
- is_main_process = dist_ctx.is_main_process
1057
- _ddp_barrier = _create_ddp_barrier(dist_ctx)
1058
-
1059
- data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
1060
- cfg,
1061
- config_path,
1062
- create_data_dir=True,
1063
- )
1064
- runtime_cfg = resolve_runtime_config(cfg)
1065
- ddp_min_rows = runtime_cfg["ddp_min_rows"]
1066
- bo_sample_limit = runtime_cfg["bo_sample_limit"]
1067
- cache_predictions = runtime_cfg["cache_predictions"]
1068
- prediction_cache_dir = runtime_cfg["prediction_cache_dir"]
1069
- prediction_cache_format = runtime_cfg["prediction_cache_format"]
1070
- report_cfg = resolve_report_config(cfg)
1071
- report_output_dir = report_cfg["report_output_dir"]
1072
- report_group_cols = report_cfg["report_group_cols"]
1073
- report_time_col = report_cfg["report_time_col"]
1074
- report_time_freq = report_cfg["report_time_freq"]
1075
- report_time_ascending = report_cfg["report_time_ascending"]
1076
- psi_bins = report_cfg["psi_bins"]
1077
- psi_strategy = report_cfg["psi_strategy"]
1078
- psi_features = report_cfg["psi_features"]
1079
- calibration_cfg = report_cfg["calibration_cfg"]
1080
- threshold_cfg = report_cfg["threshold_cfg"]
1081
- bootstrap_cfg = report_cfg["bootstrap_cfg"]
1082
- register_model = report_cfg["register_model"]
1083
- registry_path = report_cfg["registry_path"]
1084
- registry_tags = report_cfg["registry_tags"]
1085
- registry_status = report_cfg["registry_status"]
1086
- data_fingerprint_max_bytes = report_cfg["data_fingerprint_max_bytes"]
1087
- report_enabled = report_cfg["report_enabled"]
1088
-
1089
- split_cfg = resolve_split_config(cfg)
1090
- prop_test = split_cfg["prop_test"]
1091
- holdout_ratio = split_cfg["holdout_ratio"]
1092
- val_ratio = split_cfg["val_ratio"]
1093
- split_strategy = split_cfg["split_strategy"]
1094
- split_group_col = split_cfg["split_group_col"]
1095
- split_time_col = split_cfg["split_time_col"]
1096
- split_time_ascending = split_cfg["split_time_ascending"]
1097
- cv_strategy = split_cfg["cv_strategy"]
1098
- cv_group_col = split_cfg["cv_group_col"]
1099
- cv_time_col = split_cfg["cv_time_col"]
1100
- cv_time_ascending = split_cfg["cv_time_ascending"]
1101
- cv_splits = split_cfg["cv_splits"]
1102
- ft_oof_folds = split_cfg["ft_oof_folds"]
1103
- ft_oof_strategy = split_cfg["ft_oof_strategy"]
1104
- ft_oof_shuffle = split_cfg["ft_oof_shuffle"]
1105
- save_preprocess = runtime_cfg["save_preprocess"]
1106
- preprocess_artifact_path = runtime_cfg["preprocess_artifact_path"]
1107
- rand_seed = runtime_cfg["rand_seed"]
1108
- epochs = runtime_cfg["epochs"]
1109
- output_cfg = resolve_output_dirs(
1110
- cfg,
1111
- config_path,
1112
- output_override=args.output_dir,
1113
- )
1114
- output_dir = output_cfg["output_dir"]
1115
- reuse_best_params = bool(
1116
- args.reuse_best_params or runtime_cfg["reuse_best_params"])
1117
- xgb_max_depth_max = runtime_cfg["xgb_max_depth_max"]
1118
- xgb_n_estimators_max = runtime_cfg["xgb_n_estimators_max"]
1119
- optuna_storage = runtime_cfg["optuna_storage"]
1120
- optuna_study_prefix = runtime_cfg["optuna_study_prefix"]
1121
- best_params_files = runtime_cfg["best_params_files"]
1122
- plot_path_style = runtime_cfg["plot_path_style"]
1123
-
1124
- model_names = build_model_names(
1125
- cfg["model_list"], cfg["model_categories"])
1126
- if not model_names:
1127
- raise ValueError(
1128
- "No model names generated from model_list/model_categories.")
1129
-
1130
- results: Dict[str, ropt.BayesOptModel] = {}
1131
- trained_keys_by_model: Dict[str, List[str]] = {}
1132
-
1133
- for model_name in model_names:
1134
- # Per-dataset training loop: load data, split train/test, and train requested models.
1135
- data_path = resolve_data_path(
1136
- data_dir,
1137
- model_name,
1138
- data_format=data_format,
1139
- path_template=data_path_template,
1140
- )
1141
- if not data_path.exists():
1142
- raise FileNotFoundError(f"Missing dataset: {data_path}")
1143
- data_fingerprint = {"path": str(data_path)}
1144
- if report_enabled and is_main_process:
1145
- data_fingerprint = fingerprint_file(
1146
- data_path,
1147
- max_bytes=data_fingerprint_max_bytes,
1148
- )
1149
-
1150
- print(f"\n=== Processing model {model_name} ===")
1151
- raw = load_dataset(
1152
- data_path,
1153
- data_format=data_format,
1154
- dtype_map=dtype_map,
1155
- low_memory=False,
1156
- )
1157
- raw = coerce_dataset_types(raw)
1158
-
1159
- train_df, test_df = split_train_test(
1160
- raw,
1161
- holdout_ratio=holdout_ratio,
1162
- strategy=split_strategy,
1163
- group_col=split_group_col,
1164
- time_col=split_time_col,
1165
- time_ascending=split_time_ascending,
1166
- rand_seed=rand_seed,
1167
- reset_index_mode="time_group",
1168
- ratio_label="holdout_ratio",
1169
- )
1170
-
1171
- use_resn_dp = args.use_resn_dp or cfg.get(
1172
- "use_resn_data_parallel", False)
1173
- use_ft_dp = args.use_ft_dp or cfg.get("use_ft_data_parallel", True)
1174
- dataset_rows = len(raw)
1175
- ddp_enabled = bool(dist_active and (dataset_rows >= int(ddp_min_rows)))
1176
- use_resn_ddp = (args.use_resn_ddp or cfg.get(
1177
- "use_resn_ddp", False)) and ddp_enabled
1178
- use_ft_ddp = (args.use_ft_ddp or cfg.get(
1179
- "use_ft_ddp", False)) and ddp_enabled
1180
- use_gnn_dp = args.use_gnn_dp or cfg.get("use_gnn_data_parallel", False)
1181
- use_gnn_ddp = (args.use_gnn_ddp or cfg.get(
1182
- "use_gnn_ddp", False)) and ddp_enabled
1183
- gnn_use_ann = cfg.get("gnn_use_approx_knn", True)
1184
- if args.gnn_no_ann:
1185
- gnn_use_ann = False
1186
- gnn_threshold = args.gnn_ann_threshold if args.gnn_ann_threshold is not None else cfg.get(
1187
- "gnn_approx_knn_threshold", 50000)
1188
- gnn_graph_cache = args.gnn_graph_cache or cfg.get("gnn_graph_cache")
1189
- if isinstance(gnn_graph_cache, str) and gnn_graph_cache.strip():
1190
- resolved_cache = resolve_path(gnn_graph_cache, config_path.parent)
1191
- if resolved_cache is not None:
1192
- gnn_graph_cache = str(resolved_cache)
1193
- gnn_max_gpu_nodes = args.gnn_max_gpu_nodes if args.gnn_max_gpu_nodes is not None else cfg.get(
1194
- "gnn_max_gpu_knn_nodes", 200000)
1195
- gnn_gpu_mem_ratio = args.gnn_gpu_mem_ratio if args.gnn_gpu_mem_ratio is not None else cfg.get(
1196
- "gnn_knn_gpu_mem_ratio", 0.9)
1197
- gnn_gpu_mem_overhead = args.gnn_gpu_mem_overhead if args.gnn_gpu_mem_overhead is not None else cfg.get(
1198
- "gnn_knn_gpu_mem_overhead", 2.0)
1199
-
1200
- binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
1201
- task_type = str(cfg.get("task_type", "regression"))
1202
- feature_list = cfg.get("feature_list")
1203
- categorical_features = cfg.get("categorical_features")
1204
- use_gpu = bool(cfg.get("use_gpu", True))
1205
- region_province_col = cfg.get("region_province_col")
1206
- region_city_col = cfg.get("region_city_col")
1207
- region_effect_alpha = cfg.get("region_effect_alpha")
1208
- geo_feature_nmes = cfg.get("geo_feature_nmes")
1209
- geo_token_hidden_dim = cfg.get("geo_token_hidden_dim")
1210
- geo_token_layers = cfg.get("geo_token_layers")
1211
- geo_token_dropout = cfg.get("geo_token_dropout")
1212
- geo_token_k_neighbors = cfg.get("geo_token_k_neighbors")
1213
- geo_token_learning_rate = cfg.get("geo_token_learning_rate")
1214
- geo_token_epochs = cfg.get("geo_token_epochs")
1215
-
1216
- ft_role = args.ft_role or cfg.get("ft_role", "model")
1217
- if args.ft_as_feature and args.ft_role is None:
1218
- # Keep legacy behavior as a convenience alias only when the config
1219
- # didn't already request a non-default FT role.
1220
- if str(cfg.get("ft_role", "model")) == "model":
1221
- ft_role = "embedding"
1222
- ft_feature_prefix = str(
1223
- cfg.get("ft_feature_prefix", args.ft_feature_prefix))
1224
- ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
1225
-
1226
- config_fields = getattr(ropt.BayesOptConfig,
1227
- "__dataclass_fields__", {})
1228
- allowed_config_keys = set(config_fields.keys())
1229
- config_payload = {k: v for k,
1230
- v in cfg.items() if k in allowed_config_keys}
1231
- config_payload.update({
1232
- "model_nme": model_name,
1233
- "resp_nme": cfg["target"],
1234
- "weight_nme": cfg["weight"],
1235
- "factor_nmes": feature_list,
1236
- "task_type": task_type,
1237
- "binary_resp_nme": binary_target,
1238
- "cate_list": categorical_features,
1239
- "prop_test": val_ratio,
1240
- "rand_seed": rand_seed,
1241
- "epochs": epochs,
1242
- "use_gpu": use_gpu,
1243
- "use_resn_data_parallel": use_resn_dp,
1244
- "use_ft_data_parallel": use_ft_dp,
1245
- "use_gnn_data_parallel": use_gnn_dp,
1246
- "use_resn_ddp": use_resn_ddp,
1247
- "use_ft_ddp": use_ft_ddp,
1248
- "use_gnn_ddp": use_gnn_ddp,
1249
- "output_dir": output_dir,
1250
- "xgb_max_depth_max": xgb_max_depth_max,
1251
- "xgb_n_estimators_max": xgb_n_estimators_max,
1252
- "resn_weight_decay": cfg.get("resn_weight_decay"),
1253
- "final_ensemble": bool(cfg.get("final_ensemble", False)),
1254
- "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
1255
- "final_refit": bool(cfg.get("final_refit", True)),
1256
- "optuna_storage": optuna_storage,
1257
- "optuna_study_prefix": optuna_study_prefix,
1258
- "best_params_files": best_params_files,
1259
- "gnn_use_approx_knn": gnn_use_ann,
1260
- "gnn_approx_knn_threshold": gnn_threshold,
1261
- "gnn_graph_cache": gnn_graph_cache,
1262
- "gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
1263
- "gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
1264
- "gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
1265
- "region_province_col": region_province_col,
1266
- "region_city_col": region_city_col,
1267
- "region_effect_alpha": region_effect_alpha,
1268
- "geo_feature_nmes": geo_feature_nmes,
1269
- "geo_token_hidden_dim": geo_token_hidden_dim,
1270
- "geo_token_layers": geo_token_layers,
1271
- "geo_token_dropout": geo_token_dropout,
1272
- "geo_token_k_neighbors": geo_token_k_neighbors,
1273
- "geo_token_learning_rate": geo_token_learning_rate,
1274
- "geo_token_epochs": geo_token_epochs,
1275
- "ft_role": ft_role,
1276
- "ft_feature_prefix": ft_feature_prefix,
1277
- "ft_num_numeric_tokens": ft_num_numeric_tokens,
1278
- "reuse_best_params": reuse_best_params,
1279
- "bo_sample_limit": bo_sample_limit,
1280
- "cache_predictions": cache_predictions,
1281
- "prediction_cache_dir": prediction_cache_dir,
1282
- "prediction_cache_format": prediction_cache_format,
1283
- "cv_strategy": cv_strategy or split_strategy,
1284
- "cv_group_col": cv_group_col or split_group_col,
1285
- "cv_time_col": cv_time_col or split_time_col,
1286
- "cv_time_ascending": cv_time_ascending,
1287
- "cv_splits": cv_splits,
1288
- "ft_oof_folds": ft_oof_folds,
1289
- "ft_oof_strategy": ft_oof_strategy,
1290
- "ft_oof_shuffle": ft_oof_shuffle,
1291
- "save_preprocess": save_preprocess,
1292
- "preprocess_artifact_path": preprocess_artifact_path,
1293
- "plot_path_style": plot_path_style or "nested",
1294
- })
1295
- config_payload = {
1296
- k: v for k, v in config_payload.items() if v is not None}
1297
- config = ropt.BayesOptConfig(**config_payload)
1298
- model = ropt.BayesOptModel(train_df, test_df, config=config)
1299
-
1300
- if plot_requested:
1301
- plot_cfg = cfg.get("plot", {})
1302
- legacy_lift_flags = {
1303
- "glm": cfg.get("plot_lift_glm", False),
1304
- "xgb": cfg.get("plot_lift_xgb", False),
1305
- "resn": cfg.get("plot_lift_resn", False),
1306
- "ft": cfg.get("plot_lift_ft", False),
1307
- }
1308
- plot_enabled = plot_cfg.get(
1309
- "enable", any(legacy_lift_flags.values()))
1310
- if plot_enabled and plot_cfg.get("pre_oneway", False) and plot_cfg.get("oneway", True):
1311
- n_bins = int(plot_cfg.get("n_bins", 10))
1312
- model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/pre")
1313
-
1314
- if "all" in args.model_keys:
1315
- requested_keys = ["glm", "xgb", "resn", "ft", "gnn"]
1316
- else:
1317
- requested_keys = args.model_keys
1318
- requested_keys = dedupe_preserve_order(requested_keys)
1319
-
1320
- if ft_role != "model":
1321
- requested_keys = [k for k in requested_keys if k != "ft"]
1322
- if not requested_keys:
1323
- stack_keys = args.stack_model_keys or cfg.get(
1324
- "stack_model_keys")
1325
- if stack_keys:
1326
- if "all" in stack_keys:
1327
- requested_keys = ["glm", "xgb", "resn", "gnn"]
1328
- else:
1329
- requested_keys = [k for k in stack_keys if k != "ft"]
1330
- requested_keys = dedupe_preserve_order(requested_keys)
1331
- if dist_active and ddp_enabled:
1332
- ft_trainer = model.trainers.get("ft")
1333
- if ft_trainer is None:
1334
- raise ValueError("FT trainer is not available.")
1335
- ft_trainer_uses_ddp = bool(
1336
- getattr(ft_trainer, "enable_distributed_optuna", False))
1337
- if not ft_trainer_uses_ddp:
1338
- raise ValueError(
1339
- "FT embedding under torchrun requires enabling FT DDP (use --use-ft-ddp or set use_ft_ddp=true)."
1340
- )
1341
- missing = [key for key in requested_keys if key not in model.trainers]
1342
- if missing:
1343
- raise ValueError(
1344
- f"Trainer(s) {missing} not available for {model_name}")
1345
-
1346
- executed_keys: List[str] = []
1347
- if ft_role != "model":
1348
- if dist_active and not ddp_enabled:
1349
- _ddp_barrier("start_ft_embedding")
1350
- if dist_rank != 0:
1351
- _ddp_barrier("finish_ft_embedding")
1352
- continue
1353
- print(
1354
- f"Optimizing ft as {ft_role} for {model_name} (max_evals={args.max_evals})")
1355
- model.optimize_model("ft", max_evals=args.max_evals)
1356
- model.trainers["ft"].save()
1357
- if getattr(ropt, "torch", None) is not None and ropt.torch.cuda.is_available():
1358
- ropt.free_cuda()
1359
- if dist_active and not ddp_enabled:
1360
- _ddp_barrier("finish_ft_embedding")
1361
- for key in requested_keys:
1362
- trainer = model.trainers[key]
1363
- trainer_uses_ddp = bool(
1364
- getattr(trainer, "enable_distributed_optuna", False))
1365
- if dist_active and not trainer_uses_ddp:
1366
- if dist_rank != 0:
1367
- print(
1368
- f"[Rank {dist_rank}] Skip {model_name}/{key} because trainer is not DDP-enabled."
1369
- )
1370
- _ddp_barrier(f"start_non_ddp_{model_name}_{key}")
1371
- if dist_rank != 0:
1372
- _ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
1373
- continue
1374
-
1375
- print(
1376
- f"Optimizing {key} for {model_name} (max_evals={args.max_evals})")
1377
- model.optimize_model(key, max_evals=args.max_evals)
1378
- model.trainers[key].save()
1379
- _plot_loss_curve_for_trainer(model_name, model.trainers[key])
1380
- if key in PYTORCH_TRAINERS:
1381
- ropt.free_cuda()
1382
- if dist_active and not trainer_uses_ddp:
1383
- _ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
1384
- executed_keys.append(key)
1385
-
1386
- if not executed_keys:
1387
- continue
1388
-
1389
- results[model_name] = model
1390
- trained_keys_by_model[model_name] = executed_keys
1391
- if report_enabled and is_main_process:
1392
- psi_report_df = _compute_psi_report(
1393
- model,
1394
- features=psi_features,
1395
- bins=psi_bins,
1396
- strategy=str(psi_strategy),
1397
- )
1398
- for key in executed_keys:
1399
- _evaluate_and_report(
1400
- model,
1401
- model_name=model_name,
1402
- model_key=key,
1403
- cfg=cfg,
1404
- data_path=data_path,
1405
- data_fingerprint=data_fingerprint,
1406
- report_output_dir=report_output_dir,
1407
- report_group_cols=report_group_cols,
1408
- report_time_col=report_time_col,
1409
- report_time_freq=str(report_time_freq),
1410
- report_time_ascending=bool(report_time_ascending),
1411
- psi_report_df=psi_report_df,
1412
- calibration_cfg=calibration_cfg,
1413
- threshold_cfg=threshold_cfg,
1414
- bootstrap_cfg=bootstrap_cfg,
1415
- register_model=register_model,
1416
- registry_path=registry_path,
1417
- registry_tags=registry_tags,
1418
- registry_status=registry_status,
1419
- run_id=run_id,
1420
- config_sha=config_sha,
1421
- )
1422
-
1423
- if not plot_requested:
1424
- return
1425
-
1426
- for name, model in results.items():
1427
- _plot_curves_for_model(
1428
- model,
1429
- trained_keys_by_model.get(name, []),
1430
- cfg,
1431
- )
1432
-
1433
-
1434
- def main() -> None:
1435
- if configure_run_logging:
1436
- configure_run_logging(prefix="bayesopt_entry")
1437
- args = _parse_args()
1438
- train_from_config(args)
1439
-
1440
-
1441
- if __name__ == "__main__":
1442
- main()
19
+ if importlib.util.find_spec("ins_pricing") is None:
20
+ repo_root = Path(__file__).resolve().parents[2]
21
+ if str(repo_root) not in sys.path:
22
+ sys.path.insert(0, str(repo_root))
23
+
24
+ import argparse
25
+ import hashlib
26
+ import json
27
+ import os
28
+ from datetime import datetime
29
+ from typing import Any, Dict, List, Optional
30
+
31
+ import numpy as np
32
+ import pandas as pd
33
+
34
+ # Use unified import resolver to eliminate nested try/except chains
35
+ from ins_pricing.cli.utils.import_resolver import resolve_imports, setup_sys_path
36
+ from ins_pricing.cli.utils.evaluation_context import (
37
+ EvaluationContext,
38
+ TrainingContext,
39
+ ModelIdentity,
40
+ DataFingerprint,
41
+ CalibrationConfig,
42
+ ThresholdConfig,
43
+ BootstrapConfig,
44
+ ReportConfig,
45
+ RegistryConfig,
46
+ )
47
+
48
+ # Resolve all imports from a single location
49
+ setup_sys_path()
50
+ _imports = resolve_imports()
51
+
52
+ ropt = _imports.bayesopt
53
+ PLOT_MODEL_LABELS = _imports.PLOT_MODEL_LABELS
54
+ PYTORCH_TRAINERS = _imports.PYTORCH_TRAINERS
55
+ build_model_names = _imports.build_model_names
56
+ dedupe_preserve_order = _imports.dedupe_preserve_order
57
+ load_dataset = _imports.load_dataset
58
+ parse_model_pairs = _imports.parse_model_pairs
59
+ resolve_data_path = _imports.resolve_data_path
60
+ resolve_path = _imports.resolve_path
61
+ fingerprint_file = _imports.fingerprint_file
62
+ coerce_dataset_types = _imports.coerce_dataset_types
63
+ split_train_test = _imports.split_train_test
64
+
65
+ add_config_json_arg = _imports.add_config_json_arg
66
+ add_output_dir_arg = _imports.add_output_dir_arg
67
+ resolve_and_load_config = _imports.resolve_and_load_config
68
+ resolve_data_config = _imports.resolve_data_config
69
+ resolve_report_config = _imports.resolve_report_config
70
+ resolve_split_config = _imports.resolve_split_config
71
+ resolve_runtime_config = _imports.resolve_runtime_config
72
+ resolve_output_dirs = _imports.resolve_output_dirs
73
+
74
+ bootstrap_ci = _imports.bootstrap_ci
75
+ calibrate_predictions = _imports.calibrate_predictions
76
+ eval_metrics_report = _imports.metrics_report
77
+ select_threshold = _imports.select_threshold
78
+
79
+ ModelArtifact = _imports.ModelArtifact
80
+ ModelRegistry = _imports.ModelRegistry
81
+ drift_psi_report = _imports.drift_psi_report
82
+ group_metrics = _imports.group_metrics
83
+ ReportPayload = _imports.ReportPayload
84
+ write_report = _imports.write_report
85
+
86
+ configure_run_logging = _imports.configure_run_logging
87
+ plot_loss_curve_common = _imports.plot_loss_curve
88
+
89
+ import matplotlib
90
+
91
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
92
+ matplotlib.use("Agg")
93
+ import matplotlib.pyplot as plt
94
+
95
+
96
+ def _parse_args() -> argparse.Namespace:
97
+ parser = argparse.ArgumentParser(
98
+ description="Batch trainer generated from BayesOpt_AutoPricing notebook."
99
+ )
100
+ add_config_json_arg(
101
+ parser,
102
+ help_text="Path to the JSON config describing datasets and feature columns.",
103
+ )
104
+ parser.add_argument(
105
+ "--model-keys",
106
+ nargs="+",
107
+ default=["ft"],
108
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
109
+ help="Space-separated list of trainers to run (e.g., --model-keys glm xgb). Include 'all' to run every trainer.",
110
+ )
111
+ parser.add_argument(
112
+ "--stack-model-keys",
113
+ nargs="+",
114
+ default=None,
115
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
116
+ help=(
117
+ "Only used when ft_role != 'model' (FT runs as feature generator). "
118
+ "When provided (or when config defines stack_model_keys), these trainers run after FT features "
119
+ "are generated. Use 'all' to run every non-FT trainer."
120
+ ),
121
+ )
122
+ parser.add_argument(
123
+ "--max-evals",
124
+ type=int,
125
+ default=50,
126
+ help="Optuna trial count per dataset.",
127
+ )
128
+ parser.add_argument(
129
+ "--use-resn-ddp",
130
+ action="store_true",
131
+ help="Force ResNet trainer to use DistributedDataParallel.",
132
+ )
133
+ parser.add_argument(
134
+ "--use-ft-ddp",
135
+ action="store_true",
136
+ help="Force FT-Transformer trainer to use DistributedDataParallel.",
137
+ )
138
+ parser.add_argument(
139
+ "--use-resn-dp",
140
+ action="store_true",
141
+ help="Enable ResNet DataParallel fall-back regardless of config.",
142
+ )
143
+ parser.add_argument(
144
+ "--use-ft-dp",
145
+ action="store_true",
146
+ help="Enable FT-Transformer DataParallel fall-back regardless of config.",
147
+ )
148
+ parser.add_argument(
149
+ "--use-gnn-dp",
150
+ action="store_true",
151
+ help="Enable GNN DataParallel fall-back regardless of config.",
152
+ )
153
+ parser.add_argument(
154
+ "--use-gnn-ddp",
155
+ action="store_true",
156
+ help="Force GNN trainer to use DistributedDataParallel.",
157
+ )
158
+ parser.add_argument(
159
+ "--gnn-no-ann",
160
+ action="store_true",
161
+ help="Disable approximate k-NN for GNN graph construction and use exact search.",
162
+ )
163
+ parser.add_argument(
164
+ "--gnn-ann-threshold",
165
+ type=int,
166
+ default=None,
167
+ help="Row threshold above which approximate k-NN is preferred (overrides config).",
168
+ )
169
+ parser.add_argument(
170
+ "--gnn-graph-cache",
171
+ default=None,
172
+ help="Optional path to persist/load cached adjacency matrix for GNN.",
173
+ )
174
+ parser.add_argument(
175
+ "--gnn-max-gpu-nodes",
176
+ type=int,
177
+ default=None,
178
+ help="Overrides the maximum node count allowed for GPU k-NN graph construction.",
179
+ )
180
+ parser.add_argument(
181
+ "--gnn-gpu-mem-ratio",
182
+ type=float,
183
+ default=None,
184
+ help="Overrides the fraction of free GPU memory the k-NN builder may consume.",
185
+ )
186
+ parser.add_argument(
187
+ "--gnn-gpu-mem-overhead",
188
+ type=float,
189
+ default=None,
190
+ help="Overrides the temporary GPU memory overhead multiplier for k-NN estimation.",
191
+ )
192
+ add_output_dir_arg(
193
+ parser,
194
+ help_text="Override output root for models/results/plots.",
195
+ )
196
+ parser.add_argument(
197
+ "--plot-curves",
198
+ action="store_true",
199
+ help="Enable lift/diagnostic plots after training (config file may also request plotting).",
200
+ )
201
+ parser.add_argument(
202
+ "--ft-as-feature",
203
+ action="store_true",
204
+ help="Alias for --ft-role embedding (keep tuning, export embeddings; skip FT plots/SHAP).",
205
+ )
206
+ parser.add_argument(
207
+ "--ft-role",
208
+ default=None,
209
+ choices=["model", "embedding", "unsupervised_embedding"],
210
+ help="How to use FT: model (default), embedding (export pooling embeddings), or unsupervised_embedding.",
211
+ )
212
+ parser.add_argument(
213
+ "--ft-feature-prefix",
214
+ default="ft_feat",
215
+ help="Prefix used for generated FT features (columns: pred_<prefix>_0.. or pred_<prefix>).",
216
+ )
217
+ parser.add_argument(
218
+ "--reuse-best-params",
219
+ action="store_true",
220
+ help="Skip Optuna and reuse best_params saved in Results/versions or bestparams CSV when available.",
221
+ )
222
+ return parser.parse_args()
223
+
224
+
225
+ def _plot_curves_for_model(model: ropt.BayesOptModel, trained_keys: List[str], cfg: Dict) -> None:
226
+ plot_cfg = cfg.get("plot", {})
227
+ legacy_lift_flags = {
228
+ "glm": cfg.get("plot_lift_glm", False),
229
+ "xgb": cfg.get("plot_lift_xgb", False),
230
+ "resn": cfg.get("plot_lift_resn", False),
231
+ "ft": cfg.get("plot_lift_ft", False),
232
+ }
233
+ plot_enabled = plot_cfg.get("enable", any(legacy_lift_flags.values()))
234
+ if not plot_enabled:
235
+ return
236
+
237
+ n_bins = int(plot_cfg.get("n_bins", 10))
238
+ oneway_enabled = plot_cfg.get("oneway", True)
239
+
240
+ available_models = dedupe_preserve_order(
241
+ [m for m in trained_keys if m in PLOT_MODEL_LABELS]
242
+ )
243
+
244
+ lift_models = plot_cfg.get("lift_models")
245
+ if lift_models is None:
246
+ lift_models = [
247
+ m for m, enabled in legacy_lift_flags.items() if enabled]
248
+ if not lift_models:
249
+ lift_models = available_models
250
+ lift_models = dedupe_preserve_order(
251
+ [m for m in lift_models if m in available_models]
252
+ )
253
+
254
+ if oneway_enabled:
255
+ oneway_pred = bool(plot_cfg.get("oneway_pred", False))
256
+ oneway_pred_models = plot_cfg.get("oneway_pred_models")
257
+ pred_plotted = False
258
+ if oneway_pred:
259
+ if oneway_pred_models is None:
260
+ oneway_pred_models = lift_models or available_models
261
+ oneway_pred_models = dedupe_preserve_order(
262
+ [m for m in oneway_pred_models if m in available_models]
263
+ )
264
+ for model_key in oneway_pred_models:
265
+ label, pred_nme = PLOT_MODEL_LABELS[model_key]
266
+ if pred_nme not in model.train_data.columns:
267
+ print(
268
+ f"[Oneway] Missing prediction column '{pred_nme}'; skip.",
269
+ flush=True,
270
+ )
271
+ continue
272
+ model.plot_oneway(
273
+ n_bins=n_bins,
274
+ pred_col=pred_nme,
275
+ pred_label=label,
276
+ plot_subdir="oneway/post",
277
+ )
278
+ pred_plotted = True
279
+ if not oneway_pred or not pred_plotted:
280
+ model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/post")
281
+
282
+ if not available_models:
283
+ return
284
+
285
+ for model_key in lift_models:
286
+ label, pred_nme = PLOT_MODEL_LABELS[model_key]
287
+ model.plot_lift(model_label=label, pred_nme=pred_nme, n_bins=n_bins)
288
+
289
+ if not plot_cfg.get("double_lift", True) or len(available_models) < 2:
290
+ return
291
+
292
+ raw_pairs = plot_cfg.get("double_lift_pairs")
293
+ if raw_pairs:
294
+ pairs = [
295
+ (a, b)
296
+ for a, b in parse_model_pairs(raw_pairs)
297
+ if a in available_models and b in available_models and a != b
298
+ ]
299
+ else:
300
+ pairs = [(a, b) for i, a in enumerate(available_models)
301
+ for b in available_models[i + 1:]]
302
+
303
+ for first, second in pairs:
304
+ model.plot_dlift([first, second], n_bins=n_bins)
305
+
306
+
307
+ def _plot_loss_curve_for_trainer(model_name: str, trainer) -> None:
308
+ model_obj = getattr(trainer, "model", None)
309
+ history = None
310
+ if model_obj is not None:
311
+ history = getattr(model_obj, "training_history", None)
312
+ if not history:
313
+ history = getattr(trainer, "training_history", None)
314
+ if not history:
315
+ return
316
+ train_hist = list(history.get("train") or [])
317
+ val_hist = list(history.get("val") or [])
318
+ if not train_hist and not val_hist:
319
+ return
320
+ try:
321
+ plot_dir = trainer.output.plot_path(
322
+ f"{model_name}/loss/loss_{model_name}_{trainer.model_name_prefix}.png"
323
+ )
324
+ except Exception:
325
+ default_dir = Path("plot") / model_name / "loss"
326
+ default_dir.mkdir(parents=True, exist_ok=True)
327
+ plot_dir = str(
328
+ default_dir / f"loss_{model_name}_{trainer.model_name_prefix}.png")
329
+ if plot_loss_curve_common is not None:
330
+ plot_loss_curve_common(
331
+ history=history,
332
+ title=f"{trainer.model_name_prefix} Loss Curve ({model_name})",
333
+ save_path=plot_dir,
334
+ show=False,
335
+ )
336
+ else:
337
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
338
+ fig, ax = plt.subplots(figsize=(8, 4))
339
+ if train_hist:
340
+ ax.plot(range(1, len(train_hist) + 1),
341
+ train_hist, label="Train Loss", color="tab:blue")
342
+ if val_hist:
343
+ ax.plot(range(1, len(val_hist) + 1),
344
+ val_hist, label="Validation Loss", color="tab:orange")
345
+ ax.set_xlabel("Epoch")
346
+ ax.set_ylabel("Weighted Loss")
347
+ ax.set_title(
348
+ f"{trainer.model_name_prefix} Loss Curve ({model_name})")
349
+ ax.grid(True, linestyle="--", alpha=0.3)
350
+ ax.legend()
351
+ plt.tight_layout()
352
+ plt.savefig(plot_dir, dpi=300)
353
+ plt.close(fig)
354
+ print(
355
+ f"[Plot] Saved loss curve for {model_name}/{trainer.label} -> {plot_dir}")
356
+
357
+
358
+ def _sample_arrays(
359
+ y_true: np.ndarray,
360
+ y_pred: np.ndarray,
361
+ *,
362
+ max_rows: Optional[int],
363
+ seed: Optional[int],
364
+ ) -> tuple[np.ndarray, np.ndarray]:
365
+ if max_rows is None or max_rows <= 0:
366
+ return y_true, y_pred
367
+ n = len(y_true)
368
+ if n <= max_rows:
369
+ return y_true, y_pred
370
+ rng = np.random.default_rng(seed)
371
+ idx = rng.choice(n, size=int(max_rows), replace=False)
372
+ return y_true[idx], y_pred[idx]
373
+
374
+
375
+ def _compute_psi_report(
376
+ model: ropt.BayesOptModel,
377
+ *,
378
+ features: Optional[List[str]],
379
+ bins: int,
380
+ strategy: str,
381
+ ) -> Optional[pd.DataFrame]:
382
+ if drift_psi_report is None:
383
+ return None
384
+ psi_features = features or list(getattr(model, "factor_nmes", []))
385
+ psi_features = [
386
+ f for f in psi_features if f in model.train_data.columns and f in model.test_data.columns]
387
+ if not psi_features:
388
+ return None
389
+ try:
390
+ return drift_psi_report(
391
+ model.train_data[psi_features],
392
+ model.test_data[psi_features],
393
+ features=psi_features,
394
+ bins=int(bins),
395
+ strategy=str(strategy),
396
+ )
397
+ except Exception as exc:
398
+ print(f"[Report] PSI computation failed: {exc}")
399
+ return None
400
+
401
+
402
+ # --- Refactored helper functions for _evaluate_and_report ---
403
+
404
+
405
+ def _apply_calibration(
406
+ y_true_train: np.ndarray,
407
+ y_pred_train: np.ndarray,
408
+ y_pred_test: np.ndarray,
409
+ calibration_cfg: Dict[str, Any],
410
+ model_name: str,
411
+ model_key: str,
412
+ ) -> tuple[np.ndarray, np.ndarray, Optional[Dict[str, Any]]]:
413
+ """Apply calibration to predictions for classification tasks.
414
+
415
+ Returns:
416
+ Tuple of (calibrated_train_preds, calibrated_test_preds, calibration_info)
417
+ """
418
+ cal_cfg = dict(calibration_cfg or {})
419
+ cal_enabled = bool(cal_cfg.get("enable", False) or cal_cfg.get("method"))
420
+
421
+ if not cal_enabled or calibrate_predictions is None:
422
+ return y_pred_train, y_pred_test, None
423
+
424
+ method = cal_cfg.get("method", "sigmoid")
425
+ max_rows = cal_cfg.get("max_rows")
426
+ seed = cal_cfg.get("seed")
427
+ y_cal, p_cal = _sample_arrays(
428
+ y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
429
+
430
+ try:
431
+ calibrator = calibrate_predictions(y_cal, p_cal, method=method)
432
+ calibrated_train = calibrator.predict(y_pred_train)
433
+ calibrated_test = calibrator.predict(y_pred_test)
434
+ calibration_info = {"method": calibrator.method, "max_rows": max_rows}
435
+ return calibrated_train, calibrated_test, calibration_info
436
+ except Exception as exc:
437
+ print(f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
438
+ return y_pred_train, y_pred_test, None
439
+
440
+
441
+ def _select_classification_threshold(
442
+ y_true_train: np.ndarray,
443
+ y_pred_train_eval: np.ndarray,
444
+ threshold_cfg: Dict[str, Any],
445
+ ) -> tuple[float, Optional[Dict[str, Any]]]:
446
+ """Select threshold for classification predictions.
447
+
448
+ Returns:
449
+ Tuple of (threshold_value, threshold_info)
450
+ """
451
+ thr_cfg = dict(threshold_cfg or {})
452
+ thr_enabled = bool(
453
+ thr_cfg.get("enable", False)
454
+ or thr_cfg.get("metric")
455
+ or thr_cfg.get("value") is not None
456
+ )
457
+
458
+ if thr_cfg.get("value") is not None:
459
+ threshold_value = float(thr_cfg["value"])
460
+ return threshold_value, {"threshold": threshold_value, "source": "fixed"}
461
+
462
+ if thr_enabled and select_threshold is not None:
463
+ max_rows = thr_cfg.get("max_rows")
464
+ seed = thr_cfg.get("seed")
465
+ y_thr, p_thr = _sample_arrays(
466
+ y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
467
+ threshold_info = select_threshold(
468
+ y_thr,
469
+ p_thr,
470
+ metric=thr_cfg.get("metric", "f1"),
471
+ min_positive_rate=thr_cfg.get("min_positive_rate"),
472
+ grid=thr_cfg.get("grid", 99),
473
+ )
474
+ return float(threshold_info.get("threshold", 0.5)), threshold_info
475
+
476
+ return 0.5, None
477
+
478
+
479
+ def _compute_classification_metrics(
480
+ y_true_test: np.ndarray,
481
+ y_pred_test_eval: np.ndarray,
482
+ threshold_value: float,
483
+ ) -> Dict[str, Any]:
484
+ """Compute metrics for classification task."""
485
+ metrics = eval_metrics_report(
486
+ y_true_test,
487
+ y_pred_test_eval,
488
+ task_type="classification",
489
+ threshold=threshold_value,
490
+ )
491
+ precision = float(metrics.get("precision", 0.0))
492
+ recall = float(metrics.get("recall", 0.0))
493
+ f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
494
+ metrics["f1"] = float(f1)
495
+ metrics["threshold"] = float(threshold_value)
496
+ return metrics
497
+
498
+
499
+ def _compute_bootstrap_ci(
500
+ y_true_test: np.ndarray,
501
+ y_pred_test_eval: np.ndarray,
502
+ weight_test: Optional[np.ndarray],
503
+ metrics: Dict[str, Any],
504
+ bootstrap_cfg: Dict[str, Any],
505
+ task_type: str,
506
+ ) -> Dict[str, Dict[str, float]]:
507
+ """Compute bootstrap confidence intervals for metrics."""
508
+ if not bootstrap_cfg or not bool(bootstrap_cfg.get("enable", False)) or bootstrap_ci is None:
509
+ return {}
510
+
511
+ metric_names = bootstrap_cfg.get("metrics")
512
+ if not metric_names:
513
+ metric_names = [name for name in metrics.keys() if name != "threshold"]
514
+ n_samples = int(bootstrap_cfg.get("n_samples", 200))
515
+ ci = float(bootstrap_cfg.get("ci", 0.95))
516
+ seed = bootstrap_cfg.get("seed")
517
+
518
+ def _metric_fn(y_true, y_pred, weight=None):
519
+ vals = eval_metrics_report(
520
+ y_true,
521
+ y_pred,
522
+ task_type=task_type,
523
+ weight=weight,
524
+ threshold=metrics.get("threshold", 0.5),
525
+ )
526
+ if task_type == "classification":
527
+ prec = float(vals.get("precision", 0.0))
528
+ rec = float(vals.get("recall", 0.0))
529
+ vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
530
+ return vals
531
+
532
+ bootstrap_results: Dict[str, Dict[str, float]] = {}
533
+ for name in metric_names:
534
+ if name not in metrics:
535
+ continue
536
+ ci_result = bootstrap_ci(
537
+ lambda y_t, y_p, w=None: float(_metric_fn(y_t, y_p, w).get(name, 0.0)),
538
+ y_true_test,
539
+ y_pred_test_eval,
540
+ weight=weight_test,
541
+ n_samples=n_samples,
542
+ ci=ci,
543
+ seed=seed,
544
+ )
545
+ bootstrap_results[str(name)] = ci_result
546
+
547
+ return bootstrap_results
548
+
549
+
550
+ def _compute_validation_table(
551
+ model: ropt.BayesOptModel,
552
+ pred_col: str,
553
+ report_group_cols: Optional[List[str]],
554
+ weight_col: Optional[str],
555
+ model_name: str,
556
+ model_key: str,
557
+ ) -> Optional[pd.DataFrame]:
558
+ """Compute grouped validation metrics table."""
559
+ if not report_group_cols or group_metrics is None:
560
+ return None
561
+
562
+ available_groups = [
563
+ col for col in report_group_cols if col in model.test_data.columns
564
+ ]
565
+ if not available_groups:
566
+ return None
567
+
568
+ try:
569
+ validation_table = group_metrics(
570
+ model.test_data,
571
+ actual_col=model.resp_nme,
572
+ pred_col=pred_col,
573
+ group_cols=available_groups,
574
+ weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
575
+ )
576
+ counts = (
577
+ model.test_data.groupby(available_groups, dropna=False)
578
+ .size()
579
+ .reset_index(name="count")
580
+ )
581
+ return validation_table.merge(counts, on=available_groups, how="left")
582
+ except Exception as exc:
583
+ print(f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
584
+ return None
585
+
586
+
587
+ def _compute_risk_trend(
588
+ model: ropt.BayesOptModel,
589
+ pred_col: str,
590
+ report_time_col: Optional[str],
591
+ report_time_freq: str,
592
+ report_time_ascending: bool,
593
+ weight_col: Optional[str],
594
+ model_name: str,
595
+ model_key: str,
596
+ ) -> Optional[pd.DataFrame]:
597
+ """Compute time-series risk trend metrics."""
598
+ if not report_time_col or group_metrics is None:
599
+ return None
600
+
601
+ if report_time_col not in model.test_data.columns:
602
+ return None
603
+
604
+ try:
605
+ time_df = model.test_data.copy()
606
+ time_series = pd.to_datetime(time_df[report_time_col], errors="coerce")
607
+ time_df = time_df.loc[time_series.notna()].copy()
608
+
609
+ if time_df.empty:
610
+ return None
611
+
612
+ time_df["_time_bucket"] = (
613
+ pd.to_datetime(time_df[report_time_col], errors="coerce")
614
+ .dt.to_period(report_time_freq)
615
+ .dt.to_timestamp()
616
+ )
617
+ risk_trend = group_metrics(
618
+ time_df,
619
+ actual_col=model.resp_nme,
620
+ pred_col=pred_col,
621
+ group_cols=["_time_bucket"],
622
+ weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
623
+ )
624
+ counts = (
625
+ time_df.groupby("_time_bucket", dropna=False)
626
+ .size()
627
+ .reset_index(name="count")
628
+ )
629
+ risk_trend = risk_trend.merge(counts, on="_time_bucket", how="left")
630
+ risk_trend = risk_trend.sort_values(
631
+ "_time_bucket", ascending=bool(report_time_ascending)
632
+ ).reset_index(drop=True)
633
+ return risk_trend.rename(columns={"_time_bucket": report_time_col})
634
+ except Exception as exc:
635
+ print(f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
636
+ return None
637
+
638
+
639
+ def _write_metrics_json(
640
+ report_root: Path,
641
+ model_name: str,
642
+ model_key: str,
643
+ version: str,
644
+ metrics: Dict[str, Any],
645
+ threshold_info: Optional[Dict[str, Any]],
646
+ calibration_info: Optional[Dict[str, Any]],
647
+ bootstrap_results: Dict[str, Dict[str, float]],
648
+ data_path: Path,
649
+ data_fingerprint: Dict[str, Any],
650
+ config_sha: str,
651
+ pred_col: str,
652
+ task_type: str,
653
+ ) -> Path:
654
+ """Write metrics to JSON file and return the path."""
655
+ metrics_payload = {
656
+ "model_name": model_name,
657
+ "model_key": model_key,
658
+ "model_version": version,
659
+ "metrics": metrics,
660
+ "threshold": threshold_info,
661
+ "calibration": calibration_info,
662
+ "bootstrap": bootstrap_results,
663
+ "data_path": str(data_path),
664
+ "data_fingerprint": data_fingerprint,
665
+ "config_sha256": config_sha,
666
+ "pred_col": pred_col,
667
+ "task_type": task_type,
668
+ }
669
+ metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
670
+ metrics_path.write_text(
671
+ json.dumps(metrics_payload, indent=2, ensure_ascii=True),
672
+ encoding="utf-8",
673
+ )
674
+ return metrics_path
675
+
676
+
677
+ def _write_model_report(
678
+ report_root: Path,
679
+ model_name: str,
680
+ model_key: str,
681
+ version: str,
682
+ metrics: Dict[str, Any],
683
+ risk_trend: Optional[pd.DataFrame],
684
+ psi_report_df: Optional[pd.DataFrame],
685
+ validation_table: Optional[pd.DataFrame],
686
+ calibration_info: Optional[Dict[str, Any]],
687
+ threshold_info: Optional[Dict[str, Any]],
688
+ bootstrap_results: Dict[str, Dict[str, float]],
689
+ config_sha: str,
690
+ data_fingerprint: Dict[str, Any],
691
+ ) -> Optional[Path]:
692
+ """Write model report and return the path."""
693
+ if ReportPayload is None or write_report is None:
694
+ return None
695
+
696
+ notes_lines = [
697
+ f"- Config SHA256: {config_sha}",
698
+ f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
699
+ ]
700
+ if calibration_info:
701
+ notes_lines.append(f"- Calibration: {calibration_info.get('method')}")
702
+ if threshold_info:
703
+ notes_lines.append(f"- Threshold selection: {threshold_info}")
704
+ if bootstrap_results:
705
+ notes_lines.append("- Bootstrap: see metrics JSON for CI")
706
+
707
+ payload = ReportPayload(
708
+ model_name=f"{model_name}/{model_key}",
709
+ model_version=version,
710
+ metrics={k: float(v) for k, v in metrics.items()},
711
+ risk_trend=risk_trend,
712
+ drift_report=psi_report_df,
713
+ validation_table=validation_table,
714
+ extra_notes="\n".join(notes_lines),
715
+ )
716
+ return write_report(
717
+ payload,
718
+ report_root / f"{model_name}_{model_key}_report.md",
719
+ )
720
+
721
+
722
+ def _register_model_to_registry(
723
+ model: ropt.BayesOptModel,
724
+ model_name: str,
725
+ model_key: str,
726
+ version: str,
727
+ metrics: Dict[str, Any],
728
+ task_type: str,
729
+ data_path: Path,
730
+ data_fingerprint: Dict[str, Any],
731
+ config_sha: str,
732
+ registry_path: Optional[str],
733
+ registry_tags: Dict[str, Any],
734
+ registry_status: str,
735
+ report_path: Optional[Path],
736
+ metrics_path: Path,
737
+ cfg: Dict[str, Any],
738
+ ) -> None:
739
+ """Register model artifacts to the model registry."""
740
+ if ModelRegistry is None or ModelArtifact is None:
741
+ return
742
+
743
+ registry = ModelRegistry(
744
+ registry_path
745
+ if registry_path
746
+ else Path(model.output_manager.result_dir) / "model_registry.json"
747
+ )
748
+
749
+ tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
750
+ tags.update({
751
+ "model_key": str(model_key),
752
+ "task_type": str(task_type),
753
+ "data_path": str(data_path),
754
+ "data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
755
+ "data_size": str(data_fingerprint.get("size", "")),
756
+ "data_mtime": str(data_fingerprint.get("mtime", "")),
757
+ "config_sha256": str(config_sha),
758
+ })
759
+
760
+ artifacts = _collect_model_artifacts(
761
+ model, model_name, model_key, report_path, metrics_path, cfg
762
+ )
763
+
764
+ registry.register(
765
+ name=str(model_name),
766
+ version=version,
767
+ metrics={k: float(v) for k, v in metrics.items()},
768
+ tags=tags,
769
+ artifacts=artifacts,
770
+ status=str(registry_status or "candidate"),
771
+ notes=f"model_key={model_key}",
772
+ )
773
+
774
+
775
+ def _collect_model_artifacts(
776
+ model: ropt.BayesOptModel,
777
+ model_name: str,
778
+ model_key: str,
779
+ report_path: Optional[Path],
780
+ metrics_path: Path,
781
+ cfg: Dict[str, Any],
782
+ ) -> List:
783
+ """Collect all model artifacts for registry."""
784
+ artifacts = []
785
+
786
+ # Trained model artifact
787
+ trainer = model.trainers.get(model_key)
788
+ if trainer is not None:
789
+ try:
790
+ model_path = trainer.output.model_path(trainer._get_model_filename())
791
+ if os.path.exists(model_path):
792
+ artifacts.append(ModelArtifact(path=model_path, description="trained model"))
793
+ except Exception:
794
+ pass
795
+
796
+ # Report artifact
797
+ if report_path is not None:
798
+ artifacts.append(ModelArtifact(path=str(report_path), description="model report"))
799
+
800
+ # Metrics JSON artifact
801
+ if metrics_path.exists():
802
+ artifacts.append(ModelArtifact(path=str(metrics_path), description="metrics json"))
803
+
804
+ # Preprocess artifacts
805
+ if bool(cfg.get("save_preprocess", False)):
806
+ artifact_path = cfg.get("preprocess_artifact_path")
807
+ if artifact_path:
808
+ preprocess_path = Path(str(artifact_path))
809
+ if not preprocess_path.is_absolute():
810
+ preprocess_path = Path(model.output_manager.result_dir) / preprocess_path
811
+ else:
812
+ preprocess_path = Path(model.output_manager.result_path(
813
+ f"{model.model_nme}_preprocess.json"
814
+ ))
815
+ if preprocess_path.exists():
816
+ artifacts.append(
817
+ ModelArtifact(path=str(preprocess_path), description="preprocess artifacts")
818
+ )
819
+
820
+ # Prediction cache artifacts
821
+ if bool(cfg.get("cache_predictions", False)):
822
+ cache_dir = cfg.get("prediction_cache_dir")
823
+ if cache_dir:
824
+ pred_root = Path(str(cache_dir))
825
+ if not pred_root.is_absolute():
826
+ pred_root = Path(model.output_manager.result_dir) / pred_root
827
+ else:
828
+ pred_root = Path(model.output_manager.result_dir) / "predictions"
829
+ ext = "csv" if str(cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
830
+ for split_label in ("train", "test"):
831
+ pred_path = pred_root / f"{model_name}_{model_key}_{split_label}.{ext}"
832
+ if pred_path.exists():
833
+ artifacts.append(
834
+ ModelArtifact(path=str(pred_path), description=f"predictions {split_label}")
835
+ )
836
+
837
+ return artifacts
838
+
839
+
840
+ def _evaluate_and_report(
841
+ model: ropt.BayesOptModel,
842
+ *,
843
+ model_name: str,
844
+ model_key: str,
845
+ cfg: Dict[str, Any],
846
+ data_path: Path,
847
+ data_fingerprint: Dict[str, Any],
848
+ report_output_dir: Optional[str],
849
+ report_group_cols: Optional[List[str]],
850
+ report_time_col: Optional[str],
851
+ report_time_freq: str,
852
+ report_time_ascending: bool,
853
+ psi_report_df: Optional[pd.DataFrame],
854
+ calibration_cfg: Dict[str, Any],
855
+ threshold_cfg: Dict[str, Any],
856
+ bootstrap_cfg: Dict[str, Any],
857
+ register_model: bool,
858
+ registry_path: Optional[str],
859
+ registry_tags: Dict[str, Any],
860
+ registry_status: str,
861
+ run_id: str,
862
+ config_sha: str,
863
+ ) -> None:
864
+ """Evaluate model predictions and generate reports.
865
+
866
+ This function orchestrates the evaluation pipeline:
867
+ 1. Extract predictions and ground truth
868
+ 2. Apply calibration (for classification)
869
+ 3. Select threshold (for classification)
870
+ 4. Compute metrics
871
+ 5. Compute bootstrap confidence intervals
872
+ 6. Generate validation tables and risk trends
873
+ 7. Write reports and register model
874
+ """
875
+ if eval_metrics_report is None:
876
+ print("[Report] Skip evaluation: metrics module unavailable.")
877
+ return
878
+
879
+ pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
880
+ if pred_col not in model.test_data.columns:
881
+ print(f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
882
+ return
883
+
884
+ # Extract predictions and weights
885
+ weight_col = getattr(model, "weight_nme", None)
886
+ y_true_train = model.train_data[model.resp_nme].to_numpy(dtype=float, copy=False)
887
+ y_true_test = model.test_data[model.resp_nme].to_numpy(dtype=float, copy=False)
888
+ y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
889
+ y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
890
+ weight_test = (
891
+ model.test_data[weight_col].to_numpy(dtype=float, copy=False)
892
+ if weight_col and weight_col in model.test_data.columns
893
+ else None
894
+ )
895
+
896
+ task_type = str(cfg.get("task_type", getattr(model, "task_type", "regression")))
897
+
898
+ # Process based on task type
899
+ if task_type == "classification":
900
+ y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
901
+ y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
902
+
903
+ y_pred_train_eval, y_pred_test_eval, calibration_info = _apply_calibration(
904
+ y_true_train, y_pred_train, y_pred_test, calibration_cfg, model_name, model_key
905
+ )
906
+ threshold_value, threshold_info = _select_classification_threshold(
907
+ y_true_train, y_pred_train_eval, threshold_cfg
908
+ )
909
+ metrics = _compute_classification_metrics(y_true_test, y_pred_test_eval, threshold_value)
910
+ else:
911
+ y_pred_test_eval = y_pred_test
912
+ calibration_info = None
913
+ threshold_info = None
914
+ metrics = eval_metrics_report(
915
+ y_true_test, y_pred_test_eval, task_type=task_type, weight=weight_test
916
+ )
917
+
918
+ # Compute bootstrap confidence intervals
919
+ bootstrap_results = _compute_bootstrap_ci(
920
+ y_true_test, y_pred_test_eval, weight_test, metrics, bootstrap_cfg, task_type
921
+ )
922
+
923
+ # Compute validation table and risk trend
924
+ validation_table = _compute_validation_table(
925
+ model, pred_col, report_group_cols, weight_col, model_name, model_key
926
+ )
927
+ risk_trend = _compute_risk_trend(
928
+ model, pred_col, report_time_col, report_time_freq,
929
+ report_time_ascending, weight_col, model_name, model_key
930
+ )
931
+
932
+ # Setup output directory
933
+ report_root = (
934
+ Path(report_output_dir)
935
+ if report_output_dir
936
+ else Path(model.output_manager.result_dir) / "reports"
937
+ )
938
+ report_root.mkdir(parents=True, exist_ok=True)
939
+ version = f"{model_key}_{run_id}"
940
+
941
+ # Write metrics JSON
942
+ metrics_path = _write_metrics_json(
943
+ report_root, model_name, model_key, version, metrics,
944
+ threshold_info, calibration_info, bootstrap_results,
945
+ data_path, data_fingerprint, config_sha, pred_col, task_type
946
+ )
947
+
948
+ # Write model report
949
+ report_path = _write_model_report(
950
+ report_root, model_name, model_key, version, metrics,
951
+ risk_trend, psi_report_df, validation_table,
952
+ calibration_info, threshold_info, bootstrap_results,
953
+ config_sha, data_fingerprint
954
+ )
955
+
956
+ # Register model
957
+ if register_model:
958
+ _register_model_to_registry(
959
+ model, model_name, model_key, version, metrics, task_type,
960
+ data_path, data_fingerprint, config_sha, registry_path,
961
+ registry_tags, registry_status, report_path, metrics_path, cfg
962
+ )
963
+
964
+
965
+ def _evaluate_with_context(
966
+ model: ropt.BayesOptModel,
967
+ ctx: EvaluationContext,
968
+ ) -> None:
969
+ """Evaluate model predictions using context object.
970
+
971
+ This is a cleaner interface that uses the EvaluationContext dataclass
972
+ instead of 19+ individual parameters.
973
+ """
974
+ _evaluate_and_report(
975
+ model,
976
+ model_name=ctx.identity.model_name,
977
+ model_key=ctx.identity.model_key,
978
+ cfg=ctx.cfg,
979
+ data_path=ctx.data_path,
980
+ data_fingerprint=ctx.data_fingerprint.to_dict(),
981
+ report_output_dir=ctx.report.output_dir,
982
+ report_group_cols=ctx.report.group_cols,
983
+ report_time_col=ctx.report.time_col,
984
+ report_time_freq=ctx.report.time_freq,
985
+ report_time_ascending=ctx.report.time_ascending,
986
+ psi_report_df=ctx.psi_report_df,
987
+ calibration_cfg={
988
+ "enable": ctx.calibration.enable,
989
+ "method": ctx.calibration.method,
990
+ "max_rows": ctx.calibration.max_rows,
991
+ "seed": ctx.calibration.seed,
992
+ },
993
+ threshold_cfg={
994
+ "enable": ctx.threshold.enable,
995
+ "metric": ctx.threshold.metric,
996
+ "value": ctx.threshold.value,
997
+ "min_positive_rate": ctx.threshold.min_positive_rate,
998
+ "grid": ctx.threshold.grid,
999
+ "max_rows": ctx.threshold.max_rows,
1000
+ "seed": ctx.threshold.seed,
1001
+ },
1002
+ bootstrap_cfg={
1003
+ "enable": ctx.bootstrap.enable,
1004
+ "metrics": ctx.bootstrap.metrics,
1005
+ "n_samples": ctx.bootstrap.n_samples,
1006
+ "ci": ctx.bootstrap.ci,
1007
+ "seed": ctx.bootstrap.seed,
1008
+ },
1009
+ register_model=ctx.registry.register,
1010
+ registry_path=ctx.registry.path,
1011
+ registry_tags=ctx.registry.tags,
1012
+ registry_status=ctx.registry.status,
1013
+ run_id=ctx.run_id,
1014
+ config_sha=ctx.config_sha,
1015
+ )
1016
+
1017
+
1018
+ def _create_ddp_barrier(dist_ctx: TrainingContext):
1019
+ """Create a DDP barrier function for distributed training synchronization."""
1020
+ def _ddp_barrier(reason: str) -> None:
1021
+ if not dist_ctx.is_distributed:
1022
+ return
1023
+ torch_mod = getattr(ropt, "torch", None)
1024
+ dist_mod = getattr(torch_mod, "distributed", None)
1025
+ if dist_mod is None:
1026
+ return
1027
+ try:
1028
+ if not getattr(dist_mod, "is_available", lambda: False)():
1029
+ return
1030
+ if not dist_mod.is_initialized():
1031
+ ddp_ok, _, _, _ = ropt.DistributedUtils.setup_ddp()
1032
+ if not ddp_ok or not dist_mod.is_initialized():
1033
+ return
1034
+ dist_mod.barrier()
1035
+ except Exception as exc:
1036
+ print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
1037
+ raise
1038
+ return _ddp_barrier
1039
+
1040
+
1041
+ def train_from_config(args: argparse.Namespace) -> None:
1042
+ script_dir = Path(__file__).resolve().parents[1]
1043
+ config_path, cfg = resolve_and_load_config(
1044
+ args.config_json,
1045
+ script_dir,
1046
+ required_keys=["data_dir", "model_list",
1047
+ "model_categories", "target", "weight"],
1048
+ )
1049
+ plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
1050
+ config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
1051
+ run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
1052
+
1053
+ # Use TrainingContext for distributed training state
1054
+ dist_ctx = TrainingContext.from_env()
1055
+ dist_world_size = dist_ctx.world_size
1056
+ dist_rank = dist_ctx.rank
1057
+ dist_active = dist_ctx.is_distributed
1058
+ is_main_process = dist_ctx.is_main_process
1059
+ _ddp_barrier = _create_ddp_barrier(dist_ctx)
1060
+
1061
+ data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
1062
+ cfg,
1063
+ config_path,
1064
+ create_data_dir=True,
1065
+ )
1066
+ runtime_cfg = resolve_runtime_config(cfg)
1067
+ ddp_min_rows = runtime_cfg["ddp_min_rows"]
1068
+ bo_sample_limit = runtime_cfg["bo_sample_limit"]
1069
+ cache_predictions = runtime_cfg["cache_predictions"]
1070
+ prediction_cache_dir = runtime_cfg["prediction_cache_dir"]
1071
+ prediction_cache_format = runtime_cfg["prediction_cache_format"]
1072
+ report_cfg = resolve_report_config(cfg)
1073
+ report_output_dir = report_cfg["report_output_dir"]
1074
+ report_group_cols = report_cfg["report_group_cols"]
1075
+ report_time_col = report_cfg["report_time_col"]
1076
+ report_time_freq = report_cfg["report_time_freq"]
1077
+ report_time_ascending = report_cfg["report_time_ascending"]
1078
+ psi_bins = report_cfg["psi_bins"]
1079
+ psi_strategy = report_cfg["psi_strategy"]
1080
+ psi_features = report_cfg["psi_features"]
1081
+ calibration_cfg = report_cfg["calibration_cfg"]
1082
+ threshold_cfg = report_cfg["threshold_cfg"]
1083
+ bootstrap_cfg = report_cfg["bootstrap_cfg"]
1084
+ register_model = report_cfg["register_model"]
1085
+ registry_path = report_cfg["registry_path"]
1086
+ registry_tags = report_cfg["registry_tags"]
1087
+ registry_status = report_cfg["registry_status"]
1088
+ data_fingerprint_max_bytes = report_cfg["data_fingerprint_max_bytes"]
1089
+ report_enabled = report_cfg["report_enabled"]
1090
+
1091
+ split_cfg = resolve_split_config(cfg)
1092
+ prop_test = split_cfg["prop_test"]
1093
+ holdout_ratio = split_cfg["holdout_ratio"]
1094
+ val_ratio = split_cfg["val_ratio"]
1095
+ split_strategy = split_cfg["split_strategy"]
1096
+ split_group_col = split_cfg["split_group_col"]
1097
+ split_time_col = split_cfg["split_time_col"]
1098
+ split_time_ascending = split_cfg["split_time_ascending"]
1099
+ cv_strategy = split_cfg["cv_strategy"]
1100
+ cv_group_col = split_cfg["cv_group_col"]
1101
+ cv_time_col = split_cfg["cv_time_col"]
1102
+ cv_time_ascending = split_cfg["cv_time_ascending"]
1103
+ cv_splits = split_cfg["cv_splits"]
1104
+ ft_oof_folds = split_cfg["ft_oof_folds"]
1105
+ ft_oof_strategy = split_cfg["ft_oof_strategy"]
1106
+ ft_oof_shuffle = split_cfg["ft_oof_shuffle"]
1107
+ save_preprocess = runtime_cfg["save_preprocess"]
1108
+ preprocess_artifact_path = runtime_cfg["preprocess_artifact_path"]
1109
+ rand_seed = runtime_cfg["rand_seed"]
1110
+ epochs = runtime_cfg["epochs"]
1111
+ output_cfg = resolve_output_dirs(
1112
+ cfg,
1113
+ config_path,
1114
+ output_override=args.output_dir,
1115
+ )
1116
+ output_dir = output_cfg["output_dir"]
1117
+ reuse_best_params = bool(
1118
+ args.reuse_best_params or runtime_cfg["reuse_best_params"])
1119
+ xgb_max_depth_max = runtime_cfg["xgb_max_depth_max"]
1120
+ xgb_n_estimators_max = runtime_cfg["xgb_n_estimators_max"]
1121
+ optuna_storage = runtime_cfg["optuna_storage"]
1122
+ optuna_study_prefix = runtime_cfg["optuna_study_prefix"]
1123
+ best_params_files = runtime_cfg["best_params_files"]
1124
+ plot_path_style = runtime_cfg["plot_path_style"]
1125
+
1126
+ model_names = build_model_names(
1127
+ cfg["model_list"], cfg["model_categories"])
1128
+ if not model_names:
1129
+ raise ValueError(
1130
+ "No model names generated from model_list/model_categories.")
1131
+
1132
+ results: Dict[str, ropt.BayesOptModel] = {}
1133
+ trained_keys_by_model: Dict[str, List[str]] = {}
1134
+
1135
+ for model_name in model_names:
1136
+ # Per-dataset training loop: load data, split train/test, and train requested models.
1137
+ data_path = resolve_data_path(
1138
+ data_dir,
1139
+ model_name,
1140
+ data_format=data_format,
1141
+ path_template=data_path_template,
1142
+ )
1143
+ if not data_path.exists():
1144
+ raise FileNotFoundError(f"Missing dataset: {data_path}")
1145
+ data_fingerprint = {"path": str(data_path)}
1146
+ if report_enabled and is_main_process:
1147
+ data_fingerprint = fingerprint_file(
1148
+ data_path,
1149
+ max_bytes=data_fingerprint_max_bytes,
1150
+ )
1151
+
1152
+ print(f"\n=== Processing model {model_name} ===")
1153
+ raw = load_dataset(
1154
+ data_path,
1155
+ data_format=data_format,
1156
+ dtype_map=dtype_map,
1157
+ low_memory=False,
1158
+ )
1159
+ raw = coerce_dataset_types(raw)
1160
+
1161
+ train_df, test_df = split_train_test(
1162
+ raw,
1163
+ holdout_ratio=holdout_ratio,
1164
+ strategy=split_strategy,
1165
+ group_col=split_group_col,
1166
+ time_col=split_time_col,
1167
+ time_ascending=split_time_ascending,
1168
+ rand_seed=rand_seed,
1169
+ reset_index_mode="time_group",
1170
+ ratio_label="holdout_ratio",
1171
+ )
1172
+
1173
+ use_resn_dp = args.use_resn_dp or cfg.get(
1174
+ "use_resn_data_parallel", False)
1175
+ use_ft_dp = args.use_ft_dp or cfg.get("use_ft_data_parallel", True)
1176
+ dataset_rows = len(raw)
1177
+ ddp_enabled = bool(dist_active and (dataset_rows >= int(ddp_min_rows)))
1178
+ use_resn_ddp = (args.use_resn_ddp or cfg.get(
1179
+ "use_resn_ddp", False)) and ddp_enabled
1180
+ use_ft_ddp = (args.use_ft_ddp or cfg.get(
1181
+ "use_ft_ddp", False)) and ddp_enabled
1182
+ use_gnn_dp = args.use_gnn_dp or cfg.get("use_gnn_data_parallel", False)
1183
+ use_gnn_ddp = (args.use_gnn_ddp or cfg.get(
1184
+ "use_gnn_ddp", False)) and ddp_enabled
1185
+ gnn_use_ann = cfg.get("gnn_use_approx_knn", True)
1186
+ if args.gnn_no_ann:
1187
+ gnn_use_ann = False
1188
+ gnn_threshold = args.gnn_ann_threshold if args.gnn_ann_threshold is not None else cfg.get(
1189
+ "gnn_approx_knn_threshold", 50000)
1190
+ gnn_graph_cache = args.gnn_graph_cache or cfg.get("gnn_graph_cache")
1191
+ if isinstance(gnn_graph_cache, str) and gnn_graph_cache.strip():
1192
+ resolved_cache = resolve_path(gnn_graph_cache, config_path.parent)
1193
+ if resolved_cache is not None:
1194
+ gnn_graph_cache = str(resolved_cache)
1195
+ gnn_max_gpu_nodes = args.gnn_max_gpu_nodes if args.gnn_max_gpu_nodes is not None else cfg.get(
1196
+ "gnn_max_gpu_knn_nodes", 200000)
1197
+ gnn_gpu_mem_ratio = args.gnn_gpu_mem_ratio if args.gnn_gpu_mem_ratio is not None else cfg.get(
1198
+ "gnn_knn_gpu_mem_ratio", 0.9)
1199
+ gnn_gpu_mem_overhead = args.gnn_gpu_mem_overhead if args.gnn_gpu_mem_overhead is not None else cfg.get(
1200
+ "gnn_knn_gpu_mem_overhead", 2.0)
1201
+
1202
+ binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
1203
+ task_type = str(cfg.get("task_type", "regression"))
1204
+ feature_list = cfg.get("feature_list")
1205
+ categorical_features = cfg.get("categorical_features")
1206
+ use_gpu = bool(cfg.get("use_gpu", True))
1207
+ region_province_col = cfg.get("region_province_col")
1208
+ region_city_col = cfg.get("region_city_col")
1209
+ region_effect_alpha = cfg.get("region_effect_alpha")
1210
+ geo_feature_nmes = cfg.get("geo_feature_nmes")
1211
+ geo_token_hidden_dim = cfg.get("geo_token_hidden_dim")
1212
+ geo_token_layers = cfg.get("geo_token_layers")
1213
+ geo_token_dropout = cfg.get("geo_token_dropout")
1214
+ geo_token_k_neighbors = cfg.get("geo_token_k_neighbors")
1215
+ geo_token_learning_rate = cfg.get("geo_token_learning_rate")
1216
+ geo_token_epochs = cfg.get("geo_token_epochs")
1217
+
1218
+ ft_role = args.ft_role or cfg.get("ft_role", "model")
1219
+ if args.ft_as_feature and args.ft_role is None:
1220
+ # Keep legacy behavior as a convenience alias only when the config
1221
+ # didn't already request a non-default FT role.
1222
+ if str(cfg.get("ft_role", "model")) == "model":
1223
+ ft_role = "embedding"
1224
+ ft_feature_prefix = str(
1225
+ cfg.get("ft_feature_prefix", args.ft_feature_prefix))
1226
+ ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
1227
+
1228
+ config_fields = getattr(ropt.BayesOptConfig,
1229
+ "__dataclass_fields__", {})
1230
+ allowed_config_keys = set(config_fields.keys())
1231
+ config_payload = {k: v for k,
1232
+ v in cfg.items() if k in allowed_config_keys}
1233
+ config_payload.update({
1234
+ "model_nme": model_name,
1235
+ "resp_nme": cfg["target"],
1236
+ "weight_nme": cfg["weight"],
1237
+ "factor_nmes": feature_list,
1238
+ "task_type": task_type,
1239
+ "binary_resp_nme": binary_target,
1240
+ "cate_list": categorical_features,
1241
+ "prop_test": val_ratio,
1242
+ "rand_seed": rand_seed,
1243
+ "epochs": epochs,
1244
+ "use_gpu": use_gpu,
1245
+ "use_resn_data_parallel": use_resn_dp,
1246
+ "use_ft_data_parallel": use_ft_dp,
1247
+ "use_gnn_data_parallel": use_gnn_dp,
1248
+ "use_resn_ddp": use_resn_ddp,
1249
+ "use_ft_ddp": use_ft_ddp,
1250
+ "use_gnn_ddp": use_gnn_ddp,
1251
+ "output_dir": output_dir,
1252
+ "xgb_max_depth_max": xgb_max_depth_max,
1253
+ "xgb_n_estimators_max": xgb_n_estimators_max,
1254
+ "resn_weight_decay": cfg.get("resn_weight_decay"),
1255
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
1256
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
1257
+ "final_refit": bool(cfg.get("final_refit", True)),
1258
+ "optuna_storage": optuna_storage,
1259
+ "optuna_study_prefix": optuna_study_prefix,
1260
+ "best_params_files": best_params_files,
1261
+ "gnn_use_approx_knn": gnn_use_ann,
1262
+ "gnn_approx_knn_threshold": gnn_threshold,
1263
+ "gnn_graph_cache": gnn_graph_cache,
1264
+ "gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
1265
+ "gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
1266
+ "gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
1267
+ "region_province_col": region_province_col,
1268
+ "region_city_col": region_city_col,
1269
+ "region_effect_alpha": region_effect_alpha,
1270
+ "geo_feature_nmes": geo_feature_nmes,
1271
+ "geo_token_hidden_dim": geo_token_hidden_dim,
1272
+ "geo_token_layers": geo_token_layers,
1273
+ "geo_token_dropout": geo_token_dropout,
1274
+ "geo_token_k_neighbors": geo_token_k_neighbors,
1275
+ "geo_token_learning_rate": geo_token_learning_rate,
1276
+ "geo_token_epochs": geo_token_epochs,
1277
+ "ft_role": ft_role,
1278
+ "ft_feature_prefix": ft_feature_prefix,
1279
+ "ft_num_numeric_tokens": ft_num_numeric_tokens,
1280
+ "reuse_best_params": reuse_best_params,
1281
+ "bo_sample_limit": bo_sample_limit,
1282
+ "cache_predictions": cache_predictions,
1283
+ "prediction_cache_dir": prediction_cache_dir,
1284
+ "prediction_cache_format": prediction_cache_format,
1285
+ "cv_strategy": cv_strategy or split_strategy,
1286
+ "cv_group_col": cv_group_col or split_group_col,
1287
+ "cv_time_col": cv_time_col or split_time_col,
1288
+ "cv_time_ascending": cv_time_ascending,
1289
+ "cv_splits": cv_splits,
1290
+ "ft_oof_folds": ft_oof_folds,
1291
+ "ft_oof_strategy": ft_oof_strategy,
1292
+ "ft_oof_shuffle": ft_oof_shuffle,
1293
+ "save_preprocess": save_preprocess,
1294
+ "preprocess_artifact_path": preprocess_artifact_path,
1295
+ "plot_path_style": plot_path_style or "nested",
1296
+ })
1297
+ config_payload = {
1298
+ k: v for k, v in config_payload.items() if v is not None}
1299
+ config = ropt.BayesOptConfig(**config_payload)
1300
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
1301
+
1302
+ if plot_requested:
1303
+ plot_cfg = cfg.get("plot", {})
1304
+ legacy_lift_flags = {
1305
+ "glm": cfg.get("plot_lift_glm", False),
1306
+ "xgb": cfg.get("plot_lift_xgb", False),
1307
+ "resn": cfg.get("plot_lift_resn", False),
1308
+ "ft": cfg.get("plot_lift_ft", False),
1309
+ }
1310
+ plot_enabled = plot_cfg.get(
1311
+ "enable", any(legacy_lift_flags.values()))
1312
+ if plot_enabled and plot_cfg.get("pre_oneway", False) and plot_cfg.get("oneway", True):
1313
+ n_bins = int(plot_cfg.get("n_bins", 10))
1314
+ model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/pre")
1315
+
1316
+ if "all" in args.model_keys:
1317
+ requested_keys = ["glm", "xgb", "resn", "ft", "gnn"]
1318
+ else:
1319
+ requested_keys = args.model_keys
1320
+ requested_keys = dedupe_preserve_order(requested_keys)
1321
+
1322
+ if ft_role != "model":
1323
+ requested_keys = [k for k in requested_keys if k != "ft"]
1324
+ if not requested_keys:
1325
+ stack_keys = args.stack_model_keys or cfg.get(
1326
+ "stack_model_keys")
1327
+ if stack_keys:
1328
+ if "all" in stack_keys:
1329
+ requested_keys = ["glm", "xgb", "resn", "gnn"]
1330
+ else:
1331
+ requested_keys = [k for k in stack_keys if k != "ft"]
1332
+ requested_keys = dedupe_preserve_order(requested_keys)
1333
+ if dist_active and ddp_enabled:
1334
+ ft_trainer = model.trainers.get("ft")
1335
+ if ft_trainer is None:
1336
+ raise ValueError("FT trainer is not available.")
1337
+ ft_trainer_uses_ddp = bool(
1338
+ getattr(ft_trainer, "enable_distributed_optuna", False))
1339
+ if not ft_trainer_uses_ddp:
1340
+ raise ValueError(
1341
+ "FT embedding under torchrun requires enabling FT DDP (use --use-ft-ddp or set use_ft_ddp=true)."
1342
+ )
1343
+ missing = [key for key in requested_keys if key not in model.trainers]
1344
+ if missing:
1345
+ raise ValueError(
1346
+ f"Trainer(s) {missing} not available for {model_name}")
1347
+
1348
+ executed_keys: List[str] = []
1349
+ if ft_role != "model":
1350
+ if dist_active and not ddp_enabled:
1351
+ _ddp_barrier("start_ft_embedding")
1352
+ if dist_rank != 0:
1353
+ _ddp_barrier("finish_ft_embedding")
1354
+ continue
1355
+ print(
1356
+ f"Optimizing ft as {ft_role} for {model_name} (max_evals={args.max_evals})")
1357
+ model.optimize_model("ft", max_evals=args.max_evals)
1358
+ model.trainers["ft"].save()
1359
+ if getattr(ropt, "torch", None) is not None and ropt.torch.cuda.is_available():
1360
+ ropt.free_cuda()
1361
+ if dist_active and not ddp_enabled:
1362
+ _ddp_barrier("finish_ft_embedding")
1363
+ for key in requested_keys:
1364
+ trainer = model.trainers[key]
1365
+ trainer_uses_ddp = bool(
1366
+ getattr(trainer, "enable_distributed_optuna", False))
1367
+ if dist_active and not trainer_uses_ddp:
1368
+ if dist_rank != 0:
1369
+ print(
1370
+ f"[Rank {dist_rank}] Skip {model_name}/{key} because trainer is not DDP-enabled."
1371
+ )
1372
+ _ddp_barrier(f"start_non_ddp_{model_name}_{key}")
1373
+ if dist_rank != 0:
1374
+ _ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
1375
+ continue
1376
+
1377
+ print(
1378
+ f"Optimizing {key} for {model_name} (max_evals={args.max_evals})")
1379
+ model.optimize_model(key, max_evals=args.max_evals)
1380
+ model.trainers[key].save()
1381
+ _plot_loss_curve_for_trainer(model_name, model.trainers[key])
1382
+ if key in PYTORCH_TRAINERS:
1383
+ ropt.free_cuda()
1384
+ if dist_active and not trainer_uses_ddp:
1385
+ _ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
1386
+ executed_keys.append(key)
1387
+
1388
+ if not executed_keys:
1389
+ continue
1390
+
1391
+ results[model_name] = model
1392
+ trained_keys_by_model[model_name] = executed_keys
1393
+ if report_enabled and is_main_process:
1394
+ psi_report_df = _compute_psi_report(
1395
+ model,
1396
+ features=psi_features,
1397
+ bins=psi_bins,
1398
+ strategy=str(psi_strategy),
1399
+ )
1400
+ for key in executed_keys:
1401
+ _evaluate_and_report(
1402
+ model,
1403
+ model_name=model_name,
1404
+ model_key=key,
1405
+ cfg=cfg,
1406
+ data_path=data_path,
1407
+ data_fingerprint=data_fingerprint,
1408
+ report_output_dir=report_output_dir,
1409
+ report_group_cols=report_group_cols,
1410
+ report_time_col=report_time_col,
1411
+ report_time_freq=str(report_time_freq),
1412
+ report_time_ascending=bool(report_time_ascending),
1413
+ psi_report_df=psi_report_df,
1414
+ calibration_cfg=calibration_cfg,
1415
+ threshold_cfg=threshold_cfg,
1416
+ bootstrap_cfg=bootstrap_cfg,
1417
+ register_model=register_model,
1418
+ registry_path=registry_path,
1419
+ registry_tags=registry_tags,
1420
+ registry_status=registry_status,
1421
+ run_id=run_id,
1422
+ config_sha=config_sha,
1423
+ )
1424
+
1425
+ if not plot_requested:
1426
+ return
1427
+
1428
+ for name, model in results.items():
1429
+ _plot_curves_for_model(
1430
+ model,
1431
+ trained_keys_by_model.get(name, []),
1432
+ cfg,
1433
+ )
1434
+
1435
+
1436
+ def main() -> None:
1437
+ if configure_run_logging:
1438
+ configure_run_logging(prefix="bayesopt_entry")
1439
+ args = _parse_args()
1440
+ train_from_config(args)
1441
+
1442
+
1443
+ if __name__ == "__main__":
1444
+ main()