adv-optm 1.2.dev1__py3-none-any.whl → 1.2.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +8 -4
- adv_optm/optim/Muon_adv.py +57 -2
- adv_optm/util/MuonAdam_helper.py +31 -0
- {adv_optm-1.2.dev1.dist-info → adv_optm-1.2.dev2.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev1.dist-info → adv_optm-1.2.dev2.dist-info}/RECORD +9 -8
- {adv_optm-1.2.dev1.dist-info → adv_optm-1.2.dev2.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev1.dist-info → adv_optm-1.2.dev2.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev1.dist-info → adv_optm-1.2.dev2.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -107,6 +107,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
107
107
|
k_logging: int = 0,
|
|
108
108
|
layer_key_fn: Optional[Callable] = None,
|
|
109
109
|
nnmf_factor: bool = False,
|
|
110
|
+
_is_delegate: bool = False,
|
|
110
111
|
):
|
|
111
112
|
if not (lr >= 0.0):
|
|
112
113
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -137,10 +138,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
137
138
|
self.factored = nnmf_factor
|
|
138
139
|
self.kourkoutas_beta = kourkoutas_beta
|
|
139
140
|
self.layer_key_fn = layer_key_fn
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
self.
|
|
141
|
+
if not _is_delegate:
|
|
142
|
+
super().__init__(params, defaults)
|
|
143
|
+
else:
|
|
144
|
+
self.defaults = defaults
|
|
145
|
+
self.kourkoutas_helper = None
|
|
144
146
|
|
|
145
147
|
@property
|
|
146
148
|
def supports_fused_back_pass(self):
|
|
@@ -158,6 +160,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
158
160
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
159
161
|
if p.grad is None:
|
|
160
162
|
return
|
|
163
|
+
if group.get('kourkoutas_beta', False) and self.kourkoutas_helper is None:
|
|
164
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
161
165
|
|
|
162
166
|
grad = p.grad
|
|
163
167
|
if grad.dtype != torch.float32 and self.factored:
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Callable
|
|
3
|
+
|
|
4
|
+
from .AdamW_adv import AdamW_adv
|
|
5
|
+
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
3
6
|
|
|
4
7
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
8
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -18,6 +21,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
18
21
|
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
19
22
|
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
20
23
|
flattening/reshaping them.
|
|
24
|
+
|
|
25
|
+
This version can also operate in a hybrid mode, using an auxiliary AdamW
|
|
26
|
+
optimizer for specific parameters (e.g., biases, norms, embeddings) as
|
|
27
|
+
defined by a `layer_key_fn`.
|
|
21
28
|
|
|
22
29
|
Args:
|
|
23
30
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
@@ -39,6 +46,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
39
46
|
matrices to apply low-rank compression (default: True).
|
|
40
47
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
41
48
|
the uncompressed optimizer. (default: False)
|
|
49
|
+
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
50
|
+
Parameters designated by `layer_key_fn` will be optimized with
|
|
51
|
+
AdamW_adv instead of Muon. (default: False)
|
|
52
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
53
|
+
and returns a key. If the key is 'adam', the parameter is handled by
|
|
54
|
+
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
55
|
+
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
56
|
+
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
57
|
+
to the auxiliary AdamW_adv optimizer. Only used when
|
|
58
|
+
`MuonWithAuxAdam` is True. (default: None)
|
|
42
59
|
"""
|
|
43
60
|
|
|
44
61
|
def __init__(
|
|
@@ -55,6 +72,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
55
72
|
vector_reshape_muon: bool = False,
|
|
56
73
|
vector_reshape: bool = True,
|
|
57
74
|
nnmf_factor: bool = False,
|
|
75
|
+
# hybrid optimizer mode
|
|
76
|
+
MuonWithAuxAdam: bool = False,
|
|
77
|
+
layer_key_fn: Optional[Callable] = None,
|
|
78
|
+
muon_adam_lr: float = 1e-4,
|
|
79
|
+
adam_kwargs: Optional[dict] = None,
|
|
58
80
|
):
|
|
59
81
|
if not (lr >= 0.0):
|
|
60
82
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -73,8 +95,29 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
73
95
|
"vector_reshape_muon": vector_reshape_muon,
|
|
74
96
|
}
|
|
75
97
|
self.stochastic_rounding = stochastic_rounding
|
|
98
|
+
|
|
99
|
+
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
100
|
+
self.helper = None
|
|
101
|
+
self.aux_adam = None
|
|
102
|
+
|
|
103
|
+
if self.MuonWithAuxAdam:
|
|
104
|
+
adam_kwargs = adam_kwargs or {}
|
|
105
|
+
# Create a delegate AdamW optimizer to get its default hyperparameters.
|
|
106
|
+
self.aux_adam = AdamW_adv(
|
|
107
|
+
[],
|
|
108
|
+
lr=muon_adam_lr,
|
|
109
|
+
**adam_kwargs,
|
|
110
|
+
_is_delegate=True
|
|
111
|
+
)
|
|
112
|
+
# Update the defaults dictionary
|
|
113
|
+
defaults.update(self.aux_adam.defaults)
|
|
114
|
+
|
|
76
115
|
super().__init__(params, defaults)
|
|
77
116
|
|
|
117
|
+
if self.MuonWithAuxAdam:
|
|
118
|
+
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
119
|
+
|
|
120
|
+
|
|
78
121
|
@property
|
|
79
122
|
def supports_fused_back_pass(self):
|
|
80
123
|
return True
|
|
@@ -89,6 +132,18 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
89
132
|
|
|
90
133
|
@torch.no_grad()
|
|
91
134
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
135
|
+
if self.MuonWithAuxAdam:
|
|
136
|
+
optim_type = self.helper.get_optimizer_type(p)
|
|
137
|
+
if optim_type == 'adam':
|
|
138
|
+
# Delegate to the AdamW_adv optimizer's logic.
|
|
139
|
+
# We need to temporarily "lend" our state and param_groups
|
|
140
|
+
# to the delegate so it has the full context to work with,
|
|
141
|
+
# especially for features like Kourkoutas-beta.
|
|
142
|
+
self.aux_adam.state = self.state
|
|
143
|
+
self.aux_adam.param_groups = self.param_groups
|
|
144
|
+
self.aux_adam.step_parameter(p, group, i)
|
|
145
|
+
return
|
|
146
|
+
|
|
92
147
|
if p.grad is None:
|
|
93
148
|
return
|
|
94
149
|
|
|
@@ -244,4 +299,4 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
244
299
|
for i, p in enumerate(group['params']):
|
|
245
300
|
self.step_parameter(p, group, i)
|
|
246
301
|
|
|
247
|
-
return loss
|
|
302
|
+
return loss
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from torch.optim import Optimizer
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
|
|
4
|
+
class MuonAdamHelper:
|
|
5
|
+
"""
|
|
6
|
+
A helper class for Muon_adv to decide whether to use Muon or a delegate
|
|
7
|
+
AdamW optimizer for a given parameter based on a keying function.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, optimizer: Optimizer, layer_key_fn: Optional[Callable]):
|
|
10
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
11
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
12
|
+
self.optimizer = optimizer
|
|
13
|
+
|
|
14
|
+
if layer_key_fn is None:
|
|
15
|
+
# If no function is provided, default all parameters to 'muon'.
|
|
16
|
+
self.layer_key_fn = lambda p: 'muon'
|
|
17
|
+
else:
|
|
18
|
+
self.layer_key_fn = layer_key_fn
|
|
19
|
+
|
|
20
|
+
def get_optimizer_type(self, p: "torch.Tensor") -> str:
|
|
21
|
+
"""
|
|
22
|
+
Returns the designated optimizer type ('adam' or 'muon') for a parameter.
|
|
23
|
+
|
|
24
|
+
The user-provided layer_key_fn should return 'adam' for parameters
|
|
25
|
+
to be handled by the auxiliary AdamW optimizer. Any other return
|
|
26
|
+
value is treated as 'muon'.
|
|
27
|
+
"""
|
|
28
|
+
key = self.layer_key_fn(p)
|
|
29
|
+
if key == 'adam':
|
|
30
|
+
return 'adam'
|
|
31
|
+
return 'muon'
|
|
@@ -1,22 +1,23 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=THWhNF8-PI71K9Au4xAkuDs96YcEagJ-yT5r_g2-yKw,341
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=Zym0beeu0ye5_PgpAjpzcYghdPYFWs3gQzDmuPZVR80,17690
|
|
3
3
|
adv_optm/optim/Adopt_adv.py,sha256=NXbtPrGm3tZr06cApi5oEHZ2F1zwss3tRi15SGnrYPc,21426
|
|
4
4
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
5
5
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
6
|
+
adv_optm/optim/Muon_adv.py,sha256=9K5YR3odaGfDDZzasletHRlqxG8xN9IXj6oiqx1CaEI,12423
|
|
7
7
|
adv_optm/optim/Prodigy_adv.py,sha256=0_XG5YnMQTv-zJysJHlJniSo5kGYdX3p3o1e33HLt78,25897
|
|
8
8
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
9
9
|
adv_optm/optim/__init__.py,sha256=3o2XJ4J-PUq3rJM2mBnmuHwbKNb4LuW-Ig_9aBC0ycc,431
|
|
10
10
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
11
11
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
12
12
|
adv_optm/util/Kourkoutas.py,sha256=woyJfX7l4eieeg0pC5XrILBLvwECwbD3a6ou1K6qjKU,8706
|
|
13
|
+
adv_optm/util/MuonAdam_helper.py,sha256=llPCc9MBFen_wodbY4G2E17tBZky8clDiJSZLHkMva8,1236
|
|
13
14
|
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
14
15
|
adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
|
|
15
16
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
16
17
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
17
18
|
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
18
|
-
adv_optm-1.2.
|
|
19
|
-
adv_optm-1.2.
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
19
|
+
adv_optm-1.2.dev2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev2.dist-info/METADATA,sha256=JTCPGBJUd4JR7DU26AhX8qSPzWrSVtEwv9Au7I3iEPY,14022
|
|
21
|
+
adv_optm-1.2.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|