adv-optm 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -1,231 +1,243 @@
1
- import torch
2
-
3
- from typing import Tuple, Optional
4
-
5
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
- from ..util.Effective_Shape import _get_effective_shape
7
- from ..util.NNMF import _nnmf,_unnmf
8
- from ..util.OrthoGrad import _orthogonalize_gradient
9
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
-
11
- class Lion_adv(torch.optim.Optimizer):
12
- """
13
- Implements the SMMF technique for Lion algorithm.
14
-
15
- This optimizer combines the Lion update rule with the memory-saving low-rank
16
- compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
-
18
- Args:
19
- params (iterable): iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float, optional): learning rate (default: 1e-4).
22
- betas (Tuple[float, float], optional): coefficients for computing
23
- running averages of the update (default: (0.9, 0.99)).
24
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
- vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
- matrices to apply low-rank compression (default: True).
27
- stochastic_rounding (bool, optional): whether to use stochastic
28
- rounding for BF16 parameter updates (default: True).
29
- use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
- factored (bool): whether to use the factorization or use the
35
- uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
- """
39
-
40
- def __init__(
41
- self,
42
- params,
43
- lr: float = 1e-4,
44
- betas: Tuple[float, float] = (0.9, 0.99),
45
- weight_decay: float = 0.0,
46
- vector_reshape: bool = True,
47
- stochastic_rounding: bool = True,
48
- use_orthograd: bool = False,
49
- use_cautious: bool = False,
50
- clip_threshold: float = 0.0,
51
- factored: bool = True,
52
- variance_reduction: bool = False,
53
- ):
54
- if not lr > 0.0:
55
- raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
56
- if not all(0.0 <= beta <= 1.0 for beta in betas):
57
- raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
- if not weight_decay >= 0.0:
59
- raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
- if variance_reduction and use_cautious:
61
- print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
62
-
63
- defaults = dict(
64
- lr=lr,
65
- betas=betas,
66
- weight_decay=weight_decay,
67
- vector_reshape=vector_reshape,
68
- use_orthograd=use_orthograd,
69
- clip_threshold=clip_threshold,
70
- )
71
- self.stochastic_rounding = stochastic_rounding
72
- self.use_cautious = use_cautious
73
- self.factored = factored
74
- self.variance_reduction = variance_reduction
75
- super().__init__(params, defaults)
76
-
77
- @property
78
- def supports_fused_back_pass(self) -> bool:
79
- return True
80
-
81
- @property
82
- def supports_memory_efficient_fp16(self) -> bool:
83
- return True
84
-
85
- @property
86
- def supports_flat_params(self) -> bool:
87
- return False
88
-
89
- @torch.no_grad()
90
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
91
- """Performs a single optimization step on a single parameter."""
92
- if p.grad is None:
93
- return
94
-
95
- grad = p.grad
96
- if grad.dtype != torch.float32 and self.factored:
97
- grad = grad.float()
98
- if group["clip_threshold"] > 0.0:
99
- grad_norm = torch.norm(grad.detach())
100
- if grad_norm > group["clip_threshold"]:
101
- clip_coef = group["clip_threshold"] / grad_norm
102
- grad.mul_(clip_coef)
103
- if group["use_orthograd"]:
104
- grad = _orthogonalize_gradient(p, grad)
105
- state = self.state[p]
106
-
107
- # State Initialization
108
- if len(state) == 0:
109
- state['step'] = 0
110
-
111
- should_factor = (
112
- self.factored and
113
- not (len(p.shape) == 1 and not group['vector_reshape'])
114
- )
115
-
116
- state['factored'] = should_factor
117
-
118
- dtype = torch.float32 if self.factored else p.dtype
119
-
120
- if state['factored']:
121
- state['effective_shape'] = _get_effective_shape(p.numel())
122
- d1, d2 = state['effective_shape']
123
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
124
- state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
125
- packed_d2 = (d2 + 7) // 8
126
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
127
- if self.variance_reduction:
128
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
129
- else: # Fallback to standard Lion
130
- state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
- if self.variance_reduction:
132
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
133
-
134
- state['step'] += 1
135
- beta1, beta2 = group["betas"]
136
- lr = group["lr"]
137
-
138
- if state['factored']:
139
- # Factored Path
140
- d1, d2 = state['effective_shape']
141
- grad_reshaped = grad.view(d1, d2)
142
- # Reconstruct momentum m_{t-1}
143
- exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
144
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
145
- torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
146
- del unpacked_sign
147
- if exp_avg.dtype != torch.float32:
148
- exp_avg = exp_avg.float()
149
-
150
- # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
151
- signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
152
-
153
- if self.use_cautious:
154
- mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
155
- mask.div_(mask.mean().clamp_(min=1e-3))
156
- signed_update.mul_(mask)
157
- del mask
158
-
159
- # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
160
- update_for_param = signed_update.view(p.shape).mul_(lr)
161
-
162
- # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
163
- if self.variance_reduction:
164
- vr_term = grad_reshaped - state['prev_grad']
165
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2).add_(vr_term, alpha=beta2)
166
- del vr_term
167
- state['prev_grad'].copy_(grad_reshaped)
168
- else:
169
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
170
- del grad_reshaped
171
-
172
- # Compress new momentum m_t and store factors
173
- state['sign'] = _pack_bools(exp_avg > 0)
174
- _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
175
- del exp_avg
176
-
177
- else:
178
- # Fallback to standard Lion logic
179
- exp_avg = state["exp_avg"]
180
-
181
- # Compute update term and sign for the update
182
- if exp_avg.dtype != torch.float32 and self.factored:
183
- exp_avg = exp_avg.float()
184
- signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
185
-
186
- if self.use_cautious:
187
- mask = (signed_update * grad > 0).to(grad.dtype)
188
- mask.div_(mask.mean().clamp_(min=1e-3))
189
- signed_update.mul_(mask)
190
- del mask
191
-
192
- update_for_param = signed_update.mul_(lr)
193
-
194
- # Update momentum
195
- if self.variance_reduction:
196
- vr_term = grad - state['prev_grad']
197
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2).add_(vr_term, alpha=beta2)
198
- state['prev_grad'].copy_(grad)
199
- else:
200
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
201
-
202
- if group["weight_decay"] != 0:
203
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
204
- add_stochastic_(p.data, p.data,
205
- alpha=-group["weight_decay"] * lr)
206
- else:
207
- p.data.add_(
208
- p.data, alpha=-group["weight_decay"] * lr
209
- )
210
-
211
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
212
- add_stochastic_(p.data, -update_for_param)
213
- else:
214
- p.data.add_(-update_for_param)
215
-
216
- del update_for_param
217
-
218
- @torch.no_grad()
219
- def step(self, closure: Optional[callable] = None):
220
- """Performs a single optimization step."""
221
- loss = None
222
- if closure is not None:
223
- with torch.enable_grad():
224
- loss = closure()
225
-
226
- for group in self.param_groups:
227
- for i, p in enumerate(group["params"]):
228
- if p.grad is not None:
229
- self.step_parameter(p, group, i)
230
-
1
+ import torch
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ class Lion_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements the SMMF technique for Lion algorithm.
14
+
15
+ This optimizer combines the Lion update rule with the memory-saving low-rank
16
+ compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ factored (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ variance_reduction (bool): whether to use the variance reduction technique
37
+ from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ params,
43
+ lr: float = 1e-4,
44
+ betas: Tuple[float, float] = (0.9, 0.99),
45
+ weight_decay: float = 0.0,
46
+ vector_reshape: bool = True,
47
+ stochastic_rounding: bool = True,
48
+ use_orthograd: bool = False,
49
+ use_cautious: bool = False,
50
+ clip_threshold: float = 0.0,
51
+ factored: bool = True,
52
+ variance_reduction: bool = False,
53
+ ):
54
+ if not lr > 0.0:
55
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
56
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
57
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
+ if not weight_decay >= 0.0:
59
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
+
61
+ defaults = dict(
62
+ lr=lr,
63
+ betas=betas,
64
+ weight_decay=weight_decay,
65
+ vector_reshape=vector_reshape,
66
+ use_orthograd=use_orthograd,
67
+ clip_threshold=clip_threshold,
68
+ )
69
+ self.stochastic_rounding = stochastic_rounding
70
+ self.use_cautious = use_cautious
71
+ self.factored = factored
72
+ self.variance_reduction = variance_reduction
73
+ super().__init__(params, defaults)
74
+
75
+ @property
76
+ def supports_fused_back_pass(self) -> bool:
77
+ return True
78
+
79
+ @property
80
+ def supports_memory_efficient_fp16(self) -> bool:
81
+ return True
82
+
83
+ @property
84
+ def supports_flat_params(self) -> bool:
85
+ return False
86
+
87
+ @torch.no_grad()
88
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
89
+ """Performs a single optimization step on a single parameter."""
90
+ if p.grad is None:
91
+ return
92
+
93
+ grad = p.grad
94
+ if grad.dtype != torch.float32 and self.factored:
95
+ grad = grad.float()
96
+ if group["clip_threshold"] > 0.0:
97
+ grad_norm = torch.norm(grad.detach())
98
+ if grad_norm > group["clip_threshold"]:
99
+ clip_coef = group["clip_threshold"] / grad_norm
100
+ grad.mul_(clip_coef)
101
+ if group["use_orthograd"]:
102
+ grad = _orthogonalize_gradient(p, grad)
103
+ state = self.state[p]
104
+
105
+ # State Initialization
106
+ if len(state) == 0:
107
+ state['step'] = 0
108
+
109
+ should_factor = (
110
+ self.factored and
111
+ not (len(p.shape) == 1 and not group['vector_reshape'])
112
+ )
113
+
114
+ state['factored'] = should_factor
115
+
116
+ dtype = torch.float32 if self.factored else p.dtype
117
+
118
+ if state['factored']:
119
+ state['effective_shape'] = _get_effective_shape(p.numel())
120
+ d1, d2 = state['effective_shape']
121
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
122
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
123
+ packed_d2 = (d2 + 7) // 8
124
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
125
+ if self.variance_reduction:
126
+ state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
127
+ else: # Fallback to standard Lion
128
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
129
+ if self.variance_reduction:
130
+ state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
+
132
+ state['step'] += 1
133
+ beta1, beta2 = group["betas"]
134
+ lr = group["lr"]
135
+
136
+ if state['factored']:
137
+ # Factored Path
138
+ d1, d2 = state['effective_shape']
139
+ grad_reshaped = grad.view(d1, d2)
140
+ # Reconstruct momentum m_{t-1}
141
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
142
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
143
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
144
+ del unpacked_sign
145
+ if exp_avg.dtype != torch.float32:
146
+ exp_avg = exp_avg.float()
147
+
148
+ # Compute update term c_t
149
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
150
+
151
+ if self.use_cautious:
152
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
153
+ mask.div_(mask.mean().clamp_(min=1e-3))
154
+ signed_update.mul_(mask)
155
+ del mask
156
+
157
+ # Parameter update
158
+ update_for_param = signed_update.view(p.shape).mul_(lr)
159
+
160
+ # Update momentum
161
+ if self.variance_reduction:
162
+ if state['step'] == 1:
163
+ exp_avg.copy_(grad_reshaped)
164
+ else:
165
+ # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
166
+ correction = exp_avg.sub(state['prev_grad'])
167
+ # Calculate the new momentum and store it back into exp_avg
168
+ exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
169
+ del correction
170
+ # Update prev_grad for the next iteration
171
+ state['prev_grad'].copy_(grad_reshaped)
172
+ else:
173
+ # Standard Lion momentum update
174
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
175
+
176
+ # Compress new momentum m_t and store factors
177
+ state['sign'] = _pack_bools(exp_avg > 0)
178
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
179
+ del exp_avg
180
+
181
+ else:
182
+ # Fallback to standard Lion logic
183
+ exp_avg = state["exp_avg"]
184
+
185
+ # Compute update term and sign for the update
186
+ if exp_avg.dtype != torch.float32 and self.factored:
187
+ exp_avg = exp_avg.float()
188
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
189
+
190
+ if self.use_cautious:
191
+ mask = (signed_update * grad > 0).to(grad.dtype)
192
+ mask.div_(mask.mean().clamp_(min=1e-3))
193
+ signed_update.mul_(mask)
194
+ del mask
195
+
196
+ update_for_param = signed_update.mul_(lr)
197
+
198
+ # Update momentum
199
+ if self.variance_reduction:
200
+ if state['step'] == 1:
201
+ exp_avg.copy_(grad)
202
+ else:
203
+ # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
204
+ correction = exp_avg.sub(state['prev_grad'])
205
+ # Calculate the new momentum and store it back into exp_avg
206
+ exp_avg.copy_(grad).add_(correction, alpha=beta2)
207
+ del correction
208
+ # Update prev_grad for the next iteration
209
+ state['prev_grad'].copy_(grad)
210
+ else:
211
+ # Standard Lion momentum update
212
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
213
+
214
+ if group["weight_decay"] != 0:
215
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
216
+ add_stochastic_(p.data, p.data,
217
+ alpha=-group["weight_decay"] * lr)
218
+ else:
219
+ p.data.add_(
220
+ p.data, alpha=-group["weight_decay"] * lr
221
+ )
222
+
223
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
224
+ add_stochastic_(p.data, -update_for_param)
225
+ else:
226
+ p.data.add_(-update_for_param)
227
+
228
+ del update_for_param
229
+
230
+ @torch.no_grad()
231
+ def step(self, closure: Optional[callable] = None):
232
+ """Performs a single optimization step."""
233
+ loss = None
234
+ if closure is not None:
235
+ with torch.enable_grad():
236
+ loss = closure()
237
+
238
+ for group in self.param_groups:
239
+ for i, p in enumerate(group["params"]):
240
+ if p.grad is not None:
241
+ self.step_parameter(p, group, i)
242
+
231
243
  return loss
@@ -194,11 +194,12 @@ class Prodigy_adv(torch.optim.Optimizer):
194
194
  d1, d2 = state['effective_shape']
195
195
 
196
196
  # First moment (m)
197
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
198
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
199
- if not self.use_grams:
200
- packed_d2 = (d2 + 7) // 8
201
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
197
+ if self.beta1 > 0:
198
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
199
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
200
+ if not self.use_grams:
201
+ packed_d2 = (d2 + 7) // 8
202
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
202
203
  if self.use_AdEMAMix:
203
204
  state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
204
205
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
@@ -208,7 +209,8 @@ class Prodigy_adv(torch.optim.Optimizer):
208
209
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
209
210
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
210
211
  else: # Fallback to standard AdamW for non-factored tensors
211
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
212
+ if self.beta1 > 0:
213
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
212
214
  if self.use_AdEMAMix:
213
215
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
214
216
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
@@ -231,22 +233,24 @@ class Prodigy_adv(torch.optim.Optimizer):
231
233
  if state['factored']:
232
234
  d1, d2 = state['effective_shape']
233
235
 
234
- # Reconstruct momentum from previous step's factors
235
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
236
- if not self.use_grams:
237
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
238
- torch.where(unpacked_sign, mt, -mt, out=mt)
239
- del unpacked_sign
240
- # Update momentum in full-size
241
236
  grad_reshaped = grad.view(d1, d2)
242
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
243
- if self.use_grams:
244
- mt.copy_(grad_reshaped.sign() * mt.abs())
245
- elif self.use_cautious:
246
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
247
- mask.div_(mask.mean().clamp_(min=1e-3))
248
- mt.mul_(mask)
249
- del mask
237
+
238
+ # Reconstruct momentum from previous step's factors
239
+ if self.beta1 > 0:
240
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
241
+ if not self.use_grams:
242
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
243
+ torch.where(unpacked_sign, mt, -mt, out=mt)
244
+ del unpacked_sign
245
+ # Update momentum in full-size
246
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
247
+ if self.use_grams:
248
+ mt.copy_(grad_reshaped.sign() * mt.abs())
249
+ elif self.use_cautious:
250
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
251
+ mask.div_(mask.mean().clamp_(min=1e-3))
252
+ mt.mul_(mask)
253
+ del mask
250
254
 
251
255
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
252
256
  vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
@@ -258,30 +262,29 @@ class Prodigy_adv(torch.optim.Optimizer):
258
262
  unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
259
263
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
260
264
  del unpacked_sign_slow
261
-
262
265
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
263
- update_m = mt + (alpha_t * mt_slow)
266
+ update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
264
267
  else:
265
- update_m = mt
268
+ update = mt if self.beta1 > 0 else grad_reshaped
266
269
  del grad_reshaped
267
270
 
268
271
  if group['use_atan2']:
269
272
  a = 1.2732395
270
273
  denom = vt.sqrt()
271
- update = torch.atan2(update_m, denom).mul_(a)
274
+ update.atan2_(denom).mul_(a)
272
275
  else:
273
- denom = vt.sqrt().add_(self.d * group['eps'])
274
- update = update_m / denom
275
- del update_m, denom
276
+ denom = vt.sqrt()
277
+ update.div_(denom.add_(self.d * group['eps']))
278
+ del denom
276
279
 
277
- update = update.view(p.shape)
278
- update.mul_(self.dlr)
280
+ update.view(p.shape).mul_(self.dlr)
279
281
 
280
282
  # Compress updated moments and store new factors
281
- if not self.use_grams:
282
- state['sign'] = _pack_bools(mt > 0)
283
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
284
- del mt
283
+ if self.beta1 > 0:
284
+ if not self.use_grams:
285
+ state['sign'] = _pack_bools(mt > 0)
286
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
287
+ del mt
285
288
  if self.use_AdEMAMix:
286
289
  state['sign_slow'] = _pack_bools(mt_slow > 0)
287
290
  _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -290,36 +293,38 @@ class Prodigy_adv(torch.optim.Optimizer):
290
293
  del vt
291
294
 
292
295
  else: # Standard AdamW logic for non-factored tensors
293
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
294
-
295
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
296
- if self.use_grams:
297
- exp_avg = grad.sign() * exp_avg.abs()
298
- elif self.use_cautious:
299
- mask = (exp_avg * grad > 0).to(grad.dtype)
300
- mask.div_(mask.mean().clamp_(min=1e-3))
301
- exp_avg.mul_(mask)
302
- del mask
296
+ exp_avg_sq = state['exp_avg_sq']
297
+
298
+ if self.beta1 > 0:
299
+ exp_avg = state['exp_avg']
300
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
301
+ if self.use_grams:
302
+ exp_avg = grad.sign() * exp_avg.abs()
303
+ elif self.use_cautious:
304
+ mask = (exp_avg * grad > 0).to(grad.dtype)
305
+ mask.div_(mask.mean().clamp_(min=1e-3))
306
+ exp_avg.mul_(mask)
307
+ del mask
303
308
 
304
309
  if self.use_AdEMAMix:
305
310
  exp_avg_slow = state['exp_avg_slow']
306
311
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
307
- update_m = exp_avg + (alpha_t * exp_avg_slow)
312
+ update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
308
313
  else:
309
- update_m = exp_avg
314
+ update = exp_avg if self.beta1 > 0 else grad
310
315
 
311
316
  exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
312
317
 
313
318
  if group['use_atan2']:
314
319
  a = 1.2732395
315
320
  denom = exp_avg_sq.sqrt()
316
- update = torch.atan2(update_m, denom).mul_(a)
321
+ update.atan2_(denom).mul_(a)
317
322
  else:
318
- denom = exp_avg_sq.sqrt().add_(self.d * group['eps'])
319
- update = update_m / denom
320
- del update_m, denom
323
+ denom = exp_avg_sq.sqrt()
324
+ update.div_(denom.add_(self.d * group['eps']))
325
+ del denom
321
326
 
322
- update = update.mul_(self.dlr)
327
+ update.mul_(self.dlr)
323
328
 
324
329
  # --- Accumulate Prodigy stats ---
325
330
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']