theseusplot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
theseusplot/_ship.py ADDED
@@ -0,0 +1,1224 @@
1
+ """ShipOfTheseus public object."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib
6
+ import warnings
7
+ from collections.abc import Sequence
8
+ from typing import Any, cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from numpy.typing import NDArray
13
+ from pandas.api.types import is_numeric_dtype, is_string_dtype
14
+
15
+ from theseusplot._config import ContinuousConfig, continuous_config
16
+
17
+ _OUTCOME_COLUMN = ".outcome"
18
+ _MISSING_LABEL = "(Missing)"
19
+ _MIN_BREAK_COUNT = 2
20
+ _REFITTED_COLOR = "#00BFC4"
21
+ _ORIGINAL_SIZE_COLOR = "#7CAE00"
22
+ _REFITTED_SIZE_COLOR = "#C77CFF"
23
+ _POSITIVE_COLOR = "#F8766D"
24
+ _NEGATIVE_COLOR = "#00BFC4"
25
+
26
+
27
+ class ShipOfTheseus:
28
+ """Container for data and methods used to create Theseus plots."""
29
+
30
+ def __init__(
31
+ self,
32
+ data1: pd.DataFrame,
33
+ data2: pd.DataFrame,
34
+ outcome: str,
35
+ labels: Sequence[str],
36
+ y_label: str | None,
37
+ digits: int,
38
+ text_size: float,
39
+ ) -> None:
40
+ self._validate_inputs(data1, data2, outcome, labels)
41
+
42
+ self._data1 = self._prepare_input_data(data1, outcome)
43
+ self._data2 = self._prepare_input_data(data2, outcome)
44
+ self.outcome = outcome
45
+ self.labels = (labels[0], labels[1])
46
+ self.y_label = y_label
47
+ self.digits = digits
48
+ self.text_size = text_size
49
+ self._cache: dict[tuple[Any, ...], Any] = {}
50
+
51
+ @staticmethod
52
+ def _validate_inputs(
53
+ data1: pd.DataFrame,
54
+ data2: pd.DataFrame,
55
+ outcome: str,
56
+ labels: Sequence[str],
57
+ ) -> None:
58
+ if not isinstance(data1, pd.DataFrame):
59
+ msg = "data1 must be a pandas DataFrame."
60
+ raise TypeError(msg)
61
+ if not isinstance(data2, pd.DataFrame):
62
+ msg = "data2 must be a pandas DataFrame."
63
+ raise TypeError(msg)
64
+ if outcome not in data1.columns:
65
+ msg = f"outcome column {outcome!r} is missing from data1."
66
+ raise ValueError(msg)
67
+ if outcome not in data2.columns:
68
+ msg = f"outcome column {outcome!r} is missing from data2."
69
+ raise ValueError(msg)
70
+ if len(labels) != 2:
71
+ msg = "labels must contain exactly two values."
72
+ raise ValueError(msg)
73
+ if data1[outcome].isna().any() or data2[outcome].isna().any():
74
+ msg = "outcome values must not contain missing values."
75
+ raise ValueError(msg)
76
+
77
+ @staticmethod
78
+ def _prepare_input_data(data: pd.DataFrame, outcome: str) -> pd.DataFrame:
79
+ prepared = data.copy(deep=True)
80
+ for column in prepared.columns:
81
+ if column == outcome:
82
+ continue
83
+ prepared[column] = ShipOfTheseus._fill_missing_categories(
84
+ prepared[column],
85
+ )
86
+ prepared[_OUTCOME_COLUMN] = prepared[outcome]
87
+ return prepared
88
+
89
+ @staticmethod
90
+ def _fill_missing_categories(series: pd.Series) -> pd.Series:
91
+ if isinstance(series.dtype, pd.CategoricalDtype):
92
+ result = series.copy()
93
+ if _MISSING_LABEL not in result.cat.categories:
94
+ result = result.cat.add_categories([_MISSING_LABEL])
95
+ return result.fillna(_MISSING_LABEL)
96
+ if is_string_dtype(series.dtype) or series.dtype == object:
97
+ return series.fillna(_MISSING_LABEL)
98
+ return series
99
+
100
+ def table(
101
+ self,
102
+ column: str,
103
+ n: int | float = float("inf"),
104
+ continuous: ContinuousConfig | None = None,
105
+ ) -> pd.DataFrame:
106
+ """Generate a contribution table for a column."""
107
+
108
+ self._validate_column(column)
109
+ limit = self._normalize_n(n)
110
+ continuous_config_value = continuous or continuous_config()
111
+
112
+ data_contrib = self._compute_contribution(column, continuous_config_value)
113
+ data_info = self._compute_info(column, continuous_config_value)
114
+ result = data_contrib.merge(data_info, on="items", how="left")
115
+
116
+ is_factor = isinstance(result["items"].dtype, pd.CategoricalDtype)
117
+ if is_factor:
118
+ result = result.sort_values("items", kind="stable").reset_index(drop=True)
119
+ levels = list(result["items"].cat.categories)
120
+ else:
121
+ result = self._sort_by_abs_contrib(result)
122
+ levels = []
123
+
124
+ n_items = len(result)
125
+ if n_items > limit:
126
+ other_count = n_items - limit + 1
127
+ sorted_result = self._sort_by_abs_contrib(result)
128
+ result_head = sorted_result.head(limit - 1).copy()
129
+ result_head["items"] = result_head["items"].astype(str)
130
+
131
+ result_tail = sorted_result.tail(other_count).copy()
132
+ other_label = f"Sum of {other_count} other attributes"
133
+ result_other = pd.DataFrame(
134
+ [
135
+ {
136
+ "items": other_label,
137
+ "contrib": result_tail["contrib"].sum(),
138
+ "n1": int(result_tail["n1"].sum()),
139
+ "n2": int(result_tail["n2"].sum()),
140
+ "x1": result_tail["x1"].sum(),
141
+ "x2": result_tail["x2"].sum(),
142
+ },
143
+ ],
144
+ )
145
+ result_other["rate1"] = self._safe_rate(
146
+ result_other.loc[0, "x1"],
147
+ result_other.loc[0, "n1"],
148
+ )
149
+ result_other["rate2"] = self._safe_rate(
150
+ result_other.loc[0, "x2"],
151
+ result_other.loc[0, "n2"],
152
+ )
153
+ result = pd.concat([result_head, result_other], ignore_index=True)
154
+ if is_factor:
155
+ result["items"] = pd.Categorical(
156
+ result["items"],
157
+ categories=[*levels, other_label],
158
+ ordered=True,
159
+ )
160
+ result = result.sort_values("items", kind="stable").reset_index(
161
+ drop=True,
162
+ )
163
+
164
+ result = result.rename(columns={"items": column})
165
+ return result[
166
+ [column, "contrib", "n1", "n2", "x1", "x2", "rate1", "rate2"]
167
+ ]
168
+
169
+ def plot(
170
+ self,
171
+ column: str,
172
+ n: int = 10,
173
+ main_item: str | None = None,
174
+ bar_max_value: float | None = None,
175
+ levels: Sequence[str] | None = None,
176
+ continuous: ContinuousConfig | None = None,
177
+ ax: Any | None = None,
178
+ figsize: tuple[float, float] | None = None,
179
+ ) -> Any:
180
+ """Generate a Theseus plot for a column."""
181
+
182
+ continuous_config_value = continuous or continuous_config()
183
+ score1, _ = self._compute_scores(column)
184
+ table = self.table(column, n=n, continuous=continuous_config_value)
185
+ plot_data = self._plot_contribution_data(
186
+ table=table,
187
+ column=column,
188
+ levels=levels,
189
+ )
190
+ size_data = self._plot_size_data(
191
+ table=plot_data,
192
+ column=column,
193
+ main_item=main_item,
194
+ bar_max_value=bar_max_value,
195
+ )
196
+ waterfall = self._waterfall_data(
197
+ items=plot_data[column].astype(str).tolist(),
198
+ contributions=plot_data["contrib"].astype(float).tolist(),
199
+ start=score1,
200
+ )
201
+ return self._draw_plot(
202
+ waterfall=waterfall,
203
+ size_data=size_data,
204
+ column=column,
205
+ ax=ax,
206
+ figsize=figsize,
207
+ )
208
+
209
+ def plot_flip(
210
+ self,
211
+ column: str,
212
+ n: int = 10,
213
+ main_item: str | None = None,
214
+ bar_max_value: float | None = None,
215
+ levels: Sequence[str] | None = None,
216
+ continuous: ContinuousConfig | None = None,
217
+ ax: Any | None = None,
218
+ figsize: tuple[float, float] | None = None,
219
+ ) -> Any:
220
+ """Generate a horizontally oriented Theseus plot for a column."""
221
+
222
+ continuous_config_value = continuous or continuous_config()
223
+ _, score2 = self._compute_scores(column)
224
+ table = self.table(column, n=n, continuous=continuous_config_value)
225
+ table = table.copy()
226
+ table["contrib"] = -table["contrib"]
227
+ plot_data = self._plot_flip_contribution_data(
228
+ table=table,
229
+ column=column,
230
+ levels=levels,
231
+ )
232
+ size_data = self._plot_size_data(
233
+ table=plot_data,
234
+ column=column,
235
+ main_item=main_item,
236
+ bar_max_value=bar_max_value,
237
+ )
238
+ waterfall = self._waterfall_data(
239
+ items=plot_data[column].astype(str).tolist(),
240
+ contributions=plot_data["contrib"].astype(float).tolist(),
241
+ start=score2,
242
+ start_label=self.labels[1],
243
+ end_label=self.labels[0],
244
+ )
245
+ return self._draw_plot_flip(
246
+ waterfall=waterfall,
247
+ size_data=size_data,
248
+ column=column,
249
+ ax=ax,
250
+ figsize=figsize,
251
+ )
252
+
253
+ def _compute_scores(self, column: str) -> tuple[float, float]:
254
+ key = ("scores", column)
255
+ if key not in self._cache:
256
+ self._cache[key] = (
257
+ float(self._data1[_OUTCOME_COLUMN].mean()),
258
+ float(self._data2[_OUTCOME_COLUMN].mean()),
259
+ )
260
+ return cast(tuple[float, float], self._cache[key])
261
+
262
+ def _to_factor(
263
+ self,
264
+ column: str,
265
+ continuous: ContinuousConfig,
266
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
267
+ key = ("to_factor", column, continuous)
268
+ if key not in self._cache:
269
+ breaks = self._continuous_breaks(column, continuous)
270
+ data1 = self._data1.copy()
271
+ data2 = self._data2.copy()
272
+ data1[column] = self._cut_to_categorical(data1[column], breaks)
273
+ data2[column] = self._cut_to_categorical(data2[column], breaks)
274
+ self._cache[key] = (data1, data2)
275
+ data1_cached, data2_cached = cast(
276
+ tuple[pd.DataFrame, pd.DataFrame],
277
+ self._cache[key],
278
+ )
279
+ return data1_cached.copy(), data2_cached.copy()
280
+
281
+ def _compute_contribution(
282
+ self,
283
+ column: str,
284
+ continuous: ContinuousConfig,
285
+ ) -> pd.DataFrame:
286
+ key = ("contribution", column, continuous)
287
+ if key in self._cache:
288
+ cached = cast(pd.DataFrame, self._cache[key])
289
+ return cached.copy()
290
+
291
+ data1, data2 = self._data_for_column(column, continuous)
292
+ grouped1 = self._summarize_by_column(data1, column)
293
+ grouped2 = self._summarize_by_column(data2, column)
294
+
295
+ score1, score2 = self._compute_scores(column)
296
+ amounts = self._compute_replacement_amounts(
297
+ original=grouped1,
298
+ refitted=grouped2,
299
+ original_score=score1,
300
+ refitted_score=score2,
301
+ )
302
+ result = self._average_replacement_amounts(
303
+ amounts=amounts,
304
+ original_items=grouped1["items"],
305
+ refitted_items=grouped2["items"],
306
+ )
307
+ result["contrib"] = self._scale_contributions(
308
+ result["contrib"],
309
+ overall_diff=score2 - score1,
310
+ )
311
+
312
+ self._cache[key] = result.copy()
313
+ return result
314
+
315
+ @classmethod
316
+ def _compute_replacement_amounts(
317
+ cls,
318
+ original: pd.DataFrame,
319
+ refitted: pd.DataFrame,
320
+ original_score: float,
321
+ refitted_score: float,
322
+ ) -> pd.DataFrame:
323
+ rows: list[dict[str, Any]] = []
324
+
325
+ for item in refitted["items"]:
326
+ replaced = cls._replace_or_append_group(original, refitted, item)
327
+ rows.append(
328
+ {
329
+ "items": item,
330
+ "amount": cls._score_from_summary(replaced) - original_score,
331
+ },
332
+ )
333
+
334
+ for item in original["items"]:
335
+ replaced = cls._replace_or_append_group(refitted, original, item)
336
+ rows.append(
337
+ {
338
+ "items": item,
339
+ "amount": refitted_score - cls._score_from_summary(replaced),
340
+ },
341
+ )
342
+
343
+ return pd.DataFrame(rows, columns=["items", "amount"])
344
+
345
+ @classmethod
346
+ def _average_replacement_amounts(
347
+ cls,
348
+ amounts: pd.DataFrame,
349
+ original_items: pd.Series,
350
+ refitted_items: pd.Series,
351
+ ) -> pd.DataFrame:
352
+ amounts = amounts.copy()
353
+ dtype = cls._combined_categorical_dtype(original_items, refitted_items)
354
+ if dtype is not None:
355
+ amounts["items"] = pd.Categorical(
356
+ amounts["items"],
357
+ categories=dtype.categories,
358
+ ordered=dtype.ordered,
359
+ )
360
+
361
+ return (
362
+ amounts.groupby("items", observed=True, sort=False)["amount"]
363
+ .mean()
364
+ .reset_index()
365
+ .rename(columns={"amount": "contrib"})
366
+ )
367
+
368
+ @staticmethod
369
+ def _scale_contributions(
370
+ contributions: pd.Series,
371
+ overall_diff: float,
372
+ ) -> pd.Series:
373
+ raw_total = float(contributions.sum())
374
+ if np.isclose(raw_total, 0.0):
375
+ scaled = 0.0 if np.isclose(overall_diff, 0.0) else np.nan
376
+ return pd.Series(scaled, index=contributions.index, dtype=float)
377
+ return cast(pd.Series, overall_diff * contributions / raw_total)
378
+
379
+ def _compute_info(self, column: str, continuous: ContinuousConfig) -> pd.DataFrame:
380
+ key = ("info", column, continuous)
381
+ if key in self._cache:
382
+ cached = cast(pd.DataFrame, self._cache[key])
383
+ return cached.copy()
384
+
385
+ data1, data2 = self._data_for_column(column, continuous)
386
+ data1_info = self._summarize_info(data1, column, suffix="1")
387
+ data2_info = self._summarize_info(data2, column, suffix="2")
388
+ result = data1_info.merge(data2_info, on="items", how="outer", sort=False)
389
+
390
+ dtype = self._combined_categorical_dtype(
391
+ data1_info["items"],
392
+ data2_info["items"],
393
+ )
394
+ if dtype is not None:
395
+ result["items"] = pd.Categorical(
396
+ result["items"],
397
+ categories=dtype.categories,
398
+ ordered=dtype.ordered,
399
+ )
400
+
401
+ for column_name in ("n1", "n2"):
402
+ result[column_name] = result[column_name].fillna(0).astype(int)
403
+ for column_name in ("x1", "x2"):
404
+ result[column_name] = result[column_name].fillna(0)
405
+
406
+ result = result[["items", "n1", "n2", "x1", "x2", "rate1", "rate2"]]
407
+ self._cache[key] = result.copy()
408
+ return result
409
+
410
+ def _compute_size(
411
+ self,
412
+ column: str,
413
+ target: Sequence[str],
414
+ continuous: ContinuousConfig,
415
+ ) -> pd.DataFrame:
416
+ data1, data2 = self._data_for_column(column, continuous)
417
+ target_items = [str(item) for item in target]
418
+
419
+ data1_size = self._count_target_items(
420
+ data=data1,
421
+ column=column,
422
+ target=target_items,
423
+ label=self.labels[0],
424
+ )
425
+ data2_size = self._count_target_items(
426
+ data=data2,
427
+ column=column,
428
+ target=target_items,
429
+ label=self.labels[1],
430
+ )
431
+ item_names = set(data1_size["items"].astype(str)).union(
432
+ data2_size["items"].astype(str),
433
+ )
434
+ other_names = [item for item in target_items if item not in item_names]
435
+
436
+ if not other_names:
437
+ return pd.concat([data1_size, data2_size], ignore_index=True)
438
+
439
+ rows = [data1_size, data2_size]
440
+ for item in other_names:
441
+ rows.append(
442
+ self._count_other_items(
443
+ data=data1,
444
+ column=column,
445
+ target=target_items,
446
+ item=item,
447
+ label=self.labels[0],
448
+ ),
449
+ )
450
+ rows.append(
451
+ self._count_other_items(
452
+ data=data2,
453
+ column=column,
454
+ target=target_items,
455
+ item=item,
456
+ label=self.labels[1],
457
+ ),
458
+ )
459
+ return pd.concat(rows, ignore_index=True)
460
+
461
+ def _plot_contribution_data(
462
+ self,
463
+ table: pd.DataFrame,
464
+ column: str,
465
+ levels: Sequence[str] | None,
466
+ ) -> pd.DataFrame:
467
+ data = table[[column, "contrib", "n1", "n2"]].copy()
468
+ is_factor = isinstance(data[column].dtype, pd.CategoricalDtype)
469
+ if is_factor:
470
+ data = data.sort_values(column, kind="stable")
471
+ else:
472
+ data = data.sort_values("contrib", kind="stable")
473
+
474
+ if levels is not None:
475
+ level_frame = pd.DataFrame({column: [str(level) for level in levels]})
476
+ data[column] = data[column].astype(str)
477
+ data = level_frame.merge(data, on=column, how="inner", sort=False)
478
+
479
+ return data.reset_index(drop=True)
480
+
481
+ def _plot_flip_contribution_data(
482
+ self,
483
+ table: pd.DataFrame,
484
+ column: str,
485
+ levels: Sequence[str] | None,
486
+ ) -> pd.DataFrame:
487
+ data = table[[column, "contrib", "n1", "n2"]].copy()
488
+ is_factor = isinstance(data[column].dtype, pd.CategoricalDtype)
489
+ if is_factor:
490
+ data = data.sort_values(column, ascending=False, kind="stable")
491
+ else:
492
+ data = data.sort_values("contrib", kind="stable")
493
+
494
+ if levels is not None:
495
+ level_frame = pd.DataFrame(
496
+ {column: [str(level) for level in reversed(levels)]},
497
+ )
498
+ data[column] = data[column].astype(str)
499
+ data = level_frame.merge(data, on=column, how="inner", sort=False)
500
+
501
+ return data.reset_index(drop=True)
502
+
503
+ def _plot_size_data(
504
+ self,
505
+ table: pd.DataFrame,
506
+ column: str,
507
+ main_item: str | None,
508
+ bar_max_value: float | None,
509
+ ) -> pd.DataFrame:
510
+ size_data = pd.concat(
511
+ [
512
+ pd.DataFrame(
513
+ {
514
+ "items": table[column].astype(str),
515
+ "n": table["n1"].astype(float),
516
+ "type": self.labels[0],
517
+ },
518
+ ),
519
+ pd.DataFrame(
520
+ {
521
+ "items": table[column].astype(str),
522
+ "n": table["n2"].astype(float),
523
+ "type": self.labels[1],
524
+ },
525
+ ),
526
+ ],
527
+ ignore_index=True,
528
+ )
529
+ if table.empty:
530
+ size_data["scaled_n"] = 0.0
531
+ return size_data
532
+
533
+ max_amount, n_max = self._size_scale_reference(
534
+ table=table,
535
+ column=column,
536
+ main_item=main_item,
537
+ bar_max_value=bar_max_value,
538
+ )
539
+ if np.isclose(n_max, 0.0) or np.isclose(max_amount, 0.0):
540
+ size_data["scaled_n"] = 0.0
541
+ else:
542
+ size_data["scaled_n"] = size_data["n"] / n_max * max_amount
543
+ return size_data
544
+
545
+ def _size_scale_reference(
546
+ self,
547
+ table: pd.DataFrame,
548
+ column: str,
549
+ main_item: str | None,
550
+ bar_max_value: float | None,
551
+ ) -> tuple[float, float]:
552
+ data = table.copy()
553
+ data["_max_n"] = data[["n1", "n2"]].max(axis=1).astype(float)
554
+ data["_abs_contrib"] = (data["contrib"].astype(float) * 100).abs()
555
+
556
+ if main_item is not None:
557
+ row = data[data[column].astype(str) == str(main_item)]
558
+ if row.empty:
559
+ msg = f"main_item {main_item!r} is not present in the plot data."
560
+ raise ValueError(msg)
561
+ return float(row["_abs_contrib"].iloc[0]), float(row["_max_n"].iloc[0])
562
+
563
+ if bar_max_value is not None:
564
+ return abs(float(bar_max_value)), float(data["_max_n"].max())
565
+
566
+ row = data.sort_values("_abs_contrib", ascending=False).head(1)
567
+ return float(row["_abs_contrib"].iloc[0]), float(row["_max_n"].iloc[0])
568
+
569
+ def _waterfall_data(
570
+ self,
571
+ items: Sequence[str],
572
+ contributions: Sequence[float],
573
+ start: float,
574
+ start_label: str | None = None,
575
+ end_label: str | None = None,
576
+ ) -> pd.DataFrame:
577
+ rows: list[dict[str, Any]] = []
578
+ cumulative = start * 100
579
+ start_label = self.labels[0] if start_label is None else start_label
580
+ end_label = self.labels[1] if end_label is None else end_label
581
+ rows.append(
582
+ {
583
+ "items": start_label,
584
+ "bottom": 0.0,
585
+ "height": cumulative,
586
+ "amount": cumulative,
587
+ "cumulative": cumulative,
588
+ "kind": "total",
589
+ },
590
+ )
591
+
592
+ for item, contribution in zip(items, contributions, strict=True):
593
+ amount = contribution * 100
594
+ bottom = cumulative if amount >= 0 else cumulative + amount
595
+ cumulative += amount
596
+ rows.append(
597
+ {
598
+ "items": item,
599
+ "bottom": bottom,
600
+ "height": abs(amount),
601
+ "amount": amount,
602
+ "cumulative": cumulative,
603
+ "kind": "contribution",
604
+ },
605
+ )
606
+
607
+ rows.append(
608
+ {
609
+ "items": end_label,
610
+ "bottom": 0.0,
611
+ "height": cumulative,
612
+ "amount": cumulative,
613
+ "cumulative": cumulative,
614
+ "kind": "total",
615
+ },
616
+ )
617
+ return pd.DataFrame(rows)
618
+
619
+ def _draw_plot(
620
+ self,
621
+ waterfall: pd.DataFrame,
622
+ size_data: pd.DataFrame,
623
+ column: str,
624
+ ax: Any | None,
625
+ figsize: tuple[float, float] | None,
626
+ ) -> Any:
627
+ plt = self._load_pyplot()
628
+ if ax is None:
629
+ fig, ax = plt.subplots(figsize=figsize or self._default_figsize(waterfall))
630
+ else:
631
+ fig = ax.figure
632
+
633
+ positions = cast(
634
+ NDArray[np.int64],
635
+ np.arange(len(waterfall), dtype=np.int64),
636
+ )
637
+ item_to_position = {
638
+ str(item): index for index, item in enumerate(waterfall["items"])
639
+ }
640
+
641
+ self._draw_size_bars(ax=ax, size_data=size_data, positions=item_to_position)
642
+ self._draw_waterfall_bars(ax=ax, waterfall=waterfall, positions=positions)
643
+ self._draw_connectors(ax=ax, waterfall=waterfall, positions=positions)
644
+
645
+ ax.axhline(0, color="#333333", linewidth=0.8)
646
+ ax.set_xticks(positions)
647
+ ax.set_xticklabels(waterfall["items"].astype(str), rotation=45, ha="right")
648
+ ax.set_ylabel(self.y_label or "")
649
+ ax.set_xlabel("")
650
+ ax.set_title(column)
651
+ ax.margins(x=0.02)
652
+ fig.tight_layout()
653
+ return fig, ax
654
+
655
+ def _draw_plot_flip(
656
+ self,
657
+ waterfall: pd.DataFrame,
658
+ size_data: pd.DataFrame,
659
+ column: str,
660
+ ax: Any | None,
661
+ figsize: tuple[float, float] | None,
662
+ ) -> Any:
663
+ plt = self._load_pyplot()
664
+ if ax is None:
665
+ fig, ax = plt.subplots(
666
+ figsize=figsize or self._default_flip_figsize(waterfall),
667
+ )
668
+ else:
669
+ fig = ax.figure
670
+
671
+ positions = cast(
672
+ NDArray[np.int64],
673
+ np.arange(len(waterfall), dtype=np.int64),
674
+ )
675
+ item_to_position = {
676
+ str(item): index for index, item in enumerate(waterfall["items"])
677
+ }
678
+
679
+ self._draw_size_bars_horizontal(
680
+ ax=ax,
681
+ size_data=size_data,
682
+ positions=item_to_position,
683
+ )
684
+ self._draw_waterfall_bars_horizontal(
685
+ ax=ax,
686
+ waterfall=waterfall,
687
+ positions=positions,
688
+ )
689
+ self._draw_connectors_horizontal(
690
+ ax=ax,
691
+ waterfall=waterfall,
692
+ positions=positions,
693
+ )
694
+
695
+ ax.axvline(0, color="#333333", linewidth=0.8)
696
+ ax.set_yticks(positions)
697
+ ax.set_yticklabels(waterfall["items"].astype(str))
698
+ ax.invert_yaxis()
699
+ ax.set_xlabel(self.y_label or "")
700
+ ax.set_ylabel("")
701
+ ax.set_title(column)
702
+ ax.margins(y=0.02)
703
+ fig.tight_layout()
704
+ return fig, ax
705
+
706
+ @staticmethod
707
+ def _load_pyplot() -> Any:
708
+ try:
709
+ return cast(Any, importlib.import_module("matplotlib.pyplot"))
710
+ except ModuleNotFoundError as exc:
711
+ msg = "matplotlib is required to use plotting methods."
712
+ raise ModuleNotFoundError(msg) from exc
713
+
714
+ def _draw_size_bars(
715
+ self,
716
+ ax: Any,
717
+ size_data: pd.DataFrame,
718
+ positions: dict[str, int],
719
+ ) -> None:
720
+ width = 0.22
721
+ offsets = {self.labels[0]: -width / 1.5, self.labels[1]: width / 1.5}
722
+ colors = {
723
+ self.labels[0]: _ORIGINAL_SIZE_COLOR,
724
+ self.labels[1]: _REFITTED_SIZE_COLOR,
725
+ }
726
+ for _, row in size_data.iterrows():
727
+ item = str(row["items"])
728
+ if item not in positions:
729
+ continue
730
+ group = str(row["type"])
731
+ ax.bar(
732
+ positions[item] + offsets[group],
733
+ row["scaled_n"],
734
+ width=width,
735
+ color=colors[group],
736
+ alpha=0.35,
737
+ linewidth=0,
738
+ zorder=1,
739
+ )
740
+
741
+ def _draw_size_bars_horizontal(
742
+ self,
743
+ ax: Any,
744
+ size_data: pd.DataFrame,
745
+ positions: dict[str, int],
746
+ ) -> None:
747
+ height = 0.22
748
+ offsets = {self.labels[0]: -height / 1.5, self.labels[1]: height / 1.5}
749
+ colors = {
750
+ self.labels[0]: _ORIGINAL_SIZE_COLOR,
751
+ self.labels[1]: _REFITTED_SIZE_COLOR,
752
+ }
753
+ for _, row in size_data.iterrows():
754
+ item = str(row["items"])
755
+ if item not in positions:
756
+ continue
757
+ group = str(row["type"])
758
+ ax.barh(
759
+ positions[item] + offsets[group],
760
+ row["scaled_n"],
761
+ height=height,
762
+ color=colors[group],
763
+ alpha=0.35,
764
+ linewidth=0,
765
+ zorder=1,
766
+ )
767
+
768
+ def _draw_waterfall_bars(
769
+ self,
770
+ ax: Any,
771
+ waterfall: pd.DataFrame,
772
+ positions: NDArray[np.int64],
773
+ ) -> None:
774
+ colors = [
775
+ _REFITTED_COLOR
776
+ if row["kind"] == "total"
777
+ else (_POSITIVE_COLOR if row["amount"] >= 0 else _NEGATIVE_COLOR)
778
+ for _, row in waterfall.iterrows()
779
+ ]
780
+ ax.bar(
781
+ positions,
782
+ waterfall["height"],
783
+ bottom=waterfall["bottom"],
784
+ width=0.62,
785
+ color=colors,
786
+ edgecolor="#333333",
787
+ linewidth=0.6,
788
+ zorder=3,
789
+ )
790
+ for position, (_, row) in zip(positions, waterfall.iterrows(), strict=True):
791
+ value = self._format_plot_value(float(row["amount"]))
792
+ y = float(row["bottom"]) + float(row["height"])
793
+ va = "bottom"
794
+ if float(row["amount"]) < 0:
795
+ y = float(row["bottom"])
796
+ va = "top"
797
+ ax.text(
798
+ position,
799
+ y,
800
+ value,
801
+ ha="center",
802
+ va=va,
803
+ fontsize=9 * self.text_size,
804
+ zorder=4,
805
+ )
806
+
807
+ def _draw_waterfall_bars_horizontal(
808
+ self,
809
+ ax: Any,
810
+ waterfall: pd.DataFrame,
811
+ positions: NDArray[np.int64],
812
+ ) -> None:
813
+ colors = [
814
+ _REFITTED_COLOR
815
+ if row["kind"] == "total"
816
+ else (_POSITIVE_COLOR if row["amount"] >= 0 else _NEGATIVE_COLOR)
817
+ for _, row in waterfall.iterrows()
818
+ ]
819
+ ax.barh(
820
+ positions,
821
+ waterfall["height"],
822
+ left=waterfall["bottom"],
823
+ height=0.62,
824
+ color=colors,
825
+ edgecolor="#333333",
826
+ linewidth=0.6,
827
+ zorder=3,
828
+ )
829
+ for position, (_, row) in zip(positions, waterfall.iterrows(), strict=True):
830
+ value = self._format_plot_value(float(row["amount"]))
831
+ x = float(row["bottom"]) + float(row["height"])
832
+ ha = "left"
833
+ if float(row["amount"]) < 0:
834
+ x = float(row["bottom"])
835
+ ha = "right"
836
+ ax.text(
837
+ x,
838
+ position,
839
+ value,
840
+ ha=ha,
841
+ va="center",
842
+ fontsize=9 * self.text_size,
843
+ zorder=4,
844
+ )
845
+
846
+ @staticmethod
847
+ def _draw_connectors(
848
+ ax: Any,
849
+ waterfall: pd.DataFrame,
850
+ positions: NDArray[np.int64],
851
+ ) -> None:
852
+ for index in range(len(waterfall) - 2):
853
+ y = float(waterfall.loc[index, "cumulative"])
854
+ ax.plot(
855
+ [positions[index] + 0.31, positions[index + 1] - 0.31],
856
+ [y, y],
857
+ color="#666666",
858
+ linewidth=0.8,
859
+ zorder=2,
860
+ )
861
+
862
+ @staticmethod
863
+ def _draw_connectors_horizontal(
864
+ ax: Any,
865
+ waterfall: pd.DataFrame,
866
+ positions: NDArray[np.int64],
867
+ ) -> None:
868
+ for index in range(len(waterfall) - 2):
869
+ x = float(waterfall.loc[index, "cumulative"])
870
+ ax.plot(
871
+ [x, x],
872
+ [positions[index] + 0.31, positions[index + 1] - 0.31],
873
+ color="#666666",
874
+ linewidth=0.8,
875
+ zorder=2,
876
+ )
877
+
878
+ @staticmethod
879
+ def _default_figsize(waterfall: pd.DataFrame) -> tuple[float, float]:
880
+ return max(6.0, len(waterfall) * 0.75), 4.5
881
+
882
+ @staticmethod
883
+ def _default_flip_figsize(waterfall: pd.DataFrame) -> tuple[float, float]:
884
+ return 6.5, max(4.5, len(waterfall) * 0.45)
885
+
886
+ def _format_plot_value(self, value: float) -> str:
887
+ rounded = round(value, self.digits)
888
+ return f"{rounded:g}"
889
+
890
+ @staticmethod
891
+ def _count_target_items(
892
+ data: pd.DataFrame,
893
+ column: str,
894
+ target: Sequence[str],
895
+ label: str,
896
+ ) -> pd.DataFrame:
897
+ counts = (
898
+ data[data[column].astype(str).isin(target)]
899
+ .groupby(column, observed=True, sort=False)
900
+ .size()
901
+ .reset_index(name="n")
902
+ .rename(columns={column: "items"})
903
+ )
904
+ counts["items"] = counts["items"].astype(str)
905
+ counts["type"] = label
906
+ return counts[["items", "n", "type"]]
907
+
908
+ @staticmethod
909
+ def _count_other_items(
910
+ data: pd.DataFrame,
911
+ column: str,
912
+ target: Sequence[str],
913
+ item: str,
914
+ label: str,
915
+ ) -> pd.DataFrame:
916
+ count = int((~data[column].astype(str).isin(target)).sum())
917
+ return pd.DataFrame([{"items": item, "n": count, "type": label}])
918
+
919
+ def _validate_column(self, column: str) -> None:
920
+ if column not in self._data1.columns:
921
+ msg = f"column {column!r} is missing from data1."
922
+ raise ValueError(msg)
923
+ if column not in self._data2.columns:
924
+ msg = f"column {column!r} is missing from data2."
925
+ raise ValueError(msg)
926
+
927
+ @staticmethod
928
+ def _normalize_n(n: int | float) -> int | float:
929
+ if n == float("inf"):
930
+ return n
931
+ limit = int(n)
932
+ if limit < 1:
933
+ msg = "n must be at least 1 or infinity."
934
+ raise ValueError(msg)
935
+ return limit
936
+
937
+ @staticmethod
938
+ def _sort_by_abs_contrib(data: pd.DataFrame) -> pd.DataFrame:
939
+ return (
940
+ data.assign(_abs_contrib=data["contrib"].abs())
941
+ .sort_values("_abs_contrib", ascending=False, kind="stable")
942
+ .drop(columns="_abs_contrib")
943
+ .reset_index(drop=True)
944
+ )
945
+
946
+ @staticmethod
947
+ def _safe_rate(x: float, n: int) -> float:
948
+ return float(x) / n if n else float("nan")
949
+
950
+ def _data_for_column(
951
+ self,
952
+ column: str,
953
+ continuous: ContinuousConfig,
954
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
955
+ if is_numeric_dtype(self._data1[column]):
956
+ return self._to_factor(column, continuous)
957
+ return self._data1, self._data2
958
+
959
+ @staticmethod
960
+ def _summarize_by_column(data: pd.DataFrame, column: str) -> pd.DataFrame:
961
+ grouped = (
962
+ data.groupby(column, observed=True, sort=False)[_OUTCOME_COLUMN]
963
+ .agg(y="sum", n="size")
964
+ .reset_index()
965
+ .rename(columns={column: "items"})
966
+ )
967
+ grouped["rate"] = grouped["y"] / grouped["n"]
968
+ return grouped
969
+
970
+ @staticmethod
971
+ def _summarize_info(
972
+ data: pd.DataFrame,
973
+ column: str,
974
+ suffix: str,
975
+ ) -> pd.DataFrame:
976
+ grouped = (
977
+ data.groupby(column, observed=True, sort=False)[_OUTCOME_COLUMN]
978
+ .agg(**{f"x{suffix}": "sum", f"n{suffix}": "size"})
979
+ .reset_index()
980
+ .rename(columns={column: "items"})
981
+ )
982
+ grouped[f"rate{suffix}"] = grouped[f"x{suffix}"] / grouped[f"n{suffix}"]
983
+ return grouped
984
+
985
+ @staticmethod
986
+ def _replace_or_append_group(
987
+ base: pd.DataFrame,
988
+ replacement: pd.DataFrame,
989
+ item: Any,
990
+ ) -> pd.DataFrame:
991
+ result = base.copy()
992
+ replacement_row = replacement[replacement["items"] == item]
993
+ mask = result["items"] == item
994
+ if mask.any():
995
+ result.loc[mask, ["y", "n", "rate"]] = replacement_row[
996
+ ["y", "n", "rate"]
997
+ ].to_numpy()
998
+ return result
999
+ return pd.concat([result, replacement_row], ignore_index=True)
1000
+
1001
+ @staticmethod
1002
+ def _score_from_summary(data: pd.DataFrame) -> float:
1003
+ return float(data["y"].sum() / data["n"].sum())
1004
+
1005
+ @staticmethod
1006
+ def _combined_categorical_dtype(
1007
+ values1: pd.Series,
1008
+ values2: pd.Series,
1009
+ ) -> pd.CategoricalDtype | None:
1010
+ dtype1 = values1.dtype
1011
+ dtype2 = values2.dtype
1012
+ if not isinstance(dtype1, pd.CategoricalDtype):
1013
+ return None
1014
+
1015
+ categories = list(dtype1.categories)
1016
+ if isinstance(dtype2, pd.CategoricalDtype):
1017
+ categories.extend(
1018
+ category
1019
+ for category in dtype2.categories
1020
+ if category not in categories
1021
+ )
1022
+ else:
1023
+ categories.extend(value for value in values2 if value not in categories)
1024
+ return pd.CategoricalDtype(categories=categories, ordered=dtype1.ordered)
1025
+
1026
+ def _continuous_breaks(
1027
+ self,
1028
+ column: str,
1029
+ continuous: ContinuousConfig,
1030
+ ) -> NDArray[np.float64]:
1031
+ if continuous.breaks is not None:
1032
+ return self._validate_breaks(
1033
+ np.asarray(continuous.breaks, dtype=float),
1034
+ )
1035
+
1036
+ values = pd.concat(
1037
+ [self._data1[column], self._data2[column]],
1038
+ ignore_index=True,
1039
+ ).astype(float)
1040
+ non_missing = values.dropna()
1041
+ if non_missing.empty:
1042
+ msg = f"column {column!r} must contain at least one non-missing value."
1043
+ raise ValueError(msg)
1044
+
1045
+ break_num = continuous.n
1046
+
1047
+ if continuous.split == "width":
1048
+ if values.isna().any():
1049
+ break_num -= 1
1050
+ self._validate_break_num(break_num, continuous)
1051
+ breaks = np.linspace(
1052
+ non_missing.min(),
1053
+ non_missing.max(),
1054
+ break_num + 1,
1055
+ )
1056
+ elif continuous.split == "count":
1057
+ self._validate_break_num(
1058
+ break_num - int(values.isna().any()),
1059
+ continuous,
1060
+ )
1061
+ breaks = self._compute_breaks(values, break_num)
1062
+ else:
1063
+ self._validate_break_num(
1064
+ break_num - int(values.isna().any()),
1065
+ continuous,
1066
+ )
1067
+ breaks = np.unique(self._compute_breaks(values, break_num * 20))
1068
+ data1 = self._data1[[column, _OUTCOME_COLUMN]].dropna(subset=[column])
1069
+ data2 = self._data2[[column, _OUTCOME_COLUMN]].dropna(subset=[column])
1070
+ while len(breaks) > break_num + 1:
1071
+ diff = self._adjacent_rate_diff(data1, data2, column, breaks)
1072
+ remove_at = int(np.nanargmin(diff)) + 1
1073
+ breaks = np.delete(breaks, remove_at)
1074
+
1075
+ if continuous.pretty:
1076
+ pretty = self._pretty_breaks(breaks)
1077
+ if len(np.unique(pretty)) < len(pretty):
1078
+ warnings.warn(
1079
+ "Prettying breaks reduced the number of breaks. "
1080
+ "Try pretty=False.",
1081
+ stacklevel=2,
1082
+ )
1083
+ pretty = np.unique(pretty)
1084
+ breaks = pretty
1085
+ return self._validate_breaks(np.asarray(breaks, dtype=float))
1086
+
1087
+ @staticmethod
1088
+ def _validate_break_num(
1089
+ break_num: int,
1090
+ continuous: ContinuousConfig,
1091
+ ) -> None:
1092
+ if break_num < 1:
1093
+ msg = (
1094
+ "continuous.n must leave at least one numeric bin. "
1095
+ "Increase n or remove missing values."
1096
+ )
1097
+ raise ValueError(msg)
1098
+ if continuous.split == "rate" and break_num < 2:
1099
+ msg = "split='rate' requires at least two numeric bins."
1100
+ raise ValueError(msg)
1101
+
1102
+ @staticmethod
1103
+ def _validate_breaks(breaks: NDArray[np.float64]) -> NDArray[np.float64]:
1104
+ if len(breaks) < _MIN_BREAK_COUNT:
1105
+ msg = "continuous breaks must contain at least two values."
1106
+ raise ValueError(msg)
1107
+ if np.isnan(breaks).any():
1108
+ msg = "continuous breaks must not contain NaN."
1109
+ raise ValueError(msg)
1110
+ if not np.all(np.diff(breaks) > 0):
1111
+ msg = "continuous breaks must be strictly increasing."
1112
+ raise ValueError(msg)
1113
+ return breaks
1114
+
1115
+ @staticmethod
1116
+ def _compute_breaks(
1117
+ values: pd.Series,
1118
+ break_num: int,
1119
+ ) -> NDArray[np.float64]:
1120
+ if values.isna().any():
1121
+ break_num -= 1
1122
+ probs = np.linspace(0, 1, break_num + 1)
1123
+ return np.asarray(values.quantile(probs).to_numpy(), dtype=float)
1124
+
1125
+ @staticmethod
1126
+ def _pretty_breaks(
1127
+ breaks: NDArray[np.float64],
1128
+ ) -> NDArray[np.float64]:
1129
+ result = []
1130
+ for value in breaks:
1131
+ digits = 0 if value == 0 else np.floor(np.log10(abs(value))) + 1
1132
+ base = 10 ** (digits - 2)
1133
+ rounded = np.floor(value / base) if value < 0 else np.ceil(value / base)
1134
+ result.append(rounded * base)
1135
+ return np.asarray(result, dtype=float)
1136
+
1137
+ @staticmethod
1138
+ def _adjacent_rate_diff(
1139
+ data1: pd.DataFrame,
1140
+ data2: pd.DataFrame,
1141
+ column: str,
1142
+ breaks: NDArray[np.float64],
1143
+ ) -> NDArray[np.float64]:
1144
+ def summarize(data: pd.DataFrame, name: str) -> pd.Series:
1145
+ bins = pd.cut(data[column], bins=breaks, include_lowest=True)
1146
+ return (
1147
+ data.groupby(bins, observed=True, sort=False)[_OUTCOME_COLUMN]
1148
+ .mean()
1149
+ .diff(-1)
1150
+ .abs()
1151
+ .rename(name)
1152
+ )
1153
+
1154
+ diff1 = summarize(data1, "diff1")
1155
+ diff2 = summarize(data2, "diff2")
1156
+ merged = pd.concat([diff1, diff2], axis=1)
1157
+ diff = np.sqrt(merged["diff1"] ** 2 + merged["diff2"] ** 2)
1158
+ return cast(NDArray[np.float64], np.asarray(diff, dtype=np.float64))
1159
+
1160
+ @staticmethod
1161
+ def _cut_to_categorical(
1162
+ series: pd.Series,
1163
+ breaks: NDArray[np.float64],
1164
+ ) -> pd.Series:
1165
+ categories = ShipOfTheseus._cut_labels(breaks)
1166
+ values = pd.cut(
1167
+ series,
1168
+ bins=breaks,
1169
+ include_lowest=True,
1170
+ labels=categories,
1171
+ )
1172
+ if _MISSING_LABEL not in values.cat.categories:
1173
+ values = values.cat.add_categories([_MISSING_LABEL])
1174
+ return values.fillna(_MISSING_LABEL)
1175
+
1176
+ @staticmethod
1177
+ def _cut_labels(breaks: NDArray[np.float64]) -> list[str]:
1178
+ labels = []
1179
+ pairs = zip(breaks[:-1], breaks[1:], strict=True)
1180
+ for index, (left, right) in enumerate(pairs):
1181
+ left_bracket = "[" if index == 0 else "("
1182
+ labels.append(
1183
+ f"{left_bracket}{ShipOfTheseus._format_break(left)},"
1184
+ f"{ShipOfTheseus._format_break(right)}]",
1185
+ )
1186
+ return labels
1187
+
1188
+ @staticmethod
1189
+ def _format_break(value: float) -> str:
1190
+ if np.isclose(value, 0.0):
1191
+ return "0"
1192
+ if np.isclose(value, round(value)):
1193
+ return str(int(round(value)))
1194
+ return f"{value:.15g}"
1195
+
1196
+ @staticmethod
1197
+ def _raise_not_implemented() -> None:
1198
+ msg = (
1199
+ "TheseusPlot calculation and plotting logic has not been "
1200
+ "implemented yet."
1201
+ )
1202
+ raise NotImplementedError(msg)
1203
+
1204
+
1205
+ def create_ship(
1206
+ data1: pd.DataFrame,
1207
+ data2: pd.DataFrame,
1208
+ y: str = "y",
1209
+ labels: Sequence[str] = ("Original", "Refitted"),
1210
+ y_label: str | None = None,
1211
+ digits: int = 3,
1212
+ text_size: float = 1.0,
1213
+ ) -> ShipOfTheseus:
1214
+ """Create a ShipOfTheseus object."""
1215
+
1216
+ return ShipOfTheseus(
1217
+ data1=data1,
1218
+ data2=data2,
1219
+ outcome=y,
1220
+ labels=labels,
1221
+ y_label=y_label,
1222
+ digits=digits,
1223
+ text_size=text_size,
1224
+ )