heavyball 0.25.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py ADDED
@@ -0,0 +1,532 @@
1
+ import functools
2
+ import random
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from . import utils
8
+
9
+ balance_probability: float = 0.01
10
+
11
+
12
+ def _key_in_state(state, key):
13
+ if isinstance(key, str):
14
+ return key in state
15
+ for k in key:
16
+ if isinstance(k, (tuple, list)):
17
+ continue
18
+ if k not in state:
19
+ return False
20
+ return True
21
+
22
+
23
+ def _inplace_guard_(state, key, template_fn):
24
+ key_not_in_state = not _key_in_state(state, key)
25
+ if key_not_in_state:
26
+ template_fn()
27
+ return key_not_in_state
28
+
29
+
30
+ def _guard_in_state(state, key, template_fn):
31
+ if not _key_in_state(state, key):
32
+ state[key] = template_fn()
33
+ return state[key]
34
+
35
+
36
+ class FunctionTransform:
37
+ def __init__(self, fn):
38
+ self.fn = fn
39
+ self.fn_name = self.get_fn().__name__
40
+
41
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
42
+ raise NotImplementedError
43
+
44
+ def get_fn(self):
45
+ if hasattr(self.fn, 'get_fn'):
46
+ return self.fn.get_fn()
47
+ return self.fn
48
+
49
+ def val_name(self, name):
50
+ return f"{self.fn_name}_{name}"
51
+
52
+
53
+ def _zero_guard(state, key, ref, dtype):
54
+ return _guard_in_state(state, key,
55
+ lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
56
+
57
+
58
+ def _storage_dtype(group):
59
+ dtype = group.get('storage_dtype', "float32")
60
+ return getattr(torch, dtype)
61
+
62
+
63
+ class ZeroGuard(FunctionTransform):
64
+ def __init__(self, fn, names):
65
+ super().__init__(fn)
66
+ self.names = names
67
+
68
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
69
+ vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
70
+ for name in self.names]
71
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
72
+
73
+
74
+ class CopyGuard(FunctionTransform):
75
+ def __init__(self, fn, index, names):
76
+ super().__init__(fn)
77
+ self.index = index
78
+ self.names = names
79
+
80
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
81
+ val = [update, grad, param, *args][self.index]
82
+ vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
83
+ for name in self.names]
84
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
85
+
86
+
87
+ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
88
+ def __init__(self, fn, names, init_fn):
89
+ super().__init__(fn)
90
+ self.names = names
91
+ self.init_fn = init_fn
92
+
93
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
94
+ vars = []
95
+ skip_update = False
96
+ for p, g, u in zip(param, grad, update):
97
+ st = state(p)
98
+ skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
99
+ vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
100
+ if skip_update:
101
+ raise SkipUpdate
102
+ return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
103
+
104
+
105
+ class NoState(FunctionTransform):
106
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
107
+ return self.fn(group, update, grad, param, *args, **kwargs)
108
+
109
+
110
+ class NoStateNoForeach(FunctionTransform):
111
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
112
+ for a in zip(update, grad, param, *args):
113
+ return self.fn(group, *a, **kwargs)
114
+
115
+
116
+ def zero_guard(*names):
117
+ return functools.partial(ZeroGuard, names=names)
118
+
119
+
120
+ def copy_guard(index, *names):
121
+ return functools.partial(CopyGuard, index=index, names=names)
122
+
123
+
124
+ def general_guard(*names, init_fn):
125
+ return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
126
+
127
+
128
+ def no_state(fn):
129
+ return NoState(fn)
130
+
131
+
132
+ def no_state_no_foreach(fn):
133
+ return NoStateNoForeach(fn)
134
+
135
+
136
+ class SkipUpdate(ValueError):
137
+ pass
138
+
139
+
140
+ @zero_guard("exp_avg")
141
+ @no_state
142
+ def exp_avg(group, update, grad, param, exp_avg):
143
+ utils.stochastic_lerp_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
144
+ return exp_avg
145
+
146
+
147
+ @zero_guard("exp_avg_sq")
148
+ @no_state
149
+ def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
150
+ out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
151
+ group['eps'])
152
+ return out
153
+
154
+
155
+ @zero_guard("exp_avg", "exp_avg_sq")
156
+ @no_state
157
+ def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
158
+ return utils.adam_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'], #
159
+ group['eps'])
160
+
161
+
162
+ @zero_guard("exp_avg", "exp_avg_sq")
163
+ @no_state
164
+ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
165
+ utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
166
+ group['lr'], group['eps'], group['weight_decay'], group['caution'])
167
+ raise SkipUpdate
168
+
169
+
170
+ @zero_guard("exp_avg", "exp_avg_sq")
171
+ @no_state
172
+ def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
173
+ return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
174
+ group['eps'])
175
+
176
+
177
+ @zero_guard("exp_avg", "exp_avg_sq")
178
+ @no_state
179
+ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
180
+ utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
181
+ group['step'], group['lr'], group['weight_decay'], group['caution'])
182
+ raise SkipUpdate
183
+
184
+
185
+ @copy_guard(2, "z")
186
+ @no_state
187
+ def update_by_schedule_free(group, update, grad, param, z):
188
+ group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
189
+ utils.get_beta1(group), param, z, update, group['r'], group['step'],
190
+ group['weight_decay'])
191
+ raise SkipUpdate
192
+
193
+
194
+ @zero_guard("exp_avg", "exp_avg_sq")
195
+ @no_state
196
+ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
197
+ if group['step'] == 1:
198
+ utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
199
+ raise SkipUpdate
200
+
201
+ if group['step'] == 2:
202
+ update = utils.promote(update)
203
+ easq = utils.promote(exp_avg_sq)
204
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
205
+ utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
206
+ raise SkipUpdate
207
+
208
+ utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
209
+ group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
210
+ raise SkipUpdate
211
+
212
+
213
+ @zero_guard("exp_avg", "exp_avg_sq")
214
+ @no_state
215
+ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
216
+ if group['step'] == 1:
217
+ utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
218
+ raise SkipUpdate
219
+
220
+ if group['step'] == 2:
221
+ update = utils.promote(update)
222
+ easq = utils.promote(exp_avg_sq)
223
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
224
+ utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
225
+ raise SkipUpdate
226
+
227
+ return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
228
+
229
+
230
+ def _init_soap(state, group, update, grad, param):
231
+ utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
232
+
233
+
234
+ def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
235
+ Q, state["exprs"] = utils.init_Q_exprs(grad, group['precond_init_scale'], group['max_size_triangular'],
236
+ group['min_ndim_triangular'], group['memory_save_mode'],
237
+ dtype=getattr(torch, group['q_dtype']))
238
+ state["Q"] = utils.triu_to_line(Q) if group['store_triu_as_line'] else Q
239
+
240
+ if not cached:
241
+ return
242
+
243
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
244
+
245
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
246
+ expr = ','.join(expr)
247
+ grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape))
248
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
249
+ expr = f'{expr},{grad_expr}->{out_expr}'
250
+
251
+ state['cache_expr'] = expr
252
+
253
+
254
+ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = 'cumulative_prob'):
255
+ step = group['step']
256
+ if 'precondition_frequency' in group:
257
+ return step > 0 and step % group['precondition_frequency'] == 0
258
+ rng = random.Random(0x172381 ^ step)
259
+ if 'precond_scheduler' in group:
260
+ return utils.precond_schedule(step, group['precond_scheduler'], rng)
261
+ if prob is not None:
262
+ return utils.psgd_should_update(group, prob, rng, name=name)
263
+ raise ValueError("No preconditioner update schedule specified.")
264
+
265
+
266
+ @no_state_no_foreach
267
+ def orthogonalize_update(group, update, grad, param):
268
+ if update.dim() == 1:
269
+ return update
270
+ original_shape = update.shape
271
+ # doing it this way, as tmp and update are not guaranteed to share memory address or layout
272
+ tmp = update.flatten(1, -1)
273
+ utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
274
+ return tmp.reshape(original_shape)
275
+
276
+
277
+ @zero_guard("momentum")
278
+ @no_state
279
+ def nesterov_momentum(group, updates, grads, params, momentum):
280
+ return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
281
+
282
+
283
+ @zero_guard("momentum")
284
+ @no_state
285
+ def heavyball_momentum(group, updates, grads, params, momentum):
286
+ return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
287
+
288
+
289
+ @zero_guard("exp_avg", "exp_avg_sq")
290
+ @general_guard("Q", "GG", init_fn=_init_soap)
291
+ @no_state
292
+ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
293
+ update = utils.promote(update)
294
+
295
+ grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
296
+ precond = utils.adam_(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group),
297
+ utils.scalar_guard(group['step'], exp_avg[0]))
298
+ precond = [utils.project(p, q, False) for p, q in zip(precond, Q)]
299
+
300
+ for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
301
+ utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
302
+ utils.beta_debias(group['shampoo_beta'], group['step']), precond_schedule(group))
303
+ return precond
304
+
305
+
306
+ def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
307
+ if prob is None:
308
+ prob = utils.precond_update_prob_schedule()
309
+ if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
310
+ return
311
+
312
+ Q = [utils.promote(q_) for q_ in Q]
313
+ utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
314
+
315
+ if grad.dim() > 1 and precond_schedule(group, balance_probability, "balance_prob"):
316
+ if group['store_triu_as_line']:
317
+ utils.psgd_balance_Q([q_ for _, q_ in Q])
318
+ else:
319
+ utils.psgd_balance_Q(Q)
320
+
321
+
322
+ def _update_psgd_cache(cached, Q_cache, q):
323
+ if not cached:
324
+ return q
325
+
326
+ for c_, q_ in zip(Q_cache, q):
327
+ if q_.ndim == 2:
328
+ torch.matmul(q_.T, q_, out=c_)
329
+ else:
330
+ torch.mul(q_, q_, out=c_)
331
+ return Q_cache
332
+
333
+
334
+ def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
335
+ if cached:
336
+ return utils.precond_grad_cached_(cache_expr, update, *cache_expr)
337
+ return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
338
+
339
+
340
+ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, exprs, update, Q_mat, Q_cache):
341
+ if cached:
342
+ utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
343
+ group['caution'], *Q_cache)
344
+ else:
345
+ utils.fused_psgd_precond_grad(exprs[-1], update, param, group['lr'], grad, group['weight_decay'],
346
+ group['caution'], *Q_mat)
347
+
348
+
349
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
350
+ @no_state_no_foreach
351
+ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
352
+ prob: Optional[callable] = None):
353
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
354
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
355
+ return _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
356
+
357
+
358
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
359
+ @no_state_no_foreach
360
+ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
361
+ prob: Optional[callable] = None):
362
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
363
+ precond = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
364
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
365
+ return precond
366
+
367
+
368
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
369
+ @no_state_no_foreach
370
+ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
371
+ prob: Optional[callable] = None):
372
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
373
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
374
+ _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
375
+ raise SkipUpdate
376
+
377
+
378
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
379
+ @no_state_no_foreach
380
+ def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
381
+ prob: Optional[callable] = None):
382
+ Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
383
+ _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
384
+ _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
385
+ raise SkipUpdate
386
+
387
+
388
+ def palm_beta2(state, group, update, grad, param):
389
+ beta2 = 1 - group['step'] ** -group['beta2_scale']
390
+ group['betas'] = (utils.get_beta1(group), beta2)
391
+ return update
392
+
393
+
394
+ def apply_to_idx(fn, idx):
395
+ def _fn(state, group, update, grad, param):
396
+ args = [state, group, update, grad, param]
397
+ return fn(args[idx])
398
+
399
+ return _fn
400
+
401
+
402
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
403
+ update = [torch.clone(g) for g in grad]
404
+ skip_update = False
405
+ for fn in fns:
406
+ try:
407
+ update = fn(state, group, update, grad, param)
408
+ except SkipUpdate:
409
+ skip_update = True
410
+ continue
411
+ if update is None:
412
+ break
413
+ if not skip_update and update is not None:
414
+ utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
415
+
416
+
417
+ class ChainOpt(utils.StatefulOptimizer):
418
+ def __init__(self, params, defaults, foreach: bool, *fns):
419
+ super().__init__(params, defaults, foreach)
420
+
421
+ self.fns = tuple(fns)
422
+
423
+ def _step(self, group):
424
+ if 'base_lr' not in group:
425
+ group['base_lr'] = group['lr']
426
+ step = group['step'] = group.get('step', 0) + 1
427
+ if group['warmup_steps'] and step < group['warmup_steps']:
428
+ group['lr'] = -group['base_lr'] * step / group['warmup_steps']
429
+ else:
430
+ group['lr'] = -group['base_lr']
431
+
432
+ vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
433
+ if not vals:
434
+ return
435
+ p, g = zip(*vals)
436
+
437
+ if not group['foreach'] or len(p) == 1:
438
+ for param, grad in zip(p, g):
439
+ chain(self.state_, group, [grad], [param], *self.fns)
440
+ return
441
+
442
+ chain(self.state_, group, g, p, *self.fns)
443
+
444
+
445
+ use_default = object()
446
+ str_or_fn = Union[str, callable, None, use_default]
447
+
448
+
449
+ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
450
+ name = default(name, default_val)
451
+ if callable(name):
452
+ return name
453
+ elif name not in ('l2_clip_', 'rmsnorm_clip_', 'trust_region_clip_', 'a_law_compress', 'mu_law_compress'):
454
+ raise ValueError(f"Clipping function {name} not found")
455
+ return getattr(utils, name)
456
+
457
+
458
+ def default(a, b):
459
+ return b if a is None or a is use_default else a
460
+
461
+
462
+ # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
463
+ _scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
464
+ scale_by_psgd: update_by_psgd, #
465
+ scale_by_adam: update_by_adam, #
466
+ scale_by_laprop: update_by_laprop, #
467
+ scale_by_adopt: update_by_adopt}
468
+
469
+
470
+ class BaseOpt(ChainOpt):
471
+ gradient_clipping: str_or_fn = None
472
+ update_clipping: str_or_fn = None
473
+ palm: bool = False
474
+ auto_fuse: bool = True
475
+
476
+ def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
477
+ palm: bool = None, *fns):
478
+ if default(update_clipping, self.update_clipping) is None:
479
+ if fns and self.auto_fuse:
480
+ args, kwargs = None, None
481
+ fn = fns[-1]
482
+ if isinstance(fn, functools.partial):
483
+ fn, args, kwargs = fns[-1].func, fns[-1].args, fns[-1].keywords
484
+ if fn in _scale_to_update_map:
485
+ fn = _scale_to_update_map[fn]
486
+ if args is not None:
487
+ fn = functools.partial(fn, *args, **kwargs)
488
+ fns = tuple(fns)[:-1] + (fn,)
489
+ else:
490
+ if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
491
+ raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
492
+ fns = tuple(fns)
493
+
494
+ if default(palm, self.palm):
495
+ fns = (palm_beta2,) + fns
496
+ if default(gradient_clipping, self.gradient_clipping) is not None:
497
+ fns = (apply_to_idx(gradient_clipping, 2),) + fns
498
+ if default(update_clipping, self.update_clipping) is not None:
499
+ fns = fns + (apply_to_idx(update_clipping, 2),)
500
+
501
+ super().__init__(params, defaults, foreach, *fns)
502
+
503
+
504
+ class ScheduleFree(BaseOpt):
505
+ def eval(self):
506
+ for group in self.param_groups:
507
+ train_mode = group['train_mode']
508
+ beta1 = utils.get_beta1(group)
509
+ if beta1 > 0 and train_mode:
510
+ for p in group['params']:
511
+ state = self.state_(p)
512
+ if 'z' in state:
513
+ # Set p.data to x
514
+ z = utils.promote(state['z'])
515
+ p32 = utils.promote(p.data)
516
+ p32.lerp_(end=z, weight=1 - 1 / beta1)
517
+ utils.copy_stochastic_(p.data, p32)
518
+ group['train_mode'] = False
519
+
520
+ def train(self):
521
+ for group in self.param_groups:
522
+ train_mode = group['train_mode']
523
+ beta1 = utils.get_beta1(group)
524
+ if beta1 > 0 and not train_mode:
525
+ for p in group['params']:
526
+ state = self.state_(p)
527
+ if 'z' in state:
528
+ z = utils.promote(state['z'])
529
+ p32 = utils.promote(p.data)
530
+ p32.lerp_(end=z, weight=1 - beta1)
531
+ utils.copy_stochastic_(p.data, p32)
532
+ group['train_mode'] = True