torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,30 +1,33 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from ...core import Target, Transform
|
|
4
|
-
from ...utils import TensorList
|
|
4
|
+
from ...utils import TensorList, unpack_states, unpack_dicts
|
|
5
5
|
|
|
6
6
|
class ReduceOutwardLR(Transform):
|
|
7
|
-
"""
|
|
8
|
-
When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
7
|
+
"""When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
9
8
|
|
|
10
9
|
This means updates that move weights towards zero have higher learning rates.
|
|
10
|
+
|
|
11
|
+
.. warning::
|
|
12
|
+
This sounded good but after testing turns out it sucks.
|
|
11
13
|
"""
|
|
12
14
|
def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
|
|
13
15
|
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
14
16
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
15
17
|
|
|
16
18
|
@torch.no_grad
|
|
17
|
-
def
|
|
19
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
18
20
|
params = TensorList(params)
|
|
19
21
|
tensors = TensorList(tensors)
|
|
20
22
|
|
|
21
|
-
mul =
|
|
22
|
-
s =
|
|
23
|
+
mul = [s['mul'] for s in settings]
|
|
24
|
+
s = settings[0]
|
|
23
25
|
use_grad = s['use_grad']
|
|
24
26
|
invert = s['invert']
|
|
25
27
|
|
|
26
|
-
if use_grad: cur =
|
|
28
|
+
if use_grad: cur = grads
|
|
27
29
|
else: cur = tensors
|
|
30
|
+
assert cur is not None
|
|
28
31
|
|
|
29
32
|
# mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
|
|
30
33
|
if invert: mask = (params * cur) > 0
|
|
@@ -6,35 +6,18 @@ import torch
|
|
|
6
6
|
from ...core import Chainable
|
|
7
7
|
from ...utils import vec_to_tensors, TensorList
|
|
8
8
|
from ..optimizers.shampoo import _merge_small_dims
|
|
9
|
-
from
|
|
9
|
+
from ..projections import ProjectionBase
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class VectorProjection(Projection):
|
|
13
|
-
"""
|
|
14
|
-
flattens and concatenates all parameters into a vector
|
|
15
|
-
"""
|
|
16
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
17
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
18
12
|
|
|
19
|
-
|
|
20
|
-
def project(self, tensors, vars, current):
|
|
21
|
-
return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
|
|
22
|
-
|
|
23
|
-
@torch.no_grad
|
|
24
|
-
def unproject(self, tensors, vars, current):
|
|
25
|
-
return vec_to_tensors(vec=tensors[0], reference=vars.params)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class TensorizeProjection(Projection):
|
|
13
|
+
class TensorizeProjection(ProjectionBase):
|
|
30
14
|
"""flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
|
|
31
15
|
def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
|
|
32
16
|
defaults = dict(max_side=max_side)
|
|
33
17
|
super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
34
18
|
|
|
35
19
|
@torch.no_grad
|
|
36
|
-
def project(self, tensors,
|
|
37
|
-
params = vars.params
|
|
20
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
38
21
|
max_side = self.settings[params[0]]['max_side']
|
|
39
22
|
num_elems = sum(t.numel() for t in tensors)
|
|
40
23
|
|
|
@@ -60,23 +43,23 @@ class TensorizeProjection(Projection):
|
|
|
60
43
|
return [vec.view(dims)]
|
|
61
44
|
|
|
62
45
|
@torch.no_grad
|
|
63
|
-
def unproject(self,
|
|
46
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
64
47
|
remainder = self.global_state['remainder']
|
|
65
48
|
# warnings.warn(f'{tensors[0].shape = }')
|
|
66
|
-
vec =
|
|
49
|
+
vec = projected_tensors[0].view(-1)
|
|
67
50
|
if remainder > 0: vec = vec[:-remainder]
|
|
68
|
-
return vec_to_tensors(vec,
|
|
51
|
+
return vec_to_tensors(vec, params)
|
|
69
52
|
|
|
70
|
-
class BlockPartition(
|
|
53
|
+
class BlockPartition(ProjectionBase):
|
|
71
54
|
"""splits parameters into blocks (for now flatttens them and chunks)"""
|
|
72
55
|
def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
|
|
73
56
|
defaults = dict(max_size=max_size, batched=batched)
|
|
74
57
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
75
58
|
|
|
76
59
|
@torch.no_grad
|
|
77
|
-
def project(self, tensors,
|
|
60
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
78
61
|
partitioned = []
|
|
79
|
-
for p,t in zip(
|
|
62
|
+
for p,t in zip(params, tensors):
|
|
80
63
|
settings = self.settings[p]
|
|
81
64
|
max_size = settings['max_size']
|
|
82
65
|
n = t.numel()
|
|
@@ -101,10 +84,10 @@ class BlockPartition(Projection):
|
|
|
101
84
|
return partitioned
|
|
102
85
|
|
|
103
86
|
@torch.no_grad
|
|
104
|
-
def unproject(self,
|
|
105
|
-
ti = iter(
|
|
87
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
88
|
+
ti = iter(projected_tensors)
|
|
106
89
|
unprojected = []
|
|
107
|
-
for p in
|
|
90
|
+
for p in params:
|
|
108
91
|
settings = self.settings[p]
|
|
109
92
|
n = p.numel()
|
|
110
93
|
|
|
@@ -124,28 +107,3 @@ class BlockPartition(Projection):
|
|
|
124
107
|
|
|
125
108
|
return unprojected
|
|
126
109
|
|
|
127
|
-
|
|
128
|
-
class TensorNormsProjection(Projection):
|
|
129
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
130
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
131
|
-
|
|
132
|
-
@torch.no_grad
|
|
133
|
-
def project(self, tensors, vars, current):
|
|
134
|
-
orig = self.get_state(f'{current}_orig', params=vars.params)
|
|
135
|
-
torch._foreach_copy_(orig, tensors)
|
|
136
|
-
|
|
137
|
-
norms = torch._foreach_norm(tensors)
|
|
138
|
-
self.get_state(f'{current}_orig_norms', params=vars.params, init=norms, cls=TensorList).set_(norms)
|
|
139
|
-
|
|
140
|
-
return [torch.stack(norms)]
|
|
141
|
-
|
|
142
|
-
@torch.no_grad
|
|
143
|
-
def unproject(self, tensors, vars, current):
|
|
144
|
-
orig = self.get_state(f'{current}_orig', params=vars.params)
|
|
145
|
-
orig_norms = torch.stack(self.get_state(f'{current}_orig_norms', params=vars.params))
|
|
146
|
-
target_norms = tensors[0]
|
|
147
|
-
|
|
148
|
-
orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
|
|
149
|
-
|
|
150
|
-
torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
|
|
151
|
-
return orig
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
# import torchzero as tz
|
|
7
7
|
|
|
8
|
-
from ...core import Transform, Chainable,
|
|
8
|
+
from ...core import Transform, Chainable, apply_transform
|
|
9
9
|
from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
|
|
10
10
|
from ...utils import TensorList, vec_to_tensors_
|
|
11
11
|
|
|
@@ -38,15 +38,20 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""
|
|
41
|
+
"""Whitens in random slowly changing subspace.
|
|
42
|
+
|
|
43
|
+
.. warning::
|
|
44
|
+
Experimental and this is a barebones implementation.
|
|
45
|
+
|
|
46
|
+
"""
|
|
42
47
|
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
48
|
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
49
|
super().__init__(defaults, uses_grad=False)
|
|
45
50
|
|
|
46
51
|
if inner is not None: self.set_child('inner', inner)
|
|
47
52
|
|
|
48
|
-
def
|
|
49
|
-
settings =
|
|
53
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
54
|
+
settings = settings[0]
|
|
50
55
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
51
56
|
k = settings['k']
|
|
52
57
|
beta = settings['beta']
|
|
@@ -65,7 +70,7 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
65
70
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
71
|
|
|
67
72
|
if 'inner' in self.children:
|
|
68
|
-
tensors =
|
|
73
|
+
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
69
74
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
75
|
|
|
71
76
|
try:
|
|
@@ -78,9 +83,14 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
78
83
|
|
|
79
84
|
|
|
80
85
|
class HistorySubspacePreconditioning(Transform):
|
|
81
|
-
"""
|
|
86
|
+
"""Whitens in subspace spanned by history of gradient differences.
|
|
87
|
+
|
|
88
|
+
.. warning::
|
|
89
|
+
Experimental and this is a barebones implementation.
|
|
82
90
|
|
|
83
|
-
|
|
91
|
+
Args:
|
|
92
|
+
beta - for preconditioner itself in the basis.
|
|
93
|
+
basis_beta - how much basis is allowed to change.
|
|
84
94
|
"""
|
|
85
95
|
def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
|
|
86
96
|
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
@@ -88,8 +98,8 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
88
98
|
|
|
89
99
|
if inner is not None: self.set_child('inner', inner)
|
|
90
100
|
|
|
91
|
-
def
|
|
92
|
-
settings =
|
|
101
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
102
|
+
settings = settings[0]
|
|
93
103
|
|
|
94
104
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
95
105
|
k = settings['k']
|
|
@@ -122,7 +132,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
122
132
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
123
133
|
|
|
124
134
|
if 'inner' in self.children:
|
|
125
|
-
tensors =
|
|
135
|
+
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
126
136
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
127
137
|
|
|
128
138
|
try:
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, TensorwiseTransform
|
|
6
|
+
from ...utils.linalg import matrix_power_eigh
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TensorAdagrad(TensorwiseTransform):
|
|
10
|
+
"""3rd order whitening (maybe normalizes skewness, but don't quote me on it).
|
|
11
|
+
|
|
12
|
+
.. warning::
|
|
13
|
+
Experimental.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
|
|
16
|
+
defaults = dict(history_size=history_size, reg=reg)
|
|
17
|
+
super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
|
|
18
|
+
|
|
19
|
+
@torch.no_grad
|
|
20
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
21
|
+
reg = setting['reg']
|
|
22
|
+
if 'history' not in state:
|
|
23
|
+
state['history'] = deque(maxlen=setting['history_size'])
|
|
24
|
+
|
|
25
|
+
g = tensor.view(-1)
|
|
26
|
+
history = state['history']
|
|
27
|
+
history.append(g.clone())
|
|
28
|
+
|
|
29
|
+
I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
|
|
30
|
+
g_k = history[0]
|
|
31
|
+
outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
32
|
+
if len(history) > 1:
|
|
33
|
+
for g_k in list(history)[1:]:
|
|
34
|
+
outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
35
|
+
|
|
36
|
+
state['outer'] = outer.add_(I)
|
|
37
|
+
|
|
38
|
+
@torch.no_grad
|
|
39
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
40
|
+
outer = state['outer']
|
|
41
|
+
P = matrix_power_eigh(outer, -1/2)
|
|
42
|
+
return (P @ tensor.ravel()).view_as(tensor)
|
torchzero/modules/functional.py
CHANGED
|
@@ -7,8 +7,9 @@ storage is always indicated in the docstring.
|
|
|
7
7
|
|
|
8
8
|
Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
|
|
9
9
|
"""
|
|
10
|
-
|
|
11
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import overload
|
|
12
|
+
import torch
|
|
12
13
|
|
|
13
14
|
from ..utils import NumberList, TensorList
|
|
14
15
|
|
|
@@ -206,4 +207,13 @@ def sqrt_centered_ema_sq_(
|
|
|
206
207
|
ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
|
|
207
208
|
)
|
|
208
209
|
|
|
210
|
+
@overload
|
|
211
|
+
def safe_scaling_(tensors_: torch.Tensor) -> torch.Tensor: ...
|
|
212
|
+
@overload
|
|
213
|
+
def safe_scaling_(tensors_: TensorList) -> TensorList: ...
|
|
214
|
+
def safe_scaling_(tensors_: torch.Tensor | TensorList):
|
|
215
|
+
if isinstance(tensors_, torch.Tensor): scale = 1 / tensors_.abs().sum()
|
|
216
|
+
else: scale = 1 / tensors_.abs().global_sum()
|
|
217
|
+
scale = scale.clip(min=torch.finfo(tensors_[0].dtype).eps, max=1)
|
|
218
|
+
return tensors_.mul_(scale)
|
|
209
219
|
|
|
@@ -77,8 +77,11 @@ def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_
|
|
|
77
77
|
return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
|
|
78
78
|
|
|
79
79
|
_FD_FUNCS = {
|
|
80
|
+
"forward": _forward2,
|
|
80
81
|
"forward2": _forward2,
|
|
82
|
+
"backward": _backward2,
|
|
81
83
|
"backward2": _backward2,
|
|
84
|
+
"central": _central2,
|
|
82
85
|
"central2": _central2,
|
|
83
86
|
"central3": _central2, # they are the same
|
|
84
87
|
"forward3": _forward3,
|
|
@@ -88,19 +91,43 @@ _FD_FUNCS = {
|
|
|
88
91
|
|
|
89
92
|
|
|
90
93
|
class FDM(GradApproximator):
|
|
91
|
-
"""Approximate gradients via finite difference method
|
|
94
|
+
"""Approximate gradients via finite difference method.
|
|
95
|
+
|
|
96
|
+
.. note::
|
|
97
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
98
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
92
99
|
|
|
93
100
|
Args:
|
|
94
101
|
h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
|
|
95
102
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
96
|
-
target (GradTarget, optional): what to set on
|
|
103
|
+
target (GradTarget, optional): what to set on var. Defaults to 'closure'.
|
|
104
|
+
|
|
105
|
+
Examples:
|
|
106
|
+
plain FDM:
|
|
107
|
+
|
|
108
|
+
.. code-block:: python
|
|
109
|
+
|
|
110
|
+
fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
|
|
111
|
+
|
|
112
|
+
Any gradient-based method can use FDM-estimated gradients seamlessly.
|
|
113
|
+
|
|
114
|
+
.. code-block:: python
|
|
115
|
+
|
|
116
|
+
fdm_ncg = tz.Modular(
|
|
117
|
+
model.parameters(),
|
|
118
|
+
tz.m.FDM(),
|
|
119
|
+
# set hvp_method to "forward" so that it
|
|
120
|
+
# uses gradient difference instead of autograd
|
|
121
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
122
|
+
tz.m.Backtracking()
|
|
123
|
+
)
|
|
97
124
|
"""
|
|
98
|
-
def __init__(self, h: float=1e-3, formula: _FD_Formula = '
|
|
125
|
+
def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
|
|
99
126
|
defaults = dict(h=h, formula=formula)
|
|
100
127
|
super().__init__(defaults, target=target)
|
|
101
128
|
|
|
102
129
|
@torch.no_grad
|
|
103
|
-
def approximate(self, closure, params, loss
|
|
130
|
+
def approximate(self, closure, params, loss):
|
|
104
131
|
grads = []
|
|
105
132
|
loss_approx = None
|
|
106
133
|
|
|
@@ -4,26 +4,36 @@ from typing import Any, Literal
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...utils import Distributions, NumberList, TensorList
|
|
7
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
8
8
|
from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
|
|
9
9
|
from .grad_approximator import GradApproximator, GradTarget
|
|
10
10
|
from .rfdm import RandomizedFDM
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ForwardGradient(RandomizedFDM):
|
|
14
|
-
"""Forward gradient method
|
|
14
|
+
"""Forward gradient method.
|
|
15
|
+
|
|
16
|
+
This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
|
|
17
|
+
|
|
18
|
+
.. note::
|
|
19
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
20
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
21
|
+
|
|
15
22
|
|
|
16
23
|
Args:
|
|
17
24
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
18
25
|
distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
|
|
19
26
|
beta (float, optional):
|
|
20
|
-
|
|
27
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
21
28
|
pre_generate (bool, optional):
|
|
22
|
-
whether to pre-generate gradient samples before each step. Defaults to True.
|
|
29
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
23
30
|
jvp_method (str, optional):
|
|
24
|
-
how to calculate jacobian vector product, note that with `forward` and 'central' this is
|
|
31
|
+
how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
|
|
25
32
|
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
26
|
-
target (GradTarget, optional): what to set on
|
|
33
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
34
|
+
|
|
35
|
+
References:
|
|
36
|
+
Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
|
|
27
37
|
"""
|
|
28
38
|
PRE_MULTIPLY_BY_H = False
|
|
29
39
|
def __init__(
|
|
@@ -41,7 +51,7 @@ class ForwardGradient(RandomizedFDM):
|
|
|
41
51
|
self.defaults['jvp_method'] = jvp_method
|
|
42
52
|
|
|
43
53
|
@torch.no_grad
|
|
44
|
-
def approximate(self, closure, params, loss
|
|
54
|
+
def approximate(self, closure, params, loss):
|
|
45
55
|
params = TensorList(params)
|
|
46
56
|
loss_approx = None
|
|
47
57
|
|
|
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Var
|
|
9
9
|
|
|
10
10
|
GradTarget = Literal['update', 'grad', 'closure']
|
|
11
11
|
_Scalar = torch.Tensor | float
|
|
@@ -14,53 +14,98 @@ class GradApproximator(Module, ABC):
|
|
|
14
14
|
"""Base class for gradient approximations.
|
|
15
15
|
This is an abstract class, to use it, subclass it and override `approximate`.
|
|
16
16
|
|
|
17
|
+
GradientApproximator modifies the closure to evaluate the estimated gradients,
|
|
18
|
+
and further closure-based modules will use the modified closure.
|
|
19
|
+
|
|
17
20
|
Args:
|
|
18
21
|
defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
|
|
19
22
|
target (str, optional):
|
|
20
|
-
whether to set `
|
|
21
|
-
|
|
23
|
+
whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
|
|
27
|
+
Basic SPSA method implementation.
|
|
28
|
+
|
|
29
|
+
.. code-block:: python
|
|
30
|
+
|
|
31
|
+
class SPSA(GradApproximator):
|
|
32
|
+
def __init__(self, h=1e-3):
|
|
33
|
+
defaults = dict(h=h)
|
|
34
|
+
super().__init__(defaults)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def approximate(self, closure, params, loss):
|
|
38
|
+
perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
|
|
39
|
+
|
|
40
|
+
# evaluate params + perturbation
|
|
41
|
+
torch._foreach_add_(params, perturbation)
|
|
42
|
+
loss_plus = closure(False)
|
|
43
|
+
|
|
44
|
+
# evaluate params - perturbation
|
|
45
|
+
torch._foreach_sub_(params, perturbation)
|
|
46
|
+
torch._foreach_sub_(params, perturbation)
|
|
47
|
+
loss_minus = closure(False)
|
|
48
|
+
|
|
49
|
+
# restore original params
|
|
50
|
+
torch._foreach_add_(params, perturbation)
|
|
51
|
+
|
|
52
|
+
# calculate SPSA gradients
|
|
53
|
+
spsa_grads = []
|
|
54
|
+
for p, pert in zip(params, perturbation):
|
|
55
|
+
settings = self.settings[p]
|
|
56
|
+
h = settings['h']
|
|
57
|
+
d = (loss_plus - loss_minus) / (2*(h**2))
|
|
58
|
+
spsa_grads.append(pert * d)
|
|
59
|
+
|
|
60
|
+
# returns tuple: (grads, loss, loss_approx)
|
|
61
|
+
# loss must be with initial parameters
|
|
62
|
+
# since we only evaluated loss with perturbed parameters
|
|
63
|
+
# we only have loss_approx
|
|
64
|
+
return spsa_grads, None, loss_plus
|
|
65
|
+
|
|
66
|
+
"""
|
|
22
67
|
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
23
68
|
super().__init__(defaults)
|
|
24
69
|
self._target: GradTarget = target
|
|
25
70
|
|
|
26
71
|
@abstractmethod
|
|
27
|
-
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None
|
|
72
|
+
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
|
|
28
73
|
"""Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
|
|
29
74
|
|
|
30
|
-
def pre_step(self,
|
|
75
|
+
def pre_step(self, var: Var) -> Var | None:
|
|
31
76
|
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
32
77
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
33
|
-
return
|
|
78
|
+
return var
|
|
34
79
|
|
|
35
80
|
@torch.no_grad
|
|
36
|
-
def step(self,
|
|
37
|
-
ret = self.pre_step(
|
|
38
|
-
if isinstance(ret,
|
|
81
|
+
def step(self, var):
|
|
82
|
+
ret = self.pre_step(var)
|
|
83
|
+
if isinstance(ret, Var): var = ret
|
|
39
84
|
|
|
40
|
-
if
|
|
41
|
-
params, closure, loss =
|
|
85
|
+
if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
86
|
+
params, closure, loss = var.params, var.closure, var.loss
|
|
42
87
|
|
|
43
88
|
if self._target == 'closure':
|
|
44
89
|
|
|
45
90
|
def approx_closure(backward=True):
|
|
46
91
|
if backward:
|
|
47
92
|
# set loss to None because closure might be evaluated at different points
|
|
48
|
-
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None
|
|
93
|
+
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
|
|
49
94
|
for p, g in zip(params, grad): p.grad = g
|
|
50
|
-
return l if l is not None else
|
|
95
|
+
return l if l is not None else closure(False)
|
|
51
96
|
return closure(False)
|
|
52
97
|
|
|
53
|
-
|
|
54
|
-
return
|
|
98
|
+
var.closure = approx_closure
|
|
99
|
+
return var
|
|
55
100
|
|
|
56
|
-
# if
|
|
57
|
-
# warnings.warn('Using grad approximator when `
|
|
58
|
-
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss
|
|
59
|
-
if loss_approx is not None:
|
|
60
|
-
if loss is not None:
|
|
61
|
-
if self._target == 'grad':
|
|
62
|
-
elif self._target == 'update':
|
|
101
|
+
# if var.grad is not None:
|
|
102
|
+
# warnings.warn('Using grad approximator when `var.grad` is already set.')
|
|
103
|
+
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
|
|
104
|
+
if loss_approx is not None: var.loss_approx = loss_approx
|
|
105
|
+
if loss is not None: var.loss = var.loss_approx = loss
|
|
106
|
+
if self._target == 'grad': var.grad = list(grad)
|
|
107
|
+
elif self._target == 'update': var.update = list(grad)
|
|
63
108
|
else: raise ValueError(self._target)
|
|
64
|
-
return
|
|
109
|
+
return var
|
|
65
110
|
|
|
66
|
-
_FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', '
|
|
111
|
+
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa5']
|