visualtorch 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  MIT License
2
2
 
3
+ Copyright (c) 2020 Paul Gavrikov
4
+
3
5
  Copyright (c) 2024 Willy Fitra Hendria
4
6
 
5
7
  Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: visualtorch
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -14,6 +14,7 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Operating System :: OS Independent
15
15
  Requires-Python: >=3.10
16
16
  Description-Content-Type: text/markdown
17
+ Provides-Extra: dev
17
18
  License-File: LICENSE
18
19
 
19
20
  <div align="center">
@@ -23,13 +24,13 @@ License-File: LICENSE
23
24
 
24
25
  </div>
25
26
 
26
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
27
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
27
28
 
28
29
  **Note:** VisualTorch may not yet support complex models, but contributions are welcome!
29
30
 
30
31
  <div align="center">
31
32
 
32
- ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/7e2c35ea-d34d-4b92-b414-285bccb8576a)
33
+ ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/398c3356-4de0-446b-a30b-d8ebe532d2c2)
33
34
 
34
35
  </div>
35
36
 
@@ -49,11 +50,11 @@ See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage
49
50
 
50
51
  ## Contributing
51
52
 
52
- Please feel free to send a pull request to contribute to this project.
53
+ Please feel free to send a pull request to contribute to this project by following this [guideline](https://github.com/willyfh/visualtorch/blob/main/CONTRIBUTING.md).
53
54
 
54
55
  ## License
55
56
 
56
- This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).
57
+ This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
57
58
 
58
59
  Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
59
60
 
@@ -5,13 +5,13 @@
5
5
 
6
6
  </div>
7
7
 
8
- **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
8
+ **VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style, graph-style, and LeNet-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
9
9
 
10
10
  **Note:** VisualTorch may not yet support complex models, but contributions are welcome!
11
11
 
12
12
  <div align="center">
13
13
 
14
- ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/7e2c35ea-d34d-4b92-b414-285bccb8576a)
14
+ ![VisualTorch Examples](https://github.com/willyfh/visualtorch/assets/5786636/398c3356-4de0-446b-a30b-d8ebe532d2c2)
15
15
 
16
16
  </div>
17
17
 
@@ -31,11 +31,11 @@ See the [Usage Examples page](https://visualtorch.readthedocs.io/en/latest/usage
31
31
 
32
32
  ## Contributing
33
33
 
34
- Please feel free to send a pull request to contribute to this project.
34
+ Please feel free to send a pull request to contribute to this project by following this [guideline](https://github.com/willyfh/visualtorch/blob/main/CONTRIBUTING.md).
35
35
 
36
36
  ## License
37
37
 
38
- This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).
38
+ This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/main/LICENSE).
39
39
 
40
40
  Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.
41
41
 
@@ -0,0 +1,167 @@
1
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
2
+ # SETUP CONFIGURATION. #
3
+ [build-system]
4
+ requires = ["setuptools>=42", "wheel"]
5
+ build-backend = "setuptools.build_meta"
6
+
7
+
8
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
9
+ # RUFF CONFIGURATION #
10
+ [tool.ruff]
11
+ # Enable rules
12
+ select = [
13
+ "F", # Pyflakes (`F`)
14
+ "E", # pycodestyle error (`E`)
15
+ "W", # pycodestyle warning (`W`)
16
+ "C90", # mccabe (`C90`)
17
+ "I", # isort (`I`)
18
+ "N", # pep8-naming (`N`)
19
+ "D", # pydocstyle (`D`)
20
+ "UP", # pyupgrade (`UP`)
21
+ "YTT", # flake8-2020 (`YTT`)
22
+ "ANN", # flake8-annotations (`ANN`)
23
+ "S", # flake8-bandit (`S`)
24
+ "BLE", # flake8-blind-except (`BLE`)
25
+ "FBT", # flake8-boolean-trap (`FBT`)
26
+ "B", # flake8-bugbear (`B`)
27
+ "A", # flake8-builtins (`A`)
28
+ "COM", # flake8-commas (`COM`)
29
+ "CPY", # flake8-copyright (`CPY`)
30
+ "C4", # flake8-comprehensions (`C4`)
31
+ "DTZ", # flake8-datatimez (`DTZ`)
32
+ "T10", # flake8-debugger (`T10`)
33
+ "EM", # flake8-errmsg (`EM`)
34
+ "FA", # flake8-future-annotations (`FA`)
35
+ "ISC", # flake8-implicit-str-concat (`ISC`)
36
+ "ICN", # flake8-import-conventions (`ICN`)
37
+ "PIE", # flake8-pie (`PIE`)
38
+ "PT", # flake8-pytest-style (`PT`)
39
+ "RSE", # flake8-raise (`RSE`)
40
+ "RET", # flake8-return (`RET`)
41
+ "SLF", # flake8-self (`SLF`)
42
+ "SIM", # flake8-simplify (`SIM`)
43
+ "TID", # flake8-tidy-imports (`TID`)
44
+ "TCH", # flake8-type-checking (`TCH`)
45
+ "INT", # flake8-gettext (`INT`)
46
+ "ARG", # flake8-unsused-arguments (`ARG`)
47
+ "PTH", # flake8-use-pathlib (`PTH`)
48
+ "TD", # flake8-todos (`TD`)
49
+ "FIX", # flake8-fixme (`FIX`)
50
+ "ERA", # eradicate (`ERA`)
51
+ "PD", # pandas-vet (`PD`)
52
+ "PGH", # pygrep-hooks (`PGH`)
53
+ "PL", # pylint (`PL`)
54
+ "TRY", # tryceratos (`TRY`)
55
+ "FLY", # flynt (`FLY`)
56
+ "NPY", # NumPy-specific rules (`NPY`)
57
+ "PERF", # Perflint (`PERF`)
58
+ "RUF", # Ruff-specific rules (`RUF`)
59
+ # "FURB", # refurb (`FURB`) - ERROR: Unknown rule selector: `FURB`
60
+ # "LOG", # flake8-logging (`LOG`) - ERROR: Unknown rule selector: `LOG`
61
+ ]
62
+
63
+ ignore = [
64
+ # pydocstyle
65
+ "D107", # Missing docstring in __init__
66
+ "D415", # First line should end with a period, question mark, or exclamation point
67
+
68
+ # pylint
69
+ "PLR0913", # Too many arguments to function call
70
+ "PLR2004", # consider replacing with a constant variable
71
+ "PLR0912", # Too many branches
72
+ "PLR0915", # Too many statements
73
+
74
+ # flake8-annotations
75
+ "ANN101", # Missing-type-self
76
+ "ANN002", # Missing type annotation for *args
77
+ "ANN003", # Missing type annotation for **kwargs
78
+
79
+ # flake8-bandit (`S`)
80
+ "S101", # Use of assert detected.
81
+
82
+ # flake8-boolean-trap (`FBT`)
83
+ "FBT001", # Boolean positional arg in function definition
84
+ "FBT002", # Boolean default value in function definition
85
+
86
+ # flake8-datatimez (`DTZ`)
87
+ "DTZ005", # The use of `datetime.datetime.now()` without `tz` argument is not allowed
88
+
89
+ # flake8-fixme (`FIX`)
90
+ "FIX002", # Line contains TODO, consider resolving the issue
91
+ ]
92
+
93
+ # Allow autofix for all enabled rules (when `--fix`) is provided.
94
+ fixable = ["ALL"]
95
+ unfixable = []
96
+
97
+ # Exclude a variety of commonly ignored directories.
98
+ exclude = [
99
+ ".bzr",
100
+ ".direnv",
101
+ ".eggs",
102
+ ".git",
103
+ ".hg",
104
+ ".mypy_cache",
105
+ ".nox",
106
+ ".pants.d",
107
+ ".pytype",
108
+ ".ruff_cache",
109
+ ".svn",
110
+ ".tox",
111
+ ".venv",
112
+ "__pypackages__",
113
+ "_build",
114
+ "buck-out",
115
+ "build",
116
+ "dist",
117
+ "node_modules",
118
+ "venv",
119
+ ]
120
+
121
+ # Same as Black.
122
+ line-length = 120
123
+
124
+ # Allow unused variables when underscore-prefixed.
125
+ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
126
+
127
+ # Assume Python 3.10.
128
+ target-version = "py310"
129
+
130
+ # Allow imports relative to the "src" and "tests" directories.
131
+ src = ["visualtorch", "tests"]
132
+
133
+ [tool.ruff.mccabe]
134
+ # Unlike Flake8, default to a complexity level of 10.
135
+ max-complexity = 15
136
+
137
+ [tool.ruff.per-file-ignores]
138
+ "tests/nightly/tools/benchmarking/test_benchmarking.py" = ["E402"]
139
+
140
+ [tool.ruff.pydocstyle]
141
+ convention = "google"
142
+
143
+
144
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
145
+ # MYPY CONFIGURATION. #
146
+ [tool.mypy]
147
+ ignore_missing_imports = true
148
+ show_error_codes = true
149
+
150
+
151
+ [[tool.mypy.overrides]]
152
+ module = ["torch.*"]
153
+ follow_imports = "skip"
154
+ follow_imports_for_stubs = true
155
+
156
+
157
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
158
+ # PYTEST CONFIGURATION #
159
+ [tool.pytest.ini_options]
160
+ addopts = [
161
+ "--strict-markers",
162
+ "--strict-config",
163
+ "--showlocals",
164
+ "-ra",
165
+ ]
166
+ testpaths = "tests"
167
+ pythonpath = "visualtorch"
@@ -3,14 +3,25 @@
3
3
  # Copyright (C) 2024 Willy Fitra Hendria
4
4
  # SPDX-License-Identifier: MIT
5
5
 
6
+ from pathlib import Path
7
+
6
8
  import setuptools
7
9
 
8
- with open("README.md", "r") as fh:
10
+ file_path = Path("README.md")
11
+ with file_path.open("r") as fh:
9
12
  long_description = fh.read()
10
13
 
14
+
15
+ def _read_requirements(file: str) -> list:
16
+ file_path = Path(file)
17
+ with file_path.open("r") as fh:
18
+ reqs = fh.read()
19
+ return reqs.strip().split("\n")
20
+
21
+
11
22
  setuptools.setup(
12
23
  name="visualtorch",
13
- version="0.2.1",
24
+ version="0.2.3",
14
25
  author="Willy Fitra Hendria",
15
26
  author_email="willyfitrahendria@gmail.com",
16
27
  description="Architecture visualization of Torch models",
@@ -26,12 +37,8 @@ setuptools.setup(
26
37
  "License :: OSI Approved :: MIT License",
27
38
  "Operating System :: OS Independent",
28
39
  ],
29
- install_requires=[
30
- "pillow>=10.0.0",
31
- "numpy>=1.18.1",
32
- "aggdraw>=1.3.11",
33
- "torch>=2.0.0",
34
- ],
40
+ install_requires=_read_requirements("requirements.txt"),
41
+ extras_require={"dev": _read_requirements("docs/requirements.txt") + _read_requirements("dev-requirements.txt")},
35
42
  python_requires=">=3.10",
36
43
  license="MIT",
37
44
  license_files=("LICENSE",),
@@ -3,7 +3,8 @@
3
3
  # Copyright (C) 2024 Willy Fitra Hendria
4
4
  # SPDX-License-Identifier: MIT
5
5
 
6
- from visualtorch.layered import layered_view
7
6
  from visualtorch.graph import graph_view
7
+ from visualtorch.layered import layered_view
8
+ from visualtorch.lenet_style import lenet_view
8
9
 
9
- __all__ = ["layered_view", "graph_view"]
10
+ __all__ = ["layered_view", "graph_view", "lenet_view"]
@@ -1,37 +1,39 @@
1
1
  """Graph View module for pytorch model visualization."""
2
2
 
3
+ # Copyright (C) 2020 Paul Gavrikov
3
4
  # Copyright (C) 2024 Willy Fitra Hendria
4
5
  # SPDX-License-Identifier: MIT
5
6
 
6
- import aggdraw
7
- from PIL import Image
7
+ from collections import defaultdict
8
8
  from math import ceil
9
- from .layer_utils import model_to_adj_matrix, add_input_dummy_layer, TARGET_OPS
10
- from .utils import Circle, Ellipses, get_keys_by_value, Box
9
+ from typing import Any
10
+
11
+ import aggdraw
11
12
  import numpy as np
12
- from typing import Optional, Dict, Any, Tuple, List
13
13
  import torch
14
+ from PIL import Image
15
+
16
+ from .utils.layer_utils import TARGET_OPS, add_input_dummy_layer, model_to_adj_matrix
17
+ from .utils.utils import Box, Circle, Ellipses, get_keys_by_value
14
18
 
15
19
 
16
20
  def graph_view(
17
21
  model: torch.nn.Module,
18
- input_shape: Tuple[int, ...],
19
- to_file: Optional[str] = None,
20
- color_map: Optional[Dict[Any, Any]] = None,
22
+ input_shape: tuple[int, ...],
23
+ to_file: str | None = None,
24
+ color_map: dict[Any, Any] | None = None,
21
25
  node_size: int = 50,
22
- background_fill: Any = "white",
26
+ background_fill: str | tuple[int, ...] = "white",
23
27
  padding: int = 10,
24
28
  layer_spacing: int = 250,
25
29
  node_spacing: int = 10,
26
- connector_fill: Any = "gray",
30
+ connector_fill: str | tuple[int, ...] = "gray",
27
31
  connector_width: int = 1,
28
32
  ellipsize_after: int = 10,
29
- inout_as_tensor: bool = True,
30
33
  show_neurons: bool = True,
34
+ opacity: int = 255,
31
35
  ) -> Image.Image:
32
- """
33
- Generates an architecture visualization for a given linear PyTorch model (i.e., one input and output tensor for each
34
- layer) in a graph style.
36
+ """Generates an architecture visualization for a given linear PyTorch model in a graph style.
35
37
 
36
38
  Args:
37
39
  model (torch.nn.Module): A PyTorch model that will be visualized.
@@ -39,8 +41,8 @@ def graph_view(
39
41
  to_file (str, optional): Path to the file to write the created image to. If the image does not exist yet,
40
42
  it will be created, else overwritten. Image type is inferred from the file ending. Providing None
41
43
  will disable writing.
42
- color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback to default
43
- values for not specified classes.
44
+ color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback
45
+ to default values for not specified classes.
44
46
  node_size (int, optional): Size in pixels each node will have.
45
47
  background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
46
48
  padding (int, optional): Distance in pixels before the first and after the last layer.
@@ -50,108 +52,59 @@ def graph_view(
50
52
  connector_width (int, optional): Line-width of the connectors in pixels.
51
53
  ellipsize_after (int, optional): Maximum number of neurons per layer to draw. If a layer is exceeding this,
52
54
  the remaining neurons will be drawn as ellipses.
53
- inout_as_tensor (bool, optional): If True there will be one input and output node for each tensor, else the
54
- tensor will be flattened and one node for each scalar will be created (e.g., a (10, 10) shape will be
55
- represented by 100 nodes).
56
55
  show_neurons (bool, optional): If True a node for each neuron in supported layers is created (constrained by
57
56
  ellipsize_after), else each layer is represented by a node.
57
+ opacity (int, optional): Transparency of the color (0 ~ 255).
58
58
 
59
59
  Returns:
60
60
  Image.Image: Generated architecture image.
61
61
  """
62
-
63
- if color_map is None:
64
- color_map = dict()
62
+ _color_map: dict = {}
63
+ if color_map is not None:
64
+ _color_map = defaultdict(dict, color_map)
65
65
 
66
66
  # Iterate over the model to compute bounds and generate boxes
67
67
 
68
- layers: List[Any] = list()
69
- layer_y = list()
70
-
71
68
  # Attach helper layers
72
69
 
73
70
  id_to_num_mapping, adj_matrix, model_layers = model_to_adj_matrix(
74
- model, input_shape
71
+ model,
72
+ input_shape,
75
73
  )
76
74
 
77
75
  # Add fake input layers
78
76
 
79
77
  id_to_num_mapping, adj_matrix, model_layers = add_input_dummy_layer(
80
- input_shape, id_to_num_mapping, adj_matrix, model_layers
78
+ input_shape,
79
+ id_to_num_mapping,
80
+ adj_matrix,
81
+ model_layers,
81
82
  )
82
83
 
83
84
  # Create architecture
84
85
 
85
86
  current_x = padding # + input_label_size[0] + text_padding
86
87
 
87
- id_to_node_list_map = dict()
88
-
89
- for index, layer_list in enumerate(model_layers):
90
- current_y = 0
91
- nodes = []
92
- layer: Any
93
- for layer in layer_list:
94
- is_box = True
95
- units = 1
96
-
97
- if show_neurons:
98
- if hasattr(layer, "_saved_bias_sym_sizes_opt"):
99
- is_box = False
100
- units = layer._saved_bias_sym_sizes_opt[0]
101
- elif hasattr(layer, "_saved_mat2_sym_sizes"):
102
- is_box = False
103
- units = layer._saved_mat2_sym_sizes[1]
104
- elif hasattr(layer, "units"): # for dummy input layer
105
- is_box = False
106
- units = layer.units
107
-
108
- n = min(units, ellipsize_after)
109
- layer_nodes = list()
110
-
111
- for i in range(n):
112
- scale = 1
113
- c: Box | Circle | Ellipses
114
- if not is_box:
115
- if i != ellipsize_after - 2:
116
- c = Circle()
117
- else:
118
- c = Ellipses()
119
- else:
120
- c = Box()
121
- scale = 3
122
-
123
- c.x1 = current_x
124
- c.y1 = current_y
125
- c.x2 = c.x1 + node_size
126
- c.y2 = c.y1 + node_size * scale
127
-
128
- current_y = c.y2 + node_spacing
129
-
130
- c.fill = color_map.get(TARGET_OPS[layer.name()], {}).get(
131
- "fill", "#ADD8E6"
132
- )
133
- c.outline = color_map.get(TARGET_OPS[layer.name()], {}).get(
134
- "outline", "black"
135
- )
136
-
137
- layer_nodes.append(c)
138
-
139
- id_to_node_list_map[str(id(layer))] = layer_nodes
140
- nodes.extend(layer_nodes)
141
- current_y += 2 * node_size
142
-
143
- layer_y.append(current_y - node_spacing - 2 * node_size)
144
- layers.append(nodes)
145
- current_x += node_size + layer_spacing
88
+ layers, layer_y, id_to_node_list_map = _create_architecture(
89
+ model_layers,
90
+ current_x,
91
+ show_neurons,
92
+ ellipsize_after,
93
+ node_size,
94
+ node_spacing,
95
+ _color_map,
96
+ opacity,
97
+ layer_spacing,
98
+ )
146
99
 
147
100
  # Generate image
148
101
 
149
- img_width = (
150
- len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
151
- )
102
+ img_width = len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
152
103
  img_height = max(*layer_y) + 2 * padding
153
104
  img = Image.new(
154
- "RGBA", (int(ceil(img_width)), int(ceil(img_height))), background_fill
105
+ "RGBA",
106
+ (int(ceil(img_width)), int(ceil(img_height))),
107
+ background_fill,
155
108
  )
156
109
 
157
110
  draw = aggdraw.Draw(img)
@@ -164,7 +117,7 @@ def graph_view(
164
117
  node.y1 += y_off
165
118
  node.y2 += y_off
166
119
 
167
- for start_idx, end_idx in zip(*np.where(adj_matrix > 0)):
120
+ for start_idx, end_idx in zip(*np.where(adj_matrix > 0), strict=False):
168
121
  start_id = next(get_keys_by_value(id_to_num_mapping, start_idx))
169
122
  end_id = next(get_keys_by_value(id_to_num_mapping, end_idx))
170
123
 
@@ -172,10 +125,11 @@ def graph_view(
172
125
  end_layer_list = id_to_node_list_map[end_id]
173
126
 
174
127
  # draw connectors
175
- for start_node_idx, start_node in enumerate(start_layer_list):
128
+ for start_node in start_layer_list:
176
129
  for end_node in end_layer_list:
177
130
  if not isinstance(start_node, Ellipses) and not isinstance(
178
- end_node, Ellipses
131
+ end_node,
132
+ Ellipses,
179
133
  ):
180
134
  _draw_connector(
181
135
  draw,
@@ -185,8 +139,8 @@ def graph_view(
185
139
  width=connector_width,
186
140
  )
187
141
 
188
- for i, layer in enumerate(layers):
189
- for node_index, node in enumerate(layer):
142
+ for layer in layers:
143
+ for node in layer:
190
144
  node.draw(draw)
191
145
 
192
146
  draw.flush()
@@ -197,10 +151,96 @@ def graph_view(
197
151
  return img
198
152
 
199
153
 
200
- def _draw_connector(draw, start_node, end_node, color, width):
154
+ def _draw_connector(
155
+ draw: aggdraw.Draw,
156
+ start_node: Box | Circle | Ellipses,
157
+ end_node: Box | Circle | Ellipses,
158
+ color: str | tuple[int, ...],
159
+ width: int,
160
+ ) -> None:
161
+ """Draw the line connector between nodes."""
201
162
  pen = aggdraw.Pen(color, width)
202
163
  x1 = start_node.x2
203
164
  y1 = start_node.y1 + (start_node.y2 - start_node.y1) / 2
204
165
  x2 = end_node.x1
205
166
  y2 = end_node.y1 + (end_node.y2 - end_node.y1) / 2
206
167
  draw.line([x1, y1, x2, y2], pen)
168
+
169
+
170
+ def _retrieve_isbox_units(layer: torch.autograd.Function, show_neurons: bool) -> tuple[bool, int]:
171
+ """Return the number of units and the flag whether to visualize using a box or not."""
172
+ is_box = True
173
+ units = 1
174
+ if show_neurons:
175
+ if hasattr(layer, "_saved_bias_sym_sizes_opt"):
176
+ is_box = False
177
+ units = layer._saved_bias_sym_sizes_opt[0] # noqa: SLF001
178
+ elif hasattr(layer, "_saved_mat2_sym_sizes"):
179
+ is_box = False
180
+ units = layer._saved_mat2_sym_sizes[1] # noqa: SLF001
181
+ elif hasattr(layer, "units"): # for dummy input layer
182
+ is_box = False
183
+ units = layer.units
184
+ return is_box, units
185
+
186
+
187
+ def _create_architecture(
188
+ model_layers: list[list],
189
+ current_x: int,
190
+ show_neurons: bool,
191
+ ellipsize_after: int,
192
+ node_size: int,
193
+ node_spacing: int,
194
+ color_map: dict[Any, Any],
195
+ opacity: int,
196
+ layer_spacing: int,
197
+ ) -> tuple[list, list, dict]:
198
+ """Create nodes of architecture for each layers."""
199
+ id_to_node_list_map = {}
200
+ layers = []
201
+ layer_y = []
202
+ for layer_list in model_layers:
203
+ current_y = 0
204
+ nodes = []
205
+ layer: Any
206
+ for layer in layer_list:
207
+ is_box, units = _retrieve_isbox_units(layer, show_neurons)
208
+
209
+ n = min(units, ellipsize_after)
210
+ layer_nodes = []
211
+
212
+ for i in range(n):
213
+ scale = 1
214
+ c: Box | Circle | Ellipses
215
+ if not is_box:
216
+ c = Circle() if i != ellipsize_after - 2 else Ellipses()
217
+ else:
218
+ c = Box()
219
+ scale = 3
220
+
221
+ c.x1 = current_x
222
+ c.y1 = current_y
223
+ c.x2 = c.x1 + node_size
224
+ c.y2 = c.y1 + node_size * scale
225
+
226
+ current_y = c.y2 + node_spacing
227
+
228
+ c.set_fill(
229
+ color_map.get(TARGET_OPS[layer.name()], {}).get("fill", "#ADD8E6"),
230
+ opacity,
231
+ )
232
+ c.outline = color_map.get(TARGET_OPS[layer.name()], {}).get(
233
+ "outline",
234
+ "black",
235
+ )
236
+
237
+ layer_nodes.append(c)
238
+
239
+ id_to_node_list_map[str(id(layer))] = layer_nodes
240
+ nodes.extend(layer_nodes)
241
+ current_y += 2 * node_size
242
+
243
+ layer_y.append(current_y - node_spacing - 2 * node_size)
244
+ layers.append(nodes)
245
+ current_x += node_size + layer_spacing
246
+ return layers, layer_y, id_to_node_list_map