torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,18 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Transform, apply
5
+ from ...core import Chainable, Transform
6
6
  from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+ from ..optimizers.soap import (
8
+ get_orthogonal_matrix,
9
+ get_orthogonal_matrix_QR,
10
+ project,
11
+ project_back,
12
+ )
13
+
7
14
 
8
15
  @torch.no_grad
9
- def update_soap_covariances_(
16
+ def update_adasoap_covariances_(
10
17
  grad: torch.Tensor,
11
18
  GGs_: list[torch.Tensor | None],
12
19
  GG_sqs: list[torch.Tensor | None],
@@ -24,127 +31,16 @@ def update_soap_covariances_(
24
31
  if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
25
32
  else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
26
33
 
27
- @torch.no_grad
28
- def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
29
- """
30
- Projects the gradient to the eigenbases of the preconditioner.
31
- """
32
- for mat in Q:
33
- if mat is None: continue
34
- if len(mat) > 0:
35
- tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
36
- else:
37
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
38
- permute_order = list(range(1, len(tensors.shape))) + [0]
39
- tensors = tensors.permute(permute_order)
40
-
41
- return tensors
42
-
43
- @torch.no_grad
44
- def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
45
- """
46
- Projects the gradient back to the original space.
47
- """
48
- for mat in Q:
49
- if mat is None: continue
50
- if len(mat) > 0:
51
- tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
52
- else:
53
- permute_order = list(range(1, len(tensors.shape))) + [0]
54
- tensors = tensors.permute(permute_order)
55
-
56
- return tensors
57
-
58
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
59
- @torch.no_grad
60
- def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
61
- """
62
- Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
63
- """
64
- matrix = []
65
- float_data = False
66
- original_type = original_device = None
67
- for m in mat:
68
- if m is None: continue
69
- if len(m) == 0:
70
- matrix.append([])
71
- continue
72
- if m.dtype != torch.float:
73
- original_type = m.dtype
74
- original_device = m.device
75
- matrix.append(m.float())
76
- else:
77
- float_data = True
78
- matrix.append(m)
79
-
80
- final = []
81
- for m in matrix:
82
- if len(m) == 0:
83
- final.append([])
84
- continue
85
- try:
86
- _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
87
- except Exception:
88
- _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
89
- Q = Q.to(m.dtype)
90
- Q = torch.flip(Q, [1])
91
-
92
- if not float_data:
93
- Q = Q.to(original_device).type(original_type)
94
- final.append(Q)
95
- return final
96
-
97
- # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
98
- @torch.no_grad
99
- def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
100
- """
101
- Computes the eigenbases of the preconditioner using one round of power iteration
102
- followed by torch.linalg.qr decomposition.
103
- """
104
- matrix = []
105
- orth_matrix = []
106
- float_data = False
107
- original_type = original_device = None
108
- for m,o in zip(GG, Q_list):
109
- if m is None: continue
110
- assert o is not None
111
-
112
- if len(m) == 0:
113
- matrix.append([])
114
- orth_matrix.append([])
115
- continue
116
- if m.data.dtype != torch.float:
117
- original_type = m.data.dtype
118
- original_device = m.data.device
119
- matrix.append(m.data.float())
120
- orth_matrix.append(o.data.float())
121
- else:
122
- float_data = True
123
- matrix.append(m.data.float())
124
- orth_matrix.append(o.data.float())
125
-
126
- final = []
127
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
128
- if len(m)==0:
129
- final.append([])
130
- continue
131
- est_eig = torch.diag(o.T @ m @ o)
132
- sort_idx = torch.argsort(est_eig, descending=True)
133
- exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
134
- o = o[:,sort_idx]
135
- power_iter = m @ o
136
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
137
-
138
- if not float_data:
139
- Q = Q.to(original_device).type(original_type)
140
- final.append(Q)
141
-
142
- return final, exp_avg_sq
143
34
 
144
35
  class AdaSOAP(Transform):
145
- """SOAP with diagonally preconditioned GG^Ts
36
+ """SOAP with diagonally preconditioned GG^Ts.
37
+
38
+ .. warning::
39
+ Experimental.
146
40
 
147
41
  precond_beta - beta for GG^T squares
42
+
43
+ Verdict: It works, but it is about the same performance as Adam, but maybe more tuning potential?
148
44
  """
149
45
  def __init__(
150
46
  self,
@@ -180,15 +76,14 @@ class AdaSOAP(Transform):
180
76
  super().__init__(defaults, uses_grad=False)
181
77
 
182
78
  @torch.no_grad
183
- def transform(self, tensors, params, grads, vars):
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
184
80
  updates = []
185
81
  # update preconditioners
186
- for i,(p,t) in enumerate(zip(params, tensors)):
187
- state = self.state[p]
188
- settings = self.settings[p]
82
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
83
+
189
84
  beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
190
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
191
- precond_beta = settings['precond_beta']
85
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
86
+ precond_beta = setting['precond_beta']
192
87
 
193
88
  if merge_small:
194
89
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -213,7 +108,7 @@ class AdaSOAP(Transform):
213
108
 
214
109
  if state['GG'] is not None:
215
110
  assert state['GG_sq'] is not None
216
- update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
111
+ update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
217
112
  GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
218
113
  state['Q'] = get_orthogonal_matrix(GG_precond)
219
114
 
@@ -259,7 +154,7 @@ class AdaSOAP(Transform):
259
154
  if t_projected is not None:
260
155
  update = project_back(update, state["Q"])
261
156
 
262
- if settings['bias_correction']:
157
+ if setting['bias_correction']:
263
158
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
264
159
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
265
160
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -274,9 +169,9 @@ class AdaSOAP(Transform):
274
169
 
275
170
  # Update is done after the gradient step to avoid using current gradients in the projection.
276
171
  if state['GG'] is not None:
277
- update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
172
+ update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
278
173
  GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
279
- if state['step'] % settings['precond_freq'] == 0:
174
+ if state['step'] % setting['precond_freq'] == 0:
280
175
  state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
281
176
 
282
177
  return updates
@@ -0,0 +1,214 @@
1
+ """A bunch of useless modules that I hate and that didn't work"""
2
+ import torch
3
+
4
+ from ...core import Chainable, Transform, apply_transform
5
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
6
+
7
+
8
+ class CosineStepSize(Transform):
9
+ """Adaptive step size based on cosine similarity
10
+
11
+ VERDICT: Useless. This is too unstable.
12
+
13
+ Args:
14
+ scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
15
+ init (float, optional): initial step size. Defaults to 1.
16
+ eps (float, optional): epsilon for division stability. Defaults to 1e-12.
17
+ target_cossim (float, optional): cosine similarity needs to be above this to increase step size. Defaults to 1e-8.
18
+ inner (Chainable | None, optional):
19
+ inner modules applied after calculating cosine similarity and before step size correction. Defaults to None.
20
+ """
21
+ def __init__(self, scale:float = 0.95, init:float=1, eps:float=1e-12, inner:Chainable | None = None):
22
+ defaults = dict(scale=scale, init=init, eps=eps)
23
+ super().__init__(defaults, uses_grad=False)
24
+ if inner is not None: self.set_child('inner', inner)
25
+
26
+ @torch.no_grad
27
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
+ scale, init = unpack_dicts(settings, 'scale', 'init', cls=NumberList)
29
+ unpack_states(states, tensors, 'alpha', init=init, cls=NumberList) # initializes alpha to init
30
+ eps = settings[0]['eps']
31
+
32
+ tensors = as_tensorlist(tensors)
33
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
34
+
35
+ tensors_norm = tensors.global_vector_norm()
36
+ cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
37
+
38
+ if 'inner' in self.children:
39
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
40
+
41
+ new_alpha = []
42
+ for s, sc in zip(states, scale):
43
+ s['alpha'] *= 1 + cos_sim * sc
44
+ new_alpha.append(s['alpha'])
45
+
46
+ tensors.mul_(new_alpha)
47
+ prev.copy_(tensors)
48
+
49
+ return tensors
50
+
51
+
52
+
53
+ class CosineDebounce(Transform):
54
+ """Debouncing when cosine similarity is less than 0.
55
+
56
+ VERDICT: Useless. This doesn't help at all.
57
+
58
+ Args:
59
+ scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
60
+ eps (float, optional): epsilon for division stability. Defaults to 1e-12.
61
+ inner (Chainable | None, optional):
62
+ inner modules applied after calculating cosine similarity and before debouncing correction. Defaults to None.
63
+ """
64
+ def __init__(self, scale:float = 0.95, eps:float=1e-12, damping:float=0.95, inner:Chainable | None = None):
65
+ defaults = dict(scale=scale, eps=eps, damping=damping)
66
+ super().__init__(defaults, uses_grad=False)
67
+ if inner is not None: self.set_child('inner', inner)
68
+
69
+ @torch.no_grad
70
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
71
+ scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
72
+ eps = settings[0]['eps']
73
+
74
+ tensors = as_tensorlist(tensors)
75
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList).mul_(damping)
76
+
77
+ tensors_norm = tensors.global_vector_norm()
78
+ cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
79
+
80
+ if 'inner' in self.children:
81
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
82
+
83
+ if cos_sim < -eps:
84
+ undo = prev.neg().mul_(-cos_sim * scale)
85
+ comb = prev.graft(tensors).add_(tensors).graft_(prev).mul_(-cos_sim*scale)
86
+ tensors = undo.add_(comb)
87
+
88
+ prev.copy_(tensors)
89
+ return tensors
90
+
91
+
92
+
93
+ class CosineMomentum(Transform):
94
+ """Beta depends on cosine similarity. At cossim=1, beta is 0. At cossim=-1, beta is 2^power. This basically removes oscillations.
95
+
96
+ VERDICT: Useless. Worse than all other momentums.
97
+
98
+ Args:
99
+ scale (float, optional): cosine similarity multiplier. Defaults to 1.
100
+ nesterov (float, optional): whether to use nesterov momentum. Defaults to False.
101
+ power (float, optional): power for beta. Defaults to 1.
102
+ eps (float, optional): epsilon for division stability. Defaults to 1e-12.
103
+ inner (Chainable | None, optional):
104
+ inner modules applied after calculating cosine similarity and before updating exponential moving average. Defaults to None.
105
+ """
106
+ def __init__(self, scale:float = 1, nesterov: bool = False, power: float = 1, eps:float=1e-12, inner:Chainable | None = None):
107
+ defaults = dict(scale=scale, eps=eps, nesterov=nesterov, power=power)
108
+ super().__init__(defaults, uses_grad=False)
109
+ if inner is not None: self.set_child('inner', inner)
110
+
111
+ @torch.no_grad
112
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
113
+ scale, power = unpack_dicts(settings, 'scale', 'power', cls=NumberList)
114
+ eps = settings[0]['eps']
115
+ nesterov = settings[0]['nesterov']
116
+ exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
117
+
118
+ tensors = as_tensorlist(tensors)
119
+
120
+ tensors_norm = tensors.global_vector_norm()
121
+ cos_sim = (tensors.dot(exp_avg) / (tensors_norm * exp_avg.global_vector_norm()).clip(min=eps)).item()
122
+
123
+ if 'inner' in self.children:
124
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
125
+
126
+ beta = (1 - (cos_sim*scale)) ** power
127
+ if nesterov:
128
+ exp_avg.add_(tensors.mul(beta))
129
+ return tensors.add_(exp_avg)
130
+ else:
131
+ exp_avg.add_(tensors.mul_(beta))
132
+ return exp_avg.clone()
133
+
134
+
135
+ class AdaptiveDifference(Transform):
136
+ """VERDICT: Useless. Doesn't help (sort of to be expected)."""
137
+ def __init__(self, inner:Chainable | None = None):
138
+ defaults = dict()
139
+ super().__init__(defaults, uses_grad=False)
140
+ if inner is not None: self.set_child('inner', inner)
141
+
142
+ @torch.no_grad
143
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
144
+ tensors = as_tensorlist(tensors)
145
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
146
+
147
+ diff = tensors - prev.graft_(tensors)
148
+ prev.copy_(tensors)
149
+
150
+ if 'inner' in self.children:
151
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
152
+
153
+ tensors.add_(diff.graft_(tensors))
154
+
155
+ return tensors
156
+
157
+ class AdaptiveDifferenceEMA(Transform):
158
+ """VERDICT: better than non-EMA but still useless."""
159
+ def __init__(self, beta=0.99, inner:Chainable | None = None):
160
+ defaults = dict(beta=beta)
161
+ super().__init__(defaults, uses_grad=False)
162
+ if inner is not None: self.set_child('inner', inner)
163
+
164
+ @torch.no_grad
165
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
166
+ tensors = as_tensorlist(tensors)
167
+ beta = unpack_dicts(settings, 'beta', cls=NumberList)
168
+ prev, diff_exp_avg = unpack_states(states, tensors, 'prev', 'diff_exp_avg', init=[tensors,torch.zeros_like], cls=TensorList)
169
+
170
+ diff = (tensors - prev.graft_(tensors)).graft_(tensors)
171
+ diff_exp_avg.lerp_(diff, 1-beta)
172
+ prev.copy_(tensors)
173
+
174
+ if 'inner' in self.children:
175
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
176
+
177
+ tensors.add_(diff_exp_avg.graft(tensors))
178
+
179
+ return tensors
180
+
181
+
182
+ class ScaledAdaptiveDifference(Transform):
183
+ """VERDICT: Useless and doesn't help."""
184
+ def __init__(self, scale=0.95, damping:float=0.99, inner:Chainable | None = None):
185
+ defaults = dict(scale=scale, damping=damping)
186
+ super().__init__(defaults, uses_grad=False)
187
+ if inner is not None: self.set_child('inner', inner)
188
+
189
+ @torch.no_grad
190
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
191
+ tensors = as_tensorlist(tensors)
192
+ scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
193
+ prev_tensors, prev_update = unpack_states(states, tensors, 'prev', 'prev_update', init=[tensors,tensors], cls=TensorList)
194
+
195
+ cos_sim = (tensors.dot(prev_update) / (tensors.global_vector_norm() * prev_update.global_vector_norm()).clip(min=1e-10)).item()
196
+
197
+ if 'inner' in self.children:
198
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
199
+
200
+ if cos_sim > 0:
201
+ tensors.add_(prev_tensors*(cos_sim*scale))
202
+
203
+ else:
204
+ undo = prev_tensors.neg().mul_(-cos_sim*scale)
205
+ comb = prev_tensors.graft(tensors).add_(tensors).graft_(prev_tensors).mul_(-cos_sim*scale)
206
+ tensors = undo.add_(comb).graft_((tensors-prev_tensors).mul_(damping))
207
+
208
+ diff = tensors - prev_tensors.graft_(tensors)
209
+ prev_tensors.copy_(tensors)
210
+ diff.graft_(tensors)
211
+ tensors.add_(diff)
212
+ prev_update.copy_(tensors)
213
+
214
+ return tensors
@@ -0,0 +1,97 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+
7
+ def signed_cbrt(x: TensorList) -> TensorList:
8
+ return x.sign() * x.abs().pow(1/3)
9
+
10
+ def cubic_adam_(
11
+ tensors: TensorList,
12
+ exp_avg_: TensorList,
13
+ exp_avg_sq_: TensorList,
14
+ exp_avg_cu_: TensorList,
15
+ alpha: float | NumberList,
16
+ beta1: float | NumberList,
17
+ beta2: float | NumberList,
18
+ beta3: float | NumberList,
19
+ eps: float | NumberList,
20
+ debiased: bool,
21
+ step: int,
22
+ ):
23
+ exp_avg_.lerp_(tensors, 1-beta1)
24
+ exp_avg_sq_.lerp_(tensors**2, 1-beta2)
25
+ exp_avg_cu_.lerp_(tensors**3, 1-beta3)
26
+
27
+ if debiased:
28
+ m1 = exp_avg_ / (1 - beta1 ** step)
29
+ m2 = exp_avg_sq_ / (1 - beta2 ** step)
30
+ m3 = exp_avg_cu_ / (1 - beta3 ** step)
31
+ else:
32
+ m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
33
+
34
+ # adam minimizes ax^2 + bx
35
+ # we are going to minimize ax^3 + bx^2 + cx
36
+ A = signed_cbrt(m3)
37
+ B = m2.sqrt()
38
+ C = m1
39
+ discriminant = B.pow(2) - 4 * A * C
40
+
41
+ denom = 2 * A
42
+ root = discriminant.clamp(min=0).sqrt_()
43
+
44
+ x0 = (-B + root) / (denom + eps)
45
+ x1 = (-B - root) / (denom + eps)
46
+
47
+ f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
48
+ f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
49
+
50
+ x_star = x0.where(f0 < f1, x1)
51
+
52
+ adam = -C / (B + eps)
53
+ x_star = adam.where(discriminant < 0, x_star)
54
+
55
+ return x_star.mul_(-alpha)
56
+
57
+ class CubicAdam(Transform):
58
+ """Adam which has 3rd momentum and minimizes a cubic polynomial.
59
+
60
+ VERDICT: can outperform Adam very slightly. Usually very similar performance.
61
+
62
+ .. warning::
63
+ Experimental.
64
+
65
+ """
66
+ def __init__(
67
+ self,
68
+ beta1: float = 0.9,
69
+ beta2: float = 0.99,
70
+ beta3: float = 0.99,
71
+ eps: float = 1e-8,
72
+ debiased:bool=True,
73
+ alpha: float = 1.,
74
+ ):
75
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha)
76
+ super().__init__(defaults, uses_grad=False)
77
+
78
+ @torch.no_grad
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
80
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
81
+
82
+ beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
83
+ exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
84
+
85
+ return cubic_adam_(
86
+ tensors=TensorList(tensors),
87
+ exp_avg_=exp_avg,
88
+ exp_avg_sq_=exp_avg_sq,
89
+ exp_avg_cu_=exp_avg_cu,
90
+ alpha=alpha,
91
+ beta1=beta1,
92
+ beta2=beta2,
93
+ beta3=beta3,
94
+ eps=eps,
95
+ debiased=settings[0]['debiased'],
96
+ step=step,
97
+ )
@@ -2,7 +2,7 @@ from typing import Literal
2
2
  from collections.abc import Callable
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform, Chainable, apply
5
+ from ...core import Module, Target, Transform, Chainable, apply_transform
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
8
 
@@ -47,27 +47,27 @@ class CurveBall(Module):
47
47
  if inner is not None: self.set_child('inner', inner)
48
48
 
49
49
  @torch.no_grad
50
- def step(self, vars):
50
+ def step(self, var):
51
51
 
52
- params = vars.params
52
+ params = var.params
53
53
  settings = self.settings[params[0]]
54
54
  hvp_method = settings['hvp_method']
55
55
  h = settings['h']
56
56
 
57
- precond_lr, momentum, reg = self.get_settings('momentum', 'decay_rate', 'reg', params=params, cls=NumberList)
57
+ precond_lr, momentum, reg = self.get_settings(params, 'precond_lr', 'momentum', 'reg', cls=NumberList)
58
58
 
59
59
 
60
- closure = vars.closure
60
+ closure = var.closure
61
61
  assert closure is not None
62
62
 
63
- z, Hz = self.get_state('z', 'Hz', params=params, cls=TensorList)
63
+ z, Hz = self.get_state(params, 'z', 'Hz', cls=TensorList)
64
64
 
65
65
  if hvp_method == 'autograd':
66
- grad = vars.get_grad(create_graph=True)
66
+ grad = var.get_grad(create_graph=True)
67
67
  Hvp = hvp(params, grad, z)
68
68
 
69
69
  elif hvp_method == 'forward':
70
- loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=vars.get_grad(), normalize=True)
70
+ loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
71
71
 
72
72
  elif hvp_method == 'central':
73
73
  loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
@@ -79,11 +79,11 @@ class CurveBall(Module):
79
79
  Hz.set_(Hvp + z*reg)
80
80
 
81
81
 
82
- update = vars.get_update()
82
+ update = var.get_update()
83
83
  if 'inner' in self.children:
84
- update = apply(self.children['inner'], update, params, grads=vars.grad, vars=vars)
84
+ update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
85
85
 
86
86
  z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
87
- vars.update = z.neg()
87
+ var.update = z.neg()
88
88
 
89
- return vars
89
+ return var
@@ -1,13 +1,13 @@
1
1
  from typing import Literal
2
2
  import torch
3
3
  import torch_dct
4
- from .projection import Projection
4
+ from ..projections import ProjectionBase
5
5
  from ...core import Chainable
6
6
 
7
7
  def reverse_dims(t:torch.Tensor):
8
8
  return t.permute(*reversed(range(t.ndim)))
9
9
 
10
- class DCTProjection(Projection):
10
+ class DCTProjection(ProjectionBase):
11
11
  # norm description copied from pytorch docstring
12
12
  """Project update into Discrete Cosine Transform space, requires `torch_dct` library.
13
13
 
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
34
34
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
35
35
 
36
36
  @torch.no_grad
37
- def project(self, tensors, vars, current):
38
- settings = self.settings[vars.params[0]]
37
+ def project(self, tensors, params, grads, loss, states, settings, current):
38
+ settings = settings[0]
39
39
  dims = settings['dims']
40
40
  norm = settings['norm']
41
41
 
@@ -54,18 +54,18 @@ class DCTProjection(Projection):
54
54
  return projected
55
55
 
56
56
  @torch.no_grad
57
- def unproject(self, tensors, vars, current):
58
- settings = self.settings[vars.params[0]]
57
+ def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
58
+ settings = projected_settings[0]
59
59
  dims = settings['dims']
60
60
  norm = settings['norm']
61
61
 
62
62
  unprojected = []
63
- for u in tensors:
64
- dim = min(u.ndim, dims)
63
+ for t in projected_tensors:
64
+ dim = min(t.ndim, dims)
65
65
 
66
- if dim == 1: idct = torch_dct.idct(u, norm = norm)
67
- elif dim == 2: idct = torch_dct.idct_2d(u, norm=norm)
68
- elif dim == 3: idct = torch_dct.idct_3d(u, norm=norm)
66
+ if dim == 1: idct = torch_dct.idct(t, norm = norm)
67
+ elif dim == 2: idct = torch_dct.idct_2d(t, norm=norm)
68
+ elif dim == 3: idct = torch_dct.idct_3d(t, norm=norm)
69
69
  else: raise ValueError(f"Unsupported number of dimensions {dim}")
70
70
 
71
71
  unprojected.append(reverse_dims(idct))