explainiverse 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {explainiverse-0.3.0 → explainiverse-0.4.0}/PKG-INFO +1 -1
- {explainiverse-0.3.0 → explainiverse-0.4.0}/pyproject.toml +1 -1
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/__init__.py +1 -1
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/registry.py +22 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/__init__.py +8 -0
- explainiverse-0.4.0/src/explainiverse/explainers/example_based/__init__.py +18 -0
- explainiverse-0.4.0/src/explainiverse/explainers/example_based/protodash.py +826 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/LICENSE +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/README.md +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/base_adapter.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/pytorch_adapter.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/explainer.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/explanation.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/engine/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/engine/suite.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/_utils.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/faithfulness.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/metrics.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/stability.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/lime_wrapper.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/deeplift.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/gradcam.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/integrated_gradients.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
- {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: explainiverse
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
|
|
5
5
|
Home-page: https://github.com/jemsbhai/explainiverse
|
|
6
6
|
License: MIT
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "explainiverse"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
|
|
5
5
|
authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -372,6 +372,7 @@ def _create_default_registry() -> ExplainerRegistry:
|
|
|
372
372
|
from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
|
|
373
373
|
from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
|
|
374
374
|
from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
|
|
375
|
+
from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
|
|
375
376
|
|
|
376
377
|
registry = ExplainerRegistry()
|
|
377
378
|
|
|
@@ -604,6 +605,27 @@ def _create_default_registry() -> ExplainerRegistry:
|
|
|
604
605
|
)
|
|
605
606
|
)
|
|
606
607
|
|
|
608
|
+
# =========================================================================
|
|
609
|
+
# Example-Based Explainers
|
|
610
|
+
# =========================================================================
|
|
611
|
+
|
|
612
|
+
# Register ProtoDash
|
|
613
|
+
registry.register(
|
|
614
|
+
name="protodash",
|
|
615
|
+
explainer_class=ProtoDashExplainer,
|
|
616
|
+
meta=ExplainerMeta(
|
|
617
|
+
scope="local",
|
|
618
|
+
model_types=["any"],
|
|
619
|
+
data_types=["tabular"],
|
|
620
|
+
task_types=["classification", "regression"],
|
|
621
|
+
description="ProtoDash - prototype selection with importance weights for example-based explanations",
|
|
622
|
+
paper_reference="Gurumoorthy et al., 2019 - 'Efficient Data Representation by Selecting Prototypes' (ICDM)",
|
|
623
|
+
complexity="O(n_prototypes * n_samples^2)",
|
|
624
|
+
requires_training_data=True,
|
|
625
|
+
supports_batching=True
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
|
|
607
629
|
return registry
|
|
608
630
|
|
|
609
631
|
|
|
@@ -9,12 +9,17 @@ Local Explainers (instance-level):
|
|
|
9
9
|
- Anchors: High-precision rule-based explanations
|
|
10
10
|
- Counterfactual: Diverse counterfactual explanations
|
|
11
11
|
- Integrated Gradients: Gradient-based attributions for neural networks
|
|
12
|
+
- DeepLIFT: Reference-based attributions for neural networks
|
|
13
|
+
- DeepSHAP: DeepLIFT combined with SHAP for neural networks
|
|
12
14
|
|
|
13
15
|
Global Explainers (model-level):
|
|
14
16
|
- Permutation Importance: Feature importance via permutation
|
|
15
17
|
- Partial Dependence: Marginal feature effects (PDP)
|
|
16
18
|
- ALE: Accumulated Local Effects (unbiased for correlated features)
|
|
17
19
|
- SAGE: Shapley Additive Global importancE
|
|
20
|
+
|
|
21
|
+
Example-Based Explainers:
|
|
22
|
+
- ProtoDash: Prototype selection with importance weights
|
|
18
23
|
"""
|
|
19
24
|
|
|
20
25
|
from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
|
|
@@ -29,6 +34,7 @@ from explainiverse.explainers.global_explainers.sage import SAGEExplainer
|
|
|
29
34
|
from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
|
|
30
35
|
from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
|
|
31
36
|
from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
|
|
37
|
+
from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
|
|
32
38
|
|
|
33
39
|
__all__ = [
|
|
34
40
|
# Local explainers
|
|
@@ -46,4 +52,6 @@ __all__ = [
|
|
|
46
52
|
"PartialDependenceExplainer",
|
|
47
53
|
"ALEExplainer",
|
|
48
54
|
"SAGEExplainer",
|
|
55
|
+
# Example-based explainers
|
|
56
|
+
"ProtoDashExplainer",
|
|
49
57
|
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# src/explainiverse/explainers/example_based/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Example-based explanation methods.
|
|
4
|
+
|
|
5
|
+
These methods explain models by identifying representative examples
|
|
6
|
+
from the training data, rather than computing feature attributions.
|
|
7
|
+
|
|
8
|
+
Methods:
|
|
9
|
+
- ProtoDash: Select prototypical examples with importance weights
|
|
10
|
+
- (Future) Influence Functions: Identify training examples that most affect predictions
|
|
11
|
+
- (Future) MMD-Critic: Find prototypes and criticisms
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"ProtoDashExplainer",
|
|
18
|
+
]
|
|
@@ -0,0 +1,826 @@
|
|
|
1
|
+
# src/explainiverse/explainers/example_based/protodash.py
|
|
2
|
+
"""
|
|
3
|
+
ProtoDash - Prototype Selection with Importance Weights.
|
|
4
|
+
|
|
5
|
+
ProtoDash selects a small set of prototypical examples from a dataset
|
|
6
|
+
that best represent the data distribution or explain model predictions.
|
|
7
|
+
Each prototype is assigned an importance weight indicating its contribution.
|
|
8
|
+
|
|
9
|
+
The algorithm minimizes the Maximum Mean Discrepancy (MMD) between:
|
|
10
|
+
- The weighted combination of selected prototypes
|
|
11
|
+
- The target distribution (full dataset or specific instances)
|
|
12
|
+
|
|
13
|
+
Key Features:
|
|
14
|
+
- Works with any model type (or no model at all for data summarization)
|
|
15
|
+
- Provides interpretable weights for each prototype
|
|
16
|
+
- Supports multiple kernel functions (RBF, linear, cosine)
|
|
17
|
+
- Can explain individual predictions or summarize entire datasets
|
|
18
|
+
- Class-conditional prototype selection
|
|
19
|
+
|
|
20
|
+
Use Cases:
|
|
21
|
+
1. Dataset Summarization: "These 10 examples represent the entire dataset"
|
|
22
|
+
2. Prediction Explanation: "This prediction is similar to examples A, B, C"
|
|
23
|
+
3. Model Debugging: "The model relies heavily on these training examples"
|
|
24
|
+
4. Data Compression: Reduce dataset while preserving distribution
|
|
25
|
+
|
|
26
|
+
Reference:
|
|
27
|
+
Gurumoorthy, K.S., Dhurandhar, A., Cecchi, G., & Aggarwal, C. (2019).
|
|
28
|
+
"Efficient Data Representation by Selecting Prototypes with Importance Weights"
|
|
29
|
+
IEEE International Conference on Data Mining (ICDM).
|
|
30
|
+
|
|
31
|
+
Also based on:
|
|
32
|
+
Kim, B., Khanna, R., & Koyejo, O. (2016).
|
|
33
|
+
"Examples are not Enough, Learn to Criticize! Criticism for Interpretability"
|
|
34
|
+
NeurIPS 2016.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
from explainiverse.explainers.example_based import ProtoDashExplainer
|
|
38
|
+
|
|
39
|
+
# Dataset summarization
|
|
40
|
+
explainer = ProtoDashExplainer(n_prototypes=10, kernel="rbf")
|
|
41
|
+
result = explainer.find_prototypes(X_train)
|
|
42
|
+
print(f"Prototype indices: {result.explanation_data['prototype_indices']}")
|
|
43
|
+
print(f"Weights: {result.explanation_data['weights']}")
|
|
44
|
+
|
|
45
|
+
# Explaining a prediction
|
|
46
|
+
explainer = ProtoDashExplainer(model=adapter, n_prototypes=5)
|
|
47
|
+
explanation = explainer.explain(test_instance, X_reference=X_train)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
import numpy as np
|
|
51
|
+
from typing import List, Optional, Union, Callable, Tuple, Dict
|
|
52
|
+
from scipy.spatial.distance import cdist
|
|
53
|
+
from scipy.optimize import minimize
|
|
54
|
+
|
|
55
|
+
from explainiverse.core.explainer import BaseExplainer
|
|
56
|
+
from explainiverse.core.explanation import Explanation
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ProtoDashExplainer(BaseExplainer):
|
|
60
|
+
"""
|
|
61
|
+
ProtoDash explainer for prototype-based explanations.
|
|
62
|
+
|
|
63
|
+
Selects representative examples (prototypes) from a reference dataset
|
|
64
|
+
that best explain a target distribution or individual predictions.
|
|
65
|
+
Each prototype is assigned an importance weight.
|
|
66
|
+
|
|
67
|
+
The algorithm greedily selects prototypes that minimize the Maximum
|
|
68
|
+
Mean Discrepancy (MMD) between the weighted prototype set and the
|
|
69
|
+
target, then optimizes the weights.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
model: Optional model adapter (for prediction-based explanations)
|
|
73
|
+
n_prototypes: Number of prototypes to select
|
|
74
|
+
kernel: Kernel function type ("rbf", "linear", "cosine")
|
|
75
|
+
kernel_width: Width parameter for RBF kernel (auto-computed if None)
|
|
76
|
+
epsilon: Small constant for numerical stability
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> explainer = ProtoDashExplainer(n_prototypes=5, kernel="rbf")
|
|
80
|
+
>>> result = explainer.find_prototypes(X_train)
|
|
81
|
+
>>> prototypes = X_train[result.explanation_data['prototype_indices']]
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
model=None,
|
|
87
|
+
n_prototypes: int = 10,
|
|
88
|
+
kernel: str = "rbf",
|
|
89
|
+
kernel_width: Optional[float] = None,
|
|
90
|
+
epsilon: float = 1e-10,
|
|
91
|
+
optimize_weights: bool = True,
|
|
92
|
+
random_state: Optional[int] = None,
|
|
93
|
+
force_n_prototypes: bool = True
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Initialize the ProtoDash explainer.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
model: Optional model adapter. If provided, can use model
|
|
100
|
+
predictions in the kernel computation for explanation.
|
|
101
|
+
n_prototypes: Number of prototypes to select (default: 10).
|
|
102
|
+
kernel: Kernel function type:
|
|
103
|
+
- "rbf": Radial Basis Function (Gaussian) kernel
|
|
104
|
+
- "linear": Linear kernel (dot product)
|
|
105
|
+
- "cosine": Cosine similarity kernel
|
|
106
|
+
kernel_width: Width (sigma) for RBF kernel. If None, uses
|
|
107
|
+
median heuristic based on pairwise distances.
|
|
108
|
+
epsilon: Small constant for numerical stability (default: 1e-10).
|
|
109
|
+
optimize_weights: If True, optimize weights after greedy selection.
|
|
110
|
+
If False, use weights from greedy selection only.
|
|
111
|
+
random_state: Random seed for reproducibility.
|
|
112
|
+
force_n_prototypes: If True (default), always select exactly
|
|
113
|
+
n_prototypes (or all available if fewer).
|
|
114
|
+
If False, may stop early when gain becomes
|
|
115
|
+
negative (original ProtoDash behavior).
|
|
116
|
+
"""
|
|
117
|
+
super().__init__(model)
|
|
118
|
+
|
|
119
|
+
self.n_prototypes = n_prototypes
|
|
120
|
+
self.kernel = kernel.lower()
|
|
121
|
+
self.kernel_width = kernel_width
|
|
122
|
+
self.epsilon = epsilon
|
|
123
|
+
self.optimize_weights = optimize_weights
|
|
124
|
+
self.random_state = random_state
|
|
125
|
+
self.force_n_prototypes = force_n_prototypes
|
|
126
|
+
|
|
127
|
+
if self.kernel not in ["rbf", "linear", "cosine"]:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Unknown kernel '{kernel}'. Supported: 'rbf', 'linear', 'cosine'"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Cache for kernel matrix
|
|
133
|
+
self._kernel_matrix_cache = None
|
|
134
|
+
self._reference_data_hash = None
|
|
135
|
+
|
|
136
|
+
def _compute_kernel_width(self, X: np.ndarray) -> float:
|
|
137
|
+
"""
|
|
138
|
+
Compute kernel width using median heuristic.
|
|
139
|
+
|
|
140
|
+
The median heuristic sets sigma = median of pairwise distances,
|
|
141
|
+
which is a common rule of thumb for RBF kernels.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
X: Data matrix of shape (n_samples, n_features)
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Kernel width (sigma) value
|
|
148
|
+
"""
|
|
149
|
+
# Subsample for efficiency if dataset is large
|
|
150
|
+
n_samples = X.shape[0]
|
|
151
|
+
if n_samples > 1000:
|
|
152
|
+
if self.random_state is not None:
|
|
153
|
+
np.random.seed(self.random_state)
|
|
154
|
+
indices = np.random.choice(n_samples, size=1000, replace=False)
|
|
155
|
+
X_sample = X[indices]
|
|
156
|
+
else:
|
|
157
|
+
X_sample = X
|
|
158
|
+
|
|
159
|
+
# Compute pairwise distances
|
|
160
|
+
distances = cdist(X_sample, X_sample, metric='euclidean')
|
|
161
|
+
|
|
162
|
+
# Get median of non-zero distances
|
|
163
|
+
mask = distances > 0
|
|
164
|
+
if np.any(mask):
|
|
165
|
+
median_dist = np.median(distances[mask])
|
|
166
|
+
else:
|
|
167
|
+
median_dist = 1.0
|
|
168
|
+
|
|
169
|
+
return max(median_dist, self.epsilon)
|
|
170
|
+
|
|
171
|
+
def _compute_kernel(
|
|
172
|
+
self,
|
|
173
|
+
X: np.ndarray,
|
|
174
|
+
Y: Optional[np.ndarray] = None,
|
|
175
|
+
kernel_width: Optional[float] = None
|
|
176
|
+
) -> np.ndarray:
|
|
177
|
+
"""
|
|
178
|
+
Compute kernel matrix between X and Y.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
X: First data matrix of shape (n_samples_X, n_features)
|
|
182
|
+
Y: Second data matrix of shape (n_samples_Y, n_features).
|
|
183
|
+
If None, computes K(X, X).
|
|
184
|
+
kernel_width: Override kernel width for RBF kernel.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Kernel matrix of shape (n_samples_X, n_samples_Y)
|
|
188
|
+
"""
|
|
189
|
+
if Y is None:
|
|
190
|
+
Y = X
|
|
191
|
+
|
|
192
|
+
if self.kernel == "rbf":
|
|
193
|
+
sigma = kernel_width or self.kernel_width
|
|
194
|
+
if sigma is None:
|
|
195
|
+
sigma = self._compute_kernel_width(X)
|
|
196
|
+
|
|
197
|
+
# K(x, y) = exp(-||x - y||^2 / (2 * sigma^2))
|
|
198
|
+
sq_dists = cdist(X, Y, metric='sqeuclidean')
|
|
199
|
+
K = np.exp(-sq_dists / (2 * sigma ** 2))
|
|
200
|
+
|
|
201
|
+
elif self.kernel == "linear":
|
|
202
|
+
# K(x, y) = x · y
|
|
203
|
+
K = X @ Y.T
|
|
204
|
+
|
|
205
|
+
elif self.kernel == "cosine":
|
|
206
|
+
# K(x, y) = (x · y) / (||x|| * ||y||)
|
|
207
|
+
X_norm = X / (np.linalg.norm(X, axis=1, keepdims=True) + self.epsilon)
|
|
208
|
+
Y_norm = Y / (np.linalg.norm(Y, axis=1, keepdims=True) + self.epsilon)
|
|
209
|
+
K = X_norm @ Y_norm.T
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"Unknown kernel: {self.kernel}")
|
|
213
|
+
|
|
214
|
+
return K
|
|
215
|
+
|
|
216
|
+
def _greedy_prototype_selection(
|
|
217
|
+
self,
|
|
218
|
+
K_ref_ref: np.ndarray,
|
|
219
|
+
K_ref_target: np.ndarray,
|
|
220
|
+
n_prototypes: int,
|
|
221
|
+
force_n_prototypes: bool = True
|
|
222
|
+
) -> Tuple[List[int], np.ndarray]:
|
|
223
|
+
"""
|
|
224
|
+
ProtoDash greedy prototype selection with iterative weight optimization.
|
|
225
|
+
|
|
226
|
+
Implements the algorithm from:
|
|
227
|
+
Gurumoorthy et al., 2019 - "Efficient Data Representation by Selecting
|
|
228
|
+
Prototypes with Importance Weights" (ICDM)
|
|
229
|
+
|
|
230
|
+
The algorithm solves:
|
|
231
|
+
min_w (1/2) w^T K w - w^T μ
|
|
232
|
+
s.t. w >= 0
|
|
233
|
+
|
|
234
|
+
where μ_j = mean(K(x_j, target_points)) is the mean kernel similarity
|
|
235
|
+
of candidate j to all target points.
|
|
236
|
+
|
|
237
|
+
At each iteration:
|
|
238
|
+
1. Compute gradient gain for each unselected candidate
|
|
239
|
+
2. Select the candidate with maximum positive gain
|
|
240
|
+
3. Re-optimize weights over all selected prototypes
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
K_ref_ref: Kernel matrix K(reference, reference) of shape (n_ref, n_ref)
|
|
244
|
+
K_ref_target: Kernel matrix K(reference, target) of shape (n_ref, n_target)
|
|
245
|
+
n_prototypes: Number of prototypes to select
|
|
246
|
+
force_n_prototypes: If True, always select n_prototypes even if gain
|
|
247
|
+
becomes negative. If False, stop when no positive gain.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Tuple of (prototype_indices, weights)
|
|
251
|
+
"""
|
|
252
|
+
n_ref = K_ref_ref.shape[0]
|
|
253
|
+
|
|
254
|
+
# μ_j = mean kernel similarity of candidate j to target distribution
|
|
255
|
+
# This is the linear term in the QP objective
|
|
256
|
+
mu = K_ref_target.mean(axis=1)
|
|
257
|
+
|
|
258
|
+
# Track selected prototypes and their optimized weights
|
|
259
|
+
selected_indices = []
|
|
260
|
+
# Full weight vector (sparse, only selected indices are non-zero)
|
|
261
|
+
weights = np.zeros(n_ref)
|
|
262
|
+
|
|
263
|
+
for iteration in range(min(n_prototypes, n_ref)):
|
|
264
|
+
# Compute gradient gain for each candidate
|
|
265
|
+
# For the objective L(w) = (1/2) w^T K w - w^T μ
|
|
266
|
+
# Gradient: ∇L = K w - μ
|
|
267
|
+
# Gain for adding point j (currently w_j = 0): gain_j = μ_j - (Kw)_j
|
|
268
|
+
# We want to maximize gain, which means minimizing the objective
|
|
269
|
+
|
|
270
|
+
gradient = K_ref_ref @ weights - mu # ∇L
|
|
271
|
+
gains = -gradient # gain = μ - Kw (negative gradient = descent direction)
|
|
272
|
+
|
|
273
|
+
# Mask already selected indices
|
|
274
|
+
gains_masked = gains.copy()
|
|
275
|
+
gains_masked[selected_indices] = -np.inf
|
|
276
|
+
|
|
277
|
+
# Select candidate with maximum gain
|
|
278
|
+
best_idx = np.argmax(gains_masked)
|
|
279
|
+
best_gain = gains_masked[best_idx]
|
|
280
|
+
|
|
281
|
+
# Early stopping check (only if not forcing n_prototypes)
|
|
282
|
+
if not force_n_prototypes and best_gain <= self.epsilon:
|
|
283
|
+
break
|
|
284
|
+
|
|
285
|
+
selected_indices.append(best_idx)
|
|
286
|
+
|
|
287
|
+
# Re-optimize weights over all selected prototypes
|
|
288
|
+
# Solve: min_w (1/2) w^T K_ss w - w^T μ_s, s.t. w >= 0
|
|
289
|
+
# where K_ss is kernel matrix restricted to selected indices
|
|
290
|
+
# and μ_s is mu restricted to selected indices
|
|
291
|
+
|
|
292
|
+
selected_arr = np.array(selected_indices)
|
|
293
|
+
K_selected = K_ref_ref[np.ix_(selected_arr, selected_arr)]
|
|
294
|
+
mu_selected = mu[selected_arr]
|
|
295
|
+
|
|
296
|
+
# Optimize weights for selected prototypes
|
|
297
|
+
w_selected = self._optimize_weights_qp(K_selected, mu_selected)
|
|
298
|
+
|
|
299
|
+
# Update full weight vector
|
|
300
|
+
weights = np.zeros(n_ref)
|
|
301
|
+
weights[selected_arr] = w_selected
|
|
302
|
+
|
|
303
|
+
# Return only the selected indices and their weights
|
|
304
|
+
if len(selected_indices) == 0:
|
|
305
|
+
return [], np.array([])
|
|
306
|
+
|
|
307
|
+
final_weights = weights[np.array(selected_indices)]
|
|
308
|
+
return selected_indices, final_weights
|
|
309
|
+
|
|
310
|
+
def _optimize_weights_qp(
|
|
311
|
+
self,
|
|
312
|
+
K: np.ndarray,
|
|
313
|
+
mu: np.ndarray,
|
|
314
|
+
normalize: bool = False
|
|
315
|
+
) -> np.ndarray:
|
|
316
|
+
"""
|
|
317
|
+
Optimize prototype weights via constrained quadratic programming.
|
|
318
|
+
|
|
319
|
+
Solves:
|
|
320
|
+
min_w (1/2) w^T K w - w^T μ
|
|
321
|
+
s.t. w >= 0
|
|
322
|
+
(optional) sum(w) = 1
|
|
323
|
+
|
|
324
|
+
Uses scipy.optimize.minimize with SLSQP method.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
K: Kernel matrix between selected prototypes (m x m)
|
|
328
|
+
mu: Mean kernel similarity to target for each prototype (m,)
|
|
329
|
+
normalize: If True, constrain weights to sum to 1
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
Optimized non-negative weights
|
|
333
|
+
"""
|
|
334
|
+
m = K.shape[0]
|
|
335
|
+
|
|
336
|
+
if m == 0:
|
|
337
|
+
return np.array([])
|
|
338
|
+
|
|
339
|
+
if m == 1:
|
|
340
|
+
# Single prototype: optimal weight is μ/K if K > 0
|
|
341
|
+
if K[0, 0] > self.epsilon:
|
|
342
|
+
w = max(mu[0] / K[0, 0], 0)
|
|
343
|
+
else:
|
|
344
|
+
w = 1.0
|
|
345
|
+
return np.array([w]) if not normalize else np.array([1.0])
|
|
346
|
+
|
|
347
|
+
# Add small regularization for numerical stability
|
|
348
|
+
K_reg = K + self.epsilon * np.eye(m)
|
|
349
|
+
|
|
350
|
+
# Objective: (1/2) w^T K w - w^T μ
|
|
351
|
+
def objective(w):
|
|
352
|
+
return 0.5 * w @ K_reg @ w - w @ mu
|
|
353
|
+
|
|
354
|
+
def gradient(w):
|
|
355
|
+
return K_reg @ w - mu
|
|
356
|
+
|
|
357
|
+
# Initial guess: equal weights
|
|
358
|
+
w0 = np.ones(m) / m
|
|
359
|
+
|
|
360
|
+
# Bounds: w >= 0
|
|
361
|
+
bounds = [(0, None) for _ in range(m)]
|
|
362
|
+
|
|
363
|
+
# Constraints
|
|
364
|
+
constraints = []
|
|
365
|
+
if normalize:
|
|
366
|
+
constraints.append({'type': 'eq', 'fun': lambda w: np.sum(w) - 1.0})
|
|
367
|
+
|
|
368
|
+
# Optimize
|
|
369
|
+
result = minimize(
|
|
370
|
+
objective,
|
|
371
|
+
w0,
|
|
372
|
+
method='SLSQP',
|
|
373
|
+
jac=gradient,
|
|
374
|
+
bounds=bounds,
|
|
375
|
+
constraints=constraints,
|
|
376
|
+
options={'maxiter': 500, 'ftol': 1e-12}
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
weights = result.x
|
|
380
|
+
|
|
381
|
+
# Ensure non-negativity (numerical cleanup)
|
|
382
|
+
weights = np.maximum(weights, 0)
|
|
383
|
+
|
|
384
|
+
return weights
|
|
385
|
+
|
|
386
|
+
def _optimize_weights(
|
|
387
|
+
self,
|
|
388
|
+
K_proto_proto: np.ndarray,
|
|
389
|
+
K_proto_target: np.ndarray,
|
|
390
|
+
initial_weights: np.ndarray
|
|
391
|
+
) -> np.ndarray:
|
|
392
|
+
"""
|
|
393
|
+
Final weight optimization for selected prototypes.
|
|
394
|
+
|
|
395
|
+
This is called after greedy selection to do a final refinement
|
|
396
|
+
of weights, optionally with normalization for interpretability.
|
|
397
|
+
|
|
398
|
+
Solves the same QP as _optimize_weights_qp but uses the
|
|
399
|
+
mean kernel to target as the linear term.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
K_proto_proto: Kernel matrix between prototypes (m x m)
|
|
403
|
+
K_proto_target: Kernel matrix from prototypes to target (m x n_target)
|
|
404
|
+
initial_weights: Initial weights from greedy selection
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Optimized weights (non-negative, optionally normalized)
|
|
408
|
+
"""
|
|
409
|
+
n_proto = K_proto_proto.shape[0]
|
|
410
|
+
|
|
411
|
+
if n_proto == 0:
|
|
412
|
+
return np.array([])
|
|
413
|
+
|
|
414
|
+
if n_proto == 1:
|
|
415
|
+
return np.array([1.0]) # Single prototype gets weight 1
|
|
416
|
+
|
|
417
|
+
# Target: mean kernel to target points
|
|
418
|
+
mu = K_proto_target.mean(axis=1)
|
|
419
|
+
|
|
420
|
+
# Use the QP solver
|
|
421
|
+
weights = self._optimize_weights_qp(K_proto_proto, mu, normalize=False)
|
|
422
|
+
|
|
423
|
+
# Normalize for interpretability (weights sum to 1)
|
|
424
|
+
weight_sum = weights.sum()
|
|
425
|
+
if weight_sum > self.epsilon:
|
|
426
|
+
weights = weights / weight_sum
|
|
427
|
+
else:
|
|
428
|
+
# Fallback to equal weights if optimization failed
|
|
429
|
+
weights = np.ones(n_proto) / n_proto
|
|
430
|
+
|
|
431
|
+
return weights
|
|
432
|
+
|
|
433
|
+
def find_prototypes(
|
|
434
|
+
self,
|
|
435
|
+
X: np.ndarray,
|
|
436
|
+
y: Optional[np.ndarray] = None,
|
|
437
|
+
target_class: Optional[int] = None,
|
|
438
|
+
feature_names: Optional[List[str]] = None,
|
|
439
|
+
return_mmd: bool = False
|
|
440
|
+
) -> Explanation:
|
|
441
|
+
"""
|
|
442
|
+
Find prototypes that summarize a dataset.
|
|
443
|
+
|
|
444
|
+
Selects a small set of examples from X that best represent
|
|
445
|
+
the data distribution. If y is provided, can select prototypes
|
|
446
|
+
for a specific class.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
X: Data matrix of shape (n_samples, n_features).
|
|
450
|
+
y: Optional labels. If provided with target_class, selects
|
|
451
|
+
prototypes only from that class.
|
|
452
|
+
target_class: If provided with y, only consider examples
|
|
453
|
+
from this class as candidates.
|
|
454
|
+
feature_names: Optional list of feature names.
|
|
455
|
+
return_mmd: If True, include MMD score in explanation.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
Explanation object containing:
|
|
459
|
+
- prototype_indices: Indices of selected prototypes in X
|
|
460
|
+
- weights: Importance weight for each prototype
|
|
461
|
+
- prototypes: The actual prototype data points
|
|
462
|
+
- mmd_score: (optional) Final MMD between prototypes and data
|
|
463
|
+
"""
|
|
464
|
+
X = np.asarray(X, dtype=np.float64)
|
|
465
|
+
|
|
466
|
+
if X.ndim == 1:
|
|
467
|
+
X = X.reshape(1, -1)
|
|
468
|
+
|
|
469
|
+
n_samples, n_features = X.shape
|
|
470
|
+
|
|
471
|
+
# Filter by class if specified
|
|
472
|
+
if y is not None and target_class is not None:
|
|
473
|
+
y = np.asarray(y)
|
|
474
|
+
class_mask = (y == target_class)
|
|
475
|
+
X_candidates = X[class_mask]
|
|
476
|
+
original_indices = np.where(class_mask)[0]
|
|
477
|
+
else:
|
|
478
|
+
X_candidates = X
|
|
479
|
+
original_indices = np.arange(n_samples)
|
|
480
|
+
|
|
481
|
+
n_candidates = X_candidates.shape[0]
|
|
482
|
+
n_proto = min(self.n_prototypes, n_candidates)
|
|
483
|
+
|
|
484
|
+
if n_proto == 0:
|
|
485
|
+
raise ValueError("No candidate examples available for prototype selection.")
|
|
486
|
+
|
|
487
|
+
# Auto-compute kernel width if needed
|
|
488
|
+
if self.kernel == "rbf" and self.kernel_width is None:
|
|
489
|
+
self.kernel_width = self._compute_kernel_width(X_candidates)
|
|
490
|
+
|
|
491
|
+
# Compute kernel matrices
|
|
492
|
+
# K(candidates, candidates) for prototype selection
|
|
493
|
+
# K(candidates, X) for representing the full distribution
|
|
494
|
+
K_cand_cand = self._compute_kernel(X_candidates, X_candidates)
|
|
495
|
+
K_cand_all = self._compute_kernel(X_candidates, X)
|
|
496
|
+
|
|
497
|
+
# Greedy prototype selection
|
|
498
|
+
local_indices, greedy_weights = self._greedy_prototype_selection(
|
|
499
|
+
K_cand_cand, K_cand_all, n_proto, self.force_n_prototypes
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Convert to original indices
|
|
503
|
+
prototype_indices = [int(original_indices[i]) for i in local_indices]
|
|
504
|
+
|
|
505
|
+
# Optimize weights if requested
|
|
506
|
+
if self.optimize_weights and len(local_indices) > 1:
|
|
507
|
+
# Get kernel matrices for selected prototypes
|
|
508
|
+
proto_local_idx = np.array(local_indices)
|
|
509
|
+
K_proto_proto = K_cand_cand[np.ix_(proto_local_idx, proto_local_idx)]
|
|
510
|
+
K_proto_all = K_cand_all[proto_local_idx, :]
|
|
511
|
+
|
|
512
|
+
weights = self._optimize_weights(K_proto_proto, K_proto_all, greedy_weights)
|
|
513
|
+
else:
|
|
514
|
+
# Normalize greedy weights for interpretability
|
|
515
|
+
weights = greedy_weights.copy()
|
|
516
|
+
weight_sum = weights.sum()
|
|
517
|
+
if weight_sum > self.epsilon:
|
|
518
|
+
weights = weights / weight_sum
|
|
519
|
+
elif len(weights) > 0:
|
|
520
|
+
weights = np.ones(len(weights)) / len(weights)
|
|
521
|
+
|
|
522
|
+
# Build explanation data
|
|
523
|
+
explanation_data = {
|
|
524
|
+
"prototype_indices": prototype_indices,
|
|
525
|
+
"weights": weights.tolist(),
|
|
526
|
+
"prototypes": X[prototype_indices].tolist(),
|
|
527
|
+
"n_prototypes": len(prototype_indices),
|
|
528
|
+
"kernel": self.kernel,
|
|
529
|
+
"kernel_width": self.kernel_width if self.kernel == "rbf" else None,
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
if feature_names:
|
|
533
|
+
explanation_data["feature_names"] = feature_names
|
|
534
|
+
|
|
535
|
+
# Compute MMD if requested
|
|
536
|
+
if return_mmd:
|
|
537
|
+
proto_idx_local = np.array(local_indices)
|
|
538
|
+
K_pp = K_cand_cand[np.ix_(proto_idx_local, proto_idx_local)]
|
|
539
|
+
K_pa = K_cand_all[proto_idx_local, :]
|
|
540
|
+
K_aa = self._compute_kernel(X, X)
|
|
541
|
+
|
|
542
|
+
# MMD^2 = w^T K_pp w - 2 * w^T K_pa.mean() + K_aa.mean()
|
|
543
|
+
w = np.array(weights)
|
|
544
|
+
mmd_sq = w @ K_pp @ w - 2 * w @ K_pa.mean(axis=1) + K_aa.mean()
|
|
545
|
+
mmd = np.sqrt(max(mmd_sq, 0))
|
|
546
|
+
|
|
547
|
+
explanation_data["mmd_score"] = float(mmd)
|
|
548
|
+
|
|
549
|
+
# Determine label
|
|
550
|
+
if target_class is not None:
|
|
551
|
+
label_name = f"class_{target_class}"
|
|
552
|
+
else:
|
|
553
|
+
label_name = "dataset"
|
|
554
|
+
|
|
555
|
+
return Explanation(
|
|
556
|
+
explainer_name="ProtoDash",
|
|
557
|
+
target_class=label_name,
|
|
558
|
+
explanation_data=explanation_data
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def explain(
|
|
562
|
+
self,
|
|
563
|
+
instance: np.ndarray,
|
|
564
|
+
X_reference: np.ndarray,
|
|
565
|
+
feature_names: Optional[List[str]] = None,
|
|
566
|
+
use_predictions: bool = False,
|
|
567
|
+
return_similarity: bool = True
|
|
568
|
+
) -> Explanation:
|
|
569
|
+
"""
|
|
570
|
+
Explain a prediction by finding similar prototypes.
|
|
571
|
+
|
|
572
|
+
Finds prototypes from the reference set that are most similar
|
|
573
|
+
to the given instance, providing a "this is like..." explanation.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
instance: Instance to explain (1D array of shape n_features).
|
|
577
|
+
X_reference: Reference dataset to select prototypes from
|
|
578
|
+
(shape: n_samples, n_features).
|
|
579
|
+
feature_names: Optional list of feature names.
|
|
580
|
+
use_predictions: If True and model is provided, include model
|
|
581
|
+
predictions in the similarity computation.
|
|
582
|
+
return_similarity: If True, include similarity scores.
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
Explanation object containing prototype indices and weights.
|
|
586
|
+
"""
|
|
587
|
+
instance = np.asarray(instance, dtype=np.float64).flatten()
|
|
588
|
+
X_reference = np.asarray(X_reference, dtype=np.float64)
|
|
589
|
+
|
|
590
|
+
if X_reference.ndim == 1:
|
|
591
|
+
X_reference = X_reference.reshape(1, -1)
|
|
592
|
+
|
|
593
|
+
n_ref, n_features = X_reference.shape
|
|
594
|
+
n_proto = min(self.n_prototypes, n_ref)
|
|
595
|
+
|
|
596
|
+
# Auto-compute kernel width if needed
|
|
597
|
+
if self.kernel == "rbf" and self.kernel_width is None:
|
|
598
|
+
self.kernel_width = self._compute_kernel_width(X_reference)
|
|
599
|
+
|
|
600
|
+
# If using predictions and model is available, augment features
|
|
601
|
+
if use_predictions and self.model is not None:
|
|
602
|
+
# Get predictions for instance and reference
|
|
603
|
+
instance_pred = self.model.predict(instance.reshape(1, -1)).flatten()
|
|
604
|
+
ref_preds = self.model.predict(X_reference)
|
|
605
|
+
|
|
606
|
+
# Augment features with predictions
|
|
607
|
+
instance_aug = np.concatenate([instance, instance_pred])
|
|
608
|
+
X_ref_aug = np.hstack([X_reference, ref_preds])
|
|
609
|
+
else:
|
|
610
|
+
instance_aug = instance
|
|
611
|
+
X_ref_aug = X_reference
|
|
612
|
+
|
|
613
|
+
# Compute kernel matrices
|
|
614
|
+
# K(reference, reference) for prototype selection
|
|
615
|
+
# K(reference, instance) as target
|
|
616
|
+
K_ref_ref = self._compute_kernel(X_ref_aug, X_ref_aug)
|
|
617
|
+
K_ref_instance = self._compute_kernel(X_ref_aug, instance_aug.reshape(1, -1))
|
|
618
|
+
|
|
619
|
+
# Greedy prototype selection
|
|
620
|
+
prototype_indices, greedy_weights = self._greedy_prototype_selection(
|
|
621
|
+
K_ref_ref, K_ref_instance, n_proto, self.force_n_prototypes
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# Optimize weights
|
|
625
|
+
if self.optimize_weights and len(prototype_indices) > 1:
|
|
626
|
+
proto_idx = np.array(prototype_indices)
|
|
627
|
+
K_proto_proto = K_ref_ref[np.ix_(proto_idx, proto_idx)]
|
|
628
|
+
K_proto_instance = K_ref_instance[proto_idx, :]
|
|
629
|
+
|
|
630
|
+
weights = self._optimize_weights(K_proto_proto, K_proto_instance, greedy_weights)
|
|
631
|
+
else:
|
|
632
|
+
# Normalize greedy weights for interpretability
|
|
633
|
+
weights = greedy_weights.copy()
|
|
634
|
+
weight_sum = weights.sum()
|
|
635
|
+
if weight_sum > self.epsilon:
|
|
636
|
+
weights = weights / weight_sum
|
|
637
|
+
elif len(weights) > 0:
|
|
638
|
+
weights = np.ones(len(weights)) / len(weights)
|
|
639
|
+
|
|
640
|
+
# Build explanation data
|
|
641
|
+
explanation_data = {
|
|
642
|
+
"prototype_indices": [int(i) for i in prototype_indices],
|
|
643
|
+
"weights": weights.tolist(),
|
|
644
|
+
"prototypes": X_reference[prototype_indices].tolist(),
|
|
645
|
+
"n_prototypes": len(prototype_indices),
|
|
646
|
+
"kernel": self.kernel,
|
|
647
|
+
"kernel_width": self.kernel_width if self.kernel == "rbf" else None,
|
|
648
|
+
"instance": instance.tolist(),
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
if feature_names:
|
|
652
|
+
explanation_data["feature_names"] = feature_names
|
|
653
|
+
|
|
654
|
+
# Add similarity scores
|
|
655
|
+
if return_similarity:
|
|
656
|
+
K_instance_proto = self._compute_kernel(
|
|
657
|
+
instance.reshape(1, -1),
|
|
658
|
+
X_reference[prototype_indices]
|
|
659
|
+
).flatten()
|
|
660
|
+
explanation_data["similarity_scores"] = K_instance_proto.tolist()
|
|
661
|
+
|
|
662
|
+
# Add model predictions if available
|
|
663
|
+
if self.model is not None:
|
|
664
|
+
instance_pred = self.model.predict(instance.reshape(1, -1))
|
|
665
|
+
proto_preds = self.model.predict(X_reference[prototype_indices])
|
|
666
|
+
|
|
667
|
+
explanation_data["instance_prediction"] = instance_pred.tolist()
|
|
668
|
+
explanation_data["prototype_predictions"] = proto_preds.tolist()
|
|
669
|
+
|
|
670
|
+
return Explanation(
|
|
671
|
+
explainer_name="ProtoDash",
|
|
672
|
+
target_class="instance_explanation",
|
|
673
|
+
explanation_data=explanation_data
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
def explain_batch(
|
|
677
|
+
self,
|
|
678
|
+
X: np.ndarray,
|
|
679
|
+
X_reference: np.ndarray,
|
|
680
|
+
feature_names: Optional[List[str]] = None
|
|
681
|
+
) -> List[Explanation]:
|
|
682
|
+
"""
|
|
683
|
+
Explain multiple instances.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
X: Instances to explain (n_instances, n_features).
|
|
687
|
+
X_reference: Reference dataset for prototype selection.
|
|
688
|
+
feature_names: Optional feature names.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
List of Explanation objects, one per instance.
|
|
692
|
+
"""
|
|
693
|
+
X = np.asarray(X, dtype=np.float64)
|
|
694
|
+
if X.ndim == 1:
|
|
695
|
+
X = X.reshape(1, -1)
|
|
696
|
+
|
|
697
|
+
return [
|
|
698
|
+
self.explain(X[i], X_reference, feature_names)
|
|
699
|
+
for i in range(X.shape[0])
|
|
700
|
+
]
|
|
701
|
+
|
|
702
|
+
def find_criticisms(
|
|
703
|
+
self,
|
|
704
|
+
X: np.ndarray,
|
|
705
|
+
prototype_indices: List[int],
|
|
706
|
+
n_criticisms: int = 5,
|
|
707
|
+
feature_names: Optional[List[str]] = None
|
|
708
|
+
) -> Explanation:
|
|
709
|
+
"""
|
|
710
|
+
Find criticisms - examples not well-represented by prototypes.
|
|
711
|
+
|
|
712
|
+
Criticisms are data points that are furthest from the prototype
|
|
713
|
+
representation, highlighting unusual or edge-case examples.
|
|
714
|
+
|
|
715
|
+
This implements the criticism selection from MMD-Critic (Kim et al., 2016).
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
X: Full dataset.
|
|
719
|
+
prototype_indices: Indices of already-selected prototypes.
|
|
720
|
+
n_criticisms: Number of criticisms to find.
|
|
721
|
+
feature_names: Optional feature names.
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
Explanation with criticism indices and their "unusualness" scores.
|
|
725
|
+
"""
|
|
726
|
+
X = np.asarray(X, dtype=np.float64)
|
|
727
|
+
n_samples = X.shape[0]
|
|
728
|
+
|
|
729
|
+
prototype_indices = list(prototype_indices)
|
|
730
|
+
n_crit = min(n_criticisms, n_samples - len(prototype_indices))
|
|
731
|
+
|
|
732
|
+
if n_crit <= 0:
|
|
733
|
+
return Explanation(
|
|
734
|
+
explainer_name="ProtoDash_Criticisms",
|
|
735
|
+
target_class="criticisms",
|
|
736
|
+
explanation_data={
|
|
737
|
+
"criticism_indices": [],
|
|
738
|
+
"unusualness_scores": [],
|
|
739
|
+
"criticisms": []
|
|
740
|
+
}
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
# Auto-compute kernel width if needed
|
|
744
|
+
if self.kernel == "rbf" and self.kernel_width is None:
|
|
745
|
+
self.kernel_width = self._compute_kernel_width(X)
|
|
746
|
+
|
|
747
|
+
# Compute kernel from all points to prototypes
|
|
748
|
+
X_proto = X[prototype_indices]
|
|
749
|
+
K_all_proto = self._compute_kernel(X, X_proto)
|
|
750
|
+
|
|
751
|
+
# For each point, compute its "witness function" value
|
|
752
|
+
# High values = well-represented by prototypes
|
|
753
|
+
# Low values = not well-represented (criticisms)
|
|
754
|
+
|
|
755
|
+
# Mean kernel distance to prototypes
|
|
756
|
+
mean_sim_to_protos = K_all_proto.mean(axis=1)
|
|
757
|
+
|
|
758
|
+
# Mean kernel value to all other points (density estimate)
|
|
759
|
+
K_all_all = self._compute_kernel(X, X)
|
|
760
|
+
mean_sim_to_all = K_all_all.mean(axis=1)
|
|
761
|
+
|
|
762
|
+
# Unusualness = difference between expected similarity and prototype similarity
|
|
763
|
+
# Points with high unusualness are criticisms
|
|
764
|
+
unusualness = mean_sim_to_all - mean_sim_to_protos
|
|
765
|
+
|
|
766
|
+
# Exclude prototypes from consideration
|
|
767
|
+
unusualness[prototype_indices] = -np.inf
|
|
768
|
+
|
|
769
|
+
# Select top criticisms
|
|
770
|
+
criticism_indices = np.argsort(unusualness)[-n_crit:][::-1].tolist()
|
|
771
|
+
criticism_scores = unusualness[criticism_indices].tolist()
|
|
772
|
+
|
|
773
|
+
return Explanation(
|
|
774
|
+
explainer_name="ProtoDash_Criticisms",
|
|
775
|
+
target_class="criticisms",
|
|
776
|
+
explanation_data={
|
|
777
|
+
"criticism_indices": criticism_indices,
|
|
778
|
+
"unusualness_scores": criticism_scores,
|
|
779
|
+
"criticisms": X[criticism_indices].tolist(),
|
|
780
|
+
"n_criticisms": len(criticism_indices),
|
|
781
|
+
"feature_names": feature_names
|
|
782
|
+
}
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def get_prototype_summary(
|
|
786
|
+
self,
|
|
787
|
+
X: np.ndarray,
|
|
788
|
+
y: Optional[np.ndarray] = None,
|
|
789
|
+
feature_names: Optional[List[str]] = None,
|
|
790
|
+
include_criticisms: bool = True,
|
|
791
|
+
n_criticisms: int = 5
|
|
792
|
+
) -> Dict:
|
|
793
|
+
"""
|
|
794
|
+
Generate a complete prototype-based summary of a dataset.
|
|
795
|
+
|
|
796
|
+
Combines prototype selection with optional criticisms for a
|
|
797
|
+
complete data summary.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
X: Dataset to summarize.
|
|
801
|
+
y: Optional labels.
|
|
802
|
+
feature_names: Optional feature names.
|
|
803
|
+
include_criticisms: Whether to also find criticisms.
|
|
804
|
+
n_criticisms: Number of criticisms if including them.
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
Dictionary with prototypes, weights, and optionally criticisms.
|
|
808
|
+
"""
|
|
809
|
+
# Find prototypes
|
|
810
|
+
proto_exp = self.find_prototypes(X, y, feature_names=feature_names, return_mmd=True)
|
|
811
|
+
|
|
812
|
+
result = {
|
|
813
|
+
"prototypes": proto_exp.explanation_data,
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
# Find criticisms if requested
|
|
817
|
+
if include_criticisms:
|
|
818
|
+
crit_exp = self.find_criticisms(
|
|
819
|
+
X,
|
|
820
|
+
proto_exp.explanation_data["prototype_indices"],
|
|
821
|
+
n_criticisms,
|
|
822
|
+
feature_names
|
|
823
|
+
)
|
|
824
|
+
result["criticisms"] = crit_exp.explanation_data
|
|
825
|
+
|
|
826
|
+
return result
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/__init__.py
RENAMED
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/lime_wrapper.py
RENAMED
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/shap_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/ale.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/sage.py
RENAMED
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/__init__.py
RENAMED
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/deeplift.py
RENAMED
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/gradcam.py
RENAMED
|
File without changes
|
|
File without changes
|
{explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|