torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,442 +1,336 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import Iterable, Mapping, Sequence
3
- from typing import Any, Literal, final
2
+ from collections.abc import Mapping, Sequence
3
+ from operator import itemgetter
4
+ from typing import Any, final, cast, TYPE_CHECKING
4
5
 
5
6
  import torch
6
7
 
7
- from ..utils import TensorList, set_storage_, vec_to_tensors
8
- from .chain import Chain
9
- from .module import Chainable, Module
10
- from .var import Var
8
+ from .module import Module
9
+ from ..utils import vec_to_tensors, safe_dict_update_
11
10
 
12
- Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
+ if TYPE_CHECKING:
12
+ from .chain import Chainable
13
+ from .objective import Objective
13
14
 
14
15
 
15
- class Transform(Module, ABC):
16
- """Base class for a transform.
17
- This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
16
+ class Transform(Module):
17
+ """``Transform`` is a ``Module`` with only optional children.
18
18
 
19
- A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
20
- It has two methods:
19
+ ``Transform`` if more flexible in that as long as there are no children, it can use a custom list of states
20
+ and settings instead of ``self.state`` and ``self.setting``.
21
21
 
22
- - ``update_tensors`` updates the internal state of this transform, it doesn't modify tensors. \
23
- It may be called multiple times before ``apply_tensors``.
24
- - ``apply_tensors`` applies this transform to tensors, without modifying the internal state if possible.
22
+ To use, subclass this and override ``update_states`` and ``apply_states``.
23
+ """
24
+ def __init__(self, defaults: dict[str, Any] | None = None, update_freq: int = 1, inner: "Chainable | None" = None):
25
+
26
+ # store update_freq in defaults so that it is scheduleable
27
+ if defaults is None: defaults = {}
28
+ safe_dict_update_(defaults, {"__update_freq": update_freq})
29
+
30
+ super().__init__(defaults)
31
+
32
+ self._objective = None
33
+ if inner is not None:
34
+ self.set_child("inner", inner)
35
+
36
+ # settings shouldn't mutate, so they are typed as Sequence[Mapping]
37
+ def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
38
+ """Updates ``states``. This should not modify ``objective.update``."""
39
+
40
+ @abstractmethod
41
+ def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
42
+ """Updates ``objective`` using ``states``."""
43
+
44
+ def _get_states_settings(self, objective: "Objective") -> tuple[list, tuple]:
45
+ # itemgetter is faster
46
+ # but need to make sure it returns a tuple, as if there is a single param, it returns the value
47
+ getter = itemgetter(*objective.params)
48
+ is_single = len(objective.params) == 1
49
+ states = getter(self.state)
50
+ settings = getter(self.settings)
25
51
 
26
- Alternatively, if update-apply structure doesn't make sense for a transform, all logic can be defined within ``apply_tensors``.
52
+ if is_single:
53
+ states = [states, ]
54
+ settings = (settings, )
27
55
 
28
- Transform can be applied to tensors corresponding to custom parameters
29
- by calling ``keyed_transform_update`` and ``keyed_transform_apply``,
30
- parameters will be keys to store per-parameter states, so they should remain the same python objects.
56
+ else:
57
+ states = list(states) # itemgetter returns tuple
58
+
59
+ return states, settings
60
+
61
+ @final
62
+ def update(self, objective:"Objective"):
63
+ step = self.increment_counter("__step", 0)
64
+
65
+ if step % self.settings[objective.params[0]]["__update_freq"] == 0:
66
+ states, settings = self._get_states_settings(objective)
67
+ self.update_states(objective=objective, states=states, settings=settings)
68
+
69
+ @final
70
+ def apply(self, objective: "Objective"):
31
71
 
32
- Alternatively you can manually create a list of state dictionaries per each tensor and pass it to
33
- ``transform_update`` and ``transform_apply``.
72
+ # inner step
73
+ if "inner" in self.children:
74
+ inner = self.children["inner"]
75
+ objective = inner.step(objective)
34
76
 
35
- A transform can modify the closure instead of directly modifying update by passing ``target="closure"``.
77
+ # apply and return
78
+ states, settings = self._get_states_settings(objective)
79
+ return self.apply_states(objective=objective, states=states, settings=settings)
36
80
 
37
- Args:
38
- defaults (dict[str,Any] | None): dict with default values.
39
- uses_grad (bool):
40
- Set this to True if `transform` method uses the `grad` argument. This will ensure
41
- `grad` is always computed and can't be None. Otherwise set to False.
42
- target (Target, optional):
43
- what to set on var. Defaults to 'update'.
44
81
 
82
+
83
+ class TensorTransform(Transform):
84
+ """``TensorTransform`` is a ``Transform`` that doesn't use ``Objective``, instead it operates
85
+ on lists of tensors directly.
86
+
87
+ This has a ``concat_params`` setting which is used in quite a few modules, for example it is optional
88
+ in all full-matrix method like Quasi-Newton or full-matrix Adagrad.
89
+
90
+ To use, subclass this and override one of ``single_tensor_update`` or ``multi_tensor_update``,
91
+ and one of ``single_tensor_apply`` or ``multi_tensor_apply``.
92
+
93
+ For copying:
94
+
95
+ multi tensor:
96
+ ```
97
+ def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
98
+ ...
99
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
100
+ ...
101
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
102
+ ...
103
+ ```
104
+
105
+ single tensor:
106
+
107
+ ```
108
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
109
+ ...
110
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
111
+ ...
112
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
113
+ ...
114
+ ```
45
115
  """
46
116
  def __init__(
47
117
  self,
48
- defaults: dict[str,Any] | None,
118
+ defaults: dict[str, Any] | None = None,
119
+ update_freq: int = 1,
120
+ concat_params: bool = False,
49
121
  uses_grad: bool = False,
50
122
  uses_loss: bool = False,
51
- concat_params: bool = False,
52
- update_freq: int = 1,
53
- inner: Chainable | None = None,
54
- target: Target = 'update',
123
+ inner: "Chainable | None" = None,
55
124
  ):
56
- super().__init__(defaults)
57
- self._target: Target = target
125
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
126
+
127
+ self._concat_params = concat_params
58
128
  self._uses_grad = uses_grad
59
129
  self._uses_loss = uses_loss
60
- self._concat_params = concat_params
61
- self._update_freq = update_freq
62
- self._inner = inner
63
- self._var = None
64
130
 
65
- def update_tensors(
131
+ # ------------------------------- single tensor ------------------------------ #
132
+ def single_tensor_initialize(
66
133
  self,
67
- tensors: list[torch.Tensor],
68
- params: list[torch.Tensor],
69
- grads: list[torch.Tensor] | None,
70
- loss: torch.Tensor | float | None,
71
- states: list[dict[str, Any]],
72
- settings: Sequence[Mapping[str, Any]],
134
+ tensor: torch.Tensor,
135
+ param: torch.Tensor,
136
+ grad: torch.Tensor | None,
137
+ loss: torch.Tensor | None,
138
+ state: dict[str, Any],
139
+ setting: Mapping[str, Any],
73
140
  ) -> None:
74
- """update function, this shouldn't be called directly. Updates this module."""
141
+ """initialize ``state`` before first ``update``.
142
+ """
75
143
 
76
- @abstractmethod
77
- def apply_tensors(
144
+ def single_tensor_update(
145
+ self,
146
+ tensor: torch.Tensor,
147
+ param: torch.Tensor,
148
+ grad: torch.Tensor | None,
149
+ loss: torch.Tensor | None,
150
+ state: dict[str, Any],
151
+ setting: Mapping[str, Any],
152
+ ) -> None:
153
+ """Updates ``state``. This should not modify ``tensor``.
154
+ """
155
+
156
+ def single_tensor_apply(
157
+ self,
158
+ tensor: torch.Tensor,
159
+ param: torch.Tensor,
160
+ grad: torch.Tensor | None,
161
+ loss: torch.Tensor | None,
162
+ state: dict[str, Any],
163
+ setting: Mapping[str, Any],
164
+ ) -> torch.Tensor:
165
+ """Updates ``tensor`` and returns it. This shouldn't modify ``state`` if possible.
166
+ """
167
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `single_tensor_apply`.")
168
+
169
+ # ------------------------------- multi tensor ------------------------------- #
170
+ def multi_tensor_initialize(
78
171
  self,
79
172
  tensors: list[torch.Tensor],
80
173
  params: list[torch.Tensor],
81
174
  grads: list[torch.Tensor] | None,
82
- loss: torch.Tensor | float | None,
175
+ loss: torch.Tensor | None,
83
176
  states: list[dict[str, Any]],
84
177
  settings: Sequence[Mapping[str, Any]],
85
- ) -> Sequence[torch.Tensor]:
86
- """apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
87
- If possible, this shouldn't modify the internal state of this transform."""
178
+ ) -> None:
179
+ """initialize ``states`` before first ``update``.
180
+ By default calls ``single_tensor_initialize`` on all tensors.
181
+ """
182
+ if grads is None:
183
+ grads = cast(list, [None] * len(tensors))
88
184
 
89
- @final
90
- @torch.no_grad
91
- def update_transform(
185
+ for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
186
+ self.single_tensor_initialize(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
187
+
188
+ def multi_tensor_update(
92
189
  self,
93
190
  tensors: list[torch.Tensor],
94
191
  params: list[torch.Tensor],
95
192
  grads: list[torch.Tensor] | None,
96
- loss: torch.Tensor | float | None,
193
+ loss: torch.Tensor | None,
97
194
  states: list[dict[str, Any]],
98
- settings: Sequence[Mapping[str, Any]] | None,
195
+ settings: Sequence[Mapping[str, Any]],
99
196
  ) -> None:
100
- """Updates this transform from an arbitrary sequence of tensors."""
101
- if self._concat_params:
102
- tensors = [torch.cat([t.ravel() for t in tensors])]
103
- params = [torch.cat([p.ravel() for p in params])]
104
- grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
105
-
106
- if settings is None:
107
- settings = [self.defaults for _ in tensors]
108
-
109
- step = self.global_state.get('__step', 0) # that way it gets reset correctly
110
- self.global_state['__step'] = step + 1
111
-
112
- num = len(tensors)
113
- states = states[:num]
114
- settings = settings[:num]
197
+ """Updates ``states``. This should not modify ``tensor``.
198
+ By default calls ``single_tensor_update`` on all tensors.
199
+ """
115
200
 
116
- # update transform
117
- if step % self._update_freq == 0:
118
- self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
201
+ if grads is None:
202
+ grads = cast(list, [None] * len(tensors))
119
203
 
120
- # store for transform_apply
121
- self.global_state["__tensors"] = tensors
122
- self.global_state["__params"] = params
123
- self.global_state["__grads"] = grads
204
+ for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
205
+ self.single_tensor_update(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
124
206
 
125
-
126
- @final
127
- @torch.no_grad
128
- def apply_transform(
207
+ def multi_tensor_apply(
129
208
  self,
130
209
  tensors: list[torch.Tensor],
131
210
  params: list[torch.Tensor],
132
211
  grads: list[torch.Tensor] | None,
133
- loss: torch.Tensor | float | None,
212
+ loss: torch.Tensor | None,
134
213
  states: list[dict[str, Any]],
135
- settings: Sequence[Mapping[str, Any]] | None,
136
- ) -> list[torch.Tensor]:
137
- """Applies this transform to an arbitrary sequence of tensors.
138
- This can be used after ``transform_update`` has been used at least once."""
139
-
140
- if settings is None:
141
- settings = [self.defaults for _ in tensors]
214
+ settings: Sequence[Mapping[str, Any]],
215
+ ) -> Sequence[torch.Tensor]:
216
+ """Updates ``tensors`` and returns it. This shouldn't modify ``state`` if possible.
217
+ By default calls ``single_tensor_apply`` on all tensors.
218
+ """
142
219
 
143
- num = len(tensors)
144
- states = states[:num]
145
- settings = settings[:num]
220
+ if grads is None:
221
+ grads = cast(list, [None] * len(tensors))
146
222
 
147
- un_tensors = tensors
148
- un_params = params
149
- un_grads = grads
223
+ ret = []
224
+ for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
225
+ u = self.single_tensor_apply(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
226
+ ret.append(u)
150
227
 
151
- tensors = self.global_state.pop("__tensors")
152
- params = self.global_state.pop("__params")
153
- grads = self.global_state.pop("__grads")
228
+ return ret
154
229
 
155
- # step with inner
156
- if self._inner is not None:
157
- tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
158
- if self._concat_params:
159
- tensors = [torch.cat([t.ravel() for t in tensors])]
230
+ def _get_grads_loss(self, objective: "Objective"):
231
+ """evaluates grads and loss only if needed"""
160
232
 
161
- # apply transform
162
- tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
233
+ if self._uses_grad: grads = objective.get_grads()
234
+ else: grads = None # better explicitly set to None rather than objective.grads because it shouldn't be used
163
235
 
164
- if self._concat_params:
165
- tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
236
+ if self._uses_loss: loss = objective.get_loss(backward=True)
237
+ else: loss = None
166
238
 
167
- return tensors
239
+ return grads, loss
168
240
 
169
- def _get_keyed_states_settings(self, params: list[torch.Tensor]):
170
- if self._concat_params:
171
- p = params[0]
172
- states = [self.state[p]]
173
- settings = [self.settings[p]]
241
+ @torch.no_grad
242
+ def _get_cat_updates_params_grads(self, objective: "Objective", grads: list[torch.Tensor] | None):
243
+ assert self._concat_params
174
244
 
175
- else:
176
- states = []
177
- settings = []
178
- for p in params:
179
- states.append(self.state[p])
180
- settings.append(self.settings[p])
245
+ cat_updates = [torch.cat([u.ravel() for u in objective.get_updates()])]
246
+ cat_params = [torch.cat([p.ravel() for p in objective.params])]
181
247
 
182
- return states, settings
248
+ if grads is None: cat_grads = None
249
+ else: cat_grads = [torch.cat([g.ravel() for g in grads])]
183
250
 
184
- @final
185
- @torch.no_grad
186
- def keyed_transform_update(
187
- self,
188
- tensors: list[torch.Tensor],
189
- params: list[torch.Tensor],
190
- grads: list[torch.Tensor] | None,
191
- loss: torch.Tensor | float | None,
192
- ):
193
- """`params` will be used as keys and need to always point to same tensor objects.`"""
194
- states, settings = self._get_keyed_states_settings(params)
195
- self.update_transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
251
+ return cat_updates, cat_params, cat_grads
196
252
 
253
+ def _gather_tensors(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
254
+ """returns everything for ``multi_tensor_*``. Concatenates if ```self._concat_params``.
255
+ evaluates grads and loss if ``self._uses_grad`` and ``self._uses_loss``"""
197
256
 
198
- @final
199
- @torch.no_grad
200
- def keyed_transform_apply(
201
- self,
202
- tensors: list[torch.Tensor],
203
- params: list[torch.Tensor],
204
- grads: list[torch.Tensor] | None,
205
- loss: torch.Tensor | float | None,
206
- ):
207
- """`params` will be used as keys and need to always point to same tensor objects.`"""
208
- states, settings = self._get_keyed_states_settings(params)
209
- return self.apply_transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
257
+ # evaluate grads and loss if `self._uses_grad` and `self._uses_loss`
258
+ grads, loss = self._get_grads_loss(objective)
210
259
 
260
+ # gather all things
261
+ # concatenate everything to a vec if `self._concat_params`
262
+ if self._concat_params:
263
+ tensors, params, grads = self._get_cat_updates_params_grads(objective, grads)
264
+ states = [states[0]]; settings = [settings[0]]
211
265
 
212
- def pre_step(self, var: Var) -> None:
213
- """Logic to run pre-transform, this way transform has access to Var."""
214
- def post_step(self, var: Var) -> None:
215
- """Logic to run post-transform, this way transform has access to Var."""
266
+ # or take original values
267
+ else:
268
+ tensors=objective.get_updates()
269
+ params = objective.params
216
270
 
217
- def update(self, var: Var):
218
- if self._target != 'update':
219
- raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
220
- f"With {self._target = } only `step` method can be used.")
271
+ return tensors, params, grads, loss, states, settings
221
272
 
222
- # var may change, therefore current params and grads have to be extracted and passed explicitly
223
- update = var.get_update() # this sets loss
224
- if self._uses_grad: var.get_grad()
225
- if self._uses_loss: var.get_loss(False)
226
- params=var.params
227
- self.pre_step(var)
273
+ @final
274
+ def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
275
+ tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
276
+
277
+ # initialize before the first update
278
+ num_updates = self.increment_counter("__num_updates", 0)
279
+ if num_updates == 0:
280
+ self.multi_tensor_initialize(
281
+ tensors=tensors,
282
+ params=params,
283
+ grads=grads,
284
+ loss=loss,
285
+ states=states,
286
+ settings=settings
287
+ )
228
288
 
229
289
  # update
230
- self._var = var
231
- self.keyed_transform_update(update, params, var.grad, var.loss)
232
- self._var = None
233
-
234
- def apply(self, var: Var):
235
- if self._target != 'update':
236
- raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
237
- f"With {self._target = } only `step` method can be used.")
290
+ self.multi_tensor_update(
291
+ tensors=tensors,
292
+ params=params,
293
+ grads=grads,
294
+ loss=loss,
295
+ states=states,
296
+ settings=settings
297
+ )
238
298
 
239
- # var may change, therefore current params and grads have to be extracted and passed explicitly
240
- update = var.get_update() # this sets loss
241
- if self._uses_grad: var.get_grad()
242
- if self._uses_loss: var.get_loss(False)
243
- params=var.params
299
+ @final
300
+ def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
301
+ tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
302
+ # note: _gather tensors will re-cat again if `_concat_params`, this is necessary because objective
303
+ # may have been modified in functional logic, there is no way to know if that happened
244
304
 
245
305
  # apply
246
- self._var = var
247
- var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
248
- self._var = None
249
-
250
- self.post_step(var)
251
- return var
252
-
253
- def step(self, var: Var) -> Var:
254
-
255
- # var may change, therefore current params and grads have to be extracted and passed explicitly
256
- if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
257
- if self._uses_grad or self._target == 'grad': var.get_grad()
258
- if self._uses_loss: var.get_loss(False)
259
- params=var.params
260
- self.pre_step(var)
261
- self._var = var
262
-
263
- # ---------------------------------- update ---------------------------------- #
264
- if self._target == 'update':
265
- update = var.get_update()
266
- self.keyed_transform_update(update, params, var.grad, var.loss)
267
- var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
268
- self._var = None
269
- return var
270
-
271
- # ----------------------------------- grad ----------------------------------- #
272
- if self._target == 'grad':
273
- grad = var.get_grad()
274
- self.keyed_transform_update(grad, params, grad, var.loss)
275
- var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
276
- self._var = None
277
- return var
278
-
279
- # ------------------------------- params_direct ------------------------------ #
280
- if self._target == 'params_direct':
281
- self.keyed_transform_update(var.params, params, var.grad, var.loss)
282
- new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
283
- for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
284
- self._var = None
285
- return var
286
-
287
- # ----------------------------- params_differnce ----------------------------- #
288
- if self._target == 'params_difference':
289
- p_clone = [p.clone() for p in var.params]
290
- self.keyed_transform_update(p_clone, params, var.grad, var.loss)
291
- new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
292
- var.update = list(torch._foreach_sub(var.params, new_params))
293
- self._var = None
294
- return var
295
-
296
- # ----------------------------- update_difference ---------------------------- #
297
- if self._target == 'update_difference':
298
- update = var.get_update()
299
- u_clone = [u.clone() for u in update]
300
- self.keyed_transform_update(u_clone, params, var.grad, var.loss)
301
- new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
302
- var.update = list(torch._foreach_sub(update, new_update))
303
- self._var = None
304
- return var
305
-
306
- # ---------------------------------- closure --------------------------------- #
307
- if self._target == 'closure':
308
- original_closure = var.closure
309
- if original_closure is None: raise ValueError('Target = "closure", but closure is None')
310
-
311
- params = var.params
312
- parent_var = self._var
313
- def transformed_closure(backward=True):
314
- if backward:
315
- loss = original_closure()
316
- current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
317
-
318
- self._var = parent_var
319
- self.keyed_transform_update(current_grad, params, var.grad, var.loss)
320
- transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
321
- self._var = None
322
-
323
- for p, g in zip(params, transformed_grad):
324
- p.grad = g
325
-
326
- else:
327
- loss = original_closure(False)
328
-
329
- return loss
330
-
331
- var.closure = transformed_closure
332
- self.post_step(var)
333
- self._var = None
334
- return var
335
-
336
- # ---------------------------------- invalid --------------------------------- #
337
- raise ValueError(f'Invalid target: {self._target}')
338
-
339
-
340
- class TensorwiseTransform(Transform, ABC):
341
- """Base class for a parameter-wise transform.
342
-
343
- This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
344
-
345
- Args:
346
- defaults (dict[str,Any] | None): dict with default values.
347
- uses_grad (bool):
348
- Set this to True if `transform` method uses the `grad` argument. This will ensure
349
- `grad` is always computed and can't be None. Otherwise set to False.
350
- target (Target, optional):
351
- what to set on var. Defaults to 'update'.
352
- """
353
- def __init__(
354
- self,
355
- defaults: dict[str,Any] | None,
356
- uses_grad: bool = False,
357
- uses_loss: bool = False,
358
- concat_params: bool = False,
359
- update_freq: int = 1,
360
- inner: Chainable | None = None,
361
- target: Target = 'update',
362
- ):
363
- super().__init__(
364
- defaults=defaults,
365
- uses_grad=uses_grad,
366
- concat_params=concat_params,
367
- update_freq=update_freq,
368
- uses_loss=uses_loss,
369
- inner=inner,
370
- target=target,
306
+ ret = self.multi_tensor_apply(
307
+ tensors=tensors,
308
+ params=params,
309
+ grads=grads,
310
+ loss=loss,
311
+ states=states,
312
+ settings=settings
371
313
  )
372
314
 
373
- def update_tensor(
374
- self,
375
- tensor: torch.Tensor,
376
- param: torch.Tensor,
377
- grad: torch.Tensor | None,
378
- loss: torch.Tensor | float | None,
379
- state: dict[str, Any],
380
- setting: Mapping[str, Any],
381
- ) -> None:
382
- """Updates this transform. By default does nothing - if logic is in `apply` method."""
315
+ # uncat if needed and set objective.updates and return objective
316
+ if self._concat_params:
317
+ objective.updates = vec_to_tensors(ret[0], objective.params)
383
318
 
384
- @abstractmethod
385
- def apply_tensor(
386
- self,
387
- tensor: torch.Tensor,
388
- param: torch.Tensor,
389
- grad: torch.Tensor | None,
390
- loss: torch.Tensor | float | None,
391
- state: dict[str, Any],
392
- setting: Mapping[str, Any],
393
- ) -> torch.Tensor:
394
- """Applies the update rule to `tensor`."""
319
+ else:
320
+ objective.updates = list(ret)
395
321
 
396
- @final
397
- def update_tensors(self, tensors, params, grads, loss, states, settings):
398
- if grads is None: grads = [None]*len(tensors)
399
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
400
- self.update_tensor(t, p, g, loss, state, setting)
322
+ return objective
401
323
 
402
- @final
403
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
404
- applied = []
405
- if grads is None: grads = [None]*len(tensors)
406
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
407
- applied.append(self.apply_tensor(t, p, g, loss, state, setting))
408
- return applied
409
-
410
- def apply_transform(
411
- tfm: Chainable,
412
- tensors: list[torch.Tensor],
413
- params: list[torch.Tensor],
414
- grads: list[torch.Tensor] | None,
415
- loss: torch.Tensor | float | None = None,
416
- var: Var | None = None,
417
- current_step: int = 0,
418
- ):
419
- if var is None:
420
- var = Var(params=params, closure=None, model=None, current_step=current_step)
421
- var.loss = loss
422
-
423
- if isinstance(tfm, Transform) and tfm._target == 'update':
424
- if tfm._uses_grad and grads is None: grads = var.get_grad()
425
- tfm.keyed_transform_update(tensors, params, grads, loss)
426
- return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
427
-
428
- if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
429
- if isinstance(tfm, Sequence):
430
- for module in tfm:
431
- tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
432
- return tensors
433
-
434
- if isinstance(tfm, Module):
435
- cvar = var.clone(clone_update=False)
436
- cvar.update = tensors
437
- cvar = tfm.step(cvar)
438
- var.update_attrs_from_clone_(cvar)
439
- assert cvar.update is not None
440
- return cvar.update
441
-
442
- raise TypeError(type(tfm))
324
+
325
+ # make sure _concat_params, _uses_grad and _uses_loss are saved in `state_dict`
326
+ def _extra_pack(self):
327
+ return {
328
+ "__concat_params": self._concat_params,
329
+ "__uses_grad": self._uses_grad,
330
+ "__uses_loss": self._uses_loss,
331
+ }
332
+
333
+ def _extra_unpack(self, d):
334
+ self._concat_params = d["__concat_params"]
335
+ self._uses_grad = d["__uses_grad"]
336
+ self._uses_loss = d["__uses_loss"]