heavyball 1.5.3__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
104
104
  Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
105
105
  https://arxiv.org/abs/2409.11321
106
106
  https://github.com/nikhilvyas/SOAP
107
-
108
- ScheduleFree:
109
- The Road Less Scheduled
110
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
111
- https://arxiv.org/abs/2405.15682
112
- https://github.com/facebookresearch/schedule_free
113
107
  """
114
108
  use_precond_schedule: bool = False
115
109
 
116
110
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
117
111
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
118
112
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
119
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
120
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
121
- mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
122
- beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
123
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
114
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
115
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
116
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
117
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
124
118
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
125
119
 
126
120
  defaults = locals()
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
137
131
  C.scale_by_soap)
138
132
 
139
133
 
134
+ class ForeachSignLaProp(C.BaseOpt):
135
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
136
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
137
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
138
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
139
+ defaults = locals()
140
+ defaults.pop("self")
141
+ params = defaults.pop("params")
142
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
143
+
144
+
145
+ class ForeachSOLP(C.BaseOpt):
146
+ """
147
+ ForeachSOLP
148
+
149
+ Sources:
150
+ Baseline SOAP:
151
+ SOAP: Improving and Stabilizing Shampoo using Adam
152
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
153
+ https://arxiv.org/abs/2409.11321
154
+ https://github.com/nikhilvyas/SOAP
155
+ """
156
+ use_precond_schedule: bool = False
157
+
158
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
159
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
160
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
161
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
162
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
163
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
164
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
165
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
166
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
167
+
168
+ defaults = locals()
169
+ defaults.pop("self")
170
+ params = defaults.pop("params")
171
+
172
+ if use_precond_schedule:
173
+ del defaults['precondition_frequency']
174
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
175
+ else:
176
+ del defaults['precond_scheduler']
177
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
178
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
179
+ functools.partial(C.scale_by_soap, inner='laprop'))
180
+
181
+
140
182
  class PaLMForeachSOAP(ForeachSOAP):
141
183
  use_precond_schedule: bool = False
142
184
  palm: bool = True
@@ -163,7 +205,6 @@ class OrthoLaProp(C.BaseOpt):
163
205
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
206
 
165
207
 
166
-
167
208
  class LaPropOrtho(C.BaseOpt):
168
209
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
169
210
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
@@ -172,8 +213,8 @@ class LaPropOrtho(C.BaseOpt):
172
213
  defaults = locals()
173
214
  defaults.pop("self")
174
215
  params = defaults.pop("params")
175
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
176
- C.scale_by_laprop, C.orthogonalize_grad_to_param)
216
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
217
+ C.orthogonalize_grad_to_param)
177
218
 
178
219
 
179
220
  class ForeachPSGDKron(C.BaseOpt):
@@ -191,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
191
232
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
192
233
  momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
193
234
  split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
194
- stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
235
+ stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
195
236
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
196
237
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
197
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
238
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
198
239
  # expert parameters
199
240
  precond_init_scale=1.0, precond_lr=0.1):
200
241
  defaults = locals()
@@ -251,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
251
292
  CachedPSGDKron = ForeachCachedPSGDKron
252
293
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
253
294
  Muon = ForeachMuon
295
+ SignLaProp = ForeachSignLaProp
254
296
 
255
297
  __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
256
298
  "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
257
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
258
- "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
259
- "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
260
- "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho']
299
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
300
+ "ForeachAdamW", "ForeachSFAdamW",
301
+ "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
302
+ "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
303
+ 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
heavyball/chainable.py CHANGED
@@ -127,7 +127,7 @@ def zero_guard(*names):
127
127
 
128
128
 
129
129
  def copy_guard(index, *names):
130
- return functools.partial(CopyGuard, index=index, names=names, )
130
+ return functools.partial(CopyGuard, index=index, names=names)
131
131
 
132
132
 
133
133
  def general_guard(*names, init_fn, skip_first: bool = True):
@@ -222,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
222
222
  @no_state
223
223
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
224
224
  if group['step'] == 1:
225
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
225
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
226
226
  raise SkipUpdate
227
227
 
228
228
  if group['step'] == 2:
229
229
  update = utils.promote(update)
230
230
  easq = utils.promote(exp_avg_sq)
231
231
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
232
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
232
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
233
+ group['eps'])
233
234
  raise SkipUpdate
234
235
 
235
236
  utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
@@ -241,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
241
242
  @no_state
242
243
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
243
244
  if group['step'] == 1:
244
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
245
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
245
246
  raise SkipUpdate
246
247
 
247
248
  if group['step'] == 2:
248
249
  update = utils.promote(update)
249
250
  easq = utils.promote(exp_avg_sq)
250
251
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
251
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
252
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
253
+ group['eps'])
252
254
  raise SkipUpdate
253
255
 
254
256
  return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
255
257
 
256
258
 
257
- def _init_soap(state, group, update, grad, param):
258
- utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
259
+ def _init_soap(state, group, update, grad, param, inner: str = ''):
260
+ utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
259
261
 
260
262
 
261
263
  def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
@@ -343,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
343
345
  @general_guard("Q", "GG", init_fn=_init_soap)
344
346
  @no_state
345
347
  def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
346
- update = utils.promote(update)
348
+ update = utils.promote(update) # Promote to highest precision if needed
347
349
 
348
350
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
349
351
  fn = _optim_fns[inner]
350
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'], group['eps'])
352
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
353
+ group['eps'])
351
354
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
352
355
 
353
- for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
354
- utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
356
+ for u, q, gg, ea in zip(update, Q, GG, exp_avg):
357
+ utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
355
358
  utils.beta_debias(group['shampoo_beta'], group['step']),
356
359
  group['is_preconditioning'])
357
360
  return precond
@@ -503,7 +506,7 @@ def create_branch(branches: List[List[callable]], merge_fn: callable):
503
506
  def _branch(state, group, update, grad, param):
504
507
  outputs = []
505
508
  for branch in branches:
506
- branch_update = [torch.clone(g, memory_format=torch.preserve_format) for u in update]
509
+ branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
507
510
  branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
508
511
  if skip_update:
509
512
  raise ValueError("Branches should not skip updates")
heavyball/utils.py CHANGED
@@ -5,20 +5,30 @@ import random
5
5
  import string
6
6
  import warnings
7
7
  from typing import List, Optional, Tuple, Callable, Union
8
+ from unittest.mock import patch
8
9
 
9
10
  import numpy as np
10
11
  import torch
11
12
  from torch import Tensor
13
+ from torch._dynamo import config
12
14
  from torch._dynamo.exc import TorchDynamoException
13
15
  from torch.backends import cudnn, opt_einsum
14
16
  from torch.utils._pytree import tree_map
15
17
 
18
+ config.cache_size_limit = 2 ** 16
19
+
20
+ np.warnings = warnings
21
+
16
22
  compile_mode = "max-autotune-no-cudagraphs"
17
23
  dynamic = False
18
24
  compile_mode_recommended_to_none = None
19
25
  zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
20
26
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
21
27
 
28
+ base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
29
+ 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
30
+ 'weight_decay': 1e-4}
31
+
22
32
 
23
33
  def decorator(func):
24
34
  compiled = None
@@ -51,7 +61,7 @@ def decorator_knowngood(func: Callable):
51
61
  return _fn
52
62
 
53
63
 
54
- einsum_base = string.ascii_lowercase + string.ascii_uppercase
64
+ einsum_base = string.ascii_lowercase
55
65
 
56
66
 
57
67
  @decorator_knowngood
@@ -103,29 +113,29 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
103
113
  but we want to merge conv kernels into fan-in or at least merge the kernel
104
114
  so, [128, 64, 3, 3] should result in [128, 576] or [128, 64, 9] instead of [73728] or [8192, 3, 3] the baseline
105
115
  would've done
116
+
117
+ By @francois-rozet (commit: 68cde41eaf7e73b4c46eacb6a944865dcc081f1d), re-commited due to faulty merge
106
118
  """
107
- shape = grad.shape
108
119
  new_shape = []
109
-
110
- curr_shape = 1
111
-
112
- for sh in shape[1:][::-1]:
113
- temp_shape = curr_shape * sh
114
- if temp_shape > max_precond_dim:
115
- if curr_shape > 1:
116
- new_shape.append(curr_shape)
117
- curr_shape = sh
120
+ cum_size = 1
121
+
122
+ for s in grad.shape[1:][::-1]:
123
+ temp_size = cum_size * s
124
+ if temp_size > max_precond_dim:
125
+ if cum_size > 1:
126
+ new_shape.append(cum_size)
127
+ cum_size = s
118
128
  else:
119
- new_shape.append(sh)
120
- curr_shape = 1
129
+ new_shape.append(s)
130
+ cum_size = 1
121
131
  else:
122
- curr_shape = temp_shape
123
- new_shape = [*shape[:1], *new_shape[::-1]]
132
+ cum_size = temp_size
124
133
 
125
- if curr_shape > 1 or len(new_shape) == 0:
126
- new_shape.append(curr_shape)
134
+ if cum_size > 1:
135
+ new_shape.append(cum_size)
127
136
 
128
- new_grad = grad.reshape(new_shape) # needs to be .reshape() due to channels_last
137
+ new_shape = [grad.shape[0], *new_shape[::-1]]
138
+ new_grad = grad.reshape(new_shape)
129
139
  if not split:
130
140
  return new_grad
131
141
 
@@ -149,15 +159,17 @@ def beta_debias(beta, step):
149
159
  return 1 - (1 - beta) / (1 - beta ** step)
150
160
 
151
161
 
162
+ def eps_sqrt(item, eps):
163
+ return item.sqrt().clamp(min=eps)
164
+
165
+
152
166
  @decorator_knowngood
153
167
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
154
168
  out: List[Optional[Tensor]]):
155
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
156
- s32 = torch._foreach_mul(s32, beta2)
157
- s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
158
- denom = torch._foreach_sqrt(s32)
159
- denom = [d.clamp(min=eps) for d in denom]
160
- copy_stochastic_list_(state, s32)
169
+ g32 = promote(grad)
170
+ s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
171
+
172
+ denom = [eps_sqrt(d, eps) for d in s32]
161
173
 
162
174
  if out[0] is None:
163
175
  return denom
@@ -189,7 +201,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
189
201
 
190
202
  @decorator_knowngood
191
203
  def _compilable_exp_avg_(state, grad, beta):
192
- lerped = _lerp32(state, grad, beta)
204
+ lerped = _lerp(state, grad, beta)
193
205
  copy_stochastic_list_(grad, lerped)
194
206
 
195
207
 
@@ -240,29 +252,31 @@ def clean():
240
252
  gc.collect()
241
253
 
242
254
 
243
- def set_torch():
255
+ def _ignore_warning(msg):
256
+ warnings.filterwarnings('ignore', f'.*{msg}.*')
257
+
258
+
259
+ def set_torch(benchmark_limit: int = 32):
244
260
  cudnn.benchmark = True
245
261
  cudnn.deterministic = False
262
+ cudnn.benchmark_limit = benchmark_limit
246
263
  torch.use_deterministic_algorithms(False)
247
264
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
248
- opt_einsum.enabled = True
249
- opt_einsum.strategy = "auto-hq"
265
+ opt_einsum.enabled = False
266
+ opt_einsum.strategy = "auto"
267
+
268
+ # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
269
+ _ignore_warning(
270
+ 'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
271
+ _ignore_warning(
272
+ 'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
250
273
 
251
274
 
252
275
  @decorator
253
276
  def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
254
- """
255
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
256
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
257
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
258
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
259
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
260
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
261
- performance at all relative to UV^T, where USV^T = G is the SVD.
262
- """
263
277
  assert len(G.shape) == 2
264
278
  a, b, c = (3.4445, -4.7750, 2.0315)
265
- X = G.bfloat16()
279
+ X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
266
280
  X /= (X.norm() + eps) # ensure top singular value <= 1
267
281
  if G.size(0) > G.size(1):
268
282
  X = X.T
@@ -319,7 +333,7 @@ def nesterov_momentum(state, grad, beta):
319
333
 
320
334
  @decorator_knowngood
321
335
  def _compilable_nesterov_ema_(state, grad, beta):
322
- ema32 = _lerp32(state, grad, beta)
336
+ ema32 = _lerp(state, grad, beta)
323
337
  stochastic_add_(grad, ema32, 1)
324
338
 
325
339
 
@@ -330,12 +344,11 @@ def nesterov_ema(state, grad, beta):
330
344
  return grad
331
345
 
332
346
 
347
+ @decorator_knowngood
333
348
  def _compilable_grafting(magnitude, direction):
334
349
  return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
335
350
 
336
351
 
337
- # mode in ("newtonschulz", "qr", "svd")
338
- # scale_mode in ("none", "scale", "graft")
339
352
  @decorator_knowngood
340
353
  def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
341
354
  if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
@@ -363,74 +376,84 @@ def _compilable_scatter_set(target, source, index):
363
376
  target[:] = source.contiguous()[index].reshape_as(target)
364
377
 
365
378
 
366
- def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
379
+ #@decorator_knowngood
380
+ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
367
381
  """
368
382
  Computes the eigenbases of the preconditioner using one round of power iteration
369
- followed by torch.linalg.qr decomposition.
383
+ followed by torch.linalg.qr decomposition, and updates exp_avg in-place from old to new eigenspace.
384
+
385
+ :param GG: List of accumulated gradient outer products.
386
+ :param Q: List of current eigenbases (updated in-place to Q_new).
387
+ :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
370
388
  """
371
- matrix = []
372
- orth_matrix = []
373
- for m, o in zip(GG, Q):
374
- if len(m) == 0:
375
- matrix.append([])
376
- orth_matrix.append([])
377
- continue
378
- if m.data.dtype != torch.float:
379
- matrix.append(promote(m.data))
380
- orth_matrix.append(promote(o.data))
381
- else:
382
- matrix.append(promote(m.data))
383
- orth_matrix.append(promote(o.data))
389
+ if isinstance(Q, list) and not Q:
390
+ return
384
391
 
385
- indices = []
392
+ if exp_avg is not None and exp_avg.dim() != len(Q):
393
+ raise ValueError(f"exp_avg dim {exp_avg.dim()} does not match Q length {len(Q)}")
386
394
 
387
- for ind, (m, o, q) in enumerate(zip(matrix, orth_matrix, Q)):
388
- if len(m) == 0:
389
- indices.append(None)
395
+ new_qs = []
396
+
397
+ for m, q in zip(GG, Q):
398
+ if m is None:
399
+ new_qs.append(None)
390
400
  continue
391
401
 
392
- tmp = m @ o
393
- est_eig = torch.einsum('ij,ij->j', o, tmp)
402
+ m = promote(m.data)
403
+ q_old = promote(q.data)
404
+
405
+ tmp = m @ q_old
406
+ est_eig = torch.einsum('ij,ij->j', q_old, tmp)
394
407
  sort_idx = torch.argsort(est_eig, descending=True)
395
- indices.append(sort_idx)
396
- inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
397
408
 
398
- indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
399
- for i, ind in enumerate(indices))
400
- _compilable_scatter_set(exp_avg_sq, exp_avg_sq, indices)
409
+ tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
410
+ new_qs.append(tmp)
411
+
412
+ if exp_avg is None:
413
+ for q, q_new in zip(Q, new_qs):
414
+ copy_stochastic_(q, q_new)
415
+ return
416
+
417
+ assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
418
+ in_str = einsum_base[:exp_avg.dim()]
419
+ out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
420
+
421
+ from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
422
+ if not from_shampoo:
423
+ return
424
+
425
+ to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
426
+ out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
427
+
428
+ subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
429
+ exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None])
430
+ copy_stochastic_(exp_avg, exp_avg_new)
431
+
432
+ for q, q_new in zip(Q, new_qs):
433
+ if q is not None:
434
+ copy_stochastic_(q, q_new)
401
435
 
402
436
 
403
437
  def get_orthogonal_matrix(mat):
404
438
  """
405
439
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
406
440
  """
407
- matrix = []
408
- for m in mat:
409
- if len(m) == 0:
410
- matrix.append([])
411
- continue
412
- if m.data.dtype != torch.float:
413
- float_data = False
414
- original_type = m.data.dtype
415
- original_device = m.data.device
416
- matrix.append(promote(m.data))
417
- else:
418
- float_data = True
419
- matrix.append(m.data)
420
441
 
421
442
  final = []
422
- for m in matrix:
423
- if len(m) == 0:
424
- final.append([])
443
+ for m in mat:
444
+ if m is None:
445
+ final.append(None)
425
446
  continue
426
447
 
448
+ m = promote(m.data)
449
+
427
450
  device, dtype = m.device, m.dtype
428
451
  for modifier in (None, torch.double, 'cpu'):
429
452
  if modifier is not None:
430
453
  m = m.to(modifier)
431
454
  try:
432
- Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))[1].to(device=device,
433
- dtype=dtype)
455
+ eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
456
+ eigvec = eigvec.to(device=device, dtype=dtype)
434
457
  break
435
458
  except torch.OutOfMemoryError:
436
459
  pass
@@ -440,9 +463,9 @@ def get_orthogonal_matrix(mat):
440
463
  else:
441
464
  raise RuntimeError("Failed to compute eigenvalues.")
442
465
 
443
- Q = torch.flip(Q, [1])
466
+ eigvec = torch.flip(eigvec, [1])
444
467
 
445
- final.append(Q)
468
+ final.append(eigvec)
446
469
 
447
470
  return final
448
471
 
@@ -452,7 +475,9 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
452
475
  for x_, y_ in zip(x, y):
453
476
  x32 = promote(x_)
454
477
  y32 = promote(y_)
455
- copy_stochastic_(x_, x32.lerp(y32, a))
478
+ if x32.dtype != y32.dtype:
479
+ y32 = y32.to(x32.dtype)
480
+ copy_stochastic_(x_, x32 * (1 - a) + y32 * a)
456
481
 
457
482
 
458
483
  def get_beta1(group):
@@ -467,7 +492,7 @@ def get_beta1(group):
467
492
 
468
493
 
469
494
  def get_beta2(group):
470
- if 'beta2_scale' in group:
495
+ if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
471
496
  step = max(group.get("step", 1), 1)
472
497
  return 1 - step ** -group['beta2_scale']
473
498
  if 'betas' in group:
@@ -536,20 +561,32 @@ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
536
561
 
537
562
 
538
563
  @decorator
539
- def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
564
+ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
565
+ """
566
+ Simplified by @francois-rozet in commit 704ccc4bab52429f945df421647ec82c54cdd65f
567
+ Re-commited due to faulty merge
568
+ """
540
569
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
541
570
  return
542
571
 
543
- for idx, sh in enumerate(grad.shape):
544
- if sh > max_precond_dim:
572
+ for idx, m in enumerate(GG):
573
+ if not isinstance(m, Tensor):
545
574
  continue
546
575
  b = einsum_base[idx]
547
576
  g0 = einsum_base[:grad.dim()]
548
577
  g1 = g0.replace(b, b.upper())
549
578
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
550
- GG[idx].lerp_(outer_product, 1 - beta)
579
+ stochastic_lerp_(m, outer_product, 1 - beta)
580
+
581
+
582
+ def tree_apply(fn):
583
+ def _fn(*args):
584
+ return tree_map(fn, *args)
585
+
586
+ return _fn
551
587
 
552
588
 
589
+ @tree_apply
553
590
  def promote(x):
554
591
  if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
555
592
  return torch.float32
@@ -566,63 +603,85 @@ def min_dtype(xs: List[Tensor]):
566
603
  return torch.float32
567
604
 
568
605
 
569
- def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
606
+ def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d, beta, update_precond):
570
607
  """
571
608
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
572
609
  """
573
- compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
610
+ update_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
574
611
  if update_precond:
575
- get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
612
+ get_orthogonal_matrix_QR(GG, Q, exp_avg)
576
613
 
577
614
 
578
- def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
615
+ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
579
616
  """
580
617
  Initializes the preconditioner matrices (L and R in the paper).
581
618
  """
582
619
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
583
- if grad.dim() == 1:
584
- if precondition_1d or grad.shape[0] > max_precond_dim:
585
- state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
586
- else:
587
- state['GG'].append([])
588
-
589
- else:
620
+ if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
590
621
  for sh in grad.shape:
591
- if sh > max_precond_dim:
592
- state['GG'].append([])
622
+ if sh > max_precond_dim or sh == 1:
623
+ # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
624
+ state['GG'].append(None)
593
625
  else:
594
626
  state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
627
+ else:
628
+ state['GG'].append(None)
595
629
 
596
- compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
630
+ update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
597
631
  state['Q'] = get_orthogonal_matrix(state['GG'])
598
632
 
599
633
 
600
634
  @decorator
601
635
  def project(grad, Q, back: bool):
602
636
  """
603
-
604
637
  :param grad:
605
638
  :param Q:
606
- :param merge_dims:
607
- :param max_precond_dim:
608
639
  :param back: whether to project to Shampoo eigenbases or back to original space
609
640
  :return:
610
641
  """
611
642
  param = einsum_base[:grad.dim()]
612
- preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
643
+ preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
613
644
  if preconditioners:
614
645
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
615
- out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
646
+ out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
616
647
  grad = out.to(grad.dtype)
617
648
  return grad
618
649
 
619
650
 
651
+ def modify_closure(closure):
652
+ """
653
+ Modifies the closure function to use create_graph=True in backward().
654
+
655
+ Args:
656
+ closure: The closure function passed to the optimizer.
657
+
658
+ Returns:
659
+ The return value of the modified closure.
660
+ """
661
+
662
+ def patched_backward(self, *args, **kwargs):
663
+ kwargs['create_graph'] = True
664
+ return original_backward(self, *args, **kwargs)
665
+
666
+ original_backward = torch.Tensor.backward
667
+
668
+ with patch.object(torch.Tensor, 'backward', patched_backward):
669
+ return closure()
670
+
671
+
620
672
  class StatefulOptimizer(torch.optim.Optimizer):
673
+ """
674
+ finite_differences saves memory, but needs more compute. (Alternative is true HVP)
675
+ Both `True` and `False` have some edge cases they don't support, so experiment with it.
676
+ The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
677
+ Further notice that both methods have different numerics outputs
678
+ """
621
679
  ema_decay: float = 0.001
622
680
  compile_step: bool = False
623
681
  hessian_approx: bool = False
624
682
  precond_schedule: Union[Callable, float, None] = None
625
683
  stochastic_schedule: bool = False
684
+ finite_differences: bool = False
626
685
 
627
686
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
628
687
  super().__init__(params, {**defaults, 'foreach': foreach})
@@ -735,38 +794,68 @@ class StatefulOptimizer(torch.optim.Optimizer):
735
794
  set_(self.state_(p)['param_ema'], p.data)
736
795
  set_(p.data, ema_clone)
737
796
 
738
- def step(self, closure: Optional[Callable] = None):
739
- if self.precond_schedule is None:
740
- self._is_preconditioning = False
741
- else:
742
- self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
797
+ def _handle_closure(self, closure):
743
798
  hessian_approx = self.hessian_approx and self._is_preconditioning
799
+
744
800
  if closure is None:
745
801
  if hessian_approx:
746
802
  raise ValueError("Hessian approximation requires a closure.")
747
- loss = None
748
- else:
803
+ return None
804
+
805
+ if not hessian_approx:
749
806
  with torch.enable_grad():
750
807
  loss = closure()
751
- if hessian_approx:
752
- grads = []
753
- for group in self.param_groups:
754
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
755
- grads.append(g)
756
- p.vector = torch.randn_like(p)
757
- p.orig = p.data.clone()
758
- stochastic_add_(p.data, p.vector, tiny_bf16)
759
-
760
- with torch.enable_grad():
761
- closure()
762
-
763
- for group in self.param_groups:
764
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
765
- p.grad = grads.pop(0)
766
- stochastic_add_(g, p.grad, -1)
767
- p.hessian_vector = g
768
- p.data.copy_(p.orig)
769
- del p.orig
808
+ return loss
809
+
810
+ if self.finite_differences:
811
+ with torch.enable_grad():
812
+ loss = closure() # closure without retain_graph=True
813
+
814
+ grads = []
815
+ for group in self.param_groups:
816
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
817
+ grads.append(g)
818
+ p.vector = torch.randn_like(p)
819
+ p.orig = p.data.clone()
820
+ stochastic_add_(p.data, p.vector, tiny_bf16)
821
+ else:
822
+ with torch.enable_grad():
823
+ loss = modify_closure(closure)
824
+
825
+ if self.finite_differences:
826
+ with torch.enable_grad():
827
+ closure()
828
+
829
+ for group in self.param_groups:
830
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
831
+ p.grad = grads.pop(0)
832
+ stochastic_add_(g, p.grad, -1)
833
+ p.hessian_vector = g
834
+ p.data.copy_(p.orig)
835
+ del p.orig
836
+ else:
837
+ for group in self.param_groups:
838
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
839
+ p.grad = g
840
+ params, grads = zip(*[x for group in self.param_groups for x in
841
+ self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
842
+ vs = [torch.randn_like(p) for p in params]
843
+ with torch.enable_grad():
844
+ hvs = torch.autograd.grad(grads, params, vs)
845
+
846
+ for p, g, v, hv in zip(params, grads, vs, hvs):
847
+ p.hessian_vector = hv
848
+ p.grad = g
849
+ p.vector = v
850
+
851
+ return loss
852
+
853
+ def step(self, closure: Optional[Callable] = None):
854
+ if self.precond_schedule is None:
855
+ self._is_preconditioning = False
856
+ else:
857
+ self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
858
+ loss = self._handle_closure(closure)
770
859
 
771
860
  # we assume that parameters are constant and that there are no excessive recompiles
772
861
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
@@ -774,7 +863,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
774
863
  group['is_preconditioning'] = self._is_preconditioning
775
864
  self._step(group)
776
865
  if self.use_ema:
777
- self.ema_update(group)
866
+ self.ema_update()
778
867
 
779
868
  return loss
780
869
 
@@ -784,12 +873,12 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
784
873
  copy_stochastic_(t, s)
785
874
 
786
875
 
787
- def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
876
+ @decorator_knowngood
877
+ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
788
878
  ea32 = list(map(promote, state))
789
879
  grad = list(map(promote, grad))
790
880
  beta = promote(beta)
791
-
792
- ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
881
+ stochastic_lerp_(ea32, grad, 1 - beta)
793
882
  copy_stochastic_list_(state, ea32)
794
883
  return ea32
795
884
 
@@ -801,9 +890,8 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
801
890
  beta2 = beta_debias(beta2, step)
802
891
 
803
892
  g32 = list(map(promote, grad))
804
-
805
- exp_avg32 = _lerp32(exp_avg, g32, beta1)
806
- denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
893
+ exp_avg32 = _lerp(exp_avg, g32, beta1)
894
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
807
895
  u32 = torch._foreach_div(exp_avg32, denom)
808
896
  copy_stochastic_list_(grad, u32)
809
897
 
@@ -824,9 +912,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
824
912
  beta2 = beta_debias(beta2, step)
825
913
 
826
914
  u32, g32 = [list(map(promote, x)) for x in [update, grad]]
827
-
828
- exp_avg32 = _lerp32(exp_avg, u32, beta1)
829
- denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
915
+ exp_avg32 = _lerp(exp_avg, u32, beta1)
916
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
830
917
  u32 = torch._foreach_div(exp_avg32, denom)
831
918
  _compilable_update_(y, u32, decay, lr, caution, g32)
832
919
 
@@ -836,7 +923,7 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
836
923
  caution: bool):
837
924
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
838
925
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
839
- return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
926
+ _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
840
927
 
841
928
 
842
929
  @decorator_knowngood
@@ -846,17 +933,16 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
846
933
  beta2 = beta_debias(beta2, step)
847
934
 
848
935
  gp32 = list(map(promote, grad))
849
-
850
- denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
936
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, gp32, beta2, eps, [None])
851
937
  gp32 = torch._foreach_div(gp32, denom)
852
- gp32 = _lerp32(exp_avg, gp32, beta1)
853
-
938
+ gp32 = _lerp(exp_avg, gp32, beta1)
854
939
  copy_stochastic_list_(grad, gp32)
855
940
 
856
941
 
857
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, eps: float = 1e-8):
942
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
943
+ eps: float = 1e-8):
858
944
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
859
- beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0], eps)
945
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
860
946
  _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
861
947
  return grad
862
948
 
@@ -864,23 +950,23 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
864
950
  @decorator_knowngood
865
951
  def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
866
952
  grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
867
- caution: bool):
953
+ caution: bool, eps: Tensor):
868
954
  beta1 = beta_debias(beta1, step)
869
955
  beta2 = beta_debias(beta2, step)
870
956
 
871
957
  u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
872
-
873
- denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
958
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
874
959
  u32 = torch._foreach_div(u32, denom)
875
- u32 = _lerp32(exp_avg, u32, beta1)
960
+ u32 = _lerp(exp_avg, u32, beta1)
876
961
  _compilable_update_(y, u32, decay, lr, caution, gp32)
877
962
 
878
963
 
879
964
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
880
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
965
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
966
+ eps: float = 1e-8):
881
967
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
882
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
883
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
968
+ beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
969
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
884
970
 
885
971
 
886
972
  @decorator_knowngood
@@ -889,14 +975,11 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
889
975
  _compilable_update_(y, u32, decay, lr, caution, g32)
890
976
 
891
977
  beta1 = beta_debias(beta1, step)
892
- denom = torch._foreach_sqrt(exp_avg_sq32)
893
- denom = [d.clamp(min=eps) for d in denom]
894
- exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
895
- copy_stochastic_list_(exp_avg, exp_avg32)
978
+ denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
979
+ stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
896
980
 
897
981
  beta2 = beta_debias(beta2, step + 1)
898
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
899
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
982
+ stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
900
983
 
901
984
 
902
985
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
@@ -906,27 +989,23 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
906
989
 
907
990
 
908
991
  @decorator_knowngood
909
- def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
992
+ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
910
993
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
911
994
  update = [e.clone() for e in exp_avg]
912
995
 
913
996
  beta1 = beta_debias(beta1, step)
914
- denom = torch._foreach_sqrt(exp_avg_sq32)
915
- denom = [d.clamp(min=1e-8) for d in denom]
916
- exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
917
- copy_stochastic_list_(exp_avg, exp_avg32)
997
+ denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
998
+ stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
918
999
 
919
- beta2 = beta_debias(beta2, step + 1)
920
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
921
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
1000
+ stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
922
1001
 
923
1002
  copy_stochastic_list_(grad, update)
924
1003
 
925
1004
 
926
- def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
1005
+ def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps: float = 1e-8):
927
1006
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
928
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
929
- _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
1007
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1008
+ _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps)
930
1009
  return grad
931
1010
 
932
1011
 
@@ -1005,7 +1084,6 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1005
1084
  reusable einsum expressions for updating Q and preconditioning gradient.
1006
1085
  """
1007
1086
  letters = string.ascii_lowercase + string.ascii_uppercase
1008
-
1009
1087
  dtype = dtype if dtype is not None else t.dtype
1010
1088
  shape = t.shape
1011
1089
 
@@ -1049,11 +1127,9 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1049
1127
  piece1A.append(letters[i])
1050
1128
  piece2A = piece2A + letters[i]
1051
1129
  piece3A = piece3A + letters[i]
1052
-
1053
1130
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1054
1131
  subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
1055
1132
  exprGs.append(subscripts)
1056
-
1057
1133
  piece1P.append(letters[i + 13])
1058
1134
  piece2P.append(letters[i + 13])
1059
1135
  piece3P = piece3P + letters[i + 13]
@@ -1061,16 +1137,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1061
1137
  else:
1062
1138
  # use triangular matrix as preconditioner for this dim
1063
1139
  Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
1064
-
1065
1140
  piece1A.append(letters[i] + letters[i + 13])
1066
1141
  piece2A = piece2A + letters[i + 13]
1067
1142
  piece3A = piece3A + letters[i]
1068
-
1069
1143
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1070
1144
  piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1071
1145
  subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
1072
1146
  exprGs.append(subscripts)
1073
-
1074
1147
  a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1075
1148
  piece1P.append(a + b)
1076
1149
  piece2P.append(a + c)
@@ -1091,7 +1164,7 @@ def psgd_balance_Q(Q_in):
1091
1164
 
1092
1165
 
1093
1166
  def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1094
- eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1167
+ eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
1095
1168
  eps *= G.norm() / G.numel()
1096
1169
  G = G + torch.randn_like(G) * eps
1097
1170
  md = min_dtype(Q + [G])
@@ -1117,9 +1190,7 @@ def psgd_lb(A, max_abs):
1117
1190
  A /= max_abs
1118
1191
  a0 = torch.einsum('ij,ij->j', A, A)
1119
1192
  i = torch.argmax(a0)
1120
-
1121
1193
  x = torch.index_select(A, 1, i).flatten().contiguous()
1122
-
1123
1194
  x = torch.einsum('i,ij->j', x, A)
1124
1195
  x /= x.norm()
1125
1196
  x = torch.einsum('j,kj->k', x, A)
@@ -1132,15 +1203,12 @@ def psgd_lb(A, max_abs):
1132
1203
  def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1133
1204
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1134
1205
  exprA, exprGs, _ = exprs
1135
-
1136
1206
  A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1137
1207
 
1138
1208
  for q, exprG, o in zip(Q, exprGs, oq):
1139
1209
  term1 = promote(torch.einsum(exprG, A, A))
1140
1210
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1141
-
1142
1211
  term1, term2 = term1 - term2, term1 + term2
1143
-
1144
1212
  term1 *= precond_lr
1145
1213
  norm = term2.norm(float('inf'))
1146
1214
  if q.dim() < 2:
@@ -1256,8 +1324,8 @@ def identity(x):
1256
1324
 
1257
1325
  @decorator_knowngood
1258
1326
  def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1259
- ema32 = _lerp32(ema, p, ema_decay)
1260
- _lerp32(p, ema32, 1 - weight_decay)
1327
+ ema32 = _lerp(ema, p, ema_decay)
1328
+ _lerp(p, ema32, 1 - weight_decay)
1261
1329
 
1262
1330
 
1263
1331
  def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
@@ -1267,10 +1335,10 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1267
1335
 
1268
1336
 
1269
1337
  @decorator_knowngood
1270
- def _compilable_l1_weight_decay_to_ema_(p, ema, ema_deacy, weight_decay):
1271
- ema32 = _lerp32(ema, p, ema_deacy)
1338
+ def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1339
+ ema32 = _lerp(ema, p, ema_decay)
1272
1340
  for p_, e_ in zip(p, ema32):
1273
- p32 = promote(p)
1341
+ p32 = promote(p_)
1274
1342
  p32 = p32 + (p32 - e_).sign() * weight_decay
1275
1343
  copy_stochastic_(p_, p32)
1276
1344
 
@@ -1447,7 +1515,6 @@ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps:
1447
1515
 
1448
1516
  if graft:
1449
1517
  out = _compilable_grafting(g, out)
1450
-
1451
1518
  copy_stochastic_(g, out)
1452
1519
 
1453
1520
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.3
3
+ Version: 1.6.1
4
4
  Summary: Efficient optimizers
5
- Home-page: https://github.com/clashluke/heavyball
6
- Author: Lucas Nestler
5
+ Home-page: https://github.com/HomebrewML/HeavyBall
6
+ Author: HeavyBall Authors
7
7
  Author-email: github.heavyball@nestler.sh
8
8
  License: BSD
9
9
  Classifier: Development Status :: 5 - Production/Stable
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
300
300
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
301
301
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
302
302
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
303
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
303
+ correct_bias: bool = True, warmup_steps: int = 1,
304
304
  split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
305
305
  mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
306
306
  beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
324
324
  * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
325
325
  * **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
326
326
  * **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
327
- * **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
328
327
  * **`correct_bias`**: Enables/disables bias correction for the running averages.
329
328
  * **`warmup_steps`**: Number of steps for linear learning rate warmup.
330
329
  * **`split`**: Whether to split large dimensions when forming the preconditioner.
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
931
930
  * **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
932
931
  `compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
933
932
  configurations.
934
-
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
+ heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
+ heavyball/utils.py,sha256=iQxSQjw_sgJp4AvX71VdTJxJ_20Tdu7W2tdrYu5q2EI,55808
4
+ heavyball-1.6.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.6.1.dist-info/METADATA,sha256=yFMCDJPpD5jVOFtL4l_pM3jTw3_ZizeTSQ_ugVHIWKM,43441
6
+ heavyball-1.6.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.6.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.6.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
2
- heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
3
- heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
4
- heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
6
- heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.3.dist-info/RECORD,,