adv-optm 1.1.3__py3-none-any.whl → 1.2.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +3 -1
- adv_optm/optim/AdamW_adv.py +8 -4
- adv_optm/optim/Muon_adv.py +302 -0
- adv_optm/optim/__init__.py +2 -0
- adv_optm/util/MuonAdam_helper.py +31 -0
- adv_optm/util/Newton_Schulz.py +48 -0
- adv_optm/util/__init__.py +2 -1
- {adv_optm-1.1.3.dist-info → adv_optm-1.2.dev2.dist-info}/METADATA +1 -1
- adv_optm-1.2.dev2.dist-info/RECORD +23 -0
- adv_optm-1.1.3.dist-info/RECORD +0 -20
- {adv_optm-1.1.3.dist-info → adv_optm-1.2.dev2.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.3.dist-info → adv_optm-1.2.dev2.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.3.dist-info → adv_optm-1.2.dev2.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from .optim import (
|
|
|
5
5
|
Simplified_AdEMAMix,
|
|
6
6
|
Lion_adv,
|
|
7
7
|
Lion_Prodigy_adv,
|
|
8
|
+
Muon_adv,
|
|
8
9
|
)
|
|
9
10
|
|
|
10
11
|
__all__ = [
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"Simplified_AdEMAMix",
|
|
15
16
|
"Lion_adv",
|
|
16
17
|
"Lion_Prodigy_adv",
|
|
18
|
+
"Muon_adv",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
|
-
__version__ = "1.
|
|
21
|
+
__version__ = "1.2.dev2"
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -107,6 +107,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
107
107
|
k_logging: int = 0,
|
|
108
108
|
layer_key_fn: Optional[Callable] = None,
|
|
109
109
|
nnmf_factor: bool = False,
|
|
110
|
+
_is_delegate: bool = False,
|
|
110
111
|
):
|
|
111
112
|
if not (lr >= 0.0):
|
|
112
113
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -137,10 +138,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
137
138
|
self.factored = nnmf_factor
|
|
138
139
|
self.kourkoutas_beta = kourkoutas_beta
|
|
139
140
|
self.layer_key_fn = layer_key_fn
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
self.
|
|
141
|
+
if not _is_delegate:
|
|
142
|
+
super().__init__(params, defaults)
|
|
143
|
+
else:
|
|
144
|
+
self.defaults = defaults
|
|
145
|
+
self.kourkoutas_helper = None
|
|
144
146
|
|
|
145
147
|
@property
|
|
146
148
|
def supports_fused_back_pass(self):
|
|
@@ -158,6 +160,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
158
160
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
159
161
|
if p.grad is None:
|
|
160
162
|
return
|
|
163
|
+
if group.get('kourkoutas_beta', False) and self.kourkoutas_helper is None:
|
|
164
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
161
165
|
|
|
162
166
|
grad = p.grad
|
|
163
167
|
if grad.dtype != torch.float32 and self.factored:
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
3
|
+
|
|
4
|
+
from .AdamW_adv import AdamW_adv
|
|
5
|
+
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
|
+
|
|
7
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
|
+
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
9
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
10
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
11
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
12
|
+
|
|
13
|
+
class Muon_adv(torch.optim.Optimizer):
|
|
14
|
+
"""
|
|
15
|
+
Implements an advanced Muon algorithm.
|
|
16
|
+
|
|
17
|
+
Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
|
|
18
|
+
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
19
|
+
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
20
|
+
|
|
21
|
+
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
22
|
+
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
23
|
+
flattening/reshaping them.
|
|
24
|
+
|
|
25
|
+
This version can also operate in a hybrid mode, using an auxiliary AdamW
|
|
26
|
+
optimizer for specific parameters (e.g., biases, norms, embeddings) as
|
|
27
|
+
defined by a `layer_key_fn`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
31
|
+
parameter groups.
|
|
32
|
+
lr (float): learning rate (default: 1e-3).
|
|
33
|
+
beta1 (float): momentum factor (default: 0.9).
|
|
34
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
35
|
+
nesterov (bool): enables Nesterov momentum (default: True).
|
|
36
|
+
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
37
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
|
|
38
|
+
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
39
|
+
quintic polynomial in the Newton-Schulz iteration.
|
|
40
|
+
(default: (3.4445, -4.7750, 2.0315)).
|
|
41
|
+
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
42
|
+
BF16 parameter updates (default: True).
|
|
43
|
+
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
44
|
+
matrices for muon NewtonSchulz (default: False).
|
|
45
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
46
|
+
matrices to apply low-rank compression (default: True).
|
|
47
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
48
|
+
the uncompressed optimizer. (default: False)
|
|
49
|
+
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
50
|
+
Parameters designated by `layer_key_fn` will be optimized with
|
|
51
|
+
AdamW_adv instead of Muon. (default: False)
|
|
52
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
53
|
+
and returns a key. If the key is 'adam', the parameter is handled by
|
|
54
|
+
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
55
|
+
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
56
|
+
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
57
|
+
to the auxiliary AdamW_adv optimizer. Only used when
|
|
58
|
+
`MuonWithAuxAdam` is True. (default: None)
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
params,
|
|
64
|
+
lr: float = 1e-3,
|
|
65
|
+
beta1: float = 0.9,
|
|
66
|
+
weight_decay: float = 0.0,
|
|
67
|
+
nesterov: bool = True,
|
|
68
|
+
ns_steps: int = 5,
|
|
69
|
+
ns_eps: float = 1e-7,
|
|
70
|
+
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
71
|
+
stochastic_rounding: bool = True,
|
|
72
|
+
vector_reshape_muon: bool = False,
|
|
73
|
+
vector_reshape: bool = True,
|
|
74
|
+
nnmf_factor: bool = False,
|
|
75
|
+
# hybrid optimizer mode
|
|
76
|
+
MuonWithAuxAdam: bool = False,
|
|
77
|
+
layer_key_fn: Optional[Callable] = None,
|
|
78
|
+
muon_adam_lr: float = 1e-4,
|
|
79
|
+
adam_kwargs: Optional[dict] = None,
|
|
80
|
+
):
|
|
81
|
+
if not (lr >= 0.0):
|
|
82
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
83
|
+
if not (0.0 <= beta1 < 1.0):
|
|
84
|
+
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
85
|
+
if not (weight_decay >= 0.0):
|
|
86
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
87
|
+
if not (ns_steps > 0):
|
|
88
|
+
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
89
|
+
|
|
90
|
+
defaults = {
|
|
91
|
+
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
92
|
+
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
93
|
+
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
94
|
+
"vector_reshape": vector_reshape,
|
|
95
|
+
"vector_reshape_muon": vector_reshape_muon,
|
|
96
|
+
}
|
|
97
|
+
self.stochastic_rounding = stochastic_rounding
|
|
98
|
+
|
|
99
|
+
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
100
|
+
self.helper = None
|
|
101
|
+
self.aux_adam = None
|
|
102
|
+
|
|
103
|
+
if self.MuonWithAuxAdam:
|
|
104
|
+
adam_kwargs = adam_kwargs or {}
|
|
105
|
+
# Create a delegate AdamW optimizer to get its default hyperparameters.
|
|
106
|
+
self.aux_adam = AdamW_adv(
|
|
107
|
+
[],
|
|
108
|
+
lr=muon_adam_lr,
|
|
109
|
+
**adam_kwargs,
|
|
110
|
+
_is_delegate=True
|
|
111
|
+
)
|
|
112
|
+
# Update the defaults dictionary
|
|
113
|
+
defaults.update(self.aux_adam.defaults)
|
|
114
|
+
|
|
115
|
+
super().__init__(params, defaults)
|
|
116
|
+
|
|
117
|
+
if self.MuonWithAuxAdam:
|
|
118
|
+
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def supports_fused_back_pass(self):
|
|
123
|
+
return True
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def supports_memory_efficient_fp16(self):
|
|
127
|
+
return True
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def supports_flat_params(self):
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
@torch.no_grad()
|
|
134
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
135
|
+
if self.MuonWithAuxAdam:
|
|
136
|
+
optim_type = self.helper.get_optimizer_type(p)
|
|
137
|
+
if optim_type == 'adam':
|
|
138
|
+
# Delegate to the AdamW_adv optimizer's logic.
|
|
139
|
+
# We need to temporarily "lend" our state and param_groups
|
|
140
|
+
# to the delegate so it has the full context to work with,
|
|
141
|
+
# especially for features like Kourkoutas-beta.
|
|
142
|
+
self.aux_adam.state = self.state
|
|
143
|
+
self.aux_adam.param_groups = self.param_groups
|
|
144
|
+
self.aux_adam.step_parameter(p, group, i)
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
if p.grad is None:
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
grad = p.grad
|
|
151
|
+
state = self.state[p]
|
|
152
|
+
|
|
153
|
+
# State Initialization
|
|
154
|
+
if 'step' not in state:
|
|
155
|
+
state['step'] = 0
|
|
156
|
+
|
|
157
|
+
should_factor = (
|
|
158
|
+
group['nnmf_factor'] and
|
|
159
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
state['factored'] = should_factor
|
|
163
|
+
|
|
164
|
+
state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
|
|
165
|
+
|
|
166
|
+
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
167
|
+
device = p.device
|
|
168
|
+
if group['vector_reshape'] or state['reshaped_1d_muon']:
|
|
169
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
170
|
+
d1, d2 = state['effective_shape']
|
|
171
|
+
if state['factored']:
|
|
172
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
173
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
174
|
+
packed_d2 = (d2 + 7) // 8
|
|
175
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
176
|
+
else:
|
|
177
|
+
if len(p.shape) >= 2:
|
|
178
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
179
|
+
if state['reshaped_1d_muon']:
|
|
180
|
+
state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
|
|
181
|
+
elif len(p.shape) == 1:
|
|
182
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
183
|
+
|
|
184
|
+
beta1 = group['beta1']
|
|
185
|
+
nesterov = group['nesterov']
|
|
186
|
+
|
|
187
|
+
if state['factored']: # Factored Muon
|
|
188
|
+
|
|
189
|
+
# Reconstruct momentum from previous step's factors & sign
|
|
190
|
+
d1, d2 = state['effective_shape']
|
|
191
|
+
mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
192
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
193
|
+
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
194
|
+
del unpacked_sign
|
|
195
|
+
|
|
196
|
+
# Update momentum in full-size
|
|
197
|
+
grad_reshaped = grad.view(d1, d2)
|
|
198
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
199
|
+
|
|
200
|
+
if nesterov:
|
|
201
|
+
# Nesterov momentum
|
|
202
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
203
|
+
else:
|
|
204
|
+
# Standard momentum
|
|
205
|
+
update = mt_buf.clone()
|
|
206
|
+
del grad_reshaped
|
|
207
|
+
|
|
208
|
+
update = _newton_schulz_iteration(
|
|
209
|
+
update,
|
|
210
|
+
steps=group['ns_steps'],
|
|
211
|
+
eps=group['ns_eps'],
|
|
212
|
+
coeffs=group['ns_coeffs'],
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
216
|
+
|
|
217
|
+
state['sign'] = _pack_bools(mt_buf > 0)
|
|
218
|
+
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
219
|
+
del mt_buf
|
|
220
|
+
|
|
221
|
+
else: # Standard Muon logic for non-factored tensors
|
|
222
|
+
|
|
223
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
224
|
+
|
|
225
|
+
# Momentum update
|
|
226
|
+
mt_buf = state['momentum_buffer']
|
|
227
|
+
if state['reshaped_1d_muon']:
|
|
228
|
+
d1, d2 = state['effective_shape']
|
|
229
|
+
grad_reshaped = grad.view(d1, d2)
|
|
230
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
231
|
+
else:
|
|
232
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
233
|
+
|
|
234
|
+
if nesterov:
|
|
235
|
+
# Nesterov momentum
|
|
236
|
+
if state['reshaped_1d_muon']:
|
|
237
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
238
|
+
del grad_reshaped
|
|
239
|
+
else:
|
|
240
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
241
|
+
else:
|
|
242
|
+
# Standard momentum
|
|
243
|
+
update = mt_buf.clone()
|
|
244
|
+
|
|
245
|
+
# For Conv layers (4D) or other high-dim tensors, flatten to 2D
|
|
246
|
+
if len(p.shape) > 2:
|
|
247
|
+
update = update.view(p.shape[0], -1)
|
|
248
|
+
|
|
249
|
+
# NewtonSchulz
|
|
250
|
+
update = _newton_schulz_iteration(
|
|
251
|
+
update,
|
|
252
|
+
steps=group['ns_steps'],
|
|
253
|
+
eps=group['ns_eps'],
|
|
254
|
+
coeffs=group['ns_coeffs'],
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Reshape back to original if we flattened or reshaped
|
|
258
|
+
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
259
|
+
update = update.view(p.shape)
|
|
260
|
+
|
|
261
|
+
update.mul_(group['lr'])
|
|
262
|
+
|
|
263
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
264
|
+
# Momentum update
|
|
265
|
+
mt_buf = state['momentum_buffer']
|
|
266
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
267
|
+
if nesterov:
|
|
268
|
+
# Nesterov momentum
|
|
269
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
270
|
+
else:
|
|
271
|
+
# Standard momentum
|
|
272
|
+
update = mt_buf.clone()
|
|
273
|
+
update.mul_(group['lr'])
|
|
274
|
+
|
|
275
|
+
# Decoupled weight decay
|
|
276
|
+
if group["weight_decay"] != 0:
|
|
277
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
278
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
279
|
+
else:
|
|
280
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
281
|
+
|
|
282
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
283
|
+
add_stochastic_(p.data, -update)
|
|
284
|
+
else:
|
|
285
|
+
p.data.add_(-update)
|
|
286
|
+
del update
|
|
287
|
+
|
|
288
|
+
state['step'] += 1
|
|
289
|
+
|
|
290
|
+
@torch.no_grad()
|
|
291
|
+
def step(self, closure=None):
|
|
292
|
+
"""Performs a single optimization step."""
|
|
293
|
+
loss = None
|
|
294
|
+
if closure is not None:
|
|
295
|
+
with torch.enable_grad():
|
|
296
|
+
loss = closure()
|
|
297
|
+
|
|
298
|
+
for group in self.param_groups:
|
|
299
|
+
for i, p in enumerate(group['params']):
|
|
300
|
+
self.step_parameter(p, group, i)
|
|
301
|
+
|
|
302
|
+
return loss
|
adv_optm/optim/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from .Adopt_adv import Adopt_adv
|
|
|
4
4
|
from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
5
5
|
from .Lion_adv import Lion_adv
|
|
6
6
|
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
7
|
+
from .Muon_adv import Muon_adv
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"AdamW_adv",
|
|
@@ -12,4 +13,5 @@ __all__ = [
|
|
|
12
13
|
"Simplified_AdEMAMix",
|
|
13
14
|
"Lion_adv",
|
|
14
15
|
"Lion_Prodigy_adv",
|
|
16
|
+
"Muon_adv",
|
|
15
17
|
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from torch.optim import Optimizer
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
|
|
4
|
+
class MuonAdamHelper:
|
|
5
|
+
"""
|
|
6
|
+
A helper class for Muon_adv to decide whether to use Muon or a delegate
|
|
7
|
+
AdamW optimizer for a given parameter based on a keying function.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, optimizer: Optimizer, layer_key_fn: Optional[Callable]):
|
|
10
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
11
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
12
|
+
self.optimizer = optimizer
|
|
13
|
+
|
|
14
|
+
if layer_key_fn is None:
|
|
15
|
+
# If no function is provided, default all parameters to 'muon'.
|
|
16
|
+
self.layer_key_fn = lambda p: 'muon'
|
|
17
|
+
else:
|
|
18
|
+
self.layer_key_fn = layer_key_fn
|
|
19
|
+
|
|
20
|
+
def get_optimizer_type(self, p: "torch.Tensor") -> str:
|
|
21
|
+
"""
|
|
22
|
+
Returns the designated optimizer type ('adam' or 'muon') for a parameter.
|
|
23
|
+
|
|
24
|
+
The user-provided layer_key_fn should return 'adam' for parameters
|
|
25
|
+
to be handled by the auxiliary AdamW optimizer. Any other return
|
|
26
|
+
value is treated as 'muon'.
|
|
27
|
+
"""
|
|
28
|
+
key = self.layer_key_fn(p)
|
|
29
|
+
if key == 'adam':
|
|
30
|
+
return 'adam'
|
|
31
|
+
return 'muon'
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _newton_schulz_iteration(
|
|
5
|
+
G: torch.Tensor,
|
|
6
|
+
steps: int = 5,
|
|
7
|
+
eps: float = 1e-7,
|
|
8
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
|
|
9
|
+
) -> torch.Tensor:
|
|
10
|
+
"""
|
|
11
|
+
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
12
|
+
This is the core computation of the Muon optimizer.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
|
|
16
|
+
steps (int): The number of iterations to run.
|
|
17
|
+
eps (float): Small constant for numerical stability during normalization.
|
|
18
|
+
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
19
|
+
quintic polynomial update.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
torch.Tensor: The orthogonalized matrix.
|
|
23
|
+
"""
|
|
24
|
+
assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
|
|
25
|
+
|
|
26
|
+
a, b, c = coeffs
|
|
27
|
+
|
|
28
|
+
X = G.to(torch.bfloat16)
|
|
29
|
+
|
|
30
|
+
# Normalize the matrix
|
|
31
|
+
X.div_(X.norm() + eps)
|
|
32
|
+
|
|
33
|
+
# Handle non-square matrices by transposing the taller one
|
|
34
|
+
transposed = G.size(0) > G.size(1)
|
|
35
|
+
if transposed:
|
|
36
|
+
X = X.T
|
|
37
|
+
|
|
38
|
+
# Perform the iterative updates
|
|
39
|
+
for _ in range(steps):
|
|
40
|
+
A = X @ X.T
|
|
41
|
+
B = b * A + c * (A @ A)
|
|
42
|
+
X = a * X + B @ X
|
|
43
|
+
|
|
44
|
+
# Transpose back if necessary
|
|
45
|
+
if transposed:
|
|
46
|
+
X = X.T
|
|
47
|
+
|
|
48
|
+
return X.to(G.dtype)
|
adv_optm/util/__init__.py
CHANGED
|
@@ -2,10 +2,11 @@ from .BF16_Stochastic_Rounding import add_stochastic_
|
|
|
2
2
|
from .Effective_Shape import _get_effective_shape
|
|
3
3
|
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
4
|
from .OrthoGrad import _orthogonalize_gradient
|
|
5
|
-
|
|
5
|
+
from .Newton_Schulz import _newton_schulz_iteration
|
|
6
6
|
__all__ = [
|
|
7
7
|
"_pack_bools", "_unpack_bools",
|
|
8
8
|
"add_stochastic_",
|
|
9
9
|
"_get_effective_shape",
|
|
10
10
|
"_orthogonalize_gradient",
|
|
11
|
+
"_newton_schulz_iteration",
|
|
11
12
|
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=THWhNF8-PI71K9Au4xAkuDs96YcEagJ-yT5r_g2-yKw,341
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=Zym0beeu0ye5_PgpAjpzcYghdPYFWs3gQzDmuPZVR80,17690
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=NXbtPrGm3tZr06cApi5oEHZ2F1zwss3tRi15SGnrYPc,21426
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
+
adv_optm/optim/Muon_adv.py,sha256=9K5YR3odaGfDDZzasletHRlqxG8xN9IXj6oiqx1CaEI,12423
|
|
7
|
+
adv_optm/optim/Prodigy_adv.py,sha256=0_XG5YnMQTv-zJysJHlJniSo5kGYdX3p3o1e33HLt78,25897
|
|
8
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
9
|
+
adv_optm/optim/__init__.py,sha256=3o2XJ4J-PUq3rJM2mBnmuHwbKNb4LuW-Ig_9aBC0ycc,431
|
|
10
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
11
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
12
|
+
adv_optm/util/Kourkoutas.py,sha256=woyJfX7l4eieeg0pC5XrILBLvwECwbD3a6ou1K6qjKU,8706
|
|
13
|
+
adv_optm/util/MuonAdam_helper.py,sha256=llPCc9MBFen_wodbY4G2E17tBZky8clDiJSZLHkMva8,1236
|
|
14
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
15
|
+
adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
|
|
16
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
|
+
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
19
|
+
adv_optm-1.2.dev2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev2.dist-info/METADATA,sha256=JTCPGBJUd4JR7DU26AhX8qSPzWrSVtEwv9Au7I3iEPY,14022
|
|
21
|
+
adv_optm-1.2.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev2.dist-info/RECORD,,
|
adv_optm-1.1.3.dist-info/RECORD
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=9UZMsxIFudooscrxW4TwKgj3PkrKdC5ZFEOAkYpkrMw,306
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=7vWfPS2J54U9ZKFQiNJ_l86PvITb0MQ61Fy4Fzmf1d4,17479
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=NXbtPrGm3tZr06cApi5oEHZ2F1zwss3tRi15SGnrYPc,21426
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=0_XG5YnMQTv-zJysJHlJniSo5kGYdX3p3o1e33HLt78,25897
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
8
|
-
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/Kourkoutas.py,sha256=woyJfX7l4eieeg0pC5XrILBLvwECwbD3a6ou1K6qjKU,8706
|
|
12
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
13
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
|
-
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
-
adv_optm-1.1.3.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
-
adv_optm-1.1.3.dist-info/METADATA,sha256=IGemhIn9C4Zg9nE5VaiZjVuRqnBGNxlLNaXabRVXG8Y,14019
|
|
18
|
-
adv_optm-1.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
-
adv_optm-1.1.3.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
-
adv_optm-1.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|