torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .module import Chainable, Modular, Module, Var
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Reformulation(Module, ABC):
|
|
10
|
+
def __init__(self, defaults: dict | None, modules: Chainable | None):
|
|
11
|
+
super().__init__(defaults)
|
|
12
|
+
|
|
13
|
+
if modules is not None:
|
|
14
|
+
self.set_child("modules", modules)
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
18
|
+
"""
|
|
19
|
+
returns (loss, gradient), if backward is False then gradient can be None.
|
|
20
|
+
|
|
21
|
+
If evaluating original loss/gradient at x_0, set them to ``var``.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def pre_step(self, var: Var) -> Var | None:
|
|
25
|
+
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
26
|
+
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
27
|
+
|
|
28
|
+
def step(self, var):
|
|
29
|
+
ret = self.pre_step(var) # pylint:disable = assignment-from-no-return
|
|
30
|
+
if isinstance(ret, Var): var = ret
|
|
31
|
+
|
|
32
|
+
if var.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
+
params, closure = var.params, var.closure
|
|
34
|
+
|
|
35
|
+
# step with children
|
|
36
|
+
if 'modules' in self.children:
|
|
37
|
+
|
|
38
|
+
# make a reformulated closure
|
|
39
|
+
def modified_closure(backward=True):
|
|
40
|
+
loss, grad = self.closure(backward, closure, params, var)
|
|
41
|
+
|
|
42
|
+
if grad is not None:
|
|
43
|
+
for p,g in zip(params, grad):
|
|
44
|
+
p.grad = g
|
|
45
|
+
|
|
46
|
+
return loss
|
|
47
|
+
|
|
48
|
+
# set it to a new Var object
|
|
49
|
+
modified_var = var.clone(clone_update=False)
|
|
50
|
+
modified_var.closure = modified_closure
|
|
51
|
+
|
|
52
|
+
# step with child
|
|
53
|
+
modules = self.children['modules']
|
|
54
|
+
modified_var = modules.step(modified_var)
|
|
55
|
+
|
|
56
|
+
# modified_var.loss and grad refers to loss and grad of a modified objective
|
|
57
|
+
# so we only take the update
|
|
58
|
+
var.update = modified_var.update
|
|
59
|
+
|
|
60
|
+
# or just evaluate new closure and set to update
|
|
61
|
+
else:
|
|
62
|
+
loss, grad = self.closure(backward=True, closure=closure, params=params, var=var)
|
|
63
|
+
if grad is not None: var.update = list(grad)
|
|
64
|
+
|
|
65
|
+
return var
|
torchzero/core/transform.py
CHANGED
|
@@ -1,18 +1,36 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from collections.abc import Iterable,
|
|
2
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
3
3
|
from typing import Any, Literal, final
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ..utils import
|
|
8
|
-
from .module import
|
|
7
|
+
from ..utils import TensorList, set_storage_, vec_to_tensors
|
|
8
|
+
from .module import Chain, Chainable, Module, Var
|
|
9
9
|
|
|
10
10
|
Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
class Transform(Module, ABC):
|
|
13
|
-
"""Base class for a transform.
|
|
14
|
+
"""Base class for a transform.
|
|
15
|
+
This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
|
|
14
16
|
|
|
15
17
|
A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
|
|
18
|
+
It has two methods:
|
|
19
|
+
|
|
20
|
+
- ``update_tensors`` updates the internal state of this transform, it doesn't modify tensors. \
|
|
21
|
+
It may be called multiple times before ``apply_tensors``.
|
|
22
|
+
- ``apply_tensors`` applies this transform to tensors, without modifying the internal state if possible.
|
|
23
|
+
|
|
24
|
+
Alternatively, if update-apply structure doesn't make sense for a transform, all logic can be defined within ``apply_tensors``.
|
|
25
|
+
|
|
26
|
+
Transform can be applied to tensors corresponding to custom parameters
|
|
27
|
+
by calling ``keyed_transform_update`` and ``keyed_transform_apply``,
|
|
28
|
+
parameters will be keys to store per-parameter states, so they should remain the same python objects.
|
|
29
|
+
|
|
30
|
+
Alternatively you can manually create a list of state dictionaries per each tensor and pass it to
|
|
31
|
+
``transform_update`` and ``transform_apply``.
|
|
32
|
+
|
|
33
|
+
A transform can modify the closure instead of directly modifying update by passing ``target="closure"``.
|
|
16
34
|
|
|
17
35
|
Args:
|
|
18
36
|
defaults (dict[str,Any] | None): dict with default values.
|
|
@@ -21,63 +39,63 @@ class Transform(Module, ABC):
|
|
|
21
39
|
`grad` is always computed and can't be None. Otherwise set to False.
|
|
22
40
|
target (Target, optional):
|
|
23
41
|
what to set on var. Defaults to 'update'.
|
|
42
|
+
|
|
24
43
|
"""
|
|
25
44
|
def __init__(
|
|
26
45
|
self,
|
|
27
46
|
defaults: dict[str,Any] | None,
|
|
28
|
-
uses_grad: bool,
|
|
47
|
+
uses_grad: bool = False,
|
|
48
|
+
uses_loss: bool = False,
|
|
29
49
|
concat_params: bool = False,
|
|
30
50
|
update_freq: int = 1,
|
|
31
|
-
scale_first: bool = False,
|
|
32
51
|
inner: Chainable | None = None,
|
|
33
52
|
target: Target = 'update',
|
|
34
53
|
):
|
|
35
54
|
super().__init__(defaults)
|
|
36
55
|
self._target: Target = target
|
|
37
56
|
self._uses_grad = uses_grad
|
|
57
|
+
self._uses_loss = uses_loss
|
|
38
58
|
self._concat_params = concat_params
|
|
39
59
|
self._update_freq = update_freq
|
|
40
|
-
self._scale_first = scale_first
|
|
41
60
|
self._inner = inner
|
|
61
|
+
self._var = None
|
|
42
62
|
|
|
43
|
-
def
|
|
63
|
+
def update_tensors(
|
|
44
64
|
self,
|
|
45
65
|
tensors: list[torch.Tensor],
|
|
46
66
|
params: list[torch.Tensor],
|
|
47
67
|
grads: list[torch.Tensor] | None,
|
|
48
|
-
loss: torch.Tensor | None,
|
|
68
|
+
loss: torch.Tensor | float | None,
|
|
49
69
|
states: list[dict[str, Any]],
|
|
50
70
|
settings: Sequence[Mapping[str, Any]],
|
|
51
71
|
) -> None:
|
|
52
|
-
"""
|
|
72
|
+
"""update function, this shouldn't be called directly. Updates this module."""
|
|
53
73
|
|
|
54
74
|
@abstractmethod
|
|
55
|
-
def
|
|
75
|
+
def apply_tensors(
|
|
56
76
|
self,
|
|
57
77
|
tensors: list[torch.Tensor],
|
|
58
78
|
params: list[torch.Tensor],
|
|
59
79
|
grads: list[torch.Tensor] | None,
|
|
60
|
-
loss: torch.Tensor | None,
|
|
80
|
+
loss: torch.Tensor | float | None,
|
|
61
81
|
states: list[dict[str, Any]],
|
|
62
82
|
settings: Sequence[Mapping[str, Any]],
|
|
63
83
|
) -> Sequence[torch.Tensor]:
|
|
64
|
-
"""Applies the update rule to `tensors
|
|
84
|
+
"""apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
|
|
85
|
+
If possible, this shouldn't modify the internal state of this transform."""
|
|
65
86
|
|
|
66
87
|
@final
|
|
67
88
|
@torch.no_grad
|
|
68
|
-
def
|
|
89
|
+
def transform_update(
|
|
69
90
|
self,
|
|
70
91
|
tensors: list[torch.Tensor],
|
|
71
92
|
params: list[torch.Tensor],
|
|
72
93
|
grads: list[torch.Tensor] | None,
|
|
73
|
-
loss: torch.Tensor | None,
|
|
94
|
+
loss: torch.Tensor | float | None,
|
|
74
95
|
states: list[dict[str, Any]],
|
|
75
96
|
settings: Sequence[Mapping[str, Any]] | None,
|
|
76
|
-
) ->
|
|
77
|
-
"""
|
|
78
|
-
un_tensors = tensors
|
|
79
|
-
un_params = params
|
|
80
|
-
un_grads = grads
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Updates this transform from an arbitrary sequence of tensors."""
|
|
81
99
|
if self._concat_params:
|
|
82
100
|
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
83
101
|
params = [torch.cat([p.ravel() for p in params])]
|
|
@@ -86,53 +104,67 @@ class Transform(Module, ABC):
|
|
|
86
104
|
if settings is None:
|
|
87
105
|
settings = [self.defaults for _ in tensors]
|
|
88
106
|
|
|
89
|
-
step = self.global_state.get('__step', 0)
|
|
107
|
+
step = self.global_state.get('__step', 0) # that way it gets reset correctly
|
|
108
|
+
self.global_state['__step'] = step + 1
|
|
109
|
+
|
|
90
110
|
num = len(tensors)
|
|
91
111
|
states = states[:num]
|
|
92
112
|
settings = settings[:num]
|
|
93
113
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
114
|
+
# update transform
|
|
115
|
+
if step % self._update_freq == 0:
|
|
116
|
+
self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
97
117
|
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
118
|
+
# store for transform_apply
|
|
119
|
+
self.global_state["__tensors"] = tensors
|
|
120
|
+
self.global_state["__params"] = params
|
|
121
|
+
self.global_state["__grads"] = grads
|
|
103
122
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
123
|
+
|
|
124
|
+
@final
|
|
125
|
+
@torch.no_grad
|
|
126
|
+
def transform_apply(
|
|
127
|
+
self,
|
|
128
|
+
tensors: list[torch.Tensor],
|
|
129
|
+
params: list[torch.Tensor],
|
|
130
|
+
grads: list[torch.Tensor] | None,
|
|
131
|
+
loss: torch.Tensor | float | None,
|
|
132
|
+
states: list[dict[str, Any]],
|
|
133
|
+
settings: Sequence[Mapping[str, Any]] | None,
|
|
134
|
+
) -> list[torch.Tensor]:
|
|
135
|
+
"""Applies this transform to an arbitrary sequence of tensors.
|
|
136
|
+
This can be used after ``transform_update`` has been used at least once."""
|
|
137
|
+
|
|
138
|
+
if settings is None:
|
|
139
|
+
settings = [self.defaults for _ in tensors]
|
|
140
|
+
|
|
141
|
+
num = len(tensors)
|
|
142
|
+
states = states[:num]
|
|
143
|
+
settings = settings[:num]
|
|
144
|
+
|
|
145
|
+
un_tensors = tensors
|
|
146
|
+
un_params = params
|
|
147
|
+
un_grads = grads
|
|
148
|
+
|
|
149
|
+
tensors = self.global_state.pop("__tensors")
|
|
150
|
+
params = self.global_state.pop("__params")
|
|
151
|
+
grads = self.global_state.pop("__grads")
|
|
107
152
|
|
|
108
153
|
# step with inner
|
|
109
154
|
if self._inner is not None:
|
|
110
|
-
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
|
|
155
|
+
tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
|
|
111
156
|
if self._concat_params:
|
|
112
157
|
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
113
158
|
|
|
114
159
|
# apply transform
|
|
115
|
-
tensors = list(self.
|
|
116
|
-
|
|
117
|
-
# scale initial step, when preconditioner might not have been applied
|
|
118
|
-
if scale_first and step == 0:
|
|
119
|
-
torch._foreach_mul_(tensors, scale_factor)
|
|
160
|
+
tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
|
|
120
161
|
|
|
121
|
-
self.global_state['__step'] = step + 1
|
|
122
162
|
if self._concat_params:
|
|
123
163
|
tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
|
|
124
|
-
return tensors
|
|
125
164
|
|
|
165
|
+
return tensors
|
|
126
166
|
|
|
127
|
-
|
|
128
|
-
def keyed_transform(
|
|
129
|
-
self,
|
|
130
|
-
tensors: list[torch.Tensor],
|
|
131
|
-
params: list[torch.Tensor],
|
|
132
|
-
grads: list[torch.Tensor] | None,
|
|
133
|
-
loss: torch.Tensor | None,
|
|
134
|
-
):
|
|
135
|
-
"""Applies this transform to `tensors`, `params` will be used as keys and need to always point to same tensor objects."""
|
|
167
|
+
def _get_keyed_states_settings(self, params: list[torch.Tensor]):
|
|
136
168
|
if self._concat_params:
|
|
137
169
|
p = params[0]
|
|
138
170
|
states = [self.state[p]]
|
|
@@ -145,42 +177,128 @@ class Transform(Module, ABC):
|
|
|
145
177
|
states.append(self.state[p])
|
|
146
178
|
settings.append(self.settings[p])
|
|
147
179
|
|
|
148
|
-
return
|
|
180
|
+
return states, settings
|
|
181
|
+
|
|
182
|
+
@final
|
|
183
|
+
@torch.no_grad
|
|
184
|
+
def keyed_transform_update(
|
|
185
|
+
self,
|
|
186
|
+
tensors: list[torch.Tensor],
|
|
187
|
+
params: list[torch.Tensor],
|
|
188
|
+
grads: list[torch.Tensor] | None,
|
|
189
|
+
loss: torch.Tensor | float | None,
|
|
190
|
+
):
|
|
191
|
+
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
192
|
+
states, settings = self._get_keyed_states_settings(params)
|
|
193
|
+
self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@final
|
|
197
|
+
@torch.no_grad
|
|
198
|
+
def keyed_transform_apply(
|
|
199
|
+
self,
|
|
200
|
+
tensors: list[torch.Tensor],
|
|
201
|
+
params: list[torch.Tensor],
|
|
202
|
+
grads: list[torch.Tensor] | None,
|
|
203
|
+
loss: torch.Tensor | float | None,
|
|
204
|
+
):
|
|
205
|
+
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
206
|
+
states, settings = self._get_keyed_states_settings(params)
|
|
207
|
+
return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def pre_step(self, var: Var) -> None:
|
|
211
|
+
"""Logic to run pre-transform, this way transform has access to Var."""
|
|
212
|
+
def post_step(self, var: Var) -> None:
|
|
213
|
+
"""Logic to run post-transform, this way transform has access to Var."""
|
|
214
|
+
|
|
215
|
+
def update(self, var: Var):
|
|
216
|
+
if self._target != 'update':
|
|
217
|
+
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
218
|
+
f"With {self._target = } only `step` method can be used.")
|
|
149
219
|
|
|
150
|
-
def step(self, var: Var) -> Var:
|
|
151
220
|
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
221
|
+
update = var.get_update() # this sets loss
|
|
152
222
|
if self._uses_grad: var.get_grad()
|
|
223
|
+
if self._uses_loss: var.get_loss(False)
|
|
153
224
|
params=var.params
|
|
225
|
+
self.pre_step(var)
|
|
226
|
+
|
|
227
|
+
# update
|
|
228
|
+
self._var = var
|
|
229
|
+
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
230
|
+
self._var = None
|
|
231
|
+
|
|
232
|
+
def apply(self, var: Var):
|
|
233
|
+
if self._target != 'update':
|
|
234
|
+
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
235
|
+
f"With {self._target = } only `step` method can be used.")
|
|
236
|
+
|
|
237
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
238
|
+
update = var.get_update() # this sets loss
|
|
239
|
+
if self._uses_grad: var.get_grad()
|
|
240
|
+
if self._uses_loss: var.get_loss(False)
|
|
241
|
+
params=var.params
|
|
242
|
+
|
|
243
|
+
# apply
|
|
244
|
+
self._var = var
|
|
245
|
+
var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
|
|
246
|
+
self._var = None
|
|
247
|
+
|
|
248
|
+
self.post_step(var)
|
|
249
|
+
return var
|
|
250
|
+
|
|
251
|
+
def step(self, var: Var) -> Var:
|
|
252
|
+
|
|
253
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
254
|
+
if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
|
|
255
|
+
if self._uses_grad or self._target == 'grad': var.get_grad()
|
|
256
|
+
if self._uses_loss: var.get_loss(False)
|
|
257
|
+
params=var.params
|
|
258
|
+
self.pre_step(var)
|
|
259
|
+
self._var = var
|
|
154
260
|
|
|
155
261
|
# ---------------------------------- update ---------------------------------- #
|
|
156
262
|
if self._target == 'update':
|
|
157
263
|
update = var.get_update()
|
|
158
|
-
|
|
264
|
+
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
265
|
+
var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
|
|
266
|
+
self._var = None
|
|
159
267
|
return var
|
|
160
268
|
|
|
161
269
|
# ----------------------------------- grad ----------------------------------- #
|
|
162
270
|
if self._target == 'grad':
|
|
163
271
|
grad = var.get_grad()
|
|
164
|
-
|
|
272
|
+
self.keyed_transform_update(grad, params, grad, var.loss)
|
|
273
|
+
var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
|
|
274
|
+
self._var = None
|
|
165
275
|
return var
|
|
166
276
|
|
|
167
277
|
# ------------------------------- params_direct ------------------------------ #
|
|
168
278
|
if self._target == 'params_direct':
|
|
169
|
-
|
|
279
|
+
self.keyed_transform_update(var.params, params, var.grad, var.loss)
|
|
280
|
+
new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
|
|
170
281
|
for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
|
|
282
|
+
self._var = None
|
|
171
283
|
return var
|
|
172
284
|
|
|
173
285
|
# ----------------------------- params_differnce ----------------------------- #
|
|
174
286
|
if self._target == 'params_difference':
|
|
175
|
-
|
|
287
|
+
p_clone = [p.clone() for p in var.params]
|
|
288
|
+
self.keyed_transform_update(p_clone, params, var.grad, var.loss)
|
|
289
|
+
new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
|
|
176
290
|
var.update = list(torch._foreach_sub(var.params, new_params))
|
|
291
|
+
self._var = None
|
|
177
292
|
return var
|
|
178
293
|
|
|
179
294
|
# ----------------------------- update_difference ---------------------------- #
|
|
180
295
|
if self._target == 'update_difference':
|
|
181
296
|
update = var.get_update()
|
|
182
|
-
|
|
297
|
+
u_clone = [u.clone() for u in update]
|
|
298
|
+
self.keyed_transform_update(u_clone, params, var.grad, var.loss)
|
|
299
|
+
new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
|
|
183
300
|
var.update = list(torch._foreach_sub(update, new_update))
|
|
301
|
+
self._var = None
|
|
184
302
|
return var
|
|
185
303
|
|
|
186
304
|
# ---------------------------------- closure --------------------------------- #
|
|
@@ -189,11 +307,17 @@ class Transform(Module, ABC):
|
|
|
189
307
|
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
190
308
|
|
|
191
309
|
params = var.params
|
|
310
|
+
parent_var = self._var
|
|
192
311
|
def transformed_closure(backward=True):
|
|
193
312
|
if backward:
|
|
194
313
|
loss = original_closure()
|
|
195
314
|
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
196
|
-
|
|
315
|
+
|
|
316
|
+
self._var = parent_var
|
|
317
|
+
self.keyed_transform_update(current_grad, params, var.grad, var.loss)
|
|
318
|
+
transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
|
|
319
|
+
self._var = None
|
|
320
|
+
|
|
197
321
|
for p, g in zip(params, transformed_grad):
|
|
198
322
|
p.grad = g
|
|
199
323
|
|
|
@@ -203,6 +327,8 @@ class Transform(Module, ABC):
|
|
|
203
327
|
return loss
|
|
204
328
|
|
|
205
329
|
var.closure = transformed_closure
|
|
330
|
+
self.post_step(var)
|
|
331
|
+
self._var = None
|
|
206
332
|
return var
|
|
207
333
|
|
|
208
334
|
# ---------------------------------- invalid --------------------------------- #
|
|
@@ -212,7 +338,7 @@ class Transform(Module, ABC):
|
|
|
212
338
|
class TensorwiseTransform(Transform, ABC):
|
|
213
339
|
"""Base class for a parameter-wise transform.
|
|
214
340
|
|
|
215
|
-
This is an abstract class, to use it, subclass it and override `
|
|
341
|
+
This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
|
|
216
342
|
|
|
217
343
|
Args:
|
|
218
344
|
defaults (dict[str,Any] | None): dict with default values.
|
|
@@ -225,10 +351,10 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
225
351
|
def __init__(
|
|
226
352
|
self,
|
|
227
353
|
defaults: dict[str,Any] | None,
|
|
228
|
-
uses_grad: bool,
|
|
354
|
+
uses_grad: bool = False,
|
|
355
|
+
uses_loss: bool = False,
|
|
229
356
|
concat_params: bool = False,
|
|
230
357
|
update_freq: int = 1,
|
|
231
|
-
scale_first: bool = False,
|
|
232
358
|
inner: Chainable | None = None,
|
|
233
359
|
target: Target = 'update',
|
|
234
360
|
):
|
|
@@ -237,7 +363,7 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
237
363
|
uses_grad=uses_grad,
|
|
238
364
|
concat_params=concat_params,
|
|
239
365
|
update_freq=update_freq,
|
|
240
|
-
|
|
366
|
+
uses_loss=uses_loss,
|
|
241
367
|
inner=inner,
|
|
242
368
|
target=target,
|
|
243
369
|
)
|
|
@@ -247,9 +373,9 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
247
373
|
tensor: torch.Tensor,
|
|
248
374
|
param: torch.Tensor,
|
|
249
375
|
grad: torch.Tensor | None,
|
|
250
|
-
loss: torch.Tensor | None,
|
|
376
|
+
loss: torch.Tensor | float | None,
|
|
251
377
|
state: dict[str, Any],
|
|
252
|
-
|
|
378
|
+
setting: Mapping[str, Any],
|
|
253
379
|
) -> None:
|
|
254
380
|
"""Updates this transform. By default does nothing - if logic is in `apply` method."""
|
|
255
381
|
|
|
@@ -259,20 +385,20 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
259
385
|
tensor: torch.Tensor,
|
|
260
386
|
param: torch.Tensor,
|
|
261
387
|
grad: torch.Tensor | None,
|
|
262
|
-
loss: torch.Tensor | None,
|
|
388
|
+
loss: torch.Tensor | float | None,
|
|
263
389
|
state: dict[str, Any],
|
|
264
|
-
|
|
390
|
+
setting: Mapping[str, Any],
|
|
265
391
|
) -> torch.Tensor:
|
|
266
392
|
"""Applies the update rule to `tensor`."""
|
|
267
393
|
|
|
268
394
|
@final
|
|
269
|
-
def
|
|
395
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
270
396
|
if grads is None: grads = [None]*len(tensors)
|
|
271
397
|
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
272
398
|
self.update_tensor(t, p, g, loss, state, setting)
|
|
273
399
|
|
|
274
400
|
@final
|
|
275
|
-
def
|
|
401
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
276
402
|
applied = []
|
|
277
403
|
if grads is None: grads = [None]*len(tensors)
|
|
278
404
|
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
@@ -284,7 +410,7 @@ def apply_transform(
|
|
|
284
410
|
tensors: list[torch.Tensor],
|
|
285
411
|
params: list[torch.Tensor],
|
|
286
412
|
grads: list[torch.Tensor] | None,
|
|
287
|
-
loss: torch.Tensor | None = None,
|
|
413
|
+
loss: torch.Tensor | float | None = None,
|
|
288
414
|
var: Var | None = None,
|
|
289
415
|
current_step: int = 0,
|
|
290
416
|
):
|
|
@@ -292,9 +418,10 @@ def apply_transform(
|
|
|
292
418
|
var = Var(params=params, closure=None, model=None, current_step=current_step)
|
|
293
419
|
var.loss = loss
|
|
294
420
|
|
|
295
|
-
if isinstance(tfm, Transform):
|
|
421
|
+
if isinstance(tfm, Transform) and tfm._target == 'update':
|
|
296
422
|
if tfm._uses_grad and grads is None: grads = var.get_grad()
|
|
297
|
-
|
|
423
|
+
tfm.keyed_transform_update(tensors, params, grads, loss)
|
|
424
|
+
return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
|
|
298
425
|
|
|
299
426
|
if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
|
|
300
427
|
if isinstance(tfm, Sequence):
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,14 +1,23 @@
|
|
|
1
|
+
from . import experimental
|
|
1
2
|
from .clipping import *
|
|
3
|
+
from .conjugate_gradient import *
|
|
2
4
|
from .grad_approximation import *
|
|
5
|
+
from .higher_order import *
|
|
6
|
+
from .least_squares import *
|
|
3
7
|
from .line_search import *
|
|
4
|
-
from .
|
|
8
|
+
from .misc import *
|
|
5
9
|
from .momentum import *
|
|
6
10
|
from .ops import *
|
|
7
|
-
from .
|
|
11
|
+
from .adaptive import *
|
|
8
12
|
from .projections import *
|
|
9
13
|
from .quasi_newton import *
|
|
14
|
+
from .second_order import *
|
|
10
15
|
from .smoothing import *
|
|
16
|
+
from .step_size import *
|
|
17
|
+
from .termination import *
|
|
18
|
+
from .trust_region import *
|
|
19
|
+
from .variance_reduction import *
|
|
11
20
|
from .weight_decay import *
|
|
12
21
|
from .wrappers import *
|
|
13
|
-
from .
|
|
14
|
-
from .
|
|
22
|
+
from .restarts import *
|
|
23
|
+
from .zeroth_order import *
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from .adagrad import Adagrad, FullMatrixAdagrad, AdagradNorm
|
|
2
|
+
|
|
3
|
+
# from .curveball import CurveBall
|
|
4
|
+
# from .spectral import SpectralPreconditioner
|
|
5
|
+
from .adahessian import AdaHessian
|
|
6
|
+
from .adam import Adam
|
|
7
|
+
from .adan import Adan
|
|
8
|
+
from .adaptive_heavyball import AdaptiveHeavyBall
|
|
9
|
+
from .aegd import AEGD
|
|
10
|
+
from .esgd import ESGD
|
|
11
|
+
from .lmadagrad import LMAdagrad
|
|
12
|
+
from .lion import Lion
|
|
13
|
+
from .mars import MARSCorrection
|
|
14
|
+
from .matrix_momentum import MatrixMomentum
|
|
15
|
+
from .msam import MSAM, MSAMObjective
|
|
16
|
+
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
17
|
+
from .natural_gradient import NaturalGradient
|
|
18
|
+
from .orthograd import OrthoGrad, orthograd_
|
|
19
|
+
from .rmsprop import RMSprop
|
|
20
|
+
from .rprop import (
|
|
21
|
+
BacktrackOnSignChange,
|
|
22
|
+
Rprop,
|
|
23
|
+
ScaleLRBySignChange,
|
|
24
|
+
SignConsistencyLRs,
|
|
25
|
+
SignConsistencyMask,
|
|
26
|
+
)
|
|
27
|
+
from .sam import ASAM, SAM
|
|
28
|
+
from .shampoo import Shampoo
|
|
29
|
+
from .soap import SOAP
|
|
30
|
+
from .sophia_h import SophiaH
|