adv-optm 1.0.4__tar.gz → 1.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.0.4 → adv_optm-1.0.6}/PKG-INFO +1 -1
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/__init__.py +1 -1
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/Adopt_adv.py +55 -43
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/Prodigy_adv.py +4 -4
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.0.4 → adv_optm-1.0.6}/setup.py +1 -1
- {adv_optm-1.0.4 → adv_optm-1.0.6}/LICENSE +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/README.md +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.0.4 → adv_optm-1.0.6}/setup.cfg +0 -0
|
@@ -156,6 +156,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
156
156
|
grad = _orthogonalize_gradient(p, grad)
|
|
157
157
|
state = self.state[p]
|
|
158
158
|
|
|
159
|
+
beta1, beta2 = group['betas']
|
|
160
|
+
|
|
159
161
|
# State Initialization
|
|
160
162
|
if len(state) == 0:
|
|
161
163
|
state['step'] = 0
|
|
@@ -174,11 +176,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
174
176
|
d1, d2 = state['effective_shape']
|
|
175
177
|
|
|
176
178
|
# m_0 = 0
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
179
|
+
if beta1 > 0:
|
|
180
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
181
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
182
|
+
if not self.grams_moment:
|
|
183
|
+
packed_d2 = (d2 + 7) // 8
|
|
184
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
182
185
|
if self.use_AdEMAMix:
|
|
183
186
|
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
184
187
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
@@ -192,7 +195,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
192
195
|
# Initialize v_0 using NMF
|
|
193
196
|
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
194
197
|
else: # Fallback for non-factored tensors
|
|
195
|
-
|
|
198
|
+
if beta1 > 0:
|
|
199
|
+
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
196
200
|
if self.use_AdEMAMix:
|
|
197
201
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
198
202
|
state['exp_avg_sq'] = grad.square() # v_0
|
|
@@ -202,7 +206,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
202
206
|
state['step'] += 1
|
|
203
207
|
return
|
|
204
208
|
|
|
205
|
-
beta1, beta2 = group['betas']
|
|
206
209
|
if self.use_AdEMAMix:
|
|
207
210
|
beta3_ema = group['beta3_ema']
|
|
208
211
|
alpha = group['alpha']
|
|
@@ -219,13 +222,14 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
219
222
|
d1, d2 = state['effective_shape']
|
|
220
223
|
|
|
221
224
|
# Reconstruct m_{t-1}
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
if
|
|
225
|
-
state['sign']
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
225
|
+
if beta1 > 0:
|
|
226
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
227
|
+
if not self.grams_moment:
|
|
228
|
+
if state['sign'].dtype != torch.uint8:
|
|
229
|
+
state['sign'] = state['sign'].to(torch.uint8)
|
|
230
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
231
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
232
|
+
del unpacked_sign
|
|
229
233
|
|
|
230
234
|
# Reconstruct AdEMAMix EMA
|
|
231
235
|
if self.use_AdEMAMix:
|
|
@@ -253,25 +257,29 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
253
257
|
del denom
|
|
254
258
|
|
|
255
259
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
256
|
-
if
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
260
|
+
if beta1 > 0:
|
|
261
|
+
if self.Simplified_AdEMAMix:
|
|
262
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
263
|
+
else:
|
|
264
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
265
|
+
if self.grams_moment:
|
|
266
|
+
mt = grad_reshaped.sign() * mt.abs()
|
|
267
|
+
elif self.cautious_mask:
|
|
268
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
269
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
270
|
+
mt.mul_(mask)
|
|
271
|
+
del mask
|
|
267
272
|
|
|
268
273
|
if self.use_AdEMAMix:
|
|
269
274
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
270
|
-
|
|
275
|
+
if beta1 > 0:
|
|
276
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
277
|
+
else:
|
|
278
|
+
update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
|
|
271
279
|
elif self.Simplified_AdEMAMix:
|
|
272
|
-
update = torch.add(mt,
|
|
280
|
+
update = torch.add(mt, normalized_grad, alpha=alpha_grad)
|
|
273
281
|
else:
|
|
274
|
-
update = mt.clone()
|
|
282
|
+
update = mt.clone() if beta1 > 0 else normalized_grad
|
|
275
283
|
|
|
276
284
|
update = update.view(p.shape)
|
|
277
285
|
|
|
@@ -285,10 +293,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
285
293
|
del grad_reshaped
|
|
286
294
|
|
|
287
295
|
# Compress and store new factors
|
|
288
|
-
if
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
296
|
+
if beta1 > 0:
|
|
297
|
+
if not self.grams_moment:
|
|
298
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
299
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
300
|
+
del mt
|
|
292
301
|
|
|
293
302
|
if self.use_AdEMAMix:
|
|
294
303
|
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
@@ -300,10 +309,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
300
309
|
del vt
|
|
301
310
|
|
|
302
311
|
else: # Standard ADOPT logic for non-factored tensors
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if self.use_AdEMAMix:
|
|
306
|
-
m_slow = state['exp_avg_slow']
|
|
312
|
+
v = state['exp_avg_sq'] # v_{t-1}
|
|
307
313
|
|
|
308
314
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
309
315
|
denom = v.sqrt()
|
|
@@ -318,10 +324,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
318
324
|
del denom
|
|
319
325
|
|
|
320
326
|
# ADOPT Step B: Update momentum m_t
|
|
321
|
-
if
|
|
322
|
-
m
|
|
323
|
-
|
|
324
|
-
|
|
327
|
+
if beta1 > 0:
|
|
328
|
+
m = state['exp_avg'] # m_{t-1},
|
|
329
|
+
if self.Simplified_AdEMAMix:
|
|
330
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
331
|
+
else:
|
|
332
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
325
333
|
|
|
326
334
|
if self.grams_moment:
|
|
327
335
|
m = grad.sign() * m.abs()
|
|
@@ -332,12 +340,16 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
332
340
|
del mask
|
|
333
341
|
|
|
334
342
|
if self.use_AdEMAMix:
|
|
343
|
+
m_slow = state['exp_avg_slow']
|
|
335
344
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
336
|
-
|
|
345
|
+
if beta1 > 0:
|
|
346
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
347
|
+
else:
|
|
348
|
+
update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
|
|
337
349
|
elif self.Simplified_AdEMAMix:
|
|
338
|
-
update = torch.add(m,
|
|
350
|
+
update = torch.add(m, normalized_grad, alpha=alpha_grad)
|
|
339
351
|
else:
|
|
340
|
-
update = m.clone()
|
|
352
|
+
update = m.clone() if beta1 > 0 else normalized_grad
|
|
341
353
|
|
|
342
354
|
if self.use_atan2:
|
|
343
355
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -308,11 +308,11 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
308
308
|
if self.beta1 > 0:
|
|
309
309
|
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
310
310
|
else:
|
|
311
|
-
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
311
|
+
update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
|
|
312
312
|
elif self.Simplified_AdEMAMix:
|
|
313
313
|
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
314
314
|
else:
|
|
315
|
-
update = mt.clone() if self.beta1 > 0 else grad_reshaped.
|
|
315
|
+
update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(self.d)
|
|
316
316
|
del grad_reshaped
|
|
317
317
|
|
|
318
318
|
if group['use_atan2']:
|
|
@@ -362,11 +362,11 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
362
362
|
if self.beta1 > 0:
|
|
363
363
|
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
364
364
|
else:
|
|
365
|
-
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
365
|
+
update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
|
|
366
366
|
elif self.Simplified_AdEMAMix:
|
|
367
367
|
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
368
368
|
else:
|
|
369
|
-
update = exp_avg.clone() if self.beta1 > 0 else grad.
|
|
369
|
+
update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
|
|
370
370
|
|
|
371
371
|
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
|
|
372
372
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|