adv-optm 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/Lion_Prodigy_adv.py +0 -2
- adv_optm/optim/Lion_adv.py +0 -2
- adv_optm/optim/Prodigy_adv.py +56 -51
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.3.dist-info}/METADATA +1 -1
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.3.dist-info}/RECORD +9 -9
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.3.dist-info}/WHEEL +0 -0
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-0.1.2.dist-info → adv_optm-0.1.3.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
|
@@ -82,8 +82,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
82
82
|
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
83
83
|
if not weight_decay >= 0.0:
|
|
84
84
|
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
85
|
-
if variance_reduction and use_cautious:
|
|
86
|
-
print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
|
|
87
85
|
|
|
88
86
|
defaults = dict(
|
|
89
87
|
lr=lr,
|
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -57,8 +57,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
57
57
|
raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
|
|
58
58
|
if not weight_decay >= 0.0:
|
|
59
59
|
raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
|
|
60
|
-
if variance_reduction and use_cautious:
|
|
61
|
-
print("Warning: Using both 'variance_reduction' and 'use_cautious' is not recommended and may lead to unintended effects.")
|
|
62
60
|
|
|
63
61
|
defaults = dict(
|
|
64
62
|
lr=lr,
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -194,11 +194,12 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
194
194
|
d1, d2 = state['effective_shape']
|
|
195
195
|
|
|
196
196
|
# First moment (m)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
197
|
+
if self.beta1 > 0:
|
|
198
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
199
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
200
|
+
if not self.use_grams:
|
|
201
|
+
packed_d2 = (d2 + 7) // 8
|
|
202
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
202
203
|
if self.use_AdEMAMix:
|
|
203
204
|
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
204
205
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
@@ -208,7 +209,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
208
209
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
209
210
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
210
211
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
211
|
-
|
|
212
|
+
if self.beta1 > 0:
|
|
213
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
212
214
|
if self.use_AdEMAMix:
|
|
213
215
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
214
216
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
@@ -231,22 +233,24 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
231
233
|
if state['factored']:
|
|
232
234
|
d1, d2 = state['effective_shape']
|
|
233
235
|
|
|
234
|
-
# Reconstruct momentum from previous step's factors
|
|
235
|
-
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
236
|
-
if not self.use_grams:
|
|
237
|
-
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
238
|
-
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
239
|
-
del unpacked_sign
|
|
240
|
-
# Update momentum in full-size
|
|
241
236
|
grad_reshaped = grad.view(d1, d2)
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
237
|
+
|
|
238
|
+
# Reconstruct momentum from previous step's factors
|
|
239
|
+
if self.beta1 > 0:
|
|
240
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
241
|
+
if not self.use_grams:
|
|
242
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
243
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
244
|
+
del unpacked_sign
|
|
245
|
+
# Update momentum in full-size
|
|
246
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
|
|
247
|
+
if self.use_grams:
|
|
248
|
+
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
249
|
+
elif self.use_cautious:
|
|
250
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
251
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
252
|
+
mt.mul_(mask)
|
|
253
|
+
del mask
|
|
250
254
|
|
|
251
255
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
252
256
|
vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
|
|
@@ -258,30 +262,29 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
258
262
|
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
259
263
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
260
264
|
del unpacked_sign_slow
|
|
261
|
-
|
|
262
265
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
|
|
263
|
-
|
|
266
|
+
update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
|
|
264
267
|
else:
|
|
265
|
-
|
|
268
|
+
update = mt if self.beta1 > 0 else grad_reshaped
|
|
266
269
|
del grad_reshaped
|
|
267
270
|
|
|
268
271
|
if group['use_atan2']:
|
|
269
272
|
a = 1.2732395
|
|
270
273
|
denom = vt.sqrt()
|
|
271
|
-
update
|
|
274
|
+
update.atan2_(denom).mul_(a)
|
|
272
275
|
else:
|
|
273
|
-
denom = vt.sqrt()
|
|
274
|
-
update
|
|
275
|
-
del
|
|
276
|
+
denom = vt.sqrt()
|
|
277
|
+
update.div_(denom.add_(self.d * group['eps']))
|
|
278
|
+
del denom
|
|
276
279
|
|
|
277
|
-
update
|
|
278
|
-
update.mul_(self.dlr)
|
|
280
|
+
update.view(p.shape).mul_(self.dlr)
|
|
279
281
|
|
|
280
282
|
# Compress updated moments and store new factors
|
|
281
|
-
if
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
283
|
+
if self.beta1 > 0:
|
|
284
|
+
if not self.use_grams:
|
|
285
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
286
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
287
|
+
del mt
|
|
285
288
|
if self.use_AdEMAMix:
|
|
286
289
|
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
287
290
|
_nnmf(mt_slow.abs(), out=(state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -290,36 +293,38 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
290
293
|
del vt
|
|
291
294
|
|
|
292
295
|
else: # Standard AdamW logic for non-factored tensors
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
exp_avg
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
296
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
297
|
+
|
|
298
|
+
if self.beta1 > 0:
|
|
299
|
+
exp_avg = state['exp_avg']
|
|
300
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
|
|
301
|
+
if self.use_grams:
|
|
302
|
+
exp_avg = grad.sign() * exp_avg.abs()
|
|
303
|
+
elif self.use_cautious:
|
|
304
|
+
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
305
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
306
|
+
exp_avg.mul_(mask)
|
|
307
|
+
del mask
|
|
303
308
|
|
|
304
309
|
if self.use_AdEMAMix:
|
|
305
310
|
exp_avg_slow = state['exp_avg_slow']
|
|
306
311
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
|
|
307
|
-
|
|
312
|
+
update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
|
|
308
313
|
else:
|
|
309
|
-
|
|
314
|
+
update = exp_avg if self.beta1 > 0 else grad
|
|
310
315
|
|
|
311
316
|
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
|
|
312
317
|
|
|
313
318
|
if group['use_atan2']:
|
|
314
319
|
a = 1.2732395
|
|
315
320
|
denom = exp_avg_sq.sqrt()
|
|
316
|
-
update
|
|
321
|
+
update.atan2_(denom).mul_(a)
|
|
317
322
|
else:
|
|
318
|
-
denom = exp_avg_sq.sqrt()
|
|
319
|
-
update
|
|
320
|
-
del
|
|
323
|
+
denom = exp_avg_sq.sqrt()
|
|
324
|
+
update.div_(denom.add_(self.d * group['eps']))
|
|
325
|
+
del denom
|
|
321
326
|
|
|
322
|
-
update
|
|
327
|
+
update.mul_(self.dlr)
|
|
323
328
|
|
|
324
329
|
# --- Accumulate Prodigy stats ---
|
|
325
330
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=5Mmq6ovFOuVKvEuEVVHD4UfO9ObsxbJ4KtsuoOtgvxc,252
|
|
2
2
|
adv_optm/optim/AdamW_adv.py,sha256=_4Vt79EB18rnIkHttA0CdMpli8sZ5f03pesdrwT5K58,12887
|
|
3
3
|
adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=Gc_o0HqZTau-cyP2x4ssKgdQnPYPjJsPVsqTFgz6vGw,13918
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=ZxnV6fQUvOmvJVkeUbStNjeXBWMxDlfMcSLjNpL1uFU,9226
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=H7MrZMjCkZdsHBXY17Jm7aTFNySoVkIXQSszdoHn6u4,17697
|
|
7
7
|
adv_optm/optim/__init__.py,sha256=e5UighM92LDvDB2JJwj8gDsTpXEedpytScwqS6F2FR8,300
|
|
8
8
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
9
9
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
@@ -11,8 +11,8 @@ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
|
11
11
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
12
12
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
13
13
|
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
14
|
-
adv_optm-0.1.
|
|
15
|
-
adv_optm-0.1.
|
|
16
|
-
adv_optm-0.1.
|
|
17
|
-
adv_optm-0.1.
|
|
18
|
-
adv_optm-0.1.
|
|
14
|
+
adv_optm-0.1.3.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
15
|
+
adv_optm-0.1.3.dist-info/METADATA,sha256=xv7wytTibFrp0MWf2htvY8N413qNPQs6P9cB-r2HOPY,5846
|
|
16
|
+
adv_optm-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
17
|
+
adv_optm-0.1.3.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
18
|
+
adv_optm-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|