torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/binary.py
CHANGED
|
@@ -6,8 +6,8 @@ from typing import Any
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable, Module,
|
|
10
|
-
from ...utils import TensorList
|
|
9
|
+
from ...core import Chainable, Module, Objective
|
|
10
|
+
from ...utils import TensorList
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class BinaryOperationBase(Module, ABC):
|
|
@@ -25,263 +25,264 @@ class BinaryOperationBase(Module, ABC):
|
|
|
25
25
|
self.operands[k] = v
|
|
26
26
|
|
|
27
27
|
@abstractmethod
|
|
28
|
-
def transform(self,
|
|
28
|
+
def transform(self, objective: Objective, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
|
|
29
29
|
"""applies the operation to operands"""
|
|
30
30
|
raise NotImplementedError
|
|
31
31
|
|
|
32
|
+
def update(self, objective): raise RuntimeError
|
|
33
|
+
def apply(self, objective): raise RuntimeError
|
|
34
|
+
|
|
32
35
|
@torch.no_grad
|
|
33
|
-
def step(self,
|
|
36
|
+
def step(self, objective: Objective) -> Objective:
|
|
34
37
|
# pass cloned update to all module operands
|
|
35
38
|
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
36
39
|
|
|
37
40
|
for k,v in self.operands.items():
|
|
38
41
|
if k in self.children:
|
|
39
42
|
v: Module
|
|
40
|
-
|
|
41
|
-
processed_operands[k] =
|
|
42
|
-
|
|
43
|
+
updated_obj = v.step(objective.clone(clone_updates=True))
|
|
44
|
+
processed_operands[k] = updated_obj.get_updates()
|
|
45
|
+
objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them
|
|
43
46
|
|
|
44
|
-
transformed = self.transform(
|
|
45
|
-
|
|
46
|
-
return
|
|
47
|
+
transformed = self.transform(objective, update=objective.get_updates(), **processed_operands)
|
|
48
|
+
objective.updates = list(transformed)
|
|
49
|
+
return objective
|
|
47
50
|
|
|
48
51
|
|
|
49
52
|
class Add(BinaryOperationBase):
|
|
50
|
-
"""Add
|
|
53
|
+
"""Add ``other`` to tensors. ``other`` can be a number or a module.
|
|
51
54
|
|
|
52
|
-
If
|
|
55
|
+
If ``other`` is a module, this calculates ``tensors + other(tensors)``
|
|
53
56
|
"""
|
|
54
57
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
55
58
|
defaults = dict(alpha=alpha)
|
|
56
59
|
super().__init__(defaults, other=other)
|
|
57
60
|
|
|
58
61
|
@torch.no_grad
|
|
59
|
-
def transform(self,
|
|
62
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
60
63
|
if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
|
|
61
64
|
else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
|
|
62
65
|
return update
|
|
63
66
|
|
|
64
67
|
class Sub(BinaryOperationBase):
|
|
65
|
-
"""Subtract
|
|
68
|
+
"""Subtract ``other`` from tensors. ``other`` can be a number or a module.
|
|
66
69
|
|
|
67
|
-
If
|
|
70
|
+
If ``other`` is a module, this calculates :code:`tensors - other(tensors)`
|
|
68
71
|
"""
|
|
69
72
|
def __init__(self, other: Chainable | float, alpha: float = 1):
|
|
70
73
|
defaults = dict(alpha=alpha)
|
|
71
74
|
super().__init__(defaults, other=other)
|
|
72
75
|
|
|
73
76
|
@torch.no_grad
|
|
74
|
-
def transform(self,
|
|
77
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
75
78
|
if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
|
|
76
79
|
else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
|
|
77
80
|
return update
|
|
78
81
|
|
|
79
82
|
class RSub(BinaryOperationBase):
|
|
80
|
-
"""Subtract tensors from
|
|
83
|
+
"""Subtract tensors from ``other``. ``other`` can be a number or a module.
|
|
81
84
|
|
|
82
|
-
If
|
|
85
|
+
If ``other`` is a module, this calculates ``other(tensors) - tensors``
|
|
83
86
|
"""
|
|
84
87
|
def __init__(self, other: Chainable | float):
|
|
85
88
|
super().__init__({}, other=other)
|
|
86
89
|
|
|
87
90
|
@torch.no_grad
|
|
88
|
-
def transform(self,
|
|
91
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
89
92
|
return other - TensorList(update)
|
|
90
93
|
|
|
91
94
|
class Mul(BinaryOperationBase):
|
|
92
|
-
"""Multiply tensors by
|
|
95
|
+
"""Multiply tensors by ``other``. ``other`` can be a number or a module.
|
|
93
96
|
|
|
94
|
-
If
|
|
97
|
+
If ``other`` is a module, this calculates ``tensors * other(tensors)``
|
|
95
98
|
"""
|
|
96
99
|
def __init__(self, other: Chainable | float):
|
|
97
100
|
super().__init__({}, other=other)
|
|
98
101
|
|
|
99
102
|
@torch.no_grad
|
|
100
|
-
def transform(self,
|
|
103
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
101
104
|
torch._foreach_mul_(update, other)
|
|
102
105
|
return update
|
|
103
106
|
|
|
104
107
|
class Div(BinaryOperationBase):
|
|
105
|
-
"""Divide tensors by
|
|
108
|
+
"""Divide tensors by ``other``. ``other`` can be a number or a module.
|
|
106
109
|
|
|
107
|
-
If
|
|
110
|
+
If ``other`` is a module, this calculates ``tensors / other(tensors)``
|
|
108
111
|
"""
|
|
109
112
|
def __init__(self, other: Chainable | float):
|
|
110
113
|
super().__init__({}, other=other)
|
|
111
114
|
|
|
112
115
|
@torch.no_grad
|
|
113
|
-
def transform(self,
|
|
116
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
114
117
|
torch._foreach_div_(update, other)
|
|
115
118
|
return update
|
|
116
119
|
|
|
117
120
|
class RDiv(BinaryOperationBase):
|
|
118
|
-
"""Divide
|
|
121
|
+
"""Divide ``other`` by tensors. ``other`` can be a number or a module.
|
|
119
122
|
|
|
120
|
-
If
|
|
123
|
+
If ``other`` is a module, this calculates ``other(tensors) / tensors``
|
|
121
124
|
"""
|
|
122
125
|
def __init__(self, other: Chainable | float):
|
|
123
126
|
super().__init__({}, other=other)
|
|
124
127
|
|
|
125
128
|
@torch.no_grad
|
|
126
|
-
def transform(self,
|
|
129
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
127
130
|
return other / TensorList(update)
|
|
128
131
|
|
|
129
132
|
class Pow(BinaryOperationBase):
|
|
130
|
-
"""Take tensors to the power of
|
|
133
|
+
"""Take tensors to the power of ``exponent``. ``exponent`` can be a number or a module.
|
|
131
134
|
|
|
132
|
-
If
|
|
135
|
+
If ``exponent`` is a module, this calculates ``tensors ^ exponent(tensors)``
|
|
133
136
|
"""
|
|
134
137
|
def __init__(self, exponent: Chainable | float):
|
|
135
138
|
super().__init__({}, exponent=exponent)
|
|
136
139
|
|
|
137
140
|
@torch.no_grad
|
|
138
|
-
def transform(self,
|
|
141
|
+
def transform(self, objective, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
|
|
139
142
|
torch._foreach_pow_(update, exponent)
|
|
140
143
|
return update
|
|
141
144
|
|
|
142
145
|
class RPow(BinaryOperationBase):
|
|
143
|
-
"""Take
|
|
146
|
+
"""Take ``other`` to the power of tensors. ``other`` can be a number or a module.
|
|
144
147
|
|
|
145
|
-
If
|
|
148
|
+
If ``other`` is a module, this calculates ``other(tensors) ^ tensors``
|
|
146
149
|
"""
|
|
147
150
|
def __init__(self, other: Chainable | float):
|
|
148
151
|
super().__init__({}, other=other)
|
|
149
152
|
|
|
150
153
|
@torch.no_grad
|
|
151
|
-
def transform(self,
|
|
154
|
+
def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
|
|
152
155
|
if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
|
|
153
156
|
torch._foreach_pow_(other, update)
|
|
154
157
|
return other
|
|
155
158
|
|
|
156
159
|
class Lerp(BinaryOperationBase):
|
|
157
|
-
"""Does a linear interpolation of tensors and
|
|
160
|
+
"""Does a linear interpolation of tensors and ``end`` module based on a scalar ``weight``.
|
|
158
161
|
|
|
159
|
-
The output is given by
|
|
162
|
+
The output is given by ``output = tensors + weight * (end(tensors) - tensors)``
|
|
160
163
|
"""
|
|
161
164
|
def __init__(self, end: Chainable, weight: float):
|
|
162
165
|
defaults = dict(weight=weight)
|
|
163
166
|
super().__init__(defaults, end=end)
|
|
164
167
|
|
|
165
168
|
@torch.no_grad
|
|
166
|
-
def transform(self,
|
|
167
|
-
torch._foreach_lerp_(update, end, weight=self.get_settings(
|
|
169
|
+
def transform(self, objective, update: list[torch.Tensor], end: list[torch.Tensor]):
|
|
170
|
+
torch._foreach_lerp_(update, end, weight=self.get_settings(objective.params, 'weight'))
|
|
168
171
|
return update
|
|
169
172
|
|
|
170
173
|
class CopySign(BinaryOperationBase):
|
|
171
|
-
"""Returns tensors with sign copied from
|
|
174
|
+
"""Returns tensors with sign copied from ``other(tensors)``."""
|
|
172
175
|
def __init__(self, other: Chainable):
|
|
173
176
|
super().__init__({}, other=other)
|
|
174
177
|
|
|
175
178
|
@torch.no_grad
|
|
176
|
-
def transform(self,
|
|
179
|
+
def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
177
180
|
return [u.copysign_(o) for u, o in zip(update, other)]
|
|
178
181
|
|
|
179
182
|
class RCopySign(BinaryOperationBase):
|
|
180
|
-
"""Returns
|
|
183
|
+
"""Returns ``other(tensors)`` with sign copied from tensors."""
|
|
181
184
|
def __init__(self, other: Chainable):
|
|
182
185
|
super().__init__({}, other=other)
|
|
183
186
|
|
|
184
187
|
@torch.no_grad
|
|
185
|
-
def transform(self,
|
|
188
|
+
def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
186
189
|
return [o.copysign_(u) for u, o in zip(update, other)]
|
|
187
190
|
CopyMagnitude = RCopySign
|
|
188
191
|
|
|
189
192
|
class Clip(BinaryOperationBase):
|
|
190
|
-
"""clip tensors to be in
|
|
193
|
+
"""clip tensors to be in ``(min, max)`` range. ``min`` and ``max`: can be None, numbers or modules.
|
|
191
194
|
|
|
192
|
-
If
|
|
195
|
+
If ``min`` and ``max`` are modules, this calculates ``tensors.clip(min(tensors), max(tensors))``.
|
|
193
196
|
"""
|
|
194
197
|
def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
195
198
|
super().__init__({}, min=min, max=max)
|
|
196
199
|
|
|
197
200
|
@torch.no_grad
|
|
198
|
-
def transform(self,
|
|
201
|
+
def transform(self, objective, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
|
|
199
202
|
return TensorList(update).clamp_(min=min, max=max)
|
|
200
203
|
|
|
201
204
|
class MirroredClip(BinaryOperationBase):
|
|
202
|
-
"""clip tensors to be in
|
|
205
|
+
"""clip tensors to be in ``(-value, value)`` range. ``value`` can be a number or a module.
|
|
203
206
|
|
|
204
|
-
If
|
|
207
|
+
If ``value`` is a module, this calculates ``tensors.clip(-value(tensors), value(tensors))``
|
|
205
208
|
"""
|
|
206
209
|
def __init__(self, value: float | Chainable):
|
|
207
210
|
super().__init__({}, value=value)
|
|
208
211
|
|
|
209
212
|
@torch.no_grad
|
|
210
|
-
def transform(self,
|
|
213
|
+
def transform(self, objective, update: list[torch.Tensor], value: float | list[torch.Tensor]):
|
|
211
214
|
min = -value if isinstance(value, (int,float)) else [-v for v in value]
|
|
212
215
|
return TensorList(update).clamp_(min=min, max=value)
|
|
213
216
|
|
|
214
|
-
class
|
|
215
|
-
"""Outputs tensors rescaled to have the same norm as
|
|
217
|
+
class GraftInputToOutput(BinaryOperationBase):
|
|
218
|
+
"""Outputs ``tensors`` rescaled to have the same norm as ``magnitude(tensors)``."""
|
|
216
219
|
def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
217
220
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
218
221
|
super().__init__(defaults, magnitude=magnitude)
|
|
219
222
|
|
|
220
223
|
@torch.no_grad
|
|
221
|
-
def transform(self,
|
|
224
|
+
def transform(self, objective, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
|
|
222
225
|
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
|
|
223
226
|
return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
224
227
|
|
|
225
|
-
class
|
|
226
|
-
"""Outputs
|
|
228
|
+
class GraftOutputToInput(BinaryOperationBase):
|
|
229
|
+
"""Outputs ``magnitude(tensors)`` rescaled to have the same norm as ``tensors``"""
|
|
227
230
|
|
|
228
231
|
def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
|
|
229
232
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
230
233
|
super().__init__(defaults, direction=direction)
|
|
231
234
|
|
|
232
235
|
@torch.no_grad
|
|
233
|
-
def transform(self,
|
|
236
|
+
def transform(self, objective, update: list[torch.Tensor], direction: list[torch.Tensor]):
|
|
234
237
|
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
|
|
235
238
|
return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
236
239
|
|
|
237
|
-
GraftToUpdate = RGraft
|
|
238
|
-
|
|
239
240
|
class Maximum(BinaryOperationBase):
|
|
240
|
-
"""Outputs
|
|
241
|
+
"""Outputs ``maximum(tensors, other(tensors))``"""
|
|
241
242
|
def __init__(self, other: Chainable):
|
|
242
243
|
super().__init__({}, other=other)
|
|
243
244
|
|
|
244
245
|
@torch.no_grad
|
|
245
|
-
def transform(self,
|
|
246
|
+
def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
246
247
|
torch._foreach_maximum_(update, other)
|
|
247
248
|
return update
|
|
248
249
|
|
|
249
250
|
class Minimum(BinaryOperationBase):
|
|
250
|
-
"""Outputs
|
|
251
|
+
"""Outputs ``minimum(tensors, other(tensors))``"""
|
|
251
252
|
def __init__(self, other: Chainable):
|
|
252
253
|
super().__init__({}, other=other)
|
|
253
254
|
|
|
254
255
|
@torch.no_grad
|
|
255
|
-
def transform(self,
|
|
256
|
+
def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
256
257
|
torch._foreach_minimum_(update, other)
|
|
257
258
|
return update
|
|
258
259
|
|
|
259
260
|
|
|
260
261
|
class GramSchimdt(BinaryOperationBase):
|
|
261
|
-
"""outputs tensors made orthogonal to
|
|
262
|
+
"""outputs tensors made orthogonal to ``other(tensors)`` via Gram-Schmidt."""
|
|
262
263
|
def __init__(self, other: Chainable):
|
|
263
264
|
super().__init__({}, other=other)
|
|
264
265
|
|
|
265
266
|
@torch.no_grad
|
|
266
|
-
def transform(self,
|
|
267
|
+
def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
|
|
267
268
|
update = TensorList(update); other = TensorList(other)
|
|
268
269
|
min = torch.finfo(update[0].dtype).tiny * 2
|
|
269
270
|
return update - (other*update) / (other*other).clip(min=min)
|
|
270
271
|
|
|
271
272
|
|
|
272
273
|
class Threshold(BinaryOperationBase):
|
|
273
|
-
"""Outputs tensors thresholded such that values above
|
|
274
|
+
"""Outputs tensors thresholded such that values above ``threshold`` are set to ``value``."""
|
|
274
275
|
def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
|
|
275
276
|
defaults = dict(update_above=update_above)
|
|
276
277
|
super().__init__(defaults, threshold=threshold, value=value)
|
|
277
278
|
|
|
278
279
|
@torch.no_grad
|
|
279
|
-
def transform(self,
|
|
280
|
+
def transform(self, objective, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
|
|
280
281
|
update_above = self.defaults['update_above']
|
|
281
282
|
update = TensorList(update)
|
|
282
283
|
if update_above:
|
|
283
|
-
if isinstance(value, list): return update.
|
|
284
|
+
if isinstance(value, list): return update.where(update>threshold, value)
|
|
284
285
|
return update.masked_fill_(update<=threshold, value)
|
|
285
286
|
|
|
286
|
-
if isinstance(value, list): return update.
|
|
287
|
+
if isinstance(value, list): return update.where(update<threshold, value)
|
|
287
288
|
return update.masked_fill_(update>=threshold, value)
|
|
@@ -4,9 +4,9 @@ from typing import Literal
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import
|
|
7
|
+
from ...core import TensorTransform
|
|
8
8
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
|
-
from ..
|
|
9
|
+
from ..opt_utils import (
|
|
10
10
|
centered_ema_sq_,
|
|
11
11
|
debias,
|
|
12
12
|
debias_second_momentum,
|
|
@@ -17,7 +17,7 @@ from ..functional import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class EMASquared(
|
|
20
|
+
class EMASquared(TensorTransform):
|
|
21
21
|
"""Maintains an exponential moving average of squared updates.
|
|
22
22
|
|
|
23
23
|
Args:
|
|
@@ -29,10 +29,10 @@ class EMASquared(Transform):
|
|
|
29
29
|
|
|
30
30
|
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
|
|
31
31
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
|
|
32
|
-
super().__init__(defaults
|
|
32
|
+
super().__init__(defaults)
|
|
33
33
|
|
|
34
34
|
@torch.no_grad
|
|
35
|
-
def
|
|
35
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
36
36
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
37
37
|
beta = NumberList(s['beta'] for s in settings)
|
|
38
38
|
|
|
@@ -44,7 +44,7 @@ class EMASquared(Transform):
|
|
|
44
44
|
|
|
45
45
|
return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
|
|
46
46
|
|
|
47
|
-
class SqrtEMASquared(
|
|
47
|
+
class SqrtEMASquared(TensorTransform):
|
|
48
48
|
"""Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
|
|
49
49
|
|
|
50
50
|
Args:
|
|
@@ -56,11 +56,11 @@ class SqrtEMASquared(Transform):
|
|
|
56
56
|
SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
|
|
57
57
|
def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
|
|
58
58
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
|
|
59
|
-
super().__init__(defaults
|
|
59
|
+
super().__init__(defaults)
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
@torch.no_grad
|
|
63
|
-
def
|
|
63
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
64
64
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
65
65
|
|
|
66
66
|
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
@@ -83,7 +83,7 @@ class SqrtEMASquared(Transform):
|
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
|
|
86
|
-
class Debias(
|
|
86
|
+
class Debias(TensorTransform):
|
|
87
87
|
"""Multiplies the update by an Adam debiasing term based first and/or second momentum.
|
|
88
88
|
|
|
89
89
|
Args:
|
|
@@ -95,12 +95,12 @@ class Debias(Transform):
|
|
|
95
95
|
pow (float, optional): power, assumes absolute value is used. Defaults to 2.
|
|
96
96
|
target (Target, optional): target. Defaults to 'update'.
|
|
97
97
|
"""
|
|
98
|
-
def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2
|
|
98
|
+
def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2):
|
|
99
99
|
defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
|
|
100
|
-
super().__init__(defaults
|
|
100
|
+
super().__init__(defaults)
|
|
101
101
|
|
|
102
102
|
@torch.no_grad
|
|
103
|
-
def
|
|
103
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
104
104
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
105
105
|
|
|
106
106
|
pow = settings[0]['pow']
|
|
@@ -108,7 +108,7 @@ class Debias(Transform):
|
|
|
108
108
|
|
|
109
109
|
return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
|
|
110
110
|
|
|
111
|
-
class Debias2(
|
|
111
|
+
class Debias2(TensorTransform):
|
|
112
112
|
"""Multiplies the update by an Adam debiasing term based on the second momentum.
|
|
113
113
|
|
|
114
114
|
Args:
|
|
@@ -117,19 +117,19 @@ class Debias2(Transform):
|
|
|
117
117
|
pow (float, optional): power, assumes absolute value is used. Defaults to 2.
|
|
118
118
|
target (Target, optional): target. Defaults to 'update'.
|
|
119
119
|
"""
|
|
120
|
-
def __init__(self, beta: float = 0.999, pow: float = 2,
|
|
120
|
+
def __init__(self, beta: float = 0.999, pow: float = 2,):
|
|
121
121
|
defaults = dict(beta=beta, pow=pow)
|
|
122
|
-
super().__init__(defaults, uses_grad=False
|
|
122
|
+
super().__init__(defaults, uses_grad=False)
|
|
123
123
|
|
|
124
124
|
@torch.no_grad
|
|
125
|
-
def
|
|
125
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
126
126
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
127
127
|
|
|
128
128
|
pow = settings[0]['pow']
|
|
129
129
|
beta = NumberList(s['beta'] for s in settings)
|
|
130
130
|
return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
|
|
131
131
|
|
|
132
|
-
class CenteredEMASquared(
|
|
132
|
+
class CenteredEMASquared(TensorTransform):
|
|
133
133
|
"""Maintains a centered exponential moving average of squared updates. This also maintains an additional
|
|
134
134
|
exponential moving average of un-squared updates, square of which is subtracted from the EMA.
|
|
135
135
|
|
|
@@ -143,7 +143,7 @@ class CenteredEMASquared(Transform):
|
|
|
143
143
|
super().__init__(defaults, uses_grad=False)
|
|
144
144
|
|
|
145
145
|
@torch.no_grad
|
|
146
|
-
def
|
|
146
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
147
147
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
|
|
148
148
|
beta = NumberList(s['beta'] for s in settings)
|
|
149
149
|
|
|
@@ -162,7 +162,7 @@ class CenteredEMASquared(Transform):
|
|
|
162
162
|
pow=pow,
|
|
163
163
|
).clone()
|
|
164
164
|
|
|
165
|
-
class CenteredSqrtEMASquared(
|
|
165
|
+
class CenteredSqrtEMASquared(TensorTransform):
|
|
166
166
|
"""Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
|
|
167
167
|
This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
|
|
168
168
|
|
|
@@ -177,7 +177,7 @@ class CenteredSqrtEMASquared(Transform):
|
|
|
177
177
|
super().__init__(defaults, uses_grad=False)
|
|
178
178
|
|
|
179
179
|
@torch.no_grad
|
|
180
|
-
def
|
|
180
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
181
181
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
182
182
|
|
|
183
183
|
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|