torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Sequence
5
5
  import torch
6
6
 
7
7
  from ...core import Module, Target, Transform, apply_transform, Chainable
8
- from ...utils import NumberList, TensorList, generic_eq, unpack_dicts, unpack_states
8
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
9
 
10
10
  class ClipNormByEMA(Transform):
11
11
  """Clips norm to be no larger than the norm of an exponential moving average of past updates.
@@ -14,9 +14,10 @@ class ClipNormByEMA(Transform):
14
14
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
15
15
  ord (float, optional): order of the norm. Defaults to 2.
16
16
  eps (float, optional): epsilon for division. Defaults to 1e-6.
17
- tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
17
+ tensorwise (bool, optional):
18
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
18
19
  max_ema_growth (float | None, optional):
19
- if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
20
+ if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
20
21
  ema_init (str, optional):
21
22
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
22
23
  """
@@ -29,12 +30,13 @@ class ClipNormByEMA(Transform):
29
30
  tensorwise:bool=True,
30
31
  max_ema_growth: float | None = 1.5,
31
32
  ema_init: Literal['zeros', 'update'] = 'zeros',
33
+ inner: Chainable | None = None,
32
34
  ):
33
35
  defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
34
- super().__init__(defaults, uses_grad=False)
36
+ super().__init__(defaults, inner=inner)
35
37
 
36
38
  @torch.no_grad
37
- def apply(self, tensors, params, grads, loss, states, settings):
39
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
38
40
  tensors = TensorList(tensors)
39
41
  ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
40
42
 
@@ -78,7 +80,12 @@ class ClipNormByEMA(Transform):
78
80
  if self.NORMALIZE: denom.clip_(min=eps[0])
79
81
  else: denom.clip_(min=1)
80
82
 
81
- tensors.div_(denom)
83
+ self.global_state['denom'] = denom
84
+
85
+ @torch.no_grad
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
+ denom = self.global_state.pop('denom')
88
+ torch._foreach_div_(tensors, denom)
82
89
  return tensors
83
90
 
84
91
  class NormalizeByEMA(ClipNormByEMA):
@@ -88,9 +95,10 @@ class NormalizeByEMA(ClipNormByEMA):
88
95
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
89
96
  ord (float, optional): order of the norm. Defaults to 2.
90
97
  eps (float, optional): epsilon for division. Defaults to 1e-6.
91
- tensorwise (bool, optional): whether to calculate norm separately for each layer, or global norm for all layers. Defaults to True.
98
+ tensorwise (bool, optional):
99
+ if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
92
100
  max_ema_growth (float | None, optional):
93
- if specified, exponential moving average norm can grow but at most this value per step. Defaults to 1.5.
101
+ if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
94
102
  ema_init (str, optional):
95
103
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
96
104
  """
@@ -99,28 +107,30 @@ class NormalizeByEMA(ClipNormByEMA):
99
107
  # TODO Centralize by EMA?
100
108
 
101
109
  class ClipValueByEMA(Transform):
102
- """Clips magnitude of update to be no larger than magnitude of an exponential moving average of past (unclipped) updates.
110
+ """Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
103
111
 
104
112
  Args:
105
113
  beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
106
114
  ema_init (str, optional):
107
115
  How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
108
- ema_tfm (Chainable | None, optional): optional modules applied to exponential moving average before clipping by it. Defaults to None.
116
+ ema_tfm (Chainable | None, optional):
117
+ optional modules applied to exponential moving average before clipping by it. Defaults to None.
109
118
  """
110
119
  def __init__(
111
120
  self,
112
121
  beta=0.99,
113
122
  ema_init: Literal['zeros', 'update'] = 'zeros',
114
123
  ema_tfm:Chainable | None=None,
124
+ inner: Chainable | None = None,
115
125
  ):
116
126
  defaults = dict(beta=beta, ema_init=ema_init)
117
- super().__init__(defaults, uses_grad=False)
127
+ super().__init__(defaults, inner=inner)
118
128
 
119
129
  if ema_tfm is not None:
120
130
  self.set_child('ema_tfm', ema_tfm)
121
131
 
122
132
  @torch.no_grad
123
- def apply(self, tensors, params, grads, loss, states, settings):
133
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
124
134
  ema_init = itemgetter('ema_init')(settings[0])
125
135
 
126
136
  beta = unpack_dicts(settings, 'beta', cls=NumberList)
@@ -129,8 +139,12 @@ class ClipValueByEMA(Transform):
129
139
  ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
130
140
  ema.lerp_(tensors.abs(), 1-beta)
131
141
 
142
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
143
+ tensors = TensorList(tensors)
144
+ ema = unpack_states(states, tensors, 'ema', cls=TensorList)
145
+
132
146
  if 'ema_tfm' in self.children:
133
- ema = TensorList(apply_transform(self.children['ema_tfm'], ema, params, grads, loss))
147
+ ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
134
148
 
135
149
  tensors.clip_(-ema, ema)
136
150
  return tensors
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
19
19
  bounds the tracked multiplicative clipping decay to prevent collapse to 0.
20
20
  Next update is at most :code:`max(previous update * mul, max_decay)`.
21
21
  Defaults to 2.
22
- target (Target, optional): what to set on var.. Defaults to "update".
22
+ target (Target, optional): what to set on var. Defaults to "update".
23
23
  """
24
24
  def __init__(
25
25
  self,
@@ -30,11 +30,11 @@ class ClipValueGrowth(TensorwiseTransform):
30
30
  target: Target = "update",
31
31
  ):
32
32
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
33
- super().__init__(defaults, uses_grad=False, target=target)
33
+ super().__init__(defaults, target=target)
34
34
 
35
35
 
36
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
37
- add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(settings)
36
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
37
+ add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
38
38
  add: float | None
39
39
 
40
40
  if add is None and mul is None:
@@ -120,7 +120,8 @@ class ClipNormGrowth(Transform):
120
120
 
121
121
  Args:
122
122
  add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
123
- mul (float | None, optional): multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
123
+ mul (float | None, optional):
124
+ multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
124
125
  min_value (float | None, optional):
125
126
  minimum value for multiplicative clipping to prevent collapse to 0.
126
127
  Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
@@ -144,11 +145,11 @@ class ClipNormGrowth(Transform):
144
145
  target: Target = "update",
145
146
  ):
146
147
  defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
147
- super().__init__(defaults, uses_grad=False, target=target)
148
+ super().__init__(defaults, target=target)
148
149
 
149
150
 
150
151
 
151
- def apply(self, tensors, params, grads, loss, states, settings):
152
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
152
153
  parameterwise = settings[0]['parameterwise']
153
154
  tensors = TensorList(tensors)
154
155
 
@@ -1,24 +1,41 @@
1
+ """This submodule contains various untested experimental modules, some of them are to be moved out of experimental when properly tested, some are to remain here forever or to be deleted depending on the degree of their usefulness."""
1
2
  from .absoap import ABSOAP
2
3
  from .adadam import Adadam
4
+ from .adam_lambertw import AdamLambertW
3
5
  from .adamY import AdamY
6
+ from .adaptive_step_size import AdaptiveStepSize
4
7
  from .adasoap import AdaSOAP
8
+ from .cosine import (
9
+ AdaptiveDifference,
10
+ AdaptiveDifferenceEMA,
11
+ CosineDebounce,
12
+ CosineMomentum,
13
+ CosineStepSize,
14
+ ScaledAdaptiveDifference,
15
+ )
16
+ from .cubic_adam import CubicAdam
5
17
  from .curveball import CurveBall
18
+
19
+ # from dct import DCTProjection
6
20
  from .eigendescent import EigenDescent
7
21
  from .etf import (
8
22
  ExponentialTrajectoryFit,
9
23
  ExponentialTrajectoryFitV2,
10
24
  PointwiseExponential,
11
25
  )
26
+ from .exp_adam import ExpAdam
27
+ from .expanded_lbfgs import ExpandedLBFGS
28
+ from .fft import FFTProjection
12
29
  from .gradmin import GradMin
30
+ from .hnewton import HNewton
31
+ from .modular_lbfgs import ModularLBFGS
13
32
  from .newton_solver import NewtonSolver
14
33
  from .newtonnewton import NewtonNewton
34
+ from .parabolic_search import CubicParabolaSearch, ParabolaSearch
15
35
  from .reduce_outward_lr import ReduceOutwardLR
16
- from .soapy import SOAPY
17
- from .spectral import SpectralPreconditioner
18
- from .structured_newton import StructuredNewton
36
+ from .structural_projections import BlockPartition, TensorizeProjection
19
37
  from .subspace_preconditioners import (
20
38
  HistorySubspacePreconditioning,
21
39
  RandomSubspacePreconditioning,
22
40
  )
23
- from .tada import TAda
24
- from .diagonal_higher_order_newton import DiagonalHigherOrderNewton
41
+ from .tensor_adagrad import TensorAdagrad
@@ -24,7 +24,10 @@ def update_absoap_covariances_(
24
24
 
25
25
  Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
26
26
  class ABSOAP(Transform):
27
- """SOAP but with some extra options for testing. Please note that this is experimental and isn't guaranteed to work.
27
+ """SOAP but with some extra options for testing.
28
+
29
+ .. warning::
30
+ This module is just for testing my stupid ideas.
28
31
 
29
32
  Args:
30
33
  scale_by_s - whether to scale y by s
@@ -94,7 +97,7 @@ class ABSOAP(Transform):
94
97
  super().__init__(defaults, uses_grad=False)
95
98
 
96
99
  @torch.no_grad
97
- def apply(self, tensors, params, grads, loss, states, settings):
100
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
98
101
  updates = []
99
102
  # update preconditioners
100
103
  for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
@@ -10,7 +10,7 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
13
+ from ..step_size.lr import lazy_lr
14
14
  from ..momentum.experimental import sqrt_nag_ema_sq_
15
15
  from ..momentum.momentum import nag_
16
16
 
@@ -50,7 +50,13 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner. Please note that this is experimental and isn't guaranteed to work."""
53
+ """Adam with a diagonally preconditioned preconditioner.
54
+
55
+ Verdict: I haven't tested this yet.
56
+
57
+ .. warning::
58
+ Experimental.
59
+ """
54
60
  def __init__(
55
61
  self,
56
62
  beta1: float = 0.9,
@@ -10,7 +10,7 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
13
+ from ..step_size.lr import lazy_lr
14
14
  from ..momentum.experimental import sqrt_nag_ema_sq_
15
15
  from ..momentum.momentum import nag_
16
16
 
@@ -62,7 +62,13 @@ def adamy_(
62
62
  return None
63
63
 
64
64
  class AdamY(Module):
65
- """Adam but uses scaled gradient differences for second momentum. Please note that this is experimental and isn't guaranteed to work."""
65
+ """Adam but uses scaled gradient differences for second momentum.
66
+
67
+ Verdict: I haven't tested this yet.
68
+
69
+ .. warning::
70
+ Experimental.
71
+ """
66
72
  def __init__(
67
73
  self,
68
74
  beta1: float = 0.9,
@@ -0,0 +1,149 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+ import math
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..step_size.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def _lambertw_newton_raphson(x: TensorList, iterations=5):
19
+ # z = torch.zeros_like(x)
20
+ # mask_neg = x < 0
21
+ # mask_pos = ~mask_neg
22
+
23
+ # z[mask_pos] = torch.log(x[mask_pos] + 1.0)
24
+
25
+ # x_neg = x[mask_neg]
26
+ # z_neg = -1.0 + torch.sqrt(2.0 * (1.0 + math.e * x_neg))
27
+ # z[mask_neg] = z_neg
28
+
29
+ # x is always positive
30
+ z = (x+1).log_()
31
+ for _ in range(iterations):
32
+ exp_z = z.exp()
33
+ numerator = z * exp_z - x
34
+ denominator = exp_z * (z + 1.0) + 1e-8
35
+ delta = numerator / denominator
36
+ z -= delta
37
+ return z
38
+
39
+ # https://github.com/gmgeorg/torchlambertw/blob/main/torchlambertw/special.py
40
+ def _lambertw_winitzki(x: TensorList):
41
+ x_log1p = x.log1p()
42
+ return x_log1p * (1.0 - x_log1p.log1p() / (2.0 + x_log1p))
43
+
44
+
45
+ def adam_lambertw_(
46
+ tensors: TensorList,
47
+ exp_avg_: TensorList,
48
+ exp_avg_xpx_: TensorList,
49
+ alpha: float | NumberList,
50
+ beta1: float | NumberList,
51
+ beta2: float | NumberList,
52
+ eps: float | NumberList,
53
+ step: int,
54
+ pow: float = 2,
55
+ debiased: bool = True,
56
+ max_exp_avg_xpx_: TensorList | None = None,
57
+ iterations: int | None = 5,
58
+
59
+ # inner args
60
+ inner: Module | None = None,
61
+ params: list[torch.Tensor] | None = None,
62
+ grads: list[torch.Tensor] | None = None,
63
+ ):
64
+ """Returns new tensors."""
65
+ tensors_abs = tensors.abs().clip_(max=20)
66
+ tensors_xpx = tensors_abs.pow_(tensors_abs)
67
+ exp_avg_xpx_.lerp_(tensors_xpx, 1-beta2)
68
+
69
+ if max_exp_avg_xpx_ is not None:
70
+ max_exp_avg_xpx_.maximum_(exp_avg_xpx_)
71
+ exp_avg_xpx_ = max_exp_avg_xpx_
72
+
73
+ if inner is not None:
74
+ assert params is not None
75
+ tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
76
+
77
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
78
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
79
+
80
+ if iterations is None or iterations < 1: exp_avg_xpx_ = _lambertw_winitzki(exp_avg_xpx_)
81
+ else: exp_avg_xpx_ = _lambertw_newton_raphson(exp_avg_xpx_, iterations)
82
+
83
+ return (exp_avg_.lazy_mul(alpha) / exp_avg_xpx_.add_(eps))
84
+
85
+ class AdamLambertW(Transform):
86
+ """Adam but uses abs x^x and LambertW instead of square and sqrt.
87
+ The gradient will be clipped to 20 because float32 which you have to use otherwise you're PC will explode.
88
+
89
+ Args:
90
+ beta1 (float, optional): momentum. Defaults to 0.9.
91
+ beta2 (float, optional): second momentum. Defaults to 0.999.
92
+ eps (float, optional): epsilon. Defaults to 1e-8.
93
+ alpha (float, optional): learning rate. Defaults to 1.
94
+ amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
95
+ pow (float, optional): power used in second momentum power and root. Defaults to 2.
96
+ debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
97
+ iterations (int, optional): 0 or None means Winitzki approximation otherwise number of newton raphson iterations.
98
+ """
99
+ def __init__(
100
+ self,
101
+ beta1: float = 0.9,
102
+ beta2: float = 0.999,
103
+ eps: float = 1e-8,
104
+ amsgrad: bool = False,
105
+ alpha: float = 1.,
106
+ pow: float = 2,
107
+ debiased: bool = True,
108
+ iterations: int | None = 5,
109
+ inner: Chainable | None = None
110
+ ):
111
+ defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased, iterations=iterations)
112
+ super().__init__(defaults, uses_grad=False)
113
+
114
+ if inner is not None: self.set_child('inner', inner)
115
+
116
+ @torch.no_grad
117
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
118
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
119
+
120
+ beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
121
+ amsgrad,pow,debiased,iterations = itemgetter('amsgrad','pow','debiased','iterations')(settings[0])
122
+
123
+ if amsgrad:
124
+ exp_avg, exp_avg_xpx, max_exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', 'max_exp_avg_xpx', cls=TensorList)
125
+ else:
126
+ exp_avg, exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', cls=TensorList)
127
+ max_exp_avg_xpx = None
128
+
129
+
130
+ return adam_lambertw_(
131
+ tensors=TensorList(tensors),
132
+ exp_avg_=exp_avg,
133
+ exp_avg_xpx_=exp_avg_xpx,
134
+ alpha=alpha,
135
+ beta1=beta1,
136
+ beta2=beta2,
137
+ eps=eps,
138
+ step=step,
139
+ pow=pow,
140
+ debiased=debiased,
141
+ max_exp_avg_xpx_=max_exp_avg_xpx,
142
+ iterations=iterations,
143
+
144
+ # inner args
145
+ inner=self.children.get("inner", None),
146
+ params=params,
147
+ grads=grads,
148
+
149
+ )
@@ -2,12 +2,16 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from .line_search import LineSearch
5
+ from ..line_search import LineSearchBase
6
6
 
7
7
 
8
- class TrustRegion(LineSearch):
9
- """Basic first order trust region method. Re-evaluates the function after stepping, if value decreased sufficiently,
10
- step size is increased. If value increased, step size is decreased. This is prone to collapsing.
8
+ class AdaptiveStepSize(LineSearchBase):
9
+ """Basic first order step size adaptation method. Re-evaluates the function after stepping, if value decreased sufficiently,
10
+ step size is increased. If value increased, step size is decreased.
11
+
12
+ .. note::
13
+ This works well in some cases, but it is often prone to collapsing.
14
+ For a more robust alternative use :code:`tz.m.AdaptiveBacktracking`.
11
15
 
12
16
  Args:
13
17
  nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
@@ -18,6 +22,19 @@ class TrustRegion(LineSearch):
18
22
  adaptive (bool, optional):
19
23
  If enabled, when multiple consecutive steps have been successful or unsuccessful,
20
24
  the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
25
+
26
+
27
+ Examples:
28
+ Adagrad with trust region:
29
+
30
+ .. code-block:: python
31
+
32
+ opt = tz.Modular(
33
+ model.parameters(),
34
+ tz.m.Adagrad(),
35
+ tz.m.TrustRegion()
36
+ )
37
+
21
38
  """
22
39
  def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
23
40
  defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
@@ -33,9 +33,14 @@ def update_adasoap_covariances_(
33
33
 
34
34
 
35
35
  class AdaSOAP(Transform):
36
- """SOAP with diagonally preconditioned GG^Ts. Please note that this is experimental and isn't guaranteed to work.
36
+ """SOAP with diagonally preconditioned GG^Ts.
37
+
38
+ .. warning::
39
+ Experimental.
37
40
 
38
41
  precond_beta - beta for GG^T squares
42
+
43
+ Verdict: It works, but it is about the same performance as Adam, but maybe more tuning potential?
39
44
  """
40
45
  def __init__(
41
46
  self,
@@ -71,7 +76,7 @@ class AdaSOAP(Transform):
71
76
  super().__init__(defaults, uses_grad=False)
72
77
 
73
78
  @torch.no_grad
74
- def apply(self, tensors, params, grads, loss, states, settings):
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
75
80
  updates = []
76
81
  # update preconditioners
77
82
  for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):