lucid-dl 2.11.0__py3-none-any.whl → 2.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,416 @@
1
+ import functools
2
+ from types import ModuleType
3
+
4
+ import numpy as np
5
+
6
+ from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
7
+ from lucid._backend.metal import mx
8
+ from lucid._tensor import Tensor
9
+ from lucid.types import _DeviceType, _TensorData
10
+
11
+
12
+ def _to_int(arr: _TensorData, lib_: ModuleType) -> _TensorData:
13
+ if lib_ is np:
14
+ return arr.astype(np.int64)
15
+ return arr.astype(mx.int32)
16
+
17
+
18
+ class cross_entropy_kernel(Operation):
19
+ def __init__(
20
+ self,
21
+ reduction: str | None = "mean",
22
+ eps: float = 1e-7,
23
+ ignore_index: int | None = None,
24
+ has_weight: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ self.reduction = reduction
28
+ self.eps = float(eps)
29
+ self.ignore_index = ignore_index
30
+ self.has_weight = bool(has_weight)
31
+
32
+ self._log_probs = None
33
+ self._probs = None
34
+ self._target = None
35
+ self._weight = None
36
+ self._valid_count = None
37
+
38
+ def clear(self) -> None:
39
+ super().clear()
40
+ self._log_probs = None
41
+ self._probs = None
42
+ self._target = None
43
+ self._weight = None
44
+ self._valid_count = None
45
+
46
+ @func_op(n_in=3, n_ret=1)
47
+ def cpu(self, logits: Tensor, target: Tensor, weight: Tensor) -> _FuncOpReturnType:
48
+ return self._forward(logits, target, weight, lib_=np, device="cpu")
49
+
50
+ @func_op(n_in=3, n_ret=1, device="gpu")
51
+ def gpu(self, logits: Tensor, target: Tensor, weight: Tensor) -> _FuncOpReturnType:
52
+ return self._forward(logits, target, weight, lib_=mx, device="gpu")
53
+
54
+ def _forward(
55
+ self,
56
+ logits: Tensor,
57
+ target: Tensor,
58
+ weight: Tensor,
59
+ lib_: ModuleType,
60
+ device: _DeviceType,
61
+ ) -> _FuncOpReturnType:
62
+ if logits.ndim != 2:
63
+ raise ValueError("cross_entropy expects 2D logits [N, C].")
64
+
65
+ N, _ = logits.shape
66
+ max_val = lib_.max(logits.data, axis=1, keepdims=True)
67
+ exp_x = lib_.exp(logits.data - max_val)
68
+ sum_exp = lib_.sum(exp_x, axis=1, keepdims=True)
69
+ log_probs = (logits.data - max_val) - lib_.log(sum_exp)
70
+ probs = exp_x / sum_exp
71
+
72
+ target_int = _to_int(target.data, lib_)
73
+ if lib_ is np:
74
+ idx = np.arange(N, dtype=np.int64)
75
+ else:
76
+ idx = mx.arange(N, dtype=mx.int32)
77
+
78
+ if self.ignore_index is not None:
79
+ mask = target_int != self.ignore_index
80
+ else:
81
+ mask = None
82
+
83
+ gather = log_probs[idx, target_int]
84
+ loss = -gather
85
+
86
+ if self.has_weight:
87
+ w = weight.data
88
+ loss = loss * w[target_int]
89
+ else:
90
+ w = None
91
+
92
+ if mask is not None:
93
+ if lib_ is np:
94
+ loss = loss * mask.astype(loss.dtype)
95
+ else:
96
+ loss = loss * mask.astype(loss.dtype)
97
+
98
+ if self.reduction is None:
99
+ out = loss
100
+ valid_count = None
101
+ elif self.reduction == "sum":
102
+ out = lib_.sum(loss)
103
+ valid_count = None
104
+ elif self.reduction == "mean":
105
+ if mask is None:
106
+ valid_count = N
107
+ else:
108
+ valid_count = lib_.sum(mask)
109
+ if hasattr(valid_count, "item") and valid_count.item() == 0:
110
+ out = lib_.zeros((), dtype=loss.dtype)
111
+ self.result = Tensor(out, device=device)
112
+ return self.result, functools.partial(self.__grad__, lib_=lib_)
113
+ out = lib_.sum(loss) / valid_count
114
+ else:
115
+ raise ValueError("Invalid reduction type. Choose 'mean', 'sum', or 'none'.")
116
+
117
+ self._log_probs = log_probs
118
+ self._probs = probs
119
+ self._target = target_int
120
+ self._weight = w
121
+ self._valid_count = valid_count
122
+
123
+ self.result = Tensor(out, device=device)
124
+ return self.result, functools.partial(self.__grad__, lib_=lib_)
125
+
126
+ def __grad__(self, lib_: ModuleType) -> _GradType:
127
+ if self.result is None or self.result.grad is None:
128
+ raise RuntimeError("cross_entropy backward called before forward.")
129
+ if self._probs is None or self._target is None:
130
+ raise RuntimeError("cross_entropy cached data missing.")
131
+
132
+ probs = self._probs
133
+ target = self._target
134
+ N = probs.shape[0]
135
+ C = probs.shape[1]
136
+
137
+ if lib_ is np:
138
+ grad_input = probs.copy()
139
+ idx = np.arange(N, dtype=np.int64)
140
+ else:
141
+ grad_input = mx.array(probs)
142
+ idx = mx.arange(N, dtype=mx.int32)
143
+
144
+ grad_input[idx, target] = grad_input[idx, target] - 1
145
+
146
+ if self._weight is not None:
147
+ grad_input = grad_input * self._weight[target][:, None]
148
+
149
+ if self.ignore_index is not None:
150
+ if lib_ is np:
151
+ mask = (target != self.ignore_index).astype(grad_input.dtype)
152
+ grad_input = grad_input * mask[:, None]
153
+ else:
154
+ mask = (target != self.ignore_index).astype(grad_input.dtype)
155
+ grad_input = grad_input * mask[:, None]
156
+
157
+ if self.reduction is None:
158
+ go = self.result.grad
159
+ grad_input = grad_input * go[:, None]
160
+ dweight = None
161
+ else:
162
+ go = self.result.grad
163
+ if self.reduction == "mean":
164
+ if self._valid_count is None:
165
+ grad_input = grad_input / N
166
+ else:
167
+ grad_input = grad_input / self._valid_count
168
+ grad_input = grad_input * go
169
+
170
+ dweight = None
171
+ if self._weight is not None:
172
+ if lib_ is np:
173
+ dweight = np.zeros((C,), dtype=grad_input.dtype)
174
+ else:
175
+ dweight = mx.zeros((C,), dtype=grad_input.dtype)
176
+
177
+ if self.reduction is None:
178
+ go_vec = go
179
+ else:
180
+ go_vec = None
181
+
182
+ for c in range(C):
183
+ if lib_ is np:
184
+ mask_c = target == c
185
+ if self.ignore_index is not None:
186
+ mask_c = mask_c & (target != self.ignore_index)
187
+ if go_vec is None:
188
+ contrib = -self._log_probs[mask_c, c]
189
+ if self.reduction == "mean":
190
+ denom = (
191
+ self._valid_count
192
+ if self._valid_count is not None
193
+ else N
194
+ )
195
+ contrib = contrib / denom
196
+ dweight[c] = np.sum(contrib) * go
197
+ else:
198
+ contrib = -self._log_probs[mask_c, c]
199
+ dweight[c] = np.sum(contrib * go_vec[mask_c])
200
+ else:
201
+ mask_c = target == c
202
+ if self.ignore_index is not None:
203
+ mask_c = mask_c & (target != self.ignore_index)
204
+
205
+ contrib = -self._log_probs[mask_c, c]
206
+ if self.reduction == "mean":
207
+ denom = (
208
+ self._valid_count
209
+ if self._valid_count is not None
210
+ else N
211
+ )
212
+ contrib = contrib / denom
213
+
214
+ if go_vec is None:
215
+ dweight = dweight.at[c].add(mx.sum(contrib) * go)
216
+ else:
217
+ dweight = dweight.at[c].add(
218
+ mx.sum(contrib * go_vec[mask_c])
219
+ )
220
+
221
+ return grad_input, None, dweight
222
+
223
+
224
+ class binary_cross_entropy_kernel(Operation):
225
+ def __init__(
226
+ self,
227
+ reduction: str | None = "mean",
228
+ eps: float = 1e-7,
229
+ has_weight: bool = True,
230
+ ) -> None:
231
+ super().__init__()
232
+ self.reduction = reduction
233
+ self.eps = float(eps)
234
+ self.has_weight = bool(has_weight)
235
+
236
+ self._input = None
237
+ self._target = None
238
+ self._weight = None
239
+
240
+ def clear(self) -> None:
241
+ super().clear()
242
+ self._input = None
243
+ self._target = None
244
+ self._weight = None
245
+
246
+ @func_op(n_in=3, n_ret=1)
247
+ def cpu(self, input_: Tensor, target: Tensor, weight: Tensor) -> _FuncOpReturnType:
248
+ return self._forward(input_, target, weight, lib_=np, device="cpu")
249
+
250
+ @func_op(n_in=3, n_ret=1, device="gpu")
251
+ def gpu(self, input_: Tensor, target: Tensor, weight: Tensor) -> _FuncOpReturnType:
252
+ return self._forward(input_, target, weight, lib_=mx, device="gpu")
253
+
254
+ def _forward(
255
+ self,
256
+ input_: Tensor,
257
+ target: Tensor,
258
+ weight: Tensor,
259
+ lib_: ModuleType,
260
+ device: _DeviceType,
261
+ ) -> _FuncOpReturnType:
262
+ x = input_.data
263
+ t = target.data
264
+ x = lib_.clip(x, self.eps, 1.0 - self.eps)
265
+ loss = -(t * lib_.log(x) + (1 - t) * lib_.log(1 - x))
266
+
267
+ if self.has_weight:
268
+ loss = loss * weight.data
269
+
270
+ if self.reduction is None:
271
+ out = loss
272
+ elif self.reduction == "sum":
273
+ out = lib_.sum(loss)
274
+ elif self.reduction == "mean":
275
+ out = lib_.mean(loss)
276
+ else:
277
+ raise ValueError("Invalid reduction type. Choose 'mean', 'sum', or 'none'.")
278
+
279
+ self._input = x
280
+ self._target = t
281
+ self._weight = weight.data if self.has_weight else None
282
+
283
+ self.result = Tensor(out, device=device)
284
+ return self.result, functools.partial(self.__grad__)
285
+
286
+ def __grad__(self) -> _GradType:
287
+ if self.result is None or self.result.grad is None:
288
+ raise RuntimeError("binary_cross_entropy backward called before forward.")
289
+ if self._input is None or self._target is None:
290
+ raise RuntimeError("binary_cross_entropy cached data missing.")
291
+
292
+ x = self._input
293
+ t = self._target
294
+ grad = (x - t) / (x * (1 - x))
295
+
296
+ if self._weight is not None:
297
+ grad = grad * self._weight
298
+
299
+ if self.reduction is None:
300
+ grad = grad * self.result.grad
301
+ elif self.reduction == "sum":
302
+ grad = grad * self.result.grad
303
+ elif self.reduction == "mean":
304
+ grad = grad * (self.result.grad / x.size)
305
+
306
+ return grad, None, None
307
+
308
+
309
+ class binary_cross_entropy_with_logits_kernel(Operation):
310
+ def __init__(
311
+ self,
312
+ reduction: str | None = "mean",
313
+ has_weight: bool = True,
314
+ has_pos_weight: bool = True,
315
+ ) -> None:
316
+ super().__init__()
317
+ self.reduction = reduction
318
+ self.has_weight = bool(has_weight)
319
+ self.has_pos_weight = bool(has_pos_weight)
320
+
321
+ self._logits = None
322
+ self._target = None
323
+ self._weight = None
324
+ self._pos_weight = None
325
+
326
+ def clear(self) -> None:
327
+ super().clear()
328
+ self._logits = None
329
+ self._target = None
330
+ self._weight = None
331
+ self._pos_weight = None
332
+
333
+ @func_op(n_in=4, n_ret=1)
334
+ def cpu(
335
+ self, logits: Tensor, target: Tensor, weight: Tensor, pos_weight: Tensor
336
+ ) -> _FuncOpReturnType:
337
+ return self._forward(logits, target, weight, pos_weight, lib_=np, device="cpu")
338
+
339
+ @func_op(n_in=4, n_ret=1, device="gpu")
340
+ def gpu(
341
+ self, logits: Tensor, target: Tensor, weight: Tensor, pos_weight: Tensor
342
+ ) -> _FuncOpReturnType:
343
+ return self._forward(logits, target, weight, pos_weight, lib_=mx, device="gpu")
344
+
345
+ def _forward(
346
+ self,
347
+ logits: Tensor,
348
+ target: Tensor,
349
+ weight: Tensor,
350
+ pos_weight: Tensor,
351
+ lib_: ModuleType,
352
+ device: _DeviceType,
353
+ ) -> _FuncOpReturnType:
354
+ x = logits.data
355
+ t = target.data
356
+
357
+ max_val = lib_.maximum(-x, 0)
358
+ sp = max_val + lib_.log(lib_.exp(-max_val) + lib_.exp(-x - max_val))
359
+
360
+ if self.has_pos_weight:
361
+ pw = pos_weight.data
362
+ coeff = 1 + (pw - 1) * t
363
+ loss = (1 - t) * x + coeff * sp
364
+ else:
365
+ pw = None
366
+ loss = lib_.maximum(x, 0) - x * t + sp
367
+
368
+ if self.has_weight:
369
+ loss = loss * weight.data
370
+
371
+ if self.reduction is None:
372
+ out = loss
373
+ elif self.reduction == "sum":
374
+ out = lib_.sum(loss)
375
+ elif self.reduction == "mean":
376
+ out = lib_.mean(loss)
377
+ else:
378
+ raise ValueError("Invalid reduction type. Choose 'mean', 'sum', or 'none'.")
379
+
380
+ self._logits = x
381
+ self._target = t
382
+ self._weight = weight.data if self.has_weight else None
383
+ self._pos_weight = pw
384
+
385
+ self.result = Tensor(out, device=device)
386
+ return self.result, functools.partial(self.__grad__, lib_=lib_)
387
+
388
+ def __grad__(self, lib_: ModuleType) -> _GradType:
389
+ if self.result is None or self.result.grad is None:
390
+ raise RuntimeError(
391
+ "binary_cross_entropy_with_logits backward called before forward."
392
+ )
393
+ if self._logits is None or self._target is None:
394
+ raise RuntimeError("binary_cross_entropy_with_logits cached data missing.")
395
+
396
+ x = self._logits
397
+ t = self._target
398
+ sig = 1.0 / (1.0 + lib_.exp(-x))
399
+
400
+ if self._pos_weight is not None:
401
+ pw = self._pos_weight
402
+ grad = (sig - t) * (1 + (pw - 1) * t)
403
+ else:
404
+ grad = sig - t
405
+
406
+ if self._weight is not None:
407
+ grad = grad * self._weight
408
+
409
+ if self.reduction is None:
410
+ grad = grad * self.result.grad
411
+ elif self.reduction == "sum":
412
+ grad = grad * self.result.grad
413
+ elif self.reduction == "mean":
414
+ grad = grad * (self.result.grad / x.size)
415
+
416
+ return grad, None, None, None