lucid-dl 2.11.3__tar.gz → 2.11.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. {lucid_dl-2.11.3/lucid_dl.egg-info → lucid_dl-2.11.5}/PKG-INFO +3 -13
  2. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/README.md +2 -12
  3. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/module.py +55 -21
  4. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/types.py +58 -0
  5. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/visual/__init__.py +0 -1
  6. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/visual/mermaid.py +188 -2
  7. {lucid_dl-2.11.3 → lucid_dl-2.11.5/lucid_dl.egg-info}/PKG-INFO +3 -13
  8. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/SOURCES.txt +0 -1
  9. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/setup.py +1 -1
  10. lucid_dl-2.11.3/lucid/visual/graph.py +0 -141
  11. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/LICENSE +0 -0
  12. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/__init__.py +0 -0
  13. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_backend/__init__.py +0 -0
  14. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_backend/core.py +0 -0
  15. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_backend/metal.py +0 -0
  16. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/__init__.py +0 -0
  17. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/bfunc.py +0 -0
  18. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/gfunc.py +0 -0
  19. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_func/ufunc.py +0 -0
  20. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_fusion/__init__.py +0 -0
  21. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_fusion/base.py +0 -0
  22. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_fusion/func.py +0 -0
  23. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_tensor/__init__.py +0 -0
  24. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_tensor/base.py +0 -0
  25. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_tensor/tensor.py +0 -0
  26. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_util/__init__.py +0 -0
  27. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/_util/func.py +0 -0
  28. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/autograd/__init__.py +0 -0
  29. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/data/__init__.py +0 -0
  30. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/data/_base.py +0 -0
  31. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/data/_util.py +0 -0
  32. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/__init__.py +0 -0
  33. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/_base.py +0 -0
  34. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/cifar.py +0 -0
  35. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/datasets/mnist.py +0 -0
  36. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/einops/__init__.py +0 -0
  37. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/einops/_func.py +0 -0
  38. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/error.py +0 -0
  39. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/linalg/__init__.py +0 -0
  40. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/linalg/_func.py +0 -0
  41. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/__init__.py +0 -0
  42. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/__init__.py +0 -0
  43. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/alex.py +0 -0
  44. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/coatnet.py +0 -0
  45. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/convnext.py +0 -0
  46. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/crossvit.py +0 -0
  47. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/cspnet.py +0 -0
  48. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/cvt.py +0 -0
  49. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/dense.py +0 -0
  50. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/efficient.py +0 -0
  51. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/efficientformer.py +0 -0
  52. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/inception.py +0 -0
  53. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/inception_next.py +0 -0
  54. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/inception_res.py +0 -0
  55. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/lenet.py +0 -0
  56. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/maxvit.py +0 -0
  57. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/mobile.py +0 -0
  58. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/pvt.py +0 -0
  59. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/resnest.py +0 -0
  60. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/resnet.py +0 -0
  61. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/resnext.py +0 -0
  62. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/senet.py +0 -0
  63. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/sknet.py +0 -0
  64. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/swin.py +0 -0
  65. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/vgg.py +0 -0
  66. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/vit.py +0 -0
  67. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/xception.py +0 -0
  68. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imgclf/zfnet.py +0 -0
  69. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/__init__.py +0 -0
  70. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/ddpm.py +0 -0
  71. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/ncsn.py +0 -0
  72. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/imggen/vae.py +0 -0
  73. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/__init__.py +0 -0
  74. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/detr.py +0 -0
  75. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/efficientdet.py +0 -0
  76. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/fast_rcnn.py +0 -0
  77. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/faster_rcnn.py +0 -0
  78. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/rcnn.py +0 -0
  79. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/util.py +0 -0
  80. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/__init__.py +0 -0
  81. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
  82. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
  83. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
  84. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
  85. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/seq2seq/__init__.py +0 -0
  86. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/seq2seq/transformer.py +0 -0
  87. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/models/util.py +0 -0
  88. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/__init__.py +0 -0
  89. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/__init__.py +0 -0
  90. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/activation.py +0 -0
  91. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/attention.py +0 -0
  92. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/conv.py +0 -0
  93. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/embedding.py +0 -0
  94. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/loss.py +0 -0
  95. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/norm.py +0 -0
  96. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/_kernel/pool.py +0 -0
  97. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/__init__.py +0 -0
  98. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_activation.py +0 -0
  99. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_attention.py +0 -0
  100. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_conv.py +0 -0
  101. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_drop.py +0 -0
  102. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_linear.py +0 -0
  103. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_loss.py +0 -0
  104. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_norm.py +0 -0
  105. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_pool.py +0 -0
  106. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_spatial.py +0 -0
  107. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/functional/_util.py +0 -0
  108. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/fused.py +0 -0
  109. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/init/__init__.py +0 -0
  110. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/init/_dist.py +0 -0
  111. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/__init__.py +0 -0
  112. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/activation.py +0 -0
  113. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/attention.py +0 -0
  114. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/conv.py +0 -0
  115. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/drop.py +0 -0
  116. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/einops.py +0 -0
  117. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/linear.py +0 -0
  118. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/loss.py +0 -0
  119. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/norm.py +0 -0
  120. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/pool.py +0 -0
  121. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/rnn.py +0 -0
  122. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/sparse.py +0 -0
  123. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/transformer.py +0 -0
  124. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/modules/vision.py +0 -0
  125. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/parameter.py +0 -0
  126. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/nn/util.py +0 -0
  127. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/__init__.py +0 -0
  128. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/_base.py +0 -0
  129. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/ada.py +0 -0
  130. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/adam.py +0 -0
  131. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/lr_scheduler/__init__.py +0 -0
  132. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/lr_scheduler/_base.py +0 -0
  133. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
  134. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/prop.py +0 -0
  135. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/optim/sgd.py +0 -0
  136. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/port.py +0 -0
  137. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/random/__init__.py +0 -0
  138. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/random/_func.py +0 -0
  139. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/transforms/__init__.py +0 -0
  140. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/transforms/_base.py +0 -0
  141. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/transforms/image.py +0 -0
  142. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/weights/__init__.py +0 -0
  143. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid/weights/__init__.pyi +0 -0
  144. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/dependency_links.txt +0 -0
  145. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/requires.txt +0 -0
  146. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/lucid_dl.egg-info/top_level.txt +0 -0
  147. {lucid_dl-2.11.3 → lucid_dl-2.11.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.3
3
+ Version: 2.11.5
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,19 +48,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
- - Added additional `nn.Module` hooks for richer introspection during training:
52
-
53
- ```python
54
- def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
55
-
56
- def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
51
+ - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
57
52
 
58
- def register_backward_hook(self, hook: Callable)
59
-
60
- def register_full_backward_pre_hook(self, hook: Callable)
61
-
62
- def register_full_backward_hook(self, hook: Callable)
63
- ```
53
+ - Added additional `nn.Module` hooks for richer introspection during training:
64
54
 
65
55
  ## 🔧 How to Install
66
56
 
@@ -20,19 +20,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
20
20
 
21
21
  ### 🔥 What's New
22
22
 
23
- - Added additional `nn.Module` hooks for richer introspection during training:
24
-
25
- ```python
26
- def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
27
-
28
- def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
23
+ - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
29
24
 
30
- def register_backward_hook(self, hook: Callable)
31
-
32
- def register_full_backward_pre_hook(self, hook: Callable)
33
-
34
- def register_full_backward_hook(self, hook: Callable)
35
- ```
25
+ - Added additional `nn.Module` hooks for richer introspection during training:
36
26
 
37
27
  ## 🔧 How to Install
38
28
 
@@ -13,7 +13,22 @@ from typing import (
13
13
  from collections import OrderedDict
14
14
 
15
15
  from lucid._tensor import Tensor
16
- from lucid.types import _ArrayOrScalar, _NumPyArray, _DeviceType
16
+ from lucid.types import (
17
+ _ArrayOrScalar,
18
+ _BackwardHook,
19
+ _DeviceType,
20
+ _ForwardHook,
21
+ _ForwardHookKwargs,
22
+ _ForwardPreHook,
23
+ _ForwardPreHookKwargs,
24
+ _FullBackwardHook,
25
+ _FullBackwardPreHook,
26
+ _LoadStateDictPostHook,
27
+ _LoadStateDictPreHook,
28
+ _NumPyArray,
29
+ _StateDictHook,
30
+ _StateDictPreHook,
31
+ )
17
32
 
18
33
  import lucid.nn as nn
19
34
 
@@ -30,26 +45,6 @@ __all__ = [
30
45
  ]
31
46
 
32
47
 
33
- _ForwardPreHook = Callable[["Module", tuple[Any, ...]], tuple[Any, ...] | None]
34
- _ForwardPreHookKwargs = Callable[
35
- ["Module", tuple[Any, ...], dict[str, Any]],
36
- tuple[tuple[Any, ...], dict[str, Any]] | None,
37
- ]
38
- _ForwardHook = Callable[["Module", tuple[Any, ...], Any], Any | None]
39
- _ForwardHookKwargs = Callable[
40
- ["Module", tuple[Any, ...], dict[str, Any], Any], Any | None
41
- ]
42
-
43
- _BackwardHook = Callable[[Tensor, _NumPyArray], None]
44
- _FullBackwardPreHook = Callable[
45
- ["Module", tuple[_NumPyArray | None, ...]], tuple[_NumPyArray | None, ...] | None
46
- ]
47
- _FullBackwardHook = Callable[
48
- ["Module", tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
49
- tuple[_NumPyArray | None, ...] | None,
50
- ]
51
-
52
-
53
48
  class Module:
54
49
  _registry_map: dict[Type, OrderedDict[str, Any]] = {}
55
50
  _alt_name: str = ""
@@ -70,10 +65,17 @@ class Module:
70
65
  tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
71
66
  ] = []
72
67
  self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
68
+
73
69
  self._backward_hooks: list[_BackwardHook] = []
74
70
  self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
75
71
  self._full_backward_hooks: list[_FullBackwardHook] = []
76
72
 
73
+ self._state_dict_pre_hooks: list[_StateDictPreHook] = []
74
+ self._state_dict_hooks: list[_StateDictHook] = []
75
+
76
+ self._load_state_dict_pre_hooks: list[_LoadStateDictPreHook] = []
77
+ self._load_state_dict_post_hooks: list[_LoadStateDictPostHook] = []
78
+
77
79
  self._state_dict_pass_attr = set()
78
80
 
79
81
  def __setattr__(self, name: str, value: Any) -> None:
@@ -155,6 +157,26 @@ class Module:
155
157
  self._full_backward_hooks.append(hook)
156
158
  return lambda: self._full_backward_hooks.remove(hook)
157
159
 
160
+ def register_state_dict_pre_hook(self, hook: _StateDictPreHook) -> Callable:
161
+ self._state_dict_pre_hooks.append(hook)
162
+ return lambda: self._state_dict_pre_hooks.remove(hook)
163
+
164
+ def register_state_dict_hook(self, hook: _StateDictHook) -> Callable:
165
+ self._state_dict_hooks.append(hook)
166
+ return lambda: self._state_dict_hooks.remove(hook)
167
+
168
+ def register_load_state_dict_pre_hook(
169
+ self, hook: _LoadStateDictPreHook
170
+ ) -> Callable:
171
+ self._load_state_dict_pre_hooks.append(hook)
172
+ return lambda: self._load_state_dict_pre_hooks.remove(hook)
173
+
174
+ def register_load_state_dict_post_hook(
175
+ self, hook: _LoadStateDictPostHook
176
+ ) -> Callable:
177
+ self._load_state_dict_post_hooks.append(hook)
178
+ return lambda: self._load_state_dict_post_hooks.remove(hook)
179
+
158
180
  def reset_parameters(self) -> None:
159
181
  for param in self.parameters():
160
182
  param.zero()
@@ -231,6 +253,9 @@ class Module:
231
253
  prefix: str = "",
232
254
  keep_vars: bool = False,
233
255
  ) -> OrderedDict:
256
+ for hook in self._state_dict_pre_hooks:
257
+ hook(self, prefix, keep_vars)
258
+
234
259
  if destination is None:
235
260
  destination = OrderedDict()
236
261
 
@@ -249,9 +274,15 @@ class Module:
249
274
  if key in self._state_dict_pass_attr:
250
275
  del destination[key]
251
276
 
277
+ for hook in self._state_dict_hooks:
278
+ hook(self, destination, prefix, keep_vars)
279
+
252
280
  return destination
253
281
 
254
282
  def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
283
+ for hook in self._load_state_dict_pre_hooks:
284
+ hook(self, state_dict, strict)
285
+
255
286
  own_state = self.state_dict(keep_vars=True)
256
287
 
257
288
  missing_keys = set(own_state.keys()) - set(state_dict.keys())
@@ -277,6 +308,9 @@ class Module:
277
308
  elif strict:
278
309
  raise KeyError(f"Unexpected key '{key}' in state_dict.")
279
310
 
311
+ for hook in self._load_state_dict_post_hooks:
312
+ hook(self, missing_keys, unexpected_keys, strict)
313
+
280
314
  def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
281
315
  for hook, with_kwargs in self._forward_pre_hooks:
282
316
  if with_kwargs:
@@ -6,8 +6,10 @@ from typing import (
6
6
  Sequence,
7
7
  Literal,
8
8
  TypeAlias,
9
+ TYPE_CHECKING,
9
10
  runtime_checkable,
10
11
  )
12
+ from collections import OrderedDict
11
13
  import re
12
14
 
13
15
  import numpy as np
@@ -76,6 +78,62 @@ class _TensorLike(Protocol):
76
78
  ) -> None: ...
77
79
 
78
80
 
81
+ @runtime_checkable
82
+ class _ModuleHookable(Protocol):
83
+ def register_forward_pre_hook(
84
+ self, hook: Callable, *, with_kwargs: bool = False
85
+ ) -> Callable: ...
86
+
87
+ def register_forward_hook(
88
+ self, hook: Callable, *, with_kwargs: bool = False
89
+ ) -> Callable: ...
90
+
91
+ def register_backward_hook(self, hook: Callable) -> Callable: ...
92
+
93
+ def register_full_backward_pre_hook(self, hook: Callable) -> Callable: ...
94
+
95
+ def register_full_backward_hook(self, hook: Callable) -> Callable: ...
96
+
97
+ def register_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
98
+
99
+ def register_state_dict_hook(self, hook: Callable) -> Callable: ...
100
+
101
+ def register_load_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
102
+
103
+ def register_load_state_dict_post_hook(self, hook: Callable) -> Callable: ...
104
+
105
+
106
+ _ForwardPreHook: TypeAlias = Callable[
107
+ [_ModuleHookable, tuple[Any, ...]], tuple[Any, ...] | None
108
+ ]
109
+ _ForwardPreHookKwargs: TypeAlias = Callable[
110
+ [_ModuleHookable, tuple[Any, ...], dict[str, Any]],
111
+ tuple[tuple[Any, ...], dict[str, Any]] | None,
112
+ ]
113
+ _ForwardHook: TypeAlias = Callable[[_ModuleHookable, tuple[Any, ...], Any], Any | None]
114
+ _ForwardHookKwargs: TypeAlias = Callable[
115
+ [_ModuleHookable, tuple[Any, ...], dict[str, Any], Any], Any | None
116
+ ]
117
+
118
+ _BackwardHook: TypeAlias = Callable[[_TensorLike, _NumPyArray], None]
119
+ _FullBackwardPreHook: TypeAlias = Callable[
120
+ [_ModuleHookable, tuple[_NumPyArray | None, ...]],
121
+ tuple[_NumPyArray | None, ...] | None,
122
+ ]
123
+ _FullBackwardHook: TypeAlias = Callable[
124
+ [_ModuleHookable, tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
125
+ tuple[_NumPyArray | None, ...] | None,
126
+ ]
127
+
128
+ _StateDictPreHook: TypeAlias = Callable[[_ModuleHookable, str, bool], None]
129
+ _StateDictHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, str, bool], None]
130
+
131
+ _LoadStateDictPreHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, bool], None]
132
+ _LoadStateDictPostHook: TypeAlias = Callable[
133
+ [_ModuleHookable, set[str], set[str], bool], None
134
+ ]
135
+
136
+
79
137
  class Numeric:
80
138
  def __init__(
81
139
  self, base_dtype: type[int | float | complex], bits: int | None
@@ -1,2 +1 @@
1
- from .graph import *
2
1
  from .mermaid import *
@@ -9,7 +9,7 @@ from lucid._tensor import Tensor
9
9
  from lucid.types import _ShapeLike
10
10
 
11
11
 
12
- __all__ = ["build_mermaid_chart"]
12
+ __all__ = ["build_tensor_mermaid_chart", "build_module_mermaid_chart"]
13
13
 
14
14
 
15
15
  _NN_MODULES_PREFIX = "lucid.nn.modules."
@@ -255,7 +255,7 @@ def _collapse_repeated_children(
255
255
  return out
256
256
 
257
257
 
258
- def build_mermaid_chart(
258
+ def build_module_mermaid_chart(
259
259
  module: nn.Module,
260
260
  input_shape: _ShapeLike | list[_ShapeLike] | None = None,
261
261
  inputs: Iterable[Tensor] | Tensor | None = None,
@@ -751,6 +751,192 @@ def build_mermaid_chart(
751
751
  return text
752
752
 
753
753
 
754
+ def build_mermaid_chart(
755
+ module: nn.Module,
756
+ input_shape: _ShapeLike | list[_ShapeLike] | None = None,
757
+ inputs: Iterable[Tensor] | Tensor | None = None,
758
+ depth: int = 2,
759
+ direction: str = "LR",
760
+ include_io: bool = True,
761
+ show_params: bool = False,
762
+ return_lines: bool = False,
763
+ copy_to_clipboard: bool = False,
764
+ compact: bool = False,
765
+ use_class_defs: bool = False,
766
+ end_semicolons: bool = True,
767
+ edge_mode: Literal["dataflow", "execution"] = "execution",
768
+ collapse_repeats: bool = True,
769
+ repeat_min: int = 2,
770
+ color_by_subpackage: bool = True,
771
+ container_name_from_attr: bool = True,
772
+ edge_stroke_width: float = 2.0,
773
+ emphasize_model_title: bool = True,
774
+ model_title_font_px: int = 20,
775
+ show_shapes: bool = False,
776
+ hide_subpackages: Iterable[str] = (),
777
+ hide_module_names: Iterable[str] = (),
778
+ dash_multi_input_edges: bool = True,
779
+ subgraph_fill: str = "#000000",
780
+ subgraph_fill_opacity: float = 0.05,
781
+ subgraph_stroke: str = "#000000",
782
+ subgraph_stroke_opacity: float = 0.75,
783
+ force_text_color: str | None = None,
784
+ edge_curve: str = "natural",
785
+ node_spacing: int = 50,
786
+ rank_spacing: int = 50,
787
+ **forward_kwargs,
788
+ ) -> str | list[str]:
789
+ return build_module_mermaid_chart(
790
+ module,
791
+ input_shape=input_shape,
792
+ inputs=inputs,
793
+ depth=depth,
794
+ direction=direction,
795
+ include_io=include_io,
796
+ show_params=show_params,
797
+ return_lines=return_lines,
798
+ copy_to_clipboard=copy_to_clipboard,
799
+ compact=compact,
800
+ use_class_defs=use_class_defs,
801
+ end_semicolons=end_semicolons,
802
+ edge_mode=edge_mode,
803
+ collapse_repeats=collapse_repeats,
804
+ repeat_min=repeat_min,
805
+ color_by_subpackage=color_by_subpackage,
806
+ container_name_from_attr=container_name_from_attr,
807
+ edge_stroke_width=edge_stroke_width,
808
+ emphasize_model_title=emphasize_model_title,
809
+ model_title_font_px=model_title_font_px,
810
+ show_shapes=show_shapes,
811
+ hide_subpackages=hide_subpackages,
812
+ hide_module_names=hide_module_names,
813
+ dash_multi_input_edges=dash_multi_input_edges,
814
+ subgraph_fill=subgraph_fill,
815
+ subgraph_fill_opacity=subgraph_fill_opacity,
816
+ subgraph_stroke=subgraph_stroke,
817
+ subgraph_stroke_opacity=subgraph_stroke_opacity,
818
+ force_text_color=force_text_color,
819
+ edge_curve=edge_curve,
820
+ node_spacing=node_spacing,
821
+ rank_spacing=rank_spacing,
822
+ **forward_kwargs,
823
+ )
824
+
825
+
826
+ def build_tensor_mermaid_chart(
827
+ tensor: Tensor,
828
+ horizontal: bool = False,
829
+ title: str | None = None,
830
+ start_id: int | None = None,
831
+ end_semicolons: bool = True,
832
+ copy_to_clipboard: bool = False,
833
+ use_class_defs: bool = True,
834
+ op_fill: str = "lightgreen",
835
+ param_fill: str = "plum",
836
+ result_fill: str = "lightcoral",
837
+ leaf_fill: str = "lightgray",
838
+ grad_fill: str = "lightblue",
839
+ start_fill: str = "gold",
840
+ stroke_color: str = "#666",
841
+ stroke_width_px: int = 1,
842
+ ) -> str:
843
+ direction = "LR" if horizontal else "TD"
844
+ lines: list[str] = [f"flowchart {direction}"]
845
+ if title:
846
+ lines.append(f"%% {title}")
847
+
848
+ result_id: int = id(tensor)
849
+ visited: set[int] = set()
850
+ nodes_to_draw: list[Tensor] = []
851
+
852
+ def dfs(t: Tensor) -> None:
853
+ if id(t) in visited:
854
+ return
855
+ visited.add(id(t))
856
+ for p in t._prev:
857
+ dfs(p)
858
+ nodes_to_draw.append(t)
859
+
860
+ def tensor_node_id(t: Tensor) -> str:
861
+ return f"t_{id(t)}"
862
+
863
+ def op_node_id(op: object) -> str:
864
+ return f"op_{id(op)}"
865
+
866
+ def add_node(node_id: str, label: str, kind: str) -> None:
867
+ if node_id in defined_nodes:
868
+ return
869
+ defined_nodes.add(node_id)
870
+ if kind == "op":
871
+ lines.append(f'{node_id}(("{label}"))')
872
+ else:
873
+ lines.append(f'{node_id}["{label}"]')
874
+
875
+ dfs(tensor)
876
+
877
+ defined_nodes: set[str] = set()
878
+ edge_lines: list[str] = []
879
+ class_lines: list[str] = []
880
+
881
+ for t in nodes_to_draw:
882
+ t_id = tensor_node_id(t)
883
+
884
+ if not t.is_leaf and t._op is not None:
885
+ op_id = op_node_id(t._op)
886
+ op_label = type(t._op).__name__
887
+ add_node(op_id, op_label, "op")
888
+ edge_lines.append(f"{op_id} --> {t_id}")
889
+ class_lines.append(f"class {op_id} op")
890
+ for inp in t._prev:
891
+ edge_lines.append(f"{tensor_node_id(inp)} --> {op_id}")
892
+
893
+ shape_label = str(t.shape) if t.ndim > 0 else str(t.item())
894
+ add_node(t_id, shape_label, "tensor")
895
+
896
+ if start_id is not None and id(t) == start_id:
897
+ class_lines.append(f"class {t_id} start")
898
+ elif isinstance(t, nn.Parameter):
899
+ class_lines.append(f"class {t_id} param")
900
+ elif id(t) == result_id:
901
+ class_lines.append(f"class {t_id} result")
902
+ elif not t.requires_grad:
903
+ class_lines.append(f"class {t_id} leaf")
904
+ else:
905
+ class_lines.append(f"class {t_id} grad")
906
+
907
+ lines.extend(edge_lines)
908
+ if use_class_defs:
909
+ lines.append(
910
+ f"classDef op fill:{op_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
911
+ )
912
+ lines.append(
913
+ f"classDef param fill:{param_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
914
+ )
915
+ lines.append(
916
+ f"classDef result fill:{result_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
917
+ )
918
+ lines.append(
919
+ f"classDef leaf fill:{leaf_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
920
+ )
921
+ lines.append(
922
+ f"classDef grad fill:{grad_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
923
+ )
924
+ lines.append(
925
+ f"classDef start fill:{start_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
926
+ )
927
+ lines.extend(class_lines)
928
+
929
+ if end_semicolons:
930
+ lines = [
931
+ f"{line};" if line and not line.endswith(";") else line for line in lines
932
+ ]
933
+
934
+ text = "\n".join(lines)
935
+ if copy_to_clipboard:
936
+ _copy_to_clipboard(text)
937
+ return text
938
+
939
+
754
940
  def _copy_to_clipboard(text: str) -> None:
755
941
  import os
756
942
  import shutil
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.3
3
+ Version: 2.11.5
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,19 +48,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
- - Added additional `nn.Module` hooks for richer introspection during training:
52
-
53
- ```python
54
- def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
55
-
56
- def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
51
+ - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
57
52
 
58
- def register_backward_hook(self, hook: Callable)
59
-
60
- def register_full_backward_pre_hook(self, hook: Callable)
61
-
62
- def register_full_backward_hook(self, hook: Callable)
63
- ```
53
+ - Added additional `nn.Module` hooks for richer introspection during training:
64
54
 
65
55
  ## 🔧 How to Install
66
56
 
@@ -134,7 +134,6 @@ lucid/transforms/__init__.py
134
134
  lucid/transforms/_base.py
135
135
  lucid/transforms/image.py
136
136
  lucid/visual/__init__.py
137
- lucid/visual/graph.py
138
137
  lucid/visual/mermaid.py
139
138
  lucid/weights/__init__.py
140
139
  lucid/weights/__init__.pyi
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setuptools.setup(
7
7
  name="lucid-dl",
8
- version="2.11.3",
8
+ version="2.11.5",
9
9
  author="ChanLumerico",
10
10
  author_email="greensox284@gmail.com",
11
11
  description="Lumerico's Comprehensive Interface for Deep Learning",
@@ -1,141 +0,0 @@
1
- from typing import Union
2
- from warnings import deprecated
3
-
4
- import networkx as nx
5
- import matplotlib.pyplot as plt
6
-
7
- import lucid.nn as nn
8
- from lucid._tensor import Tensor
9
-
10
-
11
- __all__ = ["draw_tensor_graph"]
12
-
13
-
14
- @deprecated("This feature will be re-written with Mermaid in future relases.")
15
- def draw_tensor_graph(
16
- tensor: Tensor,
17
- horizontal: bool = False,
18
- title: Union[str, None] = None,
19
- start_id: Union[int, None] = None,
20
- ) -> plt.Figure:
21
- G: nx.DiGraph = nx.DiGraph()
22
- result_id: int = id(tensor)
23
-
24
- visited: set[int] = set()
25
- nodes_to_draw: list[Tensor] = []
26
-
27
- def dfs(t: Tensor) -> None:
28
- if id(t) in visited:
29
- return
30
- visited.add(id(t))
31
- for p in t._prev:
32
- dfs(p)
33
- nodes_to_draw.append(t)
34
-
35
- dfs(tensor)
36
-
37
- for t in nodes_to_draw:
38
- if not t.is_leaf and t._op is not None:
39
- op_id: int = id(t._op)
40
- op_label: str = type(t._op).__name__
41
- G.add_node(op_id, label=op_label, shape="circle", color="lightgreen")
42
- G.add_edge(op_id, id(t))
43
- for inp in t._prev:
44
- G.add_edge(id(inp), op_id)
45
-
46
- shape_label: str = str(t.shape) if t.ndim > 0 else str(t.item())
47
- if isinstance(t, nn.Parameter):
48
- color: str = "plum"
49
- else:
50
- color = (
51
- "lightcoral"
52
- if id(t) == result_id
53
- else "lightgray" if not t.requires_grad else "lightblue"
54
- )
55
- if start_id is not None and id(t) == start_id:
56
- color = "gold"
57
-
58
- G.add_node(id(t), label=shape_label, shape="rectangle", color=color)
59
-
60
- def grid_layout(
61
- G: nx.DiGraph, horizontal: bool = False
62
- ) -> tuple[dict, tuple, float, int]:
63
- levels: dict[int, int] = {}
64
- for node in nx.topological_sort(G):
65
- preds = list(G.predecessors(node))
66
- levels[node] = 0 if not preds else max(levels[p] for p in preds) + 1
67
-
68
- level_nodes: dict[int, list[int]] = {}
69
- for node, level in levels.items():
70
- level_nodes.setdefault(level, []).append(node)
71
-
72
- def autoscale(
73
- level_nodes: dict[int, list[int]],
74
- horizontal: bool = False,
75
- base_size: float = 0.5,
76
- base_nodesize: int = 500,
77
- ) -> tuple[tuple[float, float], float, int]:
78
- num_levels: int = len(level_nodes)
79
- max_width: int = max(len(nodes) for nodes in level_nodes.values())
80
- node_count: int = sum(len(nodes) for nodes in level_nodes.values())
81
-
82
- if horizontal:
83
- fig_w: float = min(32, max(4.0, base_size * num_levels))
84
- fig_h: float = min(32, max(4.0, base_size * max_width))
85
- else:
86
- fig_w = min(32, max(4.0, base_size * max_width))
87
- fig_h = min(32, max(4.0, base_size * num_levels))
88
-
89
- nodesize: float = (
90
- base_nodesize
91
- if node_count <= 100
92
- else base_nodesize * (100 / node_count)
93
- )
94
- fontsize: int = max(5, min(8, int(80 / node_count)))
95
- return (fig_w, fig_h), nodesize, fontsize
96
-
97
- figsize, nodesize, fontsize = autoscale(level_nodes, horizontal)
98
- pos: dict[int, tuple[float, float]] = {}
99
- for level, nodes in level_nodes.items():
100
- for i, node in enumerate(nodes):
101
- pos[node] = (
102
- (level * 2.5, -i * 2.0) if horizontal else (i * 2.5, -level * 2.0)
103
- )
104
- return pos, figsize, nodesize, fontsize
105
-
106
- labels: dict[int, str] = nx.get_node_attributes(G, "label")
107
- colors: dict[int, str] = nx.get_node_attributes(G, "color")
108
- shapes: dict[int, str] = nx.get_node_attributes(G, "shape")
109
- pos, figsize, nodesize, fontsize = grid_layout(G, horizontal)
110
-
111
- fig, ax = plt.subplots(figsize=figsize)
112
-
113
- rect_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "rectangle"]
114
- circ_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "circle"]
115
- rect_colors: list[str] = [colors[n] for n in rect_nodes]
116
-
117
- nx.draw_networkx_nodes(
118
- G,
119
- pos,
120
- nodelist=rect_nodes,
121
- node_color=rect_colors,
122
- node_size=nodesize,
123
- node_shape="s",
124
- ax=ax,
125
- )
126
- nx.draw_networkx_nodes(
127
- G,
128
- pos,
129
- nodelist=circ_nodes,
130
- node_color="lightgreen",
131
- node_size=nodesize,
132
- node_shape="o",
133
- ax=ax,
134
- )
135
- nx.draw_networkx_edges(G, pos, width=0.5, arrows=True, edge_color="gray", ax=ax)
136
- nx.draw_networkx_labels(G, pos, labels=labels, font_size=fontsize, ax=ax)
137
-
138
- ax.axis("off")
139
- ax.set_title(title if title is not None else "")
140
-
141
- return fig
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes