torchjd 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchjd/__init__.py +1 -0
- torchjd/_transform/__init__.py +22 -0
- torchjd/_transform/_differentiation.py +34 -0
- torchjd/_transform/_utils.py +62 -0
- torchjd/_transform/aggregation.py +12 -0
- torchjd/_transform/base.py +117 -0
- torchjd/_transform/concatenation.py +65 -0
- torchjd/_transform/diagonalize.py +36 -0
- torchjd/_transform/grad.py +60 -0
- torchjd/_transform/identity.py +22 -0
- torchjd/_transform/init.py +30 -0
- torchjd/_transform/jac.py +109 -0
- torchjd/_transform/matrixify.py +25 -0
- torchjd/_transform/reshape.py +26 -0
- torchjd/_transform/scaling.py +21 -0
- torchjd/_transform/stack.py +63 -0
- torchjd/_transform/store.py +29 -0
- torchjd/_transform/strategy/__init__.py +4 -0
- torchjd/_transform/strategy/_utils.py +86 -0
- torchjd/_transform/strategy/extrapolating.py +75 -0
- torchjd/_transform/strategy/isolating.py +25 -0
- torchjd/_transform/strategy/partitioning.py +81 -0
- torchjd/_transform/strategy/unifying.py +43 -0
- torchjd/_transform/subset.py +27 -0
- torchjd/_transform/tensor_dict.py +210 -0
- torchjd/aggregation/__init__.py +16 -0
- torchjd/aggregation/_gramian_utils.py +46 -0
- torchjd/aggregation/_normalizing.py +48 -0
- torchjd/aggregation/_pref_vector_utils.py +26 -0
- torchjd/aggregation/_str_utils.py +11 -0
- torchjd/aggregation/aligned_mtl.py +129 -0
- torchjd/aggregation/bases.py +89 -0
- torchjd/aggregation/cagrad.py +105 -0
- torchjd/aggregation/constant.py +67 -0
- torchjd/aggregation/dualproj.py +131 -0
- torchjd/aggregation/graddrop.py +85 -0
- torchjd/aggregation/imtl_g.py +50 -0
- torchjd/aggregation/krum.py +108 -0
- torchjd/aggregation/mean.py +42 -0
- torchjd/aggregation/mgda.py +85 -0
- torchjd/aggregation/nash_mtl.py +221 -0
- torchjd/aggregation/pcgrad.py +72 -0
- torchjd/aggregation/random.py +47 -0
- torchjd/aggregation/sum.py +40 -0
- torchjd/aggregation/trimmed_mean.py +73 -0
- torchjd/aggregation/upgrad.py +136 -0
- torchjd/backward.py +95 -0
- torchjd-0.1.0.dist-info/LICENSE +21 -0
- torchjd-0.1.0.dist-info/METADATA +55 -0
- torchjd-0.1.0.dist-info/RECORD +52 -0
- torchjd-0.1.0.dist-info/WHEEL +5 -0
- torchjd-0.1.0.dist-info/top_level.txt +1 -0
torchjd/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from torchjd.backward import backward
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from torchjd._transform.aggregation import make_aggregation
|
|
2
|
+
from torchjd._transform.base import Composition, Conjunction, Transform
|
|
3
|
+
from torchjd._transform.concatenation import Concatenation
|
|
4
|
+
from torchjd._transform.diagonalize import Diagonalize
|
|
5
|
+
from torchjd._transform.grad import Grad
|
|
6
|
+
from torchjd._transform.identity import Identity
|
|
7
|
+
from torchjd._transform.init import Init
|
|
8
|
+
from torchjd._transform.jac import Jac
|
|
9
|
+
from torchjd._transform.matrixify import Matrixify
|
|
10
|
+
from torchjd._transform.reshape import Reshape
|
|
11
|
+
from torchjd._transform.scaling import Scaling
|
|
12
|
+
from torchjd._transform.stack import Stack
|
|
13
|
+
from torchjd._transform.store import Store
|
|
14
|
+
from torchjd._transform.subset import Subset
|
|
15
|
+
from torchjd._transform.tensor_dict import (
|
|
16
|
+
EmptyTensorDict,
|
|
17
|
+
Gradients,
|
|
18
|
+
GradientVectors,
|
|
19
|
+
JacobianMatrices,
|
|
20
|
+
Jacobians,
|
|
21
|
+
TensorDict,
|
|
22
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Iterable, Sequence
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._utils import ordered_set
|
|
7
|
+
from torchjd._transform.base import _A, Transform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _Differentiation(Transform[_A, _A], ABC):
|
|
11
|
+
def __init__(self, outputs: Iterable[Tensor], inputs: Iterable[Tensor]):
|
|
12
|
+
self.outputs = ordered_set(outputs)
|
|
13
|
+
self.inputs = ordered_set(inputs)
|
|
14
|
+
|
|
15
|
+
def _compute(self, tensors: _A) -> _A:
|
|
16
|
+
tensor_outputs = [tensors[output] for output in self.outputs]
|
|
17
|
+
|
|
18
|
+
differentiated_tuple = self._differentiate(tensor_outputs)
|
|
19
|
+
new_differentiations = dict(zip(self.inputs, differentiated_tuple))
|
|
20
|
+
return type(tensors)(new_differentiations)
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def required_keys(self) -> set[Tensor]:
|
|
28
|
+
# outputs in the forward direction become inputs in the backward direction, and vice-versa
|
|
29
|
+
return set(self.outputs)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def output_keys(self) -> set[Tensor]:
|
|
33
|
+
# outputs in the forward direction become inputs in the backward direction, and vice-versa
|
|
34
|
+
return set(self.inputs)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
from typing import Hashable, Iterable, Sequence, TypeAlias, TypeVar
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torchjd._transform.tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
|
|
8
|
+
|
|
9
|
+
_KeyType = TypeVar("_KeyType", bound=Hashable)
|
|
10
|
+
_ValueType = TypeVar("_ValueType")
|
|
11
|
+
_OrderedSet: TypeAlias = OrderedDict[_KeyType, None]
|
|
12
|
+
|
|
13
|
+
_A = TypeVar("_A", bound=TensorDict)
|
|
14
|
+
_B = TypeVar("_B", bound=TensorDict)
|
|
15
|
+
_C = TypeVar("_C", bound=TensorDict)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
|
|
19
|
+
elements = list(elements)
|
|
20
|
+
result = OrderedDict.fromkeys(elements, None)
|
|
21
|
+
if len(elements) != len(result):
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Parameter `elements` should contain unique elements. Found `elements = {elements}`."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return result
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
|
|
30
|
+
result = {}
|
|
31
|
+
for d in dicts:
|
|
32
|
+
result |= d
|
|
33
|
+
return result
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _materialize(
|
|
37
|
+
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
|
|
38
|
+
) -> tuple[Tensor, ...]:
|
|
39
|
+
"""
|
|
40
|
+
Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
|
|
41
|
+
shape as the corresponding input. Returns the obtained sequence as a tuple.
|
|
42
|
+
|
|
43
|
+
Note that the name "materialize" comes from the flag `materialize_grads` from
|
|
44
|
+
`torch.autograd.grad`, which will be available in future torch releases.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
tensors = []
|
|
48
|
+
for optional_tensor, input in zip(optional_tensors, inputs):
|
|
49
|
+
if optional_tensor is None:
|
|
50
|
+
tensors.append(torch.zeros_like(input))
|
|
51
|
+
else:
|
|
52
|
+
tensors.append(optional_tensor)
|
|
53
|
+
return tuple(tensors)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _union(tensor_dicts: Iterable[_A]) -> _A:
|
|
57
|
+
output_type: type[_A] = EmptyTensorDict
|
|
58
|
+
output: _A = EmptyTensorDict()
|
|
59
|
+
for tensor_dict in tensor_dicts:
|
|
60
|
+
output_type = _least_common_ancestor(output_type, type(tensor_dict))
|
|
61
|
+
output |= tensor_dict
|
|
62
|
+
return output_type(output)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from torchjd._transform.base import Transform
|
|
2
|
+
from torchjd._transform.matrixify import Matrixify
|
|
3
|
+
from torchjd._transform.reshape import Reshape
|
|
4
|
+
from torchjd._transform.tensor_dict import Gradients, GradientVectors, JacobianMatrices, Jacobians
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def make_aggregation(
|
|
8
|
+
strategy: Transform[JacobianMatrices, GradientVectors]
|
|
9
|
+
) -> Transform[Jacobians, Gradients]:
|
|
10
|
+
"""TODO: doc"""
|
|
11
|
+
|
|
12
|
+
return Reshape(strategy.required_keys) << strategy << Matrixify(strategy.required_keys)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Generic, Sequence
|
|
5
|
+
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
from torchjd._transform._utils import _A, _B, _C, _union
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Transform(Generic[_B, _C], ABC):
|
|
12
|
+
r"""
|
|
13
|
+
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
|
|
14
|
+
descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
|
|
15
|
+
another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
|
|
16
|
+
and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
|
|
17
|
+
|
|
18
|
+
Formally a transform is a function:
|
|
19
|
+
|
|
20
|
+
.. math::
|
|
21
|
+
f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
|
|
22
|
+
|
|
23
|
+
where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
|
|
24
|
+
the value associated to the ``i`` th `required_key` of the input
|
|
25
|
+
:class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
|
|
26
|
+
value associated to the ``j`` th `output_key` of the output
|
|
27
|
+
:class:`~torchjd.transform.tensor_dict.TensorDict`.
|
|
28
|
+
|
|
29
|
+
As they are mathematical functions, transforms can be composed together as long as their
|
|
30
|
+
domains and range meaningfully match [TODO]. We also define conjunction of transforms [TODO].
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
|
|
34
|
+
return Composition(self, other)
|
|
35
|
+
|
|
36
|
+
def conjunct(self, other: Transform[_B, _C]) -> Transform[_B, _C]:
|
|
37
|
+
return Conjunction([self, other])
|
|
38
|
+
|
|
39
|
+
def __repr__(self) -> str:
|
|
40
|
+
return type(self).__name__
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def _compute(self, input: _B) -> _C:
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def __call__(self, input: _B) -> _C:
|
|
47
|
+
input.check_keys_are(self.required_keys)
|
|
48
|
+
return self._compute(input)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def required_keys(self) -> set[Tensor]:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def output_keys(self) -> set[Tensor]:
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
__lshift__ = compose
|
|
61
|
+
__or__ = conjunct
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Composition(Transform[_A, _C]):
|
|
65
|
+
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
|
|
66
|
+
if outer.required_keys != inner.output_keys:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"The `output_keys` of `inner` must match with the `required_keys` of "
|
|
69
|
+
f"outer. Found {outer.required_keys} and {inner.output_keys}"
|
|
70
|
+
)
|
|
71
|
+
self.outer = outer
|
|
72
|
+
self.inner = inner
|
|
73
|
+
|
|
74
|
+
def __repr__(self) -> str:
|
|
75
|
+
return repr(self.outer) + " ∘ " + repr(self.inner)
|
|
76
|
+
|
|
77
|
+
def _compute(self, input: _A) -> _C:
|
|
78
|
+
intermediate = self.inner(input)
|
|
79
|
+
return self.outer(intermediate)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def required_keys(self) -> set[Tensor]:
|
|
83
|
+
return self.inner.required_keys
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def output_keys(self) -> set[Tensor]:
|
|
87
|
+
return self.outer.output_keys
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class Conjunction(Transform[_A, _B]):
|
|
91
|
+
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
|
|
92
|
+
self.transforms = transforms
|
|
93
|
+
|
|
94
|
+
self._required_keys = set(
|
|
95
|
+
key for transform in transforms for key in transform.required_keys
|
|
96
|
+
)
|
|
97
|
+
for transform in transforms:
|
|
98
|
+
if transform.required_keys != self.required_keys:
|
|
99
|
+
raise ValueError("All transforms should require the same set of keys.")
|
|
100
|
+
|
|
101
|
+
output_keys_with_duplicates = [key for t in transforms for key in t.output_keys]
|
|
102
|
+
self._output_keys = set(output_keys_with_duplicates)
|
|
103
|
+
|
|
104
|
+
if len(self._output_keys) != len(output_keys_with_duplicates):
|
|
105
|
+
raise ValueError("The sets of output keys of transforms should be disjoint.")
|
|
106
|
+
|
|
107
|
+
def _compute(self, tensor_dict: _A) -> _B:
|
|
108
|
+
output = _union([transform(tensor_dict) for transform in self.transforms])
|
|
109
|
+
return output
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def required_keys(self) -> set[Tensor]:
|
|
113
|
+
return self._required_keys
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def output_keys(self) -> set[Tensor]:
|
|
117
|
+
return self._output_keys
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._utils import _A, _materialize, dicts_union
|
|
7
|
+
from torchjd._transform.base import Transform
|
|
8
|
+
from torchjd._transform.tensor_dict import Jacobians
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Concatenation(Transform[_A, Jacobians]):
|
|
12
|
+
def __init__(self, transforms: Sequence[Transform[_A, Jacobians]]):
|
|
13
|
+
if len(transforms) == 0:
|
|
14
|
+
raise ValueError("Parameter `transforms` cannot be empty.")
|
|
15
|
+
|
|
16
|
+
self.transforms = transforms
|
|
17
|
+
|
|
18
|
+
self._required_keys = transforms[0].required_keys
|
|
19
|
+
self._output_keys = {key for transform in transforms for key in transform.output_keys}
|
|
20
|
+
|
|
21
|
+
for transform in transforms[1:]:
|
|
22
|
+
if transform.required_keys != self.required_keys:
|
|
23
|
+
raise ValueError("All transforms should require the same set of keys.")
|
|
24
|
+
|
|
25
|
+
def _compute(self, input: _A) -> Jacobians:
|
|
26
|
+
results = [transform(input) for transform in self.transforms]
|
|
27
|
+
result = _concatenate(results)
|
|
28
|
+
return result
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def required_keys(self) -> set[Tensor]:
|
|
32
|
+
return self._required_keys
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def output_keys(self) -> set[Tensor]:
|
|
36
|
+
return self._output_keys
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _concatenate(jacobians_dicts: list[Jacobians]) -> Jacobians:
|
|
40
|
+
"""
|
|
41
|
+
Transforms a list of tensor dicts into a single dict of (concatenated) tensors. The set of keys
|
|
42
|
+
of the resulting dict is the union of the sets of keys of the input dicts. If a key is absent in
|
|
43
|
+
some input dicts, the corresponding concatenated tensor is filled with zeroes at the positions
|
|
44
|
+
corresponding to those dicts.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# It is important to first remove duplicate keys before computing their associated
|
|
48
|
+
# concatenated tensor. Otherwise, some computations would be duplicated. Therefore, we first
|
|
49
|
+
# compute unique_keys, and only then, we compute the concatenated tensors.
|
|
50
|
+
unique_keys = dicts_union(jacobians_dicts).keys()
|
|
51
|
+
result = Jacobians({key: _concatenate_one_key(jacobians_dicts, key) for key in unique_keys})
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _concatenate_one_key(jacobian_dicts: list[Jacobians], input: Tensor) -> Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Makes the concatenated tensor corresponding to a given key, from a list of tensor dicts.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
first_dimensions = [jacobian_dict.first_dimension for jacobian_dict in jacobian_dicts]
|
|
61
|
+
optional_jacobians = [jacobian.get(input, None) for jacobian in jacobian_dicts]
|
|
62
|
+
expanded_inputs = [input.expand(dim, *input.shape) for dim in first_dimensions]
|
|
63
|
+
jacobians = _materialize(optional_jacobians, expanded_inputs)
|
|
64
|
+
jacobian = torch.concatenate(jacobians, dim=0)
|
|
65
|
+
return jacobian
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._utils import ordered_set
|
|
7
|
+
from torchjd._transform.base import Transform
|
|
8
|
+
from torchjd._transform.tensor_dict import Gradients, Jacobians
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Diagonalize(Transform[Gradients, Jacobians]):
|
|
12
|
+
def __init__(self, considered: Iterable[Tensor]):
|
|
13
|
+
self.considered = ordered_set(considered)
|
|
14
|
+
self.indices: list[tuple[int, int]] = []
|
|
15
|
+
begin = 0
|
|
16
|
+
for tensor in self.considered:
|
|
17
|
+
end = begin + tensor.numel()
|
|
18
|
+
self.indices.append((begin, end))
|
|
19
|
+
begin = end
|
|
20
|
+
|
|
21
|
+
def _compute(self, tensors: Gradients) -> Jacobians:
|
|
22
|
+
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
|
|
23
|
+
diagonal_matrix = torch.cat(flattened_considered_values).diag()
|
|
24
|
+
diagonalized_tensors = {
|
|
25
|
+
key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
|
|
26
|
+
for (begin, end), key in zip(self.indices, self.considered)
|
|
27
|
+
}
|
|
28
|
+
return Jacobians(diagonalized_tensors)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def required_keys(self) -> set[Tensor]:
|
|
32
|
+
return set(self.considered)
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def output_keys(self) -> set[Tensor]:
|
|
36
|
+
return set(self.considered)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import Iterable, Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._differentiation import _Differentiation
|
|
7
|
+
from torchjd._transform._utils import _materialize
|
|
8
|
+
from torchjd._transform.tensor_dict import Gradients
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Grad(_Differentiation[Gradients]):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
outputs: Iterable[Tensor],
|
|
15
|
+
inputs: Iterable[Tensor],
|
|
16
|
+
retain_graph: bool = False,
|
|
17
|
+
):
|
|
18
|
+
super().__init__(outputs, inputs)
|
|
19
|
+
self.retain_graph = retain_graph
|
|
20
|
+
|
|
21
|
+
def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
|
|
22
|
+
return _grad(
|
|
23
|
+
outputs=self.outputs,
|
|
24
|
+
inputs=self.inputs,
|
|
25
|
+
grad_outputs=grad_outputs,
|
|
26
|
+
retain_graph=self.retain_graph,
|
|
27
|
+
create_graph=False,
|
|
28
|
+
allow_unused=True,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _grad(
|
|
33
|
+
outputs: Sequence[Tensor],
|
|
34
|
+
inputs: Sequence[Tensor],
|
|
35
|
+
grad_outputs: Sequence[Tensor],
|
|
36
|
+
retain_graph: bool,
|
|
37
|
+
create_graph: bool,
|
|
38
|
+
allow_unused: bool,
|
|
39
|
+
) -> tuple[Tensor, ...]:
|
|
40
|
+
"""
|
|
41
|
+
Wraps `autograd.grad` to give it additional responsibilities that it should have (like being
|
|
42
|
+
able to work with an empty sequence of `inputs`).
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
if len(inputs) == 0:
|
|
46
|
+
return tuple()
|
|
47
|
+
|
|
48
|
+
if len(outputs) == 0:
|
|
49
|
+
return tuple([torch.empty(input.shape) for input in inputs])
|
|
50
|
+
|
|
51
|
+
optional_grads = torch.autograd.grad(
|
|
52
|
+
outputs,
|
|
53
|
+
inputs,
|
|
54
|
+
grad_outputs=grad_outputs,
|
|
55
|
+
retain_graph=retain_graph,
|
|
56
|
+
create_graph=create_graph,
|
|
57
|
+
allow_unused=allow_unused,
|
|
58
|
+
)
|
|
59
|
+
grads = _materialize(optional_grads, inputs)
|
|
60
|
+
return grads
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from torchjd._transform._utils import _A
|
|
6
|
+
from torchjd._transform.base import Transform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Identity(Transform[_A, _A]):
|
|
10
|
+
def __init__(self, required_keys: Iterable[Tensor]):
|
|
11
|
+
self._required_keys = set(required_keys)
|
|
12
|
+
|
|
13
|
+
def _compute(self, tensor_dict: _A) -> _A:
|
|
14
|
+
return tensor_dict
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def required_keys(self) -> set[Tensor]:
|
|
18
|
+
return self._required_keys
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def output_keys(self) -> set[Tensor]:
|
|
22
|
+
return self._required_keys
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform.base import Transform
|
|
7
|
+
from torchjd._transform.tensor_dict import EmptyTensorDict, Gradients
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Init(Transform[EmptyTensorDict, Gradients]):
|
|
11
|
+
def __init__(self, values: Iterable[Tensor]):
|
|
12
|
+
self.values = set(values)
|
|
13
|
+
|
|
14
|
+
def _compute(self, input: EmptyTensorDict) -> Gradients:
|
|
15
|
+
r"""
|
|
16
|
+
Computes the gradients of the ``value`` with respect to itself. Returns the result as a
|
|
17
|
+
dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
|
|
18
|
+
tensor of 1s of identical shape, because :math:`\frac{\partial v}{\partial v} = 1` for any
|
|
19
|
+
:math:`v`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
return Gradients({value: torch.ones_like(value) for value in self.values})
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def required_keys(self) -> set[Tensor]:
|
|
26
|
+
return set()
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def output_keys(self) -> set[Tensor]:
|
|
30
|
+
return self.values
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from itertools import accumulate
|
|
2
|
+
from typing import Iterable, Sequence
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Size, Tensor
|
|
6
|
+
|
|
7
|
+
from torchjd._transform._differentiation import _Differentiation
|
|
8
|
+
from torchjd._transform._utils import _materialize
|
|
9
|
+
from torchjd._transform.tensor_dict import Jacobians
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Jac(_Differentiation[Jacobians]):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
outputs: Iterable[Tensor],
|
|
16
|
+
inputs: Iterable[Tensor],
|
|
17
|
+
chunk_size: int | None,
|
|
18
|
+
retain_graph: bool = False,
|
|
19
|
+
):
|
|
20
|
+
super().__init__(outputs, inputs)
|
|
21
|
+
self.chunk_size = chunk_size
|
|
22
|
+
self.retain_graph = retain_graph
|
|
23
|
+
|
|
24
|
+
def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
|
|
25
|
+
return _jac(
|
|
26
|
+
outputs=self.outputs,
|
|
27
|
+
inputs=self.inputs,
|
|
28
|
+
jac_outputs=jac_outputs,
|
|
29
|
+
chunk_size=self.chunk_size,
|
|
30
|
+
retain_graph=self.retain_graph,
|
|
31
|
+
create_graph=False,
|
|
32
|
+
allow_unused=True,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _jac(
|
|
37
|
+
outputs: Sequence[Tensor],
|
|
38
|
+
inputs: Sequence[Tensor],
|
|
39
|
+
jac_outputs: Sequence[Tensor],
|
|
40
|
+
chunk_size: int | None,
|
|
41
|
+
retain_graph: bool,
|
|
42
|
+
create_graph: bool,
|
|
43
|
+
allow_unused: bool,
|
|
44
|
+
) -> tuple[Tensor, ...]:
|
|
45
|
+
"""
|
|
46
|
+
Wraps the call to `autograd.grad` to compute the jacobian with respect to each input, in an
|
|
47
|
+
optimized way. The first dimension of the jacobians is equal to the length of the sequence of
|
|
48
|
+
`outputs`, which should be the same as the length of the sequence of `grad_outputs`. This should
|
|
49
|
+
be equivalent to calling `_grad(outputs[i], inputs, grad_outputs[i], ...)` for all i
|
|
50
|
+
sequentially, and stacking the elements of each resulting tuple.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
if len(inputs) == 0:
|
|
54
|
+
return tuple()
|
|
55
|
+
|
|
56
|
+
n_outputs = len(outputs)
|
|
57
|
+
if len(jac_outputs) != n_outputs:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"Parameters `outputs` and `jac_outputs` should be sequences of the same length. Found "
|
|
60
|
+
f"`len(outputs) = {n_outputs}` and `len(jac_outputs) = {len(jac_outputs)}`."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if n_outputs == 0:
|
|
64
|
+
return tuple([torch.empty((0,) + input.shape) for input in inputs])
|
|
65
|
+
|
|
66
|
+
def get_vjp(v):
|
|
67
|
+
optional_grads = torch.autograd.grad(
|
|
68
|
+
outputs,
|
|
69
|
+
inputs,
|
|
70
|
+
grad_outputs=v,
|
|
71
|
+
retain_graph=retain_graph,
|
|
72
|
+
create_graph=create_graph,
|
|
73
|
+
allow_unused=allow_unused,
|
|
74
|
+
)
|
|
75
|
+
grads = _materialize(optional_grads, inputs=inputs)
|
|
76
|
+
return torch.concatenate([grad.reshape([-1]) for grad in grads])
|
|
77
|
+
|
|
78
|
+
grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs)
|
|
79
|
+
|
|
80
|
+
lengths = [input.numel() for input in inputs]
|
|
81
|
+
jacobian_matrices = _extract_sub_matrices(grouped_jacobian_matrix, lengths)
|
|
82
|
+
|
|
83
|
+
shapes = [input.shape for input in inputs]
|
|
84
|
+
jacobians = _reshape_matrices(jacobian_matrices, shapes)
|
|
85
|
+
|
|
86
|
+
return tuple(jacobians)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:
|
|
90
|
+
cumulative_lengths = [*accumulate(lengths)]
|
|
91
|
+
|
|
92
|
+
if cumulative_lengths[-1] != matrix.shape[1]:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
"The sum of the provided lengths should be equal to the number of columns in the "
|
|
95
|
+
"provided matrix."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
start_indices = [0] + cumulative_lengths[:-1]
|
|
99
|
+
end_indices = cumulative_lengths
|
|
100
|
+
return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _reshape_matrices(matrices: Sequence[Tensor], shapes: Sequence[Size]) -> Sequence[Tensor]:
|
|
104
|
+
if len(matrices) != len(shapes):
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"Parameters `matrices` and `shapes` should contain the same number of elements."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return [matrix.view((matrix.shape[0],) + shape) for matrix, shape in zip(matrices, shapes)]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from torchjd._transform.base import Transform
|
|
6
|
+
from torchjd._transform.tensor_dict import JacobianMatrices, Jacobians
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Matrixify(Transform[Jacobians, JacobianMatrices]):
|
|
10
|
+
def __init__(self, required_keys: Iterable[Tensor]):
|
|
11
|
+
self._required_keys = set(required_keys)
|
|
12
|
+
|
|
13
|
+
def _compute(self, jacobians: Jacobians) -> JacobianMatrices:
|
|
14
|
+
jacobian_matrices = {
|
|
15
|
+
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
|
|
16
|
+
}
|
|
17
|
+
return JacobianMatrices(jacobian_matrices)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def required_keys(self) -> set[Tensor]:
|
|
21
|
+
return self._required_keys
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def output_keys(self) -> set[Tensor]:
|
|
25
|
+
return self._required_keys
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from torchjd._transform.base import Transform
|
|
6
|
+
from torchjd._transform.tensor_dict import Gradients, GradientVectors
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Reshape(Transform[GradientVectors, Gradients]):
|
|
10
|
+
def __init__(self, required_keys: Iterable[Tensor]):
|
|
11
|
+
self._required_keys = set(required_keys)
|
|
12
|
+
|
|
13
|
+
def _compute(self, gradient_vectors: GradientVectors) -> Gradients:
|
|
14
|
+
gradients = {
|
|
15
|
+
key: gradient_vector.view(key.shape)
|
|
16
|
+
for key, gradient_vector in gradient_vectors.items()
|
|
17
|
+
}
|
|
18
|
+
return Gradients(gradients)
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def required_keys(self) -> set[Tensor]:
|
|
22
|
+
return self._required_keys
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def output_keys(self) -> set[Tensor]:
|
|
26
|
+
return self._required_keys
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from torch import Tensor
|
|
2
|
+
|
|
3
|
+
from torchjd._transform._utils import _A
|
|
4
|
+
from torchjd._transform.base import Transform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Scaling(Transform[_A, _A]):
|
|
8
|
+
def __init__(self, scalings: dict[Tensor, float]):
|
|
9
|
+
self.scalings = scalings
|
|
10
|
+
|
|
11
|
+
def _compute(self, tensor_dict: _A) -> _A:
|
|
12
|
+
output = {key: scaling * tensor_dict[key] for key, scaling in self.scalings.items()}
|
|
13
|
+
return type(tensor_dict)(output)
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def required_keys(self) -> set[Tensor]:
|
|
17
|
+
return set(self.scalings.keys())
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def output_keys(self) -> set[Tensor]:
|
|
21
|
+
return set(self.scalings.keys())
|