adv-optm 1.1.3__tar.gz → 1.2.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/PKG-INFO +1 -1
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/__init__.py +3 -1
- adv_optm-1.2.dev1/adv_optm/optim/Muon_adv.py +247 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/__init__.py +2 -0
- adv_optm-1.2.dev1/adv_optm/util/Newton_Schulz.py +48 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/__init__.py +2 -1
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/SOURCES.txt +2 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/setup.py +1 -1
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/LICENSE +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/README.md +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.1.3 → adv_optm-1.2.dev1}/setup.cfg +0 -0
|
@@ -5,6 +5,7 @@ from .optim import (
|
|
|
5
5
|
Simplified_AdEMAMix,
|
|
6
6
|
Lion_adv,
|
|
7
7
|
Lion_Prodigy_adv,
|
|
8
|
+
Muon_adv,
|
|
8
9
|
)
|
|
9
10
|
|
|
10
11
|
__all__ = [
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"Simplified_AdEMAMix",
|
|
15
16
|
"Lion_adv",
|
|
16
17
|
"Lion_Prodigy_adv",
|
|
18
|
+
"Muon_adv",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
|
-
__version__ = "1.
|
|
21
|
+
__version__ = "1.2.dev1"
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
+
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
|
|
10
|
+
class Muon_adv(torch.optim.Optimizer):
|
|
11
|
+
"""
|
|
12
|
+
Implements an advanced Muon algorithm.
|
|
13
|
+
|
|
14
|
+
Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
|
|
15
|
+
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
16
|
+
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
17
|
+
|
|
18
|
+
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
19
|
+
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
20
|
+
flattening/reshaping them.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
24
|
+
parameter groups.
|
|
25
|
+
lr (float): learning rate (default: 1e-3).
|
|
26
|
+
beta1 (float): momentum factor (default: 0.9).
|
|
27
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
28
|
+
nesterov (bool): enables Nesterov momentum (default: True).
|
|
29
|
+
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
30
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
|
|
31
|
+
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
32
|
+
quintic polynomial in the Newton-Schulz iteration.
|
|
33
|
+
(default: (3.4445, -4.7750, 2.0315)).
|
|
34
|
+
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
35
|
+
BF16 parameter updates (default: True).
|
|
36
|
+
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
37
|
+
matrices for muon NewtonSchulz (default: False).
|
|
38
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
39
|
+
matrices to apply low-rank compression (default: True).
|
|
40
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
41
|
+
the uncompressed optimizer. (default: False)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
params,
|
|
47
|
+
lr: float = 1e-3,
|
|
48
|
+
beta1: float = 0.9,
|
|
49
|
+
weight_decay: float = 0.0,
|
|
50
|
+
nesterov: bool = True,
|
|
51
|
+
ns_steps: int = 5,
|
|
52
|
+
ns_eps: float = 1e-7,
|
|
53
|
+
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
54
|
+
stochastic_rounding: bool = True,
|
|
55
|
+
vector_reshape_muon: bool = False,
|
|
56
|
+
vector_reshape: bool = True,
|
|
57
|
+
nnmf_factor: bool = False,
|
|
58
|
+
):
|
|
59
|
+
if not (lr >= 0.0):
|
|
60
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
61
|
+
if not (0.0 <= beta1 < 1.0):
|
|
62
|
+
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
63
|
+
if not (weight_decay >= 0.0):
|
|
64
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
65
|
+
if not (ns_steps > 0):
|
|
66
|
+
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
67
|
+
|
|
68
|
+
defaults = {
|
|
69
|
+
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
70
|
+
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
71
|
+
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
72
|
+
"vector_reshape": vector_reshape,
|
|
73
|
+
"vector_reshape_muon": vector_reshape_muon,
|
|
74
|
+
}
|
|
75
|
+
self.stochastic_rounding = stochastic_rounding
|
|
76
|
+
super().__init__(params, defaults)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def supports_fused_back_pass(self):
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def supports_memory_efficient_fp16(self):
|
|
84
|
+
return True
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def supports_flat_params(self):
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
@torch.no_grad()
|
|
91
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
92
|
+
if p.grad is None:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
grad = p.grad
|
|
96
|
+
state = self.state[p]
|
|
97
|
+
|
|
98
|
+
# State Initialization
|
|
99
|
+
if 'step' not in state:
|
|
100
|
+
state['step'] = 0
|
|
101
|
+
|
|
102
|
+
should_factor = (
|
|
103
|
+
group['nnmf_factor'] and
|
|
104
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
state['factored'] = should_factor
|
|
108
|
+
|
|
109
|
+
state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
|
|
110
|
+
|
|
111
|
+
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
112
|
+
device = p.device
|
|
113
|
+
if group['vector_reshape'] or state['reshaped_1d_muon']:
|
|
114
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
115
|
+
d1, d2 = state['effective_shape']
|
|
116
|
+
if state['factored']:
|
|
117
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
118
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
119
|
+
packed_d2 = (d2 + 7) // 8
|
|
120
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
121
|
+
else:
|
|
122
|
+
if len(p.shape) >= 2:
|
|
123
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
124
|
+
if state['reshaped_1d_muon']:
|
|
125
|
+
state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
|
|
126
|
+
elif len(p.shape) == 1:
|
|
127
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
128
|
+
|
|
129
|
+
beta1 = group['beta1']
|
|
130
|
+
nesterov = group['nesterov']
|
|
131
|
+
|
|
132
|
+
if state['factored']: # Factored Muon
|
|
133
|
+
|
|
134
|
+
# Reconstruct momentum from previous step's factors & sign
|
|
135
|
+
d1, d2 = state['effective_shape']
|
|
136
|
+
mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
137
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
138
|
+
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
139
|
+
del unpacked_sign
|
|
140
|
+
|
|
141
|
+
# Update momentum in full-size
|
|
142
|
+
grad_reshaped = grad.view(d1, d2)
|
|
143
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
144
|
+
|
|
145
|
+
if nesterov:
|
|
146
|
+
# Nesterov momentum
|
|
147
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
148
|
+
else:
|
|
149
|
+
# Standard momentum
|
|
150
|
+
update = mt_buf.clone()
|
|
151
|
+
del grad_reshaped
|
|
152
|
+
|
|
153
|
+
update = _newton_schulz_iteration(
|
|
154
|
+
update,
|
|
155
|
+
steps=group['ns_steps'],
|
|
156
|
+
eps=group['ns_eps'],
|
|
157
|
+
coeffs=group['ns_coeffs'],
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
161
|
+
|
|
162
|
+
state['sign'] = _pack_bools(mt_buf > 0)
|
|
163
|
+
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
164
|
+
del mt_buf
|
|
165
|
+
|
|
166
|
+
else: # Standard Muon logic for non-factored tensors
|
|
167
|
+
|
|
168
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
169
|
+
|
|
170
|
+
# Momentum update
|
|
171
|
+
mt_buf = state['momentum_buffer']
|
|
172
|
+
if state['reshaped_1d_muon']:
|
|
173
|
+
d1, d2 = state['effective_shape']
|
|
174
|
+
grad_reshaped = grad.view(d1, d2)
|
|
175
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
176
|
+
else:
|
|
177
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
178
|
+
|
|
179
|
+
if nesterov:
|
|
180
|
+
# Nesterov momentum
|
|
181
|
+
if state['reshaped_1d_muon']:
|
|
182
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
183
|
+
del grad_reshaped
|
|
184
|
+
else:
|
|
185
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
186
|
+
else:
|
|
187
|
+
# Standard momentum
|
|
188
|
+
update = mt_buf.clone()
|
|
189
|
+
|
|
190
|
+
# For Conv layers (4D) or other high-dim tensors, flatten to 2D
|
|
191
|
+
if len(p.shape) > 2:
|
|
192
|
+
update = update.view(p.shape[0], -1)
|
|
193
|
+
|
|
194
|
+
# NewtonSchulz
|
|
195
|
+
update = _newton_schulz_iteration(
|
|
196
|
+
update,
|
|
197
|
+
steps=group['ns_steps'],
|
|
198
|
+
eps=group['ns_eps'],
|
|
199
|
+
coeffs=group['ns_coeffs'],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Reshape back to original if we flattened or reshaped
|
|
203
|
+
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
204
|
+
update = update.view(p.shape)
|
|
205
|
+
|
|
206
|
+
update.mul_(group['lr'])
|
|
207
|
+
|
|
208
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
209
|
+
# Momentum update
|
|
210
|
+
mt_buf = state['momentum_buffer']
|
|
211
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
212
|
+
if nesterov:
|
|
213
|
+
# Nesterov momentum
|
|
214
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
215
|
+
else:
|
|
216
|
+
# Standard momentum
|
|
217
|
+
update = mt_buf.clone()
|
|
218
|
+
update.mul_(group['lr'])
|
|
219
|
+
|
|
220
|
+
# Decoupled weight decay
|
|
221
|
+
if group["weight_decay"] != 0:
|
|
222
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
223
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
224
|
+
else:
|
|
225
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
226
|
+
|
|
227
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
228
|
+
add_stochastic_(p.data, -update)
|
|
229
|
+
else:
|
|
230
|
+
p.data.add_(-update)
|
|
231
|
+
del update
|
|
232
|
+
|
|
233
|
+
state['step'] += 1
|
|
234
|
+
|
|
235
|
+
@torch.no_grad()
|
|
236
|
+
def step(self, closure=None):
|
|
237
|
+
"""Performs a single optimization step."""
|
|
238
|
+
loss = None
|
|
239
|
+
if closure is not None:
|
|
240
|
+
with torch.enable_grad():
|
|
241
|
+
loss = closure()
|
|
242
|
+
|
|
243
|
+
for group in self.param_groups:
|
|
244
|
+
for i, p in enumerate(group['params']):
|
|
245
|
+
self.step_parameter(p, group, i)
|
|
246
|
+
|
|
247
|
+
return loss
|
|
@@ -4,6 +4,7 @@ from .Adopt_adv import Adopt_adv
|
|
|
4
4
|
from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
5
5
|
from .Lion_adv import Lion_adv
|
|
6
6
|
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
7
|
+
from .Muon_adv import Muon_adv
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"AdamW_adv",
|
|
@@ -12,4 +13,5 @@ __all__ = [
|
|
|
12
13
|
"Simplified_AdEMAMix",
|
|
13
14
|
"Lion_adv",
|
|
14
15
|
"Lion_Prodigy_adv",
|
|
16
|
+
"Muon_adv",
|
|
15
17
|
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _newton_schulz_iteration(
|
|
5
|
+
G: torch.Tensor,
|
|
6
|
+
steps: int = 5,
|
|
7
|
+
eps: float = 1e-7,
|
|
8
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
|
|
9
|
+
) -> torch.Tensor:
|
|
10
|
+
"""
|
|
11
|
+
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
12
|
+
This is the core computation of the Muon optimizer.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
|
|
16
|
+
steps (int): The number of iterations to run.
|
|
17
|
+
eps (float): Small constant for numerical stability during normalization.
|
|
18
|
+
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
19
|
+
quintic polynomial update.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
torch.Tensor: The orthogonalized matrix.
|
|
23
|
+
"""
|
|
24
|
+
assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
|
|
25
|
+
|
|
26
|
+
a, b, c = coeffs
|
|
27
|
+
|
|
28
|
+
X = G.to(torch.bfloat16)
|
|
29
|
+
|
|
30
|
+
# Normalize the matrix
|
|
31
|
+
X.div_(X.norm() + eps)
|
|
32
|
+
|
|
33
|
+
# Handle non-square matrices by transposing the taller one
|
|
34
|
+
transposed = G.size(0) > G.size(1)
|
|
35
|
+
if transposed:
|
|
36
|
+
X = X.T
|
|
37
|
+
|
|
38
|
+
# Perform the iterative updates
|
|
39
|
+
for _ in range(steps):
|
|
40
|
+
A = X @ X.T
|
|
41
|
+
B = b * A + c * (A @ A)
|
|
42
|
+
X = a * X + B @ X
|
|
43
|
+
|
|
44
|
+
# Transpose back if necessary
|
|
45
|
+
if transposed:
|
|
46
|
+
X = X.T
|
|
47
|
+
|
|
48
|
+
return X.to(G.dtype)
|
|
@@ -2,10 +2,11 @@ from .BF16_Stochastic_Rounding import add_stochastic_
|
|
|
2
2
|
from .Effective_Shape import _get_effective_shape
|
|
3
3
|
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
4
|
from .OrthoGrad import _orthogonalize_gradient
|
|
5
|
-
|
|
5
|
+
from .Newton_Schulz import _newton_schulz_iteration
|
|
6
6
|
__all__ = [
|
|
7
7
|
"_pack_bools", "_unpack_bools",
|
|
8
8
|
"add_stochastic_",
|
|
9
9
|
"_get_effective_shape",
|
|
10
10
|
"_orthogonalize_gradient",
|
|
11
|
+
"_newton_schulz_iteration",
|
|
11
12
|
]
|
|
@@ -11,6 +11,7 @@ adv_optm/optim/AdamW_adv.py
|
|
|
11
11
|
adv_optm/optim/Adopt_adv.py
|
|
12
12
|
adv_optm/optim/Lion_Prodigy_adv.py
|
|
13
13
|
adv_optm/optim/Lion_adv.py
|
|
14
|
+
adv_optm/optim/Muon_adv.py
|
|
14
15
|
adv_optm/optim/Prodigy_adv.py
|
|
15
16
|
adv_optm/optim/Simplified_AdEMAMix.py
|
|
16
17
|
adv_optm/optim/__init__.py
|
|
@@ -18,6 +19,7 @@ adv_optm/util/BF16_Stochastic_Rounding.py
|
|
|
18
19
|
adv_optm/util/Effective_Shape.py
|
|
19
20
|
adv_optm/util/Kourkoutas.py
|
|
20
21
|
adv_optm/util/NNMF.py
|
|
22
|
+
adv_optm/util/Newton_Schulz.py
|
|
21
23
|
adv_optm/util/One_Bit_Boolean.py
|
|
22
24
|
adv_optm/util/OrthoGrad.py
|
|
23
25
|
adv_optm/util/__init__.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|