ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +52 -50
- ins_pricing/cli/BayesOpt_incremental.py +39 -105
- ins_pricing/cli/Explain_Run.py +31 -23
- ins_pricing/cli/Explain_entry.py +532 -579
- ins_pricing/cli/Pricing_Run.py +31 -23
- ins_pricing/cli/bayesopt_entry_runner.py +11 -9
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +375 -375
- ins_pricing/cli/utils/import_resolver.py +382 -365
- ins_pricing/cli/utils/notebook_utils.py +340 -340
- ins_pricing/cli/watchdog_run.py +209 -201
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +116 -83
- ins_pricing/utils/device.py +255 -255
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
- ins_pricing-0.5.0.dist-info/RECORD +131 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
ins_pricing/cli/Pricing_Run.py
CHANGED
|
@@ -1,25 +1,33 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
5
|
+
import importlib.util
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
if __package__ in {None, ""}:
|
|
9
|
+
if importlib.util.find_spec("ins_pricing") is None:
|
|
10
|
+
repo_root = Path(__file__).resolve().parents[2]
|
|
11
|
+
if str(repo_root) not in sys.path:
|
|
12
|
+
sys.path.insert(0, str(repo_root))
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from ins_pricing.cli.utils.notebook_utils import run_from_config, run_from_config_cli # type: ignore
|
|
16
|
+
except Exception: # pragma: no cover
|
|
17
|
+
from utils.notebook_utils import run_from_config, run_from_config_cli # type: ignore
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def run(config_json: str | Path) -> None:
|
|
21
|
+
"""Unified entry point: run entry/incremental/watchdog/DDP based on config.json runner."""
|
|
22
|
+
run_from_config(config_json)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main(argv: Optional[list[str]] = None) -> None:
|
|
26
|
+
run_from_config_cli(
|
|
27
|
+
"Pricing_Run: run BayesOpt by config.json (entry/incremental/watchdog/DDP).",
|
|
28
|
+
argv,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if __name__ == "__main__":
|
|
33
|
+
main()
|
|
@@ -11,13 +11,15 @@ Example:
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
import
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
if
|
|
20
|
-
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
import importlib.util
|
|
16
|
+
import sys
|
|
17
|
+
|
|
18
|
+
if __package__ in {None, ""}:
|
|
19
|
+
if importlib.util.find_spec("ins_pricing") is None:
|
|
20
|
+
repo_root = Path(__file__).resolve().parents[2]
|
|
21
|
+
if str(repo_root) not in sys.path:
|
|
22
|
+
sys.path.insert(0, str(repo_root))
|
|
21
23
|
|
|
22
24
|
import argparse
|
|
23
25
|
import hashlib
|
|
@@ -30,8 +32,8 @@ import numpy as np
|
|
|
30
32
|
import pandas as pd
|
|
31
33
|
|
|
32
34
|
# Use unified import resolver to eliminate nested try/except chains
|
|
33
|
-
from .utils.import_resolver import resolve_imports, setup_sys_path
|
|
34
|
-
from .utils.evaluation_context import (
|
|
35
|
+
from ins_pricing.cli.utils.import_resolver import resolve_imports, setup_sys_path
|
|
36
|
+
from ins_pricing.cli.utils.evaluation_context import (
|
|
35
37
|
EvaluationContext,
|
|
36
38
|
TrainingContext,
|
|
37
39
|
ModelIdentity,
|
|
@@ -1,256 +1,256 @@
|
|
|
1
|
-
"""CLI common utilities.
|
|
2
|
-
|
|
3
|
-
This module re-exports shared utilities from ins_pricing.utils and provides
|
|
4
|
-
CLI-specific functionality for configuration loading and train/test splitting.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from __future__ import annotations
|
|
8
|
-
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
11
|
-
|
|
12
|
-
import pandas as pd
|
|
13
|
-
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
|
14
|
-
|
|
15
|
-
# Re-export shared utilities for backward compatibility
|
|
16
|
-
from ins_pricing.utils.paths import (
|
|
17
|
-
PLOT_MODEL_LABELS,
|
|
18
|
-
PYTORCH_TRAINERS,
|
|
19
|
-
dedupe_preserve_order,
|
|
20
|
-
build_model_names,
|
|
21
|
-
parse_model_pairs,
|
|
22
|
-
resolve_path,
|
|
23
|
-
resolve_dir_path,
|
|
24
|
-
resolve_data_path,
|
|
25
|
-
load_dataset,
|
|
26
|
-
coerce_dataset_types,
|
|
27
|
-
fingerprint_file,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
__all__ = [
|
|
31
|
-
# From shared utils
|
|
32
|
-
"PLOT_MODEL_LABELS",
|
|
33
|
-
"PYTORCH_TRAINERS",
|
|
34
|
-
"dedupe_preserve_order",
|
|
35
|
-
"build_model_names",
|
|
36
|
-
"parse_model_pairs",
|
|
37
|
-
"resolve_path",
|
|
38
|
-
"resolve_dir_path",
|
|
39
|
-
"resolve_data_path",
|
|
40
|
-
"load_dataset",
|
|
41
|
-
"coerce_dataset_types",
|
|
42
|
-
"fingerprint_file",
|
|
43
|
-
# CLI-specific
|
|
44
|
-
"split_train_test",
|
|
45
|
-
"resolve_config_path",
|
|
46
|
-
"load_config_json",
|
|
47
|
-
"set_env",
|
|
48
|
-
"normalize_config_paths",
|
|
49
|
-
"resolve_dtype_map",
|
|
50
|
-
"resolve_data_config",
|
|
51
|
-
"resolve_report_config",
|
|
52
|
-
"resolve_split_config",
|
|
53
|
-
"resolve_runtime_config",
|
|
54
|
-
"resolve_output_dirs",
|
|
55
|
-
"resolve_and_load_config",
|
|
56
|
-
]
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
# =============================================================================
|
|
60
|
-
# CLI-specific: Train/Test Splitting
|
|
61
|
-
# =============================================================================
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def split_train_test(
|
|
65
|
-
df: pd.DataFrame,
|
|
66
|
-
*,
|
|
67
|
-
holdout_ratio: float,
|
|
68
|
-
strategy: str = "random",
|
|
69
|
-
group_col: Optional[str] = None,
|
|
70
|
-
time_col: Optional[str] = None,
|
|
71
|
-
time_ascending: bool = True,
|
|
72
|
-
rand_seed: Optional[int] = None,
|
|
73
|
-
reset_index_mode: str = "none",
|
|
74
|
-
ratio_label: str = "holdout_ratio",
|
|
75
|
-
include_strategy_in_ratio_error: bool = False,
|
|
76
|
-
validate_ratio: bool = True,
|
|
77
|
-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
78
|
-
"""Split a DataFrame into train and test sets.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
df: Input DataFrame
|
|
82
|
-
holdout_ratio: Proportion for test set (0.0-1.0)
|
|
83
|
-
strategy: Split strategy ('random', 'time', 'group')
|
|
84
|
-
group_col: Column name for group-based splitting
|
|
85
|
-
time_col: Column name for time-based splitting
|
|
86
|
-
time_ascending: Sort order for time-based splitting
|
|
87
|
-
rand_seed: Random seed for reproducibility
|
|
88
|
-
reset_index_mode: When to reset index ('none', 'always', 'time_group')
|
|
89
|
-
ratio_label: Label for ratio in error messages
|
|
90
|
-
include_strategy_in_ratio_error: Include strategy in error messages
|
|
91
|
-
validate_ratio: Whether to validate ratio bounds
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
Tuple of (train_df, test_df)
|
|
95
|
-
"""
|
|
96
|
-
strategy = str(strategy or "random").strip().lower()
|
|
97
|
-
holdout_ratio = float(holdout_ratio)
|
|
98
|
-
|
|
99
|
-
if include_strategy_in_ratio_error and strategy in {
|
|
100
|
-
"time",
|
|
101
|
-
"timeseries",
|
|
102
|
-
"temporal",
|
|
103
|
-
"group",
|
|
104
|
-
"grouped",
|
|
105
|
-
}:
|
|
106
|
-
strategy_label = (
|
|
107
|
-
"time" if strategy in {"time", "timeseries", "temporal"} else "group"
|
|
108
|
-
)
|
|
109
|
-
ratio_error = (
|
|
110
|
-
f"{ratio_label} must be in (0, 1) for {strategy_label} split; "
|
|
111
|
-
f"got {holdout_ratio}."
|
|
112
|
-
)
|
|
113
|
-
else:
|
|
114
|
-
ratio_error = f"{ratio_label} must be in (0, 1); got {holdout_ratio}."
|
|
115
|
-
|
|
116
|
-
if strategy in {"time", "timeseries", "temporal"}:
|
|
117
|
-
if not time_col:
|
|
118
|
-
raise ValueError("split_time_col is required for time split_strategy.")
|
|
119
|
-
if time_col not in df.columns:
|
|
120
|
-
raise KeyError(f"split_time_col '{time_col}' not in dataset columns.")
|
|
121
|
-
if validate_ratio and not (0.0 < holdout_ratio < 1.0):
|
|
122
|
-
raise ValueError(ratio_error)
|
|
123
|
-
ordered = df.sort_values(time_col, ascending=bool(time_ascending))
|
|
124
|
-
cutoff = int(len(ordered) * (1.0 - holdout_ratio))
|
|
125
|
-
if cutoff <= 0 or cutoff >= len(ordered):
|
|
126
|
-
raise ValueError(
|
|
127
|
-
f"{ratio_label}={holdout_ratio} leaves no data for train/test split."
|
|
128
|
-
)
|
|
129
|
-
train_df = ordered.iloc[:cutoff]
|
|
130
|
-
test_df = ordered.iloc[cutoff:]
|
|
131
|
-
elif strategy in {"group", "grouped"}:
|
|
132
|
-
if not group_col:
|
|
133
|
-
raise ValueError("split_group_col is required for group split_strategy.")
|
|
134
|
-
if group_col not in df.columns:
|
|
135
|
-
raise KeyError(f"split_group_col '{group_col}' not in dataset columns.")
|
|
136
|
-
if validate_ratio and not (0.0 < holdout_ratio < 1.0):
|
|
137
|
-
raise ValueError(ratio_error)
|
|
138
|
-
splitter = GroupShuffleSplit(
|
|
139
|
-
n_splits=1,
|
|
140
|
-
test_size=holdout_ratio,
|
|
141
|
-
random_state=rand_seed,
|
|
142
|
-
)
|
|
143
|
-
train_idx, test_idx = next(splitter.split(df, groups=df[group_col]))
|
|
144
|
-
train_df = df.iloc[train_idx]
|
|
145
|
-
test_df = df.iloc[test_idx]
|
|
146
|
-
else:
|
|
147
|
-
train_df, test_df = train_test_split(
|
|
148
|
-
df, test_size=holdout_ratio, random_state=rand_seed
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
if reset_index_mode == "always" or (
|
|
152
|
-
reset_index_mode == "time_group"
|
|
153
|
-
and strategy in {"time", "timeseries", "temporal", "group", "grouped"}
|
|
154
|
-
):
|
|
155
|
-
train_df = train_df.reset_index(drop=True)
|
|
156
|
-
test_df = test_df.reset_index(drop=True)
|
|
157
|
-
|
|
158
|
-
return train_df, test_df
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
# =============================================================================
|
|
162
|
-
# CLI-specific: Configuration Loading (delegated to cli_config)
|
|
163
|
-
# =============================================================================
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
def _load_cli_config():
|
|
167
|
-
"""Load the cli_config module."""
|
|
168
|
-
try:
|
|
169
|
-
from . import cli_config as _cli_config
|
|
170
|
-
except Exception:
|
|
171
|
-
import cli_config as _cli_config
|
|
172
|
-
return _cli_config
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def resolve_config_path(raw: str, script_dir: Path) -> Path:
|
|
176
|
-
"""Resolve a configuration file path."""
|
|
177
|
-
return _load_cli_config().resolve_config_path(raw, script_dir)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
|
|
181
|
-
"""Load and validate a JSON configuration file."""
|
|
182
|
-
return _load_cli_config().load_config_json(path, required_keys)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
def set_env(env_overrides: Dict[str, Any]) -> None:
|
|
186
|
-
"""Set environment variables from configuration."""
|
|
187
|
-
_load_cli_config().set_env(env_overrides)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
|
|
191
|
-
"""Normalize paths in configuration relative to config file location."""
|
|
192
|
-
return _load_cli_config().normalize_config_paths(cfg, config_path)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
|
|
196
|
-
"""Resolve dtype mapping from configuration."""
|
|
197
|
-
return _load_cli_config().resolve_dtype_map(value, base_dir)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def resolve_data_config(
|
|
201
|
-
cfg: Dict[str, Any],
|
|
202
|
-
config_path: Path,
|
|
203
|
-
*,
|
|
204
|
-
create_data_dir: bool = False,
|
|
205
|
-
) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
|
|
206
|
-
"""Resolve data configuration from config file."""
|
|
207
|
-
return _load_cli_config().resolve_data_config(
|
|
208
|
-
cfg,
|
|
209
|
-
config_path,
|
|
210
|
-
create_data_dir=create_data_dir,
|
|
211
|
-
)
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
215
|
-
"""Resolve report configuration."""
|
|
216
|
-
return _load_cli_config().resolve_report_config(cfg)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
220
|
-
"""Resolve train/test split configuration."""
|
|
221
|
-
return _load_cli_config().resolve_split_config(cfg)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
225
|
-
"""Resolve runtime configuration."""
|
|
226
|
-
return _load_cli_config().resolve_runtime_config(cfg)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def resolve_output_dirs(
|
|
230
|
-
cfg: Dict[str, Any],
|
|
231
|
-
config_path: Path,
|
|
232
|
-
*,
|
|
233
|
-
output_override: Optional[str] = None,
|
|
234
|
-
) -> Dict[str, Optional[str]]:
|
|
235
|
-
"""Resolve output directory paths."""
|
|
236
|
-
return _load_cli_config().resolve_output_dirs(
|
|
237
|
-
cfg,
|
|
238
|
-
config_path,
|
|
239
|
-
output_override=output_override,
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
def resolve_and_load_config(
|
|
244
|
-
raw: str,
|
|
245
|
-
script_dir: Path,
|
|
246
|
-
required_keys: Sequence[str],
|
|
247
|
-
*,
|
|
248
|
-
apply_env: bool = True,
|
|
249
|
-
) -> Tuple[Path, Dict[str, Any]]:
|
|
250
|
-
"""Resolve and load a configuration file."""
|
|
251
|
-
return _load_cli_config().resolve_and_load_config(
|
|
252
|
-
raw,
|
|
253
|
-
script_dir,
|
|
254
|
-
required_keys,
|
|
255
|
-
apply_env=apply_env,
|
|
256
|
-
)
|
|
1
|
+
"""CLI common utilities.
|
|
2
|
+
|
|
3
|
+
This module re-exports shared utilities from ins_pricing.utils and provides
|
|
4
|
+
CLI-specific functionality for configuration loading and train/test splitting.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
|
14
|
+
|
|
15
|
+
# Re-export shared utilities for backward compatibility
|
|
16
|
+
from ins_pricing.utils.paths import (
|
|
17
|
+
PLOT_MODEL_LABELS,
|
|
18
|
+
PYTORCH_TRAINERS,
|
|
19
|
+
dedupe_preserve_order,
|
|
20
|
+
build_model_names,
|
|
21
|
+
parse_model_pairs,
|
|
22
|
+
resolve_path,
|
|
23
|
+
resolve_dir_path,
|
|
24
|
+
resolve_data_path,
|
|
25
|
+
load_dataset,
|
|
26
|
+
coerce_dataset_types,
|
|
27
|
+
fingerprint_file,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
# From shared utils
|
|
32
|
+
"PLOT_MODEL_LABELS",
|
|
33
|
+
"PYTORCH_TRAINERS",
|
|
34
|
+
"dedupe_preserve_order",
|
|
35
|
+
"build_model_names",
|
|
36
|
+
"parse_model_pairs",
|
|
37
|
+
"resolve_path",
|
|
38
|
+
"resolve_dir_path",
|
|
39
|
+
"resolve_data_path",
|
|
40
|
+
"load_dataset",
|
|
41
|
+
"coerce_dataset_types",
|
|
42
|
+
"fingerprint_file",
|
|
43
|
+
# CLI-specific
|
|
44
|
+
"split_train_test",
|
|
45
|
+
"resolve_config_path",
|
|
46
|
+
"load_config_json",
|
|
47
|
+
"set_env",
|
|
48
|
+
"normalize_config_paths",
|
|
49
|
+
"resolve_dtype_map",
|
|
50
|
+
"resolve_data_config",
|
|
51
|
+
"resolve_report_config",
|
|
52
|
+
"resolve_split_config",
|
|
53
|
+
"resolve_runtime_config",
|
|
54
|
+
"resolve_output_dirs",
|
|
55
|
+
"resolve_and_load_config",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# =============================================================================
|
|
60
|
+
# CLI-specific: Train/Test Splitting
|
|
61
|
+
# =============================================================================
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def split_train_test(
|
|
65
|
+
df: pd.DataFrame,
|
|
66
|
+
*,
|
|
67
|
+
holdout_ratio: float,
|
|
68
|
+
strategy: str = "random",
|
|
69
|
+
group_col: Optional[str] = None,
|
|
70
|
+
time_col: Optional[str] = None,
|
|
71
|
+
time_ascending: bool = True,
|
|
72
|
+
rand_seed: Optional[int] = None,
|
|
73
|
+
reset_index_mode: str = "none",
|
|
74
|
+
ratio_label: str = "holdout_ratio",
|
|
75
|
+
include_strategy_in_ratio_error: bool = False,
|
|
76
|
+
validate_ratio: bool = True,
|
|
77
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
78
|
+
"""Split a DataFrame into train and test sets.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
df: Input DataFrame
|
|
82
|
+
holdout_ratio: Proportion for test set (0.0-1.0)
|
|
83
|
+
strategy: Split strategy ('random', 'time', 'group')
|
|
84
|
+
group_col: Column name for group-based splitting
|
|
85
|
+
time_col: Column name for time-based splitting
|
|
86
|
+
time_ascending: Sort order for time-based splitting
|
|
87
|
+
rand_seed: Random seed for reproducibility
|
|
88
|
+
reset_index_mode: When to reset index ('none', 'always', 'time_group')
|
|
89
|
+
ratio_label: Label for ratio in error messages
|
|
90
|
+
include_strategy_in_ratio_error: Include strategy in error messages
|
|
91
|
+
validate_ratio: Whether to validate ratio bounds
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Tuple of (train_df, test_df)
|
|
95
|
+
"""
|
|
96
|
+
strategy = str(strategy or "random").strip().lower()
|
|
97
|
+
holdout_ratio = float(holdout_ratio)
|
|
98
|
+
|
|
99
|
+
if include_strategy_in_ratio_error and strategy in {
|
|
100
|
+
"time",
|
|
101
|
+
"timeseries",
|
|
102
|
+
"temporal",
|
|
103
|
+
"group",
|
|
104
|
+
"grouped",
|
|
105
|
+
}:
|
|
106
|
+
strategy_label = (
|
|
107
|
+
"time" if strategy in {"time", "timeseries", "temporal"} else "group"
|
|
108
|
+
)
|
|
109
|
+
ratio_error = (
|
|
110
|
+
f"{ratio_label} must be in (0, 1) for {strategy_label} split; "
|
|
111
|
+
f"got {holdout_ratio}."
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
ratio_error = f"{ratio_label} must be in (0, 1); got {holdout_ratio}."
|
|
115
|
+
|
|
116
|
+
if strategy in {"time", "timeseries", "temporal"}:
|
|
117
|
+
if not time_col:
|
|
118
|
+
raise ValueError("split_time_col is required for time split_strategy.")
|
|
119
|
+
if time_col not in df.columns:
|
|
120
|
+
raise KeyError(f"split_time_col '{time_col}' not in dataset columns.")
|
|
121
|
+
if validate_ratio and not (0.0 < holdout_ratio < 1.0):
|
|
122
|
+
raise ValueError(ratio_error)
|
|
123
|
+
ordered = df.sort_values(time_col, ascending=bool(time_ascending))
|
|
124
|
+
cutoff = int(len(ordered) * (1.0 - holdout_ratio))
|
|
125
|
+
if cutoff <= 0 or cutoff >= len(ordered):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"{ratio_label}={holdout_ratio} leaves no data for train/test split."
|
|
128
|
+
)
|
|
129
|
+
train_df = ordered.iloc[:cutoff]
|
|
130
|
+
test_df = ordered.iloc[cutoff:]
|
|
131
|
+
elif strategy in {"group", "grouped"}:
|
|
132
|
+
if not group_col:
|
|
133
|
+
raise ValueError("split_group_col is required for group split_strategy.")
|
|
134
|
+
if group_col not in df.columns:
|
|
135
|
+
raise KeyError(f"split_group_col '{group_col}' not in dataset columns.")
|
|
136
|
+
if validate_ratio and not (0.0 < holdout_ratio < 1.0):
|
|
137
|
+
raise ValueError(ratio_error)
|
|
138
|
+
splitter = GroupShuffleSplit(
|
|
139
|
+
n_splits=1,
|
|
140
|
+
test_size=holdout_ratio,
|
|
141
|
+
random_state=rand_seed,
|
|
142
|
+
)
|
|
143
|
+
train_idx, test_idx = next(splitter.split(df, groups=df[group_col]))
|
|
144
|
+
train_df = df.iloc[train_idx]
|
|
145
|
+
test_df = df.iloc[test_idx]
|
|
146
|
+
else:
|
|
147
|
+
train_df, test_df = train_test_split(
|
|
148
|
+
df, test_size=holdout_ratio, random_state=rand_seed
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if reset_index_mode == "always" or (
|
|
152
|
+
reset_index_mode == "time_group"
|
|
153
|
+
and strategy in {"time", "timeseries", "temporal", "group", "grouped"}
|
|
154
|
+
):
|
|
155
|
+
train_df = train_df.reset_index(drop=True)
|
|
156
|
+
test_df = test_df.reset_index(drop=True)
|
|
157
|
+
|
|
158
|
+
return train_df, test_df
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# =============================================================================
|
|
162
|
+
# CLI-specific: Configuration Loading (delegated to cli_config)
|
|
163
|
+
# =============================================================================
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _load_cli_config():
|
|
167
|
+
"""Load the cli_config module."""
|
|
168
|
+
try:
|
|
169
|
+
from ins_pricing.cli.utils import cli_config as _cli_config
|
|
170
|
+
except Exception:
|
|
171
|
+
import cli_config as _cli_config
|
|
172
|
+
return _cli_config
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def resolve_config_path(raw: str, script_dir: Path) -> Path:
|
|
176
|
+
"""Resolve a configuration file path."""
|
|
177
|
+
return _load_cli_config().resolve_config_path(raw, script_dir)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
|
|
181
|
+
"""Load and validate a JSON configuration file."""
|
|
182
|
+
return _load_cli_config().load_config_json(path, required_keys)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def set_env(env_overrides: Dict[str, Any]) -> None:
|
|
186
|
+
"""Set environment variables from configuration."""
|
|
187
|
+
_load_cli_config().set_env(env_overrides)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
|
|
191
|
+
"""Normalize paths in configuration relative to config file location."""
|
|
192
|
+
return _load_cli_config().normalize_config_paths(cfg, config_path)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def resolve_dtype_map(value: Any, base_dir: Path) -> Dict[str, Any]:
|
|
196
|
+
"""Resolve dtype mapping from configuration."""
|
|
197
|
+
return _load_cli_config().resolve_dtype_map(value, base_dir)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def resolve_data_config(
|
|
201
|
+
cfg: Dict[str, Any],
|
|
202
|
+
config_path: Path,
|
|
203
|
+
*,
|
|
204
|
+
create_data_dir: bool = False,
|
|
205
|
+
) -> Tuple[Path, str, Optional[str], Dict[str, Any]]:
|
|
206
|
+
"""Resolve data configuration from config file."""
|
|
207
|
+
return _load_cli_config().resolve_data_config(
|
|
208
|
+
cfg,
|
|
209
|
+
config_path,
|
|
210
|
+
create_data_dir=create_data_dir,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def resolve_report_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
215
|
+
"""Resolve report configuration."""
|
|
216
|
+
return _load_cli_config().resolve_report_config(cfg)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def resolve_split_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
220
|
+
"""Resolve train/test split configuration."""
|
|
221
|
+
return _load_cli_config().resolve_split_config(cfg)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def resolve_runtime_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
225
|
+
"""Resolve runtime configuration."""
|
|
226
|
+
return _load_cli_config().resolve_runtime_config(cfg)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def resolve_output_dirs(
|
|
230
|
+
cfg: Dict[str, Any],
|
|
231
|
+
config_path: Path,
|
|
232
|
+
*,
|
|
233
|
+
output_override: Optional[str] = None,
|
|
234
|
+
) -> Dict[str, Optional[str]]:
|
|
235
|
+
"""Resolve output directory paths."""
|
|
236
|
+
return _load_cli_config().resolve_output_dirs(
|
|
237
|
+
cfg,
|
|
238
|
+
config_path,
|
|
239
|
+
output_override=output_override,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def resolve_and_load_config(
|
|
244
|
+
raw: str,
|
|
245
|
+
script_dir: Path,
|
|
246
|
+
required_keys: Sequence[str],
|
|
247
|
+
*,
|
|
248
|
+
apply_env: bool = True,
|
|
249
|
+
) -> Tuple[Path, Dict[str, Any]]:
|
|
250
|
+
"""Resolve and load a configuration file."""
|
|
251
|
+
return _load_cli_config().resolve_and_load_config(
|
|
252
|
+
raw,
|
|
253
|
+
script_dir,
|
|
254
|
+
required_keys,
|
|
255
|
+
apply_env=apply_env,
|
|
256
|
+
)
|