torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/modules/ops/debug.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Module
|
|
6
|
-
from ...utils.tensorlist import Distributions
|
|
7
|
-
|
|
8
|
-
class PrintUpdate(Module):
|
|
9
|
-
def __init__(self, text = 'update = ', print_fn = print):
|
|
10
|
-
defaults = dict(text=text, print_fn=print_fn)
|
|
11
|
-
super().__init__(defaults)
|
|
12
|
-
|
|
13
|
-
def step(self, var):
|
|
14
|
-
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
|
|
15
|
-
return var
|
|
16
|
-
|
|
17
|
-
class PrintShape(Module):
|
|
18
|
-
def __init__(self, text = 'shapes = ', print_fn = print):
|
|
19
|
-
defaults = dict(text=text, print_fn=print_fn)
|
|
20
|
-
super().__init__(defaults)
|
|
21
|
-
|
|
22
|
-
def step(self, var):
|
|
23
|
-
shapes = [u.shape for u in var.update] if var.update is not None else None
|
|
24
|
-
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
|
|
25
|
-
return var
|
torchzero/modules/ops/misc.py
DELETED
|
@@ -1,418 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
-
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class Previous(TensorwiseTransform):
|
|
13
|
-
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
14
|
-
def __init__(self, n=1, target: Target = 'update'):
|
|
15
|
-
defaults = dict(n=n)
|
|
16
|
-
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@torch.no_grad
|
|
20
|
-
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
21
|
-
n = settings['n']
|
|
22
|
-
|
|
23
|
-
if 'history' not in state:
|
|
24
|
-
state['history'] = deque(maxlen=n+1)
|
|
25
|
-
|
|
26
|
-
state['history'].append(tensor)
|
|
27
|
-
|
|
28
|
-
return state['history'][0]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class LastDifference(Transform):
|
|
32
|
-
"""Difference between past two updates."""
|
|
33
|
-
def __init__(self,target: Target = 'update'):
|
|
34
|
-
super().__init__({}, uses_grad=False, target=target)
|
|
35
|
-
|
|
36
|
-
@torch.no_grad
|
|
37
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
38
|
-
prev = unpack_states(states, tensors, 'prev_target') # initialized to 0
|
|
39
|
-
difference = torch._foreach_sub(tensors, prev)
|
|
40
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
41
|
-
return difference
|
|
42
|
-
|
|
43
|
-
class LastGradDifference(Module):
|
|
44
|
-
"""Difference between past two grads."""
|
|
45
|
-
def __init__(self):
|
|
46
|
-
super().__init__({})
|
|
47
|
-
|
|
48
|
-
@torch.no_grad
|
|
49
|
-
def step(self, var):
|
|
50
|
-
grad = var.get_grad()
|
|
51
|
-
prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
|
|
52
|
-
difference = torch._foreach_sub(grad, prev_grad)
|
|
53
|
-
for p, c in zip(prev_grad, grad): p.set_(c)
|
|
54
|
-
var.update = list(difference)
|
|
55
|
-
return var
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class LastProduct(Transform):
|
|
59
|
-
"""Difference between past two updates."""
|
|
60
|
-
def __init__(self,target: Target = 'update'):
|
|
61
|
-
super().__init__({}, uses_grad=False, target=target)
|
|
62
|
-
|
|
63
|
-
@torch.no_grad
|
|
64
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
65
|
-
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
66
|
-
prod = torch._foreach_mul(tensors, prev)
|
|
67
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
68
|
-
return prod
|
|
69
|
-
|
|
70
|
-
class LastRatio(Transform):
|
|
71
|
-
"""Ratio between past two updates."""
|
|
72
|
-
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
|
|
73
|
-
defaults = dict(numerator=numerator)
|
|
74
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
75
|
-
|
|
76
|
-
@torch.no_grad
|
|
77
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
78
|
-
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
79
|
-
numerator = settings[0]['numerator']
|
|
80
|
-
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
81
|
-
else: ratio = torch._foreach_div(prev, tensors)
|
|
82
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
83
|
-
return ratio
|
|
84
|
-
|
|
85
|
-
class LastAbsoluteRatio(Transform):
|
|
86
|
-
"""Ratio between absolute values of past two updates."""
|
|
87
|
-
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
|
|
88
|
-
defaults = dict(numerator=numerator, eps=eps)
|
|
89
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
90
|
-
|
|
91
|
-
@torch.no_grad
|
|
92
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
93
|
-
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
94
|
-
numerator = settings[0]['numerator']
|
|
95
|
-
eps = NumberList(s['eps'] for s in settings)
|
|
96
|
-
|
|
97
|
-
torch._foreach_abs_(tensors)
|
|
98
|
-
torch._foreach_clamp_min_(prev, eps)
|
|
99
|
-
|
|
100
|
-
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
101
|
-
else: ratio = torch._foreach_div(prev, tensors)
|
|
102
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
103
|
-
return ratio
|
|
104
|
-
|
|
105
|
-
class GradSign(Transform):
|
|
106
|
-
"""copy gradient sign to update."""
|
|
107
|
-
def __init__(self, target: Target = 'update'):
|
|
108
|
-
super().__init__({}, uses_grad=True, target=target)
|
|
109
|
-
|
|
110
|
-
@torch.no_grad
|
|
111
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
112
|
-
assert grads is not None
|
|
113
|
-
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
114
|
-
|
|
115
|
-
class UpdateSign(Transform):
|
|
116
|
-
"""use per-weight magnitudes from grad while using sign from update."""
|
|
117
|
-
def __init__(self, target: Target = 'update'):
|
|
118
|
-
super().__init__({}, uses_grad=True, target=target)
|
|
119
|
-
|
|
120
|
-
@torch.no_grad
|
|
121
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
122
|
-
assert grads is not None
|
|
123
|
-
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
124
|
-
|
|
125
|
-
class GraftToGrad(Transform):
|
|
126
|
-
"""use gradient norm and update direction."""
|
|
127
|
-
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
128
|
-
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
129
|
-
super().__init__(defaults, uses_grad=True, target=target)
|
|
130
|
-
|
|
131
|
-
@torch.no_grad
|
|
132
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
133
|
-
assert grads is not None
|
|
134
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
135
|
-
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
136
|
-
|
|
137
|
-
class GraftGradToUpdate(Transform):
|
|
138
|
-
"""use update norm and gradient direction."""
|
|
139
|
-
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
140
|
-
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
141
|
-
super().__init__(defaults, uses_grad=True, target=target)
|
|
142
|
-
|
|
143
|
-
@torch.no_grad
|
|
144
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
145
|
-
assert grads is not None
|
|
146
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
147
|
-
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class GraftToParams(Transform):
|
|
151
|
-
"""makes update norm be set to parameter norm, but norm won't go below eps"""
|
|
152
|
-
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
|
|
153
|
-
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
154
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
155
|
-
|
|
156
|
-
@torch.no_grad
|
|
157
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
158
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
159
|
-
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
160
|
-
|
|
161
|
-
class Relative(Transform):
|
|
162
|
-
"""multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
|
|
163
|
-
def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
|
|
164
|
-
defaults = dict(min_value=min_value)
|
|
165
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
166
|
-
|
|
167
|
-
@torch.no_grad
|
|
168
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
169
|
-
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
170
|
-
torch._foreach_mul_(tensors, mul)
|
|
171
|
-
return tensors
|
|
172
|
-
|
|
173
|
-
class FillLoss(Module):
|
|
174
|
-
"""makes tensors filled with loss value times alpha"""
|
|
175
|
-
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
176
|
-
defaults = dict(alpha=alpha, backward=backward)
|
|
177
|
-
super().__init__(defaults)
|
|
178
|
-
|
|
179
|
-
@torch.no_grad
|
|
180
|
-
def step(self, var):
|
|
181
|
-
alpha = self.get_settings(var.params, 'alpha')
|
|
182
|
-
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
183
|
-
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
184
|
-
return var
|
|
185
|
-
|
|
186
|
-
class MulByLoss(Module):
|
|
187
|
-
"""multiplies update by loss times alpha"""
|
|
188
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
189
|
-
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
190
|
-
super().__init__(defaults)
|
|
191
|
-
|
|
192
|
-
@torch.no_grad
|
|
193
|
-
def step(self, var):
|
|
194
|
-
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
195
|
-
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
196
|
-
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
197
|
-
torch._foreach_mul_(var.update, mul)
|
|
198
|
-
return var
|
|
199
|
-
|
|
200
|
-
class DivByLoss(Module):
|
|
201
|
-
"""divides update by loss times alpha"""
|
|
202
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
203
|
-
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
204
|
-
super().__init__(defaults)
|
|
205
|
-
|
|
206
|
-
@torch.no_grad
|
|
207
|
-
def step(self, var):
|
|
208
|
-
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
209
|
-
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
210
|
-
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
211
|
-
torch._foreach_div_(var.update, mul)
|
|
212
|
-
return var
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
217
|
-
params = var.params
|
|
218
|
-
steps = self.settings[params[0]]['steps']
|
|
219
|
-
|
|
220
|
-
if sequential: modules = self.get_children_sequence()
|
|
221
|
-
else: modules = [self.children['module']] * steps
|
|
222
|
-
|
|
223
|
-
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
224
|
-
|
|
225
|
-
# store original params unless this is last module and can update params directly
|
|
226
|
-
params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
|
|
227
|
-
|
|
228
|
-
# first step - pass var as usual
|
|
229
|
-
var = modules[0].step(var)
|
|
230
|
-
new_var = var
|
|
231
|
-
|
|
232
|
-
# subsequent steps - update parameters and create new var
|
|
233
|
-
if len(modules) > 1:
|
|
234
|
-
for m in modules[1:]:
|
|
235
|
-
|
|
236
|
-
# update params
|
|
237
|
-
if (not new_var.skip_update):
|
|
238
|
-
if new_var.last_module_lrs is not None:
|
|
239
|
-
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
240
|
-
|
|
241
|
-
torch._foreach_sub_(params, new_var.get_update())
|
|
242
|
-
|
|
243
|
-
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
244
|
-
new_var = Var(params=new_var.params, closure=new_var.closure,
|
|
245
|
-
model=new_var.model, current_step=new_var.current_step + 1)
|
|
246
|
-
|
|
247
|
-
# step
|
|
248
|
-
new_var = m.step(new_var)
|
|
249
|
-
|
|
250
|
-
# final parameter update
|
|
251
|
-
if (not new_var.skip_update):
|
|
252
|
-
if new_var.last_module_lrs is not None:
|
|
253
|
-
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
254
|
-
|
|
255
|
-
torch._foreach_sub_(params, new_var.get_update())
|
|
256
|
-
|
|
257
|
-
# if last module, update is applied so return new var
|
|
258
|
-
if params_before_steps is None:
|
|
259
|
-
new_var.stop = True
|
|
260
|
-
new_var.skip_update = True
|
|
261
|
-
return new_var
|
|
262
|
-
|
|
263
|
-
# otherwise use parameter difference as update
|
|
264
|
-
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
265
|
-
for p, bef in zip(params, params_before_steps):
|
|
266
|
-
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
267
|
-
return var
|
|
268
|
-
|
|
269
|
-
class Multistep(Module):
|
|
270
|
-
def __init__(self, module: Chainable, steps: int):
|
|
271
|
-
defaults = dict(steps=steps)
|
|
272
|
-
super().__init__(defaults)
|
|
273
|
-
self.set_child('module', module)
|
|
274
|
-
|
|
275
|
-
@torch.no_grad
|
|
276
|
-
def step(self, var):
|
|
277
|
-
return _sequential_step(self, var, sequential=False)
|
|
278
|
-
|
|
279
|
-
class Sequential(Module):
|
|
280
|
-
def __init__(self, modules: Iterable[Chainable], steps: int):
|
|
281
|
-
defaults = dict(steps=steps)
|
|
282
|
-
super().__init__(defaults)
|
|
283
|
-
self.set_children_sequence(modules)
|
|
284
|
-
|
|
285
|
-
@torch.no_grad
|
|
286
|
-
def step(self, var):
|
|
287
|
-
return _sequential_step(self, var, sequential=True)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
class GradientAccumulation(Module):
|
|
291
|
-
"""gradient accumulation"""
|
|
292
|
-
def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
293
|
-
defaults = dict(n=n, mean=mean, stop=stop)
|
|
294
|
-
super().__init__(defaults)
|
|
295
|
-
self.set_child('modules', modules)
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
@torch.no_grad
|
|
299
|
-
def step(self, var):
|
|
300
|
-
accumulator = self.get_state(var.params, 'accumulator')
|
|
301
|
-
settings = self.settings[var.params[0]]
|
|
302
|
-
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
303
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
304
|
-
|
|
305
|
-
# add update to accumulator
|
|
306
|
-
torch._foreach_add_(accumulator, var.get_update())
|
|
307
|
-
|
|
308
|
-
# step with accumulated updates
|
|
309
|
-
if step % n == 0:
|
|
310
|
-
if mean:
|
|
311
|
-
torch._foreach_div_(accumulator, n)
|
|
312
|
-
|
|
313
|
-
var.update = [a.clone() for a in accumulator]
|
|
314
|
-
var = self.children['modules'].step(var)
|
|
315
|
-
|
|
316
|
-
# zero accumulator
|
|
317
|
-
torch._foreach_zero_(accumulator)
|
|
318
|
-
|
|
319
|
-
else:
|
|
320
|
-
# prevent update
|
|
321
|
-
if stop:
|
|
322
|
-
var.stop=True
|
|
323
|
-
var.skip_update=True
|
|
324
|
-
|
|
325
|
-
return var
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
class Dropout(Transform):
|
|
329
|
-
def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
|
|
330
|
-
defaults = dict(p=p, graft=graft)
|
|
331
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
332
|
-
|
|
333
|
-
@torch.no_grad
|
|
334
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
335
|
-
tensors = TensorList(tensors)
|
|
336
|
-
p = NumberList(s['p'] for s in settings)
|
|
337
|
-
graft = settings[0]['graft']
|
|
338
|
-
|
|
339
|
-
if graft:
|
|
340
|
-
target_norm = tensors.global_vector_norm()
|
|
341
|
-
tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
342
|
-
return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
|
|
343
|
-
|
|
344
|
-
return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
345
|
-
|
|
346
|
-
class WeightDropout(Module):
|
|
347
|
-
"""Applies dropout directly to weights."""
|
|
348
|
-
def __init__(self, p: float = 0.5, graft: bool = True):
|
|
349
|
-
defaults = dict(p=p, graft=graft)
|
|
350
|
-
super().__init__(defaults)
|
|
351
|
-
|
|
352
|
-
@torch.no_grad
|
|
353
|
-
def step(self, var):
|
|
354
|
-
closure = var.closure
|
|
355
|
-
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
356
|
-
params = TensorList(var.params)
|
|
357
|
-
p = NumberList(self.settings[p]['p'] for p in params)
|
|
358
|
-
mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
|
|
359
|
-
|
|
360
|
-
@torch.no_grad
|
|
361
|
-
def dropout_closure(backward=True):
|
|
362
|
-
orig_params = params.clone()
|
|
363
|
-
params.mul_(mask)
|
|
364
|
-
if backward:
|
|
365
|
-
with torch.enable_grad(): loss = closure()
|
|
366
|
-
else:
|
|
367
|
-
loss = closure(False)
|
|
368
|
-
params.copy_(orig_params)
|
|
369
|
-
return loss
|
|
370
|
-
|
|
371
|
-
var.closure = dropout_closure
|
|
372
|
-
return var
|
|
373
|
-
|
|
374
|
-
class NoiseSign(Transform):
|
|
375
|
-
"""uses random vector with update sign"""
|
|
376
|
-
def __init__(self, distribution:Distributions = 'normal', alpha = 1):
|
|
377
|
-
defaults = dict(distribution=distribution, alpha=alpha)
|
|
378
|
-
super().__init__(defaults, uses_grad=False)
|
|
379
|
-
|
|
380
|
-
@torch.no_grad
|
|
381
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
382
|
-
alpha = [s['alpha'] for s in settings]
|
|
383
|
-
distribution = self.settings[params[0]]['distribution']
|
|
384
|
-
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
class NegateOnLossIncrease(Module):
|
|
388
|
-
def __init__(self, backtrack=True):
|
|
389
|
-
defaults = dict(backtrack=backtrack)
|
|
390
|
-
super().__init__(defaults=defaults)
|
|
391
|
-
|
|
392
|
-
@torch.no_grad
|
|
393
|
-
def step(self, var):
|
|
394
|
-
closure = var.closure
|
|
395
|
-
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
396
|
-
backtrack = self.settings[var.params[0]]['backtrack']
|
|
397
|
-
|
|
398
|
-
update = var.get_update()
|
|
399
|
-
f_0 = var.get_loss(backward=False)
|
|
400
|
-
|
|
401
|
-
torch._foreach_sub_(var.params, update)
|
|
402
|
-
f_1 = closure(False)
|
|
403
|
-
|
|
404
|
-
if f_1 <= f_0:
|
|
405
|
-
if var.is_last and var.last_module_lrs is None:
|
|
406
|
-
var.stop = True
|
|
407
|
-
var.skip_update = True
|
|
408
|
-
return var
|
|
409
|
-
|
|
410
|
-
torch._foreach_add_(var.params, update)
|
|
411
|
-
return var
|
|
412
|
-
|
|
413
|
-
torch._foreach_add_(var.params, update)
|
|
414
|
-
if backtrack:
|
|
415
|
-
torch._foreach_neg_(var.update)
|
|
416
|
-
else:
|
|
417
|
-
torch._foreach_zero_(var.update)
|
|
418
|
-
return var
|
torchzero/modules/ops/split.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import cast
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Chainable, Module, Var
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def _split(
|
|
10
|
-
module: Module,
|
|
11
|
-
idxs,
|
|
12
|
-
params,
|
|
13
|
-
var: Var,
|
|
14
|
-
):
|
|
15
|
-
split_params = [p for i,p in enumerate(params) if i in idxs]
|
|
16
|
-
|
|
17
|
-
split_grad = None
|
|
18
|
-
if var.grad is not None:
|
|
19
|
-
split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
|
|
20
|
-
|
|
21
|
-
split_update = None
|
|
22
|
-
if var.update is not None:
|
|
23
|
-
split_update = [u for i,u in enumerate(var.update) if i in idxs]
|
|
24
|
-
|
|
25
|
-
split_var = var.clone(clone_update=False)
|
|
26
|
-
split_var.params = split_params
|
|
27
|
-
split_var.grad = split_grad
|
|
28
|
-
split_var.update = split_update
|
|
29
|
-
|
|
30
|
-
split_var = module.step(split_var)
|
|
31
|
-
|
|
32
|
-
if (var.grad is None) and (split_var.grad is not None):
|
|
33
|
-
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
34
|
-
|
|
35
|
-
if split_var.update is not None:
|
|
36
|
-
|
|
37
|
-
if var.update is None:
|
|
38
|
-
if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
|
|
39
|
-
else: var.update = [g.clone() for g in var.grad]
|
|
40
|
-
|
|
41
|
-
for idx, u in zip(idxs, split_var.update):
|
|
42
|
-
var.update[idx] = u
|
|
43
|
-
|
|
44
|
-
var.update_attrs_from_clone_(split_var)
|
|
45
|
-
return var
|
|
46
|
-
|
|
47
|
-
class Split(Module):
|
|
48
|
-
"""Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
|
|
49
|
-
def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
|
|
50
|
-
defaults = dict(filter=filter)
|
|
51
|
-
super().__init__(defaults)
|
|
52
|
-
|
|
53
|
-
if true is not None: self.set_child('true', true)
|
|
54
|
-
if false is not None: self.set_child('false', false)
|
|
55
|
-
|
|
56
|
-
def step(self, var):
|
|
57
|
-
|
|
58
|
-
params = var.params
|
|
59
|
-
filter = self.settings[params[0]]['filter']
|
|
60
|
-
|
|
61
|
-
true_idxs = []
|
|
62
|
-
false_idxs = []
|
|
63
|
-
for i,p in enumerate(params):
|
|
64
|
-
if filter(p): true_idxs.append(i)
|
|
65
|
-
else: false_idxs.append(i)
|
|
66
|
-
|
|
67
|
-
if 'true' in self.children:
|
|
68
|
-
true = self.children['true']
|
|
69
|
-
var = _split(true, idxs=true_idxs, params=params, var=var)
|
|
70
|
-
|
|
71
|
-
if 'false' in self.children:
|
|
72
|
-
false = self.children['false']
|
|
73
|
-
var = _split(false, idxs=false_idxs, params=params, var=var)
|
|
74
|
-
|
|
75
|
-
return var
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
from .adagrad import Adagrad, FullMatrixAdagrad
|
|
2
|
-
from .adam import Adam
|
|
3
|
-
from .lion import Lion
|
|
4
|
-
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
5
|
-
from .rmsprop import RMSprop
|
|
6
|
-
from .rprop import (
|
|
7
|
-
BacktrackOnSignChange,
|
|
8
|
-
Rprop,
|
|
9
|
-
ScaleLRBySignChange,
|
|
10
|
-
SignConsistencyLRs,
|
|
11
|
-
SignConsistencyMask,
|
|
12
|
-
)
|
|
13
|
-
from .shampoo import Shampoo
|
|
14
|
-
from .soap import SOAP
|
|
15
|
-
from .orthograd import OrthoGrad, orthograd_
|
|
16
|
-
from .sophia_h import SophiaH
|
|
17
|
-
# from .curveball import CurveBall
|
|
18
|
-
# from .spectral import SpectralPreconditioner
|
|
@@ -1,155 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from ...core import (
|
|
6
|
-
Chainable,
|
|
7
|
-
Module,
|
|
8
|
-
Target,
|
|
9
|
-
TensorwiseTransform,
|
|
10
|
-
Transform,
|
|
11
|
-
Var,
|
|
12
|
-
apply_transform,
|
|
13
|
-
)
|
|
14
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
15
|
-
from ...utils.linalg import matrix_power_eigh
|
|
16
|
-
from ..functional import add_power_, lerp_power_, root
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def adagrad_(
|
|
20
|
-
tensors_: TensorList,
|
|
21
|
-
sq_sum_: TensorList,
|
|
22
|
-
alpha: float | NumberList,
|
|
23
|
-
lr_decay: float | NumberList,
|
|
24
|
-
eps: float | NumberList,
|
|
25
|
-
step: int,
|
|
26
|
-
pow: float = 2,
|
|
27
|
-
use_sqrt: bool = True,
|
|
28
|
-
|
|
29
|
-
# inner args
|
|
30
|
-
inner: Module | None = None,
|
|
31
|
-
params: list[torch.Tensor] | None = None,
|
|
32
|
-
grads: list[torch.Tensor] | None = None,
|
|
33
|
-
):
|
|
34
|
-
"""returns `tensors_`"""
|
|
35
|
-
clr = alpha / (1 + step * lr_decay)
|
|
36
|
-
|
|
37
|
-
sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
|
|
38
|
-
|
|
39
|
-
if inner is not None:
|
|
40
|
-
assert params is not None
|
|
41
|
-
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
42
|
-
|
|
43
|
-
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
44
|
-
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
45
|
-
|
|
46
|
-
return tensors_
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class Adagrad(Transform):
|
|
51
|
-
"""Adagrad, divides by sum of past squares of gradients, matches pytorch Adagrad.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
55
|
-
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
56
|
-
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
57
|
-
alpha (float, optional): step size. Defaults to 1.
|
|
58
|
-
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
59
|
-
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
60
|
-
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
61
|
-
"""
|
|
62
|
-
def __init__(
|
|
63
|
-
self,
|
|
64
|
-
lr_decay: float = 0,
|
|
65
|
-
initial_accumulator_value: float = 0,
|
|
66
|
-
eps: float = 1e-10,
|
|
67
|
-
alpha: float = 1,
|
|
68
|
-
pow: float = 2,
|
|
69
|
-
use_sqrt: bool = True,
|
|
70
|
-
inner: Chainable | None = None,
|
|
71
|
-
):
|
|
72
|
-
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
73
|
-
eps = eps, pow=pow, use_sqrt = use_sqrt)
|
|
74
|
-
super().__init__(defaults=defaults, uses_grad=False)
|
|
75
|
-
|
|
76
|
-
if inner is not None:
|
|
77
|
-
self.set_child('inner', inner)
|
|
78
|
-
|
|
79
|
-
@torch.no_grad
|
|
80
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
81
|
-
tensors = TensorList(tensors)
|
|
82
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
83
|
-
|
|
84
|
-
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
85
|
-
|
|
86
|
-
pow, use_sqrt = itemgetter('pow', 'use_sqrt')(settings[0])
|
|
87
|
-
|
|
88
|
-
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
89
|
-
|
|
90
|
-
# initialize accumulator on 1st step
|
|
91
|
-
if step == 1:
|
|
92
|
-
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
93
|
-
|
|
94
|
-
return adagrad_(
|
|
95
|
-
tensors,
|
|
96
|
-
sq_sum_=sq_sum,
|
|
97
|
-
alpha=alpha,
|
|
98
|
-
lr_decay=lr_decay,
|
|
99
|
-
eps=eps,
|
|
100
|
-
step=self.global_state["step"],
|
|
101
|
-
pow=pow,
|
|
102
|
-
use_sqrt=use_sqrt,
|
|
103
|
-
|
|
104
|
-
# inner args
|
|
105
|
-
inner=self.children.get("inner", None),
|
|
106
|
-
params=params,
|
|
107
|
-
grads=grads,
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
class FullMatrixAdagrad(TensorwiseTransform):
|
|
113
|
-
def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=False, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', inner: Chainable | None = None):
|
|
114
|
-
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init)
|
|
115
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
116
|
-
|
|
117
|
-
@torch.no_grad
|
|
118
|
-
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
119
|
-
G = tensor.ravel()
|
|
120
|
-
GG = torch.outer(G, G)
|
|
121
|
-
decay = settings['decay']
|
|
122
|
-
beta = settings['beta']
|
|
123
|
-
init = settings['init']
|
|
124
|
-
|
|
125
|
-
if 'GG' not in state:
|
|
126
|
-
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
127
|
-
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
128
|
-
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
129
|
-
elif init == 'GGT': state['GG'] = GG.clone()
|
|
130
|
-
else: raise ValueError(init)
|
|
131
|
-
if decay is not None: state['GG'].mul_(decay)
|
|
132
|
-
|
|
133
|
-
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
134
|
-
else: state['GG'].add_(GG)
|
|
135
|
-
|
|
136
|
-
@torch.no_grad
|
|
137
|
-
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
138
|
-
GG = state['GG']
|
|
139
|
-
sqrt = settings['sqrt']
|
|
140
|
-
|
|
141
|
-
if tensor.numel() == 1:
|
|
142
|
-
GG = GG.squeeze()
|
|
143
|
-
if sqrt: return tensor / GG.sqrt()
|
|
144
|
-
return tensor / GG
|
|
145
|
-
|
|
146
|
-
try:
|
|
147
|
-
if sqrt: B = matrix_power_eigh(GG, -1/2)
|
|
148
|
-
else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
|
|
149
|
-
|
|
150
|
-
except torch.linalg.LinAlgError:
|
|
151
|
-
scale = 1 / tensor.abs().max()
|
|
152
|
-
return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
|
|
153
|
-
|
|
154
|
-
return (B @ tensor.ravel()).view_as(tensor)
|
|
155
|
-
|