torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,78 +2,78 @@ from collections import deque
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform
5
+ from ...core import Module, Transform
6
6
  from ...utils.tensorlist import Distributions, TensorList
7
- from ...utils.linalg.linear_operator import ScaledIdentity
7
+ from ...linalg.linear_operator import ScaledIdentity
8
8
 
9
9
  class Clone(Module):
10
10
  """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
11
11
  def __init__(self):
12
12
  super().__init__({})
13
13
  @torch.no_grad
14
- def step(self, var):
15
- var.update = [u.clone() for u in var.get_update()]
16
- return var
14
+ def apply(self, objective):
15
+ objective.updates = [u.clone() for u in objective.get_updates()]
16
+ return objective
17
17
 
18
18
  class Grad(Module):
19
19
  """Outputs the gradient"""
20
20
  def __init__(self):
21
21
  super().__init__({})
22
22
  @torch.no_grad
23
- def step(self, var):
24
- var.update = [g.clone() for g in var.get_grad()]
25
- return var
23
+ def apply(self, objective):
24
+ objective.updates = [g.clone() for g in objective.get_grads()]
25
+ return objective
26
26
 
27
27
  class Params(Module):
28
28
  """Outputs parameters"""
29
29
  def __init__(self):
30
30
  super().__init__({})
31
31
  @torch.no_grad
32
- def step(self, var):
33
- var.update = [p.clone() for p in var.params]
34
- return var
32
+ def apply(self, objective):
33
+ objective.updates = [p.clone() for p in objective.params]
34
+ return objective
35
35
 
36
36
  class Zeros(Module):
37
37
  """Outputs zeros"""
38
38
  def __init__(self):
39
39
  super().__init__({})
40
40
  @torch.no_grad
41
- def step(self, var):
42
- var.update = [torch.zeros_like(p) for p in var.params]
43
- return var
41
+ def apply(self, objective):
42
+ objective.updates = [torch.zeros_like(p) for p in objective.params]
43
+ return objective
44
44
 
45
45
  class Ones(Module):
46
46
  """Outputs ones"""
47
47
  def __init__(self):
48
48
  super().__init__({})
49
49
  @torch.no_grad
50
- def step(self, var):
51
- var.update = [torch.ones_like(p) for p in var.params]
52
- return var
50
+ def apply(self, objective):
51
+ objective.updates = [torch.ones_like(p) for p in objective.params]
52
+ return objective
53
53
 
54
54
  class Fill(Module):
55
- """Outputs tensors filled with :code:`value`"""
55
+ """Outputs tensors filled with ``value``"""
56
56
  def __init__(self, value: float):
57
57
  defaults = dict(value=value)
58
58
  super().__init__(defaults)
59
59
 
60
60
  @torch.no_grad
61
- def step(self, var):
62
- var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
63
- return var
61
+ def apply(self, objective):
62
+ objective.updates = [torch.full_like(p, self.settings[p]['value']) for p in objective.params]
63
+ return objective
64
64
 
65
65
  class RandomSample(Module):
66
- """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
66
+ """Outputs tensors filled with random numbers from distribution depending on value of ``distribution``."""
67
67
  def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
68
68
  defaults = dict(distribution=distribution, variance=variance)
69
69
  super().__init__(defaults)
70
70
 
71
71
  @torch.no_grad
72
- def step(self, var):
72
+ def apply(self, objective):
73
73
  distribution = self.defaults['distribution']
74
- variance = self.get_settings(var.params, 'variance')
75
- var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
76
- return var
74
+ variance = self.get_settings(objective.params, 'variance')
75
+ objective.updates = TensorList(objective.params).sample_like(distribution=distribution, variance=variance)
76
+ return objective
77
77
 
78
78
  class Randn(Module):
79
79
  """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
@@ -81,43 +81,44 @@ class Randn(Module):
81
81
  super().__init__({})
82
82
 
83
83
  @torch.no_grad
84
- def step(self, var):
85
- var.update = [torch.randn_like(p) for p in var.params]
86
- return var
84
+ def apply(self, objective):
85
+ objective.updates = [torch.randn_like(p) for p in objective.params]
86
+ return objective
87
87
 
88
88
  class Uniform(Module):
89
- """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
89
+ """Outputs tensors filled with random numbers from uniform distribution between ``low`` and ``high``."""
90
90
  def __init__(self, low: float, high: float):
91
91
  defaults = dict(low=low, high=high)
92
92
  super().__init__(defaults)
93
93
 
94
94
  @torch.no_grad
95
- def step(self, var):
96
- low,high = self.get_settings(var.params, 'low','high')
97
- var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
98
- return var
95
+ def apply(self, objective):
96
+ low,high = self.get_settings(objective.params, 'low','high')
97
+ objective.updates = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(objective.params, low, high)]
98
+ return objective
99
99
 
100
100
  class GradToNone(Module):
101
- """Sets :code:`grad` attribute to None on :code:`var`."""
101
+ """Sets ``grad`` attribute to None on ``objective``."""
102
102
  def __init__(self): super().__init__()
103
- def step(self, var):
104
- var.grad = None
105
- return var
103
+ def apply(self, objective):
104
+ objective.grads = None
105
+ return objective
106
106
 
107
107
  class UpdateToNone(Module):
108
- """Sets :code:`update` attribute to None on :code:`var`."""
108
+ """Sets ``update`` attribute to None on ``var``."""
109
109
  def __init__(self): super().__init__()
110
- def step(self, var):
111
- var.update = None
112
- return var
110
+ def apply(self, objective):
111
+ objective.updates = None
112
+ return objective
113
113
 
114
114
  class Identity(Module):
115
115
  """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
116
116
  def __init__(self, *args, **kwargs): super().__init__()
117
- def step(self, var): return var
118
- def get_H(self, var):
119
- n = sum(p.numel() for p in var.params)
120
- p = var.params[0]
117
+ def update(self, objective): pass
118
+ def apply(self, objective): return objective
119
+ def get_H(self, objective):
120
+ n = sum(p.numel() for p in objective.params)
121
+ p = objective.params[0]
121
122
  return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
122
123
 
123
124
  Noop = Identity
@@ -6,7 +6,7 @@ from typing import Any, Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Var
9
+ from ...core import Chainable, Module, Objective
10
10
  from .projection import ProjectionBase
11
11
 
12
12
  # TODO
@@ -8,7 +8,7 @@ from typing import Any, Literal
8
8
 
9
9
  import torch
10
10
 
11
- from ...core import Chainable, Module, Var
11
+ from ...core import Chainable, Module, Objective
12
12
  from ...utils import set_storage_, vec_to_tensors
13
13
 
14
14
 
@@ -80,7 +80,7 @@ class _FakeProjectedClosure:
80
80
  class ProjectionBase(Module, ABC):
81
81
  """
82
82
  Base class for projections.
83
- This is an abstract class, to use it, subclass it and override `project` and `unproject`.
83
+ This is an abstract class, to use it, subclass it and override ``project`` and ``unproject``.
84
84
 
85
85
  Args:
86
86
  modules (Chainable): modules that will be applied in the projected domain.
@@ -150,8 +150,8 @@ class ProjectionBase(Module, ABC):
150
150
  """
151
151
 
152
152
  @torch.no_grad
153
- def step(self, var: Var):
154
- params = var.params
153
+ def apply(self, objective: Objective):
154
+ params = objective.params
155
155
  settings = [self.settings[p] for p in params]
156
156
 
157
157
  def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
@@ -159,16 +159,16 @@ class ProjectionBase(Module, ABC):
159
159
  return list(self.project(
160
160
  tensors=tensors,
161
161
  params=params,
162
- grads=var.grad,
163
- loss=var.loss,
162
+ grads=objective.grads,
163
+ loss=objective.loss,
164
164
  states=states,
165
165
  settings=settings,
166
166
  current=current,
167
167
  ))
168
168
 
169
- projected_var = var.clone(clone_update=False, parent=var)
169
+ projected_obj = objective.clone(clone_updates=False, parent=objective)
170
170
 
171
- closure = var.closure
171
+ closure = objective.closure
172
172
 
173
173
  # if this is True, update and grad were projected simultaneously under current="grads"
174
174
  # so update will have to be unprojected with current="grads"
@@ -179,9 +179,9 @@ class ProjectionBase(Module, ABC):
179
179
  # but if it has already been computed, it should be projected
180
180
  if self._project_params and closure is not None:
181
181
 
182
- if self._project_update and var.update is not None:
182
+ if self._project_update and objective.updates is not None:
183
183
  # project update only if it already exists
184
- projected_var.update = _project(var.update, current='update')
184
+ projected_obj.updates = _project(objective.updates, current='update')
185
185
 
186
186
  else:
187
187
  # update will be set to gradients on var.get_grad()
@@ -189,43 +189,43 @@ class ProjectionBase(Module, ABC):
189
189
  update_is_grad = True
190
190
 
191
191
  # project grad only if it already exists
192
- if self._project_grad and var.grad is not None:
193
- projected_var.grad = _project(var.grad, current='grads')
192
+ if self._project_grad and objective.grads is not None:
193
+ projected_obj.grads = _project(objective.grads, current='grads')
194
194
 
195
195
  # otherwise update/grad needs to be calculated and projected here
196
196
  else:
197
197
  if self._project_update:
198
- if var.update is None:
198
+ if objective.updates is None:
199
199
  # update is None, meaning it will be set to `grad`.
200
200
  # we can project grad and use it for update
201
- grad = var.get_grad()
202
- projected_var.grad = _project(grad, current='grads')
203
- projected_var.update = [g.clone() for g in projected_var.grad]
204
- del var.update
201
+ grad = objective.get_grads()
202
+ projected_obj.grads = _project(grad, current='grads')
203
+ projected_obj.updates = [g.clone() for g in projected_obj.grads]
204
+ del objective.updates
205
205
  update_is_grad = True
206
206
 
207
207
  else:
208
208
  # update exists so it needs to be projected
209
- update = var.get_update()
210
- projected_var.update = _project(update, current='update')
211
- del update, var.update
209
+ update = objective.get_updates()
210
+ projected_obj.updates = _project(update, current='update')
211
+ del update, objective.updates
212
212
 
213
- if self._project_grad and projected_var.grad is None:
213
+ if self._project_grad and projected_obj.grads is None:
214
214
  # projected_vars.grad may have been projected simultaneously with update
215
215
  # but if that didn't happen, it is projected here
216
- grad = var.get_grad()
217
- projected_var.grad = _project(grad, current='grads')
216
+ grad = objective.get_grads()
217
+ projected_obj.grads = _project(grad, current='grads')
218
218
 
219
219
 
220
220
  original_params = None
221
221
  if self._project_params:
222
- original_params = [p.clone() for p in var.params]
223
- projected_params = _project(var.params, current='params')
222
+ original_params = [p.clone() for p in objective.params]
223
+ projected_params = _project(objective.params, current='params')
224
224
 
225
225
  else:
226
226
  # make fake params for correct shapes and state storage
227
227
  # they reuse update or grad storage for memory efficiency
228
- projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
228
+ projected_params = projected_obj.updates if projected_obj.updates is not None else projected_obj.grads
229
229
  assert projected_params is not None
230
230
 
231
231
  if self._projected_params is None:
@@ -245,8 +245,8 @@ class ProjectionBase(Module, ABC):
245
245
  return list(self.unproject(
246
246
  projected_tensors=projected_tensors,
247
247
  params=params,
248
- grads=var.grad,
249
- loss=var.loss,
248
+ grads=objective.grads,
249
+ loss=objective.loss,
250
250
  states=states,
251
251
  settings=settings,
252
252
  current=current,
@@ -254,19 +254,19 @@ class ProjectionBase(Module, ABC):
254
254
 
255
255
  # project closure
256
256
  if self._project_params:
257
- projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
257
+ projected_obj.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
258
258
  params=params, projected_params=projected_params)
259
259
 
260
260
  elif closure is not None:
261
- projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
261
+ projected_obj.closure = _FakeProjectedClosure(closure, project_fn=_project,
262
262
  params=params, fake_params=projected_params)
263
263
 
264
264
  else:
265
- projected_var.closure = None
265
+ projected_obj.closure = None
266
266
 
267
267
  # ----------------------------------- step ----------------------------------- #
268
- projected_var.params = projected_params
269
- projected_var = self.children['modules'].step(projected_var)
268
+ projected_obj.params = projected_params
269
+ projected_obj = self.children['modules'].apply(projected_obj)
270
270
 
271
271
  # empty fake params storage
272
272
  # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
@@ -275,24 +275,24 @@ class ProjectionBase(Module, ABC):
275
275
  set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
276
276
 
277
277
  # --------------------------------- unproject -------------------------------- #
278
- unprojected_var = projected_var.clone(clone_update=False)
279
- unprojected_var.closure = var.closure
280
- unprojected_var.params = var.params
281
- unprojected_var.grad = var.grad # this may also be set by projected_var since it has var as parent
278
+ unprojected_obj = projected_obj.clone(clone_updates=False)
279
+ unprojected_obj.closure = objective.closure
280
+ unprojected_obj.params = objective.params
281
+ unprojected_obj.grads = objective.grads # this may also be set by projected_var since it has var as parent
282
282
 
283
283
  if self._project_update:
284
- assert projected_var.update is not None
285
- unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
286
- del projected_var.update
284
+ assert projected_obj.updates is not None
285
+ unprojected_obj.updates = _unproject(projected_obj.updates, current='grads' if update_is_grad else 'update')
286
+ del projected_obj.updates
287
287
 
288
- del projected_var
288
+ del projected_obj
289
289
 
290
290
  # original params are stored if params are projected
291
291
  if original_params is not None:
292
- for p, o in zip(unprojected_var.params, original_params):
292
+ for p, o in zip(unprojected_obj.params, original_params):
293
293
  p.set_(o) # pyright: ignore[reportArgumentType]
294
294
 
295
- return unprojected_var
295
+ return unprojected_obj
296
296
 
297
297
 
298
298
 
@@ -29,3 +29,5 @@ from .quasi_newton import (
29
29
  ShorR,
30
30
  ThomasOptimalMethod,
31
31
  )
32
+
33
+ from .sg2 import SG2, SPSA2
@@ -4,7 +4,7 @@ from typing import Literal, Protocol, overload
4
4
  import torch
5
5
 
6
6
  from ...utils import TensorList
7
- from ...utils.linalg.linear_operator import DenseInverse, LinearOperator
7
+ from ...linalg.linear_operator import DenseInverse, LinearOperator
8
8
  from ..functional import safe_clip
9
9
 
10
10
 
@@ -4,9 +4,9 @@ from typing import overload
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Transform
7
+ from ...core import Chainable, TensorTransform
8
8
  from ...utils import TensorList, as_tensorlist, unpack_states
9
- from ...utils.linalg.linear_operator import LinearOperator
9
+ from ...linalg.linear_operator import LinearOperator
10
10
  from ..functional import initial_step_size
11
11
  from .damping import DampingStrategyType, apply_damping
12
12
 
@@ -154,7 +154,7 @@ class LBFGSLinearOperator(LinearOperator):
154
154
  return (n, n)
155
155
 
156
156
 
157
- class LBFGS(Transform):
157
+ class LBFGS(TensorTransform):
158
158
  """Limited-memory BFGS algorithm. A line search or trust region is recommended.
159
159
 
160
160
  Args:
@@ -226,7 +226,7 @@ class LBFGS(Transform):
226
226
  sy_tol=sy_tol,
227
227
  damping = damping,
228
228
  )
229
- super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
229
+ super().__init__(defaults, inner=inner, update_freq=update_freq)
230
230
 
231
231
  self.global_state['s_history'] = deque(maxlen=history_size)
232
232
  self.global_state['y_history'] = deque(maxlen=history_size)
@@ -249,7 +249,7 @@ class LBFGS(Transform):
249
249
  self.global_state.pop('step', None)
250
250
 
251
251
  @torch.no_grad
252
- def update_tensors(self, tensors, params, grads, loss, states, settings):
252
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
253
253
  p = as_tensorlist(params)
254
254
  g = as_tensorlist(tensors)
255
255
  step = self.global_state.get('step', 0)
@@ -311,14 +311,14 @@ class LBFGS(Transform):
311
311
  y_history.append(y)
312
312
  sy_history.append(sy)
313
313
 
314
- def get_H(self, var=...):
314
+ def get_H(self, objective=...):
315
315
  s_history = [tl.to_vec() for tl in self.global_state['s_history']]
316
316
  y_history = [tl.to_vec() for tl in self.global_state['y_history']]
317
317
  sy_history = self.global_state['sy_history']
318
318
  return LBFGSLinearOperator(s_history, y_history, sy_history)
319
319
 
320
320
  @torch.no_grad
321
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
321
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
322
322
  scale_first = self.defaults['scale_first']
323
323
 
324
324
  tensors = as_tensorlist(tensors)
@@ -4,9 +4,9 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Module, Transform, Var, apply_transform
7
+ from ...core import Chainable, Module, TensorTransform, Objective, step
8
8
  from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
9
- from ...utils.linalg.linear_operator import LinearOperator
9
+ from ...linalg.linear_operator import LinearOperator
10
10
  from ..functional import initial_step_size
11
11
  from .damping import DampingStrategyType, apply_damping
12
12
 
@@ -76,7 +76,7 @@ class LSR1LinearOperator(LinearOperator):
76
76
  return (n, n)
77
77
 
78
78
 
79
- class LSR1(Transform):
79
+ class LSR1(TensorTransform):
80
80
  """Limited-memory SR1 algorithm. A line search or trust region is recommended.
81
81
 
82
82
  Args:
@@ -146,7 +146,7 @@ class LSR1(Transform):
146
146
  gtol_restart=gtol_restart,
147
147
  damping = damping,
148
148
  )
149
- super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
149
+ super().__init__(defaults, inner=inner, update_freq=update_freq)
150
150
 
151
151
  self.global_state['s_history'] = deque(maxlen=history_size)
152
152
  self.global_state['y_history'] = deque(maxlen=history_size)
@@ -167,7 +167,7 @@ class LSR1(Transform):
167
167
  self.global_state.pop('step', None)
168
168
 
169
169
  @torch.no_grad
170
- def update_tensors(self, tensors, params, grads, loss, states, settings):
170
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
171
171
  p = as_tensorlist(params)
172
172
  g = as_tensorlist(tensors)
173
173
  step = self.global_state.get('step', 0)
@@ -225,13 +225,13 @@ class LSR1(Transform):
225
225
  s_history.append(s)
226
226
  y_history.append(y)
227
227
 
228
- def get_H(self, var=...):
228
+ def get_H(self, objective=...):
229
229
  s_history = [tl.to_vec() for tl in self.global_state['s_history']]
230
230
  y_history = [tl.to_vec() for tl in self.global_state['y_history']]
231
231
  return LSR1LinearOperator(s_history, y_history)
232
232
 
233
233
  @torch.no_grad
234
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
234
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
235
235
  scale_first = self.defaults['scale_first']
236
236
 
237
237
  tensors = as_tensorlist(tensors)
@@ -5,9 +5,9 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, TensorwiseTransform, Transform
8
+ from ...core import Chainable, Module, TensorTransform, Transform
9
9
  from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
10
- from ...utils.linalg import linear_operator
10
+ from ...linalg import linear_operator
11
11
  from ..functional import initial_step_size, safe_clip
12
12
 
13
13
 
@@ -17,7 +17,7 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
17
17
  elif state[key].shape != value.shape: state[key] = value
18
18
  else: state[key].lerp_(value, 1-beta)
19
19
 
20
- class HessianUpdateStrategy(TensorwiseTransform, ABC):
20
+ class HessianUpdateStrategy(TensorTransform, ABC):
21
21
  """Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
22
22
 
23
23
  This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
@@ -157,7 +157,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
157
157
  else: P *= init_scale
158
158
 
159
159
  @torch.no_grad
160
- def update_tensor(self, tensor, param, grad, loss, state, setting):
160
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
161
161
  p = param.view(-1); g = tensor.view(-1)
162
162
  inverse = setting['inverse']
163
163
  M_key = 'H' if inverse else 'B'
@@ -223,7 +223,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
223
223
  state['f_prev'] = loss
224
224
 
225
225
  @torch.no_grad
226
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
226
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
227
227
  step = state['step']
228
228
 
229
229
  if setting['scale_first'] and step == 1:
@@ -250,8 +250,8 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
250
250
  self.global_state.clear()
251
251
  return tensor.mul_(initial_step_size(tensor))
252
252
 
253
- def get_H(self, var):
254
- param = var.params[0]
253
+ def get_H(self, objective):
254
+ param = objective.params[0]
255
255
  state = self.state[param]
256
256
  settings = self.settings[param]
257
257
  if "B" in state:
@@ -1005,7 +1005,7 @@ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
1005
1005
  return g - (y * (s.dot(g) / sy))
1006
1006
 
1007
1007
 
1008
- class GradientCorrection(Transform):
1008
+ class GradientCorrection(TensorTransform):
1009
1009
  """
1010
1010
  Estimates gradient at minima along search direction assuming function is quadratic.
1011
1011
 
@@ -1027,9 +1027,9 @@ class GradientCorrection(Transform):
1027
1027
 
1028
1028
  """
1029
1029
  def __init__(self):
1030
- super().__init__(None, uses_grad=False)
1030
+ super().__init__()
1031
1031
 
1032
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
1032
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
1033
1033
  if 'p_prev' not in states[0]:
1034
1034
  p_prev = unpack_states(states, tensors, 'p_prev', init=params)
1035
1035
  g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
@@ -1182,16 +1182,19 @@ class ShorR(HessianUpdateStrategy):
1182
1182
  """Shor’s r-algorithm.
1183
1183
 
1184
1184
  Note:
1185
- A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
1186
- Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
1187
- so setting ``a_init`` in the line search is recommended.
1185
+ - A line search such as ``[tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)]`` is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting ``a_init`` in the line search is recommended.
1186
+
1187
+ - The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.
1188
1188
 
1189
1189
  References:
1190
- S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
1190
+ Those are the original references, but neither seem to be available online:
1191
+ - Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.
1192
+
1193
+ - Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.
1191
1194
 
1192
- Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
1195
+ An overview is available in [Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720](https://sites.math.washington.edu/~burke/papers/reprints/60-speed-Shor-R.pdf).
1193
1196
 
1194
- Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
1197
+ Reference by Skokov, V. A. describes a more efficient formula which can be found here [Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.](https://camo.ici.ro/books/thesis/th.pdf)
1195
1198
  """
1196
1199
 
1197
1200
  def __init__(
@@ -1229,3 +1232,9 @@ class ShorR(HessianUpdateStrategy):
1229
1232
 
1230
1233
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1231
1234
  return shor_r_(H=H, y=y, alpha=setting['alpha'])
1235
+
1236
+
1237
+ # Todd, Michael J. "The symmetric rank-one quasi-Newton method is a space-dilation subgradient algorithm." Operations research letters 5.5 (1986): 217-219.
1238
+ # TODO
1239
+
1240
+ # Sorensen, D. C. "The q-superlinear convergence of a collinear scaling algorithm for unconstrained optimization." SIAM Journal on Numerical Analysis 17.1 (1980): 84-114.