adv-optm 1.2.dev14__py3-none-any.whl → 2.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -27,17 +27,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
29
  cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
30
  nnmf_factor (bool): whether to use the factorization or use the
35
31
  uncompressed optimizer. (default: True)
36
32
  d0 (float):
37
33
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
38
34
  d_coef (float):
39
35
  Coefficient in the expression for the estimate of d (default 1.0).
40
- Values such as 0.5 and 2.0 typically work as well.
36
+ Values such as 0.5 and 2.0 typically work as well.
41
37
  Changing this parameter is the preferred way to tune the method.
42
38
  growth_rate (float):
43
39
  prevent the D estimate from growing faster than this multiplicative rate.
@@ -47,8 +43,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
47
43
  If you're using sharded parameters, this should be set to True. The optimizer
48
44
  will attempt to auto-detect this, but if you're using an implementation other
49
45
  than PyTorch's builtin version, the auto-detection won't work.
50
- slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
51
- pth entry of each tensor. For values greater than 1 this an an approximation to standard
46
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
47
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
52
48
  Prodigy. Values ~11 are reasonable (default 11).
53
49
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
54
50
  after the specified optimiser step and release all state memory required by Prodigy
@@ -64,11 +60,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
64
60
  lr: float = 1,
65
61
  betas: Tuple[float, float] = (0.9, 0.99),
66
62
  weight_decay: float = 0.0,
67
- vector_reshape: bool = True,
63
+ vector_reshape: bool = False,
68
64
  stochastic_rounding: bool = True,
69
65
  orthogonal_gradient: bool = False,
70
66
  cautious_mask: bool = False,
71
- clip_threshold: float = 0.0,
72
67
  nnmf_factor: bool = False,
73
68
  # prodigy parameters
74
69
  beta3: float = None,
@@ -80,6 +75,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
80
75
  slice_p: int = 11,
81
76
  prodigy_steps: int = 0,
82
77
  d_limiter: bool = True,
78
+ # Compiled
79
+ compiled_optimizer: bool = False,
83
80
  ):
84
81
  if not lr > 0.0:
85
82
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -94,21 +91,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
94
91
  weight_decay=weight_decay,
95
92
  vector_reshape=vector_reshape,
96
93
  orthogonal_gradient=orthogonal_gradient,
97
- clip_threshold=clip_threshold,
98
94
  beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
99
- growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
95
+ growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, slice_p=slice_p,
100
96
  fsdp_in_use=fsdp_in_use,
101
97
  prodigy_steps=prodigy_steps,
102
98
  d_limiter=d_limiter,
99
+ compiled_optimizer=compiled_optimizer,
103
100
  )
104
101
  self.stochastic_rounding = stochastic_rounding
105
102
  self.cautious_mask = cautious_mask
106
103
  self.factored = nnmf_factor
107
104
  self.fsdp_in_use = fsdp_in_use
108
105
  super().__init__(params, defaults)
109
- # Global state for accumulating metrics across parameter updates within a single step.
106
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
107
+ self.device = self.param_groups[0]['params'][0].device
108
+
109
+ self.global_step = 0
110
110
  self.init_step()
111
111
 
112
+ if compiled_optimizer:
113
+ torch._dynamo.config.cache_size_limit = 8192
114
+ self.compile(fullgraph=False, dynamic=False) #FIXME
115
+
112
116
  @property
113
117
  def supports_fused_back_pass(self) -> bool:
114
118
  return True
@@ -124,14 +128,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
124
128
  def init_step(self):
125
129
  """Resets accumulators and calculates dlr for the upcoming step."""
126
130
  self.d_denom = 0.0
127
-
131
+
128
132
  g_group = self.param_groups[0]
129
133
  self.beta1, self.beta2 = g_group['betas']
130
134
  self.beta3 = g_group['beta3']
131
135
  if self.beta3 is None:
132
136
  self.beta3 = math.sqrt(self.beta2)
133
-
134
- k = g_group['k']
137
+
135
138
  self.d = g_group['d']
136
139
  lr = g_group['lr']
137
140
 
@@ -139,38 +142,21 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
139
142
 
140
143
  self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
141
144
 
142
- @torch.no_grad()
143
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
144
- """Performs a single optimization step on a single parameter."""
145
- if p.grad is None:
146
- return
147
-
148
- if hasattr(p, "_fsdp_flattened"):
149
- self.fsdp_in_use = True
145
+ for group in self.param_groups:
146
+ for i, p in enumerate(group['params']):
147
+ self.__init_state(p, group)
150
148
 
151
- grad = p.grad
152
- if grad.dtype != torch.float32 and self.factored:
153
- grad = grad.float()
154
- if group["clip_threshold"] > 0.0:
155
- grad_norm = torch.norm(grad.detach())
156
- if grad_norm > group["clip_threshold"]:
157
- clip_coef = group["clip_threshold"] / grad_norm
158
- grad.mul_(clip_coef)
159
- if group["orthogonal_gradient"]:
160
- grad = _orthogonalize_gradient(p, grad)
149
+ @torch.no_grad()
150
+ def __init_state(self, p, group):
161
151
  state = self.state[p]
162
152
 
163
- # State Initialization
164
- if 'step' not in state:
165
- state['step'] = 0
153
+ if len(state) == 0:
166
154
 
167
- should_factor = (
155
+ state['factored'] = (
168
156
  self.factored and
169
157
  not (len(p.shape) == 1 and not group['vector_reshape'])
170
158
  )
171
159
 
172
- state['factored'] = should_factor
173
-
174
160
  dtype = torch.float32 if self.factored else p.dtype
175
161
 
176
162
  slice_p = group['slice_p']
@@ -185,13 +171,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
185
171
  if state['factored']:
186
172
  state['effective_shape'] = _get_effective_shape(p.numel())
187
173
  d1, d2 = state['effective_shape']
188
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
174
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
189
175
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
190
176
  packed_d2 = (d2 + 7) // 8
191
177
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
192
178
  else: # Fallback to standard Lion
193
179
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
194
180
 
181
+ @torch.no_grad()
182
+ def __step_parameter(self, p: torch.Tensor, group: dict, d: torch.Tensor | float, dlr: torch.Tensor | float):
183
+ """Performs a single optimization step on a single parameter."""
184
+ if p.grad is None:
185
+ return
186
+
187
+
188
+ grad = p.grad
189
+ if grad.dtype != torch.float32 and self.factored:
190
+ grad = grad.float()
191
+ if group["orthogonal_gradient"]:
192
+ grad = _orthogonalize_gradient(p, grad)
193
+ state = self.state[p]
194
+
195
+
195
196
  if state['factored']:
196
197
  # Factored Path
197
198
  d1, d2 = state['effective_shape']
@@ -205,7 +206,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
205
206
  exp_avg = exp_avg.float()
206
207
 
207
208
  # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
208
- signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
209
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=d * (1-self.beta1)).sign_()
209
210
 
210
211
  if self.cautious_mask:
211
212
  mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
@@ -214,10 +215,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
214
215
  del mask
215
216
 
216
217
  # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
217
- update_for_param = signed_update.view(p.shape).mul(self.dlr)
218
+ update_for_param = signed_update.view(p.shape).mul(dlr)
218
219
 
219
220
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
220
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
221
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=d * (1 - self.beta2))
221
222
  del grad_reshaped
222
223
 
223
224
  # Compress new momentum m_t and store factors
@@ -232,7 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
232
233
  # Compute update term and sign for the update
233
234
  if exp_avg.dtype != torch.float32 and self.factored:
234
235
  exp_avg = exp_avg.float()
235
- signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
236
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=d * (1-self.beta1)).sign_()
236
237
 
237
238
  if self.cautious_mask:
238
239
  mask = (signed_update * grad > 0).to(grad.dtype)
@@ -240,41 +241,18 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
240
241
  signed_update.mul_(mask)
241
242
  del mask
242
243
 
243
- update_for_param = signed_update.mul(self.dlr)
244
-
245
- # Update momentum
246
- exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
247
-
248
- prodigy_steps = group['prodigy_steps']
249
- if prodigy_steps <= 0 or group['k'] < prodigy_steps:
250
- # --- Accumulate Prodigy stats ---
251
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
252
- s, p0 = state['s'], state['p0']
253
- grad_flat = grad.flatten().float()
254
- p_flat = p.data.flatten().float()
255
- p0 = p0.float()
256
-
257
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
258
-
259
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
260
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
261
- self.d_denom += s.abs().sum().item()
244
+ update_for_param = signed_update.mul(dlr)
262
245
 
263
- del s, p0, grad_flat, p_flat, alpha
264
- else:
265
- # Free memory if prodigy_steps is reached
266
- if 's' in state:
267
- del state['s']
268
- if 'p0' in state:
269
- del state['p0']
246
+ # Update momentum
247
+ exp_avg.mul_(self.beta2).add_(grad, alpha=d * (1 - self.beta2))
270
248
 
271
249
  if group["weight_decay"] != 0:
272
250
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
273
251
  add_stochastic_(p.data, p.data,
274
- alpha=-group["weight_decay"] * self.dlr)
252
+ alpha=-group["weight_decay"] * dlr)
275
253
  else:
276
254
  p.data.add_(
277
- p.data, alpha=-group["weight_decay"] * self.dlr
255
+ p.data, alpha=-group["weight_decay"] * dlr
278
256
  )
279
257
 
280
258
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -284,6 +262,29 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
284
262
 
285
263
  del update_for_param
286
264
 
265
+ @torch.no_grad()
266
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
267
+ if hasattr(p, "_fsdp_flattened"):
268
+ self.fsdp_in_use = True
269
+
270
+ if self.global_step is None and 'step' in self.state[p]:
271
+ # For backward compatibility
272
+ self.global_step = self.state[p]['step']
273
+
274
+ if isinstance(self.d_numerator, float):
275
+ self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
276
+ self.d_denom = torch.tensor(self.d_denom, device=p.device)
277
+
278
+ if not group.get('compiled_optimizer', False):
279
+ self.__step_parameter(p, group, self.d, self.dlr)
280
+ else:
281
+ d_tensor = torch.tensor(self.d, device=p.device)
282
+ dlr_tensor = torch.tensor(self.dlr, device=p.device)
283
+ self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
284
+
285
+ def compile(self, *args, **kwargs):
286
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
287
+
287
288
  @torch.no_grad()
288
289
  def step(self, closure: Optional[callable] = None):
289
290
  """Performs a single optimization step."""
@@ -306,21 +307,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
306
307
  """Calculates the new `d` based on the accumulated stats."""
307
308
  g_group = self.param_groups[0]
308
309
  # Only perform d-adaptation if prodigy_steps has not been reached
309
- prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
310
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and self.global_step >= g_group['prodigy_steps'])
310
311
 
311
312
  if prodigy_active:
312
313
  d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
313
-
314
+
314
315
  if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
315
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
316
- device = self.param_groups[0]['params'][0].device
317
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
316
+ dist_tensor = torch.stack([self.d_numerator, self.d_denom])
318
317
  dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
319
318
  global_d_numerator = dist_tensor[0].item()
320
319
  global_d_denom = dist_tensor[1].item()
321
320
  else:
322
- global_d_numerator = self.d_numerator
323
- global_d_denom = self.d_denom
321
+ global_d_numerator = self.d_numerator.item()
322
+ global_d_denom = self.d_denom.item()
324
323
 
325
324
  d_hat = self.d
326
325
  if global_d_denom > 0:
@@ -337,5 +336,4 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
337
336
  group['d'] = self.d
338
337
  group['d_max'] = d_max
339
338
  # Increment step counter for all groups, regardless of whether d was updated
340
- for group in self.param_groups:
341
- group['k'] += 1
339
+ self.global_step += 1
@@ -27,10 +27,6 @@ class Lion_adv(torch.optim.Optimizer):
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
29
  cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
30
  nnmf_factor (bool): whether to use the factorization or use the
35
31
  uncompressed optimizer. (default: True)
36
32
  """
@@ -41,11 +37,10 @@ class Lion_adv(torch.optim.Optimizer):
41
37
  lr: float = 1e-4,
42
38
  betas: Tuple[float, float] = (0.9, 0.99),
43
39
  weight_decay: float = 0.0,
44
- vector_reshape: bool = True,
40
+ vector_reshape: bool = False,
45
41
  stochastic_rounding: bool = True,
46
42
  orthogonal_gradient: bool = False,
47
43
  cautious_mask: bool = False,
48
- clip_threshold: float = 0.0,
49
44
  nnmf_factor: bool = True,
50
45
  ):
51
46
  if not lr > 0.0:
@@ -61,7 +56,6 @@ class Lion_adv(torch.optim.Optimizer):
61
56
  weight_decay=weight_decay,
62
57
  vector_reshape=vector_reshape,
63
58
  orthogonal_gradient=orthogonal_gradient,
64
- clip_threshold=clip_threshold,
65
59
  )
66
60
  self.stochastic_rounding = stochastic_rounding
67
61
  self.cautious_mask = cautious_mask
@@ -80,48 +74,49 @@ class Lion_adv(torch.optim.Optimizer):
80
74
  def supports_flat_params(self) -> bool:
81
75
  return False
82
76
 
83
- @torch.no_grad()
84
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
85
- """Performs a single optimization step on a single parameter."""
86
- if p.grad is None:
87
- return
77
+ def init_step(self):
78
+ for group in self.param_groups:
79
+ for i, p in enumerate(group['params']):
80
+ self.__init_state(p, group)
88
81
 
89
- grad = p.grad
90
- if grad.dtype != torch.float32 and self.factored:
91
- grad = grad.float()
92
- if group["clip_threshold"] > 0.0:
93
- grad_norm = torch.norm(grad.detach())
94
- if grad_norm > group["clip_threshold"]:
95
- clip_coef = group["clip_threshold"] / grad_norm
96
- grad.mul_(clip_coef)
97
- if group["orthogonal_gradient"]:
98
- grad = _orthogonalize_gradient(p, grad)
82
+ @torch.no_grad()
83
+ def __init_state(self, p, group):
99
84
  state = self.state[p]
100
85
 
101
- # State Initialization
102
- if 'step' not in state:
103
- state['step'] = 0
86
+ if len(state) == 0:
104
87
 
105
- should_factor = (
88
+ state['factored'] = (
106
89
  self.factored and
107
90
  not (len(p.shape) == 1 and not group['vector_reshape'])
108
91
  )
109
92
 
110
- state['factored'] = should_factor
111
-
112
93
  dtype = torch.float32 if self.factored else p.dtype
113
94
 
114
95
  if state['factored']:
115
96
  state['effective_shape'] = _get_effective_shape(p.numel())
116
97
  d1, d2 = state['effective_shape']
117
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
98
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
118
99
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
119
100
  packed_d2 = (d2 + 7) // 8
120
101
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
121
102
  else: # Fallback to standard Lion
122
103
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
123
104
 
124
- state['step'] += 1
105
+ @torch.no_grad()
106
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
107
+ """Performs a single optimization step on a single parameter."""
108
+ if p.grad is None:
109
+ return
110
+
111
+ grad = p.grad
112
+ if grad.dtype != torch.float32 and self.factored:
113
+ grad = grad.float()
114
+ if group["orthogonal_gradient"]:
115
+ grad = _orthogonalize_gradient(p, grad)
116
+
117
+ state = self.state[p]
118
+
119
+
125
120
  beta1, beta2 = group["betas"]
126
121
  lr = group["lr"]
127
122
 
@@ -138,16 +133,16 @@ class Lion_adv(torch.optim.Optimizer):
138
133
  exp_avg = exp_avg.float()
139
134
 
140
135
  # Compute update term c_t
141
- signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
136
+ update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
142
137
 
143
138
  if self.cautious_mask:
144
- mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
139
+ mask = (update * grad_reshaped > 0).to(grad_reshaped.dtype)
145
140
  mask.div_(mask.mean().clamp_(min=1e-3))
146
- signed_update.mul_(mask)
141
+ update.mul_(mask)
147
142
  del mask
148
143
 
149
144
  # Parameter update
150
- update_for_param = signed_update.view(p.shape).mul_(lr)
145
+ update = update.view(p.shape).mul_(lr)
151
146
 
152
147
  # Standard Lion momentum update
153
148
  exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
@@ -165,15 +160,15 @@ class Lion_adv(torch.optim.Optimizer):
165
160
  # Compute update term and sign for the update
166
161
  if exp_avg.dtype != torch.float32 and self.factored:
167
162
  exp_avg = exp_avg.float()
168
- signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
163
+ update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
169
164
 
170
165
  if self.cautious_mask:
171
- mask = (signed_update * grad > 0).to(grad.dtype)
166
+ mask = (update * grad > 0).to(grad.dtype)
172
167
  mask.div_(mask.mean().clamp_(min=1e-3))
173
- signed_update.mul_(mask)
168
+ update.mul_(mask)
174
169
  del mask
175
170
 
176
- update_for_param = signed_update.mul_(lr)
171
+ update.mul_(lr)
177
172
 
178
173
  # Standard Lion momentum update
179
174
  exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
@@ -188,11 +183,11 @@ class Lion_adv(torch.optim.Optimizer):
188
183
  )
189
184
 
190
185
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
191
- add_stochastic_(p.data, -update_for_param)
186
+ add_stochastic_(p.data, -update)
192
187
  else:
193
- p.data.add_(-update_for_param)
188
+ p.data.add_(-update)
194
189
 
195
- del update_for_param
190
+ del update
196
191
 
197
192
  @torch.no_grad()
198
193
  def step(self, closure: Optional[callable] = None):
@@ -207,4 +202,4 @@ class Lion_adv(torch.optim.Optimizer):
207
202
  if p.grad is not None:
208
203
  self.step_parameter(p, group, i)
209
204
 
210
- return loss
205
+ return loss