ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1312 @@
1
+ """
2
+ CLI entry point generated from BayesOpt_AutoPricing.ipynb so the workflow can
3
+ run non‑interactively (e.g., via torchrun).
4
+
5
+ Example:
6
+ python -m torch.distributed.run --standalone --nproc_per_node=2 \\
7
+ ins_pricing/cli/BayesOpt_entry.py \\
8
+ --config-json ins_pricing/examples/modelling/config_template.json \\
9
+ --model-keys ft --max-evals 50 --use-ft-ddp
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from pathlib import Path
15
+ import sys
16
+
17
+ if __package__ in {None, ""}:
18
+ repo_root = Path(__file__).resolve().parents[2]
19
+ if str(repo_root) not in sys.path:
20
+ sys.path.insert(0, str(repo_root))
21
+
22
+ import argparse
23
+ import hashlib
24
+ import json
25
+ import os
26
+ from datetime import datetime
27
+ from typing import Any, Dict, List, Optional
28
+
29
+ import numpy as np
30
+ import pandas as pd
31
+
32
+ try:
33
+ from .. import bayesopt as ropt # type: ignore
34
+ from .utils.cli_common import ( # type: ignore
35
+ PLOT_MODEL_LABELS,
36
+ PYTORCH_TRAINERS,
37
+ build_model_names,
38
+ dedupe_preserve_order,
39
+ load_dataset,
40
+ parse_model_pairs,
41
+ resolve_data_path,
42
+ resolve_path,
43
+ fingerprint_file,
44
+ coerce_dataset_types,
45
+ split_train_test,
46
+ )
47
+ from .utils.cli_config import ( # type: ignore
48
+ add_config_json_arg,
49
+ add_output_dir_arg,
50
+ resolve_and_load_config,
51
+ resolve_data_config,
52
+ resolve_report_config,
53
+ resolve_split_config,
54
+ resolve_runtime_config,
55
+ resolve_output_dirs,
56
+ )
57
+ except Exception: # pragma: no cover
58
+ try:
59
+ import bayesopt as ropt # type: ignore
60
+ from utils.cli_common import ( # type: ignore
61
+ PLOT_MODEL_LABELS,
62
+ PYTORCH_TRAINERS,
63
+ build_model_names,
64
+ dedupe_preserve_order,
65
+ load_dataset,
66
+ parse_model_pairs,
67
+ resolve_data_path,
68
+ resolve_path,
69
+ fingerprint_file,
70
+ coerce_dataset_types,
71
+ split_train_test,
72
+ )
73
+ from utils.cli_config import ( # type: ignore
74
+ add_config_json_arg,
75
+ add_output_dir_arg,
76
+ resolve_and_load_config,
77
+ resolve_data_config,
78
+ resolve_report_config,
79
+ resolve_split_config,
80
+ resolve_runtime_config,
81
+ resolve_output_dirs,
82
+ )
83
+ except Exception:
84
+ try:
85
+ import ins_pricing.modelling.core.bayesopt as ropt # type: ignore
86
+ from ins_pricing.cli.utils.cli_common import ( # type: ignore
87
+ PLOT_MODEL_LABELS,
88
+ PYTORCH_TRAINERS,
89
+ build_model_names,
90
+ dedupe_preserve_order,
91
+ load_dataset,
92
+ parse_model_pairs,
93
+ resolve_data_path,
94
+ resolve_path,
95
+ fingerprint_file,
96
+ coerce_dataset_types,
97
+ split_train_test,
98
+ )
99
+ from ins_pricing.cli.utils.cli_config import ( # type: ignore
100
+ add_config_json_arg,
101
+ add_output_dir_arg,
102
+ resolve_and_load_config,
103
+ resolve_data_config,
104
+ resolve_report_config,
105
+ resolve_split_config,
106
+ resolve_runtime_config,
107
+ resolve_output_dirs,
108
+ )
109
+ except Exception:
110
+ import BayesOpt as ropt # type: ignore
111
+ from utils.cli_common import ( # type: ignore
112
+ PLOT_MODEL_LABELS,
113
+ PYTORCH_TRAINERS,
114
+ build_model_names,
115
+ dedupe_preserve_order,
116
+ load_dataset,
117
+ parse_model_pairs,
118
+ resolve_data_path,
119
+ resolve_path,
120
+ fingerprint_file,
121
+ coerce_dataset_types,
122
+ split_train_test,
123
+ )
124
+ from utils.cli_config import ( # type: ignore
125
+ add_config_json_arg,
126
+ add_output_dir_arg,
127
+ resolve_and_load_config,
128
+ resolve_data_config,
129
+ resolve_report_config,
130
+ resolve_split_config,
131
+ resolve_runtime_config,
132
+ resolve_output_dirs,
133
+ )
134
+
135
+ import matplotlib
136
+
137
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
138
+ matplotlib.use("Agg")
139
+ import matplotlib.pyplot as plt
140
+
141
+ try:
142
+ from .utils.run_logging import configure_run_logging # type: ignore
143
+ except Exception: # pragma: no cover
144
+ try:
145
+ from utils.run_logging import configure_run_logging # type: ignore
146
+ except Exception: # pragma: no cover
147
+ configure_run_logging = None # type: ignore
148
+
149
+ try:
150
+ from ..modelling.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
151
+ except Exception: # pragma: no cover
152
+ try:
153
+ from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
154
+ except Exception: # pragma: no cover
155
+ plot_loss_curve_common = None
156
+
157
+ try:
158
+ from ..modelling.core.evaluation import ( # type: ignore
159
+ bootstrap_ci,
160
+ calibrate_predictions,
161
+ metrics_report as eval_metrics_report,
162
+ select_threshold,
163
+ )
164
+ from ..governance.registry import ModelArtifact, ModelRegistry # type: ignore
165
+ from ..production import psi_report as drift_psi_report # type: ignore
166
+ from ..production.monitoring import group_metrics # type: ignore
167
+ from ..reporting.report_builder import ReportPayload, write_report # type: ignore
168
+ except Exception: # pragma: no cover
169
+ try:
170
+ from ins_pricing.modelling.core.evaluation import ( # type: ignore
171
+ bootstrap_ci,
172
+ calibrate_predictions,
173
+ metrics_report as eval_metrics_report,
174
+ select_threshold,
175
+ )
176
+ from ins_pricing.governance.registry import ( # type: ignore
177
+ ModelArtifact,
178
+ ModelRegistry,
179
+ )
180
+ from ins_pricing.production import psi_report as drift_psi_report # type: ignore
181
+ from ins_pricing.production.monitoring import group_metrics # type: ignore
182
+ from ins_pricing.reporting.report_builder import ( # type: ignore
183
+ ReportPayload,
184
+ write_report,
185
+ )
186
+ except Exception: # pragma: no cover
187
+ try:
188
+ from evaluation import ( # type: ignore
189
+ bootstrap_ci,
190
+ calibrate_predictions,
191
+ metrics_report as eval_metrics_report,
192
+ select_threshold,
193
+ )
194
+ from ins_pricing.governance.registry import ( # type: ignore
195
+ ModelArtifact,
196
+ ModelRegistry,
197
+ )
198
+ from ins_pricing.production import psi_report as drift_psi_report # type: ignore
199
+ from ins_pricing.production.monitoring import group_metrics # type: ignore
200
+ from ins_pricing.reporting.report_builder import ( # type: ignore
201
+ ReportPayload,
202
+ write_report,
203
+ )
204
+ except Exception: # pragma: no cover
205
+ bootstrap_ci = None # type: ignore
206
+ calibrate_predictions = None # type: ignore
207
+ eval_metrics_report = None # type: ignore
208
+ select_threshold = None # type: ignore
209
+ drift_psi_report = None # type: ignore
210
+ group_metrics = None # type: ignore
211
+ ReportPayload = None # type: ignore
212
+ write_report = None # type: ignore
213
+ ModelRegistry = None # type: ignore
214
+ ModelArtifact = None # type: ignore
215
+
216
+
217
+ def _parse_args() -> argparse.Namespace:
218
+ parser = argparse.ArgumentParser(
219
+ description="Batch trainer generated from BayesOpt_AutoPricing notebook."
220
+ )
221
+ add_config_json_arg(
222
+ parser,
223
+ help_text="Path to the JSON config describing datasets and feature columns.",
224
+ )
225
+ parser.add_argument(
226
+ "--model-keys",
227
+ nargs="+",
228
+ default=["ft"],
229
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
230
+ help="Space-separated list of trainers to run (e.g., --model-keys glm xgb). Include 'all' to run every trainer.",
231
+ )
232
+ parser.add_argument(
233
+ "--stack-model-keys",
234
+ nargs="+",
235
+ default=None,
236
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
237
+ help=(
238
+ "Only used when ft_role != 'model' (FT runs as feature generator). "
239
+ "When provided (or when config defines stack_model_keys), these trainers run after FT features "
240
+ "are generated. Use 'all' to run every non-FT trainer."
241
+ ),
242
+ )
243
+ parser.add_argument(
244
+ "--max-evals",
245
+ type=int,
246
+ default=50,
247
+ help="Optuna trial count per dataset.",
248
+ )
249
+ parser.add_argument(
250
+ "--use-resn-ddp",
251
+ action="store_true",
252
+ help="Force ResNet trainer to use DistributedDataParallel.",
253
+ )
254
+ parser.add_argument(
255
+ "--use-ft-ddp",
256
+ action="store_true",
257
+ help="Force FT-Transformer trainer to use DistributedDataParallel.",
258
+ )
259
+ parser.add_argument(
260
+ "--use-resn-dp",
261
+ action="store_true",
262
+ help="Enable ResNet DataParallel fall-back regardless of config.",
263
+ )
264
+ parser.add_argument(
265
+ "--use-ft-dp",
266
+ action="store_true",
267
+ help="Enable FT-Transformer DataParallel fall-back regardless of config.",
268
+ )
269
+ parser.add_argument(
270
+ "--use-gnn-dp",
271
+ action="store_true",
272
+ help="Enable GNN DataParallel fall-back regardless of config.",
273
+ )
274
+ parser.add_argument(
275
+ "--use-gnn-ddp",
276
+ action="store_true",
277
+ help="Force GNN trainer to use DistributedDataParallel.",
278
+ )
279
+ parser.add_argument(
280
+ "--gnn-no-ann",
281
+ action="store_true",
282
+ help="Disable approximate k-NN for GNN graph construction and use exact search.",
283
+ )
284
+ parser.add_argument(
285
+ "--gnn-ann-threshold",
286
+ type=int,
287
+ default=None,
288
+ help="Row threshold above which approximate k-NN is preferred (overrides config).",
289
+ )
290
+ parser.add_argument(
291
+ "--gnn-graph-cache",
292
+ default=None,
293
+ help="Optional path to persist/load cached adjacency matrix for GNN.",
294
+ )
295
+ parser.add_argument(
296
+ "--gnn-max-gpu-nodes",
297
+ type=int,
298
+ default=None,
299
+ help="Overrides the maximum node count allowed for GPU k-NN graph construction.",
300
+ )
301
+ parser.add_argument(
302
+ "--gnn-gpu-mem-ratio",
303
+ type=float,
304
+ default=None,
305
+ help="Overrides the fraction of free GPU memory the k-NN builder may consume.",
306
+ )
307
+ parser.add_argument(
308
+ "--gnn-gpu-mem-overhead",
309
+ type=float,
310
+ default=None,
311
+ help="Overrides the temporary GPU memory overhead multiplier for k-NN estimation.",
312
+ )
313
+ add_output_dir_arg(
314
+ parser,
315
+ help_text="Override output root for models/results/plots.",
316
+ )
317
+ parser.add_argument(
318
+ "--plot-curves",
319
+ action="store_true",
320
+ help="Enable lift/diagnostic plots after training (config file may also request plotting).",
321
+ )
322
+ parser.add_argument(
323
+ "--ft-as-feature",
324
+ action="store_true",
325
+ help="Alias for --ft-role embedding (keep tuning, export embeddings; skip FT plots/SHAP).",
326
+ )
327
+ parser.add_argument(
328
+ "--ft-role",
329
+ default=None,
330
+ choices=["model", "embedding", "unsupervised_embedding"],
331
+ help="How to use FT: model (default), embedding (export pooling embeddings), or unsupervised_embedding.",
332
+ )
333
+ parser.add_argument(
334
+ "--ft-feature-prefix",
335
+ default="ft_feat",
336
+ help="Prefix used for generated FT features (columns: pred_<prefix>_0.. or pred_<prefix>).",
337
+ )
338
+ parser.add_argument(
339
+ "--reuse-best-params",
340
+ action="store_true",
341
+ help="Skip Optuna and reuse best_params saved in Results/versions or bestparams CSV when available.",
342
+ )
343
+ return parser.parse_args()
344
+
345
+
346
+ def _plot_curves_for_model(model: ropt.BayesOptModel, trained_keys: List[str], cfg: Dict) -> None:
347
+ plot_cfg = cfg.get("plot", {})
348
+ legacy_lift_flags = {
349
+ "glm": cfg.get("plot_lift_glm", False),
350
+ "xgb": cfg.get("plot_lift_xgb", False),
351
+ "resn": cfg.get("plot_lift_resn", False),
352
+ "ft": cfg.get("plot_lift_ft", False),
353
+ }
354
+ plot_enabled = plot_cfg.get("enable", any(legacy_lift_flags.values()))
355
+ if not plot_enabled:
356
+ return
357
+
358
+ n_bins = int(plot_cfg.get("n_bins", 10))
359
+ oneway_enabled = plot_cfg.get("oneway", True)
360
+
361
+ available_models = dedupe_preserve_order(
362
+ [m for m in trained_keys if m in PLOT_MODEL_LABELS]
363
+ )
364
+
365
+ lift_models = plot_cfg.get("lift_models")
366
+ if lift_models is None:
367
+ lift_models = [
368
+ m for m, enabled in legacy_lift_flags.items() if enabled]
369
+ if not lift_models:
370
+ lift_models = available_models
371
+ lift_models = dedupe_preserve_order(
372
+ [m for m in lift_models if m in available_models]
373
+ )
374
+
375
+ if oneway_enabled:
376
+ oneway_pred = bool(plot_cfg.get("oneway_pred", False))
377
+ oneway_pred_models = plot_cfg.get("oneway_pred_models")
378
+ pred_plotted = False
379
+ if oneway_pred:
380
+ if oneway_pred_models is None:
381
+ oneway_pred_models = lift_models or available_models
382
+ oneway_pred_models = dedupe_preserve_order(
383
+ [m for m in oneway_pred_models if m in available_models]
384
+ )
385
+ for model_key in oneway_pred_models:
386
+ label, pred_nme = PLOT_MODEL_LABELS[model_key]
387
+ if pred_nme not in model.train_data.columns:
388
+ print(
389
+ f"[Oneway] Missing prediction column '{pred_nme}'; skip.",
390
+ flush=True,
391
+ )
392
+ continue
393
+ model.plot_oneway(
394
+ n_bins=n_bins,
395
+ pred_col=pred_nme,
396
+ pred_label=label,
397
+ plot_subdir="oneway/post",
398
+ )
399
+ pred_plotted = True
400
+ if not oneway_pred or not pred_plotted:
401
+ model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/post")
402
+
403
+ if not available_models:
404
+ return
405
+
406
+ for model_key in lift_models:
407
+ label, pred_nme = PLOT_MODEL_LABELS[model_key]
408
+ model.plot_lift(model_label=label, pred_nme=pred_nme, n_bins=n_bins)
409
+
410
+ if not plot_cfg.get("double_lift", True) or len(available_models) < 2:
411
+ return
412
+
413
+ raw_pairs = plot_cfg.get("double_lift_pairs")
414
+ if raw_pairs:
415
+ pairs = [
416
+ (a, b)
417
+ for a, b in parse_model_pairs(raw_pairs)
418
+ if a in available_models and b in available_models and a != b
419
+ ]
420
+ else:
421
+ pairs = [(a, b) for i, a in enumerate(available_models)
422
+ for b in available_models[i + 1:]]
423
+
424
+ for first, second in pairs:
425
+ model.plot_dlift([first, second], n_bins=n_bins)
426
+
427
+
428
+ def _plot_loss_curve_for_trainer(model_name: str, trainer) -> None:
429
+ model_obj = getattr(trainer, "model", None)
430
+ history = None
431
+ if model_obj is not None:
432
+ history = getattr(model_obj, "training_history", None)
433
+ if not history:
434
+ history = getattr(trainer, "training_history", None)
435
+ if not history:
436
+ return
437
+ train_hist = list(history.get("train") or [])
438
+ val_hist = list(history.get("val") or [])
439
+ if not train_hist and not val_hist:
440
+ return
441
+ try:
442
+ plot_dir = trainer.output.plot_path(
443
+ f"{model_name}/loss/loss_{model_name}_{trainer.model_name_prefix}.png"
444
+ )
445
+ except Exception:
446
+ default_dir = Path("plot") / model_name / "loss"
447
+ default_dir.mkdir(parents=True, exist_ok=True)
448
+ plot_dir = str(
449
+ default_dir / f"loss_{model_name}_{trainer.model_name_prefix}.png")
450
+ if plot_loss_curve_common is not None:
451
+ plot_loss_curve_common(
452
+ history=history,
453
+ title=f"{trainer.model_name_prefix} Loss Curve ({model_name})",
454
+ save_path=plot_dir,
455
+ show=False,
456
+ )
457
+ else:
458
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
459
+ fig, ax = plt.subplots(figsize=(8, 4))
460
+ if train_hist:
461
+ ax.plot(range(1, len(train_hist) + 1),
462
+ train_hist, label="Train Loss", color="tab:blue")
463
+ if val_hist:
464
+ ax.plot(range(1, len(val_hist) + 1),
465
+ val_hist, label="Validation Loss", color="tab:orange")
466
+ ax.set_xlabel("Epoch")
467
+ ax.set_ylabel("Weighted Loss")
468
+ ax.set_title(
469
+ f"{trainer.model_name_prefix} Loss Curve ({model_name})")
470
+ ax.grid(True, linestyle="--", alpha=0.3)
471
+ ax.legend()
472
+ plt.tight_layout()
473
+ plt.savefig(plot_dir, dpi=300)
474
+ plt.close(fig)
475
+ print(
476
+ f"[Plot] Saved loss curve for {model_name}/{trainer.label} -> {plot_dir}")
477
+
478
+
479
+ def _sample_arrays(
480
+ y_true: np.ndarray,
481
+ y_pred: np.ndarray,
482
+ *,
483
+ max_rows: Optional[int],
484
+ seed: Optional[int],
485
+ ) -> tuple[np.ndarray, np.ndarray]:
486
+ if max_rows is None or max_rows <= 0:
487
+ return y_true, y_pred
488
+ n = len(y_true)
489
+ if n <= max_rows:
490
+ return y_true, y_pred
491
+ rng = np.random.default_rng(seed)
492
+ idx = rng.choice(n, size=int(max_rows), replace=False)
493
+ return y_true[idx], y_pred[idx]
494
+
495
+
496
+ def _compute_psi_report(
497
+ model: ropt.BayesOptModel,
498
+ *,
499
+ features: Optional[List[str]],
500
+ bins: int,
501
+ strategy: str,
502
+ ) -> Optional[pd.DataFrame]:
503
+ if drift_psi_report is None:
504
+ return None
505
+ psi_features = features or list(getattr(model, "factor_nmes", []))
506
+ psi_features = [
507
+ f for f in psi_features if f in model.train_data.columns and f in model.test_data.columns]
508
+ if not psi_features:
509
+ return None
510
+ try:
511
+ return drift_psi_report(
512
+ model.train_data[psi_features],
513
+ model.test_data[psi_features],
514
+ features=psi_features,
515
+ bins=int(bins),
516
+ strategy=str(strategy),
517
+ )
518
+ except Exception as exc:
519
+ print(f"[Report] PSI computation failed: {exc}")
520
+ return None
521
+
522
+
523
+ def _evaluate_and_report(
524
+ model: ropt.BayesOptModel,
525
+ *,
526
+ model_name: str,
527
+ model_key: str,
528
+ cfg: Dict[str, Any],
529
+ data_path: Path,
530
+ data_fingerprint: Dict[str, Any],
531
+ report_output_dir: Optional[str],
532
+ report_group_cols: Optional[List[str]],
533
+ report_time_col: Optional[str],
534
+ report_time_freq: str,
535
+ report_time_ascending: bool,
536
+ psi_report_df: Optional[pd.DataFrame],
537
+ calibration_cfg: Dict[str, Any],
538
+ threshold_cfg: Dict[str, Any],
539
+ bootstrap_cfg: Dict[str, Any],
540
+ register_model: bool,
541
+ registry_path: Optional[str],
542
+ registry_tags: Dict[str, Any],
543
+ registry_status: str,
544
+ run_id: str,
545
+ config_sha: str,
546
+ ) -> None:
547
+ if eval_metrics_report is None:
548
+ print("[Report] Skip evaluation: metrics module unavailable.")
549
+ return
550
+
551
+ pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
552
+ if pred_col not in model.test_data.columns:
553
+ print(
554
+ f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
555
+ return
556
+
557
+ weight_col = getattr(model, "weight_nme", None)
558
+ y_true_train = model.train_data[model.resp_nme].to_numpy(
559
+ dtype=float, copy=False)
560
+ y_true_test = model.test_data[model.resp_nme].to_numpy(
561
+ dtype=float, copy=False)
562
+ y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
563
+ y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
564
+ weight_train = (
565
+ model.train_data[weight_col].to_numpy(dtype=float, copy=False)
566
+ if weight_col and weight_col in model.train_data.columns
567
+ else None
568
+ )
569
+ weight_test = (
570
+ model.test_data[weight_col].to_numpy(dtype=float, copy=False)
571
+ if weight_col and weight_col in model.test_data.columns
572
+ else None
573
+ )
574
+
575
+ task_type = str(cfg.get("task_type", getattr(
576
+ model, "task_type", "regression")))
577
+ if task_type == "classification":
578
+ y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
579
+ y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
580
+
581
+ calibration_info: Optional[Dict[str, Any]] = None
582
+ threshold_info: Optional[Dict[str, Any]] = None
583
+ y_pred_train_eval = y_pred_train
584
+ y_pred_test_eval = y_pred_test
585
+
586
+ if task_type == "classification":
587
+ cal_cfg = dict(calibration_cfg or {})
588
+ cal_enabled = bool(cal_cfg.get("enable", False)
589
+ or cal_cfg.get("method"))
590
+ if cal_enabled and calibrate_predictions is not None:
591
+ method = cal_cfg.get("method", "sigmoid")
592
+ max_rows = cal_cfg.get("max_rows")
593
+ seed = cal_cfg.get("seed")
594
+ y_cal, p_cal = _sample_arrays(
595
+ y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
596
+ try:
597
+ calibrator = calibrate_predictions(y_cal, p_cal, method=method)
598
+ y_pred_train_eval = calibrator.predict(y_pred_train)
599
+ y_pred_test_eval = calibrator.predict(y_pred_test)
600
+ calibration_info = {
601
+ "method": calibrator.method, "max_rows": max_rows}
602
+ except Exception as exc:
603
+ print(
604
+ f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
605
+
606
+ thr_cfg = dict(threshold_cfg or {})
607
+ thr_enabled = bool(
608
+ thr_cfg.get("enable", False)
609
+ or thr_cfg.get("metric")
610
+ or thr_cfg.get("value") is not None
611
+ )
612
+ threshold_value = 0.5
613
+ if thr_cfg.get("value") is not None:
614
+ threshold_value = float(thr_cfg["value"])
615
+ threshold_info = {"threshold": threshold_value, "source": "fixed"}
616
+ elif thr_enabled and select_threshold is not None:
617
+ max_rows = thr_cfg.get("max_rows")
618
+ seed = thr_cfg.get("seed")
619
+ y_thr, p_thr = _sample_arrays(
620
+ y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
621
+ threshold_info = select_threshold(
622
+ y_thr,
623
+ p_thr,
624
+ metric=thr_cfg.get("metric", "f1"),
625
+ min_positive_rate=thr_cfg.get("min_positive_rate"),
626
+ grid=thr_cfg.get("grid", 99),
627
+ )
628
+ threshold_value = float(threshold_info.get("threshold", 0.5))
629
+ else:
630
+ threshold_value = 0.5
631
+ metrics = eval_metrics_report(
632
+ y_true_test,
633
+ y_pred_test_eval,
634
+ task_type=task_type,
635
+ threshold=threshold_value,
636
+ )
637
+ precision = float(metrics.get("precision", 0.0))
638
+ recall = float(metrics.get("recall", 0.0))
639
+ f1 = 0.0 if (precision + recall) == 0 else 2 * \
640
+ precision * recall / (precision + recall)
641
+ metrics["f1"] = float(f1)
642
+ metrics["threshold"] = float(threshold_value)
643
+ else:
644
+ metrics = eval_metrics_report(
645
+ y_true_test,
646
+ y_pred_test_eval,
647
+ task_type=task_type,
648
+ weight=weight_test,
649
+ )
650
+
651
+ bootstrap_results: Dict[str, Dict[str, float]] = {}
652
+ if bootstrap_cfg and bool(bootstrap_cfg.get("enable", False)) and bootstrap_ci is not None:
653
+ metric_names = bootstrap_cfg.get("metrics") or list(metrics.keys())
654
+ n_samples = int(bootstrap_cfg.get("n_samples", 200))
655
+ ci = float(bootstrap_cfg.get("ci", 0.95))
656
+ seed = bootstrap_cfg.get("seed")
657
+
658
+ def _metric_fn(y_true, y_pred, weight=None):
659
+ vals = eval_metrics_report(
660
+ y_true,
661
+ y_pred,
662
+ task_type=task_type,
663
+ weight=weight,
664
+ threshold=metrics.get("threshold", 0.5),
665
+ )
666
+ if task_type == "classification":
667
+ prec = float(vals.get("precision", 0.0))
668
+ rec = float(vals.get("recall", 0.0))
669
+ vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * \
670
+ prec * rec / (prec + rec)
671
+ return vals
672
+
673
+ for name in metric_names:
674
+ if name not in metrics:
675
+ continue
676
+ ci_result = bootstrap_ci(
677
+ lambda y_t, y_p, w=None: float(
678
+ _metric_fn(y_t, y_p, w).get(name, 0.0)),
679
+ y_true_test,
680
+ y_pred_test_eval,
681
+ weight=weight_test,
682
+ n_samples=n_samples,
683
+ ci=ci,
684
+ seed=seed,
685
+ )
686
+ bootstrap_results[str(name)] = ci_result
687
+
688
+ validation_table = None
689
+ if report_group_cols and group_metrics is not None:
690
+ available_groups = [
691
+ col for col in report_group_cols if col in model.test_data.columns
692
+ ]
693
+ if available_groups:
694
+ try:
695
+ validation_table = group_metrics(
696
+ model.test_data,
697
+ actual_col=model.resp_nme,
698
+ pred_col=pred_col,
699
+ group_cols=available_groups,
700
+ weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
701
+ )
702
+ counts = (
703
+ model.test_data.groupby(available_groups, dropna=False)
704
+ .size()
705
+ .reset_index(name="count")
706
+ )
707
+ validation_table = validation_table.merge(
708
+ counts, on=available_groups, how="left")
709
+ except Exception as exc:
710
+ print(
711
+ f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
712
+
713
+ risk_trend = None
714
+ if report_time_col and group_metrics is not None:
715
+ if report_time_col in model.test_data.columns:
716
+ try:
717
+ time_df = model.test_data.copy()
718
+ time_series = pd.to_datetime(
719
+ time_df[report_time_col], errors="coerce")
720
+ time_df = time_df.loc[time_series.notna()].copy()
721
+ if not time_df.empty:
722
+ time_df["_time_bucket"] = (
723
+ pd.to_datetime(
724
+ time_df[report_time_col], errors="coerce")
725
+ .dt.to_period(report_time_freq)
726
+ .dt.to_timestamp()
727
+ )
728
+ risk_trend = group_metrics(
729
+ time_df,
730
+ actual_col=model.resp_nme,
731
+ pred_col=pred_col,
732
+ group_cols=["_time_bucket"],
733
+ weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
734
+ )
735
+ counts = (
736
+ time_df.groupby("_time_bucket", dropna=False)
737
+ .size()
738
+ .reset_index(name="count")
739
+ )
740
+ risk_trend = risk_trend.merge(
741
+ counts, on="_time_bucket", how="left")
742
+ risk_trend = risk_trend.sort_values(
743
+ "_time_bucket", ascending=bool(report_time_ascending)
744
+ ).reset_index(drop=True)
745
+ risk_trend = risk_trend.rename(
746
+ columns={"_time_bucket": report_time_col})
747
+ except Exception as exc:
748
+ print(
749
+ f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
750
+
751
+ report_root = (
752
+ Path(report_output_dir)
753
+ if report_output_dir
754
+ else Path(model.output_manager.result_dir) / "reports"
755
+ )
756
+ report_root.mkdir(parents=True, exist_ok=True)
757
+
758
+ version = f"{model_key}_{run_id}"
759
+ metrics_payload = {
760
+ "model_name": model_name,
761
+ "model_key": model_key,
762
+ "model_version": version,
763
+ "metrics": metrics,
764
+ "threshold": threshold_info,
765
+ "calibration": calibration_info,
766
+ "bootstrap": bootstrap_results,
767
+ "data_path": str(data_path),
768
+ "data_fingerprint": data_fingerprint,
769
+ "config_sha256": config_sha,
770
+ "pred_col": pred_col,
771
+ "task_type": task_type,
772
+ }
773
+ metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
774
+ metrics_path.write_text(
775
+ json.dumps(metrics_payload, indent=2, ensure_ascii=True),
776
+ encoding="utf-8",
777
+ )
778
+
779
+ report_path = None
780
+ if ReportPayload is not None and write_report is not None:
781
+ notes_lines = [
782
+ f"- Config SHA256: {config_sha}",
783
+ f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
784
+ ]
785
+ if calibration_info:
786
+ notes_lines.append(
787
+ f"- Calibration: {calibration_info.get('method')}"
788
+ )
789
+ if threshold_info:
790
+ notes_lines.append(
791
+ f"- Threshold selection: {threshold_info}"
792
+ )
793
+ if bootstrap_results:
794
+ notes_lines.append("- Bootstrap: see metrics JSON for CI")
795
+ extra_notes = "\n".join(notes_lines)
796
+ payload = ReportPayload(
797
+ model_name=f"{model_name}/{model_key}",
798
+ model_version=version,
799
+ metrics={k: float(v) for k, v in metrics.items()},
800
+ risk_trend=risk_trend,
801
+ drift_report=psi_report_df,
802
+ validation_table=validation_table,
803
+ extra_notes=extra_notes,
804
+ )
805
+ report_path = write_report(
806
+ payload,
807
+ report_root / f"{model_name}_{model_key}_report.md",
808
+ )
809
+
810
+ if register_model and ModelRegistry is not None and ModelArtifact is not None:
811
+ registry = ModelRegistry(
812
+ registry_path
813
+ if registry_path
814
+ else Path(model.output_manager.result_dir) / "model_registry.json"
815
+ )
816
+ tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
817
+ tags.update({
818
+ "model_key": str(model_key),
819
+ "task_type": str(task_type),
820
+ "data_path": str(data_path),
821
+ "data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
822
+ "data_size": str(data_fingerprint.get("size", "")),
823
+ "data_mtime": str(data_fingerprint.get("mtime", "")),
824
+ "config_sha256": str(config_sha),
825
+ })
826
+ artifacts = []
827
+ trainer = model.trainers.get(model_key)
828
+ if trainer is not None:
829
+ try:
830
+ model_path = trainer.output.model_path(
831
+ trainer._get_model_filename())
832
+ if os.path.exists(model_path):
833
+ artifacts.append(ModelArtifact(
834
+ path=model_path, description="trained model"))
835
+ except Exception:
836
+ pass
837
+ if report_path is not None:
838
+ artifacts.append(ModelArtifact(
839
+ path=str(report_path), description="model report"))
840
+ if metrics_path.exists():
841
+ artifacts.append(ModelArtifact(
842
+ path=str(metrics_path), description="metrics json"))
843
+ if bool(cfg.get("save_preprocess", False)):
844
+ artifact_path = cfg.get("preprocess_artifact_path")
845
+ if artifact_path:
846
+ preprocess_path = Path(str(artifact_path))
847
+ if not preprocess_path.is_absolute():
848
+ preprocess_path = Path(
849
+ model.output_manager.result_dir) / preprocess_path
850
+ else:
851
+ preprocess_path = Path(model.output_manager.result_path(
852
+ f"{model.model_nme}_preprocess.json"
853
+ ))
854
+ if preprocess_path.exists():
855
+ artifacts.append(
856
+ ModelArtifact(path=str(preprocess_path),
857
+ description="preprocess artifacts")
858
+ )
859
+ if bool(cfg.get("cache_predictions", False)):
860
+ cache_dir = cfg.get("prediction_cache_dir")
861
+ if cache_dir:
862
+ pred_root = Path(str(cache_dir))
863
+ if not pred_root.is_absolute():
864
+ pred_root = Path(
865
+ model.output_manager.result_dir) / pred_root
866
+ else:
867
+ pred_root = Path(
868
+ model.output_manager.result_dir) / "predictions"
869
+ ext = "csv" if str(
870
+ cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
871
+ for split_label in ("train", "test"):
872
+ pred_path = pred_root / \
873
+ f"{model_name}_{model_key}_{split_label}.{ext}"
874
+ if pred_path.exists():
875
+ artifacts.append(
876
+ ModelArtifact(path=str(pred_path),
877
+ description=f"predictions {split_label}")
878
+ )
879
+ registry.register(
880
+ name=str(model_name),
881
+ version=version,
882
+ metrics={k: float(v) for k, v in metrics.items()},
883
+ tags=tags,
884
+ artifacts=artifacts,
885
+ status=str(registry_status or "candidate"),
886
+ notes=f"model_key={model_key}",
887
+ )
888
+
889
+
890
+ def train_from_config(args: argparse.Namespace) -> None:
891
+ script_dir = Path(__file__).resolve().parents[1]
892
+ config_path, cfg = resolve_and_load_config(
893
+ args.config_json,
894
+ script_dir,
895
+ required_keys=["data_dir", "model_list",
896
+ "model_categories", "target", "weight"],
897
+ )
898
+ plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
899
+ config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
900
+ run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
901
+
902
+ def _safe_int_env(key: str, default: int) -> int:
903
+ try:
904
+ return int(os.environ.get(key, default))
905
+ except (TypeError, ValueError):
906
+ return default
907
+
908
+ dist_world_size = _safe_int_env("WORLD_SIZE", 1)
909
+ dist_rank = _safe_int_env("RANK", 0)
910
+ dist_active = dist_world_size > 1
911
+ is_main_process = (not dist_active) or dist_rank == 0
912
+
913
+ def _ddp_barrier(reason: str) -> None:
914
+ if not dist_active:
915
+ return
916
+ torch_mod = getattr(ropt, "torch", None)
917
+ dist_mod = getattr(torch_mod, "distributed", None)
918
+ if dist_mod is None:
919
+ return
920
+ try:
921
+ if not getattr(dist_mod, "is_available", lambda: False)():
922
+ return
923
+ if not dist_mod.is_initialized():
924
+ ddp_ok, _, _, _ = ropt.DistributedUtils.setup_ddp()
925
+ if not ddp_ok or not dist_mod.is_initialized():
926
+ return
927
+ dist_mod.barrier()
928
+ except Exception as exc:
929
+ print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
930
+ raise
931
+
932
+ data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
933
+ cfg,
934
+ config_path,
935
+ create_data_dir=True,
936
+ )
937
+ runtime_cfg = resolve_runtime_config(cfg)
938
+ ddp_min_rows = runtime_cfg["ddp_min_rows"]
939
+ bo_sample_limit = runtime_cfg["bo_sample_limit"]
940
+ cache_predictions = runtime_cfg["cache_predictions"]
941
+ prediction_cache_dir = runtime_cfg["prediction_cache_dir"]
942
+ prediction_cache_format = runtime_cfg["prediction_cache_format"]
943
+ report_cfg = resolve_report_config(cfg)
944
+ report_output_dir = report_cfg["report_output_dir"]
945
+ report_group_cols = report_cfg["report_group_cols"]
946
+ report_time_col = report_cfg["report_time_col"]
947
+ report_time_freq = report_cfg["report_time_freq"]
948
+ report_time_ascending = report_cfg["report_time_ascending"]
949
+ psi_bins = report_cfg["psi_bins"]
950
+ psi_strategy = report_cfg["psi_strategy"]
951
+ psi_features = report_cfg["psi_features"]
952
+ calibration_cfg = report_cfg["calibration_cfg"]
953
+ threshold_cfg = report_cfg["threshold_cfg"]
954
+ bootstrap_cfg = report_cfg["bootstrap_cfg"]
955
+ register_model = report_cfg["register_model"]
956
+ registry_path = report_cfg["registry_path"]
957
+ registry_tags = report_cfg["registry_tags"]
958
+ registry_status = report_cfg["registry_status"]
959
+ data_fingerprint_max_bytes = report_cfg["data_fingerprint_max_bytes"]
960
+ report_enabled = report_cfg["report_enabled"]
961
+
962
+ split_cfg = resolve_split_config(cfg)
963
+ prop_test = split_cfg["prop_test"]
964
+ holdout_ratio = split_cfg["holdout_ratio"]
965
+ val_ratio = split_cfg["val_ratio"]
966
+ split_strategy = split_cfg["split_strategy"]
967
+ split_group_col = split_cfg["split_group_col"]
968
+ split_time_col = split_cfg["split_time_col"]
969
+ split_time_ascending = split_cfg["split_time_ascending"]
970
+ cv_strategy = split_cfg["cv_strategy"]
971
+ cv_group_col = split_cfg["cv_group_col"]
972
+ cv_time_col = split_cfg["cv_time_col"]
973
+ cv_time_ascending = split_cfg["cv_time_ascending"]
974
+ cv_splits = split_cfg["cv_splits"]
975
+ ft_oof_folds = split_cfg["ft_oof_folds"]
976
+ ft_oof_strategy = split_cfg["ft_oof_strategy"]
977
+ ft_oof_shuffle = split_cfg["ft_oof_shuffle"]
978
+ save_preprocess = runtime_cfg["save_preprocess"]
979
+ preprocess_artifact_path = runtime_cfg["preprocess_artifact_path"]
980
+ rand_seed = runtime_cfg["rand_seed"]
981
+ epochs = runtime_cfg["epochs"]
982
+ output_cfg = resolve_output_dirs(
983
+ cfg,
984
+ config_path,
985
+ output_override=args.output_dir,
986
+ )
987
+ output_dir = output_cfg["output_dir"]
988
+ reuse_best_params = bool(
989
+ args.reuse_best_params or runtime_cfg["reuse_best_params"])
990
+ xgb_max_depth_max = runtime_cfg["xgb_max_depth_max"]
991
+ xgb_n_estimators_max = runtime_cfg["xgb_n_estimators_max"]
992
+ optuna_storage = runtime_cfg["optuna_storage"]
993
+ optuna_study_prefix = runtime_cfg["optuna_study_prefix"]
994
+ best_params_files = runtime_cfg["best_params_files"]
995
+ plot_path_style = runtime_cfg["plot_path_style"]
996
+
997
+ model_names = build_model_names(
998
+ cfg["model_list"], cfg["model_categories"])
999
+ if not model_names:
1000
+ raise ValueError(
1001
+ "No model names generated from model_list/model_categories.")
1002
+
1003
+ results: Dict[str, ropt.BayesOptModel] = {}
1004
+ trained_keys_by_model: Dict[str, List[str]] = {}
1005
+
1006
+ for model_name in model_names:
1007
+ # Per-dataset training loop: load data, split train/test, and train requested models.
1008
+ data_path = resolve_data_path(
1009
+ data_dir,
1010
+ model_name,
1011
+ data_format=data_format,
1012
+ path_template=data_path_template,
1013
+ )
1014
+ if not data_path.exists():
1015
+ raise FileNotFoundError(f"Missing dataset: {data_path}")
1016
+ data_fingerprint = {"path": str(data_path)}
1017
+ if report_enabled and is_main_process:
1018
+ data_fingerprint = fingerprint_file(
1019
+ data_path,
1020
+ max_bytes=data_fingerprint_max_bytes,
1021
+ )
1022
+
1023
+ print(f"\n=== Processing model {model_name} ===")
1024
+ raw = load_dataset(
1025
+ data_path,
1026
+ data_format=data_format,
1027
+ dtype_map=dtype_map,
1028
+ low_memory=False,
1029
+ )
1030
+ raw = coerce_dataset_types(raw)
1031
+
1032
+ train_df, test_df = split_train_test(
1033
+ raw,
1034
+ holdout_ratio=holdout_ratio,
1035
+ strategy=split_strategy,
1036
+ group_col=split_group_col,
1037
+ time_col=split_time_col,
1038
+ time_ascending=split_time_ascending,
1039
+ rand_seed=rand_seed,
1040
+ reset_index_mode="time_group",
1041
+ ratio_label="holdout_ratio",
1042
+ )
1043
+
1044
+ use_resn_dp = args.use_resn_dp or cfg.get(
1045
+ "use_resn_data_parallel", False)
1046
+ use_ft_dp = args.use_ft_dp or cfg.get("use_ft_data_parallel", True)
1047
+ dataset_rows = len(raw)
1048
+ ddp_enabled = bool(dist_active and (dataset_rows >= int(ddp_min_rows)))
1049
+ use_resn_ddp = (args.use_resn_ddp or cfg.get(
1050
+ "use_resn_ddp", False)) and ddp_enabled
1051
+ use_ft_ddp = (args.use_ft_ddp or cfg.get(
1052
+ "use_ft_ddp", False)) and ddp_enabled
1053
+ use_gnn_dp = args.use_gnn_dp or cfg.get("use_gnn_data_parallel", False)
1054
+ use_gnn_ddp = (args.use_gnn_ddp or cfg.get(
1055
+ "use_gnn_ddp", False)) and ddp_enabled
1056
+ gnn_use_ann = cfg.get("gnn_use_approx_knn", True)
1057
+ if args.gnn_no_ann:
1058
+ gnn_use_ann = False
1059
+ gnn_threshold = args.gnn_ann_threshold if args.gnn_ann_threshold is not None else cfg.get(
1060
+ "gnn_approx_knn_threshold", 50000)
1061
+ gnn_graph_cache = args.gnn_graph_cache or cfg.get("gnn_graph_cache")
1062
+ if isinstance(gnn_graph_cache, str) and gnn_graph_cache.strip():
1063
+ resolved_cache = resolve_path(gnn_graph_cache, config_path.parent)
1064
+ if resolved_cache is not None:
1065
+ gnn_graph_cache = str(resolved_cache)
1066
+ gnn_max_gpu_nodes = args.gnn_max_gpu_nodes if args.gnn_max_gpu_nodes is not None else cfg.get(
1067
+ "gnn_max_gpu_knn_nodes", 200000)
1068
+ gnn_gpu_mem_ratio = args.gnn_gpu_mem_ratio if args.gnn_gpu_mem_ratio is not None else cfg.get(
1069
+ "gnn_knn_gpu_mem_ratio", 0.9)
1070
+ gnn_gpu_mem_overhead = args.gnn_gpu_mem_overhead if args.gnn_gpu_mem_overhead is not None else cfg.get(
1071
+ "gnn_knn_gpu_mem_overhead", 2.0)
1072
+
1073
+ binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
1074
+ task_type = str(cfg.get("task_type", "regression"))
1075
+ feature_list = cfg.get("feature_list")
1076
+ categorical_features = cfg.get("categorical_features")
1077
+ use_gpu = bool(cfg.get("use_gpu", True))
1078
+ region_province_col = cfg.get("region_province_col")
1079
+ region_city_col = cfg.get("region_city_col")
1080
+ region_effect_alpha = cfg.get("region_effect_alpha")
1081
+ geo_feature_nmes = cfg.get("geo_feature_nmes")
1082
+ geo_token_hidden_dim = cfg.get("geo_token_hidden_dim")
1083
+ geo_token_layers = cfg.get("geo_token_layers")
1084
+ geo_token_dropout = cfg.get("geo_token_dropout")
1085
+ geo_token_k_neighbors = cfg.get("geo_token_k_neighbors")
1086
+ geo_token_learning_rate = cfg.get("geo_token_learning_rate")
1087
+ geo_token_epochs = cfg.get("geo_token_epochs")
1088
+
1089
+ ft_role = args.ft_role or cfg.get("ft_role", "model")
1090
+ if args.ft_as_feature and args.ft_role is None:
1091
+ # Keep legacy behavior as a convenience alias only when the config
1092
+ # didn't already request a non-default FT role.
1093
+ if str(cfg.get("ft_role", "model")) == "model":
1094
+ ft_role = "embedding"
1095
+ ft_feature_prefix = str(
1096
+ cfg.get("ft_feature_prefix", args.ft_feature_prefix))
1097
+ ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
1098
+
1099
+ model = ropt.BayesOptModel(
1100
+ train_df,
1101
+ test_df,
1102
+ model_name,
1103
+ cfg["target"],
1104
+ cfg["weight"],
1105
+ feature_list,
1106
+ task_type=task_type,
1107
+ binary_resp_nme=binary_target,
1108
+ cate_list=categorical_features,
1109
+ prop_test=val_ratio,
1110
+ rand_seed=rand_seed,
1111
+ epochs=epochs,
1112
+ use_gpu=use_gpu,
1113
+ use_resn_data_parallel=use_resn_dp,
1114
+ use_ft_data_parallel=use_ft_dp,
1115
+ use_resn_ddp=use_resn_ddp,
1116
+ use_ft_ddp=use_ft_ddp,
1117
+ use_gnn_data_parallel=use_gnn_dp,
1118
+ use_gnn_ddp=use_gnn_ddp,
1119
+ output_dir=output_dir,
1120
+ xgb_max_depth_max=xgb_max_depth_max,
1121
+ xgb_n_estimators_max=xgb_n_estimators_max,
1122
+ resn_weight_decay=cfg.get("resn_weight_decay"),
1123
+ final_ensemble=bool(cfg.get("final_ensemble", False)),
1124
+ final_ensemble_k=int(cfg.get("final_ensemble_k", 3)),
1125
+ final_refit=bool(cfg.get("final_refit", True)),
1126
+ optuna_storage=optuna_storage,
1127
+ optuna_study_prefix=optuna_study_prefix,
1128
+ best_params_files=best_params_files,
1129
+ gnn_use_approx_knn=gnn_use_ann,
1130
+ gnn_approx_knn_threshold=gnn_threshold,
1131
+ gnn_graph_cache=gnn_graph_cache,
1132
+ gnn_max_gpu_knn_nodes=gnn_max_gpu_nodes,
1133
+ gnn_knn_gpu_mem_ratio=gnn_gpu_mem_ratio,
1134
+ gnn_knn_gpu_mem_overhead=gnn_gpu_mem_overhead,
1135
+ region_province_col=region_province_col,
1136
+ region_city_col=region_city_col,
1137
+ region_effect_alpha=region_effect_alpha,
1138
+ geo_feature_nmes=geo_feature_nmes,
1139
+ geo_token_hidden_dim=geo_token_hidden_dim,
1140
+ geo_token_layers=geo_token_layers,
1141
+ geo_token_dropout=geo_token_dropout,
1142
+ geo_token_k_neighbors=geo_token_k_neighbors,
1143
+ geo_token_learning_rate=geo_token_learning_rate,
1144
+ geo_token_epochs=geo_token_epochs,
1145
+ ft_role=ft_role,
1146
+ ft_feature_prefix=ft_feature_prefix,
1147
+ ft_num_numeric_tokens=ft_num_numeric_tokens,
1148
+ infer_categorical_max_unique=int(
1149
+ cfg.get("infer_categorical_max_unique", 50)),
1150
+ infer_categorical_max_ratio=float(
1151
+ cfg.get("infer_categorical_max_ratio", 0.05)),
1152
+ reuse_best_params=reuse_best_params,
1153
+ bo_sample_limit=bo_sample_limit,
1154
+ cache_predictions=cache_predictions,
1155
+ prediction_cache_dir=prediction_cache_dir,
1156
+ prediction_cache_format=prediction_cache_format,
1157
+ cv_strategy=cv_strategy or split_strategy,
1158
+ cv_group_col=cv_group_col or split_group_col,
1159
+ cv_time_col=cv_time_col or split_time_col,
1160
+ cv_time_ascending=cv_time_ascending,
1161
+ cv_splits=cv_splits,
1162
+ ft_oof_folds=ft_oof_folds,
1163
+ ft_oof_strategy=ft_oof_strategy,
1164
+ ft_oof_shuffle=ft_oof_shuffle,
1165
+ save_preprocess=save_preprocess,
1166
+ preprocess_artifact_path=preprocess_artifact_path,
1167
+ plot_path_style=plot_path_style,
1168
+ )
1169
+
1170
+ if plot_requested:
1171
+ plot_cfg = cfg.get("plot", {})
1172
+ legacy_lift_flags = {
1173
+ "glm": cfg.get("plot_lift_glm", False),
1174
+ "xgb": cfg.get("plot_lift_xgb", False),
1175
+ "resn": cfg.get("plot_lift_resn", False),
1176
+ "ft": cfg.get("plot_lift_ft", False),
1177
+ }
1178
+ plot_enabled = plot_cfg.get(
1179
+ "enable", any(legacy_lift_flags.values()))
1180
+ if plot_enabled and plot_cfg.get("pre_oneway", False) and plot_cfg.get("oneway", True):
1181
+ n_bins = int(plot_cfg.get("n_bins", 10))
1182
+ model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/pre")
1183
+
1184
+ if "all" in args.model_keys:
1185
+ requested_keys = ["glm", "xgb", "resn", "ft", "gnn"]
1186
+ else:
1187
+ requested_keys = args.model_keys
1188
+ requested_keys = dedupe_preserve_order(requested_keys)
1189
+
1190
+ if ft_role != "model":
1191
+ requested_keys = [k for k in requested_keys if k != "ft"]
1192
+ if not requested_keys:
1193
+ stack_keys = args.stack_model_keys or cfg.get(
1194
+ "stack_model_keys")
1195
+ if stack_keys:
1196
+ if "all" in stack_keys:
1197
+ requested_keys = ["glm", "xgb", "resn", "gnn"]
1198
+ else:
1199
+ requested_keys = [k for k in stack_keys if k != "ft"]
1200
+ requested_keys = dedupe_preserve_order(requested_keys)
1201
+ if dist_active and ddp_enabled:
1202
+ ft_trainer = model.trainers.get("ft")
1203
+ if ft_trainer is None:
1204
+ raise ValueError("FT trainer is not available.")
1205
+ ft_trainer_uses_ddp = bool(
1206
+ getattr(ft_trainer, "enable_distributed_optuna", False))
1207
+ if not ft_trainer_uses_ddp:
1208
+ raise ValueError(
1209
+ "FT embedding under torchrun requires enabling FT DDP (use --use-ft-ddp or set use_ft_ddp=true)."
1210
+ )
1211
+ missing = [key for key in requested_keys if key not in model.trainers]
1212
+ if missing:
1213
+ raise ValueError(
1214
+ f"Trainer(s) {missing} not available for {model_name}")
1215
+
1216
+ executed_keys: List[str] = []
1217
+ if ft_role != "model":
1218
+ if dist_active and not ddp_enabled:
1219
+ _ddp_barrier("start_ft_embedding")
1220
+ if dist_rank != 0:
1221
+ _ddp_barrier("finish_ft_embedding")
1222
+ continue
1223
+ print(
1224
+ f"Optimizing ft as {ft_role} for {model_name} (max_evals={args.max_evals})")
1225
+ model.optimize_model("ft", max_evals=args.max_evals)
1226
+ model.trainers["ft"].save()
1227
+ if getattr(ropt, "torch", None) is not None and ropt.torch.cuda.is_available():
1228
+ ropt.free_cuda()
1229
+ if dist_active and not ddp_enabled:
1230
+ _ddp_barrier("finish_ft_embedding")
1231
+ for key in requested_keys:
1232
+ trainer = model.trainers[key]
1233
+ trainer_uses_ddp = bool(
1234
+ getattr(trainer, "enable_distributed_optuna", False))
1235
+ if dist_active and not trainer_uses_ddp:
1236
+ if dist_rank != 0:
1237
+ print(
1238
+ f"[Rank {dist_rank}] Skip {model_name}/{key} because trainer is not DDP-enabled."
1239
+ )
1240
+ _ddp_barrier(f"start_non_ddp_{model_name}_{key}")
1241
+ if dist_rank != 0:
1242
+ _ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
1243
+ continue
1244
+
1245
+ print(
1246
+ f"Optimizing {key} for {model_name} (max_evals={args.max_evals})")
1247
+ model.optimize_model(key, max_evals=args.max_evals)
1248
+ model.trainers[key].save()
1249
+ _plot_loss_curve_for_trainer(model_name, model.trainers[key])
1250
+ if key in PYTORCH_TRAINERS:
1251
+ ropt.free_cuda()
1252
+ if dist_active and not trainer_uses_ddp:
1253
+ _ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
1254
+ executed_keys.append(key)
1255
+
1256
+ if not executed_keys:
1257
+ continue
1258
+
1259
+ results[model_name] = model
1260
+ trained_keys_by_model[model_name] = executed_keys
1261
+ if report_enabled and is_main_process:
1262
+ psi_report_df = _compute_psi_report(
1263
+ model,
1264
+ features=psi_features,
1265
+ bins=psi_bins,
1266
+ strategy=str(psi_strategy),
1267
+ )
1268
+ for key in executed_keys:
1269
+ _evaluate_and_report(
1270
+ model,
1271
+ model_name=model_name,
1272
+ model_key=key,
1273
+ cfg=cfg,
1274
+ data_path=data_path,
1275
+ data_fingerprint=data_fingerprint,
1276
+ report_output_dir=report_output_dir,
1277
+ report_group_cols=report_group_cols,
1278
+ report_time_col=report_time_col,
1279
+ report_time_freq=str(report_time_freq),
1280
+ report_time_ascending=bool(report_time_ascending),
1281
+ psi_report_df=psi_report_df,
1282
+ calibration_cfg=calibration_cfg,
1283
+ threshold_cfg=threshold_cfg,
1284
+ bootstrap_cfg=bootstrap_cfg,
1285
+ register_model=register_model,
1286
+ registry_path=registry_path,
1287
+ registry_tags=registry_tags,
1288
+ registry_status=registry_status,
1289
+ run_id=run_id,
1290
+ config_sha=config_sha,
1291
+ )
1292
+
1293
+ if not plot_requested:
1294
+ return
1295
+
1296
+ for name, model in results.items():
1297
+ _plot_curves_for_model(
1298
+ model,
1299
+ trained_keys_by_model.get(name, []),
1300
+ cfg,
1301
+ )
1302
+
1303
+
1304
+ def main() -> None:
1305
+ if configure_run_logging:
1306
+ configure_run_logging(prefix="bayesopt_entry")
1307
+ args = _parse_args()
1308
+ train_from_config(args)
1309
+
1310
+
1311
+ if __name__ == "__main__":
1312
+ main()