adv-optm 1.2.dev19__py3-none-any.whl → 2.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -27,10 +27,6 @@ class Lion_adv(torch.optim.Optimizer):
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
29
  cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
30
  nnmf_factor (bool): whether to use the factorization or use the
35
31
  uncompressed optimizer. (default: True)
36
32
  """
@@ -41,12 +37,13 @@ class Lion_adv(torch.optim.Optimizer):
41
37
  lr: float = 1e-4,
42
38
  betas: Tuple[float, float] = (0.9, 0.99),
43
39
  weight_decay: float = 0.0,
44
- vector_reshape: bool = True,
40
+ vector_reshape: bool = False,
45
41
  stochastic_rounding: bool = True,
46
42
  orthogonal_gradient: bool = False,
47
43
  cautious_mask: bool = False,
48
- clip_threshold: float = 0.0,
49
44
  nnmf_factor: bool = True,
45
+ # Compiled
46
+ compiled_optimizer: bool = False,
50
47
  ):
51
48
  if not lr > 0.0:
52
49
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -61,13 +58,20 @@ class Lion_adv(torch.optim.Optimizer):
61
58
  weight_decay=weight_decay,
62
59
  vector_reshape=vector_reshape,
63
60
  orthogonal_gradient=orthogonal_gradient,
64
- clip_threshold=clip_threshold,
61
+ compiled_optimizer=compiled_optimizer,
65
62
  )
66
63
  self.stochastic_rounding = stochastic_rounding
67
64
  self.cautious_mask = cautious_mask
68
65
  self.factored = nnmf_factor
69
66
  super().__init__(params, defaults)
70
67
 
68
+ self.init_step()
69
+
70
+ if compiled_optimizer:
71
+ torch._dynamo.config.cache_size_limit = 8192
72
+ self.compile(fullgraph=True)
73
+
74
+
71
75
  @property
72
76
  def supports_fused_back_pass(self) -> bool:
73
77
  return True
@@ -80,50 +84,50 @@ class Lion_adv(torch.optim.Optimizer):
80
84
  def supports_flat_params(self) -> bool:
81
85
  return False
82
86
 
83
- @torch.no_grad()
84
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
85
- """Performs a single optimization step on a single parameter."""
86
- if p.grad is None:
87
- return
87
+ def init_step(self):
88
+ for group in self.param_groups:
89
+ for i, p in enumerate(group['params']):
90
+ self.__init_state(p, group)
88
91
 
89
- grad = p.grad
90
- if grad.dtype != torch.float32 and self.factored:
91
- grad = grad.float()
92
- if group["clip_threshold"] > 0.0:
93
- grad_norm = torch.norm(grad.detach())
94
- if grad_norm > group["clip_threshold"]:
95
- clip_coef = group["clip_threshold"] / grad_norm
96
- grad.mul_(clip_coef)
97
- if group["orthogonal_gradient"]:
98
- grad = _orthogonalize_gradient(p, grad)
92
+ @torch.no_grad()
93
+ def __init_state(self, p, group):
99
94
  state = self.state[p]
100
95
 
101
- # State Initialization
102
- if 'step' not in state:
103
- state['step'] = 0
96
+ if len(state) == 0:
104
97
 
105
- should_factor = (
98
+ state['factored'] = (
106
99
  self.factored and
107
100
  not (len(p.shape) == 1 and not group['vector_reshape'])
108
101
  )
109
102
 
110
- state['factored'] = should_factor
111
-
112
103
  dtype = torch.float32 if self.factored else p.dtype
113
104
 
114
105
  if state['factored']:
115
106
  state['effective_shape'] = _get_effective_shape(p.numel())
116
107
  d1, d2 = state['effective_shape']
117
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
108
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
118
109
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
119
110
  packed_d2 = (d2 + 7) // 8
120
111
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
121
112
  else: # Fallback to standard Lion
122
113
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
123
114
 
124
- state['step'] += 1
115
+ @torch.no_grad()
116
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
117
+ """Performs a single optimization step on a single parameter."""
118
+ if p.grad is None:
119
+ return
120
+
121
+ grad = p.grad
122
+ if grad.dtype != torch.float32 and self.factored:
123
+ grad = grad.float()
124
+ if group["orthogonal_gradient"]:
125
+ grad = _orthogonalize_gradient(p, grad)
126
+
127
+ state = self.state[p]
128
+
129
+
125
130
  beta1, beta2 = group["betas"]
126
- lr = group["lr"]
127
131
 
128
132
  if state['factored']:
129
133
  # Factored Path
@@ -138,16 +142,16 @@ class Lion_adv(torch.optim.Optimizer):
138
142
  exp_avg = exp_avg.float()
139
143
 
140
144
  # Compute update term c_t
141
- signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
145
+ update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
142
146
 
143
147
  if self.cautious_mask:
144
- mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
148
+ mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
145
149
  mask.div_(mask.mean().clamp_(min=1e-3))
146
- signed_update.mul_(mask)
150
+ update.mul_(mask)
147
151
  del mask
148
152
 
149
153
  # Parameter update
150
- update_for_param = signed_update.view(p.shape).mul_(lr)
154
+ update = update.view(p.shape).mul_(lr)
151
155
 
152
156
  # Standard Lion momentum update
153
157
  exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
@@ -165,15 +169,15 @@ class Lion_adv(torch.optim.Optimizer):
165
169
  # Compute update term and sign for the update
166
170
  if exp_avg.dtype != torch.float32 and self.factored:
167
171
  exp_avg = exp_avg.float()
168
- signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
172
+ update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
169
173
 
170
174
  if self.cautious_mask:
171
- mask = (signed_update * grad > 0).to(grad.dtype)
175
+ mask = (update * grad > 0).to(grad.dtype)
172
176
  mask.div_(mask.mean().clamp_(min=1e-3))
173
- signed_update.mul_(mask)
177
+ update.mul_(mask)
174
178
  del mask
175
179
 
176
- update_for_param = signed_update.mul_(lr)
180
+ update.mul_(lr)
177
181
 
178
182
  # Standard Lion momentum update
179
183
  exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
@@ -188,11 +192,23 @@ class Lion_adv(torch.optim.Optimizer):
188
192
  )
189
193
 
190
194
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
191
- add_stochastic_(p.data, -update_for_param)
195
+ add_stochastic_(p.data, -update)
196
+ else:
197
+ p.data.add_(-update)
198
+
199
+ del update
200
+
201
+ @torch.no_grad()
202
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
203
+
204
+ if not group.get('compiled_optimizer', False):
205
+ self.__step_parameter(p, group, group["lr"])
192
206
  else:
193
- p.data.add_(-update_for_param)
207
+ lr_tensor = torch.tensor(group["lr"], device=p.device)
208
+ self._compiled_step_parameter(p, group, lr_tensor)
194
209
 
195
- del update_for_param
210
+ def compile(self, *args, **kwargs):
211
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
196
212
 
197
213
  @torch.no_grad()
198
214
  def step(self, closure: Optional[callable] = None):
@@ -207,4 +223,4 @@ class Lion_adv(torch.optim.Optimizer):
207
223
  if p.grad is not None:
208
224
  self.step_parameter(p, group, i)
209
225
 
210
- return loss
226
+ return loss
@@ -41,7 +41,6 @@ class Muon_adv(torch.optim.Optimizer):
41
41
  stability. (default: 100.0)
42
42
  stochastic_rounding (bool): whether to use stochastic rounding for
43
43
  BF16 parameter updates (default: True).
44
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
45
44
  vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
46
45
  matrices for muon NewtonSchulz (default: False).
47
46
  vector_reshape (bool): whether to reshape 1D vectors into 2D
@@ -60,6 +59,7 @@ class Muon_adv(torch.optim.Optimizer):
60
59
  normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
61
60
  normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
62
61
  (default: 0.2)
62
+ normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
63
63
  accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
64
64
  dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
65
65
  cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
@@ -92,7 +92,6 @@ class Muon_adv(torch.optim.Optimizer):
92
92
  Simplified_AdEMAMix: bool = False,
93
93
  alpha_grad: float = 100.0,
94
94
  stochastic_rounding: bool = True,
95
- orthogonal_gradient: bool = False,
96
95
  vector_reshape_muon: bool = False,
97
96
  vector_reshape: bool = False,
98
97
  nnmf_factor: bool = False,
@@ -104,6 +103,7 @@ class Muon_adv(torch.optim.Optimizer):
104
103
  beta2_normuon: float = 0.95,
105
104
  normuon_eps: float = 1e-8,
106
105
  normuon_lr_scale: float = 0.2,
106
+ normuon_atan2: bool = False,
107
107
  # CANS
108
108
  accelerated_ns: bool = False,
109
109
  cns_a_bound: float = 1e-4,
@@ -149,13 +149,13 @@ class Muon_adv(torch.optim.Optimizer):
149
149
  "vector_reshape": vector_reshape,
150
150
  "vector_reshape_muon": vector_reshape_muon,
151
151
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
152
- "orthogonal_gradient": orthogonal_gradient,
153
152
  'compiled_optimizer': compiled_optimizer,
154
153
  # Low-rank Ortho
155
154
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
156
155
  # NorMuon
157
156
  "normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
158
157
  "normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
158
+ "normuon_atan2": normuon_atan2,
159
159
  # CANS
160
160
  "accelerated_ns": accelerated_ns, "cns_a_bound": cns_a_bound,
161
161
  # AdamW_adv defaults
@@ -293,10 +293,6 @@ class Muon_adv(torch.optim.Optimizer):
293
293
  nesterov = group['nesterov']
294
294
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
295
295
  alpha_grad = group['alpha_grad']
296
- if grad.dtype != torch.float32 and state.get('factored', False):
297
- grad = grad.float()
298
- if group.get("orthogonal_gradient"):
299
- grad = _orthogonalize_gradient(p, grad)
300
296
 
301
297
  if state['factored']: # Factored Muon
302
298
 
@@ -363,7 +359,11 @@ class Muon_adv(torch.optim.Optimizer):
363
359
  mean_squared_update = torch.mean(update.square(), dim=1)
364
360
  v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
365
361
  # Normalize update
366
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
362
+ if group['normuon_atan2']:
363
+ a = 1.2732395
364
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
365
+ else:
366
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
367
367
  # Scale learning rate
368
368
  update_norm = torch.linalg.vector_norm(update)
369
369
 
@@ -464,7 +464,11 @@ class Muon_adv(torch.optim.Optimizer):
464
464
  mean_squared_update = torch.mean(update.square(), dim=1)
465
465
  v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
466
466
  # Normalize update
467
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
467
+ if group['normuon_atan2']:
468
+ a = 1.2732395
469
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
470
+ else:
471
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
468
472
  # Scale learning rate
469
473
  update_norm = torch.linalg.vector_norm(update)
470
474
  scaled_lr = group['normuon_lr_scale'] * lr * (p.numel()**0.5) / update_norm.add_(group['normuon_eps'])