adv-optm 0.1.9__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -1,478 +1,475 @@
1
- import torch
2
- import torch.distributed as dist
3
-
4
- import math
5
-
6
- from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
- from ..util.Effective_Shape import _get_effective_shape
8
- from ..util.NNMF import _nnmf,_unnmf
9
- from ..util.OrthoGrad import _orthogonalize_gradient
10
- from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
11
-
12
- class Prodigy_adv(torch.optim.Optimizer):
13
- """
14
- Implements a factored Prodigy/AdamW algorithm.
15
- This is an advanced version of Prodigy with optional features like
16
- low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
17
-
18
- Args:
19
- params (iterable): iterable of parameters to optimize or dicts defining
20
- parameter groups
21
- lr (float): learning rate (default: 1)
22
- betas (tuple[float, float]): coefficients used for computing running
23
- averages of gradient and its square (default: (0.9, 0.999))
24
- eps (float): term added to the denominator to improve
25
- numerical stability (default: 1e-8)
26
- weight_decay (float): weight decay (L2 penalty) (default: 0)
27
- vector_reshape (bool): whether to reshape 1D vectors into 2D
28
- matrices to apply low-rank compression (default: True).
29
- stochastic_rounding (bool): whether to use stochastic
30
- rounding for BF16 parameter updates (default: True).
31
- use_atan2 (bool): whether to use the atan2 update rule. (default: False)
32
- use_grams (bool): whether to use Grams-style updates. (default: False)
33
- use_cautious (bool): whether to use cautious masking to align the gradient's
34
- direction with the first moment's. (default: False)
35
- use_orthograd (bool): whether to use OrthoGrad. (default: False)
36
- use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
37
- a second, slow-moving average of the momentum (`mt_slow`) which is
38
- combined with the primary momentum (`mt`) to stabilize updates,
39
- especially in noisy, small-batch settings. If `False`, the
40
- optimizer behaves as standard AdamW. (default: False)
41
- beta3_ema (float): The decay rate for the slow exponential moving average of
42
- the momentum (only used when `use_AdEMAMix` is `True`). A higher
43
- value (e.g., 0.9999) gives the EMA a longer memory, making it more
44
- stable but slower to adapt. A lower value (e.g., 0.999) is often
45
- better for shorter training runs. (default: 0.9999)
46
- alpha (float): The mixing coefficient that scales the slow momentum term
47
- before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
48
- A higher value increases the stabilizing influence of the slow
49
- momentum. (default: 5.0)
50
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
51
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
52
- highly recommended to prevent instability at the beginning of training,
53
- as it gradually introduces the stabilizing slow momentum term. During
54
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
55
- the scheduler is disabled.
56
- Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
57
- This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
58
- more responsive, especially for small batch sizes. Enabling this will
59
- automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
60
- and `use_atan2`. (default: False)
61
- alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
62
- (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
63
- current gradient. For small batch sizes, use high values (e.g., 10-100) to be
64
- more responsive. For large batch sizes, use low values (e.g., 0-1) for
65
- stability. (default: 100.0)
66
- factored (bool): whether to use the factorization or disable it to use
67
- the uncompressed optimizer. (default: False)
68
- d0 (float):
69
- Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
70
- d_coef (float):
71
- Coefficient in the expression for the estimate of d (default 1.0).
72
- Values such as 0.5 and 2.0 typically work as well.
73
- Changing this parameter is the preferred way to tune the method.
74
- growth_rate (float):
75
- prevent the D estimate from growing faster than this multiplicative rate.
76
- Default is inf, for unrestricted. Values like 1.02 give a kind of learning
77
- rate warmup effect.
78
- fsdp_in_use (bool):
79
- If you're using sharded parameters, this should be set to True. The optimizer
80
- will attempt to auto-detect this, but if you're using an implementation other
81
- than PyTorch's builtin version, the auto-detection won't work.
82
- slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
83
- pth entry of each tensor. For values greater than 1 this an an approximation to standard
84
- Prodigy. Values ~11 are reasonable (default 11).
85
- prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
- after the specified optimiser step and release all state memory required by Prodigy
87
- (default: 0).
88
- """
89
-
90
- def __init__(
91
- self,
92
- params,
93
- lr: float = 1,
94
- betas: tuple[float, float] = (0.9, 0.999),
95
- eps: float = 1e-8,
96
- weight_decay: float = 0.0,
97
- vector_reshape: bool = True,
98
- stochastic_rounding: bool = True,
99
- use_atan2: bool = False,
100
- use_cautious: bool = False,
101
- use_grams: bool = False,
102
- use_orthograd: bool = False,
103
- use_AdEMAMix: bool = False,
104
- beta3_ema: float = 0.9999,
105
- alpha: float = 5.0,
106
- t_alpha: int | None = None,
107
- Simplified_AdEMAMix: bool = False,
108
- alpha_grad: float = 100.0,
109
- factored: bool = False,
110
- # prodigy parameters
111
- beta3: float = None,
112
- d0: float = 1e-6,
113
- d_coef: float = 1,
114
- growth_rate: float = float('inf'),
115
- safeguard_warmup: bool = False,
116
- fsdp_in_use: bool = False,
117
- slice_p: int = 11,
118
- prodigy_steps: int = 0,
119
- ):
120
- if not (lr >= 0.0):
121
- raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
122
- if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
123
- raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
124
- if not (eps >= 0.0):
125
- raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
126
- if not (weight_decay >= 0.0):
127
- raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
128
- if not (prodigy_steps >= 0):
129
- raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
130
- if use_cautious and use_grams:
131
- print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
132
- use_cautious = False
133
- if betas[0] == 0.0 and Simplified_AdEMAMix:
134
- raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
135
- if use_AdEMAMix and Simplified_AdEMAMix:
136
- print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
137
- if use_grams and Simplified_AdEMAMix:
138
- print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
139
- if use_cautious and Simplified_AdEMAMix:
140
- print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
141
- if use_atan2 and Simplified_AdEMAMix:
142
- print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
143
- use_atan2 = False
144
-
145
- defaults = {
146
- "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
147
- "vector_reshape": vector_reshape, "use_atan2": use_atan2,
148
- "use_orthograd": use_orthograd,
149
- "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
150
- "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
151
- "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
152
- "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
153
- "alpha_grad": alpha_grad,
154
- }
155
- self.stochastic_rounding = stochastic_rounding
156
- self.use_cautious = use_cautious and not Simplified_AdEMAMix
157
- self.use_grams = use_grams and not Simplified_AdEMAMix
158
- self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
159
- self.Simplified_AdEMAMix = Simplified_AdEMAMix
160
- self.factored = factored
161
- self.fsdp_in_use = fsdp_in_use
162
- super().__init__(params, defaults)
163
- self.init_step()
164
-
165
- @property
166
- def supports_fused_back_pass(self):
167
- return True
168
-
169
- @property
170
- def supports_memory_efficient_fp16(self):
171
- return True
172
-
173
- @property
174
- def supports_flat_params(self):
175
- return False
176
-
177
- def init_step(self):
178
- """Resets accumulators and calculates dlr for the upcoming step."""
179
- self.d_denom = 0.0
180
-
181
- g_group = self.param_groups[0]
182
- self.beta1, self.beta2 = g_group['betas']
183
- self.beta3 = g_group['beta3']
184
- if self.beta3 is None:
185
- self.beta3 = math.sqrt(self.beta2)
186
-
187
- k = g_group['k']
188
- self.d = g_group['d']
189
- lr = g_group['lr']
190
-
191
- self.dlr = self.d * lr
192
-
193
- self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
194
-
195
- @torch.no_grad()
196
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
197
- if p.grad is None:
198
- return
199
-
200
- if hasattr(p, "_fsdp_flattened"):
201
- self.fsdp_in_use = True
202
-
203
- grad = p.grad
204
- if grad.dtype != torch.float32 and self.factored:
205
- grad = grad.float()
206
- if group["use_orthograd"]:
207
- grad = _orthogonalize_gradient(p, grad)
208
- state = self.state[p]
209
-
210
- # State Initialization
211
- if len(state) == 0:
212
- state['step'] = 0
213
-
214
- should_factor = (
215
- self.factored and
216
- not (len(p.shape) == 1 and not group['vector_reshape'])
217
- )
218
-
219
- state['factored'] = should_factor
220
-
221
- slice_p = group['slice_p']
222
-
223
- dtype = torch.float32 if self.factored else p.dtype
224
- device = p.device
225
-
226
- if state['factored']:
227
- state['effective_shape'] = _get_effective_shape(p.numel())
228
- d1, d2 = state['effective_shape']
229
-
230
- # First moment (m)
231
- if self.beta1 > 0:
232
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
233
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
234
- if not self.use_grams:
235
- packed_d2 = (d2 + 7) // 8
236
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
237
- if self.use_AdEMAMix:
238
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
239
- state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
240
- packed_d2 = (d2 + 7) // 8
241
- state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
242
- # Second moment (v)
243
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
244
- state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
245
- else: # Fallback to standard AdamW for non-factored tensors
246
- if self.beta1 > 0:
247
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
248
- if self.use_AdEMAMix:
249
- state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
250
- state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
251
-
252
- state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
253
- if p.any():
254
- state['p0'] = p.flatten()[::slice_p].detach().clone()
255
- else:
256
- state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
257
-
258
- if self.use_AdEMAMix:
259
- beta3_ema = group['beta3_ema']
260
- alpha = group['alpha']
261
- t_alpha = group['t_alpha']
262
- current_step = state['step'] + 1
263
- alpha_t = alpha
264
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
265
- alpha_t = min(current_step * alpha / t_alpha, alpha)
266
- if self.Simplified_AdEMAMix:
267
- alpha_grad = group["alpha_grad"]
268
-
269
- if state['factored']:
270
- d1, d2 = state['effective_shape']
271
-
272
- grad_reshaped = grad.view(d1, d2)
273
-
274
- # Reconstruct momentum from previous step's factors
275
- if self.beta1 > 0:
276
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
277
- if not self.use_grams:
278
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
279
- torch.where(unpacked_sign, mt, -mt, out=mt)
280
- del unpacked_sign
281
- # Update momentum in full-size
282
- if self.Simplified_AdEMAMix:
283
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
284
- else:
285
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
286
- if self.use_grams:
287
- mt.copy_(grad_reshaped.sign() * mt.abs())
288
- elif self.use_cautious:
289
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
290
- mask.div_(mask.mean().clamp_(min=1e-3))
291
- mt.mul_(mask)
292
- del mask
293
-
294
- vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
295
- vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
296
-
297
- if self.use_AdEMAMix:
298
- mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
299
- if state['sign_slow'].dtype != torch.uint8:
300
- state['sign_slow'] = state['sign_slow'].to(torch.uint8)
301
- unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
302
- torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
303
- del unpacked_sign_slow
304
- mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
305
- if self.beta1 > 0:
306
- update = torch.add(mt, mt_slow, alpha=alpha_t)
307
- else:
308
- update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
309
- elif self.Simplified_AdEMAMix:
310
- update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
311
- else:
312
- update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
313
- del grad_reshaped
314
-
315
- if group['use_atan2']:
316
- a = 1.2732395
317
- denom = vt.sqrt()
318
- update.atan2_(denom).mul_(a)
319
- else:
320
- denom = vt.sqrt()
321
- update.div_(denom.add_(self.d * group['eps']))
322
- del denom
323
-
324
- update = update.view(p.shape).mul_(self.dlr)
325
-
326
- # Compress updated moments and store new factors
327
- if self.beta1 > 0:
328
- if not self.use_grams:
329
- state['sign'] = _pack_bools(mt > 0)
330
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
331
- del mt
332
- if self.use_AdEMAMix:
333
- state['sign_slow'] = _pack_bools(mt_slow > 0)
334
- _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
335
- del mt_slow
336
- _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
337
- del vt
338
-
339
- else: # Standard AdamW logic for non-factored tensors
340
- exp_avg_sq = state['exp_avg_sq']
341
-
342
- if self.beta1 > 0:
343
- exp_avg = state['exp_avg']
344
- if self.Simplified_AdEMAMix:
345
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
346
- else:
347
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
348
- if self.use_grams:
349
- exp_avg = grad.sign() * exp_avg.abs()
350
- elif self.use_cautious:
351
- mask = (exp_avg * grad > 0).to(grad.dtype)
352
- mask.div_(mask.mean().clamp_(min=1e-3))
353
- exp_avg.mul_(mask)
354
- del mask
355
-
356
- if self.use_AdEMAMix:
357
- exp_avg_slow = state['exp_avg_slow']
358
- exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
359
- if self.beta1 > 0:
360
- update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
361
- else:
362
- update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
363
- elif self.Simplified_AdEMAMix:
364
- update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
365
- else:
366
- update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
367
-
368
- exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
369
-
370
- if group['use_atan2']:
371
- a = 1.2732395
372
- denom = exp_avg_sq.sqrt()
373
- update.atan2_(denom).mul_(a)
374
- else:
375
- denom = exp_avg_sq.sqrt()
376
- update.div_(denom.add_(self.d * group['eps']))
377
- del denom
378
-
379
- update.mul_(self.dlr)
380
-
381
- # --- Accumulate Prodigy stats ---
382
- prodigy_steps = group['prodigy_steps']
383
- if prodigy_steps <= 0 or group['k'] < prodigy_steps:
384
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
385
- s, p0 = state['s'], state['p0']
386
- grad_flat = grad.flatten().float()
387
- p_flat = p.data.flatten().float()
388
- p0 = p0.float()
389
-
390
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
391
-
392
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
393
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
394
- self.d_denom += s.abs().sum().item()
395
-
396
- del s, p0, grad_flat, p_flat, alpha
397
- else:
398
- # Free memory if prodigy_steps is reached
399
- if 's' in state:
400
- del state['s']
401
- if 'p0' in state:
402
- del state['p0']
403
-
404
- # Decoupled weight decay
405
- if group["weight_decay"] != 0:
406
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
407
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
408
- else:
409
- p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
410
-
411
- if p.dtype == torch.bfloat16 and self.stochastic_rounding:
412
- add_stochastic_(p.data, -update)
413
- else:
414
- p.data.add_(-update)
415
- del update
416
-
417
- state['step'] += 1
418
-
419
- @torch.no_grad()
420
- def step(self, closure=None):
421
- """Performs a single optimization step."""
422
- loss = None
423
- if closure is not None:
424
- with torch.enable_grad():
425
- loss = closure()
426
-
427
- for group in self.param_groups:
428
- for i, p in enumerate(group['params']):
429
- self.step_parameter(p, group, i)
430
-
431
-
432
- self.calculate_d()
433
- self.init_step()
434
- return loss
435
-
436
- def calculate_d(self):
437
- """Calculates the new `d` based on the accumulated stats."""
438
- g_group = self.param_groups[0]
439
-
440
- # Only perform d-adaptation if prodigy_steps has not been reached
441
- prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
442
-
443
- if prodigy_active:
444
- d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
445
-
446
- if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
447
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
448
- device = self.param_groups[0]['params'][0].device
449
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
450
- dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
451
- global_d_numerator = dist_tensor[0].item()
452
- global_d_denom = dist_tensor[1].item()
453
- else:
454
- global_d_numerator = self.d_numerator
455
- global_d_denom = self.d_denom
456
-
457
- d_hat = self.d
458
- if global_d_denom > 0:
459
- if self.Simplified_AdEMAMix and g_group['alpha_grad'] > 0:
460
- # A simple and effective hack to make prodigy compatible with Simplified_AdEMAMix large step sizes
461
- # by diving by alpha_grad we make sure that d_numerator that was influenced by (alpha_grad * grad)
462
- # are now normalized by /alpha_grad. this is a heuristic way since the update is also influenced by
463
- # the increasing and decaying accumulator but it's effective and it worked for me (for Lora/Finetune).
464
- global_d_numerator /= g_group['alpha_grad']
465
- d_hat = d_coef * global_d_numerator / global_d_denom
466
- if self.d == g_group['d0']:
467
- self.d = max(self.d, d_hat)
468
- d_max = max(d_max, d_hat)
469
- self.d = min(d_max, self.d * growth_rate)
470
-
471
- for group in self.param_groups:
472
- group['d_numerator'] = global_d_numerator
473
- group['d'] = self.d
474
- group['d_max'] = d_max
475
-
476
- # Increment step counter for all groups, regardless of whether d was updated
477
- for group in self.param_groups:
478
- group['k'] += 1
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+
6
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
+ from ..util.Effective_Shape import _get_effective_shape
8
+ from ..util.NNMF import _nnmf,_unnmf
9
+ from ..util.OrthoGrad import _orthogonalize_gradient
10
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
11
+
12
+ class Prodigy_adv(torch.optim.Optimizer):
13
+ """
14
+ Implements a factored Prodigy/AdamW algorithm.
15
+ This is an advanced version of Prodigy with optional features like
16
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups
21
+ lr (float): learning rate (default: 1)
22
+ betas (tuple[float, float]): coefficients used for computing running
23
+ averages of gradient and its square (default: (0.9, 0.999))
24
+ eps (float): term added to the denominator to improve
25
+ numerical stability (default: 1e-8)
26
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
27
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
28
+ matrices to apply low-rank compression (default: True).
29
+ stochastic_rounding (bool): whether to use stochastic
30
+ rounding for BF16 parameter updates (default: True).
31
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
32
+ grams_moment (bool): whether to use Grams-style updates. (default: False)
33
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
34
+ direction with the first moment's. (default: False)
35
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
36
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
37
+ a second, slow-moving average of the momentum (`mt_slow`) which is
38
+ combined with the primary momentum (`mt`) to stabilize updates,
39
+ especially in noisy, small-batch settings. If `False`, the
40
+ optimizer behaves as standard AdamW. (default: False)
41
+ beta3_ema (float): The decay rate for the slow exponential moving average of
42
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
43
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
44
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
45
+ better for shorter training runs. (default: 0.9999)
46
+ alpha (float): The mixing coefficient that scales the slow momentum term
47
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
48
+ A higher value increases the stabilizing influence of the slow
49
+ momentum. (default: 5.0)
50
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
51
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
52
+ highly recommended to prevent instability at the beginning of training,
53
+ as it gradually introduces the stabilizing slow momentum term. During
54
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
55
+ the scheduler is disabled.
56
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
57
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
58
+ more responsive, especially for small batch sizes. Enabling this will
59
+ automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
60
+ and `use_atan2`. (default: False)
61
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
62
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
63
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
64
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
65
+ stability. (default: 100.0)
66
+ nnmf_factor (bool): whether to use the factorization or disable it to use
67
+ the uncompressed optimizer. (default: False)
68
+ d0 (float):
69
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
70
+ d_coef (float):
71
+ Coefficient in the expression for the estimate of d (default 1.0).
72
+ Values such as 0.5 and 2.0 typically work as well.
73
+ Changing this parameter is the preferred way to tune the method.
74
+ growth_rate (float):
75
+ prevent the D estimate from growing faster than this multiplicative rate.
76
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
77
+ rate warmup effect.
78
+ fsdp_in_use (bool):
79
+ If you're using sharded parameters, this should be set to True. The optimizer
80
+ will attempt to auto-detect this, but if you're using an implementation other
81
+ than PyTorch's builtin version, the auto-detection won't work.
82
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
83
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
84
+ Prodigy. Values ~11 are reasonable (default 11).
85
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
+ after the specified optimiser step and release all state memory required by Prodigy
87
+ (default: 0).
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ params,
93
+ lr: float = 1,
94
+ betas: tuple[float, float] = (0.9, 0.999),
95
+ eps: float = 1e-8,
96
+ weight_decay: float = 0.0,
97
+ vector_reshape: bool = True,
98
+ stochastic_rounding: bool = True,
99
+ use_atan2: bool = False,
100
+ cautious_mask: bool = False,
101
+ grams_moment: bool = False,
102
+ use_orthograd: bool = False,
103
+ use_AdEMAMix: bool = False,
104
+ beta3_ema: float = 0.9999,
105
+ alpha: float = 5.0,
106
+ t_alpha: int | None = None,
107
+ Simplified_AdEMAMix: bool = False,
108
+ alpha_grad: float = 100.0,
109
+ nnmf_factor: bool = False,
110
+ # prodigy parameters
111
+ beta3: float = None,
112
+ d0: float = 1e-6,
113
+ d_coef: float = 1,
114
+ growth_rate: float = float('inf'),
115
+ safeguard_warmup: bool = False,
116
+ fsdp_in_use: bool = False,
117
+ slice_p: int = 11,
118
+ prodigy_steps: int = 0,
119
+ ):
120
+ if not (lr >= 0.0):
121
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
122
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
123
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
124
+ if not (eps >= 0.0):
125
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
126
+ if not (weight_decay >= 0.0):
127
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
128
+ if not (prodigy_steps >= 0):
129
+ raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
130
+ if cautious_mask and grams_moment:
131
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
132
+ cautious_mask = False
133
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
134
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
135
+ if use_AdEMAMix and Simplified_AdEMAMix:
136
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
137
+ if grams_moment and Simplified_AdEMAMix:
138
+ print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
139
+ if cautious_mask and Simplified_AdEMAMix:
140
+ print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
141
+ if use_atan2 and Simplified_AdEMAMix:
142
+ print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
143
+ use_atan2 = False
144
+ if Simplified_AdEMAMix and alpha_grad > 0:
145
+ # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
146
+ d_coef = d_coef/alpha_grad
147
+
148
+ defaults = {
149
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
150
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
151
+ "use_orthograd": use_orthograd,
152
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
153
+ "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
154
+ "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
155
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
156
+ "alpha_grad": alpha_grad,
157
+ }
158
+ self.stochastic_rounding = stochastic_rounding
159
+ self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
160
+ self.grams_moment = grams_moment and not Simplified_AdEMAMix
161
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
162
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
163
+ self.factored = nnmf_factor
164
+ self.fsdp_in_use = fsdp_in_use
165
+ super().__init__(params, defaults)
166
+ self.init_step()
167
+
168
+ @property
169
+ def supports_fused_back_pass(self):
170
+ return True
171
+
172
+ @property
173
+ def supports_memory_efficient_fp16(self):
174
+ return True
175
+
176
+ @property
177
+ def supports_flat_params(self):
178
+ return False
179
+
180
+ def init_step(self):
181
+ """Resets accumulators and calculates dlr for the upcoming step."""
182
+ self.d_denom = 0.0
183
+
184
+ g_group = self.param_groups[0]
185
+ self.beta1, self.beta2 = g_group['betas']
186
+ self.beta3 = g_group['beta3']
187
+ if self.beta3 is None:
188
+ self.beta3 = math.sqrt(self.beta2)
189
+
190
+ k = g_group['k']
191
+ self.d = g_group['d']
192
+ lr = g_group['lr']
193
+
194
+ self.dlr = self.d * lr
195
+
196
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
197
+
198
+ @torch.no_grad()
199
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
200
+ if p.grad is None:
201
+ return
202
+
203
+ if hasattr(p, "_fsdp_flattened"):
204
+ self.fsdp_in_use = True
205
+
206
+ grad = p.grad
207
+ if grad.dtype != torch.float32 and self.factored:
208
+ grad = grad.float()
209
+ if group["use_orthograd"]:
210
+ grad = _orthogonalize_gradient(p, grad)
211
+ state = self.state[p]
212
+
213
+ # State Initialization
214
+ if len(state) == 0:
215
+ state['step'] = 0
216
+
217
+ should_factor = (
218
+ self.factored and
219
+ not (len(p.shape) == 1 and not group['vector_reshape'])
220
+ )
221
+
222
+ state['factored'] = should_factor
223
+
224
+ slice_p = group['slice_p']
225
+
226
+ dtype = torch.float32 if self.factored else p.dtype
227
+ device = p.device
228
+
229
+ if state['factored']:
230
+ state['effective_shape'] = _get_effective_shape(p.numel())
231
+ d1, d2 = state['effective_shape']
232
+
233
+ # First moment (m)
234
+ if self.beta1 > 0:
235
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
236
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
237
+ if not self.grams_moment:
238
+ packed_d2 = (d2 + 7) // 8
239
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
240
+ if self.use_AdEMAMix:
241
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
242
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
243
+ packed_d2 = (d2 + 7) // 8
244
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
245
+ # Second moment (v)
246
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
247
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
248
+ else: # Fallback to standard AdamW for non-factored tensors
249
+ if self.beta1 > 0:
250
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
251
+ if self.use_AdEMAMix:
252
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
253
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
254
+
255
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
256
+ if p.any():
257
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
258
+ else:
259
+ state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
260
+
261
+ if self.use_AdEMAMix:
262
+ beta3_ema = group['beta3_ema']
263
+ alpha = group['alpha']
264
+ t_alpha = group['t_alpha']
265
+ current_step = state['step'] + 1
266
+ alpha_t = alpha
267
+ if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
268
+ alpha_t = min(current_step * alpha / t_alpha, alpha)
269
+ if self.Simplified_AdEMAMix:
270
+ alpha_grad = group["alpha_grad"]
271
+
272
+ if state['factored']:
273
+ d1, d2 = state['effective_shape']
274
+
275
+ grad_reshaped = grad.view(d1, d2)
276
+
277
+ # Reconstruct momentum from previous step's factors
278
+ if self.beta1 > 0:
279
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
280
+ if not self.grams_moment:
281
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
282
+ torch.where(unpacked_sign, mt, -mt, out=mt)
283
+ del unpacked_sign
284
+ # Update momentum in full-size
285
+ if self.Simplified_AdEMAMix:
286
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
287
+ else:
288
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
289
+ if self.grams_moment:
290
+ mt.copy_(grad_reshaped.sign() * mt.abs())
291
+ elif self.cautious_mask:
292
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
293
+ mask.div_(mask.mean().clamp_(min=1e-3))
294
+ mt.mul_(mask)
295
+ del mask
296
+
297
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
298
+ vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
299
+
300
+ if self.use_AdEMAMix:
301
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
302
+ if state['sign_slow'].dtype != torch.uint8:
303
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
304
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
305
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
306
+ del unpacked_sign_slow
307
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
308
+ if self.beta1 > 0:
309
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
310
+ else:
311
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
312
+ elif self.Simplified_AdEMAMix:
313
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
314
+ else:
315
+ update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
316
+ del grad_reshaped
317
+
318
+ if group['use_atan2']:
319
+ a = 1.2732395
320
+ denom = vt.sqrt()
321
+ update.atan2_(denom).mul_(a)
322
+ else:
323
+ denom = vt.sqrt()
324
+ update.div_(denom.add_(self.d * group['eps']))
325
+ del denom
326
+
327
+ update = update.view(p.shape).mul_(self.dlr)
328
+
329
+ # Compress updated moments and store new factors
330
+ if self.beta1 > 0:
331
+ if not self.grams_moment:
332
+ state['sign'] = _pack_bools(mt > 0)
333
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
334
+ del mt
335
+ if self.use_AdEMAMix:
336
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
337
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
338
+ del mt_slow
339
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
340
+ del vt
341
+
342
+ else: # Standard AdamW logic for non-factored tensors
343
+ exp_avg_sq = state['exp_avg_sq']
344
+
345
+ if self.beta1 > 0:
346
+ exp_avg = state['exp_avg']
347
+ if self.Simplified_AdEMAMix:
348
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
349
+ else:
350
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
351
+ if self.grams_moment:
352
+ exp_avg = grad.sign() * exp_avg.abs()
353
+ elif self.cautious_mask:
354
+ mask = (exp_avg * grad > 0).to(grad.dtype)
355
+ mask.div_(mask.mean().clamp_(min=1e-3))
356
+ exp_avg.mul_(mask)
357
+ del mask
358
+
359
+ if self.use_AdEMAMix:
360
+ exp_avg_slow = state['exp_avg_slow']
361
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
362
+ if self.beta1 > 0:
363
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
364
+ else:
365
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
366
+ elif self.Simplified_AdEMAMix:
367
+ update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
368
+ else:
369
+ update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
370
+
371
+ exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
372
+
373
+ if group['use_atan2']:
374
+ a = 1.2732395
375
+ denom = exp_avg_sq.sqrt()
376
+ update.atan2_(denom).mul_(a)
377
+ else:
378
+ denom = exp_avg_sq.sqrt()
379
+ update.div_(denom.add_(self.d * group['eps']))
380
+ del denom
381
+
382
+ update.mul_(self.dlr)
383
+
384
+ # --- Accumulate Prodigy stats ---
385
+ prodigy_steps = group['prodigy_steps']
386
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
387
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
388
+ s, p0 = state['s'], state['p0']
389
+ grad_flat = grad.flatten().float()
390
+ p_flat = p.data.flatten().float()
391
+ p0 = p0.float()
392
+
393
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
394
+
395
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
396
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
397
+ self.d_denom += s.abs().sum().item()
398
+
399
+ del s, p0, grad_flat, p_flat, alpha
400
+ else:
401
+ # Free memory if prodigy_steps is reached
402
+ if 's' in state:
403
+ del state['s']
404
+ if 'p0' in state:
405
+ del state['p0']
406
+
407
+ # Decoupled weight decay
408
+ if group["weight_decay"] != 0:
409
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
410
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
411
+ else:
412
+ p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
413
+
414
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
415
+ add_stochastic_(p.data, -update)
416
+ else:
417
+ p.data.add_(-update)
418
+ del update
419
+
420
+ state['step'] += 1
421
+
422
+ @torch.no_grad()
423
+ def step(self, closure=None):
424
+ """Performs a single optimization step."""
425
+ loss = None
426
+ if closure is not None:
427
+ with torch.enable_grad():
428
+ loss = closure()
429
+
430
+ for group in self.param_groups:
431
+ for i, p in enumerate(group['params']):
432
+ self.step_parameter(p, group, i)
433
+
434
+
435
+ self.calculate_d()
436
+ self.init_step()
437
+ return loss
438
+
439
+ def calculate_d(self):
440
+ """Calculates the new `d` based on the accumulated stats."""
441
+ g_group = self.param_groups[0]
442
+
443
+ # Only perform d-adaptation if prodigy_steps has not been reached
444
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
445
+
446
+ if prodigy_active:
447
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
448
+
449
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
450
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
451
+ device = self.param_groups[0]['params'][0].device
452
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
453
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
454
+ global_d_numerator = dist_tensor[0].item()
455
+ global_d_denom = dist_tensor[1].item()
456
+ else:
457
+ global_d_numerator = self.d_numerator
458
+ global_d_denom = self.d_denom
459
+
460
+ d_hat = self.d
461
+ if global_d_denom > 0:
462
+ d_hat = d_coef * global_d_numerator / global_d_denom
463
+ if self.d == g_group['d0']:
464
+ self.d = max(self.d, d_hat)
465
+ d_max = max(d_max, d_hat)
466
+ self.d = min(d_max, self.d * growth_rate)
467
+
468
+ for group in self.param_groups:
469
+ group['d_numerator'] = global_d_numerator
470
+ group['d'] = self.d
471
+ group['d_max'] = d_max
472
+
473
+ # Increment step counter for all groups, regardless of whether d was updated
474
+ for group in self.param_groups:
475
+ group['k'] += 1