adv-optm 1.2.dev9__py3-none-any.whl → 1.2.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev9"
23
+ __version__ = "1.2.dev11"
@@ -1,8 +1,7 @@
1
1
  import torch
2
- from typing import Optional, Callable
2
+ from typing import Optional
3
3
 
4
4
  from .AdamW_adv import AdamW_adv
5
- from ..util.MuonAdam_helper import MuonAdamHelper
6
5
 
7
6
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
8
7
  from ..util.Newton_Schulz import _newton_schulz_iteration
@@ -12,7 +11,7 @@ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
12
11
 
13
12
  class AdaMuon_adv(torch.optim.Optimizer):
14
13
  """
15
- Implements the AdaMuon optimizer algorithm.
14
+ IImplements an advanced AdaMuon optimizer algorithm.
16
15
 
17
16
  AdaMuon combines the geometry-aware updates of Muon with the element-wise
18
17
  adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
@@ -26,9 +25,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
26
25
  3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
27
26
  allowing for reuse of learning rate schedules.
28
27
 
29
- Can also operate in a hybrid mode, using an auxiliary AdamW
30
- optimizer for specific parameters (e.g., biases, norms, embeddings) as
31
- defined by a `layer_key_fn`.
32
28
 
33
29
  Args:
34
30
  params (iterable): iterable of parameters to optimize or dicts defining
@@ -70,16 +66,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
70
66
  (default: 128)
71
67
  nnmf_factor (bool): whether to use the factorization or disable it to use
72
68
  the uncompressed optimizer. (default: False)
73
- MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
74
- Parameters designated by `layer_key_fn` will be optimized with
75
- AdamW_adv instead of Muon. (default: False)
76
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
- and returns a key. If the key is 'adam', the parameter is handled by
78
- the auxiliary AdamW optimizer. All other keys are handled by Muon.
79
- Only used when `MuonWithAuxAdam` is True. (default: None)
80
- adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
81
- to the auxiliary AdamW_adv optimizer. Only used when
82
- `MuonWithAuxAdam` is True. (default: None)
83
69
  """
84
70
 
85
71
  def __init__(
@@ -104,11 +90,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
104
90
  low_rank_ortho: bool = False,
105
91
  ortho_rank: int = 128,
106
92
  nnmf_factor: bool = False,
107
- # hybrid optimizer mode
108
- MuonWithAuxAdam: bool = False,
109
- layer_key_fn: Optional[Callable] = None,
110
- muon_adam_lr: float = 1e-4,
111
- adam_kwargs: Optional[dict] = None,
112
93
  ):
113
94
  if not (lr >= 0.0):
114
95
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -120,7 +101,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
120
101
  print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
121
102
  nesterov = False
122
103
 
123
- muon_defaults = {
104
+ defaults = {
124
105
  "lr": lr, "betas": betas, "weight_decay": weight_decay,
125
106
  "eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
126
107
  "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
@@ -132,44 +113,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
132
113
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
133
114
  }
134
115
  self.stochastic_rounding = stochastic_rounding
135
- self.MuonWithAuxAdam = MuonWithAuxAdam
136
- self.helper = None
137
- self.aux_adam = None
138
-
139
- if not self.MuonWithAuxAdam:
140
- super().__init__(params, muon_defaults)
141
- return
142
116
 
143
- # HYBRID OPTIMIZER LOGIC
144
- adam_kwargs = adam_kwargs or {}
145
- self.aux_adam = AdamW_adv(
146
- [],
147
- lr=muon_adam_lr,
148
- **adam_kwargs,
149
- _is_delegate=True
150
- )
151
- adam_defaults = self.aux_adam.defaults
152
-
153
- final_param_groups = []
154
- _layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
155
-
156
- for group in params:
157
- # All params in a group are of the same type
158
- first_param = group['params'][0]
159
- key = _layer_key_fn(first_param)
160
- optim_type = 'adam' if key == 'adam' else 'muon'
161
-
162
- new_group = group.copy()
163
- defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
164
-
165
- for key, value in defaults_to_use.items():
166
- new_group.setdefault(key, value)
167
- final_param_groups.append(new_group)
168
-
169
- super().__init__(final_param_groups, muon_defaults)
170
-
171
- # Now that self is initialized, create the helper
172
- self.helper = MuonAdamHelper(self, layer_key_fn)
117
+ super().__init__(params, defaults)
173
118
 
174
119
 
175
120
  @property
@@ -184,29 +129,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
184
129
  def supports_flat_params(self):
185
130
  return False
186
131
 
187
- @property
188
- def kourkoutas_helper(self):
189
- """
190
- Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
191
- if it exists. This allows external access for logging K-b.
192
- """
193
- if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
194
- return self.aux_adam.kourkoutas_helper
195
- return None
196
-
197
132
  @torch.no_grad()
198
133
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
199
- if self.MuonWithAuxAdam:
200
- optim_type = self.helper.get_optimizer_type(p)
201
- if optim_type == 'adam':
202
- # Delegate to the AdamW_adv optimizer's logic.
203
- # We need to temporarily "lend" our state and param_groups
204
- self.aux_adam.state = self.state
205
- self.aux_adam.param_groups = self.param_groups
206
-
207
- self.aux_adam.step_parameter(p, group, i)
208
- return
209
-
210
134
  if p.grad is None:
211
135
  return
212
136
 
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
73
73
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
74
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
75
  logging (default: 0).
76
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
76
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
77
  and returns a unique, hashable key representing its "layer" or "bucket".
78
78
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
79
  (default: None)
@@ -105,9 +105,8 @@ class AdamW_adv(torch.optim.Optimizer):
105
105
  tiny_spike: float = 1e-9,
106
106
  k_warmup_steps: int = 0,
107
107
  k_logging: int = 0,
108
- layer_key_kb_fn: Optional[Callable] = None,
108
+ layer_key_fn: Optional[Callable] = None,
109
109
  nnmf_factor: bool = False,
110
- _is_delegate: bool = False,
111
110
  ):
112
111
  if not (lr >= 0.0):
113
112
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -137,12 +136,11 @@ class AdamW_adv(torch.optim.Optimizer):
137
136
  self.use_AdEMAMix = use_AdEMAMix
138
137
  self.factored = nnmf_factor
139
138
  self.kourkoutas_beta = kourkoutas_beta
140
- self.layer_key_kb_fn = layer_key_kb_fn
141
- if not _is_delegate:
142
- super().__init__(params, defaults)
143
- else:
144
- self.defaults = defaults
145
- self.kourkoutas_helper = None
139
+ self.layer_key_fn = layer_key_fn
140
+ super().__init__(params, defaults)
141
+
142
+ if self.kourkoutas_beta:
143
+ self.kourkoutas_helper = KourkoutasHelper(self)
146
144
 
147
145
  @property
148
146
  def supports_fused_back_pass(self):
@@ -160,8 +158,6 @@ class AdamW_adv(torch.optim.Optimizer):
160
158
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
161
159
  if p.grad is None:
162
160
  return
163
- if group.get('kourkoutas_beta', False) and self.kourkoutas_helper is None:
164
- self.kourkoutas_helper = KourkoutasHelper(self)
165
161
 
166
162
  grad = p.grad
167
163
  if grad.dtype != torch.float32 and self.factored:
@@ -244,6 +240,7 @@ class AdamW_adv(torch.optim.Optimizer):
244
240
 
245
241
  if state['factored']:
246
242
  d1, d2 = state['effective_shape']
243
+ grad_reshaped = grad.view(d1, d2)
247
244
 
248
245
  # Reconstruct momentum from previous step's factors
249
246
  if beta1 > 0:
@@ -253,7 +250,6 @@ class AdamW_adv(torch.optim.Optimizer):
253
250
  torch.where(unpacked_sign, mt, -mt, out=mt)
254
251
  del unpacked_sign
255
252
  # Update momentum in full-size
256
- grad_reshaped = grad.view(d1, d2)
257
253
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
258
254
  if self.grams_moment:
259
255
  mt.copy_(grad_reshaped.sign() * mt.abs())
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
91
91
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
92
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
93
  logging (default: 0).
94
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
94
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
95
  and returns a unique, hashable key representing its "layer" or "bucket".
96
96
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
97
  (default: None)
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
125
125
  tiny_spike: float = 1e-9,
126
126
  k_warmup_steps: int = 0,
127
127
  k_logging: int = 0,
128
- layer_key_kb_fn: Optional[Callable] = None,
128
+ layer_key_fn: Optional[Callable] = None,
129
129
  nnmf_factor: bool = False,
130
130
  ):
131
131
  if not (lr >= 0.0):
@@ -166,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
166
166
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
167
167
  self.factored = nnmf_factor
168
168
  self.kourkoutas_beta = kourkoutas_beta
169
- self.layer_key_kb_fn = layer_key_kb_fn
169
+ self.layer_key_fn = layer_key_fn
170
170
  super().__init__(params, defaults)
171
171
 
172
172
  if self.kourkoutas_beta:
@@ -1,8 +1,6 @@
1
1
  import torch
2
- from typing import Optional, Callable
3
-
2
+ from typing import Optional
4
3
  from .AdamW_adv import AdamW_adv
5
- from ..util.MuonAdam_helper import MuonAdamHelper
6
4
 
7
5
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
8
6
  from ..util.Newton_Schulz import _newton_schulz_iteration
@@ -25,10 +23,6 @@ class Muon_adv(torch.optim.Optimizer):
25
23
  This implementation is designed for 2D parameters (e.g., linear layers) and
26
24
  can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
27
25
  flattening/reshaping them.
28
-
29
- Can also operate in a hybrid mode, using an auxiliary AdamW
30
- optimizer for specific parameters (e.g., biases, norms, embeddings) as
31
- defined by a `layer_key_fn`.
32
26
 
33
27
  Args:
34
28
  params (iterable): iterable of parameters to optimize or dicts defining
@@ -71,16 +65,6 @@ class Muon_adv(torch.optim.Optimizer):
71
65
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
72
66
  (default: 0.2)
73
67
  normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
74
- MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
75
- Parameters designated by `layer_key_fn` will be optimized with
76
- AdamW_adv instead of Muon. (default: False)
77
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
78
- and returns a key. If the key is 'adam', the parameter is handled by
79
- the auxiliary AdamW optimizer. All other keys are handled by Muon.
80
- Only used when `MuonWithAuxAdam` is True. (default: None)
81
- adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
82
- to the auxiliary AdamW_adv optimizer. Only used when
83
- `MuonWithAuxAdam` is True. (default: None)
84
68
  """
85
69
 
86
70
  def __init__(
@@ -108,11 +92,6 @@ class Muon_adv(torch.optim.Optimizer):
108
92
  normuon_eps: float = 1e-8,
109
93
  normuon_lr_scale: float = 0.2,
110
94
  normuon_atan2: bool = False,
111
- # hybrid optimizer mode
112
- MuonWithAuxAdam: bool = False,
113
- layer_key_fn: Optional[Callable] = None,
114
- muon_adam_lr: float = 1e-4,
115
- adam_kwargs: Optional[dict] = None,
116
95
  ):
117
96
  if not (lr >= 0.0):
118
97
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -128,7 +107,7 @@ class Muon_adv(torch.optim.Optimizer):
128
107
  print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
129
108
  nesterov = False
130
109
 
131
- muon_defaults = {
110
+ defaults = {
132
111
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay,
133
112
  "nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
134
113
  "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
@@ -144,44 +123,7 @@ class Muon_adv(torch.optim.Optimizer):
144
123
  }
145
124
  self.stochastic_rounding = stochastic_rounding
146
125
 
147
- self.MuonWithAuxAdam = MuonWithAuxAdam
148
- self.helper = None
149
- self.aux_adam = None
150
-
151
- if not self.MuonWithAuxAdam:
152
- super().__init__(params, muon_defaults)
153
- return
154
-
155
- # HYBRID OPTIMIZER LOGIC
156
- adam_kwargs = adam_kwargs or {}
157
- self.aux_adam = AdamW_adv(
158
- [],
159
- lr=muon_adam_lr,
160
- **adam_kwargs,
161
- _is_delegate=True
162
- )
163
- adam_defaults = self.aux_adam.defaults
164
-
165
- final_param_groups = []
166
- _layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
167
-
168
- for group in params:
169
- first_param = group['params'][0]
170
- key = _layer_key_fn(first_param)
171
- optim_type = 'adam' if key == 'adam' else 'muon'
172
-
173
- new_group = group.copy()
174
- defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
175
-
176
- for key, value in defaults_to_use.items():
177
- new_group.setdefault(key, value)
178
-
179
- final_param_groups.append(new_group)
180
-
181
- super().__init__(final_param_groups, muon_defaults)
182
-
183
- # Now that self is initialized, create the helper
184
- self.helper = MuonAdamHelper(self, layer_key_fn)
126
+ super().__init__(params, defaults)
185
127
 
186
128
 
187
129
  @property
@@ -196,30 +138,8 @@ class Muon_adv(torch.optim.Optimizer):
196
138
  def supports_flat_params(self):
197
139
  return False
198
140
 
199
- @property
200
- def kourkoutas_helper(self):
201
- """
202
- Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
203
- if it exists. This allows external access for logging K-b.
204
- """
205
- if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
206
- return self.aux_adam.kourkoutas_helper
207
- return None
208
-
209
141
  @torch.no_grad()
210
142
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
211
- if self.MuonWithAuxAdam:
212
- optim_type = self.helper.get_optimizer_type(p)
213
- if optim_type == 'adam':
214
- # Delegate to the AdamW_adv optimizer's logic.
215
- # We need to temporarily "lend" our state and param_groups
216
- # to the delegate so it has the full context to work with,
217
- # especially for features like Kourkoutas-beta.
218
- self.aux_adam.state = self.state
219
- self.aux_adam.param_groups = self.param_groups
220
- self.aux_adam.step_parameter(p, group, i)
221
- return
222
-
223
143
  if p.grad is None:
224
144
  return
225
145
 
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
109
109
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
110
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
111
  logging (default: 0).
112
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
112
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
113
113
  and returns a unique, hashable key representing its "layer" or "bucket".
114
114
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
115
115
  (default: None)
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
152
152
  tiny_spike: float = 1e-9,
153
153
  k_warmup_steps: int = 0,
154
154
  k_logging: int = 0,
155
- layer_key_kb_fn: Optional[Callable] = None,
155
+ layer_key_fn: Optional[Callable] = None,
156
156
  ):
157
157
  if not (lr >= 0.0):
158
158
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
205
205
  self.fsdp_in_use = fsdp_in_use
206
206
 
207
207
  self.kourkoutas_beta = kourkoutas_beta
208
- self.layer_key_kb_fn = layer_key_kb_fn
208
+ self.layer_key_fn = layer_key_fn
209
209
 
210
210
  super().__init__(params, defaults)
211
211
  if self.kourkoutas_beta:
@@ -516,7 +516,10 @@ class Prodigy_adv(torch.optim.Optimizer):
516
516
  if global_d_denom > 0:
517
517
  d_hat = d_coef * global_d_numerator / global_d_denom
518
518
  if g_group.get('d_limiter', False):
519
- d_hat = min(self.d * (2 ** 0.25), d_hat)
519
+ if g_group.get('Simplified_AdEMAMix', False):
520
+ d_hat = min(self.d * (2 ** 0.1), d_hat)
521
+ else:
522
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
520
523
  if self.d == g_group['d0']:
521
524
  self.d = max(self.d, d_hat)
522
525
  d_max = max(d_max, d_hat)
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
67
67
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
68
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
69
  logging (default: 0).
70
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
70
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
71
  and returns a unique, hashable key representing its "layer" or "bucket".
72
72
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
73
  (default: None)
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
95
95
  tiny_spike: float = 1e-9,
96
96
  k_warmup_steps: int = 0,
97
97
  k_logging: int = 0,
98
- layer_key_kb_fn: Optional[Callable] = None,
98
+ layer_key_fn: Optional[Callable] = None,
99
99
  nnmf_factor: bool = False,
100
100
  ):
101
101
  if not (lr >= 0.0):
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
121
121
  self.stochastic_rounding = stochastic_rounding
122
122
  self.factored = nnmf_factor
123
123
  self.kourkoutas_beta = kourkoutas_beta
124
- self.layer_key_kb_fn = layer_key_kb_fn
124
+ self.layer_key_fn = layer_key_fn
125
125
  super().__init__(params, defaults)
126
126
 
127
127
  if self.kourkoutas_beta:
@@ -24,57 +24,35 @@ class KourkoutasHelper:
24
24
  # making it compatible with fused back pass mechanisms.
25
25
  self._build_layer_info_if_needed()
26
26
 
27
- if self.optimizer.param_groups[0].get('k_logging', 0) > 0:
28
- self.print_layer_info()
29
-
30
27
  def _build_layer_info_if_needed(self):
31
28
  """Builds a map of layers and the parameters they contain."""
32
29
  if self._layer_info_built:
33
30
  return
34
31
 
35
- if hasattr(self.optimizer, 'layer_key_kb_fn') and self.optimizer.layer_key_kb_fn is not None:
32
+ if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
36
33
  # A custom key function was provided by the user. We will use it.
37
34
  pass
38
35
  else:
39
36
  # No key function was provided. Default to coarse, shape-based bucketing.
40
- self.optimizer.layer_key_kb_fn = lambda p: \
37
+ self.optimizer.layer_key_fn = lambda p: \
41
38
  (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
42
39
  else tuple(p.shape)
43
40
  # This ensures that we won't mix embeddings with tokens (1 to 10)
44
41
  # TODO find a better way to safeguard the embeddings
45
42
 
46
43
  for group in self.optimizer.param_groups:
44
+ if not group.get('kourkoutas_beta', False):
45
+ continue
46
+
47
47
  for p in group['params']:
48
48
  # The mapping is static and should not depend on the presence of a gradient.
49
- layer_key = self.optimizer.layer_key_kb_fn(p)
49
+ layer_key = self.optimizer.layer_key_fn(p)
50
50
  if layer_key not in self.layer_info:
51
51
  self.layer_info[layer_key] = {'params': [], 'group_ref': group}
52
52
  self.layer_info[layer_key]['params'].append(p)
53
-
54
- k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
55
- if k_logging_interval > 0:
56
- print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
57
53
 
58
54
  self._layer_info_built = True
59
55
 
60
- def print_layer_info(self):
61
- """Prints the contents of self.layer_info for debugging."""
62
- print("\n--- BEGIN self.layer_info DUMP ---")
63
- if not self.layer_info:
64
- print("Layer info is empty. Make sure the optimizer has parameters.")
65
- return
66
-
67
- for layer_key, info in self.layer_info.items():
68
- param_count = len(info['params'])
69
- first_param_details = ""
70
- if param_count > 0:
71
- p = info['params'][0]
72
- first_param_details = f" (Example param shape: {list(p.shape)}, dtype: {p.dtype})"
73
-
74
- print(f"Key: {layer_key}, Params: {param_count}{first_param_details}")
75
-
76
- print("--- END self.layer_info DUMP ---\n")
77
-
78
56
  def prepare_step(self, current_step: int):
79
57
  """
80
58
  Calculates dynamic beta2 for all layers using the completed scalar accumulators
@@ -82,9 +60,8 @@ class KourkoutasHelper:
82
60
  """
83
61
 
84
62
  beta2_log = []
85
- first_layer_key = next(iter(self.layer_info), None)
86
63
  # These are just for the sample log, initialize them
87
- sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
64
+ sun, pooled_grad_norm, r_ema_tensor = (torch.tensor(0.0),)*3
88
65
 
89
66
  # The optimizer that owns this helper holds the master defaults for K-b.
90
67
  # This is crucial in hybrid optimizers where some param_groups might not
@@ -94,7 +71,7 @@ class KourkoutasHelper:
94
71
  for layer_key, info in self.layer_info.items():
95
72
  params, group = info['params'], info['group_ref']
96
73
 
97
- if not group.get('kourkoutas_beta', False) and not group.get('_kourkoutas_beta', False):
74
+ if not group.get('kourkoutas_beta', False):
98
75
  continue
99
76
 
100
77
  first_param_in_layer = info['params'][0]
@@ -121,7 +98,6 @@ class KourkoutasHelper:
121
98
  accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
122
99
 
123
100
  pooled_grad_norm = torch.sqrt(accumulator)
124
- prev_r_ema_val = r_ema_tensor.item() # for logging
125
101
 
126
102
  # Update the persistent EMA tensor in-place.
127
103
  r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
@@ -147,21 +123,9 @@ class KourkoutasHelper:
147
123
  if beta2_log:
148
124
  beta2_tensor = torch.tensor(beta2_log, device='cpu')
149
125
  self.last_beta2_stats = {
150
- 'min': beta2_tensor.min().item(),
151
- 'max': beta2_tensor.max().item(),
152
126
  'mean': beta2_tensor.mean().item(),
153
127
  }
154
128
 
155
- # Handle periodic console logging
156
- k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
157
- is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
158
- if is_logging_step and self.last_beta2_stats:
159
- if first_layer_key:
160
- print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{first_layer_key}':")
161
- print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema_tensor.item():.4e}")
162
- print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {self.layer_state[first_layer_key]['dynamic_beta2']:.4f}")
163
- print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
164
-
165
129
  def maybe_prepare_step(self, current_step: int):
166
130
  """
167
131
  A universal guard that calls prepare_step() exactly once per training step.
@@ -174,7 +138,7 @@ class KourkoutasHelper:
174
138
  """
175
139
  Accumulates the squared L2 norm of a single gradient for the next step's calculation.
176
140
  """
177
- layer_key = self.optimizer.layer_key_kb_fn(p)
141
+ layer_key = self.optimizer.layer_key_fn(p)
178
142
 
179
143
  if layer_key in self.layer_info:
180
144
  # Initialize the transient state for this layer if it's the first time in the step.
@@ -189,6 +153,6 @@ class KourkoutasHelper:
189
153
  """
190
154
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
191
155
  """
192
- layer_key = self.optimizer.layer_key_kb_fn(p)
156
+ layer_key = self.optimizer.layer_key_fn(p)
193
157
  # The default is the max value, which is correct for unmapped params or edge cases
194
158
  return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
adv_optm/util/__init__.py CHANGED
@@ -3,6 +3,7 @@ from .Effective_Shape import _get_effective_shape
3
3
  from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
4
  from .OrthoGrad import _orthogonalize_gradient
5
5
  from .Newton_Schulz import _newton_schulz_iteration
6
+
6
7
  __all__ = [
7
8
  "_pack_bools", "_unpack_bools",
8
9
  "add_stochastic_",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev9
3
+ Version: 1.2.dev11
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -0,0 +1,23 @@
1
+ adv_optm/__init__.py,sha256=vjm5Sc3hgTSy9qP73qVBVGp9zE2J6blsQmj6KxyI3GE,380
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=828WtdsaKXJqlZqFXE2yrsxY3Erxn-6N7CxV9jBXiaI,17880
3
+ adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
4
+ adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
5
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
+ adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
+ adv_optm/optim/Muon_adv.py,sha256=xGW9PafaIyi1noGhIgCWPwndI5bGX6kbxN-N-FQnr1U,19381
8
+ adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
9
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
10
+ adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
+ adv_optm/util/Kourkoutas.py,sha256=_fq2glPqKmzgWpLedfwq5EqIJAxICUK2fmUP-cdcgq0,7467
14
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
15
+ adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
16
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
+ adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
19
+ adv_optm-1.2.dev11.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
+ adv_optm-1.2.dev11.dist-info/METADATA,sha256=F6o4bbgIEjB9JS_9gediI_0-_rUkfsubKVtg5b4nrHE,14023
21
+ adv_optm-1.2.dev11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ adv_optm-1.2.dev11.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
+ adv_optm-1.2.dev11.dist-info/RECORD,,
@@ -1,32 +0,0 @@
1
- import torch
2
- from torch.optim import Optimizer
3
- from typing import Callable, Optional
4
-
5
- class MuonAdamHelper:
6
- """
7
- A helper class for Muon_adv to decide whether to use Muon or a delegate
8
- AdamW optimizer for a given parameter based on a keying function.
9
- """
10
- def __init__(self, optimizer: Optimizer, layer_key_fn: Optional[Callable]):
11
- if not hasattr(optimizer, 'param_groups'):
12
- raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
- self.optimizer = optimizer
14
-
15
- if layer_key_fn is None:
16
- # If no function is provided, default all parameters to 'muon'.
17
- self.layer_key_fn = lambda p: 'muon'
18
- else:
19
- self.layer_key_fn = layer_key_fn
20
-
21
- def get_optimizer_type(self, p: "torch.Tensor") -> str:
22
- """
23
- Returns the designated optimizer type ('adam' or 'muon') for a parameter.
24
-
25
- The user-provided layer_key_fn should return 'adam' for parameters
26
- to be handled by the auxiliary AdamW optimizer. Any other return
27
- value is treated as 'muon'.
28
- """
29
- key = self.layer_key_fn(p)
30
- if key == 'adam':
31
- return 'adam'
32
- return 'muon'
@@ -1,24 +0,0 @@
1
- adv_optm/__init__.py,sha256=TzvKgGTLkK0_XANeZzhURcSO9xmtUi-H9_C7tV3rXn4,379
2
- adv_optm/optim/AdaMuon_adv.py,sha256=yr1oJV339Zv7D8n148O1FJJAgdOsH8NZDZTKlcDOyu0,21181
3
- adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
4
- adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
5
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
- adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=HaF06fPKcKpVZY29_vqjWHAfivjvGntBuRyDDKj3Ozw,22784
8
- adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
9
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
10
- adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
- adv_optm/util/Kourkoutas.py,sha256=lObJGXmz3MqGSuu3DKqotSpZ0fuQFPE80R3zO_j3Z_Q,9707
14
- adv_optm/util/MuonAdam_helper.py,sha256=7rnNMujZVDaqo1g22QscMyPlZvIHQQSLHMED9_I8QWU,1250
15
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
16
- adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
17
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
18
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
19
- adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
20
- adv_optm-1.2.dev9.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
21
- adv_optm-1.2.dev9.dist-info/METADATA,sha256=GmAYWjZdfgvg9QbzyiV2PUNmzQFgJz8AjaY5F0x7Nv8,14022
22
- adv_optm-1.2.dev9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- adv_optm-1.2.dev9.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
24
- adv_optm-1.2.dev9.dist-info/RECORD,,