heavyball 1.4.3__py3-none-any.whl → 1.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -482,7 +482,7 @@ def scalar_guard(*args):
482
482
  out = []
483
483
  for x in xs:
484
484
  if isinstance(x, float):
485
- out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
485
+ out.append(torch.empty((), dtype=promote(ref.dtype), device=ref.device).fill_(x))
486
486
  elif isinstance(x, int):
487
487
  out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
488
488
  else:
@@ -1043,7 +1043,8 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
1043
1043
  if q.dim() <= 1:
1044
1044
  conjB /= q
1045
1045
  else:
1046
- conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(conjB)
1046
+ conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
1047
+ conjB)
1047
1048
  if i < order - 1:
1048
1049
  conjB = torch.transpose(conjB, i, order - 1)
1049
1050
  return A, conjB
@@ -1286,7 +1287,6 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
1286
1287
 
1287
1288
 
1288
1289
  def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1289
-
1290
1290
  lr = scalar_guard(lr, param[0])
1291
1291
  _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
1292
1292
 
@@ -1338,25 +1338,20 @@ def caution(g, update):
1338
1338
  return _compilable_cautioning(g, update)
1339
1339
 
1340
1340
 
1341
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
1341
+ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_start=1000):
1342
1342
  """Anneal preconditioner update probability during beginning of training.
1343
1343
 
1344
1344
  PSGD benefits from more preconditioner updates at the beginning of training,
1345
1345
  but once the preconditioner is learned the update probability can drop low.
1346
1346
 
1347
1347
  This schedule is an exponential anneal with a flat start. Default settings keep
1348
- update probability at 1.0 for 200 steps then exponentially anneal down to
1349
- `min_prob` by 4000 steps. Default settings work very well for most models and
1348
+ update probability at `max_prob` for 1000 steps then exponentially anneal down to
1349
+ `min_prob` by ~4000 steps. Default settings work very well for most models and
1350
1350
  training regimes.
1351
1351
  """
1352
1352
 
1353
1353
  def _schedule(n):
1354
- if n < flat_start: # higher numerical stability
1355
- return max_prob
1356
-
1357
- n -= flat_start
1358
- prob = max_prob * math.exp(-decay * (n - flat_start))
1359
- return max(min_prob, min(max_prob, prob))
1354
+ return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
1360
1355
 
1361
1356
  return _schedule
1362
1357
 
@@ -1375,12 +1370,18 @@ def merge_group(group, *tensors):
1375
1370
 
1376
1371
 
1377
1372
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1378
- def _step(p: Tensor, o: torch.optim.Optimizer):
1373
+ optimizers = {}
1374
+
1375
+ def _step(p: Tensor):
1376
+ o = optimizers[p]
1379
1377
  o.step()
1380
1378
  o.zero_grad()
1381
1379
 
1382
1380
  for p in model.parameters():
1383
- p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
1381
+ optimizers[p] = optimizer([p], *args, **kwargs)
1382
+ p.register_post_accumulate_grad_hook(_step)
1383
+
1384
+ return optimizers
1384
1385
 
1385
1386
 
1386
1387
  def fused_hook(parameters, optimizer, *args, **kwargs):
@@ -1401,6 +1402,8 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1401
1402
  for p in parameters:
1402
1403
  p.register_post_accumulate_grad_hook(_step)
1403
1404
 
1405
+ return o
1406
+
1404
1407
 
1405
1408
  @decorator_knowngood
1406
1409
  def _compilable_caution_no_scale(g: Tensor, update: Tensor):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.4.3
3
+ Version: 1.4.4
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
3
+ heavyball/utils.py,sha256=lFwN8T-dlldmOe-Qd6iWhSqqNfWl7IBawLWAo5l9rPw,48071
4
+ heavyball-1.4.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.4.4.dist-info/METADATA,sha256=w5nAamE6sr08elqo2fS6B_kXktOMXxFQvyJTkRT4Eqo,43584
6
+ heavyball-1.4.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.4.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.4.4.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
3
- heavyball/utils.py,sha256=x0rSU8lko7ACdI9GuTLC0wP6HwIZxwB8f8tukBOR0xA,48129
4
- heavyball-1.4.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.4.3.dist-info/METADATA,sha256=RM_pOme3dsQL-drKcKD6FJ0qE3SSh4JdPM-kC9vpbeU,43584
6
- heavyball-1.4.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.4.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.4.3.dist-info/RECORD,,