adv-optm 1.2.dev19__py3-none-any.whl → 2.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +11 -9
- adv_optm/optim/AdamW_adv.py +91 -61
- adv_optm/optim/Adopt_adv.py +113 -68
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +59 -43
- adv_optm/optim/Muon_adv.py +13 -12
- adv_optm/optim/Prodigy_adv.py +108 -86
- adv_optm/optim/Simplified_AdEMAMix.py +93 -52
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +10 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/METADATA +20 -20
- adv_optm-2.dev2.dist-info/RECORD +23 -0
- adv_optm-1.2.dev19.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev2.dist-info}/top_level.txt +0 -0
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -27,10 +27,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
29
|
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
30
|
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
31
|
uncompressed optimizer. (default: True)
|
|
36
32
|
"""
|
|
@@ -41,12 +37,13 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
41
37
|
lr: float = 1e-4,
|
|
42
38
|
betas: Tuple[float, float] = (0.9, 0.99),
|
|
43
39
|
weight_decay: float = 0.0,
|
|
44
|
-
vector_reshape: bool =
|
|
40
|
+
vector_reshape: bool = False,
|
|
45
41
|
stochastic_rounding: bool = True,
|
|
46
42
|
orthogonal_gradient: bool = False,
|
|
47
43
|
cautious_mask: bool = False,
|
|
48
|
-
clip_threshold: float = 0.0,
|
|
49
44
|
nnmf_factor: bool = True,
|
|
45
|
+
# Compiled
|
|
46
|
+
compiled_optimizer: bool = False,
|
|
50
47
|
):
|
|
51
48
|
if not lr > 0.0:
|
|
52
49
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -61,13 +58,20 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
61
58
|
weight_decay=weight_decay,
|
|
62
59
|
vector_reshape=vector_reshape,
|
|
63
60
|
orthogonal_gradient=orthogonal_gradient,
|
|
64
|
-
|
|
61
|
+
compiled_optimizer=compiled_optimizer,
|
|
65
62
|
)
|
|
66
63
|
self.stochastic_rounding = stochastic_rounding
|
|
67
64
|
self.cautious_mask = cautious_mask
|
|
68
65
|
self.factored = nnmf_factor
|
|
69
66
|
super().__init__(params, defaults)
|
|
70
67
|
|
|
68
|
+
self.init_step()
|
|
69
|
+
|
|
70
|
+
if compiled_optimizer:
|
|
71
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
72
|
+
self.compile(fullgraph=True)
|
|
73
|
+
|
|
74
|
+
|
|
71
75
|
@property
|
|
72
76
|
def supports_fused_back_pass(self) -> bool:
|
|
73
77
|
return True
|
|
@@ -80,50 +84,50 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
80
84
|
def supports_flat_params(self) -> bool:
|
|
81
85
|
return False
|
|
82
86
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
return
|
|
87
|
+
def init_step(self):
|
|
88
|
+
for group in self.param_groups:
|
|
89
|
+
for i, p in enumerate(group['params']):
|
|
90
|
+
self.__init_state(p, group)
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
grad = grad.float()
|
|
92
|
-
if group["clip_threshold"] > 0.0:
|
|
93
|
-
grad_norm = torch.norm(grad.detach())
|
|
94
|
-
if grad_norm > group["clip_threshold"]:
|
|
95
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
96
|
-
grad.mul_(clip_coef)
|
|
97
|
-
if group["orthogonal_gradient"]:
|
|
98
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
92
|
+
@torch.no_grad()
|
|
93
|
+
def __init_state(self, p, group):
|
|
99
94
|
state = self.state[p]
|
|
100
95
|
|
|
101
|
-
|
|
102
|
-
if 'step' not in state:
|
|
103
|
-
state['step'] = 0
|
|
96
|
+
if len(state) == 0:
|
|
104
97
|
|
|
105
|
-
|
|
98
|
+
state['factored'] = (
|
|
106
99
|
self.factored and
|
|
107
100
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
108
101
|
)
|
|
109
102
|
|
|
110
|
-
state['factored'] = should_factor
|
|
111
|
-
|
|
112
103
|
dtype = torch.float32 if self.factored else p.dtype
|
|
113
104
|
|
|
114
105
|
if state['factored']:
|
|
115
106
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
116
107
|
d1, d2 = state['effective_shape']
|
|
117
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
108
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
118
109
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
119
110
|
packed_d2 = (d2 + 7) // 8
|
|
120
111
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
121
112
|
else: # Fallback to standard Lion
|
|
122
113
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
123
114
|
|
|
124
|
-
|
|
115
|
+
@torch.no_grad()
|
|
116
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
|
|
117
|
+
"""Performs a single optimization step on a single parameter."""
|
|
118
|
+
if p.grad is None:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
grad = p.grad
|
|
122
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
123
|
+
grad = grad.float()
|
|
124
|
+
if group["orthogonal_gradient"]:
|
|
125
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
126
|
+
|
|
127
|
+
state = self.state[p]
|
|
128
|
+
|
|
129
|
+
|
|
125
130
|
beta1, beta2 = group["betas"]
|
|
126
|
-
lr = group["lr"]
|
|
127
131
|
|
|
128
132
|
if state['factored']:
|
|
129
133
|
# Factored Path
|
|
@@ -138,16 +142,16 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
138
142
|
exp_avg = exp_avg.float()
|
|
139
143
|
|
|
140
144
|
# Compute update term c_t
|
|
141
|
-
|
|
145
|
+
update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
142
146
|
|
|
143
147
|
if self.cautious_mask:
|
|
144
|
-
mask = (
|
|
148
|
+
mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
145
149
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
146
|
-
|
|
150
|
+
update.mul_(mask)
|
|
147
151
|
del mask
|
|
148
152
|
|
|
149
153
|
# Parameter update
|
|
150
|
-
|
|
154
|
+
update = update.view(p.shape).mul_(lr)
|
|
151
155
|
|
|
152
156
|
# Standard Lion momentum update
|
|
153
157
|
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
@@ -165,15 +169,15 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
165
169
|
# Compute update term and sign for the update
|
|
166
170
|
if exp_avg.dtype != torch.float32 and self.factored:
|
|
167
171
|
exp_avg = exp_avg.float()
|
|
168
|
-
|
|
172
|
+
update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
169
173
|
|
|
170
174
|
if self.cautious_mask:
|
|
171
|
-
mask = (
|
|
175
|
+
mask = (update * grad > 0).to(grad.dtype)
|
|
172
176
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
173
|
-
|
|
177
|
+
update.mul_(mask)
|
|
174
178
|
del mask
|
|
175
179
|
|
|
176
|
-
|
|
180
|
+
update.mul_(lr)
|
|
177
181
|
|
|
178
182
|
# Standard Lion momentum update
|
|
179
183
|
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
@@ -188,11 +192,23 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
188
192
|
)
|
|
189
193
|
|
|
190
194
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
191
|
-
add_stochastic_(p.data, -
|
|
195
|
+
add_stochastic_(p.data, -update)
|
|
196
|
+
else:
|
|
197
|
+
p.data.add_(-update)
|
|
198
|
+
|
|
199
|
+
del update
|
|
200
|
+
|
|
201
|
+
@torch.no_grad()
|
|
202
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
203
|
+
|
|
204
|
+
if not group.get('compiled_optimizer', False):
|
|
205
|
+
self.__step_parameter(p, group, group["lr"])
|
|
192
206
|
else:
|
|
193
|
-
p.
|
|
207
|
+
lr_tensor = torch.tensor(group["lr"], device=p.device)
|
|
208
|
+
self._compiled_step_parameter(p, group, lr_tensor)
|
|
194
209
|
|
|
195
|
-
|
|
210
|
+
def compile(self, *args, **kwargs):
|
|
211
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
196
212
|
|
|
197
213
|
@torch.no_grad()
|
|
198
214
|
def step(self, closure: Optional[callable] = None):
|
|
@@ -207,4 +223,4 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
207
223
|
if p.grad is not None:
|
|
208
224
|
self.step_parameter(p, group, i)
|
|
209
225
|
|
|
210
|
-
return loss
|
|
226
|
+
return loss
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -41,7 +41,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
41
41
|
stability. (default: 100.0)
|
|
42
42
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
43
43
|
BF16 parameter updates (default: True).
|
|
44
|
-
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
45
44
|
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
46
45
|
matrices for muon NewtonSchulz (default: False).
|
|
47
46
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -60,6 +59,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
60
59
|
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
61
60
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
62
61
|
(default: 0.2)
|
|
62
|
+
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
63
63
|
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
64
64
|
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
65
65
|
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
@@ -76,7 +76,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
76
76
|
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
77
77
|
adam_alpha (float): Alpha for AdEMAMix.
|
|
78
78
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
79
|
-
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
80
79
|
"""
|
|
81
80
|
|
|
82
81
|
def __init__(
|
|
@@ -92,7 +91,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
91
|
Simplified_AdEMAMix: bool = False,
|
|
93
92
|
alpha_grad: float = 100.0,
|
|
94
93
|
stochastic_rounding: bool = True,
|
|
95
|
-
orthogonal_gradient: bool = False,
|
|
96
94
|
vector_reshape_muon: bool = False,
|
|
97
95
|
vector_reshape: bool = False,
|
|
98
96
|
nnmf_factor: bool = False,
|
|
@@ -104,6 +102,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
104
102
|
beta2_normuon: float = 0.95,
|
|
105
103
|
normuon_eps: float = 1e-8,
|
|
106
104
|
normuon_lr_scale: float = 0.2,
|
|
105
|
+
normuon_atan2: bool = False,
|
|
107
106
|
# CANS
|
|
108
107
|
accelerated_ns: bool = False,
|
|
109
108
|
cns_a_bound: float = 1e-4,
|
|
@@ -126,7 +125,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
126
125
|
adam_ema_alpha: float = 0.95,
|
|
127
126
|
adam_tiny_spike: float = 1e-9,
|
|
128
127
|
adam_k_warmup_steps: int = 0,
|
|
129
|
-
adam_nnmf_factor: bool = False,
|
|
130
128
|
):
|
|
131
129
|
if not (lr >= 0.0):
|
|
132
130
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -149,13 +147,13 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
149
147
|
"vector_reshape": vector_reshape,
|
|
150
148
|
"vector_reshape_muon": vector_reshape_muon,
|
|
151
149
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
152
|
-
"orthogonal_gradient": orthogonal_gradient,
|
|
153
150
|
'compiled_optimizer': compiled_optimizer,
|
|
154
151
|
# Low-rank Ortho
|
|
155
152
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
156
153
|
# NorMuon
|
|
157
154
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
158
155
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
156
|
+
"normuon_atan2": normuon_atan2,
|
|
159
157
|
# CANS
|
|
160
158
|
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
161
159
|
# AdamW_adv defaults
|
|
@@ -167,7 +165,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
167
165
|
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
168
166
|
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
169
167
|
"adam_k_warmup_steps": adam_k_warmup_steps,
|
|
170
|
-
"adam_nnmf_factor":adam_nnmf_factor,
|
|
171
168
|
}
|
|
172
169
|
self.stochastic_rounding = stochastic_rounding
|
|
173
170
|
self.compiled_optimizer = compiled_optimizer
|
|
@@ -293,10 +290,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
293
290
|
nesterov = group['nesterov']
|
|
294
291
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
295
292
|
alpha_grad = group['alpha_grad']
|
|
296
|
-
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
297
|
-
grad = grad.float()
|
|
298
|
-
if group.get("orthogonal_gradient"):
|
|
299
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
300
293
|
|
|
301
294
|
if state['factored']: # Factored Muon
|
|
302
295
|
|
|
@@ -363,7 +356,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
363
356
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
364
357
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
365
358
|
# Normalize update
|
|
366
|
-
|
|
359
|
+
if group['normuon_atan2']:
|
|
360
|
+
a = 1.2732395
|
|
361
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
362
|
+
else:
|
|
363
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
367
364
|
# Scale learning rate
|
|
368
365
|
update_norm = torch.linalg.vector_norm(update)
|
|
369
366
|
|
|
@@ -464,7 +461,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
464
461
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
465
462
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
466
463
|
# Normalize update
|
|
467
|
-
|
|
464
|
+
if group['normuon_atan2']:
|
|
465
|
+
a = 1.2732395
|
|
466
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
467
|
+
else:
|
|
468
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
468
469
|
# Scale learning rate
|
|
469
470
|
update_norm = torch.linalg.vector_norm(update)
|
|
470
471
|
scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
|