adv-optm 1.1.0.dev3__py3-none-any.whl → 1.1.0.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -1,315 +1,315 @@
1
- import torch
2
- import torch.distributed as dist
3
-
4
- import math
5
-
6
- from typing import Tuple, Optional
7
-
8
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
- from ..util.Effective_Shape import _get_effective_shape
10
- from ..util.NNMF import _nnmf,_unnmf
11
- from ..util.OrthoGrad import _orthogonalize_gradient
12
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
-
14
- class Lion_Prodigy_adv(torch.optim.Optimizer):
15
- """
16
- Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
17
-
18
- Args:
19
- params (iterable): iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float, optional): learning rate (default: 1e-4).
22
- betas (Tuple[float, float], optional): coefficients for computing
23
- running averages of the update (default: (0.9, 0.99)).
24
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
- vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
- matrices to apply low-rank compression (default: True).
27
- stochastic_rounding (bool, optional): whether to use stochastic
28
- rounding for BF16 parameter updates (default: True).
29
- cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
- nnmf_factor (bool): whether to use the factorization or use the
35
- uncompressed optimizer. (default: True)
36
- d0 (float):
37
- Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
38
- d_coef (float):
39
- Coefficient in the expression for the estimate of d (default 1.0).
40
- Values such as 0.5 and 2.0 typically work as well.
41
- Changing this parameter is the preferred way to tune the method.
42
- growth_rate (float):
43
- prevent the D estimate from growing faster than this multiplicative rate.
44
- Default is inf, for unrestricted. Values like 1.02 give a kind of learning
45
- rate warmup effect.
46
- fsdp_in_use (bool):
47
- If you're using sharded parameters, this should be set to True. The optimizer
48
- will attempt to auto-detect this, but if you're using an implementation other
49
- than PyTorch's builtin version, the auto-detection won't work.
50
- slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
51
- pth entry of each tensor. For values greater than 1 this an an approximation to standard
52
- Prodigy. Values ~11 are reasonable (default 11).
53
- """
54
-
55
- def __init__(
56
- self,
57
- params,
58
- lr: float = 1,
59
- betas: Tuple[float, float] = (0.9, 0.99),
60
- weight_decay: float = 0.0,
61
- vector_reshape: bool = True,
62
- stochastic_rounding: bool = True,
63
- orthogonal_gradient: bool = False,
64
- cautious_mask: bool = False,
65
- clip_threshold: float = 0.0,
66
- nnmf_factor: bool = True,
67
- # prodigy parameters
68
- beta3: float = None,
69
- d0: float = 1e-6,
70
- d_coef: float = 1,
71
- growth_rate: float = float('inf'),
72
- safeguard_warmup: bool = False,
73
- fsdp_in_use: bool = False,
74
- slice_p: int = 11,
75
- ):
76
- if not lr > 0.0:
77
- raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
78
- if not all(0.0 <= beta <= 1.0 for beta in betas):
79
- raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
80
- if not weight_decay >= 0.0:
81
- raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
82
-
83
- defaults = dict(
84
- lr=lr,
85
- betas=betas,
86
- weight_decay=weight_decay,
87
- vector_reshape=vector_reshape,
88
- orthogonal_gradient=orthogonal_gradient,
89
- clip_threshold=clip_threshold,
90
- beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
91
- growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
92
- fsdp_in_use=fsdp_in_use,
93
- )
94
- self.stochastic_rounding = stochastic_rounding
95
- self.cautious_mask = cautious_mask
96
- self.factored = nnmf_factor
97
- self.fsdp_in_use = fsdp_in_use
98
- super().__init__(params, defaults)
99
- # Global state for accumulating metrics across parameter updates within a single step.
100
- self.init_step()
101
-
102
- @property
103
- def supports_fused_back_pass(self) -> bool:
104
- return True
105
-
106
- @property
107
- def supports_memory_efficient_fp16(self) -> bool:
108
- return True
109
-
110
- @property
111
- def supports_flat_params(self) -> bool:
112
- return False
113
-
114
- def init_step(self):
115
- """Resets accumulators and calculates dlr for the upcoming step."""
116
- self.d_denom = 0.0
117
-
118
- g_group = self.param_groups[0]
119
- self.beta1, self.beta2 = g_group['betas']
120
- self.beta3 = g_group['beta3']
121
- if self.beta3 is None:
122
- self.beta3 = math.sqrt(self.beta2)
123
-
124
- k = g_group['k']
125
- self.d = g_group['d']
126
- lr = g_group['lr']
127
-
128
- self.dlr = self.d * lr
129
-
130
- self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
131
-
132
- @torch.no_grad()
133
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
134
- """Performs a single optimization step on a single parameter."""
135
- if p.grad is None:
136
- return
137
-
138
- if hasattr(p, "_fsdp_flattened"):
139
- self.fsdp_in_use = True
140
-
141
- grad = p.grad
142
- if grad.dtype != torch.float32 and self.factored:
143
- grad = grad.float()
144
- if group["clip_threshold"] > 0.0:
145
- grad_norm = torch.norm(grad.detach())
146
- if grad_norm > group["clip_threshold"]:
147
- clip_coef = group["clip_threshold"] / grad_norm
148
- grad.mul_(clip_coef)
149
- if group["orthogonal_gradient"]:
150
- grad = _orthogonalize_gradient(p, grad)
151
- state = self.state[p]
152
-
153
- # State Initialization
154
- if len(state) == 0:
155
- state['step'] = 0
156
-
157
- should_factor = (
158
- self.factored and
159
- not (len(p.shape) == 1 and not group['vector_reshape'])
160
- )
161
-
162
- state['factored'] = should_factor
163
-
164
- dtype = torch.float32 if self.factored else p.dtype
165
-
166
- slice_p = group['slice_p']
167
-
168
- # D-Adaptation states
169
- state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
170
- if p.any():
171
- state['p0'] = p.flatten()[::slice_p].detach().clone()
172
- else:
173
- state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
174
-
175
- if state['factored']:
176
- state['effective_shape'] = _get_effective_shape(p.numel())
177
- d1, d2 = state['effective_shape']
178
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
179
- state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
180
- packed_d2 = (d2 + 7) // 8
181
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
182
- else: # Fallback to standard Lion
183
- state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
184
-
185
- if state['factored']:
186
- # Factored Path
187
- d1, d2 = state['effective_shape']
188
- grad_reshaped = grad.view(d1, d2)
189
- # Reconstruct momentum m_{t-1}
190
- exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
191
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
192
- torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
193
- del unpacked_sign
194
- if exp_avg.dtype != torch.float32:
195
- exp_avg = exp_avg.float()
196
-
197
- # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
198
- signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
199
-
200
- if self.cautious_mask:
201
- mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
202
- mask.div_(mask.mean().clamp_(min=1e-3))
203
- signed_update.mul_(mask)
204
- del mask
205
-
206
- # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
207
- update_for_param = signed_update.view(p.shape).mul(self.dlr)
208
-
209
- # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
210
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
211
- del grad_reshaped
212
-
213
- # Compress new momentum m_t and store factors
214
- state['sign'] = _pack_bools(exp_avg > 0)
215
- _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
216
- del exp_avg
217
-
218
- else:
219
- # Fallback to standard Lion logic
220
- exp_avg = state["exp_avg"]
221
-
222
- # Compute update term and sign for the update
223
- if exp_avg.dtype != torch.float32 and self.factored:
224
- exp_avg = exp_avg.float()
225
- signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
226
-
227
- if self.cautious_mask:
228
- mask = (signed_update * grad > 0).to(grad.dtype)
229
- mask.div_(mask.mean().clamp_(min=1e-3))
230
- signed_update.mul_(mask)
231
- del mask
232
-
233
- update_for_param = signed_update.mul(self.dlr)
234
-
235
- # Update momentum
236
- exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
237
-
238
- # --- Accumulate Prodigy stats ---
239
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
240
- s, p0 = state['s'], state['p0']
241
- grad_flat = grad.flatten().float()
242
- p_flat = p.data.flatten().float()
243
- p0 = p0.float()
244
-
245
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
246
-
247
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
248
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
249
- self.d_denom += s.abs().sum().item()
250
-
251
- del s, p0, grad_flat, p_flat, alpha
252
-
253
- if group["weight_decay"] != 0:
254
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
255
- add_stochastic_(p.data, p.data,
256
- alpha=-group["weight_decay"] * self.dlr)
257
- else:
258
- p.data.add_(
259
- p.data, alpha=-group["weight_decay"] * self.dlr
260
- )
261
-
262
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
263
- add_stochastic_(p.data, -update_for_param)
264
- else:
265
- p.data.add_(-update_for_param)
266
-
267
- del update_for_param
268
-
269
- @torch.no_grad()
270
- def step(self, closure: Optional[callable] = None):
271
- """Performs a single optimization step."""
272
- loss = None
273
- if closure is not None:
274
- with torch.enable_grad():
275
- loss = closure()
276
-
277
- for group in self.param_groups:
278
- for i, p in enumerate(group["params"]):
279
- if p.grad is not None:
280
- self.step_parameter(p, group, i)
281
-
282
-
283
- self.calculate_d()
284
- self.init_step()
285
- return loss
286
-
287
- def calculate_d(self):
288
- """Calculates the new `d` based on the accumulated stats."""
289
- g_group = self.param_groups[0]
290
- d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
291
-
292
- if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
293
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
294
- device = self.param_groups[0]['params'][0].device
295
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
296
- dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
297
- global_d_numerator = dist_tensor[0].item()
298
- global_d_denom = dist_tensor[1].item()
299
- else:
300
- global_d_numerator = self.d_numerator
301
- global_d_denom = self.d_denom
302
-
303
- d_hat = self.d
304
- if global_d_denom > 0:
305
- d_hat = d_coef * global_d_numerator / global_d_denom
306
- if self.d == g_group['d0']:
307
- self.d = max(self.d, d_hat)
308
- d_max = max(d_max, d_hat)
309
- self.d = min(d_max, self.d * growth_rate)
310
-
311
- for group in self.param_groups:
312
- group['d_numerator'] = global_d_numerator
313
- group['d'] = self.d
314
- group['d_max'] = d_max
315
- group['k'] += 1
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+
6
+ from typing import Tuple, Optional
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Effective_Shape import _get_effective_shape
10
+ from ..util.NNMF import _nnmf,_unnmf
11
+ from ..util.OrthoGrad import _orthogonalize_gradient
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+
14
+ class Lion_Prodigy_adv(torch.optim.Optimizer):
15
+ """
16
+ Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ nnmf_factor (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ d0 (float):
37
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
38
+ d_coef (float):
39
+ Coefficient in the expression for the estimate of d (default 1.0).
40
+ Values such as 0.5 and 2.0 typically work as well.
41
+ Changing this parameter is the preferred way to tune the method.
42
+ growth_rate (float):
43
+ prevent the D estimate from growing faster than this multiplicative rate.
44
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
45
+ rate warmup effect.
46
+ fsdp_in_use (bool):
47
+ If you're using sharded parameters, this should be set to True. The optimizer
48
+ will attempt to auto-detect this, but if you're using an implementation other
49
+ than PyTorch's builtin version, the auto-detection won't work.
50
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
51
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
52
+ Prodigy. Values ~11 are reasonable (default 11).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ params,
58
+ lr: float = 1,
59
+ betas: Tuple[float, float] = (0.9, 0.99),
60
+ weight_decay: float = 0.0,
61
+ vector_reshape: bool = True,
62
+ stochastic_rounding: bool = True,
63
+ orthogonal_gradient: bool = False,
64
+ cautious_mask: bool = False,
65
+ clip_threshold: float = 0.0,
66
+ nnmf_factor: bool = True,
67
+ # prodigy parameters
68
+ beta3: float = None,
69
+ d0: float = 1e-6,
70
+ d_coef: float = 1,
71
+ growth_rate: float = float('inf'),
72
+ safeguard_warmup: bool = False,
73
+ fsdp_in_use: bool = False,
74
+ slice_p: int = 11,
75
+ ):
76
+ if not lr > 0.0:
77
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
78
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
79
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
80
+ if not weight_decay >= 0.0:
81
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
82
+
83
+ defaults = dict(
84
+ lr=lr,
85
+ betas=betas,
86
+ weight_decay=weight_decay,
87
+ vector_reshape=vector_reshape,
88
+ orthogonal_gradient=orthogonal_gradient,
89
+ clip_threshold=clip_threshold,
90
+ beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
91
+ growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
92
+ fsdp_in_use=fsdp_in_use,
93
+ )
94
+ self.stochastic_rounding = stochastic_rounding
95
+ self.cautious_mask = cautious_mask
96
+ self.factored = nnmf_factor
97
+ self.fsdp_in_use = fsdp_in_use
98
+ super().__init__(params, defaults)
99
+ # Global state for accumulating metrics across parameter updates within a single step.
100
+ self.init_step()
101
+
102
+ @property
103
+ def supports_fused_back_pass(self) -> bool:
104
+ return True
105
+
106
+ @property
107
+ def supports_memory_efficient_fp16(self) -> bool:
108
+ return True
109
+
110
+ @property
111
+ def supports_flat_params(self) -> bool:
112
+ return False
113
+
114
+ def init_step(self):
115
+ """Resets accumulators and calculates dlr for the upcoming step."""
116
+ self.d_denom = 0.0
117
+
118
+ g_group = self.param_groups[0]
119
+ self.beta1, self.beta2 = g_group['betas']
120
+ self.beta3 = g_group['beta3']
121
+ if self.beta3 is None:
122
+ self.beta3 = math.sqrt(self.beta2)
123
+
124
+ k = g_group['k']
125
+ self.d = g_group['d']
126
+ lr = g_group['lr']
127
+
128
+ self.dlr = self.d * lr
129
+
130
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
131
+
132
+ @torch.no_grad()
133
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
134
+ """Performs a single optimization step on a single parameter."""
135
+ if p.grad is None:
136
+ return
137
+
138
+ if hasattr(p, "_fsdp_flattened"):
139
+ self.fsdp_in_use = True
140
+
141
+ grad = p.grad
142
+ if grad.dtype != torch.float32 and self.factored:
143
+ grad = grad.float()
144
+ if group["clip_threshold"] > 0.0:
145
+ grad_norm = torch.norm(grad.detach())
146
+ if grad_norm > group["clip_threshold"]:
147
+ clip_coef = group["clip_threshold"] / grad_norm
148
+ grad.mul_(clip_coef)
149
+ if group["orthogonal_gradient"]:
150
+ grad = _orthogonalize_gradient(p, grad)
151
+ state = self.state[p]
152
+
153
+ # State Initialization
154
+ if 'step' not in state:
155
+ state['step'] = 0
156
+
157
+ should_factor = (
158
+ self.factored and
159
+ not (len(p.shape) == 1 and not group['vector_reshape'])
160
+ )
161
+
162
+ state['factored'] = should_factor
163
+
164
+ dtype = torch.float32 if self.factored else p.dtype
165
+
166
+ slice_p = group['slice_p']
167
+
168
+ # D-Adaptation states
169
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
170
+ if p.any():
171
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
172
+ else:
173
+ state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
174
+
175
+ if state['factored']:
176
+ state['effective_shape'] = _get_effective_shape(p.numel())
177
+ d1, d2 = state['effective_shape']
178
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
179
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
180
+ packed_d2 = (d2 + 7) // 8
181
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
182
+ else: # Fallback to standard Lion
183
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
184
+
185
+ if state['factored']:
186
+ # Factored Path
187
+ d1, d2 = state['effective_shape']
188
+ grad_reshaped = grad.view(d1, d2)
189
+ # Reconstruct momentum m_{t-1}
190
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
191
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
192
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
193
+ del unpacked_sign
194
+ if exp_avg.dtype != torch.float32:
195
+ exp_avg = exp_avg.float()
196
+
197
+ # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
198
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
199
+
200
+ if self.cautious_mask:
201
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
202
+ mask.div_(mask.mean().clamp_(min=1e-3))
203
+ signed_update.mul_(mask)
204
+ del mask
205
+
206
+ # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
207
+ update_for_param = signed_update.view(p.shape).mul(self.dlr)
208
+
209
+ # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
210
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
211
+ del grad_reshaped
212
+
213
+ # Compress new momentum m_t and store factors
214
+ state['sign'] = _pack_bools(exp_avg > 0)
215
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
216
+ del exp_avg
217
+
218
+ else:
219
+ # Fallback to standard Lion logic
220
+ exp_avg = state["exp_avg"]
221
+
222
+ # Compute update term and sign for the update
223
+ if exp_avg.dtype != torch.float32 and self.factored:
224
+ exp_avg = exp_avg.float()
225
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
226
+
227
+ if self.cautious_mask:
228
+ mask = (signed_update * grad > 0).to(grad.dtype)
229
+ mask.div_(mask.mean().clamp_(min=1e-3))
230
+ signed_update.mul_(mask)
231
+ del mask
232
+
233
+ update_for_param = signed_update.mul(self.dlr)
234
+
235
+ # Update momentum
236
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
237
+
238
+ # --- Accumulate Prodigy stats ---
239
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
240
+ s, p0 = state['s'], state['p0']
241
+ grad_flat = grad.flatten().float()
242
+ p_flat = p.data.flatten().float()
243
+ p0 = p0.float()
244
+
245
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
246
+
247
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
248
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
249
+ self.d_denom += s.abs().sum().item()
250
+
251
+ del s, p0, grad_flat, p_flat, alpha
252
+
253
+ if group["weight_decay"] != 0:
254
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
255
+ add_stochastic_(p.data, p.data,
256
+ alpha=-group["weight_decay"] * self.dlr)
257
+ else:
258
+ p.data.add_(
259
+ p.data, alpha=-group["weight_decay"] * self.dlr
260
+ )
261
+
262
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
263
+ add_stochastic_(p.data, -update_for_param)
264
+ else:
265
+ p.data.add_(-update_for_param)
266
+
267
+ del update_for_param
268
+
269
+ @torch.no_grad()
270
+ def step(self, closure: Optional[callable] = None):
271
+ """Performs a single optimization step."""
272
+ loss = None
273
+ if closure is not None:
274
+ with torch.enable_grad():
275
+ loss = closure()
276
+
277
+ for group in self.param_groups:
278
+ for i, p in enumerate(group["params"]):
279
+ if p.grad is not None:
280
+ self.step_parameter(p, group, i)
281
+
282
+
283
+ self.calculate_d()
284
+ self.init_step()
285
+ return loss
286
+
287
+ def calculate_d(self):
288
+ """Calculates the new `d` based on the accumulated stats."""
289
+ g_group = self.param_groups[0]
290
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
291
+
292
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
293
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
294
+ device = self.param_groups[0]['params'][0].device
295
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
296
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
297
+ global_d_numerator = dist_tensor[0].item()
298
+ global_d_denom = dist_tensor[1].item()
299
+ else:
300
+ global_d_numerator = self.d_numerator
301
+ global_d_denom = self.d_denom
302
+
303
+ d_hat = self.d
304
+ if global_d_denom > 0:
305
+ d_hat = d_coef * global_d_numerator / global_d_denom
306
+ if self.d == g_group['d0']:
307
+ self.d = max(self.d, d_hat)
308
+ d_max = max(d_max, d_hat)
309
+ self.d = min(d_max, self.d * growth_rate)
310
+
311
+ for group in self.param_groups:
312
+ group['d_numerator'] = global_d_numerator
313
+ group['d'] = self.d
314
+ group['d_max'] = d_max
315
+ group['k'] += 1
@@ -99,7 +99,7 @@ class Lion_adv(torch.optim.Optimizer):
99
99
  state = self.state[p]
100
100
 
101
101
  # State Initialization
102
- if len(state) == 0:
102
+ if 'step' not in state:
103
103
  state['step'] = 0
104
104
 
105
105
  should_factor = (