adv-optm 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (23) hide show
  1. {adv_optm-0.1.2 → adv_optm-0.1.4}/PKG-INFO +1 -1
  2. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/__init__.py +1 -1
  3. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/optim/AdamW_adv.py +296 -296
  4. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/optim/Lion_Prodigy_adv.py +22 -8
  5. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/optim/Lion_adv.py +242 -230
  6. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/optim/Prodigy_adv.py +56 -51
  7. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm.egg-info/PKG-INFO +1 -1
  8. {adv_optm-0.1.2 → adv_optm-0.1.4}/setup.py +1 -1
  9. {adv_optm-0.1.2 → adv_optm-0.1.4}/LICENSE +0 -0
  10. {adv_optm-0.1.2 → adv_optm-0.1.4}/README.md +0 -0
  11. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/optim/Adopt_adv.py +0 -0
  12. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/optim/__init__.py +0 -0
  13. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  14. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/util/Effective_Shape.py +0 -0
  15. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/util/NNMF.py +0 -0
  16. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/util/One_Bit_Boolean.py +0 -0
  17. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/util/OrthoGrad.py +0 -0
  18. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm/util/__init__.py +0 -0
  19. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm.egg-info/SOURCES.txt +0 -0
  20. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm.egg-info/dependency_links.txt +0 -0
  21. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm.egg-info/requires.txt +0 -0
  22. {adv_optm-0.1.2 → adv_optm-0.1.4}/adv_optm.egg-info/top_level.txt +0 -0
  23. {adv_optm-0.1.2 → adv_optm-0.1.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -14,4 +14,4 @@ __all__ = [
14
14
  "Lion_Prodigy_adv",
15
15
  ]
16
16
 
17
- __version__ = "0.1.2"
17
+ __version__ = "0.1.4"
@@ -1,297 +1,297 @@
1
- import torch
2
- from typing import Optional
3
-
4
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
- from ..util.Effective_Shape import _get_effective_shape
6
- from ..util.NNMF import _nnmf,_unnmf
7
- from ..util.OrthoGrad import _orthogonalize_gradient
8
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
-
10
- class AdamW_adv(torch.optim.Optimizer):
11
- """
12
- Implements a factored AdamW algorithm.
13
- This is an advanced version of AdamW with optional features like
14
- low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
15
-
16
- Args:
17
- params (iterable): iterable of parameters to optimize or dicts defining
18
- parameter groups
19
- lr (float): learning rate (default: 1e-3)
20
- betas (tuple[float, float]): coefficients used for computing running
21
- averages of gradient and its square (default: (0.9, 0.999))
22
- eps (float): term added to the denominator to improve
23
- numerical stability (default: 1e-8)
24
- weight_decay (float): weight decay (L2 penalty) (default: 0)
25
- vector_reshape (bool): whether to reshape 1D vectors into 2D
26
- matrices to apply low-rank compression (default: True).
27
- stochastic_rounding (bool): whether to use stochastic
28
- rounding for BF16 parameter updates (default: True).
29
- use_atan2 (bool): whether to use the atan2 update rule. (default: False)
30
- use_grams (bool): whether to use Grams-style updates. (default: False)
31
- use_cautious (bool): whether to use cautious masking to align the gradient's
32
- direction with the first moment's. (default: False)
33
- use_orthograd (bool): whether to use OrthoGrad. (default: False)
34
- use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
35
- a second, slow-moving average of the momentum (`mt_slow`) which is
36
- combined with the primary momentum (`mt`) to stabilize updates,
37
- especially in noisy, small-batch settings. If `False`, the
38
- optimizer behaves as standard AdamW. (default: False)
39
- beta3_ema (float): The decay rate for the slow exponential moving average of
40
- the momentum (only used when `use_AdEMAMix` is `True`). A higher
41
- value (e.g., 0.9999) gives the EMA a longer memory, making it more
42
- stable but slower to adapt. A lower value (e.g., 0.999) is often
43
- better for shorter training runs. (default: 0.9999)
44
- alpha (float): The mixing coefficient that scales the slow momentum term
45
- before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
46
- A higher value increases the stabilizing influence of the slow
47
- momentum. (default: 5.0)
48
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
49
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
50
- highly recommended to prevent instability at the beginning of training,
51
- as it gradually introduces the stabilizing slow momentum term. During
52
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
53
- the scheduler is disabled and th
54
- factored (bool): whether to use the factorization or disable it to use
55
- the uncompressed optimizer. (default: True)
56
- """
57
-
58
- def __init__(
59
- self,
60
- params,
61
- lr: float = 1e-3,
62
- betas: tuple[float, float] = (0.9, 0.999),
63
- eps: float = 1e-8,
64
- weight_decay: float = 0.0,
65
- vector_reshape: bool = True,
66
- stochastic_rounding: bool = True,
67
- use_atan2: bool = False,
68
- use_cautious: bool = False,
69
- use_grams: bool = False,
70
- use_orthograd: bool = False,
71
- use_AdEMAMix: bool = False,
72
- beta3_ema: float = 0.9999,
73
- alpha: float = 5.0,
74
- t_alpha: int | None = None,
75
- factored: bool = True,
76
- ):
77
- if not (lr >= 0.0):
78
- raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
79
- if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
80
- raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
81
- if not (eps >= 0.0):
82
- raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
83
- if not (weight_decay >= 0.0):
84
- raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
85
-
86
- defaults = {
87
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
88
- "vector_reshape": vector_reshape, "use_atan2": use_atan2,
89
- "use_orthograd": use_orthograd,
90
- "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
91
- }
92
- self.stochastic_rounding = stochastic_rounding
93
- self.use_cautious = use_cautious
94
- self.use_grams = use_grams
95
- self.use_AdEMAMix = use_AdEMAMix
96
- self.factored = factored
97
- super().__init__(params, defaults)
98
-
99
- @property
100
- def supports_fused_back_pass(self):
101
- return True
102
-
103
- @property
104
- def supports_memory_efficient_fp16(self):
105
- return True
106
-
107
- @property
108
- def supports_flat_params(self):
109
- return False
110
-
111
- @torch.no_grad()
112
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
113
- if p.grad is None:
114
- return
115
-
116
- grad = p.grad
117
- if grad.dtype != torch.float32 and self.factored:
118
- grad = grad.float()
119
- if group["use_orthograd"]:
120
- grad = _orthogonalize_gradient(p, grad)
121
- state = self.state[p]
122
-
123
- beta1, beta2 = group['betas']
124
-
125
- # State Initialization
126
- if len(state) == 0:
127
- state['step'] = 0
128
-
129
- should_factor = (
130
- self.factored and
131
- not (len(p.shape) == 1 and not group['vector_reshape'])
132
- )
133
-
134
- state['factored'] = should_factor
135
-
136
- dtype = torch.float32 if self.factored else p.dtype
137
- device = p.device
138
-
139
- if state['factored']:
140
- state['effective_shape'] = _get_effective_shape(p.numel())
141
- d1, d2 = state['effective_shape']
142
-
143
- # First moment (m)
144
- if beta1 > 0:
145
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
146
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
147
- if not self.use_grams:
148
- packed_d2 = (d2 + 7) // 8
149
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
150
- if self.use_AdEMAMix:
151
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
152
- state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
153
- packed_d2 = (d2 + 7) // 8
154
- state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
155
- # Second moment (v)
156
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
157
- state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
158
- else: # Fallback to standard AdamW for non-factored tensors
159
- if beta1 > 0:
160
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
161
- if self.use_AdEMAMix:
162
- state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
163
- state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
164
-
165
- if self.use_AdEMAMix:
166
- beta3_ema = group['beta3_ema']
167
- alpha = group['alpha']
168
- t_alpha = group['t_alpha']
169
- current_step = state['step'] + 1
170
- alpha_t = alpha
171
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
172
- alpha_t = min(current_step * alpha / t_alpha, alpha)
173
-
174
- if state['factored']:
175
- d1, d2 = state['effective_shape']
176
-
177
- # Reconstruct momentum from previous step's factors
178
- if beta1 > 0:
179
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
180
- if not self.use_grams:
181
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
182
- torch.where(unpacked_sign, mt, -mt, out=mt)
183
- del unpacked_sign
184
- # Update momentum in full-size
185
- grad_reshaped = grad.view(d1, d2)
186
- mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
187
- if self.use_grams:
188
- mt.copy_(grad_reshaped.sign() * mt.abs())
189
- elif self.use_cautious:
190
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
191
- mask.div_(mask.mean().clamp_(min=1e-3))
192
- mt.mul_(mask)
193
- del mask
194
-
195
- vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
196
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
197
-
198
- if self.use_AdEMAMix:
199
- mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
200
- if state['sign_slow'].dtype != torch.uint8:
201
- state['sign_slow'] = state['sign_slow'].to(torch.uint8)
202
- unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
203
- torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
204
- del unpacked_sign_slow
205
-
206
- mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
207
- update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
208
- else:
209
- update = mt if beta1 > 0 else grad_reshaped
210
- del grad_reshaped
211
-
212
- if group['use_atan2']:
213
- a = 1.2732395
214
- denom = vt.sqrt()
215
- update.atan2_(denom).mul_(a)
216
- else:
217
- denom = vt.sqrt()
218
- update.div_(denom.add_(group['eps']))
219
- del denom
220
-
221
- update.view(p.shape).mul_(group['lr'])
222
-
223
- # Compress updated moments and store new factors
224
- if beta1 > 0:
225
- if not self.use_grams:
226
- state['sign'] = _pack_bools(mt > 0)
227
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
228
- del mt
229
- if self.use_AdEMAMix:
230
- state['sign_slow'] = _pack_bools(mt_slow > 0)
231
- _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
232
- del mt_slow
233
- _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
234
- del vt
235
-
236
- else: # Standard AdamW logic for non-factored tensors
237
- exp_avg_sq = state['exp_avg_sq']
238
-
239
- if beta1 > 0:
240
- exp_avg = state['exp_avg']
241
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
242
- if self.use_grams:
243
- exp_avg = grad.sign() * exp_avg.abs()
244
- elif self.use_cautious:
245
- mask = (exp_avg * grad > 0).to(grad.dtype)
246
- mask.div_(mask.mean().clamp_(min=1e-3))
247
- exp_avg.mul_(mask)
248
- del mask
249
-
250
- if self.use_AdEMAMix:
251
- exp_avg_slow = state['exp_avg_slow']
252
- exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
253
- update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
254
- else:
255
- update = exp_avg if beta1 > 0 else grad
256
-
257
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
258
-
259
- if group['use_atan2']:
260
- a = 1.2732395
261
- denom = exp_avg_sq.sqrt()
262
- update.atan2_(denom).mul_(a)
263
- else:
264
- denom = exp_avg_sq.sqrt()
265
- update.div_(denom.add_(group['eps']))
266
- del denom
267
-
268
- update.mul_(group['lr'])
269
-
270
- # Decoupled weight decay
271
- if group["weight_decay"] != 0:
272
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
273
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
274
- else:
275
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
276
-
277
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
278
- add_stochastic_(p.data, -update)
279
- else:
280
- p.data.add_(-update)
281
- del update
282
-
283
- state['step'] += 1
284
-
285
- @torch.no_grad()
286
- def step(self, closure=None):
287
- """Performs a single optimization step."""
288
- loss = None
289
- if closure is not None:
290
- with torch.enable_grad():
291
- loss = closure()
292
-
293
- for group in self.param_groups:
294
- for i, p in enumerate(group['params']):
295
- self.step_parameter(p, group, i)
296
-
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+
10
+ class AdamW_adv(torch.optim.Optimizer):
11
+ """
12
+ Implements a factored AdamW algorithm.
13
+ This is an advanced version of AdamW with optional features like
14
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
15
+
16
+ Args:
17
+ params (iterable): iterable of parameters to optimize or dicts defining
18
+ parameter groups
19
+ lr (float): learning rate (default: 1e-3)
20
+ betas (tuple[float, float]): coefficients used for computing running
21
+ averages of gradient and its square (default: (0.9, 0.999))
22
+ eps (float): term added to the denominator to improve
23
+ numerical stability (default: 1e-8)
24
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
25
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
30
+ use_grams (bool): whether to use Grams-style updates. (default: False)
31
+ use_cautious (bool): whether to use cautious masking to align the gradient's
32
+ direction with the first moment's. (default: False)
33
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
34
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
35
+ a second, slow-moving average of the momentum (`mt_slow`) which is
36
+ combined with the primary momentum (`mt`) to stabilize updates,
37
+ especially in noisy, small-batch settings. If `False`, the
38
+ optimizer behaves as standard AdamW. (default: False)
39
+ beta3_ema (float): The decay rate for the slow exponential moving average of
40
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
41
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
42
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
43
+ better for shorter training runs. (default: 0.9999)
44
+ alpha (float): The mixing coefficient that scales the slow momentum term
45
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
46
+ A higher value increases the stabilizing influence of the slow
47
+ momentum. (default: 5.0)
48
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
49
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
50
+ highly recommended to prevent instability at the beginning of training,
51
+ as it gradually introduces the stabilizing slow momentum term. During
52
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
53
+ the scheduler is disabled and th
54
+ factored (bool): whether to use the factorization or disable it to use
55
+ the uncompressed optimizer. (default: True)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ params,
61
+ lr: float = 1e-3,
62
+ betas: tuple[float, float] = (0.9, 0.999),
63
+ eps: float = 1e-8,
64
+ weight_decay: float = 0.0,
65
+ vector_reshape: bool = True,
66
+ stochastic_rounding: bool = True,
67
+ use_atan2: bool = False,
68
+ use_cautious: bool = False,
69
+ use_grams: bool = False,
70
+ use_orthograd: bool = False,
71
+ use_AdEMAMix: bool = False,
72
+ beta3_ema: float = 0.9999,
73
+ alpha: float = 5.0,
74
+ t_alpha: int | None = None,
75
+ factored: bool = True,
76
+ ):
77
+ if not (lr >= 0.0):
78
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
79
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
80
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
81
+ if not (eps >= 0.0):
82
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
83
+ if not (weight_decay >= 0.0):
84
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
85
+
86
+ defaults = {
87
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
88
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
89
+ "use_orthograd": use_orthograd,
90
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
91
+ }
92
+ self.stochastic_rounding = stochastic_rounding
93
+ self.use_cautious = use_cautious
94
+ self.use_grams = use_grams
95
+ self.use_AdEMAMix = use_AdEMAMix
96
+ self.factored = factored
97
+ super().__init__(params, defaults)
98
+
99
+ @property
100
+ def supports_fused_back_pass(self):
101
+ return True
102
+
103
+ @property
104
+ def supports_memory_efficient_fp16(self):
105
+ return True
106
+
107
+ @property
108
+ def supports_flat_params(self):
109
+ return False
110
+
111
+ @torch.no_grad()
112
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
113
+ if p.grad is None:
114
+ return
115
+
116
+ grad = p.grad
117
+ if grad.dtype != torch.float32 and self.factored:
118
+ grad = grad.float()
119
+ if group["use_orthograd"]:
120
+ grad = _orthogonalize_gradient(p, grad)
121
+ state = self.state[p]
122
+
123
+ beta1, beta2 = group['betas']
124
+
125
+ # State Initialization
126
+ if len(state) == 0:
127
+ state['step'] = 0
128
+
129
+ should_factor = (
130
+ self.factored and
131
+ not (len(p.shape) == 1 and not group['vector_reshape'])
132
+ )
133
+
134
+ state['factored'] = should_factor
135
+
136
+ dtype = torch.float32 if self.factored else p.dtype
137
+ device = p.device
138
+
139
+ if state['factored']:
140
+ state['effective_shape'] = _get_effective_shape(p.numel())
141
+ d1, d2 = state['effective_shape']
142
+
143
+ # First moment (m)
144
+ if beta1 > 0:
145
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
146
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
147
+ if not self.use_grams:
148
+ packed_d2 = (d2 + 7) // 8
149
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
150
+ if self.use_AdEMAMix:
151
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
152
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
153
+ packed_d2 = (d2 + 7) // 8
154
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
155
+ # Second moment (v)
156
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
157
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
158
+ else: # Fallback to standard AdamW for non-factored tensors
159
+ if beta1 > 0:
160
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
161
+ if self.use_AdEMAMix:
162
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
163
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
164
+
165
+ if self.use_AdEMAMix:
166
+ beta3_ema = group['beta3_ema']
167
+ alpha = group['alpha']
168
+ t_alpha = group['t_alpha']
169
+ current_step = state['step'] + 1
170
+ alpha_t = alpha
171
+ if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
172
+ alpha_t = min(current_step * alpha / t_alpha, alpha)
173
+
174
+ if state['factored']:
175
+ d1, d2 = state['effective_shape']
176
+
177
+ # Reconstruct momentum from previous step's factors
178
+ if beta1 > 0:
179
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
180
+ if not self.use_grams:
181
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
182
+ torch.where(unpacked_sign, mt, -mt, out=mt)
183
+ del unpacked_sign
184
+ # Update momentum in full-size
185
+ grad_reshaped = grad.view(d1, d2)
186
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
187
+ if self.use_grams:
188
+ mt.copy_(grad_reshaped.sign() * mt.abs())
189
+ elif self.use_cautious:
190
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
191
+ mask.div_(mask.mean().clamp_(min=1e-3))
192
+ mt.mul_(mask)
193
+ del mask
194
+
195
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
196
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
197
+
198
+ if self.use_AdEMAMix:
199
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
200
+ if state['sign_slow'].dtype != torch.uint8:
201
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
202
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
203
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
204
+ del unpacked_sign_slow
205
+
206
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
207
+ update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
208
+ else:
209
+ update = mt if beta1 > 0 else grad_reshaped
210
+ del grad_reshaped
211
+
212
+ if group['use_atan2']:
213
+ a = 1.2732395
214
+ denom = vt.sqrt()
215
+ update.atan2_(denom).mul_(a)
216
+ else:
217
+ denom = vt.sqrt()
218
+ update.div_(denom.add_(group['eps']))
219
+ del denom
220
+
221
+ update.view(p.shape).mul_(group['lr'])
222
+
223
+ # Compress updated moments and store new factors
224
+ if beta1 > 0:
225
+ if not self.use_grams:
226
+ state['sign'] = _pack_bools(mt > 0)
227
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
228
+ del mt
229
+ if self.use_AdEMAMix:
230
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
231
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
232
+ del mt_slow
233
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
234
+ del vt
235
+
236
+ else: # Standard AdamW logic for non-factored tensors
237
+ exp_avg_sq = state['exp_avg_sq']
238
+
239
+ if beta1 > 0:
240
+ exp_avg = state['exp_avg']
241
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
242
+ if self.use_grams:
243
+ exp_avg = grad.sign() * exp_avg.abs()
244
+ elif self.use_cautious:
245
+ mask = (exp_avg * grad > 0).to(grad.dtype)
246
+ mask.div_(mask.mean().clamp_(min=1e-3))
247
+ exp_avg.mul_(mask)
248
+ del mask
249
+
250
+ if self.use_AdEMAMix:
251
+ exp_avg_slow = state['exp_avg_slow']
252
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
253
+ update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
254
+ else:
255
+ update = exp_avg if beta1 > 0 else grad
256
+
257
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
258
+
259
+ if group['use_atan2']:
260
+ a = 1.2732395
261
+ denom = exp_avg_sq.sqrt()
262
+ update.atan2_(denom).mul_(a)
263
+ else:
264
+ denom = exp_avg_sq.sqrt()
265
+ update.div_(denom.add_(group['eps']))
266
+ del denom
267
+
268
+ update.mul_(group['lr'])
269
+
270
+ # Decoupled weight decay
271
+ if group["weight_decay"] != 0:
272
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
273
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
274
+ else:
275
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
276
+
277
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
278
+ add_stochastic_(p.data, -update)
279
+ else:
280
+ p.data.add_(-update)
281
+ del update
282
+
283
+ state['step'] += 1
284
+
285
+ @torch.no_grad()
286
+ def step(self, closure=None):
287
+ """Performs a single optimization step."""
288
+ loss = None
289
+ if closure is not None:
290
+ with torch.enable_grad():
291
+ loss = closure()
292
+
293
+ for group in self.param_groups:
294
+ for i, p in enumerate(group['params']):
295
+ self.step_parameter(p, group, i)
296
+
297
297
  return loss