torchzero 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. torchzero/__init__.py +3 -1
  2. torchzero/_minimize/__init__.py +0 -0
  3. torchzero/_minimize/methods.py +95 -0
  4. torchzero/_minimize/minimize.py +518 -0
  5. torchzero/core/__init__.py +5 -5
  6. torchzero/core/chain.py +2 -1
  7. torchzero/core/functional.py +2 -1
  8. torchzero/core/module.py +75 -4
  9. torchzero/core/transform.py +6 -5
  10. torchzero/linalg/eigh.py +116 -68
  11. torchzero/linalg/linear_operator.py +1 -0
  12. torchzero/linalg/orthogonalize.py +60 -5
  13. torchzero/linalg/sketch.py +39 -0
  14. torchzero/modules/__init__.py +1 -0
  15. torchzero/modules/adaptive/adagrad.py +2 -0
  16. torchzero/modules/adaptive/adam.py +5 -1
  17. torchzero/modules/adaptive/adan.py +3 -0
  18. torchzero/modules/adaptive/ggt.py +20 -18
  19. torchzero/modules/adaptive/lion.py +3 -1
  20. torchzero/modules/adaptive/mars.py +6 -5
  21. torchzero/modules/adaptive/msam.py +3 -0
  22. torchzero/modules/adaptive/rmsprop.py +2 -0
  23. torchzero/modules/adaptive/rprop.py +9 -7
  24. torchzero/modules/adaptive/shampoo.py +9 -1
  25. torchzero/modules/adaptive/soap.py +32 -29
  26. torchzero/modules/basis/__init__.py +2 -0
  27. torchzero/modules/basis/ggt_basis.py +199 -0
  28. torchzero/modules/basis/soap_basis.py +254 -0
  29. torchzero/modules/clipping/ema_clipping.py +32 -27
  30. torchzero/modules/clipping/growth_clipping.py +1 -0
  31. torchzero/modules/experimental/__init__.py +1 -6
  32. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  33. torchzero/modules/experimental/cubic_adam.py +4 -0
  34. torchzero/modules/grad_approximation/__init__.py +3 -2
  35. torchzero/modules/least_squares/gn.py +6 -0
  36. torchzero/modules/misc/gradient_accumulation.py +1 -0
  37. torchzero/modules/misc/misc.py +6 -0
  38. torchzero/modules/momentum/averaging.py +6 -0
  39. torchzero/modules/momentum/momentum.py +4 -0
  40. torchzero/modules/ops/__init__.py +0 -1
  41. torchzero/modules/ops/accumulate.py +4 -0
  42. torchzero/modules/ops/higher_level.py +6 -1
  43. torchzero/modules/second_order/inm.py +4 -0
  44. torchzero/modules/second_order/newton.py +11 -3
  45. torchzero/modules/second_order/newton_cg.py +7 -3
  46. torchzero/modules/second_order/nystrom.py +14 -19
  47. torchzero/modules/second_order/rsn.py +37 -6
  48. torchzero/modules/trust_region/trust_region.py +2 -1
  49. torchzero/utils/benchmarks/logistic.py +33 -18
  50. torchzero/utils/params.py +13 -1
  51. torchzero/utils/tensorlist.py +2 -2
  52. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/METADATA +1 -1
  53. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/RECORD +56 -53
  54. torchzero/modules/experimental/adanystrom.py +0 -258
  55. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  56. torchzero/modules/experimental/eigen_sr1.py +0 -182
  57. torchzero/modules/experimental/eigengrad.py +0 -207
  58. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  59. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/WHEEL +0 -0
  60. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,7 @@ from ...core import Chainable, TensorTransform
7
7
  from ...linalg import torch_linalg, regularize_eigh
8
8
  from .lre_optimizers import LREOptimizerBase
9
9
 
10
- def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol):
10
+ def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol, matrix_power=-1/2):
11
11
  """returns U ``(ndim, rank)``, L ``(rank, )``"""
12
12
  if isinstance(history, torch.Tensor):
13
13
  M = history
@@ -27,7 +27,7 @@ def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, t
27
27
  if L is None or Q is None: # this means there are no finite eigenvalues
28
28
  return None, None
29
29
 
30
- U = (M @ Q) * L.rsqrt()
30
+ U = (M @ Q) * L.pow(matrix_power)
31
31
 
32
32
  # this damping is added after computing U, this is why I didn't use one in linalg.regularize_eig
33
33
  # that's because we damp singular values this way
@@ -44,14 +44,13 @@ class GGT(TensorTransform):
44
44
  """
45
45
  GGT method from https://arxiv.org/pdf/1806.02958
46
46
 
47
- The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
48
- But it uses eigendecomposition on MM to get U and S^2 because that is faster when you don't neeed V.
47
+ The update rule is to stack recent gradients into M and
48
+ compute eigendecomposition of M M^T via eigendecomposition of M^T M.
49
49
 
50
50
  This is equivalent to full-matrix Adagrad on recent gradients.
51
51
 
52
52
  Args:
53
53
  history_size (int, optional): number of past gradients to store. Defaults to 10.
54
- beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
55
54
  update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
56
55
  eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
57
56
  truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
@@ -105,7 +104,8 @@ class GGT(TensorTransform):
105
104
  truncate: int | None = None,
106
105
  damping: float = 1e-4,
107
106
  rdamping: float = 0,
108
- eigenbasis_optimizer: LREOptimizerBase | None = None,
107
+ matrix_power: float = -1/2,
108
+ basis_optimizer: LREOptimizerBase | None = None,
109
109
  concat_params: bool = True,
110
110
 
111
111
  inner: Chainable | None = None,
@@ -114,6 +114,7 @@ class GGT(TensorTransform):
114
114
  del defaults['self'], defaults['inner'], defaults['concat_params']
115
115
 
116
116
  super().__init__(defaults, concat_params=concat_params, inner=inner)
117
+ self.add_projected_keys("grad", "history")
117
118
 
118
119
  @torch.no_grad
119
120
  def single_tensor_update(self, tensor, param, grad, loss, state, setting):
@@ -141,14 +142,15 @@ class GGT(TensorTransform):
141
142
  rdamping=setting["rdamping"],
142
143
  truncate=setting["truncate"],
143
144
  eig_tol=setting["eig_tol"],
145
+ matrix_power=setting["matrix_power"],
144
146
  )
145
147
 
146
- # reproject eigenbasis optimizer
147
- eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
148
- if eigenbasis_optimizer is not None:
148
+ # reproject basis optimizer
149
+ basis_optimizer: LREOptimizerBase | None = setting["basis_optimizer"]
150
+ if basis_optimizer is not None:
149
151
  if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
150
- eigenbasis_state = state["eigenbasis_state"]
151
- eigenbasis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=eigenbasis_state)
152
+ basis_state = state["basis_state"]
153
+ basis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=basis_state)
152
154
 
153
155
 
154
156
  # store new factors
@@ -169,18 +171,18 @@ class GGT(TensorTransform):
169
171
 
170
172
  L = state['L']
171
173
 
172
- # step with eigenbasis optimizer
173
- eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
174
- if eigenbasis_optimizer is not None:
174
+ # step with basis optimizer
175
+ basis_optimizer: LREOptimizerBase | None = setting["basis_optimizer"]
176
+ if basis_optimizer is not None:
175
177
 
176
- if "eigenbasis_state" not in state: state["eigenbasis_state"] = {}
177
- eigenbasis_state = state["eigenbasis_state"]
178
+ if "basis_state" not in state: state["basis_state"] = {}
179
+ basis_state = state["basis_state"]
178
180
 
179
- update = eigenbasis_optimizer.step(g, L=L, Q=U, state=eigenbasis_state)
181
+ update = basis_optimizer.step(g, L=L, Q=U, state=basis_state)
180
182
  return update.view_as(tensor)
181
183
 
182
184
  # or just whiten
183
185
  z = U.T @ g
184
- update = (U * L.rsqrt()) @ z
186
+ update = (U * L.pow(setting["matrix_power"])) @ z
185
187
  return update.view_as(tensor)
186
188
 
@@ -23,9 +23,11 @@ class Lion(TensorTransform):
23
23
  defaults = dict(beta1=beta1, beta2=beta2)
24
24
  super().__init__(defaults)
25
25
 
26
+ self.add_projected_keys("grad", "exp_avg")
27
+
26
28
  @torch.no_grad
27
29
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
28
30
  beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
29
- exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
31
+ exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
30
32
  return lion_(TensorList(tensors), exp_avg, beta1, beta2)
31
33
 
@@ -6,13 +6,13 @@ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
6
6
 
7
7
  def mars_correction_(
8
8
  tensors_: TensorList,
9
- prev_: TensorList,
9
+ g_prev_: TensorList,
10
10
  beta: float | NumberList,
11
11
  scaling: float | NumberList,
12
12
  max_norm: float | NumberList | None,
13
13
  ):
14
- dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
15
- prev_.copy_(tensors_)
14
+ dg = (tensors_ - g_prev_).mul_(scaling * beta / (1-beta))
15
+ g_prev_.copy_(tensors_)
16
16
 
17
17
  c = tensors_.add_(dg)
18
18
  if max_norm is not None:
@@ -63,16 +63,17 @@ class MARSCorrection(TensorTransform):
63
63
  ):
64
64
  defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
65
65
  super().__init__(defaults)
66
+ self.add_projected_keys("grad", "g_prev")
66
67
 
67
68
  @torch.no_grad
68
69
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
69
- prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
70
+ g_prev = unpack_states(states, tensors, 'g_prev', init=tensors, cls=TensorList)
70
71
  beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
71
72
  max_norm = settings[0]['max_norm']
72
73
 
73
74
  return mars_correction_(
74
75
  tensors_=TensorList(tensors),
75
- prev_=prev,
76
+ g_prev_=g_prev,
76
77
  beta=beta,
77
78
  scaling=scaling,
78
79
  max_norm=max_norm,
@@ -121,6 +121,8 @@ class MSAMMomentum(TensorTransform):
121
121
  defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
122
122
  super().__init__(defaults, uses_grad=False)
123
123
 
124
+ self.add_projected_keys("grad", "velocity")
125
+
124
126
  @torch.no_grad
125
127
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
126
128
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
@@ -180,6 +182,7 @@ class MSAM(Transform):
180
182
  super().__init__(defaults)
181
183
 
182
184
  self.set_child('modules', modules)
185
+ self.add_projected_keys("grad", "velocity")
183
186
 
184
187
 
185
188
  @torch.no_grad
@@ -38,6 +38,8 @@ class RMSprop(TensorTransform):
38
38
  super().__init__(defaults, inner=inner)
39
39
 
40
40
  self.set_child('exp_avg_sq', exp_avg_sq_tfm)
41
+ self.add_projected_keys("grad", "exp_avg")
42
+ self.add_projected_keys("grad_sq", "exp_avg_sq", "exp_avg_sq_max")
41
43
 
42
44
  @torch.no_grad
43
45
  def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
@@ -128,15 +128,15 @@ def rprop_(
128
128
 
129
129
  class Rprop(TensorTransform):
130
130
  """
131
- Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
132
- or `nminus` if it did. Then the update is applied with the sign of the current gradient.
131
+ Resilient propagation. The update magnitude gets multiplied by ``nplus`` if gradient didn't change the sign,
132
+ or ``nminus`` if it did. Then the update is applied with the sign of the current gradient.
133
133
 
134
134
  Additionally, if gradient changes sign, the update for that weight is reverted.
135
135
  Next step, magnitude for that weight won't change.
136
136
 
137
137
  Compared to pytorch this also implements backtracking update when sign changes.
138
138
 
139
- This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
139
+ This implementation is identical to ``torch.optim.Rprop`` if ``backtrack`` is set to False.
140
140
 
141
141
  Args:
142
142
  nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
@@ -164,6 +164,8 @@ class Rprop(TensorTransform):
164
164
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
165
165
  super().__init__(defaults, uses_grad=False)
166
166
 
167
+ self.add_projected_keys("grad", "prev")
168
+
167
169
  @torch.no_grad
168
170
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
169
171
  step = self.global_state.get('step', 0)
@@ -196,14 +198,14 @@ class Rprop(TensorTransform):
196
198
 
197
199
  class ScaleLRBySignChange(TensorTransform):
198
200
  """
199
- learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
200
- or `nminus` if it did.
201
+ learning rate gets multiplied by ``nplus`` if ascent/gradient didn't change the sign,
202
+ or ``nminus`` if it did.
201
203
 
202
204
  This is part of RProp update rule.
203
205
 
204
206
  Args:
205
- nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
206
- nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
207
+ nplus (float): learning rate gets multiplied by ``nplus`` if ascent/gradient didn't change the sign
208
+ nminus (float): learning rate gets multiplied by ``nminus`` if ascent/gradient changed the sign
207
209
  lb (float): lower bound for lr.
208
210
  ub (float): upper bound for lr.
209
211
  alpha (float): initial learning rate.
@@ -207,6 +207,9 @@ class Shampoo(TensorTransform):
207
207
  if setting["merge_small"]:
208
208
  tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
209
209
 
210
+ if "inner" not in self.children:
211
+ state["merged"] = tensor
212
+
210
213
  if 'diagonal_accumulator' in state:
211
214
  update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
212
215
  else:
@@ -227,10 +230,15 @@ class Shampoo(TensorTransform):
227
230
 
228
231
  state["step"] += 1
229
232
 
233
+
230
234
  @torch.no_grad
231
235
  def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
236
+
232
237
  if setting["merge_small"]:
233
- tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
238
+ if "inner" not in self.children:
239
+ tensor = state.pop("merged")
240
+ else:
241
+ tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
234
242
 
235
243
  if 'diagonal_accumulator' in state:
236
244
  dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
@@ -1,12 +1,13 @@
1
- from operator import itemgetter
2
1
  import warnings
2
+ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import TensorTransform, Chainable
7
- from ...utils import unpack_dicts, unpack_states, TensorList, NumberList
8
- from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
6
+ from ...core import Chainable, TensorTransform
9
7
  from ...linalg import torch_linalg
8
+ from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
9
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
10
+
10
11
 
11
12
  @torch.no_grad
12
13
  def update_soap_covariances_(
@@ -221,25 +222,38 @@ class SOAP(TensorTransform):
221
222
  return TensorList(tensors).clamp(-0.1, 0.1)
222
223
  # return TensorList(tensors).zero_()
223
224
 
224
-
225
225
  fs = settings[0]
226
- merged = []
226
+ merged_updates = [] # for when exp_avg is maintained unprojected
227
+ merged_grads = [] # this doesn't go into preconditioner
227
228
  projected = []
228
- # ---------------------------------- project --------------------------------- #
229
229
 
230
- for tensor, state, setting in zip(tensors, states, settings):
230
+ # -------------------------------- inner step -------------------------------- #
231
+ updates = tensors
232
+ has_inner = "inner" in self.children
233
+ if has_inner:
234
+ updates = self.inner_step_tensors("inner", updates, clone=True,
235
+ params=params, grads=grads, loss=loss)
236
+
237
+ # ---------------------------------- project --------------------------------- #
238
+ for grad, update, state, setting in zip(tensors, updates, states, settings):
231
239
  if setting["merge_small"]:
232
- tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
240
+ update, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(update, setting["max_dim"])
241
+ if has_inner: # grad is a different tensor, merge it too
242
+ grad, _, _ = _merge_small_dims(grad, setting["max_dim"])
243
+ else: # in this case update is still just grad
244
+ grad = update
233
245
 
234
- merged.append(tensor)
246
+ merged_updates.append(update)
247
+ merged_grads.append(grad)
235
248
 
236
249
  if state['GG'] is not None:
237
- tensor = project(tensor, state['Q'])
250
+ update = project(update, state['Q'])
251
+
252
+ projected.append(update)
238
253
 
239
- projected.append(tensor)
240
254
 
241
255
  # ------------------------ run adam in projected space ----------------------- #
242
- exp_avg_proj, exp_avg_sq_proj = unpack_states(states, tensors, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
256
+ exp_avg_proj, exp_avg_sq_proj = unpack_states(states, projected, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
243
257
  alpha, beta1, beta2, eps = unpack_dicts(settings, "alpha", "beta1", "beta2", "eps", cls=NumberList)
244
258
 
245
259
  # lerp exp_avg in projected space
@@ -249,15 +263,17 @@ class SOAP(TensorTransform):
249
263
  # or lerp in original space and project
250
264
  else:
251
265
  exp_avg = exp_avg_proj
252
- exp_avg.lerp_(merged, weight=1-beta1)
266
+ exp_avg.lerp_(merged_updates, weight=1-beta1)
253
267
  exp_avg_proj = []
254
268
  for t, state, setting in zip(exp_avg, states, settings):
255
269
  if state['GG'] is not None:
256
270
  t = project(t, state["Q"])
257
271
  exp_avg_proj.append(t)
258
272
 
273
+ # lerp exp_avg_sq
259
274
  exp_avg_sq_proj.mul_(beta2).addcmul_(projected, projected, value=1-beta2)
260
275
 
276
+ # adam direction
261
277
  denom = exp_avg_sq_proj.sqrt().add_(eps)
262
278
  dirs_proj = exp_avg_proj / denom
263
279
 
@@ -272,27 +288,14 @@ class SOAP(TensorTransform):
272
288
 
273
289
  dirs.append(dir)
274
290
 
275
-
276
- # -------------------------------- inner step -------------------------------- #
277
- if "inner" in self.children:
278
- tensors = self.inner_step_tensors("inner", tensors, clone=False,
279
- params=params, grads=grads,loss=loss)
280
-
281
- # we now have to re-merge small dims on updated tensors
282
- merged = []
283
- for tensor, state, setting in zip(tensors, states, settings):
284
- if setting["merge_small"]:
285
- tensor, _, _ = _merge_small_dims(tensor, setting["max_dim"])
286
- merged.append(tensor)
287
-
288
291
  # -------------------------- update preconditioners -------------------------- #
289
292
  # Update is done after the gradient step to avoid using current gradients in the projection.
290
293
 
291
- for tensor, state, setting in zip(merged, states, settings):
294
+ for grad, state, setting in zip(merged_grads, states, settings):
292
295
  if state['GG'] is not None:
293
296
 
294
297
  # lerp covariances
295
- update_soap_covariances_(tensor, state['GG'], beta=setting["shampoo_beta"])
298
+ update_soap_covariances_(grad, state['GG'], beta=setting["shampoo_beta"])
296
299
 
297
300
  # (state['step'] - 1) since we start updating on 2nd step
298
301
  if (state['step'] - 1) % setting['precond_freq'] == 0:
@@ -0,0 +1,2 @@
1
+ from .soap_basis import SOAPBasis
2
+ from .ggt_basis import GGTBasis
@@ -0,0 +1,199 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, TensorTransform
6
+ from ...utils import set_storage_
7
+ from ..adaptive.ggt import ggt_update
8
+
9
+
10
+ def _cubic_reproject(C: torch.Tensor, cu: torch.Tensor, approx:bool):
11
+ if approx: return C.pow(3) @ cu
12
+
13
+ n = cu.numel()
14
+ T = torch.zeros([n,n,n], device=cu.device, dtype=cu.dtype)
15
+ T[range(n),range(n),range(n)] = cu
16
+ T = torch.einsum('ai,bj,ck,ijk->abc', C, C, C, T)
17
+ n2 = T.size(0)
18
+ return T[range(n2), range(n2), range(n2)]
19
+
20
+ class GGTBasis(TensorTransform):
21
+ """
22
+ Run another optimizer in GGT eigenbasis. The eigenbasis is ``rank``-sized, so it is possible to run expensive
23
+ methods such as Full-matrix Adagrad/Adam.
24
+
25
+ The update rule is to stack recent gradients into M and
26
+ compute eigendecomposition of M M^T via eigendecomposition of M^T M.
27
+
28
+ This is equivalent to full-matrix Adagrad on recent gradients.
29
+
30
+ Note:
31
+ the buffers of the ``basis_opt`` are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:
32
+
33
+ ``Adagrad``, ``FullMatrixAdagrad``, ``Adam``, ``Adan``, ``Lion``, ``MARSCorrection``, ``MSAMMomentum``, ``RMSprop``, ``GGT``, ``EMA``, ``HeavyBall``, ``NAG``, ``ClipNormByEMA``, ``ClipValueByEMA``, ``NormalizeByEMA``, ``ClipValueGrowth``, ``CoordinateMomentum``, ``CubicAdam``.
34
+
35
+ Additionally most modules with no internal buffers are supported, e.g. ``Cautious``, ``Sign``, ``ClipNorm``, ``Orthogonalize``, etc. However modules that use weight values, such as ``WeighDecay`` can't be supported, as weights can't be projected.
36
+
37
+ Also, if you say use ``EMA`` on output of ``Pow(2)``, the exponential average will be reprojected as gradient and not as squared gradients. Use modules like ``EMASquared``, ``SqrtEMASquared`` to get correct reprojections.
38
+
39
+
40
+ Args:
41
+ basis_opt (Chainable): module or modules to run in GGT eigenbasis.
42
+ history_size (int, optional): number of past gradients to store, and rank of preconditioner. Defaults to 10.
43
+ update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
44
+ eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
45
+ truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
46
+ damping (float, optional): damping value. Defaults to 1e-4.
47
+ rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
48
+ concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
49
+ inner (Chainable | None, optional):
50
+ output of this module is projected and ``basis_opt`` will run on it, but preconditioners are updated
51
+ from original gradients.
52
+
53
+ ## Examples:
54
+
55
+ Examples:
56
+ Adam in GGT eigenbasis:
57
+ ```python
58
+ opt = tz.Optimizer(
59
+ model.parameters(),
60
+ tz.m.GGTBasis(tz.m.Adam(beta2=0.99)),
61
+ tz.m.LR(1e-3)
62
+ )
63
+ ```
64
+
65
+ Full-matrix Adam in GGT eigenbasis. We can define full-matrix Adam through ``FullMatrixAdagrad``.
66
+ ```python
67
+ opt = tz.Optimizer(
68
+ model.parameters(),
69
+ tz.m.GGTBasis(
70
+ [tz.m.FullMatrixAdagrad(beta=0.99, inner=tz.m.EMA(0.9, debias=True))]
71
+ ),
72
+ tz.m.LR(1e-3)
73
+ )
74
+ ```
75
+
76
+ LaProp in GGT eigenbasis:
77
+ ```python
78
+
79
+ # we define LaProp through other modules, moved it out for brevity
80
+ laprop = (
81
+ tz.m.RMSprop(0.95),
82
+ tz.m.Debias(beta1=None, beta2=0.95),
83
+ tz.m.EMA(0.95),
84
+ tz.m.Debias(beta1=0.95, beta2=None),
85
+ )
86
+
87
+ opt = tz.Optimizer(
88
+ model.parameters(),
89
+ tz.m.GGTBasis(laprop),
90
+ tz.m.LR(1e-3)
91
+ )
92
+ ```
93
+
94
+ Reference:
95
+ Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ basis_opt: Chainable,
101
+ history_size: int = 100,
102
+ update_freq: int = 1,
103
+ eig_tol: float = 1e-7,
104
+ truncate: int | None = None,
105
+ damping: float = 1e-4,
106
+ rdamping: float = 0,
107
+ matrix_power: float = -1/2,
108
+ approx_sq_reproject:bool = False,
109
+ approx_cu_reproject:bool = False,
110
+
111
+ inner: Chainable | None = None,
112
+ ):
113
+ defaults = locals().copy()
114
+ del defaults['self'], defaults['inner']
115
+
116
+ super().__init__(defaults, concat_params=True, inner=inner)
117
+ self.set_child("basis_opt", basis_opt)
118
+
119
+ @torch.no_grad
120
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
121
+ history_size = setting['history_size']
122
+ update_freq = setting['update_freq']
123
+
124
+ if 'history' not in state: state['history'] = deque(maxlen=history_size)
125
+ history = state['history']
126
+
127
+ t = tensor.clone().view(-1)
128
+ history.append(t)
129
+
130
+ step = state.get('step', 0)
131
+ state['step'] = step + 1
132
+
133
+ if step % update_freq == 0 :
134
+
135
+ # compute new factors
136
+ L = state.get("L", None)
137
+ U = state.get("U", None)
138
+
139
+ L_new, U_new = ggt_update(
140
+ history,
141
+ damping=setting["damping"],
142
+ rdamping=setting["rdamping"],
143
+ truncate=setting["truncate"],
144
+ eig_tol=setting["eig_tol"],
145
+ matrix_power=setting["matrix_power"],
146
+ )
147
+
148
+ if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
149
+ # reproject basis optimizer
150
+ # this happens after first step, so basis opt is initialized by then
151
+ # note that because we concatenate parameters, each buffer will a single rank-length vector
152
+ C = U_new.T @ U # change of basis matrix
153
+
154
+ # reproject gradient-like buffers
155
+ for (buff,) in self.get_child_projected_buffers("basis_opt", "grad"):
156
+ set_storage_(buff, C @ buff)
157
+
158
+ # reproject covariance diagonal-like buffers
159
+ for (buff,) in self.get_child_projected_buffers("basis_opt", "grad_sq"):
160
+ if setting["approx_sq_reproject"]: set_storage_(buff, C.pow(2) @ buff)
161
+ else: set_storage_(buff, (C @ buff.diag_embed() @ C.T).diagonal())
162
+
163
+ # reproject third order diagonal-like buffers
164
+ for (buff,) in self.get_child_projected_buffers("basis_opt", "grad_cu"):
165
+ buff_r = _cubic_reproject(C, buff, setting["approx_cu_reproject"])
166
+ set_storage_(buff, buff_r)
167
+
168
+ # reproject covariance-like buffers
169
+ for (buff,) in self.get_child_projected_buffers("basis_opt", "covariance"):
170
+ set_storage_(buff, C @ buff @ C.T)
171
+
172
+ # store new factors
173
+ if L_new is not None: state["L"] = L_new
174
+ if U_new is not None: state["U"] = U_new
175
+
176
+
177
+ @torch.no_grad
178
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
179
+ g = tensor.view(-1)
180
+ U = state.get('U', None)
181
+
182
+ if U is None:
183
+ # fallback to element-wise preconditioning
184
+ history = torch.stack(tuple(state["history"]), 0)
185
+ g /= history.square().mean(0).sqrt().add(1e-8)
186
+ return g.view_as(tensor)
187
+
188
+ # project
189
+ g_proj = U.T @ g
190
+
191
+ # step
192
+ dir_proj = self.inner_step_tensors("basis_opt", tensors=[g_proj], clone=False, grads=[g_proj])[0]
193
+
194
+ # unproject
195
+ update = U @ dir_proj
196
+
197
+ # update = (U * L.pow(setting["matrix_power"])) @ z
198
+ return update.view_as(tensor)
199
+