torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -6,35 +6,18 @@ import torch
|
|
|
6
6
|
from ...core import Chainable
|
|
7
7
|
from ...utils import vec_to_tensors, TensorList
|
|
8
8
|
from ..optimizers.shampoo import _merge_small_dims
|
|
9
|
-
from
|
|
9
|
+
from ..projections import ProjectionBase
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class VectorProjection(Projection):
|
|
13
|
-
"""
|
|
14
|
-
flattens and concatenates all parameters into a vector
|
|
15
|
-
"""
|
|
16
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
17
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
18
12
|
|
|
19
|
-
|
|
20
|
-
def project(self, tensors, var, current):
|
|
21
|
-
return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
|
|
22
|
-
|
|
23
|
-
@torch.no_grad
|
|
24
|
-
def unproject(self, tensors, var, current):
|
|
25
|
-
return vec_to_tensors(vec=tensors[0], reference=var.params)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class TensorizeProjection(Projection):
|
|
13
|
+
class TensorizeProjection(ProjectionBase):
|
|
30
14
|
"""flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
|
|
31
15
|
def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
|
|
32
16
|
defaults = dict(max_side=max_side)
|
|
33
17
|
super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
34
18
|
|
|
35
19
|
@torch.no_grad
|
|
36
|
-
def project(self, tensors,
|
|
37
|
-
params = var.params
|
|
20
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
38
21
|
max_side = self.settings[params[0]]['max_side']
|
|
39
22
|
num_elems = sum(t.numel() for t in tensors)
|
|
40
23
|
|
|
@@ -60,23 +43,23 @@ class TensorizeProjection(Projection):
|
|
|
60
43
|
return [vec.view(dims)]
|
|
61
44
|
|
|
62
45
|
@torch.no_grad
|
|
63
|
-
def unproject(self,
|
|
46
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
64
47
|
remainder = self.global_state['remainder']
|
|
65
48
|
# warnings.warn(f'{tensors[0].shape = }')
|
|
66
|
-
vec =
|
|
49
|
+
vec = projected_tensors[0].view(-1)
|
|
67
50
|
if remainder > 0: vec = vec[:-remainder]
|
|
68
|
-
return vec_to_tensors(vec,
|
|
51
|
+
return vec_to_tensors(vec, params)
|
|
69
52
|
|
|
70
|
-
class BlockPartition(
|
|
53
|
+
class BlockPartition(ProjectionBase):
|
|
71
54
|
"""splits parameters into blocks (for now flatttens them and chunks)"""
|
|
72
55
|
def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
|
|
73
56
|
defaults = dict(max_size=max_size, batched=batched)
|
|
74
57
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
75
58
|
|
|
76
59
|
@torch.no_grad
|
|
77
|
-
def project(self, tensors,
|
|
60
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
78
61
|
partitioned = []
|
|
79
|
-
for p,t in zip(
|
|
62
|
+
for p,t in zip(params, tensors):
|
|
80
63
|
settings = self.settings[p]
|
|
81
64
|
max_size = settings['max_size']
|
|
82
65
|
n = t.numel()
|
|
@@ -101,10 +84,10 @@ class BlockPartition(Projection):
|
|
|
101
84
|
return partitioned
|
|
102
85
|
|
|
103
86
|
@torch.no_grad
|
|
104
|
-
def unproject(self,
|
|
105
|
-
ti = iter(
|
|
87
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
88
|
+
ti = iter(projected_tensors)
|
|
106
89
|
unprojected = []
|
|
107
|
-
for p in
|
|
90
|
+
for p in params:
|
|
108
91
|
settings = self.settings[p]
|
|
109
92
|
n = p.numel()
|
|
110
93
|
|
|
@@ -124,28 +107,3 @@ class BlockPartition(Projection):
|
|
|
124
107
|
|
|
125
108
|
return unprojected
|
|
126
109
|
|
|
127
|
-
|
|
128
|
-
class TensorNormsProjection(Projection):
|
|
129
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
130
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
131
|
-
|
|
132
|
-
@torch.no_grad
|
|
133
|
-
def project(self, tensors, var, current):
|
|
134
|
-
orig = self.get_state(var.params, f'{current}_orig')
|
|
135
|
-
torch._foreach_copy_(orig, tensors)
|
|
136
|
-
|
|
137
|
-
norms = torch._foreach_norm(tensors)
|
|
138
|
-
self.get_state(var.params, f'{current}_orig_norms', cls=TensorList).set_(norms)
|
|
139
|
-
|
|
140
|
-
return [torch.stack(norms)]
|
|
141
|
-
|
|
142
|
-
@torch.no_grad
|
|
143
|
-
def unproject(self, tensors, var, current):
|
|
144
|
-
orig = self.get_state(var.params, f'{current}_orig')
|
|
145
|
-
orig_norms = torch.stack(self.get_state(var.params, f'{current}_orig_norms'))
|
|
146
|
-
target_norms = tensors[0]
|
|
147
|
-
|
|
148
|
-
orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
|
|
149
|
-
|
|
150
|
-
torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
|
|
151
|
-
return orig
|
|
@@ -38,14 +38,19 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""Whitens in random slowly changing subspace.
|
|
41
|
+
"""Whitens in random slowly changing subspace.
|
|
42
|
+
|
|
43
|
+
.. warning::
|
|
44
|
+
Experimental and this is a barebones implementation.
|
|
45
|
+
|
|
46
|
+
"""
|
|
42
47
|
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
48
|
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
49
|
super().__init__(defaults, uses_grad=False)
|
|
45
50
|
|
|
46
51
|
if inner is not None: self.set_child('inner', inner)
|
|
47
52
|
|
|
48
|
-
def
|
|
53
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
49
54
|
settings = settings[0]
|
|
50
55
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
51
56
|
k = settings['k']
|
|
@@ -79,7 +84,9 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
79
84
|
|
|
80
85
|
class HistorySubspacePreconditioning(Transform):
|
|
81
86
|
"""Whitens in subspace spanned by history of gradient differences.
|
|
82
|
-
|
|
87
|
+
|
|
88
|
+
.. warning::
|
|
89
|
+
Experimental and this is a barebones implementation.
|
|
83
90
|
|
|
84
91
|
Args:
|
|
85
92
|
beta - for preconditioner itself in the basis.
|
|
@@ -91,7 +98,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
91
98
|
|
|
92
99
|
if inner is not None: self.set_child('inner', inner)
|
|
93
100
|
|
|
94
|
-
def
|
|
101
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
95
102
|
settings = settings[0]
|
|
96
103
|
|
|
97
104
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
@@ -6,17 +6,21 @@ from ...core import Chainable, TensorwiseTransform
|
|
|
6
6
|
from ...utils.linalg import matrix_power_eigh
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class
|
|
10
|
-
"""3rd order whitening (maybe normalizes skewness
|
|
9
|
+
class TensorAdagrad(TensorwiseTransform):
|
|
10
|
+
"""3rd order whitening (maybe normalizes skewness, but don't quote me on it).
|
|
11
|
+
|
|
12
|
+
.. warning::
|
|
13
|
+
Experimental.
|
|
14
|
+
"""
|
|
11
15
|
def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
|
|
12
16
|
defaults = dict(history_size=history_size, reg=reg)
|
|
13
17
|
super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
|
|
14
18
|
|
|
15
19
|
@torch.no_grad
|
|
16
|
-
def update_tensor(self, tensor, param, grad, loss, state,
|
|
17
|
-
reg =
|
|
20
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
21
|
+
reg = setting['reg']
|
|
18
22
|
if 'history' not in state:
|
|
19
|
-
state['history'] = deque(maxlen=
|
|
23
|
+
state['history'] = deque(maxlen=setting['history_size'])
|
|
20
24
|
|
|
21
25
|
g = tensor.view(-1)
|
|
22
26
|
history = state['history']
|
|
@@ -32,7 +36,7 @@ class TAda(TensorwiseTransform):
|
|
|
32
36
|
state['outer'] = outer.add_(I)
|
|
33
37
|
|
|
34
38
|
@torch.no_grad
|
|
35
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
39
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
36
40
|
outer = state['outer']
|
|
37
41
|
P = matrix_power_eigh(outer, -1/2)
|
|
38
42
|
return (P @ tensor.ravel()).view_as(tensor)
|
torchzero/modules/functional.py
CHANGED
|
@@ -7,8 +7,9 @@ storage is always indicated in the docstring.
|
|
|
7
7
|
|
|
8
8
|
Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
|
|
9
9
|
"""
|
|
10
|
-
|
|
11
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import overload
|
|
12
|
+
import torch
|
|
12
13
|
|
|
13
14
|
from ..utils import NumberList, TensorList
|
|
14
15
|
|
|
@@ -206,4 +207,13 @@ def sqrt_centered_ema_sq_(
|
|
|
206
207
|
ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
|
|
207
208
|
)
|
|
208
209
|
|
|
210
|
+
@overload
|
|
211
|
+
def safe_scaling_(tensors_: torch.Tensor) -> torch.Tensor: ...
|
|
212
|
+
@overload
|
|
213
|
+
def safe_scaling_(tensors_: TensorList) -> TensorList: ...
|
|
214
|
+
def safe_scaling_(tensors_: torch.Tensor | TensorList):
|
|
215
|
+
if isinstance(tensors_, torch.Tensor): scale = 1 / tensors_.abs().sum()
|
|
216
|
+
else: scale = 1 / tensors_.abs().global_sum()
|
|
217
|
+
scale = scale.clip(min=torch.finfo(tensors_[0].dtype).eps, max=1)
|
|
218
|
+
return tensors_.mul_(scale)
|
|
209
219
|
|
|
@@ -77,8 +77,11 @@ def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_
|
|
|
77
77
|
return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
|
|
78
78
|
|
|
79
79
|
_FD_FUNCS = {
|
|
80
|
+
"forward": _forward2,
|
|
80
81
|
"forward2": _forward2,
|
|
82
|
+
"backward": _backward2,
|
|
81
83
|
"backward2": _backward2,
|
|
84
|
+
"central": _central2,
|
|
82
85
|
"central2": _central2,
|
|
83
86
|
"central3": _central2, # they are the same
|
|
84
87
|
"forward3": _forward3,
|
|
@@ -88,19 +91,43 @@ _FD_FUNCS = {
|
|
|
88
91
|
|
|
89
92
|
|
|
90
93
|
class FDM(GradApproximator):
|
|
91
|
-
"""Approximate gradients via finite difference method
|
|
94
|
+
"""Approximate gradients via finite difference method.
|
|
95
|
+
|
|
96
|
+
.. note::
|
|
97
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
98
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
92
99
|
|
|
93
100
|
Args:
|
|
94
101
|
h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
|
|
95
102
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
96
103
|
target (GradTarget, optional): what to set on var. Defaults to 'closure'.
|
|
104
|
+
|
|
105
|
+
Examples:
|
|
106
|
+
plain FDM:
|
|
107
|
+
|
|
108
|
+
.. code-block:: python
|
|
109
|
+
|
|
110
|
+
fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
|
|
111
|
+
|
|
112
|
+
Any gradient-based method can use FDM-estimated gradients seamlessly.
|
|
113
|
+
|
|
114
|
+
.. code-block:: python
|
|
115
|
+
|
|
116
|
+
fdm_ncg = tz.Modular(
|
|
117
|
+
model.parameters(),
|
|
118
|
+
tz.m.FDM(),
|
|
119
|
+
# set hvp_method to "forward" so that it
|
|
120
|
+
# uses gradient difference instead of autograd
|
|
121
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
122
|
+
tz.m.Backtracking()
|
|
123
|
+
)
|
|
97
124
|
"""
|
|
98
|
-
def __init__(self, h: float=1e-3, formula: _FD_Formula = '
|
|
125
|
+
def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
|
|
99
126
|
defaults = dict(h=h, formula=formula)
|
|
100
127
|
super().__init__(defaults, target=target)
|
|
101
128
|
|
|
102
129
|
@torch.no_grad
|
|
103
|
-
def approximate(self, closure, params, loss
|
|
130
|
+
def approximate(self, closure, params, loss):
|
|
104
131
|
grads = []
|
|
105
132
|
loss_approx = None
|
|
106
133
|
|
|
@@ -4,14 +4,21 @@ from typing import Any, Literal
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...utils import Distributions, NumberList, TensorList
|
|
7
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
8
8
|
from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
|
|
9
9
|
from .grad_approximator import GradApproximator, GradTarget
|
|
10
10
|
from .rfdm import RandomizedFDM
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ForwardGradient(RandomizedFDM):
|
|
14
|
-
"""Forward gradient method
|
|
14
|
+
"""Forward gradient method.
|
|
15
|
+
|
|
16
|
+
This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
|
|
17
|
+
|
|
18
|
+
.. note::
|
|
19
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
20
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
21
|
+
|
|
15
22
|
|
|
16
23
|
Args:
|
|
17
24
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
@@ -24,6 +31,9 @@ class ForwardGradient(RandomizedFDM):
|
|
|
24
31
|
how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
|
|
25
32
|
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
26
33
|
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
34
|
+
|
|
35
|
+
References:
|
|
36
|
+
Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
|
|
27
37
|
"""
|
|
28
38
|
PRE_MULTIPLY_BY_H = False
|
|
29
39
|
def __init__(
|
|
@@ -41,7 +51,7 @@ class ForwardGradient(RandomizedFDM):
|
|
|
41
51
|
self.defaults['jvp_method'] = jvp_method
|
|
42
52
|
|
|
43
53
|
@torch.no_grad
|
|
44
|
-
def approximate(self, closure, params, loss
|
|
54
|
+
def approximate(self, closure, params, loss):
|
|
45
55
|
params = TensorList(params)
|
|
46
56
|
loss_approx = None
|
|
47
57
|
|
|
@@ -14,17 +14,62 @@ class GradApproximator(Module, ABC):
|
|
|
14
14
|
"""Base class for gradient approximations.
|
|
15
15
|
This is an abstract class, to use it, subclass it and override `approximate`.
|
|
16
16
|
|
|
17
|
+
GradientApproximator modifies the closure to evaluate the estimated gradients,
|
|
18
|
+
and further closure-based modules will use the modified closure.
|
|
19
|
+
|
|
17
20
|
Args:
|
|
18
21
|
defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
|
|
19
22
|
target (str, optional):
|
|
20
23
|
whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
|
|
21
|
-
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
|
|
27
|
+
Basic SPSA method implementation.
|
|
28
|
+
|
|
29
|
+
.. code-block:: python
|
|
30
|
+
|
|
31
|
+
class SPSA(GradApproximator):
|
|
32
|
+
def __init__(self, h=1e-3):
|
|
33
|
+
defaults = dict(h=h)
|
|
34
|
+
super().__init__(defaults)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def approximate(self, closure, params, loss):
|
|
38
|
+
perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
|
|
39
|
+
|
|
40
|
+
# evaluate params + perturbation
|
|
41
|
+
torch._foreach_add_(params, perturbation)
|
|
42
|
+
loss_plus = closure(False)
|
|
43
|
+
|
|
44
|
+
# evaluate params - perturbation
|
|
45
|
+
torch._foreach_sub_(params, perturbation)
|
|
46
|
+
torch._foreach_sub_(params, perturbation)
|
|
47
|
+
loss_minus = closure(False)
|
|
48
|
+
|
|
49
|
+
# restore original params
|
|
50
|
+
torch._foreach_add_(params, perturbation)
|
|
51
|
+
|
|
52
|
+
# calculate SPSA gradients
|
|
53
|
+
spsa_grads = []
|
|
54
|
+
for p, pert in zip(params, perturbation):
|
|
55
|
+
settings = self.settings[p]
|
|
56
|
+
h = settings['h']
|
|
57
|
+
d = (loss_plus - loss_minus) / (2*(h**2))
|
|
58
|
+
spsa_grads.append(pert * d)
|
|
59
|
+
|
|
60
|
+
# returns tuple: (grads, loss, loss_approx)
|
|
61
|
+
# loss must be with initial parameters
|
|
62
|
+
# since we only evaluated loss with perturbed parameters
|
|
63
|
+
# we only have loss_approx
|
|
64
|
+
return spsa_grads, None, loss_plus
|
|
65
|
+
|
|
66
|
+
"""
|
|
22
67
|
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
23
68
|
super().__init__(defaults)
|
|
24
69
|
self._target: GradTarget = target
|
|
25
70
|
|
|
26
71
|
@abstractmethod
|
|
27
|
-
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None
|
|
72
|
+
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
|
|
28
73
|
"""Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
|
|
29
74
|
|
|
30
75
|
def pre_step(self, var: Var) -> Var | None:
|
|
@@ -45,9 +90,9 @@ class GradApproximator(Module, ABC):
|
|
|
45
90
|
def approx_closure(backward=True):
|
|
46
91
|
if backward:
|
|
47
92
|
# set loss to None because closure might be evaluated at different points
|
|
48
|
-
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None
|
|
93
|
+
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
|
|
49
94
|
for p, g in zip(params, grad): p.grad = g
|
|
50
|
-
return l if l is not None else
|
|
95
|
+
return l if l is not None else closure(False)
|
|
51
96
|
return closure(False)
|
|
52
97
|
|
|
53
98
|
var.closure = approx_closure
|
|
@@ -55,7 +100,7 @@ class GradApproximator(Module, ABC):
|
|
|
55
100
|
|
|
56
101
|
# if var.grad is not None:
|
|
57
102
|
# warnings.warn('Using grad approximator when `var.grad` is already set.')
|
|
58
|
-
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss
|
|
103
|
+
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
|
|
59
104
|
if loss_approx is not None: var.loss_approx = loss_approx
|
|
60
105
|
if loss is not None: var.loss = var.loss_approx = loss
|
|
61
106
|
if self._target == 'grad': var.grad = list(grad)
|
|
@@ -63,4 +108,4 @@ class GradApproximator(Module, ABC):
|
|
|
63
108
|
else: raise ValueError(self._target)
|
|
64
109
|
return var
|
|
65
110
|
|
|
66
|
-
_FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', '
|
|
111
|
+
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa5']
|