adv-optm 1.1.1__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.1.1 → adv_optm-1.1.2}/PKG-INFO +1 -1
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/__init__.py +1 -1
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/Lion_Prodigy_adv.py +62 -36
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.1.1 → adv_optm-1.1.2}/setup.py +1 -1
- {adv_optm-1.1.1 → adv_optm-1.1.2}/LICENSE +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/README.md +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.1.1 → adv_optm-1.1.2}/setup.cfg +0 -0
|
@@ -50,6 +50,12 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
50
50
|
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
51
51
|
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
52
52
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
53
|
+
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
54
|
+
after the specified optimiser step and release all state memory required by Prodigy
|
|
55
|
+
(default: 0).
|
|
56
|
+
d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
|
|
57
|
+
to prevent sudden, volatile increases in the adaptive step size (`d`).
|
|
58
|
+
(default: True)
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
61
|
def __init__(
|
|
@@ -63,7 +69,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
63
69
|
orthogonal_gradient: bool = False,
|
|
64
70
|
cautious_mask: bool = False,
|
|
65
71
|
clip_threshold: float = 0.0,
|
|
66
|
-
nnmf_factor: bool =
|
|
72
|
+
nnmf_factor: bool = False,
|
|
67
73
|
# prodigy parameters
|
|
68
74
|
beta3: float = None,
|
|
69
75
|
d0: float = 1e-6,
|
|
@@ -72,6 +78,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
72
78
|
safeguard_warmup: bool = False,
|
|
73
79
|
fsdp_in_use: bool = False,
|
|
74
80
|
slice_p: int = 11,
|
|
81
|
+
prodigy_steps: int = 0,
|
|
82
|
+
d_limiter: bool = True,
|
|
75
83
|
):
|
|
76
84
|
if not lr > 0.0:
|
|
77
85
|
raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
|
|
@@ -90,6 +98,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
90
98
|
beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
|
|
91
99
|
growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
|
|
92
100
|
fsdp_in_use=fsdp_in_use,
|
|
101
|
+
prodigy_steps=prodigy_steps,
|
|
102
|
+
d_limiter=d_limiter,
|
|
93
103
|
)
|
|
94
104
|
self.stochastic_rounding = stochastic_rounding
|
|
95
105
|
self.cautious_mask = cautious_mask
|
|
@@ -235,20 +245,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
235
245
|
# Update momentum
|
|
236
246
|
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
237
247
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
248
|
+
prodigy_steps = group['prodigy_steps']
|
|
249
|
+
if prodigy_steps <= 0 or group['k'] < prodigy_steps:
|
|
250
|
+
# --- Accumulate Prodigy stats ---
|
|
251
|
+
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
252
|
+
s, p0 = state['s'], state['p0']
|
|
253
|
+
grad_flat = grad.flatten().float()
|
|
254
|
+
p_flat = p.data.flatten().float()
|
|
255
|
+
p0 = p0.float()
|
|
244
256
|
|
|
245
|
-
|
|
257
|
+
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
246
258
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
259
|
+
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
260
|
+
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
261
|
+
self.d_denom += s.abs().sum().item()
|
|
250
262
|
|
|
251
|
-
|
|
263
|
+
del s, p0, grad_flat, p_flat, alpha
|
|
264
|
+
else:
|
|
265
|
+
# Free memory if prodigy_steps is reached
|
|
266
|
+
if 's' in state:
|
|
267
|
+
del state['s']
|
|
268
|
+
if 'p0' in state:
|
|
269
|
+
del state['p0']
|
|
252
270
|
|
|
253
271
|
if group["weight_decay"] != 0:
|
|
254
272
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
@@ -287,29 +305,37 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
287
305
|
def calculate_d(self):
|
|
288
306
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
289
307
|
g_group = self.param_groups[0]
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
dist.
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
308
|
+
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
309
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
|
|
310
|
+
|
|
311
|
+
if prodigy_active:
|
|
312
|
+
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
313
|
+
|
|
314
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
315
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
316
|
+
device = self.param_groups[0]['params'][0].device
|
|
317
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
318
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
319
|
+
global_d_numerator = dist_tensor[0].item()
|
|
320
|
+
global_d_denom = dist_tensor[1].item()
|
|
321
|
+
else:
|
|
322
|
+
global_d_numerator = self.d_numerator
|
|
323
|
+
global_d_denom = self.d_denom
|
|
324
|
+
|
|
325
|
+
d_hat = self.d
|
|
326
|
+
if global_d_denom > 0:
|
|
327
|
+
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
328
|
+
if g_group['d_limiter']:
|
|
329
|
+
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
330
|
+
if self.d == g_group['d0']:
|
|
331
|
+
self.d = max(self.d, d_hat)
|
|
332
|
+
d_max = max(d_max, d_hat)
|
|
333
|
+
self.d = min(d_max, self.d * growth_rate)
|
|
334
|
+
|
|
335
|
+
for group in self.param_groups:
|
|
336
|
+
group['d_numerator'] = global_d_numerator
|
|
337
|
+
group['d'] = self.d
|
|
338
|
+
group['d_max'] = d_max
|
|
339
|
+
# Increment step counter for all groups, regardless of whether d was updated
|
|
311
340
|
for group in self.param_groups:
|
|
312
|
-
group['d_numerator'] = global_d_numerator
|
|
313
|
-
group['d'] = self.d
|
|
314
|
-
group['d_max'] = d_max
|
|
315
341
|
group['k'] += 1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|