adv-optm 1.2.dev8__tar.gz → 1.2.dev10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (29) hide show
  1. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/PKG-INFO +1 -1
  2. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/AdaMuon_adv.py +12 -24
  4. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/AdamW_adv.py +4 -4
  5. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Adopt_adv.py +3 -3
  6. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Muon_adv.py +5 -22
  7. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Prodigy_adv.py +7 -4
  8. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Simplified_AdEMAMix.py +3 -3
  9. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/Kourkoutas.py +9 -6
  10. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/__init__.py +1 -0
  11. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/PKG-INFO +1 -1
  12. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/SOURCES.txt +0 -1
  13. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/setup.py +1 -1
  14. adv_optm-1.2.dev8/adv_optm/util/MuonAdam_helper.py +0 -32
  15. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/LICENSE +0 -0
  16. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/README.md +0 -0
  17. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  18. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/Lion_adv.py +0 -0
  19. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/optim/__init__.py +0 -0
  20. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  21. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/Effective_Shape.py +0 -0
  22. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/NNMF.py +0 -0
  23. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/Newton_Schulz.py +0 -0
  24. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/One_Bit_Boolean.py +0 -0
  25. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm/util/OrthoGrad.py +0 -0
  26. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-1.2.dev8 → adv_optm-1.2.dev10}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev8
3
+ Version: 1.2.dev10
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev8"
23
+ __version__ = "1.2.dev10"
@@ -1,8 +1,7 @@
1
1
  import torch
2
- from typing import Optional, Callable
2
+ from typing import Optional
3
3
 
4
4
  from .AdamW_adv import AdamW_adv
5
- from ..util.MuonAdam_helper import MuonAdamHelper
6
5
 
7
6
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
8
7
  from ..util.Newton_Schulz import _newton_schulz_iteration
@@ -73,10 +72,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
73
72
  MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
74
73
  Parameters designated by `layer_key_fn` will be optimized with
75
74
  AdamW_adv instead of Muon. (default: False)
76
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
- and returns a key. If the key is 'adam', the parameter is handled by
78
- the auxiliary AdamW optimizer. All other keys are handled by Muon.
79
- Only used when `MuonWithAuxAdam` is True. (default: None)
80
75
  adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
81
76
  to the auxiliary AdamW_adv optimizer. Only used when
82
77
  `MuonWithAuxAdam` is True. (default: None)
@@ -106,7 +101,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
106
101
  nnmf_factor: bool = False,
107
102
  # hybrid optimizer mode
108
103
  MuonWithAuxAdam: bool = False,
109
- layer_key_fn: Optional[Callable] = None,
110
104
  muon_adam_lr: float = 1e-4,
111
105
  adam_kwargs: Optional[dict] = None,
112
106
  ):
@@ -132,10 +126,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
132
126
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
133
127
  }
134
128
  self.stochastic_rounding = stochastic_rounding
129
+
135
130
  self.MuonWithAuxAdam = MuonWithAuxAdam
136
- self.helper = None
137
131
  self.aux_adam = None
138
-
132
+
139
133
  if not self.MuonWithAuxAdam:
140
134
  super().__init__(params, muon_defaults)
141
135
  return
@@ -151,26 +145,17 @@ class AdaMuon_adv(torch.optim.Optimizer):
151
145
  adam_defaults = self.aux_adam.defaults
152
146
 
153
147
  final_param_groups = []
154
- _layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
155
-
156
148
  for group in params:
157
- # All params in a group are of the same type
158
- first_param = group['params'][0]
159
- key = _layer_key_fn(first_param)
160
- optim_type = 'adam' if key == 'adam' else 'muon'
161
-
162
- new_group = group.copy()
149
+ optim_type = group.get('optim_type', 'muon')
163
150
  defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
164
151
 
152
+ new_group = group.copy()
165
153
  for key, value in defaults_to_use.items():
166
154
  new_group.setdefault(key, value)
167
155
  final_param_groups.append(new_group)
168
156
 
169
157
  super().__init__(final_param_groups, muon_defaults)
170
158
 
171
- # Now that self is initialized, create the helper
172
- self.helper = MuonAdamHelper(self, layer_key_fn)
173
-
174
159
 
175
160
  @property
176
161
  def supports_fused_back_pass(self):
@@ -197,13 +182,14 @@ class AdaMuon_adv(torch.optim.Optimizer):
197
182
  @torch.no_grad()
198
183
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
199
184
  if self.MuonWithAuxAdam:
200
- optim_type = self.helper.get_optimizer_type(p)
185
+ optim_type = group.get('optim_type')
201
186
  if optim_type == 'adam':
202
187
  # Delegate to the AdamW_adv optimizer's logic.
203
188
  # We need to temporarily "lend" our state and param_groups
189
+ # to the delegate so it has the full context to work with,
190
+ # especially for features like Kourkoutas-beta.
204
191
  self.aux_adam.state = self.state
205
192
  self.aux_adam.param_groups = self.param_groups
206
-
207
193
  self.aux_adam.step_parameter(p, group, i)
208
194
  return
209
195
 
@@ -327,7 +313,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
327
313
  # RMS-aligned rescaling
328
314
  rms_target = group['rms_target']
329
315
  num_elements = update.numel()
330
- scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
316
+ # Add eps to prevent division by zero
317
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
331
318
 
332
319
  update.mul_(scaling_factor)
333
320
  update = update.view(p.shape).mul_(group['lr'])
@@ -422,7 +409,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
422
409
  # RMS-aligned rescaling
423
410
  rms_target = group['rms_target']
424
411
  num_elements = update.numel()
425
- scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
412
+ # Add eps to prevent division by zero
413
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm() + group['eps'])
426
414
 
427
415
  update.mul_(scaling_factor)
428
416
  del num_elements, scaling_factor
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
73
73
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
74
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
75
  logging (default: 0).
76
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
76
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
77
  and returns a unique, hashable key representing its "layer" or "bucket".
78
78
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
79
  (default: None)
@@ -105,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
105
105
  tiny_spike: float = 1e-9,
106
106
  k_warmup_steps: int = 0,
107
107
  k_logging: int = 0,
108
- layer_key_kb_fn: Optional[Callable] = None,
108
+ layer_key_fn: Optional[Callable] = None,
109
109
  nnmf_factor: bool = False,
110
110
  _is_delegate: bool = False,
111
111
  ):
@@ -137,7 +137,7 @@ class AdamW_adv(torch.optim.Optimizer):
137
137
  self.use_AdEMAMix = use_AdEMAMix
138
138
  self.factored = nnmf_factor
139
139
  self.kourkoutas_beta = kourkoutas_beta
140
- self.layer_key_kb_fn = layer_key_kb_fn
140
+ self.layer_key_fn = layer_key_fn
141
141
  if not _is_delegate:
142
142
  super().__init__(params, defaults)
143
143
  else:
@@ -244,6 +244,7 @@ class AdamW_adv(torch.optim.Optimizer):
244
244
 
245
245
  if state['factored']:
246
246
  d1, d2 = state['effective_shape']
247
+ grad_reshaped = grad.view(d1, d2)
247
248
 
248
249
  # Reconstruct momentum from previous step's factors
249
250
  if beta1 > 0:
@@ -253,7 +254,6 @@ class AdamW_adv(torch.optim.Optimizer):
253
254
  torch.where(unpacked_sign, mt, -mt, out=mt)
254
255
  del unpacked_sign
255
256
  # Update momentum in full-size
256
- grad_reshaped = grad.view(d1, d2)
257
257
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
258
258
  if self.grams_moment:
259
259
  mt.copy_(grad_reshaped.sign() * mt.abs())
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
91
91
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
92
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
93
  logging (default: 0).
94
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
94
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
95
  and returns a unique, hashable key representing its "layer" or "bucket".
96
96
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
97
  (default: None)
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
125
125
  tiny_spike: float = 1e-9,
126
126
  k_warmup_steps: int = 0,
127
127
  k_logging: int = 0,
128
- layer_key_kb_fn: Optional[Callable] = None,
128
+ layer_key_fn: Optional[Callable] = None,
129
129
  nnmf_factor: bool = False,
130
130
  ):
131
131
  if not (lr >= 0.0):
@@ -166,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
166
166
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
167
167
  self.factored = nnmf_factor
168
168
  self.kourkoutas_beta = kourkoutas_beta
169
- self.layer_key_kb_fn = layer_key_kb_fn
169
+ self.layer_key_fn = layer_key_fn
170
170
  super().__init__(params, defaults)
171
171
 
172
172
  if self.kourkoutas_beta:
@@ -1,8 +1,6 @@
1
1
  import torch
2
- from typing import Optional, Callable
3
-
2
+ from typing import Optional
4
3
  from .AdamW_adv import AdamW_adv
5
- from ..util.MuonAdam_helper import MuonAdamHelper
6
4
 
7
5
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
8
6
  from ..util.Newton_Schulz import _newton_schulz_iteration
@@ -74,10 +72,6 @@ class Muon_adv(torch.optim.Optimizer):
74
72
  MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
75
73
  Parameters designated by `layer_key_fn` will be optimized with
76
74
  AdamW_adv instead of Muon. (default: False)
77
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
78
- and returns a key. If the key is 'adam', the parameter is handled by
79
- the auxiliary AdamW optimizer. All other keys are handled by Muon.
80
- Only used when `MuonWithAuxAdam` is True. (default: None)
81
75
  adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
82
76
  to the auxiliary AdamW_adv optimizer. Only used when
83
77
  `MuonWithAuxAdam` is True. (default: None)
@@ -110,7 +104,6 @@ class Muon_adv(torch.optim.Optimizer):
110
104
  normuon_atan2: bool = False,
111
105
  # hybrid optimizer mode
112
106
  MuonWithAuxAdam: bool = False,
113
- layer_key_fn: Optional[Callable] = None,
114
107
  muon_adam_lr: float = 1e-4,
115
108
  adam_kwargs: Optional[dict] = None,
116
109
  ):
@@ -145,9 +138,8 @@ class Muon_adv(torch.optim.Optimizer):
145
138
  self.stochastic_rounding = stochastic_rounding
146
139
 
147
140
  self.MuonWithAuxAdam = MuonWithAuxAdam
148
- self.helper = None
149
141
  self.aux_adam = None
150
-
142
+
151
143
  if not self.MuonWithAuxAdam:
152
144
  super().__init__(params, muon_defaults)
153
145
  return
@@ -163,25 +155,16 @@ class Muon_adv(torch.optim.Optimizer):
163
155
  adam_defaults = self.aux_adam.defaults
164
156
 
165
157
  final_param_groups = []
166
- _layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
167
-
168
158
  for group in params:
169
- first_param = group['params'][0]
170
- key = _layer_key_fn(first_param)
171
- optim_type = 'adam' if key == 'adam' else 'muon'
172
-
173
- new_group = group.copy()
159
+ optim_type = group.get('optim_type', 'muon')
174
160
  defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
175
161
 
162
+ new_group = group.copy()
176
163
  for key, value in defaults_to_use.items():
177
164
  new_group.setdefault(key, value)
178
-
179
165
  final_param_groups.append(new_group)
180
166
 
181
167
  super().__init__(final_param_groups, muon_defaults)
182
-
183
- # Now that self is initialized, create the helper
184
- self.helper = MuonAdamHelper(self, layer_key_fn)
185
168
 
186
169
 
187
170
  @property
@@ -209,7 +192,7 @@ class Muon_adv(torch.optim.Optimizer):
209
192
  @torch.no_grad()
210
193
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
211
194
  if self.MuonWithAuxAdam:
212
- optim_type = self.helper.get_optimizer_type(p)
195
+ optim_type = group.get('optim_type')
213
196
  if optim_type == 'adam':
214
197
  # Delegate to the AdamW_adv optimizer's logic.
215
198
  # We need to temporarily "lend" our state and param_groups
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
109
109
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
110
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
111
  logging (default: 0).
112
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
112
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
113
113
  and returns a unique, hashable key representing its "layer" or "bucket".
114
114
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
115
115
  (default: None)
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
152
152
  tiny_spike: float = 1e-9,
153
153
  k_warmup_steps: int = 0,
154
154
  k_logging: int = 0,
155
- layer_key_kb_fn: Optional[Callable] = None,
155
+ layer_key_fn: Optional[Callable] = None,
156
156
  ):
157
157
  if not (lr >= 0.0):
158
158
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
205
205
  self.fsdp_in_use = fsdp_in_use
206
206
 
207
207
  self.kourkoutas_beta = kourkoutas_beta
208
- self.layer_key_kb_fn = layer_key_kb_fn
208
+ self.layer_key_fn = layer_key_fn
209
209
 
210
210
  super().__init__(params, defaults)
211
211
  if self.kourkoutas_beta:
@@ -516,7 +516,10 @@ class Prodigy_adv(torch.optim.Optimizer):
516
516
  if global_d_denom > 0:
517
517
  d_hat = d_coef * global_d_numerator / global_d_denom
518
518
  if g_group.get('d_limiter', False):
519
- d_hat = min(self.d * (2 ** 0.25), d_hat)
519
+ if g_group.get('Simplified_AdEMAMix', False):
520
+ d_hat = min(self.d * (2 ** 0.1), d_hat)
521
+ else:
522
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
520
523
  if self.d == g_group['d0']:
521
524
  self.d = max(self.d, d_hat)
522
525
  d_max = max(d_max, d_hat)
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
67
67
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
68
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
69
  logging (default: 0).
70
- layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
70
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
71
  and returns a unique, hashable key representing its "layer" or "bucket".
72
72
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
73
  (default: None)
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
95
95
  tiny_spike: float = 1e-9,
96
96
  k_warmup_steps: int = 0,
97
97
  k_logging: int = 0,
98
- layer_key_kb_fn: Optional[Callable] = None,
98
+ layer_key_fn: Optional[Callable] = None,
99
99
  nnmf_factor: bool = False,
100
100
  ):
101
101
  if not (lr >= 0.0):
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
121
121
  self.stochastic_rounding = stochastic_rounding
122
122
  self.factored = nnmf_factor
123
123
  self.kourkoutas_beta = kourkoutas_beta
124
- self.layer_key_kb_fn = layer_key_kb_fn
124
+ self.layer_key_fn = layer_key_fn
125
125
  super().__init__(params, defaults)
126
126
 
127
127
  if self.kourkoutas_beta:
@@ -32,21 +32,24 @@ class KourkoutasHelper:
32
32
  if self._layer_info_built:
33
33
  return
34
34
 
35
- if hasattr(self.optimizer, 'layer_key_kb_fn') and self.optimizer.layer_key_kb_fn is not None:
35
+ if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
36
36
  # A custom key function was provided by the user. We will use it.
37
37
  pass
38
38
  else:
39
39
  # No key function was provided. Default to coarse, shape-based bucketing.
40
- self.optimizer.layer_key_kb_fn = lambda p: \
40
+ self.optimizer.layer_key_fn = lambda p: \
41
41
  (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
42
42
  else tuple(p.shape)
43
43
  # This ensures that we won't mix embeddings with tokens (1 to 10)
44
44
  # TODO find a better way to safeguard the embeddings
45
45
 
46
46
  for group in self.optimizer.param_groups:
47
+ if not group.get('kourkoutas_beta', False):
48
+ continue
49
+
47
50
  for p in group['params']:
48
51
  # The mapping is static and should not depend on the presence of a gradient.
49
- layer_key = self.optimizer.layer_key_kb_fn(p)
52
+ layer_key = self.optimizer.layer_key_fn(p)
50
53
  if layer_key not in self.layer_info:
51
54
  self.layer_info[layer_key] = {'params': [], 'group_ref': group}
52
55
  self.layer_info[layer_key]['params'].append(p)
@@ -94,7 +97,7 @@ class KourkoutasHelper:
94
97
  for layer_key, info in self.layer_info.items():
95
98
  params, group = info['params'], info['group_ref']
96
99
 
97
- if not group.get('kourkoutas_beta', False) and not group.get('_kourkoutas_beta', False):
100
+ if not group.get('kourkoutas_beta', False):
98
101
  continue
99
102
 
100
103
  first_param_in_layer = info['params'][0]
@@ -174,7 +177,7 @@ class KourkoutasHelper:
174
177
  """
175
178
  Accumulates the squared L2 norm of a single gradient for the next step's calculation.
176
179
  """
177
- layer_key = self.optimizer.layer_key_kb_fn(p)
180
+ layer_key = self.optimizer.layer_key_fn(p)
178
181
 
179
182
  if layer_key in self.layer_info:
180
183
  # Initialize the transient state for this layer if it's the first time in the step.
@@ -189,6 +192,6 @@ class KourkoutasHelper:
189
192
  """
190
193
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
191
194
  """
192
- layer_key = self.optimizer.layer_key_kb_fn(p)
195
+ layer_key = self.optimizer.layer_key_fn(p)
193
196
  # The default is the max value, which is correct for unmapped params or edge cases
194
197
  return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
@@ -3,6 +3,7 @@ from .Effective_Shape import _get_effective_shape
3
3
  from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
4
  from .OrthoGrad import _orthogonalize_gradient
5
5
  from .Newton_Schulz import _newton_schulz_iteration
6
+
6
7
  __all__ = [
7
8
  "_pack_bools", "_unpack_bools",
8
9
  "add_stochastic_",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev8
3
+ Version: 1.2.dev10
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -19,7 +19,6 @@ adv_optm/optim/__init__.py
19
19
  adv_optm/util/BF16_Stochastic_Rounding.py
20
20
  adv_optm/util/Effective_Shape.py
21
21
  adv_optm/util/Kourkoutas.py
22
- adv_optm/util/MuonAdam_helper.py
23
22
  adv_optm/util/NNMF.py
24
23
  adv_optm/util/Newton_Schulz.py
25
24
  adv_optm/util/One_Bit_Boolean.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.2.dev8",
8
+ version="1.2.dev10",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,32 +0,0 @@
1
- import torch
2
- from torch.optim import Optimizer
3
- from typing import Callable, Optional
4
-
5
- class MuonAdamHelper:
6
- """
7
- A helper class for Muon_adv to decide whether to use Muon or a delegate
8
- AdamW optimizer for a given parameter based on a keying function.
9
- """
10
- def __init__(self, optimizer: Optimizer, layer_key_fn: Optional[Callable]):
11
- if not hasattr(optimizer, 'param_groups'):
12
- raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
- self.optimizer = optimizer
14
-
15
- if layer_key_fn is None:
16
- # If no function is provided, default all parameters to 'muon'.
17
- self.layer_key_fn = lambda p: 'muon'
18
- else:
19
- self.layer_key_fn = layer_key_fn
20
-
21
- def get_optimizer_type(self, p: "torch.Tensor") -> str:
22
- """
23
- Returns the designated optimizer type ('adam' or 'muon') for a parameter.
24
-
25
- The user-provided layer_key_fn should return 'adam' for parameters
26
- to be handled by the auxiliary AdamW optimizer. Any other return
27
- value is treated as 'muon'.
28
- """
29
- key = self.layer_key_fn(p)
30
- if key == 'adam':
31
- return 'adam'
32
- return 'muon'
File without changes
File without changes
File without changes