torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,27 @@
1
1
  """Modules that perform averaging over a history of past updates."""
2
2
  from collections import deque
3
3
  from collections.abc import Sequence
4
- from typing import Any, Literal, cast
4
+ from typing import Any
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import TensorwiseTransform, Target
8
+ from ...core import TensorTransform
9
9
  from ...utils import tolist
10
10
 
11
11
 
12
- class Averaging(TensorwiseTransform):
12
+ class Averaging(TensorTransform):
13
13
  """Average of past ``history_size`` updates.
14
14
 
15
15
  Args:
16
16
  history_size (int): Number of past updates to average
17
17
  target (Target, optional): target. Defaults to 'update'.
18
18
  """
19
- def __init__(self, history_size: int, target: Target = 'update'):
19
+ def __init__(self, history_size: int):
20
20
  defaults = dict(history_size=history_size)
21
- super().__init__(uses_grad=False, defaults=defaults, target=target)
21
+ super().__init__(defaults=defaults)
22
22
 
23
23
  @torch.no_grad
24
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
24
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
25
25
  history_size = setting['history_size']
26
26
  if 'history' not in state:
27
27
  state['history'] = deque(maxlen=history_size)
@@ -34,19 +34,19 @@ class Averaging(TensorwiseTransform):
34
34
 
35
35
  return average / len(history)
36
36
 
37
- class WeightedAveraging(TensorwiseTransform):
37
+ class WeightedAveraging(TensorTransform):
38
38
  """Weighted average of past ``len(weights)`` updates.
39
39
 
40
40
  Args:
41
41
  weights (Sequence[float]): a sequence of weights from oldest to newest.
42
42
  target (Target, optional): target. Defaults to 'update'.
43
43
  """
44
- def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
44
+ def __init__(self, weights: Sequence[float] | torch.Tensor | Any):
45
45
  defaults = dict(weights = tolist(weights))
46
- super().__init__(uses_grad=False, defaults=defaults, target=target)
46
+ super().__init__(defaults=defaults)
47
47
 
48
48
  @torch.no_grad
49
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
49
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
50
50
  weights = setting['weights']
51
51
 
52
52
  if 'history' not in state:
@@ -68,19 +68,19 @@ class WeightedAveraging(TensorwiseTransform):
68
68
  return average
69
69
 
70
70
 
71
- class MedianAveraging(TensorwiseTransform):
71
+ class MedianAveraging(TensorTransform):
72
72
  """Median of past ``history_size`` updates.
73
73
 
74
74
  Args:
75
75
  history_size (int): Number of past updates to average
76
76
  target (Target, optional): target. Defaults to 'update'.
77
77
  """
78
- def __init__(self, history_size: int, target: Target = 'update'):
78
+ def __init__(self, history_size: int,):
79
79
  defaults = dict(history_size = history_size)
80
- super().__init__(uses_grad=False, defaults=defaults, target=target)
80
+ super().__init__(defaults=defaults)
81
81
 
82
82
  @torch.no_grad
83
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
83
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
84
84
  history_size = setting['history_size']
85
85
 
86
86
  if 'history' not in state:
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Target, Transform, Module, Chainable
8
+ from ...core import TensorTransform, Module, Chainable
9
9
  from ...utils import NumberList, TensorList, unpack_dicts
10
10
 
11
11
 
@@ -36,7 +36,7 @@ def cautious_(
36
36
  tensors_ -= tensors_.mul(2).mul_(mask.logical_not_())
37
37
  return tensors_
38
38
 
39
- class Cautious(Transform):
39
+ class Cautious(TensorTransform):
40
40
  """Negates update for parameters where update and gradient sign is inconsistent.
41
41
  Optionally normalizes the update by the number of parameters that are not masked.
42
42
  This is meant to be used after any momentum-based modules.
@@ -57,7 +57,7 @@ class Cautious(Transform):
57
57
  Cautious Adam
58
58
 
59
59
  ```python
60
- opt = tz.Modular(
60
+ opt = tz.Optimizer(
61
61
  bench.parameters(),
62
62
  tz.m.Adam(),
63
63
  tz.m.Cautious(),
@@ -79,12 +79,12 @@ class Cautious(Transform):
79
79
  super().__init__(defaults, uses_grad=True)
80
80
 
81
81
  @torch.no_grad
82
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
82
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
83
83
  assert grads is not None
84
84
  mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
85
85
  return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
86
86
 
87
- class UpdateGradientSignConsistency(Transform):
87
+ class UpdateGradientSignConsistency(TensorTransform):
88
88
  """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
89
89
 
90
90
  Args:
@@ -98,7 +98,7 @@ class UpdateGradientSignConsistency(Transform):
98
98
  super().__init__(defaults, uses_grad=True)
99
99
 
100
100
  @torch.no_grad
101
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
101
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
102
102
  assert grads is not None
103
103
  normalize, eps = itemgetter('normalize', 'eps')(settings[0])
104
104
 
@@ -108,7 +108,7 @@ class UpdateGradientSignConsistency(Transform):
108
108
  return mask
109
109
 
110
110
  class IntermoduleCautious(Module):
111
- """Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.
111
+ """Negaties update on :code:`main` module where it's sign doesn't match with output of ``compare`` module.
112
112
 
113
113
  Args:
114
114
  main (Chainable): main module or sequence of modules whose update will be cautioned.
@@ -137,29 +137,32 @@ class IntermoduleCautious(Module):
137
137
  self.set_child('main', main)
138
138
  self.set_child('compare', compare)
139
139
 
140
+ def update(self, objective): raise RuntimeError
141
+ def apply(self, objective): raise RuntimeError
142
+
140
143
  @torch.no_grad
141
- def step(self, var):
144
+ def step(self, objective):
142
145
  main = self.children['main']
143
146
  compare = self.children['compare']
144
147
 
145
- main_var = main.step(var.clone(clone_update=True))
146
- var.update_attrs_from_clone_(main_var)
148
+ main_var = main.step(objective.clone(clone_updates=True))
149
+ objective.update_attrs_from_clone_(main_var)
147
150
 
148
- compare_var = compare.step(var.clone(clone_update=True))
149
- var.update_attrs_from_clone_(compare_var)
151
+ compare_var = compare.step(objective.clone(clone_updates=True))
152
+ objective.update_attrs_from_clone_(compare_var)
150
153
 
151
154
  mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
152
- var.update = cautious_(
153
- TensorList(main_var.get_update()),
154
- TensorList(compare_var.get_update()),
155
+ objective.updates = cautious_(
156
+ TensorList(main_var.get_updates()),
157
+ TensorList(compare_var.get_updates()),
155
158
  normalize=normalize,
156
159
  mode=mode,
157
160
  eps=eps,
158
161
  )
159
162
 
160
- return var
163
+ return objective
161
164
 
162
- class ScaleByGradCosineSimilarity(Transform):
165
+ class ScaleByGradCosineSimilarity(TensorTransform):
163
166
  """Multiplies the update by cosine similarity with gradient.
164
167
  If cosine similarity is negative, naturally the update will be negated as well.
165
168
 
@@ -170,7 +173,7 @@ class ScaleByGradCosineSimilarity(Transform):
170
173
 
171
174
  Scaled Adam
172
175
  ```python
173
- opt = tz.Modular(
176
+ opt = tz.Optimizer(
174
177
  bench.parameters(),
175
178
  tz.m.Adam(),
176
179
  tz.m.ScaleByGradCosineSimilarity(),
@@ -186,7 +189,7 @@ class ScaleByGradCosineSimilarity(Transform):
186
189
  super().__init__(defaults, uses_grad=True)
187
190
 
188
191
  @torch.no_grad
189
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
192
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
190
193
  assert grads is not None
191
194
  eps = settings[0]['eps']
192
195
  tensors = TensorList(tensors)
@@ -196,8 +199,8 @@ class ScaleByGradCosineSimilarity(Transform):
196
199
  return tensors.mul_(cos_sim)
197
200
 
198
201
  class ScaleModulesByCosineSimilarity(Module):
199
- """Scales the output of :code:`main` module by it's cosine similarity to the output
200
- of :code:`compare` module.
202
+ """Scales the output of ``main`` module by it's cosine similarity to the output
203
+ of ``compare`` module.
201
204
 
202
205
  Args:
203
206
  main (Chainable): main module or sequence of modules whose update will be scaled.
@@ -208,7 +211,7 @@ class ScaleModulesByCosineSimilarity(Module):
208
211
 
209
212
  Adam scaled by similarity to RMSprop
210
213
  ```python
211
- opt = tz.Modular(
214
+ opt = tz.Optimizer(
212
215
  bench.parameters(),
213
216
  tz.m.ScaleModulesByCosineSimilarity(
214
217
  main = tz.m.Adam(),
@@ -230,22 +233,25 @@ class ScaleModulesByCosineSimilarity(Module):
230
233
  self.set_child('main', main)
231
234
  self.set_child('compare', compare)
232
235
 
236
+ def update(self, objective): raise RuntimeError
237
+ def apply(self, objective): raise RuntimeError
238
+
233
239
  @torch.no_grad
234
- def step(self, var):
240
+ def step(self, objective):
235
241
  main = self.children['main']
236
242
  compare = self.children['compare']
237
243
 
238
- main_var = main.step(var.clone(clone_update=True))
239
- var.update_attrs_from_clone_(main_var)
244
+ main_var = main.step(objective.clone(clone_updates=True))
245
+ objective.update_attrs_from_clone_(main_var)
240
246
 
241
- compare_var = compare.step(var.clone(clone_update=True))
242
- var.update_attrs_from_clone_(compare_var)
247
+ compare_var = compare.step(objective.clone(clone_updates=True))
248
+ objective.update_attrs_from_clone_(compare_var)
243
249
 
244
- m = TensorList(main_var.get_update())
245
- c = TensorList(compare_var.get_update())
250
+ m = TensorList(main_var.get_updates())
251
+ c = TensorList(compare_var.get_updates())
246
252
  eps = self.defaults['eps']
247
253
 
248
254
  cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
249
255
 
250
- var.update = m.mul_(cos_sim)
251
- return var
256
+ objective.updates = m.mul_(cos_sim)
257
+ return objective
@@ -4,12 +4,12 @@ from typing import Literal
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Target, Transform
7
+ from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
- from ..functional import debias, ema_
9
+ from ..opt_utils import debias, ema_
10
10
 
11
11
 
12
- class EMA(Transform):
12
+ class EMA(TensorTransform):
13
13
  """Maintains an exponential moving average of update.
14
14
 
15
15
  Args:
@@ -20,12 +20,12 @@ class EMA(Transform):
20
20
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
21
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
22
  """
23
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
23
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
24
24
  defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
25
- super().__init__(defaults, uses_grad=False, target=target)
25
+ super().__init__(defaults, uses_grad=False)
26
26
 
27
27
  @torch.no_grad
28
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
29
29
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
30
 
31
31
  debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
@@ -53,8 +53,8 @@ class HeavyBall(EMA):
53
53
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
54
54
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
55
55
  """
56
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
57
- super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
56
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
57
+ super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init)
58
58
 
59
59
  def nag_(
60
60
  tensors_: TensorList,
@@ -74,7 +74,7 @@ def nag_(
74
74
  return tensors_
75
75
 
76
76
 
77
- class NAG(Transform):
77
+ class NAG(TensorTransform):
78
78
  """Nesterov accelerated gradient method (nesterov momentum).
79
79
 
80
80
  Args:
@@ -84,12 +84,12 @@ class NAG(Transform):
84
84
  whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
85
85
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
86
86
  """
87
- def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
87
+ def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False):
88
88
  defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
89
- super().__init__(defaults, uses_grad=False, target=target)
89
+ super().__init__(defaults, uses_grad=False)
90
90
 
91
91
  @torch.no_grad
92
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
93
93
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
94
94
  lerp = self.settings[params[0]]['lerp']
95
95
 
@@ -12,8 +12,8 @@ from .binary import (
12
12
  CopyMagnitude,
13
13
  CopySign,
14
14
  Div,
15
- Graft,
16
- GraftToUpdate,
15
+ GraftInputToOutput,
16
+ GraftInputToOutput,
17
17
  GramSchimdt,
18
18
  Maximum,
19
19
  Minimum,
@@ -21,7 +21,7 @@ from .binary import (
21
21
  Pow,
22
22
  RCopySign,
23
23
  RDiv,
24
- RGraft,
24
+ GraftOutputToInput,
25
25
  RPow,
26
26
  RSub,
27
27
  Sub,
@@ -38,7 +38,7 @@ from .higher_level import (
38
38
  from .multi import (
39
39
  ClipModules,
40
40
  DivModules,
41
- GraftModules,
41
+ Graft,
42
42
  LerpModules,
43
43
  MultiOperationBase,
44
44
  PowModules,
@@ -1,90 +1,90 @@
1
1
  import torch
2
2
 
3
- from ...core import Target, Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import TensorList, unpack_states
5
5
 
6
- class AccumulateSum(Transform):
6
+ class AccumulateSum(TensorTransform):
7
7
  """Accumulates sum of all past updates.
8
8
 
9
9
  Args:
10
10
  decay (float, optional): decays the accumulator. Defaults to 0.
11
11
  target (Target, optional): target. Defaults to 'update'.
12
12
  """
13
- def __init__(self, decay: float = 0, target: Target = 'update',):
13
+ def __init__(self, decay: float = 0):
14
14
  defaults = dict(decay=decay)
15
- super().__init__(defaults, uses_grad=False, target=target)
15
+ super().__init__(defaults)
16
16
 
17
17
  @torch.no_grad
18
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
18
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
19
19
  sum = unpack_states(states, tensors, 'sum', cls=TensorList)
20
20
  decay = [1-s['decay'] for s in settings]
21
21
  return sum.add_(tensors).lazy_mul(decay, clone=True)
22
22
 
23
- class AccumulateMean(Transform):
23
+ class AccumulateMean(TensorTransform):
24
24
  """Accumulates mean of all past updates.
25
25
 
26
26
  Args:
27
27
  decay (float, optional): decays the accumulator. Defaults to 0.
28
28
  target (Target, optional): target. Defaults to 'update'.
29
29
  """
30
- def __init__(self, decay: float = 0, target: Target = 'update',):
30
+ def __init__(self, decay: float = 0):
31
31
  defaults = dict(decay=decay)
32
- super().__init__(defaults, uses_grad=False, target=target)
32
+ super().__init__(defaults)
33
33
 
34
34
  @torch.no_grad
35
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
35
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
36
36
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
37
37
  mean = unpack_states(states, tensors, 'mean', cls=TensorList)
38
38
  decay = [1-s['decay'] for s in settings]
39
39
  return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
40
40
 
41
- class AccumulateProduct(Transform):
41
+ class AccumulateProduct(TensorTransform):
42
42
  """Accumulates product of all past updates.
43
43
 
44
44
  Args:
45
45
  decay (float, optional): decays the accumulator. Defaults to 0.
46
46
  target (Target, optional): target. Defaults to 'update'.
47
47
  """
48
- def __init__(self, decay: float = 0, target: Target = 'update',):
48
+ def __init__(self, decay: float = 0, target = 'update',):
49
49
  defaults = dict(decay=decay)
50
- super().__init__(defaults, uses_grad=False, target=target)
50
+ super().__init__(defaults)
51
51
 
52
52
  @torch.no_grad
53
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
53
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
54
54
  prod = unpack_states(states, tensors, 'prod', cls=TensorList)
55
55
  decay = [1-s['decay'] for s in settings]
56
56
  return prod.mul_(tensors).lazy_mul(decay, clone=True)
57
57
 
58
- class AccumulateMaximum(Transform):
58
+ class AccumulateMaximum(TensorTransform):
59
59
  """Accumulates maximum of all past updates.
60
60
 
61
61
  Args:
62
62
  decay (float, optional): decays the accumulator. Defaults to 0.
63
63
  target (Target, optional): target. Defaults to 'update'.
64
64
  """
65
- def __init__(self, decay: float = 0, target: Target = 'update',):
65
+ def __init__(self, decay: float = 0):
66
66
  defaults = dict(decay=decay)
67
- super().__init__(defaults, uses_grad=False, target=target)
67
+ super().__init__(defaults)
68
68
 
69
69
  @torch.no_grad
70
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
70
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
71
71
  maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
72
72
  decay = [1-s['decay'] for s in settings]
73
73
  return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
74
74
 
75
- class AccumulateMinimum(Transform):
75
+ class AccumulateMinimum(TensorTransform):
76
76
  """Accumulates minimum of all past updates.
77
77
 
78
78
  Args:
79
79
  decay (float, optional): decays the accumulator. Defaults to 0.
80
80
  target (Target, optional): target. Defaults to 'update'.
81
81
  """
82
- def __init__(self, decay: float = 0, target: Target = 'update',):
82
+ def __init__(self, decay: float = 0):
83
83
  defaults = dict(decay=decay)
84
- super().__init__(defaults, uses_grad=False, target=target)
84
+ super().__init__(defaults)
85
85
 
86
86
  @torch.no_grad
87
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
88
88
  minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
89
89
  decay = [1-s['decay'] for s in settings]
90
90
  return minimum.minimum_(tensors).lazy_mul(decay, clone=True)