lucid-dl 2.8.5__tar.gz → 2.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/PKG-INFO +8 -6
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/README.md +7 -5
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/__init__.py +44 -5
- lucid_dl-2.10.0/lucid/_backend/conv.py +548 -0
- lucid_dl-2.10.0/lucid/_backend/core.py +259 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_backend/metal.py +38 -1
- lucid_dl-2.10.0/lucid/_backend/pool.py +368 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/bfunc.py +45 -45
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/ufunc.py +79 -79
- lucid_dl-2.10.0/lucid/_fusion/__init__.py +4 -0
- lucid_dl-2.10.0/lucid/_fusion/base.py +120 -0
- lucid_dl-2.10.0/lucid/_fusion/func.py +80 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_tensor/tensor.py +102 -15
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_util/func.py +62 -63
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/einops/_func.py +10 -10
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/linalg/_func.py +29 -29
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_conv.py +15 -62
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_norm.py +4 -4
- lucid_dl-2.10.0/lucid/nn/functional/_pool.py +141 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/fused.py +1 -1
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/init/_dist.py +25 -11
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/norm.py +3 -3
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/parameter.py +12 -2
- lucid_dl-2.10.0/lucid/optim/ada.py +189 -0
- lucid_dl-2.10.0/lucid/optim/adam.py +317 -0
- lucid_dl-2.10.0/lucid/optim/prop.py +156 -0
- lucid_dl-2.10.0/lucid/optim/sgd.py +147 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/types.py +27 -1
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/PKG-INFO +8 -6
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/SOURCES.txt +5 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/setup.py +1 -1
- lucid_dl-2.8.5/lucid/_backend/core.py +0 -170
- lucid_dl-2.8.5/lucid/nn/functional/_pool.py +0 -275
- lucid_dl-2.8.5/lucid/optim/ada.py +0 -179
- lucid_dl-2.8.5/lucid/optim/adam.py +0 -304
- lucid_dl-2.8.5/lucid/optim/prop.py +0 -144
- lucid_dl-2.8.5/lucid/optim/sgd.py +0 -139
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/LICENSE +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_backend/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_func/gfunc.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_tensor/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_tensor/tensor_ops.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/_util/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/data/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/data/_base.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/data/_util.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/_base.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/cifar.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/datasets/mnist.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/einops/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/error.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/linalg/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/alex.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/coatnet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/convnext.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/crossvit.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/cspnet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/cvt.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/dense.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/efficient.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/efficientformer.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/inception.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/inception_next.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/inception_res.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/lenet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/maxvit.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/mobile.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/pvt.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/resnest.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/resnet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/resnext.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/senet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/sknet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/swin.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/vgg.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/vit.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/xception.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imgclf/zfnet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imggen/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imggen/ddpm.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/imggen/vae.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/detr.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/efficientdet.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/fast_rcnn.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/faster_rcnn.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/rcnn.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/util.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/seq2seq/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/seq2seq/transformer.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/models/util.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_activation.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_attention.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_drop.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_linear.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_loss.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_spatial.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/functional/_util.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/init/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/module.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/activation.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/attention.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/conv.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/drop.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/einops.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/linear.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/loss.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/pool.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/rnn.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/sparse.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/transformer.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/modules/vision.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/nn/util.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/_base.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/_base.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/port.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/random/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/random/_func.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/transforms/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/transforms/_base.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/transforms/image.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/visual/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/visual/graph.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/weights/__init__.py +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid/weights/__init__.pyi +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/dependency_links.txt +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/requires.txt +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/lucid_dl.egg-info/top_level.txt +0 -0
- {lucid_dl-2.8.5 → lucid_dl-2.10.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.10.0
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -33,7 +33,7 @@ Dynamic: summary
|
|
|
33
33
|

|
|
34
34
|

|
|
35
35
|

|
|
36
|
-

|
|
37
37
|
|
|
38
38
|
**Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
|
|
39
39
|
|
|
@@ -44,15 +44,17 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
44
44
|
|
|
45
45
|
#### Other Languages
|
|
46
46
|
|
|
47
|
-
[🇰🇷
|
|
47
|
+
[🇰🇷 Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
|
|
48
48
|
|
|
49
49
|
### 🔥 What's New
|
|
50
50
|
|
|
51
51
|
- Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
|
|
52
52
|
|
|
53
|
-
-
|
|
54
|
-
|
|
55
|
-
|
|
53
|
+
- Introduced **Backward Fusion** for CPU execution:
|
|
54
|
+
- Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
|
|
55
|
+
- Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
|
|
56
|
+
- Uses heuristic thresholds to avoid fusion overhead on small tensors
|
|
57
|
+
- Disabled by default on GPU paths to ensure stable performance
|
|
56
58
|
|
|
57
59
|
## 🔧 How to Install
|
|
58
60
|
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|

|
|
6
6
|

|
|
7
7
|

|
|
8
|
-

|
|
9
9
|
|
|
10
10
|
**Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
|
|
11
11
|
|
|
@@ -16,15 +16,17 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
16
16
|
|
|
17
17
|
#### Other Languages
|
|
18
18
|
|
|
19
|
-
[🇰🇷
|
|
19
|
+
[🇰🇷 Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
|
|
20
20
|
|
|
21
21
|
### 🔥 What's New
|
|
22
22
|
|
|
23
23
|
- Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
|
|
24
24
|
|
|
25
|
-
-
|
|
26
|
-
|
|
27
|
-
|
|
25
|
+
- Introduced **Backward Fusion** for CPU execution:
|
|
26
|
+
- Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
|
|
27
|
+
- Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
|
|
28
|
+
- Uses heuristic thresholds to avoid fusion overhead on small tensors
|
|
29
|
+
- Disabled by default on GPU paths to ensure stable performance
|
|
28
30
|
|
|
29
31
|
## 🔧 How to Install
|
|
30
32
|
|
|
@@ -15,7 +15,7 @@ algorithms and operations without the complexity of high-level frameworks.
|
|
|
15
15
|
|
|
16
16
|
from contextlib import contextmanager, AbstractContextManager
|
|
17
17
|
from typing import Any, Generator, SupportsIndex, Callable, Self, Optional, Type
|
|
18
|
-
from types import TracebackType
|
|
18
|
+
from types import TracebackType, ModuleType
|
|
19
19
|
from functools import wraps
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
|
|
@@ -50,6 +50,8 @@ import lucid.einops as einops
|
|
|
50
50
|
import lucid.nn as nn
|
|
51
51
|
import lucid.types as types
|
|
52
52
|
|
|
53
|
+
from lucid._fusion import ENABLE_FUSION
|
|
54
|
+
|
|
53
55
|
|
|
54
56
|
_grad_enabled: bool = True
|
|
55
57
|
_flops_enabled: bool = False
|
|
@@ -177,11 +179,18 @@ def _set_tensor_grad(
|
|
|
177
179
|
|
|
178
180
|
|
|
179
181
|
def _check_is_tensor(
|
|
180
|
-
any: Tensor | _ArrayOrScalar,
|
|
182
|
+
any: Tensor | _ArrayOrScalar,
|
|
183
|
+
device: _DeviceType = "cpu",
|
|
184
|
+
dtype: _BuiltinNumeric | Numeric | None = None,
|
|
181
185
|
) -> Tensor:
|
|
182
|
-
if
|
|
183
|
-
return
|
|
184
|
-
|
|
186
|
+
if isinstance(any, Tensor):
|
|
187
|
+
return any
|
|
188
|
+
|
|
189
|
+
is_scalar = not isinstance(any, (_NumPyArray, _MLXArray, list, tuple))
|
|
190
|
+
if dtype is not None and is_scalar:
|
|
191
|
+
return Tensor(any, device=device, dtype=dtype)
|
|
192
|
+
|
|
193
|
+
return Tensor(any, device=device)
|
|
185
194
|
|
|
186
195
|
|
|
187
196
|
def _match_grad_shape(
|
|
@@ -293,3 +302,33 @@ def register_model(func: _ModuleReturnFunc) -> _ModuleReturnFunc:
|
|
|
293
302
|
return model
|
|
294
303
|
|
|
295
304
|
return wrapper
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _conv_view_limit_mb() -> int:
|
|
308
|
+
from lucid._backend import conv as _conv_backend
|
|
309
|
+
|
|
310
|
+
return _conv_backend.get_conv_view_limit_mb()
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def __getattr__(name: str) -> Any:
|
|
314
|
+
if name == "CONV_VIEW_LIMIT_MB":
|
|
315
|
+
return _conv_view_limit_mb()
|
|
316
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def __dir__() -> list[str]:
|
|
320
|
+
return sorted(list(globals().keys()) + ["CONV_VIEW_LIMIT_MB"])
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class _LucidModule(ModuleType):
|
|
324
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
325
|
+
if name == "CONV_VIEW_LIMIT_MB":
|
|
326
|
+
raise AttributeError(
|
|
327
|
+
"CONV_VIEW_LIMIT_MB is read-only; set LUCID_CONV_VIEW_LIMIT_MB "
|
|
328
|
+
"before importing lucid."
|
|
329
|
+
)
|
|
330
|
+
super().__setattr__(name, value)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
if not isinstance(sys.modules[__name__], _LucidModule):
|
|
334
|
+
sys.modules[__name__].__class__ = _LucidModule
|
|
@@ -0,0 +1,548 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from types import ModuleType
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
import itertools
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from lucid._tensor import Tensor
|
|
10
|
+
from lucid._backend.core import (
|
|
11
|
+
Operation,
|
|
12
|
+
binary_func_op,
|
|
13
|
+
_FuncOpReturnType,
|
|
14
|
+
_GradType,
|
|
15
|
+
)
|
|
16
|
+
from lucid._backend.metal import mx
|
|
17
|
+
|
|
18
|
+
from lucid.types import _NumPyArray, _MLXArray
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_Array: TypeAlias = _NumPyArray | _MLXArray
|
|
22
|
+
_Shape: TypeAlias = tuple[int, ...]
|
|
23
|
+
_Stride: TypeAlias = tuple[int, ...]
|
|
24
|
+
_Padding: TypeAlias = tuple[int, ...]
|
|
25
|
+
_Dilation: TypeAlias = tuple[int, ...]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _load_view_limit_bytes() -> int:
|
|
29
|
+
env = os.getenv("LUCID_CONV_VIEW_LIMIT_MB")
|
|
30
|
+
if env is None:
|
|
31
|
+
return _default_view_limit_bytes()
|
|
32
|
+
try:
|
|
33
|
+
value = int(env)
|
|
34
|
+
except ValueError:
|
|
35
|
+
return _default_view_limit_bytes()
|
|
36
|
+
return value * 1024 * 1024
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _sysconf_value(name: str) -> int | None:
|
|
40
|
+
try:
|
|
41
|
+
value = int(os.sysconf(name))
|
|
42
|
+
except (ValueError, AttributeError, OSError):
|
|
43
|
+
return None
|
|
44
|
+
if value <= 0:
|
|
45
|
+
return None
|
|
46
|
+
return value
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_total_memory_bytes() -> int | None:
|
|
50
|
+
page_size = _sysconf_value("SC_PAGE_SIZE") or _sysconf_value("SC_PAGESIZE")
|
|
51
|
+
phys_pages = _sysconf_value("SC_PHYS_PAGES")
|
|
52
|
+
if page_size and phys_pages:
|
|
53
|
+
return page_size * phys_pages
|
|
54
|
+
try:
|
|
55
|
+
import ctypes
|
|
56
|
+
|
|
57
|
+
class MEMORYSTATUSEX(ctypes.Structure):
|
|
58
|
+
_fields_ = [
|
|
59
|
+
("dwLength", ctypes.c_ulong),
|
|
60
|
+
("dwMemoryLoad", ctypes.c_ulong),
|
|
61
|
+
("ullTotalPhys", ctypes.c_ulonglong),
|
|
62
|
+
("ullAvailPhys", ctypes.c_ulonglong),
|
|
63
|
+
("ullTotalPageFile", ctypes.c_ulonglong),
|
|
64
|
+
("ullAvailPageFile", ctypes.c_ulonglong),
|
|
65
|
+
("ullTotalVirtual", ctypes.c_ulonglong),
|
|
66
|
+
("ullAvailVirtual", ctypes.c_ulonglong),
|
|
67
|
+
("ullAvailExtendedVirtual", ctypes.c_ulonglong),
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
stat = MEMORYSTATUSEX()
|
|
71
|
+
stat.dwLength = ctypes.sizeof(MEMORYSTATUSEX)
|
|
72
|
+
if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)):
|
|
73
|
+
return int(stat.ullTotalPhys)
|
|
74
|
+
|
|
75
|
+
except Exception:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _round_to_step(value: int, step: int) -> int:
|
|
80
|
+
return ((value + step // 2) // step) * step
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _default_view_limit_bytes() -> int:
|
|
84
|
+
total = _get_total_memory_bytes()
|
|
85
|
+
if not total:
|
|
86
|
+
return 256 * 1024 * 1024
|
|
87
|
+
|
|
88
|
+
mb = 1024 * 1024
|
|
89
|
+
min_bytes = 64 * mb
|
|
90
|
+
max_bytes = 1024 * mb
|
|
91
|
+
step = 64 * mb
|
|
92
|
+
|
|
93
|
+
target = (total * 15) // 1000
|
|
94
|
+
target = max(min_bytes, min(max_bytes, target))
|
|
95
|
+
target = _round_to_step(target, step)
|
|
96
|
+
return max(min_bytes, min(max_bytes, target))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
_CONV_VIEW_LIMIT_BYTES = _load_view_limit_bytes()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_conv_view_limit_mb() -> int:
|
|
103
|
+
return int(_CONV_VIEW_LIMIT_BYTES // (1024 * 1024))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _dtype_itemsize(data: _Array) -> int:
|
|
107
|
+
dtype = getattr(data, "dtype", None)
|
|
108
|
+
if dtype is None:
|
|
109
|
+
return 0
|
|
110
|
+
try:
|
|
111
|
+
return int(np.dtype(dtype).itemsize)
|
|
112
|
+
except TypeError:
|
|
113
|
+
return int(getattr(dtype, "size", 0) or 0)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _prod(shape: _Shape) -> int:
|
|
117
|
+
total = 1
|
|
118
|
+
for v in shape:
|
|
119
|
+
total *= int(v)
|
|
120
|
+
return total
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _view_exceeds_limit(data: _Array, out_dims: _Shape, kernel_size: _Shape) -> bool:
|
|
124
|
+
if _CONV_VIEW_LIMIT_BYTES == 0:
|
|
125
|
+
return True
|
|
126
|
+
if _CONV_VIEW_LIMIT_BYTES < 0:
|
|
127
|
+
return False
|
|
128
|
+
itemsize = _dtype_itemsize(data)
|
|
129
|
+
if itemsize == 0:
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
view_elems = data.shape[0] * data.shape[1] * _prod(out_dims) * _prod(kernel_size)
|
|
133
|
+
view_bytes = view_elems * itemsize
|
|
134
|
+
|
|
135
|
+
return view_bytes > _CONV_VIEW_LIMIT_BYTES
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _to_tuple(value: int | tuple[int, ...] | list[int], dim: int, name: str) -> _Shape:
|
|
139
|
+
if isinstance(value, int):
|
|
140
|
+
return (value,) * dim
|
|
141
|
+
|
|
142
|
+
if isinstance(value, (tuple, list)):
|
|
143
|
+
if len(value) == 1:
|
|
144
|
+
return (int(value[0]),) * dim
|
|
145
|
+
if len(value) != dim:
|
|
146
|
+
raise ValueError(f"{name} must have length {dim}, got {len(value)}.")
|
|
147
|
+
return tuple(int(v) for v in value)
|
|
148
|
+
|
|
149
|
+
raise TypeError(f"{name} must be int or sequence, got {type(value).__name__}.")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _conv_out_dims(
|
|
153
|
+
input_spatial: _Shape,
|
|
154
|
+
kernel_size: _Shape,
|
|
155
|
+
stride: _Stride,
|
|
156
|
+
padding: _Padding,
|
|
157
|
+
dilation: _Dilation,
|
|
158
|
+
) -> list[int]:
|
|
159
|
+
out_dims = []
|
|
160
|
+
for i in range(len(kernel_size)):
|
|
161
|
+
eff = dilation[i] * (kernel_size[i] - 1) + 1
|
|
162
|
+
o = (input_spatial[i] + 2 * padding[i] - eff) // stride[i] + 1
|
|
163
|
+
if o <= 0:
|
|
164
|
+
raise ValueError(f"Non-positive output dim for axis {i}: {o}")
|
|
165
|
+
out_dims.append(o)
|
|
166
|
+
|
|
167
|
+
return out_dims
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _validate_conv_shapes(input_: Tensor, weight: Tensor, groups: int) -> None:
|
|
171
|
+
if input_.ndim != weight.ndim:
|
|
172
|
+
raise ValueError("Input and weight must have the same number of dimensions.")
|
|
173
|
+
if input_.ndim < 3:
|
|
174
|
+
raise ValueError("Input and weight must have at least 3 dimensions.")
|
|
175
|
+
if groups <= 0:
|
|
176
|
+
raise ValueError("groups must be a positive integer.")
|
|
177
|
+
|
|
178
|
+
C_in = input_.shape[1]
|
|
179
|
+
C_out = weight.shape[0]
|
|
180
|
+
C_in_g = weight.shape[1]
|
|
181
|
+
|
|
182
|
+
if C_out % groups != 0 or C_in_g * groups != C_in:
|
|
183
|
+
raise ValueError("Inconsistent channel/group configuration.")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _pad_input(lib_: ModuleType, data: _Array, padding: _Padding) -> _Array:
|
|
187
|
+
if not any(padding):
|
|
188
|
+
return data
|
|
189
|
+
|
|
190
|
+
pad_width = ((0, 0), (0, 0)) + tuple((p, p) for p in padding)
|
|
191
|
+
return lib_.pad(data, pad_width)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _as_strided(
|
|
195
|
+
lib_: ModuleType, data: _Array, shape: _Shape, strides: _Shape
|
|
196
|
+
) -> _Array | None:
|
|
197
|
+
if lib_ is np:
|
|
198
|
+
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
|
|
199
|
+
|
|
200
|
+
as_strided = getattr(lib_, "as_strided", None)
|
|
201
|
+
if as_strided is None:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
return as_strided(data, shape=shape, strides=strides)
|
|
206
|
+
except TypeError:
|
|
207
|
+
return as_strided(data, shape, strides)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _make_input_view(
|
|
211
|
+
lib_: ModuleType,
|
|
212
|
+
data: _Array,
|
|
213
|
+
out_dims: _Shape,
|
|
214
|
+
kernel_size: _Shape,
|
|
215
|
+
stride: _Stride,
|
|
216
|
+
dilation: _Dilation,
|
|
217
|
+
) -> _Array | None:
|
|
218
|
+
if not hasattr(data, "strides"):
|
|
219
|
+
return None
|
|
220
|
+
strides = data.strides
|
|
221
|
+
if strides is None:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
spatial_strides = strides[2:]
|
|
225
|
+
view_strides = (
|
|
226
|
+
strides[0],
|
|
227
|
+
strides[1],
|
|
228
|
+
*[spatial_strides[i] * stride[i] for i in range(len(kernel_size))],
|
|
229
|
+
*[spatial_strides[i] * dilation[i] for i in range(len(kernel_size))],
|
|
230
|
+
)
|
|
231
|
+
view_shape = (data.shape[0], data.shape[1], *out_dims, *kernel_size)
|
|
232
|
+
|
|
233
|
+
return _as_strided(lib_, data, view_shape, view_strides)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _conv_from_view(
|
|
237
|
+
lib_: ModuleType, x_view: _Array, weight: _Array, out_dims: _Shape, groups: int
|
|
238
|
+
) -> _Array:
|
|
239
|
+
D = len(out_dims)
|
|
240
|
+
C_out = weight.shape[0]
|
|
241
|
+
C_in_g = weight.shape[1]
|
|
242
|
+
C_out_g = C_out // groups
|
|
243
|
+
|
|
244
|
+
axes_x = [1] + list(range(2 + D, 2 + 2 * D))
|
|
245
|
+
axes_w = [1] + list(range(2, 2 + D))
|
|
246
|
+
perm = [0, D + 1] + list(range(1, D + 1))
|
|
247
|
+
|
|
248
|
+
outputs = []
|
|
249
|
+
for g in range(groups):
|
|
250
|
+
x_g = x_view[:, g * C_in_g : (g + 1) * C_in_g, ...]
|
|
251
|
+
w_g = weight[g * C_out_g : (g + 1) * C_out_g, ...]
|
|
252
|
+
|
|
253
|
+
out = lib_.tensordot(x_g, w_g, axes=(axes_x, axes_w))
|
|
254
|
+
out = lib_.transpose(out, axes=perm)
|
|
255
|
+
outputs.append(out)
|
|
256
|
+
|
|
257
|
+
if len(outputs) == 1:
|
|
258
|
+
return outputs[0]
|
|
259
|
+
|
|
260
|
+
return lib_.concatenate(outputs, axis=1)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _conv_fallback(
|
|
264
|
+
lib_: ModuleType,
|
|
265
|
+
input_: _Array,
|
|
266
|
+
weight: _Array,
|
|
267
|
+
stride: _Stride,
|
|
268
|
+
padding: _Padding,
|
|
269
|
+
dilation: _Dilation,
|
|
270
|
+
groups: int,
|
|
271
|
+
out_dims: _Shape,
|
|
272
|
+
) -> _Array:
|
|
273
|
+
D = len(out_dims)
|
|
274
|
+
kernel_size = weight.shape[2:]
|
|
275
|
+
C_out = weight.shape[0]
|
|
276
|
+
C_in_g = weight.shape[1]
|
|
277
|
+
C_out_g = C_out // groups
|
|
278
|
+
|
|
279
|
+
x = _pad_input(lib_, input_, padding)
|
|
280
|
+
|
|
281
|
+
outputs = []
|
|
282
|
+
for g in range(groups):
|
|
283
|
+
x_g = x[:, g * C_in_g : (g + 1) * C_in_g]
|
|
284
|
+
w_g = weight[g * C_out_g : (g + 1) * C_out_g]
|
|
285
|
+
|
|
286
|
+
out_g = None
|
|
287
|
+
for k_idx in itertools.product(*[range(k) for k in kernel_size]):
|
|
288
|
+
slices = [slice(None), slice(None)]
|
|
289
|
+
|
|
290
|
+
for d in range(D):
|
|
291
|
+
start = k_idx[d] * dilation[d]
|
|
292
|
+
end = start + stride[d] * out_dims[d]
|
|
293
|
+
slices.append(slice(start, end, stride[d]))
|
|
294
|
+
|
|
295
|
+
x_slice = x_g[tuple(slices)]
|
|
296
|
+
w_slice = w_g[(slice(None), slice(None)) + k_idx]
|
|
297
|
+
|
|
298
|
+
contrib = lib_.tensordot(x_slice, w_slice, axes=([1], [1]))
|
|
299
|
+
perm = [0, contrib.ndim - 1] + list(range(1, contrib.ndim - 1))
|
|
300
|
+
|
|
301
|
+
contrib = lib_.transpose(contrib, axes=perm)
|
|
302
|
+
out_g = contrib if out_g is None else out_g + contrib
|
|
303
|
+
|
|
304
|
+
outputs.append(out_g)
|
|
305
|
+
|
|
306
|
+
if len(outputs) == 1:
|
|
307
|
+
return outputs[0]
|
|
308
|
+
|
|
309
|
+
return lib_.concatenate(outputs, axis=1)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _conv_forward(
|
|
313
|
+
lib_: ModuleType,
|
|
314
|
+
input_: _Array,
|
|
315
|
+
weight: _Array,
|
|
316
|
+
stride: _Stride,
|
|
317
|
+
padding: _Padding,
|
|
318
|
+
dilation: _Dilation,
|
|
319
|
+
groups: int,
|
|
320
|
+
) -> _Array:
|
|
321
|
+
input_spatial = input_.shape[2:]
|
|
322
|
+
kernel_size = weight.shape[2:]
|
|
323
|
+
out_dims = tuple(
|
|
324
|
+
_conv_out_dims(input_spatial, kernel_size, stride, padding, dilation)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if _view_exceeds_limit(input_, out_dims, kernel_size):
|
|
328
|
+
return _conv_fallback(
|
|
329
|
+
lib_, input_, weight, stride, padding, dilation, groups, out_dims
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
x = _pad_input(lib_, input_, padding)
|
|
333
|
+
x_view = _make_input_view(lib_, x, out_dims, kernel_size, stride, dilation)
|
|
334
|
+
if x_view is None:
|
|
335
|
+
return _conv_fallback(
|
|
336
|
+
lib_, input_, weight, stride, padding, dilation, groups, out_dims
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
return _conv_from_view(lib_, x_view, weight, out_dims, groups)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _conv_backward_weight(
|
|
343
|
+
lib_: ModuleType,
|
|
344
|
+
grad_out: _Array,
|
|
345
|
+
x_pad: _Array,
|
|
346
|
+
weight: _Array,
|
|
347
|
+
stride: _Stride,
|
|
348
|
+
dilation: _Dilation,
|
|
349
|
+
groups: int,
|
|
350
|
+
) -> _Array:
|
|
351
|
+
weight_shape = weight.shape
|
|
352
|
+
D = len(weight_shape) - 2
|
|
353
|
+
out_dims = grad_out.shape[2:]
|
|
354
|
+
kernel_size = weight.shape[2:]
|
|
355
|
+
C_out = weight_shape[0]
|
|
356
|
+
C_in_g = weight_shape[1]
|
|
357
|
+
C_out_g = C_out // groups
|
|
358
|
+
|
|
359
|
+
x_view = _make_input_view(lib_, x_pad, out_dims, kernel_size, stride, dilation)
|
|
360
|
+
if x_view is not None and _view_exceeds_limit(x_pad, out_dims, kernel_size):
|
|
361
|
+
x_view = None
|
|
362
|
+
axes_out = [0] + list(range(2, 2 + D))
|
|
363
|
+
axes_x = [0] + list(range(2, 2 + D))
|
|
364
|
+
|
|
365
|
+
grad_parts = []
|
|
366
|
+
for g in range(groups):
|
|
367
|
+
grad_out_g = grad_out[:, g * C_out_g : (g + 1) * C_out_g, ...]
|
|
368
|
+
|
|
369
|
+
if x_view is None:
|
|
370
|
+
x_g = x_pad[:, g * C_in_g : (g + 1) * C_in_g]
|
|
371
|
+
grad_w = lib_.zeros((C_out_g, C_in_g, *kernel_size), dtype=weight.dtype)
|
|
372
|
+
|
|
373
|
+
for k_idx in itertools.product(*[range(k) for k in kernel_size]):
|
|
374
|
+
slices = [slice(None), slice(None)]
|
|
375
|
+
|
|
376
|
+
for d in range(D):
|
|
377
|
+
start = k_idx[d] * dilation[d]
|
|
378
|
+
end = start + stride[d] * out_dims[d]
|
|
379
|
+
slices.append(slice(start, end, stride[d]))
|
|
380
|
+
|
|
381
|
+
x_slice = x_g[tuple(slices)]
|
|
382
|
+
w_grad = lib_.tensordot(grad_out_g, x_slice, axes=(axes_out, axes_x))
|
|
383
|
+
|
|
384
|
+
if lib_ is np:
|
|
385
|
+
grad_w[(slice(None), slice(None)) + k_idx] = w_grad
|
|
386
|
+
else:
|
|
387
|
+
grad_w = grad_w.at[(slice(None), slice(None)) + k_idx].add(w_grad)
|
|
388
|
+
grad_parts.append(grad_w)
|
|
389
|
+
|
|
390
|
+
else:
|
|
391
|
+
x_view_g = x_view[:, g * C_in_g : (g + 1) * C_in_g, ...]
|
|
392
|
+
grad_w = lib_.tensordot(grad_out_g, x_view_g, axes=(axes_out, axes_x))
|
|
393
|
+
grad_parts.append(grad_w)
|
|
394
|
+
|
|
395
|
+
if len(grad_parts) == 1:
|
|
396
|
+
return grad_parts[0]
|
|
397
|
+
|
|
398
|
+
return lib_.concatenate(grad_parts, axis=0)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _conv_backward_input(
|
|
402
|
+
lib_: ModuleType,
|
|
403
|
+
grad_out: _Array,
|
|
404
|
+
weight: _Array,
|
|
405
|
+
x_pad: _Array,
|
|
406
|
+
stride: _Stride,
|
|
407
|
+
padding: _Padding,
|
|
408
|
+
dilation: _Dilation,
|
|
409
|
+
groups: int,
|
|
410
|
+
) -> _Array:
|
|
411
|
+
kernel_size = weight.shape[2:]
|
|
412
|
+
D = len(kernel_size)
|
|
413
|
+
out_dims = grad_out.shape[2:]
|
|
414
|
+
|
|
415
|
+
C_out = weight.shape[0]
|
|
416
|
+
C_in_g = weight.shape[1]
|
|
417
|
+
C_out_g = C_out // groups
|
|
418
|
+
|
|
419
|
+
grad_input = lib_.zeros_like(x_pad)
|
|
420
|
+
|
|
421
|
+
for g in range(groups):
|
|
422
|
+
grad_out_g = grad_out[:, g * C_out_g : (g + 1) * C_out_g, ...]
|
|
423
|
+
w_g = weight[g * C_out_g : (g + 1) * C_out_g]
|
|
424
|
+
ch_slice = slice(g * C_in_g, (g + 1) * C_in_g)
|
|
425
|
+
|
|
426
|
+
for k_idx in itertools.product(*[range(k) for k in kernel_size]):
|
|
427
|
+
w_slice = w_g[(slice(None), slice(None)) + k_idx]
|
|
428
|
+
contrib = lib_.tensordot(grad_out_g, w_slice, axes=([1], [0]))
|
|
429
|
+
|
|
430
|
+
perm = [0, contrib.ndim - 1] + list(range(1, contrib.ndim - 1))
|
|
431
|
+
contrib = lib_.transpose(contrib, axes=perm)
|
|
432
|
+
|
|
433
|
+
slices = [slice(None), ch_slice]
|
|
434
|
+
for d in range(D):
|
|
435
|
+
start = k_idx[d] * dilation[d]
|
|
436
|
+
end = start + stride[d] * out_dims[d]
|
|
437
|
+
slices.append(slice(start, end, stride[d]))
|
|
438
|
+
|
|
439
|
+
if lib_ is np:
|
|
440
|
+
grad_input[tuple(slices)] += contrib
|
|
441
|
+
else:
|
|
442
|
+
grad_input = grad_input.at[tuple(slices)].add(contrib)
|
|
443
|
+
|
|
444
|
+
if any(padding):
|
|
445
|
+
crop = [slice(None), slice(None)]
|
|
446
|
+
for p in padding:
|
|
447
|
+
end = -p if p != 0 else None
|
|
448
|
+
crop.append(slice(p, end))
|
|
449
|
+
return grad_input[tuple(crop)]
|
|
450
|
+
|
|
451
|
+
return grad_input
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
class conv_nd(Operation):
|
|
455
|
+
def __init__(
|
|
456
|
+
self,
|
|
457
|
+
stride: int | tuple[int, ...] | list[int],
|
|
458
|
+
padding: int | tuple[int, ...] | list[int],
|
|
459
|
+
dilation: int | tuple[int, ...] | list[int],
|
|
460
|
+
groups: int,
|
|
461
|
+
) -> None:
|
|
462
|
+
super().__init__()
|
|
463
|
+
self.stride = stride
|
|
464
|
+
self.padding = padding
|
|
465
|
+
self.dilation = dilation
|
|
466
|
+
self.groups = groups
|
|
467
|
+
|
|
468
|
+
self._stride: _Stride | None = None
|
|
469
|
+
self._padding: _Padding | None = None
|
|
470
|
+
self._dilation: _Dilation | None = None
|
|
471
|
+
|
|
472
|
+
def _normalize(self, weight: Tensor) -> tuple[_Stride, _Padding, _Dilation]:
|
|
473
|
+
D = weight.ndim - 2
|
|
474
|
+
stride = _to_tuple(self.stride, D, "stride")
|
|
475
|
+
padding = _to_tuple(self.padding, D, "padding")
|
|
476
|
+
dilation = _to_tuple(self.dilation, D, "dilation")
|
|
477
|
+
|
|
478
|
+
self._stride = stride
|
|
479
|
+
self._padding = padding
|
|
480
|
+
self._dilation = dilation
|
|
481
|
+
|
|
482
|
+
return stride, padding, dilation
|
|
483
|
+
|
|
484
|
+
@binary_func_op()
|
|
485
|
+
def cpu(self, a: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
486
|
+
_validate_conv_shapes(a, b, self.groups)
|
|
487
|
+
stride, padding, dilation = self._normalize(b)
|
|
488
|
+
out = _conv_forward(np, a.data, b.data, stride, padding, dilation, self.groups)
|
|
489
|
+
|
|
490
|
+
self.result = Tensor(out)
|
|
491
|
+
return self.result, partial(self.__grad__, a=a, b=b, lib_=np)
|
|
492
|
+
|
|
493
|
+
@binary_func_op(device="gpu")
|
|
494
|
+
def gpu(self, a: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
495
|
+
_validate_conv_shapes(a, b, self.groups)
|
|
496
|
+
stride, padding, dilation = self._normalize(b)
|
|
497
|
+
out = _conv_forward(mx, a.data, b.data, stride, padding, dilation, self.groups)
|
|
498
|
+
|
|
499
|
+
self.result = Tensor(out)
|
|
500
|
+
return self.result, partial(self.__grad__, a=a, b=b, lib_=mx)
|
|
501
|
+
|
|
502
|
+
def __grad__(self, a: Tensor, b: Tensor, lib_: ModuleType) -> _GradType:
|
|
503
|
+
stride = self._stride
|
|
504
|
+
padding = self._padding
|
|
505
|
+
dilation = self._dilation
|
|
506
|
+
|
|
507
|
+
if stride is None or padding is None or dilation is None:
|
|
508
|
+
raise RuntimeError("conv_nd backward called before forward.")
|
|
509
|
+
|
|
510
|
+
x_pad = _pad_input(lib_, a.data, padding)
|
|
511
|
+
grad_out = self.result.grad
|
|
512
|
+
|
|
513
|
+
grad_input = _conv_backward_input(
|
|
514
|
+
lib_, grad_out, b.data, x_pad, stride, padding, dilation, self.groups
|
|
515
|
+
)
|
|
516
|
+
grad_weight = _conv_backward_weight(
|
|
517
|
+
lib_, grad_out, x_pad, b.data, stride, dilation, self.groups
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return grad_input, grad_weight
|
|
521
|
+
|
|
522
|
+
def __flops__(self, a: Tensor, b: Tensor) -> int:
|
|
523
|
+
stride = self._stride
|
|
524
|
+
padding = self._padding
|
|
525
|
+
dilation = self._dilation
|
|
526
|
+
if stride is None or padding is None or dilation is None:
|
|
527
|
+
stride, padding, dilation = self._normalize(b)
|
|
528
|
+
|
|
529
|
+
N = int(a.shape[0])
|
|
530
|
+
C_out = int(b.shape[0])
|
|
531
|
+
C_in_g = int(b.shape[1])
|
|
532
|
+
kernel_size = tuple(int(v) for v in b.shape[2:])
|
|
533
|
+
out_dims = _conv_out_dims(
|
|
534
|
+
tuple(int(v) for v in a.shape[2:]), kernel_size, stride, padding, dilation
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
macs_per_out = C_in_g * _prod(kernel_size)
|
|
538
|
+
out_elems = N * C_out * _prod(tuple(out_dims))
|
|
539
|
+
return out_elems * macs_per_out
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def conv_nd_op(
|
|
543
|
+
stride: int | tuple[int, ...] | list[int],
|
|
544
|
+
padding: int | tuple[int, ...] | list[int],
|
|
545
|
+
dilation: int | tuple[int, ...] | list[int],
|
|
546
|
+
groups: int,
|
|
547
|
+
) -> conv_nd:
|
|
548
|
+
return conv_nd(stride, padding, dilation, groups)
|