adv-optm 0.1.3__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (23) hide show
  1. {adv_optm-0.1.3 → adv_optm-0.1.5}/PKG-INFO +1 -1
  2. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/__init__.py +1 -1
  3. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/optim/AdamW_adv.py +308 -296
  4. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/optim/Lion_Prodigy_adv.py +22 -6
  5. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/optim/Lion_adv.py +242 -228
  6. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/optim/Prodigy_adv.py +2 -2
  7. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm.egg-info/PKG-INFO +1 -1
  8. {adv_optm-0.1.3 → adv_optm-0.1.5}/setup.py +1 -1
  9. {adv_optm-0.1.3 → adv_optm-0.1.5}/LICENSE +0 -0
  10. {adv_optm-0.1.3 → adv_optm-0.1.5}/README.md +0 -0
  11. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/optim/__init__.py +0 -0
  13. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  14. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/util/Effective_Shape.py +0 -0
  15. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/util/NNMF.py +0 -0
  16. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/util/One_Bit_Boolean.py +0 -0
  17. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/util/OrthoGrad.py +0 -0
  18. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm/util/__init__.py +0 -0
  19. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm.egg-info/SOURCES.txt +0 -0
  20. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm.egg-info/dependency_links.txt +0 -0
  21. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm.egg-info/requires.txt +0 -0
  22. {adv_optm-0.1.3 → adv_optm-0.1.5}/adv_optm.egg-info/top_level.txt +0 -0
  23. {adv_optm-0.1.3 → adv_optm-0.1.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -14,4 +14,4 @@ __all__ = [
14
14
  "Lion_Prodigy_adv",
15
15
  ]
16
16
 
17
- __version__ = "0.1.3"
17
+ __version__ = "0.1.5"
@@ -1,297 +1,309 @@
1
- import torch
2
- from typing import Optional
3
-
4
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
- from ..util.Effective_Shape import _get_effective_shape
6
- from ..util.NNMF import _nnmf,_unnmf
7
- from ..util.OrthoGrad import _orthogonalize_gradient
8
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
-
10
- class AdamW_adv(torch.optim.Optimizer):
11
- """
12
- Implements a factored AdamW algorithm.
13
- This is an advanced version of AdamW with optional features like
14
- low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
15
-
16
- Args:
17
- params (iterable): iterable of parameters to optimize or dicts defining
18
- parameter groups
19
- lr (float): learning rate (default: 1e-3)
20
- betas (tuple[float, float]): coefficients used for computing running
21
- averages of gradient and its square (default: (0.9, 0.999))
22
- eps (float): term added to the denominator to improve
23
- numerical stability (default: 1e-8)
24
- weight_decay (float): weight decay (L2 penalty) (default: 0)
25
- vector_reshape (bool): whether to reshape 1D vectors into 2D
26
- matrices to apply low-rank compression (default: True).
27
- stochastic_rounding (bool): whether to use stochastic
28
- rounding for BF16 parameter updates (default: True).
29
- use_atan2 (bool): whether to use the atan2 update rule. (default: False)
30
- use_grams (bool): whether to use Grams-style updates. (default: False)
31
- use_cautious (bool): whether to use cautious masking to align the gradient's
32
- direction with the first moment's. (default: False)
33
- use_orthograd (bool): whether to use OrthoGrad. (default: False)
34
- use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
35
- a second, slow-moving average of the momentum (`mt_slow`) which is
36
- combined with the primary momentum (`mt`) to stabilize updates,
37
- especially in noisy, small-batch settings. If `False`, the
38
- optimizer behaves as standard AdamW. (default: False)
39
- beta3_ema (float): The decay rate for the slow exponential moving average of
40
- the momentum (only used when `use_AdEMAMix` is `True`). A higher
41
- value (e.g., 0.9999) gives the EMA a longer memory, making it more
42
- stable but slower to adapt. A lower value (e.g., 0.999) is often
43
- better for shorter training runs. (default: 0.9999)
44
- alpha (float): The mixing coefficient that scales the slow momentum term
45
- before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
46
- A higher value increases the stabilizing influence of the slow
47
- momentum. (default: 5.0)
48
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
49
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
50
- highly recommended to prevent instability at the beginning of training,
51
- as it gradually introduces the stabilizing slow momentum term. During
52
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
53
- the scheduler is disabled and th
54
- factored (bool): whether to use the factorization or disable it to use
55
- the uncompressed optimizer. (default: True)
56
- """
57
-
58
- def __init__(
59
- self,
60
- params,
61
- lr: float = 1e-3,
62
- betas: tuple[float, float] = (0.9, 0.999),
63
- eps: float = 1e-8,
64
- weight_decay: float = 0.0,
65
- vector_reshape: bool = True,
66
- stochastic_rounding: bool = True,
67
- use_atan2: bool = False,
68
- use_cautious: bool = False,
69
- use_grams: bool = False,
70
- use_orthograd: bool = False,
71
- use_AdEMAMix: bool = False,
72
- beta3_ema: float = 0.9999,
73
- alpha: float = 5.0,
74
- t_alpha: int | None = None,
75
- factored: bool = True,
76
- ):
77
- if not (lr >= 0.0):
78
- raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
79
- if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
80
- raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
81
- if not (eps >= 0.0):
82
- raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
83
- if not (weight_decay >= 0.0):
84
- raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
85
-
86
- defaults = {
87
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
88
- "vector_reshape": vector_reshape, "use_atan2": use_atan2,
89
- "use_orthograd": use_orthograd,
90
- "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
91
- }
92
- self.stochastic_rounding = stochastic_rounding
93
- self.use_cautious = use_cautious
94
- self.use_grams = use_grams
95
- self.use_AdEMAMix = use_AdEMAMix
96
- self.factored = factored
97
- super().__init__(params, defaults)
98
-
99
- @property
100
- def supports_fused_back_pass(self):
101
- return True
102
-
103
- @property
104
- def supports_memory_efficient_fp16(self):
105
- return True
106
-
107
- @property
108
- def supports_flat_params(self):
109
- return False
110
-
111
- @torch.no_grad()
112
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
113
- if p.grad is None:
114
- return
115
-
116
- grad = p.grad
117
- if grad.dtype != torch.float32 and self.factored:
118
- grad = grad.float()
119
- if group["use_orthograd"]:
120
- grad = _orthogonalize_gradient(p, grad)
121
- state = self.state[p]
122
-
123
- beta1, beta2 = group['betas']
124
-
125
- # State Initialization
126
- if len(state) == 0:
127
- state['step'] = 0
128
-
129
- should_factor = (
130
- self.factored and
131
- not (len(p.shape) == 1 and not group['vector_reshape'])
132
- )
133
-
134
- state['factored'] = should_factor
135
-
136
- dtype = torch.float32 if self.factored else p.dtype
137
- device = p.device
138
-
139
- if state['factored']:
140
- state['effective_shape'] = _get_effective_shape(p.numel())
141
- d1, d2 = state['effective_shape']
142
-
143
- # First moment (m)
144
- if beta1 > 0:
145
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
146
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
147
- if not self.use_grams:
148
- packed_d2 = (d2 + 7) // 8
149
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
150
- if self.use_AdEMAMix:
151
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
152
- state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
153
- packed_d2 = (d2 + 7) // 8
154
- state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
155
- # Second moment (v)
156
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
157
- state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
158
- else: # Fallback to standard AdamW for non-factored tensors
159
- if beta1 > 0:
160
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
161
- if self.use_AdEMAMix:
162
- state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
163
- state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
164
-
165
- if self.use_AdEMAMix:
166
- beta3_ema = group['beta3_ema']
167
- alpha = group['alpha']
168
- t_alpha = group['t_alpha']
169
- current_step = state['step'] + 1
170
- alpha_t = alpha
171
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
172
- alpha_t = min(current_step * alpha / t_alpha, alpha)
173
-
174
- if state['factored']:
175
- d1, d2 = state['effective_shape']
176
-
177
- # Reconstruct momentum from previous step's factors
178
- if beta1 > 0:
179
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
180
- if not self.use_grams:
181
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
182
- torch.where(unpacked_sign, mt, -mt, out=mt)
183
- del unpacked_sign
184
- # Update momentum in full-size
185
- grad_reshaped = grad.view(d1, d2)
186
- mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
187
- if self.use_grams:
188
- mt.copy_(grad_reshaped.sign() * mt.abs())
189
- elif self.use_cautious:
190
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
191
- mask.div_(mask.mean().clamp_(min=1e-3))
192
- mt.mul_(mask)
193
- del mask
194
-
195
- vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
196
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
197
-
198
- if self.use_AdEMAMix:
199
- mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
200
- if state['sign_slow'].dtype != torch.uint8:
201
- state['sign_slow'] = state['sign_slow'].to(torch.uint8)
202
- unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
203
- torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
204
- del unpacked_sign_slow
205
-
206
- mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
207
- update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
208
- else:
209
- update = mt if beta1 > 0 else grad_reshaped
210
- del grad_reshaped
211
-
212
- if group['use_atan2']:
213
- a = 1.2732395
214
- denom = vt.sqrt()
215
- update.atan2_(denom).mul_(a)
216
- else:
217
- denom = vt.sqrt()
218
- update.div_(denom.add_(group['eps']))
219
- del denom
220
-
221
- update.view(p.shape).mul_(group['lr'])
222
-
223
- # Compress updated moments and store new factors
224
- if beta1 > 0:
225
- if not self.use_grams:
226
- state['sign'] = _pack_bools(mt > 0)
227
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
228
- del mt
229
- if self.use_AdEMAMix:
230
- state['sign_slow'] = _pack_bools(mt_slow > 0)
231
- _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
232
- del mt_slow
233
- _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
234
- del vt
235
-
236
- else: # Standard AdamW logic for non-factored tensors
237
- exp_avg_sq = state['exp_avg_sq']
238
-
239
- if beta1 > 0:
240
- exp_avg = state['exp_avg']
241
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
242
- if self.use_grams:
243
- exp_avg = grad.sign() * exp_avg.abs()
244
- elif self.use_cautious:
245
- mask = (exp_avg * grad > 0).to(grad.dtype)
246
- mask.div_(mask.mean().clamp_(min=1e-3))
247
- exp_avg.mul_(mask)
248
- del mask
249
-
250
- if self.use_AdEMAMix:
251
- exp_avg_slow = state['exp_avg_slow']
252
- exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
253
- update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
254
- else:
255
- update = exp_avg if beta1 > 0 else grad
256
-
257
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
258
-
259
- if group['use_atan2']:
260
- a = 1.2732395
261
- denom = exp_avg_sq.sqrt()
262
- update.atan2_(denom).mul_(a)
263
- else:
264
- denom = exp_avg_sq.sqrt()
265
- update.div_(denom.add_(group['eps']))
266
- del denom
267
-
268
- update.mul_(group['lr'])
269
-
270
- # Decoupled weight decay
271
- if group["weight_decay"] != 0:
272
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
273
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
274
- else:
275
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
276
-
277
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
278
- add_stochastic_(p.data, -update)
279
- else:
280
- p.data.add_(-update)
281
- del update
282
-
283
- state['step'] += 1
284
-
285
- @torch.no_grad()
286
- def step(self, closure=None):
287
- """Performs a single optimization step."""
288
- loss = None
289
- if closure is not None:
290
- with torch.enable_grad():
291
- loss = closure()
292
-
293
- for group in self.param_groups:
294
- for i, p in enumerate(group['params']):
295
- self.step_parameter(p, group, i)
296
-
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+
10
+ class AdamW_adv(torch.optim.Optimizer):
11
+ """
12
+ Implements a factored AdamW algorithm.
13
+ This is an advanced version of AdamW with optional features like
14
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
15
+
16
+ Args:
17
+ params (iterable): iterable of parameters to optimize or dicts defining
18
+ parameter groups
19
+ lr (float): learning rate (default: 1e-3)
20
+ betas (tuple[float, float]): coefficients used for computing running
21
+ averages of gradient and its square (default: (0.9, 0.999))
22
+ eps (float): term added to the denominator to improve
23
+ numerical stability (default: 1e-8)
24
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
25
+ use_bias_correction (bool): whether to use bias correction for the first
26
+ and second moment estimates, as in the original Adam paper.
27
+ (default: True)
28
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
29
+ matrices to apply low-rank compression (default: True).
30
+ stochastic_rounding (bool): whether to use stochastic
31
+ rounding for BF16 parameter updates (default: True).
32
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
33
+ use_grams (bool): whether to use Grams-style updates. (default: False)
34
+ use_cautious (bool): whether to use cautious masking to align the gradient's
35
+ direction with the first moment's. (default: False)
36
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
37
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
38
+ a second, slow-moving average of the momentum (`mt_slow`) which is
39
+ combined with the primary momentum (`mt`) to stabilize updates,
40
+ especially in noisy, small-batch settings. If `False`, the
41
+ optimizer behaves as standard AdamW. (default: False)
42
+ beta3_ema (float): The decay rate for the slow exponential moving average of
43
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
44
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
45
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
46
+ better for shorter training runs. (default: 0.9999)
47
+ alpha (float): The mixing coefficient that scales the slow momentum term
48
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
49
+ A higher value increases the stabilizing influence of the slow
50
+ momentum. (default: 5.0)
51
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
52
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
53
+ highly recommended to prevent instability at the beginning of training,
54
+ as it gradually introduces the stabilizing slow momentum term. During
55
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
+ the scheduler is disabled. (default: None)
57
+ factored (bool): whether to use the factorization or disable it to use
58
+ the uncompressed optimizer. (default: True)
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ params,
64
+ lr: float = 1e-3,
65
+ betas: tuple[float, float] = (0.9, 0.999),
66
+ eps: float = 1e-8,
67
+ weight_decay: float = 0.0,
68
+ use_bias_correction: bool = True,
69
+ vector_reshape: bool = True,
70
+ stochastic_rounding: bool = True,
71
+ use_atan2: bool = False,
72
+ use_cautious: bool = False,
73
+ use_grams: bool = False,
74
+ use_orthograd: bool = False,
75
+ use_AdEMAMix: bool = False,
76
+ beta3_ema: float = 0.9999,
77
+ alpha: float = 5.0,
78
+ t_alpha: int | None = None,
79
+ factored: bool = True,
80
+ ):
81
+ if not (lr >= 0.0):
82
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
83
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
84
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
85
+ if not (eps >= 0.0):
86
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
+ if not (weight_decay >= 0.0):
88
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
89
+
90
+ defaults = {
91
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
92
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
93
+ "use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
94
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
95
+ }
96
+ self.stochastic_rounding = stochastic_rounding
97
+ self.use_cautious = use_cautious
98
+ self.use_grams = use_grams
99
+ self.use_AdEMAMix = use_AdEMAMix
100
+ self.factored = factored
101
+ super().__init__(params, defaults)
102
+
103
+ @property
104
+ def supports_fused_back_pass(self):
105
+ return True
106
+
107
+ @property
108
+ def supports_memory_efficient_fp16(self):
109
+ return True
110
+
111
+ @property
112
+ def supports_flat_params(self):
113
+ return False
114
+
115
+ @torch.no_grad()
116
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
117
+ if p.grad is None:
118
+ return
119
+
120
+ grad = p.grad
121
+ if grad.dtype != torch.float32 and self.factored:
122
+ grad = grad.float()
123
+ if group["use_orthograd"]:
124
+ grad = _orthogonalize_gradient(p, grad)
125
+ state = self.state[p]
126
+
127
+ beta1, beta2 = group['betas']
128
+
129
+ # State Initialization
130
+ if len(state) == 0:
131
+ state['step'] = 0
132
+
133
+ should_factor = (
134
+ self.factored and
135
+ not (len(p.shape) == 1 and not group['vector_reshape'])
136
+ )
137
+
138
+ state['factored'] = should_factor
139
+
140
+ dtype = torch.float32 if self.factored else p.dtype
141
+ device = p.device
142
+
143
+ if state['factored']:
144
+ state['effective_shape'] = _get_effective_shape(p.numel())
145
+ d1, d2 = state['effective_shape']
146
+
147
+ # First moment (m)
148
+ if beta1 > 0:
149
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
150
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
151
+ if not self.use_grams:
152
+ packed_d2 = (d2 + 7) // 8
153
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
154
+ if self.use_AdEMAMix:
155
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
156
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
157
+ packed_d2 = (d2 + 7) // 8
158
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
159
+ # Second moment (v)
160
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
161
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
162
+ else: # Fallback to standard AdamW for non-factored tensors
163
+ if beta1 > 0:
164
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
165
+ if self.use_AdEMAMix:
166
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
167
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
168
+
169
+ step = state['step'] + 1
170
+ if group['use_bias_correction']:
171
+ bias_correction1 = 1.0 - beta1 ** step
172
+ bias_correction2 = 1.0 - beta2 ** step
173
+ else:
174
+ bias_correction1 = 1
175
+ bias_correction2 = 1
176
+ step_size = group['lr'] / bias_correction1
177
+
178
+ if self.use_AdEMAMix:
179
+ beta3_ema = group['beta3_ema']
180
+ alpha = group['alpha']
181
+ t_alpha = group['t_alpha']
182
+ alpha_t = alpha
183
+ if t_alpha is not None and t_alpha > 0 and step < t_alpha:
184
+ alpha_t = min(step * alpha / t_alpha, alpha)
185
+
186
+ if state['factored']:
187
+ d1, d2 = state['effective_shape']
188
+
189
+ # Reconstruct momentum from previous step's factors
190
+ if beta1 > 0:
191
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
192
+ if not self.use_grams:
193
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
194
+ torch.where(unpacked_sign, mt, -mt, out=mt)
195
+ del unpacked_sign
196
+ # Update momentum in full-size
197
+ grad_reshaped = grad.view(d1, d2)
198
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
199
+ if self.use_grams:
200
+ mt.copy_(grad_reshaped.sign() * mt.abs())
201
+ elif self.use_cautious:
202
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
203
+ mask.div_(mask.mean().clamp_(min=1e-3))
204
+ mt.mul_(mask)
205
+ del mask
206
+
207
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
208
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
209
+
210
+ if self.use_AdEMAMix:
211
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
212
+ if state['sign_slow'].dtype != torch.uint8:
213
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
214
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
215
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
216
+ del unpacked_sign_slow
217
+
218
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
219
+ update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
220
+ else:
221
+ update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
222
+ del grad_reshaped
223
+
224
+ if group['use_atan2']:
225
+ a = 1.2732395
226
+ denom = (vt.sqrt() / (bias_correction2**0.5))
227
+ update.atan2_(denom).mul_(a)
228
+ else:
229
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['eps'])
230
+ update.div_(denom)
231
+ del denom
232
+
233
+ update.view(p.shape).mul_(step_size)
234
+
235
+ # Compress updated moments and store new factors
236
+ if beta1 > 0:
237
+ if not self.use_grams:
238
+ state['sign'] = _pack_bools(mt > 0)
239
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
240
+ del mt
241
+ if self.use_AdEMAMix:
242
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
243
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
244
+ del mt_slow
245
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
246
+ del vt
247
+
248
+ else: # Standard AdamW logic for non-factored tensors
249
+ exp_avg_sq = state['exp_avg_sq']
250
+
251
+ if beta1 > 0:
252
+ exp_avg = state['exp_avg']
253
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
254
+ if self.use_grams:
255
+ exp_avg = grad.sign() * exp_avg.abs()
256
+ elif self.use_cautious:
257
+ mask = (exp_avg * grad > 0).to(grad.dtype)
258
+ mask.div_(mask.mean().clamp_(min=1e-3))
259
+ exp_avg.mul_(mask)
260
+ del mask
261
+
262
+ if self.use_AdEMAMix:
263
+ exp_avg_slow = state['exp_avg_slow']
264
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
265
+ update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
266
+ else:
267
+ update = exp_avg.clone() if beta1 > 0 else grad.clone()
268
+
269
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
270
+
271
+ if group['use_atan2']:
272
+ a = 1.2732395
273
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
274
+ update.atan2_(denom).mul_(a)
275
+ else:
276
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['eps'])
277
+ update.div_(denom)
278
+ del denom
279
+
280
+ update.mul_(step_size)
281
+
282
+ # Decoupled weight decay
283
+ if group["weight_decay"] != 0:
284
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
285
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
286
+ else:
287
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
288
+
289
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
290
+ add_stochastic_(p.data, -update)
291
+ else:
292
+ p.data.add_(-update)
293
+ del update
294
+
295
+ state['step'] += 1
296
+
297
+ @torch.no_grad()
298
+ def step(self, closure=None):
299
+ """Performs a single optimization step."""
300
+ loss = None
301
+ if closure is not None:
302
+ with torch.enable_grad():
303
+ loss = closure()
304
+
305
+ for group in self.param_groups:
306
+ for i, p in enumerate(group['params']):
307
+ self.step_parameter(p, group, i)
308
+
297
309
  return loss
@@ -216,11 +216,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
216
216
 
217
217
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
218
218
  if self.variance_reduction:
219
- vr_term = grad_reshaped - state['prev_grad']
220
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
219
+ if state['step'] == 1:
220
+ exp_avg.copy_(grad_reshaped)
221
+ else:
222
+ # Heuristic Prodigy-STORM update
223
+ correction = exp_avg.sub(state['prev_grad'])
224
+ grad_alpha = self.d * (1 - self.beta2) + self.beta2
225
+ exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
226
+ del correction, grad_alpha
221
227
  state['prev_grad'].copy_(grad_reshaped)
222
228
  else:
223
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2))
229
+ # Standard Prodigy-Lion
230
+ alpha = self.d * (1 - self.beta2)
231
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
224
232
  del grad_reshaped
225
233
 
226
234
  # Compress new momentum m_t and store factors
@@ -247,11 +255,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
247
255
 
248
256
  # Update momentum
249
257
  if self.variance_reduction:
250
- vr_term = grad - state['prev_grad']
251
- exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
258
+ if state['step'] == 1:
259
+ exp_avg.copy_(grad)
260
+ else:
261
+ # Heuristic Prodigy-STORM update
262
+ correction = exp_avg.sub(state['prev_grad'])
263
+ grad_alpha = self.d * (1 - self.beta2) + self.beta2
264
+ exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
265
+ del grad_alpha, correction
252
266
  state['prev_grad'].copy_(grad)
253
267
  else:
254
- exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2))
268
+ # Standard Prodigy-Lion
269
+ alpha = self.d * (1 - self.beta2)
270
+ exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
255
271
 
256
272
  # --- Accumulate Prodigy stats ---
257
273
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
@@ -1,229 +1,243 @@
1
- import torch
2
-
3
- from typing import Tuple, Optional
4
-
5
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
- from ..util.Effective_Shape import _get_effective_shape
7
- from ..util.NNMF import _nnmf,_unnmf
8
- from ..util.OrthoGrad import _orthogonalize_gradient
9
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
-
11
- class Lion_adv(torch.optim.Optimizer):
12
- """
13
- Implements the SMMF technique for Lion algorithm.
14
-
15
- This optimizer combines the Lion update rule with the memory-saving low-rank
16
- compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
-
18
- Args:
19
- params (iterable): iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float, optional): learning rate (default: 1e-4).
22
- betas (Tuple[float, float], optional): coefficients for computing
23
- running averages of the update (default: (0.9, 0.99)).
24
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
- vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
- matrices to apply low-rank compression (default: True).
27
- stochastic_rounding (bool, optional): whether to use stochastic
28
- rounding for BF16 parameter updates (default: True).
29
- use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
- factored (bool): whether to use the factorization or use the
35
- uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
- """
39
-
40
- def __init__(
41
- self,
42
- params,
43
- lr: float = 1e-4,
44
- betas: Tuple[float, float] = (0.9, 0.99),
45
- weight_decay: float = 0.0,
46
- vector_reshape: bool = True,
47
- stochastic_rounding: bool = True,
48
- use_orthograd: bool = False,
49
- use_cautious: bool = False,
50
- clip_threshold: float = 0.0,
51
- factored: bool = True,
52
- variance_reduction: bool = False,
53
- ):
54
- if not lr > 0.0:
55
- raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
56
- if not all(0.0 <= beta <= 1.0 for beta in betas):
57
- raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
- if not weight_decay >= 0.0:
59
- raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
-
61
- defaults = dict(
62
- lr=lr,
63
- betas=betas,
64
- weight_decay=weight_decay,
65
- vector_reshape=vector_reshape,
66
- use_orthograd=use_orthograd,
67
- clip_threshold=clip_threshold,
68
- )
69
- self.stochastic_rounding = stochastic_rounding
70
- self.use_cautious = use_cautious
71
- self.factored = factored
72
- self.variance_reduction = variance_reduction
73
- super().__init__(params, defaults)
74
-
75
- @property
76
- def supports_fused_back_pass(self) -> bool:
77
- return True
78
-
79
- @property
80
- def supports_memory_efficient_fp16(self) -> bool:
81
- return True
82
-
83
- @property
84
- def supports_flat_params(self) -> bool:
85
- return False
86
-
87
- @torch.no_grad()
88
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
89
- """Performs a single optimization step on a single parameter."""
90
- if p.grad is None:
91
- return
92
-
93
- grad = p.grad
94
- if grad.dtype != torch.float32 and self.factored:
95
- grad = grad.float()
96
- if group["clip_threshold"] > 0.0:
97
- grad_norm = torch.norm(grad.detach())
98
- if grad_norm > group["clip_threshold"]:
99
- clip_coef = group["clip_threshold"] / grad_norm
100
- grad.mul_(clip_coef)
101
- if group["use_orthograd"]:
102
- grad = _orthogonalize_gradient(p, grad)
103
- state = self.state[p]
104
-
105
- # State Initialization
106
- if len(state) == 0:
107
- state['step'] = 0
108
-
109
- should_factor = (
110
- self.factored and
111
- not (len(p.shape) == 1 and not group['vector_reshape'])
112
- )
113
-
114
- state['factored'] = should_factor
115
-
116
- dtype = torch.float32 if self.factored else p.dtype
117
-
118
- if state['factored']:
119
- state['effective_shape'] = _get_effective_shape(p.numel())
120
- d1, d2 = state['effective_shape']
121
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
122
- state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
123
- packed_d2 = (d2 + 7) // 8
124
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
125
- if self.variance_reduction:
126
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
127
- else: # Fallback to standard Lion
128
- state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
129
- if self.variance_reduction:
130
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
-
132
- state['step'] += 1
133
- beta1, beta2 = group["betas"]
134
- lr = group["lr"]
135
-
136
- if state['factored']:
137
- # Factored Path
138
- d1, d2 = state['effective_shape']
139
- grad_reshaped = grad.view(d1, d2)
140
- # Reconstruct momentum m_{t-1}
141
- exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
142
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
143
- torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
144
- del unpacked_sign
145
- if exp_avg.dtype != torch.float32:
146
- exp_avg = exp_avg.float()
147
-
148
- # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
149
- signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
150
-
151
- if self.use_cautious:
152
- mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
153
- mask.div_(mask.mean().clamp_(min=1e-3))
154
- signed_update.mul_(mask)
155
- del mask
156
-
157
- # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
158
- update_for_param = signed_update.view(p.shape).mul_(lr)
159
-
160
- # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
161
- if self.variance_reduction:
162
- vr_term = grad_reshaped - state['prev_grad']
163
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2).add_(vr_term, alpha=beta2)
164
- del vr_term
165
- state['prev_grad'].copy_(grad_reshaped)
166
- else:
167
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
168
- del grad_reshaped
169
-
170
- # Compress new momentum m_t and store factors
171
- state['sign'] = _pack_bools(exp_avg > 0)
172
- _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
173
- del exp_avg
174
-
175
- else:
176
- # Fallback to standard Lion logic
177
- exp_avg = state["exp_avg"]
178
-
179
- # Compute update term and sign for the update
180
- if exp_avg.dtype != torch.float32 and self.factored:
181
- exp_avg = exp_avg.float()
182
- signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
183
-
184
- if self.use_cautious:
185
- mask = (signed_update * grad > 0).to(grad.dtype)
186
- mask.div_(mask.mean().clamp_(min=1e-3))
187
- signed_update.mul_(mask)
188
- del mask
189
-
190
- update_for_param = signed_update.mul_(lr)
191
-
192
- # Update momentum
193
- if self.variance_reduction:
194
- vr_term = grad - state['prev_grad']
195
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2).add_(vr_term, alpha=beta2)
196
- state['prev_grad'].copy_(grad)
197
- else:
198
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
199
-
200
- if group["weight_decay"] != 0:
201
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
202
- add_stochastic_(p.data, p.data,
203
- alpha=-group["weight_decay"] * lr)
204
- else:
205
- p.data.add_(
206
- p.data, alpha=-group["weight_decay"] * lr
207
- )
208
-
209
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
210
- add_stochastic_(p.data, -update_for_param)
211
- else:
212
- p.data.add_(-update_for_param)
213
-
214
- del update_for_param
215
-
216
- @torch.no_grad()
217
- def step(self, closure: Optional[callable] = None):
218
- """Performs a single optimization step."""
219
- loss = None
220
- if closure is not None:
221
- with torch.enable_grad():
222
- loss = closure()
223
-
224
- for group in self.param_groups:
225
- for i, p in enumerate(group["params"]):
226
- if p.grad is not None:
227
- self.step_parameter(p, group, i)
228
-
1
+ import torch
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ class Lion_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements the SMMF technique for Lion algorithm.
14
+
15
+ This optimizer combines the Lion update rule with the memory-saving low-rank
16
+ compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ factored (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ variance_reduction (bool): whether to use the variance reduction technique
37
+ from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ params,
43
+ lr: float = 1e-4,
44
+ betas: Tuple[float, float] = (0.9, 0.99),
45
+ weight_decay: float = 0.0,
46
+ vector_reshape: bool = True,
47
+ stochastic_rounding: bool = True,
48
+ use_orthograd: bool = False,
49
+ use_cautious: bool = False,
50
+ clip_threshold: float = 0.0,
51
+ factored: bool = True,
52
+ variance_reduction: bool = False,
53
+ ):
54
+ if not lr > 0.0:
55
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
56
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
57
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
+ if not weight_decay >= 0.0:
59
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
+
61
+ defaults = dict(
62
+ lr=lr,
63
+ betas=betas,
64
+ weight_decay=weight_decay,
65
+ vector_reshape=vector_reshape,
66
+ use_orthograd=use_orthograd,
67
+ clip_threshold=clip_threshold,
68
+ )
69
+ self.stochastic_rounding = stochastic_rounding
70
+ self.use_cautious = use_cautious
71
+ self.factored = factored
72
+ self.variance_reduction = variance_reduction
73
+ super().__init__(params, defaults)
74
+
75
+ @property
76
+ def supports_fused_back_pass(self) -> bool:
77
+ return True
78
+
79
+ @property
80
+ def supports_memory_efficient_fp16(self) -> bool:
81
+ return True
82
+
83
+ @property
84
+ def supports_flat_params(self) -> bool:
85
+ return False
86
+
87
+ @torch.no_grad()
88
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
89
+ """Performs a single optimization step on a single parameter."""
90
+ if p.grad is None:
91
+ return
92
+
93
+ grad = p.grad
94
+ if grad.dtype != torch.float32 and self.factored:
95
+ grad = grad.float()
96
+ if group["clip_threshold"] > 0.0:
97
+ grad_norm = torch.norm(grad.detach())
98
+ if grad_norm > group["clip_threshold"]:
99
+ clip_coef = group["clip_threshold"] / grad_norm
100
+ grad.mul_(clip_coef)
101
+ if group["use_orthograd"]:
102
+ grad = _orthogonalize_gradient(p, grad)
103
+ state = self.state[p]
104
+
105
+ # State Initialization
106
+ if len(state) == 0:
107
+ state['step'] = 0
108
+
109
+ should_factor = (
110
+ self.factored and
111
+ not (len(p.shape) == 1 and not group['vector_reshape'])
112
+ )
113
+
114
+ state['factored'] = should_factor
115
+
116
+ dtype = torch.float32 if self.factored else p.dtype
117
+
118
+ if state['factored']:
119
+ state['effective_shape'] = _get_effective_shape(p.numel())
120
+ d1, d2 = state['effective_shape']
121
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
122
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
123
+ packed_d2 = (d2 + 7) // 8
124
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
125
+ if self.variance_reduction:
126
+ state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
127
+ else: # Fallback to standard Lion
128
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
129
+ if self.variance_reduction:
130
+ state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
+
132
+ state['step'] += 1
133
+ beta1, beta2 = group["betas"]
134
+ lr = group["lr"]
135
+
136
+ if state['factored']:
137
+ # Factored Path
138
+ d1, d2 = state['effective_shape']
139
+ grad_reshaped = grad.view(d1, d2)
140
+ # Reconstruct momentum m_{t-1}
141
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
142
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
143
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
144
+ del unpacked_sign
145
+ if exp_avg.dtype != torch.float32:
146
+ exp_avg = exp_avg.float()
147
+
148
+ # Compute update term c_t
149
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
150
+
151
+ if self.use_cautious:
152
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
153
+ mask.div_(mask.mean().clamp_(min=1e-3))
154
+ signed_update.mul_(mask)
155
+ del mask
156
+
157
+ # Parameter update
158
+ update_for_param = signed_update.view(p.shape).mul_(lr)
159
+
160
+ # Update momentum
161
+ if self.variance_reduction:
162
+ if state['step'] == 1:
163
+ exp_avg.copy_(grad_reshaped)
164
+ else:
165
+ # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
166
+ correction = exp_avg.sub(state['prev_grad'])
167
+ # Calculate the new momentum and store it back into exp_avg
168
+ exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
169
+ del correction
170
+ # Update prev_grad for the next iteration
171
+ state['prev_grad'].copy_(grad_reshaped)
172
+ else:
173
+ # Standard Lion momentum update
174
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
175
+
176
+ # Compress new momentum m_t and store factors
177
+ state['sign'] = _pack_bools(exp_avg > 0)
178
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
179
+ del exp_avg
180
+
181
+ else:
182
+ # Fallback to standard Lion logic
183
+ exp_avg = state["exp_avg"]
184
+
185
+ # Compute update term and sign for the update
186
+ if exp_avg.dtype != torch.float32 and self.factored:
187
+ exp_avg = exp_avg.float()
188
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
189
+
190
+ if self.use_cautious:
191
+ mask = (signed_update * grad > 0).to(grad.dtype)
192
+ mask.div_(mask.mean().clamp_(min=1e-3))
193
+ signed_update.mul_(mask)
194
+ del mask
195
+
196
+ update_for_param = signed_update.mul_(lr)
197
+
198
+ # Update momentum
199
+ if self.variance_reduction:
200
+ if state['step'] == 1:
201
+ exp_avg.copy_(grad)
202
+ else:
203
+ # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
204
+ correction = exp_avg.sub(state['prev_grad'])
205
+ # Calculate the new momentum and store it back into exp_avg
206
+ exp_avg.copy_(grad).add_(correction, alpha=beta2)
207
+ del correction
208
+ # Update prev_grad for the next iteration
209
+ state['prev_grad'].copy_(grad)
210
+ else:
211
+ # Standard Lion momentum update
212
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
213
+
214
+ if group["weight_decay"] != 0:
215
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
216
+ add_stochastic_(p.data, p.data,
217
+ alpha=-group["weight_decay"] * lr)
218
+ else:
219
+ p.data.add_(
220
+ p.data, alpha=-group["weight_decay"] * lr
221
+ )
222
+
223
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
224
+ add_stochastic_(p.data, -update_for_param)
225
+ else:
226
+ p.data.add_(-update_for_param)
227
+
228
+ del update_for_param
229
+
230
+ @torch.no_grad()
231
+ def step(self, closure: Optional[callable] = None):
232
+ """Performs a single optimization step."""
233
+ loss = None
234
+ if closure is not None:
235
+ with torch.enable_grad():
236
+ loss = closure()
237
+
238
+ for group in self.param_groups:
239
+ for i, p in enumerate(group["params"]):
240
+ if p.grad is not None:
241
+ self.step_parameter(p, group, i)
242
+
229
243
  return loss
@@ -265,7 +265,7 @@ class Prodigy_adv(torch.optim.Optimizer):
265
265
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
266
266
  update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
267
267
  else:
268
- update = mt if self.beta1 > 0 else grad_reshaped
268
+ update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
269
269
  del grad_reshaped
270
270
 
271
271
  if group['use_atan2']:
@@ -311,7 +311,7 @@ class Prodigy_adv(torch.optim.Optimizer):
311
311
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
312
312
  update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
313
313
  else:
314
- update = exp_avg if self.beta1 > 0 else grad
314
+ update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
315
315
 
316
316
  exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
317
317
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="0.1.3",
8
+ version="0.1.5",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes
File without changes