ins-pricing 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ins_pricing/CHANGELOG.md +93 -0
  2. ins_pricing/README.md +11 -0
  3. ins_pricing/cli/bayesopt_entry_runner.py +626 -499
  4. ins_pricing/cli/utils/evaluation_context.py +320 -0
  5. ins_pricing/cli/utils/import_resolver.py +350 -0
  6. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
  7. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
  8. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
  9. ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
  10. ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
  11. ins_pricing/modelling/core/bayesopt/core.py +153 -94
  12. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +118 -31
  13. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +294 -139
  14. ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
  15. ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
  16. ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
  17. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
  18. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
  19. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +587 -0
  20. ins_pricing/modelling/core/bayesopt/utils.py +98 -1495
  21. ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
  22. ins_pricing/setup.py +1 -1
  23. ins_pricing-0.3.0.dist-info/METADATA +162 -0
  24. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/RECORD +26 -13
  25. ins_pricing-0.2.8.dist-info/METADATA +0 -51
  26. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/WHEEL +0 -0
  27. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/top_level.txt +0 -0
@@ -29,108 +29,60 @@ from typing import Any, Dict, List, Optional
29
29
  import numpy as np
30
30
  import pandas as pd
31
31
 
32
- try:
33
- from .. import bayesopt as ropt # type: ignore
34
- from .utils.cli_common import ( # type: ignore
35
- PLOT_MODEL_LABELS,
36
- PYTORCH_TRAINERS,
37
- build_model_names,
38
- dedupe_preserve_order,
39
- load_dataset,
40
- parse_model_pairs,
41
- resolve_data_path,
42
- resolve_path,
43
- fingerprint_file,
44
- coerce_dataset_types,
45
- split_train_test,
46
- )
47
- from .utils.cli_config import ( # type: ignore
48
- add_config_json_arg,
49
- add_output_dir_arg,
50
- resolve_and_load_config,
51
- resolve_data_config,
52
- resolve_report_config,
53
- resolve_split_config,
54
- resolve_runtime_config,
55
- resolve_output_dirs,
56
- )
57
- except Exception: # pragma: no cover
58
- try:
59
- import bayesopt as ropt # type: ignore
60
- from utils.cli_common import ( # type: ignore
61
- PLOT_MODEL_LABELS,
62
- PYTORCH_TRAINERS,
63
- build_model_names,
64
- dedupe_preserve_order,
65
- load_dataset,
66
- parse_model_pairs,
67
- resolve_data_path,
68
- resolve_path,
69
- fingerprint_file,
70
- coerce_dataset_types,
71
- split_train_test,
72
- )
73
- from utils.cli_config import ( # type: ignore
74
- add_config_json_arg,
75
- add_output_dir_arg,
76
- resolve_and_load_config,
77
- resolve_data_config,
78
- resolve_report_config,
79
- resolve_split_config,
80
- resolve_runtime_config,
81
- resolve_output_dirs,
82
- )
83
- except Exception:
84
- try:
85
- import ins_pricing.modelling.core.bayesopt as ropt # type: ignore
86
- from ins_pricing.cli.utils.cli_common import ( # type: ignore
87
- PLOT_MODEL_LABELS,
88
- PYTORCH_TRAINERS,
89
- build_model_names,
90
- dedupe_preserve_order,
91
- load_dataset,
92
- parse_model_pairs,
93
- resolve_data_path,
94
- resolve_path,
95
- fingerprint_file,
96
- coerce_dataset_types,
97
- split_train_test,
98
- )
99
- from ins_pricing.cli.utils.cli_config import ( # type: ignore
100
- add_config_json_arg,
101
- add_output_dir_arg,
102
- resolve_and_load_config,
103
- resolve_data_config,
104
- resolve_report_config,
105
- resolve_split_config,
106
- resolve_runtime_config,
107
- resolve_output_dirs,
108
- )
109
- except Exception:
110
- import BayesOpt as ropt # type: ignore
111
- from utils.cli_common import ( # type: ignore
112
- PLOT_MODEL_LABELS,
113
- PYTORCH_TRAINERS,
114
- build_model_names,
115
- dedupe_preserve_order,
116
- load_dataset,
117
- parse_model_pairs,
118
- resolve_data_path,
119
- resolve_path,
120
- fingerprint_file,
121
- coerce_dataset_types,
122
- split_train_test,
123
- )
124
- from utils.cli_config import ( # type: ignore
125
- add_config_json_arg,
126
- add_output_dir_arg,
127
- resolve_and_load_config,
128
- resolve_data_config,
129
- resolve_report_config,
130
- resolve_split_config,
131
- resolve_runtime_config,
132
- resolve_output_dirs,
133
- )
32
+ # Use unified import resolver to eliminate nested try/except chains
33
+ from .utils.import_resolver import resolve_imports, setup_sys_path
34
+ from .utils.evaluation_context import (
35
+ EvaluationContext,
36
+ TrainingContext,
37
+ ModelIdentity,
38
+ DataFingerprint,
39
+ CalibrationConfig,
40
+ ThresholdConfig,
41
+ BootstrapConfig,
42
+ ReportConfig,
43
+ RegistryConfig,
44
+ )
45
+
46
+ # Resolve all imports from a single location
47
+ setup_sys_path()
48
+ _imports = resolve_imports()
49
+
50
+ ropt = _imports.bayesopt
51
+ PLOT_MODEL_LABELS = _imports.PLOT_MODEL_LABELS
52
+ PYTORCH_TRAINERS = _imports.PYTORCH_TRAINERS
53
+ build_model_names = _imports.build_model_names
54
+ dedupe_preserve_order = _imports.dedupe_preserve_order
55
+ load_dataset = _imports.load_dataset
56
+ parse_model_pairs = _imports.parse_model_pairs
57
+ resolve_data_path = _imports.resolve_data_path
58
+ resolve_path = _imports.resolve_path
59
+ fingerprint_file = _imports.fingerprint_file
60
+ coerce_dataset_types = _imports.coerce_dataset_types
61
+ split_train_test = _imports.split_train_test
62
+
63
+ add_config_json_arg = _imports.add_config_json_arg
64
+ add_output_dir_arg = _imports.add_output_dir_arg
65
+ resolve_and_load_config = _imports.resolve_and_load_config
66
+ resolve_data_config = _imports.resolve_data_config
67
+ resolve_report_config = _imports.resolve_report_config
68
+ resolve_split_config = _imports.resolve_split_config
69
+ resolve_runtime_config = _imports.resolve_runtime_config
70
+ resolve_output_dirs = _imports.resolve_output_dirs
71
+
72
+ bootstrap_ci = _imports.bootstrap_ci
73
+ calibrate_predictions = _imports.calibrate_predictions
74
+ eval_metrics_report = _imports.metrics_report
75
+ select_threshold = _imports.select_threshold
76
+
77
+ ModelArtifact = _imports.ModelArtifact
78
+ ModelRegistry = _imports.ModelRegistry
79
+ drift_psi_report = _imports.drift_psi_report
80
+ group_metrics = _imports.group_metrics
81
+ ReportPayload = _imports.ReportPayload
82
+ write_report = _imports.write_report
83
+
84
+ configure_run_logging = _imports.configure_run_logging
85
+ plot_loss_curve_common = _imports.plot_loss_curve
134
86
 
135
87
  import matplotlib
136
88
 
@@ -138,81 +90,6 @@ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPL
138
90
  matplotlib.use("Agg")
139
91
  import matplotlib.pyplot as plt
140
92
 
141
- try:
142
- from .utils.run_logging import configure_run_logging # type: ignore
143
- except Exception: # pragma: no cover
144
- try:
145
- from utils.run_logging import configure_run_logging # type: ignore
146
- except Exception: # pragma: no cover
147
- configure_run_logging = None # type: ignore
148
-
149
- try:
150
- from ..modelling.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
151
- except Exception: # pragma: no cover
152
- try:
153
- from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
154
- except Exception: # pragma: no cover
155
- plot_loss_curve_common = None
156
-
157
- try:
158
- from ..modelling.core.evaluation import ( # type: ignore
159
- bootstrap_ci,
160
- calibrate_predictions,
161
- metrics_report as eval_metrics_report,
162
- select_threshold,
163
- )
164
- from ..governance.registry import ModelArtifact, ModelRegistry # type: ignore
165
- from ..production import psi_report as drift_psi_report # type: ignore
166
- from ..production.monitoring import group_metrics # type: ignore
167
- from ..reporting.report_builder import ReportPayload, write_report # type: ignore
168
- except Exception: # pragma: no cover
169
- try:
170
- from ins_pricing.modelling.core.evaluation import ( # type: ignore
171
- bootstrap_ci,
172
- calibrate_predictions,
173
- metrics_report as eval_metrics_report,
174
- select_threshold,
175
- )
176
- from ins_pricing.governance.registry import ( # type: ignore
177
- ModelArtifact,
178
- ModelRegistry,
179
- )
180
- from ins_pricing.production import psi_report as drift_psi_report # type: ignore
181
- from ins_pricing.production.monitoring import group_metrics # type: ignore
182
- from ins_pricing.reporting.report_builder import ( # type: ignore
183
- ReportPayload,
184
- write_report,
185
- )
186
- except Exception: # pragma: no cover
187
- try:
188
- from evaluation import ( # type: ignore
189
- bootstrap_ci,
190
- calibrate_predictions,
191
- metrics_report as eval_metrics_report,
192
- select_threshold,
193
- )
194
- from ins_pricing.governance.registry import ( # type: ignore
195
- ModelArtifact,
196
- ModelRegistry,
197
- )
198
- from ins_pricing.production import psi_report as drift_psi_report # type: ignore
199
- from ins_pricing.production.monitoring import group_metrics # type: ignore
200
- from ins_pricing.reporting.report_builder import ( # type: ignore
201
- ReportPayload,
202
- write_report,
203
- )
204
- except Exception: # pragma: no cover
205
- bootstrap_ci = None # type: ignore
206
- calibrate_predictions = None # type: ignore
207
- eval_metrics_report = None # type: ignore
208
- select_threshold = None # type: ignore
209
- drift_psi_report = None # type: ignore
210
- group_metrics = None # type: ignore
211
- ReportPayload = None # type: ignore
212
- write_report = None # type: ignore
213
- ModelRegistry = None # type: ignore
214
- ModelArtifact = None # type: ignore
215
-
216
93
 
217
94
  def _parse_args() -> argparse.Namespace:
218
95
  parser = argparse.ArgumentParser(
@@ -520,6 +397,444 @@ def _compute_psi_report(
520
397
  return None
521
398
 
522
399
 
400
+ # --- Refactored helper functions for _evaluate_and_report ---
401
+
402
+
403
+ def _apply_calibration(
404
+ y_true_train: np.ndarray,
405
+ y_pred_train: np.ndarray,
406
+ y_pred_test: np.ndarray,
407
+ calibration_cfg: Dict[str, Any],
408
+ model_name: str,
409
+ model_key: str,
410
+ ) -> tuple[np.ndarray, np.ndarray, Optional[Dict[str, Any]]]:
411
+ """Apply calibration to predictions for classification tasks.
412
+
413
+ Returns:
414
+ Tuple of (calibrated_train_preds, calibrated_test_preds, calibration_info)
415
+ """
416
+ cal_cfg = dict(calibration_cfg or {})
417
+ cal_enabled = bool(cal_cfg.get("enable", False) or cal_cfg.get("method"))
418
+
419
+ if not cal_enabled or calibrate_predictions is None:
420
+ return y_pred_train, y_pred_test, None
421
+
422
+ method = cal_cfg.get("method", "sigmoid")
423
+ max_rows = cal_cfg.get("max_rows")
424
+ seed = cal_cfg.get("seed")
425
+ y_cal, p_cal = _sample_arrays(
426
+ y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
427
+
428
+ try:
429
+ calibrator = calibrate_predictions(y_cal, p_cal, method=method)
430
+ calibrated_train = calibrator.predict(y_pred_train)
431
+ calibrated_test = calibrator.predict(y_pred_test)
432
+ calibration_info = {"method": calibrator.method, "max_rows": max_rows}
433
+ return calibrated_train, calibrated_test, calibration_info
434
+ except Exception as exc:
435
+ print(f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
436
+ return y_pred_train, y_pred_test, None
437
+
438
+
439
+ def _select_classification_threshold(
440
+ y_true_train: np.ndarray,
441
+ y_pred_train_eval: np.ndarray,
442
+ threshold_cfg: Dict[str, Any],
443
+ ) -> tuple[float, Optional[Dict[str, Any]]]:
444
+ """Select threshold for classification predictions.
445
+
446
+ Returns:
447
+ Tuple of (threshold_value, threshold_info)
448
+ """
449
+ thr_cfg = dict(threshold_cfg or {})
450
+ thr_enabled = bool(
451
+ thr_cfg.get("enable", False)
452
+ or thr_cfg.get("metric")
453
+ or thr_cfg.get("value") is not None
454
+ )
455
+
456
+ if thr_cfg.get("value") is not None:
457
+ threshold_value = float(thr_cfg["value"])
458
+ return threshold_value, {"threshold": threshold_value, "source": "fixed"}
459
+
460
+ if thr_enabled and select_threshold is not None:
461
+ max_rows = thr_cfg.get("max_rows")
462
+ seed = thr_cfg.get("seed")
463
+ y_thr, p_thr = _sample_arrays(
464
+ y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
465
+ threshold_info = select_threshold(
466
+ y_thr,
467
+ p_thr,
468
+ metric=thr_cfg.get("metric", "f1"),
469
+ min_positive_rate=thr_cfg.get("min_positive_rate"),
470
+ grid=thr_cfg.get("grid", 99),
471
+ )
472
+ return float(threshold_info.get("threshold", 0.5)), threshold_info
473
+
474
+ return 0.5, None
475
+
476
+
477
+ def _compute_classification_metrics(
478
+ y_true_test: np.ndarray,
479
+ y_pred_test_eval: np.ndarray,
480
+ threshold_value: float,
481
+ ) -> Dict[str, Any]:
482
+ """Compute metrics for classification task."""
483
+ metrics = eval_metrics_report(
484
+ y_true_test,
485
+ y_pred_test_eval,
486
+ task_type="classification",
487
+ threshold=threshold_value,
488
+ )
489
+ precision = float(metrics.get("precision", 0.0))
490
+ recall = float(metrics.get("recall", 0.0))
491
+ f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
492
+ metrics["f1"] = float(f1)
493
+ metrics["threshold"] = float(threshold_value)
494
+ return metrics
495
+
496
+
497
+ def _compute_bootstrap_ci(
498
+ y_true_test: np.ndarray,
499
+ y_pred_test_eval: np.ndarray,
500
+ weight_test: Optional[np.ndarray],
501
+ metrics: Dict[str, Any],
502
+ bootstrap_cfg: Dict[str, Any],
503
+ task_type: str,
504
+ ) -> Dict[str, Dict[str, float]]:
505
+ """Compute bootstrap confidence intervals for metrics."""
506
+ if not bootstrap_cfg or not bool(bootstrap_cfg.get("enable", False)) or bootstrap_ci is None:
507
+ return {}
508
+
509
+ metric_names = bootstrap_cfg.get("metrics")
510
+ if not metric_names:
511
+ metric_names = [name for name in metrics.keys() if name != "threshold"]
512
+ n_samples = int(bootstrap_cfg.get("n_samples", 200))
513
+ ci = float(bootstrap_cfg.get("ci", 0.95))
514
+ seed = bootstrap_cfg.get("seed")
515
+
516
+ def _metric_fn(y_true, y_pred, weight=None):
517
+ vals = eval_metrics_report(
518
+ y_true,
519
+ y_pred,
520
+ task_type=task_type,
521
+ weight=weight,
522
+ threshold=metrics.get("threshold", 0.5),
523
+ )
524
+ if task_type == "classification":
525
+ prec = float(vals.get("precision", 0.0))
526
+ rec = float(vals.get("recall", 0.0))
527
+ vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
528
+ return vals
529
+
530
+ bootstrap_results: Dict[str, Dict[str, float]] = {}
531
+ for name in metric_names:
532
+ if name not in metrics:
533
+ continue
534
+ ci_result = bootstrap_ci(
535
+ lambda y_t, y_p, w=None: float(_metric_fn(y_t, y_p, w).get(name, 0.0)),
536
+ y_true_test,
537
+ y_pred_test_eval,
538
+ weight=weight_test,
539
+ n_samples=n_samples,
540
+ ci=ci,
541
+ seed=seed,
542
+ )
543
+ bootstrap_results[str(name)] = ci_result
544
+
545
+ return bootstrap_results
546
+
547
+
548
+ def _compute_validation_table(
549
+ model: ropt.BayesOptModel,
550
+ pred_col: str,
551
+ report_group_cols: Optional[List[str]],
552
+ weight_col: Optional[str],
553
+ model_name: str,
554
+ model_key: str,
555
+ ) -> Optional[pd.DataFrame]:
556
+ """Compute grouped validation metrics table."""
557
+ if not report_group_cols or group_metrics is None:
558
+ return None
559
+
560
+ available_groups = [
561
+ col for col in report_group_cols if col in model.test_data.columns
562
+ ]
563
+ if not available_groups:
564
+ return None
565
+
566
+ try:
567
+ validation_table = group_metrics(
568
+ model.test_data,
569
+ actual_col=model.resp_nme,
570
+ pred_col=pred_col,
571
+ group_cols=available_groups,
572
+ weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
573
+ )
574
+ counts = (
575
+ model.test_data.groupby(available_groups, dropna=False)
576
+ .size()
577
+ .reset_index(name="count")
578
+ )
579
+ return validation_table.merge(counts, on=available_groups, how="left")
580
+ except Exception as exc:
581
+ print(f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
582
+ return None
583
+
584
+
585
+ def _compute_risk_trend(
586
+ model: ropt.BayesOptModel,
587
+ pred_col: str,
588
+ report_time_col: Optional[str],
589
+ report_time_freq: str,
590
+ report_time_ascending: bool,
591
+ weight_col: Optional[str],
592
+ model_name: str,
593
+ model_key: str,
594
+ ) -> Optional[pd.DataFrame]:
595
+ """Compute time-series risk trend metrics."""
596
+ if not report_time_col or group_metrics is None:
597
+ return None
598
+
599
+ if report_time_col not in model.test_data.columns:
600
+ return None
601
+
602
+ try:
603
+ time_df = model.test_data.copy()
604
+ time_series = pd.to_datetime(time_df[report_time_col], errors="coerce")
605
+ time_df = time_df.loc[time_series.notna()].copy()
606
+
607
+ if time_df.empty:
608
+ return None
609
+
610
+ time_df["_time_bucket"] = (
611
+ pd.to_datetime(time_df[report_time_col], errors="coerce")
612
+ .dt.to_period(report_time_freq)
613
+ .dt.to_timestamp()
614
+ )
615
+ risk_trend = group_metrics(
616
+ time_df,
617
+ actual_col=model.resp_nme,
618
+ pred_col=pred_col,
619
+ group_cols=["_time_bucket"],
620
+ weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
621
+ )
622
+ counts = (
623
+ time_df.groupby("_time_bucket", dropna=False)
624
+ .size()
625
+ .reset_index(name="count")
626
+ )
627
+ risk_trend = risk_trend.merge(counts, on="_time_bucket", how="left")
628
+ risk_trend = risk_trend.sort_values(
629
+ "_time_bucket", ascending=bool(report_time_ascending)
630
+ ).reset_index(drop=True)
631
+ return risk_trend.rename(columns={"_time_bucket": report_time_col})
632
+ except Exception as exc:
633
+ print(f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
634
+ return None
635
+
636
+
637
+ def _write_metrics_json(
638
+ report_root: Path,
639
+ model_name: str,
640
+ model_key: str,
641
+ version: str,
642
+ metrics: Dict[str, Any],
643
+ threshold_info: Optional[Dict[str, Any]],
644
+ calibration_info: Optional[Dict[str, Any]],
645
+ bootstrap_results: Dict[str, Dict[str, float]],
646
+ data_path: Path,
647
+ data_fingerprint: Dict[str, Any],
648
+ config_sha: str,
649
+ pred_col: str,
650
+ task_type: str,
651
+ ) -> Path:
652
+ """Write metrics to JSON file and return the path."""
653
+ metrics_payload = {
654
+ "model_name": model_name,
655
+ "model_key": model_key,
656
+ "model_version": version,
657
+ "metrics": metrics,
658
+ "threshold": threshold_info,
659
+ "calibration": calibration_info,
660
+ "bootstrap": bootstrap_results,
661
+ "data_path": str(data_path),
662
+ "data_fingerprint": data_fingerprint,
663
+ "config_sha256": config_sha,
664
+ "pred_col": pred_col,
665
+ "task_type": task_type,
666
+ }
667
+ metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
668
+ metrics_path.write_text(
669
+ json.dumps(metrics_payload, indent=2, ensure_ascii=True),
670
+ encoding="utf-8",
671
+ )
672
+ return metrics_path
673
+
674
+
675
+ def _write_model_report(
676
+ report_root: Path,
677
+ model_name: str,
678
+ model_key: str,
679
+ version: str,
680
+ metrics: Dict[str, Any],
681
+ risk_trend: Optional[pd.DataFrame],
682
+ psi_report_df: Optional[pd.DataFrame],
683
+ validation_table: Optional[pd.DataFrame],
684
+ calibration_info: Optional[Dict[str, Any]],
685
+ threshold_info: Optional[Dict[str, Any]],
686
+ bootstrap_results: Dict[str, Dict[str, float]],
687
+ config_sha: str,
688
+ data_fingerprint: Dict[str, Any],
689
+ ) -> Optional[Path]:
690
+ """Write model report and return the path."""
691
+ if ReportPayload is None or write_report is None:
692
+ return None
693
+
694
+ notes_lines = [
695
+ f"- Config SHA256: {config_sha}",
696
+ f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
697
+ ]
698
+ if calibration_info:
699
+ notes_lines.append(f"- Calibration: {calibration_info.get('method')}")
700
+ if threshold_info:
701
+ notes_lines.append(f"- Threshold selection: {threshold_info}")
702
+ if bootstrap_results:
703
+ notes_lines.append("- Bootstrap: see metrics JSON for CI")
704
+
705
+ payload = ReportPayload(
706
+ model_name=f"{model_name}/{model_key}",
707
+ model_version=version,
708
+ metrics={k: float(v) for k, v in metrics.items()},
709
+ risk_trend=risk_trend,
710
+ drift_report=psi_report_df,
711
+ validation_table=validation_table,
712
+ extra_notes="\n".join(notes_lines),
713
+ )
714
+ return write_report(
715
+ payload,
716
+ report_root / f"{model_name}_{model_key}_report.md",
717
+ )
718
+
719
+
720
+ def _register_model_to_registry(
721
+ model: ropt.BayesOptModel,
722
+ model_name: str,
723
+ model_key: str,
724
+ version: str,
725
+ metrics: Dict[str, Any],
726
+ task_type: str,
727
+ data_path: Path,
728
+ data_fingerprint: Dict[str, Any],
729
+ config_sha: str,
730
+ registry_path: Optional[str],
731
+ registry_tags: Dict[str, Any],
732
+ registry_status: str,
733
+ report_path: Optional[Path],
734
+ metrics_path: Path,
735
+ cfg: Dict[str, Any],
736
+ ) -> None:
737
+ """Register model artifacts to the model registry."""
738
+ if ModelRegistry is None or ModelArtifact is None:
739
+ return
740
+
741
+ registry = ModelRegistry(
742
+ registry_path
743
+ if registry_path
744
+ else Path(model.output_manager.result_dir) / "model_registry.json"
745
+ )
746
+
747
+ tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
748
+ tags.update({
749
+ "model_key": str(model_key),
750
+ "task_type": str(task_type),
751
+ "data_path": str(data_path),
752
+ "data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
753
+ "data_size": str(data_fingerprint.get("size", "")),
754
+ "data_mtime": str(data_fingerprint.get("mtime", "")),
755
+ "config_sha256": str(config_sha),
756
+ })
757
+
758
+ artifacts = _collect_model_artifacts(
759
+ model, model_name, model_key, report_path, metrics_path, cfg
760
+ )
761
+
762
+ registry.register(
763
+ name=str(model_name),
764
+ version=version,
765
+ metrics={k: float(v) for k, v in metrics.items()},
766
+ tags=tags,
767
+ artifacts=artifacts,
768
+ status=str(registry_status or "candidate"),
769
+ notes=f"model_key={model_key}",
770
+ )
771
+
772
+
773
+ def _collect_model_artifacts(
774
+ model: ropt.BayesOptModel,
775
+ model_name: str,
776
+ model_key: str,
777
+ report_path: Optional[Path],
778
+ metrics_path: Path,
779
+ cfg: Dict[str, Any],
780
+ ) -> List:
781
+ """Collect all model artifacts for registry."""
782
+ artifacts = []
783
+
784
+ # Trained model artifact
785
+ trainer = model.trainers.get(model_key)
786
+ if trainer is not None:
787
+ try:
788
+ model_path = trainer.output.model_path(trainer._get_model_filename())
789
+ if os.path.exists(model_path):
790
+ artifacts.append(ModelArtifact(path=model_path, description="trained model"))
791
+ except Exception:
792
+ pass
793
+
794
+ # Report artifact
795
+ if report_path is not None:
796
+ artifacts.append(ModelArtifact(path=str(report_path), description="model report"))
797
+
798
+ # Metrics JSON artifact
799
+ if metrics_path.exists():
800
+ artifacts.append(ModelArtifact(path=str(metrics_path), description="metrics json"))
801
+
802
+ # Preprocess artifacts
803
+ if bool(cfg.get("save_preprocess", False)):
804
+ artifact_path = cfg.get("preprocess_artifact_path")
805
+ if artifact_path:
806
+ preprocess_path = Path(str(artifact_path))
807
+ if not preprocess_path.is_absolute():
808
+ preprocess_path = Path(model.output_manager.result_dir) / preprocess_path
809
+ else:
810
+ preprocess_path = Path(model.output_manager.result_path(
811
+ f"{model.model_nme}_preprocess.json"
812
+ ))
813
+ if preprocess_path.exists():
814
+ artifacts.append(
815
+ ModelArtifact(path=str(preprocess_path), description="preprocess artifacts")
816
+ )
817
+
818
+ # Prediction cache artifacts
819
+ if bool(cfg.get("cache_predictions", False)):
820
+ cache_dir = cfg.get("prediction_cache_dir")
821
+ if cache_dir:
822
+ pred_root = Path(str(cache_dir))
823
+ if not pred_root.is_absolute():
824
+ pred_root = Path(model.output_manager.result_dir) / pred_root
825
+ else:
826
+ pred_root = Path(model.output_manager.result_dir) / "predictions"
827
+ ext = "csv" if str(cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
828
+ for split_label in ("train", "test"):
829
+ pred_path = pred_root / f"{model_name}_{model_key}_{split_label}.{ext}"
830
+ if pred_path.exists():
831
+ artifacts.append(
832
+ ModelArtifact(path=str(pred_path), description=f"predictions {split_label}")
833
+ )
834
+
835
+ return artifacts
836
+
837
+
523
838
  def _evaluate_and_report(
524
839
  model: ropt.BayesOptModel,
525
840
  *,
@@ -544,374 +859,164 @@ def _evaluate_and_report(
544
859
  run_id: str,
545
860
  config_sha: str,
546
861
  ) -> None:
862
+ """Evaluate model predictions and generate reports.
863
+
864
+ This function orchestrates the evaluation pipeline:
865
+ 1. Extract predictions and ground truth
866
+ 2. Apply calibration (for classification)
867
+ 3. Select threshold (for classification)
868
+ 4. Compute metrics
869
+ 5. Compute bootstrap confidence intervals
870
+ 6. Generate validation tables and risk trends
871
+ 7. Write reports and register model
872
+ """
547
873
  if eval_metrics_report is None:
548
874
  print("[Report] Skip evaluation: metrics module unavailable.")
549
875
  return
550
876
 
551
877
  pred_col = PLOT_MODEL_LABELS.get(model_key, (None, f"pred_{model_key}"))[1]
552
878
  if pred_col not in model.test_data.columns:
553
- print(
554
- f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
879
+ print(f"[Report] Missing prediction column '{pred_col}' for {model_name}/{model_key}; skip.")
555
880
  return
556
881
 
882
+ # Extract predictions and weights
557
883
  weight_col = getattr(model, "weight_nme", None)
558
- y_true_train = model.train_data[model.resp_nme].to_numpy(
559
- dtype=float, copy=False)
560
- y_true_test = model.test_data[model.resp_nme].to_numpy(
561
- dtype=float, copy=False)
884
+ y_true_train = model.train_data[model.resp_nme].to_numpy(dtype=float, copy=False)
885
+ y_true_test = model.test_data[model.resp_nme].to_numpy(dtype=float, copy=False)
562
886
  y_pred_train = model.train_data[pred_col].to_numpy(dtype=float, copy=False)
563
887
  y_pred_test = model.test_data[pred_col].to_numpy(dtype=float, copy=False)
564
- weight_train = (
565
- model.train_data[weight_col].to_numpy(dtype=float, copy=False)
566
- if weight_col and weight_col in model.train_data.columns
567
- else None
568
- )
569
888
  weight_test = (
570
889
  model.test_data[weight_col].to_numpy(dtype=float, copy=False)
571
890
  if weight_col and weight_col in model.test_data.columns
572
891
  else None
573
892
  )
574
893
 
575
- task_type = str(cfg.get("task_type", getattr(
576
- model, "task_type", "regression")))
894
+ task_type = str(cfg.get("task_type", getattr(model, "task_type", "regression")))
895
+
896
+ # Process based on task type
577
897
  if task_type == "classification":
578
898
  y_pred_train = np.clip(y_pred_train, 0.0, 1.0)
579
899
  y_pred_test = np.clip(y_pred_test, 0.0, 1.0)
580
900
 
581
- calibration_info: Optional[Dict[str, Any]] = None
582
- threshold_info: Optional[Dict[str, Any]] = None
583
- y_pred_train_eval = y_pred_train
584
- y_pred_test_eval = y_pred_test
585
-
586
- if task_type == "classification":
587
- cal_cfg = dict(calibration_cfg or {})
588
- cal_enabled = bool(cal_cfg.get("enable", False)
589
- or cal_cfg.get("method"))
590
- if cal_enabled and calibrate_predictions is not None:
591
- method = cal_cfg.get("method", "sigmoid")
592
- max_rows = cal_cfg.get("max_rows")
593
- seed = cal_cfg.get("seed")
594
- y_cal, p_cal = _sample_arrays(
595
- y_true_train, y_pred_train, max_rows=max_rows, seed=seed)
596
- try:
597
- calibrator = calibrate_predictions(y_cal, p_cal, method=method)
598
- y_pred_train_eval = calibrator.predict(y_pred_train)
599
- y_pred_test_eval = calibrator.predict(y_pred_test)
600
- calibration_info = {
601
- "method": calibrator.method, "max_rows": max_rows}
602
- except Exception as exc:
603
- print(
604
- f"[Report] Calibration failed for {model_name}/{model_key}: {exc}")
605
-
606
- thr_cfg = dict(threshold_cfg or {})
607
- thr_enabled = bool(
608
- thr_cfg.get("enable", False)
609
- or thr_cfg.get("metric")
610
- or thr_cfg.get("value") is not None
901
+ y_pred_train_eval, y_pred_test_eval, calibration_info = _apply_calibration(
902
+ y_true_train, y_pred_train, y_pred_test, calibration_cfg, model_name, model_key
611
903
  )
612
- threshold_value = 0.5
613
- if thr_cfg.get("value") is not None:
614
- threshold_value = float(thr_cfg["value"])
615
- threshold_info = {"threshold": threshold_value, "source": "fixed"}
616
- elif thr_enabled and select_threshold is not None:
617
- max_rows = thr_cfg.get("max_rows")
618
- seed = thr_cfg.get("seed")
619
- y_thr, p_thr = _sample_arrays(
620
- y_true_train, y_pred_train_eval, max_rows=max_rows, seed=seed)
621
- threshold_info = select_threshold(
622
- y_thr,
623
- p_thr,
624
- metric=thr_cfg.get("metric", "f1"),
625
- min_positive_rate=thr_cfg.get("min_positive_rate"),
626
- grid=thr_cfg.get("grid", 99),
627
- )
628
- threshold_value = float(threshold_info.get("threshold", 0.5))
629
- else:
630
- threshold_value = 0.5
631
- metrics = eval_metrics_report(
632
- y_true_test,
633
- y_pred_test_eval,
634
- task_type=task_type,
635
- threshold=threshold_value,
904
+ threshold_value, threshold_info = _select_classification_threshold(
905
+ y_true_train, y_pred_train_eval, threshold_cfg
636
906
  )
637
- precision = float(metrics.get("precision", 0.0))
638
- recall = float(metrics.get("recall", 0.0))
639
- f1 = 0.0 if (precision + recall) == 0 else 2 * \
640
- precision * recall / (precision + recall)
641
- metrics["f1"] = float(f1)
642
- metrics["threshold"] = float(threshold_value)
907
+ metrics = _compute_classification_metrics(y_true_test, y_pred_test_eval, threshold_value)
643
908
  else:
909
+ y_pred_test_eval = y_pred_test
910
+ calibration_info = None
911
+ threshold_info = None
644
912
  metrics = eval_metrics_report(
645
- y_true_test,
646
- y_pred_test_eval,
647
- task_type=task_type,
648
- weight=weight_test,
913
+ y_true_test, y_pred_test_eval, task_type=task_type, weight=weight_test
649
914
  )
650
915
 
651
- bootstrap_results: Dict[str, Dict[str, float]] = {}
652
- if bootstrap_cfg and bool(bootstrap_cfg.get("enable", False)) and bootstrap_ci is not None:
653
- metric_names = bootstrap_cfg.get("metrics") or list(metrics.keys())
654
- n_samples = int(bootstrap_cfg.get("n_samples", 200))
655
- ci = float(bootstrap_cfg.get("ci", 0.95))
656
- seed = bootstrap_cfg.get("seed")
657
-
658
- def _metric_fn(y_true, y_pred, weight=None):
659
- vals = eval_metrics_report(
660
- y_true,
661
- y_pred,
662
- task_type=task_type,
663
- weight=weight,
664
- threshold=metrics.get("threshold", 0.5),
665
- )
666
- if task_type == "classification":
667
- prec = float(vals.get("precision", 0.0))
668
- rec = float(vals.get("recall", 0.0))
669
- vals["f1"] = 0.0 if (prec + rec) == 0 else 2 * \
670
- prec * rec / (prec + rec)
671
- return vals
672
-
673
- for name in metric_names:
674
- if name not in metrics:
675
- continue
676
- ci_result = bootstrap_ci(
677
- lambda y_t, y_p, w=None: float(
678
- _metric_fn(y_t, y_p, w).get(name, 0.0)),
679
- y_true_test,
680
- y_pred_test_eval,
681
- weight=weight_test,
682
- n_samples=n_samples,
683
- ci=ci,
684
- seed=seed,
685
- )
686
- bootstrap_results[str(name)] = ci_result
916
+ # Compute bootstrap confidence intervals
917
+ bootstrap_results = _compute_bootstrap_ci(
918
+ y_true_test, y_pred_test_eval, weight_test, metrics, bootstrap_cfg, task_type
919
+ )
687
920
 
688
- validation_table = None
689
- if report_group_cols and group_metrics is not None:
690
- available_groups = [
691
- col for col in report_group_cols if col in model.test_data.columns
692
- ]
693
- if available_groups:
694
- try:
695
- validation_table = group_metrics(
696
- model.test_data,
697
- actual_col=model.resp_nme,
698
- pred_col=pred_col,
699
- group_cols=available_groups,
700
- weight_col=weight_col if weight_col and weight_col in model.test_data.columns else None,
701
- )
702
- counts = (
703
- model.test_data.groupby(available_groups, dropna=False)
704
- .size()
705
- .reset_index(name="count")
706
- )
707
- validation_table = validation_table.merge(
708
- counts, on=available_groups, how="left")
709
- except Exception as exc:
710
- print(
711
- f"[Report] group_metrics failed for {model_name}/{model_key}: {exc}")
712
-
713
- risk_trend = None
714
- if report_time_col and group_metrics is not None:
715
- if report_time_col in model.test_data.columns:
716
- try:
717
- time_df = model.test_data.copy()
718
- time_series = pd.to_datetime(
719
- time_df[report_time_col], errors="coerce")
720
- time_df = time_df.loc[time_series.notna()].copy()
721
- if not time_df.empty:
722
- time_df["_time_bucket"] = (
723
- pd.to_datetime(
724
- time_df[report_time_col], errors="coerce")
725
- .dt.to_period(report_time_freq)
726
- .dt.to_timestamp()
727
- )
728
- risk_trend = group_metrics(
729
- time_df,
730
- actual_col=model.resp_nme,
731
- pred_col=pred_col,
732
- group_cols=["_time_bucket"],
733
- weight_col=weight_col if weight_col and weight_col in time_df.columns else None,
734
- )
735
- counts = (
736
- time_df.groupby("_time_bucket", dropna=False)
737
- .size()
738
- .reset_index(name="count")
739
- )
740
- risk_trend = risk_trend.merge(
741
- counts, on="_time_bucket", how="left")
742
- risk_trend = risk_trend.sort_values(
743
- "_time_bucket", ascending=bool(report_time_ascending)
744
- ).reset_index(drop=True)
745
- risk_trend = risk_trend.rename(
746
- columns={"_time_bucket": report_time_col})
747
- except Exception as exc:
748
- print(
749
- f"[Report] time metrics failed for {model_name}/{model_key}: {exc}")
921
+ # Compute validation table and risk trend
922
+ validation_table = _compute_validation_table(
923
+ model, pred_col, report_group_cols, weight_col, model_name, model_key
924
+ )
925
+ risk_trend = _compute_risk_trend(
926
+ model, pred_col, report_time_col, report_time_freq,
927
+ report_time_ascending, weight_col, model_name, model_key
928
+ )
750
929
 
930
+ # Setup output directory
751
931
  report_root = (
752
932
  Path(report_output_dir)
753
933
  if report_output_dir
754
934
  else Path(model.output_manager.result_dir) / "reports"
755
935
  )
756
936
  report_root.mkdir(parents=True, exist_ok=True)
757
-
758
937
  version = f"{model_key}_{run_id}"
759
- metrics_payload = {
760
- "model_name": model_name,
761
- "model_key": model_key,
762
- "model_version": version,
763
- "metrics": metrics,
764
- "threshold": threshold_info,
765
- "calibration": calibration_info,
766
- "bootstrap": bootstrap_results,
767
- "data_path": str(data_path),
768
- "data_fingerprint": data_fingerprint,
769
- "config_sha256": config_sha,
770
- "pred_col": pred_col,
771
- "task_type": task_type,
772
- }
773
- metrics_path = report_root / f"{model_name}_{model_key}_metrics.json"
774
- metrics_path.write_text(
775
- json.dumps(metrics_payload, indent=2, ensure_ascii=True),
776
- encoding="utf-8",
938
+
939
+ # Write metrics JSON
940
+ metrics_path = _write_metrics_json(
941
+ report_root, model_name, model_key, version, metrics,
942
+ threshold_info, calibration_info, bootstrap_results,
943
+ data_path, data_fingerprint, config_sha, pred_col, task_type
777
944
  )
778
945
 
779
- report_path = None
780
- if ReportPayload is not None and write_report is not None:
781
- notes_lines = [
782
- f"- Config SHA256: {config_sha}",
783
- f"- Data fingerprint: {data_fingerprint.get('sha256_prefix')}",
784
- ]
785
- if calibration_info:
786
- notes_lines.append(
787
- f"- Calibration: {calibration_info.get('method')}"
788
- )
789
- if threshold_info:
790
- notes_lines.append(
791
- f"- Threshold selection: {threshold_info}"
792
- )
793
- if bootstrap_results:
794
- notes_lines.append("- Bootstrap: see metrics JSON for CI")
795
- extra_notes = "\n".join(notes_lines)
796
- payload = ReportPayload(
797
- model_name=f"{model_name}/{model_key}",
798
- model_version=version,
799
- metrics={k: float(v) for k, v in metrics.items()},
800
- risk_trend=risk_trend,
801
- drift_report=psi_report_df,
802
- validation_table=validation_table,
803
- extra_notes=extra_notes,
804
- )
805
- report_path = write_report(
806
- payload,
807
- report_root / f"{model_name}_{model_key}_report.md",
808
- )
946
+ # Write model report
947
+ report_path = _write_model_report(
948
+ report_root, model_name, model_key, version, metrics,
949
+ risk_trend, psi_report_df, validation_table,
950
+ calibration_info, threshold_info, bootstrap_results,
951
+ config_sha, data_fingerprint
952
+ )
809
953
 
810
- if register_model and ModelRegistry is not None and ModelArtifact is not None:
811
- registry = ModelRegistry(
812
- registry_path
813
- if registry_path
814
- else Path(model.output_manager.result_dir) / "model_registry.json"
815
- )
816
- tags = {str(k): str(v) for k, v in (registry_tags or {}).items()}
817
- tags.update({
818
- "model_key": str(model_key),
819
- "task_type": str(task_type),
820
- "data_path": str(data_path),
821
- "data_sha256_prefix": str(data_fingerprint.get("sha256_prefix", "")),
822
- "data_size": str(data_fingerprint.get("size", "")),
823
- "data_mtime": str(data_fingerprint.get("mtime", "")),
824
- "config_sha256": str(config_sha),
825
- })
826
- artifacts = []
827
- trainer = model.trainers.get(model_key)
828
- if trainer is not None:
829
- try:
830
- model_path = trainer.output.model_path(
831
- trainer._get_model_filename())
832
- if os.path.exists(model_path):
833
- artifacts.append(ModelArtifact(
834
- path=model_path, description="trained model"))
835
- except Exception:
836
- pass
837
- if report_path is not None:
838
- artifacts.append(ModelArtifact(
839
- path=str(report_path), description="model report"))
840
- if metrics_path.exists():
841
- artifacts.append(ModelArtifact(
842
- path=str(metrics_path), description="metrics json"))
843
- if bool(cfg.get("save_preprocess", False)):
844
- artifact_path = cfg.get("preprocess_artifact_path")
845
- if artifact_path:
846
- preprocess_path = Path(str(artifact_path))
847
- if not preprocess_path.is_absolute():
848
- preprocess_path = Path(
849
- model.output_manager.result_dir) / preprocess_path
850
- else:
851
- preprocess_path = Path(model.output_manager.result_path(
852
- f"{model.model_nme}_preprocess.json"
853
- ))
854
- if preprocess_path.exists():
855
- artifacts.append(
856
- ModelArtifact(path=str(preprocess_path),
857
- description="preprocess artifacts")
858
- )
859
- if bool(cfg.get("cache_predictions", False)):
860
- cache_dir = cfg.get("prediction_cache_dir")
861
- if cache_dir:
862
- pred_root = Path(str(cache_dir))
863
- if not pred_root.is_absolute():
864
- pred_root = Path(
865
- model.output_manager.result_dir) / pred_root
866
- else:
867
- pred_root = Path(
868
- model.output_manager.result_dir) / "predictions"
869
- ext = "csv" if str(
870
- cfg.get("prediction_cache_format", "parquet")).lower() == "csv" else "parquet"
871
- for split_label in ("train", "test"):
872
- pred_path = pred_root / \
873
- f"{model_name}_{model_key}_{split_label}.{ext}"
874
- if pred_path.exists():
875
- artifacts.append(
876
- ModelArtifact(path=str(pred_path),
877
- description=f"predictions {split_label}")
878
- )
879
- registry.register(
880
- name=str(model_name),
881
- version=version,
882
- metrics={k: float(v) for k, v in metrics.items()},
883
- tags=tags,
884
- artifacts=artifacts,
885
- status=str(registry_status or "candidate"),
886
- notes=f"model_key={model_key}",
954
+ # Register model
955
+ if register_model:
956
+ _register_model_to_registry(
957
+ model, model_name, model_key, version, metrics, task_type,
958
+ data_path, data_fingerprint, config_sha, registry_path,
959
+ registry_tags, registry_status, report_path, metrics_path, cfg
887
960
  )
888
961
 
889
962
 
890
- def train_from_config(args: argparse.Namespace) -> None:
891
- script_dir = Path(__file__).resolve().parents[1]
892
- config_path, cfg = resolve_and_load_config(
893
- args.config_json,
894
- script_dir,
895
- required_keys=["data_dir", "model_list",
896
- "model_categories", "target", "weight"],
963
+ def _evaluate_with_context(
964
+ model: ropt.BayesOptModel,
965
+ ctx: EvaluationContext,
966
+ ) -> None:
967
+ """Evaluate model predictions using context object.
968
+
969
+ This is a cleaner interface that uses the EvaluationContext dataclass
970
+ instead of 19+ individual parameters.
971
+ """
972
+ _evaluate_and_report(
973
+ model,
974
+ model_name=ctx.identity.model_name,
975
+ model_key=ctx.identity.model_key,
976
+ cfg=ctx.cfg,
977
+ data_path=ctx.data_path,
978
+ data_fingerprint=ctx.data_fingerprint.to_dict(),
979
+ report_output_dir=ctx.report.output_dir,
980
+ report_group_cols=ctx.report.group_cols,
981
+ report_time_col=ctx.report.time_col,
982
+ report_time_freq=ctx.report.time_freq,
983
+ report_time_ascending=ctx.report.time_ascending,
984
+ psi_report_df=ctx.psi_report_df,
985
+ calibration_cfg={
986
+ "enable": ctx.calibration.enable,
987
+ "method": ctx.calibration.method,
988
+ "max_rows": ctx.calibration.max_rows,
989
+ "seed": ctx.calibration.seed,
990
+ },
991
+ threshold_cfg={
992
+ "enable": ctx.threshold.enable,
993
+ "metric": ctx.threshold.metric,
994
+ "value": ctx.threshold.value,
995
+ "min_positive_rate": ctx.threshold.min_positive_rate,
996
+ "grid": ctx.threshold.grid,
997
+ "max_rows": ctx.threshold.max_rows,
998
+ "seed": ctx.threshold.seed,
999
+ },
1000
+ bootstrap_cfg={
1001
+ "enable": ctx.bootstrap.enable,
1002
+ "metrics": ctx.bootstrap.metrics,
1003
+ "n_samples": ctx.bootstrap.n_samples,
1004
+ "ci": ctx.bootstrap.ci,
1005
+ "seed": ctx.bootstrap.seed,
1006
+ },
1007
+ register_model=ctx.registry.register,
1008
+ registry_path=ctx.registry.path,
1009
+ registry_tags=ctx.registry.tags,
1010
+ registry_status=ctx.registry.status,
1011
+ run_id=ctx.run_id,
1012
+ config_sha=ctx.config_sha,
897
1013
  )
898
- plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
899
- config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
900
- run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
901
1014
 
902
- def _safe_int_env(key: str, default: int) -> int:
903
- try:
904
- return int(os.environ.get(key, default))
905
- except (TypeError, ValueError):
906
- return default
907
-
908
- dist_world_size = _safe_int_env("WORLD_SIZE", 1)
909
- dist_rank = _safe_int_env("RANK", 0)
910
- dist_active = dist_world_size > 1
911
- is_main_process = (not dist_active) or dist_rank == 0
912
1015
 
1016
+ def _create_ddp_barrier(dist_ctx: TrainingContext):
1017
+ """Create a DDP barrier function for distributed training synchronization."""
913
1018
  def _ddp_barrier(reason: str) -> None:
914
- if not dist_active:
1019
+ if not dist_ctx.is_distributed:
915
1020
  return
916
1021
  torch_mod = getattr(ropt, "torch", None)
917
1022
  dist_mod = getattr(torch_mod, "distributed", None)
@@ -928,6 +1033,28 @@ def train_from_config(args: argparse.Namespace) -> None:
928
1033
  except Exception as exc:
929
1034
  print(f"[DDP] barrier failed during {reason}: {exc}", flush=True)
930
1035
  raise
1036
+ return _ddp_barrier
1037
+
1038
+
1039
+ def train_from_config(args: argparse.Namespace) -> None:
1040
+ script_dir = Path(__file__).resolve().parents[1]
1041
+ config_path, cfg = resolve_and_load_config(
1042
+ args.config_json,
1043
+ script_dir,
1044
+ required_keys=["data_dir", "model_list",
1045
+ "model_categories", "target", "weight"],
1046
+ )
1047
+ plot_requested = bool(args.plot_curves or cfg.get("plot_curves", False))
1048
+ config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()
1049
+ run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
1050
+
1051
+ # Use TrainingContext for distributed training state
1052
+ dist_ctx = TrainingContext.from_env()
1053
+ dist_world_size = dist_ctx.world_size
1054
+ dist_rank = dist_ctx.rank
1055
+ dist_active = dist_ctx.is_distributed
1056
+ is_main_process = dist_ctx.is_main_process
1057
+ _ddp_barrier = _create_ddp_barrier(dist_ctx)
931
1058
 
932
1059
  data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
933
1060
  cfg,