torch-anatomy 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 <Your Name>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy...
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-anatomy
3
+ Version: 0.1.0
4
+ Summary: Layer-by-layer visualizer for PyTorch models
5
+ Author: Harshal Vilas Kale
6
+ License: MIT
7
+ License-File: LICENSE
8
+ Requires-Dist: torch
9
+ Requires-Dist: torchvision
10
+ Requires-Dist: matplotlib
11
+ Requires-Dist: numpy
12
+ Requires-Dist: Pillow
13
+ Requires-Dist: click
14
+ Dynamic: author
15
+ Dynamic: license
16
+ Dynamic: license-file
17
+ Dynamic: requires-dist
18
+ Dynamic: summary
@@ -0,0 +1,43 @@
1
+ # torch-anatomy
2
+
3
+ **Layer-by-layer visualizer for PyTorch models — Understand what each layer actually does.**
4
+
5
+ ![PyPI](https://img.shields.io/pypi/v/torch-anatomy)
6
+ ![License](https://img.shields.io/github/license/yourusername/torch-anatomy)
7
+
8
+ ## Install
9
+
10
+ ```bash
11
+ pip install torch-anatomy
12
+ ```
13
+
14
+ ## Usage
15
+
16
+ ```python
17
+ from torch_anatomy import visualize_layers
18
+ from torchvision import models
19
+
20
+ model = models.resnet18(pretrained=True)
21
+ visualize_layers(
22
+ model=model,
23
+ input_image='dog.jpg',
24
+ layers_to_show=['Conv2d', 'ReLU'],
25
+ channels_per_layer=6,
26
+ colormap='inferno',
27
+ show_colorbar=True
28
+ )
29
+ ```
30
+
31
+ Or from CLI:
32
+
33
+ ```bash
34
+ torch-anatomy --model resnet18 --image dog.jpg
35
+ ```
36
+
37
+ ## Features
38
+ - Plug-and-play for any PyTorch CNN
39
+ - Visualizes feature maps for key layers
40
+ - Customizable channels, colormap, and more
41
+
42
+ ## License
43
+ MIT
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,23 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='torch-anatomy',
5
+ version='0.1.0',
6
+ description='Layer-by-layer visualizer for PyTorch models',
7
+ author='Harshal Vilas Kale',
8
+ packages=find_packages(),
9
+ install_requires=[
10
+ 'torch',
11
+ 'torchvision',
12
+ 'matplotlib',
13
+ 'numpy',
14
+ 'Pillow',
15
+ 'click'
16
+ ],
17
+ entry_points={
18
+ 'console_scripts': [
19
+ 'torch-anatomy=torch_anatomy.cli:main'
20
+ ]
21
+ },
22
+ license="MIT",
23
+ )
@@ -0,0 +1,2 @@
1
+ def test_placeholder():
2
+ assert True
@@ -0,0 +1 @@
1
+ from .visualizer import visualize_layers
@@ -0,0 +1,20 @@
1
+ import click
2
+ import torch
3
+ from torchvision import models
4
+ from .visualizer import visualize_layers
5
+
6
+ def get_model(model_name):
7
+ # Simple model loader for demo (expand as needed)
8
+ if model_name == 'resnet18':
9
+ return models.resnet18(pretrained=True)
10
+ raise ValueError(f"Unknown model: {model_name}")
11
+
12
+ @click.command()
13
+ @click.option('--model', required=True, help='Model name, e.g. resnet18')
14
+ @click.option('--image', required=True, help='Path to input image')
15
+ def main(model, image):
16
+ model_obj = get_model(model)
17
+ visualize_layers(model=model_obj, input_image=image)
18
+
19
+ if __name__ == '__main__':
20
+ main()
@@ -0,0 +1,23 @@
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ import torchvision.transforms as T
5
+
6
+ def load_image(img, size=224):
7
+ """
8
+ Loads and preprocesses an image for a PyTorch model.
9
+ Accepts file path, PIL.Image, or np.ndarray.
10
+ Returns a torch.Tensor of shape (1, C, H, W)
11
+ """
12
+ if isinstance(img, str):
13
+ img = Image.open(img).convert('RGB')
14
+ elif isinstance(img, np.ndarray):
15
+ img = Image.fromarray(img)
16
+ # Resize and normalize
17
+ transform = T.Compose([
18
+ T.Resize((size, size)),
19
+ T.ToTensor(),
20
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
+ ])
22
+ tensor = transform(img).unsqueeze(0)
23
+ return tensor
@@ -0,0 +1,129 @@
1
+ import torch
2
+ import torchvision
3
+ import matplotlib
4
+ matplotlib.use('Agg') # Set backend to non-interactive
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from PIL import Image
8
+ import os
9
+ from .utils import load_image
10
+
11
+ def get_activation(name):
12
+ """Hook to get activations from a layer"""
13
+ def hook(model, input, output):
14
+ activations[name] = output.detach()
15
+ return hook
16
+
17
+ def visualize_layers(
18
+ model,
19
+ input_image,
20
+ layers_to_show=None,
21
+ save_dir=None,
22
+ show=True,
23
+ channels_per_layer=4,
24
+ colormap='viridis',
25
+ show_colorbar=False
26
+ ):
27
+ """
28
+ Visualize intermediate activations of a PyTorch CNN model layer-by-layer.
29
+ Args:
30
+ model: PyTorch model (nn.Module)
31
+ input_image: Path to image or PIL.Image or np.ndarray
32
+ layers_to_show: List of layer names/types to visualize (default: Conv, ReLU, Pool)
33
+ save_dir: If provided, saves images to this directory
34
+ show: If True, displays the plots
35
+ channels_per_layer: Number of channels to show per layer (default: 4)
36
+ colormap: Matplotlib colormap to use (default: 'viridis')
37
+ show_colorbar: Whether to show colorbar (default: False)
38
+ """
39
+ global activations
40
+ activations = {}
41
+
42
+ # Default layers to show if none specified
43
+ if layers_to_show is None:
44
+ layers_to_show = ['Conv2d', 'ReLU', 'MaxPool2d']
45
+
46
+ # Register hooks for all layers
47
+ hooks = []
48
+ for name, layer in model.named_modules():
49
+ if any(layer_type in str(type(layer)) for layer_type in layers_to_show):
50
+ hooks.append(layer.register_forward_hook(get_activation(name)))
51
+
52
+ # Load and preprocess image
53
+ input_tensor = load_image(input_image)
54
+
55
+ # Forward pass
56
+ model.eval()
57
+ with torch.no_grad():
58
+ output = model(input_tensor)
59
+
60
+ # Remove hooks
61
+ for hook in hooks:
62
+ hook.remove()
63
+
64
+ # Create visualization
65
+ n_layers = len(activations)
66
+ if n_layers == 0:
67
+ print("No matching layers found!")
68
+ return
69
+
70
+ # Prepare input image for display
71
+ if isinstance(input_image, str):
72
+ img_disp = Image.open(input_image).convert('RGB')
73
+ elif isinstance(input_image, np.ndarray):
74
+ img_disp = Image.fromarray(input_image)
75
+ else:
76
+ img_disp = input_image
77
+
78
+ # Calculate grid size
79
+ n_cols = channels_per_layer
80
+ n_rows = n_layers + 1 # +1 for input image
81
+
82
+ plt.figure(figsize=(4*n_cols, 4*n_rows))
83
+
84
+ # Plot input image
85
+ plt.subplot(n_rows, n_cols, 1)
86
+ plt.imshow(img_disp)
87
+ plt.title('Input Image')
88
+ plt.axis('off')
89
+ if show_colorbar:
90
+ plt.colorbar()
91
+ # Fill rest of first row with blanks if channels_per_layer > 1
92
+ for i in range(2, n_cols+1):
93
+ plt.subplot(n_rows, n_cols, i)
94
+ plt.axis('off')
95
+
96
+ # Plot each layer's activations (top N channels)
97
+ for row_idx, (name, activation) in enumerate(activations.items(), 1):
98
+ n_ch = activation.shape[1]
99
+ for ch in range(min(channels_per_layer, n_ch)):
100
+ plt.subplot(n_rows, n_cols, row_idx*n_cols + ch + 1)
101
+ act = activation[0, ch].cpu().numpy()
102
+ act = (act - act.min()) / (act.max() - act.min() + 1e-8)
103
+ im = plt.imshow(act, cmap=colormap)
104
+ plt.title(f'{name}\nChannel {ch} | Shape: {activation.shape}')
105
+ plt.axis('off')
106
+ if show_colorbar:
107
+ plt.colorbar(im, fraction=0.046, pad=0.04)
108
+ # If channels_per_layer > n_ch, fill rest with blanks
109
+ for ch in range(n_ch, channels_per_layer):
110
+ plt.subplot(n_rows, n_cols, row_idx*n_cols + ch + 1)
111
+ plt.axis('off')
112
+
113
+ plt.tight_layout()
114
+
115
+ # Save if directory provided
116
+ if save_dir:
117
+ os.makedirs(save_dir, exist_ok=True)
118
+ plt.savefig(os.path.join(save_dir, 'layer_visualizations.png'))
119
+ print(f"Visualizations saved to {save_dir}/layer_visualizations.png")
120
+
121
+ # Show plot
122
+ if show:
123
+ try:
124
+ plt.show()
125
+ except Exception as e:
126
+ print(f"Could not display plot: {e}")
127
+ print("But the visualization has been saved to the output directory!")
128
+ else:
129
+ plt.close()
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-anatomy
3
+ Version: 0.1.0
4
+ Summary: Layer-by-layer visualizer for PyTorch models
5
+ Author: Harshal Vilas Kale
6
+ License: MIT
7
+ License-File: LICENSE
8
+ Requires-Dist: torch
9
+ Requires-Dist: torchvision
10
+ Requires-Dist: matplotlib
11
+ Requires-Dist: numpy
12
+ Requires-Dist: Pillow
13
+ Requires-Dist: click
14
+ Dynamic: author
15
+ Dynamic: license
16
+ Dynamic: license-file
17
+ Dynamic: requires-dist
18
+ Dynamic: summary
@@ -0,0 +1,15 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ tests/test_visualizer.py
6
+ torch_anatomy/__init__.py
7
+ torch_anatomy/cli.py
8
+ torch_anatomy/utils.py
9
+ torch_anatomy/visualizer.py
10
+ torch_anatomy.egg-info/PKG-INFO
11
+ torch_anatomy.egg-info/SOURCES.txt
12
+ torch_anatomy.egg-info/dependency_links.txt
13
+ torch_anatomy.egg-info/entry_points.txt
14
+ torch_anatomy.egg-info/requires.txt
15
+ torch_anatomy.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ torch-anatomy = torch_anatomy.cli:main
@@ -0,0 +1,6 @@
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ numpy
5
+ Pillow
6
+ click
@@ -0,0 +1 @@
1
+ torch_anatomy