ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +39 -105
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +11 -9
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/__init__.py +10 -10
  15. ins_pricing/frontend/example_workflows.py +1 -1
  16. ins_pricing/governance/__init__.py +20 -20
  17. ins_pricing/governance/release.py +159 -159
  18. ins_pricing/modelling/__init__.py +147 -92
  19. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
  20. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  21. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
  22. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  29. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  36. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  37. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  38. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  39. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  40. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  42. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  43. ins_pricing/modelling/explain/__init__.py +55 -55
  44. ins_pricing/modelling/explain/metrics.py +27 -174
  45. ins_pricing/modelling/explain/permutation.py +237 -237
  46. ins_pricing/modelling/plotting/__init__.py +40 -36
  47. ins_pricing/modelling/plotting/compat.py +228 -0
  48. ins_pricing/modelling/plotting/curves.py +572 -572
  49. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  50. ins_pricing/modelling/plotting/geo.py +362 -362
  51. ins_pricing/modelling/plotting/importance.py +121 -121
  52. ins_pricing/pricing/__init__.py +27 -27
  53. ins_pricing/production/__init__.py +35 -25
  54. ins_pricing/production/{predict.py → inference.py} +140 -57
  55. ins_pricing/production/monitoring.py +8 -21
  56. ins_pricing/reporting/__init__.py +11 -11
  57. ins_pricing/setup.py +1 -1
  58. ins_pricing/tests/production/test_inference.py +90 -0
  59. ins_pricing/utils/__init__.py +116 -83
  60. ins_pricing/utils/device.py +255 -255
  61. ins_pricing/utils/features.py +53 -0
  62. ins_pricing/utils/io.py +72 -0
  63. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  64. ins_pricing/utils/metrics.py +158 -24
  65. ins_pricing/utils/numerics.py +76 -0
  66. ins_pricing/utils/paths.py +9 -1
  67. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
  68. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  69. ins_pricing/modelling/core/BayesOpt.py +0 -146
  70. ins_pricing/modelling/core/__init__.py +0 -1
  71. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  72. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  73. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  74. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  75. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  76. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  77. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  78. ins_pricing/tests/production/test_predict.py +0 -233
  79. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  80. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  81. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  82. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  83. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  84. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,375 +1,375 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import os
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Sequence, Tuple
8
-
9
- try:
10
- from .cli_common import resolve_dir_path, resolve_path # type: ignore
11
- except Exception: # pragma: no cover
12
- from cli_common import resolve_dir_path, resolve_path # type: ignore
13
-
14
-
15
- def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
- candidate = Path(raw)
17
- if candidate.exists():
18
- return candidate.resolve()
19
- candidate2 = (script_dir / raw)
20
- if candidate2.exists():
21
- return candidate2.resolve()
22
- raise FileNotFoundError(
23
- f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
- )
25
-
26
-
27
- def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
- cfg = json.loads(path.read_text(encoding="utf-8"))
29
- missing = [key for key in required_keys if key not in cfg]
30
- if missing:
31
- raise ValueError(f"Missing required keys in {path}: {missing}")
32
- return cfg
33
-
34
-
35
- def set_env(env_overrides: Dict[str, Any]) -> None:
36
- """Apply environment variables from config.json.
37
-
38
- Notes (DDP/Optuna hang debugging):
39
- - You can add these keys into config.json's `env` to debug distributed hangs:
40
- - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
- - `NCCL_DEBUG=INFO`
42
- - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
- - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
- - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
- - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
- - This function uses `os.environ.setdefault`, so a value already set in the
47
- shell will take precedence over config.json.
48
- """
49
- for key, value in (env_overrides or {}).items():
50
- os.environ.setdefault(str(key), str(value))
51
-
52
-
53
- def _looks_like_url(value: str) -> bool:
54
- value = str(value)
55
- return "://" in value
56
-
57
-
58
- def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
- """Resolve relative paths against the config.json directory.
60
-
61
- Fields handled:
62
- - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
- - best_params_files (dict: model_key -> path)
64
- """
65
- base_dir = config_path.parent
66
- out = dict(cfg)
67
-
68
- for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
- "prediction_cache_dir", "report_output_dir", "registry_path"):
70
- if key in out and isinstance(out.get(key), str):
71
- resolved = resolve_path(out.get(key), base_dir)
72
- if resolved is not None:
73
- out[key] = str(resolved)
74
-
75
- storage = out.get("optuna_storage")
76
- if isinstance(storage, str) and storage.strip():
77
- if not _looks_like_url(storage):
78
- resolved = resolve_path(storage, base_dir)
79
- if resolved is not None:
80
- out["optuna_storage"] = str(resolved)
81
-
82
- best_files = out.get("best_params_files")
83
- if isinstance(best_files, dict):
84
- resolved_map: Dict[str, str] = {}
85
- for mk, path_str in best_files.items():
86
- if not isinstance(path_str, str):
87
- continue
88
- resolved = resolve_path(path_str, base_dir)
89
- resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
- out["best_params_files"] = resolved_map
91
-
92
- return out
93
-
94
-
95
- def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
- if value is None:
97
- return {}
98
- if isinstance(value, dict):
99
- return {str(k): v for k, v in value.items()}
100
- if isinstance(value, str):
101
- path = resolve_path(value, base_dir)
102
- if path is None or not path.exists():
103
- raise FileNotFoundError(f"dtype_map not found: {value}")
104
- payload = json.loads(path.read_text(encoding="utf-8"))
105
- if not isinstance(payload, dict):
106
- raise ValueError(f"dtype_map must be a dict: {path}")
107
- return {str(k): v for k, v in payload.items()}
108
- raise ValueError("dtype_map must be a dict or JSON path.")
109
-
110
-
111
- def resolve_data_config(
112
- cfg: Dict[str, Any],
113
- config_path: Path,
114
- *,
115
- create_data_dir: bool = False,
116
- ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
- base_dir = config_path.parent
118
- data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
- if data_dir is None:
120
- raise ValueError("data_dir is required in config.json.")
121
- data_format = cfg.get("data_format", "csv")
122
- data_path_template = cfg.get("data_path_template")
123
- dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
- return data_dir, data_format, data_path_template, dtype_map
125
-
126
-
127
- def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
- parser.add_argument(
129
- "--config-json",
130
- required=True,
131
- help=help_text,
132
- )
133
-
134
-
135
- def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
- parser.add_argument(
137
- "--output-dir",
138
- default=None,
139
- help=help_text,
140
- )
141
-
142
-
143
- def resolve_model_path_value(
144
- value: Any,
145
- *,
146
- model_name: str,
147
- base_dir: Path,
148
- data_dir: Optional[Path] = None,
149
- ) -> Optional[Path]:
150
- if value is None:
151
- return None
152
- if isinstance(value, dict):
153
- value = value.get(model_name)
154
- if value is None:
155
- return None
156
- path_str = str(value)
157
- try:
158
- path_str = path_str.format(model_name=model_name)
159
- except Exception:
160
- pass
161
- if data_dir is not None and not Path(path_str).is_absolute():
162
- candidate = data_dir / path_str
163
- if candidate.exists():
164
- return candidate.resolve()
165
- resolved = resolve_path(path_str, base_dir)
166
- if resolved is None:
167
- return None
168
- return resolved
169
-
170
-
171
- def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
- if not value:
173
- return None
174
- path_str = str(value)
175
- resolved = resolve_path(path_str, base_dir)
176
- return resolved if resolved is not None else Path(path_str)
177
-
178
-
179
- def resolve_explain_save_dir(
180
- save_root: Optional[Path],
181
- *,
182
- result_dir: Optional[Any],
183
- ) -> Path:
184
- if save_root is not None:
185
- return Path(save_root)
186
- if result_dir is None:
187
- raise ValueError("result_dir is required when explain save_root is not set.")
188
- return Path(result_dir) / "explain"
189
-
190
-
191
- def resolve_explain_output_overrides(
192
- explain_cfg: Dict[str, Any],
193
- *,
194
- model_name: str,
195
- base_dir: Path,
196
- ) -> Dict[str, Optional[Path]]:
197
- return {
198
- "model_dir": resolve_model_path_value(
199
- explain_cfg.get("model_dir"),
200
- model_name=model_name,
201
- base_dir=base_dir,
202
- data_dir=None,
203
- ),
204
- "result_dir": resolve_model_path_value(
205
- explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
- model_name=model_name,
207
- base_dir=base_dir,
208
- data_dir=None,
209
- ),
210
- "plot_dir": resolve_model_path_value(
211
- explain_cfg.get("plot_dir"),
212
- model_name=model_name,
213
- base_dir=base_dir,
214
- data_dir=None,
215
- ),
216
- }
217
-
218
-
219
- def resolve_and_load_config(
220
- raw: str,
221
- script_dir: Path,
222
- required_keys: Sequence[str],
223
- *,
224
- apply_env: bool = True,
225
- ) -> Tuple[Path, Dict[str, Any]]:
226
- config_path = resolve_config_path(raw, script_dir)
227
- cfg = load_config_json(config_path, required_keys=required_keys)
228
- cfg = normalize_config_paths(cfg, config_path)
229
- if apply_env:
230
- set_env(cfg.get("env", {}))
231
- return config_path, cfg
232
-
233
-
234
- def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
- def _as_list(value: Any) -> list[str]:
236
- if value is None:
237
- return []
238
- if isinstance(value, (list, tuple, set)):
239
- return [str(item) for item in value]
240
- return [str(value)]
241
-
242
- report_output_dir = cfg.get("report_output_dir")
243
- report_group_cols = _as_list(cfg.get("report_group_cols"))
244
- if not report_group_cols:
245
- report_group_cols = None
246
- report_time_col = cfg.get("report_time_col")
247
- report_time_freq = cfg.get("report_time_freq", "M")
248
- report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
- psi_bins = cfg.get("psi_bins", 10)
250
- psi_strategy = cfg.get("psi_strategy", "quantile")
251
- psi_features = _as_list(cfg.get("psi_features"))
252
- if not psi_features:
253
- psi_features = None
254
- calibration_cfg = cfg.get("calibration", {}) or {}
255
- threshold_cfg = cfg.get("threshold", {}) or {}
256
- bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
- register_model = bool(cfg.get("register_model", False))
258
- registry_path = cfg.get("registry_path")
259
- registry_tags = cfg.get("registry_tags", {})
260
- registry_status = cfg.get("registry_status", "candidate")
261
- data_fingerprint_max_bytes = int(
262
- cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
- calibration_enabled = bool(
264
- calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
- )
266
- threshold_enabled = bool(
267
- threshold_cfg.get("enable", False)
268
- or threshold_cfg.get("value") is not None
269
- or threshold_cfg.get("metric")
270
- )
271
- bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
- report_enabled = any([
273
- bool(report_output_dir),
274
- register_model,
275
- bool(report_group_cols),
276
- bool(report_time_col),
277
- bool(psi_features),
278
- calibration_enabled,
279
- threshold_enabled,
280
- bootstrap_enabled,
281
- ])
282
- return {
283
- "report_output_dir": report_output_dir,
284
- "report_group_cols": report_group_cols,
285
- "report_time_col": report_time_col,
286
- "report_time_freq": report_time_freq,
287
- "report_time_ascending": report_time_ascending,
288
- "psi_bins": psi_bins,
289
- "psi_strategy": psi_strategy,
290
- "psi_features": psi_features,
291
- "calibration_cfg": calibration_cfg,
292
- "threshold_cfg": threshold_cfg,
293
- "bootstrap_cfg": bootstrap_cfg,
294
- "register_model": register_model,
295
- "registry_path": registry_path,
296
- "registry_tags": registry_tags,
297
- "registry_status": registry_status,
298
- "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
- "report_enabled": report_enabled,
300
- }
301
-
302
-
303
- def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
- prop_test = cfg.get("prop_test", 0.25)
305
- holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
- if holdout_ratio is None:
307
- holdout_ratio = prop_test
308
- val_ratio = cfg.get("val_ratio", prop_test)
309
- if val_ratio is None:
310
- val_ratio = prop_test
311
- split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
- split_group_col = cfg.get("split_group_col")
313
- split_time_col = cfg.get("split_time_col")
314
- split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
- cv_strategy = cfg.get("cv_strategy")
316
- cv_group_col = cfg.get("cv_group_col")
317
- cv_time_col = cfg.get("cv_time_col")
318
- cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
- cv_splits = cfg.get("cv_splits")
320
- ft_oof_folds = cfg.get("ft_oof_folds")
321
- ft_oof_strategy = cfg.get("ft_oof_strategy")
322
- ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
- return {
324
- "prop_test": prop_test,
325
- "holdout_ratio": holdout_ratio,
326
- "val_ratio": val_ratio,
327
- "split_strategy": split_strategy,
328
- "split_group_col": split_group_col,
329
- "split_time_col": split_time_col,
330
- "split_time_ascending": split_time_ascending,
331
- "cv_strategy": cv_strategy,
332
- "cv_group_col": cv_group_col,
333
- "cv_time_col": cv_time_col,
334
- "cv_time_ascending": cv_time_ascending,
335
- "cv_splits": cv_splits,
336
- "ft_oof_folds": ft_oof_folds,
337
- "ft_oof_strategy": ft_oof_strategy,
338
- "ft_oof_shuffle": ft_oof_shuffle,
339
- }
340
-
341
-
342
- def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
343
- return {
344
- "save_preprocess": bool(cfg.get("save_preprocess", False)),
345
- "preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
346
- "rand_seed": cfg.get("rand_seed", 13),
347
- "epochs": cfg.get("epochs", 50),
348
- "plot_path_style": cfg.get("plot_path_style"),
349
- "reuse_best_params": bool(cfg.get("reuse_best_params", False)),
350
- "xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
351
- "xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
352
- "optuna_storage": cfg.get("optuna_storage"),
353
- "optuna_study_prefix": cfg.get("optuna_study_prefix"),
354
- "best_params_files": cfg.get("best_params_files"),
355
- "bo_sample_limit": cfg.get("bo_sample_limit"),
356
- "cache_predictions": bool(cfg.get("cache_predictions", False)),
357
- "prediction_cache_dir": cfg.get("prediction_cache_dir"),
358
- "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
359
- "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
360
- }
361
-
362
-
363
- def resolve_output_dirs(
364
- cfg: Dict[str, Any],
365
- config_path: Path,
366
- *,
367
- output_override: Optional[str] = None,
368
- ) -> Dict[str, Optional[str]]:
369
- output_root = resolve_dir_path(
370
- output_override or cfg.get("output_dir"),
371
- config_path.parent,
372
- )
373
- return {
374
- "output_dir": str(output_root) if output_root is not None else None,
375
- }
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Sequence, Tuple
8
+
9
+ try:
10
+ from ins_pricing.cli.utils.cli_common import resolve_dir_path, resolve_path # type: ignore
11
+ except Exception: # pragma: no cover
12
+ from cli_common import resolve_dir_path, resolve_path # type: ignore
13
+
14
+
15
+ def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
+ candidate = Path(raw)
17
+ if candidate.exists():
18
+ return candidate.resolve()
19
+ candidate2 = (script_dir / raw)
20
+ if candidate2.exists():
21
+ return candidate2.resolve()
22
+ raise FileNotFoundError(
23
+ f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
+ )
25
+
26
+
27
+ def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
+ cfg = json.loads(path.read_text(encoding="utf-8"))
29
+ missing = [key for key in required_keys if key not in cfg]
30
+ if missing:
31
+ raise ValueError(f"Missing required keys in {path}: {missing}")
32
+ return cfg
33
+
34
+
35
+ def set_env(env_overrides: Dict[str, Any]) -> None:
36
+ """Apply environment variables from config.json.
37
+
38
+ Notes (DDP/Optuna hang debugging):
39
+ - You can add these keys into config.json's `env` to debug distributed hangs:
40
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
+ - `NCCL_DEBUG=INFO`
42
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
+ - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
+ - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
+ - This function uses `os.environ.setdefault`, so a value already set in the
47
+ shell will take precedence over config.json.
48
+ """
49
+ for key, value in (env_overrides or {}).items():
50
+ os.environ.setdefault(str(key), str(value))
51
+
52
+
53
+ def _looks_like_url(value: str) -> bool:
54
+ value = str(value)
55
+ return "://" in value
56
+
57
+
58
+ def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
+ """Resolve relative paths against the config.json directory.
60
+
61
+ Fields handled:
62
+ - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
+ - best_params_files (dict: model_key -> path)
64
+ """
65
+ base_dir = config_path.parent
66
+ out = dict(cfg)
67
+
68
+ for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
+ "prediction_cache_dir", "report_output_dir", "registry_path"):
70
+ if key in out and isinstance(out.get(key), str):
71
+ resolved = resolve_path(out.get(key), base_dir)
72
+ if resolved is not None:
73
+ out[key] = str(resolved)
74
+
75
+ storage = out.get("optuna_storage")
76
+ if isinstance(storage, str) and storage.strip():
77
+ if not _looks_like_url(storage):
78
+ resolved = resolve_path(storage, base_dir)
79
+ if resolved is not None:
80
+ out["optuna_storage"] = str(resolved)
81
+
82
+ best_files = out.get("best_params_files")
83
+ if isinstance(best_files, dict):
84
+ resolved_map: Dict[str, str] = {}
85
+ for mk, path_str in best_files.items():
86
+ if not isinstance(path_str, str):
87
+ continue
88
+ resolved = resolve_path(path_str, base_dir)
89
+ resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
+ out["best_params_files"] = resolved_map
91
+
92
+ return out
93
+
94
+
95
+ def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
+ if value is None:
97
+ return {}
98
+ if isinstance(value, dict):
99
+ return {str(k): v for k, v in value.items()}
100
+ if isinstance(value, str):
101
+ path = resolve_path(value, base_dir)
102
+ if path is None or not path.exists():
103
+ raise FileNotFoundError(f"dtype_map not found: {value}")
104
+ payload = json.loads(path.read_text(encoding="utf-8"))
105
+ if not isinstance(payload, dict):
106
+ raise ValueError(f"dtype_map must be a dict: {path}")
107
+ return {str(k): v for k, v in payload.items()}
108
+ raise ValueError("dtype_map must be a dict or JSON path.")
109
+
110
+
111
+ def resolve_data_config(
112
+ cfg: Dict[str, Any],
113
+ config_path: Path,
114
+ *,
115
+ create_data_dir: bool = False,
116
+ ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
+ base_dir = config_path.parent
118
+ data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
+ if data_dir is None:
120
+ raise ValueError("data_dir is required in config.json.")
121
+ data_format = cfg.get("data_format", "csv")
122
+ data_path_template = cfg.get("data_path_template")
123
+ dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
+ return data_dir, data_format, data_path_template, dtype_map
125
+
126
+
127
+ def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
+ parser.add_argument(
129
+ "--config-json",
130
+ required=True,
131
+ help=help_text,
132
+ )
133
+
134
+
135
+ def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
+ parser.add_argument(
137
+ "--output-dir",
138
+ default=None,
139
+ help=help_text,
140
+ )
141
+
142
+
143
+ def resolve_model_path_value(
144
+ value: Any,
145
+ *,
146
+ model_name: str,
147
+ base_dir: Path,
148
+ data_dir: Optional[Path] = None,
149
+ ) -> Optional[Path]:
150
+ if value is None:
151
+ return None
152
+ if isinstance(value, dict):
153
+ value = value.get(model_name)
154
+ if value is None:
155
+ return None
156
+ path_str = str(value)
157
+ try:
158
+ path_str = path_str.format(model_name=model_name)
159
+ except Exception:
160
+ pass
161
+ if data_dir is not None and not Path(path_str).is_absolute():
162
+ candidate = data_dir / path_str
163
+ if candidate.exists():
164
+ return candidate.resolve()
165
+ resolved = resolve_path(path_str, base_dir)
166
+ if resolved is None:
167
+ return None
168
+ return resolved
169
+
170
+
171
+ def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
+ if not value:
173
+ return None
174
+ path_str = str(value)
175
+ resolved = resolve_path(path_str, base_dir)
176
+ return resolved if resolved is not None else Path(path_str)
177
+
178
+
179
+ def resolve_explain_save_dir(
180
+ save_root: Optional[Path],
181
+ *,
182
+ result_dir: Optional[Any],
183
+ ) -> Path:
184
+ if save_root is not None:
185
+ return Path(save_root)
186
+ if result_dir is None:
187
+ raise ValueError("result_dir is required when explain save_root is not set.")
188
+ return Path(result_dir) / "explain"
189
+
190
+
191
+ def resolve_explain_output_overrides(
192
+ explain_cfg: Dict[str, Any],
193
+ *,
194
+ model_name: str,
195
+ base_dir: Path,
196
+ ) -> Dict[str, Optional[Path]]:
197
+ return {
198
+ "model_dir": resolve_model_path_value(
199
+ explain_cfg.get("model_dir"),
200
+ model_name=model_name,
201
+ base_dir=base_dir,
202
+ data_dir=None,
203
+ ),
204
+ "result_dir": resolve_model_path_value(
205
+ explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
+ model_name=model_name,
207
+ base_dir=base_dir,
208
+ data_dir=None,
209
+ ),
210
+ "plot_dir": resolve_model_path_value(
211
+ explain_cfg.get("plot_dir"),
212
+ model_name=model_name,
213
+ base_dir=base_dir,
214
+ data_dir=None,
215
+ ),
216
+ }
217
+
218
+
219
+ def resolve_and_load_config(
220
+ raw: str,
221
+ script_dir: Path,
222
+ required_keys: Sequence[str],
223
+ *,
224
+ apply_env: bool = True,
225
+ ) -> Tuple[Path, Dict[str, Any]]:
226
+ config_path = resolve_config_path(raw, script_dir)
227
+ cfg = load_config_json(config_path, required_keys=required_keys)
228
+ cfg = normalize_config_paths(cfg, config_path)
229
+ if apply_env:
230
+ set_env(cfg.get("env", {}))
231
+ return config_path, cfg
232
+
233
+
234
+ def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
+ def _as_list(value: Any) -> list[str]:
236
+ if value is None:
237
+ return []
238
+ if isinstance(value, (list, tuple, set)):
239
+ return [str(item) for item in value]
240
+ return [str(value)]
241
+
242
+ report_output_dir = cfg.get("report_output_dir")
243
+ report_group_cols = _as_list(cfg.get("report_group_cols"))
244
+ if not report_group_cols:
245
+ report_group_cols = None
246
+ report_time_col = cfg.get("report_time_col")
247
+ report_time_freq = cfg.get("report_time_freq", "M")
248
+ report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
+ psi_bins = cfg.get("psi_bins", 10)
250
+ psi_strategy = cfg.get("psi_strategy", "quantile")
251
+ psi_features = _as_list(cfg.get("psi_features"))
252
+ if not psi_features:
253
+ psi_features = None
254
+ calibration_cfg = cfg.get("calibration", {}) or {}
255
+ threshold_cfg = cfg.get("threshold", {}) or {}
256
+ bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
+ register_model = bool(cfg.get("register_model", False))
258
+ registry_path = cfg.get("registry_path")
259
+ registry_tags = cfg.get("registry_tags", {})
260
+ registry_status = cfg.get("registry_status", "candidate")
261
+ data_fingerprint_max_bytes = int(
262
+ cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
+ calibration_enabled = bool(
264
+ calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
+ )
266
+ threshold_enabled = bool(
267
+ threshold_cfg.get("enable", False)
268
+ or threshold_cfg.get("value") is not None
269
+ or threshold_cfg.get("metric")
270
+ )
271
+ bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
+ report_enabled = any([
273
+ bool(report_output_dir),
274
+ register_model,
275
+ bool(report_group_cols),
276
+ bool(report_time_col),
277
+ bool(psi_features),
278
+ calibration_enabled,
279
+ threshold_enabled,
280
+ bootstrap_enabled,
281
+ ])
282
+ return {
283
+ "report_output_dir": report_output_dir,
284
+ "report_group_cols": report_group_cols,
285
+ "report_time_col": report_time_col,
286
+ "report_time_freq": report_time_freq,
287
+ "report_time_ascending": report_time_ascending,
288
+ "psi_bins": psi_bins,
289
+ "psi_strategy": psi_strategy,
290
+ "psi_features": psi_features,
291
+ "calibration_cfg": calibration_cfg,
292
+ "threshold_cfg": threshold_cfg,
293
+ "bootstrap_cfg": bootstrap_cfg,
294
+ "register_model": register_model,
295
+ "registry_path": registry_path,
296
+ "registry_tags": registry_tags,
297
+ "registry_status": registry_status,
298
+ "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
+ "report_enabled": report_enabled,
300
+ }
301
+
302
+
303
+ def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
+ prop_test = cfg.get("prop_test", 0.25)
305
+ holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
+ if holdout_ratio is None:
307
+ holdout_ratio = prop_test
308
+ val_ratio = cfg.get("val_ratio", prop_test)
309
+ if val_ratio is None:
310
+ val_ratio = prop_test
311
+ split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
+ split_group_col = cfg.get("split_group_col")
313
+ split_time_col = cfg.get("split_time_col")
314
+ split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
+ cv_strategy = cfg.get("cv_strategy")
316
+ cv_group_col = cfg.get("cv_group_col")
317
+ cv_time_col = cfg.get("cv_time_col")
318
+ cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
+ cv_splits = cfg.get("cv_splits")
320
+ ft_oof_folds = cfg.get("ft_oof_folds")
321
+ ft_oof_strategy = cfg.get("ft_oof_strategy")
322
+ ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
+ return {
324
+ "prop_test": prop_test,
325
+ "holdout_ratio": holdout_ratio,
326
+ "val_ratio": val_ratio,
327
+ "split_strategy": split_strategy,
328
+ "split_group_col": split_group_col,
329
+ "split_time_col": split_time_col,
330
+ "split_time_ascending": split_time_ascending,
331
+ "cv_strategy": cv_strategy,
332
+ "cv_group_col": cv_group_col,
333
+ "cv_time_col": cv_time_col,
334
+ "cv_time_ascending": cv_time_ascending,
335
+ "cv_splits": cv_splits,
336
+ "ft_oof_folds": ft_oof_folds,
337
+ "ft_oof_strategy": ft_oof_strategy,
338
+ "ft_oof_shuffle": ft_oof_shuffle,
339
+ }
340
+
341
+
342
+ def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
343
+ return {
344
+ "save_preprocess": bool(cfg.get("save_preprocess", False)),
345
+ "preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
346
+ "rand_seed": cfg.get("rand_seed", 13),
347
+ "epochs": cfg.get("epochs", 50),
348
+ "plot_path_style": cfg.get("plot_path_style"),
349
+ "reuse_best_params": bool(cfg.get("reuse_best_params", False)),
350
+ "xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
351
+ "xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
352
+ "optuna_storage": cfg.get("optuna_storage"),
353
+ "optuna_study_prefix": cfg.get("optuna_study_prefix"),
354
+ "best_params_files": cfg.get("best_params_files"),
355
+ "bo_sample_limit": cfg.get("bo_sample_limit"),
356
+ "cache_predictions": bool(cfg.get("cache_predictions", False)),
357
+ "prediction_cache_dir": cfg.get("prediction_cache_dir"),
358
+ "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
359
+ "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
360
+ }
361
+
362
+
363
+ def resolve_output_dirs(
364
+ cfg: Dict[str, Any],
365
+ config_path: Path,
366
+ *,
367
+ output_override: Optional[str] = None,
368
+ ) -> Dict[str, Optional[str]]:
369
+ output_root = resolve_dir_path(
370
+ output_override or cfg.get("output_dir"),
371
+ config_path.parent,
372
+ )
373
+ return {
374
+ "output_dir": str(output_root) if output_root is not None else None,
375
+ }