torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/core/module.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any, final, overload, Literal
6
+ from typing import Any, final, overload, Literal, cast
7
7
 
8
8
  import torch
9
9
 
@@ -16,6 +16,7 @@ from ..utils import (
16
16
  )
17
17
  from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
18
18
  from ..utils.python_tools import flatten
19
+ from ..utils.linalg.linear_operator import LinearOperator
19
20
 
20
21
 
21
22
  def _closure_backward(closure, params, retain_graph, create_graph):
@@ -33,11 +34,9 @@ def _closure_backward(closure, params, retain_graph, create_graph):
33
34
  # ----------------------------------- var ----------------------------------- #
34
35
  class Var:
35
36
  """
36
- Holds the state and context passed between optimizer modules during a step.
37
+ Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
38
+ Modules take in a ``Var`` object, modify and it is passed to the next module.
37
39
 
38
- This class acts as a mutable container for information relevant to the current
39
- optimization step, such as parameters, gradients, loss, and the computed update.
40
- Modules read from and write to this object to coordinate their actions.
41
40
  """
42
41
  def __init__(
43
42
  self,
@@ -45,6 +44,10 @@ class Var:
45
44
  closure: Callable | None,
46
45
  model: torch.nn.Module | None,
47
46
  current_step: int,
47
+ parent: "Var | None" = None,
48
+ modular: "Modular | None" = None,
49
+ loss: torch.Tensor | None = None,
50
+ storage: dict | None = None,
48
51
  ):
49
52
  self.params: list[torch.Tensor] = params
50
53
  """List of all parameters with requires_grad = True."""
@@ -56,19 +59,31 @@ class Var:
56
59
  """torch.nn.Module object of the model, None if it wasn't specified."""
57
60
 
58
61
  self.current_step: int = current_step
59
- """global current step, starts at 0"""
62
+ """global current step, starts at 0. This may not correspond to module current step,
63
+ for example a module may step every 10 global steps."""
64
+
65
+ self.parent: "Var | None" = parent
66
+ """parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
67
+ Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
68
+ e.g. when projecting."""
69
+
70
+ self.modular: "Modular" = cast(Modular, modular)
71
+ """Modular optimizer object that created this ``Var``."""
60
72
 
61
73
  self.update: list[torch.Tensor] | None = None
62
74
  """
63
- current update, at the end this is subtracted from model parameters unless it is None.
75
+ current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
64
76
 
65
77
  If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
78
+
79
+ At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
80
+ gradient will be used and calculated if needed.
66
81
  """
67
82
 
68
83
  self.grad: list[torch.Tensor] | None = None
69
- """gradient with current parameters. If closure is not None, this is set to None and can be calculated if needed."""
84
+ """gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
70
85
 
71
- self.loss: torch.Tensor | Any | None = None
86
+ self.loss: torch.Tensor | Any | None = loss
72
87
  """loss with current parameters."""
73
88
 
74
89
  self.loss_approx: torch.Tensor | Any | None = None
@@ -77,24 +92,28 @@ class Var:
77
92
 
78
93
  self.post_step_hooks: list[Callable[[Modular, Var]]] = []
79
94
  """list of functions to be called after optimizer step.
80
- The signature is:
81
95
 
82
- .. code:: py
96
+ This attribute should always be modified in-place (using ``append`` or ``extend``).
83
97
 
84
- def hook(optimizer: Modular, var: Vars): ...
98
+ The signature is:
85
99
 
100
+ ```python
101
+ def hook(optimizer: Modular, var: Vars): ...
102
+ ```
86
103
  """
87
104
 
88
105
  self.is_last: bool = False
89
106
  """
90
107
  Indicates that current module is either last or next-to-last before a learning rate module.
91
108
  This is always False if current module has children or is a child.
109
+ This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
92
110
  """
93
111
 
94
112
  self.nested_is_last: bool = False
95
113
  """
96
114
  Indicates that current module is either last or next-to-last before a learning rate module, for modules
97
- that have children.
115
+ that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
116
+ a child of a last module may also receive ``var.nested_is_last=True``.
98
117
  """
99
118
 
100
119
  self.last_module_lrs: list[float] | None = None
@@ -105,19 +124,30 @@ class Var:
105
124
  """
106
125
 
107
126
  self.stop: bool = False
108
- """if True, all following modules will be skipped."""
127
+ """if True, all following modules will be skipped.
128
+ If this module is a child, it only affects modules at the same level (in the same Chain)."""
109
129
 
110
130
  self.skip_update: bool = False
111
- """if True, the parameters will not be updated"""
131
+ """if True, the parameters will not be updated."""
112
132
 
113
- self.storage: dict = {}
114
- """Storage for any other data, such as hessian estimates, etc"""
133
+ # self.storage: dict = {}
134
+ # """Storage for any other data, such as hessian estimates, etc."""
115
135
 
116
- def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
117
- """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
118
- Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
136
+ self.attrs: dict = {}
137
+ """attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
138
+
139
+ if storage is None: storage = {}
140
+ self.storage: dict = storage
141
+ """additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
119
142
 
143
+ self.should_terminate: bool | None = None
144
+ """termination criteria, Modular.should_terminate is set to this after each step if not None"""
145
+
146
+ def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
147
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
148
+ Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
120
149
  if self.loss is None:
150
+
121
151
  if self.closure is None: raise RuntimeError("closure is None")
122
152
  if backward:
123
153
  with torch.enable_grad():
@@ -128,7 +158,10 @@ class Var:
128
158
  # initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
129
159
  # it is technically a more correct approach for when some parameters conditionally receive gradients
130
160
  # and in this case it shouldn't be slower.
131
- self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
161
+
162
+ # next time closure() is called, it will set grad to None.
163
+ # zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
164
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
132
165
  else:
133
166
  self.loss = self.loss_approx = self.closure(False)
134
167
 
@@ -143,11 +176,24 @@ class Var:
143
176
  closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
144
177
  )
145
178
  self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
179
+
180
+ # set parent grad
181
+ if self.parent is not None:
182
+ # the way projections/split work, they make a new closure which evaluates original
183
+ # closure and projects the gradient, and set it as their var.closure.
184
+ # then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
185
+ # and we set it to parent var here.
186
+ if self.parent.loss is None: self.parent.loss = self.loss
187
+ if self.parent.grad is None and backward:
188
+ if all(p.grad is None for p in self.parent.params):
189
+ warnings.warn("Parent grad is None after backward.")
190
+ self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
191
+
146
192
  return self.loss # type:ignore
147
193
 
148
194
  def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
149
195
  """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
150
- :code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
196
+ ``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
151
197
  if self.grad is None:
152
198
  if self.closure is None: raise RuntimeError("closure is None")
153
199
  self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
@@ -156,15 +202,21 @@ class Var:
156
202
  return self.grad
157
203
 
158
204
  def get_update(self) -> list[torch.Tensor]:
159
- """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
160
- Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
205
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
206
+ Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
161
207
  Do not call this at perturbed parameters."""
162
208
  if self.update is None: self.update = [g.clone() for g in self.get_grad()]
163
209
  return self.update
164
210
 
165
- def clone(self, clone_update: bool):
166
- """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
167
- copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
211
+ def clone(self, clone_update: bool, parent: "Var | None" = None):
212
+ """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
213
+
214
+ Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
215
+
216
+ Setting ``parent`` is only if clone's parameters are something different,
217
+ while clone's closure referes to the same objective but with a "view" on parameters.
218
+ """
219
+ copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
168
220
 
169
221
  if clone_update and self.update is not None:
170
222
  copy.update = [u.clone() for u in self.update]
@@ -174,10 +226,16 @@ class Var:
174
226
  copy.grad = self.grad
175
227
  copy.loss = self.loss
176
228
  copy.loss_approx = self.loss_approx
229
+ copy.closure = self.closure
177
230
  copy.post_step_hooks = self.post_step_hooks
178
231
  copy.stop = self.stop
179
232
  copy.skip_update = self.skip_update
180
233
 
234
+ copy.modular = self.modular
235
+ copy.attrs = self.attrs
236
+ copy.storage = self.storage
237
+ copy.should_terminate = self.should_terminate
238
+
181
239
  return copy
182
240
 
183
241
  def update_attrs_from_clone_(self, var: "Var"):
@@ -186,11 +244,15 @@ class Var:
186
244
  object. This propagates any newly computed loss or gradient values
187
245
  from the child's context back to the parent `Vars` if the parent
188
246
  didn't have them computed already.
247
+
248
+ Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
249
+ if the child updates them, the update will affect the parent too.
189
250
  """
190
251
  if self.loss is None: self.loss = var.loss
191
252
  if self.loss_approx is None: self.loss_approx = var.loss_approx
192
253
  if self.grad is None: self.grad = var.grad
193
- self.storage.update(var.storage)
254
+
255
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
194
256
 
195
257
  def zero_grad(self, set_to_none=True):
196
258
  if set_to_none:
@@ -201,6 +263,7 @@ class Var:
201
263
 
202
264
  # endregion
203
265
 
266
+
204
267
  # region Module
205
268
  # ---------------------------------- module ---------------------------------- #
206
269
  class Module(ABC):
@@ -313,17 +376,16 @@ class Module(ABC):
313
376
 
314
377
  If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
315
378
 
316
- .. code:: py
317
-
318
- exp_avg = self.state_vals("exp_avg")
319
- # returns cls (by default TensorList)
379
+ ```python
380
+ exp_avg = self.state_vals("exp_avg")
381
+ # returns cls (by default TensorList)
320
382
 
321
- exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
322
- # returns list of cls
323
-
324
- exp_avg = self.state_vals(["exp_avg"])
325
- # always returns a list of cls, even if got a single key
383
+ exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
384
+ # returns list of cls
326
385
 
386
+ exp_avg = self.state_vals(["exp_avg"])
387
+ # always returns a list of cls, even if got a single key
388
+ ```
327
389
 
328
390
  Args:
329
391
  *keys (str):
@@ -402,7 +464,8 @@ class Module(ABC):
402
464
  }
403
465
  return state_dict
404
466
 
405
- def load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
467
+ def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
468
+ """loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
406
469
  # load state
407
470
  state = state_dict['state']
408
471
  self.state.clear()
@@ -421,7 +484,7 @@ class Module(ABC):
421
484
 
422
485
  # children
423
486
  for k, v in state_dict['children']:
424
- if k in self.children: self.children[k].load_state_dict(v, id_to_tensor)
487
+ if k in self.children: self.children[k]._load_state_dict(v, id_to_tensor)
425
488
  else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
426
489
 
427
490
  # extra info
@@ -429,37 +492,68 @@ class Module(ABC):
429
492
 
430
493
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
431
494
  def step(self, var: Var) -> Var:
432
- """performs a step, returns new var but may update it in-place."""
495
+ """performs a step, returns new ``var`` but may update it in-place."""
433
496
  self.update(var)
434
497
  return self.apply(var)
435
498
 
436
499
  def update(self, var:Var) -> Any:
437
- """Updates the internal state of this module. This should not modify `var.update`.
500
+ """Updates the internal state of this module. This should not modify ``var.update``.
438
501
 
439
502
  Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
440
- such as ::code::`tz.m.Online`.
503
+ such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
441
504
  """
442
505
 
443
506
  def apply(self, var: Var) -> Var:
444
- """Applies this module to ``var.get_update()``. This should not modify the internal state of this module if possible."""
445
- raise NotImplementedError(f"{self} doesn't implement the `apply` method.")
507
+ """Applies this module to ``var.get_update()``.
508
+ This should not modify the internal state of this module if possible.
509
+
510
+ Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
511
+ such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
512
+ """
513
+ return self.step(var)
514
+
515
+ def get_H(self, var: Var) -> LinearOperator | None:
516
+ """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
517
+ The hessian approximation is assumed to be for all parameters concatenated to a vector."""
518
+ # if this method is not defined it searches in children
519
+ # this should be overwritten to return None if child params are different from this modules params
520
+ H = None
521
+ for k,v in self.children.items():
522
+ H_v = v.get_H(var)
523
+
524
+ if (H is not None) and (H_v is not None):
525
+ raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
526
+
527
+ if H_v is not None: H = H_v
528
+
529
+ return H
446
530
 
447
531
  def reset(self):
448
- """Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
449
- # no complex logic is allowed there because this is overridden by many modules
450
- # where super().reset() shouldn't be called
532
+ """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
451
533
  self.state.clear()
452
534
  self.global_state.clear()
535
+ for c in self.children.values(): c.reset()
453
536
 
454
537
  def reset_for_online(self):
455
- """resets only the intermediate state of this module, e.g. previous parameters and gradient."""
538
+ """Resets buffers that depend on previous evaluation, such as previous gradient and loss,
539
+ which may become inaccurate due to mini-batching.
540
+
541
+ ``Online`` module calls ``reset_for_online``,
542
+ then it calls ``update`` with previous parameters,
543
+ then it calls ``update`` with current parameters,
544
+ and then ``apply``.
545
+ """
456
546
  for c in self.children.values(): c.reset_for_online()
457
547
 
458
548
  def _extra_pack(self):
549
+ """extra information to store in state_dict of this optimizer.
550
+ Will be passed to ``_extra_unpack`` when loading the state_dict."""
459
551
  return {}
460
552
 
461
553
  def _extra_unpack(self, x):
462
- pass
554
+ """``_extra_pack`` return will be passed to this method when loading state_dict.
555
+ This method is called after loading the rest of the state dict"""
556
+
463
557
 
464
558
 
465
559
  # ------------------------------ HELPER METHODS ------------------------------ #
@@ -474,30 +568,33 @@ class Module(ABC):
474
568
  h: float,
475
569
  normalize: bool,
476
570
  retain_grad: bool,
477
- ):
571
+ ) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
478
572
  """
479
- Returns ``(Hvp, rgrad)``. ``rgrad`` is gradient at current parameters, possibly with create_graph=True, or it may be None with ``hvp_method="central"``. Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
573
+ Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
574
+ possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
575
+ Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
480
576
 
481
577
  Single sample example:
482
578
 
483
- .. code:: py
484
-
485
- Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
579
+ ```python
580
+ Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
581
+ ```
486
582
 
487
583
  Multiple samples example:
488
584
 
489
- .. code:: py
585
+ ```python
586
+ D = None
587
+ rgrad = None
588
+ for i in range(n_samples):
589
+ v = [torch.randn_like(p) for p in params]
590
+ Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
490
591
 
491
- D = None
492
- rgrad = None
493
- for i in range(n_samples):
494
- v = [torch.randn_like(p) for p in params]
495
- Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
592
+ if D is None: D = Hvp
593
+ else: torch._foreach_add_(D, Hvp)
496
594
 
497
- if D is None: D = Hvp
498
- else: torch._foreach_add_(D, Hvp)
595
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
596
+ ```
499
597
 
500
- if n_samples > 1: torch._foreach_div_(D, n_samples)
501
598
  Args:
502
599
  v (Sequence[torch.Tensor]): vector in hessian-vector product
503
600
  at_x0 (bool): whether this is being called at original or perturbed parameters.
@@ -533,6 +630,14 @@ class Module(ABC):
533
630
 
534
631
  return Hvp, rgrad
535
632
 
633
+ def get_generator(self, device: torch.types.Device, seed: int | None):
634
+ if seed is None: return None
635
+
636
+ if 'generator' not in self.global_state:
637
+ self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
638
+
639
+ return self.global_state['generator']
640
+
536
641
  # endregion
537
642
 
538
643
  Chainable = Module | Sequence[Module]
@@ -555,7 +660,7 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
555
660
  # ---------------------------------- Modular --------------------------------- #
556
661
 
557
662
  class _EvalCounterClosure:
558
- """keeps track of how many times closure has been evaluated"""
663
+ """keeps track of how many times closure has been evaluated, and sets closure return"""
559
664
  __slots__ = ("modular", "closure")
560
665
  def __init__(self, modular: "Modular", closure):
561
666
  self.modular = modular
@@ -565,8 +670,14 @@ class _EvalCounterClosure:
565
670
  if self.closure is None:
566
671
  raise RuntimeError("One of the modules requires closure to be passed to the step method")
567
672
 
673
+ v = self.closure(*args, **kwargs)
674
+
675
+ # set closure return on 1st evaluation
676
+ if self.modular._closure_return is None:
677
+ self.modular._closure_return = v
678
+
568
679
  self.modular.num_evaluations += 1
569
- return self.closure(*args, **kwargs)
680
+ return v
570
681
 
571
682
  # have to inherit from Modular to support lr schedulers
572
683
  # although Accelerate doesn't work due to converting param_groups to a dict
@@ -584,6 +695,7 @@ class Modular(torch.optim.Optimizer):
584
695
  param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
585
696
 
586
697
  def __init__(self, params: Params | torch.nn.Module, *modules: Module):
698
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
587
699
  self.model: torch.nn.Module | None = None
588
700
  """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
589
701
  if isinstance(params, torch.nn.Module):
@@ -617,18 +729,34 @@ class Modular(torch.optim.Optimizer):
617
729
  for m in self.unrolled_modules: defaults.update(m.defaults)
618
730
  super().__init__(param_groups, defaults=defaults)
619
731
 
620
- # note - this is what super init does:
732
+ # note - this is what super().__init__(param_groups, defaults=defaults) does:
621
733
 
622
734
  # self.defaults = defaults
623
735
  # for param_group in param_groups:
624
736
  # self.add_param_group(param_group)
625
737
 
738
+ # add_param_group adds a ChainMap where defaults are lowest priority,
739
+ # and entries specifed in param_groups or scheduler are higher priority.
740
+ # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
741
+ # in each module, settings passed to that module by calling set_param_groups are highest priority
742
+
626
743
  self.current_step = 0
627
744
  """global step counter for the optimizer."""
628
745
 
629
746
  self.num_evaluations = 0
630
747
  """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
631
748
 
749
+ # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
750
+ # we want to return original loss so this attribute is used
751
+ self._closure_return = None
752
+ """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
753
+
754
+ self.attrs = {}
755
+ """custom attributes that can be set by modules, for example EMA of weights or best so far"""
756
+
757
+ self.should_terminate = False
758
+ """is set to True by termination criteria modules."""
759
+
632
760
  def add_param_group(self, param_group: dict[str, Any]):
633
761
  proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
634
762
  self.param_groups.append(ChainMap(proc_param_group, self.defaults))
@@ -673,10 +801,13 @@ class Modular(torch.optim.Optimizer):
673
801
 
674
802
  id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
675
803
  for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
676
- m.load_state_dict(sd, id_to_tensor)
804
+ m._load_state_dict(sd, id_to_tensor)
805
+
677
806
 
807
+ def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
808
+ # clear closure return from previous step
809
+ self._closure_return = None
678
810
 
679
- def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
680
811
  # propagate global per-parameter setting overrides
681
812
  for g in self.param_groups:
682
813
  settings = dict(g.maps[0]) # ignore defaults
@@ -689,16 +820,17 @@ class Modular(torch.optim.Optimizer):
689
820
 
690
821
  # create var
691
822
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
692
- var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
823
+ var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
693
824
 
694
825
  # if closure is None, assume backward has been called and gather grads
695
826
  if closure is None:
696
827
  var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
697
828
  self.num_evaluations += 1
698
829
 
830
+ n_modules = len(self.modules)
831
+ if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
699
832
  last_module = self.modules[-1]
700
833
  last_lr = last_module.defaults.get('lr', None)
701
- n_modules = len(self.modules)
702
834
 
703
835
  # step
704
836
  for i, module in enumerate(self.modules):
@@ -718,11 +850,17 @@ class Modular(torch.optim.Optimizer):
718
850
  with torch.no_grad():
719
851
  torch._foreach_sub_(params, var.get_update())
720
852
 
853
+ # update attributes
854
+ self.attrs.update(var.attrs)
855
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
856
+
857
+ # hooks
721
858
  for hook in var.post_step_hooks:
722
859
  hook(self, var)
723
860
 
724
861
  self.current_step += 1
725
- return var.loss if var.loss is not None else var.loss_approx
862
+ #return var.loss if var.loss is not None else var.loss_approx
863
+ return self._closure_return
726
864
 
727
865
  def __repr__(self):
728
866
  return f'Modular({", ".join(str(m) for m in self.modules)})'
@@ -738,6 +876,21 @@ class Chain(Module):
738
876
  for i, module in enumerate(flat_modules):
739
877
  self.set_child(f'module_{i}', module)
740
878
 
879
+ def update(self, var):
880
+ # note here that `update` and `apply` shouldn't be used directly
881
+ # as it will update all modules, and then apply all modules
882
+ # it is used in specific cases like Chain as trust region hessian module
883
+ for i in range(len(self.children)):
884
+ self.children[f'module_{i}'].update(var)
885
+ if var.stop: break
886
+ return var
887
+
888
+ def apply(self, var):
889
+ for i in range(len(self.children)):
890
+ var = self.children[f'module_{i}'].apply(var)
891
+ if var.stop: break
892
+ return var
893
+
741
894
  def step(self, var):
742
895
  for i in range(len(self.children)):
743
896
  var = self.children[f'module_{i}'].step(var)
@@ -748,7 +901,7 @@ class Chain(Module):
748
901
  s = self.__class__.__name__
749
902
  if self.children:
750
903
  if s == 'Chain': s = 'C' # to shorten it
751
- s = f'{s}({", ".join(str(m) for m in self.children.values())}'
904
+ s = f'{s}({", ".join(str(m) for m in self.children.values())})'
752
905
  return s
753
906
 
754
907
  def maybe_chain(*modules: Chainable) -> Module:
@@ -0,0 +1,65 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Sequence
3
+
4
+ import torch
5
+
6
+ from .module import Chainable, Modular, Module, Var
7
+
8
+
9
+ class Reformulation(Module, ABC):
10
+ def __init__(self, defaults: dict | None, modules: Chainable | None):
11
+ super().__init__(defaults)
12
+
13
+ if modules is not None:
14
+ self.set_child("modules", modules)
15
+
16
+ @abstractmethod
17
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
18
+ """
19
+ returns (loss, gradient), if backward is False then gradient can be None.
20
+
21
+ If evaluating original loss/gradient at x_0, set them to ``var``.
22
+ """
23
+
24
+ def pre_step(self, var: Var) -> Var | None:
25
+ """This runs once before each step, whereas `closure` may run multiple times per step if further modules
26
+ evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
27
+
28
+ def step(self, var):
29
+ ret = self.pre_step(var) # pylint:disable = assignment-from-no-return
30
+ if isinstance(ret, Var): var = ret
31
+
32
+ if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
+ params, closure = var.params, var.closure
34
+
35
+ # step with children
36
+ if 'modules' in self.children:
37
+
38
+ # make a reformulated closure
39
+ def modified_closure(backward=True):
40
+ loss, grad = self.closure(backward, closure, params, var)
41
+
42
+ if grad is not None:
43
+ for p,g in zip(params, grad):
44
+ p.grad = g
45
+
46
+ return loss
47
+
48
+ # set it to a new Var object
49
+ modified_var = var.clone(clone_update=False)
50
+ modified_var.closure = modified_closure
51
+
52
+ # step with child
53
+ modules = self.children['modules']
54
+ modified_var = modules.step(modified_var)
55
+
56
+ # modified_var.loss and grad refers to loss and grad of a modified objective
57
+ # so we only take the update
58
+ var.update = modified_var.update
59
+
60
+ # or just evaluate new closure and set to update
61
+ else:
62
+ loss, grad = self.closure(backward=True, closure=closure, params=params, var=var)
63
+ if grad is not None: var.update = list(grad)
64
+
65
+ return var