torchjd 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. torchjd-0.1.0/LICENSE +21 -0
  2. torchjd-0.1.0/PKG-INFO +55 -0
  3. torchjd-0.1.0/README.md +25 -0
  4. torchjd-0.1.0/pyproject.toml +61 -0
  5. torchjd-0.1.0/setup.cfg +4 -0
  6. torchjd-0.1.0/src/torchjd/__init__.py +1 -0
  7. torchjd-0.1.0/src/torchjd/_transform/__init__.py +22 -0
  8. torchjd-0.1.0/src/torchjd/_transform/_differentiation.py +34 -0
  9. torchjd-0.1.0/src/torchjd/_transform/_utils.py +62 -0
  10. torchjd-0.1.0/src/torchjd/_transform/aggregation.py +12 -0
  11. torchjd-0.1.0/src/torchjd/_transform/base.py +117 -0
  12. torchjd-0.1.0/src/torchjd/_transform/concatenation.py +65 -0
  13. torchjd-0.1.0/src/torchjd/_transform/diagonalize.py +36 -0
  14. torchjd-0.1.0/src/torchjd/_transform/grad.py +60 -0
  15. torchjd-0.1.0/src/torchjd/_transform/identity.py +22 -0
  16. torchjd-0.1.0/src/torchjd/_transform/init.py +30 -0
  17. torchjd-0.1.0/src/torchjd/_transform/jac.py +109 -0
  18. torchjd-0.1.0/src/torchjd/_transform/matrixify.py +25 -0
  19. torchjd-0.1.0/src/torchjd/_transform/reshape.py +26 -0
  20. torchjd-0.1.0/src/torchjd/_transform/scaling.py +21 -0
  21. torchjd-0.1.0/src/torchjd/_transform/stack.py +63 -0
  22. torchjd-0.1.0/src/torchjd/_transform/store.py +29 -0
  23. torchjd-0.1.0/src/torchjd/_transform/strategy/__init__.py +4 -0
  24. torchjd-0.1.0/src/torchjd/_transform/strategy/_utils.py +86 -0
  25. torchjd-0.1.0/src/torchjd/_transform/strategy/extrapolating.py +75 -0
  26. torchjd-0.1.0/src/torchjd/_transform/strategy/isolating.py +25 -0
  27. torchjd-0.1.0/src/torchjd/_transform/strategy/partitioning.py +81 -0
  28. torchjd-0.1.0/src/torchjd/_transform/strategy/unifying.py +43 -0
  29. torchjd-0.1.0/src/torchjd/_transform/subset.py +27 -0
  30. torchjd-0.1.0/src/torchjd/_transform/tensor_dict.py +210 -0
  31. torchjd-0.1.0/src/torchjd/aggregation/__init__.py +16 -0
  32. torchjd-0.1.0/src/torchjd/aggregation/_gramian_utils.py +46 -0
  33. torchjd-0.1.0/src/torchjd/aggregation/_normalizing.py +48 -0
  34. torchjd-0.1.0/src/torchjd/aggregation/_pref_vector_utils.py +26 -0
  35. torchjd-0.1.0/src/torchjd/aggregation/_str_utils.py +11 -0
  36. torchjd-0.1.0/src/torchjd/aggregation/aligned_mtl.py +129 -0
  37. torchjd-0.1.0/src/torchjd/aggregation/bases.py +89 -0
  38. torchjd-0.1.0/src/torchjd/aggregation/cagrad.py +105 -0
  39. torchjd-0.1.0/src/torchjd/aggregation/constant.py +67 -0
  40. torchjd-0.1.0/src/torchjd/aggregation/dualproj.py +131 -0
  41. torchjd-0.1.0/src/torchjd/aggregation/graddrop.py +85 -0
  42. torchjd-0.1.0/src/torchjd/aggregation/imtl_g.py +50 -0
  43. torchjd-0.1.0/src/torchjd/aggregation/krum.py +108 -0
  44. torchjd-0.1.0/src/torchjd/aggregation/mean.py +42 -0
  45. torchjd-0.1.0/src/torchjd/aggregation/mgda.py +85 -0
  46. torchjd-0.1.0/src/torchjd/aggregation/nash_mtl.py +221 -0
  47. torchjd-0.1.0/src/torchjd/aggregation/pcgrad.py +72 -0
  48. torchjd-0.1.0/src/torchjd/aggregation/random.py +47 -0
  49. torchjd-0.1.0/src/torchjd/aggregation/sum.py +40 -0
  50. torchjd-0.1.0/src/torchjd/aggregation/trimmed_mean.py +73 -0
  51. torchjd-0.1.0/src/torchjd/aggregation/upgrad.py +136 -0
  52. torchjd-0.1.0/src/torchjd/backward.py +95 -0
  53. torchjd-0.1.0/src/torchjd.egg-info/PKG-INFO +55 -0
  54. torchjd-0.1.0/src/torchjd.egg-info/SOURCES.txt +55 -0
  55. torchjd-0.1.0/src/torchjd.egg-info/dependency_links.txt +1 -0
  56. torchjd-0.1.0/src/torchjd.egg-info/requires.txt +5 -0
  57. torchjd-0.1.0/src/torchjd.egg-info/top_level.txt +1 -0
torchjd-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Valérian Rey, Pierre Quinton
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torchjd-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.1
2
+ Name: torchjd
3
+ Version: 0.1.0
4
+ Summary: Library for Jacobian Descent with PyTorch.
5
+ Author-email: Valerian Rey <valerian.rey@gmail.com>, Pierre Quinton <pierre.quinton@gmail.com>
6
+ Project-URL: Homepage, https://torchjd.org/
7
+ Project-URL: Documentation, https://torchjd.org/
8
+ Project-URL: Source, https://github.com/TorchJD/torchjd
9
+ Project-URL: Changelog, https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: OS Independent
15
+ Classifier: Programming Language :: Python
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Topic :: Scientific/Engineering
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.10
23
+ Description-Content-Type: text/markdown
24
+ License-File: LICENSE
25
+ Requires-Dist: torch>=2.0.0
26
+ Requires-Dist: quadprog!=0.1.10,>=0.1.9
27
+ Requires-Dist: numpy<2.0.0,>=1.21.0
28
+ Requires-Dist: qpsolvers>=1.0.1
29
+ Requires-Dist: cvxpy>=1.3.0
30
+
31
+ # ![image](docs/source/icons/favicon-32x32.png) TorchJD
32
+
33
+ TorchJD is a library enabling Jacobian descent with PyTorch, for optimization of neural networks
34
+ with multiple objectives.
35
+
36
+ > [!IMPORTANT]
37
+ > This library is currently in an early development stage. The API is subject to significant changes
38
+ > in future versions. Use with caution in production environments and be prepared for potential
39
+ > breaking changes in upcoming releases.
40
+
41
+ ## Installation
42
+ <!-- start installation -->
43
+ TorchJD can be installed directly with pip:
44
+ ```bash
45
+ pip install torchjd
46
+ ```
47
+ <!-- end installation -->
48
+
49
+ ## Compatibility
50
+ TorchJD requires python 3.10, 3.11 or 3.12. It is only compatible with recent versions of PyTorch
51
+ (>= 2.0). For more information, read the `dependencies` in [pyproject.toml](./pyproject.toml).
52
+
53
+ ## Contribution
54
+
55
+ Please read the [Contribution page](CONTRIBUTING.md).
@@ -0,0 +1,25 @@
1
+ # ![image](docs/source/icons/favicon-32x32.png) TorchJD
2
+
3
+ TorchJD is a library enabling Jacobian descent with PyTorch, for optimization of neural networks
4
+ with multiple objectives.
5
+
6
+ > [!IMPORTANT]
7
+ > This library is currently in an early development stage. The API is subject to significant changes
8
+ > in future versions. Use with caution in production environments and be prepared for potential
9
+ > breaking changes in upcoming releases.
10
+
11
+ ## Installation
12
+ <!-- start installation -->
13
+ TorchJD can be installed directly with pip:
14
+ ```bash
15
+ pip install torchjd
16
+ ```
17
+ <!-- end installation -->
18
+
19
+ ## Compatibility
20
+ TorchJD requires python 3.10, 3.11 or 3.12. It is only compatible with recent versions of PyTorch
21
+ (>= 2.0). For more information, read the `dependencies` in [pyproject.toml](./pyproject.toml).
22
+
23
+ ## Contribution
24
+
25
+ Please read the [Contribution page](CONTRIBUTING.md).
@@ -0,0 +1,61 @@
1
+ [project]
2
+ name = "torchjd"
3
+ version = "0.1.0"
4
+ description = "Library for Jacobian Descent with PyTorch."
5
+ readme = "README.md"
6
+ authors = [
7
+ {name = "Valerian Rey", email = "valerian.rey@gmail.com"},
8
+ {name = "Pierre Quinton", email = "pierre.quinton@gmail.com"}
9
+ ]
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "torch>=2.0.0",
13
+ "quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
14
+ "numpy>=1.21.0, <2.0.0", # Does not work before 1.21. cvxpy is not yet compatible with numpy>=2.0.0. The upper cap should be removed when this becomes the case, or when their pyproject.toml reflects the incompatibility. See https://github.com/cvxpy/cvxpy/issues/2474.
15
+ "qpsolvers>=1.0.1", # Does not work before 1.0.1
16
+ "cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
17
+ ]
18
+ classifiers = [
19
+ "Development Status :: 4 - Beta",
20
+ "Intended Audience :: Developers",
21
+ "Intended Audience :: Science/Research",
22
+ "License :: OSI Approved :: MIT License",
23
+ "Operating System :: OS Independent",
24
+ "Programming Language :: Python",
25
+ "Programming Language :: Python :: 3",
26
+ "Programming Language :: Python :: 3.10",
27
+ "Programming Language :: Python :: 3.11",
28
+ "Programming Language :: Python :: 3.12",
29
+ "Topic :: Scientific/Engineering",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ ]
32
+
33
+ [project.urls]
34
+ Homepage = "https://torchjd.org/"
35
+ Documentation = "https://torchjd.org/"
36
+ Source = "https://github.com/TorchJD/torchjd"
37
+ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"
38
+
39
+ [tool.pdm.dev-dependencies]
40
+ check = [
41
+ "pre-commit>=2.9.2" # isort doesn't work before 2.9.2
42
+ ]
43
+
44
+ doc = [
45
+ "sphinx>=6.0, !=7.2.0, !=7.2.1, !=7.2.3, !=7.2.4, !=7.2.5", # Versions in [7.2.0, 7.2.5] have a bug with an internal torch import from _C
46
+ "furo>=2023.0, <2024.04.27", # Force it to be recent so that the theme looks better, 2024.04.27 seems to have bugged link colors
47
+ "tomli>=1.1", # The load function doesn't work similarly before 1.1
48
+ "sphinx-autodoc-typehints>=1.16.0", # Some problems with TypeVars before 1.16
49
+ "myst-parser>=3.0.1" # Never tested lower versions
50
+ ]
51
+
52
+ test = [
53
+ "pytest>=7.3", # Before version 7.3, not all tests are run
54
+ "contexttimer>=0.3.3, <0.3.4", # The test requiring contexttimer is not often run, so it could silently break if we uncap this library
55
+ ]
56
+
57
+ plot = [
58
+ "plotly>=5.19.0", # Recent version to avoid problems, could be relaxed
59
+ "dash>=2.16.0", # Recent version to avoid problems, could be relaxed
60
+ "kaleido==0.2.1", # Only works with locked version
61
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1 @@
1
+ from torchjd.backward import backward
@@ -0,0 +1,22 @@
1
+ from torchjd._transform.aggregation import make_aggregation
2
+ from torchjd._transform.base import Composition, Conjunction, Transform
3
+ from torchjd._transform.concatenation import Concatenation
4
+ from torchjd._transform.diagonalize import Diagonalize
5
+ from torchjd._transform.grad import Grad
6
+ from torchjd._transform.identity import Identity
7
+ from torchjd._transform.init import Init
8
+ from torchjd._transform.jac import Jac
9
+ from torchjd._transform.matrixify import Matrixify
10
+ from torchjd._transform.reshape import Reshape
11
+ from torchjd._transform.scaling import Scaling
12
+ from torchjd._transform.stack import Stack
13
+ from torchjd._transform.store import Store
14
+ from torchjd._transform.subset import Subset
15
+ from torchjd._transform.tensor_dict import (
16
+ EmptyTensorDict,
17
+ Gradients,
18
+ GradientVectors,
19
+ JacobianMatrices,
20
+ Jacobians,
21
+ TensorDict,
22
+ )
@@ -0,0 +1,34 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Iterable, Sequence
3
+
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._utils import ordered_set
7
+ from torchjd._transform.base import _A, Transform
8
+
9
+
10
+ class _Differentiation(Transform[_A, _A], ABC):
11
+ def __init__(self, outputs: Iterable[Tensor], inputs: Iterable[Tensor]):
12
+ self.outputs = ordered_set(outputs)
13
+ self.inputs = ordered_set(inputs)
14
+
15
+ def _compute(self, tensors: _A) -> _A:
16
+ tensor_outputs = [tensors[output] for output in self.outputs]
17
+
18
+ differentiated_tuple = self._differentiate(tensor_outputs)
19
+ new_differentiations = dict(zip(self.inputs, differentiated_tuple))
20
+ return type(tensors)(new_differentiations)
21
+
22
+ @abstractmethod
23
+ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
24
+ raise NotImplementedError
25
+
26
+ @property
27
+ def required_keys(self) -> set[Tensor]:
28
+ # outputs in the forward direction become inputs in the backward direction, and vice-versa
29
+ return set(self.outputs)
30
+
31
+ @property
32
+ def output_keys(self) -> set[Tensor]:
33
+ # outputs in the forward direction become inputs in the backward direction, and vice-versa
34
+ return set(self.inputs)
@@ -0,0 +1,62 @@
1
+ from collections import OrderedDict
2
+ from typing import Hashable, Iterable, Sequence, TypeAlias, TypeVar
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from torchjd._transform.tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
8
+
9
+ _KeyType = TypeVar("_KeyType", bound=Hashable)
10
+ _ValueType = TypeVar("_ValueType")
11
+ _OrderedSet: TypeAlias = OrderedDict[_KeyType, None]
12
+
13
+ _A = TypeVar("_A", bound=TensorDict)
14
+ _B = TypeVar("_B", bound=TensorDict)
15
+ _C = TypeVar("_C", bound=TensorDict)
16
+
17
+
18
+ def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
19
+ elements = list(elements)
20
+ result = OrderedDict.fromkeys(elements, None)
21
+ if len(elements) != len(result):
22
+ raise ValueError(
23
+ f"Parameter `elements` should contain unique elements. Found `elements = {elements}`."
24
+ )
25
+
26
+ return result
27
+
28
+
29
+ def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
30
+ result = {}
31
+ for d in dicts:
32
+ result |= d
33
+ return result
34
+
35
+
36
+ def _materialize(
37
+ optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
38
+ ) -> tuple[Tensor, ...]:
39
+ """
40
+ Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
41
+ shape as the corresponding input. Returns the obtained sequence as a tuple.
42
+
43
+ Note that the name "materialize" comes from the flag `materialize_grads` from
44
+ `torch.autograd.grad`, which will be available in future torch releases.
45
+ """
46
+
47
+ tensors = []
48
+ for optional_tensor, input in zip(optional_tensors, inputs):
49
+ if optional_tensor is None:
50
+ tensors.append(torch.zeros_like(input))
51
+ else:
52
+ tensors.append(optional_tensor)
53
+ return tuple(tensors)
54
+
55
+
56
+ def _union(tensor_dicts: Iterable[_A]) -> _A:
57
+ output_type: type[_A] = EmptyTensorDict
58
+ output: _A = EmptyTensorDict()
59
+ for tensor_dict in tensor_dicts:
60
+ output_type = _least_common_ancestor(output_type, type(tensor_dict))
61
+ output |= tensor_dict
62
+ return output_type(output)
@@ -0,0 +1,12 @@
1
+ from torchjd._transform.base import Transform
2
+ from torchjd._transform.matrixify import Matrixify
3
+ from torchjd._transform.reshape import Reshape
4
+ from torchjd._transform.tensor_dict import Gradients, GradientVectors, JacobianMatrices, Jacobians
5
+
6
+
7
+ def make_aggregation(
8
+ strategy: Transform[JacobianMatrices, GradientVectors]
9
+ ) -> Transform[Jacobians, Gradients]:
10
+ """TODO: doc"""
11
+
12
+ return Reshape(strategy.required_keys) << strategy << Matrixify(strategy.required_keys)
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Generic, Sequence
5
+
6
+ from torch import Tensor
7
+
8
+ from torchjd._transform._utils import _A, _B, _C, _union
9
+
10
+
11
+ class Transform(Generic[_B, _C], ABC):
12
+ r"""
13
+ Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
14
+ descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
15
+ another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
16
+ and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
17
+
18
+ Formally a transform is a function:
19
+
20
+ .. math::
21
+ f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
22
+
23
+ where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
24
+ the value associated to the ``i`` th `required_key` of the input
25
+ :class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
26
+ value associated to the ``j`` th `output_key` of the output
27
+ :class:`~torchjd.transform.tensor_dict.TensorDict`.
28
+
29
+ As they are mathematical functions, transforms can be composed together as long as their
30
+ domains and range meaningfully match [TODO]. We also define conjunction of transforms [TODO].
31
+ """
32
+
33
+ def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
34
+ return Composition(self, other)
35
+
36
+ def conjunct(self, other: Transform[_B, _C]) -> Transform[_B, _C]:
37
+ return Conjunction([self, other])
38
+
39
+ def __repr__(self) -> str:
40
+ return type(self).__name__
41
+
42
+ @abstractmethod
43
+ def _compute(self, input: _B) -> _C:
44
+ raise NotImplementedError
45
+
46
+ def __call__(self, input: _B) -> _C:
47
+ input.check_keys_are(self.required_keys)
48
+ return self._compute(input)
49
+
50
+ @property
51
+ @abstractmethod
52
+ def required_keys(self) -> set[Tensor]:
53
+ raise NotImplementedError
54
+
55
+ @property
56
+ @abstractmethod
57
+ def output_keys(self) -> set[Tensor]:
58
+ raise NotImplementedError
59
+
60
+ __lshift__ = compose
61
+ __or__ = conjunct
62
+
63
+
64
+ class Composition(Transform[_A, _C]):
65
+ def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
66
+ if outer.required_keys != inner.output_keys:
67
+ raise ValueError(
68
+ "The `output_keys` of `inner` must match with the `required_keys` of "
69
+ f"outer. Found {outer.required_keys} and {inner.output_keys}"
70
+ )
71
+ self.outer = outer
72
+ self.inner = inner
73
+
74
+ def __repr__(self) -> str:
75
+ return repr(self.outer) + " ∘ " + repr(self.inner)
76
+
77
+ def _compute(self, input: _A) -> _C:
78
+ intermediate = self.inner(input)
79
+ return self.outer(intermediate)
80
+
81
+ @property
82
+ def required_keys(self) -> set[Tensor]:
83
+ return self.inner.required_keys
84
+
85
+ @property
86
+ def output_keys(self) -> set[Tensor]:
87
+ return self.outer.output_keys
88
+
89
+
90
+ class Conjunction(Transform[_A, _B]):
91
+ def __init__(self, transforms: Sequence[Transform[_A, _B]]):
92
+ self.transforms = transforms
93
+
94
+ self._required_keys = set(
95
+ key for transform in transforms for key in transform.required_keys
96
+ )
97
+ for transform in transforms:
98
+ if transform.required_keys != self.required_keys:
99
+ raise ValueError("All transforms should require the same set of keys.")
100
+
101
+ output_keys_with_duplicates = [key for t in transforms for key in t.output_keys]
102
+ self._output_keys = set(output_keys_with_duplicates)
103
+
104
+ if len(self._output_keys) != len(output_keys_with_duplicates):
105
+ raise ValueError("The sets of output keys of transforms should be disjoint.")
106
+
107
+ def _compute(self, tensor_dict: _A) -> _B:
108
+ output = _union([transform(tensor_dict) for transform in self.transforms])
109
+ return output
110
+
111
+ @property
112
+ def required_keys(self) -> set[Tensor]:
113
+ return self._required_keys
114
+
115
+ @property
116
+ def output_keys(self) -> set[Tensor]:
117
+ return self._output_keys
@@ -0,0 +1,65 @@
1
+ from typing import Sequence
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._utils import _A, _materialize, dicts_union
7
+ from torchjd._transform.base import Transform
8
+ from torchjd._transform.tensor_dict import Jacobians
9
+
10
+
11
+ class Concatenation(Transform[_A, Jacobians]):
12
+ def __init__(self, transforms: Sequence[Transform[_A, Jacobians]]):
13
+ if len(transforms) == 0:
14
+ raise ValueError("Parameter `transforms` cannot be empty.")
15
+
16
+ self.transforms = transforms
17
+
18
+ self._required_keys = transforms[0].required_keys
19
+ self._output_keys = {key for transform in transforms for key in transform.output_keys}
20
+
21
+ for transform in transforms[1:]:
22
+ if transform.required_keys != self.required_keys:
23
+ raise ValueError("All transforms should require the same set of keys.")
24
+
25
+ def _compute(self, input: _A) -> Jacobians:
26
+ results = [transform(input) for transform in self.transforms]
27
+ result = _concatenate(results)
28
+ return result
29
+
30
+ @property
31
+ def required_keys(self) -> set[Tensor]:
32
+ return self._required_keys
33
+
34
+ @property
35
+ def output_keys(self) -> set[Tensor]:
36
+ return self._output_keys
37
+
38
+
39
+ def _concatenate(jacobians_dicts: list[Jacobians]) -> Jacobians:
40
+ """
41
+ Transforms a list of tensor dicts into a single dict of (concatenated) tensors. The set of keys
42
+ of the resulting dict is the union of the sets of keys of the input dicts. If a key is absent in
43
+ some input dicts, the corresponding concatenated tensor is filled with zeroes at the positions
44
+ corresponding to those dicts.
45
+ """
46
+
47
+ # It is important to first remove duplicate keys before computing their associated
48
+ # concatenated tensor. Otherwise, some computations would be duplicated. Therefore, we first
49
+ # compute unique_keys, and only then, we compute the concatenated tensors.
50
+ unique_keys = dicts_union(jacobians_dicts).keys()
51
+ result = Jacobians({key: _concatenate_one_key(jacobians_dicts, key) for key in unique_keys})
52
+ return result
53
+
54
+
55
+ def _concatenate_one_key(jacobian_dicts: list[Jacobians], input: Tensor) -> Tensor:
56
+ """
57
+ Makes the concatenated tensor corresponding to a given key, from a list of tensor dicts.
58
+ """
59
+
60
+ first_dimensions = [jacobian_dict.first_dimension for jacobian_dict in jacobian_dicts]
61
+ optional_jacobians = [jacobian.get(input, None) for jacobian in jacobian_dicts]
62
+ expanded_inputs = [input.expand(dim, *input.shape) for dim in first_dimensions]
63
+ jacobians = _materialize(optional_jacobians, expanded_inputs)
64
+ jacobian = torch.concatenate(jacobians, dim=0)
65
+ return jacobian
@@ -0,0 +1,36 @@
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._utils import ordered_set
7
+ from torchjd._transform.base import Transform
8
+ from torchjd._transform.tensor_dict import Gradients, Jacobians
9
+
10
+
11
+ class Diagonalize(Transform[Gradients, Jacobians]):
12
+ def __init__(self, considered: Iterable[Tensor]):
13
+ self.considered = ordered_set(considered)
14
+ self.indices: list[tuple[int, int]] = []
15
+ begin = 0
16
+ for tensor in self.considered:
17
+ end = begin + tensor.numel()
18
+ self.indices.append((begin, end))
19
+ begin = end
20
+
21
+ def _compute(self, tensors: Gradients) -> Jacobians:
22
+ flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
23
+ diagonal_matrix = torch.cat(flattened_considered_values).diag()
24
+ diagonalized_tensors = {
25
+ key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
26
+ for (begin, end), key in zip(self.indices, self.considered)
27
+ }
28
+ return Jacobians(diagonalized_tensors)
29
+
30
+ @property
31
+ def required_keys(self) -> set[Tensor]:
32
+ return set(self.considered)
33
+
34
+ @property
35
+ def output_keys(self) -> set[Tensor]:
36
+ return set(self.considered)
@@ -0,0 +1,60 @@
1
+ from typing import Iterable, Sequence
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform._differentiation import _Differentiation
7
+ from torchjd._transform._utils import _materialize
8
+ from torchjd._transform.tensor_dict import Gradients
9
+
10
+
11
+ class Grad(_Differentiation[Gradients]):
12
+ def __init__(
13
+ self,
14
+ outputs: Iterable[Tensor],
15
+ inputs: Iterable[Tensor],
16
+ retain_graph: bool = False,
17
+ ):
18
+ super().__init__(outputs, inputs)
19
+ self.retain_graph = retain_graph
20
+
21
+ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
22
+ return _grad(
23
+ outputs=self.outputs,
24
+ inputs=self.inputs,
25
+ grad_outputs=grad_outputs,
26
+ retain_graph=self.retain_graph,
27
+ create_graph=False,
28
+ allow_unused=True,
29
+ )
30
+
31
+
32
+ def _grad(
33
+ outputs: Sequence[Tensor],
34
+ inputs: Sequence[Tensor],
35
+ grad_outputs: Sequence[Tensor],
36
+ retain_graph: bool,
37
+ create_graph: bool,
38
+ allow_unused: bool,
39
+ ) -> tuple[Tensor, ...]:
40
+ """
41
+ Wraps `autograd.grad` to give it additional responsibilities that it should have (like being
42
+ able to work with an empty sequence of `inputs`).
43
+ """
44
+
45
+ if len(inputs) == 0:
46
+ return tuple()
47
+
48
+ if len(outputs) == 0:
49
+ return tuple([torch.empty(input.shape) for input in inputs])
50
+
51
+ optional_grads = torch.autograd.grad(
52
+ outputs,
53
+ inputs,
54
+ grad_outputs=grad_outputs,
55
+ retain_graph=retain_graph,
56
+ create_graph=create_graph,
57
+ allow_unused=allow_unused,
58
+ )
59
+ grads = _materialize(optional_grads, inputs)
60
+ return grads
@@ -0,0 +1,22 @@
1
+ from typing import Iterable
2
+
3
+ from torch import Tensor
4
+
5
+ from torchjd._transform._utils import _A
6
+ from torchjd._transform.base import Transform
7
+
8
+
9
+ class Identity(Transform[_A, _A]):
10
+ def __init__(self, required_keys: Iterable[Tensor]):
11
+ self._required_keys = set(required_keys)
12
+
13
+ def _compute(self, tensor_dict: _A) -> _A:
14
+ return tensor_dict
15
+
16
+ @property
17
+ def required_keys(self) -> set[Tensor]:
18
+ return self._required_keys
19
+
20
+ @property
21
+ def output_keys(self) -> set[Tensor]:
22
+ return self._required_keys
@@ -0,0 +1,30 @@
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torchjd._transform.base import Transform
7
+ from torchjd._transform.tensor_dict import EmptyTensorDict, Gradients
8
+
9
+
10
+ class Init(Transform[EmptyTensorDict, Gradients]):
11
+ def __init__(self, values: Iterable[Tensor]):
12
+ self.values = set(values)
13
+
14
+ def _compute(self, input: EmptyTensorDict) -> Gradients:
15
+ r"""
16
+ Computes the gradients of the ``value`` with respect to itself. Returns the result as a
17
+ dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
18
+ tensor of 1s of identical shape, because :math:`\frac{\partial v}{\partial v} = 1` for any
19
+ :math:`v`.
20
+ """
21
+
22
+ return Gradients({value: torch.ones_like(value) for value in self.values})
23
+
24
+ @property
25
+ def required_keys(self) -> set[Tensor]:
26
+ return set()
27
+
28
+ @property
29
+ def output_keys(self) -> set[Tensor]:
30
+ return self.values