torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,440 +1,336 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import Iterable, Mapping, Sequence
3
- from typing import Any, Literal, final
2
+ from collections.abc import Mapping, Sequence
3
+ from operator import itemgetter
4
+ from typing import Any, final, cast, TYPE_CHECKING
4
5
 
5
6
  import torch
6
7
 
7
- from ..utils import TensorList, set_storage_, vec_to_tensors
8
- from .module import Chain, Chainable, Module, Var
8
+ from .module import Module
9
+ from ..utils import vec_to_tensors, safe_dict_update_
9
10
 
10
- Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
+ if TYPE_CHECKING:
12
+ from .chain import Chainable
13
+ from .objective import Objective
11
14
 
12
15
 
13
- class Transform(Module, ABC):
14
- """Base class for a transform.
15
- This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
16
+ class Transform(Module):
17
+ """``Transform`` is a ``Module`` with only optional children.
16
18
 
17
- A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
18
- It has two methods:
19
+ ``Transform`` if more flexible in that as long as there are no children, it can use a custom list of states
20
+ and settings instead of ``self.state`` and ``self.setting``.
19
21
 
20
- - ``update_tensors`` updates the internal state of this transform, it doesn't modify tensors. \
21
- It may be called multiple times before ``apply_tensors``.
22
- - ``apply_tensors`` applies this transform to tensors, without modifying the internal state if possible.
22
+ To use, subclass this and override ``update_states`` and ``apply_states``.
23
+ """
24
+ def __init__(self, defaults: dict[str, Any] | None = None, update_freq: int = 1, inner: "Chainable | None" = None):
25
+
26
+ # store update_freq in defaults so that it is scheduleable
27
+ if defaults is None: defaults = {}
28
+ safe_dict_update_(defaults, {"__update_freq": update_freq})
29
+
30
+ super().__init__(defaults)
31
+
32
+ self._objective = None
33
+ if inner is not None:
34
+ self.set_child("inner", inner)
35
+
36
+ # settings shouldn't mutate, so they are typed as Sequence[Mapping]
37
+ def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
38
+ """Updates ``states``. This should not modify ``objective.update``."""
39
+
40
+ @abstractmethod
41
+ def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
42
+ """Updates ``objective`` using ``states``."""
43
+
44
+ def _get_states_settings(self, objective: "Objective") -> tuple[list, tuple]:
45
+ # itemgetter is faster
46
+ # but need to make sure it returns a tuple, as if there is a single param, it returns the value
47
+ getter = itemgetter(*objective.params)
48
+ is_single = len(objective.params) == 1
49
+ states = getter(self.state)
50
+ settings = getter(self.settings)
23
51
 
24
- Alternatively, if update-apply structure doesn't make sense for a transform, all logic can be defined within ``apply_tensors``.
52
+ if is_single:
53
+ states = [states, ]
54
+ settings = (settings, )
25
55
 
26
- Transform can be applied to tensors corresponding to custom parameters
27
- by calling ``keyed_transform_update`` and ``keyed_transform_apply``,
28
- parameters will be keys to store per-parameter states, so they should remain the same python objects.
56
+ else:
57
+ states = list(states) # itemgetter returns tuple
58
+
59
+ return states, settings
60
+
61
+ @final
62
+ def update(self, objective:"Objective"):
63
+ step = self.increment_counter("__step", 0)
64
+
65
+ if step % self.settings[objective.params[0]]["__update_freq"] == 0:
66
+ states, settings = self._get_states_settings(objective)
67
+ self.update_states(objective=objective, states=states, settings=settings)
68
+
69
+ @final
70
+ def apply(self, objective: "Objective"):
29
71
 
30
- Alternatively you can manually create a list of state dictionaries per each tensor and pass it to
31
- ``transform_update`` and ``transform_apply``.
72
+ # inner step
73
+ if "inner" in self.children:
74
+ inner = self.children["inner"]
75
+ objective = inner.step(objective)
32
76
 
33
- A transform can modify the closure instead of directly modifying update by passing ``target="closure"``.
77
+ # apply and return
78
+ states, settings = self._get_states_settings(objective)
79
+ return self.apply_states(objective=objective, states=states, settings=settings)
34
80
 
35
- Args:
36
- defaults (dict[str,Any] | None): dict with default values.
37
- uses_grad (bool):
38
- Set this to True if `transform` method uses the `grad` argument. This will ensure
39
- `grad` is always computed and can't be None. Otherwise set to False.
40
- target (Target, optional):
41
- what to set on var. Defaults to 'update'.
42
81
 
82
+
83
+ class TensorTransform(Transform):
84
+ """``TensorTransform`` is a ``Transform`` that doesn't use ``Objective``, instead it operates
85
+ on lists of tensors directly.
86
+
87
+ This has a ``concat_params`` setting which is used in quite a few modules, for example it is optional
88
+ in all full-matrix method like Quasi-Newton or full-matrix Adagrad.
89
+
90
+ To use, subclass this and override one of ``single_tensor_update`` or ``multi_tensor_update``,
91
+ and one of ``single_tensor_apply`` or ``multi_tensor_apply``.
92
+
93
+ For copying:
94
+
95
+ multi tensor:
96
+ ```
97
+ def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
98
+ ...
99
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
100
+ ...
101
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
102
+ ...
103
+ ```
104
+
105
+ single tensor:
106
+
107
+ ```
108
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
109
+ ...
110
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
111
+ ...
112
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
113
+ ...
114
+ ```
43
115
  """
44
116
  def __init__(
45
117
  self,
46
- defaults: dict[str,Any] | None,
118
+ defaults: dict[str, Any] | None = None,
119
+ update_freq: int = 1,
120
+ concat_params: bool = False,
47
121
  uses_grad: bool = False,
48
122
  uses_loss: bool = False,
49
- concat_params: bool = False,
50
- update_freq: int = 1,
51
- inner: Chainable | None = None,
52
- target: Target = 'update',
123
+ inner: "Chainable | None" = None,
53
124
  ):
54
- super().__init__(defaults)
55
- self._target: Target = target
125
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
126
+
127
+ self._concat_params = concat_params
56
128
  self._uses_grad = uses_grad
57
129
  self._uses_loss = uses_loss
58
- self._concat_params = concat_params
59
- self._update_freq = update_freq
60
- self._inner = inner
61
- self._var = None
62
130
 
63
- def update_tensors(
131
+ # ------------------------------- single tensor ------------------------------ #
132
+ def single_tensor_initialize(
64
133
  self,
65
- tensors: list[torch.Tensor],
66
- params: list[torch.Tensor],
67
- grads: list[torch.Tensor] | None,
68
- loss: torch.Tensor | float | None,
69
- states: list[dict[str, Any]],
70
- settings: Sequence[Mapping[str, Any]],
134
+ tensor: torch.Tensor,
135
+ param: torch.Tensor,
136
+ grad: torch.Tensor | None,
137
+ loss: torch.Tensor | None,
138
+ state: dict[str, Any],
139
+ setting: Mapping[str, Any],
71
140
  ) -> None:
72
- """update function, this shouldn't be called directly. Updates this module."""
141
+ """initialize ``state`` before first ``update``.
142
+ """
73
143
 
74
- @abstractmethod
75
- def apply_tensors(
144
+ def single_tensor_update(
145
+ self,
146
+ tensor: torch.Tensor,
147
+ param: torch.Tensor,
148
+ grad: torch.Tensor | None,
149
+ loss: torch.Tensor | None,
150
+ state: dict[str, Any],
151
+ setting: Mapping[str, Any],
152
+ ) -> None:
153
+ """Updates ``state``. This should not modify ``tensor``.
154
+ """
155
+
156
+ def single_tensor_apply(
157
+ self,
158
+ tensor: torch.Tensor,
159
+ param: torch.Tensor,
160
+ grad: torch.Tensor | None,
161
+ loss: torch.Tensor | None,
162
+ state: dict[str, Any],
163
+ setting: Mapping[str, Any],
164
+ ) -> torch.Tensor:
165
+ """Updates ``tensor`` and returns it. This shouldn't modify ``state`` if possible.
166
+ """
167
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `single_tensor_apply`.")
168
+
169
+ # ------------------------------- multi tensor ------------------------------- #
170
+ def multi_tensor_initialize(
76
171
  self,
77
172
  tensors: list[torch.Tensor],
78
173
  params: list[torch.Tensor],
79
174
  grads: list[torch.Tensor] | None,
80
- loss: torch.Tensor | float | None,
175
+ loss: torch.Tensor | None,
81
176
  states: list[dict[str, Any]],
82
177
  settings: Sequence[Mapping[str, Any]],
83
- ) -> Sequence[torch.Tensor]:
84
- """apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
85
- If possible, this shouldn't modify the internal state of this transform."""
178
+ ) -> None:
179
+ """initialize ``states`` before first ``update``.
180
+ By default calls ``single_tensor_initialize`` on all tensors.
181
+ """
182
+ if grads is None:
183
+ grads = cast(list, [None] * len(tensors))
86
184
 
87
- @final
88
- @torch.no_grad
89
- def transform_update(
185
+ for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
186
+ self.single_tensor_initialize(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
187
+
188
+ def multi_tensor_update(
90
189
  self,
91
190
  tensors: list[torch.Tensor],
92
191
  params: list[torch.Tensor],
93
192
  grads: list[torch.Tensor] | None,
94
- loss: torch.Tensor | float | None,
193
+ loss: torch.Tensor | None,
95
194
  states: list[dict[str, Any]],
96
- settings: Sequence[Mapping[str, Any]] | None,
195
+ settings: Sequence[Mapping[str, Any]],
97
196
  ) -> None:
98
- """Updates this transform from an arbitrary sequence of tensors."""
99
- if self._concat_params:
100
- tensors = [torch.cat([t.ravel() for t in tensors])]
101
- params = [torch.cat([p.ravel() for p in params])]
102
- grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
103
-
104
- if settings is None:
105
- settings = [self.defaults for _ in tensors]
106
-
107
- step = self.global_state.get('__step', 0) # that way it gets reset correctly
108
- self.global_state['__step'] = step + 1
109
-
110
- num = len(tensors)
111
- states = states[:num]
112
- settings = settings[:num]
197
+ """Updates ``states``. This should not modify ``tensor``.
198
+ By default calls ``single_tensor_update`` on all tensors.
199
+ """
113
200
 
114
- # update transform
115
- if step % self._update_freq == 0:
116
- self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
201
+ if grads is None:
202
+ grads = cast(list, [None] * len(tensors))
117
203
 
118
- # store for transform_apply
119
- self.global_state["__tensors"] = tensors
120
- self.global_state["__params"] = params
121
- self.global_state["__grads"] = grads
204
+ for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
205
+ self.single_tensor_update(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
122
206
 
123
-
124
- @final
125
- @torch.no_grad
126
- def transform_apply(
207
+ def multi_tensor_apply(
127
208
  self,
128
209
  tensors: list[torch.Tensor],
129
210
  params: list[torch.Tensor],
130
211
  grads: list[torch.Tensor] | None,
131
- loss: torch.Tensor | float | None,
212
+ loss: torch.Tensor | None,
132
213
  states: list[dict[str, Any]],
133
- settings: Sequence[Mapping[str, Any]] | None,
134
- ) -> list[torch.Tensor]:
135
- """Applies this transform to an arbitrary sequence of tensors.
136
- This can be used after ``transform_update`` has been used at least once."""
137
-
138
- if settings is None:
139
- settings = [self.defaults for _ in tensors]
214
+ settings: Sequence[Mapping[str, Any]],
215
+ ) -> Sequence[torch.Tensor]:
216
+ """Updates ``tensors`` and returns it. This shouldn't modify ``state`` if possible.
217
+ By default calls ``single_tensor_apply`` on all tensors.
218
+ """
140
219
 
141
- num = len(tensors)
142
- states = states[:num]
143
- settings = settings[:num]
220
+ if grads is None:
221
+ grads = cast(list, [None] * len(tensors))
144
222
 
145
- un_tensors = tensors
146
- un_params = params
147
- un_grads = grads
223
+ ret = []
224
+ for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
225
+ u = self.single_tensor_apply(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
226
+ ret.append(u)
148
227
 
149
- tensors = self.global_state.pop("__tensors")
150
- params = self.global_state.pop("__params")
151
- grads = self.global_state.pop("__grads")
228
+ return ret
152
229
 
153
- # step with inner
154
- if self._inner is not None:
155
- tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
156
- if self._concat_params:
157
- tensors = [torch.cat([t.ravel() for t in tensors])]
230
+ def _get_grads_loss(self, objective: "Objective"):
231
+ """evaluates grads and loss only if needed"""
158
232
 
159
- # apply transform
160
- tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
233
+ if self._uses_grad: grads = objective.get_grads()
234
+ else: grads = None # better explicitly set to None rather than objective.grads because it shouldn't be used
161
235
 
162
- if self._concat_params:
163
- tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
236
+ if self._uses_loss: loss = objective.get_loss(backward=False)
237
+ else: loss = None
164
238
 
165
- return tensors
239
+ return grads, loss
166
240
 
167
- def _get_keyed_states_settings(self, params: list[torch.Tensor]):
168
- if self._concat_params:
169
- p = params[0]
170
- states = [self.state[p]]
171
- settings = [self.settings[p]]
241
+ @torch.no_grad
242
+ def _get_cat_updates_params_grads(self, objective: "Objective", grads: list[torch.Tensor] | None):
243
+ assert self._concat_params
172
244
 
173
- else:
174
- states = []
175
- settings = []
176
- for p in params:
177
- states.append(self.state[p])
178
- settings.append(self.settings[p])
245
+ cat_updates = [torch.cat([u.ravel() for u in objective.get_updates()])]
246
+ cat_params = [torch.cat([p.ravel() for p in objective.params])]
179
247
 
180
- return states, settings
248
+ if grads is None: cat_grads = None
249
+ else: cat_grads = [torch.cat([g.ravel() for g in grads])]
181
250
 
182
- @final
183
- @torch.no_grad
184
- def keyed_transform_update(
185
- self,
186
- tensors: list[torch.Tensor],
187
- params: list[torch.Tensor],
188
- grads: list[torch.Tensor] | None,
189
- loss: torch.Tensor | float | None,
190
- ):
191
- """`params` will be used as keys and need to always point to same tensor objects.`"""
192
- states, settings = self._get_keyed_states_settings(params)
193
- self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
251
+ return cat_updates, cat_params, cat_grads
194
252
 
253
+ def _gather_tensors(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
254
+ """returns everything for ``multi_tensor_*``. Concatenates if ```self._concat_params``.
255
+ evaluates grads and loss if ``self._uses_grad`` and ``self._uses_loss``"""
195
256
 
196
- @final
197
- @torch.no_grad
198
- def keyed_transform_apply(
199
- self,
200
- tensors: list[torch.Tensor],
201
- params: list[torch.Tensor],
202
- grads: list[torch.Tensor] | None,
203
- loss: torch.Tensor | float | None,
204
- ):
205
- """`params` will be used as keys and need to always point to same tensor objects.`"""
206
- states, settings = self._get_keyed_states_settings(params)
207
- return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
257
+ # evaluate grads and loss if `self._uses_grad` and `self._uses_loss`
258
+ grads, loss = self._get_grads_loss(objective)
208
259
 
260
+ # gather all things
261
+ # concatenate everything to a vec if `self._concat_params`
262
+ if self._concat_params:
263
+ tensors, params, grads = self._get_cat_updates_params_grads(objective, grads)
264
+ states = [states[0]]; settings = [settings[0]]
209
265
 
210
- def pre_step(self, var: Var) -> None:
211
- """Logic to run pre-transform, this way transform has access to Var."""
212
- def post_step(self, var: Var) -> None:
213
- """Logic to run post-transform, this way transform has access to Var."""
266
+ # or take original values
267
+ else:
268
+ tensors=objective.get_updates()
269
+ params = objective.params
214
270
 
215
- def update(self, var: Var):
216
- if self._target != 'update':
217
- raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
218
- f"With {self._target = } only `step` method can be used.")
271
+ return tensors, params, grads, loss, states, settings
219
272
 
220
- # var may change, therefore current params and grads have to be extracted and passed explicitly
221
- update = var.get_update() # this sets loss
222
- if self._uses_grad: var.get_grad()
223
- if self._uses_loss: var.get_loss(False)
224
- params=var.params
225
- self.pre_step(var)
273
+ @final
274
+ def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
275
+ tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
276
+
277
+ # initialize before the first update
278
+ num_updates = self.increment_counter("__num_updates", 0)
279
+ if num_updates == 0:
280
+ self.multi_tensor_initialize(
281
+ tensors=tensors,
282
+ params=params,
283
+ grads=grads,
284
+ loss=loss,
285
+ states=states,
286
+ settings=settings
287
+ )
226
288
 
227
289
  # update
228
- self._var = var
229
- self.keyed_transform_update(update, params, var.grad, var.loss)
230
- self._var = None
231
-
232
- def apply(self, var: Var):
233
- if self._target != 'update':
234
- raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
235
- f"With {self._target = } only `step` method can be used.")
290
+ self.multi_tensor_update(
291
+ tensors=tensors,
292
+ params=params,
293
+ grads=grads,
294
+ loss=loss,
295
+ states=states,
296
+ settings=settings
297
+ )
236
298
 
237
- # var may change, therefore current params and grads have to be extracted and passed explicitly
238
- update = var.get_update() # this sets loss
239
- if self._uses_grad: var.get_grad()
240
- if self._uses_loss: var.get_loss(False)
241
- params=var.params
299
+ @final
300
+ def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
301
+ tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
302
+ # note: _gather tensors will re-cat again if `_concat_params`, this is necessary because objective
303
+ # may have been modified in functional logic, there is no way to know if that happened
242
304
 
243
305
  # apply
244
- self._var = var
245
- var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
246
- self._var = None
247
-
248
- self.post_step(var)
249
- return var
250
-
251
- def step(self, var: Var) -> Var:
252
-
253
- # var may change, therefore current params and grads have to be extracted and passed explicitly
254
- if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
255
- if self._uses_grad or self._target == 'grad': var.get_grad()
256
- if self._uses_loss: var.get_loss(False)
257
- params=var.params
258
- self.pre_step(var)
259
- self._var = var
260
-
261
- # ---------------------------------- update ---------------------------------- #
262
- if self._target == 'update':
263
- update = var.get_update()
264
- self.keyed_transform_update(update, params, var.grad, var.loss)
265
- var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
266
- self._var = None
267
- return var
268
-
269
- # ----------------------------------- grad ----------------------------------- #
270
- if self._target == 'grad':
271
- grad = var.get_grad()
272
- self.keyed_transform_update(grad, params, grad, var.loss)
273
- var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
274
- self._var = None
275
- return var
276
-
277
- # ------------------------------- params_direct ------------------------------ #
278
- if self._target == 'params_direct':
279
- self.keyed_transform_update(var.params, params, var.grad, var.loss)
280
- new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
281
- for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
282
- self._var = None
283
- return var
284
-
285
- # ----------------------------- params_differnce ----------------------------- #
286
- if self._target == 'params_difference':
287
- p_clone = [p.clone() for p in var.params]
288
- self.keyed_transform_update(p_clone, params, var.grad, var.loss)
289
- new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
290
- var.update = list(torch._foreach_sub(var.params, new_params))
291
- self._var = None
292
- return var
293
-
294
- # ----------------------------- update_difference ---------------------------- #
295
- if self._target == 'update_difference':
296
- update = var.get_update()
297
- u_clone = [u.clone() for u in update]
298
- self.keyed_transform_update(u_clone, params, var.grad, var.loss)
299
- new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
300
- var.update = list(torch._foreach_sub(update, new_update))
301
- self._var = None
302
- return var
303
-
304
- # ---------------------------------- closure --------------------------------- #
305
- if self._target == 'closure':
306
- original_closure = var.closure
307
- if original_closure is None: raise ValueError('Target = "closure", but closure is None')
308
-
309
- params = var.params
310
- parent_var = self._var
311
- def transformed_closure(backward=True):
312
- if backward:
313
- loss = original_closure()
314
- current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
315
-
316
- self._var = parent_var
317
- self.keyed_transform_update(current_grad, params, var.grad, var.loss)
318
- transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
319
- self._var = None
320
-
321
- for p, g in zip(params, transformed_grad):
322
- p.grad = g
323
-
324
- else:
325
- loss = original_closure(False)
326
-
327
- return loss
328
-
329
- var.closure = transformed_closure
330
- self.post_step(var)
331
- self._var = None
332
- return var
333
-
334
- # ---------------------------------- invalid --------------------------------- #
335
- raise ValueError(f'Invalid target: {self._target}')
336
-
337
-
338
- class TensorwiseTransform(Transform, ABC):
339
- """Base class for a parameter-wise transform.
340
-
341
- This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
342
-
343
- Args:
344
- defaults (dict[str,Any] | None): dict with default values.
345
- uses_grad (bool):
346
- Set this to True if `transform` method uses the `grad` argument. This will ensure
347
- `grad` is always computed and can't be None. Otherwise set to False.
348
- target (Target, optional):
349
- what to set on var. Defaults to 'update'.
350
- """
351
- def __init__(
352
- self,
353
- defaults: dict[str,Any] | None,
354
- uses_grad: bool = False,
355
- uses_loss: bool = False,
356
- concat_params: bool = False,
357
- update_freq: int = 1,
358
- inner: Chainable | None = None,
359
- target: Target = 'update',
360
- ):
361
- super().__init__(
362
- defaults=defaults,
363
- uses_grad=uses_grad,
364
- concat_params=concat_params,
365
- update_freq=update_freq,
366
- uses_loss=uses_loss,
367
- inner=inner,
368
- target=target,
306
+ ret = self.multi_tensor_apply(
307
+ tensors=tensors,
308
+ params=params,
309
+ grads=grads,
310
+ loss=loss,
311
+ states=states,
312
+ settings=settings
369
313
  )
370
314
 
371
- def update_tensor(
372
- self,
373
- tensor: torch.Tensor,
374
- param: torch.Tensor,
375
- grad: torch.Tensor | None,
376
- loss: torch.Tensor | float | None,
377
- state: dict[str, Any],
378
- setting: Mapping[str, Any],
379
- ) -> None:
380
- """Updates this transform. By default does nothing - if logic is in `apply` method."""
315
+ # uncat if needed and set objective.updates and return objective
316
+ if self._concat_params:
317
+ objective.updates = vec_to_tensors(ret[0], objective.params)
381
318
 
382
- @abstractmethod
383
- def apply_tensor(
384
- self,
385
- tensor: torch.Tensor,
386
- param: torch.Tensor,
387
- grad: torch.Tensor | None,
388
- loss: torch.Tensor | float | None,
389
- state: dict[str, Any],
390
- setting: Mapping[str, Any],
391
- ) -> torch.Tensor:
392
- """Applies the update rule to `tensor`."""
319
+ else:
320
+ objective.updates = list(ret)
393
321
 
394
- @final
395
- def update_tensors(self, tensors, params, grads, loss, states, settings):
396
- if grads is None: grads = [None]*len(tensors)
397
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
398
- self.update_tensor(t, p, g, loss, state, setting)
322
+ return objective
399
323
 
400
- @final
401
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
402
- applied = []
403
- if grads is None: grads = [None]*len(tensors)
404
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
405
- applied.append(self.apply_tensor(t, p, g, loss, state, setting))
406
- return applied
407
-
408
- def apply_transform(
409
- tfm: Chainable,
410
- tensors: list[torch.Tensor],
411
- params: list[torch.Tensor],
412
- grads: list[torch.Tensor] | None,
413
- loss: torch.Tensor | float | None = None,
414
- var: Var | None = None,
415
- current_step: int = 0,
416
- ):
417
- if var is None:
418
- var = Var(params=params, closure=None, model=None, current_step=current_step)
419
- var.loss = loss
420
-
421
- if isinstance(tfm, Transform) and tfm._target == 'update':
422
- if tfm._uses_grad and grads is None: grads = var.get_grad()
423
- tfm.keyed_transform_update(tensors, params, grads, loss)
424
- return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
425
-
426
- if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
427
- if isinstance(tfm, Sequence):
428
- for module in tfm:
429
- tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
430
- return tensors
431
-
432
- if isinstance(tfm, Module):
433
- cvar = var.clone(clone_update=False)
434
- cvar.update = tensors
435
- cvar = tfm.step(cvar)
436
- var.update_attrs_from_clone_(cvar)
437
- assert cvar.update is not None
438
- return cvar.update
439
-
440
- raise TypeError(type(tfm))
324
+
325
+ # make sure _concat_params, _uses_grad and _uses_loss are saved in `state_dict`
326
+ def _extra_pack(self):
327
+ return {
328
+ "__concat_params": self._concat_params,
329
+ "__uses_grad": self._uses_grad,
330
+ "__uses_loss": self._uses_loss,
331
+ }
332
+
333
+ def _extra_unpack(self, d):
334
+ self._concat_params = d["__concat_params"]
335
+ self._uses_grad = d["__uses_grad"]
336
+ self._uses_loss = d["__uses_loss"]