adv-optm 1.2.dev19__py3-none-any.whl → 2.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +11 -9
- adv_optm/optim/AdamW_adv.py +91 -61
- adv_optm/optim/Adopt_adv.py +113 -68
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +59 -43
- adv_optm/optim/Muon_adv.py +13 -9
- adv_optm/optim/Prodigy_adv.py +108 -86
- adv_optm/optim/Simplified_AdEMAMix.py +93 -52
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +10 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/METADATA +20 -20
- adv_optm-2.dev3.dist-info/RECORD +23 -0
- adv_optm-1.2.dev19.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev19.dist-info → adv_optm-2.dev3.dist-info}/top_level.txt +0 -0
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -27,10 +27,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
27
27
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
28
28
|
rounding for BF16 parameter updates (default: True).
|
|
29
29
|
cautious_mask (bool): whether to use the cautious masking technique. (default: False).
|
|
30
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
31
|
-
per-parameter as proposed in the paper `Lions and Muons: Optimization via
|
|
32
|
-
Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
|
|
33
|
-
(default: 0.0).
|
|
34
30
|
nnmf_factor (bool): whether to use the factorization or use the
|
|
35
31
|
uncompressed optimizer. (default: True)
|
|
36
32
|
"""
|
|
@@ -41,12 +37,13 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
41
37
|
lr: float = 1e-4,
|
|
42
38
|
betas: Tuple[float, float] = (0.9, 0.99),
|
|
43
39
|
weight_decay: float = 0.0,
|
|
44
|
-
vector_reshape: bool =
|
|
40
|
+
vector_reshape: bool = False,
|
|
45
41
|
stochastic_rounding: bool = True,
|
|
46
42
|
orthogonal_gradient: bool = False,
|
|
47
43
|
cautious_mask: bool = False,
|
|
48
|
-
clip_threshold: float = 0.0,
|
|
49
44
|
nnmf_factor: bool = True,
|
|
45
|
+
# Compiled
|
|
46
|
+
compiled_optimizer: bool = False,
|
|
50
47
|
):
|
|
51
48
|
if not lr > 0.0:
|
|
52
49
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -61,13 +58,20 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
61
58
|
weight_decay=weight_decay,
|
|
62
59
|
vector_reshape=vector_reshape,
|
|
63
60
|
orthogonal_gradient=orthogonal_gradient,
|
|
64
|
-
|
|
61
|
+
compiled_optimizer=compiled_optimizer,
|
|
65
62
|
)
|
|
66
63
|
self.stochastic_rounding = stochastic_rounding
|
|
67
64
|
self.cautious_mask = cautious_mask
|
|
68
65
|
self.factored = nnmf_factor
|
|
69
66
|
super().__init__(params, defaults)
|
|
70
67
|
|
|
68
|
+
self.init_step()
|
|
69
|
+
|
|
70
|
+
if compiled_optimizer:
|
|
71
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
72
|
+
self.compile(fullgraph=True)
|
|
73
|
+
|
|
74
|
+
|
|
71
75
|
@property
|
|
72
76
|
def supports_fused_back_pass(self) -> bool:
|
|
73
77
|
return True
|
|
@@ -80,50 +84,50 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
80
84
|
def supports_flat_params(self) -> bool:
|
|
81
85
|
return False
|
|
82
86
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
return
|
|
87
|
+
def init_step(self):
|
|
88
|
+
for group in self.param_groups:
|
|
89
|
+
for i, p in enumerate(group['params']):
|
|
90
|
+
self.__init_state(p, group)
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
grad = grad.float()
|
|
92
|
-
if group["clip_threshold"] > 0.0:
|
|
93
|
-
grad_norm = torch.norm(grad.detach())
|
|
94
|
-
if grad_norm > group["clip_threshold"]:
|
|
95
|
-
clip_coef = group["clip_threshold"] / grad_norm
|
|
96
|
-
grad.mul_(clip_coef)
|
|
97
|
-
if group["orthogonal_gradient"]:
|
|
98
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
92
|
+
@torch.no_grad()
|
|
93
|
+
def __init_state(self, p, group):
|
|
99
94
|
state = self.state[p]
|
|
100
95
|
|
|
101
|
-
|
|
102
|
-
if 'step' not in state:
|
|
103
|
-
state['step'] = 0
|
|
96
|
+
if len(state) == 0:
|
|
104
97
|
|
|
105
|
-
|
|
98
|
+
state['factored'] = (
|
|
106
99
|
self.factored and
|
|
107
100
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
108
101
|
)
|
|
109
102
|
|
|
110
|
-
state['factored'] = should_factor
|
|
111
|
-
|
|
112
103
|
dtype = torch.float32 if self.factored else p.dtype
|
|
113
104
|
|
|
114
105
|
if state['factored']:
|
|
115
106
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
116
107
|
d1, d2 = state['effective_shape']
|
|
117
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
108
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
118
109
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
119
110
|
packed_d2 = (d2 + 7) // 8
|
|
120
111
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
121
112
|
else: # Fallback to standard Lion
|
|
122
113
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
123
114
|
|
|
124
|
-
|
|
115
|
+
@torch.no_grad()
|
|
116
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
|
|
117
|
+
"""Performs a single optimization step on a single parameter."""
|
|
118
|
+
if p.grad is None:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
grad = p.grad
|
|
122
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
123
|
+
grad = grad.float()
|
|
124
|
+
if group["orthogonal_gradient"]:
|
|
125
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
126
|
+
|
|
127
|
+
state = self.state[p]
|
|
128
|
+
|
|
129
|
+
|
|
125
130
|
beta1, beta2 = group["betas"]
|
|
126
|
-
lr = group["lr"]
|
|
127
131
|
|
|
128
132
|
if state['factored']:
|
|
129
133
|
# Factored Path
|
|
@@ -138,16 +142,16 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
138
142
|
exp_avg = exp_avg.float()
|
|
139
143
|
|
|
140
144
|
# Compute update term c_t
|
|
141
|
-
|
|
145
|
+
update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
|
|
142
146
|
|
|
143
147
|
if self.cautious_mask:
|
|
144
|
-
mask = (
|
|
148
|
+
mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
145
149
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
146
|
-
|
|
150
|
+
update.mul_(mask)
|
|
147
151
|
del mask
|
|
148
152
|
|
|
149
153
|
# Parameter update
|
|
150
|
-
|
|
154
|
+
update = update.view(p.shape).mul_(lr)
|
|
151
155
|
|
|
152
156
|
# Standard Lion momentum update
|
|
153
157
|
exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
|
|
@@ -165,15 +169,15 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
165
169
|
# Compute update term and sign for the update
|
|
166
170
|
if exp_avg.dtype != torch.float32 and self.factored:
|
|
167
171
|
exp_avg = exp_avg.float()
|
|
168
|
-
|
|
172
|
+
update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
|
|
169
173
|
|
|
170
174
|
if self.cautious_mask:
|
|
171
|
-
mask = (
|
|
175
|
+
mask = (update * grad > 0).to(grad.dtype)
|
|
172
176
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
173
|
-
|
|
177
|
+
update.mul_(mask)
|
|
174
178
|
del mask
|
|
175
179
|
|
|
176
|
-
|
|
180
|
+
update.mul_(lr)
|
|
177
181
|
|
|
178
182
|
# Standard Lion momentum update
|
|
179
183
|
exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
|
|
@@ -188,11 +192,23 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
188
192
|
)
|
|
189
193
|
|
|
190
194
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
191
|
-
add_stochastic_(p.data, -
|
|
195
|
+
add_stochastic_(p.data, -update)
|
|
196
|
+
else:
|
|
197
|
+
p.data.add_(-update)
|
|
198
|
+
|
|
199
|
+
del update
|
|
200
|
+
|
|
201
|
+
@torch.no_grad()
|
|
202
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
203
|
+
|
|
204
|
+
if not group.get('compiled_optimizer', False):
|
|
205
|
+
self.__step_parameter(p, group, group["lr"])
|
|
192
206
|
else:
|
|
193
|
-
p.
|
|
207
|
+
lr_tensor = torch.tensor(group["lr"], device=p.device)
|
|
208
|
+
self._compiled_step_parameter(p, group, lr_tensor)
|
|
194
209
|
|
|
195
|
-
|
|
210
|
+
def compile(self, *args, **kwargs):
|
|
211
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
196
212
|
|
|
197
213
|
@torch.no_grad()
|
|
198
214
|
def step(self, closure: Optional[callable] = None):
|
|
@@ -207,4 +223,4 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
207
223
|
if p.grad is not None:
|
|
208
224
|
self.step_parameter(p, group, i)
|
|
209
225
|
|
|
210
|
-
return loss
|
|
226
|
+
return loss
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -41,7 +41,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
41
41
|
stability. (default: 100.0)
|
|
42
42
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
43
43
|
BF16 parameter updates (default: True).
|
|
44
|
-
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
45
44
|
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
46
45
|
matrices for muon NewtonSchulz (default: False).
|
|
47
46
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
@@ -60,6 +59,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
60
59
|
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
61
60
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
62
61
|
(default: 0.2)
|
|
62
|
+
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
63
63
|
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
64
64
|
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
65
65
|
cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
|
|
@@ -92,7 +92,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
92
|
Simplified_AdEMAMix: bool = False,
|
|
93
93
|
alpha_grad: float = 100.0,
|
|
94
94
|
stochastic_rounding: bool = True,
|
|
95
|
-
orthogonal_gradient: bool = False,
|
|
96
95
|
vector_reshape_muon: bool = False,
|
|
97
96
|
vector_reshape: bool = False,
|
|
98
97
|
nnmf_factor: bool = False,
|
|
@@ -104,6 +103,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
104
103
|
beta2_normuon: float = 0.95,
|
|
105
104
|
normuon_eps: float = 1e-8,
|
|
106
105
|
normuon_lr_scale: float = 0.2,
|
|
106
|
+
normuon_atan2: bool = False,
|
|
107
107
|
# CANS
|
|
108
108
|
accelerated_ns: bool = False,
|
|
109
109
|
cns_a_bound: float = 1e-4,
|
|
@@ -149,13 +149,13 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
149
149
|
"vector_reshape": vector_reshape,
|
|
150
150
|
"vector_reshape_muon": vector_reshape_muon,
|
|
151
151
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
152
|
-
"orthogonal_gradient": orthogonal_gradient,
|
|
153
152
|
'compiled_optimizer': compiled_optimizer,
|
|
154
153
|
# Low-rank Ortho
|
|
155
154
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
156
155
|
# NorMuon
|
|
157
156
|
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
158
157
|
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
158
|
+
"normuon_atan2": normuon_atan2,
|
|
159
159
|
# CANS
|
|
160
160
|
"accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
|
|
161
161
|
# AdamW_adv defaults
|
|
@@ -293,10 +293,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
293
293
|
nesterov = group['nesterov']
|
|
294
294
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
295
295
|
alpha_grad = group['alpha_grad']
|
|
296
|
-
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
297
|
-
grad = grad.float()
|
|
298
|
-
if group.get("orthogonal_gradient"):
|
|
299
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
300
296
|
|
|
301
297
|
if state['factored']: # Factored Muon
|
|
302
298
|
|
|
@@ -363,7 +359,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
363
359
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
364
360
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
365
361
|
# Normalize update
|
|
366
|
-
|
|
362
|
+
if group['normuon_atan2']:
|
|
363
|
+
a = 1.2732395
|
|
364
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
365
|
+
else:
|
|
366
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
367
367
|
# Scale learning rate
|
|
368
368
|
update_norm = torch.linalg.vector_norm(update)
|
|
369
369
|
|
|
@@ -464,7 +464,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
464
464
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
465
465
|
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
466
466
|
# Normalize update
|
|
467
|
-
|
|
467
|
+
if group['normuon_atan2']:
|
|
468
|
+
a = 1.2732395
|
|
469
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
470
|
+
else:
|
|
471
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
468
472
|
# Scale learning rate
|
|
469
473
|
update_norm = torch.linalg.vector_norm(update)
|
|
470
474
|
scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])
|