adv-optm 1.2.dev7__py3-none-any.whl → 1.2.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +74 -68
- adv_optm/optim/Muon_adv.py +3 -3
- adv_optm/util/Kourkoutas.py +1 -1
- {adv_optm-1.2.dev7.dist-info → adv_optm-1.2.dev8.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev7.dist-info → adv_optm-1.2.dev8.dist-info}/RECORD +9 -9
- {adv_optm-1.2.dev7.dist-info → adv_optm-1.2.dev8.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev7.dist-info → adv_optm-1.2.dev8.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev7.dist-info → adv_optm-1.2.dev8.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import Optional, Callable
|
|
|
3
3
|
|
|
4
4
|
from .AdamW_adv import AdamW_adv
|
|
5
5
|
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
|
-
from ..util.Kourkoutas import KourkoutasHelper
|
|
7
6
|
|
|
8
7
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
9
8
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -64,22 +63,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
64
63
|
matrices for muon NewtonSchulz (default: False).
|
|
65
64
|
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
66
65
|
matrices to apply low-rank compression (default: True).
|
|
66
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
67
|
+
projects the update to a lower rank before orthogonalization.
|
|
68
|
+
(default: False)
|
|
69
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
70
|
+
(default: 128)
|
|
67
71
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
68
72
|
the uncompressed optimizer. (default: False)
|
|
69
|
-
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
70
|
-
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
71
|
-
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
72
|
-
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
73
|
-
(default: 0.88)
|
|
74
|
-
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
75
|
-
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
76
|
-
(default: 0.93)
|
|
77
|
-
tiny_spike (float): A small constant added to the denominator of the
|
|
78
|
-
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
79
|
-
to `ε_spike` in the paper. (default: 1e-9)
|
|
80
|
-
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
81
|
-
at a fixed beta2 value before the
|
|
82
|
-
dynamic logic activates. (default: 0)
|
|
83
73
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
84
74
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
85
75
|
AdamW_adv instead of Muon. (default: False)
|
|
@@ -110,15 +100,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
110
100
|
alpha_grad: float = 100.0,
|
|
111
101
|
vector_reshape_muon: bool = False,
|
|
112
102
|
vector_reshape: bool = False,
|
|
103
|
+
# Low-rank Muon
|
|
104
|
+
low_rank_ortho: bool = False,
|
|
105
|
+
ortho_rank: int = 128,
|
|
113
106
|
nnmf_factor: bool = False,
|
|
114
|
-
# K-b parameters
|
|
115
|
-
kourkoutas_beta: bool = False,
|
|
116
|
-
beta2_min: float = 0.9,
|
|
117
|
-
ema_alpha: float = 0.95,
|
|
118
|
-
tiny_spike: float = 1e-9,
|
|
119
|
-
k_warmup_steps: int = 0,
|
|
120
|
-
k_logging: int = 0,
|
|
121
|
-
layer_key_kb_fn: Optional[Callable] = None,
|
|
122
107
|
# hybrid optimizer mode
|
|
123
108
|
MuonWithAuxAdam: bool = False,
|
|
124
109
|
layer_key_fn: Optional[Callable] = None,
|
|
@@ -142,14 +127,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
142
127
|
"vector_reshape": vector_reshape,
|
|
143
128
|
"vector_reshape_muon": vector_reshape_muon,
|
|
144
129
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
145
|
-
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
146
|
-
|
|
147
|
-
"
|
|
130
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
131
|
+
# Low-rank Ortho
|
|
132
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
148
133
|
}
|
|
149
134
|
self.stochastic_rounding = stochastic_rounding
|
|
150
|
-
self._kourkoutas_beta = kourkoutas_beta
|
|
151
|
-
self._kourkoutas_helper = None
|
|
152
|
-
self.layer_key_kb_fn = layer_key_kb_fn
|
|
153
135
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
154
136
|
self.helper = None
|
|
155
137
|
self.aux_adam = None
|
|
@@ -182,14 +164,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
182
164
|
|
|
183
165
|
for key, value in defaults_to_use.items():
|
|
184
166
|
new_group.setdefault(key, value)
|
|
185
|
-
if '_kourkoutas_beta' not in new_group:
|
|
186
|
-
if optim_type == 'adam':
|
|
187
|
-
new_group['_kourkoutas_beta'] = False
|
|
188
|
-
else:
|
|
189
|
-
new_group['_kourkoutas_beta'] = muon_defaults['_kourkoutas_beta']
|
|
190
167
|
final_param_groups.append(new_group)
|
|
191
168
|
|
|
192
|
-
super().__init__(final_param_groups,
|
|
169
|
+
super().__init__(final_param_groups, muon_defaults)
|
|
193
170
|
|
|
194
171
|
# Now that self is initialized, create the helper
|
|
195
172
|
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
@@ -219,9 +196,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
219
196
|
|
|
220
197
|
@torch.no_grad()
|
|
221
198
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
222
|
-
if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
|
|
223
|
-
self._kourkoutas_helper = KourkoutasHelper(self)
|
|
224
|
-
|
|
225
199
|
if self.MuonWithAuxAdam:
|
|
226
200
|
optim_type = self.helper.get_optimizer_type(p)
|
|
227
201
|
if optim_type == 'adam':
|
|
@@ -277,7 +251,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
277
251
|
|
|
278
252
|
# Retrieve hyperparameters
|
|
279
253
|
beta1, beta2 = group['betas']
|
|
280
|
-
current_step = state['step']
|
|
281
254
|
nesterov = group['nesterov']
|
|
282
255
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
283
256
|
alpha_grad = group['alpha_grad']
|
|
@@ -303,20 +276,37 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
303
276
|
signed_m_buf = torch.sign(mt_buf)
|
|
304
277
|
del grad_reshaped
|
|
305
278
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
279
|
+
# Orthogonalization step
|
|
280
|
+
if group['low_rank_ortho']:
|
|
281
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
282
|
+
M = signed_m_buf
|
|
283
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
284
|
+
if r > 0:
|
|
285
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
286
|
+
MG = M @ G_sketch
|
|
287
|
+
if MG.dtype != torch.float32:
|
|
288
|
+
MG_dtype = M.dtype
|
|
289
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
290
|
+
Q = Q.to(MG_dtype)
|
|
291
|
+
else:
|
|
292
|
+
Q, _ = torch.linalg.qr(MG)
|
|
293
|
+
projected_M = Q.T @ M
|
|
294
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
295
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
296
|
+
)
|
|
297
|
+
update = Q @ ortho_projected_M
|
|
298
|
+
else: # Fallback for invalid rank
|
|
299
|
+
update = _newton_schulz_iteration(
|
|
300
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
# Original full Newton-Schulz
|
|
304
|
+
update = _newton_schulz_iteration(
|
|
305
|
+
signed_m_buf,
|
|
306
|
+
steps=group['ns_steps'],
|
|
307
|
+
eps=group['ns_eps'],
|
|
308
|
+
coeffs=group['ns_coeffs'],
|
|
309
|
+
)
|
|
320
310
|
|
|
321
311
|
# Reconstruct second momentum from previous step's factors
|
|
322
312
|
vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
@@ -381,25 +371,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
381
371
|
if len(p.shape) > 2:
|
|
382
372
|
signed_m_buf = signed_m_buf.view(p.shape[0], -1)
|
|
383
373
|
|
|
384
|
-
#
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
374
|
+
# Orthogonalization step
|
|
375
|
+
if group['low_rank_ortho']:
|
|
376
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
377
|
+
M = signed_m_buf
|
|
378
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
379
|
+
if r > 0:
|
|
380
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
381
|
+
MG = M @ G_sketch
|
|
382
|
+
if MG.dtype != torch.float32:
|
|
383
|
+
MG_dtype = M.dtype
|
|
384
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
385
|
+
Q = Q.to(MG_dtype)
|
|
386
|
+
else:
|
|
387
|
+
Q, _ = torch.linalg.qr(MG)
|
|
388
|
+
projected_M = Q.T @ M
|
|
389
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
390
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
391
|
+
)
|
|
392
|
+
update = Q @ ortho_projected_M
|
|
393
|
+
else: # Fallback for invalid rank
|
|
394
|
+
update = _newton_schulz_iteration(
|
|
395
|
+
signed_m_buf, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
# Original full Newton-Schulz
|
|
399
|
+
update = _newton_schulz_iteration(
|
|
400
|
+
signed_m_buf,
|
|
401
|
+
steps=group['ns_steps'],
|
|
402
|
+
eps=group['ns_eps'],
|
|
403
|
+
coeffs=group['ns_coeffs'],
|
|
404
|
+
)
|
|
391
405
|
|
|
392
406
|
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
393
407
|
update = update.view(p.shape)
|
|
394
408
|
|
|
395
|
-
if group['_kourkoutas_beta']:
|
|
396
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
397
|
-
self._kourkoutas_helper.maybe_prepare_step(current_step)
|
|
398
|
-
# Accumulate current sign-stabilized orthogonal update's norm for the *next* step
|
|
399
|
-
self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
|
|
400
|
-
# Get the dynamic beta2 calculated in prepare_step()
|
|
401
|
-
beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
|
|
402
|
-
|
|
403
409
|
vt_buf = state['second_momentum_buffer']
|
|
404
410
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
405
411
|
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -178,7 +178,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
178
178
|
|
|
179
179
|
final_param_groups.append(new_group)
|
|
180
180
|
|
|
181
|
-
super().__init__(final_param_groups,
|
|
181
|
+
super().__init__(final_param_groups, muon_defaults)
|
|
182
182
|
|
|
183
183
|
# Now that self is initialized, create the helper
|
|
184
184
|
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
@@ -292,10 +292,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
292
292
|
del grad_reshaped
|
|
293
293
|
|
|
294
294
|
# Orthogonalization step
|
|
295
|
-
if group['
|
|
295
|
+
if group['low_rank_ortho']:
|
|
296
296
|
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
297
297
|
M = update
|
|
298
|
-
r = min(group['
|
|
298
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
299
299
|
if r > 0:
|
|
300
300
|
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
301
301
|
MG = M @ G_sketch
|
adv_optm/util/Kourkoutas.py
CHANGED
|
@@ -94,7 +94,7 @@ class KourkoutasHelper:
|
|
|
94
94
|
for layer_key, info in self.layer_info.items():
|
|
95
95
|
params, group = info['params'], info['group_ref']
|
|
96
96
|
|
|
97
|
-
if not group.get('kourkoutas_beta', False):
|
|
97
|
+
if not group.get('kourkoutas_beta', False) and not group.get('_kourkoutas_beta', False):
|
|
98
98
|
continue
|
|
99
99
|
|
|
100
100
|
first_param_in_layer = info['params'][0]
|
|
@@ -1,24 +1,24 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=3jnQBYhDjdSEYZxoyhxo98rcBQQVKcAUSFljeebo5X0,379
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=MJfrkPfpR9uRcgB-srphwmon55xKNshVDJBfTybHHUM,21045
|
|
3
3
|
adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
|
|
4
4
|
adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=HaF06fPKcKpVZY29_vqjWHAfivjvGntBuRyDDKj3Ozw,22784
|
|
8
8
|
adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
11
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
12
12
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
|
-
adv_optm/util/Kourkoutas.py,sha256=
|
|
13
|
+
adv_optm/util/Kourkoutas.py,sha256=lObJGXmz3MqGSuu3DKqotSpZ0fuQFPE80R3zO_j3Z_Q,9707
|
|
14
14
|
adv_optm/util/MuonAdam_helper.py,sha256=7rnNMujZVDaqo1g22QscMyPlZvIHQQSLHMED9_I8QWU,1250
|
|
15
15
|
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
16
16
|
adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
|
|
17
17
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
18
18
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
19
19
|
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
24
|
-
adv_optm-1.2.
|
|
20
|
+
adv_optm-1.2.dev8.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
21
|
+
adv_optm-1.2.dev8.dist-info/METADATA,sha256=BMJLtbcTfygjSR8YXCbml_c_0suVEUv97oasoN6jSVs,14022
|
|
22
|
+
adv_optm-1.2.dev8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
adv_optm-1.2.dev8.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
24
|
+
adv_optm-1.2.dev8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|