torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -3,13 +3,22 @@ from .adadam import Adadam
3
3
  from .adamY import AdamY
4
4
  from .adasoap import AdaSOAP
5
5
  from .curveball import CurveBall
6
- from .soapy import SOAPY
6
+ from .eigendescent import EigenDescent
7
+ from .etf import (
8
+ ExponentialTrajectoryFit,
9
+ ExponentialTrajectoryFitV2,
10
+ PointwiseExponential,
11
+ )
7
12
  from .gradmin import GradMin
13
+ from .newton_solver import NewtonSolver
14
+ from .newtonnewton import NewtonNewton
8
15
  from .reduce_outward_lr import ReduceOutwardLR
16
+ from .soapy import SOAPY
9
17
  from .spectral import SpectralPreconditioner
18
+ from .structured_newton import StructuredNewton
10
19
  from .subspace_preconditioners import (
11
20
  HistorySubspacePreconditioning,
12
21
  RandomSubspacePreconditioning,
13
22
  )
14
- from .tropical_newton import TropicalNewton
15
- from .newton_solver import NewtonSolver
23
+ from .tada import TAda
24
+ from .diagonal_higher_order_newton import DiagonalHigherOrderNewton
@@ -1,12 +1,14 @@
1
1
  from operator import itemgetter
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
- from typing import Literal
5
- from ...core import Chainable, Transform, apply
5
+
6
+ from ...core import Chainable, Transform
6
7
  from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
8
+ from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
7
9
 
8
10
  @torch.no_grad
9
- def update_soap_covariances_(
11
+ def update_absoap_covariances_(
10
12
  g1: torch.Tensor,
11
13
  g2: torch.Tensor,
12
14
  GGs_: list[torch.Tensor | None],
@@ -19,138 +21,33 @@ def update_soap_covariances_(
19
21
  if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
20
22
  else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
21
23
 
22
- @torch.no_grad
23
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
24
- """
25
- Projects the gradient to the eigenbases of the preconditioner.
26
- """
27
- for mat in Q:
28
- if mat is None: continue
29
- if len(mat) > 0:
30
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
31
- else:
32
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
33
- permute_order = list(range(1, len(tensors.shape))) + [0]
34
- tensors = tensors.permute(permute_order)
35
-
36
- return tensors
37
24
 
38
- @torch.no_grad
39
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
40
- """
41
- Projects the gradient back to the original space.
42
- """
43
- for mat in Q:
44
- if mat is None: continue
45
- if len(mat) > 0:
46
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
47
- else:
48
- permute_order = list(range(1, len(tensors.shape))) + [0]
49
- tensors = tensors.permute(permute_order)
50
-
51
- return tensors
52
-
53
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
54
- @torch.no_grad
55
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
56
- """
57
- Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
58
- """
59
- matrix = []
60
- float_data = False
61
- original_type = original_device = None
62
- for m in mat:
63
- if m is None: continue
64
- if len(m) == 0:
65
- matrix.append([])
66
- continue
67
- if m.dtype != torch.float:
68
- original_type = m.dtype
69
- original_device = m.device
70
- matrix.append(m.float())
71
- else:
72
- float_data = True
73
- matrix.append(m)
74
-
75
- final = []
76
- for m in matrix:
77
- if len(m) == 0:
78
- final.append([])
79
- continue
80
- try:
81
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
82
- except Exception:
83
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
84
- Q = Q.to(m.dtype)
85
- Q = torch.flip(Q, [1])
86
-
87
- if not float_data:
88
- Q = Q.to(original_device).type(original_type)
89
- final.append(Q)
90
- return final
91
-
92
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
93
- @torch.no_grad
94
- def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
95
- """
96
- Computes the eigenbases of the preconditioner using one round of power iteration
97
- followed by torch.linalg.qr decomposition.
98
- """
99
- matrix = []
100
- orth_matrix = []
101
- float_data = False
102
- original_type = original_device = None
103
- for m,o in zip(GG, Q_list):
104
- if m is None: continue
105
- assert o is not None
106
-
107
- if len(m) == 0:
108
- matrix.append([])
109
- orth_matrix.append([])
110
- continue
111
- if m.data.dtype != torch.float:
112
- original_type = m.data.dtype
113
- original_device = m.data.device
114
- matrix.append(m.data.float())
115
- orth_matrix.append(o.data.float())
116
- else:
117
- float_data = True
118
- matrix.append(m.data.float())
119
- orth_matrix.append(o.data.float())
120
-
121
- final = []
122
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
123
- if len(m)==0:
124
- final.append([])
125
- continue
126
- est_eig = torch.diag(o.T @ m @ o)
127
- sort_idx = torch.argsort(est_eig, descending=True)
128
- exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
129
- o = o[:,sort_idx]
130
- power_iter = m @ o
131
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
132
-
133
- if not float_data:
134
- Q = Q.to(original_device).type(original_type)
135
- final.append(Q)
136
-
137
- return final, exp_avg_sq
138
-
139
- Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
25
+ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
140
26
  class ABSOAP(Transform):
141
- """SOAP but with two extra letters included in its name in order to improve converence
142
-
143
- so what you can do is choose what goes into what ,and that is supposed to be good.
27
+ """SOAP but with some extra options for testing. Please note that this is experimental and isn't guaranteed to work.
28
+
29
+ Args:
30
+ scale_by_s - whether to scale y by s
31
+ gg1 - 1st vector into GGᵀ
32
+ gg2 - 2nd vector into GGᵀ
33
+ ema1 - vector into 1st momentum
34
+ ema2 - 2 vectors into 2nd momentum
35
+ rel1 - if True, multiplies gg1 by params
36
+ rel2 - same but for gg2
37
+ norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
38
+
39
+ letters:
40
+ p - params
41
+ g - grad
42
+ s - param difference
43
+ y - grad difference
44
+ gy - g+y
45
+ sy - s+y
46
+ sn - s normalized
47
+ yn - y normalized
48
+ gys - g + y#g
49
+ sys - s + y#s
144
50
 
145
- new args
146
-
147
- scale by s whether to scale gradient differences by parameter differences
148
-
149
- y_to_ema2 whether to use gradient differences for exponential moving average too
150
-
151
- okay I changed these args into another ones
152
-
153
- BASICALLY THIS IS FOR MY EXPERIMENTS
154
51
  """
155
52
  def __init__(
156
53
  self,
@@ -166,8 +63,8 @@ class ABSOAP(Transform):
166
63
  alpha: float = 1,
167
64
  bias_correction: bool = True,
168
65
  scale_by_s: bool = True,
169
- first: Source='g',
170
- second: Source='g',
66
+ gg1: Source='g',
67
+ gg2: Source='g',
171
68
  ema1: Source='g',
172
69
  ema2: tuple[Source, Source] = ('g','g'),
173
70
  rel1: bool=False,
@@ -189,29 +86,27 @@ class ABSOAP(Transform):
189
86
  scale_by_s=scale_by_s,
190
87
  ema1=ema1,
191
88
  ema2=ema2,
192
- first=first,
193
- second=second,
89
+ first=gg1,
90
+ second=gg2,
194
91
  rel1=rel1, rel2=rel2,
195
92
  norm=norm,
196
93
  )
197
94
  super().__init__(defaults, uses_grad=False)
198
95
 
199
96
  @torch.no_grad
200
- def transform(self, tensors, params, grads, vars):
97
+ def apply(self, tensors, params, grads, loss, states, settings):
201
98
  updates = []
202
99
  # update preconditioners
203
- for i,(p,t) in enumerate(zip(params, tensors)):
204
- state = self.state[p]
205
- settings = self.settings[p]
100
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
206
101
  beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
207
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
208
- scale_by_s = settings['scale_by_s']
209
- ema1 = settings['ema1']
210
- ema2 = settings['ema2']
211
- first=settings['first']
212
- second=settings['second']
213
- rel1 = settings['rel1']; rel2 = settings['rel2']
214
- norm=settings['norm']
102
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
103
+ scale_by_s = setting['scale_by_s']
104
+ ema1 = setting['ema1']
105
+ ema2 = setting['ema2']
106
+ first=setting['first']
107
+ second=setting['second']
108
+ rel1 = setting['rel1']; rel2 = setting['rel2']
109
+ norm=setting['norm']
215
110
 
216
111
  if merge_small:
217
112
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -219,8 +114,8 @@ class ABSOAP(Transform):
219
114
  if 'g_prev' not in state:
220
115
  state['p_prev'] = p.clone()
221
116
  state['g_prev'] = t.clone()
222
- updates.append(tensors[i].clip(-0.1,0.1))
223
- continue
117
+ # updates.append(tensors[i].clip(-0.1,0.1))
118
+ # continue
224
119
 
225
120
  p_prev = state['p_prev']
226
121
  g_prev = state['g_prev']
@@ -270,11 +165,10 @@ class ABSOAP(Transform):
270
165
  t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
271
166
  t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
272
167
 
273
-
274
168
  # initialize state on 1st step
275
169
  if 'GG' not in state:
276
170
  state["exp_avg"] = torch.zeros_like(t)
277
- state["exp_avg_sq"] = torch.ones_like(t)
171
+ state["exp_avg_sq"] = torch.zeros_like(t)
278
172
 
279
173
  if not precondition_1d and t.ndim <= 1:
280
174
  state['GG'] = []
@@ -287,7 +181,7 @@ class ABSOAP(Transform):
287
181
  state['GG'] = None
288
182
 
289
183
  if state['GG'] is not None:
290
- update_soap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
184
+ update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
291
185
  state['Q'] = get_orthogonal_matrix(state['GG'])
292
186
 
293
187
  state['step'] = 0
@@ -334,7 +228,7 @@ class ABSOAP(Transform):
334
228
  if z1_projected is not None:
335
229
  update = project_back(update, state["Q"])
336
230
 
337
- if settings['bias_correction']:
231
+ if setting['bias_correction']:
338
232
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
339
233
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
340
234
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -349,8 +243,8 @@ class ABSOAP(Transform):
349
243
 
350
244
  # Update is done after the gradient step to avoid using current gradients in the projection.
351
245
  if state['GG'] is not None:
352
- update_soap_covariances_(t1, t2, state['GG'], shampoo_beta)
353
- if state['step'] % settings['precond_freq'] == 0:
246
+ update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
247
+ if state['step'] % setting['precond_freq'] == 0:
354
248
  state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
355
249
 
356
250
  return updates
@@ -50,7 +50,7 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner."""
53
+ """Adam with a diagonally preconditioned preconditioner. Please note that this is experimental and isn't guaranteed to work."""
54
54
  def __init__(
55
55
  self,
56
56
  beta1: float = 0.9,
@@ -67,31 +67,32 @@ class Adadam(Module):
67
67
  self.getter = itemgetter('amsgrad','pow','debiased')
68
68
 
69
69
  @torch.no_grad
70
- def step(self, vars):
70
+ def step(self, var):
71
71
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
+ params = var.params
72
73
 
73
- beta1,beta2,precond_beta,eps,alpha=self.get_settings('beta1','beta2','precond_beta','eps','alpha', params=vars.params, cls=NumberList)
74
- amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
74
+ beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
75
+ amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
75
76
 
76
77
  if amsgrad:
77
- exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', params=vars.params, cls=TensorList)
78
+ exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
78
79
  else:
79
- exp_avg, exp_avg_sq, exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', params=vars.params, cls=TensorList)
80
+ exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
80
81
  max_exp_avg_sq = None
81
82
  max_exp_avg_qu = None
82
83
 
83
84
  # if this is last module, update parameters in-place with slightly more efficient addcdiv_
84
- if vars.is_last:
85
- if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
86
- passed_params = TensorList(vars.params)
87
- vars.stop = True
88
- vars.skip_update = True
85
+ if var.is_last:
86
+ if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
87
+ passed_params = TensorList(var.params)
88
+ var.stop = True
89
+ var.skip_update = True
89
90
 
90
91
  else:
91
92
  passed_params = None
92
93
 
93
- vars.update = adadam_(
94
- tensors=TensorList(vars.get_update()),
94
+ var.update = adadam_(
95
+ tensors=TensorList(var.get_update()),
95
96
  exp_avg_=exp_avg,
96
97
  exp_avg_sq_=exp_avg_sq,
97
98
  exp_avg_qu_=exp_avg_qu,
@@ -108,4 +109,4 @@ class Adadam(Module):
108
109
  params_=passed_params,
109
110
  )
110
111
 
111
- return vars
112
+ return var
@@ -62,17 +62,7 @@ def adamy_(
62
62
  return None
63
63
 
64
64
  class AdamY(Module):
65
- """Adam but uses scaled gradient differences for second momentum.
66
-
67
- Args:
68
- beta1 (float, optional): momentum. Defaults to 0.9.
69
- beta2 (float, optional): second momentum. Defaults to 0.999.
70
- eps (float, optional): epsilon. Defaults to 1e-8.
71
- alpha (float, optional): learning rate. Defaults to 1.
72
- amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
73
- pow (float, optional): power used in second momentum power and root. Defaults to 2.
74
- debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
75
- """
65
+ """Adam but uses scaled gradient differences for second momentum. Please note that this is experimental and isn't guaranteed to work."""
76
66
  def __init__(
77
67
  self,
78
68
  beta1: float = 0.9,
@@ -88,36 +78,36 @@ class AdamY(Module):
88
78
  self.getter = itemgetter('amsgrad','pow','debiased')
89
79
 
90
80
  @torch.no_grad
91
- def step(self, vars):
81
+ def step(self, var):
92
82
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
93
83
 
94
- beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
95
- amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
84
+ beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
85
+ amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
96
86
 
97
87
  if amsgrad:
98
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
88
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
99
89
  else:
100
- exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
90
+ exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
101
91
  max_exp_avg_sq = None
102
92
 
103
93
  # if this is last module, update parameters in-place with slightly more efficient addcdiv_
104
- if vars.is_last:
105
- if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
106
- passed_params = TensorList(vars.params)
107
- vars.stop = True
108
- vars.skip_update = True
94
+ if var.is_last:
95
+ if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
96
+ passed_params = TensorList(var.params)
97
+ var.stop = True
98
+ var.skip_update = True
109
99
 
110
100
  else:
111
101
  passed_params = None
112
102
 
113
- p_prev = self.get_state('p_prev', params=vars.params, cls=TensorList)
114
- g_prev = self.get_state('g_prev', params=vars.params, cls=TensorList)
103
+ p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
104
+ g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
115
105
 
116
106
 
117
- vars.update = adamy_(
118
- p=TensorList(vars.params),
107
+ var.update = adamy_(
108
+ p=TensorList(var.params),
119
109
  p_prev=p_prev,
120
- g=TensorList(vars.get_update()),
110
+ g=TensorList(var.get_update()),
121
111
  g_prev=g_prev,
122
112
  exp_avg_=exp_avg,
123
113
  exp_avg_sq_=exp_avg_sq,
@@ -132,4 +122,4 @@ class AdamY(Module):
132
122
  params_=passed_params,
133
123
  )
134
124
 
135
- return vars
125
+ return var
@@ -2,11 +2,18 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Transform, apply
5
+ from ...core import Chainable, Transform
6
6
  from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+ from ..optimizers.soap import (
8
+ get_orthogonal_matrix,
9
+ get_orthogonal_matrix_QR,
10
+ project,
11
+ project_back,
12
+ )
13
+
7
14
 
8
15
  @torch.no_grad
9
- def update_soap_covariances_(
16
+ def update_adasoap_covariances_(
10
17
  grad: torch.Tensor,
11
18
  GGs_: list[torch.Tensor | None],
12
19
  GG_sqs: list[torch.Tensor | None],
@@ -24,125 +31,9 @@ def update_soap_covariances_(
24
31
  if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
25
32
  else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
26
33
 
27
- @torch.no_grad
28
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
29
- """
30
- Projects the gradient to the eigenbases of the preconditioner.
31
- """
32
- for mat in Q:
33
- if mat is None: continue
34
- if len(mat) > 0:
35
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
36
- else:
37
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
38
- permute_order = list(range(1, len(tensors.shape))) + [0]
39
- tensors = tensors.permute(permute_order)
40
-
41
- return tensors
42
-
43
- @torch.no_grad
44
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
45
- """
46
- Projects the gradient back to the original space.
47
- """
48
- for mat in Q:
49
- if mat is None: continue
50
- if len(mat) > 0:
51
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
52
- else:
53
- permute_order = list(range(1, len(tensors.shape))) + [0]
54
- tensors = tensors.permute(permute_order)
55
-
56
- return tensors
57
-
58
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
59
- @torch.no_grad
60
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
61
- """
62
- Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
63
- """
64
- matrix = []
65
- float_data = False
66
- original_type = original_device = None
67
- for m in mat:
68
- if m is None: continue
69
- if len(m) == 0:
70
- matrix.append([])
71
- continue
72
- if m.dtype != torch.float:
73
- original_type = m.dtype
74
- original_device = m.device
75
- matrix.append(m.float())
76
- else:
77
- float_data = True
78
- matrix.append(m)
79
-
80
- final = []
81
- for m in matrix:
82
- if len(m) == 0:
83
- final.append([])
84
- continue
85
- try:
86
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
87
- except Exception:
88
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
89
- Q = Q.to(m.dtype)
90
- Q = torch.flip(Q, [1])
91
-
92
- if not float_data:
93
- Q = Q.to(original_device).type(original_type)
94
- final.append(Q)
95
- return final
96
-
97
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
98
- @torch.no_grad
99
- def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
100
- """
101
- Computes the eigenbases of the preconditioner using one round of power iteration
102
- followed by torch.linalg.qr decomposition.
103
- """
104
- matrix = []
105
- orth_matrix = []
106
- float_data = False
107
- original_type = original_device = None
108
- for m,o in zip(GG, Q_list):
109
- if m is None: continue
110
- assert o is not None
111
-
112
- if len(m) == 0:
113
- matrix.append([])
114
- orth_matrix.append([])
115
- continue
116
- if m.data.dtype != torch.float:
117
- original_type = m.data.dtype
118
- original_device = m.data.device
119
- matrix.append(m.data.float())
120
- orth_matrix.append(o.data.float())
121
- else:
122
- float_data = True
123
- matrix.append(m.data.float())
124
- orth_matrix.append(o.data.float())
125
-
126
- final = []
127
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
128
- if len(m)==0:
129
- final.append([])
130
- continue
131
- est_eig = torch.diag(o.T @ m @ o)
132
- sort_idx = torch.argsort(est_eig, descending=True)
133
- exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
134
- o = o[:,sort_idx]
135
- power_iter = m @ o
136
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
137
-
138
- if not float_data:
139
- Q = Q.to(original_device).type(original_type)
140
- final.append(Q)
141
-
142
- return final, exp_avg_sq
143
34
 
144
35
  class AdaSOAP(Transform):
145
- """SOAP with diagonally preconditioned GG^Ts
36
+ """SOAP with diagonally preconditioned GG^Ts. Please note that this is experimental and isn't guaranteed to work.
146
37
 
147
38
  precond_beta - beta for GG^T squares
148
39
  """
@@ -180,15 +71,14 @@ class AdaSOAP(Transform):
180
71
  super().__init__(defaults, uses_grad=False)
181
72
 
182
73
  @torch.no_grad
183
- def transform(self, tensors, params, grads, vars):
74
+ def apply(self, tensors, params, grads, loss, states, settings):
184
75
  updates = []
185
76
  # update preconditioners
186
- for i,(p,t) in enumerate(zip(params, tensors)):
187
- state = self.state[p]
188
- settings = self.settings[p]
77
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
78
+
189
79
  beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
190
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
191
- precond_beta = settings['precond_beta']
80
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
81
+ precond_beta = setting['precond_beta']
192
82
 
193
83
  if merge_small:
194
84
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -213,14 +103,14 @@ class AdaSOAP(Transform):
213
103
 
214
104
  if state['GG'] is not None:
215
105
  assert state['GG_sq'] is not None
216
- update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
106
+ update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
217
107
  GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
218
108
  state['Q'] = get_orthogonal_matrix(GG_precond)
219
109
 
220
110
  state['step'] = 0
221
111
  updates.append(tensors[i].clip(-0.1,0.1))
222
112
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
113
+ # that can mess with other modules scaling
224
114
 
225
115
  # Projecting gradients to the eigenbases of Shampoo's preconditioner
226
116
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -259,7 +149,7 @@ class AdaSOAP(Transform):
259
149
  if t_projected is not None:
260
150
  update = project_back(update, state["Q"])
261
151
 
262
- if settings['bias_correction']:
152
+ if setting['bias_correction']:
263
153
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
264
154
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
265
155
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -274,9 +164,9 @@ class AdaSOAP(Transform):
274
164
 
275
165
  # Update is done after the gradient step to avoid using current gradients in the projection.
276
166
  if state['GG'] is not None:
277
- update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
167
+ update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
278
168
  GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
279
- if state['step'] % settings['precond_freq'] == 0:
169
+ if state['step'] % setting['precond_freq'] == 0:
280
170
  state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
281
171
 
282
172
  return updates