torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/binary.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
#pyright: reportIncompatibleMethodOverride=false
|
|
2
|
-
""""""
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
3
|
from collections.abc import Iterable, Sequence
|
|
5
4
|
from operator import itemgetter
|
|
@@ -11,7 +10,7 @@ from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
|
11
10
|
from ...utils import TensorList, tensorlist
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
class
|
|
13
|
+
class BinaryOperationBase(Module, ABC):
|
|
15
14
|
"""Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
15
|
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
16
|
super().__init__(defaults=defaults)
|
|
@@ -47,7 +46,11 @@ class BinaryOperation(Module, ABC):
|
|
|
47
46
|
return var
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
class Add(
|
|
49
|
+
class Add(BinaryOperationBase):
|
|
50
|
+
"""Add :code:`other` to tensors. :code:`other` can be a number or a module.
|
|
51
|
+
|
|
52
|
+
If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
|
|
53
|
+
"""
|
|
51
54
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
52
55
|
defaults = dict(alpha=alpha)
|
|
53
56
|
super().__init__(defaults, other=other)
|
|
@@ -58,7 +61,11 @@ class Add(BinaryOperation):
|
|
|
58
61
|
else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
|
|
59
62
|
return update
|
|
60
63
|
|
|
61
|
-
class Sub(
|
|
64
|
+
class Sub(BinaryOperationBase):
|
|
65
|
+
"""Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
|
|
66
|
+
|
|
67
|
+
If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
|
|
68
|
+
"""
|
|
62
69
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
63
70
|
defaults = dict(alpha=alpha)
|
|
64
71
|
super().__init__(defaults, other=other)
|
|
@@ -69,7 +76,11 @@ class Sub(BinaryOperation):
|
|
|
69
76
|
else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
|
|
70
77
|
return update
|
|
71
78
|
|
|
72
|
-
class RSub(
|
|
79
|
+
class RSub(BinaryOperationBase):
|
|
80
|
+
"""Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
|
|
81
|
+
|
|
82
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
|
|
83
|
+
"""
|
|
73
84
|
def __init__(self, other: Chainable | float):
|
|
74
85
|
super().__init__({}, other=other)
|
|
75
86
|
|
|
@@ -77,7 +88,11 @@ class RSub(BinaryOperation):
|
|
|
77
88
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
78
89
|
return other - TensorList(update)
|
|
79
90
|
|
|
80
|
-
class Mul(
|
|
91
|
+
class Mul(BinaryOperationBase):
|
|
92
|
+
"""Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
|
|
93
|
+
|
|
94
|
+
If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
|
|
95
|
+
"""
|
|
81
96
|
def __init__(self, other: Chainable | float):
|
|
82
97
|
super().__init__({}, other=other)
|
|
83
98
|
|
|
@@ -86,7 +101,11 @@ class Mul(BinaryOperation):
|
|
|
86
101
|
torch._foreach_mul_(update, other)
|
|
87
102
|
return update
|
|
88
103
|
|
|
89
|
-
class Div(
|
|
104
|
+
class Div(BinaryOperationBase):
|
|
105
|
+
"""Divide tensors by :code:`other`. :code:`other` can be a number or a module.
|
|
106
|
+
|
|
107
|
+
If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
|
|
108
|
+
"""
|
|
90
109
|
def __init__(self, other: Chainable | float):
|
|
91
110
|
super().__init__({}, other=other)
|
|
92
111
|
|
|
@@ -95,7 +114,11 @@ class Div(BinaryOperation):
|
|
|
95
114
|
torch._foreach_div_(update, other)
|
|
96
115
|
return update
|
|
97
116
|
|
|
98
|
-
class RDiv(
|
|
117
|
+
class RDiv(BinaryOperationBase):
|
|
118
|
+
"""Divide :code:`other` by tensors. :code:`other` can be a number or a module.
|
|
119
|
+
|
|
120
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
|
|
121
|
+
"""
|
|
99
122
|
def __init__(self, other: Chainable | float):
|
|
100
123
|
super().__init__({}, other=other)
|
|
101
124
|
|
|
@@ -103,7 +126,11 @@ class RDiv(BinaryOperation):
|
|
|
103
126
|
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
104
127
|
return other / TensorList(update)
|
|
105
128
|
|
|
106
|
-
class Pow(
|
|
129
|
+
class Pow(BinaryOperationBase):
|
|
130
|
+
"""Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
|
|
131
|
+
|
|
132
|
+
If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
|
|
133
|
+
"""
|
|
107
134
|
def __init__(self, exponent: Chainable | float):
|
|
108
135
|
super().__init__({}, exponent=exponent)
|
|
109
136
|
|
|
@@ -112,7 +139,11 @@ class Pow(BinaryOperation):
|
|
|
112
139
|
torch._foreach_pow_(update, exponent)
|
|
113
140
|
return update
|
|
114
141
|
|
|
115
|
-
class RPow(
|
|
142
|
+
class RPow(BinaryOperationBase):
|
|
143
|
+
"""Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
|
|
144
|
+
|
|
145
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
|
|
146
|
+
"""
|
|
116
147
|
def __init__(self, other: Chainable | float):
|
|
117
148
|
super().__init__({}, other=other)
|
|
118
149
|
|
|
@@ -122,7 +153,11 @@ class RPow(BinaryOperation):
|
|
|
122
153
|
torch._foreach_pow_(other, update)
|
|
123
154
|
return other
|
|
124
155
|
|
|
125
|
-
class Lerp(
|
|
156
|
+
class Lerp(BinaryOperationBase):
|
|
157
|
+
"""Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
|
|
158
|
+
|
|
159
|
+
The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
|
|
160
|
+
"""
|
|
126
161
|
def __init__(self, end: Chainable, weight: float):
|
|
127
162
|
defaults = dict(weight=weight)
|
|
128
163
|
super().__init__(defaults, end=end)
|
|
@@ -132,7 +167,8 @@ class Lerp(BinaryOperation):
|
|
|
132
167
|
torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
|
|
133
168
|
return update
|
|
134
169
|
|
|
135
|
-
class CopySign(
|
|
170
|
+
class CopySign(BinaryOperationBase):
|
|
171
|
+
"""Returns tensors with sign copied from :code:`other(tensors)`."""
|
|
136
172
|
def __init__(self, other: Chainable):
|
|
137
173
|
super().__init__({}, other=other)
|
|
138
174
|
|
|
@@ -140,7 +176,8 @@ class CopySign(BinaryOperation):
|
|
|
140
176
|
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
141
177
|
return [u.copysign_(o) for u, o in zip(update, other)]
|
|
142
178
|
|
|
143
|
-
class RCopySign(
|
|
179
|
+
class RCopySign(BinaryOperationBase):
|
|
180
|
+
"""Returns :code:`other(tensors)` with sign copied from tensors."""
|
|
144
181
|
def __init__(self, other: Chainable):
|
|
145
182
|
super().__init__({}, other=other)
|
|
146
183
|
|
|
@@ -149,7 +186,11 @@ class RCopySign(BinaryOperation):
|
|
|
149
186
|
return [o.copysign_(u) for u, o in zip(update, other)]
|
|
150
187
|
CopyMagnitude = RCopySign
|
|
151
188
|
|
|
152
|
-
class Clip(
|
|
189
|
+
class Clip(BinaryOperationBase):
|
|
190
|
+
"""clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
|
|
191
|
+
|
|
192
|
+
If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
|
|
193
|
+
"""
|
|
153
194
|
def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
154
195
|
super().__init__({}, min=min, max=max)
|
|
155
196
|
|
|
@@ -157,8 +198,11 @@ class Clip(BinaryOperation):
|
|
|
157
198
|
def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
|
|
158
199
|
return TensorList(update).clamp_(min=min, max=max)
|
|
159
200
|
|
|
160
|
-
class MirroredClip(
|
|
161
|
-
"""clip
|
|
201
|
+
class MirroredClip(BinaryOperationBase):
|
|
202
|
+
"""clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
|
|
203
|
+
|
|
204
|
+
If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
|
|
205
|
+
"""
|
|
162
206
|
def __init__(self, value: float | Chainable):
|
|
163
207
|
super().__init__({}, value=value)
|
|
164
208
|
|
|
@@ -167,8 +211,8 @@ class MirroredClip(BinaryOperation):
|
|
|
167
211
|
min = -value if isinstance(value, (int,float)) else [-v for v in value]
|
|
168
212
|
return TensorList(update).clamp_(min=min, max=value)
|
|
169
213
|
|
|
170
|
-
class Graft(
|
|
171
|
-
"""
|
|
214
|
+
class Graft(BinaryOperationBase):
|
|
215
|
+
"""Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
|
|
172
216
|
def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
173
217
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
174
218
|
super().__init__(defaults, magnitude=magnitude)
|
|
@@ -178,8 +222,8 @@ class Graft(BinaryOperation):
|
|
|
178
222
|
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
|
|
179
223
|
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
180
224
|
|
|
181
|
-
class RGraft(
|
|
182
|
-
"""
|
|
225
|
+
class RGraft(BinaryOperationBase):
|
|
226
|
+
"""Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
|
|
183
227
|
|
|
184
228
|
def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
185
229
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
@@ -192,7 +236,8 @@ class RGraft(BinaryOperation):
|
|
|
192
236
|
|
|
193
237
|
GraftToUpdate = RGraft
|
|
194
238
|
|
|
195
|
-
class Maximum(
|
|
239
|
+
class Maximum(BinaryOperationBase):
|
|
240
|
+
"""Outputs :code:`maximum(tensors, other(tensors))`"""
|
|
196
241
|
def __init__(self, other: Chainable):
|
|
197
242
|
super().__init__({}, other=other)
|
|
198
243
|
|
|
@@ -201,7 +246,8 @@ class Maximum(BinaryOperation):
|
|
|
201
246
|
torch._foreach_maximum_(update, other)
|
|
202
247
|
return update
|
|
203
248
|
|
|
204
|
-
class Minimum(
|
|
249
|
+
class Minimum(BinaryOperationBase):
|
|
250
|
+
"""Outputs :code:`minimum(tensors, other(tensors))`"""
|
|
205
251
|
def __init__(self, other: Chainable):
|
|
206
252
|
super().__init__({}, other=other)
|
|
207
253
|
|
|
@@ -211,8 +257,8 @@ class Minimum(BinaryOperation):
|
|
|
211
257
|
return update
|
|
212
258
|
|
|
213
259
|
|
|
214
|
-
class GramSchimdt(
|
|
215
|
-
"""
|
|
260
|
+
class GramSchimdt(BinaryOperationBase):
|
|
261
|
+
"""outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
|
|
216
262
|
def __init__(self, other: Chainable):
|
|
217
263
|
super().__init__({}, other=other)
|
|
218
264
|
|
|
@@ -222,8 +268,8 @@ class GramSchimdt(BinaryOperation):
|
|
|
222
268
|
return update - (other*update) / ((other*other) + 1e-8)
|
|
223
269
|
|
|
224
270
|
|
|
225
|
-
class Threshold(
|
|
226
|
-
"""
|
|
271
|
+
class Threshold(BinaryOperationBase):
|
|
272
|
+
"""Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
|
|
227
273
|
def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
|
|
228
274
|
defaults = dict(update_above=update_above)
|
|
229
275
|
super().__init__(defaults, threshold=threshold, value=value)
|
torchzero/modules/ops/multi.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import Iterable, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
@@ -11,7 +11,7 @@ from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
|
11
11
|
from ...utils import TensorList, tensorlist
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class
|
|
14
|
+
class MultiOperationBase(Module, ABC):
|
|
15
15
|
"""Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
16
|
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
17
|
super().__init__(defaults=defaults)
|
|
@@ -51,7 +51,8 @@ class MultiOperation(Module, ABC):
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
class SubModules(
|
|
54
|
+
class SubModules(MultiOperationBase):
|
|
55
|
+
"""Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
55
56
|
def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
|
|
56
57
|
defaults = dict(alpha=alpha)
|
|
57
58
|
super().__init__(defaults, input=input, other=other)
|
|
@@ -68,10 +69,12 @@ class SubModules(MultiOperation):
|
|
|
68
69
|
else: torch._foreach_sub_(input, other, alpha=alpha)
|
|
69
70
|
return input
|
|
70
71
|
|
|
71
|
-
class DivModules(
|
|
72
|
-
|
|
72
|
+
class DivModules(MultiOperationBase):
|
|
73
|
+
"""Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
74
|
+
def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
|
|
73
75
|
defaults = {}
|
|
74
|
-
super().__init__(defaults,
|
|
76
|
+
if other_first: super().__init__(defaults, other=other, input=input)
|
|
77
|
+
else: super().__init__(defaults, input=input, other=other)
|
|
75
78
|
|
|
76
79
|
@torch.no_grad
|
|
77
80
|
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
@@ -82,7 +85,9 @@ class DivModules(MultiOperation):
|
|
|
82
85
|
torch._foreach_div_(input, other)
|
|
83
86
|
return input
|
|
84
87
|
|
|
85
|
-
|
|
88
|
+
|
|
89
|
+
class PowModules(MultiOperationBase):
|
|
90
|
+
"""Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
86
91
|
def __init__(self, input: Chainable | float, exponent: Chainable | float):
|
|
87
92
|
defaults = {}
|
|
88
93
|
super().__init__(defaults, input=input, exponent=exponent)
|
|
@@ -96,7 +101,11 @@ class PowModules(MultiOperation):
|
|
|
96
101
|
torch._foreach_div_(input, exponent)
|
|
97
102
|
return input
|
|
98
103
|
|
|
99
|
-
class LerpModules(
|
|
104
|
+
class LerpModules(MultiOperationBase):
|
|
105
|
+
"""Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
|
|
106
|
+
|
|
107
|
+
The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
|
|
108
|
+
"""
|
|
100
109
|
def __init__(self, input: Chainable, end: Chainable, weight: float):
|
|
101
110
|
defaults = dict(weight=weight)
|
|
102
111
|
super().__init__(defaults, input=input, end=end)
|
|
@@ -106,7 +115,8 @@ class LerpModules(MultiOperation):
|
|
|
106
115
|
torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
|
|
107
116
|
return input
|
|
108
117
|
|
|
109
|
-
class ClipModules(
|
|
118
|
+
class ClipModules(MultiOperationBase):
|
|
119
|
+
"""Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
|
|
110
120
|
def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
111
121
|
defaults = {}
|
|
112
122
|
super().__init__(defaults, input=input, min=min, max=max)
|
|
@@ -116,7 +126,34 @@ class ClipModules(MultiOperation):
|
|
|
116
126
|
return TensorList(input).clamp_(min=min, max=max)
|
|
117
127
|
|
|
118
128
|
|
|
119
|
-
class GraftModules(
|
|
129
|
+
class GraftModules(MultiOperationBase):
|
|
130
|
+
"""Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
direction (Chainable): module to use the direction from
|
|
134
|
+
magnitude (Chainable): module to use the magnitude from
|
|
135
|
+
tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
|
|
136
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
137
|
+
eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
|
|
138
|
+
strength (float, optional): strength of grafting. Defaults to 1.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
Shampoo grafted to Adam
|
|
142
|
+
|
|
143
|
+
.. code-block:: python
|
|
144
|
+
|
|
145
|
+
opt = tz.Modular(
|
|
146
|
+
model.parameters(),
|
|
147
|
+
tz.m.GraftModules(
|
|
148
|
+
direction = tz.m.Shampoo(),
|
|
149
|
+
magnitude = tz.m.Adam(),
|
|
150
|
+
),
|
|
151
|
+
tz.m.LR(1e-3)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
Reference:
|
|
155
|
+
Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
|
|
156
|
+
"""
|
|
120
157
|
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
|
|
121
158
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
122
159
|
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
@@ -126,12 +163,36 @@ class GraftModules(MultiOperation):
|
|
|
126
163
|
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
|
|
127
164
|
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
128
165
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def __init__(self,
|
|
132
|
-
|
|
166
|
+
class MultiplyByModuleNorm(MultiOperationBase):
|
|
167
|
+
"""Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
|
|
168
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
|
|
169
|
+
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
170
|
+
super().__init__(defaults, input=input, norm=norm)
|
|
133
171
|
|
|
134
172
|
@torch.no_grad
|
|
135
|
-
def transform(self, var,
|
|
136
|
-
|
|
173
|
+
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
174
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
|
|
175
|
+
if tensorwise:
|
|
176
|
+
if ord == 'mean_abs': n = [t.mean() for t in torch._foreach_abs(norm)]
|
|
177
|
+
else: n = torch._foreach_norm(norm, ord)
|
|
178
|
+
else: n = TensorList(norm).global_vector_norm(ord)
|
|
179
|
+
|
|
180
|
+
torch._foreach_mul_(input, n)
|
|
181
|
+
return input
|
|
182
|
+
|
|
183
|
+
class DivideByModuleNorm(MultiOperationBase):
|
|
184
|
+
"""Outputs :code:`input` divided by norm of the :code:`norm` output."""
|
|
185
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
|
|
186
|
+
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
187
|
+
super().__init__(defaults, input=input, norm=norm)
|
|
137
188
|
|
|
189
|
+
@torch.no_grad
|
|
190
|
+
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
191
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
|
|
192
|
+
if tensorwise:
|
|
193
|
+
if ord == 'mean_abs': n = [t.mean().clip(min=1e-8) for t in torch._foreach_abs(norm)]
|
|
194
|
+
else: n = torch._foreach_clamp_min(torch._foreach_norm(norm, ord), 1e-8)
|
|
195
|
+
else: n = TensorList(norm).global_vector_norm(ord).clip(min=1e-8)
|
|
196
|
+
|
|
197
|
+
torch._foreach_div_(input, n)
|
|
198
|
+
return input
|
torchzero/modules/ops/reduce.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch
|
|
|
8
8
|
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
class
|
|
11
|
+
class ReduceOperationBase(Module, ABC):
|
|
12
12
|
"""Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
13
13
|
def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
|
|
14
14
|
super().__init__(defaults=defaults)
|
|
@@ -46,7 +46,8 @@ class ReduceOperation(Module, ABC):
|
|
|
46
46
|
var.update = transformed
|
|
47
47
|
return var
|
|
48
48
|
|
|
49
|
-
class Sum(
|
|
49
|
+
class Sum(ReduceOperationBase):
|
|
50
|
+
"""Outputs sum of :code:`inputs` that can be modules or numbers."""
|
|
50
51
|
USE_MEAN = False
|
|
51
52
|
def __init__(self, *inputs: Chainable | float):
|
|
52
53
|
super().__init__({}, *inputs)
|
|
@@ -63,12 +64,14 @@ class Sum(ReduceOperation):
|
|
|
63
64
|
return sum
|
|
64
65
|
|
|
65
66
|
class Mean(Sum):
|
|
67
|
+
"""Outputs a mean of :code:`inputs` that can be modules or numbers."""
|
|
66
68
|
USE_MEAN = True
|
|
67
69
|
|
|
68
70
|
|
|
69
|
-
class WeightedSum(
|
|
71
|
+
class WeightedSum(ReduceOperationBase):
|
|
70
72
|
USE_MEAN = False
|
|
71
73
|
def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
|
|
74
|
+
"""Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
|
|
72
75
|
weights = list(weights)
|
|
73
76
|
if len(inputs) != len(weights):
|
|
74
77
|
raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
|
|
@@ -91,9 +94,11 @@ class WeightedSum(ReduceOperation):
|
|
|
91
94
|
|
|
92
95
|
|
|
93
96
|
class WeightedMean(WeightedSum):
|
|
97
|
+
"""Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
|
|
94
98
|
USE_MEAN = True
|
|
95
99
|
|
|
96
|
-
class Median(
|
|
100
|
+
class Median(ReduceOperationBase):
|
|
101
|
+
"""Outputs median of :code:`inputs` that can be modules or numbers."""
|
|
97
102
|
def __init__(self, *inputs: Chainable | float):
|
|
98
103
|
super().__init__({}, *inputs)
|
|
99
104
|
|
|
@@ -106,7 +111,8 @@ class Median(ReduceOperation):
|
|
|
106
111
|
res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
|
|
107
112
|
return res
|
|
108
113
|
|
|
109
|
-
class Prod(
|
|
114
|
+
class Prod(ReduceOperationBase):
|
|
115
|
+
"""Outputs product of :code:`inputs` that can be modules or numbers."""
|
|
110
116
|
def __init__(self, *inputs: Chainable | float):
|
|
111
117
|
super().__init__({}, *inputs)
|
|
112
118
|
|
|
@@ -120,7 +126,8 @@ class Prod(ReduceOperation):
|
|
|
120
126
|
|
|
121
127
|
return prod
|
|
122
128
|
|
|
123
|
-
class MaximumModules(
|
|
129
|
+
class MaximumModules(ReduceOperationBase):
|
|
130
|
+
"""Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
|
|
124
131
|
def __init__(self, *inputs: Chainable | float):
|
|
125
132
|
super().__init__({}, *inputs)
|
|
126
133
|
|
|
@@ -134,7 +141,8 @@ class MaximumModules(ReduceOperation):
|
|
|
134
141
|
|
|
135
142
|
return maximum
|
|
136
143
|
|
|
137
|
-
class MinimumModules(
|
|
144
|
+
class MinimumModules(ReduceOperationBase):
|
|
145
|
+
"""Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
|
|
138
146
|
def __init__(self, *inputs: Chainable | float):
|
|
139
147
|
super().__init__({}, *inputs)
|
|
140
148
|
|
torchzero/modules/ops/unary.py
CHANGED
|
@@ -6,76 +6,92 @@ from ...core import TensorwiseTransform, Target, Transform
|
|
|
6
6
|
from ...utils import TensorList, unpack_dicts,unpack_states
|
|
7
7
|
|
|
8
8
|
class UnaryLambda(Transform):
|
|
9
|
+
"""Applies :code:`fn` to input tensors.
|
|
10
|
+
|
|
11
|
+
:code:`fn` must accept and return a list of tensors.
|
|
12
|
+
"""
|
|
9
13
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
10
14
|
defaults = dict(fn=fn)
|
|
11
15
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
12
16
|
|
|
13
17
|
@torch.no_grad
|
|
14
|
-
def
|
|
18
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
15
19
|
return settings[0]['fn'](tensors)
|
|
16
20
|
|
|
17
21
|
class UnaryParameterwiseLambda(TensorwiseTransform):
|
|
22
|
+
"""Applies :code:`fn` to each input tensor.
|
|
23
|
+
|
|
24
|
+
:code:`fn` must accept and return a tensor.
|
|
25
|
+
"""
|
|
18
26
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
19
27
|
defaults = dict(fn=fn)
|
|
20
28
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
21
29
|
|
|
22
30
|
@torch.no_grad
|
|
23
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
24
|
-
return
|
|
31
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
32
|
+
return setting['fn'](tensor)
|
|
25
33
|
|
|
26
34
|
class CustomUnaryOperation(Transform):
|
|
35
|
+
"""Applies :code:`getattr(tensor, name)` to each tensor
|
|
36
|
+
"""
|
|
27
37
|
def __init__(self, name: str, target: "Target" = 'update'):
|
|
28
38
|
defaults = dict(name=name)
|
|
29
39
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
30
40
|
|
|
31
41
|
@torch.no_grad
|
|
32
|
-
def
|
|
42
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
33
43
|
return getattr(tensors, settings[0]['name'])()
|
|
34
44
|
|
|
35
45
|
|
|
36
46
|
class Abs(Transform):
|
|
47
|
+
"""Returns :code:`abs(input)`"""
|
|
37
48
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
38
49
|
@torch.no_grad
|
|
39
|
-
def
|
|
50
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
51
|
torch._foreach_abs_(tensors)
|
|
41
52
|
return tensors
|
|
42
53
|
|
|
43
54
|
class Sign(Transform):
|
|
55
|
+
"""Returns :code:`sign(input)`"""
|
|
44
56
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
45
57
|
@torch.no_grad
|
|
46
|
-
def
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
47
59
|
torch._foreach_sign_(tensors)
|
|
48
60
|
return tensors
|
|
49
61
|
|
|
50
62
|
class Exp(Transform):
|
|
63
|
+
"""Returns :code:`exp(input)`"""
|
|
51
64
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
52
65
|
@torch.no_grad
|
|
53
|
-
def
|
|
66
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
54
67
|
torch._foreach_exp_(tensors)
|
|
55
68
|
return tensors
|
|
56
69
|
|
|
57
70
|
class Sqrt(Transform):
|
|
71
|
+
"""Returns :code:`sqrt(input)`"""
|
|
58
72
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
59
73
|
@torch.no_grad
|
|
60
|
-
def
|
|
74
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
61
75
|
torch._foreach_sqrt_(tensors)
|
|
62
76
|
return tensors
|
|
63
77
|
|
|
64
78
|
class Reciprocal(Transform):
|
|
79
|
+
"""Returns :code:`1 / input`"""
|
|
65
80
|
def __init__(self, eps = 0, target: "Target" = 'update'):
|
|
66
81
|
defaults = dict(eps = eps)
|
|
67
82
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
68
83
|
@torch.no_grad
|
|
69
|
-
def
|
|
84
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
70
85
|
eps = [s['eps'] for s in settings]
|
|
71
86
|
if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
|
|
72
87
|
torch._foreach_reciprocal_(tensors)
|
|
73
88
|
return tensors
|
|
74
89
|
|
|
75
90
|
class Negate(Transform):
|
|
91
|
+
"""Returns :code:`- input`"""
|
|
76
92
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
77
93
|
@torch.no_grad
|
|
78
|
-
def
|
|
94
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
79
95
|
torch._foreach_neg_(tensors)
|
|
80
96
|
return tensors
|
|
81
97
|
|
|
@@ -97,18 +113,18 @@ class NanToNum(Transform):
|
|
|
97
113
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
98
114
|
|
|
99
115
|
@torch.no_grad
|
|
100
|
-
def
|
|
116
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
101
117
|
nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
|
|
102
118
|
return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
|
|
103
119
|
|
|
104
120
|
class Rescale(Transform):
|
|
105
|
-
"""
|
|
121
|
+
"""Rescales input to :code`(min, max)` range"""
|
|
106
122
|
def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
|
|
107
123
|
defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
|
|
108
124
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
109
125
|
|
|
110
126
|
@torch.no_grad
|
|
111
|
-
def
|
|
127
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
112
128
|
min, max = unpack_dicts(settings, 'min','max')
|
|
113
129
|
tensorwise = settings[0]['tensorwise']
|
|
114
130
|
dim = None if tensorwise else 'global'
|