ins-pricing 0.2.9__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/CHANGELOG.md +93 -0
- ins_pricing/README.md +11 -0
- ins_pricing/cli/bayesopt_entry_runner.py +626 -499
- ins_pricing/cli/utils/evaluation_context.py +320 -0
- ins_pricing/cli/utils/import_resolver.py +350 -0
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
- ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
- ins_pricing/modelling/core/bayesopt/core.py +153 -94
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +118 -31
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +294 -139
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
- ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
- ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +587 -0
- ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
- ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/METADATA +162 -149
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/RECORD +26 -13
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -29,108 +29,60 @@ from typing import Any, Dict, List, Optional
|
|
|
29
29
|
import numpy as np
|
|
30
30
|
import pandas as pd
|
|
31
31
|
|
|
32
|
-
try
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
from ins_pricing.cli.utils.cli_common import ( # type: ignore
|
|
87
|
-
PLOT_MODEL_LABELS,
|
|
88
|
-
PYTORCH_TRAINERS,
|
|
89
|
-
build_model_names,
|
|
90
|
-
dedupe_preserve_order,
|
|
91
|
-
load_dataset,
|
|
92
|
-
parse_model_pairs,
|
|
93
|
-
resolve_data_path,
|
|
94
|
-
resolve_path,
|
|
95
|
-
fingerprint_file,
|
|
96
|
-
coerce_dataset_types,
|
|
97
|
-
split_train_test,
|
|
98
|
-
)
|
|
99
|
-
from ins_pricing.cli.utils.cli_config import ( # type: ignore
|
|
100
|
-
add_config_json_arg,
|
|
101
|
-
add_output_dir_arg,
|
|
102
|
-
resolve_and_load_config,
|
|
103
|
-
resolve_data_config,
|
|
104
|
-
resolve_report_config,
|
|
105
|
-
resolve_split_config,
|
|
106
|
-
resolve_runtime_config,
|
|
107
|
-
resolve_output_dirs,
|
|
108
|
-
)
|
|
109
|
-
except Exception:
|
|
110
|
-
import BayesOpt as ropt # type: ignore
|
|
111
|
-
from utils.cli_common import ( # type: ignore
|
|
112
|
-
PLOT_MODEL_LABELS,
|
|
113
|
-
PYTORCH_TRAINERS,
|
|
114
|
-
build_model_names,
|
|
115
|
-
dedupe_preserve_order,
|
|
116
|
-
load_dataset,
|
|
117
|
-
parse_model_pairs,
|
|
118
|
-
resolve_data_path,
|
|
119
|
-
resolve_path,
|
|
120
|
-
fingerprint_file,
|
|
121
|
-
coerce_dataset_types,
|
|
122
|
-
split_train_test,
|
|
123
|
-
)
|
|
124
|
-
from utils.cli_config import ( # type: ignore
|
|
125
|
-
add_config_json_arg,
|
|
126
|
-
add_output_dir_arg,
|
|
127
|
-
resolve_and_load_config,
|
|
128
|
-
resolve_data_config,
|
|
129
|
-
resolve_report_config,
|
|
130
|
-
resolve_split_config,
|
|
131
|
-
resolve_runtime_config,
|
|
132
|
-
resolve_output_dirs,
|
|
133
|
-
)
|
|
32
|
+
# Use unified import resolver to eliminate nested try/except chains
|
|
33
|
+
from .utils.import_resolver import resolve_imports, setup_sys_path
|
|
34
|
+
from .utils.evaluation_context import (
|
|
35
|
+
EvaluationContext,
|
|
36
|
+
TrainingContext,
|
|
37
|
+
ModelIdentity,
|
|
38
|
+
DataFingerprint,
|
|
39
|
+
CalibrationConfig,
|
|
40
|
+
ThresholdConfig,
|
|
41
|
+
BootstrapConfig,
|
|
42
|
+
ReportConfig,
|
|
43
|
+
RegistryConfig,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Resolve all imports from a single location
|
|
47
|
+
setup_sys_path()
|
|
48
|
+
_imports = resolve_imports()
|
|
49
|
+
|
|
50
|
+
ropt = _imports.bayesopt
|
|
51
|
+
PLOT_MODEL_LABELS = _imports.PLOT_MODEL_LABELS
|
|
52
|
+
PYTORCH_TRAINERS = _imports.PYTORCH_TRAINERS
|
|
53
|
+
build_model_names = _imports.build_model_names
|
|
54
|
+
dedupe_preserve_order = _imports.dedupe_preserve_order
|
|
55
|
+
load_dataset = _imports.load_dataset
|
|
56
|
+
parse_model_pairs = _imports.parse_model_pairs
|
|
57
|
+
resolve_data_path = _imports.resolve_data_path
|
|
58
|
+
resolve_path = _imports.resolve_path
|
|
59
|
+
fingerprint_file = _imports.fingerprint_file
|
|
60
|
+
coerce_dataset_types = _imports.coerce_dataset_types
|
|
61
|
+
split_train_test = _imports.split_train_test
|
|
62
|
+
|
|
63
|
+
add_config_json_arg = _imports.add_config_json_arg
|
|
64
|
+
add_output_dir_arg = _imports.add_output_dir_arg
|
|
65
|
+
resolve_and_load_config = _imports.resolve_and_load_config
|
|
66
|
+
resolve_data_config = _imports.resolve_data_config
|
|
67
|
+
resolve_report_config = _imports.resolve_report_config
|
|
68
|
+
resolve_split_config = _imports.resolve_split_config
|
|
69
|
+
resolve_runtime_config = _imports.resolve_runtime_config
|
|
70
|
+
resolve_output_dirs = _imports.resolve_output_dirs
|
|
71
|
+
|
|
72
|
+
bootstrap_ci = _imports.bootstrap_ci
|
|
73
|
+
calibrate_predictions = _imports.calibrate_predictions
|
|
74
|
+
eval_metrics_report = _imports.metrics_report
|
|
75
|
+
select_threshold = _imports.select_threshold
|
|
76
|
+
|
|
77
|
+
ModelArtifact = _imports.ModelArtifact
|
|
78
|
+
ModelRegistry = _imports.ModelRegistry
|
|
79
|
+
drift_psi_report = _imports.drift_psi_report
|
|
80
|
+
group_metrics = _imports.group_metrics
|
|
81
|
+
ReportPayload = _imports.ReportPayload
|
|
82
|
+
write_report = _imports.write_report
|
|
83
|
+
|
|
84
|
+
configure_run_logging = _imports.configure_run_logging
|
|
85
|
+
plot_loss_curve_common = _imports.plot_loss_curve
|
|
134
86
|
|
|
135
87
|
import matplotlib
|
|
136
88
|
|
|
@@ -138,81 +90,6 @@ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPL
|
|
|
138
90
|
matplotlib.use("Agg")
|
|
139
91
|
import matplotlib.pyplot as plt
|
|
140
92
|
|
|
141
|
-
try:
|
|
142
|
-
from .utils.run_logging import configure_run_logging # type: ignore
|
|
143
|
-
except Exception: # pragma: no cover
|
|
144
|
-
try:
|
|
145
|
-
from utils.run_logging import configure_run_logging # type: ignore
|
|
146
|
-
except Exception: # pragma: no cover
|
|
147
|
-
configure_run_logging = None # type: ignore
|
|
148
|
-
|
|
149
|
-
try:
|
|
150
|
-
from ..modelling.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
151
|
-
except Exception: # pragma: no cover
|
|
152
|
-
try:
|
|
153
|
-
from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
154
|
-
except Exception: # pragma: no cover
|
|
155
|
-
plot_loss_curve_common = None
|
|
156
|
-
|
|
157
|
-
try:
|
|
158
|
-
from ..modelling.core.evaluation import ( # type: ignore
|
|
159
|
-
bootstrap_ci,
|
|
160
|
-
calibrate_predictions,
|
|
161
|
-
metrics_report as eval_metrics_report,
|
|
162
|
-
select_threshold,
|
|
163
|
-
)
|
|
164
|
-
from ..governance.registry import ModelArtifact, ModelRegistry # type: ignore
|
|
165
|
-
from ..production import psi_report as drift_psi_report # type: ignore
|
|
166
|
-
from ..production.monitoring import group_metrics # type: ignore
|
|
167
|
-
from ..reporting.report_builder import ReportPayload, write_report # type: ignore
|
|
168
|
-
except Exception: # pragma: no cover
|
|
169
|
-
try:
|
|
170
|
-
from ins_pricing.modelling.core.evaluation import ( # type: ignore
|
|
171
|
-
bootstrap_ci,
|
|
172
|
-
calibrate_predictions,
|
|
173
|
-
metrics_report as eval_metrics_report,
|
|
174
|
-
select_threshold,
|
|
175
|
-
)
|
|
176
|
-
from ins_pricing.governance.registry import ( # type: ignore
|
|
177
|
-
ModelArtifact,
|
|
178
|
-
ModelRegistry,
|
|
179
|
-
)
|
|
180
|
-
from ins_pricing.production import psi_report as drift_psi_report # type: ignore
|
|
181
|
-
from ins_pricing.production.monitoring import group_metrics # type: ignore
|
|
182
|
-
from ins_pricing.reporting.report_builder import ( # type: ignore
|
|
183
|
-
ReportPayload,
|
|
184
|
-
write_report,
|
|
185
|
-
)
|
|
186
|
-
except Exception: # pragma: no cover
|
|
187
|
-
try:
|
|
188
|
-
from evaluation import ( # type: ignore
|
|
189
|
-
bootstrap_ci,
|
|
190
|
-
calibrate_predictions,
|
|
191
|
-
metrics_report as eval_metrics_report,
|
|
192
|
-
select_threshold,
|
|
193
|
-
)
|
|
194
|
-
from ins_pricing.governance.registry import ( # type: ignore
|
|
195
|
-
ModelArtifact,
|
|
196
|
-
ModelRegistry,
|
|
197
|
-
)
|
|
198
|
-
from ins_pricing.production import psi_report as drift_psi_report # type: ignore
|
|
199
|
-
from ins_pricing.production.monitoring import group_metrics # type: ignore
|
|
200
|
-
from ins_pricing.reporting.report_builder import ( # type: ignore
|
|
201
|
-
ReportPayload,
|
|
202
|
-
write_report,
|
|
203
|
-
)
|
|
204
|
-
except Exception: # pragma: no cover
|
|
205
|
-
bootstrap_ci = None # type: ignore
|
|
206
|
-
calibrate_predictions = None # type: ignore
|
|
207
|
-
eval_metrics_report = None # type: ignore
|
|
208
|
-
select_threshold = None # type: ignore
|
|
209
|
-
drift_psi_report = None # type: ignore
|
|
210
|
-
group_metrics = None # type: ignore
|
|
211
|
-
ReportPayload = None # type: ignore
|
|
212
|
-
write_report = None # type: ignore
|
|
213
|
-
ModelRegistry = None # type: ignore
|
|
214
|
-
ModelArtifact = None # type: ignore
|
|
215
|
-
|
|
216
93
|
|
|
217
94
|
def _parse_args() -> argparse.Namespace:
|
|
218
95
|
parser = argparse.ArgumentParser(
|
|
@@ -520,6 +397,444 @@ def _compute_psi_report(
|
|
|
520
397
|
return None
|
|
521
398
|
|
|
522
399
|
|
|
400
|
+
# --- Refactored helper functions for _evaluate_and_report ---
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _apply_calibration(
|
|
404
|
+
y_true_train: np.ndarray,
|
|
405
|
+
y_pred_train: np.ndarray,
|
|
406
|
+
y_pred_test: np.ndarray,
|
|
407
|
+
calibration_cfg: Dict[str, Any],
|
|
408
|
+
model_name: str,
|
|
409
|
+
model_key: str,
|
|
410
|
+
) -> tuple[np.ndarray, np.ndarray, Optional[Dict[str, Any]]]:
|
|
411
|
+
"""Apply calibration to predictions for classification tasks.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Tuple of (calibrated_train_preds, calibrated_test_preds, calibration_info)
|
|
415
|
+
"""
|
|
416
|
+
cal_cfg = dict(calibration_cfg or {})
|
|
417
|
+
cal_enabled = bool(cal_cfg.get("enable", False) or cal_cfg.get("method"))
|
|
418
|
+
|
|
419
|
+
if not cal_enabled or calibrate_predictions is None:
|
|
420
|
+
return y_pred_train, y_pred_test, None
|
|
421
|
+
|
|
422
|
+
method = cal_cfg.get("method", "sigmoid")
|
|
423
|
+
max_rows = cal_cfg.get("max_rows")
|
|
424
|
+
seed = cal_cfg.get("seed")
|
|
425
|
+
y_cal, p_cal = _sample_arrays(
|
|
426
|
+
y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
|
|
427
|
+
|
|
428
|
+
try:
|
|
429
|
+
calibrator = calibrate_predictions(y_cal, p_cal, method=method)
|
|
430
|
+
calibrated_train = calibrator.predict(y_pred_train)
|
|
431
|
+
calibrated_test = calibrator.predict(y_pred_test)
|
|
432
|
+
calibration_info = {"method": calibrator.method, "max_rows": max_rows}
|
|
433
|
+
return calibrated_train, calibrated_test, calibration_info
|
|
434
|
+
except Exception as exc:
|
|
435
|
+
print(f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
|
|
436
|
+
return y_pred_train, y_pred_test, None
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _select_classification_threshold(
|
|
440
|
+
y_true_train: np.ndarray,
|
|
441
|
+
y_pred_train_eval: np.ndarray,
|
|
442
|
+
threshold_cfg: Dict[str, Any],
|
|
443
|
+
) -> tuple[float, Optional[Dict[str, Any]]]:
|
|
444
|
+
"""Select threshold for classification predictions.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Tuple of (threshold_value, threshold_info)
|
|
448
|
+
"""
|
|
449
|
+
thr_cfg = dict(threshold_cfg or {})
|
|
450
|
+
thr_enabled = bool(
|
|
451
|
+
thr_cfg.get("enable", False)
|
|
452
|
+
or thr_cfg.get("metric")
|
|
453
|
+
or thr_cfg.get("value") is not None
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
if thr_cfg.get("value") is not None:
|
|
457
|
+
threshold_value = float(thr_cfg["value"])
|
|
458
|
+
return threshold_value, {"threshold": threshold_value, "source": "fixed"}
|
|
459
|
+
|
|
460
|
+
if thr_enabled and select_threshold is not None:
|
|
461
|
+
max_rows = thr_cfg.get("max_rows")
|
|
462
|
+
seed = thr_cfg.get("seed")
|
|
463
|
+
y_thr, p_thr = _sample_arrays(
|
|
464
|
+
y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
|
|
465
|
+
threshold_info = select_threshold(
|
|
466
|
+
y_thr,
|
|
467
|
+
p_thr,
|
|
468
|
+
metric=thr_cfg.get("metric", "f1"),
|
|
469
|
+
min_positive_rate=thr_cfg.get("min_positive_rate"),
|
|
470
|
+
grid=thr_cfg.get("grid", 99),
|
|
471
|
+
)
|
|
472
|
+
return float(threshold_info.get("threshold", 0.5)), threshold_info
|
|
473
|
+
|
|
474
|
+
return 0.5, None
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _compute_classification_metrics(
|
|
478
|
+
y_true_test: np.ndarray,
|
|
479
|
+
y_pred_test_eval: np.ndarray,
|
|
480
|
+
threshold_value: float,
|
|
481
|
+
) -> Dict[str, Any]:
|
|
482
|
+
"""Compute metrics for classification task."""
|
|
483
|
+
metrics = eval_metrics_report(
|
|
484
|
+
y_true_test,
|
|
485
|
+
y_pred_test_eval,
|
|
486
|
+
task_type="classification",
|
|
487
|
+
threshold=threshold_value,
|
|
488
|
+
)
|
|
489
|
+
precision = float(metrics.get("precision", 0.0))
|
|
490
|
+
recall = float(metrics.get("recall", 0.0))
|
|
491
|
+
f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
|
|
492
|
+
metrics["f1"] = float(f1)
|
|
493
|
+
metrics["threshold"] = float(threshold_value)
|
|
494
|
+
return metrics
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def _compute_bootstrap_ci(
|
|
498
|
+
y_true_test: np.ndarray,
|
|
499
|
+
y_pred_test_eval: np.ndarray,
|
|
500
|
+
weight_test: Optional[np.ndarray],
|
|
501
|
+
metrics: Dict[str, Any],
|
|
502
|
+
bootstrap_cfg: Dict[str, Any],
|
|
503
|
+
task_type: str,
|
|
504
|
+
) -> Dict[str, Dict[str, float]]:
|
|
505
|
+
"""Compute bootstrap confidence intervals for metrics."""
|
|
506
|
+
if not bootstrap_cfg or not bool(bootstrap_cfg.get("enable", False)) or bootstrap_ci is None:
|
|
507
|
+
return {}
|
|
508
|
+
|
|
509
|
+
metric_names = bootstrap_cfg.get("metrics")
|
|
510
|
+
if not metric_names:
|
|
511
|
+
metric_names = [name for name in metrics.keys() if name != "threshold"]
|
|
512
|
+
n_samples = int(bootstrap_cfg.get("n_samples", 200))
|
|
513
|
+
ci = float(bootstrap_cfg.get("ci", 0.95))
|
|
514
|
+
seed = bootstrap_cfg.get("seed")
|
|
515
|
+
|
|
516
|
+
def _metric_fn(y_true, y_pred, weight=None):
|
|
517
|
+
vals = eval_metrics_report(
|
|
518
|
+
y_true,
|
|
519
|
+
y_pred,
|
|
520
|
+
task_type=task_type,
|
|
521
|
+
weight=weight,
|
|
522
|
+
threshold=metrics.get("threshold", 0.5),
|
|
523
|
+
)
|
|
524
|
+
if task_type == "classification":
|
|
525
|
+
prec = float(vals.get("precision", 0.0))
|
|
526
|
+
rec = float(vals.get("recall", 0.0))
|
|
527
|
+
vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
|
|
528
|
+
return vals
|
|
529
|
+
|
|
530
|
+
bootstrap_results: Dict[str, Dict[str, float]] = {}
|
|
531
|
+
for name in metric_names:
|
|
532
|
+
if name not in metrics:
|
|
533
|
+
continue
|
|
534
|
+
ci_result = bootstrap_ci(
|
|
535
|
+
lambda y_t, y_p, w=None: float(_metric_fn(y_t, y_p, w).get(name, 0.0)),
|
|
536
|
+
y_true_test,
|
|
537
|
+
y_pred_test_eval,
|
|
538
|
+
weight=weight_test,
|
|
539
|
+
n_samples=n_samples,
|
|
540
|
+
ci=ci,
|
|
541
|
+
seed=seed,
|
|
542
|
+
)
|
|
543
|
+
bootstrap_results[str(name)] = ci_result
|
|
544
|
+
|
|
545
|
+
return bootstrap_results
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def _compute_validation_table(
|
|
549
|
+
model: ropt.BayesOptModel,
|
|
550
|
+
pred_col: str,
|
|
551
|
+
report_group_cols: Optional[List[str]],
|
|
552
|
+
weight_col: Optional[str],
|
|
553
|
+
model_name: str,
|
|
554
|
+
model_key: str,
|
|
555
|
+
) -> Optional[pd.DataFrame]:
|
|
556
|
+
"""Compute grouped validation metrics table."""
|
|
557
|
+
if not report_group_cols or group_metrics is None:
|
|
558
|
+
return None
|
|
559
|
+
|
|
560
|
+
available_groups = [
|
|
561
|
+
col for col in report_group_cols if col in model.test_data.columns
|
|
562
|
+
]
|
|
563
|
+
if not available_groups:
|
|
564
|
+
return None
|
|
565
|
+
|
|
566
|
+
try:
|
|
567
|
+
validation_table = group_metrics(
|
|
568
|
+
model.test_data,
|
|
569
|
+
actual_col=model.resp_nme,
|
|
570
|
+
pred_col=pred_col,
|
|
571
|
+
group_cols=available_groups,
|
|
572
|
+
weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
|
|
573
|
+
)
|
|
574
|
+
counts = (
|
|
575
|
+
model.test_data.groupby(available_groups, dropna=False)
|
|
576
|
+
.size()
|
|
577
|
+
.reset_index(name="count")
|
|
578
|
+
)
|
|
579
|
+
return validation_table.merge(counts, on=available_groups, how="left")
|
|
580
|
+
except Exception as exc:
|
|
581
|
+
print(f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _compute_risk_trend(
|
|
586
|
+
model: ropt.BayesOptModel,
|
|
587
|
+
pred_col: str,
|
|
588
|
+
report_time_col: Optional[str],
|
|
589
|
+
report_time_freq: str,
|
|
590
|
+
report_time_ascending: bool,
|
|
591
|
+
weight_col: Optional[str],
|
|
592
|
+
model_name: str,
|
|
593
|
+
model_key: str,
|
|
594
|
+
) -> Optional[pd.DataFrame]:
|
|
595
|
+
"""Compute time-series risk trend metrics."""
|
|
596
|
+
if not report_time_col or group_metrics is None:
|
|
597
|
+
return None
|
|
598
|
+
|
|
599
|
+
if report_time_col not in model.test_data.columns:
|
|
600
|
+
return None
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
time_df = model.test_data.copy()
|
|
604
|
+
time_series = pd.to_datetime(time_df[report_time_col], errors="coerce")
|
|
605
|
+
time_df = time_df.loc[time_series.notna()].copy()
|
|
606
|
+
|
|
607
|
+
if time_df.empty:
|
|
608
|
+
return None
|
|
609
|
+
|
|
610
|
+
time_df["_time_bucket"] = (
|
|
611
|
+
pd.to_datetime(time_df[report_time_col], errors="coerce")
|
|
612
|
+
.dt.to_period(report_time_freq)
|
|
613
|
+
.dt.to_timestamp()
|
|
614
|
+
)
|
|
615
|
+
risk_trend = group_metrics(
|
|
616
|
+
time_df,
|
|
617
|
+
actual_col=model.resp_nme,
|
|
618
|
+
pred_col=pred_col,
|
|
619
|
+
group_cols=["_time_bucket"],
|
|
620
|
+
weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
|
|
621
|
+
)
|
|
622
|
+
counts = (
|
|
623
|
+
time_df.groupby("_time_bucket", dropna=False)
|
|
624
|
+
.size()
|
|
625
|
+
.reset_index(name="count")
|
|
626
|
+
)
|
|
627
|
+
risk_trend = risk_trend.merge(counts, on="_time_bucket", how="left")
|
|
628
|
+
risk_trend = risk_trend.sort_values(
|
|
629
|
+
"_time_bucket", ascending=bool(report_time_ascending)
|
|
630
|
+
).reset_index(drop=True)
|
|
631
|
+
return risk_trend.rename(columns={"_time_bucket": report_time_col})
|
|
632
|
+
except Exception as exc:
|
|
633
|
+
print(f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
|
|
634
|
+
return None
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def _write_metrics_json(
|
|
638
|
+
report_root: Path,
|
|
639
|
+
model_name: str,
|
|
640
|
+
model_key: str,
|
|
641
|
+
version: str,
|
|
642
|
+
metrics: Dict[str, Any],
|
|
643
|
+
threshold_info: Optional[Dict[str, Any]],
|
|
644
|
+
calibration_info: Optional[Dict[str, Any]],
|
|
645
|
+
bootstrap_results: Dict[str, Dict[str, float]],
|
|
646
|
+
data_path: Path,
|
|
647
|
+
data_fingerprint: Dict[str, Any],
|
|
648
|
+
config_sha: str,
|
|
649
|
+
pred_col: str,
|
|
650
|
+
task_type: str,
|
|
651
|
+
) -> Path:
|
|
652
|
+
"""Write metrics to JSON file and return the path."""
|
|
653
|
+
metrics_payload = {
|
|
654
|
+
"model_name": model_name,
|
|
655
|
+
"model_key": model_key,
|
|
656
|
+
"model_version": version,
|
|
657
|
+
"metrics": metrics,
|
|
658
|
+
"threshold": threshold_info,
|
|
659
|
+
"calibration": calibration_info,
|
|
660
|
+
"bootstrap": bootstrap_results,
|
|
661
|
+
"data_path": str(data_path),
|
|
662
|
+
"data_fingerprint": data_fingerprint,
|
|
663
|
+
"config_sha256": config_sha,
|
|
664
|
+
"pred_col": pred_col,
|
|
665
|
+
"task_type": task_type,
|
|
666
|
+
}
|
|
667
|
+
metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
|
|
668
|
+
metrics_path.write_text(
|
|
669
|
+
json.dumps(metrics_payload, indent=2, ensure_ascii=True),
|
|
670
|
+
encoding="utf-8",
|
|
671
|
+
)
|
|
672
|
+
return metrics_path
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _write_model_report(
|
|
676
|
+
report_root: Path,
|
|
677
|
+
model_name: str,
|
|
678
|
+
model_key: str,
|
|
679
|
+
version: str,
|
|
680
|
+
metrics: Dict[str, Any],
|
|
681
|
+
risk_trend: Optional[pd.DataFrame],
|
|
682
|
+
psi_report_df: Optional[pd.DataFrame],
|
|
683
|
+
validation_table: Optional[pd.DataFrame],
|
|
684
|
+
calibration_info: Optional[Dict[str, Any]],
|
|
685
|
+
threshold_info: Optional[Dict[str, Any]],
|
|
686
|
+
bootstrap_results: Dict[str, Dict[str, float]],
|
|
687
|
+
config_sha: str,
|
|
688
|
+
data_fingerprint: Dict[str, Any],
|
|
689
|
+
) -> Optional[Path]:
|
|
690
|
+
"""Write model report and return the path."""
|
|
691
|
+
if ReportPayload is None or write_report is None:
|
|
692
|
+
return None
|
|
693
|
+
|
|
694
|
+
notes_lines = [
|
|
695
|
+
f"- Config SHA256: {config_sha}",
|
|
696
|
+
f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
|
|
697
|
+
]
|
|
698
|
+
if calibration_info:
|
|
699
|
+
notes_lines.append(f"- Calibration: {calibration_info.get('method')}")
|
|
700
|
+
if threshold_info:
|
|
701
|
+
notes_lines.append(f"- Threshold selection: {threshold_info}")
|
|
702
|
+
if bootstrap_results:
|
|
703
|
+
notes_lines.append("- Bootstrap: see metrics JSON for CI")
|
|
704
|
+
|
|
705
|
+
payload = ReportPayload(
|
|
706
|
+
model_name=f"{model_name}/{model_key}",
|
|
707
|
+
model_version=version,
|
|
708
|
+
metrics={k: float(v) for k, v in metrics.items()},
|
|
709
|
+
risk_trend=risk_trend,
|
|
710
|
+
drift_report=psi_report_df,
|
|
711
|
+
validation_table=validation_table,
|
|
712
|
+
extra_notes="\n".join(notes_lines),
|
|
713
|
+
)
|
|
714
|
+
return write_report(
|
|
715
|
+
payload,
|
|
716
|
+
report_root / f"{model_name}_{model_key}_report.md",
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def _register_model_to_registry(
|
|
721
|
+
model: ropt.BayesOptModel,
|
|
722
|
+
model_name: str,
|
|
723
|
+
model_key: str,
|
|
724
|
+
version: str,
|
|
725
|
+
metrics: Dict[str, Any],
|
|
726
|
+
task_type: str,
|
|
727
|
+
data_path: Path,
|
|
728
|
+
data_fingerprint: Dict[str, Any],
|
|
729
|
+
config_sha: str,
|
|
730
|
+
registry_path: Optional[str],
|
|
731
|
+
registry_tags: Dict[str, Any],
|
|
732
|
+
registry_status: str,
|
|
733
|
+
report_path: Optional[Path],
|
|
734
|
+
metrics_path: Path,
|
|
735
|
+
cfg: Dict[str, Any],
|
|
736
|
+
) -> None:
|
|
737
|
+
"""Register model artifacts to the model registry."""
|
|
738
|
+
if ModelRegistry is None or ModelArtifact is None:
|
|
739
|
+
return
|
|
740
|
+
|
|
741
|
+
registry = ModelRegistry(
|
|
742
|
+
registry_path
|
|
743
|
+
if registry_path
|
|
744
|
+
else Path(model.output_manager.result_dir) / "model_registry.json"
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
|
|
748
|
+
tags.update({
|
|
749
|
+
"model_key": str(model_key),
|
|
750
|
+
"task_type": str(task_type),
|
|
751
|
+
"data_path": str(data_path),
|
|
752
|
+
"data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
|
|
753
|
+
"data_size": str(data_fingerprint.get("size", "")),
|
|
754
|
+
"data_mtime": str(data_fingerprint.get("mtime", "")),
|
|
755
|
+
"config_sha256": str(config_sha),
|
|
756
|
+
})
|
|
757
|
+
|
|
758
|
+
artifacts = _collect_model_artifacts(
|
|
759
|
+
model, model_name, model_key, report_path, metrics_path, cfg
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
registry.register(
|
|
763
|
+
name=str(model_name),
|
|
764
|
+
version=version,
|
|
765
|
+
metrics={k: float(v) for k, v in metrics.items()},
|
|
766
|
+
tags=tags,
|
|
767
|
+
artifacts=artifacts,
|
|
768
|
+
status=str(registry_status or "candidate"),
|
|
769
|
+
notes=f"model_key={model_key}",
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _collect_model_artifacts(
|
|
774
|
+
model: ropt.BayesOptModel,
|
|
775
|
+
model_name: str,
|
|
776
|
+
model_key: str,
|
|
777
|
+
report_path: Optional[Path],
|
|
778
|
+
metrics_path: Path,
|
|
779
|
+
cfg: Dict[str, Any],
|
|
780
|
+
) -> List:
|
|
781
|
+
"""Collect all model artifacts for registry."""
|
|
782
|
+
artifacts = []
|
|
783
|
+
|
|
784
|
+
# Trained model artifact
|
|
785
|
+
trainer = model.trainers.get(model_key)
|
|
786
|
+
if trainer is not None:
|
|
787
|
+
try:
|
|
788
|
+
model_path = trainer.output.model_path(trainer._get_model_filename())
|
|
789
|
+
if os.path.exists(model_path):
|
|
790
|
+
artifacts.append(ModelArtifact(path=model_path, description="trained model"))
|
|
791
|
+
except Exception:
|
|
792
|
+
pass
|
|
793
|
+
|
|
794
|
+
# Report artifact
|
|
795
|
+
if report_path is not None:
|
|
796
|
+
artifacts.append(ModelArtifact(path=str(report_path), description="model report"))
|
|
797
|
+
|
|
798
|
+
# Metrics JSON artifact
|
|
799
|
+
if metrics_path.exists():
|
|
800
|
+
artifacts.append(ModelArtifact(path=str(metrics_path), description="metrics json"))
|
|
801
|
+
|
|
802
|
+
# Preprocess artifacts
|
|
803
|
+
if bool(cfg.get("save_preprocess", False)):
|
|
804
|
+
artifact_path = cfg.get("preprocess_artifact_path")
|
|
805
|
+
if artifact_path:
|
|
806
|
+
preprocess_path = Path(str(artifact_path))
|
|
807
|
+
if not preprocess_path.is_absolute():
|
|
808
|
+
preprocess_path = Path(model.output_manager.result_dir) / preprocess_path
|
|
809
|
+
else:
|
|
810
|
+
preprocess_path = Path(model.output_manager.result_path(
|
|
811
|
+
f"{model.model_nme}_preprocess.json"
|
|
812
|
+
))
|
|
813
|
+
if preprocess_path.exists():
|
|
814
|
+
artifacts.append(
|
|
815
|
+
ModelArtifact(path=str(preprocess_path), description="preprocess artifacts")
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Prediction cache artifacts
|
|
819
|
+
if bool(cfg.get("cache_predictions", False)):
|
|
820
|
+
cache_dir = cfg.get("prediction_cache_dir")
|
|
821
|
+
if cache_dir:
|
|
822
|
+
pred_root = Path(str(cache_dir))
|
|
823
|
+
if not pred_root.is_absolute():
|
|
824
|
+
pred_root = Path(model.output_manager.result_dir) / pred_root
|
|
825
|
+
else:
|
|
826
|
+
pred_root = Path(model.output_manager.result_dir) / "predictions"
|
|
827
|
+
ext = "csv" if str(cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
|
|
828
|
+
for split_label in ("train", "test"):
|
|
829
|
+
pred_path = pred_root / f"{model_name}_{model_key}_{split_label}.{ext}"
|
|
830
|
+
if pred_path.exists():
|
|
831
|
+
artifacts.append(
|
|
832
|
+
ModelArtifact(path=str(pred_path), description=f"predictions {split_label}")
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
return artifacts
|
|
836
|
+
|
|
837
|
+
|
|
523
838
|
def _evaluate_and_report(
|
|
524
839
|
model: ropt.BayesOptModel,
|
|
525
840
|
*,
|
|
@@ -544,374 +859,164 @@ def _evaluate_and_report(
|
|
|
544
859
|
run_id: str,
|
|
545
860
|
config_sha: str,
|
|
546
861
|
) -> None:
|
|
862
|
+
"""Evaluate model predictions and generate reports.
|
|
863
|
+
|
|
864
|
+
This function orchestrates the evaluation pipeline:
|
|
865
|
+
1. Extract predictions and ground truth
|
|
866
|
+
2. Apply calibration (for classification)
|
|
867
|
+
3. Select threshold (for classification)
|
|
868
|
+
4. Compute metrics
|
|
869
|
+
5. Compute bootstrap confidence intervals
|
|
870
|
+
6. Generate validation tables and risk trends
|
|
871
|
+
7. Write reports and register model
|
|
872
|
+
"""
|
|
547
873
|
if eval_metrics_report is None:
|
|
548
874
|
print("[Report] Skip evaluation: metrics module unavailable.")
|
|
549
875
|
return
|
|
550
876
|
|
|
551
877
|
pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
|
|
552
878
|
if pred_col not in model.test_data.columns:
|
|
553
|
-
print(
|
|
554
|
-
f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
|
|
879
|
+
print(f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
|
|
555
880
|
return
|
|
556
881
|
|
|
882
|
+
# Extract predictions and weights
|
|
557
883
|
weight_col = getattr(model, "weight_nme", None)
|
|
558
|
-
y_true_train = model.train_data[model.resp_nme].to_numpy(
|
|
559
|
-
|
|
560
|
-
y_true_test = model.test_data[model.resp_nme].to_numpy(
|
|
561
|
-
dtype=float, copy=False)
|
|
884
|
+
y_true_train = model.train_data[model.resp_nme].to_numpy(dtype=float, copy=False)
|
|
885
|
+
y_true_test = model.test_data[model.resp_nme].to_numpy(dtype=float, copy=False)
|
|
562
886
|
y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
|
|
563
887
|
y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
|
|
564
|
-
weight_train = (
|
|
565
|
-
model.train_data[weight_col].to_numpy(dtype=float, copy=False)
|
|
566
|
-
if weight_col and weight_col in model.train_data.columns
|
|
567
|
-
else None
|
|
568
|
-
)
|
|
569
888
|
weight_test = (
|
|
570
889
|
model.test_data[weight_col].to_numpy(dtype=float, copy=False)
|
|
571
890
|
if weight_col and weight_col in model.test_data.columns
|
|
572
891
|
else None
|
|
573
892
|
)
|
|
574
893
|
|
|
575
|
-
task_type = str(cfg.get("task_type", getattr(
|
|
576
|
-
|
|
894
|
+
task_type = str(cfg.get("task_type", getattr(model, "task_type", "regression")))
|
|
895
|
+
|
|
896
|
+
# Process based on task type
|
|
577
897
|
if task_type == "classification":
|
|
578
898
|
y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
|
|
579
899
|
y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
|
|
580
900
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
y_pred_train_eval = y_pred_train
|
|
584
|
-
y_pred_test_eval = y_pred_test
|
|
585
|
-
|
|
586
|
-
if task_type == "classification":
|
|
587
|
-
cal_cfg = dict(calibration_cfg or {})
|
|
588
|
-
cal_enabled = bool(cal_cfg.get("enable", False)
|
|
589
|
-
or cal_cfg.get("method"))
|
|
590
|
-
if cal_enabled and calibrate_predictions is not None:
|
|
591
|
-
method = cal_cfg.get("method", "sigmoid")
|
|
592
|
-
max_rows = cal_cfg.get("max_rows")
|
|
593
|
-
seed = cal_cfg.get("seed")
|
|
594
|
-
y_cal, p_cal = _sample_arrays(
|
|
595
|
-
y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
|
|
596
|
-
try:
|
|
597
|
-
calibrator = calibrate_predictions(y_cal, p_cal, method=method)
|
|
598
|
-
y_pred_train_eval = calibrator.predict(y_pred_train)
|
|
599
|
-
y_pred_test_eval = calibrator.predict(y_pred_test)
|
|
600
|
-
calibration_info = {
|
|
601
|
-
"method": calibrator.method, "max_rows": max_rows}
|
|
602
|
-
except Exception as exc:
|
|
603
|
-
print(
|
|
604
|
-
f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
|
|
605
|
-
|
|
606
|
-
thr_cfg = dict(threshold_cfg or {})
|
|
607
|
-
thr_enabled = bool(
|
|
608
|
-
thr_cfg.get("enable", False)
|
|
609
|
-
or thr_cfg.get("metric")
|
|
610
|
-
or thr_cfg.get("value") is not None
|
|
901
|
+
y_pred_train_eval, y_pred_test_eval, calibration_info = _apply_calibration(
|
|
902
|
+
y_true_train, y_pred_train, y_pred_test, calibration_cfg, model_name, model_key
|
|
611
903
|
)
|
|
612
|
-
threshold_value =
|
|
613
|
-
|
|
614
|
-
threshold_value = float(thr_cfg["value"])
|
|
615
|
-
threshold_info = {"threshold": threshold_value, "source": "fixed"}
|
|
616
|
-
elif thr_enabled and select_threshold is not None:
|
|
617
|
-
max_rows = thr_cfg.get("max_rows")
|
|
618
|
-
seed = thr_cfg.get("seed")
|
|
619
|
-
y_thr, p_thr = _sample_arrays(
|
|
620
|
-
y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
|
|
621
|
-
threshold_info = select_threshold(
|
|
622
|
-
y_thr,
|
|
623
|
-
p_thr,
|
|
624
|
-
metric=thr_cfg.get("metric", "f1"),
|
|
625
|
-
min_positive_rate=thr_cfg.get("min_positive_rate"),
|
|
626
|
-
grid=thr_cfg.get("grid", 99),
|
|
627
|
-
)
|
|
628
|
-
threshold_value = float(threshold_info.get("threshold", 0.5))
|
|
629
|
-
else:
|
|
630
|
-
threshold_value = 0.5
|
|
631
|
-
metrics = eval_metrics_report(
|
|
632
|
-
y_true_test,
|
|
633
|
-
y_pred_test_eval,
|
|
634
|
-
task_type=task_type,
|
|
635
|
-
threshold=threshold_value,
|
|
904
|
+
threshold_value, threshold_info = _select_classification_threshold(
|
|
905
|
+
y_true_train, y_pred_train_eval, threshold_cfg
|
|
636
906
|
)
|
|
637
|
-
|
|
638
|
-
recall = float(metrics.get("recall", 0.0))
|
|
639
|
-
f1 = 0.0 if (precision + recall) == 0 else 2 * \
|
|
640
|
-
precision * recall / (precision + recall)
|
|
641
|
-
metrics["f1"] = float(f1)
|
|
642
|
-
metrics["threshold"] = float(threshold_value)
|
|
907
|
+
metrics = _compute_classification_metrics(y_true_test, y_pred_test_eval, threshold_value)
|
|
643
908
|
else:
|
|
909
|
+
y_pred_test_eval = y_pred_test
|
|
910
|
+
calibration_info = None
|
|
911
|
+
threshold_info = None
|
|
644
912
|
metrics = eval_metrics_report(
|
|
645
|
-
y_true_test,
|
|
646
|
-
y_pred_test_eval,
|
|
647
|
-
task_type=task_type,
|
|
648
|
-
weight=weight_test,
|
|
913
|
+
y_true_test, y_pred_test_eval, task_type=task_type, weight=weight_test
|
|
649
914
|
)
|
|
650
915
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
ci = float(bootstrap_cfg.get("ci", 0.95))
|
|
656
|
-
seed = bootstrap_cfg.get("seed")
|
|
657
|
-
|
|
658
|
-
def _metric_fn(y_true, y_pred, weight=None):
|
|
659
|
-
vals = eval_metrics_report(
|
|
660
|
-
y_true,
|
|
661
|
-
y_pred,
|
|
662
|
-
task_type=task_type,
|
|
663
|
-
weight=weight,
|
|
664
|
-
threshold=metrics.get("threshold", 0.5),
|
|
665
|
-
)
|
|
666
|
-
if task_type == "classification":
|
|
667
|
-
prec = float(vals.get("precision", 0.0))
|
|
668
|
-
rec = float(vals.get("recall", 0.0))
|
|
669
|
-
vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * \
|
|
670
|
-
prec * rec / (prec + rec)
|
|
671
|
-
return vals
|
|
672
|
-
|
|
673
|
-
for name in metric_names:
|
|
674
|
-
if name not in metrics:
|
|
675
|
-
continue
|
|
676
|
-
ci_result = bootstrap_ci(
|
|
677
|
-
lambda y_t, y_p, w=None: float(
|
|
678
|
-
_metric_fn(y_t, y_p, w).get(name, 0.0)),
|
|
679
|
-
y_true_test,
|
|
680
|
-
y_pred_test_eval,
|
|
681
|
-
weight=weight_test,
|
|
682
|
-
n_samples=n_samples,
|
|
683
|
-
ci=ci,
|
|
684
|
-
seed=seed,
|
|
685
|
-
)
|
|
686
|
-
bootstrap_results[str(name)] = ci_result
|
|
916
|
+
# Compute bootstrap confidence intervals
|
|
917
|
+
bootstrap_results = _compute_bootstrap_ci(
|
|
918
|
+
y_true_test, y_pred_test_eval, weight_test, metrics, bootstrap_cfg, task_type
|
|
919
|
+
)
|
|
687
920
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
model.test_data,
|
|
697
|
-
actual_col=model.resp_nme,
|
|
698
|
-
pred_col=pred_col,
|
|
699
|
-
group_cols=available_groups,
|
|
700
|
-
weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
|
|
701
|
-
)
|
|
702
|
-
counts = (
|
|
703
|
-
model.test_data.groupby(available_groups, dropna=False)
|
|
704
|
-
.size()
|
|
705
|
-
.reset_index(name="count")
|
|
706
|
-
)
|
|
707
|
-
validation_table = validation_table.merge(
|
|
708
|
-
counts, on=available_groups, how="left")
|
|
709
|
-
except Exception as exc:
|
|
710
|
-
print(
|
|
711
|
-
f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
|
|
712
|
-
|
|
713
|
-
risk_trend = None
|
|
714
|
-
if report_time_col and group_metrics is not None:
|
|
715
|
-
if report_time_col in model.test_data.columns:
|
|
716
|
-
try:
|
|
717
|
-
time_df = model.test_data.copy()
|
|
718
|
-
time_series = pd.to_datetime(
|
|
719
|
-
time_df[report_time_col], errors="coerce")
|
|
720
|
-
time_df = time_df.loc[time_series.notna()].copy()
|
|
721
|
-
if not time_df.empty:
|
|
722
|
-
time_df["_time_bucket"] = (
|
|
723
|
-
pd.to_datetime(
|
|
724
|
-
time_df[report_time_col], errors="coerce")
|
|
725
|
-
.dt.to_period(report_time_freq)
|
|
726
|
-
.dt.to_timestamp()
|
|
727
|
-
)
|
|
728
|
-
risk_trend = group_metrics(
|
|
729
|
-
time_df,
|
|
730
|
-
actual_col=model.resp_nme,
|
|
731
|
-
pred_col=pred_col,
|
|
732
|
-
group_cols=["_time_bucket"],
|
|
733
|
-
weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
|
|
734
|
-
)
|
|
735
|
-
counts = (
|
|
736
|
-
time_df.groupby("_time_bucket", dropna=False)
|
|
737
|
-
.size()
|
|
738
|
-
.reset_index(name="count")
|
|
739
|
-
)
|
|
740
|
-
risk_trend = risk_trend.merge(
|
|
741
|
-
counts, on="_time_bucket", how="left")
|
|
742
|
-
risk_trend = risk_trend.sort_values(
|
|
743
|
-
"_time_bucket", ascending=bool(report_time_ascending)
|
|
744
|
-
).reset_index(drop=True)
|
|
745
|
-
risk_trend = risk_trend.rename(
|
|
746
|
-
columns={"_time_bucket": report_time_col})
|
|
747
|
-
except Exception as exc:
|
|
748
|
-
print(
|
|
749
|
-
f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
|
|
921
|
+
# Compute validation table and risk trend
|
|
922
|
+
validation_table = _compute_validation_table(
|
|
923
|
+
model, pred_col, report_group_cols, weight_col, model_name, model_key
|
|
924
|
+
)
|
|
925
|
+
risk_trend = _compute_risk_trend(
|
|
926
|
+
model, pred_col, report_time_col, report_time_freq,
|
|
927
|
+
report_time_ascending, weight_col, model_name, model_key
|
|
928
|
+
)
|
|
750
929
|
|
|
930
|
+
# Setup output directory
|
|
751
931
|
report_root = (
|
|
752
932
|
Path(report_output_dir)
|
|
753
933
|
if report_output_dir
|
|
754
934
|
else Path(model.output_manager.result_dir) / "reports"
|
|
755
935
|
)
|
|
756
936
|
report_root.mkdir(parents=True, exist_ok=True)
|
|
757
|
-
|
|
758
937
|
version = f"{model_key}_{run_id}"
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
"calibration": calibration_info,
|
|
766
|
-
"bootstrap": bootstrap_results,
|
|
767
|
-
"data_path": str(data_path),
|
|
768
|
-
"data_fingerprint": data_fingerprint,
|
|
769
|
-
"config_sha256": config_sha,
|
|
770
|
-
"pred_col": pred_col,
|
|
771
|
-
"task_type": task_type,
|
|
772
|
-
}
|
|
773
|
-
metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
|
|
774
|
-
metrics_path.write_text(
|
|
775
|
-
json.dumps(metrics_payload, indent=2, ensure_ascii=True),
|
|
776
|
-
encoding="utf-8",
|
|
938
|
+
|
|
939
|
+
# Write metrics JSON
|
|
940
|
+
metrics_path = _write_metrics_json(
|
|
941
|
+
report_root, model_name, model_key, version, metrics,
|
|
942
|
+
threshold_info, calibration_info, bootstrap_results,
|
|
943
|
+
data_path, data_fingerprint, config_sha, pred_col, task_type
|
|
777
944
|
)
|
|
778
945
|
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
notes_lines.append(
|
|
787
|
-
f"- Calibration: {calibration_info.get('method')}"
|
|
788
|
-
)
|
|
789
|
-
if threshold_info:
|
|
790
|
-
notes_lines.append(
|
|
791
|
-
f"- Threshold selection: {threshold_info}"
|
|
792
|
-
)
|
|
793
|
-
if bootstrap_results:
|
|
794
|
-
notes_lines.append("- Bootstrap: see metrics JSON for CI")
|
|
795
|
-
extra_notes = "\n".join(notes_lines)
|
|
796
|
-
payload = ReportPayload(
|
|
797
|
-
model_name=f"{model_name}/{model_key}",
|
|
798
|
-
model_version=version,
|
|
799
|
-
metrics={k: float(v) for k, v in metrics.items()},
|
|
800
|
-
risk_trend=risk_trend,
|
|
801
|
-
drift_report=psi_report_df,
|
|
802
|
-
validation_table=validation_table,
|
|
803
|
-
extra_notes=extra_notes,
|
|
804
|
-
)
|
|
805
|
-
report_path = write_report(
|
|
806
|
-
payload,
|
|
807
|
-
report_root / f"{model_name}_{model_key}_report.md",
|
|
808
|
-
)
|
|
946
|
+
# Write model report
|
|
947
|
+
report_path = _write_model_report(
|
|
948
|
+
report_root, model_name, model_key, version, metrics,
|
|
949
|
+
risk_trend, psi_report_df, validation_table,
|
|
950
|
+
calibration_info, threshold_info, bootstrap_results,
|
|
951
|
+
config_sha, data_fingerprint
|
|
952
|
+
)
|
|
809
953
|
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
|
|
817
|
-
tags.update({
|
|
818
|
-
"model_key": str(model_key),
|
|
819
|
-
"task_type": str(task_type),
|
|
820
|
-
"data_path": str(data_path),
|
|
821
|
-
"data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
|
|
822
|
-
"data_size": str(data_fingerprint.get("size", "")),
|
|
823
|
-
"data_mtime": str(data_fingerprint.get("mtime", "")),
|
|
824
|
-
"config_sha256": str(config_sha),
|
|
825
|
-
})
|
|
826
|
-
artifacts = []
|
|
827
|
-
trainer = model.trainers.get(model_key)
|
|
828
|
-
if trainer is not None:
|
|
829
|
-
try:
|
|
830
|
-
model_path = trainer.output.model_path(
|
|
831
|
-
trainer._get_model_filename())
|
|
832
|
-
if os.path.exists(model_path):
|
|
833
|
-
artifacts.append(ModelArtifact(
|
|
834
|
-
path=model_path, description="trained model"))
|
|
835
|
-
except Exception:
|
|
836
|
-
pass
|
|
837
|
-
if report_path is not None:
|
|
838
|
-
artifacts.append(ModelArtifact(
|
|
839
|
-
path=str(report_path), description="model report"))
|
|
840
|
-
if metrics_path.exists():
|
|
841
|
-
artifacts.append(ModelArtifact(
|
|
842
|
-
path=str(metrics_path), description="metrics json"))
|
|
843
|
-
if bool(cfg.get("save_preprocess", False)):
|
|
844
|
-
artifact_path = cfg.get("preprocess_artifact_path")
|
|
845
|
-
if artifact_path:
|
|
846
|
-
preprocess_path = Path(str(artifact_path))
|
|
847
|
-
if not preprocess_path.is_absolute():
|
|
848
|
-
preprocess_path = Path(
|
|
849
|
-
model.output_manager.result_dir) / preprocess_path
|
|
850
|
-
else:
|
|
851
|
-
preprocess_path = Path(model.output_manager.result_path(
|
|
852
|
-
f"{model.model_nme}_preprocess.json"
|
|
853
|
-
))
|
|
854
|
-
if preprocess_path.exists():
|
|
855
|
-
artifacts.append(
|
|
856
|
-
ModelArtifact(path=str(preprocess_path),
|
|
857
|
-
description="preprocess artifacts")
|
|
858
|
-
)
|
|
859
|
-
if bool(cfg.get("cache_predictions", False)):
|
|
860
|
-
cache_dir = cfg.get("prediction_cache_dir")
|
|
861
|
-
if cache_dir:
|
|
862
|
-
pred_root = Path(str(cache_dir))
|
|
863
|
-
if not pred_root.is_absolute():
|
|
864
|
-
pred_root = Path(
|
|
865
|
-
model.output_manager.result_dir) / pred_root
|
|
866
|
-
else:
|
|
867
|
-
pred_root = Path(
|
|
868
|
-
model.output_manager.result_dir) / "predictions"
|
|
869
|
-
ext = "csv" if str(
|
|
870
|
-
cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
|
|
871
|
-
for split_label in ("train", "test"):
|
|
872
|
-
pred_path = pred_root / \
|
|
873
|
-
f"{model_name}_{model_key}_{split_label}.{ext}"
|
|
874
|
-
if pred_path.exists():
|
|
875
|
-
artifacts.append(
|
|
876
|
-
ModelArtifact(path=str(pred_path),
|
|
877
|
-
description=f"predictions {split_label}")
|
|
878
|
-
)
|
|
879
|
-
registry.register(
|
|
880
|
-
name=str(model_name),
|
|
881
|
-
version=version,
|
|
882
|
-
metrics={k: float(v) for k, v in metrics.items()},
|
|
883
|
-
tags=tags,
|
|
884
|
-
artifacts=artifacts,
|
|
885
|
-
status=str(registry_status or "candidate"),
|
|
886
|
-
notes=f"model_key={model_key}",
|
|
954
|
+
# Register model
|
|
955
|
+
if register_model:
|
|
956
|
+
_register_model_to_registry(
|
|
957
|
+
model, model_name, model_key, version, metrics, task_type,
|
|
958
|
+
data_path, data_fingerprint, config_sha, registry_path,
|
|
959
|
+
registry_tags, registry_status, report_path, metrics_path, cfg
|
|
887
960
|
)
|
|
888
961
|
|
|
889
962
|
|
|
890
|
-
def
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
963
|
+
def _evaluate_with_context(
|
|
964
|
+
model: ropt.BayesOptModel,
|
|
965
|
+
ctx: EvaluationContext,
|
|
966
|
+
) -> None:
|
|
967
|
+
"""Evaluate model predictions using context object.
|
|
968
|
+
|
|
969
|
+
This is a cleaner interface that uses the EvaluationContext dataclass
|
|
970
|
+
instead of 19+ individual parameters.
|
|
971
|
+
"""
|
|
972
|
+
_evaluate_and_report(
|
|
973
|
+
model,
|
|
974
|
+
model_name=ctx.identity.model_name,
|
|
975
|
+
model_key=ctx.identity.model_key,
|
|
976
|
+
cfg=ctx.cfg,
|
|
977
|
+
data_path=ctx.data_path,
|
|
978
|
+
data_fingerprint=ctx.data_fingerprint.to_dict(),
|
|
979
|
+
report_output_dir=ctx.report.output_dir,
|
|
980
|
+
report_group_cols=ctx.report.group_cols,
|
|
981
|
+
report_time_col=ctx.report.time_col,
|
|
982
|
+
report_time_freq=ctx.report.time_freq,
|
|
983
|
+
report_time_ascending=ctx.report.time_ascending,
|
|
984
|
+
psi_report_df=ctx.psi_report_df,
|
|
985
|
+
calibration_cfg={
|
|
986
|
+
"enable": ctx.calibration.enable,
|
|
987
|
+
"method": ctx.calibration.method,
|
|
988
|
+
"max_rows": ctx.calibration.max_rows,
|
|
989
|
+
"seed": ctx.calibration.seed,
|
|
990
|
+
},
|
|
991
|
+
threshold_cfg={
|
|
992
|
+
"enable": ctx.threshold.enable,
|
|
993
|
+
"metric": ctx.threshold.metric,
|
|
994
|
+
"value": ctx.threshold.value,
|
|
995
|
+
"min_positive_rate": ctx.threshold.min_positive_rate,
|
|
996
|
+
"grid": ctx.threshold.grid,
|
|
997
|
+
"max_rows": ctx.threshold.max_rows,
|
|
998
|
+
"seed": ctx.threshold.seed,
|
|
999
|
+
},
|
|
1000
|
+
bootstrap_cfg={
|
|
1001
|
+
"enable": ctx.bootstrap.enable,
|
|
1002
|
+
"metrics": ctx.bootstrap.metrics,
|
|
1003
|
+
"n_samples": ctx.bootstrap.n_samples,
|
|
1004
|
+
"ci": ctx.bootstrap.ci,
|
|
1005
|
+
"seed": ctx.bootstrap.seed,
|
|
1006
|
+
},
|
|
1007
|
+
register_model=ctx.registry.register,
|
|
1008
|
+
registry_path=ctx.registry.path,
|
|
1009
|
+
registry_tags=ctx.registry.tags,
|
|
1010
|
+
registry_status=ctx.registry.status,
|
|
1011
|
+
run_id=ctx.run_id,
|
|
1012
|
+
config_sha=ctx.config_sha,
|
|
897
1013
|
)
|
|
898
|
-
plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
|
|
899
|
-
config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
|
|
900
|
-
run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
901
1014
|
|
|
902
|
-
def _safe_int_env(key: str, default: int) -> int:
|
|
903
|
-
try:
|
|
904
|
-
return int(os.environ.get(key, default))
|
|
905
|
-
except (TypeError, ValueError):
|
|
906
|
-
return default
|
|
907
|
-
|
|
908
|
-
dist_world_size = _safe_int_env("WORLD_SIZE", 1)
|
|
909
|
-
dist_rank = _safe_int_env("RANK", 0)
|
|
910
|
-
dist_active = dist_world_size > 1
|
|
911
|
-
is_main_process = (not dist_active) or dist_rank == 0
|
|
912
1015
|
|
|
1016
|
+
def _create_ddp_barrier(dist_ctx: TrainingContext):
|
|
1017
|
+
"""Create a DDP barrier function for distributed training synchronization."""
|
|
913
1018
|
def _ddp_barrier(reason: str) -> None:
|
|
914
|
-
if not
|
|
1019
|
+
if not dist_ctx.is_distributed:
|
|
915
1020
|
return
|
|
916
1021
|
torch_mod = getattr(ropt, "torch", None)
|
|
917
1022
|
dist_mod = getattr(torch_mod, "distributed", None)
|
|
@@ -928,6 +1033,28 @@ def train_from_config(args: argparse.Namespace) -> None:
|
|
|
928
1033
|
except Exception as exc:
|
|
929
1034
|
print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
|
|
930
1035
|
raise
|
|
1036
|
+
return _ddp_barrier
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def train_from_config(args: argparse.Namespace) -> None:
|
|
1040
|
+
script_dir = Path(__file__).resolve().parents[1]
|
|
1041
|
+
config_path, cfg = resolve_and_load_config(
|
|
1042
|
+
args.config_json,
|
|
1043
|
+
script_dir,
|
|
1044
|
+
required_keys=["data_dir", "model_list",
|
|
1045
|
+
"model_categories", "target", "weight"],
|
|
1046
|
+
)
|
|
1047
|
+
plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
|
|
1048
|
+
config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
|
|
1049
|
+
run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
1050
|
+
|
|
1051
|
+
# Use TrainingContext for distributed training state
|
|
1052
|
+
dist_ctx = TrainingContext.from_env()
|
|
1053
|
+
dist_world_size = dist_ctx.world_size
|
|
1054
|
+
dist_rank = dist_ctx.rank
|
|
1055
|
+
dist_active = dist_ctx.is_distributed
|
|
1056
|
+
is_main_process = dist_ctx.is_main_process
|
|
1057
|
+
_ddp_barrier = _create_ddp_barrier(dist_ctx)
|
|
931
1058
|
|
|
932
1059
|
data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
|
|
933
1060
|
cfg,
|