adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,374 @@
1
+ import torch
2
+ from typing import Optional, Callable
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Effective_Shape import _get_effective_shape
6
+ from ..util.NNMF import _nnmf,_unnmf
7
+ from ..util.OrthoGrad import _orthogonalize_gradient
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
10
+
11
+ class AdamW_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements an advanced AdamW algorithm.
14
+ This is an advanced version of AdamW with optional features like
15
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
16
+
17
+ Args:
18
+ params (iterable): iterable of parameters to optimize or dicts defining
19
+ parameter groups
20
+ lr (float): learning rate (default: 1e-3)
21
+ betas (tuple[float, float]): coefficients used for computing running
22
+ averages of gradient and its square (default: (0.9, 0.999))
23
+ eps (float): term added to the denominator to improve
24
+ numerical stability (default: 1e-8)
25
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
26
+ use_bias_correction (bool): whether to use bias correction for the first
27
+ and second moment estimates, as in the original Adam paper.
28
+ (default: True)
29
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
30
+ matrices to apply low-rank compression (default: True).
31
+ stochastic_rounding (bool): whether to use stochastic
32
+ rounding for BF16 parameter updates (default: True).
33
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
34
+ grams_moment (bool): whether to use Grams-style updates. (default: False)
35
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
36
+ direction with the first moment's. (default: False)
37
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
38
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
39
+ a second, slow-moving average of the momentum (`mt_slow`) which is
40
+ combined with the primary momentum (`mt`) to stabilize updates,
41
+ especially in noisy, small-batch settings. If `False`, the
42
+ optimizer behaves as standard AdamW. (default: False)
43
+ beta3_ema (float): The decay rate for the slow exponential moving average of
44
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
45
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
46
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
47
+ better for shorter training runs. (default: 0.9999)
48
+ alpha (float): The mixing coefficient that scales the slow momentum term
49
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
50
+ A higher value increases the stabilizing influence of the slow
51
+ momentum. (default: 5.0)
52
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
53
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
54
+ highly recommended to prevent instability at the beginning of training,
55
+ as it gradually introduces the stabilizing slow momentum term. During
56
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
57
+ the scheduler is disabled. (default: None)
58
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
59
+ If `False`, the optimizer behaves as standard AdamW. (default: False)
60
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
61
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
62
+ (default: 0.88)
63
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
64
+ the pooled gradient norms. Corresponds to `α` in the paper.
65
+ (default: 0.93)
66
+ tiny_spike (float): A small constant added to the denominator of the
67
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
68
+ to `ε_spike` in the paper. (default: 1e-9)
69
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
70
+ at a fixed beta2 value before the
71
+ dynamic logic activates. (default: 0)
72
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
73
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
+ logging (default: 0).
76
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
+ and returns a unique, hashable key representing its "layer" or "bucket".
78
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
+ (default: None)
80
+ nnmf_factor (bool): whether to use the factorization or disable it to use
81
+ the uncompressed optimizer. (default: False)
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ params,
87
+ lr: float = 1e-3,
88
+ betas: tuple[float, float] = (0.9, 0.999),
89
+ eps: float = 1e-8,
90
+ weight_decay: float = 0.0,
91
+ use_bias_correction: bool = True,
92
+ vector_reshape: bool = True,
93
+ stochastic_rounding: bool = True,
94
+ use_atan2: bool = False,
95
+ cautious_mask: bool = False,
96
+ grams_moment: bool = False,
97
+ orthogonal_gradient: bool = False,
98
+ use_AdEMAMix: bool = False,
99
+ beta3_ema: float = 0.9999,
100
+ alpha: float = 5.0,
101
+ t_alpha: int | None = None,
102
+ kourkoutas_beta: bool = False,
103
+ beta2_min: float = 0.9,
104
+ ema_alpha: float = 0.95,
105
+ tiny_spike: float = 1e-9,
106
+ k_warmup_steps: int = 0,
107
+ k_logging: int = 0,
108
+ layer_key_fn: Optional[Callable] = None,
109
+ nnmf_factor: bool = False,
110
+ ):
111
+ if not (lr >= 0.0):
112
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
113
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
114
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
115
+ if not (eps >= 0.0):
116
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
117
+ if not (weight_decay >= 0.0):
118
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
119
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
120
+
121
+ if cautious_mask and grams_moment:
122
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
123
+ cautious_mask = False
124
+
125
+ defaults = {
126
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
127
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
128
+ "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
129
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
130
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
132
+ }
133
+ self.stochastic_rounding = stochastic_rounding
134
+ self.cautious_mask = cautious_mask
135
+ self.grams_moment = grams_moment
136
+ self.use_AdEMAMix = use_AdEMAMix
137
+ self.factored = nnmf_factor
138
+ self.kourkoutas_beta = kourkoutas_beta
139
+ self.layer_key_fn = layer_key_fn
140
+ super().__init__(params, defaults)
141
+
142
+ if self.kourkoutas_beta:
143
+ self.kourkoutas_helper = KourkoutasHelper(self)
144
+
145
+ @property
146
+ def supports_fused_back_pass(self):
147
+ return True
148
+
149
+ @property
150
+ def supports_memory_efficient_fp16(self):
151
+ return True
152
+
153
+ @property
154
+ def supports_flat_params(self):
155
+ return False
156
+
157
+ @torch.no_grad()
158
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
159
+ if p.grad is None:
160
+ return
161
+
162
+ grad = p.grad
163
+ if grad.dtype != torch.float32 and self.factored:
164
+ grad = grad.float()
165
+ if group["orthogonal_gradient"]:
166
+ grad = _orthogonalize_gradient(p, grad)
167
+ state = self.state[p]
168
+
169
+ # State Initialization
170
+ if 'step' not in state:
171
+ state['step'] = 0
172
+
173
+ should_factor = (
174
+ self.factored and
175
+ not (len(p.shape) == 1 and not group['vector_reshape'])
176
+ )
177
+
178
+ state['factored'] = should_factor
179
+
180
+ dtype = torch.float32 if self.factored else p.dtype
181
+ device = p.device
182
+
183
+ if state['factored']:
184
+ state['effective_shape'] = _get_effective_shape(p.numel())
185
+ d1, d2 = state['effective_shape']
186
+
187
+ # First moment (m)
188
+ if group['betas'][0] > 0:
189
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
190
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
191
+ if not self.grams_moment:
192
+ packed_d2 = (d2 + 7) // 8
193
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
194
+ if self.use_AdEMAMix:
195
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
196
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
197
+ packed_d2 = (d2 + 7) // 8
198
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
199
+ # Second moment (v)
200
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
201
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
202
+ else: # Fallback to standard AdamW for non-factored tensors
203
+ if group['betas'][0] > 0:
204
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
205
+ if self.use_AdEMAMix:
206
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
207
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
208
+
209
+ beta1, beta2 = group['betas']
210
+
211
+ current_step = state['step']
212
+ if group.get('kourkoutas_beta', False):
213
+ # Call prepare_step() once at the beginning of the step for all params
214
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
215
+ # Accumulate current grad's norm for the *next* step
216
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
217
+ # Get the dynamic beta2 calculated in prepare_step()
218
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
219
+
220
+ step = state['step'] + 1
221
+ if group['use_bias_correction']:
222
+ bias_correction1 = 1.0 - beta1 ** step
223
+ if group.get('kourkoutas_beta', False):
224
+ bias_correction2 = 1.0 - group['betas'][1] ** step
225
+ # Use beta2_max for bias correction
226
+ else:
227
+ bias_correction2 = 1.0 - beta2 ** step
228
+ else:
229
+ bias_correction1 = 1
230
+ bias_correction2 = 1
231
+ step_size = group['lr'] / bias_correction1
232
+
233
+ if self.use_AdEMAMix:
234
+ beta3_ema = group['beta3_ema']
235
+ alpha = group['alpha']
236
+ t_alpha = group['t_alpha']
237
+ alpha_t = alpha
238
+ if t_alpha is not None and t_alpha > 0 and step < t_alpha:
239
+ alpha_t = min(step * alpha / t_alpha, alpha)
240
+
241
+ if state['factored']:
242
+ d1, d2 = state['effective_shape']
243
+ grad_reshaped = grad.view(d1, d2)
244
+
245
+ # Reconstruct momentum from previous step's factors
246
+ if beta1 > 0:
247
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
248
+ if not self.grams_moment:
249
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
250
+ torch.where(unpacked_sign, mt, -mt, out=mt)
251
+ del unpacked_sign
252
+ # Update momentum in full-size
253
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
254
+ if self.grams_moment:
255
+ update_mt = (grad_reshaped.sign().mul_(mt.abs()))
256
+ elif self.cautious_mask:
257
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
258
+ mask.div_(mask.mean().clamp_(min=1e-3))
259
+ update_mt = mt.mul(mask)
260
+ del mask
261
+ else:
262
+ update_mt = mt.clone()
263
+
264
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
265
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
266
+
267
+ if self.use_AdEMAMix:
268
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
269
+ if state['sign_slow'].dtype != torch.uint8:
270
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
271
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
272
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
273
+ del unpacked_sign_slow
274
+
275
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
276
+ if beta1 > 0:
277
+ update = torch.add(update_mt, mt_slow, alpha=alpha_t)
278
+ else:
279
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
280
+ else:
281
+ update = update_mt if beta1 > 0 else grad_reshaped.clone()
282
+ del grad_reshaped
283
+
284
+ if group['use_atan2']:
285
+ a = 1.2732395
286
+ denom = (vt.sqrt() / (bias_correction2**0.5))
287
+ update.atan2_(denom).mul_(a)
288
+ else:
289
+ denom = (vt.sqrt() / (bias_correction2**0.5)).add_(group['eps'])
290
+ update.div_(denom)
291
+ del denom
292
+
293
+ update = update.view(p.shape).mul_(step_size)
294
+
295
+ # Compress updated moments and store new factors
296
+ if beta1 > 0:
297
+ if not self.grams_moment:
298
+ state['sign'] = _pack_bools(mt > 0)
299
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
300
+ del mt
301
+ if self.use_AdEMAMix:
302
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
303
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
304
+ del mt_slow
305
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
306
+ del vt
307
+
308
+ else: # Standard AdamW logic for non-factored tensors
309
+ exp_avg_sq = state['exp_avg_sq']
310
+
311
+ if beta1 > 0:
312
+ exp_avg = state['exp_avg']
313
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
314
+ if self.grams_moment:
315
+ update_mt = grad.sign().mul_(exp_avg.abs())
316
+ elif self.cautious_mask:
317
+ mask = (exp_avg * grad > 0).to(grad.dtype)
318
+ mask.div_(mask.mean().clamp_(min=1e-3))
319
+ update_mt = exp_avg.mul(mask)
320
+ del mask
321
+ else:
322
+ update_mt = exp_avg.clone()
323
+
324
+ if self.use_AdEMAMix:
325
+ exp_avg_slow = state['exp_avg_slow']
326
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
327
+ if beta1 > 0:
328
+ update = torch.add(update_mt, exp_avg_slow, alpha=alpha_t)
329
+ else:
330
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
331
+ else:
332
+ update = update_mt if beta1 > 0 else grad.clone()
333
+
334
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
335
+
336
+ if group['use_atan2']:
337
+ a = 1.2732395
338
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5))
339
+ update.atan2_(denom).mul_(a)
340
+ else:
341
+ denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group['eps'])
342
+ update.div_(denom)
343
+ del denom
344
+
345
+ update.mul_(step_size)
346
+
347
+ # Decoupled weight decay
348
+ if group["weight_decay"] != 0:
349
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
350
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
351
+ else:
352
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
353
+
354
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
355
+ add_stochastic_(p.data, -update)
356
+ else:
357
+ p.data.add_(-update)
358
+ del update
359
+
360
+ state['step'] += 1
361
+
362
+ @torch.no_grad()
363
+ def step(self, closure=None):
364
+ """Performs a single optimization step."""
365
+ loss = None
366
+ if closure is not None:
367
+ with torch.enable_grad():
368
+ loss = closure()
369
+
370
+ for group in self.param_groups:
371
+ for i, p in enumerate(group['params']):
372
+ self.step_parameter(p, group, i)
373
+
374
+ return loss