torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
tests/test_vars.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import pytest
2
2
  import torch
3
- from torchzero.core.module import Vars
3
+ from torchzero.core.module import Var
4
4
  from torchzero.utils.tensorlist import TensorList
5
5
 
6
6
  @torch.no_grad
7
- def test_vars_get_loss():
7
+ def test_var_get_loss():
8
8
 
9
9
  # ---------------------------- test that it works ---------------------------- #
10
10
  params = [torch.tensor(2.0, requires_grad=True)]
@@ -26,20 +26,20 @@ def test_vars_get_loss():
26
26
  assert not loss.requires_grad, "loss requires grad with backward=False"
27
27
  return loss
28
28
 
29
- vars = Vars(params=params, closure=closure_1, model=None, current_step=0)
29
+ var = Var(params=params, closure=closure_1, model=None, current_step=0)
30
30
 
31
- assert vars.loss is None, vars.loss
31
+ assert var.loss is None, var.loss
32
32
 
33
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
33
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
34
34
  assert evaluated, evaluated
35
- assert loss is vars.loss
36
- assert vars.loss == 4.0
37
- assert vars.loss_approx == 4.0
38
- assert vars.grad is None, vars.grad
35
+ assert loss is var.loss
36
+ assert var.loss == 4.0
37
+ assert var.loss_approx == 4.0
38
+ assert var.grad is None, var.grad
39
39
 
40
40
  # reevaluate, which should just return already evaluated loss
41
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
42
- assert vars.grad is None, vars.grad
41
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
42
+ assert var.grad is None, var.grad
43
43
 
44
44
 
45
45
  # ----------------------- test that backward=True works ---------------------- #
@@ -61,30 +61,30 @@ def test_vars_get_loss():
61
61
  assert not loss.requires_grad, "loss requires grad with backward=False"
62
62
  return loss
63
63
 
64
- vars = Vars(params=params, closure=closure_2, model=None, current_step=0)
65
- assert vars.grad is None, vars.grad
66
- assert (loss := vars.get_loss(backward=True)) == 6.0, loss
67
- assert vars.grad is not None
68
- assert vars.grad[0] == 2.0, vars.grad
64
+ var = Var(params=params, closure=closure_2, model=None, current_step=0)
65
+ assert var.grad is None, var.grad
66
+ assert (loss := var.get_loss(backward=True)) == 6.0, loss
67
+ assert var.grad is not None
68
+ assert var.grad[0] == 2.0, var.grad
69
69
 
70
70
  # reevaluate, which should just return already evaluated loss
71
- assert (loss := vars.get_loss(backward=True)) == 6.0, loss
72
- assert vars.grad[0] == 2.0, vars.grad
71
+ assert (loss := var.get_loss(backward=True)) == 6.0, loss
72
+ assert var.grad[0] == 2.0, var.grad
73
73
 
74
74
  # get grad, which should just return already evaluated grad
75
- assert (grad := vars.get_grad())[0] == 2.0, grad
76
- assert grad is vars.grad, grad
75
+ assert (grad := var.get_grad())[0] == 2.0, grad
76
+ assert grad is var.grad, grad
77
77
 
78
78
  # get update, which should create and return cloned grad
79
- assert vars.update is None
80
- assert (update := vars.get_update())[0] == 2.0, update
81
- assert update is vars.update
82
- assert update is not vars.grad
83
- assert vars.grad is not None
84
- assert update[0] == vars.grad[0]
79
+ assert var.update is None
80
+ assert (update := var.get_update())[0] == 2.0, update
81
+ assert update is var.update
82
+ assert update is not var.grad
83
+ assert var.grad is not None
84
+ assert update[0] == var.grad[0]
85
85
 
86
86
  @torch.no_grad
87
- def test_vars_get_grad():
87
+ def test_var_get_grad():
88
88
  params = [torch.tensor(2.0, requires_grad=True)]
89
89
  evaluated = False
90
90
 
@@ -103,20 +103,20 @@ def test_vars_get_grad():
103
103
  assert not loss.requires_grad, "loss requires grad with backward=False"
104
104
  return loss
105
105
 
106
- vars = Vars(params=params, closure=closure, model=None, current_step=0)
107
- assert (grad := vars.get_grad())[0] == 4.0, grad
108
- assert grad is vars.grad
106
+ var = Var(params=params, closure=closure, model=None, current_step=0)
107
+ assert (grad := var.get_grad())[0] == 4.0, grad
108
+ assert grad is var.grad
109
109
 
110
- assert vars.loss == 4.0
111
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
112
- assert (loss := vars.get_loss(backward=True)) == 4.0, loss
113
- assert vars.loss_approx == 4.0
110
+ assert var.loss == 4.0
111
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
112
+ assert (loss := var.get_loss(backward=True)) == 4.0, loss
113
+ assert var.loss_approx == 4.0
114
114
 
115
- assert vars.update is None, vars.update
116
- assert (update := vars.get_update())[0] == 4.0, update
115
+ assert var.update is None, var.update
116
+ assert (update := var.get_update())[0] == 4.0, update
117
117
 
118
118
  @torch.no_grad
119
- def test_vars_get_update():
119
+ def test_var_get_update():
120
120
  params = [torch.tensor(2.0, requires_grad=True)]
121
121
  evaluated = False
122
122
 
@@ -135,27 +135,28 @@ def test_vars_get_update():
135
135
  assert not loss.requires_grad, "loss requires grad with backward=False"
136
136
  return loss
137
137
 
138
- vars = Vars(params=params, closure=closure, model=None, current_step=0)
139
- assert vars.update is None, vars.update
140
- assert (update := vars.get_update())[0] == 4.0, update
141
- assert update is vars.update
138
+ var = Var(params=params, closure=closure, model=None, current_step=0)
139
+ assert var.update is None, var.update
140
+ assert (update := var.get_update())[0] == 4.0, update
141
+ assert update is var.update
142
142
 
143
- assert (grad := vars.get_grad())[0] == 4.0, grad
144
- assert grad is vars.grad
143
+ assert (grad := var.get_grad())[0] == 4.0, grad
144
+ assert grad is var.grad
145
145
  assert grad is not update
146
146
 
147
- assert vars.loss == 4.0
148
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
149
- assert (loss := vars.get_loss(backward=True)) == 4.0, loss
150
- assert vars.loss_approx == 4.0
147
+ assert var.loss == 4.0
148
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
149
+ assert (loss := var.get_loss(backward=True)) == 4.0, loss
150
+ assert var.loss_approx == 4.0
151
151
 
152
- assert (update := vars.get_update())[0] == 4.0, update
152
+ assert (update := var.get_update())[0] == 4.0, update
153
153
 
154
154
 
155
- def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
155
+ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
156
156
  for k,v in v1.__dict__.items():
157
157
  if not k.startswith('__'):
158
158
  # if k == 'post_step_hooks': continue
159
+ if k == 'storage': continue
159
160
  if k == 'update' and clone_update:
160
161
  if v1.update is None or v2.update is None:
161
162
  assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
@@ -165,20 +166,20 @@ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
165
166
  else:
166
167
  assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
167
168
 
168
- def test_vars_clone():
169
+ def test_var_clone():
169
170
  model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
170
171
  def closure(backward): return 1
171
- vars = Vars(params=list(model.parameters()), closure=closure, model=model, current_step=0)
172
+ var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
172
173
 
173
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
174
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
174
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
175
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
175
176
 
176
- vars.grad = TensorList(torch.randn(5))
177
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
178
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
177
+ var.grad = TensorList(torch.randn(5))
178
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
179
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
179
180
 
180
- vars.update = TensorList(torch.randn(5) * 2)
181
- vars.loss = torch.randn(1)
182
- vars.loss_approx = vars.loss
183
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
184
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
181
+ var.update = TensorList(torch.randn(5) * 2)
182
+ var.loss = torch.randn(1)
183
+ var.loss_approx = var.loss
184
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
185
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
@@ -1,3 +1,2 @@
1
- from .module import Vars, Module, Modular, Chain, maybe_chain, Chainable
2
- from .transform import Transform, TensorwiseTransform, Target, apply
3
- from .preconditioner import Preconditioner, TensorwisePreconditioner
1
+ from .module import Var, Module, Modular, Chain, maybe_chain, Chainable
2
+ from .transform import Transform, TensorwiseTransform, Target, apply_transform
torchzero/core/module.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any, final, overload
6
+ from typing import Any, final, overload, Literal
7
7
 
8
8
  import torch
9
9
 
@@ -14,6 +14,7 @@ from ..utils import (
14
14
  _make_param_groups,
15
15
  get_state_vals,
16
16
  )
17
+ from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
17
18
  from ..utils.python_tools import flatten
18
19
 
19
20
 
@@ -29,8 +30,8 @@ def _closure_backward(closure, params, retain_graph, create_graph):
29
30
  return loss
30
31
 
31
32
  # region Vars
32
- # ----------------------------------- vars ----------------------------------- #
33
- class Vars:
33
+ # ----------------------------------- var ----------------------------------- #
34
+ class Var:
34
35
  """
35
36
  Holds the state and context passed between optimizer modules during a step.
36
37
 
@@ -74,13 +75,13 @@ class Vars:
74
75
  """loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
75
76
  whereas some other modules require loss strictly at current point."""
76
77
 
77
- self.post_step_hooks: list[Callable[[Modular, Vars]]] = []
78
+ self.post_step_hooks: list[Callable[[Modular, Var]]] = []
78
79
  """list of functions to be called after optimizer step.
79
80
  The signature is:
80
81
 
81
82
  .. code:: py
82
83
 
83
- def hook(optimizer: Modular, vars: Vars): ...
84
+ def hook(optimizer: Modular, var: Vars): ...
84
85
 
85
86
  """
86
87
 
@@ -109,8 +110,11 @@ class Vars:
109
110
  self.skip_update: bool = False
110
111
  """if True, the parameters will not be updated"""
111
112
 
113
+ self.storage: dict = {}
114
+ """Storage for any other data, such as hessian estimates, etc"""
115
+
112
116
  def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
113
- """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`vars.loss`.
117
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
114
118
  Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
115
119
 
116
120
  if self.loss is None:
@@ -143,7 +147,7 @@ class Vars:
143
147
 
144
148
  def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
145
149
  """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
146
- :code:`vars.grad` and potentially :code:`vars.loss`. Do not call this at perturbed parameters."""
150
+ :code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
147
151
  if self.grad is None:
148
152
  if self.closure is None: raise RuntimeError("closure is None")
149
153
  self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
@@ -152,15 +156,15 @@ class Vars:
152
156
  return self.grad
153
157
 
154
158
  def get_update(self) -> list[torch.Tensor]:
155
- """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`vars.update`.
156
- Computing the gradients may assign :code:`vars.grad` and :code:`vars.loss` if they haven't been computed.
159
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
160
+ Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
157
161
  Do not call this at perturbed parameters."""
158
162
  if self.update is None: self.update = [g.clone() for g in self.get_grad()]
159
163
  return self.update
160
164
 
161
165
  def clone(self, clone_update: bool):
162
166
  """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
163
- copy = Vars(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
167
+ copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
164
168
 
165
169
  if clone_update and self.update is not None:
166
170
  copy.update = [u.clone() for u in self.update]
@@ -176,16 +180,17 @@ class Vars:
176
180
 
177
181
  return copy
178
182
 
179
- def update_attrs_from_clone_(self, vars: "Vars"):
183
+ def update_attrs_from_clone_(self, var: "Var"):
180
184
  """Updates attributes of this `Vars` instance from a cloned instance.
181
185
  Typically called after a child module has processed a cloned `Vars`
182
186
  object. This propagates any newly computed loss or gradient values
183
187
  from the child's context back to the parent `Vars` if the parent
184
188
  didn't have them computed already.
185
189
  """
186
- if self.loss is None: self.loss = vars.loss
187
- if self.loss_approx is None: self.loss_approx = vars.loss_approx
188
- if self.grad is None: self.grad = vars.grad
190
+ if self.loss is None: self.loss = var.loss
191
+ if self.loss_approx is None: self.loss_approx = var.loss_approx
192
+ if self.grad is None: self.grad = var.grad
193
+ self.storage.update(var.storage)
189
194
 
190
195
  def zero_grad(self, set_to_none=True):
191
196
  if set_to_none:
@@ -269,36 +274,36 @@ class Module(ABC):
269
274
  return s
270
275
 
271
276
  @overload
272
- def get_settings(self, key: str, *,
273
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike: ...
277
+ def get_settings(self, params: Sequence[torch.Tensor], key: str, *,
278
+ cls: type[ListLike] = list) -> ListLike: ...
274
279
  @overload
275
- def get_settings(self, key: list[str] | tuple[str,...], *,
276
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
280
+ def get_settings(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
281
+ cls: type[ListLike] = list) -> list[ListLike]: ...
277
282
  @overload
278
- def get_settings(self, key: str, key2: str, *keys: str,
279
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
283
+ def get_settings(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
284
+ cls: type[ListLike] = list) -> list[ListLike]: ...
280
285
 
281
- def get_settings(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
282
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike | list[ListLike]:
286
+ def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
287
+ *keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
283
288
  # if isinstance(params, Vars): params = params.params
284
289
  return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
285
290
 
286
291
 
287
292
  @overload
288
- def get_state(self, key: str, *,
289
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init = torch.zeros_like,
293
+ def get_state(self, params: Sequence[torch.Tensor], key: str, *,
294
+ must_exist: bool = False, init: Init = torch.zeros_like,
290
295
  cls: type[ListLike] = list) -> ListLike: ...
291
296
  @overload
292
- def get_state(self, key: list[str] | tuple[str,...], *,
293
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
297
+ def get_state(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
298
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
294
299
  cls: type[ListLike] = list) -> list[ListLike]: ...
295
300
  @overload
296
- def get_state(self, key: str, key2: str, *keys: str,
297
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
301
+ def get_state(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
302
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
298
303
  cls: type[ListLike] = list) -> list[ListLike]: ...
299
304
 
300
- def get_state(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
301
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
305
+ def get_state(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
306
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
302
307
  cls: type[ListLike] = list) -> ListLike | list[ListLike]:
303
308
  """Returns values of per-parameter state for a given key.
304
309
  If key doesn't exist, create it with inits.
@@ -358,6 +363,26 @@ class Module(ABC):
358
363
  # # if isinstance(params, Vars): params = params.params
359
364
  # return itemgetter(*keys)(self.settings[params[0]])
360
365
 
366
+ def clear_state_keys(self, *keys:str):
367
+ for s in self.state.values():
368
+ for k in keys:
369
+ if k in s: del s[k]
370
+
371
+ @overload
372
+ def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
373
+ @overload
374
+ def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
375
+ def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
376
+ if isinstance(keys, str):
377
+ for p,v in zip(params, values):
378
+ state = self.state[p]
379
+ state[keys] = v
380
+ return
381
+
382
+ for p, *p_v in zip(params, *values):
383
+ state = self.state[p]
384
+ for k,v in zip(keys, p_v): state[k] = v
385
+
361
386
  def state_dict(self):
362
387
  """state dict"""
363
388
  packed_state = {id(k):v for k,v in self.state.items()}
@@ -403,23 +428,111 @@ class Module(ABC):
403
428
  self._extra_unpack(state_dict['extra'])
404
429
 
405
430
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
406
- @abstractmethod
407
- def step(self, vars: Vars) -> Vars:
408
- """performs a step, returns new vars but may update them in-place."""
431
+ def step(self, var: Var) -> Var:
432
+ """performs a step, returns new var but may update it in-place."""
433
+ self.update(var)
434
+ return self.apply(var)
435
+
436
+ def update(self, var:Var) -> Any:
437
+ """Updates the internal state of this module. This should not modify `var.update`.
438
+
439
+ Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
440
+ such as ::code::`tz.m.Online`.
441
+ """
442
+
443
+ def apply(self, var: Var) -> Var:
444
+ """Applies this module to ``var.get_update()``. This should not modify the internal state of this module if possible."""
445
+ raise NotImplementedError(f"{self} doesn't implement the `apply` method.")
409
446
 
410
447
  def reset(self):
411
- """Resets the internal state of the module (e.g. momentum)."""
448
+ """Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
412
449
  # no complex logic is allowed there because this is overridden by many modules
413
450
  # where super().reset() shouldn't be called
414
451
  self.state.clear()
415
452
  self.global_state.clear()
416
453
 
454
+ def reset_for_online(self):
455
+ """resets only the intermediate state of this module, e.g. previous parameters and gradient."""
456
+ for c in self.children.values(): c.reset_for_online()
457
+
417
458
  def _extra_pack(self):
418
459
  return {}
419
460
 
420
461
  def _extra_unpack(self, x):
421
462
  pass
422
463
 
464
+
465
+ # ------------------------------ HELPER METHODS ------------------------------ #
466
+ @torch.no_grad
467
+ def Hvp(
468
+ self,
469
+ v: Sequence[torch.Tensor],
470
+ at_x0: bool,
471
+ var: Var,
472
+ rgrad: Sequence[torch.Tensor] | None,
473
+ hvp_method: Literal['autograd', 'forward', 'central'],
474
+ h: float,
475
+ normalize: bool,
476
+ retain_grad: bool,
477
+ ):
478
+ """
479
+ Returns ``(Hvp, rgrad)``. ``rgrad`` is gradient at current parameters, possibly with create_graph=True, or it may be None with ``hvp_method="central"``. Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
480
+
481
+ Single sample example:
482
+
483
+ .. code:: py
484
+
485
+ Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
486
+
487
+ Multiple samples example:
488
+
489
+ .. code:: py
490
+
491
+ D = None
492
+ rgrad = None
493
+ for i in range(n_samples):
494
+ v = [torch.randn_like(p) for p in params]
495
+ Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
496
+
497
+ if D is None: D = Hvp
498
+ else: torch._foreach_add_(D, Hvp)
499
+
500
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
501
+ Args:
502
+ v (Sequence[torch.Tensor]): vector in hessian-vector product
503
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
504
+ var (Var): Var
505
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
506
+ hvp_method (str): hvp method.
507
+ h (float): finite difference step size
508
+ normalize (bool): whether to normalize v for finite difference
509
+ retain_grad (bool): retain grad
510
+ """
511
+ # get grad
512
+ if rgrad is None and hvp_method in ('autograd', 'forward'):
513
+ if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
514
+ else:
515
+ if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
516
+ with torch.enable_grad():
517
+ loss = var.closure()
518
+ rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
519
+
520
+ if hvp_method == 'autograd':
521
+ assert rgrad is not None
522
+ Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
523
+
524
+ elif hvp_method == 'forward':
525
+ assert rgrad is not None
526
+ loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
527
+
528
+ elif hvp_method == 'central':
529
+ loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
530
+
531
+ else:
532
+ raise ValueError(hvp_method)
533
+
534
+ return Hvp, rgrad
535
+
423
536
  # endregion
424
537
 
425
538
  Chainable = Module | Sequence[Module]
@@ -440,6 +553,21 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
440
553
 
441
554
  # region Modular
442
555
  # ---------------------------------- Modular --------------------------------- #
556
+
557
+ class _EvalCounterClosure:
558
+ """keeps track of how many times closure has been evaluated"""
559
+ __slots__ = ("modular", "closure")
560
+ def __init__(self, modular: "Modular", closure):
561
+ self.modular = modular
562
+ self.closure = closure
563
+
564
+ def __call__(self, *args, **kwargs):
565
+ if self.closure is None:
566
+ raise RuntimeError("One of the modules requires closure to be passed to the step method")
567
+
568
+ self.modular.num_evaluations += 1
569
+ return self.closure(*args, **kwargs)
570
+
443
571
  # have to inherit from Modular to support lr schedulers
444
572
  # although Accelerate doesn't work due to converting param_groups to a dict
445
573
  class Modular(torch.optim.Optimizer):
@@ -496,7 +624,10 @@ class Modular(torch.optim.Optimizer):
496
624
  # self.add_param_group(param_group)
497
625
 
498
626
  self.current_step = 0
499
- """The global step counter for the optimizer."""
627
+ """global step counter for the optimizer."""
628
+
629
+ self.num_evaluations = 0
630
+ """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
500
631
 
501
632
  def add_param_group(self, param_group: dict[str, Any]):
502
633
  proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
@@ -556,13 +687,14 @@ class Modular(torch.optim.Optimizer):
556
687
  if not p.requires_grad: continue
557
688
  for map in self._per_parameter_global_settings[p]: map.update(settings)
558
689
 
559
- # create vars
690
+ # create var
560
691
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
561
- vars = Vars(params=params, closure=closure, model=self.model, current_step=self.current_step)
692
+ var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
562
693
 
563
694
  # if closure is None, assume backward has been called and gather grads
564
695
  if closure is None:
565
- vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
696
+ var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
697
+ self.num_evaluations += 1
566
698
 
567
699
  last_module = self.modules[-1]
568
700
  last_lr = last_module.defaults.get('lr', None)
@@ -570,27 +702,27 @@ class Modular(torch.optim.Optimizer):
570
702
 
571
703
  # step
572
704
  for i, module in enumerate(self.modules):
573
- if i!=0: vars = vars.clone(clone_update=False)
705
+ if i!=0: var = var.clone(clone_update=False)
574
706
 
575
707
  # last module, or next to last module before lr
576
708
  if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
577
- if module.children: vars.nested_is_last = True
578
- else: vars.is_last = True
579
- if last_lr is not None: vars.last_module_lrs = last_module.get_settings('lr', params=vars.params)
709
+ if module.children: var.nested_is_last = True
710
+ else: var.is_last = True
711
+ if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
580
712
 
581
- vars = module.step(vars)
582
- if vars.stop: break
713
+ var = module.step(var)
714
+ if var.stop: break
583
715
 
584
716
  # apply update
585
- if not vars.skip_update:
717
+ if not var.skip_update:
586
718
  with torch.no_grad():
587
- torch._foreach_sub_(params, vars.get_update())
719
+ torch._foreach_sub_(params, var.get_update())
588
720
 
589
- for hook in vars.post_step_hooks:
590
- hook(self, vars)
721
+ for hook in var.post_step_hooks:
722
+ hook(self, var)
591
723
 
592
724
  self.current_step += 1
593
- return vars.loss if vars.loss is not None else vars.loss_approx
725
+ return var.loss if var.loss is not None else var.loss_approx
594
726
 
595
727
  def __repr__(self):
596
728
  return f'Modular({", ".join(str(m) for m in self.modules)})'
@@ -606,11 +738,11 @@ class Chain(Module):
606
738
  for i, module in enumerate(flat_modules):
607
739
  self.set_child(f'module_{i}', module)
608
740
 
609
- def step(self, vars):
741
+ def step(self, var):
610
742
  for i in range(len(self.children)):
611
- vars = self.children[f'module_{i}'].step(vars)
612
- if vars.stop: break
613
- return vars
743
+ var = self.children[f'module_{i}'].step(var)
744
+ if var.stop: break
745
+ return var
614
746
 
615
747
  def __repr__(self):
616
748
  s = self.__class__.__name__