explainiverse 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {explainiverse-0.3.0 → explainiverse-0.4.0}/PKG-INFO +1 -1
  2. {explainiverse-0.3.0 → explainiverse-0.4.0}/pyproject.toml +1 -1
  3. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/__init__.py +1 -1
  4. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/registry.py +22 -0
  5. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/__init__.py +8 -0
  6. explainiverse-0.4.0/src/explainiverse/explainers/example_based/__init__.py +18 -0
  7. explainiverse-0.4.0/src/explainiverse/explainers/example_based/protodash.py +826 -0
  8. {explainiverse-0.3.0 → explainiverse-0.4.0}/LICENSE +0 -0
  9. {explainiverse-0.3.0 → explainiverse-0.4.0}/README.md +0 -0
  10. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/__init__.py +0 -0
  11. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/base_adapter.py +0 -0
  12. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/pytorch_adapter.py +0 -0
  13. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
  14. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/__init__.py +0 -0
  15. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/explainer.py +0 -0
  16. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/core/explanation.py +0 -0
  17. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/engine/__init__.py +0 -0
  18. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/engine/suite.py +0 -0
  19. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/__init__.py +0 -0
  20. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/_utils.py +0 -0
  21. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/faithfulness.py +0 -0
  22. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/metrics.py +0 -0
  23. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/evaluation/stability.py +0 -0
  24. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/__init__.py +0 -0
  25. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/lime_wrapper.py +0 -0
  26. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -0
  27. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
  28. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
  29. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
  30. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
  31. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
  32. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
  33. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
  34. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
  35. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/__init__.py +0 -0
  36. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/deeplift.py +0 -0
  37. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/gradcam.py +0 -0
  38. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/gradient/integrated_gradients.py +0 -0
  39. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
  40. {explainiverse-0.3.0 → explainiverse-0.4.0}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: explainiverse
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
5
5
  Home-page: https://github.com/jemsbhai/explainiverse
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "explainiverse"
3
- version = "0.3.0"
3
+ version = "0.4.0"
4
4
  description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
5
5
  authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
6
6
  license = "MIT"
@@ -33,7 +33,7 @@ from explainiverse.adapters.sklearn_adapter import SklearnAdapter
33
33
  from explainiverse.adapters import TORCH_AVAILABLE
34
34
  from explainiverse.engine.suite import ExplanationSuite
35
35
 
36
- __version__ = "0.2.5"
36
+ __version__ = "0.4.0"
37
37
 
38
38
  __all__ = [
39
39
  # Core
@@ -372,6 +372,7 @@ def _create_default_registry() -> ExplainerRegistry:
372
372
  from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
373
373
  from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
374
374
  from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
375
+ from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
375
376
 
376
377
  registry = ExplainerRegistry()
377
378
 
@@ -604,6 +605,27 @@ def _create_default_registry() -> ExplainerRegistry:
604
605
  )
605
606
  )
606
607
 
608
+ # =========================================================================
609
+ # Example-Based Explainers
610
+ # =========================================================================
611
+
612
+ # Register ProtoDash
613
+ registry.register(
614
+ name="protodash",
615
+ explainer_class=ProtoDashExplainer,
616
+ meta=ExplainerMeta(
617
+ scope="local",
618
+ model_types=["any"],
619
+ data_types=["tabular"],
620
+ task_types=["classification", "regression"],
621
+ description="ProtoDash - prototype selection with importance weights for example-based explanations",
622
+ paper_reference="Gurumoorthy et al., 2019 - 'Efficient Data Representation by Selecting Prototypes' (ICDM)",
623
+ complexity="O(n_prototypes * n_samples^2)",
624
+ requires_training_data=True,
625
+ supports_batching=True
626
+ )
627
+ )
628
+
607
629
  return registry
608
630
 
609
631
 
@@ -9,12 +9,17 @@ Local Explainers (instance-level):
9
9
  - Anchors: High-precision rule-based explanations
10
10
  - Counterfactual: Diverse counterfactual explanations
11
11
  - Integrated Gradients: Gradient-based attributions for neural networks
12
+ - DeepLIFT: Reference-based attributions for neural networks
13
+ - DeepSHAP: DeepLIFT combined with SHAP for neural networks
12
14
 
13
15
  Global Explainers (model-level):
14
16
  - Permutation Importance: Feature importance via permutation
15
17
  - Partial Dependence: Marginal feature effects (PDP)
16
18
  - ALE: Accumulated Local Effects (unbiased for correlated features)
17
19
  - SAGE: Shapley Additive Global importancE
20
+
21
+ Example-Based Explainers:
22
+ - ProtoDash: Prototype selection with importance weights
18
23
  """
19
24
 
20
25
  from explainiverse.explainers.attribution.lime_wrapper import LimeExplainer
@@ -29,6 +34,7 @@ from explainiverse.explainers.global_explainers.sage import SAGEExplainer
29
34
  from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
30
35
  from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
31
36
  from explainiverse.explainers.gradient.deeplift import DeepLIFTExplainer, DeepLIFTShapExplainer
37
+ from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
32
38
 
33
39
  __all__ = [
34
40
  # Local explainers
@@ -46,4 +52,6 @@ __all__ = [
46
52
  "PartialDependenceExplainer",
47
53
  "ALEExplainer",
48
54
  "SAGEExplainer",
55
+ # Example-based explainers
56
+ "ProtoDashExplainer",
49
57
  ]
@@ -0,0 +1,18 @@
1
+ # src/explainiverse/explainers/example_based/__init__.py
2
+ """
3
+ Example-based explanation methods.
4
+
5
+ These methods explain models by identifying representative examples
6
+ from the training data, rather than computing feature attributions.
7
+
8
+ Methods:
9
+ - ProtoDash: Select prototypical examples with importance weights
10
+ - (Future) Influence Functions: Identify training examples that most affect predictions
11
+ - (Future) MMD-Critic: Find prototypes and criticisms
12
+ """
13
+
14
+ from explainiverse.explainers.example_based.protodash import ProtoDashExplainer
15
+
16
+ __all__ = [
17
+ "ProtoDashExplainer",
18
+ ]
@@ -0,0 +1,826 @@
1
+ # src/explainiverse/explainers/example_based/protodash.py
2
+ """
3
+ ProtoDash - Prototype Selection with Importance Weights.
4
+
5
+ ProtoDash selects a small set of prototypical examples from a dataset
6
+ that best represent the data distribution or explain model predictions.
7
+ Each prototype is assigned an importance weight indicating its contribution.
8
+
9
+ The algorithm minimizes the Maximum Mean Discrepancy (MMD) between:
10
+ - The weighted combination of selected prototypes
11
+ - The target distribution (full dataset or specific instances)
12
+
13
+ Key Features:
14
+ - Works with any model type (or no model at all for data summarization)
15
+ - Provides interpretable weights for each prototype
16
+ - Supports multiple kernel functions (RBF, linear, cosine)
17
+ - Can explain individual predictions or summarize entire datasets
18
+ - Class-conditional prototype selection
19
+
20
+ Use Cases:
21
+ 1. Dataset Summarization: "These 10 examples represent the entire dataset"
22
+ 2. Prediction Explanation: "This prediction is similar to examples A, B, C"
23
+ 3. Model Debugging: "The model relies heavily on these training examples"
24
+ 4. Data Compression: Reduce dataset while preserving distribution
25
+
26
+ Reference:
27
+ Gurumoorthy, K.S., Dhurandhar, A., Cecchi, G., & Aggarwal, C. (2019).
28
+ "Efficient Data Representation by Selecting Prototypes with Importance Weights"
29
+ IEEE International Conference on Data Mining (ICDM).
30
+
31
+ Also based on:
32
+ Kim, B., Khanna, R., & Koyejo, O. (2016).
33
+ "Examples are not Enough, Learn to Criticize! Criticism for Interpretability"
34
+ NeurIPS 2016.
35
+
36
+ Example:
37
+ from explainiverse.explainers.example_based import ProtoDashExplainer
38
+
39
+ # Dataset summarization
40
+ explainer = ProtoDashExplainer(n_prototypes=10, kernel="rbf")
41
+ result = explainer.find_prototypes(X_train)
42
+ print(f"Prototype indices: {result.explanation_data['prototype_indices']}")
43
+ print(f"Weights: {result.explanation_data['weights']}")
44
+
45
+ # Explaining a prediction
46
+ explainer = ProtoDashExplainer(model=adapter, n_prototypes=5)
47
+ explanation = explainer.explain(test_instance, X_reference=X_train)
48
+ """
49
+
50
+ import numpy as np
51
+ from typing import List, Optional, Union, Callable, Tuple, Dict
52
+ from scipy.spatial.distance import cdist
53
+ from scipy.optimize import minimize
54
+
55
+ from explainiverse.core.explainer import BaseExplainer
56
+ from explainiverse.core.explanation import Explanation
57
+
58
+
59
+ class ProtoDashExplainer(BaseExplainer):
60
+ """
61
+ ProtoDash explainer for prototype-based explanations.
62
+
63
+ Selects representative examples (prototypes) from a reference dataset
64
+ that best explain a target distribution or individual predictions.
65
+ Each prototype is assigned an importance weight.
66
+
67
+ The algorithm greedily selects prototypes that minimize the Maximum
68
+ Mean Discrepancy (MMD) between the weighted prototype set and the
69
+ target, then optimizes the weights.
70
+
71
+ Attributes:
72
+ model: Optional model adapter (for prediction-based explanations)
73
+ n_prototypes: Number of prototypes to select
74
+ kernel: Kernel function type ("rbf", "linear", "cosine")
75
+ kernel_width: Width parameter for RBF kernel (auto-computed if None)
76
+ epsilon: Small constant for numerical stability
77
+
78
+ Example:
79
+ >>> explainer = ProtoDashExplainer(n_prototypes=5, kernel="rbf")
80
+ >>> result = explainer.find_prototypes(X_train)
81
+ >>> prototypes = X_train[result.explanation_data['prototype_indices']]
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ model=None,
87
+ n_prototypes: int = 10,
88
+ kernel: str = "rbf",
89
+ kernel_width: Optional[float] = None,
90
+ epsilon: float = 1e-10,
91
+ optimize_weights: bool = True,
92
+ random_state: Optional[int] = None,
93
+ force_n_prototypes: bool = True
94
+ ):
95
+ """
96
+ Initialize the ProtoDash explainer.
97
+
98
+ Args:
99
+ model: Optional model adapter. If provided, can use model
100
+ predictions in the kernel computation for explanation.
101
+ n_prototypes: Number of prototypes to select (default: 10).
102
+ kernel: Kernel function type:
103
+ - "rbf": Radial Basis Function (Gaussian) kernel
104
+ - "linear": Linear kernel (dot product)
105
+ - "cosine": Cosine similarity kernel
106
+ kernel_width: Width (sigma) for RBF kernel. If None, uses
107
+ median heuristic based on pairwise distances.
108
+ epsilon: Small constant for numerical stability (default: 1e-10).
109
+ optimize_weights: If True, optimize weights after greedy selection.
110
+ If False, use weights from greedy selection only.
111
+ random_state: Random seed for reproducibility.
112
+ force_n_prototypes: If True (default), always select exactly
113
+ n_prototypes (or all available if fewer).
114
+ If False, may stop early when gain becomes
115
+ negative (original ProtoDash behavior).
116
+ """
117
+ super().__init__(model)
118
+
119
+ self.n_prototypes = n_prototypes
120
+ self.kernel = kernel.lower()
121
+ self.kernel_width = kernel_width
122
+ self.epsilon = epsilon
123
+ self.optimize_weights = optimize_weights
124
+ self.random_state = random_state
125
+ self.force_n_prototypes = force_n_prototypes
126
+
127
+ if self.kernel not in ["rbf", "linear", "cosine"]:
128
+ raise ValueError(
129
+ f"Unknown kernel '{kernel}'. Supported: 'rbf', 'linear', 'cosine'"
130
+ )
131
+
132
+ # Cache for kernel matrix
133
+ self._kernel_matrix_cache = None
134
+ self._reference_data_hash = None
135
+
136
+ def _compute_kernel_width(self, X: np.ndarray) -> float:
137
+ """
138
+ Compute kernel width using median heuristic.
139
+
140
+ The median heuristic sets sigma = median of pairwise distances,
141
+ which is a common rule of thumb for RBF kernels.
142
+
143
+ Args:
144
+ X: Data matrix of shape (n_samples, n_features)
145
+
146
+ Returns:
147
+ Kernel width (sigma) value
148
+ """
149
+ # Subsample for efficiency if dataset is large
150
+ n_samples = X.shape[0]
151
+ if n_samples > 1000:
152
+ if self.random_state is not None:
153
+ np.random.seed(self.random_state)
154
+ indices = np.random.choice(n_samples, size=1000, replace=False)
155
+ X_sample = X[indices]
156
+ else:
157
+ X_sample = X
158
+
159
+ # Compute pairwise distances
160
+ distances = cdist(X_sample, X_sample, metric='euclidean')
161
+
162
+ # Get median of non-zero distances
163
+ mask = distances > 0
164
+ if np.any(mask):
165
+ median_dist = np.median(distances[mask])
166
+ else:
167
+ median_dist = 1.0
168
+
169
+ return max(median_dist, self.epsilon)
170
+
171
+ def _compute_kernel(
172
+ self,
173
+ X: np.ndarray,
174
+ Y: Optional[np.ndarray] = None,
175
+ kernel_width: Optional[float] = None
176
+ ) -> np.ndarray:
177
+ """
178
+ Compute kernel matrix between X and Y.
179
+
180
+ Args:
181
+ X: First data matrix of shape (n_samples_X, n_features)
182
+ Y: Second data matrix of shape (n_samples_Y, n_features).
183
+ If None, computes K(X, X).
184
+ kernel_width: Override kernel width for RBF kernel.
185
+
186
+ Returns:
187
+ Kernel matrix of shape (n_samples_X, n_samples_Y)
188
+ """
189
+ if Y is None:
190
+ Y = X
191
+
192
+ if self.kernel == "rbf":
193
+ sigma = kernel_width or self.kernel_width
194
+ if sigma is None:
195
+ sigma = self._compute_kernel_width(X)
196
+
197
+ # K(x, y) = exp(-||x - y||^2 / (2 * sigma^2))
198
+ sq_dists = cdist(X, Y, metric='sqeuclidean')
199
+ K = np.exp(-sq_dists / (2 * sigma ** 2))
200
+
201
+ elif self.kernel == "linear":
202
+ # K(x, y) = x · y
203
+ K = X @ Y.T
204
+
205
+ elif self.kernel == "cosine":
206
+ # K(x, y) = (x · y) / (||x|| * ||y||)
207
+ X_norm = X / (np.linalg.norm(X, axis=1, keepdims=True) + self.epsilon)
208
+ Y_norm = Y / (np.linalg.norm(Y, axis=1, keepdims=True) + self.epsilon)
209
+ K = X_norm @ Y_norm.T
210
+
211
+ else:
212
+ raise ValueError(f"Unknown kernel: {self.kernel}")
213
+
214
+ return K
215
+
216
+ def _greedy_prototype_selection(
217
+ self,
218
+ K_ref_ref: np.ndarray,
219
+ K_ref_target: np.ndarray,
220
+ n_prototypes: int,
221
+ force_n_prototypes: bool = True
222
+ ) -> Tuple[List[int], np.ndarray]:
223
+ """
224
+ ProtoDash greedy prototype selection with iterative weight optimization.
225
+
226
+ Implements the algorithm from:
227
+ Gurumoorthy et al., 2019 - "Efficient Data Representation by Selecting
228
+ Prototypes with Importance Weights" (ICDM)
229
+
230
+ The algorithm solves:
231
+ min_w (1/2) w^T K w - w^T μ
232
+ s.t. w >= 0
233
+
234
+ where μ_j = mean(K(x_j, target_points)) is the mean kernel similarity
235
+ of candidate j to all target points.
236
+
237
+ At each iteration:
238
+ 1. Compute gradient gain for each unselected candidate
239
+ 2. Select the candidate with maximum positive gain
240
+ 3. Re-optimize weights over all selected prototypes
241
+
242
+ Args:
243
+ K_ref_ref: Kernel matrix K(reference, reference) of shape (n_ref, n_ref)
244
+ K_ref_target: Kernel matrix K(reference, target) of shape (n_ref, n_target)
245
+ n_prototypes: Number of prototypes to select
246
+ force_n_prototypes: If True, always select n_prototypes even if gain
247
+ becomes negative. If False, stop when no positive gain.
248
+
249
+ Returns:
250
+ Tuple of (prototype_indices, weights)
251
+ """
252
+ n_ref = K_ref_ref.shape[0]
253
+
254
+ # μ_j = mean kernel similarity of candidate j to target distribution
255
+ # This is the linear term in the QP objective
256
+ mu = K_ref_target.mean(axis=1)
257
+
258
+ # Track selected prototypes and their optimized weights
259
+ selected_indices = []
260
+ # Full weight vector (sparse, only selected indices are non-zero)
261
+ weights = np.zeros(n_ref)
262
+
263
+ for iteration in range(min(n_prototypes, n_ref)):
264
+ # Compute gradient gain for each candidate
265
+ # For the objective L(w) = (1/2) w^T K w - w^T μ
266
+ # Gradient: ∇L = K w - μ
267
+ # Gain for adding point j (currently w_j = 0): gain_j = μ_j - (Kw)_j
268
+ # We want to maximize gain, which means minimizing the objective
269
+
270
+ gradient = K_ref_ref @ weights - mu # ∇L
271
+ gains = -gradient # gain = μ - Kw (negative gradient = descent direction)
272
+
273
+ # Mask already selected indices
274
+ gains_masked = gains.copy()
275
+ gains_masked[selected_indices] = -np.inf
276
+
277
+ # Select candidate with maximum gain
278
+ best_idx = np.argmax(gains_masked)
279
+ best_gain = gains_masked[best_idx]
280
+
281
+ # Early stopping check (only if not forcing n_prototypes)
282
+ if not force_n_prototypes and best_gain <= self.epsilon:
283
+ break
284
+
285
+ selected_indices.append(best_idx)
286
+
287
+ # Re-optimize weights over all selected prototypes
288
+ # Solve: min_w (1/2) w^T K_ss w - w^T μ_s, s.t. w >= 0
289
+ # where K_ss is kernel matrix restricted to selected indices
290
+ # and μ_s is mu restricted to selected indices
291
+
292
+ selected_arr = np.array(selected_indices)
293
+ K_selected = K_ref_ref[np.ix_(selected_arr, selected_arr)]
294
+ mu_selected = mu[selected_arr]
295
+
296
+ # Optimize weights for selected prototypes
297
+ w_selected = self._optimize_weights_qp(K_selected, mu_selected)
298
+
299
+ # Update full weight vector
300
+ weights = np.zeros(n_ref)
301
+ weights[selected_arr] = w_selected
302
+
303
+ # Return only the selected indices and their weights
304
+ if len(selected_indices) == 0:
305
+ return [], np.array([])
306
+
307
+ final_weights = weights[np.array(selected_indices)]
308
+ return selected_indices, final_weights
309
+
310
+ def _optimize_weights_qp(
311
+ self,
312
+ K: np.ndarray,
313
+ mu: np.ndarray,
314
+ normalize: bool = False
315
+ ) -> np.ndarray:
316
+ """
317
+ Optimize prototype weights via constrained quadratic programming.
318
+
319
+ Solves:
320
+ min_w (1/2) w^T K w - w^T μ
321
+ s.t. w >= 0
322
+ (optional) sum(w) = 1
323
+
324
+ Uses scipy.optimize.minimize with SLSQP method.
325
+
326
+ Args:
327
+ K: Kernel matrix between selected prototypes (m x m)
328
+ mu: Mean kernel similarity to target for each prototype (m,)
329
+ normalize: If True, constrain weights to sum to 1
330
+
331
+ Returns:
332
+ Optimized non-negative weights
333
+ """
334
+ m = K.shape[0]
335
+
336
+ if m == 0:
337
+ return np.array([])
338
+
339
+ if m == 1:
340
+ # Single prototype: optimal weight is μ/K if K > 0
341
+ if K[0, 0] > self.epsilon:
342
+ w = max(mu[0] / K[0, 0], 0)
343
+ else:
344
+ w = 1.0
345
+ return np.array([w]) if not normalize else np.array([1.0])
346
+
347
+ # Add small regularization for numerical stability
348
+ K_reg = K + self.epsilon * np.eye(m)
349
+
350
+ # Objective: (1/2) w^T K w - w^T μ
351
+ def objective(w):
352
+ return 0.5 * w @ K_reg @ w - w @ mu
353
+
354
+ def gradient(w):
355
+ return K_reg @ w - mu
356
+
357
+ # Initial guess: equal weights
358
+ w0 = np.ones(m) / m
359
+
360
+ # Bounds: w >= 0
361
+ bounds = [(0, None) for _ in range(m)]
362
+
363
+ # Constraints
364
+ constraints = []
365
+ if normalize:
366
+ constraints.append({'type': 'eq', 'fun': lambda w: np.sum(w) - 1.0})
367
+
368
+ # Optimize
369
+ result = minimize(
370
+ objective,
371
+ w0,
372
+ method='SLSQP',
373
+ jac=gradient,
374
+ bounds=bounds,
375
+ constraints=constraints,
376
+ options={'maxiter': 500, 'ftol': 1e-12}
377
+ )
378
+
379
+ weights = result.x
380
+
381
+ # Ensure non-negativity (numerical cleanup)
382
+ weights = np.maximum(weights, 0)
383
+
384
+ return weights
385
+
386
+ def _optimize_weights(
387
+ self,
388
+ K_proto_proto: np.ndarray,
389
+ K_proto_target: np.ndarray,
390
+ initial_weights: np.ndarray
391
+ ) -> np.ndarray:
392
+ """
393
+ Final weight optimization for selected prototypes.
394
+
395
+ This is called after greedy selection to do a final refinement
396
+ of weights, optionally with normalization for interpretability.
397
+
398
+ Solves the same QP as _optimize_weights_qp but uses the
399
+ mean kernel to target as the linear term.
400
+
401
+ Args:
402
+ K_proto_proto: Kernel matrix between prototypes (m x m)
403
+ K_proto_target: Kernel matrix from prototypes to target (m x n_target)
404
+ initial_weights: Initial weights from greedy selection
405
+
406
+ Returns:
407
+ Optimized weights (non-negative, optionally normalized)
408
+ """
409
+ n_proto = K_proto_proto.shape[0]
410
+
411
+ if n_proto == 0:
412
+ return np.array([])
413
+
414
+ if n_proto == 1:
415
+ return np.array([1.0]) # Single prototype gets weight 1
416
+
417
+ # Target: mean kernel to target points
418
+ mu = K_proto_target.mean(axis=1)
419
+
420
+ # Use the QP solver
421
+ weights = self._optimize_weights_qp(K_proto_proto, mu, normalize=False)
422
+
423
+ # Normalize for interpretability (weights sum to 1)
424
+ weight_sum = weights.sum()
425
+ if weight_sum > self.epsilon:
426
+ weights = weights / weight_sum
427
+ else:
428
+ # Fallback to equal weights if optimization failed
429
+ weights = np.ones(n_proto) / n_proto
430
+
431
+ return weights
432
+
433
+ def find_prototypes(
434
+ self,
435
+ X: np.ndarray,
436
+ y: Optional[np.ndarray] = None,
437
+ target_class: Optional[int] = None,
438
+ feature_names: Optional[List[str]] = None,
439
+ return_mmd: bool = False
440
+ ) -> Explanation:
441
+ """
442
+ Find prototypes that summarize a dataset.
443
+
444
+ Selects a small set of examples from X that best represent
445
+ the data distribution. If y is provided, can select prototypes
446
+ for a specific class.
447
+
448
+ Args:
449
+ X: Data matrix of shape (n_samples, n_features).
450
+ y: Optional labels. If provided with target_class, selects
451
+ prototypes only from that class.
452
+ target_class: If provided with y, only consider examples
453
+ from this class as candidates.
454
+ feature_names: Optional list of feature names.
455
+ return_mmd: If True, include MMD score in explanation.
456
+
457
+ Returns:
458
+ Explanation object containing:
459
+ - prototype_indices: Indices of selected prototypes in X
460
+ - weights: Importance weight for each prototype
461
+ - prototypes: The actual prototype data points
462
+ - mmd_score: (optional) Final MMD between prototypes and data
463
+ """
464
+ X = np.asarray(X, dtype=np.float64)
465
+
466
+ if X.ndim == 1:
467
+ X = X.reshape(1, -1)
468
+
469
+ n_samples, n_features = X.shape
470
+
471
+ # Filter by class if specified
472
+ if y is not None and target_class is not None:
473
+ y = np.asarray(y)
474
+ class_mask = (y == target_class)
475
+ X_candidates = X[class_mask]
476
+ original_indices = np.where(class_mask)[0]
477
+ else:
478
+ X_candidates = X
479
+ original_indices = np.arange(n_samples)
480
+
481
+ n_candidates = X_candidates.shape[0]
482
+ n_proto = min(self.n_prototypes, n_candidates)
483
+
484
+ if n_proto == 0:
485
+ raise ValueError("No candidate examples available for prototype selection.")
486
+
487
+ # Auto-compute kernel width if needed
488
+ if self.kernel == "rbf" and self.kernel_width is None:
489
+ self.kernel_width = self._compute_kernel_width(X_candidates)
490
+
491
+ # Compute kernel matrices
492
+ # K(candidates, candidates) for prototype selection
493
+ # K(candidates, X) for representing the full distribution
494
+ K_cand_cand = self._compute_kernel(X_candidates, X_candidates)
495
+ K_cand_all = self._compute_kernel(X_candidates, X)
496
+
497
+ # Greedy prototype selection
498
+ local_indices, greedy_weights = self._greedy_prototype_selection(
499
+ K_cand_cand, K_cand_all, n_proto, self.force_n_prototypes
500
+ )
501
+
502
+ # Convert to original indices
503
+ prototype_indices = [int(original_indices[i]) for i in local_indices]
504
+
505
+ # Optimize weights if requested
506
+ if self.optimize_weights and len(local_indices) > 1:
507
+ # Get kernel matrices for selected prototypes
508
+ proto_local_idx = np.array(local_indices)
509
+ K_proto_proto = K_cand_cand[np.ix_(proto_local_idx, proto_local_idx)]
510
+ K_proto_all = K_cand_all[proto_local_idx, :]
511
+
512
+ weights = self._optimize_weights(K_proto_proto, K_proto_all, greedy_weights)
513
+ else:
514
+ # Normalize greedy weights for interpretability
515
+ weights = greedy_weights.copy()
516
+ weight_sum = weights.sum()
517
+ if weight_sum > self.epsilon:
518
+ weights = weights / weight_sum
519
+ elif len(weights) > 0:
520
+ weights = np.ones(len(weights)) / len(weights)
521
+
522
+ # Build explanation data
523
+ explanation_data = {
524
+ "prototype_indices": prototype_indices,
525
+ "weights": weights.tolist(),
526
+ "prototypes": X[prototype_indices].tolist(),
527
+ "n_prototypes": len(prototype_indices),
528
+ "kernel": self.kernel,
529
+ "kernel_width": self.kernel_width if self.kernel == "rbf" else None,
530
+ }
531
+
532
+ if feature_names:
533
+ explanation_data["feature_names"] = feature_names
534
+
535
+ # Compute MMD if requested
536
+ if return_mmd:
537
+ proto_idx_local = np.array(local_indices)
538
+ K_pp = K_cand_cand[np.ix_(proto_idx_local, proto_idx_local)]
539
+ K_pa = K_cand_all[proto_idx_local, :]
540
+ K_aa = self._compute_kernel(X, X)
541
+
542
+ # MMD^2 = w^T K_pp w - 2 * w^T K_pa.mean() + K_aa.mean()
543
+ w = np.array(weights)
544
+ mmd_sq = w @ K_pp @ w - 2 * w @ K_pa.mean(axis=1) + K_aa.mean()
545
+ mmd = np.sqrt(max(mmd_sq, 0))
546
+
547
+ explanation_data["mmd_score"] = float(mmd)
548
+
549
+ # Determine label
550
+ if target_class is not None:
551
+ label_name = f"class_{target_class}"
552
+ else:
553
+ label_name = "dataset"
554
+
555
+ return Explanation(
556
+ explainer_name="ProtoDash",
557
+ target_class=label_name,
558
+ explanation_data=explanation_data
559
+ )
560
+
561
+ def explain(
562
+ self,
563
+ instance: np.ndarray,
564
+ X_reference: np.ndarray,
565
+ feature_names: Optional[List[str]] = None,
566
+ use_predictions: bool = False,
567
+ return_similarity: bool = True
568
+ ) -> Explanation:
569
+ """
570
+ Explain a prediction by finding similar prototypes.
571
+
572
+ Finds prototypes from the reference set that are most similar
573
+ to the given instance, providing a "this is like..." explanation.
574
+
575
+ Args:
576
+ instance: Instance to explain (1D array of shape n_features).
577
+ X_reference: Reference dataset to select prototypes from
578
+ (shape: n_samples, n_features).
579
+ feature_names: Optional list of feature names.
580
+ use_predictions: If True and model is provided, include model
581
+ predictions in the similarity computation.
582
+ return_similarity: If True, include similarity scores.
583
+
584
+ Returns:
585
+ Explanation object containing prototype indices and weights.
586
+ """
587
+ instance = np.asarray(instance, dtype=np.float64).flatten()
588
+ X_reference = np.asarray(X_reference, dtype=np.float64)
589
+
590
+ if X_reference.ndim == 1:
591
+ X_reference = X_reference.reshape(1, -1)
592
+
593
+ n_ref, n_features = X_reference.shape
594
+ n_proto = min(self.n_prototypes, n_ref)
595
+
596
+ # Auto-compute kernel width if needed
597
+ if self.kernel == "rbf" and self.kernel_width is None:
598
+ self.kernel_width = self._compute_kernel_width(X_reference)
599
+
600
+ # If using predictions and model is available, augment features
601
+ if use_predictions and self.model is not None:
602
+ # Get predictions for instance and reference
603
+ instance_pred = self.model.predict(instance.reshape(1, -1)).flatten()
604
+ ref_preds = self.model.predict(X_reference)
605
+
606
+ # Augment features with predictions
607
+ instance_aug = np.concatenate([instance, instance_pred])
608
+ X_ref_aug = np.hstack([X_reference, ref_preds])
609
+ else:
610
+ instance_aug = instance
611
+ X_ref_aug = X_reference
612
+
613
+ # Compute kernel matrices
614
+ # K(reference, reference) for prototype selection
615
+ # K(reference, instance) as target
616
+ K_ref_ref = self._compute_kernel(X_ref_aug, X_ref_aug)
617
+ K_ref_instance = self._compute_kernel(X_ref_aug, instance_aug.reshape(1, -1))
618
+
619
+ # Greedy prototype selection
620
+ prototype_indices, greedy_weights = self._greedy_prototype_selection(
621
+ K_ref_ref, K_ref_instance, n_proto, self.force_n_prototypes
622
+ )
623
+
624
+ # Optimize weights
625
+ if self.optimize_weights and len(prototype_indices) > 1:
626
+ proto_idx = np.array(prototype_indices)
627
+ K_proto_proto = K_ref_ref[np.ix_(proto_idx, proto_idx)]
628
+ K_proto_instance = K_ref_instance[proto_idx, :]
629
+
630
+ weights = self._optimize_weights(K_proto_proto, K_proto_instance, greedy_weights)
631
+ else:
632
+ # Normalize greedy weights for interpretability
633
+ weights = greedy_weights.copy()
634
+ weight_sum = weights.sum()
635
+ if weight_sum > self.epsilon:
636
+ weights = weights / weight_sum
637
+ elif len(weights) > 0:
638
+ weights = np.ones(len(weights)) / len(weights)
639
+
640
+ # Build explanation data
641
+ explanation_data = {
642
+ "prototype_indices": [int(i) for i in prototype_indices],
643
+ "weights": weights.tolist(),
644
+ "prototypes": X_reference[prototype_indices].tolist(),
645
+ "n_prototypes": len(prototype_indices),
646
+ "kernel": self.kernel,
647
+ "kernel_width": self.kernel_width if self.kernel == "rbf" else None,
648
+ "instance": instance.tolist(),
649
+ }
650
+
651
+ if feature_names:
652
+ explanation_data["feature_names"] = feature_names
653
+
654
+ # Add similarity scores
655
+ if return_similarity:
656
+ K_instance_proto = self._compute_kernel(
657
+ instance.reshape(1, -1),
658
+ X_reference[prototype_indices]
659
+ ).flatten()
660
+ explanation_data["similarity_scores"] = K_instance_proto.tolist()
661
+
662
+ # Add model predictions if available
663
+ if self.model is not None:
664
+ instance_pred = self.model.predict(instance.reshape(1, -1))
665
+ proto_preds = self.model.predict(X_reference[prototype_indices])
666
+
667
+ explanation_data["instance_prediction"] = instance_pred.tolist()
668
+ explanation_data["prototype_predictions"] = proto_preds.tolist()
669
+
670
+ return Explanation(
671
+ explainer_name="ProtoDash",
672
+ target_class="instance_explanation",
673
+ explanation_data=explanation_data
674
+ )
675
+
676
+ def explain_batch(
677
+ self,
678
+ X: np.ndarray,
679
+ X_reference: np.ndarray,
680
+ feature_names: Optional[List[str]] = None
681
+ ) -> List[Explanation]:
682
+ """
683
+ Explain multiple instances.
684
+
685
+ Args:
686
+ X: Instances to explain (n_instances, n_features).
687
+ X_reference: Reference dataset for prototype selection.
688
+ feature_names: Optional feature names.
689
+
690
+ Returns:
691
+ List of Explanation objects, one per instance.
692
+ """
693
+ X = np.asarray(X, dtype=np.float64)
694
+ if X.ndim == 1:
695
+ X = X.reshape(1, -1)
696
+
697
+ return [
698
+ self.explain(X[i], X_reference, feature_names)
699
+ for i in range(X.shape[0])
700
+ ]
701
+
702
+ def find_criticisms(
703
+ self,
704
+ X: np.ndarray,
705
+ prototype_indices: List[int],
706
+ n_criticisms: int = 5,
707
+ feature_names: Optional[List[str]] = None
708
+ ) -> Explanation:
709
+ """
710
+ Find criticisms - examples not well-represented by prototypes.
711
+
712
+ Criticisms are data points that are furthest from the prototype
713
+ representation, highlighting unusual or edge-case examples.
714
+
715
+ This implements the criticism selection from MMD-Critic (Kim et al., 2016).
716
+
717
+ Args:
718
+ X: Full dataset.
719
+ prototype_indices: Indices of already-selected prototypes.
720
+ n_criticisms: Number of criticisms to find.
721
+ feature_names: Optional feature names.
722
+
723
+ Returns:
724
+ Explanation with criticism indices and their "unusualness" scores.
725
+ """
726
+ X = np.asarray(X, dtype=np.float64)
727
+ n_samples = X.shape[0]
728
+
729
+ prototype_indices = list(prototype_indices)
730
+ n_crit = min(n_criticisms, n_samples - len(prototype_indices))
731
+
732
+ if n_crit <= 0:
733
+ return Explanation(
734
+ explainer_name="ProtoDash_Criticisms",
735
+ target_class="criticisms",
736
+ explanation_data={
737
+ "criticism_indices": [],
738
+ "unusualness_scores": [],
739
+ "criticisms": []
740
+ }
741
+ )
742
+
743
+ # Auto-compute kernel width if needed
744
+ if self.kernel == "rbf" and self.kernel_width is None:
745
+ self.kernel_width = self._compute_kernel_width(X)
746
+
747
+ # Compute kernel from all points to prototypes
748
+ X_proto = X[prototype_indices]
749
+ K_all_proto = self._compute_kernel(X, X_proto)
750
+
751
+ # For each point, compute its "witness function" value
752
+ # High values = well-represented by prototypes
753
+ # Low values = not well-represented (criticisms)
754
+
755
+ # Mean kernel distance to prototypes
756
+ mean_sim_to_protos = K_all_proto.mean(axis=1)
757
+
758
+ # Mean kernel value to all other points (density estimate)
759
+ K_all_all = self._compute_kernel(X, X)
760
+ mean_sim_to_all = K_all_all.mean(axis=1)
761
+
762
+ # Unusualness = difference between expected similarity and prototype similarity
763
+ # Points with high unusualness are criticisms
764
+ unusualness = mean_sim_to_all - mean_sim_to_protos
765
+
766
+ # Exclude prototypes from consideration
767
+ unusualness[prototype_indices] = -np.inf
768
+
769
+ # Select top criticisms
770
+ criticism_indices = np.argsort(unusualness)[-n_crit:][::-1].tolist()
771
+ criticism_scores = unusualness[criticism_indices].tolist()
772
+
773
+ return Explanation(
774
+ explainer_name="ProtoDash_Criticisms",
775
+ target_class="criticisms",
776
+ explanation_data={
777
+ "criticism_indices": criticism_indices,
778
+ "unusualness_scores": criticism_scores,
779
+ "criticisms": X[criticism_indices].tolist(),
780
+ "n_criticisms": len(criticism_indices),
781
+ "feature_names": feature_names
782
+ }
783
+ )
784
+
785
+ def get_prototype_summary(
786
+ self,
787
+ X: np.ndarray,
788
+ y: Optional[np.ndarray] = None,
789
+ feature_names: Optional[List[str]] = None,
790
+ include_criticisms: bool = True,
791
+ n_criticisms: int = 5
792
+ ) -> Dict:
793
+ """
794
+ Generate a complete prototype-based summary of a dataset.
795
+
796
+ Combines prototype selection with optional criticisms for a
797
+ complete data summary.
798
+
799
+ Args:
800
+ X: Dataset to summarize.
801
+ y: Optional labels.
802
+ feature_names: Optional feature names.
803
+ include_criticisms: Whether to also find criticisms.
804
+ n_criticisms: Number of criticisms if including them.
805
+
806
+ Returns:
807
+ Dictionary with prototypes, weights, and optionally criticisms.
808
+ """
809
+ # Find prototypes
810
+ proto_exp = self.find_prototypes(X, y, feature_names=feature_names, return_mmd=True)
811
+
812
+ result = {
813
+ "prototypes": proto_exp.explanation_data,
814
+ }
815
+
816
+ # Find criticisms if requested
817
+ if include_criticisms:
818
+ crit_exp = self.find_criticisms(
819
+ X,
820
+ proto_exp.explanation_data["prototype_indices"],
821
+ n_criticisms,
822
+ feature_names
823
+ )
824
+ result["criticisms"] = crit_exp.explanation_data
825
+
826
+ return result
File without changes
File without changes