heavyball 1.2.2__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.2.2 → heavyball-1.3.0}/PKG-INFO +1 -1
  2. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball/chainable.py +10 -9
  3. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball/utils.py +35 -18
  4. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.2.2 → heavyball-1.3.0}/setup.py +1 -1
  6. {heavyball-1.2.2 → heavyball-1.3.0}/LICENSE +0 -0
  7. {heavyball-1.2.2 → heavyball-1.3.0}/README.md +0 -0
  8. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball/__init__.py +0 -0
  9. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.2.2 → heavyball-1.3.0}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.2.2 → heavyball-1.3.0}/setup.cfg +0 -0
  14. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_caution.py +0 -0
  18. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_channels_last.py +0 -0
  19. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_closure.py +0 -0
  20. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_ema.py +0 -0
  21. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_foreach.py +0 -0
  22. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_hook.py +0 -0
  23. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_mars.py +0 -0
  24. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_memory.py +0 -0
  25. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_merge.py +0 -0
  26. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_no_grad.py +0 -0
  27. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_psgd.py +0 -0
  28. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_soap.py +0 -0
  29. {heavyball-1.2.2 → heavyball-1.3.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.2
3
+ Version: 1.3.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -183,8 +183,8 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
183
183
  @no_state
184
184
  def update_by_schedule_free(group, update, grad, param, z):
185
185
  group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
186
- utils.get_beta1(group), param, z, update, group['r'], group['step'],
187
- group['weight_decay'])
186
+ utils.get_beta1(group), param, z, update, grad, group['caution'],
187
+ group['r'], group['step'], group['weight_decay'])
188
188
  raise SkipUpdate
189
189
 
190
190
 
@@ -438,14 +438,15 @@ class ChainOpt(utils.StatefulOptimizer):
438
438
 
439
439
  for param in p:
440
440
  state = self.state_(param)
441
- if 'step' not in state:
442
- if self.compile_step:
443
- step = utils.scalar_guard(0, param)
444
- state['step'] = step
445
- step = state['step'].add_(1)
441
+ if 'step' in state:
442
+ step = state['step']
443
+ elif self.compile_step:
444
+ step = utils.scalar_guard(0, param)
445
+ else:
446
+ step = 0
446
447
  break
447
448
 
448
- group['step'] = step
449
+ group['step'] = state['step'] = step = step + 1
449
450
 
450
451
  if group['warmup_steps'] and step < group['warmup_steps']:
451
452
  group['lr'] = group['base_lr'] * step / group['warmup_steps']
@@ -494,7 +495,7 @@ class BaseOpt(ChainOpt):
494
495
  auto_fuse: bool = True
495
496
 
496
497
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
497
- palm: bool = use_default, compile_step: bool = use_default, *fns):
498
+ palm: bool = use_default, *fns, compile_step: bool = use_default):
498
499
  if default(update_clipping, self.update_clipping) is None:
499
500
  if fns and self.auto_fuse:
500
501
  args, kwargs = None, None
@@ -61,22 +61,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
61
61
 
62
62
 
63
63
  @decorator_knowngood
64
- def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
65
- beta1: Tensor, decay: float):
66
- for op, oz, g_ in zip(p, z, grad):
67
- g_ = g_.view_as(op)
68
- p_, z_, g_ = map(promote, (op, oz, g_))
64
+ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
65
+ beta1: Tensor, decay: float, grad: List[Tensor], caution):
66
+ for op, oz, u_, g_ in zip(p, z, update, grad):
67
+ u_ = u_.view_as(op)
68
+ p_, z_, u_ = map(promote, (op, oz, u_))
69
69
  if decay != 0:
70
- g_ = g_ + p_ * decay
70
+ u_ = u_ + p_ * decay
71
+ if caution:
72
+ u_ = _compilable_cautioning(u_, g_)
71
73
  p_ = p_.lerp(z_, ckp1)
72
- p_ = p_ + g_ * (lr * (beta1 * (1 - ckp1)) - lr)
73
- z_ = z_ + g_ * -lr
74
+ p_ = p_ + u_ * (lr * (beta1 * (1 - ckp1)) - lr)
75
+ z_ = z_ + u_ * -lr
74
76
  copy_stochastic_(op, p_)
75
77
  copy_stochastic_(oz, z_)
76
78
 
77
79
 
78
80
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
79
- z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
81
+ z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
82
+ step: int = 0, decay: float = 0.0):
80
83
  weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
81
84
  weight_sum = weight_sum + weight
82
85
 
@@ -85,9 +88,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
85
88
  except ZeroDivisionError:
86
89
  ckp1 = 0
87
90
 
88
- grad, parameters, z = list_guard(grad, parameters, z)
91
+ update, parameters, z, grad = list_guard(update, parameters, z, grad)
89
92
  lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
90
- _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
93
+ _compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
91
94
  return weight_sum
92
95
 
93
96
 
@@ -909,13 +912,12 @@ def copy_stochastic_(target: Tensor, source: Tensor):
909
912
  @decorator_knowngood
910
913
  def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
911
914
  g: List[Optional[Tensor]]):
912
- u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
913
- p32, u32 = [list(map(promote, x)) for x in [p, u]]
914
-
915
- for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
915
+ for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
916
+ u_ = promote(u_.view_as(p_))
917
+ p32_ = promote(p_)
916
918
  if caution:
917
- u32_ = _compilable_cautioning(promote(g_), u32_)
918
- p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
919
+ u_ = _compilable_cautioning(promote(g_), u_)
920
+ p32_ = p32_ * (1 - decay * lr) + u_ * -lr
919
921
  copy_stochastic_(p_, p32_)
920
922
 
921
923
 
@@ -1220,13 +1222,16 @@ def update_triu_(q_state, materialised):
1220
1222
  assert shape0 == shape1
1221
1223
  copy_stochastic_(q, m)
1222
1224
 
1225
+
1223
1226
  _warned = set()
1224
1227
 
1228
+
1225
1229
  def warn_once(msg):
1226
1230
  if msg not in _warned:
1227
1231
  warnings.warn(msg)
1228
1232
  _warned.add(msg)
1229
1233
 
1234
+
1230
1235
  def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1231
1236
  name: str = 'cumulative_prob'):
1232
1237
  group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
@@ -1369,4 +1374,16 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1369
1374
  seen_params.clear()
1370
1375
 
1371
1376
  for p in parameters:
1372
- p.register_post_accumulate_grad_hook(_step)
1377
+ p.register_post_accumulate_grad_hook(_step)
1378
+
1379
+
1380
+ @decorator_knowngood
1381
+ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
1382
+ mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
1383
+ update = update.masked_fill(mask, 0)
1384
+ return update
1385
+
1386
+
1387
+ def disable_caution_scaling():
1388
+ global _compilable_cautioning
1389
+ _compilable_cautioning = _compilable_caution_no_scale
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.2
3
+ Version: 1.3.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.2.2',
13
+ version='1.3.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes