visualtorch 1.1.0__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visualtorch-1.1.0/visualtorch.egg-info → visualtorch-1.2.0}/PKG-INFO +16 -15
- {visualtorch-1.1.0 → visualtorch-1.2.0}/README.md +15 -14
- {visualtorch-1.1.0 → visualtorch-1.2.0}/setup.py +1 -1
- {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_flow.py +56 -5
- {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_lenet_style.py +56 -5
- {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_regression_issues.py +47 -1
- {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_render.py +44 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/__init__.py +2 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/backend.py +5 -1
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/flow.py +28 -17
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/graph.py +9 -3
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/lenet_style.py +37 -18
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/render.py +8 -4
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/recorder.py +40 -17
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/traced_layer.py +6 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/utils.py +123 -15
- {visualtorch-1.1.0 → visualtorch-1.2.0/visualtorch.egg-info}/PKG-INFO +16 -15
- {visualtorch-1.1.0 → visualtorch-1.2.0}/LICENSE +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/pyproject.toml +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/setup.cfg +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_connectors.py +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_graph.py +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/_volumetric_layout.py +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/connectors.py +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/__init__.py +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/layer_utils.py +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/SOURCES.txt +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/dependency_links.txt +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/requires.txt +0 -0
- {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -54,21 +54,22 @@ Dynamic: summary
|
|
|
54
54
|
|
|
55
55
|
</div>
|
|
56
56
|
|
|
57
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models.
|
|
57
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
|
|
58
58
|
|
|
59
59
|
**Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
|
|
60
60
|
|
|
61
|
-
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has
|
|
62
|
-
|
|
63
|
-
execution):
|
|
64
|
-
value crosses some threshold) only show whichever branch the traced dummy input happened to take
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
61
|
+
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
|
|
62
|
+
limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
|
|
63
|
+
execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
64
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take.
|
|
65
|
+
Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
|
|
66
|
+
head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
|
|
67
|
+
tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
|
|
68
|
+
the first. Downstream connections are correct either way. Contributions are welcome!
|
|
68
69
|
|
|
69
70
|
<div align="center">
|
|
70
71
|
|
|
71
|
-

|
|
72
|
+

|
|
72
73
|
|
|
73
74
|
</div>
|
|
74
75
|
|
|
@@ -100,16 +101,16 @@ Please feel free to send a pull request to contribute to this project by followi
|
|
|
100
101
|
|
|
101
102
|
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
102
103
|
|
|
103
|
-
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz),
|
|
104
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
|
|
104
105
|
|
|
105
106
|
## Citation
|
|
106
107
|
|
|
107
108
|
Please cite this project in your publications if it helps your research.
|
|
108
109
|
|
|
109
|
-
**Note:** the paper below describes
|
|
110
|
-
since
|
|
111
|
-
for the current API) - the DOI
|
|
112
|
-
|
|
110
|
+
**Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
|
|
111
|
+
since been substantially refactored, including breaking API changes (see the
|
|
112
|
+
[documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
|
|
113
|
+
always resolves to what was actually reviewed and published.
|
|
113
114
|
|
|
114
115
|
```bibtex
|
|
115
116
|
@article{Hendria2024,
|
|
@@ -5,21 +5,22 @@
|
|
|
5
5
|
|
|
6
6
|
</div>
|
|
7
7
|
|
|
8
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models.
|
|
8
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
|
|
9
9
|
|
|
10
10
|
**Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
|
|
11
11
|
|
|
12
|
-
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has
|
|
13
|
-
|
|
14
|
-
execution):
|
|
15
|
-
value crosses some threshold) only show whichever branch the traced dummy input happened to take
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
12
|
+
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
|
|
13
|
+
limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
|
|
14
|
+
execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
15
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take.
|
|
16
|
+
Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
|
|
17
|
+
head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
|
|
18
|
+
tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
|
|
19
|
+
the first. Downstream connections are correct either way. Contributions are welcome!
|
|
19
20
|
|
|
20
21
|
<div align="center">
|
|
21
22
|
|
|
22
|
-

|
|
23
|
+

|
|
23
24
|
|
|
24
25
|
</div>
|
|
25
26
|
|
|
@@ -51,16 +52,16 @@ Please feel free to send a pull request to contribute to this project by followi
|
|
|
51
52
|
|
|
52
53
|
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
53
54
|
|
|
54
|
-
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz),
|
|
55
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
|
|
55
56
|
|
|
56
57
|
## Citation
|
|
57
58
|
|
|
58
59
|
Please cite this project in your publications if it helps your research.
|
|
59
60
|
|
|
60
|
-
**Note:** the paper below describes
|
|
61
|
-
since
|
|
62
|
-
for the current API) - the DOI
|
|
63
|
-
|
|
61
|
+
**Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
|
|
62
|
+
since been substantially refactored, including breaking API changes (see the
|
|
63
|
+
[documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
|
|
64
|
+
always resolves to what was actually reviewed and published.
|
|
64
65
|
|
|
65
66
|
```bibtex
|
|
66
67
|
@article{Hendria2024,
|
|
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
|
|
|
21
21
|
|
|
22
22
|
setuptools.setup(
|
|
23
23
|
name="visualtorch",
|
|
24
|
-
version="1.
|
|
24
|
+
version="1.2.0",
|
|
25
25
|
author="Willy Fitra Hendria",
|
|
26
26
|
author_email="willyfitrahendria@gmail.com",
|
|
27
27
|
description="Architecture visualization of Torch models",
|
|
@@ -174,16 +174,16 @@ def test_rnn_model_flow_view_runs(rnn_model: nn.Module) -> None:
|
|
|
174
174
|
|
|
175
175
|
|
|
176
176
|
@pytest.mark.parametrize("orientation", ["x", "y", "z"])
|
|
177
|
-
def
|
|
177
|
+
def test_flow_view_low_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
|
|
178
178
|
"""Test flow view on a model with a 1D output, for every supported orientation."""
|
|
179
|
-
img = flow_view(classifier_model, input_shape=(1, 3, 16, 16),
|
|
179
|
+
img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation=orientation)
|
|
180
180
|
assert img is not None
|
|
181
181
|
|
|
182
182
|
|
|
183
|
-
def
|
|
184
|
-
"""An unsupported
|
|
183
|
+
def test_flow_view_invalid_low_dim_orientation_raises(classifier_model: nn.Module) -> None:
|
|
184
|
+
"""An unsupported low_dim_orientation should raise a clear ValueError."""
|
|
185
185
|
with pytest.raises(ValueError, match="unsupported orientation"):
|
|
186
|
-
flow_view(classifier_model, input_shape=(1, 3, 16, 16),
|
|
186
|
+
flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation="bad")
|
|
187
187
|
|
|
188
188
|
|
|
189
189
|
def test_flow_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
|
|
@@ -535,3 +535,54 @@ def test_flow_view_mismatched_depth_siamese_branches_needs_no_detour() -> None:
|
|
|
535
535
|
img_matched = flow_view(SiameseNetDepthMatched(), input_shape=input_shape)
|
|
536
536
|
|
|
537
537
|
assert img_mismatched.size[1] == img_matched.size[1]
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def test_flow_view_low_dim_orientation_affects_2d_shapes() -> None:
|
|
541
|
+
"""A 2D shape (e.g. an RNN's (seq_len, hidden_size)) should now respond to
|
|
542
|
+
low_dim_orientation too, not just genuine 1D shapes.
|
|
543
|
+
""" # noqa: D205
|
|
544
|
+
|
|
545
|
+
class SequenceClassifier(nn.Module):
|
|
546
|
+
def __init__(self, hidden_size: int) -> None:
|
|
547
|
+
super().__init__()
|
|
548
|
+
self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
|
|
549
|
+
|
|
550
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
551
|
+
out, _ = self.lstm(x)
|
|
552
|
+
return out
|
|
553
|
+
|
|
554
|
+
model = SequenceClassifier(hidden_size=64)
|
|
555
|
+
input_shape = (1, 5, 8)
|
|
556
|
+
|
|
557
|
+
sizes = {
|
|
558
|
+
orientation: flow_view(model, input_shape=input_shape, low_dim_orientation=orientation).size
|
|
559
|
+
for orientation in ("x", "y", "z")
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
assert len(set(sizes.values())) == 3, f"expected all 3 orientations to differ, got {sizes}"
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def test_flow_view_2d_shape_seq_len_is_discarded() -> None:
|
|
566
|
+
"""The positional-like dim (e.g. seq_len) of a 2D shape shouldn't affect box size -
|
|
567
|
+
only the feature-like dim (e.g. hidden_size) should.
|
|
568
|
+
""" # noqa: D205
|
|
569
|
+
|
|
570
|
+
class SequenceClassifier(nn.Module):
|
|
571
|
+
def __init__(self, hidden_size: int) -> None:
|
|
572
|
+
super().__init__()
|
|
573
|
+
self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
|
|
574
|
+
|
|
575
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
576
|
+
out, _ = self.lstm(x)
|
|
577
|
+
return out
|
|
578
|
+
|
|
579
|
+
model = SequenceClassifier(hidden_size=64)
|
|
580
|
+
img_short_seq = flow_view(model, input_shape=(1, 5, 8))
|
|
581
|
+
img_long_seq = flow_view(model, input_shape=(1, 50, 8))
|
|
582
|
+
|
|
583
|
+
assert img_short_seq.tobytes() == img_long_seq.tobytes()
|
|
584
|
+
|
|
585
|
+
model_bigger_hidden = SequenceClassifier(hidden_size=256)
|
|
586
|
+
img_bigger_hidden = flow_view(model_bigger_hidden, input_shape=(1, 5, 8))
|
|
587
|
+
|
|
588
|
+
assert img_short_seq.tobytes() != img_bigger_hidden.tobytes()
|
|
@@ -174,16 +174,16 @@ def test_rnn_model_lenet_view_runs(rnn_model: nn.Module) -> None:
|
|
|
174
174
|
|
|
175
175
|
|
|
176
176
|
@pytest.mark.parametrize("orientation", ["x", "y", "z"])
|
|
177
|
-
def
|
|
177
|
+
def test_lenet_view_low_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
|
|
178
178
|
"""Test lenet view on a model with a 1D output, for every supported orientation."""
|
|
179
|
-
img = lenet_view(classifier_model, input_shape=(1, 3, 16, 16),
|
|
179
|
+
img = lenet_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation=orientation)
|
|
180
180
|
assert img is not None
|
|
181
181
|
|
|
182
182
|
|
|
183
|
-
def
|
|
184
|
-
"""An unsupported
|
|
183
|
+
def test_lenet_view_invalid_low_dim_orientation_raises(classifier_model: nn.Module) -> None:
|
|
184
|
+
"""An unsupported low_dim_orientation should raise a clear ValueError."""
|
|
185
185
|
with pytest.raises(ValueError, match="unsupported orientation"):
|
|
186
|
-
lenet_view(classifier_model, input_shape=(1, 3, 16, 16),
|
|
186
|
+
lenet_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation="bad")
|
|
187
187
|
|
|
188
188
|
|
|
189
189
|
def test_lenet_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
|
|
@@ -328,3 +328,54 @@ def test_lenet_view_funnels_survive_large_de_differences_between_layers() -> Non
|
|
|
328
328
|
non_bg = int((np.array(img.convert("RGB")) != 255).any(axis=2).sum())
|
|
329
329
|
error_msg = f"non-background pixel count {non_bg} outside expected range - funnel likely broken"
|
|
330
330
|
assert 110000 <= non_bg <= 145000, error_msg
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_lenet_view_low_dim_orientation_affects_2d_shapes() -> None:
|
|
334
|
+
"""A 2D shape (e.g. an RNN's (seq_len, hidden_size)) should now respond to
|
|
335
|
+
low_dim_orientation too, not just genuine 1D shapes.
|
|
336
|
+
""" # noqa: D205
|
|
337
|
+
|
|
338
|
+
class SequenceClassifier(nn.Module):
|
|
339
|
+
def __init__(self, hidden_size: int) -> None:
|
|
340
|
+
super().__init__()
|
|
341
|
+
self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
|
|
342
|
+
|
|
343
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
344
|
+
out, _ = self.lstm(x)
|
|
345
|
+
return out
|
|
346
|
+
|
|
347
|
+
model = SequenceClassifier(hidden_size=64)
|
|
348
|
+
input_shape = (1, 5, 8)
|
|
349
|
+
|
|
350
|
+
sizes = {
|
|
351
|
+
orientation: lenet_view(model, input_shape=input_shape, low_dim_orientation=orientation).size
|
|
352
|
+
for orientation in ("x", "y", "z")
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
assert len(set(sizes.values())) == 3, f"expected all 3 orientations to differ, got {sizes}"
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def test_lenet_view_2d_shape_seq_len_is_discarded() -> None:
|
|
359
|
+
"""The positional-like dim (e.g. seq_len) of a 2D shape shouldn't affect box size -
|
|
360
|
+
only the feature-like dim (e.g. hidden_size) should.
|
|
361
|
+
""" # noqa: D205
|
|
362
|
+
|
|
363
|
+
class SequenceClassifier(nn.Module):
|
|
364
|
+
def __init__(self, hidden_size: int) -> None:
|
|
365
|
+
super().__init__()
|
|
366
|
+
self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
|
|
367
|
+
|
|
368
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
369
|
+
out, _ = self.lstm(x)
|
|
370
|
+
return out
|
|
371
|
+
|
|
372
|
+
model = SequenceClassifier(hidden_size=64)
|
|
373
|
+
img_short_seq = lenet_view(model, input_shape=(1, 5, 8))
|
|
374
|
+
img_long_seq = lenet_view(model, input_shape=(1, 50, 8))
|
|
375
|
+
|
|
376
|
+
assert img_short_seq.tobytes() == img_long_seq.tobytes()
|
|
377
|
+
|
|
378
|
+
model_bigger_hidden = SequenceClassifier(hidden_size=256)
|
|
379
|
+
img_bigger_hidden = lenet_view(model_bigger_hidden, input_shape=(1, 5, 8))
|
|
380
|
+
|
|
381
|
+
assert img_short_seq.tobytes() != img_bigger_hidden.tobytes()
|
|
@@ -11,8 +11,9 @@ import torch
|
|
|
11
11
|
from torch import nn
|
|
12
12
|
from visualtorch.backend import extract_architecture
|
|
13
13
|
from visualtorch.flow import flow_view
|
|
14
|
+
from visualtorch.graph import graph_view
|
|
14
15
|
from visualtorch.lenet_style import lenet_view
|
|
15
|
-
from visualtorch.utils.utils import self_multiply
|
|
16
|
+
from visualtorch.utils.utils import format_shape_label, self_multiply
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@pytest.fixture()
|
|
@@ -112,3 +113,48 @@ def test_flow_view_recurrent_sequence_length_does_not_inflate_diagram_height(rec
|
|
|
112
113
|
long_img = flow_view(model, input_shape=(1, 200, 10))
|
|
113
114
|
|
|
114
115
|
assert long_img.height == short_img.height
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@pytest.fixture()
|
|
119
|
+
def lstm_using_hidden_state_model() -> nn.Module:
|
|
120
|
+
"""A model that consumes `nn.LSTM`'s hidden state (h_n), not its sequence output.
|
|
121
|
+
|
|
122
|
+
`nn.LSTM.forward()` returns `(output, (h_n, c_n))` - three tensors, not one. Before the fix,
|
|
123
|
+
only `output`'s shape was ever recorded, even when (as here) the model actually uses `h_n`
|
|
124
|
+
instead - silently dropping the shape of the tensor that matters, with nothing shown to
|
|
125
|
+
indicate more than one tensor even exists.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
class Model(nn.Module):
|
|
129
|
+
def __init__(self) -> None:
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
|
|
132
|
+
self.fc = nn.Linear(20, 5)
|
|
133
|
+
|
|
134
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
_output, (h_n, _c_n) = self.lstm(x)
|
|
136
|
+
return self.fc(h_n.squeeze(0))
|
|
137
|
+
|
|
138
|
+
return Model()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_extract_architecture_records_every_output_tensor_shape(lstm_using_hidden_state_model: nn.Module) -> None:
|
|
142
|
+
"""A multi-output leaf module's TracedLayer should record all three of its output shapes."""
|
|
143
|
+
architecture = extract_architecture(lstm_using_hidden_state_model, (1, 7, 10))
|
|
144
|
+
lstm_layer = next(layer for column in architecture.columns for layer in column if isinstance(layer.module, nn.LSTM))
|
|
145
|
+
|
|
146
|
+
assert lstm_layer.output_shape == (1, 7, 20)
|
|
147
|
+
assert lstm_layer.extra_output_shapes == ((1, 1, 20), (1, 1, 20))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def test_format_shape_label_appends_extra_shapes() -> None:
|
|
151
|
+
"""format_shape_label should append every extra shape, and omit the `+` entirely when there are none."""
|
|
152
|
+
assert format_shape_label((1, 7, 20), ()) == "(1, 7, 20)"
|
|
153
|
+
assert format_shape_label((1, 7, 20), ((1, 1, 20), (1, 1, 20))) == "(1, 7, 20) + (1, 1, 20) + (1, 1, 20)"
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.parametrize("view", [graph_view, flow_view, lenet_view])
|
|
157
|
+
def test_show_dimension_includes_every_output_shape(lstm_using_hidden_state_model: nn.Module, view: object) -> None:
|
|
158
|
+
"""show_dimension=True shouldn't crash, and should still work, for a multi-output leaf module."""
|
|
159
|
+
img = view(lstm_using_hidden_state_model, input_shape=(1, 7, 10), show_dimension=True) # type: ignore[operator]
|
|
160
|
+
assert img is not None
|
|
@@ -3,11 +3,15 @@
|
|
|
3
3
|
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
4
|
# SPDX-License-Identifier: MIT
|
|
5
5
|
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
6
8
|
import pytest
|
|
7
9
|
import torch
|
|
8
10
|
from torch import nn
|
|
9
11
|
from visualtorch import render
|
|
10
12
|
from visualtorch.backend import extract_architecture
|
|
13
|
+
from visualtorch.utils.layer_utils import Input
|
|
14
|
+
from visualtorch.utils.utils import PALETTES
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
@pytest.fixture()
|
|
@@ -148,3 +152,43 @@ def test_render_handles_unused_input_tensor() -> None:
|
|
|
148
152
|
|
|
149
153
|
img = render(PartiallyUnusedNet(), input_shape=((1, 10), (1, 5)), style="graph")
|
|
150
154
|
assert img is not None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@pytest.mark.parametrize("palette", sorted(PALETTES))
|
|
158
|
+
def test_render_runs_for_every_named_palette(sequential_model: nn.Sequential, palette: str) -> None:
|
|
159
|
+
"""Every named palette should render without error - catches any malformed hex color."""
|
|
160
|
+
img = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", palette=palette)
|
|
161
|
+
assert img is not None
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_render_rejects_unsupported_palette(sequential_model: nn.Sequential) -> None:
|
|
165
|
+
"""An unrecognized palette name should raise a clear error, not silently fall back."""
|
|
166
|
+
with pytest.raises(ValueError, match="Unsupported palette"):
|
|
167
|
+
render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", palette="bogus")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_render_palette_changes_fallback_colors(sequential_model: nn.Sequential) -> None:
|
|
171
|
+
"""A different palette should actually change the colors of unmapped layer types."""
|
|
172
|
+
default = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph")
|
|
173
|
+
dracula = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", palette="dracula")
|
|
174
|
+
|
|
175
|
+
assert default.tobytes() != dracula.tobytes()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def test_render_color_map_overrides_palette(sequential_model: nn.Sequential) -> None:
|
|
179
|
+
"""An explicit color_map entry should still win over the palette fallback."""
|
|
180
|
+
color_map: dict = defaultdict(dict)
|
|
181
|
+
color_map[Input]["fill"] = "#abcdef"
|
|
182
|
+
color_map[nn.Conv2d]["fill"] = "#123456"
|
|
183
|
+
color_map[nn.ReLU]["fill"] = "#654321"
|
|
184
|
+
|
|
185
|
+
okabe_ito = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", color_map=color_map)
|
|
186
|
+
dracula = render(
|
|
187
|
+
sequential_model,
|
|
188
|
+
input_shape=(1, 3, 16, 16),
|
|
189
|
+
style="graph",
|
|
190
|
+
color_map=color_map,
|
|
191
|
+
palette="dracula",
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
assert okabe_ito.tobytes() == dracula.tobytes()
|
|
@@ -11,6 +11,7 @@ from visualtorch.render import (
|
|
|
11
11
|
render,
|
|
12
12
|
)
|
|
13
13
|
from visualtorch.utils.layer_utils import Input
|
|
14
|
+
from visualtorch.utils.utils import PALETTES
|
|
14
15
|
|
|
15
16
|
__all__ = [
|
|
16
17
|
"render",
|
|
@@ -19,4 +20,5 @@ __all__ = [
|
|
|
19
20
|
"FlowStyleOptions",
|
|
20
21
|
"LenetStyleOptions",
|
|
21
22
|
"Input",
|
|
23
|
+
"PALETTES",
|
|
22
24
|
]
|
|
@@ -67,7 +67,10 @@ def extract_architecture(model: nn.Module, input_shape: InputShape) -> Architect
|
|
|
67
67
|
"""
|
|
68
68
|
input_shapes = validate_input_shape(input_shape)
|
|
69
69
|
|
|
70
|
-
id_to_module, id_to_output_shape, edges, input_ids = trace_module_graph(
|
|
70
|
+
id_to_module, id_to_output_shape, id_to_extra_output_shapes, edges, input_ids = trace_module_graph(
|
|
71
|
+
model,
|
|
72
|
+
input_shapes,
|
|
73
|
+
)
|
|
71
74
|
|
|
72
75
|
nodes = list(id_to_module.keys())
|
|
73
76
|
id_to_index = {node_id: idx for idx, node_id in enumerate(nodes)}
|
|
@@ -114,6 +117,7 @@ def extract_architecture(model: nn.Module, input_shape: InputShape) -> Architect
|
|
|
114
117
|
module=id_to_module[node_id],
|
|
115
118
|
output_shape=id_to_output_shape[node_id],
|
|
116
119
|
node_id=node_id,
|
|
120
|
+
extra_output_shapes=id_to_extra_output_shapes.get(node_id, ()),
|
|
117
121
|
)
|
|
118
122
|
columns[depth[node_id]].append(wrapper)
|
|
119
123
|
|
|
@@ -22,8 +22,10 @@ from .utils.utils import (
|
|
|
22
22
|
ColorWheel,
|
|
23
23
|
ImageDraw,
|
|
24
24
|
InputShape,
|
|
25
|
+
format_shape_label,
|
|
25
26
|
get_rgba_tuple,
|
|
26
27
|
linear_layout,
|
|
28
|
+
resolve_palette,
|
|
27
29
|
self_multiply,
|
|
28
30
|
vertical_image_concat,
|
|
29
31
|
)
|
|
@@ -41,7 +43,8 @@ def flow_view(
|
|
|
41
43
|
scale_xy: float = 1,
|
|
42
44
|
type_ignore: list | None = None,
|
|
43
45
|
color_map: dict | None = None,
|
|
44
|
-
|
|
46
|
+
palette: str = "okabe_ito",
|
|
47
|
+
low_dim_orientation: str = "z",
|
|
45
48
|
background_fill: str | tuple[int, ...] = "white",
|
|
46
49
|
draw_volume: bool = True,
|
|
47
50
|
padding: int = 10,
|
|
@@ -75,7 +78,13 @@ def flow_view(
|
|
|
75
78
|
type_ignore (list, optional): List of layer types in the torch model to ignore during drawing.
|
|
76
79
|
color_map (dict, optional): Dictionary defining fill and outline colors for each layer by class type.
|
|
77
80
|
Will fallback to default values for unspecified classes.
|
|
78
|
-
|
|
81
|
+
palette (str, optional): Named color palette used as the fallback for any layer type not
|
|
82
|
+
given an explicit override via `color_map`. One of `"okabe_ito"` (default,
|
|
83
|
+
colorblind-safe), `"tol_bright"`, `"tol_muted"`, `"tab10"`, `"grayscale"`, `"nord"`,
|
|
84
|
+
`"dracula"`, `"gruvbox"`, `"solarized"`, `"material"`, `"catppuccin"`.
|
|
85
|
+
low_dim_orientation (str, optional): Axis on which a layer without real spatial/channel
|
|
86
|
+
structure (a 1D shape, or a 2D shape like an RNN/attention layer's
|
|
87
|
+
`(seq_len, hidden_size)`) should be drawn. One of `'x'`, `'y'`, or `'z'`.
|
|
79
88
|
background_fill (str or tuple, optional): Background color for the image. A string or a tuple (R, G, B, A).
|
|
80
89
|
draw_volume (bool, optional): Flag to switch between 3D volumetric view and 2D box view.
|
|
81
90
|
padding (int, optional): Distance in pixels before the first and after the last layer.
|
|
@@ -132,9 +141,9 @@ def flow_view(
|
|
|
132
141
|
filtered_columns = [column for column in filtered_columns if column]
|
|
133
142
|
|
|
134
143
|
layer_types: list[type] = []
|
|
135
|
-
color_wheel = ColorWheel()
|
|
144
|
+
color_wheel = ColorWheel(colors=resolve_palette(palette))
|
|
136
145
|
make_box = _box_factory(
|
|
137
|
-
|
|
146
|
+
low_dim_orientation,
|
|
138
147
|
scale_xy,
|
|
139
148
|
min_xy,
|
|
140
149
|
max_xy,
|
|
@@ -232,7 +241,7 @@ def flow_view(
|
|
|
232
241
|
|
|
233
242
|
|
|
234
243
|
def _box_factory(
|
|
235
|
-
|
|
244
|
+
low_dim_orientation: str,
|
|
236
245
|
scale_xy: float,
|
|
237
246
|
min_xy: int,
|
|
238
247
|
max_xy: int,
|
|
@@ -251,19 +260,20 @@ def _box_factory(
|
|
|
251
260
|
def make_box(layer: TracedLayer) -> Box:
|
|
252
261
|
shape = layer.output_shape[1:] # drop batch size
|
|
253
262
|
|
|
254
|
-
if len(shape)
|
|
255
|
-
|
|
256
|
-
|
|
263
|
+
if len(shape) in (1, 2):
|
|
264
|
+
# Neither a 1D nor a 2D shape has real spatial/channel structure - there's nothing
|
|
265
|
+
# to distinguish "channel" from "spatial" the way a genuine (C, H, W) feature map
|
|
266
|
+
# does. Take the last value (for 2D, e.g. an RNN/attention layer's
|
|
267
|
+
# (seq_len, hidden_size), this is the feature/channel-like one, matching PyTorch's
|
|
268
|
+
# (..., seq, feature) convention for sequence data; for 1D it's the only value) and
|
|
269
|
+
# let the user place it on whichever axis they choose, same as any 1D value - the
|
|
270
|
+
# positional-like dim, if any, is discarded either way.
|
|
271
|
+
value = shape[-1]
|
|
272
|
+
if low_dim_orientation in ("x", "y", "z"):
|
|
273
|
+
shape = (1,) * "cxyz".index(low_dim_orientation) + (value,)
|
|
257
274
|
else:
|
|
258
|
-
error_msg = f"unsupported orientation: {
|
|
275
|
+
error_msg = f"unsupported orientation: {low_dim_orientation}"
|
|
259
276
|
raise ValueError(error_msg)
|
|
260
|
-
elif len(shape) == 2:
|
|
261
|
-
# A 2D non-batch shape (e.g. (seq_len, hidden_size) from an RNN/attention layer)
|
|
262
|
-
# isn't a CNN feature map missing a channel dim - there's no channel axis at all.
|
|
263
|
-
# Box's "3D" skew (de, below) is driven by shape[1], so a dummy 1 goes there
|
|
264
|
-
# instead of either real dim, keeping the two real dims on the box's actual width
|
|
265
|
-
# and height instead of one of them inflating the skew for a long sequence.
|
|
266
|
-
shape = (shape[0], 1, shape[1])
|
|
267
277
|
|
|
268
278
|
ori_shape = shape
|
|
269
279
|
shape = shape + (1,) * (4 - len(shape)) # expand 4D.
|
|
@@ -278,6 +288,7 @@ def _box_factory(
|
|
|
278
288
|
|
|
279
289
|
box = Box()
|
|
280
290
|
box.output_shape = tuple(ori_shape)
|
|
291
|
+
box.extra_output_shapes = layer.extra_output_shapes
|
|
281
292
|
box.de = int(x / 3) if draw_volume else 0
|
|
282
293
|
|
|
283
294
|
box.x1 = 0
|
|
@@ -476,7 +487,7 @@ def _draw_legend(
|
|
|
476
487
|
|
|
477
488
|
def _column_label_and_center(column: list[VolumetricBox]) -> tuple[str, float]:
|
|
478
489
|
"""A column's shape label (joined across branches) and its shared x-center."""
|
|
479
|
-
label = " / ".join(
|
|
490
|
+
label = " / ".join(format_shape_label(box.output_shape, box.extra_output_shapes) for box in column)
|
|
480
491
|
center_x = (column[0].x1 + column[0].x2) / 2
|
|
481
492
|
return label, center_x
|
|
482
493
|
|
|
@@ -15,7 +15,7 @@ from PIL import Image, ImageFont
|
|
|
15
15
|
from .backend import extract_architecture
|
|
16
16
|
from .connectors import compute_skip_levels, draw_connector
|
|
17
17
|
from .utils.traced_layer import TracedLayer
|
|
18
|
-
from .utils.utils import Box, Circle, ColorWheel, Ellipses, ImageDraw, InputShape
|
|
18
|
+
from .utils.utils import Box, Circle, ColorWheel, Ellipses, ImageDraw, InputShape, format_shape_label, resolve_palette
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def graph_view(
|
|
@@ -23,6 +23,7 @@ def graph_view(
|
|
|
23
23
|
input_shape: InputShape,
|
|
24
24
|
to_file: str | None = None,
|
|
25
25
|
color_map: dict[Any, Any] | None = None,
|
|
26
|
+
palette: str = "okabe_ito",
|
|
26
27
|
node_size: int = 50,
|
|
27
28
|
background_fill: str | tuple[int, ...] = "white",
|
|
28
29
|
padding: int = 10,
|
|
@@ -51,6 +52,10 @@ def graph_view(
|
|
|
51
52
|
will disable writing.
|
|
52
53
|
color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
|
|
53
54
|
to default values for not specified classes.
|
|
55
|
+
palette (str, optional): Named color palette used as the fallback for any layer type not
|
|
56
|
+
given an explicit override via `color_map`. One of `"okabe_ito"` (default,
|
|
57
|
+
colorblind-safe), `"tol_bright"`, `"tol_muted"`, `"tab10"`, `"grayscale"`, `"nord"`,
|
|
58
|
+
`"dracula"`, `"gruvbox"`, `"solarized"`, `"material"`, `"catppuccin"`.
|
|
54
59
|
node_size (int, optional): Size in pixels each node will have.
|
|
55
60
|
background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
|
|
56
61
|
padding (int, optional): Distance in pixels before the first and after the last layer.
|
|
@@ -111,7 +116,7 @@ def graph_view(
|
|
|
111
116
|
_color_map,
|
|
112
117
|
opacity,
|
|
113
118
|
layer_spacing,
|
|
114
|
-
ColorWheel(),
|
|
119
|
+
ColorWheel(colors=resolve_palette(palette)),
|
|
115
120
|
)
|
|
116
121
|
|
|
117
122
|
# An edge whose endpoint was just dropped above (a hidden input's own edges) can no longer
|
|
@@ -409,7 +414,8 @@ def _create_architecture(
|
|
|
409
414
|
|
|
410
415
|
id_to_node_list_map[layer.node_id] = layer_nodes
|
|
411
416
|
nodes.extend(layer_nodes)
|
|
412
|
-
|
|
417
|
+
label = format_shape_label(layer.output_shape, layer.extra_output_shapes)
|
|
418
|
+
column_labels.append((label, current_x + node_size / 2, current_y))
|
|
413
419
|
current_y += 2 * node_size
|
|
414
420
|
|
|
415
421
|
layer_y.append(current_y - node_spacing - 2 * node_size)
|
|
@@ -17,7 +17,16 @@ from .backend import Architecture, extract_architecture
|
|
|
17
17
|
from .connectors import compute_skip_levels, draw_connector
|
|
18
18
|
from .utils.layer_utils import Input
|
|
19
19
|
from .utils.traced_layer import TracedLayer
|
|
20
|
-
from .utils.utils import
|
|
20
|
+
from .utils.utils import (
|
|
21
|
+
ColorWheel,
|
|
22
|
+
ImageDraw,
|
|
23
|
+
InputShape,
|
|
24
|
+
StackedBox,
|
|
25
|
+
format_shape_label,
|
|
26
|
+
get_rgba_tuple,
|
|
27
|
+
resolve_palette,
|
|
28
|
+
self_multiply,
|
|
29
|
+
)
|
|
21
30
|
|
|
22
31
|
_LABEL_ROW_HEIGHT = 100
|
|
23
32
|
|
|
@@ -33,7 +42,8 @@ def lenet_view(
|
|
|
33
42
|
scale_xy: float = 1,
|
|
34
43
|
type_ignore: list | None = None,
|
|
35
44
|
color_map: dict | None = None,
|
|
36
|
-
|
|
45
|
+
palette: str = "okabe_ito",
|
|
46
|
+
low_dim_orientation: str = "z",
|
|
37
47
|
background_fill: str | tuple[int, ...] = "white",
|
|
38
48
|
padding: int = 10,
|
|
39
49
|
spacing: int = 10,
|
|
@@ -69,7 +79,13 @@ def lenet_view(
|
|
|
69
79
|
type_ignore (list, optional): List of layer types in the torch model to ignore during drawing.
|
|
70
80
|
color_map (dict, optional): Dictionary defining fill and outline colors for each layer by class type.
|
|
71
81
|
Will fallback to default values for unspecified classes.
|
|
72
|
-
|
|
82
|
+
palette (str, optional): Named color palette used as the fallback for any layer type not
|
|
83
|
+
given an explicit override via `color_map`. One of `"okabe_ito"` (default,
|
|
84
|
+
colorblind-safe), `"tol_bright"`, `"tol_muted"`, `"tab10"`, `"grayscale"`, `"nord"`,
|
|
85
|
+
`"dracula"`, `"gruvbox"`, `"solarized"`, `"material"`, `"catppuccin"`.
|
|
86
|
+
low_dim_orientation (str, optional): Axis on which a layer without real spatial/channel
|
|
87
|
+
structure (a 1D shape, or a 2D shape like an RNN/attention layer's
|
|
88
|
+
`(seq_len, hidden_size)`) should be drawn. One of `'x'`, `'y'`, or `'z'`.
|
|
73
89
|
background_fill (str or tuple, optional): Background color for the image. A string or a tuple (R, G, B, A).
|
|
74
90
|
padding (int, optional): Distance in pixels before the first and after the last layer.
|
|
75
91
|
spacing (int, optional): Spacing in pixels between two layers.
|
|
@@ -122,7 +138,7 @@ def lenet_view(
|
|
|
122
138
|
|
|
123
139
|
layer_types: list[type] = []
|
|
124
140
|
make_box = _box_factory(
|
|
125
|
-
|
|
141
|
+
low_dim_orientation,
|
|
126
142
|
scale_xy,
|
|
127
143
|
min_xy,
|
|
128
144
|
max_xy,
|
|
@@ -134,7 +150,7 @@ def lenet_view(
|
|
|
134
150
|
opacity,
|
|
135
151
|
offset_z,
|
|
136
152
|
layer_types,
|
|
137
|
-
ColorWheel(),
|
|
153
|
+
ColorWheel(colors=resolve_palette(palette)),
|
|
138
154
|
)
|
|
139
155
|
column_layout = layout_columns(
|
|
140
156
|
filtered_columns,
|
|
@@ -218,7 +234,7 @@ def _right_extent_for(box: VolumetricBox) -> float:
|
|
|
218
234
|
|
|
219
235
|
|
|
220
236
|
def _box_factory(
|
|
221
|
-
|
|
237
|
+
low_dim_orientation: str,
|
|
222
238
|
scale_xy: float,
|
|
223
239
|
min_xy: int,
|
|
224
240
|
max_xy: int,
|
|
@@ -237,19 +253,20 @@ def _box_factory(
|
|
|
237
253
|
def make_box(layer: TracedLayer) -> StackedBox:
|
|
238
254
|
shape = layer.output_shape[1:] # drop batch size
|
|
239
255
|
|
|
240
|
-
if len(shape)
|
|
241
|
-
|
|
242
|
-
|
|
256
|
+
if len(shape) in (1, 2):
|
|
257
|
+
# Neither a 1D nor a 2D shape has real spatial/channel structure - there's nothing
|
|
258
|
+
# to distinguish "channel" from "spatial" the way a genuine (C, H, W) feature map
|
|
259
|
+
# does. Take the last value (for 2D, e.g. an RNN/attention layer's
|
|
260
|
+
# (seq_len, hidden_size), this is the feature/channel-like one, matching PyTorch's
|
|
261
|
+
# (..., seq, feature) convention for sequence data; for 1D it's the only value) and
|
|
262
|
+
# let the user place it on whichever axis they choose, same as any 1D value - the
|
|
263
|
+
# positional-like dim, if any, is discarded either way.
|
|
264
|
+
value = shape[-1]
|
|
265
|
+
if low_dim_orientation in ("x", "y", "z"):
|
|
266
|
+
shape = (1,) * "cxyz".index(low_dim_orientation) + (value,)
|
|
243
267
|
else:
|
|
244
|
-
error_msg = f"unsupported orientation: {
|
|
268
|
+
error_msg = f"unsupported orientation: {low_dim_orientation}"
|
|
245
269
|
raise ValueError(error_msg)
|
|
246
|
-
elif len(shape) == 2:
|
|
247
|
-
# A 2D non-batch shape (e.g. (seq_len, hidden_size) from an RNN/attention layer)
|
|
248
|
-
# isn't a CNN feature map missing a channel dim - there's no channel axis at all.
|
|
249
|
-
# StackedBox's slice count (de, below) is driven by shape[0], so a dummy 1 goes
|
|
250
|
-
# there instead of either real dim, keeping the two real dims on the box's actual
|
|
251
|
-
# width and height instead of one of them being drawn as that many stacked slices.
|
|
252
|
-
shape = (1, *shape)
|
|
253
270
|
|
|
254
271
|
ori_shape = shape
|
|
255
272
|
shape = shape + (1,) * (4 - len(shape)) # expand 4D.
|
|
@@ -266,6 +283,7 @@ def _box_factory(
|
|
|
266
283
|
box.offset_z = offset_z
|
|
267
284
|
box.label = layer.module.name() if isinstance(layer.module, Input) else layer_type.__name__
|
|
268
285
|
box.output_shape = tuple(ori_shape)
|
|
286
|
+
box.extra_output_shapes = layer.extra_output_shapes
|
|
269
287
|
box.de = z
|
|
270
288
|
|
|
271
289
|
box.x1 = 0
|
|
@@ -405,6 +423,7 @@ def _draw_labels(
|
|
|
405
423
|
for box in column:
|
|
406
424
|
loc_x = box.x1 + (box.x2 - box.x1) // 4
|
|
407
425
|
label = getattr(box, "label", type(box).__name__)
|
|
408
|
-
|
|
426
|
+
shape_label = format_shape_label(box.output_shape, box.extra_output_shapes)
|
|
427
|
+
draw_text.text((loc_x, img.height - 50), f"{label} {shape_label}", font=font, fill=font_color)
|
|
409
428
|
|
|
410
429
|
return Image.alpha_composite(img, text_img)
|
|
@@ -32,6 +32,7 @@ class CommonOptions:
|
|
|
32
32
|
|
|
33
33
|
to_file: str | None = None
|
|
34
34
|
color_map: dict[Any, Any] | None = None
|
|
35
|
+
palette: str = "okabe_ito"
|
|
35
36
|
background_fill: str | tuple[int, ...] = "white"
|
|
36
37
|
padding: int = 10
|
|
37
38
|
opacity: int = 255
|
|
@@ -66,7 +67,7 @@ class FlowStyleOptions:
|
|
|
66
67
|
scale_z: float = 0.1
|
|
67
68
|
scale_xy: float = 1
|
|
68
69
|
type_ignore: list[type] | None = None
|
|
69
|
-
|
|
70
|
+
low_dim_orientation: str = "z"
|
|
70
71
|
draw_volume: bool = True
|
|
71
72
|
spacing: int = 10
|
|
72
73
|
draw_funnel: bool = True
|
|
@@ -86,7 +87,7 @@ class LenetStyleOptions:
|
|
|
86
87
|
scale_z: float = 1
|
|
87
88
|
scale_xy: float = 1
|
|
88
89
|
type_ignore: list[type] | None = None
|
|
89
|
-
|
|
90
|
+
low_dim_orientation: str = "z"
|
|
90
91
|
spacing: int = 10
|
|
91
92
|
draw_funnel: bool = True
|
|
92
93
|
shade_step: int = 10
|
|
@@ -107,6 +108,7 @@ def _render_graph(
|
|
|
107
108
|
input_shape,
|
|
108
109
|
to_file=common.to_file,
|
|
109
110
|
color_map=common.color_map,
|
|
111
|
+
palette=common.palette,
|
|
110
112
|
node_size=options.node_size,
|
|
111
113
|
background_fill=common.background_fill,
|
|
112
114
|
padding=common.padding,
|
|
@@ -143,7 +145,8 @@ def _render_flow(
|
|
|
143
145
|
scale_xy=options.scale_xy,
|
|
144
146
|
type_ignore=options.type_ignore,
|
|
145
147
|
color_map=common.color_map,
|
|
146
|
-
|
|
148
|
+
palette=common.palette,
|
|
149
|
+
low_dim_orientation=options.low_dim_orientation,
|
|
147
150
|
background_fill=common.background_fill,
|
|
148
151
|
draw_volume=options.draw_volume,
|
|
149
152
|
padding=common.padding,
|
|
@@ -177,7 +180,8 @@ def _render_lenet(
|
|
|
177
180
|
scale_xy=options.scale_xy,
|
|
178
181
|
type_ignore=options.type_ignore,
|
|
179
182
|
color_map=common.color_map,
|
|
180
|
-
|
|
183
|
+
palette=common.palette,
|
|
184
|
+
low_dim_orientation=options.low_dim_orientation,
|
|
181
185
|
background_fill=common.background_fill,
|
|
182
186
|
padding=common.padding,
|
|
183
187
|
spacing=options.spacing,
|
|
@@ -106,28 +106,34 @@ def _wrap_and_stamp(obj: Any, ids: set[str]) -> Any:
|
|
|
106
106
|
return obj
|
|
107
107
|
|
|
108
108
|
|
|
109
|
-
def
|
|
110
|
-
"""Recursively
|
|
109
|
+
def _all_tensor_shapes(obj: Any) -> list[tuple[int, ...]]:
|
|
110
|
+
"""Recursively collect the shape of every tensor inside obj, in encounter order.
|
|
111
|
+
|
|
112
|
+
A module's return value isn't always a single tensor - `nn.LSTM` returns
|
|
113
|
+
`(output, (h_n, c_n))`, `nn.MultiheadAttention` returns `(attn_output, attn_weights)`, and a
|
|
114
|
+
custom module can return any tuple/list/dict of tensors. Collecting all of them (rather than
|
|
115
|
+
just the first) is what lets a caller show every output shape instead of silently dropping
|
|
116
|
+
every tensor after the first.
|
|
117
|
+
"""
|
|
111
118
|
if isinstance(obj, torch.Tensor):
|
|
112
|
-
return tuple(obj.shape)
|
|
119
|
+
return [tuple(obj.shape)]
|
|
113
120
|
if isinstance(obj, Mapping):
|
|
121
|
+
shapes = []
|
|
114
122
|
for value in obj.values():
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
return shape
|
|
118
|
-
return ()
|
|
123
|
+
shapes.extend(_all_tensor_shapes(value))
|
|
124
|
+
return shapes
|
|
119
125
|
if isinstance(obj, list | tuple):
|
|
126
|
+
shapes = []
|
|
120
127
|
for value in obj:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
return ()
|
|
125
|
-
return ()
|
|
128
|
+
shapes.extend(_all_tensor_shapes(value))
|
|
129
|
+
return shapes
|
|
130
|
+
return []
|
|
126
131
|
|
|
127
132
|
|
|
128
133
|
def _wrapped_module_call(
|
|
129
134
|
id_to_module: dict[str, nn.Module],
|
|
130
135
|
id_to_output_shape: dict[str, tuple[int, ...]],
|
|
136
|
+
id_to_extra_output_shapes: dict[str, tuple[tuple[int, ...], ...]],
|
|
131
137
|
edges: list[tuple[str, str]],
|
|
132
138
|
call_counts: dict[int, int],
|
|
133
139
|
) -> Any:
|
|
@@ -154,8 +160,11 @@ def _wrapped_module_call(
|
|
|
154
160
|
call_counts[base_id] = call_index + 1
|
|
155
161
|
node_id = f"{base_id}#{call_index}"
|
|
156
162
|
|
|
163
|
+
shapes = _all_tensor_shapes(out)
|
|
157
164
|
id_to_module[node_id] = mod
|
|
158
|
-
id_to_output_shape[node_id] =
|
|
165
|
+
id_to_output_shape[node_id] = shapes[0] if shapes else ()
|
|
166
|
+
if len(shapes) > 1:
|
|
167
|
+
id_to_extra_output_shapes[node_id] = tuple(shapes[1:])
|
|
159
168
|
edges.extend((producer_id, node_id) for producer_id in producer_ids)
|
|
160
169
|
out = _wrap_and_stamp(out, {node_id})
|
|
161
170
|
|
|
@@ -171,11 +180,13 @@ class Recorder:
|
|
|
171
180
|
self,
|
|
172
181
|
id_to_module: dict[str, nn.Module],
|
|
173
182
|
id_to_output_shape: dict[str, tuple[int, ...]],
|
|
183
|
+
id_to_extra_output_shapes: dict[str, tuple[tuple[int, ...], ...]],
|
|
174
184
|
edges: list[tuple[str, str]],
|
|
175
185
|
call_counts: dict[int, int],
|
|
176
186
|
) -> None:
|
|
177
187
|
self._id_to_module = id_to_module
|
|
178
188
|
self._id_to_output_shape = id_to_output_shape
|
|
189
|
+
self._id_to_extra_output_shapes = id_to_extra_output_shapes
|
|
179
190
|
self._edges = edges
|
|
180
191
|
self._call_counts = call_counts
|
|
181
192
|
|
|
@@ -184,6 +195,7 @@ class Recorder:
|
|
|
184
195
|
nn.Module.__call__ = _wrapped_module_call( # type: ignore[method-assign]
|
|
185
196
|
self._id_to_module,
|
|
186
197
|
self._id_to_output_shape,
|
|
198
|
+
self._id_to_extra_output_shapes,
|
|
187
199
|
self._edges,
|
|
188
200
|
self._call_counts,
|
|
189
201
|
)
|
|
@@ -197,7 +209,13 @@ class Recorder:
|
|
|
197
209
|
def trace_module_graph(
|
|
198
210
|
model: nn.Module,
|
|
199
211
|
input_shapes: tuple[tuple[int, ...], ...],
|
|
200
|
-
) -> tuple[
|
|
212
|
+
) -> tuple[
|
|
213
|
+
dict[str, nn.Module],
|
|
214
|
+
dict[str, tuple[int, ...]],
|
|
215
|
+
dict[str, tuple[tuple[int, ...], ...]],
|
|
216
|
+
list[tuple[str, str]],
|
|
217
|
+
list[str],
|
|
218
|
+
]:
|
|
201
219
|
"""Trace a forward pass to recover the leaf-module call graph.
|
|
202
220
|
|
|
203
221
|
Args:
|
|
@@ -210,7 +228,11 @@ def trace_module_graph(
|
|
|
210
228
|
tuple: A tuple containing:
|
|
211
229
|
- id_to_module (dict): Mapping from node id to the leaf module. A leaf module
|
|
212
230
|
called more than once gets one entry per call (`f"{id(module)}#{call_index}"`).
|
|
213
|
-
- id_to_output_shape (dict): Mapping from node id to that module's output
|
|
231
|
+
- id_to_output_shape (dict): Mapping from node id to that module's *first* output
|
|
232
|
+
tensor's shape (the one used for box sizing).
|
|
233
|
+
- id_to_extra_output_shapes (dict): Mapping from node id to the shapes of any
|
|
234
|
+
*additional* output tensors beyond the first (e.g. `nn.LSTM`'s `(h_n, c_n)`) -
|
|
235
|
+
only present for nodes that return more than one tensor.
|
|
214
236
|
- edges (list): `(producer_node_id, consumer_node_id)` pairs, in call order.
|
|
215
237
|
- input_ids (list): One synthetic node id per input tensor, in the same order as
|
|
216
238
|
`input_shapes`.
|
|
@@ -223,10 +245,11 @@ def trace_module_graph(
|
|
|
223
245
|
|
|
224
246
|
id_to_module: dict[str, nn.Module] = {}
|
|
225
247
|
id_to_output_shape: dict[str, tuple[int, ...]] = {}
|
|
248
|
+
id_to_extra_output_shapes: dict[str, tuple[tuple[int, ...], ...]] = {}
|
|
226
249
|
edges: list[tuple[str, str]] = []
|
|
227
250
|
call_counts: dict[int, int] = {}
|
|
228
251
|
|
|
229
|
-
with Recorder(id_to_module, id_to_output_shape, edges, call_counts):
|
|
252
|
+
with Recorder(id_to_module, id_to_output_shape, id_to_extra_output_shapes, edges, call_counts):
|
|
230
253
|
if isinstance(model, nn.ModuleList):
|
|
231
254
|
# nn.ModuleList has no forward() of its own - it's a plain container, not meant to
|
|
232
255
|
# be called directly - so drive it the same way a user would: chain each child call.
|
|
@@ -244,4 +267,4 @@ def trace_module_graph(
|
|
|
244
267
|
model(*dummy_inputs)
|
|
245
268
|
|
|
246
269
|
input_ids = [f"{INPUT_NODE_ID}#{i}" for i in range(len(input_shapes))]
|
|
247
|
-
return id_to_module, id_to_output_shape, edges, input_ids
|
|
270
|
+
return id_to_module, id_to_output_shape, id_to_extra_output_shapes, edges, input_ids
|
|
@@ -20,3 +20,9 @@ class TracedLayer:
|
|
|
20
20
|
module: nn.Module
|
|
21
21
|
output_shape: tuple[int, ...]
|
|
22
22
|
node_id: str
|
|
23
|
+
extra_output_shapes: tuple[tuple[int, ...], ...] = ()
|
|
24
|
+
"""Shapes of any additional output tensors beyond `output_shape` (e.g. `nn.LSTM`'s hidden and
|
|
25
|
+
cell state, returned alongside its main sequence output) - empty for a module that returns
|
|
26
|
+
just one tensor. `output_shape` alone still drives box sizing; this is only ever read to
|
|
27
|
+
extend the `show_dimension` label so those extra tensors aren't silently unaccounted for.
|
|
28
|
+
"""
|
|
@@ -66,6 +66,7 @@ class StackedBox(Shape):
|
|
|
66
66
|
offset_z: int
|
|
67
67
|
label: str
|
|
68
68
|
output_shape: tuple
|
|
69
|
+
extra_output_shapes: tuple[tuple[int, ...], ...] = ()
|
|
69
70
|
|
|
70
71
|
def draw(self, draw: ImageDraw) -> None:
|
|
71
72
|
"""Draw box shape."""
|
|
@@ -114,6 +115,7 @@ class Box(Shape):
|
|
|
114
115
|
de: int
|
|
115
116
|
shade: int
|
|
116
117
|
output_shape: tuple[int, ...]
|
|
118
|
+
extra_output_shapes: tuple[tuple[int, ...], ...] = ()
|
|
117
119
|
|
|
118
120
|
def draw(self, draw: ImageDraw) -> None:
|
|
119
121
|
"""Draw box shape."""
|
|
@@ -225,26 +227,116 @@ class Ellipses(Shape):
|
|
|
225
227
|
)
|
|
226
228
|
|
|
227
229
|
|
|
230
|
+
PALETTES: dict[str, list[str]] = {
|
|
231
|
+
# Okabe-Ito: a colorblind-safe palette (Okabe & Ito, 2008) widely recommended for
|
|
232
|
+
# scientific visualization, e.g. in Nature's figure guidelines.
|
|
233
|
+
"okabe_ito": ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"],
|
|
234
|
+
# Paul Tol's "bright" qualitative scheme - colorblind-safe, higher-contrast alternative.
|
|
235
|
+
"tol_bright": ["#4477AA", "#EE6677", "#228833", "#CCBB44", "#66CCEE", "#AA3377", "#BBBBBB"],
|
|
236
|
+
# Paul Tol's "muted" qualitative scheme - colorblind-safe, softer aesthetic.
|
|
237
|
+
"tol_muted": [
|
|
238
|
+
"#CC6677",
|
|
239
|
+
"#332288",
|
|
240
|
+
"#DDCC77",
|
|
241
|
+
"#117733",
|
|
242
|
+
"#88CCEE",
|
|
243
|
+
"#882255",
|
|
244
|
+
"#44AA99",
|
|
245
|
+
"#999933",
|
|
246
|
+
"#AA4499",
|
|
247
|
+
],
|
|
248
|
+
# matplotlib's default color cycle.
|
|
249
|
+
"tab10": [
|
|
250
|
+
"#1f77b4",
|
|
251
|
+
"#ff7f0e",
|
|
252
|
+
"#2ca02c",
|
|
253
|
+
"#d62728",
|
|
254
|
+
"#9467bd",
|
|
255
|
+
"#8c564b",
|
|
256
|
+
"#e377c2",
|
|
257
|
+
"#7f7f7f",
|
|
258
|
+
"#bcbd22",
|
|
259
|
+
"#17becf",
|
|
260
|
+
],
|
|
261
|
+
# Evenly-spaced grays for print/monochrome-safe figures.
|
|
262
|
+
"grayscale": ["#404040", "#595959", "#737373", "#8c8c8c", "#a6a6a6", "#bfbfbf", "#d9d9d9"],
|
|
263
|
+
# Nord's Aurora + Frost accent colors.
|
|
264
|
+
"nord": [
|
|
265
|
+
"#bf616a",
|
|
266
|
+
"#d08770",
|
|
267
|
+
"#ebcb8b",
|
|
268
|
+
"#a3be8c",
|
|
269
|
+
"#b48ead",
|
|
270
|
+
"#8fbcbb",
|
|
271
|
+
"#88c0d0",
|
|
272
|
+
"#81a1c1",
|
|
273
|
+
"#5e81ac",
|
|
274
|
+
],
|
|
275
|
+
# Dracula theme's accent colors.
|
|
276
|
+
"dracula": ["#FF5555", "#FFB86C", "#F1FA8C", "#50FA7B", "#8BE9FD", "#BD93F9", "#FF79C6"],
|
|
277
|
+
# Gruvbox's bright color variants.
|
|
278
|
+
"gruvbox": ["#fb4934", "#b8bb26", "#fabd2f", "#83a598", "#d3869b", "#8ec07c", "#fe8019"],
|
|
279
|
+
# Solarized's accent colors.
|
|
280
|
+
"solarized": [
|
|
281
|
+
"#b58900",
|
|
282
|
+
"#cb4b16",
|
|
283
|
+
"#dc322f",
|
|
284
|
+
"#d33682",
|
|
285
|
+
"#6c71c4",
|
|
286
|
+
"#268bd2",
|
|
287
|
+
"#2aa198",
|
|
288
|
+
"#859900",
|
|
289
|
+
],
|
|
290
|
+
# Material Design's 500-weight color spread.
|
|
291
|
+
"material": [
|
|
292
|
+
"#f44336",
|
|
293
|
+
"#e91e63",
|
|
294
|
+
"#9c27b0",
|
|
295
|
+
"#3f51b5",
|
|
296
|
+
"#2196f3",
|
|
297
|
+
"#009688",
|
|
298
|
+
"#4caf50",
|
|
299
|
+
"#ffc107",
|
|
300
|
+
"#ff5722",
|
|
301
|
+
],
|
|
302
|
+
# Catppuccin's Mocha flavor accent colors.
|
|
303
|
+
"catppuccin": [
|
|
304
|
+
"#f38ba8",
|
|
305
|
+
"#fab387",
|
|
306
|
+
"#f9e2af",
|
|
307
|
+
"#a6e3a1",
|
|
308
|
+
"#94e2d5",
|
|
309
|
+
"#89dceb",
|
|
310
|
+
"#89b4fa",
|
|
311
|
+
"#b4befe",
|
|
312
|
+
"#cba6f7",
|
|
313
|
+
"#f5c2e7",
|
|
314
|
+
],
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def resolve_palette(name: str) -> list[str]:
|
|
319
|
+
"""Resolve a named palette to its list of hex colors.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
name (str): One of the keys in `PALETTES`.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
list[str]: The palette's hex color strings.
|
|
326
|
+
"""
|
|
327
|
+
if name not in PALETTES:
|
|
328
|
+
supported = ", ".join(sorted(PALETTES))
|
|
329
|
+
error_msg = f"Unsupported palette {name!r}. Supported palettes: {supported}."
|
|
330
|
+
raise ValueError(error_msg)
|
|
331
|
+
return PALETTES[name]
|
|
332
|
+
|
|
333
|
+
|
|
228
334
|
class ColorWheel:
|
|
229
335
|
"""Default colors for the shapes."""
|
|
230
336
|
|
|
231
337
|
def __init__(self, colors: list | None = None) -> None:
|
|
232
338
|
self._cache: dict[type, Any] = {}
|
|
233
|
-
|
|
234
|
-
# scientific visualization, e.g. in Nature's figure guidelines.
|
|
235
|
-
self.colors = (
|
|
236
|
-
colors
|
|
237
|
-
if colors is not None
|
|
238
|
-
else [
|
|
239
|
-
"#E69F00", # orange
|
|
240
|
-
"#56B4E9", # sky blue
|
|
241
|
-
"#009E73", # bluish green
|
|
242
|
-
"#F0E442", # yellow
|
|
243
|
-
"#0072B2", # blue
|
|
244
|
-
"#D55E00", # vermillion
|
|
245
|
-
"#CC79A7", # reddish purple
|
|
246
|
-
]
|
|
247
|
-
)
|
|
339
|
+
self.colors = colors if colors is not None else PALETTES["okabe_ito"]
|
|
248
340
|
|
|
249
341
|
def get_color(self, class_type: type) -> tuple | None:
|
|
250
342
|
"""Return color from cache if exist, if not, get from the list and store it to the cache."""
|
|
@@ -344,6 +436,22 @@ def self_multiply(tensor_tuple: tuple | list) -> int | float:
|
|
|
344
436
|
return s
|
|
345
437
|
|
|
346
438
|
|
|
439
|
+
def format_shape_label(output_shape: tuple[int, ...], extra_output_shapes: tuple[tuple[int, ...], ...]) -> str:
|
|
440
|
+
"""Format an output shape for display, appending any extra output shapes if present.
|
|
441
|
+
|
|
442
|
+
A module that returns more than one meaningful tensor (e.g. `nn.LSTM`'s `(output, (h_n,
|
|
443
|
+
c_n))`) would otherwise only ever show `output_shape` (the first tensor found, also the one
|
|
444
|
+
driving box size) with no indication the other tensors exist at all. `+` is used rather than
|
|
445
|
+
`visualtorch`'s existing `/` convention (already used to join sibling branches within one
|
|
446
|
+
column) so this doesn't read as an alternative/branch - these are all real outputs of this
|
|
447
|
+
same node, not one of several options.
|
|
448
|
+
"""
|
|
449
|
+
label = str(output_shape)
|
|
450
|
+
if extra_output_shapes:
|
|
451
|
+
label += " + " + " + ".join(str(shape) for shape in extra_output_shapes)
|
|
452
|
+
return label
|
|
453
|
+
|
|
454
|
+
|
|
347
455
|
def vertical_image_concat(
|
|
348
456
|
im1: Image,
|
|
349
457
|
im2: Image,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -54,21 +54,22 @@ Dynamic: summary
|
|
|
54
54
|
|
|
55
55
|
</div>
|
|
56
56
|
|
|
57
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models.
|
|
57
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
|
|
58
58
|
|
|
59
59
|
**Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
|
|
60
60
|
|
|
61
|
-
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has
|
|
62
|
-
|
|
63
|
-
execution):
|
|
64
|
-
value crosses some threshold) only show whichever branch the traced dummy input happened to take
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
61
|
+
**Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
|
|
62
|
+
limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
|
|
63
|
+
execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
|
|
64
|
+
value crosses some threshold) only show whichever branch the traced dummy input happened to take.
|
|
65
|
+
Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
|
|
66
|
+
head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
|
|
67
|
+
tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
|
|
68
|
+
the first. Downstream connections are correct either way. Contributions are welcome!
|
|
68
69
|
|
|
69
70
|
<div align="center">
|
|
70
71
|
|
|
71
|
-

|
|
72
|
+

|
|
72
73
|
|
|
73
74
|
</div>
|
|
74
75
|
|
|
@@ -100,16 +101,16 @@ Please feel free to send a pull request to contribute to this project by followi
|
|
|
100
101
|
|
|
101
102
|
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
102
103
|
|
|
103
|
-
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz),
|
|
104
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
|
|
104
105
|
|
|
105
106
|
## Citation
|
|
106
107
|
|
|
107
108
|
Please cite this project in your publications if it helps your research.
|
|
108
109
|
|
|
109
|
-
**Note:** the paper below describes
|
|
110
|
-
since
|
|
111
|
-
for the current API) - the DOI
|
|
112
|
-
|
|
110
|
+
**Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
|
|
111
|
+
since been substantially refactored, including breaking API changes (see the
|
|
112
|
+
[documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
|
|
113
|
+
always resolves to what was actually reviewed and published.
|
|
113
114
|
|
|
114
115
|
```bibtex
|
|
115
116
|
@article{Hendria2024,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|