heavyball 0.14.7__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from .cached_psgd_kron import ForeachCachedPSGDKron
2
+ from .delayed_psgd import ForeachDelayedPSGD
1
3
  from .foreach_adamw import ForeachAdamW
2
4
  from .foreach_adopt import ForeachADOPT
3
5
  from .foreach_laprop import ForeachLaProp
@@ -12,11 +14,31 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
12
14
  from .psgd_kron import ForeachPSGDKron
13
15
  from .pure_psgd import ForeachPurePSGD
14
16
  from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
15
- from .delayed_psgd import ForeachDelayedPSGD
16
17
 
17
18
  PalmForEachSoap = PaLMForeachSOAP
18
19
 
20
+ PaLMSOAP = PaLMForeachSOAP
21
+ PaLMSFAdamW = PaLMForeachSFAdamW
22
+ PaLMSFSoap = SFPaLMForeachSOAP
23
+ PaLMForeachSOAP = PaLMForeachSOAP
24
+ PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
25
+ SOAP = ForeachSOAP
26
+ SFAdamW = ForeachSFAdamW
27
+ LaProp = ForeachLaProp
28
+ ADOPT = ForeachADOPT
29
+ PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
30
+ PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
31
+ PSGDKron = ForeachPSGDKron
32
+ AdamW = ForeachAdamW
33
+ PurePSGD = ForeachPurePSGD
34
+ PaLMPAdam = ForeachPaLMPAdam
35
+ DelayedPSGD = ForeachDelayedPSGD
36
+ CachedPSGDKron = ForeachCachedPSGDKron
37
+
19
38
  __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
20
39
  'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
21
- 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
22
- 'ForeachPaLMPAdam', 'ForeachDelayedPSGD']
40
+ 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
41
+ 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
42
+ 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
43
+ 'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
44
+ 'CachedPSGDKron']
@@ -0,0 +1,141 @@
1
+ """
2
+ Originally from Evan Walters and Omead Pooladzandi, 2024
3
+ Modified under Creative Commons Attribution 4.0 International
4
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from heavyball.utils import einsum_base
11
+
12
+ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
+ precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
14
+
15
+
16
+ class ForeachCachedPSGDKron(PSGDBase):
17
+ """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
18
+
19
+ Args:
20
+ params (iterable): Iterable of parameters to optimize or dicts defining
21
+ parameter groups.
22
+ lr (float): Learning rate.
23
+ b1 (float): Momentum parameter.
24
+ weight_decay (float): Weight decay (L2 penalty).
25
+ preconditioner_update_probability (callable or float, optional): Probability of
26
+ updating the preconditioner. If None, defaults to a schedule that anneals
27
+ from 1.0 to 0.03 by 4000 steps.
28
+ max_size_triangular (int): Max size for dim's preconditioner to be triangular.
29
+ min_ndim_triangular (int): Minimum number of dimensions a layer needs
30
+ to have triangular preconditioners.
31
+ memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
32
+ to set all preconditioners to be triangular, 'one_diag' sets the largest
33
+ or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
34
+ to be diagonal.
35
+ momentum_into_precond_update: (bool), whether to send momentum into preconditioner
36
+ update instead of raw gradients.
37
+ """
38
+
39
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
40
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
41
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
42
+ split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
43
+ if not 0.0 <= lr:
44
+ raise ValueError(f"Invalid learning rate: {lr}")
45
+ if not 0.0 <= beta < 1.0:
46
+ raise ValueError(f"Invalid beta parameter: {beta}")
47
+ if not 0.0 <= weight_decay:
48
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
+
50
+ if preconditioner_update_probability is None:
51
+ preconditioner_update_probability = precond_update_prob_schedule()
52
+ if clip_fn is None:
53
+ clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
54
+ self.preconditioner_update_probability = preconditioner_update_probability
55
+ self.clip_fn = clip_fn
56
+
57
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
58
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
59
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
60
+ # precond lr hardcoded to 0.1
61
+ precond_init_scale=1.0, # precond init scale hardcoded to 1.0
62
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
63
+ store_triu_as_line=store_triu_as_line)
64
+ super().__init__(params, defaults)
65
+
66
+ self._prob_step = 0
67
+
68
+ def _step(self, group):
69
+ # update preconditioners all together
70
+ update_prob = self.preconditioner_update_probability
71
+ if callable(update_prob):
72
+ update_prob = update_prob(self._prob_step)
73
+ do_update = self.rng.random() < update_prob
74
+ self._prob_step += 1
75
+
76
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
77
+ precond_init_scale = group['precond_init_scale']
78
+ max_size_triangular = group['max_size_triangular']
79
+ min_ndim_triangular = group['min_ndim_triangular']
80
+ memory_save_mode = group['memory_save_mode']
81
+ precond_lr = group['precond_lr']
82
+ weight_decay = group['weight_decay']
83
+ lr = group['lr']
84
+ beta = group['beta']
85
+ store_triu_as_line = group['store_triu_as_line']
86
+
87
+ vals = []
88
+
89
+ for p, g in split_p_and_g_in_group(group):
90
+ state = self.state_(p)
91
+
92
+ if 'Q' not in state:
93
+ state["exp_avg"] = torch.zeros_like(g)
94
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
95
+ memory_save_mode, dtype=g.dtype)
96
+ state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
97
+ state['Q_cache'] = [torch.empty_like(q) for q in Q]
98
+
99
+ expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
100
+ expr = ','.join(expr)
101
+ grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
102
+ out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
103
+ expr = f'{expr},{grad_expr}->{out_expr}'
104
+
105
+ state['cache_expr'] = expr
106
+
107
+ vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
108
+
109
+ if not vals:
110
+ return
111
+
112
+ p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
113
+ del vals
114
+
115
+ group["step"] += 1
116
+
117
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
118
+
119
+ grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
120
+ exp_avg_list)
121
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
122
+ cached_q = Q_cache_list.pop(0)
123
+ q_orig = Q_list.pop(0)
124
+ ea = exp_avg_list.pop(0)
125
+
126
+ if do_update:
127
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
128
+ self.balance([g], [q])
129
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
130
+ [q_orig] if store_triu_as_line else None)
131
+ for c_, q_ in zip(cached_q, q):
132
+ if q_.ndim == 2:
133
+ torch.matmul(q_.T.conj(), q_, out=c_)
134
+ else:
135
+ torch.mul(q_.conj(), q_, out=c_)
136
+
137
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
138
+ grad_list = self.clip_fn(grad_list)
139
+
140
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
141
+ update_param_(p_list, grad_list, lr, weight_decay)
heavyball/delayed_psgd.py CHANGED
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import copy_stochastic_list_
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  precond_update_prob_schedule, split_p_and_g_in_group, triu_to_line, line_to_triu, set_
12
12
 
@@ -38,7 +38,7 @@ class ForeachDelayedPSGD(PSGDBase):
38
38
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: callable = None):
41
+ split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
42
42
  if not 0.0 <= lr:
43
43
  raise ValueError(f"Invalid learning rate: {lr}")
44
44
  if not 0.0 <= beta < 1.0:
@@ -58,18 +58,13 @@ class ForeachDelayedPSGD(PSGDBase):
58
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
59
59
  # precond lr hardcoded to 0.1
60
60
  precond_init_scale=1.0, # precond init scale hardcoded to 1.0
61
- step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split)
61
+ step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
62
+ store_triu_as_line=store_triu_as_line)
62
63
  super().__init__(params, defaults)
63
64
 
64
65
  self._prob_step = 0
65
66
 
66
- @torch.no_grad()
67
- def step(self, closure=None):
68
- loss = None
69
- if closure is not None:
70
- with torch.enable_grad():
71
- loss = closure()
72
-
67
+ def _step(self, group):
73
68
  # update preconditioners all together
74
69
  update_prob = self.preconditioner_update_probability
75
70
  if callable(update_prob):
@@ -77,55 +72,52 @@ class ForeachDelayedPSGD(PSGDBase):
77
72
  do_update = self.rng.random() < update_prob
78
73
  self._prob_step += 1
79
74
 
80
- for group in self.param_groups:
81
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
82
- precond_init_scale = group['precond_init_scale']
83
- max_size_triangular = group['max_size_triangular']
84
- min_ndim_triangular = group['min_ndim_triangular']
85
- memory_save_mode = group['memory_save_mode']
86
- precond_lr = group['precond_lr']
87
- weight_decay = group['weight_decay']
88
- lr = group['lr']
89
- beta = group['beta']
90
-
91
- vals = []
92
-
93
- for p, g in split_p_and_g_in_group(group):
94
- state = self.state_(p)
75
+ momentum_into_precond_update = group.get("momentum_into_precond_update", True)
76
+ precond_init_scale = group['precond_init_scale']
77
+ max_size_triangular = group['max_size_triangular']
78
+ min_ndim_triangular = group['min_ndim_triangular']
79
+ memory_save_mode = group['memory_save_mode']
80
+ precond_lr = group['precond_lr']
81
+ weight_decay = group['weight_decay']
82
+ lr = group['lr']
83
+ beta = group['beta']
84
+ store_triu_as_line = group['store_triu_as_line']
95
85
 
96
- if 'Q' not in state:
97
- state["exp_avg"] = torch.zeros_like(g)
98
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
99
- memory_save_mode, dtype=g.dtype)
100
- state["Q"] = triu_to_line(Q)
86
+ vals = []
101
87
 
102
- vals.append((p, g, state["exp_avg"], state["Q"]))
88
+ for p, g in split_p_and_g_in_group(group):
89
+ state = self.state_(p)
103
90
 
104
- if not vals:
105
- continue
91
+ if 'Q' not in state:
92
+ state["exp_avg"] = torch.zeros_like(g)
93
+ Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
94
+ memory_save_mode, dtype=g.dtype)
95
+ state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
106
96
 
107
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
108
- del vals
97
+ vals.append((p, g, state["exp_avg"], state["Q"]))
109
98
 
110
- group["step"] += 1
99
+ if not vals:
100
+ return
111
101
 
112
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
102
+ p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
103
+ del vals
113
104
 
114
- Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
115
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
116
- q_orig = Q_list.pop(0)
117
- ea = exp_avg_list.pop(0)
118
- q = line_to_triu(q_orig)
119
- self.balance(do_update, [g], [q])
120
- new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
105
+ group["step"] += 1
121
106
 
122
- if do_update:
123
- self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
124
- set_(g, new)
107
+ torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
125
108
 
126
- grad_list = self.clip_fn(grad_list)
109
+ Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
110
+ for i, (p, g) in enumerate(zip(p_list, grad_list)):
111
+ q_orig = Q_list.pop(0)
112
+ ea = exp_avg_list.pop(0)
113
+ q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
+ new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
115
+ if do_update:
116
+ self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
117
+ self.balance([g], [q])
118
+ set_(g, new)
127
119
 
128
- lr = -warmup(lr, group['step'], group['warmup_steps'])
129
- update_param_(p_list, grad_list, lr, weight_decay)
120
+ grad_list = self.clip_fn(grad_list)
130
121
 
131
- return loss
122
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
123
+ update_param_(p_list, grad_list, lr, weight_decay)
@@ -10,42 +10,32 @@ class ForeachAdamW(StatefulOptimizer):
10
10
  lr_max=-1.0, weight_decay=weight_decay)
11
11
  super().__init__(params, defaults)
12
12
 
13
- def step(self, closure=None):
14
- """Performs a single optimization step.
13
+ def _step(self, group):
14
+ eps = group['eps']
15
+ decay = group['weight_decay']
16
+ k = group['k']
15
17
 
16
- Arguments:
17
- closure (callable, optional): A closure that reevaluates the model
18
- and returns the loss.
19
- """
18
+ if not group['train_mode']:
19
+ raise Exception("Not in train mode!")
20
20
 
21
- loss = None
22
- if closure is not None:
23
- loss = closure()
21
+ active_p = [p for p in group['params'] if p.grad is not None]
24
22
 
25
- for group in self.param_groups:
26
- eps = group['eps']
27
- decay = group['weight_decay']
28
- k = group['k']
23
+ if not active_p:
24
+ return
29
25
 
30
- if not group['train_mode']:
31
- raise Exception("Not in train mode!")
26
+ for p in active_p:
27
+ if 'exp_avg' not in self.state_(p):
28
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
29
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
32
30
 
33
- active_p = [p for p in group['params'] if p.grad is not None]
31
+ y, grad, exp_avg_sq, exp_avg = zip(
32
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
33
 
35
- for p in active_p:
36
- if 'exp_avg' not in self.state_(p):
37
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
38
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
34
+ # Decay the first and second moment running average coefficient
35
+ torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
39
37
 
40
- y, grad, exp_avg_sq, exp_avg = zip(
41
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
42
-
43
- # Decay the first and second moment running average coefficient
44
- torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
45
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
46
-
47
- # Normalize grad in-place for memory efficiency
48
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
49
- update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
50
- group['k'] = k + 1
51
- return loss
38
+ # Normalize grad in-place for memory efficiency
39
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
40
+ update_param_(y, exp_avg, lr, decay, lambda p, e, l: torch._foreach_addcdiv_(p, e, denom, l))
41
+ group['k'] = k + 1
@@ -11,51 +11,41 @@ class ForeachADOPT(StatefulOptimizer):
11
11
  lr_max=-1.0, weight_decay=weight_decay)
12
12
  super().__init__(params, defaults)
13
13
 
14
- def step(self, closure=None):
15
- """Performs a single optimization step.
16
-
17
- Arguments:
18
- closure (callable, optional): A closure that reevaluates the model
19
- and returns the loss.
20
- """
21
-
22
- loss = None
23
- if closure is not None:
24
- loss = closure()
25
-
26
- for group in self.param_groups:
27
- eps = group['eps']
28
- decay = group['weight_decay']
29
- k = group['k']
30
-
31
- if not group['train_mode']:
32
- raise Exception("Not in train mode!")
33
-
34
- active_p = [p for p in group['params'] if p.grad is not None]
35
-
36
- for p in active_p:
37
- if 'exp_avg' not in self.state_(p):
38
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
39
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
40
-
41
- y, grad, exp_avg_sq, exp_avg = zip(
42
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
43
-
44
- if k > 1:
45
- lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
46
-
47
- update_param_(y, exp_avg, lr, decay)
48
- if k > 0:
49
- beta1 = beta_debias(group['betas'][0], k)
50
- denom = torch._foreach_sqrt(exp_avg_sq)
51
- torch._foreach_maximum_(denom, eps)
52
- torch._foreach_mul_(exp_avg, beta1)
53
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
54
-
55
- beta2 = beta_debias(group['betas'][1], k + 1)
56
- torch._foreach_mul_(exp_avg_sq, beta2)
57
- torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
58
- del grad
59
-
60
- group['k'] = k + 1
61
- return loss
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
18
+
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
+
22
+ active_p = [p for p in group['params'] if p.grad is not None]
23
+
24
+ if not active_p:
25
+ return
26
+
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
31
+
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
34
+
35
+ if k > 1:
36
+ lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
37
+
38
+ update_param_(y, exp_avg, lr, decay)
39
+ if k > 0:
40
+ beta1 = beta_debias(group['betas'][0], k)
41
+ denom = torch._foreach_sqrt(exp_avg_sq)
42
+ torch._foreach_maximum_(denom, eps)
43
+ torch._foreach_mul_(exp_avg, beta1)
44
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
45
+
46
+ beta2 = beta_debias(group['betas'][1], k + 1)
47
+ torch._foreach_mul_(exp_avg_sq, beta2)
48
+ torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
49
+ del grad
50
+
51
+ group['k'] = k + 1
@@ -11,46 +11,36 @@ class ForeachLaProp(StatefulOptimizer):
11
11
  lr_max=-1.0, weight_decay=weight_decay)
12
12
  super().__init__(params, defaults)
13
13
 
14
- def step(self, closure=None):
15
- """Performs a single optimization step.
14
+ def _step(self, group):
15
+ eps = group['eps']
16
+ decay = group['weight_decay']
17
+ k = group['k']
16
18
 
17
- Arguments:
18
- closure (callable, optional): A closure that reevaluates the model
19
- and returns the loss.
20
- """
19
+ if not group['train_mode']:
20
+ raise Exception("Not in train mode!")
21
21
 
22
- loss = None
23
- if closure is not None:
24
- loss = closure()
22
+ active_p = [p for p in group['params'] if p.grad is not None]
25
23
 
26
- for group in self.param_groups:
27
- eps = group['eps']
28
- decay = group['weight_decay']
29
- k = group['k']
24
+ if not active_p:
25
+ return
30
26
 
31
- if not group['train_mode']:
32
- raise Exception("Not in train mode!")
27
+ for p in active_p:
28
+ if 'exp_avg' not in self.state_(p):
29
+ self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
30
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
33
31
 
34
- active_p = [p for p in group['params'] if p.grad is not None]
32
+ y, grad, exp_avg_sq, exp_avg = zip(
33
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
35
34
 
36
- for p in active_p:
37
- if 'exp_avg' not in self.state_(p):
38
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32)
39
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
35
+ # Decay the first and second moment running average coefficient
36
+ denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
37
+ beta1 = beta_debias(group['betas'][0], k + 1)
38
+ torch._foreach_mul_(exp_avg, beta1)
39
+ torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
40
+ del grad
40
41
 
41
- y, grad, exp_avg_sq, exp_avg = zip(
42
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
42
+ # Normalize grad in-place for memory efficiency
43
+ lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
44
+ update_param_(y, exp_avg, lr, decay)
43
45
 
44
- # Decay the first and second moment running average coefficient
45
- denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
46
- beta1 = beta_debias(group['betas'][0], k + 1)
47
- torch._foreach_mul_(exp_avg, beta1)
48
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
49
- del grad
50
-
51
- # Normalize grad in-place for memory efficiency
52
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
53
- update_param_(y, exp_avg, lr, decay)
54
-
55
- group['k'] = k + 1
56
- return loss
46
+ group['k'] = k + 1
@@ -13,52 +13,42 @@ class ForeachSFAdamW(ScheduleFree):
13
13
  foreach=foreach)
14
14
  super().__init__(params, defaults)
15
15
 
16
- def step(self, closure=None):
17
- """Performs a single optimization step.
16
+ def _step(self, group):
17
+ eps = group['eps']
18
+ decay = group['weight_decay']
19
+ k = group['k']
18
20
 
19
- Arguments:
20
- closure (callable, optional): A closure that reevaluates the model
21
- and returns the loss.
22
- """
21
+ if not group['train_mode']:
22
+ raise Exception("Not in train mode!")
23
23
 
24
- loss = None
25
- if closure is not None:
26
- loss = closure()
24
+ active_p = [p for p in group['params'] if p.grad is not None]
27
25
 
28
- for group in self.param_groups:
29
- eps = group['eps']
30
- decay = group['weight_decay']
31
- k = group['k']
26
+ if not active_p:
27
+ return
32
28
 
33
- if not group['train_mode']:
34
- raise Exception("Not in train mode!")
29
+ for p in active_p:
30
+ if 'z' not in self.state_(p):
31
+ self.state_(p)['z'] = torch.clone(p.data)
32
+ self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
35
33
 
36
- active_p = [p for p in group['params'] if p.grad is not None]
34
+ y, grad, exp_avg_sq, z = zip(
35
+ *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
37
36
 
38
- for p in active_p:
39
- if 'z' not in self.state_(p):
40
- self.state_(p)['z'] = torch.clone(p.data)
41
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32)
37
+ # Decay the first moment running average coefficient
38
+ old_debiased = beta_debias(group['betas'][1], k + 1)
42
39
 
43
- y, grad, exp_avg_sq, z = zip(
44
- *[(p.data, p.grad.float(), self.state_(p)['exp_avg_sq'], self.state_(p)['z']) for p in active_p])
40
+ # Decay the first and second moment running average coefficient
41
+ denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
45
42
 
46
- # Decay the first moment running average coefficient
47
- old_debiased = beta_debias(group['betas'][1], k + 1)
43
+ # Normalize grad in-place for memory efficiency
44
+ torch._foreach_div_(grad, denom)
48
45
 
49
- # Decay the first and second moment running average coefficient
50
- denom = exp_avg_sq_(exp_avg_sq, grad, old_debiased, eps)
46
+ # Weight decay calculated at y
47
+ if decay != 0:
48
+ torch._foreach_add_(grad, y, alpha=decay)
51
49
 
52
- # Normalize grad in-place for memory efficiency
53
- torch._foreach_div_(grad, denom)
50
+ lr = warmup(group['lr'], k + 1, group['warmup_steps'])
51
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
52
+ y, z, grad, group['r'], k + 1)
54
53
 
55
- # Weight decay calculated at y
56
- if decay != 0:
57
- torch._foreach_add_(grad, y, alpha=decay)
58
-
59
- lr = warmup(group['lr'], k + 1, group['warmup_steps'])
60
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
61
- y, z, grad, group['r'], k + 1)
62
-
63
- group['k'] = k + 1
64
- return loss
54
+ group['k'] = k + 1