heavyball 0.25.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.25.0
3
+ Version: 1.0.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,11 +32,12 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-12-07, 1.0.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
39
39
 
40
+ * **Optax-like API**: `C = heavyball.chainable; grokfast = C.ChainOpt(p, lr, C.exp_avg, C.scale_by_adam)`
40
41
  * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
41
42
  * **Inplace EMA**: Same math, but less memory, less compute and higher stability
42
43
  * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
@@ -130,6 +131,6 @@ To access `heavyball.utils`, you need to explicitly `import heavyball.utils`.\
130
131
  It has several handy functions:
131
132
 
132
133
  * `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
133
- * `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls
134
+ * `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
134
135
  * `zeroth_power_mode`, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate
135
136
  the eigenvectors. Eigh has the highest precision and cost
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=1QPYBIH8amnk3-_rKe6L9FJ0rkV5wVNRr7Yw9BXjIYI,11636
2
+ heavyball/chainable.py,sha256=cp-tpetPr4CNN9xJ85JSo89JYC5BWUygoE6dnET6tmc,18141
3
+ heavyball/utils.py,sha256=qUoB9EIxl7GUyLkV5a5JAKOD6TvPc1FNsqyUbJ-HY6o,46343
4
+ heavyball-1.0.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.0.0.dist-info/METADATA,sha256=9C2btIxngp26TRCJFU6B8ftkWQt1rfZZC10rkAhaORw,12074
6
+ heavyball-1.0.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.0.0.dist-info/RECORD,,
@@ -1,135 +0,0 @@
1
- """
2
- Originally from Evan Walters and Omead Pooladzandi, 2024
3
- Modified under Creative Commons Attribution 4.0 International
4
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
- from heavyball.utils import min_dtype, precond_grad_cached_
11
-
12
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
13
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
14
-
15
-
16
- class ForeachCachedDelayedPSGDKron(PSGDBase):
17
- """
18
- Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
19
-
20
-
21
- Args:
22
- params (iterable): Iterable of parameters to optimize or dicts defining
23
- parameter groups.
24
- lr (float): Learning rate.
25
- beta (float): Momentum parameter.
26
- weight_decay (float): Weight decay (L2 penalty).
27
- preconditioner_update_probability (callable or float, optional): Probability of
28
- updating the preconditioner. If None, defaults to a schedule that anneals
29
- from 1.0 to 0.03 by 4000 steps.
30
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
31
- min_ndim_triangular (int): Minimum number of dimensions a layer needs
32
- to have triangular preconditioners.
33
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
34
- to set all preconditioners to be triangular, 'one_diag' sets the largest
35
- or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
36
- to be diagonal.
37
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
38
- update instead of raw gradients.
39
- """
40
-
41
- def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
42
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
43
- momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
44
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
45
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
46
- storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
47
- #
48
- # expert parameters
49
- precond_init_scale=1.0, precond_lr=0.1):
50
- if not 0.0 <= lr:
51
- raise ValueError(f"Invalid learning rate: {lr}")
52
- if not 0.0 <= beta < 1.0:
53
- raise ValueError(f"Invalid beta parameter: {beta}")
54
- if not 0.0 <= weight_decay:
55
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
56
-
57
- if clip_fn is None:
58
- clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
59
-
60
- defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
61
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
62
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
63
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
64
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
65
- storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
66
- super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
67
-
68
- def _step(self, group):
69
- should_update = self.should_update(group)
70
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
71
- precond_init_scale = group['precond_init_scale']
72
- max_size_triangular = group['max_size_triangular']
73
- min_ndim_triangular = group['min_ndim_triangular']
74
- memory_save_mode = group['memory_save_mode']
75
- precond_lr = group['precond_lr']
76
- weight_decay = group['weight_decay']
77
- lr = group['lr']
78
- beta = group['beta']
79
- store_triu_as_line = group['store_triu_as_line']
80
- q_dtype = getattr(torch, group['q_dtype'])
81
- storage_dtype = getattr(torch, group['storage_dtype'])
82
-
83
- vals = []
84
-
85
- for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
86
- state = self.state_(p)
87
-
88
- if 'Q' not in state:
89
- state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
90
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
91
- memory_save_mode, dtype=q_dtype)
92
- state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
93
- state['Q_cache'] = [torch.empty_like(q) for q in Q]
94
-
95
- expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
96
- expr = ','.join(expr)
97
- grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
98
- out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
99
- expr = f'{expr},{grad_expr}->{out_expr}'
100
-
101
- state['cache_expr'] = expr
102
-
103
- vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
104
-
105
- if not vals:
106
- return
107
-
108
- p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
109
- del vals
110
-
111
- group["step"] += 1
112
-
113
- stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
114
-
115
- lr = -warmup(lr, group['step'], group['warmup_steps'])
116
-
117
- grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
118
- exp_avg_list)
119
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
120
- cached_q = Q_cache_list.pop(0)
121
- q_orig = Q_list.pop(0)
122
- ea = exp_avg_list.pop(0)
123
-
124
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn, group['caution'], g)
125
-
126
- if should_update:
127
- q = line_to_triu(q_orig) if store_triu_as_line else q_orig
128
- q32 = [promote(q_) for q_ in q]
129
- self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
130
- store_triu_as_line)
131
- for c_, q_ in zip(cached_q, q):
132
- if q_.ndim == 2:
133
- torch.matmul(q_.T.conj(), q_, out=c_)
134
- else:
135
- torch.mul(q_.conj(), q_, out=c_)
@@ -1,136 +0,0 @@
1
- """
2
- Originally from Evan Walters and Omead Pooladzandi, 2024
3
- Modified under Creative Commons Attribution 4.0 International
4
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
-
11
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
-
14
-
15
- class ForeachCachedPSGDKron(PSGDBase):
16
- """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
17
-
18
- Args:
19
- params (iterable): Iterable of parameters to optimize or dicts defining
20
- parameter groups.
21
- lr (float): Learning rate.
22
- beta (float): Momentum parameter.
23
- weight_decay (float): Weight decay (L2 penalty).
24
- preconditioner_update_probability (callable or float, optional): Probability of
25
- updating the preconditioner. If None, defaults to a schedule that anneals
26
- from 1.0 to 0.03 by 4000 steps.
27
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
28
- min_ndim_triangular (int): Minimum number of dimensions a layer needs
29
- to have triangular preconditioners.
30
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
31
- to set all preconditioners to be triangular, 'one_diag' sets the largest
32
- or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
33
- to be diagonal.
34
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
35
- update instead of raw gradients.
36
- """
37
-
38
- def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
39
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
- momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
- split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
- storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
- orthogonalize_output: bool = False,
45
- #
46
- # expert parameters
47
- precond_init_scale=1.0, precond_lr=0.1):
48
- if not 0.0 <= lr:
49
- raise ValueError(f"Invalid learning rate: {lr}")
50
- if not 0.0 <= beta < 1.0:
51
- raise ValueError(f"Invalid beta parameter: {beta}")
52
- if not 0.0 <= weight_decay:
53
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54
-
55
- if clip_fn is None:
56
- clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
57
-
58
- defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
59
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
60
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
61
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
62
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
63
- storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
64
- orthogonalize_output=orthogonalize_output)
65
- super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
66
-
67
- def _step(self, group):
68
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
69
- precond_init_scale = group['precond_init_scale']
70
- max_size_triangular = group['max_size_triangular']
71
- min_ndim_triangular = group['min_ndim_triangular']
72
- memory_save_mode = group['memory_save_mode']
73
- precond_lr = group['precond_lr']
74
- weight_decay = group['weight_decay']
75
- lr = group['lr']
76
- beta = group['beta']
77
- store_triu_as_line = group['store_triu_as_line']
78
- q_dtype = getattr(torch, group['q_dtype'])
79
- storage_dtype = getattr(torch, group['storage_dtype'])
80
- orthogonalize_output = group['orthogonalize_output']
81
- should_update = self.should_update(group)
82
-
83
- vals = []
84
-
85
- for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
86
- state = self.state_(p)
87
-
88
- if 'Q' not in state:
89
- state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
90
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
91
- memory_save_mode, dtype=q_dtype)
92
- state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
93
- state['Q_cache'] = [torch.empty_like(q) for q in Q]
94
-
95
- expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
96
- expr = ','.join(expr)
97
- grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
98
- out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
99
- expr = f'{expr},{grad_expr}->{out_expr}'
100
-
101
- state['cache_expr'] = expr
102
-
103
- vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
104
-
105
- if not vals:
106
- return
107
-
108
- p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
109
- del vals
110
-
111
- group["step"] += 1
112
-
113
- stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
114
-
115
- lr = -warmup(lr, group['step'], group['warmup_steps'])
116
-
117
- grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
118
- exp_avg_list)
119
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
120
- cached_q = Q_cache_list.pop(0)
121
- q_orig = Q_list.pop(0)
122
- ea = exp_avg_list.pop(0)
123
-
124
- if should_update:
125
- q = line_to_triu(q_orig) if store_triu_as_line else q_orig
126
- q32 = [promote(q_) for q_ in q]
127
- self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
128
- store_triu_as_line)
129
- for c_, q_ in zip(cached_q, q):
130
- if q_.ndim == 2:
131
- torch.matmul(q_.T.conj(), q_, out=c_)
132
- else:
133
- torch.mul(q_.conj(), q_, out=c_)
134
-
135
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn,
136
- group['caution'], g)
heavyball/delayed_psgd.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- Originally from Evan Walters and Omead Pooladzandi, 2024
3
- Modified under Creative Commons Attribution 4.0 International
4
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
5
- """
6
-
7
- import torch
8
- from heavyball.utils import stochastic_lerp_, beta_debias, stochastic_add_
9
-
10
- from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- triu_to_line, line_to_triu, promote,_compilable_update_, decorator_knowngood
12
-
13
-
14
- @decorator_knowngood
15
- def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn, caution, grad):
16
- new = psgd_precond_grad(False, exprs, ea, *q)
17
- _compilable_update_([p], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
18
-
19
-
20
- class ForeachDelayedPSGD(PSGDBase):
21
- """
22
- Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP)
23
-
24
- Args:
25
- params (iterable): Iterable of parameters to optimize or dicts defining
26
- parameter groups.
27
- lr (float): Learning rate.
28
- beta (float): Momentum parameter.
29
- weight_decay (float): Weight decay (L2 penalty).
30
- preconditioner_update_probability (callable or float, optional): Probability of
31
- updating the preconditioner. If None, defaults to a schedule that anneals
32
- from 1.0 to 0.03 by 4000 steps.
33
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
34
- min_ndim_triangular (int): Minimum number of dimensions a layer needs
35
- to have triangular preconditioners.
36
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
37
- to set all preconditioners to be triangular, 'one_diag' sets the largest
38
- or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
39
- to be diagonal.
40
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
41
- update instead of raw gradients.
42
- """
43
-
44
- def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
45
- max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
46
- momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
47
- split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
48
- q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
49
- mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, #
50
- # expert parameters
51
- precond_init_scale=1.0, precond_lr=0.1):
52
- if not 0.0 <= lr:
53
- raise ValueError(f"Invalid learning rate: {lr}")
54
- if not 0.0 <= beta < 1.0:
55
- raise ValueError(f"Invalid beta parameter: {beta}")
56
- if not 0.0 <= weight_decay:
57
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
58
-
59
- if clip_fn is None:
60
- clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
61
-
62
- defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
63
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
64
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
65
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
66
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
67
- storage_dtype=storage_dtype,
68
- caution=caution, mars_gamma=mars_gamma, mars=mars)
69
- super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
70
-
71
- def _step(self, group):
72
- should_update = self.should_update(group)
73
- momentum_into_precond_update = group.get("momentum_into_precond_update", True)
74
- precond_init_scale = group['precond_init_scale']
75
- max_size_triangular = group['max_size_triangular']
76
- min_ndim_triangular = group['min_ndim_triangular']
77
- memory_save_mode = group['memory_save_mode']
78
- precond_lr = group['precond_lr']
79
- weight_decay = group['weight_decay']
80
- lr = group['lr']
81
- beta = group['beta']
82
- store_triu_as_line = group['store_triu_as_line']
83
- q_dtype = getattr(torch, group['q_dtype'])
84
- storage_dtype = getattr(torch, group['storage_dtype'])
85
-
86
- vals = []
87
-
88
- for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
89
- state = self.state_(p)
90
-
91
- if 'Q' not in state:
92
- state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
93
- Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
94
- memory_save_mode, dtype=q_dtype)
95
- state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
96
-
97
- vals.append((p, g, state["exp_avg"], state["Q"]))
98
-
99
- if not vals:
100
- return
101
-
102
- p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
103
- del vals
104
-
105
- group["step"] += 1
106
-
107
- stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
108
-
109
- lr = -warmup(lr, group['step'], group['warmup_steps'])
110
- lr = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(lr)
111
-
112
- Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
113
- for i, (p, g) in enumerate(zip(p_list, grad_list)):
114
- q_orig = Q_list.pop(0)
115
- ea = exp_avg_list.pop(0)
116
- q = line_to_triu(q_orig) if store_triu_as_line else q_orig
117
- _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"][-1], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
118
- g)
119
- if should_update:
120
- q32 = [promote(q_) for q_ in q]
121
- self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
122
- store_triu_as_line)
@@ -1,63 +0,0 @@
1
- import torch
2
- import torch.optim
3
- from heavyball.utils import copy_stochastic_list_
4
-
5
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, decorator_knowngood
6
-
7
-
8
- @decorator_knowngood
9
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
-
12
- torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
13
- denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
14
-
15
- update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l), caution=caution,
16
- grad=g32)
17
-
18
- copy_stochastic_list_(exp_avg, exp_avg32)
19
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
20
-
21
-
22
- class ForeachAdamW(StatefulOptimizer):
23
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
24
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
25
- mars_gamma: float = 0.0025):
26
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
27
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
28
- mars_gamma=mars_gamma)
29
- super().__init__(params, defaults, foreach)
30
-
31
- def _step(self, group):
32
- eps = group['eps']
33
- decay = group['weight_decay']
34
- k = group['k']
35
-
36
- if not group['train_mode']:
37
- raise Exception("Not in train mode!")
38
-
39
- active_p = [p for p in group['params'] if p.grad is not None]
40
-
41
- if not active_p:
42
- return
43
-
44
- storage_dtype = getattr(torch, group['storage_dtype'])
45
-
46
- for p in active_p:
47
- if 'exp_avg' not in self.state_(p):
48
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
49
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
50
-
51
- y, grad, exp_avg_sq, exp_avg = zip(
52
- *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
53
-
54
- if group['mars']:
55
- self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
56
-
57
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
58
- lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
59
- step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
60
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
61
- group['caution'])
62
-
63
- group['k'] = k + 1
@@ -1,83 +0,0 @@
1
- import torch
2
- import torch.optim
3
- from heavyball.utils import copy_stochastic_list_
4
-
5
- from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote, decorator_knowngood
6
-
7
-
8
- @decorator_knowngood
9
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
10
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
11
- update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
12
-
13
- beta1 = beta_debias(beta1, step)
14
- denom = torch._foreach_sqrt(exp_avg_sq32)
15
- torch._foreach_maximum_(denom, eps)
16
- torch._foreach_mul_(exp_avg32, beta1)
17
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
18
-
19
- beta2 = beta_debias(beta2, step + 1)
20
- torch._foreach_mul_(exp_avg_sq32, beta2)
21
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
22
-
23
- copy_stochastic_list_(exp_avg, exp_avg32)
24
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
25
-
26
-
27
- class ForeachADOPT(StatefulOptimizer):
28
-
29
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
30
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
31
- mars_gamma: float = 0.0025):
32
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
33
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
34
- mars_gamma=mars_gamma)
35
- super().__init__(params, defaults, foreach)
36
-
37
- def _step(self, group):
38
- eps = group['eps']
39
- decay = group['weight_decay']
40
- k = group['k']
41
-
42
- if not group['train_mode']:
43
- raise Exception("Not in train mode!")
44
-
45
- active_p = [p for p in group['params'] if p.grad is not None]
46
-
47
- if not active_p:
48
- return
49
-
50
- storage_dtype = getattr(torch, group['storage_dtype'])
51
-
52
- for p in active_p:
53
- if 'exp_avg' not in self.state_(p):
54
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
55
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
56
-
57
- y, grad, exp_avg_sq, exp_avg = zip(
58
- *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
59
-
60
- group['k'] = k + 1
61
-
62
- if group['mars']:
63
- self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
64
-
65
- if k > 1:
66
- lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
67
- lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
68
- k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
69
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay, group['caution'])
70
- return
71
-
72
- grad = [promote(g) for g in grad]
73
- if k > 0:
74
- beta1 = beta_debias(group['betas'][0], k)
75
- denom = torch._foreach_sqrt(exp_avg_sq)
76
- torch._foreach_maximum_(denom, eps)
77
- torch._foreach_mul_(exp_avg, beta1)
78
- torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
79
-
80
- beta2 = beta_debias(group['betas'][1], k + 1)
81
- torch._foreach_mul_(exp_avg_sq, beta2)
82
- torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
83
- del grad
@@ -1,67 +0,0 @@
1
- import torch
2
- import torch.optim
3
-
4
- from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_, decorator_knowngood
5
-
6
-
7
- @decorator_knowngood
8
- def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
9
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
10
-
11
- denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
12
-
13
- beta1 = beta_debias(beta1, step)
14
- torch._foreach_mul_(exp_avg32, beta1)
15
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
16
-
17
- update_param_(y, exp_avg32, lr, decay, caution=caution, grad=g32)
18
-
19
- copy_stochastic_list_(exp_avg, exp_avg32)
20
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
21
-
22
-
23
- class ForeachLaProp(StatefulOptimizer):
24
-
25
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
26
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
27
- mars_gamma: float = 0.0025):
28
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
29
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
30
- mars_gamma=mars_gamma)
31
- super().__init__(params, defaults, foreach)
32
-
33
- def _step(self, group):
34
- eps = group['eps']
35
- decay = group['weight_decay']
36
- k = group['k']
37
-
38
- if not group['train_mode']:
39
- raise Exception("Not in train mode!")
40
-
41
- active_p = [p for p in group['params'] if p.grad is not None]
42
-
43
- if not active_p:
44
- return
45
-
46
- storage_dtype = getattr(torch, group['storage_dtype'])
47
-
48
- for p in active_p:
49
- if 'exp_avg' not in self.state_(p):
50
- self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
51
- self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
52
-
53
- y, grad, exp_avg_sq, exp_avg = zip(
54
- *[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
55
- for p in active_p])
56
-
57
- if group['mars']:
58
- self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
59
-
60
- lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
61
- lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
62
- step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
63
-
64
- _compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
65
- group['caution'])
66
-
67
- group['k'] = k + 1