torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,65 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Sequence
3
+
4
+ import torch
5
+
6
+ from .module import Chainable, Modular, Module, Var
7
+
8
+
9
+ class Reformulation(Module, ABC):
10
+ def __init__(self, defaults: dict | None, modules: Chainable | None):
11
+ super().__init__(defaults)
12
+
13
+ if modules is not None:
14
+ self.set_child("modules", modules)
15
+
16
+ @abstractmethod
17
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
18
+ """
19
+ returns (loss, gradient), if backward is False then gradient can be None.
20
+
21
+ If evaluating original loss/gradient at x_0, set them to ``var``.
22
+ """
23
+
24
+ def pre_step(self, var: Var) -> Var | None:
25
+ """This runs once before each step, whereas `closure` may run multiple times per step if further modules
26
+ evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
27
+
28
+ def step(self, var):
29
+ ret = self.pre_step(var) # pylint:disable = assignment-from-no-return
30
+ if isinstance(ret, Var): var = ret
31
+
32
+ if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
+ params, closure = var.params, var.closure
34
+
35
+ # step with children
36
+ if 'modules' in self.children:
37
+
38
+ # make a reformulated closure
39
+ def modified_closure(backward=True):
40
+ loss, grad = self.closure(backward, closure, params, var)
41
+
42
+ if grad is not None:
43
+ for p,g in zip(params, grad):
44
+ p.grad = g
45
+
46
+ return loss
47
+
48
+ # set it to a new Var object
49
+ modified_var = var.clone(clone_update=False)
50
+ modified_var.closure = modified_closure
51
+
52
+ # step with child
53
+ modules = self.children['modules']
54
+ modified_var = modules.step(modified_var)
55
+
56
+ # modified_var.loss and grad refers to loss and grad of a modified objective
57
+ # so we only take the update
58
+ var.update = modified_var.update
59
+
60
+ # or just evaluate new closure and set to update
61
+ else:
62
+ loss, grad = self.closure(backward=True, closure=closure, params=params, var=var)
63
+ if grad is not None: var.update = list(grad)
64
+
65
+ return var
@@ -1,18 +1,36 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import Iterable, Sequence, Mapping
2
+ from collections.abc import Iterable, Mapping, Sequence
3
3
  from typing import Any, Literal, final
4
4
 
5
5
  import torch
6
6
 
7
- from ..utils import set_storage_, TensorList, vec_to_tensors
8
- from .module import Module, Var, Chain, Chainable
7
+ from ..utils import TensorList, set_storage_, vec_to_tensors
8
+ from .module import Chain, Chainable, Module, Var
9
9
 
10
10
  Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
11
 
12
+
12
13
  class Transform(Module, ABC):
13
- """Base class for a transform. This is an abstract class, to use it, subclass it and override `update` and `apply` methods.
14
+ """Base class for a transform.
15
+ This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
14
16
 
15
17
  A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
18
+ It has two methods:
19
+
20
+ - ``update_tensors`` updates the internal state of this transform, it doesn't modify tensors. \
21
+ It may be called multiple times before ``apply_tensors``.
22
+ - ``apply_tensors`` applies this transform to tensors, without modifying the internal state if possible.
23
+
24
+ Alternatively, if update-apply structure doesn't make sense for a transform, all logic can be defined within ``apply_tensors``.
25
+
26
+ Transform can be applied to tensors corresponding to custom parameters
27
+ by calling ``keyed_transform_update`` and ``keyed_transform_apply``,
28
+ parameters will be keys to store per-parameter states, so they should remain the same python objects.
29
+
30
+ Alternatively you can manually create a list of state dictionaries per each tensor and pass it to
31
+ ``transform_update`` and ``transform_apply``.
32
+
33
+ A transform can modify the closure instead of directly modifying update by passing ``target="closure"``.
16
34
 
17
35
  Args:
18
36
  defaults (dict[str,Any] | None): dict with default values.
@@ -21,63 +39,63 @@ class Transform(Module, ABC):
21
39
  `grad` is always computed and can't be None. Otherwise set to False.
22
40
  target (Target, optional):
23
41
  what to set on var. Defaults to 'update'.
42
+
24
43
  """
25
44
  def __init__(
26
45
  self,
27
46
  defaults: dict[str,Any] | None,
28
- uses_grad: bool,
47
+ uses_grad: bool = False,
48
+ uses_loss: bool = False,
29
49
  concat_params: bool = False,
30
50
  update_freq: int = 1,
31
- scale_first: bool = False,
32
51
  inner: Chainable | None = None,
33
52
  target: Target = 'update',
34
53
  ):
35
54
  super().__init__(defaults)
36
55
  self._target: Target = target
37
56
  self._uses_grad = uses_grad
57
+ self._uses_loss = uses_loss
38
58
  self._concat_params = concat_params
39
59
  self._update_freq = update_freq
40
- self._scale_first = scale_first
41
60
  self._inner = inner
61
+ self._var = None
42
62
 
43
- def update(
63
+ def update_tensors(
44
64
  self,
45
65
  tensors: list[torch.Tensor],
46
66
  params: list[torch.Tensor],
47
67
  grads: list[torch.Tensor] | None,
48
- loss: torch.Tensor | None,
68
+ loss: torch.Tensor | float | None,
49
69
  states: list[dict[str, Any]],
50
70
  settings: Sequence[Mapping[str, Any]],
51
71
  ) -> None:
52
- """Updates this transform. By default does nothing - if logic is in `apply` method."""
72
+ """update function, this shouldn't be called directly. Updates this module."""
53
73
 
54
74
  @abstractmethod
55
- def apply(
75
+ def apply_tensors(
56
76
  self,
57
77
  tensors: list[torch.Tensor],
58
78
  params: list[torch.Tensor],
59
79
  grads: list[torch.Tensor] | None,
60
- loss: torch.Tensor | None,
80
+ loss: torch.Tensor | float | None,
61
81
  states: list[dict[str, Any]],
62
82
  settings: Sequence[Mapping[str, Any]],
63
83
  ) -> Sequence[torch.Tensor]:
64
- """Applies the update rule to `tensors`."""
84
+ """apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
85
+ If possible, this shouldn't modify the internal state of this transform."""
65
86
 
66
87
  @final
67
88
  @torch.no_grad
68
- def transform(
89
+ def transform_update(
69
90
  self,
70
91
  tensors: list[torch.Tensor],
71
92
  params: list[torch.Tensor],
72
93
  grads: list[torch.Tensor] | None,
73
- loss: torch.Tensor | None,
94
+ loss: torch.Tensor | float | None,
74
95
  states: list[dict[str, Any]],
75
96
  settings: Sequence[Mapping[str, Any]] | None,
76
- ) -> list[torch.Tensor]:
77
- """Applies this transform to an arbitrary sequence of tensors."""
78
- un_tensors = tensors
79
- un_params = params
80
- un_grads = grads
97
+ ) -> None:
98
+ """Updates this transform from an arbitrary sequence of tensors."""
81
99
  if self._concat_params:
82
100
  tensors = [torch.cat([t.ravel() for t in tensors])]
83
101
  params = [torch.cat([p.ravel() for p in params])]
@@ -86,53 +104,67 @@ class Transform(Module, ABC):
86
104
  if settings is None:
87
105
  settings = [self.defaults for _ in tensors]
88
106
 
89
- step = self.global_state.get('__step', 0)
107
+ step = self.global_state.get('__step', 0) # that way it gets reset correctly
108
+ self.global_state['__step'] = step + 1
109
+
90
110
  num = len(tensors)
91
111
  states = states[:num]
92
112
  settings = settings[:num]
93
113
 
94
- update_freq = self._update_freq
95
- scale_first = self._scale_first
96
- scale_factor = 1
114
+ # update transform
115
+ if step % self._update_freq == 0:
116
+ self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
97
117
 
98
- # scaling factor for 1st step
99
- if scale_first and step == 0:
100
- # initial step size guess from pytorch LBFGS
101
- scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
102
- scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
118
+ # store for transform_apply
119
+ self.global_state["__tensors"] = tensors
120
+ self.global_state["__params"] = params
121
+ self.global_state["__grads"] = grads
103
122
 
104
- # update transform
105
- if step % update_freq == 0:
106
- self.update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
123
+
124
+ @final
125
+ @torch.no_grad
126
+ def transform_apply(
127
+ self,
128
+ tensors: list[torch.Tensor],
129
+ params: list[torch.Tensor],
130
+ grads: list[torch.Tensor] | None,
131
+ loss: torch.Tensor | float | None,
132
+ states: list[dict[str, Any]],
133
+ settings: Sequence[Mapping[str, Any]] | None,
134
+ ) -> list[torch.Tensor]:
135
+ """Applies this transform to an arbitrary sequence of tensors.
136
+ This can be used after ``transform_update`` has been used at least once."""
137
+
138
+ if settings is None:
139
+ settings = [self.defaults for _ in tensors]
140
+
141
+ num = len(tensors)
142
+ states = states[:num]
143
+ settings = settings[:num]
144
+
145
+ un_tensors = tensors
146
+ un_params = params
147
+ un_grads = grads
148
+
149
+ tensors = self.global_state.pop("__tensors")
150
+ params = self.global_state.pop("__params")
151
+ grads = self.global_state.pop("__grads")
107
152
 
108
153
  # step with inner
109
154
  if self._inner is not None:
110
- tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
155
+ tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
111
156
  if self._concat_params:
112
157
  tensors = [torch.cat([t.ravel() for t in tensors])]
113
158
 
114
159
  # apply transform
115
- tensors = list(self.apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
116
-
117
- # scale initial step, when preconditioner might not have been applied
118
- if scale_first and step == 0:
119
- torch._foreach_mul_(tensors, scale_factor)
160
+ tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
120
161
 
121
- self.global_state['__step'] = step + 1
122
162
  if self._concat_params:
123
163
  tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
124
- return tensors
125
164
 
165
+ return tensors
126
166
 
127
- @torch.no_grad
128
- def keyed_transform(
129
- self,
130
- tensors: list[torch.Tensor],
131
- params: list[torch.Tensor],
132
- grads: list[torch.Tensor] | None,
133
- loss: torch.Tensor | None,
134
- ):
135
- """Applies this transform to `tensors`, `params` will be used as keys and need to always point to same tensor objects."""
167
+ def _get_keyed_states_settings(self, params: list[torch.Tensor]):
136
168
  if self._concat_params:
137
169
  p = params[0]
138
170
  states = [self.state[p]]
@@ -145,42 +177,128 @@ class Transform(Module, ABC):
145
177
  states.append(self.state[p])
146
178
  settings.append(self.settings[p])
147
179
 
148
- return self.transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
180
+ return states, settings
181
+
182
+ @final
183
+ @torch.no_grad
184
+ def keyed_transform_update(
185
+ self,
186
+ tensors: list[torch.Tensor],
187
+ params: list[torch.Tensor],
188
+ grads: list[torch.Tensor] | None,
189
+ loss: torch.Tensor | float | None,
190
+ ):
191
+ """`params` will be used as keys and need to always point to same tensor objects.`"""
192
+ states, settings = self._get_keyed_states_settings(params)
193
+ self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
194
+
195
+
196
+ @final
197
+ @torch.no_grad
198
+ def keyed_transform_apply(
199
+ self,
200
+ tensors: list[torch.Tensor],
201
+ params: list[torch.Tensor],
202
+ grads: list[torch.Tensor] | None,
203
+ loss: torch.Tensor | float | None,
204
+ ):
205
+ """`params` will be used as keys and need to always point to same tensor objects.`"""
206
+ states, settings = self._get_keyed_states_settings(params)
207
+ return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
208
+
209
+
210
+ def pre_step(self, var: Var) -> None:
211
+ """Logic to run pre-transform, this way transform has access to Var."""
212
+ def post_step(self, var: Var) -> None:
213
+ """Logic to run post-transform, this way transform has access to Var."""
214
+
215
+ def update(self, var: Var):
216
+ if self._target != 'update':
217
+ raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
218
+ f"With {self._target = } only `step` method can be used.")
149
219
 
150
- def step(self, var: Var) -> Var:
151
220
  # var may change, therefore current params and grads have to be extracted and passed explicitly
221
+ update = var.get_update() # this sets loss
152
222
  if self._uses_grad: var.get_grad()
223
+ if self._uses_loss: var.get_loss(False)
153
224
  params=var.params
225
+ self.pre_step(var)
226
+
227
+ # update
228
+ self._var = var
229
+ self.keyed_transform_update(update, params, var.grad, var.loss)
230
+ self._var = None
231
+
232
+ def apply(self, var: Var):
233
+ if self._target != 'update':
234
+ raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
235
+ f"With {self._target = } only `step` method can be used.")
236
+
237
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
238
+ update = var.get_update() # this sets loss
239
+ if self._uses_grad: var.get_grad()
240
+ if self._uses_loss: var.get_loss(False)
241
+ params=var.params
242
+
243
+ # apply
244
+ self._var = var
245
+ var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
246
+ self._var = None
247
+
248
+ self.post_step(var)
249
+ return var
250
+
251
+ def step(self, var: Var) -> Var:
252
+
253
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
254
+ if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
255
+ if self._uses_grad or self._target == 'grad': var.get_grad()
256
+ if self._uses_loss: var.get_loss(False)
257
+ params=var.params
258
+ self.pre_step(var)
259
+ self._var = var
154
260
 
155
261
  # ---------------------------------- update ---------------------------------- #
156
262
  if self._target == 'update':
157
263
  update = var.get_update()
158
- var.update = list(self.keyed_transform(update, params, var.grad, var.loss))
264
+ self.keyed_transform_update(update, params, var.grad, var.loss)
265
+ var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
266
+ self._var = None
159
267
  return var
160
268
 
161
269
  # ----------------------------------- grad ----------------------------------- #
162
270
  if self._target == 'grad':
163
271
  grad = var.get_grad()
164
- var.grad = list(self.keyed_transform(grad, params, grad, var.loss))
272
+ self.keyed_transform_update(grad, params, grad, var.loss)
273
+ var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
274
+ self._var = None
165
275
  return var
166
276
 
167
277
  # ------------------------------- params_direct ------------------------------ #
168
278
  if self._target == 'params_direct':
169
- new_params = self.keyed_transform(var.params, params, var.grad, var.loss)
279
+ self.keyed_transform_update(var.params, params, var.grad, var.loss)
280
+ new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
170
281
  for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
282
+ self._var = None
171
283
  return var
172
284
 
173
285
  # ----------------------------- params_differnce ----------------------------- #
174
286
  if self._target == 'params_difference':
175
- new_params = tuple(self.keyed_transform([p.clone() for p in var.params], params, var.grad, var.loss))
287
+ p_clone = [p.clone() for p in var.params]
288
+ self.keyed_transform_update(p_clone, params, var.grad, var.loss)
289
+ new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
176
290
  var.update = list(torch._foreach_sub(var.params, new_params))
291
+ self._var = None
177
292
  return var
178
293
 
179
294
  # ----------------------------- update_difference ---------------------------- #
180
295
  if self._target == 'update_difference':
181
296
  update = var.get_update()
182
- new_update = tuple(self.keyed_transform([u.clone() for u in update], params, var.grad, var.loss))
297
+ u_clone = [u.clone() for u in update]
298
+ self.keyed_transform_update(u_clone, params, var.grad, var.loss)
299
+ new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
183
300
  var.update = list(torch._foreach_sub(update, new_update))
301
+ self._var = None
184
302
  return var
185
303
 
186
304
  # ---------------------------------- closure --------------------------------- #
@@ -189,11 +307,17 @@ class Transform(Module, ABC):
189
307
  if original_closure is None: raise ValueError('Target = "closure", but closure is None')
190
308
 
191
309
  params = var.params
310
+ parent_var = self._var
192
311
  def transformed_closure(backward=True):
193
312
  if backward:
194
313
  loss = original_closure()
195
314
  current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
196
- transformed_grad = list(self.keyed_transform(current_grad, params, var.grad, var.loss))
315
+
316
+ self._var = parent_var
317
+ self.keyed_transform_update(current_grad, params, var.grad, var.loss)
318
+ transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
319
+ self._var = None
320
+
197
321
  for p, g in zip(params, transformed_grad):
198
322
  p.grad = g
199
323
 
@@ -203,6 +327,8 @@ class Transform(Module, ABC):
203
327
  return loss
204
328
 
205
329
  var.closure = transformed_closure
330
+ self.post_step(var)
331
+ self._var = None
206
332
  return var
207
333
 
208
334
  # ---------------------------------- invalid --------------------------------- #
@@ -212,7 +338,7 @@ class Transform(Module, ABC):
212
338
  class TensorwiseTransform(Transform, ABC):
213
339
  """Base class for a parameter-wise transform.
214
340
 
215
- This is an abstract class, to use it, subclass it and override `transform`.
341
+ This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
216
342
 
217
343
  Args:
218
344
  defaults (dict[str,Any] | None): dict with default values.
@@ -225,10 +351,10 @@ class TensorwiseTransform(Transform, ABC):
225
351
  def __init__(
226
352
  self,
227
353
  defaults: dict[str,Any] | None,
228
- uses_grad: bool,
354
+ uses_grad: bool = False,
355
+ uses_loss: bool = False,
229
356
  concat_params: bool = False,
230
357
  update_freq: int = 1,
231
- scale_first: bool = False,
232
358
  inner: Chainable | None = None,
233
359
  target: Target = 'update',
234
360
  ):
@@ -237,7 +363,7 @@ class TensorwiseTransform(Transform, ABC):
237
363
  uses_grad=uses_grad,
238
364
  concat_params=concat_params,
239
365
  update_freq=update_freq,
240
- scale_first=scale_first,
366
+ uses_loss=uses_loss,
241
367
  inner=inner,
242
368
  target=target,
243
369
  )
@@ -247,9 +373,9 @@ class TensorwiseTransform(Transform, ABC):
247
373
  tensor: torch.Tensor,
248
374
  param: torch.Tensor,
249
375
  grad: torch.Tensor | None,
250
- loss: torch.Tensor | None,
376
+ loss: torch.Tensor | float | None,
251
377
  state: dict[str, Any],
252
- settings: Mapping[str, Any],
378
+ setting: Mapping[str, Any],
253
379
  ) -> None:
254
380
  """Updates this transform. By default does nothing - if logic is in `apply` method."""
255
381
 
@@ -259,20 +385,20 @@ class TensorwiseTransform(Transform, ABC):
259
385
  tensor: torch.Tensor,
260
386
  param: torch.Tensor,
261
387
  grad: torch.Tensor | None,
262
- loss: torch.Tensor | None,
388
+ loss: torch.Tensor | float | None,
263
389
  state: dict[str, Any],
264
- settings: Mapping[str, Any],
390
+ setting: Mapping[str, Any],
265
391
  ) -> torch.Tensor:
266
392
  """Applies the update rule to `tensor`."""
267
393
 
268
394
  @final
269
- def update(self, tensors, params, grads, loss, states, settings):
395
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
270
396
  if grads is None: grads = [None]*len(tensors)
271
397
  for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
272
398
  self.update_tensor(t, p, g, loss, state, setting)
273
399
 
274
400
  @final
275
- def apply(self, tensors, params, grads, loss, states, settings):
401
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
276
402
  applied = []
277
403
  if grads is None: grads = [None]*len(tensors)
278
404
  for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
@@ -284,7 +410,7 @@ def apply_transform(
284
410
  tensors: list[torch.Tensor],
285
411
  params: list[torch.Tensor],
286
412
  grads: list[torch.Tensor] | None,
287
- loss: torch.Tensor | None = None,
413
+ loss: torch.Tensor | float | None = None,
288
414
  var: Var | None = None,
289
415
  current_step: int = 0,
290
416
  ):
@@ -292,9 +418,10 @@ def apply_transform(
292
418
  var = Var(params=params, closure=None, model=None, current_step=current_step)
293
419
  var.loss = loss
294
420
 
295
- if isinstance(tfm, Transform):
421
+ if isinstance(tfm, Transform) and tfm._target == 'update':
296
422
  if tfm._uses_grad and grads is None: grads = var.get_grad()
297
- return list(tfm.keyed_transform(tensors, params, grads, loss))
423
+ tfm.keyed_transform_update(tensors, params, grads, loss)
424
+ return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
298
425
 
299
426
  if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
300
427
  if isinstance(tfm, Sequence):
@@ -1,14 +1,23 @@
1
+ from . import experimental
1
2
  from .clipping import *
3
+ from .conjugate_gradient import *
2
4
  from .grad_approximation import *
5
+ from .higher_order import *
6
+ from .least_squares import *
3
7
  from .line_search import *
4
- from .lr import *
8
+ from .misc import *
5
9
  from .momentum import *
6
10
  from .ops import *
7
- from .optimizers import *
11
+ from .adaptive import *
8
12
  from .projections import *
9
13
  from .quasi_newton import *
14
+ from .second_order import *
10
15
  from .smoothing import *
16
+ from .step_size import *
17
+ from .termination import *
18
+ from .trust_region import *
19
+ from .variance_reduction import *
11
20
  from .weight_decay import *
12
21
  from .wrappers import *
13
- from .second_order import *
14
- from .higher_order import *
22
+ from .restarts import *
23
+ from .zeroth_order import *
@@ -0,0 +1,30 @@
1
+ from .adagrad import Adagrad, FullMatrixAdagrad, AdagradNorm
2
+
3
+ # from .curveball import CurveBall
4
+ # from .spectral import SpectralPreconditioner
5
+ from .adahessian import AdaHessian
6
+ from .adam import Adam
7
+ from .adan import Adan
8
+ from .adaptive_heavyball import AdaptiveHeavyBall
9
+ from .aegd import AEGD
10
+ from .esgd import ESGD
11
+ from .lmadagrad import LMAdagrad
12
+ from .lion import Lion
13
+ from .mars import MARSCorrection
14
+ from .matrix_momentum import MatrixMomentum
15
+ from .msam import MSAM, MSAMObjective
16
+ from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
17
+ from .natural_gradient import NaturalGradient
18
+ from .orthograd import OrthoGrad, orthograd_
19
+ from .rmsprop import RMSprop
20
+ from .rprop import (
21
+ BacktrackOnSignChange,
22
+ Rprop,
23
+ ScaleLRBySignChange,
24
+ SignConsistencyLRs,
25
+ SignConsistencyMask,
26
+ )
27
+ from .sam import ASAM, SAM
28
+ from .shampoo import Shampoo
29
+ from .soap import SOAP
30
+ from .sophia_h import SophiaH