torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  #pyright: reportIncompatibleMethodOverride=false
2
- """"""
3
2
  from abc import ABC, abstractmethod
4
3
  from collections.abc import Iterable, Sequence
5
4
  from operator import itemgetter
@@ -11,7 +10,7 @@ from ...core import Chainable, Module, Target, Var, maybe_chain
11
10
  from ...utils import TensorList, tensorlist
12
11
 
13
12
 
14
- class BinaryOperation(Module, ABC):
13
+ class BinaryOperationBase(Module, ABC):
15
14
  """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
16
15
  def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
16
  super().__init__(defaults=defaults)
@@ -47,7 +46,11 @@ class BinaryOperation(Module, ABC):
47
46
  return var
48
47
 
49
48
 
50
- class Add(BinaryOperation):
49
+ class Add(BinaryOperationBase):
50
+ """Add :code:`other` to tensors. :code:`other` can be a number or a module.
51
+
52
+ If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
53
+ """
51
54
  def __init__(self, other: Chainable | float, alpha: float = 1):
52
55
  defaults = dict(alpha=alpha)
53
56
  super().__init__(defaults, other=other)
@@ -58,7 +61,11 @@ class Add(BinaryOperation):
58
61
  else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
59
62
  return update
60
63
 
61
- class Sub(BinaryOperation):
64
+ class Sub(BinaryOperationBase):
65
+ """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
66
+
67
+ If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
68
+ """
62
69
  def __init__(self, other: Chainable | float, alpha: float = 1):
63
70
  defaults = dict(alpha=alpha)
64
71
  super().__init__(defaults, other=other)
@@ -69,7 +76,11 @@ class Sub(BinaryOperation):
69
76
  else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
70
77
  return update
71
78
 
72
- class RSub(BinaryOperation):
79
+ class RSub(BinaryOperationBase):
80
+ """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
81
+
82
+ If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
83
+ """
73
84
  def __init__(self, other: Chainable | float):
74
85
  super().__init__({}, other=other)
75
86
 
@@ -77,7 +88,11 @@ class RSub(BinaryOperation):
77
88
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
78
89
  return other - TensorList(update)
79
90
 
80
- class Mul(BinaryOperation):
91
+ class Mul(BinaryOperationBase):
92
+ """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
93
+
94
+ If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
95
+ """
81
96
  def __init__(self, other: Chainable | float):
82
97
  super().__init__({}, other=other)
83
98
 
@@ -86,7 +101,11 @@ class Mul(BinaryOperation):
86
101
  torch._foreach_mul_(update, other)
87
102
  return update
88
103
 
89
- class Div(BinaryOperation):
104
+ class Div(BinaryOperationBase):
105
+ """Divide tensors by :code:`other`. :code:`other` can be a number or a module.
106
+
107
+ If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
108
+ """
90
109
  def __init__(self, other: Chainable | float):
91
110
  super().__init__({}, other=other)
92
111
 
@@ -95,7 +114,11 @@ class Div(BinaryOperation):
95
114
  torch._foreach_div_(update, other)
96
115
  return update
97
116
 
98
- class RDiv(BinaryOperation):
117
+ class RDiv(BinaryOperationBase):
118
+ """Divide :code:`other` by tensors. :code:`other` can be a number or a module.
119
+
120
+ If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
121
+ """
99
122
  def __init__(self, other: Chainable | float):
100
123
  super().__init__({}, other=other)
101
124
 
@@ -103,7 +126,11 @@ class RDiv(BinaryOperation):
103
126
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
104
127
  return other / TensorList(update)
105
128
 
106
- class Pow(BinaryOperation):
129
+ class Pow(BinaryOperationBase):
130
+ """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
131
+
132
+ If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
133
+ """
107
134
  def __init__(self, exponent: Chainable | float):
108
135
  super().__init__({}, exponent=exponent)
109
136
 
@@ -112,7 +139,11 @@ class Pow(BinaryOperation):
112
139
  torch._foreach_pow_(update, exponent)
113
140
  return update
114
141
 
115
- class RPow(BinaryOperation):
142
+ class RPow(BinaryOperationBase):
143
+ """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
144
+
145
+ If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
146
+ """
116
147
  def __init__(self, other: Chainable | float):
117
148
  super().__init__({}, other=other)
118
149
 
@@ -122,7 +153,11 @@ class RPow(BinaryOperation):
122
153
  torch._foreach_pow_(other, update)
123
154
  return other
124
155
 
125
- class Lerp(BinaryOperation):
156
+ class Lerp(BinaryOperationBase):
157
+ """Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
158
+
159
+ The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
160
+ """
126
161
  def __init__(self, end: Chainable, weight: float):
127
162
  defaults = dict(weight=weight)
128
163
  super().__init__(defaults, end=end)
@@ -132,7 +167,8 @@ class Lerp(BinaryOperation):
132
167
  torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
133
168
  return update
134
169
 
135
- class CopySign(BinaryOperation):
170
+ class CopySign(BinaryOperationBase):
171
+ """Returns tensors with sign copied from :code:`other(tensors)`."""
136
172
  def __init__(self, other: Chainable):
137
173
  super().__init__({}, other=other)
138
174
 
@@ -140,7 +176,8 @@ class CopySign(BinaryOperation):
140
176
  def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
141
177
  return [u.copysign_(o) for u, o in zip(update, other)]
142
178
 
143
- class RCopySign(BinaryOperation):
179
+ class RCopySign(BinaryOperationBase):
180
+ """Returns :code:`other(tensors)` with sign copied from tensors."""
144
181
  def __init__(self, other: Chainable):
145
182
  super().__init__({}, other=other)
146
183
 
@@ -149,7 +186,11 @@ class RCopySign(BinaryOperation):
149
186
  return [o.copysign_(u) for u, o in zip(update, other)]
150
187
  CopyMagnitude = RCopySign
151
188
 
152
- class Clip(BinaryOperation):
189
+ class Clip(BinaryOperationBase):
190
+ """clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
191
+
192
+ If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
193
+ """
153
194
  def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
154
195
  super().__init__({}, min=min, max=max)
155
196
 
@@ -157,8 +198,11 @@ class Clip(BinaryOperation):
157
198
  def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
158
199
  return TensorList(update).clamp_(min=min, max=max)
159
200
 
160
- class MirroredClip(BinaryOperation):
161
- """clip by -value, value"""
201
+ class MirroredClip(BinaryOperationBase):
202
+ """clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
203
+
204
+ If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
205
+ """
162
206
  def __init__(self, value: float | Chainable):
163
207
  super().__init__({}, value=value)
164
208
 
@@ -167,8 +211,8 @@ class MirroredClip(BinaryOperation):
167
211
  min = -value if isinstance(value, (int,float)) else [-v for v in value]
168
212
  return TensorList(update).clamp_(min=min, max=value)
169
213
 
170
- class Graft(BinaryOperation):
171
- """use direction from update and magnitude from `magnitude` module"""
214
+ class Graft(BinaryOperationBase):
215
+ """Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
172
216
  def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
173
217
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
174
218
  super().__init__(defaults, magnitude=magnitude)
@@ -178,8 +222,8 @@ class Graft(BinaryOperation):
178
222
  tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
179
223
  return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
180
224
 
181
- class RGraft(BinaryOperation):
182
- """use direction from `direction` module and magnitude from update"""
225
+ class RGraft(BinaryOperationBase):
226
+ """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
183
227
 
184
228
  def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
185
229
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
@@ -192,7 +236,8 @@ class RGraft(BinaryOperation):
192
236
 
193
237
  GraftToUpdate = RGraft
194
238
 
195
- class Maximum(BinaryOperation):
239
+ class Maximum(BinaryOperationBase):
240
+ """Outputs :code:`maximum(tensors, other(tensors))`"""
196
241
  def __init__(self, other: Chainable):
197
242
  super().__init__({}, other=other)
198
243
 
@@ -201,7 +246,8 @@ class Maximum(BinaryOperation):
201
246
  torch._foreach_maximum_(update, other)
202
247
  return update
203
248
 
204
- class Minimum(BinaryOperation):
249
+ class Minimum(BinaryOperationBase):
250
+ """Outputs :code:`minimum(tensors, other(tensors))`"""
205
251
  def __init__(self, other: Chainable):
206
252
  super().__init__({}, other=other)
207
253
 
@@ -211,8 +257,8 @@ class Minimum(BinaryOperation):
211
257
  return update
212
258
 
213
259
 
214
- class GramSchimdt(BinaryOperation):
215
- """makes update orthonormal to `other`"""
260
+ class GramSchimdt(BinaryOperationBase):
261
+ """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
216
262
  def __init__(self, other: Chainable):
217
263
  super().__init__({}, other=other)
218
264
 
@@ -222,8 +268,8 @@ class GramSchimdt(BinaryOperation):
222
268
  return update - (other*update) / ((other*other) + 1e-8)
223
269
 
224
270
 
225
- class Threshold(BinaryOperation):
226
- """update above/below threshold, value at and below"""
271
+ class Threshold(BinaryOperationBase):
272
+ """Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
227
273
  def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
228
274
  defaults = dict(update_above=update_above)
229
275
  super().__init__(defaults, threshold=threshold, value=value)
@@ -3,7 +3,7 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import torch
9
9
 
@@ -11,7 +11,7 @@ from ...core import Chainable, Module, Target, Var, maybe_chain
11
11
  from ...utils import TensorList, tensorlist
12
12
 
13
13
 
14
- class MultiOperation(Module, ABC):
14
+ class MultiOperationBase(Module, ABC):
15
15
  """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
16
16
  def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
17
  super().__init__(defaults=defaults)
@@ -51,7 +51,8 @@ class MultiOperation(Module, ABC):
51
51
 
52
52
 
53
53
 
54
- class SubModules(MultiOperation):
54
+ class SubModules(MultiOperationBase):
55
+ """Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
55
56
  def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
56
57
  defaults = dict(alpha=alpha)
57
58
  super().__init__(defaults, input=input, other=other)
@@ -68,10 +69,12 @@ class SubModules(MultiOperation):
68
69
  else: torch._foreach_sub_(input, other, alpha=alpha)
69
70
  return input
70
71
 
71
- class DivModules(MultiOperation):
72
- def __init__(self, input: Chainable | float, other: Chainable | float):
72
+ class DivModules(MultiOperationBase):
73
+ """Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
74
+ def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
73
75
  defaults = {}
74
- super().__init__(defaults, input=input, other=other)
76
+ if other_first: super().__init__(defaults, other=other, input=input)
77
+ else: super().__init__(defaults, input=input, other=other)
75
78
 
76
79
  @torch.no_grad
77
80
  def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
@@ -82,7 +85,9 @@ class DivModules(MultiOperation):
82
85
  torch._foreach_div_(input, other)
83
86
  return input
84
87
 
85
- class PowModules(MultiOperation):
88
+
89
+ class PowModules(MultiOperationBase):
90
+ """Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
86
91
  def __init__(self, input: Chainable | float, exponent: Chainable | float):
87
92
  defaults = {}
88
93
  super().__init__(defaults, input=input, exponent=exponent)
@@ -96,7 +101,11 @@ class PowModules(MultiOperation):
96
101
  torch._foreach_div_(input, exponent)
97
102
  return input
98
103
 
99
- class LerpModules(MultiOperation):
104
+ class LerpModules(MultiOperationBase):
105
+ """Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
106
+
107
+ The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
108
+ """
100
109
  def __init__(self, input: Chainable, end: Chainable, weight: float):
101
110
  defaults = dict(weight=weight)
102
111
  super().__init__(defaults, input=input, end=end)
@@ -106,7 +115,8 @@ class LerpModules(MultiOperation):
106
115
  torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
107
116
  return input
108
117
 
109
- class ClipModules(MultiOperation):
118
+ class ClipModules(MultiOperationBase):
119
+ """Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
110
120
  def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
111
121
  defaults = {}
112
122
  super().__init__(defaults, input=input, min=min, max=max)
@@ -116,7 +126,34 @@ class ClipModules(MultiOperation):
116
126
  return TensorList(input).clamp_(min=min, max=max)
117
127
 
118
128
 
119
- class GraftModules(MultiOperation):
129
+ class GraftModules(MultiOperationBase):
130
+ """Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
131
+
132
+ Args:
133
+ direction (Chainable): module to use the direction from
134
+ magnitude (Chainable): module to use the magnitude from
135
+ tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
136
+ ord (float, optional): norm order. Defaults to 2.
137
+ eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
138
+ strength (float, optional): strength of grafting. Defaults to 1.
139
+
140
+ Example:
141
+ Shampoo grafted to Adam
142
+
143
+ .. code-block:: python
144
+
145
+ opt = tz.Modular(
146
+ model.parameters(),
147
+ tz.m.GraftModules(
148
+ direction = tz.m.Shampoo(),
149
+ magnitude = tz.m.Adam(),
150
+ ),
151
+ tz.m.LR(1e-3)
152
+ )
153
+
154
+ Reference:
155
+ Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
156
+ """
120
157
  def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
121
158
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
122
159
  super().__init__(defaults, direction=direction, magnitude=magnitude)
@@ -126,12 +163,36 @@ class GraftModules(MultiOperation):
126
163
  tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
127
164
  return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
128
165
 
129
-
130
- class Where(MultiOperation):
131
- def __init__(self, condition: Chainable, input: Chainable | float, other: Chainable | float):
132
- super().__init__({}, condition=condition, input=input, other=other)
166
+ class MultiplyByModuleNorm(MultiOperationBase):
167
+ """Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
168
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
169
+ defaults = dict(tensorwise=tensorwise, ord=ord)
170
+ super().__init__(defaults, input=input, norm=norm)
133
171
 
134
172
  @torch.no_grad
135
- def transform(self, var, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
136
- return tensorlist.where(TensorList(condition).as_bool(), input, other)
173
+ def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
174
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
175
+ if tensorwise:
176
+ if ord == 'mean_abs': n = [t.mean() for t in torch._foreach_abs(norm)]
177
+ else: n = torch._foreach_norm(norm, ord)
178
+ else: n = TensorList(norm).global_vector_norm(ord)
179
+
180
+ torch._foreach_mul_(input, n)
181
+ return input
182
+
183
+ class DivideByModuleNorm(MultiOperationBase):
184
+ """Outputs :code:`input` divided by norm of the :code:`norm` output."""
185
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
186
+ defaults = dict(tensorwise=tensorwise, ord=ord)
187
+ super().__init__(defaults, input=input, norm=norm)
137
188
 
189
+ @torch.no_grad
190
+ def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
191
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
192
+ if tensorwise:
193
+ if ord == 'mean_abs': n = [t.mean().clip(min=1e-8) for t in torch._foreach_abs(norm)]
194
+ else: n = torch._foreach_clamp_min(torch._foreach_norm(norm, ord), 1e-8)
195
+ else: n = TensorList(norm).global_vector_norm(ord).clip(min=1e-8)
196
+
197
+ torch._foreach_div_(input, n)
198
+ return input
@@ -8,7 +8,7 @@ import torch
8
8
  from ...core import Chainable, Module, Target, Var, maybe_chain
9
9
 
10
10
 
11
- class ReduceOperation(Module, ABC):
11
+ class ReduceOperationBase(Module, ABC):
12
12
  """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
13
13
  def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
14
14
  super().__init__(defaults=defaults)
@@ -46,7 +46,8 @@ class ReduceOperation(Module, ABC):
46
46
  var.update = transformed
47
47
  return var
48
48
 
49
- class Sum(ReduceOperation):
49
+ class Sum(ReduceOperationBase):
50
+ """Outputs sum of :code:`inputs` that can be modules or numbers."""
50
51
  USE_MEAN = False
51
52
  def __init__(self, *inputs: Chainable | float):
52
53
  super().__init__({}, *inputs)
@@ -63,12 +64,14 @@ class Sum(ReduceOperation):
63
64
  return sum
64
65
 
65
66
  class Mean(Sum):
67
+ """Outputs a mean of :code:`inputs` that can be modules or numbers."""
66
68
  USE_MEAN = True
67
69
 
68
70
 
69
- class WeightedSum(ReduceOperation):
71
+ class WeightedSum(ReduceOperationBase):
70
72
  USE_MEAN = False
71
73
  def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
74
+ """Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
72
75
  weights = list(weights)
73
76
  if len(inputs) != len(weights):
74
77
  raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
@@ -91,9 +94,11 @@ class WeightedSum(ReduceOperation):
91
94
 
92
95
 
93
96
  class WeightedMean(WeightedSum):
97
+ """Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
94
98
  USE_MEAN = True
95
99
 
96
- class Median(ReduceOperation):
100
+ class Median(ReduceOperationBase):
101
+ """Outputs median of :code:`inputs` that can be modules or numbers."""
97
102
  def __init__(self, *inputs: Chainable | float):
98
103
  super().__init__({}, *inputs)
99
104
 
@@ -106,7 +111,8 @@ class Median(ReduceOperation):
106
111
  res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
107
112
  return res
108
113
 
109
- class Prod(ReduceOperation):
114
+ class Prod(ReduceOperationBase):
115
+ """Outputs product of :code:`inputs` that can be modules or numbers."""
110
116
  def __init__(self, *inputs: Chainable | float):
111
117
  super().__init__({}, *inputs)
112
118
 
@@ -120,7 +126,8 @@ class Prod(ReduceOperation):
120
126
 
121
127
  return prod
122
128
 
123
- class MaximumModules(ReduceOperation):
129
+ class MaximumModules(ReduceOperationBase):
130
+ """Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
124
131
  def __init__(self, *inputs: Chainable | float):
125
132
  super().__init__({}, *inputs)
126
133
 
@@ -134,7 +141,8 @@ class MaximumModules(ReduceOperation):
134
141
 
135
142
  return maximum
136
143
 
137
- class MinimumModules(ReduceOperation):
144
+ class MinimumModules(ReduceOperationBase):
145
+ """Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
138
146
  def __init__(self, *inputs: Chainable | float):
139
147
  super().__init__({}, *inputs)
140
148
 
@@ -6,76 +6,92 @@ from ...core import TensorwiseTransform, Target, Transform
6
6
  from ...utils import TensorList, unpack_dicts,unpack_states
7
7
 
8
8
  class UnaryLambda(Transform):
9
+ """Applies :code:`fn` to input tensors.
10
+
11
+ :code:`fn` must accept and return a list of tensors.
12
+ """
9
13
  def __init__(self, fn, target: "Target" = 'update'):
10
14
  defaults = dict(fn=fn)
11
15
  super().__init__(defaults=defaults, uses_grad=False, target=target)
12
16
 
13
17
  @torch.no_grad
14
- def apply(self, tensors, params, grads, loss, states, settings):
18
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
15
19
  return settings[0]['fn'](tensors)
16
20
 
17
21
  class UnaryParameterwiseLambda(TensorwiseTransform):
22
+ """Applies :code:`fn` to each input tensor.
23
+
24
+ :code:`fn` must accept and return a tensor.
25
+ """
18
26
  def __init__(self, fn, target: "Target" = 'update'):
19
27
  defaults = dict(fn=fn)
20
28
  super().__init__(uses_grad=False, defaults=defaults, target=target)
21
29
 
22
30
  @torch.no_grad
23
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
24
- return settings['fn'](tensor)
31
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
32
+ return setting['fn'](tensor)
25
33
 
26
34
  class CustomUnaryOperation(Transform):
35
+ """Applies :code:`getattr(tensor, name)` to each tensor
36
+ """
27
37
  def __init__(self, name: str, target: "Target" = 'update'):
28
38
  defaults = dict(name=name)
29
39
  super().__init__(defaults=defaults, uses_grad=False, target=target)
30
40
 
31
41
  @torch.no_grad
32
- def apply(self, tensors, params, grads, loss, states, settings):
42
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
33
43
  return getattr(tensors, settings[0]['name'])()
34
44
 
35
45
 
36
46
  class Abs(Transform):
47
+ """Returns :code:`abs(input)`"""
37
48
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
38
49
  @torch.no_grad
39
- def apply(self, tensors, params, grads, loss, states, settings):
50
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
51
  torch._foreach_abs_(tensors)
41
52
  return tensors
42
53
 
43
54
  class Sign(Transform):
55
+ """Returns :code:`sign(input)`"""
44
56
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
45
57
  @torch.no_grad
46
- def apply(self, tensors, params, grads, loss, states, settings):
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
47
59
  torch._foreach_sign_(tensors)
48
60
  return tensors
49
61
 
50
62
  class Exp(Transform):
63
+ """Returns :code:`exp(input)`"""
51
64
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
52
65
  @torch.no_grad
53
- def apply(self, tensors, params, grads, loss, states, settings):
66
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
54
67
  torch._foreach_exp_(tensors)
55
68
  return tensors
56
69
 
57
70
  class Sqrt(Transform):
71
+ """Returns :code:`sqrt(input)`"""
58
72
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
59
73
  @torch.no_grad
60
- def apply(self, tensors, params, grads, loss, states, settings):
74
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
61
75
  torch._foreach_sqrt_(tensors)
62
76
  return tensors
63
77
 
64
78
  class Reciprocal(Transform):
79
+ """Returns :code:`1 / input`"""
65
80
  def __init__(self, eps = 0, target: "Target" = 'update'):
66
81
  defaults = dict(eps = eps)
67
82
  super().__init__(defaults, uses_grad=False, target=target)
68
83
  @torch.no_grad
69
- def apply(self, tensors, params, grads, loss, states, settings):
84
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
70
85
  eps = [s['eps'] for s in settings]
71
86
  if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
72
87
  torch._foreach_reciprocal_(tensors)
73
88
  return tensors
74
89
 
75
90
  class Negate(Transform):
91
+ """Returns :code:`- input`"""
76
92
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
77
93
  @torch.no_grad
78
- def apply(self, tensors, params, grads, loss, states, settings):
94
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
79
95
  torch._foreach_neg_(tensors)
80
96
  return tensors
81
97
 
@@ -97,18 +113,18 @@ class NanToNum(Transform):
97
113
  super().__init__(defaults, uses_grad=False, target=target)
98
114
 
99
115
  @torch.no_grad
100
- def apply(self, tensors, params, grads, loss, states, settings):
116
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
101
117
  nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
102
118
  return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
103
119
 
104
120
  class Rescale(Transform):
105
- """rescale update to (min, max) range"""
121
+ """Rescales input to :code`(min, max)` range"""
106
122
  def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
107
123
  defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
108
124
  super().__init__(defaults, uses_grad=False, target=target)
109
125
 
110
126
  @torch.no_grad
111
- def apply(self, tensors, params, grads, loss, states, settings):
127
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
112
128
  min, max = unpack_dicts(settings, 'min','max')
113
129
  tensorwise = settings[0]['tensorwise']
114
130
  dim = None if tensorwise else 'global'