explainiverse 0.1.1a1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- explainiverse/__init__.py +45 -1
- explainiverse/adapters/__init__.py +9 -0
- explainiverse/adapters/base_adapter.py +25 -25
- explainiverse/adapters/sklearn_adapter.py +32 -32
- explainiverse/core/__init__.py +22 -0
- explainiverse/core/explainer.py +31 -31
- explainiverse/core/explanation.py +24 -24
- explainiverse/core/registry.py +545 -0
- explainiverse/engine/__init__.py +8 -0
- explainiverse/engine/suite.py +142 -142
- explainiverse/evaluation/__init__.py +8 -0
- explainiverse/evaluation/metrics.py +232 -232
- explainiverse/explainers/__init__.py +38 -0
- explainiverse/explainers/attribution/__init__.py +9 -0
- explainiverse/explainers/attribution/lime_wrapper.py +90 -63
- explainiverse/explainers/attribution/shap_wrapper.py +89 -66
- explainiverse/explainers/counterfactual/__init__.py +8 -0
- explainiverse/explainers/counterfactual/dice_wrapper.py +302 -0
- explainiverse/explainers/global_explainers/__init__.py +23 -0
- explainiverse/explainers/global_explainers/ale.py +191 -0
- explainiverse/explainers/global_explainers/partial_dependence.py +192 -0
- explainiverse/explainers/global_explainers/permutation_importance.py +123 -0
- explainiverse/explainers/global_explainers/sage.py +164 -0
- explainiverse/explainers/rule_based/__init__.py +8 -0
- explainiverse/explainers/rule_based/anchors_wrapper.py +350 -0
- explainiverse-0.2.0.dist-info/METADATA +264 -0
- explainiverse-0.2.0.dist-info/RECORD +29 -0
- explainiverse-0.1.1a1.dist-info/METADATA +0 -128
- explainiverse-0.1.1a1.dist-info/RECORD +0 -19
- {explainiverse-0.1.1a1.dist-info → explainiverse-0.2.0.dist-info}/LICENSE +0 -0
- {explainiverse-0.1.1a1.dist-info → explainiverse-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# src/explainiverse/explainers/global_explainers/partial_dependence.py
|
|
2
|
+
"""
|
|
3
|
+
Partial Dependence Plot (PDP) Explainer.
|
|
4
|
+
|
|
5
|
+
Shows the marginal effect of one or two features on the predicted outcome,
|
|
6
|
+
averaging over the values of all other features.
|
|
7
|
+
|
|
8
|
+
Reference:
|
|
9
|
+
Friedman, J.H. (2001). Greedy function approximation: A gradient boosting machine.
|
|
10
|
+
Annals of Statistics, 29(5), 1189-1232.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from typing import List, Optional, Union, Tuple
|
|
15
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
16
|
+
from explainiverse.core.explanation import Explanation
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PartialDependenceExplainer(BaseExplainer):
|
|
20
|
+
"""
|
|
21
|
+
Partial Dependence Plot (PDP) explainer.
|
|
22
|
+
|
|
23
|
+
Computes the average prediction for each value of the feature(s) of interest,
|
|
24
|
+
marginalizing over all other features. This shows the relationship between
|
|
25
|
+
the feature and the predicted outcome.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
model: Model adapter with .predict() method
|
|
29
|
+
X: Training/reference data
|
|
30
|
+
feature_names: List of feature names
|
|
31
|
+
grid_resolution: Number of points in the PDP grid
|
|
32
|
+
percentile_range: Range of percentiles to use for grid (default: 5-95)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model,
|
|
38
|
+
X: np.ndarray,
|
|
39
|
+
feature_names: List[str],
|
|
40
|
+
grid_resolution: int = 50,
|
|
41
|
+
percentile_range: Tuple[float, float] = (5, 95)
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the PDP explainer.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model: Model adapter with .predict() method
|
|
48
|
+
X: Reference dataset (n_samples, n_features)
|
|
49
|
+
feature_names: List of feature names
|
|
50
|
+
grid_resolution: Number of grid points for each feature
|
|
51
|
+
percentile_range: Tuple of (min_percentile, max_percentile) for grid
|
|
52
|
+
"""
|
|
53
|
+
super().__init__(model)
|
|
54
|
+
self.X = np.array(X)
|
|
55
|
+
self.feature_names = list(feature_names)
|
|
56
|
+
self.grid_resolution = grid_resolution
|
|
57
|
+
self.percentile_range = percentile_range
|
|
58
|
+
|
|
59
|
+
def _get_feature_idx(self, feature: Union[int, str]) -> int:
|
|
60
|
+
"""Convert feature name to index if needed."""
|
|
61
|
+
if isinstance(feature, str):
|
|
62
|
+
return self.feature_names.index(feature)
|
|
63
|
+
return feature
|
|
64
|
+
|
|
65
|
+
def _create_grid(self, feature_idx: int) -> np.ndarray:
|
|
66
|
+
"""Create a grid of values for a feature."""
|
|
67
|
+
values = self.X[:, feature_idx]
|
|
68
|
+
grid = np.linspace(
|
|
69
|
+
np.percentile(values, self.percentile_range[0]),
|
|
70
|
+
np.percentile(values, self.percentile_range[1]),
|
|
71
|
+
self.grid_resolution
|
|
72
|
+
)
|
|
73
|
+
return grid
|
|
74
|
+
|
|
75
|
+
def _compute_pdp_1d(self, feature_idx: int, target_class: int = 1) -> Tuple[np.ndarray, np.ndarray]:
|
|
76
|
+
"""
|
|
77
|
+
Compute 1D partial dependence for a single feature.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
feature_idx: Index of the feature
|
|
81
|
+
target_class: Class index for which to compute PDP
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Tuple of (grid_values, pdp_values)
|
|
85
|
+
"""
|
|
86
|
+
grid = self._create_grid(feature_idx)
|
|
87
|
+
pdp_values = []
|
|
88
|
+
|
|
89
|
+
for value in grid:
|
|
90
|
+
X_temp = self.X.copy()
|
|
91
|
+
X_temp[:, feature_idx] = value
|
|
92
|
+
|
|
93
|
+
predictions = self.model.predict(X_temp)
|
|
94
|
+
|
|
95
|
+
# Handle multi-class predictions
|
|
96
|
+
if predictions.ndim == 2:
|
|
97
|
+
avg_pred = np.mean(predictions[:, target_class])
|
|
98
|
+
else:
|
|
99
|
+
avg_pred = np.mean(predictions)
|
|
100
|
+
|
|
101
|
+
pdp_values.append(avg_pred)
|
|
102
|
+
|
|
103
|
+
return grid, np.array(pdp_values)
|
|
104
|
+
|
|
105
|
+
def _compute_pdp_2d(
|
|
106
|
+
self,
|
|
107
|
+
feature_idx1: int,
|
|
108
|
+
feature_idx2: int,
|
|
109
|
+
target_class: int = 1
|
|
110
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
111
|
+
"""
|
|
112
|
+
Compute 2D partial dependence for feature interaction.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
feature_idx1: Index of first feature
|
|
116
|
+
feature_idx2: Index of second feature
|
|
117
|
+
target_class: Class index for which to compute PDP
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Tuple of (grid1, grid2, pdp_values_2d)
|
|
121
|
+
"""
|
|
122
|
+
grid1 = self._create_grid(feature_idx1)
|
|
123
|
+
grid2 = self._create_grid(feature_idx2)
|
|
124
|
+
|
|
125
|
+
pdp_values = np.zeros((len(grid1), len(grid2)))
|
|
126
|
+
|
|
127
|
+
for i, val1 in enumerate(grid1):
|
|
128
|
+
for j, val2 in enumerate(grid2):
|
|
129
|
+
X_temp = self.X.copy()
|
|
130
|
+
X_temp[:, feature_idx1] = val1
|
|
131
|
+
X_temp[:, feature_idx2] = val2
|
|
132
|
+
|
|
133
|
+
predictions = self.model.predict(X_temp)
|
|
134
|
+
|
|
135
|
+
if predictions.ndim == 2:
|
|
136
|
+
avg_pred = np.mean(predictions[:, target_class])
|
|
137
|
+
else:
|
|
138
|
+
avg_pred = np.mean(predictions)
|
|
139
|
+
|
|
140
|
+
pdp_values[i, j] = avg_pred
|
|
141
|
+
|
|
142
|
+
return grid1, grid2, pdp_values
|
|
143
|
+
|
|
144
|
+
def explain(
|
|
145
|
+
self,
|
|
146
|
+
features: List[Union[int, str, Tuple[int, int]]],
|
|
147
|
+
target_class: int = 1,
|
|
148
|
+
**kwargs
|
|
149
|
+
) -> Explanation:
|
|
150
|
+
"""
|
|
151
|
+
Compute partial dependence for specified features.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
features: List of feature indices/names or tuples for interactions
|
|
155
|
+
target_class: Class index for which to compute PDP
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Explanation object with PDP values and grids
|
|
159
|
+
"""
|
|
160
|
+
pdp_results = {}
|
|
161
|
+
grid_results = {}
|
|
162
|
+
|
|
163
|
+
for feature in features:
|
|
164
|
+
if isinstance(feature, tuple):
|
|
165
|
+
# 2D interaction
|
|
166
|
+
idx1 = self._get_feature_idx(feature[0])
|
|
167
|
+
idx2 = self._get_feature_idx(feature[1])
|
|
168
|
+
|
|
169
|
+
grid1, grid2, pdp = self._compute_pdp_2d(idx1, idx2, target_class)
|
|
170
|
+
|
|
171
|
+
key = f"{self.feature_names[idx1]}_x_{self.feature_names[idx2]}"
|
|
172
|
+
pdp_results[key] = pdp.tolist()
|
|
173
|
+
grid_results[key] = {"grid1": grid1.tolist(), "grid2": grid2.tolist()}
|
|
174
|
+
else:
|
|
175
|
+
# 1D PDP
|
|
176
|
+
idx = self._get_feature_idx(feature)
|
|
177
|
+
grid, pdp = self._compute_pdp_1d(idx, target_class)
|
|
178
|
+
|
|
179
|
+
key = self.feature_names[idx]
|
|
180
|
+
pdp_results[key] = pdp.tolist()
|
|
181
|
+
grid_results[key] = grid.tolist()
|
|
182
|
+
|
|
183
|
+
return Explanation(
|
|
184
|
+
explainer_name="PartialDependence",
|
|
185
|
+
target_class=f"class_{target_class}",
|
|
186
|
+
explanation_data={
|
|
187
|
+
"pdp_values": pdp_results,
|
|
188
|
+
"grid_values": grid_results,
|
|
189
|
+
"features_analyzed": [str(f) for f in features],
|
|
190
|
+
"interaction": any(isinstance(f, tuple) for f in features)
|
|
191
|
+
}
|
|
192
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# src/explainiverse/explainers/global_explainers/permutation_importance.py
|
|
2
|
+
"""
|
|
3
|
+
Permutation Feature Importance Explainer.
|
|
4
|
+
|
|
5
|
+
Measures feature importance by measuring the decrease in model performance
|
|
6
|
+
when a feature's values are randomly shuffled.
|
|
7
|
+
|
|
8
|
+
Reference:
|
|
9
|
+
Breiman, L. (2001). Random Forests. Machine Learning, 45(1), 5-32.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from typing import List, Optional, Callable
|
|
14
|
+
from sklearn.metrics import accuracy_score
|
|
15
|
+
|
|
16
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
17
|
+
from explainiverse.core.explanation import Explanation
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PermutationImportanceExplainer(BaseExplainer):
|
|
21
|
+
"""
|
|
22
|
+
Global explainer based on permutation feature importance.
|
|
23
|
+
|
|
24
|
+
Measures how much the model's performance decreases when each feature
|
|
25
|
+
is randomly shuffled, breaking the relationship between the feature
|
|
26
|
+
and the target.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
model: Model adapter with .predict() method
|
|
30
|
+
X: Feature matrix for evaluation
|
|
31
|
+
y: True labels
|
|
32
|
+
feature_names: List of feature names
|
|
33
|
+
n_repeats: Number of times to permute each feature
|
|
34
|
+
scoring_fn: Function to compute score (higher is better)
|
|
35
|
+
random_state: Random seed for reproducibility
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model,
|
|
41
|
+
X: np.ndarray,
|
|
42
|
+
y: np.ndarray,
|
|
43
|
+
feature_names: List[str],
|
|
44
|
+
n_repeats: int = 10,
|
|
45
|
+
scoring_fn: Optional[Callable] = None,
|
|
46
|
+
random_state: int = 42
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Initialize the Permutation Importance explainer.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
model: Model adapter with .predict() method
|
|
53
|
+
X: Feature matrix (n_samples, n_features)
|
|
54
|
+
y: True labels (n_samples,)
|
|
55
|
+
feature_names: List of feature names
|
|
56
|
+
n_repeats: Number of permutation repeats per feature
|
|
57
|
+
scoring_fn: Custom scoring function (default: accuracy)
|
|
58
|
+
random_state: Random seed
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(model)
|
|
61
|
+
self.X = np.array(X)
|
|
62
|
+
self.y = np.array(y)
|
|
63
|
+
self.feature_names = feature_names
|
|
64
|
+
self.n_repeats = n_repeats
|
|
65
|
+
self.scoring_fn = scoring_fn or self._default_scorer
|
|
66
|
+
self.random_state = random_state
|
|
67
|
+
self.rng = np.random.RandomState(random_state)
|
|
68
|
+
|
|
69
|
+
def _default_scorer(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
70
|
+
"""Default scoring function: accuracy for classification."""
|
|
71
|
+
if y_pred.ndim == 2:
|
|
72
|
+
y_pred = np.argmax(y_pred, axis=1)
|
|
73
|
+
return accuracy_score(y_true, y_pred)
|
|
74
|
+
|
|
75
|
+
def _compute_baseline_score(self) -> float:
|
|
76
|
+
"""Compute model performance on unperturbed data."""
|
|
77
|
+
predictions = self.model.predict(self.X)
|
|
78
|
+
return self.scoring_fn(self.y, predictions)
|
|
79
|
+
|
|
80
|
+
def _permute_feature(self, X: np.ndarray, feature_idx: int) -> np.ndarray:
|
|
81
|
+
"""Create a copy of X with one feature permuted."""
|
|
82
|
+
X_permuted = X.copy()
|
|
83
|
+
self.rng.shuffle(X_permuted[:, feature_idx])
|
|
84
|
+
return X_permuted
|
|
85
|
+
|
|
86
|
+
def explain(self, **kwargs) -> Explanation:
|
|
87
|
+
"""
|
|
88
|
+
Compute permutation feature importance.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Explanation object with:
|
|
92
|
+
- feature_attributions: dict of {feature_name: importance}
|
|
93
|
+
- std: dict of {feature_name: std across repeats}
|
|
94
|
+
- baseline_score: original model score
|
|
95
|
+
"""
|
|
96
|
+
baseline_score = self._compute_baseline_score()
|
|
97
|
+
|
|
98
|
+
importances = {}
|
|
99
|
+
stds = {}
|
|
100
|
+
|
|
101
|
+
for idx, fname in enumerate(self.feature_names):
|
|
102
|
+
scores = []
|
|
103
|
+
|
|
104
|
+
for _ in range(self.n_repeats):
|
|
105
|
+
X_permuted = self._permute_feature(self.X, idx)
|
|
106
|
+
predictions = self.model.predict(X_permuted)
|
|
107
|
+
score = self.scoring_fn(self.y, predictions)
|
|
108
|
+
scores.append(score)
|
|
109
|
+
|
|
110
|
+
# Importance = drop in score when feature is permuted
|
|
111
|
+
importance = baseline_score - np.mean(scores)
|
|
112
|
+
importances[fname] = float(importance)
|
|
113
|
+
stds[fname] = float(np.std(scores))
|
|
114
|
+
|
|
115
|
+
return Explanation(
|
|
116
|
+
explainer_name="PermutationImportance",
|
|
117
|
+
target_class="global",
|
|
118
|
+
explanation_data={
|
|
119
|
+
"feature_attributions": importances,
|
|
120
|
+
"std": stds,
|
|
121
|
+
"baseline_score": baseline_score
|
|
122
|
+
}
|
|
123
|
+
)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# src/explainiverse/explainers/global_explainers/sage.py
|
|
2
|
+
"""
|
|
3
|
+
SAGE (Shapley Additive Global importancE) Explainer.
|
|
4
|
+
|
|
5
|
+
SAGE extends SHAP to provide global feature importance by computing
|
|
6
|
+
the expected Shapley value across all samples. This gives a theoretically
|
|
7
|
+
grounded global importance measure.
|
|
8
|
+
|
|
9
|
+
Reference:
|
|
10
|
+
Covert, I., Lundberg, S., & Lee, S.I. (2020). Understanding Global Feature
|
|
11
|
+
Contributions with Additive Importance Measures. NeurIPS 2020.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from typing import List, Optional, Callable
|
|
16
|
+
from sklearn.metrics import accuracy_score, mean_squared_error
|
|
17
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
18
|
+
from explainiverse.core.explanation import Explanation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SAGEExplainer(BaseExplainer):
|
|
22
|
+
"""
|
|
23
|
+
SAGE: Shapley Additive Global importancE.
|
|
24
|
+
|
|
25
|
+
Computes global feature importance using Shapley values, averaging
|
|
26
|
+
contributions across all samples. Unlike permutation importance,
|
|
27
|
+
SAGE accounts for feature interactions.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
model: Model adapter with .predict() method
|
|
31
|
+
X: Feature matrix
|
|
32
|
+
y: Target values
|
|
33
|
+
feature_names: List of feature names
|
|
34
|
+
n_permutations: Number of permutation samples for approximation
|
|
35
|
+
loss_fn: Loss function (default: accuracy for classification)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model,
|
|
41
|
+
X: np.ndarray,
|
|
42
|
+
y: np.ndarray,
|
|
43
|
+
feature_names: List[str],
|
|
44
|
+
n_permutations: int = 100,
|
|
45
|
+
loss_fn: Optional[Callable] = None,
|
|
46
|
+
task: str = "classification",
|
|
47
|
+
random_state: int = 42
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Initialize the SAGE explainer.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: Model adapter with .predict() method
|
|
54
|
+
X: Feature matrix (n_samples, n_features)
|
|
55
|
+
y: Target values (n_samples,)
|
|
56
|
+
feature_names: List of feature names
|
|
57
|
+
n_permutations: Number of permutations for approximation
|
|
58
|
+
loss_fn: Custom loss function (lower is better)
|
|
59
|
+
task: "classification" or "regression"
|
|
60
|
+
random_state: Random seed
|
|
61
|
+
"""
|
|
62
|
+
super().__init__(model)
|
|
63
|
+
self.X = np.array(X)
|
|
64
|
+
self.y = np.array(y)
|
|
65
|
+
self.feature_names = list(feature_names)
|
|
66
|
+
self.n_permutations = n_permutations
|
|
67
|
+
self.task = task
|
|
68
|
+
self.random_state = random_state
|
|
69
|
+
self.rng = np.random.RandomState(random_state)
|
|
70
|
+
|
|
71
|
+
if loss_fn is None:
|
|
72
|
+
if task == "classification":
|
|
73
|
+
self.loss_fn = lambda y_true, y_pred: 1.0 - accuracy_score(
|
|
74
|
+
y_true, np.argmax(y_pred, axis=1) if y_pred.ndim == 2 else y_pred
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
self.loss_fn = lambda y_true, y_pred: mean_squared_error(y_true, y_pred)
|
|
78
|
+
else:
|
|
79
|
+
self.loss_fn = loss_fn
|
|
80
|
+
|
|
81
|
+
def _compute_loss(self, X_masked: np.ndarray) -> float:
|
|
82
|
+
"""Compute loss on masked data."""
|
|
83
|
+
predictions = self.model.predict(X_masked)
|
|
84
|
+
return self.loss_fn(self.y, predictions)
|
|
85
|
+
|
|
86
|
+
def _marginal_contribution(
|
|
87
|
+
self,
|
|
88
|
+
feature_idx: int,
|
|
89
|
+
feature_order: List[int],
|
|
90
|
+
position: int
|
|
91
|
+
) -> float:
|
|
92
|
+
"""
|
|
93
|
+
Compute marginal contribution of a feature given a feature ordering.
|
|
94
|
+
|
|
95
|
+
The marginal contribution is the change in loss when adding the feature
|
|
96
|
+
to the set of features that come before it in the ordering.
|
|
97
|
+
"""
|
|
98
|
+
n_samples, n_features = self.X.shape
|
|
99
|
+
|
|
100
|
+
# Features before this one in the ordering
|
|
101
|
+
features_before = set(feature_order[:position])
|
|
102
|
+
features_with = features_before | {feature_idx}
|
|
103
|
+
|
|
104
|
+
# Create masked versions
|
|
105
|
+
X_without = self.X.copy()
|
|
106
|
+
X_with = self.X.copy()
|
|
107
|
+
|
|
108
|
+
# Mask features NOT in the respective sets by replacing with random samples
|
|
109
|
+
for j in range(n_features):
|
|
110
|
+
if j not in features_before:
|
|
111
|
+
# Shuffle this feature (marginalizing out)
|
|
112
|
+
shuffle_idx = self.rng.permutation(n_samples)
|
|
113
|
+
X_without[:, j] = self.X[shuffle_idx, j]
|
|
114
|
+
|
|
115
|
+
if j not in features_with:
|
|
116
|
+
shuffle_idx = self.rng.permutation(n_samples)
|
|
117
|
+
X_with[:, j] = self.X[shuffle_idx, j]
|
|
118
|
+
|
|
119
|
+
loss_without = self._compute_loss(X_without)
|
|
120
|
+
loss_with = self._compute_loss(X_with)
|
|
121
|
+
|
|
122
|
+
# Marginal contribution = reduction in loss
|
|
123
|
+
return loss_without - loss_with
|
|
124
|
+
|
|
125
|
+
def explain(self, **kwargs) -> Explanation:
|
|
126
|
+
"""
|
|
127
|
+
Compute SAGE values for all features.
|
|
128
|
+
|
|
129
|
+
Uses permutation sampling to approximate the Shapley values.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Explanation object with global feature importance (SAGE values)
|
|
133
|
+
"""
|
|
134
|
+
n_features = len(self.feature_names)
|
|
135
|
+
sage_values = np.zeros(n_features)
|
|
136
|
+
|
|
137
|
+
for _ in range(self.n_permutations):
|
|
138
|
+
# Random feature ordering
|
|
139
|
+
order = self.rng.permutation(n_features).tolist()
|
|
140
|
+
|
|
141
|
+
for position, feature_idx in enumerate(order):
|
|
142
|
+
contribution = self._marginal_contribution(
|
|
143
|
+
feature_idx, order, position
|
|
144
|
+
)
|
|
145
|
+
sage_values[feature_idx] += contribution
|
|
146
|
+
|
|
147
|
+
# Average over permutations
|
|
148
|
+
sage_values /= self.n_permutations
|
|
149
|
+
|
|
150
|
+
# Create attribution dict
|
|
151
|
+
attributions = {
|
|
152
|
+
fname: float(sage_values[i])
|
|
153
|
+
for i, fname in enumerate(self.feature_names)
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return Explanation(
|
|
157
|
+
explainer_name="SAGE",
|
|
158
|
+
target_class="global",
|
|
159
|
+
explanation_data={
|
|
160
|
+
"feature_attributions": attributions,
|
|
161
|
+
"n_permutations": self.n_permutations,
|
|
162
|
+
"task": self.task
|
|
163
|
+
}
|
|
164
|
+
)
|