heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/pure_psgd.py DELETED
@@ -1,105 +0,0 @@
1
- """
2
- Originally from Evan Walters and Omead Pooladzandi, 2024
3
- Modified under Creative Commons Attribution 4.0 International
4
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
- """
6
-
7
- import torch
8
- from heavyball.utils import identity
9
-
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, \
11
- line_to_triu, triu_to_line, promote
12
-
13
-
14
- class ForeachPurePSGD(PSGDBase):
15
- """
16
- Kronecker Factorized PSGD WITHOUT Momentum
17
-
18
- Args:
19
- params (iterable): Iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float): Learning rate.
22
- weight_decay (float): Weight decay (L2 penalty).
23
- preconditioner_update_probability (callable or float, optional): Probability of
24
- updating the preconditioner. If None, defaults to a schedule that anneals
25
- from 1.0 to 0.03 by 4000 steps.
26
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
27
- min_ndim_triangular (int): Minimum number of dimensions a layer needs
28
- to have triangular preconditioners.
29
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
30
- to set all preconditioners to be triangular, 'one_diag' sets the largest
31
- or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
32
- to be diagonal.
33
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
34
- update instead of raw gradients.
35
- """
36
-
37
- def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
38
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
39
- momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
40
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
41
- q_dtype='float32', stochastic_schedule: bool = True, mars: bool = False, caution: bool = False,
42
- mars_gamma: float = 0.0025, #
43
- # expert parameters
44
- precond_init_scale=1.0, precond_lr=0.1):
45
- if not 0.0 <= lr:
46
- raise ValueError(f"Invalid learning rate: {lr}")
47
- if not 0.0 <= weight_decay:
48
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
49
-
50
- if clip_fn is None:
51
- clip_fn = identity
52
-
53
- assert not mars, "MARS is not supported in this optimizer"
54
-
55
- defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
56
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, mars=mars, caution=caution,
60
- mars_gamma=mars_gamma)
61
- super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
-
63
- def _step(self, group):
64
- should_update = self.should_update(group)
65
- precond_init_scale = group['precond_init_scale']
66
- max_size_triangular = group['max_size_triangular']
67
- min_ndim_triangular = group['min_ndim_triangular']
68
- memory_save_mode = group['memory_save_mode']
69
- precond_lr = group['precond_lr']
70
- weight_decay = group['weight_decay']
71
- lr = group['lr']
72
- store_triu_as_line = group['store_triu_as_line']
73
- q_dtype = getattr(torch, group['q_dtype'])
74
-
75
- vals = []
76
-
77
- for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=0.0):
78
- state = self.state_(p)
79
-
80
- if 'Q' not in state:
81
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
82
- memory_save_mode, dtype=q_dtype)
83
- state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
84
-
85
- vals.append((p, g, state["Q"]))
86
-
87
- if not vals:
88
- return
89
-
90
- p_list, grad_list, Q_list = zip(*vals)
91
- del vals
92
-
93
- group["step"] += 1
94
-
95
- Q_list = list(Q_list)
96
- lr = -warmup(lr, group['step'], group['warmup_steps'])
97
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
98
- q_orig = Q_list.pop(0)
99
- q = line_to_triu(q_orig) if store_triu_as_line else q_orig
100
-
101
- if group:
102
- q32 = [promote(q_) for q_ in q]
103
- self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
104
- psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *q)
105
- update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
@@ -1,136 +0,0 @@
1
- import random
2
-
3
- import torch
4
-
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
- beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote, decorator_knowngood, \
7
- mars_correction
8
-
9
-
10
- @decorator_knowngood
11
- def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
- eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
- denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
14
- torch._foreach_div_(gp32, denom)
15
-
16
- copy_stochastic_list_(exp_avg_sq, eas32)
17
- copy_stochastic_list_(grad_projected, gp32)
18
-
19
-
20
- class SFPaLMForeachSOAP(ScheduleFree):
21
- """
22
- SFPaLMForeachSOAP
23
-
24
- Sources:
25
- Baseline SOAP:
26
- SOAP: Improving and Stabilizing Shampoo using Adam
27
- Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
28
- https://arxiv.org/abs/2409.11321
29
- https://github.com/nikhilvyas/SOAP
30
-
31
- ScheduleFree:
32
- The Road Less Scheduled
33
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
34
- https://arxiv.org/abs/2405.15682
35
- https://github.com/facebookresearch/schedule_free
36
-
37
- Beta2 Schedule:
38
- PaLM: Scaling Language Modeling with Pathways
39
- Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh, Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Levskaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick, Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Brennan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas Eck, Jeff Dean, Slav Petrov, Noah Fiedel
40
- https://arxiv.org/abs/2204.02311
41
- """
42
-
43
- def __init__(self, params, lr: float = 3e-3, beta=0.9, beta2_scale: float = 0.8, eps: float = 1e-8,
44
- weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
45
- merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
46
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
47
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
48
- foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
49
- if betas[0] is not None:
50
- beta = betas[0]
51
-
52
- assert not caution, "Caution is not implemented in ScheduleFree optimizers"
53
-
54
- defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
55
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
56
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
57
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
58
- 'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
59
- 'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split, 'mars': mars,
60
- 'caution': caution, 'mars_gamma': mars_gamma}
61
- super().__init__(params, defaults, foreach)
62
- self._data_format = data_format
63
- self.rng = random.Random(0x120983109)
64
-
65
- def _step(self, group):
66
- vals = []
67
- max_precond_dim = group['max_precond_dim']
68
- precondition_1d = group['precondition_1d']
69
- mars = group['mars']
70
-
71
- step = group['step'] = group.get("step", 0) + 1
72
-
73
- for p in group["params"]:
74
- if p.grad is None:
75
- continue
76
- grad = p.grad.float()
77
- vals.append((p, grad))
78
-
79
- if not vals:
80
- return
81
-
82
- p_list, grad = zip(*vals)
83
-
84
- adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
85
-
86
- vals = []
87
-
88
- for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
89
- state = self.state_(p)
90
-
91
- if "z" not in state:
92
- state["z"] = torch.clone(p).float()
93
- state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32, memory_format=torch.preserve_format)
94
- if mars:
95
- state['mars_prev_grad'] = g.clone()
96
- init_preconditioner(g, state, max_precond_dim, precondition_1d)
97
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
98
- continue # first step is skipped so that we never use the current gradients in the projection.
99
-
100
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
101
- # i.e. projecting to the eigenbases of matrices in state['GG']
102
- grad_projected = project(g, state['Q'], False)
103
- z, exp_avg_sq = state["z"], state["exp_avg_sq"]
104
- vals.append((p, g, grad_projected, z, exp_avg_sq))
105
-
106
- if not vals:
107
- return
108
-
109
- p_list, grad, grad_projected, z, exp_avg_sq = zip(*vals)
110
-
111
- beta2 = 1 - max(step, 1) ** -group['beta2_scale']
112
- new_debiased2 = beta_debias(beta2, step)
113
-
114
- # Decay the first and second moment running average coefficient
115
- # In-place operations to update the averages at the same time
116
- old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(new_debiased2)
117
- _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
118
-
119
- update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
120
-
121
- for p, g, gp in zip(p_list, grad, grad_projected):
122
- state = self.state_(p)
123
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
124
- # to the original space
125
- # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
126
- set_(gp, project(gp, state['Q'], back=True))
127
-
128
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2, update_precond)
129
-
130
- # Weight decay calculated at y
131
- if group["weight_decay"] > 0:
132
- torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
133
-
134
- lr = warmup(group['lr'], step, group['warmup_steps'])
135
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
136
- z, grad_projected, group['r'], step)
@@ -1,28 +0,0 @@
1
- heavyball/__init__.py,sha256=RdUfGDTXw-rtoQJNediWnhDseYyyWNPVsr6tRq_ucp8,2813
2
- heavyball/cached_delayed_psgd_kron.py,sha256=HEyT6vW6Le6FmWpf-vAEzgbAkPH2mByqXcVZn07KCMk,6866
3
- heavyball/cached_psgd_kron.py,sha256=rOgWAeVMENI7kdoBuRo3ywrCeatAnIqBdeYPHuVk2aU,6998
4
- heavyball/delayed_psgd.py,sha256=L6qRLPxJmJ_1e0Mk2zLYUEVxkt8NGHq6v3HKawlgFcU,6334
5
- heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,2860
6
- heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
7
- heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
8
- heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
9
- heavyball/foreach_soap.py,sha256=ntFqg0fbkZ8EzERGlypXB8JWoGJ1sAY59f0CuWh_d48,4801
10
- heavyball/foreach_solp.py,sha256=1r7x_FUZRaUsoSLSvi-Z_-pZNtZrMresVJGq9m1EREA,4563
11
- heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
12
- heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
13
- heavyball/palm_foreach_soap.py,sha256=fbRL1Tx9YeQ16sHWFPtY5Kj60BFV2AMngOnTiE4muK0,6231
14
- heavyball/palm_foreach_solp.py,sha256=N3M3tnahOfSHvLu3en76JTI1yo-ISEbSliSKlpt8ZWw,5994
15
- heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
16
- heavyball/precond_schedule_foreach_solp.py,sha256=xGEQ6HHUTCKeT9-ppEbLTdXVAfE74P0tph0qS16USyg,4768
17
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
18
- heavyball/precond_schedule_palm_foreach_solp.py,sha256=gaoJwJo_ZBnYuMamgFepnV9iWpUCmbYrxMWiL1QkPh0,6253
19
- heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
20
- heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
21
- heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
22
- heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
23
- heavyball/utils.py,sha256=_KvCCCnsu_l4I_OhiRr4noAiwUvzctN05JAuYPkrxXY,41191
24
- heavyball-0.25.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
25
- heavyball-0.25.1.dist-info/METADATA,sha256=WWR7dX_i7dcF-73-VJ42qcRFwZRL3unOSEwO4EM96e0,11926
26
- heavyball-0.25.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
27
- heavyball-0.25.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
28
- heavyball-0.25.1.dist-info/RECORD,,