adv-optm 1.2.dev18__py3-none-any.whl → 2.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -27,10 +27,6 @@ class Lion_adv(torch.optim.Optimizer):
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
29
  cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
30
  nnmf_factor (bool): whether to use the factorization or use the
35
31
  uncompressed optimizer. (default: True)
36
32
  """
@@ -41,12 +37,13 @@ class Lion_adv(torch.optim.Optimizer):
41
37
  lr: float = 1e-4,
42
38
  betas: Tuple[float, float] = (0.9, 0.99),
43
39
  weight_decay: float = 0.0,
44
- vector_reshape: bool = True,
40
+ vector_reshape: bool = False,
45
41
  stochastic_rounding: bool = True,
46
42
  orthogonal_gradient: bool = False,
47
43
  cautious_mask: bool = False,
48
- clip_threshold: float = 0.0,
49
44
  nnmf_factor: bool = True,
45
+ # Compiled
46
+ compiled_optimizer: bool = False,
50
47
  ):
51
48
  if not lr > 0.0:
52
49
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -61,13 +58,20 @@ class Lion_adv(torch.optim.Optimizer):
61
58
  weight_decay=weight_decay,
62
59
  vector_reshape=vector_reshape,
63
60
  orthogonal_gradient=orthogonal_gradient,
64
- clip_threshold=clip_threshold,
61
+ compiled_optimizer=compiled_optimizer,
65
62
  )
66
63
  self.stochastic_rounding = stochastic_rounding
67
64
  self.cautious_mask = cautious_mask
68
65
  self.factored = nnmf_factor
69
66
  super().__init__(params, defaults)
70
67
 
68
+ self.init_step()
69
+
70
+ if compiled_optimizer:
71
+ torch._dynamo.config.cache_size_limit = 8192
72
+ self.compile(fullgraph=True)
73
+
74
+
71
75
  @property
72
76
  def supports_fused_back_pass(self) -> bool:
73
77
  return True
@@ -80,50 +84,50 @@ class Lion_adv(torch.optim.Optimizer):
80
84
  def supports_flat_params(self) -> bool:
81
85
  return False
82
86
 
83
- @torch.no_grad()
84
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
85
- """Performs a single optimization step on a single parameter."""
86
- if p.grad is None:
87
- return
87
+ def init_step(self):
88
+ for group in self.param_groups:
89
+ for i, p in enumerate(group['params']):
90
+ self.__init_state(p, group)
88
91
 
89
- grad = p.grad
90
- if grad.dtype != torch.float32 and self.factored:
91
- grad = grad.float()
92
- if group["clip_threshold"] > 0.0:
93
- grad_norm = torch.norm(grad.detach())
94
- if grad_norm > group["clip_threshold"]:
95
- clip_coef = group["clip_threshold"] / grad_norm
96
- grad.mul_(clip_coef)
97
- if group["orthogonal_gradient"]:
98
- grad = _orthogonalize_gradient(p, grad)
92
+ @torch.no_grad()
93
+ def __init_state(self, p, group):
99
94
  state = self.state[p]
100
95
 
101
- # State Initialization
102
- if 'step' not in state:
103
- state['step'] = 0
96
+ if len(state) == 0:
104
97
 
105
- should_factor = (
98
+ state['factored'] = (
106
99
  self.factored and
107
100
  not (len(p.shape) == 1 and not group['vector_reshape'])
108
101
  )
109
102
 
110
- state['factored'] = should_factor
111
-
112
103
  dtype = torch.float32 if self.factored else p.dtype
113
104
 
114
105
  if state['factored']:
115
106
  state['effective_shape'] = _get_effective_shape(p.numel())
116
107
  d1, d2 = state['effective_shape']
117
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
108
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
118
109
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
119
110
  packed_d2 = (d2 + 7) // 8
120
111
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
121
112
  else: # Fallback to standard Lion
122
113
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
123
114
 
124
- state['step'] += 1
115
+ @torch.no_grad()
116
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
117
+ """Performs a single optimization step on a single parameter."""
118
+ if p.grad is None:
119
+ return
120
+
121
+ grad = p.grad
122
+ if grad.dtype != torch.float32 and self.factored:
123
+ grad = grad.float()
124
+ if group["orthogonal_gradient"]:
125
+ grad = _orthogonalize_gradient(p, grad)
126
+
127
+ state = self.state[p]
128
+
129
+
125
130
  beta1, beta2 = group["betas"]
126
- lr = group["lr"]
127
131
 
128
132
  if state['factored']:
129
133
  # Factored Path
@@ -138,16 +142,16 @@ class Lion_adv(torch.optim.Optimizer):
138
142
  exp_avg = exp_avg.float()
139
143
 
140
144
  # Compute update term c_t
141
- signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
145
+ update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
142
146
 
143
147
  if self.cautious_mask:
144
- mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
148
+ mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
145
149
  mask.div_(mask.mean().clamp_(min=1e-3))
146
- signed_update.mul_(mask)
150
+ update.mul_(mask)
147
151
  del mask
148
152
 
149
153
  # Parameter update
150
- update_for_param = signed_update.view(p.shape).mul_(lr)
154
+ update = update.view(p.shape).mul_(lr)
151
155
 
152
156
  # Standard Lion momentum update
153
157
  exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
@@ -165,15 +169,15 @@ class Lion_adv(torch.optim.Optimizer):
165
169
  # Compute update term and sign for the update
166
170
  if exp_avg.dtype != torch.float32 and self.factored:
167
171
  exp_avg = exp_avg.float()
168
- signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
172
+ update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
169
173
 
170
174
  if self.cautious_mask:
171
- mask = (signed_update * grad > 0).to(grad.dtype)
175
+ mask = (update * grad > 0).to(grad.dtype)
172
176
  mask.div_(mask.mean().clamp_(min=1e-3))
173
- signed_update.mul_(mask)
177
+ update.mul_(mask)
174
178
  del mask
175
179
 
176
- update_for_param = signed_update.mul_(lr)
180
+ update.mul_(lr)
177
181
 
178
182
  # Standard Lion momentum update
179
183
  exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
@@ -188,11 +192,23 @@ class Lion_adv(torch.optim.Optimizer):
188
192
  )
189
193
 
190
194
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
191
- add_stochastic_(p.data, -update_for_param)
195
+ add_stochastic_(p.data, -update)
196
+ else:
197
+ p.data.add_(-update)
198
+
199
+ del update
200
+
201
+ @torch.no_grad()
202
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
203
+
204
+ if not group.get('compiled_optimizer', False):
205
+ self.__step_parameter(p, group, group["lr"])
192
206
  else:
193
- p.data.add_(-update_for_param)
207
+ lr_tensor = torch.tensor(group["lr"], device=p.device)
208
+ self._compiled_step_parameter(p, group, lr_tensor)
194
209
 
195
- del update_for_param
210
+ def compile(self, *args, **kwargs):
211
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
196
212
 
197
213
  @torch.no_grad()
198
214
  def step(self, closure: Optional[callable] = None):
@@ -207,4 +223,4 @@ class Lion_adv(torch.optim.Optimizer):
207
223
  if p.grad is not None:
208
224
  self.step_parameter(p, group, i)
209
225
 
210
- return loss
226
+ return loss
@@ -41,7 +41,6 @@ class Muon_adv(torch.optim.Optimizer):
41
41
  stability. (default: 100.0)
42
42
  stochastic_rounding (bool): whether to use stochastic rounding for
43
43
  BF16 parameter updates (default: True).
44
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
45
44
  vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
46
45
  matrices for muon NewtonSchulz (default: False).
47
46
  vector_reshape (bool): whether to reshape 1D vectors into 2D
@@ -77,7 +76,6 @@ class Muon_adv(torch.optim.Optimizer):
77
76
  adam_beta3_ema (float): Beta3 for AdEMAMix.
78
77
  adam_alpha (float): Alpha for AdEMAMix.
79
78
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
80
- adam_nnmf_factor (bool): 1-bit factored for AdamW.
81
79
  """
82
80
 
83
81
  def __init__(
@@ -93,7 +91,6 @@ class Muon_adv(torch.optim.Optimizer):
93
91
  Simplified_AdEMAMix: bool = False,
94
92
  alpha_grad: float = 100.0,
95
93
  stochastic_rounding: bool = True,
96
- orthogonal_gradient: bool = False,
97
94
  vector_reshape_muon: bool = False,
98
95
  vector_reshape: bool = False,
99
96
  nnmf_factor: bool = False,
@@ -128,7 +125,6 @@ class Muon_adv(torch.optim.Optimizer):
128
125
  adam_ema_alpha: float = 0.95,
129
126
  adam_tiny_spike: float = 1e-9,
130
127
  adam_k_warmup_steps: int = 0,
131
- adam_nnmf_factor: bool = False,
132
128
  ):
133
129
  if not (lr >= 0.0):
134
130
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -151,7 +147,6 @@ class Muon_adv(torch.optim.Optimizer):
151
147
  "vector_reshape": vector_reshape,
152
148
  "vector_reshape_muon": vector_reshape_muon,
153
149
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
154
- "orthogonal_gradient": orthogonal_gradient,
155
150
  'compiled_optimizer': compiled_optimizer,
156
151
  # Low-rank Ortho
157
152
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
@@ -170,7 +165,6 @@ class Muon_adv(torch.optim.Optimizer):
170
165
  "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
171
166
  "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
172
167
  "adam_k_warmup_steps": adam_k_warmup_steps,
173
- "adam_nnmf_factor":adam_nnmf_factor,
174
168
  }
175
169
  self.stochastic_rounding = stochastic_rounding
176
170
  self.compiled_optimizer = compiled_optimizer
@@ -296,10 +290,6 @@ class Muon_adv(torch.optim.Optimizer):
296
290
  nesterov = group['nesterov']
297
291
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
298
292
  alpha_grad = group['alpha_grad']
299
- if grad.dtype != torch.float32 and state.get('factored', False):
300
- grad = grad.float()
301
- if group.get("orthogonal_gradient"):
302
- grad = _orthogonalize_gradient(p, grad)
303
293
 
304
294
  if state['factored']: # Factored Muon
305
295