ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +58 -46
- ins_pricing/cli/BayesOpt_incremental.py +77 -110
- ins_pricing/cli/Explain_Run.py +42 -23
- ins_pricing/cli/Explain_entry.py +551 -577
- ins_pricing/cli/Pricing_Run.py +42 -23
- ins_pricing/cli/bayesopt_entry_runner.py +51 -16
- ins_pricing/cli/utils/bootstrap.py +23 -0
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +379 -360
- ins_pricing/cli/utils/import_resolver.py +375 -358
- ins_pricing/cli/utils/notebook_utils.py +256 -242
- ins_pricing/cli/watchdog_run.py +216 -198
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/app.py +132 -61
- ins_pricing/frontend/config_builder.py +33 -0
- ins_pricing/frontend/example_config.json +11 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/frontend/runner.py +340 -388
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +1 -1
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/pricing/factors.py +67 -56
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +112 -78
- ins_pricing/utils/device.py +258 -237
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/utils/logging.py +34 -1
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- ins_pricing/utils/profiling.py +8 -4
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
- ins_pricing-0.5.1.dist-info/RECORD +132 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
"""File I/O and parameter loading utilities.
|
|
2
|
-
|
|
3
|
-
This module contains:
|
|
4
|
-
- IOUtils class for loading parameters from JSON/CSV/TSV files
|
|
5
|
-
- csv_to_dict() for CSV file handling
|
|
6
|
-
- File path sanitization utilities
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from __future__ import annotations
|
|
10
|
-
|
|
11
|
-
import csv
|
|
12
|
-
import json
|
|
13
|
-
from pathlib import Path
|
|
14
|
-
from typing import Any, Dict, List
|
|
15
|
-
|
|
16
|
-
import pandas as pd
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class IOUtils:
|
|
20
|
-
"""File and path utilities for model parameters and configs."""
|
|
21
|
-
|
|
22
|
-
@staticmethod
|
|
23
|
-
def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
|
|
24
|
-
"""Load CSV file as list of dictionaries.
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
file_path: Path to CSV file
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
List of dictionaries, one per row
|
|
31
|
-
"""
|
|
32
|
-
with open(file_path, mode='r', encoding='utf-8') as file:
|
|
33
|
-
reader = csv.DictReader(file)
|
|
34
|
-
return [
|
|
35
|
-
dict(filter(lambda item: item[0] != '', row.items()))
|
|
36
|
-
for row in reader
|
|
37
|
-
]
|
|
38
|
-
|
|
39
|
-
@staticmethod
|
|
40
|
-
def ensure_parent_dir(file_path: str) -> None:
|
|
41
|
-
"""Create parent directories when missing.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
file_path: Path to file whose parent directory should be created
|
|
45
|
-
"""
|
|
46
|
-
directory = Path(file_path).parent
|
|
47
|
-
if directory and not directory.exists():
|
|
48
|
-
directory.mkdir(parents=True, exist_ok=True)
|
|
49
|
-
|
|
50
|
-
@staticmethod
|
|
51
|
-
def _sanitize_params_dict(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
52
|
-
"""Filter index-like columns such as "Unnamed: 0" from pandas I/O.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
params: Parameter dictionary
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
Sanitized parameter dictionary
|
|
59
|
-
"""
|
|
60
|
-
return {
|
|
61
|
-
k: v
|
|
62
|
-
for k, v in (params or {}).items()
|
|
63
|
-
if k and not str(k).startswith("Unnamed")
|
|
64
|
-
}
|
|
65
|
-
|
|
66
|
-
@staticmethod
|
|
67
|
-
def load_params_file(path: str) -> Dict[str, Any]:
|
|
68
|
-
"""Load parameter dict from JSON/CSV/TSV files.
|
|
69
|
-
|
|
70
|
-
Supported formats:
|
|
71
|
-
- JSON: accept dict or {"best_params": {...}} wrapper
|
|
72
|
-
- CSV/TSV: read the first row as params
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
path: Path to parameter file
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
Parameter dictionary
|
|
79
|
-
|
|
80
|
-
Raises:
|
|
81
|
-
FileNotFoundError: If file doesn't exist
|
|
82
|
-
ValueError: If file format is unsupported or invalid
|
|
83
|
-
"""
|
|
84
|
-
file_path = Path(path).expanduser().resolve()
|
|
85
|
-
if not file_path.exists():
|
|
86
|
-
raise FileNotFoundError(f"params file not found: {file_path}")
|
|
87
|
-
|
|
88
|
-
suffix = file_path.suffix.lower()
|
|
89
|
-
|
|
90
|
-
if suffix == ".json":
|
|
91
|
-
payload = json.loads(file_path.read_text(
|
|
92
|
-
encoding="utf-8", errors="replace"))
|
|
93
|
-
if isinstance(payload, dict) and "best_params" in payload:
|
|
94
|
-
payload = payload.get("best_params") or {}
|
|
95
|
-
if not isinstance(payload, dict):
|
|
96
|
-
raise ValueError(
|
|
97
|
-
f"Invalid JSON params file (expect dict): {file_path}")
|
|
98
|
-
return IOUtils._sanitize_params_dict(dict(payload))
|
|
99
|
-
|
|
100
|
-
if suffix in (".csv", ".tsv"):
|
|
101
|
-
df = pd.read_csv(file_path, sep="\t" if suffix == ".tsv" else ",")
|
|
102
|
-
if df.empty:
|
|
103
|
-
raise ValueError(f"Empty params file: {file_path}")
|
|
104
|
-
params = df.iloc[0].to_dict()
|
|
105
|
-
return IOUtils._sanitize_params_dict(params)
|
|
106
|
-
|
|
107
|
-
raise ValueError(
|
|
108
|
-
f"Unsupported params file type '{suffix}': {file_path}")
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
# Backward compatibility function wrapper
|
|
112
|
-
def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
|
|
113
|
-
"""Load CSV file as list of dictionaries (legacy function).
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
file_path: Path to CSV file
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
List of dictionaries, one per row
|
|
120
|
-
"""
|
|
121
|
-
return IOUtils.csv_to_dict(file_path)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def ensure_parent_dir(file_path: str) -> None:
|
|
125
|
-
"""Create parent directories when missing (legacy function)."""
|
|
126
|
-
IOUtils.ensure_parent_dir(file_path)
|
|
@@ -1,555 +0,0 @@
|
|
|
1
|
-
"""Metrics computation, GPU management, device selection, CV utilities, and plotting.
|
|
2
|
-
|
|
3
|
-
This module contains:
|
|
4
|
-
- get_logger() for package-level logging
|
|
5
|
-
- MetricFactory for consistent metric computation
|
|
6
|
-
- GPUMemoryManager for CUDA memory management (imported from package utils)
|
|
7
|
-
- DeviceManager for device selection and model placement (imported from package utils)
|
|
8
|
-
- CVStrategyResolver for cross-validation strategy selection
|
|
9
|
-
- PlotUtils for lift chart plotting
|
|
10
|
-
- Backward compatibility wrappers for plotting functions
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
from __future__ import annotations
|
|
14
|
-
|
|
15
|
-
import gc
|
|
16
|
-
import logging
|
|
17
|
-
import os
|
|
18
|
-
from contextlib import contextmanager
|
|
19
|
-
from functools import lru_cache
|
|
20
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
21
|
-
|
|
22
|
-
import numpy as np
|
|
23
|
-
import pandas as pd
|
|
24
|
-
import torch
|
|
25
|
-
import torch.nn as nn
|
|
26
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
27
|
-
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, mean_tweedie_deviance
|
|
28
|
-
from sklearn.model_selection import KFold, GroupKFold, TimeSeriesSplit, StratifiedKFold
|
|
29
|
-
|
|
30
|
-
# Try to import plotting dependencies
|
|
31
|
-
try:
|
|
32
|
-
import matplotlib
|
|
33
|
-
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
34
|
-
matplotlib.use("Agg")
|
|
35
|
-
import matplotlib.pyplot as plt
|
|
36
|
-
_MPL_IMPORT_ERROR: Optional[BaseException] = None
|
|
37
|
-
except Exception as exc:
|
|
38
|
-
matplotlib = None
|
|
39
|
-
plt = None
|
|
40
|
-
_MPL_IMPORT_ERROR = exc
|
|
41
|
-
|
|
42
|
-
try:
|
|
43
|
-
from ....plotting import curves as plot_curves_common
|
|
44
|
-
except Exception:
|
|
45
|
-
try:
|
|
46
|
-
from ins_pricing.plotting import curves as plot_curves_common
|
|
47
|
-
except Exception:
|
|
48
|
-
plot_curves_common = None
|
|
49
|
-
|
|
50
|
-
from .constants import EPS
|
|
51
|
-
|
|
52
|
-
# Import DeviceManager and GPUMemoryManager from package-level utils
|
|
53
|
-
# (Eliminates ~230 lines of code duplication)
|
|
54
|
-
from ins_pricing.utils import DeviceManager, GPUMemoryManager
|
|
55
|
-
from .io_utils import IOUtils
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
# =============================================================================
|
|
59
|
-
# Logging System
|
|
60
|
-
# =============================================================================
|
|
61
|
-
|
|
62
|
-
@lru_cache(maxsize=1)
|
|
63
|
-
def _get_package_logger() -> logging.Logger:
|
|
64
|
-
"""Get or create the package-level logger with consistent formatting."""
|
|
65
|
-
logger = logging.getLogger("ins_pricing")
|
|
66
|
-
if not logger.handlers:
|
|
67
|
-
handler = logging.StreamHandler()
|
|
68
|
-
formatter = logging.Formatter(
|
|
69
|
-
"[%(levelname)s][%(name)s] %(message)s"
|
|
70
|
-
)
|
|
71
|
-
handler.setFormatter(formatter)
|
|
72
|
-
logger.addHandler(handler)
|
|
73
|
-
# Default to INFO, can be changed via environment variable
|
|
74
|
-
level = os.environ.get("INS_PRICING_LOG_LEVEL", "INFO").upper()
|
|
75
|
-
logger.setLevel(getattr(logging, level, logging.INFO))
|
|
76
|
-
return logger
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def get_logger(name: str = "ins_pricing") -> logging.Logger:
|
|
80
|
-
"""Get a logger with the given name, inheriting package-level settings.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
name: Logger name, typically module name like 'ins_pricing.trainer'
|
|
84
|
-
|
|
85
|
-
Returns:
|
|
86
|
-
Configured logger instance
|
|
87
|
-
|
|
88
|
-
Example:
|
|
89
|
-
>>> logger = get_logger("ins_pricing.trainer.ft")
|
|
90
|
-
>>> logger.info("Training started")
|
|
91
|
-
"""
|
|
92
|
-
_get_package_logger()
|
|
93
|
-
return logging.getLogger(name)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
# =============================================================================
|
|
97
|
-
# Metric Computation Factory
|
|
98
|
-
# =============================================================================
|
|
99
|
-
|
|
100
|
-
class MetricFactory:
|
|
101
|
-
"""Factory for computing evaluation metrics consistently across all trainers.
|
|
102
|
-
|
|
103
|
-
This class centralizes metric computation logic that was previously duplicated
|
|
104
|
-
across FTTrainer, ResNetTrainer, GNNTrainer, XGBTrainer, and GLMTrainer.
|
|
105
|
-
|
|
106
|
-
Example:
|
|
107
|
-
>>> factory = MetricFactory(task_type='regression', tweedie_power=1.5)
|
|
108
|
-
>>> score = factory.compute(y_true, y_pred, sample_weight)
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
-
def __init__(
|
|
112
|
-
self,
|
|
113
|
-
task_type: str = "regression",
|
|
114
|
-
tweedie_power: float = 1.5,
|
|
115
|
-
loss_name: str = "tweedie",
|
|
116
|
-
clip_min: float = 1e-8,
|
|
117
|
-
clip_max: float = 1 - 1e-8,
|
|
118
|
-
):
|
|
119
|
-
"""Initialize the metric factory.
|
|
120
|
-
|
|
121
|
-
Args:
|
|
122
|
-
task_type: Either 'regression' or 'classification'
|
|
123
|
-
tweedie_power: Power parameter for Tweedie deviance (1.0-2.0)
|
|
124
|
-
loss_name: Regression loss name ('tweedie', 'poisson', 'gamma', 'mse', 'mae')
|
|
125
|
-
clip_min: Minimum value for clipping predictions
|
|
126
|
-
clip_max: Maximum value for clipping predictions (for classification)
|
|
127
|
-
"""
|
|
128
|
-
self.task_type = task_type
|
|
129
|
-
self.tweedie_power = tweedie_power
|
|
130
|
-
self.loss_name = loss_name
|
|
131
|
-
self.clip_min = clip_min
|
|
132
|
-
self.clip_max = clip_max
|
|
133
|
-
|
|
134
|
-
def compute(
|
|
135
|
-
self,
|
|
136
|
-
y_true: np.ndarray,
|
|
137
|
-
y_pred: np.ndarray,
|
|
138
|
-
sample_weight: Optional[np.ndarray] = None,
|
|
139
|
-
) -> float:
|
|
140
|
-
"""Compute the appropriate metric based on task type.
|
|
141
|
-
|
|
142
|
-
Args:
|
|
143
|
-
y_true: Ground truth values
|
|
144
|
-
y_pred: Predicted values
|
|
145
|
-
sample_weight: Optional sample weights
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
Computed metric value (lower is better)
|
|
149
|
-
"""
|
|
150
|
-
y_pred = np.asarray(y_pred)
|
|
151
|
-
y_true = np.asarray(y_true)
|
|
152
|
-
|
|
153
|
-
if self.task_type == "classification":
|
|
154
|
-
y_pred_clipped = np.clip(y_pred, self.clip_min, self.clip_max)
|
|
155
|
-
return float(log_loss(y_true, y_pred_clipped, sample_weight=sample_weight))
|
|
156
|
-
|
|
157
|
-
loss_name = str(self.loss_name or "tweedie").strip().lower()
|
|
158
|
-
if loss_name in {"mse", "mae"}:
|
|
159
|
-
if loss_name == "mse":
|
|
160
|
-
return float(mean_squared_error(
|
|
161
|
-
y_true, y_pred, sample_weight=sample_weight))
|
|
162
|
-
return float(mean_absolute_error(
|
|
163
|
-
y_true, y_pred, sample_weight=sample_weight))
|
|
164
|
-
|
|
165
|
-
y_pred_safe = np.maximum(y_pred, self.clip_min)
|
|
166
|
-
power = self.tweedie_power
|
|
167
|
-
if loss_name == "poisson":
|
|
168
|
-
power = 1.0
|
|
169
|
-
elif loss_name == "gamma":
|
|
170
|
-
power = 2.0
|
|
171
|
-
return float(mean_tweedie_deviance(
|
|
172
|
-
y_true,
|
|
173
|
-
y_pred_safe,
|
|
174
|
-
sample_weight=sample_weight,
|
|
175
|
-
power=power,
|
|
176
|
-
))
|
|
177
|
-
|
|
178
|
-
def update_power(self, power: float) -> None:
|
|
179
|
-
"""Update the Tweedie power parameter.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
power: New power value (1.0-2.0)
|
|
183
|
-
"""
|
|
184
|
-
self.tweedie_power = power
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
# =============================================================================
|
|
188
|
-
# GPU Memory Manager and Device Manager
|
|
189
|
-
# =============================================================================
|
|
190
|
-
# NOTE: These classes are imported from ins_pricing.utils (see top of file)
|
|
191
|
-
# This eliminates ~230 lines of duplicate code while maintaining backward compatibility
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
# =============================================================================
|
|
195
|
-
# Cross-Validation Strategy Resolver
|
|
196
|
-
# =============================================================================
|
|
197
|
-
|
|
198
|
-
class CVStrategyResolver:
|
|
199
|
-
"""Resolver for cross-validation splitting strategies.
|
|
200
|
-
|
|
201
|
-
This class consolidates CV strategy resolution logic that was previously
|
|
202
|
-
duplicated across trainer_base.py and trainer_ft.py.
|
|
203
|
-
|
|
204
|
-
Supported strategies:
|
|
205
|
-
- 'random': Standard KFold
|
|
206
|
-
- 'stratified': StratifiedKFold (for classification)
|
|
207
|
-
- 'group': GroupKFold (requires group column)
|
|
208
|
-
- 'time': TimeSeriesSplit (requires time column)
|
|
209
|
-
|
|
210
|
-
Example:
|
|
211
|
-
>>> resolver = CVStrategyResolver(
|
|
212
|
-
... strategy='group',
|
|
213
|
-
... n_splits=5,
|
|
214
|
-
... group_col='policy_id',
|
|
215
|
-
... data=train_df,
|
|
216
|
-
... )
|
|
217
|
-
>>> splitter, groups = resolver.get_splitter()
|
|
218
|
-
>>> for train_idx, val_idx in splitter.split(X, y, groups):
|
|
219
|
-
... pass
|
|
220
|
-
"""
|
|
221
|
-
|
|
222
|
-
VALID_STRATEGIES = {"random", "stratified", "group", "grouped", "time", "timeseries", "temporal"}
|
|
223
|
-
|
|
224
|
-
def __init__(
|
|
225
|
-
self,
|
|
226
|
-
strategy: str = "random",
|
|
227
|
-
n_splits: int = 5,
|
|
228
|
-
shuffle: bool = True,
|
|
229
|
-
random_state: Optional[int] = None,
|
|
230
|
-
group_col: Optional[str] = None,
|
|
231
|
-
time_col: Optional[str] = None,
|
|
232
|
-
time_ascending: bool = True,
|
|
233
|
-
data: Optional[pd.DataFrame] = None,
|
|
234
|
-
):
|
|
235
|
-
"""Initialize the CV strategy resolver.
|
|
236
|
-
|
|
237
|
-
Args:
|
|
238
|
-
strategy: CV strategy name
|
|
239
|
-
n_splits: Number of CV folds
|
|
240
|
-
shuffle: Whether to shuffle for random/stratified
|
|
241
|
-
random_state: Random seed for reproducibility
|
|
242
|
-
group_col: Column name for group-based splitting
|
|
243
|
-
time_col: Column name for time-based splitting
|
|
244
|
-
time_ascending: Sort order for time-based splitting
|
|
245
|
-
data: DataFrame containing group/time columns
|
|
246
|
-
"""
|
|
247
|
-
self.strategy = strategy.strip().lower()
|
|
248
|
-
self.n_splits = max(2, int(n_splits))
|
|
249
|
-
self.shuffle = shuffle
|
|
250
|
-
self.random_state = random_state
|
|
251
|
-
self.group_col = group_col
|
|
252
|
-
self.time_col = time_col
|
|
253
|
-
self.time_ascending = time_ascending
|
|
254
|
-
self.data = data
|
|
255
|
-
|
|
256
|
-
if self.strategy not in self.VALID_STRATEGIES:
|
|
257
|
-
raise ValueError(
|
|
258
|
-
f"Invalid strategy '{strategy}'. "
|
|
259
|
-
f"Valid options: {sorted(self.VALID_STRATEGIES)}"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
def get_splitter(self) -> Tuple[Any, Optional[pd.Series]]:
|
|
263
|
-
"""Get the appropriate splitter and groups.
|
|
264
|
-
|
|
265
|
-
Returns:
|
|
266
|
-
Tuple of (splitter, groups) where groups may be None
|
|
267
|
-
|
|
268
|
-
Raises:
|
|
269
|
-
ValueError: If required columns are missing
|
|
270
|
-
"""
|
|
271
|
-
if self.strategy in {"group", "grouped"}:
|
|
272
|
-
return self._get_group_splitter()
|
|
273
|
-
elif self.strategy in {"time", "timeseries", "temporal"}:
|
|
274
|
-
return self._get_time_splitter()
|
|
275
|
-
elif self.strategy == "stratified":
|
|
276
|
-
return self._get_stratified_splitter()
|
|
277
|
-
else:
|
|
278
|
-
return self._get_random_splitter()
|
|
279
|
-
|
|
280
|
-
def _get_random_splitter(self) -> Tuple[KFold, None]:
|
|
281
|
-
"""Get a random KFold splitter."""
|
|
282
|
-
splitter = KFold(
|
|
283
|
-
n_splits=self.n_splits,
|
|
284
|
-
shuffle=self.shuffle,
|
|
285
|
-
random_state=self.random_state if self.shuffle else None,
|
|
286
|
-
)
|
|
287
|
-
return splitter, None
|
|
288
|
-
|
|
289
|
-
def _get_stratified_splitter(self) -> Tuple[StratifiedKFold, None]:
|
|
290
|
-
"""Get a stratified KFold splitter."""
|
|
291
|
-
splitter = StratifiedKFold(
|
|
292
|
-
n_splits=self.n_splits,
|
|
293
|
-
shuffle=self.shuffle,
|
|
294
|
-
random_state=self.random_state if self.shuffle else None,
|
|
295
|
-
)
|
|
296
|
-
return splitter, None
|
|
297
|
-
|
|
298
|
-
def _get_group_splitter(self) -> Tuple[GroupKFold, pd.Series]:
|
|
299
|
-
"""Get a group-based KFold splitter."""
|
|
300
|
-
if not self.group_col:
|
|
301
|
-
raise ValueError("group_col is required for group strategy")
|
|
302
|
-
if self.data is None:
|
|
303
|
-
raise ValueError("data DataFrame is required for group strategy")
|
|
304
|
-
if self.group_col not in self.data.columns:
|
|
305
|
-
raise KeyError(f"group_col '{self.group_col}' not found in data")
|
|
306
|
-
|
|
307
|
-
groups = self.data[self.group_col]
|
|
308
|
-
splitter = GroupKFold(n_splits=self.n_splits)
|
|
309
|
-
return splitter, groups
|
|
310
|
-
|
|
311
|
-
def _get_time_splitter(self) -> Tuple[Any, None]:
|
|
312
|
-
"""Get a time-series splitter."""
|
|
313
|
-
if not self.time_col:
|
|
314
|
-
raise ValueError("time_col is required for time strategy")
|
|
315
|
-
if self.data is None:
|
|
316
|
-
raise ValueError("data DataFrame is required for time strategy")
|
|
317
|
-
if self.time_col not in self.data.columns:
|
|
318
|
-
raise KeyError(f"time_col '{self.time_col}' not found in data")
|
|
319
|
-
|
|
320
|
-
splitter = TimeSeriesSplit(n_splits=self.n_splits)
|
|
321
|
-
|
|
322
|
-
# Create an ordered wrapper that sorts by time column
|
|
323
|
-
order_index = self.data[self.time_col].sort_values(
|
|
324
|
-
ascending=self.time_ascending
|
|
325
|
-
).index
|
|
326
|
-
order = self.data.index.get_indexer(order_index)
|
|
327
|
-
|
|
328
|
-
return _OrderedSplitter(splitter, order), None
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
class _OrderedSplitter:
|
|
332
|
-
"""Wrapper for splitters that need to respect a specific ordering."""
|
|
333
|
-
|
|
334
|
-
def __init__(self, base_splitter, order: np.ndarray):
|
|
335
|
-
self.base_splitter = base_splitter
|
|
336
|
-
self.order = order
|
|
337
|
-
|
|
338
|
-
def split(self, X, y=None, groups=None):
|
|
339
|
-
"""Split with ordering applied."""
|
|
340
|
-
n = len(X)
|
|
341
|
-
X_ordered = np.arange(n)[self.order]
|
|
342
|
-
for train_idx, val_idx in self.base_splitter.split(X_ordered):
|
|
343
|
-
yield self.order[train_idx], self.order[val_idx]
|
|
344
|
-
|
|
345
|
-
def get_n_splits(self, X=None, y=None, groups=None):
|
|
346
|
-
return self.base_splitter.get_n_splits()
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
# =============================================================================
|
|
350
|
-
# Plot Utils
|
|
351
|
-
# =============================================================================
|
|
352
|
-
|
|
353
|
-
def _plot_skip(label: str) -> None:
|
|
354
|
-
"""Print message when plot is skipped due to missing matplotlib."""
|
|
355
|
-
if _MPL_IMPORT_ERROR is not None:
|
|
356
|
-
print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
|
|
357
|
-
else:
|
|
358
|
-
print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
class PlotUtils:
|
|
362
|
-
"""Plotting utilities for lift charts."""
|
|
363
|
-
|
|
364
|
-
@staticmethod
|
|
365
|
-
def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
|
|
366
|
-
"""Split data into bins by cumulative weight."""
|
|
367
|
-
data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
|
|
368
|
-
data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
|
|
369
|
-
w_sum = data_sorted[wgt_nme].sum()
|
|
370
|
-
if w_sum <= EPS:
|
|
371
|
-
data_sorted['bins'] = 0
|
|
372
|
-
else:
|
|
373
|
-
data_sorted['bins'] = np.floor(
|
|
374
|
-
data_sorted['cum_weight'] * float(n_bins) / w_sum
|
|
375
|
-
)
|
|
376
|
-
data_sorted.loc[(data_sorted['bins'] == n_bins),
|
|
377
|
-
'bins'] = n_bins - 1
|
|
378
|
-
return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
|
|
379
|
-
|
|
380
|
-
@staticmethod
|
|
381
|
-
def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
|
|
382
|
-
"""Plot lift chart on given axes."""
|
|
383
|
-
ax.plot(plot_data.index, plot_data['act_v'],
|
|
384
|
-
label=act_label, color='red')
|
|
385
|
-
ax.plot(plot_data.index, plot_data['exp_v'],
|
|
386
|
-
label=pred_label, color='blue')
|
|
387
|
-
ax.set_title(title, fontsize=8)
|
|
388
|
-
ax.set_xticks(plot_data.index)
|
|
389
|
-
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
|
|
390
|
-
ax.tick_params(axis='y', labelsize=6)
|
|
391
|
-
ax.legend(loc='upper left', fontsize=5, frameon=False)
|
|
392
|
-
ax.margins(0.05)
|
|
393
|
-
ax2 = ax.twinx()
|
|
394
|
-
ax2.bar(plot_data.index, plot_data['weight'],
|
|
395
|
-
alpha=0.5, color='seagreen',
|
|
396
|
-
label=weight_label)
|
|
397
|
-
ax2.tick_params(axis='y', labelsize=6)
|
|
398
|
-
ax2.legend(loc='upper right', fontsize=5, frameon=False)
|
|
399
|
-
|
|
400
|
-
@staticmethod
|
|
401
|
-
def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
|
|
402
|
-
"""Plot double lift chart on given axes."""
|
|
403
|
-
ax.plot(plot_data.index, plot_data['act_v'],
|
|
404
|
-
label=act_label, color='red')
|
|
405
|
-
ax.plot(plot_data.index, plot_data['exp_v1'],
|
|
406
|
-
label=label1, color='blue')
|
|
407
|
-
ax.plot(plot_data.index, plot_data['exp_v2'],
|
|
408
|
-
label=label2, color='black')
|
|
409
|
-
ax.set_title(title, fontsize=8)
|
|
410
|
-
ax.set_xticks(plot_data.index)
|
|
411
|
-
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
|
|
412
|
-
ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
|
|
413
|
-
ax.tick_params(axis='y', labelsize=6)
|
|
414
|
-
ax.legend(loc='upper left', fontsize=5, frameon=False)
|
|
415
|
-
ax.margins(0.1)
|
|
416
|
-
ax2 = ax.twinx()
|
|
417
|
-
ax2.bar(plot_data.index, plot_data['weight'],
|
|
418
|
-
alpha=0.5, color='seagreen',
|
|
419
|
-
label=weight_label)
|
|
420
|
-
ax2.tick_params(axis='y', labelsize=6)
|
|
421
|
-
ax2.legend(loc='upper right', fontsize=5, frameon=False)
|
|
422
|
-
|
|
423
|
-
@staticmethod
|
|
424
|
-
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
425
|
-
weight_list, tgt_nme, n_bins: int = 10,
|
|
426
|
-
fig_nme: str = 'Lift Chart'):
|
|
427
|
-
"""Plot lift chart for model predictions."""
|
|
428
|
-
if plot_curves_common is not None:
|
|
429
|
-
save_path = os.path.join(
|
|
430
|
-
os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
|
|
431
|
-
plot_curves_common.plot_lift_curve(
|
|
432
|
-
pred_model,
|
|
433
|
-
w_act_list,
|
|
434
|
-
weight_list,
|
|
435
|
-
n_bins=n_bins,
|
|
436
|
-
title=f'Lift Chart of {tgt_nme}',
|
|
437
|
-
pred_label='Predicted',
|
|
438
|
-
act_label='Actual',
|
|
439
|
-
weight_label='Earned Exposure',
|
|
440
|
-
pred_weighted=False,
|
|
441
|
-
actual_weighted=True,
|
|
442
|
-
save_path=save_path,
|
|
443
|
-
show=False,
|
|
444
|
-
)
|
|
445
|
-
return
|
|
446
|
-
if plt is None:
|
|
447
|
-
_plot_skip("lift plot")
|
|
448
|
-
return
|
|
449
|
-
lift_data = pd.DataFrame({
|
|
450
|
-
'pred': pred_model,
|
|
451
|
-
'w_pred': w_pred_list,
|
|
452
|
-
'act': w_act_list,
|
|
453
|
-
'weight': weight_list
|
|
454
|
-
})
|
|
455
|
-
plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
|
|
456
|
-
plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
|
|
457
|
-
plot_data['act_v'] = plot_data['act'] / plot_data['weight']
|
|
458
|
-
plot_data.reset_index(inplace=True)
|
|
459
|
-
|
|
460
|
-
fig = plt.figure(figsize=(7, 5))
|
|
461
|
-
ax = fig.add_subplot(111)
|
|
462
|
-
PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
|
|
463
|
-
plt.subplots_adjust(wspace=0.3)
|
|
464
|
-
|
|
465
|
-
save_path = os.path.join(
|
|
466
|
-
os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
|
|
467
|
-
IOUtils.ensure_parent_dir(save_path)
|
|
468
|
-
plt.savefig(save_path, dpi=300)
|
|
469
|
-
plt.close(fig)
|
|
470
|
-
|
|
471
|
-
@staticmethod
|
|
472
|
-
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
473
|
-
model_nme_1, model_nme_2,
|
|
474
|
-
tgt_nme,
|
|
475
|
-
w_list, w_act_list, n_bins: int = 10,
|
|
476
|
-
fig_nme: str = 'Double Lift Chart'):
|
|
477
|
-
"""Plot double lift chart comparing two models."""
|
|
478
|
-
if plot_curves_common is not None:
|
|
479
|
-
save_path = os.path.join(
|
|
480
|
-
os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
|
|
481
|
-
plot_curves_common.plot_double_lift_curve(
|
|
482
|
-
pred_model_1,
|
|
483
|
-
pred_model_2,
|
|
484
|
-
w_act_list,
|
|
485
|
-
w_list,
|
|
486
|
-
n_bins=n_bins,
|
|
487
|
-
title=f'Double Lift Chart of {tgt_nme}',
|
|
488
|
-
label1=model_nme_1,
|
|
489
|
-
label2=model_nme_2,
|
|
490
|
-
pred1_weighted=False,
|
|
491
|
-
pred2_weighted=False,
|
|
492
|
-
actual_weighted=True,
|
|
493
|
-
save_path=save_path,
|
|
494
|
-
show=False,
|
|
495
|
-
)
|
|
496
|
-
return
|
|
497
|
-
if plt is None:
|
|
498
|
-
_plot_skip("double lift plot")
|
|
499
|
-
return
|
|
500
|
-
lift_data = pd.DataFrame({
|
|
501
|
-
'pred1': pred_model_1,
|
|
502
|
-
'pred2': pred_model_2,
|
|
503
|
-
'act': w_act_list,
|
|
504
|
-
'weight': w_list
|
|
505
|
-
})
|
|
506
|
-
lift_data['diff_ly'] = lift_data['pred1'] / lift_data['pred2']
|
|
507
|
-
lift_data['w_pred1'] = lift_data['pred1'] * lift_data['weight']
|
|
508
|
-
lift_data['w_pred2'] = lift_data['pred2'] * lift_data['weight']
|
|
509
|
-
plot_data = PlotUtils.split_data(
|
|
510
|
-
lift_data, 'diff_ly', 'weight', n_bins)
|
|
511
|
-
plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
|
|
512
|
-
plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
|
|
513
|
-
plot_data['act_v'] = plot_data['act']/plot_data['act']
|
|
514
|
-
plot_data.reset_index(inplace=True)
|
|
515
|
-
|
|
516
|
-
fig = plt.figure(figsize=(7, 5))
|
|
517
|
-
ax = fig.add_subplot(111)
|
|
518
|
-
PlotUtils.plot_dlift_ax(
|
|
519
|
-
ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
|
|
520
|
-
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
|
|
521
|
-
|
|
522
|
-
save_path = os.path.join(
|
|
523
|
-
os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
|
|
524
|
-
IOUtils.ensure_parent_dir(save_path)
|
|
525
|
-
plt.savefig(save_path, dpi=300)
|
|
526
|
-
plt.close(fig)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
# =============================================================================
|
|
530
|
-
# Backward Compatibility Wrappers
|
|
531
|
-
# =============================================================================
|
|
532
|
-
|
|
533
|
-
def split_data(data, col_nme, wgt_nme, n_bins=10):
|
|
534
|
-
"""Legacy function wrapper for PlotUtils.split_data()."""
|
|
535
|
-
return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
539
|
-
weight_list, tgt_nme, n_bins=10,
|
|
540
|
-
fig_nme='Lift Chart'):
|
|
541
|
-
"""Legacy function wrapper for PlotUtils.plot_lift_list()."""
|
|
542
|
-
return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
543
|
-
weight_list, tgt_nme, n_bins, fig_nme)
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
547
|
-
model_nme_1, model_nme_2,
|
|
548
|
-
tgt_nme,
|
|
549
|
-
w_list, w_act_list, n_bins=10,
|
|
550
|
-
fig_nme='Double Lift Chart'):
|
|
551
|
-
"""Legacy function wrapper for PlotUtils.plot_dlift_list()."""
|
|
552
|
-
return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
|
|
553
|
-
model_nme_1, model_nme_2,
|
|
554
|
-
tgt_nme, w_list, w_act_list,
|
|
555
|
-
n_bins, fig_nme)
|