neuronview 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuronview-0.1.0/PKG-INFO +194 -0
- neuronview-0.1.0/README.md +171 -0
- neuronview-0.1.0/neuronview/__init__.py +20 -0
- neuronview-0.1.0/neuronview/inspector.py +253 -0
- neuronview-0.1.0/neuronview/utils.py +105 -0
- neuronview-0.1.0/neuronview/visualize.py +199 -0
- neuronview-0.1.0/neuronview.egg-info/PKG-INFO +194 -0
- neuronview-0.1.0/neuronview.egg-info/SOURCES.txt +13 -0
- neuronview-0.1.0/neuronview.egg-info/dependency_links.txt +1 -0
- neuronview-0.1.0/neuronview.egg-info/requires.txt +7 -0
- neuronview-0.1.0/neuronview.egg-info/top_level.txt +1 -0
- neuronview-0.1.0/pyproject.toml +35 -0
- neuronview-0.1.0/setup.cfg +4 -0
- neuronview-0.1.0/tests/test_inspector.py +183 -0
- neuronview-0.1.0/tests/test_visualize.py +65 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: neuronview
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A lightweight PyTorch activation heatmap visualizer
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Project-URL: Homepage, https://github.com/Turtle-dev3/neuronview
|
|
7
|
+
Project-URL: Issues, https://github.com/Turtle-dev3/neuronview/issues
|
|
8
|
+
Keywords: pytorch,deep-learning,visualization,activations,heatmap
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Education
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
15
|
+
Requires-Python: >=3.9
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
Requires-Dist: torch>=1.9.0
|
|
18
|
+
Requires-Dist: matplotlib>=3.4.0
|
|
19
|
+
Requires-Dist: numpy>=1.20.0
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
22
|
+
Requires-Dist: torchvision>=0.10.0; extra == "dev"
|
|
23
|
+
|
|
24
|
+
# neuronview
|
|
25
|
+
|
|
26
|
+
A lightweight PyTorch library for visualizing neural network activations as heatmaps. Hook into any layer of your model and see what it "sees" — with one line of code.
|
|
27
|
+
|
|
28
|
+
## Installation
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
# From source (recommended during development)
|
|
32
|
+
git clone https://github.com/yourusername/neuronview.git
|
|
33
|
+
cd neuronview
|
|
34
|
+
pip install -e ".[dev]"
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Quick Start
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
import torch
|
|
41
|
+
import torchvision.models as models
|
|
42
|
+
from neuronview import Inspector
|
|
43
|
+
|
|
44
|
+
# 1. Load any PyTorch model
|
|
45
|
+
model = models.resnet18(pretrained=True)
|
|
46
|
+
|
|
47
|
+
# 2. Create an Inspector and pick a layer to watch
|
|
48
|
+
inspector = Inspector(model)
|
|
49
|
+
inspector.watch("layer2.0.conv1")
|
|
50
|
+
|
|
51
|
+
# 3. Run a forward pass with your input
|
|
52
|
+
image = torch.randn(1, 3, 224, 224) # replace with a real image
|
|
53
|
+
inspector.run(image)
|
|
54
|
+
|
|
55
|
+
# 4. Visualize!
|
|
56
|
+
inspector.heatmap()
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Features
|
|
60
|
+
|
|
61
|
+
### Discover layers
|
|
62
|
+
|
|
63
|
+
Not sure which layers your model has? List them all:
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from neuronview import list_layers
|
|
67
|
+
import torchvision.models as models
|
|
68
|
+
|
|
69
|
+
model = models.resnet18()
|
|
70
|
+
for name in list_layers(model):
|
|
71
|
+
print(name)
|
|
72
|
+
# conv1
|
|
73
|
+
# bn1
|
|
74
|
+
# relu
|
|
75
|
+
# maxpool
|
|
76
|
+
# layer1.0.conv1
|
|
77
|
+
# layer1.0.bn1
|
|
78
|
+
# ...
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Watch multiple layers
|
|
82
|
+
|
|
83
|
+
You can hook into several layers at once using method chaining:
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
inspector = Inspector(model)
|
|
87
|
+
inspector.watch("layer1.0.conv1").watch("layer3.0.conv1")
|
|
88
|
+
inspector.run(image)
|
|
89
|
+
|
|
90
|
+
# Specify which layer to visualize
|
|
91
|
+
inspector.heatmap(layer_name="layer1.0.conv1")
|
|
92
|
+
inspector.heatmap(layer_name="layer3.0.conv1")
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
### View a specific channel
|
|
96
|
+
|
|
97
|
+
Each convolutional layer has multiple channels (filters). See what an individual channel detects:
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
inspector.heatmap(channel=5) # show channel 5
|
|
101
|
+
inspector.heatmap(channel=0, cmap="hot") # different colormap
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Overlay on the original image
|
|
105
|
+
|
|
106
|
+
See which part of your input image activated the layer most:
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
inspector.heatmap_overlay(
|
|
110
|
+
original_image=image,
|
|
111
|
+
alpha=0.6,
|
|
112
|
+
cmap="jet",
|
|
113
|
+
)
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
### Save figures
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
inspector.heatmap(save_path="activation_map.png")
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
### Clean up
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
inspector.clear() # clear stored activations (keep hooks)
|
|
126
|
+
inspector.unwatch("conv1") # remove a specific hook
|
|
127
|
+
inspector.unwatch() # remove all hooks
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
## API Reference
|
|
131
|
+
|
|
132
|
+
### `Inspector(model)`
|
|
133
|
+
|
|
134
|
+
The main class. Wraps a PyTorch model and manages forward hooks.
|
|
135
|
+
|
|
136
|
+
| Method | Description |
|
|
137
|
+
|--------|-------------|
|
|
138
|
+
| `.watch(layer_name)` | Hook into a layer by its dot-path name. Returns `self` for chaining. |
|
|
139
|
+
| `.run(x)` | Run a forward pass and capture activations. Returns model output. |
|
|
140
|
+
| `.get_activations(layer_name=None)` | Get the raw activation tensor. If only one layer is watched, `layer_name` can be omitted. |
|
|
141
|
+
| `.heatmap(layer_name=None, channel=None, **kwargs)` | Render a heatmap. Pass `cmap`, `figsize`, `title`, `save_path`. |
|
|
142
|
+
| `.heatmap_overlay(original_image, layer_name=None, channel=None, alpha=0.5)` | Overlay heatmap on the input image. |
|
|
143
|
+
| `.layers()` | List all hookable layer names in the model. |
|
|
144
|
+
| `.unwatch(layer_name=None)` | Remove hooks (all if no name given). |
|
|
145
|
+
| `.clear()` | Clear stored activations without removing hooks. |
|
|
146
|
+
|
|
147
|
+
### `list_layers(model, include_containers=False)`
|
|
148
|
+
|
|
149
|
+
Standalone function to list all hookable layers in a model.
|
|
150
|
+
|
|
151
|
+
### `heatmap(activation, channel=None, cmap="viridis", ...)`
|
|
152
|
+
|
|
153
|
+
Standalone function — render any activation tensor as a heatmap.
|
|
154
|
+
|
|
155
|
+
### `heatmap_overlay(activation, original_image, channel=None, alpha=0.5, ...)`
|
|
156
|
+
|
|
157
|
+
Standalone function — overlay an activation heatmap on an image.
|
|
158
|
+
|
|
159
|
+
## Running Tests
|
|
160
|
+
|
|
161
|
+
```bash
|
|
162
|
+
pip install -e ".[dev]"
|
|
163
|
+
pytest
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
## Project Structure
|
|
167
|
+
|
|
168
|
+
```
|
|
169
|
+
neuronview/
|
|
170
|
+
├── neuronview/
|
|
171
|
+
│ ├── __init__.py # Public API exports
|
|
172
|
+
│ ├── inspector.py # Inspector class (hooks + activation capture)
|
|
173
|
+
│ ├── visualize.py # Heatmap rendering with matplotlib
|
|
174
|
+
│ └── utils.py # Layer listing + lookup helpers
|
|
175
|
+
├── tests/
|
|
176
|
+
│ ├── test_inspector.py
|
|
177
|
+
│ └── test_visualize.py
|
|
178
|
+
├── pyproject.toml # Package metadata + dependencies
|
|
179
|
+
└── README.md
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
## How It Works
|
|
183
|
+
|
|
184
|
+
The core mechanism is **PyTorch forward hooks**. When you call `inspector.watch("layer2.0.conv1")`, neuronview:
|
|
185
|
+
|
|
186
|
+
1. Walks the model's module tree to find that layer
|
|
187
|
+
2. Registers a callback (`register_forward_hook`) on it
|
|
188
|
+
3. When `inspector.run(x)` triggers a forward pass, the callback fires and captures the layer's output tensor
|
|
189
|
+
4. The captured tensor is detached from the autograd graph and moved to CPU
|
|
190
|
+
5. `heatmap()` averages across channels (or picks one) and renders with matplotlib
|
|
191
|
+
|
|
192
|
+
## License
|
|
193
|
+
|
|
194
|
+
MIT
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# neuronview
|
|
2
|
+
|
|
3
|
+
A lightweight PyTorch library for visualizing neural network activations as heatmaps. Hook into any layer of your model and see what it "sees" — with one line of code.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
# From source (recommended during development)
|
|
9
|
+
git clone https://github.com/yourusername/neuronview.git
|
|
10
|
+
cd neuronview
|
|
11
|
+
pip install -e ".[dev]"
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Quick Start
|
|
15
|
+
|
|
16
|
+
```python
|
|
17
|
+
import torch
|
|
18
|
+
import torchvision.models as models
|
|
19
|
+
from neuronview import Inspector
|
|
20
|
+
|
|
21
|
+
# 1. Load any PyTorch model
|
|
22
|
+
model = models.resnet18(pretrained=True)
|
|
23
|
+
|
|
24
|
+
# 2. Create an Inspector and pick a layer to watch
|
|
25
|
+
inspector = Inspector(model)
|
|
26
|
+
inspector.watch("layer2.0.conv1")
|
|
27
|
+
|
|
28
|
+
# 3. Run a forward pass with your input
|
|
29
|
+
image = torch.randn(1, 3, 224, 224) # replace with a real image
|
|
30
|
+
inspector.run(image)
|
|
31
|
+
|
|
32
|
+
# 4. Visualize!
|
|
33
|
+
inspector.heatmap()
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Features
|
|
37
|
+
|
|
38
|
+
### Discover layers
|
|
39
|
+
|
|
40
|
+
Not sure which layers your model has? List them all:
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
from neuronview import list_layers
|
|
44
|
+
import torchvision.models as models
|
|
45
|
+
|
|
46
|
+
model = models.resnet18()
|
|
47
|
+
for name in list_layers(model):
|
|
48
|
+
print(name)
|
|
49
|
+
# conv1
|
|
50
|
+
# bn1
|
|
51
|
+
# relu
|
|
52
|
+
# maxpool
|
|
53
|
+
# layer1.0.conv1
|
|
54
|
+
# layer1.0.bn1
|
|
55
|
+
# ...
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
### Watch multiple layers
|
|
59
|
+
|
|
60
|
+
You can hook into several layers at once using method chaining:
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
inspector = Inspector(model)
|
|
64
|
+
inspector.watch("layer1.0.conv1").watch("layer3.0.conv1")
|
|
65
|
+
inspector.run(image)
|
|
66
|
+
|
|
67
|
+
# Specify which layer to visualize
|
|
68
|
+
inspector.heatmap(layer_name="layer1.0.conv1")
|
|
69
|
+
inspector.heatmap(layer_name="layer3.0.conv1")
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
### View a specific channel
|
|
73
|
+
|
|
74
|
+
Each convolutional layer has multiple channels (filters). See what an individual channel detects:
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
inspector.heatmap(channel=5) # show channel 5
|
|
78
|
+
inspector.heatmap(channel=0, cmap="hot") # different colormap
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Overlay on the original image
|
|
82
|
+
|
|
83
|
+
See which part of your input image activated the layer most:
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
inspector.heatmap_overlay(
|
|
87
|
+
original_image=image,
|
|
88
|
+
alpha=0.6,
|
|
89
|
+
cmap="jet",
|
|
90
|
+
)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### Save figures
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
inspector.heatmap(save_path="activation_map.png")
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### Clean up
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
inspector.clear() # clear stored activations (keep hooks)
|
|
103
|
+
inspector.unwatch("conv1") # remove a specific hook
|
|
104
|
+
inspector.unwatch() # remove all hooks
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
## API Reference
|
|
108
|
+
|
|
109
|
+
### `Inspector(model)`
|
|
110
|
+
|
|
111
|
+
The main class. Wraps a PyTorch model and manages forward hooks.
|
|
112
|
+
|
|
113
|
+
| Method | Description |
|
|
114
|
+
|--------|-------------|
|
|
115
|
+
| `.watch(layer_name)` | Hook into a layer by its dot-path name. Returns `self` for chaining. |
|
|
116
|
+
| `.run(x)` | Run a forward pass and capture activations. Returns model output. |
|
|
117
|
+
| `.get_activations(layer_name=None)` | Get the raw activation tensor. If only one layer is watched, `layer_name` can be omitted. |
|
|
118
|
+
| `.heatmap(layer_name=None, channel=None, **kwargs)` | Render a heatmap. Pass `cmap`, `figsize`, `title`, `save_path`. |
|
|
119
|
+
| `.heatmap_overlay(original_image, layer_name=None, channel=None, alpha=0.5)` | Overlay heatmap on the input image. |
|
|
120
|
+
| `.layers()` | List all hookable layer names in the model. |
|
|
121
|
+
| `.unwatch(layer_name=None)` | Remove hooks (all if no name given). |
|
|
122
|
+
| `.clear()` | Clear stored activations without removing hooks. |
|
|
123
|
+
|
|
124
|
+
### `list_layers(model, include_containers=False)`
|
|
125
|
+
|
|
126
|
+
Standalone function to list all hookable layers in a model.
|
|
127
|
+
|
|
128
|
+
### `heatmap(activation, channel=None, cmap="viridis", ...)`
|
|
129
|
+
|
|
130
|
+
Standalone function — render any activation tensor as a heatmap.
|
|
131
|
+
|
|
132
|
+
### `heatmap_overlay(activation, original_image, channel=None, alpha=0.5, ...)`
|
|
133
|
+
|
|
134
|
+
Standalone function — overlay an activation heatmap on an image.
|
|
135
|
+
|
|
136
|
+
## Running Tests
|
|
137
|
+
|
|
138
|
+
```bash
|
|
139
|
+
pip install -e ".[dev]"
|
|
140
|
+
pytest
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
## Project Structure
|
|
144
|
+
|
|
145
|
+
```
|
|
146
|
+
neuronview/
|
|
147
|
+
├── neuronview/
|
|
148
|
+
│ ├── __init__.py # Public API exports
|
|
149
|
+
│ ├── inspector.py # Inspector class (hooks + activation capture)
|
|
150
|
+
│ ├── visualize.py # Heatmap rendering with matplotlib
|
|
151
|
+
│ └── utils.py # Layer listing + lookup helpers
|
|
152
|
+
├── tests/
|
|
153
|
+
│ ├── test_inspector.py
|
|
154
|
+
│ └── test_visualize.py
|
|
155
|
+
├── pyproject.toml # Package metadata + dependencies
|
|
156
|
+
└── README.md
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
## How It Works
|
|
160
|
+
|
|
161
|
+
The core mechanism is **PyTorch forward hooks**. When you call `inspector.watch("layer2.0.conv1")`, neuronview:
|
|
162
|
+
|
|
163
|
+
1. Walks the model's module tree to find that layer
|
|
164
|
+
2. Registers a callback (`register_forward_hook`) on it
|
|
165
|
+
3. When `inspector.run(x)` triggers a forward pass, the callback fires and captures the layer's output tensor
|
|
166
|
+
4. The captured tensor is detached from the autograd graph and moved to CPU
|
|
167
|
+
5. `heatmap()` averages across channels (or picks one) and renders with matplotlib
|
|
168
|
+
|
|
169
|
+
## License
|
|
170
|
+
|
|
171
|
+
MIT
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
neuronview — A lightweight PyTorch activation heatmap visualizer.
|
|
3
|
+
|
|
4
|
+
Inspect what your neural network sees by hooking into any layer
|
|
5
|
+
and visualizing its activations as heatmaps.
|
|
6
|
+
|
|
7
|
+
Basic usage:
|
|
8
|
+
>>> from neuronview import Inspector
|
|
9
|
+
>>> inspector = Inspector(model)
|
|
10
|
+
>>> inspector.watch("layer2.0.conv1")
|
|
11
|
+
>>> inspector.run(input_tensor)
|
|
12
|
+
>>> inspector.heatmap()
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from neuronview.inspector import Inspector
|
|
16
|
+
from neuronview.visualize import heatmap, heatmap_overlay
|
|
17
|
+
from neuronview.utils import list_layers
|
|
18
|
+
|
|
19
|
+
__version__ = "0.1.0"
|
|
20
|
+
__all__ = ["Inspector", "heatmap", "heatmap_overlay", "list_layers"]
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
inspector.py — The core of neuronview.
|
|
3
|
+
|
|
4
|
+
This module contains the Inspector class, which uses PyTorch's forward hook
|
|
5
|
+
mechanism to capture activations from any layer in a model.
|
|
6
|
+
|
|
7
|
+
HOW FORWARD HOOKS WORK:
|
|
8
|
+
PyTorch lets you register a callback on any nn.Module. Every time a
|
|
9
|
+
forward pass runs through that module, your callback receives:
|
|
10
|
+
- module: the layer itself
|
|
11
|
+
- input: the tensor(s) going INTO the layer
|
|
12
|
+
- output: the tensor(s) coming OUT of the layer (the "activations")
|
|
13
|
+
|
|
14
|
+
We grab `output`, detach it from the computation graph (so it doesn't
|
|
15
|
+
affect gradients), and store it. That's it — that's the whole trick.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Optional
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
|
|
25
|
+
from neuronview.visualize import heatmap, heatmap_overlay
|
|
26
|
+
from neuronview.utils import list_layers, _get_layer_by_name
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Inspector:
|
|
30
|
+
"""Attach to a PyTorch model and capture activations from watched layers.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
model : nn.Module
|
|
35
|
+
The PyTorch model to inspect.
|
|
36
|
+
|
|
37
|
+
Example
|
|
38
|
+
-------
|
|
39
|
+
>>> import torchvision.models as models
|
|
40
|
+
>>> model = models.resnet18(pretrained=True)
|
|
41
|
+
>>> inspector = Inspector(model)
|
|
42
|
+
>>> inspector.watch("layer2.0.conv1")
|
|
43
|
+
>>> inspector.run(my_image_tensor)
|
|
44
|
+
>>> inspector.heatmap()
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, model: nn.Module) -> None:
|
|
48
|
+
self.model = model
|
|
49
|
+
self.model.eval() # always inspect in eval mode
|
|
50
|
+
|
|
51
|
+
# Maps layer name -> captured activation tensor
|
|
52
|
+
self._activations: dict[str, torch.Tensor] = {}
|
|
53
|
+
|
|
54
|
+
# Maps layer name -> hook handle (so we can remove hooks later)
|
|
55
|
+
self._hooks: dict[str, torch.utils.hooks.RemovableHook] = {}
|
|
56
|
+
|
|
57
|
+
# ------------------------------------------------------------------
|
|
58
|
+
# Public API
|
|
59
|
+
# ------------------------------------------------------------------
|
|
60
|
+
|
|
61
|
+
def watch(self, layer_name: str) -> "Inspector":
|
|
62
|
+
"""Register a forward hook on the named layer.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
layer_name : str
|
|
67
|
+
Dot-separated path to the layer, e.g. "layer2.0.conv1".
|
|
68
|
+
Use ``Inspector.layers()`` to discover valid names.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
Inspector
|
|
73
|
+
Returns self so you can chain calls:
|
|
74
|
+
``inspector.watch("layer1").watch("layer2")``
|
|
75
|
+
"""
|
|
76
|
+
if layer_name in self._hooks:
|
|
77
|
+
return self # already watching
|
|
78
|
+
|
|
79
|
+
layer = _get_layer_by_name(self.model, layer_name)
|
|
80
|
+
|
|
81
|
+
def hook_fn(module, input, output, name=layer_name):
|
|
82
|
+
"""Capture the output tensor when this layer runs."""
|
|
83
|
+
# .detach() prevents this from being part of the gradient graph
|
|
84
|
+
# .cpu() moves it off GPU so we don't hold GPU memory
|
|
85
|
+
if isinstance(output, torch.Tensor):
|
|
86
|
+
self._activations[name] = output.detach().cpu()
|
|
87
|
+
elif isinstance(output, tuple):
|
|
88
|
+
# Some layers return tuples; grab the first tensor
|
|
89
|
+
self._activations[name] = output[0].detach().cpu()
|
|
90
|
+
|
|
91
|
+
handle = layer.register_forward_hook(hook_fn)
|
|
92
|
+
self._hooks[layer_name] = handle
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
def run(self, x: torch.Tensor) -> torch.Tensor:
|
|
96
|
+
"""Run a forward pass through the model and capture activations.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
x : torch.Tensor
|
|
101
|
+
Input tensor. For image models, typically shape (1, C, H, W).
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
torch.Tensor
|
|
106
|
+
The model's output (predictions).
|
|
107
|
+
"""
|
|
108
|
+
if not self._hooks:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"No layers are being watched. "
|
|
111
|
+
"Call inspector.watch('layer_name') first."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with torch.no_grad():
|
|
115
|
+
output = self.model(x)
|
|
116
|
+
return output
|
|
117
|
+
|
|
118
|
+
def get_activations(self, layer_name: Optional[str] = None) -> torch.Tensor:
|
|
119
|
+
"""Retrieve captured activations.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
layer_name : str, optional
|
|
124
|
+
Which layer's activations to return. If None and only one layer
|
|
125
|
+
is being watched, returns that layer's activations.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
torch.Tensor
|
|
130
|
+
The activation tensor, typically shape (batch, channels, H, W).
|
|
131
|
+
"""
|
|
132
|
+
if layer_name is None:
|
|
133
|
+
if len(self._activations) == 1:
|
|
134
|
+
return next(iter(self._activations.values()))
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"Multiple layers watched: {list(self._activations.keys())}. "
|
|
137
|
+
"Specify which one with layer_name=."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if layer_name not in self._activations:
|
|
141
|
+
raise KeyError(
|
|
142
|
+
f"No activations for '{layer_name}'. "
|
|
143
|
+
"Did you call inspector.run() after inspector.watch()?"
|
|
144
|
+
)
|
|
145
|
+
return self._activations[layer_name]
|
|
146
|
+
|
|
147
|
+
def heatmap(
|
|
148
|
+
self,
|
|
149
|
+
layer_name: Optional[str] = None,
|
|
150
|
+
channel: Optional[int] = None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
):
|
|
153
|
+
"""Show a heatmap of the captured activations.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
layer_name : str, optional
|
|
158
|
+
Which layer to visualize. Can be omitted if only one is watched.
|
|
159
|
+
channel : int, optional
|
|
160
|
+
Specific channel to show. If None, averages across all channels.
|
|
161
|
+
**kwargs
|
|
162
|
+
Passed to matplotlib (cmap, figsize, etc.)
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
matplotlib.figure.Figure
|
|
167
|
+
"""
|
|
168
|
+
act = self.get_activations(layer_name)
|
|
169
|
+
return heatmap(act, channel=channel, **kwargs)
|
|
170
|
+
|
|
171
|
+
def heatmap_overlay(
|
|
172
|
+
self,
|
|
173
|
+
original_image: torch.Tensor,
|
|
174
|
+
layer_name: Optional[str] = None,
|
|
175
|
+
channel: Optional[int] = None,
|
|
176
|
+
alpha: float = 0.5,
|
|
177
|
+
**kwargs,
|
|
178
|
+
):
|
|
179
|
+
"""Overlay the activation heatmap on the original image.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
original_image : torch.Tensor
|
|
184
|
+
The input image tensor, shape (C, H, W) or (1, C, H, W).
|
|
185
|
+
layer_name : str, optional
|
|
186
|
+
Which layer to visualize.
|
|
187
|
+
channel : int, optional
|
|
188
|
+
Specific channel. If None, averages across all channels.
|
|
189
|
+
alpha : float
|
|
190
|
+
Transparency of the heatmap overlay (0 = invisible, 1 = opaque).
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
matplotlib.figure.Figure
|
|
195
|
+
"""
|
|
196
|
+
act = self.get_activations(layer_name)
|
|
197
|
+
return heatmap_overlay(
|
|
198
|
+
act, original_image, channel=channel, alpha=alpha, **kwargs
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def layers(self) -> list[str]:
|
|
202
|
+
"""List all hookable layers in the model.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
list[str]
|
|
207
|
+
Layer names you can pass to ``watch()``.
|
|
208
|
+
"""
|
|
209
|
+
return list_layers(self.model)
|
|
210
|
+
|
|
211
|
+
def unwatch(self, layer_name: Optional[str] = None) -> "Inspector":
|
|
212
|
+
"""Remove hooks and free captured activations.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
layer_name : str, optional
|
|
217
|
+
Specific layer to unwatch. If None, removes ALL hooks.
|
|
218
|
+
"""
|
|
219
|
+
if layer_name is None:
|
|
220
|
+
for handle in self._hooks.values():
|
|
221
|
+
handle.remove()
|
|
222
|
+
self._hooks.clear()
|
|
223
|
+
self._activations.clear()
|
|
224
|
+
else:
|
|
225
|
+
if layer_name in self._hooks:
|
|
226
|
+
self._hooks[layer_name].remove()
|
|
227
|
+
del self._hooks[layer_name]
|
|
228
|
+
self._activations.pop(layer_name, None)
|
|
229
|
+
return self
|
|
230
|
+
|
|
231
|
+
def clear(self) -> "Inspector":
|
|
232
|
+
"""Clear stored activations without removing hooks.
|
|
233
|
+
|
|
234
|
+
Useful between forward passes when you want fresh data.
|
|
235
|
+
"""
|
|
236
|
+
self._activations.clear()
|
|
237
|
+
return self
|
|
238
|
+
|
|
239
|
+
# ------------------------------------------------------------------
|
|
240
|
+
# Dunder methods
|
|
241
|
+
# ------------------------------------------------------------------
|
|
242
|
+
|
|
243
|
+
def __repr__(self) -> str:
|
|
244
|
+
watched = list(self._hooks.keys())
|
|
245
|
+
captured = list(self._activations.keys())
|
|
246
|
+
return (
|
|
247
|
+
f"Inspector(model={self.model.__class__.__name__}, "
|
|
248
|
+
f"watching={watched}, captured={captured})"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def __del__(self) -> None:
|
|
252
|
+
"""Clean up hooks when the Inspector is garbage-collected."""
|
|
253
|
+
self.unwatch()
|