lucid-dl 2.8.5__tar.gz → 2.10.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/PKG-INFO +8 -6
  2. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/README.md +7 -5
  3. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/__init__.py +44 -5
  4. lucid_dl-2.10.0/lucid/_backend/conv.py +548 -0
  5. lucid_dl-2.10.0/lucid/_backend/core.py +259 -0
  6. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_backend/metal.py +38 -1
  7. lucid_dl-2.10.0/lucid/_backend/pool.py +368 -0
  8. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/bfunc.py +45 -45
  9. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/ufunc.py +79 -79
  10. lucid_dl-2.10.0/lucid/_fusion/__init__.py +4 -0
  11. lucid_dl-2.10.0/lucid/_fusion/base.py +120 -0
  12. lucid_dl-2.10.0/lucid/_fusion/func.py +80 -0
  13. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_tensor/tensor.py +102 -15
  14. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_util/func.py +62 -63
  15. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/einops/_func.py +10 -10
  16. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/linalg/_func.py +29 -29
  17. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_conv.py +15 -62
  18. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_norm.py +4 -4
  19. lucid_dl-2.10.0/lucid/nn/functional/_pool.py +141 -0
  20. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/fused.py +1 -1
  21. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/init/_dist.py +25 -11
  22. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/norm.py +3 -3
  23. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/parameter.py +12 -2
  24. lucid_dl-2.10.0/lucid/optim/ada.py +189 -0
  25. lucid_dl-2.10.0/lucid/optim/adam.py +317 -0
  26. lucid_dl-2.10.0/lucid/optim/prop.py +156 -0
  27. lucid_dl-2.10.0/lucid/optim/sgd.py +147 -0
  28. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/types.py +27 -1
  29. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/PKG-INFO +8 -6
  30. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/SOURCES.txt +5 -0
  31. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/setup.py +1 -1
  32. lucid_dl-2.8.5/lucid/_backend/core.py +0 -170
  33. lucid_dl-2.8.5/lucid/nn/functional/_pool.py +0 -275
  34. lucid_dl-2.8.5/lucid/optim/ada.py +0 -179
  35. lucid_dl-2.8.5/lucid/optim/adam.py +0 -304
  36. lucid_dl-2.8.5/lucid/optim/prop.py +0 -144
  37. lucid_dl-2.8.5/lucid/optim/sgd.py +0 -139
  38. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/LICENSE +0 -0
  39. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_backend/__init__.py +0 -0
  40. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/__init__.py +0 -0
  41. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/gfunc.py +0 -0
  42. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_tensor/__init__.py +0 -0
  43. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_tensor/tensor_ops.py +0 -0
  44. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_util/__init__.py +0 -0
  45. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/data/__init__.py +0 -0
  46. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/data/_base.py +0 -0
  47. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/data/_util.py +0 -0
  48. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/__init__.py +0 -0
  49. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/_base.py +0 -0
  50. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/cifar.py +0 -0
  51. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/mnist.py +0 -0
  52. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/einops/__init__.py +0 -0
  53. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/error.py +0 -0
  54. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/linalg/__init__.py +0 -0
  55. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/__init__.py +0 -0
  56. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/__init__.py +0 -0
  57. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/alex.py +0 -0
  58. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/coatnet.py +0 -0
  59. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/convnext.py +0 -0
  60. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/crossvit.py +0 -0
  61. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/cspnet.py +0 -0
  62. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/cvt.py +0 -0
  63. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/dense.py +0 -0
  64. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/efficient.py +0 -0
  65. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/efficientformer.py +0 -0
  66. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/inception.py +0 -0
  67. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/inception_next.py +0 -0
  68. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/inception_res.py +0 -0
  69. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/lenet.py +0 -0
  70. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/maxvit.py +0 -0
  71. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/mobile.py +0 -0
  72. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/pvt.py +0 -0
  73. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/resnest.py +0 -0
  74. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/resnet.py +0 -0
  75. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/resnext.py +0 -0
  76. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/senet.py +0 -0
  77. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/sknet.py +0 -0
  78. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/swin.py +0 -0
  79. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/vgg.py +0 -0
  80. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/vit.py +0 -0
  81. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/xception.py +0 -0
  82. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/zfnet.py +0 -0
  83. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imggen/__init__.py +0 -0
  84. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imggen/ddpm.py +0 -0
  85. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imggen/vae.py +0 -0
  86. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/__init__.py +0 -0
  87. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/detr.py +0 -0
  88. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/efficientdet.py +0 -0
  89. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/fast_rcnn.py +0 -0
  90. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/faster_rcnn.py +0 -0
  91. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/rcnn.py +0 -0
  92. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/util.py +0 -0
  93. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/__init__.py +0 -0
  94. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
  95. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
  96. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
  97. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
  98. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/seq2seq/__init__.py +0 -0
  99. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/seq2seq/transformer.py +0 -0
  100. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/util.py +0 -0
  101. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/__init__.py +0 -0
  102. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/__init__.py +0 -0
  103. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_activation.py +0 -0
  104. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_attention.py +0 -0
  105. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_drop.py +0 -0
  106. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_linear.py +0 -0
  107. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_loss.py +0 -0
  108. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_spatial.py +0 -0
  109. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_util.py +0 -0
  110. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/init/__init__.py +0 -0
  111. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/module.py +0 -0
  112. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/__init__.py +0 -0
  113. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/activation.py +0 -0
  114. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/attention.py +0 -0
  115. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/conv.py +0 -0
  116. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/drop.py +0 -0
  117. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/einops.py +0 -0
  118. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/linear.py +0 -0
  119. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/loss.py +0 -0
  120. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/pool.py +0 -0
  121. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/rnn.py +0 -0
  122. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/sparse.py +0 -0
  123. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/transformer.py +0 -0
  124. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/vision.py +0 -0
  125. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/util.py +0 -0
  126. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/__init__.py +0 -0
  127. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/_base.py +0 -0
  128. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/__init__.py +0 -0
  129. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/_base.py +0 -0
  130. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
  131. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/port.py +0 -0
  132. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/random/__init__.py +0 -0
  133. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/random/_func.py +0 -0
  134. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/transforms/__init__.py +0 -0
  135. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/transforms/_base.py +0 -0
  136. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/transforms/image.py +0 -0
  137. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/visual/__init__.py +0 -0
  138. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/visual/graph.py +0 -0
  139. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/weights/__init__.py +0 -0
  140. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/weights/__init__.pyi +0 -0
  141. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/dependency_links.txt +0 -0
  142. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/requires.txt +0 -0
  143. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/top_level.txt +0 -0
  144. {lucid_dl-2.8.5 → lucid_dl-2.10.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.8.5
3
+ Version: 2.10.0
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -33,7 +33,7 @@ Dynamic: summary
33
33
  ![PyPI - Total Downloads](https://img.shields.io/badge/total%20downloads-34.0k-yellow.svg)
34
34
  ![GitHub code size in bytes](https://img.shields.io/github/languages/code-size/ChanLumerico/lucid.svg)
35
35
  ![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)
36
- ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-26.9k-purple.svg)
36
+ ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-27.7k-purple.svg)
37
37
 
38
38
  **Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
39
39
 
@@ -44,15 +44,17 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
44
44
 
45
45
  #### Other Languages
46
46
 
47
- [🇰🇷 README.md in Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
47
+ [🇰🇷 Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
51
  - Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
52
52
 
53
- - Added new neural module category `nn.rnn`, including:
54
-
55
- `nn.RNNBase`, `nn.RNN`, `nn.LSTM`, `nn.GRU`, `nn.RNNCell`, `nn.LSTMCell`, `nn.GRUCell`
53
+ - Introduced **Backward Fusion** for CPU execution:
54
+ - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
55
+ - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
56
+ - Uses heuristic thresholds to avoid fusion overhead on small tensors
57
+ - Disabled by default on GPU paths to ensure stable performance
56
58
 
57
59
  ## 🔧 How to Install
58
60
 
@@ -5,7 +5,7 @@
5
5
  ![PyPI - Total Downloads](https://img.shields.io/badge/total%20downloads-34.0k-yellow.svg)
6
6
  ![GitHub code size in bytes](https://img.shields.io/github/languages/code-size/ChanLumerico/lucid.svg)
7
7
  ![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)
8
- ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-26.9k-purple.svg)
8
+ ![Lines of Code](https://img.shields.io/badge/lines%20of%20code-27.7k-purple.svg)
9
9
 
10
10
  **Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
11
11
 
@@ -16,15 +16,17 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
16
16
 
17
17
  #### Other Languages
18
18
 
19
- [🇰🇷 README.md in Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
19
+ [🇰🇷 Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
20
20
 
21
21
  ### 🔥 What's New
22
22
 
23
23
  - Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
24
24
 
25
- - Added new neural module category `nn.rnn`, including:
26
-
27
- `nn.RNNBase`, `nn.RNN`, `nn.LSTM`, `nn.GRU`, `nn.RNNCell`, `nn.LSTMCell`, `nn.GRUCell`
25
+ - Introduced **Backward Fusion** for CPU execution:
26
+ - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
27
+ - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
28
+ - Uses heuristic thresholds to avoid fusion overhead on small tensors
29
+ - Disabled by default on GPU paths to ensure stable performance
28
30
 
29
31
  ## 🔧 How to Install
30
32
 
@@ -15,7 +15,7 @@ algorithms and operations without the complexity of high-level frameworks.
15
15
 
16
16
  from contextlib import contextmanager, AbstractContextManager
17
17
  from typing import Any, Generator, SupportsIndex, Callable, Self, Optional, Type
18
- from types import TracebackType
18
+ from types import TracebackType, ModuleType
19
19
  from functools import wraps
20
20
  from pathlib import Path
21
21
 
@@ -50,6 +50,8 @@ import lucid.einops as einops
50
50
  import lucid.nn as nn
51
51
  import lucid.types as types
52
52
 
53
+ from lucid._fusion import ENABLE_FUSION
54
+
53
55
 
54
56
  _grad_enabled: bool = True
55
57
  _flops_enabled: bool = False
@@ -177,11 +179,18 @@ def _set_tensor_grad(
177
179
 
178
180
 
179
181
  def _check_is_tensor(
180
- any: Tensor | _ArrayOrScalar, device: _DeviceType = "cpu"
182
+ any: Tensor | _ArrayOrScalar,
183
+ device: _DeviceType = "cpu",
184
+ dtype: _BuiltinNumeric | Numeric | None = None,
181
185
  ) -> Tensor:
182
- if not isinstance(any, Tensor):
183
- return Tensor(any, device=device)
184
- return any
186
+ if isinstance(any, Tensor):
187
+ return any
188
+
189
+ is_scalar = not isinstance(any, (_NumPyArray, _MLXArray, list, tuple))
190
+ if dtype is not None and is_scalar:
191
+ return Tensor(any, device=device, dtype=dtype)
192
+
193
+ return Tensor(any, device=device)
185
194
 
186
195
 
187
196
  def _match_grad_shape(
@@ -293,3 +302,33 @@ def register_model(func: _ModuleReturnFunc) -> _ModuleReturnFunc:
293
302
  return model
294
303
 
295
304
  return wrapper
305
+
306
+
307
+ def _conv_view_limit_mb() -> int:
308
+ from lucid._backend import conv as _conv_backend
309
+
310
+ return _conv_backend.get_conv_view_limit_mb()
311
+
312
+
313
+ def __getattr__(name: str) -> Any:
314
+ if name == "CONV_VIEW_LIMIT_MB":
315
+ return _conv_view_limit_mb()
316
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
317
+
318
+
319
+ def __dir__() -> list[str]:
320
+ return sorted(list(globals().keys()) + ["CONV_VIEW_LIMIT_MB"])
321
+
322
+
323
+ class _LucidModule(ModuleType):
324
+ def __setattr__(self, name: str, value: Any) -> None:
325
+ if name == "CONV_VIEW_LIMIT_MB":
326
+ raise AttributeError(
327
+ "CONV_VIEW_LIMIT_MB is read-only; set LUCID_CONV_VIEW_LIMIT_MB "
328
+ "before importing lucid."
329
+ )
330
+ super().__setattr__(name, value)
331
+
332
+
333
+ if not isinstance(sys.modules[__name__], _LucidModule):
334
+ sys.modules[__name__].__class__ = _LucidModule
@@ -0,0 +1,548 @@
1
+ from functools import partial
2
+ from types import ModuleType
3
+ from typing import TypeAlias
4
+ import itertools
5
+ import os
6
+
7
+ import numpy as np
8
+
9
+ from lucid._tensor import Tensor
10
+ from lucid._backend.core import (
11
+ Operation,
12
+ binary_func_op,
13
+ _FuncOpReturnType,
14
+ _GradType,
15
+ )
16
+ from lucid._backend.metal import mx
17
+
18
+ from lucid.types import _NumPyArray, _MLXArray
19
+
20
+
21
+ _Array: TypeAlias = _NumPyArray | _MLXArray
22
+ _Shape: TypeAlias = tuple[int, ...]
23
+ _Stride: TypeAlias = tuple[int, ...]
24
+ _Padding: TypeAlias = tuple[int, ...]
25
+ _Dilation: TypeAlias = tuple[int, ...]
26
+
27
+
28
+ def _load_view_limit_bytes() -> int:
29
+ env = os.getenv("LUCID_CONV_VIEW_LIMIT_MB")
30
+ if env is None:
31
+ return _default_view_limit_bytes()
32
+ try:
33
+ value = int(env)
34
+ except ValueError:
35
+ return _default_view_limit_bytes()
36
+ return value * 1024 * 1024
37
+
38
+
39
+ def _sysconf_value(name: str) -> int | None:
40
+ try:
41
+ value = int(os.sysconf(name))
42
+ except (ValueError, AttributeError, OSError):
43
+ return None
44
+ if value <= 0:
45
+ return None
46
+ return value
47
+
48
+
49
+ def _get_total_memory_bytes() -> int | None:
50
+ page_size = _sysconf_value("SC_PAGE_SIZE") or _sysconf_value("SC_PAGESIZE")
51
+ phys_pages = _sysconf_value("SC_PHYS_PAGES")
52
+ if page_size and phys_pages:
53
+ return page_size * phys_pages
54
+ try:
55
+ import ctypes
56
+
57
+ class MEMORYSTATUSEX(ctypes.Structure):
58
+ _fields_ = [
59
+ ("dwLength", ctypes.c_ulong),
60
+ ("dwMemoryLoad", ctypes.c_ulong),
61
+ ("ullTotalPhys", ctypes.c_ulonglong),
62
+ ("ullAvailPhys", ctypes.c_ulonglong),
63
+ ("ullTotalPageFile", ctypes.c_ulonglong),
64
+ ("ullAvailPageFile", ctypes.c_ulonglong),
65
+ ("ullTotalVirtual", ctypes.c_ulonglong),
66
+ ("ullAvailVirtual", ctypes.c_ulonglong),
67
+ ("ullAvailExtendedVirtual", ctypes.c_ulonglong),
68
+ ]
69
+
70
+ stat = MEMORYSTATUSEX()
71
+ stat.dwLength = ctypes.sizeof(MEMORYSTATUSEX)
72
+ if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)):
73
+ return int(stat.ullTotalPhys)
74
+
75
+ except Exception:
76
+ return None
77
+
78
+
79
+ def _round_to_step(value: int, step: int) -> int:
80
+ return ((value + step // 2) // step) * step
81
+
82
+
83
+ def _default_view_limit_bytes() -> int:
84
+ total = _get_total_memory_bytes()
85
+ if not total:
86
+ return 256 * 1024 * 1024
87
+
88
+ mb = 1024 * 1024
89
+ min_bytes = 64 * mb
90
+ max_bytes = 1024 * mb
91
+ step = 64 * mb
92
+
93
+ target = (total * 15) // 1000
94
+ target = max(min_bytes, min(max_bytes, target))
95
+ target = _round_to_step(target, step)
96
+ return max(min_bytes, min(max_bytes, target))
97
+
98
+
99
+ _CONV_VIEW_LIMIT_BYTES = _load_view_limit_bytes()
100
+
101
+
102
+ def get_conv_view_limit_mb() -> int:
103
+ return int(_CONV_VIEW_LIMIT_BYTES // (1024 * 1024))
104
+
105
+
106
+ def _dtype_itemsize(data: _Array) -> int:
107
+ dtype = getattr(data, "dtype", None)
108
+ if dtype is None:
109
+ return 0
110
+ try:
111
+ return int(np.dtype(dtype).itemsize)
112
+ except TypeError:
113
+ return int(getattr(dtype, "size", 0) or 0)
114
+
115
+
116
+ def _prod(shape: _Shape) -> int:
117
+ total = 1
118
+ for v in shape:
119
+ total *= int(v)
120
+ return total
121
+
122
+
123
+ def _view_exceeds_limit(data: _Array, out_dims: _Shape, kernel_size: _Shape) -> bool:
124
+ if _CONV_VIEW_LIMIT_BYTES == 0:
125
+ return True
126
+ if _CONV_VIEW_LIMIT_BYTES < 0:
127
+ return False
128
+ itemsize = _dtype_itemsize(data)
129
+ if itemsize == 0:
130
+ return False
131
+
132
+ view_elems = data.shape[0] * data.shape[1] * _prod(out_dims) * _prod(kernel_size)
133
+ view_bytes = view_elems * itemsize
134
+
135
+ return view_bytes > _CONV_VIEW_LIMIT_BYTES
136
+
137
+
138
+ def _to_tuple(value: int | tuple[int, ...] | list[int], dim: int, name: str) -> _Shape:
139
+ if isinstance(value, int):
140
+ return (value,) * dim
141
+
142
+ if isinstance(value, (tuple, list)):
143
+ if len(value) == 1:
144
+ return (int(value[0]),) * dim
145
+ if len(value) != dim:
146
+ raise ValueError(f"{name} must have length {dim}, got {len(value)}.")
147
+ return tuple(int(v) for v in value)
148
+
149
+ raise TypeError(f"{name} must be int or sequence, got {type(value).__name__}.")
150
+
151
+
152
+ def _conv_out_dims(
153
+ input_spatial: _Shape,
154
+ kernel_size: _Shape,
155
+ stride: _Stride,
156
+ padding: _Padding,
157
+ dilation: _Dilation,
158
+ ) -> list[int]:
159
+ out_dims = []
160
+ for i in range(len(kernel_size)):
161
+ eff = dilation[i] * (kernel_size[i] - 1) + 1
162
+ o = (input_spatial[i] + 2 * padding[i] - eff) // stride[i] + 1
163
+ if o <= 0:
164
+ raise ValueError(f"Non-positive output dim for axis {i}: {o}")
165
+ out_dims.append(o)
166
+
167
+ return out_dims
168
+
169
+
170
+ def _validate_conv_shapes(input_: Tensor, weight: Tensor, groups: int) -> None:
171
+ if input_.ndim != weight.ndim:
172
+ raise ValueError("Input and weight must have the same number of dimensions.")
173
+ if input_.ndim < 3:
174
+ raise ValueError("Input and weight must have at least 3 dimensions.")
175
+ if groups <= 0:
176
+ raise ValueError("groups must be a positive integer.")
177
+
178
+ C_in = input_.shape[1]
179
+ C_out = weight.shape[0]
180
+ C_in_g = weight.shape[1]
181
+
182
+ if C_out % groups != 0 or C_in_g * groups != C_in:
183
+ raise ValueError("Inconsistent channel/group configuration.")
184
+
185
+
186
+ def _pad_input(lib_: ModuleType, data: _Array, padding: _Padding) -> _Array:
187
+ if not any(padding):
188
+ return data
189
+
190
+ pad_width = ((0, 0), (0, 0)) + tuple((p, p) for p in padding)
191
+ return lib_.pad(data, pad_width)
192
+
193
+
194
+ def _as_strided(
195
+ lib_: ModuleType, data: _Array, shape: _Shape, strides: _Shape
196
+ ) -> _Array | None:
197
+ if lib_ is np:
198
+ return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
199
+
200
+ as_strided = getattr(lib_, "as_strided", None)
201
+ if as_strided is None:
202
+ return None
203
+
204
+ try:
205
+ return as_strided(data, shape=shape, strides=strides)
206
+ except TypeError:
207
+ return as_strided(data, shape, strides)
208
+
209
+
210
+ def _make_input_view(
211
+ lib_: ModuleType,
212
+ data: _Array,
213
+ out_dims: _Shape,
214
+ kernel_size: _Shape,
215
+ stride: _Stride,
216
+ dilation: _Dilation,
217
+ ) -> _Array | None:
218
+ if not hasattr(data, "strides"):
219
+ return None
220
+ strides = data.strides
221
+ if strides is None:
222
+ return None
223
+
224
+ spatial_strides = strides[2:]
225
+ view_strides = (
226
+ strides[0],
227
+ strides[1],
228
+ *[spatial_strides[i] * stride[i] for i in range(len(kernel_size))],
229
+ *[spatial_strides[i] * dilation[i] for i in range(len(kernel_size))],
230
+ )
231
+ view_shape = (data.shape[0], data.shape[1], *out_dims, *kernel_size)
232
+
233
+ return _as_strided(lib_, data, view_shape, view_strides)
234
+
235
+
236
+ def _conv_from_view(
237
+ lib_: ModuleType, x_view: _Array, weight: _Array, out_dims: _Shape, groups: int
238
+ ) -> _Array:
239
+ D = len(out_dims)
240
+ C_out = weight.shape[0]
241
+ C_in_g = weight.shape[1]
242
+ C_out_g = C_out // groups
243
+
244
+ axes_x = [1] + list(range(2 + D, 2 + 2 * D))
245
+ axes_w = [1] + list(range(2, 2 + D))
246
+ perm = [0, D + 1] + list(range(1, D + 1))
247
+
248
+ outputs = []
249
+ for g in range(groups):
250
+ x_g = x_view[:, g * C_in_g : (g + 1) * C_in_g, ...]
251
+ w_g = weight[g * C_out_g : (g + 1) * C_out_g, ...]
252
+
253
+ out = lib_.tensordot(x_g, w_g, axes=(axes_x, axes_w))
254
+ out = lib_.transpose(out, axes=perm)
255
+ outputs.append(out)
256
+
257
+ if len(outputs) == 1:
258
+ return outputs[0]
259
+
260
+ return lib_.concatenate(outputs, axis=1)
261
+
262
+
263
+ def _conv_fallback(
264
+ lib_: ModuleType,
265
+ input_: _Array,
266
+ weight: _Array,
267
+ stride: _Stride,
268
+ padding: _Padding,
269
+ dilation: _Dilation,
270
+ groups: int,
271
+ out_dims: _Shape,
272
+ ) -> _Array:
273
+ D = len(out_dims)
274
+ kernel_size = weight.shape[2:]
275
+ C_out = weight.shape[0]
276
+ C_in_g = weight.shape[1]
277
+ C_out_g = C_out // groups
278
+
279
+ x = _pad_input(lib_, input_, padding)
280
+
281
+ outputs = []
282
+ for g in range(groups):
283
+ x_g = x[:, g * C_in_g : (g + 1) * C_in_g]
284
+ w_g = weight[g * C_out_g : (g + 1) * C_out_g]
285
+
286
+ out_g = None
287
+ for k_idx in itertools.product(*[range(k) for k in kernel_size]):
288
+ slices = [slice(None), slice(None)]
289
+
290
+ for d in range(D):
291
+ start = k_idx[d] * dilation[d]
292
+ end = start + stride[d] * out_dims[d]
293
+ slices.append(slice(start, end, stride[d]))
294
+
295
+ x_slice = x_g[tuple(slices)]
296
+ w_slice = w_g[(slice(None), slice(None)) + k_idx]
297
+
298
+ contrib = lib_.tensordot(x_slice, w_slice, axes=([1], [1]))
299
+ perm = [0, contrib.ndim - 1] + list(range(1, contrib.ndim - 1))
300
+
301
+ contrib = lib_.transpose(contrib, axes=perm)
302
+ out_g = contrib if out_g is None else out_g + contrib
303
+
304
+ outputs.append(out_g)
305
+
306
+ if len(outputs) == 1:
307
+ return outputs[0]
308
+
309
+ return lib_.concatenate(outputs, axis=1)
310
+
311
+
312
+ def _conv_forward(
313
+ lib_: ModuleType,
314
+ input_: _Array,
315
+ weight: _Array,
316
+ stride: _Stride,
317
+ padding: _Padding,
318
+ dilation: _Dilation,
319
+ groups: int,
320
+ ) -> _Array:
321
+ input_spatial = input_.shape[2:]
322
+ kernel_size = weight.shape[2:]
323
+ out_dims = tuple(
324
+ _conv_out_dims(input_spatial, kernel_size, stride, padding, dilation)
325
+ )
326
+
327
+ if _view_exceeds_limit(input_, out_dims, kernel_size):
328
+ return _conv_fallback(
329
+ lib_, input_, weight, stride, padding, dilation, groups, out_dims
330
+ )
331
+
332
+ x = _pad_input(lib_, input_, padding)
333
+ x_view = _make_input_view(lib_, x, out_dims, kernel_size, stride, dilation)
334
+ if x_view is None:
335
+ return _conv_fallback(
336
+ lib_, input_, weight, stride, padding, dilation, groups, out_dims
337
+ )
338
+
339
+ return _conv_from_view(lib_, x_view, weight, out_dims, groups)
340
+
341
+
342
+ def _conv_backward_weight(
343
+ lib_: ModuleType,
344
+ grad_out: _Array,
345
+ x_pad: _Array,
346
+ weight: _Array,
347
+ stride: _Stride,
348
+ dilation: _Dilation,
349
+ groups: int,
350
+ ) -> _Array:
351
+ weight_shape = weight.shape
352
+ D = len(weight_shape) - 2
353
+ out_dims = grad_out.shape[2:]
354
+ kernel_size = weight.shape[2:]
355
+ C_out = weight_shape[0]
356
+ C_in_g = weight_shape[1]
357
+ C_out_g = C_out // groups
358
+
359
+ x_view = _make_input_view(lib_, x_pad, out_dims, kernel_size, stride, dilation)
360
+ if x_view is not None and _view_exceeds_limit(x_pad, out_dims, kernel_size):
361
+ x_view = None
362
+ axes_out = [0] + list(range(2, 2 + D))
363
+ axes_x = [0] + list(range(2, 2 + D))
364
+
365
+ grad_parts = []
366
+ for g in range(groups):
367
+ grad_out_g = grad_out[:, g * C_out_g : (g + 1) * C_out_g, ...]
368
+
369
+ if x_view is None:
370
+ x_g = x_pad[:, g * C_in_g : (g + 1) * C_in_g]
371
+ grad_w = lib_.zeros((C_out_g, C_in_g, *kernel_size), dtype=weight.dtype)
372
+
373
+ for k_idx in itertools.product(*[range(k) for k in kernel_size]):
374
+ slices = [slice(None), slice(None)]
375
+
376
+ for d in range(D):
377
+ start = k_idx[d] * dilation[d]
378
+ end = start + stride[d] * out_dims[d]
379
+ slices.append(slice(start, end, stride[d]))
380
+
381
+ x_slice = x_g[tuple(slices)]
382
+ w_grad = lib_.tensordot(grad_out_g, x_slice, axes=(axes_out, axes_x))
383
+
384
+ if lib_ is np:
385
+ grad_w[(slice(None), slice(None)) + k_idx] = w_grad
386
+ else:
387
+ grad_w = grad_w.at[(slice(None), slice(None)) + k_idx].add(w_grad)
388
+ grad_parts.append(grad_w)
389
+
390
+ else:
391
+ x_view_g = x_view[:, g * C_in_g : (g + 1) * C_in_g, ...]
392
+ grad_w = lib_.tensordot(grad_out_g, x_view_g, axes=(axes_out, axes_x))
393
+ grad_parts.append(grad_w)
394
+
395
+ if len(grad_parts) == 1:
396
+ return grad_parts[0]
397
+
398
+ return lib_.concatenate(grad_parts, axis=0)
399
+
400
+
401
+ def _conv_backward_input(
402
+ lib_: ModuleType,
403
+ grad_out: _Array,
404
+ weight: _Array,
405
+ x_pad: _Array,
406
+ stride: _Stride,
407
+ padding: _Padding,
408
+ dilation: _Dilation,
409
+ groups: int,
410
+ ) -> _Array:
411
+ kernel_size = weight.shape[2:]
412
+ D = len(kernel_size)
413
+ out_dims = grad_out.shape[2:]
414
+
415
+ C_out = weight.shape[0]
416
+ C_in_g = weight.shape[1]
417
+ C_out_g = C_out // groups
418
+
419
+ grad_input = lib_.zeros_like(x_pad)
420
+
421
+ for g in range(groups):
422
+ grad_out_g = grad_out[:, g * C_out_g : (g + 1) * C_out_g, ...]
423
+ w_g = weight[g * C_out_g : (g + 1) * C_out_g]
424
+ ch_slice = slice(g * C_in_g, (g + 1) * C_in_g)
425
+
426
+ for k_idx in itertools.product(*[range(k) for k in kernel_size]):
427
+ w_slice = w_g[(slice(None), slice(None)) + k_idx]
428
+ contrib = lib_.tensordot(grad_out_g, w_slice, axes=([1], [0]))
429
+
430
+ perm = [0, contrib.ndim - 1] + list(range(1, contrib.ndim - 1))
431
+ contrib = lib_.transpose(contrib, axes=perm)
432
+
433
+ slices = [slice(None), ch_slice]
434
+ for d in range(D):
435
+ start = k_idx[d] * dilation[d]
436
+ end = start + stride[d] * out_dims[d]
437
+ slices.append(slice(start, end, stride[d]))
438
+
439
+ if lib_ is np:
440
+ grad_input[tuple(slices)] += contrib
441
+ else:
442
+ grad_input = grad_input.at[tuple(slices)].add(contrib)
443
+
444
+ if any(padding):
445
+ crop = [slice(None), slice(None)]
446
+ for p in padding:
447
+ end = -p if p != 0 else None
448
+ crop.append(slice(p, end))
449
+ return grad_input[tuple(crop)]
450
+
451
+ return grad_input
452
+
453
+
454
+ class conv_nd(Operation):
455
+ def __init__(
456
+ self,
457
+ stride: int | tuple[int, ...] | list[int],
458
+ padding: int | tuple[int, ...] | list[int],
459
+ dilation: int | tuple[int, ...] | list[int],
460
+ groups: int,
461
+ ) -> None:
462
+ super().__init__()
463
+ self.stride = stride
464
+ self.padding = padding
465
+ self.dilation = dilation
466
+ self.groups = groups
467
+
468
+ self._stride: _Stride | None = None
469
+ self._padding: _Padding | None = None
470
+ self._dilation: _Dilation | None = None
471
+
472
+ def _normalize(self, weight: Tensor) -> tuple[_Stride, _Padding, _Dilation]:
473
+ D = weight.ndim - 2
474
+ stride = _to_tuple(self.stride, D, "stride")
475
+ padding = _to_tuple(self.padding, D, "padding")
476
+ dilation = _to_tuple(self.dilation, D, "dilation")
477
+
478
+ self._stride = stride
479
+ self._padding = padding
480
+ self._dilation = dilation
481
+
482
+ return stride, padding, dilation
483
+
484
+ @binary_func_op()
485
+ def cpu(self, a: Tensor, b: Tensor) -> _FuncOpReturnType:
486
+ _validate_conv_shapes(a, b, self.groups)
487
+ stride, padding, dilation = self._normalize(b)
488
+ out = _conv_forward(np, a.data, b.data, stride, padding, dilation, self.groups)
489
+
490
+ self.result = Tensor(out)
491
+ return self.result, partial(self.__grad__, a=a, b=b, lib_=np)
492
+
493
+ @binary_func_op(device="gpu")
494
+ def gpu(self, a: Tensor, b: Tensor) -> _FuncOpReturnType:
495
+ _validate_conv_shapes(a, b, self.groups)
496
+ stride, padding, dilation = self._normalize(b)
497
+ out = _conv_forward(mx, a.data, b.data, stride, padding, dilation, self.groups)
498
+
499
+ self.result = Tensor(out)
500
+ return self.result, partial(self.__grad__, a=a, b=b, lib_=mx)
501
+
502
+ def __grad__(self, a: Tensor, b: Tensor, lib_: ModuleType) -> _GradType:
503
+ stride = self._stride
504
+ padding = self._padding
505
+ dilation = self._dilation
506
+
507
+ if stride is None or padding is None or dilation is None:
508
+ raise RuntimeError("conv_nd backward called before forward.")
509
+
510
+ x_pad = _pad_input(lib_, a.data, padding)
511
+ grad_out = self.result.grad
512
+
513
+ grad_input = _conv_backward_input(
514
+ lib_, grad_out, b.data, x_pad, stride, padding, dilation, self.groups
515
+ )
516
+ grad_weight = _conv_backward_weight(
517
+ lib_, grad_out, x_pad, b.data, stride, dilation, self.groups
518
+ )
519
+
520
+ return grad_input, grad_weight
521
+
522
+ def __flops__(self, a: Tensor, b: Tensor) -> int:
523
+ stride = self._stride
524
+ padding = self._padding
525
+ dilation = self._dilation
526
+ if stride is None or padding is None or dilation is None:
527
+ stride, padding, dilation = self._normalize(b)
528
+
529
+ N = int(a.shape[0])
530
+ C_out = int(b.shape[0])
531
+ C_in_g = int(b.shape[1])
532
+ kernel_size = tuple(int(v) for v in b.shape[2:])
533
+ out_dims = _conv_out_dims(
534
+ tuple(int(v) for v in a.shape[2:]), kernel_size, stride, padding, dilation
535
+ )
536
+
537
+ macs_per_out = C_in_g * _prod(kernel_size)
538
+ out_elems = N * C_out * _prod(tuple(out_dims))
539
+ return out_elems * macs_per_out
540
+
541
+
542
+ def conv_nd_op(
543
+ stride: int | tuple[int, ...] | list[int],
544
+ padding: int | tuple[int, ...] | list[int],
545
+ dilation: int | tuple[int, ...] | list[int],
546
+ groups: int,
547
+ ) -> conv_nd:
548
+ return conv_nd(stride, padding, dilation, groups)