ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,126 +0,0 @@
1
- """File I/O and parameter loading utilities.
2
-
3
- This module contains:
4
- - IOUtils class for loading parameters from JSON/CSV/TSV files
5
- - csv_to_dict() for CSV file handling
6
- - File path sanitization utilities
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import csv
12
- import json
13
- from pathlib import Path
14
- from typing import Any, Dict, List
15
-
16
- import pandas as pd
17
-
18
-
19
- class IOUtils:
20
- """File and path utilities for model parameters and configs."""
21
-
22
- @staticmethod
23
- def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
24
- """Load CSV file as list of dictionaries.
25
-
26
- Args:
27
- file_path: Path to CSV file
28
-
29
- Returns:
30
- List of dictionaries, one per row
31
- """
32
- with open(file_path, mode='r', encoding='utf-8') as file:
33
- reader = csv.DictReader(file)
34
- return [
35
- dict(filter(lambda item: item[0] != '', row.items()))
36
- for row in reader
37
- ]
38
-
39
- @staticmethod
40
- def ensure_parent_dir(file_path: str) -> None:
41
- """Create parent directories when missing.
42
-
43
- Args:
44
- file_path: Path to file whose parent directory should be created
45
- """
46
- directory = Path(file_path).parent
47
- if directory and not directory.exists():
48
- directory.mkdir(parents=True, exist_ok=True)
49
-
50
- @staticmethod
51
- def _sanitize_params_dict(params: Dict[str, Any]) -> Dict[str, Any]:
52
- """Filter index-like columns such as "Unnamed: 0" from pandas I/O.
53
-
54
- Args:
55
- params: Parameter dictionary
56
-
57
- Returns:
58
- Sanitized parameter dictionary
59
- """
60
- return {
61
- k: v
62
- for k, v in (params or {}).items()
63
- if k and not str(k).startswith("Unnamed")
64
- }
65
-
66
- @staticmethod
67
- def load_params_file(path: str) -> Dict[str, Any]:
68
- """Load parameter dict from JSON/CSV/TSV files.
69
-
70
- Supported formats:
71
- - JSON: accept dict or {"best_params": {...}} wrapper
72
- - CSV/TSV: read the first row as params
73
-
74
- Args:
75
- path: Path to parameter file
76
-
77
- Returns:
78
- Parameter dictionary
79
-
80
- Raises:
81
- FileNotFoundError: If file doesn't exist
82
- ValueError: If file format is unsupported or invalid
83
- """
84
- file_path = Path(path).expanduser().resolve()
85
- if not file_path.exists():
86
- raise FileNotFoundError(f"params file not found: {file_path}")
87
-
88
- suffix = file_path.suffix.lower()
89
-
90
- if suffix == ".json":
91
- payload = json.loads(file_path.read_text(
92
- encoding="utf-8", errors="replace"))
93
- if isinstance(payload, dict) and "best_params" in payload:
94
- payload = payload.get("best_params") or {}
95
- if not isinstance(payload, dict):
96
- raise ValueError(
97
- f"Invalid JSON params file (expect dict): {file_path}")
98
- return IOUtils._sanitize_params_dict(dict(payload))
99
-
100
- if suffix in (".csv", ".tsv"):
101
- df = pd.read_csv(file_path, sep="\t" if suffix == ".tsv" else ",")
102
- if df.empty:
103
- raise ValueError(f"Empty params file: {file_path}")
104
- params = df.iloc[0].to_dict()
105
- return IOUtils._sanitize_params_dict(params)
106
-
107
- raise ValueError(
108
- f"Unsupported params file type '{suffix}': {file_path}")
109
-
110
-
111
- # Backward compatibility function wrapper
112
- def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
113
- """Load CSV file as list of dictionaries (legacy function).
114
-
115
- Args:
116
- file_path: Path to CSV file
117
-
118
- Returns:
119
- List of dictionaries, one per row
120
- """
121
- return IOUtils.csv_to_dict(file_path)
122
-
123
-
124
- def ensure_parent_dir(file_path: str) -> None:
125
- """Create parent directories when missing (legacy function)."""
126
- IOUtils.ensure_parent_dir(file_path)
@@ -1,555 +0,0 @@
1
- """Metrics computation, GPU management, device selection, CV utilities, and plotting.
2
-
3
- This module contains:
4
- - get_logger() for package-level logging
5
- - MetricFactory for consistent metric computation
6
- - GPUMemoryManager for CUDA memory management (imported from package utils)
7
- - DeviceManager for device selection and model placement (imported from package utils)
8
- - CVStrategyResolver for cross-validation strategy selection
9
- - PlotUtils for lift chart plotting
10
- - Backward compatibility wrappers for plotting functions
11
- """
12
-
13
- from __future__ import annotations
14
-
15
- import gc
16
- import logging
17
- import os
18
- from contextlib import contextmanager
19
- from functools import lru_cache
20
- from typing import Any, Dict, List, Optional, Tuple
21
-
22
- import numpy as np
23
- import pandas as pd
24
- import torch
25
- import torch.nn as nn
26
- from torch.nn.parallel import DistributedDataParallel as DDP
27
- from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, mean_tweedie_deviance
28
- from sklearn.model_selection import KFold, GroupKFold, TimeSeriesSplit, StratifiedKFold
29
-
30
- # Try to import plotting dependencies
31
- try:
32
- import matplotlib
33
- if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
34
- matplotlib.use("Agg")
35
- import matplotlib.pyplot as plt
36
- _MPL_IMPORT_ERROR: Optional[BaseException] = None
37
- except Exception as exc:
38
- matplotlib = None
39
- plt = None
40
- _MPL_IMPORT_ERROR = exc
41
-
42
- try:
43
- from ....plotting import curves as plot_curves_common
44
- except Exception:
45
- try:
46
- from ins_pricing.plotting import curves as plot_curves_common
47
- except Exception:
48
- plot_curves_common = None
49
-
50
- from .constants import EPS
51
-
52
- # Import DeviceManager and GPUMemoryManager from package-level utils
53
- # (Eliminates ~230 lines of code duplication)
54
- from ins_pricing.utils import DeviceManager, GPUMemoryManager
55
- from .io_utils import IOUtils
56
-
57
-
58
- # =============================================================================
59
- # Logging System
60
- # =============================================================================
61
-
62
- @lru_cache(maxsize=1)
63
- def _get_package_logger() -> logging.Logger:
64
- """Get or create the package-level logger with consistent formatting."""
65
- logger = logging.getLogger("ins_pricing")
66
- if not logger.handlers:
67
- handler = logging.StreamHandler()
68
- formatter = logging.Formatter(
69
- "[%(levelname)s][%(name)s] %(message)s"
70
- )
71
- handler.setFormatter(formatter)
72
- logger.addHandler(handler)
73
- # Default to INFO, can be changed via environment variable
74
- level = os.environ.get("INS_PRICING_LOG_LEVEL", "INFO").upper()
75
- logger.setLevel(getattr(logging, level, logging.INFO))
76
- return logger
77
-
78
-
79
- def get_logger(name: str = "ins_pricing") -> logging.Logger:
80
- """Get a logger with the given name, inheriting package-level settings.
81
-
82
- Args:
83
- name: Logger name, typically module name like 'ins_pricing.trainer'
84
-
85
- Returns:
86
- Configured logger instance
87
-
88
- Example:
89
- >>> logger = get_logger("ins_pricing.trainer.ft")
90
- >>> logger.info("Training started")
91
- """
92
- _get_package_logger()
93
- return logging.getLogger(name)
94
-
95
-
96
- # =============================================================================
97
- # Metric Computation Factory
98
- # =============================================================================
99
-
100
- class MetricFactory:
101
- """Factory for computing evaluation metrics consistently across all trainers.
102
-
103
- This class centralizes metric computation logic that was previously duplicated
104
- across FTTrainer, ResNetTrainer, GNNTrainer, XGBTrainer, and GLMTrainer.
105
-
106
- Example:
107
- >>> factory = MetricFactory(task_type='regression', tweedie_power=1.5)
108
- >>> score = factory.compute(y_true, y_pred, sample_weight)
109
- """
110
-
111
- def __init__(
112
- self,
113
- task_type: str = "regression",
114
- tweedie_power: float = 1.5,
115
- loss_name: str = "tweedie",
116
- clip_min: float = 1e-8,
117
- clip_max: float = 1 - 1e-8,
118
- ):
119
- """Initialize the metric factory.
120
-
121
- Args:
122
- task_type: Either 'regression' or 'classification'
123
- tweedie_power: Power parameter for Tweedie deviance (1.0-2.0)
124
- loss_name: Regression loss name ('tweedie', 'poisson', 'gamma', 'mse', 'mae')
125
- clip_min: Minimum value for clipping predictions
126
- clip_max: Maximum value for clipping predictions (for classification)
127
- """
128
- self.task_type = task_type
129
- self.tweedie_power = tweedie_power
130
- self.loss_name = loss_name
131
- self.clip_min = clip_min
132
- self.clip_max = clip_max
133
-
134
- def compute(
135
- self,
136
- y_true: np.ndarray,
137
- y_pred: np.ndarray,
138
- sample_weight: Optional[np.ndarray] = None,
139
- ) -> float:
140
- """Compute the appropriate metric based on task type.
141
-
142
- Args:
143
- y_true: Ground truth values
144
- y_pred: Predicted values
145
- sample_weight: Optional sample weights
146
-
147
- Returns:
148
- Computed metric value (lower is better)
149
- """
150
- y_pred = np.asarray(y_pred)
151
- y_true = np.asarray(y_true)
152
-
153
- if self.task_type == "classification":
154
- y_pred_clipped = np.clip(y_pred, self.clip_min, self.clip_max)
155
- return float(log_loss(y_true, y_pred_clipped, sample_weight=sample_weight))
156
-
157
- loss_name = str(self.loss_name or "tweedie").strip().lower()
158
- if loss_name in {"mse", "mae"}:
159
- if loss_name == "mse":
160
- return float(mean_squared_error(
161
- y_true, y_pred, sample_weight=sample_weight))
162
- return float(mean_absolute_error(
163
- y_true, y_pred, sample_weight=sample_weight))
164
-
165
- y_pred_safe = np.maximum(y_pred, self.clip_min)
166
- power = self.tweedie_power
167
- if loss_name == "poisson":
168
- power = 1.0
169
- elif loss_name == "gamma":
170
- power = 2.0
171
- return float(mean_tweedie_deviance(
172
- y_true,
173
- y_pred_safe,
174
- sample_weight=sample_weight,
175
- power=power,
176
- ))
177
-
178
- def update_power(self, power: float) -> None:
179
- """Update the Tweedie power parameter.
180
-
181
- Args:
182
- power: New power value (1.0-2.0)
183
- """
184
- self.tweedie_power = power
185
-
186
-
187
- # =============================================================================
188
- # GPU Memory Manager and Device Manager
189
- # =============================================================================
190
- # NOTE: These classes are imported from ins_pricing.utils (see top of file)
191
- # This eliminates ~230 lines of duplicate code while maintaining backward compatibility
192
-
193
-
194
- # =============================================================================
195
- # Cross-Validation Strategy Resolver
196
- # =============================================================================
197
-
198
- class CVStrategyResolver:
199
- """Resolver for cross-validation splitting strategies.
200
-
201
- This class consolidates CV strategy resolution logic that was previously
202
- duplicated across trainer_base.py and trainer_ft.py.
203
-
204
- Supported strategies:
205
- - 'random': Standard KFold
206
- - 'stratified': StratifiedKFold (for classification)
207
- - 'group': GroupKFold (requires group column)
208
- - 'time': TimeSeriesSplit (requires time column)
209
-
210
- Example:
211
- >>> resolver = CVStrategyResolver(
212
- ... strategy='group',
213
- ... n_splits=5,
214
- ... group_col='policy_id',
215
- ... data=train_df,
216
- ... )
217
- >>> splitter, groups = resolver.get_splitter()
218
- >>> for train_idx, val_idx in splitter.split(X, y, groups):
219
- ... pass
220
- """
221
-
222
- VALID_STRATEGIES = {"random", "stratified", "group", "grouped", "time", "timeseries", "temporal"}
223
-
224
- def __init__(
225
- self,
226
- strategy: str = "random",
227
- n_splits: int = 5,
228
- shuffle: bool = True,
229
- random_state: Optional[int] = None,
230
- group_col: Optional[str] = None,
231
- time_col: Optional[str] = None,
232
- time_ascending: bool = True,
233
- data: Optional[pd.DataFrame] = None,
234
- ):
235
- """Initialize the CV strategy resolver.
236
-
237
- Args:
238
- strategy: CV strategy name
239
- n_splits: Number of CV folds
240
- shuffle: Whether to shuffle for random/stratified
241
- random_state: Random seed for reproducibility
242
- group_col: Column name for group-based splitting
243
- time_col: Column name for time-based splitting
244
- time_ascending: Sort order for time-based splitting
245
- data: DataFrame containing group/time columns
246
- """
247
- self.strategy = strategy.strip().lower()
248
- self.n_splits = max(2, int(n_splits))
249
- self.shuffle = shuffle
250
- self.random_state = random_state
251
- self.group_col = group_col
252
- self.time_col = time_col
253
- self.time_ascending = time_ascending
254
- self.data = data
255
-
256
- if self.strategy not in self.VALID_STRATEGIES:
257
- raise ValueError(
258
- f"Invalid strategy '{strategy}'. "
259
- f"Valid options: {sorted(self.VALID_STRATEGIES)}"
260
- )
261
-
262
- def get_splitter(self) -> Tuple[Any, Optional[pd.Series]]:
263
- """Get the appropriate splitter and groups.
264
-
265
- Returns:
266
- Tuple of (splitter, groups) where groups may be None
267
-
268
- Raises:
269
- ValueError: If required columns are missing
270
- """
271
- if self.strategy in {"group", "grouped"}:
272
- return self._get_group_splitter()
273
- elif self.strategy in {"time", "timeseries", "temporal"}:
274
- return self._get_time_splitter()
275
- elif self.strategy == "stratified":
276
- return self._get_stratified_splitter()
277
- else:
278
- return self._get_random_splitter()
279
-
280
- def _get_random_splitter(self) -> Tuple[KFold, None]:
281
- """Get a random KFold splitter."""
282
- splitter = KFold(
283
- n_splits=self.n_splits,
284
- shuffle=self.shuffle,
285
- random_state=self.random_state if self.shuffle else None,
286
- )
287
- return splitter, None
288
-
289
- def _get_stratified_splitter(self) -> Tuple[StratifiedKFold, None]:
290
- """Get a stratified KFold splitter."""
291
- splitter = StratifiedKFold(
292
- n_splits=self.n_splits,
293
- shuffle=self.shuffle,
294
- random_state=self.random_state if self.shuffle else None,
295
- )
296
- return splitter, None
297
-
298
- def _get_group_splitter(self) -> Tuple[GroupKFold, pd.Series]:
299
- """Get a group-based KFold splitter."""
300
- if not self.group_col:
301
- raise ValueError("group_col is required for group strategy")
302
- if self.data is None:
303
- raise ValueError("data DataFrame is required for group strategy")
304
- if self.group_col not in self.data.columns:
305
- raise KeyError(f"group_col '{self.group_col}' not found in data")
306
-
307
- groups = self.data[self.group_col]
308
- splitter = GroupKFold(n_splits=self.n_splits)
309
- return splitter, groups
310
-
311
- def _get_time_splitter(self) -> Tuple[Any, None]:
312
- """Get a time-series splitter."""
313
- if not self.time_col:
314
- raise ValueError("time_col is required for time strategy")
315
- if self.data is None:
316
- raise ValueError("data DataFrame is required for time strategy")
317
- if self.time_col not in self.data.columns:
318
- raise KeyError(f"time_col '{self.time_col}' not found in data")
319
-
320
- splitter = TimeSeriesSplit(n_splits=self.n_splits)
321
-
322
- # Create an ordered wrapper that sorts by time column
323
- order_index = self.data[self.time_col].sort_values(
324
- ascending=self.time_ascending
325
- ).index
326
- order = self.data.index.get_indexer(order_index)
327
-
328
- return _OrderedSplitter(splitter, order), None
329
-
330
-
331
- class _OrderedSplitter:
332
- """Wrapper for splitters that need to respect a specific ordering."""
333
-
334
- def __init__(self, base_splitter, order: np.ndarray):
335
- self.base_splitter = base_splitter
336
- self.order = order
337
-
338
- def split(self, X, y=None, groups=None):
339
- """Split with ordering applied."""
340
- n = len(X)
341
- X_ordered = np.arange(n)[self.order]
342
- for train_idx, val_idx in self.base_splitter.split(X_ordered):
343
- yield self.order[train_idx], self.order[val_idx]
344
-
345
- def get_n_splits(self, X=None, y=None, groups=None):
346
- return self.base_splitter.get_n_splits()
347
-
348
-
349
- # =============================================================================
350
- # Plot Utils
351
- # =============================================================================
352
-
353
- def _plot_skip(label: str) -> None:
354
- """Print message when plot is skipped due to missing matplotlib."""
355
- if _MPL_IMPORT_ERROR is not None:
356
- print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
357
- else:
358
- print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
359
-
360
-
361
- class PlotUtils:
362
- """Plotting utilities for lift charts."""
363
-
364
- @staticmethod
365
- def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
366
- """Split data into bins by cumulative weight."""
367
- data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
368
- data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
369
- w_sum = data_sorted[wgt_nme].sum()
370
- if w_sum <= EPS:
371
- data_sorted['bins'] = 0
372
- else:
373
- data_sorted['bins'] = np.floor(
374
- data_sorted['cum_weight'] * float(n_bins) / w_sum
375
- )
376
- data_sorted.loc[(data_sorted['bins'] == n_bins),
377
- 'bins'] = n_bins - 1
378
- return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
379
-
380
- @staticmethod
381
- def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
382
- """Plot lift chart on given axes."""
383
- ax.plot(plot_data.index, plot_data['act_v'],
384
- label=act_label, color='red')
385
- ax.plot(plot_data.index, plot_data['exp_v'],
386
- label=pred_label, color='blue')
387
- ax.set_title(title, fontsize=8)
388
- ax.set_xticks(plot_data.index)
389
- ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
390
- ax.tick_params(axis='y', labelsize=6)
391
- ax.legend(loc='upper left', fontsize=5, frameon=False)
392
- ax.margins(0.05)
393
- ax2 = ax.twinx()
394
- ax2.bar(plot_data.index, plot_data['weight'],
395
- alpha=0.5, color='seagreen',
396
- label=weight_label)
397
- ax2.tick_params(axis='y', labelsize=6)
398
- ax2.legend(loc='upper right', fontsize=5, frameon=False)
399
-
400
- @staticmethod
401
- def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
402
- """Plot double lift chart on given axes."""
403
- ax.plot(plot_data.index, plot_data['act_v'],
404
- label=act_label, color='red')
405
- ax.plot(plot_data.index, plot_data['exp_v1'],
406
- label=label1, color='blue')
407
- ax.plot(plot_data.index, plot_data['exp_v2'],
408
- label=label2, color='black')
409
- ax.set_title(title, fontsize=8)
410
- ax.set_xticks(plot_data.index)
411
- ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
412
- ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
413
- ax.tick_params(axis='y', labelsize=6)
414
- ax.legend(loc='upper left', fontsize=5, frameon=False)
415
- ax.margins(0.1)
416
- ax2 = ax.twinx()
417
- ax2.bar(plot_data.index, plot_data['weight'],
418
- alpha=0.5, color='seagreen',
419
- label=weight_label)
420
- ax2.tick_params(axis='y', labelsize=6)
421
- ax2.legend(loc='upper right', fontsize=5, frameon=False)
422
-
423
- @staticmethod
424
- def plot_lift_list(pred_model, w_pred_list, w_act_list,
425
- weight_list, tgt_nme, n_bins: int = 10,
426
- fig_nme: str = 'Lift Chart'):
427
- """Plot lift chart for model predictions."""
428
- if plot_curves_common is not None:
429
- save_path = os.path.join(
430
- os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
431
- plot_curves_common.plot_lift_curve(
432
- pred_model,
433
- w_act_list,
434
- weight_list,
435
- n_bins=n_bins,
436
- title=f'Lift Chart of {tgt_nme}',
437
- pred_label='Predicted',
438
- act_label='Actual',
439
- weight_label='Earned Exposure',
440
- pred_weighted=False,
441
- actual_weighted=True,
442
- save_path=save_path,
443
- show=False,
444
- )
445
- return
446
- if plt is None:
447
- _plot_skip("lift plot")
448
- return
449
- lift_data = pd.DataFrame({
450
- 'pred': pred_model,
451
- 'w_pred': w_pred_list,
452
- 'act': w_act_list,
453
- 'weight': weight_list
454
- })
455
- plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
456
- plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
457
- plot_data['act_v'] = plot_data['act'] / plot_data['weight']
458
- plot_data.reset_index(inplace=True)
459
-
460
- fig = plt.figure(figsize=(7, 5))
461
- ax = fig.add_subplot(111)
462
- PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
463
- plt.subplots_adjust(wspace=0.3)
464
-
465
- save_path = os.path.join(
466
- os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
467
- IOUtils.ensure_parent_dir(save_path)
468
- plt.savefig(save_path, dpi=300)
469
- plt.close(fig)
470
-
471
- @staticmethod
472
- def plot_dlift_list(pred_model_1, pred_model_2,
473
- model_nme_1, model_nme_2,
474
- tgt_nme,
475
- w_list, w_act_list, n_bins: int = 10,
476
- fig_nme: str = 'Double Lift Chart'):
477
- """Plot double lift chart comparing two models."""
478
- if plot_curves_common is not None:
479
- save_path = os.path.join(
480
- os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
481
- plot_curves_common.plot_double_lift_curve(
482
- pred_model_1,
483
- pred_model_2,
484
- w_act_list,
485
- w_list,
486
- n_bins=n_bins,
487
- title=f'Double Lift Chart of {tgt_nme}',
488
- label1=model_nme_1,
489
- label2=model_nme_2,
490
- pred1_weighted=False,
491
- pred2_weighted=False,
492
- actual_weighted=True,
493
- save_path=save_path,
494
- show=False,
495
- )
496
- return
497
- if plt is None:
498
- _plot_skip("double lift plot")
499
- return
500
- lift_data = pd.DataFrame({
501
- 'pred1': pred_model_1,
502
- 'pred2': pred_model_2,
503
- 'act': w_act_list,
504
- 'weight': w_list
505
- })
506
- lift_data['diff_ly'] = lift_data['pred1'] / lift_data['pred2']
507
- lift_data['w_pred1'] = lift_data['pred1'] * lift_data['weight']
508
- lift_data['w_pred2'] = lift_data['pred2'] * lift_data['weight']
509
- plot_data = PlotUtils.split_data(
510
- lift_data, 'diff_ly', 'weight', n_bins)
511
- plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
512
- plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
513
- plot_data['act_v'] = plot_data['act']/plot_data['act']
514
- plot_data.reset_index(inplace=True)
515
-
516
- fig = plt.figure(figsize=(7, 5))
517
- ax = fig.add_subplot(111)
518
- PlotUtils.plot_dlift_ax(
519
- ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
520
- plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
521
-
522
- save_path = os.path.join(
523
- os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
524
- IOUtils.ensure_parent_dir(save_path)
525
- plt.savefig(save_path, dpi=300)
526
- plt.close(fig)
527
-
528
-
529
- # =============================================================================
530
- # Backward Compatibility Wrappers
531
- # =============================================================================
532
-
533
- def split_data(data, col_nme, wgt_nme, n_bins=10):
534
- """Legacy function wrapper for PlotUtils.split_data()."""
535
- return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
536
-
537
-
538
- def plot_lift_list(pred_model, w_pred_list, w_act_list,
539
- weight_list, tgt_nme, n_bins=10,
540
- fig_nme='Lift Chart'):
541
- """Legacy function wrapper for PlotUtils.plot_lift_list()."""
542
- return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
543
- weight_list, tgt_nme, n_bins, fig_nme)
544
-
545
-
546
- def plot_dlift_list(pred_model_1, pred_model_2,
547
- model_nme_1, model_nme_2,
548
- tgt_nme,
549
- w_list, w_act_list, n_bins=10,
550
- fig_nme='Double Lift Chart'):
551
- """Legacy function wrapper for PlotUtils.plot_dlift_list()."""
552
- return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
553
- model_nme_1, model_nme_2,
554
- tgt_nme, w_list, w_act_list,
555
- n_bins, fig_nme)