adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,437 @@
1
+ import torch
2
+ from typing import Callable, Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf, _unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
+
11
+ class Adopt_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements an advanced ADOPT algorithm.
14
+
15
+ The ADOPT update rule modifies Adam by:
16
+ 1. **Initialization:** The second moment `vt` is initialized as `v₀ = g₀²`.
17
+ 2. **Decorrelation:** The current gradient is normalized using the second-moment estimate
18
+ from the *previous* step (`v_{t-1}`).
19
+ 3. **Order of Operations:** This normalization occurs *before* updating the
20
+ first-moment (momentum) estimate.
21
+
22
+ Args:
23
+ params (iterable): iterable of parameters to optimize or dicts defining
24
+ parameter groups
25
+ lr (float): learning rate (default: 1e-4)
26
+ betas (tuple[float, float]): coefficients used for computing running
27
+ averages of momentum and variance (default: (0.9, 0.9999))
28
+ eps (float): term added to the denominator to improve
29
+ numerical stability (default: 1e-6)
30
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
31
+ clip_lambda (Callable, optional): A function that takes the current step
32
+ and returns a value to clip the normalized gradient. Only used when
33
+ `use_atan2` is False. (default: `lambda step: step**0.25`)
34
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
35
+ matrices for low-rank compression (default: True).
36
+ stochastic_rounding (bool): whether to use stochastic
37
+ rounding for BF16 parameter updates (default: True).
38
+ use_atan2 (bool): whether to use an atan2-based normalization, which can
39
+ improve stability by removing the need for `eps`. (default: False)
40
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
41
+ direction with the first moment's. (default: False)
42
+ grams_moment (bool): whether to combine the gradient's direction with the
43
+ first moment's magnitude (default: False).
44
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
45
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
46
+ a second, slow-moving average of the momentum (`mt_slow`) which is
47
+ combined with the primary momentum (`mt`) to stabilize updates,
48
+ especially in noisy, small-batch settings. If `False`, the
49
+ optimizer behaves as standard ADOPT. (default: False)
50
+ beta3_ema (float): The decay rate for the slow exponential moving average of
51
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
52
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
53
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
54
+ better for shorter training runs. (default: 0.9999)
55
+ alpha (float): The mixing coefficient that scales the slow momentum term
56
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
57
+ A higher value increases the stabilizing influence of the slow
58
+ momentum. (default: 5.0)
59
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
60
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
61
+ highly recommended to prevent instability at the beginning of training,
62
+ as it gradually introduces the stabilizing slow momentum term. During
63
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
64
+ the scheduler is disabled and the full `alpha` value is used from
65
+ the start. (default: None)
66
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
67
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
68
+ more responsive, especially for small batch sizes. Enabling this will
69
+ automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
70
+ and `use_atan2`. (default: False)
71
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
72
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
73
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
74
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
75
+ stability. (default: 100.0)
76
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
77
+ If `False`, the optimizer behaves as standard Adopt. (default: False)
78
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
79
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
80
+ (default: 0.88)
81
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
82
+ the pooled gradient norms. Corresponds to `α` in the paper.
83
+ (default: 0.93)
84
+ tiny_spike (float): A small constant added to the denominator of the
85
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
86
+ to `ε_spike` in the paper. (default: 1e-9)
87
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
88
+ at a fixed beta2 value before the
89
+ dynamic logic activates. (default: 0)
90
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
+ logging (default: 0).
94
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
+ and returns a unique, hashable key representing its "layer" or "bucket".
96
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
+ (default: None)
98
+ nnmf_factor (bool): whether to use the factorization or disable it to use
99
+ the uncompressed optimizer. (default: False)
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ params,
105
+ lr: float = 1e-4,
106
+ betas: tuple[float, float] = (0.9, 0.9999),
107
+ eps: float = 1e-6,
108
+ weight_decay: float = 0.0,
109
+ clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
110
+ vector_reshape: bool = True,
111
+ stochastic_rounding: bool = True,
112
+ use_atan2: bool = False,
113
+ cautious_mask: bool = False,
114
+ grams_moment: bool = False,
115
+ orthogonal_gradient: bool = False,
116
+ use_AdEMAMix: bool = False,
117
+ beta3_ema: float = 0.9999,
118
+ alpha: float = 5.0,
119
+ t_alpha: int | None = None,
120
+ Simplified_AdEMAMix: bool = False,
121
+ alpha_grad: float = 100.0,
122
+ kourkoutas_beta: bool = False,
123
+ beta2_min: float = 0.9,
124
+ ema_alpha: float = 0.95,
125
+ tiny_spike: float = 1e-9,
126
+ k_warmup_steps: int = 0,
127
+ k_logging: int = 0,
128
+ layer_key_fn: Optional[Callable] = None,
129
+ nnmf_factor: bool = False,
130
+ ):
131
+ if not (lr >= 0.0):
132
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
133
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
134
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
135
+ if not (eps >= 0.0):
136
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
137
+ if not (weight_decay >= 0.0):
138
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
139
+ if cautious_mask and grams_moment:
140
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
141
+ cautious_mask = False
142
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
143
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
145
+ if use_AdEMAMix and Simplified_AdEMAMix:
146
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
147
+ if grams_moment and Simplified_AdEMAMix:
148
+ print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
149
+ if cautious_mask and Simplified_AdEMAMix:
150
+ print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
151
+
152
+ defaults = {
153
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
154
+ "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
155
+ "t_alpha": t_alpha, "alpha_grad": alpha_grad,
156
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
157
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
158
+ }
159
+ self.clip_lambda = clip_lambda
160
+ self.stochastic_rounding = stochastic_rounding
161
+ self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
162
+ self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
163
+ self.grams_moment = grams_moment and not Simplified_AdEMAMix
164
+ self.orthogonal_gradient = orthogonal_gradient
165
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
166
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
167
+ self.factored = nnmf_factor
168
+ self.kourkoutas_beta = kourkoutas_beta
169
+ self.layer_key_fn = layer_key_fn
170
+ super().__init__(params, defaults)
171
+
172
+ if self.kourkoutas_beta:
173
+ self.kourkoutas_helper = KourkoutasHelper(self)
174
+
175
+ @property
176
+ def supports_fused_back_pass(self): return True
177
+ @property
178
+ def supports_memory_efficient_fp16(self): return True
179
+ @property
180
+ def supports_flat_params(self): return False
181
+
182
+ @torch.no_grad()
183
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
184
+ if p.grad is None:
185
+ return
186
+
187
+ grad = p.grad
188
+ if self.factored and grad.dtype != torch.float32:
189
+ grad = grad.float()
190
+ if self.orthogonal_gradient:
191
+ grad = _orthogonalize_gradient(p, grad)
192
+ state = self.state[p]
193
+
194
+ # State Initialization
195
+ if 'step' not in state:
196
+ state['step'] = 0
197
+
198
+ should_factor = (
199
+ self.factored and
200
+ not (len(p.shape) == 1 and not group['vector_reshape'])
201
+ )
202
+
203
+ state['factored'] = should_factor
204
+
205
+ dtype = torch.float32 if self.factored else p.dtype
206
+
207
+ if state['factored']:
208
+ state['effective_shape'] = _get_effective_shape(p.numel())
209
+ d1, d2 = state['effective_shape']
210
+
211
+ # m_0 = 0
212
+ if group['betas'][0] > 0:
213
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
214
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
215
+ if not self.grams_moment:
216
+ packed_d2 = (d2 + 7) // 8
217
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
218
+ if self.use_AdEMAMix:
219
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
220
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
221
+ packed_d2 = (d2 + 7) // 8
222
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
223
+ # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
224
+ vt_init = grad.view(d1, d2).square_()
225
+ # Allocate NMF factors for vt
226
+ state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
227
+ state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
228
+ # Initialize v_0 using NMF
229
+ _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
230
+ else: # Fallback for non-factored tensors
231
+ if group['betas'][0] > 0:
232
+ state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
233
+ if self.use_AdEMAMix:
234
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
235
+ state['exp_avg_sq'] = grad.square() # v_0
236
+
237
+ beta1, beta2 = group['betas']
238
+
239
+ current_step = state['step']
240
+ if group.get('kourkoutas_beta', False):
241
+ # Call prepare_step() once at the beginning of the step for all params
242
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
243
+ # Accumulate current grad's norm for the *next* step
244
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
245
+ # Get the dynamic beta2 calculated in prepare_step()
246
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
247
+
248
+ # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
249
+ if state['step'] == 0 and not self.use_atan2:
250
+ state['step'] += 1
251
+ return
252
+
253
+ if self.use_AdEMAMix:
254
+ beta3_ema = group['beta3_ema']
255
+ alpha = group['alpha']
256
+ t_alpha = group['t_alpha']
257
+ # Use step+1 for 1-based step count in scheduler
258
+ alpha_step = state['step'] + 1
259
+ alpha_t = alpha
260
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
261
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
262
+ if self.Simplified_AdEMAMix:
263
+ alpha_grad = group["alpha_grad"]
264
+
265
+ if state['factored']:
266
+ d1, d2 = state['effective_shape']
267
+
268
+ # Reconstruct m_{t-1}
269
+ if beta1 > 0:
270
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
271
+ if not self.grams_moment:
272
+ if state['sign'].dtype != torch.uint8:
273
+ state['sign'] = state['sign'].to(torch.uint8)
274
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
275
+ torch.where(unpacked_sign, mt, -mt, out=mt)
276
+ del unpacked_sign
277
+
278
+ # Reconstruct AdEMAMix EMA
279
+ if self.use_AdEMAMix:
280
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
281
+ if state['sign_slow'].dtype != torch.uint8:
282
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
283
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
284
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
285
+ del unpacked_sign_slow
286
+
287
+ # Reconstruct v_{t-1} using NNMF
288
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
289
+
290
+ # ADOPT Step A: Decorrelate g_t using v_{t-1}
291
+ grad_reshaped = grad.view(d1, d2)
292
+ denom = vt.sqrt()
293
+
294
+ if self.use_atan2:
295
+ normalized_grad = torch.atan2(grad_reshaped, denom)
296
+ else:
297
+ normalized_grad = grad_reshaped / denom.add_(group['eps'])
298
+ if self.clip_lambda is not None:
299
+ clip_val = self.clip_lambda(state['step'])
300
+ normalized_grad.clamp_(-clip_val, clip_val)
301
+ del denom
302
+
303
+ # ADOPT Step B: Update momentum m_t using normalized gradient
304
+ if beta1 > 0:
305
+ if self.Simplified_AdEMAMix:
306
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
307
+ else:
308
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
309
+ if self.grams_moment:
310
+ update_mt = grad_reshaped.sign().mul_(mt.abs())
311
+ elif self.cautious_mask:
312
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
313
+ mask.div_(mask.mean().clamp_(min=1e-3))
314
+ update_mt= mt.mul(mask)
315
+ del mask
316
+ else:
317
+ update_mt = mt.clone()
318
+
319
+ if self.use_AdEMAMix:
320
+ mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
321
+ if beta1 > 0:
322
+ update = torch.add(update_mt, mt_slow, alpha=alpha_t)
323
+ else:
324
+ update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
325
+ elif self.Simplified_AdEMAMix:
326
+ update = torch.add(update_mt, normalized_grad, alpha=alpha_grad)
327
+ else:
328
+ update = update_mt if beta1 > 0 else normalized_grad
329
+
330
+ update = update.view(p.shape)
331
+
332
+ if self.use_atan2:
333
+ update.mul_(group['lr'] * 1.2732395447351628)
334
+ else:
335
+ update.mul_(group['lr'])
336
+
337
+ # Update second moment v_t for the *next* step using raw g_t
338
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
339
+ del grad_reshaped
340
+
341
+ # Compress and store new factors
342
+ if beta1 > 0:
343
+ if not self.grams_moment:
344
+ state['sign'] = _pack_bools(mt > 0)
345
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
346
+ del mt
347
+
348
+ if self.use_AdEMAMix:
349
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
350
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
351
+ del mt_slow
352
+
353
+ # factorize v_t using NMF compression
354
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
355
+ del vt
356
+
357
+ else: # Standard ADOPT logic for non-factored tensors
358
+ vt = state['exp_avg_sq'] # v_{t-1}
359
+
360
+ # ADOPT Step A: Decorrelate g_t using v_{t-1}
361
+ denom = vt.sqrt()
362
+
363
+ if self.use_atan2:
364
+ normalized_grad = torch.atan2(grad, denom)
365
+ else:
366
+ normalized_grad = grad / denom.add_(group['eps'])
367
+ if self.clip_lambda is not None:
368
+ clip_val = self.clip_lambda(state['step'])
369
+ normalized_grad.clamp_(-clip_val, clip_val)
370
+ del denom
371
+
372
+ # ADOPT Step B: Update momentum m_t
373
+ if beta1 > 0:
374
+ mt = state['exp_avg'] # m_{t-1},
375
+ if self.Simplified_AdEMAMix:
376
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
377
+ else:
378
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
379
+
380
+ if self.grams_moment:
381
+ update_mt = grad.sign().mul_(mt.abs())
382
+ elif self.cautious_mask:
383
+ mask = (mt * grad > 0).to(grad.dtype)
384
+ mask.div_(mask.mean().clamp_(min=1e-3))
385
+ update_mt = mt.mul(mask)
386
+ del mask
387
+ else:
388
+ update_mt = mt.clone()
389
+
390
+ if self.use_AdEMAMix:
391
+ m_slow = state['exp_avg_slow']
392
+ m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
393
+ if beta1 > 0:
394
+ update = torch.add(update_mt, m_slow, alpha=alpha_t)
395
+ else:
396
+ update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
397
+ elif self.Simplified_AdEMAMix:
398
+ update = torch.add(update_mt, normalized_grad, alpha=alpha_grad)
399
+ else:
400
+ update = update_mt if beta1 > 0 else normalized_grad
401
+
402
+ if self.use_atan2:
403
+ update.mul_(group['lr'] * 1.2732395447351628)
404
+ else:
405
+ update.mul_(group['lr'])
406
+
407
+ # Update second moment v_t for the next step using raw g_t
408
+ vt.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
409
+
410
+ # Parameter Update
411
+ if group["weight_decay"] != 0:
412
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
413
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
414
+ else:
415
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
416
+
417
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
418
+ add_stochastic_(p.data, -update)
419
+ else:
420
+ p.data.add_(-update)
421
+ del update
422
+
423
+ state['step'] += 1
424
+
425
+ @torch.no_grad()
426
+ def step(self, closure=None):
427
+ """Performs a single optimization step."""
428
+ loss = None
429
+ if closure is not None:
430
+ with torch.enable_grad():
431
+ loss = closure()
432
+
433
+ for group in self.param_groups:
434
+ for i, p in enumerate(group['params']):
435
+ self.step_parameter(p, group, i)
436
+
437
+ return loss