ins-pricing 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +60 -0
- ins_pricing/__init__.py +102 -0
- ins_pricing/governance/README.md +18 -0
- ins_pricing/governance/__init__.py +20 -0
- ins_pricing/governance/approval.py +93 -0
- ins_pricing/governance/audit.py +37 -0
- ins_pricing/governance/registry.py +99 -0
- ins_pricing/governance/release.py +159 -0
- ins_pricing/modelling/BayesOpt.py +146 -0
- ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
- ins_pricing/modelling/BayesOpt_entry.py +575 -0
- ins_pricing/modelling/BayesOpt_incremental.py +731 -0
- ins_pricing/modelling/Explain_Run.py +36 -0
- ins_pricing/modelling/Explain_entry.py +539 -0
- ins_pricing/modelling/Pricing_Run.py +36 -0
- ins_pricing/modelling/README.md +33 -0
- ins_pricing/modelling/__init__.py +44 -0
- ins_pricing/modelling/bayesopt/__init__.py +98 -0
- ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
- ins_pricing/modelling/bayesopt/core.py +1476 -0
- ins_pricing/modelling/bayesopt/models.py +2196 -0
- ins_pricing/modelling/bayesopt/trainers.py +2446 -0
- ins_pricing/modelling/bayesopt/utils.py +1021 -0
- ins_pricing/modelling/cli_common.py +136 -0
- ins_pricing/modelling/explain/__init__.py +55 -0
- ins_pricing/modelling/explain/gradients.py +334 -0
- ins_pricing/modelling/explain/metrics.py +176 -0
- ins_pricing/modelling/explain/permutation.py +155 -0
- ins_pricing/modelling/explain/shap_utils.py +146 -0
- ins_pricing/modelling/notebook_utils.py +284 -0
- ins_pricing/modelling/plotting/__init__.py +45 -0
- ins_pricing/modelling/plotting/common.py +63 -0
- ins_pricing/modelling/plotting/curves.py +572 -0
- ins_pricing/modelling/plotting/diagnostics.py +139 -0
- ins_pricing/modelling/plotting/geo.py +362 -0
- ins_pricing/modelling/plotting/importance.py +121 -0
- ins_pricing/modelling/run_logging.py +133 -0
- ins_pricing/modelling/tests/conftest.py +8 -0
- ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing/modelling/tests/test_explain.py +56 -0
- ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing/modelling/tests/test_plotting.py +63 -0
- ins_pricing/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing/modelling/watchdog_run.py +211 -0
- ins_pricing/pricing/README.md +44 -0
- ins_pricing/pricing/__init__.py +27 -0
- ins_pricing/pricing/calibration.py +39 -0
- ins_pricing/pricing/data_quality.py +117 -0
- ins_pricing/pricing/exposure.py +85 -0
- ins_pricing/pricing/factors.py +91 -0
- ins_pricing/pricing/monitoring.py +99 -0
- ins_pricing/pricing/rate_table.py +78 -0
- ins_pricing/production/__init__.py +21 -0
- ins_pricing/production/drift.py +30 -0
- ins_pricing/production/monitoring.py +143 -0
- ins_pricing/production/scoring.py +40 -0
- ins_pricing/reporting/README.md +20 -0
- ins_pricing/reporting/__init__.py +11 -0
- ins_pricing/reporting/report_builder.py +72 -0
- ins_pricing/reporting/scheduler.py +45 -0
- ins_pricing/setup.py +41 -0
- ins_pricing v2/__init__.py +23 -0
- ins_pricing v2/governance/__init__.py +20 -0
- ins_pricing v2/governance/approval.py +93 -0
- ins_pricing v2/governance/audit.py +37 -0
- ins_pricing v2/governance/registry.py +99 -0
- ins_pricing v2/governance/release.py +159 -0
- ins_pricing v2/modelling/Explain_Run.py +36 -0
- ins_pricing v2/modelling/Pricing_Run.py +36 -0
- ins_pricing v2/modelling/__init__.py +151 -0
- ins_pricing v2/modelling/cli_common.py +141 -0
- ins_pricing v2/modelling/config.py +249 -0
- ins_pricing v2/modelling/config_preprocess.py +254 -0
- ins_pricing v2/modelling/core.py +741 -0
- ins_pricing v2/modelling/data_container.py +42 -0
- ins_pricing v2/modelling/explain/__init__.py +55 -0
- ins_pricing v2/modelling/explain/gradients.py +334 -0
- ins_pricing v2/modelling/explain/metrics.py +176 -0
- ins_pricing v2/modelling/explain/permutation.py +155 -0
- ins_pricing v2/modelling/explain/shap_utils.py +146 -0
- ins_pricing v2/modelling/features.py +215 -0
- ins_pricing v2/modelling/model_manager.py +148 -0
- ins_pricing v2/modelling/model_plotting.py +463 -0
- ins_pricing v2/modelling/models.py +2203 -0
- ins_pricing v2/modelling/notebook_utils.py +294 -0
- ins_pricing v2/modelling/plotting/__init__.py +45 -0
- ins_pricing v2/modelling/plotting/common.py +63 -0
- ins_pricing v2/modelling/plotting/curves.py +572 -0
- ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
- ins_pricing v2/modelling/plotting/geo.py +362 -0
- ins_pricing v2/modelling/plotting/importance.py +121 -0
- ins_pricing v2/modelling/run_logging.py +133 -0
- ins_pricing v2/modelling/tests/conftest.py +8 -0
- ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing v2/modelling/tests/test_explain.py +56 -0
- ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing v2/modelling/tests/test_plotting.py +63 -0
- ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing v2/modelling/trainers.py +2447 -0
- ins_pricing v2/modelling/utils.py +1020 -0
- ins_pricing v2/modelling/watchdog_run.py +211 -0
- ins_pricing v2/pricing/__init__.py +27 -0
- ins_pricing v2/pricing/calibration.py +39 -0
- ins_pricing v2/pricing/data_quality.py +117 -0
- ins_pricing v2/pricing/exposure.py +85 -0
- ins_pricing v2/pricing/factors.py +91 -0
- ins_pricing v2/pricing/monitoring.py +99 -0
- ins_pricing v2/pricing/rate_table.py +78 -0
- ins_pricing v2/production/__init__.py +21 -0
- ins_pricing v2/production/drift.py +30 -0
- ins_pricing v2/production/monitoring.py +143 -0
- ins_pricing v2/production/scoring.py +40 -0
- ins_pricing v2/reporting/__init__.py +11 -0
- ins_pricing v2/reporting/report_builder.py +72 -0
- ins_pricing v2/reporting/scheduler.py +45 -0
- ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
- ins_pricing v2/scripts/Explain_entry.py +545 -0
- ins_pricing v2/scripts/__init__.py +1 -0
- ins_pricing v2/scripts/train.py +568 -0
- ins_pricing v2/setup.py +55 -0
- ins_pricing v2/smoke_test.py +28 -0
- ins_pricing-0.1.6.dist-info/METADATA +78 -0
- ins_pricing-0.1.6.dist-info/RECORD +169 -0
- ins_pricing-0.1.6.dist-info/WHEEL +5 -0
- ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
- user_packages/__init__.py +105 -0
- user_packages legacy/BayesOpt.py +5659 -0
- user_packages legacy/BayesOpt_entry.py +513 -0
- user_packages legacy/BayesOpt_incremental.py +685 -0
- user_packages legacy/Pricing_Run.py +36 -0
- user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
- user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
- user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
- user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
- user_packages legacy/Try/BayesOpt legacy.py +3280 -0
- user_packages legacy/Try/BayesOpt.py +838 -0
- user_packages legacy/Try/BayesOptAll.py +1569 -0
- user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
- user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
- user_packages legacy/Try/BayesOptSearch.py +830 -0
- user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
- user_packages legacy/Try/BayesOptV1.py +1911 -0
- user_packages legacy/Try/BayesOptV10.py +2973 -0
- user_packages legacy/Try/BayesOptV11.py +3001 -0
- user_packages legacy/Try/BayesOptV12.py +3001 -0
- user_packages legacy/Try/BayesOptV2.py +2065 -0
- user_packages legacy/Try/BayesOptV3.py +2209 -0
- user_packages legacy/Try/BayesOptV4.py +2342 -0
- user_packages legacy/Try/BayesOptV5.py +2372 -0
- user_packages legacy/Try/BayesOptV6.py +2759 -0
- user_packages legacy/Try/BayesOptV7.py +2832 -0
- user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
- user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
- user_packages legacy/Try/BayesOptV9.py +2927 -0
- user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
- user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
- user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
- user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
- user_packages legacy/Try/xgbbayesopt.py +523 -0
- user_packages legacy/__init__.py +19 -0
- user_packages legacy/cli_common.py +124 -0
- user_packages legacy/notebook_utils.py +228 -0
- user_packages legacy/watchdog_run.py +202 -0
|
@@ -0,0 +1,572 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Mapping, Optional, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from .common import EPS, PlotStyle, finalize_figure, plt
|
|
9
|
+
|
|
10
|
+
try: # optional dependency guard
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
auc,
|
|
13
|
+
average_precision_score,
|
|
14
|
+
precision_recall_curve,
|
|
15
|
+
roc_curve,
|
|
16
|
+
)
|
|
17
|
+
from sklearn.calibration import calibration_curve
|
|
18
|
+
except Exception: # pragma: no cover - handled at call time
|
|
19
|
+
auc = None
|
|
20
|
+
average_precision_score = None
|
|
21
|
+
precision_recall_curve = None
|
|
22
|
+
roc_curve = None
|
|
23
|
+
calibration_curve = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _require_sklearn(func_name: str) -> None:
|
|
27
|
+
if roc_curve is None or auc is None:
|
|
28
|
+
raise RuntimeError(f"{func_name} requires scikit-learn to be installed.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _to_1d(values: Sequence[float], name: str) -> np.ndarray:
|
|
32
|
+
arr = np.asarray(values, dtype=float).reshape(-1)
|
|
33
|
+
if arr.size == 0:
|
|
34
|
+
raise ValueError(f"{name} is empty.")
|
|
35
|
+
return arr
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _align_arrays(
|
|
39
|
+
pred: Sequence[float],
|
|
40
|
+
actual: Sequence[float],
|
|
41
|
+
weight: Optional[Sequence[float]] = None,
|
|
42
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
43
|
+
pred_arr = _to_1d(pred, "pred")
|
|
44
|
+
actual_arr = _to_1d(actual, "actual")
|
|
45
|
+
if len(pred_arr) != len(actual_arr):
|
|
46
|
+
raise ValueError("pred and actual must have the same length.")
|
|
47
|
+
if weight is None:
|
|
48
|
+
weight_arr = np.ones_like(pred_arr, dtype=float)
|
|
49
|
+
else:
|
|
50
|
+
weight_arr = _to_1d(weight, "weight")
|
|
51
|
+
if len(weight_arr) != len(pred_arr):
|
|
52
|
+
raise ValueError("weight must have the same length as pred.")
|
|
53
|
+
|
|
54
|
+
mask = np.isfinite(pred_arr) & np.isfinite(actual_arr) & np.isfinite(weight_arr)
|
|
55
|
+
pred_arr = pred_arr[mask]
|
|
56
|
+
actual_arr = actual_arr[mask]
|
|
57
|
+
weight_arr = weight_arr[mask]
|
|
58
|
+
return pred_arr, actual_arr, weight_arr
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _bin_by_weight(
|
|
62
|
+
data: pd.DataFrame,
|
|
63
|
+
*,
|
|
64
|
+
sort_col: str,
|
|
65
|
+
weight_col: str,
|
|
66
|
+
n_bins: int,
|
|
67
|
+
) -> pd.DataFrame:
|
|
68
|
+
n_bins = max(1, int(n_bins))
|
|
69
|
+
data_sorted = data.sort_values(by=sort_col, ascending=True).copy()
|
|
70
|
+
weight_sum = float(data_sorted[weight_col].sum())
|
|
71
|
+
if weight_sum <= EPS:
|
|
72
|
+
data_sorted.loc[:, "bins"] = 0
|
|
73
|
+
else:
|
|
74
|
+
data_sorted.loc[:, "cum_weight"] = data_sorted[weight_col].cumsum()
|
|
75
|
+
data_sorted.loc[:, "bins"] = np.floor(
|
|
76
|
+
data_sorted["cum_weight"] * float(n_bins) / weight_sum
|
|
77
|
+
)
|
|
78
|
+
data_sorted.loc[data_sorted["bins"] == n_bins, "bins"] = n_bins - 1
|
|
79
|
+
return data_sorted.groupby(["bins"], observed=True).sum(numeric_only=True)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def lift_table(
|
|
83
|
+
pred: Sequence[float],
|
|
84
|
+
actual: Sequence[float],
|
|
85
|
+
weight: Optional[Sequence[float]] = None,
|
|
86
|
+
*,
|
|
87
|
+
n_bins: int = 10,
|
|
88
|
+
pred_weighted: bool = False,
|
|
89
|
+
actual_weighted: bool = True,
|
|
90
|
+
) -> pd.DataFrame:
|
|
91
|
+
"""Compute lift table for a single model.
|
|
92
|
+
|
|
93
|
+
pred/actual should be 1d arrays. If pred_weighted/actual_weighted is True,
|
|
94
|
+
the value is already multiplied by weight and will not be re-weighted.
|
|
95
|
+
"""
|
|
96
|
+
pred_arr, actual_arr, weight_arr = _align_arrays(pred, actual, weight)
|
|
97
|
+
weight_safe = np.maximum(weight_arr, EPS)
|
|
98
|
+
|
|
99
|
+
if pred_weighted:
|
|
100
|
+
pred_raw = pred_arr / weight_safe
|
|
101
|
+
w_pred = pred_arr
|
|
102
|
+
else:
|
|
103
|
+
pred_raw = pred_arr
|
|
104
|
+
w_pred = pred_arr * weight_arr
|
|
105
|
+
|
|
106
|
+
if actual_weighted:
|
|
107
|
+
w_act = actual_arr
|
|
108
|
+
else:
|
|
109
|
+
w_act = actual_arr * weight_arr
|
|
110
|
+
|
|
111
|
+
lift_df = pd.DataFrame(
|
|
112
|
+
{
|
|
113
|
+
"pred_sort": pred_raw,
|
|
114
|
+
"w_pred": w_pred,
|
|
115
|
+
"act": w_act,
|
|
116
|
+
"weight": weight_arr,
|
|
117
|
+
}
|
|
118
|
+
)
|
|
119
|
+
plot_data = _bin_by_weight(
|
|
120
|
+
lift_df, sort_col="pred_sort", weight_col="weight", n_bins=n_bins
|
|
121
|
+
)
|
|
122
|
+
denom = np.maximum(plot_data["weight"], EPS)
|
|
123
|
+
plot_data["exp_v"] = plot_data["w_pred"] / denom
|
|
124
|
+
plot_data["act_v"] = plot_data["act"] / denom
|
|
125
|
+
plot_data.reset_index(inplace=True)
|
|
126
|
+
return plot_data
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def plot_lift_curve(
|
|
130
|
+
pred: Sequence[float],
|
|
131
|
+
actual: Sequence[float],
|
|
132
|
+
weight: Optional[Sequence[float]] = None,
|
|
133
|
+
*,
|
|
134
|
+
n_bins: int = 10,
|
|
135
|
+
title: str = "Lift Chart",
|
|
136
|
+
pred_label: str = "Predicted",
|
|
137
|
+
act_label: str = "Actual",
|
|
138
|
+
weight_label: str = "Earned Exposure",
|
|
139
|
+
pred_weighted: bool = False,
|
|
140
|
+
actual_weighted: bool = True,
|
|
141
|
+
ax: Optional[plt.Axes] = None,
|
|
142
|
+
show: bool = False,
|
|
143
|
+
save_path: Optional[str] = None,
|
|
144
|
+
style: Optional[PlotStyle] = None,
|
|
145
|
+
) -> plt.Figure:
|
|
146
|
+
style = style or PlotStyle()
|
|
147
|
+
plot_data = lift_table(
|
|
148
|
+
pred,
|
|
149
|
+
actual,
|
|
150
|
+
weight,
|
|
151
|
+
n_bins=n_bins,
|
|
152
|
+
pred_weighted=pred_weighted,
|
|
153
|
+
actual_weighted=actual_weighted,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
created_fig = ax is None
|
|
157
|
+
if created_fig:
|
|
158
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
159
|
+
else:
|
|
160
|
+
fig = ax.figure
|
|
161
|
+
|
|
162
|
+
ax.plot(plot_data.index, plot_data["act_v"], label=act_label, color="red")
|
|
163
|
+
ax.plot(plot_data.index, plot_data["exp_v"], label=pred_label, color="blue")
|
|
164
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
165
|
+
ax.set_xticks(plot_data.index)
|
|
166
|
+
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=style.tick_size)
|
|
167
|
+
ax.tick_params(axis="y", labelsize=style.tick_size)
|
|
168
|
+
if style.grid:
|
|
169
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
170
|
+
ax.legend(loc="upper left", fontsize=style.legend_size, frameon=False)
|
|
171
|
+
ax.margins(0.05)
|
|
172
|
+
|
|
173
|
+
ax2 = ax.twinx()
|
|
174
|
+
ax2.bar(
|
|
175
|
+
plot_data.index,
|
|
176
|
+
plot_data["weight"],
|
|
177
|
+
alpha=0.5,
|
|
178
|
+
color=style.weight_color,
|
|
179
|
+
label=weight_label,
|
|
180
|
+
)
|
|
181
|
+
ax2.tick_params(axis="y", labelsize=style.tick_size)
|
|
182
|
+
ax2.legend(loc="upper right", fontsize=style.legend_size, frameon=False)
|
|
183
|
+
|
|
184
|
+
if created_fig:
|
|
185
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
186
|
+
|
|
187
|
+
return fig
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def double_lift_table(
|
|
191
|
+
pred1: Sequence[float],
|
|
192
|
+
pred2: Sequence[float],
|
|
193
|
+
actual: Sequence[float],
|
|
194
|
+
weight: Optional[Sequence[float]] = None,
|
|
195
|
+
*,
|
|
196
|
+
n_bins: int = 10,
|
|
197
|
+
pred1_weighted: bool = False,
|
|
198
|
+
pred2_weighted: bool = False,
|
|
199
|
+
actual_weighted: bool = True,
|
|
200
|
+
) -> pd.DataFrame:
|
|
201
|
+
pred1_arr, actual_arr, weight_arr = _align_arrays(pred1, actual, weight)
|
|
202
|
+
pred2_arr, _, _ = _align_arrays(pred2, actual, weight_arr)
|
|
203
|
+
|
|
204
|
+
weight_safe = np.maximum(weight_arr, EPS)
|
|
205
|
+
pred1_raw = pred1_arr / weight_safe if pred1_weighted else pred1_arr
|
|
206
|
+
pred2_raw = pred2_arr / weight_safe if pred2_weighted else pred2_arr
|
|
207
|
+
|
|
208
|
+
w_pred1 = pred1_raw * weight_arr
|
|
209
|
+
w_pred2 = pred2_raw * weight_arr
|
|
210
|
+
w_act = actual_arr if actual_weighted else actual_arr * weight_arr
|
|
211
|
+
|
|
212
|
+
lift_df = pd.DataFrame(
|
|
213
|
+
{
|
|
214
|
+
"diff_ly": pred1_raw / np.maximum(pred2_raw, EPS),
|
|
215
|
+
"pred1": w_pred1,
|
|
216
|
+
"pred2": w_pred2,
|
|
217
|
+
"act": w_act,
|
|
218
|
+
"weight": weight_arr,
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
plot_data = _bin_by_weight(
|
|
222
|
+
lift_df, sort_col="diff_ly", weight_col="weight", n_bins=n_bins
|
|
223
|
+
)
|
|
224
|
+
denom = np.maximum(plot_data["act"], EPS)
|
|
225
|
+
plot_data["exp_v1"] = plot_data["pred1"] / denom
|
|
226
|
+
plot_data["exp_v2"] = plot_data["pred2"] / denom
|
|
227
|
+
plot_data["act_v"] = plot_data["act"] / denom
|
|
228
|
+
plot_data.reset_index(inplace=True)
|
|
229
|
+
return plot_data
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def plot_double_lift_curve(
|
|
233
|
+
pred1: Sequence[float],
|
|
234
|
+
pred2: Sequence[float],
|
|
235
|
+
actual: Sequence[float],
|
|
236
|
+
weight: Optional[Sequence[float]] = None,
|
|
237
|
+
*,
|
|
238
|
+
n_bins: int = 10,
|
|
239
|
+
title: str = "Double Lift Chart",
|
|
240
|
+
label1: str = "Model 1",
|
|
241
|
+
label2: str = "Model 2",
|
|
242
|
+
act_label: str = "Actual",
|
|
243
|
+
weight_label: str = "Earned Exposure",
|
|
244
|
+
pred1_weighted: bool = False,
|
|
245
|
+
pred2_weighted: bool = False,
|
|
246
|
+
actual_weighted: bool = True,
|
|
247
|
+
ax: Optional[plt.Axes] = None,
|
|
248
|
+
show: bool = False,
|
|
249
|
+
save_path: Optional[str] = None,
|
|
250
|
+
style: Optional[PlotStyle] = None,
|
|
251
|
+
) -> plt.Figure:
|
|
252
|
+
style = style or PlotStyle()
|
|
253
|
+
plot_data = double_lift_table(
|
|
254
|
+
pred1,
|
|
255
|
+
pred2,
|
|
256
|
+
actual,
|
|
257
|
+
weight,
|
|
258
|
+
n_bins=n_bins,
|
|
259
|
+
pred1_weighted=pred1_weighted,
|
|
260
|
+
pred2_weighted=pred2_weighted,
|
|
261
|
+
actual_weighted=actual_weighted,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
created_fig = ax is None
|
|
265
|
+
if created_fig:
|
|
266
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
267
|
+
else:
|
|
268
|
+
fig = ax.figure
|
|
269
|
+
|
|
270
|
+
ax.plot(plot_data.index, plot_data["act_v"], label=act_label, color="red")
|
|
271
|
+
ax.plot(plot_data.index, plot_data["exp_v1"], label=label1, color="blue")
|
|
272
|
+
ax.plot(plot_data.index, plot_data["exp_v2"], label=label2, color="black")
|
|
273
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
274
|
+
ax.set_xticks(plot_data.index)
|
|
275
|
+
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=style.tick_size)
|
|
276
|
+
ax.set_xlabel(f"{label1} / {label2}", fontsize=style.label_size)
|
|
277
|
+
ax.tick_params(axis="y", labelsize=style.tick_size)
|
|
278
|
+
if style.grid:
|
|
279
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
280
|
+
ax.legend(loc="upper left", fontsize=style.legend_size, frameon=False)
|
|
281
|
+
ax.margins(0.1)
|
|
282
|
+
|
|
283
|
+
ax2 = ax.twinx()
|
|
284
|
+
ax2.bar(
|
|
285
|
+
plot_data.index,
|
|
286
|
+
plot_data["weight"],
|
|
287
|
+
alpha=0.5,
|
|
288
|
+
color=style.weight_color,
|
|
289
|
+
label=weight_label,
|
|
290
|
+
)
|
|
291
|
+
ax2.tick_params(axis="y", labelsize=style.tick_size)
|
|
292
|
+
ax2.legend(loc="upper right", fontsize=style.legend_size, frameon=False)
|
|
293
|
+
|
|
294
|
+
if created_fig:
|
|
295
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
296
|
+
|
|
297
|
+
return fig
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def plot_roc_curves(
|
|
301
|
+
y_true: Sequence[float],
|
|
302
|
+
scores: Mapping[str, Sequence[float]],
|
|
303
|
+
*,
|
|
304
|
+
weight: Optional[Sequence[float]] = None,
|
|
305
|
+
title: str = "ROC Curve",
|
|
306
|
+
ax: Optional[plt.Axes] = None,
|
|
307
|
+
show: bool = False,
|
|
308
|
+
save_path: Optional[str] = None,
|
|
309
|
+
style: Optional[PlotStyle] = None,
|
|
310
|
+
) -> plt.Figure:
|
|
311
|
+
_require_sklearn("plot_roc_curves")
|
|
312
|
+
style = style or PlotStyle()
|
|
313
|
+
|
|
314
|
+
created_fig = ax is None
|
|
315
|
+
if created_fig:
|
|
316
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
317
|
+
else:
|
|
318
|
+
fig = ax.figure
|
|
319
|
+
|
|
320
|
+
for idx, (label, score) in enumerate(scores.items()):
|
|
321
|
+
s_arr, y_arr, w_arr = _align_arrays(score, y_true, weight)
|
|
322
|
+
try:
|
|
323
|
+
fpr, tpr, _ = roc_curve(y_arr, s_arr, sample_weight=w_arr)
|
|
324
|
+
except TypeError:
|
|
325
|
+
fpr, tpr, _ = roc_curve(y_arr, s_arr)
|
|
326
|
+
auc_val = auc(fpr, tpr)
|
|
327
|
+
color = style.palette[idx % len(style.palette)]
|
|
328
|
+
ax.plot(fpr, tpr, color=color, label=f"{label} (AUC={auc_val:.3f})")
|
|
329
|
+
|
|
330
|
+
ax.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
|
|
331
|
+
ax.set_xlabel("False Positive Rate", fontsize=style.label_size)
|
|
332
|
+
ax.set_ylabel("True Positive Rate", fontsize=style.label_size)
|
|
333
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
334
|
+
ax.tick_params(axis="both", labelsize=style.tick_size)
|
|
335
|
+
if style.grid:
|
|
336
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
337
|
+
ax.legend(loc="lower right", fontsize=style.legend_size, frameon=False)
|
|
338
|
+
|
|
339
|
+
if created_fig:
|
|
340
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
341
|
+
|
|
342
|
+
return fig
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def plot_pr_curves(
|
|
346
|
+
y_true: Sequence[float],
|
|
347
|
+
scores: Mapping[str, Sequence[float]],
|
|
348
|
+
*,
|
|
349
|
+
weight: Optional[Sequence[float]] = None,
|
|
350
|
+
title: str = "Precision-Recall Curve",
|
|
351
|
+
ax: Optional[plt.Axes] = None,
|
|
352
|
+
show: bool = False,
|
|
353
|
+
save_path: Optional[str] = None,
|
|
354
|
+
style: Optional[PlotStyle] = None,
|
|
355
|
+
) -> plt.Figure:
|
|
356
|
+
if precision_recall_curve is None or average_precision_score is None:
|
|
357
|
+
raise RuntimeError("plot_pr_curves requires scikit-learn to be installed.")
|
|
358
|
+
style = style or PlotStyle()
|
|
359
|
+
|
|
360
|
+
created_fig = ax is None
|
|
361
|
+
if created_fig:
|
|
362
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
363
|
+
else:
|
|
364
|
+
fig = ax.figure
|
|
365
|
+
|
|
366
|
+
for idx, (label, score) in enumerate(scores.items()):
|
|
367
|
+
s_arr, y_arr, w_arr = _align_arrays(score, y_true, weight)
|
|
368
|
+
try:
|
|
369
|
+
precision, recall, _ = precision_recall_curve(
|
|
370
|
+
y_arr, s_arr, sample_weight=w_arr
|
|
371
|
+
)
|
|
372
|
+
ap = average_precision_score(y_arr, s_arr, sample_weight=w_arr)
|
|
373
|
+
except TypeError:
|
|
374
|
+
precision, recall, _ = precision_recall_curve(y_arr, s_arr)
|
|
375
|
+
ap = average_precision_score(y_arr, s_arr)
|
|
376
|
+
color = style.palette[idx % len(style.palette)]
|
|
377
|
+
ax.plot(recall, precision, color=color, label=f"{label} (AP={ap:.3f})")
|
|
378
|
+
|
|
379
|
+
ax.set_xlabel("Recall", fontsize=style.label_size)
|
|
380
|
+
ax.set_ylabel("Precision", fontsize=style.label_size)
|
|
381
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
382
|
+
ax.tick_params(axis="both", labelsize=style.tick_size)
|
|
383
|
+
if style.grid:
|
|
384
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
385
|
+
ax.legend(loc="lower left", fontsize=style.legend_size, frameon=False)
|
|
386
|
+
|
|
387
|
+
if created_fig:
|
|
388
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
389
|
+
|
|
390
|
+
return fig
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def plot_ks_curve(
|
|
394
|
+
y_true: Sequence[float],
|
|
395
|
+
score: Sequence[float],
|
|
396
|
+
*,
|
|
397
|
+
weight: Optional[Sequence[float]] = None,
|
|
398
|
+
title: str = "KS Curve",
|
|
399
|
+
ax: Optional[plt.Axes] = None,
|
|
400
|
+
show: bool = False,
|
|
401
|
+
save_path: Optional[str] = None,
|
|
402
|
+
style: Optional[PlotStyle] = None,
|
|
403
|
+
) -> plt.Figure:
|
|
404
|
+
_require_sklearn("plot_ks_curve")
|
|
405
|
+
style = style or PlotStyle()
|
|
406
|
+
|
|
407
|
+
s_arr, y_arr, w_arr = _align_arrays(score, y_true, weight)
|
|
408
|
+
try:
|
|
409
|
+
fpr, tpr, thresholds = roc_curve(y_arr, s_arr, sample_weight=w_arr)
|
|
410
|
+
except TypeError:
|
|
411
|
+
fpr, tpr, thresholds = roc_curve(y_arr, s_arr)
|
|
412
|
+
ks_vals = tpr - fpr
|
|
413
|
+
ks_idx = int(np.argmax(ks_vals))
|
|
414
|
+
ks_val = float(ks_vals[ks_idx])
|
|
415
|
+
|
|
416
|
+
created_fig = ax is None
|
|
417
|
+
if created_fig:
|
|
418
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
419
|
+
else:
|
|
420
|
+
fig = ax.figure
|
|
421
|
+
|
|
422
|
+
ax.plot(thresholds, tpr, label="TPR", color=style.palette[0])
|
|
423
|
+
ax.plot(thresholds, fpr, label="FPR", color=style.palette[1])
|
|
424
|
+
ax.plot(thresholds, ks_vals, label=f"KS={ks_val:.3f}", color=style.palette[3])
|
|
425
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
426
|
+
ax.set_xlabel("Threshold", fontsize=style.label_size)
|
|
427
|
+
ax.set_ylabel("Rate", fontsize=style.label_size)
|
|
428
|
+
ax.tick_params(axis="both", labelsize=style.tick_size)
|
|
429
|
+
if style.grid:
|
|
430
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
431
|
+
ax.legend(loc="best", fontsize=style.legend_size, frameon=False)
|
|
432
|
+
|
|
433
|
+
if created_fig:
|
|
434
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
435
|
+
|
|
436
|
+
return fig
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def plot_calibration_curve(
|
|
440
|
+
y_true: Sequence[float],
|
|
441
|
+
score: Sequence[float],
|
|
442
|
+
*,
|
|
443
|
+
weight: Optional[Sequence[float]] = None,
|
|
444
|
+
n_bins: int = 10,
|
|
445
|
+
title: str = "Calibration Curve",
|
|
446
|
+
ax: Optional[plt.Axes] = None,
|
|
447
|
+
show: bool = False,
|
|
448
|
+
save_path: Optional[str] = None,
|
|
449
|
+
style: Optional[PlotStyle] = None,
|
|
450
|
+
) -> plt.Figure:
|
|
451
|
+
if calibration_curve is None:
|
|
452
|
+
raise RuntimeError("plot_calibration_curve requires scikit-learn to be installed.")
|
|
453
|
+
style = style or PlotStyle()
|
|
454
|
+
|
|
455
|
+
s_arr, y_arr, w_arr = _align_arrays(score, y_true, weight)
|
|
456
|
+
try:
|
|
457
|
+
prob_true, prob_pred = calibration_curve(
|
|
458
|
+
y_arr,
|
|
459
|
+
s_arr,
|
|
460
|
+
n_bins=max(2, int(n_bins)),
|
|
461
|
+
strategy="quantile",
|
|
462
|
+
sample_weight=w_arr,
|
|
463
|
+
)
|
|
464
|
+
except TypeError:
|
|
465
|
+
prob_true, prob_pred = calibration_curve(
|
|
466
|
+
y_arr,
|
|
467
|
+
s_arr,
|
|
468
|
+
n_bins=max(2, int(n_bins)),
|
|
469
|
+
strategy="quantile",
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
created_fig = ax is None
|
|
473
|
+
if created_fig:
|
|
474
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
475
|
+
else:
|
|
476
|
+
fig = ax.figure
|
|
477
|
+
|
|
478
|
+
ax.plot(prob_pred, prob_true, marker="o", label="Observed")
|
|
479
|
+
ax.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1, label="Ideal")
|
|
480
|
+
ax.set_xlabel("Mean Predicted", fontsize=style.label_size)
|
|
481
|
+
ax.set_ylabel("Mean Observed", fontsize=style.label_size)
|
|
482
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
483
|
+
ax.tick_params(axis="both", labelsize=style.tick_size)
|
|
484
|
+
if style.grid:
|
|
485
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
486
|
+
ax.legend(loc="best", fontsize=style.legend_size, frameon=False)
|
|
487
|
+
|
|
488
|
+
if created_fig:
|
|
489
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
490
|
+
|
|
491
|
+
return fig
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def plot_conversion_lift(
|
|
495
|
+
pred: Sequence[float],
|
|
496
|
+
actual_binary: Sequence[float],
|
|
497
|
+
weight: Optional[Sequence[float]] = None,
|
|
498
|
+
*,
|
|
499
|
+
n_bins: int = 20,
|
|
500
|
+
title: str = "Conversion Lift",
|
|
501
|
+
ax: Optional[plt.Axes] = None,
|
|
502
|
+
show: bool = False,
|
|
503
|
+
save_path: Optional[str] = None,
|
|
504
|
+
style: Optional[PlotStyle] = None,
|
|
505
|
+
) -> plt.Figure:
|
|
506
|
+
style = style or PlotStyle()
|
|
507
|
+
pred_arr, actual_arr, weight_arr = _align_arrays(pred, actual_binary, weight)
|
|
508
|
+
|
|
509
|
+
data = pd.DataFrame(
|
|
510
|
+
{
|
|
511
|
+
"pred": pred_arr,
|
|
512
|
+
"actual": actual_arr,
|
|
513
|
+
"weight": weight_arr,
|
|
514
|
+
}
|
|
515
|
+
)
|
|
516
|
+
data = data.sort_values(by="pred", ascending=True).copy()
|
|
517
|
+
data["cum_weight"] = data["weight"].cumsum()
|
|
518
|
+
total_weight = float(data["weight"].sum())
|
|
519
|
+
|
|
520
|
+
if total_weight > EPS:
|
|
521
|
+
data["bin"] = pd.cut(
|
|
522
|
+
data["cum_weight"],
|
|
523
|
+
bins=max(2, int(n_bins)),
|
|
524
|
+
labels=False,
|
|
525
|
+
right=False,
|
|
526
|
+
)
|
|
527
|
+
else:
|
|
528
|
+
data["bin"] = 0
|
|
529
|
+
|
|
530
|
+
data["weighted_actual"] = data["actual"] * data["weight"]
|
|
531
|
+
lift_agg = data.groupby("bin", observed=True).agg(
|
|
532
|
+
total_weight=("weight", "sum"),
|
|
533
|
+
weighted_actual=("weighted_actual", "sum"),
|
|
534
|
+
)
|
|
535
|
+
lift_agg = lift_agg.reset_index()
|
|
536
|
+
lift_agg["conversion_rate"] = lift_agg["weighted_actual"] / np.maximum(
|
|
537
|
+
lift_agg["total_weight"], EPS
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
overall_rate = float(lift_agg["weighted_actual"].sum()) / max(total_weight, EPS)
|
|
541
|
+
|
|
542
|
+
created_fig = ax is None
|
|
543
|
+
if created_fig:
|
|
544
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
545
|
+
else:
|
|
546
|
+
fig = ax.figure
|
|
547
|
+
|
|
548
|
+
ax.axhline(
|
|
549
|
+
y=overall_rate,
|
|
550
|
+
color="gray",
|
|
551
|
+
linestyle="--",
|
|
552
|
+
label=f"Overall ({overall_rate:.2%})",
|
|
553
|
+
)
|
|
554
|
+
ax.plot(
|
|
555
|
+
lift_agg["bin"],
|
|
556
|
+
lift_agg["conversion_rate"],
|
|
557
|
+
marker="o",
|
|
558
|
+
linestyle="-",
|
|
559
|
+
label="Actual Rate",
|
|
560
|
+
)
|
|
561
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
562
|
+
ax.set_xlabel("Score Bin", fontsize=style.label_size)
|
|
563
|
+
ax.set_ylabel("Conversion Rate", fontsize=style.label_size)
|
|
564
|
+
ax.tick_params(axis="both", labelsize=style.tick_size)
|
|
565
|
+
if style.grid:
|
|
566
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
567
|
+
ax.legend(loc="best", fontsize=style.legend_size, frameon=False)
|
|
568
|
+
|
|
569
|
+
if created_fig:
|
|
570
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
571
|
+
|
|
572
|
+
return fig
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Mapping, Optional, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from .common import EPS, PlotStyle, finalize_figure, plt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def plot_loss_curve(
|
|
12
|
+
*,
|
|
13
|
+
history: Optional[Mapping[str, Sequence[float]]] = None,
|
|
14
|
+
train: Optional[Sequence[float]] = None,
|
|
15
|
+
val: Optional[Sequence[float]] = None,
|
|
16
|
+
title: str = "Loss vs. Epoch",
|
|
17
|
+
ax: Optional[plt.Axes] = None,
|
|
18
|
+
show: bool = False,
|
|
19
|
+
save_path: Optional[str] = None,
|
|
20
|
+
style: Optional[PlotStyle] = None,
|
|
21
|
+
) -> Optional[plt.Figure]:
|
|
22
|
+
style = style or PlotStyle()
|
|
23
|
+
if history is not None:
|
|
24
|
+
if train is None:
|
|
25
|
+
train = history.get("train")
|
|
26
|
+
if val is None:
|
|
27
|
+
val = history.get("val")
|
|
28
|
+
|
|
29
|
+
train_hist = list(train or [])
|
|
30
|
+
val_hist = list(val or [])
|
|
31
|
+
if not train_hist and not val_hist:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
created_fig = ax is None
|
|
35
|
+
if created_fig:
|
|
36
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
37
|
+
else:
|
|
38
|
+
fig = ax.figure
|
|
39
|
+
|
|
40
|
+
if train_hist:
|
|
41
|
+
ax.plot(
|
|
42
|
+
range(1, len(train_hist) + 1),
|
|
43
|
+
train_hist,
|
|
44
|
+
label="Train Loss",
|
|
45
|
+
color="tab:blue",
|
|
46
|
+
)
|
|
47
|
+
if val_hist:
|
|
48
|
+
ax.plot(
|
|
49
|
+
range(1, len(val_hist) + 1),
|
|
50
|
+
val_hist,
|
|
51
|
+
label="Validation Loss",
|
|
52
|
+
color="tab:orange",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
ax.set_xlabel("Epoch", fontsize=style.label_size)
|
|
56
|
+
ax.set_ylabel("Weighted Loss", fontsize=style.label_size)
|
|
57
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
58
|
+
ax.tick_params(axis="both", labelsize=style.tick_size)
|
|
59
|
+
if style.grid:
|
|
60
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
61
|
+
ax.legend(loc="best", fontsize=style.legend_size, frameon=False)
|
|
62
|
+
|
|
63
|
+
if created_fig:
|
|
64
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
65
|
+
|
|
66
|
+
return fig
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def plot_oneway(
|
|
70
|
+
df: pd.DataFrame,
|
|
71
|
+
*,
|
|
72
|
+
feature: str,
|
|
73
|
+
weight_col: str,
|
|
74
|
+
target_col: str,
|
|
75
|
+
n_bins: int = 10,
|
|
76
|
+
is_categorical: bool = False,
|
|
77
|
+
title: Optional[str] = None,
|
|
78
|
+
ax: Optional[plt.Axes] = None,
|
|
79
|
+
show: bool = False,
|
|
80
|
+
save_path: Optional[str] = None,
|
|
81
|
+
style: Optional[PlotStyle] = None,
|
|
82
|
+
) -> Optional[plt.Figure]:
|
|
83
|
+
if feature not in df.columns:
|
|
84
|
+
raise KeyError(f"feature '{feature}' not found in data.")
|
|
85
|
+
if weight_col not in df.columns:
|
|
86
|
+
raise KeyError(f"weight_col '{weight_col}' not found in data.")
|
|
87
|
+
if target_col not in df.columns:
|
|
88
|
+
raise KeyError(f"target_col '{target_col}' not found in data.")
|
|
89
|
+
|
|
90
|
+
style = style or PlotStyle()
|
|
91
|
+
title = title or f"Analysis of {feature}"
|
|
92
|
+
|
|
93
|
+
if is_categorical:
|
|
94
|
+
group_col = feature
|
|
95
|
+
plot_source = df
|
|
96
|
+
else:
|
|
97
|
+
group_col = f"{feature}_bins"
|
|
98
|
+
series = pd.to_numeric(df[feature], errors="coerce")
|
|
99
|
+
try:
|
|
100
|
+
bins = pd.qcut(series, n_bins, duplicates="drop")
|
|
101
|
+
except ValueError:
|
|
102
|
+
bins = pd.cut(series, bins=max(1, int(n_bins)), duplicates="drop")
|
|
103
|
+
plot_source = df.assign(**{group_col: bins})
|
|
104
|
+
|
|
105
|
+
plot_data = plot_source.groupby([group_col], observed=True).sum(numeric_only=True)
|
|
106
|
+
plot_data.reset_index(inplace=True)
|
|
107
|
+
|
|
108
|
+
denom = np.maximum(plot_data[weight_col].to_numpy(dtype=float), EPS)
|
|
109
|
+
plot_data["act_v"] = plot_data[target_col].to_numpy(dtype=float) / denom
|
|
110
|
+
|
|
111
|
+
created_fig = ax is None
|
|
112
|
+
if created_fig:
|
|
113
|
+
fig, ax = plt.subplots(figsize=style.figsize)
|
|
114
|
+
else:
|
|
115
|
+
fig = ax.figure
|
|
116
|
+
|
|
117
|
+
ax.plot(plot_data.index, plot_data["act_v"], label="Actual", color="red")
|
|
118
|
+
ax.set_title(title, fontsize=style.title_size)
|
|
119
|
+
ax.set_xticks(plot_data.index)
|
|
120
|
+
labels = plot_data[group_col].astype(str).tolist()
|
|
121
|
+
tick_size = 3 if len(labels) > 50 else style.tick_size
|
|
122
|
+
ax.set_xticklabels(labels, rotation=90, fontsize=tick_size)
|
|
123
|
+
ax.tick_params(axis="y", labelsize=style.tick_size)
|
|
124
|
+
if style.grid:
|
|
125
|
+
ax.grid(True, linestyle=style.grid_style, alpha=style.grid_alpha)
|
|
126
|
+
|
|
127
|
+
ax2 = ax.twinx()
|
|
128
|
+
ax2.bar(
|
|
129
|
+
plot_data.index,
|
|
130
|
+
plot_data[weight_col],
|
|
131
|
+
alpha=0.5,
|
|
132
|
+
color=style.weight_color,
|
|
133
|
+
)
|
|
134
|
+
ax2.tick_params(axis="y", labelsize=style.tick_size)
|
|
135
|
+
|
|
136
|
+
if created_fig:
|
|
137
|
+
finalize_figure(fig, save_path=save_path, show=show, style=style)
|
|
138
|
+
|
|
139
|
+
return fig
|