heavyball 1.4.3__py3-none-any.whl → 1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +17 -14
- {heavyball-1.4.3.dist-info → heavyball-1.4.4.dist-info}/METADATA +1 -1
- heavyball-1.4.4.dist-info/RECORD +8 -0
- heavyball-1.4.3.dist-info/RECORD +0 -8
- {heavyball-1.4.3.dist-info → heavyball-1.4.4.dist-info}/LICENSE +0 -0
- {heavyball-1.4.3.dist-info → heavyball-1.4.4.dist-info}/WHEEL +0 -0
- {heavyball-1.4.3.dist-info → heavyball-1.4.4.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -482,7 +482,7 @@ def scalar_guard(*args):
|
|
482
482
|
out = []
|
483
483
|
for x in xs:
|
484
484
|
if isinstance(x, float):
|
485
|
-
out.append(torch.empty((), dtype=
|
485
|
+
out.append(torch.empty((), dtype=promote(ref.dtype), device=ref.device).fill_(x))
|
486
486
|
elif isinstance(x, int):
|
487
487
|
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
|
488
488
|
else:
|
@@ -1043,7 +1043,8 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
1043
1043
|
if q.dim() <= 1:
|
1044
1044
|
conjB /= q
|
1045
1045
|
else:
|
1046
|
-
conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
|
1046
|
+
conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
|
1047
|
+
conjB)
|
1047
1048
|
if i < order - 1:
|
1048
1049
|
conjB = torch.transpose(conjB, i, order - 1)
|
1049
1050
|
return A, conjB
|
@@ -1286,7 +1287,6 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
|
|
1286
1287
|
|
1287
1288
|
|
1288
1289
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1289
|
-
|
1290
1290
|
lr = scalar_guard(lr, param[0])
|
1291
1291
|
_compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
|
1292
1292
|
|
@@ -1338,25 +1338,20 @@ def caution(g, update):
|
|
1338
1338
|
return _compilable_cautioning(g, update)
|
1339
1339
|
|
1340
1340
|
|
1341
|
-
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.
|
1341
|
+
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_start=1000):
|
1342
1342
|
"""Anneal preconditioner update probability during beginning of training.
|
1343
1343
|
|
1344
1344
|
PSGD benefits from more preconditioner updates at the beginning of training,
|
1345
1345
|
but once the preconditioner is learned the update probability can drop low.
|
1346
1346
|
|
1347
1347
|
This schedule is an exponential anneal with a flat start. Default settings keep
|
1348
|
-
update probability at
|
1349
|
-
`min_prob` by 4000 steps. Default settings work very well for most models and
|
1348
|
+
update probability at `max_prob` for 1000 steps then exponentially anneal down to
|
1349
|
+
`min_prob` by ~4000 steps. Default settings work very well for most models and
|
1350
1350
|
training regimes.
|
1351
1351
|
"""
|
1352
1352
|
|
1353
1353
|
def _schedule(n):
|
1354
|
-
|
1355
|
-
return max_prob
|
1356
|
-
|
1357
|
-
n -= flat_start
|
1358
|
-
prob = max_prob * math.exp(-decay * (n - flat_start))
|
1359
|
-
return max(min_prob, min(max_prob, prob))
|
1354
|
+
return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
|
1360
1355
|
|
1361
1356
|
return _schedule
|
1362
1357
|
|
@@ -1375,12 +1370,18 @@ def merge_group(group, *tensors):
|
|
1375
1370
|
|
1376
1371
|
|
1377
1372
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1378
|
-
|
1373
|
+
optimizers = {}
|
1374
|
+
|
1375
|
+
def _step(p: Tensor):
|
1376
|
+
o = optimizers[p]
|
1379
1377
|
o.step()
|
1380
1378
|
o.zero_grad()
|
1381
1379
|
|
1382
1380
|
for p in model.parameters():
|
1383
|
-
p
|
1381
|
+
optimizers[p] = optimizer([p], *args, **kwargs)
|
1382
|
+
p.register_post_accumulate_grad_hook(_step)
|
1383
|
+
|
1384
|
+
return optimizers
|
1384
1385
|
|
1385
1386
|
|
1386
1387
|
def fused_hook(parameters, optimizer, *args, **kwargs):
|
@@ -1401,6 +1402,8 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1401
1402
|
for p in parameters:
|
1402
1403
|
p.register_post_accumulate_grad_hook(_step)
|
1403
1404
|
|
1405
|
+
return o
|
1406
|
+
|
1404
1407
|
|
1405
1408
|
@decorator_knowngood
|
1406
1409
|
def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
|
3
|
+
heavyball/utils.py,sha256=lFwN8T-dlldmOe-Qd6iWhSqqNfWl7IBawLWAo5l9rPw,48071
|
4
|
+
heavyball-1.4.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.4.4.dist-info/METADATA,sha256=w5nAamE6sr08elqo2fS6B_kXktOMXxFQvyJTkRT4Eqo,43584
|
6
|
+
heavyball-1.4.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.4.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.4.4.dist-info/RECORD,,
|
heavyball-1.4.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
|
3
|
-
heavyball/utils.py,sha256=x0rSU8lko7ACdI9GuTLC0wP6HwIZxwB8f8tukBOR0xA,48129
|
4
|
-
heavyball-1.4.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.4.3.dist-info/METADATA,sha256=RM_pOme3dsQL-drKcKD6FJ0qE3SSh4JdPM-kC9vpbeU,43584
|
6
|
-
heavyball-1.4.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.4.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.4.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|