ins-pricing 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +66 -74
- ins_pricing/cli/BayesOpt_incremental.py +904 -904
- ins_pricing/cli/bayesopt_entry_runner.py +1442 -1442
- ins_pricing/frontend/README.md +573 -419
- ins_pricing/frontend/config_builder.py +1 -0
- ins_pricing/modelling/README.md +67 -0
- ins_pricing/modelling/core/bayesopt/README.md +59 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +12 -0
- ins_pricing/modelling/core/bayesopt/core.py +3 -1
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +830 -809
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.4.3.dist-info → ins_pricing-0.4.5.dist-info}/METADATA +182 -162
- {ins_pricing-0.4.3.dist-info → ins_pricing-0.4.5.dist-info}/RECORD +15 -22
- ins_pricing/CHANGELOG.md +0 -272
- ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
- ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
- ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
- ins_pricing/docs/modelling/README.md +0 -34
- ins_pricing/frontend/QUICKSTART.md +0 -152
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
- {ins_pricing-0.4.3.dist-info → ins_pricing-0.4.5.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.3.dist-info → ins_pricing-0.4.5.dist-info}/top_level.txt +0 -0
|
@@ -1,1442 +1,1442 @@
|
|
|
1
|
-
"""
|
|
2
|
-
CLI entry point generated from BayesOpt_AutoPricing.ipynb so the workflow can
|
|
3
|
-
run non‑interactively (e.g., via torchrun).
|
|
4
|
-
|
|
5
|
-
Example:
|
|
6
|
-
python -m torch.distributed.run --standalone --nproc_per_node=2 \\
|
|
7
|
-
ins_pricing/cli/BayesOpt_entry.py \\
|
|
8
|
-
--config-json
|
|
9
|
-
--model-keys ft --max-evals 50 --use-ft-ddp
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
from __future__ import annotations
|
|
13
|
-
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
import sys
|
|
16
|
-
|
|
17
|
-
if __package__ in {None, ""}:
|
|
18
|
-
repo_root = Path(__file__).resolve().parents[2]
|
|
19
|
-
if str(repo_root) not in sys.path:
|
|
20
|
-
sys.path.insert(0, str(repo_root))
|
|
21
|
-
|
|
22
|
-
import argparse
|
|
23
|
-
import hashlib
|
|
24
|
-
import json
|
|
25
|
-
import os
|
|
26
|
-
from datetime import datetime
|
|
27
|
-
from typing import Any, Dict, List, Optional
|
|
28
|
-
|
|
29
|
-
import numpy as np
|
|
30
|
-
import pandas as pd
|
|
31
|
-
|
|
32
|
-
# Use unified import resolver to eliminate nested try/except chains
|
|
33
|
-
from .utils.import_resolver import resolve_imports, setup_sys_path
|
|
34
|
-
from .utils.evaluation_context import (
|
|
35
|
-
EvaluationContext,
|
|
36
|
-
TrainingContext,
|
|
37
|
-
ModelIdentity,
|
|
38
|
-
DataFingerprint,
|
|
39
|
-
CalibrationConfig,
|
|
40
|
-
ThresholdConfig,
|
|
41
|
-
BootstrapConfig,
|
|
42
|
-
ReportConfig,
|
|
43
|
-
RegistryConfig,
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
# Resolve all imports from a single location
|
|
47
|
-
setup_sys_path()
|
|
48
|
-
_imports = resolve_imports()
|
|
49
|
-
|
|
50
|
-
ropt = _imports.bayesopt
|
|
51
|
-
PLOT_MODEL_LABELS = _imports.PLOT_MODEL_LABELS
|
|
52
|
-
PYTORCH_TRAINERS = _imports.PYTORCH_TRAINERS
|
|
53
|
-
build_model_names = _imports.build_model_names
|
|
54
|
-
dedupe_preserve_order = _imports.dedupe_preserve_order
|
|
55
|
-
load_dataset = _imports.load_dataset
|
|
56
|
-
parse_model_pairs = _imports.parse_model_pairs
|
|
57
|
-
resolve_data_path = _imports.resolve_data_path
|
|
58
|
-
resolve_path = _imports.resolve_path
|
|
59
|
-
fingerprint_file = _imports.fingerprint_file
|
|
60
|
-
coerce_dataset_types = _imports.coerce_dataset_types
|
|
61
|
-
split_train_test = _imports.split_train_test
|
|
62
|
-
|
|
63
|
-
add_config_json_arg = _imports.add_config_json_arg
|
|
64
|
-
add_output_dir_arg = _imports.add_output_dir_arg
|
|
65
|
-
resolve_and_load_config = _imports.resolve_and_load_config
|
|
66
|
-
resolve_data_config = _imports.resolve_data_config
|
|
67
|
-
resolve_report_config = _imports.resolve_report_config
|
|
68
|
-
resolve_split_config = _imports.resolve_split_config
|
|
69
|
-
resolve_runtime_config = _imports.resolve_runtime_config
|
|
70
|
-
resolve_output_dirs = _imports.resolve_output_dirs
|
|
71
|
-
|
|
72
|
-
bootstrap_ci = _imports.bootstrap_ci
|
|
73
|
-
calibrate_predictions = _imports.calibrate_predictions
|
|
74
|
-
eval_metrics_report = _imports.metrics_report
|
|
75
|
-
select_threshold = _imports.select_threshold
|
|
76
|
-
|
|
77
|
-
ModelArtifact = _imports.ModelArtifact
|
|
78
|
-
ModelRegistry = _imports.ModelRegistry
|
|
79
|
-
drift_psi_report = _imports.drift_psi_report
|
|
80
|
-
group_metrics = _imports.group_metrics
|
|
81
|
-
ReportPayload = _imports.ReportPayload
|
|
82
|
-
write_report = _imports.write_report
|
|
83
|
-
|
|
84
|
-
configure_run_logging = _imports.configure_run_logging
|
|
85
|
-
plot_loss_curve_common = _imports.plot_loss_curve
|
|
86
|
-
|
|
87
|
-
import matplotlib
|
|
88
|
-
|
|
89
|
-
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
90
|
-
matplotlib.use("Agg")
|
|
91
|
-
import matplotlib.pyplot as plt
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _parse_args() -> argparse.Namespace:
|
|
95
|
-
parser = argparse.ArgumentParser(
|
|
96
|
-
description="Batch trainer generated from BayesOpt_AutoPricing notebook."
|
|
97
|
-
)
|
|
98
|
-
add_config_json_arg(
|
|
99
|
-
parser,
|
|
100
|
-
help_text="Path to the JSON config describing datasets and feature columns.",
|
|
101
|
-
)
|
|
102
|
-
parser.add_argument(
|
|
103
|
-
"--model-keys",
|
|
104
|
-
nargs="+",
|
|
105
|
-
default=["ft"],
|
|
106
|
-
choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
|
|
107
|
-
help="Space-separated list of trainers to run (e.g., --model-keys glm xgb). Include 'all' to run every trainer.",
|
|
108
|
-
)
|
|
109
|
-
parser.add_argument(
|
|
110
|
-
"--stack-model-keys",
|
|
111
|
-
nargs="+",
|
|
112
|
-
default=None,
|
|
113
|
-
choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
|
|
114
|
-
help=(
|
|
115
|
-
"Only used when ft_role != 'model' (FT runs as feature generator). "
|
|
116
|
-
"When provided (or when config defines stack_model_keys), these trainers run after FT features "
|
|
117
|
-
"are generated. Use 'all' to run every non-FT trainer."
|
|
118
|
-
),
|
|
119
|
-
)
|
|
120
|
-
parser.add_argument(
|
|
121
|
-
"--max-evals",
|
|
122
|
-
type=int,
|
|
123
|
-
default=50,
|
|
124
|
-
help="Optuna trial count per dataset.",
|
|
125
|
-
)
|
|
126
|
-
parser.add_argument(
|
|
127
|
-
"--use-resn-ddp",
|
|
128
|
-
action="store_true",
|
|
129
|
-
help="Force ResNet trainer to use DistributedDataParallel.",
|
|
130
|
-
)
|
|
131
|
-
parser.add_argument(
|
|
132
|
-
"--use-ft-ddp",
|
|
133
|
-
action="store_true",
|
|
134
|
-
help="Force FT-Transformer trainer to use DistributedDataParallel.",
|
|
135
|
-
)
|
|
136
|
-
parser.add_argument(
|
|
137
|
-
"--use-resn-dp",
|
|
138
|
-
action="store_true",
|
|
139
|
-
help="Enable ResNet DataParallel fall-back regardless of config.",
|
|
140
|
-
)
|
|
141
|
-
parser.add_argument(
|
|
142
|
-
"--use-ft-dp",
|
|
143
|
-
action="store_true",
|
|
144
|
-
help="Enable FT-Transformer DataParallel fall-back regardless of config.",
|
|
145
|
-
)
|
|
146
|
-
parser.add_argument(
|
|
147
|
-
"--use-gnn-dp",
|
|
148
|
-
action="store_true",
|
|
149
|
-
help="Enable GNN DataParallel fall-back regardless of config.",
|
|
150
|
-
)
|
|
151
|
-
parser.add_argument(
|
|
152
|
-
"--use-gnn-ddp",
|
|
153
|
-
action="store_true",
|
|
154
|
-
help="Force GNN trainer to use DistributedDataParallel.",
|
|
155
|
-
)
|
|
156
|
-
parser.add_argument(
|
|
157
|
-
"--gnn-no-ann",
|
|
158
|
-
action="store_true",
|
|
159
|
-
help="Disable approximate k-NN for GNN graph construction and use exact search.",
|
|
160
|
-
)
|
|
161
|
-
parser.add_argument(
|
|
162
|
-
"--gnn-ann-threshold",
|
|
163
|
-
type=int,
|
|
164
|
-
default=None,
|
|
165
|
-
help="Row threshold above which approximate k-NN is preferred (overrides config).",
|
|
166
|
-
)
|
|
167
|
-
parser.add_argument(
|
|
168
|
-
"--gnn-graph-cache",
|
|
169
|
-
default=None,
|
|
170
|
-
help="Optional path to persist/load cached adjacency matrix for GNN.",
|
|
171
|
-
)
|
|
172
|
-
parser.add_argument(
|
|
173
|
-
"--gnn-max-gpu-nodes",
|
|
174
|
-
type=int,
|
|
175
|
-
default=None,
|
|
176
|
-
help="Overrides the maximum node count allowed for GPU k-NN graph construction.",
|
|
177
|
-
)
|
|
178
|
-
parser.add_argument(
|
|
179
|
-
"--gnn-gpu-mem-ratio",
|
|
180
|
-
type=float,
|
|
181
|
-
default=None,
|
|
182
|
-
help="Overrides the fraction of free GPU memory the k-NN builder may consume.",
|
|
183
|
-
)
|
|
184
|
-
parser.add_argument(
|
|
185
|
-
"--gnn-gpu-mem-overhead",
|
|
186
|
-
type=float,
|
|
187
|
-
default=None,
|
|
188
|
-
help="Overrides the temporary GPU memory overhead multiplier for k-NN estimation.",
|
|
189
|
-
)
|
|
190
|
-
add_output_dir_arg(
|
|
191
|
-
parser,
|
|
192
|
-
help_text="Override output root for models/results/plots.",
|
|
193
|
-
)
|
|
194
|
-
parser.add_argument(
|
|
195
|
-
"--plot-curves",
|
|
196
|
-
action="store_true",
|
|
197
|
-
help="Enable lift/diagnostic plots after training (config file may also request plotting).",
|
|
198
|
-
)
|
|
199
|
-
parser.add_argument(
|
|
200
|
-
"--ft-as-feature",
|
|
201
|
-
action="store_true",
|
|
202
|
-
help="Alias for --ft-role embedding (keep tuning, export embeddings; skip FT plots/SHAP).",
|
|
203
|
-
)
|
|
204
|
-
parser.add_argument(
|
|
205
|
-
"--ft-role",
|
|
206
|
-
default=None,
|
|
207
|
-
choices=["model", "embedding", "unsupervised_embedding"],
|
|
208
|
-
help="How to use FT: model (default), embedding (export pooling embeddings), or unsupervised_embedding.",
|
|
209
|
-
)
|
|
210
|
-
parser.add_argument(
|
|
211
|
-
"--ft-feature-prefix",
|
|
212
|
-
default="ft_feat",
|
|
213
|
-
help="Prefix used for generated FT features (columns: pred_<prefix>_0.. or pred_<prefix>).",
|
|
214
|
-
)
|
|
215
|
-
parser.add_argument(
|
|
216
|
-
"--reuse-best-params",
|
|
217
|
-
action="store_true",
|
|
218
|
-
help="Skip Optuna and reuse best_params saved in Results/versions or bestparams CSV when available.",
|
|
219
|
-
)
|
|
220
|
-
return parser.parse_args()
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def _plot_curves_for_model(model: ropt.BayesOptModel, trained_keys: List[str], cfg: Dict) -> None:
|
|
224
|
-
plot_cfg = cfg.get("plot", {})
|
|
225
|
-
legacy_lift_flags = {
|
|
226
|
-
"glm": cfg.get("plot_lift_glm", False),
|
|
227
|
-
"xgb": cfg.get("plot_lift_xgb", False),
|
|
228
|
-
"resn": cfg.get("plot_lift_resn", False),
|
|
229
|
-
"ft": cfg.get("plot_lift_ft", False),
|
|
230
|
-
}
|
|
231
|
-
plot_enabled = plot_cfg.get("enable", any(legacy_lift_flags.values()))
|
|
232
|
-
if not plot_enabled:
|
|
233
|
-
return
|
|
234
|
-
|
|
235
|
-
n_bins = int(plot_cfg.get("n_bins", 10))
|
|
236
|
-
oneway_enabled = plot_cfg.get("oneway", True)
|
|
237
|
-
|
|
238
|
-
available_models = dedupe_preserve_order(
|
|
239
|
-
[m for m in trained_keys if m in PLOT_MODEL_LABELS]
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
lift_models = plot_cfg.get("lift_models")
|
|
243
|
-
if lift_models is None:
|
|
244
|
-
lift_models = [
|
|
245
|
-
m for m, enabled in legacy_lift_flags.items() if enabled]
|
|
246
|
-
if not lift_models:
|
|
247
|
-
lift_models = available_models
|
|
248
|
-
lift_models = dedupe_preserve_order(
|
|
249
|
-
[m for m in lift_models if m in available_models]
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
if oneway_enabled:
|
|
253
|
-
oneway_pred = bool(plot_cfg.get("oneway_pred", False))
|
|
254
|
-
oneway_pred_models = plot_cfg.get("oneway_pred_models")
|
|
255
|
-
pred_plotted = False
|
|
256
|
-
if oneway_pred:
|
|
257
|
-
if oneway_pred_models is None:
|
|
258
|
-
oneway_pred_models = lift_models or available_models
|
|
259
|
-
oneway_pred_models = dedupe_preserve_order(
|
|
260
|
-
[m for m in oneway_pred_models if m in available_models]
|
|
261
|
-
)
|
|
262
|
-
for model_key in oneway_pred_models:
|
|
263
|
-
label, pred_nme = PLOT_MODEL_LABELS[model_key]
|
|
264
|
-
if pred_nme not in model.train_data.columns:
|
|
265
|
-
print(
|
|
266
|
-
f"[Oneway] Missing prediction column '{pred_nme}'; skip.",
|
|
267
|
-
flush=True,
|
|
268
|
-
)
|
|
269
|
-
continue
|
|
270
|
-
model.plot_oneway(
|
|
271
|
-
n_bins=n_bins,
|
|
272
|
-
pred_col=pred_nme,
|
|
273
|
-
pred_label=label,
|
|
274
|
-
plot_subdir="oneway/post",
|
|
275
|
-
)
|
|
276
|
-
pred_plotted = True
|
|
277
|
-
if not oneway_pred or not pred_plotted:
|
|
278
|
-
model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/post")
|
|
279
|
-
|
|
280
|
-
if not available_models:
|
|
281
|
-
return
|
|
282
|
-
|
|
283
|
-
for model_key in lift_models:
|
|
284
|
-
label, pred_nme = PLOT_MODEL_LABELS[model_key]
|
|
285
|
-
model.plot_lift(model_label=label, pred_nme=pred_nme, n_bins=n_bins)
|
|
286
|
-
|
|
287
|
-
if not plot_cfg.get("double_lift", True) or len(available_models) < 2:
|
|
288
|
-
return
|
|
289
|
-
|
|
290
|
-
raw_pairs = plot_cfg.get("double_lift_pairs")
|
|
291
|
-
if raw_pairs:
|
|
292
|
-
pairs = [
|
|
293
|
-
(a, b)
|
|
294
|
-
for a, b in parse_model_pairs(raw_pairs)
|
|
295
|
-
if a in available_models and b in available_models and a != b
|
|
296
|
-
]
|
|
297
|
-
else:
|
|
298
|
-
pairs = [(a, b) for i, a in enumerate(available_models)
|
|
299
|
-
for b in available_models[i + 1:]]
|
|
300
|
-
|
|
301
|
-
for first, second in pairs:
|
|
302
|
-
model.plot_dlift([first, second], n_bins=n_bins)
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
def _plot_loss_curve_for_trainer(model_name: str, trainer) -> None:
|
|
306
|
-
model_obj = getattr(trainer, "model", None)
|
|
307
|
-
history = None
|
|
308
|
-
if model_obj is not None:
|
|
309
|
-
history = getattr(model_obj, "training_history", None)
|
|
310
|
-
if not history:
|
|
311
|
-
history = getattr(trainer, "training_history", None)
|
|
312
|
-
if not history:
|
|
313
|
-
return
|
|
314
|
-
train_hist = list(history.get("train") or [])
|
|
315
|
-
val_hist = list(history.get("val") or [])
|
|
316
|
-
if not train_hist and not val_hist:
|
|
317
|
-
return
|
|
318
|
-
try:
|
|
319
|
-
plot_dir = trainer.output.plot_path(
|
|
320
|
-
f"{model_name}/loss/loss_{model_name}_{trainer.model_name_prefix}.png"
|
|
321
|
-
)
|
|
322
|
-
except Exception:
|
|
323
|
-
default_dir = Path("plot") / model_name / "loss"
|
|
324
|
-
default_dir.mkdir(parents=True, exist_ok=True)
|
|
325
|
-
plot_dir = str(
|
|
326
|
-
default_dir / f"loss_{model_name}_{trainer.model_name_prefix}.png")
|
|
327
|
-
if plot_loss_curve_common is not None:
|
|
328
|
-
plot_loss_curve_common(
|
|
329
|
-
history=history,
|
|
330
|
-
title=f"{trainer.model_name_prefix} Loss Curve ({model_name})",
|
|
331
|
-
save_path=plot_dir,
|
|
332
|
-
show=False,
|
|
333
|
-
)
|
|
334
|
-
else:
|
|
335
|
-
epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
|
|
336
|
-
fig, ax = plt.subplots(figsize=(8, 4))
|
|
337
|
-
if train_hist:
|
|
338
|
-
ax.plot(range(1, len(train_hist) + 1),
|
|
339
|
-
train_hist, label="Train Loss", color="tab:blue")
|
|
340
|
-
if val_hist:
|
|
341
|
-
ax.plot(range(1, len(val_hist) + 1),
|
|
342
|
-
val_hist, label="Validation Loss", color="tab:orange")
|
|
343
|
-
ax.set_xlabel("Epoch")
|
|
344
|
-
ax.set_ylabel("Weighted Loss")
|
|
345
|
-
ax.set_title(
|
|
346
|
-
f"{trainer.model_name_prefix} Loss Curve ({model_name})")
|
|
347
|
-
ax.grid(True, linestyle="--", alpha=0.3)
|
|
348
|
-
ax.legend()
|
|
349
|
-
plt.tight_layout()
|
|
350
|
-
plt.savefig(plot_dir, dpi=300)
|
|
351
|
-
plt.close(fig)
|
|
352
|
-
print(
|
|
353
|
-
f"[Plot] Saved loss curve for {model_name}/{trainer.label} -> {plot_dir}")
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
def _sample_arrays(
|
|
357
|
-
y_true: np.ndarray,
|
|
358
|
-
y_pred: np.ndarray,
|
|
359
|
-
*,
|
|
360
|
-
max_rows: Optional[int],
|
|
361
|
-
seed: Optional[int],
|
|
362
|
-
) -> tuple[np.ndarray, np.ndarray]:
|
|
363
|
-
if max_rows is None or max_rows <= 0:
|
|
364
|
-
return y_true, y_pred
|
|
365
|
-
n = len(y_true)
|
|
366
|
-
if n <= max_rows:
|
|
367
|
-
return y_true, y_pred
|
|
368
|
-
rng = np.random.default_rng(seed)
|
|
369
|
-
idx = rng.choice(n, size=int(max_rows), replace=False)
|
|
370
|
-
return y_true[idx], y_pred[idx]
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
def _compute_psi_report(
|
|
374
|
-
model: ropt.BayesOptModel,
|
|
375
|
-
*,
|
|
376
|
-
features: Optional[List[str]],
|
|
377
|
-
bins: int,
|
|
378
|
-
strategy: str,
|
|
379
|
-
) -> Optional[pd.DataFrame]:
|
|
380
|
-
if drift_psi_report is None:
|
|
381
|
-
return None
|
|
382
|
-
psi_features = features or list(getattr(model, "factor_nmes", []))
|
|
383
|
-
psi_features = [
|
|
384
|
-
f for f in psi_features if f in model.train_data.columns and f in model.test_data.columns]
|
|
385
|
-
if not psi_features:
|
|
386
|
-
return None
|
|
387
|
-
try:
|
|
388
|
-
return drift_psi_report(
|
|
389
|
-
model.train_data[psi_features],
|
|
390
|
-
model.test_data[psi_features],
|
|
391
|
-
features=psi_features,
|
|
392
|
-
bins=int(bins),
|
|
393
|
-
strategy=str(strategy),
|
|
394
|
-
)
|
|
395
|
-
except Exception as exc:
|
|
396
|
-
print(f"[Report] PSI computation failed: {exc}")
|
|
397
|
-
return None
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
# --- Refactored helper functions for _evaluate_and_report ---
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
def _apply_calibration(
|
|
404
|
-
y_true_train: np.ndarray,
|
|
405
|
-
y_pred_train: np.ndarray,
|
|
406
|
-
y_pred_test: np.ndarray,
|
|
407
|
-
calibration_cfg: Dict[str, Any],
|
|
408
|
-
model_name: str,
|
|
409
|
-
model_key: str,
|
|
410
|
-
) -> tuple[np.ndarray, np.ndarray, Optional[Dict[str, Any]]]:
|
|
411
|
-
"""Apply calibration to predictions for classification tasks.
|
|
412
|
-
|
|
413
|
-
Returns:
|
|
414
|
-
Tuple of (calibrated_train_preds, calibrated_test_preds, calibration_info)
|
|
415
|
-
"""
|
|
416
|
-
cal_cfg = dict(calibration_cfg or {})
|
|
417
|
-
cal_enabled = bool(cal_cfg.get("enable", False) or cal_cfg.get("method"))
|
|
418
|
-
|
|
419
|
-
if not cal_enabled or calibrate_predictions is None:
|
|
420
|
-
return y_pred_train, y_pred_test, None
|
|
421
|
-
|
|
422
|
-
method = cal_cfg.get("method", "sigmoid")
|
|
423
|
-
max_rows = cal_cfg.get("max_rows")
|
|
424
|
-
seed = cal_cfg.get("seed")
|
|
425
|
-
y_cal, p_cal = _sample_arrays(
|
|
426
|
-
y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
|
|
427
|
-
|
|
428
|
-
try:
|
|
429
|
-
calibrator = calibrate_predictions(y_cal, p_cal, method=method)
|
|
430
|
-
calibrated_train = calibrator.predict(y_pred_train)
|
|
431
|
-
calibrated_test = calibrator.predict(y_pred_test)
|
|
432
|
-
calibration_info = {"method": calibrator.method, "max_rows": max_rows}
|
|
433
|
-
return calibrated_train, calibrated_test, calibration_info
|
|
434
|
-
except Exception as exc:
|
|
435
|
-
print(f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
|
|
436
|
-
return y_pred_train, y_pred_test, None
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
def _select_classification_threshold(
|
|
440
|
-
y_true_train: np.ndarray,
|
|
441
|
-
y_pred_train_eval: np.ndarray,
|
|
442
|
-
threshold_cfg: Dict[str, Any],
|
|
443
|
-
) -> tuple[float, Optional[Dict[str, Any]]]:
|
|
444
|
-
"""Select threshold for classification predictions.
|
|
445
|
-
|
|
446
|
-
Returns:
|
|
447
|
-
Tuple of (threshold_value, threshold_info)
|
|
448
|
-
"""
|
|
449
|
-
thr_cfg = dict(threshold_cfg or {})
|
|
450
|
-
thr_enabled = bool(
|
|
451
|
-
thr_cfg.get("enable", False)
|
|
452
|
-
or thr_cfg.get("metric")
|
|
453
|
-
or thr_cfg.get("value") is not None
|
|
454
|
-
)
|
|
455
|
-
|
|
456
|
-
if thr_cfg.get("value") is not None:
|
|
457
|
-
threshold_value = float(thr_cfg["value"])
|
|
458
|
-
return threshold_value, {"threshold": threshold_value, "source": "fixed"}
|
|
459
|
-
|
|
460
|
-
if thr_enabled and select_threshold is not None:
|
|
461
|
-
max_rows = thr_cfg.get("max_rows")
|
|
462
|
-
seed = thr_cfg.get("seed")
|
|
463
|
-
y_thr, p_thr = _sample_arrays(
|
|
464
|
-
y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
|
|
465
|
-
threshold_info = select_threshold(
|
|
466
|
-
y_thr,
|
|
467
|
-
p_thr,
|
|
468
|
-
metric=thr_cfg.get("metric", "f1"),
|
|
469
|
-
min_positive_rate=thr_cfg.get("min_positive_rate"),
|
|
470
|
-
grid=thr_cfg.get("grid", 99),
|
|
471
|
-
)
|
|
472
|
-
return float(threshold_info.get("threshold", 0.5)), threshold_info
|
|
473
|
-
|
|
474
|
-
return 0.5, None
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
def _compute_classification_metrics(
|
|
478
|
-
y_true_test: np.ndarray,
|
|
479
|
-
y_pred_test_eval: np.ndarray,
|
|
480
|
-
threshold_value: float,
|
|
481
|
-
) -> Dict[str, Any]:
|
|
482
|
-
"""Compute metrics for classification task."""
|
|
483
|
-
metrics = eval_metrics_report(
|
|
484
|
-
y_true_test,
|
|
485
|
-
y_pred_test_eval,
|
|
486
|
-
task_type="classification",
|
|
487
|
-
threshold=threshold_value,
|
|
488
|
-
)
|
|
489
|
-
precision = float(metrics.get("precision", 0.0))
|
|
490
|
-
recall = float(metrics.get("recall", 0.0))
|
|
491
|
-
f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
|
|
492
|
-
metrics["f1"] = float(f1)
|
|
493
|
-
metrics["threshold"] = float(threshold_value)
|
|
494
|
-
return metrics
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
def _compute_bootstrap_ci(
|
|
498
|
-
y_true_test: np.ndarray,
|
|
499
|
-
y_pred_test_eval: np.ndarray,
|
|
500
|
-
weight_test: Optional[np.ndarray],
|
|
501
|
-
metrics: Dict[str, Any],
|
|
502
|
-
bootstrap_cfg: Dict[str, Any],
|
|
503
|
-
task_type: str,
|
|
504
|
-
) -> Dict[str, Dict[str, float]]:
|
|
505
|
-
"""Compute bootstrap confidence intervals for metrics."""
|
|
506
|
-
if not bootstrap_cfg or not bool(bootstrap_cfg.get("enable", False)) or bootstrap_ci is None:
|
|
507
|
-
return {}
|
|
508
|
-
|
|
509
|
-
metric_names = bootstrap_cfg.get("metrics")
|
|
510
|
-
if not metric_names:
|
|
511
|
-
metric_names = [name for name in metrics.keys() if name != "threshold"]
|
|
512
|
-
n_samples = int(bootstrap_cfg.get("n_samples", 200))
|
|
513
|
-
ci = float(bootstrap_cfg.get("ci", 0.95))
|
|
514
|
-
seed = bootstrap_cfg.get("seed")
|
|
515
|
-
|
|
516
|
-
def _metric_fn(y_true, y_pred, weight=None):
|
|
517
|
-
vals = eval_metrics_report(
|
|
518
|
-
y_true,
|
|
519
|
-
y_pred,
|
|
520
|
-
task_type=task_type,
|
|
521
|
-
weight=weight,
|
|
522
|
-
threshold=metrics.get("threshold", 0.5),
|
|
523
|
-
)
|
|
524
|
-
if task_type == "classification":
|
|
525
|
-
prec = float(vals.get("precision", 0.0))
|
|
526
|
-
rec = float(vals.get("recall", 0.0))
|
|
527
|
-
vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
|
|
528
|
-
return vals
|
|
529
|
-
|
|
530
|
-
bootstrap_results: Dict[str, Dict[str, float]] = {}
|
|
531
|
-
for name in metric_names:
|
|
532
|
-
if name not in metrics:
|
|
533
|
-
continue
|
|
534
|
-
ci_result = bootstrap_ci(
|
|
535
|
-
lambda y_t, y_p, w=None: float(_metric_fn(y_t, y_p, w).get(name, 0.0)),
|
|
536
|
-
y_true_test,
|
|
537
|
-
y_pred_test_eval,
|
|
538
|
-
weight=weight_test,
|
|
539
|
-
n_samples=n_samples,
|
|
540
|
-
ci=ci,
|
|
541
|
-
seed=seed,
|
|
542
|
-
)
|
|
543
|
-
bootstrap_results[str(name)] = ci_result
|
|
544
|
-
|
|
545
|
-
return bootstrap_results
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
def _compute_validation_table(
|
|
549
|
-
model: ropt.BayesOptModel,
|
|
550
|
-
pred_col: str,
|
|
551
|
-
report_group_cols: Optional[List[str]],
|
|
552
|
-
weight_col: Optional[str],
|
|
553
|
-
model_name: str,
|
|
554
|
-
model_key: str,
|
|
555
|
-
) -> Optional[pd.DataFrame]:
|
|
556
|
-
"""Compute grouped validation metrics table."""
|
|
557
|
-
if not report_group_cols or group_metrics is None:
|
|
558
|
-
return None
|
|
559
|
-
|
|
560
|
-
available_groups = [
|
|
561
|
-
col for col in report_group_cols if col in model.test_data.columns
|
|
562
|
-
]
|
|
563
|
-
if not available_groups:
|
|
564
|
-
return None
|
|
565
|
-
|
|
566
|
-
try:
|
|
567
|
-
validation_table = group_metrics(
|
|
568
|
-
model.test_data,
|
|
569
|
-
actual_col=model.resp_nme,
|
|
570
|
-
pred_col=pred_col,
|
|
571
|
-
group_cols=available_groups,
|
|
572
|
-
weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
|
|
573
|
-
)
|
|
574
|
-
counts = (
|
|
575
|
-
model.test_data.groupby(available_groups, dropna=False)
|
|
576
|
-
.size()
|
|
577
|
-
.reset_index(name="count")
|
|
578
|
-
)
|
|
579
|
-
return validation_table.merge(counts, on=available_groups, how="left")
|
|
580
|
-
except Exception as exc:
|
|
581
|
-
print(f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
|
|
582
|
-
return None
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
def _compute_risk_trend(
|
|
586
|
-
model: ropt.BayesOptModel,
|
|
587
|
-
pred_col: str,
|
|
588
|
-
report_time_col: Optional[str],
|
|
589
|
-
report_time_freq: str,
|
|
590
|
-
report_time_ascending: bool,
|
|
591
|
-
weight_col: Optional[str],
|
|
592
|
-
model_name: str,
|
|
593
|
-
model_key: str,
|
|
594
|
-
) -> Optional[pd.DataFrame]:
|
|
595
|
-
"""Compute time-series risk trend metrics."""
|
|
596
|
-
if not report_time_col or group_metrics is None:
|
|
597
|
-
return None
|
|
598
|
-
|
|
599
|
-
if report_time_col not in model.test_data.columns:
|
|
600
|
-
return None
|
|
601
|
-
|
|
602
|
-
try:
|
|
603
|
-
time_df = model.test_data.copy()
|
|
604
|
-
time_series = pd.to_datetime(time_df[report_time_col], errors="coerce")
|
|
605
|
-
time_df = time_df.loc[time_series.notna()].copy()
|
|
606
|
-
|
|
607
|
-
if time_df.empty:
|
|
608
|
-
return None
|
|
609
|
-
|
|
610
|
-
time_df["_time_bucket"] = (
|
|
611
|
-
pd.to_datetime(time_df[report_time_col], errors="coerce")
|
|
612
|
-
.dt.to_period(report_time_freq)
|
|
613
|
-
.dt.to_timestamp()
|
|
614
|
-
)
|
|
615
|
-
risk_trend = group_metrics(
|
|
616
|
-
time_df,
|
|
617
|
-
actual_col=model.resp_nme,
|
|
618
|
-
pred_col=pred_col,
|
|
619
|
-
group_cols=["_time_bucket"],
|
|
620
|
-
weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
|
|
621
|
-
)
|
|
622
|
-
counts = (
|
|
623
|
-
time_df.groupby("_time_bucket", dropna=False)
|
|
624
|
-
.size()
|
|
625
|
-
.reset_index(name="count")
|
|
626
|
-
)
|
|
627
|
-
risk_trend = risk_trend.merge(counts, on="_time_bucket", how="left")
|
|
628
|
-
risk_trend = risk_trend.sort_values(
|
|
629
|
-
"_time_bucket", ascending=bool(report_time_ascending)
|
|
630
|
-
).reset_index(drop=True)
|
|
631
|
-
return risk_trend.rename(columns={"_time_bucket": report_time_col})
|
|
632
|
-
except Exception as exc:
|
|
633
|
-
print(f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
|
|
634
|
-
return None
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
def _write_metrics_json(
|
|
638
|
-
report_root: Path,
|
|
639
|
-
model_name: str,
|
|
640
|
-
model_key: str,
|
|
641
|
-
version: str,
|
|
642
|
-
metrics: Dict[str, Any],
|
|
643
|
-
threshold_info: Optional[Dict[str, Any]],
|
|
644
|
-
calibration_info: Optional[Dict[str, Any]],
|
|
645
|
-
bootstrap_results: Dict[str, Dict[str, float]],
|
|
646
|
-
data_path: Path,
|
|
647
|
-
data_fingerprint: Dict[str, Any],
|
|
648
|
-
config_sha: str,
|
|
649
|
-
pred_col: str,
|
|
650
|
-
task_type: str,
|
|
651
|
-
) -> Path:
|
|
652
|
-
"""Write metrics to JSON file and return the path."""
|
|
653
|
-
metrics_payload = {
|
|
654
|
-
"model_name": model_name,
|
|
655
|
-
"model_key": model_key,
|
|
656
|
-
"model_version": version,
|
|
657
|
-
"metrics": metrics,
|
|
658
|
-
"threshold": threshold_info,
|
|
659
|
-
"calibration": calibration_info,
|
|
660
|
-
"bootstrap": bootstrap_results,
|
|
661
|
-
"data_path": str(data_path),
|
|
662
|
-
"data_fingerprint": data_fingerprint,
|
|
663
|
-
"config_sha256": config_sha,
|
|
664
|
-
"pred_col": pred_col,
|
|
665
|
-
"task_type": task_type,
|
|
666
|
-
}
|
|
667
|
-
metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
|
|
668
|
-
metrics_path.write_text(
|
|
669
|
-
json.dumps(metrics_payload, indent=2, ensure_ascii=True),
|
|
670
|
-
encoding="utf-8",
|
|
671
|
-
)
|
|
672
|
-
return metrics_path
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
def _write_model_report(
|
|
676
|
-
report_root: Path,
|
|
677
|
-
model_name: str,
|
|
678
|
-
model_key: str,
|
|
679
|
-
version: str,
|
|
680
|
-
metrics: Dict[str, Any],
|
|
681
|
-
risk_trend: Optional[pd.DataFrame],
|
|
682
|
-
psi_report_df: Optional[pd.DataFrame],
|
|
683
|
-
validation_table: Optional[pd.DataFrame],
|
|
684
|
-
calibration_info: Optional[Dict[str, Any]],
|
|
685
|
-
threshold_info: Optional[Dict[str, Any]],
|
|
686
|
-
bootstrap_results: Dict[str, Dict[str, float]],
|
|
687
|
-
config_sha: str,
|
|
688
|
-
data_fingerprint: Dict[str, Any],
|
|
689
|
-
) -> Optional[Path]:
|
|
690
|
-
"""Write model report and return the path."""
|
|
691
|
-
if ReportPayload is None or write_report is None:
|
|
692
|
-
return None
|
|
693
|
-
|
|
694
|
-
notes_lines = [
|
|
695
|
-
f"- Config SHA256: {config_sha}",
|
|
696
|
-
f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
|
|
697
|
-
]
|
|
698
|
-
if calibration_info:
|
|
699
|
-
notes_lines.append(f"- Calibration: {calibration_info.get('method')}")
|
|
700
|
-
if threshold_info:
|
|
701
|
-
notes_lines.append(f"- Threshold selection: {threshold_info}")
|
|
702
|
-
if bootstrap_results:
|
|
703
|
-
notes_lines.append("- Bootstrap: see metrics JSON for CI")
|
|
704
|
-
|
|
705
|
-
payload = ReportPayload(
|
|
706
|
-
model_name=f"{model_name}/{model_key}",
|
|
707
|
-
model_version=version,
|
|
708
|
-
metrics={k: float(v) for k, v in metrics.items()},
|
|
709
|
-
risk_trend=risk_trend,
|
|
710
|
-
drift_report=psi_report_df,
|
|
711
|
-
validation_table=validation_table,
|
|
712
|
-
extra_notes="\n".join(notes_lines),
|
|
713
|
-
)
|
|
714
|
-
return write_report(
|
|
715
|
-
payload,
|
|
716
|
-
report_root / f"{model_name}_{model_key}_report.md",
|
|
717
|
-
)
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
def _register_model_to_registry(
|
|
721
|
-
model: ropt.BayesOptModel,
|
|
722
|
-
model_name: str,
|
|
723
|
-
model_key: str,
|
|
724
|
-
version: str,
|
|
725
|
-
metrics: Dict[str, Any],
|
|
726
|
-
task_type: str,
|
|
727
|
-
data_path: Path,
|
|
728
|
-
data_fingerprint: Dict[str, Any],
|
|
729
|
-
config_sha: str,
|
|
730
|
-
registry_path: Optional[str],
|
|
731
|
-
registry_tags: Dict[str, Any],
|
|
732
|
-
registry_status: str,
|
|
733
|
-
report_path: Optional[Path],
|
|
734
|
-
metrics_path: Path,
|
|
735
|
-
cfg: Dict[str, Any],
|
|
736
|
-
) -> None:
|
|
737
|
-
"""Register model artifacts to the model registry."""
|
|
738
|
-
if ModelRegistry is None or ModelArtifact is None:
|
|
739
|
-
return
|
|
740
|
-
|
|
741
|
-
registry = ModelRegistry(
|
|
742
|
-
registry_path
|
|
743
|
-
if registry_path
|
|
744
|
-
else Path(model.output_manager.result_dir) / "model_registry.json"
|
|
745
|
-
)
|
|
746
|
-
|
|
747
|
-
tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
|
|
748
|
-
tags.update({
|
|
749
|
-
"model_key": str(model_key),
|
|
750
|
-
"task_type": str(task_type),
|
|
751
|
-
"data_path": str(data_path),
|
|
752
|
-
"data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
|
|
753
|
-
"data_size": str(data_fingerprint.get("size", "")),
|
|
754
|
-
"data_mtime": str(data_fingerprint.get("mtime", "")),
|
|
755
|
-
"config_sha256": str(config_sha),
|
|
756
|
-
})
|
|
757
|
-
|
|
758
|
-
artifacts = _collect_model_artifacts(
|
|
759
|
-
model, model_name, model_key, report_path, metrics_path, cfg
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
registry.register(
|
|
763
|
-
name=str(model_name),
|
|
764
|
-
version=version,
|
|
765
|
-
metrics={k: float(v) for k, v in metrics.items()},
|
|
766
|
-
tags=tags,
|
|
767
|
-
artifacts=artifacts,
|
|
768
|
-
status=str(registry_status or "candidate"),
|
|
769
|
-
notes=f"model_key={model_key}",
|
|
770
|
-
)
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
def _collect_model_artifacts(
|
|
774
|
-
model: ropt.BayesOptModel,
|
|
775
|
-
model_name: str,
|
|
776
|
-
model_key: str,
|
|
777
|
-
report_path: Optional[Path],
|
|
778
|
-
metrics_path: Path,
|
|
779
|
-
cfg: Dict[str, Any],
|
|
780
|
-
) -> List:
|
|
781
|
-
"""Collect all model artifacts for registry."""
|
|
782
|
-
artifacts = []
|
|
783
|
-
|
|
784
|
-
# Trained model artifact
|
|
785
|
-
trainer = model.trainers.get(model_key)
|
|
786
|
-
if trainer is not None:
|
|
787
|
-
try:
|
|
788
|
-
model_path = trainer.output.model_path(trainer._get_model_filename())
|
|
789
|
-
if os.path.exists(model_path):
|
|
790
|
-
artifacts.append(ModelArtifact(path=model_path, description="trained model"))
|
|
791
|
-
except Exception:
|
|
792
|
-
pass
|
|
793
|
-
|
|
794
|
-
# Report artifact
|
|
795
|
-
if report_path is not None:
|
|
796
|
-
artifacts.append(ModelArtifact(path=str(report_path), description="model report"))
|
|
797
|
-
|
|
798
|
-
# Metrics JSON artifact
|
|
799
|
-
if metrics_path.exists():
|
|
800
|
-
artifacts.append(ModelArtifact(path=str(metrics_path), description="metrics json"))
|
|
801
|
-
|
|
802
|
-
# Preprocess artifacts
|
|
803
|
-
if bool(cfg.get("save_preprocess", False)):
|
|
804
|
-
artifact_path = cfg.get("preprocess_artifact_path")
|
|
805
|
-
if artifact_path:
|
|
806
|
-
preprocess_path = Path(str(artifact_path))
|
|
807
|
-
if not preprocess_path.is_absolute():
|
|
808
|
-
preprocess_path = Path(model.output_manager.result_dir) / preprocess_path
|
|
809
|
-
else:
|
|
810
|
-
preprocess_path = Path(model.output_manager.result_path(
|
|
811
|
-
f"{model.model_nme}_preprocess.json"
|
|
812
|
-
))
|
|
813
|
-
if preprocess_path.exists():
|
|
814
|
-
artifacts.append(
|
|
815
|
-
ModelArtifact(path=str(preprocess_path), description="preprocess artifacts")
|
|
816
|
-
)
|
|
817
|
-
|
|
818
|
-
# Prediction cache artifacts
|
|
819
|
-
if bool(cfg.get("cache_predictions", False)):
|
|
820
|
-
cache_dir = cfg.get("prediction_cache_dir")
|
|
821
|
-
if cache_dir:
|
|
822
|
-
pred_root = Path(str(cache_dir))
|
|
823
|
-
if not pred_root.is_absolute():
|
|
824
|
-
pred_root = Path(model.output_manager.result_dir) / pred_root
|
|
825
|
-
else:
|
|
826
|
-
pred_root = Path(model.output_manager.result_dir) / "predictions"
|
|
827
|
-
ext = "csv" if str(cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
|
|
828
|
-
for split_label in ("train", "test"):
|
|
829
|
-
pred_path = pred_root / f"{model_name}_{model_key}_{split_label}.{ext}"
|
|
830
|
-
if pred_path.exists():
|
|
831
|
-
artifacts.append(
|
|
832
|
-
ModelArtifact(path=str(pred_path), description=f"predictions {split_label}")
|
|
833
|
-
)
|
|
834
|
-
|
|
835
|
-
return artifacts
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
def _evaluate_and_report(
|
|
839
|
-
model: ropt.BayesOptModel,
|
|
840
|
-
*,
|
|
841
|
-
model_name: str,
|
|
842
|
-
model_key: str,
|
|
843
|
-
cfg: Dict[str, Any],
|
|
844
|
-
data_path: Path,
|
|
845
|
-
data_fingerprint: Dict[str, Any],
|
|
846
|
-
report_output_dir: Optional[str],
|
|
847
|
-
report_group_cols: Optional[List[str]],
|
|
848
|
-
report_time_col: Optional[str],
|
|
849
|
-
report_time_freq: str,
|
|
850
|
-
report_time_ascending: bool,
|
|
851
|
-
psi_report_df: Optional[pd.DataFrame],
|
|
852
|
-
calibration_cfg: Dict[str, Any],
|
|
853
|
-
threshold_cfg: Dict[str, Any],
|
|
854
|
-
bootstrap_cfg: Dict[str, Any],
|
|
855
|
-
register_model: bool,
|
|
856
|
-
registry_path: Optional[str],
|
|
857
|
-
registry_tags: Dict[str, Any],
|
|
858
|
-
registry_status: str,
|
|
859
|
-
run_id: str,
|
|
860
|
-
config_sha: str,
|
|
861
|
-
) -> None:
|
|
862
|
-
"""Evaluate model predictions and generate reports.
|
|
863
|
-
|
|
864
|
-
This function orchestrates the evaluation pipeline:
|
|
865
|
-
1. Extract predictions and ground truth
|
|
866
|
-
2. Apply calibration (for classification)
|
|
867
|
-
3. Select threshold (for classification)
|
|
868
|
-
4. Compute metrics
|
|
869
|
-
5. Compute bootstrap confidence intervals
|
|
870
|
-
6. Generate validation tables and risk trends
|
|
871
|
-
7. Write reports and register model
|
|
872
|
-
"""
|
|
873
|
-
if eval_metrics_report is None:
|
|
874
|
-
print("[Report] Skip evaluation: metrics module unavailable.")
|
|
875
|
-
return
|
|
876
|
-
|
|
877
|
-
pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
|
|
878
|
-
if pred_col not in model.test_data.columns:
|
|
879
|
-
print(f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
|
|
880
|
-
return
|
|
881
|
-
|
|
882
|
-
# Extract predictions and weights
|
|
883
|
-
weight_col = getattr(model, "weight_nme", None)
|
|
884
|
-
y_true_train = model.train_data[model.resp_nme].to_numpy(dtype=float, copy=False)
|
|
885
|
-
y_true_test = model.test_data[model.resp_nme].to_numpy(dtype=float, copy=False)
|
|
886
|
-
y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
|
|
887
|
-
y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
|
|
888
|
-
weight_test = (
|
|
889
|
-
model.test_data[weight_col].to_numpy(dtype=float, copy=False)
|
|
890
|
-
if weight_col and weight_col in model.test_data.columns
|
|
891
|
-
else None
|
|
892
|
-
)
|
|
893
|
-
|
|
894
|
-
task_type = str(cfg.get("task_type", getattr(model, "task_type", "regression")))
|
|
895
|
-
|
|
896
|
-
# Process based on task type
|
|
897
|
-
if task_type == "classification":
|
|
898
|
-
y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
|
|
899
|
-
y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
|
|
900
|
-
|
|
901
|
-
y_pred_train_eval, y_pred_test_eval, calibration_info = _apply_calibration(
|
|
902
|
-
y_true_train, y_pred_train, y_pred_test, calibration_cfg, model_name, model_key
|
|
903
|
-
)
|
|
904
|
-
threshold_value, threshold_info = _select_classification_threshold(
|
|
905
|
-
y_true_train, y_pred_train_eval, threshold_cfg
|
|
906
|
-
)
|
|
907
|
-
metrics = _compute_classification_metrics(y_true_test, y_pred_test_eval, threshold_value)
|
|
908
|
-
else:
|
|
909
|
-
y_pred_test_eval = y_pred_test
|
|
910
|
-
calibration_info = None
|
|
911
|
-
threshold_info = None
|
|
912
|
-
metrics = eval_metrics_report(
|
|
913
|
-
y_true_test, y_pred_test_eval, task_type=task_type, weight=weight_test
|
|
914
|
-
)
|
|
915
|
-
|
|
916
|
-
# Compute bootstrap confidence intervals
|
|
917
|
-
bootstrap_results = _compute_bootstrap_ci(
|
|
918
|
-
y_true_test, y_pred_test_eval, weight_test, metrics, bootstrap_cfg, task_type
|
|
919
|
-
)
|
|
920
|
-
|
|
921
|
-
# Compute validation table and risk trend
|
|
922
|
-
validation_table = _compute_validation_table(
|
|
923
|
-
model, pred_col, report_group_cols, weight_col, model_name, model_key
|
|
924
|
-
)
|
|
925
|
-
risk_trend = _compute_risk_trend(
|
|
926
|
-
model, pred_col, report_time_col, report_time_freq,
|
|
927
|
-
report_time_ascending, weight_col, model_name, model_key
|
|
928
|
-
)
|
|
929
|
-
|
|
930
|
-
# Setup output directory
|
|
931
|
-
report_root = (
|
|
932
|
-
Path(report_output_dir)
|
|
933
|
-
if report_output_dir
|
|
934
|
-
else Path(model.output_manager.result_dir) / "reports"
|
|
935
|
-
)
|
|
936
|
-
report_root.mkdir(parents=True, exist_ok=True)
|
|
937
|
-
version = f"{model_key}_{run_id}"
|
|
938
|
-
|
|
939
|
-
# Write metrics JSON
|
|
940
|
-
metrics_path = _write_metrics_json(
|
|
941
|
-
report_root, model_name, model_key, version, metrics,
|
|
942
|
-
threshold_info, calibration_info, bootstrap_results,
|
|
943
|
-
data_path, data_fingerprint, config_sha, pred_col, task_type
|
|
944
|
-
)
|
|
945
|
-
|
|
946
|
-
# Write model report
|
|
947
|
-
report_path = _write_model_report(
|
|
948
|
-
report_root, model_name, model_key, version, metrics,
|
|
949
|
-
risk_trend, psi_report_df, validation_table,
|
|
950
|
-
calibration_info, threshold_info, bootstrap_results,
|
|
951
|
-
config_sha, data_fingerprint
|
|
952
|
-
)
|
|
953
|
-
|
|
954
|
-
# Register model
|
|
955
|
-
if register_model:
|
|
956
|
-
_register_model_to_registry(
|
|
957
|
-
model, model_name, model_key, version, metrics, task_type,
|
|
958
|
-
data_path, data_fingerprint, config_sha, registry_path,
|
|
959
|
-
registry_tags, registry_status, report_path, metrics_path, cfg
|
|
960
|
-
)
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
def _evaluate_with_context(
|
|
964
|
-
model: ropt.BayesOptModel,
|
|
965
|
-
ctx: EvaluationContext,
|
|
966
|
-
) -> None:
|
|
967
|
-
"""Evaluate model predictions using context object.
|
|
968
|
-
|
|
969
|
-
This is a cleaner interface that uses the EvaluationContext dataclass
|
|
970
|
-
instead of 19+ individual parameters.
|
|
971
|
-
"""
|
|
972
|
-
_evaluate_and_report(
|
|
973
|
-
model,
|
|
974
|
-
model_name=ctx.identity.model_name,
|
|
975
|
-
model_key=ctx.identity.model_key,
|
|
976
|
-
cfg=ctx.cfg,
|
|
977
|
-
data_path=ctx.data_path,
|
|
978
|
-
data_fingerprint=ctx.data_fingerprint.to_dict(),
|
|
979
|
-
report_output_dir=ctx.report.output_dir,
|
|
980
|
-
report_group_cols=ctx.report.group_cols,
|
|
981
|
-
report_time_col=ctx.report.time_col,
|
|
982
|
-
report_time_freq=ctx.report.time_freq,
|
|
983
|
-
report_time_ascending=ctx.report.time_ascending,
|
|
984
|
-
psi_report_df=ctx.psi_report_df,
|
|
985
|
-
calibration_cfg={
|
|
986
|
-
"enable": ctx.calibration.enable,
|
|
987
|
-
"method": ctx.calibration.method,
|
|
988
|
-
"max_rows": ctx.calibration.max_rows,
|
|
989
|
-
"seed": ctx.calibration.seed,
|
|
990
|
-
},
|
|
991
|
-
threshold_cfg={
|
|
992
|
-
"enable": ctx.threshold.enable,
|
|
993
|
-
"metric": ctx.threshold.metric,
|
|
994
|
-
"value": ctx.threshold.value,
|
|
995
|
-
"min_positive_rate": ctx.threshold.min_positive_rate,
|
|
996
|
-
"grid": ctx.threshold.grid,
|
|
997
|
-
"max_rows": ctx.threshold.max_rows,
|
|
998
|
-
"seed": ctx.threshold.seed,
|
|
999
|
-
},
|
|
1000
|
-
bootstrap_cfg={
|
|
1001
|
-
"enable": ctx.bootstrap.enable,
|
|
1002
|
-
"metrics": ctx.bootstrap.metrics,
|
|
1003
|
-
"n_samples": ctx.bootstrap.n_samples,
|
|
1004
|
-
"ci": ctx.bootstrap.ci,
|
|
1005
|
-
"seed": ctx.bootstrap.seed,
|
|
1006
|
-
},
|
|
1007
|
-
register_model=ctx.registry.register,
|
|
1008
|
-
registry_path=ctx.registry.path,
|
|
1009
|
-
registry_tags=ctx.registry.tags,
|
|
1010
|
-
registry_status=ctx.registry.status,
|
|
1011
|
-
run_id=ctx.run_id,
|
|
1012
|
-
config_sha=ctx.config_sha,
|
|
1013
|
-
)
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
def _create_ddp_barrier(dist_ctx: TrainingContext):
|
|
1017
|
-
"""Create a DDP barrier function for distributed training synchronization."""
|
|
1018
|
-
def _ddp_barrier(reason: str) -> None:
|
|
1019
|
-
if not dist_ctx.is_distributed:
|
|
1020
|
-
return
|
|
1021
|
-
torch_mod = getattr(ropt, "torch", None)
|
|
1022
|
-
dist_mod = getattr(torch_mod, "distributed", None)
|
|
1023
|
-
if dist_mod is None:
|
|
1024
|
-
return
|
|
1025
|
-
try:
|
|
1026
|
-
if not getattr(dist_mod, "is_available", lambda: False)():
|
|
1027
|
-
return
|
|
1028
|
-
if not dist_mod.is_initialized():
|
|
1029
|
-
ddp_ok, _, _, _ = ropt.DistributedUtils.setup_ddp()
|
|
1030
|
-
if not ddp_ok or not dist_mod.is_initialized():
|
|
1031
|
-
return
|
|
1032
|
-
dist_mod.barrier()
|
|
1033
|
-
except Exception as exc:
|
|
1034
|
-
print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
|
|
1035
|
-
raise
|
|
1036
|
-
return _ddp_barrier
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
def train_from_config(args: argparse.Namespace) -> None:
|
|
1040
|
-
script_dir = Path(__file__).resolve().parents[1]
|
|
1041
|
-
config_path, cfg = resolve_and_load_config(
|
|
1042
|
-
args.config_json,
|
|
1043
|
-
script_dir,
|
|
1044
|
-
required_keys=["data_dir", "model_list",
|
|
1045
|
-
"model_categories", "target", "weight"],
|
|
1046
|
-
)
|
|
1047
|
-
plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
|
|
1048
|
-
config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
|
|
1049
|
-
run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
1050
|
-
|
|
1051
|
-
# Use TrainingContext for distributed training state
|
|
1052
|
-
dist_ctx = TrainingContext.from_env()
|
|
1053
|
-
dist_world_size = dist_ctx.world_size
|
|
1054
|
-
dist_rank = dist_ctx.rank
|
|
1055
|
-
dist_active = dist_ctx.is_distributed
|
|
1056
|
-
is_main_process = dist_ctx.is_main_process
|
|
1057
|
-
_ddp_barrier = _create_ddp_barrier(dist_ctx)
|
|
1058
|
-
|
|
1059
|
-
data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
|
|
1060
|
-
cfg,
|
|
1061
|
-
config_path,
|
|
1062
|
-
create_data_dir=True,
|
|
1063
|
-
)
|
|
1064
|
-
runtime_cfg = resolve_runtime_config(cfg)
|
|
1065
|
-
ddp_min_rows = runtime_cfg["ddp_min_rows"]
|
|
1066
|
-
bo_sample_limit = runtime_cfg["bo_sample_limit"]
|
|
1067
|
-
cache_predictions = runtime_cfg["cache_predictions"]
|
|
1068
|
-
prediction_cache_dir = runtime_cfg["prediction_cache_dir"]
|
|
1069
|
-
prediction_cache_format = runtime_cfg["prediction_cache_format"]
|
|
1070
|
-
report_cfg = resolve_report_config(cfg)
|
|
1071
|
-
report_output_dir = report_cfg["report_output_dir"]
|
|
1072
|
-
report_group_cols = report_cfg["report_group_cols"]
|
|
1073
|
-
report_time_col = report_cfg["report_time_col"]
|
|
1074
|
-
report_time_freq = report_cfg["report_time_freq"]
|
|
1075
|
-
report_time_ascending = report_cfg["report_time_ascending"]
|
|
1076
|
-
psi_bins = report_cfg["psi_bins"]
|
|
1077
|
-
psi_strategy = report_cfg["psi_strategy"]
|
|
1078
|
-
psi_features = report_cfg["psi_features"]
|
|
1079
|
-
calibration_cfg = report_cfg["calibration_cfg"]
|
|
1080
|
-
threshold_cfg = report_cfg["threshold_cfg"]
|
|
1081
|
-
bootstrap_cfg = report_cfg["bootstrap_cfg"]
|
|
1082
|
-
register_model = report_cfg["register_model"]
|
|
1083
|
-
registry_path = report_cfg["registry_path"]
|
|
1084
|
-
registry_tags = report_cfg["registry_tags"]
|
|
1085
|
-
registry_status = report_cfg["registry_status"]
|
|
1086
|
-
data_fingerprint_max_bytes = report_cfg["data_fingerprint_max_bytes"]
|
|
1087
|
-
report_enabled = report_cfg["report_enabled"]
|
|
1088
|
-
|
|
1089
|
-
split_cfg = resolve_split_config(cfg)
|
|
1090
|
-
prop_test = split_cfg["prop_test"]
|
|
1091
|
-
holdout_ratio = split_cfg["holdout_ratio"]
|
|
1092
|
-
val_ratio = split_cfg["val_ratio"]
|
|
1093
|
-
split_strategy = split_cfg["split_strategy"]
|
|
1094
|
-
split_group_col = split_cfg["split_group_col"]
|
|
1095
|
-
split_time_col = split_cfg["split_time_col"]
|
|
1096
|
-
split_time_ascending = split_cfg["split_time_ascending"]
|
|
1097
|
-
cv_strategy = split_cfg["cv_strategy"]
|
|
1098
|
-
cv_group_col = split_cfg["cv_group_col"]
|
|
1099
|
-
cv_time_col = split_cfg["cv_time_col"]
|
|
1100
|
-
cv_time_ascending = split_cfg["cv_time_ascending"]
|
|
1101
|
-
cv_splits = split_cfg["cv_splits"]
|
|
1102
|
-
ft_oof_folds = split_cfg["ft_oof_folds"]
|
|
1103
|
-
ft_oof_strategy = split_cfg["ft_oof_strategy"]
|
|
1104
|
-
ft_oof_shuffle = split_cfg["ft_oof_shuffle"]
|
|
1105
|
-
save_preprocess = runtime_cfg["save_preprocess"]
|
|
1106
|
-
preprocess_artifact_path = runtime_cfg["preprocess_artifact_path"]
|
|
1107
|
-
rand_seed = runtime_cfg["rand_seed"]
|
|
1108
|
-
epochs = runtime_cfg["epochs"]
|
|
1109
|
-
output_cfg = resolve_output_dirs(
|
|
1110
|
-
cfg,
|
|
1111
|
-
config_path,
|
|
1112
|
-
output_override=args.output_dir,
|
|
1113
|
-
)
|
|
1114
|
-
output_dir = output_cfg["output_dir"]
|
|
1115
|
-
reuse_best_params = bool(
|
|
1116
|
-
args.reuse_best_params or runtime_cfg["reuse_best_params"])
|
|
1117
|
-
xgb_max_depth_max = runtime_cfg["xgb_max_depth_max"]
|
|
1118
|
-
xgb_n_estimators_max = runtime_cfg["xgb_n_estimators_max"]
|
|
1119
|
-
optuna_storage = runtime_cfg["optuna_storage"]
|
|
1120
|
-
optuna_study_prefix = runtime_cfg["optuna_study_prefix"]
|
|
1121
|
-
best_params_files = runtime_cfg["best_params_files"]
|
|
1122
|
-
plot_path_style = runtime_cfg["plot_path_style"]
|
|
1123
|
-
|
|
1124
|
-
model_names = build_model_names(
|
|
1125
|
-
cfg["model_list"], cfg["model_categories"])
|
|
1126
|
-
if not model_names:
|
|
1127
|
-
raise ValueError(
|
|
1128
|
-
"No model names generated from model_list/model_categories.")
|
|
1129
|
-
|
|
1130
|
-
results: Dict[str, ropt.BayesOptModel] = {}
|
|
1131
|
-
trained_keys_by_model: Dict[str, List[str]] = {}
|
|
1132
|
-
|
|
1133
|
-
for model_name in model_names:
|
|
1134
|
-
# Per-dataset training loop: load data, split train/test, and train requested models.
|
|
1135
|
-
data_path = resolve_data_path(
|
|
1136
|
-
data_dir,
|
|
1137
|
-
model_name,
|
|
1138
|
-
data_format=data_format,
|
|
1139
|
-
path_template=data_path_template,
|
|
1140
|
-
)
|
|
1141
|
-
if not data_path.exists():
|
|
1142
|
-
raise FileNotFoundError(f"Missing dataset: {data_path}")
|
|
1143
|
-
data_fingerprint = {"path": str(data_path)}
|
|
1144
|
-
if report_enabled and is_main_process:
|
|
1145
|
-
data_fingerprint = fingerprint_file(
|
|
1146
|
-
data_path,
|
|
1147
|
-
max_bytes=data_fingerprint_max_bytes,
|
|
1148
|
-
)
|
|
1149
|
-
|
|
1150
|
-
print(f"\n=== Processing model {model_name} ===")
|
|
1151
|
-
raw = load_dataset(
|
|
1152
|
-
data_path,
|
|
1153
|
-
data_format=data_format,
|
|
1154
|
-
dtype_map=dtype_map,
|
|
1155
|
-
low_memory=False,
|
|
1156
|
-
)
|
|
1157
|
-
raw = coerce_dataset_types(raw)
|
|
1158
|
-
|
|
1159
|
-
train_df, test_df = split_train_test(
|
|
1160
|
-
raw,
|
|
1161
|
-
holdout_ratio=holdout_ratio,
|
|
1162
|
-
strategy=split_strategy,
|
|
1163
|
-
group_col=split_group_col,
|
|
1164
|
-
time_col=split_time_col,
|
|
1165
|
-
time_ascending=split_time_ascending,
|
|
1166
|
-
rand_seed=rand_seed,
|
|
1167
|
-
reset_index_mode="time_group",
|
|
1168
|
-
ratio_label="holdout_ratio",
|
|
1169
|
-
)
|
|
1170
|
-
|
|
1171
|
-
use_resn_dp = args.use_resn_dp or cfg.get(
|
|
1172
|
-
"use_resn_data_parallel", False)
|
|
1173
|
-
use_ft_dp = args.use_ft_dp or cfg.get("use_ft_data_parallel", True)
|
|
1174
|
-
dataset_rows = len(raw)
|
|
1175
|
-
ddp_enabled = bool(dist_active and (dataset_rows >= int(ddp_min_rows)))
|
|
1176
|
-
use_resn_ddp = (args.use_resn_ddp or cfg.get(
|
|
1177
|
-
"use_resn_ddp", False)) and ddp_enabled
|
|
1178
|
-
use_ft_ddp = (args.use_ft_ddp or cfg.get(
|
|
1179
|
-
"use_ft_ddp", False)) and ddp_enabled
|
|
1180
|
-
use_gnn_dp = args.use_gnn_dp or cfg.get("use_gnn_data_parallel", False)
|
|
1181
|
-
use_gnn_ddp = (args.use_gnn_ddp or cfg.get(
|
|
1182
|
-
"use_gnn_ddp", False)) and ddp_enabled
|
|
1183
|
-
gnn_use_ann = cfg.get("gnn_use_approx_knn", True)
|
|
1184
|
-
if args.gnn_no_ann:
|
|
1185
|
-
gnn_use_ann = False
|
|
1186
|
-
gnn_threshold = args.gnn_ann_threshold if args.gnn_ann_threshold is not None else cfg.get(
|
|
1187
|
-
"gnn_approx_knn_threshold", 50000)
|
|
1188
|
-
gnn_graph_cache = args.gnn_graph_cache or cfg.get("gnn_graph_cache")
|
|
1189
|
-
if isinstance(gnn_graph_cache, str) and gnn_graph_cache.strip():
|
|
1190
|
-
resolved_cache = resolve_path(gnn_graph_cache, config_path.parent)
|
|
1191
|
-
if resolved_cache is not None:
|
|
1192
|
-
gnn_graph_cache = str(resolved_cache)
|
|
1193
|
-
gnn_max_gpu_nodes = args.gnn_max_gpu_nodes if args.gnn_max_gpu_nodes is not None else cfg.get(
|
|
1194
|
-
"gnn_max_gpu_knn_nodes", 200000)
|
|
1195
|
-
gnn_gpu_mem_ratio = args.gnn_gpu_mem_ratio if args.gnn_gpu_mem_ratio is not None else cfg.get(
|
|
1196
|
-
"gnn_knn_gpu_mem_ratio", 0.9)
|
|
1197
|
-
gnn_gpu_mem_overhead = args.gnn_gpu_mem_overhead if args.gnn_gpu_mem_overhead is not None else cfg.get(
|
|
1198
|
-
"gnn_knn_gpu_mem_overhead", 2.0)
|
|
1199
|
-
|
|
1200
|
-
binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
|
|
1201
|
-
task_type = str(cfg.get("task_type", "regression"))
|
|
1202
|
-
feature_list = cfg.get("feature_list")
|
|
1203
|
-
categorical_features = cfg.get("categorical_features")
|
|
1204
|
-
use_gpu = bool(cfg.get("use_gpu", True))
|
|
1205
|
-
region_province_col = cfg.get("region_province_col")
|
|
1206
|
-
region_city_col = cfg.get("region_city_col")
|
|
1207
|
-
region_effect_alpha = cfg.get("region_effect_alpha")
|
|
1208
|
-
geo_feature_nmes = cfg.get("geo_feature_nmes")
|
|
1209
|
-
geo_token_hidden_dim = cfg.get("geo_token_hidden_dim")
|
|
1210
|
-
geo_token_layers = cfg.get("geo_token_layers")
|
|
1211
|
-
geo_token_dropout = cfg.get("geo_token_dropout")
|
|
1212
|
-
geo_token_k_neighbors = cfg.get("geo_token_k_neighbors")
|
|
1213
|
-
geo_token_learning_rate = cfg.get("geo_token_learning_rate")
|
|
1214
|
-
geo_token_epochs = cfg.get("geo_token_epochs")
|
|
1215
|
-
|
|
1216
|
-
ft_role = args.ft_role or cfg.get("ft_role", "model")
|
|
1217
|
-
if args.ft_as_feature and args.ft_role is None:
|
|
1218
|
-
# Keep legacy behavior as a convenience alias only when the config
|
|
1219
|
-
# didn't already request a non-default FT role.
|
|
1220
|
-
if str(cfg.get("ft_role", "model")) == "model":
|
|
1221
|
-
ft_role = "embedding"
|
|
1222
|
-
ft_feature_prefix = str(
|
|
1223
|
-
cfg.get("ft_feature_prefix", args.ft_feature_prefix))
|
|
1224
|
-
ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
|
|
1225
|
-
|
|
1226
|
-
config_fields = getattr(ropt.BayesOptConfig,
|
|
1227
|
-
"__dataclass_fields__", {})
|
|
1228
|
-
allowed_config_keys = set(config_fields.keys())
|
|
1229
|
-
config_payload = {k: v for k,
|
|
1230
|
-
v in cfg.items() if k in allowed_config_keys}
|
|
1231
|
-
config_payload.update({
|
|
1232
|
-
"model_nme": model_name,
|
|
1233
|
-
"resp_nme": cfg["target"],
|
|
1234
|
-
"weight_nme": cfg["weight"],
|
|
1235
|
-
"factor_nmes": feature_list,
|
|
1236
|
-
"task_type": task_type,
|
|
1237
|
-
"binary_resp_nme": binary_target,
|
|
1238
|
-
"cate_list": categorical_features,
|
|
1239
|
-
"prop_test": val_ratio,
|
|
1240
|
-
"rand_seed": rand_seed,
|
|
1241
|
-
"epochs": epochs,
|
|
1242
|
-
"use_gpu": use_gpu,
|
|
1243
|
-
"use_resn_data_parallel": use_resn_dp,
|
|
1244
|
-
"use_ft_data_parallel": use_ft_dp,
|
|
1245
|
-
"use_gnn_data_parallel": use_gnn_dp,
|
|
1246
|
-
"use_resn_ddp": use_resn_ddp,
|
|
1247
|
-
"use_ft_ddp": use_ft_ddp,
|
|
1248
|
-
"use_gnn_ddp": use_gnn_ddp,
|
|
1249
|
-
"output_dir": output_dir,
|
|
1250
|
-
"xgb_max_depth_max": xgb_max_depth_max,
|
|
1251
|
-
"xgb_n_estimators_max": xgb_n_estimators_max,
|
|
1252
|
-
"resn_weight_decay": cfg.get("resn_weight_decay"),
|
|
1253
|
-
"final_ensemble": bool(cfg.get("final_ensemble", False)),
|
|
1254
|
-
"final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
|
|
1255
|
-
"final_refit": bool(cfg.get("final_refit", True)),
|
|
1256
|
-
"optuna_storage": optuna_storage,
|
|
1257
|
-
"optuna_study_prefix": optuna_study_prefix,
|
|
1258
|
-
"best_params_files": best_params_files,
|
|
1259
|
-
"gnn_use_approx_knn": gnn_use_ann,
|
|
1260
|
-
"gnn_approx_knn_threshold": gnn_threshold,
|
|
1261
|
-
"gnn_graph_cache": gnn_graph_cache,
|
|
1262
|
-
"gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
|
|
1263
|
-
"gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
|
|
1264
|
-
"gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
|
|
1265
|
-
"region_province_col": region_province_col,
|
|
1266
|
-
"region_city_col": region_city_col,
|
|
1267
|
-
"region_effect_alpha": region_effect_alpha,
|
|
1268
|
-
"geo_feature_nmes": geo_feature_nmes,
|
|
1269
|
-
"geo_token_hidden_dim": geo_token_hidden_dim,
|
|
1270
|
-
"geo_token_layers": geo_token_layers,
|
|
1271
|
-
"geo_token_dropout": geo_token_dropout,
|
|
1272
|
-
"geo_token_k_neighbors": geo_token_k_neighbors,
|
|
1273
|
-
"geo_token_learning_rate": geo_token_learning_rate,
|
|
1274
|
-
"geo_token_epochs": geo_token_epochs,
|
|
1275
|
-
"ft_role": ft_role,
|
|
1276
|
-
"ft_feature_prefix": ft_feature_prefix,
|
|
1277
|
-
"ft_num_numeric_tokens": ft_num_numeric_tokens,
|
|
1278
|
-
"reuse_best_params": reuse_best_params,
|
|
1279
|
-
"bo_sample_limit": bo_sample_limit,
|
|
1280
|
-
"cache_predictions": cache_predictions,
|
|
1281
|
-
"prediction_cache_dir": prediction_cache_dir,
|
|
1282
|
-
"prediction_cache_format": prediction_cache_format,
|
|
1283
|
-
"cv_strategy": cv_strategy or split_strategy,
|
|
1284
|
-
"cv_group_col": cv_group_col or split_group_col,
|
|
1285
|
-
"cv_time_col": cv_time_col or split_time_col,
|
|
1286
|
-
"cv_time_ascending": cv_time_ascending,
|
|
1287
|
-
"cv_splits": cv_splits,
|
|
1288
|
-
"ft_oof_folds": ft_oof_folds,
|
|
1289
|
-
"ft_oof_strategy": ft_oof_strategy,
|
|
1290
|
-
"ft_oof_shuffle": ft_oof_shuffle,
|
|
1291
|
-
"save_preprocess": save_preprocess,
|
|
1292
|
-
"preprocess_artifact_path": preprocess_artifact_path,
|
|
1293
|
-
"plot_path_style": plot_path_style or "nested",
|
|
1294
|
-
})
|
|
1295
|
-
config_payload = {
|
|
1296
|
-
k: v for k, v in config_payload.items() if v is not None}
|
|
1297
|
-
config = ropt.BayesOptConfig(**config_payload)
|
|
1298
|
-
model = ropt.BayesOptModel(train_df, test_df, config=config)
|
|
1299
|
-
|
|
1300
|
-
if plot_requested:
|
|
1301
|
-
plot_cfg = cfg.get("plot", {})
|
|
1302
|
-
legacy_lift_flags = {
|
|
1303
|
-
"glm": cfg.get("plot_lift_glm", False),
|
|
1304
|
-
"xgb": cfg.get("plot_lift_xgb", False),
|
|
1305
|
-
"resn": cfg.get("plot_lift_resn", False),
|
|
1306
|
-
"ft": cfg.get("plot_lift_ft", False),
|
|
1307
|
-
}
|
|
1308
|
-
plot_enabled = plot_cfg.get(
|
|
1309
|
-
"enable", any(legacy_lift_flags.values()))
|
|
1310
|
-
if plot_enabled and plot_cfg.get("pre_oneway", False) and plot_cfg.get("oneway", True):
|
|
1311
|
-
n_bins = int(plot_cfg.get("n_bins", 10))
|
|
1312
|
-
model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/pre")
|
|
1313
|
-
|
|
1314
|
-
if "all" in args.model_keys:
|
|
1315
|
-
requested_keys = ["glm", "xgb", "resn", "ft", "gnn"]
|
|
1316
|
-
else:
|
|
1317
|
-
requested_keys = args.model_keys
|
|
1318
|
-
requested_keys = dedupe_preserve_order(requested_keys)
|
|
1319
|
-
|
|
1320
|
-
if ft_role != "model":
|
|
1321
|
-
requested_keys = [k for k in requested_keys if k != "ft"]
|
|
1322
|
-
if not requested_keys:
|
|
1323
|
-
stack_keys = args.stack_model_keys or cfg.get(
|
|
1324
|
-
"stack_model_keys")
|
|
1325
|
-
if stack_keys:
|
|
1326
|
-
if "all" in stack_keys:
|
|
1327
|
-
requested_keys = ["glm", "xgb", "resn", "gnn"]
|
|
1328
|
-
else:
|
|
1329
|
-
requested_keys = [k for k in stack_keys if k != "ft"]
|
|
1330
|
-
requested_keys = dedupe_preserve_order(requested_keys)
|
|
1331
|
-
if dist_active and ddp_enabled:
|
|
1332
|
-
ft_trainer = model.trainers.get("ft")
|
|
1333
|
-
if ft_trainer is None:
|
|
1334
|
-
raise ValueError("FT trainer is not available.")
|
|
1335
|
-
ft_trainer_uses_ddp = bool(
|
|
1336
|
-
getattr(ft_trainer, "enable_distributed_optuna", False))
|
|
1337
|
-
if not ft_trainer_uses_ddp:
|
|
1338
|
-
raise ValueError(
|
|
1339
|
-
"FT embedding under torchrun requires enabling FT DDP (use --use-ft-ddp or set use_ft_ddp=true)."
|
|
1340
|
-
)
|
|
1341
|
-
missing = [key for key in requested_keys if key not in model.trainers]
|
|
1342
|
-
if missing:
|
|
1343
|
-
raise ValueError(
|
|
1344
|
-
f"Trainer(s) {missing} not available for {model_name}")
|
|
1345
|
-
|
|
1346
|
-
executed_keys: List[str] = []
|
|
1347
|
-
if ft_role != "model":
|
|
1348
|
-
if dist_active and not ddp_enabled:
|
|
1349
|
-
_ddp_barrier("start_ft_embedding")
|
|
1350
|
-
if dist_rank != 0:
|
|
1351
|
-
_ddp_barrier("finish_ft_embedding")
|
|
1352
|
-
continue
|
|
1353
|
-
print(
|
|
1354
|
-
f"Optimizing ft as {ft_role} for {model_name} (max_evals={args.max_evals})")
|
|
1355
|
-
model.optimize_model("ft", max_evals=args.max_evals)
|
|
1356
|
-
model.trainers["ft"].save()
|
|
1357
|
-
if getattr(ropt, "torch", None) is not None and ropt.torch.cuda.is_available():
|
|
1358
|
-
ropt.free_cuda()
|
|
1359
|
-
if dist_active and not ddp_enabled:
|
|
1360
|
-
_ddp_barrier("finish_ft_embedding")
|
|
1361
|
-
for key in requested_keys:
|
|
1362
|
-
trainer = model.trainers[key]
|
|
1363
|
-
trainer_uses_ddp = bool(
|
|
1364
|
-
getattr(trainer, "enable_distributed_optuna", False))
|
|
1365
|
-
if dist_active and not trainer_uses_ddp:
|
|
1366
|
-
if dist_rank != 0:
|
|
1367
|
-
print(
|
|
1368
|
-
f"[Rank {dist_rank}] Skip {model_name}/{key} because trainer is not DDP-enabled."
|
|
1369
|
-
)
|
|
1370
|
-
_ddp_barrier(f"start_non_ddp_{model_name}_{key}")
|
|
1371
|
-
if dist_rank != 0:
|
|
1372
|
-
_ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
|
|
1373
|
-
continue
|
|
1374
|
-
|
|
1375
|
-
print(
|
|
1376
|
-
f"Optimizing {key} for {model_name} (max_evals={args.max_evals})")
|
|
1377
|
-
model.optimize_model(key, max_evals=args.max_evals)
|
|
1378
|
-
model.trainers[key].save()
|
|
1379
|
-
_plot_loss_curve_for_trainer(model_name, model.trainers[key])
|
|
1380
|
-
if key in PYTORCH_TRAINERS:
|
|
1381
|
-
ropt.free_cuda()
|
|
1382
|
-
if dist_active and not trainer_uses_ddp:
|
|
1383
|
-
_ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
|
|
1384
|
-
executed_keys.append(key)
|
|
1385
|
-
|
|
1386
|
-
if not executed_keys:
|
|
1387
|
-
continue
|
|
1388
|
-
|
|
1389
|
-
results[model_name] = model
|
|
1390
|
-
trained_keys_by_model[model_name] = executed_keys
|
|
1391
|
-
if report_enabled and is_main_process:
|
|
1392
|
-
psi_report_df = _compute_psi_report(
|
|
1393
|
-
model,
|
|
1394
|
-
features=psi_features,
|
|
1395
|
-
bins=psi_bins,
|
|
1396
|
-
strategy=str(psi_strategy),
|
|
1397
|
-
)
|
|
1398
|
-
for key in executed_keys:
|
|
1399
|
-
_evaluate_and_report(
|
|
1400
|
-
model,
|
|
1401
|
-
model_name=model_name,
|
|
1402
|
-
model_key=key,
|
|
1403
|
-
cfg=cfg,
|
|
1404
|
-
data_path=data_path,
|
|
1405
|
-
data_fingerprint=data_fingerprint,
|
|
1406
|
-
report_output_dir=report_output_dir,
|
|
1407
|
-
report_group_cols=report_group_cols,
|
|
1408
|
-
report_time_col=report_time_col,
|
|
1409
|
-
report_time_freq=str(report_time_freq),
|
|
1410
|
-
report_time_ascending=bool(report_time_ascending),
|
|
1411
|
-
psi_report_df=psi_report_df,
|
|
1412
|
-
calibration_cfg=calibration_cfg,
|
|
1413
|
-
threshold_cfg=threshold_cfg,
|
|
1414
|
-
bootstrap_cfg=bootstrap_cfg,
|
|
1415
|
-
register_model=register_model,
|
|
1416
|
-
registry_path=registry_path,
|
|
1417
|
-
registry_tags=registry_tags,
|
|
1418
|
-
registry_status=registry_status,
|
|
1419
|
-
run_id=run_id,
|
|
1420
|
-
config_sha=config_sha,
|
|
1421
|
-
)
|
|
1422
|
-
|
|
1423
|
-
if not plot_requested:
|
|
1424
|
-
return
|
|
1425
|
-
|
|
1426
|
-
for name, model in results.items():
|
|
1427
|
-
_plot_curves_for_model(
|
|
1428
|
-
model,
|
|
1429
|
-
trained_keys_by_model.get(name, []),
|
|
1430
|
-
cfg,
|
|
1431
|
-
)
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
def main() -> None:
|
|
1435
|
-
if configure_run_logging:
|
|
1436
|
-
configure_run_logging(prefix="bayesopt_entry")
|
|
1437
|
-
args = _parse_args()
|
|
1438
|
-
train_from_config(args)
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
if __name__ == "__main__":
|
|
1442
|
-
main()
|
|
1
|
+
"""
|
|
2
|
+
CLI entry point generated from BayesOpt_AutoPricing.ipynb so the workflow can
|
|
3
|
+
run non‑interactively (e.g., via torchrun).
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
python -m torch.distributed.run --standalone --nproc_per_node=2 \\
|
|
7
|
+
ins_pricing/cli/BayesOpt_entry.py \\
|
|
8
|
+
--config-json examples/config_template.json \\
|
|
9
|
+
--model-keys ft --max-evals 50 --use-ft-ddp
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
if __package__ in {None, ""}:
|
|
18
|
+
repo_root = Path(__file__).resolve().parents[2]
|
|
19
|
+
if str(repo_root) not in sys.path:
|
|
20
|
+
sys.path.insert(0, str(repo_root))
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
import hashlib
|
|
24
|
+
import json
|
|
25
|
+
import os
|
|
26
|
+
from datetime import datetime
|
|
27
|
+
from typing import Any, Dict, List, Optional
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import pandas as pd
|
|
31
|
+
|
|
32
|
+
# Use unified import resolver to eliminate nested try/except chains
|
|
33
|
+
from .utils.import_resolver import resolve_imports, setup_sys_path
|
|
34
|
+
from .utils.evaluation_context import (
|
|
35
|
+
EvaluationContext,
|
|
36
|
+
TrainingContext,
|
|
37
|
+
ModelIdentity,
|
|
38
|
+
DataFingerprint,
|
|
39
|
+
CalibrationConfig,
|
|
40
|
+
ThresholdConfig,
|
|
41
|
+
BootstrapConfig,
|
|
42
|
+
ReportConfig,
|
|
43
|
+
RegistryConfig,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Resolve all imports from a single location
|
|
47
|
+
setup_sys_path()
|
|
48
|
+
_imports = resolve_imports()
|
|
49
|
+
|
|
50
|
+
ropt = _imports.bayesopt
|
|
51
|
+
PLOT_MODEL_LABELS = _imports.PLOT_MODEL_LABELS
|
|
52
|
+
PYTORCH_TRAINERS = _imports.PYTORCH_TRAINERS
|
|
53
|
+
build_model_names = _imports.build_model_names
|
|
54
|
+
dedupe_preserve_order = _imports.dedupe_preserve_order
|
|
55
|
+
load_dataset = _imports.load_dataset
|
|
56
|
+
parse_model_pairs = _imports.parse_model_pairs
|
|
57
|
+
resolve_data_path = _imports.resolve_data_path
|
|
58
|
+
resolve_path = _imports.resolve_path
|
|
59
|
+
fingerprint_file = _imports.fingerprint_file
|
|
60
|
+
coerce_dataset_types = _imports.coerce_dataset_types
|
|
61
|
+
split_train_test = _imports.split_train_test
|
|
62
|
+
|
|
63
|
+
add_config_json_arg = _imports.add_config_json_arg
|
|
64
|
+
add_output_dir_arg = _imports.add_output_dir_arg
|
|
65
|
+
resolve_and_load_config = _imports.resolve_and_load_config
|
|
66
|
+
resolve_data_config = _imports.resolve_data_config
|
|
67
|
+
resolve_report_config = _imports.resolve_report_config
|
|
68
|
+
resolve_split_config = _imports.resolve_split_config
|
|
69
|
+
resolve_runtime_config = _imports.resolve_runtime_config
|
|
70
|
+
resolve_output_dirs = _imports.resolve_output_dirs
|
|
71
|
+
|
|
72
|
+
bootstrap_ci = _imports.bootstrap_ci
|
|
73
|
+
calibrate_predictions = _imports.calibrate_predictions
|
|
74
|
+
eval_metrics_report = _imports.metrics_report
|
|
75
|
+
select_threshold = _imports.select_threshold
|
|
76
|
+
|
|
77
|
+
ModelArtifact = _imports.ModelArtifact
|
|
78
|
+
ModelRegistry = _imports.ModelRegistry
|
|
79
|
+
drift_psi_report = _imports.drift_psi_report
|
|
80
|
+
group_metrics = _imports.group_metrics
|
|
81
|
+
ReportPayload = _imports.ReportPayload
|
|
82
|
+
write_report = _imports.write_report
|
|
83
|
+
|
|
84
|
+
configure_run_logging = _imports.configure_run_logging
|
|
85
|
+
plot_loss_curve_common = _imports.plot_loss_curve
|
|
86
|
+
|
|
87
|
+
import matplotlib
|
|
88
|
+
|
|
89
|
+
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
90
|
+
matplotlib.use("Agg")
|
|
91
|
+
import matplotlib.pyplot as plt
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _parse_args() -> argparse.Namespace:
|
|
95
|
+
parser = argparse.ArgumentParser(
|
|
96
|
+
description="Batch trainer generated from BayesOpt_AutoPricing notebook."
|
|
97
|
+
)
|
|
98
|
+
add_config_json_arg(
|
|
99
|
+
parser,
|
|
100
|
+
help_text="Path to the JSON config describing datasets and feature columns.",
|
|
101
|
+
)
|
|
102
|
+
parser.add_argument(
|
|
103
|
+
"--model-keys",
|
|
104
|
+
nargs="+",
|
|
105
|
+
default=["ft"],
|
|
106
|
+
choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
|
|
107
|
+
help="Space-separated list of trainers to run (e.g., --model-keys glm xgb). Include 'all' to run every trainer.",
|
|
108
|
+
)
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"--stack-model-keys",
|
|
111
|
+
nargs="+",
|
|
112
|
+
default=None,
|
|
113
|
+
choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
|
|
114
|
+
help=(
|
|
115
|
+
"Only used when ft_role != 'model' (FT runs as feature generator). "
|
|
116
|
+
"When provided (or when config defines stack_model_keys), these trainers run after FT features "
|
|
117
|
+
"are generated. Use 'all' to run every non-FT trainer."
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
parser.add_argument(
|
|
121
|
+
"--max-evals",
|
|
122
|
+
type=int,
|
|
123
|
+
default=50,
|
|
124
|
+
help="Optuna trial count per dataset.",
|
|
125
|
+
)
|
|
126
|
+
parser.add_argument(
|
|
127
|
+
"--use-resn-ddp",
|
|
128
|
+
action="store_true",
|
|
129
|
+
help="Force ResNet trainer to use DistributedDataParallel.",
|
|
130
|
+
)
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--use-ft-ddp",
|
|
133
|
+
action="store_true",
|
|
134
|
+
help="Force FT-Transformer trainer to use DistributedDataParallel.",
|
|
135
|
+
)
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"--use-resn-dp",
|
|
138
|
+
action="store_true",
|
|
139
|
+
help="Enable ResNet DataParallel fall-back regardless of config.",
|
|
140
|
+
)
|
|
141
|
+
parser.add_argument(
|
|
142
|
+
"--use-ft-dp",
|
|
143
|
+
action="store_true",
|
|
144
|
+
help="Enable FT-Transformer DataParallel fall-back regardless of config.",
|
|
145
|
+
)
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--use-gnn-dp",
|
|
148
|
+
action="store_true",
|
|
149
|
+
help="Enable GNN DataParallel fall-back regardless of config.",
|
|
150
|
+
)
|
|
151
|
+
parser.add_argument(
|
|
152
|
+
"--use-gnn-ddp",
|
|
153
|
+
action="store_true",
|
|
154
|
+
help="Force GNN trainer to use DistributedDataParallel.",
|
|
155
|
+
)
|
|
156
|
+
parser.add_argument(
|
|
157
|
+
"--gnn-no-ann",
|
|
158
|
+
action="store_true",
|
|
159
|
+
help="Disable approximate k-NN for GNN graph construction and use exact search.",
|
|
160
|
+
)
|
|
161
|
+
parser.add_argument(
|
|
162
|
+
"--gnn-ann-threshold",
|
|
163
|
+
type=int,
|
|
164
|
+
default=None,
|
|
165
|
+
help="Row threshold above which approximate k-NN is preferred (overrides config).",
|
|
166
|
+
)
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"--gnn-graph-cache",
|
|
169
|
+
default=None,
|
|
170
|
+
help="Optional path to persist/load cached adjacency matrix for GNN.",
|
|
171
|
+
)
|
|
172
|
+
parser.add_argument(
|
|
173
|
+
"--gnn-max-gpu-nodes",
|
|
174
|
+
type=int,
|
|
175
|
+
default=None,
|
|
176
|
+
help="Overrides the maximum node count allowed for GPU k-NN graph construction.",
|
|
177
|
+
)
|
|
178
|
+
parser.add_argument(
|
|
179
|
+
"--gnn-gpu-mem-ratio",
|
|
180
|
+
type=float,
|
|
181
|
+
default=None,
|
|
182
|
+
help="Overrides the fraction of free GPU memory the k-NN builder may consume.",
|
|
183
|
+
)
|
|
184
|
+
parser.add_argument(
|
|
185
|
+
"--gnn-gpu-mem-overhead",
|
|
186
|
+
type=float,
|
|
187
|
+
default=None,
|
|
188
|
+
help="Overrides the temporary GPU memory overhead multiplier for k-NN estimation.",
|
|
189
|
+
)
|
|
190
|
+
add_output_dir_arg(
|
|
191
|
+
parser,
|
|
192
|
+
help_text="Override output root for models/results/plots.",
|
|
193
|
+
)
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--plot-curves",
|
|
196
|
+
action="store_true",
|
|
197
|
+
help="Enable lift/diagnostic plots after training (config file may also request plotting).",
|
|
198
|
+
)
|
|
199
|
+
parser.add_argument(
|
|
200
|
+
"--ft-as-feature",
|
|
201
|
+
action="store_true",
|
|
202
|
+
help="Alias for --ft-role embedding (keep tuning, export embeddings; skip FT plots/SHAP).",
|
|
203
|
+
)
|
|
204
|
+
parser.add_argument(
|
|
205
|
+
"--ft-role",
|
|
206
|
+
default=None,
|
|
207
|
+
choices=["model", "embedding", "unsupervised_embedding"],
|
|
208
|
+
help="How to use FT: model (default), embedding (export pooling embeddings), or unsupervised_embedding.",
|
|
209
|
+
)
|
|
210
|
+
parser.add_argument(
|
|
211
|
+
"--ft-feature-prefix",
|
|
212
|
+
default="ft_feat",
|
|
213
|
+
help="Prefix used for generated FT features (columns: pred_<prefix>_0.. or pred_<prefix>).",
|
|
214
|
+
)
|
|
215
|
+
parser.add_argument(
|
|
216
|
+
"--reuse-best-params",
|
|
217
|
+
action="store_true",
|
|
218
|
+
help="Skip Optuna and reuse best_params saved in Results/versions or bestparams CSV when available.",
|
|
219
|
+
)
|
|
220
|
+
return parser.parse_args()
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _plot_curves_for_model(model: ropt.BayesOptModel, trained_keys: List[str], cfg: Dict) -> None:
|
|
224
|
+
plot_cfg = cfg.get("plot", {})
|
|
225
|
+
legacy_lift_flags = {
|
|
226
|
+
"glm": cfg.get("plot_lift_glm", False),
|
|
227
|
+
"xgb": cfg.get("plot_lift_xgb", False),
|
|
228
|
+
"resn": cfg.get("plot_lift_resn", False),
|
|
229
|
+
"ft": cfg.get("plot_lift_ft", False),
|
|
230
|
+
}
|
|
231
|
+
plot_enabled = plot_cfg.get("enable", any(legacy_lift_flags.values()))
|
|
232
|
+
if not plot_enabled:
|
|
233
|
+
return
|
|
234
|
+
|
|
235
|
+
n_bins = int(plot_cfg.get("n_bins", 10))
|
|
236
|
+
oneway_enabled = plot_cfg.get("oneway", True)
|
|
237
|
+
|
|
238
|
+
available_models = dedupe_preserve_order(
|
|
239
|
+
[m for m in trained_keys if m in PLOT_MODEL_LABELS]
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
lift_models = plot_cfg.get("lift_models")
|
|
243
|
+
if lift_models is None:
|
|
244
|
+
lift_models = [
|
|
245
|
+
m for m, enabled in legacy_lift_flags.items() if enabled]
|
|
246
|
+
if not lift_models:
|
|
247
|
+
lift_models = available_models
|
|
248
|
+
lift_models = dedupe_preserve_order(
|
|
249
|
+
[m for m in lift_models if m in available_models]
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if oneway_enabled:
|
|
253
|
+
oneway_pred = bool(plot_cfg.get("oneway_pred", False))
|
|
254
|
+
oneway_pred_models = plot_cfg.get("oneway_pred_models")
|
|
255
|
+
pred_plotted = False
|
|
256
|
+
if oneway_pred:
|
|
257
|
+
if oneway_pred_models is None:
|
|
258
|
+
oneway_pred_models = lift_models or available_models
|
|
259
|
+
oneway_pred_models = dedupe_preserve_order(
|
|
260
|
+
[m for m in oneway_pred_models if m in available_models]
|
|
261
|
+
)
|
|
262
|
+
for model_key in oneway_pred_models:
|
|
263
|
+
label, pred_nme = PLOT_MODEL_LABELS[model_key]
|
|
264
|
+
if pred_nme not in model.train_data.columns:
|
|
265
|
+
print(
|
|
266
|
+
f"[Oneway] Missing prediction column '{pred_nme}'; skip.",
|
|
267
|
+
flush=True,
|
|
268
|
+
)
|
|
269
|
+
continue
|
|
270
|
+
model.plot_oneway(
|
|
271
|
+
n_bins=n_bins,
|
|
272
|
+
pred_col=pred_nme,
|
|
273
|
+
pred_label=label,
|
|
274
|
+
plot_subdir="oneway/post",
|
|
275
|
+
)
|
|
276
|
+
pred_plotted = True
|
|
277
|
+
if not oneway_pred or not pred_plotted:
|
|
278
|
+
model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/post")
|
|
279
|
+
|
|
280
|
+
if not available_models:
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
for model_key in lift_models:
|
|
284
|
+
label, pred_nme = PLOT_MODEL_LABELS[model_key]
|
|
285
|
+
model.plot_lift(model_label=label, pred_nme=pred_nme, n_bins=n_bins)
|
|
286
|
+
|
|
287
|
+
if not plot_cfg.get("double_lift", True) or len(available_models) < 2:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
raw_pairs = plot_cfg.get("double_lift_pairs")
|
|
291
|
+
if raw_pairs:
|
|
292
|
+
pairs = [
|
|
293
|
+
(a, b)
|
|
294
|
+
for a, b in parse_model_pairs(raw_pairs)
|
|
295
|
+
if a in available_models and b in available_models and a != b
|
|
296
|
+
]
|
|
297
|
+
else:
|
|
298
|
+
pairs = [(a, b) for i, a in enumerate(available_models)
|
|
299
|
+
for b in available_models[i + 1:]]
|
|
300
|
+
|
|
301
|
+
for first, second in pairs:
|
|
302
|
+
model.plot_dlift([first, second], n_bins=n_bins)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _plot_loss_curve_for_trainer(model_name: str, trainer) -> None:
|
|
306
|
+
model_obj = getattr(trainer, "model", None)
|
|
307
|
+
history = None
|
|
308
|
+
if model_obj is not None:
|
|
309
|
+
history = getattr(model_obj, "training_history", None)
|
|
310
|
+
if not history:
|
|
311
|
+
history = getattr(trainer, "training_history", None)
|
|
312
|
+
if not history:
|
|
313
|
+
return
|
|
314
|
+
train_hist = list(history.get("train") or [])
|
|
315
|
+
val_hist = list(history.get("val") or [])
|
|
316
|
+
if not train_hist and not val_hist:
|
|
317
|
+
return
|
|
318
|
+
try:
|
|
319
|
+
plot_dir = trainer.output.plot_path(
|
|
320
|
+
f"{model_name}/loss/loss_{model_name}_{trainer.model_name_prefix}.png"
|
|
321
|
+
)
|
|
322
|
+
except Exception:
|
|
323
|
+
default_dir = Path("plot") / model_name / "loss"
|
|
324
|
+
default_dir.mkdir(parents=True, exist_ok=True)
|
|
325
|
+
plot_dir = str(
|
|
326
|
+
default_dir / f"loss_{model_name}_{trainer.model_name_prefix}.png")
|
|
327
|
+
if plot_loss_curve_common is not None:
|
|
328
|
+
plot_loss_curve_common(
|
|
329
|
+
history=history,
|
|
330
|
+
title=f"{trainer.model_name_prefix} Loss Curve ({model_name})",
|
|
331
|
+
save_path=plot_dir,
|
|
332
|
+
show=False,
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
|
|
336
|
+
fig, ax = plt.subplots(figsize=(8, 4))
|
|
337
|
+
if train_hist:
|
|
338
|
+
ax.plot(range(1, len(train_hist) + 1),
|
|
339
|
+
train_hist, label="Train Loss", color="tab:blue")
|
|
340
|
+
if val_hist:
|
|
341
|
+
ax.plot(range(1, len(val_hist) + 1),
|
|
342
|
+
val_hist, label="Validation Loss", color="tab:orange")
|
|
343
|
+
ax.set_xlabel("Epoch")
|
|
344
|
+
ax.set_ylabel("Weighted Loss")
|
|
345
|
+
ax.set_title(
|
|
346
|
+
f"{trainer.model_name_prefix} Loss Curve ({model_name})")
|
|
347
|
+
ax.grid(True, linestyle="--", alpha=0.3)
|
|
348
|
+
ax.legend()
|
|
349
|
+
plt.tight_layout()
|
|
350
|
+
plt.savefig(plot_dir, dpi=300)
|
|
351
|
+
plt.close(fig)
|
|
352
|
+
print(
|
|
353
|
+
f"[Plot] Saved loss curve for {model_name}/{trainer.label} -> {plot_dir}")
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _sample_arrays(
|
|
357
|
+
y_true: np.ndarray,
|
|
358
|
+
y_pred: np.ndarray,
|
|
359
|
+
*,
|
|
360
|
+
max_rows: Optional[int],
|
|
361
|
+
seed: Optional[int],
|
|
362
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
363
|
+
if max_rows is None or max_rows <= 0:
|
|
364
|
+
return y_true, y_pred
|
|
365
|
+
n = len(y_true)
|
|
366
|
+
if n <= max_rows:
|
|
367
|
+
return y_true, y_pred
|
|
368
|
+
rng = np.random.default_rng(seed)
|
|
369
|
+
idx = rng.choice(n, size=int(max_rows), replace=False)
|
|
370
|
+
return y_true[idx], y_pred[idx]
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _compute_psi_report(
|
|
374
|
+
model: ropt.BayesOptModel,
|
|
375
|
+
*,
|
|
376
|
+
features: Optional[List[str]],
|
|
377
|
+
bins: int,
|
|
378
|
+
strategy: str,
|
|
379
|
+
) -> Optional[pd.DataFrame]:
|
|
380
|
+
if drift_psi_report is None:
|
|
381
|
+
return None
|
|
382
|
+
psi_features = features or list(getattr(model, "factor_nmes", []))
|
|
383
|
+
psi_features = [
|
|
384
|
+
f for f in psi_features if f in model.train_data.columns and f in model.test_data.columns]
|
|
385
|
+
if not psi_features:
|
|
386
|
+
return None
|
|
387
|
+
try:
|
|
388
|
+
return drift_psi_report(
|
|
389
|
+
model.train_data[psi_features],
|
|
390
|
+
model.test_data[psi_features],
|
|
391
|
+
features=psi_features,
|
|
392
|
+
bins=int(bins),
|
|
393
|
+
strategy=str(strategy),
|
|
394
|
+
)
|
|
395
|
+
except Exception as exc:
|
|
396
|
+
print(f"[Report] PSI computation failed: {exc}")
|
|
397
|
+
return None
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# --- Refactored helper functions for _evaluate_and_report ---
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _apply_calibration(
|
|
404
|
+
y_true_train: np.ndarray,
|
|
405
|
+
y_pred_train: np.ndarray,
|
|
406
|
+
y_pred_test: np.ndarray,
|
|
407
|
+
calibration_cfg: Dict[str, Any],
|
|
408
|
+
model_name: str,
|
|
409
|
+
model_key: str,
|
|
410
|
+
) -> tuple[np.ndarray, np.ndarray, Optional[Dict[str, Any]]]:
|
|
411
|
+
"""Apply calibration to predictions for classification tasks.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Tuple of (calibrated_train_preds, calibrated_test_preds, calibration_info)
|
|
415
|
+
"""
|
|
416
|
+
cal_cfg = dict(calibration_cfg or {})
|
|
417
|
+
cal_enabled = bool(cal_cfg.get("enable", False) or cal_cfg.get("method"))
|
|
418
|
+
|
|
419
|
+
if not cal_enabled or calibrate_predictions is None:
|
|
420
|
+
return y_pred_train, y_pred_test, None
|
|
421
|
+
|
|
422
|
+
method = cal_cfg.get("method", "sigmoid")
|
|
423
|
+
max_rows = cal_cfg.get("max_rows")
|
|
424
|
+
seed = cal_cfg.get("seed")
|
|
425
|
+
y_cal, p_cal = _sample_arrays(
|
|
426
|
+
y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
|
|
427
|
+
|
|
428
|
+
try:
|
|
429
|
+
calibrator = calibrate_predictions(y_cal, p_cal, method=method)
|
|
430
|
+
calibrated_train = calibrator.predict(y_pred_train)
|
|
431
|
+
calibrated_test = calibrator.predict(y_pred_test)
|
|
432
|
+
calibration_info = {"method": calibrator.method, "max_rows": max_rows}
|
|
433
|
+
return calibrated_train, calibrated_test, calibration_info
|
|
434
|
+
except Exception as exc:
|
|
435
|
+
print(f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
|
|
436
|
+
return y_pred_train, y_pred_test, None
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _select_classification_threshold(
|
|
440
|
+
y_true_train: np.ndarray,
|
|
441
|
+
y_pred_train_eval: np.ndarray,
|
|
442
|
+
threshold_cfg: Dict[str, Any],
|
|
443
|
+
) -> tuple[float, Optional[Dict[str, Any]]]:
|
|
444
|
+
"""Select threshold for classification predictions.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Tuple of (threshold_value, threshold_info)
|
|
448
|
+
"""
|
|
449
|
+
thr_cfg = dict(threshold_cfg or {})
|
|
450
|
+
thr_enabled = bool(
|
|
451
|
+
thr_cfg.get("enable", False)
|
|
452
|
+
or thr_cfg.get("metric")
|
|
453
|
+
or thr_cfg.get("value") is not None
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
if thr_cfg.get("value") is not None:
|
|
457
|
+
threshold_value = float(thr_cfg["value"])
|
|
458
|
+
return threshold_value, {"threshold": threshold_value, "source": "fixed"}
|
|
459
|
+
|
|
460
|
+
if thr_enabled and select_threshold is not None:
|
|
461
|
+
max_rows = thr_cfg.get("max_rows")
|
|
462
|
+
seed = thr_cfg.get("seed")
|
|
463
|
+
y_thr, p_thr = _sample_arrays(
|
|
464
|
+
y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
|
|
465
|
+
threshold_info = select_threshold(
|
|
466
|
+
y_thr,
|
|
467
|
+
p_thr,
|
|
468
|
+
metric=thr_cfg.get("metric", "f1"),
|
|
469
|
+
min_positive_rate=thr_cfg.get("min_positive_rate"),
|
|
470
|
+
grid=thr_cfg.get("grid", 99),
|
|
471
|
+
)
|
|
472
|
+
return float(threshold_info.get("threshold", 0.5)), threshold_info
|
|
473
|
+
|
|
474
|
+
return 0.5, None
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _compute_classification_metrics(
|
|
478
|
+
y_true_test: np.ndarray,
|
|
479
|
+
y_pred_test_eval: np.ndarray,
|
|
480
|
+
threshold_value: float,
|
|
481
|
+
) -> Dict[str, Any]:
|
|
482
|
+
"""Compute metrics for classification task."""
|
|
483
|
+
metrics = eval_metrics_report(
|
|
484
|
+
y_true_test,
|
|
485
|
+
y_pred_test_eval,
|
|
486
|
+
task_type="classification",
|
|
487
|
+
threshold=threshold_value,
|
|
488
|
+
)
|
|
489
|
+
precision = float(metrics.get("precision", 0.0))
|
|
490
|
+
recall = float(metrics.get("recall", 0.0))
|
|
491
|
+
f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
|
|
492
|
+
metrics["f1"] = float(f1)
|
|
493
|
+
metrics["threshold"] = float(threshold_value)
|
|
494
|
+
return metrics
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def _compute_bootstrap_ci(
|
|
498
|
+
y_true_test: np.ndarray,
|
|
499
|
+
y_pred_test_eval: np.ndarray,
|
|
500
|
+
weight_test: Optional[np.ndarray],
|
|
501
|
+
metrics: Dict[str, Any],
|
|
502
|
+
bootstrap_cfg: Dict[str, Any],
|
|
503
|
+
task_type: str,
|
|
504
|
+
) -> Dict[str, Dict[str, float]]:
|
|
505
|
+
"""Compute bootstrap confidence intervals for metrics."""
|
|
506
|
+
if not bootstrap_cfg or not bool(bootstrap_cfg.get("enable", False)) or bootstrap_ci is None:
|
|
507
|
+
return {}
|
|
508
|
+
|
|
509
|
+
metric_names = bootstrap_cfg.get("metrics")
|
|
510
|
+
if not metric_names:
|
|
511
|
+
metric_names = [name for name in metrics.keys() if name != "threshold"]
|
|
512
|
+
n_samples = int(bootstrap_cfg.get("n_samples", 200))
|
|
513
|
+
ci = float(bootstrap_cfg.get("ci", 0.95))
|
|
514
|
+
seed = bootstrap_cfg.get("seed")
|
|
515
|
+
|
|
516
|
+
def _metric_fn(y_true, y_pred, weight=None):
|
|
517
|
+
vals = eval_metrics_report(
|
|
518
|
+
y_true,
|
|
519
|
+
y_pred,
|
|
520
|
+
task_type=task_type,
|
|
521
|
+
weight=weight,
|
|
522
|
+
threshold=metrics.get("threshold", 0.5),
|
|
523
|
+
)
|
|
524
|
+
if task_type == "classification":
|
|
525
|
+
prec = float(vals.get("precision", 0.0))
|
|
526
|
+
rec = float(vals.get("recall", 0.0))
|
|
527
|
+
vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
|
|
528
|
+
return vals
|
|
529
|
+
|
|
530
|
+
bootstrap_results: Dict[str, Dict[str, float]] = {}
|
|
531
|
+
for name in metric_names:
|
|
532
|
+
if name not in metrics:
|
|
533
|
+
continue
|
|
534
|
+
ci_result = bootstrap_ci(
|
|
535
|
+
lambda y_t, y_p, w=None: float(_metric_fn(y_t, y_p, w).get(name, 0.0)),
|
|
536
|
+
y_true_test,
|
|
537
|
+
y_pred_test_eval,
|
|
538
|
+
weight=weight_test,
|
|
539
|
+
n_samples=n_samples,
|
|
540
|
+
ci=ci,
|
|
541
|
+
seed=seed,
|
|
542
|
+
)
|
|
543
|
+
bootstrap_results[str(name)] = ci_result
|
|
544
|
+
|
|
545
|
+
return bootstrap_results
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def _compute_validation_table(
|
|
549
|
+
model: ropt.BayesOptModel,
|
|
550
|
+
pred_col: str,
|
|
551
|
+
report_group_cols: Optional[List[str]],
|
|
552
|
+
weight_col: Optional[str],
|
|
553
|
+
model_name: str,
|
|
554
|
+
model_key: str,
|
|
555
|
+
) -> Optional[pd.DataFrame]:
|
|
556
|
+
"""Compute grouped validation metrics table."""
|
|
557
|
+
if not report_group_cols or group_metrics is None:
|
|
558
|
+
return None
|
|
559
|
+
|
|
560
|
+
available_groups = [
|
|
561
|
+
col for col in report_group_cols if col in model.test_data.columns
|
|
562
|
+
]
|
|
563
|
+
if not available_groups:
|
|
564
|
+
return None
|
|
565
|
+
|
|
566
|
+
try:
|
|
567
|
+
validation_table = group_metrics(
|
|
568
|
+
model.test_data,
|
|
569
|
+
actual_col=model.resp_nme,
|
|
570
|
+
pred_col=pred_col,
|
|
571
|
+
group_cols=available_groups,
|
|
572
|
+
weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
|
|
573
|
+
)
|
|
574
|
+
counts = (
|
|
575
|
+
model.test_data.groupby(available_groups, dropna=False)
|
|
576
|
+
.size()
|
|
577
|
+
.reset_index(name="count")
|
|
578
|
+
)
|
|
579
|
+
return validation_table.merge(counts, on=available_groups, how="left")
|
|
580
|
+
except Exception as exc:
|
|
581
|
+
print(f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _compute_risk_trend(
|
|
586
|
+
model: ropt.BayesOptModel,
|
|
587
|
+
pred_col: str,
|
|
588
|
+
report_time_col: Optional[str],
|
|
589
|
+
report_time_freq: str,
|
|
590
|
+
report_time_ascending: bool,
|
|
591
|
+
weight_col: Optional[str],
|
|
592
|
+
model_name: str,
|
|
593
|
+
model_key: str,
|
|
594
|
+
) -> Optional[pd.DataFrame]:
|
|
595
|
+
"""Compute time-series risk trend metrics."""
|
|
596
|
+
if not report_time_col or group_metrics is None:
|
|
597
|
+
return None
|
|
598
|
+
|
|
599
|
+
if report_time_col not in model.test_data.columns:
|
|
600
|
+
return None
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
time_df = model.test_data.copy()
|
|
604
|
+
time_series = pd.to_datetime(time_df[report_time_col], errors="coerce")
|
|
605
|
+
time_df = time_df.loc[time_series.notna()].copy()
|
|
606
|
+
|
|
607
|
+
if time_df.empty:
|
|
608
|
+
return None
|
|
609
|
+
|
|
610
|
+
time_df["_time_bucket"] = (
|
|
611
|
+
pd.to_datetime(time_df[report_time_col], errors="coerce")
|
|
612
|
+
.dt.to_period(report_time_freq)
|
|
613
|
+
.dt.to_timestamp()
|
|
614
|
+
)
|
|
615
|
+
risk_trend = group_metrics(
|
|
616
|
+
time_df,
|
|
617
|
+
actual_col=model.resp_nme,
|
|
618
|
+
pred_col=pred_col,
|
|
619
|
+
group_cols=["_time_bucket"],
|
|
620
|
+
weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
|
|
621
|
+
)
|
|
622
|
+
counts = (
|
|
623
|
+
time_df.groupby("_time_bucket", dropna=False)
|
|
624
|
+
.size()
|
|
625
|
+
.reset_index(name="count")
|
|
626
|
+
)
|
|
627
|
+
risk_trend = risk_trend.merge(counts, on="_time_bucket", how="left")
|
|
628
|
+
risk_trend = risk_trend.sort_values(
|
|
629
|
+
"_time_bucket", ascending=bool(report_time_ascending)
|
|
630
|
+
).reset_index(drop=True)
|
|
631
|
+
return risk_trend.rename(columns={"_time_bucket": report_time_col})
|
|
632
|
+
except Exception as exc:
|
|
633
|
+
print(f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
|
|
634
|
+
return None
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def _write_metrics_json(
|
|
638
|
+
report_root: Path,
|
|
639
|
+
model_name: str,
|
|
640
|
+
model_key: str,
|
|
641
|
+
version: str,
|
|
642
|
+
metrics: Dict[str, Any],
|
|
643
|
+
threshold_info: Optional[Dict[str, Any]],
|
|
644
|
+
calibration_info: Optional[Dict[str, Any]],
|
|
645
|
+
bootstrap_results: Dict[str, Dict[str, float]],
|
|
646
|
+
data_path: Path,
|
|
647
|
+
data_fingerprint: Dict[str, Any],
|
|
648
|
+
config_sha: str,
|
|
649
|
+
pred_col: str,
|
|
650
|
+
task_type: str,
|
|
651
|
+
) -> Path:
|
|
652
|
+
"""Write metrics to JSON file and return the path."""
|
|
653
|
+
metrics_payload = {
|
|
654
|
+
"model_name": model_name,
|
|
655
|
+
"model_key": model_key,
|
|
656
|
+
"model_version": version,
|
|
657
|
+
"metrics": metrics,
|
|
658
|
+
"threshold": threshold_info,
|
|
659
|
+
"calibration": calibration_info,
|
|
660
|
+
"bootstrap": bootstrap_results,
|
|
661
|
+
"data_path": str(data_path),
|
|
662
|
+
"data_fingerprint": data_fingerprint,
|
|
663
|
+
"config_sha256": config_sha,
|
|
664
|
+
"pred_col": pred_col,
|
|
665
|
+
"task_type": task_type,
|
|
666
|
+
}
|
|
667
|
+
metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
|
|
668
|
+
metrics_path.write_text(
|
|
669
|
+
json.dumps(metrics_payload, indent=2, ensure_ascii=True),
|
|
670
|
+
encoding="utf-8",
|
|
671
|
+
)
|
|
672
|
+
return metrics_path
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _write_model_report(
|
|
676
|
+
report_root: Path,
|
|
677
|
+
model_name: str,
|
|
678
|
+
model_key: str,
|
|
679
|
+
version: str,
|
|
680
|
+
metrics: Dict[str, Any],
|
|
681
|
+
risk_trend: Optional[pd.DataFrame],
|
|
682
|
+
psi_report_df: Optional[pd.DataFrame],
|
|
683
|
+
validation_table: Optional[pd.DataFrame],
|
|
684
|
+
calibration_info: Optional[Dict[str, Any]],
|
|
685
|
+
threshold_info: Optional[Dict[str, Any]],
|
|
686
|
+
bootstrap_results: Dict[str, Dict[str, float]],
|
|
687
|
+
config_sha: str,
|
|
688
|
+
data_fingerprint: Dict[str, Any],
|
|
689
|
+
) -> Optional[Path]:
|
|
690
|
+
"""Write model report and return the path."""
|
|
691
|
+
if ReportPayload is None or write_report is None:
|
|
692
|
+
return None
|
|
693
|
+
|
|
694
|
+
notes_lines = [
|
|
695
|
+
f"- Config SHA256: {config_sha}",
|
|
696
|
+
f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
|
|
697
|
+
]
|
|
698
|
+
if calibration_info:
|
|
699
|
+
notes_lines.append(f"- Calibration: {calibration_info.get('method')}")
|
|
700
|
+
if threshold_info:
|
|
701
|
+
notes_lines.append(f"- Threshold selection: {threshold_info}")
|
|
702
|
+
if bootstrap_results:
|
|
703
|
+
notes_lines.append("- Bootstrap: see metrics JSON for CI")
|
|
704
|
+
|
|
705
|
+
payload = ReportPayload(
|
|
706
|
+
model_name=f"{model_name}/{model_key}",
|
|
707
|
+
model_version=version,
|
|
708
|
+
metrics={k: float(v) for k, v in metrics.items()},
|
|
709
|
+
risk_trend=risk_trend,
|
|
710
|
+
drift_report=psi_report_df,
|
|
711
|
+
validation_table=validation_table,
|
|
712
|
+
extra_notes="\n".join(notes_lines),
|
|
713
|
+
)
|
|
714
|
+
return write_report(
|
|
715
|
+
payload,
|
|
716
|
+
report_root / f"{model_name}_{model_key}_report.md",
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def _register_model_to_registry(
|
|
721
|
+
model: ropt.BayesOptModel,
|
|
722
|
+
model_name: str,
|
|
723
|
+
model_key: str,
|
|
724
|
+
version: str,
|
|
725
|
+
metrics: Dict[str, Any],
|
|
726
|
+
task_type: str,
|
|
727
|
+
data_path: Path,
|
|
728
|
+
data_fingerprint: Dict[str, Any],
|
|
729
|
+
config_sha: str,
|
|
730
|
+
registry_path: Optional[str],
|
|
731
|
+
registry_tags: Dict[str, Any],
|
|
732
|
+
registry_status: str,
|
|
733
|
+
report_path: Optional[Path],
|
|
734
|
+
metrics_path: Path,
|
|
735
|
+
cfg: Dict[str, Any],
|
|
736
|
+
) -> None:
|
|
737
|
+
"""Register model artifacts to the model registry."""
|
|
738
|
+
if ModelRegistry is None or ModelArtifact is None:
|
|
739
|
+
return
|
|
740
|
+
|
|
741
|
+
registry = ModelRegistry(
|
|
742
|
+
registry_path
|
|
743
|
+
if registry_path
|
|
744
|
+
else Path(model.output_manager.result_dir) / "model_registry.json"
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
|
|
748
|
+
tags.update({
|
|
749
|
+
"model_key": str(model_key),
|
|
750
|
+
"task_type": str(task_type),
|
|
751
|
+
"data_path": str(data_path),
|
|
752
|
+
"data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
|
|
753
|
+
"data_size": str(data_fingerprint.get("size", "")),
|
|
754
|
+
"data_mtime": str(data_fingerprint.get("mtime", "")),
|
|
755
|
+
"config_sha256": str(config_sha),
|
|
756
|
+
})
|
|
757
|
+
|
|
758
|
+
artifacts = _collect_model_artifacts(
|
|
759
|
+
model, model_name, model_key, report_path, metrics_path, cfg
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
registry.register(
|
|
763
|
+
name=str(model_name),
|
|
764
|
+
version=version,
|
|
765
|
+
metrics={k: float(v) for k, v in metrics.items()},
|
|
766
|
+
tags=tags,
|
|
767
|
+
artifacts=artifacts,
|
|
768
|
+
status=str(registry_status or "candidate"),
|
|
769
|
+
notes=f"model_key={model_key}",
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _collect_model_artifacts(
|
|
774
|
+
model: ropt.BayesOptModel,
|
|
775
|
+
model_name: str,
|
|
776
|
+
model_key: str,
|
|
777
|
+
report_path: Optional[Path],
|
|
778
|
+
metrics_path: Path,
|
|
779
|
+
cfg: Dict[str, Any],
|
|
780
|
+
) -> List:
|
|
781
|
+
"""Collect all model artifacts for registry."""
|
|
782
|
+
artifacts = []
|
|
783
|
+
|
|
784
|
+
# Trained model artifact
|
|
785
|
+
trainer = model.trainers.get(model_key)
|
|
786
|
+
if trainer is not None:
|
|
787
|
+
try:
|
|
788
|
+
model_path = trainer.output.model_path(trainer._get_model_filename())
|
|
789
|
+
if os.path.exists(model_path):
|
|
790
|
+
artifacts.append(ModelArtifact(path=model_path, description="trained model"))
|
|
791
|
+
except Exception:
|
|
792
|
+
pass
|
|
793
|
+
|
|
794
|
+
# Report artifact
|
|
795
|
+
if report_path is not None:
|
|
796
|
+
artifacts.append(ModelArtifact(path=str(report_path), description="model report"))
|
|
797
|
+
|
|
798
|
+
# Metrics JSON artifact
|
|
799
|
+
if metrics_path.exists():
|
|
800
|
+
artifacts.append(ModelArtifact(path=str(metrics_path), description="metrics json"))
|
|
801
|
+
|
|
802
|
+
# Preprocess artifacts
|
|
803
|
+
if bool(cfg.get("save_preprocess", False)):
|
|
804
|
+
artifact_path = cfg.get("preprocess_artifact_path")
|
|
805
|
+
if artifact_path:
|
|
806
|
+
preprocess_path = Path(str(artifact_path))
|
|
807
|
+
if not preprocess_path.is_absolute():
|
|
808
|
+
preprocess_path = Path(model.output_manager.result_dir) / preprocess_path
|
|
809
|
+
else:
|
|
810
|
+
preprocess_path = Path(model.output_manager.result_path(
|
|
811
|
+
f"{model.model_nme}_preprocess.json"
|
|
812
|
+
))
|
|
813
|
+
if preprocess_path.exists():
|
|
814
|
+
artifacts.append(
|
|
815
|
+
ModelArtifact(path=str(preprocess_path), description="preprocess artifacts")
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Prediction cache artifacts
|
|
819
|
+
if bool(cfg.get("cache_predictions", False)):
|
|
820
|
+
cache_dir = cfg.get("prediction_cache_dir")
|
|
821
|
+
if cache_dir:
|
|
822
|
+
pred_root = Path(str(cache_dir))
|
|
823
|
+
if not pred_root.is_absolute():
|
|
824
|
+
pred_root = Path(model.output_manager.result_dir) / pred_root
|
|
825
|
+
else:
|
|
826
|
+
pred_root = Path(model.output_manager.result_dir) / "predictions"
|
|
827
|
+
ext = "csv" if str(cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
|
|
828
|
+
for split_label in ("train", "test"):
|
|
829
|
+
pred_path = pred_root / f"{model_name}_{model_key}_{split_label}.{ext}"
|
|
830
|
+
if pred_path.exists():
|
|
831
|
+
artifacts.append(
|
|
832
|
+
ModelArtifact(path=str(pred_path), description=f"predictions {split_label}")
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
return artifacts
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def _evaluate_and_report(
|
|
839
|
+
model: ropt.BayesOptModel,
|
|
840
|
+
*,
|
|
841
|
+
model_name: str,
|
|
842
|
+
model_key: str,
|
|
843
|
+
cfg: Dict[str, Any],
|
|
844
|
+
data_path: Path,
|
|
845
|
+
data_fingerprint: Dict[str, Any],
|
|
846
|
+
report_output_dir: Optional[str],
|
|
847
|
+
report_group_cols: Optional[List[str]],
|
|
848
|
+
report_time_col: Optional[str],
|
|
849
|
+
report_time_freq: str,
|
|
850
|
+
report_time_ascending: bool,
|
|
851
|
+
psi_report_df: Optional[pd.DataFrame],
|
|
852
|
+
calibration_cfg: Dict[str, Any],
|
|
853
|
+
threshold_cfg: Dict[str, Any],
|
|
854
|
+
bootstrap_cfg: Dict[str, Any],
|
|
855
|
+
register_model: bool,
|
|
856
|
+
registry_path: Optional[str],
|
|
857
|
+
registry_tags: Dict[str, Any],
|
|
858
|
+
registry_status: str,
|
|
859
|
+
run_id: str,
|
|
860
|
+
config_sha: str,
|
|
861
|
+
) -> None:
|
|
862
|
+
"""Evaluate model predictions and generate reports.
|
|
863
|
+
|
|
864
|
+
This function orchestrates the evaluation pipeline:
|
|
865
|
+
1. Extract predictions and ground truth
|
|
866
|
+
2. Apply calibration (for classification)
|
|
867
|
+
3. Select threshold (for classification)
|
|
868
|
+
4. Compute metrics
|
|
869
|
+
5. Compute bootstrap confidence intervals
|
|
870
|
+
6. Generate validation tables and risk trends
|
|
871
|
+
7. Write reports and register model
|
|
872
|
+
"""
|
|
873
|
+
if eval_metrics_report is None:
|
|
874
|
+
print("[Report] Skip evaluation: metrics module unavailable.")
|
|
875
|
+
return
|
|
876
|
+
|
|
877
|
+
pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
|
|
878
|
+
if pred_col not in model.test_data.columns:
|
|
879
|
+
print(f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
# Extract predictions and weights
|
|
883
|
+
weight_col = getattr(model, "weight_nme", None)
|
|
884
|
+
y_true_train = model.train_data[model.resp_nme].to_numpy(dtype=float, copy=False)
|
|
885
|
+
y_true_test = model.test_data[model.resp_nme].to_numpy(dtype=float, copy=False)
|
|
886
|
+
y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
|
|
887
|
+
y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
|
|
888
|
+
weight_test = (
|
|
889
|
+
model.test_data[weight_col].to_numpy(dtype=float, copy=False)
|
|
890
|
+
if weight_col and weight_col in model.test_data.columns
|
|
891
|
+
else None
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
task_type = str(cfg.get("task_type", getattr(model, "task_type", "regression")))
|
|
895
|
+
|
|
896
|
+
# Process based on task type
|
|
897
|
+
if task_type == "classification":
|
|
898
|
+
y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
|
|
899
|
+
y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
|
|
900
|
+
|
|
901
|
+
y_pred_train_eval, y_pred_test_eval, calibration_info = _apply_calibration(
|
|
902
|
+
y_true_train, y_pred_train, y_pred_test, calibration_cfg, model_name, model_key
|
|
903
|
+
)
|
|
904
|
+
threshold_value, threshold_info = _select_classification_threshold(
|
|
905
|
+
y_true_train, y_pred_train_eval, threshold_cfg
|
|
906
|
+
)
|
|
907
|
+
metrics = _compute_classification_metrics(y_true_test, y_pred_test_eval, threshold_value)
|
|
908
|
+
else:
|
|
909
|
+
y_pred_test_eval = y_pred_test
|
|
910
|
+
calibration_info = None
|
|
911
|
+
threshold_info = None
|
|
912
|
+
metrics = eval_metrics_report(
|
|
913
|
+
y_true_test, y_pred_test_eval, task_type=task_type, weight=weight_test
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
# Compute bootstrap confidence intervals
|
|
917
|
+
bootstrap_results = _compute_bootstrap_ci(
|
|
918
|
+
y_true_test, y_pred_test_eval, weight_test, metrics, bootstrap_cfg, task_type
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
# Compute validation table and risk trend
|
|
922
|
+
validation_table = _compute_validation_table(
|
|
923
|
+
model, pred_col, report_group_cols, weight_col, model_name, model_key
|
|
924
|
+
)
|
|
925
|
+
risk_trend = _compute_risk_trend(
|
|
926
|
+
model, pred_col, report_time_col, report_time_freq,
|
|
927
|
+
report_time_ascending, weight_col, model_name, model_key
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
# Setup output directory
|
|
931
|
+
report_root = (
|
|
932
|
+
Path(report_output_dir)
|
|
933
|
+
if report_output_dir
|
|
934
|
+
else Path(model.output_manager.result_dir) / "reports"
|
|
935
|
+
)
|
|
936
|
+
report_root.mkdir(parents=True, exist_ok=True)
|
|
937
|
+
version = f"{model_key}_{run_id}"
|
|
938
|
+
|
|
939
|
+
# Write metrics JSON
|
|
940
|
+
metrics_path = _write_metrics_json(
|
|
941
|
+
report_root, model_name, model_key, version, metrics,
|
|
942
|
+
threshold_info, calibration_info, bootstrap_results,
|
|
943
|
+
data_path, data_fingerprint, config_sha, pred_col, task_type
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
# Write model report
|
|
947
|
+
report_path = _write_model_report(
|
|
948
|
+
report_root, model_name, model_key, version, metrics,
|
|
949
|
+
risk_trend, psi_report_df, validation_table,
|
|
950
|
+
calibration_info, threshold_info, bootstrap_results,
|
|
951
|
+
config_sha, data_fingerprint
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
# Register model
|
|
955
|
+
if register_model:
|
|
956
|
+
_register_model_to_registry(
|
|
957
|
+
model, model_name, model_key, version, metrics, task_type,
|
|
958
|
+
data_path, data_fingerprint, config_sha, registry_path,
|
|
959
|
+
registry_tags, registry_status, report_path, metrics_path, cfg
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def _evaluate_with_context(
|
|
964
|
+
model: ropt.BayesOptModel,
|
|
965
|
+
ctx: EvaluationContext,
|
|
966
|
+
) -> None:
|
|
967
|
+
"""Evaluate model predictions using context object.
|
|
968
|
+
|
|
969
|
+
This is a cleaner interface that uses the EvaluationContext dataclass
|
|
970
|
+
instead of 19+ individual parameters.
|
|
971
|
+
"""
|
|
972
|
+
_evaluate_and_report(
|
|
973
|
+
model,
|
|
974
|
+
model_name=ctx.identity.model_name,
|
|
975
|
+
model_key=ctx.identity.model_key,
|
|
976
|
+
cfg=ctx.cfg,
|
|
977
|
+
data_path=ctx.data_path,
|
|
978
|
+
data_fingerprint=ctx.data_fingerprint.to_dict(),
|
|
979
|
+
report_output_dir=ctx.report.output_dir,
|
|
980
|
+
report_group_cols=ctx.report.group_cols,
|
|
981
|
+
report_time_col=ctx.report.time_col,
|
|
982
|
+
report_time_freq=ctx.report.time_freq,
|
|
983
|
+
report_time_ascending=ctx.report.time_ascending,
|
|
984
|
+
psi_report_df=ctx.psi_report_df,
|
|
985
|
+
calibration_cfg={
|
|
986
|
+
"enable": ctx.calibration.enable,
|
|
987
|
+
"method": ctx.calibration.method,
|
|
988
|
+
"max_rows": ctx.calibration.max_rows,
|
|
989
|
+
"seed": ctx.calibration.seed,
|
|
990
|
+
},
|
|
991
|
+
threshold_cfg={
|
|
992
|
+
"enable": ctx.threshold.enable,
|
|
993
|
+
"metric": ctx.threshold.metric,
|
|
994
|
+
"value": ctx.threshold.value,
|
|
995
|
+
"min_positive_rate": ctx.threshold.min_positive_rate,
|
|
996
|
+
"grid": ctx.threshold.grid,
|
|
997
|
+
"max_rows": ctx.threshold.max_rows,
|
|
998
|
+
"seed": ctx.threshold.seed,
|
|
999
|
+
},
|
|
1000
|
+
bootstrap_cfg={
|
|
1001
|
+
"enable": ctx.bootstrap.enable,
|
|
1002
|
+
"metrics": ctx.bootstrap.metrics,
|
|
1003
|
+
"n_samples": ctx.bootstrap.n_samples,
|
|
1004
|
+
"ci": ctx.bootstrap.ci,
|
|
1005
|
+
"seed": ctx.bootstrap.seed,
|
|
1006
|
+
},
|
|
1007
|
+
register_model=ctx.registry.register,
|
|
1008
|
+
registry_path=ctx.registry.path,
|
|
1009
|
+
registry_tags=ctx.registry.tags,
|
|
1010
|
+
registry_status=ctx.registry.status,
|
|
1011
|
+
run_id=ctx.run_id,
|
|
1012
|
+
config_sha=ctx.config_sha,
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def _create_ddp_barrier(dist_ctx: TrainingContext):
|
|
1017
|
+
"""Create a DDP barrier function for distributed training synchronization."""
|
|
1018
|
+
def _ddp_barrier(reason: str) -> None:
|
|
1019
|
+
if not dist_ctx.is_distributed:
|
|
1020
|
+
return
|
|
1021
|
+
torch_mod = getattr(ropt, "torch", None)
|
|
1022
|
+
dist_mod = getattr(torch_mod, "distributed", None)
|
|
1023
|
+
if dist_mod is None:
|
|
1024
|
+
return
|
|
1025
|
+
try:
|
|
1026
|
+
if not getattr(dist_mod, "is_available", lambda: False)():
|
|
1027
|
+
return
|
|
1028
|
+
if not dist_mod.is_initialized():
|
|
1029
|
+
ddp_ok, _, _, _ = ropt.DistributedUtils.setup_ddp()
|
|
1030
|
+
if not ddp_ok or not dist_mod.is_initialized():
|
|
1031
|
+
return
|
|
1032
|
+
dist_mod.barrier()
|
|
1033
|
+
except Exception as exc:
|
|
1034
|
+
print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
|
|
1035
|
+
raise
|
|
1036
|
+
return _ddp_barrier
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def train_from_config(args: argparse.Namespace) -> None:
|
|
1040
|
+
script_dir = Path(__file__).resolve().parents[1]
|
|
1041
|
+
config_path, cfg = resolve_and_load_config(
|
|
1042
|
+
args.config_json,
|
|
1043
|
+
script_dir,
|
|
1044
|
+
required_keys=["data_dir", "model_list",
|
|
1045
|
+
"model_categories", "target", "weight"],
|
|
1046
|
+
)
|
|
1047
|
+
plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
|
|
1048
|
+
config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
|
|
1049
|
+
run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
1050
|
+
|
|
1051
|
+
# Use TrainingContext for distributed training state
|
|
1052
|
+
dist_ctx = TrainingContext.from_env()
|
|
1053
|
+
dist_world_size = dist_ctx.world_size
|
|
1054
|
+
dist_rank = dist_ctx.rank
|
|
1055
|
+
dist_active = dist_ctx.is_distributed
|
|
1056
|
+
is_main_process = dist_ctx.is_main_process
|
|
1057
|
+
_ddp_barrier = _create_ddp_barrier(dist_ctx)
|
|
1058
|
+
|
|
1059
|
+
data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
|
|
1060
|
+
cfg,
|
|
1061
|
+
config_path,
|
|
1062
|
+
create_data_dir=True,
|
|
1063
|
+
)
|
|
1064
|
+
runtime_cfg = resolve_runtime_config(cfg)
|
|
1065
|
+
ddp_min_rows = runtime_cfg["ddp_min_rows"]
|
|
1066
|
+
bo_sample_limit = runtime_cfg["bo_sample_limit"]
|
|
1067
|
+
cache_predictions = runtime_cfg["cache_predictions"]
|
|
1068
|
+
prediction_cache_dir = runtime_cfg["prediction_cache_dir"]
|
|
1069
|
+
prediction_cache_format = runtime_cfg["prediction_cache_format"]
|
|
1070
|
+
report_cfg = resolve_report_config(cfg)
|
|
1071
|
+
report_output_dir = report_cfg["report_output_dir"]
|
|
1072
|
+
report_group_cols = report_cfg["report_group_cols"]
|
|
1073
|
+
report_time_col = report_cfg["report_time_col"]
|
|
1074
|
+
report_time_freq = report_cfg["report_time_freq"]
|
|
1075
|
+
report_time_ascending = report_cfg["report_time_ascending"]
|
|
1076
|
+
psi_bins = report_cfg["psi_bins"]
|
|
1077
|
+
psi_strategy = report_cfg["psi_strategy"]
|
|
1078
|
+
psi_features = report_cfg["psi_features"]
|
|
1079
|
+
calibration_cfg = report_cfg["calibration_cfg"]
|
|
1080
|
+
threshold_cfg = report_cfg["threshold_cfg"]
|
|
1081
|
+
bootstrap_cfg = report_cfg["bootstrap_cfg"]
|
|
1082
|
+
register_model = report_cfg["register_model"]
|
|
1083
|
+
registry_path = report_cfg["registry_path"]
|
|
1084
|
+
registry_tags = report_cfg["registry_tags"]
|
|
1085
|
+
registry_status = report_cfg["registry_status"]
|
|
1086
|
+
data_fingerprint_max_bytes = report_cfg["data_fingerprint_max_bytes"]
|
|
1087
|
+
report_enabled = report_cfg["report_enabled"]
|
|
1088
|
+
|
|
1089
|
+
split_cfg = resolve_split_config(cfg)
|
|
1090
|
+
prop_test = split_cfg["prop_test"]
|
|
1091
|
+
holdout_ratio = split_cfg["holdout_ratio"]
|
|
1092
|
+
val_ratio = split_cfg["val_ratio"]
|
|
1093
|
+
split_strategy = split_cfg["split_strategy"]
|
|
1094
|
+
split_group_col = split_cfg["split_group_col"]
|
|
1095
|
+
split_time_col = split_cfg["split_time_col"]
|
|
1096
|
+
split_time_ascending = split_cfg["split_time_ascending"]
|
|
1097
|
+
cv_strategy = split_cfg["cv_strategy"]
|
|
1098
|
+
cv_group_col = split_cfg["cv_group_col"]
|
|
1099
|
+
cv_time_col = split_cfg["cv_time_col"]
|
|
1100
|
+
cv_time_ascending = split_cfg["cv_time_ascending"]
|
|
1101
|
+
cv_splits = split_cfg["cv_splits"]
|
|
1102
|
+
ft_oof_folds = split_cfg["ft_oof_folds"]
|
|
1103
|
+
ft_oof_strategy = split_cfg["ft_oof_strategy"]
|
|
1104
|
+
ft_oof_shuffle = split_cfg["ft_oof_shuffle"]
|
|
1105
|
+
save_preprocess = runtime_cfg["save_preprocess"]
|
|
1106
|
+
preprocess_artifact_path = runtime_cfg["preprocess_artifact_path"]
|
|
1107
|
+
rand_seed = runtime_cfg["rand_seed"]
|
|
1108
|
+
epochs = runtime_cfg["epochs"]
|
|
1109
|
+
output_cfg = resolve_output_dirs(
|
|
1110
|
+
cfg,
|
|
1111
|
+
config_path,
|
|
1112
|
+
output_override=args.output_dir,
|
|
1113
|
+
)
|
|
1114
|
+
output_dir = output_cfg["output_dir"]
|
|
1115
|
+
reuse_best_params = bool(
|
|
1116
|
+
args.reuse_best_params or runtime_cfg["reuse_best_params"])
|
|
1117
|
+
xgb_max_depth_max = runtime_cfg["xgb_max_depth_max"]
|
|
1118
|
+
xgb_n_estimators_max = runtime_cfg["xgb_n_estimators_max"]
|
|
1119
|
+
optuna_storage = runtime_cfg["optuna_storage"]
|
|
1120
|
+
optuna_study_prefix = runtime_cfg["optuna_study_prefix"]
|
|
1121
|
+
best_params_files = runtime_cfg["best_params_files"]
|
|
1122
|
+
plot_path_style = runtime_cfg["plot_path_style"]
|
|
1123
|
+
|
|
1124
|
+
model_names = build_model_names(
|
|
1125
|
+
cfg["model_list"], cfg["model_categories"])
|
|
1126
|
+
if not model_names:
|
|
1127
|
+
raise ValueError(
|
|
1128
|
+
"No model names generated from model_list/model_categories.")
|
|
1129
|
+
|
|
1130
|
+
results: Dict[str, ropt.BayesOptModel] = {}
|
|
1131
|
+
trained_keys_by_model: Dict[str, List[str]] = {}
|
|
1132
|
+
|
|
1133
|
+
for model_name in model_names:
|
|
1134
|
+
# Per-dataset training loop: load data, split train/test, and train requested models.
|
|
1135
|
+
data_path = resolve_data_path(
|
|
1136
|
+
data_dir,
|
|
1137
|
+
model_name,
|
|
1138
|
+
data_format=data_format,
|
|
1139
|
+
path_template=data_path_template,
|
|
1140
|
+
)
|
|
1141
|
+
if not data_path.exists():
|
|
1142
|
+
raise FileNotFoundError(f"Missing dataset: {data_path}")
|
|
1143
|
+
data_fingerprint = {"path": str(data_path)}
|
|
1144
|
+
if report_enabled and is_main_process:
|
|
1145
|
+
data_fingerprint = fingerprint_file(
|
|
1146
|
+
data_path,
|
|
1147
|
+
max_bytes=data_fingerprint_max_bytes,
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
print(f"\n=== Processing model {model_name} ===")
|
|
1151
|
+
raw = load_dataset(
|
|
1152
|
+
data_path,
|
|
1153
|
+
data_format=data_format,
|
|
1154
|
+
dtype_map=dtype_map,
|
|
1155
|
+
low_memory=False,
|
|
1156
|
+
)
|
|
1157
|
+
raw = coerce_dataset_types(raw)
|
|
1158
|
+
|
|
1159
|
+
train_df, test_df = split_train_test(
|
|
1160
|
+
raw,
|
|
1161
|
+
holdout_ratio=holdout_ratio,
|
|
1162
|
+
strategy=split_strategy,
|
|
1163
|
+
group_col=split_group_col,
|
|
1164
|
+
time_col=split_time_col,
|
|
1165
|
+
time_ascending=split_time_ascending,
|
|
1166
|
+
rand_seed=rand_seed,
|
|
1167
|
+
reset_index_mode="time_group",
|
|
1168
|
+
ratio_label="holdout_ratio",
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
use_resn_dp = args.use_resn_dp or cfg.get(
|
|
1172
|
+
"use_resn_data_parallel", False)
|
|
1173
|
+
use_ft_dp = args.use_ft_dp or cfg.get("use_ft_data_parallel", True)
|
|
1174
|
+
dataset_rows = len(raw)
|
|
1175
|
+
ddp_enabled = bool(dist_active and (dataset_rows >= int(ddp_min_rows)))
|
|
1176
|
+
use_resn_ddp = (args.use_resn_ddp or cfg.get(
|
|
1177
|
+
"use_resn_ddp", False)) and ddp_enabled
|
|
1178
|
+
use_ft_ddp = (args.use_ft_ddp or cfg.get(
|
|
1179
|
+
"use_ft_ddp", False)) and ddp_enabled
|
|
1180
|
+
use_gnn_dp = args.use_gnn_dp or cfg.get("use_gnn_data_parallel", False)
|
|
1181
|
+
use_gnn_ddp = (args.use_gnn_ddp or cfg.get(
|
|
1182
|
+
"use_gnn_ddp", False)) and ddp_enabled
|
|
1183
|
+
gnn_use_ann = cfg.get("gnn_use_approx_knn", True)
|
|
1184
|
+
if args.gnn_no_ann:
|
|
1185
|
+
gnn_use_ann = False
|
|
1186
|
+
gnn_threshold = args.gnn_ann_threshold if args.gnn_ann_threshold is not None else cfg.get(
|
|
1187
|
+
"gnn_approx_knn_threshold", 50000)
|
|
1188
|
+
gnn_graph_cache = args.gnn_graph_cache or cfg.get("gnn_graph_cache")
|
|
1189
|
+
if isinstance(gnn_graph_cache, str) and gnn_graph_cache.strip():
|
|
1190
|
+
resolved_cache = resolve_path(gnn_graph_cache, config_path.parent)
|
|
1191
|
+
if resolved_cache is not None:
|
|
1192
|
+
gnn_graph_cache = str(resolved_cache)
|
|
1193
|
+
gnn_max_gpu_nodes = args.gnn_max_gpu_nodes if args.gnn_max_gpu_nodes is not None else cfg.get(
|
|
1194
|
+
"gnn_max_gpu_knn_nodes", 200000)
|
|
1195
|
+
gnn_gpu_mem_ratio = args.gnn_gpu_mem_ratio if args.gnn_gpu_mem_ratio is not None else cfg.get(
|
|
1196
|
+
"gnn_knn_gpu_mem_ratio", 0.9)
|
|
1197
|
+
gnn_gpu_mem_overhead = args.gnn_gpu_mem_overhead if args.gnn_gpu_mem_overhead is not None else cfg.get(
|
|
1198
|
+
"gnn_knn_gpu_mem_overhead", 2.0)
|
|
1199
|
+
|
|
1200
|
+
binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
|
|
1201
|
+
task_type = str(cfg.get("task_type", "regression"))
|
|
1202
|
+
feature_list = cfg.get("feature_list")
|
|
1203
|
+
categorical_features = cfg.get("categorical_features")
|
|
1204
|
+
use_gpu = bool(cfg.get("use_gpu", True))
|
|
1205
|
+
region_province_col = cfg.get("region_province_col")
|
|
1206
|
+
region_city_col = cfg.get("region_city_col")
|
|
1207
|
+
region_effect_alpha = cfg.get("region_effect_alpha")
|
|
1208
|
+
geo_feature_nmes = cfg.get("geo_feature_nmes")
|
|
1209
|
+
geo_token_hidden_dim = cfg.get("geo_token_hidden_dim")
|
|
1210
|
+
geo_token_layers = cfg.get("geo_token_layers")
|
|
1211
|
+
geo_token_dropout = cfg.get("geo_token_dropout")
|
|
1212
|
+
geo_token_k_neighbors = cfg.get("geo_token_k_neighbors")
|
|
1213
|
+
geo_token_learning_rate = cfg.get("geo_token_learning_rate")
|
|
1214
|
+
geo_token_epochs = cfg.get("geo_token_epochs")
|
|
1215
|
+
|
|
1216
|
+
ft_role = args.ft_role or cfg.get("ft_role", "model")
|
|
1217
|
+
if args.ft_as_feature and args.ft_role is None:
|
|
1218
|
+
# Keep legacy behavior as a convenience alias only when the config
|
|
1219
|
+
# didn't already request a non-default FT role.
|
|
1220
|
+
if str(cfg.get("ft_role", "model")) == "model":
|
|
1221
|
+
ft_role = "embedding"
|
|
1222
|
+
ft_feature_prefix = str(
|
|
1223
|
+
cfg.get("ft_feature_prefix", args.ft_feature_prefix))
|
|
1224
|
+
ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
|
|
1225
|
+
|
|
1226
|
+
config_fields = getattr(ropt.BayesOptConfig,
|
|
1227
|
+
"__dataclass_fields__", {})
|
|
1228
|
+
allowed_config_keys = set(config_fields.keys())
|
|
1229
|
+
config_payload = {k: v for k,
|
|
1230
|
+
v in cfg.items() if k in allowed_config_keys}
|
|
1231
|
+
config_payload.update({
|
|
1232
|
+
"model_nme": model_name,
|
|
1233
|
+
"resp_nme": cfg["target"],
|
|
1234
|
+
"weight_nme": cfg["weight"],
|
|
1235
|
+
"factor_nmes": feature_list,
|
|
1236
|
+
"task_type": task_type,
|
|
1237
|
+
"binary_resp_nme": binary_target,
|
|
1238
|
+
"cate_list": categorical_features,
|
|
1239
|
+
"prop_test": val_ratio,
|
|
1240
|
+
"rand_seed": rand_seed,
|
|
1241
|
+
"epochs": epochs,
|
|
1242
|
+
"use_gpu": use_gpu,
|
|
1243
|
+
"use_resn_data_parallel": use_resn_dp,
|
|
1244
|
+
"use_ft_data_parallel": use_ft_dp,
|
|
1245
|
+
"use_gnn_data_parallel": use_gnn_dp,
|
|
1246
|
+
"use_resn_ddp": use_resn_ddp,
|
|
1247
|
+
"use_ft_ddp": use_ft_ddp,
|
|
1248
|
+
"use_gnn_ddp": use_gnn_ddp,
|
|
1249
|
+
"output_dir": output_dir,
|
|
1250
|
+
"xgb_max_depth_max": xgb_max_depth_max,
|
|
1251
|
+
"xgb_n_estimators_max": xgb_n_estimators_max,
|
|
1252
|
+
"resn_weight_decay": cfg.get("resn_weight_decay"),
|
|
1253
|
+
"final_ensemble": bool(cfg.get("final_ensemble", False)),
|
|
1254
|
+
"final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
|
|
1255
|
+
"final_refit": bool(cfg.get("final_refit", True)),
|
|
1256
|
+
"optuna_storage": optuna_storage,
|
|
1257
|
+
"optuna_study_prefix": optuna_study_prefix,
|
|
1258
|
+
"best_params_files": best_params_files,
|
|
1259
|
+
"gnn_use_approx_knn": gnn_use_ann,
|
|
1260
|
+
"gnn_approx_knn_threshold": gnn_threshold,
|
|
1261
|
+
"gnn_graph_cache": gnn_graph_cache,
|
|
1262
|
+
"gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
|
|
1263
|
+
"gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
|
|
1264
|
+
"gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
|
|
1265
|
+
"region_province_col": region_province_col,
|
|
1266
|
+
"region_city_col": region_city_col,
|
|
1267
|
+
"region_effect_alpha": region_effect_alpha,
|
|
1268
|
+
"geo_feature_nmes": geo_feature_nmes,
|
|
1269
|
+
"geo_token_hidden_dim": geo_token_hidden_dim,
|
|
1270
|
+
"geo_token_layers": geo_token_layers,
|
|
1271
|
+
"geo_token_dropout": geo_token_dropout,
|
|
1272
|
+
"geo_token_k_neighbors": geo_token_k_neighbors,
|
|
1273
|
+
"geo_token_learning_rate": geo_token_learning_rate,
|
|
1274
|
+
"geo_token_epochs": geo_token_epochs,
|
|
1275
|
+
"ft_role": ft_role,
|
|
1276
|
+
"ft_feature_prefix": ft_feature_prefix,
|
|
1277
|
+
"ft_num_numeric_tokens": ft_num_numeric_tokens,
|
|
1278
|
+
"reuse_best_params": reuse_best_params,
|
|
1279
|
+
"bo_sample_limit": bo_sample_limit,
|
|
1280
|
+
"cache_predictions": cache_predictions,
|
|
1281
|
+
"prediction_cache_dir": prediction_cache_dir,
|
|
1282
|
+
"prediction_cache_format": prediction_cache_format,
|
|
1283
|
+
"cv_strategy": cv_strategy or split_strategy,
|
|
1284
|
+
"cv_group_col": cv_group_col or split_group_col,
|
|
1285
|
+
"cv_time_col": cv_time_col or split_time_col,
|
|
1286
|
+
"cv_time_ascending": cv_time_ascending,
|
|
1287
|
+
"cv_splits": cv_splits,
|
|
1288
|
+
"ft_oof_folds": ft_oof_folds,
|
|
1289
|
+
"ft_oof_strategy": ft_oof_strategy,
|
|
1290
|
+
"ft_oof_shuffle": ft_oof_shuffle,
|
|
1291
|
+
"save_preprocess": save_preprocess,
|
|
1292
|
+
"preprocess_artifact_path": preprocess_artifact_path,
|
|
1293
|
+
"plot_path_style": plot_path_style or "nested",
|
|
1294
|
+
})
|
|
1295
|
+
config_payload = {
|
|
1296
|
+
k: v for k, v in config_payload.items() if v is not None}
|
|
1297
|
+
config = ropt.BayesOptConfig(**config_payload)
|
|
1298
|
+
model = ropt.BayesOptModel(train_df, test_df, config=config)
|
|
1299
|
+
|
|
1300
|
+
if plot_requested:
|
|
1301
|
+
plot_cfg = cfg.get("plot", {})
|
|
1302
|
+
legacy_lift_flags = {
|
|
1303
|
+
"glm": cfg.get("plot_lift_glm", False),
|
|
1304
|
+
"xgb": cfg.get("plot_lift_xgb", False),
|
|
1305
|
+
"resn": cfg.get("plot_lift_resn", False),
|
|
1306
|
+
"ft": cfg.get("plot_lift_ft", False),
|
|
1307
|
+
}
|
|
1308
|
+
plot_enabled = plot_cfg.get(
|
|
1309
|
+
"enable", any(legacy_lift_flags.values()))
|
|
1310
|
+
if plot_enabled and plot_cfg.get("pre_oneway", False) and plot_cfg.get("oneway", True):
|
|
1311
|
+
n_bins = int(plot_cfg.get("n_bins", 10))
|
|
1312
|
+
model.plot_oneway(n_bins=n_bins, plot_subdir="oneway/pre")
|
|
1313
|
+
|
|
1314
|
+
if "all" in args.model_keys:
|
|
1315
|
+
requested_keys = ["glm", "xgb", "resn", "ft", "gnn"]
|
|
1316
|
+
else:
|
|
1317
|
+
requested_keys = args.model_keys
|
|
1318
|
+
requested_keys = dedupe_preserve_order(requested_keys)
|
|
1319
|
+
|
|
1320
|
+
if ft_role != "model":
|
|
1321
|
+
requested_keys = [k for k in requested_keys if k != "ft"]
|
|
1322
|
+
if not requested_keys:
|
|
1323
|
+
stack_keys = args.stack_model_keys or cfg.get(
|
|
1324
|
+
"stack_model_keys")
|
|
1325
|
+
if stack_keys:
|
|
1326
|
+
if "all" in stack_keys:
|
|
1327
|
+
requested_keys = ["glm", "xgb", "resn", "gnn"]
|
|
1328
|
+
else:
|
|
1329
|
+
requested_keys = [k for k in stack_keys if k != "ft"]
|
|
1330
|
+
requested_keys = dedupe_preserve_order(requested_keys)
|
|
1331
|
+
if dist_active and ddp_enabled:
|
|
1332
|
+
ft_trainer = model.trainers.get("ft")
|
|
1333
|
+
if ft_trainer is None:
|
|
1334
|
+
raise ValueError("FT trainer is not available.")
|
|
1335
|
+
ft_trainer_uses_ddp = bool(
|
|
1336
|
+
getattr(ft_trainer, "enable_distributed_optuna", False))
|
|
1337
|
+
if not ft_trainer_uses_ddp:
|
|
1338
|
+
raise ValueError(
|
|
1339
|
+
"FT embedding under torchrun requires enabling FT DDP (use --use-ft-ddp or set use_ft_ddp=true)."
|
|
1340
|
+
)
|
|
1341
|
+
missing = [key for key in requested_keys if key not in model.trainers]
|
|
1342
|
+
if missing:
|
|
1343
|
+
raise ValueError(
|
|
1344
|
+
f"Trainer(s) {missing} not available for {model_name}")
|
|
1345
|
+
|
|
1346
|
+
executed_keys: List[str] = []
|
|
1347
|
+
if ft_role != "model":
|
|
1348
|
+
if dist_active and not ddp_enabled:
|
|
1349
|
+
_ddp_barrier("start_ft_embedding")
|
|
1350
|
+
if dist_rank != 0:
|
|
1351
|
+
_ddp_barrier("finish_ft_embedding")
|
|
1352
|
+
continue
|
|
1353
|
+
print(
|
|
1354
|
+
f"Optimizing ft as {ft_role} for {model_name} (max_evals={args.max_evals})")
|
|
1355
|
+
model.optimize_model("ft", max_evals=args.max_evals)
|
|
1356
|
+
model.trainers["ft"].save()
|
|
1357
|
+
if getattr(ropt, "torch", None) is not None and ropt.torch.cuda.is_available():
|
|
1358
|
+
ropt.free_cuda()
|
|
1359
|
+
if dist_active and not ddp_enabled:
|
|
1360
|
+
_ddp_barrier("finish_ft_embedding")
|
|
1361
|
+
for key in requested_keys:
|
|
1362
|
+
trainer = model.trainers[key]
|
|
1363
|
+
trainer_uses_ddp = bool(
|
|
1364
|
+
getattr(trainer, "enable_distributed_optuna", False))
|
|
1365
|
+
if dist_active and not trainer_uses_ddp:
|
|
1366
|
+
if dist_rank != 0:
|
|
1367
|
+
print(
|
|
1368
|
+
f"[Rank {dist_rank}] Skip {model_name}/{key} because trainer is not DDP-enabled."
|
|
1369
|
+
)
|
|
1370
|
+
_ddp_barrier(f"start_non_ddp_{model_name}_{key}")
|
|
1371
|
+
if dist_rank != 0:
|
|
1372
|
+
_ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
|
|
1373
|
+
continue
|
|
1374
|
+
|
|
1375
|
+
print(
|
|
1376
|
+
f"Optimizing {key} for {model_name} (max_evals={args.max_evals})")
|
|
1377
|
+
model.optimize_model(key, max_evals=args.max_evals)
|
|
1378
|
+
model.trainers[key].save()
|
|
1379
|
+
_plot_loss_curve_for_trainer(model_name, model.trainers[key])
|
|
1380
|
+
if key in PYTORCH_TRAINERS:
|
|
1381
|
+
ropt.free_cuda()
|
|
1382
|
+
if dist_active and not trainer_uses_ddp:
|
|
1383
|
+
_ddp_barrier(f"finish_non_ddp_{model_name}_{key}")
|
|
1384
|
+
executed_keys.append(key)
|
|
1385
|
+
|
|
1386
|
+
if not executed_keys:
|
|
1387
|
+
continue
|
|
1388
|
+
|
|
1389
|
+
results[model_name] = model
|
|
1390
|
+
trained_keys_by_model[model_name] = executed_keys
|
|
1391
|
+
if report_enabled and is_main_process:
|
|
1392
|
+
psi_report_df = _compute_psi_report(
|
|
1393
|
+
model,
|
|
1394
|
+
features=psi_features,
|
|
1395
|
+
bins=psi_bins,
|
|
1396
|
+
strategy=str(psi_strategy),
|
|
1397
|
+
)
|
|
1398
|
+
for key in executed_keys:
|
|
1399
|
+
_evaluate_and_report(
|
|
1400
|
+
model,
|
|
1401
|
+
model_name=model_name,
|
|
1402
|
+
model_key=key,
|
|
1403
|
+
cfg=cfg,
|
|
1404
|
+
data_path=data_path,
|
|
1405
|
+
data_fingerprint=data_fingerprint,
|
|
1406
|
+
report_output_dir=report_output_dir,
|
|
1407
|
+
report_group_cols=report_group_cols,
|
|
1408
|
+
report_time_col=report_time_col,
|
|
1409
|
+
report_time_freq=str(report_time_freq),
|
|
1410
|
+
report_time_ascending=bool(report_time_ascending),
|
|
1411
|
+
psi_report_df=psi_report_df,
|
|
1412
|
+
calibration_cfg=calibration_cfg,
|
|
1413
|
+
threshold_cfg=threshold_cfg,
|
|
1414
|
+
bootstrap_cfg=bootstrap_cfg,
|
|
1415
|
+
register_model=register_model,
|
|
1416
|
+
registry_path=registry_path,
|
|
1417
|
+
registry_tags=registry_tags,
|
|
1418
|
+
registry_status=registry_status,
|
|
1419
|
+
run_id=run_id,
|
|
1420
|
+
config_sha=config_sha,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
if not plot_requested:
|
|
1424
|
+
return
|
|
1425
|
+
|
|
1426
|
+
for name, model in results.items():
|
|
1427
|
+
_plot_curves_for_model(
|
|
1428
|
+
model,
|
|
1429
|
+
trained_keys_by_model.get(name, []),
|
|
1430
|
+
cfg,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
|
|
1434
|
+
def main() -> None:
|
|
1435
|
+
if configure_run_logging:
|
|
1436
|
+
configure_run_logging(prefix="bayesopt_entry")
|
|
1437
|
+
args = _parse_args()
|
|
1438
|
+
train_from_config(args)
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
if __name__ == "__main__":
|
|
1442
|
+
main()
|