heavyball 1.7.1__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -16
- heavyball/chainable.py +338 -190
- heavyball/helpers.py +804 -0
- heavyball/utils.py +813 -252
- heavyball-2.0.0.dev0.dist-info/METADATA +109 -0
- heavyball-2.0.0.dev0.dist-info/RECORD +9 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/WHEEL +1 -1
- heavyball/optimizations/__init__.py +0 -38
- heavyball/optimizations/integrator.py +0 -169
- heavyball/optimizations/optimizations.py +0 -329
- heavyball-1.7.1.dist-info/METADATA +0 -939
- heavyball-1.7.1.dist-info/RECORD +0 -11
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
|
+
import copy
|
1
2
|
import functools
|
2
3
|
import math
|
3
4
|
import random
|
4
|
-
from typing import List, Literal, Optional, Union
|
5
|
+
from typing import Iterable, List, Literal, Optional, Union
|
5
6
|
|
6
7
|
import torch
|
7
8
|
from torch import Tensor
|
8
9
|
|
9
10
|
from . import utils
|
10
11
|
|
11
|
-
|
12
|
+
use_default = utils.use_default
|
12
13
|
|
13
14
|
|
14
15
|
def _key_in_state(state, key):
|
@@ -36,20 +37,52 @@ def _guard_in_state(state, key, template_fn):
|
|
36
37
|
|
37
38
|
|
38
39
|
class FunctionTransform:
|
39
|
-
def __init__(self, fn):
|
40
|
+
def __init__(self, fn, names: list[str] | None = None):
|
41
|
+
if names is None:
|
42
|
+
names = []
|
40
43
|
self.fn = fn
|
41
44
|
self.fn_name = self.get_fn().__name__
|
45
|
+
self.transform_idx = None
|
46
|
+
self.is_initialized = False
|
47
|
+
self.names = names
|
42
48
|
|
43
|
-
def
|
49
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
44
50
|
raise NotImplementedError
|
45
51
|
|
52
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
53
|
+
raise NotImplementedError
|
54
|
+
|
55
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
56
|
+
states = [state(p) for p in param]
|
57
|
+
skip_update = False
|
58
|
+
for st, a in zip(states, zip(update, grad, param, *args)):
|
59
|
+
if self.transform_idx not in st.get("is_initialized", set()):
|
60
|
+
try:
|
61
|
+
self._init(st, group, *a, **kwargs)
|
62
|
+
except SkipUpdate:
|
63
|
+
skip_update = True
|
64
|
+
except:
|
65
|
+
raise
|
66
|
+
finally:
|
67
|
+
if "is_initialized" not in st:
|
68
|
+
st["is_initialized"] = set()
|
69
|
+
st["is_initialized"].add(self.transform_idx)
|
70
|
+
if skip_update:
|
71
|
+
raise SkipUpdate from None
|
72
|
+
vars = [[st.get(self.val_name(name), None) for st in states] for name in self.names]
|
73
|
+
return self._call(state, group, update, grad, param, vars, *args, **kwargs)
|
74
|
+
|
46
75
|
def get_fn(self):
|
47
76
|
if utils.hasattr_none(self.fn, "get_fn"):
|
48
77
|
return self.fn.get_fn()
|
49
78
|
return self.fn
|
50
79
|
|
51
80
|
def val_name(self, name):
|
52
|
-
|
81
|
+
assert self.transform_idx is not None
|
82
|
+
return f"{self.fn_name}_{name}_{self.transform_idx}"
|
83
|
+
|
84
|
+
def __repr__(self):
|
85
|
+
return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
|
53
86
|
|
54
87
|
|
55
88
|
def _zero_guard(state, key, ref, dtype):
|
@@ -63,49 +96,102 @@ def _storage_dtype(group):
|
|
63
96
|
|
64
97
|
class ZeroGuard(FunctionTransform):
|
65
98
|
def __init__(self, fn, names):
|
66
|
-
super().__init__(fn)
|
67
|
-
self.names = names
|
99
|
+
super().__init__(fn, names)
|
68
100
|
|
69
|
-
def
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
101
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
102
|
+
for name in self.names:
|
103
|
+
_zero_guard(state, self.val_name(name), param, _storage_dtype(group))
|
104
|
+
|
105
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
74
106
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
75
107
|
|
76
108
|
|
109
|
+
class PrecondGradAccumGuard(FunctionTransform):
|
110
|
+
def __init__(self, fn):
|
111
|
+
super().__init__(fn, ["precond_grad_accum"])
|
112
|
+
self.steps_taken = 0
|
113
|
+
self.pass_through = None
|
114
|
+
|
115
|
+
def _accum(self, state, new):
|
116
|
+
self.steps_taken += 1
|
117
|
+
utils.stochastic_add_(state, new)
|
118
|
+
|
119
|
+
def _reset(self, state):
|
120
|
+
if self.steps_taken == 0:
|
121
|
+
self.steps_taken = 0
|
122
|
+
utils.zero_(state)
|
123
|
+
|
124
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
125
|
+
if self.pass_through is None:
|
126
|
+
self.pass_through = not group.get("precond_grad_accum", False)
|
127
|
+
if self.pass_through is False:
|
128
|
+
for name in self.names:
|
129
|
+
_zero_guard(state, self.val_name(name), param, _storage_dtype(group))
|
130
|
+
|
131
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
132
|
+
base_grad = update if group.get("momentum_into_precond_update", True) else grad
|
133
|
+
if self.pass_through:
|
134
|
+
return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs)
|
135
|
+
|
136
|
+
(vars,) = vars
|
137
|
+
if group["is_preconditioning"]:
|
138
|
+
if self.steps_taken:
|
139
|
+
self._accum(vars, base_grad)
|
140
|
+
utils.stochastic_multiply_(vars, 1 / self.steps_taken)
|
141
|
+
else:
|
142
|
+
vars = base_grad
|
143
|
+
else:
|
144
|
+
self._accum(vars, base_grad)
|
145
|
+
vars = base_grad
|
146
|
+
try:
|
147
|
+
out = self.fn(state, group, update, grad, param, *args, vars, **kwargs)
|
148
|
+
finally:
|
149
|
+
if group["is_preconditioning"]:
|
150
|
+
self._reset(vars)
|
151
|
+
|
152
|
+
return out
|
153
|
+
|
154
|
+
|
77
155
|
class CopyGuard(FunctionTransform):
|
78
156
|
def __init__(self, fn, index, names):
|
79
|
-
super().__init__(fn)
|
157
|
+
super().__init__(fn, names)
|
80
158
|
self.index = index
|
81
|
-
self.names = names
|
82
159
|
|
83
|
-
def
|
160
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
84
161
|
val = [update, grad, param, *args][self.index]
|
85
|
-
|
86
|
-
[
|
87
|
-
|
88
|
-
|
162
|
+
for name in self.names:
|
163
|
+
state[self.val_name(name)] = torch.clone(val)
|
164
|
+
|
165
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
89
166
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
90
167
|
|
91
168
|
|
92
|
-
class GeneralGuard(FunctionTransform):
|
169
|
+
class GeneralGuard(FunctionTransform):
|
93
170
|
def __init__(self, fn, names, init_fn, skip_first: bool = True):
|
94
|
-
super().__init__(fn)
|
95
|
-
self.names = names
|
171
|
+
super().__init__(fn, names)
|
96
172
|
self.init_fn = init_fn
|
97
173
|
self.skip_first = skip_first
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
for p
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
174
|
+
self.named_to_anonymous = None
|
175
|
+
self.anonymous_to_named = None
|
176
|
+
|
177
|
+
def _map(self, state_fn, param, mapping):
|
178
|
+
for p in param:
|
179
|
+
state = state_fn(p)
|
180
|
+
for name, mapped in mapping.items():
|
181
|
+
if mapped in state:
|
182
|
+
raise ValueError(f"Name {name} already mapped to {mapped}")
|
183
|
+
if name in state:
|
184
|
+
state[mapped] = state.pop(name)
|
185
|
+
|
186
|
+
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
|
187
|
+
self.init_fn(state, group, update, grad, param, **kwargs)
|
188
|
+
for name in self.names:
|
189
|
+
state[self.val_name(name)] = state.pop(name, None)
|
190
|
+
if self.skip_first:
|
191
|
+
raise SkipUpdate from None
|
192
|
+
|
193
|
+
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
|
194
|
+
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
109
195
|
|
110
196
|
|
111
197
|
class NoState(FunctionTransform):
|
@@ -124,7 +210,7 @@ class NoStateNoForeach(FunctionTransform):
|
|
124
210
|
skip_update = True
|
125
211
|
pass
|
126
212
|
if skip_update:
|
127
|
-
raise SkipUpdate
|
213
|
+
raise SkipUpdate from None
|
128
214
|
return updates
|
129
215
|
|
130
216
|
|
@@ -158,12 +244,23 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
158
244
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
159
245
|
|
160
246
|
|
247
|
+
@copy_guard(2, "init")
|
248
|
+
@no_state
|
249
|
+
def weight_decay_to_init(group, update, grad, param, init):
|
250
|
+
utils.weight_decay_to_init_(
|
251
|
+
param,
|
252
|
+
init,
|
253
|
+
group["weight_decay_to_ema"] * group["lr"],
|
254
|
+
)
|
255
|
+
return update
|
256
|
+
|
257
|
+
|
161
258
|
@zero_guard("exp_avg")
|
162
259
|
@no_state
|
163
260
|
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
164
261
|
utils.weight_decay_to_ema_(
|
262
|
+
param,
|
165
263
|
exp_avg,
|
166
|
-
update,
|
167
264
|
utils.beta_debias(group["ema_beta"], group["step"]),
|
168
265
|
group["weight_decay_to_ema"] * group["lr"],
|
169
266
|
)
|
@@ -174,8 +271,8 @@ def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
|
174
271
|
@no_state
|
175
272
|
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
176
273
|
utils.l1_weight_decay_to_ema_(
|
274
|
+
param,
|
177
275
|
exp_avg,
|
178
|
-
update,
|
179
276
|
utils.beta_debias(group["ema_beta"], group["step"]),
|
180
277
|
group["weight_decay_to_ema"] * group["lr"],
|
181
278
|
)
|
@@ -221,7 +318,7 @@ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
221
318
|
group["weight_decay"],
|
222
319
|
group["caution"],
|
223
320
|
)
|
224
|
-
raise SkipUpdate
|
321
|
+
raise SkipUpdate from None
|
225
322
|
|
226
323
|
|
227
324
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -246,7 +343,7 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
246
343
|
group["weight_decay"],
|
247
344
|
group["caution"],
|
248
345
|
)
|
249
|
-
raise SkipUpdate
|
346
|
+
raise SkipUpdate from None
|
250
347
|
|
251
348
|
|
252
349
|
@no_state
|
@@ -271,7 +368,26 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
271
368
|
group["step"],
|
272
369
|
group["weight_decay"],
|
273
370
|
)
|
274
|
-
raise SkipUpdate
|
371
|
+
raise SkipUpdate from None
|
372
|
+
|
373
|
+
|
374
|
+
@copy_guard(2, "z")
|
375
|
+
@zero_guard("exp_avg")
|
376
|
+
@no_state
|
377
|
+
def update_by_msam(group, update, grad, param, z, exp_avg):
|
378
|
+
utils.msam_(
|
379
|
+
group["lr"],
|
380
|
+
utils.beta_debias(utils.get_beta1(group), group["step"]),
|
381
|
+
param,
|
382
|
+
z,
|
383
|
+
update,
|
384
|
+
grad,
|
385
|
+
exp_avg,
|
386
|
+
group["caution"],
|
387
|
+
group["weight_decay"],
|
388
|
+
group["sam_step_size"],
|
389
|
+
)
|
390
|
+
raise SkipUpdate from None
|
275
391
|
|
276
392
|
|
277
393
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -279,7 +395,7 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
279
395
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
280
396
|
if group["step"] == 1:
|
281
397
|
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
282
|
-
raise SkipUpdate
|
398
|
+
raise SkipUpdate from None
|
283
399
|
|
284
400
|
if group["step"] == 2:
|
285
401
|
update = utils.promote(update)
|
@@ -291,7 +407,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
291
407
|
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
292
408
|
group["eps"],
|
293
409
|
)
|
294
|
-
raise SkipUpdate
|
410
|
+
raise SkipUpdate from None
|
295
411
|
|
296
412
|
utils.fused_adopt_(
|
297
413
|
param,
|
@@ -307,7 +423,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
307
423
|
group["weight_decay"],
|
308
424
|
group["caution"],
|
309
425
|
)
|
310
|
-
raise SkipUpdate
|
426
|
+
raise SkipUpdate from None
|
311
427
|
|
312
428
|
|
313
429
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -315,7 +431,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
315
431
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
316
432
|
if group["step"] == 1:
|
317
433
|
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
318
|
-
raise SkipUpdate
|
434
|
+
raise SkipUpdate from None
|
319
435
|
|
320
436
|
if group["step"] == 2:
|
321
437
|
update = utils.promote(update)
|
@@ -327,7 +443,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
327
443
|
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
328
444
|
group["eps"],
|
329
445
|
)
|
330
|
-
raise SkipUpdate
|
446
|
+
raise SkipUpdate from None
|
331
447
|
|
332
448
|
return utils.adopt(
|
333
449
|
update,
|
@@ -344,10 +460,11 @@ def _init_soap(state, group, update, grad, param, inner: str = ""):
|
|
344
460
|
|
345
461
|
|
346
462
|
def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
347
|
-
Q
|
463
|
+
Q = utils.init_Q_exprs(
|
348
464
|
grad,
|
349
465
|
group["precond_init_scale"],
|
350
466
|
group["precond_init_scale_scale"],
|
467
|
+
group["precond_init_scale_power"],
|
351
468
|
group["max_size_triangular"],
|
352
469
|
group["min_ndim_triangular"],
|
353
470
|
group["memory_save_mode"],
|
@@ -356,32 +473,27 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
|
|
356
473
|
dtype=getattr(torch, group["q_dtype"]),
|
357
474
|
)
|
358
475
|
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
359
|
-
|
476
|
+
state["running_lower_bound"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
477
|
+
if group["adaptive"]:
|
478
|
+
state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
360
479
|
if not cached:
|
361
480
|
return
|
362
481
|
|
363
482
|
state["Q_cache"] = [torch.empty_like(q) for q in Q]
|
364
483
|
|
365
|
-
expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
|
366
|
-
expr = ",".join(expr)
|
367
|
-
grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
|
368
|
-
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
369
|
-
expr = f"{expr},{grad_expr}->{out_expr}"
|
370
|
-
|
371
|
-
state["cache_expr"] = expr
|
372
|
-
|
373
484
|
|
374
485
|
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
375
486
|
state["U"], state["V"], state["d"] = utils.init_lra(
|
376
487
|
grad,
|
488
|
+
group["param_count"],
|
377
489
|
group["precond_init_scale"],
|
378
490
|
group["precond_init_scale_scale"],
|
491
|
+
group["precond_init_scale_power"],
|
379
492
|
group["rank"],
|
380
493
|
getattr(param, "hessian_vector", None),
|
381
494
|
getattr(param, "vector", None),
|
382
495
|
dtype=getattr(torch, group["q_dtype"]),
|
383
496
|
)
|
384
|
-
group["preconditioning_step"] = 0
|
385
497
|
|
386
498
|
|
387
499
|
def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
|
@@ -402,12 +514,12 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
402
514
|
|
403
515
|
@no_state_no_foreach
|
404
516
|
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
|
405
|
-
if update.dim()
|
517
|
+
if update.dim() < 2:
|
406
518
|
return update
|
407
519
|
original_shape = update.shape
|
408
520
|
# doing it this way, as tmp and update are not guaranteed to share memory address or layout
|
409
521
|
tmp = update.flatten(1, -1)
|
410
|
-
utils.inplace_orthogonal_(tmp,
|
522
|
+
utils.inplace_orthogonal_(tmp, out=tmp, scale_mode=scale_mode)
|
411
523
|
return tmp.reshape(original_shape)
|
412
524
|
|
413
525
|
|
@@ -424,7 +536,7 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
|
|
424
536
|
|
425
537
|
|
426
538
|
def _store_std(state, group, update, grad, param):
|
427
|
-
state["init_std"] = torch.std(
|
539
|
+
state["init_std"] = torch.std(param)
|
428
540
|
|
429
541
|
|
430
542
|
@general_guard("init_std", init_fn=_store_std, skip_first=False)
|
@@ -483,9 +595,7 @@ _optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
|
|
483
595
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
484
596
|
@no_state
|
485
597
|
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
|
486
|
-
|
487
|
-
|
488
|
-
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
598
|
+
grad_projected = [utils.project(utils.promote(u), q, False) for u, q in zip(update, Q)]
|
489
599
|
fn = _optim_fns[inner]
|
490
600
|
precond = fn(
|
491
601
|
exp_avg,
|
@@ -500,7 +610,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
500
610
|
|
501
611
|
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
502
612
|
utils.update_preconditioner(
|
503
|
-
u,
|
613
|
+
utils.promote(u),
|
504
614
|
q,
|
505
615
|
gg,
|
506
616
|
ea,
|
@@ -512,35 +622,38 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
512
622
|
return precond
|
513
623
|
|
514
624
|
|
515
|
-
def _update_psgd_precond(
|
625
|
+
def _update_psgd_precond(
|
626
|
+
cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, prob: Optional[callable] = None
|
627
|
+
) -> Optional[Tensor]:
|
516
628
|
if prob is None:
|
517
629
|
prob = utils.precond_update_prob_schedule()
|
518
630
|
|
519
631
|
if not group["is_preconditioning"]:
|
520
|
-
return
|
632
|
+
return
|
521
633
|
|
522
634
|
if utils.hasattr_none(param, "vector"):
|
523
635
|
vector, hessian_vector = param.vector, param.hessian_vector
|
524
636
|
del param.vector
|
525
637
|
del param.hessian_vector
|
638
|
+
elif group["inverse_free"]:
|
639
|
+
vector, hessian_vector = None, grad
|
526
640
|
else:
|
527
|
-
vector, hessian_vector = utils.dampen_grad(grad)
|
641
|
+
vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
|
528
642
|
|
529
|
-
utils.psgd_update_precond(
|
530
|
-
Q_mat,
|
531
|
-
exprs,
|
643
|
+
precond = (utils.inverse_free_psgd_update_precond if vector is None else utils.psgd_update_precond)(
|
532
644
|
hessian_vector,
|
533
645
|
group["precond_lr"],
|
534
646
|
Q,
|
535
647
|
group["store_triu_as_line"],
|
648
|
+
velocity,
|
649
|
+
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
650
|
+
group["ortho_method"],
|
536
651
|
vector,
|
652
|
+
running_lower_bound,
|
653
|
+
group["lower_bound_beta"],
|
654
|
+
group["precond_update_power_iterations"],
|
537
655
|
)
|
538
|
-
|
539
|
-
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
540
|
-
if group["store_triu_as_line"]:
|
541
|
-
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
542
|
-
else:
|
543
|
-
utils.psgd_balance_Q(Q)
|
656
|
+
del vector, hessian_vector
|
544
657
|
|
545
658
|
if isinstance(prob, float):
|
546
659
|
float_prob = prob
|
@@ -548,54 +661,45 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
548
661
|
float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
|
549
662
|
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
550
663
|
|
551
|
-
if
|
552
|
-
return
|
553
|
-
|
554
|
-
|
664
|
+
if precond is not None:
|
665
|
+
return precond
|
666
|
+
if not should_use_cache or not cached:
|
667
|
+
return None # caching adds extra ops and is not worth the overhead when we precondition at every step
|
555
668
|
|
556
|
-
|
557
|
-
if not cached:
|
558
|
-
return q
|
559
|
-
|
560
|
-
for c_, q_ in zip(Q_cache, q):
|
669
|
+
for c_, q_ in zip(Q_cache, utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q):
|
561
670
|
if q_.ndim == 2:
|
562
671
|
torch.matmul(q_.T, q_, out=c_)
|
563
672
|
else:
|
564
673
|
torch.mul(q_, q_, out=c_)
|
565
|
-
return
|
674
|
+
return None
|
566
675
|
|
567
676
|
|
568
|
-
def _cached_psgd_precond_grad(group,
|
677
|
+
def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
|
678
|
+
kwargs = {"ea": update, "caution": group["caution"], "grad": grad}
|
569
679
|
if group.get("is_cached", False):
|
570
|
-
out = utils.precond_grad_cached_(
|
680
|
+
out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
|
571
681
|
else:
|
572
|
-
out = utils.psgd_precond_grad(
|
682
|
+
out = utils.psgd_precond_grad(
|
683
|
+
preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
|
684
|
+
)
|
573
685
|
group["caution"] = False # we already cautioned here - shouldn't do it again
|
574
686
|
return out
|
575
687
|
|
576
688
|
|
577
|
-
def _fused_cached_psgd_precond_grad(group, grad, param,
|
689
|
+
def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
|
690
|
+
kwargs = {
|
691
|
+
"ea": update,
|
692
|
+
"caution": group["caution"],
|
693
|
+
"grad": grad,
|
694
|
+
"param": param,
|
695
|
+
"lr": group["lr"],
|
696
|
+
"decay": group["weight_decay"],
|
697
|
+
}
|
578
698
|
if group.get("is_cached", False):
|
579
|
-
utils.fused_precond_grad_cached_(
|
580
|
-
cache_expr,
|
581
|
-
update,
|
582
|
-
param,
|
583
|
-
group["lr"],
|
584
|
-
grad,
|
585
|
-
group["weight_decay"],
|
586
|
-
group["caution"],
|
587
|
-
*Q_cache,
|
588
|
-
)
|
699
|
+
utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
|
589
700
|
else:
|
590
701
|
utils.fused_psgd_precond_grad(
|
591
|
-
|
592
|
-
update,
|
593
|
-
param,
|
594
|
-
group["lr"],
|
595
|
-
grad,
|
596
|
-
group["weight_decay"],
|
597
|
-
group["caution"],
|
598
|
-
*Q_mat,
|
702
|
+
preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
|
599
703
|
)
|
600
704
|
|
601
705
|
|
@@ -603,7 +707,7 @@ def _update_lra(
|
|
603
707
|
group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
|
604
708
|
):
|
605
709
|
if not group["is_preconditioning"]:
|
606
|
-
return utils.
|
710
|
+
return utils.multi_flatten((U, 1), (V, 1), (d, 0))
|
607
711
|
|
608
712
|
if utils.hasattr_none(params[0], "hessian_vector"):
|
609
713
|
vector = utils.flatten([p.vector for p in params])
|
@@ -613,127 +717,109 @@ def _update_lra(
|
|
613
717
|
del p.hessian_vector
|
614
718
|
else:
|
615
719
|
vector, hessian_vector = utils.dampen_multiple(grads)
|
616
|
-
|
720
|
+
precond_step = group["precond_step"] = group.get("precond_step", -1) + 1
|
721
|
+
return utils.update_lra_precond_(
|
722
|
+
U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed, bool(precond_step % 2)
|
723
|
+
)
|
617
724
|
|
618
725
|
|
726
|
+
@PrecondGradAccumGuard
|
619
727
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
620
728
|
@no_state
|
621
|
-
def scale_by_psgd_lra(group, update, grad, param, U, V, d):
|
622
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
729
|
+
def scale_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
730
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
|
623
731
|
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
624
732
|
|
625
733
|
|
734
|
+
@PrecondGradAccumGuard
|
626
735
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
627
736
|
@no_state
|
628
|
-
def update_by_psgd_lra(group, update, grad, param, U, V, d):
|
629
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
630
|
-
utils.apply_lra_update(param, update, u, v, d)
|
631
|
-
raise SkipUpdate
|
737
|
+
def update_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
738
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
|
739
|
+
utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
|
740
|
+
raise SkipUpdate from None
|
632
741
|
|
633
742
|
|
743
|
+
@PrecondGradAccumGuard
|
634
744
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
635
745
|
@no_state
|
636
|
-
def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
637
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
746
|
+
def scale_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
747
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
|
638
748
|
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
639
749
|
|
640
750
|
|
751
|
+
@PrecondGradAccumGuard
|
641
752
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
642
753
|
@no_state
|
643
|
-
def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
644
|
-
u, v, d = _update_lra(group, U, V, d, param,
|
645
|
-
utils.apply_lra_update(param, update, u, v, d)
|
646
|
-
raise SkipUpdate
|
754
|
+
def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
755
|
+
u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
|
756
|
+
utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
|
757
|
+
raise SkipUpdate from None
|
647
758
|
|
648
759
|
|
649
|
-
@
|
760
|
+
@PrecondGradAccumGuard
|
761
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
650
762
|
@no_state_no_foreach
|
651
763
|
def scale_by_psgd(
|
652
764
|
group,
|
653
765
|
update,
|
654
766
|
grad,
|
655
767
|
param,
|
768
|
+
update_to_precond,
|
656
769
|
Q,
|
657
|
-
exprs,
|
658
770
|
Q_cache,
|
659
|
-
|
771
|
+
velocity: Optional[List[Tensor]],
|
772
|
+
running_lower_bound: List[Tensor],
|
660
773
|
cached: bool = False,
|
661
774
|
prob: Optional[callable] = None,
|
662
775
|
):
|
663
|
-
|
664
|
-
|
665
|
-
Q_mat = _update_psgd_precond(
|
666
|
-
cached,
|
667
|
-
Q_cache,
|
668
|
-
group,
|
669
|
-
param,
|
670
|
-
update if group["momentum_into_precond_update"] else grad,
|
671
|
-
Q_mat,
|
672
|
-
Q,
|
673
|
-
exprs,
|
674
|
-
prob,
|
675
|
-
)
|
676
|
-
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
776
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
777
|
+
return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
|
677
778
|
|
678
779
|
|
679
|
-
@
|
780
|
+
@PrecondGradAccumGuard
|
781
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
680
782
|
@no_state_no_foreach
|
681
783
|
def scale_by_delayed_psgd(
|
682
784
|
group,
|
683
785
|
update,
|
684
786
|
grad,
|
685
787
|
param,
|
788
|
+
update_to_precond,
|
686
789
|
Q,
|
687
|
-
exprs,
|
688
790
|
Q_cache,
|
689
|
-
|
791
|
+
velocity: Optional[List[Tensor]],
|
792
|
+
running_lower_bound: List[Tensor],
|
690
793
|
cached: bool = False,
|
691
794
|
prob: Optional[callable] = None,
|
692
795
|
):
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
param,
|
700
|
-
update if group["momentum_into_precond_update"] else grad,
|
701
|
-
Q_mat,
|
702
|
-
Q,
|
703
|
-
exprs,
|
704
|
-
prob,
|
705
|
-
)
|
706
|
-
return precond
|
796
|
+
if group.get("inverse_free", False):
|
797
|
+
precond = None
|
798
|
+
else:
|
799
|
+
precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
|
800
|
+
new = _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
801
|
+
return new if precond is None else precond
|
707
802
|
|
708
803
|
|
709
|
-
@
|
804
|
+
@PrecondGradAccumGuard
|
805
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
710
806
|
@no_state_no_foreach
|
711
807
|
def update_by_psgd(
|
712
808
|
group,
|
713
809
|
update,
|
714
810
|
grad,
|
715
811
|
param,
|
812
|
+
update_to_precond,
|
716
813
|
Q,
|
717
|
-
exprs,
|
718
814
|
Q_cache,
|
719
|
-
|
815
|
+
velocity: Optional[List[Tensor]],
|
816
|
+
running_lower_bound: List[Tensor],
|
720
817
|
cached: bool = False,
|
721
818
|
prob: Optional[callable] = None,
|
722
819
|
):
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
Q_cache,
|
727
|
-
group,
|
728
|
-
param,
|
729
|
-
update if group["momentum_into_precond_update"] else grad,
|
730
|
-
Q_mat,
|
731
|
-
Q,
|
732
|
-
exprs,
|
733
|
-
prob,
|
734
|
-
)
|
735
|
-
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
736
|
-
raise SkipUpdate
|
820
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
821
|
+
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
|
822
|
+
raise SkipUpdate from None
|
737
823
|
|
738
824
|
|
739
825
|
@no_state
|
@@ -741,34 +827,31 @@ def sign(group, update, grad, param, graft: bool = True):
|
|
741
827
|
return utils.sign_(update, graft)
|
742
828
|
|
743
829
|
|
744
|
-
@
|
830
|
+
@no_state
|
831
|
+
def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
|
832
|
+
assert clip_fn is not None
|
833
|
+
return clip_fn(update)
|
834
|
+
|
835
|
+
|
836
|
+
@PrecondGradAccumGuard
|
837
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
745
838
|
@no_state_no_foreach
|
746
839
|
def update_by_delayed_psgd(
|
747
840
|
group,
|
748
841
|
update,
|
749
842
|
grad,
|
750
843
|
param,
|
844
|
+
update_to_precond,
|
751
845
|
Q,
|
752
|
-
exprs,
|
753
846
|
Q_cache,
|
754
|
-
|
847
|
+
velocity: Optional[List[Tensor]],
|
848
|
+
running_lower_bound: List[Tensor],
|
755
849
|
cached: bool = False,
|
756
850
|
prob: Optional[callable] = None,
|
757
851
|
):
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
cached,
|
762
|
-
Q_cache,
|
763
|
-
group,
|
764
|
-
param,
|
765
|
-
update if group["momentum_into_precond_update"] else grad,
|
766
|
-
Q_mat,
|
767
|
-
Q,
|
768
|
-
exprs,
|
769
|
-
prob,
|
770
|
-
)
|
771
|
-
raise SkipUpdate
|
852
|
+
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
|
853
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
854
|
+
raise SkipUpdate from None
|
772
855
|
|
773
856
|
|
774
857
|
def palm_beta2(state, group, update, grad, param):
|
@@ -819,12 +902,54 @@ def create_branch(branches: List[List[callable]], merge_fn: callable):
|
|
819
902
|
return _branch
|
820
903
|
|
821
904
|
|
905
|
+
def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
|
906
|
+
if retain:
|
907
|
+
if offset:
|
908
|
+
raise ValueError("offset cannot be retained")
|
909
|
+
|
910
|
+
offset = -1
|
911
|
+
for fn in fns:
|
912
|
+
while isinstance(fn, (FunctionTransform, functools.partial)):
|
913
|
+
if isinstance(fn, functools.partial):
|
914
|
+
fn = fn.func
|
915
|
+
continue
|
916
|
+
if fn.transform_idx is not None:
|
917
|
+
offset = max(offset, fn.transform_idx)
|
918
|
+
fn = fn.fn
|
919
|
+
offset += 1 # if we found nothing, this will be 0. if we found something, we START at N+1
|
920
|
+
|
921
|
+
fns = [copy.deepcopy(fn) for fn in fns]
|
922
|
+
for fn in fns:
|
923
|
+
while isinstance(fn, (FunctionTransform, functools.partial)):
|
924
|
+
if isinstance(fn, functools.partial):
|
925
|
+
fn = fn.func
|
926
|
+
continue
|
927
|
+
if not retain or fn.transform_idx is None:
|
928
|
+
fn.transform_idx = offset
|
929
|
+
offset += 1
|
930
|
+
fn = fn.fn
|
931
|
+
return fns
|
932
|
+
|
933
|
+
|
822
934
|
class ChainOpt(utils.StatefulOptimizer):
|
823
935
|
promote: bool = False
|
824
936
|
|
825
937
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
938
|
+
defaults = {k: v for k, v in defaults.items() if v is not use_default}
|
826
939
|
super().__init__(params, defaults, foreach)
|
827
|
-
self.fns =
|
940
|
+
self.fns = fns
|
941
|
+
|
942
|
+
@property
|
943
|
+
def fns(self):
|
944
|
+
return self._fns
|
945
|
+
|
946
|
+
@fns.setter
|
947
|
+
def fns(self, value):
|
948
|
+
self._fns = value
|
949
|
+
self._set_indices(retain=True)
|
950
|
+
|
951
|
+
def _set_indices(self, retain=True):
|
952
|
+
self._fns = set_indices(self.fns, retain)
|
828
953
|
|
829
954
|
def _step(self, group):
|
830
955
|
if "base_lr" not in group:
|
@@ -868,7 +993,6 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
868
993
|
group["step"] = None
|
869
994
|
|
870
995
|
|
871
|
-
use_default = object()
|
872
996
|
str_or_fn = Union[str, callable, None, Literal[use_default]]
|
873
997
|
|
874
998
|
|
@@ -1020,3 +1144,27 @@ class ScheduleFree(BaseOpt):
|
|
1020
1144
|
p32 = utils.promote(p.data)
|
1021
1145
|
p32.lerp_(end=z, weight=1 - beta1)
|
1022
1146
|
utils.copy_stochastic_(p.data, p32)
|
1147
|
+
|
1148
|
+
|
1149
|
+
class MSAM(BaseOpt):
|
1150
|
+
def eval(self):
|
1151
|
+
for group in self.param_groups:
|
1152
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
1153
|
+
if not train_mode:
|
1154
|
+
for p in group["params"]:
|
1155
|
+
state = self.state_(p)
|
1156
|
+
if "z" in state:
|
1157
|
+
p_copy = p.data.clone()
|
1158
|
+
utils.copy_stochastic_(p.data, state["z"])
|
1159
|
+
utils.copy_stochastic_(state["z"], p_copy)
|
1160
|
+
|
1161
|
+
def train(self):
|
1162
|
+
for group in self.param_groups:
|
1163
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
1164
|
+
if train_mode:
|
1165
|
+
for p in group["params"]:
|
1166
|
+
state = self.state_(p)
|
1167
|
+
if "z" in state:
|
1168
|
+
p_copy = p.data.clone()
|
1169
|
+
utils.copy_stochastic_(p.data, state["z"])
|
1170
|
+
utils.copy_stochastic_(state["z"], p_copy)
|