visualtorch 0.2.4__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visualtorch-0.2.4 → visualtorch-1.0.0}/PKG-INFO +60 -7
- visualtorch-0.2.4/visualtorch.egg-info/PKG-INFO → visualtorch-1.0.0/README.md +28 -23
- {visualtorch-0.2.4 → visualtorch-1.0.0}/pyproject.toml +4 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/setup.py +1 -1
- visualtorch-1.0.0/tests/test_flow.py +473 -0
- visualtorch-1.0.0/tests/test_graph.py +435 -0
- visualtorch-1.0.0/tests/test_lenet_style.py +330 -0
- visualtorch-1.0.0/tests/test_regression_issues.py +114 -0
- visualtorch-1.0.0/tests/test_render.py +150 -0
- visualtorch-1.0.0/visualtorch/__init__.py +22 -0
- visualtorch-1.0.0/visualtorch/_volumetric_layout.py +142 -0
- visualtorch-1.0.0/visualtorch/backend.py +169 -0
- visualtorch-1.0.0/visualtorch/connectors.py +90 -0
- visualtorch-1.0.0/visualtorch/flow.py +538 -0
- visualtorch-1.0.0/visualtorch/graph.py +377 -0
- visualtorch-1.0.0/visualtorch/lenet_style.py +385 -0
- visualtorch-1.0.0/visualtorch/render.py +236 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch/utils/__init__.py +1 -15
- visualtorch-1.0.0/visualtorch/utils/layer_utils.py +18 -0
- visualtorch-1.0.0/visualtorch/utils/recorder.py +247 -0
- visualtorch-1.0.0/visualtorch/utils/traced_layer.py +22 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch/utils/utils.py +69 -19
- visualtorch-1.0.0/visualtorch.egg-info/PKG-INFO +118 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/SOURCES.txt +12 -1
- visualtorch-0.2.4/README.md +0 -46
- visualtorch-0.2.4/visualtorch/__init__.py +0 -10
- visualtorch-0.2.4/visualtorch/graph.py +0 -246
- visualtorch-0.2.4/visualtorch/layered.py +0 -367
- visualtorch-0.2.4/visualtorch/lenet_style.py +0 -326
- visualtorch-0.2.4/visualtorch/utils/layer_utils.py +0 -246
- {visualtorch-0.2.4 → visualtorch-1.0.0}/LICENSE +0 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/setup.cfg +0 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/dependency_links.txt +0 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/requires.txt +0 -0
- {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -14,8 +14,37 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
14
14
|
Classifier: Operating System :: OS Independent
|
|
15
15
|
Requires-Python: >=3.10
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
|
-
Provides-Extra: dev
|
|
18
17
|
License-File: LICENSE
|
|
18
|
+
Requires-Dist: pillow>=10.0.0
|
|
19
|
+
Requires-Dist: numpy>=1.18.1
|
|
20
|
+
Requires-Dist: aggdraw>=1.3.11
|
|
21
|
+
Requires-Dist: torch>=2.0.0
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: myst-parser; extra == "dev"
|
|
24
|
+
Requires-Dist: nbsphinx; extra == "dev"
|
|
25
|
+
Requires-Dist: pandoc; extra == "dev"
|
|
26
|
+
Requires-Dist: sphinx<7.0; extra == "dev"
|
|
27
|
+
Requires-Dist: sphinx_autodoc_typehints; extra == "dev"
|
|
28
|
+
Requires-Dist: sphinx_book_theme; extra == "dev"
|
|
29
|
+
Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
30
|
+
Requires-Dist: sphinx_design; extra == "dev"
|
|
31
|
+
Requires-Dist: sphinx_gallery; extra == "dev"
|
|
32
|
+
Requires-Dist: matplotlib; extra == "dev"
|
|
33
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest; extra == "dev"
|
|
35
|
+
Dynamic: author
|
|
36
|
+
Dynamic: author-email
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: keywords
|
|
42
|
+
Dynamic: license
|
|
43
|
+
Dynamic: license-file
|
|
44
|
+
Dynamic: provides-extra
|
|
45
|
+
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
47
|
+
Dynamic: summary
|
|
19
48
|
|
|
20
49
|
<div align="center">
|
|
21
50
|
<h1>🔥 VisualTorch 🔥</h1>
|
|
@@ -24,13 +53,19 @@ License-File: LICENSE
|
|
|
24
53
|
|
|
25
54
|
</div>
|
|
26
55
|
|
|
27
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating
|
|
56
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
|
|
28
57
|
|
|
29
|
-
**Note:** VisualTorch
|
|
58
|
+
**Note:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
|
|
59
|
+
limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
|
|
60
|
+
execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
61
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take;
|
|
62
|
+
(2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
|
|
63
|
+
only has its first tensor's shape reflected in that node's size/label - its downstream connections
|
|
64
|
+
are still correct either way. Contributions are welcome!
|
|
30
65
|
|
|
31
66
|
<div align="center">
|
|
32
67
|
|
|
33
|
-

|
|
34
69
|
|
|
35
70
|
</div>
|
|
36
71
|
|
|
@@ -62,4 +97,22 @@ Originally, this project was based on the [visualkeras](https://github.com/paulg
|
|
|
62
97
|
|
|
63
98
|
Please cite this project in your publications if it helps your research.
|
|
64
99
|
|
|
65
|
-
|
|
100
|
+
**Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
|
|
101
|
+
since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
|
|
102
|
+
for the current API) - the DOI always resolves to what was actually reviewed and published, so
|
|
103
|
+
it isn't updated to match.
|
|
104
|
+
|
|
105
|
+
```bibtex
|
|
106
|
+
@article{Hendria2024,
|
|
107
|
+
doi = {10.21105/joss.06678},
|
|
108
|
+
url = {https://doi.org/10.21105/joss.06678},
|
|
109
|
+
year = {2024},
|
|
110
|
+
publisher = {The Open Journal},
|
|
111
|
+
volume = {9},
|
|
112
|
+
number = {102},
|
|
113
|
+
pages = {6678},
|
|
114
|
+
author = {Willy Fitra Hendria and Paul Gavrikov},
|
|
115
|
+
title = {VisualTorch: Streamlining Visualization for PyTorch Neural Network Architectures},
|
|
116
|
+
journal = {Journal of Open Source Software}
|
|
117
|
+
}
|
|
118
|
+
```
|
|
@@ -1,22 +1,3 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: visualtorch
|
|
3
|
-
Version: 0.2.4
|
|
4
|
-
Summary: Architecture visualization of Torch models
|
|
5
|
-
Home-page: https://github.com/willyfh/visualtorch
|
|
6
|
-
Author: Willy Fitra Hendria
|
|
7
|
-
Author-email: willyfitrahendria@gmail.com
|
|
8
|
-
License: MIT
|
|
9
|
-
Keywords: visualize architecture,torch visualization,visualtorch
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
-
Classifier: Operating System :: OS Independent
|
|
15
|
-
Requires-Python: >=3.10
|
|
16
|
-
Description-Content-Type: text/markdown
|
|
17
|
-
Provides-Extra: dev
|
|
18
|
-
License-File: LICENSE
|
|
19
|
-
|
|
20
1
|
<div align="center">
|
|
21
2
|
<h1>🔥 VisualTorch 🔥</h1>
|
|
22
3
|
|
|
@@ -24,13 +5,19 @@ License-File: LICENSE
|
|
|
24
5
|
|
|
25
6
|
</div>
|
|
26
7
|
|
|
27
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating
|
|
8
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
|
|
28
9
|
|
|
29
|
-
**Note:** VisualTorch
|
|
10
|
+
**Note:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
|
|
11
|
+
limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
|
|
12
|
+
execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
13
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take;
|
|
14
|
+
(2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
|
|
15
|
+
only has its first tensor's shape reflected in that node's size/label - its downstream connections
|
|
16
|
+
are still correct either way. Contributions are welcome!
|
|
30
17
|
|
|
31
18
|
<div align="center">
|
|
32
19
|
|
|
33
|
-

|
|
34
21
|
|
|
35
22
|
</div>
|
|
36
23
|
|
|
@@ -62,4 +49,22 @@ Originally, this project was based on the [visualkeras](https://github.com/paulg
|
|
|
62
49
|
|
|
63
50
|
Please cite this project in your publications if it helps your research.
|
|
64
51
|
|
|
65
|
-
|
|
52
|
+
**Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
|
|
53
|
+
since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
|
|
54
|
+
for the current API) - the DOI always resolves to what was actually reviewed and published, so
|
|
55
|
+
it isn't updated to match.
|
|
56
|
+
|
|
57
|
+
```bibtex
|
|
58
|
+
@article{Hendria2024,
|
|
59
|
+
doi = {10.21105/joss.06678},
|
|
60
|
+
url = {https://doi.org/10.21105/joss.06678},
|
|
61
|
+
year = {2024},
|
|
62
|
+
publisher = {The Open Journal},
|
|
63
|
+
volume = {9},
|
|
64
|
+
number = {102},
|
|
65
|
+
pages = {6678},
|
|
66
|
+
author = {Willy Fitra Hendria and Paul Gavrikov},
|
|
67
|
+
title = {VisualTorch: Streamlining Visualization for PyTorch Neural Network Architectures},
|
|
68
|
+
journal = {Journal of Open Source Software}
|
|
69
|
+
}
|
|
70
|
+
```
|
|
@@ -136,6 +136,10 @@ max-complexity = 15
|
|
|
136
136
|
|
|
137
137
|
[tool.ruff.per-file-ignores]
|
|
138
138
|
"tests/nightly/tools/benchmarking/test_benchmarking.py" = ["E402"]
|
|
139
|
+
# RecorderTensor overrides torch.Tensor.__torch_function__, whose own signature is untyped
|
|
140
|
+
# (Any throughout); the recursive tensor-structure helpers here genuinely accept arbitrary
|
|
141
|
+
# nested types (tensors, tuples, dicts, non-tensor values) for the same reason.
|
|
142
|
+
"visualtorch/utils/recorder.py" = ["ANN401", "ANN102"]
|
|
139
143
|
|
|
140
144
|
[tool.ruff.pydocstyle]
|
|
141
145
|
convention = "google"
|
|
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
|
|
|
21
21
|
|
|
22
22
|
setuptools.setup(
|
|
23
23
|
name="visualtorch",
|
|
24
|
-
version="0.
|
|
24
|
+
version="1.0.0",
|
|
25
25
|
author="Willy Fitra Hendria",
|
|
26
26
|
author_email="willyfitrahendria@gmail.com",
|
|
27
27
|
description="Architecture visualization of Torch models",
|
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
"""Tests for flow view."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pytest
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as func
|
|
12
|
+
from PIL import Image
|
|
13
|
+
from torch import nn
|
|
14
|
+
from visualtorch.flow import flow_view
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.fixture()
|
|
18
|
+
def sequential_model() -> nn.Sequential:
|
|
19
|
+
"""Define Sequential torch model for testing."""
|
|
20
|
+
return nn.Sequential(
|
|
21
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
|
22
|
+
nn.ReLU(),
|
|
23
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
|
24
|
+
nn.ReLU(),
|
|
25
|
+
nn.MaxPool2d(2, 2),
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture()
|
|
30
|
+
def module_list_model() -> nn.ModuleList:
|
|
31
|
+
"""Define ModuleList-based torch model for testing."""
|
|
32
|
+
return nn.ModuleList(
|
|
33
|
+
[
|
|
34
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
|
35
|
+
nn.ReLU(),
|
|
36
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
|
37
|
+
nn.ReLU(),
|
|
38
|
+
nn.MaxPool2d(2, 2),
|
|
39
|
+
],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture()
|
|
44
|
+
def custom_model() -> nn.Module:
|
|
45
|
+
"""Define the custom model."""
|
|
46
|
+
|
|
47
|
+
class CustomModel(nn.Module):
|
|
48
|
+
"""A simple custom cnn model."""
|
|
49
|
+
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
|
|
53
|
+
self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
|
|
54
|
+
|
|
55
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
"""Funcorward pass."""
|
|
57
|
+
x = func.relu(self.conv1(x))
|
|
58
|
+
x = func.relu(self.conv2(x))
|
|
59
|
+
return func.max_pool2d(x, 2, 2)
|
|
60
|
+
|
|
61
|
+
# Create an instance of the custom model
|
|
62
|
+
return CustomModel()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.fixture()
|
|
66
|
+
def lstm_model() -> nn.Module:
|
|
67
|
+
"""Define a simple LSTM model for testing."""
|
|
68
|
+
|
|
69
|
+
class LSTMModel(nn.Module):
|
|
70
|
+
"""A simple LSTM model."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
|
75
|
+
|
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
77
|
+
"""Forward pass."""
|
|
78
|
+
out, _ = self.lstm(x)
|
|
79
|
+
return out
|
|
80
|
+
|
|
81
|
+
# Create an instance of the LSTM model
|
|
82
|
+
return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.fixture()
|
|
86
|
+
def gru_model() -> nn.Module:
|
|
87
|
+
"""Define a simple GRU model for testing."""
|
|
88
|
+
|
|
89
|
+
class GRUModel(nn.Module):
|
|
90
|
+
"""A simple GRU model."""
|
|
91
|
+
|
|
92
|
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
|
|
93
|
+
super().__init__()
|
|
94
|
+
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
|
|
95
|
+
|
|
96
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
"""Forward pass."""
|
|
98
|
+
out, _ = self.gru(x)
|
|
99
|
+
return out
|
|
100
|
+
|
|
101
|
+
return GRUModel(input_size=10, hidden_size=20, num_layers=2)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@pytest.fixture()
|
|
105
|
+
def rnn_model() -> nn.Module:
|
|
106
|
+
"""Define a simple plain RNN model for testing."""
|
|
107
|
+
|
|
108
|
+
class RNNModel(nn.Module):
|
|
109
|
+
"""A simple RNN model."""
|
|
110
|
+
|
|
111
|
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
|
|
114
|
+
|
|
115
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
116
|
+
"""Forward pass."""
|
|
117
|
+
out, _ = self.rnn(x)
|
|
118
|
+
return out
|
|
119
|
+
|
|
120
|
+
return RNNModel(input_size=10, hidden_size=20, num_layers=2)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@pytest.fixture()
|
|
124
|
+
def classifier_model() -> nn.Module:
|
|
125
|
+
"""Define a model ending in a 1D (per-sample) output, e.g. classification logits."""
|
|
126
|
+
|
|
127
|
+
class ClassifierModel(nn.Module):
|
|
128
|
+
"""A cnn model that ends with a 1D output."""
|
|
129
|
+
|
|
130
|
+
def __init__(self) -> None:
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
|
|
133
|
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
134
|
+
self.fc = nn.Linear(8, 5)
|
|
135
|
+
|
|
136
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
137
|
+
"""Forward pass."""
|
|
138
|
+
x = self.conv(x)
|
|
139
|
+
x = self.pool(x)
|
|
140
|
+
x = torch.flatten(x, 1)
|
|
141
|
+
return self.fc(x)
|
|
142
|
+
|
|
143
|
+
return ClassifierModel()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def test_sequential_model_flow_view_runs(sequential_model: nn.Sequential) -> None:
|
|
147
|
+
"""Test flow view on sequential model."""
|
|
148
|
+
_ = flow_view(sequential_model, input_shape=(1, 3, 224, 224))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_module_list_model_flow_view_runs(module_list_model: nn.ModuleList) -> None:
|
|
152
|
+
"""Test flow view on module list model."""
|
|
153
|
+
_ = flow_view(module_list_model, input_shape=(1, 3, 224, 224))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_custom_model_flow_view_runs(custom_model: nn.Module) -> None:
|
|
157
|
+
"""Test flow view on custom model."""
|
|
158
|
+
_ = flow_view(custom_model, input_shape=(1, 3, 224, 224))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_lstm_model_flow_view_runs(lstm_model: nn.Module) -> None:
|
|
162
|
+
"""Test flow view on lstm model."""
|
|
163
|
+
_ = flow_view(lstm_model, input_shape=(1, 10, 10))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def test_gru_model_flow_view_runs(gru_model: nn.Module) -> None:
|
|
167
|
+
"""Test flow view on gru model."""
|
|
168
|
+
_ = flow_view(gru_model, input_shape=(1, 10, 10))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_rnn_model_flow_view_runs(rnn_model: nn.Module) -> None:
|
|
172
|
+
"""Test flow view on plain rnn model."""
|
|
173
|
+
_ = flow_view(rnn_model, input_shape=(1, 10, 10))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@pytest.mark.parametrize("orientation", ["x", "y", "z"])
|
|
177
|
+
def test_flow_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
|
|
178
|
+
"""Test flow view on a model with a 1D output, for every supported orientation."""
|
|
179
|
+
img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
|
|
180
|
+
assert img is not None
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_flow_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
|
|
184
|
+
"""An unsupported one_dim_orientation should raise a clear ValueError."""
|
|
185
|
+
with pytest.raises(ValueError, match="unsupported orientation"):
|
|
186
|
+
flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_flow_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
|
|
190
|
+
"""Layers matched by type_ignore should be skipped without error."""
|
|
191
|
+
img = flow_view(
|
|
192
|
+
sequential_model,
|
|
193
|
+
input_shape=(1, 3, 224, 224),
|
|
194
|
+
type_ignore=[nn.ReLU],
|
|
195
|
+
)
|
|
196
|
+
assert img is not None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_flow_view_with_legend(sequential_model: nn.Sequential) -> None:
|
|
200
|
+
"""legend=True should append a legend without error."""
|
|
201
|
+
img = flow_view(sequential_model, input_shape=(1, 3, 224, 224), legend=True)
|
|
202
|
+
assert img is not None
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def test_flow_view_writes_to_file(sequential_model: nn.Sequential, tmp_path: Path) -> None:
|
|
206
|
+
"""to_file should save a readable image to disk."""
|
|
207
|
+
out_file = tmp_path / "flow.png"
|
|
208
|
+
flow_view(sequential_model, input_shape=(1, 3, 224, 224), to_file=str(out_file))
|
|
209
|
+
|
|
210
|
+
assert out_file.exists()
|
|
211
|
+
with Image.open(out_file) as saved_img:
|
|
212
|
+
assert saved_img.size[0] > 0
|
|
213
|
+
assert saved_img.size[1] > 0
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_flow_view_with_show_dimension(sequential_model: nn.Sequential) -> None:
|
|
217
|
+
"""show_dimension=True should print each layer's shape without clipping or crashing."""
|
|
218
|
+
img = flow_view(sequential_model, input_shape=(1, 3, 224, 224), show_dimension=True)
|
|
219
|
+
assert img is not None
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_flow_view_show_dimension_with_legend(sequential_model: nn.Sequential) -> None:
|
|
223
|
+
"""show_dimension and legend should be combinable."""
|
|
224
|
+
img = flow_view(sequential_model, input_shape=(1, 3, 224, 224), show_dimension=True, legend=True)
|
|
225
|
+
assert img is not None
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def test_flow_view_output_size_matches_pre_refactor_baseline(sequential_model: nn.Sequential) -> None:
|
|
229
|
+
"""Locks in flow_view's canvas size across the backend/_volumetric_layout rewrite.
|
|
230
|
+
|
|
231
|
+
Sizes captured from `main` (0b349a3) before `register_hook` was replaced with the shared
|
|
232
|
+
`extract_architecture`/`layout_columns` backend - confirmed via a `git worktree` comparison
|
|
233
|
+
at the time of the rewrite (see the graph_view equivalent test for why size, not an exact
|
|
234
|
+
pixel hash: aggdraw's anti-aliasing isn't portable across platforms, but layout math is).
|
|
235
|
+
"""
|
|
236
|
+
cases = {
|
|
237
|
+
"default": flow_view(sequential_model, input_shape=(1, 3, 32, 32)),
|
|
238
|
+
"no_volume": flow_view(sequential_model, input_shape=(1, 3, 32, 32), draw_volume=False),
|
|
239
|
+
"show_dimension": flow_view(sequential_model, input_shape=(1, 3, 32, 32), show_dimension=True),
|
|
240
|
+
"legend": flow_view(sequential_model, input_shape=(1, 3, 32, 32), legend=True),
|
|
241
|
+
"type_ignore": flow_view(sequential_model, input_shape=(1, 3, 32, 32), type_ignore=[nn.ReLU]),
|
|
242
|
+
"no_funnel": flow_view(sequential_model, input_shape=(1, 3, 32, 32), draw_funnel=False),
|
|
243
|
+
}
|
|
244
|
+
expected_sizes = {
|
|
245
|
+
"default": (124, 42),
|
|
246
|
+
"no_volume": (116, 32),
|
|
247
|
+
"show_dimension": (303, 59),
|
|
248
|
+
"legend": (124, 136),
|
|
249
|
+
"type_ignore": (82, 42),
|
|
250
|
+
"no_funnel": (124, 42),
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
for name, img in cases.items():
|
|
254
|
+
assert img.size == expected_sizes[name], f"{name} canvas size changed"
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@pytest.fixture()
|
|
258
|
+
def residual_model() -> nn.Module:
|
|
259
|
+
"""A residual block whose shortcut is the model's own raw input (the most common pattern)."""
|
|
260
|
+
|
|
261
|
+
class ResidualBlock(nn.Module):
|
|
262
|
+
def __init__(self, channels: int) -> None:
|
|
263
|
+
super().__init__()
|
|
264
|
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
265
|
+
self.bn1 = nn.BatchNorm2d(channels)
|
|
266
|
+
self.relu = nn.ReLU()
|
|
267
|
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
268
|
+
self.bn2 = nn.BatchNorm2d(channels)
|
|
269
|
+
|
|
270
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
271
|
+
"""Forward pass with a skip connection straight from the raw input."""
|
|
272
|
+
identity = x
|
|
273
|
+
out = self.relu(self.bn1(self.conv1(x)))
|
|
274
|
+
out = self.bn2(self.conv2(out))
|
|
275
|
+
out = out + identity
|
|
276
|
+
return self.relu(out)
|
|
277
|
+
|
|
278
|
+
return ResidualBlock(channels=8)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@pytest.fixture()
|
|
282
|
+
def hidden_skip_model() -> nn.Module:
|
|
283
|
+
"""A residual block whose shortcut originates from a hidden layer, not the raw input."""
|
|
284
|
+
|
|
285
|
+
class ResidualBlock(nn.Module):
|
|
286
|
+
def __init__(self) -> None:
|
|
287
|
+
super().__init__()
|
|
288
|
+
self.stem = nn.Linear(4, 4)
|
|
289
|
+
self.fc1 = nn.Linear(4, 4)
|
|
290
|
+
self.fc2 = nn.Linear(4, 4)
|
|
291
|
+
self.out = nn.Linear(4, 2)
|
|
292
|
+
|
|
293
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
294
|
+
"""Forward pass with a skip connection around fc1/fc2."""
|
|
295
|
+
stem_out = self.stem(x)
|
|
296
|
+
branch = self.fc2(self.fc1(stem_out))
|
|
297
|
+
merged = branch + stem_out
|
|
298
|
+
return self.out(merged)
|
|
299
|
+
|
|
300
|
+
return ResidualBlock()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def test_flow_view_residual_model_runs(residual_model: nn.Module) -> None:
|
|
304
|
+
"""flow_view should not crash on a model with a skip connection from the raw input."""
|
|
305
|
+
img = flow_view(residual_model, input_shape=(1, 8, 16, 16))
|
|
306
|
+
assert img is not None
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def test_flow_view_hidden_skip_model_runs(hidden_skip_model: nn.Module) -> None:
|
|
310
|
+
"""flow_view should not crash on a model with a skip connection from a hidden layer."""
|
|
311
|
+
img = flow_view(hidden_skip_model, input_shape=(1, 4))
|
|
312
|
+
assert img is not None
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def test_flow_view_residual_model_routes_above_diagram(residual_model: nn.Module) -> None:
|
|
316
|
+
"""A skip connection from the raw input should reserve extra vertical space, not vanish."""
|
|
317
|
+
|
|
318
|
+
class PlainChain(nn.Module):
|
|
319
|
+
"""The same layers as residual_model, but without the skip connection."""
|
|
320
|
+
|
|
321
|
+
def __init__(self, channels: int) -> None:
|
|
322
|
+
super().__init__()
|
|
323
|
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
324
|
+
self.bn1 = nn.BatchNorm2d(channels)
|
|
325
|
+
self.relu = nn.ReLU()
|
|
326
|
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
327
|
+
self.bn2 = nn.BatchNorm2d(channels)
|
|
328
|
+
|
|
329
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
330
|
+
"""Forward pass with no skip connection."""
|
|
331
|
+
return self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))))
|
|
332
|
+
|
|
333
|
+
img_with_skip = flow_view(residual_model, input_shape=(1, 8, 16, 16))
|
|
334
|
+
img_without_skip = flow_view(PlainChain(channels=8), input_shape=(1, 8, 16, 16))
|
|
335
|
+
|
|
336
|
+
assert img_with_skip.size[1] > img_without_skip.size[1]
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def test_flow_view_hidden_skip_model_routes_above_diagram(hidden_skip_model: nn.Module) -> None:
|
|
340
|
+
"""A skip connection from a hidden layer should also reserve extra vertical space."""
|
|
341
|
+
|
|
342
|
+
class PlainChain(nn.Module):
|
|
343
|
+
def __init__(self) -> None:
|
|
344
|
+
super().__init__()
|
|
345
|
+
self.stem = nn.Linear(4, 4)
|
|
346
|
+
self.fc1 = nn.Linear(4, 4)
|
|
347
|
+
self.fc2 = nn.Linear(4, 4)
|
|
348
|
+
self.out = nn.Linear(4, 2)
|
|
349
|
+
|
|
350
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
351
|
+
"""Forward pass with no skip connection."""
|
|
352
|
+
return self.out(self.fc2(self.fc1(self.stem(x))))
|
|
353
|
+
|
|
354
|
+
img_with_skip = flow_view(hidden_skip_model, input_shape=(1, 4))
|
|
355
|
+
img_without_skip = flow_view(PlainChain(), input_shape=(1, 4))
|
|
356
|
+
|
|
357
|
+
assert img_with_skip.size[1] > img_without_skip.size[1]
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def test_flow_view_deep_repeated_residual_blocks_stays_reasonably_sized() -> None:
|
|
361
|
+
"""Back-to-back, non-overlapping residual blocks should share one detour level, not stack per block."""
|
|
362
|
+
|
|
363
|
+
class ResBlock(nn.Module):
|
|
364
|
+
def __init__(self, channels: int) -> None:
|
|
365
|
+
super().__init__()
|
|
366
|
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
367
|
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
368
|
+
self.relu = nn.ReLU()
|
|
369
|
+
|
|
370
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
371
|
+
"""Forward pass with a skip connection around conv1/conv2."""
|
|
372
|
+
identity = x
|
|
373
|
+
out = self.conv2(self.conv1(x))
|
|
374
|
+
return self.relu(out + identity)
|
|
375
|
+
|
|
376
|
+
class DeepModel(nn.Module):
|
|
377
|
+
def __init__(self, channels: int, n_blocks: int) -> None:
|
|
378
|
+
super().__init__()
|
|
379
|
+
self.blocks = nn.ModuleList([ResBlock(channels) for _ in range(n_blocks)])
|
|
380
|
+
|
|
381
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
382
|
+
"""Forward pass through every block in sequence."""
|
|
383
|
+
for block in self.blocks:
|
|
384
|
+
x = block(x)
|
|
385
|
+
return x
|
|
386
|
+
|
|
387
|
+
img_2_blocks = flow_view(DeepModel(4, 2), input_shape=(1, 4, 8, 8))
|
|
388
|
+
img_6_blocks = flow_view(DeepModel(4, 6), input_shape=(1, 4, 8, 8))
|
|
389
|
+
|
|
390
|
+
assert img_2_blocks.size[1] == img_6_blocks.size[1]
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _non_background_pixel_count(img: Image.Image) -> int:
|
|
394
|
+
return int((np.array(img.convert("RGB")) != 255).any(axis=2).sum())
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def test_flow_view_funnels_survive_large_de_differences_between_layers() -> None:
|
|
398
|
+
"""A funnel between two layers with very different 3D depth (`de`) must stay visible.
|
|
399
|
+
|
|
400
|
+
Regression test for a real bug: drawing every connector first and every box second (instead
|
|
401
|
+
of interleaving them column by column) let each box's opaque fill blot out large parts of
|
|
402
|
+
its own incoming funnel whenever neighboring layers have a very different `de` - which
|
|
403
|
+
barely showed on the small, near-constant-`de` models used to verify the flow_view
|
|
404
|
+
rewrite, but was highly visible on a real CNN (found by the user manually comparing
|
|
405
|
+
ReadTheDocs' `plot_basic_custom` example before/after the rewrite). Canvas *size* alone
|
|
406
|
+
doesn't catch this class of bug (box positions are unaffected, only which pixels get
|
|
407
|
+
painted), so this asserts on rendered content instead.
|
|
408
|
+
"""
|
|
409
|
+
|
|
410
|
+
class SimpleCNN(nn.Module):
|
|
411
|
+
"""The exact model from docs/examples/flow/plot_basic_custom.py."""
|
|
412
|
+
|
|
413
|
+
def __init__(self) -> None:
|
|
414
|
+
super().__init__()
|
|
415
|
+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
|
|
416
|
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
|
417
|
+
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
|
418
|
+
self.fc1 = nn.Linear(64 * 28 * 28, 128)
|
|
419
|
+
self.fc2 = nn.Linear(128, 10)
|
|
420
|
+
|
|
421
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
422
|
+
"""Forward pass with three shrinking conv/pool stages."""
|
|
423
|
+
x = self.conv1(x)
|
|
424
|
+
x = func.relu(x)
|
|
425
|
+
x = func.max_pool2d(x, 2, 2)
|
|
426
|
+
x = self.conv2(x)
|
|
427
|
+
x = func.relu(x)
|
|
428
|
+
x = func.max_pool2d(x, 2, 2)
|
|
429
|
+
x = self.conv3(x)
|
|
430
|
+
x = func.relu(x)
|
|
431
|
+
x = func.max_pool2d(x, 2, 2)
|
|
432
|
+
x = x.view(x.size(0), -1)
|
|
433
|
+
x = self.fc1(x)
|
|
434
|
+
x = func.relu(x)
|
|
435
|
+
return self.fc2(x)
|
|
436
|
+
|
|
437
|
+
img = flow_view(SimpleCNN(), input_shape=(1, 3, 224, 224), legend=True)
|
|
438
|
+
|
|
439
|
+
# Locked in from the current (fixed) implementation - confirmed pixel-identical to
|
|
440
|
+
# pre-rewrite main (1ee630e) via a git-worktree comparison for this exact model. The buggy
|
|
441
|
+
# intermediate version rendered thousands fewer non-background pixels here (missing funnel
|
|
442
|
+
# segments), so a wide but real tolerance still catches a regression of that class.
|
|
443
|
+
assert img.size == (153, 336)
|
|
444
|
+
non_bg = _non_background_pixel_count(img)
|
|
445
|
+
error_msg = f"non-background pixel count {non_bg} outside expected range - funnel likely broken"
|
|
446
|
+
assert 21000 <= non_bg <= 24000, error_msg
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
|
|
450
|
+
"""Unlike the single-input case, flow_view must not hide any of 2+ separate input boxes -
|
|
451
|
+
hiding one would make it ambiguous which arrow originates from which named input.
|
|
452
|
+
""" # noqa: D205
|
|
453
|
+
from visualtorch.backend import extract_architecture
|
|
454
|
+
from visualtorch.utils.layer_utils import InputDummyLayer
|
|
455
|
+
|
|
456
|
+
class TwoInputNet(nn.Module):
|
|
457
|
+
def __init__(self) -> None:
|
|
458
|
+
super().__init__()
|
|
459
|
+
self.a = nn.Linear(4, 4)
|
|
460
|
+
self.b = nn.Linear(4, 4)
|
|
461
|
+
self.head = nn.Linear(8, 2)
|
|
462
|
+
|
|
463
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
464
|
+
return self.head(torch.cat([self.a(x), self.b(y)], dim=1))
|
|
465
|
+
|
|
466
|
+
architecture = extract_architecture(TwoInputNet(), ((1, 4), (1, 4)))
|
|
467
|
+
input_labels = {
|
|
468
|
+
layer.module.name() for layer in architecture.columns[0] if isinstance(layer.module, InputDummyLayer)
|
|
469
|
+
}
|
|
470
|
+
assert input_labels == {"input_0", "input_1"}
|
|
471
|
+
|
|
472
|
+
img = flow_view(TwoInputNet(), input_shape=((1, 4), (1, 4)))
|
|
473
|
+
assert img is not None
|