torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,282 @@
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Transform, apply
6
+ from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+
8
+ @torch.no_grad
9
+ def update_soap_covariances_(
10
+ grad: torch.Tensor,
11
+ GGs_: list[torch.Tensor | None],
12
+ GG_sqs: list[torch.Tensor | None],
13
+ beta: float | None,
14
+ precond_beta: float | None,
15
+ ):
16
+ for i, (GG, GG_sq) in enumerate(zip(GGs_, GG_sqs)):
17
+ if GG is None: continue
18
+ assert GG_sq is not None
19
+
20
+ if precond_beta is None: GG_sq.addcmul_(GG, GG)
21
+ else: GG_sq.mul_(precond_beta).addcmul_(GG, GG, value=1-precond_beta)
22
+
23
+ axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
24
+ if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
25
+ else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
26
+
27
+ @torch.no_grad
28
+ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
29
+ """
30
+ Projects the gradient to the eigenbases of the preconditioner.
31
+ """
32
+ for mat in Q:
33
+ if mat is None: continue
34
+ if len(mat) > 0:
35
+ tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
36
+ else:
37
+ # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
38
+ permute_order = list(range(1, len(tensors.shape))) + [0]
39
+ tensors = tensors.permute(permute_order)
40
+
41
+ return tensors
42
+
43
+ @torch.no_grad
44
+ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
45
+ """
46
+ Projects the gradient back to the original space.
47
+ """
48
+ for mat in Q:
49
+ if mat is None: continue
50
+ if len(mat) > 0:
51
+ tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
52
+ else:
53
+ permute_order = list(range(1, len(tensors.shape))) + [0]
54
+ tensors = tensors.permute(permute_order)
55
+
56
+ return tensors
57
+
58
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
59
+ @torch.no_grad
60
+ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
61
+ """
62
+ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
63
+ """
64
+ matrix = []
65
+ float_data = False
66
+ original_type = original_device = None
67
+ for m in mat:
68
+ if m is None: continue
69
+ if len(m) == 0:
70
+ matrix.append([])
71
+ continue
72
+ if m.dtype != torch.float:
73
+ original_type = m.dtype
74
+ original_device = m.device
75
+ matrix.append(m.float())
76
+ else:
77
+ float_data = True
78
+ matrix.append(m)
79
+
80
+ final = []
81
+ for m in matrix:
82
+ if len(m) == 0:
83
+ final.append([])
84
+ continue
85
+ try:
86
+ _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
87
+ except Exception:
88
+ _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
89
+ Q = Q.to(m.dtype)
90
+ Q = torch.flip(Q, [1])
91
+
92
+ if not float_data:
93
+ Q = Q.to(original_device).type(original_type)
94
+ final.append(Q)
95
+ return final
96
+
97
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
98
+ @torch.no_grad
99
+ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
100
+ """
101
+ Computes the eigenbases of the preconditioner using one round of power iteration
102
+ followed by torch.linalg.qr decomposition.
103
+ """
104
+ matrix = []
105
+ orth_matrix = []
106
+ float_data = False
107
+ original_type = original_device = None
108
+ for m,o in zip(GG, Q_list):
109
+ if m is None: continue
110
+ assert o is not None
111
+
112
+ if len(m) == 0:
113
+ matrix.append([])
114
+ orth_matrix.append([])
115
+ continue
116
+ if m.data.dtype != torch.float:
117
+ original_type = m.data.dtype
118
+ original_device = m.data.device
119
+ matrix.append(m.data.float())
120
+ orth_matrix.append(o.data.float())
121
+ else:
122
+ float_data = True
123
+ matrix.append(m.data.float())
124
+ orth_matrix.append(o.data.float())
125
+
126
+ final = []
127
+ for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
128
+ if len(m)==0:
129
+ final.append([])
130
+ continue
131
+ est_eig = torch.diag(o.T @ m @ o)
132
+ sort_idx = torch.argsort(est_eig, descending=True)
133
+ exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
134
+ o = o[:,sort_idx]
135
+ power_iter = m @ o
136
+ Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
137
+
138
+ if not float_data:
139
+ Q = Q.to(original_device).type(original_type)
140
+ final.append(Q)
141
+
142
+ return final, exp_avg_sq
143
+
144
+ class AdaSOAP(Transform):
145
+ """SOAP with diagonally preconditioned GG^Ts
146
+
147
+ precond_beta - beta for GG^T squares
148
+ """
149
+ def __init__(
150
+ self,
151
+ beta1: float = 0.95,
152
+ beta2: float = 0.95,
153
+ shampoo_beta: float | None = 0.95,
154
+ precond_beta: float | None = 0.95,
155
+ precond_freq: int = 10,
156
+ merge_small: bool = True,
157
+ max_dim: int = 2_000,
158
+ precondition_1d: bool = True,
159
+ eps: float = 1e-8,
160
+ decay: float | None = None,
161
+ alpha: float = 1,
162
+ unprojected_exp_avg: bool = True,
163
+ bias_correction: bool = True,
164
+ ):
165
+ defaults = dict(
166
+ beta1=beta1,
167
+ beta2=beta2,
168
+ shampoo_beta=shampoo_beta,
169
+ precond_beta=precond_beta,
170
+ precond_freq=precond_freq,
171
+ merge_small=merge_small,
172
+ max_dim=max_dim,
173
+ precondition_1d=precondition_1d,
174
+ eps=eps,
175
+ decay=decay,
176
+ unprojected_exp_avg=unprojected_exp_avg,
177
+ bias_correction=bias_correction,
178
+ alpha=alpha,
179
+ )
180
+ super().__init__(defaults, uses_grad=False)
181
+
182
+ @torch.no_grad
183
+ def transform(self, tensors, params, grads, vars):
184
+ updates = []
185
+ # update preconditioners
186
+ for i,(p,t) in enumerate(zip(params, tensors)):
187
+ state = self.state[p]
188
+ settings = self.settings[p]
189
+ beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
190
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
191
+ precond_beta = settings['precond_beta']
192
+
193
+ if merge_small:
194
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
195
+
196
+ # initialize state on 1st step
197
+ if 'GG' not in state:
198
+ state["exp_avg"] = torch.zeros_like(t)
199
+ state["exp_avg_sq"] = torch.zeros_like(t)
200
+
201
+ if not precondition_1d and t.ndim <= 1:
202
+ state['GG'] = []
203
+ state['GG_sq'] = []
204
+
205
+ else:
206
+ state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
207
+ state['GG_sq'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
208
+
209
+ # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
210
+ if len([i is not None for i in state['GG']]) == 0:
211
+ state['GG'] = None
212
+ state['GG_sq'] = None
213
+
214
+ if state['GG'] is not None:
215
+ assert state['GG_sq'] is not None
216
+ update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
217
+ GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
218
+ state['Q'] = get_orthogonal_matrix(GG_precond)
219
+
220
+ state['step'] = 0
221
+ updates.append(tensors[i].sign())
222
+ continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
+ # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
224
+
225
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
226
+ # i.e. projecting to the eigenbases of matrices in state['GG']
227
+ t_projected = None
228
+ if state['GG'] is not None:
229
+ t_projected = project(t, state['Q'])
230
+
231
+ # exponential moving averages
232
+ # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
233
+ exp_avg: torch.Tensor = state["exp_avg"]
234
+ exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
235
+
236
+ if unprojected_exp_avg or t_projected is None:
237
+ exp_avg.lerp_(t, 1-beta1)
238
+ else:
239
+ exp_avg.lerp_(t_projected, 1-beta1)
240
+
241
+ if t_projected is None:
242
+ exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
243
+ else:
244
+ exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
245
+
246
+ # project exponential moving averages if they are accumulated unprojected
247
+ exp_avg_projected = exp_avg
248
+ if unprojected_exp_avg and t_projected is not None:
249
+ exp_avg_projected = project(exp_avg, state['Q'])
250
+
251
+ exp_avg_sq_projected = exp_avg_sq
252
+
253
+ denom = exp_avg_sq_projected.sqrt().add_(eps)
254
+ # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
255
+
256
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
257
+ # to the original space
258
+ update = exp_avg_projected / denom
259
+ if t_projected is not None:
260
+ update = project_back(update, state["Q"])
261
+
262
+ if settings['bias_correction']:
263
+ bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
264
+ bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
265
+ update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
266
+ elif alpha is not None:
267
+ update *= alpha
268
+
269
+ if merge_small:
270
+ update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
271
+
272
+ updates.append(update)
273
+ state["step"] += 1
274
+
275
+ # Update is done after the gradient step to avoid using current gradients in the projection.
276
+ if state['GG'] is not None:
277
+ update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
278
+ GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
279
+ if state['step'] % settings['precond_freq'] == 0:
280
+ state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
281
+
282
+ return updates
@@ -0,0 +1,145 @@
1
+ import warnings
2
+ from functools import partial
3
+ from typing import Literal
4
+ from collections.abc import Callable
5
+ import torch
6
+ import torchalgebras as ta
7
+
8
+ from ...core import Chainable, apply, Module
9
+ from ...utils import vec_to_tensors, TensorList
10
+ from ...utils.derivatives import (
11
+ hessian_list_to_mat,
12
+ hessian_mat,
13
+ jacobian_and_hessian_wrt,
14
+ )
15
+
16
+ class MaxItersReached(Exception): pass
17
+ def tropical_lstsq(
18
+ H: torch.Tensor,
19
+ g: torch.Tensor,
20
+ solver,
21
+ maxiter,
22
+ tol,
23
+ algebra,
24
+ verbose,
25
+ ):
26
+ """it can run on any algebra with add despite it saying tropical"""
27
+ algebra = ta.get_algebra(algebra)
28
+
29
+ x = torch.zeros_like(g, requires_grad=True)
30
+ best_x = x.detach().clone()
31
+ best_loss = float('inf')
32
+ opt = solver([x])
33
+
34
+ niter = 0
35
+ def closure(backward=True):
36
+ nonlocal niter, best_x, best_loss
37
+ if niter == maxiter: raise MaxItersReached
38
+ niter += 1
39
+
40
+ g_hat = algebra.mm(H, x)
41
+ loss = torch.nn.functional.mse_loss(g_hat, g)
42
+ if loss < best_loss:
43
+ best_x = x.detach().clone()
44
+ best_loss = loss.detach()
45
+
46
+ if backward:
47
+ opt.zero_grad()
48
+ loss.backward()
49
+ return loss
50
+
51
+ loss = None
52
+ prev_loss = float('inf')
53
+ for i in range(maxiter):
54
+ try:
55
+ loss = opt.step(closure)
56
+ if loss == 0: break
57
+ if tol is not None and prev_loss - loss < tol: break
58
+ prev_loss = loss
59
+ except MaxItersReached:
60
+ break
61
+
62
+ if verbose: print(f'{best_loss = } after {niter} iters')
63
+ return best_x.detach()
64
+
65
+ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemiring()):
66
+ if reg!=0:
67
+ I = ta.AlgebraicTensor(torch.eye(H.size(-1), dtype=H.dtype, device=H.device), algebra)
68
+ I = I * reg
69
+ H = algebra.add(H, I.data)
70
+ return H
71
+
72
+
73
+ class AlgebraicNewton(Module):
74
+ """newton in other algebras, not practical because solving linear system is very hard."""
75
+ def __init__(
76
+ self,
77
+ reg: float | None = None,
78
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
79
+ vectorize: bool = True,
80
+ solver=lambda p: torch.optim.LBFGS(p, line_search_fn='strong_wolfe'),
81
+ maxiter=1000,
82
+ tol: float | None = 1e-10,
83
+ algebra: ta.Algebra | str = 'tropical max',
84
+ verbose: bool = False,
85
+ inner: Chainable | None = None,
86
+ ):
87
+ defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize)
88
+ super().__init__(defaults)
89
+
90
+ self.algebra = ta.get_algebra(algebra)
91
+ self.lstsq_args:dict = dict(solver=solver, maxiter=maxiter, tol=tol, algebra=algebra, verbose=verbose)
92
+
93
+ if inner is not None:
94
+ self.set_child('inner', inner)
95
+
96
+ @torch.no_grad
97
+ def step(self, vars):
98
+ params = TensorList(vars.params)
99
+ closure = vars.closure
100
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
101
+
102
+ settings = self.settings[params[0]]
103
+ reg = settings['reg']
104
+ hessian_method = settings['hessian_method']
105
+ vectorize = settings['vectorize']
106
+
107
+ # ------------------------ calculate grad and hessian ------------------------ #
108
+ if hessian_method == 'autograd':
109
+ with torch.enable_grad():
110
+ loss = vars.loss = vars.loss_approx = closure(False)
111
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
112
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
113
+ vars.grad = g_list
114
+ H = hessian_list_to_mat(H_list)
115
+
116
+ elif hessian_method in ('func', 'autograd.functional'):
117
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
118
+ with torch.enable_grad():
119
+ g_list = vars.get_grad(retain_graph=True)
120
+ H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
121
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
122
+
123
+ else:
124
+ raise ValueError(hessian_method)
125
+
126
+ # -------------------------------- inner step -------------------------------- #
127
+ if 'inner' in self.children:
128
+ g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
129
+ g = torch.cat([t.view(-1) for t in g_list])
130
+
131
+ # ------------------------------- regulazition ------------------------------- #
132
+ if reg is not None: H = tikhonov(H, reg)
133
+
134
+ # ----------------------------------- solve ---------------------------------- #
135
+ tropical_update = tropical_lstsq(H, g, **self.lstsq_args)
136
+ # what now? w - u is not defined, it is defined for max version if u < w
137
+ # w = params.to_vec()
138
+ # w_hat = self.algebra.sub(w, tropical_update)
139
+ # update = w_hat - w
140
+ # no
141
+ # it makes sense to solve tropical system and sub normally
142
+ # the only thing is that tropical system can have no solutions
143
+
144
+ vars.update = vec_to_tensors(tropical_update, params)
145
+ return vars
@@ -0,0 +1,89 @@
1
+ from typing import Literal
2
+ from collections.abc import Callable
3
+ import torch
4
+
5
+ from ...core import Module, Target, Transform, Chainable, apply
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
+
9
+ def curveball(
10
+ tensors: TensorList,
11
+ z_: TensorList,
12
+ Hz: TensorList,
13
+ momentum: float | NumberList,
14
+ precond_lr: float | NumberList,
15
+ ):
16
+ """returns z_, clone it!!!"""
17
+ delta = Hz + tensors
18
+ z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
19
+ return z_
20
+
21
+
22
+ class CurveBall(Module):
23
+ """CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.
24
+
25
+ For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.
26
+
27
+ Args:
28
+ precond_lr (float, optional): learning rate for updating preconditioned gradients. Defaults to 1e-3.
29
+ momentum (float, optional): decay rate for preconditioned gradients. Defaults to 0.9.
30
+ hvp_method (str, optional): how to calculate hessian vector products. Defaults to "autograd".
31
+ h (float, optional): finite difference step size for when hvp_method is set to finite difference. Defaults to 1e-3.
32
+ reg (float, optional): hessian regularization. Defaults to 1.
33
+ inner (Chainable | None, optional): Inner modules. Defaults to None.
34
+ """
35
+ def __init__(
36
+ self,
37
+ precond_lr: float=1e-3,
38
+ momentum: float=0.9,
39
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
40
+ h: float = 1e-3,
41
+ reg: float = 1,
42
+ inner: Chainable | None = None,
43
+ ):
44
+ defaults = dict(precond_lr=precond_lr, momentum=momentum, hvp_method=hvp_method, h=h, reg=reg)
45
+ super().__init__(defaults)
46
+
47
+ if inner is not None: self.set_child('inner', inner)
48
+
49
+ @torch.no_grad
50
+ def step(self, vars):
51
+
52
+ params = vars.params
53
+ settings = self.settings[params[0]]
54
+ hvp_method = settings['hvp_method']
55
+ h = settings['h']
56
+
57
+ precond_lr, momentum, reg = self.get_settings('momentum', 'decay_rate', 'reg', params=params, cls=NumberList)
58
+
59
+
60
+ closure = vars.closure
61
+ assert closure is not None
62
+
63
+ z, Hz = self.get_state('z', 'Hz', params=params, cls=TensorList)
64
+
65
+ if hvp_method == 'autograd':
66
+ grad = vars.get_grad(create_graph=True)
67
+ Hvp = hvp(params, grad, z)
68
+
69
+ elif hvp_method == 'forward':
70
+ loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=vars.get_grad(), normalize=True)
71
+
72
+ elif hvp_method == 'central':
73
+ loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
74
+
75
+ else:
76
+ raise ValueError(hvp_method)
77
+
78
+
79
+ Hz.set_(Hvp + z*reg)
80
+
81
+
82
+ update = vars.get_update()
83
+ if 'inner' in self.children:
84
+ update = apply(self.children['inner'], update, params, grads=vars.grad, vars=vars)
85
+
86
+ z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
87
+ vars.update = z.neg()
88
+
89
+ return vars