visualtorch 0.2.1__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visualtorch-0.2.1 → visualtorch-0.2.3}/LICENSE +2 -0
- {visualtorch-0.2.1/visualtorch.egg-info → visualtorch-0.2.3}/PKG-INFO +6 -5
- {visualtorch-0.2.1 → visualtorch-0.2.3}/README.md +4 -4
- visualtorch-0.2.3/pyproject.toml +167 -0
- {visualtorch-0.2.1 → visualtorch-0.2.3}/setup.py +15 -8
- {visualtorch-0.2.1 → visualtorch-0.2.3}/visualtorch/__init__.py +3 -2
- {visualtorch-0.2.1 → visualtorch-0.2.3}/visualtorch/graph.py +136 -96
- {visualtorch-0.2.1 → visualtorch-0.2.3}/visualtorch/layered.py +168 -119
- visualtorch-0.2.3/visualtorch/lenet_style.py +326 -0
- visualtorch-0.2.3/visualtorch/utils/__init__.py +44 -0
- visualtorch-0.2.3/visualtorch/utils/layer_utils.py +241 -0
- {visualtorch-0.2.1/visualtorch → visualtorch-0.2.3/visualtorch/utils}/utils.py +117 -94
- {visualtorch-0.2.1 → visualtorch-0.2.3/visualtorch.egg-info}/PKG-INFO +6 -5
- {visualtorch-0.2.1 → visualtorch-0.2.3}/visualtorch.egg-info/SOURCES.txt +6 -3
- visualtorch-0.2.3/visualtorch.egg-info/requires.txt +18 -0
- visualtorch-0.2.1/visualtorch/layer_utils.py +0 -175
- visualtorch-0.2.1/visualtorch.egg-info/requires.txt +0 -4
- {visualtorch-0.2.1 → visualtorch-0.2.3}/setup.cfg +0 -0
- {visualtorch-0.2.1 → visualtorch-0.2.3}/visualtorch.egg-info/dependency_links.txt +0 -0
- {visualtorch-0.2.1 → visualtorch-0.2.3}/visualtorch.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: visualtorch
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: Architecture visualization of Torch models
|
|
5
5
|
Home-page: https://github.com/willyfh/visualtorch
|
|
6
6
|
Author: Willy Fitra Hendria
|
|
@@ -14,6 +14,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
14
14
|
Classifier: Operating System :: OS Independent
|
|
15
15
|
Requires-Python: >=3.10
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
|
+
Provides-Extra: dev
|
|
17
18
|
License-File: LICENSE
|
|
18
19
|
|
|
19
20
|
<div align="center">
|
|
@@ -23,13 +24,13 @@ License-File: LICENSE
|
|
|
23
24
|
|
|
24
25
|
</div>
|
|
25
26
|
|
|
26
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and
|
|
27
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
|
|
27
28
|
|
|
28
29
|
**Note:** VisualTorch may not yet support complex models, but contributions are welcome!
|
|
29
30
|
|
|
30
31
|
<div align="center">
|
|
31
32
|
|
|
32
|
-

|
|
33
34
|
|
|
34
35
|
</div>
|
|
35
36
|
|
|
@@ -49,11 +50,11 @@ See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage
|
|
|
49
50
|
|
|
50
51
|
## Contributing
|
|
51
52
|
|
|
52
|
-
Please feel free to send a pull request to contribute to this project.
|
|
53
|
+
Please feel free to send a pull request to contribute to this project by following this [guideline](https://github.com/willyfh/visualtorch/blob/main/CONTRIBUTING.md).
|
|
53
54
|
|
|
54
55
|
## License
|
|
55
56
|
|
|
56
|
-
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/
|
|
57
|
+
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
57
58
|
|
|
58
59
|
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
|
|
59
60
|
|
|
@@ -5,13 +5,13 @@
|
|
|
5
5
|
|
|
6
6
|
</div>
|
|
7
7
|
|
|
8
|
-
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and
|
|
8
|
+
**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
|
|
9
9
|
|
|
10
10
|
**Note:** VisualTorch may not yet support complex models, but contributions are welcome!
|
|
11
11
|
|
|
12
12
|
<div align="center">
|
|
13
13
|
|
|
14
|
-

|
|
15
15
|
|
|
16
16
|
</div>
|
|
17
17
|
|
|
@@ -31,11 +31,11 @@ See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage
|
|
|
31
31
|
|
|
32
32
|
## Contributing
|
|
33
33
|
|
|
34
|
-
Please feel free to send a pull request to contribute to this project.
|
|
34
|
+
Please feel free to send a pull request to contribute to this project by following this [guideline](https://github.com/willyfh/visualtorch/blob/main/CONTRIBUTING.md).
|
|
35
35
|
|
|
36
36
|
## License
|
|
37
37
|
|
|
38
|
-
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/
|
|
38
|
+
This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
|
|
39
39
|
|
|
40
40
|
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
|
|
41
41
|
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
2
|
+
# SETUP CONFIGURATION. #
|
|
3
|
+
[build-system]
|
|
4
|
+
requires = ["setuptools>=42", "wheel"]
|
|
5
|
+
build-backend = "setuptools.build_meta"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
9
|
+
# RUFF CONFIGURATION #
|
|
10
|
+
[tool.ruff]
|
|
11
|
+
# Enable rules
|
|
12
|
+
select = [
|
|
13
|
+
"F", # Pyflakes (`F`)
|
|
14
|
+
"E", # pycodestyle error (`E`)
|
|
15
|
+
"W", # pycodestyle warning (`W`)
|
|
16
|
+
"C90", # mccabe (`C90`)
|
|
17
|
+
"I", # isort (`I`)
|
|
18
|
+
"N", # pep8-naming (`N`)
|
|
19
|
+
"D", # pydocstyle (`D`)
|
|
20
|
+
"UP", # pyupgrade (`UP`)
|
|
21
|
+
"YTT", # flake8-2020 (`YTT`)
|
|
22
|
+
"ANN", # flake8-annotations (`ANN`)
|
|
23
|
+
"S", # flake8-bandit (`S`)
|
|
24
|
+
"BLE", # flake8-blind-except (`BLE`)
|
|
25
|
+
"FBT", # flake8-boolean-trap (`FBT`)
|
|
26
|
+
"B", # flake8-bugbear (`B`)
|
|
27
|
+
"A", # flake8-builtins (`A`)
|
|
28
|
+
"COM", # flake8-commas (`COM`)
|
|
29
|
+
"CPY", # flake8-copyright (`CPY`)
|
|
30
|
+
"C4", # flake8-comprehensions (`C4`)
|
|
31
|
+
"DTZ", # flake8-datatimez (`DTZ`)
|
|
32
|
+
"T10", # flake8-debugger (`T10`)
|
|
33
|
+
"EM", # flake8-errmsg (`EM`)
|
|
34
|
+
"FA", # flake8-future-annotations (`FA`)
|
|
35
|
+
"ISC", # flake8-implicit-str-concat (`ISC`)
|
|
36
|
+
"ICN", # flake8-import-conventions (`ICN`)
|
|
37
|
+
"PIE", # flake8-pie (`PIE`)
|
|
38
|
+
"PT", # flake8-pytest-style (`PT`)
|
|
39
|
+
"RSE", # flake8-raise (`RSE`)
|
|
40
|
+
"RET", # flake8-return (`RET`)
|
|
41
|
+
"SLF", # flake8-self (`SLF`)
|
|
42
|
+
"SIM", # flake8-simplify (`SIM`)
|
|
43
|
+
"TID", # flake8-tidy-imports (`TID`)
|
|
44
|
+
"TCH", # flake8-type-checking (`TCH`)
|
|
45
|
+
"INT", # flake8-gettext (`INT`)
|
|
46
|
+
"ARG", # flake8-unsused-arguments (`ARG`)
|
|
47
|
+
"PTH", # flake8-use-pathlib (`PTH`)
|
|
48
|
+
"TD", # flake8-todos (`TD`)
|
|
49
|
+
"FIX", # flake8-fixme (`FIX`)
|
|
50
|
+
"ERA", # eradicate (`ERA`)
|
|
51
|
+
"PD", # pandas-vet (`PD`)
|
|
52
|
+
"PGH", # pygrep-hooks (`PGH`)
|
|
53
|
+
"PL", # pylint (`PL`)
|
|
54
|
+
"TRY", # tryceratos (`TRY`)
|
|
55
|
+
"FLY", # flynt (`FLY`)
|
|
56
|
+
"NPY", # NumPy-specific rules (`NPY`)
|
|
57
|
+
"PERF", # Perflint (`PERF`)
|
|
58
|
+
"RUF", # Ruff-specific rules (`RUF`)
|
|
59
|
+
# "FURB", # refurb (`FURB`) - ERROR: Unknown rule selector: `FURB`
|
|
60
|
+
# "LOG", # flake8-logging (`LOG`) - ERROR: Unknown rule selector: `LOG`
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
ignore = [
|
|
64
|
+
# pydocstyle
|
|
65
|
+
"D107", # Missing docstring in __init__
|
|
66
|
+
"D415", # First line should end with a period, question mark, or exclamation point
|
|
67
|
+
|
|
68
|
+
# pylint
|
|
69
|
+
"PLR0913", # Too many arguments to function call
|
|
70
|
+
"PLR2004", # consider replacing with a constant variable
|
|
71
|
+
"PLR0912", # Too many branches
|
|
72
|
+
"PLR0915", # Too many statements
|
|
73
|
+
|
|
74
|
+
# flake8-annotations
|
|
75
|
+
"ANN101", # Missing-type-self
|
|
76
|
+
"ANN002", # Missing type annotation for *args
|
|
77
|
+
"ANN003", # Missing type annotation for **kwargs
|
|
78
|
+
|
|
79
|
+
# flake8-bandit (`S`)
|
|
80
|
+
"S101", # Use of assert detected.
|
|
81
|
+
|
|
82
|
+
# flake8-boolean-trap (`FBT`)
|
|
83
|
+
"FBT001", # Boolean positional arg in function definition
|
|
84
|
+
"FBT002", # Boolean default value in function definition
|
|
85
|
+
|
|
86
|
+
# flake8-datatimez (`DTZ`)
|
|
87
|
+
"DTZ005", # The use of `datetime.datetime.now()` without `tz` argument is not allowed
|
|
88
|
+
|
|
89
|
+
# flake8-fixme (`FIX`)
|
|
90
|
+
"FIX002", # Line contains TODO, consider resolving the issue
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
# Allow autofix for all enabled rules (when `--fix`) is provided.
|
|
94
|
+
fixable = ["ALL"]
|
|
95
|
+
unfixable = []
|
|
96
|
+
|
|
97
|
+
# Exclude a variety of commonly ignored directories.
|
|
98
|
+
exclude = [
|
|
99
|
+
".bzr",
|
|
100
|
+
".direnv",
|
|
101
|
+
".eggs",
|
|
102
|
+
".git",
|
|
103
|
+
".hg",
|
|
104
|
+
".mypy_cache",
|
|
105
|
+
".nox",
|
|
106
|
+
".pants.d",
|
|
107
|
+
".pytype",
|
|
108
|
+
".ruff_cache",
|
|
109
|
+
".svn",
|
|
110
|
+
".tox",
|
|
111
|
+
".venv",
|
|
112
|
+
"__pypackages__",
|
|
113
|
+
"_build",
|
|
114
|
+
"buck-out",
|
|
115
|
+
"build",
|
|
116
|
+
"dist",
|
|
117
|
+
"node_modules",
|
|
118
|
+
"venv",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Same as Black.
|
|
122
|
+
line-length = 120
|
|
123
|
+
|
|
124
|
+
# Allow unused variables when underscore-prefixed.
|
|
125
|
+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
|
126
|
+
|
|
127
|
+
# Assume Python 3.10.
|
|
128
|
+
target-version = "py310"
|
|
129
|
+
|
|
130
|
+
# Allow imports relative to the "src" and "tests" directories.
|
|
131
|
+
src = ["visualtorch", "tests"]
|
|
132
|
+
|
|
133
|
+
[tool.ruff.mccabe]
|
|
134
|
+
# Unlike Flake8, default to a complexity level of 10.
|
|
135
|
+
max-complexity = 15
|
|
136
|
+
|
|
137
|
+
[tool.ruff.per-file-ignores]
|
|
138
|
+
"tests/nightly/tools/benchmarking/test_benchmarking.py" = ["E402"]
|
|
139
|
+
|
|
140
|
+
[tool.ruff.pydocstyle]
|
|
141
|
+
convention = "google"
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
145
|
+
# MYPY CONFIGURATION. #
|
|
146
|
+
[tool.mypy]
|
|
147
|
+
ignore_missing_imports = true
|
|
148
|
+
show_error_codes = true
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
[[tool.mypy.overrides]]
|
|
152
|
+
module = ["torch.*"]
|
|
153
|
+
follow_imports = "skip"
|
|
154
|
+
follow_imports_for_stubs = true
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
158
|
+
# PYTEST CONFIGURATION #
|
|
159
|
+
[tool.pytest.ini_options]
|
|
160
|
+
addopts = [
|
|
161
|
+
"--strict-markers",
|
|
162
|
+
"--strict-config",
|
|
163
|
+
"--showlocals",
|
|
164
|
+
"-ra",
|
|
165
|
+
]
|
|
166
|
+
testpaths = "tests"
|
|
167
|
+
pythonpath = "visualtorch"
|
|
@@ -3,14 +3,25 @@
|
|
|
3
3
|
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
4
|
# SPDX-License-Identifier: MIT
|
|
5
5
|
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
6
8
|
import setuptools
|
|
7
9
|
|
|
8
|
-
|
|
10
|
+
file_path = Path("README.md")
|
|
11
|
+
with file_path.open("r") as fh:
|
|
9
12
|
long_description = fh.read()
|
|
10
13
|
|
|
14
|
+
|
|
15
|
+
def _read_requirements(file: str) -> list:
|
|
16
|
+
file_path = Path(file)
|
|
17
|
+
with file_path.open("r") as fh:
|
|
18
|
+
reqs = fh.read()
|
|
19
|
+
return reqs.strip().split("\n")
|
|
20
|
+
|
|
21
|
+
|
|
11
22
|
setuptools.setup(
|
|
12
23
|
name="visualtorch",
|
|
13
|
-
version="0.2.
|
|
24
|
+
version="0.2.3",
|
|
14
25
|
author="Willy Fitra Hendria",
|
|
15
26
|
author_email="willyfitrahendria@gmail.com",
|
|
16
27
|
description="Architecture visualization of Torch models",
|
|
@@ -26,12 +37,8 @@ setuptools.setup(
|
|
|
26
37
|
"License :: OSI Approved :: MIT License",
|
|
27
38
|
"Operating System :: OS Independent",
|
|
28
39
|
],
|
|
29
|
-
install_requires=
|
|
30
|
-
|
|
31
|
-
"numpy>=1.18.1",
|
|
32
|
-
"aggdraw>=1.3.11",
|
|
33
|
-
"torch>=2.0.0",
|
|
34
|
-
],
|
|
40
|
+
install_requires=_read_requirements("requirements.txt"),
|
|
41
|
+
extras_require={"dev": _read_requirements("docs/requirements.txt") + _read_requirements("dev-requirements.txt")},
|
|
35
42
|
python_requires=">=3.10",
|
|
36
43
|
license="MIT",
|
|
37
44
|
license_files=("LICENSE",),
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
4
|
# SPDX-License-Identifier: MIT
|
|
5
5
|
|
|
6
|
-
from visualtorch.layered import layered_view
|
|
7
6
|
from visualtorch.graph import graph_view
|
|
7
|
+
from visualtorch.layered import layered_view
|
|
8
|
+
from visualtorch.lenet_style import lenet_view
|
|
8
9
|
|
|
9
|
-
__all__ = ["layered_view", "graph_view"]
|
|
10
|
+
__all__ = ["layered_view", "graph_view", "lenet_view"]
|
|
@@ -1,37 +1,39 @@
|
|
|
1
1
|
"""Graph View module for pytorch model visualization."""
|
|
2
2
|
|
|
3
|
+
# Copyright (C) 2020 Paul Gavrikov
|
|
3
4
|
# Copyright (C) 2024 Willy Fitra Hendria
|
|
4
5
|
# SPDX-License-Identifier: MIT
|
|
5
6
|
|
|
6
|
-
import
|
|
7
|
-
from PIL import Image
|
|
7
|
+
from collections import defaultdict
|
|
8
8
|
from math import ceil
|
|
9
|
-
from
|
|
10
|
-
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import aggdraw
|
|
11
12
|
import numpy as np
|
|
12
|
-
from typing import Optional, Dict, Any, Tuple, List
|
|
13
13
|
import torch
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
from .utils.layer_utils import TARGET_OPS, add_input_dummy_layer, model_to_adj_matrix
|
|
17
|
+
from .utils.utils import Box, Circle, Ellipses, get_keys_by_value
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
def graph_view(
|
|
17
21
|
model: torch.nn.Module,
|
|
18
|
-
input_shape:
|
|
19
|
-
to_file:
|
|
20
|
-
color_map:
|
|
22
|
+
input_shape: tuple[int, ...],
|
|
23
|
+
to_file: str | None = None,
|
|
24
|
+
color_map: dict[Any, Any] | None = None,
|
|
21
25
|
node_size: int = 50,
|
|
22
|
-
background_fill:
|
|
26
|
+
background_fill: str | tuple[int, ...] = "white",
|
|
23
27
|
padding: int = 10,
|
|
24
28
|
layer_spacing: int = 250,
|
|
25
29
|
node_spacing: int = 10,
|
|
26
|
-
connector_fill:
|
|
30
|
+
connector_fill: str | tuple[int, ...] = "gray",
|
|
27
31
|
connector_width: int = 1,
|
|
28
32
|
ellipsize_after: int = 10,
|
|
29
|
-
inout_as_tensor: bool = True,
|
|
30
33
|
show_neurons: bool = True,
|
|
34
|
+
opacity: int = 255,
|
|
31
35
|
) -> Image.Image:
|
|
32
|
-
"""
|
|
33
|
-
Generates an architecture visualization for a given linear PyTorch model (i.e., one input and output tensor for each
|
|
34
|
-
layer) in a graph style.
|
|
36
|
+
"""Generates an architecture visualization for a given linear PyTorch model in a graph style.
|
|
35
37
|
|
|
36
38
|
Args:
|
|
37
39
|
model (torch.nn.Module): A PyTorch model that will be visualized.
|
|
@@ -39,8 +41,8 @@ def graph_view(
|
|
|
39
41
|
to_file (str, optional): Path to the file to write the created image to. If the image does not exist yet,
|
|
40
42
|
it will be created, else overwritten. Image type is inferred from the file ending. Providing None
|
|
41
43
|
will disable writing.
|
|
42
|
-
color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
|
|
43
|
-
values for not specified classes.
|
|
44
|
+
color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
|
|
45
|
+
to default values for not specified classes.
|
|
44
46
|
node_size (int, optional): Size in pixels each node will have.
|
|
45
47
|
background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
|
|
46
48
|
padding (int, optional): Distance in pixels before the first and after the last layer.
|
|
@@ -50,108 +52,59 @@ def graph_view(
|
|
|
50
52
|
connector_width (int, optional): Line-width of the connectors in pixels.
|
|
51
53
|
ellipsize_after (int, optional): Maximum number of neurons per layer to draw. If a layer is exceeding this,
|
|
52
54
|
the remaining neurons will be drawn as ellipses.
|
|
53
|
-
inout_as_tensor (bool, optional): If True there will be one input and output node for each tensor, else the
|
|
54
|
-
tensor will be flattened and one node for each scalar will be created (e.g., a (10, 10) shape will be
|
|
55
|
-
represented by 100 nodes).
|
|
56
55
|
show_neurons (bool, optional): If True a node for each neuron in supported layers is created (constrained by
|
|
57
56
|
ellipsize_after), else each layer is represented by a node.
|
|
57
|
+
opacity (int, optional): Transparency of the color (0 ~ 255).
|
|
58
58
|
|
|
59
59
|
Returns:
|
|
60
60
|
Image.Image: Generated architecture image.
|
|
61
61
|
"""
|
|
62
|
-
|
|
63
|
-
if color_map is None:
|
|
64
|
-
|
|
62
|
+
_color_map: dict = {}
|
|
63
|
+
if color_map is not None:
|
|
64
|
+
_color_map = defaultdict(dict, color_map)
|
|
65
65
|
|
|
66
66
|
# Iterate over the model to compute bounds and generate boxes
|
|
67
67
|
|
|
68
|
-
layers: List[Any] = list()
|
|
69
|
-
layer_y = list()
|
|
70
|
-
|
|
71
68
|
# Attach helper layers
|
|
72
69
|
|
|
73
70
|
id_to_num_mapping, adj_matrix, model_layers = model_to_adj_matrix(
|
|
74
|
-
model,
|
|
71
|
+
model,
|
|
72
|
+
input_shape,
|
|
75
73
|
)
|
|
76
74
|
|
|
77
75
|
# Add fake input layers
|
|
78
76
|
|
|
79
77
|
id_to_num_mapping, adj_matrix, model_layers = add_input_dummy_layer(
|
|
80
|
-
input_shape,
|
|
78
|
+
input_shape,
|
|
79
|
+
id_to_num_mapping,
|
|
80
|
+
adj_matrix,
|
|
81
|
+
model_layers,
|
|
81
82
|
)
|
|
82
83
|
|
|
83
84
|
# Create architecture
|
|
84
85
|
|
|
85
86
|
current_x = padding # + input_label_size[0] + text_padding
|
|
86
87
|
|
|
87
|
-
id_to_node_list_map =
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
if hasattr(layer, "_saved_bias_sym_sizes_opt"):
|
|
99
|
-
is_box = False
|
|
100
|
-
units = layer._saved_bias_sym_sizes_opt[0]
|
|
101
|
-
elif hasattr(layer, "_saved_mat2_sym_sizes"):
|
|
102
|
-
is_box = False
|
|
103
|
-
units = layer._saved_mat2_sym_sizes[1]
|
|
104
|
-
elif hasattr(layer, "units"): # for dummy input layer
|
|
105
|
-
is_box = False
|
|
106
|
-
units = layer.units
|
|
107
|
-
|
|
108
|
-
n = min(units, ellipsize_after)
|
|
109
|
-
layer_nodes = list()
|
|
110
|
-
|
|
111
|
-
for i in range(n):
|
|
112
|
-
scale = 1
|
|
113
|
-
c: Box | Circle | Ellipses
|
|
114
|
-
if not is_box:
|
|
115
|
-
if i != ellipsize_after - 2:
|
|
116
|
-
c = Circle()
|
|
117
|
-
else:
|
|
118
|
-
c = Ellipses()
|
|
119
|
-
else:
|
|
120
|
-
c = Box()
|
|
121
|
-
scale = 3
|
|
122
|
-
|
|
123
|
-
c.x1 = current_x
|
|
124
|
-
c.y1 = current_y
|
|
125
|
-
c.x2 = c.x1 + node_size
|
|
126
|
-
c.y2 = c.y1 + node_size * scale
|
|
127
|
-
|
|
128
|
-
current_y = c.y2 + node_spacing
|
|
129
|
-
|
|
130
|
-
c.fill = color_map.get(TARGET_OPS[layer.name()], {}).get(
|
|
131
|
-
"fill", "#ADD8E6"
|
|
132
|
-
)
|
|
133
|
-
c.outline = color_map.get(TARGET_OPS[layer.name()], {}).get(
|
|
134
|
-
"outline", "black"
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
layer_nodes.append(c)
|
|
138
|
-
|
|
139
|
-
id_to_node_list_map[str(id(layer))] = layer_nodes
|
|
140
|
-
nodes.extend(layer_nodes)
|
|
141
|
-
current_y += 2 * node_size
|
|
142
|
-
|
|
143
|
-
layer_y.append(current_y - node_spacing - 2 * node_size)
|
|
144
|
-
layers.append(nodes)
|
|
145
|
-
current_x += node_size + layer_spacing
|
|
88
|
+
layers, layer_y, id_to_node_list_map = _create_architecture(
|
|
89
|
+
model_layers,
|
|
90
|
+
current_x,
|
|
91
|
+
show_neurons,
|
|
92
|
+
ellipsize_after,
|
|
93
|
+
node_size,
|
|
94
|
+
node_spacing,
|
|
95
|
+
_color_map,
|
|
96
|
+
opacity,
|
|
97
|
+
layer_spacing,
|
|
98
|
+
)
|
|
146
99
|
|
|
147
100
|
# Generate image
|
|
148
101
|
|
|
149
|
-
img_width = (
|
|
150
|
-
len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
|
|
151
|
-
)
|
|
102
|
+
img_width = len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
|
|
152
103
|
img_height = max(*layer_y) + 2 * padding
|
|
153
104
|
img = Image.new(
|
|
154
|
-
"RGBA",
|
|
105
|
+
"RGBA",
|
|
106
|
+
(int(ceil(img_width)), int(ceil(img_height))),
|
|
107
|
+
background_fill,
|
|
155
108
|
)
|
|
156
109
|
|
|
157
110
|
draw = aggdraw.Draw(img)
|
|
@@ -164,7 +117,7 @@ def graph_view(
|
|
|
164
117
|
node.y1 += y_off
|
|
165
118
|
node.y2 += y_off
|
|
166
119
|
|
|
167
|
-
for start_idx, end_idx in zip(*np.where(adj_matrix > 0)):
|
|
120
|
+
for start_idx, end_idx in zip(*np.where(adj_matrix > 0), strict=False):
|
|
168
121
|
start_id = next(get_keys_by_value(id_to_num_mapping, start_idx))
|
|
169
122
|
end_id = next(get_keys_by_value(id_to_num_mapping, end_idx))
|
|
170
123
|
|
|
@@ -172,10 +125,11 @@ def graph_view(
|
|
|
172
125
|
end_layer_list = id_to_node_list_map[end_id]
|
|
173
126
|
|
|
174
127
|
# draw connectors
|
|
175
|
-
for
|
|
128
|
+
for start_node in start_layer_list:
|
|
176
129
|
for end_node in end_layer_list:
|
|
177
130
|
if not isinstance(start_node, Ellipses) and not isinstance(
|
|
178
|
-
end_node,
|
|
131
|
+
end_node,
|
|
132
|
+
Ellipses,
|
|
179
133
|
):
|
|
180
134
|
_draw_connector(
|
|
181
135
|
draw,
|
|
@@ -185,8 +139,8 @@ def graph_view(
|
|
|
185
139
|
width=connector_width,
|
|
186
140
|
)
|
|
187
141
|
|
|
188
|
-
for
|
|
189
|
-
for
|
|
142
|
+
for layer in layers:
|
|
143
|
+
for node in layer:
|
|
190
144
|
node.draw(draw)
|
|
191
145
|
|
|
192
146
|
draw.flush()
|
|
@@ -197,10 +151,96 @@ def graph_view(
|
|
|
197
151
|
return img
|
|
198
152
|
|
|
199
153
|
|
|
200
|
-
def _draw_connector(
|
|
154
|
+
def _draw_connector(
|
|
155
|
+
draw: aggdraw.Draw,
|
|
156
|
+
start_node: Box | Circle | Ellipses,
|
|
157
|
+
end_node: Box | Circle | Ellipses,
|
|
158
|
+
color: str | tuple[int, ...],
|
|
159
|
+
width: int,
|
|
160
|
+
) -> None:
|
|
161
|
+
"""Draw the line connector between nodes."""
|
|
201
162
|
pen = aggdraw.Pen(color, width)
|
|
202
163
|
x1 = start_node.x2
|
|
203
164
|
y1 = start_node.y1 + (start_node.y2 - start_node.y1) / 2
|
|
204
165
|
x2 = end_node.x1
|
|
205
166
|
y2 = end_node.y1 + (end_node.y2 - end_node.y1) / 2
|
|
206
167
|
draw.line([x1, y1, x2, y2], pen)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _retrieve_isbox_units(layer: torch.autograd.Function, show_neurons: bool) -> tuple[bool, int]:
|
|
171
|
+
"""Return the number of units and the flag whether to visualize using a box or not."""
|
|
172
|
+
is_box = True
|
|
173
|
+
units = 1
|
|
174
|
+
if show_neurons:
|
|
175
|
+
if hasattr(layer, "_saved_bias_sym_sizes_opt"):
|
|
176
|
+
is_box = False
|
|
177
|
+
units = layer._saved_bias_sym_sizes_opt[0] # noqa: SLF001
|
|
178
|
+
elif hasattr(layer, "_saved_mat2_sym_sizes"):
|
|
179
|
+
is_box = False
|
|
180
|
+
units = layer._saved_mat2_sym_sizes[1] # noqa: SLF001
|
|
181
|
+
elif hasattr(layer, "units"): # for dummy input layer
|
|
182
|
+
is_box = False
|
|
183
|
+
units = layer.units
|
|
184
|
+
return is_box, units
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _create_architecture(
|
|
188
|
+
model_layers: list[list],
|
|
189
|
+
current_x: int,
|
|
190
|
+
show_neurons: bool,
|
|
191
|
+
ellipsize_after: int,
|
|
192
|
+
node_size: int,
|
|
193
|
+
node_spacing: int,
|
|
194
|
+
color_map: dict[Any, Any],
|
|
195
|
+
opacity: int,
|
|
196
|
+
layer_spacing: int,
|
|
197
|
+
) -> tuple[list, list, dict]:
|
|
198
|
+
"""Create nodes of architecture for each layers."""
|
|
199
|
+
id_to_node_list_map = {}
|
|
200
|
+
layers = []
|
|
201
|
+
layer_y = []
|
|
202
|
+
for layer_list in model_layers:
|
|
203
|
+
current_y = 0
|
|
204
|
+
nodes = []
|
|
205
|
+
layer: Any
|
|
206
|
+
for layer in layer_list:
|
|
207
|
+
is_box, units = _retrieve_isbox_units(layer, show_neurons)
|
|
208
|
+
|
|
209
|
+
n = min(units, ellipsize_after)
|
|
210
|
+
layer_nodes = []
|
|
211
|
+
|
|
212
|
+
for i in range(n):
|
|
213
|
+
scale = 1
|
|
214
|
+
c: Box | Circle | Ellipses
|
|
215
|
+
if not is_box:
|
|
216
|
+
c = Circle() if i != ellipsize_after - 2 else Ellipses()
|
|
217
|
+
else:
|
|
218
|
+
c = Box()
|
|
219
|
+
scale = 3
|
|
220
|
+
|
|
221
|
+
c.x1 = current_x
|
|
222
|
+
c.y1 = current_y
|
|
223
|
+
c.x2 = c.x1 + node_size
|
|
224
|
+
c.y2 = c.y1 + node_size * scale
|
|
225
|
+
|
|
226
|
+
current_y = c.y2 + node_spacing
|
|
227
|
+
|
|
228
|
+
c.set_fill(
|
|
229
|
+
color_map.get(TARGET_OPS[layer.name()], {}).get("fill", "#ADD8E6"),
|
|
230
|
+
opacity,
|
|
231
|
+
)
|
|
232
|
+
c.outline = color_map.get(TARGET_OPS[layer.name()], {}).get(
|
|
233
|
+
"outline",
|
|
234
|
+
"black",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
layer_nodes.append(c)
|
|
238
|
+
|
|
239
|
+
id_to_node_list_map[str(id(layer))] = layer_nodes
|
|
240
|
+
nodes.extend(layer_nodes)
|
|
241
|
+
current_y += 2 * node_size
|
|
242
|
+
|
|
243
|
+
layer_y.append(current_y - node_spacing - 2 * node_size)
|
|
244
|
+
layers.append(nodes)
|
|
245
|
+
current_x += node_size + layer_spacing
|
|
246
|
+
return layers, layer_y, id_to_node_list_map
|