visualtorch 0.2.4__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {visualtorch-0.2.4 → visualtorch-1.0.0}/PKG-INFO +60 -7
  2. visualtorch-0.2.4/visualtorch.egg-info/PKG-INFO → visualtorch-1.0.0/README.md +28 -23
  3. {visualtorch-0.2.4 → visualtorch-1.0.0}/pyproject.toml +4 -0
  4. {visualtorch-0.2.4 → visualtorch-1.0.0}/setup.py +1 -1
  5. visualtorch-1.0.0/tests/test_flow.py +473 -0
  6. visualtorch-1.0.0/tests/test_graph.py +435 -0
  7. visualtorch-1.0.0/tests/test_lenet_style.py +330 -0
  8. visualtorch-1.0.0/tests/test_regression_issues.py +114 -0
  9. visualtorch-1.0.0/tests/test_render.py +150 -0
  10. visualtorch-1.0.0/visualtorch/__init__.py +22 -0
  11. visualtorch-1.0.0/visualtorch/_volumetric_layout.py +142 -0
  12. visualtorch-1.0.0/visualtorch/backend.py +169 -0
  13. visualtorch-1.0.0/visualtorch/connectors.py +90 -0
  14. visualtorch-1.0.0/visualtorch/flow.py +538 -0
  15. visualtorch-1.0.0/visualtorch/graph.py +377 -0
  16. visualtorch-1.0.0/visualtorch/lenet_style.py +385 -0
  17. visualtorch-1.0.0/visualtorch/render.py +236 -0
  18. {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch/utils/__init__.py +1 -15
  19. visualtorch-1.0.0/visualtorch/utils/layer_utils.py +18 -0
  20. visualtorch-1.0.0/visualtorch/utils/recorder.py +247 -0
  21. visualtorch-1.0.0/visualtorch/utils/traced_layer.py +22 -0
  22. {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch/utils/utils.py +69 -19
  23. visualtorch-1.0.0/visualtorch.egg-info/PKG-INFO +118 -0
  24. {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/SOURCES.txt +12 -1
  25. visualtorch-0.2.4/README.md +0 -46
  26. visualtorch-0.2.4/visualtorch/__init__.py +0 -10
  27. visualtorch-0.2.4/visualtorch/graph.py +0 -246
  28. visualtorch-0.2.4/visualtorch/layered.py +0 -367
  29. visualtorch-0.2.4/visualtorch/lenet_style.py +0 -326
  30. visualtorch-0.2.4/visualtorch/utils/layer_utils.py +0 -246
  31. {visualtorch-0.2.4 → visualtorch-1.0.0}/LICENSE +0 -0
  32. {visualtorch-0.2.4 → visualtorch-1.0.0}/setup.cfg +0 -0
  33. {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/dependency_links.txt +0 -0
  34. {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/requires.txt +0 -0
  35. {visualtorch-0.2.4 → visualtorch-1.0.0}/visualtorch.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: visualtorch
3
- Version: 0.2.4
3
+ Version: 1.0.0
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -14,8 +14,37 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Operating System :: OS Independent
15
15
  Requires-Python: >=3.10
16
16
  Description-Content-Type: text/markdown
17
- Provides-Extra: dev
18
17
  License-File: LICENSE
18
+ Requires-Dist: pillow>=10.0.0
19
+ Requires-Dist: numpy>=1.18.1
20
+ Requires-Dist: aggdraw>=1.3.11
21
+ Requires-Dist: torch>=2.0.0
22
+ Provides-Extra: dev
23
+ Requires-Dist: myst-parser; extra == "dev"
24
+ Requires-Dist: nbsphinx; extra == "dev"
25
+ Requires-Dist: pandoc; extra == "dev"
26
+ Requires-Dist: sphinx<7.0; extra == "dev"
27
+ Requires-Dist: sphinx_autodoc_typehints; extra == "dev"
28
+ Requires-Dist: sphinx_book_theme; extra == "dev"
29
+ Requires-Dist: sphinx-copybutton; extra == "dev"
30
+ Requires-Dist: sphinx_design; extra == "dev"
31
+ Requires-Dist: sphinx_gallery; extra == "dev"
32
+ Requires-Dist: matplotlib; extra == "dev"
33
+ Requires-Dist: pre-commit; extra == "dev"
34
+ Requires-Dist: pytest; extra == "dev"
35
+ Dynamic: author
36
+ Dynamic: author-email
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: keywords
42
+ Dynamic: license
43
+ Dynamic: license-file
44
+ Dynamic: provides-extra
45
+ Dynamic: requires-dist
46
+ Dynamic: requires-python
47
+ Dynamic: summary
19
48
 
20
49
  <div align="center">
21
50
  <h1>🔥 VisualTorch 🔥</h1>
@@ -24,13 +53,19 @@ License-File: LICENSE
24
53
 
25
54
  </div>
26
55
 
27
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
56
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
28
57
 
29
- **Note:** VisualTorch may not yet support complex models, but contributions are welcome!
58
+ **Note:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
59
+ limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
60
+ execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
61
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take;
62
+ (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
63
+ only has its first tensor's shape reflected in that node's size/label - its downstream connections
64
+ are still correct either way. Contributions are welcome!
30
65
 
31
66
  <div align="center">
32
67
 
33
- ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/398c3356-4de0-446b-a30b-d8ebe532d2c2)
68
+ ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
34
69
 
35
70
  </div>
36
71
 
@@ -62,4 +97,22 @@ Originally, this project was based on the [visualkeras](https://github.com/paulg
62
97
 
63
98
  Please cite this project in your publications if it helps your research.
64
99
 
65
- [A ready-made citation entry](https://visualtorch.readthedocs.io/en/latest/index.html#citation) is available.
100
+ **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
101
+ since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
102
+ for the current API) - the DOI always resolves to what was actually reviewed and published, so
103
+ it isn't updated to match.
104
+
105
+ ```bibtex
106
+ @article{Hendria2024,
107
+ doi = {10.21105/joss.06678},
108
+ url = {https://doi.org/10.21105/joss.06678},
109
+ year = {2024},
110
+ publisher = {The Open Journal},
111
+ volume = {9},
112
+ number = {102},
113
+ pages = {6678},
114
+ author = {Willy Fitra Hendria and Paul Gavrikov},
115
+ title = {VisualTorch: Streamlining Visualization for PyTorch Neural Network Architectures},
116
+ journal = {Journal of Open Source Software}
117
+ }
118
+ ```
@@ -1,22 +1,3 @@
1
- Metadata-Version: 2.1
2
- Name: visualtorch
3
- Version: 0.2.4
4
- Summary: Architecture visualization of Torch models
5
- Home-page: https://github.com/willyfh/visualtorch
6
- Author: Willy Fitra Hendria
7
- Author-email: willyfitrahendria@gmail.com
8
- License: MIT
9
- Keywords: visualize architecture,torch visualization,visualtorch
10
- Classifier: Programming Language :: Python :: 3.10
11
- Classifier: Programming Language :: Python :: 3.11
12
- Classifier: Programming Language :: Python :: 3.12
13
- Classifier: License :: OSI Approved :: MIT License
14
- Classifier: Operating System :: OS Independent
15
- Requires-Python: >=3.10
16
- Description-Content-Type: text/markdown
17
- Provides-Extra: dev
18
- License-File: LICENSE
19
-
20
1
  <div align="center">
21
2
  <h1>🔥 VisualTorch 🔥</h1>
22
3
 
@@ -24,13 +5,19 @@ License-File: LICENSE
24
5
 
25
6
  </div>
26
7
 
27
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
8
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
28
9
 
29
- **Note:** VisualTorch may not yet support complex models, but contributions are welcome!
10
+ **Note:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
11
+ limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
12
+ execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
13
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take;
14
+ (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
15
+ only has its first tensor's shape reflected in that node's size/label - its downstream connections
16
+ are still correct either way. Contributions are welcome!
30
17
 
31
18
  <div align="center">
32
19
 
33
- ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/398c3356-4de0-446b-a30b-d8ebe532d2c2)
20
+ ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
34
21
 
35
22
  </div>
36
23
 
@@ -62,4 +49,22 @@ Originally, this project was based on the [visualkeras](https://github.com/paulg
62
49
 
63
50
  Please cite this project in your publications if it helps your research.
64
51
 
65
- [A ready-made citation entry](https://visualtorch.readthedocs.io/en/latest/index.html#citation) is available.
52
+ **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
53
+ since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
54
+ for the current API) - the DOI always resolves to what was actually reviewed and published, so
55
+ it isn't updated to match.
56
+
57
+ ```bibtex
58
+ @article{Hendria2024,
59
+ doi = {10.21105/joss.06678},
60
+ url = {https://doi.org/10.21105/joss.06678},
61
+ year = {2024},
62
+ publisher = {The Open Journal},
63
+ volume = {9},
64
+ number = {102},
65
+ pages = {6678},
66
+ author = {Willy Fitra Hendria and Paul Gavrikov},
67
+ title = {VisualTorch: Streamlining Visualization for PyTorch Neural Network Architectures},
68
+ journal = {Journal of Open Source Software}
69
+ }
70
+ ```
@@ -136,6 +136,10 @@ max-complexity = 15
136
136
 
137
137
  [tool.ruff.per-file-ignores]
138
138
  "tests/nightly/tools/benchmarking/test_benchmarking.py" = ["E402"]
139
+ # RecorderTensor overrides torch.Tensor.__torch_function__, whose own signature is untyped
140
+ # (Any throughout); the recursive tensor-structure helpers here genuinely accept arbitrary
141
+ # nested types (tensors, tuples, dicts, non-tensor values) for the same reason.
142
+ "visualtorch/utils/recorder.py" = ["ANN401", "ANN102"]
139
143
 
140
144
  [tool.ruff.pydocstyle]
141
145
  convention = "google"
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
21
21
 
22
22
  setuptools.setup(
23
23
  name="visualtorch",
24
- version="0.2.4",
24
+ version="1.0.0",
25
25
  author="Willy Fitra Hendria",
26
26
  author_email="willyfitrahendria@gmail.com",
27
27
  description="Architecture visualization of Torch models",
@@ -0,0 +1,473 @@
1
+ """Tests for flow view."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pytest
10
+ import torch
11
+ import torch.nn.functional as func
12
+ from PIL import Image
13
+ from torch import nn
14
+ from visualtorch.flow import flow_view
15
+
16
+
17
+ @pytest.fixture()
18
+ def sequential_model() -> nn.Sequential:
19
+ """Define Sequential torch model for testing."""
20
+ return nn.Sequential(
21
+ nn.Conv2d(3, 64, 3, 1, 1),
22
+ nn.ReLU(),
23
+ nn.Conv2d(64, 128, 3, 1, 1),
24
+ nn.ReLU(),
25
+ nn.MaxPool2d(2, 2),
26
+ )
27
+
28
+
29
+ @pytest.fixture()
30
+ def module_list_model() -> nn.ModuleList:
31
+ """Define ModuleList-based torch model for testing."""
32
+ return nn.ModuleList(
33
+ [
34
+ nn.Conv2d(3, 64, 3, 1, 1),
35
+ nn.ReLU(),
36
+ nn.Conv2d(64, 128, 3, 1, 1),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(2, 2),
39
+ ],
40
+ )
41
+
42
+
43
+ @pytest.fixture()
44
+ def custom_model() -> nn.Module:
45
+ """Define the custom model."""
46
+
47
+ class CustomModel(nn.Module):
48
+ """A simple custom cnn model."""
49
+
50
+ def __init__(self) -> None:
51
+ super().__init__()
52
+ self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
53
+ self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ """Funcorward pass."""
57
+ x = func.relu(self.conv1(x))
58
+ x = func.relu(self.conv2(x))
59
+ return func.max_pool2d(x, 2, 2)
60
+
61
+ # Create an instance of the custom model
62
+ return CustomModel()
63
+
64
+
65
+ @pytest.fixture()
66
+ def lstm_model() -> nn.Module:
67
+ """Define a simple LSTM model for testing."""
68
+
69
+ class LSTMModel(nn.Module):
70
+ """A simple LSTM model."""
71
+
72
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
73
+ super().__init__()
74
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """Forward pass."""
78
+ out, _ = self.lstm(x)
79
+ return out
80
+
81
+ # Create an instance of the LSTM model
82
+ return LSTMModel(input_size=10, hidden_size=20, num_layers=2)
83
+
84
+
85
+ @pytest.fixture()
86
+ def gru_model() -> nn.Module:
87
+ """Define a simple GRU model for testing."""
88
+
89
+ class GRUModel(nn.Module):
90
+ """A simple GRU model."""
91
+
92
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
93
+ super().__init__()
94
+ self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ """Forward pass."""
98
+ out, _ = self.gru(x)
99
+ return out
100
+
101
+ return GRUModel(input_size=10, hidden_size=20, num_layers=2)
102
+
103
+
104
+ @pytest.fixture()
105
+ def rnn_model() -> nn.Module:
106
+ """Define a simple plain RNN model for testing."""
107
+
108
+ class RNNModel(nn.Module):
109
+ """A simple RNN model."""
110
+
111
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
112
+ super().__init__()
113
+ self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ """Forward pass."""
117
+ out, _ = self.rnn(x)
118
+ return out
119
+
120
+ return RNNModel(input_size=10, hidden_size=20, num_layers=2)
121
+
122
+
123
+ @pytest.fixture()
124
+ def classifier_model() -> nn.Module:
125
+ """Define a model ending in a 1D (per-sample) output, e.g. classification logits."""
126
+
127
+ class ClassifierModel(nn.Module):
128
+ """A cnn model that ends with a 1D output."""
129
+
130
+ def __init__(self) -> None:
131
+ super().__init__()
132
+ self.conv = nn.Conv2d(3, 8, 3, 1, 1)
133
+ self.pool = nn.AdaptiveAvgPool2d(1)
134
+ self.fc = nn.Linear(8, 5)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ """Forward pass."""
138
+ x = self.conv(x)
139
+ x = self.pool(x)
140
+ x = torch.flatten(x, 1)
141
+ return self.fc(x)
142
+
143
+ return ClassifierModel()
144
+
145
+
146
+ def test_sequential_model_flow_view_runs(sequential_model: nn.Sequential) -> None:
147
+ """Test flow view on sequential model."""
148
+ _ = flow_view(sequential_model, input_shape=(1, 3, 224, 224))
149
+
150
+
151
+ def test_module_list_model_flow_view_runs(module_list_model: nn.ModuleList) -> None:
152
+ """Test flow view on module list model."""
153
+ _ = flow_view(module_list_model, input_shape=(1, 3, 224, 224))
154
+
155
+
156
+ def test_custom_model_flow_view_runs(custom_model: nn.Module) -> None:
157
+ """Test flow view on custom model."""
158
+ _ = flow_view(custom_model, input_shape=(1, 3, 224, 224))
159
+
160
+
161
+ def test_lstm_model_flow_view_runs(lstm_model: nn.Module) -> None:
162
+ """Test flow view on lstm model."""
163
+ _ = flow_view(lstm_model, input_shape=(1, 10, 10))
164
+
165
+
166
+ def test_gru_model_flow_view_runs(gru_model: nn.Module) -> None:
167
+ """Test flow view on gru model."""
168
+ _ = flow_view(gru_model, input_shape=(1, 10, 10))
169
+
170
+
171
+ def test_rnn_model_flow_view_runs(rnn_model: nn.Module) -> None:
172
+ """Test flow view on plain rnn model."""
173
+ _ = flow_view(rnn_model, input_shape=(1, 10, 10))
174
+
175
+
176
+ @pytest.mark.parametrize("orientation", ["x", "y", "z"])
177
+ def test_flow_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
178
+ """Test flow view on a model with a 1D output, for every supported orientation."""
179
+ img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
180
+ assert img is not None
181
+
182
+
183
+ def test_flow_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
+ """An unsupported one_dim_orientation should raise a clear ValueError."""
185
+ with pytest.raises(ValueError, match="unsupported orientation"):
186
+ flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
187
+
188
+
189
+ def test_flow_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
190
+ """Layers matched by type_ignore should be skipped without error."""
191
+ img = flow_view(
192
+ sequential_model,
193
+ input_shape=(1, 3, 224, 224),
194
+ type_ignore=[nn.ReLU],
195
+ )
196
+ assert img is not None
197
+
198
+
199
+ def test_flow_view_with_legend(sequential_model: nn.Sequential) -> None:
200
+ """legend=True should append a legend without error."""
201
+ img = flow_view(sequential_model, input_shape=(1, 3, 224, 224), legend=True)
202
+ assert img is not None
203
+
204
+
205
+ def test_flow_view_writes_to_file(sequential_model: nn.Sequential, tmp_path: Path) -> None:
206
+ """to_file should save a readable image to disk."""
207
+ out_file = tmp_path / "flow.png"
208
+ flow_view(sequential_model, input_shape=(1, 3, 224, 224), to_file=str(out_file))
209
+
210
+ assert out_file.exists()
211
+ with Image.open(out_file) as saved_img:
212
+ assert saved_img.size[0] > 0
213
+ assert saved_img.size[1] > 0
214
+
215
+
216
+ def test_flow_view_with_show_dimension(sequential_model: nn.Sequential) -> None:
217
+ """show_dimension=True should print each layer's shape without clipping or crashing."""
218
+ img = flow_view(sequential_model, input_shape=(1, 3, 224, 224), show_dimension=True)
219
+ assert img is not None
220
+
221
+
222
+ def test_flow_view_show_dimension_with_legend(sequential_model: nn.Sequential) -> None:
223
+ """show_dimension and legend should be combinable."""
224
+ img = flow_view(sequential_model, input_shape=(1, 3, 224, 224), show_dimension=True, legend=True)
225
+ assert img is not None
226
+
227
+
228
+ def test_flow_view_output_size_matches_pre_refactor_baseline(sequential_model: nn.Sequential) -> None:
229
+ """Locks in flow_view's canvas size across the backend/_volumetric_layout rewrite.
230
+
231
+ Sizes captured from `main` (0b349a3) before `register_hook` was replaced with the shared
232
+ `extract_architecture`/`layout_columns` backend - confirmed via a `git worktree` comparison
233
+ at the time of the rewrite (see the graph_view equivalent test for why size, not an exact
234
+ pixel hash: aggdraw's anti-aliasing isn't portable across platforms, but layout math is).
235
+ """
236
+ cases = {
237
+ "default": flow_view(sequential_model, input_shape=(1, 3, 32, 32)),
238
+ "no_volume": flow_view(sequential_model, input_shape=(1, 3, 32, 32), draw_volume=False),
239
+ "show_dimension": flow_view(sequential_model, input_shape=(1, 3, 32, 32), show_dimension=True),
240
+ "legend": flow_view(sequential_model, input_shape=(1, 3, 32, 32), legend=True),
241
+ "type_ignore": flow_view(sequential_model, input_shape=(1, 3, 32, 32), type_ignore=[nn.ReLU]),
242
+ "no_funnel": flow_view(sequential_model, input_shape=(1, 3, 32, 32), draw_funnel=False),
243
+ }
244
+ expected_sizes = {
245
+ "default": (124, 42),
246
+ "no_volume": (116, 32),
247
+ "show_dimension": (303, 59),
248
+ "legend": (124, 136),
249
+ "type_ignore": (82, 42),
250
+ "no_funnel": (124, 42),
251
+ }
252
+
253
+ for name, img in cases.items():
254
+ assert img.size == expected_sizes[name], f"{name} canvas size changed"
255
+
256
+
257
+ @pytest.fixture()
258
+ def residual_model() -> nn.Module:
259
+ """A residual block whose shortcut is the model's own raw input (the most common pattern)."""
260
+
261
+ class ResidualBlock(nn.Module):
262
+ def __init__(self, channels: int) -> None:
263
+ super().__init__()
264
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
265
+ self.bn1 = nn.BatchNorm2d(channels)
266
+ self.relu = nn.ReLU()
267
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
268
+ self.bn2 = nn.BatchNorm2d(channels)
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
271
+ """Forward pass with a skip connection straight from the raw input."""
272
+ identity = x
273
+ out = self.relu(self.bn1(self.conv1(x)))
274
+ out = self.bn2(self.conv2(out))
275
+ out = out + identity
276
+ return self.relu(out)
277
+
278
+ return ResidualBlock(channels=8)
279
+
280
+
281
+ @pytest.fixture()
282
+ def hidden_skip_model() -> nn.Module:
283
+ """A residual block whose shortcut originates from a hidden layer, not the raw input."""
284
+
285
+ class ResidualBlock(nn.Module):
286
+ def __init__(self) -> None:
287
+ super().__init__()
288
+ self.stem = nn.Linear(4, 4)
289
+ self.fc1 = nn.Linear(4, 4)
290
+ self.fc2 = nn.Linear(4, 4)
291
+ self.out = nn.Linear(4, 2)
292
+
293
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
294
+ """Forward pass with a skip connection around fc1/fc2."""
295
+ stem_out = self.stem(x)
296
+ branch = self.fc2(self.fc1(stem_out))
297
+ merged = branch + stem_out
298
+ return self.out(merged)
299
+
300
+ return ResidualBlock()
301
+
302
+
303
+ def test_flow_view_residual_model_runs(residual_model: nn.Module) -> None:
304
+ """flow_view should not crash on a model with a skip connection from the raw input."""
305
+ img = flow_view(residual_model, input_shape=(1, 8, 16, 16))
306
+ assert img is not None
307
+
308
+
309
+ def test_flow_view_hidden_skip_model_runs(hidden_skip_model: nn.Module) -> None:
310
+ """flow_view should not crash on a model with a skip connection from a hidden layer."""
311
+ img = flow_view(hidden_skip_model, input_shape=(1, 4))
312
+ assert img is not None
313
+
314
+
315
+ def test_flow_view_residual_model_routes_above_diagram(residual_model: nn.Module) -> None:
316
+ """A skip connection from the raw input should reserve extra vertical space, not vanish."""
317
+
318
+ class PlainChain(nn.Module):
319
+ """The same layers as residual_model, but without the skip connection."""
320
+
321
+ def __init__(self, channels: int) -> None:
322
+ super().__init__()
323
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
324
+ self.bn1 = nn.BatchNorm2d(channels)
325
+ self.relu = nn.ReLU()
326
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
327
+ self.bn2 = nn.BatchNorm2d(channels)
328
+
329
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
330
+ """Forward pass with no skip connection."""
331
+ return self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))))
332
+
333
+ img_with_skip = flow_view(residual_model, input_shape=(1, 8, 16, 16))
334
+ img_without_skip = flow_view(PlainChain(channels=8), input_shape=(1, 8, 16, 16))
335
+
336
+ assert img_with_skip.size[1] > img_without_skip.size[1]
337
+
338
+
339
+ def test_flow_view_hidden_skip_model_routes_above_diagram(hidden_skip_model: nn.Module) -> None:
340
+ """A skip connection from a hidden layer should also reserve extra vertical space."""
341
+
342
+ class PlainChain(nn.Module):
343
+ def __init__(self) -> None:
344
+ super().__init__()
345
+ self.stem = nn.Linear(4, 4)
346
+ self.fc1 = nn.Linear(4, 4)
347
+ self.fc2 = nn.Linear(4, 4)
348
+ self.out = nn.Linear(4, 2)
349
+
350
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
351
+ """Forward pass with no skip connection."""
352
+ return self.out(self.fc2(self.fc1(self.stem(x))))
353
+
354
+ img_with_skip = flow_view(hidden_skip_model, input_shape=(1, 4))
355
+ img_without_skip = flow_view(PlainChain(), input_shape=(1, 4))
356
+
357
+ assert img_with_skip.size[1] > img_without_skip.size[1]
358
+
359
+
360
+ def test_flow_view_deep_repeated_residual_blocks_stays_reasonably_sized() -> None:
361
+ """Back-to-back, non-overlapping residual blocks should share one detour level, not stack per block."""
362
+
363
+ class ResBlock(nn.Module):
364
+ def __init__(self, channels: int) -> None:
365
+ super().__init__()
366
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
367
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
368
+ self.relu = nn.ReLU()
369
+
370
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
371
+ """Forward pass with a skip connection around conv1/conv2."""
372
+ identity = x
373
+ out = self.conv2(self.conv1(x))
374
+ return self.relu(out + identity)
375
+
376
+ class DeepModel(nn.Module):
377
+ def __init__(self, channels: int, n_blocks: int) -> None:
378
+ super().__init__()
379
+ self.blocks = nn.ModuleList([ResBlock(channels) for _ in range(n_blocks)])
380
+
381
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
382
+ """Forward pass through every block in sequence."""
383
+ for block in self.blocks:
384
+ x = block(x)
385
+ return x
386
+
387
+ img_2_blocks = flow_view(DeepModel(4, 2), input_shape=(1, 4, 8, 8))
388
+ img_6_blocks = flow_view(DeepModel(4, 6), input_shape=(1, 4, 8, 8))
389
+
390
+ assert img_2_blocks.size[1] == img_6_blocks.size[1]
391
+
392
+
393
+ def _non_background_pixel_count(img: Image.Image) -> int:
394
+ return int((np.array(img.convert("RGB")) != 255).any(axis=2).sum())
395
+
396
+
397
+ def test_flow_view_funnels_survive_large_de_differences_between_layers() -> None:
398
+ """A funnel between two layers with very different 3D depth (`de`) must stay visible.
399
+
400
+ Regression test for a real bug: drawing every connector first and every box second (instead
401
+ of interleaving them column by column) let each box's opaque fill blot out large parts of
402
+ its own incoming funnel whenever neighboring layers have a very different `de` - which
403
+ barely showed on the small, near-constant-`de` models used to verify the flow_view
404
+ rewrite, but was highly visible on a real CNN (found by the user manually comparing
405
+ ReadTheDocs' `plot_basic_custom` example before/after the rewrite). Canvas *size* alone
406
+ doesn't catch this class of bug (box positions are unaffected, only which pixels get
407
+ painted), so this asserts on rendered content instead.
408
+ """
409
+
410
+ class SimpleCNN(nn.Module):
411
+ """The exact model from docs/examples/flow/plot_basic_custom.py."""
412
+
413
+ def __init__(self) -> None:
414
+ super().__init__()
415
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
416
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
417
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
418
+ self.fc1 = nn.Linear(64 * 28 * 28, 128)
419
+ self.fc2 = nn.Linear(128, 10)
420
+
421
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
422
+ """Forward pass with three shrinking conv/pool stages."""
423
+ x = self.conv1(x)
424
+ x = func.relu(x)
425
+ x = func.max_pool2d(x, 2, 2)
426
+ x = self.conv2(x)
427
+ x = func.relu(x)
428
+ x = func.max_pool2d(x, 2, 2)
429
+ x = self.conv3(x)
430
+ x = func.relu(x)
431
+ x = func.max_pool2d(x, 2, 2)
432
+ x = x.view(x.size(0), -1)
433
+ x = self.fc1(x)
434
+ x = func.relu(x)
435
+ return self.fc2(x)
436
+
437
+ img = flow_view(SimpleCNN(), input_shape=(1, 3, 224, 224), legend=True)
438
+
439
+ # Locked in from the current (fixed) implementation - confirmed pixel-identical to
440
+ # pre-rewrite main (1ee630e) via a git-worktree comparison for this exact model. The buggy
441
+ # intermediate version rendered thousands fewer non-background pixels here (missing funnel
442
+ # segments), so a wide but real tolerance still catches a regression of that class.
443
+ assert img.size == (153, 336)
444
+ non_bg = _non_background_pixel_count(img)
445
+ error_msg = f"non-background pixel count {non_bg} outside expected range - funnel likely broken"
446
+ assert 21000 <= non_bg <= 24000, error_msg
447
+
448
+
449
+ def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
450
+ """Unlike the single-input case, flow_view must not hide any of 2+ separate input boxes -
451
+ hiding one would make it ambiguous which arrow originates from which named input.
452
+ """ # noqa: D205
453
+ from visualtorch.backend import extract_architecture
454
+ from visualtorch.utils.layer_utils import InputDummyLayer
455
+
456
+ class TwoInputNet(nn.Module):
457
+ def __init__(self) -> None:
458
+ super().__init__()
459
+ self.a = nn.Linear(4, 4)
460
+ self.b = nn.Linear(4, 4)
461
+ self.head = nn.Linear(8, 2)
462
+
463
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
464
+ return self.head(torch.cat([self.a(x), self.b(y)], dim=1))
465
+
466
+ architecture = extract_architecture(TwoInputNet(), ((1, 4), (1, 4)))
467
+ input_labels = {
468
+ layer.module.name() for layer in architecture.columns[0] if isinstance(layer.module, InputDummyLayer)
469
+ }
470
+ assert input_labels == {"input_0", "input_1"}
471
+
472
+ img = flow_view(TwoInputNet(), input_shape=((1, 4), (1, 4)))
473
+ assert img is not None