shiftshap 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shiftshap-0.2.0/LICENSE +21 -0
- shiftshap-0.2.0/PKG-INFO +126 -0
- shiftshap-0.2.0/README.md +99 -0
- shiftshap-0.2.0/pyproject.toml +44 -0
- shiftshap-0.2.0/setup.cfg +4 -0
- shiftshap-0.2.0/src/shiftshap/__init__.py +32 -0
- shiftshap-0.2.0/src/shiftshap/core.py +261 -0
- shiftshap-0.2.0/src/shiftshap/metrics.py +201 -0
- shiftshap-0.2.0/src/shiftshap.egg-info/PKG-INFO +126 -0
- shiftshap-0.2.0/src/shiftshap.egg-info/SOURCES.txt +12 -0
- shiftshap-0.2.0/src/shiftshap.egg-info/dependency_links.txt +1 -0
- shiftshap-0.2.0/src/shiftshap.egg-info/requires.txt +9 -0
- shiftshap-0.2.0/src/shiftshap.egg-info/top_level.txt +1 -0
- shiftshap-0.2.0/tests/test_core.py +168 -0
shiftshap-0.2.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Mayowa Samuel Olokun
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
shiftshap-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: shiftshap
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Monitor whether your model's SHAP explanations still hold as data drifts.
|
|
5
|
+
Author: Mayowa Samuel Olokun
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/OWNER/shiftshap
|
|
8
|
+
Project-URL: Issues, https://github.com/OWNER/shiftshap/issues
|
|
9
|
+
Keywords: shap,explainability,xai,drift,distribution-shift,model-monitoring,machine-learning
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.9
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: numpy>=1.21
|
|
20
|
+
Requires-Dist: pandas>=1.3
|
|
21
|
+
Provides-Extra: plot
|
|
22
|
+
Requires-Dist: matplotlib>=3.4; extra == "plot"
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
25
|
+
Requires-Dist: matplotlib>=3.4; extra == "dev"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# shiftshap
|
|
29
|
+
|
|
30
|
+
**Monitor whether your model's SHAP explanations still hold as your data drifts.**
|
|
31
|
+
|
|
32
|
+
`shiftshap` answers a question every team running a model in production eventually
|
|
33
|
+
asks: *are my model's explanations still trustworthy?* Models live for months or
|
|
34
|
+
years — they get retrained, upstream pipelines change, and feature distributions
|
|
35
|
+
shift. When that happens, the model's reasoning quietly changes with it. The
|
|
36
|
+
feature that drove your predictions last quarter may not be the one driving them
|
|
37
|
+
today.
|
|
38
|
+
|
|
39
|
+
SHAP is excellent at explaining a model **at a single point in time**, but it has
|
|
40
|
+
no built-in way to tell you how those explanations have **changed** between two
|
|
41
|
+
points. Today people work around this by pickling explanation objects and writing
|
|
42
|
+
their own comparison scripts, or by using data-drift tools that know nothing about
|
|
43
|
+
SHAP's structure. `shiftshap` fills that gap.
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## Install
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install shiftshap # core
|
|
51
|
+
pip install shiftshap[plot] # + matplotlib for the drift chart
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Quickstart
|
|
55
|
+
|
|
56
|
+
If you already use SHAP, you already have everything you need. Take your SHAP
|
|
57
|
+
values from two periods — training vs. production, or last month vs. this month —
|
|
58
|
+
and pass them in:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import shiftshap
|
|
62
|
+
|
|
63
|
+
report = shiftshap.compare(reference_shap, current_shap)
|
|
64
|
+
|
|
65
|
+
print(report.summary())
|
|
66
|
+
# 2 of 5 features show HIGH explanation drift (0 medium).
|
|
67
|
+
# Top driver changed from 'income' to 'balance'.
|
|
68
|
+
# Overall rank stability (Spearman): 0.70.
|
|
69
|
+
|
|
70
|
+
print(report.details()) # plain-English narrative of the biggest movers
|
|
71
|
+
report.to_frame() # full per-feature table
|
|
72
|
+
report.plot() # rank-drift bump chart
|
|
73
|
+
shiftshap.metric_definitions() # what every metric means, in plain words
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Inputs can be `shap.Explanation` objects, NumPy arrays, lists, or pandas
|
|
77
|
+
DataFrames of shape `(n_samples, n_features)`. **Multi-class SHAP** (3D arrays)
|
|
78
|
+
is supported — by default classes are aggregated, or pass `class_index=k` to
|
|
79
|
+
focus on one class. The two periods don't need the same number of samples.
|
|
80
|
+
|
|
81
|
+
### Robust by design
|
|
82
|
+
|
|
83
|
+
`shiftshap` is built to survive real, messy production data. It handles NaNs
|
|
84
|
+
(ignored with a note), zero-variance features, tiny samples (with an explicit
|
|
85
|
+
"results unreliable" warning rather than false alarms), and multi-class outputs —
|
|
86
|
+
and it fails with clear, actionable errors on genuinely broken input (infinities,
|
|
87
|
+
mismatched feature counts, empty arrays) instead of cryptic stack traces.
|
|
88
|
+
|
|
89
|
+
## What it tells you
|
|
90
|
+
|
|
91
|
+
For every feature, `shiftshap` reports:
|
|
92
|
+
|
|
93
|
+
- **Importance drift** — how the mean absolute SHAP value changed between periods.
|
|
94
|
+
- **PSI** (Population Stability Index) on the SHAP distribution — the industry-standard
|
|
95
|
+
shift metric, with the accepted `0.2` threshold flagging significant drift.
|
|
96
|
+
- **Rank drift** — whether your most important features reordered, plus an overall
|
|
97
|
+
Spearman rank-stability score.
|
|
98
|
+
- **Severity** — a `high` / `medium` / `low` label per feature, so the output is
|
|
99
|
+
actionable at a glance.
|
|
100
|
+
|
|
101
|
+
And a **bump chart** showing how the feature-importance ranking shifted:
|
|
102
|
+
|
|
103
|
+

|
|
104
|
+
|
|
105
|
+
## Why it matters
|
|
106
|
+
|
|
107
|
+
An explanation that has silently drifted is worse than no explanation — it gives
|
|
108
|
+
false confidence. In regulated settings (finance, insurance, healthcare) teams are
|
|
109
|
+
increasingly required to show that model explanations remain valid over time.
|
|
110
|
+
`shiftshap` turns that check into two lines of code.
|
|
111
|
+
|
|
112
|
+
## Roadmap
|
|
113
|
+
|
|
114
|
+
`v0.1` deliberately does one thing well: compare two snapshots of SHAP explanations
|
|
115
|
+
for tabular models. Planned next:
|
|
116
|
+
|
|
117
|
+
- Persistent explanation store for many time-points (not just two).
|
|
118
|
+
- Research-grade **faithfulness-under-shift** metrics beyond distributional PSI.
|
|
119
|
+
- Alerting hooks for monitoring pipelines.
|
|
120
|
+
- Support for image and text explanations.
|
|
121
|
+
|
|
122
|
+
Contributions and issues welcome.
|
|
123
|
+
|
|
124
|
+
## License
|
|
125
|
+
|
|
126
|
+
MIT
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# shiftshap
|
|
2
|
+
|
|
3
|
+
**Monitor whether your model's SHAP explanations still hold as your data drifts.**
|
|
4
|
+
|
|
5
|
+
`shiftshap` answers a question every team running a model in production eventually
|
|
6
|
+
asks: *are my model's explanations still trustworthy?* Models live for months or
|
|
7
|
+
years — they get retrained, upstream pipelines change, and feature distributions
|
|
8
|
+
shift. When that happens, the model's reasoning quietly changes with it. The
|
|
9
|
+
feature that drove your predictions last quarter may not be the one driving them
|
|
10
|
+
today.
|
|
11
|
+
|
|
12
|
+
SHAP is excellent at explaining a model **at a single point in time**, but it has
|
|
13
|
+
no built-in way to tell you how those explanations have **changed** between two
|
|
14
|
+
points. Today people work around this by pickling explanation objects and writing
|
|
15
|
+
their own comparison scripts, or by using data-drift tools that know nothing about
|
|
16
|
+
SHAP's structure. `shiftshap` fills that gap.
|
|
17
|
+
|
|
18
|
+
---
|
|
19
|
+
|
|
20
|
+
## Install
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install shiftshap # core
|
|
24
|
+
pip install shiftshap[plot] # + matplotlib for the drift chart
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Quickstart
|
|
28
|
+
|
|
29
|
+
If you already use SHAP, you already have everything you need. Take your SHAP
|
|
30
|
+
values from two periods — training vs. production, or last month vs. this month —
|
|
31
|
+
and pass them in:
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
import shiftshap
|
|
35
|
+
|
|
36
|
+
report = shiftshap.compare(reference_shap, current_shap)
|
|
37
|
+
|
|
38
|
+
print(report.summary())
|
|
39
|
+
# 2 of 5 features show HIGH explanation drift (0 medium).
|
|
40
|
+
# Top driver changed from 'income' to 'balance'.
|
|
41
|
+
# Overall rank stability (Spearman): 0.70.
|
|
42
|
+
|
|
43
|
+
print(report.details()) # plain-English narrative of the biggest movers
|
|
44
|
+
report.to_frame() # full per-feature table
|
|
45
|
+
report.plot() # rank-drift bump chart
|
|
46
|
+
shiftshap.metric_definitions() # what every metric means, in plain words
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Inputs can be `shap.Explanation` objects, NumPy arrays, lists, or pandas
|
|
50
|
+
DataFrames of shape `(n_samples, n_features)`. **Multi-class SHAP** (3D arrays)
|
|
51
|
+
is supported — by default classes are aggregated, or pass `class_index=k` to
|
|
52
|
+
focus on one class. The two periods don't need the same number of samples.
|
|
53
|
+
|
|
54
|
+
### Robust by design
|
|
55
|
+
|
|
56
|
+
`shiftshap` is built to survive real, messy production data. It handles NaNs
|
|
57
|
+
(ignored with a note), zero-variance features, tiny samples (with an explicit
|
|
58
|
+
"results unreliable" warning rather than false alarms), and multi-class outputs —
|
|
59
|
+
and it fails with clear, actionable errors on genuinely broken input (infinities,
|
|
60
|
+
mismatched feature counts, empty arrays) instead of cryptic stack traces.
|
|
61
|
+
|
|
62
|
+
## What it tells you
|
|
63
|
+
|
|
64
|
+
For every feature, `shiftshap` reports:
|
|
65
|
+
|
|
66
|
+
- **Importance drift** — how the mean absolute SHAP value changed between periods.
|
|
67
|
+
- **PSI** (Population Stability Index) on the SHAP distribution — the industry-standard
|
|
68
|
+
shift metric, with the accepted `0.2` threshold flagging significant drift.
|
|
69
|
+
- **Rank drift** — whether your most important features reordered, plus an overall
|
|
70
|
+
Spearman rank-stability score.
|
|
71
|
+
- **Severity** — a `high` / `medium` / `low` label per feature, so the output is
|
|
72
|
+
actionable at a glance.
|
|
73
|
+
|
|
74
|
+
And a **bump chart** showing how the feature-importance ranking shifted:
|
|
75
|
+
|
|
76
|
+

|
|
77
|
+
|
|
78
|
+
## Why it matters
|
|
79
|
+
|
|
80
|
+
An explanation that has silently drifted is worse than no explanation — it gives
|
|
81
|
+
false confidence. In regulated settings (finance, insurance, healthcare) teams are
|
|
82
|
+
increasingly required to show that model explanations remain valid over time.
|
|
83
|
+
`shiftshap` turns that check into two lines of code.
|
|
84
|
+
|
|
85
|
+
## Roadmap
|
|
86
|
+
|
|
87
|
+
`v0.1` deliberately does one thing well: compare two snapshots of SHAP explanations
|
|
88
|
+
for tabular models. Planned next:
|
|
89
|
+
|
|
90
|
+
- Persistent explanation store for many time-points (not just two).
|
|
91
|
+
- Research-grade **faithfulness-under-shift** metrics beyond distributional PSI.
|
|
92
|
+
- Alerting hooks for monitoring pipelines.
|
|
93
|
+
- Support for image and text explanations.
|
|
94
|
+
|
|
95
|
+
Contributions and issues welcome.
|
|
96
|
+
|
|
97
|
+
## License
|
|
98
|
+
|
|
99
|
+
MIT
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=64", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "shiftshap"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "Monitor whether your model's SHAP explanations still hold as data drifts."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [{ name = "Mayowa Samuel Olokun" }]
|
|
13
|
+
keywords = [
|
|
14
|
+
"shap",
|
|
15
|
+
"explainability",
|
|
16
|
+
"xai",
|
|
17
|
+
"drift",
|
|
18
|
+
"distribution-shift",
|
|
19
|
+
"model-monitoring",
|
|
20
|
+
"machine-learning",
|
|
21
|
+
]
|
|
22
|
+
classifiers = [
|
|
23
|
+
"Development Status :: 3 - Alpha",
|
|
24
|
+
"Intended Audience :: Science/Research",
|
|
25
|
+
"Intended Audience :: Developers",
|
|
26
|
+
"License :: OSI Approved :: MIT License",
|
|
27
|
+
"Programming Language :: Python :: 3",
|
|
28
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
29
|
+
]
|
|
30
|
+
dependencies = [
|
|
31
|
+
"numpy>=1.21",
|
|
32
|
+
"pandas>=1.3",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[project.optional-dependencies]
|
|
36
|
+
plot = ["matplotlib>=3.4"]
|
|
37
|
+
dev = ["pytest>=7.0", "matplotlib>=3.4"]
|
|
38
|
+
|
|
39
|
+
[project.urls]
|
|
40
|
+
Homepage = "https://github.com/OWNER/shiftshap"
|
|
41
|
+
Issues = "https://github.com/OWNER/shiftshap/issues"
|
|
42
|
+
|
|
43
|
+
[tool.setuptools.packages.find]
|
|
44
|
+
where = ["src"]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""shiftshap -- monitor whether your model's SHAP explanations still hold as data drifts.
|
|
2
|
+
|
|
3
|
+
Quickstart
|
|
4
|
+
----------
|
|
5
|
+
>>> import shiftshap
|
|
6
|
+
>>> report = shiftshap.compare(reference_shap_values, current_shap_values)
|
|
7
|
+
>>> print(report.summary()) # one-line verdict
|
|
8
|
+
>>> print(report.details()) # plain-English narrative of the biggest movers
|
|
9
|
+
>>> report.to_frame() # full per-feature table
|
|
10
|
+
>>> report.plot() # rank-drift bump chart (needs matplotlib)
|
|
11
|
+
>>> shiftshap.metric_definitions() # what every column means
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from .core import DriftReport, compare, metric_definitions
|
|
15
|
+
from .metrics import (
|
|
16
|
+
METRIC_DEFINITIONS,
|
|
17
|
+
mean_abs_importance,
|
|
18
|
+
population_stability_index,
|
|
19
|
+
spearman_rank_correlation,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"compare",
|
|
24
|
+
"DriftReport",
|
|
25
|
+
"metric_definitions",
|
|
26
|
+
"METRIC_DEFINITIONS",
|
|
27
|
+
"population_stability_index",
|
|
28
|
+
"mean_abs_importance",
|
|
29
|
+
"spearman_rank_correlation",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
__version__ = "0.2.0"
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""The main entry point: ``shiftshap.compare``.
|
|
2
|
+
|
|
3
|
+
Give it SHAP explanations from two points in time (or two data batches) and it
|
|
4
|
+
tells you which features' importance has drifted, by how much, how severely,
|
|
5
|
+
and whether your most important features have reordered.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
from typing import Optional, Sequence
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from .metrics import (
|
|
17
|
+
METRIC_DEFINITIONS,
|
|
18
|
+
choose_bins,
|
|
19
|
+
importance_ranks,
|
|
20
|
+
mean_abs_importance,
|
|
21
|
+
population_stability_index,
|
|
22
|
+
severity_label,
|
|
23
|
+
spearman_rank_correlation,
|
|
24
|
+
to_shap_array,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def metric_definitions() -> dict:
|
|
29
|
+
"""Return plain-language definitions of every metric shiftshap reports."""
|
|
30
|
+
return dict(METRIC_DEFINITIONS)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _resolve_feature_names(feature_names, reference_shap, current_shap, n_features):
|
|
34
|
+
if feature_names is not None:
|
|
35
|
+
names = list(feature_names)
|
|
36
|
+
else:
|
|
37
|
+
names = (
|
|
38
|
+
getattr(reference_shap, "feature_names", None)
|
|
39
|
+
or getattr(current_shap, "feature_names", None)
|
|
40
|
+
)
|
|
41
|
+
if names is None:
|
|
42
|
+
names = [f"feature_{i}" for i in range(n_features)]
|
|
43
|
+
names = [str(n) for n in names]
|
|
44
|
+
if len(names) != n_features:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Got {len(names)} feature names but SHAP values have "
|
|
47
|
+
f"{n_features} features."
|
|
48
|
+
)
|
|
49
|
+
return names
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DriftReport:
|
|
53
|
+
"""Result of a ``compare`` call. Print it, tabulate it, or plot it."""
|
|
54
|
+
|
|
55
|
+
def __init__(self, frame: pd.DataFrame, rank_correlation: float,
|
|
56
|
+
psi_thresholds: tuple[float, float], notes: Optional[list] = None):
|
|
57
|
+
self._frame = frame.sort_values("psi", ascending=False).reset_index(drop=True)
|
|
58
|
+
self.rank_correlation = rank_correlation
|
|
59
|
+
self.psi_thresholds = psi_thresholds
|
|
60
|
+
self.notes = notes or []
|
|
61
|
+
|
|
62
|
+
def to_frame(self) -> pd.DataFrame:
|
|
63
|
+
"""Full per-feature drift table, sorted biggest-mover first."""
|
|
64
|
+
return self._frame.copy()
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def n_features(self) -> int:
|
|
68
|
+
return len(self._frame)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def n_high(self) -> int:
|
|
72
|
+
return int((self._frame["severity"] == "high").sum())
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def n_medium(self) -> int:
|
|
76
|
+
return int((self._frame["severity"] == "medium").sum())
|
|
77
|
+
|
|
78
|
+
def summary(self) -> str:
|
|
79
|
+
"""A short, human-readable one-to-three line summary."""
|
|
80
|
+
total = self.n_features
|
|
81
|
+
lines = [
|
|
82
|
+
f"{self.n_high} of {total} feature(s) show HIGH explanation drift "
|
|
83
|
+
f"({self.n_medium} medium)."
|
|
84
|
+
]
|
|
85
|
+
ref_top = self._frame.sort_values("reference_rank").iloc[0]
|
|
86
|
+
cur_top = self._frame.sort_values("current_rank").iloc[0]
|
|
87
|
+
if ref_top["feature"] != cur_top["feature"]:
|
|
88
|
+
lines.append(
|
|
89
|
+
f"Top driver changed from '{ref_top['feature']}' to "
|
|
90
|
+
f"'{cur_top['feature']}'."
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
lines.append(f"Top driver unchanged ('{cur_top['feature']}').")
|
|
94
|
+
lines.append(f"Overall rank stability (Spearman): {self.rank_correlation:.2f}.")
|
|
95
|
+
return " ".join(lines)
|
|
96
|
+
|
|
97
|
+
def details(self, top_n: int = 5) -> str:
|
|
98
|
+
"""A plain-English narrative of the biggest movers -- explicit for users."""
|
|
99
|
+
low, high = self.psi_thresholds
|
|
100
|
+
out = [self.summary(), ""]
|
|
101
|
+
out.append(f"PSI thresholds: medium >= {low}, high >= {high}.")
|
|
102
|
+
out.append(f"Top {min(top_n, self.n_features)} features by drift:")
|
|
103
|
+
for _, r in self._frame.head(top_n).iterrows():
|
|
104
|
+
direction = "gained" if r["importance_change"] > 0 else "lost"
|
|
105
|
+
move = ""
|
|
106
|
+
if r["rank_change"] > 0:
|
|
107
|
+
move = f", climbed {int(r['rank_change'])} rank(s)"
|
|
108
|
+
elif r["rank_change"] < 0:
|
|
109
|
+
move = f", fell {int(-r['rank_change'])} rank(s)"
|
|
110
|
+
out.append(
|
|
111
|
+
f" - {r['feature']}: PSI={r['psi']:.3f} ({r['severity']}); "
|
|
112
|
+
f"{direction} influence "
|
|
113
|
+
f"({r['reference_importance']:.3f} -> {r['current_importance']:.3f}){move}."
|
|
114
|
+
)
|
|
115
|
+
if self.notes:
|
|
116
|
+
out.append("")
|
|
117
|
+
out.append("Notes:")
|
|
118
|
+
out.extend(f" - {n}" for n in self.notes)
|
|
119
|
+
return "\n".join(out)
|
|
120
|
+
|
|
121
|
+
def __repr__(self) -> str:
|
|
122
|
+
return f"<DriftReport: {self.summary()}>"
|
|
123
|
+
|
|
124
|
+
def plot(self, ax=None, top_n: Optional[int] = None):
|
|
125
|
+
"""Bump chart of feature-importance rank: reference vs current.
|
|
126
|
+
|
|
127
|
+
Needs matplotlib (optional). Returns the Axes.
|
|
128
|
+
"""
|
|
129
|
+
try:
|
|
130
|
+
import matplotlib.pyplot as plt
|
|
131
|
+
except ImportError as exc: # pragma: no cover
|
|
132
|
+
raise ImportError(
|
|
133
|
+
"Plotting needs matplotlib. Install with: pip install shiftshap[plot]"
|
|
134
|
+
) from exc
|
|
135
|
+
|
|
136
|
+
frame = self._frame.copy()
|
|
137
|
+
if top_n is not None:
|
|
138
|
+
keep = frame.nsmallest(top_n, "reference_rank")["feature"]
|
|
139
|
+
frame = frame[frame["feature"].isin(keep)]
|
|
140
|
+
|
|
141
|
+
if ax is None:
|
|
142
|
+
_, ax = plt.subplots(figsize=(6, max(3, 0.4 * len(frame) + 1)))
|
|
143
|
+
|
|
144
|
+
colours = {"high": "#d1495b", "medium": "#edae49", "low": "#9aa5b1"}
|
|
145
|
+
for _, row in frame.iterrows():
|
|
146
|
+
ax.plot(
|
|
147
|
+
[0, 1], [row["reference_rank"], row["current_rank"]],
|
|
148
|
+
marker="o", linewidth=2.2,
|
|
149
|
+
color=colours.get(row["severity"], "#9aa5b1"),
|
|
150
|
+
)
|
|
151
|
+
ax.annotate(f" {row['feature']}", (1, row["current_rank"]),
|
|
152
|
+
va="center", fontsize=9, color="#2b2b2b")
|
|
153
|
+
|
|
154
|
+
ax.set_xticks([0, 1])
|
|
155
|
+
ax.set_xticklabels(["reference", "current"])
|
|
156
|
+
ax.set_ylabel("importance rank (1 = most important)")
|
|
157
|
+
ax.invert_yaxis()
|
|
158
|
+
ax.set_xlim(-0.15, 1.55)
|
|
159
|
+
ax.set_title("SHAP importance rank drift")
|
|
160
|
+
ax.grid(axis="y", linestyle=":", alpha=0.4)
|
|
161
|
+
return ax
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def compare(
|
|
165
|
+
reference_shap,
|
|
166
|
+
current_shap,
|
|
167
|
+
feature_names: Optional[Sequence[str]] = None,
|
|
168
|
+
bins: int = 10,
|
|
169
|
+
psi_thresholds: tuple[float, float] = (0.1, 0.2),
|
|
170
|
+
class_index: Optional[int] = None,
|
|
171
|
+
) -> DriftReport:
|
|
172
|
+
"""Compare two sets of SHAP explanations and report drift.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
reference_shap, current_shap
|
|
177
|
+
SHAP values for the two periods: a ``shap.Explanation``, NumPy array,
|
|
178
|
+
list, or DataFrame of shape ``(n_samples, n_features)``. 3D multi-class
|
|
179
|
+
SHAP is supported (see ``class_index``). Sample counts may differ; the
|
|
180
|
+
feature count must match.
|
|
181
|
+
feature_names
|
|
182
|
+
Optional; else taken from a ``shap.Explanation`` or defaulted.
|
|
183
|
+
bins
|
|
184
|
+
Requested number of quantile bins for PSI (auto-reduced for small data).
|
|
185
|
+
psi_thresholds
|
|
186
|
+
``(medium, high)`` PSI cut-offs for severity.
|
|
187
|
+
class_index
|
|
188
|
+
For multi-class SHAP: pick a single class. Default aggregates classes.
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
DriftReport
|
|
193
|
+
"""
|
|
194
|
+
if not (isinstance(psi_thresholds, (tuple, list)) and len(psi_thresholds) == 2
|
|
195
|
+
and psi_thresholds[0] <= psi_thresholds[1]):
|
|
196
|
+
raise ValueError("psi_thresholds must be (medium, high) with medium <= high.")
|
|
197
|
+
|
|
198
|
+
ref = to_shap_array(reference_shap, class_index=class_index)
|
|
199
|
+
cur = to_shap_array(current_shap, class_index=class_index)
|
|
200
|
+
|
|
201
|
+
if ref.shape[1] != cur.shape[1]:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
f"Feature count mismatch: reference has {ref.shape[1]} features, "
|
|
204
|
+
f"current has {cur.shape[1]}."
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
n_features = ref.shape[1]
|
|
208
|
+
names = _resolve_feature_names(feature_names, reference_shap, current_shap, n_features)
|
|
209
|
+
|
|
210
|
+
notes = []
|
|
211
|
+
if np.isnan(ref).any() or np.isnan(cur).any():
|
|
212
|
+
n_bad = int(np.isnan(ref).any(axis=1).sum() + np.isnan(cur).any(axis=1).sum())
|
|
213
|
+
notes.append(
|
|
214
|
+
f"{n_bad} sample(s) contained NaN SHAP values and were ignored "
|
|
215
|
+
f"per-feature where needed."
|
|
216
|
+
)
|
|
217
|
+
warnings.warn(notes[-1], stacklevel=2)
|
|
218
|
+
|
|
219
|
+
eff_bins = choose_bins(ref.shape[0], bins)
|
|
220
|
+
if eff_bins < bins:
|
|
221
|
+
notes.append(
|
|
222
|
+
f"Reference has {ref.shape[0]} samples; PSI bins reduced from "
|
|
223
|
+
f"{bins} to {eff_bins} to stay meaningful."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
min_n = min(ref.shape[0], cur.shape[0])
|
|
227
|
+
if min_n < 30:
|
|
228
|
+
msg = (
|
|
229
|
+
f"Only {min_n} sample(s) in the smaller period. PSI/drift results "
|
|
230
|
+
f"are unreliable below ~30 samples and may show false alarms -- "
|
|
231
|
+
f"treat severities with caution."
|
|
232
|
+
)
|
|
233
|
+
notes.append(msg)
|
|
234
|
+
warnings.warn(msg, stacklevel=2)
|
|
235
|
+
|
|
236
|
+
ref_imp = mean_abs_importance(ref)
|
|
237
|
+
cur_imp = mean_abs_importance(cur)
|
|
238
|
+
ref_rank = importance_ranks(ref_imp)
|
|
239
|
+
cur_rank = importance_ranks(cur_imp)
|
|
240
|
+
|
|
241
|
+
psi = np.array([
|
|
242
|
+
population_stability_index(ref[:, j], cur[:, j], bins=bins)
|
|
243
|
+
for j in range(n_features)
|
|
244
|
+
])
|
|
245
|
+
severity = [severity_label(p, psi_thresholds) for p in psi]
|
|
246
|
+
|
|
247
|
+
frame = pd.DataFrame({
|
|
248
|
+
"feature": names,
|
|
249
|
+
"reference_importance": ref_imp,
|
|
250
|
+
"current_importance": cur_imp,
|
|
251
|
+
"importance_change": cur_imp - ref_imp,
|
|
252
|
+
"psi": psi,
|
|
253
|
+
"reference_rank": ref_rank,
|
|
254
|
+
"current_rank": cur_rank,
|
|
255
|
+
"rank_change": ref_rank - cur_rank,
|
|
256
|
+
"severity": severity,
|
|
257
|
+
})
|
|
258
|
+
|
|
259
|
+
rank_corr = spearman_rank_correlation(ref_rank, cur_rank)
|
|
260
|
+
return DriftReport(frame, rank_correlation=rank_corr,
|
|
261
|
+
psi_thresholds=tuple(psi_thresholds), notes=notes)
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Core drift metrics for SHAP explanations.
|
|
2
|
+
|
|
3
|
+
Dependency-light (NumPy only). Every function does one small, well-defined,
|
|
4
|
+
testable job. All metrics are documented in ``METRIC_DEFINITIONS`` so the
|
|
5
|
+
output is never a mystery number to the user.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
# Plain-language definitions of every metric shiftshap reports.
|
|
16
|
+
# Exposed to users via shiftshap.metric_definitions().
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
METRIC_DEFINITIONS = {
|
|
19
|
+
"reference_importance": (
|
|
20
|
+
"Global importance of the feature in the REFERENCE period: the mean "
|
|
21
|
+
"absolute SHAP value across all reference samples. Higher = the feature "
|
|
22
|
+
"influenced predictions more."
|
|
23
|
+
),
|
|
24
|
+
"current_importance": (
|
|
25
|
+
"Global importance of the feature in the CURRENT period (mean absolute "
|
|
26
|
+
"SHAP value across all current samples)."
|
|
27
|
+
),
|
|
28
|
+
"importance_change": (
|
|
29
|
+
"current_importance minus reference_importance. Negative = the feature "
|
|
30
|
+
"lost influence; positive = it gained influence."
|
|
31
|
+
),
|
|
32
|
+
"psi": (
|
|
33
|
+
"Population Stability Index of the feature's SHAP-value distribution "
|
|
34
|
+
"between the two periods. 0 = no shift. Rule of thumb: <0.1 no "
|
|
35
|
+
"significant shift, 0.1-0.2 moderate, >=0.2 significant."
|
|
36
|
+
),
|
|
37
|
+
"reference_rank": "Importance rank in the reference period (1 = most important).",
|
|
38
|
+
"current_rank": "Importance rank in the current period (1 = most important).",
|
|
39
|
+
"rank_change": (
|
|
40
|
+
"reference_rank minus current_rank. Positive = the feature climbed in "
|
|
41
|
+
"importance; negative = it fell."
|
|
42
|
+
),
|
|
43
|
+
"severity": (
|
|
44
|
+
"Overall drift severity for the feature (low / medium / high), based on "
|
|
45
|
+
"its PSI against the configured thresholds."
|
|
46
|
+
),
|
|
47
|
+
"rank_correlation": (
|
|
48
|
+
"Spearman correlation of the whole feature-importance ranking between "
|
|
49
|
+
"periods. 1.0 = identical ordering, 0 = unrelated, -1.0 = reversed."
|
|
50
|
+
),
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def to_shap_array(x, class_index: int | None = None) -> np.ndarray:
|
|
55
|
+
"""Normalise a SHAP input into a 2D array ``(n_samples, n_features)``.
|
|
56
|
+
|
|
57
|
+
Accepts:
|
|
58
|
+
* a ``shap.Explanation`` (uses its ``.values``),
|
|
59
|
+
* a NumPy array / list / pandas DataFrame of SHAP values,
|
|
60
|
+
* 1D (single feature), 2D, or 3D multi-class SHAP.
|
|
61
|
+
|
|
62
|
+
Multi-class handling (3D input, shape ``(n_samples, n_features, n_classes)``):
|
|
63
|
+
* ``class_index=None`` (default): aggregate across classes using the mean
|
|
64
|
+
absolute value per (sample, feature) -> a non-negative contribution
|
|
65
|
+
magnitude whose distribution can be tracked.
|
|
66
|
+
* ``class_index=k``: analyse only class ``k``.
|
|
67
|
+
"""
|
|
68
|
+
values = getattr(x, "values", x) # duck-type shap.Explanation
|
|
69
|
+
|
|
70
|
+
# pandas DataFrame / Series -> ndarray
|
|
71
|
+
if hasattr(values, "to_numpy"):
|
|
72
|
+
values = values.to_numpy()
|
|
73
|
+
|
|
74
|
+
arr = np.asarray(values, dtype=float)
|
|
75
|
+
|
|
76
|
+
if arr.size == 0:
|
|
77
|
+
raise ValueError("SHAP input is empty (0 elements).")
|
|
78
|
+
|
|
79
|
+
if arr.ndim == 1:
|
|
80
|
+
arr = arr.reshape(-1, 1)
|
|
81
|
+
elif arr.ndim == 3:
|
|
82
|
+
n_classes = arr.shape[2]
|
|
83
|
+
if class_index is not None:
|
|
84
|
+
if not 0 <= class_index < n_classes:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"class_index={class_index} out of range for "
|
|
87
|
+
f"{n_classes} classes."
|
|
88
|
+
)
|
|
89
|
+
arr = arr[:, :, class_index]
|
|
90
|
+
else:
|
|
91
|
+
# Aggregate across classes -> per (sample, feature) magnitude.
|
|
92
|
+
arr = np.mean(np.abs(arr), axis=2)
|
|
93
|
+
elif arr.ndim != 2:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Expected SHAP values with 1, 2 or 3 dimensions, got {arr.ndim}D "
|
|
96
|
+
f"array of shape {arr.shape}."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if np.isinf(arr).any():
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"SHAP input contains infinite values. Please clean these before "
|
|
102
|
+
"comparing (inf usually signals an upstream bug)."
|
|
103
|
+
)
|
|
104
|
+
return arr
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def drop_nan_pair(ref: np.ndarray, cur: np.ndarray) -> tuple[np.ndarray, np.ndarray, int]:
|
|
108
|
+
"""Return ref/cur 1D vectors with NaNs removed, plus the count dropped."""
|
|
109
|
+
ref = ref[~np.isnan(ref)]
|
|
110
|
+
cur = cur[~np.isnan(cur)]
|
|
111
|
+
return ref, cur, 0 # counts handled by caller for messaging
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def mean_abs_importance(shap_arr: np.ndarray) -> np.ndarray:
|
|
115
|
+
"""Global feature importance = mean absolute SHAP value per feature.
|
|
116
|
+
|
|
117
|
+
NaN-safe: NaN entries are ignored (nanmean). A feature that is entirely NaN
|
|
118
|
+
returns 0.0 importance.
|
|
119
|
+
"""
|
|
120
|
+
with warnings.catch_warnings():
|
|
121
|
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
|
122
|
+
imp = np.nanmean(np.abs(shap_arr), axis=0)
|
|
123
|
+
return np.nan_to_num(imp, nan=0.0)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def importance_ranks(importance: np.ndarray) -> np.ndarray:
|
|
127
|
+
"""1-based ranks (1 = most important) for an importance vector."""
|
|
128
|
+
order = np.argsort(-importance, kind="stable")
|
|
129
|
+
ranks = np.empty(len(order), dtype=int)
|
|
130
|
+
ranks[order] = np.arange(1, len(order) + 1)
|
|
131
|
+
return ranks
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def choose_bins(n_ref: int, requested_bins: int) -> int:
|
|
135
|
+
"""Pick a safe bin count so PSI stays meaningful for small samples.
|
|
136
|
+
|
|
137
|
+
We want at least ~5 reference samples per bin. Never fewer than 2 bins.
|
|
138
|
+
"""
|
|
139
|
+
safe = max(2, min(requested_bins, n_ref // 5)) if n_ref >= 10 else 2
|
|
140
|
+
return safe
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def population_stability_index(
|
|
144
|
+
reference: np.ndarray,
|
|
145
|
+
current: np.ndarray,
|
|
146
|
+
bins: int = 10,
|
|
147
|
+
epsilon: float = 1e-6,
|
|
148
|
+
) -> float:
|
|
149
|
+
"""Population Stability Index between two 1D distributions (NaN-safe).
|
|
150
|
+
|
|
151
|
+
Bin edges come from quantiles of the reference. Returns 0.0 for a
|
|
152
|
+
degenerate reference (no variation) or when there is too little data.
|
|
153
|
+
"""
|
|
154
|
+
reference = np.asarray(reference, dtype=float)
|
|
155
|
+
current = np.asarray(current, dtype=float)
|
|
156
|
+
reference = reference[~np.isnan(reference)]
|
|
157
|
+
current = current[~np.isnan(current)]
|
|
158
|
+
|
|
159
|
+
if len(reference) < 2 or len(current) < 1:
|
|
160
|
+
return 0.0
|
|
161
|
+
|
|
162
|
+
bins = choose_bins(len(reference), bins)
|
|
163
|
+
|
|
164
|
+
quantiles = np.linspace(0.0, 1.0, bins + 1)
|
|
165
|
+
edges = np.quantile(reference, quantiles)
|
|
166
|
+
edges = np.unique(edges)
|
|
167
|
+
if len(edges) < 3:
|
|
168
|
+
return 0.0 # near-constant reference -> no meaningful PSI
|
|
169
|
+
edges[0], edges[-1] = -np.inf, np.inf
|
|
170
|
+
|
|
171
|
+
ref_counts, _ = np.histogram(reference, bins=edges)
|
|
172
|
+
cur_counts, _ = np.histogram(current, bins=edges)
|
|
173
|
+
|
|
174
|
+
ref_prop = np.clip(ref_counts / max(ref_counts.sum(), 1), epsilon, None)
|
|
175
|
+
cur_prop = np.clip(cur_counts / max(cur_counts.sum(), 1), epsilon, None)
|
|
176
|
+
|
|
177
|
+
return float(np.sum((cur_prop - ref_prop) * np.log(cur_prop / ref_prop)))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def spearman_rank_correlation(rank_a: np.ndarray, rank_b: np.ndarray) -> float:
|
|
181
|
+
"""Spearman correlation between two rank vectors (distinct ranks).
|
|
182
|
+
|
|
183
|
+
1.0 = identical ordering, -1.0 = reversed. Returns 1.0 for <2 features.
|
|
184
|
+
"""
|
|
185
|
+
rank_a = np.asarray(rank_a, dtype=float)
|
|
186
|
+
rank_b = np.asarray(rank_b, dtype=float)
|
|
187
|
+
n = len(rank_a)
|
|
188
|
+
if n < 2:
|
|
189
|
+
return 1.0
|
|
190
|
+
d_squared = np.sum((rank_a - rank_b) ** 2)
|
|
191
|
+
return float(1.0 - (6.0 * d_squared) / (n * (n**2 - 1)))
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def severity_label(psi: float, thresholds: tuple[float, float] = (0.1, 0.2)) -> str:
|
|
195
|
+
"""Map a PSI value to low / medium / high using (medium, high) cut-offs."""
|
|
196
|
+
low, high = thresholds
|
|
197
|
+
if psi >= high:
|
|
198
|
+
return "high"
|
|
199
|
+
if psi >= low:
|
|
200
|
+
return "medium"
|
|
201
|
+
return "low"
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: shiftshap
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Monitor whether your model's SHAP explanations still hold as data drifts.
|
|
5
|
+
Author: Mayowa Samuel Olokun
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/OWNER/shiftshap
|
|
8
|
+
Project-URL: Issues, https://github.com/OWNER/shiftshap/issues
|
|
9
|
+
Keywords: shap,explainability,xai,drift,distribution-shift,model-monitoring,machine-learning
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.9
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: numpy>=1.21
|
|
20
|
+
Requires-Dist: pandas>=1.3
|
|
21
|
+
Provides-Extra: plot
|
|
22
|
+
Requires-Dist: matplotlib>=3.4; extra == "plot"
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
25
|
+
Requires-Dist: matplotlib>=3.4; extra == "dev"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# shiftshap
|
|
29
|
+
|
|
30
|
+
**Monitor whether your model's SHAP explanations still hold as your data drifts.**
|
|
31
|
+
|
|
32
|
+
`shiftshap` answers a question every team running a model in production eventually
|
|
33
|
+
asks: *are my model's explanations still trustworthy?* Models live for months or
|
|
34
|
+
years — they get retrained, upstream pipelines change, and feature distributions
|
|
35
|
+
shift. When that happens, the model's reasoning quietly changes with it. The
|
|
36
|
+
feature that drove your predictions last quarter may not be the one driving them
|
|
37
|
+
today.
|
|
38
|
+
|
|
39
|
+
SHAP is excellent at explaining a model **at a single point in time**, but it has
|
|
40
|
+
no built-in way to tell you how those explanations have **changed** between two
|
|
41
|
+
points. Today people work around this by pickling explanation objects and writing
|
|
42
|
+
their own comparison scripts, or by using data-drift tools that know nothing about
|
|
43
|
+
SHAP's structure. `shiftshap` fills that gap.
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## Install
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install shiftshap # core
|
|
51
|
+
pip install shiftshap[plot] # + matplotlib for the drift chart
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Quickstart
|
|
55
|
+
|
|
56
|
+
If you already use SHAP, you already have everything you need. Take your SHAP
|
|
57
|
+
values from two periods — training vs. production, or last month vs. this month —
|
|
58
|
+
and pass them in:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import shiftshap
|
|
62
|
+
|
|
63
|
+
report = shiftshap.compare(reference_shap, current_shap)
|
|
64
|
+
|
|
65
|
+
print(report.summary())
|
|
66
|
+
# 2 of 5 features show HIGH explanation drift (0 medium).
|
|
67
|
+
# Top driver changed from 'income' to 'balance'.
|
|
68
|
+
# Overall rank stability (Spearman): 0.70.
|
|
69
|
+
|
|
70
|
+
print(report.details()) # plain-English narrative of the biggest movers
|
|
71
|
+
report.to_frame() # full per-feature table
|
|
72
|
+
report.plot() # rank-drift bump chart
|
|
73
|
+
shiftshap.metric_definitions() # what every metric means, in plain words
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Inputs can be `shap.Explanation` objects, NumPy arrays, lists, or pandas
|
|
77
|
+
DataFrames of shape `(n_samples, n_features)`. **Multi-class SHAP** (3D arrays)
|
|
78
|
+
is supported — by default classes are aggregated, or pass `class_index=k` to
|
|
79
|
+
focus on one class. The two periods don't need the same number of samples.
|
|
80
|
+
|
|
81
|
+
### Robust by design
|
|
82
|
+
|
|
83
|
+
`shiftshap` is built to survive real, messy production data. It handles NaNs
|
|
84
|
+
(ignored with a note), zero-variance features, tiny samples (with an explicit
|
|
85
|
+
"results unreliable" warning rather than false alarms), and multi-class outputs —
|
|
86
|
+
and it fails with clear, actionable errors on genuinely broken input (infinities,
|
|
87
|
+
mismatched feature counts, empty arrays) instead of cryptic stack traces.
|
|
88
|
+
|
|
89
|
+
## What it tells you
|
|
90
|
+
|
|
91
|
+
For every feature, `shiftshap` reports:
|
|
92
|
+
|
|
93
|
+
- **Importance drift** — how the mean absolute SHAP value changed between periods.
|
|
94
|
+
- **PSI** (Population Stability Index) on the SHAP distribution — the industry-standard
|
|
95
|
+
shift metric, with the accepted `0.2` threshold flagging significant drift.
|
|
96
|
+
- **Rank drift** — whether your most important features reordered, plus an overall
|
|
97
|
+
Spearman rank-stability score.
|
|
98
|
+
- **Severity** — a `high` / `medium` / `low` label per feature, so the output is
|
|
99
|
+
actionable at a glance.
|
|
100
|
+
|
|
101
|
+
And a **bump chart** showing how the feature-importance ranking shifted:
|
|
102
|
+
|
|
103
|
+

|
|
104
|
+
|
|
105
|
+
## Why it matters
|
|
106
|
+
|
|
107
|
+
An explanation that has silently drifted is worse than no explanation — it gives
|
|
108
|
+
false confidence. In regulated settings (finance, insurance, healthcare) teams are
|
|
109
|
+
increasingly required to show that model explanations remain valid over time.
|
|
110
|
+
`shiftshap` turns that check into two lines of code.
|
|
111
|
+
|
|
112
|
+
## Roadmap
|
|
113
|
+
|
|
114
|
+
`v0.1` deliberately does one thing well: compare two snapshots of SHAP explanations
|
|
115
|
+
for tabular models. Planned next:
|
|
116
|
+
|
|
117
|
+
- Persistent explanation store for many time-points (not just two).
|
|
118
|
+
- Research-grade **faithfulness-under-shift** metrics beyond distributional PSI.
|
|
119
|
+
- Alerting hooks for monitoring pipelines.
|
|
120
|
+
- Support for image and text explanations.
|
|
121
|
+
|
|
122
|
+
Contributions and issues welcome.
|
|
123
|
+
|
|
124
|
+
## License
|
|
125
|
+
|
|
126
|
+
MIT
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/shiftshap/__init__.py
|
|
5
|
+
src/shiftshap/core.py
|
|
6
|
+
src/shiftshap/metrics.py
|
|
7
|
+
src/shiftshap.egg-info/PKG-INFO
|
|
8
|
+
src/shiftshap.egg-info/SOURCES.txt
|
|
9
|
+
src/shiftshap.egg-info/dependency_links.txt
|
|
10
|
+
src/shiftshap.egg-info/requires.txt
|
|
11
|
+
src/shiftshap.egg-info/top_level.txt
|
|
12
|
+
tests/test_core.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
shiftshap
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Tests for shiftshap. Run with: pytest"""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
import shiftshap
|
|
7
|
+
from shiftshap.metrics import (
|
|
8
|
+
importance_ranks,
|
|
9
|
+
mean_abs_importance,
|
|
10
|
+
population_stability_index,
|
|
11
|
+
severity_label,
|
|
12
|
+
spearman_rank_correlation,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_psi_zero_for_identical_distributions():
|
|
17
|
+
rng = np.random.default_rng(0)
|
|
18
|
+
x = rng.normal(size=5000)
|
|
19
|
+
assert population_stability_index(x, x) == pytest.approx(0.0, abs=1e-9)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_psi_positive_for_shifted_distribution():
|
|
23
|
+
rng = np.random.default_rng(1)
|
|
24
|
+
ref = rng.normal(0, 1, size=5000)
|
|
25
|
+
cur = rng.normal(3, 1, size=5000) # clearly shifted
|
|
26
|
+
assert population_stability_index(ref, cur) > 0.2
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_importance_and_ranks():
|
|
30
|
+
# feature 1 has the largest magnitude -> rank 1
|
|
31
|
+
arr = np.array([[0.1, -2.0, 0.5], [0.2, 2.0, -0.5]])
|
|
32
|
+
imp = mean_abs_importance(arr)
|
|
33
|
+
assert np.argmax(imp) == 1
|
|
34
|
+
ranks = importance_ranks(imp)
|
|
35
|
+
assert ranks[1] == 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_spearman_bounds():
|
|
39
|
+
a = np.array([1, 2, 3, 4])
|
|
40
|
+
assert spearman_rank_correlation(a, a) == pytest.approx(1.0)
|
|
41
|
+
assert spearman_rank_correlation(a, a[::-1]) == pytest.approx(-1.0)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_severity_thresholds():
|
|
45
|
+
assert severity_label(0.05) == "low"
|
|
46
|
+
assert severity_label(0.15) == "medium"
|
|
47
|
+
assert severity_label(0.30) == "high"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_compare_end_to_end_detects_drift():
|
|
51
|
+
rng = np.random.default_rng(42)
|
|
52
|
+
n = 2000
|
|
53
|
+
# Reference: feature 0 dominates. Current: feature 2 takes over.
|
|
54
|
+
ref = np.column_stack([
|
|
55
|
+
rng.normal(0, 2.0, n), # big
|
|
56
|
+
rng.normal(0, 0.5, n),
|
|
57
|
+
rng.normal(0, 0.2, n), # small
|
|
58
|
+
])
|
|
59
|
+
cur = np.column_stack([
|
|
60
|
+
rng.normal(0, 0.2, n), # now small
|
|
61
|
+
rng.normal(0, 0.5, n),
|
|
62
|
+
rng.normal(0, 2.0, n), # now big
|
|
63
|
+
])
|
|
64
|
+
report = shiftshap.compare(ref, cur, feature_names=["a", "b", "c"])
|
|
65
|
+
frame = report.to_frame()
|
|
66
|
+
|
|
67
|
+
# The top driver should have flipped from 'a' to 'c'.
|
|
68
|
+
top_ref = frame.sort_values("reference_rank").iloc[0]["feature"]
|
|
69
|
+
top_cur = frame.sort_values("current_rank").iloc[0]["feature"]
|
|
70
|
+
assert top_ref == "a"
|
|
71
|
+
assert top_cur == "c"
|
|
72
|
+
# Rank ordering reversed -> strongly negative Spearman.
|
|
73
|
+
assert report.rank_correlation < 0
|
|
74
|
+
assert report.n_high >= 1
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_feature_count_mismatch_raises():
|
|
78
|
+
with pytest.raises(ValueError):
|
|
79
|
+
shiftshap.compare(np.zeros((10, 3)), np.zeros((10, 4)))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def test_accepts_explanation_like_object():
|
|
83
|
+
class FakeExplanation:
|
|
84
|
+
def __init__(self, values, feature_names):
|
|
85
|
+
self.values = values
|
|
86
|
+
self.feature_names = feature_names
|
|
87
|
+
|
|
88
|
+
rng = np.random.default_rng(7)
|
|
89
|
+
ref = FakeExplanation(rng.normal(size=(100, 2)), ["x", "y"])
|
|
90
|
+
cur = FakeExplanation(rng.normal(size=(100, 2)), ["x", "y"])
|
|
91
|
+
report = shiftshap.compare(ref, cur)
|
|
92
|
+
assert list(report.to_frame()["feature"]) == ["x", "y"] or set(
|
|
93
|
+
report.to_frame()["feature"]
|
|
94
|
+
) == {"x", "y"}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ---------------------------------------------------------------------------
|
|
98
|
+
# v0.2 hardening tests: multi-class, messy inputs, clear errors, clarity
|
|
99
|
+
# ---------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
def test_multiclass_3d_aggregates():
|
|
102
|
+
rng = np.random.default_rng(0)
|
|
103
|
+
ref = rng.normal(size=(200, 4, 3))
|
|
104
|
+
cur = rng.normal(size=(200, 4, 3))
|
|
105
|
+
report = shiftshap.compare(ref, cur)
|
|
106
|
+
assert report.n_features == 4 # collapsed across classes
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_multiclass_class_index_selects_one_class():
|
|
110
|
+
rng = np.random.default_rng(0)
|
|
111
|
+
ref = rng.normal(size=(200, 4, 3))
|
|
112
|
+
cur = rng.normal(size=(200, 4, 3))
|
|
113
|
+
report = shiftshap.compare(ref, cur, class_index=1)
|
|
114
|
+
assert report.n_features == 4
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_bad_class_index_raises():
|
|
118
|
+
rng = np.random.default_rng(0)
|
|
119
|
+
with pytest.raises(ValueError):
|
|
120
|
+
shiftshap.compare(rng.normal(size=(50, 4, 3)), rng.normal(size=(50, 4, 3)), class_index=9)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_inf_input_raises_clearly():
|
|
124
|
+
x = np.zeros((10, 3)); x[0, 0] = np.inf
|
|
125
|
+
with pytest.raises(ValueError, match="infinite"):
|
|
126
|
+
shiftshap.compare(x, np.zeros((10, 3)))
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_empty_input_raises():
|
|
130
|
+
with pytest.raises(ValueError, match="empty"):
|
|
131
|
+
shiftshap.compare(np.array([]), np.array([]))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_nan_is_handled_with_note():
|
|
135
|
+
rng = np.random.default_rng(0)
|
|
136
|
+
ref = rng.normal(size=(200, 3)); ref[::10, 1] = np.nan
|
|
137
|
+
with pytest.warns(UserWarning):
|
|
138
|
+
report = shiftshap.compare(ref, rng.normal(size=(200, 3)))
|
|
139
|
+
assert any("NaN" in n for n in report.notes)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_small_sample_warns():
|
|
143
|
+
rng = np.random.default_rng(0)
|
|
144
|
+
with pytest.warns(UserWarning, match="unreliable"):
|
|
145
|
+
report = shiftshap.compare(rng.normal(size=(5, 3)), rng.normal(size=(5, 3)))
|
|
146
|
+
assert any("unreliable" in n for n in report.notes)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def test_dataframe_input_accepted():
|
|
150
|
+
import pandas as pd
|
|
151
|
+
rng = np.random.default_rng(0)
|
|
152
|
+
ref = pd.DataFrame(rng.normal(size=(100, 3)), columns=["a", "b", "c"])
|
|
153
|
+
cur = pd.DataFrame(rng.normal(size=(100, 3)), columns=["a", "b", "c"])
|
|
154
|
+
report = shiftshap.compare(ref, cur)
|
|
155
|
+
assert report.n_features == 3
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def test_bad_thresholds_raise():
|
|
159
|
+
with pytest.raises(ValueError):
|
|
160
|
+
shiftshap.compare(np.zeros((50, 2)), np.zeros((50, 2)), psi_thresholds=(0.3, 0.1))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def test_details_and_definitions_are_strings():
|
|
164
|
+
rng = np.random.default_rng(0)
|
|
165
|
+
report = shiftshap.compare(rng.normal(size=(100, 3)), rng.normal(size=(100, 3)))
|
|
166
|
+
assert isinstance(report.details(), str)
|
|
167
|
+
defs = shiftshap.metric_definitions()
|
|
168
|
+
assert "psi" in defs and isinstance(defs["psi"], str)
|