torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Transform, apply
5
+ from ...core import Chainable, Transform, apply_transform
6
6
  from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
7
 
8
8
  @torch.no_grad
@@ -24,11 +24,9 @@ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
24
24
  Projects the gradient to the eigenbases of the preconditioner.
25
25
  """
26
26
  for mat in Q:
27
- if mat is None: continue
28
- if len(mat) > 0:
27
+ if mat is not None and len(mat) > 0:
29
28
  tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
29
  else:
31
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
30
  permute_order = list(range(1, len(tensors.shape))) + [0]
33
31
  tensors = tensors.permute(permute_order)
34
32
 
@@ -40,8 +38,7 @@ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
40
38
  Projects the gradient back to the original space.
41
39
  """
42
40
  for mat in Q:
43
- if mat is None: continue
44
- if len(mat) > 0:
41
+ if mat is not None and len(mat) > 0:
45
42
  tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
43
  else:
47
44
  permute_order = list(range(1, len(tensors.shape))) + [0]
@@ -59,8 +56,7 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
59
56
  float_data = False
60
57
  original_type = original_device = None
61
58
  for m in mat:
62
- if m is None: continue
63
- if len(m) == 0:
59
+ if m is None or len(m) == 0:
64
60
  matrix.append([])
65
61
  continue
66
62
  if m.dtype != torch.float:
@@ -100,13 +96,11 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
100
96
  float_data = False
101
97
  original_type = original_device = None
102
98
  for m,o in zip(GG, Q_list):
103
- if m is None: continue
104
- assert o is not None
105
-
106
- if len(m) == 0:
99
+ if m is None or len(m) == 0:
107
100
  matrix.append([])
108
101
  orth_matrix.append([])
109
102
  continue
103
+ assert o is not None
110
104
  if m.data.dtype != torch.float:
111
105
  original_type = m.data.dtype
112
106
  original_device = m.data.device
@@ -152,11 +146,28 @@ class SOAP(Transform):
152
146
  epsilon for dividing first momentum by second. Defaults to 1e-8.
153
147
  decay (float | None, optional):
154
148
  Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
155
- unprojected_exp_avg (bool, optional):
156
- whether to update first momentum in unprojected space. Both true and false work and lead to different
157
- results but True usually works better. Defaults to True.
149
+ alpha (float, optional):
150
+ learning rate. Defaults to 1.
158
151
  bias_correction (bool, optional):
159
152
  enables adam bias correction. Defaults to True.
153
+
154
+ Examples:
155
+ SOAP:
156
+
157
+ .. code-block:: python
158
+
159
+ opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
160
+
161
+ Stabilized SOAP:
162
+
163
+ .. code-block:: python
164
+
165
+ opt = tz.Modular(
166
+ model.parameters(),
167
+ tz.m.SOAP(),
168
+ tz.m.NormalizeByEMA(max_ema_growth=1.2),
169
+ tz.m.LR(1e-2)
170
+ )
160
171
  """
161
172
  def __init__(
162
173
  self,
@@ -170,7 +181,6 @@ class SOAP(Transform):
170
181
  eps: float = 1e-8,
171
182
  decay: float | None = None,
172
183
  alpha: float = 1,
173
- unprojected_exp_avg: bool = True,
174
184
  bias_correction: bool = True,
175
185
  ):
176
186
  defaults = dict(
@@ -183,21 +193,18 @@ class SOAP(Transform):
183
193
  precondition_1d=precondition_1d,
184
194
  eps=eps,
185
195
  decay=decay,
186
- unprojected_exp_avg=unprojected_exp_avg,
187
196
  bias_correction=bias_correction,
188
197
  alpha=alpha,
189
198
  )
190
199
  super().__init__(defaults, uses_grad=False)
191
200
 
192
201
  @torch.no_grad
193
- def transform(self, tensors, params, grads, vars):
202
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
194
203
  updates = []
195
204
  # update preconditioners
196
- for i,(p,t) in enumerate(zip(params, tensors)):
197
- state = self.state[p]
198
- settings = self.settings[p]
199
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
200
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
205
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
206
+ beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps,alpha = itemgetter(
207
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps','alpha')(setting)
201
208
 
202
209
  if merge_small:
203
210
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -205,7 +212,7 @@ class SOAP(Transform):
205
212
  # initialize state on 1st step
206
213
  if 'GG' not in state:
207
214
  state["exp_avg"] = torch.zeros_like(t)
208
- state["exp_avg_sq"] = torch.zeros_like(t)
215
+ state["exp_avg_sq_projected"] = torch.zeros_like(t)
209
216
 
210
217
  if not precondition_1d and t.ndim <= 1:
211
218
  state['GG'] = []
@@ -235,35 +242,31 @@ class SOAP(Transform):
235
242
  # exponential moving averages
236
243
  # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
237
244
  exp_avg: torch.Tensor = state["exp_avg"]
238
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
245
+ exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
239
246
 
240
- if unprojected_exp_avg or t_projected is None:
241
- exp_avg.lerp_(t, 1-beta1)
242
- else:
243
- exp_avg.lerp_(t_projected, 1-beta1)
247
+ exp_avg.lerp_(t, 1-beta1)
244
248
 
245
249
  if t_projected is None:
246
- exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
250
+ exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
247
251
  else:
248
- exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
252
+ exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
249
253
 
250
254
  # project exponential moving averages if they are accumulated unprojected
251
255
  exp_avg_projected = exp_avg
252
- if unprojected_exp_avg and t_projected is not None:
256
+ if t_projected is not None:
253
257
  exp_avg_projected = project(exp_avg, state['Q'])
254
258
 
255
- exp_avg_sq_projected = exp_avg_sq
256
-
257
259
  denom = exp_avg_sq_projected.sqrt().add_(eps)
258
260
  # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
259
261
 
260
262
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
261
263
  # to the original space
262
264
  update = exp_avg_projected / denom
265
+
263
266
  if t_projected is not None:
264
267
  update = project_back(update, state["Q"])
265
268
 
266
- if settings['bias_correction']:
269
+ if setting['bias_correction']:
267
270
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
268
271
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
269
272
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -279,7 +282,7 @@ class SOAP(Transform):
279
282
  # Update is done after the gradient step to avoid using current gradients in the projection.
280
283
  if state['GG'] is not None:
281
284
  update_soap_covariances_(t, state['GG'], shampoo_beta)
282
- if state['step'] % settings['precond_freq'] == 0:
283
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
285
+ if state['step'] % setting['precond_freq'] == 0:
286
+ state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
284
287
 
285
288
  return updates
@@ -2,7 +2,7 @@ from typing import Literal
2
2
  from collections.abc import Callable
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform, Chainable, apply
5
+ from ...core import Module, Target, Transform, Chainable, apply_transform
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
8
 
@@ -35,6 +35,74 @@ def sophia_H(
35
35
 
36
36
 
37
37
  class SophiaH(Module):
38
+ """SophiaH optimizer from https://arxiv.org/abs/2305.14342
39
+
40
+ This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
41
+
42
+ .. note::
43
+ In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
44
+
45
+ .. note::
46
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
47
+
48
+ .. note::
49
+ This module requires the a closure passed to the optimizer step,
50
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
51
+ The closure must accept a ``backward`` argument (refer to documentation).
52
+
53
+ Args:
54
+ beta1 (float, optional): first momentum. Defaults to 0.96.
55
+ beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
56
+ update_freq (int, optional):
57
+ frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
58
+ precond_scale (float, optional):
59
+ scale of the preconditioner. Defaults to 1.
60
+ clip (float, optional):
61
+ clips update to (-clip, clip). Defaults to 1.
62
+ eps (float, optional):
63
+ clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
64
+ hvp_method (str, optional):
65
+ Determines how Hessian-vector products are evaluated.
66
+
67
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
68
+ This requires creating a graph for the gradient.
69
+ - ``"forward"``: Use a forward finite difference formula to
70
+ approximate the HVP. This requires one extra gradient evaluation.
71
+ - ``"central"``: Use a central finite difference formula for a
72
+ more accurate HVP approximation. This requires two extra
73
+ gradient evaluations.
74
+ Defaults to "autograd".
75
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
76
+ n_samples (int, optional):
77
+ number of hessian-vector products with random vectors to evaluate each time when updating
78
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
79
+ seed (int | None, optional): seed for random vectors. Defaults to None.
80
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
81
+
82
+ Examples:
83
+ Using SophiaH:
84
+
85
+ .. code-block:: python
86
+
87
+ opt = tz.Modular(
88
+ model.parameters(),
89
+ tz.m.SophiaH(),
90
+ tz.m.LR(0.1)
91
+ )
92
+
93
+ SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
94
+ Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
95
+ SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
96
+
97
+ .. code-block:: python
98
+
99
+ opt = tz.Modular(
100
+ model.parameters(),
101
+ tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
102
+ tz.m.LR(0.1)
103
+ )
104
+
105
+ """
38
106
  def __init__(
39
107
  self,
40
108
  beta1: float = 0.96,
@@ -56,8 +124,8 @@ class SophiaH(Module):
56
124
  self.set_child('inner', inner)
57
125
 
58
126
  @torch.no_grad
59
- def step(self, vars):
60
- params = vars.params
127
+ def step(self, var):
128
+ params = var.params
61
129
  settings = self.settings[params[0]]
62
130
  hvp_method = settings['hvp_method']
63
131
  fd_h = settings['fd_h']
@@ -71,37 +139,26 @@ class SophiaH(Module):
71
139
  self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
72
140
  generator = self.global_state['generator']
73
141
 
74
- beta1, beta2, precond_scale, clip, eps = self.get_settings(
75
- 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', params=params, cls=NumberList)
142
+ beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
143
+ 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
76
144
 
77
- exp_avg, h_exp_avg = self.get_state('exp_avg', 'h_exp_avg', params=params, cls=TensorList)
145
+ exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
78
146
 
79
147
  step = self.global_state.get('step', 0)
80
148
  self.global_state['step'] = step + 1
81
149
 
82
- closure = vars.closure
150
+ closure = var.closure
83
151
  assert closure is not None
84
152
 
85
153
  h = None
86
154
  if step % update_freq == 0:
87
155
 
88
- grad=None
156
+ rgrad=None
89
157
  for i in range(n_samples):
90
158
  u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
91
159
 
92
- if hvp_method == 'autograd':
93
- if grad is None: grad = vars.get_grad(create_graph=True)
94
- assert grad is not None
95
- Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
96
-
97
- elif hvp_method == 'forward':
98
- loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=vars.get_grad(), normalize=True)
99
-
100
- elif hvp_method == 'central':
101
- loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
102
-
103
- else:
104
- raise ValueError(hvp_method)
160
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
161
+ h=fd_h, normalize=True, retain_grad=i < n_samples-1)
105
162
 
106
163
  if h is None: h = Hvp
107
164
  else: torch._foreach_add_(h, Hvp)
@@ -109,11 +166,11 @@ class SophiaH(Module):
109
166
  assert h is not None
110
167
  if n_samples > 1: torch._foreach_div_(h, n_samples)
111
168
 
112
- update = vars.get_update()
169
+ update = var.get_update()
113
170
  if 'inner' in self.children:
114
- update = apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars)
171
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
115
172
 
116
- vars.update = sophia_H(
173
+ var.update = sophia_H(
117
174
  tensors=TensorList(update),
118
175
  h=TensorList(h) if h is not None else None,
119
176
  exp_avg_=exp_avg,
@@ -126,4 +183,4 @@ class SophiaH(Module):
126
183
  eps=eps,
127
184
  step=step,
128
185
  )
129
- return vars
186
+ return var
@@ -1,5 +1,3 @@
1
- from .projection import Projection
2
- from .fft import FFTProjection
3
- from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
4
-
1
+ from .projection import ProjectionBase, VectorProjection, ScalarProjection
2
+ from .cast import To, ViewAsReal
5
3
  # from .galore import GaLore
@@ -0,0 +1,51 @@
1
+ import torch
2
+ from .projection import ProjectionBase
3
+ from ...core import Chainable
4
+
5
+ class To(ProjectionBase):
6
+ """Cast modules to specified device and dtype"""
7
+ def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
8
+ defaults = dict(dtype=dtype, device=device)
9
+ super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)
10
+
11
+ @torch.no_grad
12
+ def project(self, tensors, params, grads, loss, states, settings, current):
13
+ casted = []
14
+ for tensor, state, setting in zip(tensors,states, settings):
15
+ state['dtype'] = tensor.dtype
16
+ state['device'] = tensor.device
17
+ tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
18
+ casted.append(tensor)
19
+ return casted
20
+
21
+ @torch.no_grad
22
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
23
+ uncasted = []
24
+ for tensor, state in zip(projected_tensors, states):
25
+ tensor = tensor.to(dtype=state['dtype'], device=state['device'])
26
+ uncasted.append(tensor)
27
+ return uncasted
28
+
29
+
30
+ class ViewAsReal(ProjectionBase):
31
+ """View complex tensors as real tensors. Doesn't affect tensors that are already."""
32
+ def __init__(self, modules: Chainable):
33
+ super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)
34
+
35
+ @torch.no_grad
36
+ def project(self, tensors, params, grads, loss, states, settings, current):
37
+ views = []
38
+ for tensor, state in zip(tensors,states):
39
+ is_complex = torch.is_complex(tensor)
40
+ state['is_complex'] = is_complex
41
+ if is_complex: tensor = torch.view_as_real(tensor)
42
+ views.append(tensor)
43
+ return views
44
+
45
+ @torch.no_grad
46
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
47
+ un_views = []
48
+ for tensor, state in zip(projected_tensors, states):
49
+ if state['is_complex']: tensor = torch.view_as_complex(tensor)
50
+ un_views.append(tensor)
51
+ return un_views
@@ -6,5 +6,7 @@ from typing import Any, Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Vars
10
- from .projection import Projection
9
+ from ...core import Chainable, Module, Var
10
+ from .projection import ProjectionBase
11
+
12
+ # TODO