adv-optm 1.1.2__py3-none-any.whl → 1.2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +3 -1
- adv_optm/optim/AdamW_adv.py +2 -2
- adv_optm/optim/Adopt_adv.py +1 -1
- adv_optm/optim/Lion_Prodigy_adv.py +1 -1
- adv_optm/optim/Muon_adv.py +247 -0
- adv_optm/optim/Prodigy_adv.py +2 -2
- adv_optm/optim/Simplified_AdEMAMix.py +2 -2
- adv_optm/optim/__init__.py +2 -0
- adv_optm/util/Newton_Schulz.py +48 -0
- adv_optm/util/__init__.py +2 -1
- {adv_optm-1.1.2.dist-info → adv_optm-1.2.dev1.dist-info}/METADATA +1 -1
- adv_optm-1.2.dev1.dist-info/RECORD +22 -0
- adv_optm-1.1.2.dist-info/RECORD +0 -20
- {adv_optm-1.1.2.dist-info → adv_optm-1.2.dev1.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.2.dist-info → adv_optm-1.2.dev1.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.2.dist-info → adv_optm-1.2.dev1.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from .optim import (
|
|
|
5
5
|
Simplified_AdEMAMix,
|
|
6
6
|
Lion_adv,
|
|
7
7
|
Lion_Prodigy_adv,
|
|
8
|
+
Muon_adv,
|
|
8
9
|
)
|
|
9
10
|
|
|
10
11
|
__all__ = [
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"Simplified_AdEMAMix",
|
|
15
16
|
"Lion_adv",
|
|
16
17
|
"Lion_Prodigy_adv",
|
|
18
|
+
"Muon_adv",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
|
-
__version__ = "1.
|
|
21
|
+
__version__ = "1.2.dev1"
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -209,7 +209,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
209
209
|
beta1, beta2 = group['betas']
|
|
210
210
|
|
|
211
211
|
current_step = state['step']
|
|
212
|
-
if group
|
|
212
|
+
if group.get('kourkoutas_beta', False):
|
|
213
213
|
# Call prepare_step() once at the beginning of the step for all params
|
|
214
214
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
215
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -220,7 +220,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
220
220
|
step = state['step'] + 1
|
|
221
221
|
if group['use_bias_correction']:
|
|
222
222
|
bias_correction1 = 1.0 - beta1 ** step
|
|
223
|
-
if group
|
|
223
|
+
if group.get('kourkoutas_beta', False):
|
|
224
224
|
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
225
|
# Use beta2_max for bias correction
|
|
226
226
|
else:
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -240,7 +240,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
240
240
|
beta1, beta2 = group['betas']
|
|
241
241
|
|
|
242
242
|
current_step = state['step']
|
|
243
|
-
if group
|
|
243
|
+
if group.get('kourkoutas_beta', False):
|
|
244
244
|
# Call prepare_step() once at the beginning of the step for all params
|
|
245
245
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
246
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -325,7 +325,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
325
325
|
d_hat = self.d
|
|
326
326
|
if global_d_denom > 0:
|
|
327
327
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
328
|
-
if g_group
|
|
328
|
+
if g_group.get('d_limiter', False):
|
|
329
329
|
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
330
330
|
if self.d == g_group['d0']:
|
|
331
331
|
self.d = max(self.d, d_hat)
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
+
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
6
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
7
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
8
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
|
|
10
|
+
class Muon_adv(torch.optim.Optimizer):
|
|
11
|
+
"""
|
|
12
|
+
Implements an advanced Muon algorithm.
|
|
13
|
+
|
|
14
|
+
Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
|
|
15
|
+
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
16
|
+
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
17
|
+
|
|
18
|
+
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
19
|
+
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
20
|
+
flattening/reshaping them.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
24
|
+
parameter groups.
|
|
25
|
+
lr (float): learning rate (default: 1e-3).
|
|
26
|
+
beta1 (float): momentum factor (default: 0.9).
|
|
27
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
28
|
+
nesterov (bool): enables Nesterov momentum (default: True).
|
|
29
|
+
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
30
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
|
|
31
|
+
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
32
|
+
quintic polynomial in the Newton-Schulz iteration.
|
|
33
|
+
(default: (3.4445, -4.7750, 2.0315)).
|
|
34
|
+
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
35
|
+
BF16 parameter updates (default: True).
|
|
36
|
+
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
37
|
+
matrices for muon NewtonSchulz (default: False).
|
|
38
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
39
|
+
matrices to apply low-rank compression (default: True).
|
|
40
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
41
|
+
the uncompressed optimizer. (default: False)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
params,
|
|
47
|
+
lr: float = 1e-3,
|
|
48
|
+
beta1: float = 0.9,
|
|
49
|
+
weight_decay: float = 0.0,
|
|
50
|
+
nesterov: bool = True,
|
|
51
|
+
ns_steps: int = 5,
|
|
52
|
+
ns_eps: float = 1e-7,
|
|
53
|
+
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
54
|
+
stochastic_rounding: bool = True,
|
|
55
|
+
vector_reshape_muon: bool = False,
|
|
56
|
+
vector_reshape: bool = True,
|
|
57
|
+
nnmf_factor: bool = False,
|
|
58
|
+
):
|
|
59
|
+
if not (lr >= 0.0):
|
|
60
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
61
|
+
if not (0.0 <= beta1 < 1.0):
|
|
62
|
+
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
63
|
+
if not (weight_decay >= 0.0):
|
|
64
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
65
|
+
if not (ns_steps > 0):
|
|
66
|
+
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
67
|
+
|
|
68
|
+
defaults = {
|
|
69
|
+
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
70
|
+
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
71
|
+
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
72
|
+
"vector_reshape": vector_reshape,
|
|
73
|
+
"vector_reshape_muon": vector_reshape_muon,
|
|
74
|
+
}
|
|
75
|
+
self.stochastic_rounding = stochastic_rounding
|
|
76
|
+
super().__init__(params, defaults)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def supports_fused_back_pass(self):
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def supports_memory_efficient_fp16(self):
|
|
84
|
+
return True
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def supports_flat_params(self):
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
@torch.no_grad()
|
|
91
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
92
|
+
if p.grad is None:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
grad = p.grad
|
|
96
|
+
state = self.state[p]
|
|
97
|
+
|
|
98
|
+
# State Initialization
|
|
99
|
+
if 'step' not in state:
|
|
100
|
+
state['step'] = 0
|
|
101
|
+
|
|
102
|
+
should_factor = (
|
|
103
|
+
group['nnmf_factor'] and
|
|
104
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
state['factored'] = should_factor
|
|
108
|
+
|
|
109
|
+
state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
|
|
110
|
+
|
|
111
|
+
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
112
|
+
device = p.device
|
|
113
|
+
if group['vector_reshape'] or state['reshaped_1d_muon']:
|
|
114
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
115
|
+
d1, d2 = state['effective_shape']
|
|
116
|
+
if state['factored']:
|
|
117
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
118
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
119
|
+
packed_d2 = (d2 + 7) // 8
|
|
120
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
121
|
+
else:
|
|
122
|
+
if len(p.shape) >= 2:
|
|
123
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
124
|
+
if state['reshaped_1d_muon']:
|
|
125
|
+
state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
|
|
126
|
+
elif len(p.shape) == 1:
|
|
127
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
128
|
+
|
|
129
|
+
beta1 = group['beta1']
|
|
130
|
+
nesterov = group['nesterov']
|
|
131
|
+
|
|
132
|
+
if state['factored']: # Factored Muon
|
|
133
|
+
|
|
134
|
+
# Reconstruct momentum from previous step's factors & sign
|
|
135
|
+
d1, d2 = state['effective_shape']
|
|
136
|
+
mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
137
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
138
|
+
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
139
|
+
del unpacked_sign
|
|
140
|
+
|
|
141
|
+
# Update momentum in full-size
|
|
142
|
+
grad_reshaped = grad.view(d1, d2)
|
|
143
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
144
|
+
|
|
145
|
+
if nesterov:
|
|
146
|
+
# Nesterov momentum
|
|
147
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
148
|
+
else:
|
|
149
|
+
# Standard momentum
|
|
150
|
+
update = mt_buf.clone()
|
|
151
|
+
del grad_reshaped
|
|
152
|
+
|
|
153
|
+
update = _newton_schulz_iteration(
|
|
154
|
+
update,
|
|
155
|
+
steps=group['ns_steps'],
|
|
156
|
+
eps=group['ns_eps'],
|
|
157
|
+
coeffs=group['ns_coeffs'],
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
161
|
+
|
|
162
|
+
state['sign'] = _pack_bools(mt_buf > 0)
|
|
163
|
+
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
164
|
+
del mt_buf
|
|
165
|
+
|
|
166
|
+
else: # Standard Muon logic for non-factored tensors
|
|
167
|
+
|
|
168
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
169
|
+
|
|
170
|
+
# Momentum update
|
|
171
|
+
mt_buf = state['momentum_buffer']
|
|
172
|
+
if state['reshaped_1d_muon']:
|
|
173
|
+
d1, d2 = state['effective_shape']
|
|
174
|
+
grad_reshaped = grad.view(d1, d2)
|
|
175
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
176
|
+
else:
|
|
177
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
178
|
+
|
|
179
|
+
if nesterov:
|
|
180
|
+
# Nesterov momentum
|
|
181
|
+
if state['reshaped_1d_muon']:
|
|
182
|
+
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
183
|
+
del grad_reshaped
|
|
184
|
+
else:
|
|
185
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
186
|
+
else:
|
|
187
|
+
# Standard momentum
|
|
188
|
+
update = mt_buf.clone()
|
|
189
|
+
|
|
190
|
+
# For Conv layers (4D) or other high-dim tensors, flatten to 2D
|
|
191
|
+
if len(p.shape) > 2:
|
|
192
|
+
update = update.view(p.shape[0], -1)
|
|
193
|
+
|
|
194
|
+
# NewtonSchulz
|
|
195
|
+
update = _newton_schulz_iteration(
|
|
196
|
+
update,
|
|
197
|
+
steps=group['ns_steps'],
|
|
198
|
+
eps=group['ns_eps'],
|
|
199
|
+
coeffs=group['ns_coeffs'],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Reshape back to original if we flattened or reshaped
|
|
203
|
+
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
204
|
+
update = update.view(p.shape)
|
|
205
|
+
|
|
206
|
+
update.mul_(group['lr'])
|
|
207
|
+
|
|
208
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
209
|
+
# Momentum update
|
|
210
|
+
mt_buf = state['momentum_buffer']
|
|
211
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
212
|
+
if nesterov:
|
|
213
|
+
# Nesterov momentum
|
|
214
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
215
|
+
else:
|
|
216
|
+
# Standard momentum
|
|
217
|
+
update = mt_buf.clone()
|
|
218
|
+
update.mul_(group['lr'])
|
|
219
|
+
|
|
220
|
+
# Decoupled weight decay
|
|
221
|
+
if group["weight_decay"] != 0:
|
|
222
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
223
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
224
|
+
else:
|
|
225
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
226
|
+
|
|
227
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
228
|
+
add_stochastic_(p.data, -update)
|
|
229
|
+
else:
|
|
230
|
+
p.data.add_(-update)
|
|
231
|
+
del update
|
|
232
|
+
|
|
233
|
+
state['step'] += 1
|
|
234
|
+
|
|
235
|
+
@torch.no_grad()
|
|
236
|
+
def step(self, closure=None):
|
|
237
|
+
"""Performs a single optimization step."""
|
|
238
|
+
loss = None
|
|
239
|
+
if closure is not None:
|
|
240
|
+
with torch.enable_grad():
|
|
241
|
+
loss = closure()
|
|
242
|
+
|
|
243
|
+
for group in self.param_groups:
|
|
244
|
+
for i, p in enumerate(group['params']):
|
|
245
|
+
self.step_parameter(p, group, i)
|
|
246
|
+
|
|
247
|
+
return loss
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -304,7 +304,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
304
304
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
305
305
|
|
|
306
306
|
current_step = state['step']
|
|
307
|
-
if group
|
|
307
|
+
if group.get('kourkoutas_beta', False):
|
|
308
308
|
# Call prepare_step() once at the beginning of the step for all params
|
|
309
309
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
310
310
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -515,7 +515,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
515
515
|
d_hat = self.d
|
|
516
516
|
if global_d_denom > 0:
|
|
517
517
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
518
|
-
if g_group
|
|
518
|
+
if g_group.get('d_limiter', False):
|
|
519
519
|
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
520
520
|
if self.d == g_group['d0']:
|
|
521
521
|
self.d = max(self.d, d_hat)
|
|
@@ -191,7 +191,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
191
191
|
beta1_final, beta2 = group["betas"]
|
|
192
192
|
|
|
193
193
|
current_step = state['step']
|
|
194
|
-
if group
|
|
194
|
+
if group.get('kourkoutas_beta', False):
|
|
195
195
|
# Call prepare_step() once at the beginning of the step for all params
|
|
196
196
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
197
|
# Accumulate current grad's norm for the *next* step
|
|
@@ -210,7 +210,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
210
210
|
|
|
211
211
|
if group['use_bias_correction']:
|
|
212
212
|
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
213
|
-
if group
|
|
213
|
+
if group.get('kourkoutas_beta', False):
|
|
214
214
|
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
215
|
else:
|
|
216
216
|
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
adv_optm/optim/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from .Adopt_adv import Adopt_adv
|
|
|
4
4
|
from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
5
5
|
from .Lion_adv import Lion_adv
|
|
6
6
|
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
7
|
+
from .Muon_adv import Muon_adv
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"AdamW_adv",
|
|
@@ -12,4 +13,5 @@ __all__ = [
|
|
|
12
13
|
"Simplified_AdEMAMix",
|
|
13
14
|
"Lion_adv",
|
|
14
15
|
"Lion_Prodigy_adv",
|
|
16
|
+
"Muon_adv",
|
|
15
17
|
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _newton_schulz_iteration(
|
|
5
|
+
G: torch.Tensor,
|
|
6
|
+
steps: int = 5,
|
|
7
|
+
eps: float = 1e-7,
|
|
8
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
|
|
9
|
+
) -> torch.Tensor:
|
|
10
|
+
"""
|
|
11
|
+
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
12
|
+
This is the core computation of the Muon optimizer.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
|
|
16
|
+
steps (int): The number of iterations to run.
|
|
17
|
+
eps (float): Small constant for numerical stability during normalization.
|
|
18
|
+
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
19
|
+
quintic polynomial update.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
torch.Tensor: The orthogonalized matrix.
|
|
23
|
+
"""
|
|
24
|
+
assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
|
|
25
|
+
|
|
26
|
+
a, b, c = coeffs
|
|
27
|
+
|
|
28
|
+
X = G.to(torch.bfloat16)
|
|
29
|
+
|
|
30
|
+
# Normalize the matrix
|
|
31
|
+
X.div_(X.norm() + eps)
|
|
32
|
+
|
|
33
|
+
# Handle non-square matrices by transposing the taller one
|
|
34
|
+
transposed = G.size(0) > G.size(1)
|
|
35
|
+
if transposed:
|
|
36
|
+
X = X.T
|
|
37
|
+
|
|
38
|
+
# Perform the iterative updates
|
|
39
|
+
for _ in range(steps):
|
|
40
|
+
A = X @ X.T
|
|
41
|
+
B = b * A + c * (A @ A)
|
|
42
|
+
X = a * X + B @ X
|
|
43
|
+
|
|
44
|
+
# Transpose back if necessary
|
|
45
|
+
if transposed:
|
|
46
|
+
X = X.T
|
|
47
|
+
|
|
48
|
+
return X.to(G.dtype)
|
adv_optm/util/__init__.py
CHANGED
|
@@ -2,10 +2,11 @@ from .BF16_Stochastic_Rounding import add_stochastic_
|
|
|
2
2
|
from .Effective_Shape import _get_effective_shape
|
|
3
3
|
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
4
|
from .OrthoGrad import _orthogonalize_gradient
|
|
5
|
-
|
|
5
|
+
from .Newton_Schulz import _newton_schulz_iteration
|
|
6
6
|
__all__ = [
|
|
7
7
|
"_pack_bools", "_unpack_bools",
|
|
8
8
|
"add_stochastic_",
|
|
9
9
|
"_get_effective_shape",
|
|
10
10
|
"_orthogonalize_gradient",
|
|
11
|
+
"_newton_schulz_iteration",
|
|
11
12
|
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=fdpoxqapAZbMiPax_P4Zm9PkN--71G0yds0q0V9oAbo,341
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=7vWfPS2J54U9ZKFQiNJ_l86PvITb0MQ61Fy4Fzmf1d4,17479
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=NXbtPrGm3tZr06cApi5oEHZ2F1zwss3tRi15SGnrYPc,21426
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
+
adv_optm/optim/Muon_adv.py,sha256=eXqPL6GIWutBJpP7Yb_qIk7pGAjwfTAloCFRDhkRoUU,9908
|
|
7
|
+
adv_optm/optim/Prodigy_adv.py,sha256=0_XG5YnMQTv-zJysJHlJniSo5kGYdX3p3o1e33HLt78,25897
|
|
8
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
9
|
+
adv_optm/optim/__init__.py,sha256=3o2XJ4J-PUq3rJM2mBnmuHwbKNb4LuW-Ig_9aBC0ycc,431
|
|
10
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
11
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
12
|
+
adv_optm/util/Kourkoutas.py,sha256=woyJfX7l4eieeg0pC5XrILBLvwECwbD3a6ou1K6qjKU,8706
|
|
13
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
14
|
+
adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
|
|
15
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
16
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
17
|
+
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
18
|
+
adv_optm-1.2.dev1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
19
|
+
adv_optm-1.2.dev1.dist-info/METADATA,sha256=ofbAQu0ldYk8udMEC0jLcI9Ex2a6M8iaXheDTo4Un3M,14022
|
|
20
|
+
adv_optm-1.2.dev1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
21
|
+
adv_optm-1.2.dev1.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
22
|
+
adv_optm-1.2.dev1.dist-info/RECORD,,
|
adv_optm-1.1.2.dist-info/RECORD
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=IJAqLP1mOIBeEFFeMCrFvxEwK7oz-g5SRAEfmukmy9o,306
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=ddEUVOif1gfZPgEJNrEGZ2wnha4MPMWw5ppPd8acQ3o,17457
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=fhH3hS9K6z5Blxc7NFfzpCrUGbl9EQnwLPmKDxBC1zg,21415
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=0AY-2hjdNKnxiY4MUG4Y4HCe1DvesuwaaxRzrHUkAGA,14606
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=nD59cAWOJJCjZdIiuD5hD9MWO5sTjPQSvq-3dwGTcEM,25875
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=gPjMhKulzmAeO42foe-d7xW0AcB50vKFYsvHgxbD3uc,12949
|
|
8
|
-
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/Kourkoutas.py,sha256=woyJfX7l4eieeg0pC5XrILBLvwECwbD3a6ou1K6qjKU,8706
|
|
12
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
13
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
|
-
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
-
adv_optm-1.1.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
-
adv_optm-1.1.2.dist-info/METADATA,sha256=mtTfygEQn52Jqwc_W7rnDhJwRAArvaEjiK0s-cyDFVQ,14019
|
|
18
|
-
adv_optm-1.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
-
adv_optm-1.1.2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
-
adv_optm-1.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|