visualtorch 1.1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {visualtorch-1.1.0/visualtorch.egg-info → visualtorch-1.2.0}/PKG-INFO +16 -15
  2. {visualtorch-1.1.0 → visualtorch-1.2.0}/README.md +15 -14
  3. {visualtorch-1.1.0 → visualtorch-1.2.0}/setup.py +1 -1
  4. {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_flow.py +56 -5
  5. {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_lenet_style.py +56 -5
  6. {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_regression_issues.py +47 -1
  7. {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_render.py +44 -0
  8. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/__init__.py +2 -0
  9. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/backend.py +5 -1
  10. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/flow.py +28 -17
  11. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/graph.py +9 -3
  12. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/lenet_style.py +37 -18
  13. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/render.py +8 -4
  14. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/recorder.py +40 -17
  15. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/traced_layer.py +6 -0
  16. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/utils.py +123 -15
  17. {visualtorch-1.1.0 → visualtorch-1.2.0/visualtorch.egg-info}/PKG-INFO +16 -15
  18. {visualtorch-1.1.0 → visualtorch-1.2.0}/LICENSE +0 -0
  19. {visualtorch-1.1.0 → visualtorch-1.2.0}/pyproject.toml +0 -0
  20. {visualtorch-1.1.0 → visualtorch-1.2.0}/setup.cfg +0 -0
  21. {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_connectors.py +0 -0
  22. {visualtorch-1.1.0 → visualtorch-1.2.0}/tests/test_graph.py +0 -0
  23. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/_volumetric_layout.py +0 -0
  24. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/connectors.py +0 -0
  25. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/__init__.py +0 -0
  26. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch/utils/layer_utils.py +0 -0
  27. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/SOURCES.txt +0 -0
  28. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/dependency_links.txt +0 -0
  29. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/requires.txt +0 -0
  30. {visualtorch-1.1.0 → visualtorch-1.2.0}/visualtorch.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: visualtorch
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -54,21 +54,22 @@ Dynamic: summary
54
54
 
55
55
  </div>
56
56
 
57
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
57
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
58
58
 
59
59
  **Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
60
60
 
61
- **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
62
- limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
63
- execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
64
- value crosses some threshold) only show whichever branch the traced dummy input happened to take;
65
- (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
66
- only has its first tensor's shape reflected in that node's size/label - its downstream connections
67
- are still correct either way. Contributions are welcome!
61
+ **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
62
+ limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
63
+ execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
64
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take.
65
+ Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
66
+ head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
67
+ tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
68
+ the first. Downstream connections are correct either way. Contributions are welcome!
68
69
 
69
70
  <div align="center">
70
71
 
71
- ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
72
+ ![VisualTorch Examples](https://raw.githubusercontent.com/willyfh/visualtorch/e6ad79751e0f7412b1074beb45f9baeccd1419e4/docs/source/_static/images/banners/readme-examples.png)
72
73
 
73
74
  </div>
74
75
 
@@ -100,16 +101,16 @@ Please feel free to send a pull request to contribute to this project by followi
100
101
 
101
102
  This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
102
103
 
103
- Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
104
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
104
105
 
105
106
  ## Citation
106
107
 
107
108
  Please cite this project in your publications if it helps your research.
108
109
 
109
- **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
110
- since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
111
- for the current API) - the DOI always resolves to what was actually reviewed and published, so
112
- it isn't updated to match.
110
+ **Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
111
+ since been substantially refactored, including breaking API changes (see the
112
+ [documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
113
+ always resolves to what was actually reviewed and published.
113
114
 
114
115
  ```bibtex
115
116
  @article{Hendria2024,
@@ -5,21 +5,22 @@
5
5
 
6
6
  </div>
7
7
 
8
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
8
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
9
9
 
10
10
  **Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
11
11
 
12
- **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
13
- limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
14
- execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
15
- value crosses some threshold) only show whichever branch the traced dummy input happened to take;
16
- (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
17
- only has its first tensor's shape reflected in that node's size/label - its downstream connections
18
- are still correct either way. Contributions are welcome!
12
+ **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
13
+ limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
14
+ execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
15
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take.
16
+ Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
17
+ head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
18
+ tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
19
+ the first. Downstream connections are correct either way. Contributions are welcome!
19
20
 
20
21
  <div align="center">
21
22
 
22
- ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
23
+ ![VisualTorch Examples](https://raw.githubusercontent.com/willyfh/visualtorch/e6ad79751e0f7412b1074beb45f9baeccd1419e4/docs/source/_static/images/banners/readme-examples.png)
23
24
 
24
25
  </div>
25
26
 
@@ -51,16 +52,16 @@ Please feel free to send a pull request to contribute to this project by followi
51
52
 
52
53
  This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
53
54
 
54
- Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
55
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
55
56
 
56
57
  ## Citation
57
58
 
58
59
  Please cite this project in your publications if it helps your research.
59
60
 
60
- **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
61
- since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
62
- for the current API) - the DOI always resolves to what was actually reviewed and published, so
63
- it isn't updated to match.
61
+ **Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
62
+ since been substantially refactored, including breaking API changes (see the
63
+ [documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
64
+ always resolves to what was actually reviewed and published.
64
65
 
65
66
  ```bibtex
66
67
  @article{Hendria2024,
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
21
21
 
22
22
  setuptools.setup(
23
23
  name="visualtorch",
24
- version="1.1.0",
24
+ version="1.2.0",
25
25
  author="Willy Fitra Hendria",
26
26
  author_email="willyfitrahendria@gmail.com",
27
27
  description="Architecture visualization of Torch models",
@@ -174,16 +174,16 @@ def test_rnn_model_flow_view_runs(rnn_model: nn.Module) -> None:
174
174
 
175
175
 
176
176
  @pytest.mark.parametrize("orientation", ["x", "y", "z"])
177
- def test_flow_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
177
+ def test_flow_view_low_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
178
178
  """Test flow view on a model with a 1D output, for every supported orientation."""
179
- img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
179
+ img = flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation=orientation)
180
180
  assert img is not None
181
181
 
182
182
 
183
- def test_flow_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
- """An unsupported one_dim_orientation should raise a clear ValueError."""
183
+ def test_flow_view_invalid_low_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
+ """An unsupported low_dim_orientation should raise a clear ValueError."""
185
185
  with pytest.raises(ValueError, match="unsupported orientation"):
186
- flow_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
186
+ flow_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation="bad")
187
187
 
188
188
 
189
189
  def test_flow_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
@@ -535,3 +535,54 @@ def test_flow_view_mismatched_depth_siamese_branches_needs_no_detour() -> None:
535
535
  img_matched = flow_view(SiameseNetDepthMatched(), input_shape=input_shape)
536
536
 
537
537
  assert img_mismatched.size[1] == img_matched.size[1]
538
+
539
+
540
+ def test_flow_view_low_dim_orientation_affects_2d_shapes() -> None:
541
+ """A 2D shape (e.g. an RNN's (seq_len, hidden_size)) should now respond to
542
+ low_dim_orientation too, not just genuine 1D shapes.
543
+ """ # noqa: D205
544
+
545
+ class SequenceClassifier(nn.Module):
546
+ def __init__(self, hidden_size: int) -> None:
547
+ super().__init__()
548
+ self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
549
+
550
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
551
+ out, _ = self.lstm(x)
552
+ return out
553
+
554
+ model = SequenceClassifier(hidden_size=64)
555
+ input_shape = (1, 5, 8)
556
+
557
+ sizes = {
558
+ orientation: flow_view(model, input_shape=input_shape, low_dim_orientation=orientation).size
559
+ for orientation in ("x", "y", "z")
560
+ }
561
+
562
+ assert len(set(sizes.values())) == 3, f"expected all 3 orientations to differ, got {sizes}"
563
+
564
+
565
+ def test_flow_view_2d_shape_seq_len_is_discarded() -> None:
566
+ """The positional-like dim (e.g. seq_len) of a 2D shape shouldn't affect box size -
567
+ only the feature-like dim (e.g. hidden_size) should.
568
+ """ # noqa: D205
569
+
570
+ class SequenceClassifier(nn.Module):
571
+ def __init__(self, hidden_size: int) -> None:
572
+ super().__init__()
573
+ self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
574
+
575
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
576
+ out, _ = self.lstm(x)
577
+ return out
578
+
579
+ model = SequenceClassifier(hidden_size=64)
580
+ img_short_seq = flow_view(model, input_shape=(1, 5, 8))
581
+ img_long_seq = flow_view(model, input_shape=(1, 50, 8))
582
+
583
+ assert img_short_seq.tobytes() == img_long_seq.tobytes()
584
+
585
+ model_bigger_hidden = SequenceClassifier(hidden_size=256)
586
+ img_bigger_hidden = flow_view(model_bigger_hidden, input_shape=(1, 5, 8))
587
+
588
+ assert img_short_seq.tobytes() != img_bigger_hidden.tobytes()
@@ -174,16 +174,16 @@ def test_rnn_model_lenet_view_runs(rnn_model: nn.Module) -> None:
174
174
 
175
175
 
176
176
  @pytest.mark.parametrize("orientation", ["x", "y", "z"])
177
- def test_lenet_view_one_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
177
+ def test_lenet_view_low_dim_orientation(classifier_model: nn.Module, orientation: str) -> None:
178
178
  """Test lenet view on a model with a 1D output, for every supported orientation."""
179
- img = lenet_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation=orientation)
179
+ img = lenet_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation=orientation)
180
180
  assert img is not None
181
181
 
182
182
 
183
- def test_lenet_view_invalid_one_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
- """An unsupported one_dim_orientation should raise a clear ValueError."""
183
+ def test_lenet_view_invalid_low_dim_orientation_raises(classifier_model: nn.Module) -> None:
184
+ """An unsupported low_dim_orientation should raise a clear ValueError."""
185
185
  with pytest.raises(ValueError, match="unsupported orientation"):
186
- lenet_view(classifier_model, input_shape=(1, 3, 16, 16), one_dim_orientation="bad")
186
+ lenet_view(classifier_model, input_shape=(1, 3, 16, 16), low_dim_orientation="bad")
187
187
 
188
188
 
189
189
  def test_lenet_view_with_type_ignore(sequential_model: nn.Sequential) -> None:
@@ -328,3 +328,54 @@ def test_lenet_view_funnels_survive_large_de_differences_between_layers() -> Non
328
328
  non_bg = int((np.array(img.convert("RGB")) != 255).any(axis=2).sum())
329
329
  error_msg = f"non-background pixel count {non_bg} outside expected range - funnel likely broken"
330
330
  assert 110000 <= non_bg <= 145000, error_msg
331
+
332
+
333
+ def test_lenet_view_low_dim_orientation_affects_2d_shapes() -> None:
334
+ """A 2D shape (e.g. an RNN's (seq_len, hidden_size)) should now respond to
335
+ low_dim_orientation too, not just genuine 1D shapes.
336
+ """ # noqa: D205
337
+
338
+ class SequenceClassifier(nn.Module):
339
+ def __init__(self, hidden_size: int) -> None:
340
+ super().__init__()
341
+ self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
342
+
343
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
344
+ out, _ = self.lstm(x)
345
+ return out
346
+
347
+ model = SequenceClassifier(hidden_size=64)
348
+ input_shape = (1, 5, 8)
349
+
350
+ sizes = {
351
+ orientation: lenet_view(model, input_shape=input_shape, low_dim_orientation=orientation).size
352
+ for orientation in ("x", "y", "z")
353
+ }
354
+
355
+ assert len(set(sizes.values())) == 3, f"expected all 3 orientations to differ, got {sizes}"
356
+
357
+
358
+ def test_lenet_view_2d_shape_seq_len_is_discarded() -> None:
359
+ """The positional-like dim (e.g. seq_len) of a 2D shape shouldn't affect box size -
360
+ only the feature-like dim (e.g. hidden_size) should.
361
+ """ # noqa: D205
362
+
363
+ class SequenceClassifier(nn.Module):
364
+ def __init__(self, hidden_size: int) -> None:
365
+ super().__init__()
366
+ self.lstm = nn.LSTM(input_size=8, hidden_size=hidden_size, batch_first=True)
367
+
368
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
369
+ out, _ = self.lstm(x)
370
+ return out
371
+
372
+ model = SequenceClassifier(hidden_size=64)
373
+ img_short_seq = lenet_view(model, input_shape=(1, 5, 8))
374
+ img_long_seq = lenet_view(model, input_shape=(1, 50, 8))
375
+
376
+ assert img_short_seq.tobytes() == img_long_seq.tobytes()
377
+
378
+ model_bigger_hidden = SequenceClassifier(hidden_size=256)
379
+ img_bigger_hidden = lenet_view(model_bigger_hidden, input_shape=(1, 5, 8))
380
+
381
+ assert img_short_seq.tobytes() != img_bigger_hidden.tobytes()
@@ -11,8 +11,9 @@ import torch
11
11
  from torch import nn
12
12
  from visualtorch.backend import extract_architecture
13
13
  from visualtorch.flow import flow_view
14
+ from visualtorch.graph import graph_view
14
15
  from visualtorch.lenet_style import lenet_view
15
- from visualtorch.utils.utils import self_multiply
16
+ from visualtorch.utils.utils import format_shape_label, self_multiply
16
17
 
17
18
 
18
19
  @pytest.fixture()
@@ -112,3 +113,48 @@ def test_flow_view_recurrent_sequence_length_does_not_inflate_diagram_height(rec
112
113
  long_img = flow_view(model, input_shape=(1, 200, 10))
113
114
 
114
115
  assert long_img.height == short_img.height
116
+
117
+
118
+ @pytest.fixture()
119
+ def lstm_using_hidden_state_model() -> nn.Module:
120
+ """A model that consumes `nn.LSTM`'s hidden state (h_n), not its sequence output.
121
+
122
+ `nn.LSTM.forward()` returns `(output, (h_n, c_n))` - three tensors, not one. Before the fix,
123
+ only `output`'s shape was ever recorded, even when (as here) the model actually uses `h_n`
124
+ instead - silently dropping the shape of the tensor that matters, with nothing shown to
125
+ indicate more than one tensor even exists.
126
+ """
127
+
128
+ class Model(nn.Module):
129
+ def __init__(self) -> None:
130
+ super().__init__()
131
+ self.lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
132
+ self.fc = nn.Linear(20, 5)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ _output, (h_n, _c_n) = self.lstm(x)
136
+ return self.fc(h_n.squeeze(0))
137
+
138
+ return Model()
139
+
140
+
141
+ def test_extract_architecture_records_every_output_tensor_shape(lstm_using_hidden_state_model: nn.Module) -> None:
142
+ """A multi-output leaf module's TracedLayer should record all three of its output shapes."""
143
+ architecture = extract_architecture(lstm_using_hidden_state_model, (1, 7, 10))
144
+ lstm_layer = next(layer for column in architecture.columns for layer in column if isinstance(layer.module, nn.LSTM))
145
+
146
+ assert lstm_layer.output_shape == (1, 7, 20)
147
+ assert lstm_layer.extra_output_shapes == ((1, 1, 20), (1, 1, 20))
148
+
149
+
150
+ def test_format_shape_label_appends_extra_shapes() -> None:
151
+ """format_shape_label should append every extra shape, and omit the `+` entirely when there are none."""
152
+ assert format_shape_label((1, 7, 20), ()) == "(1, 7, 20)"
153
+ assert format_shape_label((1, 7, 20), ((1, 1, 20), (1, 1, 20))) == "(1, 7, 20) + (1, 1, 20) + (1, 1, 20)"
154
+
155
+
156
+ @pytest.mark.parametrize("view", [graph_view, flow_view, lenet_view])
157
+ def test_show_dimension_includes_every_output_shape(lstm_using_hidden_state_model: nn.Module, view: object) -> None:
158
+ """show_dimension=True shouldn't crash, and should still work, for a multi-output leaf module."""
159
+ img = view(lstm_using_hidden_state_model, input_shape=(1, 7, 10), show_dimension=True) # type: ignore[operator]
160
+ assert img is not None
@@ -3,11 +3,15 @@
3
3
  # Copyright (C) 2024 Willy Fitra Hendria
4
4
  # SPDX-License-Identifier: MIT
5
5
 
6
+ from collections import defaultdict
7
+
6
8
  import pytest
7
9
  import torch
8
10
  from torch import nn
9
11
  from visualtorch import render
10
12
  from visualtorch.backend import extract_architecture
13
+ from visualtorch.utils.layer_utils import Input
14
+ from visualtorch.utils.utils import PALETTES
11
15
 
12
16
 
13
17
  @pytest.fixture()
@@ -148,3 +152,43 @@ def test_render_handles_unused_input_tensor() -> None:
148
152
 
149
153
  img = render(PartiallyUnusedNet(), input_shape=((1, 10), (1, 5)), style="graph")
150
154
  assert img is not None
155
+
156
+
157
+ @pytest.mark.parametrize("palette", sorted(PALETTES))
158
+ def test_render_runs_for_every_named_palette(sequential_model: nn.Sequential, palette: str) -> None:
159
+ """Every named palette should render without error - catches any malformed hex color."""
160
+ img = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", palette=palette)
161
+ assert img is not None
162
+
163
+
164
+ def test_render_rejects_unsupported_palette(sequential_model: nn.Sequential) -> None:
165
+ """An unrecognized palette name should raise a clear error, not silently fall back."""
166
+ with pytest.raises(ValueError, match="Unsupported palette"):
167
+ render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", palette="bogus")
168
+
169
+
170
+ def test_render_palette_changes_fallback_colors(sequential_model: nn.Sequential) -> None:
171
+ """A different palette should actually change the colors of unmapped layer types."""
172
+ default = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph")
173
+ dracula = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", palette="dracula")
174
+
175
+ assert default.tobytes() != dracula.tobytes()
176
+
177
+
178
+ def test_render_color_map_overrides_palette(sequential_model: nn.Sequential) -> None:
179
+ """An explicit color_map entry should still win over the palette fallback."""
180
+ color_map: dict = defaultdict(dict)
181
+ color_map[Input]["fill"] = "#abcdef"
182
+ color_map[nn.Conv2d]["fill"] = "#123456"
183
+ color_map[nn.ReLU]["fill"] = "#654321"
184
+
185
+ okabe_ito = render(sequential_model, input_shape=(1, 3, 16, 16), style="graph", color_map=color_map)
186
+ dracula = render(
187
+ sequential_model,
188
+ input_shape=(1, 3, 16, 16),
189
+ style="graph",
190
+ color_map=color_map,
191
+ palette="dracula",
192
+ )
193
+
194
+ assert okabe_ito.tobytes() == dracula.tobytes()
@@ -11,6 +11,7 @@ from visualtorch.render import (
11
11
  render,
12
12
  )
13
13
  from visualtorch.utils.layer_utils import Input
14
+ from visualtorch.utils.utils import PALETTES
14
15
 
15
16
  __all__ = [
16
17
  "render",
@@ -19,4 +20,5 @@ __all__ = [
19
20
  "FlowStyleOptions",
20
21
  "LenetStyleOptions",
21
22
  "Input",
23
+ "PALETTES",
22
24
  ]
@@ -67,7 +67,10 @@ def extract_architecture(model: nn.Module, input_shape: InputShape) -> Architect
67
67
  """
68
68
  input_shapes = validate_input_shape(input_shape)
69
69
 
70
- id_to_module, id_to_output_shape, edges, input_ids = trace_module_graph(model, input_shapes)
70
+ id_to_module, id_to_output_shape, id_to_extra_output_shapes, edges, input_ids = trace_module_graph(
71
+ model,
72
+ input_shapes,
73
+ )
71
74
 
72
75
  nodes = list(id_to_module.keys())
73
76
  id_to_index = {node_id: idx for idx, node_id in enumerate(nodes)}
@@ -114,6 +117,7 @@ def extract_architecture(model: nn.Module, input_shape: InputShape) -> Architect
114
117
  module=id_to_module[node_id],
115
118
  output_shape=id_to_output_shape[node_id],
116
119
  node_id=node_id,
120
+ extra_output_shapes=id_to_extra_output_shapes.get(node_id, ()),
117
121
  )
118
122
  columns[depth[node_id]].append(wrapper)
119
123
 
@@ -22,8 +22,10 @@ from .utils.utils import (
22
22
  ColorWheel,
23
23
  ImageDraw,
24
24
  InputShape,
25
+ format_shape_label,
25
26
  get_rgba_tuple,
26
27
  linear_layout,
28
+ resolve_palette,
27
29
  self_multiply,
28
30
  vertical_image_concat,
29
31
  )
@@ -41,7 +43,8 @@ def flow_view(
41
43
  scale_xy: float = 1,
42
44
  type_ignore: list | None = None,
43
45
  color_map: dict | None = None,
44
- one_dim_orientation: str = "z",
46
+ palette: str = "okabe_ito",
47
+ low_dim_orientation: str = "z",
45
48
  background_fill: str | tuple[int, ...] = "white",
46
49
  draw_volume: bool = True,
47
50
  padding: int = 10,
@@ -75,7 +78,13 @@ def flow_view(
75
78
  type_ignore (list, optional): List of layer types in the torch model to ignore during drawing.
76
79
  color_map (dict, optional): Dictionary defining fill and outline colors for each layer by class type.
77
80
  Will fallback to default values for unspecified classes.
78
- one_dim_orientation (str, optional): Axis on which one-dim layers should be drawn. E.g., 'x', 'y', or 'z'.
81
+ palette (str, optional): Named color palette used as the fallback for any layer type not
82
+ given an explicit override via `color_map`. One of `"okabe_ito"` (default,
83
+ colorblind-safe), `"tol_bright"`, `"tol_muted"`, `"tab10"`, `"grayscale"`, `"nord"`,
84
+ `"dracula"`, `"gruvbox"`, `"solarized"`, `"material"`, `"catppuccin"`.
85
+ low_dim_orientation (str, optional): Axis on which a layer without real spatial/channel
86
+ structure (a 1D shape, or a 2D shape like an RNN/attention layer's
87
+ `(seq_len, hidden_size)`) should be drawn. One of `'x'`, `'y'`, or `'z'`.
79
88
  background_fill (str or tuple, optional): Background color for the image. A string or a tuple (R, G, B, A).
80
89
  draw_volume (bool, optional): Flag to switch between 3D volumetric view and 2D box view.
81
90
  padding (int, optional): Distance in pixels before the first and after the last layer.
@@ -132,9 +141,9 @@ def flow_view(
132
141
  filtered_columns = [column for column in filtered_columns if column]
133
142
 
134
143
  layer_types: list[type] = []
135
- color_wheel = ColorWheel()
144
+ color_wheel = ColorWheel(colors=resolve_palette(palette))
136
145
  make_box = _box_factory(
137
- one_dim_orientation,
146
+ low_dim_orientation,
138
147
  scale_xy,
139
148
  min_xy,
140
149
  max_xy,
@@ -232,7 +241,7 @@ def flow_view(
232
241
 
233
242
 
234
243
  def _box_factory(
235
- one_dim_orientation: str,
244
+ low_dim_orientation: str,
236
245
  scale_xy: float,
237
246
  min_xy: int,
238
247
  max_xy: int,
@@ -251,19 +260,20 @@ def _box_factory(
251
260
  def make_box(layer: TracedLayer) -> Box:
252
261
  shape = layer.output_shape[1:] # drop batch size
253
262
 
254
- if len(shape) == 1:
255
- if one_dim_orientation in ("x", "y", "z"):
256
- shape = (1,) * "cxyz".index(one_dim_orientation) + shape
263
+ if len(shape) in (1, 2):
264
+ # Neither a 1D nor a 2D shape has real spatial/channel structure - there's nothing
265
+ # to distinguish "channel" from "spatial" the way a genuine (C, H, W) feature map
266
+ # does. Take the last value (for 2D, e.g. an RNN/attention layer's
267
+ # (seq_len, hidden_size), this is the feature/channel-like one, matching PyTorch's
268
+ # (..., seq, feature) convention for sequence data; for 1D it's the only value) and
269
+ # let the user place it on whichever axis they choose, same as any 1D value - the
270
+ # positional-like dim, if any, is discarded either way.
271
+ value = shape[-1]
272
+ if low_dim_orientation in ("x", "y", "z"):
273
+ shape = (1,) * "cxyz".index(low_dim_orientation) + (value,)
257
274
  else:
258
- error_msg = f"unsupported orientation: {one_dim_orientation}"
275
+ error_msg = f"unsupported orientation: {low_dim_orientation}"
259
276
  raise ValueError(error_msg)
260
- elif len(shape) == 2:
261
- # A 2D non-batch shape (e.g. (seq_len, hidden_size) from an RNN/attention layer)
262
- # isn't a CNN feature map missing a channel dim - there's no channel axis at all.
263
- # Box's "3D" skew (de, below) is driven by shape[1], so a dummy 1 goes there
264
- # instead of either real dim, keeping the two real dims on the box's actual width
265
- # and height instead of one of them inflating the skew for a long sequence.
266
- shape = (shape[0], 1, shape[1])
267
277
 
268
278
  ori_shape = shape
269
279
  shape = shape + (1,) * (4 - len(shape)) # expand 4D.
@@ -278,6 +288,7 @@ def _box_factory(
278
288
 
279
289
  box = Box()
280
290
  box.output_shape = tuple(ori_shape)
291
+ box.extra_output_shapes = layer.extra_output_shapes
281
292
  box.de = int(x / 3) if draw_volume else 0
282
293
 
283
294
  box.x1 = 0
@@ -476,7 +487,7 @@ def _draw_legend(
476
487
 
477
488
  def _column_label_and_center(column: list[VolumetricBox]) -> tuple[str, float]:
478
489
  """A column's shape label (joined across branches) and its shared x-center."""
479
- label = " / ".join(str(box.output_shape) for box in column)
490
+ label = " / ".join(format_shape_label(box.output_shape, box.extra_output_shapes) for box in column)
480
491
  center_x = (column[0].x1 + column[0].x2) / 2
481
492
  return label, center_x
482
493
 
@@ -15,7 +15,7 @@ from PIL import Image, ImageFont
15
15
  from .backend import extract_architecture
16
16
  from .connectors import compute_skip_levels, draw_connector
17
17
  from .utils.traced_layer import TracedLayer
18
- from .utils.utils import Box, Circle, ColorWheel, Ellipses, ImageDraw, InputShape
18
+ from .utils.utils import Box, Circle, ColorWheel, Ellipses, ImageDraw, InputShape, format_shape_label, resolve_palette
19
19
 
20
20
 
21
21
  def graph_view(
@@ -23,6 +23,7 @@ def graph_view(
23
23
  input_shape: InputShape,
24
24
  to_file: str | None = None,
25
25
  color_map: dict[Any, Any] | None = None,
26
+ palette: str = "okabe_ito",
26
27
  node_size: int = 50,
27
28
  background_fill: str | tuple[int, ...] = "white",
28
29
  padding: int = 10,
@@ -51,6 +52,10 @@ def graph_view(
51
52
  will disable writing.
52
53
  color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
53
54
  to default values for not specified classes.
55
+ palette (str, optional): Named color palette used as the fallback for any layer type not
56
+ given an explicit override via `color_map`. One of `"okabe_ito"` (default,
57
+ colorblind-safe), `"tol_bright"`, `"tol_muted"`, `"tab10"`, `"grayscale"`, `"nord"`,
58
+ `"dracula"`, `"gruvbox"`, `"solarized"`, `"material"`, `"catppuccin"`.
54
59
  node_size (int, optional): Size in pixels each node will have.
55
60
  background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
56
61
  padding (int, optional): Distance in pixels before the first and after the last layer.
@@ -111,7 +116,7 @@ def graph_view(
111
116
  _color_map,
112
117
  opacity,
113
118
  layer_spacing,
114
- ColorWheel(),
119
+ ColorWheel(colors=resolve_palette(palette)),
115
120
  )
116
121
 
117
122
  # An edge whose endpoint was just dropped above (a hidden input's own edges) can no longer
@@ -409,7 +414,8 @@ def _create_architecture(
409
414
 
410
415
  id_to_node_list_map[layer.node_id] = layer_nodes
411
416
  nodes.extend(layer_nodes)
412
- column_labels.append((str(layer.output_shape), current_x + node_size / 2, current_y))
417
+ label = format_shape_label(layer.output_shape, layer.extra_output_shapes)
418
+ column_labels.append((label, current_x + node_size / 2, current_y))
413
419
  current_y += 2 * node_size
414
420
 
415
421
  layer_y.append(current_y - node_spacing - 2 * node_size)
@@ -17,7 +17,16 @@ from .backend import Architecture, extract_architecture
17
17
  from .connectors import compute_skip_levels, draw_connector
18
18
  from .utils.layer_utils import Input
19
19
  from .utils.traced_layer import TracedLayer
20
- from .utils.utils import ColorWheel, ImageDraw, InputShape, StackedBox, get_rgba_tuple, self_multiply
20
+ from .utils.utils import (
21
+ ColorWheel,
22
+ ImageDraw,
23
+ InputShape,
24
+ StackedBox,
25
+ format_shape_label,
26
+ get_rgba_tuple,
27
+ resolve_palette,
28
+ self_multiply,
29
+ )
21
30
 
22
31
  _LABEL_ROW_HEIGHT = 100
23
32
 
@@ -33,7 +42,8 @@ def lenet_view(
33
42
  scale_xy: float = 1,
34
43
  type_ignore: list | None = None,
35
44
  color_map: dict | None = None,
36
- one_dim_orientation: str = "z",
45
+ palette: str = "okabe_ito",
46
+ low_dim_orientation: str = "z",
37
47
  background_fill: str | tuple[int, ...] = "white",
38
48
  padding: int = 10,
39
49
  spacing: int = 10,
@@ -69,7 +79,13 @@ def lenet_view(
69
79
  type_ignore (list, optional): List of layer types in the torch model to ignore during drawing.
70
80
  color_map (dict, optional): Dictionary defining fill and outline colors for each layer by class type.
71
81
  Will fallback to default values for unspecified classes.
72
- one_dim_orientation (str, optional): Axis on which one-dim layers should be drawn. E.g., 'x', 'y', or 'z'.
82
+ palette (str, optional): Named color palette used as the fallback for any layer type not
83
+ given an explicit override via `color_map`. One of `"okabe_ito"` (default,
84
+ colorblind-safe), `"tol_bright"`, `"tol_muted"`, `"tab10"`, `"grayscale"`, `"nord"`,
85
+ `"dracula"`, `"gruvbox"`, `"solarized"`, `"material"`, `"catppuccin"`.
86
+ low_dim_orientation (str, optional): Axis on which a layer without real spatial/channel
87
+ structure (a 1D shape, or a 2D shape like an RNN/attention layer's
88
+ `(seq_len, hidden_size)`) should be drawn. One of `'x'`, `'y'`, or `'z'`.
73
89
  background_fill (str or tuple, optional): Background color for the image. A string or a tuple (R, G, B, A).
74
90
  padding (int, optional): Distance in pixels before the first and after the last layer.
75
91
  spacing (int, optional): Spacing in pixels between two layers.
@@ -122,7 +138,7 @@ def lenet_view(
122
138
 
123
139
  layer_types: list[type] = []
124
140
  make_box = _box_factory(
125
- one_dim_orientation,
141
+ low_dim_orientation,
126
142
  scale_xy,
127
143
  min_xy,
128
144
  max_xy,
@@ -134,7 +150,7 @@ def lenet_view(
134
150
  opacity,
135
151
  offset_z,
136
152
  layer_types,
137
- ColorWheel(),
153
+ ColorWheel(colors=resolve_palette(palette)),
138
154
  )
139
155
  column_layout = layout_columns(
140
156
  filtered_columns,
@@ -218,7 +234,7 @@ def _right_extent_for(box: VolumetricBox) -> float:
218
234
 
219
235
 
220
236
  def _box_factory(
221
- one_dim_orientation: str,
237
+ low_dim_orientation: str,
222
238
  scale_xy: float,
223
239
  min_xy: int,
224
240
  max_xy: int,
@@ -237,19 +253,20 @@ def _box_factory(
237
253
  def make_box(layer: TracedLayer) -> StackedBox:
238
254
  shape = layer.output_shape[1:] # drop batch size
239
255
 
240
- if len(shape) == 1:
241
- if one_dim_orientation in ("x", "y", "z"):
242
- shape = (1,) * "cxyz".index(one_dim_orientation) + shape
256
+ if len(shape) in (1, 2):
257
+ # Neither a 1D nor a 2D shape has real spatial/channel structure - there's nothing
258
+ # to distinguish "channel" from "spatial" the way a genuine (C, H, W) feature map
259
+ # does. Take the last value (for 2D, e.g. an RNN/attention layer's
260
+ # (seq_len, hidden_size), this is the feature/channel-like one, matching PyTorch's
261
+ # (..., seq, feature) convention for sequence data; for 1D it's the only value) and
262
+ # let the user place it on whichever axis they choose, same as any 1D value - the
263
+ # positional-like dim, if any, is discarded either way.
264
+ value = shape[-1]
265
+ if low_dim_orientation in ("x", "y", "z"):
266
+ shape = (1,) * "cxyz".index(low_dim_orientation) + (value,)
243
267
  else:
244
- error_msg = f"unsupported orientation: {one_dim_orientation}"
268
+ error_msg = f"unsupported orientation: {low_dim_orientation}"
245
269
  raise ValueError(error_msg)
246
- elif len(shape) == 2:
247
- # A 2D non-batch shape (e.g. (seq_len, hidden_size) from an RNN/attention layer)
248
- # isn't a CNN feature map missing a channel dim - there's no channel axis at all.
249
- # StackedBox's slice count (de, below) is driven by shape[0], so a dummy 1 goes
250
- # there instead of either real dim, keeping the two real dims on the box's actual
251
- # width and height instead of one of them being drawn as that many stacked slices.
252
- shape = (1, *shape)
253
270
 
254
271
  ori_shape = shape
255
272
  shape = shape + (1,) * (4 - len(shape)) # expand 4D.
@@ -266,6 +283,7 @@ def _box_factory(
266
283
  box.offset_z = offset_z
267
284
  box.label = layer.module.name() if isinstance(layer.module, Input) else layer_type.__name__
268
285
  box.output_shape = tuple(ori_shape)
286
+ box.extra_output_shapes = layer.extra_output_shapes
269
287
  box.de = z
270
288
 
271
289
  box.x1 = 0
@@ -405,6 +423,7 @@ def _draw_labels(
405
423
  for box in column:
406
424
  loc_x = box.x1 + (box.x2 - box.x1) // 4
407
425
  label = getattr(box, "label", type(box).__name__)
408
- draw_text.text((loc_x, img.height - 50), f"{label} {box.output_shape}", font=font, fill=font_color)
426
+ shape_label = format_shape_label(box.output_shape, box.extra_output_shapes)
427
+ draw_text.text((loc_x, img.height - 50), f"{label} {shape_label}", font=font, fill=font_color)
409
428
 
410
429
  return Image.alpha_composite(img, text_img)
@@ -32,6 +32,7 @@ class CommonOptions:
32
32
 
33
33
  to_file: str | None = None
34
34
  color_map: dict[Any, Any] | None = None
35
+ palette: str = "okabe_ito"
35
36
  background_fill: str | tuple[int, ...] = "white"
36
37
  padding: int = 10
37
38
  opacity: int = 255
@@ -66,7 +67,7 @@ class FlowStyleOptions:
66
67
  scale_z: float = 0.1
67
68
  scale_xy: float = 1
68
69
  type_ignore: list[type] | None = None
69
- one_dim_orientation: str = "z"
70
+ low_dim_orientation: str = "z"
70
71
  draw_volume: bool = True
71
72
  spacing: int = 10
72
73
  draw_funnel: bool = True
@@ -86,7 +87,7 @@ class LenetStyleOptions:
86
87
  scale_z: float = 1
87
88
  scale_xy: float = 1
88
89
  type_ignore: list[type] | None = None
89
- one_dim_orientation: str = "z"
90
+ low_dim_orientation: str = "z"
90
91
  spacing: int = 10
91
92
  draw_funnel: bool = True
92
93
  shade_step: int = 10
@@ -107,6 +108,7 @@ def _render_graph(
107
108
  input_shape,
108
109
  to_file=common.to_file,
109
110
  color_map=common.color_map,
111
+ palette=common.palette,
110
112
  node_size=options.node_size,
111
113
  background_fill=common.background_fill,
112
114
  padding=common.padding,
@@ -143,7 +145,8 @@ def _render_flow(
143
145
  scale_xy=options.scale_xy,
144
146
  type_ignore=options.type_ignore,
145
147
  color_map=common.color_map,
146
- one_dim_orientation=options.one_dim_orientation,
148
+ palette=common.palette,
149
+ low_dim_orientation=options.low_dim_orientation,
147
150
  background_fill=common.background_fill,
148
151
  draw_volume=options.draw_volume,
149
152
  padding=common.padding,
@@ -177,7 +180,8 @@ def _render_lenet(
177
180
  scale_xy=options.scale_xy,
178
181
  type_ignore=options.type_ignore,
179
182
  color_map=common.color_map,
180
- one_dim_orientation=options.one_dim_orientation,
183
+ palette=common.palette,
184
+ low_dim_orientation=options.low_dim_orientation,
181
185
  background_fill=common.background_fill,
182
186
  padding=common.padding,
183
187
  spacing=options.spacing,
@@ -106,28 +106,34 @@ def _wrap_and_stamp(obj: Any, ids: set[str]) -> Any:
106
106
  return obj
107
107
 
108
108
 
109
- def _first_tensor_shape(obj: Any) -> tuple[int, ...]:
110
- """Recursively find the shape of the first tensor inside obj, or () if there isn't one."""
109
+ def _all_tensor_shapes(obj: Any) -> list[tuple[int, ...]]:
110
+ """Recursively collect the shape of every tensor inside obj, in encounter order.
111
+
112
+ A module's return value isn't always a single tensor - `nn.LSTM` returns
113
+ `(output, (h_n, c_n))`, `nn.MultiheadAttention` returns `(attn_output, attn_weights)`, and a
114
+ custom module can return any tuple/list/dict of tensors. Collecting all of them (rather than
115
+ just the first) is what lets a caller show every output shape instead of silently dropping
116
+ every tensor after the first.
117
+ """
111
118
  if isinstance(obj, torch.Tensor):
112
- return tuple(obj.shape)
119
+ return [tuple(obj.shape)]
113
120
  if isinstance(obj, Mapping):
121
+ shapes = []
114
122
  for value in obj.values():
115
- shape = _first_tensor_shape(value)
116
- if shape:
117
- return shape
118
- return ()
123
+ shapes.extend(_all_tensor_shapes(value))
124
+ return shapes
119
125
  if isinstance(obj, list | tuple):
126
+ shapes = []
120
127
  for value in obj:
121
- shape = _first_tensor_shape(value)
122
- if shape:
123
- return shape
124
- return ()
125
- return ()
128
+ shapes.extend(_all_tensor_shapes(value))
129
+ return shapes
130
+ return []
126
131
 
127
132
 
128
133
  def _wrapped_module_call(
129
134
  id_to_module: dict[str, nn.Module],
130
135
  id_to_output_shape: dict[str, tuple[int, ...]],
136
+ id_to_extra_output_shapes: dict[str, tuple[tuple[int, ...], ...]],
131
137
  edges: list[tuple[str, str]],
132
138
  call_counts: dict[int, int],
133
139
  ) -> Any:
@@ -154,8 +160,11 @@ def _wrapped_module_call(
154
160
  call_counts[base_id] = call_index + 1
155
161
  node_id = f"{base_id}#{call_index}"
156
162
 
163
+ shapes = _all_tensor_shapes(out)
157
164
  id_to_module[node_id] = mod
158
- id_to_output_shape[node_id] = _first_tensor_shape(out)
165
+ id_to_output_shape[node_id] = shapes[0] if shapes else ()
166
+ if len(shapes) > 1:
167
+ id_to_extra_output_shapes[node_id] = tuple(shapes[1:])
159
168
  edges.extend((producer_id, node_id) for producer_id in producer_ids)
160
169
  out = _wrap_and_stamp(out, {node_id})
161
170
 
@@ -171,11 +180,13 @@ class Recorder:
171
180
  self,
172
181
  id_to_module: dict[str, nn.Module],
173
182
  id_to_output_shape: dict[str, tuple[int, ...]],
183
+ id_to_extra_output_shapes: dict[str, tuple[tuple[int, ...], ...]],
174
184
  edges: list[tuple[str, str]],
175
185
  call_counts: dict[int, int],
176
186
  ) -> None:
177
187
  self._id_to_module = id_to_module
178
188
  self._id_to_output_shape = id_to_output_shape
189
+ self._id_to_extra_output_shapes = id_to_extra_output_shapes
179
190
  self._edges = edges
180
191
  self._call_counts = call_counts
181
192
 
@@ -184,6 +195,7 @@ class Recorder:
184
195
  nn.Module.__call__ = _wrapped_module_call( # type: ignore[method-assign]
185
196
  self._id_to_module,
186
197
  self._id_to_output_shape,
198
+ self._id_to_extra_output_shapes,
187
199
  self._edges,
188
200
  self._call_counts,
189
201
  )
@@ -197,7 +209,13 @@ class Recorder:
197
209
  def trace_module_graph(
198
210
  model: nn.Module,
199
211
  input_shapes: tuple[tuple[int, ...], ...],
200
- ) -> tuple[dict[str, nn.Module], dict[str, tuple[int, ...]], list[tuple[str, str]], list[str]]:
212
+ ) -> tuple[
213
+ dict[str, nn.Module],
214
+ dict[str, tuple[int, ...]],
215
+ dict[str, tuple[tuple[int, ...], ...]],
216
+ list[tuple[str, str]],
217
+ list[str],
218
+ ]:
201
219
  """Trace a forward pass to recover the leaf-module call graph.
202
220
 
203
221
  Args:
@@ -210,7 +228,11 @@ def trace_module_graph(
210
228
  tuple: A tuple containing:
211
229
  - id_to_module (dict): Mapping from node id to the leaf module. A leaf module
212
230
  called more than once gets one entry per call (`f"{id(module)}#{call_index}"`).
213
- - id_to_output_shape (dict): Mapping from node id to that module's output shape.
231
+ - id_to_output_shape (dict): Mapping from node id to that module's *first* output
232
+ tensor's shape (the one used for box sizing).
233
+ - id_to_extra_output_shapes (dict): Mapping from node id to the shapes of any
234
+ *additional* output tensors beyond the first (e.g. `nn.LSTM`'s `(h_n, c_n)`) -
235
+ only present for nodes that return more than one tensor.
214
236
  - edges (list): `(producer_node_id, consumer_node_id)` pairs, in call order.
215
237
  - input_ids (list): One synthetic node id per input tensor, in the same order as
216
238
  `input_shapes`.
@@ -223,10 +245,11 @@ def trace_module_graph(
223
245
 
224
246
  id_to_module: dict[str, nn.Module] = {}
225
247
  id_to_output_shape: dict[str, tuple[int, ...]] = {}
248
+ id_to_extra_output_shapes: dict[str, tuple[tuple[int, ...], ...]] = {}
226
249
  edges: list[tuple[str, str]] = []
227
250
  call_counts: dict[int, int] = {}
228
251
 
229
- with Recorder(id_to_module, id_to_output_shape, edges, call_counts):
252
+ with Recorder(id_to_module, id_to_output_shape, id_to_extra_output_shapes, edges, call_counts):
230
253
  if isinstance(model, nn.ModuleList):
231
254
  # nn.ModuleList has no forward() of its own - it's a plain container, not meant to
232
255
  # be called directly - so drive it the same way a user would: chain each child call.
@@ -244,4 +267,4 @@ def trace_module_graph(
244
267
  model(*dummy_inputs)
245
268
 
246
269
  input_ids = [f"{INPUT_NODE_ID}#{i}" for i in range(len(input_shapes))]
247
- return id_to_module, id_to_output_shape, edges, input_ids
270
+ return id_to_module, id_to_output_shape, id_to_extra_output_shapes, edges, input_ids
@@ -20,3 +20,9 @@ class TracedLayer:
20
20
  module: nn.Module
21
21
  output_shape: tuple[int, ...]
22
22
  node_id: str
23
+ extra_output_shapes: tuple[tuple[int, ...], ...] = ()
24
+ """Shapes of any additional output tensors beyond `output_shape` (e.g. `nn.LSTM`'s hidden and
25
+ cell state, returned alongside its main sequence output) - empty for a module that returns
26
+ just one tensor. `output_shape` alone still drives box sizing; this is only ever read to
27
+ extend the `show_dimension` label so those extra tensors aren't silently unaccounted for.
28
+ """
@@ -66,6 +66,7 @@ class StackedBox(Shape):
66
66
  offset_z: int
67
67
  label: str
68
68
  output_shape: tuple
69
+ extra_output_shapes: tuple[tuple[int, ...], ...] = ()
69
70
 
70
71
  def draw(self, draw: ImageDraw) -> None:
71
72
  """Draw box shape."""
@@ -114,6 +115,7 @@ class Box(Shape):
114
115
  de: int
115
116
  shade: int
116
117
  output_shape: tuple[int, ...]
118
+ extra_output_shapes: tuple[tuple[int, ...], ...] = ()
117
119
 
118
120
  def draw(self, draw: ImageDraw) -> None:
119
121
  """Draw box shape."""
@@ -225,26 +227,116 @@ class Ellipses(Shape):
225
227
  )
226
228
 
227
229
 
230
+ PALETTES: dict[str, list[str]] = {
231
+ # Okabe-Ito: a colorblind-safe palette (Okabe & Ito, 2008) widely recommended for
232
+ # scientific visualization, e.g. in Nature's figure guidelines.
233
+ "okabe_ito": ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"],
234
+ # Paul Tol's "bright" qualitative scheme - colorblind-safe, higher-contrast alternative.
235
+ "tol_bright": ["#4477AA", "#EE6677", "#228833", "#CCBB44", "#66CCEE", "#AA3377", "#BBBBBB"],
236
+ # Paul Tol's "muted" qualitative scheme - colorblind-safe, softer aesthetic.
237
+ "tol_muted": [
238
+ "#CC6677",
239
+ "#332288",
240
+ "#DDCC77",
241
+ "#117733",
242
+ "#88CCEE",
243
+ "#882255",
244
+ "#44AA99",
245
+ "#999933",
246
+ "#AA4499",
247
+ ],
248
+ # matplotlib's default color cycle.
249
+ "tab10": [
250
+ "#1f77b4",
251
+ "#ff7f0e",
252
+ "#2ca02c",
253
+ "#d62728",
254
+ "#9467bd",
255
+ "#8c564b",
256
+ "#e377c2",
257
+ "#7f7f7f",
258
+ "#bcbd22",
259
+ "#17becf",
260
+ ],
261
+ # Evenly-spaced grays for print/monochrome-safe figures.
262
+ "grayscale": ["#404040", "#595959", "#737373", "#8c8c8c", "#a6a6a6", "#bfbfbf", "#d9d9d9"],
263
+ # Nord's Aurora + Frost accent colors.
264
+ "nord": [
265
+ "#bf616a",
266
+ "#d08770",
267
+ "#ebcb8b",
268
+ "#a3be8c",
269
+ "#b48ead",
270
+ "#8fbcbb",
271
+ "#88c0d0",
272
+ "#81a1c1",
273
+ "#5e81ac",
274
+ ],
275
+ # Dracula theme's accent colors.
276
+ "dracula": ["#FF5555", "#FFB86C", "#F1FA8C", "#50FA7B", "#8BE9FD", "#BD93F9", "#FF79C6"],
277
+ # Gruvbox's bright color variants.
278
+ "gruvbox": ["#fb4934", "#b8bb26", "#fabd2f", "#83a598", "#d3869b", "#8ec07c", "#fe8019"],
279
+ # Solarized's accent colors.
280
+ "solarized": [
281
+ "#b58900",
282
+ "#cb4b16",
283
+ "#dc322f",
284
+ "#d33682",
285
+ "#6c71c4",
286
+ "#268bd2",
287
+ "#2aa198",
288
+ "#859900",
289
+ ],
290
+ # Material Design's 500-weight color spread.
291
+ "material": [
292
+ "#f44336",
293
+ "#e91e63",
294
+ "#9c27b0",
295
+ "#3f51b5",
296
+ "#2196f3",
297
+ "#009688",
298
+ "#4caf50",
299
+ "#ffc107",
300
+ "#ff5722",
301
+ ],
302
+ # Catppuccin's Mocha flavor accent colors.
303
+ "catppuccin": [
304
+ "#f38ba8",
305
+ "#fab387",
306
+ "#f9e2af",
307
+ "#a6e3a1",
308
+ "#94e2d5",
309
+ "#89dceb",
310
+ "#89b4fa",
311
+ "#b4befe",
312
+ "#cba6f7",
313
+ "#f5c2e7",
314
+ ],
315
+ }
316
+
317
+
318
+ def resolve_palette(name: str) -> list[str]:
319
+ """Resolve a named palette to its list of hex colors.
320
+
321
+ Args:
322
+ name (str): One of the keys in `PALETTES`.
323
+
324
+ Returns:
325
+ list[str]: The palette's hex color strings.
326
+ """
327
+ if name not in PALETTES:
328
+ supported = ", ".join(sorted(PALETTES))
329
+ error_msg = f"Unsupported palette {name!r}. Supported palettes: {supported}."
330
+ raise ValueError(error_msg)
331
+ return PALETTES[name]
332
+
333
+
228
334
  class ColorWheel:
229
335
  """Default colors for the shapes."""
230
336
 
231
337
  def __init__(self, colors: list | None = None) -> None:
232
338
  self._cache: dict[type, Any] = {}
233
- # Okabe-Ito: a colorblind-safe palette (Okabe & Ito, 2008) widely recommended for
234
- # scientific visualization, e.g. in Nature's figure guidelines.
235
- self.colors = (
236
- colors
237
- if colors is not None
238
- else [
239
- "#E69F00", # orange
240
- "#56B4E9", # sky blue
241
- "#009E73", # bluish green
242
- "#F0E442", # yellow
243
- "#0072B2", # blue
244
- "#D55E00", # vermillion
245
- "#CC79A7", # reddish purple
246
- ]
247
- )
339
+ self.colors = colors if colors is not None else PALETTES["okabe_ito"]
248
340
 
249
341
  def get_color(self, class_type: type) -> tuple | None:
250
342
  """Return color from cache if exist, if not, get from the list and store it to the cache."""
@@ -344,6 +436,22 @@ def self_multiply(tensor_tuple: tuple | list) -> int | float:
344
436
  return s
345
437
 
346
438
 
439
+ def format_shape_label(output_shape: tuple[int, ...], extra_output_shapes: tuple[tuple[int, ...], ...]) -> str:
440
+ """Format an output shape for display, appending any extra output shapes if present.
441
+
442
+ A module that returns more than one meaningful tensor (e.g. `nn.LSTM`'s `(output, (h_n,
443
+ c_n))`) would otherwise only ever show `output_shape` (the first tensor found, also the one
444
+ driving box size) with no indication the other tensors exist at all. `+` is used rather than
445
+ `visualtorch`'s existing `/` convention (already used to join sibling branches within one
446
+ column) so this doesn't read as an alternative/branch - these are all real outputs of this
447
+ same node, not one of several options.
448
+ """
449
+ label = str(output_shape)
450
+ if extra_output_shapes:
451
+ label += " + " + " + ".join(str(shape) for shape in extra_output_shapes)
452
+ return label
453
+
454
+
347
455
  def vertical_image_concat(
348
456
  im1: Image,
349
457
  im2: Image,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: visualtorch
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -54,21 +54,22 @@ Dynamic: summary
54
54
 
55
55
  </div>
56
56
 
57
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
57
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating flow-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. Its original visual styles were inspired by [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview); since then, it has grown its own unified tracing backend and architecture-handling logic well beyond its origins.
58
58
 
59
59
  **Note:** `1.0+` is a major release with breaking API changes, but with significantly better features and algorithms - upgrading is recommended. For the old API, use `0.2.5` or older.
60
60
 
61
- **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has two inherent
62
- limitations shared by any tracing-based approach (not bugs, and not fixable without full symbolic
63
- execution): (1) models with **data-dependent control flow** (e.g. a branch only taken if a tensor
64
- value crosses some threshold) only show whichever branch the traced dummy input happened to take;
65
- (2) a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task head)
66
- only has its first tensor's shape reflected in that node's size/label - its downstream connections
67
- are still correct either way. Contributions are welcome!
61
+ **Limitation:** VisualTorch traces a real forward pass to build the diagram, which has an inherent
62
+ limitation shared by any tracing-based approach (not a bug, and not fixable without full symbolic
63
+ execution): models with **data-dependent control flow** (e.g. a branch only taken if a tensor
64
+ value crosses some threshold) only show whichever branch the traced dummy input happened to take.
65
+ Separately, a layer that returns **multiple meaningful output tensors** (e.g. a custom multi-task
66
+ head, or `nn.LSTM`'s `(output, (h_n, c_n))`) still has its node's size based on only its first
67
+ tensor; with `show_dimension=True`, every output tensor's shape is shown in the label, not just
68
+ the first. Downstream connections are correct either way. Contributions are welcome!
68
69
 
69
70
  <div align="center">
70
71
 
71
- ![VisualTorch Examples](docs/source/_static/images/banners/readme-examples.png)
72
+ ![VisualTorch Examples](https://raw.githubusercontent.com/willyfh/visualtorch/e6ad79751e0f7412b1074beb45f9baeccd1419e4/docs/source/_static/images/banners/readme-examples.png)
72
73
 
73
74
  </div>
74
75
 
@@ -100,16 +101,16 @@ Please feel free to send a pull request to contribute to this project by followi
100
101
 
101
102
  This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
102
103
 
103
- Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
104
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), [pytorch-summary](https://github.com/sksq96/pytorch-summary), and [torchview](https://github.com/mert-kurttutan/torchview), all of which are also licensed under the MIT license.
104
105
 
105
106
  ## Citation
106
107
 
107
108
  Please cite this project in your publications if it helps your research.
108
109
 
109
- **Note:** the paper below describes the API as of its publication date (2024). VisualTorch has
110
- since had breaking API changes (see the [documentation](https://visualtorch.readthedocs.io/en/latest/)
111
- for the current API) - the DOI always resolves to what was actually reviewed and published, so
112
- it isn't updated to match.
110
+ **Note:** the paper below describes VisualTorch as of its publication date (2024). The project has
111
+ since been substantially refactored, including breaking API changes (see the
112
+ [documentation](https://visualtorch.readthedocs.io/en/latest/) for the current API) - the DOI
113
+ always resolves to what was actually reviewed and published.
113
114
 
114
115
  ```bibtex
115
116
  @article{Hendria2024,
File without changes
File without changes
File without changes