torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/binary.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
#pyright: reportIncompatibleMethodOverride=false
|
|
2
|
-
""""""
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
3
|
from collections.abc import Iterable, Sequence
|
|
5
4
|
from operator import itemgetter
|
|
@@ -7,11 +6,11 @@ from typing import Any
|
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from ...core import Chainable, Module, Target,
|
|
9
|
+
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
11
10
|
from ...utils import TensorList, tensorlist
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
class
|
|
13
|
+
class BinaryOperationBase(Module, ABC):
|
|
15
14
|
"""Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
15
|
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
16
|
super().__init__(defaults=defaults)
|
|
@@ -26,211 +25,258 @@ class BinaryOperation(Module, ABC):
|
|
|
26
25
|
self.operands[k] = v
|
|
27
26
|
|
|
28
27
|
@abstractmethod
|
|
29
|
-
def transform(self,
|
|
28
|
+
def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
|
|
30
29
|
"""applies the operation to operands"""
|
|
31
30
|
raise NotImplementedError
|
|
32
31
|
|
|
33
32
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
33
|
+
def step(self, var: Var) -> Var:
|
|
35
34
|
# pass cloned update to all module operands
|
|
36
35
|
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
36
|
|
|
38
37
|
for k,v in self.operands.items():
|
|
39
38
|
if k in self.children:
|
|
40
39
|
v: Module
|
|
41
|
-
|
|
42
|
-
processed_operands[k] =
|
|
43
|
-
|
|
40
|
+
updated_var = v.step(var.clone(clone_update=True))
|
|
41
|
+
processed_operands[k] = updated_var.get_update()
|
|
42
|
+
var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
|
|
44
43
|
|
|
45
|
-
transformed = self.transform(
|
|
46
|
-
|
|
47
|
-
return
|
|
44
|
+
transformed = self.transform(var, update=var.get_update(), **processed_operands)
|
|
45
|
+
var.update = list(transformed)
|
|
46
|
+
return var
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
class Add(
|
|
49
|
+
class Add(BinaryOperationBase):
|
|
50
|
+
"""Add :code:`other` to tensors. :code:`other` can be a number or a module.
|
|
51
|
+
|
|
52
|
+
If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
|
|
53
|
+
"""
|
|
51
54
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
52
55
|
defaults = dict(alpha=alpha)
|
|
53
56
|
super().__init__(defaults, other=other)
|
|
54
57
|
|
|
55
58
|
@torch.no_grad
|
|
56
|
-
def transform(self,
|
|
57
|
-
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[
|
|
58
|
-
else: torch._foreach_add_(update, other, alpha=self.settings[
|
|
59
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
60
|
+
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[var.params[0]]['alpha'])
|
|
61
|
+
else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
|
|
59
62
|
return update
|
|
60
63
|
|
|
61
|
-
class Sub(
|
|
64
|
+
class Sub(BinaryOperationBase):
|
|
65
|
+
"""Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
|
|
66
|
+
|
|
67
|
+
If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
|
|
68
|
+
"""
|
|
62
69
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
63
70
|
defaults = dict(alpha=alpha)
|
|
64
71
|
super().__init__(defaults, other=other)
|
|
65
72
|
|
|
66
73
|
@torch.no_grad
|
|
67
|
-
def transform(self,
|
|
68
|
-
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[
|
|
69
|
-
else: torch._foreach_sub_(update, other, alpha=self.settings[
|
|
74
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
75
|
+
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[var.params[0]]['alpha'])
|
|
76
|
+
else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
|
|
70
77
|
return update
|
|
71
78
|
|
|
72
|
-
class RSub(
|
|
79
|
+
class RSub(BinaryOperationBase):
|
|
80
|
+
"""Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
|
|
81
|
+
|
|
82
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
|
|
83
|
+
"""
|
|
73
84
|
def __init__(self, other: Chainable | float):
|
|
74
85
|
super().__init__({}, other=other)
|
|
75
86
|
|
|
76
87
|
@torch.no_grad
|
|
77
|
-
def transform(self,
|
|
88
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
78
89
|
return other - TensorList(update)
|
|
79
90
|
|
|
80
|
-
class Mul(
|
|
91
|
+
class Mul(BinaryOperationBase):
|
|
92
|
+
"""Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
|
|
93
|
+
|
|
94
|
+
If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
|
|
95
|
+
"""
|
|
81
96
|
def __init__(self, other: Chainable | float):
|
|
82
97
|
super().__init__({}, other=other)
|
|
83
98
|
|
|
84
99
|
@torch.no_grad
|
|
85
|
-
def transform(self,
|
|
100
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
86
101
|
torch._foreach_mul_(update, other)
|
|
87
102
|
return update
|
|
88
103
|
|
|
89
|
-
class Div(
|
|
104
|
+
class Div(BinaryOperationBase):
|
|
105
|
+
"""Divide tensors by :code:`other`. :code:`other` can be a number or a module.
|
|
106
|
+
|
|
107
|
+
If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
|
|
108
|
+
"""
|
|
90
109
|
def __init__(self, other: Chainable | float):
|
|
91
110
|
super().__init__({}, other=other)
|
|
92
111
|
|
|
93
112
|
@torch.no_grad
|
|
94
|
-
def transform(self,
|
|
113
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
95
114
|
torch._foreach_div_(update, other)
|
|
96
115
|
return update
|
|
97
116
|
|
|
98
|
-
class RDiv(
|
|
117
|
+
class RDiv(BinaryOperationBase):
|
|
118
|
+
"""Divide :code:`other` by tensors. :code:`other` can be a number or a module.
|
|
119
|
+
|
|
120
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
|
|
121
|
+
"""
|
|
99
122
|
def __init__(self, other: Chainable | float):
|
|
100
123
|
super().__init__({}, other=other)
|
|
101
124
|
|
|
102
125
|
@torch.no_grad
|
|
103
|
-
def transform(self,
|
|
126
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
104
127
|
return other / TensorList(update)
|
|
105
128
|
|
|
106
|
-
class Pow(
|
|
129
|
+
class Pow(BinaryOperationBase):
|
|
130
|
+
"""Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
|
|
131
|
+
|
|
132
|
+
If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
|
|
133
|
+
"""
|
|
107
134
|
def __init__(self, exponent: Chainable | float):
|
|
108
135
|
super().__init__({}, exponent=exponent)
|
|
109
136
|
|
|
110
137
|
@torch.no_grad
|
|
111
|
-
def transform(self,
|
|
138
|
+
def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
|
|
112
139
|
torch._foreach_pow_(update, exponent)
|
|
113
140
|
return update
|
|
114
141
|
|
|
115
|
-
class RPow(
|
|
142
|
+
class RPow(BinaryOperationBase):
|
|
143
|
+
"""Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
|
|
144
|
+
|
|
145
|
+
If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
|
|
146
|
+
"""
|
|
116
147
|
def __init__(self, other: Chainable | float):
|
|
117
148
|
super().__init__({}, other=other)
|
|
118
149
|
|
|
119
150
|
@torch.no_grad
|
|
120
|
-
def transform(self,
|
|
151
|
+
def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
121
152
|
if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
|
|
122
153
|
torch._foreach_pow_(other, update)
|
|
123
154
|
return other
|
|
124
155
|
|
|
125
|
-
class Lerp(
|
|
156
|
+
class Lerp(BinaryOperationBase):
|
|
157
|
+
"""Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
|
|
158
|
+
|
|
159
|
+
The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
|
|
160
|
+
"""
|
|
126
161
|
def __init__(self, end: Chainable, weight: float):
|
|
127
162
|
defaults = dict(weight=weight)
|
|
128
163
|
super().__init__(defaults, end=end)
|
|
129
164
|
|
|
130
165
|
@torch.no_grad
|
|
131
|
-
def transform(self,
|
|
132
|
-
torch._foreach_lerp_(update, end, weight=self.get_settings('weight'
|
|
166
|
+
def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
|
|
167
|
+
torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
|
|
133
168
|
return update
|
|
134
169
|
|
|
135
|
-
class CopySign(
|
|
170
|
+
class CopySign(BinaryOperationBase):
|
|
171
|
+
"""Returns tensors with sign copied from :code:`other(tensors)`."""
|
|
136
172
|
def __init__(self, other: Chainable):
|
|
137
173
|
super().__init__({}, other=other)
|
|
138
174
|
|
|
139
175
|
@torch.no_grad
|
|
140
|
-
def transform(self,
|
|
176
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
141
177
|
return [u.copysign_(o) for u, o in zip(update, other)]
|
|
142
178
|
|
|
143
|
-
class RCopySign(
|
|
179
|
+
class RCopySign(BinaryOperationBase):
|
|
180
|
+
"""Returns :code:`other(tensors)` with sign copied from tensors."""
|
|
144
181
|
def __init__(self, other: Chainable):
|
|
145
182
|
super().__init__({}, other=other)
|
|
146
183
|
|
|
147
184
|
@torch.no_grad
|
|
148
|
-
def transform(self,
|
|
185
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
149
186
|
return [o.copysign_(u) for u, o in zip(update, other)]
|
|
150
187
|
CopyMagnitude = RCopySign
|
|
151
188
|
|
|
152
|
-
class Clip(
|
|
189
|
+
class Clip(BinaryOperationBase):
|
|
190
|
+
"""clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
|
|
191
|
+
|
|
192
|
+
If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
|
|
193
|
+
"""
|
|
153
194
|
def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
154
195
|
super().__init__({}, min=min, max=max)
|
|
155
196
|
|
|
156
197
|
@torch.no_grad
|
|
157
|
-
def transform(self,
|
|
198
|
+
def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
|
|
158
199
|
return TensorList(update).clamp_(min=min, max=max)
|
|
159
200
|
|
|
160
|
-
class MirroredClip(
|
|
161
|
-
"""clip
|
|
201
|
+
class MirroredClip(BinaryOperationBase):
|
|
202
|
+
"""clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
|
|
203
|
+
|
|
204
|
+
If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
|
|
205
|
+
"""
|
|
162
206
|
def __init__(self, value: float | Chainable):
|
|
163
207
|
super().__init__({}, value=value)
|
|
164
208
|
|
|
165
209
|
@torch.no_grad
|
|
166
|
-
def transform(self,
|
|
210
|
+
def transform(self, var, update: list[torch.Tensor], value: float | list[torch.Tensor]):
|
|
167
211
|
min = -value if isinstance(value, (int,float)) else [-v for v in value]
|
|
168
212
|
return TensorList(update).clamp_(min=min, max=value)
|
|
169
213
|
|
|
170
|
-
class Graft(
|
|
171
|
-
"""
|
|
214
|
+
class Graft(BinaryOperationBase):
|
|
215
|
+
"""Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
|
|
172
216
|
def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
173
217
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
174
218
|
super().__init__(defaults, magnitude=magnitude)
|
|
175
219
|
|
|
176
220
|
@torch.no_grad
|
|
177
|
-
def transform(self,
|
|
178
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[
|
|
221
|
+
def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
|
|
222
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
|
|
179
223
|
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
180
224
|
|
|
181
|
-
class RGraft(
|
|
182
|
-
"""
|
|
225
|
+
class RGraft(BinaryOperationBase):
|
|
226
|
+
"""Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
|
|
183
227
|
|
|
184
228
|
def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
185
229
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
186
230
|
super().__init__(defaults, direction=direction)
|
|
187
231
|
|
|
188
232
|
@torch.no_grad
|
|
189
|
-
def transform(self,
|
|
190
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[
|
|
233
|
+
def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
|
|
234
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
|
|
191
235
|
return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
192
236
|
|
|
193
237
|
GraftToUpdate = RGraft
|
|
194
238
|
|
|
195
|
-
class Maximum(
|
|
239
|
+
class Maximum(BinaryOperationBase):
|
|
240
|
+
"""Outputs :code:`maximum(tensors, other(tensors))`"""
|
|
196
241
|
def __init__(self, other: Chainable):
|
|
197
242
|
super().__init__({}, other=other)
|
|
198
243
|
|
|
199
244
|
@torch.no_grad
|
|
200
|
-
def transform(self,
|
|
245
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
201
246
|
torch._foreach_maximum_(update, other)
|
|
202
247
|
return update
|
|
203
248
|
|
|
204
|
-
class Minimum(
|
|
249
|
+
class Minimum(BinaryOperationBase):
|
|
250
|
+
"""Outputs :code:`minimum(tensors, other(tensors))`"""
|
|
205
251
|
def __init__(self, other: Chainable):
|
|
206
252
|
super().__init__({}, other=other)
|
|
207
253
|
|
|
208
254
|
@torch.no_grad
|
|
209
|
-
def transform(self,
|
|
255
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
210
256
|
torch._foreach_minimum_(update, other)
|
|
211
257
|
return update
|
|
212
258
|
|
|
213
259
|
|
|
214
|
-
class GramSchimdt(
|
|
215
|
-
"""
|
|
260
|
+
class GramSchimdt(BinaryOperationBase):
|
|
261
|
+
"""outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
|
|
216
262
|
def __init__(self, other: Chainable):
|
|
217
263
|
super().__init__({}, other=other)
|
|
218
264
|
|
|
219
265
|
@torch.no_grad
|
|
220
|
-
def transform(self,
|
|
266
|
+
def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
221
267
|
update = TensorList(update); other = TensorList(other)
|
|
222
268
|
return update - (other*update) / ((other*other) + 1e-8)
|
|
223
269
|
|
|
224
270
|
|
|
225
|
-
class Threshold(
|
|
226
|
-
"""
|
|
271
|
+
class Threshold(BinaryOperationBase):
|
|
272
|
+
"""Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
|
|
227
273
|
def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
|
|
228
274
|
defaults = dict(update_above=update_above)
|
|
229
275
|
super().__init__(defaults, threshold=threshold, value=value)
|
|
230
276
|
|
|
231
277
|
@torch.no_grad
|
|
232
|
-
def transform(self,
|
|
233
|
-
update_above = self.settings[
|
|
278
|
+
def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
|
|
279
|
+
update_above = self.settings[var.params[0]]['update_above']
|
|
234
280
|
update = TensorList(update)
|
|
235
281
|
if update_above:
|
|
236
282
|
if isinstance(value, list): return update.where_(update>threshold, value)
|
torchzero/modules/ops/multi.py
CHANGED
|
@@ -3,15 +3,15 @@
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import Iterable, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module, Target,
|
|
10
|
+
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
11
11
|
from ...utils import TensorList, tensorlist
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class
|
|
14
|
+
class MultiOperationBase(Module, ABC):
|
|
15
15
|
"""Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
16
|
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
17
|
super().__init__(defaults=defaults)
|
|
@@ -29,36 +29,37 @@ class MultiOperation(Module, ABC):
|
|
|
29
29
|
raise ValueError('At least one operand must be a module')
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
|
-
def transform(self,
|
|
32
|
+
def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
33
33
|
"""applies the operation to operands"""
|
|
34
34
|
raise NotImplementedError
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def step(self,
|
|
37
|
+
def step(self, var: Var) -> Var:
|
|
38
38
|
# pass cloned update to all module operands
|
|
39
39
|
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
40
40
|
|
|
41
41
|
for k,v in self.operands.items():
|
|
42
42
|
if k in self.children:
|
|
43
43
|
v: Module
|
|
44
|
-
|
|
45
|
-
processed_operands[k] =
|
|
46
|
-
|
|
44
|
+
updated_var = v.step(var.clone(clone_update=True))
|
|
45
|
+
processed_operands[k] = updated_var.get_update()
|
|
46
|
+
var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
|
|
47
47
|
|
|
48
|
-
transformed = self.transform(
|
|
49
|
-
|
|
50
|
-
return
|
|
48
|
+
transformed = self.transform(var, **processed_operands)
|
|
49
|
+
var.update = transformed
|
|
50
|
+
return var
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
class SubModules(
|
|
54
|
+
class SubModules(MultiOperationBase):
|
|
55
|
+
"""Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
55
56
|
def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
|
|
56
57
|
defaults = dict(alpha=alpha)
|
|
57
58
|
super().__init__(defaults, input=input, other=other)
|
|
58
59
|
|
|
59
60
|
@torch.no_grad
|
|
60
|
-
def transform(self,
|
|
61
|
-
alpha = self.settings[
|
|
61
|
+
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
62
|
+
alpha = self.settings[var.params[0]]['alpha']
|
|
62
63
|
|
|
63
64
|
if isinstance(input, (int,float)):
|
|
64
65
|
assert isinstance(other, list)
|
|
@@ -68,13 +69,15 @@ class SubModules(MultiOperation):
|
|
|
68
69
|
else: torch._foreach_sub_(input, other, alpha=alpha)
|
|
69
70
|
return input
|
|
70
71
|
|
|
71
|
-
class DivModules(
|
|
72
|
-
|
|
72
|
+
class DivModules(MultiOperationBase):
|
|
73
|
+
"""Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
74
|
+
def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
|
|
73
75
|
defaults = {}
|
|
74
|
-
super().__init__(defaults,
|
|
76
|
+
if other_first: super().__init__(defaults, other=other, input=input)
|
|
77
|
+
else: super().__init__(defaults, input=input, other=other)
|
|
75
78
|
|
|
76
79
|
@torch.no_grad
|
|
77
|
-
def transform(self,
|
|
80
|
+
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
78
81
|
if isinstance(input, (int,float)):
|
|
79
82
|
assert isinstance(other, list)
|
|
80
83
|
return input / TensorList(other)
|
|
@@ -82,13 +85,15 @@ class DivModules(MultiOperation):
|
|
|
82
85
|
torch._foreach_div_(input, other)
|
|
83
86
|
return input
|
|
84
87
|
|
|
85
|
-
|
|
88
|
+
|
|
89
|
+
class PowModules(MultiOperationBase):
|
|
90
|
+
"""Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
86
91
|
def __init__(self, input: Chainable | float, exponent: Chainable | float):
|
|
87
92
|
defaults = {}
|
|
88
93
|
super().__init__(defaults, input=input, exponent=exponent)
|
|
89
94
|
|
|
90
95
|
@torch.no_grad
|
|
91
|
-
def transform(self,
|
|
96
|
+
def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
92
97
|
if isinstance(input, (int,float)):
|
|
93
98
|
assert isinstance(exponent, list)
|
|
94
99
|
return input ** TensorList(exponent)
|
|
@@ -96,42 +101,98 @@ class PowModules(MultiOperation):
|
|
|
96
101
|
torch._foreach_div_(input, exponent)
|
|
97
102
|
return input
|
|
98
103
|
|
|
99
|
-
class LerpModules(
|
|
104
|
+
class LerpModules(MultiOperationBase):
|
|
105
|
+
"""Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
|
|
106
|
+
|
|
107
|
+
The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
|
|
108
|
+
"""
|
|
100
109
|
def __init__(self, input: Chainable, end: Chainable, weight: float):
|
|
101
110
|
defaults = dict(weight=weight)
|
|
102
111
|
super().__init__(defaults, input=input, end=end)
|
|
103
112
|
|
|
104
113
|
@torch.no_grad
|
|
105
|
-
def transform(self,
|
|
106
|
-
torch._foreach_lerp_(input, end, weight=self.settings[
|
|
114
|
+
def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
115
|
+
torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
|
|
107
116
|
return input
|
|
108
117
|
|
|
109
|
-
class ClipModules(
|
|
118
|
+
class ClipModules(MultiOperationBase):
|
|
119
|
+
"""Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
|
|
110
120
|
def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
111
121
|
defaults = {}
|
|
112
122
|
super().__init__(defaults, input=input, min=min, max=max)
|
|
113
123
|
|
|
114
124
|
@torch.no_grad
|
|
115
|
-
def transform(self,
|
|
125
|
+
def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
116
126
|
return TensorList(input).clamp_(min=min, max=max)
|
|
117
127
|
|
|
118
128
|
|
|
119
|
-
class GraftModules(
|
|
129
|
+
class GraftModules(MultiOperationBase):
|
|
130
|
+
"""Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
direction (Chainable): module to use the direction from
|
|
134
|
+
magnitude (Chainable): module to use the magnitude from
|
|
135
|
+
tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
|
|
136
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
137
|
+
eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
|
|
138
|
+
strength (float, optional): strength of grafting. Defaults to 1.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
Shampoo grafted to Adam
|
|
142
|
+
|
|
143
|
+
.. code-block:: python
|
|
144
|
+
|
|
145
|
+
opt = tz.Modular(
|
|
146
|
+
model.parameters(),
|
|
147
|
+
tz.m.GraftModules(
|
|
148
|
+
direction = tz.m.Shampoo(),
|
|
149
|
+
magnitude = tz.m.Adam(),
|
|
150
|
+
),
|
|
151
|
+
tz.m.LR(1e-3)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
Reference:
|
|
155
|
+
Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
|
|
156
|
+
"""
|
|
120
157
|
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
|
|
121
158
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
122
159
|
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
123
160
|
|
|
124
161
|
@torch.no_grad
|
|
125
|
-
def transform(self,
|
|
126
|
-
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[
|
|
162
|
+
def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
|
|
163
|
+
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
|
|
127
164
|
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
128
165
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def __init__(self,
|
|
132
|
-
|
|
166
|
+
class MultiplyByModuleNorm(MultiOperationBase):
|
|
167
|
+
"""Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
|
|
168
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
|
|
169
|
+
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
170
|
+
super().__init__(defaults, input=input, norm=norm)
|
|
133
171
|
|
|
134
172
|
@torch.no_grad
|
|
135
|
-
def transform(self,
|
|
136
|
-
|
|
173
|
+
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
174
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
|
|
175
|
+
if tensorwise:
|
|
176
|
+
if ord == 'mean_abs': n = [t.mean() for t in torch._foreach_abs(norm)]
|
|
177
|
+
else: n = torch._foreach_norm(norm, ord)
|
|
178
|
+
else: n = TensorList(norm).global_vector_norm(ord)
|
|
179
|
+
|
|
180
|
+
torch._foreach_mul_(input, n)
|
|
181
|
+
return input
|
|
182
|
+
|
|
183
|
+
class DivideByModuleNorm(MultiOperationBase):
|
|
184
|
+
"""Outputs :code:`input` divided by norm of the :code:`norm` output."""
|
|
185
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
|
|
186
|
+
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
187
|
+
super().__init__(defaults, input=input, norm=norm)
|
|
137
188
|
|
|
189
|
+
@torch.no_grad
|
|
190
|
+
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
191
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
|
|
192
|
+
if tensorwise:
|
|
193
|
+
if ord == 'mean_abs': n = [t.mean().clip(min=1e-8) for t in torch._foreach_abs(norm)]
|
|
194
|
+
else: n = torch._foreach_clamp_min(torch._foreach_norm(norm, ord), 1e-8)
|
|
195
|
+
else: n = TensorList(norm).global_vector_norm(ord).clip(min=1e-8)
|
|
196
|
+
|
|
197
|
+
torch._foreach_div_(input, n)
|
|
198
|
+
return input
|