torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,9 @@ from typing import Literal, overload
3
3
  import torch
4
4
  from scipy.sparse.linalg import LinearOperator, gcrotmk
5
5
 
6
- from ...core import Chainable, Module, apply_transform
7
- from ...utils import NumberList, TensorList, as_tensorlist, generic_vector_norm, vec_to_tensors
8
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
9
- from ...utils.linalg.solve import cg, minres
6
+ from ...core import Chainable, Module, step
7
+ from ...utils import TensorList, vec_to_tensors
8
+ from ...utils.derivatives import hvp_fd_central, hvp_fd_forward
10
9
 
11
10
 
12
11
  class ScipyNewtonCG(Module):
@@ -14,7 +13,7 @@ class ScipyNewtonCG(Module):
14
13
  def __init__(
15
14
  self,
16
15
  solver = gcrotmk,
17
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
16
+ hvp_method: Literal["fd_forward", "fd_central", "autograd"] = "autograd",
18
17
  h: float = 1e-3,
19
18
  warm_start=False,
20
19
  inner: Chainable | None = None,
@@ -33,47 +32,47 @@ class ScipyNewtonCG(Module):
33
32
  self._kwargs = kwargs
34
33
 
35
34
  @torch.no_grad
36
- def step(self, var):
37
- params = TensorList(var.params)
38
- closure = var.closure
35
+ def apply(self, objective):
36
+ params = TensorList(objective.params)
37
+ closure = objective.closure
39
38
  if closure is None: raise RuntimeError('NewtonCG requires closure')
40
39
 
41
- settings = self.settings[params[0]]
42
- hvp_method = settings['hvp_method']
43
- solver = settings['solver']
44
- h = settings['h']
45
- warm_start = settings['warm_start']
40
+ fs = self.settings[params[0]]
41
+ hvp_method = fs['hvp_method']
42
+ solver = fs['solver']
43
+ h = fs['h']
44
+ warm_start = fs['warm_start']
46
45
 
47
46
  self._num_hvps_last_step = 0
48
47
  # ---------------------- Hessian vector product function --------------------- #
49
48
  device = params[0].device; dtype=params[0].dtype
50
49
  if hvp_method == 'autograd':
51
- grad = var.get_grad(create_graph=True)
50
+ grad = objective.get_grads(create_graph=True)
52
51
 
53
52
  def H_mm(x_np):
54
53
  self._num_hvps_last_step += 1
55
54
  x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
56
55
  with torch.enable_grad():
57
- Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
56
+ Hvp = TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
58
57
  return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
59
58
 
60
59
  else:
61
60
 
62
61
  with torch.enable_grad():
63
- grad = var.get_grad()
62
+ grad = objective.get_grads()
64
63
 
65
64
  if hvp_method == 'forward':
66
65
  def H_mm(x_np):
67
66
  self._num_hvps_last_step += 1
68
67
  x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
69
- Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
68
+ Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad)[1])
70
69
  return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
71
70
 
72
71
  elif hvp_method == 'central':
73
72
  def H_mm(x_np):
74
73
  self._num_hvps_last_step += 1
75
74
  x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
76
- Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
75
+ Hvp = TensorList(hvp_fd_central(closure, params, x, h=h)[1])
77
76
  return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
78
77
 
79
78
  else:
@@ -83,10 +82,8 @@ class ScipyNewtonCG(Module):
83
82
  H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
84
83
 
85
84
  # -------------------------------- inner step -------------------------------- #
86
- b = var.get_update()
87
- if 'inner' in self.children:
88
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
89
- b = as_tensorlist(b)
85
+ objective = self.inner_step("inner", objective, must_exist=False)
86
+ b = TensorList(objective.get_updates())
90
87
 
91
88
  # ---------------------------------- run cg ---------------------------------- #
92
89
  x0 = None
@@ -98,8 +95,8 @@ class ScipyNewtonCG(Module):
98
95
  if warm_start:
99
96
  self.global_state['x_prev'] = x_np
100
97
 
101
- var.update = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
98
+ objective.updates = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
102
99
 
103
100
  self._num_hvps += self._num_hvps_last_step
104
- return var
101
+ return objective
105
102
 
@@ -38,15 +38,15 @@ class SPSA1(GradApproximator):
38
38
  super().__init__(defaults, target=target)
39
39
 
40
40
 
41
- def pre_step(self, var):
41
+ def pre_step(self, objective):
42
42
 
43
43
  if self.defaults['pre_generate']:
44
44
 
45
- params = TensorList(var.params)
45
+ params = TensorList(objective.params)
46
46
  generator = self.get_generator(params[0].device, self.defaults['seed'])
47
47
 
48
48
  n_samples = self.defaults['n_samples']
49
- h = self.get_settings(var.params, 'h')
49
+ h = self.get_settings(objective.params, 'h')
50
50
 
51
51
  perturbations = [params.rademacher_like(generator=generator) for _ in range(n_samples)]
52
52
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
@@ -1,11 +1,8 @@
1
1
  import math
2
-
3
- import numpy as np
4
2
  import torch
5
3
 
6
4
  from ...core import Chainable
7
- from ...utils import vec_to_tensors, TensorList
8
- from ..adaptive.shampoo import _merge_small_dims
5
+ from ...utils import vec_to_tensors
9
6
  from ..projections import ProjectionBase
10
7
 
11
8
 
@@ -30,7 +30,7 @@ def debiased_step_size(
30
30
  pow: float = 2,
31
31
  alpha: float | NumberList = 1,
32
32
  ):
33
- """returns multiplier to step size"""
33
+ """returns multiplier to step size, step starts from 1"""
34
34
  if isinstance(beta1, NumberList): beta1 = beta1.fill_none(0)
35
35
  if isinstance(beta2, NumberList): beta2 = beta2.fill_none(0)
36
36
 
@@ -52,11 +52,11 @@ class ForwardGradient(RandomizedFDM):
52
52
  params = TensorList(params)
53
53
  loss_approx = None
54
54
 
55
- settings = self.settings[params[0]]
56
- n_samples = settings['n_samples']
57
- jvp_method = settings['jvp_method']
58
- h = settings['h']
59
- distribution = settings['distribution']
55
+ fs = self.settings[params[0]]
56
+ n_samples = fs['n_samples']
57
+ jvp_method = fs['jvp_method']
58
+ h = fs['h']
59
+ distribution = fs['distribution']
60
60
  default = [None]*n_samples
61
61
  perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
62
62
  generator = self.get_generator(params[0].device, self.defaults['seed'])
@@ -74,10 +74,10 @@ class ForwardGradient(RandomizedFDM):
74
74
  loss, d = jvp(partial(closure, False), params=params, tangent=prt)
75
75
 
76
76
  elif jvp_method == 'forward':
77
- loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, normalize=True, h=h)
77
+ loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, h=h)
78
78
 
79
79
  elif jvp_method == 'central':
80
- loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, normalize=True, h=h)
80
+ loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, h=h)
81
81
 
82
82
  else: raise ValueError(jvp_method)
83
83
 
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Objective
9
9
 
10
10
  GradTarget = Literal['update', 'grad', 'closure']
11
11
  _Scalar = torch.Tensor | float
@@ -62,24 +62,25 @@ class GradApproximator(Module, ABC):
62
62
  return spsa_grads, None, loss_plus
63
63
  ```
64
64
  """
65
- def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
65
+ def __init__(self, defaults: dict[str, Any] | None = None, return_approx_loss:bool=False, target: GradTarget = 'closure'):
66
66
  super().__init__(defaults)
67
67
  self._target: GradTarget = target
68
+ self._return_approx_loss = return_approx_loss
68
69
 
69
70
  @abstractmethod
70
71
  def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
71
72
  """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
72
73
 
73
- def pre_step(self, var: Var) -> None:
74
+ def pre_step(self, objective: Objective) -> None:
74
75
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
75
76
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
76
77
 
77
78
  @torch.no_grad
78
- def step(self, var):
79
- self.pre_step(var)
79
+ def update(self, objective):
80
+ self.pre_step(objective)
80
81
 
81
- if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
82
- params, closure, loss = var.params, var.closure, var.loss
82
+ if objective.closure is None: raise RuntimeError("Gradient approximation requires closure")
83
+ params, closure, loss = objective.params, objective.closure, objective.loss
83
84
 
84
85
  if self._target == 'closure':
85
86
 
@@ -88,20 +89,26 @@ class GradApproximator(Module, ABC):
88
89
  # set loss to None because closure might be evaluated at different points
89
90
  grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
90
91
  for p, g in zip(params, grad): p.grad = g
91
- return l if l is not None else closure(False)
92
+ if l is not None: return l
93
+ if self._return_approx_loss and l_approx is not None: return l_approx
94
+ return closure(False)
95
+
92
96
  return closure(False)
93
97
 
94
- var.closure = approx_closure
95
- return var
98
+ objective.closure = approx_closure
99
+ return
96
100
 
97
101
  # if var.grad is not None:
98
102
  # warnings.warn('Using grad approximator when `var.grad` is already set.')
99
- grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
100
- if loss_approx is not None: var.loss_approx = loss_approx
101
- if loss is not None: var.loss = var.loss_approx = loss
102
- if self._target == 'grad': var.grad = list(grad)
103
- elif self._target == 'update': var.update = list(grad)
103
+ grad, loss, loss_approx = self.approximate(closure=closure, params=params, loss=loss)
104
+ if loss_approx is not None: objective.loss_approx = loss_approx
105
+ if loss is not None: objective.loss = objective.loss_approx = loss
106
+ if self._target == 'grad': objective.grads = list(grad)
107
+ elif self._target == 'update': objective.updates = list(grad)
104
108
  else: raise ValueError(self._target)
105
- return var
109
+ return
110
+
111
+ def apply(self, objective):
112
+ return objective
106
113
 
107
114
  _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
@@ -176,7 +176,7 @@ class RandomizedFDM(GradApproximator):
176
176
  ```py
177
177
  spsa = tz.Modular(
178
178
  model.parameters(),
179
- tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
179
+ tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
180
180
  tz.m.LR(1e-2)
181
181
  )
182
182
  ```
@@ -187,7 +187,7 @@ class RandomizedFDM(GradApproximator):
187
187
  ```
188
188
  rdsa = tz.Modular(
189
189
  model.parameters(),
190
- tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
190
+ tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
191
191
  tz.m.LR(1e-2)
192
192
  )
193
193
  ```
@@ -223,23 +223,24 @@ class RandomizedFDM(GradApproximator):
223
223
  n_samples: int = 1,
224
224
  formula: _FD_Formula = "central",
225
225
  distribution: Distributions = "rademacher",
226
- pre_generate = True,
226
+ pre_generate: bool = True,
227
+ return_approx_loss: bool = False,
227
228
  seed: int | None | torch.Generator = None,
228
229
  target: GradTarget = "closure",
229
230
  ):
230
231
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
231
- super().__init__(defaults, target=target)
232
+ super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
232
233
 
233
234
 
234
- def pre_step(self, var):
235
- h = self.get_settings(var.params, 'h')
235
+ def pre_step(self, objective):
236
+ h = self.get_settings(objective.params, 'h')
236
237
  pre_generate = self.defaults['pre_generate']
237
238
 
238
239
  if pre_generate:
239
240
  n_samples = self.defaults['n_samples']
240
241
  distribution = self.defaults['distribution']
241
242
 
242
- params = TensorList(var.params)
243
+ params = TensorList(objective.params)
243
244
  generator = self.get_generator(params[0].device, self.defaults['seed'])
244
245
  perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
245
246
 
@@ -346,11 +347,12 @@ class RDSA(RandomizedFDM):
346
347
  n_samples: int = 1,
347
348
  formula: _FD_Formula = "central2",
348
349
  distribution: Distributions = "gaussian",
349
- pre_generate = True,
350
+ pre_generate: bool = True,
351
+ return_approx_loss: bool = False,
350
352
  target: GradTarget = "closure",
351
353
  seed: int | None | torch.Generator = None,
352
354
  ):
353
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
355
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
354
356
 
355
357
  class GaussianSmoothing(RandomizedFDM):
356
358
  """
@@ -380,11 +382,12 @@ class GaussianSmoothing(RandomizedFDM):
380
382
  n_samples: int = 100,
381
383
  formula: _FD_Formula = "forward2",
382
384
  distribution: Distributions = "gaussian",
383
- pre_generate = True,
385
+ pre_generate: bool = True,
386
+ return_approx_loss: bool = False,
384
387
  target: GradTarget = "closure",
385
388
  seed: int | None | torch.Generator = None,
386
389
  ):
387
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
390
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
388
391
 
389
392
  class MeZO(GradApproximator):
390
393
  """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
@@ -406,10 +409,10 @@ class MeZO(GradApproximator):
406
409
  """
407
410
 
408
411
  def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
409
- distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
412
+ distribution: Distributions = 'rademacher', return_approx_loss: bool = False, target: GradTarget = 'closure'):
410
413
 
411
414
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
412
- super().__init__(defaults, target=target)
415
+ super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
413
416
 
414
417
  def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
415
418
  prt = TensorList(params).sample_like(
@@ -419,19 +422,19 @@ class MeZO(GradApproximator):
419
422
  )
420
423
  return prt
421
424
 
422
- def pre_step(self, var):
423
- h = NumberList(self.settings[p]['h'] for p in var.params)
425
+ def pre_step(self, objective):
426
+ h = NumberList(self.settings[p]['h'] for p in objective.params)
424
427
 
425
428
  n_samples = self.defaults['n_samples']
426
429
  distribution = self.defaults['distribution']
427
430
 
428
- step = var.current_step
431
+ step = objective.current_step
429
432
 
430
433
  # create functions that generate a deterministic perturbation from seed based on current step
431
434
  prt_fns = []
432
435
  for i in range(n_samples):
433
436
 
434
- prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
437
+ prt_fn = partial(self._seeded_perturbation, params=objective.params, distribution=distribution, seed=1_000_000*step + i, h=h)
435
438
  prt_fns.append(prt_fn)
436
439
 
437
440
  self.global_state['prt_fns'] = prt_fns
@@ -1,28 +1,31 @@
1
1
  import torch
2
- from ...core import Module
3
2
 
4
- from ...utils.derivatives import jacobian_wrt, flatten_jacobian
3
+ from ...core import Chainable, Module, step
4
+ from ...linalg import linear_operator
5
5
  from ...utils import vec_to_tensors
6
- from ...utils.linalg import linear_operator
6
+ from ...utils.derivatives import flatten_jacobian, jacobian_wrt
7
+
8
+
7
9
  class SumOfSquares(Module):
8
10
  """Sets loss to be the sum of squares of values returned by the closure.
9
11
 
10
12
  This is meant to be used to test least squares methods against ordinary minimization methods.
11
13
 
12
14
  To use this, the closure should return a vector of values to minimize sum of squares of.
13
- Please add the `backward` argument, it will always be False but it is required.
15
+ Please add the ``backward`` argument, it will always be False but it is required.
14
16
  """
15
17
  def __init__(self):
16
18
  super().__init__()
17
19
 
18
20
  @torch.no_grad
19
- def step(self, var):
20
- closure = var.closure
21
+ def update(self, objective):
22
+ closure = objective.closure
21
23
 
22
24
  if closure is not None:
25
+
23
26
  def sos_closure(backward=True):
24
27
  if backward:
25
- var.zero_grad()
28
+ objective.zero_grad()
26
29
  with torch.enable_grad():
27
30
  loss = closure(False)
28
31
  loss = loss.pow(2).sum()
@@ -32,16 +35,13 @@ class SumOfSquares(Module):
32
35
  loss = closure(False)
33
36
  return loss.pow(2).sum()
34
37
 
35
- var.closure = sos_closure
36
-
37
- if var.loss is not None:
38
- var.loss = var.loss.pow(2).sum()
38
+ objective.closure = sos_closure
39
39
 
40
- if var.loss_approx is not None:
41
- var.loss_approx = var.loss_approx.pow(2).sum()
42
-
43
- return var
40
+ if objective.loss is not None:
41
+ objective.loss = objective.loss.pow(2).sum()
44
42
 
43
+ if objective.loss_approx is not None:
44
+ objective.loss_approx = objective.loss_approx.pow(2).sum()
45
45
 
46
46
  class GaussNewton(Module):
47
47
  """Gauss-newton method.
@@ -101,35 +101,45 @@ class GaussNewton(Module):
101
101
  print(f'{losses.mean() = }')
102
102
  ```
103
103
  """
104
- def __init__(self, reg:float = 1e-8, batched:bool=True, ):
104
+ def __init__(self, reg:float = 1e-8, batched:bool=True, inner: Chainable | None = None):
105
105
  super().__init__(defaults=dict(batched=batched, reg=reg))
106
+ if inner is not None: self.set_child('inner', inner)
106
107
 
107
108
  @torch.no_grad
108
- def update(self, var):
109
- params = var.params
109
+ def update(self, objective):
110
+ params = objective.params
110
111
  batched = self.defaults['batched']
111
112
 
112
- closure = var.closure
113
+ closure = objective.closure
113
114
  assert closure is not None
114
115
 
115
116
  # gauss newton direction
116
117
  with torch.enable_grad():
117
- f = var.get_loss(backward=False) # n_out
118
- assert isinstance(f, torch.Tensor)
119
- G_list = jacobian_wrt([f.ravel()], params, batched=batched)
118
+ r = objective.get_loss(backward=False) # nresiduals
119
+ assert isinstance(r, torch.Tensor)
120
+ J_list = jacobian_wrt([r.ravel()], params, batched=batched)
121
+
122
+ objective.loss = r.pow(2).sum()
123
+
124
+ J = self.global_state["J"] = flatten_jacobian(J_list) # (nresiduals, ndim)
125
+ Jr = J.T @ r.detach() # (ndim)
126
+
127
+ # if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
128
+ # otherwise solve (J J^T)z = r and set x = J^T z, so we need r
129
+ nresiduals, ndim = J.shape
130
+ if nresiduals >= ndim or "inner" in self.children:
131
+ self.global_state["Jr"] = Jr
120
132
 
121
- var.loss = f.pow(2).sum()
133
+ else:
134
+ self.global_state["r"] = r
122
135
 
123
- G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
124
- Gtf = G.T @ f.detach() # (ndim)
125
- self.global_state["Gtf"] = Gtf
126
- var.grad = vec_to_tensors(Gtf, var.params)
136
+ objective.grads = vec_to_tensors(Jr, objective.params)
127
137
 
128
138
  # set closure to calculate sum of squares for line searches etc
129
- if var.closure is not None:
139
+ if objective.closure is not None:
130
140
  def sos_closure(backward=True):
131
141
  if backward:
132
- var.zero_grad()
142
+ objective.zero_grad()
133
143
  with torch.enable_grad():
134
144
  loss = closure(False).pow(2).sum()
135
145
  loss.backward()
@@ -138,24 +148,62 @@ class GaussNewton(Module):
138
148
  loss = closure(False).pow(2).sum()
139
149
  return loss
140
150
 
141
- var.closure = sos_closure
151
+ objective.closure = sos_closure
142
152
 
143
153
  @torch.no_grad
144
- def apply(self, var):
154
+ def apply(self, objective):
145
155
  reg = self.defaults['reg']
146
156
 
147
- G = self.global_state['G']
148
- Gtf = self.global_state['Gtf']
157
+ J: torch.Tensor = self.global_state['J']
158
+ nresiduals, ndim = J.shape
159
+ if nresiduals >= ndim or "inner" in self.children:
160
+
161
+ # (J^T J)v = J^T r
162
+ Jr: torch.Tensor = self.global_state['Jr']
163
+
164
+ # inner step
165
+ if "inner" in self.children:
166
+
167
+ # var.grad is set to unflattened Jr
168
+ assert objective.grads is not None
169
+ objective = self.inner_step("inner", objective, must_exist=True)
170
+ Jr_list = objective.get_updates()
171
+ Jr = torch.cat([t.ravel() for t in Jr_list])
172
+
173
+ JJ = J.T @ J # (ndim, ndim)
174
+ if reg != 0:
175
+ JJ.add_(torch.eye(JJ.size(0), device=JJ.device, dtype=JJ.dtype).mul_(reg))
176
+
177
+ if nresiduals >= ndim:
178
+ v, info = torch.linalg.solve_ex(JJ, Jr) # pylint:disable=not-callable
179
+ else:
180
+ v = torch.linalg.lstsq(JJ, Jr).solution # pylint:disable=not-callable
181
+
182
+ objective.updates = vec_to_tensors(v, objective.params)
183
+ return objective
184
+
185
+ else:
186
+ # solve (J J^T)z = r and set v = J^T z
187
+ # derivation
188
+ # we need (J^T J)v = J^T r
189
+ # suppose z is solution to (G G^T)z = r, and v = J^T z
190
+ # if v = J^T z, then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
191
+ # therefore with our presuppositions (J^T J)v = J^T r
192
+
193
+ # also this gives a minimum norm solution
194
+
195
+ r = self.global_state['r']
149
196
 
150
- GtG = G.T @ G # (ndim, ndim)
151
- if reg != 0:
152
- GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
197
+ JJT = J @ J.T # (nresiduals, nresiduals)
198
+ if reg != 0:
199
+ JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
153
200
 
154
- v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
201
+ z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
202
+ v = J.T @ z
155
203
 
156
- var.update = vec_to_tensors(v, var.params)
157
- return var
204
+ objective.updates = vec_to_tensors(v, objective.params)
205
+ return objective
158
206
 
159
- def get_H(self, var):
160
- G = self.global_state['G']
161
- return linear_operator.AtA(G)
207
+ def get_H(self, objective=...):
208
+ J = self.global_state['J']
209
+ return linear_operator.AtA(J)
@@ -117,7 +117,7 @@ class Backtracking(LineSearchBase):
117
117
 
118
118
  # # directional derivative
119
119
  if c == 0: d = 0
120
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
120
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))
121
121
 
122
122
  # scale init
123
123
  init_scale = self.global_state.get('init_scale', 1)
@@ -199,7 +199,7 @@ class AdaptiveBacktracking(LineSearchBase):
199
199
 
200
200
  # directional derivative (0 if c = 0 because it is not needed)
201
201
  if c == 0: d = 0
202
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
202
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))
203
203
 
204
204
  # scale beta
205
205
  beta = beta * self.global_state['beta_scale']