torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,10 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, apply_transform
9
- from ...utils import Distributions, TensorList, vec_to_tensors
10
- from ...utils.linalg.linear_operator import Sketched
8
+ from ...core import Chainable, Transform, HVPMethod
9
+ from ...utils import vec_to_tensors
10
+ from ...linalg.linear_operator import Sketched
11
+
11
12
  from .newton import _newton_step
12
13
 
13
14
  def _qr_orthonormalize(A:torch.Tensor):
@@ -15,9 +16,9 @@ def _qr_orthonormalize(A:torch.Tensor):
15
16
  if m < n:
16
17
  q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
17
18
  return q.T
18
- else:
19
- q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
20
- return q
19
+
20
+ q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
21
+ return q
21
22
 
22
23
  def _orthonormal_sketch(m, n, dtype, device, generator):
23
24
  return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
@@ -25,26 +26,31 @@ def _orthonormal_sketch(m, n, dtype, device, generator):
25
26
  def _gaussian_sketch(m, n, dtype, device, generator):
26
27
  return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
27
28
 
28
- class RSN(Module):
29
- """Randomized Subspace Newton. Performs a Newton step in a random subspace.
29
+ def _rademacher_sketch(m, n, dtype, device, generator):
30
+ rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
31
+ return rademacher.mul_(1 / math.sqrt(m))
32
+
33
+ class SubspaceNewton(Transform):
34
+ """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
30
35
 
31
36
  Args:
32
37
  sketch_size (int):
33
38
  size of the random sketch. This many hessian-vector products will need to be evaluated each step.
34
39
  sketch_type (str, optional):
35
40
  - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
41
+ - "rademacher" - approximately orthonormal scaled random rademacher basis.
36
42
  - "gaussian" - random gaussian (not orthonormal) basis.
37
43
  - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
38
- - "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
44
+ - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction (default).
39
45
  damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
40
46
  hvp_method (str, optional):
41
47
  How to compute hessian-matrix product:
42
- - "batched" - uses batched autograd
48
+ - "batched_autograd" - uses batched autograd
43
49
  - "autograd" - uses unbatched autograd
44
50
  - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
45
51
  - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
46
52
 
47
- . Defaults to "batched".
53
+ . Defaults to "batched_autograd".
48
54
  h (float, optional): finite difference step size. Defaults to 1e-2.
49
55
  use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
50
56
  update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
@@ -93,7 +99,7 @@ class RSN(Module):
93
99
  sketch_size: int,
94
100
  sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
95
101
  damping:float=0,
96
- hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
102
+ hvp_method: HVPMethod = "batched_autograd",
97
103
  h: float = 1e-2,
98
104
  use_lstsq: bool = True,
99
105
  update_freq: int = 1,
@@ -102,115 +108,119 @@ class RSN(Module):
102
108
  seed: int | None = None,
103
109
  inner: Chainable | None = None,
104
110
  ):
105
- defaults = dict(sketch_size=sketch_size, sketch_type=sketch_type,seed=seed,hvp_method=hvp_method, h=h, damping=damping, use_lstsq=use_lstsq, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
106
- super().__init__(defaults)
107
-
108
- if inner is not None:
109
- self.set_child("inner", inner)
111
+ defaults = locals().copy()
112
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
113
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
110
114
 
111
115
  @torch.no_grad
112
- def update(self, var):
113
- step = self.global_state.get('step', 0)
114
- self.global_state['step'] = step + 1
115
-
116
- if step % self.defaults['update_freq'] == 0:
116
+ def update_states(self, objective, states, settings):
117
+ fs = settings[0]
118
+ params = objective.params
119
+ generator = self.get_generator(params[0].device, fs["seed"])
117
120
 
118
- closure = var.closure
119
- if closure is None:
120
- raise RuntimeError("RSN requires closure")
121
- params = var.params
122
- generator = self.get_generator(params[0].device, self.defaults["seed"])
121
+ ndim = sum(p.numel() for p in params)
123
122
 
124
- ndim = sum(p.numel() for p in params)
123
+ device=params[0].device
124
+ dtype=params[0].dtype
125
125
 
126
- device=params[0].device
127
- dtype=params[0].dtype
126
+ # sample sketch matrix S: (ndim, sketch_size)
127
+ sketch_size = min(fs["sketch_size"], ndim)
128
+ sketch_type = fs["sketch_type"]
129
+ hvp_method = fs["hvp_method"]
128
130
 
129
- # sample sketch matrix S: (ndim, sketch_size)
130
- sketch_size = min(self.defaults["sketch_size"], ndim)
131
- sketch_type = self.defaults["sketch_type"]
132
- hvp_method = self.defaults["hvp_method"]
131
+ if sketch_type in ('normal', 'gaussian'):
132
+ S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
133
133
 
134
- if sketch_type in ('normal', 'gaussian'):
135
- S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
134
+ elif sketch_type == "rademacher":
135
+ S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
136
136
 
137
- elif sketch_type == 'orthonormal':
138
- S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
137
+ elif sketch_type == 'orthonormal':
138
+ S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
139
139
 
140
- elif sketch_type == 'common_directions':
141
- # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
142
- g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
143
- g = torch.cat([t.ravel() for t in g_list])
140
+ elif sketch_type == 'common_directions':
141
+ # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
142
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
143
+ g = torch.cat([t.ravel() for t in g_list])
144
144
 
145
- # initialize directions deque
146
- if "directions" not in self.global_state:
145
+ # initialize directions deque
146
+ if "directions" not in self.global_state:
147
147
 
148
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
149
+ if g_norm < torch.finfo(g.dtype).tiny * 2:
150
+ g = torch.randn_like(g)
148
151
  g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
149
- if g_norm < torch.finfo(g.dtype).tiny * 2:
150
- g = torch.randn_like(g)
151
- g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
152
-
153
- self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
154
- S = self.global_state["directions"][0].unsqueeze(1)
155
-
156
- # add new steepest descent direction orthonormal to existing columns
157
- else:
158
- S = torch.stack(tuple(self.global_state["directions"]), dim=1)
159
- p = g - S @ (S.T @ g)
160
- p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
161
- if p_norm > torch.finfo(p.dtype).tiny * 2:
162
- p = p / p_norm
163
- self.global_state["directions"].append(p)
164
- S = torch.cat([S, p.unsqueeze(1)], dim=1)
165
-
166
- elif sketch_type == "mixed":
167
- g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
168
- g = torch.cat([t.ravel() for t in g_list])
169
-
170
- if "slow_ema" not in self.global_state:
171
- self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
172
- self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
173
-
174
- slow_ema = self.global_state["slow_ema"]
175
- fast_ema = self.global_state["fast_ema"]
176
- slow_ema.lerp_(g, 0.001)
177
- fast_ema.lerp_(g, 0.1)
178
-
179
- S = torch.stack([g, slow_ema, fast_ema], dim=1)
180
- if sketch_size > 3:
181
- S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
182
- S = torch.cat([S, S_random], dim=1)
183
-
184
- S = _qr_orthonormalize(S)
185
-
186
- else:
187
- raise ValueError(f'Unknown sketch_type {sketch_type}')
188
-
189
- # form sketched hessian
190
- HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
191
- H_sketched = S.T @ HS
192
152
 
193
- self.global_state["H_sketched"] = H_sketched
194
- self.global_state["S"] = S
153
+ self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
154
+ S = self.global_state["directions"][0].unsqueeze(1)
195
155
 
196
- def apply(self, var):
156
+ # add new steepest descent direction orthonormal to existing columns
157
+ else:
158
+ S = torch.stack(tuple(self.global_state["directions"]), dim=1)
159
+ p = g - S @ (S.T @ g)
160
+ p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
161
+ if p_norm > torch.finfo(p.dtype).tiny * 2:
162
+ p = p / p_norm
163
+ self.global_state["directions"].append(p)
164
+ S = torch.cat([S, p.unsqueeze(1)], dim=1)
165
+
166
+ elif sketch_type == "mixed":
167
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
168
+ g = torch.cat([t.ravel() for t in g_list])
169
+
170
+ # initialize state
171
+ if "slow_ema" not in self.global_state:
172
+ self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
173
+ self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
174
+ self.global_state["p_prev"] = torch.randn_like(g)
175
+
176
+ # previous update direction
177
+ p_cur = torch.cat([t.ravel() for t in params])
178
+ prev_dir = p_cur - self.global_state["p_prev"]
179
+ self.global_state["p_prev"] = p_cur
180
+
181
+ # EMAs
182
+ slow_ema = self.global_state["slow_ema"]
183
+ fast_ema = self.global_state["fast_ema"]
184
+ slow_ema.lerp_(g, 0.001)
185
+ fast_ema.lerp_(g, 0.1)
186
+
187
+ # form and orthogonalize sketching matrix
188
+ S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
189
+ if sketch_size > 4:
190
+ S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
191
+ S = torch.cat([S, S_random], dim=1)
192
+
193
+ S = _qr_orthonormalize(S)
194
+
195
+ else:
196
+ raise ValueError(f'Unknown sketch_type {sketch_type}')
197
+
198
+ # form sketched hessian
199
+ HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
200
+ hvp_method=fs["hvp_method"], h=fs["h"])
201
+ H_sketched = S.T @ HS
202
+
203
+ self.global_state["H_sketched"] = H_sketched
204
+ self.global_state["S"] = S
205
+
206
+ def apply_states(self, objective, states, settings):
197
207
  S: torch.Tensor = self.global_state["S"]
208
+
198
209
  d_proj = _newton_step(
199
- var=var,
210
+ objective=objective,
200
211
  H=self.global_state["H_sketched"],
201
212
  damping=self.defaults["damping"],
202
- inner=self.children.get("inner", None),
203
213
  H_tfm=self.defaults["H_tfm"],
204
214
  eigval_fn=self.defaults["eigval_fn"],
205
215
  use_lstsq=self.defaults["use_lstsq"],
206
216
  g_proj = lambda g: S.T @ g
207
217
  )
208
- d = S @ d_proj
209
- var.update = vec_to_tensors(d, var.params)
210
218
 
211
- return var
219
+ d = S @ d_proj
220
+ objective.updates = vec_to_tensors(d, objective.params)
221
+ return objective
212
222
 
213
- def get_H(self, var=...):
223
+ def get_H(self, objective=...):
214
224
  eigval_fn = self.defaults["eigval_fn"]
215
225
  H_sketched: torch.Tensor = self.global_state["H_sketched"]
216
226
  S: torch.Tensor = self.global_state["S"]
@@ -4,7 +4,7 @@ from collections.abc import Iterable
4
4
  import torch
5
5
 
6
6
  from ...utils.tensorlist import TensorList
7
- from ...core import Transform, Target
7
+ from ...core import TensorTransform
8
8
 
9
9
 
10
10
  def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
@@ -55,7 +55,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
55
55
  v[-1] = 1
56
56
  return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
57
 
58
- class LaplacianSmoothing(Transform):
58
+ class LaplacianSmoothing(TensorTransform):
59
59
  """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
60
60
 
61
61
  Args:
@@ -70,29 +70,30 @@ class LaplacianSmoothing(Transform):
70
70
  what to set on var.
71
71
 
72
72
  Examples:
73
- Laplacian Smoothing Gradient Descent optimizer as in the paper
73
+ Laplacian Smoothing Gradient Descent optimizer as in the paper
74
74
 
75
- .. code-block:: python
75
+ ```python
76
76
 
77
- opt = tz.Modular(
78
- model.parameters(),
79
- tz.m.LaplacianSmoothing(),
80
- tz.m.LR(1e-2),
81
- )
77
+ opt = tz.Modular(
78
+ model.parameters(),
79
+ tz.m.LaplacianSmoothing(),
80
+ tz.m.LR(1e-2),
81
+ )
82
+ ```
82
83
 
83
84
  Reference:
84
85
  Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
85
86
 
86
87
  """
87
- def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
88
+ def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
88
89
  defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
89
- super().__init__(defaults, uses_grad=False, target=target)
90
+ super().__init__(defaults)
90
91
  # precomputed denominator for when layerwise=False
91
92
  self.global_state['full_denominator'] = None
92
93
 
93
94
 
94
95
  @torch.no_grad
95
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
96
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
96
97
  layerwise = settings[0]['layerwise']
97
98
 
98
99
  # layerwise laplacian smoothing
@@ -7,14 +7,15 @@ from typing import Literal, cast
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Modular, Module, Var
10
+ from ...core import Chainable, Modular, Module, Objective
11
11
  from ...core.reformulation import Reformulation
12
12
  from ...utils import Distributions, NumberList, TensorList
13
13
  from ..termination import TerminationCriteriaBase, make_termination_criteria
14
14
 
15
15
 
16
- def _reset_except_self(optimizer: Modular, var: Var, self: Module):
17
- for m in optimizer.unrolled_modules:
16
+ def _reset_except_self(objective: Objective, modules, self: Module):
17
+ assert objective.modular is not None
18
+ for m in objective.modular.flat_modules:
18
19
  if m is not self:
19
20
  m.reset()
20
21
 
@@ -98,15 +99,15 @@ class GradientSampling(Reformulation):
98
99
  self.set_child('termination', make_termination_criteria(extra=termination))
99
100
 
100
101
  @torch.no_grad
101
- def pre_step(self, var):
102
- params = TensorList(var.params)
102
+ def pre_step(self, objective):
103
+ params = TensorList(objective.params)
103
104
 
104
105
  fixed = self.defaults['fixed']
105
106
 
106
107
  # check termination criteria
107
108
  if 'termination' in self.children:
108
109
  termination = cast(TerminationCriteriaBase, self.children['termination'])
109
- if termination.should_terminate(var):
110
+ if termination.should_terminate(objective):
110
111
 
111
112
  # decay sigmas
112
113
  states = [self.state[p] for p in params]
@@ -118,7 +119,7 @@ class GradientSampling(Reformulation):
118
119
 
119
120
  # reset on sigmas decay
120
121
  if self.defaults['reset_on_termination']:
121
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
122
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
122
123
 
123
124
  # clear perturbations
124
125
  self.global_state.pop('perts', None)
@@ -136,7 +137,7 @@ class GradientSampling(Reformulation):
136
137
  self.global_state['perts'] = perts
137
138
 
138
139
  @torch.no_grad
139
- def closure(self, backward, closure, params, var):
140
+ def closure(self, backward, closure, params, objective):
140
141
  params = TensorList(params)
141
142
  loss_agg = None
142
143
  grad_agg = None
@@ -160,7 +161,7 @@ class GradientSampling(Reformulation):
160
161
 
161
162
  # evaluate at x_0
162
163
  if include_x0:
163
- f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
164
+ f_0 = objective.get_loss(backward=backward)
164
165
 
165
166
  isfinite = math.isfinite(f_0)
166
167
  if isfinite:
@@ -168,7 +169,7 @@ class GradientSampling(Reformulation):
168
169
  loss_agg = f_0
169
170
 
170
171
  if backward:
171
- g_0 = var.get_grad()
172
+ g_0 = objective.get_grads()
172
173
  if isfinite: grad_agg = g_0
173
174
 
174
175
  # evaluate at x_0 + p for each perturbation
@@ -5,9 +5,9 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Transform
8
+ from ...core import Chainable, TensorTransform
9
9
  from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
10
- from ...utils.linalg.linear_operator import ScaledIdentity
10
+ from ...linalg.linear_operator import ScaledIdentity
11
11
  from ..functional import epsilon_step_size
12
12
 
13
13
  def _acceptable_alpha(alpha, param:torch.Tensor):
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
16
16
  return False
17
17
  return True
18
18
 
19
- def _get_H(self: Transform, var):
19
+ def _get_H(self: TensorTransform, var):
20
20
  n = sum(p.numel() for p in var.params)
21
21
  p = var.params[0]
22
22
  alpha = self.global_state.get('alpha', 1)
@@ -25,7 +25,7 @@ def _get_H(self: Transform, var):
25
25
  return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
26
26
 
27
27
 
28
- class PolyakStepSize(Transform):
28
+ class PolyakStepSize(TensorTransform):
29
29
  """Polyak's subgradient method with known or unknown f*.
30
30
 
31
31
  Args:
@@ -47,7 +47,7 @@ class PolyakStepSize(Transform):
47
47
  super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
48
48
 
49
49
  @torch.no_grad
50
- def update_tensors(self, tensors, params, grads, loss, states, settings):
50
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
51
51
  assert grads is not None and loss is not None
52
52
  tensors = TensorList(tensors)
53
53
  grads = TensorList(grads)
@@ -79,15 +79,15 @@ class PolyakStepSize(Transform):
79
79
  self.global_state['alpha'] = alpha
80
80
 
81
81
  @torch.no_grad
82
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
82
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
83
83
  alpha = self.global_state.get('alpha', 1)
84
84
  if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
85
85
 
86
86
  torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
87
87
  return tensors
88
88
 
89
- def get_H(self, var):
90
- return _get_H(self, var)
89
+ def get_H(self, objective):
90
+ return _get_H(self, objective)
91
91
 
92
92
 
93
93
  def _bb_short(s: TensorList, y: TensorList, sy, eps):
@@ -116,7 +116,7 @@ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
116
116
  return None
117
117
  return (short * long) ** 0.5
118
118
 
119
- class BarzilaiBorwein(Transform):
119
+ class BarzilaiBorwein(TensorTransform):
120
120
  """Barzilai-Borwein step size method.
121
121
 
122
122
  Args:
@@ -144,7 +144,7 @@ class BarzilaiBorwein(Transform):
144
144
  self.global_state['reset'] = True
145
145
 
146
146
  @torch.no_grad
147
- def update_tensors(self, tensors, params, grads, loss, states, settings):
147
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
148
148
  step = self.global_state.get('step', 0)
149
149
  self.global_state['step'] = step + 1
150
150
 
@@ -175,11 +175,11 @@ class BarzilaiBorwein(Transform):
175
175
  prev_p.copy_(params)
176
176
  prev_g.copy_(g)
177
177
 
178
- def get_H(self, var):
179
- return _get_H(self, var)
178
+ def get_H(self, objective):
179
+ return _get_H(self, objective)
180
180
 
181
181
  @torch.no_grad
182
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
182
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
183
183
  alpha = self.global_state.get('alpha', None)
184
184
 
185
185
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -189,7 +189,7 @@ class BarzilaiBorwein(Transform):
189
189
  return tensors
190
190
 
191
191
 
192
- class BBStab(Transform):
192
+ class BBStab(TensorTransform):
193
193
  """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
194
194
 
195
195
  This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
@@ -228,7 +228,7 @@ class BBStab(Transform):
228
228
  self.global_state['reset'] = True
229
229
 
230
230
  @torch.no_grad
231
- def update_tensors(self, tensors, params, grads, loss, states, settings):
231
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
232
232
  step = self.global_state.get('step', 0)
233
233
  self.global_state['step'] = step + 1
234
234
 
@@ -287,11 +287,11 @@ class BBStab(Transform):
287
287
  prev_p.copy_(params)
288
288
  prev_g.copy_(g)
289
289
 
290
- def get_H(self, var):
291
- return _get_H(self, var)
290
+ def get_H(self, objective):
291
+ return _get_H(self, objective)
292
292
 
293
293
  @torch.no_grad
294
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
294
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
295
295
  alpha = self.global_state.get('alpha', None)
296
296
 
297
297
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -301,7 +301,7 @@ class BBStab(Transform):
301
301
  return tensors
302
302
 
303
303
 
304
- class AdGD(Transform):
304
+ class AdGD(TensorTransform):
305
305
  """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
306
306
  def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
307
307
  defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
@@ -313,7 +313,7 @@ class AdGD(Transform):
313
313
  self.global_state['reset'] = True
314
314
 
315
315
  @torch.no_grad
316
- def update_tensors(self, tensors, params, grads, loss, states, settings):
316
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
317
317
  variant = settings[0]['variant']
318
318
  theta_0 = 0 if variant == 1 else 1/3
319
319
  theta = self.global_state.get('theta', theta_0)
@@ -371,7 +371,7 @@ class AdGD(Transform):
371
371
  prev_g.copy_(g)
372
372
 
373
373
  @torch.no_grad
374
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
374
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
375
375
  alpha = self.global_state.get('alpha', None)
376
376
 
377
377
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -383,5 +383,5 @@ class AdGD(Transform):
383
383
  torch._foreach_mul_(tensors, alpha)
384
384
  return tensors
385
385
 
386
- def get_H(self, var):
387
- return _get_H(self, var)
386
+ def get_H(self, objective):
387
+ return _get_H(self, objective)
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
  import random
4
4
 
5
- from ...core import Transform
5
+ from ...core import TensorTransform
6
6
  from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
7
7
 
8
8
  def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
@@ -12,24 +12,24 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
12
12
  return tensors * lr
13
13
  return tensors
14
14
 
15
- class LR(Transform):
15
+ class LR(TensorTransform):
16
16
  """Learning rate. Adding this module also adds support for LR schedulers."""
17
17
  def __init__(self, lr: float):
18
18
  defaults=dict(lr=lr)
19
- super().__init__(defaults, uses_grad=False)
19
+ super().__init__(defaults)
20
20
 
21
21
  @torch.no_grad
22
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
22
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
23
23
  return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
24
24
 
25
- class StepSize(Transform):
25
+ class StepSize(TensorTransform):
26
26
  """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
27
27
  def __init__(self, step_size: float, key = 'step_size'):
28
28
  defaults={"key": key, key: step_size}
29
- super().__init__(defaults, uses_grad=False)
29
+ super().__init__(defaults)
30
30
 
31
31
  @torch.no_grad
32
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
32
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
33
33
  return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
34
34
 
35
35
 
@@ -38,8 +38,8 @@ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberLi
38
38
  if step > steps: return end_lr
39
39
  return start_lr + (end_lr - start_lr) * (step / steps)
40
40
 
41
- class Warmup(Transform):
42
- """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
41
+ class Warmup(TensorTransform):
42
+ """Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.
43
43
 
44
44
  Args:
45
45
  steps (int, optional): number of steps to perform warmup for. Defaults to 100.
@@ -64,7 +64,7 @@ class Warmup(Transform):
64
64
  super().__init__(defaults, uses_grad=False)
65
65
 
66
66
  @torch.no_grad
67
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
67
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
68
68
  start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
69
69
  num_steps = settings[0]['steps']
70
70
  step = self.global_state.get('step', 0)
@@ -77,7 +77,7 @@ class Warmup(Transform):
77
77
  self.global_state['step'] = step + 1
78
78
  return tensors
79
79
 
80
- class WarmupNormClip(Transform):
80
+ class WarmupNormClip(TensorTransform):
81
81
  """Warmup via clipping of the update norm.
82
82
 
83
83
  Args:
@@ -102,7 +102,7 @@ class WarmupNormClip(Transform):
102
102
  super().__init__(defaults, uses_grad=False)
103
103
 
104
104
  @torch.no_grad
105
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
105
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
106
106
  start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
107
107
  num_steps = settings[0]['steps']
108
108
  step = self.global_state.get('step', 0)
@@ -118,8 +118,8 @@ class WarmupNormClip(Transform):
118
118
  return tensors
119
119
 
120
120
 
121
- class RandomStepSize(Transform):
122
- """Uses random global or layer-wise step size from `low` to `high`.
121
+ class RandomStepSize(TensorTransform):
122
+ """Uses random global or layer-wise step size from ``low`` to ``high``.
123
123
 
124
124
  Args:
125
125
  low (float, optional): minimum learning rate. Defaults to 0.
@@ -133,7 +133,7 @@ class RandomStepSize(Transform):
133
133
  super().__init__(defaults, uses_grad=False)
134
134
 
135
135
  @torch.no_grad
136
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
136
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
137
137
  s = settings[0]
138
138
  parameterwise = s['parameterwise']
139
139