torchjd 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. torchjd/__init__.py +1 -0
  2. torchjd/_transform/__init__.py +22 -0
  3. torchjd/_transform/_differentiation.py +34 -0
  4. torchjd/_transform/_utils.py +62 -0
  5. torchjd/_transform/aggregation.py +12 -0
  6. torchjd/_transform/base.py +117 -0
  7. torchjd/_transform/concatenation.py +65 -0
  8. torchjd/_transform/diagonalize.py +36 -0
  9. torchjd/_transform/grad.py +60 -0
  10. torchjd/_transform/identity.py +22 -0
  11. torchjd/_transform/init.py +30 -0
  12. torchjd/_transform/jac.py +109 -0
  13. torchjd/_transform/matrixify.py +25 -0
  14. torchjd/_transform/reshape.py +26 -0
  15. torchjd/_transform/scaling.py +21 -0
  16. torchjd/_transform/stack.py +63 -0
  17. torchjd/_transform/store.py +29 -0
  18. torchjd/_transform/strategy/__init__.py +4 -0
  19. torchjd/_transform/strategy/_utils.py +86 -0
  20. torchjd/_transform/strategy/extrapolating.py +75 -0
  21. torchjd/_transform/strategy/isolating.py +25 -0
  22. torchjd/_transform/strategy/partitioning.py +81 -0
  23. torchjd/_transform/strategy/unifying.py +43 -0
  24. torchjd/_transform/subset.py +27 -0
  25. torchjd/_transform/tensor_dict.py +210 -0
  26. torchjd/aggregation/__init__.py +16 -0
  27. torchjd/aggregation/_gramian_utils.py +46 -0
  28. torchjd/aggregation/_normalizing.py +48 -0
  29. torchjd/aggregation/_pref_vector_utils.py +26 -0
  30. torchjd/aggregation/_str_utils.py +11 -0
  31. torchjd/aggregation/aligned_mtl.py +129 -0
  32. torchjd/aggregation/bases.py +89 -0
  33. torchjd/aggregation/cagrad.py +105 -0
  34. torchjd/aggregation/constant.py +67 -0
  35. torchjd/aggregation/dualproj.py +131 -0
  36. torchjd/aggregation/graddrop.py +85 -0
  37. torchjd/aggregation/imtl_g.py +50 -0
  38. torchjd/aggregation/krum.py +108 -0
  39. torchjd/aggregation/mean.py +42 -0
  40. torchjd/aggregation/mgda.py +85 -0
  41. torchjd/aggregation/nash_mtl.py +221 -0
  42. torchjd/aggregation/pcgrad.py +72 -0
  43. torchjd/aggregation/random.py +47 -0
  44. torchjd/aggregation/sum.py +40 -0
  45. torchjd/aggregation/trimmed_mean.py +73 -0
  46. torchjd/aggregation/upgrad.py +136 -0
  47. torchjd/backward.py +95 -0
  48. torchjd-0.1.0.dist-info/LICENSE +21 -0
  49. torchjd-0.1.0.dist-info/METADATA +55 -0
  50. torchjd-0.1.0.dist-info/RECORD +52 -0
  51. torchjd-0.1.0.dist-info/WHEEL +5 -0
  52. torchjd-0.1.0.dist-info/top_level.txt +1 -0
torchjd/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from torchjd.backward import backward
@@ -0,0 +1,22 @@
1
+ from torchjd._transform.aggregation import make_aggregation
2
+ from torchjd._transform.base import Composition, Conjunction, Transform
3
+ from torchjd._transform.concatenation import Concatenation
4
+ from torchjd._transform.diagonalize import Diagonalize
5
+ from torchjd._transform.grad import Grad
6
+ from torchjd._transform.identity import Identity
7
+ from torchjd._transform.init import Init
8
+ from torchjd._transform.jac import Jac
9
+ from torchjd._transform.matrixify import Matrixify
10
+ from torchjd._transform.reshape import Reshape
11
+ from torchjd._transform.scaling import Scaling
12
+ from torchjd._transform.stack import Stack
13
+ from torchjd._transform.store import Store
14
+ from torchjd._transform.subset import Subset
15
+ from torchjd._transform.tensor_dict import (
16
+ EmptyTensorDict,
17
+ Gradients,
18
+ GradientVectors,
19
+ JacobianMatrices,
20
+ Jacobians,
21
+ TensorDict,
22
+ )
@@ -0,0 +1,34 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Iterable, Sequence
3
+
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._utils import ordered_set
7
+ from torchjd._transform.base import _A, Transform
8
+
9
+
10
+ class _Differentiation(Transform[_A, _A], ABC):
11
+ def __init__(self, outputs: Iterable[Tensor], inputs: Iterable[Tensor]):
12
+ self.outputs = ordered_set(outputs)
13
+ self.inputs = ordered_set(inputs)
14
+
15
+ def _compute(self, tensors: _A) -> _A:
16
+ tensor_outputs = [tensors[output] for output in self.outputs]
17
+
18
+ differentiated_tuple = self._differentiate(tensor_outputs)
19
+ new_differentiations = dict(zip(self.inputs, differentiated_tuple))
20
+ return type(tensors)(new_differentiations)
21
+
22
+ @abstractmethod
23
+ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
24
+ raise NotImplementedError
25
+
26
+ @property
27
+ def required_keys(self) -> set[Tensor]:
28
+ # outputs in the forward direction become inputs in the backward direction, and vice-versa
29
+ return set(self.outputs)
30
+
31
+ @property
32
+ def output_keys(self) -> set[Tensor]:
33
+ # outputs in the forward direction become inputs in the backward direction, and vice-versa
34
+ return set(self.inputs)
@@ -0,0 +1,62 @@
1
+ from collections import OrderedDict
2
+ from typing import Hashable, Iterable, Sequence, TypeAlias, TypeVar
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from torchjd._transform.tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
8
+
9
+ _KeyType = TypeVar("_KeyType", bound=Hashable)
10
+ _ValueType = TypeVar("_ValueType")
11
+ _OrderedSet: TypeAlias = OrderedDict[_KeyType, None]
12
+
13
+ _A = TypeVar("_A", bound=TensorDict)
14
+ _B = TypeVar("_B", bound=TensorDict)
15
+ _C = TypeVar("_C", bound=TensorDict)
16
+
17
+
18
+ def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
19
+ elements = list(elements)
20
+ result = OrderedDict.fromkeys(elements, None)
21
+ if len(elements) != len(result):
22
+ raise ValueError(
23
+ f"Parameter `elements` should contain unique elements. Found `elements = {elements}`."
24
+ )
25
+
26
+ return result
27
+
28
+
29
+ def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
30
+ result = {}
31
+ for d in dicts:
32
+ result |= d
33
+ return result
34
+
35
+
36
+ def _materialize(
37
+ optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
38
+ ) -> tuple[Tensor, ...]:
39
+ """
40
+ Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
41
+ shape as the corresponding input. Returns the obtained sequence as a tuple.
42
+
43
+ Note that the name "materialize" comes from the flag `materialize_grads` from
44
+ `torch.autograd.grad`, which will be available in future torch releases.
45
+ """
46
+
47
+ tensors = []
48
+ for optional_tensor, input in zip(optional_tensors, inputs):
49
+ if optional_tensor is None:
50
+ tensors.append(torch.zeros_like(input))
51
+ else:
52
+ tensors.append(optional_tensor)
53
+ return tuple(tensors)
54
+
55
+
56
+ def _union(tensor_dicts: Iterable[_A]) -> _A:
57
+ output_type: type[_A] = EmptyTensorDict
58
+ output: _A = EmptyTensorDict()
59
+ for tensor_dict in tensor_dicts:
60
+ output_type = _least_common_ancestor(output_type, type(tensor_dict))
61
+ output |= tensor_dict
62
+ return output_type(output)
@@ -0,0 +1,12 @@
1
+ from torchjd._transform.base import Transform
2
+ from torchjd._transform.matrixify import Matrixify
3
+ from torchjd._transform.reshape import Reshape
4
+ from torchjd._transform.tensor_dict import Gradients, GradientVectors, JacobianMatrices, Jacobians
5
+
6
+
7
+ def make_aggregation(
8
+ strategy: Transform[JacobianMatrices, GradientVectors]
9
+ ) -> Transform[Jacobians, Gradients]:
10
+ """TODO: doc"""
11
+
12
+ return Reshape(strategy.required_keys) << strategy << Matrixify(strategy.required_keys)
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Generic, Sequence
5
+
6
+ from torch import Tensor
7
+
8
+ from torchjd._transform._utils import _A, _B, _C, _union
9
+
10
+
11
+ class Transform(Generic[_B, _C], ABC):
12
+ r"""
13
+ Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
14
+ descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
15
+ another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
16
+ and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
17
+
18
+ Formally a transform is a function:
19
+
20
+ .. math::
21
+ f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
22
+
23
+ where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
24
+ the value associated to the ``i`` th `required_key` of the input
25
+ :class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
26
+ value associated to the ``j`` th `output_key` of the output
27
+ :class:`~torchjd.transform.tensor_dict.TensorDict`.
28
+
29
+ As they are mathematical functions, transforms can be composed together as long as their
30
+ domains and range meaningfully match [TODO]. We also define conjunction of transforms [TODO].
31
+ """
32
+
33
+ def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
34
+ return Composition(self, other)
35
+
36
+ def conjunct(self, other: Transform[_B, _C]) -> Transform[_B, _C]:
37
+ return Conjunction([self, other])
38
+
39
+ def __repr__(self) -> str:
40
+ return type(self).__name__
41
+
42
+ @abstractmethod
43
+ def _compute(self, input: _B) -> _C:
44
+ raise NotImplementedError
45
+
46
+ def __call__(self, input: _B) -> _C:
47
+ input.check_keys_are(self.required_keys)
48
+ return self._compute(input)
49
+
50
+ @property
51
+ @abstractmethod
52
+ def required_keys(self) -> set[Tensor]:
53
+ raise NotImplementedError
54
+
55
+ @property
56
+ @abstractmethod
57
+ def output_keys(self) -> set[Tensor]:
58
+ raise NotImplementedError
59
+
60
+ __lshift__ = compose
61
+ __or__ = conjunct
62
+
63
+
64
+ class Composition(Transform[_A, _C]):
65
+ def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
66
+ if outer.required_keys != inner.output_keys:
67
+ raise ValueError(
68
+ "The `output_keys` of `inner` must match with the `required_keys` of "
69
+ f"outer. Found {outer.required_keys} and {inner.output_keys}"
70
+ )
71
+ self.outer = outer
72
+ self.inner = inner
73
+
74
+ def __repr__(self) -> str:
75
+ return repr(self.outer) + " ∘ " + repr(self.inner)
76
+
77
+ def _compute(self, input: _A) -> _C:
78
+ intermediate = self.inner(input)
79
+ return self.outer(intermediate)
80
+
81
+ @property
82
+ def required_keys(self) -> set[Tensor]:
83
+ return self.inner.required_keys
84
+
85
+ @property
86
+ def output_keys(self) -> set[Tensor]:
87
+ return self.outer.output_keys
88
+
89
+
90
+ class Conjunction(Transform[_A, _B]):
91
+ def __init__(self, transforms: Sequence[Transform[_A, _B]]):
92
+ self.transforms = transforms
93
+
94
+ self._required_keys = set(
95
+ key for transform in transforms for key in transform.required_keys
96
+ )
97
+ for transform in transforms:
98
+ if transform.required_keys != self.required_keys:
99
+ raise ValueError("All transforms should require the same set of keys.")
100
+
101
+ output_keys_with_duplicates = [key for t in transforms for key in t.output_keys]
102
+ self._output_keys = set(output_keys_with_duplicates)
103
+
104
+ if len(self._output_keys) != len(output_keys_with_duplicates):
105
+ raise ValueError("The sets of output keys of transforms should be disjoint.")
106
+
107
+ def _compute(self, tensor_dict: _A) -> _B:
108
+ output = _union([transform(tensor_dict) for transform in self.transforms])
109
+ return output
110
+
111
+ @property
112
+ def required_keys(self) -> set[Tensor]:
113
+ return self._required_keys
114
+
115
+ @property
116
+ def output_keys(self) -> set[Tensor]:
117
+ return self._output_keys
@@ -0,0 +1,65 @@
1
+ from typing import Sequence
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._utils import _A, _materialize, dicts_union
7
+ from torchjd._transform.base import Transform
8
+ from torchjd._transform.tensor_dict import Jacobians
9
+
10
+
11
+ class Concatenation(Transform[_A, Jacobians]):
12
+ def __init__(self, transforms: Sequence[Transform[_A, Jacobians]]):
13
+ if len(transforms) == 0:
14
+ raise ValueError("Parameter `transforms` cannot be empty.")
15
+
16
+ self.transforms = transforms
17
+
18
+ self._required_keys = transforms[0].required_keys
19
+ self._output_keys = {key for transform in transforms for key in transform.output_keys}
20
+
21
+ for transform in transforms[1:]:
22
+ if transform.required_keys != self.required_keys:
23
+ raise ValueError("All transforms should require the same set of keys.")
24
+
25
+ def _compute(self, input: _A) -> Jacobians:
26
+ results = [transform(input) for transform in self.transforms]
27
+ result = _concatenate(results)
28
+ return result
29
+
30
+ @property
31
+ def required_keys(self) -> set[Tensor]:
32
+ return self._required_keys
33
+
34
+ @property
35
+ def output_keys(self) -> set[Tensor]:
36
+ return self._output_keys
37
+
38
+
39
+ def _concatenate(jacobians_dicts: list[Jacobians]) -> Jacobians:
40
+ """
41
+ Transforms a list of tensor dicts into a single dict of (concatenated) tensors. The set of keys
42
+ of the resulting dict is the union of the sets of keys of the input dicts. If a key is absent in
43
+ some input dicts, the corresponding concatenated tensor is filled with zeroes at the positions
44
+ corresponding to those dicts.
45
+ """
46
+
47
+ # It is important to first remove duplicate keys before computing their associated
48
+ # concatenated tensor. Otherwise, some computations would be duplicated. Therefore, we first
49
+ # compute unique_keys, and only then, we compute the concatenated tensors.
50
+ unique_keys = dicts_union(jacobians_dicts).keys()
51
+ result = Jacobians({key: _concatenate_one_key(jacobians_dicts, key) for key in unique_keys})
52
+ return result
53
+
54
+
55
+ def _concatenate_one_key(jacobian_dicts: list[Jacobians], input: Tensor) -> Tensor:
56
+ """
57
+ Makes the concatenated tensor corresponding to a given key, from a list of tensor dicts.
58
+ """
59
+
60
+ first_dimensions = [jacobian_dict.first_dimension for jacobian_dict in jacobian_dicts]
61
+ optional_jacobians = [jacobian.get(input, None) for jacobian in jacobian_dicts]
62
+ expanded_inputs = [input.expand(dim, *input.shape) for dim in first_dimensions]
63
+ jacobians = _materialize(optional_jacobians, expanded_inputs)
64
+ jacobian = torch.concatenate(jacobians, dim=0)
65
+ return jacobian
@@ -0,0 +1,36 @@
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._utils import ordered_set
7
+ from torchjd._transform.base import Transform
8
+ from torchjd._transform.tensor_dict import Gradients, Jacobians
9
+
10
+
11
+ class Diagonalize(Transform[Gradients, Jacobians]):
12
+ def __init__(self, considered: Iterable[Tensor]):
13
+ self.considered = ordered_set(considered)
14
+ self.indices: list[tuple[int, int]] = []
15
+ begin = 0
16
+ for tensor in self.considered:
17
+ end = begin + tensor.numel()
18
+ self.indices.append((begin, end))
19
+ begin = end
20
+
21
+ def _compute(self, tensors: Gradients) -> Jacobians:
22
+ flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
23
+ diagonal_matrix = torch.cat(flattened_considered_values).diag()
24
+ diagonalized_tensors = {
25
+ key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
26
+ for (begin, end), key in zip(self.indices, self.considered)
27
+ }
28
+ return Jacobians(diagonalized_tensors)
29
+
30
+ @property
31
+ def required_keys(self) -> set[Tensor]:
32
+ return set(self.considered)
33
+
34
+ @property
35
+ def output_keys(self) -> set[Tensor]:
36
+ return set(self.considered)
@@ -0,0 +1,60 @@
1
+ from typing import Iterable, Sequence
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._differentiation import _Differentiation
7
+ from torchjd._transform._utils import _materialize
8
+ from torchjd._transform.tensor_dict import Gradients
9
+
10
+
11
+ class Grad(_Differentiation[Gradients]):
12
+ def __init__(
13
+ self,
14
+ outputs: Iterable[Tensor],
15
+ inputs: Iterable[Tensor],
16
+ retain_graph: bool = False,
17
+ ):
18
+ super().__init__(outputs, inputs)
19
+ self.retain_graph = retain_graph
20
+
21
+ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
22
+ return _grad(
23
+ outputs=self.outputs,
24
+ inputs=self.inputs,
25
+ grad_outputs=grad_outputs,
26
+ retain_graph=self.retain_graph,
27
+ create_graph=False,
28
+ allow_unused=True,
29
+ )
30
+
31
+
32
+ def _grad(
33
+ outputs: Sequence[Tensor],
34
+ inputs: Sequence[Tensor],
35
+ grad_outputs: Sequence[Tensor],
36
+ retain_graph: bool,
37
+ create_graph: bool,
38
+ allow_unused: bool,
39
+ ) -> tuple[Tensor, ...]:
40
+ """
41
+ Wraps `autograd.grad` to give it additional responsibilities that it should have (like being
42
+ able to work with an empty sequence of `inputs`).
43
+ """
44
+
45
+ if len(inputs) == 0:
46
+ return tuple()
47
+
48
+ if len(outputs) == 0:
49
+ return tuple([torch.empty(input.shape) for input in inputs])
50
+
51
+ optional_grads = torch.autograd.grad(
52
+ outputs,
53
+ inputs,
54
+ grad_outputs=grad_outputs,
55
+ retain_graph=retain_graph,
56
+ create_graph=create_graph,
57
+ allow_unused=allow_unused,
58
+ )
59
+ grads = _materialize(optional_grads, inputs)
60
+ return grads
@@ -0,0 +1,22 @@
1
+ from typing import Iterable
2
+
3
+ from torch import Tensor
4
+
5
+ from torchjd._transform._utils import _A
6
+ from torchjd._transform.base import Transform
7
+
8
+
9
+ class Identity(Transform[_A, _A]):
10
+ def __init__(self, required_keys: Iterable[Tensor]):
11
+ self._required_keys = set(required_keys)
12
+
13
+ def _compute(self, tensor_dict: _A) -> _A:
14
+ return tensor_dict
15
+
16
+ @property
17
+ def required_keys(self) -> set[Tensor]:
18
+ return self._required_keys
19
+
20
+ @property
21
+ def output_keys(self) -> set[Tensor]:
22
+ return self._required_keys
@@ -0,0 +1,30 @@
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform.base import Transform
7
+ from torchjd._transform.tensor_dict import EmptyTensorDict, Gradients
8
+
9
+
10
+ class Init(Transform[EmptyTensorDict, Gradients]):
11
+ def __init__(self, values: Iterable[Tensor]):
12
+ self.values = set(values)
13
+
14
+ def _compute(self, input: EmptyTensorDict) -> Gradients:
15
+ r"""
16
+ Computes the gradients of the ``value`` with respect to itself. Returns the result as a
17
+ dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
18
+ tensor of 1s of identical shape, because :math:`\frac{\partial v}{\partial v} = 1` for any
19
+ :math:`v`.
20
+ """
21
+
22
+ return Gradients({value: torch.ones_like(value) for value in self.values})
23
+
24
+ @property
25
+ def required_keys(self) -> set[Tensor]:
26
+ return set()
27
+
28
+ @property
29
+ def output_keys(self) -> set[Tensor]:
30
+ return self.values
@@ -0,0 +1,109 @@
1
+ from itertools import accumulate
2
+ from typing import Iterable, Sequence
3
+
4
+ import torch
5
+ from torch import Size, Tensor
6
+
7
+ from torchjd._transform._differentiation import _Differentiation
8
+ from torchjd._transform._utils import _materialize
9
+ from torchjd._transform.tensor_dict import Jacobians
10
+
11
+
12
+ class Jac(_Differentiation[Jacobians]):
13
+ def __init__(
14
+ self,
15
+ outputs: Iterable[Tensor],
16
+ inputs: Iterable[Tensor],
17
+ chunk_size: int | None,
18
+ retain_graph: bool = False,
19
+ ):
20
+ super().__init__(outputs, inputs)
21
+ self.chunk_size = chunk_size
22
+ self.retain_graph = retain_graph
23
+
24
+ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
25
+ return _jac(
26
+ outputs=self.outputs,
27
+ inputs=self.inputs,
28
+ jac_outputs=jac_outputs,
29
+ chunk_size=self.chunk_size,
30
+ retain_graph=self.retain_graph,
31
+ create_graph=False,
32
+ allow_unused=True,
33
+ )
34
+
35
+
36
+ def _jac(
37
+ outputs: Sequence[Tensor],
38
+ inputs: Sequence[Tensor],
39
+ jac_outputs: Sequence[Tensor],
40
+ chunk_size: int | None,
41
+ retain_graph: bool,
42
+ create_graph: bool,
43
+ allow_unused: bool,
44
+ ) -> tuple[Tensor, ...]:
45
+ """
46
+ Wraps the call to `autograd.grad` to compute the jacobian with respect to each input, in an
47
+ optimized way. The first dimension of the jacobians is equal to the length of the sequence of
48
+ `outputs`, which should be the same as the length of the sequence of `grad_outputs`. This should
49
+ be equivalent to calling `_grad(outputs[i], inputs, grad_outputs[i], ...)` for all i
50
+ sequentially, and stacking the elements of each resulting tuple.
51
+ """
52
+
53
+ if len(inputs) == 0:
54
+ return tuple()
55
+
56
+ n_outputs = len(outputs)
57
+ if len(jac_outputs) != n_outputs:
58
+ raise ValueError(
59
+ "Parameters `outputs` and `jac_outputs` should be sequences of the same length. Found "
60
+ f"`len(outputs) = {n_outputs}` and `len(jac_outputs) = {len(jac_outputs)}`."
61
+ )
62
+
63
+ if n_outputs == 0:
64
+ return tuple([torch.empty((0,) + input.shape) for input in inputs])
65
+
66
+ def get_vjp(v):
67
+ optional_grads = torch.autograd.grad(
68
+ outputs,
69
+ inputs,
70
+ grad_outputs=v,
71
+ retain_graph=retain_graph,
72
+ create_graph=create_graph,
73
+ allow_unused=allow_unused,
74
+ )
75
+ grads = _materialize(optional_grads, inputs=inputs)
76
+ return torch.concatenate([grad.reshape([-1]) for grad in grads])
77
+
78
+ grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs)
79
+
80
+ lengths = [input.numel() for input in inputs]
81
+ jacobian_matrices = _extract_sub_matrices(grouped_jacobian_matrix, lengths)
82
+
83
+ shapes = [input.shape for input in inputs]
84
+ jacobians = _reshape_matrices(jacobian_matrices, shapes)
85
+
86
+ return tuple(jacobians)
87
+
88
+
89
+ def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:
90
+ cumulative_lengths = [*accumulate(lengths)]
91
+
92
+ if cumulative_lengths[-1] != matrix.shape[1]:
93
+ raise ValueError(
94
+ "The sum of the provided lengths should be equal to the number of columns in the "
95
+ "provided matrix."
96
+ )
97
+
98
+ start_indices = [0] + cumulative_lengths[:-1]
99
+ end_indices = cumulative_lengths
100
+ return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)]
101
+
102
+
103
+ def _reshape_matrices(matrices: Sequence[Tensor], shapes: Sequence[Size]) -> Sequence[Tensor]:
104
+ if len(matrices) != len(shapes):
105
+ raise ValueError(
106
+ "Parameters `matrices` and `shapes` should contain the same number of elements."
107
+ )
108
+
109
+ return [matrix.view((matrix.shape[0],) + shape) for matrix, shape in zip(matrices, shapes)]
@@ -0,0 +1,25 @@
1
+ from typing import Iterable
2
+
3
+ from torch import Tensor
4
+
5
+ from torchjd._transform.base import Transform
6
+ from torchjd._transform.tensor_dict import JacobianMatrices, Jacobians
7
+
8
+
9
+ class Matrixify(Transform[Jacobians, JacobianMatrices]):
10
+ def __init__(self, required_keys: Iterable[Tensor]):
11
+ self._required_keys = set(required_keys)
12
+
13
+ def _compute(self, jacobians: Jacobians) -> JacobianMatrices:
14
+ jacobian_matrices = {
15
+ key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
16
+ }
17
+ return JacobianMatrices(jacobian_matrices)
18
+
19
+ @property
20
+ def required_keys(self) -> set[Tensor]:
21
+ return self._required_keys
22
+
23
+ @property
24
+ def output_keys(self) -> set[Tensor]:
25
+ return self._required_keys
@@ -0,0 +1,26 @@
1
+ from typing import Iterable
2
+
3
+ from torch import Tensor
4
+
5
+ from torchjd._transform.base import Transform
6
+ from torchjd._transform.tensor_dict import Gradients, GradientVectors
7
+
8
+
9
+ class Reshape(Transform[GradientVectors, Gradients]):
10
+ def __init__(self, required_keys: Iterable[Tensor]):
11
+ self._required_keys = set(required_keys)
12
+
13
+ def _compute(self, gradient_vectors: GradientVectors) -> Gradients:
14
+ gradients = {
15
+ key: gradient_vector.view(key.shape)
16
+ for key, gradient_vector in gradient_vectors.items()
17
+ }
18
+ return Gradients(gradients)
19
+
20
+ @property
21
+ def required_keys(self) -> set[Tensor]:
22
+ return self._required_keys
23
+
24
+ @property
25
+ def output_keys(self) -> set[Tensor]:
26
+ return self._required_keys
@@ -0,0 +1,21 @@
1
+ from torch import Tensor
2
+
3
+ from torchjd._transform._utils import _A
4
+ from torchjd._transform.base import Transform
5
+
6
+
7
+ class Scaling(Transform[_A, _A]):
8
+ def __init__(self, scalings: dict[Tensor, float]):
9
+ self.scalings = scalings
10
+
11
+ def _compute(self, tensor_dict: _A) -> _A:
12
+ output = {key: scaling * tensor_dict[key] for key, scaling in self.scalings.items()}
13
+ return type(tensor_dict)(output)
14
+
15
+ @property
16
+ def required_keys(self) -> set[Tensor]:
17
+ return set(self.scalings.keys())
18
+
19
+ @property
20
+ def output_keys(self) -> set[Tensor]:
21
+ return set(self.scalings.keys())