torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ from collections.abc import Iterable
4
4
  import torch
5
5
 
6
6
  from ...utils.tensorlist import TensorList
7
- from ...core import Transform, Target
7
+ from ...core import TensorTransform
8
8
 
9
9
 
10
10
  def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
@@ -55,7 +55,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
55
55
  v[-1] = 1
56
56
  return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
57
 
58
- class LaplacianSmoothing(Transform):
58
+ class LaplacianSmoothing(TensorTransform):
59
59
  """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
60
60
 
61
61
  Args:
@@ -70,29 +70,30 @@ class LaplacianSmoothing(Transform):
70
70
  what to set on var.
71
71
 
72
72
  Examples:
73
- Laplacian Smoothing Gradient Descent optimizer as in the paper
73
+ Laplacian Smoothing Gradient Descent optimizer as in the paper
74
74
 
75
- .. code-block:: python
75
+ ```python
76
76
 
77
- opt = tz.Modular(
78
- model.parameters(),
79
- tz.m.LaplacianSmoothing(),
80
- tz.m.LR(1e-2),
81
- )
77
+ opt = tz.Modular(
78
+ model.parameters(),
79
+ tz.m.LaplacianSmoothing(),
80
+ tz.m.LR(1e-2),
81
+ )
82
+ ```
82
83
 
83
84
  Reference:
84
85
  Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
85
86
 
86
87
  """
87
- def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
88
+ def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
88
89
  defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
89
- super().__init__(defaults, uses_grad=False, target=target)
90
+ super().__init__(defaults)
90
91
  # precomputed denominator for when layerwise=False
91
92
  self.global_state['full_denominator'] = None
92
93
 
93
94
 
94
95
  @torch.no_grad
95
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
96
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
96
97
  layerwise = settings[0]['layerwise']
97
98
 
98
99
  # layerwise laplacian smoothing
@@ -7,14 +7,15 @@ from typing import Literal, cast
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Modular, Module, Var
10
+ from ...core import Chainable, Modular, Module, Objective
11
11
  from ...core.reformulation import Reformulation
12
12
  from ...utils import Distributions, NumberList, TensorList
13
13
  from ..termination import TerminationCriteriaBase, make_termination_criteria
14
14
 
15
15
 
16
- def _reset_except_self(optimizer: Modular, var: Var, self: Module):
17
- for m in optimizer.unrolled_modules:
16
+ def _reset_except_self(objective: Objective, modules, self: Module):
17
+ assert objective.modular is not None
18
+ for m in objective.modular.flat_modules:
18
19
  if m is not self:
19
20
  m.reset()
20
21
 
@@ -98,15 +99,15 @@ class GradientSampling(Reformulation):
98
99
  self.set_child('termination', make_termination_criteria(extra=termination))
99
100
 
100
101
  @torch.no_grad
101
- def pre_step(self, var):
102
- params = TensorList(var.params)
102
+ def pre_step(self, objective):
103
+ params = TensorList(objective.params)
103
104
 
104
105
  fixed = self.defaults['fixed']
105
106
 
106
107
  # check termination criteria
107
108
  if 'termination' in self.children:
108
109
  termination = cast(TerminationCriteriaBase, self.children['termination'])
109
- if termination.should_terminate(var):
110
+ if termination.should_terminate(objective):
110
111
 
111
112
  # decay sigmas
112
113
  states = [self.state[p] for p in params]
@@ -118,7 +119,7 @@ class GradientSampling(Reformulation):
118
119
 
119
120
  # reset on sigmas decay
120
121
  if self.defaults['reset_on_termination']:
121
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
122
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
122
123
 
123
124
  # clear perturbations
124
125
  self.global_state.pop('perts', None)
@@ -136,7 +137,7 @@ class GradientSampling(Reformulation):
136
137
  self.global_state['perts'] = perts
137
138
 
138
139
  @torch.no_grad
139
- def closure(self, backward, closure, params, var):
140
+ def closure(self, backward, closure, params, objective):
140
141
  params = TensorList(params)
141
142
  loss_agg = None
142
143
  grad_agg = None
@@ -160,7 +161,7 @@ class GradientSampling(Reformulation):
160
161
 
161
162
  # evaluate at x_0
162
163
  if include_x0:
163
- f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
164
+ f_0 = objective.get_loss(backward=backward)
164
165
 
165
166
  isfinite = math.isfinite(f_0)
166
167
  if isfinite:
@@ -168,7 +169,7 @@ class GradientSampling(Reformulation):
168
169
  loss_agg = f_0
169
170
 
170
171
  if backward:
171
- g_0 = var.get_grad()
172
+ g_0 = objective.get_grads()
172
173
  if isfinite: grad_agg = g_0
173
174
 
174
175
  # evaluate at x_0 + p for each perturbation
@@ -5,9 +5,9 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Transform
8
+ from ...core import Chainable, TensorTransform
9
9
  from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
10
- from ...utils.linalg.linear_operator import ScaledIdentity
10
+ from ...linalg.linear_operator import ScaledIdentity
11
11
  from ..functional import epsilon_step_size
12
12
 
13
13
  def _acceptable_alpha(alpha, param:torch.Tensor):
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
16
16
  return False
17
17
  return True
18
18
 
19
- def _get_H(self: Transform, var):
19
+ def _get_H(self: TensorTransform, var):
20
20
  n = sum(p.numel() for p in var.params)
21
21
  p = var.params[0]
22
22
  alpha = self.global_state.get('alpha', 1)
@@ -25,7 +25,7 @@ def _get_H(self: Transform, var):
25
25
  return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
26
26
 
27
27
 
28
- class PolyakStepSize(Transform):
28
+ class PolyakStepSize(TensorTransform):
29
29
  """Polyak's subgradient method with known or unknown f*.
30
30
 
31
31
  Args:
@@ -47,7 +47,7 @@ class PolyakStepSize(Transform):
47
47
  super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
48
48
 
49
49
  @torch.no_grad
50
- def update_tensors(self, tensors, params, grads, loss, states, settings):
50
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
51
51
  assert grads is not None and loss is not None
52
52
  tensors = TensorList(tensors)
53
53
  grads = TensorList(grads)
@@ -79,15 +79,15 @@ class PolyakStepSize(Transform):
79
79
  self.global_state['alpha'] = alpha
80
80
 
81
81
  @torch.no_grad
82
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
82
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
83
83
  alpha = self.global_state.get('alpha', 1)
84
84
  if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
85
85
 
86
86
  torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
87
87
  return tensors
88
88
 
89
- def get_H(self, var):
90
- return _get_H(self, var)
89
+ def get_H(self, objective):
90
+ return _get_H(self, objective)
91
91
 
92
92
 
93
93
  def _bb_short(s: TensorList, y: TensorList, sy, eps):
@@ -116,7 +116,7 @@ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
116
116
  return None
117
117
  return (short * long) ** 0.5
118
118
 
119
- class BarzilaiBorwein(Transform):
119
+ class BarzilaiBorwein(TensorTransform):
120
120
  """Barzilai-Borwein step size method.
121
121
 
122
122
  Args:
@@ -144,7 +144,7 @@ class BarzilaiBorwein(Transform):
144
144
  self.global_state['reset'] = True
145
145
 
146
146
  @torch.no_grad
147
- def update_tensors(self, tensors, params, grads, loss, states, settings):
147
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
148
148
  step = self.global_state.get('step', 0)
149
149
  self.global_state['step'] = step + 1
150
150
 
@@ -175,11 +175,11 @@ class BarzilaiBorwein(Transform):
175
175
  prev_p.copy_(params)
176
176
  prev_g.copy_(g)
177
177
 
178
- def get_H(self, var):
179
- return _get_H(self, var)
178
+ def get_H(self, objective):
179
+ return _get_H(self, objective)
180
180
 
181
181
  @torch.no_grad
182
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
182
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
183
183
  alpha = self.global_state.get('alpha', None)
184
184
 
185
185
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -189,7 +189,7 @@ class BarzilaiBorwein(Transform):
189
189
  return tensors
190
190
 
191
191
 
192
- class BBStab(Transform):
192
+ class BBStab(TensorTransform):
193
193
  """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
194
194
 
195
195
  This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
@@ -228,7 +228,7 @@ class BBStab(Transform):
228
228
  self.global_state['reset'] = True
229
229
 
230
230
  @torch.no_grad
231
- def update_tensors(self, tensors, params, grads, loss, states, settings):
231
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
232
232
  step = self.global_state.get('step', 0)
233
233
  self.global_state['step'] = step + 1
234
234
 
@@ -287,11 +287,11 @@ class BBStab(Transform):
287
287
  prev_p.copy_(params)
288
288
  prev_g.copy_(g)
289
289
 
290
- def get_H(self, var):
291
- return _get_H(self, var)
290
+ def get_H(self, objective):
291
+ return _get_H(self, objective)
292
292
 
293
293
  @torch.no_grad
294
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
294
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
295
295
  alpha = self.global_state.get('alpha', None)
296
296
 
297
297
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -301,7 +301,7 @@ class BBStab(Transform):
301
301
  return tensors
302
302
 
303
303
 
304
- class AdGD(Transform):
304
+ class AdGD(TensorTransform):
305
305
  """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
306
306
  def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
307
307
  defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
@@ -313,7 +313,7 @@ class AdGD(Transform):
313
313
  self.global_state['reset'] = True
314
314
 
315
315
  @torch.no_grad
316
- def update_tensors(self, tensors, params, grads, loss, states, settings):
316
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
317
317
  variant = settings[0]['variant']
318
318
  theta_0 = 0 if variant == 1 else 1/3
319
319
  theta = self.global_state.get('theta', theta_0)
@@ -371,7 +371,7 @@ class AdGD(Transform):
371
371
  prev_g.copy_(g)
372
372
 
373
373
  @torch.no_grad
374
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
374
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
375
375
  alpha = self.global_state.get('alpha', None)
376
376
 
377
377
  if not _acceptable_alpha(alpha, tensors[0]):
@@ -383,5 +383,5 @@ class AdGD(Transform):
383
383
  torch._foreach_mul_(tensors, alpha)
384
384
  return tensors
385
385
 
386
- def get_H(self, var):
387
- return _get_H(self, var)
386
+ def get_H(self, objective):
387
+ return _get_H(self, objective)
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
  import random
4
4
 
5
- from ...core import Transform
5
+ from ...core import TensorTransform
6
6
  from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
7
7
 
8
8
  def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
@@ -12,24 +12,24 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
12
12
  return tensors * lr
13
13
  return tensors
14
14
 
15
- class LR(Transform):
15
+ class LR(TensorTransform):
16
16
  """Learning rate. Adding this module also adds support for LR schedulers."""
17
17
  def __init__(self, lr: float):
18
18
  defaults=dict(lr=lr)
19
- super().__init__(defaults, uses_grad=False)
19
+ super().__init__(defaults)
20
20
 
21
21
  @torch.no_grad
22
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
22
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
23
23
  return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
24
24
 
25
- class StepSize(Transform):
25
+ class StepSize(TensorTransform):
26
26
  """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
27
27
  def __init__(self, step_size: float, key = 'step_size'):
28
28
  defaults={"key": key, key: step_size}
29
- super().__init__(defaults, uses_grad=False)
29
+ super().__init__(defaults)
30
30
 
31
31
  @torch.no_grad
32
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
32
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
33
33
  return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
34
34
 
35
35
 
@@ -38,8 +38,8 @@ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberLi
38
38
  if step > steps: return end_lr
39
39
  return start_lr + (end_lr - start_lr) * (step / steps)
40
40
 
41
- class Warmup(Transform):
42
- """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
41
+ class Warmup(TensorTransform):
42
+ """Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.
43
43
 
44
44
  Args:
45
45
  steps (int, optional): number of steps to perform warmup for. Defaults to 100.
@@ -64,7 +64,7 @@ class Warmup(Transform):
64
64
  super().__init__(defaults, uses_grad=False)
65
65
 
66
66
  @torch.no_grad
67
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
67
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
68
68
  start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
69
69
  num_steps = settings[0]['steps']
70
70
  step = self.global_state.get('step', 0)
@@ -77,7 +77,7 @@ class Warmup(Transform):
77
77
  self.global_state['step'] = step + 1
78
78
  return tensors
79
79
 
80
- class WarmupNormClip(Transform):
80
+ class WarmupNormClip(TensorTransform):
81
81
  """Warmup via clipping of the update norm.
82
82
 
83
83
  Args:
@@ -102,7 +102,7 @@ class WarmupNormClip(Transform):
102
102
  super().__init__(defaults, uses_grad=False)
103
103
 
104
104
  @torch.no_grad
105
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
105
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
106
106
  start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
107
107
  num_steps = settings[0]['steps']
108
108
  step = self.global_state.get('step', 0)
@@ -118,8 +118,8 @@ class WarmupNormClip(Transform):
118
118
  return tensors
119
119
 
120
120
 
121
- class RandomStepSize(Transform):
122
- """Uses random global or layer-wise step size from `low` to `high`.
121
+ class RandomStepSize(TensorTransform):
122
+ """Uses random global or layer-wise step size from ``low`` to ``high``.
123
123
 
124
124
  Args:
125
125
  low (float, optional): minimum learning rate. Defaults to 0.
@@ -133,7 +133,7 @@ class RandomStepSize(Transform):
133
133
  super().__init__(defaults, uses_grad=False)
134
134
 
135
135
  @torch.no_grad
136
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
136
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
137
137
  s = settings[0]
138
138
  parameterwise = s['parameterwise']
139
139
 
@@ -1,11 +1,11 @@
1
1
  import time
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
- from typing import cast
4
+ from typing import cast, final
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Objective
9
9
  from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
10
10
 
11
11
 
@@ -16,14 +16,15 @@ class TerminationCriteriaBase(Module):
16
16
  super().__init__(defaults)
17
17
 
18
18
  @abstractmethod
19
- def termination_criteria(self, var: Var) -> bool:
19
+ def termination_criteria(self, objective: Objective) -> bool:
20
20
  ...
21
21
 
22
- def should_terminate(self, var: Var) -> bool:
22
+ @final
23
+ def should_terminate(self, objective: Objective) -> bool:
23
24
  n_bad = self.global_state.get('_n_bad', 0)
24
25
  n = self.defaults['_n']
25
26
 
26
- if self.termination_criteria(var):
27
+ if self.termination_criteria(objective):
27
28
  n_bad += 1
28
29
  if n_bad >= n:
29
30
  self.global_state['_n_bad'] = 0
@@ -36,12 +37,12 @@ class TerminationCriteriaBase(Module):
36
37
  return False
37
38
 
38
39
 
39
- def update(self, var):
40
- var.should_terminate = self.should_terminate(var)
41
- if var.should_terminate: self.global_state['_n_bad'] = 0
40
+ def update(self, objective):
41
+ objective.should_terminate = self.should_terminate(objective)
42
+ if objective.should_terminate: self.global_state['_n_bad'] = 0
42
43
 
43
- def apply(self, var):
44
- return var
44
+ def apply(self, objective):
45
+ return objective
45
46
 
46
47
 
47
48
  class TerminateAfterNSteps(TerminationCriteriaBase):
@@ -49,7 +50,7 @@ class TerminateAfterNSteps(TerminationCriteriaBase):
49
50
  defaults = dict(steps=steps)
50
51
  super().__init__(defaults)
51
52
 
52
- def termination_criteria(self, var):
53
+ def termination_criteria(self, objective):
53
54
  step = self.global_state.get('step', 0)
54
55
  self.global_state['step'] = step + 1
55
56
 
@@ -61,16 +62,17 @@ class TerminateAfterNEvaluations(TerminationCriteriaBase):
61
62
  defaults = dict(maxevals=maxevals)
62
63
  super().__init__(defaults)
63
64
 
64
- def termination_criteria(self, var):
65
+ def termination_criteria(self, objective):
65
66
  maxevals = self.defaults['maxevals']
66
- return var.modular.num_evaluations >= maxevals
67
+ assert objective.modular is not None
68
+ return objective.modular.num_evaluations >= maxevals
67
69
 
68
70
  class TerminateAfterNSeconds(TerminationCriteriaBase):
69
71
  def __init__(self, seconds:float, sec_fn = time.time):
70
72
  defaults = dict(seconds=seconds, sec_fn=sec_fn)
71
73
  super().__init__(defaults)
72
74
 
73
- def termination_criteria(self, var):
75
+ def termination_criteria(self, objective):
74
76
  max_seconds = self.defaults['seconds']
75
77
  sec_fn = self.defaults['sec_fn']
76
78
 
@@ -88,10 +90,10 @@ class TerminateByGradientNorm(TerminationCriteriaBase):
88
90
  defaults = dict(tol=tol, ord=ord)
89
91
  super().__init__(defaults, n=n)
90
92
 
91
- def termination_criteria(self, var):
93
+ def termination_criteria(self, objective):
92
94
  tol = self.defaults['tol']
93
95
  ord = self.defaults['ord']
94
- return TensorList(var.get_grad()).global_metric(ord) <= tol
96
+ return TensorList(objective.get_grads()).global_metric(ord) <= tol
95
97
 
96
98
 
97
99
  class TerminateByUpdateNorm(TerminationCriteriaBase):
@@ -100,20 +102,20 @@ class TerminateByUpdateNorm(TerminationCriteriaBase):
100
102
  defaults = dict(tol=tol, ord=ord)
101
103
  super().__init__(defaults, n=n)
102
104
 
103
- def termination_criteria(self, var):
105
+ def termination_criteria(self, objective):
104
106
  step = self.global_state.get('step', 0)
105
107
  self.global_state['step'] = step + 1
106
108
 
107
109
  tol = self.defaults['tol']
108
110
  ord = self.defaults['ord']
109
111
 
110
- p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
112
+ p_prev = self.get_state(objective.params, 'p_prev', cls=TensorList)
111
113
  if step == 0:
112
- p_prev.copy_(var.params)
114
+ p_prev.copy_(objective.params)
113
115
  return False
114
116
 
115
- should_terminate = (p_prev - var.params).global_metric(ord) <= tol
116
- p_prev.copy_(var.params)
117
+ should_terminate = (p_prev - objective.params).global_metric(ord) <= tol
118
+ p_prev.copy_(objective.params)
117
119
  return should_terminate
118
120
 
119
121
 
@@ -122,10 +124,10 @@ class TerminateOnNoImprovement(TerminationCriteriaBase):
122
124
  defaults = dict(tol=tol)
123
125
  super().__init__(defaults, n=n)
124
126
 
125
- def termination_criteria(self, var):
127
+ def termination_criteria(self, objective):
126
128
  tol = self.defaults['tol']
127
129
 
128
- f = tofloat(var.get_loss(False))
130
+ f = tofloat(objective.get_loss(False))
129
131
  if 'f_min' not in self.global_state:
130
132
  self.global_state['f_min'] = f
131
133
  return False
@@ -141,9 +143,9 @@ class TerminateOnLossReached(TerminationCriteriaBase):
141
143
  defaults = dict(value=value)
142
144
  super().__init__(defaults)
143
145
 
144
- def termination_criteria(self, var):
146
+ def termination_criteria(self, objective):
145
147
  value = self.defaults['value']
146
- return var.get_loss(False) <= value
148
+ return objective.get_loss(False) <= value
147
149
 
148
150
  class TerminateAny(TerminationCriteriaBase):
149
151
  def __init__(self, *criteria: TerminationCriteriaBase):
@@ -151,9 +153,9 @@ class TerminateAny(TerminationCriteriaBase):
151
153
 
152
154
  self.set_children_sequence(criteria)
153
155
 
154
- def termination_criteria(self, var: Var) -> bool:
156
+ def termination_criteria(self, objective: Objective) -> bool:
155
157
  for c in self.get_children_sequence():
156
- if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
158
+ if cast(TerminationCriteriaBase, c).termination_criteria(objective): return True
157
159
 
158
160
  return False
159
161
 
@@ -163,9 +165,9 @@ class TerminateAll(TerminationCriteriaBase):
163
165
 
164
166
  self.set_children_sequence(criteria)
165
167
 
166
- def termination_criteria(self, var: Var) -> bool:
168
+ def termination_criteria(self, objective: Objective) -> bool:
167
169
  for c in self.get_children_sequence():
168
- if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
170
+ if not cast(TerminationCriteriaBase, c).termination_criteria(objective): return False
169
171
 
170
172
  return True
171
173
 
@@ -173,7 +175,7 @@ class TerminateNever(TerminationCriteriaBase):
173
175
  def __init__(self):
174
176
  super().__init__()
175
177
 
176
- def termination_criteria(self, var): return False
178
+ def termination_criteria(self, objective): return False
177
179
 
178
180
  def make_termination_criteria(
179
181
  ftol: float | None = None,
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable, Module
7
7
  from ...utils import TensorList, vec_to_tensors
8
- from ...utils.linalg.linear_operator import LinearOperator
8
+ from ...linalg.linear_operator import LinearOperator
9
9
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
10
10
 
11
11
 
@@ -58,7 +58,7 @@ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_param
58
58
  for _ in range(it_max):
59
59
  r_try = (r_min + r_max) / 2
60
60
  lam = r_try * M
61
- s_lam = H.add_diagonal(lam).solve(g).neg()
61
+ s_lam = H.solve_plus_diag(g, lam).neg()
62
62
  # s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
63
63
  solver_it += 1
64
64
  crit = conv_criterion(s_lam, r_try)
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
 
4
4
  from ...core import Chainable, Module
5
- from ...utils.linalg import linear_operator
5
+ from ...linalg import linear_operator
6
6
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
7
 
8
8
 
@@ -32,38 +32,31 @@ class LevenbergMarquardt(TrustRegionBase):
32
32
  max_attempts (max_attempts, optional):
33
33
  maximum number of trust region size size reductions per step. A zero update vector is returned when
34
34
  this limit is exceeded. Defaults to 10.
35
+ adaptive (bool, optional):
36
+ if True, trust radius is multiplied by square root of gradient norm.
35
37
  fallback (bool, optional):
36
38
  if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
37
39
  be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
38
40
  inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
39
41
 
40
- Examples:
41
- Gauss-Newton with Levenberg-Marquardt trust-region
42
+ ### Examples:
42
43
 
43
- .. code-block:: python
44
+ Gauss-Newton with Levenberg-Marquardt trust-region
44
45
 
45
- opt = tz.Modular(
46
- model.parameters(),
47
- tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
48
- )
46
+ ```python
47
+ opt = tz.Modular(
48
+ model.parameters(),
49
+ tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
50
+ )
51
+ ```
49
52
 
50
- LM-SR1
51
-
52
- .. code-block:: python
53
-
54
- opt = tz.Modular(
55
- model.parameters(),
56
- tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
57
- )
58
-
59
- First order trust region (hessian is assumed to be identity)
60
-
61
- .. code-block:: python
62
-
63
- opt = tz.Modular(
64
- model.parameters(),
65
- tz.m.LevenbergMarquardt(tz.m.Identity()),
66
- )
53
+ LM-SR1
54
+ ```python
55
+ opt = tz.Modular(
56
+ model.parameters(),
57
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
58
+ )
59
+ ```
67
60
 
68
61
  """
69
62
  def __init__(
@@ -78,11 +71,12 @@ class LevenbergMarquardt(TrustRegionBase):
78
71
  max_attempts: int = 10,
79
72
  radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
80
73
  y: float = 0,
74
+ adaptive: bool = False,
81
75
  fallback: bool = False,
82
76
  update_freq: int = 1,
83
77
  inner: Chainable | None = None,
84
78
  ):
85
- defaults = dict(y=y, fallback=fallback)
79
+ defaults = dict(y=y, fallback=fallback, adaptive=adaptive)
86
80
  super().__init__(
87
81
  defaults=defaults,
88
82
  hess_module=hess_module,
@@ -103,6 +97,7 @@ class LevenbergMarquardt(TrustRegionBase):
103
97
 
104
98
  def trust_solve(self, f, g, H, radius, params, closure, settings):
105
99
  y = settings['y']
100
+ adaptive = settings["adaptive"]
106
101
 
107
102
  if isinstance(H, linear_operator.DenseInverse):
108
103
  if settings['fallback']:
@@ -117,12 +112,14 @@ class LevenbergMarquardt(TrustRegionBase):
117
112
  )
118
113
 
119
114
  reg = 1/radius
115
+ if adaptive: reg = reg * torch.linalg.vector_norm(g).sqrt()
116
+
120
117
  if y == 0:
121
- return H.add_diagonal(reg).solve(g)
118
+ return H.solve_plus_diag(g, reg) # pyright:ignore[reportAttributeAccessIssue]
122
119
 
123
120
  diag = H.diagonal()
124
121
  diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
125
122
  if y != 1: diag = (diag*y) + (1-y)
126
- return H.add_diagonal(diag*reg).solve(g)
123
+ return H.solve_plus_diag(g, diag*reg)
127
124
 
128
125