lucid-dl 2.11.3__tar.gz → 2.11.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lucid_dl-2.11.3/lucid_dl.egg-info → lucid_dl-2.11.5}/PKG-INFO +3 -13
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/README.md +2 -12
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/module.py +55 -21
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/types.py +58 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/visual/__init__.py +0 -1
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/visual/mermaid.py +188 -2
- {lucid_dl-2.11.3 → lucid_dl-2.11.5/lucid_dl.egg-info}/PKG-INFO +3 -13
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/SOURCES.txt +0 -1
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/setup.py +1 -1
- lucid_dl-2.11.3/lucid/visual/graph.py +0 -141
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/LICENSE +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_backend/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_backend/core.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_backend/metal.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/bfunc.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/gfunc.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/ufunc.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_fusion/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_fusion/base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_fusion/func.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_tensor/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_tensor/base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_tensor/tensor.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_util/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_util/func.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/autograd/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/data/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/data/_base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/data/_util.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/_base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/cifar.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/mnist.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/einops/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/einops/_func.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/error.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/linalg/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/linalg/_func.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/alex.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/coatnet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/convnext.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/crossvit.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/cspnet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/cvt.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/dense.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/efficient.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/efficientformer.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/inception.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/inception_next.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/inception_res.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/lenet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/maxvit.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/mobile.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/pvt.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/resnest.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/resnet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/resnext.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/senet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/sknet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/swin.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/vgg.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/vit.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/xception.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/zfnet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/ddpm.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/ncsn.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/vae.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/detr.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/efficientdet.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/fast_rcnn.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/faster_rcnn.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/rcnn.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/util.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/seq2seq/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/seq2seq/transformer.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/util.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/activation.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/attention.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/conv.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/embedding.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/loss.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/norm.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/pool.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_activation.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_attention.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_conv.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_drop.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_linear.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_loss.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_norm.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_pool.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_spatial.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_util.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/fused.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/init/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/init/_dist.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/activation.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/attention.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/conv.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/drop.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/einops.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/linear.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/loss.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/norm.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/pool.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/rnn.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/sparse.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/transformer.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/vision.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/parameter.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/util.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/_base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/ada.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/adam.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/lr_scheduler/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/lr_scheduler/_base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/prop.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/sgd.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/port.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/random/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/random/_func.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/transforms/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/transforms/_base.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/transforms/image.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/weights/__init__.py +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/weights/__init__.pyi +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/dependency_links.txt +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/requires.txt +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/top_level.txt +0 -0
- {lucid_dl-2.11.3 → lucid_dl-2.11.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.11.
|
|
3
|
+
Version: 2.11.5
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -48,19 +48,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
48
48
|
|
|
49
49
|
### 🔥 What's New
|
|
50
50
|
|
|
51
|
-
- Added
|
|
52
|
-
|
|
53
|
-
```python
|
|
54
|
-
def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
55
|
-
|
|
56
|
-
def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
51
|
+
- Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
|
|
57
52
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def register_full_backward_pre_hook(self, hook: Callable)
|
|
61
|
-
|
|
62
|
-
def register_full_backward_hook(self, hook: Callable)
|
|
63
|
-
```
|
|
53
|
+
- Added additional `nn.Module` hooks for richer introspection during training:
|
|
64
54
|
|
|
65
55
|
## 🔧 How to Install
|
|
66
56
|
|
|
@@ -20,19 +20,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
20
20
|
|
|
21
21
|
### 🔥 What's New
|
|
22
22
|
|
|
23
|
-
- Added
|
|
24
|
-
|
|
25
|
-
```python
|
|
26
|
-
def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
27
|
-
|
|
28
|
-
def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
23
|
+
- Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
|
|
29
24
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def register_full_backward_pre_hook(self, hook: Callable)
|
|
33
|
-
|
|
34
|
-
def register_full_backward_hook(self, hook: Callable)
|
|
35
|
-
```
|
|
25
|
+
- Added additional `nn.Module` hooks for richer introspection during training:
|
|
36
26
|
|
|
37
27
|
## 🔧 How to Install
|
|
38
28
|
|
|
@@ -13,7 +13,22 @@ from typing import (
|
|
|
13
13
|
from collections import OrderedDict
|
|
14
14
|
|
|
15
15
|
from lucid._tensor import Tensor
|
|
16
|
-
from lucid.types import
|
|
16
|
+
from lucid.types import (
|
|
17
|
+
_ArrayOrScalar,
|
|
18
|
+
_BackwardHook,
|
|
19
|
+
_DeviceType,
|
|
20
|
+
_ForwardHook,
|
|
21
|
+
_ForwardHookKwargs,
|
|
22
|
+
_ForwardPreHook,
|
|
23
|
+
_ForwardPreHookKwargs,
|
|
24
|
+
_FullBackwardHook,
|
|
25
|
+
_FullBackwardPreHook,
|
|
26
|
+
_LoadStateDictPostHook,
|
|
27
|
+
_LoadStateDictPreHook,
|
|
28
|
+
_NumPyArray,
|
|
29
|
+
_StateDictHook,
|
|
30
|
+
_StateDictPreHook,
|
|
31
|
+
)
|
|
17
32
|
|
|
18
33
|
import lucid.nn as nn
|
|
19
34
|
|
|
@@ -30,26 +45,6 @@ __all__ = [
|
|
|
30
45
|
]
|
|
31
46
|
|
|
32
47
|
|
|
33
|
-
_ForwardPreHook = Callable[["Module", tuple[Any, ...]], tuple[Any, ...] | None]
|
|
34
|
-
_ForwardPreHookKwargs = Callable[
|
|
35
|
-
["Module", tuple[Any, ...], dict[str, Any]],
|
|
36
|
-
tuple[tuple[Any, ...], dict[str, Any]] | None,
|
|
37
|
-
]
|
|
38
|
-
_ForwardHook = Callable[["Module", tuple[Any, ...], Any], Any | None]
|
|
39
|
-
_ForwardHookKwargs = Callable[
|
|
40
|
-
["Module", tuple[Any, ...], dict[str, Any], Any], Any | None
|
|
41
|
-
]
|
|
42
|
-
|
|
43
|
-
_BackwardHook = Callable[[Tensor, _NumPyArray], None]
|
|
44
|
-
_FullBackwardPreHook = Callable[
|
|
45
|
-
["Module", tuple[_NumPyArray | None, ...]], tuple[_NumPyArray | None, ...] | None
|
|
46
|
-
]
|
|
47
|
-
_FullBackwardHook = Callable[
|
|
48
|
-
["Module", tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
|
|
49
|
-
tuple[_NumPyArray | None, ...] | None,
|
|
50
|
-
]
|
|
51
|
-
|
|
52
|
-
|
|
53
48
|
class Module:
|
|
54
49
|
_registry_map: dict[Type, OrderedDict[str, Any]] = {}
|
|
55
50
|
_alt_name: str = ""
|
|
@@ -70,10 +65,17 @@ class Module:
|
|
|
70
65
|
tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
|
|
71
66
|
] = []
|
|
72
67
|
self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
|
|
68
|
+
|
|
73
69
|
self._backward_hooks: list[_BackwardHook] = []
|
|
74
70
|
self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
|
|
75
71
|
self._full_backward_hooks: list[_FullBackwardHook] = []
|
|
76
72
|
|
|
73
|
+
self._state_dict_pre_hooks: list[_StateDictPreHook] = []
|
|
74
|
+
self._state_dict_hooks: list[_StateDictHook] = []
|
|
75
|
+
|
|
76
|
+
self._load_state_dict_pre_hooks: list[_LoadStateDictPreHook] = []
|
|
77
|
+
self._load_state_dict_post_hooks: list[_LoadStateDictPostHook] = []
|
|
78
|
+
|
|
77
79
|
self._state_dict_pass_attr = set()
|
|
78
80
|
|
|
79
81
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
@@ -155,6 +157,26 @@ class Module:
|
|
|
155
157
|
self._full_backward_hooks.append(hook)
|
|
156
158
|
return lambda: self._full_backward_hooks.remove(hook)
|
|
157
159
|
|
|
160
|
+
def register_state_dict_pre_hook(self, hook: _StateDictPreHook) -> Callable:
|
|
161
|
+
self._state_dict_pre_hooks.append(hook)
|
|
162
|
+
return lambda: self._state_dict_pre_hooks.remove(hook)
|
|
163
|
+
|
|
164
|
+
def register_state_dict_hook(self, hook: _StateDictHook) -> Callable:
|
|
165
|
+
self._state_dict_hooks.append(hook)
|
|
166
|
+
return lambda: self._state_dict_hooks.remove(hook)
|
|
167
|
+
|
|
168
|
+
def register_load_state_dict_pre_hook(
|
|
169
|
+
self, hook: _LoadStateDictPreHook
|
|
170
|
+
) -> Callable:
|
|
171
|
+
self._load_state_dict_pre_hooks.append(hook)
|
|
172
|
+
return lambda: self._load_state_dict_pre_hooks.remove(hook)
|
|
173
|
+
|
|
174
|
+
def register_load_state_dict_post_hook(
|
|
175
|
+
self, hook: _LoadStateDictPostHook
|
|
176
|
+
) -> Callable:
|
|
177
|
+
self._load_state_dict_post_hooks.append(hook)
|
|
178
|
+
return lambda: self._load_state_dict_post_hooks.remove(hook)
|
|
179
|
+
|
|
158
180
|
def reset_parameters(self) -> None:
|
|
159
181
|
for param in self.parameters():
|
|
160
182
|
param.zero()
|
|
@@ -231,6 +253,9 @@ class Module:
|
|
|
231
253
|
prefix: str = "",
|
|
232
254
|
keep_vars: bool = False,
|
|
233
255
|
) -> OrderedDict:
|
|
256
|
+
for hook in self._state_dict_pre_hooks:
|
|
257
|
+
hook(self, prefix, keep_vars)
|
|
258
|
+
|
|
234
259
|
if destination is None:
|
|
235
260
|
destination = OrderedDict()
|
|
236
261
|
|
|
@@ -249,9 +274,15 @@ class Module:
|
|
|
249
274
|
if key in self._state_dict_pass_attr:
|
|
250
275
|
del destination[key]
|
|
251
276
|
|
|
277
|
+
for hook in self._state_dict_hooks:
|
|
278
|
+
hook(self, destination, prefix, keep_vars)
|
|
279
|
+
|
|
252
280
|
return destination
|
|
253
281
|
|
|
254
282
|
def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
|
|
283
|
+
for hook in self._load_state_dict_pre_hooks:
|
|
284
|
+
hook(self, state_dict, strict)
|
|
285
|
+
|
|
255
286
|
own_state = self.state_dict(keep_vars=True)
|
|
256
287
|
|
|
257
288
|
missing_keys = set(own_state.keys()) - set(state_dict.keys())
|
|
@@ -277,6 +308,9 @@ class Module:
|
|
|
277
308
|
elif strict:
|
|
278
309
|
raise KeyError(f"Unexpected key '{key}' in state_dict.")
|
|
279
310
|
|
|
311
|
+
for hook in self._load_state_dict_post_hooks:
|
|
312
|
+
hook(self, missing_keys, unexpected_keys, strict)
|
|
313
|
+
|
|
280
314
|
def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
|
|
281
315
|
for hook, with_kwargs in self._forward_pre_hooks:
|
|
282
316
|
if with_kwargs:
|
|
@@ -6,8 +6,10 @@ from typing import (
|
|
|
6
6
|
Sequence,
|
|
7
7
|
Literal,
|
|
8
8
|
TypeAlias,
|
|
9
|
+
TYPE_CHECKING,
|
|
9
10
|
runtime_checkable,
|
|
10
11
|
)
|
|
12
|
+
from collections import OrderedDict
|
|
11
13
|
import re
|
|
12
14
|
|
|
13
15
|
import numpy as np
|
|
@@ -76,6 +78,62 @@ class _TensorLike(Protocol):
|
|
|
76
78
|
) -> None: ...
|
|
77
79
|
|
|
78
80
|
|
|
81
|
+
@runtime_checkable
|
|
82
|
+
class _ModuleHookable(Protocol):
|
|
83
|
+
def register_forward_pre_hook(
|
|
84
|
+
self, hook: Callable, *, with_kwargs: bool = False
|
|
85
|
+
) -> Callable: ...
|
|
86
|
+
|
|
87
|
+
def register_forward_hook(
|
|
88
|
+
self, hook: Callable, *, with_kwargs: bool = False
|
|
89
|
+
) -> Callable: ...
|
|
90
|
+
|
|
91
|
+
def register_backward_hook(self, hook: Callable) -> Callable: ...
|
|
92
|
+
|
|
93
|
+
def register_full_backward_pre_hook(self, hook: Callable) -> Callable: ...
|
|
94
|
+
|
|
95
|
+
def register_full_backward_hook(self, hook: Callable) -> Callable: ...
|
|
96
|
+
|
|
97
|
+
def register_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
|
|
98
|
+
|
|
99
|
+
def register_state_dict_hook(self, hook: Callable) -> Callable: ...
|
|
100
|
+
|
|
101
|
+
def register_load_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
|
|
102
|
+
|
|
103
|
+
def register_load_state_dict_post_hook(self, hook: Callable) -> Callable: ...
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
_ForwardPreHook: TypeAlias = Callable[
|
|
107
|
+
[_ModuleHookable, tuple[Any, ...]], tuple[Any, ...] | None
|
|
108
|
+
]
|
|
109
|
+
_ForwardPreHookKwargs: TypeAlias = Callable[
|
|
110
|
+
[_ModuleHookable, tuple[Any, ...], dict[str, Any]],
|
|
111
|
+
tuple[tuple[Any, ...], dict[str, Any]] | None,
|
|
112
|
+
]
|
|
113
|
+
_ForwardHook: TypeAlias = Callable[[_ModuleHookable, tuple[Any, ...], Any], Any | None]
|
|
114
|
+
_ForwardHookKwargs: TypeAlias = Callable[
|
|
115
|
+
[_ModuleHookable, tuple[Any, ...], dict[str, Any], Any], Any | None
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
_BackwardHook: TypeAlias = Callable[[_TensorLike, _NumPyArray], None]
|
|
119
|
+
_FullBackwardPreHook: TypeAlias = Callable[
|
|
120
|
+
[_ModuleHookable, tuple[_NumPyArray | None, ...]],
|
|
121
|
+
tuple[_NumPyArray | None, ...] | None,
|
|
122
|
+
]
|
|
123
|
+
_FullBackwardHook: TypeAlias = Callable[
|
|
124
|
+
[_ModuleHookable, tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
|
|
125
|
+
tuple[_NumPyArray | None, ...] | None,
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
_StateDictPreHook: TypeAlias = Callable[[_ModuleHookable, str, bool], None]
|
|
129
|
+
_StateDictHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, str, bool], None]
|
|
130
|
+
|
|
131
|
+
_LoadStateDictPreHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, bool], None]
|
|
132
|
+
_LoadStateDictPostHook: TypeAlias = Callable[
|
|
133
|
+
[_ModuleHookable, set[str], set[str], bool], None
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
|
|
79
137
|
class Numeric:
|
|
80
138
|
def __init__(
|
|
81
139
|
self, base_dtype: type[int | float | complex], bits: int | None
|
|
@@ -9,7 +9,7 @@ from lucid._tensor import Tensor
|
|
|
9
9
|
from lucid.types import _ShapeLike
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
__all__ = ["
|
|
12
|
+
__all__ = ["build_tensor_mermaid_chart", "build_module_mermaid_chart"]
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
_NN_MODULES_PREFIX = "lucid.nn.modules."
|
|
@@ -255,7 +255,7 @@ def _collapse_repeated_children(
|
|
|
255
255
|
return out
|
|
256
256
|
|
|
257
257
|
|
|
258
|
-
def
|
|
258
|
+
def build_module_mermaid_chart(
|
|
259
259
|
module: nn.Module,
|
|
260
260
|
input_shape: _ShapeLike | list[_ShapeLike] | None = None,
|
|
261
261
|
inputs: Iterable[Tensor] | Tensor | None = None,
|
|
@@ -751,6 +751,192 @@ def build_mermaid_chart(
|
|
|
751
751
|
return text
|
|
752
752
|
|
|
753
753
|
|
|
754
|
+
def build_mermaid_chart(
|
|
755
|
+
module: nn.Module,
|
|
756
|
+
input_shape: _ShapeLike | list[_ShapeLike] | None = None,
|
|
757
|
+
inputs: Iterable[Tensor] | Tensor | None = None,
|
|
758
|
+
depth: int = 2,
|
|
759
|
+
direction: str = "LR",
|
|
760
|
+
include_io: bool = True,
|
|
761
|
+
show_params: bool = False,
|
|
762
|
+
return_lines: bool = False,
|
|
763
|
+
copy_to_clipboard: bool = False,
|
|
764
|
+
compact: bool = False,
|
|
765
|
+
use_class_defs: bool = False,
|
|
766
|
+
end_semicolons: bool = True,
|
|
767
|
+
edge_mode: Literal["dataflow", "execution"] = "execution",
|
|
768
|
+
collapse_repeats: bool = True,
|
|
769
|
+
repeat_min: int = 2,
|
|
770
|
+
color_by_subpackage: bool = True,
|
|
771
|
+
container_name_from_attr: bool = True,
|
|
772
|
+
edge_stroke_width: float = 2.0,
|
|
773
|
+
emphasize_model_title: bool = True,
|
|
774
|
+
model_title_font_px: int = 20,
|
|
775
|
+
show_shapes: bool = False,
|
|
776
|
+
hide_subpackages: Iterable[str] = (),
|
|
777
|
+
hide_module_names: Iterable[str] = (),
|
|
778
|
+
dash_multi_input_edges: bool = True,
|
|
779
|
+
subgraph_fill: str = "#000000",
|
|
780
|
+
subgraph_fill_opacity: float = 0.05,
|
|
781
|
+
subgraph_stroke: str = "#000000",
|
|
782
|
+
subgraph_stroke_opacity: float = 0.75,
|
|
783
|
+
force_text_color: str | None = None,
|
|
784
|
+
edge_curve: str = "natural",
|
|
785
|
+
node_spacing: int = 50,
|
|
786
|
+
rank_spacing: int = 50,
|
|
787
|
+
**forward_kwargs,
|
|
788
|
+
) -> str | list[str]:
|
|
789
|
+
return build_module_mermaid_chart(
|
|
790
|
+
module,
|
|
791
|
+
input_shape=input_shape,
|
|
792
|
+
inputs=inputs,
|
|
793
|
+
depth=depth,
|
|
794
|
+
direction=direction,
|
|
795
|
+
include_io=include_io,
|
|
796
|
+
show_params=show_params,
|
|
797
|
+
return_lines=return_lines,
|
|
798
|
+
copy_to_clipboard=copy_to_clipboard,
|
|
799
|
+
compact=compact,
|
|
800
|
+
use_class_defs=use_class_defs,
|
|
801
|
+
end_semicolons=end_semicolons,
|
|
802
|
+
edge_mode=edge_mode,
|
|
803
|
+
collapse_repeats=collapse_repeats,
|
|
804
|
+
repeat_min=repeat_min,
|
|
805
|
+
color_by_subpackage=color_by_subpackage,
|
|
806
|
+
container_name_from_attr=container_name_from_attr,
|
|
807
|
+
edge_stroke_width=edge_stroke_width,
|
|
808
|
+
emphasize_model_title=emphasize_model_title,
|
|
809
|
+
model_title_font_px=model_title_font_px,
|
|
810
|
+
show_shapes=show_shapes,
|
|
811
|
+
hide_subpackages=hide_subpackages,
|
|
812
|
+
hide_module_names=hide_module_names,
|
|
813
|
+
dash_multi_input_edges=dash_multi_input_edges,
|
|
814
|
+
subgraph_fill=subgraph_fill,
|
|
815
|
+
subgraph_fill_opacity=subgraph_fill_opacity,
|
|
816
|
+
subgraph_stroke=subgraph_stroke,
|
|
817
|
+
subgraph_stroke_opacity=subgraph_stroke_opacity,
|
|
818
|
+
force_text_color=force_text_color,
|
|
819
|
+
edge_curve=edge_curve,
|
|
820
|
+
node_spacing=node_spacing,
|
|
821
|
+
rank_spacing=rank_spacing,
|
|
822
|
+
**forward_kwargs,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
def build_tensor_mermaid_chart(
|
|
827
|
+
tensor: Tensor,
|
|
828
|
+
horizontal: bool = False,
|
|
829
|
+
title: str | None = None,
|
|
830
|
+
start_id: int | None = None,
|
|
831
|
+
end_semicolons: bool = True,
|
|
832
|
+
copy_to_clipboard: bool = False,
|
|
833
|
+
use_class_defs: bool = True,
|
|
834
|
+
op_fill: str = "lightgreen",
|
|
835
|
+
param_fill: str = "plum",
|
|
836
|
+
result_fill: str = "lightcoral",
|
|
837
|
+
leaf_fill: str = "lightgray",
|
|
838
|
+
grad_fill: str = "lightblue",
|
|
839
|
+
start_fill: str = "gold",
|
|
840
|
+
stroke_color: str = "#666",
|
|
841
|
+
stroke_width_px: int = 1,
|
|
842
|
+
) -> str:
|
|
843
|
+
direction = "LR" if horizontal else "TD"
|
|
844
|
+
lines: list[str] = [f"flowchart {direction}"]
|
|
845
|
+
if title:
|
|
846
|
+
lines.append(f"%% {title}")
|
|
847
|
+
|
|
848
|
+
result_id: int = id(tensor)
|
|
849
|
+
visited: set[int] = set()
|
|
850
|
+
nodes_to_draw: list[Tensor] = []
|
|
851
|
+
|
|
852
|
+
def dfs(t: Tensor) -> None:
|
|
853
|
+
if id(t) in visited:
|
|
854
|
+
return
|
|
855
|
+
visited.add(id(t))
|
|
856
|
+
for p in t._prev:
|
|
857
|
+
dfs(p)
|
|
858
|
+
nodes_to_draw.append(t)
|
|
859
|
+
|
|
860
|
+
def tensor_node_id(t: Tensor) -> str:
|
|
861
|
+
return f"t_{id(t)}"
|
|
862
|
+
|
|
863
|
+
def op_node_id(op: object) -> str:
|
|
864
|
+
return f"op_{id(op)}"
|
|
865
|
+
|
|
866
|
+
def add_node(node_id: str, label: str, kind: str) -> None:
|
|
867
|
+
if node_id in defined_nodes:
|
|
868
|
+
return
|
|
869
|
+
defined_nodes.add(node_id)
|
|
870
|
+
if kind == "op":
|
|
871
|
+
lines.append(f'{node_id}(("{label}"))')
|
|
872
|
+
else:
|
|
873
|
+
lines.append(f'{node_id}["{label}"]')
|
|
874
|
+
|
|
875
|
+
dfs(tensor)
|
|
876
|
+
|
|
877
|
+
defined_nodes: set[str] = set()
|
|
878
|
+
edge_lines: list[str] = []
|
|
879
|
+
class_lines: list[str] = []
|
|
880
|
+
|
|
881
|
+
for t in nodes_to_draw:
|
|
882
|
+
t_id = tensor_node_id(t)
|
|
883
|
+
|
|
884
|
+
if not t.is_leaf and t._op is not None:
|
|
885
|
+
op_id = op_node_id(t._op)
|
|
886
|
+
op_label = type(t._op).__name__
|
|
887
|
+
add_node(op_id, op_label, "op")
|
|
888
|
+
edge_lines.append(f"{op_id} --> {t_id}")
|
|
889
|
+
class_lines.append(f"class {op_id} op")
|
|
890
|
+
for inp in t._prev:
|
|
891
|
+
edge_lines.append(f"{tensor_node_id(inp)} --> {op_id}")
|
|
892
|
+
|
|
893
|
+
shape_label = str(t.shape) if t.ndim > 0 else str(t.item())
|
|
894
|
+
add_node(t_id, shape_label, "tensor")
|
|
895
|
+
|
|
896
|
+
if start_id is not None and id(t) == start_id:
|
|
897
|
+
class_lines.append(f"class {t_id} start")
|
|
898
|
+
elif isinstance(t, nn.Parameter):
|
|
899
|
+
class_lines.append(f"class {t_id} param")
|
|
900
|
+
elif id(t) == result_id:
|
|
901
|
+
class_lines.append(f"class {t_id} result")
|
|
902
|
+
elif not t.requires_grad:
|
|
903
|
+
class_lines.append(f"class {t_id} leaf")
|
|
904
|
+
else:
|
|
905
|
+
class_lines.append(f"class {t_id} grad")
|
|
906
|
+
|
|
907
|
+
lines.extend(edge_lines)
|
|
908
|
+
if use_class_defs:
|
|
909
|
+
lines.append(
|
|
910
|
+
f"classDef op fill:{op_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
911
|
+
)
|
|
912
|
+
lines.append(
|
|
913
|
+
f"classDef param fill:{param_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
914
|
+
)
|
|
915
|
+
lines.append(
|
|
916
|
+
f"classDef result fill:{result_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
917
|
+
)
|
|
918
|
+
lines.append(
|
|
919
|
+
f"classDef leaf fill:{leaf_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
920
|
+
)
|
|
921
|
+
lines.append(
|
|
922
|
+
f"classDef grad fill:{grad_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
923
|
+
)
|
|
924
|
+
lines.append(
|
|
925
|
+
f"classDef start fill:{start_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
926
|
+
)
|
|
927
|
+
lines.extend(class_lines)
|
|
928
|
+
|
|
929
|
+
if end_semicolons:
|
|
930
|
+
lines = [
|
|
931
|
+
f"{line};" if line and not line.endswith(";") else line for line in lines
|
|
932
|
+
]
|
|
933
|
+
|
|
934
|
+
text = "\n".join(lines)
|
|
935
|
+
if copy_to_clipboard:
|
|
936
|
+
_copy_to_clipboard(text)
|
|
937
|
+
return text
|
|
938
|
+
|
|
939
|
+
|
|
754
940
|
def _copy_to_clipboard(text: str) -> None:
|
|
755
941
|
import os
|
|
756
942
|
import shutil
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.11.
|
|
3
|
+
Version: 2.11.5
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -48,19 +48,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
48
48
|
|
|
49
49
|
### 🔥 What's New
|
|
50
50
|
|
|
51
|
-
- Added
|
|
52
|
-
|
|
53
|
-
```python
|
|
54
|
-
def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
55
|
-
|
|
56
|
-
def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
51
|
+
- Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
|
|
57
52
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def register_full_backward_pre_hook(self, hook: Callable)
|
|
61
|
-
|
|
62
|
-
def register_full_backward_hook(self, hook: Callable)
|
|
63
|
-
```
|
|
53
|
+
- Added additional `nn.Module` hooks for richer introspection during training:
|
|
64
54
|
|
|
65
55
|
## 🔧 How to Install
|
|
66
56
|
|
|
@@ -1,141 +0,0 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
from warnings import deprecated
|
|
3
|
-
|
|
4
|
-
import networkx as nx
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
|
|
7
|
-
import lucid.nn as nn
|
|
8
|
-
from lucid._tensor import Tensor
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
__all__ = ["draw_tensor_graph"]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@deprecated("This feature will be re-written with Mermaid in future relases.")
|
|
15
|
-
def draw_tensor_graph(
|
|
16
|
-
tensor: Tensor,
|
|
17
|
-
horizontal: bool = False,
|
|
18
|
-
title: Union[str, None] = None,
|
|
19
|
-
start_id: Union[int, None] = None,
|
|
20
|
-
) -> plt.Figure:
|
|
21
|
-
G: nx.DiGraph = nx.DiGraph()
|
|
22
|
-
result_id: int = id(tensor)
|
|
23
|
-
|
|
24
|
-
visited: set[int] = set()
|
|
25
|
-
nodes_to_draw: list[Tensor] = []
|
|
26
|
-
|
|
27
|
-
def dfs(t: Tensor) -> None:
|
|
28
|
-
if id(t) in visited:
|
|
29
|
-
return
|
|
30
|
-
visited.add(id(t))
|
|
31
|
-
for p in t._prev:
|
|
32
|
-
dfs(p)
|
|
33
|
-
nodes_to_draw.append(t)
|
|
34
|
-
|
|
35
|
-
dfs(tensor)
|
|
36
|
-
|
|
37
|
-
for t in nodes_to_draw:
|
|
38
|
-
if not t.is_leaf and t._op is not None:
|
|
39
|
-
op_id: int = id(t._op)
|
|
40
|
-
op_label: str = type(t._op).__name__
|
|
41
|
-
G.add_node(op_id, label=op_label, shape="circle", color="lightgreen")
|
|
42
|
-
G.add_edge(op_id, id(t))
|
|
43
|
-
for inp in t._prev:
|
|
44
|
-
G.add_edge(id(inp), op_id)
|
|
45
|
-
|
|
46
|
-
shape_label: str = str(t.shape) if t.ndim > 0 else str(t.item())
|
|
47
|
-
if isinstance(t, nn.Parameter):
|
|
48
|
-
color: str = "plum"
|
|
49
|
-
else:
|
|
50
|
-
color = (
|
|
51
|
-
"lightcoral"
|
|
52
|
-
if id(t) == result_id
|
|
53
|
-
else "lightgray" if not t.requires_grad else "lightblue"
|
|
54
|
-
)
|
|
55
|
-
if start_id is not None and id(t) == start_id:
|
|
56
|
-
color = "gold"
|
|
57
|
-
|
|
58
|
-
G.add_node(id(t), label=shape_label, shape="rectangle", color=color)
|
|
59
|
-
|
|
60
|
-
def grid_layout(
|
|
61
|
-
G: nx.DiGraph, horizontal: bool = False
|
|
62
|
-
) -> tuple[dict, tuple, float, int]:
|
|
63
|
-
levels: dict[int, int] = {}
|
|
64
|
-
for node in nx.topological_sort(G):
|
|
65
|
-
preds = list(G.predecessors(node))
|
|
66
|
-
levels[node] = 0 if not preds else max(levels[p] for p in preds) + 1
|
|
67
|
-
|
|
68
|
-
level_nodes: dict[int, list[int]] = {}
|
|
69
|
-
for node, level in levels.items():
|
|
70
|
-
level_nodes.setdefault(level, []).append(node)
|
|
71
|
-
|
|
72
|
-
def autoscale(
|
|
73
|
-
level_nodes: dict[int, list[int]],
|
|
74
|
-
horizontal: bool = False,
|
|
75
|
-
base_size: float = 0.5,
|
|
76
|
-
base_nodesize: int = 500,
|
|
77
|
-
) -> tuple[tuple[float, float], float, int]:
|
|
78
|
-
num_levels: int = len(level_nodes)
|
|
79
|
-
max_width: int = max(len(nodes) for nodes in level_nodes.values())
|
|
80
|
-
node_count: int = sum(len(nodes) for nodes in level_nodes.values())
|
|
81
|
-
|
|
82
|
-
if horizontal:
|
|
83
|
-
fig_w: float = min(32, max(4.0, base_size * num_levels))
|
|
84
|
-
fig_h: float = min(32, max(4.0, base_size * max_width))
|
|
85
|
-
else:
|
|
86
|
-
fig_w = min(32, max(4.0, base_size * max_width))
|
|
87
|
-
fig_h = min(32, max(4.0, base_size * num_levels))
|
|
88
|
-
|
|
89
|
-
nodesize: float = (
|
|
90
|
-
base_nodesize
|
|
91
|
-
if node_count <= 100
|
|
92
|
-
else base_nodesize * (100 / node_count)
|
|
93
|
-
)
|
|
94
|
-
fontsize: int = max(5, min(8, int(80 / node_count)))
|
|
95
|
-
return (fig_w, fig_h), nodesize, fontsize
|
|
96
|
-
|
|
97
|
-
figsize, nodesize, fontsize = autoscale(level_nodes, horizontal)
|
|
98
|
-
pos: dict[int, tuple[float, float]] = {}
|
|
99
|
-
for level, nodes in level_nodes.items():
|
|
100
|
-
for i, node in enumerate(nodes):
|
|
101
|
-
pos[node] = (
|
|
102
|
-
(level * 2.5, -i * 2.0) if horizontal else (i * 2.5, -level * 2.0)
|
|
103
|
-
)
|
|
104
|
-
return pos, figsize, nodesize, fontsize
|
|
105
|
-
|
|
106
|
-
labels: dict[int, str] = nx.get_node_attributes(G, "label")
|
|
107
|
-
colors: dict[int, str] = nx.get_node_attributes(G, "color")
|
|
108
|
-
shapes: dict[int, str] = nx.get_node_attributes(G, "shape")
|
|
109
|
-
pos, figsize, nodesize, fontsize = grid_layout(G, horizontal)
|
|
110
|
-
|
|
111
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
112
|
-
|
|
113
|
-
rect_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "rectangle"]
|
|
114
|
-
circ_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "circle"]
|
|
115
|
-
rect_colors: list[str] = [colors[n] for n in rect_nodes]
|
|
116
|
-
|
|
117
|
-
nx.draw_networkx_nodes(
|
|
118
|
-
G,
|
|
119
|
-
pos,
|
|
120
|
-
nodelist=rect_nodes,
|
|
121
|
-
node_color=rect_colors,
|
|
122
|
-
node_size=nodesize,
|
|
123
|
-
node_shape="s",
|
|
124
|
-
ax=ax,
|
|
125
|
-
)
|
|
126
|
-
nx.draw_networkx_nodes(
|
|
127
|
-
G,
|
|
128
|
-
pos,
|
|
129
|
-
nodelist=circ_nodes,
|
|
130
|
-
node_color="lightgreen",
|
|
131
|
-
node_size=nodesize,
|
|
132
|
-
node_shape="o",
|
|
133
|
-
ax=ax,
|
|
134
|
-
)
|
|
135
|
-
nx.draw_networkx_edges(G, pos, width=0.5, arrows=True, edge_color="gray", ax=ax)
|
|
136
|
-
nx.draw_networkx_labels(G, pos, labels=labels, font_size=fontsize, ax=ax)
|
|
137
|
-
|
|
138
|
-
ax.axis("off")
|
|
139
|
-
ax.set_title(title if title is not None else "")
|
|
140
|
-
|
|
141
|
-
return fig
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|