lucid-dl 2.11.0__py3-none-any.whl → 2.11.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/__init__.py +4 -2
- lucid/_backend/core.py +89 -9
- lucid/_backend/metal.py +5 -1
- lucid/_func/__init__.py +162 -0
- lucid/_tensor/{tensor_ops.py → base.py} +64 -0
- lucid/_tensor/tensor.py +63 -19
- lucid/autograd/__init__.py +4 -1
- lucid/datasets/mnist.py +135 -6
- lucid/models/imggen/__init__.py +1 -0
- lucid/models/imggen/ncsn.py +402 -0
- lucid/nn/_kernel/__init__.py +1 -0
- lucid/nn/_kernel/activation.py +188 -0
- lucid/nn/_kernel/attention.py +125 -0
- lucid/{_backend → nn/_kernel}/conv.py +4 -13
- lucid/nn/_kernel/embedding.py +72 -0
- lucid/nn/_kernel/loss.py +416 -0
- lucid/nn/_kernel/norm.py +365 -0
- lucid/{_backend → nn/_kernel}/pool.py +7 -27
- lucid/nn/functional/__init__.py +4 -0
- lucid/nn/functional/_activation.py +19 -13
- lucid/nn/functional/_attention.py +9 -0
- lucid/nn/functional/_conv.py +5 -16
- lucid/nn/functional/_loss.py +31 -32
- lucid/nn/functional/_norm.py +60 -69
- lucid/nn/functional/_pool.py +7 -7
- lucid/nn/functional/_util.py +5 -1
- lucid/nn/init/_dist.py +1 -0
- lucid/types.py +24 -2
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/METADATA +7 -5
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/RECORD +33 -26
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/WHEEL +1 -1
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from types import ModuleType
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
|
|
7
|
+
from lucid._backend.metal import mx
|
|
8
|
+
from lucid._tensor import Tensor
|
|
9
|
+
from lucid.types import _DeviceType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _norm_axis(axis: int, ndim: int) -> int:
|
|
13
|
+
return axis if axis >= 0 else axis + ndim
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class softmax_kernel(Operation):
|
|
17
|
+
def __init__(self, axis: int = -1) -> None:
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.axis = axis
|
|
20
|
+
self._axis = None
|
|
21
|
+
self._y = None
|
|
22
|
+
|
|
23
|
+
def clear(self) -> None:
|
|
24
|
+
super().clear()
|
|
25
|
+
self._axis = None
|
|
26
|
+
self._y = None
|
|
27
|
+
|
|
28
|
+
@func_op(n_in=1, n_ret=1, device="cpu")
|
|
29
|
+
def cpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
30
|
+
return self._forward(a, lib_=np, device="cpu")
|
|
31
|
+
|
|
32
|
+
@func_op(n_in=1, n_ret=1, device="gpu")
|
|
33
|
+
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
34
|
+
return self._forward(a, lib_=mx, device="gpu")
|
|
35
|
+
|
|
36
|
+
def _forward(
|
|
37
|
+
self, a: Tensor, lib_: ModuleType, device: _DeviceType
|
|
38
|
+
) -> _FuncOpReturnType:
|
|
39
|
+
axis = _norm_axis(self.axis, a.ndim)
|
|
40
|
+
max_val = lib_.max(a.data, axis=axis, keepdims=True)
|
|
41
|
+
exp_x = lib_.exp(a.data - max_val)
|
|
42
|
+
sum_exp = lib_.sum(exp_x, axis=axis, keepdims=True)
|
|
43
|
+
y = exp_x / sum_exp
|
|
44
|
+
|
|
45
|
+
self._axis = axis
|
|
46
|
+
self._y = y
|
|
47
|
+
|
|
48
|
+
self.result = Tensor(y, device=device)
|
|
49
|
+
return self.result, functools.partial(self.__grad__, lib_=lib_)
|
|
50
|
+
|
|
51
|
+
def __grad__(self, lib_: ModuleType) -> _GradType:
|
|
52
|
+
if self.result is None or self.result.grad is None:
|
|
53
|
+
raise RuntimeError("softmax backward called before forward.")
|
|
54
|
+
if self._y is None or self._axis is None:
|
|
55
|
+
raise RuntimeError("softmax cached data missing.")
|
|
56
|
+
|
|
57
|
+
dy = self.result.grad
|
|
58
|
+
y = self._y
|
|
59
|
+
axis = self._axis
|
|
60
|
+
|
|
61
|
+
dot = lib_.sum(dy * y, axis=axis, keepdims=True)
|
|
62
|
+
dx = y * (dy - dot)
|
|
63
|
+
return dx
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class sigmoid_kernel(Operation):
|
|
67
|
+
def __init__(self) -> None:
|
|
68
|
+
super().__init__()
|
|
69
|
+
self._y = None
|
|
70
|
+
|
|
71
|
+
def clear(self) -> None:
|
|
72
|
+
super().clear()
|
|
73
|
+
self._y = None
|
|
74
|
+
|
|
75
|
+
@func_op(n_in=1, n_ret=1, device="cpu")
|
|
76
|
+
def cpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
77
|
+
return self._forward(a, lib_=np, device="cpu")
|
|
78
|
+
|
|
79
|
+
@func_op(n_in=1, n_ret=1, device="gpu")
|
|
80
|
+
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
81
|
+
return self._forward(a, lib_=mx, device="gpu")
|
|
82
|
+
|
|
83
|
+
def _forward(
|
|
84
|
+
self, a: Tensor, lib_: ModuleType, device: _DeviceType
|
|
85
|
+
) -> _FuncOpReturnType:
|
|
86
|
+
y = 1.0 / (1.0 + lib_.exp(-a.data))
|
|
87
|
+
self._y = y
|
|
88
|
+
self.result = Tensor(y, device=device)
|
|
89
|
+
return self.result, functools.partial(self.__grad__)
|
|
90
|
+
|
|
91
|
+
def __grad__(self) -> _GradType:
|
|
92
|
+
if self.result is None or self.result.grad is None or self._y is None:
|
|
93
|
+
raise RuntimeError("sigmoid backward called before forward.")
|
|
94
|
+
|
|
95
|
+
dy = self.result.grad
|
|
96
|
+
y = self._y
|
|
97
|
+
|
|
98
|
+
dx = dy * y * (1 - y)
|
|
99
|
+
return dx
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class gelu_kernel(Operation):
|
|
103
|
+
def __init__(self) -> None:
|
|
104
|
+
super().__init__()
|
|
105
|
+
self._x = None
|
|
106
|
+
|
|
107
|
+
def clear(self) -> None:
|
|
108
|
+
super().clear()
|
|
109
|
+
self._x = None
|
|
110
|
+
|
|
111
|
+
@func_op(n_in=1, n_ret=1, device="cpu")
|
|
112
|
+
def cpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
113
|
+
return self._forward(a, lib_=np, device="cpu")
|
|
114
|
+
|
|
115
|
+
@func_op(n_in=1, n_ret=1, device="gpu")
|
|
116
|
+
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
117
|
+
return self._forward(a, lib_=mx, device="gpu")
|
|
118
|
+
|
|
119
|
+
def _forward(
|
|
120
|
+
self, a: Tensor, lib_: ModuleType, device: _DeviceType
|
|
121
|
+
) -> _FuncOpReturnType:
|
|
122
|
+
self._x = a.data
|
|
123
|
+
c = lib_.sqrt(2.0 / lib_.pi)
|
|
124
|
+
y = 0.5 * a.data * (1.0 + lib_.tanh(c * (a.data + 0.044715 * (a.data**3))))
|
|
125
|
+
|
|
126
|
+
self.result = Tensor(y, device=device)
|
|
127
|
+
return self.result, functools.partial(self.__grad__, lib_=lib_)
|
|
128
|
+
|
|
129
|
+
def __grad__(self, lib_: ModuleType) -> _GradType:
|
|
130
|
+
if self.result is None or self.result.grad is None or self._x is None:
|
|
131
|
+
raise RuntimeError("gelu backward called before forward.")
|
|
132
|
+
|
|
133
|
+
x = self._x
|
|
134
|
+
dy = self.result.grad
|
|
135
|
+
c = lib_.sqrt(2.0 / lib_.pi)
|
|
136
|
+
t = c * (x + 0.044715 * x**3)
|
|
137
|
+
dt = c * (1 + 3 * 0.044715 * x**2)
|
|
138
|
+
sech2 = 1.0 / lib_.cosh(t) ** 2
|
|
139
|
+
|
|
140
|
+
dx = 0.5 * (1 + lib_.tanh(t)) + 0.5 * x * sech2 * dt
|
|
141
|
+
return dy * dx
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class silu_kernel(Operation):
|
|
145
|
+
def __init__(self) -> None:
|
|
146
|
+
super().__init__()
|
|
147
|
+
self._x = None
|
|
148
|
+
self._sig = None
|
|
149
|
+
|
|
150
|
+
def clear(self) -> None:
|
|
151
|
+
super().clear()
|
|
152
|
+
self._x = None
|
|
153
|
+
self._sig = None
|
|
154
|
+
|
|
155
|
+
@func_op(n_in=1, n_ret=1, device="cpu")
|
|
156
|
+
def cpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
157
|
+
return self._forward(a, lib_=np, device="cpu")
|
|
158
|
+
|
|
159
|
+
@func_op(n_in=1, n_ret=1, device="gpu")
|
|
160
|
+
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
161
|
+
return self._forward(a, lib_=mx, device="gpu")
|
|
162
|
+
|
|
163
|
+
def _forward(
|
|
164
|
+
self, a: Tensor, lib_: ModuleType, device: _DeviceType
|
|
165
|
+
) -> _FuncOpReturnType:
|
|
166
|
+
self._x = a.data
|
|
167
|
+
sig = 1.0 / (1.0 + lib_.exp(-a.data))
|
|
168
|
+
self._sig = sig
|
|
169
|
+
y = a.data * sig
|
|
170
|
+
|
|
171
|
+
self.result = Tensor(y, device=device)
|
|
172
|
+
return self.result, functools.partial(self.__grad__)
|
|
173
|
+
|
|
174
|
+
def __grad__(self) -> _GradType:
|
|
175
|
+
if (
|
|
176
|
+
self.result is None
|
|
177
|
+
or self.result.grad is None
|
|
178
|
+
or self._x is None
|
|
179
|
+
or self._sig is None
|
|
180
|
+
):
|
|
181
|
+
raise RuntimeError("silu backward called before forward.")
|
|
182
|
+
|
|
183
|
+
dy = self.result.grad
|
|
184
|
+
sig = self._sig
|
|
185
|
+
x = self._x
|
|
186
|
+
|
|
187
|
+
dx = dy * (sig + x * sig * (1 - sig))
|
|
188
|
+
return dx
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
from types import ModuleType
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
|
|
8
|
+
from lucid._backend.metal import mx
|
|
9
|
+
from lucid._tensor import Tensor
|
|
10
|
+
|
|
11
|
+
from lucid.types import _DeviceType, _TensorData
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _make_causal_mask(lib_: ModuleType, L: int, S: int, dtype: object) -> _TensorData:
|
|
15
|
+
triu = getattr(lib_, "triu", None)
|
|
16
|
+
ones = getattr(lib_, "ones", None)
|
|
17
|
+
if triu is None or ones is None:
|
|
18
|
+
mask = np.triu(np.ones((L, S), dtype=np.float32), k=1)
|
|
19
|
+
if lib_ is mx:
|
|
20
|
+
mask = mx.array(mask)
|
|
21
|
+
else:
|
|
22
|
+
mask = triu(ones((L, S), dtype=dtype), k=1)
|
|
23
|
+
return mask * (-1e12)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class scaled_dot_product_attention_kernel(Operation):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
attn_mask: Tensor | None = None,
|
|
30
|
+
is_causal: bool = False,
|
|
31
|
+
scale: float | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.attn_mask = attn_mask
|
|
35
|
+
self.is_causal = bool(is_causal)
|
|
36
|
+
self.scale = scale
|
|
37
|
+
|
|
38
|
+
self._q = None
|
|
39
|
+
self._k = None
|
|
40
|
+
self._v = None
|
|
41
|
+
self._attn = None
|
|
42
|
+
self._scale = None
|
|
43
|
+
|
|
44
|
+
def clear(self) -> None:
|
|
45
|
+
super().clear()
|
|
46
|
+
self._q = None
|
|
47
|
+
self._k = None
|
|
48
|
+
self._v = None
|
|
49
|
+
self._attn = None
|
|
50
|
+
self._scale = None
|
|
51
|
+
|
|
52
|
+
@func_op(n_in=3, n_ret=1)
|
|
53
|
+
def cpu(self, q: Tensor, k: Tensor, v: Tensor) -> _FuncOpReturnType:
|
|
54
|
+
return self._forward(q, k, v, lib_=np, device="cpu")
|
|
55
|
+
|
|
56
|
+
@func_op(n_in=3, n_ret=1, device="gpu")
|
|
57
|
+
def gpu(self, q: Tensor, k: Tensor, v: Tensor) -> _FuncOpReturnType:
|
|
58
|
+
return self._forward(q, k, v, lib_=mx, device="gpu")
|
|
59
|
+
|
|
60
|
+
def _forward(
|
|
61
|
+
self, q: Tensor, k: Tensor, v: Tensor, lib_: ModuleType, device: _DeviceType
|
|
62
|
+
) -> _FuncOpReturnType:
|
|
63
|
+
qd = q.data
|
|
64
|
+
kd = k.data
|
|
65
|
+
vd = v.data
|
|
66
|
+
|
|
67
|
+
scale = self.scale
|
|
68
|
+
if scale is None:
|
|
69
|
+
scale = 1.0 / math.sqrt(q.shape[-1])
|
|
70
|
+
|
|
71
|
+
kt = lib_.swapaxes(kd, -1, -2)
|
|
72
|
+
scores = lib_.matmul(qd, kt) * scale
|
|
73
|
+
|
|
74
|
+
if self.is_causal:
|
|
75
|
+
L = q.shape[-2]
|
|
76
|
+
S = k.shape[-2]
|
|
77
|
+
scores = scores + _make_causal_mask(lib_, L, S, dtype=scores.dtype)
|
|
78
|
+
|
|
79
|
+
if self.attn_mask is not None:
|
|
80
|
+
scores = scores + self.attn_mask.data
|
|
81
|
+
|
|
82
|
+
max_val = lib_.max(scores, axis=-1, keepdims=True)
|
|
83
|
+
exp_x = lib_.exp(scores - max_val)
|
|
84
|
+
sum_exp = lib_.sum(exp_x, axis=-1, keepdims=True)
|
|
85
|
+
attn = exp_x / sum_exp
|
|
86
|
+
|
|
87
|
+
out = lib_.matmul(attn, vd)
|
|
88
|
+
|
|
89
|
+
self._q = qd
|
|
90
|
+
self._k = kd
|
|
91
|
+
self._v = vd
|
|
92
|
+
self._attn = attn
|
|
93
|
+
self._scale = scale
|
|
94
|
+
|
|
95
|
+
self.result = Tensor(out, device=device)
|
|
96
|
+
return self.result, functools.partial(self.__grad__, lib_=lib_)
|
|
97
|
+
|
|
98
|
+
def __grad__(self, lib_: ModuleType) -> _GradType:
|
|
99
|
+
if self.result is None or self.result.grad is None:
|
|
100
|
+
raise RuntimeError("attention backward called before forward.")
|
|
101
|
+
|
|
102
|
+
if self._attn is None or self._q is None or self._k is None or self._v is None:
|
|
103
|
+
raise RuntimeError("attention cached data missing.")
|
|
104
|
+
|
|
105
|
+
dy = self.result.grad
|
|
106
|
+
attn = self._attn
|
|
107
|
+
qd = self._q
|
|
108
|
+
kd = self._k
|
|
109
|
+
vd = self._v
|
|
110
|
+
scale = self._scale if self._scale is not None else 1.0
|
|
111
|
+
|
|
112
|
+
attn_t = lib_.swapaxes(attn, -1, -2)
|
|
113
|
+
dV = lib_.matmul(attn_t, dy)
|
|
114
|
+
|
|
115
|
+
v_t = lib_.swapaxes(vd, -1, -2)
|
|
116
|
+
dA = lib_.matmul(dy, v_t)
|
|
117
|
+
|
|
118
|
+
dot = lib_.sum(dA * attn, axis=-1, keepdims=True)
|
|
119
|
+
dS = attn * (dA - dot)
|
|
120
|
+
|
|
121
|
+
dS = dS * scale
|
|
122
|
+
dQ = lib_.matmul(dS, kd)
|
|
123
|
+
dK = lib_.matmul(lib_.swapaxes(dS, -1, -2), qd)
|
|
124
|
+
|
|
125
|
+
return dQ, dK, dV
|
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
|
9
9
|
from lucid._tensor import Tensor
|
|
10
10
|
from lucid._backend.core import (
|
|
11
11
|
Operation,
|
|
12
|
-
|
|
12
|
+
func_op,
|
|
13
13
|
_FuncOpReturnType,
|
|
14
14
|
_GradType,
|
|
15
15
|
)
|
|
@@ -451,7 +451,7 @@ def _conv_backward_input(
|
|
|
451
451
|
return grad_input
|
|
452
452
|
|
|
453
453
|
|
|
454
|
-
class
|
|
454
|
+
class conv_nd_kernel(Operation):
|
|
455
455
|
def __init__(
|
|
456
456
|
self,
|
|
457
457
|
stride: int | tuple[int, ...] | list[int],
|
|
@@ -481,7 +481,7 @@ class conv_nd(Operation):
|
|
|
481
481
|
|
|
482
482
|
return stride, padding, dilation
|
|
483
483
|
|
|
484
|
-
@
|
|
484
|
+
@func_op(n_in=2, n_ret=1)
|
|
485
485
|
def cpu(self, a: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
486
486
|
_validate_conv_shapes(a, b, self.groups)
|
|
487
487
|
stride, padding, dilation = self._normalize(b)
|
|
@@ -490,7 +490,7 @@ class conv_nd(Operation):
|
|
|
490
490
|
self.result = Tensor(out)
|
|
491
491
|
return self.result, partial(self.__grad__, a=a, b=b, lib_=np)
|
|
492
492
|
|
|
493
|
-
@
|
|
493
|
+
@func_op(n_in=2, n_ret=1, device="gpu")
|
|
494
494
|
def gpu(self, a: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
495
495
|
_validate_conv_shapes(a, b, self.groups)
|
|
496
496
|
stride, padding, dilation = self._normalize(b)
|
|
@@ -537,12 +537,3 @@ class conv_nd(Operation):
|
|
|
537
537
|
macs_per_out = C_in_g * _prod(kernel_size)
|
|
538
538
|
out_elems = N * C_out * _prod(tuple(out_dims))
|
|
539
539
|
return out_elems * macs_per_out
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
def conv_nd_op(
|
|
543
|
-
stride: int | tuple[int, ...] | list[int],
|
|
544
|
-
padding: int | tuple[int, ...] | list[int],
|
|
545
|
-
dilation: int | tuple[int, ...] | list[int],
|
|
546
|
-
groups: int,
|
|
547
|
-
) -> conv_nd:
|
|
548
|
-
return conv_nd(stride, padding, dilation, groups)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from types import ModuleType
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
|
|
7
|
+
from lucid._backend.metal import mx
|
|
8
|
+
from lucid._tensor import Tensor
|
|
9
|
+
|
|
10
|
+
from lucid.types import _DeviceType, _TensorData
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _as_int_array(arr, lib_: ModuleType) -> _TensorData:
|
|
14
|
+
if lib_ is np:
|
|
15
|
+
return arr.astype(np.int64)
|
|
16
|
+
return arr.astype(mx.int32)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class embedding_kernel(Operation):
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
super().__init__()
|
|
22
|
+
self._indices = None
|
|
23
|
+
self._num_embeddings = None
|
|
24
|
+
|
|
25
|
+
def clear(self) -> None:
|
|
26
|
+
super().clear()
|
|
27
|
+
self._indices = None
|
|
28
|
+
self._num_embeddings = None
|
|
29
|
+
|
|
30
|
+
@func_op(n_in=2, n_ret=1)
|
|
31
|
+
def cpu(self, indices: Tensor, weight: Tensor) -> _FuncOpReturnType:
|
|
32
|
+
return self._forward(indices, weight, lib_=np, device="cpu")
|
|
33
|
+
|
|
34
|
+
@func_op(n_in=2, n_ret=1, device="gpu")
|
|
35
|
+
def gpu(self, indices: Tensor, weight: Tensor) -> _FuncOpReturnType:
|
|
36
|
+
return self._forward(indices, weight, lib_=mx, device="gpu")
|
|
37
|
+
|
|
38
|
+
def _forward(
|
|
39
|
+
self, indices: Tensor, weight: Tensor, lib_: ModuleType, device: _DeviceType
|
|
40
|
+
) -> _FuncOpReturnType:
|
|
41
|
+
idx = _as_int_array(indices.data, lib_)
|
|
42
|
+
out = weight.data[idx]
|
|
43
|
+
|
|
44
|
+
self._indices = idx
|
|
45
|
+
self._num_embeddings = int(weight.shape[0])
|
|
46
|
+
|
|
47
|
+
self.result = Tensor(out, device=device)
|
|
48
|
+
return self.result, functools.partial(self.__grad__, lib_=lib_)
|
|
49
|
+
|
|
50
|
+
def __grad__(self, lib_: ModuleType) -> _GradType:
|
|
51
|
+
if self.result is None or self.result.grad is None:
|
|
52
|
+
raise RuntimeError("embedding backward called before forward.")
|
|
53
|
+
if self._indices is None or self._num_embeddings is None:
|
|
54
|
+
raise RuntimeError("embedding cached data missing.")
|
|
55
|
+
|
|
56
|
+
grad_out = self.result.grad
|
|
57
|
+
idx = self._indices.reshape(-1)
|
|
58
|
+
grad_flat = grad_out.reshape(idx.shape[0], -1)
|
|
59
|
+
|
|
60
|
+
if lib_ is np:
|
|
61
|
+
grad_w = np.zeros(
|
|
62
|
+
(self._num_embeddings, grad_flat.shape[1]), dtype=grad_out.dtype
|
|
63
|
+
)
|
|
64
|
+
np.add.at(grad_w, idx, grad_flat)
|
|
65
|
+
else:
|
|
66
|
+
grad_w = mx.zeros(
|
|
67
|
+
(self._num_embeddings, grad_flat.shape[1]), dtype=grad_out.dtype
|
|
68
|
+
)
|
|
69
|
+
for i in range(idx.shape[0]):
|
|
70
|
+
grad_w = grad_w.at[idx[i]].add(grad_flat[i])
|
|
71
|
+
|
|
72
|
+
return None, grad_w
|