ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +9 -6
- ins_pricing/__init__.py +3 -11
- ins_pricing/cli/BayesOpt_entry.py +24 -0
- ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
- ins_pricing/cli/Explain_Run.py +25 -0
- ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
- ins_pricing/cli/Pricing_Run.py +25 -0
- ins_pricing/cli/__init__.py +1 -0
- ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
- ins_pricing/cli/utils/__init__.py +1 -0
- ins_pricing/cli/utils/cli_common.py +320 -0
- ins_pricing/cli/utils/cli_config.py +375 -0
- ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
- {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
- ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
- ins_pricing/docs/modelling/README.md +34 -0
- ins_pricing/modelling/__init__.py +57 -6
- ins_pricing/modelling/core/__init__.py +1 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
- ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
- ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
- ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
- ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
- ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
- ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
- ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
- ins_pricing/modelling/core/evaluation.py +115 -0
- ins_pricing/production/__init__.py +4 -0
- ins_pricing/production/preprocess.py +71 -0
- ins_pricing/setup.py +10 -5
- {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
- ins_pricing-0.2.0.dist-info/RECORD +125 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
- ins_pricing/modelling/BayesOpt_entry.py +0 -633
- ins_pricing/modelling/Explain_Run.py +0 -36
- ins_pricing/modelling/Pricing_Run.py +0 -36
- ins_pricing/modelling/README.md +0 -33
- ins_pricing/modelling/bayesopt/models.py +0 -2196
- ins_pricing/modelling/bayesopt/trainers.py +0 -2446
- ins_pricing/modelling/cli_common.py +0 -136
- ins_pricing/modelling/tests/test_plotting.py +0 -63
- ins_pricing/modelling/watchdog_run.py +0 -211
- ins_pricing-0.1.11.dist-info/RECORD +0 -169
- ins_pricing_gemini/__init__.py +0 -23
- ins_pricing_gemini/governance/__init__.py +0 -20
- ins_pricing_gemini/governance/approval.py +0 -93
- ins_pricing_gemini/governance/audit.py +0 -37
- ins_pricing_gemini/governance/registry.py +0 -99
- ins_pricing_gemini/governance/release.py +0 -159
- ins_pricing_gemini/modelling/Explain_Run.py +0 -36
- ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
- ins_pricing_gemini/modelling/__init__.py +0 -151
- ins_pricing_gemini/modelling/cli_common.py +0 -141
- ins_pricing_gemini/modelling/config.py +0 -249
- ins_pricing_gemini/modelling/config_preprocess.py +0 -254
- ins_pricing_gemini/modelling/core.py +0 -741
- ins_pricing_gemini/modelling/data_container.py +0 -42
- ins_pricing_gemini/modelling/explain/__init__.py +0 -55
- ins_pricing_gemini/modelling/explain/gradients.py +0 -334
- ins_pricing_gemini/modelling/explain/metrics.py +0 -176
- ins_pricing_gemini/modelling/explain/permutation.py +0 -155
- ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
- ins_pricing_gemini/modelling/features.py +0 -215
- ins_pricing_gemini/modelling/model_manager.py +0 -148
- ins_pricing_gemini/modelling/model_plotting.py +0 -463
- ins_pricing_gemini/modelling/models.py +0 -2203
- ins_pricing_gemini/modelling/notebook_utils.py +0 -294
- ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
- ins_pricing_gemini/modelling/plotting/common.py +0 -63
- ins_pricing_gemini/modelling/plotting/curves.py +0 -572
- ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
- ins_pricing_gemini/modelling/plotting/geo.py +0 -362
- ins_pricing_gemini/modelling/plotting/importance.py +0 -121
- ins_pricing_gemini/modelling/run_logging.py +0 -133
- ins_pricing_gemini/modelling/tests/conftest.py +0 -8
- ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
- ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
- ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
- ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
- ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
- ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
- ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
- ins_pricing_gemini/modelling/trainers.py +0 -2447
- ins_pricing_gemini/modelling/utils.py +0 -1020
- ins_pricing_gemini/pricing/__init__.py +0 -27
- ins_pricing_gemini/pricing/calibration.py +0 -39
- ins_pricing_gemini/pricing/data_quality.py +0 -117
- ins_pricing_gemini/pricing/exposure.py +0 -85
- ins_pricing_gemini/pricing/factors.py +0 -91
- ins_pricing_gemini/pricing/monitoring.py +0 -99
- ins_pricing_gemini/pricing/rate_table.py +0 -78
- ins_pricing_gemini/production/__init__.py +0 -21
- ins_pricing_gemini/production/drift.py +0 -30
- ins_pricing_gemini/production/monitoring.py +0 -143
- ins_pricing_gemini/production/scoring.py +0 -40
- ins_pricing_gemini/reporting/__init__.py +0 -11
- ins_pricing_gemini/reporting/report_builder.py +0 -72
- ins_pricing_gemini/reporting/scheduler.py +0 -45
- ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
- ins_pricing_gemini/scripts/Explain_entry.py +0 -545
- ins_pricing_gemini/scripts/__init__.py +0 -1
- ins_pricing_gemini/scripts/train.py +0 -568
- ins_pricing_gemini/setup.py +0 -55
- ins_pricing_gemini/smoke_test.py +0 -28
- /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
- /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
- /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
- /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
- {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Shared CLI utilities for ins_pricing modelling."""
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
PLOT_MODEL_LABELS: Dict[str, Tuple[str, str]] = {
|
|
12
|
+
"glm": ("GLM", "pred_glm"),
|
|
13
|
+
"xgb": ("Xgboost", "pred_xgb"),
|
|
14
|
+
"resn": ("ResNet", "pred_resn"),
|
|
15
|
+
"ft": ("FTTransformer", "pred_ft"),
|
|
16
|
+
"gnn": ("GNN", "pred_gnn"),
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
PYTORCH_TRAINERS = {"resn", "ft", "gnn"}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dedupe_preserve_order(items: Iterable[str]) -> List[str]:
|
|
23
|
+
seen = set()
|
|
24
|
+
unique: List[str] = []
|
|
25
|
+
for item in items:
|
|
26
|
+
if item not in seen:
|
|
27
|
+
unique.append(item)
|
|
28
|
+
seen.add(item)
|
|
29
|
+
return unique
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_model_names(prefixes: Sequence[str], suffixes: Sequence[str]) -> List[str]:
|
|
33
|
+
names: List[str] = []
|
|
34
|
+
for suffix in suffixes:
|
|
35
|
+
names.extend(f"{prefix}_{suffix}" for prefix in prefixes)
|
|
36
|
+
return names
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def parse_model_pairs(raw_pairs: List) -> List[Tuple[str, str]]:
|
|
40
|
+
pairs: List[Tuple[str, str]] = []
|
|
41
|
+
for pair in raw_pairs:
|
|
42
|
+
if isinstance(pair, (list, tuple)) and len(pair) == 2:
|
|
43
|
+
pairs.append((str(pair[0]), str(pair[1])))
|
|
44
|
+
elif isinstance(pair, str):
|
|
45
|
+
parts = [p.strip() for p in pair.split(",") if p.strip()]
|
|
46
|
+
if len(parts) == 2:
|
|
47
|
+
pairs.append((parts[0], parts[1]))
|
|
48
|
+
return pairs
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def resolve_path(value: Optional[str], base_dir: Path) -> Optional[Path]:
|
|
52
|
+
if value is None:
|
|
53
|
+
return None
|
|
54
|
+
if not isinstance(value, str) or not value.strip():
|
|
55
|
+
return None
|
|
56
|
+
p = Path(value)
|
|
57
|
+
if p.is_absolute():
|
|
58
|
+
return p
|
|
59
|
+
return (base_dir / p).resolve()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def resolve_dir_path(
|
|
63
|
+
value: Optional[Union[str, Path]],
|
|
64
|
+
base_dir: Path,
|
|
65
|
+
*,
|
|
66
|
+
create: bool = False,
|
|
67
|
+
) -> Optional[Path]:
|
|
68
|
+
if value is None:
|
|
69
|
+
return None
|
|
70
|
+
if isinstance(value, Path):
|
|
71
|
+
path = value if value.is_absolute() else (base_dir / value).resolve()
|
|
72
|
+
else:
|
|
73
|
+
value_str = str(value)
|
|
74
|
+
if not value_str.strip():
|
|
75
|
+
return None
|
|
76
|
+
path = resolve_path(value_str, base_dir)
|
|
77
|
+
if path is None:
|
|
78
|
+
path = Path(value_str)
|
|
79
|
+
if not path.is_absolute():
|
|
80
|
+
path = (base_dir / path).resolve()
|
|
81
|
+
if create:
|
|
82
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
return path
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _infer_format_from_path(path: Path) -> str:
|
|
87
|
+
suffix = path.suffix.lower()
|
|
88
|
+
if suffix in {".parquet", ".pq"}:
|
|
89
|
+
return "parquet"
|
|
90
|
+
if suffix in {".feather", ".ft"}:
|
|
91
|
+
return "feather"
|
|
92
|
+
return "csv"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def resolve_data_path(
|
|
96
|
+
data_dir: Path,
|
|
97
|
+
model_name: str,
|
|
98
|
+
*,
|
|
99
|
+
data_format: str = "csv",
|
|
100
|
+
path_template: Optional[str] = None,
|
|
101
|
+
) -> Path:
|
|
102
|
+
fmt = str(data_format or "csv").strip().lower()
|
|
103
|
+
template = path_template or "{model_name}.{ext}"
|
|
104
|
+
if fmt == "auto":
|
|
105
|
+
candidates = [
|
|
106
|
+
data_dir / template.format(model_name=model_name, ext="parquet"),
|
|
107
|
+
data_dir / template.format(model_name=model_name, ext="feather"),
|
|
108
|
+
data_dir / template.format(model_name=model_name, ext="csv"),
|
|
109
|
+
]
|
|
110
|
+
for cand in candidates:
|
|
111
|
+
if cand.exists():
|
|
112
|
+
return cand.resolve()
|
|
113
|
+
return candidates[-1].resolve()
|
|
114
|
+
ext = "csv" if fmt in {"csv"} else fmt
|
|
115
|
+
return (data_dir / template.format(model_name=model_name, ext=ext)).resolve()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def load_dataset(
|
|
119
|
+
path: Path,
|
|
120
|
+
*,
|
|
121
|
+
data_format: str = "auto",
|
|
122
|
+
dtype_map: Optional[Dict[str, Any]] = None,
|
|
123
|
+
low_memory: bool = False,
|
|
124
|
+
) -> pd.DataFrame:
|
|
125
|
+
fmt = str(data_format or "auto").strip().lower()
|
|
126
|
+
if fmt == "auto":
|
|
127
|
+
fmt = _infer_format_from_path(path)
|
|
128
|
+
if fmt == "parquet":
|
|
129
|
+
df = pd.read_parquet(path)
|
|
130
|
+
elif fmt == "feather":
|
|
131
|
+
df = pd.read_feather(path)
|
|
132
|
+
elif fmt == "csv":
|
|
133
|
+
df = pd.read_csv(path, low_memory=low_memory, dtype=dtype_map or None)
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(f"Unsupported data_format: {data_format}")
|
|
136
|
+
if dtype_map:
|
|
137
|
+
for col, dtype in dtype_map.items():
|
|
138
|
+
if col in df.columns:
|
|
139
|
+
df[col] = df[col].astype(dtype)
|
|
140
|
+
return df
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def coerce_dataset_types(raw: pd.DataFrame) -> pd.DataFrame:
|
|
144
|
+
data = raw.copy()
|
|
145
|
+
for col in data.columns:
|
|
146
|
+
s = data[col]
|
|
147
|
+
if pd.api.types.is_numeric_dtype(s):
|
|
148
|
+
data[col] = pd.to_numeric(s, errors="coerce").fillna(0)
|
|
149
|
+
else:
|
|
150
|
+
data[col] = s.astype("object").fillna("<NA>")
|
|
151
|
+
return data
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def split_train_test(
|
|
155
|
+
df: pd.DataFrame,
|
|
156
|
+
*,
|
|
157
|
+
holdout_ratio: float,
|
|
158
|
+
strategy: str = "random",
|
|
159
|
+
group_col: Optional[str] = None,
|
|
160
|
+
time_col: Optional[str] = None,
|
|
161
|
+
time_ascending: bool = True,
|
|
162
|
+
rand_seed: Optional[int] = None,
|
|
163
|
+
reset_index_mode: str = "none",
|
|
164
|
+
ratio_label: str = "holdout_ratio",
|
|
165
|
+
include_strategy_in_ratio_error: bool = False,
|
|
166
|
+
validate_ratio: bool = True,
|
|
167
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
168
|
+
strategy = str(strategy or "random").strip().lower()
|
|
169
|
+
holdout_ratio = float(holdout_ratio)
|
|
170
|
+
if include_strategy_in_ratio_error and strategy in {"time", "timeseries", "temporal", "group", "grouped"}:
|
|
171
|
+
strategy_label = "time" if strategy in {"time", "timeseries", "temporal"} else "group"
|
|
172
|
+
ratio_error = (
|
|
173
|
+
f"{ratio_label} must be in (0, 1) for {strategy_label} split; got {holdout_ratio}."
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
ratio_error = f"{ratio_label} must be in (0, 1); got {holdout_ratio}."
|
|
177
|
+
|
|
178
|
+
if strategy in {"time", "timeseries", "temporal"}:
|
|
179
|
+
if not time_col:
|
|
180
|
+
raise ValueError("split_time_col is required for time split_strategy.")
|
|
181
|
+
if time_col not in df.columns:
|
|
182
|
+
raise KeyError(f"split_time_col '{time_col}' not in dataset columns.")
|
|
183
|
+
if validate_ratio and not (0.0 < holdout_ratio < 1.0):
|
|
184
|
+
raise ValueError(ratio_error)
|
|
185
|
+
ordered = df.sort_values(time_col, ascending=bool(time_ascending))
|
|
186
|
+
cutoff = int(len(ordered) * (1.0 - holdout_ratio))
|
|
187
|
+
if cutoff <= 0 or cutoff >= len(ordered):
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"{ratio_label}={holdout_ratio} leaves no data for train/test split.")
|
|
190
|
+
train_df = ordered.iloc[:cutoff]
|
|
191
|
+
test_df = ordered.iloc[cutoff:]
|
|
192
|
+
elif strategy in {"group", "grouped"}:
|
|
193
|
+
if not group_col:
|
|
194
|
+
raise ValueError("split_group_col is required for group split_strategy.")
|
|
195
|
+
if group_col not in df.columns:
|
|
196
|
+
raise KeyError(f"split_group_col '{group_col}' not in dataset columns.")
|
|
197
|
+
if validate_ratio and not (0.0 < holdout_ratio < 1.0):
|
|
198
|
+
raise ValueError(ratio_error)
|
|
199
|
+
splitter = GroupShuffleSplit(
|
|
200
|
+
n_splits=1,
|
|
201
|
+
test_size=holdout_ratio,
|
|
202
|
+
random_state=rand_seed,
|
|
203
|
+
)
|
|
204
|
+
train_idx, test_idx = next(splitter.split(df, groups=df[group_col]))
|
|
205
|
+
train_df = df.iloc[train_idx]
|
|
206
|
+
test_df = df.iloc[test_idx]
|
|
207
|
+
else:
|
|
208
|
+
train_df, test_df = train_test_split(
|
|
209
|
+
df, test_size=holdout_ratio, random_state=rand_seed
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if reset_index_mode == "always" or (
|
|
213
|
+
reset_index_mode == "time_group"
|
|
214
|
+
and strategy in {"time", "timeseries", "temporal", "group", "grouped"}
|
|
215
|
+
):
|
|
216
|
+
train_df = train_df.reset_index(drop=True)
|
|
217
|
+
test_df = test_df.reset_index(drop=True)
|
|
218
|
+
return train_df, test_df
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def fingerprint_file(path: Path, *, max_bytes: int = 10_485_760) -> Dict[str, Any]:
|
|
222
|
+
path = Path(path)
|
|
223
|
+
stat = path.stat()
|
|
224
|
+
h = hashlib.sha256()
|
|
225
|
+
remaining = int(max_bytes)
|
|
226
|
+
with path.open("rb") as fh:
|
|
227
|
+
while remaining > 0:
|
|
228
|
+
chunk = fh.read(min(1024 * 1024, remaining))
|
|
229
|
+
if not chunk:
|
|
230
|
+
break
|
|
231
|
+
h.update(chunk)
|
|
232
|
+
remaining -= len(chunk)
|
|
233
|
+
return {
|
|
234
|
+
"path": str(path),
|
|
235
|
+
"size": int(stat.st_size),
|
|
236
|
+
"mtime": int(stat.st_mtime),
|
|
237
|
+
"sha256_prefix": h.hexdigest(),
|
|
238
|
+
"max_bytes": int(max_bytes),
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _load_cli_config():
|
|
243
|
+
try:
|
|
244
|
+
from . import cli_config as _cli_config # type: ignore
|
|
245
|
+
except Exception: # pragma: no cover
|
|
246
|
+
import cli_config as _cli_config # type: ignore
|
|
247
|
+
return _cli_config
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def resolve_config_path(raw: str, script_dir: Path) -> Path:
|
|
251
|
+
return _load_cli_config().resolve_config_path(raw, script_dir)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
|
|
255
|
+
return _load_cli_config().load_config_json(path, required_keys)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def set_env(env_overrides: Dict[str, Any]) -> None:
|
|
259
|
+
_load_cli_config().set_env(env_overrides)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
|
|
263
|
+
return _load_cli_config().normalize_config_paths(cfg, config_path)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
|
|
267
|
+
return _load_cli_config().resolve_dtype_map(value, base_dir)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def resolve_data_config(
|
|
271
|
+
cfg: Dict[str, Any],
|
|
272
|
+
config_path: Path,
|
|
273
|
+
*,
|
|
274
|
+
create_data_dir: bool = False,
|
|
275
|
+
) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
|
|
276
|
+
return _load_cli_config().resolve_data_config(
|
|
277
|
+
cfg,
|
|
278
|
+
config_path,
|
|
279
|
+
create_data_dir=create_data_dir,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
284
|
+
return _load_cli_config().resolve_report_config(cfg)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
288
|
+
return _load_cli_config().resolve_split_config(cfg)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
292
|
+
return _load_cli_config().resolve_runtime_config(cfg)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def resolve_output_dirs(
|
|
296
|
+
cfg: Dict[str, Any],
|
|
297
|
+
config_path: Path,
|
|
298
|
+
*,
|
|
299
|
+
output_override: Optional[str] = None,
|
|
300
|
+
) -> Dict[str, Optional[str]]:
|
|
301
|
+
return _load_cli_config().resolve_output_dirs(
|
|
302
|
+
cfg,
|
|
303
|
+
config_path,
|
|
304
|
+
output_override=output_override,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def resolve_and_load_config(
|
|
309
|
+
raw: str,
|
|
310
|
+
script_dir: Path,
|
|
311
|
+
required_keys: Sequence[str],
|
|
312
|
+
*,
|
|
313
|
+
apply_env: bool = True,
|
|
314
|
+
) -> Tuple[Path, Dict[str, Any]]:
|
|
315
|
+
return _load_cli_config().resolve_and_load_config(
|
|
316
|
+
raw,
|
|
317
|
+
script_dir,
|
|
318
|
+
required_keys,
|
|
319
|
+
apply_env=apply_env,
|
|
320
|
+
)
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from .cli_common import resolve_dir_path, resolve_path # type: ignore
|
|
11
|
+
except Exception: # pragma: no cover
|
|
12
|
+
from cli_common import resolve_dir_path, resolve_path # type: ignore
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def resolve_config_path(raw: str, script_dir: Path) -> Path:
|
|
16
|
+
candidate = Path(raw)
|
|
17
|
+
if candidate.exists():
|
|
18
|
+
return candidate.resolve()
|
|
19
|
+
candidate2 = (script_dir / raw)
|
|
20
|
+
if candidate2.exists():
|
|
21
|
+
return candidate2.resolve()
|
|
22
|
+
raise FileNotFoundError(
|
|
23
|
+
f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
|
|
28
|
+
cfg = json.loads(path.read_text(encoding="utf-8"))
|
|
29
|
+
missing = [key for key in required_keys if key not in cfg]
|
|
30
|
+
if missing:
|
|
31
|
+
raise ValueError(f"Missing required keys in {path}: {missing}")
|
|
32
|
+
return cfg
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def set_env(env_overrides: Dict[str, Any]) -> None:
|
|
36
|
+
"""Apply environment variables from config.json.
|
|
37
|
+
|
|
38
|
+
Notes (DDP/Optuna hang debugging):
|
|
39
|
+
- You can add these keys into config.json's `env` to debug distributed hangs:
|
|
40
|
+
- `TORCH_DISTRIBUTED_DEBUG=DETAIL`
|
|
41
|
+
- `NCCL_DEBUG=INFO`
|
|
42
|
+
- `BAYESOPT_DDP_BARRIER_DEBUG=1`
|
|
43
|
+
- `BAYESOPT_DDP_BARRIER_TIMEOUT=300`
|
|
44
|
+
- `BAYESOPT_CUDA_SYNC=1` (optional; can slow down)
|
|
45
|
+
- `BAYESOPT_CUDA_IPC_COLLECT=1` (optional; can slow down)
|
|
46
|
+
- This function uses `os.environ.setdefault`, so a value already set in the
|
|
47
|
+
shell will take precedence over config.json.
|
|
48
|
+
"""
|
|
49
|
+
for key, value in (env_overrides or {}).items():
|
|
50
|
+
os.environ.setdefault(str(key), str(value))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _looks_like_url(value: str) -> bool:
|
|
54
|
+
value = str(value)
|
|
55
|
+
return "://" in value
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
|
|
59
|
+
"""Resolve relative paths against the config.json directory.
|
|
60
|
+
|
|
61
|
+
Fields handled:
|
|
62
|
+
- data_dir / output_dir / optuna_storage / gnn_graph_cache
|
|
63
|
+
- best_params_files (dict: model_key -> path)
|
|
64
|
+
"""
|
|
65
|
+
base_dir = config_path.parent
|
|
66
|
+
out = dict(cfg)
|
|
67
|
+
|
|
68
|
+
for key in ("data_dir", "output_dir", "gnn_graph_cache", "preprocess_artifact_path",
|
|
69
|
+
"prediction_cache_dir", "report_output_dir", "registry_path"):
|
|
70
|
+
if key in out and isinstance(out.get(key), str):
|
|
71
|
+
resolved = resolve_path(out.get(key), base_dir)
|
|
72
|
+
if resolved is not None:
|
|
73
|
+
out[key] = str(resolved)
|
|
74
|
+
|
|
75
|
+
storage = out.get("optuna_storage")
|
|
76
|
+
if isinstance(storage, str) and storage.strip():
|
|
77
|
+
if not _looks_like_url(storage):
|
|
78
|
+
resolved = resolve_path(storage, base_dir)
|
|
79
|
+
if resolved is not None:
|
|
80
|
+
out["optuna_storage"] = str(resolved)
|
|
81
|
+
|
|
82
|
+
best_files = out.get("best_params_files")
|
|
83
|
+
if isinstance(best_files, dict):
|
|
84
|
+
resolved_map: Dict[str, str] = {}
|
|
85
|
+
for mk, path_str in best_files.items():
|
|
86
|
+
if not isinstance(path_str, str):
|
|
87
|
+
continue
|
|
88
|
+
resolved = resolve_path(path_str, base_dir)
|
|
89
|
+
resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
|
|
90
|
+
out["best_params_files"] = resolved_map
|
|
91
|
+
|
|
92
|
+
return out
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
|
|
96
|
+
if value is None:
|
|
97
|
+
return {}
|
|
98
|
+
if isinstance(value, dict):
|
|
99
|
+
return {str(k): v for k, v in value.items()}
|
|
100
|
+
if isinstance(value, str):
|
|
101
|
+
path = resolve_path(value, base_dir)
|
|
102
|
+
if path is None or not path.exists():
|
|
103
|
+
raise FileNotFoundError(f"dtype_map not found: {value}")
|
|
104
|
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
|
105
|
+
if not isinstance(payload, dict):
|
|
106
|
+
raise ValueError(f"dtype_map must be a dict: {path}")
|
|
107
|
+
return {str(k): v for k, v in payload.items()}
|
|
108
|
+
raise ValueError("dtype_map must be a dict or JSON path.")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def resolve_data_config(
|
|
112
|
+
cfg: Dict[str, Any],
|
|
113
|
+
config_path: Path,
|
|
114
|
+
*,
|
|
115
|
+
create_data_dir: bool = False,
|
|
116
|
+
) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
|
|
117
|
+
base_dir = config_path.parent
|
|
118
|
+
data_dir = resolve_dir_path(cfg.get("data_dir"), base_dir, create=create_data_dir)
|
|
119
|
+
if data_dir is None:
|
|
120
|
+
raise ValueError("data_dir is required in config.json.")
|
|
121
|
+
data_format = cfg.get("data_format", "csv")
|
|
122
|
+
data_path_template = cfg.get("data_path_template")
|
|
123
|
+
dtype_map = resolve_dtype_map(cfg.get("dtype_map"), base_dir)
|
|
124
|
+
return data_dir, data_format, data_path_template, dtype_map
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def add_config_json_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--config-json",
|
|
130
|
+
required=True,
|
|
131
|
+
help=help_text,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def add_output_dir_arg(parser: argparse.ArgumentParser, *, help_text: str) -> None:
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"--output-dir",
|
|
138
|
+
default=None,
|
|
139
|
+
help=help_text,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def resolve_model_path_value(
|
|
144
|
+
value: Any,
|
|
145
|
+
*,
|
|
146
|
+
model_name: str,
|
|
147
|
+
base_dir: Path,
|
|
148
|
+
data_dir: Optional[Path] = None,
|
|
149
|
+
) -> Optional[Path]:
|
|
150
|
+
if value is None:
|
|
151
|
+
return None
|
|
152
|
+
if isinstance(value, dict):
|
|
153
|
+
value = value.get(model_name)
|
|
154
|
+
if value is None:
|
|
155
|
+
return None
|
|
156
|
+
path_str = str(value)
|
|
157
|
+
try:
|
|
158
|
+
path_str = path_str.format(model_name=model_name)
|
|
159
|
+
except Exception:
|
|
160
|
+
pass
|
|
161
|
+
if data_dir is not None and not Path(path_str).is_absolute():
|
|
162
|
+
candidate = data_dir / path_str
|
|
163
|
+
if candidate.exists():
|
|
164
|
+
return candidate.resolve()
|
|
165
|
+
resolved = resolve_path(path_str, base_dir)
|
|
166
|
+
if resolved is None:
|
|
167
|
+
return None
|
|
168
|
+
return resolved
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def resolve_explain_save_root(value: Any, base_dir: Path) -> Optional[Path]:
|
|
172
|
+
if not value:
|
|
173
|
+
return None
|
|
174
|
+
path_str = str(value)
|
|
175
|
+
resolved = resolve_path(path_str, base_dir)
|
|
176
|
+
return resolved if resolved is not None else Path(path_str)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def resolve_explain_save_dir(
|
|
180
|
+
save_root: Optional[Path],
|
|
181
|
+
*,
|
|
182
|
+
result_dir: Optional[Any],
|
|
183
|
+
) -> Path:
|
|
184
|
+
if save_root is not None:
|
|
185
|
+
return Path(save_root)
|
|
186
|
+
if result_dir is None:
|
|
187
|
+
raise ValueError("result_dir is required when explain save_root is not set.")
|
|
188
|
+
return Path(result_dir) / "explain"
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def resolve_explain_output_overrides(
|
|
192
|
+
explain_cfg: Dict[str, Any],
|
|
193
|
+
*,
|
|
194
|
+
model_name: str,
|
|
195
|
+
base_dir: Path,
|
|
196
|
+
) -> Dict[str, Optional[Path]]:
|
|
197
|
+
return {
|
|
198
|
+
"model_dir": resolve_model_path_value(
|
|
199
|
+
explain_cfg.get("model_dir"),
|
|
200
|
+
model_name=model_name,
|
|
201
|
+
base_dir=base_dir,
|
|
202
|
+
data_dir=None,
|
|
203
|
+
),
|
|
204
|
+
"result_dir": resolve_model_path_value(
|
|
205
|
+
explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
|
|
206
|
+
model_name=model_name,
|
|
207
|
+
base_dir=base_dir,
|
|
208
|
+
data_dir=None,
|
|
209
|
+
),
|
|
210
|
+
"plot_dir": resolve_model_path_value(
|
|
211
|
+
explain_cfg.get("plot_dir"),
|
|
212
|
+
model_name=model_name,
|
|
213
|
+
base_dir=base_dir,
|
|
214
|
+
data_dir=None,
|
|
215
|
+
),
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def resolve_and_load_config(
|
|
220
|
+
raw: str,
|
|
221
|
+
script_dir: Path,
|
|
222
|
+
required_keys: Sequence[str],
|
|
223
|
+
*,
|
|
224
|
+
apply_env: bool = True,
|
|
225
|
+
) -> Tuple[Path, Dict[str, Any]]:
|
|
226
|
+
config_path = resolve_config_path(raw, script_dir)
|
|
227
|
+
cfg = load_config_json(config_path, required_keys=required_keys)
|
|
228
|
+
cfg = normalize_config_paths(cfg, config_path)
|
|
229
|
+
if apply_env:
|
|
230
|
+
set_env(cfg.get("env", {}))
|
|
231
|
+
return config_path, cfg
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
235
|
+
def _as_list(value: Any) -> list[str]:
|
|
236
|
+
if value is None:
|
|
237
|
+
return []
|
|
238
|
+
if isinstance(value, (list, tuple, set)):
|
|
239
|
+
return [str(item) for item in value]
|
|
240
|
+
return [str(value)]
|
|
241
|
+
|
|
242
|
+
report_output_dir = cfg.get("report_output_dir")
|
|
243
|
+
report_group_cols = _as_list(cfg.get("report_group_cols"))
|
|
244
|
+
if not report_group_cols:
|
|
245
|
+
report_group_cols = None
|
|
246
|
+
report_time_col = cfg.get("report_time_col")
|
|
247
|
+
report_time_freq = cfg.get("report_time_freq", "M")
|
|
248
|
+
report_time_ascending = bool(cfg.get("report_time_ascending", True))
|
|
249
|
+
psi_bins = cfg.get("psi_bins", 10)
|
|
250
|
+
psi_strategy = cfg.get("psi_strategy", "quantile")
|
|
251
|
+
psi_features = _as_list(cfg.get("psi_features"))
|
|
252
|
+
if not psi_features:
|
|
253
|
+
psi_features = None
|
|
254
|
+
calibration_cfg = cfg.get("calibration", {}) or {}
|
|
255
|
+
threshold_cfg = cfg.get("threshold", {}) or {}
|
|
256
|
+
bootstrap_cfg = cfg.get("bootstrap", {}) or {}
|
|
257
|
+
register_model = bool(cfg.get("register_model", False))
|
|
258
|
+
registry_path = cfg.get("registry_path")
|
|
259
|
+
registry_tags = cfg.get("registry_tags", {})
|
|
260
|
+
registry_status = cfg.get("registry_status", "candidate")
|
|
261
|
+
data_fingerprint_max_bytes = int(
|
|
262
|
+
cfg.get("data_fingerprint_max_bytes", 10_485_760))
|
|
263
|
+
calibration_enabled = bool(
|
|
264
|
+
calibration_cfg.get("enable", False) or calibration_cfg.get("method")
|
|
265
|
+
)
|
|
266
|
+
threshold_enabled = bool(
|
|
267
|
+
threshold_cfg.get("enable", False)
|
|
268
|
+
or threshold_cfg.get("value") is not None
|
|
269
|
+
or threshold_cfg.get("metric")
|
|
270
|
+
)
|
|
271
|
+
bootstrap_enabled = bool(bootstrap_cfg.get("enable", False))
|
|
272
|
+
report_enabled = any([
|
|
273
|
+
bool(report_output_dir),
|
|
274
|
+
register_model,
|
|
275
|
+
bool(report_group_cols),
|
|
276
|
+
bool(report_time_col),
|
|
277
|
+
bool(psi_features),
|
|
278
|
+
calibration_enabled,
|
|
279
|
+
threshold_enabled,
|
|
280
|
+
bootstrap_enabled,
|
|
281
|
+
])
|
|
282
|
+
return {
|
|
283
|
+
"report_output_dir": report_output_dir,
|
|
284
|
+
"report_group_cols": report_group_cols,
|
|
285
|
+
"report_time_col": report_time_col,
|
|
286
|
+
"report_time_freq": report_time_freq,
|
|
287
|
+
"report_time_ascending": report_time_ascending,
|
|
288
|
+
"psi_bins": psi_bins,
|
|
289
|
+
"psi_strategy": psi_strategy,
|
|
290
|
+
"psi_features": psi_features,
|
|
291
|
+
"calibration_cfg": calibration_cfg,
|
|
292
|
+
"threshold_cfg": threshold_cfg,
|
|
293
|
+
"bootstrap_cfg": bootstrap_cfg,
|
|
294
|
+
"register_model": register_model,
|
|
295
|
+
"registry_path": registry_path,
|
|
296
|
+
"registry_tags": registry_tags,
|
|
297
|
+
"registry_status": registry_status,
|
|
298
|
+
"data_fingerprint_max_bytes": data_fingerprint_max_bytes,
|
|
299
|
+
"report_enabled": report_enabled,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
304
|
+
prop_test = cfg.get("prop_test", 0.25)
|
|
305
|
+
holdout_ratio = cfg.get("holdout_ratio", prop_test)
|
|
306
|
+
if holdout_ratio is None:
|
|
307
|
+
holdout_ratio = prop_test
|
|
308
|
+
val_ratio = cfg.get("val_ratio", prop_test)
|
|
309
|
+
if val_ratio is None:
|
|
310
|
+
val_ratio = prop_test
|
|
311
|
+
split_strategy = str(cfg.get("split_strategy", "random")).strip().lower()
|
|
312
|
+
split_group_col = cfg.get("split_group_col")
|
|
313
|
+
split_time_col = cfg.get("split_time_col")
|
|
314
|
+
split_time_ascending = bool(cfg.get("split_time_ascending", True))
|
|
315
|
+
cv_strategy = cfg.get("cv_strategy")
|
|
316
|
+
cv_group_col = cfg.get("cv_group_col")
|
|
317
|
+
cv_time_col = cfg.get("cv_time_col")
|
|
318
|
+
cv_time_ascending = cfg.get("cv_time_ascending", split_time_ascending)
|
|
319
|
+
cv_splits = cfg.get("cv_splits")
|
|
320
|
+
ft_oof_folds = cfg.get("ft_oof_folds")
|
|
321
|
+
ft_oof_strategy = cfg.get("ft_oof_strategy")
|
|
322
|
+
ft_oof_shuffle = cfg.get("ft_oof_shuffle", True)
|
|
323
|
+
return {
|
|
324
|
+
"prop_test": prop_test,
|
|
325
|
+
"holdout_ratio": holdout_ratio,
|
|
326
|
+
"val_ratio": val_ratio,
|
|
327
|
+
"split_strategy": split_strategy,
|
|
328
|
+
"split_group_col": split_group_col,
|
|
329
|
+
"split_time_col": split_time_col,
|
|
330
|
+
"split_time_ascending": split_time_ascending,
|
|
331
|
+
"cv_strategy": cv_strategy,
|
|
332
|
+
"cv_group_col": cv_group_col,
|
|
333
|
+
"cv_time_col": cv_time_col,
|
|
334
|
+
"cv_time_ascending": cv_time_ascending,
|
|
335
|
+
"cv_splits": cv_splits,
|
|
336
|
+
"ft_oof_folds": ft_oof_folds,
|
|
337
|
+
"ft_oof_strategy": ft_oof_strategy,
|
|
338
|
+
"ft_oof_shuffle": ft_oof_shuffle,
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
343
|
+
return {
|
|
344
|
+
"save_preprocess": bool(cfg.get("save_preprocess", False)),
|
|
345
|
+
"preprocess_artifact_path": cfg.get("preprocess_artifact_path"),
|
|
346
|
+
"rand_seed": cfg.get("rand_seed", 13),
|
|
347
|
+
"epochs": cfg.get("epochs", 50),
|
|
348
|
+
"plot_path_style": cfg.get("plot_path_style"),
|
|
349
|
+
"reuse_best_params": bool(cfg.get("reuse_best_params", False)),
|
|
350
|
+
"xgb_max_depth_max": int(cfg.get("xgb_max_depth_max", 25)),
|
|
351
|
+
"xgb_n_estimators_max": int(cfg.get("xgb_n_estimators_max", 500)),
|
|
352
|
+
"optuna_storage": cfg.get("optuna_storage"),
|
|
353
|
+
"optuna_study_prefix": cfg.get("optuna_study_prefix"),
|
|
354
|
+
"best_params_files": cfg.get("best_params_files"),
|
|
355
|
+
"bo_sample_limit": cfg.get("bo_sample_limit"),
|
|
356
|
+
"cache_predictions": bool(cfg.get("cache_predictions", False)),
|
|
357
|
+
"prediction_cache_dir": cfg.get("prediction_cache_dir"),
|
|
358
|
+
"prediction_cache_format": cfg.get("prediction_cache_format", "parquet"),
|
|
359
|
+
"ddp_min_rows": cfg.get("ddp_min_rows", 50000),
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def resolve_output_dirs(
|
|
364
|
+
cfg: Dict[str, Any],
|
|
365
|
+
config_path: Path,
|
|
366
|
+
*,
|
|
367
|
+
output_override: Optional[str] = None,
|
|
368
|
+
) -> Dict[str, Optional[str]]:
|
|
369
|
+
output_root = resolve_dir_path(
|
|
370
|
+
output_override or cfg.get("output_dir"),
|
|
371
|
+
config_path.parent,
|
|
372
|
+
)
|
|
373
|
+
return {
|
|
374
|
+
"output_dir": str(output_root) if output_root is not None else None,
|
|
375
|
+
}
|