adv-optm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -0,0 +1,367 @@
1
+ import torch
2
+ from typing import Optional
3
+ import math
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ class Prodigy_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements a factored Prodigy/AdamW algorithm.
14
+ This is an advanced version of Prodigy with optional features like
15
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
16
+
17
+ Args:
18
+ params (iterable): iterable of parameters to optimize or dicts defining
19
+ parameter groups
20
+ lr (float): learning rate (default: 1e-3)
21
+ betas (tuple[float, float]): coefficients used for computing running
22
+ averages of gradient and its square (default: (0.9, 0.999))
23
+ eps (float): term added to the denominator to improve
24
+ numerical stability (default: 1e-8)
25
+ weight_decay (float): weight decay (L2 penalty) (default: 0)
26
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
27
+ matrices to apply low-rank compression (default: True).
28
+ stochastic_rounding (bool): whether to use stochastic
29
+ rounding for BF16 parameter updates (default: True).
30
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
31
+ use_grams (bool): whether to use Grams-style updates. (default: False)
32
+ use_cautious (bool): whether to use cautious masking to align the gradient's
33
+ direction with the first moment's. (default: False)
34
+ use_orthograd (bool): whether to use OrthoGrad. (default: False)
35
+ use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
36
+ a second, slow-moving average of the momentum (`mt_slow`) which is
37
+ combined with the primary momentum (`mt`) to stabilize updates,
38
+ especially in noisy, small-batch settings. If `False`, the
39
+ optimizer behaves as standard AdamW. (default: False)
40
+ beta3_ema (float): The decay rate for the slow exponential moving average of
41
+ the momentum (only used when `use_AdEMAMix` is `True`). A higher
42
+ value (e.g., 0.9999) gives the EMA a longer memory, making it more
43
+ stable but slower to adapt. A lower value (e.g., 0.999) is often
44
+ better for shorter training runs. (default: 0.9999)
45
+ alpha (float): The mixing coefficient that scales the slow momentum term
46
+ before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
47
+ A higher value increases the stabilizing influence of the slow
48
+ momentum. (default: 5.0)
49
+ t_alpha (Optional[int]): The number of steps for a linear warmup of the
50
+ `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
51
+ highly recommended to prevent instability at the beginning of training,
52
+ as it gradually introduces the stabilizing slow momentum term. During
53
+ the warmup, `alpha` ramps from 0 to its target value. If `None`,
54
+ the scheduler is disabled and th
55
+ factored (bool): whether to use the factorization or disable it to use
56
+ the uncompressed optimizer. (default: True)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ params,
62
+ lr: float = 1e-3,
63
+ betas: tuple[float, float] = (0.9, 0.999),
64
+ eps: float = 1e-8,
65
+ weight_decay: float = 0.0,
66
+ vector_reshape: bool = True,
67
+ stochastic_rounding: bool = True,
68
+ use_atan2: bool = False,
69
+ use_cautious: bool = False,
70
+ use_grams: bool = False,
71
+ use_orthograd: bool = False,
72
+ use_AdEMAMix: bool = False,
73
+ beta3_ema: float = 0.9999,
74
+ alpha: float = 5.0,
75
+ t_alpha: int | None = None,
76
+ factored: bool = True,
77
+ # prodigy parameters
78
+ beta3: float = None,
79
+ d0: float = 1e-6,
80
+ d_coef: float = 1,
81
+ growth_rate: float = float('inf'),
82
+ safeguard_warmup: bool = False,
83
+ slice_p: int = 11,
84
+ ):
85
+ if not (lr >= 0.0):
86
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
87
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
88
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
89
+ if not (eps >= 0.0):
90
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
91
+ if not (weight_decay >= 0.0):
92
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
93
+
94
+ defaults = {
95
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
96
+ "vector_reshape": vector_reshape, "use_atan2": use_atan2,
97
+ "use_orthograd": use_orthograd,
98
+ "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
99
+ "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
100
+ "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
101
+ }
102
+ self.stochastic_rounding = stochastic_rounding
103
+ self.use_cautious = use_cautious
104
+ self.use_grams = use_grams
105
+ self.use_AdEMAMix = use_AdEMAMix
106
+ self.factored = factored
107
+ super().__init__(params, defaults)
108
+ self.init_step()
109
+
110
+ @property
111
+ def supports_fused_back_pass(self):
112
+ return True
113
+
114
+ @property
115
+ def supports_memory_efficient_fp16(self):
116
+ return True
117
+
118
+ @property
119
+ def supports_flat_params(self):
120
+ return False
121
+
122
+ def init_step(self):
123
+ """Resets accumulators and calculates dlr for the upcoming step."""
124
+ self.d_denom = 0.0
125
+
126
+ g_group = self.param_groups[0]
127
+ self.beta1, self.beta2 = g_group['betas']
128
+ self.beta3 = g_group['beta3']
129
+ if self.beta3 is None:
130
+ self.beta3 = math.sqrt(self.beta2)
131
+
132
+ k = g_group['k']
133
+ self.d = g_group['d']
134
+ lr = g_group['lr']
135
+
136
+ self.dlr = self.d * lr
137
+
138
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
139
+
140
+ @torch.no_grad()
141
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
142
+ if p.grad is None:
143
+ return
144
+
145
+ grad = p.grad
146
+ if grad.dtype != torch.float32 and self.factored:
147
+ grad = grad.float()
148
+ if group["use_orthograd"]:
149
+ grad = _orthogonalize_gradient(p, grad)
150
+ state = self.state[p]
151
+
152
+ # State Initialization
153
+ if len(state) == 0:
154
+ state['step'] = 0
155
+
156
+ should_factor = (
157
+ self.factored and
158
+ not (len(p.shape) == 1 and not group['vector_reshape'])
159
+ )
160
+
161
+ state['factored'] = should_factor
162
+
163
+ slice_p = group['slice_p']
164
+
165
+ dtype = torch.float32 if self.factored else p.dtype
166
+ device = p.device
167
+
168
+ if state['factored']:
169
+ state['effective_shape'] = _get_effective_shape(p.numel())
170
+ d1, d2 = state['effective_shape']
171
+
172
+ # First moment (m)
173
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
174
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
175
+ if not self.use_grams:
176
+ packed_d2 = (d2 + 7) // 8
177
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
178
+ if self.use_AdEMAMix:
179
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
180
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
181
+ packed_d2 = (d2 + 7) // 8
182
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
183
+ # Second moment (v)
184
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
185
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
186
+ else: # Fallback to standard AdamW for non-factored tensors
187
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
188
+ if self.use_AdEMAMix:
189
+ state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
190
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
191
+
192
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
193
+ if p.any():
194
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
195
+ else:
196
+ state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
197
+
198
+ if self.use_AdEMAMix:
199
+ beta3_ema = group['beta3_ema']
200
+ alpha = group['alpha']
201
+ t_alpha = group['t_alpha']
202
+ current_step = state['step'] + 1
203
+ alpha_t = alpha
204
+ if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
205
+ alpha_t = min(current_step * alpha / t_alpha, alpha)
206
+
207
+ if state['factored']:
208
+ d1, d2 = state['effective_shape']
209
+
210
+ # Reconstruct momentum from previous step's factors
211
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
212
+ if not self.use_grams:
213
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
214
+ torch.where(unpacked_sign, mt, -mt, out=mt)
215
+ del unpacked_sign
216
+ # Update momentum in full-size
217
+ grad_reshaped = grad.view(d1, d2)
218
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
219
+ if self.use_grams:
220
+ mt.copy_(grad_reshaped.sign() * mt.abs())
221
+ elif self.use_cautious:
222
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
223
+ mask.div_(mask.mean().clamp_(min=1e-3))
224
+ mt.mul_(mask)
225
+ del mask
226
+
227
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
228
+ vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
229
+
230
+ if self.use_AdEMAMix:
231
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
232
+ if state['sign_slow'].dtype != torch.uint8:
233
+ state['sign_slow'] = state['sign_slow'].to(torch.uint8)
234
+ unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
235
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
236
+ del unpacked_sign_slow
237
+
238
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
239
+ update_m = mt + (alpha_t * mt_slow)
240
+ else:
241
+ update_m = mt
242
+ del grad_reshaped
243
+
244
+ if group['use_atan2']:
245
+ a = 1.2732395
246
+ denom = vt.sqrt()
247
+ update = torch.atan2(update_m, denom).mul_(a)
248
+ else:
249
+ denom = vt.sqrt().add_(group['eps'])
250
+ update = update_m / denom
251
+ del update_m, denom
252
+
253
+ update = update.view(p.shape)
254
+ update.mul_(self.dlr)
255
+
256
+ # Compress updated moments and store new factors
257
+ if not self.use_grams:
258
+ state['sign'] = _pack_bools(mt > 0)
259
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
260
+ del mt
261
+ if self.use_AdEMAMix:
262
+ state['sign_slow'] = _pack_bools(mt_slow > 0)
263
+ _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
264
+ del mt_slow
265
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
266
+ del vt
267
+
268
+ else: # Standard AdamW logic for non-factored tensors
269
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
270
+
271
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
272
+ if self.use_grams:
273
+ exp_avg = grad.sign() * exp_avg.abs()
274
+ elif self.use_cautious:
275
+ mask = (exp_avg * grad > 0).to(grad.dtype)
276
+ mask.div_(mask.mean().clamp_(min=1e-3))
277
+ exp_avg.mul_(mask)
278
+ del mask
279
+
280
+ if self.use_AdEMAMix:
281
+ exp_avg_slow = state['exp_avg_slow']
282
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
283
+ update_m = exp_avg + (alpha_t * exp_avg_slow)
284
+ else:
285
+ update_m = exp_avg
286
+
287
+ exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
288
+
289
+ if group['use_atan2']:
290
+ a = 1.2732395
291
+ denom = exp_avg_sq.sqrt()
292
+ update = torch.atan2(update_m, denom).mul_(a)
293
+ else:
294
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
295
+ update = update_m / denom
296
+ del update_m, denom
297
+
298
+ update = update.mul_(self.dlr)
299
+
300
+ # --- Accumulate Prodigy stats ---
301
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
302
+ s, p0 = state['s'], state['p0']
303
+ grad_flat = grad.flatten().float()
304
+ p_flat = p.data.flatten().float()
305
+ p0 = p0.float()
306
+
307
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
308
+
309
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
310
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
311
+ self.d_denom += s.abs().sum().item()
312
+
313
+ del s, p0, grad_flat, p_flat, alpha
314
+
315
+ # Decoupled weight decay
316
+ if group["weight_decay"] != 0:
317
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
318
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
319
+ else:
320
+ p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
321
+
322
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
323
+ add_stochastic_(p.data, -update)
324
+ else:
325
+ p.data.add_(-update)
326
+ del update
327
+
328
+ state['step'] += 1
329
+
330
+ @torch.no_grad()
331
+ def step(self, closure=None):
332
+ """Performs a single optimization step."""
333
+ loss = None
334
+ if closure is not None:
335
+ with torch.enable_grad():
336
+ loss = closure()
337
+
338
+ for group in self.param_groups:
339
+ for i, p in enumerate(group['params']):
340
+ self.step_parameter(p, group, i)
341
+
342
+
343
+ self.calculate_d()
344
+ self.init_step()
345
+ return loss
346
+
347
+ def calculate_d(self):
348
+ """Calculates the new `d` based on the accumulated stats."""
349
+ g_group = self.param_groups[0]
350
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
351
+
352
+ global_d_numerator = self.d_numerator
353
+ global_d_denom = self.d_denom
354
+
355
+ d_hat = self.d
356
+ if global_d_denom > 0:
357
+ d_hat = d_coef * global_d_numerator / global_d_denom
358
+ if self.d == g_group['d0']:
359
+ self.d = max(self.d, d_hat)
360
+ d_max = max(d_max, d_hat)
361
+ self.d = min(d_max, self.d * growth_rate)
362
+
363
+ for group in self.param_groups:
364
+ group['d_numerator'] = global_d_numerator
365
+ group['d'] = self.d
366
+ group['d_max'] = d_max
367
+ group['k'] += 1
@@ -0,0 +1,9 @@
1
+ from .AdamW_adv import AdamW_adv
2
+ from .Prodigy_adv import Prodigy_adv
3
+ from .Adopt_adv import Adopt_adv
4
+
5
+ __all__ = [
6
+ "AdamW_adv",
7
+ "Prodigy_adv",
8
+ "Adopt_adv",
9
+ ]
@@ -0,0 +1,47 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ def copy_stochastic_(target: Tensor, source: Tensor):
5
+ """
6
+ Nerogar's implementation of stochastic rounding in the paper "Revisiting BFloat16 Training"
7
+ (https://arxiv.org/abs/2010.06192).
8
+ see:
9
+ https://github.com/pytorch/pytorch/issues/120376
10
+ https://github.com/Nerogar/OneTrainer/blob/daae18eaed8c0fa39289b2ff79cc2c1e08577fcb/modules/util/bf16_stochastic_rounding.py
11
+
12
+ Args:
13
+ target: the target tensor with dtype=bfloat16
14
+ source: the target tensor with dtype=float32
15
+ """
16
+ # create a random 16 bit integer
17
+ result = torch.randint_like(
18
+ source,
19
+ dtype=torch.int32,
20
+ low=0,
21
+ high=(1 << 16),
22
+ )
23
+
24
+ # add the random number to the lower 16 bit of the mantissa
25
+ result.add_(source.view(dtype=torch.int32))
26
+
27
+ # mask off the lower 16 bit of the mantissa
28
+ result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
29
+
30
+ # copy the higher 16 bit into the target tensor
31
+ target.copy_(result.view(dtype=torch.float32))
32
+
33
+ del result
34
+
35
+ def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
36
+ """
37
+ adds other to input using stochastic rounding
38
+
39
+ Args:
40
+ input: the input tensor with dtype=bfloat16
41
+ other: the other tensor
42
+ alpha: a multiplier for other
43
+ """
44
+ result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
45
+
46
+ result.add_(input, alpha=alpha)
47
+ copy_stochastic_(input, result)
@@ -0,0 +1,8 @@
1
+ def _get_effective_shape(numel: int) -> tuple[int, int]:
2
+ """Finds two factors of numel that are closest to its square root."""
3
+ if numel <= 0:
4
+ return (0, 0)
5
+ for i in reversed(range(1, int(numel ** 0.5) + 1)):
6
+ if numel % i == 0:
7
+ return (numel // i, i)
8
+ return (numel, 1)
adv_optm/util/NNMF.py ADDED
@@ -0,0 +1,18 @@
1
+ import torch
2
+
3
+ def _unnmf(row_col: tuple) -> torch.Tensor:
4
+ """Reconstructs a matrix from its rank-1 factors (outer product)."""
5
+ return torch.outer(row_col[0], row_col[1])
6
+
7
+ def _nnmf(matrix: torch.Tensor, out: tuple):
8
+ """Performs a rank-1 non-negative matrix factorization."""
9
+ shape = matrix.shape
10
+ torch.sum(matrix, dim=1, out=out[0])
11
+ torch.sum(matrix, dim=0, out=out[1])
12
+ # Normalize one of the factors for stability
13
+ if shape[0] < shape[1]:
14
+ scale = out[0].sum()
15
+ if scale != 0: out[0].div_(scale)
16
+ else:
17
+ scale = out[1].sum()
18
+ if scale != 0: out[1].div_(scale)
@@ -0,0 +1,22 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _pack_bools(tensor: torch.Tensor) -> torch.Tensor:
5
+ """Packs a boolean tensor into a uint8 tensor to achieve 1-bit storage."""
6
+ n, m = tensor.shape
7
+ packed_m = (m + 7) // 8
8
+ padded_tensor = torch.nn.functional.pad(tensor, (0, packed_m * 8 - m), 'constant', 0)
9
+ reshaped = padded_tensor.view(n, packed_m, 8)
10
+ shifter = torch.arange(8, device=tensor.device, dtype=torch.uint8)
11
+ packed = (reshaped.to(torch.uint8) * (2**shifter)).sum(dim=2).to(torch.uint8)
12
+ return packed
13
+
14
+ @torch.no_grad()
15
+ def _unpack_bools(packed_tensor: torch.Tensor, original_m: int) -> torch.Tensor:
16
+ """Unpacks a uint8 tensor back into a boolean tensor."""
17
+ if packed_tensor.dtype != torch.uint8:
18
+ packed_tensor = packed_tensor.to(torch.uint8)
19
+ shifter = (2**torch.arange(8, device=packed_tensor.device, dtype=torch.uint8)).view(1, 1, 8)
20
+ unpacked_padded = (packed_tensor.unsqueeze(2) & shifter) != 0
21
+ unpacked = unpacked_padded.view(packed_tensor.shape[0], -1)[:, :original_m]
22
+ return unpacked
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+ def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
4
+ """Projects the gradient `grad` to be orthogonal to the parameter `p`."""
5
+ if grad.is_sparse: raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
6
+ original_shape = grad.shape
7
+ original_dtype = grad.dtype
8
+ w = p.view(-1).float()
9
+ g = grad.view(-1).float()
10
+ w_norm_sq = torch.dot(w, w).add_(1e-30)
11
+ proj = torch.dot(w, g) / w_norm_sq
12
+ g_orth = g.sub(w, alpha=proj)
13
+ g_norm = g.norm(2)
14
+ g_orth_norm = g_orth.norm(2).add_(1e-30)
15
+ g_orth_scaled = g_orth * (g_norm / g_orth_norm)
16
+ return g_orth_scaled.view(original_shape).to(original_dtype)
@@ -0,0 +1,37 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from typing import Tuple
5
+
6
+ def _rsvd(A: torch.Tensor, rank: int, oversampling: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
7
+ """Performs Randomized SVD."""
8
+ orig_dtype, device, (m, n) = A.dtype, A.device, A.shape
9
+ A_float = A.float()
10
+ l, true_rank = rank + oversampling, min(m, n, rank)
11
+
12
+ if true_rank == 0:
13
+ return (
14
+ torch.zeros(m, rank, dtype=orig_dtype, device=device),
15
+ torch.zeros(rank, dtype=orig_dtype, device=device),
16
+ torch.zeros(rank, n, dtype=orig_dtype, device=device),
17
+ )
18
+
19
+ if l >= min(m, n): # Fallback to full SVD
20
+ U_full, S_full, Vh_full = torch.linalg.svd(A_float, full_matrices=False)
21
+ U, S, Vh = U_full[:, :true_rank], S_full[:true_rank], Vh_full[:true_rank, :]
22
+ else: # Standard RSVD path
23
+ Omega = torch.randn(n, l, dtype=A_float.dtype, device=device)
24
+ Y = A_float @ Omega
25
+ Q, _ = torch.linalg.qr(Y.float())
26
+ B = Q.T @ A_float
27
+ U_tilde, S, Vh = torch.linalg.svd(B.float(), full_matrices=False)
28
+ U, S, Vh = (Q @ U_tilde)[:, :true_rank], S[:true_rank], Vh[:true_rank, :]
29
+
30
+ if true_rank < rank: # Pad factors with zeros
31
+ U_padded = torch.zeros(m, rank, dtype=A_float.dtype, device=device)
32
+ S_padded = torch.zeros(rank, dtype=A_float.dtype, device=device)
33
+ Vh_padded = torch.zeros(rank, n, dtype=A_float.dtype, device=device)
34
+ U_padded[:, :true_rank], S_padded[:true_rank], Vh_padded[:true_rank, :] = U, S, Vh
35
+ U, S, Vh = U_padded, S_padded, Vh_padded
36
+
37
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
@@ -0,0 +1,11 @@
1
+ from .BF16_Stochastic_Rounding import add_stochastic_, copy_stochastic_
2
+ from .Effective_Shape import _get_effective_shape
3
+ from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
+ from .OrthoGrad import _orthogonalize_gradient
5
+
6
+ __all__ = [
7
+ "_pack_bools", "_unpack_bools",
8
+ "add_stochastic_",
9
+ "_get_effective_shape",
10
+ "_orthogonalize_gradient",
11
+ ]
@@ -0,0 +1,134 @@
1
+ Metadata-Version: 2.4
2
+ Name: adv_optm
3
+ Version: 0.1.0
4
+ Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
+ Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
+ Author: Koratahiu
7
+ Author-email: hiuhonor@gmail.com
8
+ License: Apache 2.0
9
+ Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.0
19
+ Dynamic: author
20
+ Dynamic: author-email
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # Advanced Optimizers
33
+
34
+ This repo introduces a new family of highly efficient, lightweight yet powerful optimizers, born from extensive research into recent academic literature and validated through practical training runs across diverse models.
35
+
36
+ ---
37
+
38
+ ### Install
39
+
40
+ `pip install adv_optm`
41
+
42
+ ---
43
+
44
+ ### Theory (Inspired by SMMF)
45
+
46
+ Based primarily on:
47
+ **[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
48
+
49
+ The core innovation:
50
+ - Uses fast, non-negative matrix factorization (rank 1, à la Adafactor), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
51
+ - For the *signed first moment*, we split into **sign + absolute value**:
52
+ - Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
53
+ - Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
54
+ - Final storage: **four factored vectors + one 1-bit sign**.
55
+ - Updates behave like full-state Adam but with drastically reduced memory.
56
+
57
+ > ✅ **TL;DR**: Lightweight, strong, memory-efficient optimizer.
58
+
59
+ ---
60
+
61
+ ### Memory Cost
62
+
63
+ - **Adopt_Factored** for full SDXL finetune: **328 MB** (4 small vectors + 1-bit state)
64
+ - **Adopt_Factored with AdEMAMix** for full SDXL finetune: **625 MB** (6 small vectors + two 1-bit states)
65
+ > SDXL is 6.5GB model.
66
+
67
+ ---
68
+
69
+ ### ⏱️ Speed (my tests in SDXL - BS 4)
70
+
71
+ - **Adopt_Factored**: ~10s/it
72
+ - **Adopt_Factored with AdEMAMix**: ~12s/it
73
+ - **Adafactor**: ~8.5s/it
74
+ → Overhead from compression/reconstruction cycles.
75
+ → It's faster than [MLorc](https://arxiv.org/abs/2506.01897) (~12s/it), which uses RSVD compression, and should be the fastest momentum compression (AFAIK).
76
+
77
+ ---
78
+
79
+ ### 📈 Performance
80
+
81
+ - **Better than Adafactor, and CAME factorzation methods**
82
+ - **Comparable or identical to Adam** (see SMMF paper results)
83
+
84
+ ---
85
+
86
+ ### Available Optimizers (all support `Factored` toggle)
87
+
88
+ Set `Factored=False` to disable factorization and run as a full uncompressed optimizer (like vanilla Adam).
89
+
90
+ 1. **Adam**
91
+ 2. **Prodigy**
92
+ 3. **Adopt**
93
+
94
+ ---
95
+
96
+ ### Bonus Features (Built-in)
97
+
98
+ - **Fused Backward Pass**
99
+
100
+ - **Stochastic Rounding (SR)**: Improves quality and convergence for **BF16 training**.
101
+
102
+ - **[AdEMAMix](https://arxiv.org/abs/2409.03137)**
103
+ → This adds a second, slow-moving EMA, which is combined with the primary momentum to stabilize updates, especially during long runs of full finetuning.
104
+ → A higher value of beta3 (e.g., 0.9999) gives the EMA a longer memory, making it more stable but slower to adapt. A lower value (e.g., 0.999) is often better for shorter training runs (2k-4k steps).
105
+ → When `factored` is true, it compresses the new momentum in the same way as the first moment (1-bit state + 2 vectors). However, this introduces noticeable overhead as we are compressing/reconstructing a third state each step.
106
+
107
+ ⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
108
+
109
+ ⚠️ **Note**: The factored AdEMAMix is **Experimental** (as it needs more tests and validation, but it should work). Also, Adopt with AdEMAMix is **Experimental** (as Adopt normalizes the gradients for the momentum).
110
+
111
+ - **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
112
+ → Robust `eps` replacement (no tuning!) + built-in gradient clipping
113
+ → *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
114
+
115
+ - **[OrthoGrad](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability)**
116
+ → Removes gradient component parallel to weights → prevents "naïve loss minimization" (NLM) → reduces natural overfitting
117
+ → Perfect for fine-tuning the direction of existing features (e.g., full finetune or training a trained LoRA) without weight decay erasing prior knowledge.
118
+
119
+ ⚠️ **Note**: OrthoGrad introduces **~33% time overhead**, so take this into account.
120
+
121
+ - **[Grams: Gradient Descent with Adaptive Momentum Scaling](https://github.com/Gunale0926/Grams)**
122
+ → Eliminates the need for 1-bit momentum sign storage by using the **sign of gradients** for the first moment.
123
+
124
+ ⚠️ **Not recommended for small batch sizes**: gradients are too noisy, which can destabilize momentum (tested for Prodigy and it made the optimizer slower to find the LR or converge in BS 4).
125
+
126
+ ### Other Notes
127
+
128
+ - **Adopt** skips the first step (only initializes the states) and has built-in clipping (sticking to the original optimizer), but we skip both of these when you enable `use_atan2`; as the optimizer becomes scale-invariant and the values of the states won't cause any issues or instability.
129
+
130
+ - When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
131
+
132
+ - I don't recommend using **OrthoGrad** for training LoRA or embeddings, as their weights are zero-initialized and using weight decay for them should be safe and also beneficial (OrthoGrad is intended for fine-tuning pretrained models with no weight decay).
133
+
134
+ ---