torchjd 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchjd-0.1.0/LICENSE +21 -0
- torchjd-0.1.0/PKG-INFO +55 -0
- torchjd-0.1.0/README.md +25 -0
- torchjd-0.1.0/pyproject.toml +61 -0
- torchjd-0.1.0/setup.cfg +4 -0
- torchjd-0.1.0/src/torchjd/__init__.py +1 -0
- torchjd-0.1.0/src/torchjd/_transform/__init__.py +22 -0
- torchjd-0.1.0/src/torchjd/_transform/_differentiation.py +34 -0
- torchjd-0.1.0/src/torchjd/_transform/_utils.py +62 -0
- torchjd-0.1.0/src/torchjd/_transform/aggregation.py +12 -0
- torchjd-0.1.0/src/torchjd/_transform/base.py +117 -0
- torchjd-0.1.0/src/torchjd/_transform/concatenation.py +65 -0
- torchjd-0.1.0/src/torchjd/_transform/diagonalize.py +36 -0
- torchjd-0.1.0/src/torchjd/_transform/grad.py +60 -0
- torchjd-0.1.0/src/torchjd/_transform/identity.py +22 -0
- torchjd-0.1.0/src/torchjd/_transform/init.py +30 -0
- torchjd-0.1.0/src/torchjd/_transform/jac.py +109 -0
- torchjd-0.1.0/src/torchjd/_transform/matrixify.py +25 -0
- torchjd-0.1.0/src/torchjd/_transform/reshape.py +26 -0
- torchjd-0.1.0/src/torchjd/_transform/scaling.py +21 -0
- torchjd-0.1.0/src/torchjd/_transform/stack.py +63 -0
- torchjd-0.1.0/src/torchjd/_transform/store.py +29 -0
- torchjd-0.1.0/src/torchjd/_transform/strategy/__init__.py +4 -0
- torchjd-0.1.0/src/torchjd/_transform/strategy/_utils.py +86 -0
- torchjd-0.1.0/src/torchjd/_transform/strategy/extrapolating.py +75 -0
- torchjd-0.1.0/src/torchjd/_transform/strategy/isolating.py +25 -0
- torchjd-0.1.0/src/torchjd/_transform/strategy/partitioning.py +81 -0
- torchjd-0.1.0/src/torchjd/_transform/strategy/unifying.py +43 -0
- torchjd-0.1.0/src/torchjd/_transform/subset.py +27 -0
- torchjd-0.1.0/src/torchjd/_transform/tensor_dict.py +210 -0
- torchjd-0.1.0/src/torchjd/aggregation/__init__.py +16 -0
- torchjd-0.1.0/src/torchjd/aggregation/_gramian_utils.py +46 -0
- torchjd-0.1.0/src/torchjd/aggregation/_normalizing.py +48 -0
- torchjd-0.1.0/src/torchjd/aggregation/_pref_vector_utils.py +26 -0
- torchjd-0.1.0/src/torchjd/aggregation/_str_utils.py +11 -0
- torchjd-0.1.0/src/torchjd/aggregation/aligned_mtl.py +129 -0
- torchjd-0.1.0/src/torchjd/aggregation/bases.py +89 -0
- torchjd-0.1.0/src/torchjd/aggregation/cagrad.py +105 -0
- torchjd-0.1.0/src/torchjd/aggregation/constant.py +67 -0
- torchjd-0.1.0/src/torchjd/aggregation/dualproj.py +131 -0
- torchjd-0.1.0/src/torchjd/aggregation/graddrop.py +85 -0
- torchjd-0.1.0/src/torchjd/aggregation/imtl_g.py +50 -0
- torchjd-0.1.0/src/torchjd/aggregation/krum.py +108 -0
- torchjd-0.1.0/src/torchjd/aggregation/mean.py +42 -0
- torchjd-0.1.0/src/torchjd/aggregation/mgda.py +85 -0
- torchjd-0.1.0/src/torchjd/aggregation/nash_mtl.py +221 -0
- torchjd-0.1.0/src/torchjd/aggregation/pcgrad.py +72 -0
- torchjd-0.1.0/src/torchjd/aggregation/random.py +47 -0
- torchjd-0.1.0/src/torchjd/aggregation/sum.py +40 -0
- torchjd-0.1.0/src/torchjd/aggregation/trimmed_mean.py +73 -0
- torchjd-0.1.0/src/torchjd/aggregation/upgrad.py +136 -0
- torchjd-0.1.0/src/torchjd/backward.py +95 -0
- torchjd-0.1.0/src/torchjd.egg-info/PKG-INFO +55 -0
- torchjd-0.1.0/src/torchjd.egg-info/SOURCES.txt +55 -0
- torchjd-0.1.0/src/torchjd.egg-info/dependency_links.txt +1 -0
- torchjd-0.1.0/src/torchjd.egg-info/requires.txt +5 -0
- torchjd-0.1.0/src/torchjd.egg-info/top_level.txt +1 -0
torchjd-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Valérian Rey, Pierre Quinton
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
torchjd-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: torchjd
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Library for Jacobian Descent with PyTorch.
|
|
5
|
+
Author-email: Valerian Rey <valerian.rey@gmail.com>, Pierre Quinton <pierre.quinton@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://torchjd.org/
|
|
7
|
+
Project-URL: Documentation, https://torchjd.org/
|
|
8
|
+
Project-URL: Source, https://github.com/TorchJD/torchjd
|
|
9
|
+
Project-URL: Changelog, https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Requires-Python: >=3.10
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
License-File: LICENSE
|
|
25
|
+
Requires-Dist: torch>=2.0.0
|
|
26
|
+
Requires-Dist: quadprog!=0.1.10,>=0.1.9
|
|
27
|
+
Requires-Dist: numpy<2.0.0,>=1.21.0
|
|
28
|
+
Requires-Dist: qpsolvers>=1.0.1
|
|
29
|
+
Requires-Dist: cvxpy>=1.3.0
|
|
30
|
+
|
|
31
|
+
#  TorchJD
|
|
32
|
+
|
|
33
|
+
TorchJD is a library enabling Jacobian descent with PyTorch, for optimization of neural networks
|
|
34
|
+
with multiple objectives.
|
|
35
|
+
|
|
36
|
+
> [!IMPORTANT]
|
|
37
|
+
> This library is currently in an early development stage. The API is subject to significant changes
|
|
38
|
+
> in future versions. Use with caution in production environments and be prepared for potential
|
|
39
|
+
> breaking changes in upcoming releases.
|
|
40
|
+
|
|
41
|
+
## Installation
|
|
42
|
+
<!-- start installation -->
|
|
43
|
+
TorchJD can be installed directly with pip:
|
|
44
|
+
```bash
|
|
45
|
+
pip install torchjd
|
|
46
|
+
```
|
|
47
|
+
<!-- end installation -->
|
|
48
|
+
|
|
49
|
+
## Compatibility
|
|
50
|
+
TorchJD requires python 3.10, 3.11 or 3.12. It is only compatible with recent versions of PyTorch
|
|
51
|
+
(>= 2.0). For more information, read the `dependencies` in [pyproject.toml](./pyproject.toml).
|
|
52
|
+
|
|
53
|
+
## Contribution
|
|
54
|
+
|
|
55
|
+
Please read the [Contribution page](CONTRIBUTING.md).
|
torchjd-0.1.0/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
#  TorchJD
|
|
2
|
+
|
|
3
|
+
TorchJD is a library enabling Jacobian descent with PyTorch, for optimization of neural networks
|
|
4
|
+
with multiple objectives.
|
|
5
|
+
|
|
6
|
+
> [!IMPORTANT]
|
|
7
|
+
> This library is currently in an early development stage. The API is subject to significant changes
|
|
8
|
+
> in future versions. Use with caution in production environments and be prepared for potential
|
|
9
|
+
> breaking changes in upcoming releases.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
<!-- start installation -->
|
|
13
|
+
TorchJD can be installed directly with pip:
|
|
14
|
+
```bash
|
|
15
|
+
pip install torchjd
|
|
16
|
+
```
|
|
17
|
+
<!-- end installation -->
|
|
18
|
+
|
|
19
|
+
## Compatibility
|
|
20
|
+
TorchJD requires python 3.10, 3.11 or 3.12. It is only compatible with recent versions of PyTorch
|
|
21
|
+
(>= 2.0). For more information, read the `dependencies` in [pyproject.toml](./pyproject.toml).
|
|
22
|
+
|
|
23
|
+
## Contribution
|
|
24
|
+
|
|
25
|
+
Please read the [Contribution page](CONTRIBUTING.md).
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "torchjd"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Library for Jacobian Descent with PyTorch."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{name = "Valerian Rey", email = "valerian.rey@gmail.com"},
|
|
8
|
+
{name = "Pierre Quinton", email = "pierre.quinton@gmail.com"}
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"torch>=2.0.0",
|
|
13
|
+
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
|
|
14
|
+
"numpy>=1.21.0, <2.0.0", # Does not work before 1.21. cvxpy is not yet compatible with numpy>=2.0.0. The upper cap should be removed when this becomes the case, or when their pyproject.toml reflects the incompatibility. See https://github.com/cvxpy/cvxpy/issues/2474.
|
|
15
|
+
"qpsolvers>=1.0.1", # Does not work before 1.0.1
|
|
16
|
+
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
|
|
17
|
+
]
|
|
18
|
+
classifiers = [
|
|
19
|
+
"Development Status :: 4 - Beta",
|
|
20
|
+
"Intended Audience :: Developers",
|
|
21
|
+
"Intended Audience :: Science/Research",
|
|
22
|
+
"License :: OSI Approved :: MIT License",
|
|
23
|
+
"Operating System :: OS Independent",
|
|
24
|
+
"Programming Language :: Python",
|
|
25
|
+
"Programming Language :: Python :: 3",
|
|
26
|
+
"Programming Language :: Python :: 3.10",
|
|
27
|
+
"Programming Language :: Python :: 3.11",
|
|
28
|
+
"Programming Language :: Python :: 3.12",
|
|
29
|
+
"Topic :: Scientific/Engineering",
|
|
30
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.urls]
|
|
34
|
+
Homepage = "https://torchjd.org/"
|
|
35
|
+
Documentation = "https://torchjd.org/"
|
|
36
|
+
Source = "https://github.com/TorchJD/torchjd"
|
|
37
|
+
Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"
|
|
38
|
+
|
|
39
|
+
[tool.pdm.dev-dependencies]
|
|
40
|
+
check = [
|
|
41
|
+
"pre-commit>=2.9.2" # isort doesn't work before 2.9.2
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
doc = [
|
|
45
|
+
"sphinx>=6.0, !=7.2.0, !=7.2.1, !=7.2.3, !=7.2.4, !=7.2.5", # Versions in [7.2.0, 7.2.5] have a bug with an internal torch import from _C
|
|
46
|
+
"furo>=2023.0, <2024.04.27", # Force it to be recent so that the theme looks better, 2024.04.27 seems to have bugged link colors
|
|
47
|
+
"tomli>=1.1", # The load function doesn't work similarly before 1.1
|
|
48
|
+
"sphinx-autodoc-typehints>=1.16.0", # Some problems with TypeVars before 1.16
|
|
49
|
+
"myst-parser>=3.0.1" # Never tested lower versions
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
test = [
|
|
53
|
+
"pytest>=7.3", # Before version 7.3, not all tests are run
|
|
54
|
+
"contexttimer>=0.3.3, <0.3.4", # The test requiring contexttimer is not often run, so it could silently break if we uncap this library
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
plot = [
|
|
58
|
+
"plotly>=5.19.0", # Recent version to avoid problems, could be relaxed
|
|
59
|
+
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
|
|
60
|
+
"kaleido==0.2.1", # Only works with locked version
|
|
61
|
+
]
|
torchjd-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from torchjd.backward import backward
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from torchjd._transform.aggregation import make_aggregation
|
|
2
|
+
from torchjd._transform.base import Composition, Conjunction, Transform
|
|
3
|
+
from torchjd._transform.concatenation import Concatenation
|
|
4
|
+
from torchjd._transform.diagonalize import Diagonalize
|
|
5
|
+
from torchjd._transform.grad import Grad
|
|
6
|
+
from torchjd._transform.identity import Identity
|
|
7
|
+
from torchjd._transform.init import Init
|
|
8
|
+
from torchjd._transform.jac import Jac
|
|
9
|
+
from torchjd._transform.matrixify import Matrixify
|
|
10
|
+
from torchjd._transform.reshape import Reshape
|
|
11
|
+
from torchjd._transform.scaling import Scaling
|
|
12
|
+
from torchjd._transform.stack import Stack
|
|
13
|
+
from torchjd._transform.store import Store
|
|
14
|
+
from torchjd._transform.subset import Subset
|
|
15
|
+
from torchjd._transform.tensor_dict import (
|
|
16
|
+
EmptyTensorDict,
|
|
17
|
+
Gradients,
|
|
18
|
+
GradientVectors,
|
|
19
|
+
JacobianMatrices,
|
|
20
|
+
Jacobians,
|
|
21
|
+
TensorDict,
|
|
22
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Iterable, Sequence
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._utils import ordered_set
|
|
7
|
+
from torchjd._transform.base import _A, Transform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _Differentiation(Transform[_A, _A], ABC):
|
|
11
|
+
def __init__(self, outputs: Iterable[Tensor], inputs: Iterable[Tensor]):
|
|
12
|
+
self.outputs = ordered_set(outputs)
|
|
13
|
+
self.inputs = ordered_set(inputs)
|
|
14
|
+
|
|
15
|
+
def _compute(self, tensors: _A) -> _A:
|
|
16
|
+
tensor_outputs = [tensors[output] for output in self.outputs]
|
|
17
|
+
|
|
18
|
+
differentiated_tuple = self._differentiate(tensor_outputs)
|
|
19
|
+
new_differentiations = dict(zip(self.inputs, differentiated_tuple))
|
|
20
|
+
return type(tensors)(new_differentiations)
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def required_keys(self) -> set[Tensor]:
|
|
28
|
+
# outputs in the forward direction become inputs in the backward direction, and vice-versa
|
|
29
|
+
return set(self.outputs)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def output_keys(self) -> set[Tensor]:
|
|
33
|
+
# outputs in the forward direction become inputs in the backward direction, and vice-versa
|
|
34
|
+
return set(self.inputs)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
from typing import Hashable, Iterable, Sequence, TypeAlias, TypeVar
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torchjd._transform.tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
|
|
8
|
+
|
|
9
|
+
_KeyType = TypeVar("_KeyType", bound=Hashable)
|
|
10
|
+
_ValueType = TypeVar("_ValueType")
|
|
11
|
+
_OrderedSet: TypeAlias = OrderedDict[_KeyType, None]
|
|
12
|
+
|
|
13
|
+
_A = TypeVar("_A", bound=TensorDict)
|
|
14
|
+
_B = TypeVar("_B", bound=TensorDict)
|
|
15
|
+
_C = TypeVar("_C", bound=TensorDict)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
|
|
19
|
+
elements = list(elements)
|
|
20
|
+
result = OrderedDict.fromkeys(elements, None)
|
|
21
|
+
if len(elements) != len(result):
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Parameter `elements` should contain unique elements. Found `elements = {elements}`."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return result
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
|
|
30
|
+
result = {}
|
|
31
|
+
for d in dicts:
|
|
32
|
+
result |= d
|
|
33
|
+
return result
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _materialize(
|
|
37
|
+
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
|
|
38
|
+
) -> tuple[Tensor, ...]:
|
|
39
|
+
"""
|
|
40
|
+
Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
|
|
41
|
+
shape as the corresponding input. Returns the obtained sequence as a tuple.
|
|
42
|
+
|
|
43
|
+
Note that the name "materialize" comes from the flag `materialize_grads` from
|
|
44
|
+
`torch.autograd.grad`, which will be available in future torch releases.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
tensors = []
|
|
48
|
+
for optional_tensor, input in zip(optional_tensors, inputs):
|
|
49
|
+
if optional_tensor is None:
|
|
50
|
+
tensors.append(torch.zeros_like(input))
|
|
51
|
+
else:
|
|
52
|
+
tensors.append(optional_tensor)
|
|
53
|
+
return tuple(tensors)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _union(tensor_dicts: Iterable[_A]) -> _A:
|
|
57
|
+
output_type: type[_A] = EmptyTensorDict
|
|
58
|
+
output: _A = EmptyTensorDict()
|
|
59
|
+
for tensor_dict in tensor_dicts:
|
|
60
|
+
output_type = _least_common_ancestor(output_type, type(tensor_dict))
|
|
61
|
+
output |= tensor_dict
|
|
62
|
+
return output_type(output)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from torchjd._transform.base import Transform
|
|
2
|
+
from torchjd._transform.matrixify import Matrixify
|
|
3
|
+
from torchjd._transform.reshape import Reshape
|
|
4
|
+
from torchjd._transform.tensor_dict import Gradients, GradientVectors, JacobianMatrices, Jacobians
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def make_aggregation(
|
|
8
|
+
strategy: Transform[JacobianMatrices, GradientVectors]
|
|
9
|
+
) -> Transform[Jacobians, Gradients]:
|
|
10
|
+
"""TODO: doc"""
|
|
11
|
+
|
|
12
|
+
return Reshape(strategy.required_keys) << strategy << Matrixify(strategy.required_keys)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Generic, Sequence
|
|
5
|
+
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
from torchjd._transform._utils import _A, _B, _C, _union
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Transform(Generic[_B, _C], ABC):
|
|
12
|
+
r"""
|
|
13
|
+
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
|
|
14
|
+
descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
|
|
15
|
+
another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
|
|
16
|
+
and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
|
|
17
|
+
|
|
18
|
+
Formally a transform is a function:
|
|
19
|
+
|
|
20
|
+
.. math::
|
|
21
|
+
f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
|
|
22
|
+
|
|
23
|
+
where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
|
|
24
|
+
the value associated to the ``i`` th `required_key` of the input
|
|
25
|
+
:class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
|
|
26
|
+
value associated to the ``j`` th `output_key` of the output
|
|
27
|
+
:class:`~torchjd.transform.tensor_dict.TensorDict`.
|
|
28
|
+
|
|
29
|
+
As they are mathematical functions, transforms can be composed together as long as their
|
|
30
|
+
domains and range meaningfully match [TODO]. We also define conjunction of transforms [TODO].
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
|
|
34
|
+
return Composition(self, other)
|
|
35
|
+
|
|
36
|
+
def conjunct(self, other: Transform[_B, _C]) -> Transform[_B, _C]:
|
|
37
|
+
return Conjunction([self, other])
|
|
38
|
+
|
|
39
|
+
def __repr__(self) -> str:
|
|
40
|
+
return type(self).__name__
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def _compute(self, input: _B) -> _C:
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def __call__(self, input: _B) -> _C:
|
|
47
|
+
input.check_keys_are(self.required_keys)
|
|
48
|
+
return self._compute(input)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def required_keys(self) -> set[Tensor]:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def output_keys(self) -> set[Tensor]:
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
__lshift__ = compose
|
|
61
|
+
__or__ = conjunct
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Composition(Transform[_A, _C]):
|
|
65
|
+
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
|
|
66
|
+
if outer.required_keys != inner.output_keys:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"The `output_keys` of `inner` must match with the `required_keys` of "
|
|
69
|
+
f"outer. Found {outer.required_keys} and {inner.output_keys}"
|
|
70
|
+
)
|
|
71
|
+
self.outer = outer
|
|
72
|
+
self.inner = inner
|
|
73
|
+
|
|
74
|
+
def __repr__(self) -> str:
|
|
75
|
+
return repr(self.outer) + " ∘ " + repr(self.inner)
|
|
76
|
+
|
|
77
|
+
def _compute(self, input: _A) -> _C:
|
|
78
|
+
intermediate = self.inner(input)
|
|
79
|
+
return self.outer(intermediate)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def required_keys(self) -> set[Tensor]:
|
|
83
|
+
return self.inner.required_keys
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def output_keys(self) -> set[Tensor]:
|
|
87
|
+
return self.outer.output_keys
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class Conjunction(Transform[_A, _B]):
|
|
91
|
+
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
|
|
92
|
+
self.transforms = transforms
|
|
93
|
+
|
|
94
|
+
self._required_keys = set(
|
|
95
|
+
key for transform in transforms for key in transform.required_keys
|
|
96
|
+
)
|
|
97
|
+
for transform in transforms:
|
|
98
|
+
if transform.required_keys != self.required_keys:
|
|
99
|
+
raise ValueError("All transforms should require the same set of keys.")
|
|
100
|
+
|
|
101
|
+
output_keys_with_duplicates = [key for t in transforms for key in t.output_keys]
|
|
102
|
+
self._output_keys = set(output_keys_with_duplicates)
|
|
103
|
+
|
|
104
|
+
if len(self._output_keys) != len(output_keys_with_duplicates):
|
|
105
|
+
raise ValueError("The sets of output keys of transforms should be disjoint.")
|
|
106
|
+
|
|
107
|
+
def _compute(self, tensor_dict: _A) -> _B:
|
|
108
|
+
output = _union([transform(tensor_dict) for transform in self.transforms])
|
|
109
|
+
return output
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def required_keys(self) -> set[Tensor]:
|
|
113
|
+
return self._required_keys
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def output_keys(self) -> set[Tensor]:
|
|
117
|
+
return self._output_keys
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._utils import _A, _materialize, dicts_union
|
|
7
|
+
from torchjd._transform.base import Transform
|
|
8
|
+
from torchjd._transform.tensor_dict import Jacobians
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Concatenation(Transform[_A, Jacobians]):
|
|
12
|
+
def __init__(self, transforms: Sequence[Transform[_A, Jacobians]]):
|
|
13
|
+
if len(transforms) == 0:
|
|
14
|
+
raise ValueError("Parameter `transforms` cannot be empty.")
|
|
15
|
+
|
|
16
|
+
self.transforms = transforms
|
|
17
|
+
|
|
18
|
+
self._required_keys = transforms[0].required_keys
|
|
19
|
+
self._output_keys = {key for transform in transforms for key in transform.output_keys}
|
|
20
|
+
|
|
21
|
+
for transform in transforms[1:]:
|
|
22
|
+
if transform.required_keys != self.required_keys:
|
|
23
|
+
raise ValueError("All transforms should require the same set of keys.")
|
|
24
|
+
|
|
25
|
+
def _compute(self, input: _A) -> Jacobians:
|
|
26
|
+
results = [transform(input) for transform in self.transforms]
|
|
27
|
+
result = _concatenate(results)
|
|
28
|
+
return result
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def required_keys(self) -> set[Tensor]:
|
|
32
|
+
return self._required_keys
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def output_keys(self) -> set[Tensor]:
|
|
36
|
+
return self._output_keys
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _concatenate(jacobians_dicts: list[Jacobians]) -> Jacobians:
|
|
40
|
+
"""
|
|
41
|
+
Transforms a list of tensor dicts into a single dict of (concatenated) tensors. The set of keys
|
|
42
|
+
of the resulting dict is the union of the sets of keys of the input dicts. If a key is absent in
|
|
43
|
+
some input dicts, the corresponding concatenated tensor is filled with zeroes at the positions
|
|
44
|
+
corresponding to those dicts.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# It is important to first remove duplicate keys before computing their associated
|
|
48
|
+
# concatenated tensor. Otherwise, some computations would be duplicated. Therefore, we first
|
|
49
|
+
# compute unique_keys, and only then, we compute the concatenated tensors.
|
|
50
|
+
unique_keys = dicts_union(jacobians_dicts).keys()
|
|
51
|
+
result = Jacobians({key: _concatenate_one_key(jacobians_dicts, key) for key in unique_keys})
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _concatenate_one_key(jacobian_dicts: list[Jacobians], input: Tensor) -> Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Makes the concatenated tensor corresponding to a given key, from a list of tensor dicts.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
first_dimensions = [jacobian_dict.first_dimension for jacobian_dict in jacobian_dicts]
|
|
61
|
+
optional_jacobians = [jacobian.get(input, None) for jacobian in jacobian_dicts]
|
|
62
|
+
expanded_inputs = [input.expand(dim, *input.shape) for dim in first_dimensions]
|
|
63
|
+
jacobians = _materialize(optional_jacobians, expanded_inputs)
|
|
64
|
+
jacobian = torch.concatenate(jacobians, dim=0)
|
|
65
|
+
return jacobian
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._utils import ordered_set
|
|
7
|
+
from torchjd._transform.base import Transform
|
|
8
|
+
from torchjd._transform.tensor_dict import Gradients, Jacobians
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Diagonalize(Transform[Gradients, Jacobians]):
|
|
12
|
+
def __init__(self, considered: Iterable[Tensor]):
|
|
13
|
+
self.considered = ordered_set(considered)
|
|
14
|
+
self.indices: list[tuple[int, int]] = []
|
|
15
|
+
begin = 0
|
|
16
|
+
for tensor in self.considered:
|
|
17
|
+
end = begin + tensor.numel()
|
|
18
|
+
self.indices.append((begin, end))
|
|
19
|
+
begin = end
|
|
20
|
+
|
|
21
|
+
def _compute(self, tensors: Gradients) -> Jacobians:
|
|
22
|
+
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
|
|
23
|
+
diagonal_matrix = torch.cat(flattened_considered_values).diag()
|
|
24
|
+
diagonalized_tensors = {
|
|
25
|
+
key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
|
|
26
|
+
for (begin, end), key in zip(self.indices, self.considered)
|
|
27
|
+
}
|
|
28
|
+
return Jacobians(diagonalized_tensors)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def required_keys(self) -> set[Tensor]:
|
|
32
|
+
return set(self.considered)
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def output_keys(self) -> set[Tensor]:
|
|
36
|
+
return set(self.considered)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import Iterable, Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform._differentiation import _Differentiation
|
|
7
|
+
from torchjd._transform._utils import _materialize
|
|
8
|
+
from torchjd._transform.tensor_dict import Gradients
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Grad(_Differentiation[Gradients]):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
outputs: Iterable[Tensor],
|
|
15
|
+
inputs: Iterable[Tensor],
|
|
16
|
+
retain_graph: bool = False,
|
|
17
|
+
):
|
|
18
|
+
super().__init__(outputs, inputs)
|
|
19
|
+
self.retain_graph = retain_graph
|
|
20
|
+
|
|
21
|
+
def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
|
|
22
|
+
return _grad(
|
|
23
|
+
outputs=self.outputs,
|
|
24
|
+
inputs=self.inputs,
|
|
25
|
+
grad_outputs=grad_outputs,
|
|
26
|
+
retain_graph=self.retain_graph,
|
|
27
|
+
create_graph=False,
|
|
28
|
+
allow_unused=True,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _grad(
|
|
33
|
+
outputs: Sequence[Tensor],
|
|
34
|
+
inputs: Sequence[Tensor],
|
|
35
|
+
grad_outputs: Sequence[Tensor],
|
|
36
|
+
retain_graph: bool,
|
|
37
|
+
create_graph: bool,
|
|
38
|
+
allow_unused: bool,
|
|
39
|
+
) -> tuple[Tensor, ...]:
|
|
40
|
+
"""
|
|
41
|
+
Wraps `autograd.grad` to give it additional responsibilities that it should have (like being
|
|
42
|
+
able to work with an empty sequence of `inputs`).
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
if len(inputs) == 0:
|
|
46
|
+
return tuple()
|
|
47
|
+
|
|
48
|
+
if len(outputs) == 0:
|
|
49
|
+
return tuple([torch.empty(input.shape) for input in inputs])
|
|
50
|
+
|
|
51
|
+
optional_grads = torch.autograd.grad(
|
|
52
|
+
outputs,
|
|
53
|
+
inputs,
|
|
54
|
+
grad_outputs=grad_outputs,
|
|
55
|
+
retain_graph=retain_graph,
|
|
56
|
+
create_graph=create_graph,
|
|
57
|
+
allow_unused=allow_unused,
|
|
58
|
+
)
|
|
59
|
+
grads = _materialize(optional_grads, inputs)
|
|
60
|
+
return grads
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from torchjd._transform._utils import _A
|
|
6
|
+
from torchjd._transform.base import Transform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Identity(Transform[_A, _A]):
|
|
10
|
+
def __init__(self, required_keys: Iterable[Tensor]):
|
|
11
|
+
self._required_keys = set(required_keys)
|
|
12
|
+
|
|
13
|
+
def _compute(self, tensor_dict: _A) -> _A:
|
|
14
|
+
return tensor_dict
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def required_keys(self) -> set[Tensor]:
|
|
18
|
+
return self._required_keys
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def output_keys(self) -> set[Tensor]:
|
|
22
|
+
return self._required_keys
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torchjd._transform.base import Transform
|
|
7
|
+
from torchjd._transform.tensor_dict import EmptyTensorDict, Gradients
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Init(Transform[EmptyTensorDict, Gradients]):
|
|
11
|
+
def __init__(self, values: Iterable[Tensor]):
|
|
12
|
+
self.values = set(values)
|
|
13
|
+
|
|
14
|
+
def _compute(self, input: EmptyTensorDict) -> Gradients:
|
|
15
|
+
r"""
|
|
16
|
+
Computes the gradients of the ``value`` with respect to itself. Returns the result as a
|
|
17
|
+
dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
|
|
18
|
+
tensor of 1s of identical shape, because :math:`\frac{\partial v}{\partial v} = 1` for any
|
|
19
|
+
:math:`v`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
return Gradients({value: torch.ones_like(value) for value in self.values})
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def required_keys(self) -> set[Tensor]:
|
|
26
|
+
return set()
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def output_keys(self) -> set[Tensor]:
|
|
30
|
+
return self.values
|