torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable,
|
|
9
|
+
from ...core import Chainable, apply_transform, Module
|
|
10
10
|
from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
|
|
11
11
|
|
|
12
12
|
class NystromSketchAndSolve(Module):
|
|
@@ -15,7 +15,7 @@ class NystromSketchAndSolve(Module):
|
|
|
15
15
|
rank: int,
|
|
16
16
|
reg: float = 1e-3,
|
|
17
17
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
-
h=1e-
|
|
18
|
+
h=1e-3,
|
|
19
19
|
inner: Chainable | None = None,
|
|
20
20
|
seed: int | None = None,
|
|
21
21
|
):
|
|
@@ -26,10 +26,10 @@ class NystromSketchAndSolve(Module):
|
|
|
26
26
|
self.set_child('inner', inner)
|
|
27
27
|
|
|
28
28
|
@torch.no_grad
|
|
29
|
-
def step(self,
|
|
30
|
-
params = TensorList(
|
|
29
|
+
def step(self, var):
|
|
30
|
+
params = TensorList(var.params)
|
|
31
31
|
|
|
32
|
-
closure =
|
|
32
|
+
closure = var.closure
|
|
33
33
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
34
|
|
|
35
35
|
settings = self.settings[params[0]]
|
|
@@ -47,7 +47,7 @@ class NystromSketchAndSolve(Module):
|
|
|
47
47
|
|
|
48
48
|
# ---------------------- Hessian vector product function --------------------- #
|
|
49
49
|
if hvp_method == 'autograd':
|
|
50
|
-
grad =
|
|
50
|
+
grad = var.get_grad(create_graph=True)
|
|
51
51
|
|
|
52
52
|
def H_mm(x):
|
|
53
53
|
with torch.enable_grad():
|
|
@@ -57,7 +57,7 @@ class NystromSketchAndSolve(Module):
|
|
|
57
57
|
else:
|
|
58
58
|
|
|
59
59
|
with torch.enable_grad():
|
|
60
|
-
grad =
|
|
60
|
+
grad = var.get_grad()
|
|
61
61
|
|
|
62
62
|
if hvp_method == 'forward':
|
|
63
63
|
def H_mm(x):
|
|
@@ -74,14 +74,14 @@ class NystromSketchAndSolve(Module):
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
# -------------------------------- inner step -------------------------------- #
|
|
77
|
-
b =
|
|
77
|
+
b = var.get_update()
|
|
78
78
|
if 'inner' in self.children:
|
|
79
|
-
b =
|
|
79
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
80
80
|
|
|
81
81
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
82
82
|
x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
|
|
83
|
-
|
|
84
|
-
return
|
|
83
|
+
var.update = vec_to_tensors(x, reference=params)
|
|
84
|
+
return var
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
|
|
@@ -93,7 +93,7 @@ class NystromPCG(Module):
|
|
|
93
93
|
tol=1e-3,
|
|
94
94
|
reg: float = 1e-6,
|
|
95
95
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
96
|
-
h=1e-
|
|
96
|
+
h=1e-3,
|
|
97
97
|
inner: Chainable | None = None,
|
|
98
98
|
seed: int | None = None,
|
|
99
99
|
):
|
|
@@ -104,10 +104,10 @@ class NystromPCG(Module):
|
|
|
104
104
|
self.set_child('inner', inner)
|
|
105
105
|
|
|
106
106
|
@torch.no_grad
|
|
107
|
-
def step(self,
|
|
108
|
-
params = TensorList(
|
|
107
|
+
def step(self, var):
|
|
108
|
+
params = TensorList(var.params)
|
|
109
109
|
|
|
110
|
-
closure =
|
|
110
|
+
closure = var.closure
|
|
111
111
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
112
112
|
|
|
113
113
|
settings = self.settings[params[0]]
|
|
@@ -129,7 +129,7 @@ class NystromPCG(Module):
|
|
|
129
129
|
|
|
130
130
|
# ---------------------- Hessian vector product function --------------------- #
|
|
131
131
|
if hvp_method == 'autograd':
|
|
132
|
-
grad =
|
|
132
|
+
grad = var.get_grad(create_graph=True)
|
|
133
133
|
|
|
134
134
|
def H_mm(x):
|
|
135
135
|
with torch.enable_grad():
|
|
@@ -139,7 +139,7 @@ class NystromPCG(Module):
|
|
|
139
139
|
else:
|
|
140
140
|
|
|
141
141
|
with torch.enable_grad():
|
|
142
|
-
grad =
|
|
142
|
+
grad = var.get_grad()
|
|
143
143
|
|
|
144
144
|
if hvp_method == 'forward':
|
|
145
145
|
def H_mm(x):
|
|
@@ -156,13 +156,13 @@ class NystromPCG(Module):
|
|
|
156
156
|
|
|
157
157
|
|
|
158
158
|
# -------------------------------- inner step -------------------------------- #
|
|
159
|
-
b =
|
|
159
|
+
b = var.get_update()
|
|
160
160
|
if 'inner' in self.children:
|
|
161
|
-
b =
|
|
161
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
162
162
|
|
|
163
163
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
164
164
|
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
165
|
-
|
|
166
|
-
return
|
|
165
|
+
var.update = vec_to_tensors(x, reference=params)
|
|
166
|
+
return var
|
|
167
167
|
|
|
168
168
|
|
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from ...core import Modular, Module,
|
|
9
|
+
from ...core import Modular, Module, Var
|
|
10
10
|
from ...utils import NumberList, TensorList
|
|
11
11
|
from ...utils.derivatives import jacobian_wrt
|
|
12
12
|
from ..grad_approximation import GradApproximator, GradTarget
|
|
@@ -17,24 +17,24 @@ class Reformulation(Module, ABC):
|
|
|
17
17
|
super().__init__(defaults)
|
|
18
18
|
|
|
19
19
|
@abstractmethod
|
|
20
|
-
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor],
|
|
20
|
+
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
21
21
|
"""returns loss and gradient, if backward is False then gradient can be None"""
|
|
22
22
|
|
|
23
|
-
def pre_step(self,
|
|
23
|
+
def pre_step(self, var: Var) -> Var | None:
|
|
24
24
|
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
25
25
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
26
|
-
return
|
|
26
|
+
return var
|
|
27
27
|
|
|
28
|
-
def step(self,
|
|
29
|
-
ret = self.pre_step(
|
|
30
|
-
if isinstance(ret,
|
|
28
|
+
def step(self, var):
|
|
29
|
+
ret = self.pre_step(var)
|
|
30
|
+
if isinstance(ret, Var): var = ret
|
|
31
31
|
|
|
32
|
-
if
|
|
33
|
-
params, closure =
|
|
32
|
+
if var.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
+
params, closure = var.params, var.closure
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def modified_closure(backward=True):
|
|
37
|
-
loss, grad = self.closure(backward, closure, params,
|
|
37
|
+
loss, grad = self.closure(backward, closure, params, var)
|
|
38
38
|
|
|
39
39
|
if grad is not None:
|
|
40
40
|
for p,g in zip(params, grad):
|
|
@@ -42,8 +42,8 @@ class Reformulation(Module, ABC):
|
|
|
42
42
|
|
|
43
43
|
return loss
|
|
44
44
|
|
|
45
|
-
|
|
46
|
-
return
|
|
45
|
+
var.closure = modified_closure
|
|
46
|
+
return var
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _decay_sigma_(self: Module, params):
|
|
@@ -58,7 +58,7 @@ def _generate_perturbations_to_state_(self: Module, params: TensorList, n_sample
|
|
|
58
58
|
for param, prt in zip(params, zip(*perturbations)):
|
|
59
59
|
self.state[param]['perturbations'] = prt
|
|
60
60
|
|
|
61
|
-
def _clear_state_hook(optimizer: Modular,
|
|
61
|
+
def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
|
|
62
62
|
for m in optimizer.unrolled_modules:
|
|
63
63
|
if m is not self:
|
|
64
64
|
m.reset()
|
|
@@ -85,12 +85,12 @@ class GaussianHomotopy(Reformulation):
|
|
|
85
85
|
else: self.global_state['generator'] = None
|
|
86
86
|
return self.global_state['generator']
|
|
87
87
|
|
|
88
|
-
def pre_step(self,
|
|
89
|
-
params = TensorList(
|
|
88
|
+
def pre_step(self, var):
|
|
89
|
+
params = TensorList(var.params)
|
|
90
90
|
settings = self.settings[params[0]]
|
|
91
91
|
n_samples = settings['n_samples']
|
|
92
|
-
init_sigma = self.
|
|
93
|
-
sigmas = self.get_state('sigma',
|
|
92
|
+
init_sigma = [self.settings[p]['init_sigma'] for p in params]
|
|
93
|
+
sigmas = self.get_state(params, 'sigma', init=init_sigma)
|
|
94
94
|
|
|
95
95
|
if any('perturbations' not in self.state[p] for p in params):
|
|
96
96
|
generator = self._get_generator(settings['seed'], params)
|
|
@@ -109,9 +109,9 @@ class GaussianHomotopy(Reformulation):
|
|
|
109
109
|
tol = settings['tol']
|
|
110
110
|
if tol is not None and not decayed:
|
|
111
111
|
if not any('prev_params' in self.state[p] for p in params):
|
|
112
|
-
prev_params = self.get_state('prev_params',
|
|
112
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
113
113
|
else:
|
|
114
|
-
prev_params = self.get_state('prev_params',
|
|
114
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
|
|
115
115
|
s = params - prev_params
|
|
116
116
|
|
|
117
117
|
if s.abs().global_max() <= tol:
|
|
@@ -124,10 +124,10 @@ class GaussianHomotopy(Reformulation):
|
|
|
124
124
|
generator = self._get_generator(settings['seed'], params)
|
|
125
125
|
_generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
|
|
126
126
|
if settings['clear_state']:
|
|
127
|
-
|
|
127
|
+
var.post_step_hooks.append(partial(_clear_state_hook, self=self))
|
|
128
128
|
|
|
129
129
|
@torch.no_grad
|
|
130
|
-
def closure(self, backward, closure, params,
|
|
130
|
+
def closure(self, backward, closure, params, var):
|
|
131
131
|
params = TensorList(params)
|
|
132
132
|
|
|
133
133
|
settings = self.settings[params[0]]
|
|
@@ -67,7 +67,7 @@ class LaplacianSmoothing(Transform):
|
|
|
67
67
|
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
68
68
|
Only has effect if `layerwise` is True. Defaults to 4.
|
|
69
69
|
target (str, optional):
|
|
70
|
-
what to set on
|
|
70
|
+
what to set on var.
|
|
71
71
|
|
|
72
72
|
Reference:
|
|
73
73
|
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
@@ -82,19 +82,17 @@ class LaplacianSmoothing(Transform):
|
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
@torch.no_grad
|
|
85
|
-
def
|
|
86
|
-
layerwise =
|
|
85
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
86
|
+
layerwise = settings[0]['layerwise']
|
|
87
87
|
|
|
88
88
|
# layerwise laplacian smoothing
|
|
89
89
|
if layerwise:
|
|
90
90
|
|
|
91
91
|
# precompute the denominator for each layer and store it in each parameters state
|
|
92
92
|
smoothed_target = TensorList()
|
|
93
|
-
for p, t in zip(params, tensors):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
state = self.state[p]
|
|
97
|
-
if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, settings['sigma'])
|
|
93
|
+
for p, t, state, setting in zip(params, tensors, states, settings):
|
|
94
|
+
if p.numel() > setting['min_numel']:
|
|
95
|
+
if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, setting['sigma'])
|
|
98
96
|
smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
|
|
99
97
|
else:
|
|
100
98
|
smoothed_target.append(t)
|
|
@@ -106,7 +104,7 @@ class LaplacianSmoothing(Transform):
|
|
|
106
104
|
# precompute full denominator
|
|
107
105
|
tensors = TensorList(tensors)
|
|
108
106
|
if self.global_state.get('full_denominator', None) is None:
|
|
109
|
-
self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(),
|
|
107
|
+
self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), settings[0]['sigma'])
|
|
110
108
|
|
|
111
109
|
# apply the smoothing
|
|
112
110
|
vec = tensors.to_vec()
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_
|
|
1
|
+
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, NormalizedWeightDecay
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from collections.abc import Iterable, Sequence
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
6
|
from ...core import Module, Target, Transform
|
|
6
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
@torch.no_grad
|
|
9
11
|
def weight_decay_(
|
|
@@ -25,12 +27,44 @@ class WeightDecay(Transform):
|
|
|
25
27
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
28
|
|
|
27
29
|
@torch.no_grad
|
|
28
|
-
def
|
|
29
|
-
weight_decay =
|
|
30
|
-
ord =
|
|
30
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
31
|
+
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
32
|
+
ord = settings[0]['ord']
|
|
31
33
|
|
|
32
34
|
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
|
|
33
35
|
|
|
36
|
+
class NormalizedWeightDecay(Transform):
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
weight_decay: float = 0.1,
|
|
40
|
+
ord: int = 2,
|
|
41
|
+
norm_input: Literal["update", "grad", "params"] = "update",
|
|
42
|
+
target: Target = "update",
|
|
43
|
+
):
|
|
44
|
+
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
|
|
45
|
+
super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
|
|
46
|
+
|
|
47
|
+
@torch.no_grad
|
|
48
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
49
|
+
weight_decay = NumberList(s['weight_decay'] for s in settings)
|
|
50
|
+
|
|
51
|
+
ord = settings[0]['ord']
|
|
52
|
+
norm_input = settings[0]['norm_input']
|
|
53
|
+
|
|
54
|
+
if norm_input == 'update': src = TensorList(tensors)
|
|
55
|
+
elif norm_input == 'grad':
|
|
56
|
+
assert grads is not None
|
|
57
|
+
src = TensorList(grads)
|
|
58
|
+
elif norm_input == 'params':
|
|
59
|
+
src = TensorList(params)
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError(norm_input)
|
|
62
|
+
|
|
63
|
+
norm = src.global_vector_norm(ord)
|
|
64
|
+
|
|
65
|
+
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
|
|
66
|
+
|
|
67
|
+
|
|
34
68
|
@torch.no_grad
|
|
35
69
|
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
|
|
36
70
|
"""directly decays weights in-place"""
|
|
@@ -44,9 +78,9 @@ class DirectWeightDecay(Module):
|
|
|
44
78
|
super().__init__(defaults)
|
|
45
79
|
|
|
46
80
|
@torch.no_grad
|
|
47
|
-
def step(self,
|
|
48
|
-
weight_decay = self.get_settings('weight_decay',
|
|
49
|
-
ord = self.settings[
|
|
81
|
+
def step(self, var):
|
|
82
|
+
weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
|
|
83
|
+
ord = self.settings[var.params[0]]['ord']
|
|
50
84
|
|
|
51
|
-
decay_weights_(
|
|
52
|
-
return
|
|
85
|
+
decay_weights_(var.params, weight_decay, ord)
|
|
86
|
+
return var
|
|
@@ -24,8 +24,8 @@ class Wrap(Module):
|
|
|
24
24
|
return super().set_param_groups(param_groups)
|
|
25
25
|
|
|
26
26
|
@torch.no_grad
|
|
27
|
-
def step(self,
|
|
28
|
-
params =
|
|
27
|
+
def step(self, var):
|
|
28
|
+
params = var.params
|
|
29
29
|
|
|
30
30
|
# initialize opt on 1st step
|
|
31
31
|
if self.optimizer is None:
|
|
@@ -35,18 +35,18 @@ class Wrap(Module):
|
|
|
35
35
|
|
|
36
36
|
# set grad to update
|
|
37
37
|
orig_grad = [p.grad for p in params]
|
|
38
|
-
for p, u in zip(params,
|
|
38
|
+
for p, u in zip(params, var.get_update()):
|
|
39
39
|
p.grad = u
|
|
40
40
|
|
|
41
41
|
# if this module is last, can step with _opt directly
|
|
42
42
|
# direct step can't be applied if next module is LR but _opt doesn't support lr,
|
|
43
43
|
# and if there are multiple different per-parameter lrs (would be annoying to support)
|
|
44
|
-
if
|
|
45
|
-
(
|
|
44
|
+
if var.is_last and (
|
|
45
|
+
(var.last_module_lrs is None)
|
|
46
46
|
or
|
|
47
|
-
(('lr' in self.optimizer.defaults) and (len(set(
|
|
47
|
+
(('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
|
|
48
48
|
):
|
|
49
|
-
lr = 1 if
|
|
49
|
+
lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
|
|
50
50
|
|
|
51
51
|
# update optimizer lr with desired lr
|
|
52
52
|
if lr != 1:
|
|
@@ -68,19 +68,19 @@ class Wrap(Module):
|
|
|
68
68
|
for p, g in zip(params, orig_grad):
|
|
69
69
|
p.grad = g
|
|
70
70
|
|
|
71
|
-
|
|
72
|
-
return
|
|
71
|
+
var.stop = True; var.skip_update = True
|
|
72
|
+
return var
|
|
73
73
|
|
|
74
74
|
# this is not the last module, meaning update is difference in parameters
|
|
75
75
|
params_before_step = [p.clone() for p in params]
|
|
76
76
|
self.optimizer.step() # step and update params
|
|
77
77
|
for p, g in zip(params, orig_grad):
|
|
78
78
|
p.grad = g
|
|
79
|
-
|
|
79
|
+
var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
|
|
80
80
|
for p, o in zip(params, params_before_step):
|
|
81
81
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
82
82
|
|
|
83
|
-
return
|
|
83
|
+
return var
|
|
84
84
|
|
|
85
85
|
def reset(self):
|
|
86
86
|
super().reset()
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import directsearch
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from directsearch.ds import DEFAULT_PARAMS
|
|
9
|
+
|
|
10
|
+
from ...modules.second_order.newton import tikhonov_
|
|
11
|
+
from ...utils import Optimizer, TensorList
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_float(x):
|
|
15
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
16
|
+
if isinstance(x, np.ndarray): return x.item()
|
|
17
|
+
return float(x)
|
|
18
|
+
|
|
19
|
+
def _ensure_numpy(x):
|
|
20
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu()
|
|
21
|
+
if isinstance(x, np.ndarray): return x
|
|
22
|
+
return np.array(x)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Closure = Callable[[bool], Any]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DirectSearch(Optimizer):
|
|
29
|
+
"""Use directsearch as pytorch optimizer.
|
|
30
|
+
|
|
31
|
+
Note that this performs full minimization on each step,
|
|
32
|
+
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
33
|
+
solution.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
params (_type_): _description_
|
|
37
|
+
maxevals (_type_, optional): _description_. Defaults to DEFAULT_PARAMS['maxevals'].
|
|
38
|
+
"""
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
params,
|
|
42
|
+
maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
|
|
43
|
+
rho = DEFAULT_PARAMS['rho'], # Forcing function
|
|
44
|
+
sketch_dim = DEFAULT_PARAMS['sketch_dim'], # Target dimension for sketching
|
|
45
|
+
sketch_type = DEFAULT_PARAMS['sketch_type'], # Sketching technique
|
|
46
|
+
poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
|
|
47
|
+
alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
|
|
48
|
+
alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
|
|
49
|
+
alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
|
|
50
|
+
gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
|
|
51
|
+
gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
|
|
52
|
+
verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
|
|
53
|
+
print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
|
|
54
|
+
use_stochastic_three_points = DEFAULT_PARAMS['use_stochastic_three_points'], # Boolean for a specific method
|
|
55
|
+
rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
|
|
56
|
+
):
|
|
57
|
+
super().__init__(params, {})
|
|
58
|
+
|
|
59
|
+
kwargs = locals().copy()
|
|
60
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__']
|
|
61
|
+
self._kwargs = kwargs
|
|
62
|
+
|
|
63
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
64
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
65
|
+
return _ensure_float(closure(False))
|
|
66
|
+
|
|
67
|
+
@torch.no_grad
|
|
68
|
+
def step(self, closure: Closure):
|
|
69
|
+
params = self.get_params()
|
|
70
|
+
|
|
71
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
72
|
+
|
|
73
|
+
res = directsearch.solve(
|
|
74
|
+
partial(self._objective, params = params, closure = closure),
|
|
75
|
+
x0 = x0,
|
|
76
|
+
**self._kwargs
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
80
|
+
return res.f
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DirectSearchDS(Optimizer):
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
params,
|
|
88
|
+
maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
|
|
89
|
+
rho = DEFAULT_PARAMS['rho'], # Forcing function
|
|
90
|
+
poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
|
|
91
|
+
alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
|
|
92
|
+
alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
|
|
93
|
+
alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
|
|
94
|
+
gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
|
|
95
|
+
gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
|
|
96
|
+
verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
|
|
97
|
+
print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
|
|
98
|
+
rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
|
|
99
|
+
):
|
|
100
|
+
super().__init__(params, {})
|
|
101
|
+
|
|
102
|
+
kwargs = locals().copy()
|
|
103
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__']
|
|
104
|
+
self._kwargs = kwargs
|
|
105
|
+
|
|
106
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
107
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
108
|
+
return _ensure_float(closure(False))
|
|
109
|
+
|
|
110
|
+
@torch.no_grad
|
|
111
|
+
def step(self, closure: Closure):
|
|
112
|
+
params = self.get_params()
|
|
113
|
+
|
|
114
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
115
|
+
|
|
116
|
+
res = directsearch.solve_directsearch(
|
|
117
|
+
partial(self._objective, params = params, closure = closure),
|
|
118
|
+
x0 = x0,
|
|
119
|
+
**self._kwargs
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
123
|
+
return res.f
|
|
124
|
+
|
|
125
|
+
class DirectSearchProbabilistic(Optimizer):
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
params,
|
|
129
|
+
maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
|
|
130
|
+
rho = DEFAULT_PARAMS['rho'], # Forcing function
|
|
131
|
+
alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
|
|
132
|
+
alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
|
|
133
|
+
alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
|
|
134
|
+
gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
|
|
135
|
+
gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
|
|
136
|
+
verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
|
|
137
|
+
print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
|
|
138
|
+
rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
|
|
139
|
+
):
|
|
140
|
+
super().__init__(params, {})
|
|
141
|
+
|
|
142
|
+
kwargs = locals().copy()
|
|
143
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__']
|
|
144
|
+
self._kwargs = kwargs
|
|
145
|
+
|
|
146
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
147
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
148
|
+
return _ensure_float(closure(False))
|
|
149
|
+
|
|
150
|
+
@torch.no_grad
|
|
151
|
+
def step(self, closure: Closure):
|
|
152
|
+
params = self.get_params()
|
|
153
|
+
|
|
154
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
155
|
+
|
|
156
|
+
res = directsearch.solve_probabilistic_directsearch(
|
|
157
|
+
partial(self._objective, params = params, closure = closure),
|
|
158
|
+
x0 = x0,
|
|
159
|
+
**self._kwargs
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
163
|
+
return res.f
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class DirectSearchSubspace(Optimizer):
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
params,
|
|
170
|
+
maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
|
|
171
|
+
rho = DEFAULT_PARAMS['rho'], # Forcing function
|
|
172
|
+
sketch_dim = DEFAULT_PARAMS['sketch_dim'], # Target dimension for sketching
|
|
173
|
+
sketch_type = DEFAULT_PARAMS['sketch_type'], # Sketching technique
|
|
174
|
+
poll_type = DEFAULT_PARAMS['poll_type'], # Polling direction type
|
|
175
|
+
alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
|
|
176
|
+
alpha_max = DEFAULT_PARAMS['alpha_max'], # Maximum value for the stepsize
|
|
177
|
+
alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
|
|
178
|
+
gamma_inc = DEFAULT_PARAMS['gamma_inc'], # Increasing factor for the stepsize
|
|
179
|
+
gamma_dec = DEFAULT_PARAMS['gamma_dec'], # Decreasing factor for the stepsize
|
|
180
|
+
verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
|
|
181
|
+
print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
|
|
182
|
+
rho_uses_normd = DEFAULT_PARAMS['rho_uses_normd'], # Forcing function based on direction norm
|
|
183
|
+
):
|
|
184
|
+
super().__init__(params, {})
|
|
185
|
+
|
|
186
|
+
kwargs = locals().copy()
|
|
187
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__']
|
|
188
|
+
self._kwargs = kwargs
|
|
189
|
+
|
|
190
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
191
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
192
|
+
return _ensure_float(closure(False))
|
|
193
|
+
|
|
194
|
+
@torch.no_grad
|
|
195
|
+
def step(self, closure: Closure):
|
|
196
|
+
params = self.get_params()
|
|
197
|
+
|
|
198
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
199
|
+
|
|
200
|
+
res = directsearch.solve_subspace_directsearch(
|
|
201
|
+
partial(self._objective, params = params, closure = closure),
|
|
202
|
+
x0 = x0,
|
|
203
|
+
**self._kwargs
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
207
|
+
return res.f
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class DirectSearchSTP(Optimizer):
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
params,
|
|
215
|
+
maxevals = DEFAULT_PARAMS['maxevals'], # Maximum number of function evaluations
|
|
216
|
+
alpha0 = DEFAULT_PARAMS['alpha0'], # Original stepsize value
|
|
217
|
+
alpha_min = DEFAULT_PARAMS['alpha_min'], # Minimum value for the stepsize
|
|
218
|
+
verbose = DEFAULT_PARAMS['verbose'], # Display information about the method
|
|
219
|
+
print_freq = DEFAULT_PARAMS['print_freq'], # How frequently to display information
|
|
220
|
+
):
|
|
221
|
+
super().__init__(params, {})
|
|
222
|
+
|
|
223
|
+
kwargs = locals().copy()
|
|
224
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__']
|
|
225
|
+
self._kwargs = kwargs
|
|
226
|
+
|
|
227
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
228
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
229
|
+
return _ensure_float(closure(False))
|
|
230
|
+
|
|
231
|
+
@torch.no_grad
|
|
232
|
+
def step(self, closure: Closure):
|
|
233
|
+
params = self.get_params()
|
|
234
|
+
|
|
235
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
236
|
+
|
|
237
|
+
res = directsearch.solve_stp(
|
|
238
|
+
partial(self._objective, params = params, closure = closure),
|
|
239
|
+
x0 = x0,
|
|
240
|
+
**self._kwargs
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
244
|
+
return res.f
|