torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -48,16 +48,25 @@ class Cautious(Transform):
48
48
  eps (float, optional): epsilon for normalization. Defaults to 1e-6.
49
49
  mode (str, optional):
50
50
  what to do with updates with inconsistent signs.
51
+ - "zero" - set them to zero (as in paper)
52
+ - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
53
+ - "backtrack" - negate them
51
54
 
52
- "zero" - set them to zero (as in paper)
55
+ ## Examples:
53
56
 
54
- "grad" - set them to the gradient
57
+ Cautious Adam
55
58
 
56
- "backtrack" - negate them (same as using update magnitude and gradient sign)
59
+ ```python
60
+ opt = tz.Modular(
61
+ bench.parameters(),
62
+ tz.m.Adam(),
63
+ tz.m.Cautious(),
64
+ tz.m.LR(1e-2)
65
+ )
66
+ ```
57
67
 
58
- reference
59
- *Cautious Optimizers: Improving Training with One Line of Code.
60
- Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
68
+ References:
69
+ Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
61
70
  """
62
71
 
63
72
  def __init__(
@@ -70,7 +79,7 @@ class Cautious(Transform):
70
79
  super().__init__(defaults, uses_grad=True)
71
80
 
72
81
  @torch.no_grad
73
- def apply(self, tensors, params, grads, loss, states, settings):
82
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
74
83
  assert grads is not None
75
84
  mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
76
85
  return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
@@ -89,7 +98,7 @@ class UpdateGradientSignConsistency(Transform):
89
98
  super().__init__(defaults, uses_grad=True)
90
99
 
91
100
  @torch.no_grad
92
- def apply(self, tensors, params, grads, loss, states, settings):
101
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
93
102
  assert grads is not None
94
103
  normalize, eps = itemgetter('normalize', 'eps')(settings[0])
95
104
 
@@ -109,12 +118,9 @@ class IntermoduleCautious(Module):
109
118
  eps (float, optional): epsilon for normalization. Defaults to 1e-6.
110
119
  mode (str, optional):
111
120
  what to do with updates with inconsistent signs.
112
-
113
- "zero" - set them to zero (as in paper)
114
-
115
- "grad" - set them to the gradient
116
-
117
- "backtrack" - negate them (same as using update magnitude and gradient sign)
121
+ - "zero" - set them to zero (as in paper)
122
+ - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
123
+ - "backtrack" - negate them
118
124
  """
119
125
  def __init__(
120
126
  self,
@@ -142,7 +148,7 @@ class IntermoduleCautious(Module):
142
148
  compare_var = compare.step(var.clone(clone_update=True))
143
149
  var.update_attrs_from_clone_(compare_var)
144
150
 
145
- mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[var.params[0]])
151
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
146
152
  var.update = cautious_(
147
153
  TensorList(main_var.get_update()),
148
154
  TensorList(compare_var.get_update()),
@@ -159,6 +165,18 @@ class ScaleByGradCosineSimilarity(Transform):
159
165
 
160
166
  Args:
161
167
  eps (float, optional): epsilon for division. Defaults to 1e-6.
168
+
169
+ ## Examples:
170
+
171
+ Scaled Adam
172
+ ```python
173
+ opt = tz.Modular(
174
+ bench.parameters(),
175
+ tz.m.Adam(),
176
+ tz.m.ScaleByGradCosineSimilarity(),
177
+ tz.m.LR(1e-2)
178
+ )
179
+ ```
162
180
  """
163
181
  def __init__(
164
182
  self,
@@ -168,12 +186,12 @@ class ScaleByGradCosineSimilarity(Transform):
168
186
  super().__init__(defaults, uses_grad=True)
169
187
 
170
188
  @torch.no_grad
171
- def apply(self, tensors, params, grads, loss, states, settings):
189
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
172
190
  assert grads is not None
173
191
  eps = settings[0]['eps']
174
192
  tensors = TensorList(tensors)
175
193
  grads = TensorList(grads)
176
- cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
194
+ cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
177
195
 
178
196
  return tensors.mul_(cos_sim)
179
197
 
@@ -185,6 +203,20 @@ class ScaleModulesByCosineSimilarity(Module):
185
203
  main (Chainable): main module or sequence of modules whose update will be scaled.
186
204
  compare (Chainable): module or sequence of modules to compare to
187
205
  eps (float, optional): epsilon for division. Defaults to 1e-6.
206
+
207
+ ## Examples:
208
+
209
+ Adam scaled by similarity to RMSprop
210
+ ```python
211
+ opt = tz.Modular(
212
+ bench.parameters(),
213
+ tz.m.ScaleModulesByCosineSimilarity(
214
+ main = tz.m.Adam(),
215
+ compare = tz.m.RMSprop(0.999, debiased=True),
216
+ ),
217
+ tz.m.LR(1e-2)
218
+ )
219
+ ```
188
220
  """
189
221
  def __init__(
190
222
  self,
@@ -211,9 +243,9 @@ class ScaleModulesByCosineSimilarity(Module):
211
243
 
212
244
  m = TensorList(main_var.get_update())
213
245
  c = TensorList(compare_var.get_update())
214
- eps = self.settings[var.params[0]]['eps']
246
+ eps = self.defaults['eps']
215
247
 
216
- cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
248
+ cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
217
249
 
218
250
  var.update = m.mul_(cos_sim)
219
251
  return var
@@ -1,10 +1,44 @@
1
+ from collections import deque
2
+ from operator import itemgetter
1
3
  from typing import Literal
2
4
 
3
5
  import torch
4
6
 
5
7
  from ...core import Target, Transform
6
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
- from .ema import EMA
9
+ from ..functional import debias, ema_
10
+
11
+
12
+ class EMA(Transform):
13
+ """Maintains an exponential moving average of update.
14
+
15
+ Args:
16
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
17
+ dampening (float, optional): momentum dampening. Defaults to 0.
18
+ debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
+ lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
+ ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
+ """
23
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
24
+ defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
25
+ super().__init__(defaults, uses_grad=False, target=target)
26
+
27
+ @torch.no_grad
28
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
29
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
+
31
+ debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
32
+
33
+ exp_avg = unpack_states(states, tensors, 'exp_avg',
34
+ init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
35
+ momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
36
+
37
+ exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
38
+
39
+ if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
40
+ else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
41
+
8
42
 
9
43
 
10
44
  class HeavyBall(EMA):
@@ -55,9 +89,10 @@ class NAG(Transform):
55
89
  super().__init__(defaults, uses_grad=False, target=target)
56
90
 
57
91
  @torch.no_grad
58
- def apply(self, tensors, params, grads, loss, states, settings):
92
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
59
93
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
60
94
  lerp = self.settings[params[0]]['lerp']
61
95
 
62
96
  momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
63
97
  return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
98
+
@@ -7,7 +7,7 @@ from .accumulate import (
7
7
  )
8
8
  from .binary import (
9
9
  Add,
10
- BinaryOperation,
10
+ BinaryOperationBase,
11
11
  Clip,
12
12
  CopyMagnitude,
13
13
  CopySign,
@@ -27,37 +27,20 @@ from .binary import (
27
27
  Sub,
28
28
  Threshold,
29
29
  )
30
- from .debug import PrintShape, PrintUpdate
31
- from .misc import (
32
- DivByLoss,
33
- Dropout,
34
- FillLoss,
35
- GradientAccumulation,
36
- GradSign,
37
- GraftGradToUpdate,
38
- GraftToGrad,
39
- GraftToParams,
40
- LastAbsoluteRatio,
41
- LastDifference,
42
- LastGradDifference,
43
- LastProduct,
44
- LastRatio,
45
- MulByLoss,
46
- Multistep,
47
- NegateOnLossIncrease,
48
- NoiseSign,
49
- Previous,
50
- Relative,
51
- Sequential,
52
- UpdateSign,
53
- WeightDropout,
30
+ from .higher_level import (
31
+ CenteredEMASquared,
32
+ CenteredSqrtEMASquared,
33
+ Debias,
34
+ Debias2,
35
+ EMASquared,
36
+ SqrtEMASquared,
54
37
  )
55
38
  from .multi import (
56
39
  ClipModules,
57
40
  DivModules,
58
41
  GraftModules,
59
42
  LerpModules,
60
- MultiOperation,
43
+ MultiOperationBase,
61
44
  PowModules,
62
45
  SubModules,
63
46
  )
@@ -66,13 +49,11 @@ from .reduce import (
66
49
  Mean,
67
50
  MinimumModules,
68
51
  Prod,
69
- ReduceOperation,
52
+ ReduceOperationBase,
70
53
  Sum,
71
54
  WeightedMean,
72
55
  WeightedSum,
73
56
  )
74
- from .split import Split
75
- from .switch import Alternate, Switch
76
57
  from .unary import (
77
58
  Abs,
78
59
  CustomUnaryOperation,
@@ -91,13 +72,12 @@ from .utility import (
91
72
  Grad,
92
73
  GradToNone,
93
74
  Identity,
94
- NoOp,
75
+ Noop,
95
76
  Ones,
96
77
  Params,
97
78
  Randn,
98
79
  RandomSample,
99
80
  Uniform,
100
- Update,
101
81
  UpdateToNone,
102
82
  Zeros,
103
83
  )
@@ -1,11 +1,7 @@
1
- from collections import deque
2
- from operator import itemgetter
3
- from typing import Literal
4
-
5
1
  import torch
6
2
 
7
3
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList, unpack_states, unpack_dicts
4
+ from ...utils import TensorList, unpack_states
9
5
 
10
6
  class AccumulateSum(Transform):
11
7
  """Accumulates sum of all past updates.
@@ -19,7 +15,7 @@ class AccumulateSum(Transform):
19
15
  super().__init__(defaults, uses_grad=False, target=target)
20
16
 
21
17
  @torch.no_grad
22
- def apply(self, tensors, params, grads, loss, states, settings):
18
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
23
19
  sum = unpack_states(states, tensors, 'sum', cls=TensorList)
24
20
  decay = [1-s['decay'] for s in settings]
25
21
  return sum.add_(tensors).lazy_mul(decay, clone=True)
@@ -36,7 +32,7 @@ class AccumulateMean(Transform):
36
32
  super().__init__(defaults, uses_grad=False, target=target)
37
33
 
38
34
  @torch.no_grad
39
- def apply(self, tensors, params, grads, loss, states, settings):
35
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
36
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
41
37
  mean = unpack_states(states, tensors, 'mean', cls=TensorList)
42
38
  decay = [1-s['decay'] for s in settings]
@@ -54,7 +50,7 @@ class AccumulateProduct(Transform):
54
50
  super().__init__(defaults, uses_grad=False, target=target)
55
51
 
56
52
  @torch.no_grad
57
- def apply(self, tensors, params, grads, loss, states, settings):
53
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
58
54
  prod = unpack_states(states, tensors, 'prod', cls=TensorList)
59
55
  decay = [1-s['decay'] for s in settings]
60
56
  return prod.mul_(tensors).lazy_mul(decay, clone=True)
@@ -71,7 +67,7 @@ class AccumulateMaximum(Transform):
71
67
  super().__init__(defaults, uses_grad=False, target=target)
72
68
 
73
69
  @torch.no_grad
74
- def apply(self, tensors, params, grads, loss, states, settings):
70
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
75
71
  maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
76
72
  decay = [1-s['decay'] for s in settings]
77
73
  return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
@@ -88,7 +84,7 @@ class AccumulateMinimum(Transform):
88
84
  super().__init__(defaults, uses_grad=False, target=target)
89
85
 
90
86
  @torch.no_grad
91
- def apply(self, tensors, params, grads, loss, states, settings):
87
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
88
  minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
93
89
  decay = [1-s['decay'] for s in settings]
94
90
  return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
@@ -1,5 +1,4 @@
1
1
  #pyright: reportIncompatibleMethodOverride=false
2
- """"""
3
2
  from abc import ABC, abstractmethod
4
3
  from collections.abc import Iterable, Sequence
5
4
  from operator import itemgetter
@@ -11,7 +10,7 @@ from ...core import Chainable, Module, Target, Var, maybe_chain
11
10
  from ...utils import TensorList, tensorlist
12
11
 
13
12
 
14
- class BinaryOperation(Module, ABC):
13
+ class BinaryOperationBase(Module, ABC):
15
14
  """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
16
15
  def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
16
  super().__init__(defaults=defaults)
@@ -47,29 +46,41 @@ class BinaryOperation(Module, ABC):
47
46
  return var
48
47
 
49
48
 
50
- class Add(BinaryOperation):
49
+ class Add(BinaryOperationBase):
50
+ """Add :code:`other` to tensors. :code:`other` can be a number or a module.
51
+
52
+ If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
53
+ """
51
54
  def __init__(self, other: Chainable | float, alpha: float = 1):
52
55
  defaults = dict(alpha=alpha)
53
56
  super().__init__(defaults, other=other)
54
57
 
55
58
  @torch.no_grad
56
59
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
57
- if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[var.params[0]]['alpha'])
58
- else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
60
+ if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
61
+ else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
59
62
  return update
60
63
 
61
- class Sub(BinaryOperation):
64
+ class Sub(BinaryOperationBase):
65
+ """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
66
+
67
+ If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
68
+ """
62
69
  def __init__(self, other: Chainable | float, alpha: float = 1):
63
70
  defaults = dict(alpha=alpha)
64
71
  super().__init__(defaults, other=other)
65
72
 
66
73
  @torch.no_grad
67
74
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
68
- if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[var.params[0]]['alpha'])
69
- else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
75
+ if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
76
+ else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
70
77
  return update
71
78
 
72
- class RSub(BinaryOperation):
79
+ class RSub(BinaryOperationBase):
80
+ """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
81
+
82
+ If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
83
+ """
73
84
  def __init__(self, other: Chainable | float):
74
85
  super().__init__({}, other=other)
75
86
 
@@ -77,7 +88,11 @@ class RSub(BinaryOperation):
77
88
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
78
89
  return other - TensorList(update)
79
90
 
80
- class Mul(BinaryOperation):
91
+ class Mul(BinaryOperationBase):
92
+ """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
93
+
94
+ If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
95
+ """
81
96
  def __init__(self, other: Chainable | float):
82
97
  super().__init__({}, other=other)
83
98
 
@@ -86,7 +101,11 @@ class Mul(BinaryOperation):
86
101
  torch._foreach_mul_(update, other)
87
102
  return update
88
103
 
89
- class Div(BinaryOperation):
104
+ class Div(BinaryOperationBase):
105
+ """Divide tensors by :code:`other`. :code:`other` can be a number or a module.
106
+
107
+ If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
108
+ """
90
109
  def __init__(self, other: Chainable | float):
91
110
  super().__init__({}, other=other)
92
111
 
@@ -95,7 +114,11 @@ class Div(BinaryOperation):
95
114
  torch._foreach_div_(update, other)
96
115
  return update
97
116
 
98
- class RDiv(BinaryOperation):
117
+ class RDiv(BinaryOperationBase):
118
+ """Divide :code:`other` by tensors. :code:`other` can be a number or a module.
119
+
120
+ If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
121
+ """
99
122
  def __init__(self, other: Chainable | float):
100
123
  super().__init__({}, other=other)
101
124
 
@@ -103,7 +126,11 @@ class RDiv(BinaryOperation):
103
126
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
104
127
  return other / TensorList(update)
105
128
 
106
- class Pow(BinaryOperation):
129
+ class Pow(BinaryOperationBase):
130
+ """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
131
+
132
+ If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
133
+ """
107
134
  def __init__(self, exponent: Chainable | float):
108
135
  super().__init__({}, exponent=exponent)
109
136
 
@@ -112,7 +139,11 @@ class Pow(BinaryOperation):
112
139
  torch._foreach_pow_(update, exponent)
113
140
  return update
114
141
 
115
- class RPow(BinaryOperation):
142
+ class RPow(BinaryOperationBase):
143
+ """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
144
+
145
+ If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
146
+ """
116
147
  def __init__(self, other: Chainable | float):
117
148
  super().__init__({}, other=other)
118
149
 
@@ -122,7 +153,11 @@ class RPow(BinaryOperation):
122
153
  torch._foreach_pow_(other, update)
123
154
  return other
124
155
 
125
- class Lerp(BinaryOperation):
156
+ class Lerp(BinaryOperationBase):
157
+ """Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
158
+
159
+ The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
160
+ """
126
161
  def __init__(self, end: Chainable, weight: float):
127
162
  defaults = dict(weight=weight)
128
163
  super().__init__(defaults, end=end)
@@ -132,7 +167,8 @@ class Lerp(BinaryOperation):
132
167
  torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
133
168
  return update
134
169
 
135
- class CopySign(BinaryOperation):
170
+ class CopySign(BinaryOperationBase):
171
+ """Returns tensors with sign copied from :code:`other(tensors)`."""
136
172
  def __init__(self, other: Chainable):
137
173
  super().__init__({}, other=other)
138
174
 
@@ -140,7 +176,8 @@ class CopySign(BinaryOperation):
140
176
  def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
141
177
  return [u.copysign_(o) for u, o in zip(update, other)]
142
178
 
143
- class RCopySign(BinaryOperation):
179
+ class RCopySign(BinaryOperationBase):
180
+ """Returns :code:`other(tensors)` with sign copied from tensors."""
144
181
  def __init__(self, other: Chainable):
145
182
  super().__init__({}, other=other)
146
183
 
@@ -149,7 +186,11 @@ class RCopySign(BinaryOperation):
149
186
  return [o.copysign_(u) for u, o in zip(update, other)]
150
187
  CopyMagnitude = RCopySign
151
188
 
152
- class Clip(BinaryOperation):
189
+ class Clip(BinaryOperationBase):
190
+ """clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
191
+
192
+ If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
193
+ """
153
194
  def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
154
195
  super().__init__({}, min=min, max=max)
155
196
 
@@ -157,8 +198,11 @@ class Clip(BinaryOperation):
157
198
  def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
158
199
  return TensorList(update).clamp_(min=min, max=max)
159
200
 
160
- class MirroredClip(BinaryOperation):
161
- """clip by -value, value"""
201
+ class MirroredClip(BinaryOperationBase):
202
+ """clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
203
+
204
+ If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
205
+ """
162
206
  def __init__(self, value: float | Chainable):
163
207
  super().__init__({}, value=value)
164
208
 
@@ -167,19 +211,19 @@ class MirroredClip(BinaryOperation):
167
211
  min = -value if isinstance(value, (int,float)) else [-v for v in value]
168
212
  return TensorList(update).clamp_(min=min, max=value)
169
213
 
170
- class Graft(BinaryOperation):
171
- """use direction from update and magnitude from `magnitude` module"""
214
+ class Graft(BinaryOperationBase):
215
+ """Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
172
216
  def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
173
217
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
174
218
  super().__init__(defaults, magnitude=magnitude)
175
219
 
176
220
  @torch.no_grad
177
221
  def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
178
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
222
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
179
223
  return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
180
224
 
181
- class RGraft(BinaryOperation):
182
- """use direction from `direction` module and magnitude from update"""
225
+ class RGraft(BinaryOperationBase):
226
+ """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
183
227
 
184
228
  def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
185
229
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
@@ -187,12 +231,13 @@ class RGraft(BinaryOperation):
187
231
 
188
232
  @torch.no_grad
189
233
  def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
190
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
234
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
191
235
  return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
192
236
 
193
237
  GraftToUpdate = RGraft
194
238
 
195
- class Maximum(BinaryOperation):
239
+ class Maximum(BinaryOperationBase):
240
+ """Outputs :code:`maximum(tensors, other(tensors))`"""
196
241
  def __init__(self, other: Chainable):
197
242
  super().__init__({}, other=other)
198
243
 
@@ -201,7 +246,8 @@ class Maximum(BinaryOperation):
201
246
  torch._foreach_maximum_(update, other)
202
247
  return update
203
248
 
204
- class Minimum(BinaryOperation):
249
+ class Minimum(BinaryOperationBase):
250
+ """Outputs :code:`minimum(tensors, other(tensors))`"""
205
251
  def __init__(self, other: Chainable):
206
252
  super().__init__({}, other=other)
207
253
 
@@ -211,26 +257,27 @@ class Minimum(BinaryOperation):
211
257
  return update
212
258
 
213
259
 
214
- class GramSchimdt(BinaryOperation):
215
- """makes update orthonormal to `other`"""
260
+ class GramSchimdt(BinaryOperationBase):
261
+ """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
216
262
  def __init__(self, other: Chainable):
217
263
  super().__init__({}, other=other)
218
264
 
219
265
  @torch.no_grad
220
266
  def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
221
267
  update = TensorList(update); other = TensorList(other)
222
- return update - (other*update) / ((other*other) + 1e-8)
268
+ min = torch.finfo(update[0].dtype).tiny * 2
269
+ return update - (other*update) / (other*other).clip(min=min)
223
270
 
224
271
 
225
- class Threshold(BinaryOperation):
226
- """update above/below threshold, value at and below"""
272
+ class Threshold(BinaryOperationBase):
273
+ """Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
227
274
  def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
228
275
  defaults = dict(update_above=update_above)
229
276
  super().__init__(defaults, threshold=threshold, value=value)
230
277
 
231
278
  @torch.no_grad
232
279
  def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
233
- update_above = self.settings[var.params[0]]['update_above']
280
+ update_above = self.defaults['update_above']
234
281
  update = TensorList(update)
235
282
  if update_above:
236
283
  if isinstance(value, list): return update.where_(update>threshold, value)