heavyball 0.23.0__tar.gz → 0.23.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {heavyball-0.23.0 → heavyball-0.23.3}/PKG-INFO +1 -1
  2. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/delayed_psgd.py +2 -2
  3. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/foreach_adamw.py +2 -2
  4. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/foreach_adopt.py +2 -2
  5. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/foreach_laprop.py +2 -2
  6. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/foreach_sfadamw.py +2 -3
  7. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/palm_foreach_sfadamw.py +2 -2
  8. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/precond_schedule_sfpsoap.py +2 -2
  9. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/schedule_free_palm_foreach_soap.py +3 -3
  10. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/utils.py +45 -22
  11. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball.egg-info/PKG-INFO +1 -1
  12. {heavyball-0.23.0 → heavyball-0.23.3}/setup.py +1 -1
  13. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_mars.py +3 -3
  14. {heavyball-0.23.0 → heavyball-0.23.3}/LICENSE +0 -0
  15. {heavyball-0.23.0 → heavyball-0.23.3}/README.md +0 -0
  16. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/__init__.py +0 -0
  17. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/cached_delayed_psgd_kron.py +0 -0
  18. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/cached_psgd_kron.py +0 -0
  19. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/foreach_soap.py +0 -0
  20. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/p_adam.py +0 -0
  21. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/palm_foreach_soap.py +0 -0
  22. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/precond_schedule_foreach_soap.py +0 -0
  23. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  24. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/psgd_kron.py +0 -0
  25. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball/pure_psgd.py +0 -0
  26. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball.egg-info/SOURCES.txt +0 -0
  27. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball.egg-info/dependency_links.txt +0 -0
  28. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball.egg-info/requires.txt +0 -0
  29. {heavyball-0.23.0 → heavyball-0.23.3}/heavyball.egg-info/top_level.txt +0 -0
  30. {heavyball-0.23.0 → heavyball-0.23.3}/setup.cfg +0 -0
  31. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_bf16_params.py +0 -0
  32. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_bf16_q.py +0 -0
  33. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_bf16_storage.py +0 -0
  34. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_caution.py +0 -0
  35. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_closure.py +0 -0
  36. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_ema.py +0 -0
  37. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_foreach.py +0 -0
  38. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_memory.py +0 -0
  39. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_merge.py +0 -0
  40. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_no_grad.py +0 -0
  41. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_psgd.py +0 -0
  42. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_soap.py +0 -0
  43. {heavyball-0.23.0 → heavyball-0.23.3}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.23.0
3
+ Version: 0.23.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -8,10 +8,10 @@ import torch
8
8
  from heavyball.utils import stochastic_lerp_, beta_debias, stochastic_add_
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- triu_to_line, line_to_triu, promote,_compilable_update_
11
+ triu_to_line, line_to_triu, promote,_compilable_update_, decorator_knowngood
12
12
 
13
13
 
14
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
14
+ @decorator_knowngood
15
15
  def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn, caution, grad):
16
16
  new = psgd_precond_grad(False, exprs, ea, *q)
17
17
  _compilable_update_([p], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
@@ -2,10 +2,10 @@ import torch
2
2
  import torch.optim
3
3
  from heavyball.utils import copy_stochastic_list_
4
4
 
5
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
5
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, decorator_knowngood
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
8
+ @decorator_knowngood
9
9
  def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
10
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
11
 
@@ -2,10 +2,10 @@ import torch
2
2
  import torch.optim
3
3
  from heavyball.utils import copy_stochastic_list_
4
4
 
5
- from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
5
+ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote, decorator_knowngood
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
8
+ @decorator_knowngood
9
9
  def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
10
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
11
  update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
@@ -1,10 +1,10 @@
1
1
  import torch
2
2
  import torch.optim
3
3
 
4
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
4
+ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_, decorator_knowngood
5
5
 
6
6
 
7
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
7
+ @decorator_knowngood
8
8
  def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
9
9
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
10
10
 
@@ -1,11 +1,10 @@
1
1
  import torch
2
2
  import torch.optim
3
- from heavyball.utils import get_ckp1, copy_stochastic_list_
4
3
 
5
- from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
4
+ from .utils import get_ckp1, copy_stochastic_list_, warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_, decorator_knowngood
6
5
 
7
6
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
7
+ @decorator_knowngood
9
8
  def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
9
  old_debiased2 = beta_debias(beta2, step)
11
10
 
@@ -2,10 +2,10 @@ import torch
2
2
  import torch.optim
3
3
 
4
4
  from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, promote, \
5
- _compilable_schedule_free_, copy_stochastic_list_
5
+ _compilable_schedule_free_, copy_stochastic_list_, decorator_knowngood
6
6
 
7
7
 
8
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
8
+ @decorator_knowngood
9
9
  def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
10
10
  old_debiased2 = beta_debias(beta2, step)
11
11
 
@@ -4,10 +4,10 @@ import torch
4
4
 
5
5
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
6
  beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
7
- promote
7
+ promote, decorator_knowngood
8
8
 
9
9
 
10
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
10
+ @decorator_knowngood
11
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -1,13 +1,13 @@
1
1
  import random
2
2
 
3
3
  import torch
4
- from heavyball.utils import mars_correction
5
4
 
6
5
  from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
7
- beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote, decorator_knowngood, \
7
+ mars_correction
8
8
 
9
9
 
10
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
10
+ @decorator_knowngood
11
11
  def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
12
  eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
13
  denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
@@ -10,8 +10,11 @@ import torch
10
10
  from torch import Tensor
11
11
  from torch.backends import cudnn, opt_einsum
12
12
  from torch.utils._pytree import tree_map
13
+ from torch._dynamo.exc import TorchDynamoException
13
14
 
14
- compile_mode = None
15
+ compile_mode = "max-autotune-no-cudagraphs"
16
+ dynamic = False
17
+ compile_mode_recommended_to_none = None
15
18
  zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
16
19
 
17
20
 
@@ -24,7 +27,22 @@ def decorator(func):
24
27
  return func(*args, **kwargs)
25
28
  nonlocal compiled
26
29
  if compiled is None:
27
- compiled = torch.compile(func, fullgraph=True, dynamic=False, mode=compile_mode)
30
+ compiled = torch.compile(func, fullgraph=True, dynamic=dynamic, mode=compile_mode_recommended_to_none)
31
+ return compiled(*args, **kwargs)
32
+
33
+ return _fn
34
+
35
+
36
+ def decorator_knowngood(func):
37
+ compiled = None
38
+
39
+ @functools.wraps(func)
40
+ def _fn(*args, **kwargs):
41
+ if compile_mode is None:
42
+ return func(*args, **kwargs)
43
+ nonlocal compiled
44
+ if compiled is None:
45
+ compiled = torch.compile(func, fullgraph=True, dynamic=dynamic, mode=compile_mode)
28
46
  return compiled(*args, **kwargs)
29
47
 
30
48
  return _fn
@@ -39,7 +57,7 @@ def warmup(lr: float, step: int, warmup_steps: int):
39
57
  return lr * step / warmup_steps
40
58
 
41
59
 
42
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
60
+ @decorator_knowngood
43
61
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor, beta1: Tensor):
44
62
  p32, z32, g32 = [promote(x) for x in (p, z, grad)]
45
63
  for p_, z_, g_ in zip(p32, z32, g32):
@@ -139,7 +157,7 @@ def beta_debias(beta, step):
139
157
  return 1 - (1 - beta) / (1 - beta ** step)
140
158
 
141
159
 
142
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
160
+ @decorator_knowngood
143
161
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]):
144
162
  torch._foreach_mul_(state, beta2)
145
163
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
@@ -175,7 +193,7 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
175
193
  def is_compiling():
176
194
  try:
177
195
  return torch.compiler.is_compiling()
178
- except AttributeError:
196
+ except TorchDynamoException:
179
197
  return True
180
198
 
181
199
 
@@ -339,7 +357,7 @@ def get_orthogonal_matrix(mat):
339
357
  return final
340
358
 
341
359
 
342
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
360
+ @decorator_knowngood
343
361
  def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
344
362
  for x_, y_ in zip(x, y):
345
363
  x32 = promote(x_)
@@ -368,7 +386,7 @@ def scalar_guard(x, ref):
368
386
  return x
369
387
 
370
388
 
371
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
389
+ @decorator_knowngood
372
390
  def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
373
391
  for x_, y_ in zip(x, y):
374
392
  x32 = promote(x_)
@@ -595,7 +613,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
595
613
  else:
596
614
  with torch.enable_grad():
597
615
  loss = closure()
598
- with torch.no_grad():
616
+
617
+ # we assume that parameters are constant and that there are no excessive recompiles
618
+ with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
599
619
  for top_group in self.param_groups:
600
620
  for group in self.get_groups(top_group):
601
621
  self._step(group)
@@ -643,7 +663,7 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
643
663
  copy_stochastic_(t, s)
644
664
 
645
665
 
646
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
666
+ @decorator_knowngood
647
667
  def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
648
668
  grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
649
669
  beta1 = beta_debias(beta1, step)
@@ -667,7 +687,7 @@ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor]
667
687
  return denom
668
688
 
669
689
 
670
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
690
+ @decorator_knowngood
671
691
  def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
672
692
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
673
693
  # create a random 16 bit integer
@@ -686,12 +706,12 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
686
706
  def copy_stochastic_(target: Tensor, source: Tensor):
687
707
  if not is_compiling() and target.data_ptr() == source.data_ptr():
688
708
  return
689
- if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
690
- set_(target, source)
691
- _compilable_copy_stochastic_(target, source)
709
+ if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
710
+ _compilable_copy_stochastic_(target, source.float())
711
+ set_(target, source)
692
712
 
693
713
 
694
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
714
+ @decorator_knowngood
695
715
  def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
696
716
  g: List[Optional[Tensor]]):
697
717
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
@@ -852,7 +872,7 @@ def psgd_lb(A, max_abs):
852
872
  return x
853
873
 
854
874
 
855
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
875
+ @decorator_knowngood
856
876
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
857
877
  """Update Kronecker product preconditioner Q with pair (V, G)."""
858
878
  exprA, exprGs, _ = exprs
@@ -885,7 +905,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
885
905
  stochastic_add_([o], [term1], -1)
886
906
 
887
907
 
888
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
908
+ @decorator_knowngood
889
909
  def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
890
910
  """Precondition gradient G with preconditioner Q."""
891
911
  md = min_dtype(preconds)
@@ -1030,12 +1050,15 @@ class PSGDBase(StatefulOptimizer):
1030
1050
 
1031
1051
 
1032
1052
  # TODO: Figure out why this sometimes crashes
1033
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1053
+ #@decorator_knowngood
1034
1054
  def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
1035
1055
  clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
1036
- md = min_dtype(cached_q + [ea])
1037
- new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
1038
- update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
1056
+ md = min_dtype(list(cached_q) + [ea])
1057
+ args = [q.to(md) for q in cached_q]
1058
+ args = args + [ea.to(md)]
1059
+ new = torch.einsum(expr, *args)
1060
+ new = new.to(torch.float32)
1061
+ _compilable_update_([param], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
1039
1062
 
1040
1063
 
1041
1064
  def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
@@ -1044,7 +1067,7 @@ def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: T
1044
1067
  _compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
1045
1068
 
1046
1069
 
1047
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1070
+ @decorator_knowngood
1048
1071
  def _compilable_mars_correction_(g: Tensor, old_g: Tensor, a: Tensor):
1049
1072
  g_copy = [g_.clone() for g_ in g]
1050
1073
  _compilable_stochastic_lerp_(g, old_g, a)
@@ -1058,7 +1081,7 @@ def mars_correction(g, old_g, beta1, gamma):
1058
1081
  _compilable_mars_correction_(g, old_g, a)
1059
1082
 
1060
1083
 
1061
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1084
+ @decorator_knowngood
1062
1085
  def _compilable_cautioning_(g: Tensor, update: Tensor):
1063
1086
  mask = (g * update) > 0
1064
1087
  update.masked_fill_(~mask, 0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.23.0
3
+ Version: 0.23.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.23.0',
13
+ version='0.23.3',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -12,7 +12,7 @@ config.cache_size_limit = 128
12
12
 
13
13
  @pytest.mark.parametrize("opt", heavyball.__all__)
14
14
  @pytest.mark.parametrize("size,depth", [(128, 2)])
15
- def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations: int = 2):
15
+ def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations: int = 1):
16
16
  set_torch()
17
17
  opt = getattr(heavyball, opt)
18
18
  if ScheduleFree in opt.__mro__:
@@ -27,11 +27,11 @@ def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations:
27
27
  losses.append([])
28
28
 
29
29
  for i in range(outer_iterations):
30
- model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().double()
30
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
31
31
  o = get_optim(opt, model.parameters(), lr=1e-5, mars=mars)
32
32
 
33
33
  for _ in range(iterations):
34
- loss = model(torch.randn((1024, size), device='cuda', dtype=torch.double)).square().mean()
34
+ loss = model(torch.randn((1024, size), device='cuda')).square().mean()
35
35
  loss.backward()
36
36
  o.step()
37
37
  o.zero_grad()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes