neuronview 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,194 @@
1
+ Metadata-Version: 2.4
2
+ Name: neuronview
3
+ Version: 0.1.0
4
+ Summary: A lightweight PyTorch activation heatmap visualizer
5
+ License-Expression: MIT
6
+ Project-URL: Homepage, https://github.com/Turtle-dev3/neuronview
7
+ Project-URL: Issues, https://github.com/Turtle-dev3/neuronview/issues
8
+ Keywords: pytorch,deep-learning,visualization,activations,heatmap
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Education
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Scientific/Engineering :: Visualization
15
+ Requires-Python: >=3.9
16
+ Description-Content-Type: text/markdown
17
+ Requires-Dist: torch>=1.9.0
18
+ Requires-Dist: matplotlib>=3.4.0
19
+ Requires-Dist: numpy>=1.20.0
20
+ Provides-Extra: dev
21
+ Requires-Dist: pytest>=7.0; extra == "dev"
22
+ Requires-Dist: torchvision>=0.10.0; extra == "dev"
23
+
24
+ # neuronview
25
+
26
+ A lightweight PyTorch library for visualizing neural network activations as heatmaps. Hook into any layer of your model and see what it "sees" — with one line of code.
27
+
28
+ ## Installation
29
+
30
+ ```bash
31
+ # From source (recommended during development)
32
+ git clone https://github.com/yourusername/neuronview.git
33
+ cd neuronview
34
+ pip install -e ".[dev]"
35
+ ```
36
+
37
+ ## Quick Start
38
+
39
+ ```python
40
+ import torch
41
+ import torchvision.models as models
42
+ from neuronview import Inspector
43
+
44
+ # 1. Load any PyTorch model
45
+ model = models.resnet18(pretrained=True)
46
+
47
+ # 2. Create an Inspector and pick a layer to watch
48
+ inspector = Inspector(model)
49
+ inspector.watch("layer2.0.conv1")
50
+
51
+ # 3. Run a forward pass with your input
52
+ image = torch.randn(1, 3, 224, 224) # replace with a real image
53
+ inspector.run(image)
54
+
55
+ # 4. Visualize!
56
+ inspector.heatmap()
57
+ ```
58
+
59
+ ## Features
60
+
61
+ ### Discover layers
62
+
63
+ Not sure which layers your model has? List them all:
64
+
65
+ ```python
66
+ from neuronview import list_layers
67
+ import torchvision.models as models
68
+
69
+ model = models.resnet18()
70
+ for name in list_layers(model):
71
+ print(name)
72
+ # conv1
73
+ # bn1
74
+ # relu
75
+ # maxpool
76
+ # layer1.0.conv1
77
+ # layer1.0.bn1
78
+ # ...
79
+ ```
80
+
81
+ ### Watch multiple layers
82
+
83
+ You can hook into several layers at once using method chaining:
84
+
85
+ ```python
86
+ inspector = Inspector(model)
87
+ inspector.watch("layer1.0.conv1").watch("layer3.0.conv1")
88
+ inspector.run(image)
89
+
90
+ # Specify which layer to visualize
91
+ inspector.heatmap(layer_name="layer1.0.conv1")
92
+ inspector.heatmap(layer_name="layer3.0.conv1")
93
+ ```
94
+
95
+ ### View a specific channel
96
+
97
+ Each convolutional layer has multiple channels (filters). See what an individual channel detects:
98
+
99
+ ```python
100
+ inspector.heatmap(channel=5) # show channel 5
101
+ inspector.heatmap(channel=0, cmap="hot") # different colormap
102
+ ```
103
+
104
+ ### Overlay on the original image
105
+
106
+ See which part of your input image activated the layer most:
107
+
108
+ ```python
109
+ inspector.heatmap_overlay(
110
+ original_image=image,
111
+ alpha=0.6,
112
+ cmap="jet",
113
+ )
114
+ ```
115
+
116
+ ### Save figures
117
+
118
+ ```python
119
+ inspector.heatmap(save_path="activation_map.png")
120
+ ```
121
+
122
+ ### Clean up
123
+
124
+ ```python
125
+ inspector.clear() # clear stored activations (keep hooks)
126
+ inspector.unwatch("conv1") # remove a specific hook
127
+ inspector.unwatch() # remove all hooks
128
+ ```
129
+
130
+ ## API Reference
131
+
132
+ ### `Inspector(model)`
133
+
134
+ The main class. Wraps a PyTorch model and manages forward hooks.
135
+
136
+ | Method | Description |
137
+ |--------|-------------|
138
+ | `.watch(layer_name)` | Hook into a layer by its dot-path name. Returns `self` for chaining. |
139
+ | `.run(x)` | Run a forward pass and capture activations. Returns model output. |
140
+ | `.get_activations(layer_name=None)` | Get the raw activation tensor. If only one layer is watched, `layer_name` can be omitted. |
141
+ | `.heatmap(layer_name=None, channel=None, **kwargs)` | Render a heatmap. Pass `cmap`, `figsize`, `title`, `save_path`. |
142
+ | `.heatmap_overlay(original_image, layer_name=None, channel=None, alpha=0.5)` | Overlay heatmap on the input image. |
143
+ | `.layers()` | List all hookable layer names in the model. |
144
+ | `.unwatch(layer_name=None)` | Remove hooks (all if no name given). |
145
+ | `.clear()` | Clear stored activations without removing hooks. |
146
+
147
+ ### `list_layers(model, include_containers=False)`
148
+
149
+ Standalone function to list all hookable layers in a model.
150
+
151
+ ### `heatmap(activation, channel=None, cmap="viridis", ...)`
152
+
153
+ Standalone function — render any activation tensor as a heatmap.
154
+
155
+ ### `heatmap_overlay(activation, original_image, channel=None, alpha=0.5, ...)`
156
+
157
+ Standalone function — overlay an activation heatmap on an image.
158
+
159
+ ## Running Tests
160
+
161
+ ```bash
162
+ pip install -e ".[dev]"
163
+ pytest
164
+ ```
165
+
166
+ ## Project Structure
167
+
168
+ ```
169
+ neuronview/
170
+ ├── neuronview/
171
+ │ ├── __init__.py # Public API exports
172
+ │ ├── inspector.py # Inspector class (hooks + activation capture)
173
+ │ ├── visualize.py # Heatmap rendering with matplotlib
174
+ │ └── utils.py # Layer listing + lookup helpers
175
+ ├── tests/
176
+ │ ├── test_inspector.py
177
+ │ └── test_visualize.py
178
+ ├── pyproject.toml # Package metadata + dependencies
179
+ └── README.md
180
+ ```
181
+
182
+ ## How It Works
183
+
184
+ The core mechanism is **PyTorch forward hooks**. When you call `inspector.watch("layer2.0.conv1")`, neuronview:
185
+
186
+ 1. Walks the model's module tree to find that layer
187
+ 2. Registers a callback (`register_forward_hook`) on it
188
+ 3. When `inspector.run(x)` triggers a forward pass, the callback fires and captures the layer's output tensor
189
+ 4. The captured tensor is detached from the autograd graph and moved to CPU
190
+ 5. `heatmap()` averages across channels (or picks one) and renders with matplotlib
191
+
192
+ ## License
193
+
194
+ MIT
@@ -0,0 +1,171 @@
1
+ # neuronview
2
+
3
+ A lightweight PyTorch library for visualizing neural network activations as heatmaps. Hook into any layer of your model and see what it "sees" — with one line of code.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ # From source (recommended during development)
9
+ git clone https://github.com/yourusername/neuronview.git
10
+ cd neuronview
11
+ pip install -e ".[dev]"
12
+ ```
13
+
14
+ ## Quick Start
15
+
16
+ ```python
17
+ import torch
18
+ import torchvision.models as models
19
+ from neuronview import Inspector
20
+
21
+ # 1. Load any PyTorch model
22
+ model = models.resnet18(pretrained=True)
23
+
24
+ # 2. Create an Inspector and pick a layer to watch
25
+ inspector = Inspector(model)
26
+ inspector.watch("layer2.0.conv1")
27
+
28
+ # 3. Run a forward pass with your input
29
+ image = torch.randn(1, 3, 224, 224) # replace with a real image
30
+ inspector.run(image)
31
+
32
+ # 4. Visualize!
33
+ inspector.heatmap()
34
+ ```
35
+
36
+ ## Features
37
+
38
+ ### Discover layers
39
+
40
+ Not sure which layers your model has? List them all:
41
+
42
+ ```python
43
+ from neuronview import list_layers
44
+ import torchvision.models as models
45
+
46
+ model = models.resnet18()
47
+ for name in list_layers(model):
48
+ print(name)
49
+ # conv1
50
+ # bn1
51
+ # relu
52
+ # maxpool
53
+ # layer1.0.conv1
54
+ # layer1.0.bn1
55
+ # ...
56
+ ```
57
+
58
+ ### Watch multiple layers
59
+
60
+ You can hook into several layers at once using method chaining:
61
+
62
+ ```python
63
+ inspector = Inspector(model)
64
+ inspector.watch("layer1.0.conv1").watch("layer3.0.conv1")
65
+ inspector.run(image)
66
+
67
+ # Specify which layer to visualize
68
+ inspector.heatmap(layer_name="layer1.0.conv1")
69
+ inspector.heatmap(layer_name="layer3.0.conv1")
70
+ ```
71
+
72
+ ### View a specific channel
73
+
74
+ Each convolutional layer has multiple channels (filters). See what an individual channel detects:
75
+
76
+ ```python
77
+ inspector.heatmap(channel=5) # show channel 5
78
+ inspector.heatmap(channel=0, cmap="hot") # different colormap
79
+ ```
80
+
81
+ ### Overlay on the original image
82
+
83
+ See which part of your input image activated the layer most:
84
+
85
+ ```python
86
+ inspector.heatmap_overlay(
87
+ original_image=image,
88
+ alpha=0.6,
89
+ cmap="jet",
90
+ )
91
+ ```
92
+
93
+ ### Save figures
94
+
95
+ ```python
96
+ inspector.heatmap(save_path="activation_map.png")
97
+ ```
98
+
99
+ ### Clean up
100
+
101
+ ```python
102
+ inspector.clear() # clear stored activations (keep hooks)
103
+ inspector.unwatch("conv1") # remove a specific hook
104
+ inspector.unwatch() # remove all hooks
105
+ ```
106
+
107
+ ## API Reference
108
+
109
+ ### `Inspector(model)`
110
+
111
+ The main class. Wraps a PyTorch model and manages forward hooks.
112
+
113
+ | Method | Description |
114
+ |--------|-------------|
115
+ | `.watch(layer_name)` | Hook into a layer by its dot-path name. Returns `self` for chaining. |
116
+ | `.run(x)` | Run a forward pass and capture activations. Returns model output. |
117
+ | `.get_activations(layer_name=None)` | Get the raw activation tensor. If only one layer is watched, `layer_name` can be omitted. |
118
+ | `.heatmap(layer_name=None, channel=None, **kwargs)` | Render a heatmap. Pass `cmap`, `figsize`, `title`, `save_path`. |
119
+ | `.heatmap_overlay(original_image, layer_name=None, channel=None, alpha=0.5)` | Overlay heatmap on the input image. |
120
+ | `.layers()` | List all hookable layer names in the model. |
121
+ | `.unwatch(layer_name=None)` | Remove hooks (all if no name given). |
122
+ | `.clear()` | Clear stored activations without removing hooks. |
123
+
124
+ ### `list_layers(model, include_containers=False)`
125
+
126
+ Standalone function to list all hookable layers in a model.
127
+
128
+ ### `heatmap(activation, channel=None, cmap="viridis", ...)`
129
+
130
+ Standalone function — render any activation tensor as a heatmap.
131
+
132
+ ### `heatmap_overlay(activation, original_image, channel=None, alpha=0.5, ...)`
133
+
134
+ Standalone function — overlay an activation heatmap on an image.
135
+
136
+ ## Running Tests
137
+
138
+ ```bash
139
+ pip install -e ".[dev]"
140
+ pytest
141
+ ```
142
+
143
+ ## Project Structure
144
+
145
+ ```
146
+ neuronview/
147
+ ├── neuronview/
148
+ │ ├── __init__.py # Public API exports
149
+ │ ├── inspector.py # Inspector class (hooks + activation capture)
150
+ │ ├── visualize.py # Heatmap rendering with matplotlib
151
+ │ └── utils.py # Layer listing + lookup helpers
152
+ ├── tests/
153
+ │ ├── test_inspector.py
154
+ │ └── test_visualize.py
155
+ ├── pyproject.toml # Package metadata + dependencies
156
+ └── README.md
157
+ ```
158
+
159
+ ## How It Works
160
+
161
+ The core mechanism is **PyTorch forward hooks**. When you call `inspector.watch("layer2.0.conv1")`, neuronview:
162
+
163
+ 1. Walks the model's module tree to find that layer
164
+ 2. Registers a callback (`register_forward_hook`) on it
165
+ 3. When `inspector.run(x)` triggers a forward pass, the callback fires and captures the layer's output tensor
166
+ 4. The captured tensor is detached from the autograd graph and moved to CPU
167
+ 5. `heatmap()` averages across channels (or picks one) and renders with matplotlib
168
+
169
+ ## License
170
+
171
+ MIT
@@ -0,0 +1,20 @@
1
+ """
2
+ neuronview — A lightweight PyTorch activation heatmap visualizer.
3
+
4
+ Inspect what your neural network sees by hooking into any layer
5
+ and visualizing its activations as heatmaps.
6
+
7
+ Basic usage:
8
+ >>> from neuronview import Inspector
9
+ >>> inspector = Inspector(model)
10
+ >>> inspector.watch("layer2.0.conv1")
11
+ >>> inspector.run(input_tensor)
12
+ >>> inspector.heatmap()
13
+ """
14
+
15
+ from neuronview.inspector import Inspector
16
+ from neuronview.visualize import heatmap, heatmap_overlay
17
+ from neuronview.utils import list_layers
18
+
19
+ __version__ = "0.1.0"
20
+ __all__ = ["Inspector", "heatmap", "heatmap_overlay", "list_layers"]
@@ -0,0 +1,253 @@
1
+ """
2
+ inspector.py — The core of neuronview.
3
+
4
+ This module contains the Inspector class, which uses PyTorch's forward hook
5
+ mechanism to capture activations from any layer in a model.
6
+
7
+ HOW FORWARD HOOKS WORK:
8
+ PyTorch lets you register a callback on any nn.Module. Every time a
9
+ forward pass runs through that module, your callback receives:
10
+ - module: the layer itself
11
+ - input: the tensor(s) going INTO the layer
12
+ - output: the tensor(s) coming OUT of the layer (the "activations")
13
+
14
+ We grab `output`, detach it from the computation graph (so it doesn't
15
+ affect gradients), and store it. That's it — that's the whole trick.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from neuronview.visualize import heatmap, heatmap_overlay
26
+ from neuronview.utils import list_layers, _get_layer_by_name
27
+
28
+
29
+ class Inspector:
30
+ """Attach to a PyTorch model and capture activations from watched layers.
31
+
32
+ Parameters
33
+ ----------
34
+ model : nn.Module
35
+ The PyTorch model to inspect.
36
+
37
+ Example
38
+ -------
39
+ >>> import torchvision.models as models
40
+ >>> model = models.resnet18(pretrained=True)
41
+ >>> inspector = Inspector(model)
42
+ >>> inspector.watch("layer2.0.conv1")
43
+ >>> inspector.run(my_image_tensor)
44
+ >>> inspector.heatmap()
45
+ """
46
+
47
+ def __init__(self, model: nn.Module) -> None:
48
+ self.model = model
49
+ self.model.eval() # always inspect in eval mode
50
+
51
+ # Maps layer name -> captured activation tensor
52
+ self._activations: dict[str, torch.Tensor] = {}
53
+
54
+ # Maps layer name -> hook handle (so we can remove hooks later)
55
+ self._hooks: dict[str, torch.utils.hooks.RemovableHook] = {}
56
+
57
+ # ------------------------------------------------------------------
58
+ # Public API
59
+ # ------------------------------------------------------------------
60
+
61
+ def watch(self, layer_name: str) -> "Inspector":
62
+ """Register a forward hook on the named layer.
63
+
64
+ Parameters
65
+ ----------
66
+ layer_name : str
67
+ Dot-separated path to the layer, e.g. "layer2.0.conv1".
68
+ Use ``Inspector.layers()`` to discover valid names.
69
+
70
+ Returns
71
+ -------
72
+ Inspector
73
+ Returns self so you can chain calls:
74
+ ``inspector.watch("layer1").watch("layer2")``
75
+ """
76
+ if layer_name in self._hooks:
77
+ return self # already watching
78
+
79
+ layer = _get_layer_by_name(self.model, layer_name)
80
+
81
+ def hook_fn(module, input, output, name=layer_name):
82
+ """Capture the output tensor when this layer runs."""
83
+ # .detach() prevents this from being part of the gradient graph
84
+ # .cpu() moves it off GPU so we don't hold GPU memory
85
+ if isinstance(output, torch.Tensor):
86
+ self._activations[name] = output.detach().cpu()
87
+ elif isinstance(output, tuple):
88
+ # Some layers return tuples; grab the first tensor
89
+ self._activations[name] = output[0].detach().cpu()
90
+
91
+ handle = layer.register_forward_hook(hook_fn)
92
+ self._hooks[layer_name] = handle
93
+ return self
94
+
95
+ def run(self, x: torch.Tensor) -> torch.Tensor:
96
+ """Run a forward pass through the model and capture activations.
97
+
98
+ Parameters
99
+ ----------
100
+ x : torch.Tensor
101
+ Input tensor. For image models, typically shape (1, C, H, W).
102
+
103
+ Returns
104
+ -------
105
+ torch.Tensor
106
+ The model's output (predictions).
107
+ """
108
+ if not self._hooks:
109
+ raise RuntimeError(
110
+ "No layers are being watched. "
111
+ "Call inspector.watch('layer_name') first."
112
+ )
113
+
114
+ with torch.no_grad():
115
+ output = self.model(x)
116
+ return output
117
+
118
+ def get_activations(self, layer_name: Optional[str] = None) -> torch.Tensor:
119
+ """Retrieve captured activations.
120
+
121
+ Parameters
122
+ ----------
123
+ layer_name : str, optional
124
+ Which layer's activations to return. If None and only one layer
125
+ is being watched, returns that layer's activations.
126
+
127
+ Returns
128
+ -------
129
+ torch.Tensor
130
+ The activation tensor, typically shape (batch, channels, H, W).
131
+ """
132
+ if layer_name is None:
133
+ if len(self._activations) == 1:
134
+ return next(iter(self._activations.values()))
135
+ raise ValueError(
136
+ f"Multiple layers watched: {list(self._activations.keys())}. "
137
+ "Specify which one with layer_name=."
138
+ )
139
+
140
+ if layer_name not in self._activations:
141
+ raise KeyError(
142
+ f"No activations for '{layer_name}'. "
143
+ "Did you call inspector.run() after inspector.watch()?"
144
+ )
145
+ return self._activations[layer_name]
146
+
147
+ def heatmap(
148
+ self,
149
+ layer_name: Optional[str] = None,
150
+ channel: Optional[int] = None,
151
+ **kwargs,
152
+ ):
153
+ """Show a heatmap of the captured activations.
154
+
155
+ Parameters
156
+ ----------
157
+ layer_name : str, optional
158
+ Which layer to visualize. Can be omitted if only one is watched.
159
+ channel : int, optional
160
+ Specific channel to show. If None, averages across all channels.
161
+ **kwargs
162
+ Passed to matplotlib (cmap, figsize, etc.)
163
+
164
+ Returns
165
+ -------
166
+ matplotlib.figure.Figure
167
+ """
168
+ act = self.get_activations(layer_name)
169
+ return heatmap(act, channel=channel, **kwargs)
170
+
171
+ def heatmap_overlay(
172
+ self,
173
+ original_image: torch.Tensor,
174
+ layer_name: Optional[str] = None,
175
+ channel: Optional[int] = None,
176
+ alpha: float = 0.5,
177
+ **kwargs,
178
+ ):
179
+ """Overlay the activation heatmap on the original image.
180
+
181
+ Parameters
182
+ ----------
183
+ original_image : torch.Tensor
184
+ The input image tensor, shape (C, H, W) or (1, C, H, W).
185
+ layer_name : str, optional
186
+ Which layer to visualize.
187
+ channel : int, optional
188
+ Specific channel. If None, averages across all channels.
189
+ alpha : float
190
+ Transparency of the heatmap overlay (0 = invisible, 1 = opaque).
191
+
192
+ Returns
193
+ -------
194
+ matplotlib.figure.Figure
195
+ """
196
+ act = self.get_activations(layer_name)
197
+ return heatmap_overlay(
198
+ act, original_image, channel=channel, alpha=alpha, **kwargs
199
+ )
200
+
201
+ def layers(self) -> list[str]:
202
+ """List all hookable layers in the model.
203
+
204
+ Returns
205
+ -------
206
+ list[str]
207
+ Layer names you can pass to ``watch()``.
208
+ """
209
+ return list_layers(self.model)
210
+
211
+ def unwatch(self, layer_name: Optional[str] = None) -> "Inspector":
212
+ """Remove hooks and free captured activations.
213
+
214
+ Parameters
215
+ ----------
216
+ layer_name : str, optional
217
+ Specific layer to unwatch. If None, removes ALL hooks.
218
+ """
219
+ if layer_name is None:
220
+ for handle in self._hooks.values():
221
+ handle.remove()
222
+ self._hooks.clear()
223
+ self._activations.clear()
224
+ else:
225
+ if layer_name in self._hooks:
226
+ self._hooks[layer_name].remove()
227
+ del self._hooks[layer_name]
228
+ self._activations.pop(layer_name, None)
229
+ return self
230
+
231
+ def clear(self) -> "Inspector":
232
+ """Clear stored activations without removing hooks.
233
+
234
+ Useful between forward passes when you want fresh data.
235
+ """
236
+ self._activations.clear()
237
+ return self
238
+
239
+ # ------------------------------------------------------------------
240
+ # Dunder methods
241
+ # ------------------------------------------------------------------
242
+
243
+ def __repr__(self) -> str:
244
+ watched = list(self._hooks.keys())
245
+ captured = list(self._activations.keys())
246
+ return (
247
+ f"Inspector(model={self.model.__class__.__name__}, "
248
+ f"watching={watched}, captured={captured})"
249
+ )
250
+
251
+ def __del__(self) -> None:
252
+ """Clean up hooks when the Inspector is garbage-collected."""
253
+ self.unwatch()