adv-optm 1.2.dev9__py3-none-any.whl → 1.2.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +4 -80
- adv_optm/optim/AdamW_adv.py +8 -12
- adv_optm/optim/Adopt_adv.py +3 -3
- adv_optm/optim/Muon_adv.py +3 -83
- adv_optm/optim/Prodigy_adv.py +7 -4
- adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- adv_optm/util/Kourkoutas.py +10 -46
- adv_optm/util/__init__.py +1 -0
- {adv_optm-1.2.dev9.dist-info → adv_optm-1.2.dev11.dist-info}/METADATA +1 -1
- adv_optm-1.2.dev11.dist-info/RECORD +23 -0
- adv_optm/util/MuonAdam_helper.py +0 -32
- adv_optm-1.2.dev9.dist-info/RECORD +0 -24
- {adv_optm-1.2.dev9.dist-info → adv_optm-1.2.dev11.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev9.dist-info → adv_optm-1.2.dev11.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev9.dist-info → adv_optm-1.2.dev11.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from .AdamW_adv import AdamW_adv
|
|
5
|
-
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
5
|
|
|
7
6
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
7
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -12,7 +11,7 @@ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
|
12
11
|
|
|
13
12
|
class AdaMuon_adv(torch.optim.Optimizer):
|
|
14
13
|
"""
|
|
15
|
-
|
|
14
|
+
IImplements an advanced AdaMuon optimizer algorithm.
|
|
16
15
|
|
|
17
16
|
AdaMuon combines the geometry-aware updates of Muon with the element-wise
|
|
18
17
|
adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
|
|
@@ -26,9 +25,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
26
25
|
3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
|
|
27
26
|
allowing for reuse of learning rate schedules.
|
|
28
27
|
|
|
29
|
-
Can also operate in a hybrid mode, using an auxiliary AdamW
|
|
30
|
-
optimizer for specific parameters (e.g., biases, norms, embeddings) as
|
|
31
|
-
defined by a `layer_key_fn`.
|
|
32
28
|
|
|
33
29
|
Args:
|
|
34
30
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
@@ -70,16 +66,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
70
66
|
(default: 128)
|
|
71
67
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
72
68
|
the uncompressed optimizer. (default: False)
|
|
73
|
-
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
74
|
-
Parameters designated by `layer_key_fn` will be optimized with
|
|
75
|
-
AdamW_adv instead of Muon. (default: False)
|
|
76
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
-
and returns a key. If the key is 'adam', the parameter is handled by
|
|
78
|
-
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
79
|
-
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
80
|
-
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
81
|
-
to the auxiliary AdamW_adv optimizer. Only used when
|
|
82
|
-
`MuonWithAuxAdam` is True. (default: None)
|
|
83
69
|
"""
|
|
84
70
|
|
|
85
71
|
def __init__(
|
|
@@ -104,11 +90,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
104
90
|
low_rank_ortho: bool = False,
|
|
105
91
|
ortho_rank: int = 128,
|
|
106
92
|
nnmf_factor: bool = False,
|
|
107
|
-
# hybrid optimizer mode
|
|
108
|
-
MuonWithAuxAdam: bool = False,
|
|
109
|
-
layer_key_fn: Optional[Callable] = None,
|
|
110
|
-
muon_adam_lr: float = 1e-4,
|
|
111
|
-
adam_kwargs: Optional[dict] = None,
|
|
112
93
|
):
|
|
113
94
|
if not (lr >= 0.0):
|
|
114
95
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -120,7 +101,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
120
101
|
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
121
102
|
nesterov = False
|
|
122
103
|
|
|
123
|
-
|
|
104
|
+
defaults = {
|
|
124
105
|
"lr": lr, "betas": betas, "weight_decay": weight_decay,
|
|
125
106
|
"eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
|
|
126
107
|
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
@@ -132,44 +113,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
132
113
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
133
114
|
}
|
|
134
115
|
self.stochastic_rounding = stochastic_rounding
|
|
135
|
-
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
136
|
-
self.helper = None
|
|
137
|
-
self.aux_adam = None
|
|
138
|
-
|
|
139
|
-
if not self.MuonWithAuxAdam:
|
|
140
|
-
super().__init__(params, muon_defaults)
|
|
141
|
-
return
|
|
142
116
|
|
|
143
|
-
|
|
144
|
-
adam_kwargs = adam_kwargs or {}
|
|
145
|
-
self.aux_adam = AdamW_adv(
|
|
146
|
-
[],
|
|
147
|
-
lr=muon_adam_lr,
|
|
148
|
-
**adam_kwargs,
|
|
149
|
-
_is_delegate=True
|
|
150
|
-
)
|
|
151
|
-
adam_defaults = self.aux_adam.defaults
|
|
152
|
-
|
|
153
|
-
final_param_groups = []
|
|
154
|
-
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
155
|
-
|
|
156
|
-
for group in params:
|
|
157
|
-
# All params in a group are of the same type
|
|
158
|
-
first_param = group['params'][0]
|
|
159
|
-
key = _layer_key_fn(first_param)
|
|
160
|
-
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
161
|
-
|
|
162
|
-
new_group = group.copy()
|
|
163
|
-
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
164
|
-
|
|
165
|
-
for key, value in defaults_to_use.items():
|
|
166
|
-
new_group.setdefault(key, value)
|
|
167
|
-
final_param_groups.append(new_group)
|
|
168
|
-
|
|
169
|
-
super().__init__(final_param_groups, muon_defaults)
|
|
170
|
-
|
|
171
|
-
# Now that self is initialized, create the helper
|
|
172
|
-
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
117
|
+
super().__init__(params, defaults)
|
|
173
118
|
|
|
174
119
|
|
|
175
120
|
@property
|
|
@@ -184,29 +129,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
184
129
|
def supports_flat_params(self):
|
|
185
130
|
return False
|
|
186
131
|
|
|
187
|
-
@property
|
|
188
|
-
def kourkoutas_helper(self):
|
|
189
|
-
"""
|
|
190
|
-
Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
|
|
191
|
-
if it exists. This allows external access for logging K-b.
|
|
192
|
-
"""
|
|
193
|
-
if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
|
|
194
|
-
return self.aux_adam.kourkoutas_helper
|
|
195
|
-
return None
|
|
196
|
-
|
|
197
132
|
@torch.no_grad()
|
|
198
133
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
199
|
-
if self.MuonWithAuxAdam:
|
|
200
|
-
optim_type = self.helper.get_optimizer_type(p)
|
|
201
|
-
if optim_type == 'adam':
|
|
202
|
-
# Delegate to the AdamW_adv optimizer's logic.
|
|
203
|
-
# We need to temporarily "lend" our state and param_groups
|
|
204
|
-
self.aux_adam.state = self.state
|
|
205
|
-
self.aux_adam.param_groups = self.param_groups
|
|
206
|
-
|
|
207
|
-
self.aux_adam.step_parameter(p, group, i)
|
|
208
|
-
return
|
|
209
|
-
|
|
210
134
|
if p.grad is None:
|
|
211
135
|
return
|
|
212
136
|
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
73
73
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
74
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
75
|
logging (default: 0).
|
|
76
|
-
|
|
76
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
77
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
78
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
79
|
(default: None)
|
|
@@ -105,9 +105,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
105
105
|
tiny_spike: float = 1e-9,
|
|
106
106
|
k_warmup_steps: int = 0,
|
|
107
107
|
k_logging: int = 0,
|
|
108
|
-
|
|
108
|
+
layer_key_fn: Optional[Callable] = None,
|
|
109
109
|
nnmf_factor: bool = False,
|
|
110
|
-
_is_delegate: bool = False,
|
|
111
110
|
):
|
|
112
111
|
if not (lr >= 0.0):
|
|
113
112
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -137,12 +136,11 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
137
136
|
self.use_AdEMAMix = use_AdEMAMix
|
|
138
137
|
self.factored = nnmf_factor
|
|
139
138
|
self.kourkoutas_beta = kourkoutas_beta
|
|
140
|
-
self.
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
self.
|
|
145
|
-
self.kourkoutas_helper = None
|
|
139
|
+
self.layer_key_fn = layer_key_fn
|
|
140
|
+
super().__init__(params, defaults)
|
|
141
|
+
|
|
142
|
+
if self.kourkoutas_beta:
|
|
143
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
146
144
|
|
|
147
145
|
@property
|
|
148
146
|
def supports_fused_back_pass(self):
|
|
@@ -160,8 +158,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
160
158
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
161
159
|
if p.grad is None:
|
|
162
160
|
return
|
|
163
|
-
if group.get('kourkoutas_beta', False) and self.kourkoutas_helper is None:
|
|
164
|
-
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
165
161
|
|
|
166
162
|
grad = p.grad
|
|
167
163
|
if grad.dtype != torch.float32 and self.factored:
|
|
@@ -244,6 +240,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
244
240
|
|
|
245
241
|
if state['factored']:
|
|
246
242
|
d1, d2 = state['effective_shape']
|
|
243
|
+
grad_reshaped = grad.view(d1, d2)
|
|
247
244
|
|
|
248
245
|
# Reconstruct momentum from previous step's factors
|
|
249
246
|
if beta1 > 0:
|
|
@@ -253,7 +250,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
253
250
|
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
254
251
|
del unpacked_sign
|
|
255
252
|
# Update momentum in full-size
|
|
256
|
-
grad_reshaped = grad.view(d1, d2)
|
|
257
253
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
258
254
|
if self.grams_moment:
|
|
259
255
|
mt.copy_(grad_reshaped.sign() * mt.abs())
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
91
91
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
92
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
93
|
logging (default: 0).
|
|
94
|
-
|
|
94
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
95
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
96
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
97
|
(default: None)
|
|
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
125
125
|
tiny_spike: float = 1e-9,
|
|
126
126
|
k_warmup_steps: int = 0,
|
|
127
127
|
k_logging: int = 0,
|
|
128
|
-
|
|
128
|
+
layer_key_fn: Optional[Callable] = None,
|
|
129
129
|
nnmf_factor: bool = False,
|
|
130
130
|
):
|
|
131
131
|
if not (lr >= 0.0):
|
|
@@ -166,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
166
166
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
167
167
|
self.factored = nnmf_factor
|
|
168
168
|
self.kourkoutas_beta = kourkoutas_beta
|
|
169
|
-
self.
|
|
169
|
+
self.layer_key_fn = layer_key_fn
|
|
170
170
|
super().__init__(params, defaults)
|
|
171
171
|
|
|
172
172
|
if self.kourkoutas_beta:
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
3
|
-
|
|
2
|
+
from typing import Optional
|
|
4
3
|
from .AdamW_adv import AdamW_adv
|
|
5
|
-
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
4
|
|
|
7
5
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
6
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -25,10 +23,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
25
23
|
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
26
24
|
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
27
25
|
flattening/reshaping them.
|
|
28
|
-
|
|
29
|
-
Can also operate in a hybrid mode, using an auxiliary AdamW
|
|
30
|
-
optimizer for specific parameters (e.g., biases, norms, embeddings) as
|
|
31
|
-
defined by a `layer_key_fn`.
|
|
32
26
|
|
|
33
27
|
Args:
|
|
34
28
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
@@ -71,16 +65,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
71
65
|
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
72
66
|
(default: 0.2)
|
|
73
67
|
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
74
|
-
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
75
|
-
Parameters designated by `layer_key_fn` will be optimized with
|
|
76
|
-
AdamW_adv instead of Muon. (default: False)
|
|
77
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
78
|
-
and returns a key. If the key is 'adam', the parameter is handled by
|
|
79
|
-
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
80
|
-
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
81
|
-
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
82
|
-
to the auxiliary AdamW_adv optimizer. Only used when
|
|
83
|
-
`MuonWithAuxAdam` is True. (default: None)
|
|
84
68
|
"""
|
|
85
69
|
|
|
86
70
|
def __init__(
|
|
@@ -108,11 +92,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
108
92
|
normuon_eps: float = 1e-8,
|
|
109
93
|
normuon_lr_scale: float = 0.2,
|
|
110
94
|
normuon_atan2: bool = False,
|
|
111
|
-
# hybrid optimizer mode
|
|
112
|
-
MuonWithAuxAdam: bool = False,
|
|
113
|
-
layer_key_fn: Optional[Callable] = None,
|
|
114
|
-
muon_adam_lr: float = 1e-4,
|
|
115
|
-
adam_kwargs: Optional[dict] = None,
|
|
116
95
|
):
|
|
117
96
|
if not (lr >= 0.0):
|
|
118
97
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -128,7 +107,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
128
107
|
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
129
108
|
nesterov = False
|
|
130
109
|
|
|
131
|
-
|
|
110
|
+
defaults = {
|
|
132
111
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
133
112
|
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
134
113
|
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
@@ -144,44 +123,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
144
123
|
}
|
|
145
124
|
self.stochastic_rounding = stochastic_rounding
|
|
146
125
|
|
|
147
|
-
|
|
148
|
-
self.helper = None
|
|
149
|
-
self.aux_adam = None
|
|
150
|
-
|
|
151
|
-
if not self.MuonWithAuxAdam:
|
|
152
|
-
super().__init__(params, muon_defaults)
|
|
153
|
-
return
|
|
154
|
-
|
|
155
|
-
# HYBRID OPTIMIZER LOGIC
|
|
156
|
-
adam_kwargs = adam_kwargs or {}
|
|
157
|
-
self.aux_adam = AdamW_adv(
|
|
158
|
-
[],
|
|
159
|
-
lr=muon_adam_lr,
|
|
160
|
-
**adam_kwargs,
|
|
161
|
-
_is_delegate=True
|
|
162
|
-
)
|
|
163
|
-
adam_defaults = self.aux_adam.defaults
|
|
164
|
-
|
|
165
|
-
final_param_groups = []
|
|
166
|
-
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
167
|
-
|
|
168
|
-
for group in params:
|
|
169
|
-
first_param = group['params'][0]
|
|
170
|
-
key = _layer_key_fn(first_param)
|
|
171
|
-
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
172
|
-
|
|
173
|
-
new_group = group.copy()
|
|
174
|
-
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
175
|
-
|
|
176
|
-
for key, value in defaults_to_use.items():
|
|
177
|
-
new_group.setdefault(key, value)
|
|
178
|
-
|
|
179
|
-
final_param_groups.append(new_group)
|
|
180
|
-
|
|
181
|
-
super().__init__(final_param_groups, muon_defaults)
|
|
182
|
-
|
|
183
|
-
# Now that self is initialized, create the helper
|
|
184
|
-
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
126
|
+
super().__init__(params, defaults)
|
|
185
127
|
|
|
186
128
|
|
|
187
129
|
@property
|
|
@@ -196,30 +138,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
196
138
|
def supports_flat_params(self):
|
|
197
139
|
return False
|
|
198
140
|
|
|
199
|
-
@property
|
|
200
|
-
def kourkoutas_helper(self):
|
|
201
|
-
"""
|
|
202
|
-
Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
|
|
203
|
-
if it exists. This allows external access for logging K-b.
|
|
204
|
-
"""
|
|
205
|
-
if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
|
|
206
|
-
return self.aux_adam.kourkoutas_helper
|
|
207
|
-
return None
|
|
208
|
-
|
|
209
141
|
@torch.no_grad()
|
|
210
142
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
211
|
-
if self.MuonWithAuxAdam:
|
|
212
|
-
optim_type = self.helper.get_optimizer_type(p)
|
|
213
|
-
if optim_type == 'adam':
|
|
214
|
-
# Delegate to the AdamW_adv optimizer's logic.
|
|
215
|
-
# We need to temporarily "lend" our state and param_groups
|
|
216
|
-
# to the delegate so it has the full context to work with,
|
|
217
|
-
# especially for features like Kourkoutas-beta.
|
|
218
|
-
self.aux_adam.state = self.state
|
|
219
|
-
self.aux_adam.param_groups = self.param_groups
|
|
220
|
-
self.aux_adam.step_parameter(p, group, i)
|
|
221
|
-
return
|
|
222
|
-
|
|
223
143
|
if p.grad is None:
|
|
224
144
|
return
|
|
225
145
|
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
109
109
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
110
110
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
111
111
|
logging (default: 0).
|
|
112
|
-
|
|
112
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
113
113
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
114
114
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
115
115
|
(default: None)
|
|
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
152
152
|
tiny_spike: float = 1e-9,
|
|
153
153
|
k_warmup_steps: int = 0,
|
|
154
154
|
k_logging: int = 0,
|
|
155
|
-
|
|
155
|
+
layer_key_fn: Optional[Callable] = None,
|
|
156
156
|
):
|
|
157
157
|
if not (lr >= 0.0):
|
|
158
158
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
205
205
|
self.fsdp_in_use = fsdp_in_use
|
|
206
206
|
|
|
207
207
|
self.kourkoutas_beta = kourkoutas_beta
|
|
208
|
-
self.
|
|
208
|
+
self.layer_key_fn = layer_key_fn
|
|
209
209
|
|
|
210
210
|
super().__init__(params, defaults)
|
|
211
211
|
if self.kourkoutas_beta:
|
|
@@ -516,7 +516,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
516
516
|
if global_d_denom > 0:
|
|
517
517
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
518
518
|
if g_group.get('d_limiter', False):
|
|
519
|
-
|
|
519
|
+
if g_group.get('Simplified_AdEMAMix', False):
|
|
520
|
+
d_hat = min(self.d * (2 ** 0.1), d_hat)
|
|
521
|
+
else:
|
|
522
|
+
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
520
523
|
if self.d == g_group['d0']:
|
|
521
524
|
self.d = max(self.d, d_hat)
|
|
522
525
|
d_max = max(d_max, d_hat)
|
|
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
67
67
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
68
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
69
|
logging (default: 0).
|
|
70
|
-
|
|
70
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
71
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
72
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
73
|
(default: None)
|
|
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
95
95
|
tiny_spike: float = 1e-9,
|
|
96
96
|
k_warmup_steps: int = 0,
|
|
97
97
|
k_logging: int = 0,
|
|
98
|
-
|
|
98
|
+
layer_key_fn: Optional[Callable] = None,
|
|
99
99
|
nnmf_factor: bool = False,
|
|
100
100
|
):
|
|
101
101
|
if not (lr >= 0.0):
|
|
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
121
121
|
self.stochastic_rounding = stochastic_rounding
|
|
122
122
|
self.factored = nnmf_factor
|
|
123
123
|
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
-
self.
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
125
125
|
super().__init__(params, defaults)
|
|
126
126
|
|
|
127
127
|
if self.kourkoutas_beta:
|
adv_optm/util/Kourkoutas.py
CHANGED
|
@@ -24,57 +24,35 @@ class KourkoutasHelper:
|
|
|
24
24
|
# making it compatible with fused back pass mechanisms.
|
|
25
25
|
self._build_layer_info_if_needed()
|
|
26
26
|
|
|
27
|
-
if self.optimizer.param_groups[0].get('k_logging', 0) > 0:
|
|
28
|
-
self.print_layer_info()
|
|
29
|
-
|
|
30
27
|
def _build_layer_info_if_needed(self):
|
|
31
28
|
"""Builds a map of layers and the parameters they contain."""
|
|
32
29
|
if self._layer_info_built:
|
|
33
30
|
return
|
|
34
31
|
|
|
35
|
-
if hasattr(self.optimizer, '
|
|
32
|
+
if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
|
|
36
33
|
# A custom key function was provided by the user. We will use it.
|
|
37
34
|
pass
|
|
38
35
|
else:
|
|
39
36
|
# No key function was provided. Default to coarse, shape-based bucketing.
|
|
40
|
-
self.optimizer.
|
|
37
|
+
self.optimizer.layer_key_fn = lambda p: \
|
|
41
38
|
(id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
|
|
42
39
|
else tuple(p.shape)
|
|
43
40
|
# This ensures that we won't mix embeddings with tokens (1 to 10)
|
|
44
41
|
# TODO find a better way to safeguard the embeddings
|
|
45
42
|
|
|
46
43
|
for group in self.optimizer.param_groups:
|
|
44
|
+
if not group.get('kourkoutas_beta', False):
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
47
|
for p in group['params']:
|
|
48
48
|
# The mapping is static and should not depend on the presence of a gradient.
|
|
49
|
-
layer_key = self.optimizer.
|
|
49
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
50
50
|
if layer_key not in self.layer_info:
|
|
51
51
|
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
52
52
|
self.layer_info[layer_key]['params'].append(p)
|
|
53
|
-
|
|
54
|
-
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
55
|
-
if k_logging_interval > 0:
|
|
56
|
-
print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
|
|
57
53
|
|
|
58
54
|
self._layer_info_built = True
|
|
59
55
|
|
|
60
|
-
def print_layer_info(self):
|
|
61
|
-
"""Prints the contents of self.layer_info for debugging."""
|
|
62
|
-
print("\n--- BEGIN self.layer_info DUMP ---")
|
|
63
|
-
if not self.layer_info:
|
|
64
|
-
print("Layer info is empty. Make sure the optimizer has parameters.")
|
|
65
|
-
return
|
|
66
|
-
|
|
67
|
-
for layer_key, info in self.layer_info.items():
|
|
68
|
-
param_count = len(info['params'])
|
|
69
|
-
first_param_details = ""
|
|
70
|
-
if param_count > 0:
|
|
71
|
-
p = info['params'][0]
|
|
72
|
-
first_param_details = f" (Example param shape: {list(p.shape)}, dtype: {p.dtype})"
|
|
73
|
-
|
|
74
|
-
print(f"Key: {layer_key}, Params: {param_count}{first_param_details}")
|
|
75
|
-
|
|
76
|
-
print("--- END self.layer_info DUMP ---\n")
|
|
77
|
-
|
|
78
56
|
def prepare_step(self, current_step: int):
|
|
79
57
|
"""
|
|
80
58
|
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
@@ -82,9 +60,8 @@ class KourkoutasHelper:
|
|
|
82
60
|
"""
|
|
83
61
|
|
|
84
62
|
beta2_log = []
|
|
85
|
-
first_layer_key = next(iter(self.layer_info), None)
|
|
86
63
|
# These are just for the sample log, initialize them
|
|
87
|
-
sun, pooled_grad_norm,
|
|
64
|
+
sun, pooled_grad_norm, r_ema_tensor = (torch.tensor(0.0),)*3
|
|
88
65
|
|
|
89
66
|
# The optimizer that owns this helper holds the master defaults for K-b.
|
|
90
67
|
# This is crucial in hybrid optimizers where some param_groups might not
|
|
@@ -94,7 +71,7 @@ class KourkoutasHelper:
|
|
|
94
71
|
for layer_key, info in self.layer_info.items():
|
|
95
72
|
params, group = info['params'], info['group_ref']
|
|
96
73
|
|
|
97
|
-
if not group.get('kourkoutas_beta', False)
|
|
74
|
+
if not group.get('kourkoutas_beta', False):
|
|
98
75
|
continue
|
|
99
76
|
|
|
100
77
|
first_param_in_layer = info['params'][0]
|
|
@@ -121,7 +98,6 @@ class KourkoutasHelper:
|
|
|
121
98
|
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
122
99
|
|
|
123
100
|
pooled_grad_norm = torch.sqrt(accumulator)
|
|
124
|
-
prev_r_ema_val = r_ema_tensor.item() # for logging
|
|
125
101
|
|
|
126
102
|
# Update the persistent EMA tensor in-place.
|
|
127
103
|
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
@@ -147,21 +123,9 @@ class KourkoutasHelper:
|
|
|
147
123
|
if beta2_log:
|
|
148
124
|
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
149
125
|
self.last_beta2_stats = {
|
|
150
|
-
'min': beta2_tensor.min().item(),
|
|
151
|
-
'max': beta2_tensor.max().item(),
|
|
152
126
|
'mean': beta2_tensor.mean().item(),
|
|
153
127
|
}
|
|
154
128
|
|
|
155
|
-
# Handle periodic console logging
|
|
156
|
-
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
157
|
-
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
158
|
-
if is_logging_step and self.last_beta2_stats:
|
|
159
|
-
if first_layer_key:
|
|
160
|
-
print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{first_layer_key}':")
|
|
161
|
-
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema_tensor.item():.4e}")
|
|
162
|
-
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {self.layer_state[first_layer_key]['dynamic_beta2']:.4f}")
|
|
163
|
-
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
|
|
164
|
-
|
|
165
129
|
def maybe_prepare_step(self, current_step: int):
|
|
166
130
|
"""
|
|
167
131
|
A universal guard that calls prepare_step() exactly once per training step.
|
|
@@ -174,7 +138,7 @@ class KourkoutasHelper:
|
|
|
174
138
|
"""
|
|
175
139
|
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
176
140
|
"""
|
|
177
|
-
layer_key = self.optimizer.
|
|
141
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
178
142
|
|
|
179
143
|
if layer_key in self.layer_info:
|
|
180
144
|
# Initialize the transient state for this layer if it's the first time in the step.
|
|
@@ -189,6 +153,6 @@ class KourkoutasHelper:
|
|
|
189
153
|
"""
|
|
190
154
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
191
155
|
"""
|
|
192
|
-
layer_key = self.optimizer.
|
|
156
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
193
157
|
# The default is the max value, which is correct for unmapped params or edge cases
|
|
194
158
|
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
adv_optm/util/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ from .Effective_Shape import _get_effective_shape
|
|
|
3
3
|
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
4
|
from .OrthoGrad import _orthogonalize_gradient
|
|
5
5
|
from .Newton_Schulz import _newton_schulz_iteration
|
|
6
|
+
|
|
6
7
|
__all__ = [
|
|
7
8
|
"_pack_bools", "_unpack_bools",
|
|
8
9
|
"add_stochastic_",
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=vjm5Sc3hgTSy9qP73qVBVGp9zE2J6blsQmj6KxyI3GE,380
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=828WtdsaKXJqlZqFXE2yrsxY3Erxn-6N7CxV9jBXiaI,17880
|
|
3
|
+
adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
|
|
4
|
+
adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
|
|
5
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
|
+
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=xGW9PafaIyi1noGhIgCWPwndI5bGX6kbxN-N-FQnr1U,19381
|
|
8
|
+
adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
|
|
9
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
|
+
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
12
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
|
+
adv_optm/util/Kourkoutas.py,sha256=_fq2glPqKmzgWpLedfwq5EqIJAxICUK2fmUP-cdcgq0,7467
|
|
14
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
15
|
+
adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
|
|
16
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
|
+
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
+
adv_optm-1.2.dev11.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.dev11.dist-info/METADATA,sha256=F6o4bbgIEjB9JS_9gediI_0-_rUkfsubKVtg5b4nrHE,14023
|
|
21
|
+
adv_optm-1.2.dev11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.dev11.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.dev11.dist-info/RECORD,,
|
adv_optm/util/MuonAdam_helper.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.optim import Optimizer
|
|
3
|
-
from typing import Callable, Optional
|
|
4
|
-
|
|
5
|
-
class MuonAdamHelper:
|
|
6
|
-
"""
|
|
7
|
-
A helper class for Muon_adv to decide whether to use Muon or a delegate
|
|
8
|
-
AdamW optimizer for a given parameter based on a keying function.
|
|
9
|
-
"""
|
|
10
|
-
def __init__(self, optimizer: Optimizer, layer_key_fn: Optional[Callable]):
|
|
11
|
-
if not hasattr(optimizer, 'param_groups'):
|
|
12
|
-
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
|
-
self.optimizer = optimizer
|
|
14
|
-
|
|
15
|
-
if layer_key_fn is None:
|
|
16
|
-
# If no function is provided, default all parameters to 'muon'.
|
|
17
|
-
self.layer_key_fn = lambda p: 'muon'
|
|
18
|
-
else:
|
|
19
|
-
self.layer_key_fn = layer_key_fn
|
|
20
|
-
|
|
21
|
-
def get_optimizer_type(self, p: "torch.Tensor") -> str:
|
|
22
|
-
"""
|
|
23
|
-
Returns the designated optimizer type ('adam' or 'muon') for a parameter.
|
|
24
|
-
|
|
25
|
-
The user-provided layer_key_fn should return 'adam' for parameters
|
|
26
|
-
to be handled by the auxiliary AdamW optimizer. Any other return
|
|
27
|
-
value is treated as 'muon'.
|
|
28
|
-
"""
|
|
29
|
-
key = self.layer_key_fn(p)
|
|
30
|
-
if key == 'adam':
|
|
31
|
-
return 'adam'
|
|
32
|
-
return 'muon'
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=TzvKgGTLkK0_XANeZzhURcSO9xmtUi-H9_C7tV3rXn4,379
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=yr1oJV339Zv7D8n148O1FJJAgdOsH8NZDZTKlcDOyu0,21181
|
|
3
|
-
adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
|
|
4
|
-
adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
|
|
5
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
|
-
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=HaF06fPKcKpVZY29_vqjWHAfivjvGntBuRyDDKj3Ozw,22784
|
|
8
|
-
adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
|
|
9
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
|
|
10
|
-
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
12
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
|
-
adv_optm/util/Kourkoutas.py,sha256=lObJGXmz3MqGSuu3DKqotSpZ0fuQFPE80R3zO_j3Z_Q,9707
|
|
14
|
-
adv_optm/util/MuonAdam_helper.py,sha256=7rnNMujZVDaqo1g22QscMyPlZvIHQQSLHMED9_I8QWU,1250
|
|
15
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
16
|
-
adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
|
|
17
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
18
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
19
|
-
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
20
|
-
adv_optm-1.2.dev9.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
21
|
-
adv_optm-1.2.dev9.dist-info/METADATA,sha256=GmAYWjZdfgvg9QbzyiV2PUNmzQFgJz8AjaY5F0x7Nv8,14022
|
|
22
|
-
adv_optm-1.2.dev9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
-
adv_optm-1.2.dev9.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
24
|
-
adv_optm-1.2.dev9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|