lucid-dl 2.11.0__py3-none-any.whl → 2.11.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/__init__.py +4 -2
- lucid/_backend/core.py +89 -9
- lucid/_backend/metal.py +5 -1
- lucid/_func/__init__.py +162 -0
- lucid/_tensor/{tensor_ops.py → base.py} +64 -0
- lucid/_tensor/tensor.py +63 -19
- lucid/autograd/__init__.py +4 -1
- lucid/datasets/mnist.py +135 -6
- lucid/models/imggen/__init__.py +1 -0
- lucid/models/imggen/ncsn.py +402 -0
- lucid/nn/_kernel/__init__.py +1 -0
- lucid/nn/_kernel/activation.py +188 -0
- lucid/nn/_kernel/attention.py +125 -0
- lucid/{_backend → nn/_kernel}/conv.py +4 -13
- lucid/nn/_kernel/embedding.py +72 -0
- lucid/nn/_kernel/loss.py +416 -0
- lucid/nn/_kernel/norm.py +365 -0
- lucid/{_backend → nn/_kernel}/pool.py +7 -27
- lucid/nn/functional/__init__.py +4 -0
- lucid/nn/functional/_activation.py +19 -13
- lucid/nn/functional/_attention.py +9 -0
- lucid/nn/functional/_conv.py +5 -16
- lucid/nn/functional/_loss.py +31 -32
- lucid/nn/functional/_norm.py +60 -69
- lucid/nn/functional/_pool.py +7 -7
- lucid/nn/functional/_util.py +5 -1
- lucid/nn/init/_dist.py +1 -0
- lucid/types.py +24 -2
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/METADATA +7 -5
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/RECORD +33 -26
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/WHEEL +1 -1
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/top_level.txt +0 -0
lucid/nn/_kernel/norm.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from types import ModuleType
|
|
3
|
+
from typing import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
|
|
8
|
+
from lucid._backend.metal import mx
|
|
9
|
+
|
|
10
|
+
from lucid._tensor import Tensor
|
|
11
|
+
from lucid.types import _DeviceType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _norm_axes(ndim: int, normalized_shape: Sequence[int]) -> tuple[int, ...]:
|
|
15
|
+
return tuple(range(ndim - len(normalized_shape), ndim))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _broadcast_shape(ndim: int, normalized_shape: Sequence[int]) -> tuple[int, ...]:
|
|
19
|
+
return (1,) * (ndim - len(normalized_shape)) + tuple(normalized_shape)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class layer_norm_kernel(Operation):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
normalized_shape: Sequence[int],
|
|
26
|
+
eps: float = 1e-5,
|
|
27
|
+
has_weight: bool = True,
|
|
28
|
+
has_bias: bool = True,
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.normalized_shape = tuple(int(v) for v in normalized_shape)
|
|
32
|
+
self.eps = float(eps)
|
|
33
|
+
self.has_weight = bool(has_weight)
|
|
34
|
+
self.has_bias = bool(has_bias)
|
|
35
|
+
|
|
36
|
+
self._xhat = None
|
|
37
|
+
self._rstd = None
|
|
38
|
+
self._norm_axes = None
|
|
39
|
+
self._n = None
|
|
40
|
+
|
|
41
|
+
def clear(self) -> None:
|
|
42
|
+
super().clear()
|
|
43
|
+
self._xhat = None
|
|
44
|
+
self._rstd = None
|
|
45
|
+
self._norm_axes = None
|
|
46
|
+
self._n = None
|
|
47
|
+
|
|
48
|
+
@func_op(n_in=3, n_ret=1, device="cpu")
|
|
49
|
+
def cpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
50
|
+
return self._forward(a, w, b, lib_=np, device="cpu")
|
|
51
|
+
|
|
52
|
+
@func_op(n_in=3, n_ret=1, device="gpu")
|
|
53
|
+
def gpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
54
|
+
return self._forward(a, w, b, lib_=mx, device="gpu")
|
|
55
|
+
|
|
56
|
+
def _forward(
|
|
57
|
+
self,
|
|
58
|
+
a: Tensor,
|
|
59
|
+
w: Tensor,
|
|
60
|
+
b: Tensor,
|
|
61
|
+
lib_: ModuleType,
|
|
62
|
+
device: _DeviceType,
|
|
63
|
+
) -> _FuncOpReturnType:
|
|
64
|
+
norm_axes = _norm_axes(a.ndim, self.normalized_shape)
|
|
65
|
+
n = int(np.prod(self.normalized_shape))
|
|
66
|
+
mean = lib_.mean(a.data, axis=norm_axes, keepdims=True)
|
|
67
|
+
var = lib_.var(a.data, axis=norm_axes, keepdims=True)
|
|
68
|
+
rstd = 1.0 / lib_.sqrt(var + self.eps)
|
|
69
|
+
xhat = (a.data - mean) * rstd
|
|
70
|
+
|
|
71
|
+
out = xhat
|
|
72
|
+
if self.has_weight:
|
|
73
|
+
out = out * w.data.reshape(_broadcast_shape(a.ndim, self.normalized_shape))
|
|
74
|
+
if self.has_bias:
|
|
75
|
+
out = out + b.data.reshape(_broadcast_shape(a.ndim, self.normalized_shape))
|
|
76
|
+
|
|
77
|
+
self._xhat = xhat
|
|
78
|
+
self._rstd = rstd
|
|
79
|
+
self._norm_axes = norm_axes
|
|
80
|
+
self._n = n
|
|
81
|
+
|
|
82
|
+
self.result = Tensor(out, device=device)
|
|
83
|
+
return self.result, functools.partial(self.__grad__, a=a, w=w, lib_=lib_)
|
|
84
|
+
|
|
85
|
+
def __grad__(self, a: Tensor, w: Tensor, lib_: ModuleType) -> _GradType:
|
|
86
|
+
if self.result is None or self.result.grad is None:
|
|
87
|
+
raise RuntimeError("layer_norm backward called before forward.")
|
|
88
|
+
|
|
89
|
+
if self._xhat is None or self._rstd is None or self._norm_axes is None:
|
|
90
|
+
raise RuntimeError("layer_norm cached data missing.")
|
|
91
|
+
|
|
92
|
+
dy = self.result.grad
|
|
93
|
+
xhat = self._xhat
|
|
94
|
+
rstd = self._rstd
|
|
95
|
+
norm_axes = self._norm_axes
|
|
96
|
+
n = self._n if self._n is not None else int(np.prod(self.normalized_shape))
|
|
97
|
+
|
|
98
|
+
if self.has_weight:
|
|
99
|
+
w_broadcast = w.data.reshape(
|
|
100
|
+
_broadcast_shape(a.ndim, self.normalized_shape)
|
|
101
|
+
)
|
|
102
|
+
dyw = dy * w_broadcast
|
|
103
|
+
else:
|
|
104
|
+
dyw = dy
|
|
105
|
+
|
|
106
|
+
sum1 = lib_.sum(dyw, axis=norm_axes, keepdims=True)
|
|
107
|
+
sum2 = lib_.sum(dyw * xhat, axis=norm_axes, keepdims=True)
|
|
108
|
+
|
|
109
|
+
dx = (1.0 / n) * rstd * (n * dyw - sum1 - xhat * sum2)
|
|
110
|
+
|
|
111
|
+
reduce_axes = tuple(range(0, a.ndim - len(self.normalized_shape)))
|
|
112
|
+
if reduce_axes:
|
|
113
|
+
dweight = lib_.sum(dy * xhat, axis=reduce_axes)
|
|
114
|
+
dbias = lib_.sum(dy, axis=reduce_axes)
|
|
115
|
+
else:
|
|
116
|
+
dweight = dy * xhat
|
|
117
|
+
dbias = dy
|
|
118
|
+
|
|
119
|
+
return dx, dweight, dbias
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class batch_norm_kernel(Operation):
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
eps: float = 1e-5,
|
|
126
|
+
momentum: float = 0.1,
|
|
127
|
+
training: bool = True,
|
|
128
|
+
has_running: bool = True,
|
|
129
|
+
has_weight: bool = True,
|
|
130
|
+
has_bias: bool = True,
|
|
131
|
+
) -> None:
|
|
132
|
+
super().__init__()
|
|
133
|
+
self.eps = float(eps)
|
|
134
|
+
self.momentum = float(momentum)
|
|
135
|
+
self.training = bool(training)
|
|
136
|
+
self.has_running = bool(has_running)
|
|
137
|
+
self.has_weight = bool(has_weight)
|
|
138
|
+
self.has_bias = bool(has_bias)
|
|
139
|
+
|
|
140
|
+
self._xhat = None
|
|
141
|
+
self._rstd = None
|
|
142
|
+
self._axes = None
|
|
143
|
+
self._m = None
|
|
144
|
+
self._use_batch_stats = None
|
|
145
|
+
|
|
146
|
+
def clear(self) -> None:
|
|
147
|
+
super().clear()
|
|
148
|
+
self._xhat = None
|
|
149
|
+
self._rstd = None
|
|
150
|
+
self._axes = None
|
|
151
|
+
self._m = None
|
|
152
|
+
self._use_batch_stats = None
|
|
153
|
+
|
|
154
|
+
@func_op(n_in=5, n_ret=1, device="cpu")
|
|
155
|
+
def cpu(
|
|
156
|
+
self, a: Tensor, running_mean: Tensor, running_var: Tensor, w: Tensor, b: Tensor
|
|
157
|
+
) -> _FuncOpReturnType:
|
|
158
|
+
return self._forward(a, running_mean, running_var, w, b, lib_=np, device="cpu")
|
|
159
|
+
|
|
160
|
+
@func_op(n_in=5, n_ret=1, device="gpu")
|
|
161
|
+
def gpu(
|
|
162
|
+
self, a: Tensor, running_mean: Tensor, running_var: Tensor, w: Tensor, b: Tensor
|
|
163
|
+
) -> _FuncOpReturnType:
|
|
164
|
+
return self._forward(a, running_mean, running_var, w, b, lib_=mx, device="gpu")
|
|
165
|
+
|
|
166
|
+
def _forward(
|
|
167
|
+
self,
|
|
168
|
+
a: Tensor,
|
|
169
|
+
running_mean: Tensor,
|
|
170
|
+
running_var: Tensor,
|
|
171
|
+
w: Tensor,
|
|
172
|
+
b: Tensor,
|
|
173
|
+
lib_: ModuleType,
|
|
174
|
+
device: _DeviceType,
|
|
175
|
+
) -> _FuncOpReturnType:
|
|
176
|
+
axes = (0,) + tuple(range(2, a.ndim))
|
|
177
|
+
m = int(np.prod([a.shape[i] for i in axes]))
|
|
178
|
+
use_batch_stats = self.training or not self.has_running
|
|
179
|
+
|
|
180
|
+
if use_batch_stats:
|
|
181
|
+
mean = lib_.mean(a.data, axis=axes, keepdims=True)
|
|
182
|
+
var = lib_.var(a.data, axis=axes, keepdims=True)
|
|
183
|
+
|
|
184
|
+
if self.training and self.has_running:
|
|
185
|
+
rm = (
|
|
186
|
+
self.momentum * mean.reshape(-1)
|
|
187
|
+
+ (1 - self.momentum) * running_mean.data
|
|
188
|
+
)
|
|
189
|
+
rv = (
|
|
190
|
+
self.momentum * var.reshape(-1)
|
|
191
|
+
+ (1 - self.momentum) * running_var.data
|
|
192
|
+
)
|
|
193
|
+
running_mean.data = rm
|
|
194
|
+
running_var.data = rv
|
|
195
|
+
|
|
196
|
+
else:
|
|
197
|
+
mean = running_mean.data.reshape(1, -1, *([1] * (a.ndim - 2)))
|
|
198
|
+
var = running_var.data.reshape(1, -1, *([1] * (a.ndim - 2)))
|
|
199
|
+
|
|
200
|
+
rstd = 1.0 / lib_.sqrt(var + self.eps)
|
|
201
|
+
xhat = (a.data - mean) * rstd
|
|
202
|
+
|
|
203
|
+
out = xhat
|
|
204
|
+
if self.has_weight:
|
|
205
|
+
out = out * w.data.reshape(1, -1, *([1] * (a.ndim - 2)))
|
|
206
|
+
if self.has_bias:
|
|
207
|
+
out = out + b.data.reshape(1, -1, *([1] * (a.ndim - 2)))
|
|
208
|
+
|
|
209
|
+
self._xhat = xhat
|
|
210
|
+
self._rstd = rstd
|
|
211
|
+
self._axes = axes
|
|
212
|
+
self._m = m
|
|
213
|
+
self._use_batch_stats = use_batch_stats
|
|
214
|
+
|
|
215
|
+
self.result = Tensor(out, device=device)
|
|
216
|
+
return self.result, functools.partial(self.__grad__, a=a, w=w, lib_=lib_)
|
|
217
|
+
|
|
218
|
+
def __grad__(self, a: Tensor, w: Tensor, lib_: ModuleType) -> _GradType:
|
|
219
|
+
if self.result is None or self.result.grad is None:
|
|
220
|
+
raise RuntimeError("batch_norm backward called before forward.")
|
|
221
|
+
|
|
222
|
+
if self._rstd is None or self._axes is None or self._m is None:
|
|
223
|
+
raise RuntimeError("batch_norm cached data missing.")
|
|
224
|
+
|
|
225
|
+
dy = self.result.grad
|
|
226
|
+
axes = self._axes
|
|
227
|
+
m = self._m
|
|
228
|
+
|
|
229
|
+
if self.has_weight:
|
|
230
|
+
w_broadcast = w.data.reshape(1, -1, *([1] * (a.ndim - 2)))
|
|
231
|
+
dyw = dy * w_broadcast
|
|
232
|
+
else:
|
|
233
|
+
dyw = dy
|
|
234
|
+
|
|
235
|
+
if self._use_batch_stats:
|
|
236
|
+
xhat = self._xhat
|
|
237
|
+
rstd = self._rstd
|
|
238
|
+
sum1 = lib_.sum(dyw, axis=axes, keepdims=True)
|
|
239
|
+
sum2 = lib_.sum(dyw * xhat, axis=axes, keepdims=True)
|
|
240
|
+
dx = (1.0 / m) * rstd * (m * dyw - sum1 - xhat * sum2)
|
|
241
|
+
else:
|
|
242
|
+
rstd = self._rstd
|
|
243
|
+
dx = dyw * rstd
|
|
244
|
+
|
|
245
|
+
reduce_axes = (0,) + tuple(range(2, a.ndim))
|
|
246
|
+
dweight = lib_.sum(
|
|
247
|
+
dy * (self._xhat if self._xhat is not None else 1.0), axis=reduce_axes
|
|
248
|
+
)
|
|
249
|
+
dbias = lib_.sum(dy, axis=reduce_axes)
|
|
250
|
+
|
|
251
|
+
return dx, None, None, dweight, dbias
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class group_norm_kernel(Operation):
|
|
255
|
+
def __init__(
|
|
256
|
+
self,
|
|
257
|
+
num_groups: int,
|
|
258
|
+
eps: float = 1e-5,
|
|
259
|
+
has_weight: bool = True,
|
|
260
|
+
has_bias: bool = True,
|
|
261
|
+
) -> None:
|
|
262
|
+
super().__init__()
|
|
263
|
+
self.num_groups = int(num_groups)
|
|
264
|
+
self.eps = float(eps)
|
|
265
|
+
self.has_weight = bool(has_weight)
|
|
266
|
+
self.has_bias = bool(has_bias)
|
|
267
|
+
|
|
268
|
+
self._xhat = None
|
|
269
|
+
self._rstd = None
|
|
270
|
+
self._group_shape = None
|
|
271
|
+
self._reduce_axes = None
|
|
272
|
+
self._m = None
|
|
273
|
+
|
|
274
|
+
def clear(self) -> None:
|
|
275
|
+
super().clear()
|
|
276
|
+
self._xhat = None
|
|
277
|
+
self._rstd = None
|
|
278
|
+
self._group_shape = None
|
|
279
|
+
self._reduce_axes = None
|
|
280
|
+
self._m = None
|
|
281
|
+
|
|
282
|
+
@func_op(n_in=3, n_ret=1, device="cpu")
|
|
283
|
+
def cpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
284
|
+
return self._forward(a, w, b, lib_=np, device="cpu")
|
|
285
|
+
|
|
286
|
+
@func_op(n_in=3, n_ret=1, device="gpu")
|
|
287
|
+
def gpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
|
|
288
|
+
return self._forward(a, w, b, lib_=mx, device="gpu")
|
|
289
|
+
|
|
290
|
+
def _forward(
|
|
291
|
+
self,
|
|
292
|
+
a: Tensor,
|
|
293
|
+
w: Tensor,
|
|
294
|
+
b: Tensor,
|
|
295
|
+
lib_: ModuleType,
|
|
296
|
+
device: _DeviceType,
|
|
297
|
+
) -> _FuncOpReturnType:
|
|
298
|
+
N, C, *spatial = a.shape
|
|
299
|
+
if C % self.num_groups != 0:
|
|
300
|
+
raise ValueError("num_groups must divide channels.")
|
|
301
|
+
|
|
302
|
+
group_size = C // self.num_groups
|
|
303
|
+
group_shape = (N, self.num_groups, group_size, *spatial)
|
|
304
|
+
x = a.data.reshape(group_shape)
|
|
305
|
+
reduce_axes = (2,) + tuple(range(3, x.ndim))
|
|
306
|
+
m = int(np.prod([x.shape[i] for i in reduce_axes]))
|
|
307
|
+
|
|
308
|
+
mean = lib_.mean(x, axis=reduce_axes, keepdims=True)
|
|
309
|
+
var = lib_.var(x, axis=reduce_axes, keepdims=True)
|
|
310
|
+
rstd = 1.0 / lib_.sqrt(var + self.eps)
|
|
311
|
+
xhat = (x - mean) * rstd
|
|
312
|
+
|
|
313
|
+
out = xhat.reshape(a.shape)
|
|
314
|
+
if self.has_weight:
|
|
315
|
+
out = out * w.data.reshape(1, C, *([1] * len(spatial)))
|
|
316
|
+
if self.has_bias:
|
|
317
|
+
out = out + b.data.reshape(1, C, *([1] * len(spatial)))
|
|
318
|
+
|
|
319
|
+
self._xhat = xhat
|
|
320
|
+
self._rstd = rstd
|
|
321
|
+
self._group_shape = group_shape
|
|
322
|
+
self._reduce_axes = reduce_axes
|
|
323
|
+
self._m = m
|
|
324
|
+
|
|
325
|
+
self.result = Tensor(out, device=device)
|
|
326
|
+
return self.result, functools.partial(self.__grad__, a=a, w=w, b=b, lib_=lib_)
|
|
327
|
+
|
|
328
|
+
def __grad__(self, a: Tensor, w: Tensor, b: Tensor, lib_: ModuleType) -> _GradType:
|
|
329
|
+
if self.result is None or self.result.grad is None:
|
|
330
|
+
raise RuntimeError("group_norm backward called before forward.")
|
|
331
|
+
|
|
332
|
+
if (
|
|
333
|
+
self._xhat is None
|
|
334
|
+
or self._rstd is None
|
|
335
|
+
or self._group_shape is None
|
|
336
|
+
or self._reduce_axes is None
|
|
337
|
+
or self._m is None
|
|
338
|
+
):
|
|
339
|
+
raise RuntimeError("group_norm cached data missing.")
|
|
340
|
+
|
|
341
|
+
dy = self.result.grad
|
|
342
|
+
N, C, *spatial = a.shape
|
|
343
|
+
dy_g = dy.reshape(self._group_shape)
|
|
344
|
+
xhat = self._xhat
|
|
345
|
+
rstd = self._rstd
|
|
346
|
+
axes = self._reduce_axes
|
|
347
|
+
m = self._m
|
|
348
|
+
|
|
349
|
+
if self.has_weight:
|
|
350
|
+
w_broadcast = w.data.reshape(1, C, *([1] * len(spatial)))
|
|
351
|
+
dyw = dy * w_broadcast
|
|
352
|
+
dyw_g = dyw.reshape(self._group_shape)
|
|
353
|
+
else:
|
|
354
|
+
dyw_g = dy_g
|
|
355
|
+
|
|
356
|
+
sum1 = lib_.sum(dyw_g, axis=axes, keepdims=True)
|
|
357
|
+
sum2 = lib_.sum(dyw_g * xhat, axis=axes, keepdims=True)
|
|
358
|
+
dx_g = (1.0 / m) * rstd * (m * dyw_g - sum1 - xhat * sum2)
|
|
359
|
+
dx = dx_g.reshape(a.shape)
|
|
360
|
+
|
|
361
|
+
reduce_axes = (0,) + tuple(range(2, a.ndim))
|
|
362
|
+
dweight = lib_.sum(dy * xhat.reshape(a.shape), axis=reduce_axes)
|
|
363
|
+
dbias = lib_.sum(dy, axis=reduce_axes)
|
|
364
|
+
|
|
365
|
+
return dx, dweight, dbias
|
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
|
8
8
|
from lucid._tensor import Tensor
|
|
9
9
|
from lucid._backend.core import (
|
|
10
10
|
Operation,
|
|
11
|
-
|
|
11
|
+
func_op,
|
|
12
12
|
_FuncOpReturnType,
|
|
13
13
|
_GradType,
|
|
14
14
|
)
|
|
@@ -92,11 +92,7 @@ def _where(lib_: ModuleType, cond: _Array, x: _Array, y: _Array) -> _Array:
|
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
def _pool_forward_sum(
|
|
95
|
-
|
|
96
|
-
x_pad: _Array,
|
|
97
|
-
out_dims: _Shape,
|
|
98
|
-
kernel_size: _Shape,
|
|
99
|
-
stride: _Shape,
|
|
95
|
+
x_pad: _Array, out_dims: _Shape, kernel_size: _Shape, stride: _Shape
|
|
100
96
|
) -> _Array:
|
|
101
97
|
out = None
|
|
102
98
|
for k_idx in itertools.product(*[range(k) for k in kernel_size]):
|
|
@@ -211,7 +207,7 @@ def _pool_backward_max(
|
|
|
211
207
|
return _crop_padding(grad_input_pad, padding)
|
|
212
208
|
|
|
213
209
|
|
|
214
|
-
class
|
|
210
|
+
class pool_nd_kernel(Operation):
|
|
215
211
|
def __init__(
|
|
216
212
|
self,
|
|
217
213
|
kernel_size: int | tuple[int, ...] | list[int],
|
|
@@ -259,7 +255,7 @@ class pool_nd(Operation):
|
|
|
259
255
|
|
|
260
256
|
return kernel, stride, padding
|
|
261
257
|
|
|
262
|
-
@
|
|
258
|
+
@func_op(n_in=1, n_ret=1)
|
|
263
259
|
def cpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
264
260
|
kernel, stride, padding = self._normalize(a)
|
|
265
261
|
out_dims = _pool_out_dims(a.shape[2:], kernel, stride, padding)
|
|
@@ -268,7 +264,7 @@ class pool_nd(Operation):
|
|
|
268
264
|
|
|
269
265
|
x_pad = _pad_input(np, a.data, padding)
|
|
270
266
|
if self.mode == "avg":
|
|
271
|
-
out_sum = _pool_forward_sum(
|
|
267
|
+
out_sum = _pool_forward_sum(x_pad, out_dims, kernel, stride)
|
|
272
268
|
out = out_sum / _prod(kernel)
|
|
273
269
|
else:
|
|
274
270
|
out, max_idx = _pool_forward_max(np, x_pad, out_dims, kernel, stride)
|
|
@@ -277,7 +273,7 @@ class pool_nd(Operation):
|
|
|
277
273
|
self.result = Tensor(out)
|
|
278
274
|
return self.result, partial(self.__grad__, lib_=np)
|
|
279
275
|
|
|
280
|
-
@
|
|
276
|
+
@func_op(n_in=1, n_ret=1, device="gpu")
|
|
281
277
|
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
282
278
|
kernel, stride, padding = self._normalize(a)
|
|
283
279
|
out_dims = _pool_out_dims(a.shape[2:], kernel, stride, padding)
|
|
@@ -286,7 +282,7 @@ class pool_nd(Operation):
|
|
|
286
282
|
|
|
287
283
|
x_pad = _pad_input(mx, a.data, padding)
|
|
288
284
|
if self.mode == "avg":
|
|
289
|
-
out_sum = _pool_forward_sum(
|
|
285
|
+
out_sum = _pool_forward_sum(x_pad, out_dims, kernel, stride)
|
|
290
286
|
out = out_sum / _prod(kernel)
|
|
291
287
|
else:
|
|
292
288
|
out, max_idx = _pool_forward_max(mx, x_pad, out_dims, kernel, stride)
|
|
@@ -350,19 +346,3 @@ class pool_nd(Operation):
|
|
|
350
346
|
if self.mode == "avg":
|
|
351
347
|
return out_elems * kernel_elems
|
|
352
348
|
return out_elems * max(kernel_elems - 1, 0)
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
def avg_pool_nd_op(
|
|
356
|
-
kernel_size: int | tuple[int, ...] | list[int],
|
|
357
|
-
stride: int | tuple[int, ...] | list[int],
|
|
358
|
-
padding: int | tuple[int, ...] | list[int],
|
|
359
|
-
) -> pool_nd:
|
|
360
|
-
return pool_nd(kernel_size, stride, padding, mode="avg")
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
def max_pool_nd_op(
|
|
364
|
-
kernel_size: int | tuple[int, ...] | list[int],
|
|
365
|
-
stride: int | tuple[int, ...] | list[int],
|
|
366
|
-
padding: int | tuple[int, ...] | list[int],
|
|
367
|
-
) -> pool_nd:
|
|
368
|
-
return pool_nd(kernel_size, stride, padding, mode="max")
|
lucid/nn/functional/__init__.py
CHANGED
|
@@ -54,6 +54,10 @@ def tanh(input_: Tensor) -> Tensor:
|
|
|
54
54
|
return _activation.tanh(input_)
|
|
55
55
|
|
|
56
56
|
|
|
57
|
+
def silu(input_: Tensor) -> Tensor:
|
|
58
|
+
return _activation.silu(input_)
|
|
59
|
+
|
|
60
|
+
|
|
57
61
|
def softmax(input_: Tensor, axis: int = -1) -> Tensor:
|
|
58
62
|
return _activation.softmax(input_, axis)
|
|
59
63
|
|
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
import lucid
|
|
2
2
|
|
|
3
3
|
from lucid._tensor import Tensor
|
|
4
|
+
from lucid.nn._kernel.activation import (
|
|
5
|
+
softmax_kernel,
|
|
6
|
+
sigmoid_kernel,
|
|
7
|
+
gelu_kernel,
|
|
8
|
+
silu_kernel,
|
|
9
|
+
)
|
|
4
10
|
|
|
5
11
|
|
|
6
12
|
def relu(input_: Tensor) -> Tensor:
|
|
@@ -9,14 +15,14 @@ def relu(input_: Tensor) -> Tensor:
|
|
|
9
15
|
|
|
10
16
|
def leaky_relu(input_: Tensor, negative_slope: float = 0.01) -> Tensor:
|
|
11
17
|
mask = input_ > 0
|
|
12
|
-
out = input_ * mask + input_ * negative_slope * (
|
|
18
|
+
out = input_ * mask + input_ * negative_slope * (~mask)
|
|
13
19
|
return out
|
|
14
20
|
|
|
15
21
|
|
|
16
22
|
def elu(input_: Tensor, alpha: float = 1.0) -> Tensor:
|
|
17
23
|
mask = input_ >= 0
|
|
18
24
|
pos = input_ * mask
|
|
19
|
-
neg = alpha * (lucid.exp(input_) - 1) * (
|
|
25
|
+
neg = alpha * (lucid.exp(input_) - 1) * (~mask)
|
|
20
26
|
return pos + neg
|
|
21
27
|
|
|
22
28
|
|
|
@@ -26,29 +32,29 @@ def selu(input_: Tensor) -> Tensor:
|
|
|
26
32
|
|
|
27
33
|
mask = input_ >= 0
|
|
28
34
|
pos = _scale * input_ * mask
|
|
29
|
-
neg = _scale * _alpha * (lucid.exp(input_) - 1) * (
|
|
35
|
+
neg = _scale * _alpha * (lucid.exp(input_) - 1) * (~mask)
|
|
30
36
|
return pos + neg
|
|
31
37
|
|
|
32
38
|
|
|
33
39
|
def gelu(input_: Tensor) -> Tensor:
|
|
34
|
-
|
|
35
|
-
return
|
|
40
|
+
op = gelu_kernel()
|
|
41
|
+
return op(input_)
|
|
36
42
|
|
|
37
43
|
|
|
38
44
|
def sigmoid(input_: Tensor) -> Tensor:
|
|
39
|
-
|
|
45
|
+
op = sigmoid_kernel()
|
|
46
|
+
return op(input_)
|
|
40
47
|
|
|
41
48
|
|
|
42
49
|
def tanh(input_: Tensor) -> Tensor:
|
|
43
50
|
return lucid.tanh(input_)
|
|
44
51
|
|
|
45
52
|
|
|
46
|
-
def
|
|
47
|
-
|
|
48
|
-
|
|
53
|
+
def silu(input_: Tensor) -> Tensor:
|
|
54
|
+
op = silu_kernel()
|
|
55
|
+
return op(input_)
|
|
49
56
|
|
|
50
|
-
e_input = lucid.exp(input_stable)
|
|
51
|
-
sum_e_input = e_input.sum(axis=axis, keepdims=True)
|
|
52
57
|
|
|
53
|
-
|
|
54
|
-
|
|
58
|
+
def softmax(input_: Tensor, axis: int = -1) -> Tensor:
|
|
59
|
+
op = softmax_kernel(axis=axis)
|
|
60
|
+
return op(input_)
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import math
|
|
2
|
+
|
|
2
3
|
import lucid
|
|
3
4
|
import lucid.nn.functional as F
|
|
4
5
|
|
|
6
|
+
from lucid.nn._kernel.attention import scaled_dot_product_attention_kernel
|
|
7
|
+
|
|
5
8
|
from lucid._tensor import Tensor
|
|
6
9
|
|
|
7
10
|
|
|
@@ -14,6 +17,12 @@ def scaled_dot_product_attention(
|
|
|
14
17
|
is_causal: bool = False,
|
|
15
18
|
scale: float | None = None,
|
|
16
19
|
) -> Tensor:
|
|
20
|
+
if dropout_p == 0.0:
|
|
21
|
+
op = scaled_dot_product_attention_kernel(
|
|
22
|
+
attn_mask=attn_mask, is_causal=is_causal, scale=scale
|
|
23
|
+
)
|
|
24
|
+
return op(query, key, value)
|
|
25
|
+
|
|
17
26
|
L, S = query.shape[-2], key.shape[-2]
|
|
18
27
|
scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
|
|
19
28
|
attn_bias = lucid.zeros(L, S, dtype=query.dtype).free()
|
lucid/nn/functional/_conv.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Tuple, Optional
|
|
|
3
3
|
|
|
4
4
|
import lucid
|
|
5
5
|
from lucid._tensor import Tensor
|
|
6
|
-
from lucid.
|
|
6
|
+
from lucid.nn._kernel.conv import conv_nd_kernel
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def unfold(
|
|
@@ -66,17 +66,6 @@ def unfold(
|
|
|
66
66
|
return col.reshape((N_out, C_filt))
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
def _conv_tensor(
|
|
70
|
-
input_: Tensor,
|
|
71
|
-
weight: Tensor,
|
|
72
|
-
stride: Tuple[int, ...],
|
|
73
|
-
padding: Tuple[int, ...],
|
|
74
|
-
dilation: Tuple[int, ...],
|
|
75
|
-
groups: int,
|
|
76
|
-
) -> Tensor:
|
|
77
|
-
return conv_nd_op(stride, padding, dilation, groups)(input_, weight)
|
|
78
|
-
|
|
79
|
-
|
|
80
69
|
def conv(
|
|
81
70
|
input_: Tensor,
|
|
82
71
|
weight: Tensor,
|
|
@@ -92,7 +81,7 @@ def conv(
|
|
|
92
81
|
if len(stride) != len(padding) or len(stride) != len(dilation):
|
|
93
82
|
raise ValueError("Stride, padding, and dilation must have the same length.")
|
|
94
83
|
|
|
95
|
-
out =
|
|
84
|
+
out = conv_nd_kernel(stride, padding, dilation, groups)(input_, weight)
|
|
96
85
|
|
|
97
86
|
if bias is not None:
|
|
98
87
|
bias_sh = [1, weight.shape[0]] + [1] * (input_.ndim - 2)
|
|
@@ -181,9 +170,9 @@ def conv_transpose(
|
|
|
181
170
|
zeros = lucid.zeros(*zero_shape, dtype=ups.dtype, device=ups.device)
|
|
182
171
|
ups = lucid.concatenate([ups, zeros], axis=axis)
|
|
183
172
|
|
|
184
|
-
out_g =
|
|
185
|
-
|
|
186
|
-
)
|
|
173
|
+
out_g = conv_nd_kernel(
|
|
174
|
+
stride=(1,) * D, padding=pad_, dilation=dilation, groups=1
|
|
175
|
+
)(ups, w_t)
|
|
187
176
|
outputs.append(out_g)
|
|
188
177
|
|
|
189
178
|
output = lucid.concatenate(outputs, axis=1)
|
lucid/nn/functional/_loss.py
CHANGED
|
@@ -3,6 +3,12 @@ from typing import Literal
|
|
|
3
3
|
import lucid
|
|
4
4
|
from lucid._tensor import Tensor
|
|
5
5
|
|
|
6
|
+
from lucid.nn._kernel.loss import (
|
|
7
|
+
cross_entropy_kernel,
|
|
8
|
+
binary_cross_entropy_kernel,
|
|
9
|
+
binary_cross_entropy_with_logits_kernel,
|
|
10
|
+
)
|
|
11
|
+
|
|
6
12
|
_ReductionType = Literal["mean", "sum"]
|
|
7
13
|
|
|
8
14
|
|
|
@@ -55,13 +61,14 @@ def binary_cross_entropy(
|
|
|
55
61
|
reduction: _ReductionType | None = "mean",
|
|
56
62
|
eps: float = 1e-7,
|
|
57
63
|
) -> Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
has_weight = weight is not None
|
|
65
|
+
if weight is None:
|
|
66
|
+
weight = lucid.ones_like(input_, device=input_.device)
|
|
60
67
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
return
|
|
68
|
+
op = binary_cross_entropy_kernel(
|
|
69
|
+
reduction=reduction, eps=eps, has_weight=has_weight
|
|
70
|
+
)
|
|
71
|
+
return op(input_, target, weight)
|
|
65
72
|
|
|
66
73
|
|
|
67
74
|
def binary_cross_entropy_with_logits(
|
|
@@ -71,19 +78,17 @@ def binary_cross_entropy_with_logits(
|
|
|
71
78
|
pos_weight: Tensor | None = None,
|
|
72
79
|
reduction: _ReductionType | None = "mean",
|
|
73
80
|
) -> Tensor:
|
|
74
|
-
|
|
75
|
-
|
|
81
|
+
has_weight = weight is not None
|
|
82
|
+
has_pos_weight = pos_weight is not None
|
|
83
|
+
if weight is None:
|
|
84
|
+
weight = lucid.ones_like(input_, device=input_.device)
|
|
85
|
+
if pos_weight is None:
|
|
86
|
+
pos_weight = lucid.ones_like(input_, device=input_.device)
|
|
76
87
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
loss = lucid.maximum(input_, 0) - input_ * target + sp
|
|
82
|
-
|
|
83
|
-
if weight is not None:
|
|
84
|
-
loss *= weight
|
|
85
|
-
|
|
86
|
-
return _loss_reduction(loss, reduction)
|
|
88
|
+
op = binary_cross_entropy_with_logits_kernel(
|
|
89
|
+
reduction=reduction, has_weight=has_weight, has_pos_weight=has_pos_weight
|
|
90
|
+
)
|
|
91
|
+
return op(input_, target, weight, pos_weight)
|
|
87
92
|
|
|
88
93
|
|
|
89
94
|
def cross_entropy(
|
|
@@ -94,20 +99,14 @@ def cross_entropy(
|
|
|
94
99
|
eps: float = 1e-7,
|
|
95
100
|
ignore_index: int | None = None,
|
|
96
101
|
) -> Tensor:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
loss *= weight[target_int]
|
|
106
|
-
|
|
107
|
-
if ignore_index is not None:
|
|
108
|
-
return _ignore_index_loss(loss, target_int, ignore_index, reduction)
|
|
109
|
-
|
|
110
|
-
return _loss_reduction(loss, reduction)
|
|
102
|
+
has_weight = weight is not None
|
|
103
|
+
if weight is None:
|
|
104
|
+
weight = lucid.ones((input_.shape[1],), device=input_.device)
|
|
105
|
+
|
|
106
|
+
op = cross_entropy_kernel(
|
|
107
|
+
reduction=reduction, eps=eps, ignore_index=ignore_index, has_weight=has_weight
|
|
108
|
+
)
|
|
109
|
+
return op(input_, target, weight)
|
|
111
110
|
|
|
112
111
|
|
|
113
112
|
def nll_loss(
|