torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +55 -22
- tests/test_tensorlist.py +3 -3
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +20 -130
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +76 -26
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +15 -15
- torchzero/modules/quasi_newton/lsr1.py +18 -17
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- torchzero/modules/second_order/newton.py +38 -21
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +19 -19
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8.dist-info/RECORD +0 -130
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
torchzero/core/transform.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from collections.abc import Iterable, Sequence
|
|
3
|
-
from typing import Any, Literal
|
|
2
|
+
from collections.abc import Iterable, Sequence, Mapping
|
|
3
|
+
from typing import Any, Literal, final
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ..utils import set_storage_
|
|
8
|
-
from .module import Module,
|
|
7
|
+
from ..utils import set_storage_, TensorList, vec_to_tensors
|
|
8
|
+
from .module import Module, Var, Chain, Chainable
|
|
9
9
|
|
|
10
10
|
Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
|
|
11
11
|
|
|
12
12
|
class Transform(Module, ABC):
|
|
13
|
-
"""Base class for a transform.
|
|
13
|
+
"""Base class for a transform. This is an abstract class, to use it, subclass it and override `update` and `apply` methods.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
18
|
defaults (dict[str,Any] | None): dict with default values.
|
|
@@ -20,62 +20,180 @@ class Transform(Module, ABC):
|
|
|
20
20
|
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
21
21
|
`grad` is always computed and can't be None. Otherwise set to False.
|
|
22
22
|
target (Target, optional):
|
|
23
|
-
what to set on
|
|
23
|
+
what to set on var. Defaults to 'update'.
|
|
24
24
|
"""
|
|
25
|
-
def __init__(
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
defaults: dict[str,Any] | None,
|
|
28
|
+
uses_grad: bool,
|
|
29
|
+
concat_params: bool = False,
|
|
30
|
+
update_freq: int = 1,
|
|
31
|
+
scale_first: bool = False,
|
|
32
|
+
inner: Chainable | None = None,
|
|
33
|
+
target: Target = 'update',
|
|
34
|
+
):
|
|
26
35
|
super().__init__(defaults)
|
|
27
36
|
self._target: Target = target
|
|
28
37
|
self._uses_grad = uses_grad
|
|
38
|
+
self._concat_params = concat_params
|
|
39
|
+
self._update_freq = update_freq
|
|
40
|
+
self._scale_first = scale_first
|
|
41
|
+
self._inner = inner
|
|
42
|
+
|
|
43
|
+
def update(
|
|
44
|
+
self,
|
|
45
|
+
tensors: list[torch.Tensor],
|
|
46
|
+
params: list[torch.Tensor],
|
|
47
|
+
grads: list[torch.Tensor] | None,
|
|
48
|
+
loss: torch.Tensor | None,
|
|
49
|
+
states: list[dict[str, Any]],
|
|
50
|
+
settings: Sequence[Mapping[str, Any]],
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Updates this transform. By default does nothing - if logic is in `apply` method."""
|
|
29
53
|
|
|
30
54
|
@abstractmethod
|
|
31
|
-
def
|
|
32
|
-
|
|
55
|
+
def apply(
|
|
56
|
+
self,
|
|
57
|
+
tensors: list[torch.Tensor],
|
|
58
|
+
params: list[torch.Tensor],
|
|
59
|
+
grads: list[torch.Tensor] | None,
|
|
60
|
+
loss: torch.Tensor | None,
|
|
61
|
+
states: list[dict[str, Any]],
|
|
62
|
+
settings: Sequence[Mapping[str, Any]],
|
|
63
|
+
) -> Sequence[torch.Tensor]:
|
|
64
|
+
"""Applies the update rule to `tensors`."""
|
|
65
|
+
|
|
66
|
+
@final
|
|
67
|
+
@torch.no_grad
|
|
68
|
+
def transform(
|
|
69
|
+
self,
|
|
70
|
+
tensors: list[torch.Tensor],
|
|
71
|
+
params: list[torch.Tensor],
|
|
72
|
+
grads: list[torch.Tensor] | None,
|
|
73
|
+
loss: torch.Tensor | None,
|
|
74
|
+
states: list[dict[str, Any]],
|
|
75
|
+
settings: Sequence[Mapping[str, Any]] | None,
|
|
76
|
+
) -> list[torch.Tensor]:
|
|
77
|
+
"""Applies this transform to an arbitrary sequence of tensors."""
|
|
78
|
+
un_tensors = tensors
|
|
79
|
+
un_params = params
|
|
80
|
+
un_grads = grads
|
|
81
|
+
if self._concat_params:
|
|
82
|
+
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
83
|
+
params = [torch.cat([p.ravel() for p in params])]
|
|
84
|
+
grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
85
|
+
|
|
86
|
+
if settings is None:
|
|
87
|
+
settings = [self.defaults for _ in tensors]
|
|
88
|
+
|
|
89
|
+
step = self.global_state.get('__step', 0)
|
|
90
|
+
num = len(tensors)
|
|
91
|
+
states = states[:num]
|
|
92
|
+
settings = settings[:num]
|
|
93
|
+
|
|
94
|
+
update_freq = self._update_freq
|
|
95
|
+
scale_first = self._scale_first
|
|
96
|
+
scale_factor = 1
|
|
97
|
+
|
|
98
|
+
# scaling factor for 1st step
|
|
99
|
+
if scale_first and step == 0:
|
|
100
|
+
# initial step size guess from pytorch LBFGS
|
|
101
|
+
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
102
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
103
|
+
|
|
104
|
+
# update transform
|
|
105
|
+
if step % update_freq == 0:
|
|
106
|
+
self.update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
107
|
+
|
|
108
|
+
# step with inner
|
|
109
|
+
if self._inner is not None:
|
|
110
|
+
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
|
|
111
|
+
if self._concat_params:
|
|
112
|
+
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
113
|
+
|
|
114
|
+
# apply transform
|
|
115
|
+
tensors = list(self.apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
|
|
116
|
+
|
|
117
|
+
# scale initial step, when preconditioner might not have been applied
|
|
118
|
+
if scale_first and step == 0:
|
|
119
|
+
torch._foreach_mul_(tensors, scale_factor)
|
|
120
|
+
|
|
121
|
+
self.global_state['__step'] = step + 1
|
|
122
|
+
if self._concat_params:
|
|
123
|
+
tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
|
|
124
|
+
return tensors
|
|
125
|
+
|
|
33
126
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
127
|
+
@torch.no_grad
|
|
128
|
+
def keyed_transform(
|
|
129
|
+
self,
|
|
130
|
+
tensors: list[torch.Tensor],
|
|
131
|
+
params: list[torch.Tensor],
|
|
132
|
+
grads: list[torch.Tensor] | None,
|
|
133
|
+
loss: torch.Tensor | None,
|
|
134
|
+
):
|
|
135
|
+
"""Applies this transform to `tensors`, `params` will be used as keys and need to always point to same tensor objects."""
|
|
136
|
+
if self._concat_params:
|
|
137
|
+
p = params[0]
|
|
138
|
+
states = [self.state[p]]
|
|
139
|
+
settings = [self.settings[p]]
|
|
140
|
+
|
|
141
|
+
else:
|
|
142
|
+
states = []
|
|
143
|
+
settings = []
|
|
144
|
+
for p in params:
|
|
145
|
+
states.append(self.state[p])
|
|
146
|
+
settings.append(self.settings[p])
|
|
147
|
+
|
|
148
|
+
return self.transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
149
|
+
|
|
150
|
+
def step(self, var: Var) -> Var:
|
|
151
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
152
|
+
if self._uses_grad: var.get_grad()
|
|
153
|
+
params=var.params
|
|
38
154
|
|
|
39
155
|
# ---------------------------------- update ---------------------------------- #
|
|
40
156
|
if self._target == 'update':
|
|
41
|
-
|
|
42
|
-
|
|
157
|
+
update = var.get_update()
|
|
158
|
+
var.update = list(self.keyed_transform(update, params, var.grad, var.loss))
|
|
159
|
+
return var
|
|
43
160
|
|
|
44
161
|
# ----------------------------------- grad ----------------------------------- #
|
|
45
162
|
if self._target == 'grad':
|
|
46
|
-
|
|
47
|
-
|
|
163
|
+
grad = var.get_grad()
|
|
164
|
+
var.grad = list(self.keyed_transform(grad, params, grad, var.loss))
|
|
165
|
+
return var
|
|
48
166
|
|
|
49
167
|
# ------------------------------- params_direct ------------------------------ #
|
|
50
168
|
if self._target == 'params_direct':
|
|
51
|
-
new_params = self.
|
|
52
|
-
for p, new_p in zip(
|
|
53
|
-
return
|
|
169
|
+
new_params = self.keyed_transform(var.params, params, var.grad, var.loss)
|
|
170
|
+
for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
|
|
171
|
+
return var
|
|
54
172
|
|
|
55
173
|
# ----------------------------- params_differnce ----------------------------- #
|
|
56
174
|
if self._target == 'params_difference':
|
|
57
|
-
new_params = tuple(self.
|
|
58
|
-
|
|
59
|
-
return
|
|
175
|
+
new_params = tuple(self.keyed_transform([p.clone() for p in var.params], params, var.grad, var.loss))
|
|
176
|
+
var.update = list(torch._foreach_sub(var.params, new_params))
|
|
177
|
+
return var
|
|
60
178
|
|
|
61
179
|
# ----------------------------- update_difference ---------------------------- #
|
|
62
180
|
if self._target == 'update_difference':
|
|
63
|
-
update =
|
|
64
|
-
new_update = tuple(self.
|
|
65
|
-
|
|
66
|
-
return
|
|
181
|
+
update = var.get_update()
|
|
182
|
+
new_update = tuple(self.keyed_transform([u.clone() for u in update], params, var.grad, var.loss))
|
|
183
|
+
var.update = list(torch._foreach_sub(update, new_update))
|
|
184
|
+
return var
|
|
67
185
|
|
|
68
186
|
# ---------------------------------- closure --------------------------------- #
|
|
69
187
|
if self._target == 'closure':
|
|
70
|
-
original_closure =
|
|
188
|
+
original_closure = var.closure
|
|
71
189
|
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
72
190
|
|
|
73
|
-
params =
|
|
191
|
+
params = var.params
|
|
74
192
|
def transformed_closure(backward=True):
|
|
75
193
|
if backward:
|
|
76
194
|
loss = original_closure()
|
|
77
195
|
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
78
|
-
transformed_grad = list(self.
|
|
196
|
+
transformed_grad = list(self.keyed_transform(current_grad, params, var.grad, var.loss))
|
|
79
197
|
for p, g in zip(params, transformed_grad):
|
|
80
198
|
p.grad = g
|
|
81
199
|
|
|
@@ -84,14 +202,14 @@ class Transform(Module, ABC):
|
|
|
84
202
|
|
|
85
203
|
return loss
|
|
86
204
|
|
|
87
|
-
|
|
88
|
-
return
|
|
205
|
+
var.closure = transformed_closure
|
|
206
|
+
return var
|
|
89
207
|
|
|
90
208
|
# ---------------------------------- invalid --------------------------------- #
|
|
91
209
|
raise ValueError(f'Invalid target: {self._target}')
|
|
92
210
|
|
|
93
211
|
|
|
94
|
-
class TensorwiseTransform(
|
|
212
|
+
class TensorwiseTransform(Transform, ABC):
|
|
95
213
|
"""Base class for a parameter-wise transform.
|
|
96
214
|
|
|
97
215
|
This is an abstract class, to use it, subclass it and override `transform`.
|
|
@@ -102,151 +220,94 @@ class TensorwiseTransform(Module, ABC):
|
|
|
102
220
|
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
103
221
|
`grad` is always computed and can't be None. Otherwise set to False.
|
|
104
222
|
target (Target, optional):
|
|
105
|
-
what to set on
|
|
223
|
+
what to set on var. Defaults to 'update'.
|
|
106
224
|
"""
|
|
107
|
-
def __init__(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
defaults: dict[str,Any] | None,
|
|
228
|
+
uses_grad: bool,
|
|
229
|
+
concat_params: bool = False,
|
|
230
|
+
update_freq: int = 1,
|
|
231
|
+
scale_first: bool = False,
|
|
232
|
+
inner: Chainable | None = None,
|
|
233
|
+
target: Target = 'update',
|
|
234
|
+
):
|
|
235
|
+
super().__init__(
|
|
236
|
+
defaults=defaults,
|
|
237
|
+
uses_grad=uses_grad,
|
|
238
|
+
concat_params=concat_params,
|
|
239
|
+
update_freq=update_freq,
|
|
240
|
+
scale_first=scale_first,
|
|
241
|
+
inner=inner,
|
|
242
|
+
target=target,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def update_tensor(
|
|
246
|
+
self,
|
|
247
|
+
tensor: torch.Tensor,
|
|
248
|
+
param: torch.Tensor,
|
|
249
|
+
grad: torch.Tensor | None,
|
|
250
|
+
loss: torch.Tensor | None,
|
|
251
|
+
state: dict[str, Any],
|
|
252
|
+
settings: Mapping[str, Any],
|
|
253
|
+
) -> None:
|
|
254
|
+
"""Updates this transform. By default does nothing - if logic is in `apply` method."""
|
|
111
255
|
|
|
112
256
|
@abstractmethod
|
|
113
|
-
def
|
|
257
|
+
def apply_tensor(
|
|
114
258
|
self,
|
|
115
259
|
tensor: torch.Tensor,
|
|
116
260
|
param: torch.Tensor,
|
|
117
261
|
grad: torch.Tensor | None,
|
|
118
|
-
|
|
262
|
+
loss: torch.Tensor | None,
|
|
263
|
+
state: dict[str, Any],
|
|
264
|
+
settings: Mapping[str, Any],
|
|
119
265
|
) -> torch.Tensor:
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
vars.update = transformed_update
|
|
138
|
-
return vars
|
|
139
|
-
|
|
140
|
-
# ----------------------------------- grad ----------------------------------- #
|
|
141
|
-
if self._target == 'grad':
|
|
142
|
-
grad = vars.get_grad()
|
|
143
|
-
transformed_grad = []
|
|
144
|
-
|
|
145
|
-
for p, g in zip(params, grad):
|
|
146
|
-
transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
|
|
147
|
-
|
|
148
|
-
vars.grad = transformed_grad
|
|
149
|
-
return vars
|
|
150
|
-
|
|
151
|
-
# ------------------------------- params_direct ------------------------------ #
|
|
152
|
-
if self._target == 'params_direct':
|
|
153
|
-
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
154
|
-
|
|
155
|
-
for p, g in zip(params, grad):
|
|
156
|
-
set_storage_(p, self.transform(tensor=p, param=p, grad=g, vars=vars))
|
|
157
|
-
|
|
158
|
-
return vars
|
|
159
|
-
|
|
160
|
-
# ----------------------------- params_difference ---------------------------- #
|
|
161
|
-
if self._target == 'params_difference':
|
|
162
|
-
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
163
|
-
transformed_params = []
|
|
164
|
-
|
|
165
|
-
for p, g in zip(params, grad):
|
|
166
|
-
transformed_params.append(
|
|
167
|
-
self.transform(tensor=p.clone(), param=p, grad=g, vars=vars)
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
vars.update = list(torch._foreach_sub(params, transformed_params))
|
|
171
|
-
return vars
|
|
172
|
-
|
|
173
|
-
# ----------------------------- update_difference ---------------------------- #
|
|
174
|
-
if self._target == 'update_difference':
|
|
175
|
-
update = vars.get_update()
|
|
176
|
-
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
177
|
-
transformed_update = []
|
|
178
|
-
|
|
179
|
-
for p, g, u in zip(params, grad, update):
|
|
180
|
-
transformed_update.append(
|
|
181
|
-
self.transform(tensor=u.clone(), param=p, grad=g, vars=vars)
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
vars.update = list(torch._foreach_sub(update, transformed_update))
|
|
185
|
-
return vars
|
|
186
|
-
|
|
187
|
-
# ---------------------------------- closure --------------------------------- #
|
|
188
|
-
if self._target == 'closure':
|
|
189
|
-
original_closure = vars.closure
|
|
190
|
-
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
191
|
-
|
|
192
|
-
params = vars.params
|
|
193
|
-
def transformed_closure(backward=True):
|
|
194
|
-
if backward:
|
|
195
|
-
loss = original_closure()
|
|
196
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
197
|
-
transformed_grad = []
|
|
198
|
-
|
|
199
|
-
for p, g in zip(params, grad):
|
|
200
|
-
transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
|
|
201
|
-
|
|
202
|
-
for p, g in zip(params, transformed_grad):
|
|
203
|
-
p.grad = g
|
|
204
|
-
|
|
205
|
-
else:
|
|
206
|
-
loss = original_closure(False)
|
|
207
|
-
|
|
208
|
-
return loss
|
|
209
|
-
|
|
210
|
-
vars.closure = transformed_closure
|
|
211
|
-
return vars
|
|
212
|
-
|
|
213
|
-
# ---------------------------------- invalid --------------------------------- #
|
|
214
|
-
raise ValueError(f'Invalid target: {self._target}')
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
def apply(
|
|
266
|
+
"""Applies the update rule to `tensor`."""
|
|
267
|
+
|
|
268
|
+
@final
|
|
269
|
+
def update(self, tensors, params, grads, loss, states, settings):
|
|
270
|
+
if grads is None: grads = [None]*len(tensors)
|
|
271
|
+
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
272
|
+
self.update_tensor(t, p, g, loss, state, setting)
|
|
273
|
+
|
|
274
|
+
@final
|
|
275
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
276
|
+
applied = []
|
|
277
|
+
if grads is None: grads = [None]*len(tensors)
|
|
278
|
+
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
279
|
+
applied.append(self.apply_tensor(t, p, g, loss, state, setting))
|
|
280
|
+
return applied
|
|
281
|
+
|
|
282
|
+
def apply_transform(
|
|
219
283
|
tfm: Chainable,
|
|
220
284
|
tensors: list[torch.Tensor],
|
|
221
285
|
params: list[torch.Tensor],
|
|
222
286
|
grads: list[torch.Tensor] | None,
|
|
223
|
-
|
|
287
|
+
loss: torch.Tensor | None = None,
|
|
288
|
+
var: Var | None = None,
|
|
224
289
|
current_step: int = 0,
|
|
225
290
|
):
|
|
226
|
-
if
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
return list(tfm.transform(tensors, params, grads, vars))
|
|
291
|
+
if var is None:
|
|
292
|
+
var = Var(params=params, closure=None, model=None, current_step=current_step)
|
|
293
|
+
var.loss = loss
|
|
230
294
|
|
|
231
|
-
if isinstance(tfm,
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
if tfm._uses_grad: grads_list = vars.get_grad()
|
|
235
|
-
else: grads_list = [None] * len(tensors)
|
|
236
|
-
return [tfm.transform(t, p, g, vars) for t,p,g in zip(tensors,params,grads_list)]
|
|
295
|
+
if isinstance(tfm, Transform):
|
|
296
|
+
if tfm._uses_grad and grads is None: grads = var.get_grad()
|
|
297
|
+
return list(tfm.keyed_transform(tensors, params, grads, loss))
|
|
237
298
|
|
|
238
299
|
if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
|
|
239
300
|
if isinstance(tfm, Sequence):
|
|
240
301
|
for module in tfm:
|
|
241
|
-
tensors =
|
|
302
|
+
tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
|
|
242
303
|
return tensors
|
|
243
304
|
|
|
244
305
|
if isinstance(tfm, Module):
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
assert
|
|
250
|
-
return
|
|
306
|
+
cvar = var.clone(clone_update=False)
|
|
307
|
+
cvar.update = tensors
|
|
308
|
+
cvar = tfm.step(cvar)
|
|
309
|
+
var.update_attrs_from_clone_(cvar)
|
|
310
|
+
assert cvar.update is not None
|
|
311
|
+
return cvar.update
|
|
251
312
|
|
|
252
313
|
raise TypeError(type(tfm))
|
torchzero/modules/__init__.py
CHANGED
|
@@ -151,8 +151,8 @@ class ClipValue(Transform):
|
|
|
151
151
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
152
152
|
|
|
153
153
|
@torch.no_grad
|
|
154
|
-
def
|
|
155
|
-
value =
|
|
154
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
155
|
+
value = [s['value'] for s in settings]
|
|
156
156
|
return TensorList(tensors).clip_([-v for v in value], value)
|
|
157
157
|
|
|
158
158
|
class ClipNorm(Transform):
|
|
@@ -186,9 +186,9 @@ class ClipNorm(Transform):
|
|
|
186
186
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
187
187
|
|
|
188
188
|
@torch.no_grad
|
|
189
|
-
def
|
|
190
|
-
max_norm =
|
|
191
|
-
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(
|
|
189
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
190
|
+
max_norm = NumberList(s['max_norm'] for s in settings)
|
|
191
|
+
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
192
192
|
_clip_norm_(
|
|
193
193
|
tensors_ = TensorList(tensors),
|
|
194
194
|
min = 0,
|
|
@@ -232,9 +232,9 @@ class Normalize(Transform):
|
|
|
232
232
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
233
233
|
|
|
234
234
|
@torch.no_grad
|
|
235
|
-
def
|
|
236
|
-
norm_value =
|
|
237
|
-
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(
|
|
235
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
236
|
+
norm_value = NumberList(s['norm_value'] for s in settings)
|
|
237
|
+
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
238
238
|
|
|
239
239
|
_clip_norm_(
|
|
240
240
|
tensors_ = TensorList(tensors),
|
|
@@ -311,8 +311,8 @@ class Centralize(Transform):
|
|
|
311
311
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
312
312
|
|
|
313
313
|
@torch.no_grad
|
|
314
|
-
def
|
|
315
|
-
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(
|
|
314
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
315
|
+
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
|
|
316
316
|
|
|
317
317
|
_centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|
|
318
318
|
|
|
@@ -4,8 +4,8 @@ from collections.abc import Iterable, Sequence
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Module, Target, Transform,
|
|
8
|
-
from ...utils import NumberList, TensorList, generic_eq
|
|
7
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
8
|
+
from ...utils import NumberList, TensorList, generic_eq, unpack_dicts, unpack_states
|
|
9
9
|
|
|
10
10
|
class ClipNormByEMA(Transform):
|
|
11
11
|
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
@@ -34,13 +34,14 @@ class ClipNormByEMA(Transform):
|
|
|
34
34
|
super().__init__(defaults, uses_grad=False)
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def
|
|
38
|
-
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(self.settings[params[0]])
|
|
39
|
-
|
|
40
|
-
beta, eps = self.get_settings('beta', 'eps', params=params, cls=NumberList)
|
|
37
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
41
38
|
tensors = TensorList(tensors)
|
|
39
|
+
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
|
|
40
|
+
|
|
41
|
+
beta, eps = unpack_dicts(settings, 'beta', 'eps', cls=NumberList)
|
|
42
|
+
|
|
43
|
+
ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
|
|
42
44
|
|
|
43
|
-
ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
|
|
44
45
|
ema.lerp_(tensors, 1-beta)
|
|
45
46
|
|
|
46
47
|
if tensorwise:
|
|
@@ -48,7 +49,7 @@ class ClipNormByEMA(Transform):
|
|
|
48
49
|
|
|
49
50
|
# clip ema norm growth
|
|
50
51
|
if max_ema_growth is not None:
|
|
51
|
-
prev_ema_norm =
|
|
52
|
+
prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
|
|
52
53
|
allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
|
|
53
54
|
ema_denom = (ema_norm / allowed_norm).clip(min=1)
|
|
54
55
|
ema.div_(ema_denom)
|
|
@@ -119,17 +120,17 @@ class ClipValueByEMA(Transform):
|
|
|
119
120
|
self.set_child('ema_tfm', ema_tfm)
|
|
120
121
|
|
|
121
122
|
@torch.no_grad
|
|
122
|
-
def
|
|
123
|
-
ema_init = itemgetter('ema_init')(
|
|
123
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
124
|
+
ema_init = itemgetter('ema_init')(settings[0])
|
|
124
125
|
|
|
125
|
-
beta =
|
|
126
|
+
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
126
127
|
tensors = TensorList(tensors)
|
|
127
128
|
|
|
128
|
-
ema =
|
|
129
|
+
ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
|
|
129
130
|
ema.lerp_(tensors.abs(), 1-beta)
|
|
130
131
|
|
|
131
132
|
if 'ema_tfm' in self.children:
|
|
132
|
-
ema = TensorList(
|
|
133
|
+
ema = TensorList(apply_transform(self.children['ema_tfm'], ema, params, grads, loss))
|
|
133
134
|
|
|
134
135
|
tensors.clip_(-ema, ema)
|
|
135
136
|
return tensors
|
|
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
19
19
|
bounds the tracked multiplicative clipping decay to prevent collapse to 0.
|
|
20
20
|
Next update is at most :code:`max(previous update * mul, max_decay)`.
|
|
21
21
|
Defaults to 2.
|
|
22
|
-
target (Target, optional): what to set on
|
|
22
|
+
target (Target, optional): what to set on var.. Defaults to "update".
|
|
23
23
|
"""
|
|
24
24
|
def __init__(
|
|
25
25
|
self,
|
|
@@ -33,12 +33,10 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
33
33
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def
|
|
37
|
-
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(
|
|
36
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
37
|
+
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(settings)
|
|
38
38
|
add: float | None
|
|
39
39
|
|
|
40
|
-
state = self.state[param]
|
|
41
|
-
|
|
42
40
|
if add is None and mul is None:
|
|
43
41
|
return tensor
|
|
44
42
|
|
|
@@ -133,7 +131,7 @@ class ClipNormGrowth(Transform):
|
|
|
133
131
|
ord (float, optional): norm order. Defaults to 2.
|
|
134
132
|
parameterwise (bool, optional):
|
|
135
133
|
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
136
|
-
target (Target, optional): what to set on
|
|
134
|
+
target (Target, optional): what to set on var. Defaults to "update".
|
|
137
135
|
"""
|
|
138
136
|
def __init__(
|
|
139
137
|
self,
|
|
@@ -150,35 +148,35 @@ class ClipNormGrowth(Transform):
|
|
|
150
148
|
|
|
151
149
|
|
|
152
150
|
|
|
153
|
-
def
|
|
154
|
-
parameterwise =
|
|
151
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
152
|
+
parameterwise = settings[0]['parameterwise']
|
|
155
153
|
tensors = TensorList(tensors)
|
|
156
154
|
|
|
157
155
|
if parameterwise:
|
|
158
156
|
ts = tensors
|
|
159
|
-
stts =
|
|
160
|
-
stns =
|
|
157
|
+
stts = states
|
|
158
|
+
stns = settings
|
|
161
159
|
|
|
162
160
|
else:
|
|
163
161
|
ts = [tensors.to_vec()]
|
|
164
162
|
stts = [self.global_state]
|
|
165
|
-
stns = [
|
|
163
|
+
stns = [settings[0]]
|
|
166
164
|
|
|
167
165
|
|
|
168
|
-
for t,state,
|
|
166
|
+
for t, state, setting in zip(ts, stts, stns):
|
|
169
167
|
if 'prev_norm' not in state:
|
|
170
|
-
state['prev_norm'] = torch.linalg.vector_norm(t, ord=
|
|
168
|
+
state['prev_norm'] = torch.linalg.vector_norm(t, ord=setting['ord']) # pylint:disable=not-callable
|
|
171
169
|
state['prev_denom'] = 1
|
|
172
170
|
continue
|
|
173
171
|
|
|
174
172
|
_, state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
|
|
175
173
|
tensor_ = t,
|
|
176
174
|
prev_norm = state['prev_norm'],
|
|
177
|
-
add =
|
|
178
|
-
mul =
|
|
179
|
-
min_value =
|
|
180
|
-
max_decay =
|
|
181
|
-
ord =
|
|
175
|
+
add = setting['add'],
|
|
176
|
+
mul = setting['mul'],
|
|
177
|
+
min_value = setting['min_value'],
|
|
178
|
+
max_decay = setting['max_decay'],
|
|
179
|
+
ord = setting['ord'],
|
|
182
180
|
)
|
|
183
181
|
|
|
184
182
|
if not parameterwise:
|