torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections import ChainMap, defaultdict
|
|
3
|
+
from collections.abc import Mapping, Sequence
|
|
4
|
+
from typing import Any, overload, final
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from .module import Module, Chainable, Vars
|
|
9
|
+
from .transform import apply, Transform, Target
|
|
10
|
+
from ..utils import TensorList, vec_to_tensors
|
|
11
|
+
|
|
12
|
+
class Preconditioner(Transform):
|
|
13
|
+
"""Abstract class for a preconditioner."""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
defaults: dict | None,
|
|
17
|
+
uses_grad: bool,
|
|
18
|
+
concat_params: bool = False,
|
|
19
|
+
update_freq: int = 1,
|
|
20
|
+
scale_first: bool = False,
|
|
21
|
+
inner: Chainable | None = None,
|
|
22
|
+
target: Target = "update",
|
|
23
|
+
):
|
|
24
|
+
if defaults is None: defaults = {}
|
|
25
|
+
defaults.update(dict(__update_freq=update_freq, __concat_params=concat_params, __scale_first=scale_first))
|
|
26
|
+
super().__init__(defaults, uses_grad=uses_grad, target=target)
|
|
27
|
+
|
|
28
|
+
if inner is not None:
|
|
29
|
+
self.set_child('inner', inner)
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def update(self, tensors: list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
|
|
33
|
+
"""updates the preconditioner with `tensors`, any internal state should be stored using `keys`"""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def apply(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> list[torch.Tensor]:
|
|
37
|
+
"""applies preconditioner to `tensors`, any internal state should be stored using `keys`"""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
+
step = self.global_state.get('step', 0)
|
|
42
|
+
states = [self.state[p] for p in params]
|
|
43
|
+
settings = [self.settings[p] for p in params]
|
|
44
|
+
global_settings = settings[0]
|
|
45
|
+
update_freq = global_settings['__update_freq']
|
|
46
|
+
|
|
47
|
+
scale_first = global_settings['__scale_first']
|
|
48
|
+
scale_factor = 0
|
|
49
|
+
if scale_first and step == 0:
|
|
50
|
+
# initial step size guess from pytorch LBFGS
|
|
51
|
+
scale_factor = TensorList(tensors).abs().sum()
|
|
52
|
+
|
|
53
|
+
# update preconditioner
|
|
54
|
+
if step % update_freq == 0:
|
|
55
|
+
self.update(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
|
|
56
|
+
|
|
57
|
+
# step with inner
|
|
58
|
+
if 'inner' in self.children:
|
|
59
|
+
tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
|
|
60
|
+
|
|
61
|
+
# apply preconditioner
|
|
62
|
+
tensors = self.apply(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
|
|
63
|
+
|
|
64
|
+
# scale initial step, when preconditioner might not have been applied
|
|
65
|
+
if scale_first and step == 0:
|
|
66
|
+
torch._foreach_div_(tensors, scale_factor)
|
|
67
|
+
|
|
68
|
+
self.global_state['step'] = step + 1
|
|
69
|
+
return tensors
|
|
70
|
+
|
|
71
|
+
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
72
|
+
step = self.global_state.get('step', 0)
|
|
73
|
+
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
74
|
+
params_vec = torch.cat([p.ravel() for p in params])
|
|
75
|
+
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
76
|
+
|
|
77
|
+
states = [self.state[params[0]]]
|
|
78
|
+
settings = [self.settings[params[0]]]
|
|
79
|
+
global_settings = settings[0]
|
|
80
|
+
update_freq = global_settings['__update_freq']
|
|
81
|
+
|
|
82
|
+
scale_first = global_settings['__scale_first']
|
|
83
|
+
scale_factor = 0
|
|
84
|
+
if scale_first and step == 0:
|
|
85
|
+
# initial step size guess from pytorch LBFGS
|
|
86
|
+
scale_factor = tensors_vec.abs().sum()
|
|
87
|
+
|
|
88
|
+
# update preconditioner
|
|
89
|
+
if step % update_freq == 0:
|
|
90
|
+
self.update(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)
|
|
91
|
+
|
|
92
|
+
# step with inner
|
|
93
|
+
if 'inner' in self.children:
|
|
94
|
+
tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
|
|
95
|
+
tensors_vec = torch.cat([t.ravel() for t in tensors]) # have to recat
|
|
96
|
+
|
|
97
|
+
# apply preconditioner
|
|
98
|
+
tensors_vec = self.apply(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)[0]
|
|
99
|
+
|
|
100
|
+
# scale initial step, when preconditioner might not have been applied
|
|
101
|
+
if scale_first and step == 0:
|
|
102
|
+
if scale_factor >= torch.finfo(tensors_vec.dtype).eps:
|
|
103
|
+
tensors_vec /= scale_factor
|
|
104
|
+
|
|
105
|
+
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
106
|
+
self.global_state['step'] = step + 1
|
|
107
|
+
return tensors
|
|
108
|
+
|
|
109
|
+
@torch.no_grad
|
|
110
|
+
def transform(self, tensors, params, grads, vars):
|
|
111
|
+
concat_params = self.settings[params[0]]['__concat_params']
|
|
112
|
+
if concat_params: return self._concat_transform(tensors, params, grads, vars)
|
|
113
|
+
return self._tensor_wise_transform(tensors, params, grads, vars)
|
|
114
|
+
|
|
115
|
+
class TensorwisePreconditioner(Preconditioner, ABC):
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def update_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]):
|
|
118
|
+
"""update preconditioner with `tensor`"""
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def apply_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
|
|
122
|
+
"""apply preconditioner to `tensor`"""
|
|
123
|
+
|
|
124
|
+
@final
|
|
125
|
+
def update(self, tensors, params, grads, states, settings):
|
|
126
|
+
if grads is None: grads = [None]*len(tensors)
|
|
127
|
+
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
128
|
+
self.update_tensor(t, p, g, state, setting)
|
|
129
|
+
|
|
130
|
+
@final
|
|
131
|
+
def apply(self, tensors, params, grads, states, settings):
|
|
132
|
+
preconditioned = []
|
|
133
|
+
if grads is None: grads = [None]*len(tensors)
|
|
134
|
+
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
135
|
+
preconditioned.append(self.apply_tensor(t, p, g, state, setting))
|
|
136
|
+
return preconditioned
|
|
137
|
+
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ..utils import set_storage_
|
|
8
|
+
from .module import Module, Vars, Chain, Chainable
|
|
9
|
+
|
|
10
|
+
Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
|
|
11
|
+
|
|
12
|
+
class Transform(Module, ABC):
|
|
13
|
+
"""Base class for a transform.
|
|
14
|
+
|
|
15
|
+
This is an abstract class, to use it, subclass it and override `transform`.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
defaults (dict[str,Any] | None): dict with default values.
|
|
19
|
+
uses_grad (bool):
|
|
20
|
+
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
21
|
+
`grad` is always computed and can't be None. Otherwise set to False.
|
|
22
|
+
target (Target, optional):
|
|
23
|
+
what to set on vars. Defaults to 'update'.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
|
|
26
|
+
super().__init__(defaults)
|
|
27
|
+
self._target: Target = target
|
|
28
|
+
self._uses_grad = uses_grad
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def transform(self, tensors: list[torch.Tensor], params: list[torch.Tensor], grads: list[torch.Tensor] | None, vars: Vars) -> Iterable[torch.Tensor]:
|
|
32
|
+
"""applies the update rule to `target`."""
|
|
33
|
+
|
|
34
|
+
def step(self, vars: Vars) -> Vars:
|
|
35
|
+
# vars may change, therefore current params and grads have to be extracted and passed explicitly
|
|
36
|
+
if self._uses_grad: vars.get_grad()
|
|
37
|
+
params=vars.params; grad = vars.grad
|
|
38
|
+
|
|
39
|
+
# ---------------------------------- update ---------------------------------- #
|
|
40
|
+
if self._target == 'update':
|
|
41
|
+
vars.update = list(self.transform(vars.get_update(), params, grad, vars))
|
|
42
|
+
return vars
|
|
43
|
+
|
|
44
|
+
# ----------------------------------- grad ----------------------------------- #
|
|
45
|
+
if self._target == 'grad':
|
|
46
|
+
vars.grad = list(self.transform(vars.get_grad(), params, grad, vars))
|
|
47
|
+
return vars
|
|
48
|
+
|
|
49
|
+
# ------------------------------- params_direct ------------------------------ #
|
|
50
|
+
if self._target == 'params_direct':
|
|
51
|
+
new_params = self.transform(vars.params, params, grad, vars)
|
|
52
|
+
for p, new_p in zip(vars.params, new_params): set_storage_(p, new_p)
|
|
53
|
+
return vars
|
|
54
|
+
|
|
55
|
+
# ----------------------------- params_differnce ----------------------------- #
|
|
56
|
+
if self._target == 'params_difference':
|
|
57
|
+
new_params = tuple(self.transform([p.clone() for p in vars.params], params, grad, vars))
|
|
58
|
+
vars.update = list(torch._foreach_sub(vars.params, new_params))
|
|
59
|
+
return vars
|
|
60
|
+
|
|
61
|
+
# ----------------------------- update_difference ---------------------------- #
|
|
62
|
+
if self._target == 'update_difference':
|
|
63
|
+
update = vars.get_update()
|
|
64
|
+
new_update = tuple(self.transform([u.clone() for u in update], params, grad, vars))
|
|
65
|
+
vars.update = list(torch._foreach_sub(update, new_update))
|
|
66
|
+
return vars
|
|
67
|
+
|
|
68
|
+
# ---------------------------------- closure --------------------------------- #
|
|
69
|
+
if self._target == 'closure':
|
|
70
|
+
original_closure = vars.closure
|
|
71
|
+
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
72
|
+
|
|
73
|
+
params = vars.params
|
|
74
|
+
def transformed_closure(backward=True):
|
|
75
|
+
if backward:
|
|
76
|
+
loss = original_closure()
|
|
77
|
+
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
78
|
+
transformed_grad = list(self.transform(current_grad, params, grad, vars))
|
|
79
|
+
for p, g in zip(params, transformed_grad):
|
|
80
|
+
p.grad = g
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
loss = original_closure(False)
|
|
84
|
+
|
|
85
|
+
return loss
|
|
86
|
+
|
|
87
|
+
vars.closure = transformed_closure
|
|
88
|
+
return vars
|
|
89
|
+
|
|
90
|
+
# ---------------------------------- invalid --------------------------------- #
|
|
91
|
+
raise ValueError(f'Invalid target: {self._target}')
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class TensorwiseTransform(Module, ABC):
|
|
95
|
+
"""Base class for a parameter-wise transform.
|
|
96
|
+
|
|
97
|
+
This is an abstract class, to use it, subclass it and override `transform`.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
defaults (dict[str,Any] | None): dict with default values.
|
|
101
|
+
uses_grad (bool):
|
|
102
|
+
Set this to True if `transform` method uses the `grad` argument. This will ensure
|
|
103
|
+
`grad` is always computed and can't be None. Otherwise set to False.
|
|
104
|
+
target (Target, optional):
|
|
105
|
+
what to set on vars. Defaults to 'update'.
|
|
106
|
+
"""
|
|
107
|
+
def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
|
|
108
|
+
super().__init__(defaults)
|
|
109
|
+
self._target: Target = target
|
|
110
|
+
self._uses_grad: bool = uses_grad
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def transform(
|
|
114
|
+
self,
|
|
115
|
+
tensor: torch.Tensor,
|
|
116
|
+
param: torch.Tensor,
|
|
117
|
+
grad: torch.Tensor | None,
|
|
118
|
+
vars: Vars,
|
|
119
|
+
) -> torch.Tensor:
|
|
120
|
+
"""applies the update rule to `target`"""
|
|
121
|
+
|
|
122
|
+
def step(self, vars: Vars) -> Vars:
|
|
123
|
+
params = vars.params
|
|
124
|
+
if self._uses_grad and vars.grad is None: vars.get_grad()
|
|
125
|
+
|
|
126
|
+
# ---------------------------------- update ---------------------------------- #
|
|
127
|
+
if self._target == 'update':
|
|
128
|
+
update = vars.get_update()
|
|
129
|
+
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
130
|
+
transformed_update = []
|
|
131
|
+
|
|
132
|
+
for p, g, u in zip(params, grad, update):
|
|
133
|
+
# settings = self.settings[p] # couldn't make typing work with this
|
|
134
|
+
#, self.transform(target=u, param=p, grad=g, vars=vars, **{k:settings[k] for k in self.defaults})
|
|
135
|
+
transformed_update.append(self.transform(tensor=u, param=p, grad=g, vars=vars))
|
|
136
|
+
|
|
137
|
+
vars.update = transformed_update
|
|
138
|
+
return vars
|
|
139
|
+
|
|
140
|
+
# ----------------------------------- grad ----------------------------------- #
|
|
141
|
+
if self._target == 'grad':
|
|
142
|
+
grad = vars.get_grad()
|
|
143
|
+
transformed_grad = []
|
|
144
|
+
|
|
145
|
+
for p, g in zip(params, grad):
|
|
146
|
+
transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
|
|
147
|
+
|
|
148
|
+
vars.grad = transformed_grad
|
|
149
|
+
return vars
|
|
150
|
+
|
|
151
|
+
# ------------------------------- params_direct ------------------------------ #
|
|
152
|
+
if self._target == 'params_direct':
|
|
153
|
+
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
154
|
+
|
|
155
|
+
for p, g in zip(params, grad):
|
|
156
|
+
set_storage_(p, self.transform(tensor=p, param=p, grad=g, vars=vars))
|
|
157
|
+
|
|
158
|
+
return vars
|
|
159
|
+
|
|
160
|
+
# ----------------------------- params_difference ---------------------------- #
|
|
161
|
+
if self._target == 'params_difference':
|
|
162
|
+
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
163
|
+
transformed_params = []
|
|
164
|
+
|
|
165
|
+
for p, g in zip(params, grad):
|
|
166
|
+
transformed_params.append(
|
|
167
|
+
self.transform(tensor=p.clone(), param=p, grad=g, vars=vars)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
vars.update = list(torch._foreach_sub(params, transformed_params))
|
|
171
|
+
return vars
|
|
172
|
+
|
|
173
|
+
# ----------------------------- update_difference ---------------------------- #
|
|
174
|
+
if self._target == 'update_difference':
|
|
175
|
+
update = vars.get_update()
|
|
176
|
+
grad = vars.grad if vars.grad is not None else [None] * len(params)
|
|
177
|
+
transformed_update = []
|
|
178
|
+
|
|
179
|
+
for p, g, u in zip(params, grad, update):
|
|
180
|
+
transformed_update.append(
|
|
181
|
+
self.transform(tensor=u.clone(), param=p, grad=g, vars=vars)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
vars.update = list(torch._foreach_sub(update, transformed_update))
|
|
185
|
+
return vars
|
|
186
|
+
|
|
187
|
+
# ---------------------------------- closure --------------------------------- #
|
|
188
|
+
if self._target == 'closure':
|
|
189
|
+
original_closure = vars.closure
|
|
190
|
+
if original_closure is None: raise ValueError('Target = "closure", but closure is None')
|
|
191
|
+
|
|
192
|
+
params = vars.params
|
|
193
|
+
def transformed_closure(backward=True):
|
|
194
|
+
if backward:
|
|
195
|
+
loss = original_closure()
|
|
196
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
197
|
+
transformed_grad = []
|
|
198
|
+
|
|
199
|
+
for p, g in zip(params, grad):
|
|
200
|
+
transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
|
|
201
|
+
|
|
202
|
+
for p, g in zip(params, transformed_grad):
|
|
203
|
+
p.grad = g
|
|
204
|
+
|
|
205
|
+
else:
|
|
206
|
+
loss = original_closure(False)
|
|
207
|
+
|
|
208
|
+
return loss
|
|
209
|
+
|
|
210
|
+
vars.closure = transformed_closure
|
|
211
|
+
return vars
|
|
212
|
+
|
|
213
|
+
# ---------------------------------- invalid --------------------------------- #
|
|
214
|
+
raise ValueError(f'Invalid target: {self._target}')
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def apply(
|
|
219
|
+
tfm: Chainable,
|
|
220
|
+
tensors: list[torch.Tensor],
|
|
221
|
+
params: list[torch.Tensor],
|
|
222
|
+
grads: list[torch.Tensor] | None,
|
|
223
|
+
vars: Vars | None = None,
|
|
224
|
+
current_step: int = 0,
|
|
225
|
+
):
|
|
226
|
+
if vars is None: vars = Vars(params=params, closure=None, model=None, current_step=current_step)
|
|
227
|
+
if isinstance(tfm, Transform):
|
|
228
|
+
if tfm._uses_grad and grads is None: grads = vars.get_grad()
|
|
229
|
+
return list(tfm.transform(tensors, params, grads, vars))
|
|
230
|
+
|
|
231
|
+
if isinstance(tfm, TensorwiseTransform):
|
|
232
|
+
grads_list = grads
|
|
233
|
+
if grads_list is None:
|
|
234
|
+
if tfm._uses_grad: grads_list = vars.get_grad()
|
|
235
|
+
else: grads_list = [None] * len(tensors)
|
|
236
|
+
return [tfm.transform(t, p, g, vars) for t,p,g in zip(tensors,params,grads_list)]
|
|
237
|
+
|
|
238
|
+
if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
|
|
239
|
+
if isinstance(tfm, Sequence):
|
|
240
|
+
for module in tfm:
|
|
241
|
+
tensors = apply(module, tensors=tensors, params=params, grads=grads, vars=vars)
|
|
242
|
+
return tensors
|
|
243
|
+
|
|
244
|
+
if isinstance(tfm, Module):
|
|
245
|
+
cvars = vars.clone(clone_update=False)
|
|
246
|
+
cvars.update = tensors
|
|
247
|
+
cvars = tfm.step(cvars)
|
|
248
|
+
vars.update_attrs_from_clone_(cvars)
|
|
249
|
+
assert cvars.update is not None
|
|
250
|
+
return cvars.update
|
|
251
|
+
|
|
252
|
+
raise TypeError(type(tfm))
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,21 +1,13 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
from . import
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .
|
|
14
|
-
from .optimizers import *
|
|
15
|
-
from .orthogonalization import *
|
|
16
|
-
from .quasi_newton import *
|
|
17
|
-
from .regularization import *
|
|
18
|
-
from .scheduling import *
|
|
19
|
-
from .second_order import *
|
|
20
|
-
from .smoothing import *
|
|
21
|
-
from .weight_averaging import *
|
|
1
|
+
from .clipping import *
|
|
2
|
+
from .grad_approximation import *
|
|
3
|
+
from .line_search import *
|
|
4
|
+
from .lr import *
|
|
5
|
+
from .momentum import *
|
|
6
|
+
from .ops import *
|
|
7
|
+
from .optimizers import *
|
|
8
|
+
from .projections import *
|
|
9
|
+
from .quasi_newton import *
|
|
10
|
+
from .smoothing import *
|
|
11
|
+
from .weight_decay import *
|
|
12
|
+
from .wrappers import *
|
|
13
|
+
from .second_order import *
|