heavyball 1.1.0__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.1.0 → heavyball-1.1.1}/PKG-INFO +1 -1
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball/chainable.py +8 -7
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball/utils.py +19 -4
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.1.0 → heavyball-1.1.1}/setup.py +1 -1
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_channels_last.py +2 -2
- {heavyball-1.1.0 → heavyball-1.1.1}/LICENSE +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/README.md +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball/__init__.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/setup.cfg +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_bf16_params.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_bf16_q.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_bf16_storage.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_caution.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_closure.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_ema.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_foreach.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_hook.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_mars.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_memory.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_merge.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_no_grad.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_psgd.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_soap.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.1}/test/test_stochastic_updates.py +0 -0
@@ -140,16 +140,14 @@ class SkipUpdate(ValueError):
|
|
140
140
|
@zero_guard("exp_avg")
|
141
141
|
@no_state
|
142
142
|
def exp_avg(group, update, grad, param, exp_avg):
|
143
|
-
utils.
|
144
|
-
return exp_avg
|
143
|
+
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
145
144
|
|
146
145
|
|
147
146
|
@zero_guard("exp_avg_sq")
|
148
147
|
@no_state
|
149
148
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
150
|
-
|
151
|
-
|
152
|
-
return out
|
149
|
+
return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
|
150
|
+
group['eps'])
|
153
151
|
|
154
152
|
|
155
153
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -350,9 +348,12 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
|
|
350
348
|
@no_state_no_foreach
|
351
349
|
def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
352
350
|
prob: Optional[callable] = None):
|
351
|
+
old = update
|
352
|
+
update = update.to(memory_format=torch.contiguous_format)
|
353
353
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
354
354
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
355
|
-
|
355
|
+
out = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
|
356
|
+
return torch.as_strided(out, old.shape, old.stride())
|
356
357
|
|
357
358
|
|
358
359
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
@@ -400,7 +401,7 @@ def apply_to_idx(fn, idx):
|
|
400
401
|
|
401
402
|
|
402
403
|
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
403
|
-
update = [torch.clone(g) for g in grad]
|
404
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
404
405
|
skip_update = False
|
405
406
|
for fn in fns:
|
406
407
|
try:
|
@@ -201,6 +201,21 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
201
201
|
return grad
|
202
202
|
|
203
203
|
|
204
|
+
@decorator_knowngood
|
205
|
+
def _compilable_exp_avg_(state, grad, beta):
|
206
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
207
|
+
[s.lerp_(g, beta) for s, g in zip(s32, g32)]
|
208
|
+
copy_stochastic_list_(state, s32)
|
209
|
+
copy_stochastic_list_(grad, s32)
|
210
|
+
|
211
|
+
|
212
|
+
def scale_by_exp_avg_(state, grad, beta):
|
213
|
+
state, grad = list_guard(state, grad)
|
214
|
+
beta = scalar_guard(beta, state[0])
|
215
|
+
_compilable_exp_avg_(state, grad, beta)
|
216
|
+
return grad
|
217
|
+
|
218
|
+
|
204
219
|
@decorator_knowngood
|
205
220
|
def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
|
206
221
|
p_norm = torch._foreach_norm(parameters)
|
@@ -321,13 +336,13 @@ def nesterov_momentum(state, grad, beta):
|
|
321
336
|
|
322
337
|
@decorator_knowngood
|
323
338
|
def inplace_orthogonal_(x, mode, out):
|
324
|
-
if mode == '
|
339
|
+
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
340
|
+
y = zeropower_via_newtonschulz5(x, 5)
|
341
|
+
elif mode == 'qr':
|
325
342
|
y = torch.linalg.qr(x).Q
|
326
343
|
elif mode == 'svd':
|
327
344
|
u, s, v = torch.linalg.svd(x)
|
328
345
|
y = u @ v.T
|
329
|
-
elif mode == 'newtonschulz':
|
330
|
-
y = zeropower_via_newtonschulz5(x, 5)
|
331
346
|
else:
|
332
347
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
333
348
|
set_(out, y)
|
@@ -363,7 +378,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
363
378
|
est_eig = torch.einsum('ij,ij->j', o, tmp)
|
364
379
|
sort_idx = torch.argsort(est_eig, descending=True)
|
365
380
|
indices.append(sort_idx)
|
366
|
-
inplace_orthogonal_(tmp[:, sort_idx], q)
|
381
|
+
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q)
|
367
382
|
|
368
383
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
369
384
|
for i, ind in enumerate(indices))
|
@@ -17,7 +17,7 @@ config.cache_size_limit = 128
|
|
17
17
|
|
18
18
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
19
|
@pytest.mark.parametrize("size,depth", [(128, 1)])
|
20
|
-
def test_foreach(opt, size, depth: int, iterations: int =
|
20
|
+
def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations: int = 1):
|
21
21
|
set_torch()
|
22
22
|
opt = getattr(heavyball, opt)
|
23
23
|
|
@@ -34,7 +34,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 32, outer_iterations:
|
|
34
34
|
if is_channels_last:
|
35
35
|
model.to(memory_format=torch.channels_last)
|
36
36
|
|
37
|
-
o = get_optim(opt, model.parameters(), lr=1e-
|
37
|
+
o = get_optim(opt, model.parameters(), lr=1e-5, weight_decay=1e-4, warmup_steps=16)
|
38
38
|
|
39
39
|
for _ in range(iterations):
|
40
40
|
loss = model(torch.randn((1024, size, 4, 4), device='cuda')).square().mean()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|