py-sadl 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sadl/ops.py ADDED
@@ -0,0 +1,67 @@
1
+ """All custom or extended (from numpy/cupy) operations on SADL Tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from .backend import xp
8
+ from .tensor import (
9
+ _GRAD_MODE_ENABLED,
10
+ Tensor,
11
+ )
12
+ from .tensor import (
13
+ _copy_to_device as copy_to_device,
14
+ )
15
+
16
+
17
+ def ones_like(
18
+ other: Tensor,
19
+ *,
20
+ dtype: Any = None,
21
+ requires_grad: bool = False,
22
+ ) -> Tensor:
23
+ """Create a Tensor of ones with the same shape and device as `other`.
24
+
25
+ Args:
26
+ other (Tensor): The tensor to match shape and device from.
27
+ dtype (Any): Override dtype. Defaults to None (use other's dtype).
28
+ requires_grad (bool): Whether to track gradients. Defaults to False.
29
+
30
+ Returns:
31
+ Tensor: A tensor of ones.
32
+ """
33
+ # Use xp.ones(shape) instead of xp.ones_like(tensor) to avoid
34
+ # triggering __array_function__ on the Tensor
35
+ result: Tensor = xp.ones(other.shape, dtype=dtype or other.dtype).view(Tensor)
36
+ result.requires_grad = _GRAD_MODE_ENABLED and requires_grad
37
+ return result
38
+
39
+
40
+ def zeros_like(
41
+ other: Tensor,
42
+ *,
43
+ dtype: Any = None,
44
+ requires_grad: bool = False,
45
+ ) -> Tensor:
46
+ """Create a Tensor of zeros with the same shape and device as `other`.
47
+
48
+ Args:
49
+ other (Tensor): The tensor to match shape and device from.
50
+ dtype (Any): Override dtype. Defaults to None (use other's dtype).
51
+ requires_grad (bool): Whether to track gradients. Defaults to False.
52
+
53
+ Returns:
54
+ Tensor: A tensor of zeros.
55
+ """
56
+ # Use xp.zeros(shape) instead of xp.zeros_like(tensor) to avoid
57
+ # triggering __array_function__ on the Tensor
58
+ result: Tensor = xp.zeros(other.shape, dtype=dtype or other.dtype).view(Tensor)
59
+ result.requires_grad = _GRAD_MODE_ENABLED and requires_grad
60
+ return result
61
+
62
+
63
+ __all__ = [
64
+ "copy_to_device",
65
+ "ones_like",
66
+ "zeros_like",
67
+ ]
sadl/optimizer.py ADDED
@@ -0,0 +1,352 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from collections import OrderedDict
4
+ from collections.abc import Iterable, ValuesView
5
+ from itertools import chain
6
+
7
+ from .backend import TensorDevice, xp
8
+ from .tensor import Parameter, Tensor, no_grad, no_grad_fn, tensor
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def toposort(root: Tensor) -> list[Tensor]:
14
+ """Performs topological sort on a graph.
15
+
16
+ `root` is the starting point of the graph.
17
+
18
+ Args:
19
+ root (Tensor): The starting point of the graph.
20
+ Expected to have an attribute `src` denoting
21
+ a list of its neighbors, also being of type `Tensor`.
22
+
23
+ Raises:
24
+ ValueError: If the graph, with the starting point
25
+ given by `root`, is not a DAG.
26
+
27
+ Returns:
28
+ list[Tensor]: The ordered nodes.
29
+ """
30
+ ordered_nodes: list[Tensor] = []
31
+ currently_visiting: set[Tensor] = set()
32
+ done: set[Tensor] = set()
33
+
34
+ stack: list[tuple[Tensor, bool]] = [(root, False)]
35
+ while stack:
36
+ node, visited = stack.pop()
37
+ if node in done:
38
+ continue
39
+
40
+ if visited:
41
+ ordered_nodes.append(node)
42
+ currently_visiting.discard(node)
43
+ done.add(node)
44
+ continue
45
+
46
+ stack.append((node, True))
47
+ currently_visiting.add(node)
48
+ for neighbor in reversed(node.src):
49
+ if neighbor in done:
50
+ continue
51
+ if neighbor in currently_visiting:
52
+ raise ValueError("Cycle in computation graph detected, but only DAG allowed!")
53
+ stack.append((neighbor, False))
54
+
55
+ return ordered_nodes
56
+
57
+
58
+ class Optimizer(ABC):
59
+ """Abstract base class for all optimizers."""
60
+
61
+ def __init__(self, params: list[Parameter], lr: float = 1e-3):
62
+ if len(params) == 0:
63
+ raise ValueError("Must pass at least one parameter to optimize.")
64
+ for param in params:
65
+ if not isinstance(param, Parameter):
66
+ raise TypeError("All parameters passed to the optimizer must be of type Parameter.")
67
+ if not param.keep_grad:
68
+ raise ValueError(
69
+ "Attribute keep_grad must always be True for all parameters "
70
+ "to avoid clearing gradients during the backward pass. "
71
+ "This is important for cases like gradient accumulation."
72
+ )
73
+ if len(param.src) > 0:
74
+ raise ValueError(
75
+ "Parameters should always be leafes and therefore "
76
+ "should not have any parents/src "
77
+ "from which they were created."
78
+ )
79
+
80
+ self.params = params
81
+ self.lr = tensor(lr, device=self.params[0].device)
82
+
83
+ @property
84
+ def device(self) -> tuple[TensorDevice] | None:
85
+ """The devices on which the optimizer state is currently located.
86
+
87
+ Can be mutliple if the optimizer state is sharded across multiple
88
+ devices.
89
+
90
+ Returns:
91
+ tuple[TensorDevice] | None: The devices on which the optimizer
92
+ state is currently located. None, if the function has
93
+ no parameters.
94
+ """
95
+ unique_devices = {attr.device for attr in self.state}
96
+ if len(unique_devices) == 0:
97
+ return None
98
+ return tuple(unique_devices)
99
+
100
+ # "no_grad_fn" technically not needed here, as optimizer state Tensors
101
+ # are leaves (not part of a computation graph). We still annotate
102
+ # just to be safe.
103
+ @no_grad_fn
104
+ def copy_to_device(self, device: TensorDevice) -> "Optimizer":
105
+ """Copy the optimizer state to the specified `device`.
106
+
107
+ Args:
108
+ device (TensorDevice): The device to copy the state to.
109
+
110
+ Returns:
111
+ Optimizer: self, for method chaining.
112
+ """
113
+ for key, state_tensor in self.get_state().items():
114
+ setattr(self, key, state_tensor.copy_to_device(device))
115
+
116
+ return self
117
+
118
+ # "no_grad_fn" technically not needed here, as optimizer state Tensors
119
+ # are leaves (not part of a computation graph). We still annotate
120
+ # just to be safe.
121
+ @no_grad_fn
122
+ def get_state(self, to_device: TensorDevice | None = None) -> OrderedDict[str, Tensor]:
123
+ """The state of the optimizer.
124
+
125
+ Note: Only **direct** attributes of the Optimizer class of type **Tensor**
126
+ will be considered in the state.
127
+
128
+ Args:
129
+ to_device (TensorDevice | None): If specified, copy each
130
+ Tensor in the state to this device in the returned dict.
131
+ If `None`, the device of the Tensors is not changed. Defaults to None.
132
+
133
+ Returns:
134
+ OrderedDict[str, Tensor]: Dict containing the state.
135
+ """
136
+ state = OrderedDict[str, Tensor]()
137
+ for key, data in vars(self).items():
138
+ if isinstance(data, Tensor):
139
+ state[key] = data.copy_to_device(to_device) if to_device is not None else data
140
+ return state
141
+
142
+ @property
143
+ def state(self) -> ValuesView[Tensor]:
144
+ """The Tensors forming the state of the optimizer.
145
+
146
+ Returns:
147
+ ValuesView[Tensor]: A view over the
148
+ state Tensors.
149
+ """
150
+ return self.get_state().values()
151
+
152
+ # "no_grad_fn" technically not needed here, as optimizer state Tensors
153
+ # are leaves (not part of a computation graph). We still annotate
154
+ # just to be safe.
155
+ @no_grad_fn
156
+ def load_state(
157
+ self,
158
+ *,
159
+ state: OrderedDict[str, Tensor],
160
+ match_device: bool = False,
161
+ partial: bool = False,
162
+ ) -> "Optimizer":
163
+ """Load/initialize the state of the optimizer.
164
+
165
+ Args:
166
+ state (OrderedDict[str, Tensor]): The state of the optimizer.
167
+ match_device (bool): If True, copy each loaded Tensor
168
+ to the target's device before assignment.
169
+ If False, raises on device mismatch. Defaults to False.
170
+ partial (bool): If True, allow missing keys in `state`.
171
+ If False, raises on missing keys. Defaults to False.
172
+
173
+ Returns:
174
+ Optimizer: self, for method chaining.
175
+ """
176
+ for key, data in vars(self).items():
177
+ init_data = state.get(key, None)
178
+ if init_data is None:
179
+ if not partial:
180
+ raise KeyError(f'Optimizer state "{key}" not found in passed state!')
181
+ continue
182
+
183
+ if not isinstance(init_data, Tensor):
184
+ raise TypeError(
185
+ 'Data in passed state must be "Tensor", '
186
+ f'found "{type(init_data).__name__}" ({init_data})'
187
+ )
188
+ if data.shape != init_data.shape:
189
+ raise ValueError(
190
+ f"Shape of seed Tensor does not align with shape of "
191
+ f'target parameter "{key}". Found "{data.shape}", '
192
+ f'expected "{init_data.shape}".'
193
+ )
194
+ if match_device:
195
+ init_data = init_data.copy_to_device(data.device)
196
+ elif data.device != init_data.device:
197
+ raise ValueError(
198
+ f"Device of seed Tensor does not align with device of "
199
+ f'target parameter "{key}". Found "{init_data.device}", '
200
+ f'expected "{data.device}".'
201
+ )
202
+ setattr(self, key, init_data)
203
+
204
+ return self
205
+
206
+ def _clear_graph(self, topo_nodes: Iterable[Tensor]) -> None:
207
+ """Clears computation graph structure and gradients after backward pass.
208
+
209
+ Removes references to parent tensors, backward functions, and operation
210
+ context to free memory. Gradients are cleared for non-parameter tensors
211
+ (keep_grad=False).
212
+
213
+ Args:
214
+ topo_nodes (Iterator[Tensor]): Nodes in topological order from the
215
+ computation graph to clear.
216
+ """
217
+ for node in topo_nodes:
218
+ node.detach(in_place=True)
219
+
220
+ def _clear_activation_gradients(self, topo_nodes: Iterable[Tensor]) -> None:
221
+ """Clears gradients of activation tensors before backward pass.
222
+
223
+ Ensures activations start with no gradients, preventing stale gradients
224
+ from previous backward passes if tensors are reused. Parameters
225
+ (keep_grad=True) retain their gradients for accumulation.
226
+
227
+ Args:
228
+ topo_nodes (Iterator[Tensor]): Nodes in topological order from the
229
+ computation graph whose gradients should be cleared.
230
+ """
231
+ for node in topo_nodes:
232
+ if not node.keep_grad:
233
+ node.grad = None
234
+
235
+ def backward(self, loss: Tensor) -> None:
236
+ """Perform backpropagation on the computation graph with respect to `loss`.
237
+
238
+ Note: Does **not** support accumulating multiple losses
239
+ into the gradient of `loss`. This only works for parameters.
240
+ Use multiple losses and sum them instead.
241
+
242
+ Args:
243
+ loss (Tensor): The loss on with respect to which we perform
244
+ gradient calculations.
245
+
246
+ Raises:
247
+ ValueError: If the loss is not a scalar.
248
+ """
249
+ if not isinstance(loss, Tensor) or loss.size > 1:
250
+ raise ValueError("Expected 'loss' argument to be to be a scalar Tensor.")
251
+ if loss.keep_grad:
252
+ raise ValueError(
253
+ "keep_grad=True not supported for the loss Tensor. "
254
+ "Use multiple losses and accumulate/sum them into a new "
255
+ "loss Tensor instead."
256
+ )
257
+
258
+ node_order = toposort(loss)
259
+
260
+ # Clear the gradients of all activations:
261
+ # (this is necessary to make "node.grad is None" below well-defined)
262
+ self._clear_activation_gradients(topo_nodes=node_order)
263
+
264
+ # do not use xp.ones_like(loss) here, because "loss" is a Tensor that requires
265
+ # a gradient, which will trigger "__array_function__" and try to track this
266
+ # -> another fix would be moving it down in the "no_grad()" context
267
+ loss.grad = xp.ones(loss.shape, dtype=loss.dtype) # seed gradient for loss
268
+
269
+ with no_grad():
270
+ for node in reversed(node_order):
271
+ # "node.grad is None" -> Means: In the current computation graph,
272
+ # this node has not received any gradients, meaning it does not
273
+ # contribute to the loss => We can skip it
274
+ # "node.is_leaf()" -> Means: In the current computation graph,
275
+ # this node has no parents to pass gradients to, meaning we can skip it
276
+ # "not any(src_requires_grad)" -> Means: In the current computation graph,
277
+ # not parent requires a gradient to be passed to it, meaning we can skip
278
+ # the current node
279
+ src_requires_grad = [t.requires_grad for t in node.src]
280
+ if node.grad is None or node.is_leaf() or not any(src_requires_grad):
281
+ continue
282
+
283
+ if node.backward_fn is None:
284
+ raise ValueError(
285
+ f'"backward_fn" for node "{id(node)}" in computation graph is "None"!'
286
+ )
287
+
288
+ logger.debug(f'Calling backward function: "{node.backward_fn.__name__}"')
289
+
290
+ src_grads = node.backward_fn(
291
+ *node.src,
292
+ compute_grad=src_requires_grad,
293
+ grad_out=node.grad,
294
+ **node.op_ctx,
295
+ )
296
+ assert len(src_grads) == len(node.src)
297
+
298
+ for src, src_grad in zip(node.src, src_grads, strict=True):
299
+ if src_grad is None:
300
+ continue
301
+
302
+ assert src.shape == src_grad.shape
303
+ current_src_grad = src.grad if src.grad is not None else xp.zeros_like(src)
304
+ src.grad = current_src_grad + src_grad
305
+
306
+ self._clear_graph(topo_nodes=node_order)
307
+
308
+ def zero_grad(self, additional_tensors: list[Tensor] | None = None) -> None:
309
+ """Clears the gradiens of all parameters that are optimized.
310
+
311
+ Applied to all parameters in `self.params`.
312
+
313
+ Args:
314
+ additional_tensors (list[Tensor], optional): Extra Tensors for
315
+ which to reset gradients. This could be activations that have
316
+ explicitly set their `keep_grad` attribute to `True`, meaning
317
+ they are not cleared before the backward pass of the graph
318
+ they are part of. Defaults to None.
319
+ """
320
+ for param in chain(self.params, additional_tensors or []):
321
+ param.grad = None
322
+
323
+ @no_grad_fn
324
+ @abstractmethod
325
+ def step(self) -> None:
326
+ """The step function to update the parameters.
327
+
328
+ Must be implemented by the specific optimizer.
329
+ """
330
+
331
+
332
+ class SGD(Optimizer):
333
+ """Stochastic gradient descent optimizer."""
334
+
335
+ @no_grad_fn
336
+ def step(self) -> None:
337
+ """Performs a single gradient descent step.
338
+
339
+ Raises:
340
+ ValueError: If a parameter has no gradient.
341
+ """
342
+ for param in self.params:
343
+ if param.grad is None:
344
+ raise ValueError("Gradient of parameter must not be None in step function")
345
+ param -= self.lr * param.grad # noqa: PLW2901
346
+
347
+
348
+ __all__ = [
349
+ "SGD",
350
+ "Optimizer",
351
+ "toposort",
352
+ ]