torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
torchzero/core/module.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any, final, overload, Literal
6
+ from typing import Any, final, overload, Literal, cast
7
7
 
8
8
  import torch
9
9
 
@@ -16,6 +16,7 @@ from ..utils import (
16
16
  )
17
17
  from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
18
18
  from ..utils.python_tools import flatten
19
+ from ..utils.linalg.linear_operator import LinearOperator
19
20
 
20
21
 
21
22
  def _closure_backward(closure, params, retain_graph, create_graph):
@@ -33,11 +34,9 @@ def _closure_backward(closure, params, retain_graph, create_graph):
33
34
  # ----------------------------------- var ----------------------------------- #
34
35
  class Var:
35
36
  """
36
- Holds the state and context passed between optimizer modules during a step.
37
+ Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
38
+ Modules take in a ``Var`` object, modify and it is passed to the next module.
37
39
 
38
- This class acts as a mutable container for information relevant to the current
39
- optimization step, such as parameters, gradients, loss, and the computed update.
40
- Modules read from and write to this object to coordinate their actions.
41
40
  """
42
41
  def __init__(
43
42
  self,
@@ -45,6 +44,10 @@ class Var:
45
44
  closure: Callable | None,
46
45
  model: torch.nn.Module | None,
47
46
  current_step: int,
47
+ parent: "Var | None" = None,
48
+ modular: "Modular | None" = None,
49
+ loss: torch.Tensor | None = None,
50
+ storage: dict | None = None,
48
51
  ):
49
52
  self.params: list[torch.Tensor] = params
50
53
  """List of all parameters with requires_grad = True."""
@@ -56,19 +59,31 @@ class Var:
56
59
  """torch.nn.Module object of the model, None if it wasn't specified."""
57
60
 
58
61
  self.current_step: int = current_step
59
- """global current step, starts at 0"""
62
+ """global current step, starts at 0. This may not correspond to module current step,
63
+ for example a module may step every 10 global steps."""
64
+
65
+ self.parent: "Var | None" = parent
66
+ """parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
67
+ Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
68
+ e.g. when projecting."""
69
+
70
+ self.modular: "Modular" = cast(Modular, modular)
71
+ """Modular optimizer object that created this ``Var``."""
60
72
 
61
73
  self.update: list[torch.Tensor] | None = None
62
74
  """
63
- current update, at the end this is subtracted from model parameters unless it is None.
75
+ current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
64
76
 
65
77
  If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
78
+
79
+ At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
80
+ gradient will be used and calculated if needed.
66
81
  """
67
82
 
68
83
  self.grad: list[torch.Tensor] | None = None
69
- """gradient with current parameters. If closure is not None, this is set to None and can be calculated if needed."""
84
+ """gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
70
85
 
71
- self.loss: torch.Tensor | Any | None = None
86
+ self.loss: torch.Tensor | Any | None = loss
72
87
  """loss with current parameters."""
73
88
 
74
89
  self.loss_approx: torch.Tensor | Any | None = None
@@ -77,24 +92,28 @@ class Var:
77
92
 
78
93
  self.post_step_hooks: list[Callable[[Modular, Var]]] = []
79
94
  """list of functions to be called after optimizer step.
80
- The signature is:
81
95
 
82
- .. code:: py
96
+ This attribute should always be modified in-place (using ``append`` or ``extend``).
83
97
 
84
- def hook(optimizer: Modular, var: Vars): ...
98
+ The signature is:
85
99
 
100
+ ```python
101
+ def hook(optimizer: Modular, var: Vars): ...
102
+ ```
86
103
  """
87
104
 
88
105
  self.is_last: bool = False
89
106
  """
90
107
  Indicates that current module is either last or next-to-last before a learning rate module.
91
108
  This is always False if current module has children or is a child.
109
+ This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
92
110
  """
93
111
 
94
112
  self.nested_is_last: bool = False
95
113
  """
96
114
  Indicates that current module is either last or next-to-last before a learning rate module, for modules
97
- that have children.
115
+ that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
116
+ a child of a last module may also receive ``var.nested_is_last=True``.
98
117
  """
99
118
 
100
119
  self.last_module_lrs: list[float] | None = None
@@ -105,19 +124,30 @@ class Var:
105
124
  """
106
125
 
107
126
  self.stop: bool = False
108
- """if True, all following modules will be skipped."""
127
+ """if True, all following modules will be skipped.
128
+ If this module is a child, it only affects modules at the same level (in the same Chain)."""
109
129
 
110
130
  self.skip_update: bool = False
111
- """if True, the parameters will not be updated"""
131
+ """if True, the parameters will not be updated."""
112
132
 
113
- self.storage: dict = {}
114
- """Storage for any other data, such as hessian estimates, etc"""
133
+ # self.storage: dict = {}
134
+ # """Storage for any other data, such as hessian estimates, etc."""
115
135
 
116
- def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
117
- """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
118
- Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
136
+ self.attrs: dict = {}
137
+ """attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
138
+
139
+ if storage is None: storage = {}
140
+ self.storage: dict = storage
141
+ """additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
142
+
143
+ self.should_terminate: bool | None = None
144
+ """termination criteria, Modular.should_terminate is set to this after each step if not None"""
119
145
 
146
+ def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
147
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
148
+ Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
120
149
  if self.loss is None:
150
+
121
151
  if self.closure is None: raise RuntimeError("closure is None")
122
152
  if backward:
123
153
  with torch.enable_grad():
@@ -128,7 +158,10 @@ class Var:
128
158
  # initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
129
159
  # it is technically a more correct approach for when some parameters conditionally receive gradients
130
160
  # and in this case it shouldn't be slower.
131
- self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
161
+
162
+ # next time closure() is called, it will set grad to None.
163
+ # zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
164
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
132
165
  else:
133
166
  self.loss = self.loss_approx = self.closure(False)
134
167
 
@@ -143,11 +176,24 @@ class Var:
143
176
  closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
144
177
  )
145
178
  self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
179
+
180
+ # set parent grad
181
+ if self.parent is not None:
182
+ # the way projections/split work, they make a new closure which evaluates original
183
+ # closure and projects the gradient, and set it as their var.closure.
184
+ # then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
185
+ # and we set it to parent var here.
186
+ if self.parent.loss is None: self.parent.loss = self.loss
187
+ if self.parent.grad is None and backward:
188
+ if all(p.grad is None for p in self.parent.params):
189
+ warnings.warn("Parent grad is None after backward.")
190
+ self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
191
+
146
192
  return self.loss # type:ignore
147
193
 
148
194
  def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
149
195
  """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
150
- :code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
196
+ ``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
151
197
  if self.grad is None:
152
198
  if self.closure is None: raise RuntimeError("closure is None")
153
199
  self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
@@ -156,15 +202,21 @@ class Var:
156
202
  return self.grad
157
203
 
158
204
  def get_update(self) -> list[torch.Tensor]:
159
- """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
160
- Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
205
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
206
+ Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
161
207
  Do not call this at perturbed parameters."""
162
208
  if self.update is None: self.update = [g.clone() for g in self.get_grad()]
163
209
  return self.update
164
210
 
165
- def clone(self, clone_update: bool):
166
- """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
167
- copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
211
+ def clone(self, clone_update: bool, parent: "Var | None" = None):
212
+ """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
213
+
214
+ Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
215
+
216
+ Setting ``parent`` is only if clone's parameters are something different,
217
+ while clone's closure referes to the same objective but with a "view" on parameters.
218
+ """
219
+ copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
168
220
 
169
221
  if clone_update and self.update is not None:
170
222
  copy.update = [u.clone() for u in self.update]
@@ -174,10 +226,16 @@ class Var:
174
226
  copy.grad = self.grad
175
227
  copy.loss = self.loss
176
228
  copy.loss_approx = self.loss_approx
229
+ copy.closure = self.closure
177
230
  copy.post_step_hooks = self.post_step_hooks
178
231
  copy.stop = self.stop
179
232
  copy.skip_update = self.skip_update
180
233
 
234
+ copy.modular = self.modular
235
+ copy.attrs = self.attrs
236
+ copy.storage = self.storage
237
+ copy.should_terminate = self.should_terminate
238
+
181
239
  return copy
182
240
 
183
241
  def update_attrs_from_clone_(self, var: "Var"):
@@ -186,11 +244,15 @@ class Var:
186
244
  object. This propagates any newly computed loss or gradient values
187
245
  from the child's context back to the parent `Vars` if the parent
188
246
  didn't have them computed already.
247
+
248
+ Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
249
+ if the child updates them, the update will affect the parent too.
189
250
  """
190
251
  if self.loss is None: self.loss = var.loss
191
252
  if self.loss_approx is None: self.loss_approx = var.loss_approx
192
253
  if self.grad is None: self.grad = var.grad
193
- self.storage.update(var.storage)
254
+
255
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
194
256
 
195
257
  def zero_grad(self, set_to_none=True):
196
258
  if set_to_none:
@@ -201,6 +263,7 @@ class Var:
201
263
 
202
264
  # endregion
203
265
 
266
+
204
267
  # region Module
205
268
  # ---------------------------------- module ---------------------------------- #
206
269
  class Module(ABC):
@@ -313,17 +376,16 @@ class Module(ABC):
313
376
 
314
377
  If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
315
378
 
316
- .. code:: py
317
-
318
- exp_avg = self.state_vals("exp_avg")
319
- # returns cls (by default TensorList)
320
-
321
- exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
322
- # returns list of cls
379
+ ```python
380
+ exp_avg = self.state_vals("exp_avg")
381
+ # returns cls (by default TensorList)
323
382
 
324
- exp_avg = self.state_vals(["exp_avg"])
325
- # always returns a list of cls, even if got a single key
383
+ exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
384
+ # returns list of cls
326
385
 
386
+ exp_avg = self.state_vals(["exp_avg"])
387
+ # always returns a list of cls, even if got a single key
388
+ ```
327
389
 
328
390
  Args:
329
391
  *keys (str):
@@ -402,7 +464,8 @@ class Module(ABC):
402
464
  }
403
465
  return state_dict
404
466
 
405
- def load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
467
+ def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
468
+ """loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
406
469
  # load state
407
470
  state = state_dict['state']
408
471
  self.state.clear()
@@ -421,7 +484,7 @@ class Module(ABC):
421
484
 
422
485
  # children
423
486
  for k, v in state_dict['children']:
424
- if k in self.children: self.children[k].load_state_dict(v, id_to_tensor)
487
+ if k in self.children: self.children[k]._load_state_dict(v, id_to_tensor)
425
488
  else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
426
489
 
427
490
  # extra info
@@ -429,37 +492,72 @@ class Module(ABC):
429
492
 
430
493
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
431
494
  def step(self, var: Var) -> Var:
432
- """performs a step, returns new var but may update it in-place."""
495
+ """performs a step, returns new ``var`` but may update it in-place."""
433
496
  self.update(var)
434
497
  return self.apply(var)
435
498
 
436
499
  def update(self, var:Var) -> Any:
437
- """Updates the internal state of this module. This should not modify `var.update`.
500
+ """Updates the internal state of this module. This should not modify ``var.update``.
438
501
 
439
502
  Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
440
- such as ::code::`tz.m.Online`.
503
+ such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
441
504
  """
442
505
 
443
506
  def apply(self, var: Var) -> Var:
444
- """Applies this module to ``var.get_update()``. This should not modify the internal state of this module if possible."""
445
- raise NotImplementedError(f"{self} doesn't implement the `apply` method.")
507
+ """Applies this module to ``var.get_update()``.
508
+ This should not modify the internal state of this module if possible.
509
+
510
+ Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
511
+ such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
512
+ """
513
+ return self.step(var)
514
+
515
+ def get_H(self, var: Var) -> LinearOperator | None:
516
+ """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
517
+ The hessian approximation is assumed to be for all parameters concatenated to a vector."""
518
+ # if this method is not defined it searches in children
519
+ # this should be overwritten to return None if child params are different from this modules params
520
+ H = None
521
+ for k,v in self.children.items():
522
+ H_v = v.get_H(var)
523
+
524
+ if (H is not None) and (H_v is not None):
525
+ raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
526
+
527
+ if H_v is not None: H = H_v
528
+
529
+ return H
446
530
 
447
531
  def reset(self):
448
- """Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
449
- # no complex logic is allowed there because this is overridden by many modules
450
- # where super().reset() shouldn't be called
532
+ """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
451
533
  self.state.clear()
534
+
535
+ generator = self.global_state.get("generator", None)
452
536
  self.global_state.clear()
537
+ if generator is not None: self.global_state["generator"] = generator
538
+
539
+ for c in self.children.values(): c.reset()
453
540
 
454
541
  def reset_for_online(self):
455
- """resets only the intermediate state of this module, e.g. previous parameters and gradient."""
542
+ """Resets buffers that depend on previous evaluation, such as previous gradient and loss,
543
+ which may become inaccurate due to mini-batching.
544
+
545
+ ``Online`` module calls ``reset_for_online``,
546
+ then it calls ``update`` with previous parameters,
547
+ then it calls ``update`` with current parameters,
548
+ and then ``apply``.
549
+ """
456
550
  for c in self.children.values(): c.reset_for_online()
457
551
 
458
552
  def _extra_pack(self):
553
+ """extra information to store in state_dict of this optimizer.
554
+ Will be passed to ``_extra_unpack`` when loading the state_dict."""
459
555
  return {}
460
556
 
461
557
  def _extra_unpack(self, x):
462
- pass
558
+ """``_extra_pack`` return will be passed to this method when loading state_dict.
559
+ This method is called after loading the rest of the state dict"""
560
+
463
561
 
464
562
 
465
563
  # ------------------------------ HELPER METHODS ------------------------------ #
@@ -474,30 +572,33 @@ class Module(ABC):
474
572
  h: float,
475
573
  normalize: bool,
476
574
  retain_grad: bool,
477
- ):
575
+ ) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
478
576
  """
479
- Returns ``(Hvp, rgrad)``. ``rgrad`` is gradient at current parameters, possibly with create_graph=True, or it may be None with ``hvp_method="central"``. Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
577
+ Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
578
+ possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
579
+ Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
480
580
 
481
581
  Single sample example:
482
582
 
483
- .. code:: py
484
-
485
- Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
583
+ ```python
584
+ Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
585
+ ```
486
586
 
487
587
  Multiple samples example:
488
588
 
489
- .. code:: py
589
+ ```python
590
+ D = None
591
+ rgrad = None
592
+ for i in range(n_samples):
593
+ v = [torch.randn_like(p) for p in params]
594
+ Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
490
595
 
491
- D = None
492
- rgrad = None
493
- for i in range(n_samples):
494
- v = [torch.randn_like(p) for p in params]
495
- Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
596
+ if D is None: D = Hvp
597
+ else: torch._foreach_add_(D, Hvp)
496
598
 
497
- if D is None: D = Hvp
498
- else: torch._foreach_add_(D, Hvp)
599
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
600
+ ```
499
601
 
500
- if n_samples > 1: torch._foreach_div_(D, n_samples)
501
602
  Args:
502
603
  v (Sequence[torch.Tensor]): vector in hessian-vector product
503
604
  at_x0 (bool): whether this is being called at original or perturbed parameters.
@@ -533,6 +634,14 @@ class Module(ABC):
533
634
 
534
635
  return Hvp, rgrad
535
636
 
637
+ def get_generator(self, device: torch.types.Device, seed: int | None):
638
+ if seed is None: return None
639
+
640
+ if 'generator' not in self.global_state:
641
+ self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
642
+
643
+ return self.global_state['generator']
644
+
536
645
  # endregion
537
646
 
538
647
  Chainable = Module | Sequence[Module]
@@ -555,7 +664,7 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
555
664
  # ---------------------------------- Modular --------------------------------- #
556
665
 
557
666
  class _EvalCounterClosure:
558
- """keeps track of how many times closure has been evaluated"""
667
+ """keeps track of how many times closure has been evaluated, and sets closure return"""
559
668
  __slots__ = ("modular", "closure")
560
669
  def __init__(self, modular: "Modular", closure):
561
670
  self.modular = modular
@@ -565,8 +674,14 @@ class _EvalCounterClosure:
565
674
  if self.closure is None:
566
675
  raise RuntimeError("One of the modules requires closure to be passed to the step method")
567
676
 
677
+ v = self.closure(*args, **kwargs)
678
+
679
+ # set closure return on 1st evaluation
680
+ if self.modular._closure_return is None:
681
+ self.modular._closure_return = v
682
+
568
683
  self.modular.num_evaluations += 1
569
- return self.closure(*args, **kwargs)
684
+ return v
570
685
 
571
686
  # have to inherit from Modular to support lr schedulers
572
687
  # although Accelerate doesn't work due to converting param_groups to a dict
@@ -584,6 +699,7 @@ class Modular(torch.optim.Optimizer):
584
699
  param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
585
700
 
586
701
  def __init__(self, params: Params | torch.nn.Module, *modules: Module):
702
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
587
703
  self.model: torch.nn.Module | None = None
588
704
  """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
589
705
  if isinstance(params, torch.nn.Module):
@@ -617,18 +733,34 @@ class Modular(torch.optim.Optimizer):
617
733
  for m in self.unrolled_modules: defaults.update(m.defaults)
618
734
  super().__init__(param_groups, defaults=defaults)
619
735
 
620
- # note - this is what super init does:
736
+ # note - this is what super().__init__(param_groups, defaults=defaults) does:
621
737
 
622
738
  # self.defaults = defaults
623
739
  # for param_group in param_groups:
624
740
  # self.add_param_group(param_group)
625
741
 
742
+ # add_param_group adds a ChainMap where defaults are lowest priority,
743
+ # and entries specifed in param_groups or scheduler are higher priority.
744
+ # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
745
+ # in each module, settings passed to that module by calling set_param_groups are highest priority
746
+
626
747
  self.current_step = 0
627
748
  """global step counter for the optimizer."""
628
749
 
629
750
  self.num_evaluations = 0
630
751
  """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
631
752
 
753
+ # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
754
+ # we want to return original loss so this attribute is used
755
+ self._closure_return = None
756
+ """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
757
+
758
+ self.attrs = {}
759
+ """custom attributes that can be set by modules, for example EMA of weights or best so far"""
760
+
761
+ self.should_terminate = False
762
+ """is set to True by termination criteria modules."""
763
+
632
764
  def add_param_group(self, param_group: dict[str, Any]):
633
765
  proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
634
766
  self.param_groups.append(ChainMap(proc_param_group, self.defaults))
@@ -673,10 +805,13 @@ class Modular(torch.optim.Optimizer):
673
805
 
674
806
  id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
675
807
  for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
676
- m.load_state_dict(sd, id_to_tensor)
808
+ m._load_state_dict(sd, id_to_tensor)
677
809
 
678
810
 
679
- def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
811
+ def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
812
+ # clear closure return from previous step
813
+ self._closure_return = None
814
+
680
815
  # propagate global per-parameter setting overrides
681
816
  for g in self.param_groups:
682
817
  settings = dict(g.maps[0]) # ignore defaults
@@ -689,16 +824,17 @@ class Modular(torch.optim.Optimizer):
689
824
 
690
825
  # create var
691
826
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
692
- var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
827
+ var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
693
828
 
694
829
  # if closure is None, assume backward has been called and gather grads
695
830
  if closure is None:
696
831
  var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
697
832
  self.num_evaluations += 1
698
833
 
834
+ n_modules = len(self.modules)
835
+ if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
699
836
  last_module = self.modules[-1]
700
837
  last_lr = last_module.defaults.get('lr', None)
701
- n_modules = len(self.modules)
702
838
 
703
839
  # step
704
840
  for i, module in enumerate(self.modules):
@@ -718,11 +854,17 @@ class Modular(torch.optim.Optimizer):
718
854
  with torch.no_grad():
719
855
  torch._foreach_sub_(params, var.get_update())
720
856
 
857
+ # update attributes
858
+ self.attrs.update(var.attrs)
859
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
860
+
861
+ # hooks
721
862
  for hook in var.post_step_hooks:
722
863
  hook(self, var)
723
864
 
724
865
  self.current_step += 1
725
- return var.loss if var.loss is not None else var.loss_approx
866
+ #return var.loss if var.loss is not None else var.loss_approx
867
+ return self._closure_return
726
868
 
727
869
  def __repr__(self):
728
870
  return f'Modular({", ".join(str(m) for m in self.modules)})'
@@ -738,6 +880,21 @@ class Chain(Module):
738
880
  for i, module in enumerate(flat_modules):
739
881
  self.set_child(f'module_{i}', module)
740
882
 
883
+ def update(self, var):
884
+ # note here that `update` and `apply` shouldn't be used directly
885
+ # as it will update all modules, and then apply all modules
886
+ # it is used in specific cases like Chain as trust region hessian module
887
+ for i in range(len(self.children)):
888
+ self.children[f'module_{i}'].update(var)
889
+ if var.stop: break
890
+ return var
891
+
892
+ def apply(self, var):
893
+ for i in range(len(self.children)):
894
+ var = self.children[f'module_{i}'].apply(var)
895
+ if var.stop: break
896
+ return var
897
+
741
898
  def step(self, var):
742
899
  for i in range(len(self.children)):
743
900
  var = self.children[f'module_{i}'].step(var)
@@ -748,7 +905,7 @@ class Chain(Module):
748
905
  s = self.__class__.__name__
749
906
  if self.children:
750
907
  if s == 'Chain': s = 'C' # to shorten it
751
- s = f'{s}({", ".join(str(m) for m in self.children.values())}'
908
+ s = f'{s}({", ".join(str(m) for m in self.children.values())})'
752
909
  return s
753
910
 
754
911
  def maybe_chain(*modules: Chainable) -> Module:
@@ -0,0 +1,65 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Sequence
3
+
4
+ import torch
5
+
6
+ from .module import Chainable, Modular, Module, Var
7
+
8
+
9
+ class Reformulation(Module, ABC):
10
+ def __init__(self, defaults: dict | None, modules: Chainable | None):
11
+ super().__init__(defaults)
12
+
13
+ if modules is not None:
14
+ self.set_child("modules", modules)
15
+
16
+ @abstractmethod
17
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
18
+ """
19
+ returns (loss, gradient), if backward is False then gradient can be None.
20
+
21
+ If evaluating original loss/gradient at x_0, set them to ``var``.
22
+ """
23
+
24
+ def pre_step(self, var: Var) -> Var | None:
25
+ """This runs once before each step, whereas `closure` may run multiple times per step if further modules
26
+ evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
27
+
28
+ def step(self, var):
29
+ ret = self.pre_step(var) # pylint:disable = assignment-from-no-return
30
+ if isinstance(ret, Var): var = ret
31
+
32
+ if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
+ params, closure = var.params, var.closure
34
+
35
+ # step with children
36
+ if 'modules' in self.children:
37
+
38
+ # make a reformulated closure
39
+ def modified_closure(backward=True):
40
+ loss, grad = self.closure(backward, closure, params, var)
41
+
42
+ if grad is not None:
43
+ for p,g in zip(params, grad):
44
+ p.grad = g
45
+
46
+ return loss
47
+
48
+ # set it to a new Var object
49
+ modified_var = var.clone(clone_update=False)
50
+ modified_var.closure = modified_closure
51
+
52
+ # step with child
53
+ modules = self.children['modules']
54
+ modified_var = modules.step(modified_var)
55
+
56
+ # modified_var.loss and grad refers to loss and grad of a modified objective
57
+ # so we only take the update
58
+ var.update = modified_var.update
59
+
60
+ # or just evaluate new closure and set to update
61
+ else:
62
+ loss, grad = self.closure(backward=True, closure=closure, params=params, var=var)
63
+ if grad is not None: var.update = list(grad)
64
+
65
+ return var