torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -5,39 +5,16 @@ from typing import Literal
5
5
  import torch
6
6
 
7
7
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
9
- from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
10
-
11
-
12
- class EMA(Transform):
13
- """Maintains an exponential moving average of update.
14
-
15
- Args:
16
- momentum (float, optional): momentum (beta). Defaults to 0.9.
17
- dampening (float, optional): momentum dampening. Defaults to 0.
18
- debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
- lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
- ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
- target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
- """
23
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
24
- defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
25
- super().__init__(defaults, uses_grad=False, target=target)
26
-
27
- @torch.no_grad
28
- def apply(self, tensors, params, grads, loss, states, settings):
29
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
-
31
- debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
32
-
33
- exp_avg = unpack_states(states, tensors, 'exp_avg',
34
- init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
35
- momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
36
-
37
- exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
38
-
39
- if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
40
- else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
8
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
+ from ..functional import (
10
+ centered_ema_sq_,
11
+ debias,
12
+ debias_second_momentum,
13
+ ema_,
14
+ ema_sq_,
15
+ sqrt_centered_ema_sq_,
16
+ sqrt_ema_sq_,
17
+ )
41
18
 
42
19
 
43
20
  class EMASquared(Transform):
@@ -55,7 +32,7 @@ class EMASquared(Transform):
55
32
  super().__init__(defaults, uses_grad=False)
56
33
 
57
34
  @torch.no_grad
58
- def apply(self, tensors, params, grads, loss, states, settings):
35
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
59
36
  amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
60
37
  beta = NumberList(s['beta'] for s in settings)
61
38
 
@@ -83,7 +60,7 @@ class SqrtEMASquared(Transform):
83
60
 
84
61
 
85
62
  @torch.no_grad
86
- def apply(self, tensors, params, grads, loss, states, settings):
63
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
64
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
88
65
 
89
66
  amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
@@ -123,7 +100,7 @@ class Debias(Transform):
123
100
  super().__init__(defaults, uses_grad=False, target=target)
124
101
 
125
102
  @torch.no_grad
126
- def apply(self, tensors, params, grads, loss, states, settings):
103
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
127
104
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
128
105
 
129
106
  pow = settings[0]['pow']
@@ -145,7 +122,7 @@ class Debias2(Transform):
145
122
  super().__init__(defaults, uses_grad=False, target=target)
146
123
 
147
124
  @torch.no_grad
148
- def apply(self, tensors, params, grads, loss, states, settings):
125
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
149
126
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
150
127
 
151
128
  pow = settings[0]['pow']
@@ -166,7 +143,7 @@ class CenteredEMASquared(Transform):
166
143
  super().__init__(defaults, uses_grad=False)
167
144
 
168
145
  @torch.no_grad
169
- def apply(self, tensors, params, grads, loss, states, settings):
146
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
170
147
  amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
171
148
  beta = NumberList(s['beta'] for s in settings)
172
149
 
@@ -200,7 +177,7 @@ class CenteredSqrtEMASquared(Transform):
200
177
  super().__init__(defaults, uses_grad=False)
201
178
 
202
179
  @torch.no_grad
203
- def apply(self, tensors, params, grads, loss, states, settings):
180
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
204
181
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
205
182
 
206
183
  amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
@@ -3,15 +3,15 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import torch
9
9
 
10
10
  from ...core import Chainable, Module, Target, Var, maybe_chain
11
- from ...utils import TensorList, tensorlist
11
+ from ...utils import TensorList, tensorlist, Metrics
12
12
 
13
13
 
14
- class MultiOperation(Module, ABC):
14
+ class MultiOperationBase(Module, ABC):
15
15
  """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
16
16
  def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
17
  super().__init__(defaults=defaults)
@@ -51,14 +51,15 @@ class MultiOperation(Module, ABC):
51
51
 
52
52
 
53
53
 
54
- class SubModules(MultiOperation):
54
+ class SubModules(MultiOperationBase):
55
+ """Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
55
56
  def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
56
57
  defaults = dict(alpha=alpha)
57
58
  super().__init__(defaults, input=input, other=other)
58
59
 
59
60
  @torch.no_grad
60
61
  def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
61
- alpha = self.settings[var.params[0]]['alpha']
62
+ alpha = self.defaults['alpha']
62
63
 
63
64
  if isinstance(input, (int,float)):
64
65
  assert isinstance(other, list)
@@ -68,10 +69,12 @@ class SubModules(MultiOperation):
68
69
  else: torch._foreach_sub_(input, other, alpha=alpha)
69
70
  return input
70
71
 
71
- class DivModules(MultiOperation):
72
- def __init__(self, input: Chainable | float, other: Chainable | float):
72
+ class DivModules(MultiOperationBase):
73
+ """Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
74
+ def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
73
75
  defaults = {}
74
- super().__init__(defaults, input=input, other=other)
76
+ if other_first: super().__init__(defaults, other=other, input=input)
77
+ else: super().__init__(defaults, input=input, other=other)
75
78
 
76
79
  @torch.no_grad
77
80
  def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
@@ -82,7 +85,9 @@ class DivModules(MultiOperation):
82
85
  torch._foreach_div_(input, other)
83
86
  return input
84
87
 
85
- class PowModules(MultiOperation):
88
+
89
+ class PowModules(MultiOperationBase):
90
+ """Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
86
91
  def __init__(self, input: Chainable | float, exponent: Chainable | float):
87
92
  defaults = {}
88
93
  super().__init__(defaults, input=input, exponent=exponent)
@@ -96,17 +101,22 @@ class PowModules(MultiOperation):
96
101
  torch._foreach_div_(input, exponent)
97
102
  return input
98
103
 
99
- class LerpModules(MultiOperation):
104
+ class LerpModules(MultiOperationBase):
105
+ """Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
106
+
107
+ The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
108
+ """
100
109
  def __init__(self, input: Chainable, end: Chainable, weight: float):
101
110
  defaults = dict(weight=weight)
102
111
  super().__init__(defaults, input=input, end=end)
103
112
 
104
113
  @torch.no_grad
105
114
  def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
106
- torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
115
+ torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
107
116
  return input
108
117
 
109
- class ClipModules(MultiOperation):
118
+ class ClipModules(MultiOperationBase):
119
+ """Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
110
120
  def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
111
121
  defaults = {}
112
122
  super().__init__(defaults, input=input, min=min, max=max)
@@ -116,22 +126,73 @@ class ClipModules(MultiOperation):
116
126
  return TensorList(input).clamp_(min=min, max=max)
117
127
 
118
128
 
119
- class GraftModules(MultiOperation):
120
- def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
129
+ class GraftModules(MultiOperationBase):
130
+ """Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
131
+
132
+ Args:
133
+ direction (Chainable): module to use the direction from
134
+ magnitude (Chainable): module to use the magnitude from
135
+ tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
136
+ ord (float, optional): norm order. Defaults to 2.
137
+ eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
138
+ strength (float, optional): strength of grafting. Defaults to 1.
139
+
140
+ Example:
141
+ Shampoo grafted to Adam
142
+
143
+ .. code-block:: python
144
+
145
+ opt = tz.Modular(
146
+ model.parameters(),
147
+ tz.m.GraftModules(
148
+ direction = tz.m.Shampoo(),
149
+ magnitude = tz.m.Adam(),
150
+ ),
151
+ tz.m.LR(1e-3)
152
+ )
153
+
154
+ Reference:
155
+ Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
156
+ """
157
+ def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
121
158
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
122
159
  super().__init__(defaults, direction=direction, magnitude=magnitude)
123
160
 
124
161
  @torch.no_grad
125
162
  def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
126
- tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
163
+ tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
127
164
  return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
128
165
 
129
-
130
- class Where(MultiOperation):
131
- def __init__(self, condition: Chainable, input: Chainable | float, other: Chainable | float):
132
- super().__init__({}, condition=condition, input=input, other=other)
166
+ class MultiplyByModuleNorm(MultiOperationBase):
167
+ """Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
168
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
169
+ defaults = dict(tensorwise=tensorwise, ord=ord)
170
+ super().__init__(defaults, input=input, norm=norm)
133
171
 
134
172
  @torch.no_grad
135
- def transform(self, var, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
136
- return tensorlist.where(TensorList(condition).as_bool(), input, other)
173
+ def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
174
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
175
+ if tensorwise:
176
+ n = TensorList(norm).metric(ord)
177
+ else:
178
+ n = TensorList(norm).global_metric(ord)
179
+
180
+ torch._foreach_mul_(input, n)
181
+ return input
182
+
183
+ class DivideByModuleNorm(MultiOperationBase):
184
+ """Outputs :code:`input` divided by norm of the :code:`norm` output."""
185
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
186
+ defaults = dict(tensorwise=tensorwise, ord=ord)
187
+ super().__init__(defaults, input=input, norm=norm)
137
188
 
189
+ @torch.no_grad
190
+ def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
191
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
192
+ if tensorwise:
193
+ n = TensorList(norm).metric(ord)
194
+ else:
195
+ n = TensorList(norm).global_metric(ord)
196
+
197
+ torch._foreach_div_(input, n)
198
+ return input
@@ -8,7 +8,7 @@ import torch
8
8
  from ...core import Chainable, Module, Target, Var, maybe_chain
9
9
 
10
10
 
11
- class ReduceOperation(Module, ABC):
11
+ class ReduceOperationBase(Module, ABC):
12
12
  """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
13
13
  def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
14
14
  super().__init__(defaults=defaults)
@@ -46,7 +46,8 @@ class ReduceOperation(Module, ABC):
46
46
  var.update = transformed
47
47
  return var
48
48
 
49
- class Sum(ReduceOperation):
49
+ class Sum(ReduceOperationBase):
50
+ """Outputs sum of :code:`inputs` that can be modules or numbers."""
50
51
  USE_MEAN = False
51
52
  def __init__(self, *inputs: Chainable | float):
52
53
  super().__init__({}, *inputs)
@@ -63,12 +64,14 @@ class Sum(ReduceOperation):
63
64
  return sum
64
65
 
65
66
  class Mean(Sum):
67
+ """Outputs a mean of :code:`inputs` that can be modules or numbers."""
66
68
  USE_MEAN = True
67
69
 
68
70
 
69
- class WeightedSum(ReduceOperation):
71
+ class WeightedSum(ReduceOperationBase):
70
72
  USE_MEAN = False
71
73
  def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
74
+ """Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
72
75
  weights = list(weights)
73
76
  if len(inputs) != len(weights):
74
77
  raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
@@ -78,7 +81,7 @@ class WeightedSum(ReduceOperation):
78
81
  @torch.no_grad
79
82
  def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
80
83
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
81
- weights = self.settings[var.params[0]]['weights']
84
+ weights = self.defaults['weights']
82
85
  sum = cast(list, sorted_inputs[0])
83
86
  torch._foreach_mul_(sum, weights[0])
84
87
  if len(sorted_inputs) > 1:
@@ -91,9 +94,11 @@ class WeightedSum(ReduceOperation):
91
94
 
92
95
 
93
96
  class WeightedMean(WeightedSum):
97
+ """Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
94
98
  USE_MEAN = True
95
99
 
96
- class Median(ReduceOperation):
100
+ class Median(ReduceOperationBase):
101
+ """Outputs median of :code:`inputs` that can be modules or numbers."""
97
102
  def __init__(self, *inputs: Chainable | float):
98
103
  super().__init__({}, *inputs)
99
104
 
@@ -106,7 +111,8 @@ class Median(ReduceOperation):
106
111
  res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
107
112
  return res
108
113
 
109
- class Prod(ReduceOperation):
114
+ class Prod(ReduceOperationBase):
115
+ """Outputs product of :code:`inputs` that can be modules or numbers."""
110
116
  def __init__(self, *inputs: Chainable | float):
111
117
  super().__init__({}, *inputs)
112
118
 
@@ -120,7 +126,8 @@ class Prod(ReduceOperation):
120
126
 
121
127
  return prod
122
128
 
123
- class MaximumModules(ReduceOperation):
129
+ class MaximumModules(ReduceOperationBase):
130
+ """Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
124
131
  def __init__(self, *inputs: Chainable | float):
125
132
  super().__init__({}, *inputs)
126
133
 
@@ -134,7 +141,8 @@ class MaximumModules(ReduceOperation):
134
141
 
135
142
  return maximum
136
143
 
137
- class MinimumModules(ReduceOperation):
144
+ class MinimumModules(ReduceOperationBase):
145
+ """Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
138
146
  def __init__(self, *inputs: Chainable | float):
139
147
  super().__init__({}, *inputs)
140
148
 
@@ -6,76 +6,92 @@ from ...core import TensorwiseTransform, Target, Transform
6
6
  from ...utils import TensorList, unpack_dicts,unpack_states
7
7
 
8
8
  class UnaryLambda(Transform):
9
+ """Applies :code:`fn` to input tensors.
10
+
11
+ :code:`fn` must accept and return a list of tensors.
12
+ """
9
13
  def __init__(self, fn, target: "Target" = 'update'):
10
14
  defaults = dict(fn=fn)
11
15
  super().__init__(defaults=defaults, uses_grad=False, target=target)
12
16
 
13
17
  @torch.no_grad
14
- def apply(self, tensors, params, grads, loss, states, settings):
18
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
15
19
  return settings[0]['fn'](tensors)
16
20
 
17
21
  class UnaryParameterwiseLambda(TensorwiseTransform):
22
+ """Applies :code:`fn` to each input tensor.
23
+
24
+ :code:`fn` must accept and return a tensor.
25
+ """
18
26
  def __init__(self, fn, target: "Target" = 'update'):
19
27
  defaults = dict(fn=fn)
20
28
  super().__init__(uses_grad=False, defaults=defaults, target=target)
21
29
 
22
30
  @torch.no_grad
23
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
24
- return settings['fn'](tensor)
31
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
32
+ return setting['fn'](tensor)
25
33
 
26
34
  class CustomUnaryOperation(Transform):
35
+ """Applies :code:`getattr(tensor, name)` to each tensor
36
+ """
27
37
  def __init__(self, name: str, target: "Target" = 'update'):
28
38
  defaults = dict(name=name)
29
39
  super().__init__(defaults=defaults, uses_grad=False, target=target)
30
40
 
31
41
  @torch.no_grad
32
- def apply(self, tensors, params, grads, loss, states, settings):
42
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
33
43
  return getattr(tensors, settings[0]['name'])()
34
44
 
35
45
 
36
46
  class Abs(Transform):
47
+ """Returns :code:`abs(input)`"""
37
48
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
38
49
  @torch.no_grad
39
- def apply(self, tensors, params, grads, loss, states, settings):
50
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
51
  torch._foreach_abs_(tensors)
41
52
  return tensors
42
53
 
43
54
  class Sign(Transform):
55
+ """Returns :code:`sign(input)`"""
44
56
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
45
57
  @torch.no_grad
46
- def apply(self, tensors, params, grads, loss, states, settings):
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
47
59
  torch._foreach_sign_(tensors)
48
60
  return tensors
49
61
 
50
62
  class Exp(Transform):
63
+ """Returns :code:`exp(input)`"""
51
64
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
52
65
  @torch.no_grad
53
- def apply(self, tensors, params, grads, loss, states, settings):
66
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
54
67
  torch._foreach_exp_(tensors)
55
68
  return tensors
56
69
 
57
70
  class Sqrt(Transform):
71
+ """Returns :code:`sqrt(input)`"""
58
72
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
59
73
  @torch.no_grad
60
- def apply(self, tensors, params, grads, loss, states, settings):
74
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
61
75
  torch._foreach_sqrt_(tensors)
62
76
  return tensors
63
77
 
64
78
  class Reciprocal(Transform):
79
+ """Returns :code:`1 / input`"""
65
80
  def __init__(self, eps = 0, target: "Target" = 'update'):
66
81
  defaults = dict(eps = eps)
67
82
  super().__init__(defaults, uses_grad=False, target=target)
68
83
  @torch.no_grad
69
- def apply(self, tensors, params, grads, loss, states, settings):
84
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
70
85
  eps = [s['eps'] for s in settings]
71
86
  if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
72
87
  torch._foreach_reciprocal_(tensors)
73
88
  return tensors
74
89
 
75
90
  class Negate(Transform):
91
+ """Returns :code:`- input`"""
76
92
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
77
93
  @torch.no_grad
78
- def apply(self, tensors, params, grads, loss, states, settings):
94
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
79
95
  torch._foreach_neg_(tensors)
80
96
  return tensors
81
97
 
@@ -97,18 +113,18 @@ class NanToNum(Transform):
97
113
  super().__init__(defaults, uses_grad=False, target=target)
98
114
 
99
115
  @torch.no_grad
100
- def apply(self, tensors, params, grads, loss, states, settings):
116
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
101
117
  nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
102
118
  return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
103
119
 
104
120
  class Rescale(Transform):
105
- """rescale update to (min, max) range"""
121
+ """Rescales input to :code`(min, max)` range"""
106
122
  def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
107
123
  defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
108
124
  super().__init__(defaults, uses_grad=False, target=target)
109
125
 
110
126
  @torch.no_grad
111
- def apply(self, tensors, params, grads, loss, states, settings):
127
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
112
128
  min, max = unpack_dicts(settings, 'min','max')
113
129
  tensorwise = settings[0]['tensorwise']
114
130
  dim = None if tensorwise else 'global'
@@ -4,38 +4,37 @@ import torch
4
4
 
5
5
  from ...core import Module, Target, Transform
6
6
  from ...utils.tensorlist import Distributions, TensorList
7
+ from ...utils.linalg.linear_operator import ScaledIdentity
7
8
 
8
-
9
- class Clone(Transform):
10
- def __init__(self): super().__init__({}, uses_grad=False)
11
- @torch.no_grad
12
- def apply(self, tensors, params, grads, loss, states, settings): return [t.clone() for t in tensors]
13
-
14
- class Grad(Module):
9
+ class Clone(Module):
10
+ """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
15
11
  def __init__(self):
16
12
  super().__init__({})
17
13
  @torch.no_grad
18
14
  def step(self, var):
19
- var.update = [g.clone() for g in var.get_grad()]
15
+ var.update = [u.clone() for u in var.get_update()]
20
16
  return var
21
17
 
22
- class Params(Module):
18
+ class Grad(Module):
19
+ """Outputs the gradient"""
23
20
  def __init__(self):
24
21
  super().__init__({})
25
22
  @torch.no_grad
26
23
  def step(self, var):
27
- var.update = [p.clone() for p in var.params]
24
+ var.update = [g.clone() for g in var.get_grad()]
28
25
  return var
29
26
 
30
- class Update(Module):
27
+ class Params(Module):
28
+ """Outputs parameters"""
31
29
  def __init__(self):
32
30
  super().__init__({})
33
31
  @torch.no_grad
34
32
  def step(self, var):
35
- var.update = [u.clone() for u in var.get_update()]
33
+ var.update = [p.clone() for p in var.params]
36
34
  return var
37
35
 
38
36
  class Zeros(Module):
37
+ """Outputs zeros"""
39
38
  def __init__(self):
40
39
  super().__init__({})
41
40
  @torch.no_grad
@@ -44,6 +43,7 @@ class Zeros(Module):
44
43
  return var
45
44
 
46
45
  class Ones(Module):
46
+ """Outputs ones"""
47
47
  def __init__(self):
48
48
  super().__init__({})
49
49
  @torch.no_grad
@@ -52,6 +52,7 @@ class Ones(Module):
52
52
  return var
53
53
 
54
54
  class Fill(Module):
55
+ """Outputs tensors filled with :code:`value`"""
55
56
  def __init__(self, value: float):
56
57
  defaults = dict(value=value)
57
58
  super().__init__(defaults)
@@ -62,18 +63,20 @@ class Fill(Module):
62
63
  return var
63
64
 
64
65
  class RandomSample(Module):
65
- def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
66
- defaults = dict(eps=eps, distribution=distribution)
66
+ """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
67
+ def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
68
+ defaults = dict(distribution=distribution, variance=variance)
67
69
  super().__init__(defaults)
68
70
 
69
71
  @torch.no_grad
70
72
  def step(self, var):
71
- var.update = TensorList(var.params).sample_like(
72
- eps=[self.settings[p]['eps'] for p in var.params], distribution=self.settings[var.params[0]]['distribution']
73
- )
73
+ distribution = self.defaults['distribution']
74
+ variance = self.get_settings(var.params, 'variance')
75
+ var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
74
76
  return var
75
77
 
76
78
  class Randn(Module):
79
+ """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
77
80
  def __init__(self):
78
81
  super().__init__({})
79
82
 
@@ -83,6 +86,7 @@ class Randn(Module):
83
86
  return var
84
87
 
85
88
  class Uniform(Module):
89
+ """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
86
90
  def __init__(self, low: float, high: float):
87
91
  defaults = dict(low=low, high=high)
88
92
  super().__init__(defaults)
@@ -94,19 +98,27 @@ class Uniform(Module):
94
98
  return var
95
99
 
96
100
  class GradToNone(Module):
101
+ """Sets :code:`grad` attribute to None on :code:`var`."""
97
102
  def __init__(self): super().__init__()
98
103
  def step(self, var):
99
104
  var.grad = None
100
105
  return var
101
106
 
102
107
  class UpdateToNone(Module):
108
+ """Sets :code:`update` attribute to None on :code:`var`."""
103
109
  def __init__(self): super().__init__()
104
110
  def step(self, var):
105
111
  var.update = None
106
112
  return var
107
113
 
108
114
  class Identity(Module):
115
+ """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
109
116
  def __init__(self, *args, **kwargs): super().__init__()
110
117
  def step(self, var): return var
118
+ def get_H(self, var):
119
+ n = sum(p.numel() for p in var.params)
120
+ p = var.params[0]
121
+ return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
111
122
 
112
- NoOp = Identity
123
+ Noop = Identity
124
+ """A placeholder identity operator that is argument-insensitive."""
@@ -1,5 +1,3 @@
1
- from .projection import Projection
2
- from .fft import FFTProjection
3
- from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
4
-
1
+ from .projection import ProjectionBase, VectorProjection, ScalarProjection
2
+ from .cast import To, ViewAsReal
5
3
  # from .galore import GaLore