torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
torchzero/core/transform.py
CHANGED
|
@@ -1,440 +1,336 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from collections.abc import
|
|
3
|
-
from
|
|
2
|
+
from collections.abc import Mapping, Sequence
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Any, final, cast, TYPE_CHECKING
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
8
|
+
from .module import Module
|
|
9
|
+
from ..utils import vec_to_tensors, safe_dict_update_
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .chain import Chainable
|
|
13
|
+
from .objective import Objective
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
class Transform(Module
|
|
14
|
-
"""
|
|
15
|
-
This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
|
|
16
|
+
class Transform(Module):
|
|
17
|
+
"""``Transform`` is a ``Module`` with only optional children.
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
+
``Transform`` if more flexible in that as long as there are no children, it can use a custom list of states
|
|
20
|
+
and settings instead of ``self.state`` and ``self.setting``.
|
|
19
21
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
22
|
+
To use, subclass this and override ``update_states`` and ``apply_states``.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, defaults: dict[str, Any] | None = None, update_freq: int = 1, inner: "Chainable | None" = None):
|
|
25
|
+
|
|
26
|
+
# store update_freq in defaults so that it is scheduleable
|
|
27
|
+
if defaults is None: defaults = {}
|
|
28
|
+
safe_dict_update_(defaults, {"__update_freq": update_freq})
|
|
29
|
+
|
|
30
|
+
super().__init__(defaults)
|
|
31
|
+
|
|
32
|
+
self._objective = None
|
|
33
|
+
if inner is not None:
|
|
34
|
+
self.set_child("inner", inner)
|
|
35
|
+
|
|
36
|
+
# settings shouldn't mutate, so they are typed as Sequence[Mapping]
|
|
37
|
+
def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
|
|
38
|
+
"""Updates ``states``. This should not modify ``objective.update``."""
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
|
|
42
|
+
"""Updates ``objective`` using ``states``."""
|
|
43
|
+
|
|
44
|
+
def _get_states_settings(self, objective: "Objective") -> tuple[list, tuple]:
|
|
45
|
+
# itemgetter is faster
|
|
46
|
+
# but need to make sure it returns a tuple, as if there is a single param, it returns the value
|
|
47
|
+
getter = itemgetter(*objective.params)
|
|
48
|
+
is_single = len(objective.params) == 1
|
|
49
|
+
states = getter(self.state)
|
|
50
|
+
settings = getter(self.settings)
|
|
23
51
|
|
|
24
|
-
|
|
52
|
+
if is_single:
|
|
53
|
+
states = [states, ]
|
|
54
|
+
settings = (settings, )
|
|
25
55
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
56
|
+
else:
|
|
57
|
+
states = list(states) # itemgetter returns tuple
|
|
58
|
+
|
|
59
|
+
return states, settings
|
|
60
|
+
|
|
61
|
+
@final
|
|
62
|
+
def update(self, objective:"Objective"):
|
|
63
|
+
step = self.increment_counter("__step", 0)
|
|
64
|
+
|
|
65
|
+
if step % self.settings[objective.params[0]]["__update_freq"] == 0:
|
|
66
|
+
states, settings = self._get_states_settings(objective)
|
|
67
|
+
self.update_states(objective=objective, states=states, settings=settings)
|
|
68
|
+
|
|
69
|
+
@final
|
|
70
|
+
def apply(self, objective: "Objective"):
|
|
29
71
|
|
|
30
|
-
|
|
31
|
-
|
|
72
|
+
# inner step
|
|
73
|
+
if "inner" in self.children:
|
|
74
|
+
inner = self.children["inner"]
|
|
75
|
+
objective = inner.step(objective)
|
|
32
76
|
|
|
33
|
-
|
|
77
|
+
# apply and return
|
|
78
|
+
states, settings = self._get_states_settings(objective)
|
|
79
|
+
return self.apply_states(objective=objective, states=states, settings=settings)
|
|
34
80
|
|
|
35
|
-
Args:
|
|
36
|
-
defaults (dict[str,Any] | None): dict with default values.
|
|
37
|
-
uses_grad (bool):
|
|
38
|
-
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
39
|
-
`grad` is always computed and can't be None. Otherwise set to False.
|
|
40
|
-
target (Target, optional):
|
|
41
|
-
what to set on var. Defaults to 'update'.
|
|
42
81
|
|
|
82
|
+
|
|
83
|
+
class TensorTransform(Transform):
|
|
84
|
+
"""``TensorTransform`` is a ``Transform`` that doesn't use ``Objective``, instead it operates
|
|
85
|
+
on lists of tensors directly.
|
|
86
|
+
|
|
87
|
+
This has a ``concat_params`` setting which is used in quite a few modules, for example it is optional
|
|
88
|
+
in all full-matrix method like Quasi-Newton or full-matrix Adagrad.
|
|
89
|
+
|
|
90
|
+
To use, subclass this and override one of ``single_tensor_update`` or ``multi_tensor_update``,
|
|
91
|
+
and one of ``single_tensor_apply`` or ``multi_tensor_apply``.
|
|
92
|
+
|
|
93
|
+
For copying:
|
|
94
|
+
|
|
95
|
+
multi tensor:
|
|
96
|
+
```
|
|
97
|
+
def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
|
|
98
|
+
...
|
|
99
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
100
|
+
...
|
|
101
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
102
|
+
...
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
single tensor:
|
|
106
|
+
|
|
107
|
+
```
|
|
108
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
109
|
+
...
|
|
110
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
111
|
+
...
|
|
112
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
113
|
+
...
|
|
114
|
+
```
|
|
43
115
|
"""
|
|
44
116
|
def __init__(
|
|
45
117
|
self,
|
|
46
|
-
defaults: dict[str,Any] | None,
|
|
118
|
+
defaults: dict[str, Any] | None = None,
|
|
119
|
+
update_freq: int = 1,
|
|
120
|
+
concat_params: bool = False,
|
|
47
121
|
uses_grad: bool = False,
|
|
48
122
|
uses_loss: bool = False,
|
|
49
|
-
|
|
50
|
-
update_freq: int = 1,
|
|
51
|
-
inner: Chainable | None = None,
|
|
52
|
-
target: Target = 'update',
|
|
123
|
+
inner: "Chainable | None" = None,
|
|
53
124
|
):
|
|
54
|
-
super().__init__(defaults)
|
|
55
|
-
|
|
125
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
126
|
+
|
|
127
|
+
self._concat_params = concat_params
|
|
56
128
|
self._uses_grad = uses_grad
|
|
57
129
|
self._uses_loss = uses_loss
|
|
58
|
-
self._concat_params = concat_params
|
|
59
|
-
self._update_freq = update_freq
|
|
60
|
-
self._inner = inner
|
|
61
|
-
self._var = None
|
|
62
130
|
|
|
63
|
-
|
|
131
|
+
# ------------------------------- single tensor ------------------------------ #
|
|
132
|
+
def single_tensor_initialize(
|
|
64
133
|
self,
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
loss: torch.Tensor |
|
|
69
|
-
|
|
70
|
-
|
|
134
|
+
tensor: torch.Tensor,
|
|
135
|
+
param: torch.Tensor,
|
|
136
|
+
grad: torch.Tensor | None,
|
|
137
|
+
loss: torch.Tensor | None,
|
|
138
|
+
state: dict[str, Any],
|
|
139
|
+
setting: Mapping[str, Any],
|
|
71
140
|
) -> None:
|
|
72
|
-
"""
|
|
141
|
+
"""initialize ``state`` before first ``update``.
|
|
142
|
+
"""
|
|
73
143
|
|
|
74
|
-
|
|
75
|
-
|
|
144
|
+
def single_tensor_update(
|
|
145
|
+
self,
|
|
146
|
+
tensor: torch.Tensor,
|
|
147
|
+
param: torch.Tensor,
|
|
148
|
+
grad: torch.Tensor | None,
|
|
149
|
+
loss: torch.Tensor | None,
|
|
150
|
+
state: dict[str, Any],
|
|
151
|
+
setting: Mapping[str, Any],
|
|
152
|
+
) -> None:
|
|
153
|
+
"""Updates ``state``. This should not modify ``tensor``.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def single_tensor_apply(
|
|
157
|
+
self,
|
|
158
|
+
tensor: torch.Tensor,
|
|
159
|
+
param: torch.Tensor,
|
|
160
|
+
grad: torch.Tensor | None,
|
|
161
|
+
loss: torch.Tensor | None,
|
|
162
|
+
state: dict[str, Any],
|
|
163
|
+
setting: Mapping[str, Any],
|
|
164
|
+
) -> torch.Tensor:
|
|
165
|
+
"""Updates ``tensor`` and returns it. This shouldn't modify ``state`` if possible.
|
|
166
|
+
"""
|
|
167
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `single_tensor_apply`.")
|
|
168
|
+
|
|
169
|
+
# ------------------------------- multi tensor ------------------------------- #
|
|
170
|
+
def multi_tensor_initialize(
|
|
76
171
|
self,
|
|
77
172
|
tensors: list[torch.Tensor],
|
|
78
173
|
params: list[torch.Tensor],
|
|
79
174
|
grads: list[torch.Tensor] | None,
|
|
80
|
-
loss: torch.Tensor |
|
|
175
|
+
loss: torch.Tensor | None,
|
|
81
176
|
states: list[dict[str, Any]],
|
|
82
177
|
settings: Sequence[Mapping[str, Any]],
|
|
83
|
-
) ->
|
|
84
|
-
"""
|
|
85
|
-
|
|
178
|
+
) -> None:
|
|
179
|
+
"""initialize ``states`` before first ``update``.
|
|
180
|
+
By default calls ``single_tensor_initialize`` on all tensors.
|
|
181
|
+
"""
|
|
182
|
+
if grads is None:
|
|
183
|
+
grads = cast(list, [None] * len(tensors))
|
|
86
184
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
185
|
+
for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
|
|
186
|
+
self.single_tensor_initialize(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
|
|
187
|
+
|
|
188
|
+
def multi_tensor_update(
|
|
90
189
|
self,
|
|
91
190
|
tensors: list[torch.Tensor],
|
|
92
191
|
params: list[torch.Tensor],
|
|
93
192
|
grads: list[torch.Tensor] | None,
|
|
94
|
-
loss: torch.Tensor |
|
|
193
|
+
loss: torch.Tensor | None,
|
|
95
194
|
states: list[dict[str, Any]],
|
|
96
|
-
settings: Sequence[Mapping[str, Any]]
|
|
195
|
+
settings: Sequence[Mapping[str, Any]],
|
|
97
196
|
) -> None:
|
|
98
|
-
"""Updates
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
params = [torch.cat([p.ravel() for p in params])]
|
|
102
|
-
grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
103
|
-
|
|
104
|
-
if settings is None:
|
|
105
|
-
settings = [self.defaults for _ in tensors]
|
|
106
|
-
|
|
107
|
-
step = self.global_state.get('__step', 0) # that way it gets reset correctly
|
|
108
|
-
self.global_state['__step'] = step + 1
|
|
109
|
-
|
|
110
|
-
num = len(tensors)
|
|
111
|
-
states = states[:num]
|
|
112
|
-
settings = settings[:num]
|
|
197
|
+
"""Updates ``states``. This should not modify ``tensor``.
|
|
198
|
+
By default calls ``single_tensor_update`` on all tensors.
|
|
199
|
+
"""
|
|
113
200
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
201
|
+
if grads is None:
|
|
202
|
+
grads = cast(list, [None] * len(tensors))
|
|
117
203
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
self.global_state["__params"] = params
|
|
121
|
-
self.global_state["__grads"] = grads
|
|
204
|
+
for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
|
|
205
|
+
self.single_tensor_update(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
|
|
122
206
|
|
|
123
|
-
|
|
124
|
-
@final
|
|
125
|
-
@torch.no_grad
|
|
126
|
-
def transform_apply(
|
|
207
|
+
def multi_tensor_apply(
|
|
127
208
|
self,
|
|
128
209
|
tensors: list[torch.Tensor],
|
|
129
210
|
params: list[torch.Tensor],
|
|
130
211
|
grads: list[torch.Tensor] | None,
|
|
131
|
-
loss: torch.Tensor |
|
|
212
|
+
loss: torch.Tensor | None,
|
|
132
213
|
states: list[dict[str, Any]],
|
|
133
|
-
settings: Sequence[Mapping[str, Any]]
|
|
134
|
-
) ->
|
|
135
|
-
"""
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
if settings is None:
|
|
139
|
-
settings = [self.defaults for _ in tensors]
|
|
214
|
+
settings: Sequence[Mapping[str, Any]],
|
|
215
|
+
) -> Sequence[torch.Tensor]:
|
|
216
|
+
"""Updates ``tensors`` and returns it. This shouldn't modify ``state`` if possible.
|
|
217
|
+
By default calls ``single_tensor_apply`` on all tensors.
|
|
218
|
+
"""
|
|
140
219
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
settings = settings[:num]
|
|
220
|
+
if grads is None:
|
|
221
|
+
grads = cast(list, [None] * len(tensors))
|
|
144
222
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
223
|
+
ret = []
|
|
224
|
+
for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
|
|
225
|
+
u = self.single_tensor_apply(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
|
|
226
|
+
ret.append(u)
|
|
148
227
|
|
|
149
|
-
|
|
150
|
-
params = self.global_state.pop("__params")
|
|
151
|
-
grads = self.global_state.pop("__grads")
|
|
228
|
+
return ret
|
|
152
229
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
|
|
156
|
-
if self._concat_params:
|
|
157
|
-
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
230
|
+
def _get_grads_loss(self, objective: "Objective"):
|
|
231
|
+
"""evaluates grads and loss only if needed"""
|
|
158
232
|
|
|
159
|
-
|
|
160
|
-
|
|
233
|
+
if self._uses_grad: grads = objective.get_grads()
|
|
234
|
+
else: grads = None # better explicitly set to None rather than objective.grads because it shouldn't be used
|
|
161
235
|
|
|
162
|
-
if self.
|
|
163
|
-
|
|
236
|
+
if self._uses_loss: loss = objective.get_loss(backward=False)
|
|
237
|
+
else: loss = None
|
|
164
238
|
|
|
165
|
-
return
|
|
239
|
+
return grads, loss
|
|
166
240
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
states = [self.state[p]]
|
|
171
|
-
settings = [self.settings[p]]
|
|
241
|
+
@torch.no_grad
|
|
242
|
+
def _get_cat_updates_params_grads(self, objective: "Objective", grads: list[torch.Tensor] | None):
|
|
243
|
+
assert self._concat_params
|
|
172
244
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
settings = []
|
|
176
|
-
for p in params:
|
|
177
|
-
states.append(self.state[p])
|
|
178
|
-
settings.append(self.settings[p])
|
|
245
|
+
cat_updates = [torch.cat([u.ravel() for u in objective.get_updates()])]
|
|
246
|
+
cat_params = [torch.cat([p.ravel() for p in objective.params])]
|
|
179
247
|
|
|
180
|
-
|
|
248
|
+
if grads is None: cat_grads = None
|
|
249
|
+
else: cat_grads = [torch.cat([g.ravel() for g in grads])]
|
|
181
250
|
|
|
182
|
-
|
|
183
|
-
@torch.no_grad
|
|
184
|
-
def keyed_transform_update(
|
|
185
|
-
self,
|
|
186
|
-
tensors: list[torch.Tensor],
|
|
187
|
-
params: list[torch.Tensor],
|
|
188
|
-
grads: list[torch.Tensor] | None,
|
|
189
|
-
loss: torch.Tensor | float | None,
|
|
190
|
-
):
|
|
191
|
-
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
192
|
-
states, settings = self._get_keyed_states_settings(params)
|
|
193
|
-
self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
251
|
+
return cat_updates, cat_params, cat_grads
|
|
194
252
|
|
|
253
|
+
def _gather_tensors(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
|
|
254
|
+
"""returns everything for ``multi_tensor_*``. Concatenates if ```self._concat_params``.
|
|
255
|
+
evaluates grads and loss if ``self._uses_grad`` and ``self._uses_loss``"""
|
|
195
256
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
def keyed_transform_apply(
|
|
199
|
-
self,
|
|
200
|
-
tensors: list[torch.Tensor],
|
|
201
|
-
params: list[torch.Tensor],
|
|
202
|
-
grads: list[torch.Tensor] | None,
|
|
203
|
-
loss: torch.Tensor | float | None,
|
|
204
|
-
):
|
|
205
|
-
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
206
|
-
states, settings = self._get_keyed_states_settings(params)
|
|
207
|
-
return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
257
|
+
# evaluate grads and loss if `self._uses_grad` and `self._uses_loss`
|
|
258
|
+
grads, loss = self._get_grads_loss(objective)
|
|
208
259
|
|
|
260
|
+
# gather all things
|
|
261
|
+
# concatenate everything to a vec if `self._concat_params`
|
|
262
|
+
if self._concat_params:
|
|
263
|
+
tensors, params, grads = self._get_cat_updates_params_grads(objective, grads)
|
|
264
|
+
states = [states[0]]; settings = [settings[0]]
|
|
209
265
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
266
|
+
# or take original values
|
|
267
|
+
else:
|
|
268
|
+
tensors=objective.get_updates()
|
|
269
|
+
params = objective.params
|
|
214
270
|
|
|
215
|
-
|
|
216
|
-
if self._target != 'update':
|
|
217
|
-
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
218
|
-
f"With {self._target = } only `step` method can be used.")
|
|
271
|
+
return tensors, params, grads, loss, states, settings
|
|
219
272
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
self.
|
|
273
|
+
@final
|
|
274
|
+
def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
|
|
275
|
+
tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
|
|
276
|
+
|
|
277
|
+
# initialize before the first update
|
|
278
|
+
num_updates = self.increment_counter("__num_updates", 0)
|
|
279
|
+
if num_updates == 0:
|
|
280
|
+
self.multi_tensor_initialize(
|
|
281
|
+
tensors=tensors,
|
|
282
|
+
params=params,
|
|
283
|
+
grads=grads,
|
|
284
|
+
loss=loss,
|
|
285
|
+
states=states,
|
|
286
|
+
settings=settings
|
|
287
|
+
)
|
|
226
288
|
|
|
227
289
|
# update
|
|
228
|
-
self.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
290
|
+
self.multi_tensor_update(
|
|
291
|
+
tensors=tensors,
|
|
292
|
+
params=params,
|
|
293
|
+
grads=grads,
|
|
294
|
+
loss=loss,
|
|
295
|
+
states=states,
|
|
296
|
+
settings=settings
|
|
297
|
+
)
|
|
236
298
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
299
|
+
@final
|
|
300
|
+
def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
|
|
301
|
+
tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
|
|
302
|
+
# note: _gather tensors will re-cat again if `_concat_params`, this is necessary because objective
|
|
303
|
+
# may have been modified in functional logic, there is no way to know if that happened
|
|
242
304
|
|
|
243
305
|
# apply
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
def step(self, var: Var) -> Var:
|
|
252
|
-
|
|
253
|
-
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
254
|
-
if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
|
|
255
|
-
if self._uses_grad or self._target == 'grad': var.get_grad()
|
|
256
|
-
if self._uses_loss: var.get_loss(False)
|
|
257
|
-
params=var.params
|
|
258
|
-
self.pre_step(var)
|
|
259
|
-
self._var = var
|
|
260
|
-
|
|
261
|
-
# ---------------------------------- update ---------------------------------- #
|
|
262
|
-
if self._target == 'update':
|
|
263
|
-
update = var.get_update()
|
|
264
|
-
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
265
|
-
var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
|
|
266
|
-
self._var = None
|
|
267
|
-
return var
|
|
268
|
-
|
|
269
|
-
# ----------------------------------- grad ----------------------------------- #
|
|
270
|
-
if self._target == 'grad':
|
|
271
|
-
grad = var.get_grad()
|
|
272
|
-
self.keyed_transform_update(grad, params, grad, var.loss)
|
|
273
|
-
var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
|
|
274
|
-
self._var = None
|
|
275
|
-
return var
|
|
276
|
-
|
|
277
|
-
# ------------------------------- params_direct ------------------------------ #
|
|
278
|
-
if self._target == 'params_direct':
|
|
279
|
-
self.keyed_transform_update(var.params, params, var.grad, var.loss)
|
|
280
|
-
new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
|
|
281
|
-
for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
|
|
282
|
-
self._var = None
|
|
283
|
-
return var
|
|
284
|
-
|
|
285
|
-
# ----------------------------- params_differnce ----------------------------- #
|
|
286
|
-
if self._target == 'params_difference':
|
|
287
|
-
p_clone = [p.clone() for p in var.params]
|
|
288
|
-
self.keyed_transform_update(p_clone, params, var.grad, var.loss)
|
|
289
|
-
new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
|
|
290
|
-
var.update = list(torch._foreach_sub(var.params, new_params))
|
|
291
|
-
self._var = None
|
|
292
|
-
return var
|
|
293
|
-
|
|
294
|
-
# ----------------------------- update_difference ---------------------------- #
|
|
295
|
-
if self._target == 'update_difference':
|
|
296
|
-
update = var.get_update()
|
|
297
|
-
u_clone = [u.clone() for u in update]
|
|
298
|
-
self.keyed_transform_update(u_clone, params, var.grad, var.loss)
|
|
299
|
-
new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
|
|
300
|
-
var.update = list(torch._foreach_sub(update, new_update))
|
|
301
|
-
self._var = None
|
|
302
|
-
return var
|
|
303
|
-
|
|
304
|
-
# ---------------------------------- closure --------------------------------- #
|
|
305
|
-
if self._target == 'closure':
|
|
306
|
-
original_closure = var.closure
|
|
307
|
-
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
308
|
-
|
|
309
|
-
params = var.params
|
|
310
|
-
parent_var = self._var
|
|
311
|
-
def transformed_closure(backward=True):
|
|
312
|
-
if backward:
|
|
313
|
-
loss = original_closure()
|
|
314
|
-
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
315
|
-
|
|
316
|
-
self._var = parent_var
|
|
317
|
-
self.keyed_transform_update(current_grad, params, var.grad, var.loss)
|
|
318
|
-
transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
|
|
319
|
-
self._var = None
|
|
320
|
-
|
|
321
|
-
for p, g in zip(params, transformed_grad):
|
|
322
|
-
p.grad = g
|
|
323
|
-
|
|
324
|
-
else:
|
|
325
|
-
loss = original_closure(False)
|
|
326
|
-
|
|
327
|
-
return loss
|
|
328
|
-
|
|
329
|
-
var.closure = transformed_closure
|
|
330
|
-
self.post_step(var)
|
|
331
|
-
self._var = None
|
|
332
|
-
return var
|
|
333
|
-
|
|
334
|
-
# ---------------------------------- invalid --------------------------------- #
|
|
335
|
-
raise ValueError(f'Invalid target: {self._target}')
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
class TensorwiseTransform(Transform, ABC):
|
|
339
|
-
"""Base class for a parameter-wise transform.
|
|
340
|
-
|
|
341
|
-
This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
|
|
342
|
-
|
|
343
|
-
Args:
|
|
344
|
-
defaults (dict[str,Any] | None): dict with default values.
|
|
345
|
-
uses_grad (bool):
|
|
346
|
-
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
347
|
-
`grad` is always computed and can't be None. Otherwise set to False.
|
|
348
|
-
target (Target, optional):
|
|
349
|
-
what to set on var. Defaults to 'update'.
|
|
350
|
-
"""
|
|
351
|
-
def __init__(
|
|
352
|
-
self,
|
|
353
|
-
defaults: dict[str,Any] | None,
|
|
354
|
-
uses_grad: bool = False,
|
|
355
|
-
uses_loss: bool = False,
|
|
356
|
-
concat_params: bool = False,
|
|
357
|
-
update_freq: int = 1,
|
|
358
|
-
inner: Chainable | None = None,
|
|
359
|
-
target: Target = 'update',
|
|
360
|
-
):
|
|
361
|
-
super().__init__(
|
|
362
|
-
defaults=defaults,
|
|
363
|
-
uses_grad=uses_grad,
|
|
364
|
-
concat_params=concat_params,
|
|
365
|
-
update_freq=update_freq,
|
|
366
|
-
uses_loss=uses_loss,
|
|
367
|
-
inner=inner,
|
|
368
|
-
target=target,
|
|
306
|
+
ret = self.multi_tensor_apply(
|
|
307
|
+
tensors=tensors,
|
|
308
|
+
params=params,
|
|
309
|
+
grads=grads,
|
|
310
|
+
loss=loss,
|
|
311
|
+
states=states,
|
|
312
|
+
settings=settings
|
|
369
313
|
)
|
|
370
314
|
|
|
371
|
-
|
|
372
|
-
self
|
|
373
|
-
|
|
374
|
-
param: torch.Tensor,
|
|
375
|
-
grad: torch.Tensor | None,
|
|
376
|
-
loss: torch.Tensor | float | None,
|
|
377
|
-
state: dict[str, Any],
|
|
378
|
-
setting: Mapping[str, Any],
|
|
379
|
-
) -> None:
|
|
380
|
-
"""Updates this transform. By default does nothing - if logic is in `apply` method."""
|
|
315
|
+
# uncat if needed and set objective.updates and return objective
|
|
316
|
+
if self._concat_params:
|
|
317
|
+
objective.updates = vec_to_tensors(ret[0], objective.params)
|
|
381
318
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
self,
|
|
385
|
-
tensor: torch.Tensor,
|
|
386
|
-
param: torch.Tensor,
|
|
387
|
-
grad: torch.Tensor | None,
|
|
388
|
-
loss: torch.Tensor | float | None,
|
|
389
|
-
state: dict[str, Any],
|
|
390
|
-
setting: Mapping[str, Any],
|
|
391
|
-
) -> torch.Tensor:
|
|
392
|
-
"""Applies the update rule to `tensor`."""
|
|
319
|
+
else:
|
|
320
|
+
objective.updates = list(ret)
|
|
393
321
|
|
|
394
|
-
|
|
395
|
-
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
396
|
-
if grads is None: grads = [None]*len(tensors)
|
|
397
|
-
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
398
|
-
self.update_tensor(t, p, g, loss, state, setting)
|
|
322
|
+
return objective
|
|
399
323
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
loss: torch.Tensor | float | None = None,
|
|
414
|
-
var: Var | None = None,
|
|
415
|
-
current_step: int = 0,
|
|
416
|
-
):
|
|
417
|
-
if var is None:
|
|
418
|
-
var = Var(params=params, closure=None, model=None, current_step=current_step)
|
|
419
|
-
var.loss = loss
|
|
420
|
-
|
|
421
|
-
if isinstance(tfm, Transform) and tfm._target == 'update':
|
|
422
|
-
if tfm._uses_grad and grads is None: grads = var.get_grad()
|
|
423
|
-
tfm.keyed_transform_update(tensors, params, grads, loss)
|
|
424
|
-
return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
|
|
425
|
-
|
|
426
|
-
if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
|
|
427
|
-
if isinstance(tfm, Sequence):
|
|
428
|
-
for module in tfm:
|
|
429
|
-
tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
|
|
430
|
-
return tensors
|
|
431
|
-
|
|
432
|
-
if isinstance(tfm, Module):
|
|
433
|
-
cvar = var.clone(clone_update=False)
|
|
434
|
-
cvar.update = tensors
|
|
435
|
-
cvar = tfm.step(cvar)
|
|
436
|
-
var.update_attrs_from_clone_(cvar)
|
|
437
|
-
assert cvar.update is not None
|
|
438
|
-
return cvar.update
|
|
439
|
-
|
|
440
|
-
raise TypeError(type(tfm))
|
|
324
|
+
|
|
325
|
+
# make sure _concat_params, _uses_grad and _uses_loss are saved in `state_dict`
|
|
326
|
+
def _extra_pack(self):
|
|
327
|
+
return {
|
|
328
|
+
"__concat_params": self._concat_params,
|
|
329
|
+
"__uses_grad": self._uses_grad,
|
|
330
|
+
"__uses_loss": self._uses_loss,
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
def _extra_unpack(self, d):
|
|
334
|
+
self._concat_params = d["__concat_params"]
|
|
335
|
+
self._uses_grad = d["__uses_grad"]
|
|
336
|
+
self._uses_loss = d["__uses_loss"]
|