torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,294 +0,0 @@
|
|
|
1
|
-
from collections import abc
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...tensorlist import TensorList, where, Distributions
|
|
6
|
-
from ...core import OptimizerModule
|
|
7
|
-
from ...utils.derivatives import jacobian
|
|
8
|
-
|
|
9
|
-
def _bool_ones_like(x):
|
|
10
|
-
return torch.ones_like(x, dtype=torch.bool)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class MinibatchRprop(OptimizerModule):
|
|
14
|
-
"""
|
|
15
|
-
for experiments, unlikely to work well on most problems.
|
|
16
|
-
|
|
17
|
-
explanation: does 2 steps per batch, applies rprop rule on the second step.
|
|
18
|
-
"""
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
nplus: float = 1.2,
|
|
22
|
-
nminus: float = 0.5,
|
|
23
|
-
lb: float | None = 1e-6,
|
|
24
|
-
ub: float | None = 50,
|
|
25
|
-
backtrack=True,
|
|
26
|
-
next_mode = 'continue',
|
|
27
|
-
increase_mul = 0.5,
|
|
28
|
-
alpha: float = 1,
|
|
29
|
-
):
|
|
30
|
-
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, increase_mul=increase_mul)
|
|
31
|
-
super().__init__(defaults)
|
|
32
|
-
self.current_step = 0
|
|
33
|
-
self.backtrack = backtrack
|
|
34
|
-
|
|
35
|
-
self.next_mode = next_mode
|
|
36
|
-
|
|
37
|
-
@torch.no_grad
|
|
38
|
-
def step(self, vars):
|
|
39
|
-
if vars.closure is None: raise ValueError("Minibatch Rprop requires closure")
|
|
40
|
-
if vars.ascent is not None: raise ValueError("Minibatch Rprop must be the first module.")
|
|
41
|
-
params = self.get_params()
|
|
42
|
-
|
|
43
|
-
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
|
|
44
|
-
allowed, magnitudes = self.get_state_keys(
|
|
45
|
-
'allowed', 'magnitudes',
|
|
46
|
-
inits = [_bool_ones_like, torch.zeros_like],
|
|
47
|
-
params=params
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
g1_sign = vars.maybe_compute_grad_(params).sign() # no inplace to not modify grads
|
|
51
|
-
# initialize on 1st iteration
|
|
52
|
-
if self.current_step == 0:
|
|
53
|
-
magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
|
|
54
|
-
# ascent = magnitudes * g1_sign
|
|
55
|
-
# self.current_step += 1
|
|
56
|
-
# return ascent
|
|
57
|
-
|
|
58
|
-
# first step
|
|
59
|
-
ascent = g1_sign.mul_(magnitudes).mul_(allowed)
|
|
60
|
-
params -= ascent
|
|
61
|
-
with torch.enable_grad(): vars.fx0_approx = vars.closure()
|
|
62
|
-
f0 = vars.fx0; f1 = vars.fx0_approx
|
|
63
|
-
assert f0 is not None and f1 is not None
|
|
64
|
-
|
|
65
|
-
# if loss increased, reduce all lrs and undo the update
|
|
66
|
-
if f1 > f0:
|
|
67
|
-
increase_mul = self.get_group_key('increase_mul')
|
|
68
|
-
magnitudes.mul_(increase_mul).clamp_(lb, ub)
|
|
69
|
-
params += ascent
|
|
70
|
-
self.current_step += 1
|
|
71
|
-
return f0
|
|
72
|
-
|
|
73
|
-
# on `continue` we move to params after 1st update
|
|
74
|
-
# therefore state must be updated to have all attributes after 1st update
|
|
75
|
-
if self.next_mode == 'continue':
|
|
76
|
-
vars.fx0 = vars.fx0_approx
|
|
77
|
-
vars.grad = params.ensure_grad_().grad
|
|
78
|
-
sign = vars.grad.sign()
|
|
79
|
-
|
|
80
|
-
else:
|
|
81
|
-
sign = params.ensure_grad_().grad.sign_() # can use in-place as this is not fx0 grad
|
|
82
|
-
|
|
83
|
-
# compare 1st and 2nd gradients via rprop rule
|
|
84
|
-
prev = ascent
|
|
85
|
-
mul = sign * prev # prev is already multiuplied by `allowed`
|
|
86
|
-
|
|
87
|
-
sign_changed = mul < 0
|
|
88
|
-
sign_same = mul > 0
|
|
89
|
-
zeroes = mul == 0
|
|
90
|
-
|
|
91
|
-
mul.fill_(1)
|
|
92
|
-
mul.masked_fill_(sign_changed, nminus)
|
|
93
|
-
mul.masked_fill_(sign_same, nplus)
|
|
94
|
-
|
|
95
|
-
# multiply magnitudes based on sign change and clamp to bounds
|
|
96
|
-
magnitudes.mul_(mul).clamp_(lb, ub)
|
|
97
|
-
|
|
98
|
-
# revert update if sign changed
|
|
99
|
-
if self.backtrack:
|
|
100
|
-
ascent2 = sign.mul_(magnitudes)
|
|
101
|
-
ascent2.masked_set_(sign_changed, prev.neg_())
|
|
102
|
-
else:
|
|
103
|
-
ascent2 = sign.mul_(magnitudes * ~sign_changed)
|
|
104
|
-
|
|
105
|
-
# update allowed to only have weights where last update wasn't reverted
|
|
106
|
-
allowed.set_(sign_same | zeroes)
|
|
107
|
-
|
|
108
|
-
self.current_step += 1
|
|
109
|
-
|
|
110
|
-
# update params or step
|
|
111
|
-
if self.next_mode == 'continue' or (self.next_mode == 'add' and self.next_module is None):
|
|
112
|
-
vars.ascent = ascent2
|
|
113
|
-
return self._update_params_or_step_with_next(vars, params)
|
|
114
|
-
|
|
115
|
-
if self.next_mode == 'add':
|
|
116
|
-
# undo 1st step
|
|
117
|
-
params += ascent
|
|
118
|
-
vars.ascent = ascent + ascent2
|
|
119
|
-
return self._update_params_or_step_with_next(vars, params)
|
|
120
|
-
|
|
121
|
-
if self.next_mode == 'undo':
|
|
122
|
-
params += ascent
|
|
123
|
-
vars.ascent = ascent2
|
|
124
|
-
return self._update_params_or_step_with_next(vars, params)
|
|
125
|
-
|
|
126
|
-
raise ValueError(f'invalid next_mode: {self.next_mode}')
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class GradMin(OptimizerModule):
|
|
131
|
-
"""
|
|
132
|
-
for experiments, unlikely to work well on most problems.
|
|
133
|
-
|
|
134
|
-
explanation: calculate grads wrt sum of grads + loss.
|
|
135
|
-
"""
|
|
136
|
-
def __init__(self, loss_term: float = 1, square=False, maximize_grad = False, create_graph = False):
|
|
137
|
-
super().__init__(dict(loss_term=loss_term))
|
|
138
|
-
self.square = square
|
|
139
|
-
self.maximize_grad = maximize_grad
|
|
140
|
-
self.create_graph = create_graph
|
|
141
|
-
|
|
142
|
-
@torch.no_grad
|
|
143
|
-
def step(self, vars):
|
|
144
|
-
if vars.closure is None: raise ValueError()
|
|
145
|
-
if vars.ascent is not None:
|
|
146
|
-
raise ValueError("GradMin doesn't accept ascent_direction")
|
|
147
|
-
|
|
148
|
-
params = self.get_params()
|
|
149
|
-
loss_term = self.get_group_key('loss_term')
|
|
150
|
-
|
|
151
|
-
self.zero_grad()
|
|
152
|
-
with torch.enable_grad():
|
|
153
|
-
vars.fx0 = vars.closure(False)
|
|
154
|
-
grads = jacobian([vars.fx0], params, create_graph=True, batched=False) # type:ignore
|
|
155
|
-
grads = TensorList(grads).squeeze_(0)
|
|
156
|
-
if self.square:
|
|
157
|
-
grads = grads ** 2
|
|
158
|
-
else:
|
|
159
|
-
grads = grads.abs()
|
|
160
|
-
|
|
161
|
-
if self.maximize_grad: grads: TensorList = grads - (vars.fx0 * loss_term) # type:ignore
|
|
162
|
-
else: grads = grads + (vars.fx0 * loss_term)
|
|
163
|
-
grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel()
|
|
164
|
-
|
|
165
|
-
if self.create_graph: grad_mean.backward(create_graph=True)
|
|
166
|
-
else: grad_mean.backward(retain_graph=False)
|
|
167
|
-
|
|
168
|
-
if self.maximize_grad: vars.grad = params.ensure_grad_().grad.neg_()
|
|
169
|
-
else: vars.grad = params.ensure_grad_().grad
|
|
170
|
-
|
|
171
|
-
vars.maybe_use_grad_(params)
|
|
172
|
-
return self._update_params_or_step_with_next(vars)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
class HVPDiagNewton(OptimizerModule):
|
|
176
|
-
"""
|
|
177
|
-
for experiments, unlikely to work well on most problems.
|
|
178
|
-
|
|
179
|
-
explanation: may or may not approximate newton step if hessian is diagonal with 2 backward passes. Probably not.
|
|
180
|
-
"""
|
|
181
|
-
def __init__(self, eps=1e-3):
|
|
182
|
-
super().__init__(dict(eps=eps))
|
|
183
|
-
|
|
184
|
-
@torch.no_grad
|
|
185
|
-
def step(self, vars):
|
|
186
|
-
if vars.closure is None: raise ValueError()
|
|
187
|
-
if vars.ascent is not None:
|
|
188
|
-
raise ValueError("HVPDiagNewton doesn't accept ascent_direction")
|
|
189
|
-
|
|
190
|
-
params = self.get_params()
|
|
191
|
-
eps = self.get_group_key('eps')
|
|
192
|
-
grad_fx0 = vars.maybe_compute_grad_(params).clone()
|
|
193
|
-
vars.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten
|
|
194
|
-
|
|
195
|
-
params += grad_fx0 * eps
|
|
196
|
-
with torch.enable_grad(): _ = vars.closure()
|
|
197
|
-
|
|
198
|
-
params -= grad_fx0 * eps
|
|
199
|
-
|
|
200
|
-
newton = grad_fx0 * ((grad_fx0 * eps) / (params.grad - grad_fx0))
|
|
201
|
-
newton.nan_to_num_(0,0,0)
|
|
202
|
-
|
|
203
|
-
vars.ascent = newton
|
|
204
|
-
return self._update_params_or_step_with_next(vars)
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
class ReduceOutwardLR(OptimizerModule):
|
|
209
|
-
"""
|
|
210
|
-
When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
211
|
-
|
|
212
|
-
This means updates that move weights towards zero have higher learning rates.
|
|
213
|
-
"""
|
|
214
|
-
def __init__(self, mul = 0.5, use_grad=False, invert=False):
|
|
215
|
-
defaults = dict(mul = mul)
|
|
216
|
-
super().__init__(defaults)
|
|
217
|
-
|
|
218
|
-
self.use_grad = use_grad
|
|
219
|
-
self.invert = invert
|
|
220
|
-
|
|
221
|
-
@torch.no_grad
|
|
222
|
-
def _update(self, vars, ascent):
|
|
223
|
-
params = self.get_params()
|
|
224
|
-
mul = self.get_group_key('mul')
|
|
225
|
-
|
|
226
|
-
if self.use_grad: cur = vars.maybe_compute_grad_(params)
|
|
227
|
-
else: cur = ascent
|
|
228
|
-
|
|
229
|
-
# mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
|
|
230
|
-
if self.invert: mask = (params * cur) > 0
|
|
231
|
-
else: mask = (params * cur) < 0
|
|
232
|
-
ascent.masked_set_(mask, ascent*mul)
|
|
233
|
-
|
|
234
|
-
return ascent
|
|
235
|
-
|
|
236
|
-
class NoiseSign(OptimizerModule):
|
|
237
|
-
"""uses random vector with ascent sign"""
|
|
238
|
-
def __init__(self, distribution:Distributions = 'normal', alpha = 1):
|
|
239
|
-
super().__init__({})
|
|
240
|
-
self.alpha = alpha
|
|
241
|
-
self.distribution:Distributions = distribution
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def _update(self, vars, ascent):
|
|
245
|
-
return ascent.sample_like(self.alpha, self.distribution).copysign_(ascent)
|
|
246
|
-
|
|
247
|
-
class ParamSign(OptimizerModule):
|
|
248
|
-
"""uses params with ascent sign"""
|
|
249
|
-
def __init__(self):
|
|
250
|
-
super().__init__({})
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
def _update(self, vars, ascent):
|
|
254
|
-
params = self.get_params()
|
|
255
|
-
|
|
256
|
-
return params.copysign(ascent)
|
|
257
|
-
|
|
258
|
-
class NegParamSign(OptimizerModule):
|
|
259
|
-
"""uses max(params_abs) - params_abs with ascent sign"""
|
|
260
|
-
def __init__(self):
|
|
261
|
-
super().__init__({})
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
def _update(self, vars, ascent):
|
|
265
|
-
neg_params = self.get_params().abs()
|
|
266
|
-
max = neg_params.total_max()
|
|
267
|
-
neg_params = neg_params.neg_().add(max)
|
|
268
|
-
return neg_params.copysign_(ascent)
|
|
269
|
-
|
|
270
|
-
class InvParamSign(OptimizerModule):
|
|
271
|
-
"""uses 1/(params_abs+eps) with ascent sign"""
|
|
272
|
-
def __init__(self, eps=1e-2):
|
|
273
|
-
super().__init__({})
|
|
274
|
-
self.eps = eps
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
def _update(self, vars, ascent):
|
|
278
|
-
inv_params = self.get_params().abs().add_(self.eps).reciprocal_()
|
|
279
|
-
return inv_params.copysign(ascent)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
class ParamWhereConsistentSign(OptimizerModule):
|
|
283
|
-
"""where ascent and param signs are the same, it sets ascent to param value"""
|
|
284
|
-
def __init__(self, eps=1e-2):
|
|
285
|
-
super().__init__({})
|
|
286
|
-
self.eps = eps
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def _update(self, vars, ascent):
|
|
290
|
-
params = self.get_params()
|
|
291
|
-
same_sign = params.sign() == ascent.sign()
|
|
292
|
-
ascent.masked_set_(same_sign, params)
|
|
293
|
-
|
|
294
|
-
return ascent
|
|
@@ -1,104 +0,0 @@
|
|
|
1
|
-
import bisect
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import OptimizationVars
|
|
8
|
-
from ..line_search.base_ls import LineSearchBase
|
|
9
|
-
|
|
10
|
-
_FloatOrTensor = float | torch.Tensor
|
|
11
|
-
|
|
12
|
-
def _ensure_float(x):
|
|
13
|
-
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
14
|
-
elif isinstance(x, np.ndarray): return x.item()
|
|
15
|
-
return float(x)
|
|
16
|
-
|
|
17
|
-
class Point:
|
|
18
|
-
def __init__(self, x, fx, dfx = None):
|
|
19
|
-
self.x = x
|
|
20
|
-
self.fx = fx
|
|
21
|
-
self.dfx = dfx
|
|
22
|
-
|
|
23
|
-
def __repr__(self):
|
|
24
|
-
return f'Point(x={self.x:.2f}, fx={self.fx:.2f})'
|
|
25
|
-
|
|
26
|
-
def _step_2poins(x1, f1, df1, x2, f2):
|
|
27
|
-
# we have two points and one derivative
|
|
28
|
-
# minimize the quadratic to obtain 3rd point and perform bracketing
|
|
29
|
-
a = (df1 * x2 - f2 - df1*x1 + f1) / (x1**2 - x2**2 - 2*x1**2 + 2*x1*x2)
|
|
30
|
-
b = df1 - 2*a*x1
|
|
31
|
-
# c = -(a*x1**2 + b*x1 - y1)
|
|
32
|
-
return -b / (2 * a), a
|
|
33
|
-
|
|
34
|
-
class QuadraticInterpolation2Point(LineSearchBase):
|
|
35
|
-
"""This is WIP, please don't use yet!
|
|
36
|
-
Use `torchzero.modules.MinimizeQuadraticLS` and `torchzero.modules.MinimizeQuadratic3PointsLS` instead.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
lr (_type_, optional): _description_. Defaults to 1e-2.
|
|
40
|
-
log_lrs (bool, optional): _description_. Defaults to False.
|
|
41
|
-
max_evals (int, optional): _description_. Defaults to 2.
|
|
42
|
-
min_dist (_type_, optional): _description_. Defaults to 1e-2.
|
|
43
|
-
"""
|
|
44
|
-
def __init__(self, lr=1e-2, log_lrs = False, max_evals = 2, min_dist = 1e-2,):
|
|
45
|
-
super().__init__({"lr": lr}, maxiter=None, log_lrs=log_lrs)
|
|
46
|
-
self.max_evals = max_evals
|
|
47
|
-
self.min_dist = min_dist
|
|
48
|
-
|
|
49
|
-
@torch.no_grad
|
|
50
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
51
|
-
if vars.closure is None: raise ValueError('QuardaticLS requires closure')
|
|
52
|
-
closure = vars.closure
|
|
53
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
54
|
-
grad = vars.grad
|
|
55
|
-
if grad is None: grad = vars.ascent # in case we used FDM
|
|
56
|
-
if grad is None: raise ValueError('QuardaticLS requires gradients.')
|
|
57
|
-
|
|
58
|
-
params = self.get_params()
|
|
59
|
-
lr: float = self.get_first_group_key('lr') # this doesn't support variable lrs but we still want to support schedulers
|
|
60
|
-
|
|
61
|
-
# directional f'(x0)
|
|
62
|
-
# for each lr we step by this much
|
|
63
|
-
dfx0 = magn = grad.total_vector_norm(2)
|
|
64
|
-
|
|
65
|
-
# f(x1)
|
|
66
|
-
fx1 = self._evaluate_lr_(lr, closure, grad, params)
|
|
67
|
-
|
|
68
|
-
# make min_dist relative
|
|
69
|
-
min_dist = abs(lr) * self.min_dist
|
|
70
|
-
points = sorted([Point(0, _ensure_float(vars.fx0), dfx0), Point(lr, _ensure_float(fx1))], key = lambda x: x.fx)
|
|
71
|
-
|
|
72
|
-
for i in range(self.max_evals):
|
|
73
|
-
# find new point
|
|
74
|
-
p1, p2 = points
|
|
75
|
-
if p1.dfx is None: p1, p2 = p2, p1
|
|
76
|
-
xmin, curvature = _step_2poins(p1.x * magn, p1.fx, -p1.dfx, p2.x * magn, p2.fx) # type:ignore
|
|
77
|
-
xmin = _ensure_float(xmin/magn)
|
|
78
|
-
print(f'{xmin = }', f'{curvature = }, n_evals = {i+1}')
|
|
79
|
-
|
|
80
|
-
# if max_evals = 1, we just minimize a quadratic once
|
|
81
|
-
if i == self.max_evals - 1:
|
|
82
|
-
if curvature > 0: return xmin
|
|
83
|
-
return lr
|
|
84
|
-
|
|
85
|
-
# TODO: handle negative curvature
|
|
86
|
-
# if curvature < 0:
|
|
87
|
-
# if points[0].x == 0: return lr
|
|
88
|
-
# return points[0].x
|
|
89
|
-
|
|
90
|
-
# evaluate value and gradients at new point
|
|
91
|
-
fxmin = self._evaluate_lr_(xmin, closure, grad, params, backward=True)
|
|
92
|
-
dfxmin = -(params.grad * grad).total_sum()
|
|
93
|
-
|
|
94
|
-
# insort new point
|
|
95
|
-
bisect.insort(points, Point(xmin, _ensure_float(fxmin), dfxmin), key = lambda x: x.fx)
|
|
96
|
-
|
|
97
|
-
# pick 2 best points to find the new bracketing interval
|
|
98
|
-
points = sorted(points, key = lambda x: x.fx)[:2]
|
|
99
|
-
# TODO: new point might be worse than 2 existing ones which would lead to stagnation
|
|
100
|
-
|
|
101
|
-
# if points are too close, end the loop
|
|
102
|
-
if abs(points[0].x - points[1].x) < min_dist: break
|
|
103
|
-
|
|
104
|
-
return points[0].x
|
|
@@ -1,259 +0,0 @@
|
|
|
1
|
-
import typing as T
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from collections import abc
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ... import tensorlist as tl
|
|
8
|
-
from ...core import OptimizationVars, OptimizerModule, _Chain, _maybe_pass_backward
|
|
9
|
-
# this whole thing can also be implemented via parameter vectors.
|
|
10
|
-
# Need to test which one is more efficient...
|
|
11
|
-
|
|
12
|
-
class Projection(ABC):
|
|
13
|
-
n = 1
|
|
14
|
-
@abstractmethod
|
|
15
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars) -> list[tl.TensorList]:
|
|
16
|
-
"""Generate a projection.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
params (tl.TensorList): tensor list of parameters.
|
|
20
|
-
state (OptimizationState): optimization state object.
|
|
21
|
-
|
|
22
|
-
Returns:
|
|
23
|
-
projection.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
class ProjRandom(Projection):
|
|
27
|
-
def __init__(self, n = 1, distribution: tl.Distributions = 'normal', ):
|
|
28
|
-
self.distribution: tl.Distributions = distribution
|
|
29
|
-
self.n = n
|
|
30
|
-
|
|
31
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
32
|
-
return [params.sample_like(distribution=self.distribution) for _ in range(self.n)]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class Proj2Masks(Projection):
|
|
36
|
-
def __init__(self, n_pairs = 1):
|
|
37
|
-
"""Similar to ProjRandom, but generates pairs of two random masks of 0s and 1s,
|
|
38
|
-
where second mask is an inverse of the first mask."""
|
|
39
|
-
self.n_pairs = n_pairs
|
|
40
|
-
|
|
41
|
-
@property
|
|
42
|
-
def n(self):
|
|
43
|
-
return self.n_pairs * 2
|
|
44
|
-
|
|
45
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
46
|
-
projections = []
|
|
47
|
-
for i in range(self.n_pairs):
|
|
48
|
-
mask = params.bernoulli_like(0.5)
|
|
49
|
-
mask2 = 1 - mask
|
|
50
|
-
projections.append(mask)
|
|
51
|
-
projections.append(mask2)
|
|
52
|
-
|
|
53
|
-
return projections
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class ProjAscent(Projection):
|
|
57
|
-
"""Use ascent direction as the projection."""
|
|
58
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
59
|
-
if vars.ascent is None: raise ValueError
|
|
60
|
-
return [vars.ascent]
|
|
61
|
-
|
|
62
|
-
class ProjAscentRay(Projection):
|
|
63
|
-
def __init__(self, eps = 0.1, n = 1, distribution: tl.Distributions = 'normal', ):
|
|
64
|
-
self.eps = eps
|
|
65
|
-
self.distribution: tl.Distributions = distribution
|
|
66
|
-
self.n = n
|
|
67
|
-
|
|
68
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
69
|
-
if vars.ascent is None: raise ValueError
|
|
70
|
-
mean = params.total_mean().detach().cpu().item()
|
|
71
|
-
return [vars.ascent + vars.ascent.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
|
|
72
|
-
|
|
73
|
-
class ProjGrad(Projection):
|
|
74
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
75
|
-
grad = vars.maybe_compute_grad_(params)
|
|
76
|
-
return [grad]
|
|
77
|
-
|
|
78
|
-
class ProjGradRay(Projection):
|
|
79
|
-
def __init__(self, eps = 0.1, n = 1, distribution: tl.Distributions = 'normal', ):
|
|
80
|
-
self.eps = eps
|
|
81
|
-
self.distribution: tl.Distributions = distribution
|
|
82
|
-
self.n = n
|
|
83
|
-
|
|
84
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
85
|
-
grad = vars.maybe_compute_grad_(params)
|
|
86
|
-
mean = params.total_mean().detach().cpu().item()
|
|
87
|
-
return [grad + grad.sample_like(mean * self.eps, distribution=self.distribution) for _ in range(self.n)]
|
|
88
|
-
|
|
89
|
-
class ProjGradAscentDifference(Projection):
|
|
90
|
-
def __init__(self, normalize=False):
|
|
91
|
-
"""Use difference between gradient and ascent direction as projection.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
normalize (bool, optional): normalizes grads and ascent projection to have norm = 1. Defaults to False.
|
|
95
|
-
"""
|
|
96
|
-
self.normalize = normalize
|
|
97
|
-
|
|
98
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
99
|
-
grad = vars.maybe_compute_grad_(params)
|
|
100
|
-
if self.normalize:
|
|
101
|
-
return [vars.ascent / vars.ascent.total_vector_norm(2) - grad / grad.total_vector_norm(2)] # type:ignore
|
|
102
|
-
|
|
103
|
-
return [vars.ascent - grad] # type:ignore
|
|
104
|
-
|
|
105
|
-
class ProjLastGradDifference(Projection):
|
|
106
|
-
def __init__(self):
|
|
107
|
-
"""Use difference between last two gradients as the projection."""
|
|
108
|
-
self.last_grad = None
|
|
109
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
110
|
-
if self.last_grad is None:
|
|
111
|
-
self.last_grad = vars.maybe_compute_grad_(params)
|
|
112
|
-
return [self.last_grad]
|
|
113
|
-
|
|
114
|
-
grad = vars.maybe_compute_grad_(params)
|
|
115
|
-
diff = grad - self.last_grad
|
|
116
|
-
self.last_grad = grad
|
|
117
|
-
return [diff]
|
|
118
|
-
|
|
119
|
-
class ProjLastAscentDifference(Projection):
|
|
120
|
-
def __init__(self):
|
|
121
|
-
"""Use difference between last two ascent directions as the projection."""
|
|
122
|
-
self.last_direction = T.cast(tl.TensorList, None)
|
|
123
|
-
|
|
124
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars):
|
|
125
|
-
if self.last_direction is None:
|
|
126
|
-
self.last_direction: tl.TensorList = vars.ascent # type:ignore
|
|
127
|
-
return [self.last_direction]
|
|
128
|
-
|
|
129
|
-
diff = vars.ascent - self.last_direction # type:ignore
|
|
130
|
-
self.last_direction = vars.ascent # type:ignore
|
|
131
|
-
return [diff]
|
|
132
|
-
|
|
133
|
-
class ProjNormalize(Projection):
|
|
134
|
-
def __init__(self, *projections: Projection):
|
|
135
|
-
"""Normalizes all projections to have norm = 1."""
|
|
136
|
-
self.projections = projections
|
|
137
|
-
|
|
138
|
-
@property
|
|
139
|
-
def n(self):
|
|
140
|
-
return sum(proj.n for proj in self.projections)
|
|
141
|
-
|
|
142
|
-
def sample(self, params: tl.TensorList, vars: OptimizationVars): # type:ignore
|
|
143
|
-
vecs = [proj for obj in self.projections for proj in obj.sample(params, vars)]
|
|
144
|
-
norms = [v.total_vector_norm(2) for v in vecs]
|
|
145
|
-
return [v/norm if norm!=0 else v.randn_like() for v,norm in zip(vecs,norms)] # type:ignore
|
|
146
|
-
|
|
147
|
-
class Subspace(OptimizerModule):
|
|
148
|
-
"""This is pretty inefficient, I thought of a much better way to do this via jvp and I will rewrite this soon.
|
|
149
|
-
|
|
150
|
-
Optimizes parameters projected into a lower (or higher) dimensional subspace.
|
|
151
|
-
|
|
152
|
-
The subspace is a bunch of projections that go through the current point. Projections can be random,
|
|
153
|
-
or face in the direction of the gradient, or difference between last two gradients, etc. The projections
|
|
154
|
-
are updated every `update_every` steps.
|
|
155
|
-
|
|
156
|
-
Notes:
|
|
157
|
-
This doesn't work with anything that directly calculates the hessian or other quantities via `torch.autograd.grad`,
|
|
158
|
-
like `ExactNewton`. I will have to manually implement a subspace version for it.
|
|
159
|
-
|
|
160
|
-
This also zeroes parameters after each step, meaning it won't work with some integrations like nevergrad
|
|
161
|
-
(as they store their own parameters which don't get zeroed). It does however work with integrations like
|
|
162
|
-
`scipy.optimize` because they performs a full minimization on each step.
|
|
163
|
-
Another version of this which doesn't zero the params is under way.
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
projections (Projection | Iterable[Projection]):
|
|
167
|
-
list of projections - `Projection` objects that define the directions of the projections.
|
|
168
|
-
Each Projection object may generate one or multiple directions.
|
|
169
|
-
update_every (int, optional): generates new projections every n steps. Defaults to 1.
|
|
170
|
-
"""
|
|
171
|
-
def __init__(
|
|
172
|
-
self,
|
|
173
|
-
modules: OptimizerModule | abc.Iterable[OptimizerModule],
|
|
174
|
-
projections: Projection | abc.Iterable[Projection],
|
|
175
|
-
update_every: int | None = 1,
|
|
176
|
-
):
|
|
177
|
-
super().__init__({})
|
|
178
|
-
if isinstance(projections, Projection): projections = [projections]
|
|
179
|
-
self.projections = list(projections)
|
|
180
|
-
self._set_child_('subspace', modules)
|
|
181
|
-
self.update_every = update_every
|
|
182
|
-
self.current_step = 0
|
|
183
|
-
|
|
184
|
-
# cast them because they are guaranteed to be assigned on 1st step.
|
|
185
|
-
self.projection_vectors = T.cast(list[tl.TensorList], None)
|
|
186
|
-
self.projected_params = T.cast(torch.Tensor, None)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
def _update_child_params_(self, child: "OptimizerModule"):
|
|
190
|
-
dtype = self._params[0].dtype
|
|
191
|
-
device = self._params[0].device
|
|
192
|
-
params = [torch.zeros(sum(proj.n for proj in self.projections), dtype = dtype, device = device, requires_grad=True)]
|
|
193
|
-
if child._has_custom_params: raise RuntimeError(f"Subspace child {child.__class__.__name__} can't have custom params.")
|
|
194
|
-
if not child._initialized:
|
|
195
|
-
child._initialize_(params, set_passed_params=False)
|
|
196
|
-
else:
|
|
197
|
-
child.param_groups = []
|
|
198
|
-
child.add_param_group({"params": params})
|
|
199
|
-
|
|
200
|
-
@torch.no_grad
|
|
201
|
-
def step(self, vars):
|
|
202
|
-
#if self.next_module is None: raise ValueError('RandomProjection needs a child')
|
|
203
|
-
if vars.closure is None: raise ValueError('RandomProjection needs a closure')
|
|
204
|
-
closure = vars.closure
|
|
205
|
-
params = self.get_params()
|
|
206
|
-
|
|
207
|
-
# every `regenerate_every` steps we generate new random projections.
|
|
208
|
-
if self.current_step == 0 or (self.update_every is not None and self.current_step % self.update_every == 0):
|
|
209
|
-
|
|
210
|
-
# generate n projection vetors
|
|
211
|
-
self.projection_vectors = [sample for proj in self.projections for sample in proj.sample(params, vars)]
|
|
212
|
-
|
|
213
|
-
# child params is n scalars corresponding to each projection vector
|
|
214
|
-
self.projected_params = self.children['subspace']._params[0] # type:ignore
|
|
215
|
-
|
|
216
|
-
# closure that takes the projected params from the child, puts them into full space params, and evaluates the loss
|
|
217
|
-
def projected_closure(backward = True):
|
|
218
|
-
residual = sum(vec * p for vec, p in zip(self.projection_vectors, self.projected_params))
|
|
219
|
-
|
|
220
|
-
# this in-place operation prevents autodiff from working
|
|
221
|
-
# we manually calculate the gradients as they are just a product
|
|
222
|
-
# therefore we need torch.no_grad here because optimizers call closure under torch.enabled_grad
|
|
223
|
-
with torch.no_grad(): params.add_(residual)
|
|
224
|
-
|
|
225
|
-
loss = _maybe_pass_backward(closure, backward)
|
|
226
|
-
|
|
227
|
-
if backward:
|
|
228
|
-
self.projected_params.grad = torch.cat([(params.grad * vec).total_sum().unsqueeze(0) for vec in self.projection_vectors])
|
|
229
|
-
with torch.no_grad(): params.sub_(residual)
|
|
230
|
-
return loss
|
|
231
|
-
|
|
232
|
-
# # if ascent direction is provided,
|
|
233
|
-
# # project the ascent direction into the projection space (need to test if this works)
|
|
234
|
-
# if ascent_direction is not None:
|
|
235
|
-
# ascent_direction = tl.sum([ascent_direction*v for v in self.projection_vectors])
|
|
236
|
-
|
|
237
|
-
# perform a step with the child
|
|
238
|
-
subspace_state = vars.copy(False)
|
|
239
|
-
subspace_state.closure = projected_closure
|
|
240
|
-
subspace_state.ascent = None
|
|
241
|
-
if subspace_state.grad is not None:
|
|
242
|
-
subspace_state.grad = tl.TensorList([torch.cat([(params.grad * vec).total_sum().unsqueeze(0) for vec in self.projection_vectors])])
|
|
243
|
-
self.children['subspace'].step(subspace_state) # type:ignore
|
|
244
|
-
|
|
245
|
-
# that is going to update child's paramers, which we now project back to the full parameter space
|
|
246
|
-
residual = tl.sum([vec * p for vec, p in zip(self.projection_vectors, self.projected_params)])
|
|
247
|
-
vars.ascent = residual.neg_()
|
|
248
|
-
|
|
249
|
-
# move fx0 and fx0 approx to state
|
|
250
|
-
if subspace_state.fx0 is not None: vars.fx0 = subspace_state.fx0
|
|
251
|
-
if subspace_state.fx0_approx is not None: vars.fx0 = subspace_state.fx0_approx
|
|
252
|
-
# projected_params are residuals that have been applied to actual params on previous step in some way
|
|
253
|
-
# therefore they need to now become zero (otherwise they work like momentum with no decay).
|
|
254
|
-
# note: THIS WON'T WORK WITH INTEGRATIONS, UNLESS THEY PERFORM FULL MINIMIZATION EACH STEP
|
|
255
|
-
# because their params won't be zeroed.
|
|
256
|
-
self.projected_params.zero_()
|
|
257
|
-
|
|
258
|
-
self.current_step += 1
|
|
259
|
-
return self._update_params_or_step_with_next(vars)
|