heavyball 0.25.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +192 -29
- heavyball/chainable.py +475 -0
- heavyball/utils.py +334 -180
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/METADATA +4 -3
- heavyball-1.0.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -91
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -100
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.0.dist-info/RECORD +0 -24
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
ADDED
@@ -0,0 +1,475 @@
|
|
1
|
+
import functools
|
2
|
+
import random
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from . import utils
|
8
|
+
|
9
|
+
balance_probability: float = 0.01
|
10
|
+
|
11
|
+
|
12
|
+
def _key_in_state(state, key):
|
13
|
+
if isinstance(key, str):
|
14
|
+
return key in state
|
15
|
+
for k in key:
|
16
|
+
if isinstance(k, (tuple, list)):
|
17
|
+
continue
|
18
|
+
if k not in state:
|
19
|
+
return False
|
20
|
+
return True
|
21
|
+
|
22
|
+
|
23
|
+
def _inplace_guard_(state, key, template_fn):
|
24
|
+
key_not_in_state = not _key_in_state(state, key)
|
25
|
+
if key_not_in_state:
|
26
|
+
template_fn()
|
27
|
+
return key_not_in_state
|
28
|
+
|
29
|
+
|
30
|
+
def _guard_in_state(state, key, template_fn):
|
31
|
+
if not _key_in_state(state, key):
|
32
|
+
state[key] = template_fn()
|
33
|
+
return state[key]
|
34
|
+
|
35
|
+
|
36
|
+
def _zero_guard(state, key, ref, dtype):
|
37
|
+
return _guard_in_state(state, key,
|
38
|
+
lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
|
39
|
+
|
40
|
+
|
41
|
+
def _storage_dtype(group):
|
42
|
+
dtype = group.get('storage_dtype', "float32")
|
43
|
+
return getattr(torch, dtype)
|
44
|
+
|
45
|
+
|
46
|
+
def zero_guard(*names):
|
47
|
+
def _outer(fn):
|
48
|
+
def _fn(state, group, update, grad, param, *args, **kwargs):
|
49
|
+
vars = [[_zero_guard(state(p), name, p, _storage_dtype(group)) for p in param] for name in names]
|
50
|
+
return fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
51
|
+
|
52
|
+
return _fn
|
53
|
+
|
54
|
+
return _outer
|
55
|
+
|
56
|
+
|
57
|
+
def copy_guard(index, *names):
|
58
|
+
def _outer(fn):
|
59
|
+
def _fn(state, group, update, grad, param, *args, **kwargs):
|
60
|
+
val = [update, grad, param, *args][index]
|
61
|
+
vars = [[_guard_in_state(state(p), name, lambda: torch.clone(v)) for p, v in zip(param, val)] #
|
62
|
+
for name in names]
|
63
|
+
return fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
64
|
+
|
65
|
+
return _fn
|
66
|
+
|
67
|
+
return _outer
|
68
|
+
|
69
|
+
|
70
|
+
def general_guard(*names, init_fn):
|
71
|
+
def _outer(fn):
|
72
|
+
def _fn(state, group, update, grad, param, *args, **kwargs):
|
73
|
+
vars = []
|
74
|
+
skip_update = False
|
75
|
+
for p, g, u in zip(param, grad, update):
|
76
|
+
st = state(p)
|
77
|
+
skip_update |= _inplace_guard_(st, names, lambda: init_fn(st, group, u, g, p, **kwargs))
|
78
|
+
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in names])
|
79
|
+
if skip_update:
|
80
|
+
raise SkipUpdate
|
81
|
+
return fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
82
|
+
|
83
|
+
return _fn
|
84
|
+
|
85
|
+
return _outer
|
86
|
+
|
87
|
+
|
88
|
+
def no_state(fn):
|
89
|
+
def _fn(state, *args, **kwargs):
|
90
|
+
return fn(*args, **kwargs)
|
91
|
+
|
92
|
+
return _fn
|
93
|
+
|
94
|
+
|
95
|
+
def no_state_no_foreach(fn):
|
96
|
+
def _fn(state, group, *args, **kwargs):
|
97
|
+
for a in zip(*args):
|
98
|
+
return fn(group, *a, **kwargs)
|
99
|
+
|
100
|
+
return _fn
|
101
|
+
|
102
|
+
|
103
|
+
class SkipUpdate(ValueError):
|
104
|
+
pass
|
105
|
+
|
106
|
+
|
107
|
+
@zero_guard("exp_avg")
|
108
|
+
@no_state
|
109
|
+
def exp_avg(group, update, grad, param, exp_avg):
|
110
|
+
utils.stochastic_lerp_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
111
|
+
return exp_avg
|
112
|
+
|
113
|
+
|
114
|
+
@zero_guard("exp_avg_sq")
|
115
|
+
@no_state
|
116
|
+
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
117
|
+
out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
|
118
|
+
group['eps'])
|
119
|
+
if group['step'] == 1:
|
120
|
+
raise SkipUpdate
|
121
|
+
return out
|
122
|
+
|
123
|
+
|
124
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
125
|
+
@no_state
|
126
|
+
def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
127
|
+
return utils.adam_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'], #
|
128
|
+
group['eps'])
|
129
|
+
|
130
|
+
|
131
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
132
|
+
@no_state
|
133
|
+
def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
134
|
+
utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
|
135
|
+
group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
136
|
+
raise SkipUpdate
|
137
|
+
|
138
|
+
|
139
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
140
|
+
@no_state
|
141
|
+
def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
142
|
+
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
|
143
|
+
group['eps'])
|
144
|
+
|
145
|
+
|
146
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
147
|
+
@no_state
|
148
|
+
def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
149
|
+
utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
|
150
|
+
group['step'], group['lr'], group['weight_decay'], group['caution'])
|
151
|
+
raise SkipUpdate
|
152
|
+
|
153
|
+
|
154
|
+
@copy_guard(2, "z")
|
155
|
+
@no_state
|
156
|
+
def update_by_schedule_free(group, update, grad, param, z):
|
157
|
+
group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
|
158
|
+
utils.get_beta1(group), param, z, update, group['r'], group['step'],
|
159
|
+
group['weight_decay'])
|
160
|
+
raise SkipUpdate
|
161
|
+
|
162
|
+
|
163
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
164
|
+
@no_state
|
165
|
+
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
166
|
+
if group['step'] == 1:
|
167
|
+
utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
|
168
|
+
raise SkipUpdate
|
169
|
+
|
170
|
+
if group['step'] == 2:
|
171
|
+
update = utils.promote(update)
|
172
|
+
easq = utils.promote(exp_avg_sq)
|
173
|
+
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
174
|
+
utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
|
175
|
+
raise SkipUpdate
|
176
|
+
|
177
|
+
utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
178
|
+
group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
179
|
+
raise SkipUpdate
|
180
|
+
|
181
|
+
|
182
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
183
|
+
@no_state
|
184
|
+
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
185
|
+
if group['step'] == 1:
|
186
|
+
utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
|
187
|
+
raise SkipUpdate
|
188
|
+
|
189
|
+
if group['step'] == 2:
|
190
|
+
update = utils.promote(update)
|
191
|
+
easq = utils.promote(exp_avg_sq)
|
192
|
+
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
193
|
+
utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
|
194
|
+
raise SkipUpdate
|
195
|
+
|
196
|
+
return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
|
197
|
+
|
198
|
+
|
199
|
+
def _init_soap(state, group, update, grad, param):
|
200
|
+
utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
|
201
|
+
|
202
|
+
|
203
|
+
def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
204
|
+
Q, state["exprs"] = utils.init_Q_exprs(grad, group['precond_init_scale'], group['max_size_triangular'],
|
205
|
+
group['min_ndim_triangular'], group['memory_save_mode'],
|
206
|
+
dtype=getattr(torch, group['q_dtype']))
|
207
|
+
state["Q"] = utils.triu_to_line(Q) if group['store_triu_as_line'] else Q
|
208
|
+
|
209
|
+
if not cached:
|
210
|
+
return
|
211
|
+
|
212
|
+
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
213
|
+
|
214
|
+
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
|
215
|
+
expr = ','.join(expr)
|
216
|
+
grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape))
|
217
|
+
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
218
|
+
expr = f'{expr},{grad_expr}->{out_expr}'
|
219
|
+
|
220
|
+
state['cache_expr'] = expr
|
221
|
+
|
222
|
+
|
223
|
+
def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = 'cumulative_prob'):
|
224
|
+
step = group['step']
|
225
|
+
if 'precondition_frequency' in group:
|
226
|
+
return step > 0 and step % group['precondition_frequency'] == 0
|
227
|
+
rng = random.Random(0x172381 ^ step)
|
228
|
+
if 'precond_scheduler' in group:
|
229
|
+
return utils.precond_schedule(step, group['precond_scheduler'], rng)
|
230
|
+
if prob is not None:
|
231
|
+
return utils.psgd_should_update(group, prob, rng, name=name)
|
232
|
+
raise ValueError("No preconditioner update schedule specified.")
|
233
|
+
|
234
|
+
|
235
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
236
|
+
@general_guard("Q", "GG", init_fn=_init_soap)
|
237
|
+
@no_state
|
238
|
+
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
|
239
|
+
update = utils.promote(update)
|
240
|
+
|
241
|
+
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
242
|
+
precond = utils.adam_(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group),
|
243
|
+
utils.scalar_guard(group['step'], exp_avg[0]))
|
244
|
+
precond = [utils.project(p, q, False) for p, q in zip(precond, Q)]
|
245
|
+
|
246
|
+
for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
|
247
|
+
utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
|
248
|
+
utils.beta_debias(group['shampoo_beta'], group['step']), precond_schedule(group))
|
249
|
+
return precond
|
250
|
+
|
251
|
+
|
252
|
+
def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
253
|
+
if prob is None:
|
254
|
+
prob = utils.precond_update_prob_schedule()
|
255
|
+
if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
|
256
|
+
return
|
257
|
+
|
258
|
+
Q = [utils.promote(q_) for q_ in Q]
|
259
|
+
utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
|
260
|
+
|
261
|
+
if grad.dim() > 1 and precond_schedule(group, balance_probability, "balance_prob"):
|
262
|
+
if group['store_triu_as_line']:
|
263
|
+
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
264
|
+
else:
|
265
|
+
utils.psgd_balance_Q(Q)
|
266
|
+
|
267
|
+
|
268
|
+
def _update_psgd_cache(cached, Q_cache, q):
|
269
|
+
if not cached:
|
270
|
+
return q
|
271
|
+
|
272
|
+
for c_, q_ in zip(Q_cache, q):
|
273
|
+
if q_.ndim == 2:
|
274
|
+
torch.matmul(q_.T, q_, out=c_)
|
275
|
+
else:
|
276
|
+
torch.mul(q_, q_, out=c_)
|
277
|
+
return Q_cache
|
278
|
+
|
279
|
+
|
280
|
+
def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
|
281
|
+
if cached:
|
282
|
+
return utils.precond_grad_cached_(cache_expr, update, *cache_expr)
|
283
|
+
return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
|
284
|
+
|
285
|
+
|
286
|
+
def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, exprs, update, Q_mat, Q_cache):
|
287
|
+
if cached:
|
288
|
+
utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
|
289
|
+
group['caution'], *Q_cache)
|
290
|
+
else:
|
291
|
+
utils.fused_psgd_precond_grad(exprs[-1], update, param, group['lr'], grad, group['weight_decay'],
|
292
|
+
group['caution'], *Q_mat)
|
293
|
+
|
294
|
+
|
295
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
296
|
+
@no_state_no_foreach
|
297
|
+
def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
298
|
+
prob: Optional[callable] = None):
|
299
|
+
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
300
|
+
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
301
|
+
return _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
|
302
|
+
|
303
|
+
|
304
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
305
|
+
@no_state_no_foreach
|
306
|
+
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
307
|
+
prob: Optional[callable] = None):
|
308
|
+
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
309
|
+
precond = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
|
310
|
+
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
311
|
+
return precond
|
312
|
+
|
313
|
+
|
314
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
315
|
+
@no_state_no_foreach
|
316
|
+
def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
317
|
+
prob: Optional[callable] = None):
|
318
|
+
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
319
|
+
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
320
|
+
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
321
|
+
raise SkipUpdate
|
322
|
+
|
323
|
+
|
324
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
325
|
+
@no_state_no_foreach
|
326
|
+
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
327
|
+
prob: Optional[callable] = None):
|
328
|
+
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
329
|
+
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
330
|
+
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
331
|
+
raise SkipUpdate
|
332
|
+
|
333
|
+
|
334
|
+
def palm_beta2(state, group, update, grad, param):
|
335
|
+
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
336
|
+
group['betas'] = (utils.get_beta1(group), beta2)
|
337
|
+
return update
|
338
|
+
|
339
|
+
|
340
|
+
def apply_to_idx(fn, idx):
|
341
|
+
def _fn(state, group, update, grad, param):
|
342
|
+
args = [state, group, update, grad, param]
|
343
|
+
return fn(args[idx])
|
344
|
+
|
345
|
+
return _fn
|
346
|
+
|
347
|
+
|
348
|
+
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
349
|
+
update = grad
|
350
|
+
skip_update = False
|
351
|
+
for fn in fns:
|
352
|
+
try:
|
353
|
+
update = fn(state, group, update, grad, param)
|
354
|
+
except SkipUpdate:
|
355
|
+
skip_update = True
|
356
|
+
continue
|
357
|
+
if update is None:
|
358
|
+
break
|
359
|
+
if not skip_update and update is not None:
|
360
|
+
utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
|
361
|
+
|
362
|
+
|
363
|
+
class ChainOpt(utils.StatefulOptimizer):
|
364
|
+
def __init__(self, params, defaults, foreach: bool, *fns):
|
365
|
+
super().__init__(params, defaults, foreach)
|
366
|
+
|
367
|
+
self.fns = tuple(fns)
|
368
|
+
|
369
|
+
def _step(self, group):
|
370
|
+
if 'base_lr' not in group:
|
371
|
+
group['base_lr'] = group['lr']
|
372
|
+
step = group['step'] = group.get('step', 0) + 1
|
373
|
+
if group['warmup_steps'] and step < group['warmup_steps']:
|
374
|
+
group['lr'] = -group['base_lr'] * step / group['warmup_steps']
|
375
|
+
else:
|
376
|
+
group['lr'] = -group['base_lr']
|
377
|
+
|
378
|
+
p, g = zip(*list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group))))
|
379
|
+
|
380
|
+
if not group['foreach'] or len(p) == 1:
|
381
|
+
for param, grad in zip(p, g):
|
382
|
+
chain(self.state_, group, [grad], [param], *self.fns)
|
383
|
+
return
|
384
|
+
|
385
|
+
chain(self.state_, group, g, p, *self.fns)
|
386
|
+
|
387
|
+
|
388
|
+
use_default = object()
|
389
|
+
str_or_fn = Union[str, callable, None, use_default]
|
390
|
+
|
391
|
+
|
392
|
+
def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
|
393
|
+
name = default(name, default_val)
|
394
|
+
if callable(name):
|
395
|
+
return name
|
396
|
+
elif name not in ('l2_clip_', 'rmsnorm_clip_', 'trust_region_clip_', 'a_law_compress', 'mu_law_compress'):
|
397
|
+
raise ValueError(f"Clipping function {name} not found")
|
398
|
+
return getattr(utils, name)
|
399
|
+
|
400
|
+
|
401
|
+
def default(a, b):
|
402
|
+
return b if a is None or a is use_default else a
|
403
|
+
|
404
|
+
|
405
|
+
# not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
|
406
|
+
_scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
|
407
|
+
scale_by_psgd: update_by_psgd, #
|
408
|
+
scale_by_adam: update_by_adam, #
|
409
|
+
scale_by_laprop: update_by_laprop, #
|
410
|
+
scale_by_adopt: update_by_adopt}
|
411
|
+
|
412
|
+
|
413
|
+
class BaseOpt(ChainOpt):
|
414
|
+
gradient_clipping: str_or_fn = None
|
415
|
+
update_clipping: str_or_fn = None
|
416
|
+
palm: bool = False
|
417
|
+
auto_fuse: bool = True
|
418
|
+
|
419
|
+
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
420
|
+
palm: bool = None, *fns):
|
421
|
+
if default(update_clipping, self.update_clipping) is None:
|
422
|
+
if fns and self.auto_fuse:
|
423
|
+
args, kwargs = None, None
|
424
|
+
fn = fns[-1]
|
425
|
+
if isinstance(fn, functools.partial):
|
426
|
+
fn, args, kwargs = fns[-1].func, fns[-1].args, fns[-1].keywords
|
427
|
+
if fn in _scale_to_update_map:
|
428
|
+
fn = _scale_to_update_map[fn]
|
429
|
+
if args is not None:
|
430
|
+
fn = functools.partial(fn, *args, **kwargs)
|
431
|
+
fns = tuple(fns)[:-1] + (fn,)
|
432
|
+
else:
|
433
|
+
if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
|
434
|
+
raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
|
435
|
+
fns = tuple(fns)
|
436
|
+
|
437
|
+
if default(palm, self.palm):
|
438
|
+
fns = (palm_beta2,) + fns
|
439
|
+
if default(gradient_clipping, self.gradient_clipping) is not None:
|
440
|
+
fns = (apply_to_idx(gradient_clipping, 2),) + fns
|
441
|
+
if default(update_clipping, self.update_clipping) is not None:
|
442
|
+
fns = fns + (apply_to_idx(update_clipping, 2),)
|
443
|
+
|
444
|
+
super().__init__(params, defaults, foreach, *fns)
|
445
|
+
|
446
|
+
|
447
|
+
class ScheduleFree(BaseOpt):
|
448
|
+
def eval(self):
|
449
|
+
for group in self.param_groups:
|
450
|
+
train_mode = group['train_mode']
|
451
|
+
beta1 = utils.get_beta1(group)
|
452
|
+
if beta1 > 0 and train_mode:
|
453
|
+
for p in group['params']:
|
454
|
+
state = self.state_(p)
|
455
|
+
if 'z' in state:
|
456
|
+
# Set p.data to x
|
457
|
+
z = utils.promote(state['z'])
|
458
|
+
p32 = utils.promote(p.data)
|
459
|
+
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
460
|
+
utils.copy_stochastic_(p.data, p32)
|
461
|
+
group['train_mode'] = False
|
462
|
+
|
463
|
+
def train(self):
|
464
|
+
for group in self.param_groups:
|
465
|
+
train_mode = group['train_mode']
|
466
|
+
beta1 = utils.get_beta1(group)
|
467
|
+
if beta1 > 0 and not train_mode:
|
468
|
+
for p in group['params']:
|
469
|
+
state = self.state_(p)
|
470
|
+
if 'z' in state:
|
471
|
+
z = utils.promote(state['z'])
|
472
|
+
p32 = utils.promote(p.data)
|
473
|
+
p32.lerp_(end=z, weight=1 - beta1)
|
474
|
+
utils.copy_stochastic_(p.data, p32)
|
475
|
+
group['train_mode'] = True
|