adv-optm 2.1.dev1__tar.gz → 2.1.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/PKG-INFO +1 -1
  2. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/__init__.py +1 -1
  3. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/AdaMuon_adv.py +8 -4
  4. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/AdamW_adv.py +3 -2
  5. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Adopt_adv.py +4 -3
  6. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +1 -0
  7. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Muon_adv.py +5 -1
  8. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Prodigy_adv.py +3 -2
  9. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +4 -8
  10. adv_optm-2.1.dev3/adv_optm/util/Kourkoutas.py +196 -0
  11. adv_optm-2.1.dev3/adv_optm/util/Muon_AuxAdam.py +194 -0
  12. adv_optm-2.1.dev3/adv_optm/util/Muon_util.py +318 -0
  13. adv_optm-2.1.dev3/adv_optm/util/OrthoGrad.py +21 -0
  14. adv_optm-2.1.dev3/adv_optm/util/__init__.py +0 -0
  15. adv_optm-2.1.dev3/adv_optm/util/factorization_util.py +105 -0
  16. adv_optm-2.1.dev3/adv_optm/util/lion_k.py +53 -0
  17. adv_optm-2.1.dev3/adv_optm/util/param_update.py +164 -0
  18. adv_optm-2.1.dev3/adv_optm/util/update_util.py +24 -0
  19. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
  20. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/SOURCES.txt +10 -1
  21. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/setup.py +1 -1
  22. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/LICENSE +0 -0
  23. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/README.md +0 -0
  24. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/Lion_adv.py +0 -0
  25. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm/optim/__init__.py +0 -0
  26. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-2.1.dev1 → adv_optm-2.1.dev3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.1.dev1
3
+ Version: 2.1.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.1.dev1"
23
+ __version__ = "2.1.dev3"
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
10
  from ..util import Muon_AuxAdam
11
11
 
12
- A = torch.as_tensor(4 / math.pi)
12
+ A = 4 / math.pi
13
13
 
14
14
  class AdaMuon_adv(torch.optim.Optimizer):
15
15
  """
@@ -396,7 +396,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
396
396
  del denom, vt_buf
397
397
 
398
398
  # RMS-aligned scaling
399
- step_scale = lr * A if group['use_atan2'] else lr
399
+ step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
400
400
  rms_adjustment(update, group['rms_rescaling'], step_scale)
401
401
 
402
402
  update = update.reshape(p.shape)
@@ -454,14 +454,18 @@ class AdaMuon_adv(torch.optim.Optimizer):
454
454
  del denom
455
455
 
456
456
  # RMS-aligned rescaling
457
- step_scale = lr * A if group['use_atan2'] else lr
457
+ step_scale = lr * A if group['use_atan2'] and not group['normuon_variant'] else lr
458
458
  rms_adjustment(update, group['rms_rescaling'], step_scale)
459
459
 
460
460
  update = update.reshape(original_shape)
461
461
 
462
462
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
463
463
 
464
- compiled_muon_step_parameter(state, grad, group, group['lr'], random_int_tensor)
464
+ if group.get('compiled_optimizer', False):
465
+ lr = torch.as_tensor(group['lr'])
466
+ else:
467
+ lr = group['lr']
468
+ compiled_muon_step_parameter(state, grad, group, lr, random_int_tensor)
465
469
 
466
470
  @torch.no_grad()
467
471
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
@@ -10,7 +10,7 @@ from ..util.update_util import _grams_update, _cautious_update
10
10
  from ..util.OrthoGrad import _orthogonalize_gradient
11
11
  from ..util.Kourkoutas import KourkoutasHelper
12
12
 
13
- A = torch.as_tensor(4 / math.pi)
13
+ A = 4 / math.pi
14
14
 
15
15
  class AdamW_adv(torch.optim.Optimizer):
16
16
  """
@@ -233,7 +233,7 @@ class AdamW_adv(torch.optim.Optimizer):
233
233
  current_step = state['step']
234
234
  if group.get('kourkoutas_beta', False):
235
235
  # Call prepare_step() once at the beginning of the step for all params
236
- self.kourkoutas_helper.maybe_prepare_step(current_step)
236
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
237
237
  # Get the dynamic beta2 calculated in prepare_step()
238
238
  beta2 = self.kourkoutas_helper.get_beta2(p, group)
239
239
 
@@ -249,6 +249,7 @@ class AdamW_adv(torch.optim.Optimizer):
249
249
  random_int_tensor = None
250
250
 
251
251
  if group.get('compiled_optimizer', False):
252
+ step_size = torch.as_tensor(step_size)
252
253
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
253
254
  # Pre-generate random tensor for stochastic rounding if needed.
254
255
  random_int_tensor = param_update._get_random_int_for_sr(p)
@@ -9,7 +9,7 @@ from ..util.OrthoGrad import _orthogonalize_gradient
9
9
  from ..util.Kourkoutas import KourkoutasHelper
10
10
  from ..util.update_util import _grams_update, _cautious_update
11
11
 
12
- A = torch.as_tensor(4 / math.pi)
12
+ A = 4 / math.pi
13
13
 
14
14
  class Adopt_adv(torch.optim.Optimizer):
15
15
  """
@@ -258,7 +258,7 @@ class Adopt_adv(torch.optim.Optimizer):
258
258
  current_step = state['step']
259
259
  if group.get('kourkoutas_beta', False):
260
260
  # Call prepare_step() once at the beginning of the step for all params
261
- self.kourkoutas_helper.maybe_prepare_step(current_step)
261
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
262
262
  # Get the dynamic beta2 calculated in prepare_step()
263
263
  beta2 = self.kourkoutas_helper.get_beta2(p, group)
264
264
 
@@ -270,14 +270,15 @@ class Adopt_adv(torch.optim.Optimizer):
270
270
  random_int_tensor = None
271
271
 
272
272
  if group.get('compiled_optimizer', False):
273
+ lr = torch.as_tensor(group['lr'])
273
274
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
274
275
  # Pre-generate random tensor for stochastic rounding if needed.
275
276
  random_int_tensor = param_update._get_random_int_for_sr(p)
276
277
  step_param_fn = self._compiled_step_parameter
277
278
  else:
279
+ lr = group['lr']
278
280
  step_param_fn = self._step_parameter
279
281
 
280
- lr = group['lr']
281
282
 
282
283
  step_param_fn(p, grad, state, group, lr, beta1, beta2, random_int_tensor)
283
284
 
@@ -226,6 +226,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
226
226
  random_int_tensor = param_update._get_random_int_for_sr(p)
227
227
  # TODO, workaround until pytorch#169634 is fixed
228
228
  d = torch.as_tensor(group['d'])
229
+ dlr = torch.as_tensor(group['dlr'])
229
230
  step_param_fn = self._compiled_step_parameter
230
231
  else:
231
232
  d = group['d']
@@ -399,7 +399,11 @@ class Muon_adv(torch.optim.Optimizer):
399
399
 
400
400
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
401
401
 
402
- compiled_muon_step_parameter(state, grad, group, group['lr'], random_int_tensor)
402
+ if group.get('compiled_optimizer', False):
403
+ lr = torch.as_tensor(group['lr'])
404
+ else:
405
+ lr = group['lr']
406
+ compiled_muon_step_parameter(state, grad, group, lr, random_int_tensor)
403
407
 
404
408
  @torch.no_grad()
405
409
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
@@ -11,7 +11,7 @@ from ..util.Kourkoutas import KourkoutasHelper
11
11
  from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
12
12
  from ..util.update_util import _grams_update, _cautious_update
13
13
 
14
- A = torch.as_tensor(4 / math.pi)
14
+ A = 4 / math.pi
15
15
 
16
16
  class Prodigy_adv(torch.optim.Optimizer):
17
17
  """
@@ -327,7 +327,7 @@ class Prodigy_adv(torch.optim.Optimizer):
327
327
  current_step = state['step']
328
328
  if group.get('kourkoutas_beta', False):
329
329
  # Call prepare_step() once at the beginning of the step for all params
330
- self.kourkoutas_helper.maybe_prepare_step(current_step)
330
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
331
331
  # Get the dynamic beta2 calculated in prepare_step()
332
332
  beta2 = self.kourkoutas_helper.get_beta2(p, group)
333
333
  else:
@@ -343,6 +343,7 @@ class Prodigy_adv(torch.optim.Optimizer):
343
343
  random_int_tensor = param_update._get_random_int_for_sr(p)
344
344
  # TODO, workaround until pytorch#169634 is fixed
345
345
  d = torch.as_tensor(group['d'])
346
+ dlr = torch.as_tensor(dlr)
346
347
  step_param_fn = self._compiled_step_parameter
347
348
  else:
348
349
  d = group['d']
@@ -211,7 +211,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
211
211
  current_step = state['step']
212
212
  if group.get('kourkoutas_beta', False):
213
213
  # Call prepare_step() once at the beginning of the step for all params
214
- self.kourkoutas_helper.maybe_prepare_step(current_step)
214
+ self.kourkoutas_helper.maybe_prepare_step(current_step, p.device)
215
215
  # Accumulate current grad's norm for the *next* step
216
216
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
217
217
  # Get the dynamic beta2 calculated in prepare_step()
@@ -244,7 +244,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
244
244
  # Pre-generate random tensor for stochastic rounding if needed.
245
245
  random_int_tensor = param_update._get_random_int_for_sr(p)
246
246
  # TODO, workaround until pytorch#169634 is fixed
247
- sqrt_den_num = torch.as_tensor(sqrt_den_num)
247
+ lr = torch.as_tensor(lr)
248
248
  step_param_fn = self._compiled_step_parameter
249
249
  else:
250
250
  step_param_fn = self._step_parameter
@@ -289,10 +289,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
289
289
  state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
290
290
  del vt
291
291
 
292
- if group['use_bias_correction']:
293
- update.mul_(sqrt_den_num)
294
-
295
- update = update.view(p.shape).mul_(lr)
292
+ update = update.view(p.shape).mul_(lr * sqrt_den_num)
296
293
 
297
294
  else: # Standard optimizer logic for non-factored tensors
298
295
  exp_avg_sq = state['exp_avg_sq']
@@ -308,8 +305,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
308
305
  update.div_(denom)
309
306
  del denom
310
307
 
311
- update_scaling = lr * sqrt_den_num if group['use_bias_correction'] else lr
312
- update.mul_(update_scaling)
308
+ update.mul_(lr * sqrt_den_num)
313
309
 
314
310
  param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
315
311
 
@@ -0,0 +1,196 @@
1
+ import torch
2
+ from torch.optim import Optimizer
3
+
4
+ class KourkoutasHelper:
5
+ """
6
+ A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
7
+ """
8
+ def __init__(self, optimizer: Optimizer):
9
+ # We need a reference to the optimizer to access its param_groups and state
10
+ if not hasattr(optimizer, 'param_groups'):
11
+ raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
12
+ self.optimizer = optimizer
13
+ self.layer_state = {}
14
+
15
+ self.layer_info = {}
16
+ self._layer_info_built = False
17
+ self._current_step_prepared = -1
18
+
19
+ # Store stats for external logging (e.g., TensorBoard)
20
+ self.last_beta2_stats = {}
21
+
22
+ # This ensures the map is complete before the first backward pass,
23
+ # making it compatible with fused back pass mechanisms.
24
+ self._build_layer_info_if_needed()
25
+
26
+ def _build_layer_info_if_needed(self):
27
+ """Builds a map of layers and the parameters they contain."""
28
+ if self._layer_info_built:
29
+ return
30
+
31
+ if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
32
+ # A custom key function was provided by the user. We will use it.
33
+ pass
34
+ else:
35
+ # No key function was provided. Default to coarse, shape-based bucketing.
36
+ self.optimizer.layer_key_fn = lambda p: \
37
+ (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
38
+ else tuple(p.shape)
39
+ # This ensures that we won't mix embeddings with tokens (1 to 10)
40
+ # TODO find a better way to safeguard the embeddings
41
+
42
+ for group in self.optimizer.param_groups:
43
+ if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
44
+ continue
45
+
46
+ for p in group['params']:
47
+ # The mapping is static and should not depend on the presence of a gradient.
48
+ layer_key = self.optimizer.layer_key_fn(p)
49
+ if layer_key not in self.layer_info:
50
+ self.layer_info[layer_key] = {'params': [], 'group_ref': group}
51
+ self.layer_info[layer_key]['params'].append(p)
52
+
53
+ self._layer_info_built = True
54
+
55
+ def _get_or_init_layer_ema_tensor(self, layer_key, layer_params, device):
56
+ """
57
+ Retrieves the EMA tensor for this layer.
58
+ It handles synchronization between the internal layer_state and
59
+ the external optimizer.state (which is required for state_dict saving/loading).
60
+ """
61
+ # Initialize container in layer_state if missing
62
+ if layer_key not in self.layer_state:
63
+ self.layer_state[layer_key] = {
64
+ 'sum_sq_accumulator': torch.tensor(0.0, device=device, dtype=torch.float32)
65
+ }
66
+
67
+ internal_ema = self.layer_state[layer_key].get('kourkoutas_r_ema')
68
+
69
+ # Check optimizer.state for any existing state (e.g. from a loaded checkpoint)
70
+ # We check the first parameter in the list to see if it has state.
71
+ # If a checkpoint was loaded, optimizer.state[p] will contain the tensor.
72
+ representative_p = layer_params[0]
73
+ external_ema = self.optimizer.state[representative_p].get('kourkoutas_r_ema')
74
+
75
+ # Case A: Desync detected (Optimizer has state, but Internal doesn't, or they differ).
76
+ # This usually happens after load_state_dict(). We trust the optimizer.state.
77
+ if external_ema is not None and (internal_ema is None or internal_ema is not external_ema):
78
+ # Adopt the external tensor as our working tensor
79
+ self.layer_state[layer_key]['kourkoutas_r_ema'] = external_ema
80
+
81
+ # Ensure ALL params in this layer point to this exact tensor object
82
+ # (Fixes any fragmentation if only some params had state)
83
+ for p in layer_params:
84
+ self.optimizer.state[p]['kourkoutas_r_ema'] = external_ema
85
+
86
+ return external_ema
87
+
88
+ # Case B: No state anywhere. Create new.
89
+ if internal_ema is None:
90
+ new_ema = torch.tensor(0.0, device=device, dtype=torch.float32)
91
+ self.layer_state[layer_key]['kourkoutas_r_ema'] = new_ema
92
+
93
+ # Register this tensor in optimizer.state for ALL params so it gets saved
94
+ for p in layer_params:
95
+ self.optimizer.state[p]['kourkoutas_r_ema'] = new_ema
96
+
97
+ return new_ema
98
+
99
+ # Case C: Internal state exists and looks valid.
100
+ # We just need to ensure the link to optimizer.state is maintained (just in case).
101
+ # This is a cheap reference assignment.
102
+ for p in layer_params:
103
+ if 'kourkoutas_r_ema' not in self.optimizer.state[p]:
104
+ self.optimizer.state[p]['kourkoutas_r_ema'] = internal_ema
105
+
106
+ return internal_ema
107
+
108
+ def prepare_step(self, current_step: int, device):
109
+ """
110
+ Calculates dynamic beta2 for all layers using the completed scalar accumulators
111
+ from the PREVIOUS step. Should be called once at the start of an optimizer step.
112
+ """
113
+ beta2_log = []
114
+ master_defaults = self.optimizer.defaults
115
+
116
+ for layer_key, info in self.layer_info.items():
117
+ group = info['group_ref']
118
+
119
+ if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
120
+ continue
121
+
122
+ # Retrieve the EMA tensor. This function ensures the tensor is present
123
+ # in self.optimizer.state[p] for all parameters, ensuring state_dict support.
124
+ r_ema_tensor = self._get_or_init_layer_ema_tensor(layer_key, info['params'], device)
125
+
126
+ # Get accumulator
127
+ accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
128
+ pooled_grad_norm = torch.sqrt(accumulator)
129
+
130
+ # Use group-specific K-b settings, falling back to the optimizer's master defaults.
131
+ # This makes the helper robust against param groups that enable kourkoutas_beta
132
+ # but are missing the other required hyperparameters.
133
+ # In hybrid optimizers like Muon_adv, the Kourkoutas-related keys in the
134
+ # defaults and param_groups are prefixed with 'adam_' to avoid conflicts.
135
+ # We must detect this case and use the correct key names.
136
+ prefix = 'adam_' if group.get('adam_kourkoutas_beta', False) else ''
137
+
138
+ ema_alpha = group.get(f'{prefix}ema_alpha', master_defaults[f'{prefix}ema_alpha'])
139
+ betas_tuple = group.get(f'{prefix}betas', master_defaults[f'{prefix}betas'])
140
+ beta2_max = betas_tuple[1]
141
+ beta2_min = group.get(f'{prefix}beta2_min', master_defaults[f'{prefix}beta2_min'])
142
+ tiny_spike = group.get(f'{prefix}tiny_spike', master_defaults[f'{prefix}tiny_spike'])
143
+ k_warmup_steps = group.get(f'{prefix}k_warmup_steps', master_defaults[f'{prefix}k_warmup_steps'])
144
+
145
+ # Update the persistent EMA tensor in-place.
146
+ r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
147
+
148
+ # Calculate Beta2
149
+ if current_step < k_warmup_steps:
150
+ beta2 = beta2_max
151
+ else:
152
+ raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
153
+ sun = raw / (1.0 + raw)
154
+ beta2 = beta2_max - (beta2_max - beta2_min) * sun
155
+
156
+ # Store the final calculated beta2 in the helper's transient state for this step.
157
+ self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) and not group.get('compiled_optimizer', False) else beta2
158
+
159
+ # Reset the accumulator for the next optimizer step.
160
+ accumulator.zero_()
161
+
162
+ beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
163
+
164
+ # Compute stats for TensorBoard
165
+ if beta2_log:
166
+ beta2_tensor = torch.as_tensor(beta2_log, device='cpu')
167
+ self.last_beta2_stats = {
168
+ 'mean': beta2_tensor.mean().item()
169
+ }
170
+
171
+ def maybe_prepare_step(self, current_step: int, device):
172
+ """
173
+ A universal guard that calls prepare_step() exactly once per training step.
174
+ """
175
+ if self._current_step_prepared < current_step:
176
+ self.prepare_step(current_step, device)
177
+ self._current_step_prepared = current_step
178
+
179
+ def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
180
+ """
181
+ Accumulates the squared L2 norm of a single gradient for the next step's calculation.
182
+ """
183
+ layer_key = self.optimizer.layer_key_fn(p)
184
+
185
+ if layer_key in self.layer_info and layer_key in self.layer_state:
186
+ # Accumulate for the *next* step's prepare_step call
187
+ self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
188
+
189
+ def get_beta2(self, p: torch.Tensor, group: dict) -> float:
190
+ """
191
+ Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
192
+ """
193
+ layer_key = self.optimizer.layer_key_fn(p)
194
+ # The default is the max value, which is correct for unmapped params or edge cases
195
+ beta2_default = group.get('betas', group.get('adam_betas'))[1] if group.get('betas', group.get('adam_betas')) else 0.999
196
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
@@ -0,0 +1,194 @@
1
+ import torch
2
+
3
+ import math
4
+
5
+ from ..util import param_update
6
+ from ..util.OrthoGrad import _orthogonalize_gradient
7
+ from ..util.factorization_util import _get_effective_shape, _reconstruct_state, _factorize_state
8
+ from ..util.update_util import _grams_update, _cautious_update
9
+
10
+ A = 4 / math.pi
11
+
12
+ @torch.no_grad()
13
+ def _init_auxadam_state(self, p, group):
14
+ state = self.state[p]
15
+
16
+ state['step'] = 0
17
+
18
+ state['factored'] = (
19
+ group['adam_nnmf_factor'] and
20
+ not (len(p.shape) == 1 and not group['vector_reshape'])
21
+ )
22
+ dtype = torch.float32 if state['factored'] else p.dtype
23
+ device = p.device
24
+
25
+ if state['factored']:
26
+ state['effective_shape'] = _get_effective_shape(p.numel())
27
+ d1, d2 = state['effective_shape']
28
+ # First moment (m)
29
+ if group['adam_betas'][0] > 0:
30
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
31
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
32
+ packed_d2 = (d2 + 7) // 8
33
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
34
+ if group.get('adam_use_AdEMAMix'):
35
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
36
+ state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
37
+ packed_d2 = (d2 + 7) // 8
38
+ state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
39
+ # Second moment (v)
40
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
41
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
42
+ else: # Fallback to standard AdamW for non-factored tensors
43
+ if group['adam_betas'][0] > 0:
44
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
45
+ if group.get('adam_use_AdEMAMix'):
46
+ state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
47
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
48
+
49
+
50
+ @torch.no_grad()
51
+ def _adam_step_parameter(self, p, grad, state, group, is_compiled, random_int_tensor):
52
+
53
+ step = state['step']
54
+
55
+ beta1_adam, beta2_adam = group['adam_betas']
56
+
57
+ if self.kourkoutas_helper:
58
+ # Prepare Kourkoutas-β once per optimizer step.
59
+ self.kourkoutas_helper.maybe_prepare_step(step, p.device)
60
+ # Get the dynamic beta2_adam calculated in prepare_step()
61
+ beta2_adam = self.kourkoutas_helper.get_beta2(p, group)
62
+
63
+ if group['adam_use_bias_correction']:
64
+ current_step = step + 1
65
+ beta1_adam, beta2_adam = group['adam_betas']
66
+ bias_correction1 = 1.0 - beta1_adam ** current_step
67
+ sqrt_bias_correction2 = (1.0 - beta2_adam ** current_step)**0.5
68
+ else:
69
+ bias_correction1 = 1.0
70
+ sqrt_bias_correction2 = 1.0
71
+
72
+ state['step'] += 1
73
+
74
+ step_size = group['lr'] / bias_correction1
75
+
76
+ if group.get('compiled_optimizer', False):
77
+ step_size = torch.as_tensor(step_size)
78
+
79
+ @torch.compile(fullgraph=True, disable= not is_compiled)
80
+ def compiled_muon_step_parameter(state, grad, group, step_size, sqrt_bias_correction2, random_int_tensor):
81
+ if grad.dtype != torch.float32 and state.get('factored', False):
82
+ grad = grad.float()
83
+ if group.get("adam_orthogonal_gradient"):
84
+ grad = _orthogonalize_gradient(p, grad)
85
+
86
+ if self.kourkoutas_helper:
87
+ # Accumulate current grad's norm for the *next* step
88
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
89
+
90
+ if group.get('adam_use_AdEMAMix'):
91
+ beta3_ema = group['adam_beta3_ema']
92
+ alpha = group['adam_alpha']
93
+
94
+ if state['factored']:
95
+ d1, d2 = state['effective_shape']
96
+ grad_reshaped = grad.view(d1, d2)
97
+
98
+ # Reconstruct momentum from previous step's factors
99
+ if beta1_adam > 0:
100
+ mt = _reconstruct_state((state['mu_m_nmf'], state['mv_m_nmf'], state['sign'], d2), signed=True)
101
+
102
+ # Update momentum in full-size
103
+ mt.lerp_(grad_reshaped, 1.0 - beta1_adam)
104
+
105
+ # Factorize
106
+ state['mu_m_nmf'], state['mv_m_nmf'], state['sign'] = _factorize_state(mt.clone(), signed=True)
107
+
108
+ if group.get('adam_grams_moment'):
109
+ update_mt = _grams_update(mt, grad_reshaped, inplace=True)
110
+ elif group.get('adam_cautious_mask'):
111
+ update_mt = _cautious_update(mt, grad_reshaped, inplace=True)
112
+ else:
113
+ update_mt = mt
114
+
115
+ vt = _reconstruct_state((state['mu_v_nmf'], state['mv_v_nmf']), signed=False)
116
+ vt.mul_(beta2_adam).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2_adam)
117
+
118
+ if group.get('adam_use_AdEMAMix'):
119
+ mt_slow = _reconstruct_state((state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'], d2), signed=True)
120
+
121
+ mt_slow.lerp_(grad_reshaped, 1.0 - beta3_ema)
122
+
123
+ if beta1_adam > 0:
124
+ update = update_mt.add_(mt_slow, alpha=alpha)
125
+ else:
126
+ update = grad_reshaped.add(mt_slow, alpha=alpha)
127
+ # Factorize
128
+ state['mu_m_slow_nmf'], state['mv_m_slow_nmf'], state['sign_slow'] = _factorize_state(mt_slow, signed=True)
129
+ del mt_slow
130
+ else:
131
+ if beta1_adam > 0:
132
+ update = update_mt
133
+ else:
134
+ update = grad_reshaped.clone()
135
+
136
+ if group['adam_use_atan2']:
137
+ denom = vt.sqrt()
138
+ denom.div_(sqrt_bias_correction2)
139
+ update.atan2_(denom)
140
+ else:
141
+ denom = vt.sqrt()
142
+ denom.div_(sqrt_bias_correction2).add_(group['adam_eps'])
143
+ update.div_(denom)
144
+ del denom
145
+
146
+ # Factorize
147
+ state['mu_v_nmf'], state['mv_v_nmf'] = _factorize_state(vt, signed=False)
148
+ del vt
149
+
150
+ update_scaling = step_size * A if group['use_atan2'] else step_size
151
+ update = update.view(p.shape).mul_(update_scaling)
152
+
153
+ else: # Standard AdamW logic for non-factored tensors
154
+ if beta1_adam > 0:
155
+ exp_avg = state['exp_avg']
156
+ exp_avg.lerp_(grad, 1.0 - beta1_adam)
157
+
158
+ if group.get('adam_grams_moment'):
159
+ update_mt = _grams_update(exp_avg, grad)
160
+ elif group.get('adam_cautious_mask'):
161
+ update_mt = _cautious_update(exp_avg, grad)
162
+ else:
163
+ update_mt = exp_avg.clone()
164
+
165
+ if group.get('adam_use_AdEMAMix'):
166
+ exp_avg_slow = state['exp_avg_slow']
167
+ exp_avg_slow.lerp_(grad, 1.0 - beta3_ema)
168
+
169
+ if beta1_adam > 0:
170
+ update = update_mt.add_(exp_avg_slow, alpha=alpha)
171
+ else:
172
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
173
+ else:
174
+ update = update_mt if beta1_adam > 0 else grad.clone()
175
+
176
+ exp_avg_sq = state['exp_avg_sq']
177
+ exp_avg_sq.mul_(beta2_adam).addcmul_(grad, grad, value=1 - beta2_adam)
178
+
179
+ if group.get('adam_use_atan2'):
180
+ denom = exp_avg_sq.sqrt()
181
+ denom.div_(sqrt_bias_correction2)
182
+ update.atan2_(denom)
183
+ else:
184
+ denom = exp_avg_sq.sqrt()
185
+ denom.div_(sqrt_bias_correction2).add_(group['adam_eps'])
186
+ update.div_(denom)
187
+ del denom
188
+
189
+ update_scaling = step_size * A if group['adam_use_atan2'] else step_size
190
+ update.mul_(update_scaling)
191
+
192
+ param_update.apply_parameter_update(self, p, group, update, step_size, group["adam_weight_decay"], random_int_tensor=random_int_tensor)
193
+
194
+ compiled_muon_step_parameter(state, grad, group, step_size, sqrt_bias_correction2, random_int_tensor)