torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -5,46 +5,49 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, apply_transform
9
- from ...utils import Distributions, TensorList, vec_to_tensors
10
- from ...utils.linalg.linear_operator import Sketched
11
- from .newton import _newton_step
8
+ from ...core import Chainable, Transform, HVPMethod
9
+ from ...utils import vec_to_tensors_
10
+ from ...linalg.linear_operator import Sketched
11
+
12
+ from .newton import _newton_update_state_, _newton_solve
12
13
 
13
14
  def _qr_orthonormalize(A:torch.Tensor):
14
15
  m,n = A.shape
15
16
  if m < n:
16
17
  q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
17
18
  return q.T
18
- else:
19
- q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
20
- return q
19
+
20
+ q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
21
+ return q
22
+
21
23
 
22
24
  def _orthonormal_sketch(m, n, dtype, device, generator):
23
25
  return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
24
26
 
25
- def _gaussian_sketch(m, n, dtype, device, generator):
26
- return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
27
+ def _rademacher_sketch(m, n, dtype, device, generator):
28
+ rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
29
+ return rademacher.mul_(1 / math.sqrt(m))
27
30
 
28
- class RSN(Module):
29
- """Randomized Subspace Newton. Performs a Newton step in a random subspace.
31
+ class SubspaceNewton(Transform):
32
+ """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
30
33
 
31
34
  Args:
32
35
  sketch_size (int):
33
36
  size of the random sketch. This many hessian-vector products will need to be evaluated each step.
34
37
  sketch_type (str, optional):
38
+ - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
35
39
  - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
36
- - "gaussian" - random gaussian (not orthonormal) basis.
37
- - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
38
- - "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
40
+ - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis. It is recommended to use at least "orthonormal" - it requires QR but it is still very cheap.
41
+ - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
39
42
  damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
40
43
  hvp_method (str, optional):
41
44
  How to compute hessian-matrix product:
42
- - "batched" - uses batched autograd
45
+ - "batched_autograd" - uses batched autograd
43
46
  - "autograd" - uses unbatched autograd
44
47
  - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
45
48
  - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
46
49
 
47
- . Defaults to "batched".
50
+ . Defaults to "batched_autograd".
48
51
  h (float, optional): finite difference step size. Defaults to 1e-2.
49
52
  use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
50
53
  update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
@@ -67,7 +70,7 @@ class RSN(Module):
67
70
 
68
71
  RSN with line search
69
72
  ```python
70
- opt = tz.Modular(
73
+ opt = tz.Optimizer(
71
74
  model.parameters(),
72
75
  tz.m.RSN(),
73
76
  tz.m.Backtracking()
@@ -76,7 +79,7 @@ class RSN(Module):
76
79
 
77
80
  RSN with trust region
78
81
  ```python
79
- opt = tz.Modular(
82
+ opt = tz.Optimizer(
80
83
  model.parameters(),
81
84
  tz.m.LevenbergMarquardt(tz.m.RSN()),
82
85
  )
@@ -91,137 +94,141 @@ class RSN(Module):
91
94
  def __init__(
92
95
  self,
93
96
  sketch_size: int,
94
- sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
97
+ sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
95
98
  damping:float=0,
96
- hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
97
- h: float = 1e-2,
98
- use_lstsq: bool = True,
99
- update_freq: int = 1,
100
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
101
99
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
100
+ update_freq: int = 1,
101
+ precompute_inverse: bool = False,
102
+ use_lstsq: bool = True,
103
+ hvp_method: HVPMethod = "batched_autograd",
104
+ h: float = 1e-2,
102
105
  seed: int | None = None,
103
106
  inner: Chainable | None = None,
104
107
  ):
105
- defaults = dict(sketch_size=sketch_size, sketch_type=sketch_type,seed=seed,hvp_method=hvp_method, h=h, damping=damping, use_lstsq=use_lstsq, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
106
- super().__init__(defaults)
107
-
108
- if inner is not None:
109
- self.set_child("inner", inner)
108
+ defaults = locals().copy()
109
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
110
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
110
111
 
111
112
  @torch.no_grad
112
- def update(self, var):
113
- step = self.global_state.get('step', 0)
114
- self.global_state['step'] = step + 1
115
-
116
- if step % self.defaults['update_freq'] == 0:
113
+ def update_states(self, objective, states, settings):
114
+ fs = settings[0]
115
+ params = objective.params
116
+ generator = self.get_generator(params[0].device, fs["seed"])
117
117
 
118
- closure = var.closure
119
- if closure is None:
120
- raise RuntimeError("RSN requires closure")
121
- params = var.params
122
- generator = self.get_generator(params[0].device, self.defaults["seed"])
118
+ ndim = sum(p.numel() for p in params)
123
119
 
124
- ndim = sum(p.numel() for p in params)
120
+ device=params[0].device
121
+ dtype=params[0].dtype
125
122
 
126
- device=params[0].device
127
- dtype=params[0].dtype
123
+ # sample sketch matrix S: (ndim, sketch_size)
124
+ sketch_size = min(fs["sketch_size"], ndim)
125
+ sketch_type = fs["sketch_type"]
126
+ hvp_method = fs["hvp_method"]
128
127
 
129
- # sample sketch matrix S: (ndim, sketch_size)
130
- sketch_size = min(self.defaults["sketch_size"], ndim)
131
- sketch_type = self.defaults["sketch_type"]
132
- hvp_method = self.defaults["hvp_method"]
128
+ if sketch_type == "rademacher":
129
+ S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
133
130
 
134
- if sketch_type in ('normal', 'gaussian'):
135
- S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
131
+ elif sketch_type == 'orthonormal':
132
+ S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
136
133
 
137
- elif sketch_type == 'orthonormal':
138
- S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
134
+ elif sketch_type == 'common_directions':
135
+ # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
136
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
137
+ g = torch.cat([t.ravel() for t in g_list])
139
138
 
140
- elif sketch_type == 'common_directions':
141
- # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
142
- g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
143
- g = torch.cat([t.ravel() for t in g_list])
144
-
145
- # initialize directions deque
146
- if "directions" not in self.global_state:
139
+ # initialize directions deque
140
+ if "directions" not in self.global_state:
147
141
 
142
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
143
+ if g_norm < torch.finfo(g.dtype).tiny * 2:
144
+ g = torch.randn_like(g)
148
145
  g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
149
- if g_norm < torch.finfo(g.dtype).tiny * 2:
150
- g = torch.randn_like(g)
151
- g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
152
-
153
- self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
154
- S = self.global_state["directions"][0].unsqueeze(1)
155
-
156
- # add new steepest descent direction orthonormal to existing columns
157
- else:
158
- S = torch.stack(tuple(self.global_state["directions"]), dim=1)
159
- p = g - S @ (S.T @ g)
160
- p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
161
- if p_norm > torch.finfo(p.dtype).tiny * 2:
162
- p = p / p_norm
163
- self.global_state["directions"].append(p)
164
- S = torch.cat([S, p.unsqueeze(1)], dim=1)
165
-
166
- elif sketch_type == "mixed":
167
- g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
168
- g = torch.cat([t.ravel() for t in g_list])
169
-
170
- if "slow_ema" not in self.global_state:
171
- self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
172
- self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
173
-
174
- slow_ema = self.global_state["slow_ema"]
175
- fast_ema = self.global_state["fast_ema"]
176
- slow_ema.lerp_(g, 0.001)
177
- fast_ema.lerp_(g, 0.1)
178
-
179
- S = torch.stack([g, slow_ema, fast_ema], dim=1)
180
- if sketch_size > 3:
181
- S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
182
- S = torch.cat([S, S_random], dim=1)
183
-
184
- S = _qr_orthonormalize(S)
185
146
 
147
+ self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
148
+ S = self.global_state["directions"][0].unsqueeze(1)
149
+
150
+ # add new steepest descent direction orthonormal to existing columns
186
151
  else:
187
- raise ValueError(f'Unknown sketch_type {sketch_type}')
152
+ S = torch.stack(tuple(self.global_state["directions"]), dim=1)
153
+ p = g - S @ (S.T @ g)
154
+ p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
155
+ if p_norm > torch.finfo(p.dtype).tiny * 2:
156
+ p = p / p_norm
157
+ self.global_state["directions"].append(p)
158
+ S = torch.cat([S, p.unsqueeze(1)], dim=1)
159
+
160
+ elif sketch_type == "mixed":
161
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
162
+ g = torch.cat([t.ravel() for t in g_list])
163
+
164
+ # initialize state
165
+ if "slow_ema" not in self.global_state:
166
+ self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
167
+ self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
168
+ self.global_state["p_prev"] = torch.randn_like(g)
169
+
170
+ # previous update direction
171
+ p_cur = torch.cat([t.ravel() for t in params])
172
+ prev_dir = p_cur - self.global_state["p_prev"]
173
+ self.global_state["p_prev"] = p_cur
174
+
175
+ # EMAs
176
+ slow_ema = self.global_state["slow_ema"]
177
+ fast_ema = self.global_state["fast_ema"]
178
+ slow_ema.lerp_(g, 0.001)
179
+ fast_ema.lerp_(g, 0.1)
180
+
181
+ # form and orthogonalize sketching matrix
182
+ S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
183
+ if sketch_size > 4:
184
+ S_random = torch.randn(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator) / math.sqrt(ndim)
185
+ S = torch.cat([S, S_random], dim=1)
186
+
187
+ S = _qr_orthonormalize(S)
188
+
189
+ else:
190
+ raise ValueError(f'Unknown sketch_type {sketch_type}')
191
+
192
+ # form sketched hessian
193
+ HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
194
+ hvp_method=fs["hvp_method"], h=fs["h"])
195
+ H_sketched = S.T @ HS
196
+
197
+ # update state
198
+ _newton_update_state_(
199
+ state = self.global_state,
200
+ H = H_sketched,
201
+ damping = fs["damping"],
202
+ eigval_fn = fs["eigval_fn"],
203
+ precompute_inverse = fs["precompute_inverse"],
204
+ use_lstsq = fs["use_lstsq"]
188
205
 
189
- # form sketched hessian
190
- HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
191
- H_sketched = S.T @ HS
206
+ )
192
207
 
193
- self.global_state["H_sketched"] = H_sketched
194
- self.global_state["S"] = S
208
+ self.global_state["S"] = S
195
209
 
196
- def apply(self, var):
197
- S: torch.Tensor = self.global_state["S"]
198
- d_proj = _newton_step(
199
- var=var,
200
- H=self.global_state["H_sketched"],
201
- damping=self.defaults["damping"],
202
- inner=self.children.get("inner", None),
203
- H_tfm=self.defaults["H_tfm"],
204
- eigval_fn=self.defaults["eigval_fn"],
205
- use_lstsq=self.defaults["use_lstsq"],
206
- g_proj = lambda g: S.T @ g
207
- )
208
- d = S @ d_proj
209
- var.update = vec_to_tensors(d, var.params)
210
+ def apply_states(self, objective, states, settings):
211
+ updates = objective.get_updates()
212
+ fs = settings[0]
210
213
 
211
- return var
214
+ S = self.global_state["S"]
215
+ b = torch.cat([t.ravel() for t in updates])
216
+ b_proj = S.T @ b
212
217
 
213
- def get_H(self, var=...):
214
- eigval_fn = self.defaults["eigval_fn"]
215
- H_sketched: torch.Tensor = self.global_state["H_sketched"]
216
- S: torch.Tensor = self.global_state["S"]
218
+ d_proj = _newton_solve(b=b_proj, state=self.global_state, use_lstsq=fs["use_lstsq"])
217
219
 
218
- if eigval_fn is not None:
219
- try:
220
- L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
221
- L: torch.Tensor = eigval_fn(L)
222
- H_sketched = Q @ L.diag_embed() @ Q.mH
220
+ d = S @ d_proj
221
+ vec_to_tensors_(d, updates)
222
+ return objective
223
223
 
224
- except torch.linalg.LinAlgError:
225
- pass
224
+ def get_H(self, objective=...):
225
+ if "H" in self.global_state:
226
+ H_sketched = self.global_state["H"]
226
227
 
228
+ else:
229
+ L = self.global_state["L"]
230
+ Q = self.global_state["Q"]
231
+ H_sketched = Q @ L.diag_embed() @ Q.mH
232
+
233
+ S: torch.Tensor = self.global_state["S"]
227
234
  return Sketched(S, H_sketched)
@@ -4,7 +4,7 @@ from collections.abc import Iterable
4
4
  import torch
5
5
 
6
6
  from ...utils.tensorlist import TensorList
7
- from ...core import Transform, Target
7
+ from ...core import TensorTransform
8
8
 
9
9
 
10
10
  def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
@@ -55,7 +55,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
55
55
  v[-1] = 1
56
56
  return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
57
 
58
- class LaplacianSmoothing(Transform):
58
+ class LaplacianSmoothing(TensorTransform):
59
59
  """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
60
60
 
61
61
  Args:
@@ -70,29 +70,30 @@ class LaplacianSmoothing(Transform):
70
70
  what to set on var.
71
71
 
72
72
  Examples:
73
- Laplacian Smoothing Gradient Descent optimizer as in the paper
73
+ Laplacian Smoothing Gradient Descent optimizer as in the paper
74
74
 
75
- .. code-block:: python
75
+ ```python
76
76
 
77
- opt = tz.Modular(
78
- model.parameters(),
79
- tz.m.LaplacianSmoothing(),
80
- tz.m.LR(1e-2),
81
- )
77
+ opt = tz.Optimizer(
78
+ model.parameters(),
79
+ tz.m.LaplacianSmoothing(),
80
+ tz.m.LR(1e-2),
81
+ )
82
+ ```
82
83
 
83
84
  Reference:
84
85
  Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
85
86
 
86
87
  """
87
- def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
88
+ def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
88
89
  defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
89
- super().__init__(defaults, uses_grad=False, target=target)
90
+ super().__init__(defaults)
90
91
  # precomputed denominator for when layerwise=False
91
92
  self.global_state['full_denominator'] = None
92
93
 
93
94
 
94
95
  @torch.no_grad
95
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
96
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
96
97
  layerwise = settings[0]['layerwise']
97
98
 
98
99
  # layerwise laplacian smoothing
@@ -7,14 +7,14 @@ from typing import Literal, cast
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Modular, Module, Var
10
+ from ...core import Chainable, Optimizer, Module, Objective
11
11
  from ...core.reformulation import Reformulation
12
12
  from ...utils import Distributions, NumberList, TensorList
13
13
  from ..termination import TerminationCriteriaBase, make_termination_criteria
14
14
 
15
15
 
16
- def _reset_except_self(optimizer: Modular, var: Var, self: Module):
17
- for m in optimizer.unrolled_modules:
16
+ def _reset_except_self(objective: Objective, modules, self: Module):
17
+ for m in modules:
18
18
  if m is not self:
19
19
  m.reset()
20
20
 
@@ -98,15 +98,15 @@ class GradientSampling(Reformulation):
98
98
  self.set_child('termination', make_termination_criteria(extra=termination))
99
99
 
100
100
  @torch.no_grad
101
- def pre_step(self, var):
102
- params = TensorList(var.params)
101
+ def pre_step(self, objective):
102
+ params = TensorList(objective.params)
103
103
 
104
104
  fixed = self.defaults['fixed']
105
105
 
106
106
  # check termination criteria
107
107
  if 'termination' in self.children:
108
108
  termination = cast(TerminationCriteriaBase, self.children['termination'])
109
- if termination.should_terminate(var):
109
+ if termination.should_terminate(objective):
110
110
 
111
111
  # decay sigmas
112
112
  states = [self.state[p] for p in params]
@@ -118,7 +118,7 @@ class GradientSampling(Reformulation):
118
118
 
119
119
  # reset on sigmas decay
120
120
  if self.defaults['reset_on_termination']:
121
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
121
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
122
122
 
123
123
  # clear perturbations
124
124
  self.global_state.pop('perts', None)
@@ -136,7 +136,7 @@ class GradientSampling(Reformulation):
136
136
  self.global_state['perts'] = perts
137
137
 
138
138
  @torch.no_grad
139
- def closure(self, backward, closure, params, var):
139
+ def closure(self, backward, closure, params, objective):
140
140
  params = TensorList(params)
141
141
  loss_agg = None
142
142
  grad_agg = None
@@ -160,7 +160,7 @@ class GradientSampling(Reformulation):
160
160
 
161
161
  # evaluate at x_0
162
162
  if include_x0:
163
- f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
163
+ f_0 = objective.get_loss(backward=backward)
164
164
 
165
165
  isfinite = math.isfinite(f_0)
166
166
  if isfinite:
@@ -168,7 +168,7 @@ class GradientSampling(Reformulation):
168
168
  loss_agg = f_0
169
169
 
170
170
  if backward:
171
- g_0 = var.get_grad()
171
+ g_0 = objective.get_grads()
172
172
  if isfinite: grad_agg = g_0
173
173
 
174
174
  # evaluate at x_0 + p for each perturbation
@@ -5,10 +5,10 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Transform
8
+ from ...core import Chainable, TensorTransform
9
9
  from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
10
- from ...utils.linalg.linear_operator import ScaledIdentity
11
- from ..functional import epsilon_step_size
10
+ from ...linalg.linear_operator import ScaledIdentity
11
+ from ..opt_utils import epsilon_step_size
12
12
 
13
13
  def _acceptable_alpha(alpha, param:torch.Tensor):
14
14
  finfo = torch.finfo(param.dtype)
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
16
16
  return False
17
17
  return True
18
18
 
19
- def _get_H(self: Transform, var):
19
+ def _get_scaled_identity_H(self: TensorTransform, var):
20
20
  n = sum(p.numel() for p in var.params)
21
21
  p = var.params[0]
22
22
  alpha = self.global_state.get('alpha', 1)
@@ -25,7 +25,7 @@ def _get_H(self: Transform, var):
25
25
  return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
26
26
 
27
27
 
28
- class PolyakStepSize(Transform):
28
+ class PolyakStepSize(TensorTransform):
29
29
  """Polyak's subgradient method with known or unknown f*.
30
30
 
31
31
  Args:
@@ -47,7 +47,7 @@ class PolyakStepSize(Transform):
47
47
  super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
48
48
 
49
49
  @torch.no_grad
50
- def update_tensors(self, tensors, params, grads, loss, states, settings):
50
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
51
51
  assert grads is not None and loss is not None
52
52
  tensors = TensorList(tensors)
53
53
  grads = TensorList(grads)
@@ -79,15 +79,15 @@ class PolyakStepSize(Transform):
79
79
  self.global_state['alpha'] = alpha
80
80
 
81
81
  @torch.no_grad
82
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
82
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
83
83
  alpha = self.global_state.get('alpha', 1)
84
84
  if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
85
85
 
86
86
  torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
87
87
  return tensors
88
88
 
89
- def get_H(self, var):
90
- return _get_H(self, var)
89
+ def get_H(self, objective):
90
+ return _get_scaled_identity_H(self, objective)
91
91
 
92
92
 
93
93
  def _bb_short(s: TensorList, y: TensorList, sy, eps):
@@ -116,7 +116,7 @@ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
116
116
  return None
117
117
  return (short * long) ** 0.5
118
118
 
119
- class BarzilaiBorwein(Transform):
119
+ class BarzilaiBorwein(TensorTransform):
120
120
  """Barzilai-Borwein step size method.
121
121
 
122
122
  Args:
@@ -144,7 +144,7 @@ class BarzilaiBorwein(Transform):
144
144
  self.global_state['reset'] = True
145
145
 
146
146
  @torch.no_grad
147
- def update_tensors(self, tensors, params, grads, loss, states, settings):
147
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
148
148
  step = self.global_state.get('step', 0)
149
149
  self.global_state['step'] = step + 1
150
150
 
@@ -175,11 +175,11 @@ class BarzilaiBorwein(Transform):
175
175
  prev_p.copy_(params)
176
176
  prev_g.copy_(g)
177
177
 
178
- def get_H(self, var):
179
- return _get_H(self, var)
178
+ def get_H(self, objective):
179
+ return _get_scaled_identity_H(self, objective)
180
180
 
181
181
  @torch.no_grad
182
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
182
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
183
183
  alpha = self.global_state.get('alpha', None)
184
184
 
185
185
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -189,7 +189,7 @@ class BarzilaiBorwein(Transform):
189
189
  return tensors
190
190
 
191
191
 
192
- class BBStab(Transform):
192
+ class BBStab(TensorTransform):
193
193
  """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
194
194
 
195
195
  This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
@@ -228,7 +228,7 @@ class BBStab(Transform):
228
228
  self.global_state['reset'] = True
229
229
 
230
230
  @torch.no_grad
231
- def update_tensors(self, tensors, params, grads, loss, states, settings):
231
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
232
232
  step = self.global_state.get('step', 0)
233
233
  self.global_state['step'] = step + 1
234
234
 
@@ -287,11 +287,11 @@ class BBStab(Transform):
287
287
  prev_p.copy_(params)
288
288
  prev_g.copy_(g)
289
289
 
290
- def get_H(self, var):
291
- return _get_H(self, var)
290
+ def get_H(self, objective):
291
+ return _get_scaled_identity_H(self, objective)
292
292
 
293
293
  @torch.no_grad
294
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
294
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
295
295
  alpha = self.global_state.get('alpha', None)
296
296
 
297
297
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -301,7 +301,7 @@ class BBStab(Transform):
301
301
  return tensors
302
302
 
303
303
 
304
- class AdGD(Transform):
304
+ class AdGD(TensorTransform):
305
305
  """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
306
306
  def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
307
307
  defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
@@ -313,7 +313,7 @@ class AdGD(Transform):
313
313
  self.global_state['reset'] = True
314
314
 
315
315
  @torch.no_grad
316
- def update_tensors(self, tensors, params, grads, loss, states, settings):
316
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
317
317
  variant = settings[0]['variant']
318
318
  theta_0 = 0 if variant == 1 else 1/3
319
319
  theta = self.global_state.get('theta', theta_0)
@@ -371,7 +371,7 @@ class AdGD(Transform):
371
371
  prev_g.copy_(g)
372
372
 
373
373
  @torch.no_grad
374
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
374
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
375
375
  alpha = self.global_state.get('alpha', None)
376
376
 
377
377
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -383,5 +383,5 @@ class AdGD(Transform):
383
383
  torch._foreach_mul_(tensors, alpha)
384
384
  return tensors
385
385
 
386
- def get_H(self, var):
387
- return _get_H(self, var)
386
+ def get_H(self, objective):
387
+ return _get_scaled_identity_H(self, objective)