visualtorch 0.2.3__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visualtorch-0.2.3/visualtorch.egg-info → visualtorch-0.2.5}/PKG-INFO +32 -3
- {visualtorch-0.2.3 → visualtorch-0.2.5}/setup.py +1 -1
- visualtorch-0.2.5/tests/test_graph.py +98 -0
- visualtorch-0.2.5/tests/test_layered.py +165 -0
- visualtorch-0.2.5/tests/test_lenet_style.py +159 -0
- visualtorch-0.2.5/tests/test_regression_issues.py +73 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/layered.py +4 -1
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/lenet_style.py +4 -1
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/utils/layer_utils.py +11 -3
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/utils/utils.py +29 -8
- {visualtorch-0.2.3 → visualtorch-0.2.5/visualtorch.egg-info}/PKG-INFO +32 -3
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/SOURCES.txt +4 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/LICENSE +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/README.md +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/pyproject.toml +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/setup.cfg +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/__init__.py +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/graph.py +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch/utils/__init__.py +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/dependency_links.txt +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/requires.txt +0 -0
- {visualtorch-0.2.3 → visualtorch-0.2.5}/visualtorch.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -14,8 +14,37 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
14
14
|
Classifier: Operating System :: OS Independent
|
|
15
15
|
Requires-Python: >=3.10
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
|
-
Provides-Extra: dev
|
|
18
17
|
License-File: LICENSE
|
|
18
|
+
Requires-Dist: pillow>=10.0.0
|
|
19
|
+
Requires-Dist: numpy>=1.18.1
|
|
20
|
+
Requires-Dist: aggdraw>=1.3.11
|
|
21
|
+
Requires-Dist: torch>=2.0.0
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: myst-parser; extra == "dev"
|
|
24
|
+
Requires-Dist: nbsphinx; extra == "dev"
|
|
25
|
+
Requires-Dist: pandoc; extra == "dev"
|
|
26
|
+
Requires-Dist: sphinx<7.0; extra == "dev"
|
|
27
|
+
Requires-Dist: sphinx_autodoc_typehints; extra == "dev"
|
|
28
|
+
Requires-Dist: sphinx_book_theme; extra == "dev"
|
|
29
|
+
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
30
|
+
Requires-Dist: sphinx_design; extra == "dev"
|
|
31
|
+
Requires-Dist: sphinx_gallery; extra == "dev"
|
|
32
|
+
Requires-Dist: matplotlib; extra == "dev"
|
|
33
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest; extra == "dev"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: author-email
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: keywords
|
|
42
|
+
Dynamic: license
|
|
43
|
+
Dynamic: license-file
|
|
44
|
+
Dynamic: provides-extra
|
|
45
|
+
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
47
|
+
Dynamic: summary
|
|
19
48
|
|
|
20
49
|
<div align="center">
|
|
21
50
|
<h1>🔥 VisualTorch 🔥</h1>
|
|
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
|
|
|
21
21
|
|
|
22
22
|
setuptools.setup(
|
|
23
23
|
name="visualtorch",
|
|
24
|
-
version="0.2.
|
|
24
|
+
version="0.2.5",
|
|
25
25
|
author="Willy Fitra Hendria",
|
|
26
26
|
author_email="willyfitrahendria@gmail.com",
|
|
27
27
|
description="Architecture visualization of Torch models",
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Tests for graph view."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from torch import nn
|
|
12
|
+
from visualtorch import graph_view
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture()
|
|
16
|
+
def dense_model() -> nn.Module:
|
|
17
|
+
"""A simple dense model creation."""
|
|
18
|
+
|
|
19
|
+
class SimpleDense(nn.Module):
|
|
20
|
+
"""Simple Dense Model."""
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.h0 = nn.Linear(4, 8)
|
|
25
|
+
self.h1 = nn.Linear(8, 8)
|
|
26
|
+
self.h2 = nn.Linear(8, 4)
|
|
27
|
+
self.out = nn.Linear(4, 2)
|
|
28
|
+
|
|
29
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
"""Define the forward pass."""
|
|
31
|
+
x = self.h0(x)
|
|
32
|
+
x = self.h1(x)
|
|
33
|
+
x = self.h2(x)
|
|
34
|
+
return self.out(x)
|
|
35
|
+
|
|
36
|
+
return SimpleDense()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture()
|
|
40
|
+
def conv_model() -> nn.Module:
|
|
41
|
+
"""A simple conv model, exercising the Conv2d/ConvolutionBackward0 path."""
|
|
42
|
+
return nn.Sequential(
|
|
43
|
+
nn.Conv2d(3, 8, 3, 1, 1),
|
|
44
|
+
nn.Conv2d(8, 16, 3, 1, 1),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture()
|
|
49
|
+
def wide_dense_model() -> nn.Module:
|
|
50
|
+
"""A dense model with a hidden layer wider than the default ellipsize_after threshold."""
|
|
51
|
+
|
|
52
|
+
class WideDense(nn.Module):
|
|
53
|
+
"""A dense model with more than 10 hidden units, to trigger ellipsis drawing."""
|
|
54
|
+
|
|
55
|
+
def __init__(self) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.h0 = nn.Linear(4, 20)
|
|
58
|
+
self.out = nn.Linear(20, 2)
|
|
59
|
+
|
|
60
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
"""Forward pass."""
|
|
62
|
+
return self.out(self.h0(x))
|
|
63
|
+
|
|
64
|
+
return WideDense()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_dense_model_graph_view_runs(dense_model: nn.Module) -> None:
|
|
68
|
+
"""Test graph view using dense model."""
|
|
69
|
+
_ = graph_view(dense_model, input_shape=(1, 4))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_conv_model_graph_view_runs(conv_model: nn.Module) -> None:
|
|
73
|
+
"""graph_view should support Conv2d layers, not just Linear."""
|
|
74
|
+
img = graph_view(conv_model, input_shape=(1, 3, 16, 16))
|
|
75
|
+
assert img is not None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_graph_view_ellipsizes_wide_layers(wide_dense_model: nn.Module) -> None:
|
|
79
|
+
"""A hidden layer wider than ellipsize_after should draw an ellipsis, not crash."""
|
|
80
|
+
img = graph_view(wide_dense_model, input_shape=(1, 4), ellipsize_after=10)
|
|
81
|
+
assert img is not None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_graph_view_show_neurons_false(wide_dense_model: nn.Module) -> None:
|
|
85
|
+
"""show_neurons=False should render one node per layer instead of per neuron."""
|
|
86
|
+
img = graph_view(wide_dense_model, input_shape=(1, 4), show_neurons=False)
|
|
87
|
+
assert img is not None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_graph_view_writes_to_file(dense_model: nn.Module, tmp_path: Path) -> None:
|
|
91
|
+
"""to_file should save a readable image to disk."""
|
|
92
|
+
out_file = tmp_path / "graph.png"
|
|
93
|
+
graph_view(dense_model, input_shape=(1, 4), to_file=str(out_file))
|
|
94
|
+
|
|
95
|
+
assert out_file.exists()
|
|
96
|
+
with Image.open(out_file) as saved_img:
|
|
97
|
+
assert saved_img.size[0] > 0
|
|
98
|
+
assert saved_img.size[1] > 0
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Tests for layered view."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as func
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from torch import nn
|
|
13
|
+
from visualtorch import layered_view
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture()
|
|
17
|
+
def sequential_model() -> nn.Sequential:
|
|
18
|
+
"""Define Sequential torch model for testing."""
|
|
19
|
+
return nn.Sequential(
|
|
20
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
|
21
|
+
nn.ReLU(),
|
|
22
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
|
23
|
+
nn.ReLU(),
|
|
24
|
+
nn.MaxPool2d(2, 2),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.fixture()
|
|
29
|
+
def module_list_model() -> nn.ModuleList:
|
|
30
|
+
"""Define ModuleList-based torch model for testing."""
|
|
31
|
+
return nn.ModuleList(
|
|
32
|
+
[
|
|
33
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
|
34
|
+
nn.ReLU(),
|
|
35
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
|
36
|
+
nn.ReLU(),
|
|
37
|
+
nn.MaxPool2d(2, 2),
|
|
38
|
+
],
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture()
|
|
43
|
+
def custom_model() -> nn.Module:
|
|
44
|
+
"""Define the custom model."""
|
|
45
|
+
|
|
46
|
+
class CustomModel(nn.Module):
|
|
47
|
+
"""A simple custom cnn model."""
|
|
48
|
+
|
|
49
|
+
def __init__(self) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
|
|
52
|
+
self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
"""Funcorward pass."""
|
|
56
|
+
x = func.relu(self.conv1(x))
|
|
57
|
+
x = func.relu(self.conv2(x))
|
|
58
|
+
return func.max_pool2d(x, 2, 2)
|
|
59
|
+
|
|
60
|
+
# Create an instance of the custom model
|
|
61
|
+
return CustomModel()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture()
|
|
65
|
+
def lstm_model() -> nn.Module:
|
|
66
|
+
"""Define a simple LSTM model for testing."""
|
|
67
|
+
|
|
68
|
+
class LSTMModel(nn.Module):
|
|
69
|
+
"""A simple LSTM model."""
|
|
70
|
+
|
|
71
|
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
|
74
|
+
|
|
75
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
"""Forward pass."""
|
|
77
|
+
out, _ = self.lstm(x)
|
|
78
|
+
return out
|
|
79
|
+
|
|
80
|
+
# Create an instance of the LSTM model
|
|
81
|
+
return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pytest.fixture()
|
|
85
|
+
def classifier_model() -> nn.Module:
|
|
86
|
+
"""Define a model ending in a 1D (per-sample) output, e.g. classification logits."""
|
|
87
|
+
|
|
88
|
+
class ClassifierModel(nn.Module):
|
|
89
|
+
"""A cnn model that ends with a 1D output."""
|
|
90
|
+
|
|
91
|
+
def __init__(self) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
|
|
94
|
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
95
|
+
self.fc = nn.Linear(8, 5)
|
|
96
|
+
|
|
97
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
98
|
+
"""Forward pass."""
|
|
99
|
+
x = self.conv(x)
|
|
100
|
+
x = self.pool(x)
|
|
101
|
+
x = torch.flatten(x, 1)
|
|
102
|
+
return self.fc(x)
|
|
103
|
+
|
|
104
|
+
return ClassifierModel()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_sequential_model_layered_view_runs(sequential_model: nn.Sequential) -> None:
|
|
108
|
+
"""Test layered view on sequential model."""
|
|
109
|
+
_ = layered_view(sequential_model, input_shape=(1, 3, 224, 224))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_module_list_model_layered_view_runs(module_list_model: nn.ModuleList) -> None:
|
|
113
|
+
"""Test layered view on module list model."""
|
|
114
|
+
_ = layered_view(module_list_model, input_shape=(1, 3, 224, 224))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_custom_model_layered_view_runs(custom_model: nn.Module) -> None:
|
|
118
|
+
"""Test layered view on custom model."""
|
|
119
|
+
_ = layered_view(custom_model, input_shape=(1, 3, 224, 224))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_lstm_model_layered_view_runs(lstm_model: nn.Module) -> None:
|
|
123
|
+
"""Test layered view on lstm model."""
|
|
124
|
+
_ = layered_view(lstm_model, input_shape=(1, 10, 10))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@pytest.mark.parametrize("orientation", ["x", "y", "z"])
|
|
128
|
+
def test_layered_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
|
|
129
|
+
"""Test layered view on a model with a 1D output, for every supported orientation."""
|
|
130
|
+
img = layered_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
|
|
131
|
+
assert img is not None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_layered_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
|
|
135
|
+
"""An unsupported one_dim_orientation should raise a clear ValueError."""
|
|
136
|
+
with pytest.raises(ValueError, match="unsupported orientation"):
|
|
137
|
+
layered_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_layered_view_with_type_and_index_ignore(sequential_model: nn.Sequential) -> None:
|
|
141
|
+
"""Layers matched by type_ignore or index_ignore should be skipped without error."""
|
|
142
|
+
img = layered_view(
|
|
143
|
+
sequential_model,
|
|
144
|
+
input_shape=(1, 3, 224, 224),
|
|
145
|
+
type_ignore=[nn.ReLU],
|
|
146
|
+
index_ignore=[0],
|
|
147
|
+
)
|
|
148
|
+
assert img is not None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_layered_view_with_legend(sequential_model: nn.Sequential) -> None:
|
|
152
|
+
"""legend=True should append a legend without error."""
|
|
153
|
+
img = layered_view(sequential_model, input_shape=(1, 3, 224, 224), legend=True)
|
|
154
|
+
assert img is not None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_layered_view_writes_to_file(sequential_model: nn.Sequential, tmp_path: Path) -> None:
|
|
158
|
+
"""to_file should save a readable image to disk."""
|
|
159
|
+
out_file = tmp_path / "layered.png"
|
|
160
|
+
layered_view(sequential_model, input_shape=(1, 3, 224, 224), to_file=str(out_file))
|
|
161
|
+
|
|
162
|
+
assert out_file.exists()
|
|
163
|
+
with Image.open(out_file) as saved_img:
|
|
164
|
+
assert saved_img.size[0] > 0
|
|
165
|
+
assert saved_img.size[1] > 0
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""Tests for lenet view."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as func
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from torch import nn
|
|
13
|
+
from visualtorch import lenet_view
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture()
|
|
17
|
+
def sequential_model() -> nn.Sequential:
|
|
18
|
+
"""Define Sequential torch model for testing."""
|
|
19
|
+
return nn.Sequential(
|
|
20
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
|
21
|
+
nn.ReLU(),
|
|
22
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
|
23
|
+
nn.ReLU(),
|
|
24
|
+
nn.MaxPool2d(2, 2),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.fixture()
|
|
29
|
+
def module_list_model() -> nn.ModuleList:
|
|
30
|
+
"""Define ModuleList-based torch model for testing."""
|
|
31
|
+
return nn.ModuleList(
|
|
32
|
+
[
|
|
33
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
|
34
|
+
nn.ReLU(),
|
|
35
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
|
36
|
+
nn.ReLU(),
|
|
37
|
+
nn.MaxPool2d(2, 2),
|
|
38
|
+
],
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture()
|
|
43
|
+
def custom_model() -> nn.Module:
|
|
44
|
+
"""Define the custom model."""
|
|
45
|
+
|
|
46
|
+
class CustomModel(nn.Module):
|
|
47
|
+
"""A simple custom cnn model."""
|
|
48
|
+
|
|
49
|
+
def __init__(self) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
|
|
52
|
+
self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
"""Funcorward pass."""
|
|
56
|
+
x = func.relu(self.conv1(x))
|
|
57
|
+
x = func.relu(self.conv2(x))
|
|
58
|
+
return func.max_pool2d(x, 2, 2)
|
|
59
|
+
|
|
60
|
+
# Create an instance of the custom model
|
|
61
|
+
return CustomModel()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture()
|
|
65
|
+
def lstm_model() -> nn.Module:
|
|
66
|
+
"""Define a simple LSTM model for testing."""
|
|
67
|
+
|
|
68
|
+
class LSTMModel(nn.Module):
|
|
69
|
+
"""A simple LSTM model."""
|
|
70
|
+
|
|
71
|
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
|
74
|
+
|
|
75
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
"""Forward pass."""
|
|
77
|
+
out, _ = self.lstm(x)
|
|
78
|
+
return out
|
|
79
|
+
|
|
80
|
+
# Create an instance of the LSTM model
|
|
81
|
+
return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pytest.fixture()
|
|
85
|
+
def classifier_model() -> nn.Module:
|
|
86
|
+
"""Define a model ending in a 1D (per-sample) output, e.g. classification logits."""
|
|
87
|
+
|
|
88
|
+
class ClassifierModel(nn.Module):
|
|
89
|
+
"""A cnn model that ends with a 1D output."""
|
|
90
|
+
|
|
91
|
+
def __init__(self) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
|
|
94
|
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
95
|
+
self.fc = nn.Linear(8, 5)
|
|
96
|
+
|
|
97
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
98
|
+
"""Forward pass."""
|
|
99
|
+
x = self.conv(x)
|
|
100
|
+
x = self.pool(x)
|
|
101
|
+
x = torch.flatten(x, 1)
|
|
102
|
+
return self.fc(x)
|
|
103
|
+
|
|
104
|
+
return ClassifierModel()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_sequential_model_lenet_view_runs(sequential_model: nn.Sequential) -> None:
|
|
108
|
+
"""Test lenet view on sequential model."""
|
|
109
|
+
_ = lenet_view(sequential_model, input_shape=(1, 3, 224, 224))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_module_list_model_lenet_view_runs(module_list_model: nn.ModuleList) -> None:
|
|
113
|
+
"""Test lenet view on module list model."""
|
|
114
|
+
_ = lenet_view(module_list_model, input_shape=(1, 3, 224, 224))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_custom_model_lenet_view_runs(custom_model: nn.Module) -> None:
|
|
118
|
+
"""Test lenet view on custom model."""
|
|
119
|
+
_ = lenet_view(custom_model, input_shape=(1, 3, 224, 224))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_lstm_model_layered_view_runs(lstm_model: nn.Module) -> None:
|
|
123
|
+
"""Test layered view on lstm model."""
|
|
124
|
+
_ = lenet_view(lstm_model, input_shape=(1, 10, 10))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@pytest.mark.parametrize("orientation", ["x", "y", "z"])
|
|
128
|
+
def test_lenet_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
|
|
129
|
+
"""Test lenet view on a model with a 1D output, for every supported orientation."""
|
|
130
|
+
img = lenet_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
|
|
131
|
+
assert img is not None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_lenet_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
|
|
135
|
+
"""An unsupported one_dim_orientation should raise a clear ValueError."""
|
|
136
|
+
with pytest.raises(ValueError, match="unsupported orientation"):
|
|
137
|
+
lenet_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_lenet_view_with_type_and_index_ignore(sequential_model: nn.Sequential) -> None:
|
|
141
|
+
"""Layers matched by type_ignore or index_ignore should be skipped without error."""
|
|
142
|
+
img = lenet_view(
|
|
143
|
+
sequential_model,
|
|
144
|
+
input_shape=(1, 3, 224, 224),
|
|
145
|
+
type_ignore=[nn.ReLU],
|
|
146
|
+
index_ignore=[0],
|
|
147
|
+
)
|
|
148
|
+
assert img is not None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_lenet_view_writes_to_file(sequential_model: nn.Sequential, tmp_path: Path) -> None:
|
|
152
|
+
"""to_file should save a readable image to disk."""
|
|
153
|
+
out_file = tmp_path / "lenet.png"
|
|
154
|
+
lenet_view(sequential_model, input_shape=(1, 3, 224, 224), to_file=str(out_file))
|
|
155
|
+
|
|
156
|
+
assert out_file.exists()
|
|
157
|
+
with Image.open(out_file) as saved_img:
|
|
158
|
+
assert saved_img.size[0] > 0
|
|
159
|
+
assert saved_img.size[1] > 0
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Regression tests for shape-handling crashes reported in open GitHub issues.
|
|
2
|
+
|
|
3
|
+
See https://github.com/willyfh/visualtorch/issues/63, /68, /69.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
7
|
+
# SPDX-License-Identifier: MIT
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
import torch
|
|
11
|
+
from torch import nn
|
|
12
|
+
from visualtorch import layered_view, lenet_view
|
|
13
|
+
from visualtorch.utils.utils import self_multiply
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture()
|
|
17
|
+
def multi_output_container_model() -> nn.Module:
|
|
18
|
+
"""A model whose inner block is a container (not Sequential/ModuleList) that returns multiple tensors.
|
|
19
|
+
|
|
20
|
+
This mirrors timm's ``FeatureListNet`` used in issue #69: a container module that isn't
|
|
21
|
+
``nn.Sequential``/``nn.ModuleList`` and returns multiple differently-shaped tensors.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
class FeaturePyramidBlock(nn.Module):
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.stage1 = nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1)
|
|
28
|
+
self.stage2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)
|
|
29
|
+
self.stage3 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
|
|
30
|
+
|
|
31
|
+
def forward(self, x: torch.Tensor) -> list:
|
|
32
|
+
f1 = self.stage1(x)
|
|
33
|
+
f2 = self.stage2(f1)
|
|
34
|
+
f3 = self.stage3(f2)
|
|
35
|
+
return [f1, f2, f3]
|
|
36
|
+
|
|
37
|
+
class MultiScaleNet(nn.Module):
|
|
38
|
+
def __init__(self) -> None:
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.features = FeaturePyramidBlock()
|
|
41
|
+
|
|
42
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
return self.features(x)[-1]
|
|
44
|
+
|
|
45
|
+
return MultiScaleNet()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_layered_view_multi_output_container(multi_output_container_model: nn.Module) -> None:
|
|
49
|
+
"""layered_view should not crash on container modules that output multiple tensors."""
|
|
50
|
+
img = layered_view(multi_output_container_model, input_shape=(1, 3, 32, 32))
|
|
51
|
+
assert img is not None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_lenet_view_multi_output_container(multi_output_container_model: nn.Module) -> None:
|
|
55
|
+
"""lenet_view should not crash on container modules that output multiple tensors."""
|
|
56
|
+
img = lenet_view(multi_output_container_model, input_shape=(1, 3, 32, 32))
|
|
57
|
+
assert img is not None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_self_multiply_handles_nested_shape() -> None:
|
|
61
|
+
"""self_multiply should always reduce to a scalar, even if an element is itself a shape."""
|
|
62
|
+
nested_shape = (1, torch.Size([4, 8]))
|
|
63
|
+
result = self_multiply(nested_shape)
|
|
64
|
+
assert isinstance(result, int)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_lenet_view_rejects_multi_input_shape_with_clear_error() -> None:
|
|
68
|
+
"""Multi-tensor-input models aren't supported yet; the failure should be a clear ValueError."""
|
|
69
|
+
model = nn.Linear(10, 5)
|
|
70
|
+
multi_input_shape = ((1, 10), (1, 10))
|
|
71
|
+
|
|
72
|
+
with pytest.raises(ValueError, match="single"):
|
|
73
|
+
lenet_view(model, input_shape=multi_input_shape)
|
|
@@ -22,6 +22,7 @@ from .utils.utils import (
|
|
|
22
22
|
get_rgba_tuple,
|
|
23
23
|
linear_layout,
|
|
24
24
|
self_multiply,
|
|
25
|
+
validate_input_shape,
|
|
25
26
|
vertical_image_concat,
|
|
26
27
|
)
|
|
27
28
|
|
|
@@ -85,6 +86,8 @@ def layered_view(
|
|
|
85
86
|
"""
|
|
86
87
|
# Iterate over the model to compute bounds and generate boxes
|
|
87
88
|
|
|
89
|
+
validate_input_shape(input_shape)
|
|
90
|
+
|
|
88
91
|
x_off = -1
|
|
89
92
|
|
|
90
93
|
img_height = 0
|
|
@@ -291,7 +294,7 @@ def _create_architecture(
|
|
|
291
294
|
layer = layers[key]["module"]
|
|
292
295
|
shape = layers[key]["output_shape"]
|
|
293
296
|
# Do no render the SpacingDummyLayer, just increase the pointer
|
|
294
|
-
if type(layer)
|
|
297
|
+
if type(layer) is SpacingDummyLayer:
|
|
295
298
|
current_z += layer.spacing
|
|
296
299
|
continue
|
|
297
300
|
|
|
@@ -19,6 +19,7 @@ from .utils.utils import (
|
|
|
19
19
|
StackedBox,
|
|
20
20
|
get_rgba_tuple,
|
|
21
21
|
self_multiply,
|
|
22
|
+
validate_input_shape,
|
|
22
23
|
)
|
|
23
24
|
|
|
24
25
|
|
|
@@ -82,6 +83,8 @@ def lenet_view(
|
|
|
82
83
|
"""
|
|
83
84
|
# Iterate over the model to compute bounds and generate boxes
|
|
84
85
|
|
|
86
|
+
validate_input_shape(input_shape)
|
|
87
|
+
|
|
85
88
|
x_off = -1
|
|
86
89
|
|
|
87
90
|
img_height = 0
|
|
@@ -247,7 +250,7 @@ def _create_architecture(
|
|
|
247
250
|
layer = layers[key]["module"]
|
|
248
251
|
shape = layers[key]["output_shape"]
|
|
249
252
|
# Do no render the SpacingDummyLayer, just increase the pointer
|
|
250
|
-
if type(layer)
|
|
253
|
+
if type(layer) is SpacingDummyLayer:
|
|
251
254
|
current_x += layer.spacing
|
|
252
255
|
continue
|
|
253
256
|
|
|
@@ -185,12 +185,20 @@ def register_hook(
|
|
|
185
185
|
m_key = "%s-%i" % (class_name, module_idx + 1)
|
|
186
186
|
layers[m_key] = OrderedDict()
|
|
187
187
|
layers[m_key]["module"] = module
|
|
188
|
-
if isinstance(out,
|
|
189
|
-
|
|
188
|
+
if isinstance(out, tuple | list):
|
|
189
|
+
if len(out) > 0 and hasattr(out[0], "size"):
|
|
190
|
+
layers[m_key]["output_shape"] = out[0].size()
|
|
191
|
+
else:
|
|
192
|
+
layers[m_key]["output_shape"] = tuple(o.size() for o in out if hasattr(o, "size"))
|
|
190
193
|
else:
|
|
191
194
|
layers[m_key]["output_shape"] = out.size()
|
|
192
195
|
|
|
193
|
-
|
|
196
|
+
# Only hook leaf modules (no children). Container modules - whether nn.Sequential,
|
|
197
|
+
# nn.ModuleList, or a custom container such as timm's FeatureListNet - would otherwise be
|
|
198
|
+
# captured as if they were a single layer, with their multi-tensor output mistaken for one
|
|
199
|
+
# layer's output shape.
|
|
200
|
+
is_leaf = len(list(module.children())) == 0
|
|
201
|
+
if is_leaf and module is not model:
|
|
194
202
|
hooks.append(module.register_forward_hook(hook))
|
|
195
203
|
|
|
196
204
|
|
|
@@ -261,23 +261,44 @@ def get_keys_by_value(d: dict, v: int) -> Generator:
|
|
|
261
261
|
yield key
|
|
262
262
|
|
|
263
263
|
|
|
264
|
-
def
|
|
264
|
+
def validate_input_shape(input_shape: tuple) -> None:
|
|
265
|
+
"""Validate that input_shape describes a single input tensor.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
input_shape (tuple): The shape to validate.
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
ValueError: If input_shape is not a single flat tuple of ints, e.g. when a tuple of
|
|
272
|
+
per-tensor shapes was passed for a model that takes multiple separate input tensors.
|
|
273
|
+
"""
|
|
274
|
+
if not isinstance(input_shape, tuple) or not all(isinstance(dim, int) for dim in input_shape):
|
|
275
|
+
error_msg = (
|
|
276
|
+
"input_shape must be a single tuple of ints, e.g. (1, 3, 224, 224). "
|
|
277
|
+
f"Got {input_shape!r} instead. Visualizing models that take multiple separate "
|
|
278
|
+
"input tensors is not supported yet."
|
|
279
|
+
)
|
|
280
|
+
raise ValueError(error_msg)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def self_multiply(tensor_tuple: tuple | list) -> int | float:
|
|
265
284
|
"""Multiplies all elements in the tuple together.
|
|
266
285
|
|
|
286
|
+
Elements that are themselves a tuple/list (e.g. a nested torch.Size, which can end up here
|
|
287
|
+
when a layer's captured output shape wasn't a plain flat shape) are flattened by multiplying
|
|
288
|
+
their own elements together first, so the result is always a scalar.
|
|
289
|
+
|
|
267
290
|
Args:
|
|
268
|
-
tensor_tuple (tuple): A tuple containing tensors.
|
|
291
|
+
tensor_tuple (tuple or list): A tuple containing tensors.
|
|
269
292
|
|
|
270
293
|
Returns:
|
|
271
294
|
int or float: The result of multiplying all elements together.
|
|
272
295
|
"""
|
|
273
|
-
tensor_list =
|
|
274
|
-
if None in tensor_list:
|
|
275
|
-
tensor_list.remove(None)
|
|
296
|
+
tensor_list = [v for v in tensor_tuple if v is not None]
|
|
276
297
|
if len(tensor_list) == 0:
|
|
277
298
|
return 0
|
|
278
|
-
s =
|
|
279
|
-
for
|
|
280
|
-
s *=
|
|
299
|
+
s = 1
|
|
300
|
+
for v in tensor_list:
|
|
301
|
+
s *= self_multiply(v) if isinstance(v, tuple | list) else v
|
|
281
302
|
return s
|
|
282
303
|
|
|
283
304
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -14,8 +14,37 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
14
14
|
Classifier: Operating System :: OS Independent
|
|
15
15
|
Requires-Python: >=3.10
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
|
-
Provides-Extra: dev
|
|
18
17
|
License-File: LICENSE
|
|
18
|
+
Requires-Dist: pillow>=10.0.0
|
|
19
|
+
Requires-Dist: numpy>=1.18.1
|
|
20
|
+
Requires-Dist: aggdraw>=1.3.11
|
|
21
|
+
Requires-Dist: torch>=2.0.0
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: myst-parser; extra == "dev"
|
|
24
|
+
Requires-Dist: nbsphinx; extra == "dev"
|
|
25
|
+
Requires-Dist: pandoc; extra == "dev"
|
|
26
|
+
Requires-Dist: sphinx<7.0; extra == "dev"
|
|
27
|
+
Requires-Dist: sphinx_autodoc_typehints; extra == "dev"
|
|
28
|
+
Requires-Dist: sphinx_book_theme; extra == "dev"
|
|
29
|
+
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
30
|
+
Requires-Dist: sphinx_design; extra == "dev"
|
|
31
|
+
Requires-Dist: sphinx_gallery; extra == "dev"
|
|
32
|
+
Requires-Dist: matplotlib; extra == "dev"
|
|
33
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest; extra == "dev"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: author-email
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: keywords
|
|
42
|
+
Dynamic: license
|
|
43
|
+
Dynamic: license-file
|
|
44
|
+
Dynamic: provides-extra
|
|
45
|
+
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
47
|
+
Dynamic: summary
|
|
19
48
|
|
|
20
49
|
<div align="center">
|
|
21
50
|
<h1>🔥 VisualTorch 🔥</h1>
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|