visualtorch 1.0.0__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visualtorch-1.0.0/visualtorch.egg-info → visualtorch-1.2.0}/PKG-INFO +25 -15
- {visualtorch-1.0.0 → visualtorch-1.2.0}/README.md +23 -14
- {visualtorch-1.0.0 → visualtorch-1.2.0}/setup.py +1 -1
- visualtorch-1.2.0/tests/test_connectors.py +187 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_flow.py +132 -17
- {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_graph.py +65 -39
- {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_lenet_style.py +56 -5
- {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_regression_issues.py +47 -1
- {visualtorch-1.0.0 → visualtorch-1.2.0}/tests/test_render.py +44 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/__init__.py +4 -2
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/_volumetric_layout.py +7 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/backend.py +7 -3
- visualtorch-1.2.0/visualtorch/connectors.py +206 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/flow.py +55 -22
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/graph.py +57 -9
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/lenet_style.py +73 -29
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/render.py +14 -4
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/__init__.py +2 -2
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/layer_utils.py +2 -2
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/recorder.py +40 -17
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/traced_layer.py +6 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch/utils/utils.py +136 -15
- {visualtorch-1.0.0 → visualtorch-1.2.0/visualtorch.egg-info}/PKG-INFO +25 -15
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/SOURCES.txt +1 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/requires.txt +1 -0
- visualtorch-1.0.0/visualtorch/connectors.py +0 -90
- {visualtorch-1.0.0 → visualtorch-1.2.0}/LICENSE +0 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/pyproject.toml +0 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/setup.cfg +0 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/dependency_links.txt +0 -0
- {visualtorch-1.0.0 → visualtorch-1.2.0}/visualtorch.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -30,6 +30,7 @@ Requires-Dist: sphinx-copybutton; extra == "dev"
|
|
|
30
30
|
Requires-Dist: sphinx_design; extra == "dev"
|
|
31
31
|
Requires-Dist: sphinx_gallery; extra == "dev"
|
|
32
32
|
Requires-Dist: matplotlib; extra == "dev"
|
|
33
|
+
Requires-Dist: torchvision; extra == "dev"
|
|
33
34
|
Requires-Dist: pre-commit; extra == "dev"
|
|
34
35
|
Requires-Dist: pytest; extra == "dev"
|
|
35
36
|
Dynamic: author
|
|
@@ -53,19 +54,22 @@ Dynamic: summary
|
|
|
53
54
|
|
|
54
55
|
</div>
|
|
55
56
|
|
|
56
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models.
|
|
57
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
|
|
57
58
|
|
|
58
|
-
**Note:**
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
59
|
+
**Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
|
|
60
|
+
|
|
61
|
+
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
|
|
62
|
+
limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
|
|
63
|
+
execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
64
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take.
|
|
65
|
+
Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
|
|
66
|
+
head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
|
|
67
|
+
tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
|
|
68
|
+
the first. Downstream connections are correct either way. Contributions are welcome!
|
|
65
69
|
|
|
66
70
|
<div align="center">
|
|
67
71
|
|
|
68
|
-

|
|
72
|
+

|
|
69
73
|
|
|
70
74
|
</div>
|
|
71
75
|
|
|
@@ -79,6 +83,12 @@ The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/u
|
|
|
79
83
|
|
|
80
84
|
See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
|
|
81
85
|
|
|
86
|
+
## Used in Research
|
|
87
|
+
|
|
88
|
+
VisualTorch has been used in published research, including works published in Nature, IEEE, and MDPI.
|
|
89
|
+
|
|
90
|
+
See the [Research Showcase page](https://visualtorch.readthedocs.io/en/latest/markdown/showcase/index.html) for the full list.
|
|
91
|
+
|
|
82
92
|
## Examples
|
|
83
93
|
|
|
84
94
|
See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
|
|
@@ -91,16 +101,16 @@ Please feel free to send a pull request to contribute to this project by followi
|
|
|
91
101
|
|
|
92
102
|
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
93
103
|
|
|
94
|
-
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz),
|
|
104
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
|
|
95
105
|
|
|
96
106
|
## Citation
|
|
97
107
|
|
|
98
108
|
Please cite this project in your publications if it helps your research.
|
|
99
109
|
|
|
100
|
-
**Note:** the paper below describes
|
|
101
|
-
since
|
|
102
|
-
for the current API) - the DOI
|
|
103
|
-
|
|
110
|
+
**Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
|
|
111
|
+
since been substantially refactored, including breaking API changes (see the
|
|
112
|
+
[documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
|
|
113
|
+
always resolves to what was actually reviewed and published.
|
|
104
114
|
|
|
105
115
|
```bibtex
|
|
106
116
|
@article{Hendria2024,
|
|
@@ -5,19 +5,22 @@
|
|
|
5
5
|
|
|
6
6
|
</div>
|
|
7
7
|
|
|
8
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models.
|
|
8
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
|
|
9
9
|
|
|
10
|
-
**Note:**
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
**Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
|
|
11
|
+
|
|
12
|
+
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
|
|
13
|
+
limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
|
|
14
|
+
execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
15
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take.
|
|
16
|
+
Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
|
|
17
|
+
head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
|
|
18
|
+
tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
|
|
19
|
+
the first. Downstream connections are correct either way. Contributions are welcome!
|
|
17
20
|
|
|
18
21
|
<div align="center">
|
|
19
22
|
|
|
20
|
-

|
|
23
|
+

|
|
21
24
|
|
|
22
25
|
</div>
|
|
23
26
|
|
|
@@ -31,6 +34,12 @@ The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/u
|
|
|
31
34
|
|
|
32
35
|
See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
|
|
33
36
|
|
|
37
|
+
## Used in Research
|
|
38
|
+
|
|
39
|
+
VisualTorch has been used in published research, including works published in Nature, IEEE, and MDPI.
|
|
40
|
+
|
|
41
|
+
See the [Research Showcase page](https://visualtorch.readthedocs.io/en/latest/markdown/showcase/index.html) for the full list.
|
|
42
|
+
|
|
34
43
|
## Examples
|
|
35
44
|
|
|
36
45
|
See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
|
|
@@ -43,16 +52,16 @@ Please feel free to send a pull request to contribute to this project by followi
|
|
|
43
52
|
|
|
44
53
|
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
45
54
|
|
|
46
|
-
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz),
|
|
55
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
|
|
47
56
|
|
|
48
57
|
## Citation
|
|
49
58
|
|
|
50
59
|
Please cite this project in your publications if it helps your research.
|
|
51
60
|
|
|
52
|
-
**Note:** the paper below describes
|
|
53
|
-
since
|
|
54
|
-
for the current API) - the DOI
|
|
55
|
-
|
|
61
|
+
**Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
|
|
62
|
+
since been substantially refactored, including breaking API changes (see the
|
|
63
|
+
[documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
|
|
64
|
+
always resolves to what was actually reviewed and published.
|
|
56
65
|
|
|
57
66
|
```bibtex
|
|
58
67
|
@article{Hendria2024,
|
|
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
|
|
|
21
21
|
|
|
22
22
|
setuptools.setup(
|
|
23
23
|
name="visualtorch",
|
|
24
|
-
version="1.
|
|
24
|
+
version="1.2.0",
|
|
25
25
|
author="Willy Fitra Hendria",
|
|
26
26
|
author_email="willyfitrahendria@gmail.com",
|
|
27
27
|
description="Architecture visualization of Torch models",
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Tests for the frontend-agnostic connector-routing helpers."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from visualtorch.connectors import _segment_intersects_rect, compute_skip_levels
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _no_bbox(_node_id: str) -> tuple[float, float, float, float] | None:
|
|
10
|
+
return None
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# ---- compute_skip_levels: span/content gating ----
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_compute_skip_levels_ignores_span_le_1_edges() -> None:
|
|
17
|
+
"""An edge with column span <= 1 should never be assigned a detour level."""
|
|
18
|
+
id_to_column = {"a": 0, "b": 1}
|
|
19
|
+
|
|
20
|
+
edge_to_level, num_levels = compute_skip_levels([("a", "b")], id_to_column, lambda *_: True, _no_bbox)
|
|
21
|
+
|
|
22
|
+
assert edge_to_level == {}
|
|
23
|
+
assert num_levels == 0
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_compute_skip_levels_assigns_distinct_levels_to_overlapping_skips() -> None:
|
|
27
|
+
"""Two skip edges whose column spans genuinely overlap must get different levels."""
|
|
28
|
+
id_to_column = {"a": 0, "b": 3, "c": 2, "d": 5}
|
|
29
|
+
edges = [("a", "b"), ("c", "d")]
|
|
30
|
+
|
|
31
|
+
edge_to_level, num_levels = compute_skip_levels(edges, id_to_column, lambda *_: True, _no_bbox)
|
|
32
|
+
|
|
33
|
+
assert num_levels == 2
|
|
34
|
+
assert edge_to_level[("a", "b")] != edge_to_level[("c", "d")]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_compute_skip_levels_allows_touching_intervals_to_share_a_level() -> None:
|
|
38
|
+
"""Two skip edges that only touch at a shared column boundary can share a level."""
|
|
39
|
+
id_to_column = {"a": 0, "b": 3, "c": 3, "d": 5}
|
|
40
|
+
edges = [("a", "b"), ("c", "d")]
|
|
41
|
+
|
|
42
|
+
edge_to_level, num_levels = compute_skip_levels(edges, id_to_column, lambda *_: True, _no_bbox)
|
|
43
|
+
|
|
44
|
+
assert num_levels == 1
|
|
45
|
+
assert edge_to_level[("a", "b")] == edge_to_level[("c", "d")] == 0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_compute_skip_levels_ignores_edges_with_no_content() -> None:
|
|
49
|
+
"""A skip edge that `edge_has_content` reports as empty shouldn't consume a level."""
|
|
50
|
+
id_to_column = {"a": 0, "b": 3}
|
|
51
|
+
|
|
52
|
+
edge_to_level, num_levels = compute_skip_levels([("a", "b")], id_to_column, lambda *_: False, _no_bbox)
|
|
53
|
+
|
|
54
|
+
assert edge_to_level == {}
|
|
55
|
+
assert num_levels == 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ---- compute_skip_levels: collision-awareness ----
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_compute_skip_levels_keeps_level_when_intervening_box_collides() -> None:
|
|
62
|
+
"""A span>1 edge whose straight line genuinely crosses an intervening same-row box."""
|
|
63
|
+
id_to_column = {"a": 0, "mid": 1, "b": 2}
|
|
64
|
+
bboxes = {
|
|
65
|
+
"a": (0.0, 0.0, 10.0, 10.0),
|
|
66
|
+
"mid": (20.0, 0.0, 30.0, 10.0), # same row (y=0..10) as a and b - directly in the path
|
|
67
|
+
"b": (40.0, 0.0, 50.0, 10.0),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
edge_to_level, num_levels = compute_skip_levels(
|
|
71
|
+
[("a", "b")],
|
|
72
|
+
id_to_column,
|
|
73
|
+
lambda *_: True,
|
|
74
|
+
lambda node_id: bboxes[node_id],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
assert num_levels == 1
|
|
78
|
+
assert edge_to_level[("a", "b")] == 0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_compute_skip_levels_drops_level_when_intervening_box_is_a_different_row() -> None:
|
|
82
|
+
"""A span>1 edge whose straight line passes clear of an intervening box in a different row."""
|
|
83
|
+
id_to_column = {"a": 0, "mid": 1, "b": 2}
|
|
84
|
+
bboxes = {
|
|
85
|
+
"a": (0.0, 0.0, 10.0, 10.0),
|
|
86
|
+
"mid": (20.0, 100.0, 30.0, 110.0), # a different row entirely - well below the line
|
|
87
|
+
"b": (40.0, 0.0, 50.0, 10.0),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
edge_to_level, num_levels = compute_skip_levels(
|
|
91
|
+
[("a", "b")],
|
|
92
|
+
id_to_column,
|
|
93
|
+
lambda *_: True,
|
|
94
|
+
lambda node_id: bboxes[node_id],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
assert edge_to_level == {}
|
|
98
|
+
assert num_levels == 0
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_compute_skip_levels_missing_bbox_conservatively_assumes_collision() -> None:
|
|
102
|
+
"""If an intervening id's bbox can't be resolved, assume a collision rather than guessing clear."""
|
|
103
|
+
id_to_column = {"a": 0, "mid": 1, "b": 2}
|
|
104
|
+
bboxes = {"a": (0.0, 0.0, 10.0, 10.0), "b": (40.0, 0.0, 50.0, 10.0)}
|
|
105
|
+
|
|
106
|
+
edge_to_level, num_levels = compute_skip_levels(
|
|
107
|
+
[("a", "b")],
|
|
108
|
+
id_to_column,
|
|
109
|
+
lambda *_: True,
|
|
110
|
+
lambda node_id: bboxes.get(node_id),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
assert num_levels == 1
|
|
114
|
+
assert edge_to_level[("a", "b")] == 0
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_compute_skip_levels_missing_endpoint_bbox_conservatively_assumes_collision() -> None:
|
|
118
|
+
"""If the edge's own start/end bbox can't be resolved, also assume a collision."""
|
|
119
|
+
id_to_column = {"a": 0, "b": 2}
|
|
120
|
+
|
|
121
|
+
edge_to_level, num_levels = compute_skip_levels([("a", "b")], id_to_column, lambda *_: True, _no_bbox)
|
|
122
|
+
|
|
123
|
+
assert num_levels == 1
|
|
124
|
+
assert edge_to_level[("a", "b")] == 0
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# ---- _segment_intersects_rect ----
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_segment_intersects_rect_segment_fully_inside() -> None:
|
|
131
|
+
"""A segment entirely within the rect counts as intersecting."""
|
|
132
|
+
assert _segment_intersects_rect(2, 2, 8, 8, 0, 0, 10, 10) is True
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def test_segment_intersects_rect_segment_crosses_through_middle() -> None:
|
|
136
|
+
"""A segment passing straight through the rect's interior intersects."""
|
|
137
|
+
assert _segment_intersects_rect(-5, 5, 15, 5, 0, 0, 10, 10) is True
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_segment_intersects_rect_segment_passes_above() -> None:
|
|
141
|
+
"""A segment entirely above the rect doesn't intersect."""
|
|
142
|
+
assert _segment_intersects_rect(-5, -20, 15, -20, 0, 0, 10, 10) is False
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_segment_intersects_rect_segment_passes_below() -> None:
|
|
146
|
+
"""A segment entirely below the rect doesn't intersect."""
|
|
147
|
+
assert _segment_intersects_rect(-5, 50, 15, 50, 0, 0, 10, 10) is False
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def test_segment_intersects_rect_touches_one_corner() -> None:
|
|
151
|
+
"""Touching exactly one corner counts as intersecting (conservative)."""
|
|
152
|
+
assert _segment_intersects_rect(-5, -5, 0, 0, 0, 0, 10, 10) is True
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_segment_intersects_rect_touches_one_edge() -> None:
|
|
156
|
+
"""Touching exactly one edge counts as intersecting (conservative)."""
|
|
157
|
+
assert _segment_intersects_rect(-5, 5, 0, 5, 0, 0, 10, 10) is True
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def test_segment_intersects_rect_horizontal_segment_clear() -> None:
|
|
161
|
+
"""A horizontal segment clear of the rect doesn't intersect."""
|
|
162
|
+
assert _segment_intersects_rect(-5, -5, 15, -5, 0, 0, 10, 10) is False
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def test_segment_intersects_rect_horizontal_segment_through_rect() -> None:
|
|
166
|
+
"""A horizontal segment passing through the rect intersects."""
|
|
167
|
+
assert _segment_intersects_rect(-5, 5, 15, 5, 0, 0, 10, 10) is True
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_segment_intersects_rect_vertical_segment_through_rect() -> None:
|
|
171
|
+
"""A vertical segment passing through the rect intersects."""
|
|
172
|
+
assert _segment_intersects_rect(5, -5, 5, 15, 0, 0, 10, 10) is True
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def test_segment_intersects_rect_vertical_segment_clear() -> None:
|
|
176
|
+
"""A vertical segment clear of the rect doesn't intersect."""
|
|
177
|
+
assert _segment_intersects_rect(50, -5, 50, 15, 0, 0, 10, 10) is False
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_segment_intersects_rect_degenerate_point_inside() -> None:
|
|
181
|
+
"""A zero-length segment (a point) inside the rect counts as intersecting."""
|
|
182
|
+
assert _segment_intersects_rect(5, 5, 5, 5, 0, 0, 10, 10) is True
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_segment_intersects_rect_degenerate_point_outside() -> None:
|
|
186
|
+
"""A zero-length segment (a point) outside the rect doesn't intersect."""
|
|
187
|
+
assert _segment_intersects_rect(50, 50, 50, 50, 0, 0, 10, 10) is False
|
|
@@ -174,16 +174,16 @@ def test_rnn_model_flow_view_runs(rnn_model: nn.Module) -> None:
|
|
|
174
174
|
|
|
175
175
|
|
|
176
176
|
@pytest.mark.parametrize("orientation", ["x", "y", "z"])
|
|
177
|
-
def
|
|
177
|
+
def test_flow_view_low_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
|
|
178
178
|
"""Test flow view on a model with a 1D output, for every supported orientation."""
|
|
179
|
-
img = flow_view(classifier_model, input_shape=(1, 3, 16, 16),
|
|
179
|
+
img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation=orientation)
|
|
180
180
|
assert img is not None
|
|
181
181
|
|
|
182
182
|
|
|
183
|
-
def
|
|
184
|
-
"""An unsupported
|
|
183
|
+
def test_flow_view_invalid_low_dim_orientation_raises(classifier_model: nn.Module) -> None:
|
|
184
|
+
"""An unsupported low_dim_orientation should raise a clear ValueError."""
|
|
185
185
|
with pytest.raises(ValueError, match="unsupported orientation"):
|
|
186
|
-
flow_view(classifier_model, input_shape=(1, 3, 16, 16),
|
|
186
|
+
flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation="bad")
|
|
187
187
|
|
|
188
188
|
|
|
189
189
|
def test_flow_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
|
|
@@ -232,6 +232,13 @@ def test_flow_view_output_size_matches_pre_refactor_baseline(sequential_model: n
|
|
|
232
232
|
`extract_architecture`/`layout_columns` backend - confirmed via a `git worktree` comparison
|
|
233
233
|
at the time of the rewrite (see the graph_view equivalent test for why size, not an exact
|
|
234
234
|
pixel hash: aggdraw's anti-aliasing isn't portable across platforms, but layout math is).
|
|
235
|
+
|
|
236
|
+
Updated after `show_input` defaulted to True (an intentional, deliberate visual change - a
|
|
237
|
+
single-consumer input box is now shown by default, unlike flow_view's original look), so
|
|
238
|
+
these sizes are no longer literally pre-refactor but reflect the new intended default.
|
|
239
|
+
|
|
240
|
+
"legend" was updated again after the synthetic input class was renamed from `InputDummyLayer`
|
|
241
|
+
to `Input`, changing that legend patch's text width by 1px.
|
|
235
242
|
"""
|
|
236
243
|
cases = {
|
|
237
244
|
"default": flow_view(sequential_model, input_shape=(1, 3, 32, 32)),
|
|
@@ -242,12 +249,12 @@ def test_flow_view_output_size_matches_pre_refactor_baseline(sequential_model: n
|
|
|
242
249
|
"no_funnel": flow_view(sequential_model, input_shape=(1, 3, 32, 32), draw_funnel=False),
|
|
243
250
|
}
|
|
244
251
|
expected_sizes = {
|
|
245
|
-
"default": (
|
|
246
|
-
"no_volume": (
|
|
247
|
-
"show_dimension": (
|
|
248
|
-
"legend": (
|
|
249
|
-
"type_ignore": (
|
|
250
|
-
"no_funnel": (
|
|
252
|
+
"default": (144, 42),
|
|
253
|
+
"no_volume": (136, 32),
|
|
254
|
+
"show_dimension": (353, 59),
|
|
255
|
+
"legend": (148, 136),
|
|
256
|
+
"type_ignore": (102, 42),
|
|
257
|
+
"no_funnel": (144, 42),
|
|
251
258
|
}
|
|
252
259
|
|
|
253
260
|
for name, img in cases.items():
|
|
@@ -440,10 +447,13 @@ def test_flow_view_funnels_survive_large_de_differences_between_layers() -> None
|
|
|
440
447
|
# pre-rewrite main (1ee630e) via a git-worktree comparison for this exact model. The buggy
|
|
441
448
|
# intermediate version rendered thousands fewer non-background pixels here (missing funnel
|
|
442
449
|
# segments), so a wide but real tolerance still catches a regression of that class.
|
|
443
|
-
|
|
450
|
+
#
|
|
451
|
+
# Updated after `show_input` defaulted to True - the input box is now shown, adding a
|
|
452
|
+
# consistent amount of extra canvas/content on top of the original baseline above.
|
|
453
|
+
assert img.size == (171, 364)
|
|
444
454
|
non_bg = _non_background_pixel_count(img)
|
|
445
455
|
error_msg = f"non-background pixel count {non_bg} outside expected range - funnel likely broken"
|
|
446
|
-
assert
|
|
456
|
+
assert 27000 <= non_bg <= 31000, error_msg
|
|
447
457
|
|
|
448
458
|
|
|
449
459
|
def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
|
|
@@ -451,7 +461,7 @@ def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
|
|
|
451
461
|
hiding one would make it ambiguous which arrow originates from which named input.
|
|
452
462
|
""" # noqa: D205
|
|
453
463
|
from visualtorch.backend import extract_architecture
|
|
454
|
-
from visualtorch.utils.layer_utils import
|
|
464
|
+
from visualtorch.utils.layer_utils import Input
|
|
455
465
|
|
|
456
466
|
class TwoInputNet(nn.Module):
|
|
457
467
|
def __init__(self) -> None:
|
|
@@ -464,10 +474,115 @@ def test_flow_view_shows_all_input_boxes_for_multi_input_model() -> None:
|
|
|
464
474
|
return self.head(torch.cat([self.a(x), self.b(y)], dim=1))
|
|
465
475
|
|
|
466
476
|
architecture = extract_architecture(TwoInputNet(), ((1, 4), (1, 4)))
|
|
467
|
-
input_labels = {
|
|
468
|
-
layer.module.name() for layer in architecture.columns[0] if isinstance(layer.module, InputDummyLayer)
|
|
469
|
-
}
|
|
477
|
+
input_labels = {layer.module.name() for layer in architecture.columns[0] if isinstance(layer.module, Input)}
|
|
470
478
|
assert input_labels == {"input_0", "input_1"}
|
|
471
479
|
|
|
472
480
|
img = flow_view(TwoInputNet(), input_shape=((1, 4), (1, 4)))
|
|
473
481
|
assert img is not None
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def test_flow_view_mismatched_depth_siamese_branches_needs_no_detour() -> None:
|
|
485
|
+
"""Sibling branches of different depths merging shouldn't trigger a routed detour."""
|
|
486
|
+
|
|
487
|
+
class SiameseNet(nn.Module):
|
|
488
|
+
def __init__(self) -> None:
|
|
489
|
+
super().__init__()
|
|
490
|
+
self.image_branch = nn.Sequential(
|
|
491
|
+
nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
|
|
492
|
+
nn.ReLU(),
|
|
493
|
+
nn.AdaptiveAvgPool2d(1),
|
|
494
|
+
nn.Flatten(),
|
|
495
|
+
)
|
|
496
|
+
self.vector_branch = nn.Sequential(
|
|
497
|
+
nn.Linear(10, 8),
|
|
498
|
+
nn.ReLU(),
|
|
499
|
+
)
|
|
500
|
+
self.head = nn.Linear(16, 4)
|
|
501
|
+
|
|
502
|
+
def forward(self, image: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
|
503
|
+
"""Run each branch on its own input tensor, then concatenate and project."""
|
|
504
|
+
image_features = self.image_branch(image)
|
|
505
|
+
vector_features = self.vector_branch(vector)
|
|
506
|
+
merged = torch.cat([image_features, vector_features], dim=1)
|
|
507
|
+
return self.head(merged)
|
|
508
|
+
|
|
509
|
+
class SiameseNetDepthMatched(nn.Module):
|
|
510
|
+
def __init__(self) -> None:
|
|
511
|
+
super().__init__()
|
|
512
|
+
self.image_branch = nn.Sequential(
|
|
513
|
+
nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
|
|
514
|
+
nn.ReLU(),
|
|
515
|
+
nn.AdaptiveAvgPool2d(1),
|
|
516
|
+
nn.Flatten(),
|
|
517
|
+
)
|
|
518
|
+
self.vector_branch = nn.Sequential(
|
|
519
|
+
nn.Linear(10, 8),
|
|
520
|
+
nn.ReLU(),
|
|
521
|
+
nn.Linear(8, 8),
|
|
522
|
+
nn.ReLU(),
|
|
523
|
+
)
|
|
524
|
+
self.head = nn.Linear(16, 4)
|
|
525
|
+
|
|
526
|
+
def forward(self, image: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
|
527
|
+
"""Run each branch on its own input tensor, then concatenate and project."""
|
|
528
|
+
image_features = self.image_branch(image)
|
|
529
|
+
vector_features = self.vector_branch(vector)
|
|
530
|
+
merged = torch.cat([image_features, vector_features], dim=1)
|
|
531
|
+
return self.head(merged)
|
|
532
|
+
|
|
533
|
+
input_shape = ((1, 3, 16, 16), (1, 10))
|
|
534
|
+
img_mismatched = flow_view(SiameseNet(), input_shape=input_shape)
|
|
535
|
+
img_matched = flow_view(SiameseNetDepthMatched(), input_shape=input_shape)
|
|
536
|
+
|
|
537
|
+
assert img_mismatched.size[1] == img_matched.size[1]
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def test_flow_view_low_dim_orientation_affects_2d_shapes() -> None:
|
|
541
|
+
"""A 2D shape (e.g. an RNN's (seq_len, hidden_size)) should now respond to
|
|
542
|
+
low_dim_orientation too, not just genuine 1D shapes.
|
|
543
|
+
""" # noqa: D205
|
|
544
|
+
|
|
545
|
+
class SequenceClassifier(nn.Module):
|
|
546
|
+
def __init__(self, hidden_size: int) -> None:
|
|
547
|
+
super().__init__()
|
|
548
|
+
self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
|
|
549
|
+
|
|
550
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
551
|
+
out, _ = self.lstm(x)
|
|
552
|
+
return out
|
|
553
|
+
|
|
554
|
+
model = SequenceClassifier(hidden_size=64)
|
|
555
|
+
input_shape = (1, 5, 8)
|
|
556
|
+
|
|
557
|
+
sizes = {
|
|
558
|
+
orientation: flow_view(model, input_shape=input_shape, low_dim_orientation=orientation).size
|
|
559
|
+
for orientation in ("x", "y", "z")
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
assert len(set(sizes.values())) == 3, f"expected all 3 orientations to differ, got {sizes}"
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def test_flow_view_2d_shape_seq_len_is_discarded() -> None:
|
|
566
|
+
"""The positional-like dim (e.g. seq_len) of a 2D shape shouldn't affect box size -
|
|
567
|
+
only the feature-like dim (e.g. hidden_size) should.
|
|
568
|
+
""" # noqa: D205
|
|
569
|
+
|
|
570
|
+
class SequenceClassifier(nn.Module):
|
|
571
|
+
def __init__(self, hidden_size: int) -> None:
|
|
572
|
+
super().__init__()
|
|
573
|
+
self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
|
|
574
|
+
|
|
575
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
576
|
+
out, _ = self.lstm(x)
|
|
577
|
+
return out
|
|
578
|
+
|
|
579
|
+
model = SequenceClassifier(hidden_size=64)
|
|
580
|
+
img_short_seq = flow_view(model, input_shape=(1, 5, 8))
|
|
581
|
+
img_long_seq = flow_view(model, input_shape=(1, 50, 8))
|
|
582
|
+
|
|
583
|
+
assert img_short_seq.tobytes() == img_long_seq.tobytes()
|
|
584
|
+
|
|
585
|
+
model_bigger_hidden = SequenceClassifier(hidden_size=256)
|
|
586
|
+
img_bigger_hidden = flow_view(model_bigger_hidden, input_shape=(1, 5, 8))
|
|
587
|
+
|
|
588
|
+
assert img_short_seq.tobytes() != img_bigger_hidden.tobytes()
|