lucid-dl 2.12.0__tar.gz → 2.12.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lucid_dl-2.12.0/lucid_dl.egg-info → lucid_dl-2.12.1}/PKG-INFO +5 -1
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/README.md +4 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/__init__.py +2 -2
- lucid_dl-2.12.1/lucid/_tensor/__init__.py +11 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_tensor/base.py +2 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_tensor/tensor.py +192 -3
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_util/__init__.py +14 -5
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_util/func.py +73 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/__init__.py +1 -0
- lucid_dl-2.12.1/lucid/models/seqclf/__init__.py +1 -0
- lucid_dl-2.12.1/lucid/models/seqclf/bert.py +31 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/embedding.py +19 -16
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_util.py +40 -8
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/attention.py +58 -6
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/rnn.py +133 -21
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/sparse.py +16 -1
- {lucid_dl-2.12.0 → lucid_dl-2.12.1/lucid_dl.egg-info}/PKG-INFO +5 -1
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/SOURCES.txt +2 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/setup.py +1 -1
- lucid_dl-2.12.0/lucid/_tensor/__init__.py +0 -1
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/LICENSE +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_backend/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_backend/core.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_backend/metal.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/bfunc.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/gfunc.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/ufunc.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_fusion/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_fusion/base.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_fusion/func.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/autograd/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/data/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/data/_base.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/data/_util.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/_base.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/cifar.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/mnist.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/einops/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/einops/_func.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/error.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/linalg/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/linalg/_func.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/alex.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/coatnet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/convnext.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/crossvit.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/cspnet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/cvt.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/dense.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/efficient.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/efficientformer.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/inception.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/inception_next.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/inception_res.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/lenet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/maxvit.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/mobile.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/pvt.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/resnest.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/resnet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/resnext.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/senet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/sknet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/swin.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/vgg.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/vit.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/xception.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/zfnet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/ddpm.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/ncsn.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/vae.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/detr.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/efficientdet.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/fast_rcnn.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/faster_rcnn.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/rcnn.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/util.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/seq2seq/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/seq2seq/transformer.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/utils.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/activation.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/attention.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/conv.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/loss.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/norm.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/pool.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_activation.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_attention.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_conv.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_drop.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_linear.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_loss.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_norm.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_pool.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_spatial.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/fused.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/init/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/init/_dist.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/module.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/activation.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/conv.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/drop.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/einops.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/linear.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/loss.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/norm.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/pool.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/transformer.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/vision.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/parameter.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/utils/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/utils/_grad.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/utils/rnn.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/_base.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/ada.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/adam.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/_base.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/prop.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/sgd.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/port.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/random/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/random/_func.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/transforms/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/transforms/_base.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/transforms/image.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/types.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/visual/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/visual/mermaid.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/weights/__init__.py +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/weights/__init__.pyi +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/dependency_links.txt +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/requires.txt +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/top_level.txt +0 -0
- {lucid_dl-2.12.0 → lucid_dl-2.12.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.12.
|
|
3
|
+
Version: 2.12.1
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -48,6 +48,10 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
48
48
|
|
|
49
49
|
### 🔥 What's New
|
|
50
50
|
|
|
51
|
+
- New Tensor utility function added: `lucid.Tensor.expand`
|
|
52
|
+
|
|
53
|
+
- Added Type-Generic Tensors: `lucid.LongTensor`, `lucid.DoubleTensor`, etc.
|
|
54
|
+
|
|
51
55
|
- Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
|
|
52
56
|
|
|
53
57
|
- Added additional `nn.Module` hooks for richer introspection during training:
|
|
@@ -20,6 +20,10 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
20
20
|
|
|
21
21
|
### 🔥 What's New
|
|
22
22
|
|
|
23
|
+
- New Tensor utility function added: `lucid.Tensor.expand`
|
|
24
|
+
|
|
25
|
+
- Added Type-Generic Tensors: `lucid.LongTensor`, `lucid.DoubleTensor`, etc.
|
|
26
|
+
|
|
23
27
|
- Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
|
|
24
28
|
|
|
25
29
|
- Added additional `nn.Module` hooks for richer introspection during training:
|
|
@@ -25,7 +25,7 @@ import json
|
|
|
25
25
|
import math
|
|
26
26
|
import numpy as np
|
|
27
27
|
|
|
28
|
-
from lucid._tensor import
|
|
28
|
+
from lucid._tensor import *
|
|
29
29
|
from lucid._func import *
|
|
30
30
|
from lucid._util import *
|
|
31
31
|
|
|
@@ -308,7 +308,7 @@ def register_model(func: _ModuleReturnFunc) -> _ModuleReturnFunc:
|
|
|
308
308
|
|
|
309
309
|
|
|
310
310
|
def _conv_view_limit_mb() -> int:
|
|
311
|
-
from lucid._kernel import conv as _conv_kernel
|
|
311
|
+
from lucid.nn._kernel import conv as _conv_kernel
|
|
312
312
|
|
|
313
313
|
return _conv_kernel.get_conv_view_limit_mb()
|
|
314
314
|
|
|
@@ -108,6 +108,8 @@ class _TensorBase:
|
|
|
108
108
|
|
|
109
109
|
def broadcast_to(self, shape: _ShapeLike) -> Self: ...
|
|
110
110
|
|
|
111
|
+
def expand(self, *sizes: int | _ShapeLike) -> Self: ...
|
|
112
|
+
|
|
111
113
|
def chunk(self, chunks: int, axis: int = 0) -> tuple[Self, ...]: ...
|
|
112
114
|
|
|
113
115
|
def swapaxes(self, axis1: int, axis2: int) -> Self: ...
|
|
@@ -1,4 +1,15 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import (
|
|
2
|
+
Callable,
|
|
3
|
+
Iterator,
|
|
4
|
+
Optional,
|
|
5
|
+
Self,
|
|
6
|
+
SupportsIndex,
|
|
7
|
+
Any,
|
|
8
|
+
overload,
|
|
9
|
+
Generic,
|
|
10
|
+
TypeVar,
|
|
11
|
+
ClassVar,
|
|
12
|
+
)
|
|
2
13
|
from types import NoneType
|
|
3
14
|
from collections import deque
|
|
4
15
|
|
|
@@ -22,15 +33,32 @@ from lucid._backend.core import BackwardOperation, Operation, noop
|
|
|
22
33
|
from lucid._backend.metal import mx, parse_mlx_indexing, check_metal_availability
|
|
23
34
|
|
|
24
35
|
|
|
36
|
+
__all__ = [
|
|
37
|
+
"Tensor",
|
|
38
|
+
"FloatTensor",
|
|
39
|
+
"DoubleTensor",
|
|
40
|
+
"HalfTensor",
|
|
41
|
+
"CharTensor",
|
|
42
|
+
"ShortTensor",
|
|
43
|
+
"IntTensor",
|
|
44
|
+
"LongTensor",
|
|
45
|
+
"BoolTensor",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
DType = TypeVar("DType", bound=Numeric | bool)
|
|
50
|
+
|
|
25
51
|
_HookType = Callable[["Tensor", _NumPyArray | _MLXArray], None]
|
|
26
52
|
|
|
27
53
|
_dtype_map = {int: types.Int64, float: types.Float64, complex: types.Complex64}
|
|
28
54
|
|
|
29
55
|
|
|
30
|
-
class Tensor(_TensorBase, _TensorInplace):
|
|
56
|
+
class Tensor(Generic[DType], _TensorBase, _TensorInplace):
|
|
57
|
+
_fixed_dtype: ClassVar[Numeric | None] = None
|
|
58
|
+
|
|
31
59
|
def __init__(
|
|
32
60
|
self,
|
|
33
|
-
data: _ArrayOrScalar
|
|
61
|
+
data: _ArrayOrScalar,
|
|
34
62
|
requires_grad: bool = False,
|
|
35
63
|
keep_grad: bool = False,
|
|
36
64
|
dtype: _BuiltinNumeric | Numeric | None = None,
|
|
@@ -39,6 +67,9 @@ class Tensor(_TensorBase, _TensorInplace):
|
|
|
39
67
|
self._is_free = False
|
|
40
68
|
self._is_bool_tensor = False
|
|
41
69
|
|
|
70
|
+
if self._fixed_dtype is not None:
|
|
71
|
+
dtype = self._fixed_dtype
|
|
72
|
+
|
|
42
73
|
if dtype is bool:
|
|
43
74
|
self._is_bool_tensor = True
|
|
44
75
|
dtype = None
|
|
@@ -285,6 +316,12 @@ class Tensor(_TensorBase, _TensorInplace):
|
|
|
285
316
|
dtype = device_or_dtype
|
|
286
317
|
return self.astype(dtype)
|
|
287
318
|
|
|
319
|
+
def cpu(self) -> Self:
|
|
320
|
+
return self.to(device="cpu")
|
|
321
|
+
|
|
322
|
+
def gpu(self) -> Self:
|
|
323
|
+
return self.to(device="gpu")
|
|
324
|
+
|
|
288
325
|
def is_cpu(self) -> bool:
|
|
289
326
|
return self.device == "cpu"
|
|
290
327
|
|
|
@@ -480,3 +517,155 @@ class Tensor(_TensorBase, _TensorInplace):
|
|
|
480
517
|
|
|
481
518
|
def bool(self) -> Self:
|
|
482
519
|
return self.astype(bool)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class LongTensor(Tensor[types.Int64]):
|
|
523
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Int64
|
|
524
|
+
|
|
525
|
+
def __init__(
|
|
526
|
+
self,
|
|
527
|
+
data: _ArrayOrScalar,
|
|
528
|
+
requires_grad: bool = False,
|
|
529
|
+
keep_grad: bool = False,
|
|
530
|
+
device: _DeviceType = "cpu",
|
|
531
|
+
) -> None:
|
|
532
|
+
super().__init__(
|
|
533
|
+
data=data,
|
|
534
|
+
requires_grad=requires_grad,
|
|
535
|
+
keep_grad=keep_grad,
|
|
536
|
+
dtype=types.Int64,
|
|
537
|
+
device=device,
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
class IntTensor(Tensor[types.Int32]):
|
|
542
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Int32
|
|
543
|
+
|
|
544
|
+
def __init__(
|
|
545
|
+
self,
|
|
546
|
+
data: _ArrayOrScalar,
|
|
547
|
+
requires_grad: bool = False,
|
|
548
|
+
keep_grad: bool = False,
|
|
549
|
+
device: _DeviceType = "cpu",
|
|
550
|
+
) -> None:
|
|
551
|
+
super().__init__(
|
|
552
|
+
data=data,
|
|
553
|
+
requires_grad=requires_grad,
|
|
554
|
+
keep_grad=keep_grad,
|
|
555
|
+
dtype=types.Int32,
|
|
556
|
+
device=device,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class ShortTensor(Tensor[types.Int16]):
|
|
561
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Int16
|
|
562
|
+
|
|
563
|
+
def __init__(
|
|
564
|
+
self,
|
|
565
|
+
data: _ArrayOrScalar,
|
|
566
|
+
requires_grad: bool = False,
|
|
567
|
+
keep_grad: bool = False,
|
|
568
|
+
device: _DeviceType = "cpu",
|
|
569
|
+
) -> None:
|
|
570
|
+
super().__init__(
|
|
571
|
+
data=data,
|
|
572
|
+
requires_grad=requires_grad,
|
|
573
|
+
keep_grad=keep_grad,
|
|
574
|
+
dtype=types.Int16,
|
|
575
|
+
device=device,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
class CharTensor(Tensor[types.Int8]):
|
|
580
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Int8
|
|
581
|
+
|
|
582
|
+
def __init__(
|
|
583
|
+
self,
|
|
584
|
+
data: _ArrayOrScalar,
|
|
585
|
+
requires_grad: bool = False,
|
|
586
|
+
keep_grad: bool = False,
|
|
587
|
+
device: _DeviceType = "cpu",
|
|
588
|
+
) -> None:
|
|
589
|
+
super().__init__(
|
|
590
|
+
data=data,
|
|
591
|
+
requires_grad=requires_grad,
|
|
592
|
+
keep_grad=keep_grad,
|
|
593
|
+
dtype=types.Int8,
|
|
594
|
+
device=device,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
class HalfTensor(Tensor[types.Float16]):
|
|
599
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Float16
|
|
600
|
+
|
|
601
|
+
def __init__(
|
|
602
|
+
self,
|
|
603
|
+
data: _ArrayOrScalar,
|
|
604
|
+
requires_grad: bool = False,
|
|
605
|
+
keep_grad: bool = False,
|
|
606
|
+
device: _DeviceType = "cpu",
|
|
607
|
+
) -> None:
|
|
608
|
+
super().__init__(
|
|
609
|
+
data=data,
|
|
610
|
+
requires_grad=requires_grad,
|
|
611
|
+
keep_grad=keep_grad,
|
|
612
|
+
dtype=types.Float16,
|
|
613
|
+
device=device,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
class FloatTensor(Tensor[types.Float32]):
|
|
618
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Float32
|
|
619
|
+
|
|
620
|
+
def __init__(
|
|
621
|
+
self,
|
|
622
|
+
data: _ArrayOrScalar,
|
|
623
|
+
requires_grad: bool = False,
|
|
624
|
+
keep_grad: bool = False,
|
|
625
|
+
device: _DeviceType = "cpu",
|
|
626
|
+
) -> None:
|
|
627
|
+
super().__init__(
|
|
628
|
+
data=data,
|
|
629
|
+
requires_grad=requires_grad,
|
|
630
|
+
keep_grad=keep_grad,
|
|
631
|
+
dtype=types.Float32,
|
|
632
|
+
device=device,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
class DoubleTensor(Tensor[types.Float64]):
|
|
637
|
+
_fixed_dtype: ClassVar[Numeric | None] = types.Float64
|
|
638
|
+
|
|
639
|
+
def __init__(
|
|
640
|
+
self,
|
|
641
|
+
data: _ArrayOrScalar,
|
|
642
|
+
requires_grad: bool = False,
|
|
643
|
+
keep_grad: bool = False,
|
|
644
|
+
device: _DeviceType = "cpu",
|
|
645
|
+
) -> None:
|
|
646
|
+
super().__init__(
|
|
647
|
+
data=data,
|
|
648
|
+
requires_grad=requires_grad,
|
|
649
|
+
keep_grad=keep_grad,
|
|
650
|
+
dtype=types.Float64,
|
|
651
|
+
device=device,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
class BoolTensor(Tensor[bool]):
|
|
656
|
+
_fixed_dtype: ClassVar[Numeric | None] = None
|
|
657
|
+
|
|
658
|
+
def __init__(
|
|
659
|
+
self,
|
|
660
|
+
data: _ArrayOrScalar,
|
|
661
|
+
requires_grad: bool = False,
|
|
662
|
+
keep_grad: bool = False,
|
|
663
|
+
device: _DeviceType = "cpu",
|
|
664
|
+
) -> None:
|
|
665
|
+
super().__init__(
|
|
666
|
+
data=data,
|
|
667
|
+
requires_grad=requires_grad,
|
|
668
|
+
keep_grad=keep_grad,
|
|
669
|
+
dtype=bool,
|
|
670
|
+
device=device,
|
|
671
|
+
)
|
|
@@ -9,11 +9,11 @@ from lucid._util import func
|
|
|
9
9
|
# fmt: off
|
|
10
10
|
__all__ = [
|
|
11
11
|
"reshape", "squeeze", "unsqueeze", "expand_dims", "ravel", "stack", "hstack",
|
|
12
|
-
"vstack", "concatenate", "pad", "repeat", "tile", "flatten", "meshgrid",
|
|
13
|
-
"split", "tril", "triu", "broadcast_to", "
|
|
14
|
-
"unbind", "sort", "nonzero", "unique", "topk", "argsort",
|
|
15
|
-
"histogram", "histogram2d", "where", "nonzero", "argmin",
|
|
16
|
-
"diagonal",
|
|
12
|
+
"vstack", "concatenate", "pad", "repeat", "tile", "flatten", "meshgrid",
|
|
13
|
+
"split", "tril", "triu", "broadcast_to", "expand", "chunk", "masked_fill",
|
|
14
|
+
"roll", "unbind", "sort", "nonzero", "unique", "topk", "argsort",
|
|
15
|
+
"histogramdd", "histogram", "histogram2d", "where", "nonzero", "argmin",
|
|
16
|
+
"argmax", "diagonal",
|
|
17
17
|
]
|
|
18
18
|
# fmt: on
|
|
19
19
|
|
|
@@ -106,6 +106,14 @@ def broadcast_to(a: Tensor, /, shape: _ShapeLike) -> Tensor:
|
|
|
106
106
|
return func.broadcast_to(shape)(a)
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
def expand(a: Tensor, /, *sizes: int | _ShapeLike) -> Tensor:
|
|
110
|
+
if len(sizes) == 1 and isinstance(sizes[0], (tuple, list)):
|
|
111
|
+
shape = sizes[0]
|
|
112
|
+
else:
|
|
113
|
+
shape = sizes
|
|
114
|
+
return func.expand(shape)(a)
|
|
115
|
+
|
|
116
|
+
|
|
109
117
|
def chunk(a: Tensor, /, chunks: int, axis: int = 0) -> tuple[Tensor, ...]:
|
|
110
118
|
return func.chunk(chunks, axis)(a)
|
|
111
119
|
|
|
@@ -257,6 +265,7 @@ Tensor.split = split
|
|
|
257
265
|
Tensor.tril = tril
|
|
258
266
|
Tensor.triu = triu
|
|
259
267
|
Tensor.broadcast_to = broadcast_to
|
|
268
|
+
Tensor.expand = expand
|
|
260
269
|
Tensor.chunk = chunk
|
|
261
270
|
Tensor.masked_fill = masked_fill
|
|
262
271
|
Tensor.roll = roll
|
|
@@ -605,6 +605,79 @@ class broadcast_to(Operation):
|
|
|
605
605
|
return self.result.grad.reshape(self.original_shape)
|
|
606
606
|
|
|
607
607
|
|
|
608
|
+
class expand(Operation):
|
|
609
|
+
def __init__(self, shape: _ShapeLike) -> None:
|
|
610
|
+
super().__init__()
|
|
611
|
+
self.shape = shape
|
|
612
|
+
|
|
613
|
+
def _resolve_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
614
|
+
shape = tuple(int(dim) for dim in self.shape)
|
|
615
|
+
if len(shape) == 0:
|
|
616
|
+
raise ValueError("expand() expects at least one dimension.")
|
|
617
|
+
|
|
618
|
+
if len(shape) < len(input_shape):
|
|
619
|
+
raise ValueError(
|
|
620
|
+
"expand() cannot shrink the number of dimensions from "
|
|
621
|
+
f"{len(input_shape)} to {len(shape)}."
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
ndim_diff = len(shape) - len(input_shape)
|
|
625
|
+
padded_input = (1,) * ndim_diff + input_shape
|
|
626
|
+
|
|
627
|
+
resolved: list[int] = []
|
|
628
|
+
for axis, (target_dim, input_dim) in enumerate(zip(shape, padded_input)):
|
|
629
|
+
if target_dim == -1:
|
|
630
|
+
if axis < ndim_diff:
|
|
631
|
+
raise ValueError(
|
|
632
|
+
"expand() cannot use -1 in a leading, "
|
|
633
|
+
"non-existing dimension."
|
|
634
|
+
)
|
|
635
|
+
target_dim = input_dim
|
|
636
|
+
|
|
637
|
+
elif target_dim < -1:
|
|
638
|
+
raise ValueError("expand() size must be >= -1.")
|
|
639
|
+
|
|
640
|
+
if input_dim == target_dim:
|
|
641
|
+
resolved.append(target_dim)
|
|
642
|
+
elif input_dim == 1 and target_dim >= 0:
|
|
643
|
+
resolved.append(target_dim)
|
|
644
|
+
else:
|
|
645
|
+
raise ValueError(
|
|
646
|
+
"expand() cannot expand dimension "
|
|
647
|
+
f"{axis} from {input_dim} to {target_dim}."
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
return tuple(resolved)
|
|
651
|
+
|
|
652
|
+
@unary_func_op()
|
|
653
|
+
def cpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
654
|
+
self.original_shape = a.shape
|
|
655
|
+
self.expanded_shape = self._resolve_shape(a.shape)
|
|
656
|
+
|
|
657
|
+
self.result = Tensor(np.broadcast_to(a.data, self.expanded_shape))
|
|
658
|
+
return self.result, self.__grad__
|
|
659
|
+
|
|
660
|
+
@unary_func_op(device="gpu")
|
|
661
|
+
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
662
|
+
self.original_shape = a.shape
|
|
663
|
+
self.expanded_shape = self._resolve_shape(a.shape)
|
|
664
|
+
|
|
665
|
+
self.result = Tensor(mx.broadcast_to(a.data, self.expanded_shape))
|
|
666
|
+
return self.result, self.__grad__
|
|
667
|
+
|
|
668
|
+
def __grad__(self) -> _GradType:
|
|
669
|
+
input_shape = self.original_shape
|
|
670
|
+
ndim_diff = len(self.expanded_shape) - len(input_shape)
|
|
671
|
+
if ndim_diff > 0:
|
|
672
|
+
input_shape = (1,) * ndim_diff + input_shape
|
|
673
|
+
|
|
674
|
+
for axis, (in_dim, out_dim) in enumerate(zip(input_shape, self.expanded_shape)):
|
|
675
|
+
if in_dim == 1 and out_dim > 1:
|
|
676
|
+
self.result.grad = self.result.grad.sum(axis=axis, keepdims=True)
|
|
677
|
+
|
|
678
|
+
return self.result.grad.reshape(self.original_shape)
|
|
679
|
+
|
|
680
|
+
|
|
608
681
|
class chunk(Operation):
|
|
609
682
|
def __init__(self, chunks: int, axis: int) -> None:
|
|
610
683
|
super().__init__()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .bert import *
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import lucid
|
|
2
|
+
import lucid.nn as nn
|
|
3
|
+
import lucid.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from lucid._tensor import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class _BertEmbeddings(nn.Module):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
vocab_size: int,
|
|
12
|
+
hidden_size: int,
|
|
13
|
+
pad_token_id: int,
|
|
14
|
+
max_position_embeddings: int,
|
|
15
|
+
type_vocab_size: int,
|
|
16
|
+
layer_norm_eps: float,
|
|
17
|
+
hidden_dropout_prob: float,
|
|
18
|
+
) -> None:
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, pad_token_id)
|
|
21
|
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
|
22
|
+
self.token_type_embeddings = nn.Embedding(type_vocab_size)
|
|
23
|
+
|
|
24
|
+
self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
|
25
|
+
self.dropout = nn.Dropout(hidden_dropout_prob)
|
|
26
|
+
|
|
27
|
+
self.position_ids: nn.Buffer
|
|
28
|
+
self.register_buffer(
|
|
29
|
+
"position_ids", nn.Buffer(lucid.arange(max_position_embeddings))
|
|
30
|
+
)
|
|
31
|
+
# TODO: Implement `lucid.Tensor.expand`
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import
|
|
1
|
+
from functools import partial
|
|
2
2
|
from types import ModuleType
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
@@ -7,49 +7,44 @@ from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
|
|
|
7
7
|
from lucid._backend.metal import mx
|
|
8
8
|
from lucid._tensor import Tensor
|
|
9
9
|
|
|
10
|
-
from lucid.types import _DeviceType, _TensorData
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def _as_int_array(arr, lib_: ModuleType) -> _TensorData:
|
|
14
|
-
if lib_ is np:
|
|
15
|
-
return arr.astype(np.int64)
|
|
16
|
-
return arr.astype(mx.int32)
|
|
17
|
-
|
|
18
10
|
|
|
19
11
|
class embedding_kernel(Operation):
|
|
20
|
-
def __init__(self) -> None:
|
|
12
|
+
def __init__(self, padding_idx: int = -1) -> None:
|
|
21
13
|
super().__init__()
|
|
14
|
+
self.padding_idx = int(padding_idx)
|
|
22
15
|
self._indices = None
|
|
23
16
|
self._num_embeddings = None
|
|
24
17
|
|
|
25
18
|
def clear(self) -> None:
|
|
26
19
|
super().clear()
|
|
20
|
+
self.padding_idx = -1
|
|
27
21
|
self._indices = None
|
|
28
22
|
self._num_embeddings = None
|
|
29
23
|
|
|
30
24
|
@func_op(n_in=2, n_ret=1)
|
|
31
25
|
def cpu(self, indices: Tensor, weight: Tensor) -> _FuncOpReturnType:
|
|
32
|
-
return self._forward(indices, weight, lib_=np
|
|
26
|
+
return self._forward(indices, weight, lib_=np)
|
|
33
27
|
|
|
34
28
|
@func_op(n_in=2, n_ret=1, device="gpu")
|
|
35
29
|
def gpu(self, indices: Tensor, weight: Tensor) -> _FuncOpReturnType:
|
|
36
|
-
return self._forward(indices, weight, lib_=mx
|
|
30
|
+
return self._forward(indices, weight, lib_=mx)
|
|
37
31
|
|
|
38
32
|
def _forward(
|
|
39
|
-
self, indices: Tensor, weight: Tensor, lib_: ModuleType
|
|
33
|
+
self, indices: Tensor, weight: Tensor, lib_: ModuleType
|
|
40
34
|
) -> _FuncOpReturnType:
|
|
41
|
-
idx =
|
|
35
|
+
idx = indices.data
|
|
42
36
|
out = weight.data[idx]
|
|
43
37
|
|
|
44
38
|
self._indices = idx
|
|
45
39
|
self._num_embeddings = int(weight.shape[0])
|
|
46
40
|
|
|
47
|
-
self.result = Tensor(out
|
|
48
|
-
return self.result,
|
|
41
|
+
self.result = Tensor(out)
|
|
42
|
+
return self.result, partial(self.__grad__, lib_=lib_)
|
|
49
43
|
|
|
50
44
|
def __grad__(self, lib_: ModuleType) -> _GradType:
|
|
51
45
|
if self.result is None or self.result.grad is None:
|
|
52
46
|
raise RuntimeError("embedding backward called before forward.")
|
|
47
|
+
|
|
53
48
|
if self._indices is None or self._num_embeddings is None:
|
|
54
49
|
raise RuntimeError("embedding cached data missing.")
|
|
55
50
|
|
|
@@ -58,15 +53,23 @@ class embedding_kernel(Operation):
|
|
|
58
53
|
grad_flat = grad_out.reshape(idx.shape[0], -1)
|
|
59
54
|
|
|
60
55
|
if lib_ is np:
|
|
56
|
+
if self.padding_idx >= 0:
|
|
57
|
+
keep = idx != self.padding_idx
|
|
58
|
+
idx = idx[keep]
|
|
59
|
+
grad_flat = grad_flat[keep]
|
|
60
|
+
|
|
61
61
|
grad_w = np.zeros(
|
|
62
62
|
(self._num_embeddings, grad_flat.shape[1]), dtype=grad_out.dtype
|
|
63
63
|
)
|
|
64
64
|
np.add.at(grad_w, idx, grad_flat)
|
|
65
|
+
|
|
65
66
|
else:
|
|
66
67
|
grad_w = mx.zeros(
|
|
67
68
|
(self._num_embeddings, grad_flat.shape[1]), dtype=grad_out.dtype
|
|
68
69
|
)
|
|
69
70
|
for i in range(idx.shape[0]):
|
|
71
|
+
if self.padding_idx >= 0 and int(idx[i]) == self.padding_idx:
|
|
72
|
+
continue
|
|
70
73
|
grad_w = grad_w.at[idx[i]].add(grad_flat[i])
|
|
71
74
|
|
|
72
75
|
return None, grad_w
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
1
3
|
import lucid
|
|
2
4
|
import lucid.nn.functional
|
|
3
5
|
|
|
@@ -5,6 +7,7 @@ from lucid._tensor import Tensor
|
|
|
5
7
|
from lucid.types import _Scalar, Numeric
|
|
6
8
|
|
|
7
9
|
from lucid.nn._kernel.embedding import embedding_kernel
|
|
10
|
+
from lucid._backend.metal import mx
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def _interpolate_bilinear(
|
|
@@ -131,17 +134,46 @@ def embedding(
|
|
|
131
134
|
max_norm: float | None = None,
|
|
132
135
|
norm_type: float = 2.0,
|
|
133
136
|
) -> Tensor:
|
|
137
|
+
num_embeddings = int(weight.shape[0])
|
|
138
|
+
if padding_idx is None:
|
|
139
|
+
pad = -1
|
|
140
|
+
else:
|
|
141
|
+
pad = int(padding_idx)
|
|
142
|
+
if pad < 0:
|
|
143
|
+
pad += num_embeddings
|
|
144
|
+
if pad < 0 or pad >= num_embeddings:
|
|
145
|
+
raise IndexError("padding_idx out of range.")
|
|
146
|
+
|
|
134
147
|
indices = input_.astype(lucid.Int)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if
|
|
138
|
-
|
|
139
|
-
output *= 1 - mask[..., None]
|
|
148
|
+
idx_data = indices.data
|
|
149
|
+
|
|
150
|
+
if (idx_data < 0).any() or (idx_data >= num_embeddings).any():
|
|
151
|
+
raise IndexError("embedding indices out of range.")
|
|
140
152
|
|
|
141
153
|
if max_norm is not None:
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
154
|
+
lib_ = np if weight.is_cpu() else mx
|
|
155
|
+
flat = idx_data.reshape(-1)
|
|
156
|
+
|
|
157
|
+
w = weight.data[flat]
|
|
158
|
+
if norm_type <= 0:
|
|
159
|
+
raise ValueError("norm_type must be positive.")
|
|
160
|
+
|
|
161
|
+
norms = (lib_.abs(w) ** norm_type).sum(axis=1) ** (1.0 / norm_type)
|
|
162
|
+
scale = lib_.minimum(1.0, max_norm / (norms + (norms == 0)))
|
|
163
|
+
|
|
164
|
+
if pad >= 0:
|
|
165
|
+
mask = flat == pad
|
|
166
|
+
mask_f = mask.astype(scale.dtype)
|
|
167
|
+
scale = scale * (1 - mask_f) + mask_f
|
|
168
|
+
|
|
169
|
+
weight.data[flat] = w * scale[:, None]
|
|
170
|
+
|
|
171
|
+
op = embedding_kernel(padding_idx=pad)
|
|
172
|
+
output = op(indices, weight)
|
|
173
|
+
|
|
174
|
+
if pad >= 0:
|
|
175
|
+
mask = input_.data == pad
|
|
176
|
+
output *= 1 - mask[..., None]
|
|
145
177
|
|
|
146
178
|
return output
|
|
147
179
|
|