adv-optm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ from .optim import (
2
+ AdamW_adv,
3
+ Prodigy_adv,
4
+ Adopt_adv,
5
+ )
6
+
7
+ __all__ = [
8
+ "AdamW_adv",
9
+ "Prodigy_adv",
10
+ "Adopt_adv",
11
+ ]
12
+
13
+ __version__ = "0.1.0"
@@ -0,0 +1,293 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+
10
+ class AdamW_adv(torch.optim.Optimizer):
11
+ """
12
+ Implements a factored AdamW algorithm.
13
+ This is an advanced version of AdamW with optional features like
14
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
15
+
16
+ Args:
17
+ params (iterable): iterable of parameters to optimize or dicts defining
18
+ parameter groups
19
+ lr (float): learning rate (default: 1e-3)
20
+ betas (tuple[float, float]): coefficients used for computing running
21
+ averages of gradient and its square (default: (0.9, 0.999))
22
+ eps (float): term added to the denominator to improve
23
+ numerical stability (default: 1e-8)
24
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
25
+ use_bias_correction (boolean): Turn on Adam's bias correction. (default: False)
26
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
27
+ matrices to apply low-rank compression (default: True).
28
+ stochastic_rounding (bool): whether to use stochastic
29
+ rounding for BF16 parameter updates (default: True).
30
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
31
+ use_grams (bool): whether to use Grams-style updates. (default: False)
32
+ use_cautious (bool): whether to use cautious masking to align the gradient's
33
+ direction with the first moment's. (default: False)
34
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
35
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
36
+ a second, slow-moving average of the momentum (`mt_slow`) which is
37
+ combined with the primary momentum (`mt`) to stabilize updates,
38
+ especially in noisy, small-batch settings. If `False`, the
39
+ optimizer behaves as standard AdamW. (default: False)
40
+ beta3 (float): The decay rate for the slow exponential moving average of
41
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
42
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
43
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
44
+ better for shorter training runs. (default: 0.9999)
45
+ alpha (float): The mixing coefficient that scales the slow momentum term
46
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
47
+ A higher value increases the stabilizing influence of the slow
48
+ momentum. (default: 5.0)
49
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
50
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
51
+ highly recommended to prevent instability at the beginning of training,
52
+ as it gradually introduces the stabilizing slow momentum term. During
53
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
54
+ the scheduler is disabled and th
55
+ factored (bool): whether to use the factorization or disable it to use
56
+ the uncompressed optimizer. (default: True)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ params,
62
+ lr: float = 1e-3,
63
+ betas: tuple[float, float] = (0.9, 0.999),
64
+ eps: float = 1e-8,
65
+ weight_decay: float = 0.0,
66
+ use_bias_correction: bool = False,
67
+ vector_reshape: bool = True,
68
+ stochastic_rounding: bool = True,
69
+ use_atan2: bool = False,
70
+ use_cautious: bool = False,
71
+ use_grams: bool = False,
72
+ use_orthograd: bool = False,
73
+ use_AdEMAMix: bool = False,
74
+ beta3: float = 0.9999,
75
+ alpha: float = 5.0,
76
+ t_alpha: int | None = None,
77
+ factored: bool = True,
78
+ ):
79
+ if not (lr >= 0.0):
80
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
81
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
82
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
83
+ if not (eps >= 0.0):
84
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
85
+ if not (weight_decay >= 0.0):
86
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
87
+
88
+ defaults = {
89
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
90
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
91
+ "use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
92
+ "beta3": beta3, "alpha": alpha, "t_alpha": t_alpha,
93
+ }
94
+ self.stochastic_rounding = stochastic_rounding
95
+ self.use_cautious = use_cautious
96
+ self.use_grams = use_grams
97
+ self.use_AdEMAMix = use_AdEMAMix
98
+ self.factored = factored
99
+ super().__init__(params, defaults)
100
+
101
+ @property
102
+ def supports_fused_back_pass(self):
103
+ return True
104
+
105
+ @property
106
+ def supports_memory_efficient_fp16(self):
107
+ return True
108
+
109
+ @property
110
+ def supports_flat_params(self):
111
+ return False
112
+
113
+ @torch.no_grad()
114
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
115
+ if p.grad is None:
116
+ return
117
+
118
+ grad = p.grad
119
+ if grad.dtype != torch.float32 and self.factored:
120
+ grad = grad.float()
121
+ if group["use_orthograd"]:
122
+ grad = _orthogonalize_gradient(p, grad)
123
+ state = self.state[p]
124
+
125
+ # State Initialization
126
+ if len(state) == 0:
127
+ state['step'] = 0
128
+
129
+ should_factor = (
130
+ self.factored and
131
+ not (len(p.shape) == 1 and not group['vector_reshape'])
132
+ )
133
+
134
+ state['factored'] = should_factor
135
+
136
+ dtype = torch.float32 if self.factored else p.dtype
137
+ device = p.device
138
+
139
+ if state['factored']:
140
+ state['effective_shape'] = _get_effective_shape(p.numel())
141
+ d1, d2 = state['effective_shape']
142
+
143
+ # First moment (m)
144
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
145
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
146
+ if not self.use_grams:
147
+ packed_d2 = (d2 + 7) // 8
148
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
149
+ if self.use_AdEMAMix:
150
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
151
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
152
+ packed_d2 = (d2 + 7) // 8
153
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
154
+ # Second moment (v)
155
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
156
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
157
+ else: # Fallback to standard AdamW for non-factored tensors
158
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
159
+ if self.use_AdEMAMix:
160
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
161
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
162
+
163
+ beta1, beta2 = group['betas']
164
+ if self.use_AdEMAMix:
165
+ beta3 = group['beta3']
166
+ alpha = group['alpha']
167
+ t_alpha = group['t_alpha']
168
+ current_step = state['step'] + 1
169
+ alpha_t = alpha
170
+ if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
171
+ alpha_t = min(current_step * alpha / t_alpha, alpha)
172
+
173
+ if state['factored']:
174
+ d1, d2 = state['effective_shape']
175
+
176
+ # Reconstruct momentum from previous step's factors
177
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
178
+ if not self.use_grams:
179
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
180
+ torch.where(unpacked_sign, mt, -mt, out=mt)
181
+ del unpacked_sign
182
+ # Update momentum in full-size
183
+ grad_reshaped = grad.view(d1, d2)
184
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
185
+ if self.use_grams:
186
+ mt.copy_(grad_reshaped.sign() * mt.abs())
187
+ elif self.use_cautious:
188
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
189
+ mask.div_(mask.mean().clamp_(min=1e-3))
190
+ mt.mul_(mask)
191
+ del mask
192
+
193
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
194
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
195
+
196
+ if self.use_AdEMAMix:
197
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
198
+ if state['sign_slow'].dtype != torch.uint8:
199
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
200
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
201
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
202
+ del unpacked_sign_slow
203
+
204
+ mt_slow.mul_(beta3).add_(grad_reshaped, alpha=1.0 - beta3)
205
+ update_m = mt + (alpha_t * mt_slow)
206
+ else:
207
+ update_m = mt
208
+ del grad_reshaped
209
+
210
+ if group['use_atan2']:
211
+ a = 1.2732395
212
+ denom = vt.sqrt()
213
+ update = torch.atan2(update_m, denom).mul_(a)
214
+ else:
215
+ denom = vt.sqrt().add_(group['eps'])
216
+ update = update_m / denom
217
+ del update_m, denom
218
+
219
+ update = update.view(p.shape)
220
+ update.mul_(group['lr'])
221
+
222
+ # Compress updated moments and store new factors
223
+ if not self.use_grams:
224
+ state['sign'] = _pack_bools(mt > 0)
225
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
226
+ del mt
227
+ if self.use_AdEMAMix:
228
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
229
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
230
+ del mt_slow
231
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
232
+ del vt
233
+
234
+ else: # Standard AdamW logic for non-factored tensors
235
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
236
+
237
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
238
+ if self.use_grams:
239
+ exp_avg = grad.sign() * exp_avg.abs()
240
+ elif self.use_cautious:
241
+ mask = (exp_avg * grad > 0).to(grad.dtype)
242
+ mask.div_(mask.mean().clamp_(min=1e-3))
243
+ exp_avg.mul_(mask)
244
+ del mask
245
+
246
+ if self.use_AdEMAMix:
247
+ exp_avg_slow = state['exp_avg_slow']
248
+ exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
249
+ update_m = exp_avg + (alpha_t * exp_avg_slow)
250
+ else:
251
+ update_m = exp_avg
252
+
253
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
254
+
255
+ if group['use_atan2']:
256
+ a = 1.2732395
257
+ denom = exp_avg_sq.sqrt()
258
+ update = torch.atan2(update_m, denom).mul_(a)
259
+ else:
260
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
261
+ update = update_m / denom
262
+ del update_m, denom
263
+
264
+ update = update.mul_(group['lr'])
265
+
266
+ # Decoupled weight decay
267
+ if group["weight_decay"] != 0:
268
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
269
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
270
+ else:
271
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
272
+
273
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
274
+ add_stochastic_(p.data, -update)
275
+ else:
276
+ p.data.add_(-update)
277
+ del update
278
+
279
+ state['step'] += 1
280
+
281
+ @torch.no_grad()
282
+ def step(self, closure=None):
283
+ """Performs a single optimization step."""
284
+ loss = None
285
+ if closure is not None:
286
+ with torch.enable_grad():
287
+ loss = closure()
288
+
289
+ for group in self.param_groups:
290
+ for i, p in enumerate(group['params']):
291
+ self.step_parameter(p, group, i)
292
+
293
+ return loss
@@ -0,0 +1,336 @@
1
+ import torch
2
+ from typing import Callable, Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf, _unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+
10
+ class Adopt_adv(torch.optim.Optimizer):
11
+ """
12
+ Implements a fusion of SMMF, and the ADOPT algorithm.
13
+
14
+ The ADOPT update rule modifies Adam by:
15
+ 1. **Initialization:** The second moment `v` is initialized as `v₀ = g₀²`.
16
+ 2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
17
+ from the *previous* step (`v_{t-1}`).
18
+ 3. **Order of Operations:** This normalization occurs *before* updating the
19
+ first-moment (momentum) estimate.
20
+
21
+ Args:
22
+ params (iterable): iterable of parameters to optimize or dicts defining
23
+ parameter groups
24
+ lr (float): learning rate (default: 1e-4)
25
+ betas (tuple[float, float]): coefficients used for computing running
26
+ averages of momentum and variance (default: (0.9, 0.9999))
27
+ eps (float): term added to the denominator to improve
28
+ numerical stability (default: 1e-6)
29
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
30
+ clip_lambda (Callable, optional): A function that takes the current step
31
+ and returns a value to clip the normalized gradient. Only used when
32
+ `use_atan2` is False. (default: `lambda step: step**0.25`)
33
+ rank (int): the rank for the low-rank approximation (default: 4).
34
+ oversampling (int): oversampling parameter for Randomized SVD. (default: 0).
35
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
36
+ matrices for low-rank compression (default: True).
37
+ stochastic_rounding (bool): whether to use stochastic
38
+ rounding for BF16 parameter updates (default: True).
39
+ use_atan2 (bool): whether to use an atan2-based normalization, which can
40
+ improve stability by removing the need for `eps`. (default: False)
41
+ use_cautious (bool): whether to use cautious masking to align the gradient's
42
+ direction with the first moment's. (default: False)
43
+ use_grams (bool): whether to combine the gradient's direction with the
44
+ first moment's magnitude (default: False).
45
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
46
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
47
+ a second, slow-moving average of the momentum (`mt_slow`) which is
48
+ combined with the primary momentum (`mt`) to stabilize updates,
49
+ especially in noisy, small-batch settings. If `False`, the
50
+ optimizer behaves as standard ADOPT. (default: False)
51
+ beta3 (float): The decay rate for the slow exponential moving average of
52
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
53
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
54
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
55
+ better for shorter training runs. (default: 0.9999)
56
+ alpha (float): The mixing coefficient that scales the slow momentum term
57
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
58
+ A higher value increases the stabilizing influence of the slow
59
+ momentum. (default: 5.0)
60
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
61
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
62
+ highly recommended to prevent instability at the beginning of training,
63
+ as it gradually introduces the stabilizing slow momentum term. During
64
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
65
+ the scheduler is disabled and the full `alpha` value is used from
66
+ the start. (default: None)
67
+ factored (bool): whether to use the factorization or disable it to use
68
+ the uncompressed optimizer. (default: True)
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ params,
74
+ lr: float = 1e-4,
75
+ betas: tuple[float, float] = (0.9, 0.9999),
76
+ eps: float = 1e-6,
77
+ weight_decay: float = 0.0,
78
+ clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
79
+ vector_reshape: bool = True,
80
+ stochastic_rounding: bool = True,
81
+ use_atan2: bool = False,
82
+ use_cautious: bool = True,
83
+ use_grams: bool = False,
84
+ use_orthograd: bool = False,
85
+ use_AdEMAMix: bool = False,
86
+ beta3: float = 0.9999,
87
+ alpha: float = 5.0,
88
+ t_alpha: int | None = None,
89
+ factored: bool = True,
90
+ ):
91
+ if not (lr >= 0.0):
92
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
93
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
94
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
95
+ if not (eps >= 0.0):
96
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
97
+ if not (weight_decay >= 0.0):
98
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
99
+
100
+ defaults = {
101
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
102
+ "vector_reshape": vector_reshape, "beta3": beta3, "alpha": alpha,
103
+ "t_alpha": t_alpha,
104
+ }
105
+ self.clip_lambda = clip_lambda
106
+ self.stochastic_rounding = stochastic_rounding
107
+ self.use_atan2 = use_atan2
108
+ self.use_cautious = use_cautious
109
+ self.use_grams = use_grams
110
+ self.use_orthograd = use_orthograd
111
+ self.use_AdEMAMix = use_AdEMAMix
112
+ self.factored = factored
113
+ super().__init__(params, defaults)
114
+
115
+ @property
116
+ def supports_fused_back_pass(self): return True
117
+ @property
118
+ def supports_memory_efficient_fp16(self): return True
119
+ @property
120
+ def supports_flat_params(self): return False
121
+
122
+ @torch.no_grad()
123
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
124
+ if p.grad is None:
125
+ return
126
+
127
+ grad = p.grad
128
+ if self.factored and grad.dtype != torch.float32:
129
+ grad = grad.float()
130
+ if self.use_orthograd:
131
+ grad = _orthogonalize_gradient(p, grad)
132
+ state = self.state[p]
133
+
134
+ # State Initialization
135
+ if len(state) == 0:
136
+ state['step'] = 0
137
+
138
+ should_factor = (
139
+ self.factored and
140
+ not (len(p.shape) == 1 and not group['vector_reshape'])
141
+ )
142
+
143
+ state['factored'] = should_factor
144
+
145
+ dtype = torch.float32 if self.factored else p.dtype
146
+
147
+ if state['factored']:
148
+ state['effective_shape'] = _get_effective_shape(p.numel())
149
+ d1, d2 = state['effective_shape']
150
+
151
+ # m_0 = 0
152
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
153
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
154
+ if not self.use_grams:
155
+ packed_d2 = (d2 + 7) // 8
156
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
157
+ if self.use_AdEMAMix:
158
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
159
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
160
+ packed_d2 = (d2 + 7) // 8
161
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
162
+ # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
163
+ vt_init = grad.view(d1, d2).square_()
164
+ # Allocate NMF factors for v
165
+ state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
166
+ state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
167
+ # Initialize v_0 using NMF
168
+ _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
169
+ else: # Fallback for non-factored tensors
170
+ state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
171
+ if self.use_AdEMAMix:
172
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
173
+ state['exp_avg_sq'] = grad.square() # v_0
174
+
175
+ # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
176
+ if state['step'] == 0 and not self.use_atan2:
177
+ state['step'] += 1
178
+ return
179
+
180
+ beta1, beta2 = group['betas']
181
+ if self.use_AdEMAMix:
182
+ beta3 = group['beta3']
183
+ alpha = group['alpha']
184
+ t_alpha = group['t_alpha']
185
+ # Use step+1 for 1-based step count in scheduler
186
+ current_step = state['step'] + 1
187
+ alpha_t = alpha
188
+ if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
189
+ alpha_t = min(current_step * alpha / t_alpha, alpha)
190
+
191
+ if state['factored']:
192
+ d1, d2 = state['effective_shape']
193
+
194
+ # Reconstruct m_{t-1}
195
+ mt_prev = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
196
+ if not self.use_grams:
197
+ if state['sign'].dtype != torch.uint8:
198
+ state['sign'] = state['sign'].to(torch.uint8)
199
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
200
+ torch.where(unpacked_sign, mt_prev, -mt_prev, out=mt_prev)
201
+ del unpacked_sign
202
+
203
+ # Reconstruct AdEMAMix EMA
204
+ if self.use_AdEMAMix:
205
+ mt_slow_prev = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
206
+ if state['sign_slow'].dtype != torch.uint8:
207
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
208
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
209
+ torch.where(unpacked_sign_slow, mt_slow_prev, -mt_slow_prev, out=mt_slow_prev)
210
+ del unpacked_sign_slow
211
+
212
+ # Reconstruct v_{t-1} using NNMF
213
+ vt_prev = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
214
+
215
+ # ADOPT Step A: Decorrelate g_t using v_{t-1}
216
+ grad_reshaped = grad.view(d1, d2)
217
+ denom = vt_prev.sqrt()
218
+
219
+ if self.use_atan2:
220
+ normalized_grad = torch.atan2(grad_reshaped, denom)
221
+ else:
222
+ normalized_grad = grad_reshaped / denom.add_(group['eps'])
223
+ if self.clip_lambda is not None:
224
+ clip_val = self.clip_lambda(state['step'])
225
+ normalized_grad.clamp_(-clip_val, clip_val)
226
+ del denom
227
+
228
+ # ADOPT Step B: Update momentum m_t using normalized gradient
229
+ mt = mt_prev.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
230
+ if self.use_grams:
231
+ mt = grad_reshaped.sign() * mt.abs()
232
+ elif self.use_cautious:
233
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
234
+ mask.div_(mask.mean().clamp_(min=1e-3))
235
+ mt.mul_(mask)
236
+ del mask
237
+
238
+ if self.use_AdEMAMix:
239
+ mt_slow = mt_slow_prev.mul_(beta3).add_(normalized_grad, alpha=1.0 - beta3)
240
+ update = mt + (alpha_t * mt_slow)
241
+ update = update.view(p.shape)
242
+ else:
243
+ update = mt.view(p.shape)
244
+
245
+ if self.use_atan2:
246
+ update.mul_(group['lr'] * 1.2732395447351628)
247
+ else:
248
+ update.mul_(group['lr'])
249
+
250
+ # Update second moment v_t for the *next* step using raw g_t
251
+ vt_updated = vt_prev.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
252
+ del grad_reshaped
253
+
254
+ # Compress and store new factors
255
+ if not self.use_grams:
256
+ state['sign'] = _pack_bools(mt > 0)
257
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
258
+
259
+ if self.use_AdEMAMix:
260
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
261
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
262
+
263
+ # factorize v_t using NMF compression
264
+ _nnmf(vt_updated, out=(state['mu_v_nmf'], state['mv_v_nmf']))
265
+
266
+ else: # Standard ADOPT logic for non-factored tensors
267
+ m, v = state['exp_avg'], state['exp_avg_sq'] # m_{t-1}, v_{t-1}
268
+
269
+ if self.use_AdEMAMix:
270
+ m_slow = state['exp_avg_slow']
271
+
272
+ # ADOPT Step A: Decorrelate g_t using v_{t-1}
273
+ denom = v.sqrt()
274
+
275
+ if self.use_atan2:
276
+ normalized_grad = torch.atan2(grad, denom)
277
+ else:
278
+ normalized_grad = grad / denom.add_(group['eps'])
279
+ if self.clip_lambda is not None:
280
+ clip_val = self.clip_lambda(state['step'])
281
+ normalized_grad.clamp_(-clip_val, clip_val)
282
+ del denom
283
+
284
+ # ADOPT Step B: Update momentum m_t
285
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
286
+
287
+ if self.use_grams:
288
+ m = grad.sign() * m.abs()
289
+ elif self.use_cautious:
290
+ mask = (m * grad > 0).to(grad.dtype)
291
+ mask.div_(mask.mean().clamp_(min=1e-3))
292
+ m.mul_(mask)
293
+ del mask
294
+
295
+ if self.use_AdEMAMix:
296
+ m_slow.mul_(beta3).add_(normalized_grad, alpha=1.0 - beta3)
297
+ update = m + (alpha_t * m_slow)
298
+ else:
299
+ update = m
300
+
301
+ if self.use_atan2:
302
+ update.mul_(group['lr'] * 1.2732395447351628)
303
+ else:
304
+ update.mul_(group['lr'])
305
+
306
+ # Update second moment v_t for the next step using raw g_t
307
+ v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
308
+
309
+ # Parameter Update
310
+ if group["weight_decay"] != 0:
311
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
312
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
313
+ else:
314
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
315
+
316
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
317
+ add_stochastic_(p.data, -update)
318
+ else:
319
+ p.data.add_(-update)
320
+ del update
321
+
322
+ state['step'] += 1
323
+
324
+ @torch.no_grad()
325
+ def step(self, closure=None):
326
+ """Performs a single optimization step."""
327
+ loss = None
328
+ if closure is not None:
329
+ with torch.enable_grad():
330
+ loss = closure()
331
+
332
+ for group in self.param_groups:
333
+ for i, p in enumerate(group['params']):
334
+ self.step_parameter(p, group, i)
335
+
336
+ return loss