torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -5,39 +5,16 @@ from typing import Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import
|
|
9
|
-
from ..functional import
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
|
-
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
-
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
21
|
-
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
22
|
-
"""
|
|
23
|
-
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
24
|
-
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
25
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
|
-
|
|
27
|
-
@torch.no_grad
|
|
28
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
29
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
30
|
-
|
|
31
|
-
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
32
|
-
|
|
33
|
-
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
34
|
-
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
35
|
-
momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
36
|
-
|
|
37
|
-
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
38
|
-
|
|
39
|
-
if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
|
|
40
|
-
else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
|
+
from ..functional import (
|
|
10
|
+
centered_ema_sq_,
|
|
11
|
+
debias,
|
|
12
|
+
debias_second_momentum,
|
|
13
|
+
ema_,
|
|
14
|
+
ema_sq_,
|
|
15
|
+
sqrt_centered_ema_sq_,
|
|
16
|
+
sqrt_ema_sq_,
|
|
17
|
+
)
|
|
41
18
|
|
|
42
19
|
|
|
43
20
|
class EMASquared(Transform):
|
|
@@ -55,7 +32,7 @@ class EMASquared(Transform):
|
|
|
55
32
|
super().__init__(defaults, uses_grad=False)
|
|
56
33
|
|
|
57
34
|
@torch.no_grad
|
|
58
|
-
def
|
|
35
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
59
36
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
60
37
|
beta = NumberList(s['beta'] for s in settings)
|
|
61
38
|
|
|
@@ -83,7 +60,7 @@ class SqrtEMASquared(Transform):
|
|
|
83
60
|
|
|
84
61
|
|
|
85
62
|
@torch.no_grad
|
|
86
|
-
def
|
|
63
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
87
64
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
88
65
|
|
|
89
66
|
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
@@ -123,7 +100,7 @@ class Debias(Transform):
|
|
|
123
100
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
124
101
|
|
|
125
102
|
@torch.no_grad
|
|
126
|
-
def
|
|
103
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
127
104
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
128
105
|
|
|
129
106
|
pow = settings[0]['pow']
|
|
@@ -145,7 +122,7 @@ class Debias2(Transform):
|
|
|
145
122
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
146
123
|
|
|
147
124
|
@torch.no_grad
|
|
148
|
-
def
|
|
125
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
149
126
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
150
127
|
|
|
151
128
|
pow = settings[0]['pow']
|
|
@@ -166,7 +143,7 @@ class CenteredEMASquared(Transform):
|
|
|
166
143
|
super().__init__(defaults, uses_grad=False)
|
|
167
144
|
|
|
168
145
|
@torch.no_grad
|
|
169
|
-
def
|
|
146
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
170
147
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
|
|
171
148
|
beta = NumberList(s['beta'] for s in settings)
|
|
172
149
|
|
|
@@ -200,7 +177,7 @@ class CenteredSqrtEMASquared(Transform):
|
|
|
200
177
|
super().__init__(defaults, uses_grad=False)
|
|
201
178
|
|
|
202
179
|
@torch.no_grad
|
|
203
|
-
def
|
|
180
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
204
181
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
205
182
|
|
|
206
183
|
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
torchzero/modules/ops/multi.py
CHANGED
|
@@ -3,15 +3,15 @@
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import Iterable, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
10
|
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
11
|
-
from ...utils import TensorList, tensorlist
|
|
11
|
+
from ...utils import TensorList, tensorlist, Metrics
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class
|
|
14
|
+
class MultiOperationBase(Module, ABC):
|
|
15
15
|
"""Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
16
16
|
def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
|
|
17
17
|
super().__init__(defaults=defaults)
|
|
@@ -51,14 +51,15 @@ class MultiOperation(Module, ABC):
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
class SubModules(
|
|
54
|
+
class SubModules(MultiOperationBase):
|
|
55
|
+
"""Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
55
56
|
def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
|
|
56
57
|
defaults = dict(alpha=alpha)
|
|
57
58
|
super().__init__(defaults, input=input, other=other)
|
|
58
59
|
|
|
59
60
|
@torch.no_grad
|
|
60
61
|
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
61
|
-
alpha = self.
|
|
62
|
+
alpha = self.defaults['alpha']
|
|
62
63
|
|
|
63
64
|
if isinstance(input, (int,float)):
|
|
64
65
|
assert isinstance(other, list)
|
|
@@ -68,10 +69,12 @@ class SubModules(MultiOperation):
|
|
|
68
69
|
else: torch._foreach_sub_(input, other, alpha=alpha)
|
|
69
70
|
return input
|
|
70
71
|
|
|
71
|
-
class DivModules(
|
|
72
|
-
|
|
72
|
+
class DivModules(MultiOperationBase):
|
|
73
|
+
"""Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
74
|
+
def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
|
|
73
75
|
defaults = {}
|
|
74
|
-
super().__init__(defaults,
|
|
76
|
+
if other_first: super().__init__(defaults, other=other, input=input)
|
|
77
|
+
else: super().__init__(defaults, input=input, other=other)
|
|
75
78
|
|
|
76
79
|
@torch.no_grad
|
|
77
80
|
def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
@@ -82,7 +85,9 @@ class DivModules(MultiOperation):
|
|
|
82
85
|
torch._foreach_div_(input, other)
|
|
83
86
|
return input
|
|
84
87
|
|
|
85
|
-
|
|
88
|
+
|
|
89
|
+
class PowModules(MultiOperationBase):
|
|
90
|
+
"""Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
|
|
86
91
|
def __init__(self, input: Chainable | float, exponent: Chainable | float):
|
|
87
92
|
defaults = {}
|
|
88
93
|
super().__init__(defaults, input=input, exponent=exponent)
|
|
@@ -96,17 +101,22 @@ class PowModules(MultiOperation):
|
|
|
96
101
|
torch._foreach_div_(input, exponent)
|
|
97
102
|
return input
|
|
98
103
|
|
|
99
|
-
class LerpModules(
|
|
104
|
+
class LerpModules(MultiOperationBase):
|
|
105
|
+
"""Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
|
|
106
|
+
|
|
107
|
+
The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
|
|
108
|
+
"""
|
|
100
109
|
def __init__(self, input: Chainable, end: Chainable, weight: float):
|
|
101
110
|
defaults = dict(weight=weight)
|
|
102
111
|
super().__init__(defaults, input=input, end=end)
|
|
103
112
|
|
|
104
113
|
@torch.no_grad
|
|
105
114
|
def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
106
|
-
torch._foreach_lerp_(input, end, weight=self.
|
|
115
|
+
torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
|
|
107
116
|
return input
|
|
108
117
|
|
|
109
|
-
class ClipModules(
|
|
118
|
+
class ClipModules(MultiOperationBase):
|
|
119
|
+
"""Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
|
|
110
120
|
def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
111
121
|
defaults = {}
|
|
112
122
|
super().__init__(defaults, input=input, min=min, max=max)
|
|
@@ -116,22 +126,73 @@ class ClipModules(MultiOperation):
|
|
|
116
126
|
return TensorList(input).clamp_(min=min, max=max)
|
|
117
127
|
|
|
118
128
|
|
|
119
|
-
class GraftModules(
|
|
120
|
-
|
|
129
|
+
class GraftModules(MultiOperationBase):
|
|
130
|
+
"""Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
direction (Chainable): module to use the direction from
|
|
134
|
+
magnitude (Chainable): module to use the magnitude from
|
|
135
|
+
tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
|
|
136
|
+
ord (float, optional): norm order. Defaults to 2.
|
|
137
|
+
eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
|
|
138
|
+
strength (float, optional): strength of grafting. Defaults to 1.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
Shampoo grafted to Adam
|
|
142
|
+
|
|
143
|
+
.. code-block:: python
|
|
144
|
+
|
|
145
|
+
opt = tz.Modular(
|
|
146
|
+
model.parameters(),
|
|
147
|
+
tz.m.GraftModules(
|
|
148
|
+
direction = tz.m.Shampoo(),
|
|
149
|
+
magnitude = tz.m.Adam(),
|
|
150
|
+
),
|
|
151
|
+
tz.m.LR(1e-3)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
Reference:
|
|
155
|
+
Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
|
|
156
|
+
"""
|
|
157
|
+
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
|
|
121
158
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
122
159
|
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
123
160
|
|
|
124
161
|
@torch.no_grad
|
|
125
162
|
def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
|
|
126
|
-
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.
|
|
163
|
+
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
|
|
127
164
|
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
128
165
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def __init__(self,
|
|
132
|
-
|
|
166
|
+
class MultiplyByModuleNorm(MultiOperationBase):
|
|
167
|
+
"""Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
|
|
168
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
|
|
169
|
+
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
170
|
+
super().__init__(defaults, input=input, norm=norm)
|
|
133
171
|
|
|
134
172
|
@torch.no_grad
|
|
135
|
-
def transform(self, var,
|
|
136
|
-
|
|
173
|
+
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
174
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
|
|
175
|
+
if tensorwise:
|
|
176
|
+
n = TensorList(norm).metric(ord)
|
|
177
|
+
else:
|
|
178
|
+
n = TensorList(norm).global_metric(ord)
|
|
179
|
+
|
|
180
|
+
torch._foreach_mul_(input, n)
|
|
181
|
+
return input
|
|
182
|
+
|
|
183
|
+
class DivideByModuleNorm(MultiOperationBase):
|
|
184
|
+
"""Outputs :code:`input` divided by norm of the :code:`norm` output."""
|
|
185
|
+
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
|
|
186
|
+
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
187
|
+
super().__init__(defaults, input=input, norm=norm)
|
|
137
188
|
|
|
189
|
+
@torch.no_grad
|
|
190
|
+
def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
191
|
+
tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
|
|
192
|
+
if tensorwise:
|
|
193
|
+
n = TensorList(norm).metric(ord)
|
|
194
|
+
else:
|
|
195
|
+
n = TensorList(norm).global_metric(ord)
|
|
196
|
+
|
|
197
|
+
torch._foreach_div_(input, n)
|
|
198
|
+
return input
|
torchzero/modules/ops/reduce.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch
|
|
|
8
8
|
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
class
|
|
11
|
+
class ReduceOperationBase(Module, ABC):
|
|
12
12
|
"""Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
13
13
|
def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
|
|
14
14
|
super().__init__(defaults=defaults)
|
|
@@ -46,7 +46,8 @@ class ReduceOperation(Module, ABC):
|
|
|
46
46
|
var.update = transformed
|
|
47
47
|
return var
|
|
48
48
|
|
|
49
|
-
class Sum(
|
|
49
|
+
class Sum(ReduceOperationBase):
|
|
50
|
+
"""Outputs sum of :code:`inputs` that can be modules or numbers."""
|
|
50
51
|
USE_MEAN = False
|
|
51
52
|
def __init__(self, *inputs: Chainable | float):
|
|
52
53
|
super().__init__({}, *inputs)
|
|
@@ -63,12 +64,14 @@ class Sum(ReduceOperation):
|
|
|
63
64
|
return sum
|
|
64
65
|
|
|
65
66
|
class Mean(Sum):
|
|
67
|
+
"""Outputs a mean of :code:`inputs` that can be modules or numbers."""
|
|
66
68
|
USE_MEAN = True
|
|
67
69
|
|
|
68
70
|
|
|
69
|
-
class WeightedSum(
|
|
71
|
+
class WeightedSum(ReduceOperationBase):
|
|
70
72
|
USE_MEAN = False
|
|
71
73
|
def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
|
|
74
|
+
"""Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
|
|
72
75
|
weights = list(weights)
|
|
73
76
|
if len(inputs) != len(weights):
|
|
74
77
|
raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
|
|
@@ -78,7 +81,7 @@ class WeightedSum(ReduceOperation):
|
|
|
78
81
|
@torch.no_grad
|
|
79
82
|
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
80
83
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
81
|
-
weights = self.
|
|
84
|
+
weights = self.defaults['weights']
|
|
82
85
|
sum = cast(list, sorted_inputs[0])
|
|
83
86
|
torch._foreach_mul_(sum, weights[0])
|
|
84
87
|
if len(sorted_inputs) > 1:
|
|
@@ -91,9 +94,11 @@ class WeightedSum(ReduceOperation):
|
|
|
91
94
|
|
|
92
95
|
|
|
93
96
|
class WeightedMean(WeightedSum):
|
|
97
|
+
"""Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
|
|
94
98
|
USE_MEAN = True
|
|
95
99
|
|
|
96
|
-
class Median(
|
|
100
|
+
class Median(ReduceOperationBase):
|
|
101
|
+
"""Outputs median of :code:`inputs` that can be modules or numbers."""
|
|
97
102
|
def __init__(self, *inputs: Chainable | float):
|
|
98
103
|
super().__init__({}, *inputs)
|
|
99
104
|
|
|
@@ -106,7 +111,8 @@ class Median(ReduceOperation):
|
|
|
106
111
|
res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
|
|
107
112
|
return res
|
|
108
113
|
|
|
109
|
-
class Prod(
|
|
114
|
+
class Prod(ReduceOperationBase):
|
|
115
|
+
"""Outputs product of :code:`inputs` that can be modules or numbers."""
|
|
110
116
|
def __init__(self, *inputs: Chainable | float):
|
|
111
117
|
super().__init__({}, *inputs)
|
|
112
118
|
|
|
@@ -120,7 +126,8 @@ class Prod(ReduceOperation):
|
|
|
120
126
|
|
|
121
127
|
return prod
|
|
122
128
|
|
|
123
|
-
class MaximumModules(
|
|
129
|
+
class MaximumModules(ReduceOperationBase):
|
|
130
|
+
"""Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
|
|
124
131
|
def __init__(self, *inputs: Chainable | float):
|
|
125
132
|
super().__init__({}, *inputs)
|
|
126
133
|
|
|
@@ -134,7 +141,8 @@ class MaximumModules(ReduceOperation):
|
|
|
134
141
|
|
|
135
142
|
return maximum
|
|
136
143
|
|
|
137
|
-
class MinimumModules(
|
|
144
|
+
class MinimumModules(ReduceOperationBase):
|
|
145
|
+
"""Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
|
|
138
146
|
def __init__(self, *inputs: Chainable | float):
|
|
139
147
|
super().__init__({}, *inputs)
|
|
140
148
|
|
torchzero/modules/ops/unary.py
CHANGED
|
@@ -6,76 +6,92 @@ from ...core import TensorwiseTransform, Target, Transform
|
|
|
6
6
|
from ...utils import TensorList, unpack_dicts,unpack_states
|
|
7
7
|
|
|
8
8
|
class UnaryLambda(Transform):
|
|
9
|
+
"""Applies :code:`fn` to input tensors.
|
|
10
|
+
|
|
11
|
+
:code:`fn` must accept and return a list of tensors.
|
|
12
|
+
"""
|
|
9
13
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
10
14
|
defaults = dict(fn=fn)
|
|
11
15
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
12
16
|
|
|
13
17
|
@torch.no_grad
|
|
14
|
-
def
|
|
18
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
15
19
|
return settings[0]['fn'](tensors)
|
|
16
20
|
|
|
17
21
|
class UnaryParameterwiseLambda(TensorwiseTransform):
|
|
22
|
+
"""Applies :code:`fn` to each input tensor.
|
|
23
|
+
|
|
24
|
+
:code:`fn` must accept and return a tensor.
|
|
25
|
+
"""
|
|
18
26
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
19
27
|
defaults = dict(fn=fn)
|
|
20
28
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
21
29
|
|
|
22
30
|
@torch.no_grad
|
|
23
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
24
|
-
return
|
|
31
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
32
|
+
return setting['fn'](tensor)
|
|
25
33
|
|
|
26
34
|
class CustomUnaryOperation(Transform):
|
|
35
|
+
"""Applies :code:`getattr(tensor, name)` to each tensor
|
|
36
|
+
"""
|
|
27
37
|
def __init__(self, name: str, target: "Target" = 'update'):
|
|
28
38
|
defaults = dict(name=name)
|
|
29
39
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
30
40
|
|
|
31
41
|
@torch.no_grad
|
|
32
|
-
def
|
|
42
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
33
43
|
return getattr(tensors, settings[0]['name'])()
|
|
34
44
|
|
|
35
45
|
|
|
36
46
|
class Abs(Transform):
|
|
47
|
+
"""Returns :code:`abs(input)`"""
|
|
37
48
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
38
49
|
@torch.no_grad
|
|
39
|
-
def
|
|
50
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
51
|
torch._foreach_abs_(tensors)
|
|
41
52
|
return tensors
|
|
42
53
|
|
|
43
54
|
class Sign(Transform):
|
|
55
|
+
"""Returns :code:`sign(input)`"""
|
|
44
56
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
45
57
|
@torch.no_grad
|
|
46
|
-
def
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
47
59
|
torch._foreach_sign_(tensors)
|
|
48
60
|
return tensors
|
|
49
61
|
|
|
50
62
|
class Exp(Transform):
|
|
63
|
+
"""Returns :code:`exp(input)`"""
|
|
51
64
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
52
65
|
@torch.no_grad
|
|
53
|
-
def
|
|
66
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
54
67
|
torch._foreach_exp_(tensors)
|
|
55
68
|
return tensors
|
|
56
69
|
|
|
57
70
|
class Sqrt(Transform):
|
|
71
|
+
"""Returns :code:`sqrt(input)`"""
|
|
58
72
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
59
73
|
@torch.no_grad
|
|
60
|
-
def
|
|
74
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
61
75
|
torch._foreach_sqrt_(tensors)
|
|
62
76
|
return tensors
|
|
63
77
|
|
|
64
78
|
class Reciprocal(Transform):
|
|
79
|
+
"""Returns :code:`1 / input`"""
|
|
65
80
|
def __init__(self, eps = 0, target: "Target" = 'update'):
|
|
66
81
|
defaults = dict(eps = eps)
|
|
67
82
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
68
83
|
@torch.no_grad
|
|
69
|
-
def
|
|
84
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
70
85
|
eps = [s['eps'] for s in settings]
|
|
71
86
|
if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
|
|
72
87
|
torch._foreach_reciprocal_(tensors)
|
|
73
88
|
return tensors
|
|
74
89
|
|
|
75
90
|
class Negate(Transform):
|
|
91
|
+
"""Returns :code:`- input`"""
|
|
76
92
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
77
93
|
@torch.no_grad
|
|
78
|
-
def
|
|
94
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
79
95
|
torch._foreach_neg_(tensors)
|
|
80
96
|
return tensors
|
|
81
97
|
|
|
@@ -97,18 +113,18 @@ class NanToNum(Transform):
|
|
|
97
113
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
98
114
|
|
|
99
115
|
@torch.no_grad
|
|
100
|
-
def
|
|
116
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
101
117
|
nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
|
|
102
118
|
return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
|
|
103
119
|
|
|
104
120
|
class Rescale(Transform):
|
|
105
|
-
"""
|
|
121
|
+
"""Rescales input to :code`(min, max)` range"""
|
|
106
122
|
def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
|
|
107
123
|
defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
|
|
108
124
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
109
125
|
|
|
110
126
|
@torch.no_grad
|
|
111
|
-
def
|
|
127
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
112
128
|
min, max = unpack_dicts(settings, 'min','max')
|
|
113
129
|
tensorwise = settings[0]['tensorwise']
|
|
114
130
|
dim = None if tensorwise else 'global'
|
torchzero/modules/ops/utility.py
CHANGED
|
@@ -4,38 +4,37 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...core import Module, Target, Transform
|
|
6
6
|
from ...utils.tensorlist import Distributions, TensorList
|
|
7
|
+
from ...utils.linalg.linear_operator import ScaledIdentity
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def __init__(self): super().__init__({}, uses_grad=False)
|
|
11
|
-
@torch.no_grad
|
|
12
|
-
def apply(self, tensors, params, grads, loss, states, settings): return [t.clone() for t in tensors]
|
|
13
|
-
|
|
14
|
-
class Grad(Module):
|
|
9
|
+
class Clone(Module):
|
|
10
|
+
"""Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
|
|
15
11
|
def __init__(self):
|
|
16
12
|
super().__init__({})
|
|
17
13
|
@torch.no_grad
|
|
18
14
|
def step(self, var):
|
|
19
|
-
var.update = [
|
|
15
|
+
var.update = [u.clone() for u in var.get_update()]
|
|
20
16
|
return var
|
|
21
17
|
|
|
22
|
-
class
|
|
18
|
+
class Grad(Module):
|
|
19
|
+
"""Outputs the gradient"""
|
|
23
20
|
def __init__(self):
|
|
24
21
|
super().__init__({})
|
|
25
22
|
@torch.no_grad
|
|
26
23
|
def step(self, var):
|
|
27
|
-
var.update = [
|
|
24
|
+
var.update = [g.clone() for g in var.get_grad()]
|
|
28
25
|
return var
|
|
29
26
|
|
|
30
|
-
class
|
|
27
|
+
class Params(Module):
|
|
28
|
+
"""Outputs parameters"""
|
|
31
29
|
def __init__(self):
|
|
32
30
|
super().__init__({})
|
|
33
31
|
@torch.no_grad
|
|
34
32
|
def step(self, var):
|
|
35
|
-
var.update = [
|
|
33
|
+
var.update = [p.clone() for p in var.params]
|
|
36
34
|
return var
|
|
37
35
|
|
|
38
36
|
class Zeros(Module):
|
|
37
|
+
"""Outputs zeros"""
|
|
39
38
|
def __init__(self):
|
|
40
39
|
super().__init__({})
|
|
41
40
|
@torch.no_grad
|
|
@@ -44,6 +43,7 @@ class Zeros(Module):
|
|
|
44
43
|
return var
|
|
45
44
|
|
|
46
45
|
class Ones(Module):
|
|
46
|
+
"""Outputs ones"""
|
|
47
47
|
def __init__(self):
|
|
48
48
|
super().__init__({})
|
|
49
49
|
@torch.no_grad
|
|
@@ -52,6 +52,7 @@ class Ones(Module):
|
|
|
52
52
|
return var
|
|
53
53
|
|
|
54
54
|
class Fill(Module):
|
|
55
|
+
"""Outputs tensors filled with :code:`value`"""
|
|
55
56
|
def __init__(self, value: float):
|
|
56
57
|
defaults = dict(value=value)
|
|
57
58
|
super().__init__(defaults)
|
|
@@ -62,18 +63,20 @@ class Fill(Module):
|
|
|
62
63
|
return var
|
|
63
64
|
|
|
64
65
|
class RandomSample(Module):
|
|
65
|
-
|
|
66
|
-
|
|
66
|
+
"""Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
|
|
67
|
+
def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
|
|
68
|
+
defaults = dict(distribution=distribution, variance=variance)
|
|
67
69
|
super().__init__(defaults)
|
|
68
70
|
|
|
69
71
|
@torch.no_grad
|
|
70
72
|
def step(self, var):
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
)
|
|
73
|
+
distribution = self.defaults['distribution']
|
|
74
|
+
variance = self.get_settings(var.params, 'variance')
|
|
75
|
+
var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
|
|
74
76
|
return var
|
|
75
77
|
|
|
76
78
|
class Randn(Module):
|
|
79
|
+
"""Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
|
|
77
80
|
def __init__(self):
|
|
78
81
|
super().__init__({})
|
|
79
82
|
|
|
@@ -83,6 +86,7 @@ class Randn(Module):
|
|
|
83
86
|
return var
|
|
84
87
|
|
|
85
88
|
class Uniform(Module):
|
|
89
|
+
"""Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
|
|
86
90
|
def __init__(self, low: float, high: float):
|
|
87
91
|
defaults = dict(low=low, high=high)
|
|
88
92
|
super().__init__(defaults)
|
|
@@ -94,19 +98,27 @@ class Uniform(Module):
|
|
|
94
98
|
return var
|
|
95
99
|
|
|
96
100
|
class GradToNone(Module):
|
|
101
|
+
"""Sets :code:`grad` attribute to None on :code:`var`."""
|
|
97
102
|
def __init__(self): super().__init__()
|
|
98
103
|
def step(self, var):
|
|
99
104
|
var.grad = None
|
|
100
105
|
return var
|
|
101
106
|
|
|
102
107
|
class UpdateToNone(Module):
|
|
108
|
+
"""Sets :code:`update` attribute to None on :code:`var`."""
|
|
103
109
|
def __init__(self): super().__init__()
|
|
104
110
|
def step(self, var):
|
|
105
111
|
var.update = None
|
|
106
112
|
return var
|
|
107
113
|
|
|
108
114
|
class Identity(Module):
|
|
115
|
+
"""Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
|
|
109
116
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
110
117
|
def step(self, var): return var
|
|
118
|
+
def get_H(self, var):
|
|
119
|
+
n = sum(p.numel() for p in var.params)
|
|
120
|
+
p = var.params[0]
|
|
121
|
+
return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
|
|
111
122
|
|
|
112
|
-
|
|
123
|
+
Noop = Identity
|
|
124
|
+
"""A placeholder identity operator that is argument-insensitive."""
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from .projection import
|
|
2
|
-
from .
|
|
3
|
-
from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
|
|
4
|
-
|
|
1
|
+
from .projection import ProjectionBase, VectorProjection, ScalarProjection
|
|
2
|
+
from .cast import To, ViewAsReal
|
|
5
3
|
# from .galore import GaLore
|