torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/core/transform.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from collections.abc import Iterable, Sequence
|
|
3
|
-
from typing import Any, Literal
|
|
2
|
+
from collections.abc import Iterable, Sequence, Mapping
|
|
3
|
+
from typing import Any, Literal, final
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ..utils import set_storage_
|
|
8
|
-
from .module import Module,
|
|
7
|
+
from ..utils import set_storage_, TensorList, vec_to_tensors
|
|
8
|
+
from .module import Module, Var, Chain, Chainable
|
|
9
9
|
|
|
10
10
|
Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
|
|
11
11
|
|
|
12
12
|
class Transform(Module, ABC):
|
|
13
|
-
"""Base class for a transform.
|
|
13
|
+
"""Base class for a transform. This is an abstract class, to use it, subclass it and override `update` and `apply` methods.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
18
|
defaults (dict[str,Any] | None): dict with default values.
|
|
@@ -20,62 +20,283 @@ class Transform(Module, ABC):
|
|
|
20
20
|
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
21
21
|
`grad` is always computed and can't be None. Otherwise set to False.
|
|
22
22
|
target (Target, optional):
|
|
23
|
-
what to set on
|
|
23
|
+
what to set on var. Defaults to 'update'.
|
|
24
24
|
"""
|
|
25
|
-
def __init__(
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
defaults: dict[str,Any] | None,
|
|
28
|
+
uses_grad: bool = False,
|
|
29
|
+
uses_loss: bool = False,
|
|
30
|
+
concat_params: bool = False,
|
|
31
|
+
update_freq: int = 1,
|
|
32
|
+
scale_first: bool = False,
|
|
33
|
+
inner: Chainable | None = None,
|
|
34
|
+
target: Target = 'update',
|
|
35
|
+
):
|
|
26
36
|
super().__init__(defaults)
|
|
27
37
|
self._target: Target = target
|
|
28
38
|
self._uses_grad = uses_grad
|
|
39
|
+
self._uses_loss = uses_loss
|
|
40
|
+
self._concat_params = concat_params
|
|
41
|
+
self._update_freq = update_freq
|
|
42
|
+
self._scale_first = scale_first
|
|
43
|
+
self._inner = inner
|
|
44
|
+
|
|
45
|
+
def update_tensors(
|
|
46
|
+
self,
|
|
47
|
+
tensors: list[torch.Tensor],
|
|
48
|
+
params: list[torch.Tensor],
|
|
49
|
+
grads: list[torch.Tensor] | None,
|
|
50
|
+
loss: torch.Tensor | float | None,
|
|
51
|
+
states: list[dict[str, Any]],
|
|
52
|
+
settings: Sequence[Mapping[str, Any]],
|
|
53
|
+
) -> None:
|
|
54
|
+
"""update function, this shouldn't be called directly. Updates this module."""
|
|
29
55
|
|
|
30
56
|
@abstractmethod
|
|
31
|
-
def
|
|
32
|
-
|
|
57
|
+
def apply_tensors(
|
|
58
|
+
self,
|
|
59
|
+
tensors: list[torch.Tensor],
|
|
60
|
+
params: list[torch.Tensor],
|
|
61
|
+
grads: list[torch.Tensor] | None,
|
|
62
|
+
loss: torch.Tensor | float | None,
|
|
63
|
+
states: list[dict[str, Any]],
|
|
64
|
+
settings: Sequence[Mapping[str, Any]],
|
|
65
|
+
) -> Sequence[torch.Tensor]:
|
|
66
|
+
"""apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
|
|
67
|
+
If possible, this shouldn't modify the internal state of this transform."""
|
|
68
|
+
|
|
69
|
+
@final
|
|
70
|
+
@torch.no_grad
|
|
71
|
+
def transform_update(
|
|
72
|
+
self,
|
|
73
|
+
tensors: list[torch.Tensor],
|
|
74
|
+
params: list[torch.Tensor],
|
|
75
|
+
grads: list[torch.Tensor] | None,
|
|
76
|
+
loss: torch.Tensor | float | None,
|
|
77
|
+
states: list[dict[str, Any]],
|
|
78
|
+
settings: Sequence[Mapping[str, Any]] | None,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Updates this transform from an arbitrary sequence of tensors."""
|
|
81
|
+
if self._concat_params:
|
|
82
|
+
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
83
|
+
params = [torch.cat([p.ravel() for p in params])]
|
|
84
|
+
grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
85
|
+
|
|
86
|
+
if settings is None:
|
|
87
|
+
settings = [self.defaults for _ in tensors]
|
|
88
|
+
|
|
89
|
+
step = self.global_state.get('__step', 0) # that way it gets reset correctly
|
|
90
|
+
self.global_state['__step'] = step + 1
|
|
91
|
+
|
|
92
|
+
num = len(tensors)
|
|
93
|
+
states = states[:num]
|
|
94
|
+
settings = settings[:num]
|
|
95
|
+
|
|
96
|
+
scale_factor = 1
|
|
97
|
+
|
|
98
|
+
# scaling factor for 1st step
|
|
99
|
+
if self._scale_first and step == 0:
|
|
100
|
+
# initial step size guess from pytorch LBFGS
|
|
101
|
+
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
102
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
103
|
+
|
|
104
|
+
# update transform
|
|
105
|
+
if step % self._update_freq == 0:
|
|
106
|
+
self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
107
|
+
|
|
108
|
+
# store for transform_apply
|
|
109
|
+
self.global_state["__tensors"] = tensors
|
|
110
|
+
self.global_state["__params"] = params
|
|
111
|
+
self.global_state["__grads"] = grads
|
|
112
|
+
self.global_state["__scale_factor"] = scale_factor
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@final
|
|
116
|
+
@torch.no_grad
|
|
117
|
+
def transform_apply(
|
|
118
|
+
self,
|
|
119
|
+
tensors: list[torch.Tensor],
|
|
120
|
+
params: list[torch.Tensor],
|
|
121
|
+
grads: list[torch.Tensor] | None,
|
|
122
|
+
loss: torch.Tensor | float | None,
|
|
123
|
+
states: list[dict[str, Any]],
|
|
124
|
+
settings: Sequence[Mapping[str, Any]] | None,
|
|
125
|
+
) -> list[torch.Tensor]:
|
|
126
|
+
"""Applies this transform to an arbitrary sequence of tensors.
|
|
127
|
+
This can be used after ``transform_update`` has been used at least once."""
|
|
128
|
+
|
|
129
|
+
if settings is None:
|
|
130
|
+
settings = [self.defaults for _ in tensors]
|
|
131
|
+
|
|
132
|
+
num = len(tensors)
|
|
133
|
+
states = states[:num]
|
|
134
|
+
settings = settings[:num]
|
|
135
|
+
|
|
136
|
+
un_tensors = tensors
|
|
137
|
+
un_params = params
|
|
138
|
+
un_grads = grads
|
|
139
|
+
|
|
140
|
+
tensors = self.global_state.pop("__tensors")
|
|
141
|
+
params = self.global_state.pop("__params")
|
|
142
|
+
grads = self.global_state.pop("__grads")
|
|
143
|
+
scale_factor = self.global_state.pop("__scale_factor")
|
|
144
|
+
|
|
145
|
+
# step with inner
|
|
146
|
+
if self._inner is not None:
|
|
147
|
+
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
|
|
148
|
+
if self._concat_params:
|
|
149
|
+
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
150
|
+
|
|
151
|
+
# apply transform
|
|
152
|
+
tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
|
|
153
|
+
|
|
154
|
+
# scale initial step, when preconditioner might not have been applied
|
|
155
|
+
if self._scale_first and self.global_state['__step'] == 1:
|
|
156
|
+
torch._foreach_mul_(tensors, scale_factor)
|
|
157
|
+
|
|
158
|
+
if self._concat_params:
|
|
159
|
+
tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
|
|
160
|
+
return tensors
|
|
161
|
+
|
|
162
|
+
def _get_keyed_states_settings(self, params: list[torch.Tensor]):
|
|
163
|
+
if self._concat_params:
|
|
164
|
+
p = params[0]
|
|
165
|
+
states = [self.state[p]]
|
|
166
|
+
settings = [self.settings[p]]
|
|
33
167
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
168
|
+
else:
|
|
169
|
+
states = []
|
|
170
|
+
settings = []
|
|
171
|
+
for p in params:
|
|
172
|
+
states.append(self.state[p])
|
|
173
|
+
settings.append(self.settings[p])
|
|
174
|
+
|
|
175
|
+
return states, settings
|
|
176
|
+
|
|
177
|
+
@final
|
|
178
|
+
@torch.no_grad
|
|
179
|
+
def keyed_transform_update(
|
|
180
|
+
self,
|
|
181
|
+
tensors: list[torch.Tensor],
|
|
182
|
+
params: list[torch.Tensor],
|
|
183
|
+
grads: list[torch.Tensor] | None,
|
|
184
|
+
loss: torch.Tensor | float | None,
|
|
185
|
+
):
|
|
186
|
+
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
187
|
+
states, settings = self._get_keyed_states_settings(params)
|
|
188
|
+
self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@final
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def keyed_transform_apply(
|
|
194
|
+
self,
|
|
195
|
+
tensors: list[torch.Tensor],
|
|
196
|
+
params: list[torch.Tensor],
|
|
197
|
+
grads: list[torch.Tensor] | None,
|
|
198
|
+
loss: torch.Tensor | float | None,
|
|
199
|
+
):
|
|
200
|
+
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
201
|
+
states, settings = self._get_keyed_states_settings(params)
|
|
202
|
+
return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def pre_step(self, var: Var) -> None:
|
|
206
|
+
"""Logic to run pre-transform, this way transform has access to Var."""
|
|
207
|
+
def post_step(self, var: Var) -> None:
|
|
208
|
+
"""Logic to run post-transform, this way transform has access to Var."""
|
|
209
|
+
|
|
210
|
+
def update(self, var: Var):
|
|
211
|
+
if self._target != 'update':
|
|
212
|
+
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
213
|
+
f"With {self._target = } only `step` method can be used.")
|
|
214
|
+
|
|
215
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
216
|
+
update = var.get_update() # this sets loss
|
|
217
|
+
if self._uses_grad: var.get_grad()
|
|
218
|
+
if self._uses_loss: var.get_loss(False)
|
|
219
|
+
params=var.params
|
|
220
|
+
self.pre_step(var)
|
|
221
|
+
|
|
222
|
+
# update
|
|
223
|
+
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
224
|
+
|
|
225
|
+
def apply(self, var: Var):
|
|
226
|
+
if self._target != 'update':
|
|
227
|
+
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
228
|
+
f"With {self._target = } only `step` method can be used.")
|
|
229
|
+
|
|
230
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
231
|
+
update = var.get_update() # this sets loss
|
|
232
|
+
if self._uses_grad: var.get_grad()
|
|
233
|
+
if self._uses_loss: var.get_loss(False)
|
|
234
|
+
params=var.params
|
|
235
|
+
|
|
236
|
+
# apply
|
|
237
|
+
var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
|
|
238
|
+
self.post_step(var)
|
|
239
|
+
return var
|
|
240
|
+
|
|
241
|
+
def step(self, var: Var) -> Var:
|
|
242
|
+
|
|
243
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
244
|
+
if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
|
|
245
|
+
if self._uses_grad or self._target == 'grad': var.get_grad()
|
|
246
|
+
if self._uses_loss: var.get_loss(False)
|
|
247
|
+
params=var.params
|
|
248
|
+
self.pre_step(var)
|
|
38
249
|
|
|
39
250
|
# ---------------------------------- update ---------------------------------- #
|
|
40
251
|
if self._target == 'update':
|
|
41
|
-
|
|
42
|
-
|
|
252
|
+
update = var.get_update()
|
|
253
|
+
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
254
|
+
var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
|
|
255
|
+
return var
|
|
43
256
|
|
|
44
257
|
# ----------------------------------- grad ----------------------------------- #
|
|
45
258
|
if self._target == 'grad':
|
|
46
|
-
|
|
47
|
-
|
|
259
|
+
grad = var.get_grad()
|
|
260
|
+
self.keyed_transform_update(grad, params, grad, var.loss)
|
|
261
|
+
var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
|
|
262
|
+
return var
|
|
48
263
|
|
|
49
264
|
# ------------------------------- params_direct ------------------------------ #
|
|
50
265
|
if self._target == 'params_direct':
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
266
|
+
self.keyed_transform_update(var.params, params, var.grad, var.loss)
|
|
267
|
+
new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
|
|
268
|
+
for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
|
|
269
|
+
return var
|
|
54
270
|
|
|
55
271
|
# ----------------------------- params_differnce ----------------------------- #
|
|
56
272
|
if self._target == 'params_difference':
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
273
|
+
p_clone = [p.clone() for p in var.params]
|
|
274
|
+
self.keyed_transform_update(p_clone, params, var.grad, var.loss)
|
|
275
|
+
new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
|
|
276
|
+
var.update = list(torch._foreach_sub(var.params, new_params))
|
|
277
|
+
return var
|
|
60
278
|
|
|
61
279
|
# ----------------------------- update_difference ---------------------------- #
|
|
62
280
|
if self._target == 'update_difference':
|
|
63
|
-
update =
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
281
|
+
update = var.get_update()
|
|
282
|
+
u_clone = [u.clone() for u in update]
|
|
283
|
+
self.keyed_transform_update(u_clone, params, var.grad, var.loss)
|
|
284
|
+
new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
|
|
285
|
+
var.update = list(torch._foreach_sub(update, new_update))
|
|
286
|
+
return var
|
|
67
287
|
|
|
68
288
|
# ---------------------------------- closure --------------------------------- #
|
|
69
289
|
if self._target == 'closure':
|
|
70
|
-
original_closure =
|
|
290
|
+
original_closure = var.closure
|
|
71
291
|
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
72
292
|
|
|
73
|
-
params =
|
|
293
|
+
params = var.params
|
|
74
294
|
def transformed_closure(backward=True):
|
|
75
295
|
if backward:
|
|
76
296
|
loss = original_closure()
|
|
77
297
|
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
78
|
-
|
|
298
|
+
self.keyed_transform_update(current_grad, params, var.grad, var.loss)
|
|
299
|
+
transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
|
|
79
300
|
for p, g in zip(params, transformed_grad):
|
|
80
301
|
p.grad = g
|
|
81
302
|
|
|
@@ -84,14 +305,15 @@ class Transform(Module, ABC):
|
|
|
84
305
|
|
|
85
306
|
return loss
|
|
86
307
|
|
|
87
|
-
|
|
88
|
-
|
|
308
|
+
var.closure = transformed_closure
|
|
309
|
+
self.post_step(var)
|
|
310
|
+
return var
|
|
89
311
|
|
|
90
312
|
# ---------------------------------- invalid --------------------------------- #
|
|
91
313
|
raise ValueError(f'Invalid target: {self._target}')
|
|
92
314
|
|
|
93
315
|
|
|
94
|
-
class TensorwiseTransform(
|
|
316
|
+
class TensorwiseTransform(Transform, ABC):
|
|
95
317
|
"""Base class for a parameter-wise transform.
|
|
96
318
|
|
|
97
319
|
This is an abstract class, to use it, subclass it and override `transform`.
|
|
@@ -102,151 +324,97 @@ class TensorwiseTransform(Module, ABC):
|
|
|
102
324
|
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
103
325
|
`grad` is always computed and can't be None. Otherwise set to False.
|
|
104
326
|
target (Target, optional):
|
|
105
|
-
what to set on
|
|
327
|
+
what to set on var. Defaults to 'update'.
|
|
106
328
|
"""
|
|
107
|
-
def __init__(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
defaults: dict[str,Any] | None,
|
|
332
|
+
uses_grad: bool = False,
|
|
333
|
+
uses_loss: bool = False,
|
|
334
|
+
concat_params: bool = False,
|
|
335
|
+
update_freq: int = 1,
|
|
336
|
+
scale_first: bool = False,
|
|
337
|
+
inner: Chainable | None = None,
|
|
338
|
+
target: Target = 'update',
|
|
339
|
+
):
|
|
340
|
+
super().__init__(
|
|
341
|
+
defaults=defaults,
|
|
342
|
+
uses_grad=uses_grad,
|
|
343
|
+
concat_params=concat_params,
|
|
344
|
+
update_freq=update_freq,
|
|
345
|
+
scale_first=scale_first,
|
|
346
|
+
uses_loss=uses_loss,
|
|
347
|
+
inner=inner,
|
|
348
|
+
target=target,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def update_tensor(
|
|
352
|
+
self,
|
|
353
|
+
tensor: torch.Tensor,
|
|
354
|
+
param: torch.Tensor,
|
|
355
|
+
grad: torch.Tensor | None,
|
|
356
|
+
loss: torch.Tensor | float | None,
|
|
357
|
+
state: dict[str, Any],
|
|
358
|
+
setting: Mapping[str, Any],
|
|
359
|
+
) -> None:
|
|
360
|
+
"""Updates this transform. By default does nothing - if logic is in `apply` method."""
|
|
111
361
|
|
|
112
362
|
@abstractmethod
|
|
113
|
-
def
|
|
363
|
+
def apply_tensor(
|
|
114
364
|
self,
|
|
115
365
|
tensor: torch.Tensor,
|
|
116
366
|
param: torch.Tensor,
|
|
117
367
|
grad: torch.Tensor | None,
|
|
118
|
-
|
|
368
|
+
loss: torch.Tensor | float | None,
|
|
369
|
+
state: dict[str, Any],
|
|
370
|
+
setting: Mapping[str, Any],
|
|
119
371
|
) -> torch.Tensor:
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
vars.update = transformed_update
|
|
138
|
-
return vars
|
|
139
|
-
|
|
140
|
-
# ----------------------------------- grad ----------------------------------- #
|
|
141
|
-
if self._target == 'grad':
|
|
142
|
-
grad = vars.get_grad()
|
|
143
|
-
transformed_grad = []
|
|
144
|
-
|
|
145
|
-
for p, g in zip(params, grad):
|
|
146
|
-
transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
|
|
147
|
-
|
|
148
|
-
vars.grad = transformed_grad
|
|
149
|
-
return vars
|
|
150
|
-
|
|
151
|
-
# ------------------------------- params_direct ------------------------------ #
|
|
152
|
-
if self._target == 'params_direct':
|
|
153
|
-
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
154
|
-
|
|
155
|
-
for p, g in zip(params, grad):
|
|
156
|
-
set_storage_(p, self.transform(tensor=p, param=p, grad=g, vars=vars))
|
|
157
|
-
|
|
158
|
-
return vars
|
|
159
|
-
|
|
160
|
-
# ----------------------------- params_difference ---------------------------- #
|
|
161
|
-
if self._target == 'params_difference':
|
|
162
|
-
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
163
|
-
transformed_params = []
|
|
164
|
-
|
|
165
|
-
for p, g in zip(params, grad):
|
|
166
|
-
transformed_params.append(
|
|
167
|
-
self.transform(tensor=p.clone(), param=p, grad=g, vars=vars)
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
vars.update = list(torch._foreach_sub(params, transformed_params))
|
|
171
|
-
return vars
|
|
172
|
-
|
|
173
|
-
# ----------------------------- update_difference ---------------------------- #
|
|
174
|
-
if self._target == 'update_difference':
|
|
175
|
-
update = vars.get_update()
|
|
176
|
-
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
177
|
-
transformed_update = []
|
|
178
|
-
|
|
179
|
-
for p, g, u in zip(params, grad, update):
|
|
180
|
-
transformed_update.append(
|
|
181
|
-
self.transform(tensor=u.clone(), param=p, grad=g, vars=vars)
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
vars.update = list(torch._foreach_sub(update, transformed_update))
|
|
185
|
-
return vars
|
|
186
|
-
|
|
187
|
-
# ---------------------------------- closure --------------------------------- #
|
|
188
|
-
if self._target == 'closure':
|
|
189
|
-
original_closure = vars.closure
|
|
190
|
-
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
191
|
-
|
|
192
|
-
params = vars.params
|
|
193
|
-
def transformed_closure(backward=True):
|
|
194
|
-
if backward:
|
|
195
|
-
loss = original_closure()
|
|
196
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
197
|
-
transformed_grad = []
|
|
198
|
-
|
|
199
|
-
for p, g in zip(params, grad):
|
|
200
|
-
transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
|
|
201
|
-
|
|
202
|
-
for p, g in zip(params, transformed_grad):
|
|
203
|
-
p.grad = g
|
|
204
|
-
|
|
205
|
-
else:
|
|
206
|
-
loss = original_closure(False)
|
|
207
|
-
|
|
208
|
-
return loss
|
|
209
|
-
|
|
210
|
-
vars.closure = transformed_closure
|
|
211
|
-
return vars
|
|
212
|
-
|
|
213
|
-
# ---------------------------------- invalid --------------------------------- #
|
|
214
|
-
raise ValueError(f'Invalid target: {self._target}')
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
def apply(
|
|
372
|
+
"""Applies the update rule to `tensor`."""
|
|
373
|
+
|
|
374
|
+
@final
|
|
375
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
376
|
+
if grads is None: grads = [None]*len(tensors)
|
|
377
|
+
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
378
|
+
self.update_tensor(t, p, g, loss, state, setting)
|
|
379
|
+
|
|
380
|
+
@final
|
|
381
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
382
|
+
applied = []
|
|
383
|
+
if grads is None: grads = [None]*len(tensors)
|
|
384
|
+
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
385
|
+
applied.append(self.apply_tensor(t, p, g, loss, state, setting))
|
|
386
|
+
return applied
|
|
387
|
+
|
|
388
|
+
def apply_transform(
|
|
219
389
|
tfm: Chainable,
|
|
220
390
|
tensors: list[torch.Tensor],
|
|
221
391
|
params: list[torch.Tensor],
|
|
222
392
|
grads: list[torch.Tensor] | None,
|
|
223
|
-
|
|
393
|
+
loss: torch.Tensor | float | None = None,
|
|
394
|
+
var: Var | None = None,
|
|
224
395
|
current_step: int = 0,
|
|
225
396
|
):
|
|
226
|
-
if
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
if tfm._uses_grad: grads_list = vars.get_grad()
|
|
235
|
-
else: grads_list = [None] * len(tensors)
|
|
236
|
-
return [tfm.transform(t, p, g, vars) for t,p,g in zip(tensors,params,grads_list)]
|
|
397
|
+
if var is None:
|
|
398
|
+
var = Var(params=params, closure=None, model=None, current_step=current_step)
|
|
399
|
+
var.loss = loss
|
|
400
|
+
|
|
401
|
+
if isinstance(tfm, Transform) and tfm._target == 'update':
|
|
402
|
+
if tfm._uses_grad and grads is None: grads = var.get_grad()
|
|
403
|
+
tfm.keyed_transform_update(tensors, params, grads, loss)
|
|
404
|
+
return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
|
|
237
405
|
|
|
238
406
|
if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
|
|
239
407
|
if isinstance(tfm, Sequence):
|
|
240
408
|
for module in tfm:
|
|
241
|
-
tensors =
|
|
409
|
+
tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
|
|
242
410
|
return tensors
|
|
243
411
|
|
|
244
412
|
if isinstance(tfm, Module):
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
assert
|
|
250
|
-
return
|
|
413
|
+
cvar = var.clone(clone_update=False)
|
|
414
|
+
cvar.update = tensors
|
|
415
|
+
cvar = tfm.step(cvar)
|
|
416
|
+
var.update_attrs_from_clone_(cvar)
|
|
417
|
+
assert cvar.update is not None
|
|
418
|
+
return cvar.update
|
|
251
419
|
|
|
252
420
|
raise TypeError(type(tfm))
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from .clipping import *
|
|
2
2
|
from .grad_approximation import *
|
|
3
3
|
from .line_search import *
|
|
4
|
-
from .
|
|
4
|
+
from .step_size import *
|
|
5
5
|
from .momentum import *
|
|
6
6
|
from .ops import *
|
|
7
7
|
from .optimizers import *
|
|
@@ -11,3 +11,5 @@ from .smoothing import *
|
|
|
11
11
|
from .weight_decay import *
|
|
12
12
|
from .wrappers import *
|
|
13
13
|
from .second_order import *
|
|
14
|
+
from .higher_order import *
|
|
15
|
+
from .misc import *
|