adv-optm 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -14,4 +14,4 @@ __all__ = [
14
14
  "Lion_Prodigy_adv",
15
15
  ]
16
16
 
17
- __version__ = "0.1.2"
17
+ __version__ = "0.1.3"
@@ -82,8 +82,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
82
82
  raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
83
83
  if not weight_decay >= 0.0:
84
84
  raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
85
- if variance_reduction and use_cautious:
86
- print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
87
85
 
88
86
  defaults = dict(
89
87
  lr=lr,
@@ -57,8 +57,6 @@ class Lion_adv(torch.optim.Optimizer):
57
57
  raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
58
58
  if not weight_decay >= 0.0:
59
59
  raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
60
- if variance_reduction and use_cautious:
61
- print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
62
60
 
63
61
  defaults = dict(
64
62
  lr=lr,
@@ -194,11 +194,12 @@ class Prodigy_adv(torch.optim.Optimizer):
194
194
  d1, d2 = state['effective_shape']
195
195
 
196
196
  # First moment (m)
197
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
198
- state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
199
- if not self.use_grams:
200
- packed_d2 = (d2 + 7) // 8
201
- state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
197
+ if self.beta1 > 0:
198
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
199
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
200
+ if not self.use_grams:
201
+ packed_d2 = (d2 + 7) // 8
202
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
202
203
  if self.use_AdEMAMix:
203
204
  state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
204
205
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
@@ -208,7 +209,8 @@ class Prodigy_adv(torch.optim.Optimizer):
208
209
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
209
210
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
210
211
  else: # Fallback to standard AdamW for non-factored tensors
211
- state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
212
+ if self.beta1 > 0:
213
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
212
214
  if self.use_AdEMAMix:
213
215
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
214
216
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
@@ -231,22 +233,24 @@ class Prodigy_adv(torch.optim.Optimizer):
231
233
  if state['factored']:
232
234
  d1, d2 = state['effective_shape']
233
235
 
234
- # Reconstruct momentum from previous step's factors
235
- mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
236
- if not self.use_grams:
237
- unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
238
- torch.where(unpacked_sign, mt, -mt, out=mt)
239
- del unpacked_sign
240
- # Update momentum in full-size
241
236
  grad_reshaped = grad.view(d1, d2)
242
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
243
- if self.use_grams:
244
- mt.copy_(grad_reshaped.sign() * mt.abs())
245
- elif self.use_cautious:
246
- mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
247
- mask.div_(mask.mean().clamp_(min=1e-3))
248
- mt.mul_(mask)
249
- del mask
237
+
238
+ # Reconstruct momentum from previous step's factors
239
+ if self.beta1 > 0:
240
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
241
+ if not self.use_grams:
242
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
243
+ torch.where(unpacked_sign, mt, -mt, out=mt)
244
+ del unpacked_sign
245
+ # Update momentum in full-size
246
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
247
+ if self.use_grams:
248
+ mt.copy_(grad_reshaped.sign() * mt.abs())
249
+ elif self.use_cautious:
250
+ mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
251
+ mask.div_(mask.mean().clamp_(min=1e-3))
252
+ mt.mul_(mask)
253
+ del mask
250
254
 
251
255
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
252
256
  vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
@@ -258,30 +262,29 @@ class Prodigy_adv(torch.optim.Optimizer):
258
262
  unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
259
263
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
260
264
  del unpacked_sign_slow
261
-
262
265
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
263
- update_m = mt + (alpha_t * mt_slow)
266
+ update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
264
267
  else:
265
- update_m = mt
268
+ update = mt if self.beta1 > 0 else grad_reshaped
266
269
  del grad_reshaped
267
270
 
268
271
  if group['use_atan2']:
269
272
  a = 1.2732395
270
273
  denom = vt.sqrt()
271
- update = torch.atan2(update_m, denom).mul_(a)
274
+ update.atan2_(denom).mul_(a)
272
275
  else:
273
- denom = vt.sqrt().add_(self.d * group['eps'])
274
- update = update_m / denom
275
- del update_m, denom
276
+ denom = vt.sqrt()
277
+ update.div_(denom.add_(self.d * group['eps']))
278
+ del denom
276
279
 
277
- update = update.view(p.shape)
278
- update.mul_(self.dlr)
280
+ update.view(p.shape).mul_(self.dlr)
279
281
 
280
282
  # Compress updated moments and store new factors
281
- if not self.use_grams:
282
- state['sign'] = _pack_bools(mt > 0)
283
- _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
284
- del mt
283
+ if self.beta1 > 0:
284
+ if not self.use_grams:
285
+ state['sign'] = _pack_bools(mt > 0)
286
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
287
+ del mt
285
288
  if self.use_AdEMAMix:
286
289
  state['sign_slow'] = _pack_bools(mt_slow > 0)
287
290
  _nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -290,36 +293,38 @@ class Prodigy_adv(torch.optim.Optimizer):
290
293
  del vt
291
294
 
292
295
  else: # Standard AdamW logic for non-factored tensors
293
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
294
-
295
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
296
- if self.use_grams:
297
- exp_avg = grad.sign() * exp_avg.abs()
298
- elif self.use_cautious:
299
- mask = (exp_avg * grad > 0).to(grad.dtype)
300
- mask.div_(mask.mean().clamp_(min=1e-3))
301
- exp_avg.mul_(mask)
302
- del mask
296
+ exp_avg_sq = state['exp_avg_sq']
297
+
298
+ if self.beta1 > 0:
299
+ exp_avg = state['exp_avg']
300
+ exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
301
+ if self.use_grams:
302
+ exp_avg = grad.sign() * exp_avg.abs()
303
+ elif self.use_cautious:
304
+ mask = (exp_avg * grad > 0).to(grad.dtype)
305
+ mask.div_(mask.mean().clamp_(min=1e-3))
306
+ exp_avg.mul_(mask)
307
+ del mask
303
308
 
304
309
  if self.use_AdEMAMix:
305
310
  exp_avg_slow = state['exp_avg_slow']
306
311
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
307
- update_m = exp_avg + (alpha_t * exp_avg_slow)
312
+ update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
308
313
  else:
309
- update_m = exp_avg
314
+ update = exp_avg if self.beta1 > 0 else grad
310
315
 
311
316
  exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
312
317
 
313
318
  if group['use_atan2']:
314
319
  a = 1.2732395
315
320
  denom = exp_avg_sq.sqrt()
316
- update = torch.atan2(update_m, denom).mul_(a)
321
+ update.atan2_(denom).mul_(a)
317
322
  else:
318
- denom = exp_avg_sq.sqrt().add_(self.d * group['eps'])
319
- update = update_m / denom
320
- del update_m, denom
323
+ denom = exp_avg_sq.sqrt()
324
+ update.div_(denom.add_(self.d * group['eps']))
325
+ del denom
321
326
 
322
- update = update.mul_(self.dlr)
327
+ update.mul_(self.dlr)
323
328
 
324
329
  # --- Accumulate Prodigy stats ---
325
330
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,9 +1,9 @@
1
- adv_optm/__init__.py,sha256=BNYlxkuU8MFsWSY1_PLzp2XBSzpt-sxhnVuWVKRZGZ8,252
1
+ adv_optm/__init__.py,sha256=5Mmq6ovFOuVKvEuEVVHD4UfO9ObsxbJ4KtsuoOtgvxc,252
2
2
  adv_optm/optim/AdamW_adv.py,sha256=_4Vt79EB18rnIkHttA0CdMpli8sZ5f03pesdrwT5K58,12887
3
3
  adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=ql6506h_IIZvTPdGYrQdd6iEhCXHTMntqmg739fc_dw,14102
5
- adv_optm/optim/Lion_adv.py,sha256=jOoRbJ6u9HCK7IBI9ILOCcwprKIGTUNvUzhRd99WJK0,9410
6
- adv_optm/optim/Prodigy_adv.py,sha256=InR50MoE32zG6qgEkg_JzXl7uXAVRy4EYG0JDl4eKok,17324
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=Gc_o0HqZTau-cyP2x4ssKgdQnPYPjJsPVsqTFgz6vGw,13918
5
+ adv_optm/optim/Lion_adv.py,sha256=ZxnV6fQUvOmvJVkeUbStNjeXBWMxDlfMcSLjNpL1uFU,9226
6
+ adv_optm/optim/Prodigy_adv.py,sha256=H7MrZMjCkZdsHBXY17Jm7aTFNySoVkIXQSszdoHn6u4,17697
7
7
  adv_optm/optim/__init__.py,sha256=e5UighM92LDvDB2JJwj8gDsTpXEedpytScwqS6F2FR8,300
8
8
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
9
9
  adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
@@ -11,8 +11,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
11
11
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
12
12
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
13
13
  adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
14
- adv_optm-0.1.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
15
- adv_optm-0.1.2.dist-info/METADATA,sha256=iV5GBWtl4WphBeSIIsUoq1ay6-GJGnDD3XF6aSWWrqg,5846
16
- adv_optm-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
- adv_optm-0.1.2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
18
- adv_optm-0.1.2.dist-info/RECORD,,
14
+ adv_optm-0.1.3.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
15
+ adv_optm-0.1.3.dist-info/METADATA,sha256=xv7wytTibFrp0MWf2htvY8N413qNPQs6P9cB-r2HOPY,5846
16
+ adv_optm-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
+ adv_optm-0.1.3.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
18
+ adv_optm-0.1.3.dist-info/RECORD,,