lucid-dl 2.11.2__tar.gz → 2.11.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/PKG-INFO +30 -21
  2. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/README.md +29 -20
  3. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/crossvit.py +1 -1
  4. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/efficientformer.py +2 -2
  5. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/maxvit.py +1 -1
  6. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/pvt.py +2 -2
  7. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imggen/vae.py +1 -1
  8. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/efficientdet.py +24 -8
  9. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/rcnn.py +1 -1
  10. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/util.py +5 -0
  11. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/module.py +142 -13
  12. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/types.py +58 -0
  13. lucid_dl-2.11.4/lucid/visual/__init__.py +2 -0
  14. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/visual/graph.py +3 -0
  15. lucid_dl-2.11.4/lucid/visual/mermaid.py +818 -0
  16. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid_dl.egg-info/PKG-INFO +30 -21
  17. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid_dl.egg-info/SOURCES.txt +1 -0
  18. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/setup.py +1 -1
  19. lucid_dl-2.11.2/lucid/visual/__init__.py +0 -1
  20. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/LICENSE +0 -0
  21. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/__init__.py +0 -0
  22. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_backend/__init__.py +0 -0
  23. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_backend/core.py +0 -0
  24. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_backend/metal.py +0 -0
  25. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_func/__init__.py +0 -0
  26. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_func/bfunc.py +0 -0
  27. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_func/gfunc.py +0 -0
  28. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_func/ufunc.py +0 -0
  29. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_fusion/__init__.py +0 -0
  30. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_fusion/base.py +0 -0
  31. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_fusion/func.py +0 -0
  32. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_tensor/__init__.py +0 -0
  33. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_tensor/base.py +0 -0
  34. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_tensor/tensor.py +0 -0
  35. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_util/__init__.py +0 -0
  36. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/_util/func.py +0 -0
  37. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/autograd/__init__.py +0 -0
  38. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/data/__init__.py +0 -0
  39. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/data/_base.py +0 -0
  40. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/data/_util.py +0 -0
  41. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/datasets/__init__.py +0 -0
  42. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/datasets/_base.py +0 -0
  43. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/datasets/cifar.py +0 -0
  44. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/datasets/mnist.py +0 -0
  45. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/einops/__init__.py +0 -0
  46. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/einops/_func.py +0 -0
  47. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/error.py +0 -0
  48. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/linalg/__init__.py +0 -0
  49. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/linalg/_func.py +0 -0
  50. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/__init__.py +0 -0
  51. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/__init__.py +0 -0
  52. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/alex.py +0 -0
  53. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/coatnet.py +0 -0
  54. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/convnext.py +0 -0
  55. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/cspnet.py +0 -0
  56. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/cvt.py +0 -0
  57. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/dense.py +0 -0
  58. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/efficient.py +0 -0
  59. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/inception.py +0 -0
  60. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/inception_next.py +0 -0
  61. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/inception_res.py +0 -0
  62. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/lenet.py +0 -0
  63. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/mobile.py +0 -0
  64. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/resnest.py +0 -0
  65. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/resnet.py +0 -0
  66. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/resnext.py +0 -0
  67. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/senet.py +0 -0
  68. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/sknet.py +0 -0
  69. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/swin.py +0 -0
  70. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/vgg.py +0 -0
  71. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/vit.py +0 -0
  72. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/xception.py +0 -0
  73. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imgclf/zfnet.py +0 -0
  74. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imggen/__init__.py +0 -0
  75. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imggen/ddpm.py +0 -0
  76. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/imggen/ncsn.py +0 -0
  77. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/__init__.py +0 -0
  78. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/detr.py +0 -0
  79. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/fast_rcnn.py +0 -0
  80. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/faster_rcnn.py +0 -0
  81. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/yolo/__init__.py +0 -0
  82. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
  83. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
  84. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
  85. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
  86. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/seq2seq/__init__.py +0 -0
  87. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/seq2seq/transformer.py +0 -0
  88. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/models/util.py +0 -0
  89. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/__init__.py +0 -0
  90. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/__init__.py +0 -0
  91. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/activation.py +0 -0
  92. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/attention.py +0 -0
  93. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/conv.py +0 -0
  94. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/embedding.py +0 -0
  95. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/loss.py +0 -0
  96. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/norm.py +0 -0
  97. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/_kernel/pool.py +0 -0
  98. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/__init__.py +0 -0
  99. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_activation.py +0 -0
  100. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_attention.py +0 -0
  101. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_conv.py +0 -0
  102. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_drop.py +0 -0
  103. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_linear.py +0 -0
  104. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_loss.py +0 -0
  105. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_norm.py +0 -0
  106. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_pool.py +0 -0
  107. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_spatial.py +0 -0
  108. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/functional/_util.py +0 -0
  109. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/fused.py +0 -0
  110. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/init/__init__.py +0 -0
  111. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/init/_dist.py +0 -0
  112. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/__init__.py +0 -0
  113. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/activation.py +0 -0
  114. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/attention.py +0 -0
  115. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/conv.py +0 -0
  116. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/drop.py +0 -0
  117. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/einops.py +0 -0
  118. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/linear.py +0 -0
  119. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/loss.py +0 -0
  120. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/norm.py +0 -0
  121. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/pool.py +0 -0
  122. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/rnn.py +0 -0
  123. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/sparse.py +0 -0
  124. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/transformer.py +0 -0
  125. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/modules/vision.py +0 -0
  126. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/parameter.py +0 -0
  127. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/nn/util.py +0 -0
  128. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/__init__.py +0 -0
  129. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/_base.py +0 -0
  130. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/ada.py +0 -0
  131. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/adam.py +0 -0
  132. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/lr_scheduler/__init__.py +0 -0
  133. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/lr_scheduler/_base.py +0 -0
  134. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
  135. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/prop.py +0 -0
  136. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/optim/sgd.py +0 -0
  137. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/port.py +0 -0
  138. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/random/__init__.py +0 -0
  139. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/random/_func.py +0 -0
  140. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/transforms/__init__.py +0 -0
  141. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/transforms/_base.py +0 -0
  142. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/transforms/image.py +0 -0
  143. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/weights/__init__.py +0 -0
  144. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid/weights/__init__.pyi +0 -0
  145. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid_dl.egg-info/dependency_links.txt +0 -0
  146. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid_dl.egg-info/requires.txt +0 -0
  147. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/lucid_dl.egg-info/top_level.txt +0 -0
  148. {lucid_dl-2.11.2 → lucid_dl-2.11.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.2
3
+ Version: 2.11.4
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,26 +48,35 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
- - Added various inplace tensor operations (e.g. `a.add_(b)`, `a.mul_(b)`)
52
-
53
- - Added **Noise Conditional Score Network(NCSN)** to `lucid.models.NCSN`
54
-
55
- - Branched a Stand-Alone Autograd Engine as `lucid.autograd`
56
-
57
- - Provides a generalized API of computing gradients:
58
-
59
- ```python
60
- import lucid.autograd as autograd
61
- x = lucid.Tensor([1., 2.], requires_grad=True)
62
- y = (x ** 2).sum()
63
- autograd.grad(y, x) # ∂y/∂x
64
- ```
65
-
66
- - Introduced **Backward Fusion** for CPU execution:
67
- - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
68
- - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
69
- - Uses heuristic thresholds to avoid fusion overhead on small tensors
70
- - Disabled by default on GPU paths to ensure stable performance
51
+ - Added additional `nn.Module` hooks for richer introspection during training:
52
+
53
+ ```python
54
+ def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
55
+ ```
56
+ ```python
57
+ def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
58
+ ```
59
+ ```python
60
+ def register_backward_hook(self, hook: Callable)
61
+ ```
62
+ ```python
63
+ def register_full_backward_pre_hook(self, hook: Callable)
64
+ ```
65
+ ```python
66
+ def register_full_backward_hook(self, hook: Callable)
67
+ ```
68
+ ```python
69
+ def register_state_dict_pre_hook(self, hook: Callable)
70
+ ```
71
+ ```python
72
+ def register_state_dict_hook(self, hook: Callable)
73
+ ```
74
+ ```python
75
+ def register_load_state_dict_pre_hook(self, hook: Callable)
76
+ ```
77
+ ```python
78
+ def register_load_state_dict_post_hook(self, hook: Callable)
79
+ ```
71
80
 
72
81
  ## 🔧 How to Install
73
82
 
@@ -20,26 +20,35 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
20
20
 
21
21
  ### 🔥 What's New
22
22
 
23
- - Added various inplace tensor operations (e.g. `a.add_(b)`, `a.mul_(b)`)
24
-
25
- - Added **Noise Conditional Score Network(NCSN)** to `lucid.models.NCSN`
26
-
27
- - Branched a Stand-Alone Autograd Engine as `lucid.autograd`
28
-
29
- - Provides a generalized API of computing gradients:
30
-
31
- ```python
32
- import lucid.autograd as autograd
33
- x = lucid.Tensor([1., 2.], requires_grad=True)
34
- y = (x ** 2).sum()
35
- autograd.grad(y, x) # ∂y/∂x
36
- ```
37
-
38
- - Introduced **Backward Fusion** for CPU execution:
39
- - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
40
- - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
41
- - Uses heuristic thresholds to avoid fusion overhead on small tensors
42
- - Disabled by default on GPU paths to ensure stable performance
23
+ - Added additional `nn.Module` hooks for richer introspection during training:
24
+
25
+ ```python
26
+ def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
27
+ ```
28
+ ```python
29
+ def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
30
+ ```
31
+ ```python
32
+ def register_backward_hook(self, hook: Callable)
33
+ ```
34
+ ```python
35
+ def register_full_backward_pre_hook(self, hook: Callable)
36
+ ```
37
+ ```python
38
+ def register_full_backward_hook(self, hook: Callable)
39
+ ```
40
+ ```python
41
+ def register_state_dict_pre_hook(self, hook: Callable)
42
+ ```
43
+ ```python
44
+ def register_state_dict_hook(self, hook: Callable)
45
+ ```
46
+ ```python
47
+ def register_load_state_dict_pre_hook(self, hook: Callable)
48
+ ```
49
+ ```python
50
+ def register_load_state_dict_post_hook(self, hook: Callable)
51
+ ```
43
52
 
44
53
  ## 🔧 How to Install
45
54
 
@@ -79,7 +79,7 @@ class _PatchEmbed(nn.Module):
79
79
  f"Input image size {(H, W)} does not match with {self.img_size}."
80
80
  )
81
81
 
82
- x = self.proj(x).flatten(axis=2).swapaxes(1, 2)
82
+ x = self.proj(x).flatten(start_axis=2).swapaxes(1, 2)
83
83
  return x
84
84
 
85
85
 
@@ -80,7 +80,7 @@ class _Attention(nn.Module):
80
80
  y, x = lucid.meshgrid(
81
81
  lucid.arange(resolution[0]), lucid.arange(resolution[1]), indexing="ij"
82
82
  )
83
- pos = lucid.stack([y, x]).flatten(axis=1)
83
+ pos = lucid.stack([y, x]).flatten(start_axis=1)
84
84
  rel_pos = lucid.abs(pos[..., :, None] - pos[..., None, :])
85
85
  rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
86
86
 
@@ -159,7 +159,7 @@ class _Downsample(nn.Module):
159
159
 
160
160
  class _Flatten(nn.Module):
161
161
  def forward(self, x: Tensor) -> Tensor:
162
- x = x.flatten(axis=2).swapaxes(1, 2)
162
+ x = x.flatten(start_axis=2).swapaxes(1, 2)
163
163
  return x
164
164
 
165
165
 
@@ -216,7 +216,7 @@ def _grid_reverse(
216
216
 
217
217
  def _get_relative_position_index(win_h: int, win_w: int) -> Tensor:
218
218
  coords = lucid.stack(lucid.meshgrid(lucid.arange(win_h), lucid.arange(win_w)))
219
- coords_flatten = lucid.flatten(coords, axis=1)
219
+ coords_flatten = lucid.flatten(coords, start_axis=1)
220
220
 
221
221
  relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
222
222
  relative_coords = relative_coords.transpose((1, 2, 0))
@@ -328,7 +328,7 @@ class _DWConv(nn.Module):
328
328
  B, _, C = x.shape
329
329
  x = x.swapaxes(1, 2).reshape(B, C, H, W)
330
330
  x = self.dwconv(x)
331
- x = x.flatten(axis=2).swapaxes(1, 2)
331
+ x = x.flatten(start_axis=2).swapaxes(1, 2)
332
332
 
333
333
  return x
334
334
 
@@ -548,7 +548,7 @@ class _OverlapPatchEmbed(nn.Module):
548
548
  def forward(self, x: Tensor) -> tuple[Tensor, int, int]:
549
549
  x = self.proj(x)
550
550
  H, W = x.shape[2:]
551
- x = x.flatten(axis=2).swapaxes(1, 2)
551
+ x = x.flatten(start_axis=2).swapaxes(1, 2)
552
552
  x = self.norm(x)
553
553
 
554
554
  return x, H, W
@@ -51,7 +51,7 @@ class VAE(nn.Module):
51
51
  h = x
52
52
  for encoder in self.encoders:
53
53
  h = encoder(h)
54
- mu, logvar = lucid.split(h, 2, axis=1)
54
+ mu, logvar = lucid.chunk(h, 2, axis=1)
55
55
  z = self.reparameterize(mu, logvar)
56
56
 
57
57
  mus.append(mu)
@@ -82,23 +82,37 @@ class _BiFPN(nn.Module):
82
82
  def _norm_weight(self, weight: Tensor) -> Tensor:
83
83
  return weight / (weight.sum(axis=0) + self.eps)
84
84
 
85
+ @staticmethod
86
+ def _resize_like(x: Tensor, ref: Tensor) -> Tensor:
87
+ if x.shape[2:] == ref.shape[2:]:
88
+ return x
89
+ return F.interpolate(x, size=ref.shape[2:], mode="nearest")
90
+
85
91
  def _forward_up(self, feats: tuple[Tensor]) -> tuple[Tensor]:
86
92
  p3_in, p4_in, p5_in, p6_in, p7_in = feats
87
93
 
88
94
  w1_p6_up = self._norm_weight(self.acts["6_w1"](self.weights["6_w1"]))
89
- p6_up_in = w1_p6_up[0] * p6_in + w1_p6_up[1] * self.ups["6"](p7_in)
95
+ p6_up_in = w1_p6_up[0] * p6_in + w1_p6_up[1] * self._resize_like(
96
+ self.ups["6"](p7_in), p6_in
97
+ )
90
98
  p6_up = self.convs["6_up"](p6_up_in)
91
99
 
92
100
  w1_p5_up = self._norm_weight(self.acts["5_w1"](self.weights["5_w1"]))
93
- p5_up_in = w1_p5_up[0] * p5_in + w1_p5_up[1] * self.ups["5"](p6_up)
101
+ p5_up_in = w1_p5_up[0] * p5_in + w1_p5_up[1] * self._resize_like(
102
+ self.ups["5"](p6_up), p5_in
103
+ )
94
104
  p5_up = self.convs["5_up"](p5_up_in)
95
105
 
96
106
  w1_p4_up = self._norm_weight(self.acts["4_w1"](self.weights["4_w1"]))
97
- p4_up_in = w1_p4_up[0] * p4_in + w1_p4_up[1] * self.ups["4"](p5_up)
107
+ p4_up_in = w1_p4_up[0] * p4_in + w1_p4_up[1] * self._resize_like(
108
+ self.ups["4"](p5_up), p4_in
109
+ )
98
110
  p4_up = self.convs["4_up"](p4_up_in)
99
111
 
100
112
  w1_p3_up = self._norm_weight(self.acts["3_w1"](self.weights["3_w1"]))
101
- p3_up_in = w1_p3_up[0] * p3_in + w1_p3_up[1] * self.ups["3"](p4_up)
113
+ p3_up_in = w1_p3_up[0] * p3_in + w1_p3_up[1] * self._resize_like(
114
+ self.ups["3"](p4_up), p3_in
115
+ )
102
116
  p3_out = self.convs["3_up"](p3_up_in)
103
117
 
104
118
  return p3_out, p4_up, p5_up, p6_up
@@ -113,7 +127,7 @@ class _BiFPN(nn.Module):
113
127
  p4_down_in = (
114
128
  w2_p4_down[0] * p4_in
115
129
  + w2_p4_down[1] * p4_up
116
- + w2_p4_down[2] * self.downs["4"](p3_out)
130
+ + w2_p4_down[2] * self._resize_like(self.downs["4"](p3_out), p4_in)
117
131
  )
118
132
  p4_out = self.convs["4_down"](p4_down_in)
119
133
 
@@ -121,7 +135,7 @@ class _BiFPN(nn.Module):
121
135
  p5_down_in = (
122
136
  w2_p5_down[0] * p5_in
123
137
  + w2_p5_down[1] * p5_up
124
- + w2_p5_down[2] * self.downs["5"](p4_out)
138
+ + w2_p5_down[2] * self._resize_like(self.downs["5"](p4_out), p5_in)
125
139
  )
126
140
  p5_out = self.convs["5_down"](p5_down_in)
127
141
 
@@ -129,12 +143,14 @@ class _BiFPN(nn.Module):
129
143
  p6_down_in = (
130
144
  w2_p6_down[0] * p6_in
131
145
  + w2_p6_down[1] * p6_up
132
- + w2_p6_down[2] * self.downs["6"](p5_out)
146
+ + w2_p6_down[2] * self._resize_like(self.downs["6"](p5_out), p6_in)
133
147
  )
134
148
  p6_out = self.convs["6_down"](p6_down_in)
135
149
 
136
150
  w2_p7_down = self._norm_weight(self.acts["7_w2"](self.weights["7_w2"]))
137
- p7_down_in = w2_p7_down[0] * p7_in + w2_p7_down[1] * self.downs["7"](p6_out)
151
+ p7_down_in = w2_p7_down[0] * p7_in + w2_p7_down[1] * self._resize_like(
152
+ self.downs["7"](p6_out), p7_in
153
+ )
138
154
  p7_out = self.convs["7_down"](p7_down_in)
139
155
 
140
156
  return p3_out, p4_out, p5_out, p6_out, p7_out
@@ -120,7 +120,7 @@ class RCNN(nn.Module):
120
120
 
121
121
  if isinstance(feats, (tuple, list)):
122
122
  feats = feats[-1]
123
- feats = feats.flatten(axis=1)
123
+ feats = feats.flatten(start_axis=1)
124
124
 
125
125
  cls_scores = self.svm(feats)
126
126
  bbox_deltas = self.bbox_reg(feats)
@@ -283,6 +283,11 @@ class SelectiveSearch(nn.Module):
283
283
 
284
284
 
285
285
  def iou(boxes_a: Tensor, boxes_b: Tensor) -> Tensor:
286
+ if boxes_a.ndim == 1:
287
+ boxes_a = boxes_a.unsqueeze(0)
288
+ if boxes_b.ndim == 1:
289
+ boxes_b = boxes_b.unsqueeze(0)
290
+
286
291
  x1a, y1a, x2a, y2a = boxes_a.unbind(axis=1)
287
292
  x1b, y1b, x2b, y2b = boxes_b.unbind(axis=1)
288
293
 
@@ -13,7 +13,22 @@ from typing import (
13
13
  from collections import OrderedDict
14
14
 
15
15
  from lucid._tensor import Tensor
16
- from lucid.types import _ArrayOrScalar, _NumPyArray, _DeviceType
16
+ from lucid.types import (
17
+ _ArrayOrScalar,
18
+ _BackwardHook,
19
+ _DeviceType,
20
+ _ForwardHook,
21
+ _ForwardHookKwargs,
22
+ _ForwardPreHook,
23
+ _ForwardPreHookKwargs,
24
+ _FullBackwardHook,
25
+ _FullBackwardPreHook,
26
+ _LoadStateDictPostHook,
27
+ _LoadStateDictPreHook,
28
+ _NumPyArray,
29
+ _StateDictHook,
30
+ _StateDictPreHook,
31
+ )
17
32
 
18
33
  import lucid.nn as nn
19
34
 
@@ -29,9 +44,6 @@ __all__ = [
29
44
  "set_state_dict_pass_attr",
30
45
  ]
31
46
 
32
- _ForwardHookType = Callable[["Module", tuple[Tensor], tuple[Tensor]], None]
33
- _BackwardHookType = Callable[[Tensor, _NumPyArray], None]
34
-
35
47
 
36
48
  class Module:
37
49
  _registry_map: dict[Type, OrderedDict[str, Any]] = {}
@@ -49,8 +61,20 @@ class Module:
49
61
  self.training = True
50
62
  self.device: _DeviceType = "cpu"
51
63
 
52
- self._forward_hooks: list[_ForwardHookType] = []
53
- self._backward_hooks: list[_BackwardHookType] = []
64
+ self._forward_pre_hooks: list[
65
+ tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
66
+ ] = []
67
+ self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
68
+
69
+ self._backward_hooks: list[_BackwardHook] = []
70
+ self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
71
+ self._full_backward_hooks: list[_FullBackwardHook] = []
72
+
73
+ self._state_dict_pre_hooks: list[_StateDictPreHook] = []
74
+ self._state_dict_hooks: list[_StateDictHook] = []
75
+
76
+ self._load_state_dict_pre_hooks: list[_LoadStateDictPreHook] = []
77
+ self._load_state_dict_post_hooks: list[_LoadStateDictPostHook] = []
54
78
 
55
79
  self._state_dict_pass_attr = set()
56
80
 
@@ -106,14 +130,53 @@ class Module:
106
130
 
107
131
  self.__setattr__(name, buffer)
108
132
 
109
- def register_forward_hook(self, hook: _ForwardHookType) -> Callable:
110
- self._forward_hooks.append(hook)
111
- return lambda: self._forward_hooks.remove(hook)
112
-
113
- def register_backward_hook(self, hook: _BackwardHookType) -> Callable:
133
+ def register_forward_pre_hook(
134
+ self,
135
+ hook: _ForwardPreHook | _ForwardPreHookKwargs,
136
+ *,
137
+ with_kwargs: bool = False,
138
+ ) -> Callable:
139
+ self._forward_pre_hooks.append((hook, with_kwargs))
140
+ return lambda: self._forward_pre_hooks.remove((hook, with_kwargs))
141
+
142
+ def register_forward_hook(
143
+ self, hook: _ForwardHook | _ForwardHookKwargs, *, with_kwargs: bool = False
144
+ ) -> Callable:
145
+ self._forward_hooks.append((hook, with_kwargs))
146
+ return lambda: self._forward_hooks.remove((hook, with_kwargs))
147
+
148
+ def register_backward_hook(self, hook: _BackwardHook) -> Callable:
114
149
  self._backward_hooks.append(hook)
115
150
  return lambda: self._backward_hooks.remove(hook)
116
151
 
152
+ def register_full_backward_pre_hook(self, hook: _FullBackwardPreHook) -> Callable:
153
+ self._full_backward_pre_hooks.append(hook)
154
+ return lambda: self._full_backward_pre_hooks.remove(hook)
155
+
156
+ def register_full_backward_hook(self, hook: _FullBackwardHook) -> Callable:
157
+ self._full_backward_hooks.append(hook)
158
+ return lambda: self._full_backward_hooks.remove(hook)
159
+
160
+ def register_state_dict_pre_hook(self, hook: _StateDictPreHook) -> Callable:
161
+ self._state_dict_pre_hooks.append(hook)
162
+ return lambda: self._state_dict_pre_hooks.remove(hook)
163
+
164
+ def register_state_dict_hook(self, hook: _StateDictHook) -> Callable:
165
+ self._state_dict_hooks.append(hook)
166
+ return lambda: self._state_dict_hooks.remove(hook)
167
+
168
+ def register_load_state_dict_pre_hook(
169
+ self, hook: _LoadStateDictPreHook
170
+ ) -> Callable:
171
+ self._load_state_dict_pre_hooks.append(hook)
172
+ return lambda: self._load_state_dict_pre_hooks.remove(hook)
173
+
174
+ def register_load_state_dict_post_hook(
175
+ self, hook: _LoadStateDictPostHook
176
+ ) -> Callable:
177
+ self._load_state_dict_post_hooks.append(hook)
178
+ return lambda: self._load_state_dict_post_hooks.remove(hook)
179
+
117
180
  def reset_parameters(self) -> None:
118
181
  for param in self.parameters():
119
182
  param.zero()
@@ -190,6 +253,9 @@ class Module:
190
253
  prefix: str = "",
191
254
  keep_vars: bool = False,
192
255
  ) -> OrderedDict:
256
+ for hook in self._state_dict_pre_hooks:
257
+ hook(self, prefix, keep_vars)
258
+
193
259
  if destination is None:
194
260
  destination = OrderedDict()
195
261
 
@@ -208,9 +274,15 @@ class Module:
208
274
  if key in self._state_dict_pass_attr:
209
275
  del destination[key]
210
276
 
277
+ for hook in self._state_dict_hooks:
278
+ hook(self, destination, prefix, keep_vars)
279
+
211
280
  return destination
212
281
 
213
282
  def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
283
+ for hook in self._load_state_dict_pre_hooks:
284
+ hook(self, state_dict, strict)
285
+
214
286
  own_state = self.state_dict(keep_vars=True)
215
287
 
216
288
  missing_keys = set(own_state.keys()) - set(state_dict.keys())
@@ -236,15 +308,72 @@ class Module:
236
308
  elif strict:
237
309
  raise KeyError(f"Unexpected key '{key}' in state_dict.")
238
310
 
311
+ for hook in self._load_state_dict_post_hooks:
312
+ hook(self, missing_keys, unexpected_keys, strict)
313
+
239
314
  def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
315
+ for hook, with_kwargs in self._forward_pre_hooks:
316
+ if with_kwargs:
317
+ result = hook(self, args, kwargs)
318
+ if result is not None:
319
+ args, kwargs = result
320
+ else:
321
+ result = hook(self, args)
322
+ if result is not None:
323
+ args = result
324
+
240
325
  output = self.forward(*args, **kwargs)
241
- for hook in self._forward_hooks:
242
- hook(self, args, output)
326
+
327
+ for hook, with_kwargs in self._forward_hooks:
328
+ if with_kwargs:
329
+ result = hook(self, args, kwargs, output)
330
+ else:
331
+ result = hook(self, args, output)
332
+ if result is not None:
333
+ output = result
243
334
 
244
335
  if isinstance(output, Tensor) and self._backward_hooks:
245
336
  for hook in self._backward_hooks:
246
337
  output.register_hook(hook)
247
338
 
339
+ if self._full_backward_pre_hooks or self._full_backward_hooks:
340
+ outputs = output if isinstance(output, tuple) else (output,)
341
+ output_tensors = [out for out in outputs if isinstance(out, Tensor)]
342
+
343
+ if output_tensors:
344
+ grad_outputs: list[_NumPyArray | None] = [None] * len(output_tensors)
345
+ called = False
346
+
347
+ def _call_full_backward_hooks() -> None:
348
+ nonlocal called, grad_outputs
349
+ if called:
350
+ return
351
+ called = True
352
+
353
+ grad_output_tuple = tuple(grad_outputs)
354
+ for hook in self._full_backward_pre_hooks:
355
+ result = hook(self, grad_output_tuple)
356
+ if result is not None:
357
+ grad_output_tuple = result
358
+
359
+ grad_input_tuple = tuple(
360
+ arg.grad if isinstance(arg, Tensor) else None for arg in args
361
+ )
362
+ for hook in self._full_backward_hooks:
363
+ hook(self, grad_input_tuple, grad_output_tuple)
364
+
365
+ for idx, out in enumerate(output_tensors):
366
+
367
+ def _make_hook(index: int) -> Callable:
368
+ def _hook(_, grad: _NumPyArray) -> None:
369
+ grad_outputs[index] = grad
370
+ if all(g is not None for g in grad_outputs):
371
+ _call_full_backward_hooks()
372
+
373
+ return _hook
374
+
375
+ out.register_hook(_make_hook(idx))
376
+
248
377
  return output
249
378
 
250
379
  def __repr__(self) -> str:
@@ -6,8 +6,10 @@ from typing import (
6
6
  Sequence,
7
7
  Literal,
8
8
  TypeAlias,
9
+ TYPE_CHECKING,
9
10
  runtime_checkable,
10
11
  )
12
+ from collections import OrderedDict
11
13
  import re
12
14
 
13
15
  import numpy as np
@@ -76,6 +78,62 @@ class _TensorLike(Protocol):
76
78
  ) -> None: ...
77
79
 
78
80
 
81
+ @runtime_checkable
82
+ class _ModuleHookable(Protocol):
83
+ def register_forward_pre_hook(
84
+ self, hook: Callable, *, with_kwargs: bool = False
85
+ ) -> Callable: ...
86
+
87
+ def register_forward_hook(
88
+ self, hook: Callable, *, with_kwargs: bool = False
89
+ ) -> Callable: ...
90
+
91
+ def register_backward_hook(self, hook: Callable) -> Callable: ...
92
+
93
+ def register_full_backward_pre_hook(self, hook: Callable) -> Callable: ...
94
+
95
+ def register_full_backward_hook(self, hook: Callable) -> Callable: ...
96
+
97
+ def register_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
98
+
99
+ def register_state_dict_hook(self, hook: Callable) -> Callable: ...
100
+
101
+ def register_load_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
102
+
103
+ def register_load_state_dict_post_hook(self, hook: Callable) -> Callable: ...
104
+
105
+
106
+ _ForwardPreHook: TypeAlias = Callable[
107
+ [_ModuleHookable, tuple[Any, ...]], tuple[Any, ...] | None
108
+ ]
109
+ _ForwardPreHookKwargs: TypeAlias = Callable[
110
+ [_ModuleHookable, tuple[Any, ...], dict[str, Any]],
111
+ tuple[tuple[Any, ...], dict[str, Any]] | None,
112
+ ]
113
+ _ForwardHook: TypeAlias = Callable[[_ModuleHookable, tuple[Any, ...], Any], Any | None]
114
+ _ForwardHookKwargs: TypeAlias = Callable[
115
+ [_ModuleHookable, tuple[Any, ...], dict[str, Any], Any], Any | None
116
+ ]
117
+
118
+ _BackwardHook: TypeAlias = Callable[[_TensorLike, _NumPyArray], None]
119
+ _FullBackwardPreHook: TypeAlias = Callable[
120
+ [_ModuleHookable, tuple[_NumPyArray | None, ...]],
121
+ tuple[_NumPyArray | None, ...] | None,
122
+ ]
123
+ _FullBackwardHook: TypeAlias = Callable[
124
+ [_ModuleHookable, tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
125
+ tuple[_NumPyArray | None, ...] | None,
126
+ ]
127
+
128
+ _StateDictPreHook: TypeAlias = Callable[[_ModuleHookable, str, bool], None]
129
+ _StateDictHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, str, bool], None]
130
+
131
+ _LoadStateDictPreHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, bool], None]
132
+ _LoadStateDictPostHook: TypeAlias = Callable[
133
+ [_ModuleHookable, set[str], set[str], bool], None
134
+ ]
135
+
136
+
79
137
  class Numeric:
80
138
  def __init__(
81
139
  self, base_dtype: type[int | float | complex], bits: int | None
@@ -0,0 +1,2 @@
1
+ from .graph import *
2
+ from .mermaid import *
@@ -1,4 +1,6 @@
1
1
  from typing import Union
2
+ from warnings import deprecated
3
+
2
4
  import networkx as nx
3
5
  import matplotlib.pyplot as plt
4
6
 
@@ -9,6 +11,7 @@ from lucid._tensor import Tensor
9
11
  __all__ = ["draw_tensor_graph"]
10
12
 
11
13
 
14
+ @deprecated("This feature will be re-written with Mermaid in future relases.")
12
15
  def draw_tensor_graph(
13
16
  tensor: Tensor,
14
17
  horizontal: bool = False,