shiftshap 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Mayowa Samuel Olokun
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,126 @@
1
+ Metadata-Version: 2.4
2
+ Name: shiftshap
3
+ Version: 0.2.0
4
+ Summary: Monitor whether your model's SHAP explanations still hold as data drifts.
5
+ Author: Mayowa Samuel Olokun
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/OWNER/shiftshap
8
+ Project-URL: Issues, https://github.com/OWNER/shiftshap/issues
9
+ Keywords: shap,explainability,xai,drift,distribution-shift,model-monitoring,machine-learning
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.9
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: numpy>=1.21
20
+ Requires-Dist: pandas>=1.3
21
+ Provides-Extra: plot
22
+ Requires-Dist: matplotlib>=3.4; extra == "plot"
23
+ Provides-Extra: dev
24
+ Requires-Dist: pytest>=7.0; extra == "dev"
25
+ Requires-Dist: matplotlib>=3.4; extra == "dev"
26
+ Dynamic: license-file
27
+
28
+ # shiftshap
29
+
30
+ **Monitor whether your model's SHAP explanations still hold as your data drifts.**
31
+
32
+ `shiftshap` answers a question every team running a model in production eventually
33
+ asks: *are my model's explanations still trustworthy?* Models live for months or
34
+ years — they get retrained, upstream pipelines change, and feature distributions
35
+ shift. When that happens, the model's reasoning quietly changes with it. The
36
+ feature that drove your predictions last quarter may not be the one driving them
37
+ today.
38
+
39
+ SHAP is excellent at explaining a model **at a single point in time**, but it has
40
+ no built-in way to tell you how those explanations have **changed** between two
41
+ points. Today people work around this by pickling explanation objects and writing
42
+ their own comparison scripts, or by using data-drift tools that know nothing about
43
+ SHAP's structure. `shiftshap` fills that gap.
44
+
45
+ ---
46
+
47
+ ## Install
48
+
49
+ ```bash
50
+ pip install shiftshap # core
51
+ pip install shiftshap[plot] # + matplotlib for the drift chart
52
+ ```
53
+
54
+ ## Quickstart
55
+
56
+ If you already use SHAP, you already have everything you need. Take your SHAP
57
+ values from two periods — training vs. production, or last month vs. this month —
58
+ and pass them in:
59
+
60
+ ```python
61
+ import shiftshap
62
+
63
+ report = shiftshap.compare(reference_shap, current_shap)
64
+
65
+ print(report.summary())
66
+ # 2 of 5 features show HIGH explanation drift (0 medium).
67
+ # Top driver changed from 'income' to 'balance'.
68
+ # Overall rank stability (Spearman): 0.70.
69
+
70
+ print(report.details()) # plain-English narrative of the biggest movers
71
+ report.to_frame() # full per-feature table
72
+ report.plot() # rank-drift bump chart
73
+ shiftshap.metric_definitions() # what every metric means, in plain words
74
+ ```
75
+
76
+ Inputs can be `shap.Explanation` objects, NumPy arrays, lists, or pandas
77
+ DataFrames of shape `(n_samples, n_features)`. **Multi-class SHAP** (3D arrays)
78
+ is supported — by default classes are aggregated, or pass `class_index=k` to
79
+ focus on one class. The two periods don't need the same number of samples.
80
+
81
+ ### Robust by design
82
+
83
+ `shiftshap` is built to survive real, messy production data. It handles NaNs
84
+ (ignored with a note), zero-variance features, tiny samples (with an explicit
85
+ "results unreliable" warning rather than false alarms), and multi-class outputs —
86
+ and it fails with clear, actionable errors on genuinely broken input (infinities,
87
+ mismatched feature counts, empty arrays) instead of cryptic stack traces.
88
+
89
+ ## What it tells you
90
+
91
+ For every feature, `shiftshap` reports:
92
+
93
+ - **Importance drift** — how the mean absolute SHAP value changed between periods.
94
+ - **PSI** (Population Stability Index) on the SHAP distribution — the industry-standard
95
+ shift metric, with the accepted `0.2` threshold flagging significant drift.
96
+ - **Rank drift** — whether your most important features reordered, plus an overall
97
+ Spearman rank-stability score.
98
+ - **Severity** — a `high` / `medium` / `low` label per feature, so the output is
99
+ actionable at a glance.
100
+
101
+ And a **bump chart** showing how the feature-importance ranking shifted:
102
+
103
+ ![rank drift chart](examples/rank_drift.png)
104
+
105
+ ## Why it matters
106
+
107
+ An explanation that has silently drifted is worse than no explanation — it gives
108
+ false confidence. In regulated settings (finance, insurance, healthcare) teams are
109
+ increasingly required to show that model explanations remain valid over time.
110
+ `shiftshap` turns that check into two lines of code.
111
+
112
+ ## Roadmap
113
+
114
+ `v0.1` deliberately does one thing well: compare two snapshots of SHAP explanations
115
+ for tabular models. Planned next:
116
+
117
+ - Persistent explanation store for many time-points (not just two).
118
+ - Research-grade **faithfulness-under-shift** metrics beyond distributional PSI.
119
+ - Alerting hooks for monitoring pipelines.
120
+ - Support for image and text explanations.
121
+
122
+ Contributions and issues welcome.
123
+
124
+ ## License
125
+
126
+ MIT
@@ -0,0 +1,99 @@
1
+ # shiftshap
2
+
3
+ **Monitor whether your model's SHAP explanations still hold as your data drifts.**
4
+
5
+ `shiftshap` answers a question every team running a model in production eventually
6
+ asks: *are my model's explanations still trustworthy?* Models live for months or
7
+ years — they get retrained, upstream pipelines change, and feature distributions
8
+ shift. When that happens, the model's reasoning quietly changes with it. The
9
+ feature that drove your predictions last quarter may not be the one driving them
10
+ today.
11
+
12
+ SHAP is excellent at explaining a model **at a single point in time**, but it has
13
+ no built-in way to tell you how those explanations have **changed** between two
14
+ points. Today people work around this by pickling explanation objects and writing
15
+ their own comparison scripts, or by using data-drift tools that know nothing about
16
+ SHAP's structure. `shiftshap` fills that gap.
17
+
18
+ ---
19
+
20
+ ## Install
21
+
22
+ ```bash
23
+ pip install shiftshap # core
24
+ pip install shiftshap[plot] # + matplotlib for the drift chart
25
+ ```
26
+
27
+ ## Quickstart
28
+
29
+ If you already use SHAP, you already have everything you need. Take your SHAP
30
+ values from two periods — training vs. production, or last month vs. this month —
31
+ and pass them in:
32
+
33
+ ```python
34
+ import shiftshap
35
+
36
+ report = shiftshap.compare(reference_shap, current_shap)
37
+
38
+ print(report.summary())
39
+ # 2 of 5 features show HIGH explanation drift (0 medium).
40
+ # Top driver changed from 'income' to 'balance'.
41
+ # Overall rank stability (Spearman): 0.70.
42
+
43
+ print(report.details()) # plain-English narrative of the biggest movers
44
+ report.to_frame() # full per-feature table
45
+ report.plot() # rank-drift bump chart
46
+ shiftshap.metric_definitions() # what every metric means, in plain words
47
+ ```
48
+
49
+ Inputs can be `shap.Explanation` objects, NumPy arrays, lists, or pandas
50
+ DataFrames of shape `(n_samples, n_features)`. **Multi-class SHAP** (3D arrays)
51
+ is supported — by default classes are aggregated, or pass `class_index=k` to
52
+ focus on one class. The two periods don't need the same number of samples.
53
+
54
+ ### Robust by design
55
+
56
+ `shiftshap` is built to survive real, messy production data. It handles NaNs
57
+ (ignored with a note), zero-variance features, tiny samples (with an explicit
58
+ "results unreliable" warning rather than false alarms), and multi-class outputs —
59
+ and it fails with clear, actionable errors on genuinely broken input (infinities,
60
+ mismatched feature counts, empty arrays) instead of cryptic stack traces.
61
+
62
+ ## What it tells you
63
+
64
+ For every feature, `shiftshap` reports:
65
+
66
+ - **Importance drift** — how the mean absolute SHAP value changed between periods.
67
+ - **PSI** (Population Stability Index) on the SHAP distribution — the industry-standard
68
+ shift metric, with the accepted `0.2` threshold flagging significant drift.
69
+ - **Rank drift** — whether your most important features reordered, plus an overall
70
+ Spearman rank-stability score.
71
+ - **Severity** — a `high` / `medium` / `low` label per feature, so the output is
72
+ actionable at a glance.
73
+
74
+ And a **bump chart** showing how the feature-importance ranking shifted:
75
+
76
+ ![rank drift chart](examples/rank_drift.png)
77
+
78
+ ## Why it matters
79
+
80
+ An explanation that has silently drifted is worse than no explanation — it gives
81
+ false confidence. In regulated settings (finance, insurance, healthcare) teams are
82
+ increasingly required to show that model explanations remain valid over time.
83
+ `shiftshap` turns that check into two lines of code.
84
+
85
+ ## Roadmap
86
+
87
+ `v0.1` deliberately does one thing well: compare two snapshots of SHAP explanations
88
+ for tabular models. Planned next:
89
+
90
+ - Persistent explanation store for many time-points (not just two).
91
+ - Research-grade **faithfulness-under-shift** metrics beyond distributional PSI.
92
+ - Alerting hooks for monitoring pipelines.
93
+ - Support for image and text explanations.
94
+
95
+ Contributions and issues welcome.
96
+
97
+ ## License
98
+
99
+ MIT
@@ -0,0 +1,44 @@
1
+ [build-system]
2
+ requires = ["setuptools>=64", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "shiftshap"
7
+ version = "0.2.0"
8
+ description = "Monitor whether your model's SHAP explanations still hold as data drifts."
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = { text = "MIT" }
12
+ authors = [{ name = "Mayowa Samuel Olokun" }]
13
+ keywords = [
14
+ "shap",
15
+ "explainability",
16
+ "xai",
17
+ "drift",
18
+ "distribution-shift",
19
+ "model-monitoring",
20
+ "machine-learning",
21
+ ]
22
+ classifiers = [
23
+ "Development Status :: 3 - Alpha",
24
+ "Intended Audience :: Science/Research",
25
+ "Intended Audience :: Developers",
26
+ "License :: OSI Approved :: MIT License",
27
+ "Programming Language :: Python :: 3",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ ]
30
+ dependencies = [
31
+ "numpy>=1.21",
32
+ "pandas>=1.3",
33
+ ]
34
+
35
+ [project.optional-dependencies]
36
+ plot = ["matplotlib>=3.4"]
37
+ dev = ["pytest>=7.0", "matplotlib>=3.4"]
38
+
39
+ [project.urls]
40
+ Homepage = "https://github.com/OWNER/shiftshap"
41
+ Issues = "https://github.com/OWNER/shiftshap/issues"
42
+
43
+ [tool.setuptools.packages.find]
44
+ where = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,32 @@
1
+ """shiftshap -- monitor whether your model's SHAP explanations still hold as data drifts.
2
+
3
+ Quickstart
4
+ ----------
5
+ >>> import shiftshap
6
+ >>> report = shiftshap.compare(reference_shap_values, current_shap_values)
7
+ >>> print(report.summary()) # one-line verdict
8
+ >>> print(report.details()) # plain-English narrative of the biggest movers
9
+ >>> report.to_frame() # full per-feature table
10
+ >>> report.plot() # rank-drift bump chart (needs matplotlib)
11
+ >>> shiftshap.metric_definitions() # what every column means
12
+ """
13
+
14
+ from .core import DriftReport, compare, metric_definitions
15
+ from .metrics import (
16
+ METRIC_DEFINITIONS,
17
+ mean_abs_importance,
18
+ population_stability_index,
19
+ spearman_rank_correlation,
20
+ )
21
+
22
+ __all__ = [
23
+ "compare",
24
+ "DriftReport",
25
+ "metric_definitions",
26
+ "METRIC_DEFINITIONS",
27
+ "population_stability_index",
28
+ "mean_abs_importance",
29
+ "spearman_rank_correlation",
30
+ ]
31
+
32
+ __version__ = "0.2.0"
@@ -0,0 +1,261 @@
1
+ """The main entry point: ``shiftshap.compare``.
2
+
3
+ Give it SHAP explanations from two points in time (or two data batches) and it
4
+ tells you which features' importance has drifted, by how much, how severely,
5
+ and whether your most important features have reordered.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from typing import Optional, Sequence
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ from .metrics import (
17
+ METRIC_DEFINITIONS,
18
+ choose_bins,
19
+ importance_ranks,
20
+ mean_abs_importance,
21
+ population_stability_index,
22
+ severity_label,
23
+ spearman_rank_correlation,
24
+ to_shap_array,
25
+ )
26
+
27
+
28
+ def metric_definitions() -> dict:
29
+ """Return plain-language definitions of every metric shiftshap reports."""
30
+ return dict(METRIC_DEFINITIONS)
31
+
32
+
33
+ def _resolve_feature_names(feature_names, reference_shap, current_shap, n_features):
34
+ if feature_names is not None:
35
+ names = list(feature_names)
36
+ else:
37
+ names = (
38
+ getattr(reference_shap, "feature_names", None)
39
+ or getattr(current_shap, "feature_names", None)
40
+ )
41
+ if names is None:
42
+ names = [f"feature_{i}" for i in range(n_features)]
43
+ names = [str(n) for n in names]
44
+ if len(names) != n_features:
45
+ raise ValueError(
46
+ f"Got {len(names)} feature names but SHAP values have "
47
+ f"{n_features} features."
48
+ )
49
+ return names
50
+
51
+
52
+ class DriftReport:
53
+ """Result of a ``compare`` call. Print it, tabulate it, or plot it."""
54
+
55
+ def __init__(self, frame: pd.DataFrame, rank_correlation: float,
56
+ psi_thresholds: tuple[float, float], notes: Optional[list] = None):
57
+ self._frame = frame.sort_values("psi", ascending=False).reset_index(drop=True)
58
+ self.rank_correlation = rank_correlation
59
+ self.psi_thresholds = psi_thresholds
60
+ self.notes = notes or []
61
+
62
+ def to_frame(self) -> pd.DataFrame:
63
+ """Full per-feature drift table, sorted biggest-mover first."""
64
+ return self._frame.copy()
65
+
66
+ @property
67
+ def n_features(self) -> int:
68
+ return len(self._frame)
69
+
70
+ @property
71
+ def n_high(self) -> int:
72
+ return int((self._frame["severity"] == "high").sum())
73
+
74
+ @property
75
+ def n_medium(self) -> int:
76
+ return int((self._frame["severity"] == "medium").sum())
77
+
78
+ def summary(self) -> str:
79
+ """A short, human-readable one-to-three line summary."""
80
+ total = self.n_features
81
+ lines = [
82
+ f"{self.n_high} of {total} feature(s) show HIGH explanation drift "
83
+ f"({self.n_medium} medium)."
84
+ ]
85
+ ref_top = self._frame.sort_values("reference_rank").iloc[0]
86
+ cur_top = self._frame.sort_values("current_rank").iloc[0]
87
+ if ref_top["feature"] != cur_top["feature"]:
88
+ lines.append(
89
+ f"Top driver changed from '{ref_top['feature']}' to "
90
+ f"'{cur_top['feature']}'."
91
+ )
92
+ else:
93
+ lines.append(f"Top driver unchanged ('{cur_top['feature']}').")
94
+ lines.append(f"Overall rank stability (Spearman): {self.rank_correlation:.2f}.")
95
+ return " ".join(lines)
96
+
97
+ def details(self, top_n: int = 5) -> str:
98
+ """A plain-English narrative of the biggest movers -- explicit for users."""
99
+ low, high = self.psi_thresholds
100
+ out = [self.summary(), ""]
101
+ out.append(f"PSI thresholds: medium >= {low}, high >= {high}.")
102
+ out.append(f"Top {min(top_n, self.n_features)} features by drift:")
103
+ for _, r in self._frame.head(top_n).iterrows():
104
+ direction = "gained" if r["importance_change"] > 0 else "lost"
105
+ move = ""
106
+ if r["rank_change"] > 0:
107
+ move = f", climbed {int(r['rank_change'])} rank(s)"
108
+ elif r["rank_change"] < 0:
109
+ move = f", fell {int(-r['rank_change'])} rank(s)"
110
+ out.append(
111
+ f" - {r['feature']}: PSI={r['psi']:.3f} ({r['severity']}); "
112
+ f"{direction} influence "
113
+ f"({r['reference_importance']:.3f} -> {r['current_importance']:.3f}){move}."
114
+ )
115
+ if self.notes:
116
+ out.append("")
117
+ out.append("Notes:")
118
+ out.extend(f" - {n}" for n in self.notes)
119
+ return "\n".join(out)
120
+
121
+ def __repr__(self) -> str:
122
+ return f"<DriftReport: {self.summary()}>"
123
+
124
+ def plot(self, ax=None, top_n: Optional[int] = None):
125
+ """Bump chart of feature-importance rank: reference vs current.
126
+
127
+ Needs matplotlib (optional). Returns the Axes.
128
+ """
129
+ try:
130
+ import matplotlib.pyplot as plt
131
+ except ImportError as exc: # pragma: no cover
132
+ raise ImportError(
133
+ "Plotting needs matplotlib. Install with: pip install shiftshap[plot]"
134
+ ) from exc
135
+
136
+ frame = self._frame.copy()
137
+ if top_n is not None:
138
+ keep = frame.nsmallest(top_n, "reference_rank")["feature"]
139
+ frame = frame[frame["feature"].isin(keep)]
140
+
141
+ if ax is None:
142
+ _, ax = plt.subplots(figsize=(6, max(3, 0.4 * len(frame) + 1)))
143
+
144
+ colours = {"high": "#d1495b", "medium": "#edae49", "low": "#9aa5b1"}
145
+ for _, row in frame.iterrows():
146
+ ax.plot(
147
+ [0, 1], [row["reference_rank"], row["current_rank"]],
148
+ marker="o", linewidth=2.2,
149
+ color=colours.get(row["severity"], "#9aa5b1"),
150
+ )
151
+ ax.annotate(f" {row['feature']}", (1, row["current_rank"]),
152
+ va="center", fontsize=9, color="#2b2b2b")
153
+
154
+ ax.set_xticks([0, 1])
155
+ ax.set_xticklabels(["reference", "current"])
156
+ ax.set_ylabel("importance rank (1 = most important)")
157
+ ax.invert_yaxis()
158
+ ax.set_xlim(-0.15, 1.55)
159
+ ax.set_title("SHAP importance rank drift")
160
+ ax.grid(axis="y", linestyle=":", alpha=0.4)
161
+ return ax
162
+
163
+
164
+ def compare(
165
+ reference_shap,
166
+ current_shap,
167
+ feature_names: Optional[Sequence[str]] = None,
168
+ bins: int = 10,
169
+ psi_thresholds: tuple[float, float] = (0.1, 0.2),
170
+ class_index: Optional[int] = None,
171
+ ) -> DriftReport:
172
+ """Compare two sets of SHAP explanations and report drift.
173
+
174
+ Parameters
175
+ ----------
176
+ reference_shap, current_shap
177
+ SHAP values for the two periods: a ``shap.Explanation``, NumPy array,
178
+ list, or DataFrame of shape ``(n_samples, n_features)``. 3D multi-class
179
+ SHAP is supported (see ``class_index``). Sample counts may differ; the
180
+ feature count must match.
181
+ feature_names
182
+ Optional; else taken from a ``shap.Explanation`` or defaulted.
183
+ bins
184
+ Requested number of quantile bins for PSI (auto-reduced for small data).
185
+ psi_thresholds
186
+ ``(medium, high)`` PSI cut-offs for severity.
187
+ class_index
188
+ For multi-class SHAP: pick a single class. Default aggregates classes.
189
+
190
+ Returns
191
+ -------
192
+ DriftReport
193
+ """
194
+ if not (isinstance(psi_thresholds, (tuple, list)) and len(psi_thresholds) == 2
195
+ and psi_thresholds[0] <= psi_thresholds[1]):
196
+ raise ValueError("psi_thresholds must be (medium, high) with medium <= high.")
197
+
198
+ ref = to_shap_array(reference_shap, class_index=class_index)
199
+ cur = to_shap_array(current_shap, class_index=class_index)
200
+
201
+ if ref.shape[1] != cur.shape[1]:
202
+ raise ValueError(
203
+ f"Feature count mismatch: reference has {ref.shape[1]} features, "
204
+ f"current has {cur.shape[1]}."
205
+ )
206
+
207
+ n_features = ref.shape[1]
208
+ names = _resolve_feature_names(feature_names, reference_shap, current_shap, n_features)
209
+
210
+ notes = []
211
+ if np.isnan(ref).any() or np.isnan(cur).any():
212
+ n_bad = int(np.isnan(ref).any(axis=1).sum() + np.isnan(cur).any(axis=1).sum())
213
+ notes.append(
214
+ f"{n_bad} sample(s) contained NaN SHAP values and were ignored "
215
+ f"per-feature where needed."
216
+ )
217
+ warnings.warn(notes[-1], stacklevel=2)
218
+
219
+ eff_bins = choose_bins(ref.shape[0], bins)
220
+ if eff_bins < bins:
221
+ notes.append(
222
+ f"Reference has {ref.shape[0]} samples; PSI bins reduced from "
223
+ f"{bins} to {eff_bins} to stay meaningful."
224
+ )
225
+
226
+ min_n = min(ref.shape[0], cur.shape[0])
227
+ if min_n < 30:
228
+ msg = (
229
+ f"Only {min_n} sample(s) in the smaller period. PSI/drift results "
230
+ f"are unreliable below ~30 samples and may show false alarms -- "
231
+ f"treat severities with caution."
232
+ )
233
+ notes.append(msg)
234
+ warnings.warn(msg, stacklevel=2)
235
+
236
+ ref_imp = mean_abs_importance(ref)
237
+ cur_imp = mean_abs_importance(cur)
238
+ ref_rank = importance_ranks(ref_imp)
239
+ cur_rank = importance_ranks(cur_imp)
240
+
241
+ psi = np.array([
242
+ population_stability_index(ref[:, j], cur[:, j], bins=bins)
243
+ for j in range(n_features)
244
+ ])
245
+ severity = [severity_label(p, psi_thresholds) for p in psi]
246
+
247
+ frame = pd.DataFrame({
248
+ "feature": names,
249
+ "reference_importance": ref_imp,
250
+ "current_importance": cur_imp,
251
+ "importance_change": cur_imp - ref_imp,
252
+ "psi": psi,
253
+ "reference_rank": ref_rank,
254
+ "current_rank": cur_rank,
255
+ "rank_change": ref_rank - cur_rank,
256
+ "severity": severity,
257
+ })
258
+
259
+ rank_corr = spearman_rank_correlation(ref_rank, cur_rank)
260
+ return DriftReport(frame, rank_correlation=rank_corr,
261
+ psi_thresholds=tuple(psi_thresholds), notes=notes)
@@ -0,0 +1,201 @@
1
+ """Core drift metrics for SHAP explanations.
2
+
3
+ Dependency-light (NumPy only). Every function does one small, well-defined,
4
+ testable job. All metrics are documented in ``METRIC_DEFINITIONS`` so the
5
+ output is never a mystery number to the user.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ import numpy as np
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Plain-language definitions of every metric shiftshap reports.
16
+ # Exposed to users via shiftshap.metric_definitions().
17
+ # ---------------------------------------------------------------------------
18
+ METRIC_DEFINITIONS = {
19
+ "reference_importance": (
20
+ "Global importance of the feature in the REFERENCE period: the mean "
21
+ "absolute SHAP value across all reference samples. Higher = the feature "
22
+ "influenced predictions more."
23
+ ),
24
+ "current_importance": (
25
+ "Global importance of the feature in the CURRENT period (mean absolute "
26
+ "SHAP value across all current samples)."
27
+ ),
28
+ "importance_change": (
29
+ "current_importance minus reference_importance. Negative = the feature "
30
+ "lost influence; positive = it gained influence."
31
+ ),
32
+ "psi": (
33
+ "Population Stability Index of the feature's SHAP-value distribution "
34
+ "between the two periods. 0 = no shift. Rule of thumb: <0.1 no "
35
+ "significant shift, 0.1-0.2 moderate, >=0.2 significant."
36
+ ),
37
+ "reference_rank": "Importance rank in the reference period (1 = most important).",
38
+ "current_rank": "Importance rank in the current period (1 = most important).",
39
+ "rank_change": (
40
+ "reference_rank minus current_rank. Positive = the feature climbed in "
41
+ "importance; negative = it fell."
42
+ ),
43
+ "severity": (
44
+ "Overall drift severity for the feature (low / medium / high), based on "
45
+ "its PSI against the configured thresholds."
46
+ ),
47
+ "rank_correlation": (
48
+ "Spearman correlation of the whole feature-importance ranking between "
49
+ "periods. 1.0 = identical ordering, 0 = unrelated, -1.0 = reversed."
50
+ ),
51
+ }
52
+
53
+
54
+ def to_shap_array(x, class_index: int | None = None) -> np.ndarray:
55
+ """Normalise a SHAP input into a 2D array ``(n_samples, n_features)``.
56
+
57
+ Accepts:
58
+ * a ``shap.Explanation`` (uses its ``.values``),
59
+ * a NumPy array / list / pandas DataFrame of SHAP values,
60
+ * 1D (single feature), 2D, or 3D multi-class SHAP.
61
+
62
+ Multi-class handling (3D input, shape ``(n_samples, n_features, n_classes)``):
63
+ * ``class_index=None`` (default): aggregate across classes using the mean
64
+ absolute value per (sample, feature) -> a non-negative contribution
65
+ magnitude whose distribution can be tracked.
66
+ * ``class_index=k``: analyse only class ``k``.
67
+ """
68
+ values = getattr(x, "values", x) # duck-type shap.Explanation
69
+
70
+ # pandas DataFrame / Series -> ndarray
71
+ if hasattr(values, "to_numpy"):
72
+ values = values.to_numpy()
73
+
74
+ arr = np.asarray(values, dtype=float)
75
+
76
+ if arr.size == 0:
77
+ raise ValueError("SHAP input is empty (0 elements).")
78
+
79
+ if arr.ndim == 1:
80
+ arr = arr.reshape(-1, 1)
81
+ elif arr.ndim == 3:
82
+ n_classes = arr.shape[2]
83
+ if class_index is not None:
84
+ if not 0 <= class_index < n_classes:
85
+ raise ValueError(
86
+ f"class_index={class_index} out of range for "
87
+ f"{n_classes} classes."
88
+ )
89
+ arr = arr[:, :, class_index]
90
+ else:
91
+ # Aggregate across classes -> per (sample, feature) magnitude.
92
+ arr = np.mean(np.abs(arr), axis=2)
93
+ elif arr.ndim != 2:
94
+ raise ValueError(
95
+ f"Expected SHAP values with 1, 2 or 3 dimensions, got {arr.ndim}D "
96
+ f"array of shape {arr.shape}."
97
+ )
98
+
99
+ if np.isinf(arr).any():
100
+ raise ValueError(
101
+ "SHAP input contains infinite values. Please clean these before "
102
+ "comparing (inf usually signals an upstream bug)."
103
+ )
104
+ return arr
105
+
106
+
107
+ def drop_nan_pair(ref: np.ndarray, cur: np.ndarray) -> tuple[np.ndarray, np.ndarray, int]:
108
+ """Return ref/cur 1D vectors with NaNs removed, plus the count dropped."""
109
+ ref = ref[~np.isnan(ref)]
110
+ cur = cur[~np.isnan(cur)]
111
+ return ref, cur, 0 # counts handled by caller for messaging
112
+
113
+
114
+ def mean_abs_importance(shap_arr: np.ndarray) -> np.ndarray:
115
+ """Global feature importance = mean absolute SHAP value per feature.
116
+
117
+ NaN-safe: NaN entries are ignored (nanmean). A feature that is entirely NaN
118
+ returns 0.0 importance.
119
+ """
120
+ with warnings.catch_warnings():
121
+ warnings.simplefilter("ignore", category=RuntimeWarning)
122
+ imp = np.nanmean(np.abs(shap_arr), axis=0)
123
+ return np.nan_to_num(imp, nan=0.0)
124
+
125
+
126
+ def importance_ranks(importance: np.ndarray) -> np.ndarray:
127
+ """1-based ranks (1 = most important) for an importance vector."""
128
+ order = np.argsort(-importance, kind="stable")
129
+ ranks = np.empty(len(order), dtype=int)
130
+ ranks[order] = np.arange(1, len(order) + 1)
131
+ return ranks
132
+
133
+
134
+ def choose_bins(n_ref: int, requested_bins: int) -> int:
135
+ """Pick a safe bin count so PSI stays meaningful for small samples.
136
+
137
+ We want at least ~5 reference samples per bin. Never fewer than 2 bins.
138
+ """
139
+ safe = max(2, min(requested_bins, n_ref // 5)) if n_ref >= 10 else 2
140
+ return safe
141
+
142
+
143
+ def population_stability_index(
144
+ reference: np.ndarray,
145
+ current: np.ndarray,
146
+ bins: int = 10,
147
+ epsilon: float = 1e-6,
148
+ ) -> float:
149
+ """Population Stability Index between two 1D distributions (NaN-safe).
150
+
151
+ Bin edges come from quantiles of the reference. Returns 0.0 for a
152
+ degenerate reference (no variation) or when there is too little data.
153
+ """
154
+ reference = np.asarray(reference, dtype=float)
155
+ current = np.asarray(current, dtype=float)
156
+ reference = reference[~np.isnan(reference)]
157
+ current = current[~np.isnan(current)]
158
+
159
+ if len(reference) < 2 or len(current) < 1:
160
+ return 0.0
161
+
162
+ bins = choose_bins(len(reference), bins)
163
+
164
+ quantiles = np.linspace(0.0, 1.0, bins + 1)
165
+ edges = np.quantile(reference, quantiles)
166
+ edges = np.unique(edges)
167
+ if len(edges) < 3:
168
+ return 0.0 # near-constant reference -> no meaningful PSI
169
+ edges[0], edges[-1] = -np.inf, np.inf
170
+
171
+ ref_counts, _ = np.histogram(reference, bins=edges)
172
+ cur_counts, _ = np.histogram(current, bins=edges)
173
+
174
+ ref_prop = np.clip(ref_counts / max(ref_counts.sum(), 1), epsilon, None)
175
+ cur_prop = np.clip(cur_counts / max(cur_counts.sum(), 1), epsilon, None)
176
+
177
+ return float(np.sum((cur_prop - ref_prop) * np.log(cur_prop / ref_prop)))
178
+
179
+
180
+ def spearman_rank_correlation(rank_a: np.ndarray, rank_b: np.ndarray) -> float:
181
+ """Spearman correlation between two rank vectors (distinct ranks).
182
+
183
+ 1.0 = identical ordering, -1.0 = reversed. Returns 1.0 for <2 features.
184
+ """
185
+ rank_a = np.asarray(rank_a, dtype=float)
186
+ rank_b = np.asarray(rank_b, dtype=float)
187
+ n = len(rank_a)
188
+ if n < 2:
189
+ return 1.0
190
+ d_squared = np.sum((rank_a - rank_b) ** 2)
191
+ return float(1.0 - (6.0 * d_squared) / (n * (n**2 - 1)))
192
+
193
+
194
+ def severity_label(psi: float, thresholds: tuple[float, float] = (0.1, 0.2)) -> str:
195
+ """Map a PSI value to low / medium / high using (medium, high) cut-offs."""
196
+ low, high = thresholds
197
+ if psi >= high:
198
+ return "high"
199
+ if psi >= low:
200
+ return "medium"
201
+ return "low"
@@ -0,0 +1,126 @@
1
+ Metadata-Version: 2.4
2
+ Name: shiftshap
3
+ Version: 0.2.0
4
+ Summary: Monitor whether your model's SHAP explanations still hold as data drifts.
5
+ Author: Mayowa Samuel Olokun
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/OWNER/shiftshap
8
+ Project-URL: Issues, https://github.com/OWNER/shiftshap/issues
9
+ Keywords: shap,explainability,xai,drift,distribution-shift,model-monitoring,machine-learning
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.9
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: numpy>=1.21
20
+ Requires-Dist: pandas>=1.3
21
+ Provides-Extra: plot
22
+ Requires-Dist: matplotlib>=3.4; extra == "plot"
23
+ Provides-Extra: dev
24
+ Requires-Dist: pytest>=7.0; extra == "dev"
25
+ Requires-Dist: matplotlib>=3.4; extra == "dev"
26
+ Dynamic: license-file
27
+
28
+ # shiftshap
29
+
30
+ **Monitor whether your model's SHAP explanations still hold as your data drifts.**
31
+
32
+ `shiftshap` answers a question every team running a model in production eventually
33
+ asks: *are my model's explanations still trustworthy?* Models live for months or
34
+ years — they get retrained, upstream pipelines change, and feature distributions
35
+ shift. When that happens, the model's reasoning quietly changes with it. The
36
+ feature that drove your predictions last quarter may not be the one driving them
37
+ today.
38
+
39
+ SHAP is excellent at explaining a model **at a single point in time**, but it has
40
+ no built-in way to tell you how those explanations have **changed** between two
41
+ points. Today people work around this by pickling explanation objects and writing
42
+ their own comparison scripts, or by using data-drift tools that know nothing about
43
+ SHAP's structure. `shiftshap` fills that gap.
44
+
45
+ ---
46
+
47
+ ## Install
48
+
49
+ ```bash
50
+ pip install shiftshap # core
51
+ pip install shiftshap[plot] # + matplotlib for the drift chart
52
+ ```
53
+
54
+ ## Quickstart
55
+
56
+ If you already use SHAP, you already have everything you need. Take your SHAP
57
+ values from two periods — training vs. production, or last month vs. this month —
58
+ and pass them in:
59
+
60
+ ```python
61
+ import shiftshap
62
+
63
+ report = shiftshap.compare(reference_shap, current_shap)
64
+
65
+ print(report.summary())
66
+ # 2 of 5 features show HIGH explanation drift (0 medium).
67
+ # Top driver changed from 'income' to 'balance'.
68
+ # Overall rank stability (Spearman): 0.70.
69
+
70
+ print(report.details()) # plain-English narrative of the biggest movers
71
+ report.to_frame() # full per-feature table
72
+ report.plot() # rank-drift bump chart
73
+ shiftshap.metric_definitions() # what every metric means, in plain words
74
+ ```
75
+
76
+ Inputs can be `shap.Explanation` objects, NumPy arrays, lists, or pandas
77
+ DataFrames of shape `(n_samples, n_features)`. **Multi-class SHAP** (3D arrays)
78
+ is supported — by default classes are aggregated, or pass `class_index=k` to
79
+ focus on one class. The two periods don't need the same number of samples.
80
+
81
+ ### Robust by design
82
+
83
+ `shiftshap` is built to survive real, messy production data. It handles NaNs
84
+ (ignored with a note), zero-variance features, tiny samples (with an explicit
85
+ "results unreliable" warning rather than false alarms), and multi-class outputs —
86
+ and it fails with clear, actionable errors on genuinely broken input (infinities,
87
+ mismatched feature counts, empty arrays) instead of cryptic stack traces.
88
+
89
+ ## What it tells you
90
+
91
+ For every feature, `shiftshap` reports:
92
+
93
+ - **Importance drift** — how the mean absolute SHAP value changed between periods.
94
+ - **PSI** (Population Stability Index) on the SHAP distribution — the industry-standard
95
+ shift metric, with the accepted `0.2` threshold flagging significant drift.
96
+ - **Rank drift** — whether your most important features reordered, plus an overall
97
+ Spearman rank-stability score.
98
+ - **Severity** — a `high` / `medium` / `low` label per feature, so the output is
99
+ actionable at a glance.
100
+
101
+ And a **bump chart** showing how the feature-importance ranking shifted:
102
+
103
+ ![rank drift chart](examples/rank_drift.png)
104
+
105
+ ## Why it matters
106
+
107
+ An explanation that has silently drifted is worse than no explanation — it gives
108
+ false confidence. In regulated settings (finance, insurance, healthcare) teams are
109
+ increasingly required to show that model explanations remain valid over time.
110
+ `shiftshap` turns that check into two lines of code.
111
+
112
+ ## Roadmap
113
+
114
+ `v0.1` deliberately does one thing well: compare two snapshots of SHAP explanations
115
+ for tabular models. Planned next:
116
+
117
+ - Persistent explanation store for many time-points (not just two).
118
+ - Research-grade **faithfulness-under-shift** metrics beyond distributional PSI.
119
+ - Alerting hooks for monitoring pipelines.
120
+ - Support for image and text explanations.
121
+
122
+ Contributions and issues welcome.
123
+
124
+ ## License
125
+
126
+ MIT
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/shiftshap/__init__.py
5
+ src/shiftshap/core.py
6
+ src/shiftshap/metrics.py
7
+ src/shiftshap.egg-info/PKG-INFO
8
+ src/shiftshap.egg-info/SOURCES.txt
9
+ src/shiftshap.egg-info/dependency_links.txt
10
+ src/shiftshap.egg-info/requires.txt
11
+ src/shiftshap.egg-info/top_level.txt
12
+ tests/test_core.py
@@ -0,0 +1,9 @@
1
+ numpy>=1.21
2
+ pandas>=1.3
3
+
4
+ [dev]
5
+ pytest>=7.0
6
+ matplotlib>=3.4
7
+
8
+ [plot]
9
+ matplotlib>=3.4
@@ -0,0 +1 @@
1
+ shiftshap
@@ -0,0 +1,168 @@
1
+ """Tests for shiftshap. Run with: pytest"""
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ import shiftshap
7
+ from shiftshap.metrics import (
8
+ importance_ranks,
9
+ mean_abs_importance,
10
+ population_stability_index,
11
+ severity_label,
12
+ spearman_rank_correlation,
13
+ )
14
+
15
+
16
+ def test_psi_zero_for_identical_distributions():
17
+ rng = np.random.default_rng(0)
18
+ x = rng.normal(size=5000)
19
+ assert population_stability_index(x, x) == pytest.approx(0.0, abs=1e-9)
20
+
21
+
22
+ def test_psi_positive_for_shifted_distribution():
23
+ rng = np.random.default_rng(1)
24
+ ref = rng.normal(0, 1, size=5000)
25
+ cur = rng.normal(3, 1, size=5000) # clearly shifted
26
+ assert population_stability_index(ref, cur) > 0.2
27
+
28
+
29
+ def test_importance_and_ranks():
30
+ # feature 1 has the largest magnitude -> rank 1
31
+ arr = np.array([[0.1, -2.0, 0.5], [0.2, 2.0, -0.5]])
32
+ imp = mean_abs_importance(arr)
33
+ assert np.argmax(imp) == 1
34
+ ranks = importance_ranks(imp)
35
+ assert ranks[1] == 1
36
+
37
+
38
+ def test_spearman_bounds():
39
+ a = np.array([1, 2, 3, 4])
40
+ assert spearman_rank_correlation(a, a) == pytest.approx(1.0)
41
+ assert spearman_rank_correlation(a, a[::-1]) == pytest.approx(-1.0)
42
+
43
+
44
+ def test_severity_thresholds():
45
+ assert severity_label(0.05) == "low"
46
+ assert severity_label(0.15) == "medium"
47
+ assert severity_label(0.30) == "high"
48
+
49
+
50
+ def test_compare_end_to_end_detects_drift():
51
+ rng = np.random.default_rng(42)
52
+ n = 2000
53
+ # Reference: feature 0 dominates. Current: feature 2 takes over.
54
+ ref = np.column_stack([
55
+ rng.normal(0, 2.0, n), # big
56
+ rng.normal(0, 0.5, n),
57
+ rng.normal(0, 0.2, n), # small
58
+ ])
59
+ cur = np.column_stack([
60
+ rng.normal(0, 0.2, n), # now small
61
+ rng.normal(0, 0.5, n),
62
+ rng.normal(0, 2.0, n), # now big
63
+ ])
64
+ report = shiftshap.compare(ref, cur, feature_names=["a", "b", "c"])
65
+ frame = report.to_frame()
66
+
67
+ # The top driver should have flipped from 'a' to 'c'.
68
+ top_ref = frame.sort_values("reference_rank").iloc[0]["feature"]
69
+ top_cur = frame.sort_values("current_rank").iloc[0]["feature"]
70
+ assert top_ref == "a"
71
+ assert top_cur == "c"
72
+ # Rank ordering reversed -> strongly negative Spearman.
73
+ assert report.rank_correlation < 0
74
+ assert report.n_high >= 1
75
+
76
+
77
+ def test_feature_count_mismatch_raises():
78
+ with pytest.raises(ValueError):
79
+ shiftshap.compare(np.zeros((10, 3)), np.zeros((10, 4)))
80
+
81
+
82
+ def test_accepts_explanation_like_object():
83
+ class FakeExplanation:
84
+ def __init__(self, values, feature_names):
85
+ self.values = values
86
+ self.feature_names = feature_names
87
+
88
+ rng = np.random.default_rng(7)
89
+ ref = FakeExplanation(rng.normal(size=(100, 2)), ["x", "y"])
90
+ cur = FakeExplanation(rng.normal(size=(100, 2)), ["x", "y"])
91
+ report = shiftshap.compare(ref, cur)
92
+ assert list(report.to_frame()["feature"]) == ["x", "y"] or set(
93
+ report.to_frame()["feature"]
94
+ ) == {"x", "y"}
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # v0.2 hardening tests: multi-class, messy inputs, clear errors, clarity
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def test_multiclass_3d_aggregates():
102
+ rng = np.random.default_rng(0)
103
+ ref = rng.normal(size=(200, 4, 3))
104
+ cur = rng.normal(size=(200, 4, 3))
105
+ report = shiftshap.compare(ref, cur)
106
+ assert report.n_features == 4 # collapsed across classes
107
+
108
+
109
+ def test_multiclass_class_index_selects_one_class():
110
+ rng = np.random.default_rng(0)
111
+ ref = rng.normal(size=(200, 4, 3))
112
+ cur = rng.normal(size=(200, 4, 3))
113
+ report = shiftshap.compare(ref, cur, class_index=1)
114
+ assert report.n_features == 4
115
+
116
+
117
+ def test_bad_class_index_raises():
118
+ rng = np.random.default_rng(0)
119
+ with pytest.raises(ValueError):
120
+ shiftshap.compare(rng.normal(size=(50, 4, 3)), rng.normal(size=(50, 4, 3)), class_index=9)
121
+
122
+
123
+ def test_inf_input_raises_clearly():
124
+ x = np.zeros((10, 3)); x[0, 0] = np.inf
125
+ with pytest.raises(ValueError, match="infinite"):
126
+ shiftshap.compare(x, np.zeros((10, 3)))
127
+
128
+
129
+ def test_empty_input_raises():
130
+ with pytest.raises(ValueError, match="empty"):
131
+ shiftshap.compare(np.array([]), np.array([]))
132
+
133
+
134
+ def test_nan_is_handled_with_note():
135
+ rng = np.random.default_rng(0)
136
+ ref = rng.normal(size=(200, 3)); ref[::10, 1] = np.nan
137
+ with pytest.warns(UserWarning):
138
+ report = shiftshap.compare(ref, rng.normal(size=(200, 3)))
139
+ assert any("NaN" in n for n in report.notes)
140
+
141
+
142
+ def test_small_sample_warns():
143
+ rng = np.random.default_rng(0)
144
+ with pytest.warns(UserWarning, match="unreliable"):
145
+ report = shiftshap.compare(rng.normal(size=(5, 3)), rng.normal(size=(5, 3)))
146
+ assert any("unreliable" in n for n in report.notes)
147
+
148
+
149
+ def test_dataframe_input_accepted():
150
+ import pandas as pd
151
+ rng = np.random.default_rng(0)
152
+ ref = pd.DataFrame(rng.normal(size=(100, 3)), columns=["a", "b", "c"])
153
+ cur = pd.DataFrame(rng.normal(size=(100, 3)), columns=["a", "b", "c"])
154
+ report = shiftshap.compare(ref, cur)
155
+ assert report.n_features == 3
156
+
157
+
158
+ def test_bad_thresholds_raise():
159
+ with pytest.raises(ValueError):
160
+ shiftshap.compare(np.zeros((50, 2)), np.zeros((50, 2)), psi_thresholds=(0.3, 0.1))
161
+
162
+
163
+ def test_details_and_definitions_are_strings():
164
+ rng = np.random.default_rng(0)
165
+ report = shiftshap.compare(rng.normal(size=(100, 3)), rng.normal(size=(100, 3)))
166
+ assert isinstance(report.details(), str)
167
+ defs = shiftshap.metric_definitions()
168
+ assert "psi" in defs and isinstance(defs["psi"], str)