torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from typing import Any, Literal
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Target, Var, maybe_chain
11
- from ...utils import TensorList, tensorlist, Metrics
10
+ from ...core import Chainable, Module, Objective
11
+ from ...utils import TensorList, Metrics
12
12
 
13
13
 
14
14
  class MultiOperationBase(Module, ABC):
@@ -29,36 +29,39 @@ class MultiOperationBase(Module, ABC):
29
29
  raise ValueError('At least one operand must be a module')
30
30
 
31
31
  @abstractmethod
32
- def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
32
+ def transform(self, objective: Objective, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
33
33
  """applies the operation to operands"""
34
34
  raise NotImplementedError
35
35
 
36
+ def update(self, objective): raise RuntimeError
37
+ def apply(self, objective): raise RuntimeError
38
+
36
39
  @torch.no_grad
37
- def step(self, var: Var) -> Var:
40
+ def step(self, objective: Objective) -> Objective:
38
41
  # pass cloned update to all module operands
39
42
  processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
40
43
 
41
44
  for k,v in self.operands.items():
42
45
  if k in self.children:
43
46
  v: Module
44
- updated_var = v.step(var.clone(clone_update=True))
45
- processed_operands[k] = updated_var.get_update()
46
- var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
47
+ updated_obj = v.step(objective.clone(clone_updates=True))
48
+ processed_operands[k] = updated_obj.get_updates()
49
+ objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them
47
50
 
48
- transformed = self.transform(var, **processed_operands)
49
- var.update = transformed
50
- return var
51
+ transformed = self.transform(objective, **processed_operands)
52
+ objective.updates = transformed
53
+ return objective
51
54
 
52
55
 
53
56
 
54
57
  class SubModules(MultiOperationBase):
55
- """Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
58
+ """Calculates ``input - other``. ``input`` and ``other`` can be numbers or modules."""
56
59
  def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
57
60
  defaults = dict(alpha=alpha)
58
61
  super().__init__(defaults, input=input, other=other)
59
62
 
60
63
  @torch.no_grad
61
- def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
64
+ def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
62
65
  alpha = self.defaults['alpha']
63
66
 
64
67
  if isinstance(input, (int,float)):
@@ -70,14 +73,14 @@ class SubModules(MultiOperationBase):
70
73
  return input
71
74
 
72
75
  class DivModules(MultiOperationBase):
73
- """Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
76
+ """Calculates ``input / other``. ``input`` and ``other`` can be numbers or modules."""
74
77
  def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
75
78
  defaults = {}
76
79
  if other_first: super().__init__(defaults, other=other, input=input)
77
80
  else: super().__init__(defaults, input=input, other=other)
78
81
 
79
82
  @torch.no_grad
80
- def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
83
+ def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
81
84
  if isinstance(input, (int,float)):
82
85
  assert isinstance(other, list)
83
86
  return input / TensorList(other)
@@ -87,13 +90,13 @@ class DivModules(MultiOperationBase):
87
90
 
88
91
 
89
92
  class PowModules(MultiOperationBase):
90
- """Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
93
+ """Calculates ``input ** exponent``. ``input`` and ``other`` can be numbers or modules."""
91
94
  def __init__(self, input: Chainable | float, exponent: Chainable | float):
92
95
  defaults = {}
93
96
  super().__init__(defaults, input=input, exponent=exponent)
94
97
 
95
98
  @torch.no_grad
96
- def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
99
+ def transform(self, objective: Objective, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
97
100
  if isinstance(input, (int,float)):
98
101
  assert isinstance(exponent, list)
99
102
  return input ** TensorList(exponent)
@@ -102,32 +105,32 @@ class PowModules(MultiOperationBase):
102
105
  return input
103
106
 
104
107
  class LerpModules(MultiOperationBase):
105
- """Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
108
+ """Does a linear interpolation of ``input(tensors)`` and ``end(tensors)`` based on a scalar ``weight``.
106
109
 
107
- The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
110
+ The output is given by ``output = input(tensors) + weight * (end(tensors) - input(tensors))``
108
111
  """
109
112
  def __init__(self, input: Chainable, end: Chainable, weight: float):
110
113
  defaults = dict(weight=weight)
111
114
  super().__init__(defaults, input=input, end=end)
112
115
 
113
116
  @torch.no_grad
114
- def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
117
+ def transform(self, objective: Objective, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
115
118
  torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
116
119
  return input
117
120
 
118
121
  class ClipModules(MultiOperationBase):
119
- """Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
122
+ """Calculates ``input(tensors).clip(min, max)``. ``min`` and ``max`` can be numbers or modules."""
120
123
  def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
121
124
  defaults = {}
122
125
  super().__init__(defaults, input=input, min=min, max=max)
123
126
 
124
127
  @torch.no_grad
125
- def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
128
+ def transform(self, objective: Objective, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
126
129
  return TensorList(input).clamp_(min=min, max=max)
127
130
 
128
131
 
129
- class GraftModules(MultiOperationBase):
130
- """Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
132
+ class Graft(MultiOperationBase):
133
+ """Outputs ``direction`` output rescaled to have the same norm as ``magnitude`` output.
131
134
 
132
135
  Args:
133
136
  direction (Chainable): module to use the direction from
@@ -137,40 +140,40 @@ class GraftModules(MultiOperationBase):
137
140
  eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
138
141
  strength (float, optional): strength of grafting. Defaults to 1.
139
142
 
140
- Example:
141
- Shampoo grafted to Adam
142
-
143
- .. code-block:: python
143
+ ### Example:
144
144
 
145
- opt = tz.Modular(
146
- model.parameters(),
147
- tz.m.GraftModules(
148
- direction = tz.m.Shampoo(),
149
- magnitude = tz.m.Adam(),
150
- ),
151
- tz.m.LR(1e-3)
152
- )
145
+ Shampoo grafted to Adam
146
+ ```python
147
+ opt = tz.Modular(
148
+ model.parameters(),
149
+ tz.m.GraftModules(
150
+ direction = tz.m.Shampoo(),
151
+ magnitude = tz.m.Adam(),
152
+ ),
153
+ tz.m.LR(1e-3)
154
+ )
155
+ ```
153
156
 
154
157
  Reference:
155
- Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
158
+ [Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.](https://arxiv.org/pdf/2002.11803)
156
159
  """
157
160
  def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
158
161
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
159
162
  super().__init__(defaults, direction=direction, magnitude=magnitude)
160
163
 
161
164
  @torch.no_grad
162
- def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
165
+ def transform(self, objective, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
163
166
  tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
164
167
  return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
165
168
 
166
169
  class MultiplyByModuleNorm(MultiOperationBase):
167
- """Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
170
+ """Outputs ``input`` multiplied by norm of the ``norm`` output."""
168
171
  def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
169
172
  defaults = dict(tensorwise=tensorwise, ord=ord)
170
173
  super().__init__(defaults, input=input, norm=norm)
171
174
 
172
175
  @torch.no_grad
173
- def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
176
+ def transform(self, objective, input: list[torch.Tensor], norm:list[torch.Tensor]):
174
177
  tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
175
178
  if tensorwise:
176
179
  n = TensorList(norm).metric(ord)
@@ -181,13 +184,13 @@ class MultiplyByModuleNorm(MultiOperationBase):
181
184
  return input
182
185
 
183
186
  class DivideByModuleNorm(MultiOperationBase):
184
- """Outputs :code:`input` divided by norm of the :code:`norm` output."""
187
+ """Outputs ``input`` divided by norm of the ``norm`` output."""
185
188
  def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
186
189
  defaults = dict(tensorwise=tensorwise, ord=ord)
187
190
  super().__init__(defaults, input=input, norm=norm)
188
191
 
189
192
  @torch.no_grad
190
- def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
193
+ def transform(self, objective, input: list[torch.Tensor], norm:list[torch.Tensor]):
191
194
  tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
192
195
  if tensorwise:
193
196
  n = TensorList(norm).metric(ord)
@@ -5,7 +5,7 @@ from typing import Any, cast
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, Target, Var, maybe_chain
8
+ from ...core import Chainable, Module, Objective, maybe_chain
9
9
 
10
10
 
11
11
  class ReduceOperationBase(Module, ABC):
@@ -26,34 +26,37 @@ class ReduceOperationBase(Module, ABC):
26
26
  raise ValueError('At least one operand must be a module')
27
27
 
28
28
  @abstractmethod
29
- def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
29
+ def transform(self, objective: Objective, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
30
30
  """applies the operation to operands"""
31
31
  raise NotImplementedError
32
32
 
33
+ def update(self, objective): raise RuntimeError
34
+ def apply(self, objective): raise RuntimeError
35
+
33
36
  @torch.no_grad
34
- def step(self, var: Var) -> Var:
37
+ def step(self, objective: Objective) -> Objective:
35
38
  # pass cloned update to all module operands
36
39
  processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
37
40
 
38
41
  for i, v in enumerate(self.operands):
39
42
  if f'operand_{i}' in self.children:
40
43
  v: Module
41
- updated_var = v.step(var.clone(clone_update=True))
42
- processed_operands[i] = updated_var.get_update()
43
- var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
44
+ updated_obj = v.step(objective.clone(clone_updates=True))
45
+ processed_operands[i] = updated_obj.get_updates()
46
+ objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them
44
47
 
45
- transformed = self.transform(var, *processed_operands)
46
- var.update = transformed
47
- return var
48
+ transformed = self.transform(objective, *processed_operands)
49
+ objective.updates = transformed
50
+ return objective
48
51
 
49
52
  class Sum(ReduceOperationBase):
50
- """Outputs sum of :code:`inputs` that can be modules or numbers."""
53
+ """Outputs sum of ``inputs`` that can be modules or numbers."""
51
54
  USE_MEAN = False
52
55
  def __init__(self, *inputs: Chainable | float):
53
56
  super().__init__({}, *inputs)
54
57
 
55
58
  @torch.no_grad
56
- def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
59
+ def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
57
60
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
58
61
  sum = cast(list, sorted_inputs[0])
59
62
  if len(sorted_inputs) > 1:
@@ -64,14 +67,14 @@ class Sum(ReduceOperationBase):
64
67
  return sum
65
68
 
66
69
  class Mean(Sum):
67
- """Outputs a mean of :code:`inputs` that can be modules or numbers."""
70
+ """Outputs a mean of ``inputs`` that can be modules or numbers."""
68
71
  USE_MEAN = True
69
72
 
70
73
 
71
74
  class WeightedSum(ReduceOperationBase):
75
+ """Outputs a weighted sum of ``inputs`` that can be modules or numbers."""
72
76
  USE_MEAN = False
73
77
  def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
74
- """Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
75
78
  weights = list(weights)
76
79
  if len(inputs) != len(weights):
77
80
  raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
@@ -79,7 +82,7 @@ class WeightedSum(ReduceOperationBase):
79
82
  super().__init__(defaults=defaults, *inputs)
80
83
 
81
84
  @torch.no_grad
82
- def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
85
+ def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
83
86
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
84
87
  weights = self.defaults['weights']
85
88
  sum = cast(list, sorted_inputs[0])
@@ -94,16 +97,16 @@ class WeightedSum(ReduceOperationBase):
94
97
 
95
98
 
96
99
  class WeightedMean(WeightedSum):
97
- """Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
100
+ """Outputs weighted mean of ``inputs`` that can be modules or numbers."""
98
101
  USE_MEAN = True
99
102
 
100
103
  class Median(ReduceOperationBase):
101
- """Outputs median of :code:`inputs` that can be modules or numbers."""
104
+ """Outputs median of ``inputs`` that can be modules or numbers."""
102
105
  def __init__(self, *inputs: Chainable | float):
103
106
  super().__init__({}, *inputs)
104
107
 
105
108
  @torch.no_grad
106
- def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
109
+ def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
107
110
  res = []
108
111
  lists = [i for i in inputs if isinstance(i, list)]
109
112
  floats = [i for i in inputs if isinstance(i, (int,float))]
@@ -112,12 +115,12 @@ class Median(ReduceOperationBase):
112
115
  return res
113
116
 
114
117
  class Prod(ReduceOperationBase):
115
- """Outputs product of :code:`inputs` that can be modules or numbers."""
118
+ """Outputs product of ``inputs`` that can be modules or numbers."""
116
119
  def __init__(self, *inputs: Chainable | float):
117
120
  super().__init__({}, *inputs)
118
121
 
119
122
  @torch.no_grad
120
- def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
123
+ def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
121
124
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
122
125
  prod = cast(list, sorted_inputs[0])
123
126
  if len(sorted_inputs) > 1:
@@ -127,12 +130,12 @@ class Prod(ReduceOperationBase):
127
130
  return prod
128
131
 
129
132
  class MaximumModules(ReduceOperationBase):
130
- """Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
133
+ """Outputs elementwise maximum of ``inputs`` that can be modules or numbers."""
131
134
  def __init__(self, *inputs: Chainable | float):
132
135
  super().__init__({}, *inputs)
133
136
 
134
137
  @torch.no_grad
135
- def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
138
+ def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
136
139
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
137
140
  maximum = cast(list, sorted_inputs[0])
138
141
  if len(sorted_inputs) > 1:
@@ -142,12 +145,12 @@ class MaximumModules(ReduceOperationBase):
142
145
  return maximum
143
146
 
144
147
  class MinimumModules(ReduceOperationBase):
145
- """Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
148
+ """Outputs elementwise minimum of ``inputs`` that can be modules or numbers."""
146
149
  def __init__(self, *inputs: Chainable | float):
147
150
  super().__init__({}, *inputs)
148
151
 
149
152
  @torch.no_grad
150
- def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
153
+ def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
151
154
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
152
155
  minimum = cast(list, sorted_inputs[0])
153
156
  if len(sorted_inputs) > 1:
@@ -2,102 +2,102 @@ from collections import deque
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import TensorwiseTransform, Target, Transform
5
+ from ...core import TensorTransform
6
6
  from ...utils import TensorList, unpack_dicts,unpack_states
7
7
 
8
- class UnaryLambda(Transform):
9
- """Applies :code:`fn` to input tensors.
8
+ class UnaryLambda(TensorTransform):
9
+ """Applies ``fn`` to input tensors.
10
10
 
11
- :code:`fn` must accept and return a list of tensors.
11
+ ``fn`` must accept and return a list of tensors.
12
12
  """
13
- def __init__(self, fn, target: "Target" = 'update'):
13
+ def __init__(self, fn):
14
14
  defaults = dict(fn=fn)
15
- super().__init__(defaults=defaults, uses_grad=False, target=target)
15
+ super().__init__(defaults=defaults)
16
16
 
17
17
  @torch.no_grad
18
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
18
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
19
19
  return settings[0]['fn'](tensors)
20
20
 
21
- class UnaryParameterwiseLambda(TensorwiseTransform):
22
- """Applies :code:`fn` to each input tensor.
21
+ class UnaryParameterwiseLambda(TensorTransform):
22
+ """Applies ``fn`` to each input tensor.
23
23
 
24
- :code:`fn` must accept and return a tensor.
24
+ ``fn`` must accept and return a tensor.
25
25
  """
26
- def __init__(self, fn, target: "Target" = 'update'):
26
+ def __init__(self, fn):
27
27
  defaults = dict(fn=fn)
28
- super().__init__(uses_grad=False, defaults=defaults, target=target)
28
+ super().__init__(defaults=defaults)
29
29
 
30
30
  @torch.no_grad
31
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
31
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
32
32
  return setting['fn'](tensor)
33
33
 
34
- class CustomUnaryOperation(Transform):
35
- """Applies :code:`getattr(tensor, name)` to each tensor
34
+ class CustomUnaryOperation(TensorTransform):
35
+ """Applies ``getattr(tensor, name)`` to each tensor
36
36
  """
37
- def __init__(self, name: str, target: "Target" = 'update'):
37
+ def __init__(self, name: str):
38
38
  defaults = dict(name=name)
39
- super().__init__(defaults=defaults, uses_grad=False, target=target)
39
+ super().__init__(defaults=defaults)
40
40
 
41
41
  @torch.no_grad
42
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
42
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
43
43
  return getattr(tensors, settings[0]['name'])()
44
44
 
45
45
 
46
- class Abs(Transform):
47
- """Returns :code:`abs(input)`"""
48
- def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
46
+ class Abs(TensorTransform):
47
+ """Returns ``abs(input)``"""
48
+ def __init__(self): super().__init__()
49
49
  @torch.no_grad
50
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
50
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
51
51
  torch._foreach_abs_(tensors)
52
52
  return tensors
53
53
 
54
- class Sign(Transform):
55
- """Returns :code:`sign(input)`"""
56
- def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
54
+ class Sign(TensorTransform):
55
+ """Returns ``sign(input)``"""
56
+ def __init__(self): super().__init__()
57
57
  @torch.no_grad
58
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
58
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
59
59
  torch._foreach_sign_(tensors)
60
60
  return tensors
61
61
 
62
- class Exp(Transform):
63
- """Returns :code:`exp(input)`"""
64
- def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
62
+ class Exp(TensorTransform):
63
+ """Returns ``exp(input)``"""
64
+ def __init__(self): super().__init__()
65
65
  @torch.no_grad
66
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
66
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
67
67
  torch._foreach_exp_(tensors)
68
68
  return tensors
69
69
 
70
- class Sqrt(Transform):
71
- """Returns :code:`sqrt(input)`"""
72
- def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
70
+ class Sqrt(TensorTransform):
71
+ """Returns ``sqrt(input)``"""
72
+ def __init__(self): super().__init__()
73
73
  @torch.no_grad
74
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
74
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
75
75
  torch._foreach_sqrt_(tensors)
76
76
  return tensors
77
77
 
78
- class Reciprocal(Transform):
79
- """Returns :code:`1 / input`"""
80
- def __init__(self, eps = 0, target: "Target" = 'update'):
78
+ class Reciprocal(TensorTransform):
79
+ """Returns ``1 / input``"""
80
+ def __init__(self, eps = 0):
81
81
  defaults = dict(eps = eps)
82
- super().__init__(defaults, uses_grad=False, target=target)
82
+ super().__init__(defaults)
83
83
  @torch.no_grad
84
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
84
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
85
85
  eps = [s['eps'] for s in settings]
86
86
  if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
87
87
  torch._foreach_reciprocal_(tensors)
88
88
  return tensors
89
89
 
90
- class Negate(Transform):
91
- """Returns :code:`- input`"""
92
- def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
90
+ class Negate(TensorTransform):
91
+ """Returns ``- input``"""
92
+ def __init__(self): super().__init__()
93
93
  @torch.no_grad
94
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
94
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
95
95
  torch._foreach_neg_(tensors)
96
96
  return tensors
97
97
 
98
98
 
99
- class NanToNum(Transform):
100
- """Convert `nan`, `inf` and `-inf` to numbers.
99
+ class NanToNum(TensorTransform):
100
+ """Convert ``nan``, ``inf`` and `-`inf`` to numbers.
101
101
 
102
102
  Args:
103
103
  nan (optional): the value to replace NaNs with. Default is zero.
@@ -108,23 +108,23 @@ class NanToNum(Transform):
108
108
  If None, negative infinity values are replaced with the lowest finite value
109
109
  representable by input's dtype. Default is None.
110
110
  """
111
- def __init__(self, nan=None, posinf=None, neginf=None, target: "Target" = 'update'):
111
+ def __init__(self, nan=None, posinf=None, neginf=None):
112
112
  defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
113
- super().__init__(defaults, uses_grad=False, target=target)
113
+ super().__init__(defaults)
114
114
 
115
115
  @torch.no_grad
116
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
116
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
117
117
  nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
118
118
  return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
119
119
 
120
- class Rescale(Transform):
121
- """Rescales input to :code`(min, max)` range"""
122
- def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
120
+ class Rescale(TensorTransform):
121
+ """Rescales input to ``(min, max)`` range"""
122
+ def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8):
123
123
  defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
124
- super().__init__(defaults, uses_grad=False, target=target)
124
+ super().__init__(defaults)
125
125
 
126
126
  @torch.no_grad
127
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
127
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
128
128
  min, max = unpack_dicts(settings, 'min','max')
129
129
  tensorwise = settings[0]['tensorwise']
130
130
  dim = None if tensorwise else 'global'