visualtorch 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,64 @@
1
+ Metadata-Version: 2.1
2
+ Name: visualtorch
3
+ Version: 0.2.2
4
+ Summary: Architecture visualization of Torch models
5
+ Home-page: https://github.com/willyfh/visualtorch
6
+ Author: Willy Fitra Hendria
7
+ Author-email: willyfitrahendria@gmail.com
8
+ License: MIT
9
+ Keywords: visualize architecture,torch visualization,visualtorch
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: OS Independent
15
+ Requires-Python: >=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+
19
+ <div align="center">
20
+ <h1>🔥 VisualTorch 🔥</h1>
21
+
22
+ [![python](https://img.shields.io/badge/python-3.10%2B-blue)]() [![pytorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)]() [![Downloads](https://static.pepy.tech/personalized-badge/visualtorch?period=total&units=international_system&left_color=grey&right_color=green&left_text=PyPI%20Downloads)](https://pepy.tech/project/visualtorch) [![Run Tests](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml/badge.svg)](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml) [![Documentation Status](https://readthedocs.org/projects/visualtorch/badge/?version=latest)](https://visualtorch.readthedocs.io/en/latest/?badge=latest)
23
+
24
+ </div>
25
+
26
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
27
+
28
+ **Note:** VisualTorch may not yet support complex models, but contributions are welcome!
29
+
30
+ <div align="center">
31
+
32
+ ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/398c3356-4de0-446b-a30b-d8ebe532d2c2)
33
+
34
+ </div>
35
+
36
+ ## Documentation
37
+
38
+ Online documentation is available at [visualtorch.readthedocs.io](https://visualtorch.readthedocs.io/en/latest/).
39
+
40
+ The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html), [API references](https://visualtorch.readthedocs.io/en/latest/markdown/api_references/index.html), and other useful information.
41
+
42
+ ## Installation
43
+
44
+ See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
45
+
46
+ ## Examples
47
+
48
+ See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
49
+
50
+ ## Contributing
51
+
52
+ Please feel free to send a pull request to contribute to this project.
53
+
54
+ ## License
55
+
56
+ This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).
57
+
58
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
59
+
60
+ ## Citation
61
+
62
+ Please cite this project in your publications if it helps your research.
63
+
64
+ [A ready-made citation entry](https://visualtorch.readthedocs.io/en/latest/index.html#citation) is available.
@@ -0,0 +1,46 @@
1
+ <div align="center">
2
+ <h1>🔥 VisualTorch 🔥</h1>
3
+
4
+ [![python](https://img.shields.io/badge/python-3.10%2B-blue)]() [![pytorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)]() [![Downloads](https://static.pepy.tech/personalized-badge/visualtorch?period=total&units=international_system&left_color=grey&right_color=green&left_text=PyPI%20Downloads)](https://pepy.tech/project/visualtorch) [![Run Tests](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml/badge.svg)](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml) [![Documentation Status](https://readthedocs.org/projects/visualtorch/badge/?version=latest)](https://visualtorch.readthedocs.io/en/latest/?badge=latest)
5
+
6
+ </div>
7
+
8
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
9
+
10
+ **Note:** VisualTorch may not yet support complex models, but contributions are welcome!
11
+
12
+ <div align="center">
13
+
14
+ ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/398c3356-4de0-446b-a30b-d8ebe532d2c2)
15
+
16
+ </div>
17
+
18
+ ## Documentation
19
+
20
+ Online documentation is available at [visualtorch.readthedocs.io](https://visualtorch.readthedocs.io/en/latest/).
21
+
22
+ The docs include [usage examples](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html), [API references](https://visualtorch.readthedocs.io/en/latest/markdown/api_references/index.html), and other useful information.
23
+
24
+ ## Installation
25
+
26
+ See the [Installation page](https://visualtorch.readthedocs.io/en/latest/markdown/get_started/installation.html).
27
+
28
+ ## Examples
29
+
30
+ See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage_examples/index.html).
31
+
32
+ ## Contributing
33
+
34
+ Please feel free to send a pull request to contribute to this project.
35
+
36
+ ## License
37
+
38
+ This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).
39
+
40
+ Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
41
+
42
+ ## Citation
43
+
44
+ Please cite this project in your publications if it helps your research.
45
+
46
+ [A ready-made citation entry](https://visualtorch.readthedocs.io/en/latest/index.html#citation) is available.
@@ -0,0 +1,167 @@
1
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
2
+ # SETUP CONFIGURATION. #
3
+ [build-system]
4
+ requires = ["setuptools>=42", "wheel"]
5
+ build-backend = "setuptools.build_meta"
6
+
7
+
8
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
9
+ # RUFF CONFIGURATION #
10
+ [tool.ruff]
11
+ # Enable rules
12
+ select = [
13
+ "F", # Pyflakes (`F`)
14
+ "E", # pycodestyle error (`E`)
15
+ "W", # pycodestyle warning (`W`)
16
+ "C90", # mccabe (`C90`)
17
+ "I", # isort (`I`)
18
+ "N", # pep8-naming (`N`)
19
+ "D", # pydocstyle (`D`)
20
+ "UP", # pyupgrade (`UP`)
21
+ "YTT", # flake8-2020 (`YTT`)
22
+ "ANN", # flake8-annotations (`ANN`)
23
+ "S", # flake8-bandit (`S`)
24
+ "BLE", # flake8-blind-except (`BLE`)
25
+ "FBT", # flake8-boolean-trap (`FBT`)
26
+ "B", # flake8-bugbear (`B`)
27
+ "A", # flake8-builtins (`A`)
28
+ "COM", # flake8-commas (`COM`)
29
+ "CPY", # flake8-copyright (`CPY`)
30
+ "C4", # flake8-comprehensions (`C4`)
31
+ "DTZ", # flake8-datatimez (`DTZ`)
32
+ "T10", # flake8-debugger (`T10`)
33
+ "EM", # flake8-errmsg (`EM`)
34
+ "FA", # flake8-future-annotations (`FA`)
35
+ "ISC", # flake8-implicit-str-concat (`ISC`)
36
+ "ICN", # flake8-import-conventions (`ICN`)
37
+ "PIE", # flake8-pie (`PIE`)
38
+ "PT", # flake8-pytest-style (`PT`)
39
+ "RSE", # flake8-raise (`RSE`)
40
+ "RET", # flake8-return (`RET`)
41
+ "SLF", # flake8-self (`SLF`)
42
+ "SIM", # flake8-simplify (`SIM`)
43
+ "TID", # flake8-tidy-imports (`TID`)
44
+ "TCH", # flake8-type-checking (`TCH`)
45
+ "INT", # flake8-gettext (`INT`)
46
+ "ARG", # flake8-unsused-arguments (`ARG`)
47
+ "PTH", # flake8-use-pathlib (`PTH`)
48
+ "TD", # flake8-todos (`TD`)
49
+ "FIX", # flake8-fixme (`FIX`)
50
+ "ERA", # eradicate (`ERA`)
51
+ "PD", # pandas-vet (`PD`)
52
+ "PGH", # pygrep-hooks (`PGH`)
53
+ "PL", # pylint (`PL`)
54
+ "TRY", # tryceratos (`TRY`)
55
+ "FLY", # flynt (`FLY`)
56
+ "NPY", # NumPy-specific rules (`NPY`)
57
+ "PERF", # Perflint (`PERF`)
58
+ "RUF", # Ruff-specific rules (`RUF`)
59
+ # "FURB", # refurb (`FURB`) - ERROR: Unknown rule selector: `FURB`
60
+ # "LOG", # flake8-logging (`LOG`) - ERROR: Unknown rule selector: `LOG`
61
+ ]
62
+
63
+ ignore = [
64
+ # pydocstyle
65
+ "D107", # Missing docstring in __init__
66
+ "D415", # First line should end with a period, question mark, or exclamation point
67
+
68
+ # pylint
69
+ "PLR0913", # Too many arguments to function call
70
+ "PLR2004", # consider replacing with a constant variable
71
+ "PLR0912", # Too many branches
72
+ "PLR0915", # Too many statements
73
+
74
+ # flake8-annotations
75
+ "ANN101", # Missing-type-self
76
+ "ANN002", # Missing type annotation for *args
77
+ "ANN003", # Missing type annotation for **kwargs
78
+
79
+ # flake8-bandit (`S`)
80
+ "S101", # Use of assert detected.
81
+
82
+ # flake8-boolean-trap (`FBT`)
83
+ "FBT001", # Boolean positional arg in function definition
84
+ "FBT002", # Boolean default value in function definition
85
+
86
+ # flake8-datatimez (`DTZ`)
87
+ "DTZ005", # The use of `datetime.datetime.now()` without `tz` argument is not allowed
88
+
89
+ # flake8-fixme (`FIX`)
90
+ "FIX002", # Line contains TODO, consider resolving the issue
91
+ ]
92
+
93
+ # Allow autofix for all enabled rules (when `--fix`) is provided.
94
+ fixable = ["ALL"]
95
+ unfixable = []
96
+
97
+ # Exclude a variety of commonly ignored directories.
98
+ exclude = [
99
+ ".bzr",
100
+ ".direnv",
101
+ ".eggs",
102
+ ".git",
103
+ ".hg",
104
+ ".mypy_cache",
105
+ ".nox",
106
+ ".pants.d",
107
+ ".pytype",
108
+ ".ruff_cache",
109
+ ".svn",
110
+ ".tox",
111
+ ".venv",
112
+ "__pypackages__",
113
+ "_build",
114
+ "buck-out",
115
+ "build",
116
+ "dist",
117
+ "node_modules",
118
+ "venv",
119
+ ]
120
+
121
+ # Same as Black.
122
+ line-length = 120
123
+
124
+ # Allow unused variables when underscore-prefixed.
125
+ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
126
+
127
+ # Assume Python 3.10.
128
+ target-version = "py310"
129
+
130
+ # Allow imports relative to the "src" and "tests" directories.
131
+ src = ["visualtorch", "tests"]
132
+
133
+ [tool.ruff.mccabe]
134
+ # Unlike Flake8, default to a complexity level of 10.
135
+ max-complexity = 15
136
+
137
+ [tool.ruff.per-file-ignores]
138
+ "tests/nightly/tools/benchmarking/test_benchmarking.py" = ["E402"]
139
+
140
+ [tool.ruff.pydocstyle]
141
+ convention = "google"
142
+
143
+
144
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
145
+ # MYPY CONFIGURATION. #
146
+ [tool.mypy]
147
+ ignore_missing_imports = true
148
+ show_error_codes = true
149
+
150
+
151
+ [[tool.mypy.overrides]]
152
+ module = ["torch.*"]
153
+ follow_imports = "skip"
154
+ follow_imports_for_stubs = true
155
+
156
+
157
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
158
+ # PYTEST CONFIGURATION #
159
+ [tool.pytest.ini_options]
160
+ addopts = [
161
+ "--strict-markers",
162
+ "--strict-config",
163
+ "--showlocals",
164
+ "-ra",
165
+ ]
166
+ testpaths = "tests"
167
+ pythonpath = "visualtorch"
@@ -1,11 +1,19 @@
1
+ """Setup file for visualtorch."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from pathlib import Path
7
+
1
8
  import setuptools
2
9
 
3
- with open("README.md", "r") as fh:
10
+ file_path = Path("README.md")
11
+ with file_path.open("r") as fh:
4
12
  long_description = fh.read()
5
13
 
6
14
  setuptools.setup(
7
15
  name="visualtorch",
8
- version="0.2.0",
16
+ version="0.2.2",
9
17
  author="Willy Fitra Hendria",
10
18
  author_email="willyfitrahendria@gmail.com",
11
19
  description="Architecture visualization of Torch models",
@@ -0,0 +1,10 @@
1
+ """Modules for pytorch model visualization."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from visualtorch.graph import graph_view
7
+ from visualtorch.layered import layered_view
8
+ from visualtorch.lenet_style import lenet_view
9
+
10
+ __all__ = ["layered_view", "graph_view", "lenet_view"]
@@ -1,32 +1,38 @@
1
- import aggdraw
2
- from PIL import Image
1
+ """Graph View module for pytorch model visualization."""
2
+
3
+ # Copyright (C) 2024 Willy Fitra Hendria
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from collections import defaultdict
3
7
  from math import ceil
4
- from .layer_utils import model_to_adj_matrix, add_input_dummy_layer
5
- from .utils import Circle, Ellipses, get_keys_by_value, Box
8
+ from typing import Any
9
+
10
+ import aggdraw
6
11
  import numpy as np
7
- from typing import Optional, Dict, Any, Tuple, List
8
12
  import torch
13
+ from PIL import Image
14
+
15
+ from .utils.layer_utils import TARGET_OPS, add_input_dummy_layer, model_to_adj_matrix
16
+ from .utils.utils import Box, Circle, Ellipses, get_keys_by_value
9
17
 
10
18
 
11
19
  def graph_view(
12
20
  model: torch.nn.Module,
13
- input_shape: Tuple[int, ...],
14
- to_file: Optional[str] = None,
15
- color_map: Optional[Dict[Any, Any]] = None,
21
+ input_shape: tuple[int, ...],
22
+ to_file: str | None = None,
23
+ color_map: dict[Any, Any] | None = None,
16
24
  node_size: int = 50,
17
- background_fill: Any = "white",
25
+ background_fill: str | tuple[int, ...] = "white",
18
26
  padding: int = 10,
19
27
  layer_spacing: int = 250,
20
28
  node_spacing: int = 10,
21
- connector_fill: Any = "gray",
29
+ connector_fill: str | tuple[int, ...] = "gray",
22
30
  connector_width: int = 1,
23
31
  ellipsize_after: int = 10,
24
- inout_as_tensor: bool = True,
25
32
  show_neurons: bool = True,
33
+ opacity: int = 255,
26
34
  ) -> Image.Image:
27
- """
28
- Generates an architecture visualization for a given linear PyTorch model (i.e., one input and output tensor for each
29
- layer) in a graph style.
35
+ """Generates an architecture visualization for a given linear PyTorch model in a graph style.
30
36
 
31
37
  Args:
32
38
  model (torch.nn.Module): A PyTorch model that will be visualized.
@@ -34,8 +40,8 @@ def graph_view(
34
40
  to_file (str, optional): Path to the file to write the created image to. If the image does not exist yet,
35
41
  it will be created, else overwritten. Image type is inferred from the file ending. Providing None
36
42
  will disable writing.
37
- color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback to default
38
- values for not specified classes.
43
+ color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
44
+ to default values for not specified classes.
39
45
  node_size (int, optional): Size in pixels each node will have.
40
46
  background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
41
47
  padding (int, optional): Distance in pixels before the first and after the last layer.
@@ -45,103 +51,59 @@ def graph_view(
45
51
  connector_width (int, optional): Line-width of the connectors in pixels.
46
52
  ellipsize_after (int, optional): Maximum number of neurons per layer to draw. If a layer is exceeding this,
47
53
  the remaining neurons will be drawn as ellipses.
48
- inout_as_tensor (bool, optional): If True there will be one input and output node for each tensor, else the
49
- tensor will be flattened and one node for each scalar will be created (e.g., a (10, 10) shape will be
50
- represented by 100 nodes).
51
54
  show_neurons (bool, optional): If True a node for each neuron in supported layers is created (constrained by
52
55
  ellipsize_after), else each layer is represented by a node.
56
+ opacity (int, optional): Transparency of the color (0 ~ 255).
53
57
 
54
58
  Returns:
55
59
  Image.Image: Generated architecture image.
56
60
  """
57
-
58
- if color_map is None:
59
- color_map = dict()
61
+ _color_map: dict = {}
62
+ if color_map is not None:
63
+ _color_map = defaultdict(dict, color_map)
60
64
 
61
65
  # Iterate over the model to compute bounds and generate boxes
62
66
 
63
- layers: List[Any] = list()
64
- layer_y = list()
65
-
66
67
  # Attach helper layers
67
68
 
68
69
  id_to_num_mapping, adj_matrix, model_layers = model_to_adj_matrix(
69
- model, input_shape
70
+ model,
71
+ input_shape,
70
72
  )
71
73
 
72
74
  # Add fake input layers
73
75
 
74
76
  id_to_num_mapping, adj_matrix, model_layers = add_input_dummy_layer(
75
- input_shape, id_to_num_mapping, adj_matrix, model_layers
77
+ input_shape,
78
+ id_to_num_mapping,
79
+ adj_matrix,
80
+ model_layers,
76
81
  )
77
82
 
78
83
  # Create architecture
79
84
 
80
85
  current_x = padding # + input_label_size[0] + text_padding
81
86
 
82
- id_to_node_list_map = dict()
83
-
84
- for index, layer_list in enumerate(model_layers):
85
- current_y = 0
86
- nodes = []
87
- for layer in layer_list:
88
- is_box = True
89
- units = 1
90
-
91
- if show_neurons:
92
- if hasattr(layer, "_saved_bias_sym_sizes_opt"):
93
- is_box = False
94
- units = layer._saved_bias_sym_sizes_opt[0]
95
- elif hasattr(layer, "_saved_mat2_sym_sizes"):
96
- is_box = False
97
- units = layer._saved_mat2_sym_sizes[1]
98
- elif hasattr(layer, "units"): # for dummy input layer
99
- is_box = False
100
- units = layer.units
101
-
102
- n = min(units, ellipsize_after)
103
- layer_nodes = list()
104
-
105
- for i in range(n):
106
- scale = 1
107
- c: Box | Circle | Ellipses
108
- if not is_box:
109
- if i != ellipsize_after - 2:
110
- c = Circle()
111
- else:
112
- c = Ellipses()
113
- else:
114
- c = Box()
115
- scale = 3
116
-
117
- c.x1 = current_x
118
- c.y1 = current_y
119
- c.x2 = c.x1 + node_size
120
- c.y2 = c.y1 + node_size * scale
121
-
122
- current_y = c.y2 + node_spacing
123
-
124
- c.fill = color_map.get(type(layer), {}).get("fill", "#ADD8E6")
125
- c.outline = color_map.get(type(layer), {}).get("outline", "black")
126
-
127
- layer_nodes.append(c)
128
-
129
- id_to_node_list_map[str(id(layer))] = layer_nodes
130
- nodes.extend(layer_nodes)
131
- current_y += 2 * node_size
132
-
133
- layer_y.append(current_y - node_spacing - 2 * node_size)
134
- layers.append(nodes)
135
- current_x += node_size + layer_spacing
87
+ layers, layer_y, id_to_node_list_map = _create_architecture(
88
+ model_layers,
89
+ current_x,
90
+ show_neurons,
91
+ ellipsize_after,
92
+ node_size,
93
+ node_spacing,
94
+ _color_map,
95
+ opacity,
96
+ layer_spacing,
97
+ )
136
98
 
137
99
  # Generate image
138
100
 
139
- img_width = (
140
- len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
141
- )
101
+ img_width = len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
142
102
  img_height = max(*layer_y) + 2 * padding
143
103
  img = Image.new(
144
- "RGBA", (int(ceil(img_width)), int(ceil(img_height))), background_fill
104
+ "RGBA",
105
+ (int(ceil(img_width)), int(ceil(img_height))),
106
+ background_fill,
145
107
  )
146
108
 
147
109
  draw = aggdraw.Draw(img)
@@ -154,7 +116,7 @@ def graph_view(
154
116
  node.y1 += y_off
155
117
  node.y2 += y_off
156
118
 
157
- for start_idx, end_idx in zip(*np.where(adj_matrix > 0)):
119
+ for start_idx, end_idx in zip(*np.where(adj_matrix > 0), strict=False):
158
120
  start_id = next(get_keys_by_value(id_to_num_mapping, start_idx))
159
121
  end_id = next(get_keys_by_value(id_to_num_mapping, end_idx))
160
122
 
@@ -162,10 +124,11 @@ def graph_view(
162
124
  end_layer_list = id_to_node_list_map[end_id]
163
125
 
164
126
  # draw connectors
165
- for start_node_idx, start_node in enumerate(start_layer_list):
127
+ for start_node in start_layer_list:
166
128
  for end_node in end_layer_list:
167
129
  if not isinstance(start_node, Ellipses) and not isinstance(
168
- end_node, Ellipses
130
+ end_node,
131
+ Ellipses,
169
132
  ):
170
133
  _draw_connector(
171
134
  draw,
@@ -175,8 +138,8 @@ def graph_view(
175
138
  width=connector_width,
176
139
  )
177
140
 
178
- for i, layer in enumerate(layers):
179
- for node_index, node in enumerate(layer):
141
+ for layer in layers:
142
+ for node in layer:
180
143
  node.draw(draw)
181
144
 
182
145
  draw.flush()
@@ -187,10 +150,97 @@ def graph_view(
187
150
  return img
188
151
 
189
152
 
190
- def _draw_connector(draw, start_node, end_node, color, width):
153
+ def _draw_connector(
154
+ draw: aggdraw.Draw,
155
+ start_node: Box | Circle | Ellipses,
156
+ end_node: Box | Circle | Ellipses,
157
+ color: str | tuple[int, ...],
158
+ width: int,
159
+ ) -> None:
160
+ """Draw the line connector between nodes."""
191
161
  pen = aggdraw.Pen(color, width)
192
162
  x1 = start_node.x2
193
163
  y1 = start_node.y1 + (start_node.y2 - start_node.y1) / 2
194
164
  x2 = end_node.x1
195
165
  y2 = end_node.y1 + (end_node.y2 - end_node.y1) / 2
196
166
  draw.line([x1, y1, x2, y2], pen)
167
+
168
+
169
+ def _retrieve_isbox_units(layer: torch.autograd.Function, show_neurons: bool) -> tuple[bool, int]:
170
+ """Return the number of units and the flag whether to visualize using a box or not."""
171
+ print("test: ", type(layer))
172
+ is_box = True
173
+ units = 1
174
+ if show_neurons:
175
+ if hasattr(layer, "_saved_bias_sym_sizes_opt"):
176
+ is_box = False
177
+ units = layer._saved_bias_sym_sizes_opt[0] # noqa: SLF001
178
+ elif hasattr(layer, "_saved_mat2_sym_sizes"):
179
+ is_box = False
180
+ units = layer._saved_mat2_sym_sizes[1] # noqa: SLF001
181
+ elif hasattr(layer, "units"): # for dummy input layer
182
+ is_box = False
183
+ units = layer.units
184
+ return is_box, units
185
+
186
+
187
+ def _create_architecture(
188
+ model_layers: list[list],
189
+ current_x: int,
190
+ show_neurons: bool,
191
+ ellipsize_after: int,
192
+ node_size: int,
193
+ node_spacing: int,
194
+ color_map: dict[Any, Any],
195
+ opacity: int,
196
+ layer_spacing: int,
197
+ ) -> tuple[list, list, dict]:
198
+ """Create nodes of architecture for each layers."""
199
+ id_to_node_list_map = {}
200
+ layers = []
201
+ layer_y = []
202
+ for layer_list in model_layers:
203
+ current_y = 0
204
+ nodes = []
205
+ layer: Any
206
+ for layer in layer_list:
207
+ is_box, units = _retrieve_isbox_units(layer, show_neurons)
208
+
209
+ n = min(units, ellipsize_after)
210
+ layer_nodes = []
211
+
212
+ for i in range(n):
213
+ scale = 1
214
+ c: Box | Circle | Ellipses
215
+ if not is_box:
216
+ c = Circle() if i != ellipsize_after - 2 else Ellipses()
217
+ else:
218
+ c = Box()
219
+ scale = 3
220
+
221
+ c.x1 = current_x
222
+ c.y1 = current_y
223
+ c.x2 = c.x1 + node_size
224
+ c.y2 = c.y1 + node_size * scale
225
+
226
+ current_y = c.y2 + node_spacing
227
+
228
+ c.set_fill(
229
+ color_map.get(TARGET_OPS[layer.name()], {}).get("fill", "#ADD8E6"),
230
+ opacity,
231
+ )
232
+ c.outline = color_map.get(TARGET_OPS[layer.name()], {}).get(
233
+ "outline",
234
+ "black",
235
+ )
236
+
237
+ layer_nodes.append(c)
238
+
239
+ id_to_node_list_map[str(id(layer))] = layer_nodes
240
+ nodes.extend(layer_nodes)
241
+ current_y += 2 * node_size
242
+
243
+ layer_y.append(current_y - node_spacing - 2 * node_size)
244
+ layers.append(nodes)
245
+ current_x += node_size + layer_spacing
246
+ return layers, layer_y, id_to_node_list_map