visualtorch 0.2.3__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: visualtorch
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
@@ -21,7 +21,7 @@ def _read_requirements(file: str) -> list:
21
21
 
22
22
  setuptools.setup(
23
23
  name="visualtorch",
24
- version="0.2.3",
24
+ version="0.2.4",
25
25
  author="Willy Fitra Hendria",
26
26
  author_email="willyfitrahendria@gmail.com",
27
27
  description="Architecture visualization of Torch models",
@@ -185,8 +185,13 @@ def register_hook(
185
185
  m_key = "%s-%i" % (class_name, module_idx + 1)
186
186
  layers[m_key] = OrderedDict()
187
187
  layers[m_key]["module"] = module
188
- if isinstance(out, list | tuple):
189
- layers[m_key]["output_shape"] = tuple((-1,) + o.size()[1:] for o in out)
188
+ if isinstance(out, tuple):
189
+ if hasattr(out[0], "size"):
190
+ layers[m_key]["output_shape"] = out[0].size()
191
+ else:
192
+ layers[m_key]["output_shape"] = tuple(o.size() for o in out if hasattr(o, "size"))
193
+ elif isinstance(out, list):
194
+ layers[m_key]["output_shape"] = tuple(o.size() for o in out if hasattr(o, "size"))
190
195
  else:
191
196
  layers[m_key]["output_shape"] = out.size()
192
197
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: visualtorch
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: Architecture visualization of Torch models
5
5
  Home-page: https://github.com/willyfh/visualtorch
6
6
  Author: Willy Fitra Hendria
File without changes
File without changes
File without changes
File without changes