heavyball 1.2.2__tar.gz → 1.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.2.2 → heavyball-1.2.3}/PKG-INFO +1 -1
  2. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball/chainable.py +2 -2
  3. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball/utils.py +30 -12
  4. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.2.2 → heavyball-1.2.3}/setup.py +1 -1
  6. {heavyball-1.2.2 → heavyball-1.2.3}/LICENSE +0 -0
  7. {heavyball-1.2.2 → heavyball-1.2.3}/README.md +0 -0
  8. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball/__init__.py +0 -0
  9. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.2.2 → heavyball-1.2.3}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.2.2 → heavyball-1.2.3}/setup.cfg +0 -0
  14. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_caution.py +0 -0
  18. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_channels_last.py +0 -0
  19. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_closure.py +0 -0
  20. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_ema.py +0 -0
  21. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_foreach.py +0 -0
  22. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_hook.py +0 -0
  23. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_mars.py +0 -0
  24. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_memory.py +0 -0
  25. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_merge.py +0 -0
  26. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_no_grad.py +0 -0
  27. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_psgd.py +0 -0
  28. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_soap.py +0 -0
  29. {heavyball-1.2.2 → heavyball-1.2.3}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.2
3
+ Version: 1.2.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -183,8 +183,8 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
183
183
  @no_state
184
184
  def update_by_schedule_free(group, update, grad, param, z):
185
185
  group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
186
- utils.get_beta1(group), param, z, update, group['r'], group['step'],
187
- group['weight_decay'])
186
+ utils.get_beta1(group), param, z, update, grad, group['caution'],
187
+ group['r'], group['step'], group['weight_decay'])
188
188
  raise SkipUpdate
189
189
 
190
190
 
@@ -61,22 +61,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
61
61
 
62
62
 
63
63
  @decorator_knowngood
64
- def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
65
- beta1: Tensor, decay: float):
66
- for op, oz, g_ in zip(p, z, grad):
67
- g_ = g_.view_as(op)
68
- p_, z_, g_ = map(promote, (op, oz, g_))
64
+ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
65
+ beta1: Tensor, decay: float, grad: List[Tensor], caution):
66
+ for op, oz, u_, g_ in zip(p, z, update, grad):
67
+ u_ = u_.view_as(op)
68
+ p_, z_, u_ = map(promote, (op, oz, u_))
69
69
  if decay != 0:
70
- g_ = g_ + p_ * decay
70
+ u_ = u_ + p_ * decay
71
+ if caution:
72
+ u_ = _compilable_cautioning(u_, g_)
71
73
  p_ = p_.lerp(z_, ckp1)
72
- p_ = p_ + g_ * (lr * (beta1 * (1 - ckp1)) - lr)
73
- z_ = z_ + g_ * -lr
74
+ p_ = p_ + u_ * (lr * (beta1 * (1 - ckp1)) - lr)
75
+ z_ = z_ + u_ * -lr
74
76
  copy_stochastic_(op, p_)
75
77
  copy_stochastic_(oz, z_)
76
78
 
77
79
 
78
80
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
79
- z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
81
+ z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
82
+ step: int = 0, decay: float = 0.0):
80
83
  weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
81
84
  weight_sum = weight_sum + weight
82
85
 
@@ -85,9 +88,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
85
88
  except ZeroDivisionError:
86
89
  ckp1 = 0
87
90
 
88
- grad, parameters, z = list_guard(grad, parameters, z)
91
+ update, parameters, z = list_guard(update, parameters, z)
89
92
  lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
90
- _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
93
+ _compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
91
94
  return weight_sum
92
95
 
93
96
 
@@ -1220,13 +1223,16 @@ def update_triu_(q_state, materialised):
1220
1223
  assert shape0 == shape1
1221
1224
  copy_stochastic_(q, m)
1222
1225
 
1226
+
1223
1227
  _warned = set()
1224
1228
 
1229
+
1225
1230
  def warn_once(msg):
1226
1231
  if msg not in _warned:
1227
1232
  warnings.warn(msg)
1228
1233
  _warned.add(msg)
1229
1234
 
1235
+
1230
1236
  def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1231
1237
  name: str = 'cumulative_prob'):
1232
1238
  group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
@@ -1369,4 +1375,16 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1369
1375
  seen_params.clear()
1370
1376
 
1371
1377
  for p in parameters:
1372
- p.register_post_accumulate_grad_hook(_step)
1378
+ p.register_post_accumulate_grad_hook(_step)
1379
+
1380
+
1381
+ @decorator_knowngood
1382
+ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
1383
+ mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
1384
+ update = update.masked_fill(mask, 0)
1385
+ return update
1386
+
1387
+
1388
+ def disable_caution_scaling():
1389
+ global _compilable_cautioning
1390
+ _compilable_cautioning = _compilable_caution_no_scale
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.2
3
+ Version: 1.2.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.2.2',
13
+ version='1.2.3',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes