torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,18 +1,36 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import Iterable, Sequence, Mapping
2
+ from collections.abc import Iterable, Mapping, Sequence
3
3
  from typing import Any, Literal, final
4
4
 
5
5
  import torch
6
6
 
7
- from ..utils import set_storage_, TensorList, vec_to_tensors
8
- from .module import Module, Var, Chain, Chainable
7
+ from ..utils import TensorList, set_storage_, vec_to_tensors
8
+ from .module import Chain, Chainable, Module, Var
9
9
 
10
10
  Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
11
 
12
+
12
13
  class Transform(Module, ABC):
13
- """Base class for a transform. This is an abstract class, to use it, subclass it and override `update` and `apply` methods.
14
+ """Base class for a transform.
15
+ This is an abstract class, to use it, subclass it and override ``update_tensors`` and ``apply_tensors`` methods.
14
16
 
15
17
  A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
18
+ It has two methods:
19
+
20
+ - ``update_tensors`` updates the internal state of this transform, it doesn't modify tensors. \
21
+ It may be called multiple times before ``apply_tensors``.
22
+ - ``apply_tensors`` applies this transform to tensors, without modifying the internal state if possible.
23
+
24
+ Alternatively, if update-apply structure doesn't make sense for a transform, all logic can be defined within ``apply_tensors``.
25
+
26
+ Transform can be applied to tensors corresponding to custom parameters
27
+ by calling ``keyed_transform_update`` and ``keyed_transform_apply``,
28
+ parameters will be keys to store per-parameter states, so they should remain the same python objects.
29
+
30
+ Alternatively you can manually create a list of state dictionaries per each tensor and pass it to
31
+ ``transform_update`` and ``transform_apply``.
32
+
33
+ A transform can modify the closure instead of directly modifying update by passing ``target="closure"``.
16
34
 
17
35
  Args:
18
36
  defaults (dict[str,Any] | None): dict with default values.
@@ -21,6 +39,7 @@ class Transform(Module, ABC):
21
39
  `grad` is always computed and can't be None. Otherwise set to False.
22
40
  target (Target, optional):
23
41
  what to set on var. Defaults to 'update'.
42
+
24
43
  """
25
44
  def __init__(
26
45
  self,
@@ -29,7 +48,6 @@ class Transform(Module, ABC):
29
48
  uses_loss: bool = False,
30
49
  concat_params: bool = False,
31
50
  update_freq: int = 1,
32
- scale_first: bool = False,
33
51
  inner: Chainable | None = None,
34
52
  target: Target = 'update',
35
53
  ):
@@ -39,8 +57,8 @@ class Transform(Module, ABC):
39
57
  self._uses_loss = uses_loss
40
58
  self._concat_params = concat_params
41
59
  self._update_freq = update_freq
42
- self._scale_first = scale_first
43
60
  self._inner = inner
61
+ self._var = None
44
62
 
45
63
  def update_tensors(
46
64
  self,
@@ -93,14 +111,6 @@ class Transform(Module, ABC):
93
111
  states = states[:num]
94
112
  settings = settings[:num]
95
113
 
96
- scale_factor = 1
97
-
98
- # scaling factor for 1st step
99
- if self._scale_first and step == 0:
100
- # initial step size guess from pytorch LBFGS
101
- scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
102
- scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
103
-
104
114
  # update transform
105
115
  if step % self._update_freq == 0:
106
116
  self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
@@ -109,7 +119,6 @@ class Transform(Module, ABC):
109
119
  self.global_state["__tensors"] = tensors
110
120
  self.global_state["__params"] = params
111
121
  self.global_state["__grads"] = grads
112
- self.global_state["__scale_factor"] = scale_factor
113
122
 
114
123
 
115
124
  @final
@@ -140,23 +149,19 @@ class Transform(Module, ABC):
140
149
  tensors = self.global_state.pop("__tensors")
141
150
  params = self.global_state.pop("__params")
142
151
  grads = self.global_state.pop("__grads")
143
- scale_factor = self.global_state.pop("__scale_factor")
144
152
 
145
153
  # step with inner
146
154
  if self._inner is not None:
147
- tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
155
+ tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads, var=self._var)
148
156
  if self._concat_params:
149
157
  tensors = [torch.cat([t.ravel() for t in tensors])]
150
158
 
151
159
  # apply transform
152
160
  tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
153
161
 
154
- # scale initial step, when preconditioner might not have been applied
155
- if self._scale_first and self.global_state['__step'] == 1:
156
- torch._foreach_mul_(tensors, scale_factor)
157
-
158
162
  if self._concat_params:
159
163
  tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
164
+
160
165
  return tensors
161
166
 
162
167
  def _get_keyed_states_settings(self, params: list[torch.Tensor]):
@@ -220,7 +225,9 @@ class Transform(Module, ABC):
220
225
  self.pre_step(var)
221
226
 
222
227
  # update
228
+ self._var = var
223
229
  self.keyed_transform_update(update, params, var.grad, var.loss)
230
+ self._var = None
224
231
 
225
232
  def apply(self, var: Var):
226
233
  if self._target != 'update':
@@ -234,7 +241,10 @@ class Transform(Module, ABC):
234
241
  params=var.params
235
242
 
236
243
  # apply
244
+ self._var = var
237
245
  var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
246
+ self._var = None
247
+
238
248
  self.post_step(var)
239
249
  return var
240
250
 
@@ -246,12 +256,14 @@ class Transform(Module, ABC):
246
256
  if self._uses_loss: var.get_loss(False)
247
257
  params=var.params
248
258
  self.pre_step(var)
259
+ self._var = var
249
260
 
250
261
  # ---------------------------------- update ---------------------------------- #
251
262
  if self._target == 'update':
252
263
  update = var.get_update()
253
264
  self.keyed_transform_update(update, params, var.grad, var.loss)
254
265
  var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
266
+ self._var = None
255
267
  return var
256
268
 
257
269
  # ----------------------------------- grad ----------------------------------- #
@@ -259,6 +271,7 @@ class Transform(Module, ABC):
259
271
  grad = var.get_grad()
260
272
  self.keyed_transform_update(grad, params, grad, var.loss)
261
273
  var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
274
+ self._var = None
262
275
  return var
263
276
 
264
277
  # ------------------------------- params_direct ------------------------------ #
@@ -266,6 +279,7 @@ class Transform(Module, ABC):
266
279
  self.keyed_transform_update(var.params, params, var.grad, var.loss)
267
280
  new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
268
281
  for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
282
+ self._var = None
269
283
  return var
270
284
 
271
285
  # ----------------------------- params_differnce ----------------------------- #
@@ -274,6 +288,7 @@ class Transform(Module, ABC):
274
288
  self.keyed_transform_update(p_clone, params, var.grad, var.loss)
275
289
  new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
276
290
  var.update = list(torch._foreach_sub(var.params, new_params))
291
+ self._var = None
277
292
  return var
278
293
 
279
294
  # ----------------------------- update_difference ---------------------------- #
@@ -283,6 +298,7 @@ class Transform(Module, ABC):
283
298
  self.keyed_transform_update(u_clone, params, var.grad, var.loss)
284
299
  new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
285
300
  var.update = list(torch._foreach_sub(update, new_update))
301
+ self._var = None
286
302
  return var
287
303
 
288
304
  # ---------------------------------- closure --------------------------------- #
@@ -291,12 +307,17 @@ class Transform(Module, ABC):
291
307
  if original_closure is None: raise ValueError('Target = "closure", but closure is None')
292
308
 
293
309
  params = var.params
310
+ parent_var = self._var
294
311
  def transformed_closure(backward=True):
295
312
  if backward:
296
313
  loss = original_closure()
297
314
  current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
315
+
316
+ self._var = parent_var
298
317
  self.keyed_transform_update(current_grad, params, var.grad, var.loss)
299
318
  transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
319
+ self._var = None
320
+
300
321
  for p, g in zip(params, transformed_grad):
301
322
  p.grad = g
302
323
 
@@ -307,6 +328,7 @@ class Transform(Module, ABC):
307
328
 
308
329
  var.closure = transformed_closure
309
330
  self.post_step(var)
331
+ self._var = None
310
332
  return var
311
333
 
312
334
  # ---------------------------------- invalid --------------------------------- #
@@ -316,7 +338,7 @@ class Transform(Module, ABC):
316
338
  class TensorwiseTransform(Transform, ABC):
317
339
  """Base class for a parameter-wise transform.
318
340
 
319
- This is an abstract class, to use it, subclass it and override `transform`.
341
+ This is an abstract class, to use it, subclass it and override `update_tensor` and `apply_tensor`.
320
342
 
321
343
  Args:
322
344
  defaults (dict[str,Any] | None): dict with default values.
@@ -333,7 +355,6 @@ class TensorwiseTransform(Transform, ABC):
333
355
  uses_loss: bool = False,
334
356
  concat_params: bool = False,
335
357
  update_freq: int = 1,
336
- scale_first: bool = False,
337
358
  inner: Chainable | None = None,
338
359
  target: Target = 'update',
339
360
  ):
@@ -342,7 +363,6 @@ class TensorwiseTransform(Transform, ABC):
342
363
  uses_grad=uses_grad,
343
364
  concat_params=concat_params,
344
365
  update_freq=update_freq,
345
- scale_first=scale_first,
346
366
  uses_loss=uses_loss,
347
367
  inner=inner,
348
368
  target=target,
@@ -1,15 +1,23 @@
1
+ from . import experimental
1
2
  from .clipping import *
3
+ from .conjugate_gradient import *
2
4
  from .grad_approximation import *
5
+ from .higher_order import *
6
+ from .least_squares import *
3
7
  from .line_search import *
4
- from .step_size import *
8
+ from .misc import *
5
9
  from .momentum import *
6
10
  from .ops import *
7
- from .optimizers import *
11
+ from .adaptive import *
8
12
  from .projections import *
9
13
  from .quasi_newton import *
14
+ from .second_order import *
10
15
  from .smoothing import *
16
+ from .step_size import *
17
+ from .termination import *
18
+ from .trust_region import *
19
+ from .variance_reduction import *
11
20
  from .weight_decay import *
12
21
  from .wrappers import *
13
- from .second_order import *
14
- from .higher_order import *
15
- from .misc import *
22
+ from .restarts import *
23
+ from .zeroth_order import *
@@ -1,4 +1,4 @@
1
- from .adagrad import Adagrad, FullMatrixAdagrad
1
+ from .adagrad import Adagrad, FullMatrixAdagrad, AdagradNorm
2
2
 
3
3
  # from .curveball import CurveBall
4
4
  # from .spectral import SpectralPreconditioner
@@ -6,12 +6,15 @@ from .adahessian import AdaHessian
6
6
  from .adam import Adam
7
7
  from .adan import Adan
8
8
  from .adaptive_heavyball import AdaptiveHeavyBall
9
+ from .aegd import AEGD
9
10
  from .esgd import ESGD
10
- from .ladagrad import LMAdagrad
11
+ from .lmadagrad import LMAdagrad
11
12
  from .lion import Lion
12
13
  from .mars import MARSCorrection
14
+ from .matrix_momentum import MatrixMomentum
13
15
  from .msam import MSAM, MSAMObjective
14
16
  from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
17
+ from .natural_gradient import NaturalGradient
15
18
  from .orthograd import OrthoGrad, orthograd_
16
19
  from .rmsprop import RMSprop
17
20
  from .rprop import (
@@ -0,0 +1,356 @@
1
+ from operator import itemgetter
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from ...core import (
6
+ Chainable,
7
+ Module,
8
+ Target,
9
+ TensorwiseTransform,
10
+ Transform,
11
+ Var,
12
+ apply_transform,
13
+ )
14
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
15
+ from ...utils.linalg import matrix_power_eigh
16
+ from ..functional import add_power_, lerp_power_, root, epsilon_step_size
17
+ from ...utils.linalg.linear_operator import Dense
18
+
19
+ def adagrad_(
20
+ tensors_: TensorList,
21
+ sq_sum_: TensorList,
22
+ alpha: float | NumberList,
23
+ lr_decay: float | NumberList,
24
+ eps: float | NumberList,
25
+ step: int,
26
+ pow: float = 2,
27
+ use_sqrt: bool = True,
28
+ divide: bool = False,
29
+
30
+ decay: float | None = None,
31
+ beta: float | None = None,
32
+
33
+ # inner args
34
+ inner: Module | None = None,
35
+ params: list[torch.Tensor] | None = None,
36
+ grads: list[torch.Tensor] | None = None,
37
+ ):
38
+ """returns `tensors_`"""
39
+ clr = alpha / (1 + step * lr_decay)
40
+
41
+ if beta is None or step == 1: sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
42
+ else: sq_sum_ = lerp_power_(tensors_, exp_avg_pow_=sq_sum_, beta=beta, pow=pow)
43
+ if decay is not None:
44
+ sq_sum_.mul_(1-decay)
45
+
46
+ if inner is not None:
47
+ assert params is not None
48
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
49
+
50
+ if divide: sq_sum_ = sq_sum_ / max(step, 1)
51
+
52
+ if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
53
+ else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
54
+
55
+ return tensors_
56
+
57
+
58
+
59
+ class Adagrad(Transform):
60
+ """Adagrad, divides by sum of past squares of gradients.
61
+
62
+ This implementation is identical to ``torch.optim.Adagrad``.
63
+
64
+ Args:
65
+ lr_decay (float, optional): learning rate decay. Defaults to 0.
66
+ initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
67
+ eps (float, optional): division epsilon. Defaults to 1e-10.
68
+ alpha (float, optional): step size. Defaults to 1.
69
+ pow (float, optional): power for gradients and accumulator root. Defaults to 2.
70
+ use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
71
+ inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
72
+ """
73
+ def __init__(
74
+ self,
75
+ lr_decay: float = 0,
76
+ initial_accumulator_value: float = 0,
77
+ eps: float = 1e-10,
78
+ alpha: float = 1,
79
+ pow: float = 2,
80
+ use_sqrt: bool = True,
81
+ divide: bool=False,
82
+ beta:float | None = None,
83
+ decay: float | None = None,
84
+ inner: Chainable | None = None,
85
+ ):
86
+ defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
87
+ eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
88
+ super().__init__(defaults=defaults, uses_grad=False)
89
+
90
+ if inner is not None:
91
+ self.set_child('inner', inner)
92
+
93
+ @torch.no_grad
94
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
95
+ tensors = TensorList(tensors)
96
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
97
+
98
+ lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
99
+
100
+ pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
101
+
102
+ sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
103
+
104
+ # initialize accumulator on 1st step
105
+ if step == 1:
106
+ sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
107
+
108
+ return adagrad_(
109
+ tensors,
110
+ sq_sum_=sq_sum,
111
+ alpha=alpha,
112
+ lr_decay=lr_decay,
113
+ eps=eps,
114
+ step=step,
115
+ pow=pow,
116
+ use_sqrt=use_sqrt,
117
+ divide=divide,
118
+
119
+ beta = self.defaults["beta"],
120
+ decay = self.defaults["decay"],
121
+ # inner args
122
+ inner=self.children.get("inner", None),
123
+ params=params,
124
+ grads=grads,
125
+ )
126
+
127
+
128
+ def lerp(start, end, weight):
129
+ return start + weight * (end - start)
130
+
131
+ def adagrad_norm_(
132
+ tensors_: TensorList,
133
+ accumulator: float | torch.Tensor,
134
+ alpha: float | NumberList,
135
+ lr_decay: float | NumberList,
136
+ eps: float | NumberList,
137
+ step: int,
138
+ use_sqrt: bool = True,
139
+ divide: bool = False,
140
+
141
+ decay: float | None = None,
142
+ beta: float | None = None,
143
+
144
+ # inner args
145
+ inner: Module | None = None,
146
+ params: list[torch.Tensor] | None = None,
147
+ grads: list[torch.Tensor] | None = None,
148
+ ):
149
+ """returns `tensors_`"""
150
+ clr = alpha / (1 + step * lr_decay)
151
+
152
+ gg = tensors_.dot(tensors_)
153
+
154
+ if beta is None or step == 1: accumulator += gg
155
+ else: accumulator = lerp(accumulator, gg, 1-beta)
156
+
157
+ if decay is not None:
158
+ accumulator *= 1-decay
159
+
160
+ if inner is not None:
161
+ assert params is not None
162
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
163
+
164
+ if divide: accumulator = accumulator / max(step, 1)
165
+
166
+ if use_sqrt: tensors_.div_(eps + accumulator.sqrt()).mul_(clr)
167
+ else: tensors_.div_(eps + accumulator).mul_(clr)
168
+
169
+ return tensors_, accumulator
170
+
171
+ class AdagradNorm(Transform):
172
+ """Adagrad-Norm, divides by sum of past means of squares of gradients.
173
+
174
+ Args:
175
+ lr_decay (float, optional): learning rate decay. Defaults to 0.
176
+ initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
177
+ eps (float, optional): division epsilon. Defaults to 1e-10.
178
+ alpha (float, optional): step size. Defaults to 1.
179
+ pow (float, optional): power for gradients and accumulator root. Defaults to 2.
180
+ use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
181
+ inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
182
+ """
183
+ def __init__(
184
+ self,
185
+ lr_decay: float = 0,
186
+ initial_accumulator_value: float = 0,
187
+ eps: float = 1e-10,
188
+ alpha: float = 1,
189
+ pow: float = 2,
190
+ use_sqrt: bool = True,
191
+ divide: bool=False,
192
+ beta:float | None = None,
193
+ decay: float | None = None,
194
+ inner: Chainable | None = None,
195
+ ):
196
+ defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
197
+ eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
198
+ super().__init__(defaults=defaults, uses_grad=False)
199
+
200
+ if inner is not None:
201
+ self.set_child('inner', inner)
202
+
203
+ @torch.no_grad
204
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
205
+ tensors = TensorList(tensors)
206
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
207
+ lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
208
+
209
+ use_sqrt, divide, initial_accumulator_value = itemgetter('use_sqrt', 'divide', "initial_accumulator_value")(settings[0])
210
+
211
+ accumulator = self.global_state.get("accumulator", initial_accumulator_value)
212
+
213
+ d, self.global_state["accumulator"] = adagrad_norm_(
214
+ tensors,
215
+ accumulator=accumulator,
216
+ alpha=alpha,
217
+ lr_decay=lr_decay,
218
+ eps=eps,
219
+ step=step,
220
+ use_sqrt=use_sqrt,
221
+ divide=divide,
222
+
223
+ beta = self.defaults["beta"],
224
+ decay = self.defaults["decay"],
225
+ # inner args
226
+ inner=self.children.get("inner", None),
227
+ params=params,
228
+ grads=grads,
229
+ )
230
+
231
+ return d
232
+
233
+
234
+ class FullMatrixAdagrad(TensorwiseTransform):
235
+ """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
236
+
237
+ Note:
238
+ A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
239
+
240
+ Args:
241
+ beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
242
+ decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
243
+ sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
244
+ concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
245
+ precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
246
+ init (Literal[str], optional):
247
+ how to initialize the accumulator.
248
+ - "identity" - with identity matrix (default).
249
+ - "zeros" - with zero matrix.
250
+ - "ones" - with matrix of ones.
251
+ -"GGT" - with the first outer product
252
+ divide (bool, optional): whether to divide the accumulator by number of gradients in it. Defaults to False.
253
+ inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.
254
+
255
+ ## Examples:
256
+
257
+ Plain full-matrix adagrad
258
+ ```python
259
+ opt = tz.Modular(
260
+ model.parameters(),
261
+ tz.m.FullMatrixAdagrd(),
262
+ tz.m.LR(1e-2),
263
+ )
264
+ ```
265
+
266
+ Full-matrix RMSprop
267
+ ```python
268
+ opt = tz.Modular(
269
+ model.parameters(),
270
+ tz.m.FullMatrixAdagrad(beta=0.99),
271
+ tz.m.LR(1e-2),
272
+ )
273
+ ```
274
+
275
+ Full-matrix Adam
276
+ ```python
277
+ opt = tz.Modular(
278
+ model.parameters(),
279
+ tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
280
+ tz.m.Debias(0.9, 0.999),
281
+ tz.m.LR(1e-2),
282
+ )
283
+ ```
284
+ """
285
+ def __init__(
286
+ self,
287
+ beta: float | None = None,
288
+ decay: float | None = None,
289
+ sqrt: bool = True,
290
+ concat_params=True,
291
+ precond_freq: int = 1,
292
+ init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
293
+ reg: float = 1e-12,
294
+ divide: bool = False,
295
+ inner: Chainable | None = None,
296
+ ):
297
+ defaults = dict(beta=beta, decay=decay, sqrt=sqrt, precond_freq=precond_freq, init=init, divide=divide, reg=reg)
298
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner,)
299
+
300
+ @torch.no_grad
301
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
302
+ G = tensor.ravel()
303
+ GG = torch.outer(G, G)
304
+ decay = setting['decay']
305
+ beta = setting['beta']
306
+ init = setting['init']
307
+
308
+ if 'GG' not in state:
309
+ if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
310
+ elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
311
+ elif init == 'ones': state['GG'] = torch.ones_like(GG)
312
+ elif init == 'GGT': state['GG'] = GG.clone()
313
+ else: raise ValueError(init)
314
+ if decay is not None: state['GG'].mul_(decay)
315
+
316
+ if beta is not None: state['GG'].lerp_(GG, 1-beta)
317
+ else: state['GG'].add_(GG)
318
+ state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
319
+
320
+ @torch.no_grad
321
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
322
+ step = state.get('step', 0)
323
+ state['step'] = step + 1
324
+
325
+ GG: torch.Tensor = state['GG']
326
+ sqrt = setting['sqrt']
327
+ divide = setting['divide']
328
+ precond_freq = setting['precond_freq']
329
+ reg = setting['reg']
330
+
331
+ if divide: GG = GG/state.get('i', 1)
332
+
333
+ if reg != 0:
334
+ GG = GG + torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype).mul_(reg)
335
+
336
+ if tensor.numel() == 1:
337
+ GG = GG.squeeze()
338
+ if sqrt: return tensor / GG.sqrt()
339
+ return tensor / GG
340
+
341
+ try:
342
+ if sqrt:
343
+ if "B" not in state or step % precond_freq == 0:
344
+ B = state["B"] = matrix_power_eigh(GG, -1/2)
345
+ else:
346
+ B = state["B"]
347
+
348
+ else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
349
+
350
+ except torch.linalg.LinAlgError:
351
+ # fallback to diagonal AdaGrad
352
+ denom = GG.diagonal()
353
+ if sqrt: denom = denom.sqrt()
354
+ return tensor.div_(denom + max(reg, 1e-12))
355
+
356
+ return (B @ tensor.ravel()).view_as(tensor)