adv-optm 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.0.5"
19
+ __version__ = "1.0.6"
@@ -156,6 +156,8 @@ class Adopt_adv(torch.optim.Optimizer):
156
156
  grad = _orthogonalize_gradient(p, grad)
157
157
  state = self.state[p]
158
158
 
159
+ beta1, beta2 = group['betas']
160
+
159
161
  # State Initialization
160
162
  if len(state) == 0:
161
163
  state['step'] = 0
@@ -174,11 +176,12 @@ class Adopt_adv(torch.optim.Optimizer):
174
176
  d1, d2 = state['effective_shape']
175
177
 
176
178
  # m_0 = 0
177
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
178
- state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
179
- if not self.grams_moment:
180
- packed_d2 = (d2 + 7) // 8
181
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
179
+ if beta1 > 0:
180
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
181
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
182
+ if not self.grams_moment:
183
+ packed_d2 = (d2 + 7) // 8
184
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
182
185
  if self.use_AdEMAMix:
183
186
  state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
184
187
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
@@ -192,7 +195,8 @@ class Adopt_adv(torch.optim.Optimizer):
192
195
  # Initialize v_0 using NMF
193
196
  _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
194
197
  else: # Fallback for non-factored tensors
195
- state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
198
+ if beta1 > 0:
199
+ state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
196
200
  if self.use_AdEMAMix:
197
201
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
198
202
  state['exp_avg_sq'] = grad.square() # v_0
@@ -202,7 +206,6 @@ class Adopt_adv(torch.optim.Optimizer):
202
206
  state['step'] += 1
203
207
  return
204
208
 
205
- beta1, beta2 = group['betas']
206
209
  if self.use_AdEMAMix:
207
210
  beta3_ema = group['beta3_ema']
208
211
  alpha = group['alpha']
@@ -219,13 +222,14 @@ class Adopt_adv(torch.optim.Optimizer):
219
222
  d1, d2 = state['effective_shape']
220
223
 
221
224
  # Reconstruct m_{t-1}
222
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
223
- if not self.grams_moment:
224
- if state['sign'].dtype != torch.uint8:
225
- state['sign'] = state['sign'].to(torch.uint8)
226
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
227
- torch.where(unpacked_sign, mt, -mt, out=mt)
228
- del unpacked_sign
225
+ if beta1 > 0:
226
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
227
+ if not self.grams_moment:
228
+ if state['sign'].dtype != torch.uint8:
229
+ state['sign'] = state['sign'].to(torch.uint8)
230
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
231
+ torch.where(unpacked_sign, mt, -mt, out=mt)
232
+ del unpacked_sign
229
233
 
230
234
  # Reconstruct AdEMAMix EMA
231
235
  if self.use_AdEMAMix:
@@ -253,25 +257,29 @@ class Adopt_adv(torch.optim.Optimizer):
253
257
  del denom
254
258
 
255
259
  # ADOPT Step B: Update momentum m_t using normalized gradient
256
- if self.Simplified_AdEMAMix:
257
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
258
- else:
259
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
260
- if self.grams_moment:
261
- mt = grad_reshaped.sign() * mt.abs()
262
- elif self.cautious_mask:
263
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
264
- mask.div_(mask.mean().clamp_(min=1e-3))
265
- mt.mul_(mask)
266
- del mask
260
+ if beta1 > 0:
261
+ if self.Simplified_AdEMAMix:
262
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
263
+ else:
264
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
265
+ if self.grams_moment:
266
+ mt = grad_reshaped.sign() * mt.abs()
267
+ elif self.cautious_mask:
268
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
269
+ mask.div_(mask.mean().clamp_(min=1e-3))
270
+ mt.mul_(mask)
271
+ del mask
267
272
 
268
273
  if self.use_AdEMAMix:
269
274
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
270
- update = torch.add(mt, mt_slow, alpha=alpha_t)
275
+ if beta1 > 0:
276
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
277
+ else:
278
+ update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
271
279
  elif self.Simplified_AdEMAMix:
272
- update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
280
+ update = torch.add(mt, normalized_grad, alpha=alpha_grad)
273
281
  else:
274
- update = mt.clone()
282
+ update = mt.clone() if beta1 > 0 else normalized_grad
275
283
 
276
284
  update = update.view(p.shape)
277
285
 
@@ -285,10 +293,11 @@ class Adopt_adv(torch.optim.Optimizer):
285
293
  del grad_reshaped
286
294
 
287
295
  # Compress and store new factors
288
- if not self.grams_moment:
289
- state['sign'] = _pack_bools(mt > 0)
290
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
291
- del mt
296
+ if beta1 > 0:
297
+ if not self.grams_moment:
298
+ state['sign'] = _pack_bools(mt > 0)
299
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
300
+ del mt
292
301
 
293
302
  if self.use_AdEMAMix:
294
303
  state['sign_slow'] = _pack_bools(mt_slow > 0)
@@ -300,10 +309,7 @@ class Adopt_adv(torch.optim.Optimizer):
300
309
  del vt
301
310
 
302
311
  else: # Standard ADOPT logic for non-factored tensors
303
- m, v = state['exp_avg'], state['exp_avg_sq'] # m_{t-1}, v_{t-1}
304
-
305
- if self.use_AdEMAMix:
306
- m_slow = state['exp_avg_slow']
312
+ v = state['exp_avg_sq'] # v_{t-1}
307
313
 
308
314
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
309
315
  denom = v.sqrt()
@@ -318,10 +324,12 @@ class Adopt_adv(torch.optim.Optimizer):
318
324
  del denom
319
325
 
320
326
  # ADOPT Step B: Update momentum m_t
321
- if self.Simplified_AdEMAMix:
322
- m.mul_(beta1).add_(normalized_grad, alpha=1.0)
323
- else:
324
- m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
327
+ if beta1 > 0:
328
+ m = state['exp_avg'] # m_{t-1},
329
+ if self.Simplified_AdEMAMix:
330
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0)
331
+ else:
332
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
325
333
 
326
334
  if self.grams_moment:
327
335
  m = grad.sign() * m.abs()
@@ -332,12 +340,16 @@ class Adopt_adv(torch.optim.Optimizer):
332
340
  del mask
333
341
 
334
342
  if self.use_AdEMAMix:
343
+ m_slow = state['exp_avg_slow']
335
344
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
336
- update = torch.add(m, m_slow, alpha=alpha_t)
345
+ if beta1 > 0:
346
+ update = torch.add(m, m_slow, alpha=alpha_t)
347
+ else:
348
+ update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
337
349
  elif self.Simplified_AdEMAMix:
338
- update = torch.add(m, grad, alpha=alpha_grad)
350
+ update = torch.add(m, normalized_grad, alpha=alpha_grad)
339
351
  else:
340
- update = m.clone()
352
+ update = m.clone() if beta1 > 0 else normalized_grad
341
353
 
342
354
  if self.use_atan2:
343
355
  update.mul_(group['lr'] * 1.2732395447351628)
@@ -308,11 +308,11 @@ class Prodigy_adv(torch.optim.Optimizer):
308
308
  if self.beta1 > 0:
309
309
  update = torch.add(mt, mt_slow, alpha=alpha_t)
310
310
  else:
311
- update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
311
+ update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
312
312
  elif self.Simplified_AdEMAMix:
313
313
  update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
314
314
  else:
315
- update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
315
+ update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(self.d)
316
316
  del grad_reshaped
317
317
 
318
318
  if group['use_atan2']:
@@ -362,11 +362,11 @@ class Prodigy_adv(torch.optim.Optimizer):
362
362
  if self.beta1 > 0:
363
363
  update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
364
364
  else:
365
- update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
365
+ update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
366
366
  elif self.Simplified_AdEMAMix:
367
367
  update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
368
368
  else:
369
- update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
369
+ update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
370
370
 
371
371
  exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
372
372
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.5
3
+ Version: 1.0.6
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,9 +1,9 @@
1
- adv_optm/__init__.py,sha256=9sM4fP1pj9divFhLVUzHbBWe50H82H3FYIGVIHTHpkg,306
1
+ adv_optm/__init__.py,sha256=dAbueuVEIGoYrYXx8UE4ATfFBH5wEKrpkXGPTjFH0r0,306
2
2
  adv_optm/optim/AdamW_adv.py,sha256=aTuYcJgd_EcZOrs6TDgBrBKw3wtU5LPzE5WvTBDDeEo,14317
3
- adv_optm/optim/Adopt_adv.py,sha256=KdEVSl2w1gRXFtz2fwCVT4i9inTspp-PQq3mobpa-9A,17476
3
+ adv_optm/optim/Adopt_adv.py,sha256=FTpDDSlYruZDt1VVLgEI_bADiO8f26j-utQs7Gn2fFA,18108
4
4
  adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
5
5
  adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
6
- adv_optm/optim/Prodigy_adv.py,sha256=8XUpu19BaBmHb-R9K3jgwySDbtVaLU1_Drtttc_zITs,22461
6
+ adv_optm/optim/Prodigy_adv.py,sha256=G8xXLO9YBeLb9574uS0HpdY9w3ojblaV-PJFghUnToQ,22493
7
7
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=tb3d6Cw_nGwcTzYUhDnKqyP7GzjD1hn8k4WqGG5lhmw,9813
8
8
  adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
9
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
@@ -12,8 +12,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
12
12
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
13
13
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
14
14
  adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
15
- adv_optm-1.0.5.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
- adv_optm-1.0.5.dist-info/METADATA,sha256=ouxI4cwBQ2IPuOjrkA478XwSetGP6ku51vW1QxHIGcY,8422
17
- adv_optm-1.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
- adv_optm-1.0.5.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
- adv_optm-1.0.5.dist-info/RECORD,,
15
+ adv_optm-1.0.6.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
+ adv_optm-1.0.6.dist-info/METADATA,sha256=3PslWXH0ysoiXU83vN3F9kWRw48fwUM4H1z1tMyEGvI,8422
17
+ adv_optm-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ adv_optm-1.0.6.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
+ adv_optm-1.0.6.dist-info/RECORD,,