lucid-dl 2.9.0__tar.gz → 2.10.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/PKG-INFO +7 -5
  2. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/README.md +6 -4
  3. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/__init__.py +2 -0
  4. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_backend/conv.py +23 -4
  5. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_backend/core.py +104 -52
  6. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_backend/pool.py +22 -4
  7. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_func/bfunc.py +45 -45
  8. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_func/ufunc.py +79 -79
  9. lucid_dl-2.10.0/lucid/_fusion/__init__.py +4 -0
  10. lucid_dl-2.10.0/lucid/_fusion/base.py +120 -0
  11. lucid_dl-2.10.0/lucid/_fusion/func.py +80 -0
  12. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_tensor/tensor.py +83 -17
  13. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_util/func.py +62 -63
  14. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/einops/_func.py +10 -10
  15. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/linalg/_func.py +29 -29
  16. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/fused.py +1 -1
  17. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/types.py +27 -1
  18. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid_dl.egg-info/PKG-INFO +7 -5
  19. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid_dl.egg-info/SOURCES.txt +3 -0
  20. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/setup.py +1 -1
  21. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/LICENSE +0 -0
  22. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_backend/__init__.py +0 -0
  23. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_backend/metal.py +0 -0
  24. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_func/__init__.py +0 -0
  25. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_func/gfunc.py +0 -0
  26. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_tensor/__init__.py +0 -0
  27. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_tensor/tensor_ops.py +0 -0
  28. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/_util/__init__.py +0 -0
  29. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/data/__init__.py +0 -0
  30. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/data/_base.py +0 -0
  31. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/data/_util.py +0 -0
  32. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/datasets/__init__.py +0 -0
  33. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/datasets/_base.py +0 -0
  34. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/datasets/cifar.py +0 -0
  35. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/datasets/mnist.py +0 -0
  36. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/einops/__init__.py +0 -0
  37. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/error.py +0 -0
  38. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/linalg/__init__.py +0 -0
  39. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/__init__.py +0 -0
  40. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/__init__.py +0 -0
  41. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/alex.py +0 -0
  42. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/coatnet.py +0 -0
  43. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/convnext.py +0 -0
  44. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/crossvit.py +0 -0
  45. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/cspnet.py +0 -0
  46. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/cvt.py +0 -0
  47. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/dense.py +0 -0
  48. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/efficient.py +0 -0
  49. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/efficientformer.py +0 -0
  50. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/inception.py +0 -0
  51. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/inception_next.py +0 -0
  52. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/inception_res.py +0 -0
  53. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/lenet.py +0 -0
  54. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/maxvit.py +0 -0
  55. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/mobile.py +0 -0
  56. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/pvt.py +0 -0
  57. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/resnest.py +0 -0
  58. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/resnet.py +0 -0
  59. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/resnext.py +0 -0
  60. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/senet.py +0 -0
  61. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/sknet.py +0 -0
  62. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/swin.py +0 -0
  63. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/vgg.py +0 -0
  64. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/vit.py +0 -0
  65. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/xception.py +0 -0
  66. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imgclf/zfnet.py +0 -0
  67. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imggen/__init__.py +0 -0
  68. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imggen/ddpm.py +0 -0
  69. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/imggen/vae.py +0 -0
  70. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/__init__.py +0 -0
  71. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/detr.py +0 -0
  72. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/efficientdet.py +0 -0
  73. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/fast_rcnn.py +0 -0
  74. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/faster_rcnn.py +0 -0
  75. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/rcnn.py +0 -0
  76. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/util.py +0 -0
  77. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/__init__.py +0 -0
  78. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
  79. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
  80. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
  81. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
  82. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/seq2seq/__init__.py +0 -0
  83. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/seq2seq/transformer.py +0 -0
  84. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/models/util.py +0 -0
  85. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/__init__.py +0 -0
  86. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/__init__.py +0 -0
  87. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_activation.py +0 -0
  88. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_attention.py +0 -0
  89. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_conv.py +0 -0
  90. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_drop.py +0 -0
  91. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_linear.py +0 -0
  92. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_loss.py +0 -0
  93. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_norm.py +0 -0
  94. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_pool.py +0 -0
  95. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_spatial.py +0 -0
  96. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/functional/_util.py +0 -0
  97. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/init/__init__.py +0 -0
  98. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/init/_dist.py +0 -0
  99. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/module.py +0 -0
  100. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/__init__.py +0 -0
  101. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/activation.py +0 -0
  102. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/attention.py +0 -0
  103. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/conv.py +0 -0
  104. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/drop.py +0 -0
  105. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/einops.py +0 -0
  106. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/linear.py +0 -0
  107. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/loss.py +0 -0
  108. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/norm.py +0 -0
  109. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/pool.py +0 -0
  110. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/rnn.py +0 -0
  111. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/sparse.py +0 -0
  112. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/transformer.py +0 -0
  113. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/modules/vision.py +0 -0
  114. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/parameter.py +0 -0
  115. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/nn/util.py +0 -0
  116. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/__init__.py +0 -0
  117. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/_base.py +0 -0
  118. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/ada.py +0 -0
  119. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/adam.py +0 -0
  120. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/__init__.py +0 -0
  121. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/_base.py +0 -0
  122. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
  123. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/prop.py +0 -0
  124. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/optim/sgd.py +0 -0
  125. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/port.py +0 -0
  126. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/random/__init__.py +0 -0
  127. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/random/_func.py +0 -0
  128. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/transforms/__init__.py +0 -0
  129. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/transforms/_base.py +0 -0
  130. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/transforms/image.py +0 -0
  131. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/visual/__init__.py +0 -0
  132. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/visual/graph.py +0 -0
  133. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/weights/__init__.py +0 -0
  134. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid/weights/__init__.pyi +0 -0
  135. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid_dl.egg-info/dependency_links.txt +0 -0
  136. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid_dl.egg-info/requires.txt +0 -0
  137. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/lucid_dl.egg-info/top_level.txt +0 -0
  138. {lucid_dl-2.9.0 → lucid_dl-2.10.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.9.0
3
+ Version: 2.10.0
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -33,7 +33,7 @@ Dynamic: summary
33
33
  ![PyPI - Total Downloads](https://img.shields.io/badge/total%20downloads-34.0k-yellow.svg)
34
34
  ![GitHub code size in bytes](https://img.shields.io/github/languages/code-size/ChanLumerico/lucid.svg)
35
35
  ![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)
36
- ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-26.9k-purple.svg)
36
+ ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-27.7k-purple.svg)
37
37
 
38
38
  **Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
39
39
 
@@ -50,9 +50,11 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
50
50
 
51
51
  - Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
52
52
 
53
- - Added new neural module category `nn.rnn`, including:
54
-
55
- `nn.RNNBase`, `nn.RNN`, `nn.LSTM`, `nn.GRU`, `nn.RNNCell`, `nn.LSTMCell`, `nn.GRUCell`
53
+ - Introduced **Backward Fusion** for CPU execution:
54
+ - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
55
+ - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
56
+ - Uses heuristic thresholds to avoid fusion overhead on small tensors
57
+ - Disabled by default on GPU paths to ensure stable performance
56
58
 
57
59
  ## 🔧 How to Install
58
60
 
@@ -5,7 +5,7 @@
5
5
  ![PyPI - Total Downloads](https://img.shields.io/badge/total%20downloads-34.0k-yellow.svg)
6
6
  ![GitHub code size in bytes](https://img.shields.io/github/languages/code-size/ChanLumerico/lucid.svg)
7
7
  ![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)
8
- ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-26.9k-purple.svg)
8
+ ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-27.7k-purple.svg)
9
9
 
10
10
  **Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
11
11
 
@@ -22,9 +22,11 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
22
22
 
23
23
  - Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
24
24
 
25
- - Added new neural module category `nn.rnn`, including:
26
-
27
- `nn.RNNBase`, `nn.RNN`, `nn.LSTM`, `nn.GRU`, `nn.RNNCell`, `nn.LSTMCell`, `nn.GRUCell`
25
+ - Introduced **Backward Fusion** for CPU execution:
26
+ - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
27
+ - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
28
+ - Uses heuristic thresholds to avoid fusion overhead on small tensors
29
+ - Disabled by default on GPU paths to ensure stable performance
28
30
 
29
31
  ## 🔧 How to Install
30
32
 
@@ -50,6 +50,8 @@ import lucid.einops as einops
50
50
  import lucid.nn as nn
51
51
  import lucid.types as types
52
52
 
53
+ from lucid._fusion import ENABLE_FUSION
54
+
53
55
 
54
56
  _grad_enabled: bool = True
55
57
  _flops_enabled: bool = False
@@ -8,10 +8,10 @@ import numpy as np
8
8
 
9
9
  from lucid._tensor import Tensor
10
10
  from lucid._backend.core import (
11
- operation,
11
+ Operation,
12
12
  binary_func_op,
13
13
  _FuncOpReturnType,
14
- _GradFuncType,
14
+ _GradType,
15
15
  )
16
16
  from lucid._backend.metal import mx
17
17
 
@@ -451,7 +451,7 @@ def _conv_backward_input(
451
451
  return grad_input
452
452
 
453
453
 
454
- class conv_nd(operation):
454
+ class conv_nd(Operation):
455
455
  def __init__(
456
456
  self,
457
457
  stride: int | tuple[int, ...] | list[int],
@@ -499,7 +499,7 @@ class conv_nd(operation):
499
499
  self.result = Tensor(out)
500
500
  return self.result, partial(self.__grad__, a=a, b=b, lib_=mx)
501
501
 
502
- def __grad__(self, a: Tensor, b: Tensor, lib_: ModuleType) -> _GradFuncType:
502
+ def __grad__(self, a: Tensor, b: Tensor, lib_: ModuleType) -> _GradType:
503
503
  stride = self._stride
504
504
  padding = self._padding
505
505
  dilation = self._dilation
@@ -519,6 +519,25 @@ class conv_nd(operation):
519
519
 
520
520
  return grad_input, grad_weight
521
521
 
522
+ def __flops__(self, a: Tensor, b: Tensor) -> int:
523
+ stride = self._stride
524
+ padding = self._padding
525
+ dilation = self._dilation
526
+ if stride is None or padding is None or dilation is None:
527
+ stride, padding, dilation = self._normalize(b)
528
+
529
+ N = int(a.shape[0])
530
+ C_out = int(b.shape[0])
531
+ C_in_g = int(b.shape[1])
532
+ kernel_size = tuple(int(v) for v in b.shape[2:])
533
+ out_dims = _conv_out_dims(
534
+ tuple(int(v) for v in a.shape[2:]), kernel_size, stride, padding, dilation
535
+ )
536
+
537
+ macs_per_out = C_in_g * _prod(kernel_size)
538
+ out_elems = N * C_out * _prod(tuple(out_dims))
539
+ return out_elems * macs_per_out
540
+
522
541
 
523
542
  def conv_nd_op(
524
543
  stride: int | tuple[int, ...] | list[int],
@@ -4,18 +4,24 @@ import functools
4
4
  import weakref
5
5
 
6
6
  import lucid
7
- import lucid.types as types
8
- from lucid.types import _DeviceType, _NumPyArray, _MLXArray, _BuiltinNumeric
7
+ from lucid.types import (
8
+ Numeric,
9
+ _DeviceType,
10
+ _NumPyArray,
11
+ _MLXArray,
12
+ _BuiltinNumeric,
13
+ _TensorLike,
14
+ )
9
15
 
10
- from lucid._tensor import Tensor
11
16
  from lucid._backend.metal import is_gpu_op
12
17
 
13
18
 
14
- _GradFuncType = Callable[[None], Tuple[_NumPyArray | _MLXArray, ...]]
19
+ _GradType = _NumPyArray | _MLXArray | Tuple[_NumPyArray | _MLXArray, ...]
20
+ _GradFuncType = Callable[[], _GradType]
15
21
 
16
- _ReturnGradFuncPair = Tuple[Tensor, _GradFuncType]
17
22
 
18
- _FuncOpReturnType = Tuple[_ReturnGradFuncPair, ...]
23
+ _ReturnGradFuncPair = Tuple[_TensorLike, _GradFuncType]
24
+ _FuncOpReturnType = _ReturnGradFuncPair | Tuple[_ReturnGradFuncPair, ...]
19
25
 
20
26
 
21
27
  def func_op(
@@ -24,13 +30,13 @@ def func_op(
24
30
  has_gradient: bool = True,
25
31
  device: _DeviceType = "cpu",
26
32
  ) -> Callable:
27
- def decorator(func: Callable[..., _FuncOpReturnType]) -> Callable:
28
- @functools.wraps(func)
29
- def wrapper(op_self: operation, *args, **kwargs) -> Tuple[Tensor, ...]:
30
- tensors: Tuple[Tensor, ...] = tuple()
33
+ def decorator(forward_func: Callable[..., _FuncOpReturnType]) -> Callable:
34
+ @functools.wraps(forward_func)
35
+ def wrapper(op_self: Operation, *args, **kwargs) -> Tuple[_TensorLike, ...]:
36
+ tensors: Tuple[_TensorLike, ...] = tuple()
31
37
  requires_grad = False
32
38
  is_free = True
33
- dtype_hint: _BuiltinNumeric | types.Numeric | None = None
39
+ dtype_hint: _BuiltinNumeric | Numeric | None = None
34
40
 
35
41
  if n_in is None:
36
42
  tensor_args = args
@@ -42,7 +48,7 @@ def func_op(
42
48
  tensor_args = args[:n_in]
43
49
 
44
50
  for arg in tensor_args:
45
- if isinstance(arg, Tensor):
51
+ if isinstance(arg, _TensorLike):
46
52
  dtype_hint = arg.dtype
47
53
  break
48
54
 
@@ -64,7 +70,7 @@ def func_op(
64
70
 
65
71
  non_tensor_args = args[n_in:] if n_in is not None else ()
66
72
  new_args = (*tensors, *non_tensor_args)
67
- func_return_pairs = func(op_self, *new_args, **kwargs)
73
+ func_return_pairs = forward_func(op_self, *new_args, **kwargs)
68
74
 
69
75
  tensor_refs = tuple(weakref.ref(t) for t in tensors)
70
76
 
@@ -78,7 +84,7 @@ def func_op(
78
84
  if n_ret is None:
79
85
  if not isinstance(func_return_pairs, tuple):
80
86
  raise ValueError(
81
- f"{func.__name__} should return multiple '_ReturnGradFuncPair'."
87
+ f"{forward_func.__name__} should return multiple '_ReturnGradFuncPair'."
82
88
  )
83
89
  num_returns = len(func_return_pairs)
84
90
  else:
@@ -87,45 +93,27 @@ def func_op(
87
93
  if num_returns == 1:
88
94
  func_return_pairs: _FuncOpReturnType = (func_return_pairs,)
89
95
 
90
- results: Tuple[Tensor, ...] = tuple()
91
- for result, compute_grad in func_return_pairs:
96
+ results: Tuple[_TensorLike, ...] = tuple()
97
+ for result, grad_func in func_return_pairs:
92
98
  result.requires_grad = requires_grad and has_gradient and grad_enabled
93
- if track_graph:
94
- result._op = op_self
95
99
  result.to(device)
96
- if is_free:
97
- result.free()
98
-
100
+ result.free() if is_free else ...
99
101
  results += (result,)
102
+
100
103
  if not track_graph:
101
104
  continue
102
-
103
- def _backward_op(
104
- *, _func: Callable = compute_grad, _tensor_refs=tensor_refs
105
- ) -> None:
106
- grads = _func()
107
- if n_in == 1 or not isinstance(grads, tuple):
108
- grads = (grads,)
109
-
110
- live_tensors = tuple(ref() for ref in _tensor_refs)
111
- if any(t is None for t in live_tensors):
112
- return
113
-
114
- if len(grads) != len(live_tensors):
115
- raise ValueError(
116
- f"Expected {len(live_tensors)} gradients, got {len(grads)}."
117
- )
118
-
119
- for tensor, grad in zip(live_tensors, grads):
120
- new_grad = lucid._match_grad_shape(
121
- tensor.data, grad, device=device
122
- )
123
- lucid._set_tensor_grad(tensor, new_grad)
105
+ result._op = op_self
124
106
 
125
107
  if result.requires_grad or lucid.flops_enabled():
126
108
  result._prev = list(tensors)
127
- result._backward_op = (
128
- _backward_op if result.requires_grad else lambda: None
109
+ if not result.requires_grad:
110
+ continue
111
+
112
+ result._backward_op = BackwardOperation(
113
+ forward_op_ref=weakref.ref(op_self),
114
+ grad_func=grad_func,
115
+ tensor_refs=tensor_refs,
116
+ device=device,
129
117
  )
130
118
 
131
119
  if track_graph:
@@ -161,11 +149,11 @@ def poly_func_op(has_gradient: bool = True, device: _DeviceType = "cpu") -> Call
161
149
  return func_op(None, 1, has_gradient=has_gradient, device=device)
162
150
 
163
151
 
164
- class operation(ABC):
152
+ class Operation(ABC):
165
153
  __fallback__: ClassVar[bool] = False
166
154
 
167
155
  def __init__(self) -> None:
168
- self.result: Tensor | tuple[Tensor, ...] | None = None
156
+ self.result: _TensorLike | tuple[_TensorLike, ...] | None = None
169
157
  self._flops: int | None = None
170
158
 
171
159
  def clear(self) -> None:
@@ -177,11 +165,11 @@ class operation(ABC):
177
165
  @abstractmethod
178
166
  def gpu(self, *args, **kwargs) -> _FuncOpReturnType: ...
179
167
 
180
- def __grad__(self, *args, **kwargs) -> _GradFuncType: ...
168
+ def __grad__(self, *args, **kwargs) -> _GradType: ...
181
169
 
182
- def __grad_cpu__(self, *args, **kwargs) -> _GradFuncType: ...
170
+ def __grad_cpu__(self, *args, **kwargs) -> _GradType: ...
183
171
 
184
- def __grad_gpu__(self, *args, **kwargs) -> _GradFuncType: ...
172
+ def __grad_gpu__(self, *args, **kwargs) -> _GradType: ...
185
173
 
186
174
  @property
187
175
  def flops(self) -> int:
@@ -196,12 +184,76 @@ class operation(ABC):
196
184
  def __flops__(self, *args, **kwargs) -> int:
197
185
  return 0
198
186
 
199
- def __call__(self, *args, **kwargs) -> Tensor | tuple[Tensor, ...]:
187
+ def __call__(self, *args, **kwargs) -> _TensorLike | tuple[_TensorLike, ...]:
200
188
  if is_gpu_op(*args):
201
189
  return self.gpu(*args, **kwargs)
202
190
  return self.cpu(*args, **kwargs)
203
191
 
204
192
 
205
- def fallback(cls: type[operation]) -> type[operation]:
193
+ def fallback(cls: type[Operation]) -> type[Operation]:
206
194
  cls.__fallback__ = True
207
195
  return cls
196
+
197
+
198
+ class BackwardOperation:
199
+ def __init__(
200
+ self,
201
+ forward_op_ref: weakref.ref[Operation] | None,
202
+ grad_func: _GradFuncType | None,
203
+ tensor_refs: tuple[weakref.ref[_TensorLike]],
204
+ device: _DeviceType | None = "cpu",
205
+ custom_closure: Callable[[], None] | None = None,
206
+ ) -> None:
207
+ self.forward_op_ref = forward_op_ref
208
+ self.grad_func = grad_func
209
+ self.tensor_refs = tensor_refs
210
+ self.device = device
211
+
212
+ self.custom_closure = custom_closure
213
+ self.num_inputs = len(tensor_refs)
214
+
215
+ if self.grad_func is None and self.custom_closure is None:
216
+ raise ValueError("Either 'grad_func' or 'custom_closure' must be provided.")
217
+
218
+ def override_grad_func(self, new_grad_func: _GradFuncType) -> None:
219
+ if self.custom_closure is not None:
220
+ return
221
+ self.grad_func = new_grad_func
222
+
223
+ def override_tensor_refs(
224
+ self, new_tensor_refs: tuple[weakref.ref[_TensorLike]]
225
+ ) -> None:
226
+ self.tensor_refs = new_tensor_refs
227
+ self.num_inputs = len(new_tensor_refs)
228
+
229
+ def __call__(self) -> None:
230
+ if self.custom_closure is not None:
231
+ self.custom_closure()
232
+ return
233
+
234
+ if self.device is None and self.forward_op_ref is not None:
235
+ raise RuntimeError(
236
+ "Only 'noop' BackwardOperation can be called without device."
237
+ )
238
+
239
+ grads = self.grad_func()
240
+ if self.num_inputs == 1 or not isinstance(grads, tuple):
241
+ grads = (grads,)
242
+
243
+ live_tensors = tuple(ref() for ref in self.tensor_refs)
244
+ if any(t is None for t in live_tensors):
245
+ return
246
+
247
+ if len(grads) != len(live_tensors):
248
+ raise ValueError(
249
+ f"Expected {len(live_tensors)} gradients, got {len(grads)}."
250
+ )
251
+
252
+ for tensor, grad in zip(live_tensors, grads):
253
+ new_grad = lucid._match_grad_shape(tensor.data, grad, device=self.device)
254
+ lucid._set_tensor_grad(tensor, new_grad)
255
+
256
+
257
+ noop = BackwardOperation(
258
+ forward_op_ref=None, grad_func=lambda: (), tensor_refs=(), device=None
259
+ )
@@ -7,10 +7,10 @@ import numpy as np
7
7
 
8
8
  from lucid._tensor import Tensor
9
9
  from lucid._backend.core import (
10
- operation,
10
+ Operation,
11
11
  unary_func_op,
12
12
  _FuncOpReturnType,
13
- _GradFuncType,
13
+ _GradType,
14
14
  )
15
15
  from lucid._backend.metal import mx
16
16
  from lucid.types import _NumPyArray, _MLXArray
@@ -211,7 +211,7 @@ def _pool_backward_max(
211
211
  return _crop_padding(grad_input_pad, padding)
212
212
 
213
213
 
214
- class pool_nd(operation):
214
+ class pool_nd(Operation):
215
215
  def __init__(
216
216
  self,
217
217
  kernel_size: int | tuple[int, ...] | list[int],
@@ -295,7 +295,7 @@ class pool_nd(operation):
295
295
  self.result = Tensor(out)
296
296
  return self.result, partial(self.__grad__, lib_=mx)
297
297
 
298
- def __grad__(self, lib_: ModuleType) -> _GradFuncType:
298
+ def __grad__(self, lib_: ModuleType) -> _GradType:
299
299
  if (
300
300
  self._kernel_size is None
301
301
  or self._stride is None
@@ -333,6 +333,24 @@ class pool_nd(operation):
333
333
  )
334
334
  return grad_input
335
335
 
336
+ def __flops__(self, a: Tensor) -> int:
337
+ if self._kernel_size is None or self._out_dims is None:
338
+ kernel, stride, padding = self._normalize(a)
339
+ out_dims = _pool_out_dims(a.shape[2:], kernel, stride, padding)
340
+ else:
341
+ kernel = self._kernel_size
342
+ out_dims = self._out_dims
343
+
344
+ kernel_elems = _prod(kernel)
345
+ out_elems = int(a.shape[0]) * int(a.shape[1]) * _prod(out_dims)
346
+
347
+ if kernel_elems <= 0 or out_elems <= 0:
348
+ return 0
349
+
350
+ if self.mode == "avg":
351
+ return out_elems * kernel_elems
352
+ return out_elems * max(kernel_elems - 1, 0)
353
+
336
354
 
337
355
  def avg_pool_nd_op(
338
356
  kernel_size: int | tuple[int, ...] | list[int],