heavyball 1.7.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +276 -37
- heavyball/chainable.py +419 -206
- heavyball/helpers.py +808 -0
- heavyball/utils.py +1062 -315
- heavyball-2.0.0.dist-info/METADATA +122 -0
- heavyball-2.0.0.dist-info/RECORD +9 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/WHEEL +1 -1
- heavyball-1.7.2.dist-info/METADATA +0 -939
- heavyball-1.7.2.dist-info/RECORD +0 -8
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
|
+
import copy
|
1
2
|
import functools
|
2
3
|
import math
|
3
4
|
import random
|
4
|
-
from
|
5
|
+
from collections.abc import Iterable as _Iterable
|
6
|
+
from typing import Iterable, List, Literal, Optional, Union
|
5
7
|
|
6
8
|
import torch
|
7
9
|
from torch import Tensor
|
8
10
|
|
9
11
|
from . import utils
|
10
12
|
|
11
|
-
|
13
|
+
use_default = utils.use_default
|
12
14
|
|
13
15
|
|
14
16
|
def _key_in_state(state, key):
|
@@ -36,20 +38,68 @@ def _guard_in_state(state, key, template_fn):
|
|
36
38
|
|
37
39
|
|
38
40
|
class FunctionTransform:
|
39
|
-
def __init__(self, fn):
|
41
|
+
def __init__(self, fn, names: list[str] | None = None):
|
42
|
+
if names is None:
|
43
|
+
names = []
|
40
44
|
self.fn = fn
|
41
45
|
self.fn_name = self.get_fn().__name__
|
46
|
+
self.transform_idx = None
|
47
|
+
self.is_initialized = False
|
48
|
+
self.names = names
|
42
49
|
|
43
|
-
def
|
50
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
51
|
+
raise NotImplementedError
|
52
|
+
|
53
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
44
54
|
raise NotImplementedError
|
45
55
|
|
56
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
57
|
+
states = [state(p) for p in param]
|
58
|
+
skip_update = False
|
59
|
+
for st, a in zip(states, zip(update, grad, param, *args)):
|
60
|
+
if self.transform_idx not in st.get("is_initialized", set()):
|
61
|
+
try:
|
62
|
+
self._init(st, group, *a, **kwargs)
|
63
|
+
except SkipUpdate:
|
64
|
+
skip_update = True
|
65
|
+
except:
|
66
|
+
raise
|
67
|
+
finally:
|
68
|
+
if "is_initialized" not in st:
|
69
|
+
st["is_initialized"] = set()
|
70
|
+
st["is_initialized"].add(self.transform_idx)
|
71
|
+
if skip_update:
|
72
|
+
raise SkipUpdate from None
|
73
|
+
vars = [[st.get(self.val_name(name), None) for st in states] for name in self.names]
|
74
|
+
return self._call(state, group, update, grad, param, vars, *args, **kwargs)
|
75
|
+
|
46
76
|
def get_fn(self):
|
47
77
|
if utils.hasattr_none(self.fn, "get_fn"):
|
48
78
|
return self.fn.get_fn()
|
49
79
|
return self.fn
|
50
80
|
|
51
81
|
def val_name(self, name):
|
52
|
-
|
82
|
+
assert self.transform_idx is not None
|
83
|
+
return f"{self.fn_name}_{name}_{self.transform_idx}"
|
84
|
+
|
85
|
+
def __repr__(self):
|
86
|
+
return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
|
87
|
+
|
88
|
+
|
89
|
+
class Branch:
|
90
|
+
def __init__(self, branches: List[List[callable]], merge_fn: callable):
|
91
|
+
self.branches = branches
|
92
|
+
self.merge_fn = merge_fn
|
93
|
+
|
94
|
+
def __call__(self, state, group, update, grad, param):
|
95
|
+
outputs = []
|
96
|
+
for branch in self.branches:
|
97
|
+
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
98
|
+
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
99
|
+
if skip_update:
|
100
|
+
raise ValueError("Branches should not skip updates")
|
101
|
+
outputs.append(branch_update)
|
102
|
+
return self.merge_fn(outputs)
|
53
103
|
|
54
104
|
|
55
105
|
def _zero_guard(state, key, ref, dtype):
|
@@ -63,49 +113,102 @@ def _storage_dtype(group):
|
|
63
113
|
|
64
114
|
class ZeroGuard(FunctionTransform):
|
65
115
|
def __init__(self, fn, names):
|
66
|
-
super().__init__(fn)
|
67
|
-
self.names = names
|
116
|
+
super().__init__(fn, names)
|
68
117
|
|
69
|
-
def
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
118
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
119
|
+
for name in self.names:
|
120
|
+
_zero_guard(state, self.val_name(name), param, _storage_dtype(group))
|
121
|
+
|
122
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
74
123
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
75
124
|
|
76
125
|
|
126
|
+
class PrecondGradAccumGuard(FunctionTransform):
|
127
|
+
def __init__(self, fn):
|
128
|
+
super().__init__(fn, ["precond_grad_accum"])
|
129
|
+
self.steps_taken = 0
|
130
|
+
self.pass_through = None
|
131
|
+
|
132
|
+
def _accum(self, state, new):
|
133
|
+
self.steps_taken += 1
|
134
|
+
utils.stochastic_add_(state, new)
|
135
|
+
|
136
|
+
def _reset(self, state):
|
137
|
+
if self.steps_taken != 0:
|
138
|
+
self.steps_taken = 0
|
139
|
+
utils.zero_(state)
|
140
|
+
|
141
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
142
|
+
if self.pass_through is None:
|
143
|
+
self.pass_through = not group.get("precond_grad_accum", False)
|
144
|
+
if self.pass_through is False:
|
145
|
+
for name in self.names:
|
146
|
+
_zero_guard(state, self.val_name(name), param, _storage_dtype(group))
|
147
|
+
|
148
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
149
|
+
base_grad = update if group.get("momentum_into_precond_update", True) else grad
|
150
|
+
if self.pass_through:
|
151
|
+
return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs)
|
152
|
+
|
153
|
+
(vars,) = vars
|
154
|
+
if group["is_preconditioning"]:
|
155
|
+
if self.steps_taken:
|
156
|
+
self._accum(vars, base_grad)
|
157
|
+
utils.stochastic_multiply_(vars, 1 / self.steps_taken)
|
158
|
+
else:
|
159
|
+
vars = base_grad
|
160
|
+
else:
|
161
|
+
self._accum(vars, base_grad)
|
162
|
+
vars = base_grad
|
163
|
+
try:
|
164
|
+
out = self.fn(state, group, update, grad, param, *args, vars, **kwargs)
|
165
|
+
finally:
|
166
|
+
if group["is_preconditioning"]:
|
167
|
+
self._reset(vars)
|
168
|
+
|
169
|
+
return out
|
170
|
+
|
171
|
+
|
77
172
|
class CopyGuard(FunctionTransform):
|
78
173
|
def __init__(self, fn, index, names):
|
79
|
-
super().__init__(fn)
|
174
|
+
super().__init__(fn, names)
|
80
175
|
self.index = index
|
81
|
-
self.names = names
|
82
176
|
|
83
|
-
def
|
177
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
84
178
|
val = [update, grad, param, *args][self.index]
|
85
|
-
|
86
|
-
[
|
87
|
-
|
88
|
-
|
179
|
+
for name in self.names:
|
180
|
+
state[self.val_name(name)] = torch.clone(val)
|
181
|
+
|
182
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
89
183
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
90
184
|
|
91
185
|
|
92
|
-
class GeneralGuard(FunctionTransform):
|
186
|
+
class GeneralGuard(FunctionTransform):
|
93
187
|
def __init__(self, fn, names, init_fn, skip_first: bool = True):
|
94
|
-
super().__init__(fn)
|
95
|
-
self.names = names
|
188
|
+
super().__init__(fn, names)
|
96
189
|
self.init_fn = init_fn
|
97
190
|
self.skip_first = skip_first
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
for p
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
191
|
+
self.named_to_anonymous = None
|
192
|
+
self.anonymous_to_named = None
|
193
|
+
|
194
|
+
def _map(self, state_fn, param, mapping):
|
195
|
+
for p in param:
|
196
|
+
state = state_fn(p)
|
197
|
+
for name, mapped in mapping.items():
|
198
|
+
if mapped in state:
|
199
|
+
raise ValueError(f"Name {name} already mapped to {mapped}")
|
200
|
+
if name in state:
|
201
|
+
state[mapped] = state.pop(name)
|
202
|
+
|
203
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
204
|
+
self.init_fn(state, group, update, grad, param, **kwargs)
|
205
|
+
for name in self.names:
|
206
|
+
state[self.val_name(name)] = state.pop(name, None)
|
207
|
+
if self.skip_first:
|
208
|
+
raise SkipUpdate from None
|
209
|
+
|
210
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
211
|
+
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
109
212
|
|
110
213
|
|
111
214
|
class NoState(FunctionTransform):
|
@@ -124,10 +227,27 @@ class NoStateNoForeach(FunctionTransform):
|
|
124
227
|
skip_update = True
|
125
228
|
pass
|
126
229
|
if skip_update:
|
127
|
-
raise SkipUpdate
|
230
|
+
raise SkipUpdate from None
|
128
231
|
return updates
|
129
232
|
|
130
233
|
|
234
|
+
class SqueezeGrad(FunctionTransform):
|
235
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
236
|
+
original_shapes = [u.shape for u in update]
|
237
|
+
update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update]
|
238
|
+
grad = [x.view_as(u) for x, u in zip(grad, update)]
|
239
|
+
param = [x.view_as(u) for x, u in zip(param, update)]
|
240
|
+
args = list(args)
|
241
|
+
for i, a in enumerate(args):
|
242
|
+
if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
|
243
|
+
args[i] = [x.view_as(u) for x, u in zip(a, update)]
|
244
|
+
for k, a in kwargs.items():
|
245
|
+
if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
|
246
|
+
kwargs[k] = [x.view_as(u) for x, u in zip(a, update)]
|
247
|
+
out = self.fn(state, group, update, grad, param, *args, **kwargs)
|
248
|
+
return [o.view(s) for o, s in zip(out, original_shapes)]
|
249
|
+
|
250
|
+
|
131
251
|
def zero_guard(*names):
|
132
252
|
return functools.partial(ZeroGuard, names=names)
|
133
253
|
|
@@ -158,12 +278,23 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
158
278
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
159
279
|
|
160
280
|
|
281
|
+
@copy_guard(2, "init")
|
282
|
+
@no_state
|
283
|
+
def weight_decay_to_init(group, update, grad, param, init):
|
284
|
+
utils.stochastic_lerp_(param, init, group["weight_decay_to_ema"] * group["lr"])
|
285
|
+
return update
|
286
|
+
|
287
|
+
|
288
|
+
def identity(state, group, update, grad, param):
|
289
|
+
return update
|
290
|
+
|
291
|
+
|
161
292
|
@zero_guard("exp_avg")
|
162
293
|
@no_state
|
163
294
|
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
164
295
|
utils.weight_decay_to_ema_(
|
296
|
+
param,
|
165
297
|
exp_avg,
|
166
|
-
update,
|
167
298
|
utils.beta_debias(group["ema_beta"], group["step"]),
|
168
299
|
group["weight_decay_to_ema"] * group["lr"],
|
169
300
|
)
|
@@ -174,8 +305,8 @@ def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
|
174
305
|
@no_state
|
175
306
|
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
176
307
|
utils.l1_weight_decay_to_ema_(
|
308
|
+
param,
|
177
309
|
exp_avg,
|
178
|
-
update,
|
179
310
|
utils.beta_debias(group["ema_beta"], group["step"]),
|
180
311
|
group["weight_decay_to_ema"] * group["lr"],
|
181
312
|
)
|
@@ -221,7 +352,27 @@ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
221
352
|
group["weight_decay"],
|
222
353
|
group["caution"],
|
223
354
|
)
|
224
|
-
raise SkipUpdate
|
355
|
+
raise SkipUpdate from None
|
356
|
+
|
357
|
+
|
358
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
359
|
+
@no_state
|
360
|
+
def update_by_adamc(group, update, grad, param, exp_avg, exp_avg_sq):
|
361
|
+
utils.fused_adam_(
|
362
|
+
param,
|
363
|
+
exp_avg,
|
364
|
+
exp_avg_sq,
|
365
|
+
update,
|
366
|
+
grad,
|
367
|
+
utils.get_beta1(group),
|
368
|
+
utils.get_beta2(group),
|
369
|
+
group["step"],
|
370
|
+
group["lr"],
|
371
|
+
group["eps"],
|
372
|
+
group["lr"] * group["weight_decay"] / group["max_lr"],
|
373
|
+
group["caution"],
|
374
|
+
)
|
375
|
+
raise SkipUpdate from None
|
225
376
|
|
226
377
|
|
227
378
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -246,7 +397,7 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
246
397
|
group["weight_decay"],
|
247
398
|
group["caution"],
|
248
399
|
)
|
249
|
-
raise SkipUpdate
|
400
|
+
raise SkipUpdate from None
|
250
401
|
|
251
402
|
|
252
403
|
@no_state
|
@@ -271,7 +422,26 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
271
422
|
group["step"],
|
272
423
|
group["weight_decay"],
|
273
424
|
)
|
274
|
-
raise SkipUpdate
|
425
|
+
raise SkipUpdate from None
|
426
|
+
|
427
|
+
|
428
|
+
@copy_guard(2, "z")
|
429
|
+
@zero_guard("exp_avg")
|
430
|
+
@no_state
|
431
|
+
def update_by_msam(group, update, grad, param, z, exp_avg):
|
432
|
+
utils.msam_(
|
433
|
+
group["lr"],
|
434
|
+
utils.beta_debias(utils.get_beta1(group), group["step"]),
|
435
|
+
param,
|
436
|
+
z,
|
437
|
+
update,
|
438
|
+
grad,
|
439
|
+
exp_avg,
|
440
|
+
group["caution"],
|
441
|
+
group["weight_decay"],
|
442
|
+
group["sam_step_size"],
|
443
|
+
)
|
444
|
+
raise SkipUpdate from None
|
275
445
|
|
276
446
|
|
277
447
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -279,7 +449,7 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
279
449
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
280
450
|
if group["step"] == 1:
|
281
451
|
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
282
|
-
raise SkipUpdate
|
452
|
+
raise SkipUpdate from None
|
283
453
|
|
284
454
|
if group["step"] == 2:
|
285
455
|
update = utils.promote(update)
|
@@ -291,7 +461,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
291
461
|
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
292
462
|
group["eps"],
|
293
463
|
)
|
294
|
-
raise SkipUpdate
|
464
|
+
raise SkipUpdate from None
|
295
465
|
|
296
466
|
utils.fused_adopt_(
|
297
467
|
param,
|
@@ -307,7 +477,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
307
477
|
group["weight_decay"],
|
308
478
|
group["caution"],
|
309
479
|
)
|
310
|
-
raise SkipUpdate
|
480
|
+
raise SkipUpdate from None
|
311
481
|
|
312
482
|
|
313
483
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -315,7 +485,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
315
485
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
316
486
|
if group["step"] == 1:
|
317
487
|
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
318
|
-
raise SkipUpdate
|
488
|
+
raise SkipUpdate from None
|
319
489
|
|
320
490
|
if group["step"] == 2:
|
321
491
|
update = utils.promote(update)
|
@@ -327,7 +497,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
327
497
|
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
328
498
|
group["eps"],
|
329
499
|
)
|
330
|
-
raise SkipUpdate
|
500
|
+
raise SkipUpdate from None
|
331
501
|
|
332
502
|
return utils.adopt(
|
333
503
|
update,
|
@@ -344,10 +514,11 @@ def _init_soap(state, group, update, grad, param, inner: str = ""):
|
|
344
514
|
|
345
515
|
|
346
516
|
def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
347
|
-
Q
|
517
|
+
Q = utils.init_Q_exprs(
|
348
518
|
grad,
|
349
519
|
group["precond_init_scale"],
|
350
520
|
group["precond_init_scale_scale"],
|
521
|
+
group["precond_init_scale_power"],
|
351
522
|
group["max_size_triangular"],
|
352
523
|
group["min_ndim_triangular"],
|
353
524
|
group["memory_save_mode"],
|
@@ -356,32 +527,28 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
|
|
356
527
|
dtype=getattr(torch, group["q_dtype"]),
|
357
528
|
)
|
358
529
|
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
359
|
-
|
530
|
+
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
|
531
|
+
state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
|
532
|
+
if group["adaptive"]:
|
533
|
+
state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
360
534
|
if not cached:
|
361
535
|
return
|
362
536
|
|
363
537
|
state["Q_cache"] = [torch.empty_like(q) for q in Q]
|
364
538
|
|
365
|
-
expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
|
366
|
-
expr = ",".join(expr)
|
367
|
-
grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
|
368
|
-
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
369
|
-
expr = f"{expr},{grad_expr}->{out_expr}"
|
370
|
-
|
371
|
-
state["cache_expr"] = expr
|
372
|
-
|
373
539
|
|
374
540
|
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
375
541
|
state["U"], state["V"], state["d"] = utils.init_lra(
|
376
542
|
grad,
|
543
|
+
group["param_count"],
|
377
544
|
group["precond_init_scale"],
|
378
545
|
group["precond_init_scale_scale"],
|
546
|
+
group["precond_init_scale_power"],
|
379
547
|
group["rank"],
|
380
548
|
getattr(param, "hessian_vector", None),
|
381
549
|
getattr(param, "vector", None),
|
382
550
|
dtype=getattr(torch, group["q_dtype"]),
|
383
551
|
)
|
384
|
-
group["preconditioning_step"] = 0
|
385
552
|
|
386
553
|
|
387
554
|
def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
|
@@ -402,12 +569,12 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
402
569
|
|
403
570
|
@no_state_no_foreach
|
404
571
|
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
|
405
|
-
if update.dim()
|
572
|
+
if update.dim() < 2:
|
406
573
|
return update
|
407
574
|
original_shape = update.shape
|
408
575
|
# doing it this way, as tmp and update are not guaranteed to share memory address or layout
|
409
576
|
tmp = update.flatten(1, -1)
|
410
|
-
utils.inplace_orthogonal_(tmp,
|
577
|
+
utils.inplace_orthogonal_(tmp, out=tmp, scale_mode=scale_mode)
|
411
578
|
return tmp.reshape(original_shape)
|
412
579
|
|
413
580
|
|
@@ -424,7 +591,7 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
|
|
424
591
|
|
425
592
|
|
426
593
|
def _store_std(state, group, update, grad, param):
|
427
|
-
state["init_std"] = torch.std(
|
594
|
+
state["init_std"] = torch.std(param)
|
428
595
|
|
429
596
|
|
430
597
|
@general_guard("init_std", init_fn=_store_std, skip_first=False)
|
@@ -483,9 +650,7 @@ _optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
|
|
483
650
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
484
651
|
@no_state
|
485
652
|
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
|
486
|
-
|
487
|
-
|
488
|
-
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
653
|
+
grad_projected = [utils.project(utils.promote(u), q, False) for u, q in zip(update, Q)]
|
489
654
|
fn = _optim_fns[inner]
|
490
655
|
precond = fn(
|
491
656
|
exp_avg,
|
@@ -500,7 +665,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
500
665
|
|
501
666
|
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
502
667
|
utils.update_preconditioner(
|
503
|
-
u,
|
668
|
+
utils.promote(u),
|
504
669
|
q,
|
505
670
|
gg,
|
506
671
|
ea,
|
@@ -512,35 +677,38 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
512
677
|
return precond
|
513
678
|
|
514
679
|
|
515
|
-
def _update_psgd_precond(
|
680
|
+
def _update_psgd_precond(
|
681
|
+
cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, step, prob: Optional[callable] = None
|
682
|
+
) -> Optional[Tensor]:
|
516
683
|
if prob is None:
|
517
684
|
prob = utils.precond_update_prob_schedule()
|
518
685
|
|
519
686
|
if not group["is_preconditioning"]:
|
520
|
-
return
|
687
|
+
return
|
521
688
|
|
522
689
|
if utils.hasattr_none(param, "vector"):
|
523
690
|
vector, hessian_vector = param.vector, param.hessian_vector
|
524
691
|
del param.vector
|
525
692
|
del param.hessian_vector
|
693
|
+
elif group["inverse_free"]:
|
694
|
+
vector, hessian_vector = None, grad
|
526
695
|
else:
|
527
|
-
vector, hessian_vector = utils.dampen_grad(grad)
|
696
|
+
vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
|
528
697
|
|
529
|
-
utils.psgd_update_precond(
|
530
|
-
Q_mat,
|
531
|
-
exprs,
|
698
|
+
precond = utils.psgd_update_precond(
|
532
699
|
hessian_vector,
|
533
700
|
group["precond_lr"],
|
534
701
|
Q,
|
535
702
|
group["store_triu_as_line"],
|
703
|
+
velocity,
|
704
|
+
utils.get_beta2(group),
|
705
|
+
group["ortho_method"],
|
536
706
|
vector,
|
707
|
+
running_lower_bound,
|
708
|
+
group["lower_bound_beta"],
|
709
|
+
group["precond_update_power_iterations"],
|
537
710
|
)
|
538
|
-
|
539
|
-
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
540
|
-
if group["store_triu_as_line"]:
|
541
|
-
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
542
|
-
else:
|
543
|
-
utils.psgd_balance_Q(Q)
|
711
|
+
del vector, hessian_vector
|
544
712
|
|
545
713
|
if isinstance(prob, float):
|
546
714
|
float_prob = prob
|
@@ -548,54 +716,45 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
548
716
|
float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
|
549
717
|
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
550
718
|
|
551
|
-
if
|
552
|
-
return
|
553
|
-
|
554
|
-
|
719
|
+
if precond is not None:
|
720
|
+
return precond
|
721
|
+
if not should_use_cache or not cached:
|
722
|
+
return None # caching adds extra ops and is not worth the overhead when we precondition at every step
|
555
723
|
|
556
|
-
|
557
|
-
if not cached:
|
558
|
-
return q
|
559
|
-
|
560
|
-
for c_, q_ in zip(Q_cache, q):
|
724
|
+
for c_, q_ in zip(Q_cache, utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q):
|
561
725
|
if q_.ndim == 2:
|
562
726
|
torch.matmul(q_.T, q_, out=c_)
|
563
727
|
else:
|
564
728
|
torch.mul(q_, q_, out=c_)
|
565
|
-
return
|
729
|
+
return None
|
566
730
|
|
567
731
|
|
568
|
-
def _cached_psgd_precond_grad(group,
|
732
|
+
def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
|
733
|
+
kwargs = {"ea": update, "caution": group["caution"], "grad": grad}
|
569
734
|
if group.get("is_cached", False):
|
570
|
-
out = utils.precond_grad_cached_(
|
735
|
+
out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
|
571
736
|
else:
|
572
|
-
out = utils.psgd_precond_grad(
|
737
|
+
out = utils.psgd_precond_grad(
|
738
|
+
preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
|
739
|
+
)
|
573
740
|
group["caution"] = False # we already cautioned here - shouldn't do it again
|
574
741
|
return out
|
575
742
|
|
576
743
|
|
577
|
-
def _fused_cached_psgd_precond_grad(group, grad, param,
|
744
|
+
def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
|
745
|
+
kwargs = {
|
746
|
+
"ea": update,
|
747
|
+
"caution": group["caution"],
|
748
|
+
"grad": grad,
|
749
|
+
"param": param,
|
750
|
+
"lr": group["lr"],
|
751
|
+
"decay": group["weight_decay"],
|
752
|
+
}
|
578
753
|
if group.get("is_cached", False):
|
579
|
-
utils.fused_precond_grad_cached_(
|
580
|
-
cache_expr,
|
581
|
-
update,
|
582
|
-
param,
|
583
|
-
group["lr"],
|
584
|
-
grad,
|
585
|
-
group["weight_decay"],
|
586
|
-
group["caution"],
|
587
|
-
*Q_cache,
|
588
|
-
)
|
754
|
+
utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
|
589
755
|
else:
|
590
756
|
utils.fused_psgd_precond_grad(
|
591
|
-
|
592
|
-
update,
|
593
|
-
param,
|
594
|
-
group["lr"],
|
595
|
-
grad,
|
596
|
-
group["weight_decay"],
|
597
|
-
group["caution"],
|
598
|
-
*Q_mat,
|
757
|
+
preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
|
599
758
|
)
|
600
759
|
|
601
760
|
|
@@ -603,7 +762,7 @@ def _update_lra(
|
|
603
762
|
group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
|
604
763
|
):
|
605
764
|
if not group["is_preconditioning"]:
|
606
|
-
return utils.
|
765
|
+
return utils.multi_flatten((U, 1), (V, 1), (d, 0))
|
607
766
|
|
608
767
|
if utils.hasattr_none(params[0], "hessian_vector"):
|
609
768
|
vector = utils.flatten([p.vector for p in params])
|
@@ -613,127 +772,121 @@ def _update_lra(
|
|
613
772
|
del p.hessian_vector
|
614
773
|
else:
|
615
774
|
vector, hessian_vector = utils.dampen_multiple(grads)
|
616
|
-
|
775
|
+
precond_step = group["precond_step"] = group.get("precond_step", -1) + 1
|
776
|
+
return utils.update_lra_precond_(
|
777
|
+
U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed, bool(precond_step % 2)
|
778
|
+
)
|
617
779
|
|
618
780
|
|
781
|
+
@SqueezeGrad
|
782
|
+
@PrecondGradAccumGuard
|
619
783
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
620
784
|
@no_state
|
621
|
-
def scale_by_psgd_lra(group, update, grad, param, U, V, d):
|
622
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
785
|
+
def scale_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
786
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
|
623
787
|
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
624
788
|
|
625
789
|
|
790
|
+
@SqueezeGrad
|
791
|
+
@PrecondGradAccumGuard
|
626
792
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
627
793
|
@no_state
|
628
|
-
def update_by_psgd_lra(group, update, grad, param, U, V, d):
|
629
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
630
|
-
utils.apply_lra_update(param, update, u, v, d)
|
631
|
-
raise SkipUpdate
|
794
|
+
def update_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
795
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
|
796
|
+
utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
|
797
|
+
raise SkipUpdate from None
|
632
798
|
|
633
799
|
|
800
|
+
@SqueezeGrad
|
801
|
+
@PrecondGradAccumGuard
|
634
802
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
635
803
|
@no_state
|
636
|
-
def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
637
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
804
|
+
def scale_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
805
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
|
638
806
|
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
639
807
|
|
640
808
|
|
809
|
+
@SqueezeGrad
|
810
|
+
@PrecondGradAccumGuard
|
641
811
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
642
812
|
@no_state
|
643
|
-
def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
644
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
645
|
-
utils.apply_lra_update(param, update, u, v, d)
|
646
|
-
raise SkipUpdate
|
813
|
+
def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
814
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
|
815
|
+
utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
|
816
|
+
raise SkipUpdate from None
|
647
817
|
|
648
818
|
|
649
|
-
@
|
819
|
+
@SqueezeGrad
|
820
|
+
@PrecondGradAccumGuard
|
821
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
650
822
|
@no_state_no_foreach
|
651
823
|
def scale_by_psgd(
|
652
824
|
group,
|
653
825
|
update,
|
654
826
|
grad,
|
655
827
|
param,
|
828
|
+
update_to_precond,
|
656
829
|
Q,
|
657
|
-
exprs,
|
658
830
|
Q_cache,
|
659
|
-
|
831
|
+
velocity: Optional[List[Tensor]],
|
832
|
+
running_lower_bound: List[Tensor],
|
833
|
+
step: Tensor,
|
660
834
|
cached: bool = False,
|
661
835
|
prob: Optional[callable] = None,
|
662
836
|
):
|
663
|
-
|
664
|
-
|
665
|
-
Q_mat = _update_psgd_precond(
|
666
|
-
cached,
|
667
|
-
Q_cache,
|
668
|
-
group,
|
669
|
-
param,
|
670
|
-
update if group["momentum_into_precond_update"] else grad,
|
671
|
-
Q_mat,
|
672
|
-
Q,
|
673
|
-
exprs,
|
674
|
-
prob,
|
675
|
-
)
|
676
|
-
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
837
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
|
838
|
+
return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
|
677
839
|
|
678
840
|
|
679
|
-
@
|
841
|
+
@SqueezeGrad
|
842
|
+
@PrecondGradAccumGuard
|
843
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
680
844
|
@no_state_no_foreach
|
681
845
|
def scale_by_delayed_psgd(
|
682
846
|
group,
|
683
847
|
update,
|
684
848
|
grad,
|
685
849
|
param,
|
850
|
+
update_to_precond,
|
686
851
|
Q,
|
687
|
-
exprs,
|
688
852
|
Q_cache,
|
689
|
-
|
853
|
+
velocity: Optional[List[Tensor]],
|
854
|
+
running_lower_bound: List[Tensor],
|
855
|
+
step: Tensor,
|
690
856
|
cached: bool = False,
|
691
857
|
prob: Optional[callable] = None,
|
692
858
|
):
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
group,
|
699
|
-
param,
|
700
|
-
update if group["momentum_into_precond_update"] else grad,
|
701
|
-
Q_mat,
|
702
|
-
Q,
|
703
|
-
exprs,
|
704
|
-
prob,
|
859
|
+
if group.get("inverse_free", False):
|
860
|
+
precond = None
|
861
|
+
else:
|
862
|
+
precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
|
863
|
+
new = _update_psgd_precond(
|
864
|
+
cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob
|
705
865
|
)
|
706
|
-
return precond
|
866
|
+
return new if precond is None else precond
|
707
867
|
|
708
868
|
|
709
|
-
@
|
869
|
+
@SqueezeGrad
|
870
|
+
@PrecondGradAccumGuard
|
871
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
710
872
|
@no_state_no_foreach
|
711
873
|
def update_by_psgd(
|
712
874
|
group,
|
713
875
|
update,
|
714
876
|
grad,
|
715
877
|
param,
|
878
|
+
update_to_precond,
|
716
879
|
Q,
|
717
|
-
exprs,
|
718
880
|
Q_cache,
|
719
|
-
|
881
|
+
velocity: Optional[List[Tensor]],
|
882
|
+
running_lower_bound: List[Tensor],
|
883
|
+
step: Tensor,
|
720
884
|
cached: bool = False,
|
721
885
|
prob: Optional[callable] = None,
|
722
886
|
):
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
Q_cache,
|
727
|
-
group,
|
728
|
-
param,
|
729
|
-
update if group["momentum_into_precond_update"] else grad,
|
730
|
-
Q_mat,
|
731
|
-
Q,
|
732
|
-
exprs,
|
733
|
-
prob,
|
734
|
-
)
|
735
|
-
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
736
|
-
raise SkipUpdate
|
887
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
|
888
|
+
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
|
889
|
+
raise SkipUpdate from None
|
737
890
|
|
738
891
|
|
739
892
|
@no_state
|
@@ -741,34 +894,33 @@ def sign(group, update, grad, param, graft: bool = True):
|
|
741
894
|
return utils.sign_(update, graft)
|
742
895
|
|
743
896
|
|
744
|
-
@
|
897
|
+
@no_state
|
898
|
+
def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
|
899
|
+
assert clip_fn is not None
|
900
|
+
return clip_fn(update)
|
901
|
+
|
902
|
+
|
903
|
+
@SqueezeGrad
|
904
|
+
@PrecondGradAccumGuard
|
905
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
745
906
|
@no_state_no_foreach
|
746
907
|
def update_by_delayed_psgd(
|
747
908
|
group,
|
748
909
|
update,
|
749
910
|
grad,
|
750
911
|
param,
|
912
|
+
update_to_precond,
|
751
913
|
Q,
|
752
|
-
exprs,
|
753
914
|
Q_cache,
|
754
|
-
|
915
|
+
velocity: Optional[List[Tensor]],
|
916
|
+
running_lower_bound: List[Tensor],
|
917
|
+
step: Tensor,
|
755
918
|
cached: bool = False,
|
756
919
|
prob: Optional[callable] = None,
|
757
920
|
):
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
cached,
|
762
|
-
Q_cache,
|
763
|
-
group,
|
764
|
-
param,
|
765
|
-
update if group["momentum_into_precond_update"] else grad,
|
766
|
-
Q_mat,
|
767
|
-
Q,
|
768
|
-
exprs,
|
769
|
-
prob,
|
770
|
-
)
|
771
|
-
raise SkipUpdate
|
921
|
+
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
|
922
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
|
923
|
+
raise SkipUpdate from None
|
772
924
|
|
773
925
|
|
774
926
|
def palm_beta2(state, group, update, grad, param):
|
@@ -805,26 +957,63 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
805
957
|
utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
|
806
958
|
|
807
959
|
|
808
|
-
def
|
809
|
-
|
810
|
-
|
811
|
-
for branch in branches:
|
812
|
-
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
813
|
-
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
814
|
-
if skip_update:
|
815
|
-
raise ValueError("Branches should not skip updates")
|
816
|
-
outputs.append(branch_update)
|
817
|
-
return merge_fn(outputs)
|
960
|
+
def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
|
961
|
+
if retain and offset:
|
962
|
+
raise ValueError("offset cannot be retained")
|
818
963
|
|
819
|
-
|
964
|
+
def _walk(obj):
|
965
|
+
stack = [obj]
|
966
|
+
while stack:
|
967
|
+
cur = stack.pop()
|
968
|
+
if isinstance(cur, FunctionTransform):
|
969
|
+
yield cur
|
970
|
+
stack.append(cur.fn)
|
971
|
+
elif isinstance(cur, functools.partial):
|
972
|
+
stack.append(cur.func)
|
973
|
+
elif isinstance(cur, Branch):
|
974
|
+
for branch in cur.branches:
|
975
|
+
stack.extend(branch)
|
976
|
+
elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)):
|
977
|
+
stack.extend(cur)
|
978
|
+
|
979
|
+
if retain:
|
980
|
+
offset = max((ft.transform_idx for ft in _walk(fns) if ft.transform_idx is not None), default=-1) + 1
|
981
|
+
|
982
|
+
new_fns = [copy.deepcopy(fn) for fn in fns]
|
983
|
+
for ft in _walk(new_fns):
|
984
|
+
if not retain or ft.transform_idx is None:
|
985
|
+
ft.transform_idx, offset = offset, offset + 1
|
986
|
+
|
987
|
+
return new_fns
|
820
988
|
|
821
989
|
|
822
990
|
class ChainOpt(utils.StatefulOptimizer):
|
823
991
|
promote: bool = False
|
992
|
+
global_defaults = {
|
993
|
+
"caution": False,
|
994
|
+
"lr": 1,
|
995
|
+
"warmup_steps": 0,
|
996
|
+
"weight_decay": 0,
|
997
|
+
"eps": 1e-8,
|
998
|
+
}
|
824
999
|
|
825
1000
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
826
|
-
|
827
|
-
|
1001
|
+
base = self.global_defaults.copy()
|
1002
|
+
base.update({k: v for k, v in defaults.items() if v is not use_default})
|
1003
|
+
super().__init__(params, base, foreach)
|
1004
|
+
self.fns = fns
|
1005
|
+
|
1006
|
+
@property
|
1007
|
+
def fns(self):
|
1008
|
+
return self._fns
|
1009
|
+
|
1010
|
+
@fns.setter
|
1011
|
+
def fns(self, value):
|
1012
|
+
self._fns = value
|
1013
|
+
self._set_indices(retain=True)
|
1014
|
+
|
1015
|
+
def _set_indices(self, retain=True):
|
1016
|
+
self._fns = set_indices(self.fns, retain)
|
828
1017
|
|
829
1018
|
def _step(self, group):
|
830
1019
|
if "base_lr" not in group:
|
@@ -868,7 +1057,6 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
868
1057
|
group["step"] = None
|
869
1058
|
|
870
1059
|
|
871
|
-
use_default = object()
|
872
1060
|
str_or_fn = Union[str, callable, None, Literal[use_default]]
|
873
1061
|
|
874
1062
|
|
@@ -931,7 +1119,7 @@ class BaseOpt(ChainOpt):
|
|
931
1119
|
|
932
1120
|
update_clipping: str_or_fn = None
|
933
1121
|
The function to use for clipping the outgoing updates before applying them, after all other transformations.
|
934
|
-
This will turn off
|
1122
|
+
This will turn off fused updates.
|
935
1123
|
This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
|
936
1124
|
|
937
1125
|
"""
|
@@ -945,11 +1133,11 @@ class BaseOpt(ChainOpt):
|
|
945
1133
|
self,
|
946
1134
|
params,
|
947
1135
|
defaults,
|
948
|
-
foreach: bool,
|
949
|
-
gradient_clipping: str_or_fn,
|
950
|
-
update_clipping: str_or_fn,
|
1136
|
+
foreach: bool = True,
|
1137
|
+
gradient_clipping: str_or_fn = None,
|
1138
|
+
update_clipping: str_or_fn = None,
|
951
1139
|
palm: bool = use_default,
|
952
|
-
|
1140
|
+
fns: Iterable[callable] = (),
|
953
1141
|
compile_step: bool = use_default,
|
954
1142
|
promote: bool = use_default,
|
955
1143
|
):
|
@@ -957,6 +1145,7 @@ class BaseOpt(ChainOpt):
|
|
957
1145
|
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
958
1146
|
|
959
1147
|
args, kwargs = None, None
|
1148
|
+
fns = tuple(fns)
|
960
1149
|
fn = fns[-1]
|
961
1150
|
if isinstance(fn, functools.partial):
|
962
1151
|
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
@@ -1020,3 +1209,27 @@ class ScheduleFree(BaseOpt):
|
|
1020
1209
|
p32 = utils.promote(p.data)
|
1021
1210
|
p32.lerp_(end=z, weight=1 - beta1)
|
1022
1211
|
utils.copy_stochastic_(p.data, p32)
|
1212
|
+
|
1213
|
+
|
1214
|
+
class MSAM(BaseOpt):
|
1215
|
+
def eval(self):
|
1216
|
+
for group in self.param_groups:
|
1217
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
1218
|
+
if not train_mode:
|
1219
|
+
for p in group["params"]:
|
1220
|
+
state = self.state_(p)
|
1221
|
+
if "z" in state:
|
1222
|
+
p_copy = p.data.clone()
|
1223
|
+
utils.copy_stochastic_(p.data, state["z"])
|
1224
|
+
utils.copy_stochastic_(state["z"], p_copy)
|
1225
|
+
|
1226
|
+
def train(self):
|
1227
|
+
for group in self.param_groups:
|
1228
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
1229
|
+
if train_mode:
|
1230
|
+
for p in group["params"]:
|
1231
|
+
state = self.state_(p)
|
1232
|
+
if "z" in state:
|
1233
|
+
p_copy = p.data.clone()
|
1234
|
+
utils.copy_stochastic_(p.data, state["z"])
|
1235
|
+
utils.copy_stochastic_(state["z"], p_copy)
|