torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -77,8 +77,11 @@ def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_
77
77
  return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
78
78
 
79
79
  _FD_FUNCS = {
80
+ "forward": _forward2,
80
81
  "forward2": _forward2,
82
+ "backward": _backward2,
81
83
  "backward2": _backward2,
84
+ "central": _central2,
82
85
  "central2": _central2,
83
86
  "central3": _central2, # they are the same
84
87
  "forward3": _forward3,
@@ -88,19 +91,42 @@ _FD_FUNCS = {
88
91
 
89
92
 
90
93
  class FDM(GradApproximator):
91
- """Approximate gradients via finite difference method
94
+ """Approximate gradients via finite difference method.
95
+
96
+ Note:
97
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
98
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
92
99
 
93
100
  Args:
94
101
  h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
95
102
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
96
103
  target (GradTarget, optional): what to set on var. Defaults to 'closure'.
104
+
105
+ Examples:
106
+ plain FDM:
107
+
108
+ ```python
109
+ fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
110
+ ```
111
+
112
+ Any gradient-based method can use FDM-estimated gradients.
113
+ ```python
114
+ fdm_ncg = tz.Modular(
115
+ model.parameters(),
116
+ tz.m.FDM(),
117
+ # set hvp_method to "forward" so that it
118
+ # uses gradient difference instead of autograd
119
+ tz.m.NewtonCG(hvp_method="forward"),
120
+ tz.m.Backtracking()
121
+ )
122
+ ```
97
123
  """
98
- def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
124
+ def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
99
125
  defaults = dict(h=h, formula=formula)
100
126
  super().__init__(defaults, target=target)
101
127
 
102
128
  @torch.no_grad
103
- def approximate(self, closure, params, loss, var):
129
+ def approximate(self, closure, params, loss):
104
130
  grads = []
105
131
  loss_approx = None
106
132
 
@@ -112,7 +138,7 @@ class FDM(GradApproximator):
112
138
  h = settings['h']
113
139
  fd_fn = _FD_FUNCS[settings['formula']]
114
140
 
115
- p_flat = p.view(-1); g_flat = g.view(-1)
141
+ p_flat = p.ravel(); g_flat = g.ravel()
116
142
  for i in range(len(p_flat)):
117
143
  loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
118
144
  g_flat[i] = d
@@ -4,14 +4,21 @@ from typing import Any, Literal
4
4
 
5
5
  import torch
6
6
 
7
- from ...utils import Distributions, NumberList, TensorList, generic_eq
7
+ from ...utils import Distributions, NumberList, TensorList
8
8
  from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
9
9
  from .grad_approximator import GradApproximator, GradTarget
10
10
  from .rfdm import RandomizedFDM
11
11
 
12
12
 
13
13
  class ForwardGradient(RandomizedFDM):
14
- """Forward gradient method, same as randomized finite difference but directional derivative is estimated via autograd (as jacobian vector product)
14
+ """Forward gradient method.
15
+
16
+ This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
17
+
18
+ Note:
19
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
20
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
21
+
15
22
 
16
23
  Args:
17
24
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
@@ -24,6 +31,9 @@ class ForwardGradient(RandomizedFDM):
24
31
  how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
25
32
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
26
33
  target (GradTarget, optional): what to set on var. Defaults to "closure".
34
+
35
+ References:
36
+ Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
27
37
  """
28
38
  PRE_MULTIPLY_BY_H = False
29
39
  def __init__(
@@ -41,7 +51,7 @@ class ForwardGradient(RandomizedFDM):
41
51
  self.defaults['jvp_method'] = jvp_method
42
52
 
43
53
  @torch.no_grad
44
- def approximate(self, closure, params, loss, var):
54
+ def approximate(self, closure, params, loss):
45
55
  params = TensorList(params)
46
56
  loss_approx = None
47
57
 
@@ -57,7 +67,9 @@ class ForwardGradient(RandomizedFDM):
57
67
  grad = None
58
68
  for i in range(n_samples):
59
69
  prt = perturbations[i]
60
- if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator)
70
+ if prt[0] is None:
71
+ prt = params.sample_like(distribution=distribution, variance=1, generator=generator)
72
+
61
73
  else: prt = TensorList(prt)
62
74
 
63
75
  if jvp_method == 'autograd':
@@ -14,28 +14,69 @@ class GradApproximator(Module, ABC):
14
14
  """Base class for gradient approximations.
15
15
  This is an abstract class, to use it, subclass it and override `approximate`.
16
16
 
17
+ GradientApproximator modifies the closure to evaluate the estimated gradients,
18
+ and further closure-based modules will use the modified closure.
19
+
17
20
  Args:
18
21
  defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
19
22
  target (str, optional):
20
23
  whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
24
+
25
+ Example:
26
+
27
+ Basic SPSA method implementation.
28
+ ```python
29
+ class SPSA(GradApproximator):
30
+ def __init__(self, h=1e-3):
31
+ defaults = dict(h=h)
32
+ super().__init__(defaults)
33
+
34
+ @torch.no_grad
35
+ def approximate(self, closure, params, loss):
36
+ perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
37
+
38
+ # evaluate params + perturbation
39
+ torch._foreach_add_(params, perturbation)
40
+ loss_plus = closure(False)
41
+
42
+ # evaluate params - perturbation
43
+ torch._foreach_sub_(params, perturbation)
44
+ torch._foreach_sub_(params, perturbation)
45
+ loss_minus = closure(False)
46
+
47
+ # restore original params
48
+ torch._foreach_add_(params, perturbation)
49
+
50
+ # calculate SPSA gradients
51
+ spsa_grads = []
52
+ for p, pert in zip(params, perturbation):
53
+ settings = self.settings[p]
54
+ h = settings['h']
55
+ d = (loss_plus - loss_minus) / (2*(h**2))
56
+ spsa_grads.append(pert * d)
57
+
58
+ # returns tuple: (grads, loss, loss_approx)
59
+ # loss must be with initial parameters
60
+ # since we only evaluated loss with perturbed parameters
61
+ # we only have loss_approx
62
+ return spsa_grads, None, loss_plus
63
+ ```
21
64
  """
22
65
  def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
23
66
  super().__init__(defaults)
24
67
  self._target: GradTarget = target
25
68
 
26
69
  @abstractmethod
27
- def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, var: Var) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
28
- """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
70
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
71
+ """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
29
72
 
30
- def pre_step(self, var: Var) -> Var | None:
73
+ def pre_step(self, var: Var) -> None:
31
74
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
32
75
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
33
- return var
34
76
 
35
77
  @torch.no_grad
36
78
  def step(self, var):
37
- ret = self.pre_step(var)
38
- if isinstance(ret, Var): var = ret
79
+ self.pre_step(var)
39
80
 
40
81
  if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
41
82
  params, closure, loss = var.params, var.closure, var.loss
@@ -45,9 +86,9 @@ class GradApproximator(Module, ABC):
45
86
  def approx_closure(backward=True):
46
87
  if backward:
47
88
  # set loss to None because closure might be evaluated at different points
48
- grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, var=var)
89
+ grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
49
90
  for p, g in zip(params, grad): p.grad = g
50
- return l if l is not None else l_approx
91
+ return l if l is not None else closure(False)
51
92
  return closure(False)
52
93
 
53
94
  var.closure = approx_closure
@@ -55,7 +96,7 @@ class GradApproximator(Module, ABC):
55
96
 
56
97
  # if var.grad is not None:
57
98
  # warnings.warn('Using grad approximator when `var.grad` is already set.')
58
- grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, var=var)
99
+ grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
59
100
  if loss_approx is not None: var.loss_approx = loss_approx
60
101
  if loss is not None: var.loss = var.loss_approx = loss
61
102
  if self._target == 'grad': var.grad = list(grad)
@@ -63,4 +104,4 @@ class GradApproximator(Module, ABC):
63
104
  else: raise ValueError(self._target)
64
105
  return var
65
106
 
66
- _FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
107
+ _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']