heavyball 1.1.3__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +39 -17
- heavyball/utils.py +76 -77
- {heavyball-1.1.3.dist-info → heavyball-1.2.1.dist-info}/METADATA +1 -1
- heavyball-1.2.1.dist-info/RECORD +8 -0
- heavyball-1.1.3.dist-info/RECORD +0 -8
- {heavyball-1.1.3.dist-info → heavyball-1.2.1.dist-info}/LICENSE +0 -0
- {heavyball-1.1.3.dist-info → heavyball-1.2.1.dist-info}/WHEEL +0 -0
- {heavyball-1.1.3.dist-info → heavyball-1.2.1.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
+
import warnings
|
3
4
|
from typing import Optional, Union, Literal
|
4
5
|
|
5
6
|
import torch
|
@@ -251,7 +252,11 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
251
252
|
step = group['step']
|
252
253
|
if 'precondition_frequency' in group:
|
253
254
|
return step > 0 and step % group['precondition_frequency'] == 0
|
254
|
-
|
255
|
+
if isinstance(step, torch.Tensor):
|
256
|
+
utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
|
257
|
+
rng = random.Random(0x172381)
|
258
|
+
else:
|
259
|
+
rng = random.Random(0x172381 ^ step)
|
255
260
|
if 'precond_scheduler' in group:
|
256
261
|
return utils.precond_schedule(step, group['precond_scheduler'], rng)
|
257
262
|
if prob is not None:
|
@@ -414,6 +419,8 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
414
419
|
|
415
420
|
|
416
421
|
class ChainOpt(utils.StatefulOptimizer):
|
422
|
+
compile_step: bool = False
|
423
|
+
|
417
424
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
418
425
|
super().__init__(params, defaults, foreach)
|
419
426
|
self.fns = tuple(fns)
|
@@ -421,23 +428,36 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
421
428
|
def _step(self, group):
|
422
429
|
if 'base_lr' not in group:
|
423
430
|
group['base_lr'] = group['lr']
|
424
|
-
step = group['step'] = group.get('step', 0) + 1
|
425
|
-
if group['warmup_steps'] and step < group['warmup_steps']:
|
426
|
-
group['lr'] = -group['base_lr'] * step / group['warmup_steps']
|
427
|
-
else:
|
428
|
-
group['lr'] = -group['base_lr']
|
429
431
|
|
430
432
|
vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
|
431
433
|
if not vals:
|
432
434
|
return
|
433
435
|
p, g = zip(*vals)
|
434
436
|
|
437
|
+
for param in p:
|
438
|
+
state = self.state_(param)
|
439
|
+
if 'step' not in state:
|
440
|
+
if self.compile_step:
|
441
|
+
step = utils.scalar_guard(0, param)
|
442
|
+
state['step'] = step
|
443
|
+
step = state['step'].add_(1)
|
444
|
+
break
|
445
|
+
|
446
|
+
group['step'] = step
|
447
|
+
|
448
|
+
if group['warmup_steps'] and step < group['warmup_steps']:
|
449
|
+
group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
450
|
+
else:
|
451
|
+
group['lr'] = group['base_lr']
|
452
|
+
|
435
453
|
if not group['foreach'] or len(p) == 1:
|
436
454
|
for param, grad in zip(p, g):
|
437
455
|
chain(self.state_, group, [grad], [param], *self.fns)
|
438
|
-
|
456
|
+
else:
|
457
|
+
chain(self.state_, group, g, p, *self.fns)
|
439
458
|
|
440
|
-
|
459
|
+
group['lr'] = None
|
460
|
+
group['step'] = None
|
441
461
|
|
442
462
|
|
443
463
|
use_default = object()
|
@@ -454,15 +474,15 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
|
|
454
474
|
|
455
475
|
|
456
476
|
def default(a, b):
|
457
|
-
return b if a is
|
477
|
+
return b if a is use_default else a
|
458
478
|
|
459
479
|
|
460
480
|
# not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
|
461
|
-
_scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
|
462
|
-
scale_by_psgd: update_by_psgd, #
|
463
|
-
scale_by_adam: update_by_adam, #
|
464
|
-
scale_by_laprop: update_by_laprop, #
|
465
|
-
scale_by_adopt: update_by_adopt}
|
481
|
+
_scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
|
482
|
+
scale_by_psgd.get_fn(): update_by_psgd, #
|
483
|
+
scale_by_adam.get_fn(): update_by_adam, #
|
484
|
+
scale_by_laprop.get_fn(): update_by_laprop, #
|
485
|
+
scale_by_adopt.get_fn(): update_by_adopt}
|
466
486
|
|
467
487
|
|
468
488
|
class BaseOpt(ChainOpt):
|
@@ -470,16 +490,17 @@ class BaseOpt(ChainOpt):
|
|
470
490
|
update_clipping: str_or_fn = None
|
471
491
|
palm: bool = False
|
472
492
|
auto_fuse: bool = True
|
473
|
-
compile_step: bool = False
|
474
493
|
|
475
494
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
476
|
-
palm: bool = use_default, *fns):
|
495
|
+
palm: bool = use_default, compile_step: bool = use_default, *fns):
|
477
496
|
if default(update_clipping, self.update_clipping) is None:
|
478
497
|
if fns and self.auto_fuse:
|
479
498
|
args, kwargs = None, None
|
480
499
|
fn = fns[-1]
|
481
500
|
if isinstance(fn, functools.partial):
|
482
|
-
fn, args, kwargs =
|
501
|
+
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
502
|
+
if isinstance(fn, FunctionTransform):
|
503
|
+
fn = fn.get_fn()
|
483
504
|
if fn in _scale_to_update_map:
|
484
505
|
fn = _scale_to_update_map[fn]
|
485
506
|
if args is not None:
|
@@ -491,6 +512,7 @@ class BaseOpt(ChainOpt):
|
|
491
512
|
|
492
513
|
fns = tuple(fns)
|
493
514
|
|
515
|
+
self.compile_step = default(compile_step, self.compile_step)
|
494
516
|
if default(palm, self.palm):
|
495
517
|
fns = (palm_beta2,) + fns
|
496
518
|
if default(gradient_clipping, self.gradient_clipping) is not None:
|
heavyball/utils.py
CHANGED
@@ -1,16 +1,9 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
|
4
|
-
Originally from Evan Walters and Omead Pooladzandi, 2024
|
5
|
-
Modified under Creative Commons Attribution 4.0 International
|
6
|
-
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
7
|
-
"""
|
8
|
-
|
9
1
|
import functools
|
10
2
|
import gc
|
11
3
|
import math
|
12
4
|
import random
|
13
5
|
import string
|
6
|
+
import warnings
|
14
7
|
from typing import List, Optional, Tuple, Callable, Union
|
15
8
|
|
16
9
|
import numpy as np
|
@@ -70,16 +63,16 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
70
63
|
@decorator_knowngood
|
71
64
|
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
|
72
65
|
beta1: Tensor, decay: float):
|
73
|
-
|
74
|
-
|
75
|
-
|
66
|
+
for op, oz, g_ in zip(p, z, grad):
|
67
|
+
g_ = g_.view_as(op)
|
68
|
+
p_, z_, g_ = map(promote, (op, oz, g_))
|
76
69
|
if decay != 0:
|
77
|
-
g_
|
78
|
-
p_.
|
79
|
-
p_
|
80
|
-
z_
|
81
|
-
|
82
|
-
|
70
|
+
g_ = g_ + p_ * decay
|
71
|
+
p_ = p_.lerp(z_, ckp1)
|
72
|
+
p_ = p_ + g_ * (lr * (beta1 * (1 - ckp1)) - lr)
|
73
|
+
z_ = z_ + g_ * -lr
|
74
|
+
copy_stochastic_(op, p_)
|
75
|
+
copy_stochastic_(oz, z_)
|
83
76
|
|
84
77
|
|
85
78
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
@@ -164,9 +157,9 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
|
|
164
157
|
out: List[Optional[Tensor]]):
|
165
158
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
166
159
|
s32 = torch._foreach_mul(s32, beta2)
|
167
|
-
[s
|
160
|
+
s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
|
168
161
|
denom = torch._foreach_sqrt(s32)
|
169
|
-
[d.
|
162
|
+
denom = [d.clamp(min=eps) for d in denom]
|
170
163
|
copy_stochastic_list_(state, s32)
|
171
164
|
|
172
165
|
if out[0] is None:
|
@@ -184,13 +177,9 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
184
177
|
|
185
178
|
@decorator_knowngood
|
186
179
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
187
|
-
|
188
|
-
|
189
|
-
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
190
|
-
denom = torch._foreach_sqrt(s32)
|
191
|
-
[d.clamp_(min=eps) for d in denom]
|
180
|
+
g32 = promote(grad)
|
181
|
+
denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
|
192
182
|
out = torch._foreach_div(g32, denom)
|
193
|
-
copy_stochastic_list_(state, s32)
|
194
183
|
copy_stochastic_list_(grad, out)
|
195
184
|
|
196
185
|
|
@@ -201,10 +190,10 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
201
190
|
return grad
|
202
191
|
|
203
192
|
|
193
|
+
# TODO: This lerp was fucked - check other lerps
|
204
194
|
@decorator_knowngood
|
205
195
|
def _compilable_exp_avg_(state, grad, beta):
|
206
|
-
s32
|
207
|
-
s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
|
196
|
+
s32 = [s.lerp(g, 1 - beta) for s, g in zip(promote(state), promote(grad))]
|
208
197
|
copy_stochastic_list_(state, s32)
|
209
198
|
copy_stochastic_list_(grad, s32)
|
210
199
|
|
@@ -218,14 +207,16 @@ def scale_by_exp_avg_(state, grad, beta):
|
|
218
207
|
|
219
208
|
@decorator_knowngood
|
220
209
|
def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
|
221
|
-
|
222
|
-
|
223
|
-
torch.
|
224
|
-
torch.
|
225
|
-
torch.
|
226
|
-
torch.
|
227
|
-
torch.
|
228
|
-
torch.
|
210
|
+
p32, g32 = [list(map(promote, x)) for x in (parameters, gradients)]
|
211
|
+
p_norm = torch._foreach_norm(p32)
|
212
|
+
g_norm = torch._foreach_norm(g32)
|
213
|
+
p_norm = torch._foreach_maximum(p_norm, minimum)
|
214
|
+
g_norm = torch._foreach_maximum(g_norm, eps)
|
215
|
+
p_norm = torch._foreach_div(p_norm, g_norm)
|
216
|
+
p_norm = torch._foreach_mul(p_norm, clip_val)
|
217
|
+
p_norm = torch._foreach_minimum(p_norm, 1)
|
218
|
+
g32 = torch._foreach_mul(g32, p_norm)
|
219
|
+
copy_stochastic_list_(gradients, g32)
|
229
220
|
|
230
221
|
|
231
222
|
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
@@ -246,10 +237,6 @@ def is_compiling():
|
|
246
237
|
|
247
238
|
|
248
239
|
def set_(dst: Tensor, src: Tensor):
|
249
|
-
if not is_compiling() and src.data_ptr() == dst.data_ptr():
|
250
|
-
return
|
251
|
-
if src.shape != dst.shape:
|
252
|
-
src = src.reshape_as(dst)
|
253
240
|
dst.copy_(src)
|
254
241
|
|
255
242
|
|
@@ -306,7 +293,7 @@ def ortho(x):
|
|
306
293
|
def _compilable_heavyball_momentum_(state, grad, beta):
|
307
294
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
308
295
|
s32 = torch._foreach_mul(s32, beta)
|
309
|
-
torch.
|
296
|
+
s32 = torch._foreach_add(s32, g32)
|
310
297
|
copy_stochastic_list_(state, s32)
|
311
298
|
copy_stochastic_list_(grad, s32)
|
312
299
|
|
@@ -315,8 +302,8 @@ def _compilable_heavyball_momentum_(state, grad, beta):
|
|
315
302
|
def _compilable_nesterov_momentum_(state, grad, beta):
|
316
303
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
317
304
|
s32 = torch._foreach_mul(s32, beta)
|
318
|
-
torch.
|
319
|
-
[g
|
305
|
+
s32 = torch._foreach_add(s32, g32)
|
306
|
+
g32 = [g + s * beta for g, s in zip(g32, s32)]
|
320
307
|
copy_stochastic_list_(state, s32)
|
321
308
|
copy_stochastic_list_(grad, g32)
|
322
309
|
|
@@ -353,7 +340,7 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
|
353
340
|
elif scale_mode == "scale":
|
354
341
|
y *= max(1, x.size(0) / x.size(1)) ** 0.5
|
355
342
|
elif scale_mode == "graft":
|
356
|
-
y *= x.norm() / y.norm().
|
343
|
+
y *= x.norm() / y.norm().clamp(min=1e-6)
|
357
344
|
else:
|
358
345
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
359
346
|
set_(out, y)
|
@@ -509,8 +496,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
509
496
|
for x_, y_ in zip(x, y):
|
510
497
|
x32 = promote(x_)
|
511
498
|
y32 = promote(y_)
|
512
|
-
|
513
|
-
copy_stochastic_(x_, x32)
|
499
|
+
copy_stochastic_(x_, x32 + y32 * alpha)
|
514
500
|
|
515
501
|
|
516
502
|
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
@@ -634,10 +620,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
634
620
|
def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
|
635
621
|
beta1: float = -1.0):
|
636
622
|
for p in group["params"]:
|
637
|
-
if skip_none and p.grad is None:
|
638
|
-
continue
|
639
|
-
|
640
623
|
if p.grad is None:
|
624
|
+
if skip_none:
|
625
|
+
continue
|
641
626
|
grad = None
|
642
627
|
else:
|
643
628
|
if should_promote:
|
@@ -792,7 +777,7 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
792
777
|
exp_avg32 = _lerp32(exp_avg, u32, beta1)
|
793
778
|
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
794
779
|
u32 = torch._foreach_div(exp_avg32, denom)
|
795
|
-
_compilable_update_(y, u32, decay,
|
780
|
+
_compilable_update_(y, u32, decay, lr, caution, g32)
|
796
781
|
|
797
782
|
|
798
783
|
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
@@ -837,7 +822,7 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
837
822
|
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
838
823
|
u32 = torch._foreach_div(u32, denom)
|
839
824
|
u32 = _lerp32(exp_avg, u32, beta1)
|
840
|
-
_compilable_update_(y, u32, decay,
|
825
|
+
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
841
826
|
|
842
827
|
|
843
828
|
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
@@ -850,22 +835,19 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
|
|
850
835
|
@decorator_knowngood
|
851
836
|
def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
852
837
|
u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
|
853
|
-
_compilable_update_(y, u32, decay,
|
838
|
+
_compilable_update_(y, u32, decay, lr, caution, g32)
|
854
839
|
|
855
840
|
beta1 = beta_debias(beta1, step)
|
856
841
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
857
|
-
[
|
858
|
-
exp_avg32 =
|
859
|
-
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
|
842
|
+
denom = [d.clamp(min=eps) for d in denom]
|
843
|
+
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
860
844
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
861
845
|
|
862
846
|
beta2 = beta_debias(beta2, step + 1)
|
863
|
-
exp_avg_sq32 =
|
864
|
-
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
|
847
|
+
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
|
865
848
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
866
849
|
|
867
850
|
|
868
|
-
|
869
851
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
870
852
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
871
853
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
@@ -879,14 +861,12 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
879
861
|
|
880
862
|
beta1 = beta_debias(beta1, step)
|
881
863
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
882
|
-
[
|
883
|
-
exp_avg32 =
|
884
|
-
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
864
|
+
denom = [d.clamp(min=1e-8) for d in denom]
|
865
|
+
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
885
866
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
886
867
|
|
887
868
|
beta2 = beta_debias(beta2, step + 1)
|
888
|
-
exp_avg_sq32 =
|
889
|
-
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
869
|
+
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
|
890
870
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
891
871
|
|
892
872
|
copy_stochastic_list_(grad, update)
|
@@ -921,39 +901,31 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
|
921
901
|
|
922
902
|
|
923
903
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
924
|
-
if not is_compiling() and target.data_ptr() == source.data_ptr():
|
925
|
-
return
|
926
904
|
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
|
927
905
|
_compilable_copy_stochastic_(target, source.float())
|
928
906
|
set_(target, source)
|
929
907
|
|
930
908
|
|
931
909
|
@decorator_knowngood
|
932
|
-
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor,
|
910
|
+
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
|
933
911
|
g: List[Optional[Tensor]]):
|
934
912
|
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
935
913
|
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
936
914
|
|
937
|
-
|
938
|
-
torch._foreach_mul_(p32, 1 - decay * lr)
|
939
|
-
|
940
|
-
for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
|
915
|
+
for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
|
941
916
|
if caution:
|
942
917
|
u32_ = _compilable_cautioning(promote(g_), u32_)
|
943
|
-
|
944
|
-
|
945
|
-
copy_stochastic_list_(p, p32)
|
918
|
+
p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
|
919
|
+
copy_stochastic_(p_, p32_)
|
946
920
|
|
947
921
|
|
948
|
-
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float,
|
949
|
-
|
922
|
+
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
|
923
|
+
grad: List[Tensor] = None):
|
950
924
|
param, update, grad = list_guard(param, update, grad)
|
951
925
|
lr = scalar_guard(lr, param[0])
|
952
926
|
if not caution:
|
953
927
|
grad = [None] * len(param)
|
954
|
-
|
955
|
-
add_fn = stochastic_add_
|
956
|
-
_compilable_update_(param, update, decay, add_fn, lr, caution, grad)
|
928
|
+
_compilable_update_(param, update, decay, lr, caution, grad)
|
957
929
|
|
958
930
|
|
959
931
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -1194,6 +1166,7 @@ def identity(x):
|
|
1194
1166
|
|
1195
1167
|
@decorator_knowngood
|
1196
1168
|
def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
1169
|
+
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
1197
1170
|
g32 = list(map(promote, grad))
|
1198
1171
|
[g.mul_(1 / scale) for g in g32]
|
1199
1172
|
tanh = torch._foreach_tanh(g32)
|
@@ -1247,6 +1220,12 @@ def update_triu_(q_state, materialised):
|
|
1247
1220
|
assert shape0 == shape1
|
1248
1221
|
copy_stochastic_(q, m)
|
1249
1222
|
|
1223
|
+
_warned = set()
|
1224
|
+
|
1225
|
+
def warn_once(msg):
|
1226
|
+
if msg not in _warned:
|
1227
|
+
warnings.warn(msg)
|
1228
|
+
_warned.add(msg)
|
1250
1229
|
|
1251
1230
|
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1252
1231
|
name: str = 'cumulative_prob'):
|
@@ -1291,6 +1270,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
|
1291
1270
|
return new.to(ea.dtype)
|
1292
1271
|
|
1293
1272
|
|
1273
|
+
@decorator_knowngood
|
1294
1274
|
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1295
1275
|
precond = psgd_precond_grad(expr, grad, *preconds)
|
1296
1276
|
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
@@ -1371,3 +1351,22 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
|
1371
1351
|
|
1372
1352
|
for p in model.parameters():
|
1373
1353
|
p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
|
1354
|
+
|
1355
|
+
|
1356
|
+
def fused_hook(parameters, optimizer, *args, **kwargs):
|
1357
|
+
parameters = list(parameters)
|
1358
|
+
param_count = len(parameters)
|
1359
|
+
seen_params = set()
|
1360
|
+
|
1361
|
+
o = optimizer(parameters, *args, **kwargs)
|
1362
|
+
|
1363
|
+
def _step(p: Tensor):
|
1364
|
+
seen_params.add(p)
|
1365
|
+
|
1366
|
+
if len(seen_params) < param_count:
|
1367
|
+
o.step()
|
1368
|
+
o.zero_grad()
|
1369
|
+
seen_params.clear()
|
1370
|
+
|
1371
|
+
for p in parameters:
|
1372
|
+
p.register_post_accumulate_grad_hook(_step)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=5CrtSVaTI9DIgPPy0DD3WbWyVmc6-3jd2E5zM2frQlI,21092
|
3
|
+
heavyball/utils.py,sha256=n80bsTZ2NUD4L3YERaC7ydKaQzW4kSDNoBTLF0DfE1g,47393
|
4
|
+
heavyball-1.2.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.2.1.dist-info/METADATA,sha256=eBqL5rxjms6zSclp1IjQKI8vmv1OOgvrmdOroZnE1GQ,12022
|
6
|
+
heavyball-1.2.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.2.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.2.1.dist-info/RECORD,,
|
heavyball-1.1.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=IrTlhiHvBlc1SaUUQDb1ulVfj0nrPJcqoP52AXWO3cI,20362
|
3
|
-
heavyball/utils.py,sha256=0j5wRDYeI9Elz9m8tcP7CZNhj_9OIWEF_uQpb0LTrYM,47814
|
4
|
-
heavyball-1.1.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.1.3.dist-info/METADATA,sha256=ILjX-OviZL9vO7bgw7ldGx-8XXIa1YfWRYCW94EtAOI,12022
|
6
|
-
heavyball-1.1.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|