ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,375 +1,375 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import os
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Sequence, Tuple
8
-
9
- try:
10
- from .cli_common import resolve_dir_path, resolve_path # type: ignore
11
- except Exception: # pragma: no cover
12
- from cli_common import resolve_dir_path, resolve_path # type: ignore
13
-
14
-
15
- def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
- candidate = Path(raw)
17
- if candidate.exists():
18
- return candidate.resolve()
19
- candidate2 = (script_dir / raw)
20
- if candidate2.exists():
21
- return candidate2.resolve()
22
- raise FileNotFoundError(
23
- f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
- )
25
-
26
-
27
- def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
- cfg = json.loads(path.read_text(encoding="utf-8"))
29
- missing = [key for key in required_keys if key not in cfg]
30
- if missing:
31
- raise ValueError(f"Missing required keys in {path}: {missing}")
32
- return cfg
33
-
34
-
35
- def set_env(env_overrides: Dict[str, Any]) -> None:
36
- """Apply environment variables from config.json.
37
-
38
- Notes (DDP/Optuna hang debugging):
39
- - You can add these keys into config.json's `env` to debug distributed hangs:
40
- - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
- - `NCCL_DEBUG=INFO`
42
- - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
- - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
- - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
- - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
- - This function uses `os.environ.setdefault`, so a value already set in the
47
- shell will take precedence over config.json.
48
- """
49
- for key, value in (env_overrides or {}).items():
50
- os.environ.setdefault(str(key), str(value))
51
-
52
-
53
- def _looks_like_url(value: str) -> bool:
54
- value = str(value)
55
- return "://" in value
56
-
57
-
58
- def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
- """Resolve relative paths against the config.json directory.
60
-
61
- Fields handled:
62
- - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
- - best_params_files (dict: model_key -> path)
64
- """
65
- base_dir = config_path.parent
66
- out = dict(cfg)
67
-
68
- for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
- "prediction_cache_dir", "report_output_dir", "registry_path"):
70
- if key in out and isinstance(out.get(key), str):
71
- resolved = resolve_path(out.get(key), base_dir)
72
- if resolved is not None:
73
- out[key] = str(resolved)
74
-
75
- storage = out.get("optuna_storage")
76
- if isinstance(storage, str) and storage.strip():
77
- if not _looks_like_url(storage):
78
- resolved = resolve_path(storage, base_dir)
79
- if resolved is not None:
80
- out["optuna_storage"] = str(resolved)
81
-
82
- best_files = out.get("best_params_files")
83
- if isinstance(best_files, dict):
84
- resolved_map: Dict[str, str] = {}
85
- for mk, path_str in best_files.items():
86
- if not isinstance(path_str, str):
87
- continue
88
- resolved = resolve_path(path_str, base_dir)
89
- resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
- out["best_params_files"] = resolved_map
91
-
92
- return out
93
-
94
-
95
- def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
- if value is None:
97
- return {}
98
- if isinstance(value, dict):
99
- return {str(k): v for k, v in value.items()}
100
- if isinstance(value, str):
101
- path = resolve_path(value, base_dir)
102
- if path is None or not path.exists():
103
- raise FileNotFoundError(f"dtype_map not found: {value}")
104
- payload = json.loads(path.read_text(encoding="utf-8"))
105
- if not isinstance(payload, dict):
106
- raise ValueError(f"dtype_map must be a dict: {path}")
107
- return {str(k): v for k, v in payload.items()}
108
- raise ValueError("dtype_map must be a dict or JSON path.")
109
-
110
-
111
- def resolve_data_config(
112
- cfg: Dict[str, Any],
113
- config_path: Path,
114
- *,
115
- create_data_dir: bool = False,
116
- ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
- base_dir = config_path.parent
118
- data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
- if data_dir is None:
120
- raise ValueError("data_dir is required in config.json.")
121
- data_format = cfg.get("data_format", "csv")
122
- data_path_template = cfg.get("data_path_template")
123
- dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
- return data_dir, data_format, data_path_template, dtype_map
125
-
126
-
127
- def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
- parser.add_argument(
129
- "--config-json",
130
- required=True,
131
- help=help_text,
132
- )
133
-
134
-
135
- def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
- parser.add_argument(
137
- "--output-dir",
138
- default=None,
139
- help=help_text,
140
- )
141
-
142
-
143
- def resolve_model_path_value(
144
- value: Any,
145
- *,
146
- model_name: str,
147
- base_dir: Path,
148
- data_dir: Optional[Path] = None,
149
- ) -> Optional[Path]:
150
- if value is None:
151
- return None
152
- if isinstance(value, dict):
153
- value = value.get(model_name)
154
- if value is None:
155
- return None
156
- path_str = str(value)
157
- try:
158
- path_str = path_str.format(model_name=model_name)
159
- except Exception:
160
- pass
161
- if data_dir is not None and not Path(path_str).is_absolute():
162
- candidate = data_dir / path_str
163
- if candidate.exists():
164
- return candidate.resolve()
165
- resolved = resolve_path(path_str, base_dir)
166
- if resolved is None:
167
- return None
168
- return resolved
169
-
170
-
171
- def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
- if not value:
173
- return None
174
- path_str = str(value)
175
- resolved = resolve_path(path_str, base_dir)
176
- return resolved if resolved is not None else Path(path_str)
177
-
178
-
179
- def resolve_explain_save_dir(
180
- save_root: Optional[Path],
181
- *,
182
- result_dir: Optional[Any],
183
- ) -> Path:
184
- if save_root is not None:
185
- return Path(save_root)
186
- if result_dir is None:
187
- raise ValueError("result_dir is required when explain save_root is not set.")
188
- return Path(result_dir) / "explain"
189
-
190
-
191
- def resolve_explain_output_overrides(
192
- explain_cfg: Dict[str, Any],
193
- *,
194
- model_name: str,
195
- base_dir: Path,
196
- ) -> Dict[str, Optional[Path]]:
197
- return {
198
- "model_dir": resolve_model_path_value(
199
- explain_cfg.get("model_dir"),
200
- model_name=model_name,
201
- base_dir=base_dir,
202
- data_dir=None,
203
- ),
204
- "result_dir": resolve_model_path_value(
205
- explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
- model_name=model_name,
207
- base_dir=base_dir,
208
- data_dir=None,
209
- ),
210
- "plot_dir": resolve_model_path_value(
211
- explain_cfg.get("plot_dir"),
212
- model_name=model_name,
213
- base_dir=base_dir,
214
- data_dir=None,
215
- ),
216
- }
217
-
218
-
219
- def resolve_and_load_config(
220
- raw: str,
221
- script_dir: Path,
222
- required_keys: Sequence[str],
223
- *,
224
- apply_env: bool = True,
225
- ) -> Tuple[Path, Dict[str, Any]]:
226
- config_path = resolve_config_path(raw, script_dir)
227
- cfg = load_config_json(config_path, required_keys=required_keys)
228
- cfg = normalize_config_paths(cfg, config_path)
229
- if apply_env:
230
- set_env(cfg.get("env", {}))
231
- return config_path, cfg
232
-
233
-
234
- def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
- def _as_list(value: Any) -> list[str]:
236
- if value is None:
237
- return []
238
- if isinstance(value, (list, tuple, set)):
239
- return [str(item) for item in value]
240
- return [str(value)]
241
-
242
- report_output_dir = cfg.get("report_output_dir")
243
- report_group_cols = _as_list(cfg.get("report_group_cols"))
244
- if not report_group_cols:
245
- report_group_cols = None
246
- report_time_col = cfg.get("report_time_col")
247
- report_time_freq = cfg.get("report_time_freq", "M")
248
- report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
- psi_bins = cfg.get("psi_bins", 10)
250
- psi_strategy = cfg.get("psi_strategy", "quantile")
251
- psi_features = _as_list(cfg.get("psi_features"))
252
- if not psi_features:
253
- psi_features = None
254
- calibration_cfg = cfg.get("calibration", {}) or {}
255
- threshold_cfg = cfg.get("threshold", {}) or {}
256
- bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
- register_model = bool(cfg.get("register_model", False))
258
- registry_path = cfg.get("registry_path")
259
- registry_tags = cfg.get("registry_tags", {})
260
- registry_status = cfg.get("registry_status", "candidate")
261
- data_fingerprint_max_bytes = int(
262
- cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
- calibration_enabled = bool(
264
- calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
- )
266
- threshold_enabled = bool(
267
- threshold_cfg.get("enable", False)
268
- or threshold_cfg.get("value") is not None
269
- or threshold_cfg.get("metric")
270
- )
271
- bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
- report_enabled = any([
273
- bool(report_output_dir),
274
- register_model,
275
- bool(report_group_cols),
276
- bool(report_time_col),
277
- bool(psi_features),
278
- calibration_enabled,
279
- threshold_enabled,
280
- bootstrap_enabled,
281
- ])
282
- return {
283
- "report_output_dir": report_output_dir,
284
- "report_group_cols": report_group_cols,
285
- "report_time_col": report_time_col,
286
- "report_time_freq": report_time_freq,
287
- "report_time_ascending": report_time_ascending,
288
- "psi_bins": psi_bins,
289
- "psi_strategy": psi_strategy,
290
- "psi_features": psi_features,
291
- "calibration_cfg": calibration_cfg,
292
- "threshold_cfg": threshold_cfg,
293
- "bootstrap_cfg": bootstrap_cfg,
294
- "register_model": register_model,
295
- "registry_path": registry_path,
296
- "registry_tags": registry_tags,
297
- "registry_status": registry_status,
298
- "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
- "report_enabled": report_enabled,
300
- }
301
-
302
-
303
- def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
- prop_test = cfg.get("prop_test", 0.25)
305
- holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
- if holdout_ratio is None:
307
- holdout_ratio = prop_test
308
- val_ratio = cfg.get("val_ratio", prop_test)
309
- if val_ratio is None:
310
- val_ratio = prop_test
311
- split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
- split_group_col = cfg.get("split_group_col")
313
- split_time_col = cfg.get("split_time_col")
314
- split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
- cv_strategy = cfg.get("cv_strategy")
316
- cv_group_col = cfg.get("cv_group_col")
317
- cv_time_col = cfg.get("cv_time_col")
318
- cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
- cv_splits = cfg.get("cv_splits")
320
- ft_oof_folds = cfg.get("ft_oof_folds")
321
- ft_oof_strategy = cfg.get("ft_oof_strategy")
322
- ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
- return {
324
- "prop_test": prop_test,
325
- "holdout_ratio": holdout_ratio,
326
- "val_ratio": val_ratio,
327
- "split_strategy": split_strategy,
328
- "split_group_col": split_group_col,
329
- "split_time_col": split_time_col,
330
- "split_time_ascending": split_time_ascending,
331
- "cv_strategy": cv_strategy,
332
- "cv_group_col": cv_group_col,
333
- "cv_time_col": cv_time_col,
334
- "cv_time_ascending": cv_time_ascending,
335
- "cv_splits": cv_splits,
336
- "ft_oof_folds": ft_oof_folds,
337
- "ft_oof_strategy": ft_oof_strategy,
338
- "ft_oof_shuffle": ft_oof_shuffle,
339
- }
340
-
341
-
342
- def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
343
- return {
344
- "save_preprocess": bool(cfg.get("save_preprocess", False)),
345
- "preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
346
- "rand_seed": cfg.get("rand_seed", 13),
347
- "epochs": cfg.get("epochs", 50),
348
- "plot_path_style": cfg.get("plot_path_style"),
349
- "reuse_best_params": bool(cfg.get("reuse_best_params", False)),
350
- "xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
351
- "xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
352
- "optuna_storage": cfg.get("optuna_storage"),
353
- "optuna_study_prefix": cfg.get("optuna_study_prefix"),
354
- "best_params_files": cfg.get("best_params_files"),
355
- "bo_sample_limit": cfg.get("bo_sample_limit"),
356
- "cache_predictions": bool(cfg.get("cache_predictions", False)),
357
- "prediction_cache_dir": cfg.get("prediction_cache_dir"),
358
- "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
359
- "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
360
- }
361
-
362
-
363
- def resolve_output_dirs(
364
- cfg: Dict[str, Any],
365
- config_path: Path,
366
- *,
367
- output_override: Optional[str] = None,
368
- ) -> Dict[str, Optional[str]]:
369
- output_root = resolve_dir_path(
370
- output_override or cfg.get("output_dir"),
371
- config_path.parent,
372
- )
373
- return {
374
- "output_dir": str(output_root) if output_root is not None else None,
375
- }
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Sequence, Tuple
8
+
9
+ try:
10
+ from ins_pricing.cli.utils.cli_common import resolve_dir_path, resolve_path # type: ignore
11
+ except Exception: # pragma: no cover
12
+ from cli_common import resolve_dir_path, resolve_path # type: ignore
13
+
14
+
15
+ def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
+ candidate = Path(raw)
17
+ if candidate.exists():
18
+ return candidate.resolve()
19
+ candidate2 = (script_dir / raw)
20
+ if candidate2.exists():
21
+ return candidate2.resolve()
22
+ raise FileNotFoundError(
23
+ f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
+ )
25
+
26
+
27
+ def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
+ cfg = json.loads(path.read_text(encoding="utf-8"))
29
+ missing = [key for key in required_keys if key not in cfg]
30
+ if missing:
31
+ raise ValueError(f"Missing required keys in {path}: {missing}")
32
+ return cfg
33
+
34
+
35
+ def set_env(env_overrides: Dict[str, Any]) -> None:
36
+ """Apply environment variables from config.json.
37
+
38
+ Notes (DDP/Optuna hang debugging):
39
+ - You can add these keys into config.json's `env` to debug distributed hangs:
40
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
+ - `NCCL_DEBUG=INFO`
42
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
+ - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
+ - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
+ - This function uses `os.environ.setdefault`, so a value already set in the
47
+ shell will take precedence over config.json.
48
+ """
49
+ for key, value in (env_overrides or {}).items():
50
+ os.environ.setdefault(str(key), str(value))
51
+
52
+
53
+ def _looks_like_url(value: str) -> bool:
54
+ value = str(value)
55
+ return "://" in value
56
+
57
+
58
+ def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
+ """Resolve relative paths against the config.json directory.
60
+
61
+ Fields handled:
62
+ - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
+ - best_params_files (dict: model_key -> path)
64
+ """
65
+ base_dir = config_path.parent
66
+ out = dict(cfg)
67
+
68
+ for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
+ "prediction_cache_dir", "report_output_dir", "registry_path"):
70
+ if key in out and isinstance(out.get(key), str):
71
+ resolved = resolve_path(out.get(key), base_dir)
72
+ if resolved is not None:
73
+ out[key] = str(resolved)
74
+
75
+ storage = out.get("optuna_storage")
76
+ if isinstance(storage, str) and storage.strip():
77
+ if not _looks_like_url(storage):
78
+ resolved = resolve_path(storage, base_dir)
79
+ if resolved is not None:
80
+ out["optuna_storage"] = str(resolved)
81
+
82
+ best_files = out.get("best_params_files")
83
+ if isinstance(best_files, dict):
84
+ resolved_map: Dict[str, str] = {}
85
+ for mk, path_str in best_files.items():
86
+ if not isinstance(path_str, str):
87
+ continue
88
+ resolved = resolve_path(path_str, base_dir)
89
+ resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
+ out["best_params_files"] = resolved_map
91
+
92
+ return out
93
+
94
+
95
+ def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
+ if value is None:
97
+ return {}
98
+ if isinstance(value, dict):
99
+ return {str(k): v for k, v in value.items()}
100
+ if isinstance(value, str):
101
+ path = resolve_path(value, base_dir)
102
+ if path is None or not path.exists():
103
+ raise FileNotFoundError(f"dtype_map not found: {value}")
104
+ payload = json.loads(path.read_text(encoding="utf-8"))
105
+ if not isinstance(payload, dict):
106
+ raise ValueError(f"dtype_map must be a dict: {path}")
107
+ return {str(k): v for k, v in payload.items()}
108
+ raise ValueError("dtype_map must be a dict or JSON path.")
109
+
110
+
111
+ def resolve_data_config(
112
+ cfg: Dict[str, Any],
113
+ config_path: Path,
114
+ *,
115
+ create_data_dir: bool = False,
116
+ ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
+ base_dir = config_path.parent
118
+ data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
+ if data_dir is None:
120
+ raise ValueError("data_dir is required in config.json.")
121
+ data_format = cfg.get("data_format", "csv")
122
+ data_path_template = cfg.get("data_path_template")
123
+ dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
+ return data_dir, data_format, data_path_template, dtype_map
125
+
126
+
127
+ def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
+ parser.add_argument(
129
+ "--config-json",
130
+ required=True,
131
+ help=help_text,
132
+ )
133
+
134
+
135
+ def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
+ parser.add_argument(
137
+ "--output-dir",
138
+ default=None,
139
+ help=help_text,
140
+ )
141
+
142
+
143
+ def resolve_model_path_value(
144
+ value: Any,
145
+ *,
146
+ model_name: str,
147
+ base_dir: Path,
148
+ data_dir: Optional[Path] = None,
149
+ ) -> Optional[Path]:
150
+ if value is None:
151
+ return None
152
+ if isinstance(value, dict):
153
+ value = value.get(model_name)
154
+ if value is None:
155
+ return None
156
+ path_str = str(value)
157
+ try:
158
+ path_str = path_str.format(model_name=model_name)
159
+ except Exception:
160
+ pass
161
+ if data_dir is not None and not Path(path_str).is_absolute():
162
+ candidate = data_dir / path_str
163
+ if candidate.exists():
164
+ return candidate.resolve()
165
+ resolved = resolve_path(path_str, base_dir)
166
+ if resolved is None:
167
+ return None
168
+ return resolved
169
+
170
+
171
+ def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
+ if not value:
173
+ return None
174
+ path_str = str(value)
175
+ resolved = resolve_path(path_str, base_dir)
176
+ return resolved if resolved is not None else Path(path_str)
177
+
178
+
179
+ def resolve_explain_save_dir(
180
+ save_root: Optional[Path],
181
+ *,
182
+ result_dir: Optional[Any],
183
+ ) -> Path:
184
+ if save_root is not None:
185
+ return Path(save_root)
186
+ if result_dir is None:
187
+ raise ValueError("result_dir is required when explain save_root is not set.")
188
+ return Path(result_dir) / "explain"
189
+
190
+
191
+ def resolve_explain_output_overrides(
192
+ explain_cfg: Dict[str, Any],
193
+ *,
194
+ model_name: str,
195
+ base_dir: Path,
196
+ ) -> Dict[str, Optional[Path]]:
197
+ return {
198
+ "model_dir": resolve_model_path_value(
199
+ explain_cfg.get("model_dir"),
200
+ model_name=model_name,
201
+ base_dir=base_dir,
202
+ data_dir=None,
203
+ ),
204
+ "result_dir": resolve_model_path_value(
205
+ explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
+ model_name=model_name,
207
+ base_dir=base_dir,
208
+ data_dir=None,
209
+ ),
210
+ "plot_dir": resolve_model_path_value(
211
+ explain_cfg.get("plot_dir"),
212
+ model_name=model_name,
213
+ base_dir=base_dir,
214
+ data_dir=None,
215
+ ),
216
+ }
217
+
218
+
219
+ def resolve_and_load_config(
220
+ raw: str,
221
+ script_dir: Path,
222
+ required_keys: Sequence[str],
223
+ *,
224
+ apply_env: bool = True,
225
+ ) -> Tuple[Path, Dict[str, Any]]:
226
+ config_path = resolve_config_path(raw, script_dir)
227
+ cfg = load_config_json(config_path, required_keys=required_keys)
228
+ cfg = normalize_config_paths(cfg, config_path)
229
+ if apply_env:
230
+ set_env(cfg.get("env", {}))
231
+ return config_path, cfg
232
+
233
+
234
+ def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
+ def _as_list(value: Any) -> list[str]:
236
+ if value is None:
237
+ return []
238
+ if isinstance(value, (list, tuple, set)):
239
+ return [str(item) for item in value]
240
+ return [str(value)]
241
+
242
+ report_output_dir = cfg.get("report_output_dir")
243
+ report_group_cols = _as_list(cfg.get("report_group_cols"))
244
+ if not report_group_cols:
245
+ report_group_cols = None
246
+ report_time_col = cfg.get("report_time_col")
247
+ report_time_freq = cfg.get("report_time_freq", "M")
248
+ report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
+ psi_bins = cfg.get("psi_bins", 10)
250
+ psi_strategy = cfg.get("psi_strategy", "quantile")
251
+ psi_features = _as_list(cfg.get("psi_features"))
252
+ if not psi_features:
253
+ psi_features = None
254
+ calibration_cfg = cfg.get("calibration", {}) or {}
255
+ threshold_cfg = cfg.get("threshold", {}) or {}
256
+ bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
+ register_model = bool(cfg.get("register_model", False))
258
+ registry_path = cfg.get("registry_path")
259
+ registry_tags = cfg.get("registry_tags", {})
260
+ registry_status = cfg.get("registry_status", "candidate")
261
+ data_fingerprint_max_bytes = int(
262
+ cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
+ calibration_enabled = bool(
264
+ calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
+ )
266
+ threshold_enabled = bool(
267
+ threshold_cfg.get("enable", False)
268
+ or threshold_cfg.get("value") is not None
269
+ or threshold_cfg.get("metric")
270
+ )
271
+ bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
+ report_enabled = any([
273
+ bool(report_output_dir),
274
+ register_model,
275
+ bool(report_group_cols),
276
+ bool(report_time_col),
277
+ bool(psi_features),
278
+ calibration_enabled,
279
+ threshold_enabled,
280
+ bootstrap_enabled,
281
+ ])
282
+ return {
283
+ "report_output_dir": report_output_dir,
284
+ "report_group_cols": report_group_cols,
285
+ "report_time_col": report_time_col,
286
+ "report_time_freq": report_time_freq,
287
+ "report_time_ascending": report_time_ascending,
288
+ "psi_bins": psi_bins,
289
+ "psi_strategy": psi_strategy,
290
+ "psi_features": psi_features,
291
+ "calibration_cfg": calibration_cfg,
292
+ "threshold_cfg": threshold_cfg,
293
+ "bootstrap_cfg": bootstrap_cfg,
294
+ "register_model": register_model,
295
+ "registry_path": registry_path,
296
+ "registry_tags": registry_tags,
297
+ "registry_status": registry_status,
298
+ "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
+ "report_enabled": report_enabled,
300
+ }
301
+
302
+
303
+ def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
+ prop_test = cfg.get("prop_test", 0.25)
305
+ holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
+ if holdout_ratio is None:
307
+ holdout_ratio = prop_test
308
+ val_ratio = cfg.get("val_ratio", prop_test)
309
+ if val_ratio is None:
310
+ val_ratio = prop_test
311
+ split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
+ split_group_col = cfg.get("split_group_col")
313
+ split_time_col = cfg.get("split_time_col")
314
+ split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
+ cv_strategy = cfg.get("cv_strategy")
316
+ cv_group_col = cfg.get("cv_group_col")
317
+ cv_time_col = cfg.get("cv_time_col")
318
+ cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
+ cv_splits = cfg.get("cv_splits")
320
+ ft_oof_folds = cfg.get("ft_oof_folds")
321
+ ft_oof_strategy = cfg.get("ft_oof_strategy")
322
+ ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
+ return {
324
+ "prop_test": prop_test,
325
+ "holdout_ratio": holdout_ratio,
326
+ "val_ratio": val_ratio,
327
+ "split_strategy": split_strategy,
328
+ "split_group_col": split_group_col,
329
+ "split_time_col": split_time_col,
330
+ "split_time_ascending": split_time_ascending,
331
+ "cv_strategy": cv_strategy,
332
+ "cv_group_col": cv_group_col,
333
+ "cv_time_col": cv_time_col,
334
+ "cv_time_ascending": cv_time_ascending,
335
+ "cv_splits": cv_splits,
336
+ "ft_oof_folds": ft_oof_folds,
337
+ "ft_oof_strategy": ft_oof_strategy,
338
+ "ft_oof_shuffle": ft_oof_shuffle,
339
+ }
340
+
341
+
342
+ def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
343
+ return {
344
+ "save_preprocess": bool(cfg.get("save_preprocess", False)),
345
+ "preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
346
+ "rand_seed": cfg.get("rand_seed", 13),
347
+ "epochs": cfg.get("epochs", 50),
348
+ "plot_path_style": cfg.get("plot_path_style"),
349
+ "reuse_best_params": bool(cfg.get("reuse_best_params", False)),
350
+ "xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
351
+ "xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
352
+ "optuna_storage": cfg.get("optuna_storage"),
353
+ "optuna_study_prefix": cfg.get("optuna_study_prefix"),
354
+ "best_params_files": cfg.get("best_params_files"),
355
+ "bo_sample_limit": cfg.get("bo_sample_limit"),
356
+ "cache_predictions": bool(cfg.get("cache_predictions", False)),
357
+ "prediction_cache_dir": cfg.get("prediction_cache_dir"),
358
+ "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
359
+ "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
360
+ }
361
+
362
+
363
+ def resolve_output_dirs(
364
+ cfg: Dict[str, Any],
365
+ config_path: Path,
366
+ *,
367
+ output_override: Optional[str] = None,
368
+ ) -> Dict[str, Optional[str]]:
369
+ output_root = resolve_dir_path(
370
+ output_override or cfg.get("output_dir"),
371
+ config_path.parent,
372
+ )
373
+ return {
374
+ "output_dir": str(output_root) if output_root is not None else None,
375
+ }