heavyball 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +11 -9
- heavyball/utils.py +30 -12
- {heavyball-1.2.1.dist-info → heavyball-1.2.3.dist-info}/METADATA +1 -1
- heavyball-1.2.3.dist-info/RECORD +8 -0
- heavyball-1.2.1.dist-info/RECORD +0 -8
- {heavyball-1.2.1.dist-info → heavyball-1.2.3.dist-info}/LICENSE +0 -0
- {heavyball-1.2.1.dist-info → heavyball-1.2.3.dist-info}/WHEEL +0 -0
- {heavyball-1.2.1.dist-info → heavyball-1.2.3.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -183,8 +183,8 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
183
183
|
@no_state
|
184
184
|
def update_by_schedule_free(group, update, grad, param, z):
|
185
185
|
group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
|
186
|
-
utils.get_beta1(group), param, z, update,
|
187
|
-
group['weight_decay'])
|
186
|
+
utils.get_beta1(group), param, z, update, grad, group['caution'],
|
187
|
+
group['r'], group['step'], group['weight_decay'])
|
188
188
|
raise SkipUpdate
|
189
189
|
|
190
190
|
|
@@ -287,16 +287,19 @@ def heavyball_momentum(group, updates, grads, params, momentum):
|
|
287
287
|
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
|
288
288
|
|
289
289
|
|
290
|
+
_optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
|
291
|
+
|
292
|
+
|
290
293
|
@zero_guard("exp_avg", "exp_avg_sq")
|
291
294
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
292
295
|
@no_state
|
293
|
-
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
|
296
|
+
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
|
294
297
|
update = utils.promote(update)
|
295
298
|
|
296
299
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
297
|
-
|
298
|
-
|
299
|
-
precond = [utils.project(p, q,
|
300
|
+
fn = _optim_fns[inner]
|
301
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
|
302
|
+
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
300
303
|
|
301
304
|
for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
|
302
305
|
utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
|
@@ -355,8 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
355
358
|
update = update.to(memory_format=torch.contiguous_format)
|
356
359
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
357
360
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
358
|
-
|
359
|
-
return torch.as_strided(out, old.shape, old.stride())
|
361
|
+
return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
360
362
|
|
361
363
|
|
362
364
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
@@ -512,7 +514,7 @@ class BaseOpt(ChainOpt):
|
|
512
514
|
|
513
515
|
fns = tuple(fns)
|
514
516
|
|
515
|
-
self.compile_step =
|
517
|
+
self.compile_step = default(compile_step, self.compile_step)
|
516
518
|
if default(palm, self.palm):
|
517
519
|
fns = (palm_beta2,) + fns
|
518
520
|
if default(gradient_clipping, self.gradient_clipping) is not None:
|
heavyball/utils.py
CHANGED
@@ -61,22 +61,25 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
61
61
|
|
62
62
|
|
63
63
|
@decorator_knowngood
|
64
|
-
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor,
|
65
|
-
beta1: Tensor, decay: float):
|
66
|
-
for op, oz, g_ in zip(p, z, grad):
|
67
|
-
|
68
|
-
p_, z_,
|
64
|
+
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
|
65
|
+
beta1: Tensor, decay: float, grad: List[Tensor], caution):
|
66
|
+
for op, oz, u_, g_ in zip(p, z, update, grad):
|
67
|
+
u_ = u_.view_as(op)
|
68
|
+
p_, z_, u_ = map(promote, (op, oz, u_))
|
69
69
|
if decay != 0:
|
70
|
-
|
70
|
+
u_ = u_ + p_ * decay
|
71
|
+
if caution:
|
72
|
+
u_ = _compilable_cautioning(u_, g_)
|
71
73
|
p_ = p_.lerp(z_, ckp1)
|
72
|
-
p_ = p_ +
|
73
|
-
z_ = z_ +
|
74
|
+
p_ = p_ + u_ * (lr * (beta1 * (1 - ckp1)) - lr)
|
75
|
+
z_ = z_ + u_ * -lr
|
74
76
|
copy_stochastic_(op, p_)
|
75
77
|
copy_stochastic_(oz, z_)
|
76
78
|
|
77
79
|
|
78
80
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
79
|
-
z: List[Tensor],
|
81
|
+
z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
|
82
|
+
step: int = 0, decay: float = 0.0):
|
80
83
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
81
84
|
weight_sum = weight_sum + weight
|
82
85
|
|
@@ -85,9 +88,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
85
88
|
except ZeroDivisionError:
|
86
89
|
ckp1 = 0
|
87
90
|
|
88
|
-
|
91
|
+
update, parameters, z = list_guard(update, parameters, z)
|
89
92
|
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
90
|
-
_compilable_schedule_free_(parameters, z, ckp1,
|
93
|
+
_compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
|
91
94
|
return weight_sum
|
92
95
|
|
93
96
|
|
@@ -1220,13 +1223,16 @@ def update_triu_(q_state, materialised):
|
|
1220
1223
|
assert shape0 == shape1
|
1221
1224
|
copy_stochastic_(q, m)
|
1222
1225
|
|
1226
|
+
|
1223
1227
|
_warned = set()
|
1224
1228
|
|
1229
|
+
|
1225
1230
|
def warn_once(msg):
|
1226
1231
|
if msg not in _warned:
|
1227
1232
|
warnings.warn(msg)
|
1228
1233
|
_warned.add(msg)
|
1229
1234
|
|
1235
|
+
|
1230
1236
|
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1231
1237
|
name: str = 'cumulative_prob'):
|
1232
1238
|
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
@@ -1369,4 +1375,16 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1369
1375
|
seen_params.clear()
|
1370
1376
|
|
1371
1377
|
for p in parameters:
|
1372
|
-
p.register_post_accumulate_grad_hook(_step)
|
1378
|
+
p.register_post_accumulate_grad_hook(_step)
|
1379
|
+
|
1380
|
+
|
1381
|
+
@decorator_knowngood
|
1382
|
+
def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
1383
|
+
mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
|
1384
|
+
update = update.masked_fill(mask, 0)
|
1385
|
+
return update
|
1386
|
+
|
1387
|
+
|
1388
|
+
def disable_caution_scaling():
|
1389
|
+
global _compilable_cautioning
|
1390
|
+
_compilable_cautioning = _compilable_caution_no_scale
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=u9w2z_aSslcokWVCiiXQJ8GSPlOhgrOFUYAwt2JfTzI,21100
|
3
|
+
heavyball/utils.py,sha256=I2zfiB_-EP35LYr-vLyxPNl8_uJo2se3Id0IWjZeVjg,47951
|
4
|
+
heavyball-1.2.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.2.3.dist-info/METADATA,sha256=EtW_3QIUKrKpyYUfXmGQm3_EpZkr8oQyow7gAyC4Ges,12022
|
6
|
+
heavyball-1.2.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.2.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.2.3.dist-info/RECORD,,
|
heavyball-1.2.1.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=5CrtSVaTI9DIgPPy0DD3WbWyVmc6-3jd2E5zM2frQlI,21092
|
3
|
-
heavyball/utils.py,sha256=n80bsTZ2NUD4L3YERaC7ydKaQzW4kSDNoBTLF0DfE1g,47393
|
4
|
-
heavyball-1.2.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.2.1.dist-info/METADATA,sha256=eBqL5rxjms6zSclp1IjQKI8vmv1OOgvrmdOroZnE1GQ,12022
|
6
|
-
heavyball-1.2.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.2.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|