adv-optm 1.1.0.dev2__py3-none-any.whl → 1.1.0.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -1,440 +1,436 @@
1
- import torch
2
- from typing import Callable, Optional
3
-
4
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
- from ..util.Effective_Shape import _get_effective_shape
6
- from ..util.NNMF import _nnmf, _unnmf
7
- from ..util.OrthoGrad import _orthogonalize_gradient
8
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
- from ..util.Kourkoutas import KourkoutasHelper
10
-
11
- class Adopt_adv(torch.optim.Optimizer):
12
- """
13
- Implements a fusion of SMMF, and the ADOPT algorithm.
14
-
15
- The ADOPT update rule modifies Adam by:
16
- 1. **Initialization:** The second moment `v` is initialized as `v₀ = g₀²`.
17
- 2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
18
- from the *previous* step (`v_{t-1}`).
19
- 3. **Order of Operations:** This normalization occurs *before* updating the
20
- first-moment (momentum) estimate.
21
-
22
- Args:
23
- params (iterable): iterable of parameters to optimize or dicts defining
24
- parameter groups
25
- lr (float): learning rate (default: 1e-4)
26
- betas (tuple[float, float]): coefficients used for computing running
27
- averages of momentum and variance (default: (0.9, 0.9999))
28
- eps (float): term added to the denominator to improve
29
- numerical stability (default: 1e-6)
30
- weight_decay (float): weight decay (L2 penalty) (default: 0)
31
- clip_lambda (Callable, optional): A function that takes the current step
32
- and returns a value to clip the normalized gradient. Only used when
33
- `use_atan2` is False. (default: `lambda step: step**0.25`)
34
- vector_reshape (bool): whether to reshape 1D vectors into 2D
35
- matrices for low-rank compression (default: True).
36
- stochastic_rounding (bool): whether to use stochastic
37
- rounding for BF16 parameter updates (default: True).
38
- use_atan2 (bool): whether to use an atan2-based normalization, which can
39
- improve stability by removing the need for `eps`. (default: False)
40
- cautious_mask (bool): whether to use cautious masking to align the gradient's
41
- direction with the first moment's. (default: False)
42
- grams_moment (bool): whether to combine the gradient's direction with the
43
- first moment's magnitude (default: False).
44
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
45
- use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
46
- a second, slow-moving average of the momentum (`mt_slow`) which is
47
- combined with the primary momentum (`mt`) to stabilize updates,
48
- especially in noisy, small-batch settings. If `False`, the
49
- optimizer behaves as standard ADOPT. (default: False)
50
- beta3_ema (float): The decay rate for the slow exponential moving average of
51
- the momentum (only used when `use_AdEMAMix` is `True`). A higher
52
- value (e.g., 0.9999) gives the EMA a longer memory, making it more
53
- stable but slower to adapt. A lower value (e.g., 0.999) is often
54
- better for shorter training runs. (default: 0.9999)
55
- alpha (float): The mixing coefficient that scales the slow momentum term
56
- before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
57
- A higher value increases the stabilizing influence of the slow
58
- momentum. (default: 5.0)
59
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
60
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
61
- highly recommended to prevent instability at the beginning of training,
62
- as it gradually introduces the stabilizing slow momentum term. During
63
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
64
- the scheduler is disabled and the full `alpha` value is used from
65
- the start. (default: None)
66
- Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
67
- This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
68
- more responsive, especially for small batch sizes. Enabling this will
69
- automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
70
- and `use_atan2`. (default: False)
71
- alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
72
- (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
73
- current gradient. For small batch sizes, use high values (e.g., 10-100) to be
74
- more responsive. For large batch sizes, use low values (e.g., 0-1) for
75
- stability. (default: 100.0)
76
- kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
77
- If `False`, the optimizer behaves as standard Adopt. (default: False)
78
- beta2_min (float): The minimum value for dynamic β₂, used during periods of
79
- high gradient variance ("sunspikes"). Must be less than `betas[1]`.
80
- (default: 0.88)
81
- ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
82
- the pooled gradient norms. Corresponds to `α` in the paper.
83
- (default: 0.93)
84
- tiny_spike (float): A small constant added to the denominator of the
85
- "sunspike" ratio calculation to prevent division by zero. Corresponds
86
- to `ε_spike` in the paper. (default: 1e-9)
87
- k_warmup_steps (int): The number of initial steps during which β₂ is held
88
- at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
89
- dynamic logic activates. (default: 0)
90
- k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
- logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
- every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
- logging (default: 0).
94
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
- and returns a unique, hashable key representing its "layer" or "bucket".
96
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
- (default: None)
98
- nnmf_factor (bool): whether to use the factorization or disable it to use
99
- the uncompressed optimizer. (default: False)
100
- """
101
-
102
- def __init__(
103
- self,
104
- params,
105
- lr: float = 1e-4,
106
- betas: tuple[float, float] = (0.9, 0.9999),
107
- eps: float = 1e-6,
108
- weight_decay: float = 0.0,
109
- clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
110
- vector_reshape: bool = True,
111
- stochastic_rounding: bool = True,
112
- use_atan2: bool = False,
113
- cautious_mask: bool = False,
114
- grams_moment: bool = False,
115
- orthogonal_gradient: bool = False,
116
- use_AdEMAMix: bool = False,
117
- beta3_ema: float = 0.9999,
118
- alpha: float = 5.0,
119
- t_alpha: int | None = None,
120
- Simplified_AdEMAMix: bool = False,
121
- alpha_grad: float = 100.0,
122
- kourkoutas_beta: bool = False,
123
- beta2_min: float = 0.88,
124
- ema_alpha: float = 0.93,
125
- tiny_spike: float = 1e-9,
126
- k_warmup_steps: int = 0,
127
- k_logging: int = 0,
128
- layer_key_fn: Optional[Callable] = None,
129
- nnmf_factor: bool = False,
130
- ):
131
- if not (lr >= 0.0):
132
- raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
133
- if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
134
- raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
135
- if not (eps >= 0.0):
136
- raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
137
- if not (weight_decay >= 0.0):
138
- raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
139
- if cautious_mask and grams_moment:
140
- print("Warning: cautious is incompatible with grams, Disabling cautious.")
141
- cautious_mask = False
142
- if betas[0] == 0.0 and Simplified_AdEMAMix:
143
- raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
- if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
145
- if use_AdEMAMix and Simplified_AdEMAMix:
146
- print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
147
- if grams_moment and Simplified_AdEMAMix:
148
- print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
149
- if cautious_mask and Simplified_AdEMAMix:
150
- print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
151
- if use_atan2 and Simplified_AdEMAMix:
152
- print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
153
- use_atan2 = False
154
-
155
- defaults = {
156
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
157
- "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
158
- "t_alpha": t_alpha, "alpha_grad": alpha_grad,
159
- "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
160
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
161
- }
162
- self.clip_lambda = clip_lambda
163
- self.stochastic_rounding = stochastic_rounding
164
- self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
165
- self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
166
- self.grams_moment = grams_moment and not Simplified_AdEMAMix
167
- self.orthogonal_gradient = orthogonal_gradient
168
- self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
169
- self.Simplified_AdEMAMix = Simplified_AdEMAMix
170
- self.factored = nnmf_factor
171
- self.kourkoutas_beta = kourkoutas_beta
172
- self.layer_key_fn = layer_key_fn
173
- super().__init__(params, defaults)
174
-
175
- if self.kourkoutas_beta:
176
- self.kourkoutas_helper = KourkoutasHelper(self)
177
-
178
- @property
179
- def supports_fused_back_pass(self): return True
180
- @property
181
- def supports_memory_efficient_fp16(self): return True
182
- @property
183
- def supports_flat_params(self): return False
184
-
185
- @torch.no_grad()
186
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
187
- if p.grad is None:
188
- return
189
-
190
- grad = p.grad
191
- if self.factored and grad.dtype != torch.float32:
192
- grad = grad.float()
193
- if self.orthogonal_gradient:
194
- grad = _orthogonalize_gradient(p, grad)
195
- state = self.state[p]
196
-
197
- # State Initialization
198
- if len(state) == 0:
199
- state['step'] = 0
200
-
201
- should_factor = (
202
- self.factored and
203
- not (len(p.shape) == 1 and not group['vector_reshape'])
204
- )
205
-
206
- state['factored'] = should_factor
207
-
208
- dtype = torch.float32 if self.factored else p.dtype
209
-
210
- if state['factored']:
211
- state['effective_shape'] = _get_effective_shape(p.numel())
212
- d1, d2 = state['effective_shape']
213
-
214
- # m_0 = 0
215
- if group['betas'][0] > 0:
216
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
217
- state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
218
- if not self.grams_moment:
219
- packed_d2 = (d2 + 7) // 8
220
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
221
- if self.use_AdEMAMix:
222
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
223
- state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
224
- packed_d2 = (d2 + 7) // 8
225
- state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
226
- # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
227
- vt_init = grad.view(d1, d2).square_()
228
- # Allocate NMF factors for v
229
- state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
230
- state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
231
- # Initialize v_0 using NMF
232
- _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
233
- else: # Fallback for non-factored tensors
234
- if group['betas'][0] > 0:
235
- state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
236
- if self.use_AdEMAMix:
237
- state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
238
- state['exp_avg_sq'] = grad.square() # v_0
239
-
240
- beta1, beta2 = group['betas']
241
-
242
- current_step = state['step']
243
- if group['kourkoutas_beta']:
244
- # Call prepare_step() once at the beginning of the step for all params
245
- self.kourkoutas_helper.maybe_prepare_step(current_step)
246
- # Accumulate current grad's norm for the *next* step
247
- self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
248
- # Get the dynamic beta2 calculated in prepare_step()
249
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
250
-
251
- # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
252
- if state['step'] == 0 and not self.use_atan2:
253
- state['step'] += 1
254
- return
255
-
256
- if self.use_AdEMAMix:
257
- beta3_ema = group['beta3_ema']
258
- alpha = group['alpha']
259
- t_alpha = group['t_alpha']
260
- # Use step+1 for 1-based step count in scheduler
261
- alpha_step = state['step'] + 1
262
- alpha_t = alpha
263
- if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
264
- alpha_t = min(alpha_step * alpha / t_alpha, alpha)
265
- if self.Simplified_AdEMAMix:
266
- alpha_grad = group["alpha_grad"]
267
-
268
- if state['factored']:
269
- d1, d2 = state['effective_shape']
270
-
271
- # Reconstruct m_{t-1}
272
- if beta1 > 0:
273
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
274
- if not self.grams_moment:
275
- if state['sign'].dtype != torch.uint8:
276
- state['sign'] = state['sign'].to(torch.uint8)
277
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
278
- torch.where(unpacked_sign, mt, -mt, out=mt)
279
- del unpacked_sign
280
-
281
- # Reconstruct AdEMAMix EMA
282
- if self.use_AdEMAMix:
283
- mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
284
- if state['sign_slow'].dtype != torch.uint8:
285
- state['sign_slow'] = state['sign_slow'].to(torch.uint8)
286
- unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
287
- torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
288
- del unpacked_sign_slow
289
-
290
- # Reconstruct v_{t-1} using NNMF
291
- vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
292
-
293
- # ADOPT Step A: Decorrelate g_t using v_{t-1}
294
- grad_reshaped = grad.view(d1, d2)
295
- denom = vt.sqrt()
296
-
297
- if self.use_atan2:
298
- normalized_grad = torch.atan2(grad_reshaped, denom)
299
- else:
300
- normalized_grad = grad_reshaped / denom.add_(group['eps'])
301
- if self.clip_lambda is not None:
302
- clip_val = self.clip_lambda(state['step'])
303
- normalized_grad.clamp_(-clip_val, clip_val)
304
- del denom
305
-
306
- # ADOPT Step B: Update momentum m_t using normalized gradient
307
- if beta1 > 0:
308
- if self.Simplified_AdEMAMix:
309
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
310
- else:
311
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
312
- if self.grams_moment:
313
- mt = grad_reshaped.sign() * mt.abs()
314
- elif self.cautious_mask:
315
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
316
- mask.div_(mask.mean().clamp_(min=1e-3))
317
- mt.mul_(mask)
318
- del mask
319
-
320
- if self.use_AdEMAMix:
321
- mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
322
- if beta1 > 0:
323
- update = torch.add(mt, mt_slow, alpha=alpha_t)
324
- else:
325
- update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
326
- elif self.Simplified_AdEMAMix:
327
- update = torch.add(mt, normalized_grad, alpha=alpha_grad)
328
- else:
329
- update = mt.clone() if beta1 > 0 else normalized_grad
330
-
331
- update = update.view(p.shape)
332
-
333
- if self.use_atan2:
334
- update.mul_(group['lr'] * 1.2732395447351628)
335
- else:
336
- update.mul_(group['lr'])
337
-
338
- # Update second moment v_t for the *next* step using raw g_t
339
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
340
- del grad_reshaped
341
-
342
- # Compress and store new factors
343
- if beta1 > 0:
344
- if not self.grams_moment:
345
- state['sign'] = _pack_bools(mt > 0)
346
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
347
- del mt
348
-
349
- if self.use_AdEMAMix:
350
- state['sign_slow'] = _pack_bools(mt_slow > 0)
351
- _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
352
- del mt_slow
353
-
354
- # factorize v_t using NMF compression
355
- _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
356
- del vt
357
-
358
- else: # Standard ADOPT logic for non-factored tensors
359
- v = state['exp_avg_sq'] # v_{t-1}
360
-
361
- # ADOPT Step A: Decorrelate g_t using v_{t-1}
362
- denom = v.sqrt()
363
-
364
- if self.use_atan2:
365
- normalized_grad = torch.atan2(grad, denom)
366
- else:
367
- normalized_grad = grad / denom.add_(group['eps'])
368
- if self.clip_lambda is not None:
369
- clip_val = self.clip_lambda(state['step'])
370
- normalized_grad.clamp_(-clip_val, clip_val)
371
- del denom
372
-
373
- # ADOPT Step B: Update momentum m_t
374
- if beta1 > 0:
375
- m = state['exp_avg'] # m_{t-1},
376
- if self.Simplified_AdEMAMix:
377
- m.mul_(beta1).add_(normalized_grad, alpha=1.0)
378
- else:
379
- m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
380
-
381
- if self.grams_moment:
382
- m = grad.sign() * m.abs()
383
- elif self.cautious_mask:
384
- mask = (m * grad > 0).to(grad.dtype)
385
- mask.div_(mask.mean().clamp_(min=1e-3))
386
- m.mul_(mask)
387
- del mask
388
-
389
- if self.use_AdEMAMix:
390
- m_slow = state['exp_avg_slow']
391
- m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
392
- if beta1 > 0:
393
- update = torch.add(m, m_slow, alpha=alpha_t)
394
- else:
395
- update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
396
- elif self.Simplified_AdEMAMix:
397
- update = torch.add(m, normalized_grad, alpha=alpha_grad)
398
- else:
399
- update = m.clone() if beta1 > 0 else normalized_grad
400
-
401
- if self.use_atan2:
402
- update.mul_(group['lr'] * 1.2732395447351628)
403
- else:
404
- update.mul_(group['lr'])
405
-
406
- # Update second moment v_t for the next step using raw g_t
407
- v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
408
-
409
- # Parameter Update
410
- if group["weight_decay"] != 0:
411
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
412
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
413
- else:
414
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
415
-
416
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
417
- add_stochastic_(p.data, -update)
418
- else:
419
- p.data.add_(-update)
420
- del update
421
-
422
- state['step'] += 1
423
-
424
- @torch.no_grad()
425
- def step(self, closure=None):
426
- """Performs a single optimization step."""
427
- loss = None
428
- if closure is not None:
429
- with torch.enable_grad():
430
- loss = closure()
431
-
432
- for group in self.param_groups:
433
- for i, p in enumerate(group['params']):
434
- self.step_parameter(p, group, i)
435
-
436
- if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
437
- first_param_state = self.state[self.param_groups[0]['params'][0]]
438
- step_num = first_param_state['step']
439
-
1
+ import torch
2
+ from typing import Callable, Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf, _unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
+
11
+ class Adopt_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements a fusion of SMMF, and the ADOPT algorithm.
14
+
15
+ The ADOPT update rule modifies Adam by:
16
+ 1. **Initialization:** The second moment `v` is initialized as `v₀ = g₀²`.
17
+ 2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
18
+ from the *previous* step (`v_{t-1}`).
19
+ 3. **Order of Operations:** This normalization occurs *before* updating the
20
+ first-moment (momentum) estimate.
21
+
22
+ Args:
23
+ params (iterable): iterable of parameters to optimize or dicts defining
24
+ parameter groups
25
+ lr (float): learning rate (default: 1e-4)
26
+ betas (tuple[float, float]): coefficients used for computing running
27
+ averages of momentum and variance (default: (0.9, 0.9999))
28
+ eps (float): term added to the denominator to improve
29
+ numerical stability (default: 1e-6)
30
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
31
+ clip_lambda (Callable, optional): A function that takes the current step
32
+ and returns a value to clip the normalized gradient. Only used when
33
+ `use_atan2` is False. (default: `lambda step: step**0.25`)
34
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
35
+ matrices for low-rank compression (default: True).
36
+ stochastic_rounding (bool): whether to use stochastic
37
+ rounding for BF16 parameter updates (default: True).
38
+ use_atan2 (bool): whether to use an atan2-based normalization, which can
39
+ improve stability by removing the need for `eps`. (default: False)
40
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
41
+ direction with the first moment's. (default: False)
42
+ grams_moment (bool): whether to combine the gradient's direction with the
43
+ first moment's magnitude (default: False).
44
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
45
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
46
+ a second, slow-moving average of the momentum (`mt_slow`) which is
47
+ combined with the primary momentum (`mt`) to stabilize updates,
48
+ especially in noisy, small-batch settings. If `False`, the
49
+ optimizer behaves as standard ADOPT. (default: False)
50
+ beta3_ema (float): The decay rate for the slow exponential moving average of
51
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
52
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
53
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
54
+ better for shorter training runs. (default: 0.9999)
55
+ alpha (float): The mixing coefficient that scales the slow momentum term
56
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
57
+ A higher value increases the stabilizing influence of the slow
58
+ momentum. (default: 5.0)
59
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
60
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
61
+ highly recommended to prevent instability at the beginning of training,
62
+ as it gradually introduces the stabilizing slow momentum term. During
63
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
64
+ the scheduler is disabled and the full `alpha` value is used from
65
+ the start. (default: None)
66
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
67
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
68
+ more responsive, especially for small batch sizes. Enabling this will
69
+ automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
70
+ and `use_atan2`. (default: False)
71
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
72
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
73
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
74
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
75
+ stability. (default: 100.0)
76
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
77
+ If `False`, the optimizer behaves as standard Adopt. (default: False)
78
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
79
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
80
+ (default: 0.88)
81
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
82
+ the pooled gradient norms. Corresponds to `α` in the paper.
83
+ (default: 0.93)
84
+ tiny_spike (float): A small constant added to the denominator of the
85
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
86
+ to `ε_spike` in the paper. (default: 1e-9)
87
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
88
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
89
+ dynamic logic activates. (default: 0)
90
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
+ logging (default: 0).
94
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
+ and returns a unique, hashable key representing its "layer" or "bucket".
96
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
+ (default: None)
98
+ nnmf_factor (bool): whether to use the factorization or disable it to use
99
+ the uncompressed optimizer. (default: False)
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ params,
105
+ lr: float = 1e-4,
106
+ betas: tuple[float, float] = (0.9, 0.9999),
107
+ eps: float = 1e-6,
108
+ weight_decay: float = 0.0,
109
+ clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
110
+ vector_reshape: bool = True,
111
+ stochastic_rounding: bool = True,
112
+ use_atan2: bool = False,
113
+ cautious_mask: bool = False,
114
+ grams_moment: bool = False,
115
+ orthogonal_gradient: bool = False,
116
+ use_AdEMAMix: bool = False,
117
+ beta3_ema: float = 0.9999,
118
+ alpha: float = 5.0,
119
+ t_alpha: int | None = None,
120
+ Simplified_AdEMAMix: bool = False,
121
+ alpha_grad: float = 100.0,
122
+ kourkoutas_beta: bool = False,
123
+ beta2_min: float = 0.88,
124
+ ema_alpha: float = 0.93,
125
+ tiny_spike: float = 1e-9,
126
+ k_warmup_steps: int = 0,
127
+ k_logging: int = 0,
128
+ layer_key_fn: Optional[Callable] = None,
129
+ nnmf_factor: bool = False,
130
+ ):
131
+ if not (lr >= 0.0):
132
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
133
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
134
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
135
+ if not (eps >= 0.0):
136
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
137
+ if not (weight_decay >= 0.0):
138
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
139
+ if cautious_mask and grams_moment:
140
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
141
+ cautious_mask = False
142
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
143
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
145
+ if use_AdEMAMix and Simplified_AdEMAMix:
146
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
147
+ if grams_moment and Simplified_AdEMAMix:
148
+ print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
149
+ if cautious_mask and Simplified_AdEMAMix:
150
+ print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
151
+ if use_atan2 and Simplified_AdEMAMix:
152
+ print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
153
+ use_atan2 = False
154
+
155
+ defaults = {
156
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
157
+ "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
158
+ "t_alpha": t_alpha, "alpha_grad": alpha_grad,
159
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
160
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
161
+ }
162
+ self.clip_lambda = clip_lambda
163
+ self.stochastic_rounding = stochastic_rounding
164
+ self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
165
+ self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
166
+ self.grams_moment = grams_moment and not Simplified_AdEMAMix
167
+ self.orthogonal_gradient = orthogonal_gradient
168
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
169
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
170
+ self.factored = nnmf_factor
171
+ self.kourkoutas_beta = kourkoutas_beta
172
+ self.layer_key_fn = layer_key_fn
173
+ super().__init__(params, defaults)
174
+
175
+ if self.kourkoutas_beta:
176
+ self.kourkoutas_helper = KourkoutasHelper(self)
177
+
178
+ @property
179
+ def supports_fused_back_pass(self): return True
180
+ @property
181
+ def supports_memory_efficient_fp16(self): return True
182
+ @property
183
+ def supports_flat_params(self): return False
184
+
185
+ @torch.no_grad()
186
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
187
+ if p.grad is None:
188
+ return
189
+
190
+ grad = p.grad
191
+ if self.factored and grad.dtype != torch.float32:
192
+ grad = grad.float()
193
+ if self.orthogonal_gradient:
194
+ grad = _orthogonalize_gradient(p, grad)
195
+ state = self.state[p]
196
+
197
+ # State Initialization
198
+ if len(state) == 0:
199
+ state['step'] = 0
200
+
201
+ should_factor = (
202
+ self.factored and
203
+ not (len(p.shape) == 1 and not group['vector_reshape'])
204
+ )
205
+
206
+ state['factored'] = should_factor
207
+
208
+ dtype = torch.float32 if self.factored else p.dtype
209
+
210
+ if state['factored']:
211
+ state['effective_shape'] = _get_effective_shape(p.numel())
212
+ d1, d2 = state['effective_shape']
213
+
214
+ # m_0 = 0
215
+ if group['betas'][0] > 0:
216
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
217
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
218
+ if not self.grams_moment:
219
+ packed_d2 = (d2 + 7) // 8
220
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
221
+ if self.use_AdEMAMix:
222
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
223
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
224
+ packed_d2 = (d2 + 7) // 8
225
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
226
+ # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
227
+ vt_init = grad.view(d1, d2).square_()
228
+ # Allocate NMF factors for v
229
+ state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
230
+ state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
231
+ # Initialize v_0 using NMF
232
+ _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
233
+ else: # Fallback for non-factored tensors
234
+ if group['betas'][0] > 0:
235
+ state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
236
+ if self.use_AdEMAMix:
237
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
238
+ state['exp_avg_sq'] = grad.square() # v_0
239
+
240
+ beta1, beta2 = group['betas']
241
+
242
+ current_step = state['step']
243
+ if group['kourkoutas_beta']:
244
+ # Call prepare_step() once at the beginning of the step for all params
245
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
246
+ # Accumulate current grad's norm for the *next* step
247
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
248
+ # Get the dynamic beta2 calculated in prepare_step()
249
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
250
+
251
+ # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
252
+ if state['step'] == 0 and not self.use_atan2:
253
+ state['step'] += 1
254
+ return
255
+
256
+ if self.use_AdEMAMix:
257
+ beta3_ema = group['beta3_ema']
258
+ alpha = group['alpha']
259
+ t_alpha = group['t_alpha']
260
+ # Use step+1 for 1-based step count in scheduler
261
+ alpha_step = state['step'] + 1
262
+ alpha_t = alpha
263
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
264
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
265
+ if self.Simplified_AdEMAMix:
266
+ alpha_grad = group["alpha_grad"]
267
+
268
+ if state['factored']:
269
+ d1, d2 = state['effective_shape']
270
+
271
+ # Reconstruct m_{t-1}
272
+ if beta1 > 0:
273
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
274
+ if not self.grams_moment:
275
+ if state['sign'].dtype != torch.uint8:
276
+ state['sign'] = state['sign'].to(torch.uint8)
277
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
278
+ torch.where(unpacked_sign, mt, -mt, out=mt)
279
+ del unpacked_sign
280
+
281
+ # Reconstruct AdEMAMix EMA
282
+ if self.use_AdEMAMix:
283
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
284
+ if state['sign_slow'].dtype != torch.uint8:
285
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
286
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
287
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
288
+ del unpacked_sign_slow
289
+
290
+ # Reconstruct v_{t-1} using NNMF
291
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
292
+
293
+ # ADOPT Step A: Decorrelate g_t using v_{t-1}
294
+ grad_reshaped = grad.view(d1, d2)
295
+ denom = vt.sqrt()
296
+
297
+ if self.use_atan2:
298
+ normalized_grad = torch.atan2(grad_reshaped, denom)
299
+ else:
300
+ normalized_grad = grad_reshaped / denom.add_(group['eps'])
301
+ if self.clip_lambda is not None:
302
+ clip_val = self.clip_lambda(state['step'])
303
+ normalized_grad.clamp_(-clip_val, clip_val)
304
+ del denom
305
+
306
+ # ADOPT Step B: Update momentum m_t using normalized gradient
307
+ if beta1 > 0:
308
+ if self.Simplified_AdEMAMix:
309
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
310
+ else:
311
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
312
+ if self.grams_moment:
313
+ mt = grad_reshaped.sign() * mt.abs()
314
+ elif self.cautious_mask:
315
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
316
+ mask.div_(mask.mean().clamp_(min=1e-3))
317
+ mt.mul_(mask)
318
+ del mask
319
+
320
+ if self.use_AdEMAMix:
321
+ mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
322
+ if beta1 > 0:
323
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
324
+ else:
325
+ update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
326
+ elif self.Simplified_AdEMAMix:
327
+ update = torch.add(mt, normalized_grad, alpha=alpha_grad)
328
+ else:
329
+ update = mt.clone() if beta1 > 0 else normalized_grad
330
+
331
+ update = update.view(p.shape)
332
+
333
+ if self.use_atan2:
334
+ update.mul_(group['lr'] * 1.2732395447351628)
335
+ else:
336
+ update.mul_(group['lr'])
337
+
338
+ # Update second moment v_t for the *next* step using raw g_t
339
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
340
+ del grad_reshaped
341
+
342
+ # Compress and store new factors
343
+ if beta1 > 0:
344
+ if not self.grams_moment:
345
+ state['sign'] = _pack_bools(mt > 0)
346
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
347
+ del mt
348
+
349
+ if self.use_AdEMAMix:
350
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
351
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
352
+ del mt_slow
353
+
354
+ # factorize v_t using NMF compression
355
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
356
+ del vt
357
+
358
+ else: # Standard ADOPT logic for non-factored tensors
359
+ v = state['exp_avg_sq'] # v_{t-1}
360
+
361
+ # ADOPT Step A: Decorrelate g_t using v_{t-1}
362
+ denom = v.sqrt()
363
+
364
+ if self.use_atan2:
365
+ normalized_grad = torch.atan2(grad, denom)
366
+ else:
367
+ normalized_grad = grad / denom.add_(group['eps'])
368
+ if self.clip_lambda is not None:
369
+ clip_val = self.clip_lambda(state['step'])
370
+ normalized_grad.clamp_(-clip_val, clip_val)
371
+ del denom
372
+
373
+ # ADOPT Step B: Update momentum m_t
374
+ if beta1 > 0:
375
+ m = state['exp_avg'] # m_{t-1},
376
+ if self.Simplified_AdEMAMix:
377
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0)
378
+ else:
379
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
380
+
381
+ if self.grams_moment:
382
+ m = grad.sign() * m.abs()
383
+ elif self.cautious_mask:
384
+ mask = (m * grad > 0).to(grad.dtype)
385
+ mask.div_(mask.mean().clamp_(min=1e-3))
386
+ m.mul_(mask)
387
+ del mask
388
+
389
+ if self.use_AdEMAMix:
390
+ m_slow = state['exp_avg_slow']
391
+ m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
392
+ if beta1 > 0:
393
+ update = torch.add(m, m_slow, alpha=alpha_t)
394
+ else:
395
+ update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
396
+ elif self.Simplified_AdEMAMix:
397
+ update = torch.add(m, normalized_grad, alpha=alpha_grad)
398
+ else:
399
+ update = m.clone() if beta1 > 0 else normalized_grad
400
+
401
+ if self.use_atan2:
402
+ update.mul_(group['lr'] * 1.2732395447351628)
403
+ else:
404
+ update.mul_(group['lr'])
405
+
406
+ # Update second moment v_t for the next step using raw g_t
407
+ v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
408
+
409
+ # Parameter Update
410
+ if group["weight_decay"] != 0:
411
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
412
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
413
+ else:
414
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
415
+
416
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
417
+ add_stochastic_(p.data, -update)
418
+ else:
419
+ p.data.add_(-update)
420
+ del update
421
+
422
+ state['step'] += 1
423
+
424
+ @torch.no_grad()
425
+ def step(self, closure=None):
426
+ """Performs a single optimization step."""
427
+ loss = None
428
+ if closure is not None:
429
+ with torch.enable_grad():
430
+ loss = closure()
431
+
432
+ for group in self.param_groups:
433
+ for i, p in enumerate(group['params']):
434
+ self.step_parameter(p, group, i)
435
+
440
436
  return loss