adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,539 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+
6
+ from typing import Optional, Callable
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Effective_Shape import _get_effective_shape
10
+ from ..util.NNMF import _nnmf,_unnmf
11
+ from ..util.OrthoGrad import _orthogonalize_gradient
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+ from ..util.Kourkoutas import KourkoutasHelper
14
+
15
+ class Prodigy_adv(torch.optim.Optimizer):
16
+ """
17
+ Implements an advanced Prodigy algorithm.
18
+ This is an advanced version of Prodigy with optional features like
19
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
20
+
21
+ Args:
22
+ params (iterable): iterable of parameters to optimize or dicts defining
23
+ parameter groups
24
+ lr (float): learning rate (default: 1)
25
+ betas (tuple[float, float]): coefficients used for computing running
26
+ averages of gradient and its square (default: (0.9, 0.999))
27
+ eps (float): term added to the denominator to improve
28
+ numerical stability (default: 1e-8)
29
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
30
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
31
+ matrices to apply low-rank compression (default: True).
32
+ stochastic_rounding (bool): whether to use stochastic
33
+ rounding for BF16 parameter updates (default: True).
34
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
35
+ grams_moment (bool): whether to use Grams-style updates. (default: False)
36
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
37
+ direction with the first moment's. (default: False)
38
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
39
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
40
+ a second, slow-moving average of the momentum (`mt_slow`) which is
41
+ combined with the primary momentum (`mt`) to stabilize updates,
42
+ especially in noisy, small-batch settings. If `False`, the
43
+ optimizer behaves as standard AdamW. (default: False)
44
+ beta3_ema (float): The decay rate for the slow exponential moving average of
45
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
46
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
47
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
48
+ better for shorter training runs. (default: 0.9999)
49
+ alpha (float): The mixing coefficient that scales the slow momentum term
50
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
51
+ A higher value increases the stabilizing influence of the slow
52
+ momentum. (default: 5.0)
53
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
54
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
55
+ highly recommended to prevent instability at the beginning of training,
56
+ as it gradually introduces the stabilizing slow momentum term. During
57
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
58
+ the scheduler is disabled.
59
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
60
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
61
+ more responsive, especially for small batch sizes. Enabling this will
62
+ automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
63
+ and `use_atan2`. (default: False)
64
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
65
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
66
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
67
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
68
+ stability. (default: 100.0)
69
+ nnmf_factor (bool): whether to use the factorization or disable it to use
70
+ the uncompressed optimizer. (default: False)
71
+ d0 (float):
72
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
73
+ d_coef (float):
74
+ Coefficient in the expression for the estimate of d (default 1.0).
75
+ Values such as 0.5 and 2.0 typically work as well.
76
+ Changing this parameter is the preferred way to tune the method.
77
+ growth_rate (float):
78
+ prevent the D estimate from growing faster than this multiplicative rate.
79
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
80
+ rate warmup effect.
81
+ fsdp_in_use (bool):
82
+ If you're using sharded parameters, this should be set to True. The optimizer
83
+ will attempt to auto-detect this, but if you're using an implementation other
84
+ than PyTorch's builtin version, the auto-detection won't work.
85
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
86
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
87
+ Prodigy. Values ~11 are reasonable (default 11).
88
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
89
+ after the specified optimiser step and release all state memory required by Prodigy
90
+ (default: 0).
91
+ d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
92
+ to prevent sudden, volatile increases in the adaptive step size (`d`).
93
+ (default: False)
94
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
95
+ If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
96
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
97
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
98
+ (default: 0.88)
99
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
100
+ the pooled gradient norms. Corresponds to `α` in the paper.
101
+ (default: 0.93)
102
+ tiny_spike (float): A small constant added to the denominator of the
103
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
104
+ to `ε_spike` in the paper. (default: 1e-9)
105
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
106
+ at a fixed beta2 value before the
107
+ dynamic logic activates. (default: 0)
108
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
109
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
+ logging (default: 0).
112
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
113
+ and returns a unique, hashable key representing its "layer" or "bucket".
114
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
115
+ (default: None)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ params,
121
+ lr: float = 1,
122
+ betas: tuple[float, float] = (0.9, 0.999),
123
+ eps: float = 1e-8,
124
+ weight_decay: float = 0.0,
125
+ vector_reshape: bool = True,
126
+ stochastic_rounding: bool = True,
127
+ use_atan2: bool = False,
128
+ cautious_mask: bool = False,
129
+ grams_moment: bool = False,
130
+ orthogonal_gradient: bool = False,
131
+ use_AdEMAMix: bool = False,
132
+ beta3_ema: float = 0.9999,
133
+ alpha: float = 5.0,
134
+ t_alpha: int | None = None,
135
+ Simplified_AdEMAMix: bool = False,
136
+ alpha_grad: float = 100.0,
137
+ nnmf_factor: bool = False,
138
+ # prodigy parameters
139
+ beta3: float = None,
140
+ d0: float = 1e-6,
141
+ d_coef: float = 1,
142
+ growth_rate: float = float('inf'),
143
+ safeguard_warmup: bool = False,
144
+ fsdp_in_use: bool = False,
145
+ slice_p: int = 11,
146
+ prodigy_steps: int = 0,
147
+ d_limiter: bool = False,
148
+ # K-b parameters
149
+ kourkoutas_beta: bool = False,
150
+ beta2_min: float = 0.9,
151
+ ema_alpha: float = 0.95,
152
+ tiny_spike: float = 1e-9,
153
+ k_warmup_steps: int = 0,
154
+ k_logging: int = 0,
155
+ layer_key_fn: Optional[Callable] = None,
156
+ ):
157
+ if not (lr >= 0.0):
158
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
159
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
160
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
161
+ if not (eps >= 0.0):
162
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
163
+ if not (weight_decay >= 0.0):
164
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
165
+ if not (prodigy_steps >= 0):
166
+ raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
167
+ if cautious_mask and grams_moment:
168
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
169
+ cautious_mask = False
170
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
171
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
172
+ if use_AdEMAMix and Simplified_AdEMAMix:
173
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
174
+ if grams_moment and Simplified_AdEMAMix:
175
+ print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
176
+ if cautious_mask and Simplified_AdEMAMix:
177
+ print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
178
+ if use_atan2 and Simplified_AdEMAMix:
179
+ print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
180
+ use_atan2 = False
181
+ if kourkoutas_beta and not (betas[1] > beta2_min):
182
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
183
+ if Simplified_AdEMAMix and alpha_grad > 0 and not d_limiter:
184
+ # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix.
185
+ d_coef = d_coef/alpha_grad
186
+
187
+ defaults = {
188
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
189
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
190
+ "orthogonal_gradient": orthogonal_gradient,
191
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
192
+ "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
193
+ "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
194
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps, "d_limiter": d_limiter,
195
+ "alpha_grad": alpha_grad,
196
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
197
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
198
+ }
199
+ self.stochastic_rounding = stochastic_rounding
200
+ self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
201
+ self.grams_moment = grams_moment and not Simplified_AdEMAMix
202
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
203
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
204
+ self.factored = nnmf_factor
205
+ self.fsdp_in_use = fsdp_in_use
206
+
207
+ self.kourkoutas_beta = kourkoutas_beta
208
+ self.layer_key_fn = layer_key_fn
209
+
210
+ super().__init__(params, defaults)
211
+ if self.kourkoutas_beta:
212
+ self.kourkoutas_helper = KourkoutasHelper(self)
213
+ self.init_step()
214
+
215
+ @property
216
+ def supports_fused_back_pass(self):
217
+ return True
218
+
219
+ @property
220
+ def supports_memory_efficient_fp16(self):
221
+ return True
222
+
223
+ @property
224
+ def supports_flat_params(self):
225
+ return False
226
+
227
+ def init_step(self):
228
+ """Resets accumulators and calculates dlr for the upcoming step."""
229
+ self.d_denom = 0.0
230
+
231
+ g_group = self.param_groups[0]
232
+ self.beta1, self.beta2_default = g_group['betas']
233
+ self.beta3 = g_group['beta3']
234
+ if self.beta3 is None:
235
+ self.beta3 = math.sqrt(self.beta2_default)
236
+
237
+ self.d = g_group['d']
238
+ lr = g_group['lr']
239
+
240
+ self.dlr = self.d * lr
241
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
242
+
243
+ @torch.no_grad()
244
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
245
+ if p.grad is None:
246
+ return
247
+
248
+ if hasattr(p, "_fsdp_flattened"):
249
+ self.fsdp_in_use = True
250
+
251
+ grad = p.grad
252
+ if grad.dtype != torch.float32 and self.factored:
253
+ grad = grad.float()
254
+ if group["orthogonal_gradient"]:
255
+ grad = _orthogonalize_gradient(p, grad)
256
+ state = self.state[p]
257
+
258
+ # State Initialization
259
+ if 'step' not in state:
260
+ state['step'] = 0
261
+
262
+ should_factor = (
263
+ self.factored and
264
+ not (len(p.shape) == 1 and not group['vector_reshape'])
265
+ )
266
+
267
+ state['factored'] = should_factor
268
+
269
+ slice_p = group['slice_p']
270
+
271
+ dtype = torch.float32 if self.factored else p.dtype
272
+ device = p.device
273
+
274
+ if state['factored']:
275
+ state['effective_shape'] = _get_effective_shape(p.numel())
276
+ d1, d2 = state['effective_shape']
277
+
278
+ # First moment (m)
279
+ if self.beta1 > 0:
280
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
281
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
282
+ if not self.grams_moment:
283
+ packed_d2 = (d2 + 7) // 8
284
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
285
+ if self.use_AdEMAMix:
286
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
287
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
288
+ packed_d2 = (d2 + 7) // 8
289
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
290
+ # Second moment (v)
291
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
292
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
293
+ else: # Fallback to standard AdamW for non-factored tensors
294
+ if self.beta1 > 0:
295
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
296
+ if self.use_AdEMAMix:
297
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
298
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
299
+
300
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
301
+ if p.any():
302
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
303
+ else:
304
+ state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
305
+
306
+ current_step = state['step']
307
+ if group.get('kourkoutas_beta', False):
308
+ # Call prepare_step() once at the beginning of the step for all params
309
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
310
+ # Accumulate current grad's norm for the *next* step
311
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
312
+ # Get the dynamic beta2 calculated in prepare_step()
313
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
314
+ else:
315
+ beta2 = self.beta2_default
316
+
317
+ if self.use_AdEMAMix:
318
+ beta3_ema = group['beta3_ema']
319
+ alpha = group['alpha']
320
+ t_alpha = group['t_alpha']
321
+ alpha_step = state['step'] + 1
322
+ alpha_t = alpha
323
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
324
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
325
+ if self.Simplified_AdEMAMix:
326
+ alpha_grad = group["alpha_grad"]
327
+
328
+ if state['factored']:
329
+ d1, d2 = state['effective_shape']
330
+
331
+ grad_reshaped = grad.view(d1, d2)
332
+
333
+ # Reconstruct momentum from previous step's factors
334
+ if self.beta1 > 0:
335
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
336
+ if not self.grams_moment:
337
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
338
+ torch.where(unpacked_sign, mt, -mt, out=mt)
339
+ del unpacked_sign
340
+ # Update momentum in full-size
341
+ if self.Simplified_AdEMAMix:
342
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
343
+ else:
344
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
345
+ if self.grams_moment:
346
+ update_mt = (grad_reshaped.sign().mul_(mt.abs()))
347
+ elif self.cautious_mask:
348
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
349
+ mask.div_(mask.mean().clamp_(min=1e-3))
350
+ update_mt = mt.mul(mask)
351
+ del mask
352
+ else:
353
+ update_mt = mt.clone()
354
+
355
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
356
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
357
+
358
+ if self.use_AdEMAMix:
359
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
360
+ if state['sign_slow'].dtype != torch.uint8:
361
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
362
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
363
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
364
+ del unpacked_sign_slow
365
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
366
+ if self.beta1 > 0:
367
+ update = torch.add(update_mt, mt_slow, alpha=alpha_t)
368
+ else:
369
+ update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
370
+ elif self.Simplified_AdEMAMix:
371
+ update = torch.add(update_mt, grad_reshaped, alpha=alpha_grad * self.d)
372
+ else:
373
+ update = update_mt if self.beta1 > 0 else grad_reshaped.mul(self.d)
374
+ del grad_reshaped
375
+
376
+ if group['use_atan2']:
377
+ a = 1.2732395
378
+ denom = vt.sqrt()
379
+ update.atan2_(denom).mul_(a)
380
+ else:
381
+ denom = vt.sqrt()
382
+ update.div_(denom.add_(self.d * group['eps']))
383
+ del denom
384
+
385
+ update = update.view(p.shape).mul_(self.dlr)
386
+
387
+ # Compress updated moments and store new factors
388
+ if self.beta1 > 0:
389
+ if not self.grams_moment:
390
+ state['sign'] = _pack_bools(mt > 0)
391
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
392
+ del mt
393
+ if self.use_AdEMAMix:
394
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
395
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
396
+ del mt_slow
397
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
398
+ del vt
399
+
400
+ else: # Standard AdamW logic for non-factored tensors
401
+ exp_avg_sq = state['exp_avg_sq']
402
+
403
+ if self.beta1 > 0:
404
+ exp_avg = state['exp_avg']
405
+ if self.Simplified_AdEMAMix:
406
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
407
+ else:
408
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
409
+ if self.grams_moment:
410
+ update_mt = grad.sign().mul_(exp_avg.abs())
411
+ elif self.cautious_mask:
412
+ mask = (exp_avg * grad > 0).to(grad.dtype)
413
+ mask.div_(mask.mean().clamp_(min=1e-3))
414
+ update_mt = exp_avg.mul(mask)
415
+ del mask
416
+ else:
417
+ update_mt = exp_avg.clone()
418
+
419
+ if self.use_AdEMAMix:
420
+ exp_avg_slow = state['exp_avg_slow']
421
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
422
+ if self.beta1 > 0:
423
+ update = torch.add(update_mt, exp_avg_slow, alpha=alpha_t)
424
+ else:
425
+ update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
426
+ elif self.Simplified_AdEMAMix:
427
+ update = torch.add(update_mt, grad, alpha=alpha_grad * self.d)
428
+ else:
429
+ update = update_mt if self.beta1 > 0 else grad.mul(self.d)
430
+
431
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
432
+
433
+ if group['use_atan2']:
434
+ a = 1.2732395
435
+ denom = exp_avg_sq.sqrt()
436
+ update.atan2_(denom).mul_(a)
437
+ else:
438
+ denom = exp_avg_sq.sqrt()
439
+ update.div_(denom.add_(self.d * group['eps']))
440
+ del denom
441
+
442
+ update.mul_(self.dlr)
443
+
444
+ # --- Accumulate Prodigy stats ---
445
+ prodigy_steps = group['prodigy_steps']
446
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
447
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
448
+ s, p0 = state['s'], state['p0']
449
+ grad_flat = grad.flatten().float()
450
+ p_flat = p.data.flatten().float()
451
+ p0 = p0.float()
452
+
453
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
454
+
455
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
456
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
457
+ self.d_denom += s.abs().sum().item()
458
+
459
+ del s, p0, grad_flat, p_flat, alpha
460
+ else:
461
+ # Free memory if prodigy_steps is reached
462
+ if 's' in state:
463
+ del state['s']
464
+ if 'p0' in state:
465
+ del state['p0']
466
+
467
+ # Decoupled weight decay
468
+ if group["weight_decay"] != 0:
469
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
470
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
471
+ else:
472
+ p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
473
+
474
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
475
+ add_stochastic_(p.data, -update)
476
+ else:
477
+ p.data.add_(-update)
478
+ del update
479
+
480
+ state['step'] += 1
481
+
482
+ @torch.no_grad()
483
+ def step(self, closure=None):
484
+ """Performs a single optimization step."""
485
+ loss = None
486
+ if closure is not None:
487
+ with torch.enable_grad():
488
+ loss = closure()
489
+
490
+ for group in self.param_groups:
491
+ for i, p in enumerate(group['params']):
492
+ self.step_parameter(p, group, i)
493
+
494
+ self.calculate_d()
495
+ self.init_step()
496
+ return loss
497
+
498
+ def calculate_d(self):
499
+ """Calculates the new `d` based on the accumulated stats."""
500
+ g_group = self.param_groups[0]
501
+
502
+ # Only perform d-adaptation if prodigy_steps has not been reached
503
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
504
+
505
+ if prodigy_active:
506
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
507
+
508
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
509
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
510
+ device = self.param_groups[0]['params'][0].device
511
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
512
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
513
+ global_d_numerator = dist_tensor[0].item()
514
+ global_d_denom = dist_tensor[1].item()
515
+ else:
516
+ global_d_numerator = self.d_numerator
517
+ global_d_denom = self.d_denom
518
+
519
+ d_hat = self.d
520
+ if global_d_denom > 0:
521
+ d_hat = d_coef * global_d_numerator / global_d_denom
522
+ if g_group.get('d_limiter', False):
523
+ if g_group.get('Simplified_AdEMAMix', False):
524
+ d_hat = min(self.d * (2 ** 0.1), d_hat)
525
+ else:
526
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
527
+ if self.d == g_group['d0']:
528
+ self.d = max(self.d, d_hat)
529
+ d_max = max(d_max, d_hat)
530
+ self.d = min(d_max, self.d * growth_rate)
531
+
532
+ for group in self.param_groups:
533
+ group['d_numerator'] = global_d_numerator
534
+ group['d'] = self.d
535
+ group['d_max'] = d_max
536
+
537
+ # Increment step counter for all groups, regardless of whether d was updated
538
+ for group in self.param_groups:
539
+ group['k'] += 1