heavyball 0.25.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py ADDED
@@ -0,0 +1,475 @@
1
+ import functools
2
+ import random
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from . import utils
8
+
9
+ balance_probability: float = 0.01
10
+
11
+
12
+ def _key_in_state(state, key):
13
+ if isinstance(key, str):
14
+ return key in state
15
+ for k in key:
16
+ if isinstance(k, (tuple, list)):
17
+ continue
18
+ if k not in state:
19
+ return False
20
+ return True
21
+
22
+
23
+ def _inplace_guard_(state, key, template_fn):
24
+ key_not_in_state = not _key_in_state(state, key)
25
+ if key_not_in_state:
26
+ template_fn()
27
+ return key_not_in_state
28
+
29
+
30
+ def _guard_in_state(state, key, template_fn):
31
+ if not _key_in_state(state, key):
32
+ state[key] = template_fn()
33
+ return state[key]
34
+
35
+
36
+ def _zero_guard(state, key, ref, dtype):
37
+ return _guard_in_state(state, key,
38
+ lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
39
+
40
+
41
+ def _storage_dtype(group):
42
+ dtype = group.get('storage_dtype', "float32")
43
+ return getattr(torch, dtype)
44
+
45
+
46
+ def zero_guard(*names):
47
+ def _outer(fn):
48
+ def _fn(state, group, update, grad, param, *args, **kwargs):
49
+ vars = [[_zero_guard(state(p), name, p, _storage_dtype(group)) for p in param] for name in names]
50
+ return fn(state, group, update, grad, param, *args, *vars, **kwargs)
51
+
52
+ return _fn
53
+
54
+ return _outer
55
+
56
+
57
+ def copy_guard(index, *names):
58
+ def _outer(fn):
59
+ def _fn(state, group, update, grad, param, *args, **kwargs):
60
+ val = [update, grad, param, *args][index]
61
+ vars = [[_guard_in_state(state(p), name, lambda: torch.clone(v)) for p, v in zip(param, val)] #
62
+ for name in names]
63
+ return fn(state, group, update, grad, param, *args, *vars, **kwargs)
64
+
65
+ return _fn
66
+
67
+ return _outer
68
+
69
+
70
+ def general_guard(*names, init_fn):
71
+ def _outer(fn):
72
+ def _fn(state, group, update, grad, param, *args, **kwargs):
73
+ vars = []
74
+ skip_update = False
75
+ for p, g, u in zip(param, grad, update):
76
+ st = state(p)
77
+ skip_update |= _inplace_guard_(st, names, lambda: init_fn(st, group, u, g, p, **kwargs))
78
+ vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in names])
79
+ if skip_update:
80
+ raise SkipUpdate
81
+ return fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
82
+
83
+ return _fn
84
+
85
+ return _outer
86
+
87
+
88
+ def no_state(fn):
89
+ def _fn(state, *args, **kwargs):
90
+ return fn(*args, **kwargs)
91
+
92
+ return _fn
93
+
94
+
95
+ def no_state_no_foreach(fn):
96
+ def _fn(state, group, *args, **kwargs):
97
+ for a in zip(*args):
98
+ return fn(group, *a, **kwargs)
99
+
100
+ return _fn
101
+
102
+
103
+ class SkipUpdate(ValueError):
104
+ pass
105
+
106
+
107
+ @zero_guard("exp_avg")
108
+ @no_state
109
+ def exp_avg(group, update, grad, param, exp_avg):
110
+ utils.stochastic_lerp_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
111
+ return exp_avg
112
+
113
+
114
+ @zero_guard("exp_avg_sq")
115
+ @no_state
116
+ def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
117
+ out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
118
+ group['eps'])
119
+ if group['step'] == 1:
120
+ raise SkipUpdate
121
+ return out
122
+
123
+
124
+ @zero_guard("exp_avg", "exp_avg_sq")
125
+ @no_state
126
+ def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
127
+ return utils.adam_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'], #
128
+ group['eps'])
129
+
130
+
131
+ @zero_guard("exp_avg", "exp_avg_sq")
132
+ @no_state
133
+ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
134
+ utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
135
+ group['lr'], group['eps'], group['weight_decay'], group['caution'])
136
+ raise SkipUpdate
137
+
138
+
139
+ @zero_guard("exp_avg", "exp_avg_sq")
140
+ @no_state
141
+ def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
142
+ return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
143
+ group['eps'])
144
+
145
+
146
+ @zero_guard("exp_avg", "exp_avg_sq")
147
+ @no_state
148
+ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
149
+ utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
150
+ group['step'], group['lr'], group['weight_decay'], group['caution'])
151
+ raise SkipUpdate
152
+
153
+
154
+ @copy_guard(2, "z")
155
+ @no_state
156
+ def update_by_schedule_free(group, update, grad, param, z):
157
+ group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
158
+ utils.get_beta1(group), param, z, update, group['r'], group['step'],
159
+ group['weight_decay'])
160
+ raise SkipUpdate
161
+
162
+
163
+ @zero_guard("exp_avg", "exp_avg_sq")
164
+ @no_state
165
+ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
166
+ if group['step'] == 1:
167
+ utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
168
+ raise SkipUpdate
169
+
170
+ if group['step'] == 2:
171
+ update = utils.promote(update)
172
+ easq = utils.promote(exp_avg_sq)
173
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
174
+ utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
175
+ raise SkipUpdate
176
+
177
+ utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
178
+ group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
179
+ raise SkipUpdate
180
+
181
+
182
+ @zero_guard("exp_avg", "exp_avg_sq")
183
+ @no_state
184
+ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
185
+ if group['step'] == 1:
186
+ utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
187
+ raise SkipUpdate
188
+
189
+ if group['step'] == 2:
190
+ update = utils.promote(update)
191
+ easq = utils.promote(exp_avg_sq)
192
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
193
+ utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
194
+ raise SkipUpdate
195
+
196
+ return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
197
+
198
+
199
+ def _init_soap(state, group, update, grad, param):
200
+ utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
201
+
202
+
203
+ def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
204
+ Q, state["exprs"] = utils.init_Q_exprs(grad, group['precond_init_scale'], group['max_size_triangular'],
205
+ group['min_ndim_triangular'], group['memory_save_mode'],
206
+ dtype=getattr(torch, group['q_dtype']))
207
+ state["Q"] = utils.triu_to_line(Q) if group['store_triu_as_line'] else Q
208
+
209
+ if not cached:
210
+ return
211
+
212
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
213
+
214
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
215
+ expr = ','.join(expr)
216
+ grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape))
217
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
218
+ expr = f'{expr},{grad_expr}->{out_expr}'
219
+
220
+ state['cache_expr'] = expr
221
+
222
+
223
+ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = 'cumulative_prob'):
224
+ step = group['step']
225
+ if 'precondition_frequency' in group:
226
+ return step > 0 and step % group['precondition_frequency'] == 0
227
+ rng = random.Random(0x172381 ^ step)
228
+ if 'precond_scheduler' in group:
229
+ return utils.precond_schedule(step, group['precond_scheduler'], rng)
230
+ if prob is not None:
231
+ return utils.psgd_should_update(group, prob, rng, name=name)
232
+ raise ValueError("No preconditioner update schedule specified.")
233
+
234
+
235
+ @zero_guard("exp_avg", "exp_avg_sq")
236
+ @general_guard("Q", "GG", init_fn=_init_soap)
237
+ @no_state
238
+ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
239
+ update = utils.promote(update)
240
+
241
+ grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
242
+ precond = utils.adam_(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group),
243
+ utils.scalar_guard(group['step'], exp_avg[0]))
244
+ precond = [utils.project(p, q, False) for p, q in zip(precond, Q)]
245
+
246
+ for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
247
+ utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
248
+ utils.beta_debias(group['shampoo_beta'], group['step']), precond_schedule(group))
249
+ return precond
250
+
251
+
252
+ def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
253
+ if prob is None:
254
+ prob = utils.precond_update_prob_schedule()
255
+ if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
256
+ return
257
+
258
+ Q = [utils.promote(q_) for q_ in Q]
259
+ utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
260
+
261
+ if grad.dim() > 1 and precond_schedule(group, balance_probability, "balance_prob"):
262
+ if group['store_triu_as_line']:
263
+ utils.psgd_balance_Q([q_ for _, q_ in Q])
264
+ else:
265
+ utils.psgd_balance_Q(Q)
266
+
267
+
268
+ def _update_psgd_cache(cached, Q_cache, q):
269
+ if not cached:
270
+ return q
271
+
272
+ for c_, q_ in zip(Q_cache, q):
273
+ if q_.ndim == 2:
274
+ torch.matmul(q_.T, q_, out=c_)
275
+ else:
276
+ torch.mul(q_, q_, out=c_)
277
+ return Q_cache
278
+
279
+
280
+ def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
281
+ if cached:
282
+ return utils.precond_grad_cached_(cache_expr, update, *cache_expr)
283
+ return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
284
+
285
+
286
+ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, exprs, update, Q_mat, Q_cache):
287
+ if cached:
288
+ utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
289
+ group['caution'], *Q_cache)
290
+ else:
291
+ utils.fused_psgd_precond_grad(exprs[-1], update, param, group['lr'], grad, group['weight_decay'],
292
+ group['caution'], *Q_mat)
293
+
294
+
295
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
296
+ @no_state_no_foreach
297
+ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
298
+ prob: Optional[callable] = None):
299
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
300
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
301
+ return _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
302
+
303
+
304
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
305
+ @no_state_no_foreach
306
+ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
307
+ prob: Optional[callable] = None):
308
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
309
+ precond = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
310
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
311
+ return precond
312
+
313
+
314
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
315
+ @no_state_no_foreach
316
+ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
317
+ prob: Optional[callable] = None):
318
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
319
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
320
+ _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
321
+ raise SkipUpdate
322
+
323
+
324
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
325
+ @no_state_no_foreach
326
+ def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
327
+ prob: Optional[callable] = None):
328
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
329
+ _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
330
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
331
+ raise SkipUpdate
332
+
333
+
334
+ def palm_beta2(state, group, update, grad, param):
335
+ beta2 = 1 - group['step'] ** -group['beta2_scale']
336
+ group['betas'] = (utils.get_beta1(group), beta2)
337
+ return update
338
+
339
+
340
+ def apply_to_idx(fn, idx):
341
+ def _fn(state, group, update, grad, param):
342
+ args = [state, group, update, grad, param]
343
+ return fn(args[idx])
344
+
345
+ return _fn
346
+
347
+
348
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
349
+ update = grad
350
+ skip_update = False
351
+ for fn in fns:
352
+ try:
353
+ update = fn(state, group, update, grad, param)
354
+ except SkipUpdate:
355
+ skip_update = True
356
+ continue
357
+ if update is None:
358
+ break
359
+ if not skip_update and update is not None:
360
+ utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
361
+
362
+
363
+ class ChainOpt(utils.StatefulOptimizer):
364
+ def __init__(self, params, defaults, foreach: bool, *fns):
365
+ super().__init__(params, defaults, foreach)
366
+
367
+ self.fns = tuple(fns)
368
+
369
+ def _step(self, group):
370
+ if 'base_lr' not in group:
371
+ group['base_lr'] = group['lr']
372
+ step = group['step'] = group.get('step', 0) + 1
373
+ if group['warmup_steps'] and step < group['warmup_steps']:
374
+ group['lr'] = -group['base_lr'] * step / group['warmup_steps']
375
+ else:
376
+ group['lr'] = -group['base_lr']
377
+
378
+ p, g = zip(*list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group))))
379
+
380
+ if not group['foreach'] or len(p) == 1:
381
+ for param, grad in zip(p, g):
382
+ chain(self.state_, group, [grad], [param], *self.fns)
383
+ return
384
+
385
+ chain(self.state_, group, g, p, *self.fns)
386
+
387
+
388
+ use_default = object()
389
+ str_or_fn = Union[str, callable, None, use_default]
390
+
391
+
392
+ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
393
+ name = default(name, default_val)
394
+ if callable(name):
395
+ return name
396
+ elif name not in ('l2_clip_', 'rmsnorm_clip_', 'trust_region_clip_', 'a_law_compress', 'mu_law_compress'):
397
+ raise ValueError(f"Clipping function {name} not found")
398
+ return getattr(utils, name)
399
+
400
+
401
+ def default(a, b):
402
+ return b if a is None or a is use_default else a
403
+
404
+
405
+ # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
406
+ _scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
407
+ scale_by_psgd: update_by_psgd, #
408
+ scale_by_adam: update_by_adam, #
409
+ scale_by_laprop: update_by_laprop, #
410
+ scale_by_adopt: update_by_adopt}
411
+
412
+
413
+ class BaseOpt(ChainOpt):
414
+ gradient_clipping: str_or_fn = None
415
+ update_clipping: str_or_fn = None
416
+ palm: bool = False
417
+ auto_fuse: bool = True
418
+
419
+ def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
420
+ palm: bool = None, *fns):
421
+ if default(update_clipping, self.update_clipping) is None:
422
+ if fns and self.auto_fuse:
423
+ args, kwargs = None, None
424
+ fn = fns[-1]
425
+ if isinstance(fn, functools.partial):
426
+ fn, args, kwargs = fns[-1].func, fns[-1].args, fns[-1].keywords
427
+ if fn in _scale_to_update_map:
428
+ fn = _scale_to_update_map[fn]
429
+ if args is not None:
430
+ fn = functools.partial(fn, *args, **kwargs)
431
+ fns = tuple(fns)[:-1] + (fn,)
432
+ else:
433
+ if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
434
+ raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
435
+ fns = tuple(fns)
436
+
437
+ if default(palm, self.palm):
438
+ fns = (palm_beta2,) + fns
439
+ if default(gradient_clipping, self.gradient_clipping) is not None:
440
+ fns = (apply_to_idx(gradient_clipping, 2),) + fns
441
+ if default(update_clipping, self.update_clipping) is not None:
442
+ fns = fns + (apply_to_idx(update_clipping, 2),)
443
+
444
+ super().__init__(params, defaults, foreach, *fns)
445
+
446
+
447
+ class ScheduleFree(BaseOpt):
448
+ def eval(self):
449
+ for group in self.param_groups:
450
+ train_mode = group['train_mode']
451
+ beta1 = utils.get_beta1(group)
452
+ if beta1 > 0 and train_mode:
453
+ for p in group['params']:
454
+ state = self.state_(p)
455
+ if 'z' in state:
456
+ # Set p.data to x
457
+ z = utils.promote(state['z'])
458
+ p32 = utils.promote(p.data)
459
+ p32.lerp_(end=z, weight=1 - 1 / beta1)
460
+ utils.copy_stochastic_(p.data, p32)
461
+ group['train_mode'] = False
462
+
463
+ def train(self):
464
+ for group in self.param_groups:
465
+ train_mode = group['train_mode']
466
+ beta1 = utils.get_beta1(group)
467
+ if beta1 > 0 and not train_mode:
468
+ for p in group['params']:
469
+ state = self.state_(p)
470
+ if 'z' in state:
471
+ z = utils.promote(state['z'])
472
+ p32 = utils.promote(p.data)
473
+ p32.lerp_(end=z, weight=1 - beta1)
474
+ utils.copy_stochastic_(p.data, p32)
475
+ group['train_mode'] = True