adv-optm 2.1.dev1__tar.gz → 2.1.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/PKG-INFO +1 -1
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/__init__.py +1 -1
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/AdaMuon_adv.py +8 -4
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/AdamW_adv.py +3 -2
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Adopt_adv.py +4 -3
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +1 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Muon_adv.py +5 -1
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Prodigy_adv.py +3 -2
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +4 -8
- adv_optm-2.1.dev3/adv_optm/util/Kourkoutas.py +196 -0
- adv_optm-2.1.dev3/adv_optm/util/Muon_AuxAdam.py +194 -0
- adv_optm-2.1.dev3/adv_optm/util/Muon_util.py +318 -0
- adv_optm-2.1.dev3/adv_optm/util/OrthoGrad.py +21 -0
- adv_optm-2.1.dev3/adv_optm/util/__init__.py +0 -0
- adv_optm-2.1.dev3/adv_optm/util/factorization_util.py +105 -0
- adv_optm-2.1.dev3/adv_optm/util/lion_k.py +53 -0
- adv_optm-2.1.dev3/adv_optm/util/param_update.py +164 -0
- adv_optm-2.1.dev3/adv_optm/util/update_util.py +24 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/SOURCES.txt +10 -1
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/setup.py +1 -1
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/LICENSE +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/README.md +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/setup.cfg +0 -0
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util import Muon_AuxAdam
|
|
11
11
|
|
|
12
|
-
A =
|
|
12
|
+
A = 4 / math.pi
|
|
13
13
|
|
|
14
14
|
class AdaMuon_adv(torch.optim.Optimizer):
|
|
15
15
|
"""
|
|
@@ -396,7 +396,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
396
396
|
del denom, vt_buf
|
|
397
397
|
|
|
398
398
|
# RMS-aligned scaling
|
|
399
|
-
step_scale = lr * A if group['use_atan2'] else lr
|
|
399
|
+
step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
|
|
400
400
|
rms_adjustment(update, group['rms_rescaling'], step_scale)
|
|
401
401
|
|
|
402
402
|
update = update.reshape(p.shape)
|
|
@@ -454,14 +454,18 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
454
454
|
del denom
|
|
455
455
|
|
|
456
456
|
# RMS-aligned rescaling
|
|
457
|
-
step_scale = lr * A if group['use_atan2'] else lr
|
|
457
|
+
step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
|
|
458
458
|
rms_adjustment(update, group['rms_rescaling'], step_scale)
|
|
459
459
|
|
|
460
460
|
update = update.reshape(original_shape)
|
|
461
461
|
|
|
462
462
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
463
463
|
|
|
464
|
-
|
|
464
|
+
if group.get('compiled_optimizer', False):
|
|
465
|
+
lr = torch.as_tensor(group['lr'])
|
|
466
|
+
else:
|
|
467
|
+
lr = group['lr']
|
|
468
|
+
compiled_muon_step_parameter(state, grad, group, lr, random_int_tensor)
|
|
465
469
|
|
|
466
470
|
@torch.no_grad()
|
|
467
471
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
@@ -10,7 +10,7 @@ from ..util.update_util import _grams_update, _cautious_update
|
|
|
10
10
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
11
11
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
12
|
|
|
13
|
-
A =
|
|
13
|
+
A = 4 / math.pi
|
|
14
14
|
|
|
15
15
|
class AdamW_adv(torch.optim.Optimizer):
|
|
16
16
|
"""
|
|
@@ -233,7 +233,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
233
233
|
current_step = state['step']
|
|
234
234
|
if group.get('kourkoutas_beta', False):
|
|
235
235
|
# Call prepare_step() once at the beginning of the step for all params
|
|
236
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
236
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
237
237
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
238
238
|
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
239
239
|
|
|
@@ -249,6 +249,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
249
249
|
random_int_tensor = None
|
|
250
250
|
|
|
251
251
|
if group.get('compiled_optimizer', False):
|
|
252
|
+
step_size = torch.as_tensor(step_size)
|
|
252
253
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
253
254
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
254
255
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
|
|
|
9
9
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
10
|
from ..util.update_util import _grams_update, _cautious_update
|
|
11
11
|
|
|
12
|
-
A =
|
|
12
|
+
A = 4 / math.pi
|
|
13
13
|
|
|
14
14
|
class Adopt_adv(torch.optim.Optimizer):
|
|
15
15
|
"""
|
|
@@ -258,7 +258,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
258
258
|
current_step = state['step']
|
|
259
259
|
if group.get('kourkoutas_beta', False):
|
|
260
260
|
# Call prepare_step() once at the beginning of the step for all params
|
|
261
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
261
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
262
262
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
263
263
|
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
264
264
|
|
|
@@ -270,14 +270,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
270
270
|
random_int_tensor = None
|
|
271
271
|
|
|
272
272
|
if group.get('compiled_optimizer', False):
|
|
273
|
+
lr = torch.as_tensor(group['lr'])
|
|
273
274
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
274
275
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
275
276
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
276
277
|
step_param_fn = self._compiled_step_parameter
|
|
277
278
|
else:
|
|
279
|
+
lr = group['lr']
|
|
278
280
|
step_param_fn = self._step_parameter
|
|
279
281
|
|
|
280
|
-
lr = group['lr']
|
|
281
282
|
|
|
282
283
|
step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor)
|
|
283
284
|
|
|
@@ -226,6 +226,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
226
226
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
227
227
|
# TODO, workaround until pytorch#169634 is fixed
|
|
228
228
|
d = torch.as_tensor(group['d'])
|
|
229
|
+
dlr = torch.as_tensor(group['dlr'])
|
|
229
230
|
step_param_fn = self._compiled_step_parameter
|
|
230
231
|
else:
|
|
231
232
|
d = group['d']
|
|
@@ -399,7 +399,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
399
399
|
|
|
400
400
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
401
401
|
|
|
402
|
-
|
|
402
|
+
if group.get('compiled_optimizer', False):
|
|
403
|
+
lr = torch.as_tensor(group['lr'])
|
|
404
|
+
else:
|
|
405
|
+
lr = group['lr']
|
|
406
|
+
compiled_muon_step_parameter(state, grad, group, lr, random_int_tensor)
|
|
403
407
|
|
|
404
408
|
@torch.no_grad()
|
|
405
409
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
@@ -11,7 +11,7 @@ from ..util.Kourkoutas import KourkoutasHelper
|
|
|
11
11
|
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
12
12
|
from ..util.update_util import _grams_update, _cautious_update
|
|
13
13
|
|
|
14
|
-
A =
|
|
14
|
+
A = 4 / math.pi
|
|
15
15
|
|
|
16
16
|
class Prodigy_adv(torch.optim.Optimizer):
|
|
17
17
|
"""
|
|
@@ -327,7 +327,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
327
327
|
current_step = state['step']
|
|
328
328
|
if group.get('kourkoutas_beta', False):
|
|
329
329
|
# Call prepare_step() once at the beginning of the step for all params
|
|
330
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
330
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
331
331
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
332
332
|
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
333
333
|
else:
|
|
@@ -343,6 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
343
343
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
344
344
|
# TODO, workaround until pytorch#169634 is fixed
|
|
345
345
|
d = torch.as_tensor(group['d'])
|
|
346
|
+
dlr = torch.as_tensor(dlr)
|
|
346
347
|
step_param_fn = self._compiled_step_parameter
|
|
347
348
|
else:
|
|
348
349
|
d = group['d']
|
|
@@ -211,7 +211,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
211
211
|
current_step = state['step']
|
|
212
212
|
if group.get('kourkoutas_beta', False):
|
|
213
213
|
# Call prepare_step() once at the beginning of the step for all params
|
|
214
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
214
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
|
|
215
215
|
# Accumulate current grad's norm for the *next* step
|
|
216
216
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
217
217
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
@@ -244,7 +244,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
244
244
|
# Pre-generate random tensor for stochastic rounding if needed.
|
|
245
245
|
random_int_tensor = param_update._get_random_int_for_sr(p)
|
|
246
246
|
# TODO, workaround until pytorch#169634 is fixed
|
|
247
|
-
|
|
247
|
+
lr = torch.as_tensor(lr)
|
|
248
248
|
step_param_fn = self._compiled_step_parameter
|
|
249
249
|
else:
|
|
250
250
|
step_param_fn = self._step_parameter
|
|
@@ -289,10 +289,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
289
289
|
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
|
|
290
290
|
del vt
|
|
291
291
|
|
|
292
|
-
|
|
293
|
-
update.mul_(sqrt_den_num)
|
|
294
|
-
|
|
295
|
-
update = update.view(p.shape).mul_(lr)
|
|
292
|
+
update = update.view(p.shape).mul_(lr * sqrt_den_num)
|
|
296
293
|
|
|
297
294
|
else: # Standard optimizer logic for non-factored tensors
|
|
298
295
|
exp_avg_sq = state['exp_avg_sq']
|
|
@@ -308,8 +305,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
308
305
|
update.div_(denom)
|
|
309
306
|
del denom
|
|
310
307
|
|
|
311
|
-
|
|
312
|
-
update.mul_(update_scaling)
|
|
308
|
+
update.mul_(lr * sqrt_den_num)
|
|
313
309
|
|
|
314
310
|
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
315
311
|
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim import Optimizer
|
|
3
|
+
|
|
4
|
+
class KourkoutasHelper:
|
|
5
|
+
"""
|
|
6
|
+
A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
|
|
7
|
+
"""
|
|
8
|
+
def __init__(self, optimizer: Optimizer):
|
|
9
|
+
# We need a reference to the optimizer to access its param_groups and state
|
|
10
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
11
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
12
|
+
self.optimizer = optimizer
|
|
13
|
+
self.layer_state = {}
|
|
14
|
+
|
|
15
|
+
self.layer_info = {}
|
|
16
|
+
self._layer_info_built = False
|
|
17
|
+
self._current_step_prepared = -1
|
|
18
|
+
|
|
19
|
+
# Store stats for external logging (e.g., TensorBoard)
|
|
20
|
+
self.last_beta2_stats = {}
|
|
21
|
+
|
|
22
|
+
# This ensures the map is complete before the first backward pass,
|
|
23
|
+
# making it compatible with fused back pass mechanisms.
|
|
24
|
+
self._build_layer_info_if_needed()
|
|
25
|
+
|
|
26
|
+
def _build_layer_info_if_needed(self):
|
|
27
|
+
"""Builds a map of layers and the parameters they contain."""
|
|
28
|
+
if self._layer_info_built:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
|
|
32
|
+
# A custom key function was provided by the user. We will use it.
|
|
33
|
+
pass
|
|
34
|
+
else:
|
|
35
|
+
# No key function was provided. Default to coarse, shape-based bucketing.
|
|
36
|
+
self.optimizer.layer_key_fn = lambda p: \
|
|
37
|
+
(id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
|
|
38
|
+
else tuple(p.shape)
|
|
39
|
+
# This ensures that we won't mix embeddings with tokens (1 to 10)
|
|
40
|
+
# TODO find a better way to safeguard the embeddings
|
|
41
|
+
|
|
42
|
+
for group in self.optimizer.param_groups:
|
|
43
|
+
if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
for p in group['params']:
|
|
47
|
+
# The mapping is static and should not depend on the presence of a gradient.
|
|
48
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
49
|
+
if layer_key not in self.layer_info:
|
|
50
|
+
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
51
|
+
self.layer_info[layer_key]['params'].append(p)
|
|
52
|
+
|
|
53
|
+
self._layer_info_built = True
|
|
54
|
+
|
|
55
|
+
def _get_or_init_layer_ema_tensor(self, layer_key, layer_params, device):
|
|
56
|
+
"""
|
|
57
|
+
Retrieves the EMA tensor for this layer.
|
|
58
|
+
It handles synchronization between the internal layer_state and
|
|
59
|
+
the external optimizer.state (which is required for state_dict saving/loading).
|
|
60
|
+
"""
|
|
61
|
+
# Initialize container in layer_state if missing
|
|
62
|
+
if layer_key not in self.layer_state:
|
|
63
|
+
self.layer_state[layer_key] = {
|
|
64
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=device, dtype=torch.float32)
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
internal_ema = self.layer_state[layer_key].get('kourkoutas_r_ema')
|
|
68
|
+
|
|
69
|
+
# Check optimizer.state for any existing state (e.g. from a loaded checkpoint)
|
|
70
|
+
# We check the first parameter in the list to see if it has state.
|
|
71
|
+
# If a checkpoint was loaded, optimizer.state[p] will contain the tensor.
|
|
72
|
+
representative_p = layer_params[0]
|
|
73
|
+
external_ema = self.optimizer.state[representative_p].get('kourkoutas_r_ema')
|
|
74
|
+
|
|
75
|
+
# Case A: Desync detected (Optimizer has state, but Internal doesn't, or they differ).
|
|
76
|
+
# This usually happens after load_state_dict(). We trust the optimizer.state.
|
|
77
|
+
if external_ema is not None and (internal_ema is None or internal_ema is not external_ema):
|
|
78
|
+
# Adopt the external tensor as our working tensor
|
|
79
|
+
self.layer_state[layer_key]['kourkoutas_r_ema'] = external_ema
|
|
80
|
+
|
|
81
|
+
# Ensure ALL params in this layer point to this exact tensor object
|
|
82
|
+
# (Fixes any fragmentation if only some params had state)
|
|
83
|
+
for p in layer_params:
|
|
84
|
+
self.optimizer.state[p]['kourkoutas_r_ema'] = external_ema
|
|
85
|
+
|
|
86
|
+
return external_ema
|
|
87
|
+
|
|
88
|
+
# Case B: No state anywhere. Create new.
|
|
89
|
+
if internal_ema is None:
|
|
90
|
+
new_ema = torch.tensor(0.0, device=device, dtype=torch.float32)
|
|
91
|
+
self.layer_state[layer_key]['kourkoutas_r_ema'] = new_ema
|
|
92
|
+
|
|
93
|
+
# Register this tensor in optimizer.state for ALL params so it gets saved
|
|
94
|
+
for p in layer_params:
|
|
95
|
+
self.optimizer.state[p]['kourkoutas_r_ema'] = new_ema
|
|
96
|
+
|
|
97
|
+
return new_ema
|
|
98
|
+
|
|
99
|
+
# Case C: Internal state exists and looks valid.
|
|
100
|
+
# We just need to ensure the link to optimizer.state is maintained (just in case).
|
|
101
|
+
# This is a cheap reference assignment.
|
|
102
|
+
for p in layer_params:
|
|
103
|
+
if 'kourkoutas_r_ema' not in self.optimizer.state[p]:
|
|
104
|
+
self.optimizer.state[p]['kourkoutas_r_ema'] = internal_ema
|
|
105
|
+
|
|
106
|
+
return internal_ema
|
|
107
|
+
|
|
108
|
+
def prepare_step(self, current_step: int, device):
|
|
109
|
+
"""
|
|
110
|
+
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
111
|
+
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
112
|
+
"""
|
|
113
|
+
beta2_log = []
|
|
114
|
+
master_defaults = self.optimizer.defaults
|
|
115
|
+
|
|
116
|
+
for layer_key, info in self.layer_info.items():
|
|
117
|
+
group = info['group_ref']
|
|
118
|
+
|
|
119
|
+
if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# Retrieve the EMA tensor. This function ensures the tensor is present
|
|
123
|
+
# in self.optimizer.state[p] for all parameters, ensuring state_dict support.
|
|
124
|
+
r_ema_tensor = self._get_or_init_layer_ema_tensor(layer_key, info['params'], device)
|
|
125
|
+
|
|
126
|
+
# Get accumulator
|
|
127
|
+
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
128
|
+
pooled_grad_norm = torch.sqrt(accumulator)
|
|
129
|
+
|
|
130
|
+
# Use group-specific K-b settings, falling back to the optimizer's master defaults.
|
|
131
|
+
# This makes the helper robust against param groups that enable kourkoutas_beta
|
|
132
|
+
# but are missing the other required hyperparameters.
|
|
133
|
+
# In hybrid optimizers like Muon_adv, the Kourkoutas-related keys in the
|
|
134
|
+
# defaults and param_groups are prefixed with 'adam_' to avoid conflicts.
|
|
135
|
+
# We must detect this case and use the correct key names.
|
|
136
|
+
prefix = 'adam_' if group.get('adam_kourkoutas_beta', False) else ''
|
|
137
|
+
|
|
138
|
+
ema_alpha = group.get(f'{prefix}ema_alpha', master_defaults[f'{prefix}ema_alpha'])
|
|
139
|
+
betas_tuple = group.get(f'{prefix}betas', master_defaults[f'{prefix}betas'])
|
|
140
|
+
beta2_max = betas_tuple[1]
|
|
141
|
+
beta2_min = group.get(f'{prefix}beta2_min', master_defaults[f'{prefix}beta2_min'])
|
|
142
|
+
tiny_spike = group.get(f'{prefix}tiny_spike', master_defaults[f'{prefix}tiny_spike'])
|
|
143
|
+
k_warmup_steps = group.get(f'{prefix}k_warmup_steps', master_defaults[f'{prefix}k_warmup_steps'])
|
|
144
|
+
|
|
145
|
+
# Update the persistent EMA tensor in-place.
|
|
146
|
+
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
147
|
+
|
|
148
|
+
# Calculate Beta2
|
|
149
|
+
if current_step < k_warmup_steps:
|
|
150
|
+
beta2 = beta2_max
|
|
151
|
+
else:
|
|
152
|
+
raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
|
|
153
|
+
sun = raw / (1.0 + raw)
|
|
154
|
+
beta2 = beta2_max - (beta2_max - beta2_min) * sun
|
|
155
|
+
|
|
156
|
+
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
157
|
+
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) and not group.get('compiled_optimizer', False) else beta2
|
|
158
|
+
|
|
159
|
+
# Reset the accumulator for the next optimizer step.
|
|
160
|
+
accumulator.zero_()
|
|
161
|
+
|
|
162
|
+
beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
|
|
163
|
+
|
|
164
|
+
# Compute stats for TensorBoard
|
|
165
|
+
if beta2_log:
|
|
166
|
+
beta2_tensor = torch.as_tensor(beta2_log, device='cpu')
|
|
167
|
+
self.last_beta2_stats = {
|
|
168
|
+
'mean': beta2_tensor.mean().item()
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
def maybe_prepare_step(self, current_step: int, device):
|
|
172
|
+
"""
|
|
173
|
+
A universal guard that calls prepare_step() exactly once per training step.
|
|
174
|
+
"""
|
|
175
|
+
if self._current_step_prepared < current_step:
|
|
176
|
+
self.prepare_step(current_step, device)
|
|
177
|
+
self._current_step_prepared = current_step
|
|
178
|
+
|
|
179
|
+
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
180
|
+
"""
|
|
181
|
+
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
182
|
+
"""
|
|
183
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
184
|
+
|
|
185
|
+
if layer_key in self.layer_info and layer_key in self.layer_state:
|
|
186
|
+
# Accumulate for the *next* step's prepare_step call
|
|
187
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
188
|
+
|
|
189
|
+
def get_beta2(self, p: torch.Tensor, group: dict) -> float:
|
|
190
|
+
"""
|
|
191
|
+
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
192
|
+
"""
|
|
193
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
194
|
+
# The default is the max value, which is correct for unmapped params or edge cases
|
|
195
|
+
beta2_default = group.get('betas', group.get('adam_betas'))[1] if group.get('betas', group.get('adam_betas')) else 0.999
|
|
196
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from ..util import param_update
|
|
6
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
7
|
+
from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
|
|
8
|
+
from ..util.update_util import _grams_update, _cautious_update
|
|
9
|
+
|
|
10
|
+
A = 4 / math.pi
|
|
11
|
+
|
|
12
|
+
@torch.no_grad()
|
|
13
|
+
def _init_auxadam_state(self, p, group):
|
|
14
|
+
state = self.state[p]
|
|
15
|
+
|
|
16
|
+
state['step'] = 0
|
|
17
|
+
|
|
18
|
+
state['factored'] = (
|
|
19
|
+
group['adam_nnmf_factor'] and
|
|
20
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
21
|
+
)
|
|
22
|
+
dtype = torch.float32 if state['factored'] else p.dtype
|
|
23
|
+
device = p.device
|
|
24
|
+
|
|
25
|
+
if state['factored']:
|
|
26
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
27
|
+
d1, d2 = state['effective_shape']
|
|
28
|
+
# First moment (m)
|
|
29
|
+
if group['adam_betas'][0] > 0:
|
|
30
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
31
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
32
|
+
packed_d2 = (d2 + 7) // 8
|
|
33
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
34
|
+
if group.get('adam_use_AdEMAMix'):
|
|
35
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
36
|
+
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
37
|
+
packed_d2 = (d2 + 7) // 8
|
|
38
|
+
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
39
|
+
# Second moment (v)
|
|
40
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
41
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
42
|
+
else: # Fallback to standard AdamW for non-factored tensors
|
|
43
|
+
if group['adam_betas'][0] > 0:
|
|
44
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
45
|
+
if group.get('adam_use_AdEMAMix'):
|
|
46
|
+
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
47
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@torch.no_grad()
|
|
51
|
+
def _adam_step_parameter(self, p, grad, state, group, is_compiled, random_int_tensor):
|
|
52
|
+
|
|
53
|
+
step = state['step']
|
|
54
|
+
|
|
55
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
56
|
+
|
|
57
|
+
if self.kourkoutas_helper:
|
|
58
|
+
# Prepare Kourkoutas-β once per optimizer step.
|
|
59
|
+
self.kourkoutas_helper.maybe_prepare_step(step, p.device)
|
|
60
|
+
# Get the dynamic beta2_adam calculated in prepare_step()
|
|
61
|
+
beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
|
|
62
|
+
|
|
63
|
+
if group['adam_use_bias_correction']:
|
|
64
|
+
current_step = step + 1
|
|
65
|
+
beta1_adam, beta2_adam = group['adam_betas']
|
|
66
|
+
bias_correction1 = 1.0 - beta1_adam ** current_step
|
|
67
|
+
sqrt_bias_correction2 = (1.0 - beta2_adam ** current_step)**0.5
|
|
68
|
+
else:
|
|
69
|
+
bias_correction1 = 1.0
|
|
70
|
+
sqrt_bias_correction2 = 1.0
|
|
71
|
+
|
|
72
|
+
state['step'] += 1
|
|
73
|
+
|
|
74
|
+
step_size = group['lr'] / bias_correction1
|
|
75
|
+
|
|
76
|
+
if group.get('compiled_optimizer', False):
|
|
77
|
+
step_size = torch.as_tensor(step_size)
|
|
78
|
+
|
|
79
|
+
@torch.compile(fullgraph=True, disable= not is_compiled)
|
|
80
|
+
def compiled_muon_step_parameter(state, grad, group, step_size, sqrt_bias_correction2, random_int_tensor):
|
|
81
|
+
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
82
|
+
grad = grad.float()
|
|
83
|
+
if group.get("adam_orthogonal_gradient"):
|
|
84
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
85
|
+
|
|
86
|
+
if self.kourkoutas_helper:
|
|
87
|
+
# Accumulate current grad's norm for the *next* step
|
|
88
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
89
|
+
|
|
90
|
+
if group.get('adam_use_AdEMAMix'):
|
|
91
|
+
beta3_ema = group['adam_beta3_ema']
|
|
92
|
+
alpha = group['adam_alpha']
|
|
93
|
+
|
|
94
|
+
if state['factored']:
|
|
95
|
+
d1, d2 = state['effective_shape']
|
|
96
|
+
grad_reshaped = grad.view(d1, d2)
|
|
97
|
+
|
|
98
|
+
# Reconstruct momentum from previous step's factors
|
|
99
|
+
if beta1_adam > 0:
|
|
100
|
+
mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
|
|
101
|
+
|
|
102
|
+
# Update momentum in full-size
|
|
103
|
+
mt.lerp_(grad_reshaped, 1.0 - beta1_adam)
|
|
104
|
+
|
|
105
|
+
# Factorize
|
|
106
|
+
state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
|
|
107
|
+
|
|
108
|
+
if group.get('adam_grams_moment'):
|
|
109
|
+
update_mt = _grams_update(mt, grad_reshaped, inplace=True)
|
|
110
|
+
elif group.get('adam_cautious_mask'):
|
|
111
|
+
update_mt = _cautious_update(mt, grad_reshaped, inplace=True)
|
|
112
|
+
else:
|
|
113
|
+
update_mt = mt
|
|
114
|
+
|
|
115
|
+
vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
|
|
116
|
+
vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
|
|
117
|
+
|
|
118
|
+
if group.get('adam_use_AdEMAMix'):
|
|
119
|
+
mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
|
|
120
|
+
|
|
121
|
+
mt_slow.lerp_(grad_reshaped, 1.0 - beta3_ema)
|
|
122
|
+
|
|
123
|
+
if beta1_adam > 0:
|
|
124
|
+
update = update_mt.add_(mt_slow, alpha=alpha)
|
|
125
|
+
else:
|
|
126
|
+
update = grad_reshaped.add(mt_slow, alpha=alpha)
|
|
127
|
+
# Factorize
|
|
128
|
+
state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
|
|
129
|
+
del mt_slow
|
|
130
|
+
else:
|
|
131
|
+
if beta1_adam > 0:
|
|
132
|
+
update = update_mt
|
|
133
|
+
else:
|
|
134
|
+
update = grad_reshaped.clone()
|
|
135
|
+
|
|
136
|
+
if group['adam_use_atan2']:
|
|
137
|
+
denom = vt.sqrt()
|
|
138
|
+
denom.div_(sqrt_bias_correction2)
|
|
139
|
+
update.atan2_(denom)
|
|
140
|
+
else:
|
|
141
|
+
denom = vt.sqrt()
|
|
142
|
+
denom.div_(sqrt_bias_correction2).add_(group['adam_eps'])
|
|
143
|
+
update.div_(denom)
|
|
144
|
+
del denom
|
|
145
|
+
|
|
146
|
+
# Factorize
|
|
147
|
+
state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
|
|
148
|
+
del vt
|
|
149
|
+
|
|
150
|
+
update_scaling = step_size * A if group['use_atan2'] else step_size
|
|
151
|
+
update = update.view(p.shape).mul_(update_scaling)
|
|
152
|
+
|
|
153
|
+
else: # Standard AdamW logic for non-factored tensors
|
|
154
|
+
if beta1_adam > 0:
|
|
155
|
+
exp_avg = state['exp_avg']
|
|
156
|
+
exp_avg.lerp_(grad, 1.0 - beta1_adam)
|
|
157
|
+
|
|
158
|
+
if group.get('adam_grams_moment'):
|
|
159
|
+
update_mt = _grams_update(exp_avg, grad)
|
|
160
|
+
elif group.get('adam_cautious_mask'):
|
|
161
|
+
update_mt = _cautious_update(exp_avg, grad)
|
|
162
|
+
else:
|
|
163
|
+
update_mt = exp_avg.clone()
|
|
164
|
+
|
|
165
|
+
if group.get('adam_use_AdEMAMix'):
|
|
166
|
+
exp_avg_slow = state['exp_avg_slow']
|
|
167
|
+
exp_avg_slow.lerp_(grad, 1.0 - beta3_ema)
|
|
168
|
+
|
|
169
|
+
if beta1_adam > 0:
|
|
170
|
+
update = update_mt.add_(exp_avg_slow, alpha=alpha)
|
|
171
|
+
else:
|
|
172
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
173
|
+
else:
|
|
174
|
+
update = update_mt if beta1_adam > 0 else grad.clone()
|
|
175
|
+
|
|
176
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
177
|
+
exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad, value=1 - beta2_adam)
|
|
178
|
+
|
|
179
|
+
if group.get('adam_use_atan2'):
|
|
180
|
+
denom = exp_avg_sq.sqrt()
|
|
181
|
+
denom.div_(sqrt_bias_correction2)
|
|
182
|
+
update.atan2_(denom)
|
|
183
|
+
else:
|
|
184
|
+
denom = exp_avg_sq.sqrt()
|
|
185
|
+
denom.div_(sqrt_bias_correction2).add_(group['adam_eps'])
|
|
186
|
+
update.div_(denom)
|
|
187
|
+
del denom
|
|
188
|
+
|
|
189
|
+
update_scaling = step_size * A if group['adam_use_atan2'] else step_size
|
|
190
|
+
update.mul_(update_scaling)
|
|
191
|
+
|
|
192
|
+
param_update.apply_parameter_update(self, p, group, update, step_size, group["adam_weight_decay"], random_int_tensor=random_int_tensor)
|
|
193
|
+
|
|
194
|
+
compiled_muon_step_parameter(state, grad, group, step_size, sqrt_bias_correction2, random_int_tensor)
|