torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -24,11 +24,9 @@ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
24
24
  Projects the gradient to the eigenbases of the preconditioner.
25
25
  """
26
26
  for mat in Q:
27
- if mat is None: continue
28
- if len(mat) > 0:
27
+ if mat is not None and len(mat) > 0:
29
28
  tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
29
  else:
31
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
30
  permute_order = list(range(1, len(tensors.shape))) + [0]
33
31
  tensors = tensors.permute(permute_order)
34
32
 
@@ -40,8 +38,7 @@ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
40
38
  Projects the gradient back to the original space.
41
39
  """
42
40
  for mat in Q:
43
- if mat is None: continue
44
- if len(mat) > 0:
41
+ if mat is not None and len(mat) > 0:
45
42
  tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
43
  else:
47
44
  permute_order = list(range(1, len(tensors.shape))) + [0]
@@ -59,8 +56,7 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
59
56
  float_data = False
60
57
  original_type = original_device = None
61
58
  for m in mat:
62
- if m is None: continue
63
- if len(m) == 0:
59
+ if m is None or len(m) == 0:
64
60
  matrix.append([])
65
61
  continue
66
62
  if m.dtype != torch.float:
@@ -100,13 +96,11 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
100
96
  float_data = False
101
97
  original_type = original_device = None
102
98
  for m,o in zip(GG, Q_list):
103
- if m is None: continue
104
- assert o is not None
105
-
106
- if len(m) == 0:
99
+ if m is None or len(m) == 0:
107
100
  matrix.append([])
108
101
  orth_matrix.append([])
109
102
  continue
103
+ assert o is not None
110
104
  if m.data.dtype != torch.float:
111
105
  original_type = m.data.dtype
112
106
  original_device = m.data.device
@@ -156,6 +150,24 @@ class SOAP(Transform):
156
150
  learning rate. Defaults to 1.
157
151
  bias_correction (bool, optional):
158
152
  enables adam bias correction. Defaults to True.
153
+
154
+ Examples:
155
+ SOAP:
156
+
157
+ .. code-block:: python
158
+
159
+ opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
160
+
161
+ Stabilized SOAP:
162
+
163
+ .. code-block:: python
164
+
165
+ opt = tz.Modular(
166
+ model.parameters(),
167
+ tz.m.SOAP(),
168
+ tz.m.NormalizeByEMA(max_ema_growth=1.2),
169
+ tz.m.LR(1e-2)
170
+ )
159
171
  """
160
172
  def __init__(
161
173
  self,
@@ -187,7 +199,7 @@ class SOAP(Transform):
187
199
  super().__init__(defaults, uses_grad=False)
188
200
 
189
201
  @torch.no_grad
190
- def apply(self, tensors, params, grads, loss, states, settings):
202
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
191
203
  updates = []
192
204
  # update preconditioners
193
205
  for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
@@ -200,7 +212,7 @@ class SOAP(Transform):
200
212
  # initialize state on 1st step
201
213
  if 'GG' not in state:
202
214
  state["exp_avg"] = torch.zeros_like(t)
203
- state["exp_avg_sq"] = torch.zeros_like(t)
215
+ state["exp_avg_sq_projected"] = torch.zeros_like(t)
204
216
 
205
217
  if not precondition_1d and t.ndim <= 1:
206
218
  state['GG'] = []
@@ -230,22 +242,20 @@ class SOAP(Transform):
230
242
  # exponential moving averages
231
243
  # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
232
244
  exp_avg: torch.Tensor = state["exp_avg"]
233
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
245
+ exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
234
246
 
235
247
  exp_avg.lerp_(t, 1-beta1)
236
248
 
237
249
  if t_projected is None:
238
- exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
250
+ exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
239
251
  else:
240
- exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
252
+ exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
241
253
 
242
254
  # project exponential moving averages if they are accumulated unprojected
243
255
  exp_avg_projected = exp_avg
244
256
  if t_projected is not None:
245
257
  exp_avg_projected = project(exp_avg, state['Q'])
246
258
 
247
- exp_avg_sq_projected = exp_avg_sq
248
-
249
259
  denom = exp_avg_sq_projected.sqrt().add_(eps)
250
260
  # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
251
261
 
@@ -273,6 +283,6 @@ class SOAP(Transform):
273
283
  if state['GG'] is not None:
274
284
  update_soap_covariances_(t, state['GG'], shampoo_beta)
275
285
  if state['step'] % setting['precond_freq'] == 0:
276
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
286
+ state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
277
287
 
278
288
  return updates
@@ -35,6 +35,74 @@ def sophia_H(
35
35
 
36
36
 
37
37
  class SophiaH(Module):
38
+ """SophiaH optimizer from https://arxiv.org/abs/2305.14342
39
+
40
+ This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
41
+
42
+ .. note::
43
+ In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
44
+
45
+ .. note::
46
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
47
+
48
+ .. note::
49
+ This module requires the a closure passed to the optimizer step,
50
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
51
+ The closure must accept a ``backward`` argument (refer to documentation).
52
+
53
+ Args:
54
+ beta1 (float, optional): first momentum. Defaults to 0.96.
55
+ beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
56
+ update_freq (int, optional):
57
+ frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
58
+ precond_scale (float, optional):
59
+ scale of the preconditioner. Defaults to 1.
60
+ clip (float, optional):
61
+ clips update to (-clip, clip). Defaults to 1.
62
+ eps (float, optional):
63
+ clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
64
+ hvp_method (str, optional):
65
+ Determines how Hessian-vector products are evaluated.
66
+
67
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
68
+ This requires creating a graph for the gradient.
69
+ - ``"forward"``: Use a forward finite difference formula to
70
+ approximate the HVP. This requires one extra gradient evaluation.
71
+ - ``"central"``: Use a central finite difference formula for a
72
+ more accurate HVP approximation. This requires two extra
73
+ gradient evaluations.
74
+ Defaults to "autograd".
75
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
76
+ n_samples (int, optional):
77
+ number of hessian-vector products with random vectors to evaluate each time when updating
78
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
79
+ seed (int | None, optional): seed for random vectors. Defaults to None.
80
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
81
+
82
+ Examples:
83
+ Using SophiaH:
84
+
85
+ .. code-block:: python
86
+
87
+ opt = tz.Modular(
88
+ model.parameters(),
89
+ tz.m.SophiaH(),
90
+ tz.m.LR(0.1)
91
+ )
92
+
93
+ SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
94
+ Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
95
+ SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
96
+
97
+ .. code-block:: python
98
+
99
+ opt = tz.Modular(
100
+ model.parameters(),
101
+ tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
102
+ tz.m.LR(0.1)
103
+ )
104
+
105
+ """
38
106
  def __init__(
39
107
  self,
40
108
  beta1: float = 0.96,
@@ -85,23 +153,12 @@ class SophiaH(Module):
85
153
  h = None
86
154
  if step % update_freq == 0:
87
155
 
88
- grad=None
156
+ rgrad=None
89
157
  for i in range(n_samples):
90
158
  u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
91
159
 
92
- if hvp_method == 'autograd':
93
- if grad is None: grad = var.get_grad(create_graph=True)
94
- assert grad is not None
95
- Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
96
-
97
- elif hvp_method == 'forward':
98
- loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=var.get_grad(), normalize=True)
99
-
100
- elif hvp_method == 'central':
101
- loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
102
-
103
- else:
104
- raise ValueError(hvp_method)
160
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
161
+ h=fd_h, normalize=True, retain_grad=i < n_samples-1)
105
162
 
106
163
  if h is None: h = Hvp
107
164
  else: torch._foreach_add_(h, Hvp)
@@ -1,5 +1,3 @@
1
- from .projection import Projection
2
- from .fft import FFTProjection
3
- from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
4
-
1
+ from .projection import ProjectionBase, VectorProjection, ScalarProjection
2
+ from .cast import To, ViewAsReal
5
3
  # from .galore import GaLore
@@ -0,0 +1,51 @@
1
+ import torch
2
+ from .projection import ProjectionBase
3
+ from ...core import Chainable
4
+
5
+ class To(ProjectionBase):
6
+ """Cast modules to specified device and dtype"""
7
+ def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
8
+ defaults = dict(dtype=dtype, device=device)
9
+ super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)
10
+
11
+ @torch.no_grad
12
+ def project(self, tensors, params, grads, loss, states, settings, current):
13
+ casted = []
14
+ for tensor, state, setting in zip(tensors,states, settings):
15
+ state['dtype'] = tensor.dtype
16
+ state['device'] = tensor.device
17
+ tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
18
+ casted.append(tensor)
19
+ return casted
20
+
21
+ @torch.no_grad
22
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
23
+ uncasted = []
24
+ for tensor, state in zip(projected_tensors, states):
25
+ tensor = tensor.to(dtype=state['dtype'], device=state['device'])
26
+ uncasted.append(tensor)
27
+ return uncasted
28
+
29
+
30
+ class ViewAsReal(ProjectionBase):
31
+ """View complex tensors as real tensors. Doesn't affect tensors that are already."""
32
+ def __init__(self, modules: Chainable):
33
+ super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)
34
+
35
+ @torch.no_grad
36
+ def project(self, tensors, params, grads, loss, states, settings, current):
37
+ views = []
38
+ for tensor, state in zip(tensors,states):
39
+ is_complex = torch.is_complex(tensor)
40
+ state['is_complex'] = is_complex
41
+ if is_complex: tensor = torch.view_as_real(tensor)
42
+ views.append(tensor)
43
+ return views
44
+
45
+ @torch.no_grad
46
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
47
+ un_views = []
48
+ for tensor, state in zip(projected_tensors, states):
49
+ if state['is_complex']: tensor = torch.view_as_complex(tensor)
50
+ un_views.append(tensor)
51
+ return un_views
@@ -7,4 +7,6 @@ from typing import Any, Literal
7
7
  import torch
8
8
 
9
9
  from ...core import Chainable, Module, Var
10
- from .projection import Projection
10
+ from .projection import ProjectionBase
11
+
12
+ # TODO