torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
from ...core import Module, Target, Transform
|
|
5
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
5
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def _bool_ones_like(x):
|
|
@@ -135,7 +135,8 @@ class Rprop(Transform):
|
|
|
135
135
|
Next step, magnitude for that weight won't change.
|
|
136
136
|
|
|
137
137
|
Compared to pytorch this also implements backtracking update when sign changes.
|
|
138
|
-
|
|
138
|
+
|
|
139
|
+
This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
|
|
139
140
|
|
|
140
141
|
Args:
|
|
141
142
|
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
@@ -161,20 +162,22 @@ class Rprop(Transform):
|
|
|
161
162
|
alpha: float = 1,
|
|
162
163
|
):
|
|
163
164
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
|
|
164
|
-
self.current_step = 0
|
|
165
165
|
super().__init__(defaults, uses_grad=False)
|
|
166
166
|
|
|
167
167
|
@torch.no_grad
|
|
168
|
-
def
|
|
169
|
-
|
|
170
|
-
|
|
168
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
169
|
+
step = self.global_state.get('step', 0)
|
|
170
|
+
self.global_state['step'] = step + 1
|
|
171
|
+
|
|
172
|
+
nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
|
|
173
|
+
prev, allowed, magnitudes = unpack_states(
|
|
174
|
+
states, tensors,
|
|
171
175
|
'prev','allowed','magnitudes',
|
|
172
|
-
params=params,
|
|
173
176
|
init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
|
|
174
177
|
cls = TensorList,
|
|
175
178
|
)
|
|
176
179
|
|
|
177
|
-
|
|
180
|
+
tensors = rprop_(
|
|
178
181
|
tensors_ = as_tensorlist(tensors),
|
|
179
182
|
prev_ = prev,
|
|
180
183
|
allowed_ = allowed,
|
|
@@ -184,12 +187,11 @@ class Rprop(Transform):
|
|
|
184
187
|
lb = lb,
|
|
185
188
|
ub = ub,
|
|
186
189
|
alpha = alpha,
|
|
187
|
-
backtrack=
|
|
188
|
-
step=
|
|
190
|
+
backtrack=settings[0]['backtrack'],
|
|
191
|
+
step=step,
|
|
189
192
|
)
|
|
190
193
|
|
|
191
|
-
|
|
192
|
-
return target
|
|
194
|
+
return tensors
|
|
193
195
|
|
|
194
196
|
|
|
195
197
|
class ScaleLRBySignChange(Transform):
|
|
@@ -220,23 +222,25 @@ class ScaleLRBySignChange(Transform):
|
|
|
220
222
|
):
|
|
221
223
|
defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
|
|
222
224
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
223
|
-
self.current_step = 0
|
|
224
225
|
|
|
225
226
|
@torch.no_grad
|
|
226
|
-
def
|
|
227
|
-
|
|
228
|
-
|
|
227
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
228
|
+
step = self.global_state.get('step', 0)
|
|
229
|
+
self.global_state['step'] = step + 1
|
|
230
|
+
|
|
231
|
+
tensors = as_tensorlist(tensors)
|
|
232
|
+
use_grad = settings[0]['use_grad']
|
|
229
233
|
if use_grad: cur = as_tensorlist(grads)
|
|
230
|
-
else: cur =
|
|
234
|
+
else: cur = tensors
|
|
231
235
|
|
|
232
|
-
nplus, nminus, lb, ub =
|
|
233
|
-
prev, lrs =
|
|
236
|
+
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
237
|
+
prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
|
|
234
238
|
|
|
235
|
-
if
|
|
236
|
-
lrs.set_(
|
|
239
|
+
if step == 0:
|
|
240
|
+
lrs.set_(tensors.full_like([s['alpha'] for s in settings]))
|
|
237
241
|
|
|
238
|
-
|
|
239
|
-
tensors_ =
|
|
242
|
+
tensors = scale_by_sign_change_(
|
|
243
|
+
tensors_ = tensors,
|
|
240
244
|
cur = cur,
|
|
241
245
|
prev_ = prev,
|
|
242
246
|
lrs_ = lrs,
|
|
@@ -244,10 +248,9 @@ class ScaleLRBySignChange(Transform):
|
|
|
244
248
|
nminus = nminus,
|
|
245
249
|
lb = lb,
|
|
246
250
|
ub = ub,
|
|
247
|
-
step =
|
|
251
|
+
step = step,
|
|
248
252
|
)
|
|
249
|
-
|
|
250
|
-
return target
|
|
253
|
+
return tensors
|
|
251
254
|
|
|
252
255
|
class BacktrackOnSignChange(Transform):
|
|
253
256
|
"""Negates or undoes update for parameters where where gradient or update sign changes.
|
|
@@ -268,44 +271,77 @@ class BacktrackOnSignChange(Transform):
|
|
|
268
271
|
def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
|
|
269
272
|
defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
|
|
270
273
|
super().__init__(defaults, uses_grad=use_grad)
|
|
271
|
-
self.current_step = 0
|
|
272
274
|
|
|
273
275
|
@torch.no_grad
|
|
274
|
-
def
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
276
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
277
|
+
step = self.global_state.get('step', 0)
|
|
278
|
+
self.global_state['step'] = step + 1
|
|
279
|
+
|
|
280
|
+
tensors = as_tensorlist(tensors)
|
|
281
|
+
use_grad = settings[0]['use_grad']
|
|
282
|
+
backtrack = settings[0]['backtrack']
|
|
279
283
|
|
|
280
284
|
if use_grad: cur = as_tensorlist(grads)
|
|
281
|
-
else: cur =
|
|
285
|
+
else: cur = tensors
|
|
282
286
|
|
|
283
|
-
|
|
284
|
-
tensors_ =
|
|
287
|
+
tensors = backtrack_on_sign_change_(
|
|
288
|
+
tensors_ = tensors,
|
|
285
289
|
cur = cur,
|
|
286
|
-
prev_ =
|
|
290
|
+
prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
|
|
287
291
|
backtrack = backtrack,
|
|
288
|
-
step =
|
|
292
|
+
step = step,
|
|
289
293
|
)
|
|
290
294
|
|
|
291
|
-
|
|
292
|
-
return target
|
|
295
|
+
return tensors
|
|
293
296
|
|
|
294
297
|
class SignConsistencyMask(Transform):
|
|
295
|
-
"""
|
|
298
|
+
"""
|
|
299
|
+
Outputs a mask of sign consistency of current and previous inputs.
|
|
300
|
+
|
|
301
|
+
The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
|
|
302
|
+
|
|
303
|
+
Examples:
|
|
304
|
+
|
|
305
|
+
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
306
|
+
|
|
307
|
+
.. code-block:: python
|
|
308
|
+
|
|
309
|
+
opt = tz.Modular(
|
|
310
|
+
model.parameters(),
|
|
311
|
+
tz.m.Mul(tz.m.SignConsistencyMask()),
|
|
312
|
+
tz.m.LR(1e-2)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
"""
|
|
296
316
|
def __init__(self,target: Target = 'update'):
|
|
297
317
|
super().__init__({}, uses_grad=False, target = target)
|
|
298
318
|
|
|
299
319
|
@torch.no_grad
|
|
300
|
-
def
|
|
301
|
-
prev =
|
|
320
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
321
|
+
prev = unpack_states(states, tensors, 'prev', cls=TensorList)
|
|
302
322
|
mask = prev.mul_(tensors).gt_(0)
|
|
303
|
-
prev.
|
|
323
|
+
prev.copy_(tensors)
|
|
304
324
|
return mask
|
|
305
325
|
|
|
306
326
|
|
|
307
327
|
class SignConsistencyLRs(Transform):
|
|
308
|
-
"""
|
|
328
|
+
"""Outputs per-weight learning rates based on consecutive sign consistency.
|
|
329
|
+
|
|
330
|
+
The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
|
|
331
|
+
|
|
332
|
+
Examples:
|
|
333
|
+
|
|
334
|
+
GD scaled by consecutive gradient sign consistency
|
|
335
|
+
|
|
336
|
+
.. code-block:: python
|
|
337
|
+
|
|
338
|
+
opt = tz.Modular(
|
|
339
|
+
model.parameters(),
|
|
340
|
+
tz.m.Mul(tz.m.SignConsistencyLRs()),
|
|
341
|
+
tz.m.LR(1e-2)
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
"""
|
|
309
345
|
def __init__(
|
|
310
346
|
self,
|
|
311
347
|
nplus: float = 1.2,
|
|
@@ -317,16 +353,18 @@ class SignConsistencyLRs(Transform):
|
|
|
317
353
|
):
|
|
318
354
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
319
355
|
super().__init__(defaults, uses_grad=False, target = target)
|
|
320
|
-
self.current_step = 0
|
|
321
356
|
|
|
322
357
|
@torch.no_grad
|
|
323
|
-
def
|
|
358
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
359
|
+
step = self.global_state.get('step', 0)
|
|
360
|
+
self.global_state['step'] = step + 1
|
|
361
|
+
|
|
324
362
|
target = as_tensorlist(tensors)
|
|
325
|
-
nplus, nminus, lb, ub =
|
|
326
|
-
prev, lrs =
|
|
363
|
+
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
364
|
+
prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
|
|
327
365
|
|
|
328
|
-
if
|
|
329
|
-
lrs.set_(target.full_like(
|
|
366
|
+
if step == 0:
|
|
367
|
+
lrs.set_(target.full_like([s['alpha'] for s in settings]))
|
|
330
368
|
|
|
331
369
|
target = sign_consistency_lrs_(
|
|
332
370
|
tensors = target,
|
|
@@ -336,7 +374,6 @@ class SignConsistencyLRs(Transform):
|
|
|
336
374
|
nminus = nminus,
|
|
337
375
|
lb = lb,
|
|
338
376
|
ub = ub,
|
|
339
|
-
step =
|
|
377
|
+
step = step,
|
|
340
378
|
)
|
|
341
|
-
self.current_step += 1
|
|
342
379
|
return target.clone()
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
import torch
|
|
3
|
+
from ...utils import TensorList, NumberList
|
|
4
|
+
from ...core import Module
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SAM(Module):
|
|
8
|
+
"""Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
|
|
9
|
+
|
|
10
|
+
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
11
|
+
It performs two forward and backward passes per step.
|
|
12
|
+
|
|
13
|
+
This implementation modifies the closure to return loss and calculate gradients
|
|
14
|
+
of the SAM objective. All modules after this will use the modified objective.
|
|
15
|
+
|
|
16
|
+
.. note::
|
|
17
|
+
This module requires a closure passed to the optimizer step,
|
|
18
|
+
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
22
|
+
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
23
|
+
asam (bool, optional):
|
|
24
|
+
enables ASAM variant which makes perturbation relative to weight magnitudes.
|
|
25
|
+
ASAM requires a much larger :code:`rho`, like 0.5 or 1.
|
|
26
|
+
The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
|
|
27
|
+
it has larger :code:`rho` by default.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
SAM-SGD:
|
|
31
|
+
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
|
|
34
|
+
opt = tz.Modular(
|
|
35
|
+
model.parameters(),
|
|
36
|
+
tz.m.SAM(),
|
|
37
|
+
tz.m.LR(1e-2)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SAM-Adam:
|
|
41
|
+
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
opt = tz.Modular(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.SAM(),
|
|
47
|
+
tz.m.Adam(),
|
|
48
|
+
tz.m.LR(1e-2)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
References:
|
|
52
|
+
Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
|
|
53
|
+
"""
|
|
54
|
+
def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
|
|
55
|
+
defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
|
|
56
|
+
super().__init__(defaults)
|
|
57
|
+
|
|
58
|
+
@torch.no_grad
|
|
59
|
+
def step(self, var):
|
|
60
|
+
|
|
61
|
+
params = var.params
|
|
62
|
+
closure = var.closure
|
|
63
|
+
zero_grad = var.zero_grad
|
|
64
|
+
if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
|
|
65
|
+
p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
|
|
66
|
+
s = self.settings[var.params[0]]
|
|
67
|
+
eps = s['eps']
|
|
68
|
+
asam = s['asam']
|
|
69
|
+
|
|
70
|
+
# 1/p + 1/q = 1
|
|
71
|
+
# okay, authors of SAM paper, I will manually solve your equation
|
|
72
|
+
# so q = -p/(1-p)
|
|
73
|
+
q = -p / (1-p)
|
|
74
|
+
# as a validation for 2 it is -2 / -1 = 2
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def sam_closure(backward=True):
|
|
78
|
+
orig_grads = None
|
|
79
|
+
if not backward:
|
|
80
|
+
# if backward is False, make sure this doesn't modify gradients
|
|
81
|
+
# to avoid issues
|
|
82
|
+
orig_grads = [p.grad for p in params]
|
|
83
|
+
|
|
84
|
+
# gradient at initial parameters
|
|
85
|
+
zero_grad()
|
|
86
|
+
with torch.enable_grad():
|
|
87
|
+
closure()
|
|
88
|
+
|
|
89
|
+
grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
|
|
90
|
+
grad_abs = grad.abs()
|
|
91
|
+
|
|
92
|
+
# compute e
|
|
93
|
+
term1 = grad.sign().mul_(rho)
|
|
94
|
+
term2 = grad_abs.pow(q-1)
|
|
95
|
+
|
|
96
|
+
if asam:
|
|
97
|
+
grad_abs.mul_(torch._foreach_abs(params))
|
|
98
|
+
|
|
99
|
+
denom = grad_abs.pow_(q).sum().pow(1/p)
|
|
100
|
+
|
|
101
|
+
e = term1.mul_(term2).div_(denom.clip(min=eps))
|
|
102
|
+
|
|
103
|
+
if asam:
|
|
104
|
+
e.mul_(torch._foreach_pow(params, 2))
|
|
105
|
+
|
|
106
|
+
# calculate loss and gradient approximation of inner problem
|
|
107
|
+
torch._foreach_add_(params, e)
|
|
108
|
+
if backward:
|
|
109
|
+
zero_grad()
|
|
110
|
+
with torch.enable_grad():
|
|
111
|
+
# this sets .grad attributes
|
|
112
|
+
sam_loss = closure()
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
sam_loss = closure(False)
|
|
116
|
+
|
|
117
|
+
# and restore initial parameters
|
|
118
|
+
torch._foreach_sub_(params, e)
|
|
119
|
+
|
|
120
|
+
if orig_grads is not None:
|
|
121
|
+
for param,orig_grad in zip(params, orig_grads):
|
|
122
|
+
param.grad = orig_grad
|
|
123
|
+
|
|
124
|
+
return sam_loss
|
|
125
|
+
|
|
126
|
+
var.closure = sam_closure
|
|
127
|
+
return var
|
|
128
|
+
|
|
129
|
+
# different class because defaults for SAM are bad for ASAM
|
|
130
|
+
class ASAM(SAM):
|
|
131
|
+
"""Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
|
|
132
|
+
|
|
133
|
+
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
134
|
+
It performs two forward and backward passes per step.
|
|
135
|
+
|
|
136
|
+
This implementation modifies the closure to return loss and calculate gradients
|
|
137
|
+
of the SAM objective. All modules after this will use the modified objective.
|
|
138
|
+
|
|
139
|
+
.. note::
|
|
140
|
+
This module requires a closure passed to the optimizer step,
|
|
141
|
+
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
145
|
+
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
146
|
+
|
|
147
|
+
Examples:
|
|
148
|
+
ASAM-Adam:
|
|
149
|
+
|
|
150
|
+
.. code-block:: python
|
|
151
|
+
|
|
152
|
+
opt = tz.Modular(
|
|
153
|
+
model.parameters(),
|
|
154
|
+
tz.m.ASAM(),
|
|
155
|
+
tz.m.Adam(),
|
|
156
|
+
tz.m.LR(1e-2)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
References:
|
|
160
|
+
Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
|
|
161
|
+
"""
|
|
162
|
+
def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
|
|
163
|
+
super().__init__(rho=rho, p=p, eps=eps, asam=True)
|
|
@@ -4,7 +4,7 @@ from functools import partial
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Transform,
|
|
7
|
+
from ...core import Chainable, Transform, apply_transform
|
|
8
8
|
from ...utils.linalg import matrix_power_eigh
|
|
9
9
|
from ...utils import set_storage_
|
|
10
10
|
|
|
@@ -59,7 +59,7 @@ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
|
|
|
59
59
|
if tensor.shape[sort_idxs[0]] > max_dim:
|
|
60
60
|
return tensor, None, None
|
|
61
61
|
|
|
62
|
-
tensor = tensor.permute(*sort_idxs)
|
|
62
|
+
tensor = tensor.permute(*sort_idxs.tolist())
|
|
63
63
|
flatten_end_idx = 0
|
|
64
64
|
flat_sizes = []
|
|
65
65
|
flat_numel = 1
|
|
@@ -80,19 +80,28 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
|
|
|
80
80
|
if flat_sizes is None: return tensor
|
|
81
81
|
assert sort_idxs is not None
|
|
82
82
|
tensor = tensor.unflatten(0, flat_sizes)
|
|
83
|
-
return tensor.permute(*np.argsort(sort_idxs))
|
|
83
|
+
return tensor.permute(*np.argsort(sort_idxs).tolist())
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
class Shampoo(Transform):
|
|
87
87
|
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
88
88
|
|
|
89
|
+
.. note::
|
|
90
|
+
Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
|
|
91
|
+
|
|
92
|
+
.. note::
|
|
93
|
+
Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
|
|
94
|
+
|
|
95
|
+
.. note::
|
|
96
|
+
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
|
|
97
|
+
|
|
89
98
|
Args:
|
|
90
99
|
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
91
100
|
beta (float | None, optional):
|
|
92
101
|
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
93
102
|
matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
|
|
94
103
|
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
95
|
-
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to
|
|
104
|
+
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
|
|
96
105
|
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
97
106
|
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
|
|
98
107
|
precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
|
|
@@ -101,39 +110,62 @@ class Shampoo(Transform):
|
|
|
101
110
|
module applied after updating preconditioners and before applying preconditioning.
|
|
102
111
|
For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
|
|
103
112
|
Defaults to None.
|
|
113
|
+
|
|
114
|
+
Examples:
|
|
115
|
+
Shampoo grafted to Adam
|
|
116
|
+
|
|
117
|
+
.. code-block:: python
|
|
118
|
+
|
|
119
|
+
opt = tz.Modular(
|
|
120
|
+
model.parameters(),
|
|
121
|
+
tz.m.GraftModules(
|
|
122
|
+
direction = tz.m.Shampoo(),
|
|
123
|
+
magnitude = tz.m.Adam(),
|
|
124
|
+
),
|
|
125
|
+
tz.m.LR(1e-3)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
Adam with Shampoo preconditioner
|
|
129
|
+
|
|
130
|
+
.. code-block:: python
|
|
131
|
+
|
|
132
|
+
opt = tz.Modular(
|
|
133
|
+
model.parameters(),
|
|
134
|
+
tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
135
|
+
tz.m.Debias(0.9, 0.999),
|
|
136
|
+
tz.m.LR(1e-3)
|
|
137
|
+
)
|
|
104
138
|
"""
|
|
105
139
|
def __init__(
|
|
106
140
|
self,
|
|
107
141
|
decay: float | None = None,
|
|
108
142
|
beta: float | None = None,
|
|
109
|
-
reg: float = 1e-6,
|
|
110
143
|
update_freq: int = 10,
|
|
111
|
-
exp_override: int | None =
|
|
144
|
+
exp_override: int | None = 2,
|
|
112
145
|
merge_small: bool = True,
|
|
113
146
|
max_dim: int = 2_000,
|
|
114
147
|
precondition_1d: bool = True,
|
|
115
148
|
adagrad_eps: float = 1e-8,
|
|
116
149
|
inner: Chainable | None = None,
|
|
117
150
|
):
|
|
118
|
-
defaults = dict(decay=decay, beta=beta,
|
|
151
|
+
defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
|
|
119
152
|
super().__init__(defaults, uses_grad=False)
|
|
120
153
|
|
|
121
154
|
if inner is not None:
|
|
122
155
|
self.set_child('inner', inner)
|
|
123
156
|
|
|
124
|
-
def
|
|
125
|
-
|
|
157
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
158
|
+
merged_tensors = [] # target with merged dims
|
|
126
159
|
|
|
127
160
|
# update preconditioners
|
|
128
|
-
for i,(
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
beta, reg, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
132
|
-
'beta', 'reg', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(settings)
|
|
161
|
+
for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
|
|
162
|
+
beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
163
|
+
'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
|
|
133
164
|
|
|
134
165
|
if merge_small:
|
|
135
166
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
136
|
-
|
|
167
|
+
|
|
168
|
+
merged_tensors.append(t)
|
|
137
169
|
|
|
138
170
|
# initialize accumulators and preconditioners for each dim on 1st step
|
|
139
171
|
if 'accumulators' not in state:
|
|
@@ -167,22 +199,18 @@ class Shampoo(Transform):
|
|
|
167
199
|
|
|
168
200
|
# inner step
|
|
169
201
|
if 'inner' in self.children:
|
|
170
|
-
tensors =
|
|
202
|
+
tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)
|
|
171
203
|
|
|
172
204
|
# have to merge small dims again
|
|
173
|
-
|
|
174
|
-
for i,(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, settings['max_dim'])
|
|
179
|
-
merged_target.append(t)
|
|
205
|
+
merged_tensors = [] # target with merged dims
|
|
206
|
+
for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
|
|
207
|
+
if setting['merge_small']:
|
|
208
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
|
|
209
|
+
merged_tensors.append(t)
|
|
180
210
|
|
|
181
211
|
# precondition
|
|
182
|
-
for i,
|
|
183
|
-
|
|
184
|
-
settings = self.settings[p]
|
|
185
|
-
decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(settings)
|
|
212
|
+
for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
|
|
213
|
+
decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)
|
|
186
214
|
|
|
187
215
|
if 'diagonal_accumulator' in state:
|
|
188
216
|
tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
|