torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,62 +1,14 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
1
|
from typing import Literal
|
|
3
|
-
|
|
4
2
|
import torch
|
|
3
|
+
|
|
5
4
|
from ...core import (
|
|
6
5
|
Chainable,
|
|
7
|
-
|
|
8
|
-
Target,
|
|
9
|
-
TensorwiseTransform,
|
|
10
|
-
Transform,
|
|
11
|
-
Var,
|
|
12
|
-
apply_transform,
|
|
6
|
+
TensorTransform,
|
|
13
7
|
)
|
|
14
|
-
from ...utils import NumberList, TensorList, unpack_dicts
|
|
15
|
-
from ...
|
|
16
|
-
from ..functional import add_power_, lerp_power_, root, epsilon_step_size
|
|
17
|
-
from ...utils.linalg.linear_operator import Dense
|
|
18
|
-
|
|
19
|
-
def adagrad_(
|
|
20
|
-
tensors_: TensorList,
|
|
21
|
-
sq_sum_: TensorList,
|
|
22
|
-
alpha: float | NumberList,
|
|
23
|
-
lr_decay: float | NumberList,
|
|
24
|
-
eps: float | NumberList,
|
|
25
|
-
step: int,
|
|
26
|
-
pow: float = 2,
|
|
27
|
-
use_sqrt: bool = True,
|
|
28
|
-
divide: bool = False,
|
|
29
|
-
|
|
30
|
-
decay: float | None = None,
|
|
31
|
-
beta: float | None = None,
|
|
32
|
-
|
|
33
|
-
# inner args
|
|
34
|
-
inner: Module | None = None,
|
|
35
|
-
params: list[torch.Tensor] | None = None,
|
|
36
|
-
grads: list[torch.Tensor] | None = None,
|
|
37
|
-
):
|
|
38
|
-
"""returns `tensors_`"""
|
|
39
|
-
clr = alpha / (1 + step * lr_decay)
|
|
40
|
-
|
|
41
|
-
if beta is None or step == 1: sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
|
|
42
|
-
else: sq_sum_ = lerp_power_(tensors_, exp_avg_pow_=sq_sum_, beta=beta, pow=pow)
|
|
43
|
-
if decay is not None:
|
|
44
|
-
sq_sum_.mul_(1-decay)
|
|
45
|
-
|
|
46
|
-
if inner is not None:
|
|
47
|
-
assert params is not None
|
|
48
|
-
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
49
|
-
|
|
50
|
-
if divide: sq_sum_ = sq_sum_ / max(step, 1)
|
|
51
|
-
|
|
52
|
-
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
53
|
-
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts
|
|
9
|
+
from ...linalg.matrix_power import matrix_power as _matrix_power, MatrixPowerMethod
|
|
54
10
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class Adagrad(Transform):
|
|
11
|
+
class Adagrad(TensorTransform):
|
|
60
12
|
"""Adagrad, divides by sum of past squares of gradients.
|
|
61
13
|
|
|
62
14
|
This implementation is identical to ``torch.optim.Adagrad``.
|
|
@@ -72,103 +24,53 @@ class Adagrad(Transform):
|
|
|
72
24
|
"""
|
|
73
25
|
def __init__(
|
|
74
26
|
self,
|
|
27
|
+
|
|
28
|
+
# hyperparams
|
|
75
29
|
lr_decay: float = 0,
|
|
76
30
|
initial_accumulator_value: float = 0,
|
|
77
31
|
eps: float = 1e-10,
|
|
78
32
|
alpha: float = 1,
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
divide: bool=False,
|
|
82
|
-
beta:float | None = None,
|
|
83
|
-
decay: float | None = None,
|
|
33
|
+
|
|
34
|
+
# tfms
|
|
84
35
|
inner: Chainable | None = None,
|
|
36
|
+
accumulator_tfm: Chainable | None = None
|
|
85
37
|
):
|
|
86
|
-
defaults =
|
|
87
|
-
|
|
88
|
-
super().__init__(defaults=defaults,
|
|
38
|
+
defaults = locals().copy()
|
|
39
|
+
del defaults['self'], defaults['inner'], defaults["accumulator_tfm"]
|
|
40
|
+
super().__init__(defaults=defaults, inner=inner)
|
|
89
41
|
|
|
90
|
-
|
|
91
|
-
self.set_child('inner', inner)
|
|
42
|
+
self.set_child('accumulator', accumulator_tfm)
|
|
92
43
|
|
|
93
44
|
@torch.no_grad
|
|
94
|
-
def
|
|
95
|
-
|
|
96
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
45
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
46
|
+
state["accumulator"] = torch.full_like(tensor, fill_value=setting["initial_accumulator_value"])
|
|
97
47
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
103
|
-
|
|
104
|
-
# initialize accumulator on 1st step
|
|
105
|
-
if step == 1:
|
|
106
|
-
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
107
|
-
|
|
108
|
-
return adagrad_(
|
|
109
|
-
tensors,
|
|
110
|
-
sq_sum_=sq_sum,
|
|
111
|
-
alpha=alpha,
|
|
112
|
-
lr_decay=lr_decay,
|
|
113
|
-
eps=eps,
|
|
114
|
-
step=step,
|
|
115
|
-
pow=pow,
|
|
116
|
-
use_sqrt=use_sqrt,
|
|
117
|
-
divide=divide,
|
|
118
|
-
|
|
119
|
-
beta = self.defaults["beta"],
|
|
120
|
-
decay = self.defaults["decay"],
|
|
121
|
-
# inner args
|
|
122
|
-
inner=self.children.get("inner", None),
|
|
123
|
-
params=params,
|
|
124
|
-
grads=grads,
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def lerp(start, end, weight):
|
|
129
|
-
return start + weight * (end - start)
|
|
130
|
-
|
|
131
|
-
def adagrad_norm_(
|
|
132
|
-
tensors_: TensorList,
|
|
133
|
-
accumulator: float | torch.Tensor,
|
|
134
|
-
alpha: float | NumberList,
|
|
135
|
-
lr_decay: float | NumberList,
|
|
136
|
-
eps: float | NumberList,
|
|
137
|
-
step: int,
|
|
138
|
-
use_sqrt: bool = True,
|
|
139
|
-
divide: bool = False,
|
|
140
|
-
|
|
141
|
-
decay: float | None = None,
|
|
142
|
-
beta: float | None = None,
|
|
143
|
-
|
|
144
|
-
# inner args
|
|
145
|
-
inner: Module | None = None,
|
|
146
|
-
params: list[torch.Tensor] | None = None,
|
|
147
|
-
grads: list[torch.Tensor] | None = None,
|
|
148
|
-
):
|
|
149
|
-
"""returns `tensors_`"""
|
|
150
|
-
clr = alpha / (1 + step * lr_decay)
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
50
|
+
torch._foreach_addcmul_([state["accumulator"] for state in states], tensors, tensors)
|
|
51
|
+
self.increment_counter("step", start=0)
|
|
151
52
|
|
|
152
|
-
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
55
|
+
tensors_ = TensorList(tensors)
|
|
56
|
+
step = self.global_state["step"] # 0 on first apply
|
|
57
|
+
eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)
|
|
153
58
|
|
|
154
|
-
|
|
155
|
-
|
|
59
|
+
accumulator = [state["accumulator"] for state in states]
|
|
60
|
+
accumulator = TensorList(self.inner_step_tensors(
|
|
61
|
+
"accumulator", tensors=accumulator, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
|
|
156
62
|
|
|
157
|
-
|
|
158
|
-
|
|
63
|
+
denom = accumulator.sqrt().add_(eps)
|
|
64
|
+
tensors_ /= denom
|
|
159
65
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
66
|
+
clr = alpha / (1 + step * lr_decay)
|
|
67
|
+
tensors_.lazy_mul_(clr)
|
|
163
68
|
|
|
164
|
-
|
|
69
|
+
return tensors_
|
|
165
70
|
|
|
166
|
-
if use_sqrt: tensors_.div_(eps + accumulator.sqrt()).mul_(clr)
|
|
167
|
-
else: tensors_.div_(eps + accumulator).mul_(clr)
|
|
168
71
|
|
|
169
|
-
return tensors_, accumulator
|
|
170
72
|
|
|
171
|
-
class AdagradNorm(
|
|
73
|
+
class AdagradNorm(TensorTransform):
|
|
172
74
|
"""Adagrad-Norm, divides by sum of past means of squares of gradients.
|
|
173
75
|
|
|
174
76
|
Args:
|
|
@@ -176,7 +78,6 @@ class AdagradNorm(Transform):
|
|
|
176
78
|
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
177
79
|
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
178
80
|
alpha (float, optional): step size. Defaults to 1.
|
|
179
|
-
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
180
81
|
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
181
82
|
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
182
83
|
"""
|
|
@@ -185,71 +86,104 @@ class AdagradNorm(Transform):
|
|
|
185
86
|
lr_decay: float = 0,
|
|
186
87
|
initial_accumulator_value: float = 0,
|
|
187
88
|
eps: float = 1e-10,
|
|
188
|
-
alpha: float = 1,
|
|
189
|
-
pow: float = 2,
|
|
190
|
-
use_sqrt: bool = True,
|
|
191
|
-
divide: bool=False,
|
|
192
89
|
beta:float | None = None,
|
|
193
|
-
|
|
90
|
+
beta_debias: bool = True,
|
|
91
|
+
layerwise: bool = True,
|
|
92
|
+
use_sqrt: bool = True,
|
|
93
|
+
alpha: float = 1,
|
|
194
94
|
inner: Chainable | None = None,
|
|
195
95
|
):
|
|
196
|
-
defaults =
|
|
197
|
-
|
|
198
|
-
super().__init__(defaults=defaults,
|
|
96
|
+
defaults = locals().copy()
|
|
97
|
+
del defaults['self'], defaults['inner']
|
|
98
|
+
super().__init__(defaults=defaults, inner=inner)
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
|
|
102
|
+
|
|
103
|
+
# layerwise initialize in each state
|
|
104
|
+
if settings[0]["layerwise"]:
|
|
105
|
+
for tensor, state, setting in zip(tensors, states, settings):
|
|
106
|
+
|
|
107
|
+
initial_accumulator_value = setting["initial_accumulator_value"]
|
|
108
|
+
state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)
|
|
109
|
+
|
|
110
|
+
# global initialize in global state
|
|
111
|
+
else:
|
|
112
|
+
initial_accumulator_value = settings[0]["initial_accumulator_value"]
|
|
113
|
+
tensor = tensors[0]
|
|
114
|
+
self.global_state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)
|
|
115
|
+
|
|
116
|
+
def _get_accumulator(self, states, settings) -> torch.Tensor | TensorList:
|
|
117
|
+
layerwise = settings[0]["layerwise"]
|
|
118
|
+
if layerwise:
|
|
119
|
+
return TensorList(s["accumulator"] for s in states)
|
|
120
|
+
|
|
121
|
+
return self.global_state["accumulator"]
|
|
199
122
|
|
|
200
|
-
|
|
201
|
-
|
|
123
|
+
@torch.no_grad
|
|
124
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
125
|
+
tensors = TensorList(tensors)
|
|
126
|
+
accumulator = self._get_accumulator(states, settings)
|
|
127
|
+
self.increment_counter("step", start=0)
|
|
128
|
+
|
|
129
|
+
# compute squared gradient norm (gg)
|
|
130
|
+
if isinstance(accumulator, TensorList): gg = tensors.tensorwise_dot(tensors)
|
|
131
|
+
else: gg = tensors.dot(tensors)
|
|
132
|
+
|
|
133
|
+
# update the accumulator
|
|
134
|
+
beta = settings[0]["beta"]
|
|
135
|
+
if beta is None: accumulator.add_(gg) # pyright:ignore[reportArgumentType]
|
|
136
|
+
else: accumulator.lerp_(gg, weight=1-beta) # pyright:ignore[reportArgumentType, reportCallIssue]
|
|
202
137
|
|
|
203
138
|
@torch.no_grad
|
|
204
|
-
def
|
|
139
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
205
140
|
tensors = TensorList(tensors)
|
|
206
|
-
|
|
207
|
-
|
|
141
|
+
accumulator = self._get_accumulator(states, settings)
|
|
142
|
+
eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)
|
|
143
|
+
step = self.global_state["step"] # 0 on 1st step
|
|
144
|
+
fs = settings[0]
|
|
145
|
+
beta = fs["beta"]
|
|
208
146
|
|
|
209
|
-
|
|
147
|
+
# ------------------------ debias if beta is not None ------------------------ #
|
|
148
|
+
if fs["beta_debias"] and beta is not None:
|
|
149
|
+
accumulator = accumulator / (1 - beta ** (step + 1))
|
|
210
150
|
|
|
211
|
-
accumulator = self.global_state.get("accumulator", initial_accumulator_value)
|
|
212
151
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
eps=eps,
|
|
219
|
-
step=step,
|
|
220
|
-
use_sqrt=use_sqrt,
|
|
221
|
-
divide=divide,
|
|
152
|
+
# ---------------------------- compute denominator --------------------------- #
|
|
153
|
+
if fs["use_sqrt"]:
|
|
154
|
+
denom = accumulator.sqrt().add_(eps) # pyright:ignore[reportArgumentType]
|
|
155
|
+
else:
|
|
156
|
+
denom = accumulator + eps # pyright:ignore[reportOperatorIssue]
|
|
222
157
|
|
|
223
|
-
beta = self.defaults["beta"],
|
|
224
|
-
decay = self.defaults["decay"],
|
|
225
|
-
# inner args
|
|
226
|
-
inner=self.children.get("inner", None),
|
|
227
|
-
params=params,
|
|
228
|
-
grads=grads,
|
|
229
|
-
)
|
|
230
158
|
|
|
231
|
-
|
|
159
|
+
# ---------------------------- compute the update ---------------------------- #
|
|
160
|
+
tensors /= denom
|
|
161
|
+
clr = alpha / (1 + step * lr_decay) # lr decay
|
|
162
|
+
tensors.lazy_mul_(clr)
|
|
232
163
|
|
|
164
|
+
return tensors
|
|
233
165
|
|
|
234
|
-
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class FullMatrixAdagrad(TensorTransform):
|
|
235
169
|
"""Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
|
|
236
170
|
|
|
237
171
|
Note:
|
|
238
172
|
A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
|
|
239
173
|
|
|
240
174
|
Args:
|
|
241
|
-
|
|
242
|
-
decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
|
|
243
|
-
sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
|
|
244
|
-
concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
|
|
175
|
+
reg (float, optional): regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.
|
|
245
176
|
precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
|
|
177
|
+
beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
|
|
178
|
+
beta_debias (bool, optional): whether to use debiasing, only has effect when ``beta`` is not ``None``. Defaults to True.
|
|
246
179
|
init (Literal[str], optional):
|
|
247
180
|
how to initialize the accumulator.
|
|
248
181
|
- "identity" - with identity matrix (default).
|
|
249
182
|
- "zeros" - with zero matrix.
|
|
250
183
|
- "ones" - with matrix of ones.
|
|
251
184
|
-"GGT" - with the first outer product
|
|
252
|
-
|
|
185
|
+
matrix_power (float, optional): accumulator matrix power. Defaults to -1/2.
|
|
186
|
+
concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
|
|
253
187
|
inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.
|
|
254
188
|
|
|
255
189
|
## Examples:
|
|
@@ -284,73 +218,89 @@ class FullMatrixAdagrad(TensorwiseTransform):
|
|
|
284
218
|
"""
|
|
285
219
|
def __init__(
|
|
286
220
|
self,
|
|
221
|
+
reg: float = 1e-12,
|
|
222
|
+
precond_freq: int = 1,
|
|
287
223
|
beta: float | None = None,
|
|
288
|
-
|
|
289
|
-
|
|
224
|
+
beta_debias: bool=True,
|
|
225
|
+
init: Literal["identity", "zeros", "GGT"] = "identity",
|
|
226
|
+
matrix_power: float = -1/2,
|
|
227
|
+
matrix_power_method: MatrixPowerMethod = "eigh_abs",
|
|
290
228
|
concat_params=True,
|
|
291
|
-
|
|
292
|
-
init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
|
|
293
|
-
reg: float = 1e-12,
|
|
294
|
-
divide: bool = False,
|
|
229
|
+
|
|
295
230
|
inner: Chainable | None = None,
|
|
231
|
+
accumulator_tfm: Chainable | None = None
|
|
296
232
|
):
|
|
297
|
-
defaults =
|
|
298
|
-
|
|
233
|
+
defaults = locals().copy()
|
|
234
|
+
del defaults['self'], defaults['inner'], defaults["concat_params"], defaults["accumulator_tfm"]
|
|
235
|
+
super().__init__(defaults=defaults, inner=inner, concat_params=concat_params)
|
|
236
|
+
|
|
237
|
+
self.set_child("accumulator", accumulator_tfm)
|
|
299
238
|
|
|
300
239
|
@torch.no_grad
|
|
301
|
-
def
|
|
240
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
241
|
+
|
|
302
242
|
G = tensor.ravel()
|
|
303
|
-
GG = torch.outer(G, G)
|
|
304
|
-
|
|
243
|
+
GGᵀ = torch.outer(G, G)
|
|
244
|
+
|
|
245
|
+
# initialize
|
|
246
|
+
if "accumulator" not in state:
|
|
247
|
+
init = setting['init']
|
|
248
|
+
if init == 'identity': state['accumulator'] = torch.eye(GGᵀ.size(0), device=GGᵀ.device, dtype=GGᵀ.dtype)
|
|
249
|
+
elif init == 'zeros': state['accumulator'] = torch.zeros_like(GGᵀ)
|
|
250
|
+
elif init == 'GGT': state['accumulator'] = GGᵀ.clone()
|
|
251
|
+
else: raise ValueError(init)
|
|
252
|
+
|
|
253
|
+
# update
|
|
305
254
|
beta = setting['beta']
|
|
306
|
-
|
|
255
|
+
accumulator: torch.Tensor = state["accumulator"]
|
|
307
256
|
|
|
308
|
-
if
|
|
309
|
-
|
|
310
|
-
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
311
|
-
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
312
|
-
elif init == 'GGT': state['GG'] = GG.clone()
|
|
313
|
-
else: raise ValueError(init)
|
|
314
|
-
if decay is not None: state['GG'].mul_(decay)
|
|
257
|
+
if beta is None: accumulator.add_(GGᵀ)
|
|
258
|
+
else: accumulator.lerp_(GGᵀ, 1-beta)
|
|
315
259
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
|
|
260
|
+
# update number of GGᵀ in accumulator for divide
|
|
261
|
+
state['num_GGTs'] = state.get('num_GGTs', 0) + 1
|
|
319
262
|
|
|
320
263
|
@torch.no_grad
|
|
321
|
-
def
|
|
264
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
322
265
|
step = state.get('step', 0)
|
|
323
266
|
state['step'] = step + 1
|
|
324
267
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
268
|
+
accumulator: torch.Tensor = state['accumulator']
|
|
269
|
+
accumulator = self.inner_step_tensors("accumulator", [accumulator], clone=True, must_exist=False)[0]
|
|
270
|
+
|
|
328
271
|
precond_freq = setting['precond_freq']
|
|
329
272
|
reg = setting['reg']
|
|
273
|
+
beta = setting["beta"]
|
|
330
274
|
|
|
331
|
-
|
|
332
|
-
|
|
275
|
+
# add regularizer
|
|
333
276
|
if reg != 0:
|
|
334
|
-
|
|
277
|
+
device = accumulator.device; dtype = accumulator.dtype
|
|
278
|
+
accumulator = accumulator + torch.eye(accumulator.size(0), device=device, dtype=dtype).mul_(reg)
|
|
335
279
|
|
|
280
|
+
# for single value use sqrt
|
|
336
281
|
if tensor.numel() == 1:
|
|
337
|
-
|
|
338
|
-
if sqrt: return tensor / GG.sqrt()
|
|
339
|
-
return tensor / GG
|
|
282
|
+
dir = tensor.mul_(accumulator.squeeze() ** setting["matrix_power"])
|
|
340
283
|
|
|
341
|
-
|
|
342
|
-
|
|
284
|
+
# otherwise use matrix inverse square root
|
|
285
|
+
else:
|
|
286
|
+
|
|
287
|
+
# compute inverse square root and store to state
|
|
288
|
+
try:
|
|
343
289
|
if "B" not in state or step % precond_freq == 0:
|
|
344
|
-
B = state["B"] =
|
|
290
|
+
B = state["B"] = _matrix_power(accumulator, setting["matrix_power"], method=setting["matrix_power_method"])
|
|
345
291
|
else:
|
|
346
292
|
B = state["B"]
|
|
347
293
|
|
|
348
|
-
|
|
294
|
+
dir = (B @ tensor.ravel()).view_as(tensor)
|
|
295
|
+
|
|
296
|
+
# fallback to diagonal Adagrad on fail
|
|
297
|
+
except torch.linalg.LinAlgError:
|
|
298
|
+
dir = tensor.mul_(accumulator.diagonal() ** setting["matrix_power"])
|
|
349
299
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
300
|
+
# debias
|
|
301
|
+
if setting["beta_debias"] and beta is not None:
|
|
302
|
+
num_GGTs = state.get('num_GGTs', 1)
|
|
303
|
+
bias_correction = 1 - beta ** num_GGTs
|
|
304
|
+
dir *= bias_correction ** 0.5
|
|
355
305
|
|
|
356
|
-
return
|
|
306
|
+
return dir
|