torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,27 @@
1
1
  """Modules that perform averaging over a history of past updates."""
2
2
  from collections import deque
3
3
  from collections.abc import Sequence
4
- from typing import Any, Literal, cast
4
+ from typing import Any
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import TensorwiseTransform, Target
8
+ from ...core import TensorTransform
9
9
  from ...utils import tolist
10
10
 
11
11
 
12
- class Averaging(TensorwiseTransform):
12
+ class Averaging(TensorTransform):
13
13
  """Average of past ``history_size`` updates.
14
14
 
15
15
  Args:
16
16
  history_size (int): Number of past updates to average
17
17
  target (Target, optional): target. Defaults to 'update'.
18
18
  """
19
- def __init__(self, history_size: int, target: Target = 'update'):
19
+ def __init__(self, history_size: int):
20
20
  defaults = dict(history_size=history_size)
21
- super().__init__(uses_grad=False, defaults=defaults, target=target)
21
+ super().__init__(defaults=defaults)
22
22
 
23
23
  @torch.no_grad
24
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
24
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
25
25
  history_size = setting['history_size']
26
26
  if 'history' not in state:
27
27
  state['history'] = deque(maxlen=history_size)
@@ -34,19 +34,19 @@ class Averaging(TensorwiseTransform):
34
34
 
35
35
  return average / len(history)
36
36
 
37
- class WeightedAveraging(TensorwiseTransform):
37
+ class WeightedAveraging(TensorTransform):
38
38
  """Weighted average of past ``len(weights)`` updates.
39
39
 
40
40
  Args:
41
41
  weights (Sequence[float]): a sequence of weights from oldest to newest.
42
42
  target (Target, optional): target. Defaults to 'update'.
43
43
  """
44
- def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
44
+ def __init__(self, weights: Sequence[float] | torch.Tensor | Any):
45
45
  defaults = dict(weights = tolist(weights))
46
- super().__init__(uses_grad=False, defaults=defaults, target=target)
46
+ super().__init__(defaults=defaults)
47
47
 
48
48
  @torch.no_grad
49
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
49
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
50
50
  weights = setting['weights']
51
51
 
52
52
  if 'history' not in state:
@@ -68,19 +68,19 @@ class WeightedAveraging(TensorwiseTransform):
68
68
  return average
69
69
 
70
70
 
71
- class MedianAveraging(TensorwiseTransform):
71
+ class MedianAveraging(TensorTransform):
72
72
  """Median of past ``history_size`` updates.
73
73
 
74
74
  Args:
75
75
  history_size (int): Number of past updates to average
76
76
  target (Target, optional): target. Defaults to 'update'.
77
77
  """
78
- def __init__(self, history_size: int, target: Target = 'update'):
78
+ def __init__(self, history_size: int,):
79
79
  defaults = dict(history_size = history_size)
80
- super().__init__(uses_grad=False, defaults=defaults, target=target)
80
+ super().__init__(defaults=defaults)
81
81
 
82
82
  @torch.no_grad
83
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
83
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
84
84
  history_size = setting['history_size']
85
85
 
86
86
  if 'history' not in state:
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Target, Transform, Module, Chainable
8
+ from ...core import TensorTransform, Module, Chainable
9
9
  from ...utils import NumberList, TensorList, unpack_dicts
10
10
 
11
11
 
@@ -36,7 +36,7 @@ def cautious_(
36
36
  tensors_ -= tensors_.mul(2).mul_(mask.logical_not_())
37
37
  return tensors_
38
38
 
39
- class Cautious(Transform):
39
+ class Cautious(TensorTransform):
40
40
  """Negates update for parameters where update and gradient sign is inconsistent.
41
41
  Optionally normalizes the update by the number of parameters that are not masked.
42
42
  This is meant to be used after any momentum-based modules.
@@ -79,12 +79,12 @@ class Cautious(Transform):
79
79
  super().__init__(defaults, uses_grad=True)
80
80
 
81
81
  @torch.no_grad
82
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
82
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
83
83
  assert grads is not None
84
84
  mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
85
85
  return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
86
86
 
87
- class UpdateGradientSignConsistency(Transform):
87
+ class UpdateGradientSignConsistency(TensorTransform):
88
88
  """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
89
89
 
90
90
  Args:
@@ -98,7 +98,7 @@ class UpdateGradientSignConsistency(Transform):
98
98
  super().__init__(defaults, uses_grad=True)
99
99
 
100
100
  @torch.no_grad
101
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
101
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
102
102
  assert grads is not None
103
103
  normalize, eps = itemgetter('normalize', 'eps')(settings[0])
104
104
 
@@ -108,7 +108,7 @@ class UpdateGradientSignConsistency(Transform):
108
108
  return mask
109
109
 
110
110
  class IntermoduleCautious(Module):
111
- """Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.
111
+ """Negaties update on :code:`main` module where it's sign doesn't match with output of ``compare`` module.
112
112
 
113
113
  Args:
114
114
  main (Chainable): main module or sequence of modules whose update will be cautioned.
@@ -137,29 +137,32 @@ class IntermoduleCautious(Module):
137
137
  self.set_child('main', main)
138
138
  self.set_child('compare', compare)
139
139
 
140
+ def update(self, objective): raise RuntimeError
141
+ def apply(self, objective): raise RuntimeError
142
+
140
143
  @torch.no_grad
141
- def step(self, var):
144
+ def step(self, objective):
142
145
  main = self.children['main']
143
146
  compare = self.children['compare']
144
147
 
145
- main_var = main.step(var.clone(clone_update=True))
146
- var.update_attrs_from_clone_(main_var)
148
+ main_var = main.step(objective.clone(clone_updates=True))
149
+ objective.update_attrs_from_clone_(main_var)
147
150
 
148
- compare_var = compare.step(var.clone(clone_update=True))
149
- var.update_attrs_from_clone_(compare_var)
151
+ compare_var = compare.step(objective.clone(clone_updates=True))
152
+ objective.update_attrs_from_clone_(compare_var)
150
153
 
151
154
  mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
152
- var.update = cautious_(
153
- TensorList(main_var.get_update()),
154
- TensorList(compare_var.get_update()),
155
+ objective.updates = cautious_(
156
+ TensorList(main_var.get_updates()),
157
+ TensorList(compare_var.get_updates()),
155
158
  normalize=normalize,
156
159
  mode=mode,
157
160
  eps=eps,
158
161
  )
159
162
 
160
- return var
163
+ return objective
161
164
 
162
- class ScaleByGradCosineSimilarity(Transform):
165
+ class ScaleByGradCosineSimilarity(TensorTransform):
163
166
  """Multiplies the update by cosine similarity with gradient.
164
167
  If cosine similarity is negative, naturally the update will be negated as well.
165
168
 
@@ -186,7 +189,7 @@ class ScaleByGradCosineSimilarity(Transform):
186
189
  super().__init__(defaults, uses_grad=True)
187
190
 
188
191
  @torch.no_grad
189
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
192
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
190
193
  assert grads is not None
191
194
  eps = settings[0]['eps']
192
195
  tensors = TensorList(tensors)
@@ -196,8 +199,8 @@ class ScaleByGradCosineSimilarity(Transform):
196
199
  return tensors.mul_(cos_sim)
197
200
 
198
201
  class ScaleModulesByCosineSimilarity(Module):
199
- """Scales the output of :code:`main` module by it's cosine similarity to the output
200
- of :code:`compare` module.
202
+ """Scales the output of ``main`` module by it's cosine similarity to the output
203
+ of ``compare`` module.
201
204
 
202
205
  Args:
203
206
  main (Chainable): main module or sequence of modules whose update will be scaled.
@@ -230,22 +233,25 @@ class ScaleModulesByCosineSimilarity(Module):
230
233
  self.set_child('main', main)
231
234
  self.set_child('compare', compare)
232
235
 
236
+ def update(self, objective): raise RuntimeError
237
+ def apply(self, objective): raise RuntimeError
238
+
233
239
  @torch.no_grad
234
- def step(self, var):
240
+ def step(self, objective):
235
241
  main = self.children['main']
236
242
  compare = self.children['compare']
237
243
 
238
- main_var = main.step(var.clone(clone_update=True))
239
- var.update_attrs_from_clone_(main_var)
244
+ main_var = main.step(objective.clone(clone_updates=True))
245
+ objective.update_attrs_from_clone_(main_var)
240
246
 
241
- compare_var = compare.step(var.clone(clone_update=True))
242
- var.update_attrs_from_clone_(compare_var)
247
+ compare_var = compare.step(objective.clone(clone_updates=True))
248
+ objective.update_attrs_from_clone_(compare_var)
243
249
 
244
- m = TensorList(main_var.get_update())
245
- c = TensorList(compare_var.get_update())
250
+ m = TensorList(main_var.get_updates())
251
+ c = TensorList(compare_var.get_updates())
246
252
  eps = self.defaults['eps']
247
253
 
248
254
  cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
249
255
 
250
- var.update = m.mul_(cos_sim)
251
- return var
256
+ objective.updates = m.mul_(cos_sim)
257
+ return objective
@@ -4,12 +4,12 @@ from typing import Literal
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Target, Transform
7
+ from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
9
  from ..functional import debias, ema_
10
10
 
11
11
 
12
- class EMA(Transform):
12
+ class EMA(TensorTransform):
13
13
  """Maintains an exponential moving average of update.
14
14
 
15
15
  Args:
@@ -20,12 +20,12 @@ class EMA(Transform):
20
20
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
21
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
22
  """
23
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
23
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
24
24
  defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
25
- super().__init__(defaults, uses_grad=False, target=target)
25
+ super().__init__(defaults, uses_grad=False)
26
26
 
27
27
  @torch.no_grad
28
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
29
29
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
30
 
31
31
  debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
@@ -53,8 +53,8 @@ class HeavyBall(EMA):
53
53
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
54
54
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
55
55
  """
56
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
57
- super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
56
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
57
+ super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init)
58
58
 
59
59
  def nag_(
60
60
  tensors_: TensorList,
@@ -74,7 +74,7 @@ def nag_(
74
74
  return tensors_
75
75
 
76
76
 
77
- class NAG(Transform):
77
+ class NAG(TensorTransform):
78
78
  """Nesterov accelerated gradient method (nesterov momentum).
79
79
 
80
80
  Args:
@@ -84,12 +84,12 @@ class NAG(Transform):
84
84
  whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
85
85
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
86
86
  """
87
- def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
87
+ def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False):
88
88
  defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
89
- super().__init__(defaults, uses_grad=False, target=target)
89
+ super().__init__(defaults, uses_grad=False)
90
90
 
91
91
  @torch.no_grad
92
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
93
93
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
94
94
  lerp = self.settings[params[0]]['lerp']
95
95
 
@@ -12,8 +12,8 @@ from .binary import (
12
12
  CopyMagnitude,
13
13
  CopySign,
14
14
  Div,
15
- Graft,
16
- GraftToUpdate,
15
+ GraftInputToOutput,
16
+ GraftInputToOutput,
17
17
  GramSchimdt,
18
18
  Maximum,
19
19
  Minimum,
@@ -21,7 +21,7 @@ from .binary import (
21
21
  Pow,
22
22
  RCopySign,
23
23
  RDiv,
24
- RGraft,
24
+ GraftOutputToInput,
25
25
  RPow,
26
26
  RSub,
27
27
  Sub,
@@ -38,7 +38,7 @@ from .higher_level import (
38
38
  from .multi import (
39
39
  ClipModules,
40
40
  DivModules,
41
- GraftModules,
41
+ Graft,
42
42
  LerpModules,
43
43
  MultiOperationBase,
44
44
  PowModules,
@@ -1,90 +1,90 @@
1
1
  import torch
2
2
 
3
- from ...core import Target, Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import TensorList, unpack_states
5
5
 
6
- class AccumulateSum(Transform):
6
+ class AccumulateSum(TensorTransform):
7
7
  """Accumulates sum of all past updates.
8
8
 
9
9
  Args:
10
10
  decay (float, optional): decays the accumulator. Defaults to 0.
11
11
  target (Target, optional): target. Defaults to 'update'.
12
12
  """
13
- def __init__(self, decay: float = 0, target: Target = 'update',):
13
+ def __init__(self, decay: float = 0):
14
14
  defaults = dict(decay=decay)
15
- super().__init__(defaults, uses_grad=False, target=target)
15
+ super().__init__(defaults)
16
16
 
17
17
  @torch.no_grad
18
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
18
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
19
19
  sum = unpack_states(states, tensors, 'sum', cls=TensorList)
20
20
  decay = [1-s['decay'] for s in settings]
21
21
  return sum.add_(tensors).lazy_mul(decay, clone=True)
22
22
 
23
- class AccumulateMean(Transform):
23
+ class AccumulateMean(TensorTransform):
24
24
  """Accumulates mean of all past updates.
25
25
 
26
26
  Args:
27
27
  decay (float, optional): decays the accumulator. Defaults to 0.
28
28
  target (Target, optional): target. Defaults to 'update'.
29
29
  """
30
- def __init__(self, decay: float = 0, target: Target = 'update',):
30
+ def __init__(self, decay: float = 0):
31
31
  defaults = dict(decay=decay)
32
- super().__init__(defaults, uses_grad=False, target=target)
32
+ super().__init__(defaults)
33
33
 
34
34
  @torch.no_grad
35
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
35
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
36
36
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
37
37
  mean = unpack_states(states, tensors, 'mean', cls=TensorList)
38
38
  decay = [1-s['decay'] for s in settings]
39
39
  return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
40
40
 
41
- class AccumulateProduct(Transform):
41
+ class AccumulateProduct(TensorTransform):
42
42
  """Accumulates product of all past updates.
43
43
 
44
44
  Args:
45
45
  decay (float, optional): decays the accumulator. Defaults to 0.
46
46
  target (Target, optional): target. Defaults to 'update'.
47
47
  """
48
- def __init__(self, decay: float = 0, target: Target = 'update',):
48
+ def __init__(self, decay: float = 0, target = 'update',):
49
49
  defaults = dict(decay=decay)
50
- super().__init__(defaults, uses_grad=False, target=target)
50
+ super().__init__(defaults)
51
51
 
52
52
  @torch.no_grad
53
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
53
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
54
54
  prod = unpack_states(states, tensors, 'prod', cls=TensorList)
55
55
  decay = [1-s['decay'] for s in settings]
56
56
  return prod.mul_(tensors).lazy_mul(decay, clone=True)
57
57
 
58
- class AccumulateMaximum(Transform):
58
+ class AccumulateMaximum(TensorTransform):
59
59
  """Accumulates maximum of all past updates.
60
60
 
61
61
  Args:
62
62
  decay (float, optional): decays the accumulator. Defaults to 0.
63
63
  target (Target, optional): target. Defaults to 'update'.
64
64
  """
65
- def __init__(self, decay: float = 0, target: Target = 'update',):
65
+ def __init__(self, decay: float = 0):
66
66
  defaults = dict(decay=decay)
67
- super().__init__(defaults, uses_grad=False, target=target)
67
+ super().__init__(defaults)
68
68
 
69
69
  @torch.no_grad
70
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
70
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
71
71
  maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
72
72
  decay = [1-s['decay'] for s in settings]
73
73
  return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
74
74
 
75
- class AccumulateMinimum(Transform):
75
+ class AccumulateMinimum(TensorTransform):
76
76
  """Accumulates minimum of all past updates.
77
77
 
78
78
  Args:
79
79
  decay (float, optional): decays the accumulator. Defaults to 0.
80
80
  target (Target, optional): target. Defaults to 'update'.
81
81
  """
82
- def __init__(self, decay: float = 0, target: Target = 'update',):
82
+ def __init__(self, decay: float = 0):
83
83
  defaults = dict(decay=decay)
84
- super().__init__(defaults, uses_grad=False, target=target)
84
+ super().__init__(defaults)
85
85
 
86
86
  @torch.no_grad
87
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
88
88
  minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
89
89
  decay = [1-s['decay'] for s in settings]
90
90
  return minimum.minimum_(tensors).lazy_mul(decay, clone=True)