visualtorch 1.0.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {visualtorch-1.0.0/visualtorch.egg-info → visualtorch-1.2.0}/PKG-INFO +25 -15
  2. {visualtorch-1.0.0 → visualtorch-1.2.0}/README.md +23 -14
  3. {visualtorch-1.0.0 → visualtorch-1.2.0}/setup.py +1 -1
  4. visualtorch-1.2.0/tests/test_connectors.py +187 -0
  5. {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_flow.py +132 -17
  6. {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_graph.py +65 -39
  7. {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_lenet_style.py +56 -5
  8. {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_regression_issues.py +47 -1
  9. {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_render.py +44 -0
  10. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/__init__.py +4 -2
  11. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/_volumetric_layout.py +7 -0
  12. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/backend.py +7 -3
  13. visualtorch-1.2.0/visualtorch/connectors.py +206 -0
  14. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/flow.py +55 -22
  15. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/graph.py +57 -9
  16. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/lenet_style.py +73 -29
  17. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/render.py +14 -4
  18. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/__init__.py +2 -2
  19. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/layer_utils.py +2 -2
  20. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/recorder.py +40 -17
  21. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/traced_layer.py +6 -0
  22. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/utils.py +136 -15
  23. {visualtorch-1.0.0 → visualtorch-1.2.0/visualtorch.egg-info}/PKG-INFO +25 -15
  24. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/SOURCES.txt +1 -0
  25. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/requires.txt +1 -0
  26. visualtorch-1.0.0/visualtorch/connectors.py +0 -90
  27. {visualtorch-1.0.0 → visualtorch-1.2.0}/LICENSE +0 -0
  28. {visualtorch-1.0.0 → visualtorch-1.2.0}/pyproject.toml +0 -0
  29. {visualtorch-1.0.0 → visualtorch-1.2.0}/setup.cfg +0 -0
  30. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/dependency_links.txt +0 -0
  31. {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: visualtorch
3
- Version: 1.0.0
3
+ Version: 1.2.0
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -30,6 +30,7 @@ Requires-Dist: sphinx-copybutton; extra == "dev"
30
30
  Requires-Dist: sphinx_design; extra == "dev"
31
31
  Requires-Dist: sphinx_gallery; extra == "dev"
32
32
  Requires-Dist: matplotlib; extra == "dev"
33
+ Requires-Dist: torchvision; extra == "dev"
33
34
  Requires-Dist: pre-commit; extra == "dev"
34
35
  Requires-Dist: pytest; extra == "dev"
35
36
  Dynamic: author
@@ -53,19 +54,22 @@ Dynamic: summary
53
54
 
54
55
  </div>
55
56
 
56
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
57
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
57
58
 
58
- **Note:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
59
- limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
60
- execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
61
- value crosses some threshold) only show whichever branch the traced dummy input happened to take;
62
- (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
63
- only has its first tensor's shape reflected in that node's size/label - its downstream connections
64
- are still correct either way. Contributions are welcome!
59
+ **Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
60
+
61
+ **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
62
+ limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
63
+ execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
64
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take.
65
+ Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
66
+ head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
67
+ tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
68
+ the first. Downstream connections are correct either way. Contributions are welcome!
65
69
 
66
70
  <div align="center">
67
71
 
68
- ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
72
+ ![VisualTorch Examples](https://raw.githubusercontent.com/willyfh/visualtorch/e6ad79751e0f7412b1074beb45f9baeccd1419e4/docs/source/_static/images/banners/readme-examples.png)
69
73
 
70
74
  </div>
71
75
 
@@ -79,6 +83,12 @@ The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/u
79
83
 
80
84
  See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
81
85
 
86
+ ## Used in Research
87
+
88
+ VisualTorch has been used in published research, including works published in Nature, IEEE, and MDPI.
89
+
90
+ See the [Research Showcase page](https://visualtorch.readthedocs.io/en/latest/markdown/showcase/index.html) for the full list.
91
+
82
92
  ## Examples
83
93
 
84
94
  See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
@@ -91,16 +101,16 @@ Please feel free to send a pull request to contribute to this project by followi
91
101
 
92
102
  This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
93
103
 
94
- Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
104
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
95
105
 
96
106
  ## Citation
97
107
 
98
108
  Please cite this project in your publications if it helps your research.
99
109
 
100
- **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
101
- since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
102
- for the current API) - the DOI always resolves to what was actually reviewed and published, so
103
- it isn't updated to match.
110
+ **Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
111
+ since been substantially refactored, including breaking API changes (see the
112
+ [documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
113
+ always resolves to what was actually reviewed and published.
104
114
 
105
115
  ```bibtex
106
116
  @article{Hendria2024,
@@ -5,19 +5,22 @@
5
5
 
6
6
  </div>
7
7
 
8
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
8
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
9
9
 
10
- **Note:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
11
- limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
12
- execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
13
- value crosses some threshold) only show whichever branch the traced dummy input happened to take;
14
- (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
15
- only has its first tensor's shape reflected in that node's size/label - its downstream connections
16
- are still correct either way. Contributions are welcome!
10
+ **Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
11
+
12
+ **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
13
+ limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
14
+ execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
15
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take.
16
+ Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
17
+ head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
18
+ tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
19
+ the first. Downstream connections are correct either way. Contributions are welcome!
17
20
 
18
21
  <div align="center">
19
22
 
20
- ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
23
+ ![VisualTorch Examples](https://raw.githubusercontent.com/willyfh/visualtorch/e6ad79751e0f7412b1074beb45f9baeccd1419e4/docs/source/_static/images/banners/readme-examples.png)
21
24
 
22
25
  </div>
23
26
 
@@ -31,6 +34,12 @@ The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/u
31
34
 
32
35
  See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
33
36
 
37
+ ## Used in Research
38
+
39
+ VisualTorch has been used in published research, including works published in Nature, IEEE, and MDPI.
40
+
41
+ See the [Research Showcase page](https://visualtorch.readthedocs.io/en/latest/markdown/showcase/index.html) for the full list.
42
+
34
43
  ## Examples
35
44
 
36
45
  See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
@@ -43,16 +52,16 @@ Please feel free to send a pull request to contribute to this project by followi
43
52
 
44
53
  This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
45
54
 
46
- Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
55
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
47
56
 
48
57
  ## Citation
49
58
 
50
59
  Please cite this project in your publications if it helps your research.
51
60
 
52
- **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
53
- since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
54
- for the current API) - the DOI always resolves to what was actually reviewed and published, so
55
- it isn't updated to match.
61
+ **Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
62
+ since been substantially refactored, including breaking API changes (see the
63
+ [documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
64
+ always resolves to what was actually reviewed and published.
56
65
 
57
66
  ```bibtex
58
67
  @article{Hendria2024,
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
21
21
 
22
22
  setuptools.setup(
23
23
  name="visualtorch",
24
- version="1.0.0",
24
+ version="1.2.0",
25
25
  author="Willy Fitra Hendria",
26
26
  author_email="willyfitrahendria@gmail.com",
27
27
  description="Architecture visualization of Torch models",
@@ -0,0 +1,187 @@
1
+ """Tests for the frontend-agnostic connector-routing helpers."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from visualtorch.connectors import _segment_intersects_rect, compute_skip_levels
7
+
8
+
9
+ def _no_bbox(_node_id: str) -> tuple[float, float, float, float] | None:
10
+ return None
11
+
12
+
13
+ # ---- compute_skip_levels: span/content gating ----
14
+
15
+
16
+ def test_compute_skip_levels_ignores_span_le_1_edges() -> None:
17
+ """An edge with column span <= 1 should never be assigned a detour level."""
18
+ id_to_column = {"a": 0, "b": 1}
19
+
20
+ edge_to_level, num_levels = compute_skip_levels([("a", "b")], id_to_column, lambda *_: True, _no_bbox)
21
+
22
+ assert edge_to_level == {}
23
+ assert num_levels == 0
24
+
25
+
26
+ def test_compute_skip_levels_assigns_distinct_levels_to_overlapping_skips() -> None:
27
+ """Two skip edges whose column spans genuinely overlap must get different levels."""
28
+ id_to_column = {"a": 0, "b": 3, "c": 2, "d": 5}
29
+ edges = [("a", "b"), ("c", "d")]
30
+
31
+ edge_to_level, num_levels = compute_skip_levels(edges, id_to_column, lambda *_: True, _no_bbox)
32
+
33
+ assert num_levels == 2
34
+ assert edge_to_level[("a", "b")] != edge_to_level[("c", "d")]
35
+
36
+
37
+ def test_compute_skip_levels_allows_touching_intervals_to_share_a_level() -> None:
38
+ """Two skip edges that only touch at a shared column boundary can share a level."""
39
+ id_to_column = {"a": 0, "b": 3, "c": 3, "d": 5}
40
+ edges = [("a", "b"), ("c", "d")]
41
+
42
+ edge_to_level, num_levels = compute_skip_levels(edges, id_to_column, lambda *_: True, _no_bbox)
43
+
44
+ assert num_levels == 1
45
+ assert edge_to_level[("a", "b")] == edge_to_level[("c", "d")] == 0
46
+
47
+
48
+ def test_compute_skip_levels_ignores_edges_with_no_content() -> None:
49
+ """A skip edge that `edge_has_content` reports as empty shouldn't consume a level."""
50
+ id_to_column = {"a": 0, "b": 3}
51
+
52
+ edge_to_level, num_levels = compute_skip_levels([("a", "b")], id_to_column, lambda *_: False, _no_bbox)
53
+
54
+ assert edge_to_level == {}
55
+ assert num_levels == 0
56
+
57
+
58
+ # ---- compute_skip_levels: collision-awareness ----
59
+
60
+
61
+ def test_compute_skip_levels_keeps_level_when_intervening_box_collides() -> None:
62
+ """A span>1 edge whose straight line genuinely crosses an intervening same-row box."""
63
+ id_to_column = {"a": 0, "mid": 1, "b": 2}
64
+ bboxes = {
65
+ "a": (0.0, 0.0, 10.0, 10.0),
66
+ "mid": (20.0, 0.0, 30.0, 10.0), # same row (y=0..10) as a and b - directly in the path
67
+ "b": (40.0, 0.0, 50.0, 10.0),
68
+ }
69
+
70
+ edge_to_level, num_levels = compute_skip_levels(
71
+ [("a", "b")],
72
+ id_to_column,
73
+ lambda *_: True,
74
+ lambda node_id: bboxes[node_id],
75
+ )
76
+
77
+ assert num_levels == 1
78
+ assert edge_to_level[("a", "b")] == 0
79
+
80
+
81
+ def test_compute_skip_levels_drops_level_when_intervening_box_is_a_different_row() -> None:
82
+ """A span>1 edge whose straight line passes clear of an intervening box in a different row."""
83
+ id_to_column = {"a": 0, "mid": 1, "b": 2}
84
+ bboxes = {
85
+ "a": (0.0, 0.0, 10.0, 10.0),
86
+ "mid": (20.0, 100.0, 30.0, 110.0), # a different row entirely - well below the line
87
+ "b": (40.0, 0.0, 50.0, 10.0),
88
+ }
89
+
90
+ edge_to_level, num_levels = compute_skip_levels(
91
+ [("a", "b")],
92
+ id_to_column,
93
+ lambda *_: True,
94
+ lambda node_id: bboxes[node_id],
95
+ )
96
+
97
+ assert edge_to_level == {}
98
+ assert num_levels == 0
99
+
100
+
101
+ def test_compute_skip_levels_missing_bbox_conservatively_assumes_collision() -> None:
102
+ """If an intervening id's bbox can't be resolved, assume a collision rather than guessing clear."""
103
+ id_to_column = {"a": 0, "mid": 1, "b": 2}
104
+ bboxes = {"a": (0.0, 0.0, 10.0, 10.0), "b": (40.0, 0.0, 50.0, 10.0)}
105
+
106
+ edge_to_level, num_levels = compute_skip_levels(
107
+ [("a", "b")],
108
+ id_to_column,
109
+ lambda *_: True,
110
+ lambda node_id: bboxes.get(node_id),
111
+ )
112
+
113
+ assert num_levels == 1
114
+ assert edge_to_level[("a", "b")] == 0
115
+
116
+
117
+ def test_compute_skip_levels_missing_endpoint_bbox_conservatively_assumes_collision() -> None:
118
+ """If the edge's own start/end bbox can't be resolved, also assume a collision."""
119
+ id_to_column = {"a": 0, "b": 2}
120
+
121
+ edge_to_level, num_levels = compute_skip_levels([("a", "b")], id_to_column, lambda *_: True, _no_bbox)
122
+
123
+ assert num_levels == 1
124
+ assert edge_to_level[("a", "b")] == 0
125
+
126
+
127
+ # ---- _segment_intersects_rect ----
128
+
129
+
130
+ def test_segment_intersects_rect_segment_fully_inside() -> None:
131
+ """A segment entirely within the rect counts as intersecting."""
132
+ assert _segment_intersects_rect(2, 2, 8, 8, 0, 0, 10, 10) is True
133
+
134
+
135
+ def test_segment_intersects_rect_segment_crosses_through_middle() -> None:
136
+ """A segment passing straight through the rect's interior intersects."""
137
+ assert _segment_intersects_rect(-5, 5, 15, 5, 0, 0, 10, 10) is True
138
+
139
+
140
+ def test_segment_intersects_rect_segment_passes_above() -> None:
141
+ """A segment entirely above the rect doesn't intersect."""
142
+ assert _segment_intersects_rect(-5, -20, 15, -20, 0, 0, 10, 10) is False
143
+
144
+
145
+ def test_segment_intersects_rect_segment_passes_below() -> None:
146
+ """A segment entirely below the rect doesn't intersect."""
147
+ assert _segment_intersects_rect(-5, 50, 15, 50, 0, 0, 10, 10) is False
148
+
149
+
150
+ def test_segment_intersects_rect_touches_one_corner() -> None:
151
+ """Touching exactly one corner counts as intersecting (conservative)."""
152
+ assert _segment_intersects_rect(-5, -5, 0, 0, 0, 0, 10, 10) is True
153
+
154
+
155
+ def test_segment_intersects_rect_touches_one_edge() -> None:
156
+ """Touching exactly one edge counts as intersecting (conservative)."""
157
+ assert _segment_intersects_rect(-5, 5, 0, 5, 0, 0, 10, 10) is True
158
+
159
+
160
+ def test_segment_intersects_rect_horizontal_segment_clear() -> None:
161
+ """A horizontal segment clear of the rect doesn't intersect."""
162
+ assert _segment_intersects_rect(-5, -5, 15, -5, 0, 0, 10, 10) is False
163
+
164
+
165
+ def test_segment_intersects_rect_horizontal_segment_through_rect() -> None:
166
+ """A horizontal segment passing through the rect intersects."""
167
+ assert _segment_intersects_rect(-5, 5, 15, 5, 0, 0, 10, 10) is True
168
+
169
+
170
+ def test_segment_intersects_rect_vertical_segment_through_rect() -> None:
171
+ """A vertical segment passing through the rect intersects."""
172
+ assert _segment_intersects_rect(5, -5, 5, 15, 0, 0, 10, 10) is True
173
+
174
+
175
+ def test_segment_intersects_rect_vertical_segment_clear() -> None:
176
+ """A vertical segment clear of the rect doesn't intersect."""
177
+ assert _segment_intersects_rect(50, -5, 50, 15, 0, 0, 10, 10) is False
178
+
179
+
180
+ def test_segment_intersects_rect_degenerate_point_inside() -> None:
181
+ """A zero-length segment (a point) inside the rect counts as intersecting."""
182
+ assert _segment_intersects_rect(5, 5, 5, 5, 0, 0, 10, 10) is True
183
+
184
+
185
+ def test_segment_intersects_rect_degenerate_point_outside() -> None:
186
+ """A zero-length segment (a point) outside the rect doesn't intersect."""
187
+ assert _segment_intersects_rect(50, 50, 50, 50, 0, 0, 10, 10) is False
@@ -174,16 +174,16 @@ def test_rnn_model_flow_view_runs(rnn_model: nn.Module) -> None:
174
174
 
175
175
 
176
176
  @pytest.mark.parametrize("orientation", ["x", "y", "z"])
177
- def test_flow_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
177
+ def test_flow_view_low_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
178
178
  """Test flow view on a model with a 1D output, for every supported orientation."""
179
- img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
179
+ img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation=orientation)
180
180
  assert img is not None
181
181
 
182
182
 
183
- def test_flow_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
- """An unsupported one_dim_orientation should raise a clear ValueError."""
183
+ def test_flow_view_invalid_low_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
+ """An unsupported low_dim_orientation should raise a clear ValueError."""
185
185
  with pytest.raises(ValueError, match="unsupported orientation"):
186
- flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
186
+ flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation="bad")
187
187
 
188
188
 
189
189
  def test_flow_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
@@ -232,6 +232,13 @@ def test_flow_view_output_size_matches_pre_refactor_baseline(sequential_model: n
232
232
  `extract_architecture`/`layout_columns` backend - confirmed via a `git worktree` comparison
233
233
  at the time of the rewrite (see the graph_view equivalent test for why size, not an exact
234
234
  pixel hash: aggdraw's anti-aliasing isn't portable across platforms, but layout math is).
235
+
236
+ Updated after `show_input` defaulted to True (an intentional, deliberate visual change - a
237
+ single-consumer input box is now shown by default, unlike flow_view's original look), so
238
+ these sizes are no longer literally pre-refactor but reflect the new intended default.
239
+
240
+ "legend" was updated again after the synthetic input class was renamed from `InputDummyLayer`
241
+ to `Input`, changing that legend patch's text width by 1px.
235
242
  """
236
243
  cases = {
237
244
  "default": flow_view(sequential_model, input_shape=(1, 3, 32, 32)),
@@ -242,12 +249,12 @@ def test_flow_view_output_size_matches_pre_refactor_baseline(sequential_model: n
242
249
  "no_funnel": flow_view(sequential_model, input_shape=(1, 3, 32, 32), draw_funnel=False),
243
250
  }
244
251
  expected_sizes = {
245
- "default": (124, 42),
246
- "no_volume": (116, 32),
247
- "show_dimension": (303, 59),
248
- "legend": (124, 136),
249
- "type_ignore": (82, 42),
250
- "no_funnel": (124, 42),
252
+ "default": (144, 42),
253
+ "no_volume": (136, 32),
254
+ "show_dimension": (353, 59),
255
+ "legend": (148, 136),
256
+ "type_ignore": (102, 42),
257
+ "no_funnel": (144, 42),
251
258
  }
252
259
 
253
260
  for name, img in cases.items():
@@ -440,10 +447,13 @@ def test_flow_view_funnels_survive_large_de_differences_between_layers() -> None
440
447
  # pre-rewrite main (1ee630e) via a git-worktree comparison for this exact model. The buggy
441
448
  # intermediate version rendered thousands fewer non-background pixels here (missing funnel
442
449
  # segments), so a wide but real tolerance still catches a regression of that class.
443
- assert img.size == (153, 336)
450
+ #
451
+ # Updated after `show_input` defaulted to True - the input box is now shown, adding a
452
+ # consistent amount of extra canvas/content on top of the original baseline above.
453
+ assert img.size == (171, 364)
444
454
  non_bg = _non_background_pixel_count(img)
445
455
  error_msg = f"non-background pixel count {non_bg} outside expected range - funnel likely broken"
446
- assert 21000 <= non_bg <= 24000, error_msg
456
+ assert 27000 <= non_bg <= 31000, error_msg
447
457
 
448
458
 
449
459
  def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
@@ -451,7 +461,7 @@ def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
451
461
  hiding one would make it ambiguous which arrow originates from which named input.
452
462
  """ # noqa: D205
453
463
  from visualtorch.backend import extract_architecture
454
- from visualtorch.utils.layer_utils import InputDummyLayer
464
+ from visualtorch.utils.layer_utils import Input
455
465
 
456
466
  class TwoInputNet(nn.Module):
457
467
  def __init__(self) -> None:
@@ -464,10 +474,115 @@ def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
464
474
  return self.head(torch.cat([self.a(x), self.b(y)], dim=1))
465
475
 
466
476
  architecture = extract_architecture(TwoInputNet(), ((1, 4), (1, 4)))
467
- input_labels = {
468
- layer.module.name() for layer in architecture.columns[0] if isinstance(layer.module, InputDummyLayer)
469
- }
477
+ input_labels = {layer.module.name() for layer in architecture.columns[0] if isinstance(layer.module, Input)}
470
478
  assert input_labels == {"input_0", "input_1"}
471
479
 
472
480
  img = flow_view(TwoInputNet(), input_shape=((1, 4), (1, 4)))
473
481
  assert img is not None
482
+
483
+
484
+ def test_flow_view_mismatched_depth_siamese_branches_needs_no_detour() -> None:
485
+ """Sibling branches of different depths merging shouldn't trigger a routed detour."""
486
+
487
+ class SiameseNet(nn.Module):
488
+ def __init__(self) -> None:
489
+ super().__init__()
490
+ self.image_branch = nn.Sequential(
491
+ nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
492
+ nn.ReLU(),
493
+ nn.AdaptiveAvgPool2d(1),
494
+ nn.Flatten(),
495
+ )
496
+ self.vector_branch = nn.Sequential(
497
+ nn.Linear(10, 8),
498
+ nn.ReLU(),
499
+ )
500
+ self.head = nn.Linear(16, 4)
501
+
502
+ def forward(self, image: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
503
+ """Run each branch on its own input tensor, then concatenate and project."""
504
+ image_features = self.image_branch(image)
505
+ vector_features = self.vector_branch(vector)
506
+ merged = torch.cat([image_features, vector_features], dim=1)
507
+ return self.head(merged)
508
+
509
+ class SiameseNetDepthMatched(nn.Module):
510
+ def __init__(self) -> None:
511
+ super().__init__()
512
+ self.image_branch = nn.Sequential(
513
+ nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
514
+ nn.ReLU(),
515
+ nn.AdaptiveAvgPool2d(1),
516
+ nn.Flatten(),
517
+ )
518
+ self.vector_branch = nn.Sequential(
519
+ nn.Linear(10, 8),
520
+ nn.ReLU(),
521
+ nn.Linear(8, 8),
522
+ nn.ReLU(),
523
+ )
524
+ self.head = nn.Linear(16, 4)
525
+
526
+ def forward(self, image: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
527
+ """Run each branch on its own input tensor, then concatenate and project."""
528
+ image_features = self.image_branch(image)
529
+ vector_features = self.vector_branch(vector)
530
+ merged = torch.cat([image_features, vector_features], dim=1)
531
+ return self.head(merged)
532
+
533
+ input_shape = ((1, 3, 16, 16), (1, 10))
534
+ img_mismatched = flow_view(SiameseNet(), input_shape=input_shape)
535
+ img_matched = flow_view(SiameseNetDepthMatched(), input_shape=input_shape)
536
+
537
+ assert img_mismatched.size[1] == img_matched.size[1]
538
+
539
+
540
+ def test_flow_view_low_dim_orientation_affects_2d_shapes() -> None:
541
+ """A 2D shape (e.g. an RNN's (seq_len, hidden_size)) should now respond to
542
+ low_dim_orientation too, not just genuine 1D shapes.
543
+ """ # noqa: D205
544
+
545
+ class SequenceClassifier(nn.Module):
546
+ def __init__(self, hidden_size: int) -> None:
547
+ super().__init__()
548
+ self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
549
+
550
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
551
+ out, _ = self.lstm(x)
552
+ return out
553
+
554
+ model = SequenceClassifier(hidden_size=64)
555
+ input_shape = (1, 5, 8)
556
+
557
+ sizes = {
558
+ orientation: flow_view(model, input_shape=input_shape, low_dim_orientation=orientation).size
559
+ for orientation in ("x", "y", "z")
560
+ }
561
+
562
+ assert len(set(sizes.values())) == 3, f"expected all 3 orientations to differ, got {sizes}"
563
+
564
+
565
+ def test_flow_view_2d_shape_seq_len_is_discarded() -> None:
566
+ """The positional-like dim (e.g. seq_len) of a 2D shape shouldn't affect box size -
567
+ only the feature-like dim (e.g. hidden_size) should.
568
+ """ # noqa: D205
569
+
570
+ class SequenceClassifier(nn.Module):
571
+ def __init__(self, hidden_size: int) -> None:
572
+ super().__init__()
573
+ self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
574
+
575
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
576
+ out, _ = self.lstm(x)
577
+ return out
578
+
579
+ model = SequenceClassifier(hidden_size=64)
580
+ img_short_seq = flow_view(model, input_shape=(1, 5, 8))
581
+ img_long_seq = flow_view(model, input_shape=(1, 50, 8))
582
+
583
+ assert img_short_seq.tobytes() == img_long_seq.tobytes()
584
+
585
+ model_bigger_hidden = SequenceClassifier(hidden_size=256)
586
+ img_bigger_hidden = flow_view(model_bigger_hidden, input_shape=(1, 5, 8))
587
+
588
+ assert img_short_seq.tobytes() != img_bigger_hidden.tobytes()