torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
|
|
8
|
+
def _adam_step(ascent: TensorList, exp_avg: TensorList, exp_avg_sq: TensorList, alpha, beta1, beta2, eps, step:int, max_exp_avg_sqs: TensorList | None, params: TensorList | None = None):
|
|
9
|
+
# Decay the first and second moment running average coefficient
|
|
10
|
+
exp_avg.lerp_compat_(ascent, 1 - beta1)
|
|
11
|
+
exp_avg_sq.mul_(beta2).addcmul_(ascent, ascent.conj(), value=1 - beta2)
|
|
12
|
+
|
|
13
|
+
bias_correction1 = 1 - beta1**step
|
|
14
|
+
bias_correction2 = 1 - beta2**step
|
|
15
|
+
|
|
16
|
+
if max_exp_avg_sqs is not None:
|
|
17
|
+
max_exp_avg_sqs.maximum_(exp_avg_sq)
|
|
18
|
+
denom = max_exp_avg_sqs.sqrt().div_(bias_correction2**0.5).add_(eps)
|
|
19
|
+
else:
|
|
20
|
+
denom = exp_avg_sq.sqrt().div_(bias_correction2**0.5).add_(eps)
|
|
21
|
+
|
|
22
|
+
if params is None:
|
|
23
|
+
return (exp_avg / denom).mul_(alpha / bias_correction1)
|
|
24
|
+
|
|
25
|
+
# else directly apply the update to params
|
|
26
|
+
params.addcdiv_(exp_avg, denom, value = -(alpha / bias_correction1))
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Adam(OptimizerModule):
|
|
32
|
+
"""Adam. Combines momentum and RMSProp. Exactly matches `torch.optim.Adam`.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
|
|
36
|
+
beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
|
|
37
|
+
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
38
|
+
amsgrad (bool, optional):
|
|
39
|
+
whether to use the AMSGrad variant of this algorithm from
|
|
40
|
+
On the Convergence of Adam and Beyond (default: False).
|
|
41
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
42
|
+
"""
|
|
43
|
+
def __init__(self, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, alpha: float = 1, amsgrad=False):
|
|
44
|
+
defaults = dict(alpha = alpha, beta1=beta1, beta2=beta2, eps=eps)
|
|
45
|
+
super().__init__(defaults)
|
|
46
|
+
|
|
47
|
+
self.cur_step = 1
|
|
48
|
+
self.amsgrad = amsgrad
|
|
49
|
+
|
|
50
|
+
@torch.no_grad
|
|
51
|
+
def step(self, state):
|
|
52
|
+
# Adam step is a bit differet from other optimizer steps
|
|
53
|
+
# due to how common it is, I implemented two additional optimizations,
|
|
54
|
+
|
|
55
|
+
# 1st - if next module is None or if next module is LR and module after is None
|
|
56
|
+
# this will directly update parameters using `addcdiv_`
|
|
57
|
+
|
|
58
|
+
# 2nd - if next module is LR`, adam will "fuse" with it to avoid an additional add operation.
|
|
59
|
+
|
|
60
|
+
# the optimizations are quite verbose and seem to barely have any effect, so I probably won't implement
|
|
61
|
+
# this for other modules
|
|
62
|
+
|
|
63
|
+
settings = self.get_all_group_keys()
|
|
64
|
+
|
|
65
|
+
if self.amsgrad:
|
|
66
|
+
exp_avg, exp_avg_sq, max_exp_avg_sqs = self.get_state_keys('exp_avg', 'exp_avg_sq', 'max_exp_avg_sqs')
|
|
67
|
+
else:
|
|
68
|
+
exp_avg, exp_avg_sq = self.get_state_keys('exp_avg', 'exp_avg_sq')
|
|
69
|
+
max_exp_avg_sqs = None
|
|
70
|
+
|
|
71
|
+
params = None
|
|
72
|
+
|
|
73
|
+
# apply addcdiv if next module is None
|
|
74
|
+
if self.next_module is None: params = self.get_params()
|
|
75
|
+
|
|
76
|
+
# fuse with LR module if it is next
|
|
77
|
+
if self.next_module is not None and self.next_module.IS_LR_MODULE:
|
|
78
|
+
alpha = self.next_module.get_group_key('lr') * settings['alpha']
|
|
79
|
+
self.next_module._skip = True # type:ignore
|
|
80
|
+
|
|
81
|
+
# apply addcdiv if module after LR is None.
|
|
82
|
+
if self.next_module.next_module is None: params = self.get_params()
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
alpha = settings['alpha']
|
|
86
|
+
|
|
87
|
+
# get params if ascent is None so we need params to access their gradient as initial ascent
|
|
88
|
+
if state.ascent is None:
|
|
89
|
+
if params is None: pg = self.get_params()
|
|
90
|
+
else: pg = params
|
|
91
|
+
else:
|
|
92
|
+
pg = None
|
|
93
|
+
|
|
94
|
+
ret = _adam_step(
|
|
95
|
+
ascent=state.maybe_use_grad_(pg),
|
|
96
|
+
exp_avg = exp_avg,
|
|
97
|
+
exp_avg_sq = exp_avg_sq,
|
|
98
|
+
alpha = alpha,
|
|
99
|
+
beta1 = settings['beta1'],
|
|
100
|
+
beta2 = settings['beta2'],
|
|
101
|
+
eps = settings['eps'],
|
|
102
|
+
step = self.cur_step,
|
|
103
|
+
max_exp_avg_sqs = max_exp_avg_sqs,
|
|
104
|
+
params = params
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self.cur_step += 1
|
|
108
|
+
if params is None:
|
|
109
|
+
assert ret is not None
|
|
110
|
+
state.ascent = ret
|
|
111
|
+
return self._update_params_or_step_with_next(state)
|
|
112
|
+
|
|
113
|
+
# next module is either None or LR
|
|
114
|
+
if self.next_module is None: return state.get_loss()
|
|
115
|
+
|
|
116
|
+
# step with LR, which has _skip = True so it won't apply lr, but may step with the scheduler
|
|
117
|
+
self.next_module._update(state, None) # type:ignore
|
|
118
|
+
return state.get_loss()
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import OptimizerModule
|
|
4
|
+
from ...tensorlist import TensorList
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _lion_step_(ascent: TensorList, ema: TensorList, beta1, beta2,):
|
|
8
|
+
update = ema.lerp_compat(ascent, 1-beta1).sign_()
|
|
9
|
+
ema.lerp_compat_(ascent, 1-beta2)
|
|
10
|
+
return update
|
|
11
|
+
|
|
12
|
+
class Lion(OptimizerModule):
|
|
13
|
+
"""Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
beta1 (float, optional): dampening for momentum. Defaults to 0.9.
|
|
17
|
+
beta2 (float, optional): momentum factor. Defaults to 0.99.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
|
|
21
|
+
defaults = dict(beta1=beta1, beta2=beta2)
|
|
22
|
+
super().__init__(defaults)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad
|
|
25
|
+
def _update(self, state, ascent):
|
|
26
|
+
beta1, beta2 = self.get_group_keys('beta1', 'beta2')
|
|
27
|
+
ema = self.get_state_key('ema')
|
|
28
|
+
return _lion_step_(ascent,ema,beta1,beta2)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _rmsprop_step_(ascent: TensorList, mean_sqr: TensorList, smoothing, eps: TensorList):
|
|
10
|
+
mean_sqr.mul_(smoothing).addcmul_(ascent, ascent, value = 1 - smoothing)
|
|
11
|
+
return ascent.div_(mean_sqr.sqrt().add_(eps))
|
|
12
|
+
|
|
13
|
+
def _centered_rmsprop_step_(ascent: TensorList, mean_sqr: TensorList, mean: TensorList, smoothing, eps: TensorList):
|
|
14
|
+
mean_sqr.mul_(smoothing).addcmul_(ascent, ascent, value = 1 - smoothing)
|
|
15
|
+
mean.lerp_compat_(ascent, 1-smoothing)
|
|
16
|
+
return ascent.div_(mean_sqr.addcmul(mean, mean, value=-1).sqrt_().add_(eps))
|
|
17
|
+
|
|
18
|
+
class RMSProp(OptimizerModule):
|
|
19
|
+
"""
|
|
20
|
+
Divides ascent direction by running average of its mean square root.
|
|
21
|
+
|
|
22
|
+
Exactly matches `torch.optim.RMSProp`.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
smoothing (float, optional):
|
|
26
|
+
smoothing constant (decay of ascent mean square root running average).
|
|
27
|
+
Defaults to 0.99.
|
|
28
|
+
eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-8.
|
|
29
|
+
centered (float, optional):
|
|
30
|
+
if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance.
|
|
31
|
+
Defaults to False.
|
|
32
|
+
|
|
33
|
+
reference
|
|
34
|
+
https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, smoothing: float = 0.99, eps: float = 1e-8, centered=False):
|
|
37
|
+
|
|
38
|
+
defaults = dict(smoothing = smoothing, eps = eps)
|
|
39
|
+
super().__init__(defaults)
|
|
40
|
+
self.centered = centered
|
|
41
|
+
|
|
42
|
+
@torch.no_grad
|
|
43
|
+
def _update(self, state, ascent):
|
|
44
|
+
settings = self.get_all_group_keys()
|
|
45
|
+
if self.centered:
|
|
46
|
+
mean, mean_sqr = self.get_state_keys('mean', 'mean_sqr')
|
|
47
|
+
updated_direction = _centered_rmsprop_step_(ascent, mean_sqr, mean, settings['smoothing'], settings['eps'])
|
|
48
|
+
else:
|
|
49
|
+
mean_sqr = self.get_state_key('mean_sqr')
|
|
50
|
+
updated_direction = _rmsprop_step_(ascent, mean_sqr, settings['smoothing'], settings['eps'])
|
|
51
|
+
return updated_direction
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from collections import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...tensorlist import TensorList, where
|
|
6
|
+
from ...core import OptimizerModule
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _bool_ones_like(x):
|
|
10
|
+
return torch.ones_like(x, dtype=torch.bool)
|
|
11
|
+
|
|
12
|
+
class Rprop(OptimizerModule):
|
|
13
|
+
"""
|
|
14
|
+
Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
|
|
15
|
+
or `nminus` if it did. Then the update is applied with the sign of the current gradient.
|
|
16
|
+
|
|
17
|
+
Additionally, if gradient changes sign, the update for that weight is reverted.
|
|
18
|
+
Next step, magnitude for that weight won't change.
|
|
19
|
+
|
|
20
|
+
Compared to pytorch this also implements backtracking update when sign changes.
|
|
21
|
+
To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
25
|
+
nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
|
|
26
|
+
lb (float): minimum step size, can be None (default: 1e-6)
|
|
27
|
+
ub (float): maximum step size, can be None (default: 50)
|
|
28
|
+
backtrack (float):
|
|
29
|
+
if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
|
|
30
|
+
When this is False, this exactly matches pytorch Rprop. (default: True)
|
|
31
|
+
alpha (float): learning rate (default: 1).
|
|
32
|
+
|
|
33
|
+
reference
|
|
34
|
+
*Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
|
|
35
|
+
The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
|
|
36
|
+
"""
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
nplus: float = 1.2,
|
|
40
|
+
nminus: float = 0.5,
|
|
41
|
+
lb: float | None = 1e-6,
|
|
42
|
+
ub: float | None = 50,
|
|
43
|
+
backtrack=True,
|
|
44
|
+
alpha: float = 1,
|
|
45
|
+
):
|
|
46
|
+
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
47
|
+
super().__init__(defaults)
|
|
48
|
+
self.current_step = 0
|
|
49
|
+
self.backtrack = backtrack
|
|
50
|
+
|
|
51
|
+
@torch.no_grad
|
|
52
|
+
def _update(self, state, ascent):
|
|
53
|
+
params = self.get_params()
|
|
54
|
+
|
|
55
|
+
sign = ascent.sign_()
|
|
56
|
+
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
|
|
57
|
+
prev, allowed, magnitudes = self.get_state_keys(
|
|
58
|
+
'prev', 'allowed', 'magnitudes',
|
|
59
|
+
inits = [torch.zeros_like, _bool_ones_like, torch.zeros_like],
|
|
60
|
+
params=params
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# initialize on 1st step
|
|
64
|
+
if self.current_step == 0:
|
|
65
|
+
magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
|
|
66
|
+
ascent = magnitudes * sign
|
|
67
|
+
prev.copy_(ascent)
|
|
68
|
+
self.current_step += 1
|
|
69
|
+
return ascent
|
|
70
|
+
|
|
71
|
+
mul = (sign * prev).mul_(allowed)
|
|
72
|
+
|
|
73
|
+
sign_changed = mul < 0
|
|
74
|
+
sign_same = mul > 0
|
|
75
|
+
zeroes = mul == 0
|
|
76
|
+
|
|
77
|
+
mul.fill_(1)
|
|
78
|
+
mul.masked_fill_(sign_changed, nminus)
|
|
79
|
+
mul.masked_fill_(sign_same, nplus)
|
|
80
|
+
|
|
81
|
+
# multiply magnitudes based on sign change and clamp to bounds
|
|
82
|
+
magnitudes.mul_(mul).clamp_(lb, ub)
|
|
83
|
+
|
|
84
|
+
# revert update if sign changed
|
|
85
|
+
if self.backtrack:
|
|
86
|
+
ascent = sign.mul_(magnitudes)
|
|
87
|
+
ascent.masked_set_(sign_changed, prev.neg_())
|
|
88
|
+
else:
|
|
89
|
+
ascent = sign.mul_(magnitudes * ~sign_changed)
|
|
90
|
+
|
|
91
|
+
# update allowed to only have weights where last update wasn't reverted
|
|
92
|
+
allowed.set_(sign_same | zeroes)
|
|
93
|
+
|
|
94
|
+
prev.copy_(ascent)
|
|
95
|
+
self.current_step += 1
|
|
96
|
+
return ascent
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import typing as T
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import OptimizerModule
|
|
6
|
+
from ..momentum.momentum import _heavyball_step, _nesterov_step_
|
|
7
|
+
|
|
8
|
+
class SGD(OptimizerModule):
|
|
9
|
+
"""Same as `torch.optim.SGD` but as an optimizer module. Exactly matches `torch.optim.SGD`, except
|
|
10
|
+
nesterov momentum additionally supports dampening, and negative momentum is allowed.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
momentum (float, optional): momentum. Defaults to 0.
|
|
14
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
15
|
+
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
16
|
+
nesterov (bool, optional):
|
|
17
|
+
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
18
|
+
alpha (float, optional): learning rate. Defaults to 1e-3.
|
|
19
|
+
"""
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
momentum: float = 0,
|
|
23
|
+
dampening: float = 0,
|
|
24
|
+
weight_decay: float = 0,
|
|
25
|
+
nesterov: bool = False,
|
|
26
|
+
alpha: float = 1,
|
|
27
|
+
):
|
|
28
|
+
|
|
29
|
+
defaults = dict(alpha=alpha, momentum=momentum, dampening=dampening, weight_decay=weight_decay,)
|
|
30
|
+
super().__init__(defaults)
|
|
31
|
+
self.nesterov = nesterov
|
|
32
|
+
self.current_step = 0
|
|
33
|
+
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def _update(self, state, ascent):
|
|
36
|
+
params = self.get_params()
|
|
37
|
+
settings = self.get_all_group_keys()
|
|
38
|
+
|
|
39
|
+
if any(i != 0 for i in settings['weight_decay']):
|
|
40
|
+
ascent += params * settings['weight_decay']
|
|
41
|
+
|
|
42
|
+
if any(i != 1 for i in settings['alpha']):
|
|
43
|
+
ascent *= settings['alpha']
|
|
44
|
+
|
|
45
|
+
if any(i != 0 for i in settings['momentum']):
|
|
46
|
+
velocity = self.get_state_key('velocity', init = torch.zeros_like if self.nesterov else ascent)
|
|
47
|
+
# consistency with pytorch which on first step only initializes momentum
|
|
48
|
+
if self.current_step > 0 or self.nesterov:
|
|
49
|
+
# nesterov step can be done in-place, polyak returns new direction
|
|
50
|
+
if self.nesterov: _nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
51
|
+
else: ascent = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
|
|
52
|
+
|
|
53
|
+
self.current_step += 1
|
|
54
|
+
return ascent
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Newton-Schulz iteration code is taken from https://github.com/KellerJordan/Muon
|
|
3
|
+
|
|
4
|
+
Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and Franz Cecista and Laker Newhouse and Jeremy Bernstein.
|
|
5
|
+
Muon: An optimizer for hidden layers in neural networks (2024). URL: https://kellerjordan.github.io/posts/muon
|
|
6
|
+
"""
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ...core import OptimizerModule, _Targets
|
|
12
|
+
# from ...utils.compile import maybe_compile
|
|
13
|
+
|
|
14
|
+
def _zeropower_via_newtonschulz5(G, steps):
|
|
15
|
+
"""
|
|
16
|
+
code from https://github.com/KellerJordan/Muon
|
|
17
|
+
|
|
18
|
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
19
|
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
20
|
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
21
|
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
22
|
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
23
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
24
|
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
25
|
+
"""
|
|
26
|
+
assert len(G.shape) == 2
|
|
27
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
28
|
+
X = G.bfloat16()
|
|
29
|
+
if G.size(0) > G.size(1):
|
|
30
|
+
X = X.T
|
|
31
|
+
|
|
32
|
+
# Ensure spectral norm is at most 1
|
|
33
|
+
X = X / (X.norm() + 1e-7)
|
|
34
|
+
# Perform the NS iterations
|
|
35
|
+
for _ in range(steps):
|
|
36
|
+
A = X @ X.T
|
|
37
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
38
|
+
X = a * X + B @ X
|
|
39
|
+
|
|
40
|
+
if G.size(0) > G.size(1):
|
|
41
|
+
X = X.T
|
|
42
|
+
|
|
43
|
+
return X
|
|
44
|
+
|
|
45
|
+
_compiled_zeropower_via_newtonschulz5 = torch.compile(_zeropower_via_newtonschulz5)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def zeropower_via_newtonschulz_(params: Iterable[torch.Tensor], steps: int = 6, adaptive = False, compiled = True):
|
|
49
|
+
"""Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
50
|
+
|
|
51
|
+
This sets gradients in-place.
|
|
52
|
+
|
|
53
|
+
Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
54
|
+
|
|
55
|
+
The orthogonalization code is taken from https://github.com/KellerJordan/Muon
|
|
56
|
+
Args:
|
|
57
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
58
|
+
steps (int): The number of Newton-Schulz iterations to run. (6 is probably always enough).
|
|
59
|
+
The number of Newton-Schulz iterations to run. (6 is probably always enough). Defaults to 6.
|
|
60
|
+
adaptive (bool, optional):
|
|
61
|
+
Enables adaptation to scale of gradients (from https://github.com/leloykun/adaptive-muon). Defaults to False.
|
|
62
|
+
compiled (bool, optional):
|
|
63
|
+
Uses compiled newton-Schulz iteration function. Faster but won't work on windows. Defaults to True.
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
if compiled: fn = _compiled_zeropower_via_newtonschulz5
|
|
68
|
+
else: fn = _zeropower_via_newtonschulz5
|
|
69
|
+
for p in params:
|
|
70
|
+
if p.grad is not None and p.grad.ndim >= 2 and min(p.grad.shape) >= 2:
|
|
71
|
+
G = p.grad.view(p.grad.shape[0], -1)
|
|
72
|
+
X = fn(G, steps)
|
|
73
|
+
|
|
74
|
+
if adaptive:
|
|
75
|
+
# this is from https://github.com/leloykun/adaptive-muon
|
|
76
|
+
X = torch.einsum('ij,ij,ab->ab', G.type_as(X), X, X) # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
77
|
+
|
|
78
|
+
p.grad = X.reshape_as(p.grad).to(p.grad, copy=False)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ZeropowerViaNewtonSchulz(OptimizerModule):
|
|
82
|
+
"""Uses Newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
83
|
+
|
|
84
|
+
To disable orthogonalization for a parameter, put it into a parameter group with "newtonshultz" = False.
|
|
85
|
+
The Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
86
|
+
|
|
87
|
+
The orthogonalization code is taken from https://github.com/KellerJordan/Muon.
|
|
88
|
+
|
|
89
|
+
Note that unlike this module, Muon also uses Adam for gradients that are not orthogonalized,
|
|
90
|
+
so I'd still recommend using it. Maybe use `Wrap` to wrap it into a module (I will make muon
|
|
91
|
+
with selectable modules to optimize non-muon params soon)
|
|
92
|
+
|
|
93
|
+
However not using Adam, or putting Adam module after this to apply it to ALL updates, both seem
|
|
94
|
+
to work quite well too.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
ns_steps (int, optional):
|
|
98
|
+
The number of Newton-Schulz iterations to run. (6 is probably always enough). Defaults to 6.
|
|
99
|
+
adaptive (bool, optional):
|
|
100
|
+
Enables adaptation to scale of gradients (from https://github.com/leloykun/adaptive-muon). Defaults to True.
|
|
101
|
+
compiled (bool, optional):
|
|
102
|
+
Uses compiled newton-Schulz iteration function. Faster but won't work on windows. Defaults to True.
|
|
103
|
+
target (str, optional):
|
|
104
|
+
determines what this module updates.
|
|
105
|
+
|
|
106
|
+
"ascent" - it updates the ascent
|
|
107
|
+
|
|
108
|
+
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
109
|
+
|
|
110
|
+
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
111
|
+
"""
|
|
112
|
+
def __init__(self, ns_steps = 6, adaptive = False, compiled=True, target:_Targets='ascent'):
|
|
113
|
+
defaults = dict(newtonshultz = True, ns_steps=ns_steps, adaptive=adaptive)
|
|
114
|
+
super().__init__(defaults, target=target)
|
|
115
|
+
|
|
116
|
+
if compiled: self._zeropower_via_newtonschulz5 = _compiled_zeropower_via_newtonschulz5
|
|
117
|
+
else: self._zeropower_via_newtonschulz5 = _zeropower_via_newtonschulz5
|
|
118
|
+
|
|
119
|
+
def _update(self, state, ascent):
|
|
120
|
+
toggle, ns_steps, adaptive = self.get_group_keys('newtonshultz', 'ns_steps', 'adaptive', cls=list)
|
|
121
|
+
|
|
122
|
+
for asc, enable, steps, ada in zip(ascent, toggle, ns_steps, adaptive):
|
|
123
|
+
if enable and len([i for i in asc.shape if i > 1]) != 0:
|
|
124
|
+
G = asc.view(asc.shape[0], -1)
|
|
125
|
+
X = self._zeropower_via_newtonschulz5(G, steps)
|
|
126
|
+
|
|
127
|
+
if ada:
|
|
128
|
+
# this is from https://github.com/leloykun/adaptive-muon
|
|
129
|
+
X = torch.einsum('ij,ij,ab->ab', G.type_as(X), X, X) # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
130
|
+
|
|
131
|
+
asc.set_(X.reshape_as(asc).to(asc, copy=False)) # type:ignore
|
|
132
|
+
|
|
133
|
+
return ascent
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class DualNormCorrection(OptimizerModule):
|
|
138
|
+
"""Dual norm correction from https://github.com/leloykun/adaptive-muon.
|
|
139
|
+
|
|
140
|
+
Description from the page:
|
|
141
|
+
|
|
142
|
+
Single-line modification to any (dualizer-based) optimizer that allows the optimizer to adapt to the scale of the gradients as they change during training.
|
|
143
|
+
This is done by scaling the dualized gradient by the clipped dual norm of the original gradient.
|
|
144
|
+
"""
|
|
145
|
+
def __init__(self, adaptive_scale_min: int | None = -1, adaptive_scale_max: int | None = 1):
|
|
146
|
+
defaults = dict(adaptive_scale_min = adaptive_scale_min, adaptive_scale_max = adaptive_scale_max)
|
|
147
|
+
super().__init__(defaults)
|
|
148
|
+
|
|
149
|
+
def _update(self, state, ascent):
|
|
150
|
+
params = self.get_params()
|
|
151
|
+
adaptive_scale_min, adaptive_scale_max = self.get_group_keys('adaptive_scale_min', 'adaptive_scale_max')
|
|
152
|
+
|
|
153
|
+
for asc, grad, min, max in zip(ascent, state.maybe_compute_grad_(params), adaptive_scale_min, adaptive_scale_max):
|
|
154
|
+
if len([i for i in asc.shape if i > 1]) != 0:
|
|
155
|
+
scale = torch.einsum('ij,ij->', grad.view(grad.shape[0], -1), asc.view(asc.shape[0], -1))
|
|
156
|
+
if min is not None or max is not None: scale = scale.clip(min, max)
|
|
157
|
+
asc *= scale
|
|
158
|
+
|
|
159
|
+
return ascent
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Orthogonalization code adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
|
|
2
|
+
|
|
3
|
+
Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
4
|
+
Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
from collections.abc import Iterable, Sequence
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ...core import OptimizerModule, _Targets
|
|
12
|
+
|
|
13
|
+
@torch.no_grad()
|
|
14
|
+
def _orthogonalize_update_(updates: Sequence[torch.Tensor], toggle = None, warn_fail=True) -> None:
|
|
15
|
+
"""adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers"""
|
|
16
|
+
if toggle is None: toggle = [True] * len(updates)
|
|
17
|
+
|
|
18
|
+
# Orthogonalise the gradients using SVD
|
|
19
|
+
for grad, orth in zip(updates, toggle):
|
|
20
|
+
if orth and grad.ndim > 1:
|
|
21
|
+
G: torch.Tensor = grad.view(grad.shape[0], -1)
|
|
22
|
+
orth_G: torch.Tensor | None = None
|
|
23
|
+
try:
|
|
24
|
+
u, s, vt = torch.linalg.svd(G, full_matrices=False) # pylint:disable=not-callable
|
|
25
|
+
orth_G = u @ vt
|
|
26
|
+
except RuntimeError:
|
|
27
|
+
# if warn: logging.warning('Failed to perform SVD, adding some noise.')
|
|
28
|
+
try:
|
|
29
|
+
u, s, v = torch.svd_lowrank(
|
|
30
|
+
G,
|
|
31
|
+
q=1, # assume rank is at least 1
|
|
32
|
+
M=1e-4 * G.mean() * torch.randn_like(G))
|
|
33
|
+
orth_G = u @ v.T
|
|
34
|
+
except RuntimeError:
|
|
35
|
+
if warn_fail: logging.error(('Failed to perform SVD with noise,'
|
|
36
|
+
' skipping gradient orthogonalisation'))
|
|
37
|
+
if orth_G is not None:
|
|
38
|
+
grad.set_(orth_G.reshape_as(grad)) # type:ignore
|
|
39
|
+
|
|
40
|
+
return updates
|
|
41
|
+
|
|
42
|
+
def orthogonalize_grad_(params: Iterable[torch.Tensor], warn_fail=False):
|
|
43
|
+
"""orthogonalizes gradients of an iterable of parameters.
|
|
44
|
+
|
|
45
|
+
This updates gradients in-place.
|
|
46
|
+
|
|
47
|
+
The orthogonalization code is adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
|
|
48
|
+
Args:
|
|
49
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
50
|
+
warn_fail (bool, optional):
|
|
51
|
+
whether to print a warning when orthogonalization fails, and gradients are not
|
|
52
|
+
orthogonalized. Defaults to True.
|
|
53
|
+
"""
|
|
54
|
+
grads = [p.grad for p in params if p.grad is not None]
|
|
55
|
+
_orthogonalize_update_(grads, warn_fail=warn_fail)
|
|
56
|
+
|
|
57
|
+
class Orthogonalize(OptimizerModule):
|
|
58
|
+
"""Orthogonalizes the update using SVD.
|
|
59
|
+
|
|
60
|
+
To disable orthogonalization for a parameter, put it into a parameter group with "orth" = False.
|
|
61
|
+
|
|
62
|
+
The orthogonalization code is adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
|
|
63
|
+
|
|
64
|
+
Tip: :py:class:`tz.m.ZeropowerViaNewtonSchulz` is a significantly faster version of this.
|
|
65
|
+
Args:
|
|
66
|
+
warn_fail (bool, optional):
|
|
67
|
+
whether to print a warning when orthogonalization fails, and gradients are not
|
|
68
|
+
orthogonalized. Defaults to True.
|
|
69
|
+
target (str, optional):
|
|
70
|
+
determines what this module updates.
|
|
71
|
+
|
|
72
|
+
"ascent" - it updates the ascent
|
|
73
|
+
|
|
74
|
+
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
75
|
+
|
|
76
|
+
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
77
|
+
"""
|
|
78
|
+
def __init__(self, warn_fail=True, target: _Targets = 'ascent'):
|
|
79
|
+
defaults = dict(orth = True)
|
|
80
|
+
super().__init__(defaults, target = target)
|
|
81
|
+
self.warn_fail = warn_fail
|
|
82
|
+
|
|
83
|
+
def _update(self, state, ascent):
|
|
84
|
+
toggle = self.get_group_key('orth', cls=list)
|
|
85
|
+
_orthogonalize_update_(ascent, toggle, self.warn_fail)
|
|
86
|
+
return ascent
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
This includes regularization modules like weight decay.
|
|
3
|
+
"""
|
|
4
|
+
from .dropout import Dropout
|
|
5
|
+
from .noise import AddNoise, Random, add_noise_
|
|
6
|
+
from .normalization import (
|
|
7
|
+
Centralize,
|
|
8
|
+
ClipNorm,
|
|
9
|
+
ClipValue,
|
|
10
|
+
Normalize,
|
|
11
|
+
centralize_grad_,
|
|
12
|
+
clip_grad_norm_,
|
|
13
|
+
clip_grad_value_,
|
|
14
|
+
normalize_grad_,
|
|
15
|
+
)
|
|
16
|
+
from .weight_decay import (
|
|
17
|
+
WeightDecay,
|
|
18
|
+
l1_regularize_,
|
|
19
|
+
l2_regularize_,
|
|
20
|
+
weight_decay_penalty,
|
|
21
|
+
)
|
|
22
|
+
from .ortho_grad import OrthoGrad, orthograd_
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import typing as T
|
|
2
|
+
from collections import abc
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...tensorlist import Distributions, TensorList
|
|
7
|
+
from ...core import OptimizerModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Dropout(OptimizerModule):
|
|
11
|
+
"""
|
|
12
|
+
Applies dropout to the update - sets random elements to 0.
|
|
13
|
+
|
|
14
|
+
This can be used to apply learning rate dropout, if put after other modules, or gradient dropout,
|
|
15
|
+
if put first.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
p (float, optional): probability to replace update value with zero. Defaults to 0.5.
|
|
19
|
+
|
|
20
|
+
reference
|
|
21
|
+
*Lin, H., Zeng, W., Zhuang, Y., Ding, X., Huang, Y., & Paisley, J. (2022).
|
|
22
|
+
Learning rate dropout. IEEE Transactions on Neural Networks and Learning Systems,
|
|
23
|
+
34(11), 9029-9039.*
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self, p: float = 0.5):
|
|
26
|
+
defaults = dict(p = p)
|
|
27
|
+
super().__init__(defaults)
|
|
28
|
+
|
|
29
|
+
@torch.no_grad
|
|
30
|
+
def _update(self, state, ascent):
|
|
31
|
+
p = self.get_group_key('p')
|
|
32
|
+
|
|
33
|
+
ascent *= ascent.bernoulli_like(p)
|
|
34
|
+
return ascent
|