py-sadl 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sadl/grad_ops.py ADDED
@@ -0,0 +1,1158 @@
1
+ """Contains all operations that support gradient calculation.
2
+
3
+ Uses numpy as the backend.
4
+
5
+ The OpType enum is inspired by tinygrad's op categorization, thanks @tinygrad!
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Callable
11
+ from dataclasses import dataclass
12
+ from enum import Enum
13
+ from typing import TYPE_CHECKING, Any, Literal
14
+
15
+ from .backend import xp
16
+ from .utils import copy_array
17
+
18
+ if TYPE_CHECKING:
19
+ from .tensor import Tensor
20
+
21
+
22
+ # Type alias for gradient operations (returns plain arrays, not Tensors)
23
+ # Gradients are raw numerical buffers without computation graph overhead
24
+ GradOp = Callable[..., tuple["xp.ndarray | None", ...]]
25
+
26
+
27
+ class OpType(Enum):
28
+ """Operation category by computational behavior.
29
+
30
+ Inspired by tinygrad's op categorization.
31
+ """
32
+
33
+ ELEMENTWISE = "elementwise" # Point-wise: add, mul, sin, etc.
34
+ REDUCTION = "reduction" # Dimension reduction: sum, mean, max, etc.
35
+ MOVEMENT = "movement" # Data movement: copy_to_device, reshape, etc.
36
+ LINALG = "linalg" # Linear algebra: matmul, etc.
37
+
38
+
39
+ class OpInputs(Enum):
40
+ """Number of tensor inputs to an operation.
41
+
42
+ The enum value equals the input count, e.g. `OpInputs.BINARY.value == 2`.
43
+ """
44
+
45
+ UNARY = 1
46
+ BINARY = 2
47
+ TERNARY = 3
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class GradOpSpec:
52
+ """Specification for a gradient operation.
53
+
54
+ Attributes:
55
+ backward_fn (GradOp): The gradient computation function.
56
+ op_type (OpType): Operation category (elementwise, reduction, etc.).
57
+ op_inputs (OpInputs): Number of inputs (unary, binary, ternary).
58
+ forward_names (tuple[str, ...]): Forward op names mapping to this backward.
59
+ First name is canonical, others are aliases.
60
+ constraints (dict[str, str] | None): Input constraints for testing.
61
+ Maps input name to constraint type, e.g. ``{"x": "positive"}``.
62
+ skip_test (bool): Whether to skip automated finite difference testing.
63
+ skip_reason (str | None): Reason for skipping. Required if skip_test=True.
64
+ """
65
+
66
+ backward_fn: GradOp
67
+ op_type: OpType
68
+ op_inputs: OpInputs
69
+ forward_names: tuple[str, ...]
70
+ constraints: dict[str, str] | None = None
71
+ skip_test: bool = False
72
+ skip_reason: str | None = None
73
+
74
+ def __post_init__(self) -> None:
75
+ """Validate that skip_reason is provided when skip_test is True.
76
+
77
+ Raises:
78
+ ValueError: If skip_test is True but skip_reason is None or empty.
79
+ """
80
+ if self.skip_test and not self.skip_reason:
81
+ raise ValueError("skip_reason is required when skip_test=True")
82
+
83
+
84
+ # The registry maps forward op names to their gradient specifications
85
+ _GRAD_OPS_REGISTRY: dict[str, GradOpSpec] = {}
86
+
87
+
88
+ def register_grad_op( # noqa: PLR0913
89
+ *,
90
+ op_type: OpType,
91
+ op_inputs: OpInputs,
92
+ forward_names: tuple[str, ...] | None = None,
93
+ constraints: dict[str, str] | None = None,
94
+ skip_test: bool = False,
95
+ skip_reason: str | None = None,
96
+ ) -> Callable[[GradOp], GradOp]:
97
+ """Decorator factory to register a gradient operation with metadata.
98
+
99
+ The decorated function should follow the naming convention ``<operation>_backward``.
100
+ It will be registered under all provided forward_names, or under the operation
101
+ name extracted from the function name if forward_names is None.
102
+
103
+ Args:
104
+ op_type (OpType): Operation category (elementwise, reduction, etc.).
105
+ op_inputs (OpInputs): Number of tensor inputs (unary, binary, ternary).
106
+ forward_names (tuple[str, ...] | None): Forward op names to register under.
107
+ If None, extracted from function name.
108
+ constraints (dict[str, str] | None): Input constraints for testing.
109
+ skip_test (bool): Whether to skip automated finite difference testing.
110
+ skip_reason (str | None): Reason for skipping. Required if skip_test=True.
111
+
112
+ Returns:
113
+ Callable[[GradOp], GradOp]: Decorator that registers the grad op.
114
+
115
+ Raises:
116
+ ValueError: If skip_test=True but skip_reason is not provided.
117
+ """
118
+ if skip_test and not skip_reason:
119
+ raise ValueError("skip_reason is required when skip_test=True")
120
+
121
+ def decorator(func: GradOp) -> GradOp:
122
+ canonical_name = func.__name__.rsplit("_", maxsplit=1)[0]
123
+ names = forward_names if forward_names is not None else (canonical_name,)
124
+
125
+ spec = GradOpSpec(
126
+ backward_fn=func,
127
+ op_type=op_type,
128
+ op_inputs=op_inputs,
129
+ forward_names=names,
130
+ constraints=constraints,
131
+ skip_test=skip_test,
132
+ skip_reason=skip_reason,
133
+ )
134
+
135
+ for name in names:
136
+ _GRAD_OPS_REGISTRY[name] = spec
137
+
138
+ return func
139
+
140
+ return decorator
141
+
142
+
143
+ def normalize_grad_op_name(*, name: str, is_reduce: bool = False) -> str:
144
+ """Normalize operation name for registry lookup.
145
+
146
+ Handles the special case where "add" with is_reduce=True maps to "sum".
147
+
148
+ Args:
149
+ name (str): The operation name.
150
+ is_reduce (bool): Whether the operation is a reduction.
151
+
152
+ Returns:
153
+ str: The normalized operation name.
154
+
155
+ Examples:
156
+ >>> normalize_grad_op_name(name="add", is_reduce=True)
157
+ "sum"
158
+ >>> normalize_grad_op_name(name="add", is_reduce=False)
159
+ "add"
160
+ """
161
+ if name == "add" and is_reduce:
162
+ return "sum"
163
+ return name
164
+
165
+
166
+ def get_grad_op(name: str) -> GradOp | None:
167
+ """Get the backward function for a forward operation.
168
+
169
+ Args:
170
+ name (str): Forward operation name (e.g. "add", "matmul").
171
+
172
+ Returns:
173
+ GradOp | None: The gradient function, or None if not found.
174
+ """
175
+ spec = _GRAD_OPS_REGISTRY.get(normalize_grad_op_name(name=name))
176
+ return spec.backward_fn if spec is not None else None
177
+
178
+
179
+ def get_grad_op_spec(name: str) -> GradOpSpec | None:
180
+ """Get the full specification for a gradient operation.
181
+
182
+ Args:
183
+ name (str): Forward operation name.
184
+
185
+ Returns:
186
+ GradOpSpec | None: The full specification, or None if not found.
187
+ """
188
+ return _GRAD_OPS_REGISTRY.get(normalize_grad_op_name(name=name))
189
+
190
+
191
+ def _broadcast_backward(
192
+ x: Tensor,
193
+ grad_out: xp.ndarray,
194
+ ) -> xp.ndarray:
195
+ """Applies a backward gradient operation on broadcasting.
196
+
197
+ Effectively collapses `grad_out` by summing over all
198
+ dimensions of `x` that were broadcasted.
199
+
200
+ Args:
201
+ x (Tensor): The Tensor that was broadcasted.
202
+ grad_out (xp.ndarray): The gradient of the following operation.
203
+
204
+ Returns:
205
+ xp.ndarray: The computed gradient.
206
+ """
207
+ if x.shape == grad_out.shape:
208
+ return grad_out # shapes are the same, no broadcast happened
209
+
210
+ collapse_dim: list[int] = []
211
+ for i in range(max(x.ndim, grad_out.ndim)):
212
+ idx_x = x.ndim - i - 1
213
+ idx_grad_out = grad_out.ndim - i - 1
214
+ if idx_x < 0 or x.shape[idx_x] < grad_out.shape[idx_grad_out]:
215
+ collapse_dim.append(idx_grad_out)
216
+
217
+ return xp.sum(grad_out, axis=tuple(collapse_dim), keepdims=True).reshape(x.shape)
218
+
219
+
220
+ def broadcastable(
221
+ elem_wise_backward_fn: Callable[..., Any],
222
+ ) -> Callable[..., Any]:
223
+ """A decorator to extend element-wise backward gradient computing functions.
224
+
225
+ This decorator adds the ability to support broadcasting.
226
+
227
+ Args:
228
+ elem_wise_backward_fn (Callable[[...], Any]): The backward
229
+ function of an element-wise operation that should support
230
+ broadcasting.
231
+
232
+ Returns:
233
+ Callable[[...], Any]: The wrapper function
234
+ that supports broadcasting.
235
+ """
236
+
237
+ def wrapper(
238
+ *inputs: Tensor,
239
+ compute_grad: tuple[bool],
240
+ grad_out: xp.ndarray,
241
+ **kwargs: Any,
242
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
243
+ x, y = inputs
244
+ grad_x, grad_y = elem_wise_backward_fn(
245
+ *inputs,
246
+ compute_grad=compute_grad,
247
+ grad_out=grad_out,
248
+ **kwargs,
249
+ )
250
+ grad_x = _broadcast_backward(x, grad_x) if grad_x is not None else None
251
+ grad_y = _broadcast_backward(y, grad_y) if grad_y is not None else None
252
+ return grad_x, grad_y
253
+
254
+ # Trick so that register_grad_op decorator registers the,
255
+ # now broadcastable backward function under the name of the
256
+ # original backward function `elem_wise_backward_fn` (e.g. `add_backward`),
257
+ # instead of under the name `wrapper`
258
+ wrapper.__name__ = elem_wise_backward_fn.__name__
259
+
260
+ return wrapper
261
+
262
+
263
+ @register_grad_op(
264
+ op_type=OpType.ELEMENTWISE,
265
+ op_inputs=OpInputs.UNARY,
266
+ forward_names=("absolute", "abs"),
267
+ )
268
+ def absolute_backward(
269
+ *inputs: Tensor,
270
+ compute_grad: tuple[bool],
271
+ grad_out: xp.ndarray,
272
+ ) -> tuple[xp.ndarray | None]:
273
+ """Computes gradients for `abs(x) = |x| = z`.
274
+
275
+ Args:
276
+ *inputs (Tensor): The inputs, expected to be of length 1.
277
+ Expects: `tuple[0]` to be `x`.
278
+ compute_grad (tuple[bool]): Flags indicating which input gradients
279
+ to compute, aligned with `inputs`.
280
+ grad_out (xp.ndarray): Upstream gradient.
281
+
282
+ Returns:
283
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
284
+ """
285
+ x = inputs[0]
286
+ x_grad = xp.sign(x) * grad_out if compute_grad[0] else None
287
+ return (x_grad,)
288
+
289
+
290
+ @register_grad_op(
291
+ op_type=OpType.ELEMENTWISE,
292
+ op_inputs=OpInputs.UNARY,
293
+ )
294
+ def negative_backward(
295
+ *inputs: Tensor, # noqa: ARG001
296
+ compute_grad: tuple[bool],
297
+ grad_out: xp.ndarray,
298
+ ) -> tuple[xp.ndarray | None]:
299
+ """Computes gradients for `-x = -1 * x = z`.
300
+
301
+ Args:
302
+ *inputs (Tensor): The inputs, expected to be of length 1.
303
+ Expects: `tuple[0]` to be `x`.
304
+ compute_grad (tuple[bool]): Flags indicating which input gradients
305
+ to compute, aligned with `inputs`.
306
+ grad_out (xp.ndarray): Upstream gradient.
307
+
308
+ Returns:
309
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
310
+ """
311
+ x_grad = -1 * grad_out if compute_grad[0] else None
312
+ return (x_grad,)
313
+
314
+
315
+ @register_grad_op(
316
+ op_type=OpType.ELEMENTWISE,
317
+ op_inputs=OpInputs.BINARY,
318
+ )
319
+ @broadcastable
320
+ def add_backward(
321
+ *inputs: Tensor, # noqa: ARG001
322
+ compute_grad: tuple[bool, bool],
323
+ grad_out: xp.ndarray,
324
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
325
+ """Computes gradients for `x + y = z`.
326
+
327
+ Args:
328
+ *inputs (Tensor): Two inputs `(x, y)`.
329
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
330
+ to compute, aligned with `inputs`.
331
+ grad_out (xp.ndarray): Upstream gradient.
332
+
333
+ Returns:
334
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
335
+ `None` where `compute_grad[i]` is False.
336
+ """
337
+ x_grad = grad_out if compute_grad[0] else None
338
+ y_grad = grad_out if compute_grad[1] else None
339
+ return x_grad, y_grad
340
+
341
+
342
+ # Theoretically we do not need a backward for subtract
343
+ # -> "x - y = x + (-y) = z", so we could chain the
344
+ # "negative" and "add" backward functions, but a single
345
+ # "substract" function is more efficient
346
+ @register_grad_op(
347
+ op_type=OpType.ELEMENTWISE,
348
+ op_inputs=OpInputs.BINARY,
349
+ )
350
+ @broadcastable
351
+ def subtract_backward(
352
+ *inputs: Tensor, # noqa: ARG001
353
+ compute_grad: tuple[bool, bool],
354
+ grad_out: xp.ndarray,
355
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
356
+ """Computes gradients for `x - y = z`.
357
+
358
+ Args:
359
+ *inputs (Tensor): Two inputs `(x, y)`.
360
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
361
+ to compute, aligned with `inputs`.
362
+ grad_out (xp.ndarray): Upstream gradient.
363
+
364
+ Returns:
365
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
366
+ `None` where `compute_grad[i]` is False.
367
+ """
368
+ x_grad = grad_out if compute_grad[0] else None
369
+ y_grad = -1 * grad_out if compute_grad[1] else None
370
+ return x_grad, y_grad
371
+
372
+
373
+ @register_grad_op(
374
+ op_type=OpType.ELEMENTWISE,
375
+ op_inputs=OpInputs.BINARY,
376
+ forward_names=("mul", "multiply"),
377
+ )
378
+ @broadcastable
379
+ def mul_backward(
380
+ *inputs: Tensor,
381
+ compute_grad: tuple[bool, bool],
382
+ grad_out: xp.ndarray,
383
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
384
+ """Computes the gradient for multiplication `x * y = z`.
385
+
386
+ Args:
387
+ *inputs (Tensor): The inputs, expected to be of length 2.
388
+ Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
389
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
390
+ to compute, aligned with `inputs`.
391
+ grad_out (xp.ndarray): The gradient of the following
392
+ operation.
393
+
394
+ Returns:
395
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
396
+ `None` where `compute_grad[i]` is False.
397
+ """
398
+ x, y = inputs
399
+ grad_x = y * grad_out if compute_grad[0] else None
400
+ grad_y = x * grad_out if compute_grad[1] else None
401
+ return grad_x, grad_y
402
+
403
+
404
+ @register_grad_op(
405
+ op_type=OpType.ELEMENTWISE,
406
+ op_inputs=OpInputs.BINARY,
407
+ forward_names=("div", "divide"),
408
+ constraints={"y": "positive"}, # avoid division by zero
409
+ )
410
+ @broadcastable
411
+ def div_backward(
412
+ *inputs: Tensor,
413
+ compute_grad: tuple[bool, bool],
414
+ grad_out: xp.ndarray,
415
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
416
+ """Computes the gradient for division `x / y = z`.
417
+
418
+ Args:
419
+ *inputs (Tensor): The inputs, expected to be of length 2.
420
+ Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
421
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
422
+ to compute, aligned with `inputs`.
423
+ grad_out (xp.ndarray): The gradient of the following
424
+ operation.
425
+
426
+ Returns:
427
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
428
+ `None` where `compute_grad[i]` is False.
429
+ """
430
+ x, y = inputs
431
+ grad_x = grad_out / y if compute_grad[0] else None
432
+ grad_y = -x * grad_out / (y * y) if compute_grad[1] else None
433
+ return grad_x, grad_y
434
+
435
+
436
+ @register_grad_op(
437
+ op_type=OpType.LINALG,
438
+ op_inputs=OpInputs.BINARY,
439
+ )
440
+ def matmul_backward(
441
+ *inputs: Tensor,
442
+ compute_grad: tuple[bool, bool],
443
+ grad_out: xp.ndarray,
444
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
445
+ """Computes the gradient for matrix multiplication `AB = Z`.
446
+
447
+ Batched matrix multiplication is also supported, so inputs can
448
+ have dimensions `(..., i, j)` and `(..., j, k)`
449
+
450
+ Args:
451
+ *inputs (Tensor): The inputs, expected to be of length 2.
452
+ Expects: `tuple[0]` to be the first matrix (or batch of matrices),
453
+ `tuple[1]` to be the second matrix (or batch of matrices).
454
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
455
+ to compute, aligned with `inputs`.
456
+ grad_out (xp.ndarray): The gradient of the following
457
+ operation.
458
+
459
+ Returns:
460
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(A, B)`, with
461
+ `None` where `compute_grad[i]` is False.
462
+ """
463
+ A, B = inputs
464
+ grad_x = xp.matmul(grad_out, xp.swapaxes(B, -2, -1)) if compute_grad[0] else None
465
+ grad_y = xp.matmul(xp.swapaxes(A, -2, -1), grad_out) if compute_grad[1] else None
466
+
467
+ return grad_x, grad_y
468
+
469
+
470
+ @register_grad_op(
471
+ op_type=OpType.ELEMENTWISE,
472
+ op_inputs=OpInputs.UNARY,
473
+ constraints={"x": "positive"},
474
+ )
475
+ def sqrt_backward(
476
+ *inputs: Tensor,
477
+ compute_grad: tuple[bool],
478
+ grad_out: xp.ndarray,
479
+ ) -> tuple[xp.ndarray | None]:
480
+ """Computes the gradient for square root `sqrt(x) = z`.
481
+
482
+ Args:
483
+ *inputs (Tensor): The inputs, expected to be of length 1.
484
+ Expects: `tuple[0]` to be `x`.
485
+ compute_grad (tuple[bool]): Flags indicating which input gradients
486
+ to compute, aligned with `inputs`.
487
+ grad_out (xp.ndarray): The gradient of the following
488
+ operation.
489
+
490
+ Returns:
491
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
492
+ """
493
+ x = inputs[0]
494
+ grad_x = grad_out / (2 * xp.sqrt(x)) if compute_grad[0] else None
495
+ return (grad_x,)
496
+
497
+
498
+ @register_grad_op(
499
+ op_type=OpType.ELEMENTWISE,
500
+ op_inputs=OpInputs.BINARY,
501
+ constraints={"x": "positive"}, # avoid complex numbers with non-integer exponents
502
+ )
503
+ def power_backward(
504
+ *inputs: Tensor,
505
+ compute_grad: tuple[bool, bool],
506
+ grad_out: xp.ndarray,
507
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
508
+ """Computes the gradient for power `x^y = z`.
509
+
510
+ Args:
511
+ *inputs (Tensor): The inputs, expected to be of length 2.
512
+ Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
513
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
514
+ to compute, aligned with `inputs`.
515
+ grad_out (xp.ndarray): The gradient of the following
516
+ operation.
517
+
518
+ Returns:
519
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
520
+ `None` where `compute_grad[i]` is False.
521
+ """
522
+ x, y = inputs
523
+ grad_x = y * xp.pow(x, y - 1) * grad_out if compute_grad[0] else None
524
+ grad_y = xp.pow(x, y) * xp.log(x) * grad_out if compute_grad[1] else None
525
+ return grad_x, grad_y
526
+
527
+
528
+ @register_grad_op(
529
+ op_type=OpType.ELEMENTWISE,
530
+ op_inputs=OpInputs.UNARY,
531
+ )
532
+ def square_backward(
533
+ *inputs: Tensor,
534
+ compute_grad: tuple[bool],
535
+ grad_out: xp.ndarray,
536
+ ) -> tuple[xp.ndarray | None]:
537
+ """Computes the gradient for the square op (`x^2 = z`).
538
+
539
+ Args:
540
+ *inputs (Tensor): The inputs, expected to be of length 1.
541
+ Expects: `tuple[0]` to be `x`.
542
+ compute_grad (tuple[bool]): Flags indicating which input gradients
543
+ to compute, aligned with `inputs`.
544
+ grad_out (xp.ndarray): The gradient of the following
545
+ operation.
546
+
547
+ Returns:
548
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
549
+ """
550
+ x = inputs[0]
551
+ grad_x = 2 * x * grad_out if compute_grad[0] else None
552
+ return (grad_x,)
553
+
554
+
555
+ @register_grad_op(
556
+ op_type=OpType.ELEMENTWISE,
557
+ op_inputs=OpInputs.UNARY,
558
+ )
559
+ def exp_backward(
560
+ *inputs: Tensor,
561
+ compute_grad: tuple[bool],
562
+ grad_out: xp.ndarray,
563
+ ) -> tuple[xp.ndarray | None]:
564
+ """Computes the gradient for exponentiation `e^x = z`.
565
+
566
+ Args:
567
+ *inputs (Tensor): The inputs, expected to be of length 1.
568
+ Expects: `tuple[0]` to be `x`.
569
+ compute_grad (tuple[bool]): Flags indicating which input gradients
570
+ to compute, aligned with `inputs`.
571
+ grad_out (xp.ndarray): The gradient of the following
572
+ operation.
573
+
574
+ Returns:
575
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
576
+ """
577
+ x = inputs[0]
578
+ grad_x = grad_out * xp.exp(x) if compute_grad[0] else None
579
+ return (grad_x,)
580
+
581
+
582
+ @register_grad_op(
583
+ op_type=OpType.ELEMENTWISE,
584
+ op_inputs=OpInputs.UNARY,
585
+ constraints={"x": "positive"},
586
+ )
587
+ def log_backward(
588
+ *inputs: Tensor,
589
+ compute_grad: tuple[bool],
590
+ grad_out: xp.ndarray,
591
+ ) -> tuple[xp.ndarray | None]:
592
+ """Computes the gradient for the logarithm `log(x) = z`.
593
+
594
+ Args:
595
+ *inputs (Tensor): The inputs, expected to be of length 1.
596
+ Expects: `tuple[0]` to be `x`.
597
+ compute_grad (tuple[bool]): Flags indicating which input gradients
598
+ to compute, aligned with `inputs`.
599
+ grad_out (xp.ndarray): The gradient of the following
600
+ operation.
601
+
602
+ Returns:
603
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
604
+ """
605
+ x = inputs[0]
606
+ grad_x = grad_out / x if compute_grad[0] else None
607
+ return (grad_x,)
608
+
609
+
610
+ @register_grad_op(
611
+ op_type=OpType.ELEMENTWISE,
612
+ op_inputs=OpInputs.UNARY,
613
+ )
614
+ def sin_backward(
615
+ *inputs: Tensor,
616
+ compute_grad: tuple[bool],
617
+ grad_out: xp.ndarray,
618
+ ) -> tuple[xp.ndarray | None]:
619
+ """Computes the gradient for the sinus function `sin(x) = z`.
620
+
621
+ Args:
622
+ *inputs (Tensor): The inputs, expected to be of length 1.
623
+ Expects: `tuple[0]` to be `x`.
624
+ compute_grad (tuple[bool]): Flags indicating which input gradients
625
+ to compute, aligned with `inputs`.
626
+ grad_out (xp.ndarray): The gradient of the following
627
+ operation.
628
+
629
+ Returns:
630
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
631
+ """
632
+ x = inputs[0]
633
+ grad_x = xp.cos(x) * grad_out if compute_grad[0] else None
634
+ return (grad_x,)
635
+
636
+
637
+ @register_grad_op(
638
+ op_type=OpType.ELEMENTWISE,
639
+ op_inputs=OpInputs.UNARY,
640
+ )
641
+ def cos_backward(
642
+ *inputs: Tensor,
643
+ compute_grad: tuple[bool],
644
+ grad_out: xp.ndarray,
645
+ ) -> tuple[xp.ndarray | None]:
646
+ """Computes the gradient for the cosine function `cos(x) = z`.
647
+
648
+ Args:
649
+ *inputs (Tensor): The inputs, expected to be of length 1.
650
+ Expects: `tuple[0]` to be `x`.
651
+ compute_grad (tuple[bool]): Flags indicating which input gradients
652
+ to compute, aligned with `inputs`.
653
+ grad_out (xp.ndarray): The gradient of the following
654
+ operation.
655
+
656
+ Returns:
657
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
658
+ """
659
+ x = inputs[0]
660
+ grad_x = -xp.sin(x) * grad_out if compute_grad[0] else None
661
+ return (grad_x,)
662
+
663
+
664
+ @register_grad_op(
665
+ op_type=OpType.REDUCTION,
666
+ op_inputs=OpInputs.UNARY,
667
+ )
668
+ def sum_backward(
669
+ *inputs: Tensor,
670
+ compute_grad: tuple[bool],
671
+ grad_out: xp.ndarray,
672
+ **kwargs: Any,
673
+ ) -> tuple[xp.ndarray | None]:
674
+ """Computes the gradient for the sum function `sum(x) = z`.
675
+
676
+ Args:
677
+ *inputs (Tensor): The inputs, expected to be of length 1.
678
+ Expects: `tuple[0]` to be `x`.
679
+ compute_grad (tuple[bool]): Flags indicating which input gradients
680
+ to compute, aligned with `inputs`.
681
+ grad_out (xp.ndarray): The gradient of the following
682
+ operation.
683
+ **kwargs (Any): Additional arguments, expects `axis`
684
+ to be present, denoting the axis of x over which the sum
685
+ was done.
686
+
687
+ Returns:
688
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
689
+ """
690
+ x = inputs[0]
691
+
692
+ grad_x: xp.ndarray | None = None
693
+
694
+ if compute_grad[0]:
695
+ kwargs_axis = kwargs.get("axis")
696
+ axis = tuple(kwargs_axis) if kwargs_axis is not None else tuple(range(x.ndim))
697
+
698
+ grad_x = xp.broadcast_to(xp.expand_dims(grad_out, axis=axis), shape=x.shape)
699
+
700
+ return (grad_x,)
701
+
702
+
703
+ @register_grad_op(
704
+ op_type=OpType.REDUCTION,
705
+ op_inputs=OpInputs.UNARY,
706
+ )
707
+ def mean_backward(
708
+ *inputs: Tensor,
709
+ compute_grad: tuple[bool],
710
+ grad_out: xp.ndarray,
711
+ **kwargs: Any,
712
+ ) -> tuple[xp.ndarray | None]:
713
+ """Computes the gradient for the mean function `mean(x) = z`.
714
+
715
+ Args:
716
+ *inputs (Tensor): The inputs, expected to be of length 1.
717
+ Expects: `tuple[0]` to be `x`.
718
+ compute_grad (tuple[bool]): Flags indicating which input gradients
719
+ to compute, aligned with `inputs`.
720
+ grad_out (xp.ndarray): The gradient of the following
721
+ operation.
722
+ **kwargs (Any): Additional arguments, expects `axis`
723
+ to be present, denoting the axis of x over which the sum
724
+ was done.
725
+
726
+ Returns:
727
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
728
+ """
729
+ x = inputs[0]
730
+
731
+ grad_x: xp.ndarray | None = None
732
+
733
+ if compute_grad[0]:
734
+ kwargs_axis = kwargs.get("axis")
735
+ axis = tuple(kwargs_axis) if kwargs_axis is not None else tuple(range(x.ndim))
736
+
737
+ grad_x = xp.broadcast_to(xp.expand_dims(grad_out, axis=axis), shape=x.shape)
738
+
739
+ n_reduced_elem = xp.prod([x.shape[a] for a in axis])
740
+
741
+ grad_x = grad_x / n_reduced_elem
742
+
743
+ return (grad_x,)
744
+
745
+
746
+ def _extremum_backward(
747
+ *inputs: Tensor,
748
+ op_type: Literal["min", "max"],
749
+ grad_out: xp.ndarray,
750
+ **kwargs: Any,
751
+ ) -> tuple[xp.ndarray]:
752
+ """Computes the gradient for an extremum function.
753
+
754
+ Extemum is either `min` or `max`: `f(x) = z, f ∈ {min, max}`.
755
+
756
+ Args:
757
+ *inputs (Tensor): The inputs, expected to be of length 1.
758
+ Expects: `tuple[0]` to be `x`.
759
+ op_type (Literal["min", "max"]): The extremum type for which to
760
+ compute the gradient. Must be in ["min", "max"].
761
+ grad_out (xp.ndarray): The gradient of the following
762
+ operation.
763
+ **kwargs (Any): Additional arguments, tries to extract:
764
+
765
+ - `axis`: Denoting the axis of x over which the
766
+ operation was done (defaults to None).
767
+
768
+ - `x_mask`: Stores whether a value in `x` is an extremum
769
+ (`min` for minimum, `max` for maximum)
770
+ along the reduced axes (must be of the same shape as `x`).
771
+ This argument is **required** to avoid numerical instability
772
+ when supressing all non-extrema values (this is the backward function).
773
+
774
+ - `keepdims`: Whether the dimensions over which was
775
+ reduced were retained (defaults to False).
776
+
777
+ Returns:
778
+ tuple[xp.ndarray]: The computed gradient with respect to `x`.
779
+ """
780
+ x = inputs[0]
781
+
782
+ x_mask = kwargs.get("x_mask")
783
+ if x_mask is None:
784
+ raise ValueError(
785
+ f'Missing keyword argument "x_mask" is required in backward for {op_type}.'
786
+ )
787
+
788
+ assert x_mask.shape == x.shape, "Extremum mask must have the same shape as x"
789
+
790
+ kwargs_axis = kwargs.get("axis")
791
+ if kwargs_axis is None:
792
+ axis = tuple(range(x.ndim))
793
+ else:
794
+ axis = make_axis(ndim=x.ndim, axis_candidate=kwargs_axis)
795
+
796
+ if not kwargs.get("keepdims", False):
797
+ grad_out = xp.expand_dims(grad_out, axis=axis)
798
+
799
+ grad_out = xp.broadcast_to(grad_out, shape=x.shape)
800
+
801
+ count = xp.sum(x_mask, axis=axis, keepdims=True)
802
+ assert xp.all(count > 0), (
803
+ f'There must be at least one {op_type} along the reduced axis "{axis}"'
804
+ )
805
+
806
+ grad_x = xp.where(x_mask, grad_out / count, 0)
807
+ # Scale the gradient by the inverse of the number of extremas there were
808
+ # along the reduced axes.
809
+ # If one axis had 3 extremas, then the gradient is divided between them
810
+ # equally. In that case, the gradients for all 3 extremas are scaled by 1/3.
811
+
812
+ return (grad_x,)
813
+
814
+
815
+ @register_grad_op(
816
+ op_type=OpType.REDUCTION,
817
+ op_inputs=OpInputs.UNARY,
818
+ skip_test=True,
819
+ skip_reason="requires x_mask computed during forward pass",
820
+ )
821
+ def max_backward(
822
+ *inputs: Tensor,
823
+ compute_grad: tuple[bool],
824
+ grad_out: xp.ndarray,
825
+ **kwargs: Any,
826
+ ) -> tuple[xp.ndarray | None]:
827
+ """Computes the gradient for the max function: max(x) = z.
828
+
829
+ Args:
830
+ *inputs (Tensor): The inputs, expected to be of length 1.
831
+ Expects: `tuple[0]` to be `x`.
832
+ compute_grad (tuple[bool]): Flags indicating which input gradients
833
+ to compute, aligned with `inputs`.
834
+ grad_out (xp.ndarray): The gradient of the following
835
+ operation.
836
+ **kwargs (Any): Additional arguments, tries to extract:
837
+
838
+ - `axis`: Denoting the axis of x over which the
839
+ operation was done (defaults to None).
840
+
841
+ - `x_mask`: Stores whether a value in `x` is a max
842
+ along the reduced axes (must be of the same shape as `x`).
843
+ This argument is **required** to avoid numerical instability
844
+ when supressing all non-max values (this is the backward function).
845
+
846
+ - `keepdims`: Whether the dimensions over which was
847
+ reduced were retained (defaults to False).
848
+
849
+ Returns:
850
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
851
+ """
852
+ if not compute_grad[0]:
853
+ return (None,)
854
+ # else ->
855
+ return _extremum_backward(
856
+ *inputs,
857
+ op_type="max",
858
+ grad_out=grad_out,
859
+ **kwargs,
860
+ )
861
+
862
+
863
+ @register_grad_op(
864
+ op_type=OpType.REDUCTION,
865
+ op_inputs=OpInputs.UNARY,
866
+ skip_test=True,
867
+ skip_reason="requires x_mask computed during forward pass",
868
+ )
869
+ def min_backward(
870
+ *inputs: Tensor,
871
+ compute_grad: tuple[bool],
872
+ grad_out: xp.ndarray,
873
+ **kwargs: Any,
874
+ ) -> tuple[xp.ndarray | None]:
875
+ """Computes the gradient for the min function: min(x) = z.
876
+
877
+ Args:
878
+ *inputs (Tensor): The inputs, expected to be of length 1.
879
+ Expects: `tuple[0]` to be `x`.
880
+ compute_grad (tuple[bool]): Flags indicating which input gradients
881
+ to compute, aligned with `inputs`.
882
+ grad_out (xp.ndarray): The gradient of the following
883
+ operation.
884
+ **kwargs (Any): Additional arguments, tries to extract:
885
+
886
+ - `axis`: Denoting the axis of x over which the
887
+ operation was done (defaults to None).
888
+
889
+ - `x_mask`: Stores whether a value in `x` is a min
890
+ along the reduced axes (must be of the same shape as `x`).
891
+ This argument is **required** to avoid numerical instability
892
+ when supressing all non-min values (this is the backward function).
893
+
894
+ - `keepdims`: Whether the dimensions over which was
895
+ reduced were retained (defaults to False).
896
+
897
+ Returns:
898
+ tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
899
+ """
900
+ if not compute_grad[0]:
901
+ return (None,)
902
+ # else ->
903
+ return _extremum_backward(
904
+ *inputs,
905
+ op_type="min",
906
+ grad_out=grad_out,
907
+ **kwargs,
908
+ )
909
+
910
+
911
+ def _element_wise_extremum_backward(
912
+ *inputs: Tensor,
913
+ compute_grad: tuple[bool, bool],
914
+ op_type: Literal["minimum", "maximum"],
915
+ grad_out: xp.ndarray,
916
+ **kwargs: Any,
917
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
918
+ """Computes the gradient for an element-wise extremum function.
919
+
920
+ Extemum is either `minimum` or `maximum`:
921
+ `f(x, y) = z, f ∈ {minimum, maximum}`.
922
+
923
+ Args:
924
+ *inputs (Tensor): The inputs, expected to be of length 2.
925
+ Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
926
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
927
+ to compute, aligned with `inputs`.
928
+ op_type (Literal["minimum", "maximum"]): The extremum type
929
+ for which to compute the gradient.
930
+ Must be in ["minimum", "maximum"].
931
+ grad_out (xp.ndarray): The gradient of the following
932
+ operation.
933
+ **kwargs (Any): Additional arguments, tries to extract:
934
+
935
+ - `x_mask`: Stores whether a value in `x` is an extremum
936
+ (`min` for minimum, `max` for maximum) compared
937
+ to the value at the same location in `y`
938
+ (must be of the same shape as `x`). This argument
939
+ is **required** to avoid numerical instability
940
+ when supressing all non-extrema values (this is the backward function).
941
+
942
+ - `where`: On which locations to apply `f`.
943
+
944
+ Returns:
945
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
946
+ `None` where `compute_grad[i]` is False.
947
+ """
948
+ x = inputs[0]
949
+ y = inputs[1]
950
+
951
+ if x.shape != y.shape:
952
+ raise ValueError(
953
+ f"Both inputs must be of same shape, got {x.shape} for x and {y.shape} for y"
954
+ )
955
+
956
+ x_mask = kwargs.get("x_mask")
957
+ if x_mask is None:
958
+ raise ValueError(
959
+ f'Missing keyword argument "x_mask" is required in backward for {op_type}.'
960
+ )
961
+
962
+ apply_grad_op = kwargs.get("where", 1)
963
+
964
+ assert x_mask.shape == x.shape, "Extremum mask must have the same shape as x"
965
+
966
+ both_extremum = x == y
967
+
968
+ x_grad: xp.ndarray | None = None
969
+ y_grad: xp.ndarray | None = None
970
+
971
+ if compute_grad[0]:
972
+ scale_x = xp.where(both_extremum, 0.5, x_mask)
973
+ x_grad = apply_grad_op * scale_x * grad_out
974
+
975
+ if compute_grad[1]:
976
+ scale_y = xp.where(both_extremum, 0.5, 1 - x_mask)
977
+ y_grad = apply_grad_op * scale_y * grad_out
978
+
979
+ return x_grad, y_grad
980
+
981
+
982
+ @register_grad_op(
983
+ op_type=OpType.ELEMENTWISE,
984
+ op_inputs=OpInputs.BINARY,
985
+ skip_test=True,
986
+ skip_reason="requires x_mask computed during forward pass",
987
+ )
988
+ @broadcastable
989
+ def maximum_backward(
990
+ *inputs: Tensor,
991
+ compute_grad: tuple[bool, bool],
992
+ grad_out: xp.ndarray,
993
+ **kwargs: Any,
994
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
995
+ """Computes the gradient for an element-wise maximum.
996
+
997
+ Function is `maximum(x, y) = z`.
998
+
999
+ Args:
1000
+ *inputs (Tensor): The inputs, expected to be of length 2.
1001
+ Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
1002
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
1003
+ to compute, aligned with `inputs`.
1004
+ grad_out (xp.ndarray): The gradient of the following
1005
+ operation.
1006
+ **kwargs (Any): Additional arguments, tries to extract:
1007
+
1008
+ - `x_mask`: Stores whether a value in `x` is
1009
+ larger compared to the value at the same
1010
+ location in `y` (must be of the same shape as `x`).
1011
+ This argument is **required** to avoid numerical instability
1012
+ when supressing all non-maximum values (this is the backward function).
1013
+
1014
+ Returns:
1015
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
1016
+ `None` where `compute_grad[i]` is False.
1017
+ """
1018
+ return _element_wise_extremum_backward(
1019
+ *inputs,
1020
+ compute_grad=compute_grad,
1021
+ op_type="maximum",
1022
+ grad_out=grad_out,
1023
+ **kwargs,
1024
+ )
1025
+
1026
+
1027
+ @register_grad_op(
1028
+ op_type=OpType.ELEMENTWISE,
1029
+ op_inputs=OpInputs.BINARY,
1030
+ skip_test=True,
1031
+ skip_reason="requires x_mask computed during forward pass",
1032
+ )
1033
+ @broadcastable
1034
+ def minimum_backward(
1035
+ *inputs: Tensor,
1036
+ compute_grad: tuple[bool, bool],
1037
+ grad_out: xp.ndarray,
1038
+ **kwargs: Any,
1039
+ ) -> tuple[xp.ndarray | None, xp.ndarray | None]:
1040
+ """Computes the gradient for an element-wise minimum.
1041
+
1042
+ Function is `minimum(x, y) = z`.
1043
+
1044
+ Args:
1045
+ *inputs (Tensor): The inputs, expected to be of length 2.
1046
+ Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
1047
+ compute_grad (tuple[bool, bool]): Flags indicating which input gradients
1048
+ to compute, aligned with `inputs`.
1049
+ grad_out (xp.ndarray): The gradient of the following
1050
+ operation.
1051
+ **kwargs (Any): Additional arguments, tries to extract:
1052
+
1053
+ - `x_mask`: Stores whether a value in `x` is
1054
+ smaller compared to the value at the same
1055
+ location in `y` (must be of the same shape as `x`).
1056
+ This argument is **required** to avoid numerical instability
1057
+ when supressing all non-minimum values (this is the backward function).
1058
+
1059
+ Returns:
1060
+ tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
1061
+ `None` where `compute_grad[i]` is False.
1062
+ """
1063
+ return _element_wise_extremum_backward(
1064
+ *inputs,
1065
+ compute_grad=compute_grad,
1066
+ op_type="minimum",
1067
+ grad_out=grad_out,
1068
+ **kwargs,
1069
+ )
1070
+
1071
+
1072
+ @register_grad_op(
1073
+ op_type=OpType.MOVEMENT,
1074
+ op_inputs=OpInputs.UNARY,
1075
+ skip_test=True,
1076
+ skip_reason="not testable with finite differences",
1077
+ )
1078
+ def copy_to_device_backward(
1079
+ *inputs: Tensor,
1080
+ compute_grad: tuple[bool],
1081
+ grad_out: xp.ndarray,
1082
+ ) -> tuple[xp.ndarray | None]:
1083
+ """Computes gradients for copying a Tensor to a (different) device.
1084
+
1085
+ Args:
1086
+ *inputs (Tensor): The input Tensor, should only be one, as
1087
+ a copy operation only operates on a single Tensor.
1088
+ compute_grad (tuple[bool]): Flags indicating which input gradients
1089
+ to compute, aligned with `inputs`.
1090
+ grad_out (xp.ndarray): Upstream gradient.
1091
+
1092
+ Returns:
1093
+ tuple[xp.ndarray | None]: Gradients for the Tensor, with
1094
+ `None` where `compute_grad[i]` is False.
1095
+ """
1096
+ x = inputs[0]
1097
+ # Just pass through grad_out by reverting the "copy_to_device" op
1098
+ # That means copying grad_out back to the device of x
1099
+ x_grad = copy_array(array=grad_out, device=x.device) if compute_grad[0] else None
1100
+ return (x_grad,)
1101
+
1102
+
1103
+ def make_axis(ndim: int, axis_candidate: Any) -> tuple[int, ...]:
1104
+ """Transforms the `axis` argument for numpy ops.
1105
+
1106
+ Returns a consistent `tuple[int, ...]` type.
1107
+ Transforms negative axes into positive ones for
1108
+ consistency.
1109
+
1110
+ Args:
1111
+ ndim (int): Number of dimensions of the array
1112
+ on which the numpy op is performed.
1113
+ axis_candidate (Any): The `axis` argument
1114
+ used in the numpy op.
1115
+
1116
+ Raises:
1117
+ ValueError: If `axis_candidate` is an invalid
1118
+ numpy `axis` type.
1119
+
1120
+ Returns:
1121
+ tuple[int, ...]: The consistent numpy `axis`
1122
+ argument.
1123
+ """
1124
+ if isinstance(axis_candidate, int):
1125
+ axis = (axis_candidate,)
1126
+ elif isinstance(axis_candidate, list) and all(isinstance(a, int) for a in axis_candidate):
1127
+ axis = tuple(axis_candidate)
1128
+
1129
+ elif isinstance(axis_candidate, tuple) and all(isinstance(a, int) for a in axis_candidate):
1130
+ axis = axis_candidate
1131
+
1132
+ else:
1133
+ raise ValueError(
1134
+ '"axis_candidate" must be in '
1135
+ f'[int, list[int], tuple[int]], found: "{type(axis_candidate).__name__}".'
1136
+ )
1137
+
1138
+ def make_positive(value: int) -> int:
1139
+ if value < 0:
1140
+ return ndim + value # normalize axes to positive values
1141
+ return value
1142
+
1143
+ normalized_axis = tuple(map(make_positive, axis))
1144
+
1145
+ assert all(0 <= a < ndim for a in normalized_axis)
1146
+
1147
+ return normalized_axis
1148
+
1149
+
1150
+ __all__ = [
1151
+ "GradOp",
1152
+ "GradOpSpec",
1153
+ "OpInputs",
1154
+ "OpType",
1155
+ "get_grad_op",
1156
+ "get_grad_op_spec",
1157
+ "register_grad_op",
1158
+ ]