adv-optm 2.1.dev2__tar.gz → 2.1.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/PKG-INFO +1 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/__init__.py +1 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/AdaMuon_adv.py +8 -4
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/AdamW_adv.py +2 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/Adopt_adv.py +3 -2
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +1 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/Muon_adv.py +5 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/Prodigy_adv.py +2 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +3 -7
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/Kourkoutas.py +1 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/Muon_AuxAdam.py +5 -2
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/setup.py +1 -1
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/LICENSE +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/README.md +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/param_update.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.1.dev2 → adv_optm-2.1.dev3}/setup.cfg +0 -0
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util import Muon_AuxAdam
|
|
11
11
|
|
|
12
|
-
A =
|
|
12
|
+
A = 4 / math.pi
|
|
13
13
|
|
|
14
14
|
class AdaMuon_adv(torch.optim.Optimizer):
|
|
15
15
|
"""
|
|
@@ -396,7 +396,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
396
396
|
del denom, vt_buf
|
|
397
397
|
|
|
398
398
|
# RMS-aligned scaling
|
|
399
|
-
step_scale = lr * A if group['use_atan2'] else lr
|
|
399
|
+
step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
|
|
400
400
|
rms_adjustment(update, group['rms_rescaling'], step_scale)
|
|
401
401
|
|
|
402
402
|
update = update.reshape(p.shape)
|
|
@@ -454,14 +454,18 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
454
454
|
del denom
|
|
455
455
|
|
|
456
456
|
# RMS-aligned rescaling
|
|
457
|
-
step_scale = lr * A if group['use_atan2'] else lr
|
|
457
|
+
step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
|
|
458
458
|
rms_adjustment(update, group['rms_rescaling'], step_scale)
|
|
459
459
|
|
|
460
460
|
update = update.reshape(original_shape)
|
|
461
461
|
|
|
462
462
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
463
463
|
|
|
464
|
-
|
|
464
|
+
if group.get('compiled_optimizer', False):
|
|
465
|
+
lr = torch.as_tensor(group['lr'])
|
|
466
|
+
else:
|
|
467
|
+
lr = group['lr']
|
|
468
|
+
compiled_muon_step_parameter(state, grad, group, lr, random_int_tensor)
|
|
465
469
|
|
|
466
470
|
@torch.no_grad()
|
|
467
471
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
@@ -10,7 +10,7 @@ from ..util.update_util import _grams_update, _cautious_update
|
|
|
10
10
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
11
11
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
12
|
|
|
13
|
-
A =
|
|
13
|
+
A = 4 / math.pi
|
|
14
14
|
|
|
15
15
|
class AdamW_adv(torch.optim.Optimizer):
|
|
16
16
|
"""
|
|
@@ -249,6 +249,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
249
249
|
random_int_tensor = None
|
|
250
250
|
|
|
251
251
|
if group.get('compiled_optimizer', False):
|
|
252
|
+
step_size = torch.as_tensor(step_size)
|
|
252
253
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
253
254
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
254
255
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util.update_util import _grams_update, _cautious_update
|
|
11
11
|
|
|
12
|
-
A =
|
|
12
|
+
A = 4 / math.pi
|
|
13
13
|
|
|
14
14
|
class Adopt_adv(torch.optim.Optimizer):
|
|
15
15
|
"""
|
|
@@ -270,14 +270,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
270
270
|
random_int_tensor = None
|
|
271
271
|
|
|
272
272
|
if group.get('compiled_optimizer', False):
|
|
273
|
+
lr = torch.as_tensor(group['lr'])
|
|
273
274
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
274
275
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
275
276
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
276
277
|
step_param_fn = self._compiled_step_parameter
|
|
277
278
|
else:
|
|
279
|
+
lr = group['lr']
|
|
278
280
|
step_param_fn = self._step_parameter
|
|
279
281
|
|
|
280
|
-
lr = group['lr']
|
|
281
282
|
|
|
282
283
|
step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor)
|
|
283
284
|
|
|
@@ -226,6 +226,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
226
226
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
227
227
|
# TODO, workaround until pytorch#169634 is fixed
|
|
228
228
|
d = torch.as_tensor(group['d'])
|
|
229
|
+
dlr = torch.as_tensor(group['dlr'])
|
|
229
230
|
step_param_fn = self._compiled_step_parameter
|
|
230
231
|
else:
|
|
231
232
|
d = group['d']
|
|
@@ -399,7 +399,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
399
399
|
|
|
400
400
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
401
401
|
|
|
402
|
-
|
|
402
|
+
if group.get('compiled_optimizer', False):
|
|
403
|
+
lr = torch.as_tensor(group['lr'])
|
|
404
|
+
else:
|
|
405
|
+
lr = group['lr']
|
|
406
|
+
compiled_muon_step_parameter(state, grad, group, lr, random_int_tensor)
|
|
403
407
|
|
|
404
408
|
@torch.no_grad()
|
|
405
409
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
@@ -11,7 +11,7 @@ from ..util.Kourkoutas import KourkoutasHelper
|
|
|
11
11
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
12
12
|
from ..util.update_util import _grams_update, _cautious_update
|
|
13
13
|
|
|
14
|
-
A =
|
|
14
|
+
A = 4 / math.pi
|
|
15
15
|
|
|
16
16
|
class Prodigy_adv(torch.optim.Optimizer):
|
|
17
17
|
"""
|
|
@@ -343,6 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
343
343
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
344
344
|
# TODO, workaround until pytorch#169634 is fixed
|
|
345
345
|
d = torch.as_tensor(group['d'])
|
|
346
|
+
dlr = torch.as_tensor(dlr)
|
|
346
347
|
step_param_fn = self._compiled_step_parameter
|
|
347
348
|
else:
|
|
348
349
|
d = group['d']
|
|
@@ -244,7 +244,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
244
244
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
245
245
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
246
246
|
# TODO, workaround until pytorch#169634 is fixed
|
|
247
|
-
|
|
247
|
+
lr = torch.as_tensor(lr)
|
|
248
248
|
step_param_fn = self._compiled_step_parameter
|
|
249
249
|
else:
|
|
250
250
|
step_param_fn = self._step_parameter
|
|
@@ -289,10 +289,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
289
289
|
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
|
|
290
290
|
del vt
|
|
291
291
|
|
|
292
|
-
|
|
293
|
-
update.mul_(sqrt_den_num)
|
|
294
|
-
|
|
295
|
-
update = update.view(p.shape).mul_(lr)
|
|
292
|
+
update = update.view(p.shape).mul_(lr * sqrt_den_num)
|
|
296
293
|
|
|
297
294
|
else: # Standard optimizer logic for non-factored tensors
|
|
298
295
|
exp_avg_sq = state['exp_avg_sq']
|
|
@@ -308,8 +305,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
308
305
|
update.div_(denom)
|
|
309
306
|
del denom
|
|
310
307
|
|
|
311
|
-
|
|
312
|
-
update.mul_(update_scaling)
|
|
308
|
+
update.mul_(lr * sqrt_den_num)
|
|
313
309
|
|
|
314
310
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
315
311
|
|
|
@@ -182,7 +182,7 @@ class KourkoutasHelper:
|
|
|
182
182
|
"""
|
|
183
183
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
184
184
|
|
|
185
|
-
if layer_key in self.layer_info:
|
|
185
|
+
if layer_key in self.layer_info and layer_key in self.layer_state:
|
|
186
186
|
# Accumulate for the *next* step's prepare_step call
|
|
187
187
|
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
188
188
|
|
|
@@ -7,7 +7,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
7
7
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
8
8
|
from ..util.update_util import _grams_update, _cautious_update
|
|
9
9
|
|
|
10
|
-
A =
|
|
10
|
+
A = 4 / math.pi
|
|
11
11
|
|
|
12
12
|
@torch.no_grad()
|
|
13
13
|
def _init_auxadam_state(self, p, group):
|
|
@@ -56,7 +56,7 @@ def _adam_step_parameter(self, p, grad, state, group, is_compiled, random_int_te
|
|
|
56
56
|
|
|
57
57
|
if self.kourkoutas_helper:
|
|
58
58
|
# Prepare Kourkoutas-β once per optimizer step.
|
|
59
|
-
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
59
|
+
self.kourkoutas_helper.maybe_prepare_step(step, p.device)
|
|
60
60
|
# Get the dynamic beta2_adam calculated in prepare_step()
|
|
61
61
|
beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
|
|
62
62
|
|
|
@@ -73,6 +73,9 @@ def _adam_step_parameter(self, p, grad, state, group, is_compiled, random_int_te
|
|
|
73
73
|
|
|
74
74
|
step_size = group['lr'] / bias_correction1
|
|
75
75
|
|
|
76
|
+
if group.get('compiled_optimizer', False):
|
|
77
|
+
step_size = torch.as_tensor(step_size)
|
|
78
|
+
|
|
76
79
|
@torch.compile(fullgraph=True, disable= not is_compiled)
|
|
77
80
|
def compiled_muon_step_parameter(state, grad, group, step_size, sqrt_bias_correction2, random_int_tensor):
|
|
78
81
|
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|