heavyball 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +9 -7
- heavyball/utils.py +1 -0
- {heavyball-1.2.0.dist-info → heavyball-1.2.2.dist-info}/METADATA +1 -1
- heavyball-1.2.2.dist-info/RECORD +8 -0
- heavyball-1.2.0.dist-info/RECORD +0 -8
- {heavyball-1.2.0.dist-info → heavyball-1.2.2.dist-info}/LICENSE +0 -0
- {heavyball-1.2.0.dist-info → heavyball-1.2.2.dist-info}/WHEEL +0 -0
- {heavyball-1.2.0.dist-info → heavyball-1.2.2.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -287,16 +287,19 @@ def heavyball_momentum(group, updates, grads, params, momentum):
|
|
287
287
|
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
|
288
288
|
|
289
289
|
|
290
|
+
_optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
|
291
|
+
|
292
|
+
|
290
293
|
@zero_guard("exp_avg", "exp_avg_sq")
|
291
294
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
292
295
|
@no_state
|
293
|
-
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
|
296
|
+
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
|
294
297
|
update = utils.promote(update)
|
295
298
|
|
296
299
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
297
|
-
|
298
|
-
|
299
|
-
precond = [utils.project(p, q,
|
300
|
+
fn = _optim_fns[inner]
|
301
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
|
302
|
+
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
300
303
|
|
301
304
|
for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
|
302
305
|
utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
|
@@ -355,8 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
355
358
|
update = update.to(memory_format=torch.contiguous_format)
|
356
359
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
357
360
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
358
|
-
|
359
|
-
return torch.as_strided(out, old.shape, old.stride())
|
361
|
+
return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
360
362
|
|
361
363
|
|
362
364
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
@@ -512,7 +514,7 @@ class BaseOpt(ChainOpt):
|
|
512
514
|
|
513
515
|
fns = tuple(fns)
|
514
516
|
|
515
|
-
self.compile_step =
|
517
|
+
self.compile_step = default(compile_step, self.compile_step)
|
516
518
|
if default(palm, self.palm):
|
517
519
|
fns = (palm_beta2,) + fns
|
518
520
|
if default(gradient_clipping, self.gradient_clipping) is not None:
|
heavyball/utils.py
CHANGED
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=O3vBuwDKSkcl6yGrcEgSylqZ6htjRTg1NIA1sNY4KcA,21076
|
3
|
+
heavyball/utils.py,sha256=n80bsTZ2NUD4L3YERaC7ydKaQzW4kSDNoBTLF0DfE1g,47393
|
4
|
+
heavyball-1.2.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.2.2.dist-info/METADATA,sha256=k_qyB9aR8PpreGu_XxNv9z5pyTnH1TMcA8Ah3k1hUgQ,12022
|
6
|
+
heavyball-1.2.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.2.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.2.2.dist-info/RECORD,,
|
heavyball-1.2.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=5CrtSVaTI9DIgPPy0DD3WbWyVmc6-3jd2E5zM2frQlI,21092
|
3
|
-
heavyball/utils.py,sha256=H0r2GpqRS1c6qIYqW5rFYA-020AVVVWbfGne17mzlcM,47377
|
4
|
-
heavyball-1.2.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.2.0.dist-info/METADATA,sha256=YzMGNrvU_RIKGn13r8GO8kp05s9Me5PWyD3KvEd09Uo,12022
|
6
|
-
heavyball-1.2.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.2.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|