visualtorch 0.2.3__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {visualtorch-0.2.3/visualtorch.egg-info → visualtorch-0.2.5}/PKG-INFO +32 -3
  2. {visualtorch-0.2.3 → visualtorch-0.2.5}/setup.py +1 -1
  3. visualtorch-0.2.5/tests/test_graph.py +98 -0
  4. visualtorch-0.2.5/tests/test_layered.py +165 -0
  5. visualtorch-0.2.5/tests/test_lenet_style.py +159 -0
  6. visualtorch-0.2.5/tests/test_regression_issues.py +73 -0
  7. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/layered.py +4 -1
  8. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/lenet_style.py +4 -1
  9. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/utils/layer_utils.py +11 -3
  10. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/utils/utils.py +29 -8
  11. {visualtorch-0.2.3 → visualtorch-0.2.5/visualtorch.egg-info}/PKG-INFO +32 -3
  12. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/SOURCES.txt +4 -0
  13. {visualtorch-0.2.3 → visualtorch-0.2.5}/LICENSE +0 -0
  14. {visualtorch-0.2.3 → visualtorch-0.2.5}/README.md +0 -0
  15. {visualtorch-0.2.3 → visualtorch-0.2.5}/pyproject.toml +0 -0
  16. {visualtorch-0.2.3 → visualtorch-0.2.5}/setup.cfg +0 -0
  17. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/__init__.py +0 -0
  18. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/graph.py +0 -0
  19. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/utils/__init__.py +0 -0
  20. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/dependency_links.txt +0 -0
  21. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/requires.txt +0 -0
  22. {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: visualtorch
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -14,8 +14,37 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Operating System :: OS Independent
15
15
  Requires-Python: >=3.10
16
16
  Description-Content-Type: text/markdown
17
- Provides-Extra: dev
18
17
  License-File: LICENSE
18
+ Requires-Dist: pillow>=10.0.0
19
+ Requires-Dist: numpy>=1.18.1
20
+ Requires-Dist: aggdraw>=1.3.11
21
+ Requires-Dist: torch>=2.0.0
22
+ Provides-Extra: dev
23
+ Requires-Dist: myst-parser; extra == "dev"
24
+ Requires-Dist: nbsphinx; extra == "dev"
25
+ Requires-Dist: pandoc; extra == "dev"
26
+ Requires-Dist: sphinx<7.0; extra == "dev"
27
+ Requires-Dist: sphinx_autodoc_typehints; extra == "dev"
28
+ Requires-Dist: sphinx_book_theme; extra == "dev"
29
+ Requires-Dist: sphinx-copybutton; extra == "dev"
30
+ Requires-Dist: sphinx_design; extra == "dev"
31
+ Requires-Dist: sphinx_gallery; extra == "dev"
32
+ Requires-Dist: matplotlib; extra == "dev"
33
+ Requires-Dist: pre-commit; extra == "dev"
34
+ Requires-Dist: pytest; extra == "dev"
35
+ Dynamic: author
36
+ Dynamic: author-email
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: keywords
42
+ Dynamic: license
43
+ Dynamic: license-file
44
+ Dynamic: provides-extra
45
+ Dynamic: requires-dist
46
+ Dynamic: requires-python
47
+ Dynamic: summary
19
48
 
20
49
  <div align="center">
21
50
  <h1>🔥 VisualTorch 🔥</h1>
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
21
21
 
22
22
  setuptools.setup(
23
23
  name="visualtorch",
24
- version="0.2.3",
24
+ version="0.2.5",
25
25
  author="Willy Fitra Hendria",
26
26
  author_email="willyfitrahendria@gmail.com",
27
27
  description="Architecture visualization of Torch models",
@@ -0,0 +1,98 @@
1
+ """Tests for graph view."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ import torch
10
+ from PIL import Image
11
+ from torch import nn
12
+ from visualtorch import graph_view
13
+
14
+
15
+ @pytest.fixture()
16
+ def dense_model() -> nn.Module:
17
+ """A simple dense model creation."""
18
+
19
+ class SimpleDense(nn.Module):
20
+ """Simple Dense Model."""
21
+
22
+ def __init__(self) -> None:
23
+ super().__init__()
24
+ self.h0 = nn.Linear(4, 8)
25
+ self.h1 = nn.Linear(8, 8)
26
+ self.h2 = nn.Linear(8, 4)
27
+ self.out = nn.Linear(4, 2)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ """Define the forward pass."""
31
+ x = self.h0(x)
32
+ x = self.h1(x)
33
+ x = self.h2(x)
34
+ return self.out(x)
35
+
36
+ return SimpleDense()
37
+
38
+
39
+ @pytest.fixture()
40
+ def conv_model() -> nn.Module:
41
+ """A simple conv model, exercising the Conv2d/ConvolutionBackward0 path."""
42
+ return nn.Sequential(
43
+ nn.Conv2d(3, 8, 3, 1, 1),
44
+ nn.Conv2d(8, 16, 3, 1, 1),
45
+ )
46
+
47
+
48
+ @pytest.fixture()
49
+ def wide_dense_model() -> nn.Module:
50
+ """A dense model with a hidden layer wider than the default ellipsize_after threshold."""
51
+
52
+ class WideDense(nn.Module):
53
+ """A dense model with more than 10 hidden units, to trigger ellipsis drawing."""
54
+
55
+ def __init__(self) -> None:
56
+ super().__init__()
57
+ self.h0 = nn.Linear(4, 20)
58
+ self.out = nn.Linear(20, 2)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """Forward pass."""
62
+ return self.out(self.h0(x))
63
+
64
+ return WideDense()
65
+
66
+
67
+ def test_dense_model_graph_view_runs(dense_model: nn.Module) -> None:
68
+ """Test graph view using dense model."""
69
+ _ = graph_view(dense_model, input_shape=(1, 4))
70
+
71
+
72
+ def test_conv_model_graph_view_runs(conv_model: nn.Module) -> None:
73
+ """graph_view should support Conv2d layers, not just Linear."""
74
+ img = graph_view(conv_model, input_shape=(1, 3, 16, 16))
75
+ assert img is not None
76
+
77
+
78
+ def test_graph_view_ellipsizes_wide_layers(wide_dense_model: nn.Module) -> None:
79
+ """A hidden layer wider than ellipsize_after should draw an ellipsis, not crash."""
80
+ img = graph_view(wide_dense_model, input_shape=(1, 4), ellipsize_after=10)
81
+ assert img is not None
82
+
83
+
84
+ def test_graph_view_show_neurons_false(wide_dense_model: nn.Module) -> None:
85
+ """show_neurons=False should render one node per layer instead of per neuron."""
86
+ img = graph_view(wide_dense_model, input_shape=(1, 4), show_neurons=False)
87
+ assert img is not None
88
+
89
+
90
+ def test_graph_view_writes_to_file(dense_model: nn.Module, tmp_path: Path) -> None:
91
+ """to_file should save a readable image to disk."""
92
+ out_file = tmp_path / "graph.png"
93
+ graph_view(dense_model, input_shape=(1, 4), to_file=str(out_file))
94
+
95
+ assert out_file.exists()
96
+ with Image.open(out_file) as saved_img:
97
+ assert saved_img.size[0] > 0
98
+ assert saved_img.size[1] > 0
@@ -0,0 +1,165 @@
1
+ """Tests for layered view."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ import torch
10
+ import torch.nn.functional as func
11
+ from PIL import Image
12
+ from torch import nn
13
+ from visualtorch import layered_view
14
+
15
+
16
+ @pytest.fixture()
17
+ def sequential_model() -> nn.Sequential:
18
+ """Define Sequential torch model for testing."""
19
+ return nn.Sequential(
20
+ nn.Conv2d(3, 64, 3, 1, 1),
21
+ nn.ReLU(),
22
+ nn.Conv2d(64, 128, 3, 1, 1),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(2, 2),
25
+ )
26
+
27
+
28
+ @pytest.fixture()
29
+ def module_list_model() -> nn.ModuleList:
30
+ """Define ModuleList-based torch model for testing."""
31
+ return nn.ModuleList(
32
+ [
33
+ nn.Conv2d(3, 64, 3, 1, 1),
34
+ nn.ReLU(),
35
+ nn.Conv2d(64, 128, 3, 1, 1),
36
+ nn.ReLU(),
37
+ nn.MaxPool2d(2, 2),
38
+ ],
39
+ )
40
+
41
+
42
+ @pytest.fixture()
43
+ def custom_model() -> nn.Module:
44
+ """Define the custom model."""
45
+
46
+ class CustomModel(nn.Module):
47
+ """A simple custom cnn model."""
48
+
49
+ def __init__(self) -> None:
50
+ super().__init__()
51
+ self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
52
+ self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ """Funcorward pass."""
56
+ x = func.relu(self.conv1(x))
57
+ x = func.relu(self.conv2(x))
58
+ return func.max_pool2d(x, 2, 2)
59
+
60
+ # Create an instance of the custom model
61
+ return CustomModel()
62
+
63
+
64
+ @pytest.fixture()
65
+ def lstm_model() -> nn.Module:
66
+ """Define a simple LSTM model for testing."""
67
+
68
+ class LSTMModel(nn.Module):
69
+ """A simple LSTM model."""
70
+
71
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
72
+ super().__init__()
73
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ """Forward pass."""
77
+ out, _ = self.lstm(x)
78
+ return out
79
+
80
+ # Create an instance of the LSTM model
81
+ return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
82
+
83
+
84
+ @pytest.fixture()
85
+ def classifier_model() -> nn.Module:
86
+ """Define a model ending in a 1D (per-sample) output, e.g. classification logits."""
87
+
88
+ class ClassifierModel(nn.Module):
89
+ """A cnn model that ends with a 1D output."""
90
+
91
+ def __init__(self) -> None:
92
+ super().__init__()
93
+ self.conv = nn.Conv2d(3, 8, 3, 1, 1)
94
+ self.pool = nn.AdaptiveAvgPool2d(1)
95
+ self.fc = nn.Linear(8, 5)
96
+
97
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
98
+ """Forward pass."""
99
+ x = self.conv(x)
100
+ x = self.pool(x)
101
+ x = torch.flatten(x, 1)
102
+ return self.fc(x)
103
+
104
+ return ClassifierModel()
105
+
106
+
107
+ def test_sequential_model_layered_view_runs(sequential_model: nn.Sequential) -> None:
108
+ """Test layered view on sequential model."""
109
+ _ = layered_view(sequential_model, input_shape=(1, 3, 224, 224))
110
+
111
+
112
+ def test_module_list_model_layered_view_runs(module_list_model: nn.ModuleList) -> None:
113
+ """Test layered view on module list model."""
114
+ _ = layered_view(module_list_model, input_shape=(1, 3, 224, 224))
115
+
116
+
117
+ def test_custom_model_layered_view_runs(custom_model: nn.Module) -> None:
118
+ """Test layered view on custom model."""
119
+ _ = layered_view(custom_model, input_shape=(1, 3, 224, 224))
120
+
121
+
122
+ def test_lstm_model_layered_view_runs(lstm_model: nn.Module) -> None:
123
+ """Test layered view on lstm model."""
124
+ _ = layered_view(lstm_model, input_shape=(1, 10, 10))
125
+
126
+
127
+ @pytest.mark.parametrize("orientation", ["x", "y", "z"])
128
+ def test_layered_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
129
+ """Test layered view on a model with a 1D output, for every supported orientation."""
130
+ img = layered_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
131
+ assert img is not None
132
+
133
+
134
+ def test_layered_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
135
+ """An unsupported one_dim_orientation should raise a clear ValueError."""
136
+ with pytest.raises(ValueError, match="unsupported orientation"):
137
+ layered_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
138
+
139
+
140
+ def test_layered_view_with_type_and_index_ignore(sequential_model: nn.Sequential) -> None:
141
+ """Layers matched by type_ignore or index_ignore should be skipped without error."""
142
+ img = layered_view(
143
+ sequential_model,
144
+ input_shape=(1, 3, 224, 224),
145
+ type_ignore=[nn.ReLU],
146
+ index_ignore=[0],
147
+ )
148
+ assert img is not None
149
+
150
+
151
+ def test_layered_view_with_legend(sequential_model: nn.Sequential) -> None:
152
+ """legend=True should append a legend without error."""
153
+ img = layered_view(sequential_model, input_shape=(1, 3, 224, 224), legend=True)
154
+ assert img is not None
155
+
156
+
157
+ def test_layered_view_writes_to_file(sequential_model: nn.Sequential, tmp_path: Path) -> None:
158
+ """to_file should save a readable image to disk."""
159
+ out_file = tmp_path / "layered.png"
160
+ layered_view(sequential_model, input_shape=(1, 3, 224, 224), to_file=str(out_file))
161
+
162
+ assert out_file.exists()
163
+ with Image.open(out_file) as saved_img:
164
+ assert saved_img.size[0] > 0
165
+ assert saved_img.size[1] > 0
@@ -0,0 +1,159 @@
1
+ """Tests for lenet view."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ import torch
10
+ import torch.nn.functional as func
11
+ from PIL import Image
12
+ from torch import nn
13
+ from visualtorch import lenet_view
14
+
15
+
16
+ @pytest.fixture()
17
+ def sequential_model() -> nn.Sequential:
18
+ """Define Sequential torch model for testing."""
19
+ return nn.Sequential(
20
+ nn.Conv2d(3, 64, 3, 1, 1),
21
+ nn.ReLU(),
22
+ nn.Conv2d(64, 128, 3, 1, 1),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(2, 2),
25
+ )
26
+
27
+
28
+ @pytest.fixture()
29
+ def module_list_model() -> nn.ModuleList:
30
+ """Define ModuleList-based torch model for testing."""
31
+ return nn.ModuleList(
32
+ [
33
+ nn.Conv2d(3, 64, 3, 1, 1),
34
+ nn.ReLU(),
35
+ nn.Conv2d(64, 128, 3, 1, 1),
36
+ nn.ReLU(),
37
+ nn.MaxPool2d(2, 2),
38
+ ],
39
+ )
40
+
41
+
42
+ @pytest.fixture()
43
+ def custom_model() -> nn.Module:
44
+ """Define the custom model."""
45
+
46
+ class CustomModel(nn.Module):
47
+ """A simple custom cnn model."""
48
+
49
+ def __init__(self) -> None:
50
+ super().__init__()
51
+ self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
52
+ self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ """Funcorward pass."""
56
+ x = func.relu(self.conv1(x))
57
+ x = func.relu(self.conv2(x))
58
+ return func.max_pool2d(x, 2, 2)
59
+
60
+ # Create an instance of the custom model
61
+ return CustomModel()
62
+
63
+
64
+ @pytest.fixture()
65
+ def lstm_model() -> nn.Module:
66
+ """Define a simple LSTM model for testing."""
67
+
68
+ class LSTMModel(nn.Module):
69
+ """A simple LSTM model."""
70
+
71
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
72
+ super().__init__()
73
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ """Forward pass."""
77
+ out, _ = self.lstm(x)
78
+ return out
79
+
80
+ # Create an instance of the LSTM model
81
+ return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
82
+
83
+
84
+ @pytest.fixture()
85
+ def classifier_model() -> nn.Module:
86
+ """Define a model ending in a 1D (per-sample) output, e.g. classification logits."""
87
+
88
+ class ClassifierModel(nn.Module):
89
+ """A cnn model that ends with a 1D output."""
90
+
91
+ def __init__(self) -> None:
92
+ super().__init__()
93
+ self.conv = nn.Conv2d(3, 8, 3, 1, 1)
94
+ self.pool = nn.AdaptiveAvgPool2d(1)
95
+ self.fc = nn.Linear(8, 5)
96
+
97
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
98
+ """Forward pass."""
99
+ x = self.conv(x)
100
+ x = self.pool(x)
101
+ x = torch.flatten(x, 1)
102
+ return self.fc(x)
103
+
104
+ return ClassifierModel()
105
+
106
+
107
+ def test_sequential_model_lenet_view_runs(sequential_model: nn.Sequential) -> None:
108
+ """Test lenet view on sequential model."""
109
+ _ = lenet_view(sequential_model, input_shape=(1, 3, 224, 224))
110
+
111
+
112
+ def test_module_list_model_lenet_view_runs(module_list_model: nn.ModuleList) -> None:
113
+ """Test lenet view on module list model."""
114
+ _ = lenet_view(module_list_model, input_shape=(1, 3, 224, 224))
115
+
116
+
117
+ def test_custom_model_lenet_view_runs(custom_model: nn.Module) -> None:
118
+ """Test lenet view on custom model."""
119
+ _ = lenet_view(custom_model, input_shape=(1, 3, 224, 224))
120
+
121
+
122
+ def test_lstm_model_layered_view_runs(lstm_model: nn.Module) -> None:
123
+ """Test layered view on lstm model."""
124
+ _ = lenet_view(lstm_model, input_shape=(1, 10, 10))
125
+
126
+
127
+ @pytest.mark.parametrize("orientation", ["x", "y", "z"])
128
+ def test_lenet_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
129
+ """Test lenet view on a model with a 1D output, for every supported orientation."""
130
+ img = lenet_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
131
+ assert img is not None
132
+
133
+
134
+ def test_lenet_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
135
+ """An unsupported one_dim_orientation should raise a clear ValueError."""
136
+ with pytest.raises(ValueError, match="unsupported orientation"):
137
+ lenet_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
138
+
139
+
140
+ def test_lenet_view_with_type_and_index_ignore(sequential_model: nn.Sequential) -> None:
141
+ """Layers matched by type_ignore or index_ignore should be skipped without error."""
142
+ img = lenet_view(
143
+ sequential_model,
144
+ input_shape=(1, 3, 224, 224),
145
+ type_ignore=[nn.ReLU],
146
+ index_ignore=[0],
147
+ )
148
+ assert img is not None
149
+
150
+
151
+ def test_lenet_view_writes_to_file(sequential_model: nn.Sequential, tmp_path: Path) -> None:
152
+ """to_file should save a readable image to disk."""
153
+ out_file = tmp_path / "lenet.png"
154
+ lenet_view(sequential_model, input_shape=(1, 3, 224, 224), to_file=str(out_file))
155
+
156
+ assert out_file.exists()
157
+ with Image.open(out_file) as saved_img:
158
+ assert saved_img.size[0] > 0
159
+ assert saved_img.size[1] > 0
@@ -0,0 +1,73 @@
1
+ """Regression tests for shape-handling crashes reported in open GitHub issues.
2
+
3
+ See https://github.com/willyfh/visualtorch/issues/63, /68, /69.
4
+ """
5
+
6
+ # Copyright (C) 2024 Willy Fitra Hendria
7
+ # SPDX-License-Identifier: MIT
8
+
9
+ import pytest
10
+ import torch
11
+ from torch import nn
12
+ from visualtorch import layered_view, lenet_view
13
+ from visualtorch.utils.utils import self_multiply
14
+
15
+
16
+ @pytest.fixture()
17
+ def multi_output_container_model() -> nn.Module:
18
+ """A model whose inner block is a container (not Sequential/ModuleList) that returns multiple tensors.
19
+
20
+ This mirrors timm's ``FeatureListNet`` used in issue #69: a container module that isn't
21
+ ``nn.Sequential``/``nn.ModuleList`` and returns multiple differently-shaped tensors.
22
+ """
23
+
24
+ class FeaturePyramidBlock(nn.Module):
25
+ def __init__(self) -> None:
26
+ super().__init__()
27
+ self.stage1 = nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1)
28
+ self.stage2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)
29
+ self.stage3 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
30
+
31
+ def forward(self, x: torch.Tensor) -> list:
32
+ f1 = self.stage1(x)
33
+ f2 = self.stage2(f1)
34
+ f3 = self.stage3(f2)
35
+ return [f1, f2, f3]
36
+
37
+ class MultiScaleNet(nn.Module):
38
+ def __init__(self) -> None:
39
+ super().__init__()
40
+ self.features = FeaturePyramidBlock()
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ return self.features(x)[-1]
44
+
45
+ return MultiScaleNet()
46
+
47
+
48
+ def test_layered_view_multi_output_container(multi_output_container_model: nn.Module) -> None:
49
+ """layered_view should not crash on container modules that output multiple tensors."""
50
+ img = layered_view(multi_output_container_model, input_shape=(1, 3, 32, 32))
51
+ assert img is not None
52
+
53
+
54
+ def test_lenet_view_multi_output_container(multi_output_container_model: nn.Module) -> None:
55
+ """lenet_view should not crash on container modules that output multiple tensors."""
56
+ img = lenet_view(multi_output_container_model, input_shape=(1, 3, 32, 32))
57
+ assert img is not None
58
+
59
+
60
+ def test_self_multiply_handles_nested_shape() -> None:
61
+ """self_multiply should always reduce to a scalar, even if an element is itself a shape."""
62
+ nested_shape = (1, torch.Size([4, 8]))
63
+ result = self_multiply(nested_shape)
64
+ assert isinstance(result, int)
65
+
66
+
67
+ def test_lenet_view_rejects_multi_input_shape_with_clear_error() -> None:
68
+ """Multi-tensor-input models aren't supported yet; the failure should be a clear ValueError."""
69
+ model = nn.Linear(10, 5)
70
+ multi_input_shape = ((1, 10), (1, 10))
71
+
72
+ with pytest.raises(ValueError, match="single"):
73
+ lenet_view(model, input_shape=multi_input_shape)
@@ -22,6 +22,7 @@ from .utils.utils import (
22
22
  get_rgba_tuple,
23
23
  linear_layout,
24
24
  self_multiply,
25
+ validate_input_shape,
25
26
  vertical_image_concat,
26
27
  )
27
28
 
@@ -85,6 +86,8 @@ def layered_view(
85
86
  """
86
87
  # Iterate over the model to compute bounds and generate boxes
87
88
 
89
+ validate_input_shape(input_shape)
90
+
88
91
  x_off = -1
89
92
 
90
93
  img_height = 0
@@ -291,7 +294,7 @@ def _create_architecture(
291
294
  layer = layers[key]["module"]
292
295
  shape = layers[key]["output_shape"]
293
296
  # Do no render the SpacingDummyLayer, just increase the pointer
294
- if type(layer) == SpacingDummyLayer:
297
+ if type(layer) is SpacingDummyLayer:
295
298
  current_z += layer.spacing
296
299
  continue
297
300
 
@@ -19,6 +19,7 @@ from .utils.utils import (
19
19
  StackedBox,
20
20
  get_rgba_tuple,
21
21
  self_multiply,
22
+ validate_input_shape,
22
23
  )
23
24
 
24
25
 
@@ -82,6 +83,8 @@ def lenet_view(
82
83
  """
83
84
  # Iterate over the model to compute bounds and generate boxes
84
85
 
86
+ validate_input_shape(input_shape)
87
+
85
88
  x_off = -1
86
89
 
87
90
  img_height = 0
@@ -247,7 +250,7 @@ def _create_architecture(
247
250
  layer = layers[key]["module"]
248
251
  shape = layers[key]["output_shape"]
249
252
  # Do no render the SpacingDummyLayer, just increase the pointer
250
- if type(layer) == SpacingDummyLayer:
253
+ if type(layer) is SpacingDummyLayer:
251
254
  current_x += layer.spacing
252
255
  continue
253
256
 
@@ -185,12 +185,20 @@ def register_hook(
185
185
  m_key = "%s-%i" % (class_name, module_idx + 1)
186
186
  layers[m_key] = OrderedDict()
187
187
  layers[m_key]["module"] = module
188
- if isinstance(out, list | tuple):
189
- layers[m_key]["output_shape"] = tuple((-1,) + o.size()[1:] for o in out)
188
+ if isinstance(out, tuple | list):
189
+ if len(out) > 0 and hasattr(out[0], "size"):
190
+ layers[m_key]["output_shape"] = out[0].size()
191
+ else:
192
+ layers[m_key]["output_shape"] = tuple(o.size() for o in out if hasattr(o, "size"))
190
193
  else:
191
194
  layers[m_key]["output_shape"] = out.size()
192
195
 
193
- if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and module is not model:
196
+ # Only hook leaf modules (no children). Container modules - whether nn.Sequential,
197
+ # nn.ModuleList, or a custom container such as timm's FeatureListNet - would otherwise be
198
+ # captured as if they were a single layer, with their multi-tensor output mistaken for one
199
+ # layer's output shape.
200
+ is_leaf = len(list(module.children())) == 0
201
+ if is_leaf and module is not model:
194
202
  hooks.append(module.register_forward_hook(hook))
195
203
 
196
204
 
@@ -261,23 +261,44 @@ def get_keys_by_value(d: dict, v: int) -> Generator:
261
261
  yield key
262
262
 
263
263
 
264
- def self_multiply(tensor_tuple: tuple) -> int | float:
264
+ def validate_input_shape(input_shape: tuple) -> None:
265
+ """Validate that input_shape describes a single input tensor.
266
+
267
+ Args:
268
+ input_shape (tuple): The shape to validate.
269
+
270
+ Raises:
271
+ ValueError: If input_shape is not a single flat tuple of ints, e.g. when a tuple of
272
+ per-tensor shapes was passed for a model that takes multiple separate input tensors.
273
+ """
274
+ if not isinstance(input_shape, tuple) or not all(isinstance(dim, int) for dim in input_shape):
275
+ error_msg = (
276
+ "input_shape must be a single tuple of ints, e.g. (1, 3, 224, 224). "
277
+ f"Got {input_shape!r} instead. Visualizing models that take multiple separate "
278
+ "input tensors is not supported yet."
279
+ )
280
+ raise ValueError(error_msg)
281
+
282
+
283
+ def self_multiply(tensor_tuple: tuple | list) -> int | float:
265
284
  """Multiplies all elements in the tuple together.
266
285
 
286
+ Elements that are themselves a tuple/list (e.g. a nested torch.Size, which can end up here
287
+ when a layer's captured output shape wasn't a plain flat shape) are flattened by multiplying
288
+ their own elements together first, so the result is always a scalar.
289
+
267
290
  Args:
268
- tensor_tuple (tuple): A tuple containing tensors.
291
+ tensor_tuple (tuple or list): A tuple containing tensors.
269
292
 
270
293
  Returns:
271
294
  int or float: The result of multiplying all elements together.
272
295
  """
273
- tensor_list = list(tensor_tuple)
274
- if None in tensor_list:
275
- tensor_list.remove(None)
296
+ tensor_list = [v for v in tensor_tuple if v is not None]
276
297
  if len(tensor_list) == 0:
277
298
  return 0
278
- s = tensor_list[0]
279
- for i in range(1, len(tensor_list)):
280
- s *= tensor_list[i]
299
+ s = 1
300
+ for v in tensor_list:
301
+ s *= self_multiply(v) if isinstance(v, tuple | list) else v
281
302
  return s
282
303
 
283
304
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: visualtorch
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -14,8 +14,37 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Operating System :: OS Independent
15
15
  Requires-Python: >=3.10
16
16
  Description-Content-Type: text/markdown
17
- Provides-Extra: dev
18
17
  License-File: LICENSE
18
+ Requires-Dist: pillow>=10.0.0
19
+ Requires-Dist: numpy>=1.18.1
20
+ Requires-Dist: aggdraw>=1.3.11
21
+ Requires-Dist: torch>=2.0.0
22
+ Provides-Extra: dev
23
+ Requires-Dist: myst-parser; extra == "dev"
24
+ Requires-Dist: nbsphinx; extra == "dev"
25
+ Requires-Dist: pandoc; extra == "dev"
26
+ Requires-Dist: sphinx<7.0; extra == "dev"
27
+ Requires-Dist: sphinx_autodoc_typehints; extra == "dev"
28
+ Requires-Dist: sphinx_book_theme; extra == "dev"
29
+ Requires-Dist: sphinx-copybutton; extra == "dev"
30
+ Requires-Dist: sphinx_design; extra == "dev"
31
+ Requires-Dist: sphinx_gallery; extra == "dev"
32
+ Requires-Dist: matplotlib; extra == "dev"
33
+ Requires-Dist: pre-commit; extra == "dev"
34
+ Requires-Dist: pytest; extra == "dev"
35
+ Dynamic: author
36
+ Dynamic: author-email
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: keywords
42
+ Dynamic: license
43
+ Dynamic: license-file
44
+ Dynamic: provides-extra
45
+ Dynamic: requires-dist
46
+ Dynamic: requires-python
47
+ Dynamic: summary
19
48
 
20
49
  <div align="center">
21
50
  <h1>🔥 VisualTorch 🔥</h1>
@@ -2,6 +2,10 @@ LICENSE
2
2
  README.md
3
3
  pyproject.toml
4
4
  setup.py
5
+ tests/test_graph.py
6
+ tests/test_layered.py
7
+ tests/test_lenet_style.py
8
+ tests/test_regression_issues.py
5
9
  visualtorch/__init__.py
6
10
  visualtorch/graph.py
7
11
  visualtorch/layered.py
File without changes
File without changes
File without changes
File without changes