adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,341 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+
6
+ from typing import Tuple, Optional
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Effective_Shape import _get_effective_shape
10
+ from ..util.NNMF import _nnmf,_unnmf
11
+ from ..util.OrthoGrad import _orthogonalize_gradient
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+
14
+ class Lion_Prodigy_adv(torch.optim.Optimizer):
15
+ """
16
+ Implements the SMMF technique and Prodigy D-Adaptation method for Lion algorithm.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ nnmf_factor (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ d0 (float):
37
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
38
+ d_coef (float):
39
+ Coefficient in the expression for the estimate of d (default 1.0).
40
+ Values such as 0.5 and 2.0 typically work as well.
41
+ Changing this parameter is the preferred way to tune the method.
42
+ growth_rate (float):
43
+ prevent the D estimate from growing faster than this multiplicative rate.
44
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
45
+ rate warmup effect.
46
+ fsdp_in_use (bool):
47
+ If you're using sharded parameters, this should be set to True. The optimizer
48
+ will attempt to auto-detect this, but if you're using an implementation other
49
+ than PyTorch's builtin version, the auto-detection won't work.
50
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
51
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
52
+ Prodigy. Values ~11 are reasonable (default 11).
53
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
54
+ after the specified optimiser step and release all state memory required by Prodigy
55
+ (default: 0).
56
+ d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
57
+ to prevent sudden, volatile increases in the adaptive step size (`d`).
58
+ (default: True)
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ params,
64
+ lr: float = 1,
65
+ betas: Tuple[float, float] = (0.9, 0.99),
66
+ weight_decay: float = 0.0,
67
+ vector_reshape: bool = True,
68
+ stochastic_rounding: bool = True,
69
+ orthogonal_gradient: bool = False,
70
+ cautious_mask: bool = False,
71
+ clip_threshold: float = 0.0,
72
+ nnmf_factor: bool = False,
73
+ # prodigy parameters
74
+ beta3: float = None,
75
+ d0: float = 1e-6,
76
+ d_coef: float = 1,
77
+ growth_rate: float = float('inf'),
78
+ safeguard_warmup: bool = False,
79
+ fsdp_in_use: bool = False,
80
+ slice_p: int = 11,
81
+ prodigy_steps: int = 0,
82
+ d_limiter: bool = True,
83
+ ):
84
+ if not lr > 0.0:
85
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
86
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
87
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
88
+ if not weight_decay >= 0.0:
89
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
90
+
91
+ defaults = dict(
92
+ lr=lr,
93
+ betas=betas,
94
+ weight_decay=weight_decay,
95
+ vector_reshape=vector_reshape,
96
+ orthogonal_gradient=orthogonal_gradient,
97
+ clip_threshold=clip_threshold,
98
+ beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
99
+ growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
100
+ fsdp_in_use=fsdp_in_use,
101
+ prodigy_steps=prodigy_steps,
102
+ d_limiter=d_limiter,
103
+ )
104
+ self.stochastic_rounding = stochastic_rounding
105
+ self.cautious_mask = cautious_mask
106
+ self.factored = nnmf_factor
107
+ self.fsdp_in_use = fsdp_in_use
108
+ super().__init__(params, defaults)
109
+ # Global state for accumulating metrics across parameter updates within a single step.
110
+ self.init_step()
111
+
112
+ @property
113
+ def supports_fused_back_pass(self) -> bool:
114
+ return True
115
+
116
+ @property
117
+ def supports_memory_efficient_fp16(self) -> bool:
118
+ return True
119
+
120
+ @property
121
+ def supports_flat_params(self) -> bool:
122
+ return False
123
+
124
+ def init_step(self):
125
+ """Resets accumulators and calculates dlr for the upcoming step."""
126
+ self.d_denom = 0.0
127
+
128
+ g_group = self.param_groups[0]
129
+ self.beta1, self.beta2 = g_group['betas']
130
+ self.beta3 = g_group['beta3']
131
+ if self.beta3 is None:
132
+ self.beta3 = math.sqrt(self.beta2)
133
+
134
+ k = g_group['k']
135
+ self.d = g_group['d']
136
+ lr = g_group['lr']
137
+
138
+ self.dlr = self.d * lr
139
+
140
+ self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
141
+
142
+ @torch.no_grad()
143
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
144
+ """Performs a single optimization step on a single parameter."""
145
+ if p.grad is None:
146
+ return
147
+
148
+ if hasattr(p, "_fsdp_flattened"):
149
+ self.fsdp_in_use = True
150
+
151
+ grad = p.grad
152
+ if grad.dtype != torch.float32 and self.factored:
153
+ grad = grad.float()
154
+ if group["clip_threshold"] > 0.0:
155
+ grad_norm = torch.norm(grad.detach())
156
+ if grad_norm > group["clip_threshold"]:
157
+ clip_coef = group["clip_threshold"] / grad_norm
158
+ grad.mul_(clip_coef)
159
+ if group["orthogonal_gradient"]:
160
+ grad = _orthogonalize_gradient(p, grad)
161
+ state = self.state[p]
162
+
163
+ # State Initialization
164
+ if 'step' not in state:
165
+ state['step'] = 0
166
+
167
+ should_factor = (
168
+ self.factored and
169
+ not (len(p.shape) == 1 and not group['vector_reshape'])
170
+ )
171
+
172
+ state['factored'] = should_factor
173
+
174
+ dtype = torch.float32 if self.factored else p.dtype
175
+
176
+ slice_p = group['slice_p']
177
+
178
+ # D-Adaptation states
179
+ state['s'] = torch.zeros_like(p.flatten()[::slice_p]).detach()
180
+ if p.any():
181
+ state['p0'] = p.flatten()[::slice_p].detach().clone()
182
+ else:
183
+ state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)
184
+
185
+ if state['factored']:
186
+ state['effective_shape'] = _get_effective_shape(p.numel())
187
+ d1, d2 = state['effective_shape']
188
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
189
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
190
+ packed_d2 = (d2 + 7) // 8
191
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
192
+ else: # Fallback to standard Lion
193
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
194
+
195
+ if state['factored']:
196
+ # Factored Path
197
+ d1, d2 = state['effective_shape']
198
+ grad_reshaped = grad.view(d1, d2)
199
+ # Reconstruct momentum m_{t-1}
200
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
201
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
202
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
203
+ del unpacked_sign
204
+ if exp_avg.dtype != torch.float32:
205
+ exp_avg = exp_avg.float()
206
+
207
+ # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
208
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
209
+
210
+ if self.cautious_mask:
211
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
212
+ mask.div_(mask.mean().clamp_(min=1e-3))
213
+ signed_update.mul_(mask)
214
+ del mask
215
+
216
+ # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
217
+ update_for_param = signed_update.view(p.shape).mul(self.dlr)
218
+
219
+ # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
220
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
221
+ del grad_reshaped
222
+
223
+ # Compress new momentum m_t and store factors
224
+ state['sign'] = _pack_bools(exp_avg > 0)
225
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
226
+ del exp_avg
227
+
228
+ else:
229
+ # Fallback to standard Lion logic
230
+ exp_avg = state["exp_avg"]
231
+
232
+ # Compute update term and sign for the update
233
+ if exp_avg.dtype != torch.float32 and self.factored:
234
+ exp_avg = exp_avg.float()
235
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
236
+
237
+ if self.cautious_mask:
238
+ mask = (signed_update * grad > 0).to(grad.dtype)
239
+ mask.div_(mask.mean().clamp_(min=1e-3))
240
+ signed_update.mul_(mask)
241
+ del mask
242
+
243
+ update_for_param = signed_update.mul(self.dlr)
244
+
245
+ # Update momentum
246
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
247
+
248
+ prodigy_steps = group['prodigy_steps']
249
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
250
+ # --- Accumulate Prodigy stats ---
251
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
252
+ s, p0 = state['s'], state['p0']
253
+ grad_flat = grad.flatten().float()
254
+ p_flat = p.data.flatten().float()
255
+ p0 = p0.float()
256
+
257
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
258
+
259
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
260
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
261
+ self.d_denom += s.abs().sum().item()
262
+
263
+ del s, p0, grad_flat, p_flat, alpha
264
+ else:
265
+ # Free memory if prodigy_steps is reached
266
+ if 's' in state:
267
+ del state['s']
268
+ if 'p0' in state:
269
+ del state['p0']
270
+
271
+ if group["weight_decay"] != 0:
272
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
273
+ add_stochastic_(p.data, p.data,
274
+ alpha=-group["weight_decay"] * self.dlr)
275
+ else:
276
+ p.data.add_(
277
+ p.data, alpha=-group["weight_decay"] * self.dlr
278
+ )
279
+
280
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
281
+ add_stochastic_(p.data, -update_for_param)
282
+ else:
283
+ p.data.add_(-update_for_param)
284
+
285
+ del update_for_param
286
+
287
+ @torch.no_grad()
288
+ def step(self, closure: Optional[callable] = None):
289
+ """Performs a single optimization step."""
290
+ loss = None
291
+ if closure is not None:
292
+ with torch.enable_grad():
293
+ loss = closure()
294
+
295
+ for group in self.param_groups:
296
+ for i, p in enumerate(group["params"]):
297
+ if p.grad is not None:
298
+ self.step_parameter(p, group, i)
299
+
300
+
301
+ self.calculate_d()
302
+ self.init_step()
303
+ return loss
304
+
305
+ def calculate_d(self):
306
+ """Calculates the new `d` based on the accumulated stats."""
307
+ g_group = self.param_groups[0]
308
+ # Only perform d-adaptation if prodigy_steps has not been reached
309
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
310
+
311
+ if prodigy_active:
312
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
313
+
314
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
315
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
316
+ device = self.param_groups[0]['params'][0].device
317
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
318
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
319
+ global_d_numerator = dist_tensor[0].item()
320
+ global_d_denom = dist_tensor[1].item()
321
+ else:
322
+ global_d_numerator = self.d_numerator
323
+ global_d_denom = self.d_denom
324
+
325
+ d_hat = self.d
326
+ if global_d_denom > 0:
327
+ d_hat = d_coef * global_d_numerator / global_d_denom
328
+ if g_group.get('d_limiter', False):
329
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
330
+ if self.d == g_group['d0']:
331
+ self.d = max(self.d, d_hat)
332
+ d_max = max(d_max, d_hat)
333
+ self.d = min(d_max, self.d * growth_rate)
334
+
335
+ for group in self.param_groups:
336
+ group['d_numerator'] = global_d_numerator
337
+ group['d'] = self.d
338
+ group['d_max'] = d_max
339
+ # Increment step counter for all groups, regardless of whether d was updated
340
+ for group in self.param_groups:
341
+ group['k'] += 1
@@ -0,0 +1,210 @@
1
+ import torch
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.OrthoGrad import _orthogonalize_gradient
9
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
10
+
11
+ class Lion_adv(torch.optim.Optimizer):
12
+ """
13
+ Implements the SMMF technique for Lion algorithm.
14
+
15
+ This optimizer combines the Lion update rule with the memory-saving low-rank
16
+ compression (SMMF) technique from https://arxiv.org/abs/2412.08894.
17
+
18
+ Args:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups.
21
+ lr (float, optional): learning rate (default: 1e-4).
22
+ betas (Tuple[float, float], optional): coefficients for computing
23
+ running averages of the update (default: (0.9, 0.99)).
24
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0).
25
+ vector_reshape (bool, optional): whether to reshape 1D vectors into 2D
26
+ matrices to apply low-rank compression (default: True).
27
+ stochastic_rounding (bool, optional): whether to use stochastic
28
+ rounding for BF16 parameter updates (default: True).
29
+ cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
+ clip_threshold (float, optional): whether to clip the gradients norm
31
+ per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
+ Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
+ (default: 0.0).
34
+ nnmf_factor (bool): whether to use the factorization or use the
35
+ uncompressed optimizer. (default: True)
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ params,
41
+ lr: float = 1e-4,
42
+ betas: Tuple[float, float] = (0.9, 0.99),
43
+ weight_decay: float = 0.0,
44
+ vector_reshape: bool = True,
45
+ stochastic_rounding: bool = True,
46
+ orthogonal_gradient: bool = False,
47
+ cautious_mask: bool = False,
48
+ clip_threshold: float = 0.0,
49
+ nnmf_factor: bool = True,
50
+ ):
51
+ if not lr > 0.0:
52
+ raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
53
+ if not all(0.0 <= beta <= 1.0 for beta in betas):
54
+ raise ValueError(f"Betas should be in [0.0, 1.0], but got {betas}")
55
+ if not weight_decay >= 0.0:
56
+ raise ValueError(f"Weight decay must be >= 0.0, but got {weight_decay}")
57
+
58
+ defaults = dict(
59
+ lr=lr,
60
+ betas=betas,
61
+ weight_decay=weight_decay,
62
+ vector_reshape=vector_reshape,
63
+ orthogonal_gradient=orthogonal_gradient,
64
+ clip_threshold=clip_threshold,
65
+ )
66
+ self.stochastic_rounding = stochastic_rounding
67
+ self.cautious_mask = cautious_mask
68
+ self.factored = nnmf_factor
69
+ super().__init__(params, defaults)
70
+
71
+ @property
72
+ def supports_fused_back_pass(self) -> bool:
73
+ return True
74
+
75
+ @property
76
+ def supports_memory_efficient_fp16(self) -> bool:
77
+ return True
78
+
79
+ @property
80
+ def supports_flat_params(self) -> bool:
81
+ return False
82
+
83
+ @torch.no_grad()
84
+ def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
85
+ """Performs a single optimization step on a single parameter."""
86
+ if p.grad is None:
87
+ return
88
+
89
+ grad = p.grad
90
+ if grad.dtype != torch.float32 and self.factored:
91
+ grad = grad.float()
92
+ if group["clip_threshold"] > 0.0:
93
+ grad_norm = torch.norm(grad.detach())
94
+ if grad_norm > group["clip_threshold"]:
95
+ clip_coef = group["clip_threshold"] / grad_norm
96
+ grad.mul_(clip_coef)
97
+ if group["orthogonal_gradient"]:
98
+ grad = _orthogonalize_gradient(p, grad)
99
+ state = self.state[p]
100
+
101
+ # State Initialization
102
+ if 'step' not in state:
103
+ state['step'] = 0
104
+
105
+ should_factor = (
106
+ self.factored and
107
+ not (len(p.shape) == 1 and not group['vector_reshape'])
108
+ )
109
+
110
+ state['factored'] = should_factor
111
+
112
+ dtype = torch.float32 if self.factored else p.dtype
113
+
114
+ if state['factored']:
115
+ state['effective_shape'] = _get_effective_shape(p.numel())
116
+ d1, d2 = state['effective_shape']
117
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
118
+ state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
119
+ packed_d2 = (d2 + 7) // 8
120
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
121
+ else: # Fallback to standard Lion
122
+ state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
123
+
124
+ state['step'] += 1
125
+ beta1, beta2 = group["betas"]
126
+ lr = group["lr"]
127
+
128
+ if state['factored']:
129
+ # Factored Path
130
+ d1, d2 = state['effective_shape']
131
+ grad_reshaped = grad.view(d1, d2)
132
+ # Reconstruct momentum m_{t-1}
133
+ exp_avg = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
134
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
135
+ torch.where(unpacked_sign, exp_avg, -exp_avg, out=exp_avg)
136
+ del unpacked_sign
137
+ if exp_avg.dtype != torch.float32:
138
+ exp_avg = exp_avg.float()
139
+
140
+ # Compute update term c_t
141
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
142
+
143
+ if self.cautious_mask:
144
+ mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
145
+ mask.div_(mask.mean().clamp_(min=1e-3))
146
+ signed_update.mul_(mask)
147
+ del mask
148
+
149
+ # Parameter update
150
+ update_for_param = signed_update.view(p.shape).mul_(lr)
151
+
152
+ # Standard Lion momentum update
153
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
154
+ del grad_reshaped
155
+
156
+ # Compress new momentum m_t and store factors
157
+ state['sign'] = _pack_bools(exp_avg > 0)
158
+ _nnmf(exp_avg.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
159
+ del exp_avg
160
+
161
+ else:
162
+ # Fallback to standard Lion logic
163
+ exp_avg = state["exp_avg"]
164
+
165
+ # Compute update term and sign for the update
166
+ if exp_avg.dtype != torch.float32 and self.factored:
167
+ exp_avg = exp_avg.float()
168
+ signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
169
+
170
+ if self.cautious_mask:
171
+ mask = (signed_update * grad > 0).to(grad.dtype)
172
+ mask.div_(mask.mean().clamp_(min=1e-3))
173
+ signed_update.mul_(mask)
174
+ del mask
175
+
176
+ update_for_param = signed_update.mul_(lr)
177
+
178
+ # Standard Lion momentum update
179
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
180
+
181
+ if group["weight_decay"] != 0:
182
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
183
+ add_stochastic_(p.data, p.data,
184
+ alpha=-group["weight_decay"] * lr)
185
+ else:
186
+ p.data.add_(
187
+ p.data, alpha=-group["weight_decay"] * lr
188
+ )
189
+
190
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
191
+ add_stochastic_(p.data, -update_for_param)
192
+ else:
193
+ p.data.add_(-update_for_param)
194
+
195
+ del update_for_param
196
+
197
+ @torch.no_grad()
198
+ def step(self, closure: Optional[callable] = None):
199
+ """Performs a single optimization step."""
200
+ loss = None
201
+ if closure is not None:
202
+ with torch.enable_grad():
203
+ loss = closure()
204
+
205
+ for group in self.param_groups:
206
+ for i, p in enumerate(group["params"]):
207
+ if p.grad is not None:
208
+ self.step_parameter(p, group, i)
209
+
210
+ return loss