heavyball 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
104
104
  Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
105
105
  https://arxiv.org/abs/2409.11321
106
106
  https://github.com/nikhilvyas/SOAP
107
-
108
- ScheduleFree:
109
- The Road Less Scheduled
110
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
111
- https://arxiv.org/abs/2405.15682
112
- https://github.com/facebookresearch/schedule_free
113
107
  """
114
108
  use_precond_schedule: bool = False
115
109
 
116
110
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
117
111
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
118
112
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
119
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
120
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
121
- mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
122
- beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
123
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
114
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
115
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
116
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
117
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
124
118
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
125
119
 
126
120
  defaults = locals()
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
137
131
  C.scale_by_soap)
138
132
 
139
133
 
134
+ class ForeachSignLaProp(C.BaseOpt):
135
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
136
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
137
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
138
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
139
+ defaults = locals()
140
+ defaults.pop("self")
141
+ params = defaults.pop("params")
142
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
143
+
144
+
145
+ class ForeachSOLP(C.BaseOpt):
146
+ """
147
+ ForeachSOLP
148
+
149
+ Sources:
150
+ Baseline SOAP:
151
+ SOAP: Improving and Stabilizing Shampoo using Adam
152
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
153
+ https://arxiv.org/abs/2409.11321
154
+ https://github.com/nikhilvyas/SOAP
155
+ """
156
+ use_precond_schedule: bool = False
157
+
158
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
159
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
160
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
161
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
162
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
163
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
164
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
165
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
166
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
167
+
168
+ defaults = locals()
169
+ defaults.pop("self")
170
+ params = defaults.pop("params")
171
+
172
+ if use_precond_schedule:
173
+ del defaults['precondition_frequency']
174
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
175
+ else:
176
+ del defaults['precond_scheduler']
177
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
178
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
179
+ functools.partial(C.scale_by_soap, inner='laprop'))
180
+
181
+
140
182
  class PaLMForeachSOAP(ForeachSOAP):
141
183
  use_precond_schedule: bool = False
142
184
  palm: bool = True
@@ -163,7 +205,6 @@ class OrthoLaProp(C.BaseOpt):
163
205
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
206
 
165
207
 
166
-
167
208
  class LaPropOrtho(C.BaseOpt):
168
209
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
169
210
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
@@ -172,8 +213,8 @@ class LaPropOrtho(C.BaseOpt):
172
213
  defaults = locals()
173
214
  defaults.pop("self")
174
215
  params = defaults.pop("params")
175
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
176
- C.scale_by_laprop, C.orthogonalize_grad_to_param)
216
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
217
+ C.orthogonalize_grad_to_param)
177
218
 
178
219
 
179
220
  class ForeachPSGDKron(C.BaseOpt):
@@ -191,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
191
232
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
192
233
  momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
193
234
  split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
194
- stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
235
+ stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
195
236
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
196
237
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
197
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
238
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
198
239
  # expert parameters
199
240
  precond_init_scale=1.0, precond_lr=0.1):
200
241
  defaults = locals()
@@ -251,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
251
292
  CachedPSGDKron = ForeachCachedPSGDKron
252
293
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
253
294
  Muon = ForeachMuon
295
+ SignLaProp = ForeachSignLaProp
254
296
 
255
297
  __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
256
298
  "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
257
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
258
- "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
259
- "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
260
- "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho']
299
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
300
+ "ForeachAdamW", "ForeachSFAdamW",
301
+ "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
302
+ "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
303
+ 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
heavyball/chainable.py CHANGED
@@ -127,7 +127,7 @@ def zero_guard(*names):
127
127
 
128
128
 
129
129
  def copy_guard(index, *names):
130
- return functools.partial(CopyGuard, index=index, names=names, )
130
+ return functools.partial(CopyGuard, index=index, names=names)
131
131
 
132
132
 
133
133
  def general_guard(*names, init_fn, skip_first: bool = True):
@@ -222,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
222
222
  @no_state
223
223
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
224
224
  if group['step'] == 1:
225
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
225
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
226
226
  raise SkipUpdate
227
227
 
228
228
  if group['step'] == 2:
229
229
  update = utils.promote(update)
230
230
  easq = utils.promote(exp_avg_sq)
231
231
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
232
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
232
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
233
+ group['eps'])
233
234
  raise SkipUpdate
234
235
 
235
236
  utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
@@ -241,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
241
242
  @no_state
242
243
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
243
244
  if group['step'] == 1:
244
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
245
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
245
246
  raise SkipUpdate
246
247
 
247
248
  if group['step'] == 2:
248
249
  update = utils.promote(update)
249
250
  easq = utils.promote(exp_avg_sq)
250
251
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
251
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
252
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
253
+ group['eps'])
252
254
  raise SkipUpdate
253
255
 
254
256
  return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
255
257
 
256
258
 
257
- def _init_soap(state, group, update, grad, param):
258
- utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
259
+ def _init_soap(state, group, update, grad, param, inner: str = ''):
260
+ utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
259
261
 
260
262
 
261
263
  def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
@@ -343,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
343
345
  @general_guard("Q", "GG", init_fn=_init_soap)
344
346
  @no_state
345
347
  def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
346
- update = utils.promote(update)
348
+ update = utils.promote(update) # Promote to highest precision if needed
347
349
 
348
350
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
349
351
  fn = _optim_fns[inner]
350
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'], group['eps'])
352
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
353
+ group['eps'])
351
354
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
352
355
 
353
- for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
354
- utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
356
+ for u, q, gg, ea in zip(update, Q, GG, exp_avg):
357
+ utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
355
358
  utils.beta_debias(group['shampoo_beta'], group['step']),
356
359
  group['is_preconditioning'])
357
360
  return precond
@@ -503,7 +506,7 @@ def create_branch(branches: List[List[callable]], merge_fn: callable):
503
506
  def _branch(state, group, update, grad, param):
504
507
  outputs = []
505
508
  for branch in branches:
506
- branch_update = [torch.clone(g, memory_format=torch.preserve_format) for u in update]
509
+ branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
507
510
  branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
508
511
  if skip_update:
509
512
  raise ValueError("Branches should not skip updates")
heavyball/utils.py CHANGED
@@ -1,24 +1,40 @@
1
+ import copy
1
2
  import functools
2
3
  import gc
4
+ import inspect
3
5
  import math
4
6
  import random
5
7
  import string
8
+ import sys
9
+ import time
6
10
  import warnings
11
+ from datetime import datetime
7
12
  from typing import List, Optional, Tuple, Callable, Union
13
+ from unittest.mock import patch
8
14
 
15
+ import hyperopt
9
16
  import numpy as np
10
17
  import torch
11
18
  from torch import Tensor
19
+ from torch._dynamo import config
12
20
  from torch._dynamo.exc import TorchDynamoException
13
21
  from torch.backends import cudnn, opt_einsum
14
22
  from torch.utils._pytree import tree_map
15
23
 
24
+ config.cache_size_limit = 2 ** 16
25
+
26
+ np.warnings = warnings
27
+
16
28
  compile_mode = "max-autotune-no-cudagraphs"
17
29
  dynamic = False
18
30
  compile_mode_recommended_to_none = None
19
31
  zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
20
32
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
21
33
 
34
+ base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
35
+ 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
36
+ 'weight_decay': 1e-4}
37
+
22
38
 
23
39
  def decorator(func):
24
40
  compiled = None
@@ -51,7 +67,7 @@ def decorator_knowngood(func: Callable):
51
67
  return _fn
52
68
 
53
69
 
54
- einsum_base = string.ascii_lowercase + string.ascii_uppercase
70
+ einsum_base = string.ascii_lowercase
55
71
 
56
72
 
57
73
  @decorator_knowngood
@@ -103,29 +119,29 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
103
119
  but we want to merge conv kernels into fan-in or at least merge the kernel
104
120
  so, [128, 64, 3, 3] should result in [128, 576] or [128, 64, 9] instead of [73728] or [8192, 3, 3] the baseline
105
121
  would've done
122
+
123
+ By @francois-rozet (commit: 68cde41eaf7e73b4c46eacb6a944865dcc081f1d), re-commited due to faulty merge
106
124
  """
107
- shape = grad.shape
108
125
  new_shape = []
109
-
110
- curr_shape = 1
111
-
112
- for sh in shape[1:][::-1]:
113
- temp_shape = curr_shape * sh
114
- if temp_shape > max_precond_dim:
115
- if curr_shape > 1:
116
- new_shape.append(curr_shape)
117
- curr_shape = sh
126
+ cum_size = 1
127
+
128
+ for s in grad.shape[1:][::-1]:
129
+ temp_size = cum_size * s
130
+ if temp_size > max_precond_dim:
131
+ if cum_size > 1:
132
+ new_shape.append(cum_size)
133
+ cum_size = s
118
134
  else:
119
- new_shape.append(sh)
120
- curr_shape = 1
135
+ new_shape.append(s)
136
+ cum_size = 1
121
137
  else:
122
- curr_shape = temp_shape
123
- new_shape = [*shape[:1], *new_shape[::-1]]
138
+ cum_size = temp_size
124
139
 
125
- if curr_shape > 1 or len(new_shape) == 0:
126
- new_shape.append(curr_shape)
140
+ if cum_size > 1:
141
+ new_shape.append(cum_size)
127
142
 
128
- new_grad = grad.reshape(new_shape) # needs to be .reshape() due to channels_last
143
+ new_shape = [grad.shape[0], *new_shape[::-1]]
144
+ new_grad = grad.reshape(new_shape)
129
145
  if not split:
130
146
  return new_grad
131
147
 
@@ -153,12 +169,11 @@ def beta_debias(beta, step):
153
169
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
154
170
  out: List[Optional[Tensor]]):
155
171
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
156
- s32 = torch._foreach_mul(s32, beta2)
157
- s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
158
- denom = torch._foreach_sqrt(s32)
159
- denom = [d.clamp(min=eps) for d in denom]
172
+ s32 = [s * beta2 + g * g * (1 - beta2) for s, g in zip(s32, g32)]
160
173
  copy_stochastic_list_(state, s32)
161
174
 
175
+ denom = [d.sqrt().clamp(min=eps) for d in s32]
176
+
162
177
  if out[0] is None:
163
178
  return denom
164
179
 
@@ -174,7 +189,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
174
189
 
175
190
  @decorator_knowngood
176
191
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
177
- g32 = promote(grad)
192
+ g32 = list(map(promote, grad))
178
193
  denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
179
194
  out = torch._foreach_div(g32, denom)
180
195
  copy_stochastic_list_(grad, out)
@@ -189,7 +204,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
189
204
 
190
205
  @decorator_knowngood
191
206
  def _compilable_exp_avg_(state, grad, beta):
192
- lerped = _lerp32(state, grad, beta)
207
+ lerped = _lerp(state, grad, beta)
193
208
  copy_stochastic_list_(grad, lerped)
194
209
 
195
210
 
@@ -240,29 +255,31 @@ def clean():
240
255
  gc.collect()
241
256
 
242
257
 
243
- def set_torch():
258
+ def _ignore_warning(msg):
259
+ warnings.filterwarnings('ignore', f'.*{msg}.*')
260
+
261
+
262
+ def set_torch(benchmark_limit: int = 32):
244
263
  cudnn.benchmark = True
245
264
  cudnn.deterministic = False
265
+ cudnn.benchmark_limit = benchmark_limit
246
266
  torch.use_deterministic_algorithms(False)
247
267
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
248
268
  opt_einsum.enabled = True
249
- opt_einsum.strategy = "auto-hq"
269
+ opt_einsum.strategy = "dp"
270
+
271
+ # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
272
+ _ignore_warning(
273
+ 'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
274
+ _ignore_warning(
275
+ 'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
250
276
 
251
277
 
252
278
  @decorator
253
279
  def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
254
- """
255
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
256
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
257
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
258
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
259
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
260
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
261
- performance at all relative to UV^T, where USV^T = G is the SVD.
262
- """
263
280
  assert len(G.shape) == 2
264
281
  a, b, c = (3.4445, -4.7750, 2.0315)
265
- X = G.bfloat16()
282
+ X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
266
283
  X /= (X.norm() + eps) # ensure top singular value <= 1
267
284
  if G.size(0) > G.size(1):
268
285
  X = X.T
@@ -319,7 +336,7 @@ def nesterov_momentum(state, grad, beta):
319
336
 
320
337
  @decorator_knowngood
321
338
  def _compilable_nesterov_ema_(state, grad, beta):
322
- ema32 = _lerp32(state, grad, beta)
339
+ ema32 = _lerp(state, grad, beta)
323
340
  stochastic_add_(grad, ema32, 1)
324
341
 
325
342
 
@@ -330,12 +347,11 @@ def nesterov_ema(state, grad, beta):
330
347
  return grad
331
348
 
332
349
 
350
+ @decorator_knowngood
333
351
  def _compilable_grafting(magnitude, direction):
334
352
  return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
335
353
 
336
354
 
337
- # mode in ("newtonschulz", "qr", "svd")
338
- # scale_mode in ("none", "scale", "graft")
339
355
  @decorator_knowngood
340
356
  def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
341
357
  if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
@@ -363,74 +379,82 @@ def _compilable_scatter_set(target, source, index):
363
379
  target[:] = source.contiguous()[index].reshape_as(target)
364
380
 
365
381
 
366
- def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
382
+ @decorator_knowngood
383
+ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
367
384
  """
368
385
  Computes the eigenbases of the preconditioner using one round of power iteration
369
- followed by torch.linalg.qr decomposition.
386
+ followed by torch.linalg.qr decomposition, and updates exp_avg in-place from old to new eigenspace.
387
+
388
+ :param GG: List of accumulated gradient outer products.
389
+ :param Q: List of current eigenbases (updated in-place to Q_new).
390
+ :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
370
391
  """
371
- matrix = []
372
- orth_matrix = []
373
- for m, o in zip(GG, Q):
374
- if len(m) == 0:
375
- matrix.append([])
376
- orth_matrix.append([])
377
- continue
378
- if m.data.dtype != torch.float:
379
- matrix.append(promote(m.data))
380
- orth_matrix.append(promote(o.data))
381
- else:
382
- matrix.append(promote(m.data))
383
- orth_matrix.append(promote(o.data))
392
+ if isinstance(Q, list) and not Q:
393
+ return
394
+
395
+ if exp_avg is not None and exp_avg.dim() != len(Q):
396
+ raise ValueError(f"exp_avg dim {exp_avg.dim()} does not match Q length {len(Q)}")
384
397
 
385
- indices = []
398
+ new_qs = []
386
399
 
387
- for ind, (m, o, q) in enumerate(zip(matrix, orth_matrix, Q)):
400
+ for m, q in zip(GG, Q):
388
401
  if len(m) == 0:
389
- indices.append(None)
390
402
  continue
391
403
 
392
- tmp = m @ o
393
- est_eig = torch.einsum('ij,ij->j', o, tmp)
404
+ m = promote(m.data)
405
+ q_old = promote(q.data)
406
+
407
+ tmp = m @ q_old
408
+ est_eig = torch.einsum('ij,ij->j', q_old, tmp)
394
409
  sort_idx = torch.argsort(est_eig, descending=True)
395
- indices.append(sort_idx)
396
- inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
397
410
 
398
- indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
399
- for i, ind in enumerate(indices))
400
- _compilable_scatter_set(exp_avg_sq, exp_avg_sq, indices)
411
+ tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
412
+ new_qs.append(tmp)
413
+
414
+ if exp_avg is None:
415
+ for q, q_new in zip(Q, new_qs):
416
+ copy_stochastic_(q, q_new)
417
+ return
418
+
419
+ assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
420
+ in_str = einsum_base[:exp_avg.dim()]
421
+ out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
422
+
423
+ from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if len(m) > 0])
424
+ if not from_shampoo:
425
+ return
426
+
427
+ to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if len(m) > 0])
428
+ out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
429
+
430
+ subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
431
+ exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
432
+ copy_stochastic_(exp_avg, exp_avg_new)
433
+
434
+ for q, q_new in zip(Q, new_qs):
435
+ copy_stochastic_(q, q_new)
401
436
 
402
437
 
403
438
  def get_orthogonal_matrix(mat):
404
439
  """
405
440
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
406
441
  """
407
- matrix = []
408
- for m in mat:
409
- if len(m) == 0:
410
- matrix.append([])
411
- continue
412
- if m.data.dtype != torch.float:
413
- float_data = False
414
- original_type = m.data.dtype
415
- original_device = m.data.device
416
- matrix.append(promote(m.data))
417
- else:
418
- float_data = True
419
- matrix.append(m.data)
420
442
 
421
443
  final = []
422
- for m in matrix:
444
+ for m in mat:
423
445
  if len(m) == 0:
424
446
  final.append([])
425
447
  continue
426
448
 
449
+ m = promote(m.data)
450
+
427
451
  device, dtype = m.device, m.dtype
428
452
  for modifier in (None, torch.double, 'cpu'):
429
453
  if modifier is not None:
430
454
  m = m.to(modifier)
431
455
  try:
432
- Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))[1].to(device=device,
433
- dtype=dtype)
456
+ eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
457
+ eigvec = eigvec.to(device=device, dtype=dtype)
434
458
  break
435
459
  except torch.OutOfMemoryError:
436
460
  pass
@@ -440,9 +464,9 @@ def get_orthogonal_matrix(mat):
440
464
  else:
441
465
  raise RuntimeError("Failed to compute eigenvalues.")
442
466
 
443
- Q = torch.flip(Q, [1])
467
+ eigvec = torch.flip(eigvec, [1])
444
468
 
445
- final.append(Q)
469
+ final.append(eigvec)
446
470
 
447
471
  return final
448
472
 
@@ -467,7 +491,7 @@ def get_beta1(group):
467
491
 
468
492
 
469
493
  def get_beta2(group):
470
- if 'beta2_scale' in group:
494
+ if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
471
495
  step = max(group.get("step", 1), 1)
472
496
  return 1 - step ** -group['beta2_scale']
473
497
  if 'betas' in group:
@@ -536,20 +560,32 @@ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
536
560
 
537
561
 
538
562
  @decorator
539
- def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
563
+ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
564
+ """
565
+ Simplified by @francois-rozet in commit 704ccc4bab52429f945df421647ec82c54cdd65f
566
+ Re-commited due to faulty merge
567
+ """
540
568
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
541
569
  return
542
570
 
543
- for idx, sh in enumerate(grad.shape):
544
- if sh > max_precond_dim:
571
+ for idx, m in enumerate(GG):
572
+ if not isinstance(m, Tensor):
545
573
  continue
546
574
  b = einsum_base[idx]
547
575
  g0 = einsum_base[:grad.dim()]
548
576
  g1 = g0.replace(b, b.upper())
549
577
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
550
- GG[idx].lerp_(outer_product, 1 - beta)
578
+ m.lerp_(outer_product, 1 - beta)
579
+
580
+
581
+ def tree_apply(fn):
582
+ def _fn(*args):
583
+ return tree_map(fn, *args)
584
+
585
+ return _fn
551
586
 
552
587
 
588
+ @tree_apply
553
589
  def promote(x):
554
590
  if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
555
591
  return torch.float32
@@ -566,45 +602,38 @@ def min_dtype(xs: List[Tensor]):
566
602
  return torch.float32
567
603
 
568
604
 
569
- def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
605
+ def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d, beta, update_precond):
570
606
  """
571
607
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
572
608
  """
573
- compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
609
+ update_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
574
610
  if update_precond:
575
- get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
611
+ get_orthogonal_matrix_QR(GG, Q, exp_avg)
576
612
 
577
613
 
578
- def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
614
+ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
579
615
  """
580
616
  Initializes the preconditioner matrices (L and R in the paper).
581
617
  """
582
618
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
583
- if grad.dim() == 1:
584
- if precondition_1d or grad.shape[0] > max_precond_dim:
585
- state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
586
- else:
587
- state['GG'].append([])
588
-
589
- else:
619
+ if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
590
620
  for sh in grad.shape:
591
621
  if sh > max_precond_dim:
592
- state['GG'].append([])
622
+ state['GG'].append(None)
593
623
  else:
594
624
  state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
625
+ else:
626
+ state['GG'].append(None)
595
627
 
596
- compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
628
+ update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
597
629
  state['Q'] = get_orthogonal_matrix(state['GG'])
598
630
 
599
631
 
600
632
  @decorator
601
633
  def project(grad, Q, back: bool):
602
634
  """
603
-
604
635
  :param grad:
605
636
  :param Q:
606
- :param merge_dims:
607
- :param max_precond_dim:
608
637
  :param back: whether to project to Shampoo eigenbases or back to original space
609
638
  :return:
610
639
  """
@@ -617,12 +646,40 @@ def project(grad, Q, back: bool):
617
646
  return grad
618
647
 
619
648
 
649
+ def modify_closure(closure):
650
+ """
651
+ Modifies the closure function to use create_graph=True in backward().
652
+
653
+ Args:
654
+ closure: The closure function passed to the optimizer.
655
+
656
+ Returns:
657
+ The return value of the modified closure.
658
+ """
659
+
660
+ def patched_backward(self, *args, **kwargs):
661
+ kwargs['create_graph'] = True
662
+ return original_backward(self, *args, **kwargs)
663
+
664
+ original_backward = torch.Tensor.backward
665
+
666
+ with patch.object(torch.Tensor, 'backward', patched_backward):
667
+ return closure()
668
+
669
+
620
670
  class StatefulOptimizer(torch.optim.Optimizer):
671
+ """
672
+ finite_differences saves memory, but needs more compute. (Alternative is true HVP)
673
+ Both `True` and `False` have some edge cases they don't support, so experiment with it.
674
+ The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
675
+ Further notice that both methods have different numerics outputs
676
+ """
621
677
  ema_decay: float = 0.001
622
678
  compile_step: bool = False
623
679
  hessian_approx: bool = False
624
680
  precond_schedule: Union[Callable, float, None] = None
625
681
  stochastic_schedule: bool = False
682
+ finite_differences: bool = False
626
683
 
627
684
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
628
685
  super().__init__(params, {**defaults, 'foreach': foreach})
@@ -735,38 +792,68 @@ class StatefulOptimizer(torch.optim.Optimizer):
735
792
  set_(self.state_(p)['param_ema'], p.data)
736
793
  set_(p.data, ema_clone)
737
794
 
738
- def step(self, closure: Optional[Callable] = None):
739
- if self.precond_schedule is None:
740
- self._is_preconditioning = False
741
- else:
742
- self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
795
+ def _handle_closure(self, closure):
743
796
  hessian_approx = self.hessian_approx and self._is_preconditioning
797
+
744
798
  if closure is None:
745
799
  if hessian_approx:
746
800
  raise ValueError("Hessian approximation requires a closure.")
747
- loss = None
748
- else:
801
+ return None
802
+
803
+ if not hessian_approx:
749
804
  with torch.enable_grad():
750
805
  loss = closure()
751
- if hessian_approx:
752
- grads = []
753
- for group in self.param_groups:
754
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
755
- grads.append(g)
756
- p.vector = torch.randn_like(p)
757
- p.orig = p.data.clone()
758
- stochastic_add_(p.data, p.vector, tiny_bf16)
759
-
760
- with torch.enable_grad():
761
- closure()
762
-
763
- for group in self.param_groups:
764
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
765
- p.grad = grads.pop(0)
766
- stochastic_add_(g, p.grad, -1)
767
- p.hessian_vector = g
768
- p.data.copy_(p.orig)
769
- del p.orig
806
+ return loss
807
+
808
+ if self.finite_differences:
809
+ with torch.enable_grad():
810
+ loss = closure() # closure without retain_graph=True
811
+
812
+ grads = []
813
+ for group in self.param_groups:
814
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
815
+ grads.append(g)
816
+ p.vector = torch.randn_like(p)
817
+ p.orig = p.data.clone()
818
+ stochastic_add_(p.data, p.vector, tiny_bf16)
819
+ else:
820
+ with torch.enable_grad():
821
+ loss = modify_closure(closure)
822
+
823
+ if self.finite_differences:
824
+ with torch.enable_grad():
825
+ closure()
826
+
827
+ for group in self.param_groups:
828
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
829
+ p.grad = grads.pop(0)
830
+ stochastic_add_(g, p.grad, -1)
831
+ p.hessian_vector = g
832
+ p.data.copy_(p.orig)
833
+ del p.orig
834
+ else:
835
+ for group in self.param_groups:
836
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
837
+ p.grad = g
838
+ params, grads = zip(*[x for group in self.param_groups for x in
839
+ self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
840
+ vs = [torch.randn_like(p) for p in params]
841
+ with torch.enable_grad():
842
+ hvs = torch.autograd.grad(grads, params, vs)
843
+
844
+ for p, g, v, hv in zip(params, grads, vs, hvs):
845
+ p.hessian_vector = hv
846
+ p.grad = g
847
+ p.vector = v
848
+
849
+ return loss
850
+
851
+ def step(self, closure: Optional[Callable] = None):
852
+ if self.precond_schedule is None:
853
+ self._is_preconditioning = False
854
+ else:
855
+ self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
856
+ loss = self._handle_closure(closure)
770
857
 
771
858
  # we assume that parameters are constant and that there are no excessive recompiles
772
859
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
@@ -774,7 +861,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
774
861
  group['is_preconditioning'] = self._is_preconditioning
775
862
  self._step(group)
776
863
  if self.use_ema:
777
- self.ema_update(group)
864
+ self.ema_update()
778
865
 
779
866
  return loss
780
867
 
@@ -784,12 +871,12 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
784
871
  copy_stochastic_(t, s)
785
872
 
786
873
 
787
- def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
874
+ @decorator_knowngood
875
+ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
788
876
  ea32 = list(map(promote, state))
789
877
  grad = list(map(promote, grad))
790
878
  beta = promote(beta)
791
-
792
- ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
879
+ ea32 = [e * beta + g * (1 - beta) for e, g in zip(ea32, grad)]
793
880
  copy_stochastic_list_(state, ea32)
794
881
  return ea32
795
882
 
@@ -801,10 +888,9 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
801
888
  beta2 = beta_debias(beta2, step)
802
889
 
803
890
  g32 = list(map(promote, grad))
804
-
805
- exp_avg32 = _lerp32(exp_avg, g32, beta1)
806
- denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
807
- u32 = torch._foreach_div(exp_avg32, denom)
891
+ exp_avg32 = _lerp(exp_avg, g32, beta1)
892
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
893
+ u32 = [ea / d for ea, d in zip(exp_avg32, denom)]
808
894
  copy_stochastic_list_(grad, u32)
809
895
 
810
896
 
@@ -824,9 +910,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
824
910
  beta2 = beta_debias(beta2, step)
825
911
 
826
912
  u32, g32 = [list(map(promote, x)) for x in [update, grad]]
827
-
828
- exp_avg32 = _lerp32(exp_avg, u32, beta1)
829
- denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
913
+ exp_avg32 = _lerp(exp_avg, u32, beta1)
914
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
830
915
  u32 = torch._foreach_div(exp_avg32, denom)
831
916
  _compilable_update_(y, u32, decay, lr, caution, g32)
832
917
 
@@ -836,7 +921,7 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
836
921
  caution: bool):
837
922
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
838
923
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
839
- return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
924
+ _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
840
925
 
841
926
 
842
927
  @decorator_knowngood
@@ -846,17 +931,16 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
846
931
  beta2 = beta_debias(beta2, step)
847
932
 
848
933
  gp32 = list(map(promote, grad))
849
-
850
- denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
934
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, gp32, beta2, eps, [None])
851
935
  gp32 = torch._foreach_div(gp32, denom)
852
- gp32 = _lerp32(exp_avg, gp32, beta1)
853
-
936
+ gp32 = _lerp(exp_avg, gp32, beta1)
854
937
  copy_stochastic_list_(grad, gp32)
855
938
 
856
939
 
857
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, eps: float = 1e-8):
940
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
941
+ eps: float = 1e-8):
858
942
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
859
- beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0], eps)
943
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
860
944
  _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
861
945
  return grad
862
946
 
@@ -864,23 +948,23 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
864
948
  @decorator_knowngood
865
949
  def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
866
950
  grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
867
- caution: bool):
951
+ caution: bool, eps: Tensor):
868
952
  beta1 = beta_debias(beta1, step)
869
953
  beta2 = beta_debias(beta2, step)
870
954
 
871
955
  u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
872
-
873
- denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
956
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
874
957
  u32 = torch._foreach_div(u32, denom)
875
- u32 = _lerp32(exp_avg, u32, beta1)
958
+ u32 = _lerp(exp_avg, u32, beta1)
876
959
  _compilable_update_(y, u32, decay, lr, caution, gp32)
877
960
 
878
961
 
879
962
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
880
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
963
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
964
+ eps: float = 1e-8):
881
965
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
882
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
883
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
966
+ beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
967
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
884
968
 
885
969
 
886
970
  @decorator_knowngood
@@ -917,7 +1001,7 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
917
1001
  copy_stochastic_list_(exp_avg, exp_avg32)
918
1002
 
919
1003
  beta2 = beta_debias(beta2, step + 1)
920
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
1004
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
921
1005
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
922
1006
 
923
1007
  copy_stochastic_list_(grad, update)
@@ -1005,7 +1089,6 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1005
1089
  reusable einsum expressions for updating Q and preconditioning gradient.
1006
1090
  """
1007
1091
  letters = string.ascii_lowercase + string.ascii_uppercase
1008
-
1009
1092
  dtype = dtype if dtype is not None else t.dtype
1010
1093
  shape = t.shape
1011
1094
 
@@ -1049,11 +1132,9 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1049
1132
  piece1A.append(letters[i])
1050
1133
  piece2A = piece2A + letters[i]
1051
1134
  piece3A = piece3A + letters[i]
1052
-
1053
1135
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1054
1136
  subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
1055
1137
  exprGs.append(subscripts)
1056
-
1057
1138
  piece1P.append(letters[i + 13])
1058
1139
  piece2P.append(letters[i + 13])
1059
1140
  piece3P = piece3P + letters[i + 13]
@@ -1061,16 +1142,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1061
1142
  else:
1062
1143
  # use triangular matrix as preconditioner for this dim
1063
1144
  Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
1064
-
1065
1145
  piece1A.append(letters[i] + letters[i + 13])
1066
1146
  piece2A = piece2A + letters[i + 13]
1067
1147
  piece3A = piece3A + letters[i]
1068
-
1069
1148
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1070
1149
  piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1071
1150
  subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
1072
1151
  exprGs.append(subscripts)
1073
-
1074
1152
  a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1075
1153
  piece1P.append(a + b)
1076
1154
  piece2P.append(a + c)
@@ -1091,7 +1169,7 @@ def psgd_balance_Q(Q_in):
1091
1169
 
1092
1170
 
1093
1171
  def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1094
- eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1172
+ eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
1095
1173
  eps *= G.norm() / G.numel()
1096
1174
  G = G + torch.randn_like(G) * eps
1097
1175
  md = min_dtype(Q + [G])
@@ -1117,9 +1195,7 @@ def psgd_lb(A, max_abs):
1117
1195
  A /= max_abs
1118
1196
  a0 = torch.einsum('ij,ij->j', A, A)
1119
1197
  i = torch.argmax(a0)
1120
-
1121
1198
  x = torch.index_select(A, 1, i).flatten().contiguous()
1122
-
1123
1199
  x = torch.einsum('i,ij->j', x, A)
1124
1200
  x /= x.norm()
1125
1201
  x = torch.einsum('j,kj->k', x, A)
@@ -1132,15 +1208,12 @@ def psgd_lb(A, max_abs):
1132
1208
  def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1133
1209
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1134
1210
  exprA, exprGs, _ = exprs
1135
-
1136
1211
  A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1137
1212
 
1138
1213
  for q, exprG, o in zip(Q, exprGs, oq):
1139
1214
  term1 = promote(torch.einsum(exprG, A, A))
1140
1215
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1141
-
1142
1216
  term1, term2 = term1 - term2, term1 + term2
1143
-
1144
1217
  term1 *= precond_lr
1145
1218
  norm = term2.norm(float('inf'))
1146
1219
  if q.dim() < 2:
@@ -1256,8 +1329,8 @@ def identity(x):
1256
1329
 
1257
1330
  @decorator_knowngood
1258
1331
  def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1259
- ema32 = _lerp32(ema, p, ema_decay)
1260
- _lerp32(p, ema32, 1 - weight_decay)
1332
+ ema32 = _lerp(ema, p, ema_decay)
1333
+ _lerp(p, ema32, 1 - weight_decay)
1261
1334
 
1262
1335
 
1263
1336
  def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
@@ -1267,10 +1340,10 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1267
1340
 
1268
1341
 
1269
1342
  @decorator_knowngood
1270
- def _compilable_l1_weight_decay_to_ema_(p, ema, ema_deacy, weight_decay):
1271
- ema32 = _lerp32(ema, p, ema_deacy)
1343
+ def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1344
+ ema32 = _lerp(ema, p, ema_decay)
1272
1345
  for p_, e_ in zip(p, ema32):
1273
- p32 = promote(p)
1346
+ p32 = promote(p_)
1274
1347
  p32 = p32 + (p32 - e_).sign() * weight_decay
1275
1348
  copy_stochastic_(p_, p32)
1276
1349
 
@@ -1447,7 +1520,6 @@ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps:
1447
1520
 
1448
1521
  if graft:
1449
1522
  out = _compilable_grafting(g, out)
1450
-
1451
1523
  copy_stochastic_(g, out)
1452
1524
 
1453
1525
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.3
3
+ Version: 1.6.0
4
4
  Summary: Efficient optimizers
5
- Home-page: https://github.com/clashluke/heavyball
6
- Author: Lucas Nestler
5
+ Home-page: https://github.com/HomebrewML/HeavyBall
6
+ Author: HeavyBall Authors
7
7
  Author-email: github.heavyball@nestler.sh
8
8
  License: BSD
9
9
  Classifier: Development Status :: 5 - Production/Stable
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
300
300
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
301
301
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
302
302
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
303
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
303
+ correct_bias: bool = True, warmup_steps: int = 1,
304
304
  split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
305
305
  mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
306
306
  beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
324
324
  * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
325
325
  * **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
326
326
  * **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
327
- * **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
328
327
  * **`correct_bias`**: Enables/disables bias correction for the running averages.
329
328
  * **`warmup_steps`**: Number of steps for linear learning rate warmup.
330
329
  * **`split`**: Whether to split large dimensions when forming the preconditioner.
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
931
930
  * **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
932
931
  `compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
933
932
  configurations.
934
-
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
+ heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
+ heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
4
+ heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
6
+ heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.6.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
2
- heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
3
- heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
4
- heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
6
- heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.3.dist-info/RECORD,,