py-sadl 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_sadl-1.0.2.dist-info/METADATA +338 -0
- py_sadl-1.0.2.dist-info/RECORD +13 -0
- py_sadl-1.0.2.dist-info/WHEEL +4 -0
- py_sadl-1.0.2.dist-info/licenses/LICENSE +21 -0
- sadl/__init__.py +74 -0
- sadl/backend.py +45 -0
- sadl/disk.py +147 -0
- sadl/function.py +415 -0
- sadl/grad_ops.py +1158 -0
- sadl/ops.py +67 -0
- sadl/optimizer.py +352 -0
- sadl/tensor.py +531 -0
- sadl/utils.py +33 -0
sadl/grad_ops.py
ADDED
|
@@ -0,0 +1,1158 @@
|
|
|
1
|
+
"""Contains all operations that support gradient calculation.
|
|
2
|
+
|
|
3
|
+
Uses numpy as the backend.
|
|
4
|
+
|
|
5
|
+
The OpType enum is inspired by tinygrad's op categorization, thanks @tinygrad!
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
14
|
+
|
|
15
|
+
from .backend import xp
|
|
16
|
+
from .utils import copy_array
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .tensor import Tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Type alias for gradient operations (returns plain arrays, not Tensors)
|
|
23
|
+
# Gradients are raw numerical buffers without computation graph overhead
|
|
24
|
+
GradOp = Callable[..., tuple["xp.ndarray | None", ...]]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpType(Enum):
|
|
28
|
+
"""Operation category by computational behavior.
|
|
29
|
+
|
|
30
|
+
Inspired by tinygrad's op categorization.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
ELEMENTWISE = "elementwise" # Point-wise: add, mul, sin, etc.
|
|
34
|
+
REDUCTION = "reduction" # Dimension reduction: sum, mean, max, etc.
|
|
35
|
+
MOVEMENT = "movement" # Data movement: copy_to_device, reshape, etc.
|
|
36
|
+
LINALG = "linalg" # Linear algebra: matmul, etc.
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OpInputs(Enum):
|
|
40
|
+
"""Number of tensor inputs to an operation.
|
|
41
|
+
|
|
42
|
+
The enum value equals the input count, e.g. `OpInputs.BINARY.value == 2`.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
UNARY = 1
|
|
46
|
+
BINARY = 2
|
|
47
|
+
TERNARY = 3
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class GradOpSpec:
|
|
52
|
+
"""Specification for a gradient operation.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
backward_fn (GradOp): The gradient computation function.
|
|
56
|
+
op_type (OpType): Operation category (elementwise, reduction, etc.).
|
|
57
|
+
op_inputs (OpInputs): Number of inputs (unary, binary, ternary).
|
|
58
|
+
forward_names (tuple[str, ...]): Forward op names mapping to this backward.
|
|
59
|
+
First name is canonical, others are aliases.
|
|
60
|
+
constraints (dict[str, str] | None): Input constraints for testing.
|
|
61
|
+
Maps input name to constraint type, e.g. ``{"x": "positive"}``.
|
|
62
|
+
skip_test (bool): Whether to skip automated finite difference testing.
|
|
63
|
+
skip_reason (str | None): Reason for skipping. Required if skip_test=True.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
backward_fn: GradOp
|
|
67
|
+
op_type: OpType
|
|
68
|
+
op_inputs: OpInputs
|
|
69
|
+
forward_names: tuple[str, ...]
|
|
70
|
+
constraints: dict[str, str] | None = None
|
|
71
|
+
skip_test: bool = False
|
|
72
|
+
skip_reason: str | None = None
|
|
73
|
+
|
|
74
|
+
def __post_init__(self) -> None:
|
|
75
|
+
"""Validate that skip_reason is provided when skip_test is True.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If skip_test is True but skip_reason is None or empty.
|
|
79
|
+
"""
|
|
80
|
+
if self.skip_test and not self.skip_reason:
|
|
81
|
+
raise ValueError("skip_reason is required when skip_test=True")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# The registry maps forward op names to their gradient specifications
|
|
85
|
+
_GRAD_OPS_REGISTRY: dict[str, GradOpSpec] = {}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def register_grad_op( # noqa: PLR0913
|
|
89
|
+
*,
|
|
90
|
+
op_type: OpType,
|
|
91
|
+
op_inputs: OpInputs,
|
|
92
|
+
forward_names: tuple[str, ...] | None = None,
|
|
93
|
+
constraints: dict[str, str] | None = None,
|
|
94
|
+
skip_test: bool = False,
|
|
95
|
+
skip_reason: str | None = None,
|
|
96
|
+
) -> Callable[[GradOp], GradOp]:
|
|
97
|
+
"""Decorator factory to register a gradient operation with metadata.
|
|
98
|
+
|
|
99
|
+
The decorated function should follow the naming convention ``<operation>_backward``.
|
|
100
|
+
It will be registered under all provided forward_names, or under the operation
|
|
101
|
+
name extracted from the function name if forward_names is None.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
op_type (OpType): Operation category (elementwise, reduction, etc.).
|
|
105
|
+
op_inputs (OpInputs): Number of tensor inputs (unary, binary, ternary).
|
|
106
|
+
forward_names (tuple[str, ...] | None): Forward op names to register under.
|
|
107
|
+
If None, extracted from function name.
|
|
108
|
+
constraints (dict[str, str] | None): Input constraints for testing.
|
|
109
|
+
skip_test (bool): Whether to skip automated finite difference testing.
|
|
110
|
+
skip_reason (str | None): Reason for skipping. Required if skip_test=True.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Callable[[GradOp], GradOp]: Decorator that registers the grad op.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ValueError: If skip_test=True but skip_reason is not provided.
|
|
117
|
+
"""
|
|
118
|
+
if skip_test and not skip_reason:
|
|
119
|
+
raise ValueError("skip_reason is required when skip_test=True")
|
|
120
|
+
|
|
121
|
+
def decorator(func: GradOp) -> GradOp:
|
|
122
|
+
canonical_name = func.__name__.rsplit("_", maxsplit=1)[0]
|
|
123
|
+
names = forward_names if forward_names is not None else (canonical_name,)
|
|
124
|
+
|
|
125
|
+
spec = GradOpSpec(
|
|
126
|
+
backward_fn=func,
|
|
127
|
+
op_type=op_type,
|
|
128
|
+
op_inputs=op_inputs,
|
|
129
|
+
forward_names=names,
|
|
130
|
+
constraints=constraints,
|
|
131
|
+
skip_test=skip_test,
|
|
132
|
+
skip_reason=skip_reason,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
for name in names:
|
|
136
|
+
_GRAD_OPS_REGISTRY[name] = spec
|
|
137
|
+
|
|
138
|
+
return func
|
|
139
|
+
|
|
140
|
+
return decorator
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def normalize_grad_op_name(*, name: str, is_reduce: bool = False) -> str:
|
|
144
|
+
"""Normalize operation name for registry lookup.
|
|
145
|
+
|
|
146
|
+
Handles the special case where "add" with is_reduce=True maps to "sum".
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
name (str): The operation name.
|
|
150
|
+
is_reduce (bool): Whether the operation is a reduction.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
str: The normalized operation name.
|
|
154
|
+
|
|
155
|
+
Examples:
|
|
156
|
+
>>> normalize_grad_op_name(name="add", is_reduce=True)
|
|
157
|
+
"sum"
|
|
158
|
+
>>> normalize_grad_op_name(name="add", is_reduce=False)
|
|
159
|
+
"add"
|
|
160
|
+
"""
|
|
161
|
+
if name == "add" and is_reduce:
|
|
162
|
+
return "sum"
|
|
163
|
+
return name
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def get_grad_op(name: str) -> GradOp | None:
|
|
167
|
+
"""Get the backward function for a forward operation.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
name (str): Forward operation name (e.g. "add", "matmul").
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
GradOp | None: The gradient function, or None if not found.
|
|
174
|
+
"""
|
|
175
|
+
spec = _GRAD_OPS_REGISTRY.get(normalize_grad_op_name(name=name))
|
|
176
|
+
return spec.backward_fn if spec is not None else None
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def get_grad_op_spec(name: str) -> GradOpSpec | None:
|
|
180
|
+
"""Get the full specification for a gradient operation.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
name (str): Forward operation name.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
GradOpSpec | None: The full specification, or None if not found.
|
|
187
|
+
"""
|
|
188
|
+
return _GRAD_OPS_REGISTRY.get(normalize_grad_op_name(name=name))
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _broadcast_backward(
|
|
192
|
+
x: Tensor,
|
|
193
|
+
grad_out: xp.ndarray,
|
|
194
|
+
) -> xp.ndarray:
|
|
195
|
+
"""Applies a backward gradient operation on broadcasting.
|
|
196
|
+
|
|
197
|
+
Effectively collapses `grad_out` by summing over all
|
|
198
|
+
dimensions of `x` that were broadcasted.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
x (Tensor): The Tensor that was broadcasted.
|
|
202
|
+
grad_out (xp.ndarray): The gradient of the following operation.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
xp.ndarray: The computed gradient.
|
|
206
|
+
"""
|
|
207
|
+
if x.shape == grad_out.shape:
|
|
208
|
+
return grad_out # shapes are the same, no broadcast happened
|
|
209
|
+
|
|
210
|
+
collapse_dim: list[int] = []
|
|
211
|
+
for i in range(max(x.ndim, grad_out.ndim)):
|
|
212
|
+
idx_x = x.ndim - i - 1
|
|
213
|
+
idx_grad_out = grad_out.ndim - i - 1
|
|
214
|
+
if idx_x < 0 or x.shape[idx_x] < grad_out.shape[idx_grad_out]:
|
|
215
|
+
collapse_dim.append(idx_grad_out)
|
|
216
|
+
|
|
217
|
+
return xp.sum(grad_out, axis=tuple(collapse_dim), keepdims=True).reshape(x.shape)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def broadcastable(
|
|
221
|
+
elem_wise_backward_fn: Callable[..., Any],
|
|
222
|
+
) -> Callable[..., Any]:
|
|
223
|
+
"""A decorator to extend element-wise backward gradient computing functions.
|
|
224
|
+
|
|
225
|
+
This decorator adds the ability to support broadcasting.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
elem_wise_backward_fn (Callable[[...], Any]): The backward
|
|
229
|
+
function of an element-wise operation that should support
|
|
230
|
+
broadcasting.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Callable[[...], Any]: The wrapper function
|
|
234
|
+
that supports broadcasting.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def wrapper(
|
|
238
|
+
*inputs: Tensor,
|
|
239
|
+
compute_grad: tuple[bool],
|
|
240
|
+
grad_out: xp.ndarray,
|
|
241
|
+
**kwargs: Any,
|
|
242
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
243
|
+
x, y = inputs
|
|
244
|
+
grad_x, grad_y = elem_wise_backward_fn(
|
|
245
|
+
*inputs,
|
|
246
|
+
compute_grad=compute_grad,
|
|
247
|
+
grad_out=grad_out,
|
|
248
|
+
**kwargs,
|
|
249
|
+
)
|
|
250
|
+
grad_x = _broadcast_backward(x, grad_x) if grad_x is not None else None
|
|
251
|
+
grad_y = _broadcast_backward(y, grad_y) if grad_y is not None else None
|
|
252
|
+
return grad_x, grad_y
|
|
253
|
+
|
|
254
|
+
# Trick so that register_grad_op decorator registers the,
|
|
255
|
+
# now broadcastable backward function under the name of the
|
|
256
|
+
# original backward function `elem_wise_backward_fn` (e.g. `add_backward`),
|
|
257
|
+
# instead of under the name `wrapper`
|
|
258
|
+
wrapper.__name__ = elem_wise_backward_fn.__name__
|
|
259
|
+
|
|
260
|
+
return wrapper
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@register_grad_op(
|
|
264
|
+
op_type=OpType.ELEMENTWISE,
|
|
265
|
+
op_inputs=OpInputs.UNARY,
|
|
266
|
+
forward_names=("absolute", "abs"),
|
|
267
|
+
)
|
|
268
|
+
def absolute_backward(
|
|
269
|
+
*inputs: Tensor,
|
|
270
|
+
compute_grad: tuple[bool],
|
|
271
|
+
grad_out: xp.ndarray,
|
|
272
|
+
) -> tuple[xp.ndarray | None]:
|
|
273
|
+
"""Computes gradients for `abs(x) = |x| = z`.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
277
|
+
Expects: `tuple[0]` to be `x`.
|
|
278
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
279
|
+
to compute, aligned with `inputs`.
|
|
280
|
+
grad_out (xp.ndarray): Upstream gradient.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
284
|
+
"""
|
|
285
|
+
x = inputs[0]
|
|
286
|
+
x_grad = xp.sign(x) * grad_out if compute_grad[0] else None
|
|
287
|
+
return (x_grad,)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@register_grad_op(
|
|
291
|
+
op_type=OpType.ELEMENTWISE,
|
|
292
|
+
op_inputs=OpInputs.UNARY,
|
|
293
|
+
)
|
|
294
|
+
def negative_backward(
|
|
295
|
+
*inputs: Tensor, # noqa: ARG001
|
|
296
|
+
compute_grad: tuple[bool],
|
|
297
|
+
grad_out: xp.ndarray,
|
|
298
|
+
) -> tuple[xp.ndarray | None]:
|
|
299
|
+
"""Computes gradients for `-x = -1 * x = z`.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
303
|
+
Expects: `tuple[0]` to be `x`.
|
|
304
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
305
|
+
to compute, aligned with `inputs`.
|
|
306
|
+
grad_out (xp.ndarray): Upstream gradient.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
310
|
+
"""
|
|
311
|
+
x_grad = -1 * grad_out if compute_grad[0] else None
|
|
312
|
+
return (x_grad,)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@register_grad_op(
|
|
316
|
+
op_type=OpType.ELEMENTWISE,
|
|
317
|
+
op_inputs=OpInputs.BINARY,
|
|
318
|
+
)
|
|
319
|
+
@broadcastable
|
|
320
|
+
def add_backward(
|
|
321
|
+
*inputs: Tensor, # noqa: ARG001
|
|
322
|
+
compute_grad: tuple[bool, bool],
|
|
323
|
+
grad_out: xp.ndarray,
|
|
324
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
325
|
+
"""Computes gradients for `x + y = z`.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
*inputs (Tensor): Two inputs `(x, y)`.
|
|
329
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
330
|
+
to compute, aligned with `inputs`.
|
|
331
|
+
grad_out (xp.ndarray): Upstream gradient.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
335
|
+
`None` where `compute_grad[i]` is False.
|
|
336
|
+
"""
|
|
337
|
+
x_grad = grad_out if compute_grad[0] else None
|
|
338
|
+
y_grad = grad_out if compute_grad[1] else None
|
|
339
|
+
return x_grad, y_grad
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
# Theoretically we do not need a backward for subtract
|
|
343
|
+
# -> "x - y = x + (-y) = z", so we could chain the
|
|
344
|
+
# "negative" and "add" backward functions, but a single
|
|
345
|
+
# "substract" function is more efficient
|
|
346
|
+
@register_grad_op(
|
|
347
|
+
op_type=OpType.ELEMENTWISE,
|
|
348
|
+
op_inputs=OpInputs.BINARY,
|
|
349
|
+
)
|
|
350
|
+
@broadcastable
|
|
351
|
+
def subtract_backward(
|
|
352
|
+
*inputs: Tensor, # noqa: ARG001
|
|
353
|
+
compute_grad: tuple[bool, bool],
|
|
354
|
+
grad_out: xp.ndarray,
|
|
355
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
356
|
+
"""Computes gradients for `x - y = z`.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
*inputs (Tensor): Two inputs `(x, y)`.
|
|
360
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
361
|
+
to compute, aligned with `inputs`.
|
|
362
|
+
grad_out (xp.ndarray): Upstream gradient.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
366
|
+
`None` where `compute_grad[i]` is False.
|
|
367
|
+
"""
|
|
368
|
+
x_grad = grad_out if compute_grad[0] else None
|
|
369
|
+
y_grad = -1 * grad_out if compute_grad[1] else None
|
|
370
|
+
return x_grad, y_grad
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@register_grad_op(
|
|
374
|
+
op_type=OpType.ELEMENTWISE,
|
|
375
|
+
op_inputs=OpInputs.BINARY,
|
|
376
|
+
forward_names=("mul", "multiply"),
|
|
377
|
+
)
|
|
378
|
+
@broadcastable
|
|
379
|
+
def mul_backward(
|
|
380
|
+
*inputs: Tensor,
|
|
381
|
+
compute_grad: tuple[bool, bool],
|
|
382
|
+
grad_out: xp.ndarray,
|
|
383
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
384
|
+
"""Computes the gradient for multiplication `x * y = z`.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
388
|
+
Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
|
|
389
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
390
|
+
to compute, aligned with `inputs`.
|
|
391
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
392
|
+
operation.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
396
|
+
`None` where `compute_grad[i]` is False.
|
|
397
|
+
"""
|
|
398
|
+
x, y = inputs
|
|
399
|
+
grad_x = y * grad_out if compute_grad[0] else None
|
|
400
|
+
grad_y = x * grad_out if compute_grad[1] else None
|
|
401
|
+
return grad_x, grad_y
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@register_grad_op(
|
|
405
|
+
op_type=OpType.ELEMENTWISE,
|
|
406
|
+
op_inputs=OpInputs.BINARY,
|
|
407
|
+
forward_names=("div", "divide"),
|
|
408
|
+
constraints={"y": "positive"}, # avoid division by zero
|
|
409
|
+
)
|
|
410
|
+
@broadcastable
|
|
411
|
+
def div_backward(
|
|
412
|
+
*inputs: Tensor,
|
|
413
|
+
compute_grad: tuple[bool, bool],
|
|
414
|
+
grad_out: xp.ndarray,
|
|
415
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
416
|
+
"""Computes the gradient for division `x / y = z`.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
420
|
+
Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
|
|
421
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
422
|
+
to compute, aligned with `inputs`.
|
|
423
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
424
|
+
operation.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
428
|
+
`None` where `compute_grad[i]` is False.
|
|
429
|
+
"""
|
|
430
|
+
x, y = inputs
|
|
431
|
+
grad_x = grad_out / y if compute_grad[0] else None
|
|
432
|
+
grad_y = -x * grad_out / (y * y) if compute_grad[1] else None
|
|
433
|
+
return grad_x, grad_y
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
@register_grad_op(
|
|
437
|
+
op_type=OpType.LINALG,
|
|
438
|
+
op_inputs=OpInputs.BINARY,
|
|
439
|
+
)
|
|
440
|
+
def matmul_backward(
|
|
441
|
+
*inputs: Tensor,
|
|
442
|
+
compute_grad: tuple[bool, bool],
|
|
443
|
+
grad_out: xp.ndarray,
|
|
444
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
445
|
+
"""Computes the gradient for matrix multiplication `AB = Z`.
|
|
446
|
+
|
|
447
|
+
Batched matrix multiplication is also supported, so inputs can
|
|
448
|
+
have dimensions `(..., i, j)` and `(..., j, k)`
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
452
|
+
Expects: `tuple[0]` to be the first matrix (or batch of matrices),
|
|
453
|
+
`tuple[1]` to be the second matrix (or batch of matrices).
|
|
454
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
455
|
+
to compute, aligned with `inputs`.
|
|
456
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
457
|
+
operation.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(A, B)`, with
|
|
461
|
+
`None` where `compute_grad[i]` is False.
|
|
462
|
+
"""
|
|
463
|
+
A, B = inputs
|
|
464
|
+
grad_x = xp.matmul(grad_out, xp.swapaxes(B, -2, -1)) if compute_grad[0] else None
|
|
465
|
+
grad_y = xp.matmul(xp.swapaxes(A, -2, -1), grad_out) if compute_grad[1] else None
|
|
466
|
+
|
|
467
|
+
return grad_x, grad_y
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
@register_grad_op(
|
|
471
|
+
op_type=OpType.ELEMENTWISE,
|
|
472
|
+
op_inputs=OpInputs.UNARY,
|
|
473
|
+
constraints={"x": "positive"},
|
|
474
|
+
)
|
|
475
|
+
def sqrt_backward(
|
|
476
|
+
*inputs: Tensor,
|
|
477
|
+
compute_grad: tuple[bool],
|
|
478
|
+
grad_out: xp.ndarray,
|
|
479
|
+
) -> tuple[xp.ndarray | None]:
|
|
480
|
+
"""Computes the gradient for square root `sqrt(x) = z`.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
484
|
+
Expects: `tuple[0]` to be `x`.
|
|
485
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
486
|
+
to compute, aligned with `inputs`.
|
|
487
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
488
|
+
operation.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
492
|
+
"""
|
|
493
|
+
x = inputs[0]
|
|
494
|
+
grad_x = grad_out / (2 * xp.sqrt(x)) if compute_grad[0] else None
|
|
495
|
+
return (grad_x,)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
@register_grad_op(
|
|
499
|
+
op_type=OpType.ELEMENTWISE,
|
|
500
|
+
op_inputs=OpInputs.BINARY,
|
|
501
|
+
constraints={"x": "positive"}, # avoid complex numbers with non-integer exponents
|
|
502
|
+
)
|
|
503
|
+
def power_backward(
|
|
504
|
+
*inputs: Tensor,
|
|
505
|
+
compute_grad: tuple[bool, bool],
|
|
506
|
+
grad_out: xp.ndarray,
|
|
507
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
508
|
+
"""Computes the gradient for power `x^y = z`.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
512
|
+
Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
|
|
513
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
514
|
+
to compute, aligned with `inputs`.
|
|
515
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
516
|
+
operation.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
520
|
+
`None` where `compute_grad[i]` is False.
|
|
521
|
+
"""
|
|
522
|
+
x, y = inputs
|
|
523
|
+
grad_x = y * xp.pow(x, y - 1) * grad_out if compute_grad[0] else None
|
|
524
|
+
grad_y = xp.pow(x, y) * xp.log(x) * grad_out if compute_grad[1] else None
|
|
525
|
+
return grad_x, grad_y
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@register_grad_op(
|
|
529
|
+
op_type=OpType.ELEMENTWISE,
|
|
530
|
+
op_inputs=OpInputs.UNARY,
|
|
531
|
+
)
|
|
532
|
+
def square_backward(
|
|
533
|
+
*inputs: Tensor,
|
|
534
|
+
compute_grad: tuple[bool],
|
|
535
|
+
grad_out: xp.ndarray,
|
|
536
|
+
) -> tuple[xp.ndarray | None]:
|
|
537
|
+
"""Computes the gradient for the square op (`x^2 = z`).
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
541
|
+
Expects: `tuple[0]` to be `x`.
|
|
542
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
543
|
+
to compute, aligned with `inputs`.
|
|
544
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
545
|
+
operation.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
549
|
+
"""
|
|
550
|
+
x = inputs[0]
|
|
551
|
+
grad_x = 2 * x * grad_out if compute_grad[0] else None
|
|
552
|
+
return (grad_x,)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
@register_grad_op(
|
|
556
|
+
op_type=OpType.ELEMENTWISE,
|
|
557
|
+
op_inputs=OpInputs.UNARY,
|
|
558
|
+
)
|
|
559
|
+
def exp_backward(
|
|
560
|
+
*inputs: Tensor,
|
|
561
|
+
compute_grad: tuple[bool],
|
|
562
|
+
grad_out: xp.ndarray,
|
|
563
|
+
) -> tuple[xp.ndarray | None]:
|
|
564
|
+
"""Computes the gradient for exponentiation `e^x = z`.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
568
|
+
Expects: `tuple[0]` to be `x`.
|
|
569
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
570
|
+
to compute, aligned with `inputs`.
|
|
571
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
572
|
+
operation.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
576
|
+
"""
|
|
577
|
+
x = inputs[0]
|
|
578
|
+
grad_x = grad_out * xp.exp(x) if compute_grad[0] else None
|
|
579
|
+
return (grad_x,)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@register_grad_op(
|
|
583
|
+
op_type=OpType.ELEMENTWISE,
|
|
584
|
+
op_inputs=OpInputs.UNARY,
|
|
585
|
+
constraints={"x": "positive"},
|
|
586
|
+
)
|
|
587
|
+
def log_backward(
|
|
588
|
+
*inputs: Tensor,
|
|
589
|
+
compute_grad: tuple[bool],
|
|
590
|
+
grad_out: xp.ndarray,
|
|
591
|
+
) -> tuple[xp.ndarray | None]:
|
|
592
|
+
"""Computes the gradient for the logarithm `log(x) = z`.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
596
|
+
Expects: `tuple[0]` to be `x`.
|
|
597
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
598
|
+
to compute, aligned with `inputs`.
|
|
599
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
600
|
+
operation.
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
604
|
+
"""
|
|
605
|
+
x = inputs[0]
|
|
606
|
+
grad_x = grad_out / x if compute_grad[0] else None
|
|
607
|
+
return (grad_x,)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
@register_grad_op(
|
|
611
|
+
op_type=OpType.ELEMENTWISE,
|
|
612
|
+
op_inputs=OpInputs.UNARY,
|
|
613
|
+
)
|
|
614
|
+
def sin_backward(
|
|
615
|
+
*inputs: Tensor,
|
|
616
|
+
compute_grad: tuple[bool],
|
|
617
|
+
grad_out: xp.ndarray,
|
|
618
|
+
) -> tuple[xp.ndarray | None]:
|
|
619
|
+
"""Computes the gradient for the sinus function `sin(x) = z`.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
623
|
+
Expects: `tuple[0]` to be `x`.
|
|
624
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
625
|
+
to compute, aligned with `inputs`.
|
|
626
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
627
|
+
operation.
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
631
|
+
"""
|
|
632
|
+
x = inputs[0]
|
|
633
|
+
grad_x = xp.cos(x) * grad_out if compute_grad[0] else None
|
|
634
|
+
return (grad_x,)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
@register_grad_op(
|
|
638
|
+
op_type=OpType.ELEMENTWISE,
|
|
639
|
+
op_inputs=OpInputs.UNARY,
|
|
640
|
+
)
|
|
641
|
+
def cos_backward(
|
|
642
|
+
*inputs: Tensor,
|
|
643
|
+
compute_grad: tuple[bool],
|
|
644
|
+
grad_out: xp.ndarray,
|
|
645
|
+
) -> tuple[xp.ndarray | None]:
|
|
646
|
+
"""Computes the gradient for the cosine function `cos(x) = z`.
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
650
|
+
Expects: `tuple[0]` to be `x`.
|
|
651
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
652
|
+
to compute, aligned with `inputs`.
|
|
653
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
654
|
+
operation.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
658
|
+
"""
|
|
659
|
+
x = inputs[0]
|
|
660
|
+
grad_x = -xp.sin(x) * grad_out if compute_grad[0] else None
|
|
661
|
+
return (grad_x,)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
@register_grad_op(
|
|
665
|
+
op_type=OpType.REDUCTION,
|
|
666
|
+
op_inputs=OpInputs.UNARY,
|
|
667
|
+
)
|
|
668
|
+
def sum_backward(
|
|
669
|
+
*inputs: Tensor,
|
|
670
|
+
compute_grad: tuple[bool],
|
|
671
|
+
grad_out: xp.ndarray,
|
|
672
|
+
**kwargs: Any,
|
|
673
|
+
) -> tuple[xp.ndarray | None]:
|
|
674
|
+
"""Computes the gradient for the sum function `sum(x) = z`.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
678
|
+
Expects: `tuple[0]` to be `x`.
|
|
679
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
680
|
+
to compute, aligned with `inputs`.
|
|
681
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
682
|
+
operation.
|
|
683
|
+
**kwargs (Any): Additional arguments, expects `axis`
|
|
684
|
+
to be present, denoting the axis of x over which the sum
|
|
685
|
+
was done.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
689
|
+
"""
|
|
690
|
+
x = inputs[0]
|
|
691
|
+
|
|
692
|
+
grad_x: xp.ndarray | None = None
|
|
693
|
+
|
|
694
|
+
if compute_grad[0]:
|
|
695
|
+
kwargs_axis = kwargs.get("axis")
|
|
696
|
+
axis = tuple(kwargs_axis) if kwargs_axis is not None else tuple(range(x.ndim))
|
|
697
|
+
|
|
698
|
+
grad_x = xp.broadcast_to(xp.expand_dims(grad_out, axis=axis), shape=x.shape)
|
|
699
|
+
|
|
700
|
+
return (grad_x,)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
@register_grad_op(
|
|
704
|
+
op_type=OpType.REDUCTION,
|
|
705
|
+
op_inputs=OpInputs.UNARY,
|
|
706
|
+
)
|
|
707
|
+
def mean_backward(
|
|
708
|
+
*inputs: Tensor,
|
|
709
|
+
compute_grad: tuple[bool],
|
|
710
|
+
grad_out: xp.ndarray,
|
|
711
|
+
**kwargs: Any,
|
|
712
|
+
) -> tuple[xp.ndarray | None]:
|
|
713
|
+
"""Computes the gradient for the mean function `mean(x) = z`.
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
717
|
+
Expects: `tuple[0]` to be `x`.
|
|
718
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
719
|
+
to compute, aligned with `inputs`.
|
|
720
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
721
|
+
operation.
|
|
722
|
+
**kwargs (Any): Additional arguments, expects `axis`
|
|
723
|
+
to be present, denoting the axis of x over which the sum
|
|
724
|
+
was done.
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
728
|
+
"""
|
|
729
|
+
x = inputs[0]
|
|
730
|
+
|
|
731
|
+
grad_x: xp.ndarray | None = None
|
|
732
|
+
|
|
733
|
+
if compute_grad[0]:
|
|
734
|
+
kwargs_axis = kwargs.get("axis")
|
|
735
|
+
axis = tuple(kwargs_axis) if kwargs_axis is not None else tuple(range(x.ndim))
|
|
736
|
+
|
|
737
|
+
grad_x = xp.broadcast_to(xp.expand_dims(grad_out, axis=axis), shape=x.shape)
|
|
738
|
+
|
|
739
|
+
n_reduced_elem = xp.prod([x.shape[a] for a in axis])
|
|
740
|
+
|
|
741
|
+
grad_x = grad_x / n_reduced_elem
|
|
742
|
+
|
|
743
|
+
return (grad_x,)
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def _extremum_backward(
|
|
747
|
+
*inputs: Tensor,
|
|
748
|
+
op_type: Literal["min", "max"],
|
|
749
|
+
grad_out: xp.ndarray,
|
|
750
|
+
**kwargs: Any,
|
|
751
|
+
) -> tuple[xp.ndarray]:
|
|
752
|
+
"""Computes the gradient for an extremum function.
|
|
753
|
+
|
|
754
|
+
Extemum is either `min` or `max`: `f(x) = z, f ∈ {min, max}`.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
758
|
+
Expects: `tuple[0]` to be `x`.
|
|
759
|
+
op_type (Literal["min", "max"]): The extremum type for which to
|
|
760
|
+
compute the gradient. Must be in ["min", "max"].
|
|
761
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
762
|
+
operation.
|
|
763
|
+
**kwargs (Any): Additional arguments, tries to extract:
|
|
764
|
+
|
|
765
|
+
- `axis`: Denoting the axis of x over which the
|
|
766
|
+
operation was done (defaults to None).
|
|
767
|
+
|
|
768
|
+
- `x_mask`: Stores whether a value in `x` is an extremum
|
|
769
|
+
(`min` for minimum, `max` for maximum)
|
|
770
|
+
along the reduced axes (must be of the same shape as `x`).
|
|
771
|
+
This argument is **required** to avoid numerical instability
|
|
772
|
+
when supressing all non-extrema values (this is the backward function).
|
|
773
|
+
|
|
774
|
+
- `keepdims`: Whether the dimensions over which was
|
|
775
|
+
reduced were retained (defaults to False).
|
|
776
|
+
|
|
777
|
+
Returns:
|
|
778
|
+
tuple[xp.ndarray]: The computed gradient with respect to `x`.
|
|
779
|
+
"""
|
|
780
|
+
x = inputs[0]
|
|
781
|
+
|
|
782
|
+
x_mask = kwargs.get("x_mask")
|
|
783
|
+
if x_mask is None:
|
|
784
|
+
raise ValueError(
|
|
785
|
+
f'Missing keyword argument "x_mask" is required in backward for {op_type}.'
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
assert x_mask.shape == x.shape, "Extremum mask must have the same shape as x"
|
|
789
|
+
|
|
790
|
+
kwargs_axis = kwargs.get("axis")
|
|
791
|
+
if kwargs_axis is None:
|
|
792
|
+
axis = tuple(range(x.ndim))
|
|
793
|
+
else:
|
|
794
|
+
axis = make_axis(ndim=x.ndim, axis_candidate=kwargs_axis)
|
|
795
|
+
|
|
796
|
+
if not kwargs.get("keepdims", False):
|
|
797
|
+
grad_out = xp.expand_dims(grad_out, axis=axis)
|
|
798
|
+
|
|
799
|
+
grad_out = xp.broadcast_to(grad_out, shape=x.shape)
|
|
800
|
+
|
|
801
|
+
count = xp.sum(x_mask, axis=axis, keepdims=True)
|
|
802
|
+
assert xp.all(count > 0), (
|
|
803
|
+
f'There must be at least one {op_type} along the reduced axis "{axis}"'
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
grad_x = xp.where(x_mask, grad_out / count, 0)
|
|
807
|
+
# Scale the gradient by the inverse of the number of extremas there were
|
|
808
|
+
# along the reduced axes.
|
|
809
|
+
# If one axis had 3 extremas, then the gradient is divided between them
|
|
810
|
+
# equally. In that case, the gradients for all 3 extremas are scaled by 1/3.
|
|
811
|
+
|
|
812
|
+
return (grad_x,)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
@register_grad_op(
|
|
816
|
+
op_type=OpType.REDUCTION,
|
|
817
|
+
op_inputs=OpInputs.UNARY,
|
|
818
|
+
skip_test=True,
|
|
819
|
+
skip_reason="requires x_mask computed during forward pass",
|
|
820
|
+
)
|
|
821
|
+
def max_backward(
|
|
822
|
+
*inputs: Tensor,
|
|
823
|
+
compute_grad: tuple[bool],
|
|
824
|
+
grad_out: xp.ndarray,
|
|
825
|
+
**kwargs: Any,
|
|
826
|
+
) -> tuple[xp.ndarray | None]:
|
|
827
|
+
"""Computes the gradient for the max function: max(x) = z.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
831
|
+
Expects: `tuple[0]` to be `x`.
|
|
832
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
833
|
+
to compute, aligned with `inputs`.
|
|
834
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
835
|
+
operation.
|
|
836
|
+
**kwargs (Any): Additional arguments, tries to extract:
|
|
837
|
+
|
|
838
|
+
- `axis`: Denoting the axis of x over which the
|
|
839
|
+
operation was done (defaults to None).
|
|
840
|
+
|
|
841
|
+
- `x_mask`: Stores whether a value in `x` is a max
|
|
842
|
+
along the reduced axes (must be of the same shape as `x`).
|
|
843
|
+
This argument is **required** to avoid numerical instability
|
|
844
|
+
when supressing all non-max values (this is the backward function).
|
|
845
|
+
|
|
846
|
+
- `keepdims`: Whether the dimensions over which was
|
|
847
|
+
reduced were retained (defaults to False).
|
|
848
|
+
|
|
849
|
+
Returns:
|
|
850
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
851
|
+
"""
|
|
852
|
+
if not compute_grad[0]:
|
|
853
|
+
return (None,)
|
|
854
|
+
# else ->
|
|
855
|
+
return _extremum_backward(
|
|
856
|
+
*inputs,
|
|
857
|
+
op_type="max",
|
|
858
|
+
grad_out=grad_out,
|
|
859
|
+
**kwargs,
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
@register_grad_op(
|
|
864
|
+
op_type=OpType.REDUCTION,
|
|
865
|
+
op_inputs=OpInputs.UNARY,
|
|
866
|
+
skip_test=True,
|
|
867
|
+
skip_reason="requires x_mask computed during forward pass",
|
|
868
|
+
)
|
|
869
|
+
def min_backward(
|
|
870
|
+
*inputs: Tensor,
|
|
871
|
+
compute_grad: tuple[bool],
|
|
872
|
+
grad_out: xp.ndarray,
|
|
873
|
+
**kwargs: Any,
|
|
874
|
+
) -> tuple[xp.ndarray | None]:
|
|
875
|
+
"""Computes the gradient for the min function: min(x) = z.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
*inputs (Tensor): The inputs, expected to be of length 1.
|
|
879
|
+
Expects: `tuple[0]` to be `x`.
|
|
880
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
881
|
+
to compute, aligned with `inputs`.
|
|
882
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
883
|
+
operation.
|
|
884
|
+
**kwargs (Any): Additional arguments, tries to extract:
|
|
885
|
+
|
|
886
|
+
- `axis`: Denoting the axis of x over which the
|
|
887
|
+
operation was done (defaults to None).
|
|
888
|
+
|
|
889
|
+
- `x_mask`: Stores whether a value in `x` is a min
|
|
890
|
+
along the reduced axes (must be of the same shape as `x`).
|
|
891
|
+
This argument is **required** to avoid numerical instability
|
|
892
|
+
when supressing all non-min values (this is the backward function).
|
|
893
|
+
|
|
894
|
+
- `keepdims`: Whether the dimensions over which was
|
|
895
|
+
reduced were retained (defaults to False).
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
tuple[xp.ndarray | None]: Gradient for `x`, or `None` if skipped.
|
|
899
|
+
"""
|
|
900
|
+
if not compute_grad[0]:
|
|
901
|
+
return (None,)
|
|
902
|
+
# else ->
|
|
903
|
+
return _extremum_backward(
|
|
904
|
+
*inputs,
|
|
905
|
+
op_type="min",
|
|
906
|
+
grad_out=grad_out,
|
|
907
|
+
**kwargs,
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
def _element_wise_extremum_backward(
|
|
912
|
+
*inputs: Tensor,
|
|
913
|
+
compute_grad: tuple[bool, bool],
|
|
914
|
+
op_type: Literal["minimum", "maximum"],
|
|
915
|
+
grad_out: xp.ndarray,
|
|
916
|
+
**kwargs: Any,
|
|
917
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
918
|
+
"""Computes the gradient for an element-wise extremum function.
|
|
919
|
+
|
|
920
|
+
Extemum is either `minimum` or `maximum`:
|
|
921
|
+
`f(x, y) = z, f ∈ {minimum, maximum}`.
|
|
922
|
+
|
|
923
|
+
Args:
|
|
924
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
925
|
+
Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
|
|
926
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
927
|
+
to compute, aligned with `inputs`.
|
|
928
|
+
op_type (Literal["minimum", "maximum"]): The extremum type
|
|
929
|
+
for which to compute the gradient.
|
|
930
|
+
Must be in ["minimum", "maximum"].
|
|
931
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
932
|
+
operation.
|
|
933
|
+
**kwargs (Any): Additional arguments, tries to extract:
|
|
934
|
+
|
|
935
|
+
- `x_mask`: Stores whether a value in `x` is an extremum
|
|
936
|
+
(`min` for minimum, `max` for maximum) compared
|
|
937
|
+
to the value at the same location in `y`
|
|
938
|
+
(must be of the same shape as `x`). This argument
|
|
939
|
+
is **required** to avoid numerical instability
|
|
940
|
+
when supressing all non-extrema values (this is the backward function).
|
|
941
|
+
|
|
942
|
+
- `where`: On which locations to apply `f`.
|
|
943
|
+
|
|
944
|
+
Returns:
|
|
945
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
946
|
+
`None` where `compute_grad[i]` is False.
|
|
947
|
+
"""
|
|
948
|
+
x = inputs[0]
|
|
949
|
+
y = inputs[1]
|
|
950
|
+
|
|
951
|
+
if x.shape != y.shape:
|
|
952
|
+
raise ValueError(
|
|
953
|
+
f"Both inputs must be of same shape, got {x.shape} for x and {y.shape} for y"
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
x_mask = kwargs.get("x_mask")
|
|
957
|
+
if x_mask is None:
|
|
958
|
+
raise ValueError(
|
|
959
|
+
f'Missing keyword argument "x_mask" is required in backward for {op_type}.'
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
apply_grad_op = kwargs.get("where", 1)
|
|
963
|
+
|
|
964
|
+
assert x_mask.shape == x.shape, "Extremum mask must have the same shape as x"
|
|
965
|
+
|
|
966
|
+
both_extremum = x == y
|
|
967
|
+
|
|
968
|
+
x_grad: xp.ndarray | None = None
|
|
969
|
+
y_grad: xp.ndarray | None = None
|
|
970
|
+
|
|
971
|
+
if compute_grad[0]:
|
|
972
|
+
scale_x = xp.where(both_extremum, 0.5, x_mask)
|
|
973
|
+
x_grad = apply_grad_op * scale_x * grad_out
|
|
974
|
+
|
|
975
|
+
if compute_grad[1]:
|
|
976
|
+
scale_y = xp.where(both_extremum, 0.5, 1 - x_mask)
|
|
977
|
+
y_grad = apply_grad_op * scale_y * grad_out
|
|
978
|
+
|
|
979
|
+
return x_grad, y_grad
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
@register_grad_op(
|
|
983
|
+
op_type=OpType.ELEMENTWISE,
|
|
984
|
+
op_inputs=OpInputs.BINARY,
|
|
985
|
+
skip_test=True,
|
|
986
|
+
skip_reason="requires x_mask computed during forward pass",
|
|
987
|
+
)
|
|
988
|
+
@broadcastable
|
|
989
|
+
def maximum_backward(
|
|
990
|
+
*inputs: Tensor,
|
|
991
|
+
compute_grad: tuple[bool, bool],
|
|
992
|
+
grad_out: xp.ndarray,
|
|
993
|
+
**kwargs: Any,
|
|
994
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
995
|
+
"""Computes the gradient for an element-wise maximum.
|
|
996
|
+
|
|
997
|
+
Function is `maximum(x, y) = z`.
|
|
998
|
+
|
|
999
|
+
Args:
|
|
1000
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
1001
|
+
Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
|
|
1002
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
1003
|
+
to compute, aligned with `inputs`.
|
|
1004
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
1005
|
+
operation.
|
|
1006
|
+
**kwargs (Any): Additional arguments, tries to extract:
|
|
1007
|
+
|
|
1008
|
+
- `x_mask`: Stores whether a value in `x` is
|
|
1009
|
+
larger compared to the value at the same
|
|
1010
|
+
location in `y` (must be of the same shape as `x`).
|
|
1011
|
+
This argument is **required** to avoid numerical instability
|
|
1012
|
+
when supressing all non-maximum values (this is the backward function).
|
|
1013
|
+
|
|
1014
|
+
Returns:
|
|
1015
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
1016
|
+
`None` where `compute_grad[i]` is False.
|
|
1017
|
+
"""
|
|
1018
|
+
return _element_wise_extremum_backward(
|
|
1019
|
+
*inputs,
|
|
1020
|
+
compute_grad=compute_grad,
|
|
1021
|
+
op_type="maximum",
|
|
1022
|
+
grad_out=grad_out,
|
|
1023
|
+
**kwargs,
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
|
|
1027
|
+
@register_grad_op(
|
|
1028
|
+
op_type=OpType.ELEMENTWISE,
|
|
1029
|
+
op_inputs=OpInputs.BINARY,
|
|
1030
|
+
skip_test=True,
|
|
1031
|
+
skip_reason="requires x_mask computed during forward pass",
|
|
1032
|
+
)
|
|
1033
|
+
@broadcastable
|
|
1034
|
+
def minimum_backward(
|
|
1035
|
+
*inputs: Tensor,
|
|
1036
|
+
compute_grad: tuple[bool, bool],
|
|
1037
|
+
grad_out: xp.ndarray,
|
|
1038
|
+
**kwargs: Any,
|
|
1039
|
+
) -> tuple[xp.ndarray | None, xp.ndarray | None]:
|
|
1040
|
+
"""Computes the gradient for an element-wise minimum.
|
|
1041
|
+
|
|
1042
|
+
Function is `minimum(x, y) = z`.
|
|
1043
|
+
|
|
1044
|
+
Args:
|
|
1045
|
+
*inputs (Tensor): The inputs, expected to be of length 2.
|
|
1046
|
+
Expects: `tuple[0]` to be `x`, `tuple[1]` to be `y`.
|
|
1047
|
+
compute_grad (tuple[bool, bool]): Flags indicating which input gradients
|
|
1048
|
+
to compute, aligned with `inputs`.
|
|
1049
|
+
grad_out (xp.ndarray): The gradient of the following
|
|
1050
|
+
operation.
|
|
1051
|
+
**kwargs (Any): Additional arguments, tries to extract:
|
|
1052
|
+
|
|
1053
|
+
- `x_mask`: Stores whether a value in `x` is
|
|
1054
|
+
smaller compared to the value at the same
|
|
1055
|
+
location in `y` (must be of the same shape as `x`).
|
|
1056
|
+
This argument is **required** to avoid numerical instability
|
|
1057
|
+
when supressing all non-minimum values (this is the backward function).
|
|
1058
|
+
|
|
1059
|
+
Returns:
|
|
1060
|
+
tuple[xp.ndarray | None, xp.ndarray | None]: Gradients for `(x, y)`, with
|
|
1061
|
+
`None` where `compute_grad[i]` is False.
|
|
1062
|
+
"""
|
|
1063
|
+
return _element_wise_extremum_backward(
|
|
1064
|
+
*inputs,
|
|
1065
|
+
compute_grad=compute_grad,
|
|
1066
|
+
op_type="minimum",
|
|
1067
|
+
grad_out=grad_out,
|
|
1068
|
+
**kwargs,
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
@register_grad_op(
|
|
1073
|
+
op_type=OpType.MOVEMENT,
|
|
1074
|
+
op_inputs=OpInputs.UNARY,
|
|
1075
|
+
skip_test=True,
|
|
1076
|
+
skip_reason="not testable with finite differences",
|
|
1077
|
+
)
|
|
1078
|
+
def copy_to_device_backward(
|
|
1079
|
+
*inputs: Tensor,
|
|
1080
|
+
compute_grad: tuple[bool],
|
|
1081
|
+
grad_out: xp.ndarray,
|
|
1082
|
+
) -> tuple[xp.ndarray | None]:
|
|
1083
|
+
"""Computes gradients for copying a Tensor to a (different) device.
|
|
1084
|
+
|
|
1085
|
+
Args:
|
|
1086
|
+
*inputs (Tensor): The input Tensor, should only be one, as
|
|
1087
|
+
a copy operation only operates on a single Tensor.
|
|
1088
|
+
compute_grad (tuple[bool]): Flags indicating which input gradients
|
|
1089
|
+
to compute, aligned with `inputs`.
|
|
1090
|
+
grad_out (xp.ndarray): Upstream gradient.
|
|
1091
|
+
|
|
1092
|
+
Returns:
|
|
1093
|
+
tuple[xp.ndarray | None]: Gradients for the Tensor, with
|
|
1094
|
+
`None` where `compute_grad[i]` is False.
|
|
1095
|
+
"""
|
|
1096
|
+
x = inputs[0]
|
|
1097
|
+
# Just pass through grad_out by reverting the "copy_to_device" op
|
|
1098
|
+
# That means copying grad_out back to the device of x
|
|
1099
|
+
x_grad = copy_array(array=grad_out, device=x.device) if compute_grad[0] else None
|
|
1100
|
+
return (x_grad,)
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
def make_axis(ndim: int, axis_candidate: Any) -> tuple[int, ...]:
|
|
1104
|
+
"""Transforms the `axis` argument for numpy ops.
|
|
1105
|
+
|
|
1106
|
+
Returns a consistent `tuple[int, ...]` type.
|
|
1107
|
+
Transforms negative axes into positive ones for
|
|
1108
|
+
consistency.
|
|
1109
|
+
|
|
1110
|
+
Args:
|
|
1111
|
+
ndim (int): Number of dimensions of the array
|
|
1112
|
+
on which the numpy op is performed.
|
|
1113
|
+
axis_candidate (Any): The `axis` argument
|
|
1114
|
+
used in the numpy op.
|
|
1115
|
+
|
|
1116
|
+
Raises:
|
|
1117
|
+
ValueError: If `axis_candidate` is an invalid
|
|
1118
|
+
numpy `axis` type.
|
|
1119
|
+
|
|
1120
|
+
Returns:
|
|
1121
|
+
tuple[int, ...]: The consistent numpy `axis`
|
|
1122
|
+
argument.
|
|
1123
|
+
"""
|
|
1124
|
+
if isinstance(axis_candidate, int):
|
|
1125
|
+
axis = (axis_candidate,)
|
|
1126
|
+
elif isinstance(axis_candidate, list) and all(isinstance(a, int) for a in axis_candidate):
|
|
1127
|
+
axis = tuple(axis_candidate)
|
|
1128
|
+
|
|
1129
|
+
elif isinstance(axis_candidate, tuple) and all(isinstance(a, int) for a in axis_candidate):
|
|
1130
|
+
axis = axis_candidate
|
|
1131
|
+
|
|
1132
|
+
else:
|
|
1133
|
+
raise ValueError(
|
|
1134
|
+
'"axis_candidate" must be in '
|
|
1135
|
+
f'[int, list[int], tuple[int]], found: "{type(axis_candidate).__name__}".'
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
def make_positive(value: int) -> int:
|
|
1139
|
+
if value < 0:
|
|
1140
|
+
return ndim + value # normalize axes to positive values
|
|
1141
|
+
return value
|
|
1142
|
+
|
|
1143
|
+
normalized_axis = tuple(map(make_positive, axis))
|
|
1144
|
+
|
|
1145
|
+
assert all(0 <= a < ndim for a in normalized_axis)
|
|
1146
|
+
|
|
1147
|
+
return normalized_axis
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
__all__ = [
|
|
1151
|
+
"GradOp",
|
|
1152
|
+
"GradOpSpec",
|
|
1153
|
+
"OpInputs",
|
|
1154
|
+
"OpType",
|
|
1155
|
+
"get_grad_op",
|
|
1156
|
+
"get_grad_op_spec",
|
|
1157
|
+
"register_grad_op",
|
|
1158
|
+
]
|