torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,8 +6,8 @@ from typing import Any
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Target, Var, maybe_chain
10
- from ...utils import TensorList, tensorlist
9
+ from ...core import Chainable, Module, Objective
10
+ from ...utils import TensorList
11
11
 
12
12
 
13
13
  class BinaryOperationBase(Module, ABC):
@@ -25,263 +25,264 @@ class BinaryOperationBase(Module, ABC):
25
25
  self.operands[k] = v
26
26
 
27
27
  @abstractmethod
28
- def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
28
+ def transform(self, objective: Objective, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
29
29
  """applies the operation to operands"""
30
30
  raise NotImplementedError
31
31
 
32
+ def update(self, objective): raise RuntimeError
33
+ def apply(self, objective): raise RuntimeError
34
+
32
35
  @torch.no_grad
33
- def step(self, var: Var) -> Var:
36
+ def step(self, objective: Objective) -> Objective:
34
37
  # pass cloned update to all module operands
35
38
  processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
36
39
 
37
40
  for k,v in self.operands.items():
38
41
  if k in self.children:
39
42
  v: Module
40
- updated_var = v.step(var.clone(clone_update=True))
41
- processed_operands[k] = updated_var.get_update()
42
- var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
43
+ updated_obj = v.step(objective.clone(clone_updates=True))
44
+ processed_operands[k] = updated_obj.get_updates()
45
+ objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them
43
46
 
44
- transformed = self.transform(var, update=var.get_update(), **processed_operands)
45
- var.update = list(transformed)
46
- return var
47
+ transformed = self.transform(objective, update=objective.get_updates(), **processed_operands)
48
+ objective.updates = list(transformed)
49
+ return objective
47
50
 
48
51
 
49
52
  class Add(BinaryOperationBase):
50
- """Add :code:`other` to tensors. :code:`other` can be a number or a module.
53
+ """Add ``other`` to tensors. ``other`` can be a number or a module.
51
54
 
52
- If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
55
+ If ``other`` is a module, this calculates ``tensors + other(tensors)``
53
56
  """
54
57
  def __init__(self, other: Chainable | float, alpha: float = 1):
55
58
  defaults = dict(alpha=alpha)
56
59
  super().__init__(defaults, other=other)
57
60
 
58
61
  @torch.no_grad
59
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
62
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
60
63
  if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
61
64
  else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
62
65
  return update
63
66
 
64
67
  class Sub(BinaryOperationBase):
65
- """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
68
+ """Subtract ``other`` from tensors. ``other`` can be a number or a module.
66
69
 
67
- If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
70
+ If ``other`` is a module, this calculates :code:`tensors - other(tensors)`
68
71
  """
69
72
  def __init__(self, other: Chainable | float, alpha: float = 1):
70
73
  defaults = dict(alpha=alpha)
71
74
  super().__init__(defaults, other=other)
72
75
 
73
76
  @torch.no_grad
74
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
77
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
75
78
  if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
76
79
  else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
77
80
  return update
78
81
 
79
82
  class RSub(BinaryOperationBase):
80
- """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
83
+ """Subtract tensors from ``other``. ``other`` can be a number or a module.
81
84
 
82
- If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
85
+ If ``other`` is a module, this calculates ``other(tensors) - tensors``
83
86
  """
84
87
  def __init__(self, other: Chainable | float):
85
88
  super().__init__({}, other=other)
86
89
 
87
90
  @torch.no_grad
88
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
91
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
89
92
  return other - TensorList(update)
90
93
 
91
94
  class Mul(BinaryOperationBase):
92
- """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
95
+ """Multiply tensors by ``other``. ``other`` can be a number or a module.
93
96
 
94
- If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
97
+ If ``other`` is a module, this calculates ``tensors * other(tensors)``
95
98
  """
96
99
  def __init__(self, other: Chainable | float):
97
100
  super().__init__({}, other=other)
98
101
 
99
102
  @torch.no_grad
100
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
103
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
101
104
  torch._foreach_mul_(update, other)
102
105
  return update
103
106
 
104
107
  class Div(BinaryOperationBase):
105
- """Divide tensors by :code:`other`. :code:`other` can be a number or a module.
108
+ """Divide tensors by ``other``. ``other`` can be a number or a module.
106
109
 
107
- If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
110
+ If ``other`` is a module, this calculates ``tensors / other(tensors)``
108
111
  """
109
112
  def __init__(self, other: Chainable | float):
110
113
  super().__init__({}, other=other)
111
114
 
112
115
  @torch.no_grad
113
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
116
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
114
117
  torch._foreach_div_(update, other)
115
118
  return update
116
119
 
117
120
  class RDiv(BinaryOperationBase):
118
- """Divide :code:`other` by tensors. :code:`other` can be a number or a module.
121
+ """Divide ``other`` by tensors. ``other`` can be a number or a module.
119
122
 
120
- If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
123
+ If ``other`` is a module, this calculates ``other(tensors) / tensors``
121
124
  """
122
125
  def __init__(self, other: Chainable | float):
123
126
  super().__init__({}, other=other)
124
127
 
125
128
  @torch.no_grad
126
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
129
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
127
130
  return other / TensorList(update)
128
131
 
129
132
  class Pow(BinaryOperationBase):
130
- """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
133
+ """Take tensors to the power of ``exponent``. ``exponent`` can be a number or a module.
131
134
 
132
- If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
135
+ If ``exponent`` is a module, this calculates ``tensors ^ exponent(tensors)``
133
136
  """
134
137
  def __init__(self, exponent: Chainable | float):
135
138
  super().__init__({}, exponent=exponent)
136
139
 
137
140
  @torch.no_grad
138
- def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
141
+ def transform(self, objective, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
139
142
  torch._foreach_pow_(update, exponent)
140
143
  return update
141
144
 
142
145
  class RPow(BinaryOperationBase):
143
- """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
146
+ """Take ``other`` to the power of tensors. ``other`` can be a number or a module.
144
147
 
145
- If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
148
+ If ``other`` is a module, this calculates ``other(tensors) ^ tensors``
146
149
  """
147
150
  def __init__(self, other: Chainable | float):
148
151
  super().__init__({}, other=other)
149
152
 
150
153
  @torch.no_grad
151
- def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
154
+ def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
152
155
  if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
153
156
  torch._foreach_pow_(other, update)
154
157
  return other
155
158
 
156
159
  class Lerp(BinaryOperationBase):
157
- """Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
160
+ """Does a linear interpolation of tensors and ``end`` module based on a scalar ``weight``.
158
161
 
159
- The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
162
+ The output is given by ``output = tensors + weight * (end(tensors) - tensors)``
160
163
  """
161
164
  def __init__(self, end: Chainable, weight: float):
162
165
  defaults = dict(weight=weight)
163
166
  super().__init__(defaults, end=end)
164
167
 
165
168
  @torch.no_grad
166
- def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
167
- torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
169
+ def transform(self, objective, update: list[torch.Tensor], end: list[torch.Tensor]):
170
+ torch._foreach_lerp_(update, end, weight=self.get_settings(objective.params, 'weight'))
168
171
  return update
169
172
 
170
173
  class CopySign(BinaryOperationBase):
171
- """Returns tensors with sign copied from :code:`other(tensors)`."""
174
+ """Returns tensors with sign copied from ``other(tensors)``."""
172
175
  def __init__(self, other: Chainable):
173
176
  super().__init__({}, other=other)
174
177
 
175
178
  @torch.no_grad
176
- def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
179
+ def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
177
180
  return [u.copysign_(o) for u, o in zip(update, other)]
178
181
 
179
182
  class RCopySign(BinaryOperationBase):
180
- """Returns :code:`other(tensors)` with sign copied from tensors."""
183
+ """Returns ``other(tensors)`` with sign copied from tensors."""
181
184
  def __init__(self, other: Chainable):
182
185
  super().__init__({}, other=other)
183
186
 
184
187
  @torch.no_grad
185
- def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
188
+ def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
186
189
  return [o.copysign_(u) for u, o in zip(update, other)]
187
190
  CopyMagnitude = RCopySign
188
191
 
189
192
  class Clip(BinaryOperationBase):
190
- """clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
193
+ """clip tensors to be in ``(min, max)`` range. ``min`` and ``max`: can be None, numbers or modules.
191
194
 
192
- If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
195
+ If ``min`` and ``max`` are modules, this calculates ``tensors.clip(min(tensors), max(tensors))``.
193
196
  """
194
197
  def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
195
198
  super().__init__({}, min=min, max=max)
196
199
 
197
200
  @torch.no_grad
198
- def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
201
+ def transform(self, objective, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
199
202
  return TensorList(update).clamp_(min=min, max=max)
200
203
 
201
204
  class MirroredClip(BinaryOperationBase):
202
- """clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
205
+ """clip tensors to be in ``(-value, value)`` range. ``value`` can be a number or a module.
203
206
 
204
- If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
207
+ If ``value`` is a module, this calculates ``tensors.clip(-value(tensors), value(tensors))``
205
208
  """
206
209
  def __init__(self, value: float | Chainable):
207
210
  super().__init__({}, value=value)
208
211
 
209
212
  @torch.no_grad
210
- def transform(self, var, update: list[torch.Tensor], value: float | list[torch.Tensor]):
213
+ def transform(self, objective, update: list[torch.Tensor], value: float | list[torch.Tensor]):
211
214
  min = -value if isinstance(value, (int,float)) else [-v for v in value]
212
215
  return TensorList(update).clamp_(min=min, max=value)
213
216
 
214
- class Graft(BinaryOperationBase):
215
- """Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
217
+ class GraftInputToOutput(BinaryOperationBase):
218
+ """Outputs ``tensors`` rescaled to have the same norm as ``magnitude(tensors)``."""
216
219
  def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
217
220
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
218
221
  super().__init__(defaults, magnitude=magnitude)
219
222
 
220
223
  @torch.no_grad
221
- def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
224
+ def transform(self, objective, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
222
225
  tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
223
226
  return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
224
227
 
225
- class RGraft(BinaryOperationBase):
226
- """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
228
+ class GraftOutputToInput(BinaryOperationBase):
229
+ """Outputs ``magnitude(tensors)`` rescaled to have the same norm as ``tensors``"""
227
230
 
228
231
  def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
229
232
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
230
233
  super().__init__(defaults, direction=direction)
231
234
 
232
235
  @torch.no_grad
233
- def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
236
+ def transform(self, objective, update: list[torch.Tensor], direction: list[torch.Tensor]):
234
237
  tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
235
238
  return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
236
239
 
237
- GraftToUpdate = RGraft
238
-
239
240
  class Maximum(BinaryOperationBase):
240
- """Outputs :code:`maximum(tensors, other(tensors))`"""
241
+ """Outputs ``maximum(tensors, other(tensors))``"""
241
242
  def __init__(self, other: Chainable):
242
243
  super().__init__({}, other=other)
243
244
 
244
245
  @torch.no_grad
245
- def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
246
+ def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
246
247
  torch._foreach_maximum_(update, other)
247
248
  return update
248
249
 
249
250
  class Minimum(BinaryOperationBase):
250
- """Outputs :code:`minimum(tensors, other(tensors))`"""
251
+ """Outputs ``minimum(tensors, other(tensors))``"""
251
252
  def __init__(self, other: Chainable):
252
253
  super().__init__({}, other=other)
253
254
 
254
255
  @torch.no_grad
255
- def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
256
+ def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
256
257
  torch._foreach_minimum_(update, other)
257
258
  return update
258
259
 
259
260
 
260
261
  class GramSchimdt(BinaryOperationBase):
261
- """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
262
+ """outputs tensors made orthogonal to ``other(tensors)`` via Gram-Schmidt."""
262
263
  def __init__(self, other: Chainable):
263
264
  super().__init__({}, other=other)
264
265
 
265
266
  @torch.no_grad
266
- def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
267
+ def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
267
268
  update = TensorList(update); other = TensorList(other)
268
269
  min = torch.finfo(update[0].dtype).tiny * 2
269
270
  return update - (other*update) / (other*other).clip(min=min)
270
271
 
271
272
 
272
273
  class Threshold(BinaryOperationBase):
273
- """Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
274
+ """Outputs tensors thresholded such that values above ``threshold`` are set to ``value``."""
274
275
  def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
275
276
  defaults = dict(update_above=update_above)
276
277
  super().__init__(defaults, threshold=threshold, value=value)
277
278
 
278
279
  @torch.no_grad
279
- def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
280
+ def transform(self, objective, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
280
281
  update_above = self.defaults['update_above']
281
282
  update = TensorList(update)
282
283
  if update_above:
283
- if isinstance(value, list): return update.where_(update>threshold, value)
284
+ if isinstance(value, list): return update.where(update>threshold, value)
284
285
  return update.masked_fill_(update<=threshold, value)
285
286
 
286
- if isinstance(value, list): return update.where_(update<threshold, value)
287
+ if isinstance(value, list): return update.where(update<threshold, value)
287
288
  return update.masked_fill_(update>=threshold, value)
@@ -4,9 +4,9 @@ from typing import Literal
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Target, Transform
7
+ from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
- from ..functional import (
9
+ from ..opt_utils import (
10
10
  centered_ema_sq_,
11
11
  debias,
12
12
  debias_second_momentum,
@@ -17,7 +17,7 @@ from ..functional import (
17
17
  )
18
18
 
19
19
 
20
- class EMASquared(Transform):
20
+ class EMASquared(TensorTransform):
21
21
  """Maintains an exponential moving average of squared updates.
22
22
 
23
23
  Args:
@@ -29,10 +29,10 @@ class EMASquared(Transform):
29
29
 
30
30
  def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
31
31
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
32
- super().__init__(defaults, uses_grad=False)
32
+ super().__init__(defaults)
33
33
 
34
34
  @torch.no_grad
35
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
35
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
36
36
  amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
37
37
  beta = NumberList(s['beta'] for s in settings)
38
38
 
@@ -44,7 +44,7 @@ class EMASquared(Transform):
44
44
 
45
45
  return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
46
46
 
47
- class SqrtEMASquared(Transform):
47
+ class SqrtEMASquared(TensorTransform):
48
48
  """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
49
49
 
50
50
  Args:
@@ -56,11 +56,11 @@ class SqrtEMASquared(Transform):
56
56
  SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
57
57
  def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
58
58
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
59
- super().__init__(defaults, uses_grad=False)
59
+ super().__init__(defaults)
60
60
 
61
61
 
62
62
  @torch.no_grad
63
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
63
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
64
64
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
65
65
 
66
66
  amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
@@ -83,7 +83,7 @@ class SqrtEMASquared(Transform):
83
83
  )
84
84
 
85
85
 
86
- class Debias(Transform):
86
+ class Debias(TensorTransform):
87
87
  """Multiplies the update by an Adam debiasing term based first and/or second momentum.
88
88
 
89
89
  Args:
@@ -95,12 +95,12 @@ class Debias(Transform):
95
95
  pow (float, optional): power, assumes absolute value is used. Defaults to 2.
96
96
  target (Target, optional): target. Defaults to 'update'.
97
97
  """
98
- def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
98
+ def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2):
99
99
  defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
100
- super().__init__(defaults, uses_grad=False, target=target)
100
+ super().__init__(defaults)
101
101
 
102
102
  @torch.no_grad
103
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
103
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
104
104
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
105
105
 
106
106
  pow = settings[0]['pow']
@@ -108,7 +108,7 @@ class Debias(Transform):
108
108
 
109
109
  return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
110
110
 
111
- class Debias2(Transform):
111
+ class Debias2(TensorTransform):
112
112
  """Multiplies the update by an Adam debiasing term based on the second momentum.
113
113
 
114
114
  Args:
@@ -117,19 +117,19 @@ class Debias2(Transform):
117
117
  pow (float, optional): power, assumes absolute value is used. Defaults to 2.
118
118
  target (Target, optional): target. Defaults to 'update'.
119
119
  """
120
- def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
120
+ def __init__(self, beta: float = 0.999, pow: float = 2,):
121
121
  defaults = dict(beta=beta, pow=pow)
122
- super().__init__(defaults, uses_grad=False, target=target)
122
+ super().__init__(defaults, uses_grad=False)
123
123
 
124
124
  @torch.no_grad
125
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
126
126
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
127
127
 
128
128
  pow = settings[0]['pow']
129
129
  beta = NumberList(s['beta'] for s in settings)
130
130
  return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
131
131
 
132
- class CenteredEMASquared(Transform):
132
+ class CenteredEMASquared(TensorTransform):
133
133
  """Maintains a centered exponential moving average of squared updates. This also maintains an additional
134
134
  exponential moving average of un-squared updates, square of which is subtracted from the EMA.
135
135
 
@@ -143,7 +143,7 @@ class CenteredEMASquared(Transform):
143
143
  super().__init__(defaults, uses_grad=False)
144
144
 
145
145
  @torch.no_grad
146
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
146
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
147
147
  amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
148
148
  beta = NumberList(s['beta'] for s in settings)
149
149
 
@@ -162,7 +162,7 @@ class CenteredEMASquared(Transform):
162
162
  pow=pow,
163
163
  ).clone()
164
164
 
165
- class CenteredSqrtEMASquared(Transform):
165
+ class CenteredSqrtEMASquared(TensorTransform):
166
166
  """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
167
167
  This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
168
168
 
@@ -177,7 +177,7 @@ class CenteredSqrtEMASquared(Transform):
177
177
  super().__init__(defaults, uses_grad=False)
178
178
 
179
179
  @torch.no_grad
180
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
180
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
181
181
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
182
182
 
183
183
  amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])