ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,345 +1,353 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import os
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Sequence, Tuple
8
-
9
- try:
10
- from .cli_common import resolve_dir_path, resolve_path # type: ignore
11
- except Exception: # pragma: no cover
12
- from cli_common import resolve_dir_path, resolve_path # type: ignore
13
-
14
-
15
- def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
- candidate = Path(raw)
17
- if candidate.exists():
18
- return candidate.resolve()
19
- candidate2 = (script_dir / raw)
20
- if candidate2.exists():
21
- return candidate2.resolve()
22
- raise FileNotFoundError(
23
- f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
- )
25
-
26
-
27
- def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
- cfg = json.loads(path.read_text(encoding="utf-8"))
29
- missing = [key for key in required_keys if key not in cfg]
30
- if missing:
31
- raise ValueError(f"Missing required keys in {path}: {missing}")
32
- return cfg
33
-
34
-
35
- def set_env(env_overrides: Dict[str, Any]) -> None:
36
- """Apply environment variables from config.json.
37
-
38
- Notes (DDP/Optuna hang debugging):
39
- - You can add these keys into config.json's `env` to debug distributed hangs:
40
- - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
- - `NCCL_DEBUG=INFO`
42
- - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
- - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
- - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
- - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
- - This function uses `os.environ.setdefault`, so a value already set in the
47
- shell will take precedence over config.json.
48
- """
49
- for key, value in (env_overrides or {}).items():
50
- os.environ.setdefault(str(key), str(value))
51
-
52
-
53
- def _looks_like_url(value: str) -> bool:
54
- value = str(value)
55
- return "://" in value
56
-
57
-
58
- def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
- """Resolve relative paths against the config.json directory.
60
-
61
- Fields handled:
62
- - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
- - best_params_files (dict: model_key -> path)
64
- """
65
- base_dir = config_path.parent
66
- out = dict(cfg)
67
-
68
- for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
- "prediction_cache_dir", "report_output_dir", "registry_path"):
70
- if key in out and isinstance(out.get(key), str):
71
- resolved = resolve_path(out.get(key), base_dir)
72
- if resolved is not None:
73
- out[key] = str(resolved)
74
-
75
- storage = out.get("optuna_storage")
76
- if isinstance(storage, str) and storage.strip():
77
- if not _looks_like_url(storage):
78
- resolved = resolve_path(storage, base_dir)
79
- if resolved is not None:
80
- out["optuna_storage"] = str(resolved)
81
-
82
- best_files = out.get("best_params_files")
83
- if isinstance(best_files, dict):
84
- resolved_map: Dict[str, str] = {}
85
- for mk, path_str in best_files.items():
86
- if not isinstance(path_str, str):
87
- continue
88
- resolved = resolve_path(path_str, base_dir)
89
- resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
- out["best_params_files"] = resolved_map
91
-
92
- return out
93
-
94
-
95
- def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
- if value is None:
97
- return {}
98
- if isinstance(value, dict):
99
- return {str(k): v for k, v in value.items()}
100
- if isinstance(value, str):
101
- path = resolve_path(value, base_dir)
102
- if path is None or not path.exists():
103
- raise FileNotFoundError(f"dtype_map not found: {value}")
104
- payload = json.loads(path.read_text(encoding="utf-8"))
105
- if not isinstance(payload, dict):
106
- raise ValueError(f"dtype_map must be a dict: {path}")
107
- return {str(k): v for k, v in payload.items()}
108
- raise ValueError("dtype_map must be a dict or JSON path.")
109
-
110
-
111
- def resolve_data_config(
112
- cfg: Dict[str, Any],
113
- config_path: Path,
114
- *,
115
- create_data_dir: bool = False,
116
- ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
- base_dir = config_path.parent
118
- data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
- if data_dir is None:
120
- raise ValueError("data_dir is required in config.json.")
121
- data_format = cfg.get("data_format", "csv")
122
- data_path_template = cfg.get("data_path_template")
123
- dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
- return data_dir, data_format, data_path_template, dtype_map
125
-
126
-
127
- def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
- parser.add_argument(
129
- "--config-json",
130
- required=True,
131
- help=help_text,
132
- )
133
-
134
-
135
- def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
- parser.add_argument(
137
- "--output-dir",
138
- default=None,
139
- help=help_text,
140
- )
141
-
142
-
143
- def resolve_model_path_value(
144
- value: Any,
145
- *,
146
- model_name: str,
147
- base_dir: Path,
148
- data_dir: Optional[Path] = None,
149
- ) -> Optional[Path]:
150
- if value is None:
151
- return None
152
- if isinstance(value, dict):
153
- value = value.get(model_name)
154
- if value is None:
155
- return None
156
- path_str = str(value)
157
- try:
158
- path_str = path_str.format(model_name=model_name)
159
- except Exception:
160
- pass
161
- if data_dir is not None and not Path(path_str).is_absolute():
162
- candidate = data_dir / path_str
163
- if candidate.exists():
164
- return candidate.resolve()
165
- resolved = resolve_path(path_str, base_dir)
166
- if resolved is None:
167
- return None
168
- return resolved
169
-
170
-
171
- def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
- if not value:
173
- return None
174
- path_str = str(value)
175
- resolved = resolve_path(path_str, base_dir)
176
- return resolved if resolved is not None else Path(path_str)
177
-
178
-
179
- def resolve_explain_save_dir(
180
- save_root: Optional[Path],
181
- *,
182
- result_dir: Optional[Any],
183
- ) -> Path:
184
- if save_root is not None:
185
- return Path(save_root)
186
- if result_dir is None:
187
- raise ValueError("result_dir is required when explain save_root is not set.")
188
- return Path(result_dir) / "explain"
189
-
190
-
191
- def resolve_explain_output_overrides(
192
- explain_cfg: Dict[str, Any],
193
- *,
194
- model_name: str,
195
- base_dir: Path,
196
- ) -> Dict[str, Optional[Path]]:
197
- return {
198
- "model_dir": resolve_model_path_value(
199
- explain_cfg.get("model_dir"),
200
- model_name=model_name,
201
- base_dir=base_dir,
202
- data_dir=None,
203
- ),
204
- "result_dir": resolve_model_path_value(
205
- explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
- model_name=model_name,
207
- base_dir=base_dir,
208
- data_dir=None,
209
- ),
210
- "plot_dir": resolve_model_path_value(
211
- explain_cfg.get("plot_dir"),
212
- model_name=model_name,
213
- base_dir=base_dir,
214
- data_dir=None,
215
- ),
216
- }
217
-
218
-
219
- def resolve_and_load_config(
220
- raw: str,
221
- script_dir: Path,
222
- required_keys: Sequence[str],
223
- *,
224
- apply_env: bool = True,
225
- ) -> Tuple[Path, Dict[str, Any]]:
226
- config_path = resolve_config_path(raw, script_dir)
227
- cfg = load_config_json(config_path, required_keys=required_keys)
228
- cfg = normalize_config_paths(cfg, config_path)
229
- if apply_env:
230
- set_env(cfg.get("env", {}))
231
- return config_path, cfg
232
-
233
-
234
- def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
- def _as_list(value: Any) -> list[str]:
236
- if value is None:
237
- return []
238
- if isinstance(value, (list, tuple, set)):
239
- return [str(item) for item in value]
240
- return [str(value)]
241
-
242
- report_output_dir = cfg.get("report_output_dir")
243
- report_group_cols = _as_list(cfg.get("report_group_cols"))
244
- if not report_group_cols:
245
- report_group_cols = None
246
- report_time_col = cfg.get("report_time_col")
247
- report_time_freq = cfg.get("report_time_freq", "M")
248
- report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
- psi_bins = cfg.get("psi_bins", 10)
250
- psi_strategy = cfg.get("psi_strategy", "quantile")
251
- psi_features = _as_list(cfg.get("psi_features"))
252
- if not psi_features:
253
- psi_features = None
254
- calibration_cfg = cfg.get("calibration", {}) or {}
255
- threshold_cfg = cfg.get("threshold", {}) or {}
256
- bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
- register_model = bool(cfg.get("register_model", False))
258
- registry_path = cfg.get("registry_path")
259
- registry_tags = cfg.get("registry_tags", {})
260
- registry_status = cfg.get("registry_status", "candidate")
261
- data_fingerprint_max_bytes = int(
262
- cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
- calibration_enabled = bool(
264
- calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
- )
266
- threshold_enabled = bool(
267
- threshold_cfg.get("enable", False)
268
- or threshold_cfg.get("value") is not None
269
- or threshold_cfg.get("metric")
270
- )
271
- bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
- report_enabled = any([
273
- bool(report_output_dir),
274
- register_model,
275
- bool(report_group_cols),
276
- bool(report_time_col),
277
- bool(psi_features),
278
- calibration_enabled,
279
- threshold_enabled,
280
- bootstrap_enabled,
281
- ])
282
- return {
283
- "report_output_dir": report_output_dir,
284
- "report_group_cols": report_group_cols,
285
- "report_time_col": report_time_col,
286
- "report_time_freq": report_time_freq,
287
- "report_time_ascending": report_time_ascending,
288
- "psi_bins": psi_bins,
289
- "psi_strategy": psi_strategy,
290
- "psi_features": psi_features,
291
- "calibration_cfg": calibration_cfg,
292
- "threshold_cfg": threshold_cfg,
293
- "bootstrap_cfg": bootstrap_cfg,
294
- "register_model": register_model,
295
- "registry_path": registry_path,
296
- "registry_tags": registry_tags,
297
- "registry_status": registry_status,
298
- "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
- "report_enabled": report_enabled,
300
- }
301
-
302
-
303
- def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
- prop_test = cfg.get("prop_test", 0.25)
305
- holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
- if holdout_ratio is None:
307
- holdout_ratio = prop_test
308
- val_ratio = cfg.get("val_ratio", prop_test)
309
- if val_ratio is None:
310
- val_ratio = prop_test
311
- split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
- split_group_col = cfg.get("split_group_col")
313
- split_time_col = cfg.get("split_time_col")
314
- split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
- cv_strategy = cfg.get("cv_strategy")
316
- cv_group_col = cfg.get("cv_group_col")
317
- cv_time_col = cfg.get("cv_time_col")
318
- cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
- cv_splits = cfg.get("cv_splits")
320
- ft_oof_folds = cfg.get("ft_oof_folds")
321
- ft_oof_strategy = cfg.get("ft_oof_strategy")
322
- ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
- return {
324
- "prop_test": prop_test,
325
- "holdout_ratio": holdout_ratio,
326
- "val_ratio": val_ratio,
327
- "split_strategy": split_strategy,
328
- "split_group_col": split_group_col,
329
- "split_time_col": split_time_col,
330
- "split_time_ascending": split_time_ascending,
331
- "cv_strategy": cv_strategy,
332
- "cv_group_col": cv_group_col,
333
- "cv_time_col": cv_time_col,
334
- "cv_time_ascending": cv_time_ascending,
335
- "cv_splits": cv_splits,
336
- "ft_oof_folds": ft_oof_folds,
337
- "ft_oof_strategy": ft_oof_strategy,
338
- "ft_oof_shuffle": ft_oof_shuffle,
339
- }
340
-
341
-
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Sequence, Tuple
8
+
9
+ try:
10
+ from ins_pricing.cli.utils.cli_common import resolve_dir_path, resolve_path # type: ignore
11
+ except Exception: # pragma: no cover
12
+ from cli_common import resolve_dir_path, resolve_path # type: ignore
13
+
14
+
15
+ def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
+ candidate = Path(raw)
17
+ if candidate.exists():
18
+ return candidate.resolve()
19
+ candidate2 = (script_dir / raw)
20
+ if candidate2.exists():
21
+ return candidate2.resolve()
22
+ raise FileNotFoundError(
23
+ f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
+ )
25
+
26
+
27
+ def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
+ cfg = json.loads(path.read_text(encoding="utf-8"))
29
+ missing = [key for key in required_keys if key not in cfg]
30
+ if missing:
31
+ raise ValueError(f"Missing required keys in {path}: {missing}")
32
+ return cfg
33
+
34
+
35
+ def set_env(env_overrides: Dict[str, Any]) -> None:
36
+ """Apply environment variables from config.json.
37
+
38
+ Notes (DDP/Optuna hang debugging):
39
+ - You can add these keys into config.json's `env` to debug distributed hangs:
40
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
+ - `NCCL_DEBUG=INFO`
42
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
+ - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
+ - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
+ - This function uses `os.environ.setdefault`, so a value already set in the
47
+ shell will take precedence over config.json.
48
+ """
49
+ for key, value in (env_overrides or {}).items():
50
+ os.environ.setdefault(str(key), str(value))
51
+
52
+
53
+ def _looks_like_url(value: str) -> bool:
54
+ value = str(value)
55
+ return "://" in value
56
+
57
+
58
+ def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
+ """Resolve relative paths against the config.json directory.
60
+
61
+ Fields handled:
62
+ - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
+ - best_params_files (dict: model_key -> path)
64
+ """
65
+ base_dir = config_path.parent
66
+ out = dict(cfg)
67
+
68
+ for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
+ "prediction_cache_dir", "report_output_dir", "registry_path"):
70
+ if key in out and isinstance(out.get(key), str):
71
+ resolved = resolve_path(out.get(key), base_dir)
72
+ if resolved is not None:
73
+ out[key] = str(resolved)
74
+
75
+ storage = out.get("optuna_storage")
76
+ if isinstance(storage, str) and storage.strip():
77
+ if not _looks_like_url(storage):
78
+ resolved = resolve_path(storage, base_dir)
79
+ if resolved is not None:
80
+ out["optuna_storage"] = str(resolved)
81
+
82
+ best_files = out.get("best_params_files")
83
+ if isinstance(best_files, dict):
84
+ resolved_map: Dict[str, str] = {}
85
+ for mk, path_str in best_files.items():
86
+ if not isinstance(path_str, str):
87
+ continue
88
+ resolved = resolve_path(path_str, base_dir)
89
+ resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
+ out["best_params_files"] = resolved_map
91
+
92
+ return out
93
+
94
+
95
+ def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
+ if value is None:
97
+ return {}
98
+ if isinstance(value, dict):
99
+ return {str(k): v for k, v in value.items()}
100
+ if isinstance(value, str):
101
+ path = resolve_path(value, base_dir)
102
+ if path is None or not path.exists():
103
+ raise FileNotFoundError(f"dtype_map not found: {value}")
104
+ payload = json.loads(path.read_text(encoding="utf-8"))
105
+ if not isinstance(payload, dict):
106
+ raise ValueError(f"dtype_map must be a dict: {path}")
107
+ return {str(k): v for k, v in payload.items()}
108
+ raise ValueError("dtype_map must be a dict or JSON path.")
109
+
110
+
111
+ def resolve_data_config(
112
+ cfg: Dict[str, Any],
113
+ config_path: Path,
114
+ *,
115
+ create_data_dir: bool = False,
116
+ ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
+ base_dir = config_path.parent
118
+ data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
+ if data_dir is None:
120
+ raise ValueError("data_dir is required in config.json.")
121
+ data_format = cfg.get("data_format", "csv")
122
+ data_path_template = cfg.get("data_path_template")
123
+ dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
+ return data_dir, data_format, data_path_template, dtype_map
125
+
126
+
127
+ def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
+ parser.add_argument(
129
+ "--config-json",
130
+ required=True,
131
+ help=help_text,
132
+ )
133
+
134
+
135
+ def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
+ parser.add_argument(
137
+ "--output-dir",
138
+ default=None,
139
+ help=help_text,
140
+ )
141
+
142
+
143
+ def resolve_model_path_value(
144
+ value: Any,
145
+ *,
146
+ model_name: str,
147
+ base_dir: Path,
148
+ data_dir: Optional[Path] = None,
149
+ ) -> Optional[Path]:
150
+ if value is None:
151
+ return None
152
+ if isinstance(value, dict):
153
+ value = value.get(model_name)
154
+ if value is None:
155
+ return None
156
+ path_str = str(value)
157
+ try:
158
+ path_str = path_str.format(model_name=model_name)
159
+ except Exception:
160
+ pass
161
+ if data_dir is not None and not Path(path_str).is_absolute():
162
+ candidate = data_dir / path_str
163
+ if candidate.exists():
164
+ return candidate.resolve()
165
+ resolved = resolve_path(path_str, base_dir)
166
+ if resolved is None:
167
+ return None
168
+ return resolved
169
+
170
+
171
+ def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
+ if not value:
173
+ return None
174
+ path_str = str(value)
175
+ resolved = resolve_path(path_str, base_dir)
176
+ return resolved if resolved is not None else Path(path_str)
177
+
178
+
179
+ def resolve_explain_save_dir(
180
+ save_root: Optional[Path],
181
+ *,
182
+ result_dir: Optional[Any],
183
+ ) -> Path:
184
+ if save_root is not None:
185
+ return Path(save_root)
186
+ if result_dir is None:
187
+ raise ValueError("result_dir is required when explain save_root is not set.")
188
+ return Path(result_dir) / "explain"
189
+
190
+
191
+ def resolve_explain_output_overrides(
192
+ explain_cfg: Dict[str, Any],
193
+ *,
194
+ model_name: str,
195
+ base_dir: Path,
196
+ ) -> Dict[str, Optional[Path]]:
197
+ return {
198
+ "model_dir": resolve_model_path_value(
199
+ explain_cfg.get("model_dir"),
200
+ model_name=model_name,
201
+ base_dir=base_dir,
202
+ data_dir=None,
203
+ ),
204
+ "result_dir": resolve_model_path_value(
205
+ explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
+ model_name=model_name,
207
+ base_dir=base_dir,
208
+ data_dir=None,
209
+ ),
210
+ "plot_dir": resolve_model_path_value(
211
+ explain_cfg.get("plot_dir"),
212
+ model_name=model_name,
213
+ base_dir=base_dir,
214
+ data_dir=None,
215
+ ),
216
+ }
217
+
218
+
219
+ def resolve_and_load_config(
220
+ raw: str,
221
+ script_dir: Path,
222
+ required_keys: Sequence[str],
223
+ *,
224
+ apply_env: bool = True,
225
+ ) -> Tuple[Path, Dict[str, Any]]:
226
+ config_path = resolve_config_path(raw, script_dir)
227
+ cfg = load_config_json(config_path, required_keys=required_keys)
228
+ cfg = normalize_config_paths(cfg, config_path)
229
+ if apply_env:
230
+ set_env(cfg.get("env", {}))
231
+ return config_path, cfg
232
+
233
+
234
+ def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
+ def _as_list(value: Any) -> list[str]:
236
+ if value is None:
237
+ return []
238
+ if isinstance(value, (list, tuple, set)):
239
+ return [str(item) for item in value]
240
+ return [str(value)]
241
+
242
+ report_output_dir = cfg.get("report_output_dir")
243
+ report_group_cols = _as_list(cfg.get("report_group_cols"))
244
+ if not report_group_cols:
245
+ report_group_cols = None
246
+ report_time_col = cfg.get("report_time_col")
247
+ report_time_freq = cfg.get("report_time_freq", "M")
248
+ report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
+ psi_bins = cfg.get("psi_bins", 10)
250
+ psi_strategy = cfg.get("psi_strategy", "quantile")
251
+ psi_features = _as_list(cfg.get("psi_features"))
252
+ if not psi_features:
253
+ psi_features = None
254
+ calibration_cfg = cfg.get("calibration", {}) or {}
255
+ threshold_cfg = cfg.get("threshold", {}) or {}
256
+ bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
+ register_model = bool(cfg.get("register_model", False))
258
+ registry_path = cfg.get("registry_path")
259
+ registry_tags = cfg.get("registry_tags", {})
260
+ registry_status = cfg.get("registry_status", "candidate")
261
+ data_fingerprint_max_bytes = int(
262
+ cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
+ calibration_enabled = bool(
264
+ calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
+ )
266
+ threshold_enabled = bool(
267
+ threshold_cfg.get("enable", False)
268
+ or threshold_cfg.get("value") is not None
269
+ or threshold_cfg.get("metric")
270
+ )
271
+ bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
+ report_enabled = any([
273
+ bool(report_output_dir),
274
+ register_model,
275
+ bool(report_group_cols),
276
+ bool(report_time_col),
277
+ bool(psi_features),
278
+ calibration_enabled,
279
+ threshold_enabled,
280
+ bootstrap_enabled,
281
+ ])
282
+ return {
283
+ "report_output_dir": report_output_dir,
284
+ "report_group_cols": report_group_cols,
285
+ "report_time_col": report_time_col,
286
+ "report_time_freq": report_time_freq,
287
+ "report_time_ascending": report_time_ascending,
288
+ "psi_bins": psi_bins,
289
+ "psi_strategy": psi_strategy,
290
+ "psi_features": psi_features,
291
+ "calibration_cfg": calibration_cfg,
292
+ "threshold_cfg": threshold_cfg,
293
+ "bootstrap_cfg": bootstrap_cfg,
294
+ "register_model": register_model,
295
+ "registry_path": registry_path,
296
+ "registry_tags": registry_tags,
297
+ "registry_status": registry_status,
298
+ "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
+ "report_enabled": report_enabled,
300
+ }
301
+
302
+
303
+ def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
+ prop_test = cfg.get("prop_test", 0.25)
305
+ holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
+ if holdout_ratio is None:
307
+ holdout_ratio = prop_test
308
+ val_ratio = cfg.get("val_ratio", prop_test)
309
+ if val_ratio is None:
310
+ val_ratio = prop_test
311
+ split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
+ split_group_col = cfg.get("split_group_col")
313
+ split_time_col = cfg.get("split_time_col")
314
+ split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
+ cv_strategy = cfg.get("cv_strategy")
316
+ cv_group_col = cfg.get("cv_group_col")
317
+ cv_time_col = cfg.get("cv_time_col")
318
+ cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
+ cv_splits = cfg.get("cv_splits")
320
+ ft_oof_folds = cfg.get("ft_oof_folds")
321
+ ft_oof_strategy = cfg.get("ft_oof_strategy")
322
+ ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
+ return {
324
+ "prop_test": prop_test,
325
+ "holdout_ratio": holdout_ratio,
326
+ "val_ratio": val_ratio,
327
+ "split_strategy": split_strategy,
328
+ "split_group_col": split_group_col,
329
+ "split_time_col": split_time_col,
330
+ "split_time_ascending": split_time_ascending,
331
+ "cv_strategy": cv_strategy,
332
+ "cv_group_col": cv_group_col,
333
+ "cv_time_col": cv_time_col,
334
+ "cv_time_ascending": cv_time_ascending,
335
+ "cv_splits": cv_splits,
336
+ "ft_oof_folds": ft_oof_folds,
337
+ "ft_oof_strategy": ft_oof_strategy,
338
+ "ft_oof_shuffle": ft_oof_shuffle,
339
+ }
340
+
341
+
342
342
  def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
343
+ xgb_gpu_id = cfg.get("xgb_gpu_id")
344
+ if isinstance(xgb_gpu_id, str) and xgb_gpu_id.strip() == "":
345
+ xgb_gpu_id = None
346
+ if xgb_gpu_id is not None:
347
+ try:
348
+ xgb_gpu_id = int(xgb_gpu_id)
349
+ except (TypeError, ValueError):
350
+ xgb_gpu_id = None
343
351
  return {
344
352
  "save_preprocess": bool(cfg.get("save_preprocess", False)),
345
353
  "preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
@@ -349,27 +357,38 @@ def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
349
357
  "reuse_best_params": bool(cfg.get("reuse_best_params", False)),
350
358
  "xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
351
359
  "xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
360
+ "xgb_gpu_id": xgb_gpu_id,
361
+ "xgb_cleanup_per_fold": bool(cfg.get("xgb_cleanup_per_fold", False)),
362
+ "xgb_cleanup_synchronize": bool(cfg.get("xgb_cleanup_synchronize", False)),
363
+ "xgb_use_dmatrix": bool(cfg.get("xgb_use_dmatrix", True)),
364
+ "ft_cleanup_per_fold": bool(cfg.get("ft_cleanup_per_fold", False)),
365
+ "ft_cleanup_synchronize": bool(cfg.get("ft_cleanup_synchronize", False)),
366
+ "resn_cleanup_per_fold": bool(cfg.get("resn_cleanup_per_fold", False)),
367
+ "resn_cleanup_synchronize": bool(cfg.get("resn_cleanup_synchronize", False)),
368
+ "gnn_cleanup_per_fold": bool(cfg.get("gnn_cleanup_per_fold", False)),
369
+ "gnn_cleanup_synchronize": bool(cfg.get("gnn_cleanup_synchronize", False)),
370
+ "optuna_cleanup_synchronize": bool(cfg.get("optuna_cleanup_synchronize", False)),
352
371
  "optuna_storage": cfg.get("optuna_storage"),
353
372
  "optuna_study_prefix": cfg.get("optuna_study_prefix"),
354
373
  "best_params_files": cfg.get("best_params_files"),
355
374
  "bo_sample_limit": cfg.get("bo_sample_limit"),
356
375
  "cache_predictions": bool(cfg.get("cache_predictions", False)),
357
- "prediction_cache_dir": cfg.get("prediction_cache_dir"),
358
- "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
359
- "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
360
- }
361
-
362
-
363
- def resolve_output_dirs(
364
- cfg: Dict[str, Any],
365
- config_path: Path,
366
- *,
367
- output_override: Optional[str] = None,
368
- ) -> Dict[str, Optional[str]]:
369
- output_root = resolve_dir_path(
370
- output_override or cfg.get("output_dir"),
371
- config_path.parent,
372
- )
373
- return {
374
- "output_dir": str(output_root) if output_root is not None else None,
375
- }
376
+ "prediction_cache_dir": cfg.get("prediction_cache_dir"),
377
+ "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
378
+ "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
379
+ }
380
+
381
+
382
+ def resolve_output_dirs(
383
+ cfg: Dict[str, Any],
384
+ config_path: Path,
385
+ *,
386
+ output_override: Optional[str] = None,
387
+ ) -> Dict[str, Optional[str]]:
388
+ output_root = resolve_dir_path(
389
+ output_override or cfg.get("output_dir"),
390
+ config_path.parent,
391
+ )
392
+ return {
393
+ "output_dir": str(output_root) if output_root is not None else None,
394
+ }