heavyball 0.24.3__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/foreach_soap.py CHANGED
@@ -71,12 +71,11 @@ class ForeachSOAP(StatefulOptimizer):
71
71
  # Decay the first and second moment running average coefficient
72
72
  # In-place operations to update the averages at the same time
73
73
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
74
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
75
74
 
76
75
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
77
76
 
78
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
79
- state = self.state_(p)
77
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
78
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
80
79
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
81
80
  # i.e. projecting to the eigenbases of matrices in state['GG']
82
81
  exp_avg_projected = project(ea, state['Q'], False)
@@ -81,11 +81,10 @@ class PaLMForeachSOAP(StatefulOptimizer):
81
81
  # In-place operations to update the averages at the same time
82
82
  beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
83
83
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
84
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
85
84
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
86
85
 
87
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
88
- state = self.state_(p)
86
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
87
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
89
88
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
90
89
  # i.e. projecting to the eigenbases of matrices in state['GG']
91
90
  exp_avg_projected = project(ea, state['Q'], False)
@@ -73,12 +73,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
73
73
  # Decay the first and second moment running average coefficient
74
74
  # In-place operations to update the averages at the same time
75
75
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
76
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
77
76
 
78
77
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
79
78
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
80
79
 
81
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
80
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
81
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
82
82
  state = self.state_(p)
83
83
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
84
84
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -84,12 +84,12 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
84
84
  # In-place operations to update the averages at the same time
85
85
  beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
86
86
  step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
87
- denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
88
87
 
89
88
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
90
89
  step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
91
90
 
92
- for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
91
+ for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
92
+ d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
93
93
  state = self.state_(p)
94
94
  # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
95
95
  # i.e. projecting to the eigenbases of matrices in state['GG']
heavyball/utils.py CHANGED
@@ -492,22 +492,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
492
492
  super().__init__(params, {**defaults, 'foreach': foreach})
493
493
  self.fake_groups = {}
494
494
  self.use_ema = use_ema
495
-
496
- def key(self, param: Tensor):
497
- return (param.data_ptr(), tuple(param.shape))
495
+ self.mapping = {}
498
496
 
499
497
  def get_groups(self, group):
500
498
  if group['foreach']:
501
499
  return [group]
502
500
 
503
501
  for p in group['params']:
504
- if self.key(p) not in self.fake_groups:
505
- self.fake_groups[self.key(p)] = {**group, 'params': [p]}
502
+ if p not in self.fake_groups:
503
+ self.fake_groups[p] = {**group, 'params': [p]}
506
504
 
507
- return [self.fake_groups[self.key(p)] for p in group['params']]
505
+ return [self.fake_groups[p] for p in group['params']]
508
506
 
509
507
  def state_(self, arg: Tensor):
510
- return self.state[self.key(arg)]
508
+ return self.state[self.mapping.get(arg, arg)]
511
509
 
512
510
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
513
511
  for p, g in zip(p_list, g_list):
@@ -538,6 +536,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
538
536
  p_views = merge_group(group, p)
539
537
  if grad is not None:
540
538
  grad = merge_group(group, grad)
539
+ for i, pv in enumerate(p_views):
540
+ self.mapping[pv] = (p, i)
541
541
  if isinstance(p_views, Tensor):
542
542
  yield p_views, grad
543
543
  continue
@@ -622,11 +622,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
622
622
  for top_group in self.param_groups:
623
623
  for group in self.get_groups(top_group):
624
624
  self._step(group)
625
+ self.mapping.clear()
625
626
  if self.use_ema:
626
627
  self.ema_update(group)
628
+
627
629
  return loss
628
630
 
629
631
 
632
+
630
633
  class ScheduleFree(StatefulOptimizer):
631
634
  def eval(self):
632
635
  for group in self.param_groups:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.3
3
+ Version: 0.25.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -6,19 +6,19 @@ heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,28
6
6
  heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
7
7
  heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
8
8
  heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
9
- heavyball/foreach_soap.py,sha256=408jRysE9ek0ea-TphhSBMTa9zcjkgMX3qlx8qTCt34,4803
9
+ heavyball/foreach_soap.py,sha256=Tgwg4_Sir9nI_3R85f8NMQagquUBJmAEMQqh0uD3b0Y,4771
10
10
  heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
11
11
  heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
12
- heavyball/palm_foreach_soap.py,sha256=cExM9nTC3zAgsRr42VOIMWNwYA4dAJaA8-pIo7SWilc,6230
13
- heavyball/precond_schedule_foreach_soap.py,sha256=EL_Z-v5l7BC98QgI-Zg9iyM77TAreVgD5Zln59ewGoI,4966
14
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=HWo2t7yY-_n4pPGmDiELccy0jdELTVhdlH-eyFBih5k,6502
12
+ heavyball/palm_foreach_soap.py,sha256=zSjpYYm1hfgIudjo_q3ozu3Vkfhz8w8im1c-ou1U3sI,6198
13
+ heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
14
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
15
15
  heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
16
16
  heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
17
17
  heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
19
- heavyball/utils.py,sha256=AxhcHzbFAvhTgTFyIcdxs9TJkH4AgVEaNeBRjOLzoBM,40095
20
- heavyball-0.24.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.24.3.dist-info/METADATA,sha256=32T-Q-a4k096KjxoR-3DQt25XpO_h0zs7lWKTDQLugI,11926
22
- heavyball-0.24.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.24.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.24.3.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=QAHOZj__Kn1vpSSBm6otfKb656bqoHUyZXrVJrB_23U,40145
20
+ heavyball-0.25.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.25.0.dist-info/METADATA,sha256=T649SIhfWXSJVTRXJnYLctkD1fQJl95r05Mrhdw8nck,11926
22
+ heavyball-0.25.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.25.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.25.0.dist-info/RECORD,,