lucid-dl 2.11.0__py3-none-any.whl → 2.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ import functools
2
+ from types import ModuleType
3
+ from typing import Sequence
4
+
5
+ import numpy as np
6
+
7
+ from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
8
+ from lucid._backend.metal import mx
9
+
10
+ from lucid._tensor import Tensor
11
+ from lucid.types import _DeviceType
12
+
13
+
14
+ def _norm_axes(ndim: int, normalized_shape: Sequence[int]) -> tuple[int, ...]:
15
+ return tuple(range(ndim - len(normalized_shape), ndim))
16
+
17
+
18
+ def _broadcast_shape(ndim: int, normalized_shape: Sequence[int]) -> tuple[int, ...]:
19
+ return (1,) * (ndim - len(normalized_shape)) + tuple(normalized_shape)
20
+
21
+
22
+ class layer_norm_kernel(Operation):
23
+ def __init__(
24
+ self,
25
+ normalized_shape: Sequence[int],
26
+ eps: float = 1e-5,
27
+ has_weight: bool = True,
28
+ has_bias: bool = True,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.normalized_shape = tuple(int(v) for v in normalized_shape)
32
+ self.eps = float(eps)
33
+ self.has_weight = bool(has_weight)
34
+ self.has_bias = bool(has_bias)
35
+
36
+ self._xhat = None
37
+ self._rstd = None
38
+ self._norm_axes = None
39
+ self._n = None
40
+
41
+ def clear(self) -> None:
42
+ super().clear()
43
+ self._xhat = None
44
+ self._rstd = None
45
+ self._norm_axes = None
46
+ self._n = None
47
+
48
+ @func_op(n_in=3, n_ret=1, device="cpu")
49
+ def cpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
50
+ return self._forward(a, w, b, lib_=np, device="cpu")
51
+
52
+ @func_op(n_in=3, n_ret=1, device="gpu")
53
+ def gpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
54
+ return self._forward(a, w, b, lib_=mx, device="gpu")
55
+
56
+ def _forward(
57
+ self,
58
+ a: Tensor,
59
+ w: Tensor,
60
+ b: Tensor,
61
+ lib_: ModuleType,
62
+ device: _DeviceType,
63
+ ) -> _FuncOpReturnType:
64
+ norm_axes = _norm_axes(a.ndim, self.normalized_shape)
65
+ n = int(np.prod(self.normalized_shape))
66
+ mean = lib_.mean(a.data, axis=norm_axes, keepdims=True)
67
+ var = lib_.var(a.data, axis=norm_axes, keepdims=True)
68
+ rstd = 1.0 / lib_.sqrt(var + self.eps)
69
+ xhat = (a.data - mean) * rstd
70
+
71
+ out = xhat
72
+ if self.has_weight:
73
+ out = out * w.data.reshape(_broadcast_shape(a.ndim, self.normalized_shape))
74
+ if self.has_bias:
75
+ out = out + b.data.reshape(_broadcast_shape(a.ndim, self.normalized_shape))
76
+
77
+ self._xhat = xhat
78
+ self._rstd = rstd
79
+ self._norm_axes = norm_axes
80
+ self._n = n
81
+
82
+ self.result = Tensor(out, device=device)
83
+ return self.result, functools.partial(self.__grad__, a=a, w=w, lib_=lib_)
84
+
85
+ def __grad__(self, a: Tensor, w: Tensor, lib_: ModuleType) -> _GradType:
86
+ if self.result is None or self.result.grad is None:
87
+ raise RuntimeError("layer_norm backward called before forward.")
88
+
89
+ if self._xhat is None or self._rstd is None or self._norm_axes is None:
90
+ raise RuntimeError("layer_norm cached data missing.")
91
+
92
+ dy = self.result.grad
93
+ xhat = self._xhat
94
+ rstd = self._rstd
95
+ norm_axes = self._norm_axes
96
+ n = self._n if self._n is not None else int(np.prod(self.normalized_shape))
97
+
98
+ if self.has_weight:
99
+ w_broadcast = w.data.reshape(
100
+ _broadcast_shape(a.ndim, self.normalized_shape)
101
+ )
102
+ dyw = dy * w_broadcast
103
+ else:
104
+ dyw = dy
105
+
106
+ sum1 = lib_.sum(dyw, axis=norm_axes, keepdims=True)
107
+ sum2 = lib_.sum(dyw * xhat, axis=norm_axes, keepdims=True)
108
+
109
+ dx = (1.0 / n) * rstd * (n * dyw - sum1 - xhat * sum2)
110
+
111
+ reduce_axes = tuple(range(0, a.ndim - len(self.normalized_shape)))
112
+ if reduce_axes:
113
+ dweight = lib_.sum(dy * xhat, axis=reduce_axes)
114
+ dbias = lib_.sum(dy, axis=reduce_axes)
115
+ else:
116
+ dweight = dy * xhat
117
+ dbias = dy
118
+
119
+ return dx, dweight, dbias
120
+
121
+
122
+ class batch_norm_kernel(Operation):
123
+ def __init__(
124
+ self,
125
+ eps: float = 1e-5,
126
+ momentum: float = 0.1,
127
+ training: bool = True,
128
+ has_running: bool = True,
129
+ has_weight: bool = True,
130
+ has_bias: bool = True,
131
+ ) -> None:
132
+ super().__init__()
133
+ self.eps = float(eps)
134
+ self.momentum = float(momentum)
135
+ self.training = bool(training)
136
+ self.has_running = bool(has_running)
137
+ self.has_weight = bool(has_weight)
138
+ self.has_bias = bool(has_bias)
139
+
140
+ self._xhat = None
141
+ self._rstd = None
142
+ self._axes = None
143
+ self._m = None
144
+ self._use_batch_stats = None
145
+
146
+ def clear(self) -> None:
147
+ super().clear()
148
+ self._xhat = None
149
+ self._rstd = None
150
+ self._axes = None
151
+ self._m = None
152
+ self._use_batch_stats = None
153
+
154
+ @func_op(n_in=5, n_ret=1, device="cpu")
155
+ def cpu(
156
+ self, a: Tensor, running_mean: Tensor, running_var: Tensor, w: Tensor, b: Tensor
157
+ ) -> _FuncOpReturnType:
158
+ return self._forward(a, running_mean, running_var, w, b, lib_=np, device="cpu")
159
+
160
+ @func_op(n_in=5, n_ret=1, device="gpu")
161
+ def gpu(
162
+ self, a: Tensor, running_mean: Tensor, running_var: Tensor, w: Tensor, b: Tensor
163
+ ) -> _FuncOpReturnType:
164
+ return self._forward(a, running_mean, running_var, w, b, lib_=mx, device="gpu")
165
+
166
+ def _forward(
167
+ self,
168
+ a: Tensor,
169
+ running_mean: Tensor,
170
+ running_var: Tensor,
171
+ w: Tensor,
172
+ b: Tensor,
173
+ lib_: ModuleType,
174
+ device: _DeviceType,
175
+ ) -> _FuncOpReturnType:
176
+ axes = (0,) + tuple(range(2, a.ndim))
177
+ m = int(np.prod([a.shape[i] for i in axes]))
178
+ use_batch_stats = self.training or not self.has_running
179
+
180
+ if use_batch_stats:
181
+ mean = lib_.mean(a.data, axis=axes, keepdims=True)
182
+ var = lib_.var(a.data, axis=axes, keepdims=True)
183
+
184
+ if self.training and self.has_running:
185
+ rm = (
186
+ self.momentum * mean.reshape(-1)
187
+ + (1 - self.momentum) * running_mean.data
188
+ )
189
+ rv = (
190
+ self.momentum * var.reshape(-1)
191
+ + (1 - self.momentum) * running_var.data
192
+ )
193
+ running_mean.data = rm
194
+ running_var.data = rv
195
+
196
+ else:
197
+ mean = running_mean.data.reshape(1, -1, *([1] * (a.ndim - 2)))
198
+ var = running_var.data.reshape(1, -1, *([1] * (a.ndim - 2)))
199
+
200
+ rstd = 1.0 / lib_.sqrt(var + self.eps)
201
+ xhat = (a.data - mean) * rstd
202
+
203
+ out = xhat
204
+ if self.has_weight:
205
+ out = out * w.data.reshape(1, -1, *([1] * (a.ndim - 2)))
206
+ if self.has_bias:
207
+ out = out + b.data.reshape(1, -1, *([1] * (a.ndim - 2)))
208
+
209
+ self._xhat = xhat
210
+ self._rstd = rstd
211
+ self._axes = axes
212
+ self._m = m
213
+ self._use_batch_stats = use_batch_stats
214
+
215
+ self.result = Tensor(out, device=device)
216
+ return self.result, functools.partial(self.__grad__, a=a, w=w, lib_=lib_)
217
+
218
+ def __grad__(self, a: Tensor, w: Tensor, lib_: ModuleType) -> _GradType:
219
+ if self.result is None or self.result.grad is None:
220
+ raise RuntimeError("batch_norm backward called before forward.")
221
+
222
+ if self._rstd is None or self._axes is None or self._m is None:
223
+ raise RuntimeError("batch_norm cached data missing.")
224
+
225
+ dy = self.result.grad
226
+ axes = self._axes
227
+ m = self._m
228
+
229
+ if self.has_weight:
230
+ w_broadcast = w.data.reshape(1, -1, *([1] * (a.ndim - 2)))
231
+ dyw = dy * w_broadcast
232
+ else:
233
+ dyw = dy
234
+
235
+ if self._use_batch_stats:
236
+ xhat = self._xhat
237
+ rstd = self._rstd
238
+ sum1 = lib_.sum(dyw, axis=axes, keepdims=True)
239
+ sum2 = lib_.sum(dyw * xhat, axis=axes, keepdims=True)
240
+ dx = (1.0 / m) * rstd * (m * dyw - sum1 - xhat * sum2)
241
+ else:
242
+ rstd = self._rstd
243
+ dx = dyw * rstd
244
+
245
+ reduce_axes = (0,) + tuple(range(2, a.ndim))
246
+ dweight = lib_.sum(
247
+ dy * (self._xhat if self._xhat is not None else 1.0), axis=reduce_axes
248
+ )
249
+ dbias = lib_.sum(dy, axis=reduce_axes)
250
+
251
+ return dx, None, None, dweight, dbias
252
+
253
+
254
+ class group_norm_kernel(Operation):
255
+ def __init__(
256
+ self,
257
+ num_groups: int,
258
+ eps: float = 1e-5,
259
+ has_weight: bool = True,
260
+ has_bias: bool = True,
261
+ ) -> None:
262
+ super().__init__()
263
+ self.num_groups = int(num_groups)
264
+ self.eps = float(eps)
265
+ self.has_weight = bool(has_weight)
266
+ self.has_bias = bool(has_bias)
267
+
268
+ self._xhat = None
269
+ self._rstd = None
270
+ self._group_shape = None
271
+ self._reduce_axes = None
272
+ self._m = None
273
+
274
+ def clear(self) -> None:
275
+ super().clear()
276
+ self._xhat = None
277
+ self._rstd = None
278
+ self._group_shape = None
279
+ self._reduce_axes = None
280
+ self._m = None
281
+
282
+ @func_op(n_in=3, n_ret=1, device="cpu")
283
+ def cpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
284
+ return self._forward(a, w, b, lib_=np, device="cpu")
285
+
286
+ @func_op(n_in=3, n_ret=1, device="gpu")
287
+ def gpu(self, a: Tensor, w: Tensor, b: Tensor) -> _FuncOpReturnType:
288
+ return self._forward(a, w, b, lib_=mx, device="gpu")
289
+
290
+ def _forward(
291
+ self,
292
+ a: Tensor,
293
+ w: Tensor,
294
+ b: Tensor,
295
+ lib_: ModuleType,
296
+ device: _DeviceType,
297
+ ) -> _FuncOpReturnType:
298
+ N, C, *spatial = a.shape
299
+ if C % self.num_groups != 0:
300
+ raise ValueError("num_groups must divide channels.")
301
+
302
+ group_size = C // self.num_groups
303
+ group_shape = (N, self.num_groups, group_size, *spatial)
304
+ x = a.data.reshape(group_shape)
305
+ reduce_axes = (2,) + tuple(range(3, x.ndim))
306
+ m = int(np.prod([x.shape[i] for i in reduce_axes]))
307
+
308
+ mean = lib_.mean(x, axis=reduce_axes, keepdims=True)
309
+ var = lib_.var(x, axis=reduce_axes, keepdims=True)
310
+ rstd = 1.0 / lib_.sqrt(var + self.eps)
311
+ xhat = (x - mean) * rstd
312
+
313
+ out = xhat.reshape(a.shape)
314
+ if self.has_weight:
315
+ out = out * w.data.reshape(1, C, *([1] * len(spatial)))
316
+ if self.has_bias:
317
+ out = out + b.data.reshape(1, C, *([1] * len(spatial)))
318
+
319
+ self._xhat = xhat
320
+ self._rstd = rstd
321
+ self._group_shape = group_shape
322
+ self._reduce_axes = reduce_axes
323
+ self._m = m
324
+
325
+ self.result = Tensor(out, device=device)
326
+ return self.result, functools.partial(self.__grad__, a=a, w=w, b=b, lib_=lib_)
327
+
328
+ def __grad__(self, a: Tensor, w: Tensor, b: Tensor, lib_: ModuleType) -> _GradType:
329
+ if self.result is None or self.result.grad is None:
330
+ raise RuntimeError("group_norm backward called before forward.")
331
+
332
+ if (
333
+ self._xhat is None
334
+ or self._rstd is None
335
+ or self._group_shape is None
336
+ or self._reduce_axes is None
337
+ or self._m is None
338
+ ):
339
+ raise RuntimeError("group_norm cached data missing.")
340
+
341
+ dy = self.result.grad
342
+ N, C, *spatial = a.shape
343
+ dy_g = dy.reshape(self._group_shape)
344
+ xhat = self._xhat
345
+ rstd = self._rstd
346
+ axes = self._reduce_axes
347
+ m = self._m
348
+
349
+ if self.has_weight:
350
+ w_broadcast = w.data.reshape(1, C, *([1] * len(spatial)))
351
+ dyw = dy * w_broadcast
352
+ dyw_g = dyw.reshape(self._group_shape)
353
+ else:
354
+ dyw_g = dy_g
355
+
356
+ sum1 = lib_.sum(dyw_g, axis=axes, keepdims=True)
357
+ sum2 = lib_.sum(dyw_g * xhat, axis=axes, keepdims=True)
358
+ dx_g = (1.0 / m) * rstd * (m * dyw_g - sum1 - xhat * sum2)
359
+ dx = dx_g.reshape(a.shape)
360
+
361
+ reduce_axes = (0,) + tuple(range(2, a.ndim))
362
+ dweight = lib_.sum(dy * xhat.reshape(a.shape), axis=reduce_axes)
363
+ dbias = lib_.sum(dy, axis=reduce_axes)
364
+
365
+ return dx, dweight, dbias
@@ -8,7 +8,7 @@ import numpy as np
8
8
  from lucid._tensor import Tensor
9
9
  from lucid._backend.core import (
10
10
  Operation,
11
- unary_func_op,
11
+ func_op,
12
12
  _FuncOpReturnType,
13
13
  _GradType,
14
14
  )
@@ -92,11 +92,7 @@ def _where(lib_: ModuleType, cond: _Array, x: _Array, y: _Array) -> _Array:
92
92
 
93
93
 
94
94
  def _pool_forward_sum(
95
- lib_: ModuleType,
96
- x_pad: _Array,
97
- out_dims: _Shape,
98
- kernel_size: _Shape,
99
- stride: _Shape,
95
+ x_pad: _Array, out_dims: _Shape, kernel_size: _Shape, stride: _Shape
100
96
  ) -> _Array:
101
97
  out = None
102
98
  for k_idx in itertools.product(*[range(k) for k in kernel_size]):
@@ -211,7 +207,7 @@ def _pool_backward_max(
211
207
  return _crop_padding(grad_input_pad, padding)
212
208
 
213
209
 
214
- class pool_nd(Operation):
210
+ class pool_nd_kernel(Operation):
215
211
  def __init__(
216
212
  self,
217
213
  kernel_size: int | tuple[int, ...] | list[int],
@@ -259,7 +255,7 @@ class pool_nd(Operation):
259
255
 
260
256
  return kernel, stride, padding
261
257
 
262
- @unary_func_op()
258
+ @func_op(n_in=1, n_ret=1)
263
259
  def cpu(self, a: Tensor) -> _FuncOpReturnType:
264
260
  kernel, stride, padding = self._normalize(a)
265
261
  out_dims = _pool_out_dims(a.shape[2:], kernel, stride, padding)
@@ -268,7 +264,7 @@ class pool_nd(Operation):
268
264
 
269
265
  x_pad = _pad_input(np, a.data, padding)
270
266
  if self.mode == "avg":
271
- out_sum = _pool_forward_sum(np, x_pad, out_dims, kernel, stride)
267
+ out_sum = _pool_forward_sum(x_pad, out_dims, kernel, stride)
272
268
  out = out_sum / _prod(kernel)
273
269
  else:
274
270
  out, max_idx = _pool_forward_max(np, x_pad, out_dims, kernel, stride)
@@ -277,7 +273,7 @@ class pool_nd(Operation):
277
273
  self.result = Tensor(out)
278
274
  return self.result, partial(self.__grad__, lib_=np)
279
275
 
280
- @unary_func_op(device="gpu")
276
+ @func_op(n_in=1, n_ret=1, device="gpu")
281
277
  def gpu(self, a: Tensor) -> _FuncOpReturnType:
282
278
  kernel, stride, padding = self._normalize(a)
283
279
  out_dims = _pool_out_dims(a.shape[2:], kernel, stride, padding)
@@ -286,7 +282,7 @@ class pool_nd(Operation):
286
282
 
287
283
  x_pad = _pad_input(mx, a.data, padding)
288
284
  if self.mode == "avg":
289
- out_sum = _pool_forward_sum(mx, x_pad, out_dims, kernel, stride)
285
+ out_sum = _pool_forward_sum(x_pad, out_dims, kernel, stride)
290
286
  out = out_sum / _prod(kernel)
291
287
  else:
292
288
  out, max_idx = _pool_forward_max(mx, x_pad, out_dims, kernel, stride)
@@ -350,19 +346,3 @@ class pool_nd(Operation):
350
346
  if self.mode == "avg":
351
347
  return out_elems * kernel_elems
352
348
  return out_elems * max(kernel_elems - 1, 0)
353
-
354
-
355
- def avg_pool_nd_op(
356
- kernel_size: int | tuple[int, ...] | list[int],
357
- stride: int | tuple[int, ...] | list[int],
358
- padding: int | tuple[int, ...] | list[int],
359
- ) -> pool_nd:
360
- return pool_nd(kernel_size, stride, padding, mode="avg")
361
-
362
-
363
- def max_pool_nd_op(
364
- kernel_size: int | tuple[int, ...] | list[int],
365
- stride: int | tuple[int, ...] | list[int],
366
- padding: int | tuple[int, ...] | list[int],
367
- ) -> pool_nd:
368
- return pool_nd(kernel_size, stride, padding, mode="max")
@@ -54,6 +54,10 @@ def tanh(input_: Tensor) -> Tensor:
54
54
  return _activation.tanh(input_)
55
55
 
56
56
 
57
+ def silu(input_: Tensor) -> Tensor:
58
+ return _activation.silu(input_)
59
+
60
+
57
61
  def softmax(input_: Tensor, axis: int = -1) -> Tensor:
58
62
  return _activation.softmax(input_, axis)
59
63
 
@@ -1,6 +1,12 @@
1
1
  import lucid
2
2
 
3
3
  from lucid._tensor import Tensor
4
+ from lucid.nn._kernel.activation import (
5
+ softmax_kernel,
6
+ sigmoid_kernel,
7
+ gelu_kernel,
8
+ silu_kernel,
9
+ )
4
10
 
5
11
 
6
12
  def relu(input_: Tensor) -> Tensor:
@@ -9,14 +15,14 @@ def relu(input_: Tensor) -> Tensor:
9
15
 
10
16
  def leaky_relu(input_: Tensor, negative_slope: float = 0.01) -> Tensor:
11
17
  mask = input_ > 0
12
- out = input_ * mask + input_ * negative_slope * (1 - mask)
18
+ out = input_ * mask + input_ * negative_slope * (~mask)
13
19
  return out
14
20
 
15
21
 
16
22
  def elu(input_: Tensor, alpha: float = 1.0) -> Tensor:
17
23
  mask = input_ >= 0
18
24
  pos = input_ * mask
19
- neg = alpha * (lucid.exp(input_) - 1) * (1 - mask)
25
+ neg = alpha * (lucid.exp(input_) - 1) * (~mask)
20
26
  return pos + neg
21
27
 
22
28
 
@@ -26,29 +32,29 @@ def selu(input_: Tensor) -> Tensor:
26
32
 
27
33
  mask = input_ >= 0
28
34
  pos = _scale * input_ * mask
29
- neg = _scale * _alpha * (lucid.exp(input_) - 1) * (1 - mask)
35
+ neg = _scale * _alpha * (lucid.exp(input_) - 1) * (~mask)
30
36
  return pos + neg
31
37
 
32
38
 
33
39
  def gelu(input_: Tensor) -> Tensor:
34
- c = lucid.sqrt(2 / lucid.pi).free()
35
- return 0.5 * input_ * (1 + lucid.tanh(c * (input_ + 0.044715 * input_**3)))
40
+ op = gelu_kernel()
41
+ return op(input_)
36
42
 
37
43
 
38
44
  def sigmoid(input_: Tensor) -> Tensor:
39
- return 1 / (1 + lucid.exp(-input_))
45
+ op = sigmoid_kernel()
46
+ return op(input_)
40
47
 
41
48
 
42
49
  def tanh(input_: Tensor) -> Tensor:
43
50
  return lucid.tanh(input_)
44
51
 
45
52
 
46
- def softmax(input_: Tensor, axis: int = -1) -> Tensor:
47
- input_max = lucid.max(input_, axis=axis, keepdims=True)
48
- input_stable = input_ - input_max
53
+ def silu(input_: Tensor) -> Tensor:
54
+ op = silu_kernel()
55
+ return op(input_)
49
56
 
50
- e_input = lucid.exp(input_stable)
51
- sum_e_input = e_input.sum(axis=axis, keepdims=True)
52
57
 
53
- output = e_input / sum_e_input
54
- return output
58
+ def softmax(input_: Tensor, axis: int = -1) -> Tensor:
59
+ op = softmax_kernel(axis=axis)
60
+ return op(input_)
@@ -1,7 +1,10 @@
1
1
  import math
2
+
2
3
  import lucid
3
4
  import lucid.nn.functional as F
4
5
 
6
+ from lucid.nn._kernel.attention import scaled_dot_product_attention_kernel
7
+
5
8
  from lucid._tensor import Tensor
6
9
 
7
10
 
@@ -14,6 +17,12 @@ def scaled_dot_product_attention(
14
17
  is_causal: bool = False,
15
18
  scale: float | None = None,
16
19
  ) -> Tensor:
20
+ if dropout_p == 0.0:
21
+ op = scaled_dot_product_attention_kernel(
22
+ attn_mask=attn_mask, is_causal=is_causal, scale=scale
23
+ )
24
+ return op(query, key, value)
25
+
17
26
  L, S = query.shape[-2], key.shape[-2]
18
27
  scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
19
28
  attn_bias = lucid.zeros(L, S, dtype=query.dtype).free()
@@ -3,7 +3,7 @@ from typing import Tuple, Optional
3
3
 
4
4
  import lucid
5
5
  from lucid._tensor import Tensor
6
- from lucid._backend.conv import conv_nd_op
6
+ from lucid.nn._kernel.conv import conv_nd_kernel
7
7
 
8
8
 
9
9
  def unfold(
@@ -66,17 +66,6 @@ def unfold(
66
66
  return col.reshape((N_out, C_filt))
67
67
 
68
68
 
69
- def _conv_tensor(
70
- input_: Tensor,
71
- weight: Tensor,
72
- stride: Tuple[int, ...],
73
- padding: Tuple[int, ...],
74
- dilation: Tuple[int, ...],
75
- groups: int,
76
- ) -> Tensor:
77
- return conv_nd_op(stride, padding, dilation, groups)(input_, weight)
78
-
79
-
80
69
  def conv(
81
70
  input_: Tensor,
82
71
  weight: Tensor,
@@ -92,7 +81,7 @@ def conv(
92
81
  if len(stride) != len(padding) or len(stride) != len(dilation):
93
82
  raise ValueError("Stride, padding, and dilation must have the same length.")
94
83
 
95
- out = _conv_tensor(input_, weight, stride, padding, dilation, groups)
84
+ out = conv_nd_kernel(stride, padding, dilation, groups)(input_, weight)
96
85
 
97
86
  if bias is not None:
98
87
  bias_sh = [1, weight.shape[0]] + [1] * (input_.ndim - 2)
@@ -181,9 +170,9 @@ def conv_transpose(
181
170
  zeros = lucid.zeros(*zero_shape, dtype=ups.dtype, device=ups.device)
182
171
  ups = lucid.concatenate([ups, zeros], axis=axis)
183
172
 
184
- out_g = _conv_tensor(
185
- ups, w_t, stride=(1,) * D, padding=pad_, dilation=dilation, groups=1
186
- )
173
+ out_g = conv_nd_kernel(
174
+ stride=(1,) * D, padding=pad_, dilation=dilation, groups=1
175
+ )(ups, w_t)
187
176
  outputs.append(out_g)
188
177
 
189
178
  output = lucid.concatenate(outputs, axis=1)
@@ -3,6 +3,12 @@ from typing import Literal
3
3
  import lucid
4
4
  from lucid._tensor import Tensor
5
5
 
6
+ from lucid.nn._kernel.loss import (
7
+ cross_entropy_kernel,
8
+ binary_cross_entropy_kernel,
9
+ binary_cross_entropy_with_logits_kernel,
10
+ )
11
+
6
12
  _ReductionType = Literal["mean", "sum"]
7
13
 
8
14
 
@@ -55,13 +61,14 @@ def binary_cross_entropy(
55
61
  reduction: _ReductionType | None = "mean",
56
62
  eps: float = 1e-7,
57
63
  ) -> Tensor:
58
- input_ = lucid.clip(input_, eps, 1 - eps)
59
- loss = -target * lucid.log(input_) - (1 - target) * lucid.log(1 - input_)
64
+ has_weight = weight is not None
65
+ if weight is None:
66
+ weight = lucid.ones_like(input_, device=input_.device)
60
67
 
61
- if weight is not None:
62
- loss *= weight
63
-
64
- return _loss_reduction(loss, reduction)
68
+ op = binary_cross_entropy_kernel(
69
+ reduction=reduction, eps=eps, has_weight=has_weight
70
+ )
71
+ return op(input_, target, weight)
65
72
 
66
73
 
67
74
  def binary_cross_entropy_with_logits(
@@ -71,19 +78,17 @@ def binary_cross_entropy_with_logits(
71
78
  pos_weight: Tensor | None = None,
72
79
  reduction: _ReductionType | None = "mean",
73
80
  ) -> Tensor:
74
- max_val = lucid.maximum(-input_, 0)
75
- sp = max_val + lucid.log(lucid.exp(-max_val) + lucid.exp(-input_ - max_val))
81
+ has_weight = weight is not None
82
+ has_pos_weight = pos_weight is not None
83
+ if weight is None:
84
+ weight = lucid.ones_like(input_, device=input_.device)
85
+ if pos_weight is None:
86
+ pos_weight = lucid.ones_like(input_, device=input_.device)
76
87
 
77
- if pos_weight is not None:
78
- coeff = 1 + (pos_weight - 1) * target
79
- loss = (1 - target) * input_ + coeff * sp
80
- else:
81
- loss = lucid.maximum(input_, 0) - input_ * target + sp
82
-
83
- if weight is not None:
84
- loss *= weight
85
-
86
- return _loss_reduction(loss, reduction)
88
+ op = binary_cross_entropy_with_logits_kernel(
89
+ reduction=reduction, has_weight=has_weight, has_pos_weight=has_pos_weight
90
+ )
91
+ return op(input_, target, weight, pos_weight)
87
92
 
88
93
 
89
94
  def cross_entropy(
@@ -94,20 +99,14 @@ def cross_entropy(
94
99
  eps: float = 1e-7,
95
100
  ignore_index: int | None = None,
96
101
  ) -> Tensor:
97
- exp_logits = lucid.exp(input_ - lucid.max(input_, axis=1, keepdims=True))
98
- prob = exp_logits / lucid.sum(exp_logits, axis=1, keepdims=True)
99
-
100
- indices = lucid.arange(input_.shape[0], device=input_.device).astype(lucid.Int)
101
- target_int = target.astype(lucid.Int)
102
-
103
- loss = -lucid.log(prob[indices, target_int] + eps)
104
- if weight is not None:
105
- loss *= weight[target_int]
106
-
107
- if ignore_index is not None:
108
- return _ignore_index_loss(loss, target_int, ignore_index, reduction)
109
-
110
- return _loss_reduction(loss, reduction)
102
+ has_weight = weight is not None
103
+ if weight is None:
104
+ weight = lucid.ones((input_.shape[1],), device=input_.device)
105
+
106
+ op = cross_entropy_kernel(
107
+ reduction=reduction, eps=eps, ignore_index=ignore_index, has_weight=has_weight
108
+ )
109
+ return op(input_, target, weight)
111
110
 
112
111
 
113
112
  def nll_loss(