torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,99 +1,342 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from ...
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Module, Target, Transform
|
|
5
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _bool_ones_like(x):
|
|
9
|
+
return torch.ones_like(x, dtype=torch.bool)
|
|
10
|
+
|
|
11
|
+
def sign_consistency_lrs_(
|
|
12
|
+
tensors: TensorList,
|
|
13
|
+
prev_: TensorList,
|
|
14
|
+
lrs_: TensorList,
|
|
15
|
+
nplus: float | NumberList,
|
|
16
|
+
nminus: float | NumberList,
|
|
17
|
+
lb: float | NumberList,
|
|
18
|
+
ub: float | NumberList,
|
|
19
|
+
step: int,
|
|
20
|
+
):
|
|
21
|
+
"""returns `lrs_`"""
|
|
22
|
+
sign = tensors.sign()
|
|
23
|
+
if step == 0:
|
|
24
|
+
prev_.set_(sign)
|
|
25
|
+
return lrs_.clamp_(lb, ub)
|
|
26
|
+
|
|
27
|
+
mul = sign * prev_
|
|
28
|
+
prev_.set_(sign)
|
|
29
|
+
|
|
30
|
+
sign_changed = mul < 0
|
|
31
|
+
sign_same = mul > 0
|
|
32
|
+
|
|
33
|
+
mul.fill_(1)
|
|
34
|
+
mul.masked_fill_(sign_changed, nminus)
|
|
35
|
+
mul.masked_fill_(sign_same, nplus)
|
|
36
|
+
|
|
37
|
+
# multiply magnitudes based on sign change and clamp to bounds
|
|
38
|
+
lrs_.mul_(mul).clamp_(lb, ub)
|
|
39
|
+
return lrs_
|
|
40
|
+
|
|
41
|
+
def scale_by_sign_change_(
|
|
42
|
+
tensors_: TensorList,
|
|
43
|
+
cur: TensorList,
|
|
44
|
+
prev_: TensorList,
|
|
45
|
+
lrs_: TensorList,
|
|
46
|
+
nplus: float | NumberList,
|
|
47
|
+
nminus: float | NumberList,
|
|
48
|
+
lb: float | NumberList,
|
|
49
|
+
ub: float | NumberList,
|
|
50
|
+
step: int,
|
|
51
|
+
):
|
|
52
|
+
"""returns `tensors_`"""
|
|
53
|
+
lrs_ = sign_consistency_lrs_(cur,prev_=prev_,lrs_=lrs_,nplus=nplus,nminus=nminus,
|
|
54
|
+
lb=lb,ub=ub,step=step)
|
|
55
|
+
return tensors_.mul_(lrs_)
|
|
56
|
+
|
|
57
|
+
def backtrack_on_sign_change_(
|
|
58
|
+
tensors_: TensorList,
|
|
59
|
+
cur: TensorList,
|
|
60
|
+
prev_: TensorList,
|
|
61
|
+
backtrack: bool,
|
|
62
|
+
step: int
|
|
63
|
+
):
|
|
64
|
+
"""returns `tensors_`."""
|
|
65
|
+
if step == 0:
|
|
66
|
+
prev_.set_(cur)
|
|
67
|
+
return tensors_
|
|
68
|
+
|
|
69
|
+
# mask will be > 0 for parameters where both signs are the same
|
|
70
|
+
mask = (cur * prev_) < 0
|
|
71
|
+
if backtrack: tensors_.masked_set_(mask, prev_)
|
|
72
|
+
else: tensors_.select_set_(mask, 0)
|
|
73
|
+
|
|
74
|
+
prev_.set_(cur)
|
|
75
|
+
return tensors_
|
|
76
|
+
|
|
77
|
+
def rprop_(
|
|
78
|
+
tensors_: TensorList,
|
|
79
|
+
prev_: TensorList,
|
|
80
|
+
allowed_: TensorList,
|
|
81
|
+
magnitudes_: TensorList,
|
|
82
|
+
nplus: float | NumberList,
|
|
83
|
+
nminus: float | NumberList,
|
|
84
|
+
lb: float | NumberList,
|
|
85
|
+
ub: float | NumberList,
|
|
86
|
+
alpha: float | NumberList,
|
|
87
|
+
backtrack: bool,
|
|
88
|
+
step: int,
|
|
89
|
+
):
|
|
90
|
+
"""returns new tensors."""
|
|
91
|
+
|
|
92
|
+
sign = tensors_.sign_()
|
|
93
|
+
|
|
94
|
+
# initialize on 1st step
|
|
95
|
+
if step == 0:
|
|
96
|
+
magnitudes_.fill_(alpha).clamp_(lb, ub)
|
|
97
|
+
new_tensors = magnitudes_ * sign
|
|
98
|
+
prev_.copy_(new_tensors)
|
|
99
|
+
return new_tensors
|
|
100
|
+
|
|
101
|
+
mul = (sign * prev_).mul_(allowed_)
|
|
102
|
+
|
|
103
|
+
sign_changed = mul < 0
|
|
104
|
+
sign_same = mul > 0
|
|
105
|
+
zeroes = mul == 0
|
|
106
|
+
|
|
107
|
+
mul.fill_(1)
|
|
108
|
+
mul.masked_fill_(sign_changed, nminus)
|
|
109
|
+
mul.masked_fill_(sign_same, nplus)
|
|
110
|
+
|
|
111
|
+
# multiply magnitudes based on sign change and clamp to bounds
|
|
112
|
+
magnitudes_.mul_(mul).clamp_(lb, ub)
|
|
113
|
+
|
|
114
|
+
# revert update if sign changed
|
|
115
|
+
if backtrack:
|
|
116
|
+
new_tensors = sign.mul_(magnitudes_)
|
|
117
|
+
new_tensors.masked_set_(sign_changed, prev_.neg_())
|
|
118
|
+
else:
|
|
119
|
+
new_tensors = sign.mul_(magnitudes_ * ~sign_changed)
|
|
120
|
+
|
|
121
|
+
# update allowed to only have weights where last update wasn't reverted
|
|
122
|
+
allowed_.set_(sign_same | zeroes)
|
|
123
|
+
|
|
124
|
+
prev_.copy_(new_tensors)
|
|
125
|
+
return new_tensors
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class Rprop(Transform):
|
|
130
|
+
"""
|
|
131
|
+
Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
|
|
132
|
+
or `nminus` if it did. Then the update is applied with the sign of the current gradient.
|
|
133
|
+
|
|
134
|
+
Additionally, if gradient changes sign, the update for that weight is reverted.
|
|
135
|
+
Next step, magnitude for that weight won't change.
|
|
136
|
+
|
|
137
|
+
Compared to pytorch this also implements backtracking update when sign changes.
|
|
138
|
+
To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
142
|
+
nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
|
|
143
|
+
lb (float): minimum step size, can be None (default: 1e-6)
|
|
144
|
+
ub (float): maximum step size, can be None (default: 50)
|
|
145
|
+
backtrack (float):
|
|
146
|
+
if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
|
|
147
|
+
When this is False, this exactly matches pytorch Rprop. (default: True)
|
|
148
|
+
alpha (float): initial per-parameter learning rate (default: 1).
|
|
149
|
+
|
|
150
|
+
reference
|
|
151
|
+
*Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
|
|
152
|
+
The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
|
|
153
|
+
"""
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
nplus: float = 1.2,
|
|
157
|
+
nminus: float = 0.5,
|
|
158
|
+
lb: float = 1e-6,
|
|
159
|
+
ub: float = 50,
|
|
160
|
+
backtrack=True,
|
|
161
|
+
alpha: float = 1,
|
|
162
|
+
):
|
|
163
|
+
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
|
|
164
|
+
self.current_step = 0
|
|
165
|
+
super().__init__(defaults, uses_grad=False)
|
|
166
|
+
|
|
167
|
+
@torch.no_grad
|
|
168
|
+
def transform(self, tensors, params, grads, vars):
|
|
169
|
+
nplus, nminus, lb, ub, alpha = self.get_settings('nplus', 'nminus', 'lb', 'ub', 'alpha', params=params, cls=NumberList)
|
|
170
|
+
prev, allowed, magnitudes = self.get_state(
|
|
171
|
+
'prev','allowed','magnitudes',
|
|
172
|
+
params=params,
|
|
173
|
+
init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
|
|
174
|
+
cls = TensorList,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
target = rprop_(
|
|
178
|
+
tensors_ = as_tensorlist(tensors),
|
|
179
|
+
prev_ = prev,
|
|
180
|
+
allowed_ = allowed,
|
|
181
|
+
magnitudes_ = magnitudes,
|
|
182
|
+
nplus = nplus,
|
|
183
|
+
nminus = nminus,
|
|
184
|
+
lb = lb,
|
|
185
|
+
ub = ub,
|
|
186
|
+
alpha = alpha,
|
|
187
|
+
backtrack=self.settings[params[0]]['backtrack'],
|
|
188
|
+
step=self.current_step,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.current_step += 1
|
|
192
|
+
return target
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class ScaleLRBySignChange(Transform):
|
|
196
|
+
"""
|
|
197
|
+
learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
|
|
198
|
+
or `nminus` if it did.
|
|
199
|
+
|
|
200
|
+
This is part of RProp update rule.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
|
|
204
|
+
nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
|
|
205
|
+
lb (float): lower bound for lr.
|
|
206
|
+
ub (float): upper bound for lr.
|
|
207
|
+
alpha (float): initial learning rate.
|
|
208
|
+
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
nplus: float = 1.2,
|
|
214
|
+
nminus: float = 0.5,
|
|
215
|
+
lb=1e-6,
|
|
216
|
+
ub=50.0,
|
|
217
|
+
alpha=1.0,
|
|
218
|
+
use_grad=False,
|
|
219
|
+
target: Target = "update",
|
|
220
|
+
):
|
|
221
|
+
defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
|
|
222
|
+
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
223
|
+
self.current_step = 0
|
|
224
|
+
|
|
225
|
+
@torch.no_grad
|
|
226
|
+
def transform(self, tensors, params, grads, vars):
|
|
227
|
+
target = as_tensorlist(tensors)
|
|
228
|
+
use_grad = self.settings[params[0]]['use_grad']
|
|
229
|
+
if use_grad: cur = as_tensorlist(grads)
|
|
230
|
+
else: cur = target
|
|
231
|
+
|
|
232
|
+
nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
|
|
233
|
+
prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
|
|
234
|
+
|
|
235
|
+
if self.current_step == 0:
|
|
236
|
+
lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
|
|
237
|
+
|
|
238
|
+
target = scale_by_sign_change_(
|
|
239
|
+
tensors_ = target,
|
|
240
|
+
cur = cur,
|
|
241
|
+
prev_ = prev,
|
|
242
|
+
lrs_ = lrs,
|
|
243
|
+
nplus = nplus,
|
|
244
|
+
nminus = nminus,
|
|
245
|
+
lb = lb,
|
|
246
|
+
ub = ub,
|
|
247
|
+
step = self.current_step,
|
|
248
|
+
)
|
|
249
|
+
self.current_step += 1
|
|
250
|
+
return target
|
|
251
|
+
|
|
252
|
+
class BacktrackOnSignChange(Transform):
|
|
253
|
+
"""Negates or undoes update for parameters where where gradient or update sign changes.
|
|
254
|
+
|
|
255
|
+
This is part of RProp update rule.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
normalize (bool, optional): renormalize update after masking. Defaults to False.
|
|
259
|
+
eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
|
|
260
|
+
use_grad (bool, optional):
|
|
261
|
+
if True, tracks sign change of the gradient,
|
|
262
|
+
otherwise track sign change of the update. Defaults to True.
|
|
263
|
+
backtrack (bool, optional):
|
|
264
|
+
if True, undoes the update when sign changes, otherwise negates it.
|
|
265
|
+
Defaults to True.
|
|
266
|
+
|
|
267
|
+
"""
|
|
268
|
+
def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
|
|
269
|
+
defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
|
|
270
|
+
super().__init__(defaults, uses_grad=use_grad)
|
|
271
|
+
self.current_step = 0
|
|
272
|
+
|
|
273
|
+
@torch.no_grad
|
|
274
|
+
def transform(self, tensors, params, grads, vars):
|
|
275
|
+
target = as_tensorlist(tensors)
|
|
276
|
+
settings = self.settings[params[0]]
|
|
277
|
+
use_grad = settings['use_grad']
|
|
278
|
+
backtrack = settings['backtrack']
|
|
279
|
+
|
|
280
|
+
if use_grad: cur = as_tensorlist(grads)
|
|
281
|
+
else: cur = target
|
|
282
|
+
|
|
283
|
+
target = backtrack_on_sign_change_(
|
|
284
|
+
tensors_ = target,
|
|
285
|
+
cur = cur,
|
|
286
|
+
prev_ = self.get_state('prev', params=params, cls=TensorList),
|
|
287
|
+
backtrack = backtrack,
|
|
288
|
+
step = self.current_step,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
self.current_step += 1
|
|
292
|
+
return target
|
|
293
|
+
|
|
294
|
+
class SignConsistencyMask(Transform):
|
|
295
|
+
"""0 if sign changed 1 otherwise"""
|
|
296
|
+
def __init__(self,target: Target = 'update'):
|
|
297
|
+
super().__init__({}, uses_grad=False, target = target)
|
|
298
|
+
|
|
299
|
+
@torch.no_grad
|
|
300
|
+
def transform(self, tensors, params, grads, vars):
|
|
301
|
+
prev = self.get_state('prev', params=params, cls=TensorList)
|
|
302
|
+
mask = prev.mul_(tensors).gt_(0)
|
|
303
|
+
prev.set_(tensors)
|
|
304
|
+
return mask
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class SignConsistencyLRs(Transform):
|
|
308
|
+
"""LR for each weight is increased when two consequtive update signs are the same, decreased otherwise. This returns the LRs themselves."""
|
|
309
|
+
def __init__(
|
|
310
|
+
self,
|
|
311
|
+
nplus: float = 1.2,
|
|
312
|
+
nminus: float = 0.5,
|
|
313
|
+
lb: float | None = 1e-6,
|
|
314
|
+
ub: float | None = 50,
|
|
315
|
+
alpha: float = 1,
|
|
316
|
+
target: Target = 'update'
|
|
317
|
+
):
|
|
318
|
+
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
319
|
+
super().__init__(defaults, uses_grad=False, target = target)
|
|
320
|
+
self.current_step = 0
|
|
321
|
+
|
|
322
|
+
@torch.no_grad
|
|
323
|
+
def transform(self, tensors, params, grads, vars):
|
|
324
|
+
target = as_tensorlist(tensors)
|
|
325
|
+
nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
|
|
326
|
+
prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
|
|
327
|
+
|
|
328
|
+
if self.current_step == 0:
|
|
329
|
+
lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
|
|
330
|
+
|
|
331
|
+
target = sign_consistency_lrs_(
|
|
332
|
+
tensors = target,
|
|
333
|
+
prev_ = prev,
|
|
334
|
+
lrs_ = lrs,
|
|
335
|
+
nplus = nplus,
|
|
336
|
+
nminus = nminus,
|
|
337
|
+
lb = lb,
|
|
338
|
+
ub = ub,
|
|
339
|
+
step = self.current_step,
|
|
340
|
+
)
|
|
341
|
+
self.current_step += 1
|
|
342
|
+
return target.clone()
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
from functools import partial
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Transform, apply
|
|
8
|
+
from ...utils.linalg import matrix_power_eigh
|
|
9
|
+
from ...utils import set_storage_
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def update_shampoo_preconditioner_(
|
|
13
|
+
grad: torch.Tensor,
|
|
14
|
+
accumulators_: list[torch.Tensor | None],
|
|
15
|
+
preconditioners_: list[torch.Tensor | None],
|
|
16
|
+
step: int,
|
|
17
|
+
update_freq: int,
|
|
18
|
+
exp_override: int | None,
|
|
19
|
+
beta: float | None,
|
|
20
|
+
):
|
|
21
|
+
for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
|
|
22
|
+
if accumulator is None: continue
|
|
23
|
+
assert preconditioner is not None
|
|
24
|
+
|
|
25
|
+
axes = list(range(i)) + list(range(i + 1, grad.ndim))
|
|
26
|
+
if beta is None: accumulator.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
27
|
+
else: accumulator.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
28
|
+
|
|
29
|
+
if step % update_freq == 0:
|
|
30
|
+
matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
|
|
31
|
+
set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def apply_shampoo_preconditioner(
|
|
35
|
+
tensor: torch.Tensor,
|
|
36
|
+
preconditioners_: list[torch.Tensor | None],
|
|
37
|
+
decay: float | None,
|
|
38
|
+
):
|
|
39
|
+
for i, preconditioner in enumerate(preconditioners_):
|
|
40
|
+
if preconditioner is None: continue
|
|
41
|
+
tensor = torch.tensordot(tensor, preconditioner, ([0], [0])) # pyright:ignore[reportArgumentType]
|
|
42
|
+
if decay is not None: preconditioner.mul_(decay)
|
|
43
|
+
return tensor
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def update_diagonal_(grad: torch.Tensor, diagonal_accumulator_: torch.Tensor, beta: float | None):
|
|
47
|
+
if beta is None: diagonal_accumulator_.add_(grad.pow(2))
|
|
48
|
+
else: diagonal_accumulator_.mul_(beta).addcmul_(grad, grad, value=1-beta)
|
|
49
|
+
|
|
50
|
+
def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, decay: float | None, eps: float):
|
|
51
|
+
grad_.div_(diagonal_accumulator_.sqrt() + eps)
|
|
52
|
+
if decay is not None: diagonal_accumulator_.mul_(decay)
|
|
53
|
+
return grad_
|
|
54
|
+
|
|
55
|
+
def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
|
|
56
|
+
"""a safer merger"""
|
|
57
|
+
if tensor.ndim == 0: return tensor, None, None
|
|
58
|
+
sort_idxs = np.argsort(tensor.shape)
|
|
59
|
+
if tensor.shape[sort_idxs[0]] > max_dim:
|
|
60
|
+
return tensor, None, None
|
|
61
|
+
|
|
62
|
+
tensor = tensor.permute(*sort_idxs)
|
|
63
|
+
flatten_end_idx = 0
|
|
64
|
+
flat_sizes = []
|
|
65
|
+
flat_numel = 1
|
|
66
|
+
for i, size in enumerate(tensor.shape):
|
|
67
|
+
if flat_numel * size <= max_dim:
|
|
68
|
+
flatten_end_idx = i
|
|
69
|
+
flat_numel *= size
|
|
70
|
+
flat_sizes.append(size)
|
|
71
|
+
else:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
if flatten_end_idx != 0:
|
|
75
|
+
tensor = tensor.flatten(end_dim=flatten_end_idx)
|
|
76
|
+
|
|
77
|
+
return tensor, flat_sizes, sort_idxs
|
|
78
|
+
|
|
79
|
+
def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None, sort_idxs: np.ndarray | Sequence[int] | None):
|
|
80
|
+
if flat_sizes is None: return tensor
|
|
81
|
+
assert sort_idxs is not None
|
|
82
|
+
tensor = tensor.unflatten(0, flat_sizes)
|
|
83
|
+
return tensor.permute(*np.argsort(sort_idxs))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Shampoo(Transform):
|
|
87
|
+
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
91
|
+
beta (float | None, optional):
|
|
92
|
+
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
93
|
+
matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
|
|
94
|
+
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
95
|
+
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to None.
|
|
96
|
+
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
97
|
+
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
|
|
98
|
+
precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
|
|
99
|
+
adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
|
|
100
|
+
inner (Chainable | None, optional):
|
|
101
|
+
module applied after updating preconditioners and before applying preconditioning.
|
|
102
|
+
For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
|
|
103
|
+
Defaults to None.
|
|
104
|
+
"""
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
decay: float | None = None,
|
|
108
|
+
beta: float | None = None,
|
|
109
|
+
reg: float = 1e-6,
|
|
110
|
+
update_freq: int = 10,
|
|
111
|
+
exp_override: int | None = None,
|
|
112
|
+
merge_small: bool = True,
|
|
113
|
+
max_dim: int = 2_000,
|
|
114
|
+
precondition_1d: bool = True,
|
|
115
|
+
adagrad_eps: float = 1e-8,
|
|
116
|
+
inner: Chainable | None = None,
|
|
117
|
+
):
|
|
118
|
+
defaults = dict(decay=decay, beta=beta, reg=reg, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
|
|
119
|
+
super().__init__(defaults, uses_grad=False)
|
|
120
|
+
|
|
121
|
+
if inner is not None:
|
|
122
|
+
self.set_child('inner', inner)
|
|
123
|
+
|
|
124
|
+
def transform(self, tensors, params, grads, vars):
|
|
125
|
+
merged_target = [] # target with merged dims
|
|
126
|
+
|
|
127
|
+
# update preconditioners
|
|
128
|
+
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
129
|
+
state = self.state[p]
|
|
130
|
+
settings = self.settings[p]
|
|
131
|
+
beta, reg, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
132
|
+
'beta', 'reg', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(settings)
|
|
133
|
+
|
|
134
|
+
if merge_small:
|
|
135
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
136
|
+
merged_target.append(t)
|
|
137
|
+
|
|
138
|
+
# initialize accumulators and preconditioners for each dim on 1st step
|
|
139
|
+
if 'accumulators' not in state:
|
|
140
|
+
|
|
141
|
+
if not precondition_1d and t.ndim <= 1:
|
|
142
|
+
state['accumulators'] = []
|
|
143
|
+
|
|
144
|
+
else:
|
|
145
|
+
state['accumulators'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
146
|
+
state['preconditioners'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
147
|
+
|
|
148
|
+
# either scalar parameter, 1d with precondition_1d=False, or too big, then basic diagonal preconditioner is used.
|
|
149
|
+
if len([i is not None for i in state['accumulators']]) == 0:
|
|
150
|
+
state['diagonal_accumulator'] = torch.zeros_like(t)
|
|
151
|
+
|
|
152
|
+
state['step'] = 0
|
|
153
|
+
|
|
154
|
+
# update preconditioners
|
|
155
|
+
if 'diagonal_accumulator' in state:
|
|
156
|
+
update_diagonal_(t, state['diagonal_accumulator'], beta)
|
|
157
|
+
else:
|
|
158
|
+
update_shampoo_preconditioner_(
|
|
159
|
+
t,
|
|
160
|
+
accumulators_=state['accumulators'],
|
|
161
|
+
preconditioners_=state['preconditioners'],
|
|
162
|
+
step=state['step'],
|
|
163
|
+
update_freq=update_freq,
|
|
164
|
+
exp_override=exp_override,
|
|
165
|
+
beta=beta,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# inner step
|
|
169
|
+
if 'inner' in self.children:
|
|
170
|
+
tensors = apply(self.children['inner'], tensors, params=params, grads=grads, vars=vars)
|
|
171
|
+
|
|
172
|
+
# have to merge small dims again
|
|
173
|
+
merged_target = [] # target with merged dims
|
|
174
|
+
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
175
|
+
state = self.state[p]
|
|
176
|
+
settings = self.settings[p]
|
|
177
|
+
if settings['merge_small']:
|
|
178
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, settings['max_dim'])
|
|
179
|
+
merged_target.append(t)
|
|
180
|
+
|
|
181
|
+
# precondition
|
|
182
|
+
for i, (p, t) in enumerate(zip(params, merged_target)):
|
|
183
|
+
state = self.state[p]
|
|
184
|
+
settings = self.settings[p]
|
|
185
|
+
decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(settings)
|
|
186
|
+
|
|
187
|
+
if 'diagonal_accumulator' in state:
|
|
188
|
+
tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
|
|
189
|
+
else:
|
|
190
|
+
tensors[i] = apply_shampoo_preconditioner(t, preconditioners_=state['preconditioners'], decay=decay)
|
|
191
|
+
|
|
192
|
+
if merge_small:
|
|
193
|
+
tensors[i] = _unmerge_small_dims(tensors[i], state['flat_sizes'], state['sort_idxs'])
|
|
194
|
+
|
|
195
|
+
state['step'] += 1
|
|
196
|
+
|
|
197
|
+
return tensors
|