visualtorch 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- visualtorch-0.2.2/PKG-INFO +64 -0
- visualtorch-0.2.2/README.md +46 -0
- visualtorch-0.2.2/pyproject.toml +167 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/setup.py +10 -2
- visualtorch-0.2.2/visualtorch/__init__.py +10 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/visualtorch/graph.py +141 -91
- {visualtorch-0.2.0 → visualtorch-0.2.2}/visualtorch/layered.py +172 -119
- visualtorch-0.2.2/visualtorch/lenet_style.py +326 -0
- visualtorch-0.2.2/visualtorch/utils/__init__.py +44 -0
- visualtorch-0.2.2/visualtorch/utils/layer_utils.py +240 -0
- {visualtorch-0.2.0/visualtorch → visualtorch-0.2.2/visualtorch/utils}/utils.py +121 -94
- visualtorch-0.2.2/visualtorch.egg-info/PKG-INFO +64 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/visualtorch.egg-info/SOURCES.txt +6 -3
- visualtorch-0.2.0/PKG-INFO +0 -216
- visualtorch-0.2.0/README.md +0 -198
- visualtorch-0.2.0/visualtorch/__init__.py +0 -4
- visualtorch-0.2.0/visualtorch/layer_utils.py +0 -161
- visualtorch-0.2.0/visualtorch.egg-info/PKG-INFO +0 -216
- {visualtorch-0.2.0 → visualtorch-0.2.2}/LICENSE +0 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/setup.cfg +0 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/visualtorch.egg-info/dependency_links.txt +0 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/visualtorch.egg-info/requires.txt +0 -0
- {visualtorch-0.2.0 → visualtorch-0.2.2}/visualtorch.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: visualtorch
|
|
3
|
+
Version: 0.2.2
|
|
4
|
+
Summary: Architecture visualization of Torch models
|
|
5
|
+
Home-page: https://github.com/willyfh/visualtorch
|
|
6
|
+
Author: Willy Fitra Hendria
|
|
7
|
+
Author-email: willyfitrahendria@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: visualize architecture,torch visualization,visualtorch
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
|
|
19
|
+
<div align="center">
|
|
20
|
+
<h1>🔥 VisualTorch 🔥</h1>
|
|
21
|
+
|
|
22
|
+
[]() []() [](https://pepy.tech/project/visualtorch) [](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml) [](https://visualtorch.readthedocs.io/en/latest/?badge=latest)
|
|
23
|
+
|
|
24
|
+
</div>
|
|
25
|
+
|
|
26
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
|
|
27
|
+
|
|
28
|
+
**Note:** VisualTorch may not yet support complex models, but contributions are welcome!
|
|
29
|
+
|
|
30
|
+
<div align="center">
|
|
31
|
+
|
|
32
|
+

|
|
33
|
+
|
|
34
|
+
</div>
|
|
35
|
+
|
|
36
|
+
## Documentation
|
|
37
|
+
|
|
38
|
+
Online documentation is available at [visualtorch.readthedocs.io](https://visualtorch.readthedocs.io/en/latest/).
|
|
39
|
+
|
|
40
|
+
The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html), [API references](https://visualtorch.readthedocs.io/en/latest/markdown/api_references/index.html), and other useful information.
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
|
|
45
|
+
|
|
46
|
+
## Examples
|
|
47
|
+
|
|
48
|
+
See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
|
|
49
|
+
|
|
50
|
+
## Contributing
|
|
51
|
+
|
|
52
|
+
Please feel free to send a pull request to contribute to this project.
|
|
53
|
+
|
|
54
|
+
## License
|
|
55
|
+
|
|
56
|
+
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).
|
|
57
|
+
|
|
58
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
|
|
59
|
+
|
|
60
|
+
## Citation
|
|
61
|
+
|
|
62
|
+
Please cite this project in your publications if it helps your research.
|
|
63
|
+
|
|
64
|
+
[A ready-made citation entry](https://visualtorch.readthedocs.io/en/latest/index.html#citation) is available.
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
<div align="center">
|
|
2
|
+
<h1>🔥 VisualTorch 🔥</h1>
|
|
3
|
+
|
|
4
|
+
[]() []() [](https://pepy.tech/project/visualtorch) [](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml) [](https://visualtorch.readthedocs.io/en/latest/?badge=latest)
|
|
5
|
+
|
|
6
|
+
</div>
|
|
7
|
+
|
|
8
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
|
|
9
|
+
|
|
10
|
+
**Note:** VisualTorch may not yet support complex models, but contributions are welcome!
|
|
11
|
+
|
|
12
|
+
<div align="center">
|
|
13
|
+
|
|
14
|
+

|
|
15
|
+
|
|
16
|
+
</div>
|
|
17
|
+
|
|
18
|
+
## Documentation
|
|
19
|
+
|
|
20
|
+
Online documentation is available at [visualtorch.readthedocs.io](https://visualtorch.readthedocs.io/en/latest/).
|
|
21
|
+
|
|
22
|
+
The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html), [API references](https://visualtorch.readthedocs.io/en/latest/markdown/api_references/index.html), and other useful information.
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
|
|
27
|
+
|
|
28
|
+
## Examples
|
|
29
|
+
|
|
30
|
+
See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
|
|
31
|
+
|
|
32
|
+
## Contributing
|
|
33
|
+
|
|
34
|
+
Please feel free to send a pull request to contribute to this project.
|
|
35
|
+
|
|
36
|
+
## License
|
|
37
|
+
|
|
38
|
+
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).
|
|
39
|
+
|
|
40
|
+
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
|
|
41
|
+
|
|
42
|
+
## Citation
|
|
43
|
+
|
|
44
|
+
Please cite this project in your publications if it helps your research.
|
|
45
|
+
|
|
46
|
+
[A ready-made citation entry](https://visualtorch.readthedocs.io/en/latest/index.html#citation) is available.
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
2
|
+
# SETUP CONFIGURATION. #
|
|
3
|
+
[build-system]
|
|
4
|
+
requires = ["setuptools>=42", "wheel"]
|
|
5
|
+
build-backend = "setuptools.build_meta"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
9
|
+
# RUFF CONFIGURATION #
|
|
10
|
+
[tool.ruff]
|
|
11
|
+
# Enable rules
|
|
12
|
+
select = [
|
|
13
|
+
"F", # Pyflakes (`F`)
|
|
14
|
+
"E", # pycodestyle error (`E`)
|
|
15
|
+
"W", # pycodestyle warning (`W`)
|
|
16
|
+
"C90", # mccabe (`C90`)
|
|
17
|
+
"I", # isort (`I`)
|
|
18
|
+
"N", # pep8-naming (`N`)
|
|
19
|
+
"D", # pydocstyle (`D`)
|
|
20
|
+
"UP", # pyupgrade (`UP`)
|
|
21
|
+
"YTT", # flake8-2020 (`YTT`)
|
|
22
|
+
"ANN", # flake8-annotations (`ANN`)
|
|
23
|
+
"S", # flake8-bandit (`S`)
|
|
24
|
+
"BLE", # flake8-blind-except (`BLE`)
|
|
25
|
+
"FBT", # flake8-boolean-trap (`FBT`)
|
|
26
|
+
"B", # flake8-bugbear (`B`)
|
|
27
|
+
"A", # flake8-builtins (`A`)
|
|
28
|
+
"COM", # flake8-commas (`COM`)
|
|
29
|
+
"CPY", # flake8-copyright (`CPY`)
|
|
30
|
+
"C4", # flake8-comprehensions (`C4`)
|
|
31
|
+
"DTZ", # flake8-datatimez (`DTZ`)
|
|
32
|
+
"T10", # flake8-debugger (`T10`)
|
|
33
|
+
"EM", # flake8-errmsg (`EM`)
|
|
34
|
+
"FA", # flake8-future-annotations (`FA`)
|
|
35
|
+
"ISC", # flake8-implicit-str-concat (`ISC`)
|
|
36
|
+
"ICN", # flake8-import-conventions (`ICN`)
|
|
37
|
+
"PIE", # flake8-pie (`PIE`)
|
|
38
|
+
"PT", # flake8-pytest-style (`PT`)
|
|
39
|
+
"RSE", # flake8-raise (`RSE`)
|
|
40
|
+
"RET", # flake8-return (`RET`)
|
|
41
|
+
"SLF", # flake8-self (`SLF`)
|
|
42
|
+
"SIM", # flake8-simplify (`SIM`)
|
|
43
|
+
"TID", # flake8-tidy-imports (`TID`)
|
|
44
|
+
"TCH", # flake8-type-checking (`TCH`)
|
|
45
|
+
"INT", # flake8-gettext (`INT`)
|
|
46
|
+
"ARG", # flake8-unsused-arguments (`ARG`)
|
|
47
|
+
"PTH", # flake8-use-pathlib (`PTH`)
|
|
48
|
+
"TD", # flake8-todos (`TD`)
|
|
49
|
+
"FIX", # flake8-fixme (`FIX`)
|
|
50
|
+
"ERA", # eradicate (`ERA`)
|
|
51
|
+
"PD", # pandas-vet (`PD`)
|
|
52
|
+
"PGH", # pygrep-hooks (`PGH`)
|
|
53
|
+
"PL", # pylint (`PL`)
|
|
54
|
+
"TRY", # tryceratos (`TRY`)
|
|
55
|
+
"FLY", # flynt (`FLY`)
|
|
56
|
+
"NPY", # NumPy-specific rules (`NPY`)
|
|
57
|
+
"PERF", # Perflint (`PERF`)
|
|
58
|
+
"RUF", # Ruff-specific rules (`RUF`)
|
|
59
|
+
# "FURB", # refurb (`FURB`) - ERROR: Unknown rule selector: `FURB`
|
|
60
|
+
# "LOG", # flake8-logging (`LOG`) - ERROR: Unknown rule selector: `LOG`
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
ignore = [
|
|
64
|
+
# pydocstyle
|
|
65
|
+
"D107", # Missing docstring in __init__
|
|
66
|
+
"D415", # First line should end with a period, question mark, or exclamation point
|
|
67
|
+
|
|
68
|
+
# pylint
|
|
69
|
+
"PLR0913", # Too many arguments to function call
|
|
70
|
+
"PLR2004", # consider replacing with a constant variable
|
|
71
|
+
"PLR0912", # Too many branches
|
|
72
|
+
"PLR0915", # Too many statements
|
|
73
|
+
|
|
74
|
+
# flake8-annotations
|
|
75
|
+
"ANN101", # Missing-type-self
|
|
76
|
+
"ANN002", # Missing type annotation for *args
|
|
77
|
+
"ANN003", # Missing type annotation for **kwargs
|
|
78
|
+
|
|
79
|
+
# flake8-bandit (`S`)
|
|
80
|
+
"S101", # Use of assert detected.
|
|
81
|
+
|
|
82
|
+
# flake8-boolean-trap (`FBT`)
|
|
83
|
+
"FBT001", # Boolean positional arg in function definition
|
|
84
|
+
"FBT002", # Boolean default value in function definition
|
|
85
|
+
|
|
86
|
+
# flake8-datatimez (`DTZ`)
|
|
87
|
+
"DTZ005", # The use of `datetime.datetime.now()` without `tz` argument is not allowed
|
|
88
|
+
|
|
89
|
+
# flake8-fixme (`FIX`)
|
|
90
|
+
"FIX002", # Line contains TODO, consider resolving the issue
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
# Allow autofix for all enabled rules (when `--fix`) is provided.
|
|
94
|
+
fixable = ["ALL"]
|
|
95
|
+
unfixable = []
|
|
96
|
+
|
|
97
|
+
# Exclude a variety of commonly ignored directories.
|
|
98
|
+
exclude = [
|
|
99
|
+
".bzr",
|
|
100
|
+
".direnv",
|
|
101
|
+
".eggs",
|
|
102
|
+
".git",
|
|
103
|
+
".hg",
|
|
104
|
+
".mypy_cache",
|
|
105
|
+
".nox",
|
|
106
|
+
".pants.d",
|
|
107
|
+
".pytype",
|
|
108
|
+
".ruff_cache",
|
|
109
|
+
".svn",
|
|
110
|
+
".tox",
|
|
111
|
+
".venv",
|
|
112
|
+
"__pypackages__",
|
|
113
|
+
"_build",
|
|
114
|
+
"buck-out",
|
|
115
|
+
"build",
|
|
116
|
+
"dist",
|
|
117
|
+
"node_modules",
|
|
118
|
+
"venv",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Same as Black.
|
|
122
|
+
line-length = 120
|
|
123
|
+
|
|
124
|
+
# Allow unused variables when underscore-prefixed.
|
|
125
|
+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
|
126
|
+
|
|
127
|
+
# Assume Python 3.10.
|
|
128
|
+
target-version = "py310"
|
|
129
|
+
|
|
130
|
+
# Allow imports relative to the "src" and "tests" directories.
|
|
131
|
+
src = ["visualtorch", "tests"]
|
|
132
|
+
|
|
133
|
+
[tool.ruff.mccabe]
|
|
134
|
+
# Unlike Flake8, default to a complexity level of 10.
|
|
135
|
+
max-complexity = 15
|
|
136
|
+
|
|
137
|
+
[tool.ruff.per-file-ignores]
|
|
138
|
+
"tests/nightly/tools/benchmarking/test_benchmarking.py" = ["E402"]
|
|
139
|
+
|
|
140
|
+
[tool.ruff.pydocstyle]
|
|
141
|
+
convention = "google"
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
145
|
+
# MYPY CONFIGURATION. #
|
|
146
|
+
[tool.mypy]
|
|
147
|
+
ignore_missing_imports = true
|
|
148
|
+
show_error_codes = true
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
[[tool.mypy.overrides]]
|
|
152
|
+
module = ["torch.*"]
|
|
153
|
+
follow_imports = "skip"
|
|
154
|
+
follow_imports_for_stubs = true
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
158
|
+
# PYTEST CONFIGURATION #
|
|
159
|
+
[tool.pytest.ini_options]
|
|
160
|
+
addopts = [
|
|
161
|
+
"--strict-markers",
|
|
162
|
+
"--strict-config",
|
|
163
|
+
"--showlocals",
|
|
164
|
+
"-ra",
|
|
165
|
+
]
|
|
166
|
+
testpaths = "tests"
|
|
167
|
+
pythonpath = "visualtorch"
|
|
@@ -1,11 +1,19 @@
|
|
|
1
|
+
"""Setup file for visualtorch."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
1
8
|
import setuptools
|
|
2
9
|
|
|
3
|
-
|
|
10
|
+
file_path = Path("README.md")
|
|
11
|
+
with file_path.open("r") as fh:
|
|
4
12
|
long_description = fh.read()
|
|
5
13
|
|
|
6
14
|
setuptools.setup(
|
|
7
15
|
name="visualtorch",
|
|
8
|
-
version="0.2.
|
|
16
|
+
version="0.2.2",
|
|
9
17
|
author="Willy Fitra Hendria",
|
|
10
18
|
author_email="willyfitrahendria@gmail.com",
|
|
11
19
|
description="Architecture visualization of Torch models",
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Modules for pytorch model visualization."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from visualtorch.graph import graph_view
|
|
7
|
+
from visualtorch.layered import layered_view
|
|
8
|
+
from visualtorch.lenet_style import lenet_view
|
|
9
|
+
|
|
10
|
+
__all__ = ["layered_view", "graph_view", "lenet_view"]
|
|
@@ -1,32 +1,38 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""Graph View module for pytorch model visualization."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
|
+
# SPDX-License-Identifier: MIT
|
|
5
|
+
|
|
6
|
+
from collections import defaultdict
|
|
3
7
|
from math import ceil
|
|
4
|
-
from
|
|
5
|
-
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import aggdraw
|
|
6
11
|
import numpy as np
|
|
7
|
-
from typing import Optional, Dict, Any, Tuple, List
|
|
8
12
|
import torch
|
|
13
|
+
from PIL import Image
|
|
14
|
+
|
|
15
|
+
from .utils.layer_utils import TARGET_OPS, add_input_dummy_layer, model_to_adj_matrix
|
|
16
|
+
from .utils.utils import Box, Circle, Ellipses, get_keys_by_value
|
|
9
17
|
|
|
10
18
|
|
|
11
19
|
def graph_view(
|
|
12
20
|
model: torch.nn.Module,
|
|
13
|
-
input_shape:
|
|
14
|
-
to_file:
|
|
15
|
-
color_map:
|
|
21
|
+
input_shape: tuple[int, ...],
|
|
22
|
+
to_file: str | None = None,
|
|
23
|
+
color_map: dict[Any, Any] | None = None,
|
|
16
24
|
node_size: int = 50,
|
|
17
|
-
background_fill:
|
|
25
|
+
background_fill: str | tuple[int, ...] = "white",
|
|
18
26
|
padding: int = 10,
|
|
19
27
|
layer_spacing: int = 250,
|
|
20
28
|
node_spacing: int = 10,
|
|
21
|
-
connector_fill:
|
|
29
|
+
connector_fill: str | tuple[int, ...] = "gray",
|
|
22
30
|
connector_width: int = 1,
|
|
23
31
|
ellipsize_after: int = 10,
|
|
24
|
-
inout_as_tensor: bool = True,
|
|
25
32
|
show_neurons: bool = True,
|
|
33
|
+
opacity: int = 255,
|
|
26
34
|
) -> Image.Image:
|
|
27
|
-
"""
|
|
28
|
-
Generates an architecture visualization for a given linear PyTorch model (i.e., one input and output tensor for each
|
|
29
|
-
layer) in a graph style.
|
|
35
|
+
"""Generates an architecture visualization for a given linear PyTorch model in a graph style.
|
|
30
36
|
|
|
31
37
|
Args:
|
|
32
38
|
model (torch.nn.Module): A PyTorch model that will be visualized.
|
|
@@ -34,8 +40,8 @@ def graph_view(
|
|
|
34
40
|
to_file (str, optional): Path to the file to write the created image to. If the image does not exist yet,
|
|
35
41
|
it will be created, else overwritten. Image type is inferred from the file ending. Providing None
|
|
36
42
|
will disable writing.
|
|
37
|
-
color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
|
|
38
|
-
values for not specified classes.
|
|
43
|
+
color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
|
|
44
|
+
to default values for not specified classes.
|
|
39
45
|
node_size (int, optional): Size in pixels each node will have.
|
|
40
46
|
background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
|
|
41
47
|
padding (int, optional): Distance in pixels before the first and after the last layer.
|
|
@@ -45,103 +51,59 @@ def graph_view(
|
|
|
45
51
|
connector_width (int, optional): Line-width of the connectors in pixels.
|
|
46
52
|
ellipsize_after (int, optional): Maximum number of neurons per layer to draw. If a layer is exceeding this,
|
|
47
53
|
the remaining neurons will be drawn as ellipses.
|
|
48
|
-
inout_as_tensor (bool, optional): If True there will be one input and output node for each tensor, else the
|
|
49
|
-
tensor will be flattened and one node for each scalar will be created (e.g., a (10, 10) shape will be
|
|
50
|
-
represented by 100 nodes).
|
|
51
54
|
show_neurons (bool, optional): If True a node for each neuron in supported layers is created (constrained by
|
|
52
55
|
ellipsize_after), else each layer is represented by a node.
|
|
56
|
+
opacity (int, optional): Transparency of the color (0 ~ 255).
|
|
53
57
|
|
|
54
58
|
Returns:
|
|
55
59
|
Image.Image: Generated architecture image.
|
|
56
60
|
"""
|
|
57
|
-
|
|
58
|
-
if color_map is None:
|
|
59
|
-
|
|
61
|
+
_color_map: dict = {}
|
|
62
|
+
if color_map is not None:
|
|
63
|
+
_color_map = defaultdict(dict, color_map)
|
|
60
64
|
|
|
61
65
|
# Iterate over the model to compute bounds and generate boxes
|
|
62
66
|
|
|
63
|
-
layers: List[Any] = list()
|
|
64
|
-
layer_y = list()
|
|
65
|
-
|
|
66
67
|
# Attach helper layers
|
|
67
68
|
|
|
68
69
|
id_to_num_mapping, adj_matrix, model_layers = model_to_adj_matrix(
|
|
69
|
-
model,
|
|
70
|
+
model,
|
|
71
|
+
input_shape,
|
|
70
72
|
)
|
|
71
73
|
|
|
72
74
|
# Add fake input layers
|
|
73
75
|
|
|
74
76
|
id_to_num_mapping, adj_matrix, model_layers = add_input_dummy_layer(
|
|
75
|
-
input_shape,
|
|
77
|
+
input_shape,
|
|
78
|
+
id_to_num_mapping,
|
|
79
|
+
adj_matrix,
|
|
80
|
+
model_layers,
|
|
76
81
|
)
|
|
77
82
|
|
|
78
83
|
# Create architecture
|
|
79
84
|
|
|
80
85
|
current_x = padding # + input_label_size[0] + text_padding
|
|
81
86
|
|
|
82
|
-
id_to_node_list_map =
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
is_box = False
|
|
94
|
-
units = layer._saved_bias_sym_sizes_opt[0]
|
|
95
|
-
elif hasattr(layer, "_saved_mat2_sym_sizes"):
|
|
96
|
-
is_box = False
|
|
97
|
-
units = layer._saved_mat2_sym_sizes[1]
|
|
98
|
-
elif hasattr(layer, "units"): # for dummy input layer
|
|
99
|
-
is_box = False
|
|
100
|
-
units = layer.units
|
|
101
|
-
|
|
102
|
-
n = min(units, ellipsize_after)
|
|
103
|
-
layer_nodes = list()
|
|
104
|
-
|
|
105
|
-
for i in range(n):
|
|
106
|
-
scale = 1
|
|
107
|
-
c: Box | Circle | Ellipses
|
|
108
|
-
if not is_box:
|
|
109
|
-
if i != ellipsize_after - 2:
|
|
110
|
-
c = Circle()
|
|
111
|
-
else:
|
|
112
|
-
c = Ellipses()
|
|
113
|
-
else:
|
|
114
|
-
c = Box()
|
|
115
|
-
scale = 3
|
|
116
|
-
|
|
117
|
-
c.x1 = current_x
|
|
118
|
-
c.y1 = current_y
|
|
119
|
-
c.x2 = c.x1 + node_size
|
|
120
|
-
c.y2 = c.y1 + node_size * scale
|
|
121
|
-
|
|
122
|
-
current_y = c.y2 + node_spacing
|
|
123
|
-
|
|
124
|
-
c.fill = color_map.get(type(layer), {}).get("fill", "#ADD8E6")
|
|
125
|
-
c.outline = color_map.get(type(layer), {}).get("outline", "black")
|
|
126
|
-
|
|
127
|
-
layer_nodes.append(c)
|
|
128
|
-
|
|
129
|
-
id_to_node_list_map[str(id(layer))] = layer_nodes
|
|
130
|
-
nodes.extend(layer_nodes)
|
|
131
|
-
current_y += 2 * node_size
|
|
132
|
-
|
|
133
|
-
layer_y.append(current_y - node_spacing - 2 * node_size)
|
|
134
|
-
layers.append(nodes)
|
|
135
|
-
current_x += node_size + layer_spacing
|
|
87
|
+
layers, layer_y, id_to_node_list_map = _create_architecture(
|
|
88
|
+
model_layers,
|
|
89
|
+
current_x,
|
|
90
|
+
show_neurons,
|
|
91
|
+
ellipsize_after,
|
|
92
|
+
node_size,
|
|
93
|
+
node_spacing,
|
|
94
|
+
_color_map,
|
|
95
|
+
opacity,
|
|
96
|
+
layer_spacing,
|
|
97
|
+
)
|
|
136
98
|
|
|
137
99
|
# Generate image
|
|
138
100
|
|
|
139
|
-
img_width = (
|
|
140
|
-
len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
|
|
141
|
-
)
|
|
101
|
+
img_width = len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
|
|
142
102
|
img_height = max(*layer_y) + 2 * padding
|
|
143
103
|
img = Image.new(
|
|
144
|
-
"RGBA",
|
|
104
|
+
"RGBA",
|
|
105
|
+
(int(ceil(img_width)), int(ceil(img_height))),
|
|
106
|
+
background_fill,
|
|
145
107
|
)
|
|
146
108
|
|
|
147
109
|
draw = aggdraw.Draw(img)
|
|
@@ -154,7 +116,7 @@ def graph_view(
|
|
|
154
116
|
node.y1 += y_off
|
|
155
117
|
node.y2 += y_off
|
|
156
118
|
|
|
157
|
-
for start_idx, end_idx in zip(*np.where(adj_matrix > 0)):
|
|
119
|
+
for start_idx, end_idx in zip(*np.where(adj_matrix > 0), strict=False):
|
|
158
120
|
start_id = next(get_keys_by_value(id_to_num_mapping, start_idx))
|
|
159
121
|
end_id = next(get_keys_by_value(id_to_num_mapping, end_idx))
|
|
160
122
|
|
|
@@ -162,10 +124,11 @@ def graph_view(
|
|
|
162
124
|
end_layer_list = id_to_node_list_map[end_id]
|
|
163
125
|
|
|
164
126
|
# draw connectors
|
|
165
|
-
for
|
|
127
|
+
for start_node in start_layer_list:
|
|
166
128
|
for end_node in end_layer_list:
|
|
167
129
|
if not isinstance(start_node, Ellipses) and not isinstance(
|
|
168
|
-
end_node,
|
|
130
|
+
end_node,
|
|
131
|
+
Ellipses,
|
|
169
132
|
):
|
|
170
133
|
_draw_connector(
|
|
171
134
|
draw,
|
|
@@ -175,8 +138,8 @@ def graph_view(
|
|
|
175
138
|
width=connector_width,
|
|
176
139
|
)
|
|
177
140
|
|
|
178
|
-
for
|
|
179
|
-
for
|
|
141
|
+
for layer in layers:
|
|
142
|
+
for node in layer:
|
|
180
143
|
node.draw(draw)
|
|
181
144
|
|
|
182
145
|
draw.flush()
|
|
@@ -187,10 +150,97 @@ def graph_view(
|
|
|
187
150
|
return img
|
|
188
151
|
|
|
189
152
|
|
|
190
|
-
def _draw_connector(
|
|
153
|
+
def _draw_connector(
|
|
154
|
+
draw: aggdraw.Draw,
|
|
155
|
+
start_node: Box | Circle | Ellipses,
|
|
156
|
+
end_node: Box | Circle | Ellipses,
|
|
157
|
+
color: str | tuple[int, ...],
|
|
158
|
+
width: int,
|
|
159
|
+
) -> None:
|
|
160
|
+
"""Draw the line connector between nodes."""
|
|
191
161
|
pen = aggdraw.Pen(color, width)
|
|
192
162
|
x1 = start_node.x2
|
|
193
163
|
y1 = start_node.y1 + (start_node.y2 - start_node.y1) / 2
|
|
194
164
|
x2 = end_node.x1
|
|
195
165
|
y2 = end_node.y1 + (end_node.y2 - end_node.y1) / 2
|
|
196
166
|
draw.line([x1, y1, x2, y2], pen)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _retrieve_isbox_units(layer: torch.autograd.Function, show_neurons: bool) -> tuple[bool, int]:
|
|
170
|
+
"""Return the number of units and the flag whether to visualize using a box or not."""
|
|
171
|
+
print("test: ", type(layer))
|
|
172
|
+
is_box = True
|
|
173
|
+
units = 1
|
|
174
|
+
if show_neurons:
|
|
175
|
+
if hasattr(layer, "_saved_bias_sym_sizes_opt"):
|
|
176
|
+
is_box = False
|
|
177
|
+
units = layer._saved_bias_sym_sizes_opt[0] # noqa: SLF001
|
|
178
|
+
elif hasattr(layer, "_saved_mat2_sym_sizes"):
|
|
179
|
+
is_box = False
|
|
180
|
+
units = layer._saved_mat2_sym_sizes[1] # noqa: SLF001
|
|
181
|
+
elif hasattr(layer, "units"): # for dummy input layer
|
|
182
|
+
is_box = False
|
|
183
|
+
units = layer.units
|
|
184
|
+
return is_box, units
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _create_architecture(
|
|
188
|
+
model_layers: list[list],
|
|
189
|
+
current_x: int,
|
|
190
|
+
show_neurons: bool,
|
|
191
|
+
ellipsize_after: int,
|
|
192
|
+
node_size: int,
|
|
193
|
+
node_spacing: int,
|
|
194
|
+
color_map: dict[Any, Any],
|
|
195
|
+
opacity: int,
|
|
196
|
+
layer_spacing: int,
|
|
197
|
+
) -> tuple[list, list, dict]:
|
|
198
|
+
"""Create nodes of architecture for each layers."""
|
|
199
|
+
id_to_node_list_map = {}
|
|
200
|
+
layers = []
|
|
201
|
+
layer_y = []
|
|
202
|
+
for layer_list in model_layers:
|
|
203
|
+
current_y = 0
|
|
204
|
+
nodes = []
|
|
205
|
+
layer: Any
|
|
206
|
+
for layer in layer_list:
|
|
207
|
+
is_box, units = _retrieve_isbox_units(layer, show_neurons)
|
|
208
|
+
|
|
209
|
+
n = min(units, ellipsize_after)
|
|
210
|
+
layer_nodes = []
|
|
211
|
+
|
|
212
|
+
for i in range(n):
|
|
213
|
+
scale = 1
|
|
214
|
+
c: Box | Circle | Ellipses
|
|
215
|
+
if not is_box:
|
|
216
|
+
c = Circle() if i != ellipsize_after - 2 else Ellipses()
|
|
217
|
+
else:
|
|
218
|
+
c = Box()
|
|
219
|
+
scale = 3
|
|
220
|
+
|
|
221
|
+
c.x1 = current_x
|
|
222
|
+
c.y1 = current_y
|
|
223
|
+
c.x2 = c.x1 + node_size
|
|
224
|
+
c.y2 = c.y1 + node_size * scale
|
|
225
|
+
|
|
226
|
+
current_y = c.y2 + node_spacing
|
|
227
|
+
|
|
228
|
+
c.set_fill(
|
|
229
|
+
color_map.get(TARGET_OPS[layer.name()], {}).get("fill", "#ADD8E6"),
|
|
230
|
+
opacity,
|
|
231
|
+
)
|
|
232
|
+
c.outline = color_map.get(TARGET_OPS[layer.name()], {}).get(
|
|
233
|
+
"outline",
|
|
234
|
+
"black",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
layer_nodes.append(c)
|
|
238
|
+
|
|
239
|
+
id_to_node_list_map[str(id(layer))] = layer_nodes
|
|
240
|
+
nodes.extend(layer_nodes)
|
|
241
|
+
current_y += 2 * node_size
|
|
242
|
+
|
|
243
|
+
layer_y.append(current_y - node_spacing - 2 * node_size)
|
|
244
|
+
layers.append(nodes)
|
|
245
|
+
current_x += node_size + layer_spacing
|
|
246
|
+
return layers, layer_y, id_to_node_list_map
|