adv-optm 1.2.dev6__py3-none-any.whl → 1.2.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/Muon_adv.py +165 -18
- {adv_optm-1.2.dev6.dist-info → adv_optm-1.2.dev7.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev6.dist-info → adv_optm-1.2.dev7.dist-info}/RECORD +7 -7
- {adv_optm-1.2.dev6.dist-info → adv_optm-1.2.dev7.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev6.dist-info → adv_optm-1.2.dev7.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev6.dist-info → adv_optm-1.2.dev7.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -18,6 +18,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
18
18
|
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
19
19
|
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
20
20
|
|
|
21
|
+
NorMuon (Neuron-wise Normalized Muon) extends this by adding neuron-level
|
|
22
|
+
adaptive learning rates, combining the benefits of orthogonalization with
|
|
23
|
+
second-order momentum statistics.
|
|
24
|
+
|
|
21
25
|
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
22
26
|
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
23
27
|
flattening/reshaping them.
|
|
@@ -54,6 +58,19 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
54
58
|
matrices to apply low-rank compression (default: True).
|
|
55
59
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
56
60
|
the uncompressed optimizer. (default: False)
|
|
61
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
62
|
+
projects the update to a lower rank before orthogonalization.
|
|
63
|
+
(default: False)
|
|
64
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
65
|
+
(default: 128)
|
|
66
|
+
normuon_variant (bool): If True, enables the NorMuon update rule, which adds
|
|
67
|
+
neuron-wise normalization. (default: False)
|
|
68
|
+
beta2_normuon (float): The exponential decay rate for the second moment estimates
|
|
69
|
+
used in NorMuon. (default: 0.95)
|
|
70
|
+
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
71
|
+
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
72
|
+
(default: 0.2)
|
|
73
|
+
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
57
74
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
58
75
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
59
76
|
AdamW_adv instead of Muon. (default: False)
|
|
@@ -82,6 +99,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
82
99
|
vector_reshape_muon: bool = False,
|
|
83
100
|
vector_reshape: bool = False,
|
|
84
101
|
nnmf_factor: bool = False,
|
|
102
|
+
# Low-rank Muon
|
|
103
|
+
low_rank_ortho: bool = False,
|
|
104
|
+
ortho_rank: int = 128,
|
|
105
|
+
# NorMuon additions
|
|
106
|
+
normuon_variant: bool = False,
|
|
107
|
+
beta2_normuon: float = 0.95,
|
|
108
|
+
normuon_eps: float = 1e-8,
|
|
109
|
+
normuon_lr_scale: float = 0.2,
|
|
110
|
+
normuon_atan2: bool = False,
|
|
85
111
|
# hybrid optimizer mode
|
|
86
112
|
MuonWithAuxAdam: bool = False,
|
|
87
113
|
layer_key_fn: Optional[Callable] = None,
|
|
@@ -92,6 +118,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
118
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
93
119
|
if not (0.0 <= beta1 < 1.0):
|
|
94
120
|
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
121
|
+
if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
|
|
122
|
+
raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
|
|
95
123
|
if not (weight_decay >= 0.0):
|
|
96
124
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
97
125
|
if not (ns_steps > 0):
|
|
@@ -106,10 +134,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
106
134
|
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
107
135
|
"vector_reshape": vector_reshape,
|
|
108
136
|
"vector_reshape_muon": vector_reshape_muon,
|
|
109
|
-
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
137
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
138
|
+
# Low-rank Ortho
|
|
139
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
140
|
+
# NorMuon
|
|
141
|
+
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
142
|
+
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
143
|
+
"normuon_atan2": normuon_atan2,
|
|
110
144
|
}
|
|
111
145
|
self.stochastic_rounding = stochastic_rounding
|
|
112
|
-
|
|
146
|
+
|
|
113
147
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
114
148
|
self.helper = None
|
|
115
149
|
self.aux_adam = None
|
|
@@ -223,6 +257,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
223
257
|
elif len(p.shape) == 1:
|
|
224
258
|
state['momentum_buffer'] = torch.zeros_like(p)
|
|
225
259
|
|
|
260
|
+
# NorMuon state initialization
|
|
261
|
+
if group['normuon_variant']:
|
|
262
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
263
|
+
num_rows = p.shape[0] if len(p.shape) >= 2 else state['effective_shape'][0]
|
|
264
|
+
state['normuon_v'] = torch.zeros(num_rows, device=p.device, dtype=torch.float32)
|
|
265
|
+
|
|
226
266
|
beta1 = group['beta1']
|
|
227
267
|
nesterov = group['nesterov']
|
|
228
268
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
@@ -251,14 +291,60 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
251
291
|
update = mt_buf.clone()
|
|
252
292
|
del grad_reshaped
|
|
253
293
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
294
|
+
# Orthogonalization step
|
|
295
|
+
if group['low_rank_muon']:
|
|
296
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
297
|
+
M = update
|
|
298
|
+
r = min(group['low_rank_rank'], M.shape[0], M.shape[1])
|
|
299
|
+
if r > 0:
|
|
300
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
301
|
+
MG = M @ G_sketch
|
|
302
|
+
if MG.dtype != torch.float32:
|
|
303
|
+
MG_dtype = M.dtype
|
|
304
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
305
|
+
Q = Q.to(MG_dtype)
|
|
306
|
+
else:
|
|
307
|
+
Q, _ = torch.linalg.qr(MG)
|
|
308
|
+
projected_M = Q.T @ M
|
|
309
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
310
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
311
|
+
)
|
|
312
|
+
update = Q @ ortho_projected_M
|
|
313
|
+
else: # Fallback for invalid rank
|
|
314
|
+
update = _newton_schulz_iteration(
|
|
315
|
+
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
# Original full Newton-Schulz
|
|
319
|
+
update = _newton_schulz_iteration(
|
|
320
|
+
update,
|
|
321
|
+
steps=group['ns_steps'],
|
|
322
|
+
eps=group['ns_eps'],
|
|
323
|
+
coeffs=group['ns_coeffs'],
|
|
324
|
+
)
|
|
260
325
|
|
|
261
|
-
|
|
326
|
+
|
|
327
|
+
if group['normuon_variant'] and 'normuon_v' in state:
|
|
328
|
+
v_t = state['normuon_v']
|
|
329
|
+
beta2_normuon = group['beta2_normuon']
|
|
330
|
+
# Update 2nd moment estimate
|
|
331
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
332
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
333
|
+
# Normalize update
|
|
334
|
+
if group['normuon_atan2']:
|
|
335
|
+
a = 1.2732395
|
|
336
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
337
|
+
else:
|
|
338
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
339
|
+
# Scale learning rate
|
|
340
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
341
|
+
if update_norm > 1e-12:
|
|
342
|
+
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
343
|
+
else:
|
|
344
|
+
scaled_lr = 0.0
|
|
345
|
+
update = update.view(p.shape).mul_(scaled_lr)
|
|
346
|
+
else: # Original Muon learning rate application
|
|
347
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
262
348
|
|
|
263
349
|
state['sign'] = _pack_bools(mt_buf > 0)
|
|
264
350
|
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
@@ -298,19 +384,80 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
298
384
|
if len(p.shape) > 2:
|
|
299
385
|
update = update.view(p.shape[0], -1)
|
|
300
386
|
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
387
|
+
# Orthogonalization step
|
|
388
|
+
if group['low_rank_ortho']:
|
|
389
|
+
# Low-Rank Orthogonalization based on Gaussian Sketching
|
|
390
|
+
M = update
|
|
391
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
392
|
+
|
|
393
|
+
if r > 0:
|
|
394
|
+
# 1. Sketch the matrix
|
|
395
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
396
|
+
MG = M @ G_sketch
|
|
397
|
+
|
|
398
|
+
# 2. QR decomposition to get orthogonal basis Q
|
|
399
|
+
if MG.dtype != torch.float32:
|
|
400
|
+
MG_dtype = M.dtype
|
|
401
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
402
|
+
Q = Q.to(MG_dtype)
|
|
403
|
+
else:
|
|
404
|
+
Q, _ = torch.linalg.qr(MG)
|
|
405
|
+
|
|
406
|
+
# 3. Project M onto the basis
|
|
407
|
+
projected_M = Q.T @ M
|
|
408
|
+
|
|
409
|
+
# 4. Orthogonalize the smaller projected matrix
|
|
410
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
411
|
+
projected_M,
|
|
412
|
+
steps=group['ns_steps'],
|
|
413
|
+
eps=group['ns_eps'],
|
|
414
|
+
coeffs=group['ns_coeffs'],
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# 5. Project back to the original space
|
|
418
|
+
update = Q @ ortho_projected_M
|
|
419
|
+
else: # Fallback for invalid rank
|
|
420
|
+
update = _newton_schulz_iteration(
|
|
421
|
+
update,
|
|
422
|
+
steps=group['ns_steps'],
|
|
423
|
+
eps=group['ns_eps'],
|
|
424
|
+
coeffs=group['ns_coeffs'],
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Original NewtonSchulz
|
|
428
|
+
update = _newton_schulz_iteration(
|
|
429
|
+
update,
|
|
430
|
+
steps=group['ns_steps'],
|
|
431
|
+
eps=group['ns_eps'],
|
|
432
|
+
coeffs=group['ns_coeffs'],
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# NorMuon Logic
|
|
436
|
+
if group['normuon_variant'] and 'normuon_v' in state:
|
|
437
|
+
v_t = state['normuon_v']
|
|
438
|
+
beta2_normuon = group['beta2_normuon']
|
|
439
|
+
# Update 2nd moment estimate
|
|
440
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
441
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
442
|
+
# Normalize update
|
|
443
|
+
if group['normuon_atan2']:
|
|
444
|
+
a = 1.2732395
|
|
445
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
446
|
+
else:
|
|
447
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
448
|
+
# Scale learning rate
|
|
449
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
450
|
+
if update_norm > 1e-12:
|
|
451
|
+
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
452
|
+
else:
|
|
453
|
+
scaled_lr = 0.0
|
|
454
|
+
update.mul_(scaled_lr)
|
|
455
|
+
else: # Original Muon learning rate application
|
|
456
|
+
update.mul_(group['lr'])
|
|
308
457
|
|
|
309
458
|
# Reshape back to original if we flattened or reshaped
|
|
310
459
|
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
311
460
|
update = update.view(p.shape)
|
|
312
|
-
|
|
313
|
-
update.mul_(group['lr'])
|
|
314
461
|
|
|
315
462
|
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
316
463
|
# Momentum update
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=93-4akpaONvZ7BCkwDSYs3i28lI-aFV1RVwsEq1UZhU,379
|
|
2
2
|
adv_optm/optim/AdaMuon_adv.py,sha256=hTGSH8wzmQ-NYIcqV6EAEbqCxxfEwmmMWaIadX1qiuQ,21009
|
|
3
3
|
adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
|
|
4
4
|
adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=JBLLfU83lRwezowI6A4JQAO1-NBLvSDOB8Dsad5zuHU,22775
|
|
8
8
|
adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
@@ -17,8 +17,8 @@ adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnC
|
|
|
17
17
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
18
18
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
19
19
|
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
24
|
-
adv_optm-1.2.
|
|
20
|
+
adv_optm-1.2.dev7.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
21
|
+
adv_optm-1.2.dev7.dist-info/METADATA,sha256=mAEVDwu_gh6S-fN6LBfEJoYdn_5LJLOw_nHRZcE7orw,14022
|
|
22
|
+
adv_optm-1.2.dev7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
adv_optm-1.2.dev7.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
24
|
+
adv_optm-1.2.dev7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|