theseusplot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- theseusplot/__init__.py +11 -0
- theseusplot/_config.py +43 -0
- theseusplot/_ship.py +1224 -0
- theseusplot/py.typed +1 -0
- theseusplot-0.1.0.dist-info/METADATA +315 -0
- theseusplot-0.1.0.dist-info/RECORD +7 -0
- theseusplot-0.1.0.dist-info/WHEEL +4 -0
theseusplot/_ship.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
1
|
+
"""ShipOfTheseus public object."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import warnings
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from typing import Any, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from numpy.typing import NDArray
|
|
13
|
+
from pandas.api.types import is_numeric_dtype, is_string_dtype
|
|
14
|
+
|
|
15
|
+
from theseusplot._config import ContinuousConfig, continuous_config
|
|
16
|
+
|
|
17
|
+
_OUTCOME_COLUMN = ".outcome"
|
|
18
|
+
_MISSING_LABEL = "(Missing)"
|
|
19
|
+
_MIN_BREAK_COUNT = 2
|
|
20
|
+
_REFITTED_COLOR = "#00BFC4"
|
|
21
|
+
_ORIGINAL_SIZE_COLOR = "#7CAE00"
|
|
22
|
+
_REFITTED_SIZE_COLOR = "#C77CFF"
|
|
23
|
+
_POSITIVE_COLOR = "#F8766D"
|
|
24
|
+
_NEGATIVE_COLOR = "#00BFC4"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ShipOfTheseus:
|
|
28
|
+
"""Container for data and methods used to create Theseus plots."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
data1: pd.DataFrame,
|
|
33
|
+
data2: pd.DataFrame,
|
|
34
|
+
outcome: str,
|
|
35
|
+
labels: Sequence[str],
|
|
36
|
+
y_label: str | None,
|
|
37
|
+
digits: int,
|
|
38
|
+
text_size: float,
|
|
39
|
+
) -> None:
|
|
40
|
+
self._validate_inputs(data1, data2, outcome, labels)
|
|
41
|
+
|
|
42
|
+
self._data1 = self._prepare_input_data(data1, outcome)
|
|
43
|
+
self._data2 = self._prepare_input_data(data2, outcome)
|
|
44
|
+
self.outcome = outcome
|
|
45
|
+
self.labels = (labels[0], labels[1])
|
|
46
|
+
self.y_label = y_label
|
|
47
|
+
self.digits = digits
|
|
48
|
+
self.text_size = text_size
|
|
49
|
+
self._cache: dict[tuple[Any, ...], Any] = {}
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _validate_inputs(
|
|
53
|
+
data1: pd.DataFrame,
|
|
54
|
+
data2: pd.DataFrame,
|
|
55
|
+
outcome: str,
|
|
56
|
+
labels: Sequence[str],
|
|
57
|
+
) -> None:
|
|
58
|
+
if not isinstance(data1, pd.DataFrame):
|
|
59
|
+
msg = "data1 must be a pandas DataFrame."
|
|
60
|
+
raise TypeError(msg)
|
|
61
|
+
if not isinstance(data2, pd.DataFrame):
|
|
62
|
+
msg = "data2 must be a pandas DataFrame."
|
|
63
|
+
raise TypeError(msg)
|
|
64
|
+
if outcome not in data1.columns:
|
|
65
|
+
msg = f"outcome column {outcome!r} is missing from data1."
|
|
66
|
+
raise ValueError(msg)
|
|
67
|
+
if outcome not in data2.columns:
|
|
68
|
+
msg = f"outcome column {outcome!r} is missing from data2."
|
|
69
|
+
raise ValueError(msg)
|
|
70
|
+
if len(labels) != 2:
|
|
71
|
+
msg = "labels must contain exactly two values."
|
|
72
|
+
raise ValueError(msg)
|
|
73
|
+
if data1[outcome].isna().any() or data2[outcome].isna().any():
|
|
74
|
+
msg = "outcome values must not contain missing values."
|
|
75
|
+
raise ValueError(msg)
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _prepare_input_data(data: pd.DataFrame, outcome: str) -> pd.DataFrame:
|
|
79
|
+
prepared = data.copy(deep=True)
|
|
80
|
+
for column in prepared.columns:
|
|
81
|
+
if column == outcome:
|
|
82
|
+
continue
|
|
83
|
+
prepared[column] = ShipOfTheseus._fill_missing_categories(
|
|
84
|
+
prepared[column],
|
|
85
|
+
)
|
|
86
|
+
prepared[_OUTCOME_COLUMN] = prepared[outcome]
|
|
87
|
+
return prepared
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def _fill_missing_categories(series: pd.Series) -> pd.Series:
|
|
91
|
+
if isinstance(series.dtype, pd.CategoricalDtype):
|
|
92
|
+
result = series.copy()
|
|
93
|
+
if _MISSING_LABEL not in result.cat.categories:
|
|
94
|
+
result = result.cat.add_categories([_MISSING_LABEL])
|
|
95
|
+
return result.fillna(_MISSING_LABEL)
|
|
96
|
+
if is_string_dtype(series.dtype) or series.dtype == object:
|
|
97
|
+
return series.fillna(_MISSING_LABEL)
|
|
98
|
+
return series
|
|
99
|
+
|
|
100
|
+
def table(
|
|
101
|
+
self,
|
|
102
|
+
column: str,
|
|
103
|
+
n: int | float = float("inf"),
|
|
104
|
+
continuous: ContinuousConfig | None = None,
|
|
105
|
+
) -> pd.DataFrame:
|
|
106
|
+
"""Generate a contribution table for a column."""
|
|
107
|
+
|
|
108
|
+
self._validate_column(column)
|
|
109
|
+
limit = self._normalize_n(n)
|
|
110
|
+
continuous_config_value = continuous or continuous_config()
|
|
111
|
+
|
|
112
|
+
data_contrib = self._compute_contribution(column, continuous_config_value)
|
|
113
|
+
data_info = self._compute_info(column, continuous_config_value)
|
|
114
|
+
result = data_contrib.merge(data_info, on="items", how="left")
|
|
115
|
+
|
|
116
|
+
is_factor = isinstance(result["items"].dtype, pd.CategoricalDtype)
|
|
117
|
+
if is_factor:
|
|
118
|
+
result = result.sort_values("items", kind="stable").reset_index(drop=True)
|
|
119
|
+
levels = list(result["items"].cat.categories)
|
|
120
|
+
else:
|
|
121
|
+
result = self._sort_by_abs_contrib(result)
|
|
122
|
+
levels = []
|
|
123
|
+
|
|
124
|
+
n_items = len(result)
|
|
125
|
+
if n_items > limit:
|
|
126
|
+
other_count = n_items - limit + 1
|
|
127
|
+
sorted_result = self._sort_by_abs_contrib(result)
|
|
128
|
+
result_head = sorted_result.head(limit - 1).copy()
|
|
129
|
+
result_head["items"] = result_head["items"].astype(str)
|
|
130
|
+
|
|
131
|
+
result_tail = sorted_result.tail(other_count).copy()
|
|
132
|
+
other_label = f"Sum of {other_count} other attributes"
|
|
133
|
+
result_other = pd.DataFrame(
|
|
134
|
+
[
|
|
135
|
+
{
|
|
136
|
+
"items": other_label,
|
|
137
|
+
"contrib": result_tail["contrib"].sum(),
|
|
138
|
+
"n1": int(result_tail["n1"].sum()),
|
|
139
|
+
"n2": int(result_tail["n2"].sum()),
|
|
140
|
+
"x1": result_tail["x1"].sum(),
|
|
141
|
+
"x2": result_tail["x2"].sum(),
|
|
142
|
+
},
|
|
143
|
+
],
|
|
144
|
+
)
|
|
145
|
+
result_other["rate1"] = self._safe_rate(
|
|
146
|
+
result_other.loc[0, "x1"],
|
|
147
|
+
result_other.loc[0, "n1"],
|
|
148
|
+
)
|
|
149
|
+
result_other["rate2"] = self._safe_rate(
|
|
150
|
+
result_other.loc[0, "x2"],
|
|
151
|
+
result_other.loc[0, "n2"],
|
|
152
|
+
)
|
|
153
|
+
result = pd.concat([result_head, result_other], ignore_index=True)
|
|
154
|
+
if is_factor:
|
|
155
|
+
result["items"] = pd.Categorical(
|
|
156
|
+
result["items"],
|
|
157
|
+
categories=[*levels, other_label],
|
|
158
|
+
ordered=True,
|
|
159
|
+
)
|
|
160
|
+
result = result.sort_values("items", kind="stable").reset_index(
|
|
161
|
+
drop=True,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
result = result.rename(columns={"items": column})
|
|
165
|
+
return result[
|
|
166
|
+
[column, "contrib", "n1", "n2", "x1", "x2", "rate1", "rate2"]
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
def plot(
|
|
170
|
+
self,
|
|
171
|
+
column: str,
|
|
172
|
+
n: int = 10,
|
|
173
|
+
main_item: str | None = None,
|
|
174
|
+
bar_max_value: float | None = None,
|
|
175
|
+
levels: Sequence[str] | None = None,
|
|
176
|
+
continuous: ContinuousConfig | None = None,
|
|
177
|
+
ax: Any | None = None,
|
|
178
|
+
figsize: tuple[float, float] | None = None,
|
|
179
|
+
) -> Any:
|
|
180
|
+
"""Generate a Theseus plot for a column."""
|
|
181
|
+
|
|
182
|
+
continuous_config_value = continuous or continuous_config()
|
|
183
|
+
score1, _ = self._compute_scores(column)
|
|
184
|
+
table = self.table(column, n=n, continuous=continuous_config_value)
|
|
185
|
+
plot_data = self._plot_contribution_data(
|
|
186
|
+
table=table,
|
|
187
|
+
column=column,
|
|
188
|
+
levels=levels,
|
|
189
|
+
)
|
|
190
|
+
size_data = self._plot_size_data(
|
|
191
|
+
table=plot_data,
|
|
192
|
+
column=column,
|
|
193
|
+
main_item=main_item,
|
|
194
|
+
bar_max_value=bar_max_value,
|
|
195
|
+
)
|
|
196
|
+
waterfall = self._waterfall_data(
|
|
197
|
+
items=plot_data[column].astype(str).tolist(),
|
|
198
|
+
contributions=plot_data["contrib"].astype(float).tolist(),
|
|
199
|
+
start=score1,
|
|
200
|
+
)
|
|
201
|
+
return self._draw_plot(
|
|
202
|
+
waterfall=waterfall,
|
|
203
|
+
size_data=size_data,
|
|
204
|
+
column=column,
|
|
205
|
+
ax=ax,
|
|
206
|
+
figsize=figsize,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def plot_flip(
|
|
210
|
+
self,
|
|
211
|
+
column: str,
|
|
212
|
+
n: int = 10,
|
|
213
|
+
main_item: str | None = None,
|
|
214
|
+
bar_max_value: float | None = None,
|
|
215
|
+
levels: Sequence[str] | None = None,
|
|
216
|
+
continuous: ContinuousConfig | None = None,
|
|
217
|
+
ax: Any | None = None,
|
|
218
|
+
figsize: tuple[float, float] | None = None,
|
|
219
|
+
) -> Any:
|
|
220
|
+
"""Generate a horizontally oriented Theseus plot for a column."""
|
|
221
|
+
|
|
222
|
+
continuous_config_value = continuous or continuous_config()
|
|
223
|
+
_, score2 = self._compute_scores(column)
|
|
224
|
+
table = self.table(column, n=n, continuous=continuous_config_value)
|
|
225
|
+
table = table.copy()
|
|
226
|
+
table["contrib"] = -table["contrib"]
|
|
227
|
+
plot_data = self._plot_flip_contribution_data(
|
|
228
|
+
table=table,
|
|
229
|
+
column=column,
|
|
230
|
+
levels=levels,
|
|
231
|
+
)
|
|
232
|
+
size_data = self._plot_size_data(
|
|
233
|
+
table=plot_data,
|
|
234
|
+
column=column,
|
|
235
|
+
main_item=main_item,
|
|
236
|
+
bar_max_value=bar_max_value,
|
|
237
|
+
)
|
|
238
|
+
waterfall = self._waterfall_data(
|
|
239
|
+
items=plot_data[column].astype(str).tolist(),
|
|
240
|
+
contributions=plot_data["contrib"].astype(float).tolist(),
|
|
241
|
+
start=score2,
|
|
242
|
+
start_label=self.labels[1],
|
|
243
|
+
end_label=self.labels[0],
|
|
244
|
+
)
|
|
245
|
+
return self._draw_plot_flip(
|
|
246
|
+
waterfall=waterfall,
|
|
247
|
+
size_data=size_data,
|
|
248
|
+
column=column,
|
|
249
|
+
ax=ax,
|
|
250
|
+
figsize=figsize,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def _compute_scores(self, column: str) -> tuple[float, float]:
|
|
254
|
+
key = ("scores", column)
|
|
255
|
+
if key not in self._cache:
|
|
256
|
+
self._cache[key] = (
|
|
257
|
+
float(self._data1[_OUTCOME_COLUMN].mean()),
|
|
258
|
+
float(self._data2[_OUTCOME_COLUMN].mean()),
|
|
259
|
+
)
|
|
260
|
+
return cast(tuple[float, float], self._cache[key])
|
|
261
|
+
|
|
262
|
+
def _to_factor(
|
|
263
|
+
self,
|
|
264
|
+
column: str,
|
|
265
|
+
continuous: ContinuousConfig,
|
|
266
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
267
|
+
key = ("to_factor", column, continuous)
|
|
268
|
+
if key not in self._cache:
|
|
269
|
+
breaks = self._continuous_breaks(column, continuous)
|
|
270
|
+
data1 = self._data1.copy()
|
|
271
|
+
data2 = self._data2.copy()
|
|
272
|
+
data1[column] = self._cut_to_categorical(data1[column], breaks)
|
|
273
|
+
data2[column] = self._cut_to_categorical(data2[column], breaks)
|
|
274
|
+
self._cache[key] = (data1, data2)
|
|
275
|
+
data1_cached, data2_cached = cast(
|
|
276
|
+
tuple[pd.DataFrame, pd.DataFrame],
|
|
277
|
+
self._cache[key],
|
|
278
|
+
)
|
|
279
|
+
return data1_cached.copy(), data2_cached.copy()
|
|
280
|
+
|
|
281
|
+
def _compute_contribution(
|
|
282
|
+
self,
|
|
283
|
+
column: str,
|
|
284
|
+
continuous: ContinuousConfig,
|
|
285
|
+
) -> pd.DataFrame:
|
|
286
|
+
key = ("contribution", column, continuous)
|
|
287
|
+
if key in self._cache:
|
|
288
|
+
cached = cast(pd.DataFrame, self._cache[key])
|
|
289
|
+
return cached.copy()
|
|
290
|
+
|
|
291
|
+
data1, data2 = self._data_for_column(column, continuous)
|
|
292
|
+
grouped1 = self._summarize_by_column(data1, column)
|
|
293
|
+
grouped2 = self._summarize_by_column(data2, column)
|
|
294
|
+
|
|
295
|
+
score1, score2 = self._compute_scores(column)
|
|
296
|
+
amounts = self._compute_replacement_amounts(
|
|
297
|
+
original=grouped1,
|
|
298
|
+
refitted=grouped2,
|
|
299
|
+
original_score=score1,
|
|
300
|
+
refitted_score=score2,
|
|
301
|
+
)
|
|
302
|
+
result = self._average_replacement_amounts(
|
|
303
|
+
amounts=amounts,
|
|
304
|
+
original_items=grouped1["items"],
|
|
305
|
+
refitted_items=grouped2["items"],
|
|
306
|
+
)
|
|
307
|
+
result["contrib"] = self._scale_contributions(
|
|
308
|
+
result["contrib"],
|
|
309
|
+
overall_diff=score2 - score1,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
self._cache[key] = result.copy()
|
|
313
|
+
return result
|
|
314
|
+
|
|
315
|
+
@classmethod
|
|
316
|
+
def _compute_replacement_amounts(
|
|
317
|
+
cls,
|
|
318
|
+
original: pd.DataFrame,
|
|
319
|
+
refitted: pd.DataFrame,
|
|
320
|
+
original_score: float,
|
|
321
|
+
refitted_score: float,
|
|
322
|
+
) -> pd.DataFrame:
|
|
323
|
+
rows: list[dict[str, Any]] = []
|
|
324
|
+
|
|
325
|
+
for item in refitted["items"]:
|
|
326
|
+
replaced = cls._replace_or_append_group(original, refitted, item)
|
|
327
|
+
rows.append(
|
|
328
|
+
{
|
|
329
|
+
"items": item,
|
|
330
|
+
"amount": cls._score_from_summary(replaced) - original_score,
|
|
331
|
+
},
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
for item in original["items"]:
|
|
335
|
+
replaced = cls._replace_or_append_group(refitted, original, item)
|
|
336
|
+
rows.append(
|
|
337
|
+
{
|
|
338
|
+
"items": item,
|
|
339
|
+
"amount": refitted_score - cls._score_from_summary(replaced),
|
|
340
|
+
},
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
return pd.DataFrame(rows, columns=["items", "amount"])
|
|
344
|
+
|
|
345
|
+
@classmethod
|
|
346
|
+
def _average_replacement_amounts(
|
|
347
|
+
cls,
|
|
348
|
+
amounts: pd.DataFrame,
|
|
349
|
+
original_items: pd.Series,
|
|
350
|
+
refitted_items: pd.Series,
|
|
351
|
+
) -> pd.DataFrame:
|
|
352
|
+
amounts = amounts.copy()
|
|
353
|
+
dtype = cls._combined_categorical_dtype(original_items, refitted_items)
|
|
354
|
+
if dtype is not None:
|
|
355
|
+
amounts["items"] = pd.Categorical(
|
|
356
|
+
amounts["items"],
|
|
357
|
+
categories=dtype.categories,
|
|
358
|
+
ordered=dtype.ordered,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return (
|
|
362
|
+
amounts.groupby("items", observed=True, sort=False)["amount"]
|
|
363
|
+
.mean()
|
|
364
|
+
.reset_index()
|
|
365
|
+
.rename(columns={"amount": "contrib"})
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
@staticmethod
|
|
369
|
+
def _scale_contributions(
|
|
370
|
+
contributions: pd.Series,
|
|
371
|
+
overall_diff: float,
|
|
372
|
+
) -> pd.Series:
|
|
373
|
+
raw_total = float(contributions.sum())
|
|
374
|
+
if np.isclose(raw_total, 0.0):
|
|
375
|
+
scaled = 0.0 if np.isclose(overall_diff, 0.0) else np.nan
|
|
376
|
+
return pd.Series(scaled, index=contributions.index, dtype=float)
|
|
377
|
+
return cast(pd.Series, overall_diff * contributions / raw_total)
|
|
378
|
+
|
|
379
|
+
def _compute_info(self, column: str, continuous: ContinuousConfig) -> pd.DataFrame:
|
|
380
|
+
key = ("info", column, continuous)
|
|
381
|
+
if key in self._cache:
|
|
382
|
+
cached = cast(pd.DataFrame, self._cache[key])
|
|
383
|
+
return cached.copy()
|
|
384
|
+
|
|
385
|
+
data1, data2 = self._data_for_column(column, continuous)
|
|
386
|
+
data1_info = self._summarize_info(data1, column, suffix="1")
|
|
387
|
+
data2_info = self._summarize_info(data2, column, suffix="2")
|
|
388
|
+
result = data1_info.merge(data2_info, on="items", how="outer", sort=False)
|
|
389
|
+
|
|
390
|
+
dtype = self._combined_categorical_dtype(
|
|
391
|
+
data1_info["items"],
|
|
392
|
+
data2_info["items"],
|
|
393
|
+
)
|
|
394
|
+
if dtype is not None:
|
|
395
|
+
result["items"] = pd.Categorical(
|
|
396
|
+
result["items"],
|
|
397
|
+
categories=dtype.categories,
|
|
398
|
+
ordered=dtype.ordered,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
for column_name in ("n1", "n2"):
|
|
402
|
+
result[column_name] = result[column_name].fillna(0).astype(int)
|
|
403
|
+
for column_name in ("x1", "x2"):
|
|
404
|
+
result[column_name] = result[column_name].fillna(0)
|
|
405
|
+
|
|
406
|
+
result = result[["items", "n1", "n2", "x1", "x2", "rate1", "rate2"]]
|
|
407
|
+
self._cache[key] = result.copy()
|
|
408
|
+
return result
|
|
409
|
+
|
|
410
|
+
def _compute_size(
|
|
411
|
+
self,
|
|
412
|
+
column: str,
|
|
413
|
+
target: Sequence[str],
|
|
414
|
+
continuous: ContinuousConfig,
|
|
415
|
+
) -> pd.DataFrame:
|
|
416
|
+
data1, data2 = self._data_for_column(column, continuous)
|
|
417
|
+
target_items = [str(item) for item in target]
|
|
418
|
+
|
|
419
|
+
data1_size = self._count_target_items(
|
|
420
|
+
data=data1,
|
|
421
|
+
column=column,
|
|
422
|
+
target=target_items,
|
|
423
|
+
label=self.labels[0],
|
|
424
|
+
)
|
|
425
|
+
data2_size = self._count_target_items(
|
|
426
|
+
data=data2,
|
|
427
|
+
column=column,
|
|
428
|
+
target=target_items,
|
|
429
|
+
label=self.labels[1],
|
|
430
|
+
)
|
|
431
|
+
item_names = set(data1_size["items"].astype(str)).union(
|
|
432
|
+
data2_size["items"].astype(str),
|
|
433
|
+
)
|
|
434
|
+
other_names = [item for item in target_items if item not in item_names]
|
|
435
|
+
|
|
436
|
+
if not other_names:
|
|
437
|
+
return pd.concat([data1_size, data2_size], ignore_index=True)
|
|
438
|
+
|
|
439
|
+
rows = [data1_size, data2_size]
|
|
440
|
+
for item in other_names:
|
|
441
|
+
rows.append(
|
|
442
|
+
self._count_other_items(
|
|
443
|
+
data=data1,
|
|
444
|
+
column=column,
|
|
445
|
+
target=target_items,
|
|
446
|
+
item=item,
|
|
447
|
+
label=self.labels[0],
|
|
448
|
+
),
|
|
449
|
+
)
|
|
450
|
+
rows.append(
|
|
451
|
+
self._count_other_items(
|
|
452
|
+
data=data2,
|
|
453
|
+
column=column,
|
|
454
|
+
target=target_items,
|
|
455
|
+
item=item,
|
|
456
|
+
label=self.labels[1],
|
|
457
|
+
),
|
|
458
|
+
)
|
|
459
|
+
return pd.concat(rows, ignore_index=True)
|
|
460
|
+
|
|
461
|
+
def _plot_contribution_data(
|
|
462
|
+
self,
|
|
463
|
+
table: pd.DataFrame,
|
|
464
|
+
column: str,
|
|
465
|
+
levels: Sequence[str] | None,
|
|
466
|
+
) -> pd.DataFrame:
|
|
467
|
+
data = table[[column, "contrib", "n1", "n2"]].copy()
|
|
468
|
+
is_factor = isinstance(data[column].dtype, pd.CategoricalDtype)
|
|
469
|
+
if is_factor:
|
|
470
|
+
data = data.sort_values(column, kind="stable")
|
|
471
|
+
else:
|
|
472
|
+
data = data.sort_values("contrib", kind="stable")
|
|
473
|
+
|
|
474
|
+
if levels is not None:
|
|
475
|
+
level_frame = pd.DataFrame({column: [str(level) for level in levels]})
|
|
476
|
+
data[column] = data[column].astype(str)
|
|
477
|
+
data = level_frame.merge(data, on=column, how="inner", sort=False)
|
|
478
|
+
|
|
479
|
+
return data.reset_index(drop=True)
|
|
480
|
+
|
|
481
|
+
def _plot_flip_contribution_data(
|
|
482
|
+
self,
|
|
483
|
+
table: pd.DataFrame,
|
|
484
|
+
column: str,
|
|
485
|
+
levels: Sequence[str] | None,
|
|
486
|
+
) -> pd.DataFrame:
|
|
487
|
+
data = table[[column, "contrib", "n1", "n2"]].copy()
|
|
488
|
+
is_factor = isinstance(data[column].dtype, pd.CategoricalDtype)
|
|
489
|
+
if is_factor:
|
|
490
|
+
data = data.sort_values(column, ascending=False, kind="stable")
|
|
491
|
+
else:
|
|
492
|
+
data = data.sort_values("contrib", kind="stable")
|
|
493
|
+
|
|
494
|
+
if levels is not None:
|
|
495
|
+
level_frame = pd.DataFrame(
|
|
496
|
+
{column: [str(level) for level in reversed(levels)]},
|
|
497
|
+
)
|
|
498
|
+
data[column] = data[column].astype(str)
|
|
499
|
+
data = level_frame.merge(data, on=column, how="inner", sort=False)
|
|
500
|
+
|
|
501
|
+
return data.reset_index(drop=True)
|
|
502
|
+
|
|
503
|
+
def _plot_size_data(
|
|
504
|
+
self,
|
|
505
|
+
table: pd.DataFrame,
|
|
506
|
+
column: str,
|
|
507
|
+
main_item: str | None,
|
|
508
|
+
bar_max_value: float | None,
|
|
509
|
+
) -> pd.DataFrame:
|
|
510
|
+
size_data = pd.concat(
|
|
511
|
+
[
|
|
512
|
+
pd.DataFrame(
|
|
513
|
+
{
|
|
514
|
+
"items": table[column].astype(str),
|
|
515
|
+
"n": table["n1"].astype(float),
|
|
516
|
+
"type": self.labels[0],
|
|
517
|
+
},
|
|
518
|
+
),
|
|
519
|
+
pd.DataFrame(
|
|
520
|
+
{
|
|
521
|
+
"items": table[column].astype(str),
|
|
522
|
+
"n": table["n2"].astype(float),
|
|
523
|
+
"type": self.labels[1],
|
|
524
|
+
},
|
|
525
|
+
),
|
|
526
|
+
],
|
|
527
|
+
ignore_index=True,
|
|
528
|
+
)
|
|
529
|
+
if table.empty:
|
|
530
|
+
size_data["scaled_n"] = 0.0
|
|
531
|
+
return size_data
|
|
532
|
+
|
|
533
|
+
max_amount, n_max = self._size_scale_reference(
|
|
534
|
+
table=table,
|
|
535
|
+
column=column,
|
|
536
|
+
main_item=main_item,
|
|
537
|
+
bar_max_value=bar_max_value,
|
|
538
|
+
)
|
|
539
|
+
if np.isclose(n_max, 0.0) or np.isclose(max_amount, 0.0):
|
|
540
|
+
size_data["scaled_n"] = 0.0
|
|
541
|
+
else:
|
|
542
|
+
size_data["scaled_n"] = size_data["n"] / n_max * max_amount
|
|
543
|
+
return size_data
|
|
544
|
+
|
|
545
|
+
def _size_scale_reference(
|
|
546
|
+
self,
|
|
547
|
+
table: pd.DataFrame,
|
|
548
|
+
column: str,
|
|
549
|
+
main_item: str | None,
|
|
550
|
+
bar_max_value: float | None,
|
|
551
|
+
) -> tuple[float, float]:
|
|
552
|
+
data = table.copy()
|
|
553
|
+
data["_max_n"] = data[["n1", "n2"]].max(axis=1).astype(float)
|
|
554
|
+
data["_abs_contrib"] = (data["contrib"].astype(float) * 100).abs()
|
|
555
|
+
|
|
556
|
+
if main_item is not None:
|
|
557
|
+
row = data[data[column].astype(str) == str(main_item)]
|
|
558
|
+
if row.empty:
|
|
559
|
+
msg = f"main_item {main_item!r} is not present in the plot data."
|
|
560
|
+
raise ValueError(msg)
|
|
561
|
+
return float(row["_abs_contrib"].iloc[0]), float(row["_max_n"].iloc[0])
|
|
562
|
+
|
|
563
|
+
if bar_max_value is not None:
|
|
564
|
+
return abs(float(bar_max_value)), float(data["_max_n"].max())
|
|
565
|
+
|
|
566
|
+
row = data.sort_values("_abs_contrib", ascending=False).head(1)
|
|
567
|
+
return float(row["_abs_contrib"].iloc[0]), float(row["_max_n"].iloc[0])
|
|
568
|
+
|
|
569
|
+
def _waterfall_data(
|
|
570
|
+
self,
|
|
571
|
+
items: Sequence[str],
|
|
572
|
+
contributions: Sequence[float],
|
|
573
|
+
start: float,
|
|
574
|
+
start_label: str | None = None,
|
|
575
|
+
end_label: str | None = None,
|
|
576
|
+
) -> pd.DataFrame:
|
|
577
|
+
rows: list[dict[str, Any]] = []
|
|
578
|
+
cumulative = start * 100
|
|
579
|
+
start_label = self.labels[0] if start_label is None else start_label
|
|
580
|
+
end_label = self.labels[1] if end_label is None else end_label
|
|
581
|
+
rows.append(
|
|
582
|
+
{
|
|
583
|
+
"items": start_label,
|
|
584
|
+
"bottom": 0.0,
|
|
585
|
+
"height": cumulative,
|
|
586
|
+
"amount": cumulative,
|
|
587
|
+
"cumulative": cumulative,
|
|
588
|
+
"kind": "total",
|
|
589
|
+
},
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
for item, contribution in zip(items, contributions, strict=True):
|
|
593
|
+
amount = contribution * 100
|
|
594
|
+
bottom = cumulative if amount >= 0 else cumulative + amount
|
|
595
|
+
cumulative += amount
|
|
596
|
+
rows.append(
|
|
597
|
+
{
|
|
598
|
+
"items": item,
|
|
599
|
+
"bottom": bottom,
|
|
600
|
+
"height": abs(amount),
|
|
601
|
+
"amount": amount,
|
|
602
|
+
"cumulative": cumulative,
|
|
603
|
+
"kind": "contribution",
|
|
604
|
+
},
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
rows.append(
|
|
608
|
+
{
|
|
609
|
+
"items": end_label,
|
|
610
|
+
"bottom": 0.0,
|
|
611
|
+
"height": cumulative,
|
|
612
|
+
"amount": cumulative,
|
|
613
|
+
"cumulative": cumulative,
|
|
614
|
+
"kind": "total",
|
|
615
|
+
},
|
|
616
|
+
)
|
|
617
|
+
return pd.DataFrame(rows)
|
|
618
|
+
|
|
619
|
+
def _draw_plot(
|
|
620
|
+
self,
|
|
621
|
+
waterfall: pd.DataFrame,
|
|
622
|
+
size_data: pd.DataFrame,
|
|
623
|
+
column: str,
|
|
624
|
+
ax: Any | None,
|
|
625
|
+
figsize: tuple[float, float] | None,
|
|
626
|
+
) -> Any:
|
|
627
|
+
plt = self._load_pyplot()
|
|
628
|
+
if ax is None:
|
|
629
|
+
fig, ax = plt.subplots(figsize=figsize or self._default_figsize(waterfall))
|
|
630
|
+
else:
|
|
631
|
+
fig = ax.figure
|
|
632
|
+
|
|
633
|
+
positions = cast(
|
|
634
|
+
NDArray[np.int64],
|
|
635
|
+
np.arange(len(waterfall), dtype=np.int64),
|
|
636
|
+
)
|
|
637
|
+
item_to_position = {
|
|
638
|
+
str(item): index for index, item in enumerate(waterfall["items"])
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
self._draw_size_bars(ax=ax, size_data=size_data, positions=item_to_position)
|
|
642
|
+
self._draw_waterfall_bars(ax=ax, waterfall=waterfall, positions=positions)
|
|
643
|
+
self._draw_connectors(ax=ax, waterfall=waterfall, positions=positions)
|
|
644
|
+
|
|
645
|
+
ax.axhline(0, color="#333333", linewidth=0.8)
|
|
646
|
+
ax.set_xticks(positions)
|
|
647
|
+
ax.set_xticklabels(waterfall["items"].astype(str), rotation=45, ha="right")
|
|
648
|
+
ax.set_ylabel(self.y_label or "")
|
|
649
|
+
ax.set_xlabel("")
|
|
650
|
+
ax.set_title(column)
|
|
651
|
+
ax.margins(x=0.02)
|
|
652
|
+
fig.tight_layout()
|
|
653
|
+
return fig, ax
|
|
654
|
+
|
|
655
|
+
def _draw_plot_flip(
|
|
656
|
+
self,
|
|
657
|
+
waterfall: pd.DataFrame,
|
|
658
|
+
size_data: pd.DataFrame,
|
|
659
|
+
column: str,
|
|
660
|
+
ax: Any | None,
|
|
661
|
+
figsize: tuple[float, float] | None,
|
|
662
|
+
) -> Any:
|
|
663
|
+
plt = self._load_pyplot()
|
|
664
|
+
if ax is None:
|
|
665
|
+
fig, ax = plt.subplots(
|
|
666
|
+
figsize=figsize or self._default_flip_figsize(waterfall),
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
fig = ax.figure
|
|
670
|
+
|
|
671
|
+
positions = cast(
|
|
672
|
+
NDArray[np.int64],
|
|
673
|
+
np.arange(len(waterfall), dtype=np.int64),
|
|
674
|
+
)
|
|
675
|
+
item_to_position = {
|
|
676
|
+
str(item): index for index, item in enumerate(waterfall["items"])
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
self._draw_size_bars_horizontal(
|
|
680
|
+
ax=ax,
|
|
681
|
+
size_data=size_data,
|
|
682
|
+
positions=item_to_position,
|
|
683
|
+
)
|
|
684
|
+
self._draw_waterfall_bars_horizontal(
|
|
685
|
+
ax=ax,
|
|
686
|
+
waterfall=waterfall,
|
|
687
|
+
positions=positions,
|
|
688
|
+
)
|
|
689
|
+
self._draw_connectors_horizontal(
|
|
690
|
+
ax=ax,
|
|
691
|
+
waterfall=waterfall,
|
|
692
|
+
positions=positions,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
ax.axvline(0, color="#333333", linewidth=0.8)
|
|
696
|
+
ax.set_yticks(positions)
|
|
697
|
+
ax.set_yticklabels(waterfall["items"].astype(str))
|
|
698
|
+
ax.invert_yaxis()
|
|
699
|
+
ax.set_xlabel(self.y_label or "")
|
|
700
|
+
ax.set_ylabel("")
|
|
701
|
+
ax.set_title(column)
|
|
702
|
+
ax.margins(y=0.02)
|
|
703
|
+
fig.tight_layout()
|
|
704
|
+
return fig, ax
|
|
705
|
+
|
|
706
|
+
@staticmethod
|
|
707
|
+
def _load_pyplot() -> Any:
|
|
708
|
+
try:
|
|
709
|
+
return cast(Any, importlib.import_module("matplotlib.pyplot"))
|
|
710
|
+
except ModuleNotFoundError as exc:
|
|
711
|
+
msg = "matplotlib is required to use plotting methods."
|
|
712
|
+
raise ModuleNotFoundError(msg) from exc
|
|
713
|
+
|
|
714
|
+
def _draw_size_bars(
|
|
715
|
+
self,
|
|
716
|
+
ax: Any,
|
|
717
|
+
size_data: pd.DataFrame,
|
|
718
|
+
positions: dict[str, int],
|
|
719
|
+
) -> None:
|
|
720
|
+
width = 0.22
|
|
721
|
+
offsets = {self.labels[0]: -width / 1.5, self.labels[1]: width / 1.5}
|
|
722
|
+
colors = {
|
|
723
|
+
self.labels[0]: _ORIGINAL_SIZE_COLOR,
|
|
724
|
+
self.labels[1]: _REFITTED_SIZE_COLOR,
|
|
725
|
+
}
|
|
726
|
+
for _, row in size_data.iterrows():
|
|
727
|
+
item = str(row["items"])
|
|
728
|
+
if item not in positions:
|
|
729
|
+
continue
|
|
730
|
+
group = str(row["type"])
|
|
731
|
+
ax.bar(
|
|
732
|
+
positions[item] + offsets[group],
|
|
733
|
+
row["scaled_n"],
|
|
734
|
+
width=width,
|
|
735
|
+
color=colors[group],
|
|
736
|
+
alpha=0.35,
|
|
737
|
+
linewidth=0,
|
|
738
|
+
zorder=1,
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
def _draw_size_bars_horizontal(
|
|
742
|
+
self,
|
|
743
|
+
ax: Any,
|
|
744
|
+
size_data: pd.DataFrame,
|
|
745
|
+
positions: dict[str, int],
|
|
746
|
+
) -> None:
|
|
747
|
+
height = 0.22
|
|
748
|
+
offsets = {self.labels[0]: -height / 1.5, self.labels[1]: height / 1.5}
|
|
749
|
+
colors = {
|
|
750
|
+
self.labels[0]: _ORIGINAL_SIZE_COLOR,
|
|
751
|
+
self.labels[1]: _REFITTED_SIZE_COLOR,
|
|
752
|
+
}
|
|
753
|
+
for _, row in size_data.iterrows():
|
|
754
|
+
item = str(row["items"])
|
|
755
|
+
if item not in positions:
|
|
756
|
+
continue
|
|
757
|
+
group = str(row["type"])
|
|
758
|
+
ax.barh(
|
|
759
|
+
positions[item] + offsets[group],
|
|
760
|
+
row["scaled_n"],
|
|
761
|
+
height=height,
|
|
762
|
+
color=colors[group],
|
|
763
|
+
alpha=0.35,
|
|
764
|
+
linewidth=0,
|
|
765
|
+
zorder=1,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
def _draw_waterfall_bars(
|
|
769
|
+
self,
|
|
770
|
+
ax: Any,
|
|
771
|
+
waterfall: pd.DataFrame,
|
|
772
|
+
positions: NDArray[np.int64],
|
|
773
|
+
) -> None:
|
|
774
|
+
colors = [
|
|
775
|
+
_REFITTED_COLOR
|
|
776
|
+
if row["kind"] == "total"
|
|
777
|
+
else (_POSITIVE_COLOR if row["amount"] >= 0 else _NEGATIVE_COLOR)
|
|
778
|
+
for _, row in waterfall.iterrows()
|
|
779
|
+
]
|
|
780
|
+
ax.bar(
|
|
781
|
+
positions,
|
|
782
|
+
waterfall["height"],
|
|
783
|
+
bottom=waterfall["bottom"],
|
|
784
|
+
width=0.62,
|
|
785
|
+
color=colors,
|
|
786
|
+
edgecolor="#333333",
|
|
787
|
+
linewidth=0.6,
|
|
788
|
+
zorder=3,
|
|
789
|
+
)
|
|
790
|
+
for position, (_, row) in zip(positions, waterfall.iterrows(), strict=True):
|
|
791
|
+
value = self._format_plot_value(float(row["amount"]))
|
|
792
|
+
y = float(row["bottom"]) + float(row["height"])
|
|
793
|
+
va = "bottom"
|
|
794
|
+
if float(row["amount"]) < 0:
|
|
795
|
+
y = float(row["bottom"])
|
|
796
|
+
va = "top"
|
|
797
|
+
ax.text(
|
|
798
|
+
position,
|
|
799
|
+
y,
|
|
800
|
+
value,
|
|
801
|
+
ha="center",
|
|
802
|
+
va=va,
|
|
803
|
+
fontsize=9 * self.text_size,
|
|
804
|
+
zorder=4,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
def _draw_waterfall_bars_horizontal(
|
|
808
|
+
self,
|
|
809
|
+
ax: Any,
|
|
810
|
+
waterfall: pd.DataFrame,
|
|
811
|
+
positions: NDArray[np.int64],
|
|
812
|
+
) -> None:
|
|
813
|
+
colors = [
|
|
814
|
+
_REFITTED_COLOR
|
|
815
|
+
if row["kind"] == "total"
|
|
816
|
+
else (_POSITIVE_COLOR if row["amount"] >= 0 else _NEGATIVE_COLOR)
|
|
817
|
+
for _, row in waterfall.iterrows()
|
|
818
|
+
]
|
|
819
|
+
ax.barh(
|
|
820
|
+
positions,
|
|
821
|
+
waterfall["height"],
|
|
822
|
+
left=waterfall["bottom"],
|
|
823
|
+
height=0.62,
|
|
824
|
+
color=colors,
|
|
825
|
+
edgecolor="#333333",
|
|
826
|
+
linewidth=0.6,
|
|
827
|
+
zorder=3,
|
|
828
|
+
)
|
|
829
|
+
for position, (_, row) in zip(positions, waterfall.iterrows(), strict=True):
|
|
830
|
+
value = self._format_plot_value(float(row["amount"]))
|
|
831
|
+
x = float(row["bottom"]) + float(row["height"])
|
|
832
|
+
ha = "left"
|
|
833
|
+
if float(row["amount"]) < 0:
|
|
834
|
+
x = float(row["bottom"])
|
|
835
|
+
ha = "right"
|
|
836
|
+
ax.text(
|
|
837
|
+
x,
|
|
838
|
+
position,
|
|
839
|
+
value,
|
|
840
|
+
ha=ha,
|
|
841
|
+
va="center",
|
|
842
|
+
fontsize=9 * self.text_size,
|
|
843
|
+
zorder=4,
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
@staticmethod
|
|
847
|
+
def _draw_connectors(
|
|
848
|
+
ax: Any,
|
|
849
|
+
waterfall: pd.DataFrame,
|
|
850
|
+
positions: NDArray[np.int64],
|
|
851
|
+
) -> None:
|
|
852
|
+
for index in range(len(waterfall) - 2):
|
|
853
|
+
y = float(waterfall.loc[index, "cumulative"])
|
|
854
|
+
ax.plot(
|
|
855
|
+
[positions[index] + 0.31, positions[index + 1] - 0.31],
|
|
856
|
+
[y, y],
|
|
857
|
+
color="#666666",
|
|
858
|
+
linewidth=0.8,
|
|
859
|
+
zorder=2,
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
@staticmethod
|
|
863
|
+
def _draw_connectors_horizontal(
|
|
864
|
+
ax: Any,
|
|
865
|
+
waterfall: pd.DataFrame,
|
|
866
|
+
positions: NDArray[np.int64],
|
|
867
|
+
) -> None:
|
|
868
|
+
for index in range(len(waterfall) - 2):
|
|
869
|
+
x = float(waterfall.loc[index, "cumulative"])
|
|
870
|
+
ax.plot(
|
|
871
|
+
[x, x],
|
|
872
|
+
[positions[index] + 0.31, positions[index + 1] - 0.31],
|
|
873
|
+
color="#666666",
|
|
874
|
+
linewidth=0.8,
|
|
875
|
+
zorder=2,
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
@staticmethod
|
|
879
|
+
def _default_figsize(waterfall: pd.DataFrame) -> tuple[float, float]:
|
|
880
|
+
return max(6.0, len(waterfall) * 0.75), 4.5
|
|
881
|
+
|
|
882
|
+
@staticmethod
|
|
883
|
+
def _default_flip_figsize(waterfall: pd.DataFrame) -> tuple[float, float]:
|
|
884
|
+
return 6.5, max(4.5, len(waterfall) * 0.45)
|
|
885
|
+
|
|
886
|
+
def _format_plot_value(self, value: float) -> str:
|
|
887
|
+
rounded = round(value, self.digits)
|
|
888
|
+
return f"{rounded:g}"
|
|
889
|
+
|
|
890
|
+
@staticmethod
|
|
891
|
+
def _count_target_items(
|
|
892
|
+
data: pd.DataFrame,
|
|
893
|
+
column: str,
|
|
894
|
+
target: Sequence[str],
|
|
895
|
+
label: str,
|
|
896
|
+
) -> pd.DataFrame:
|
|
897
|
+
counts = (
|
|
898
|
+
data[data[column].astype(str).isin(target)]
|
|
899
|
+
.groupby(column, observed=True, sort=False)
|
|
900
|
+
.size()
|
|
901
|
+
.reset_index(name="n")
|
|
902
|
+
.rename(columns={column: "items"})
|
|
903
|
+
)
|
|
904
|
+
counts["items"] = counts["items"].astype(str)
|
|
905
|
+
counts["type"] = label
|
|
906
|
+
return counts[["items", "n", "type"]]
|
|
907
|
+
|
|
908
|
+
@staticmethod
|
|
909
|
+
def _count_other_items(
|
|
910
|
+
data: pd.DataFrame,
|
|
911
|
+
column: str,
|
|
912
|
+
target: Sequence[str],
|
|
913
|
+
item: str,
|
|
914
|
+
label: str,
|
|
915
|
+
) -> pd.DataFrame:
|
|
916
|
+
count = int((~data[column].astype(str).isin(target)).sum())
|
|
917
|
+
return pd.DataFrame([{"items": item, "n": count, "type": label}])
|
|
918
|
+
|
|
919
|
+
def _validate_column(self, column: str) -> None:
|
|
920
|
+
if column not in self._data1.columns:
|
|
921
|
+
msg = f"column {column!r} is missing from data1."
|
|
922
|
+
raise ValueError(msg)
|
|
923
|
+
if column not in self._data2.columns:
|
|
924
|
+
msg = f"column {column!r} is missing from data2."
|
|
925
|
+
raise ValueError(msg)
|
|
926
|
+
|
|
927
|
+
@staticmethod
|
|
928
|
+
def _normalize_n(n: int | float) -> int | float:
|
|
929
|
+
if n == float("inf"):
|
|
930
|
+
return n
|
|
931
|
+
limit = int(n)
|
|
932
|
+
if limit < 1:
|
|
933
|
+
msg = "n must be at least 1 or infinity."
|
|
934
|
+
raise ValueError(msg)
|
|
935
|
+
return limit
|
|
936
|
+
|
|
937
|
+
@staticmethod
|
|
938
|
+
def _sort_by_abs_contrib(data: pd.DataFrame) -> pd.DataFrame:
|
|
939
|
+
return (
|
|
940
|
+
data.assign(_abs_contrib=data["contrib"].abs())
|
|
941
|
+
.sort_values("_abs_contrib", ascending=False, kind="stable")
|
|
942
|
+
.drop(columns="_abs_contrib")
|
|
943
|
+
.reset_index(drop=True)
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
@staticmethod
|
|
947
|
+
def _safe_rate(x: float, n: int) -> float:
|
|
948
|
+
return float(x) / n if n else float("nan")
|
|
949
|
+
|
|
950
|
+
def _data_for_column(
|
|
951
|
+
self,
|
|
952
|
+
column: str,
|
|
953
|
+
continuous: ContinuousConfig,
|
|
954
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
955
|
+
if is_numeric_dtype(self._data1[column]):
|
|
956
|
+
return self._to_factor(column, continuous)
|
|
957
|
+
return self._data1, self._data2
|
|
958
|
+
|
|
959
|
+
@staticmethod
|
|
960
|
+
def _summarize_by_column(data: pd.DataFrame, column: str) -> pd.DataFrame:
|
|
961
|
+
grouped = (
|
|
962
|
+
data.groupby(column, observed=True, sort=False)[_OUTCOME_COLUMN]
|
|
963
|
+
.agg(y="sum", n="size")
|
|
964
|
+
.reset_index()
|
|
965
|
+
.rename(columns={column: "items"})
|
|
966
|
+
)
|
|
967
|
+
grouped["rate"] = grouped["y"] / grouped["n"]
|
|
968
|
+
return grouped
|
|
969
|
+
|
|
970
|
+
@staticmethod
|
|
971
|
+
def _summarize_info(
|
|
972
|
+
data: pd.DataFrame,
|
|
973
|
+
column: str,
|
|
974
|
+
suffix: str,
|
|
975
|
+
) -> pd.DataFrame:
|
|
976
|
+
grouped = (
|
|
977
|
+
data.groupby(column, observed=True, sort=False)[_OUTCOME_COLUMN]
|
|
978
|
+
.agg(**{f"x{suffix}": "sum", f"n{suffix}": "size"})
|
|
979
|
+
.reset_index()
|
|
980
|
+
.rename(columns={column: "items"})
|
|
981
|
+
)
|
|
982
|
+
grouped[f"rate{suffix}"] = grouped[f"x{suffix}"] / grouped[f"n{suffix}"]
|
|
983
|
+
return grouped
|
|
984
|
+
|
|
985
|
+
@staticmethod
|
|
986
|
+
def _replace_or_append_group(
|
|
987
|
+
base: pd.DataFrame,
|
|
988
|
+
replacement: pd.DataFrame,
|
|
989
|
+
item: Any,
|
|
990
|
+
) -> pd.DataFrame:
|
|
991
|
+
result = base.copy()
|
|
992
|
+
replacement_row = replacement[replacement["items"] == item]
|
|
993
|
+
mask = result["items"] == item
|
|
994
|
+
if mask.any():
|
|
995
|
+
result.loc[mask, ["y", "n", "rate"]] = replacement_row[
|
|
996
|
+
["y", "n", "rate"]
|
|
997
|
+
].to_numpy()
|
|
998
|
+
return result
|
|
999
|
+
return pd.concat([result, replacement_row], ignore_index=True)
|
|
1000
|
+
|
|
1001
|
+
@staticmethod
|
|
1002
|
+
def _score_from_summary(data: pd.DataFrame) -> float:
|
|
1003
|
+
return float(data["y"].sum() / data["n"].sum())
|
|
1004
|
+
|
|
1005
|
+
@staticmethod
|
|
1006
|
+
def _combined_categorical_dtype(
|
|
1007
|
+
values1: pd.Series,
|
|
1008
|
+
values2: pd.Series,
|
|
1009
|
+
) -> pd.CategoricalDtype | None:
|
|
1010
|
+
dtype1 = values1.dtype
|
|
1011
|
+
dtype2 = values2.dtype
|
|
1012
|
+
if not isinstance(dtype1, pd.CategoricalDtype):
|
|
1013
|
+
return None
|
|
1014
|
+
|
|
1015
|
+
categories = list(dtype1.categories)
|
|
1016
|
+
if isinstance(dtype2, pd.CategoricalDtype):
|
|
1017
|
+
categories.extend(
|
|
1018
|
+
category
|
|
1019
|
+
for category in dtype2.categories
|
|
1020
|
+
if category not in categories
|
|
1021
|
+
)
|
|
1022
|
+
else:
|
|
1023
|
+
categories.extend(value for value in values2 if value not in categories)
|
|
1024
|
+
return pd.CategoricalDtype(categories=categories, ordered=dtype1.ordered)
|
|
1025
|
+
|
|
1026
|
+
def _continuous_breaks(
|
|
1027
|
+
self,
|
|
1028
|
+
column: str,
|
|
1029
|
+
continuous: ContinuousConfig,
|
|
1030
|
+
) -> NDArray[np.float64]:
|
|
1031
|
+
if continuous.breaks is not None:
|
|
1032
|
+
return self._validate_breaks(
|
|
1033
|
+
np.asarray(continuous.breaks, dtype=float),
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
values = pd.concat(
|
|
1037
|
+
[self._data1[column], self._data2[column]],
|
|
1038
|
+
ignore_index=True,
|
|
1039
|
+
).astype(float)
|
|
1040
|
+
non_missing = values.dropna()
|
|
1041
|
+
if non_missing.empty:
|
|
1042
|
+
msg = f"column {column!r} must contain at least one non-missing value."
|
|
1043
|
+
raise ValueError(msg)
|
|
1044
|
+
|
|
1045
|
+
break_num = continuous.n
|
|
1046
|
+
|
|
1047
|
+
if continuous.split == "width":
|
|
1048
|
+
if values.isna().any():
|
|
1049
|
+
break_num -= 1
|
|
1050
|
+
self._validate_break_num(break_num, continuous)
|
|
1051
|
+
breaks = np.linspace(
|
|
1052
|
+
non_missing.min(),
|
|
1053
|
+
non_missing.max(),
|
|
1054
|
+
break_num + 1,
|
|
1055
|
+
)
|
|
1056
|
+
elif continuous.split == "count":
|
|
1057
|
+
self._validate_break_num(
|
|
1058
|
+
break_num - int(values.isna().any()),
|
|
1059
|
+
continuous,
|
|
1060
|
+
)
|
|
1061
|
+
breaks = self._compute_breaks(values, break_num)
|
|
1062
|
+
else:
|
|
1063
|
+
self._validate_break_num(
|
|
1064
|
+
break_num - int(values.isna().any()),
|
|
1065
|
+
continuous,
|
|
1066
|
+
)
|
|
1067
|
+
breaks = np.unique(self._compute_breaks(values, break_num * 20))
|
|
1068
|
+
data1 = self._data1[[column, _OUTCOME_COLUMN]].dropna(subset=[column])
|
|
1069
|
+
data2 = self._data2[[column, _OUTCOME_COLUMN]].dropna(subset=[column])
|
|
1070
|
+
while len(breaks) > break_num + 1:
|
|
1071
|
+
diff = self._adjacent_rate_diff(data1, data2, column, breaks)
|
|
1072
|
+
remove_at = int(np.nanargmin(diff)) + 1
|
|
1073
|
+
breaks = np.delete(breaks, remove_at)
|
|
1074
|
+
|
|
1075
|
+
if continuous.pretty:
|
|
1076
|
+
pretty = self._pretty_breaks(breaks)
|
|
1077
|
+
if len(np.unique(pretty)) < len(pretty):
|
|
1078
|
+
warnings.warn(
|
|
1079
|
+
"Prettying breaks reduced the number of breaks. "
|
|
1080
|
+
"Try pretty=False.",
|
|
1081
|
+
stacklevel=2,
|
|
1082
|
+
)
|
|
1083
|
+
pretty = np.unique(pretty)
|
|
1084
|
+
breaks = pretty
|
|
1085
|
+
return self._validate_breaks(np.asarray(breaks, dtype=float))
|
|
1086
|
+
|
|
1087
|
+
@staticmethod
|
|
1088
|
+
def _validate_break_num(
|
|
1089
|
+
break_num: int,
|
|
1090
|
+
continuous: ContinuousConfig,
|
|
1091
|
+
) -> None:
|
|
1092
|
+
if break_num < 1:
|
|
1093
|
+
msg = (
|
|
1094
|
+
"continuous.n must leave at least one numeric bin. "
|
|
1095
|
+
"Increase n or remove missing values."
|
|
1096
|
+
)
|
|
1097
|
+
raise ValueError(msg)
|
|
1098
|
+
if continuous.split == "rate" and break_num < 2:
|
|
1099
|
+
msg = "split='rate' requires at least two numeric bins."
|
|
1100
|
+
raise ValueError(msg)
|
|
1101
|
+
|
|
1102
|
+
@staticmethod
|
|
1103
|
+
def _validate_breaks(breaks: NDArray[np.float64]) -> NDArray[np.float64]:
|
|
1104
|
+
if len(breaks) < _MIN_BREAK_COUNT:
|
|
1105
|
+
msg = "continuous breaks must contain at least two values."
|
|
1106
|
+
raise ValueError(msg)
|
|
1107
|
+
if np.isnan(breaks).any():
|
|
1108
|
+
msg = "continuous breaks must not contain NaN."
|
|
1109
|
+
raise ValueError(msg)
|
|
1110
|
+
if not np.all(np.diff(breaks) > 0):
|
|
1111
|
+
msg = "continuous breaks must be strictly increasing."
|
|
1112
|
+
raise ValueError(msg)
|
|
1113
|
+
return breaks
|
|
1114
|
+
|
|
1115
|
+
@staticmethod
|
|
1116
|
+
def _compute_breaks(
|
|
1117
|
+
values: pd.Series,
|
|
1118
|
+
break_num: int,
|
|
1119
|
+
) -> NDArray[np.float64]:
|
|
1120
|
+
if values.isna().any():
|
|
1121
|
+
break_num -= 1
|
|
1122
|
+
probs = np.linspace(0, 1, break_num + 1)
|
|
1123
|
+
return np.asarray(values.quantile(probs).to_numpy(), dtype=float)
|
|
1124
|
+
|
|
1125
|
+
@staticmethod
|
|
1126
|
+
def _pretty_breaks(
|
|
1127
|
+
breaks: NDArray[np.float64],
|
|
1128
|
+
) -> NDArray[np.float64]:
|
|
1129
|
+
result = []
|
|
1130
|
+
for value in breaks:
|
|
1131
|
+
digits = 0 if value == 0 else np.floor(np.log10(abs(value))) + 1
|
|
1132
|
+
base = 10 ** (digits - 2)
|
|
1133
|
+
rounded = np.floor(value / base) if value < 0 else np.ceil(value / base)
|
|
1134
|
+
result.append(rounded * base)
|
|
1135
|
+
return np.asarray(result, dtype=float)
|
|
1136
|
+
|
|
1137
|
+
@staticmethod
|
|
1138
|
+
def _adjacent_rate_diff(
|
|
1139
|
+
data1: pd.DataFrame,
|
|
1140
|
+
data2: pd.DataFrame,
|
|
1141
|
+
column: str,
|
|
1142
|
+
breaks: NDArray[np.float64],
|
|
1143
|
+
) -> NDArray[np.float64]:
|
|
1144
|
+
def summarize(data: pd.DataFrame, name: str) -> pd.Series:
|
|
1145
|
+
bins = pd.cut(data[column], bins=breaks, include_lowest=True)
|
|
1146
|
+
return (
|
|
1147
|
+
data.groupby(bins, observed=True, sort=False)[_OUTCOME_COLUMN]
|
|
1148
|
+
.mean()
|
|
1149
|
+
.diff(-1)
|
|
1150
|
+
.abs()
|
|
1151
|
+
.rename(name)
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
diff1 = summarize(data1, "diff1")
|
|
1155
|
+
diff2 = summarize(data2, "diff2")
|
|
1156
|
+
merged = pd.concat([diff1, diff2], axis=1)
|
|
1157
|
+
diff = np.sqrt(merged["diff1"] ** 2 + merged["diff2"] ** 2)
|
|
1158
|
+
return cast(NDArray[np.float64], np.asarray(diff, dtype=np.float64))
|
|
1159
|
+
|
|
1160
|
+
@staticmethod
|
|
1161
|
+
def _cut_to_categorical(
|
|
1162
|
+
series: pd.Series,
|
|
1163
|
+
breaks: NDArray[np.float64],
|
|
1164
|
+
) -> pd.Series:
|
|
1165
|
+
categories = ShipOfTheseus._cut_labels(breaks)
|
|
1166
|
+
values = pd.cut(
|
|
1167
|
+
series,
|
|
1168
|
+
bins=breaks,
|
|
1169
|
+
include_lowest=True,
|
|
1170
|
+
labels=categories,
|
|
1171
|
+
)
|
|
1172
|
+
if _MISSING_LABEL not in values.cat.categories:
|
|
1173
|
+
values = values.cat.add_categories([_MISSING_LABEL])
|
|
1174
|
+
return values.fillna(_MISSING_LABEL)
|
|
1175
|
+
|
|
1176
|
+
@staticmethod
|
|
1177
|
+
def _cut_labels(breaks: NDArray[np.float64]) -> list[str]:
|
|
1178
|
+
labels = []
|
|
1179
|
+
pairs = zip(breaks[:-1], breaks[1:], strict=True)
|
|
1180
|
+
for index, (left, right) in enumerate(pairs):
|
|
1181
|
+
left_bracket = "[" if index == 0 else "("
|
|
1182
|
+
labels.append(
|
|
1183
|
+
f"{left_bracket}{ShipOfTheseus._format_break(left)},"
|
|
1184
|
+
f"{ShipOfTheseus._format_break(right)}]",
|
|
1185
|
+
)
|
|
1186
|
+
return labels
|
|
1187
|
+
|
|
1188
|
+
@staticmethod
|
|
1189
|
+
def _format_break(value: float) -> str:
|
|
1190
|
+
if np.isclose(value, 0.0):
|
|
1191
|
+
return "0"
|
|
1192
|
+
if np.isclose(value, round(value)):
|
|
1193
|
+
return str(int(round(value)))
|
|
1194
|
+
return f"{value:.15g}"
|
|
1195
|
+
|
|
1196
|
+
@staticmethod
|
|
1197
|
+
def _raise_not_implemented() -> None:
|
|
1198
|
+
msg = (
|
|
1199
|
+
"TheseusPlot calculation and plotting logic has not been "
|
|
1200
|
+
"implemented yet."
|
|
1201
|
+
)
|
|
1202
|
+
raise NotImplementedError(msg)
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
def create_ship(
|
|
1206
|
+
data1: pd.DataFrame,
|
|
1207
|
+
data2: pd.DataFrame,
|
|
1208
|
+
y: str = "y",
|
|
1209
|
+
labels: Sequence[str] = ("Original", "Refitted"),
|
|
1210
|
+
y_label: str | None = None,
|
|
1211
|
+
digits: int = 3,
|
|
1212
|
+
text_size: float = 1.0,
|
|
1213
|
+
) -> ShipOfTheseus:
|
|
1214
|
+
"""Create a ShipOfTheseus object."""
|
|
1215
|
+
|
|
1216
|
+
return ShipOfTheseus(
|
|
1217
|
+
data1=data1,
|
|
1218
|
+
data2=data2,
|
|
1219
|
+
outcome=y,
|
|
1220
|
+
labels=labels,
|
|
1221
|
+
y_label=y_label,
|
|
1222
|
+
digits=digits,
|
|
1223
|
+
text_size=text_size,
|
|
1224
|
+
)
|