adv-optm 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -2,12 +2,16 @@ from .optim import (
2
2
  AdamW_adv,
3
3
  Prodigy_adv,
4
4
  Adopt_adv,
5
+ Lion_adv,
6
+ Lion_Prodigy_adv,
5
7
  )
6
8
 
7
9
  __all__ = [
8
10
  "AdamW_adv",
9
11
  "Prodigy_adv",
10
12
  "Adopt_adv",
13
+ "Lion_adv",
14
+ "Lion_Prodigy_adv",
11
15
  ]
12
16
 
13
- __version__ = "0.1.1"
17
+ __version__ = "0.1.2"
@@ -22,7 +22,6 @@ class AdamW_adv(torch.optim.Optimizer):
22
22
  eps (float): term added to the denominator to improve
23
23
  numerical stability (default: 1e-8)
24
24
  weight_decay (float): weight decay (L2 penalty) (default: 0)
25
- use_bias_correction (boolean): Turn on Adam's bias correction. (default: False)
26
25
  vector_reshape (bool): whether to reshape 1D vectors into 2D
27
26
  matrices to apply low-rank compression (default: True).
28
27
  stochastic_rounding (bool): whether to use stochastic
@@ -63,7 +62,6 @@ class AdamW_adv(torch.optim.Optimizer):
63
62
  betas: tuple[float, float] = (0.9, 0.999),
64
63
  eps: float = 1e-8,
65
64
  weight_decay: float = 0.0,
66
- use_bias_correction: bool = False,
67
65
  vector_reshape: bool = True,
68
66
  stochastic_rounding: bool = True,
69
67
  use_atan2: bool = False,
@@ -88,7 +86,7 @@ class AdamW_adv(torch.optim.Optimizer):
88
86
  defaults = {
89
87
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
90
88
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
91
- "use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
89
+ "use_orthograd": use_orthograd,
92
90
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
93
91
  }
94
92
  self.stochastic_rounding = stochastic_rounding
@@ -122,6 +120,8 @@ class AdamW_adv(torch.optim.Optimizer):
122
120
  grad = _orthogonalize_gradient(p, grad)
123
121
  state = self.state[p]
124
122
 
123
+ beta1, beta2 = group['betas']
124
+
125
125
  # State Initialization
126
126
  if len(state) == 0:
127
127
  state['step'] = 0
@@ -141,11 +141,12 @@ class AdamW_adv(torch.optim.Optimizer):
141
141
  d1, d2 = state['effective_shape']
142
142
 
143
143
  # First moment (m)
144
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
145
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
146
- if not self.use_grams:
147
- packed_d2 = (d2 + 7) // 8
148
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
144
+ if beta1 > 0:
145
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
146
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
147
+ if not self.use_grams:
148
+ packed_d2 = (d2 + 7) // 8
149
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
149
150
  if self.use_AdEMAMix:
150
151
  state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
151
152
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
@@ -155,12 +156,12 @@ class AdamW_adv(torch.optim.Optimizer):
155
156
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
156
157
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
157
158
  else: # Fallback to standard AdamW for non-factored tensors
158
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
159
+ if beta1 > 0:
160
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
159
161
  if self.use_AdEMAMix:
160
162
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
161
163
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
162
164
 
163
- beta1, beta2 = group['betas']
164
165
  if self.use_AdEMAMix:
165
166
  beta3_ema = group['beta3_ema']
166
167
  alpha = group['alpha']
@@ -174,21 +175,22 @@ class AdamW_adv(torch.optim.Optimizer):
174
175
  d1, d2 = state['effective_shape']
175
176
 
176
177
  # Reconstruct momentum from previous step's factors
177
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
178
- if not self.use_grams:
179
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
180
- torch.where(unpacked_sign, mt, -mt, out=mt)
181
- del unpacked_sign
182
- # Update momentum in full-size
183
- grad_reshaped = grad.view(d1, d2)
184
- mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
185
- if self.use_grams:
186
- mt.copy_(grad_reshaped.sign() * mt.abs())
187
- elif self.use_cautious:
188
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
189
- mask.div_(mask.mean().clamp_(min=1e-3))
190
- mt.mul_(mask)
191
- del mask
178
+ if beta1 > 0:
179
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
180
+ if not self.use_grams:
181
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
182
+ torch.where(unpacked_sign, mt, -mt, out=mt)
183
+ del unpacked_sign
184
+ # Update momentum in full-size
185
+ grad_reshaped = grad.view(d1, d2)
186
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
187
+ if self.use_grams:
188
+ mt.copy_(grad_reshaped.sign() * mt.abs())
189
+ elif self.use_cautious:
190
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
191
+ mask.div_(mask.mean().clamp_(min=1e-3))
192
+ mt.mul_(mask)
193
+ del mask
192
194
 
193
195
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
194
196
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
@@ -202,28 +204,28 @@ class AdamW_adv(torch.optim.Optimizer):
202
204
  del unpacked_sign_slow
203
205
 
204
206
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
205
- update_m = mt + (alpha_t * mt_slow)
207
+ update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
206
208
  else:
207
- update_m = mt
209
+ update = mt if beta1 > 0 else grad_reshaped
208
210
  del grad_reshaped
209
211
 
210
212
  if group['use_atan2']:
211
213
  a = 1.2732395
212
214
  denom = vt.sqrt()
213
- update = torch.atan2(update_m, denom).mul_(a)
215
+ update.atan2_(denom).mul_(a)
214
216
  else:
215
- denom = vt.sqrt().add_(group['eps'])
216
- update = update_m / denom
217
- del update_m, denom
217
+ denom = vt.sqrt()
218
+ update.div_(denom.add_(group['eps']))
219
+ del denom
218
220
 
219
- update = update.view(p.shape)
220
- update.mul_(group['lr'])
221
+ update.view(p.shape).mul_(group['lr'])
221
222
 
222
223
  # Compress updated moments and store new factors
223
- if not self.use_grams:
224
- state['sign'] = _pack_bools(mt > 0)
225
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
226
- del mt
224
+ if beta1 > 0:
225
+ if not self.use_grams:
226
+ state['sign'] = _pack_bools(mt > 0)
227
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
228
+ del mt
227
229
  if self.use_AdEMAMix:
228
230
  state['sign_slow'] = _pack_bools(mt_slow > 0)
229
231
  _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -232,36 +234,38 @@ class AdamW_adv(torch.optim.Optimizer):
232
234
  del vt
233
235
 
234
236
  else: # Standard AdamW logic for non-factored tensors
235
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
236
-
237
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
238
- if self.use_grams:
239
- exp_avg = grad.sign() * exp_avg.abs()
240
- elif self.use_cautious:
241
- mask = (exp_avg * grad > 0).to(grad.dtype)
242
- mask.div_(mask.mean().clamp_(min=1e-3))
243
- exp_avg.mul_(mask)
244
- del mask
237
+ exp_avg_sq = state['exp_avg_sq']
238
+
239
+ if beta1 > 0:
240
+ exp_avg = state['exp_avg']
241
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
242
+ if self.use_grams:
243
+ exp_avg = grad.sign() * exp_avg.abs()
244
+ elif self.use_cautious:
245
+ mask = (exp_avg * grad > 0).to(grad.dtype)
246
+ mask.div_(mask.mean().clamp_(min=1e-3))
247
+ exp_avg.mul_(mask)
248
+ del mask
245
249
 
246
250
  if self.use_AdEMAMix:
247
251
  exp_avg_slow = state['exp_avg_slow']
248
252
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
249
- update_m = exp_avg + (alpha_t * exp_avg_slow)
253
+ update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
250
254
  else:
251
- update_m = exp_avg
255
+ update = exp_avg if beta1 > 0 else grad
252
256
 
253
257
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
254
258
 
255
259
  if group['use_atan2']:
256
260
  a = 1.2732395
257
261
  denom = exp_avg_sq.sqrt()
258
- update = torch.atan2(update_m, denom).mul_(a)
262
+ update.atan2_(denom).mul_(a)
259
263
  else:
260
- denom = exp_avg_sq.sqrt().add_(group['eps'])
261
- update = update_m / denom
262
- del update_m, denom
264
+ denom = exp_avg_sq.sqrt()
265
+ update.div_(denom.add_(group['eps']))
266
+ del denom
263
267
 
264
- update = update.mul_(group['lr'])
268
+ update.mul_(group['lr'])
265
269
 
266
270
  # Decoupled weight decay
267
271
  if group["weight_decay"] != 0:
@@ -30,8 +30,6 @@ class Adopt_adv(torch.optim.Optimizer):
30
30
  clip_lambda (Callable, optional): A function that takes the current step
31
31
  and returns a value to clip the normalized gradient. Only used when
32
32
  `use_atan2` is False. (default: `lambda step: step**0.25`)
33
- rank (int): the rank for the low-rank approximation (default: 4).
34
- oversampling (int): oversampling parameter for Randomized SVD. (default: 0).
35
33
  vector_reshape (bool): whether to reshape 1D vectors into 2D
36
34
  matrices for low-rank compression (default: True).
37
35
  stochastic_rounding (bool): whether to use stochastic
@@ -192,29 +190,29 @@ class Adopt_adv(torch.optim.Optimizer):
192
190
  d1, d2 = state['effective_shape']
193
191
 
194
192
  # Reconstruct m_{t-1}
195
- mt_prev = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
193
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
196
194
  if not self.use_grams:
197
195
  if state['sign'].dtype != torch.uint8:
198
196
  state['sign'] = state['sign'].to(torch.uint8)
199
197
  unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
200
- torch.where(unpacked_sign, mt_prev, -mt_prev, out=mt_prev)
198
+ torch.where(unpacked_sign, mt, -mt, out=mt)
201
199
  del unpacked_sign
202
200
 
203
201
  # Reconstruct AdEMAMix EMA
204
202
  if self.use_AdEMAMix:
205
- mt_slow_prev = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
203
+ mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
206
204
  if state['sign_slow'].dtype != torch.uint8:
207
205
  state['sign_slow'] = state['sign_slow'].to(torch.uint8)
208
206
  unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
209
- torch.where(unpacked_sign_slow, mt_slow_prev, -mt_slow_prev, out=mt_slow_prev)
207
+ torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
210
208
  del unpacked_sign_slow
211
209
 
212
210
  # Reconstruct v_{t-1} using NNMF
213
- vt_prev = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
211
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
214
212
 
215
213
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
216
214
  grad_reshaped = grad.view(d1, d2)
217
- denom = vt_prev.sqrt()
215
+ denom = vt.sqrt()
218
216
 
219
217
  if self.use_atan2:
220
218
  normalized_grad = torch.atan2(grad_reshaped, denom)
@@ -226,7 +224,7 @@ class Adopt_adv(torch.optim.Optimizer):
226
224
  del denom
227
225
 
228
226
  # ADOPT Step B: Update momentum m_t using normalized gradient
229
- mt = mt_prev.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
227
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
230
228
  if self.use_grams:
231
229
  mt = grad_reshaped.sign() * mt.abs()
232
230
  elif self.use_cautious:
@@ -236,7 +234,7 @@ class Adopt_adv(torch.optim.Optimizer):
236
234
  del mask
237
235
 
238
236
  if self.use_AdEMAMix:
239
- mt_slow = mt_slow_prev.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
237
+ mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
240
238
  update = mt + (alpha_t * mt_slow)
241
239
  update = update.view(p.shape)
242
240
  else:
@@ -248,20 +246,23 @@ class Adopt_adv(torch.optim.Optimizer):
248
246
  update.mul_(group['lr'])
249
247
 
250
248
  # Update second moment v_t for the *next* step using raw g_t
251
- vt_updated = vt_prev.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
249
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
252
250
  del grad_reshaped
253
251
 
254
252
  # Compress and store new factors
255
253
  if not self.use_grams:
256
254
  state['sign'] = _pack_bools(mt > 0)
257
255
  _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
256
+ del mt
258
257
 
259
258
  if self.use_AdEMAMix:
260
259
  state['sign_slow'] = _pack_bools(mt_slow > 0)
261
260
  _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
261
+ del mt_slow
262
262
 
263
263
  # factorize v_t using NMF compression
264
- _nnmf(vt_updated, out=(state['mu_v_nmf'], state['mv_v_nmf']))
264
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
265
+ del vt
265
266
 
266
267
  else: # Standard ADOPT logic for non-factored tensors
267
268
  m, v = state['exp_avg'], state['exp_avg_sq'] # m_{t-1}, v_{t-1}
@@ -0,0 +1,335 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+
6
+ from typing import Tuple, Optional
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Effective_Shape import _get_effective_shape
10
+ from ..util.NNMF import _nnmf,_unnmf
11
+ from ..util.OrthoGrad import _orthogonalize_gradient
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+
14
+ class Lion_Prodigy_adv(torch.optim.Optimizer):
15
+ """
16
+ Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ factored (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ variance_reduction (bool): whether to use the variance reduction technique
37
+ from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
+ d0 (float):
39
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
40
+ d_coef (float):
41
+ Coefficient in the expression for the estimate of d (default 1.0).
42
+ Values such as 0.5 and 2.0 typically work as well.
43
+ Changing this parameter is the preferred way to tune the method.
44
+ growth_rate (float):
45
+ prevent the D estimate from growing faster than this multiplicative rate.
46
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
47
+ rate warmup effect.
48
+ fsdp_in_use (bool):
49
+ If you're using sharded parameters, this should be set to True. The optimizer
50
+ will attempt to auto-detect this, but if you're using an implementation other
51
+ than PyTorch's builtin version, the auto-detection won't work.
52
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
53
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
54
+ Prodigy. Values ~11 are reasonable (default 11).
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ params,
60
+ lr: float = 1,
61
+ betas: Tuple[float, float] = (0.9, 0.99),
62
+ weight_decay: float = 0.0,
63
+ vector_reshape: bool = True,
64
+ stochastic_rounding: bool = True,
65
+ use_orthograd: bool = False,
66
+ use_cautious: bool = False,
67
+ clip_threshold: float = 0.0,
68
+ factored: bool = True,
69
+ variance_reduction: bool = False,
70
+ # prodigy parameters
71
+ beta3: float = None,
72
+ d0: float = 1e-6,
73
+ d_coef: float = 1,
74
+ growth_rate: float = float('inf'),
75
+ safeguard_warmup: bool = False,
76
+ fsdp_in_use: bool = False,
77
+ slice_p: int = 11,
78
+ ):
79
+ if not lr > 0.0:
80
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
81
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
82
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
83
+ if not weight_decay >= 0.0:
84
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
85
+ if variance_reduction and use_cautious:
86
+ print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
87
+
88
+ defaults = dict(
89
+ lr=lr,
90
+ betas=betas,
91
+ weight_decay=weight_decay,
92
+ vector_reshape=vector_reshape,
93
+ use_orthograd=use_orthograd,
94
+ clip_threshold=clip_threshold,
95
+ beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
96
+ growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
97
+ fsdp_in_use=fsdp_in_use,
98
+ )
99
+ self.stochastic_rounding = stochastic_rounding
100
+ self.use_cautious = use_cautious
101
+ self.factored = factored
102
+ self.variance_reduction = variance_reduction
103
+ self.fsdp_in_use = fsdp_in_use
104
+ super().__init__(params, defaults)
105
+ # Global state for accumulating metrics across parameter updates within a single step.
106
+ self.init_step()
107
+
108
+ @property
109
+ def supports_fused_back_pass(self) -> bool:
110
+ return True
111
+
112
+ @property
113
+ def supports_memory_efficient_fp16(self) -> bool:
114
+ return True
115
+
116
+ @property
117
+ def supports_flat_params(self) -> bool:
118
+ return False
119
+
120
+ def init_step(self):
121
+ """Resets accumulators and calculates dlr for the upcoming step."""
122
+ self.d_denom = 0.0
123
+
124
+ g_group = self.param_groups[0]
125
+ self.beta1, self.beta2 = g_group['betas']
126
+ self.beta3 = g_group['beta3']
127
+ if self.beta3 is None:
128
+ self.beta3 = math.sqrt(self.beta2)
129
+
130
+ k = g_group['k']
131
+ self.d = g_group['d']
132
+ lr = g_group['lr']
133
+
134
+ self.dlr = self.d * lr
135
+
136
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
137
+
138
+ @torch.no_grad()
139
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
140
+ """Performs a single optimization step on a single parameter."""
141
+ if p.grad is None:
142
+ return
143
+
144
+ if hasattr(p, "_fsdp_flattened"):
145
+ self.fsdp_in_use = True
146
+
147
+ grad = p.grad
148
+ if grad.dtype != torch.float32 and self.factored:
149
+ grad = grad.float()
150
+ if group["clip_threshold"] > 0.0:
151
+ grad_norm = torch.norm(grad.detach())
152
+ if grad_norm > group["clip_threshold"]:
153
+ clip_coef = group["clip_threshold"] / grad_norm
154
+ grad.mul_(clip_coef)
155
+ if group["use_orthograd"]:
156
+ grad = _orthogonalize_gradient(p, grad)
157
+ state = self.state[p]
158
+
159
+ # State Initialization
160
+ if len(state) == 0:
161
+ state['step'] = 0
162
+
163
+ should_factor = (
164
+ self.factored and
165
+ not (len(p.shape) == 1 and not group['vector_reshape'])
166
+ )
167
+
168
+ state['factored'] = should_factor
169
+
170
+ dtype = torch.float32 if self.factored else p.dtype
171
+
172
+ slice_p = group['slice_p']
173
+
174
+ # D-Adaptation states
175
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
176
+ if p.any():
177
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
178
+ else:
179
+ state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
180
+
181
+ if state['factored']:
182
+ state['effective_shape'] = _get_effective_shape(p.numel())
183
+ d1, d2 = state['effective_shape']
184
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
185
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
186
+ packed_d2 = (d2 + 7) // 8
187
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
188
+ if self.variance_reduction:
189
+ state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
190
+ else: # Fallback to standard Lion
191
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
192
+ if self.variance_reduction:
193
+ state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
194
+
195
+ if state['factored']:
196
+ # Factored Path
197
+ d1, d2 = state['effective_shape']
198
+ grad_reshaped = grad.view(d1, d2)
199
+ # Reconstruct momentum m_{t-1}
200
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
201
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
202
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
203
+ del unpacked_sign
204
+ if exp_avg.dtype != torch.float32:
205
+ exp_avg = exp_avg.float()
206
+
207
+ # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
208
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=(1-self.beta1)).sign_()
209
+
210
+ if self.use_cautious:
211
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
212
+ mask.div_(mask.mean().clamp_(min=1e-3))
213
+ signed_update.mul_(mask)
214
+ del mask
215
+
216
+ # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
217
+ update_for_param = signed_update.view(p.shape).mul(self.dlr)
218
+
219
+ # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
220
+ if self.variance_reduction:
221
+ vr_term = grad_reshaped - state['prev_grad']
222
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
223
+ state['prev_grad'].copy_(grad_reshaped)
224
+ else:
225
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1-self.beta2))
226
+ del grad_reshaped
227
+
228
+ # Compress new momentum m_t and store factors
229
+ state['sign'] = _pack_bools(exp_avg > 0)
230
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
231
+ del exp_avg
232
+
233
+ else:
234
+ # Fallback to standard Lion logic
235
+ exp_avg = state["exp_avg"]
236
+
237
+ # Compute update term and sign for the update
238
+ if exp_avg.dtype != torch.float32 and self.factored:
239
+ exp_avg = exp_avg.float()
240
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=(1-self.beta1)).sign_()
241
+
242
+ if self.use_cautious:
243
+ mask = (signed_update * grad > 0).to(grad.dtype)
244
+ mask.div_(mask.mean().clamp_(min=1e-3))
245
+ signed_update.mul_(mask)
246
+ del mask
247
+
248
+ update_for_param = signed_update.mul(self.dlr)
249
+
250
+ # Update momentum
251
+ if self.variance_reduction:
252
+ vr_term = grad - state['prev_grad']
253
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2)).add_(vr_term, alpha=self.beta2)
254
+ state['prev_grad'].copy_(grad)
255
+ else:
256
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1-self.beta2))
257
+
258
+ # --- Accumulate Prodigy stats ---
259
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
260
+ s, p0 = state['s'], state['p0']
261
+ grad_flat = grad.flatten().float()
262
+ p_flat = p.data.flatten().float()
263
+ p0 = p0.float()
264
+
265
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
266
+
267
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
268
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
269
+ self.d_denom += s.abs().sum().item()
270
+
271
+ del s, p0, grad_flat, p_flat, alpha
272
+
273
+ if group["weight_decay"] != 0:
274
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
275
+ add_stochastic_(p.data, p.data,
276
+ alpha=-group["weight_decay"] * self.dlr)
277
+ else:
278
+ p.data.add_(
279
+ p.data, alpha=-group["weight_decay"] * self.dlr
280
+ )
281
+
282
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
283
+ add_stochastic_(p.data, -update_for_param)
284
+ else:
285
+ p.data.add_(-update_for_param)
286
+
287
+ del update_for_param
288
+
289
+ @torch.no_grad()
290
+ def step(self, closure: Optional[callable] = None):
291
+ """Performs a single optimization step."""
292
+ loss = None
293
+ if closure is not None:
294
+ with torch.enable_grad():
295
+ loss = closure()
296
+
297
+ for group in self.param_groups:
298
+ for i, p in enumerate(group["params"]):
299
+ if p.grad is not None:
300
+ self.step_parameter(p, group, i)
301
+
302
+
303
+ self.calculate_d()
304
+ self.init_step()
305
+ return loss
306
+
307
+ def calculate_d(self):
308
+ """Calculates the new `d` based on the accumulated stats."""
309
+ g_group = self.param_groups[0]
310
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
311
+
312
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
313
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
314
+ device = self.param_groups[0]['params'][0].device
315
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
316
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
317
+ global_d_numerator = dist_tensor[0].item()
318
+ global_d_denom = dist_tensor[1].item()
319
+ else:
320
+ global_d_numerator = self.d_numerator
321
+ global_d_denom = self.d_denom
322
+
323
+ d_hat = self.d
324
+ if global_d_denom > 0:
325
+ d_hat = d_coef * global_d_numerator / global_d_denom
326
+ if self.d == g_group['d0']:
327
+ self.d = max(self.d, d_hat)
328
+ d_max = max(d_max, d_hat)
329
+ self.d = min(d_max, self.d * growth_rate)
330
+
331
+ for group in self.param_groups:
332
+ group['d_numerator'] = global_d_numerator
333
+ group['d'] = self.d
334
+ group['d_max'] = d_max
335
+ group['k'] += 1
@@ -0,0 +1,231 @@
1
+ import torch
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ class Lion_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements the SMMF technique for Lion algorithm.
14
+
15
+ This optimizer combines the Lion update rule with the memory-saving low-rank
16
+ compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ use_cautious (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ factored (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ variance_reduction (bool): whether to use the variance reduction technique
37
+ from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ params,
43
+ lr: float = 1e-4,
44
+ betas: Tuple[float, float] = (0.9, 0.99),
45
+ weight_decay: float = 0.0,
46
+ vector_reshape: bool = True,
47
+ stochastic_rounding: bool = True,
48
+ use_orthograd: bool = False,
49
+ use_cautious: bool = False,
50
+ clip_threshold: float = 0.0,
51
+ factored: bool = True,
52
+ variance_reduction: bool = False,
53
+ ):
54
+ if not lr > 0.0:
55
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
56
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
57
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
+ if not weight_decay >= 0.0:
59
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
+ if variance_reduction and use_cautious:
61
+ print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
62
+
63
+ defaults = dict(
64
+ lr=lr,
65
+ betas=betas,
66
+ weight_decay=weight_decay,
67
+ vector_reshape=vector_reshape,
68
+ use_orthograd=use_orthograd,
69
+ clip_threshold=clip_threshold,
70
+ )
71
+ self.stochastic_rounding = stochastic_rounding
72
+ self.use_cautious = use_cautious
73
+ self.factored = factored
74
+ self.variance_reduction = variance_reduction
75
+ super().__init__(params, defaults)
76
+
77
+ @property
78
+ def supports_fused_back_pass(self) -> bool:
79
+ return True
80
+
81
+ @property
82
+ def supports_memory_efficient_fp16(self) -> bool:
83
+ return True
84
+
85
+ @property
86
+ def supports_flat_params(self) -> bool:
87
+ return False
88
+
89
+ @torch.no_grad()
90
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
91
+ """Performs a single optimization step on a single parameter."""
92
+ if p.grad is None:
93
+ return
94
+
95
+ grad = p.grad
96
+ if grad.dtype != torch.float32 and self.factored:
97
+ grad = grad.float()
98
+ if group["clip_threshold"] > 0.0:
99
+ grad_norm = torch.norm(grad.detach())
100
+ if grad_norm > group["clip_threshold"]:
101
+ clip_coef = group["clip_threshold"] / grad_norm
102
+ grad.mul_(clip_coef)
103
+ if group["use_orthograd"]:
104
+ grad = _orthogonalize_gradient(p, grad)
105
+ state = self.state[p]
106
+
107
+ # State Initialization
108
+ if len(state) == 0:
109
+ state['step'] = 0
110
+
111
+ should_factor = (
112
+ self.factored and
113
+ not (len(p.shape) == 1 and not group['vector_reshape'])
114
+ )
115
+
116
+ state['factored'] = should_factor
117
+
118
+ dtype = torch.float32 if self.factored else p.dtype
119
+
120
+ if state['factored']:
121
+ state['effective_shape'] = _get_effective_shape(p.numel())
122
+ d1, d2 = state['effective_shape']
123
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
124
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
125
+ packed_d2 = (d2 + 7) // 8
126
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
127
+ if self.variance_reduction:
128
+ state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
129
+ else: # Fallback to standard Lion
130
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
+ if self.variance_reduction:
132
+ state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
133
+
134
+ state['step'] += 1
135
+ beta1, beta2 = group["betas"]
136
+ lr = group["lr"]
137
+
138
+ if state['factored']:
139
+ # Factored Path
140
+ d1, d2 = state['effective_shape']
141
+ grad_reshaped = grad.view(d1, d2)
142
+ # Reconstruct momentum m_{t-1}
143
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
144
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
145
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
146
+ del unpacked_sign
147
+ if exp_avg.dtype != torch.float32:
148
+ exp_avg = exp_avg.float()
149
+
150
+ # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
151
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
152
+
153
+ if self.use_cautious:
154
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
155
+ mask.div_(mask.mean().clamp_(min=1e-3))
156
+ signed_update.mul_(mask)
157
+ del mask
158
+
159
+ # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
160
+ update_for_param = signed_update.view(p.shape).mul_(lr)
161
+
162
+ # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
163
+ if self.variance_reduction:
164
+ vr_term = grad_reshaped - state['prev_grad']
165
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2).add_(vr_term, alpha=beta2)
166
+ del vr_term
167
+ state['prev_grad'].copy_(grad_reshaped)
168
+ else:
169
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
170
+ del grad_reshaped
171
+
172
+ # Compress new momentum m_t and store factors
173
+ state['sign'] = _pack_bools(exp_avg > 0)
174
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
175
+ del exp_avg
176
+
177
+ else:
178
+ # Fallback to standard Lion logic
179
+ exp_avg = state["exp_avg"]
180
+
181
+ # Compute update term and sign for the update
182
+ if exp_avg.dtype != torch.float32 and self.factored:
183
+ exp_avg = exp_avg.float()
184
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
185
+
186
+ if self.use_cautious:
187
+ mask = (signed_update * grad > 0).to(grad.dtype)
188
+ mask.div_(mask.mean().clamp_(min=1e-3))
189
+ signed_update.mul_(mask)
190
+ del mask
191
+
192
+ update_for_param = signed_update.mul_(lr)
193
+
194
+ # Update momentum
195
+ if self.variance_reduction:
196
+ vr_term = grad - state['prev_grad']
197
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2).add_(vr_term, alpha=beta2)
198
+ state['prev_grad'].copy_(grad)
199
+ else:
200
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
201
+
202
+ if group["weight_decay"] != 0:
203
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
204
+ add_stochastic_(p.data, p.data,
205
+ alpha=-group["weight_decay"] * lr)
206
+ else:
207
+ p.data.add_(
208
+ p.data, alpha=-group["weight_decay"] * lr
209
+ )
210
+
211
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
212
+ add_stochastic_(p.data, -update_for_param)
213
+ else:
214
+ p.data.add_(-update_for_param)
215
+
216
+ del update_for_param
217
+
218
+ @torch.no_grad()
219
+ def step(self, closure: Optional[callable] = None):
220
+ """Performs a single optimization step."""
221
+ loss = None
222
+ if closure is not None:
223
+ with torch.enable_grad():
224
+ loss = closure()
225
+
226
+ for group in self.param_groups:
227
+ for i, p in enumerate(group["params"]):
228
+ if p.grad is not None:
229
+ self.step_parameter(p, group, i)
230
+
231
+ return loss
@@ -18,7 +18,7 @@ class Prodigy_adv(torch.optim.Optimizer):
18
18
  Args:
19
19
  params (iterable): iterable of parameters to optimize or dicts defining
20
20
  parameter groups
21
- lr (float): learning rate (default: 1e-3)
21
+ lr (float): learning rate (default: 1)
22
22
  betas (tuple[float, float]): coefficients used for computing running
23
23
  averages of gradient and its square (default: (0.9, 0.999))
24
24
  eps (float): term added to the denominator to improve
@@ -71,13 +71,13 @@ class Prodigy_adv(torch.optim.Optimizer):
71
71
  than PyTorch's builtin version, the auto-detection won't work.
72
72
  slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
73
73
  pth entry of each tensor. For values greater than 1 this an an approximation to standard
74
- Prodigy. Values ~11 are reasonable (default 1).
74
+ Prodigy. Values ~11 are reasonable (default 11).
75
75
  """
76
76
 
77
77
  def __init__(
78
78
  self,
79
79
  params,
80
- lr: float = 1e-3,
80
+ lr: float = 1,
81
81
  betas: tuple[float, float] = (0.9, 0.999),
82
82
  eps: float = 1e-8,
83
83
  weight_decay: float = 0.0,
@@ -270,7 +270,7 @@ class Prodigy_adv(torch.optim.Optimizer):
270
270
  denom = vt.sqrt()
271
271
  update = torch.atan2(update_m, denom).mul_(a)
272
272
  else:
273
- denom = vt.sqrt().add_(group['eps'])
273
+ denom = vt.sqrt().add_(self.d * group['eps'])
274
274
  update = update_m / denom
275
275
  del update_m, denom
276
276
 
@@ -315,7 +315,7 @@ class Prodigy_adv(torch.optim.Optimizer):
315
315
  denom = exp_avg_sq.sqrt()
316
316
  update = torch.atan2(update_m, denom).mul_(a)
317
317
  else:
318
- denom = exp_avg_sq.sqrt().add_(group['eps'])
318
+ denom = exp_avg_sq.sqrt().add_(self.d * group['eps'])
319
319
  update = update_m / denom
320
320
  del update_m, denom
321
321
 
@@ -1,9 +1,13 @@
1
1
  from .AdamW_adv import AdamW_adv
2
2
  from .Prodigy_adv import Prodigy_adv
3
3
  from .Adopt_adv import Adopt_adv
4
+ from .Lion_adv import Lion_adv
5
+ from .Lion_Prodigy_adv import Lion_Prodigy_adv
4
6
 
5
7
  __all__ = [
6
8
  "AdamW_adv",
7
9
  "Prodigy_adv",
8
10
  "Adopt_adv",
11
+ "Lion_adv",
12
+ "Lion_Prodigy_adv",
9
13
  ]
adv_optm/util/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .BF16_Stochastic_Rounding import add_stochastic_, copy_stochastic_
1
+ from .BF16_Stochastic_Rounding import add_stochastic_
2
2
  from .Effective_Shape import _get_effective_shape
3
3
  from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
4
  from .OrthoGrad import _orthogonalize_gradient
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -47,7 +47,7 @@ Based primarily on:
47
47
  **[SMMF: Square-Matricized Momentum Factorization for Memory-Efficient Optimization](https://arxiv.org/abs/2412.08894)**
48
48
 
49
49
  The core innovation:
50
- - Uses fast, non-negative matrix factorization (rank 1, à la Adafactor), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
50
+ - Uses fast, non-negative matrix factorization (NNMF - rank 1), but **reconstructs the full state before each update** to preserve momentum accuracy, then re-factors afterward (factor → reconstruct → update → factor cycle).
51
51
  - For the *signed first moment*, we split into **sign + absolute value**:
52
52
  - Sign is stored as **1-bit state** via bitwise ops (SMMF originally used 8-bit with 7 bits wasted).
53
53
  - Absolute value goes through the factor/reconstruct cycle using two factored vectors + the signed state.
@@ -106,8 +106,6 @@ Set `Factored=False` to disable factorization and run as a full uncompressed opt
106
106
 
107
107
  ⚠️ **Note**: AdEMAMix updates are more aggressive than normal Adam/Adopt, so use a x2-x5 smaller LR than usual (or use Prodigy).
108
108
 
109
- ⚠️ **Note**: The factored AdEMAMix is **Experimental** (as it needs more tests and validation, but it should work). Also, Adopt with AdEMAMix is **Experimental** (as Adopt normalizes the gradients for the momentum).
110
-
111
109
  - **[`atan2` smoothing & scaling](https://github.com/lucidrains/adam-atan2-pytorch)**
112
110
  → Robust `eps` replacement (no tuning!) + built-in gradient clipping
113
111
  → *Ideal for ADOPT* (which normally needs higher `eps` and clipping), so `use_atan2` is all-in-one for it.
@@ -129,6 +127,4 @@ Set `Factored=False` to disable factorization and run as a full uncompressed opt
129
127
 
130
128
  - When `use_atan2` is True, `eps` will be ignored and you should also disable any gradient clipping.
131
129
 
132
- - I don't recommend using **OrthoGrad** for training LoRA or embeddings, as their weights are zero-initialized and using weight decay for them should be safe and also beneficial (OrthoGrad is intended for fine-tuning pretrained models with no weight decay).
133
-
134
130
  ---
@@ -0,0 +1,18 @@
1
+ adv_optm/__init__.py,sha256=BNYlxkuU8MFsWSY1_PLzp2XBSzpt-sxhnVuWVKRZGZ8,252
2
+ adv_optm/optim/AdamW_adv.py,sha256=_4Vt79EB18rnIkHttA0CdMpli8sZ5f03pesdrwT5K58,12887
3
+ adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=ql6506h_IIZvTPdGYrQdd6iEhCXHTMntqmg739fc_dw,14102
5
+ adv_optm/optim/Lion_adv.py,sha256=jOoRbJ6u9HCK7IBI9ILOCcwprKIGTUNvUzhRd99WJK0,9410
6
+ adv_optm/optim/Prodigy_adv.py,sha256=InR50MoE32zG6qgEkg_JzXl7uXAVRy4EYG0JDl4eKok,17324
7
+ adv_optm/optim/__init__.py,sha256=e5UighM92LDvDB2JJwj8gDsTpXEedpytScwqS6F2FR8,300
8
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
9
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
10
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
11
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
12
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
13
+ adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
14
+ adv_optm-0.1.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
15
+ adv_optm-0.1.2.dist-info/METADATA,sha256=iV5GBWtl4WphBeSIIsUoq1ay6-GJGnDD3XF6aSWWrqg,5846
16
+ adv_optm-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
+ adv_optm-0.1.2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
18
+ adv_optm-0.1.2.dist-info/RECORD,,
@@ -1,37 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- from typing import Tuple
5
-
6
- def _rsvd(A: torch.Tensor, rank: int, oversampling: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
7
- """Performs Randomized SVD."""
8
- orig_dtype, device, (m, n) = A.dtype, A.device, A.shape
9
- A_float = A.float()
10
- l, true_rank = rank + oversampling, min(m, n, rank)
11
-
12
- if true_rank == 0:
13
- return (
14
- torch.zeros(m, rank, dtype=orig_dtype, device=device),
15
- torch.zeros(rank, dtype=orig_dtype, device=device),
16
- torch.zeros(rank, n, dtype=orig_dtype, device=device),
17
- )
18
-
19
- if l >= min(m, n): # Fallback to full SVD
20
- U_full, S_full, Vh_full = torch.linalg.svd(A_float, full_matrices=False)
21
- U, S, Vh = U_full[:, :true_rank], S_full[:true_rank], Vh_full[:true_rank, :]
22
- else: # Standard RSVD path
23
- Omega = torch.randn(n, l, dtype=A_float.dtype, device=device)
24
- Y = A_float @ Omega
25
- Q, _ = torch.linalg.qr(Y.float())
26
- B = Q.T @ A_float
27
- U_tilde, S, Vh = torch.linalg.svd(B.float(), full_matrices=False)
28
- U, S, Vh = (Q @ U_tilde)[:, :true_rank], S[:true_rank], Vh[:true_rank, :]
29
-
30
- if true_rank < rank: # Pad factors with zeros
31
- U_padded = torch.zeros(m, rank, dtype=A_float.dtype, device=device)
32
- S_padded = torch.zeros(rank, dtype=A_float.dtype, device=device)
33
- Vh_padded = torch.zeros(rank, n, dtype=A_float.dtype, device=device)
34
- U_padded[:, :true_rank], S_padded[:true_rank], Vh_padded[:true_rank, :] = U, S, Vh
35
- U, S, Vh = U_padded, S_padded, Vh_padded
36
-
37
- return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
@@ -1,17 +0,0 @@
1
- adv_optm/__init__.py,sha256=Ol6hg_EdQH1AXJsa_9l5iWnlUXuOXwD-6eU1OweL87A,172
2
- adv_optm/optim/AdamW_adv.py,sha256=VGGzLhLh6CdY4I8mxmlzIC90rWnc9oGNuuXK8vE1dE0,12729
3
- adv_optm/optim/Adopt_adv.py,sha256=-GRpXWISCq6HPkd7UB1S57jSzsg2D3nAhAt6082_7Ms,14992
4
- adv_optm/optim/Prodigy_adv.py,sha256=5N5GsTWYg_0q_R95E_ryZVa3zSe-q30p_bFK5dXOUpM,17311
5
- adv_optm/optim/__init__.py,sha256=kX9MQhLQZGlKFPCGLXsZtooigs4wXULTEmNSSOJvcCY,178
6
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
7
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
8
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
9
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
10
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
11
- adv_optm/util/Randomized_SVD.py,sha256=TFG417hh1t5f1n_mChnbgdQhpMoi37O04xVCe8wz8Qc,1708
12
- adv_optm/util/__init__.py,sha256=3yYKo23JDfHDZdGcjrDKxH8nYjk5KDB-i44kW-J4sPk,367
13
- adv_optm-0.1.1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
14
- adv_optm-0.1.1.dist-info/METADATA,sha256=Mej63zbzvVh1YkAydQojP6SZSqz_46JA6-Y_3i3b2Fs,6342
15
- adv_optm-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- adv_optm-0.1.1.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
17
- adv_optm-0.1.1.dist-info/RECORD,,