adv-optm 1.2.dev8__tar.gz → 1.2.dev10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/PKG-INFO +1 -1
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/__init__.py +1 -1
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/AdaMuon_adv.py +12 -24
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/AdamW_adv.py +4 -4
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Adopt_adv.py +3 -3
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Muon_adv.py +5 -22
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Prodigy_adv.py +7 -4
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/Kourkoutas.py +9 -6
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/__init__.py +1 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/SOURCES.txt +0 -1
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/setup.py +1 -1
- adv_optm-1.2.dev8/adv_optm/util/MuonAdam_helper.py +0 -32
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/LICENSE +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/README.md +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/setup.cfg +0 -0
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from .AdamW_adv import AdamW_adv
|
|
5
|
-
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
5
|
|
|
7
6
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
7
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -73,10 +72,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
73
72
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
74
73
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
75
74
|
AdamW_adv instead of Muon. (default: False)
|
|
76
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
-
and returns a key. If the key is 'adam', the parameter is handled by
|
|
78
|
-
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
79
|
-
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
80
75
|
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
81
76
|
to the auxiliary AdamW_adv optimizer. Only used when
|
|
82
77
|
`MuonWithAuxAdam` is True. (default: None)
|
|
@@ -106,7 +101,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
106
101
|
nnmf_factor: bool = False,
|
|
107
102
|
# hybrid optimizer mode
|
|
108
103
|
MuonWithAuxAdam: bool = False,
|
|
109
|
-
layer_key_fn: Optional[Callable] = None,
|
|
110
104
|
muon_adam_lr: float = 1e-4,
|
|
111
105
|
adam_kwargs: Optional[dict] = None,
|
|
112
106
|
):
|
|
@@ -132,10 +126,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
132
126
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
133
127
|
}
|
|
134
128
|
self.stochastic_rounding = stochastic_rounding
|
|
129
|
+
|
|
135
130
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
136
|
-
self.helper = None
|
|
137
131
|
self.aux_adam = None
|
|
138
|
-
|
|
132
|
+
|
|
139
133
|
if not self.MuonWithAuxAdam:
|
|
140
134
|
super().__init__(params, muon_defaults)
|
|
141
135
|
return
|
|
@@ -151,26 +145,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
151
145
|
adam_defaults = self.aux_adam.defaults
|
|
152
146
|
|
|
153
147
|
final_param_groups = []
|
|
154
|
-
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
155
|
-
|
|
156
148
|
for group in params:
|
|
157
|
-
|
|
158
|
-
first_param = group['params'][0]
|
|
159
|
-
key = _layer_key_fn(first_param)
|
|
160
|
-
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
161
|
-
|
|
162
|
-
new_group = group.copy()
|
|
149
|
+
optim_type = group.get('optim_type', 'muon')
|
|
163
150
|
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
164
151
|
|
|
152
|
+
new_group = group.copy()
|
|
165
153
|
for key, value in defaults_to_use.items():
|
|
166
154
|
new_group.setdefault(key, value)
|
|
167
155
|
final_param_groups.append(new_group)
|
|
168
156
|
|
|
169
157
|
super().__init__(final_param_groups, muon_defaults)
|
|
170
158
|
|
|
171
|
-
# Now that self is initialized, create the helper
|
|
172
|
-
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
173
|
-
|
|
174
159
|
|
|
175
160
|
@property
|
|
176
161
|
def supports_fused_back_pass(self):
|
|
@@ -197,13 +182,14 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
197
182
|
@torch.no_grad()
|
|
198
183
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
199
184
|
if self.MuonWithAuxAdam:
|
|
200
|
-
optim_type =
|
|
185
|
+
optim_type = group.get('optim_type')
|
|
201
186
|
if optim_type == 'adam':
|
|
202
187
|
# Delegate to the AdamW_adv optimizer's logic.
|
|
203
188
|
# We need to temporarily "lend" our state and param_groups
|
|
189
|
+
# to the delegate so it has the full context to work with,
|
|
190
|
+
# especially for features like Kourkoutas-beta.
|
|
204
191
|
self.aux_adam.state = self.state
|
|
205
192
|
self.aux_adam.param_groups = self.param_groups
|
|
206
|
-
|
|
207
193
|
self.aux_adam.step_parameter(p, group, i)
|
|
208
194
|
return
|
|
209
195
|
|
|
@@ -327,7 +313,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
327
313
|
# RMS-aligned rescaling
|
|
328
314
|
rms_target = group['rms_target']
|
|
329
315
|
num_elements = update.numel()
|
|
330
|
-
|
|
316
|
+
# Add eps to prevent division by zero
|
|
317
|
+
scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
|
|
331
318
|
|
|
332
319
|
update.mul_(scaling_factor)
|
|
333
320
|
update = update.view(p.shape).mul_(group['lr'])
|
|
@@ -422,7 +409,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
422
409
|
# RMS-aligned rescaling
|
|
423
410
|
rms_target = group['rms_target']
|
|
424
411
|
num_elements = update.numel()
|
|
425
|
-
|
|
412
|
+
# Add eps to prevent division by zero
|
|
413
|
+
scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
|
|
426
414
|
|
|
427
415
|
update.mul_(scaling_factor)
|
|
428
416
|
del num_elements, scaling_factor
|
|
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
73
73
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
74
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
75
|
logging (default: 0).
|
|
76
|
-
|
|
76
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
77
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
78
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
79
|
(default: None)
|
|
@@ -105,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
105
105
|
tiny_spike: float = 1e-9,
|
|
106
106
|
k_warmup_steps: int = 0,
|
|
107
107
|
k_logging: int = 0,
|
|
108
|
-
|
|
108
|
+
layer_key_fn: Optional[Callable] = None,
|
|
109
109
|
nnmf_factor: bool = False,
|
|
110
110
|
_is_delegate: bool = False,
|
|
111
111
|
):
|
|
@@ -137,7 +137,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
137
137
|
self.use_AdEMAMix = use_AdEMAMix
|
|
138
138
|
self.factored = nnmf_factor
|
|
139
139
|
self.kourkoutas_beta = kourkoutas_beta
|
|
140
|
-
self.
|
|
140
|
+
self.layer_key_fn = layer_key_fn
|
|
141
141
|
if not _is_delegate:
|
|
142
142
|
super().__init__(params, defaults)
|
|
143
143
|
else:
|
|
@@ -244,6 +244,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
244
244
|
|
|
245
245
|
if state['factored']:
|
|
246
246
|
d1, d2 = state['effective_shape']
|
|
247
|
+
grad_reshaped = grad.view(d1, d2)
|
|
247
248
|
|
|
248
249
|
# Reconstruct momentum from previous step's factors
|
|
249
250
|
if beta1 > 0:
|
|
@@ -253,7 +254,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
253
254
|
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
254
255
|
del unpacked_sign
|
|
255
256
|
# Update momentum in full-size
|
|
256
|
-
grad_reshaped = grad.view(d1, d2)
|
|
257
257
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
258
258
|
if self.grams_moment:
|
|
259
259
|
mt.copy_(grad_reshaped.sign() * mt.abs())
|
|
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
91
91
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
92
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
93
|
logging (default: 0).
|
|
94
|
-
|
|
94
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
95
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
96
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
97
|
(default: None)
|
|
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
125
125
|
tiny_spike: float = 1e-9,
|
|
126
126
|
k_warmup_steps: int = 0,
|
|
127
127
|
k_logging: int = 0,
|
|
128
|
-
|
|
128
|
+
layer_key_fn: Optional[Callable] = None,
|
|
129
129
|
nnmf_factor: bool = False,
|
|
130
130
|
):
|
|
131
131
|
if not (lr >= 0.0):
|
|
@@ -166,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
166
166
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
167
167
|
self.factored = nnmf_factor
|
|
168
168
|
self.kourkoutas_beta = kourkoutas_beta
|
|
169
|
-
self.
|
|
169
|
+
self.layer_key_fn = layer_key_fn
|
|
170
170
|
super().__init__(params, defaults)
|
|
171
171
|
|
|
172
172
|
if self.kourkoutas_beta:
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
3
|
-
|
|
2
|
+
from typing import Optional
|
|
4
3
|
from .AdamW_adv import AdamW_adv
|
|
5
|
-
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
4
|
|
|
7
5
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
6
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
@@ -74,10 +72,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
74
72
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
75
73
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
76
74
|
AdamW_adv instead of Muon. (default: False)
|
|
77
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
78
|
-
and returns a key. If the key is 'adam', the parameter is handled by
|
|
79
|
-
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
80
|
-
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
81
75
|
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
82
76
|
to the auxiliary AdamW_adv optimizer. Only used when
|
|
83
77
|
`MuonWithAuxAdam` is True. (default: None)
|
|
@@ -110,7 +104,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
110
104
|
normuon_atan2: bool = False,
|
|
111
105
|
# hybrid optimizer mode
|
|
112
106
|
MuonWithAuxAdam: bool = False,
|
|
113
|
-
layer_key_fn: Optional[Callable] = None,
|
|
114
107
|
muon_adam_lr: float = 1e-4,
|
|
115
108
|
adam_kwargs: Optional[dict] = None,
|
|
116
109
|
):
|
|
@@ -145,9 +138,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
145
138
|
self.stochastic_rounding = stochastic_rounding
|
|
146
139
|
|
|
147
140
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
148
|
-
self.helper = None
|
|
149
141
|
self.aux_adam = None
|
|
150
|
-
|
|
142
|
+
|
|
151
143
|
if not self.MuonWithAuxAdam:
|
|
152
144
|
super().__init__(params, muon_defaults)
|
|
153
145
|
return
|
|
@@ -163,25 +155,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
163
155
|
adam_defaults = self.aux_adam.defaults
|
|
164
156
|
|
|
165
157
|
final_param_groups = []
|
|
166
|
-
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
167
|
-
|
|
168
158
|
for group in params:
|
|
169
|
-
|
|
170
|
-
key = _layer_key_fn(first_param)
|
|
171
|
-
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
172
|
-
|
|
173
|
-
new_group = group.copy()
|
|
159
|
+
optim_type = group.get('optim_type', 'muon')
|
|
174
160
|
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
175
161
|
|
|
162
|
+
new_group = group.copy()
|
|
176
163
|
for key, value in defaults_to_use.items():
|
|
177
164
|
new_group.setdefault(key, value)
|
|
178
|
-
|
|
179
165
|
final_param_groups.append(new_group)
|
|
180
166
|
|
|
181
167
|
super().__init__(final_param_groups, muon_defaults)
|
|
182
|
-
|
|
183
|
-
# Now that self is initialized, create the helper
|
|
184
|
-
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
185
168
|
|
|
186
169
|
|
|
187
170
|
@property
|
|
@@ -209,7 +192,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
209
192
|
@torch.no_grad()
|
|
210
193
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
211
194
|
if self.MuonWithAuxAdam:
|
|
212
|
-
optim_type =
|
|
195
|
+
optim_type = group.get('optim_type')
|
|
213
196
|
if optim_type == 'adam':
|
|
214
197
|
# Delegate to the AdamW_adv optimizer's logic.
|
|
215
198
|
# We need to temporarily "lend" our state and param_groups
|
|
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
109
109
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
110
110
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
111
111
|
logging (default: 0).
|
|
112
|
-
|
|
112
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
113
113
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
114
114
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
115
115
|
(default: None)
|
|
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
152
152
|
tiny_spike: float = 1e-9,
|
|
153
153
|
k_warmup_steps: int = 0,
|
|
154
154
|
k_logging: int = 0,
|
|
155
|
-
|
|
155
|
+
layer_key_fn: Optional[Callable] = None,
|
|
156
156
|
):
|
|
157
157
|
if not (lr >= 0.0):
|
|
158
158
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
205
205
|
self.fsdp_in_use = fsdp_in_use
|
|
206
206
|
|
|
207
207
|
self.kourkoutas_beta = kourkoutas_beta
|
|
208
|
-
self.
|
|
208
|
+
self.layer_key_fn = layer_key_fn
|
|
209
209
|
|
|
210
210
|
super().__init__(params, defaults)
|
|
211
211
|
if self.kourkoutas_beta:
|
|
@@ -516,7 +516,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
516
516
|
if global_d_denom > 0:
|
|
517
517
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
518
518
|
if g_group.get('d_limiter', False):
|
|
519
|
-
|
|
519
|
+
if g_group.get('Simplified_AdEMAMix', False):
|
|
520
|
+
d_hat = min(self.d * (2 ** 0.1), d_hat)
|
|
521
|
+
else:
|
|
522
|
+
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
520
523
|
if self.d == g_group['d0']:
|
|
521
524
|
self.d = max(self.d, d_hat)
|
|
522
525
|
d_max = max(d_max, d_hat)
|
|
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
67
67
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
68
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
69
|
logging (default: 0).
|
|
70
|
-
|
|
70
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
71
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
72
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
73
|
(default: None)
|
|
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
95
95
|
tiny_spike: float = 1e-9,
|
|
96
96
|
k_warmup_steps: int = 0,
|
|
97
97
|
k_logging: int = 0,
|
|
98
|
-
|
|
98
|
+
layer_key_fn: Optional[Callable] = None,
|
|
99
99
|
nnmf_factor: bool = False,
|
|
100
100
|
):
|
|
101
101
|
if not (lr >= 0.0):
|
|
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
121
121
|
self.stochastic_rounding = stochastic_rounding
|
|
122
122
|
self.factored = nnmf_factor
|
|
123
123
|
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
-
self.
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
125
125
|
super().__init__(params, defaults)
|
|
126
126
|
|
|
127
127
|
if self.kourkoutas_beta:
|
|
@@ -32,21 +32,24 @@ class KourkoutasHelper:
|
|
|
32
32
|
if self._layer_info_built:
|
|
33
33
|
return
|
|
34
34
|
|
|
35
|
-
if hasattr(self.optimizer, '
|
|
35
|
+
if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
|
|
36
36
|
# A custom key function was provided by the user. We will use it.
|
|
37
37
|
pass
|
|
38
38
|
else:
|
|
39
39
|
# No key function was provided. Default to coarse, shape-based bucketing.
|
|
40
|
-
self.optimizer.
|
|
40
|
+
self.optimizer.layer_key_fn = lambda p: \
|
|
41
41
|
(id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
|
|
42
42
|
else tuple(p.shape)
|
|
43
43
|
# This ensures that we won't mix embeddings with tokens (1 to 10)
|
|
44
44
|
# TODO find a better way to safeguard the embeddings
|
|
45
45
|
|
|
46
46
|
for group in self.optimizer.param_groups:
|
|
47
|
+
if not group.get('kourkoutas_beta', False):
|
|
48
|
+
continue
|
|
49
|
+
|
|
47
50
|
for p in group['params']:
|
|
48
51
|
# The mapping is static and should not depend on the presence of a gradient.
|
|
49
|
-
layer_key = self.optimizer.
|
|
52
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
50
53
|
if layer_key not in self.layer_info:
|
|
51
54
|
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
52
55
|
self.layer_info[layer_key]['params'].append(p)
|
|
@@ -94,7 +97,7 @@ class KourkoutasHelper:
|
|
|
94
97
|
for layer_key, info in self.layer_info.items():
|
|
95
98
|
params, group = info['params'], info['group_ref']
|
|
96
99
|
|
|
97
|
-
if not group.get('kourkoutas_beta', False)
|
|
100
|
+
if not group.get('kourkoutas_beta', False):
|
|
98
101
|
continue
|
|
99
102
|
|
|
100
103
|
first_param_in_layer = info['params'][0]
|
|
@@ -174,7 +177,7 @@ class KourkoutasHelper:
|
|
|
174
177
|
"""
|
|
175
178
|
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
176
179
|
"""
|
|
177
|
-
layer_key = self.optimizer.
|
|
180
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
178
181
|
|
|
179
182
|
if layer_key in self.layer_info:
|
|
180
183
|
# Initialize the transient state for this layer if it's the first time in the step.
|
|
@@ -189,6 +192,6 @@ class KourkoutasHelper:
|
|
|
189
192
|
"""
|
|
190
193
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
191
194
|
"""
|
|
192
|
-
layer_key = self.optimizer.
|
|
195
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
193
196
|
# The default is the max value, which is correct for unmapped params or edge cases
|
|
194
197
|
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
|
@@ -3,6 +3,7 @@ from .Effective_Shape import _get_effective_shape
|
|
|
3
3
|
from .One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
4
4
|
from .OrthoGrad import _orthogonalize_gradient
|
|
5
5
|
from .Newton_Schulz import _newton_schulz_iteration
|
|
6
|
+
|
|
6
7
|
__all__ = [
|
|
7
8
|
"_pack_bools", "_unpack_bools",
|
|
8
9
|
"add_stochastic_",
|
|
@@ -19,7 +19,6 @@ adv_optm/optim/__init__.py
|
|
|
19
19
|
adv_optm/util/BF16_Stochastic_Rounding.py
|
|
20
20
|
adv_optm/util/Effective_Shape.py
|
|
21
21
|
adv_optm/util/Kourkoutas.py
|
|
22
|
-
adv_optm/util/MuonAdam_helper.py
|
|
23
22
|
adv_optm/util/NNMF.py
|
|
24
23
|
adv_optm/util/Newton_Schulz.py
|
|
25
24
|
adv_optm/util/One_Bit_Boolean.py
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.optim import Optimizer
|
|
3
|
-
from typing import Callable, Optional
|
|
4
|
-
|
|
5
|
-
class MuonAdamHelper:
|
|
6
|
-
"""
|
|
7
|
-
A helper class for Muon_adv to decide whether to use Muon or a delegate
|
|
8
|
-
AdamW optimizer for a given parameter based on a keying function.
|
|
9
|
-
"""
|
|
10
|
-
def __init__(self, optimizer: Optimizer, layer_key_fn: Optional[Callable]):
|
|
11
|
-
if not hasattr(optimizer, 'param_groups'):
|
|
12
|
-
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
|
-
self.optimizer = optimizer
|
|
14
|
-
|
|
15
|
-
if layer_key_fn is None:
|
|
16
|
-
# If no function is provided, default all parameters to 'muon'.
|
|
17
|
-
self.layer_key_fn = lambda p: 'muon'
|
|
18
|
-
else:
|
|
19
|
-
self.layer_key_fn = layer_key_fn
|
|
20
|
-
|
|
21
|
-
def get_optimizer_type(self, p: "torch.Tensor") -> str:
|
|
22
|
-
"""
|
|
23
|
-
Returns the designated optimizer type ('adam' or 'muon') for a parameter.
|
|
24
|
-
|
|
25
|
-
The user-provided layer_key_fn should return 'adam' for parameters
|
|
26
|
-
to be handled by the auxiliary AdamW optimizer. Any other return
|
|
27
|
-
value is treated as 'muon'.
|
|
28
|
-
"""
|
|
29
|
-
key = self.layer_key_fn(p)
|
|
30
|
-
if key == 'adam':
|
|
31
|
-
return 'adam'
|
|
32
|
-
return 'muon'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|