torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,18 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import Iterable, Sequence
3
- from typing import Any, Literal
2
+ from collections.abc import Iterable, Sequence, Mapping
3
+ from typing import Any, Literal, final
4
4
 
5
5
  import torch
6
6
 
7
- from ..utils import set_storage_
8
- from .module import Module, Vars, Chain, Chainable
7
+ from ..utils import set_storage_, TensorList, vec_to_tensors
8
+ from .module import Module, Var, Chain, Chainable
9
9
 
10
10
  Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
11
 
12
12
  class Transform(Module, ABC):
13
- """Base class for a transform.
13
+ """Base class for a transform. This is an abstract class, to use it, subclass it and override `update` and `apply` methods.
14
14
 
15
- This is an abstract class, to use it, subclass it and override `transform`.
15
+ A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
16
16
 
17
17
  Args:
18
18
  defaults (dict[str,Any] | None): dict with default values.
@@ -20,62 +20,283 @@ class Transform(Module, ABC):
20
20
  Set this to True if `transform` method uses the `grad` argument. This will ensure
21
21
  `grad` is always computed and can't be None. Otherwise set to False.
22
22
  target (Target, optional):
23
- what to set on vars. Defaults to 'update'.
23
+ what to set on var. Defaults to 'update'.
24
24
  """
25
- def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
25
+ def __init__(
26
+ self,
27
+ defaults: dict[str,Any] | None,
28
+ uses_grad: bool = False,
29
+ uses_loss: bool = False,
30
+ concat_params: bool = False,
31
+ update_freq: int = 1,
32
+ scale_first: bool = False,
33
+ inner: Chainable | None = None,
34
+ target: Target = 'update',
35
+ ):
26
36
  super().__init__(defaults)
27
37
  self._target: Target = target
28
38
  self._uses_grad = uses_grad
39
+ self._uses_loss = uses_loss
40
+ self._concat_params = concat_params
41
+ self._update_freq = update_freq
42
+ self._scale_first = scale_first
43
+ self._inner = inner
44
+
45
+ def update_tensors(
46
+ self,
47
+ tensors: list[torch.Tensor],
48
+ params: list[torch.Tensor],
49
+ grads: list[torch.Tensor] | None,
50
+ loss: torch.Tensor | float | None,
51
+ states: list[dict[str, Any]],
52
+ settings: Sequence[Mapping[str, Any]],
53
+ ) -> None:
54
+ """update function, this shouldn't be called directly. Updates this module."""
29
55
 
30
56
  @abstractmethod
31
- def transform(self, tensors: list[torch.Tensor], params: list[torch.Tensor], grads: list[torch.Tensor] | None, vars: Vars) -> Iterable[torch.Tensor]:
32
- """applies the update rule to `target`."""
57
+ def apply_tensors(
58
+ self,
59
+ tensors: list[torch.Tensor],
60
+ params: list[torch.Tensor],
61
+ grads: list[torch.Tensor] | None,
62
+ loss: torch.Tensor | float | None,
63
+ states: list[dict[str, Any]],
64
+ settings: Sequence[Mapping[str, Any]],
65
+ ) -> Sequence[torch.Tensor]:
66
+ """apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
67
+ If possible, this shouldn't modify the internal state of this transform."""
68
+
69
+ @final
70
+ @torch.no_grad
71
+ def transform_update(
72
+ self,
73
+ tensors: list[torch.Tensor],
74
+ params: list[torch.Tensor],
75
+ grads: list[torch.Tensor] | None,
76
+ loss: torch.Tensor | float | None,
77
+ states: list[dict[str, Any]],
78
+ settings: Sequence[Mapping[str, Any]] | None,
79
+ ) -> None:
80
+ """Updates this transform from an arbitrary sequence of tensors."""
81
+ if self._concat_params:
82
+ tensors = [torch.cat([t.ravel() for t in tensors])]
83
+ params = [torch.cat([p.ravel() for p in params])]
84
+ grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
85
+
86
+ if settings is None:
87
+ settings = [self.defaults for _ in tensors]
88
+
89
+ step = self.global_state.get('__step', 0) # that way it gets reset correctly
90
+ self.global_state['__step'] = step + 1
91
+
92
+ num = len(tensors)
93
+ states = states[:num]
94
+ settings = settings[:num]
95
+
96
+ scale_factor = 1
97
+
98
+ # scaling factor for 1st step
99
+ if self._scale_first and step == 0:
100
+ # initial step size guess from pytorch LBFGS
101
+ scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
102
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
103
+
104
+ # update transform
105
+ if step % self._update_freq == 0:
106
+ self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
107
+
108
+ # store for transform_apply
109
+ self.global_state["__tensors"] = tensors
110
+ self.global_state["__params"] = params
111
+ self.global_state["__grads"] = grads
112
+ self.global_state["__scale_factor"] = scale_factor
113
+
114
+
115
+ @final
116
+ @torch.no_grad
117
+ def transform_apply(
118
+ self,
119
+ tensors: list[torch.Tensor],
120
+ params: list[torch.Tensor],
121
+ grads: list[torch.Tensor] | None,
122
+ loss: torch.Tensor | float | None,
123
+ states: list[dict[str, Any]],
124
+ settings: Sequence[Mapping[str, Any]] | None,
125
+ ) -> list[torch.Tensor]:
126
+ """Applies this transform to an arbitrary sequence of tensors.
127
+ This can be used after ``transform_update`` has been used at least once."""
128
+
129
+ if settings is None:
130
+ settings = [self.defaults for _ in tensors]
131
+
132
+ num = len(tensors)
133
+ states = states[:num]
134
+ settings = settings[:num]
135
+
136
+ un_tensors = tensors
137
+ un_params = params
138
+ un_grads = grads
139
+
140
+ tensors = self.global_state.pop("__tensors")
141
+ params = self.global_state.pop("__params")
142
+ grads = self.global_state.pop("__grads")
143
+ scale_factor = self.global_state.pop("__scale_factor")
144
+
145
+ # step with inner
146
+ if self._inner is not None:
147
+ tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
148
+ if self._concat_params:
149
+ tensors = [torch.cat([t.ravel() for t in tensors])]
150
+
151
+ # apply transform
152
+ tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
153
+
154
+ # scale initial step, when preconditioner might not have been applied
155
+ if self._scale_first and self.global_state['__step'] == 1:
156
+ torch._foreach_mul_(tensors, scale_factor)
157
+
158
+ if self._concat_params:
159
+ tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
160
+ return tensors
161
+
162
+ def _get_keyed_states_settings(self, params: list[torch.Tensor]):
163
+ if self._concat_params:
164
+ p = params[0]
165
+ states = [self.state[p]]
166
+ settings = [self.settings[p]]
33
167
 
34
- def step(self, vars: Vars) -> Vars:
35
- # vars may change, therefore current params and grads have to be extracted and passed explicitly
36
- if self._uses_grad: vars.get_grad()
37
- params=vars.params; grad = vars.grad
168
+ else:
169
+ states = []
170
+ settings = []
171
+ for p in params:
172
+ states.append(self.state[p])
173
+ settings.append(self.settings[p])
174
+
175
+ return states, settings
176
+
177
+ @final
178
+ @torch.no_grad
179
+ def keyed_transform_update(
180
+ self,
181
+ tensors: list[torch.Tensor],
182
+ params: list[torch.Tensor],
183
+ grads: list[torch.Tensor] | None,
184
+ loss: torch.Tensor | float | None,
185
+ ):
186
+ """`params` will be used as keys and need to always point to same tensor objects.`"""
187
+ states, settings = self._get_keyed_states_settings(params)
188
+ self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
189
+
190
+
191
+ @final
192
+ @torch.no_grad
193
+ def keyed_transform_apply(
194
+ self,
195
+ tensors: list[torch.Tensor],
196
+ params: list[torch.Tensor],
197
+ grads: list[torch.Tensor] | None,
198
+ loss: torch.Tensor | float | None,
199
+ ):
200
+ """`params` will be used as keys and need to always point to same tensor objects.`"""
201
+ states, settings = self._get_keyed_states_settings(params)
202
+ return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
203
+
204
+
205
+ def pre_step(self, var: Var) -> None:
206
+ """Logic to run pre-transform, this way transform has access to Var."""
207
+ def post_step(self, var: Var) -> None:
208
+ """Logic to run post-transform, this way transform has access to Var."""
209
+
210
+ def update(self, var: Var):
211
+ if self._target != 'update':
212
+ raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
213
+ f"With {self._target = } only `step` method can be used.")
214
+
215
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
216
+ update = var.get_update() # this sets loss
217
+ if self._uses_grad: var.get_grad()
218
+ if self._uses_loss: var.get_loss(False)
219
+ params=var.params
220
+ self.pre_step(var)
221
+
222
+ # update
223
+ self.keyed_transform_update(update, params, var.grad, var.loss)
224
+
225
+ def apply(self, var: Var):
226
+ if self._target != 'update':
227
+ raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
228
+ f"With {self._target = } only `step` method can be used.")
229
+
230
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
231
+ update = var.get_update() # this sets loss
232
+ if self._uses_grad: var.get_grad()
233
+ if self._uses_loss: var.get_loss(False)
234
+ params=var.params
235
+
236
+ # apply
237
+ var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
238
+ self.post_step(var)
239
+ return var
240
+
241
+ def step(self, var: Var) -> Var:
242
+
243
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
244
+ if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
245
+ if self._uses_grad or self._target == 'grad': var.get_grad()
246
+ if self._uses_loss: var.get_loss(False)
247
+ params=var.params
248
+ self.pre_step(var)
38
249
 
39
250
  # ---------------------------------- update ---------------------------------- #
40
251
  if self._target == 'update':
41
- vars.update = list(self.transform(vars.get_update(), params, grad, vars))
42
- return vars
252
+ update = var.get_update()
253
+ self.keyed_transform_update(update, params, var.grad, var.loss)
254
+ var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
255
+ return var
43
256
 
44
257
  # ----------------------------------- grad ----------------------------------- #
45
258
  if self._target == 'grad':
46
- vars.grad = list(self.transform(vars.get_grad(), params, grad, vars))
47
- return vars
259
+ grad = var.get_grad()
260
+ self.keyed_transform_update(grad, params, grad, var.loss)
261
+ var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
262
+ return var
48
263
 
49
264
  # ------------------------------- params_direct ------------------------------ #
50
265
  if self._target == 'params_direct':
51
- new_params = self.transform(vars.params, params, grad, vars)
52
- for p, new_p in zip(vars.params, new_params): set_storage_(p, new_p)
53
- return vars
266
+ self.keyed_transform_update(var.params, params, var.grad, var.loss)
267
+ new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
268
+ for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
269
+ return var
54
270
 
55
271
  # ----------------------------- params_differnce ----------------------------- #
56
272
  if self._target == 'params_difference':
57
- new_params = tuple(self.transform([p.clone() for p in vars.params], params, grad, vars))
58
- vars.update = list(torch._foreach_sub(vars.params, new_params))
59
- return vars
273
+ p_clone = [p.clone() for p in var.params]
274
+ self.keyed_transform_update(p_clone, params, var.grad, var.loss)
275
+ new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
276
+ var.update = list(torch._foreach_sub(var.params, new_params))
277
+ return var
60
278
 
61
279
  # ----------------------------- update_difference ---------------------------- #
62
280
  if self._target == 'update_difference':
63
- update = vars.get_update()
64
- new_update = tuple(self.transform([u.clone() for u in update], params, grad, vars))
65
- vars.update = list(torch._foreach_sub(update, new_update))
66
- return vars
281
+ update = var.get_update()
282
+ u_clone = [u.clone() for u in update]
283
+ self.keyed_transform_update(u_clone, params, var.grad, var.loss)
284
+ new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
285
+ var.update = list(torch._foreach_sub(update, new_update))
286
+ return var
67
287
 
68
288
  # ---------------------------------- closure --------------------------------- #
69
289
  if self._target == 'closure':
70
- original_closure = vars.closure
290
+ original_closure = var.closure
71
291
  if original_closure is None: raise ValueError('Target = "closure", but closure is None')
72
292
 
73
- params = vars.params
293
+ params = var.params
74
294
  def transformed_closure(backward=True):
75
295
  if backward:
76
296
  loss = original_closure()
77
297
  current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
78
- transformed_grad = list(self.transform(current_grad, params, grad, vars))
298
+ self.keyed_transform_update(current_grad, params, var.grad, var.loss)
299
+ transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
79
300
  for p, g in zip(params, transformed_grad):
80
301
  p.grad = g
81
302
 
@@ -84,14 +305,15 @@ class Transform(Module, ABC):
84
305
 
85
306
  return loss
86
307
 
87
- vars.closure = transformed_closure
88
- return vars
308
+ var.closure = transformed_closure
309
+ self.post_step(var)
310
+ return var
89
311
 
90
312
  # ---------------------------------- invalid --------------------------------- #
91
313
  raise ValueError(f'Invalid target: {self._target}')
92
314
 
93
315
 
94
- class TensorwiseTransform(Module, ABC):
316
+ class TensorwiseTransform(Transform, ABC):
95
317
  """Base class for a parameter-wise transform.
96
318
 
97
319
  This is an abstract class, to use it, subclass it and override `transform`.
@@ -102,151 +324,97 @@ class TensorwiseTransform(Module, ABC):
102
324
  Set this to True if `transform` method uses the `grad` argument. This will ensure
103
325
  `grad` is always computed and can't be None. Otherwise set to False.
104
326
  target (Target, optional):
105
- what to set on vars. Defaults to 'update'.
327
+ what to set on var. Defaults to 'update'.
106
328
  """
107
- def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
108
- super().__init__(defaults)
109
- self._target: Target = target
110
- self._uses_grad: bool = uses_grad
329
+ def __init__(
330
+ self,
331
+ defaults: dict[str,Any] | None,
332
+ uses_grad: bool = False,
333
+ uses_loss: bool = False,
334
+ concat_params: bool = False,
335
+ update_freq: int = 1,
336
+ scale_first: bool = False,
337
+ inner: Chainable | None = None,
338
+ target: Target = 'update',
339
+ ):
340
+ super().__init__(
341
+ defaults=defaults,
342
+ uses_grad=uses_grad,
343
+ concat_params=concat_params,
344
+ update_freq=update_freq,
345
+ scale_first=scale_first,
346
+ uses_loss=uses_loss,
347
+ inner=inner,
348
+ target=target,
349
+ )
350
+
351
+ def update_tensor(
352
+ self,
353
+ tensor: torch.Tensor,
354
+ param: torch.Tensor,
355
+ grad: torch.Tensor | None,
356
+ loss: torch.Tensor | float | None,
357
+ state: dict[str, Any],
358
+ setting: Mapping[str, Any],
359
+ ) -> None:
360
+ """Updates this transform. By default does nothing - if logic is in `apply` method."""
111
361
 
112
362
  @abstractmethod
113
- def transform(
363
+ def apply_tensor(
114
364
  self,
115
365
  tensor: torch.Tensor,
116
366
  param: torch.Tensor,
117
367
  grad: torch.Tensor | None,
118
- vars: Vars,
368
+ loss: torch.Tensor | float | None,
369
+ state: dict[str, Any],
370
+ setting: Mapping[str, Any],
119
371
  ) -> torch.Tensor:
120
- """applies the update rule to `target`"""
121
-
122
- def step(self, vars: Vars) -> Vars:
123
- params = vars.params
124
- if self._uses_grad and vars.grad is None: vars.get_grad()
125
-
126
- # ---------------------------------- update ---------------------------------- #
127
- if self._target == 'update':
128
- update = vars.get_update()
129
- grad = vars.grad if vars.grad is not None else [None] * len(params)
130
- transformed_update = []
131
-
132
- for p, g, u in zip(params, grad, update):
133
- # settings = self.settings[p] # couldn't make typing work with this
134
- #, self.transform(target=u, param=p, grad=g, vars=vars, **{k:settings[k] for k in self.defaults})
135
- transformed_update.append(self.transform(tensor=u, param=p, grad=g, vars=vars))
136
-
137
- vars.update = transformed_update
138
- return vars
139
-
140
- # ----------------------------------- grad ----------------------------------- #
141
- if self._target == 'grad':
142
- grad = vars.get_grad()
143
- transformed_grad = []
144
-
145
- for p, g in zip(params, grad):
146
- transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
147
-
148
- vars.grad = transformed_grad
149
- return vars
150
-
151
- # ------------------------------- params_direct ------------------------------ #
152
- if self._target == 'params_direct':
153
- grad = vars.grad if vars.grad is not None else [None] * len(params)
154
-
155
- for p, g in zip(params, grad):
156
- set_storage_(p, self.transform(tensor=p, param=p, grad=g, vars=vars))
157
-
158
- return vars
159
-
160
- # ----------------------------- params_difference ---------------------------- #
161
- if self._target == 'params_difference':
162
- grad = vars.grad if vars.grad is not None else [None] * len(params)
163
- transformed_params = []
164
-
165
- for p, g in zip(params, grad):
166
- transformed_params.append(
167
- self.transform(tensor=p.clone(), param=p, grad=g, vars=vars)
168
- )
169
-
170
- vars.update = list(torch._foreach_sub(params, transformed_params))
171
- return vars
172
-
173
- # ----------------------------- update_difference ---------------------------- #
174
- if self._target == 'update_difference':
175
- update = vars.get_update()
176
- grad = vars.grad if vars.grad is not None else [None] * len(params)
177
- transformed_update = []
178
-
179
- for p, g, u in zip(params, grad, update):
180
- transformed_update.append(
181
- self.transform(tensor=u.clone(), param=p, grad=g, vars=vars)
182
- )
183
-
184
- vars.update = list(torch._foreach_sub(update, transformed_update))
185
- return vars
186
-
187
- # ---------------------------------- closure --------------------------------- #
188
- if self._target == 'closure':
189
- original_closure = vars.closure
190
- if original_closure is None: raise ValueError('Target = "closure", but closure is None')
191
-
192
- params = vars.params
193
- def transformed_closure(backward=True):
194
- if backward:
195
- loss = original_closure()
196
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
197
- transformed_grad = []
198
-
199
- for p, g in zip(params, grad):
200
- transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
201
-
202
- for p, g in zip(params, transformed_grad):
203
- p.grad = g
204
-
205
- else:
206
- loss = original_closure(False)
207
-
208
- return loss
209
-
210
- vars.closure = transformed_closure
211
- return vars
212
-
213
- # ---------------------------------- invalid --------------------------------- #
214
- raise ValueError(f'Invalid target: {self._target}')
215
-
216
-
217
-
218
- def apply(
372
+ """Applies the update rule to `tensor`."""
373
+
374
+ @final
375
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
376
+ if grads is None: grads = [None]*len(tensors)
377
+ for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
378
+ self.update_tensor(t, p, g, loss, state, setting)
379
+
380
+ @final
381
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
382
+ applied = []
383
+ if grads is None: grads = [None]*len(tensors)
384
+ for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
385
+ applied.append(self.apply_tensor(t, p, g, loss, state, setting))
386
+ return applied
387
+
388
+ def apply_transform(
219
389
  tfm: Chainable,
220
390
  tensors: list[torch.Tensor],
221
391
  params: list[torch.Tensor],
222
392
  grads: list[torch.Tensor] | None,
223
- vars: Vars | None = None,
393
+ loss: torch.Tensor | float | None = None,
394
+ var: Var | None = None,
224
395
  current_step: int = 0,
225
396
  ):
226
- if vars is None: vars = Vars(params=params, closure=None, model=None, current_step=current_step)
227
- if isinstance(tfm, Transform):
228
- if tfm._uses_grad and grads is None: grads = vars.get_grad()
229
- return list(tfm.transform(tensors, params, grads, vars))
230
-
231
- if isinstance(tfm, TensorwiseTransform):
232
- grads_list = grads
233
- if grads_list is None:
234
- if tfm._uses_grad: grads_list = vars.get_grad()
235
- else: grads_list = [None] * len(tensors)
236
- return [tfm.transform(t, p, g, vars) for t,p,g in zip(tensors,params,grads_list)]
397
+ if var is None:
398
+ var = Var(params=params, closure=None, model=None, current_step=current_step)
399
+ var.loss = loss
400
+
401
+ if isinstance(tfm, Transform) and tfm._target == 'update':
402
+ if tfm._uses_grad and grads is None: grads = var.get_grad()
403
+ tfm.keyed_transform_update(tensors, params, grads, loss)
404
+ return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
237
405
 
238
406
  if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
239
407
  if isinstance(tfm, Sequence):
240
408
  for module in tfm:
241
- tensors = apply(module, tensors=tensors, params=params, grads=grads, vars=vars)
409
+ tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
242
410
  return tensors
243
411
 
244
412
  if isinstance(tfm, Module):
245
- cvars = vars.clone(clone_update=False)
246
- cvars.update = tensors
247
- cvars = tfm.step(cvars)
248
- vars.update_attrs_from_clone_(cvars)
249
- assert cvars.update is not None
250
- return cvars.update
413
+ cvar = var.clone(clone_update=False)
414
+ cvar.update = tensors
415
+ cvar = tfm.step(cvar)
416
+ var.update_attrs_from_clone_(cvar)
417
+ assert cvar.update is not None
418
+ return cvar.update
251
419
 
252
420
  raise TypeError(type(tfm))
@@ -1,7 +1,7 @@
1
1
  from .clipping import *
2
2
  from .grad_approximation import *
3
3
  from .line_search import *
4
- from .lr import *
4
+ from .step_size import *
5
5
  from .momentum import *
6
6
  from .ops import *
7
7
  from .optimizers import *
@@ -11,3 +11,5 @@ from .smoothing import *
11
11
  from .weight_decay import *
12
12
  from .wrappers import *
13
13
  from .second_order import *
14
+ from .higher_order import *
15
+ from .misc import *