explainiverse 0.2.3__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {explainiverse-0.2.3 → explainiverse-0.2.4}/PKG-INFO +38 -8
  2. {explainiverse-0.2.3 → explainiverse-0.2.4}/README.md +37 -7
  3. {explainiverse-0.2.3 → explainiverse-0.2.4}/pyproject.toml +1 -1
  4. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/__init__.py +1 -1
  5. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/core/registry.py +18 -0
  6. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/__init__.py +2 -0
  7. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/gradient/__init__.py +2 -1
  8. explainiverse-0.2.4/src/explainiverse/explainers/gradient/gradcam.py +390 -0
  9. {explainiverse-0.2.3 → explainiverse-0.2.4}/LICENSE +0 -0
  10. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/adapters/__init__.py +0 -0
  11. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/adapters/base_adapter.py +0 -0
  12. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/adapters/pytorch_adapter.py +0 -0
  13. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/adapters/sklearn_adapter.py +0 -0
  14. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/core/__init__.py +0 -0
  15. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/core/explainer.py +0 -0
  16. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/core/explanation.py +0 -0
  17. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/engine/__init__.py +0 -0
  18. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/engine/suite.py +0 -0
  19. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/evaluation/__init__.py +0 -0
  20. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/evaluation/metrics.py +0 -0
  21. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/attribution/__init__.py +0 -0
  22. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/attribution/lime_wrapper.py +0 -0
  23. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/attribution/shap_wrapper.py +0 -0
  24. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/attribution/treeshap_wrapper.py +0 -0
  25. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/counterfactual/__init__.py +0 -0
  26. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/counterfactual/dice_wrapper.py +0 -0
  27. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/global_explainers/__init__.py +0 -0
  28. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/global_explainers/ale.py +0 -0
  29. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/global_explainers/partial_dependence.py +0 -0
  30. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/global_explainers/permutation_importance.py +0 -0
  31. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/global_explainers/sage.py +0 -0
  32. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/gradient/integrated_gradients.py +0 -0
  33. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/rule_based/__init__.py +0 -0
  34. {explainiverse-0.2.3 → explainiverse-0.2.4}/src/explainiverse/explainers/rule_based/anchors_wrapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: explainiverse
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more
5
5
  Home-page: https://github.com/jemsbhai/explainiverse
6
6
  License: MIT
@@ -31,7 +31,7 @@ Description-Content-Type: text/markdown
31
31
  # Explainiverse
32
32
 
33
33
  **Explainiverse** is a unified, extensible Python framework for Explainable AI (XAI).
34
- It provides a standardized interface for model-agnostic explainability with 10 state-of-the-art XAI methods, evaluation metrics, and a plugin registry for easy extensibility.
34
+ It provides a standardized interface for model-agnostic explainability with 11 state-of-the-art XAI methods, evaluation metrics, and a plugin registry for easy extensibility.
35
35
 
36
36
  ---
37
37
 
@@ -44,6 +44,7 @@ It provides a standardized interface for model-agnostic explainability with 10 s
44
44
  - **SHAP** - SHapley Additive exPlanations via KernelSHAP ([Lundberg & Lee, 2017](https://arxiv.org/abs/1705.07874))
45
45
  - **TreeSHAP** - Exact SHAP values for tree models, 10x+ faster ([Lundberg et al., 2018](https://arxiv.org/abs/1802.03888))
46
46
  - **Integrated Gradients** - Axiomatic attributions for neural networks ([Sundararajan et al., 2017](https://arxiv.org/abs/1703.01365))
47
+ - **GradCAM/GradCAM++** - Visual explanations for CNNs ([Selvaraju et al., 2017](https://arxiv.org/abs/1610.02391))
47
48
  - **Anchors** - High-precision rule-based explanations ([Ribeiro et al., 2018](https://ojs.aaai.org/index.php/AAAI/article/view/11491))
48
49
  - **Counterfactual** - DiCE-style diverse counterfactual explanations ([Mothilal et al., 2020](https://arxiv.org/abs/1905.07697))
49
50
 
@@ -110,7 +111,7 @@ adapter = SklearnAdapter(model, class_names=iris.target_names.tolist())
110
111
 
111
112
  # List available explainers
112
113
  print(default_registry.list_explainers())
113
- # ['lime', 'shap', 'treeshap', 'integrated_gradients', 'anchors', 'counterfactual', 'permutation_importance', 'partial_dependence', 'ale', 'sage']
114
+ # ['lime', 'shap', 'treeshap', 'integrated_gradients', 'gradcam', 'anchors', 'counterfactual', 'permutation_importance', 'partial_dependence', 'ale', 'sage']
114
115
 
115
116
  # Create and use an explainer
116
117
  explainer = default_registry.create(
@@ -131,9 +132,9 @@ print(explanation.explanation_data["feature_attributions"])
131
132
  local_tabular = default_registry.filter(scope="local", data_type="tabular")
132
133
  print(local_tabular) # ['lime', 'shap', 'treeshap', 'integrated_gradients', 'anchors', 'counterfactual']
133
134
 
134
- # Find explainers optimized for tree models
135
- tree_explainers = default_registry.filter(model_type="tree")
136
- print(tree_explainers) # ['treeshap']
135
+ # Find explainers for images/CNNs
136
+ image_explainers = default_registry.filter(data_type="image")
137
+ print(image_explainers) # ['lime', 'integrated_gradients', 'gradcam']
137
138
 
138
139
  # Get recommendations
139
140
  recommendations = default_registry.recommend(
@@ -227,6 +228,35 @@ explanation = explainer.explain(X_test[0], return_convergence_delta=True)
227
228
  print(f"Convergence delta: {explanation.explanation_data['convergence_delta']}")
228
229
  ```
229
230
 
231
+ ### GradCAM for CNN Visual Explanations
232
+
233
+ ```python
234
+ from explainiverse.explainers import GradCAMExplainer
235
+ from explainiverse import PyTorchAdapter
236
+
237
+ # Wrap your CNN model
238
+ adapter = PyTorchAdapter(cnn_model, task="classification", class_names=class_names)
239
+
240
+ # Find the last convolutional layer
241
+ layers = adapter.list_layers()
242
+ target_layer = "layer4" # Adjust based on your model architecture
243
+
244
+ # Create GradCAM explainer
245
+ explainer = GradCAMExplainer(
246
+ model=adapter,
247
+ target_layer=target_layer,
248
+ class_names=class_names,
249
+ method="gradcam" # or "gradcam++" for improved version
250
+ )
251
+
252
+ # Explain an image prediction
253
+ explanation = explainer.explain(image) # image shape: (C, H, W) or (N, C, H, W)
254
+ heatmap = explanation.explanation_data["heatmap"]
255
+
256
+ # Create overlay visualization
257
+ overlay = explainer.get_overlay(original_image, heatmap, alpha=0.5)
258
+ ```
259
+
230
260
  ### Using Specific Explainers
231
261
 
232
262
  ```python
@@ -332,8 +362,8 @@ poetry run pytest tests/test_new_explainers.py -v
332
362
  - [x] Permutation Importance, PDP, ALE, SAGE
333
363
  - [x] Explainer Registry with filtering
334
364
  - [x] PyTorch Adapter ✅
335
- - [x] Integrated Gradients ✅ NEW
336
- - [ ] GradCAM for CNNs
365
+ - [x] Integrated Gradients ✅
366
+ - [x] GradCAM/GradCAM++ for CNNs ✅ NEW
337
367
  - [ ] TensorFlow adapter
338
368
  - [ ] Interactive visualization dashboard
339
369
 
@@ -1,7 +1,7 @@
1
1
  # Explainiverse
2
2
 
3
3
  **Explainiverse** is a unified, extensible Python framework for Explainable AI (XAI).
4
- It provides a standardized interface for model-agnostic explainability with 10 state-of-the-art XAI methods, evaluation metrics, and a plugin registry for easy extensibility.
4
+ It provides a standardized interface for model-agnostic explainability with 11 state-of-the-art XAI methods, evaluation metrics, and a plugin registry for easy extensibility.
5
5
 
6
6
  ---
7
7
 
@@ -14,6 +14,7 @@ It provides a standardized interface for model-agnostic explainability with 10 s
14
14
  - **SHAP** - SHapley Additive exPlanations via KernelSHAP ([Lundberg & Lee, 2017](https://arxiv.org/abs/1705.07874))
15
15
  - **TreeSHAP** - Exact SHAP values for tree models, 10x+ faster ([Lundberg et al., 2018](https://arxiv.org/abs/1802.03888))
16
16
  - **Integrated Gradients** - Axiomatic attributions for neural networks ([Sundararajan et al., 2017](https://arxiv.org/abs/1703.01365))
17
+ - **GradCAM/GradCAM++** - Visual explanations for CNNs ([Selvaraju et al., 2017](https://arxiv.org/abs/1610.02391))
17
18
  - **Anchors** - High-precision rule-based explanations ([Ribeiro et al., 2018](https://ojs.aaai.org/index.php/AAAI/article/view/11491))
18
19
  - **Counterfactual** - DiCE-style diverse counterfactual explanations ([Mothilal et al., 2020](https://arxiv.org/abs/1905.07697))
19
20
 
@@ -80,7 +81,7 @@ adapter = SklearnAdapter(model, class_names=iris.target_names.tolist())
80
81
 
81
82
  # List available explainers
82
83
  print(default_registry.list_explainers())
83
- # ['lime', 'shap', 'treeshap', 'integrated_gradients', 'anchors', 'counterfactual', 'permutation_importance', 'partial_dependence', 'ale', 'sage']
84
+ # ['lime', 'shap', 'treeshap', 'integrated_gradients', 'gradcam', 'anchors', 'counterfactual', 'permutation_importance', 'partial_dependence', 'ale', 'sage']
84
85
 
85
86
  # Create and use an explainer
86
87
  explainer = default_registry.create(
@@ -101,9 +102,9 @@ print(explanation.explanation_data["feature_attributions"])
101
102
  local_tabular = default_registry.filter(scope="local", data_type="tabular")
102
103
  print(local_tabular) # ['lime', 'shap', 'treeshap', 'integrated_gradients', 'anchors', 'counterfactual']
103
104
 
104
- # Find explainers optimized for tree models
105
- tree_explainers = default_registry.filter(model_type="tree")
106
- print(tree_explainers) # ['treeshap']
105
+ # Find explainers for images/CNNs
106
+ image_explainers = default_registry.filter(data_type="image")
107
+ print(image_explainers) # ['lime', 'integrated_gradients', 'gradcam']
107
108
 
108
109
  # Get recommendations
109
110
  recommendations = default_registry.recommend(
@@ -197,6 +198,35 @@ explanation = explainer.explain(X_test[0], return_convergence_delta=True)
197
198
  print(f"Convergence delta: {explanation.explanation_data['convergence_delta']}")
198
199
  ```
199
200
 
201
+ ### GradCAM for CNN Visual Explanations
202
+
203
+ ```python
204
+ from explainiverse.explainers import GradCAMExplainer
205
+ from explainiverse import PyTorchAdapter
206
+
207
+ # Wrap your CNN model
208
+ adapter = PyTorchAdapter(cnn_model, task="classification", class_names=class_names)
209
+
210
+ # Find the last convolutional layer
211
+ layers = adapter.list_layers()
212
+ target_layer = "layer4" # Adjust based on your model architecture
213
+
214
+ # Create GradCAM explainer
215
+ explainer = GradCAMExplainer(
216
+ model=adapter,
217
+ target_layer=target_layer,
218
+ class_names=class_names,
219
+ method="gradcam" # or "gradcam++" for improved version
220
+ )
221
+
222
+ # Explain an image prediction
223
+ explanation = explainer.explain(image) # image shape: (C, H, W) or (N, C, H, W)
224
+ heatmap = explanation.explanation_data["heatmap"]
225
+
226
+ # Create overlay visualization
227
+ overlay = explainer.get_overlay(original_image, heatmap, alpha=0.5)
228
+ ```
229
+
200
230
  ### Using Specific Explainers
201
231
 
202
232
  ```python
@@ -302,8 +332,8 @@ poetry run pytest tests/test_new_explainers.py -v
302
332
  - [x] Permutation Importance, PDP, ALE, SAGE
303
333
  - [x] Explainer Registry with filtering
304
334
  - [x] PyTorch Adapter ✅
305
- - [x] Integrated Gradients ✅ NEW
306
- - [ ] GradCAM for CNNs
335
+ - [x] Integrated Gradients ✅
336
+ - [x] GradCAM/GradCAM++ for CNNs ✅ NEW
307
337
  - [ ] TensorFlow adapter
308
338
  - [ ] Interactive visualization dashboard
309
339
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "explainiverse"
3
- version = "0.2.3"
3
+ version = "0.2.4"
4
4
  description = "Unified, extensible explainability framework supporting LIME, SHAP, Anchors, Counterfactuals, PDP, ALE, SAGE, and more"
5
5
  authors = ["Muntaser Syed <jemsbhai@gmail.com>"]
6
6
  license = "MIT"
@@ -33,7 +33,7 @@ from explainiverse.adapters.sklearn_adapter import SklearnAdapter
33
33
  from explainiverse.adapters import TORCH_AVAILABLE
34
34
  from explainiverse.engine.suite import ExplanationSuite
35
35
 
36
- __version__ = "0.2.3"
36
+ __version__ = "0.2.4"
37
37
 
38
38
  __all__ = [
39
39
  # Core
@@ -370,6 +370,7 @@ def _create_default_registry() -> ExplainerRegistry:
370
370
  from explainiverse.explainers.global_explainers.sage import SAGEExplainer
371
371
  from explainiverse.explainers.counterfactual.dice_wrapper import CounterfactualExplainer
372
372
  from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
373
+ from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
373
374
 
374
375
  registry = ExplainerRegistry()
375
376
 
@@ -479,6 +480,23 @@ def _create_default_registry() -> ExplainerRegistry:
479
480
  )
480
481
  )
481
482
 
483
+ # Register GradCAM (for CNNs)
484
+ registry.register(
485
+ name="gradcam",
486
+ explainer_class=GradCAMExplainer,
487
+ meta=ExplainerMeta(
488
+ scope="local",
489
+ model_types=["neural"],
490
+ data_types=["image"],
491
+ task_types=["classification"],
492
+ description="GradCAM/GradCAM++ - visual explanations for CNNs via gradient-weighted activations (requires PyTorch)",
493
+ paper_reference="Selvaraju et al., 2017 - 'Grad-CAM: Visual Explanations from Deep Networks' (ICCV)",
494
+ complexity="O(forward_pass + backward_pass)",
495
+ requires_training_data=False,
496
+ supports_batching=True
497
+ )
498
+ )
499
+
482
500
  # =========================================================================
483
501
  # Global Explainers (model-level)
484
502
  # =========================================================================
@@ -27,6 +27,7 @@ from explainiverse.explainers.global_explainers.partial_dependence import Partia
27
27
  from explainiverse.explainers.global_explainers.ale import ALEExplainer
28
28
  from explainiverse.explainers.global_explainers.sage import SAGEExplainer
29
29
  from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
30
+ from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
30
31
 
31
32
  __all__ = [
32
33
  # Local explainers
@@ -36,6 +37,7 @@ __all__ = [
36
37
  "AnchorsExplainer",
37
38
  "CounterfactualExplainer",
38
39
  "IntegratedGradientsExplainer",
40
+ "GradCAMExplainer",
39
41
  # Global explainers
40
42
  "PermutationImportanceExplainer",
41
43
  "PartialDependenceExplainer",
@@ -7,5 +7,6 @@ typically via the PyTorchAdapter.
7
7
  """
8
8
 
9
9
  from explainiverse.explainers.gradient.integrated_gradients import IntegratedGradientsExplainer
10
+ from explainiverse.explainers.gradient.gradcam import GradCAMExplainer
10
11
 
11
- __all__ = ["IntegratedGradientsExplainer"]
12
+ __all__ = ["IntegratedGradientsExplainer", "GradCAMExplainer"]
@@ -0,0 +1,390 @@
1
+ # src/explainiverse/explainers/gradient/gradcam.py
2
+ """
3
+ GradCAM and GradCAM++ - Visual Explanations for CNNs.
4
+
5
+ GradCAM produces visual explanations by highlighting important regions
6
+ in an image that contribute to the model's prediction. It uses gradients
7
+ flowing into the final convolutional layer to produce a coarse localization map.
8
+
9
+ GradCAM++ improves upon GradCAM by using a weighted combination of positive
10
+ partial derivatives, providing better localization for multiple instances
11
+ of the same class.
12
+
13
+ References:
14
+ GradCAM: Selvaraju et al., 2017 - "Grad-CAM: Visual Explanations from
15
+ Deep Networks via Gradient-based Localization"
16
+ https://arxiv.org/abs/1610.02391
17
+
18
+ GradCAM++: Chattopadhay et al., 2018 - "Grad-CAM++: Generalized Gradient-based
19
+ Visual Explanations for Deep Convolutional Networks"
20
+ https://arxiv.org/abs/1710.11063
21
+
22
+ Example:
23
+ from explainiverse.explainers.gradient import GradCAMExplainer
24
+ from explainiverse.adapters import PyTorchAdapter
25
+
26
+ # For a CNN model
27
+ adapter = PyTorchAdapter(cnn_model, task="classification")
28
+
29
+ explainer = GradCAMExplainer(
30
+ model=adapter,
31
+ target_layer="layer4", # Last conv layer
32
+ class_names=class_names
33
+ )
34
+
35
+ explanation = explainer.explain(image)
36
+ heatmap = explanation.explanation_data["heatmap"]
37
+ """
38
+
39
+ import numpy as np
40
+ from typing import List, Optional, Tuple, Union
41
+
42
+ from explainiverse.core.explainer import BaseExplainer
43
+ from explainiverse.core.explanation import Explanation
44
+
45
+
46
+ class GradCAMExplainer(BaseExplainer):
47
+ """
48
+ GradCAM and GradCAM++ explainer for CNNs.
49
+
50
+ Produces visual heatmaps showing which regions of an input image
51
+ are most important for the model's prediction.
52
+
53
+ Attributes:
54
+ model: PyTorchAdapter wrapping a CNN model
55
+ target_layer: Name of the convolutional layer to use
56
+ class_names: List of class names
57
+ method: "gradcam" or "gradcam++"
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ model,
63
+ target_layer: str,
64
+ class_names: Optional[List[str]] = None,
65
+ method: str = "gradcam"
66
+ ):
67
+ """
68
+ Initialize the GradCAM explainer.
69
+
70
+ Args:
71
+ model: A PyTorchAdapter wrapping a CNN model.
72
+ target_layer: Name of the target convolutional layer.
73
+ Usually the last conv layer before the classifier.
74
+ Use adapter.list_layers() to see available layers.
75
+ class_names: List of class names for classification.
76
+ method: "gradcam" for standard GradCAM, "gradcam++" for improved version.
77
+ """
78
+ super().__init__(model)
79
+
80
+ # Validate model has layer access
81
+ if not hasattr(model, 'get_layer_gradients'):
82
+ raise TypeError(
83
+ "Model adapter must have get_layer_gradients() method. "
84
+ "Use PyTorchAdapter for PyTorch models."
85
+ )
86
+
87
+ self.target_layer = target_layer
88
+ self.class_names = list(class_names) if class_names else None
89
+ self.method = method.lower()
90
+
91
+ if self.method not in ["gradcam", "gradcam++"]:
92
+ raise ValueError(f"Method must be 'gradcam' or 'gradcam++', got '{method}'")
93
+
94
+ def _compute_gradcam(
95
+ self,
96
+ activations: np.ndarray,
97
+ gradients: np.ndarray
98
+ ) -> np.ndarray:
99
+ """
100
+ Compute standard GradCAM heatmap.
101
+
102
+ GradCAM = ReLU(sum_k(alpha_k * A^k))
103
+ where alpha_k = global_avg_pool(gradients for channel k)
104
+ """
105
+ # Global average pooling of gradients to get weights
106
+ # activations shape: (batch, channels, height, width)
107
+ # gradients shape: (batch, channels, height, width)
108
+
109
+ # For each channel, compute the average gradient (importance weight)
110
+ weights = np.mean(gradients, axis=(2, 3), keepdims=True) # (batch, channels, 1, 1)
111
+
112
+ # Weighted combination of activation maps
113
+ cam = np.sum(weights * activations, axis=1) # (batch, height, width)
114
+
115
+ # Apply ReLU (we only care about positive influence)
116
+ cam = np.maximum(cam, 0)
117
+
118
+ return cam
119
+
120
+ def _compute_gradcam_plusplus(
121
+ self,
122
+ activations: np.ndarray,
123
+ gradients: np.ndarray
124
+ ) -> np.ndarray:
125
+ """
126
+ Compute GradCAM++ heatmap.
127
+
128
+ GradCAM++ uses higher-order derivatives to weight the gradients,
129
+ providing better localization especially for multiple instances.
130
+ """
131
+ # First derivative
132
+ grad_2 = gradients ** 2
133
+ grad_3 = gradients ** 3
134
+
135
+ # Sum over spatial dimensions for denominator
136
+ sum_activations = np.sum(activations, axis=(2, 3), keepdims=True)
137
+
138
+ # Avoid division by zero
139
+ eps = 1e-8
140
+
141
+ # Alpha coefficients (pixel-wise weights)
142
+ alpha_num = grad_2
143
+ alpha_denom = 2 * grad_2 + sum_activations * grad_3 + eps
144
+ alpha = alpha_num / alpha_denom
145
+
146
+ # Set alpha to 0 where gradients are 0
147
+ alpha = np.where(gradients != 0, alpha, 0)
148
+
149
+ # Weights are sum of (alpha * ReLU(gradients))
150
+ weights = np.sum(alpha * np.maximum(gradients, 0), axis=(2, 3), keepdims=True)
151
+
152
+ # Weighted combination
153
+ cam = np.sum(weights * activations, axis=1)
154
+
155
+ # Apply ReLU
156
+ cam = np.maximum(cam, 0)
157
+
158
+ return cam
159
+
160
+ def _normalize_heatmap(self, heatmap: np.ndarray) -> np.ndarray:
161
+ """Normalize heatmap to [0, 1] range."""
162
+ heatmap = heatmap.squeeze()
163
+
164
+ min_val = heatmap.min()
165
+ max_val = heatmap.max()
166
+
167
+ if max_val - min_val > 1e-8:
168
+ heatmap = (heatmap - min_val) / (max_val - min_val)
169
+ else:
170
+ heatmap = np.zeros_like(heatmap)
171
+
172
+ return heatmap
173
+
174
+ def _resize_heatmap(
175
+ self,
176
+ heatmap: np.ndarray,
177
+ target_size: Tuple[int, int]
178
+ ) -> np.ndarray:
179
+ """
180
+ Resize heatmap to match input image size.
181
+
182
+ Uses simple bilinear-like interpolation without requiring scipy/cv2.
183
+ """
184
+ h, w = heatmap.shape
185
+ target_h, target_w = target_size
186
+
187
+ # Create coordinate grids
188
+ y_ratio = h / target_h
189
+ x_ratio = w / target_w
190
+
191
+ y_coords = np.arange(target_h) * y_ratio
192
+ x_coords = np.arange(target_w) * x_ratio
193
+
194
+ # Get integer indices and fractions
195
+ y_floor = np.floor(y_coords).astype(int)
196
+ x_floor = np.floor(x_coords).astype(int)
197
+
198
+ y_ceil = np.minimum(y_floor + 1, h - 1)
199
+ x_ceil = np.minimum(x_floor + 1, w - 1)
200
+
201
+ y_frac = y_coords - y_floor
202
+ x_frac = x_coords - x_floor
203
+
204
+ # Bilinear interpolation
205
+ resized = np.zeros((target_h, target_w))
206
+ for i in range(target_h):
207
+ for j in range(target_w):
208
+ top_left = heatmap[y_floor[i], x_floor[j]]
209
+ top_right = heatmap[y_floor[i], x_ceil[j]]
210
+ bottom_left = heatmap[y_ceil[i], x_floor[j]]
211
+ bottom_right = heatmap[y_ceil[i], x_ceil[j]]
212
+
213
+ top = top_left * (1 - x_frac[j]) + top_right * x_frac[j]
214
+ bottom = bottom_left * (1 - x_frac[j]) + bottom_right * x_frac[j]
215
+
216
+ resized[i, j] = top * (1 - y_frac[i]) + bottom * y_frac[i]
217
+
218
+ return resized
219
+
220
+ def explain(
221
+ self,
222
+ image: np.ndarray,
223
+ target_class: Optional[int] = None,
224
+ resize_to_input: bool = True
225
+ ) -> Explanation:
226
+ """
227
+ Generate GradCAM explanation for an image.
228
+
229
+ Args:
230
+ image: Input image as numpy array. Expected shapes:
231
+ - (C, H, W) for single image
232
+ - (1, C, H, W) for batched single image
233
+ - (H, W, C) will be transposed automatically
234
+ target_class: Class to explain. If None, uses predicted class.
235
+ resize_to_input: If True, resize heatmap to match input size.
236
+
237
+ Returns:
238
+ Explanation object with heatmap and metadata.
239
+ """
240
+ image = np.array(image, dtype=np.float32)
241
+
242
+ # Handle different input shapes
243
+ if image.ndim == 3:
244
+ # Could be (C, H, W) or (H, W, C)
245
+ if image.shape[0] in [1, 3, 4]: # Likely (C, H, W)
246
+ image = image[np.newaxis, ...] # Add batch dim
247
+ else: # Likely (H, W, C)
248
+ image = np.transpose(image, (2, 0, 1))[np.newaxis, ...]
249
+ elif image.ndim == 4:
250
+ pass # Already (N, C, H, W)
251
+ else:
252
+ raise ValueError(f"Expected 3D or 4D input, got shape {image.shape}")
253
+
254
+ input_size = (image.shape[2], image.shape[3]) # (H, W)
255
+
256
+ # Get activations and gradients for target layer
257
+ activations, gradients = self.model.get_layer_gradients(
258
+ image,
259
+ layer_name=self.target_layer,
260
+ target_class=target_class
261
+ )
262
+
263
+ # Ensure 4D: (batch, channels, height, width)
264
+ if activations.ndim == 2:
265
+ # Fully connected layer output, reshape
266
+ side = int(np.sqrt(activations.shape[1]))
267
+ activations = activations.reshape(1, 1, side, side)
268
+ gradients = gradients.reshape(1, 1, side, side)
269
+ elif activations.ndim == 3:
270
+ activations = activations[np.newaxis, ...]
271
+ gradients = gradients[np.newaxis, ...]
272
+
273
+ # Compute CAM based on method
274
+ if self.method == "gradcam":
275
+ cam = self._compute_gradcam(activations, gradients)
276
+ else: # gradcam++
277
+ cam = self._compute_gradcam_plusplus(activations, gradients)
278
+
279
+ # Normalize to [0, 1]
280
+ heatmap = self._normalize_heatmap(cam)
281
+
282
+ # Optionally resize to input size
283
+ if resize_to_input and heatmap.shape != input_size:
284
+ heatmap = self._resize_heatmap(heatmap, input_size)
285
+
286
+ # Determine target class info
287
+ if target_class is None:
288
+ predictions = self.model.predict(image)
289
+ target_class = int(np.argmax(predictions))
290
+
291
+ if self.class_names and target_class < len(self.class_names):
292
+ label_name = self.class_names[target_class]
293
+ else:
294
+ label_name = f"class_{target_class}"
295
+
296
+ return Explanation(
297
+ explainer_name=f"GradCAM" if self.method == "gradcam" else "GradCAM++",
298
+ target_class=label_name,
299
+ explanation_data={
300
+ "heatmap": heatmap.tolist(),
301
+ "heatmap_shape": list(heatmap.shape),
302
+ "target_layer": self.target_layer,
303
+ "method": self.method,
304
+ "input_shape": list(image.shape)
305
+ }
306
+ )
307
+
308
+ def explain_batch(
309
+ self,
310
+ images: np.ndarray,
311
+ target_class: Optional[int] = None
312
+ ) -> List[Explanation]:
313
+ """
314
+ Generate explanations for multiple images.
315
+
316
+ Args:
317
+ images: Batch of images (N, C, H, W).
318
+ target_class: Target class for all images.
319
+
320
+ Returns:
321
+ List of Explanation objects.
322
+ """
323
+ images = np.array(images)
324
+
325
+ return [
326
+ self.explain(images[i], target_class=target_class)
327
+ for i in range(images.shape[0])
328
+ ]
329
+
330
+ def get_overlay(
331
+ self,
332
+ image: np.ndarray,
333
+ heatmap: np.ndarray,
334
+ alpha: float = 0.5,
335
+ colormap: str = "jet"
336
+ ) -> np.ndarray:
337
+ """
338
+ Create an overlay of the heatmap on the original image.
339
+
340
+ This is a simple implementation without matplotlib/cv2 dependencies.
341
+ For better visualizations, use the heatmap with your preferred
342
+ visualization library.
343
+
344
+ Args:
345
+ image: Original image (H, W, 3) in [0, 255] or [0, 1] range.
346
+ heatmap: GradCAM heatmap (H, W) in [0, 1] range.
347
+ alpha: Transparency of the heatmap overlay.
348
+ colormap: Color scheme (currently only "jet" supported).
349
+
350
+ Returns:
351
+ Overlaid image as numpy array (H, W, 3) in [0, 1] range.
352
+ """
353
+ image = np.array(image)
354
+ heatmap = np.array(heatmap)
355
+
356
+ # Normalize image to [0, 1]
357
+ if image.max() > 1:
358
+ image = image / 255.0
359
+
360
+ # Handle channel-first format
361
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
362
+ image = np.transpose(image, (1, 2, 0))
363
+
364
+ # Simple jet colormap approximation
365
+ def jet_colormap(x):
366
+ """Simple jet colormap: blue -> cyan -> green -> yellow -> red"""
367
+ r = np.clip(1.5 - np.abs(4 * x - 3), 0, 1)
368
+ g = np.clip(1.5 - np.abs(4 * x - 2), 0, 1)
369
+ b = np.clip(1.5 - np.abs(4 * x - 1), 0, 1)
370
+ return np.stack([r, g, b], axis=-1)
371
+
372
+ # Apply colormap to heatmap
373
+ colored_heatmap = jet_colormap(heatmap)
374
+
375
+ # Ensure same size
376
+ if colored_heatmap.shape[:2] != image.shape[:2]:
377
+ colored_heatmap = self._resize_heatmap(
378
+ colored_heatmap.mean(axis=-1),
379
+ image.shape[:2]
380
+ )
381
+ colored_heatmap = jet_colormap(colored_heatmap)
382
+
383
+ # Blend
384
+ if image.ndim == 2:
385
+ image = np.stack([image] * 3, axis=-1)
386
+
387
+ overlay = (1 - alpha) * image + alpha * colored_heatmap
388
+ overlay = np.clip(overlay, 0, 1)
389
+
390
+ return overlay
File without changes