ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1 @@
1
+ """Shared CLI utilities for ins_pricing modelling."""
@@ -0,0 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
6
+
7
+ import pandas as pd
8
+ from sklearn.model_selection import GroupShuffleSplit, train_test_split
9
+
10
+
11
+ PLOT_MODEL_LABELS: Dict[str, Tuple[str, str]] = {
12
+ "glm": ("GLM", "pred_glm"),
13
+ "xgb": ("Xgboost", "pred_xgb"),
14
+ "resn": ("ResNet", "pred_resn"),
15
+ "ft": ("FTTransformer", "pred_ft"),
16
+ "gnn": ("GNN", "pred_gnn"),
17
+ }
18
+
19
+ PYTORCH_TRAINERS = {"resn", "ft", "gnn"}
20
+
21
+
22
+ def dedupe_preserve_order(items: Iterable[str]) -> List[str]:
23
+ seen = set()
24
+ unique: List[str] = []
25
+ for item in items:
26
+ if item not in seen:
27
+ unique.append(item)
28
+ seen.add(item)
29
+ return unique
30
+
31
+
32
+ def build_model_names(prefixes: Sequence[str], suffixes: Sequence[str]) -> List[str]:
33
+ names: List[str] = []
34
+ for suffix in suffixes:
35
+ names.extend(f"{prefix}_{suffix}" for prefix in prefixes)
36
+ return names
37
+
38
+
39
+ def parse_model_pairs(raw_pairs: List) -> List[Tuple[str, str]]:
40
+ pairs: List[Tuple[str, str]] = []
41
+ for pair in raw_pairs:
42
+ if isinstance(pair, (list, tuple)) and len(pair) == 2:
43
+ pairs.append((str(pair[0]), str(pair[1])))
44
+ elif isinstance(pair, str):
45
+ parts = [p.strip() for p in pair.split(",") if p.strip()]
46
+ if len(parts) == 2:
47
+ pairs.append((parts[0], parts[1]))
48
+ return pairs
49
+
50
+
51
+ def resolve_path(value: Optional[str], base_dir: Path) -> Optional[Path]:
52
+ if value is None:
53
+ return None
54
+ if not isinstance(value, str) or not value.strip():
55
+ return None
56
+ p = Path(value)
57
+ if p.is_absolute():
58
+ return p
59
+ return (base_dir / p).resolve()
60
+
61
+
62
+ def resolve_dir_path(
63
+ value: Optional[Union[str, Path]],
64
+ base_dir: Path,
65
+ *,
66
+ create: bool = False,
67
+ ) -> Optional[Path]:
68
+ if value is None:
69
+ return None
70
+ if isinstance(value, Path):
71
+ path = value if value.is_absolute() else (base_dir / value).resolve()
72
+ else:
73
+ value_str = str(value)
74
+ if not value_str.strip():
75
+ return None
76
+ path = resolve_path(value_str, base_dir)
77
+ if path is None:
78
+ path = Path(value_str)
79
+ if not path.is_absolute():
80
+ path = (base_dir / path).resolve()
81
+ if create:
82
+ path.mkdir(parents=True, exist_ok=True)
83
+ return path
84
+
85
+
86
+ def _infer_format_from_path(path: Path) -> str:
87
+ suffix = path.suffix.lower()
88
+ if suffix in {".parquet", ".pq"}:
89
+ return "parquet"
90
+ if suffix in {".feather", ".ft"}:
91
+ return "feather"
92
+ return "csv"
93
+
94
+
95
+ def resolve_data_path(
96
+ data_dir: Path,
97
+ model_name: str,
98
+ *,
99
+ data_format: str = "csv",
100
+ path_template: Optional[str] = None,
101
+ ) -> Path:
102
+ fmt = str(data_format or "csv").strip().lower()
103
+ template = path_template or "{model_name}.{ext}"
104
+ if fmt == "auto":
105
+ candidates = [
106
+ data_dir / template.format(model_name=model_name, ext="parquet"),
107
+ data_dir / template.format(model_name=model_name, ext="feather"),
108
+ data_dir / template.format(model_name=model_name, ext="csv"),
109
+ ]
110
+ for cand in candidates:
111
+ if cand.exists():
112
+ return cand.resolve()
113
+ return candidates[-1].resolve()
114
+ ext = "csv" if fmt in {"csv"} else fmt
115
+ return (data_dir / template.format(model_name=model_name, ext=ext)).resolve()
116
+
117
+
118
+ def load_dataset(
119
+ path: Path,
120
+ *,
121
+ data_format: str = "auto",
122
+ dtype_map: Optional[Dict[str, Any]] = None,
123
+ low_memory: bool = False,
124
+ ) -> pd.DataFrame:
125
+ fmt = str(data_format or "auto").strip().lower()
126
+ if fmt == "auto":
127
+ fmt = _infer_format_from_path(path)
128
+ if fmt == "parquet":
129
+ df = pd.read_parquet(path)
130
+ elif fmt == "feather":
131
+ df = pd.read_feather(path)
132
+ elif fmt == "csv":
133
+ df = pd.read_csv(path, low_memory=low_memory, dtype=dtype_map or None)
134
+ else:
135
+ raise ValueError(f"Unsupported data_format: {data_format}")
136
+ if dtype_map:
137
+ for col, dtype in dtype_map.items():
138
+ if col in df.columns:
139
+ df[col] = df[col].astype(dtype)
140
+ return df
141
+
142
+
143
+ def coerce_dataset_types(raw: pd.DataFrame) -> pd.DataFrame:
144
+ data = raw.copy()
145
+ for col in data.columns:
146
+ s = data[col]
147
+ if pd.api.types.is_numeric_dtype(s):
148
+ data[col] = pd.to_numeric(s, errors="coerce").fillna(0)
149
+ else:
150
+ data[col] = s.astype("object").fillna("<NA>")
151
+ return data
152
+
153
+
154
+ def split_train_test(
155
+ df: pd.DataFrame,
156
+ *,
157
+ holdout_ratio: float,
158
+ strategy: str = "random",
159
+ group_col: Optional[str] = None,
160
+ time_col: Optional[str] = None,
161
+ time_ascending: bool = True,
162
+ rand_seed: Optional[int] = None,
163
+ reset_index_mode: str = "none",
164
+ ratio_label: str = "holdout_ratio",
165
+ include_strategy_in_ratio_error: bool = False,
166
+ validate_ratio: bool = True,
167
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
168
+ strategy = str(strategy or "random").strip().lower()
169
+ holdout_ratio = float(holdout_ratio)
170
+ if include_strategy_in_ratio_error and strategy in {"time", "timeseries", "temporal", "group", "grouped"}:
171
+ strategy_label = "time" if strategy in {"time", "timeseries", "temporal"} else "group"
172
+ ratio_error = (
173
+ f"{ratio_label} must be in (0, 1) for {strategy_label} split; got {holdout_ratio}."
174
+ )
175
+ else:
176
+ ratio_error = f"{ratio_label} must be in (0, 1); got {holdout_ratio}."
177
+
178
+ if strategy in {"time", "timeseries", "temporal"}:
179
+ if not time_col:
180
+ raise ValueError("split_time_col is required for time split_strategy.")
181
+ if time_col not in df.columns:
182
+ raise KeyError(f"split_time_col '{time_col}' not in dataset columns.")
183
+ if validate_ratio and not (0.0 < holdout_ratio < 1.0):
184
+ raise ValueError(ratio_error)
185
+ ordered = df.sort_values(time_col, ascending=bool(time_ascending))
186
+ cutoff = int(len(ordered) * (1.0 - holdout_ratio))
187
+ if cutoff <= 0 or cutoff >= len(ordered):
188
+ raise ValueError(
189
+ f"{ratio_label}={holdout_ratio} leaves no data for train/test split.")
190
+ train_df = ordered.iloc[:cutoff]
191
+ test_df = ordered.iloc[cutoff:]
192
+ elif strategy in {"group", "grouped"}:
193
+ if not group_col:
194
+ raise ValueError("split_group_col is required for group split_strategy.")
195
+ if group_col not in df.columns:
196
+ raise KeyError(f"split_group_col '{group_col}' not in dataset columns.")
197
+ if validate_ratio and not (0.0 < holdout_ratio < 1.0):
198
+ raise ValueError(ratio_error)
199
+ splitter = GroupShuffleSplit(
200
+ n_splits=1,
201
+ test_size=holdout_ratio,
202
+ random_state=rand_seed,
203
+ )
204
+ train_idx, test_idx = next(splitter.split(df, groups=df[group_col]))
205
+ train_df = df.iloc[train_idx]
206
+ test_df = df.iloc[test_idx]
207
+ else:
208
+ train_df, test_df = train_test_split(
209
+ df, test_size=holdout_ratio, random_state=rand_seed
210
+ )
211
+
212
+ if reset_index_mode == "always" or (
213
+ reset_index_mode == "time_group"
214
+ and strategy in {"time", "timeseries", "temporal", "group", "grouped"}
215
+ ):
216
+ train_df = train_df.reset_index(drop=True)
217
+ test_df = test_df.reset_index(drop=True)
218
+ return train_df, test_df
219
+
220
+
221
+ def fingerprint_file(path: Path, *, max_bytes: int = 10_485_760) -> Dict[str, Any]:
222
+ path = Path(path)
223
+ stat = path.stat()
224
+ h = hashlib.sha256()
225
+ remaining = int(max_bytes)
226
+ with path.open("rb") as fh:
227
+ while remaining > 0:
228
+ chunk = fh.read(min(1024 * 1024, remaining))
229
+ if not chunk:
230
+ break
231
+ h.update(chunk)
232
+ remaining -= len(chunk)
233
+ return {
234
+ "path": str(path),
235
+ "size": int(stat.st_size),
236
+ "mtime": int(stat.st_mtime),
237
+ "sha256_prefix": h.hexdigest(),
238
+ "max_bytes": int(max_bytes),
239
+ }
240
+
241
+
242
+ def _load_cli_config():
243
+ try:
244
+ from . import cli_config as _cli_config # type: ignore
245
+ except Exception: # pragma: no cover
246
+ import cli_config as _cli_config # type: ignore
247
+ return _cli_config
248
+
249
+
250
+ def resolve_config_path(raw: str, script_dir: Path) -> Path:
251
+ return _load_cli_config().resolve_config_path(raw, script_dir)
252
+
253
+
254
+ def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
255
+ return _load_cli_config().load_config_json(path, required_keys)
256
+
257
+
258
+ def set_env(env_overrides: Dict[str, Any]) -> None:
259
+ _load_cli_config().set_env(env_overrides)
260
+
261
+
262
+ def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
263
+ return _load_cli_config().normalize_config_paths(cfg, config_path)
264
+
265
+
266
+ def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
267
+ return _load_cli_config().resolve_dtype_map(value, base_dir)
268
+
269
+
270
+ def resolve_data_config(
271
+ cfg: Dict[str, Any],
272
+ config_path: Path,
273
+ *,
274
+ create_data_dir: bool = False,
275
+ ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
276
+ return _load_cli_config().resolve_data_config(
277
+ cfg,
278
+ config_path,
279
+ create_data_dir=create_data_dir,
280
+ )
281
+
282
+
283
+ def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
284
+ return _load_cli_config().resolve_report_config(cfg)
285
+
286
+
287
+ def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
288
+ return _load_cli_config().resolve_split_config(cfg)
289
+
290
+
291
+ def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
292
+ return _load_cli_config().resolve_runtime_config(cfg)
293
+
294
+
295
+ def resolve_output_dirs(
296
+ cfg: Dict[str, Any],
297
+ config_path: Path,
298
+ *,
299
+ output_override: Optional[str] = None,
300
+ ) -> Dict[str, Optional[str]]:
301
+ return _load_cli_config().resolve_output_dirs(
302
+ cfg,
303
+ config_path,
304
+ output_override=output_override,
305
+ )
306
+
307
+
308
+ def resolve_and_load_config(
309
+ raw: str,
310
+ script_dir: Path,
311
+ required_keys: Sequence[str],
312
+ *,
313
+ apply_env: bool = True,
314
+ ) -> Tuple[Path, Dict[str, Any]]:
315
+ return _load_cli_config().resolve_and_load_config(
316
+ raw,
317
+ script_dir,
318
+ required_keys,
319
+ apply_env=apply_env,
320
+ )
@@ -0,0 +1,375 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Sequence, Tuple
8
+
9
+ try:
10
+ from .cli_common import resolve_dir_path, resolve_path # type: ignore
11
+ except Exception: # pragma: no cover
12
+ from cli_common import resolve_dir_path, resolve_path # type: ignore
13
+
14
+
15
+ def resolve_config_path(raw: str, script_dir: Path) -> Path:
16
+ candidate = Path(raw)
17
+ if candidate.exists():
18
+ return candidate.resolve()
19
+ candidate2 = (script_dir / raw)
20
+ if candidate2.exists():
21
+ return candidate2.resolve()
22
+ raise FileNotFoundError(
23
+ f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
24
+ )
25
+
26
+
27
+ def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
28
+ cfg = json.loads(path.read_text(encoding="utf-8"))
29
+ missing = [key for key in required_keys if key not in cfg]
30
+ if missing:
31
+ raise ValueError(f"Missing required keys in {path}: {missing}")
32
+ return cfg
33
+
34
+
35
+ def set_env(env_overrides: Dict[str, Any]) -> None:
36
+ """Apply environment variables from config.json.
37
+
38
+ Notes (DDP/Optuna hang debugging):
39
+ - You can add these keys into config.json's `env` to debug distributed hangs:
40
+ - `TORCH_DISTRIBUTED_DEBUG=DETAIL`
41
+ - `NCCL_DEBUG=INFO`
42
+ - `BAYESOPT_DDP_BARRIER_DEBUG=1`
43
+ - `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
44
+ - `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
45
+ - `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
46
+ - This function uses `os.environ.setdefault`, so a value already set in the
47
+ shell will take precedence over config.json.
48
+ """
49
+ for key, value in (env_overrides or {}).items():
50
+ os.environ.setdefault(str(key), str(value))
51
+
52
+
53
+ def _looks_like_url(value: str) -> bool:
54
+ value = str(value)
55
+ return "://" in value
56
+
57
+
58
+ def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
59
+ """Resolve relative paths against the config.json directory.
60
+
61
+ Fields handled:
62
+ - data_dir / output_dir / optuna_storage / gnn_graph_cache
63
+ - best_params_files (dict: model_key -> path)
64
+ """
65
+ base_dir = config_path.parent
66
+ out = dict(cfg)
67
+
68
+ for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
69
+ "prediction_cache_dir", "report_output_dir", "registry_path"):
70
+ if key in out and isinstance(out.get(key), str):
71
+ resolved = resolve_path(out.get(key), base_dir)
72
+ if resolved is not None:
73
+ out[key] = str(resolved)
74
+
75
+ storage = out.get("optuna_storage")
76
+ if isinstance(storage, str) and storage.strip():
77
+ if not _looks_like_url(storage):
78
+ resolved = resolve_path(storage, base_dir)
79
+ if resolved is not None:
80
+ out["optuna_storage"] = str(resolved)
81
+
82
+ best_files = out.get("best_params_files")
83
+ if isinstance(best_files, dict):
84
+ resolved_map: Dict[str, str] = {}
85
+ for mk, path_str in best_files.items():
86
+ if not isinstance(path_str, str):
87
+ continue
88
+ resolved = resolve_path(path_str, base_dir)
89
+ resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
90
+ out["best_params_files"] = resolved_map
91
+
92
+ return out
93
+
94
+
95
+ def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
96
+ if value is None:
97
+ return {}
98
+ if isinstance(value, dict):
99
+ return {str(k): v for k, v in value.items()}
100
+ if isinstance(value, str):
101
+ path = resolve_path(value, base_dir)
102
+ if path is None or not path.exists():
103
+ raise FileNotFoundError(f"dtype_map not found: {value}")
104
+ payload = json.loads(path.read_text(encoding="utf-8"))
105
+ if not isinstance(payload, dict):
106
+ raise ValueError(f"dtype_map must be a dict: {path}")
107
+ return {str(k): v for k, v in payload.items()}
108
+ raise ValueError("dtype_map must be a dict or JSON path.")
109
+
110
+
111
+ def resolve_data_config(
112
+ cfg: Dict[str, Any],
113
+ config_path: Path,
114
+ *,
115
+ create_data_dir: bool = False,
116
+ ) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
117
+ base_dir = config_path.parent
118
+ data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
119
+ if data_dir is None:
120
+ raise ValueError("data_dir is required in config.json.")
121
+ data_format = cfg.get("data_format", "csv")
122
+ data_path_template = cfg.get("data_path_template")
123
+ dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
124
+ return data_dir, data_format, data_path_template, dtype_map
125
+
126
+
127
+ def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
128
+ parser.add_argument(
129
+ "--config-json",
130
+ required=True,
131
+ help=help_text,
132
+ )
133
+
134
+
135
+ def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
136
+ parser.add_argument(
137
+ "--output-dir",
138
+ default=None,
139
+ help=help_text,
140
+ )
141
+
142
+
143
+ def resolve_model_path_value(
144
+ value: Any,
145
+ *,
146
+ model_name: str,
147
+ base_dir: Path,
148
+ data_dir: Optional[Path] = None,
149
+ ) -> Optional[Path]:
150
+ if value is None:
151
+ return None
152
+ if isinstance(value, dict):
153
+ value = value.get(model_name)
154
+ if value is None:
155
+ return None
156
+ path_str = str(value)
157
+ try:
158
+ path_str = path_str.format(model_name=model_name)
159
+ except Exception:
160
+ pass
161
+ if data_dir is not None and not Path(path_str).is_absolute():
162
+ candidate = data_dir / path_str
163
+ if candidate.exists():
164
+ return candidate.resolve()
165
+ resolved = resolve_path(path_str, base_dir)
166
+ if resolved is None:
167
+ return None
168
+ return resolved
169
+
170
+
171
+ def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
172
+ if not value:
173
+ return None
174
+ path_str = str(value)
175
+ resolved = resolve_path(path_str, base_dir)
176
+ return resolved if resolved is not None else Path(path_str)
177
+
178
+
179
+ def resolve_explain_save_dir(
180
+ save_root: Optional[Path],
181
+ *,
182
+ result_dir: Optional[Any],
183
+ ) -> Path:
184
+ if save_root is not None:
185
+ return Path(save_root)
186
+ if result_dir is None:
187
+ raise ValueError("result_dir is required when explain save_root is not set.")
188
+ return Path(result_dir) / "explain"
189
+
190
+
191
+ def resolve_explain_output_overrides(
192
+ explain_cfg: Dict[str, Any],
193
+ *,
194
+ model_name: str,
195
+ base_dir: Path,
196
+ ) -> Dict[str, Optional[Path]]:
197
+ return {
198
+ "model_dir": resolve_model_path_value(
199
+ explain_cfg.get("model_dir"),
200
+ model_name=model_name,
201
+ base_dir=base_dir,
202
+ data_dir=None,
203
+ ),
204
+ "result_dir": resolve_model_path_value(
205
+ explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
206
+ model_name=model_name,
207
+ base_dir=base_dir,
208
+ data_dir=None,
209
+ ),
210
+ "plot_dir": resolve_model_path_value(
211
+ explain_cfg.get("plot_dir"),
212
+ model_name=model_name,
213
+ base_dir=base_dir,
214
+ data_dir=None,
215
+ ),
216
+ }
217
+
218
+
219
+ def resolve_and_load_config(
220
+ raw: str,
221
+ script_dir: Path,
222
+ required_keys: Sequence[str],
223
+ *,
224
+ apply_env: bool = True,
225
+ ) -> Tuple[Path, Dict[str, Any]]:
226
+ config_path = resolve_config_path(raw, script_dir)
227
+ cfg = load_config_json(config_path, required_keys=required_keys)
228
+ cfg = normalize_config_paths(cfg, config_path)
229
+ if apply_env:
230
+ set_env(cfg.get("env", {}))
231
+ return config_path, cfg
232
+
233
+
234
+ def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
235
+ def _as_list(value: Any) -> list[str]:
236
+ if value is None:
237
+ return []
238
+ if isinstance(value, (list, tuple, set)):
239
+ return [str(item) for item in value]
240
+ return [str(value)]
241
+
242
+ report_output_dir = cfg.get("report_output_dir")
243
+ report_group_cols = _as_list(cfg.get("report_group_cols"))
244
+ if not report_group_cols:
245
+ report_group_cols = None
246
+ report_time_col = cfg.get("report_time_col")
247
+ report_time_freq = cfg.get("report_time_freq", "M")
248
+ report_time_ascending = bool(cfg.get("report_time_ascending", True))
249
+ psi_bins = cfg.get("psi_bins", 10)
250
+ psi_strategy = cfg.get("psi_strategy", "quantile")
251
+ psi_features = _as_list(cfg.get("psi_features"))
252
+ if not psi_features:
253
+ psi_features = None
254
+ calibration_cfg = cfg.get("calibration", {}) or {}
255
+ threshold_cfg = cfg.get("threshold", {}) or {}
256
+ bootstrap_cfg = cfg.get("bootstrap", {}) or {}
257
+ register_model = bool(cfg.get("register_model", False))
258
+ registry_path = cfg.get("registry_path")
259
+ registry_tags = cfg.get("registry_tags", {})
260
+ registry_status = cfg.get("registry_status", "candidate")
261
+ data_fingerprint_max_bytes = int(
262
+ cfg.get("data_fingerprint_max_bytes", 10_485_760))
263
+ calibration_enabled = bool(
264
+ calibration_cfg.get("enable", False) or calibration_cfg.get("method")
265
+ )
266
+ threshold_enabled = bool(
267
+ threshold_cfg.get("enable", False)
268
+ or threshold_cfg.get("value") is not None
269
+ or threshold_cfg.get("metric")
270
+ )
271
+ bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
272
+ report_enabled = any([
273
+ bool(report_output_dir),
274
+ register_model,
275
+ bool(report_group_cols),
276
+ bool(report_time_col),
277
+ bool(psi_features),
278
+ calibration_enabled,
279
+ threshold_enabled,
280
+ bootstrap_enabled,
281
+ ])
282
+ return {
283
+ "report_output_dir": report_output_dir,
284
+ "report_group_cols": report_group_cols,
285
+ "report_time_col": report_time_col,
286
+ "report_time_freq": report_time_freq,
287
+ "report_time_ascending": report_time_ascending,
288
+ "psi_bins": psi_bins,
289
+ "psi_strategy": psi_strategy,
290
+ "psi_features": psi_features,
291
+ "calibration_cfg": calibration_cfg,
292
+ "threshold_cfg": threshold_cfg,
293
+ "bootstrap_cfg": bootstrap_cfg,
294
+ "register_model": register_model,
295
+ "registry_path": registry_path,
296
+ "registry_tags": registry_tags,
297
+ "registry_status": registry_status,
298
+ "data_fingerprint_max_bytes": data_fingerprint_max_bytes,
299
+ "report_enabled": report_enabled,
300
+ }
301
+
302
+
303
+ def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
304
+ prop_test = cfg.get("prop_test", 0.25)
305
+ holdout_ratio = cfg.get("holdout_ratio", prop_test)
306
+ if holdout_ratio is None:
307
+ holdout_ratio = prop_test
308
+ val_ratio = cfg.get("val_ratio", prop_test)
309
+ if val_ratio is None:
310
+ val_ratio = prop_test
311
+ split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
312
+ split_group_col = cfg.get("split_group_col")
313
+ split_time_col = cfg.get("split_time_col")
314
+ split_time_ascending = bool(cfg.get("split_time_ascending", True))
315
+ cv_strategy = cfg.get("cv_strategy")
316
+ cv_group_col = cfg.get("cv_group_col")
317
+ cv_time_col = cfg.get("cv_time_col")
318
+ cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
319
+ cv_splits = cfg.get("cv_splits")
320
+ ft_oof_folds = cfg.get("ft_oof_folds")
321
+ ft_oof_strategy = cfg.get("ft_oof_strategy")
322
+ ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
323
+ return {
324
+ "prop_test": prop_test,
325
+ "holdout_ratio": holdout_ratio,
326
+ "val_ratio": val_ratio,
327
+ "split_strategy": split_strategy,
328
+ "split_group_col": split_group_col,
329
+ "split_time_col": split_time_col,
330
+ "split_time_ascending": split_time_ascending,
331
+ "cv_strategy": cv_strategy,
332
+ "cv_group_col": cv_group_col,
333
+ "cv_time_col": cv_time_col,
334
+ "cv_time_ascending": cv_time_ascending,
335
+ "cv_splits": cv_splits,
336
+ "ft_oof_folds": ft_oof_folds,
337
+ "ft_oof_strategy": ft_oof_strategy,
338
+ "ft_oof_shuffle": ft_oof_shuffle,
339
+ }
340
+
341
+
342
+ def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
343
+ return {
344
+ "save_preprocess": bool(cfg.get("save_preprocess", False)),
345
+ "preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
346
+ "rand_seed": cfg.get("rand_seed", 13),
347
+ "epochs": cfg.get("epochs", 50),
348
+ "plot_path_style": cfg.get("plot_path_style"),
349
+ "reuse_best_params": bool(cfg.get("reuse_best_params", False)),
350
+ "xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
351
+ "xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
352
+ "optuna_storage": cfg.get("optuna_storage"),
353
+ "optuna_study_prefix": cfg.get("optuna_study_prefix"),
354
+ "best_params_files": cfg.get("best_params_files"),
355
+ "bo_sample_limit": cfg.get("bo_sample_limit"),
356
+ "cache_predictions": bool(cfg.get("cache_predictions", False)),
357
+ "prediction_cache_dir": cfg.get("prediction_cache_dir"),
358
+ "prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
359
+ "ddp_min_rows": cfg.get("ddp_min_rows", 50000),
360
+ }
361
+
362
+
363
+ def resolve_output_dirs(
364
+ cfg: Dict[str, Any],
365
+ config_path: Path,
366
+ *,
367
+ output_override: Optional[str] = None,
368
+ ) -> Dict[str, Optional[str]]:
369
+ output_root = resolve_dir_path(
370
+ output_override or cfg.get("output_dir"),
371
+ config_path.parent,
372
+ )
373
+ return {
374
+ "output_dir": str(output_root) if output_root is not None else None,
375
+ }