torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/reduce.py
CHANGED
|
@@ -5,10 +5,10 @@ from typing import Any, cast
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module, Target,
|
|
8
|
+
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
class
|
|
11
|
+
class ReduceOperationBase(Module, ABC):
|
|
12
12
|
"""Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
13
13
|
def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
|
|
14
14
|
super().__init__(defaults=defaults)
|
|
@@ -26,33 +26,34 @@ class ReduceOperation(Module, ABC):
|
|
|
26
26
|
raise ValueError('At least one operand must be a module')
|
|
27
27
|
|
|
28
28
|
@abstractmethod
|
|
29
|
-
def transform(self,
|
|
29
|
+
def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
30
30
|
"""applies the operation to operands"""
|
|
31
31
|
raise NotImplementedError
|
|
32
32
|
|
|
33
33
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
34
|
+
def step(self, var: Var) -> Var:
|
|
35
35
|
# pass cloned update to all module operands
|
|
36
36
|
processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
37
|
|
|
38
38
|
for i, v in enumerate(self.operands):
|
|
39
39
|
if f'operand_{i}' in self.children:
|
|
40
40
|
v: Module
|
|
41
|
-
|
|
42
|
-
processed_operands[i] =
|
|
43
|
-
|
|
41
|
+
updated_var = v.step(var.clone(clone_update=True))
|
|
42
|
+
processed_operands[i] = updated_var.get_update()
|
|
43
|
+
var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
|
|
44
44
|
|
|
45
|
-
transformed = self.transform(
|
|
46
|
-
|
|
47
|
-
return
|
|
45
|
+
transformed = self.transform(var, *processed_operands)
|
|
46
|
+
var.update = transformed
|
|
47
|
+
return var
|
|
48
48
|
|
|
49
|
-
class Sum(
|
|
49
|
+
class Sum(ReduceOperationBase):
|
|
50
|
+
"""Outputs sum of :code:`inputs` that can be modules or numbers."""
|
|
50
51
|
USE_MEAN = False
|
|
51
52
|
def __init__(self, *inputs: Chainable | float):
|
|
52
53
|
super().__init__({}, *inputs)
|
|
53
54
|
|
|
54
55
|
@torch.no_grad
|
|
55
|
-
def transform(self,
|
|
56
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
56
57
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
57
58
|
sum = cast(list, sorted_inputs[0])
|
|
58
59
|
if len(sorted_inputs) > 1:
|
|
@@ -63,12 +64,14 @@ class Sum(ReduceOperation):
|
|
|
63
64
|
return sum
|
|
64
65
|
|
|
65
66
|
class Mean(Sum):
|
|
67
|
+
"""Outputs a mean of :code:`inputs` that can be modules or numbers."""
|
|
66
68
|
USE_MEAN = True
|
|
67
69
|
|
|
68
70
|
|
|
69
|
-
class WeightedSum(
|
|
71
|
+
class WeightedSum(ReduceOperationBase):
|
|
70
72
|
USE_MEAN = False
|
|
71
73
|
def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
|
|
74
|
+
"""Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
|
|
72
75
|
weights = list(weights)
|
|
73
76
|
if len(inputs) != len(weights):
|
|
74
77
|
raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
|
|
@@ -76,9 +79,9 @@ class WeightedSum(ReduceOperation):
|
|
|
76
79
|
super().__init__(defaults=defaults, *inputs)
|
|
77
80
|
|
|
78
81
|
@torch.no_grad
|
|
79
|
-
def transform(self,
|
|
82
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
80
83
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
81
|
-
weights = self.settings[
|
|
84
|
+
weights = self.settings[var.params[0]]['weights']
|
|
82
85
|
sum = cast(list, sorted_inputs[0])
|
|
83
86
|
torch._foreach_mul_(sum, weights[0])
|
|
84
87
|
if len(sorted_inputs) > 1:
|
|
@@ -91,14 +94,16 @@ class WeightedSum(ReduceOperation):
|
|
|
91
94
|
|
|
92
95
|
|
|
93
96
|
class WeightedMean(WeightedSum):
|
|
97
|
+
"""Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
|
|
94
98
|
USE_MEAN = True
|
|
95
99
|
|
|
96
|
-
class Median(
|
|
100
|
+
class Median(ReduceOperationBase):
|
|
101
|
+
"""Outputs median of :code:`inputs` that can be modules or numbers."""
|
|
97
102
|
def __init__(self, *inputs: Chainable | float):
|
|
98
103
|
super().__init__({}, *inputs)
|
|
99
104
|
|
|
100
105
|
@torch.no_grad
|
|
101
|
-
def transform(self,
|
|
106
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
102
107
|
res = []
|
|
103
108
|
lists = [i for i in inputs if isinstance(i, list)]
|
|
104
109
|
floats = [i for i in inputs if isinstance(i, (int,float))]
|
|
@@ -106,12 +111,13 @@ class Median(ReduceOperation):
|
|
|
106
111
|
res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
|
|
107
112
|
return res
|
|
108
113
|
|
|
109
|
-
class Prod(
|
|
114
|
+
class Prod(ReduceOperationBase):
|
|
115
|
+
"""Outputs product of :code:`inputs` that can be modules or numbers."""
|
|
110
116
|
def __init__(self, *inputs: Chainable | float):
|
|
111
117
|
super().__init__({}, *inputs)
|
|
112
118
|
|
|
113
119
|
@torch.no_grad
|
|
114
|
-
def transform(self,
|
|
120
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
115
121
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
116
122
|
prod = cast(list, sorted_inputs[0])
|
|
117
123
|
if len(sorted_inputs) > 1:
|
|
@@ -120,12 +126,13 @@ class Prod(ReduceOperation):
|
|
|
120
126
|
|
|
121
127
|
return prod
|
|
122
128
|
|
|
123
|
-
class MaximumModules(
|
|
129
|
+
class MaximumModules(ReduceOperationBase):
|
|
130
|
+
"""Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
|
|
124
131
|
def __init__(self, *inputs: Chainable | float):
|
|
125
132
|
super().__init__({}, *inputs)
|
|
126
133
|
|
|
127
134
|
@torch.no_grad
|
|
128
|
-
def transform(self,
|
|
135
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
129
136
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
130
137
|
maximum = cast(list, sorted_inputs[0])
|
|
131
138
|
if len(sorted_inputs) > 1:
|
|
@@ -134,12 +141,13 @@ class MaximumModules(ReduceOperation):
|
|
|
134
141
|
|
|
135
142
|
return maximum
|
|
136
143
|
|
|
137
|
-
class MinimumModules(
|
|
144
|
+
class MinimumModules(ReduceOperationBase):
|
|
145
|
+
"""Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
|
|
138
146
|
def __init__(self, *inputs: Chainable | float):
|
|
139
147
|
super().__init__({}, *inputs)
|
|
140
148
|
|
|
141
149
|
@torch.no_grad
|
|
142
|
-
def transform(self,
|
|
150
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
143
151
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
144
152
|
minimum = cast(list, sorted_inputs[0])
|
|
145
153
|
if len(sorted_inputs) > 1:
|
torchzero/modules/ops/unary.py
CHANGED
|
@@ -3,79 +3,95 @@ from collections import deque
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import TensorwiseTransform, Target, Transform
|
|
6
|
-
from ...utils import TensorList
|
|
6
|
+
from ...utils import TensorList, unpack_dicts,unpack_states
|
|
7
7
|
|
|
8
8
|
class UnaryLambda(Transform):
|
|
9
|
+
"""Applies :code:`fn` to input tensors.
|
|
10
|
+
|
|
11
|
+
:code:`fn` must accept and return a list of tensors.
|
|
12
|
+
"""
|
|
9
13
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
10
14
|
defaults = dict(fn=fn)
|
|
11
15
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
12
16
|
|
|
13
17
|
@torch.no_grad
|
|
14
|
-
def
|
|
15
|
-
return
|
|
18
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
19
|
+
return settings[0]['fn'](tensors)
|
|
16
20
|
|
|
17
21
|
class UnaryParameterwiseLambda(TensorwiseTransform):
|
|
22
|
+
"""Applies :code:`fn` to each input tensor.
|
|
23
|
+
|
|
24
|
+
:code:`fn` must accept and return a tensor.
|
|
25
|
+
"""
|
|
18
26
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
19
27
|
defaults = dict(fn=fn)
|
|
20
28
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
21
29
|
|
|
22
30
|
@torch.no_grad
|
|
23
|
-
def
|
|
24
|
-
return
|
|
31
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
32
|
+
return setting['fn'](tensor)
|
|
25
33
|
|
|
26
34
|
class CustomUnaryOperation(Transform):
|
|
35
|
+
"""Applies :code:`getattr(tensor, name)` to each tensor
|
|
36
|
+
"""
|
|
27
37
|
def __init__(self, name: str, target: "Target" = 'update'):
|
|
28
38
|
defaults = dict(name=name)
|
|
29
39
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
30
40
|
|
|
31
41
|
@torch.no_grad
|
|
32
|
-
def
|
|
33
|
-
return getattr(tensors,
|
|
42
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
43
|
+
return getattr(tensors, settings[0]['name'])()
|
|
34
44
|
|
|
35
45
|
|
|
36
46
|
class Abs(Transform):
|
|
47
|
+
"""Returns :code:`abs(input)`"""
|
|
37
48
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
38
49
|
@torch.no_grad
|
|
39
|
-
def
|
|
50
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
51
|
torch._foreach_abs_(tensors)
|
|
41
52
|
return tensors
|
|
42
53
|
|
|
43
54
|
class Sign(Transform):
|
|
55
|
+
"""Returns :code:`sign(input)`"""
|
|
44
56
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
45
57
|
@torch.no_grad
|
|
46
|
-
def
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
47
59
|
torch._foreach_sign_(tensors)
|
|
48
60
|
return tensors
|
|
49
61
|
|
|
50
62
|
class Exp(Transform):
|
|
63
|
+
"""Returns :code:`exp(input)`"""
|
|
51
64
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
52
65
|
@torch.no_grad
|
|
53
|
-
def
|
|
66
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
54
67
|
torch._foreach_exp_(tensors)
|
|
55
68
|
return tensors
|
|
56
69
|
|
|
57
70
|
class Sqrt(Transform):
|
|
71
|
+
"""Returns :code:`sqrt(input)`"""
|
|
58
72
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
59
73
|
@torch.no_grad
|
|
60
|
-
def
|
|
74
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
61
75
|
torch._foreach_sqrt_(tensors)
|
|
62
76
|
return tensors
|
|
63
77
|
|
|
64
78
|
class Reciprocal(Transform):
|
|
79
|
+
"""Returns :code:`1 / input`"""
|
|
65
80
|
def __init__(self, eps = 0, target: "Target" = 'update'):
|
|
66
81
|
defaults = dict(eps = eps)
|
|
67
82
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
68
83
|
@torch.no_grad
|
|
69
|
-
def
|
|
70
|
-
eps =
|
|
84
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
85
|
+
eps = [s['eps'] for s in settings]
|
|
71
86
|
if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
|
|
72
87
|
torch._foreach_reciprocal_(tensors)
|
|
73
88
|
return tensors
|
|
74
89
|
|
|
75
90
|
class Negate(Transform):
|
|
91
|
+
"""Returns :code:`- input`"""
|
|
76
92
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
77
93
|
@torch.no_grad
|
|
78
|
-
def
|
|
94
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
79
95
|
torch._foreach_neg_(tensors)
|
|
80
96
|
return tensors
|
|
81
97
|
|
|
@@ -97,19 +113,19 @@ class NanToNum(Transform):
|
|
|
97
113
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
98
114
|
|
|
99
115
|
@torch.no_grad
|
|
100
|
-
def
|
|
101
|
-
nan, posinf, neginf =
|
|
116
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
117
|
+
nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
|
|
102
118
|
return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
|
|
103
119
|
|
|
104
120
|
class Rescale(Transform):
|
|
105
|
-
"""
|
|
121
|
+
"""Rescales input to :code`(min, max)` range"""
|
|
106
122
|
def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
|
|
107
123
|
defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
|
|
108
124
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
109
125
|
|
|
110
126
|
@torch.no_grad
|
|
111
|
-
def
|
|
112
|
-
min,max =
|
|
113
|
-
tensorwise =
|
|
127
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
128
|
+
min, max = unpack_dicts(settings, 'min','max')
|
|
129
|
+
tensorwise = settings[0]['tensorwise']
|
|
114
130
|
dim = None if tensorwise else 'global'
|
|
115
|
-
return TensorList(tensors).rescale(min=min, max=max, eps=
|
|
131
|
+
return TensorList(tensors).rescale(min=min, max=max, eps=settings[0]['eps'], dim=dim)
|
torchzero/modules/ops/utility.py
CHANGED
|
@@ -6,107 +6,115 @@ from ...core import Module, Target, Transform
|
|
|
6
6
|
from ...utils.tensorlist import Distributions, TensorList
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class Clone(
|
|
10
|
-
|
|
11
|
-
@torch.no_grad
|
|
12
|
-
def transform(self, tensors, params, grads, vars): return [t.clone() for t in tensors]
|
|
13
|
-
|
|
14
|
-
class Grad(Module):
|
|
9
|
+
class Clone(Module):
|
|
10
|
+
"""Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
|
|
15
11
|
def __init__(self):
|
|
16
12
|
super().__init__({})
|
|
17
13
|
@torch.no_grad
|
|
18
|
-
def step(self,
|
|
19
|
-
|
|
20
|
-
return
|
|
14
|
+
def step(self, var):
|
|
15
|
+
var.update = [u.clone() for u in var.get_update()]
|
|
16
|
+
return var
|
|
21
17
|
|
|
22
|
-
class
|
|
18
|
+
class Grad(Module):
|
|
19
|
+
"""Outputs the gradient"""
|
|
23
20
|
def __init__(self):
|
|
24
21
|
super().__init__({})
|
|
25
22
|
@torch.no_grad
|
|
26
|
-
def step(self,
|
|
27
|
-
|
|
28
|
-
return
|
|
23
|
+
def step(self, var):
|
|
24
|
+
var.update = [g.clone() for g in var.get_grad()]
|
|
25
|
+
return var
|
|
29
26
|
|
|
30
|
-
class
|
|
27
|
+
class Params(Module):
|
|
28
|
+
"""Outputs parameters"""
|
|
31
29
|
def __init__(self):
|
|
32
30
|
super().__init__({})
|
|
33
31
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
35
|
-
|
|
36
|
-
return
|
|
32
|
+
def step(self, var):
|
|
33
|
+
var.update = [p.clone() for p in var.params]
|
|
34
|
+
return var
|
|
37
35
|
|
|
38
36
|
class Zeros(Module):
|
|
37
|
+
"""Outputs zeros"""
|
|
39
38
|
def __init__(self):
|
|
40
39
|
super().__init__({})
|
|
41
40
|
@torch.no_grad
|
|
42
|
-
def step(self,
|
|
43
|
-
|
|
44
|
-
return
|
|
41
|
+
def step(self, var):
|
|
42
|
+
var.update = [torch.zeros_like(p) for p in var.params]
|
|
43
|
+
return var
|
|
45
44
|
|
|
46
45
|
class Ones(Module):
|
|
46
|
+
"""Outputs ones"""
|
|
47
47
|
def __init__(self):
|
|
48
48
|
super().__init__({})
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def step(self,
|
|
51
|
-
|
|
52
|
-
return
|
|
50
|
+
def step(self, var):
|
|
51
|
+
var.update = [torch.ones_like(p) for p in var.params]
|
|
52
|
+
return var
|
|
53
53
|
|
|
54
54
|
class Fill(Module):
|
|
55
|
+
"""Outputs tensors filled with :code:`value`"""
|
|
55
56
|
def __init__(self, value: float):
|
|
56
57
|
defaults = dict(value=value)
|
|
57
58
|
super().__init__(defaults)
|
|
58
59
|
|
|
59
60
|
@torch.no_grad
|
|
60
|
-
def step(self,
|
|
61
|
-
|
|
62
|
-
return
|
|
61
|
+
def step(self, var):
|
|
62
|
+
var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
|
|
63
|
+
return var
|
|
63
64
|
|
|
64
65
|
class RandomSample(Module):
|
|
66
|
+
"""Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
|
|
65
67
|
def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
|
|
66
68
|
defaults = dict(eps=eps, distribution=distribution)
|
|
67
69
|
super().__init__(defaults)
|
|
68
70
|
|
|
69
71
|
@torch.no_grad
|
|
70
|
-
def step(self,
|
|
71
|
-
|
|
72
|
-
eps=self.
|
|
72
|
+
def step(self, var):
|
|
73
|
+
var.update = TensorList(var.params).sample_like(
|
|
74
|
+
eps=[self.settings[p]['eps'] for p in var.params], distribution=self.settings[var.params[0]]['distribution']
|
|
73
75
|
)
|
|
74
|
-
return
|
|
76
|
+
return var
|
|
75
77
|
|
|
76
78
|
class Randn(Module):
|
|
79
|
+
"""Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
|
|
77
80
|
def __init__(self):
|
|
78
81
|
super().__init__({})
|
|
79
82
|
|
|
80
83
|
@torch.no_grad
|
|
81
|
-
def step(self,
|
|
82
|
-
|
|
83
|
-
return
|
|
84
|
+
def step(self, var):
|
|
85
|
+
var.update = [torch.randn_like(p) for p in var.params]
|
|
86
|
+
return var
|
|
84
87
|
|
|
85
88
|
class Uniform(Module):
|
|
89
|
+
"""Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
|
|
86
90
|
def __init__(self, low: float, high: float):
|
|
87
91
|
defaults = dict(low=low, high=high)
|
|
88
92
|
super().__init__(defaults)
|
|
89
93
|
|
|
90
94
|
@torch.no_grad
|
|
91
|
-
def step(self,
|
|
92
|
-
low,high = self.get_settings('low','high'
|
|
93
|
-
|
|
94
|
-
return
|
|
95
|
+
def step(self, var):
|
|
96
|
+
low,high = self.get_settings(var.params, 'low','high')
|
|
97
|
+
var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
|
|
98
|
+
return var
|
|
95
99
|
|
|
96
100
|
class GradToNone(Module):
|
|
101
|
+
"""Sets :code:`grad` attribute to None on :code:`var`."""
|
|
97
102
|
def __init__(self): super().__init__()
|
|
98
|
-
def step(self,
|
|
99
|
-
|
|
100
|
-
return
|
|
103
|
+
def step(self, var):
|
|
104
|
+
var.grad = None
|
|
105
|
+
return var
|
|
101
106
|
|
|
102
107
|
class UpdateToNone(Module):
|
|
108
|
+
"""Sets :code:`update` attribute to None on :code:`var`."""
|
|
103
109
|
def __init__(self): super().__init__()
|
|
104
|
-
def step(self,
|
|
105
|
-
|
|
106
|
-
return
|
|
110
|
+
def step(self, var):
|
|
111
|
+
var.update = None
|
|
112
|
+
return var
|
|
107
113
|
|
|
108
114
|
class Identity(Module):
|
|
115
|
+
"""A placeholder identity operator that is argument-insensitive."""
|
|
109
116
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
110
|
-
def step(self,
|
|
117
|
+
def step(self, var): return var
|
|
111
118
|
|
|
112
|
-
NoOp = Identity
|
|
119
|
+
NoOp = Identity
|
|
120
|
+
"""A placeholder identity operator that is argument-insensitive."""
|
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
from .adagrad import Adagrad, FullMatrixAdagrad
|
|
2
|
+
|
|
3
|
+
# from .curveball import CurveBall
|
|
4
|
+
# from .spectral import SpectralPreconditioner
|
|
5
|
+
from .adahessian import AdaHessian
|
|
2
6
|
from .adam import Adam
|
|
7
|
+
from .adan import Adan
|
|
8
|
+
from .adaptive_heavyball import AdaptiveHeavyBall
|
|
9
|
+
from .esgd import ESGD
|
|
10
|
+
from .ladagrad import LMAdagrad
|
|
3
11
|
from .lion import Lion
|
|
12
|
+
from .mars import MARSCorrection
|
|
13
|
+
from .msam import MSAM, MSAMObjective
|
|
4
14
|
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
15
|
+
from .orthograd import OrthoGrad, orthograd_
|
|
5
16
|
from .rmsprop import RMSprop
|
|
6
17
|
from .rprop import (
|
|
7
18
|
BacktrackOnSignChange,
|
|
@@ -10,9 +21,7 @@ from .rprop import (
|
|
|
10
21
|
SignConsistencyLRs,
|
|
11
22
|
SignConsistencyMask,
|
|
12
23
|
)
|
|
24
|
+
from .sam import ASAM, SAM
|
|
13
25
|
from .shampoo import Shampoo
|
|
14
26
|
from .soap import SOAP
|
|
15
|
-
from .orthograd import OrthoGrad, orthograd_
|
|
16
27
|
from .sophia_h import SophiaH
|
|
17
|
-
# from .curveball import CurveBall
|
|
18
|
-
# from .spectral import SpectralPreconditioner
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
|
-
|
|
5
5
|
from ...core import (
|
|
6
6
|
Chainable,
|
|
7
7
|
Module,
|
|
8
|
-
Preconditioner,
|
|
9
8
|
Target,
|
|
10
|
-
|
|
9
|
+
TensorwiseTransform,
|
|
11
10
|
Transform,
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
Var,
|
|
12
|
+
apply_transform,
|
|
14
13
|
)
|
|
15
|
-
from ...utils import NumberList, TensorList
|
|
14
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
16
15
|
from ...utils.linalg import matrix_power_eigh
|
|
17
16
|
from ..functional import add_power_, lerp_power_, root
|
|
18
17
|
|
|
@@ -26,12 +25,12 @@ def adagrad_(
|
|
|
26
25
|
step: int,
|
|
27
26
|
pow: float = 2,
|
|
28
27
|
use_sqrt: bool = True,
|
|
28
|
+
divide: bool = False,
|
|
29
29
|
|
|
30
30
|
# inner args
|
|
31
31
|
inner: Module | None = None,
|
|
32
32
|
params: list[torch.Tensor] | None = None,
|
|
33
33
|
grads: list[torch.Tensor] | None = None,
|
|
34
|
-
vars: Vars | None = None,
|
|
35
34
|
):
|
|
36
35
|
"""returns `tensors_`"""
|
|
37
36
|
clr = alpha / (1 + step * lr_decay)
|
|
@@ -40,7 +39,9 @@ def adagrad_(
|
|
|
40
39
|
|
|
41
40
|
if inner is not None:
|
|
42
41
|
assert params is not None
|
|
43
|
-
tensors_ = TensorList(
|
|
42
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
43
|
+
|
|
44
|
+
if divide: sq_sum_ = sq_sum_ / max(step, 1)
|
|
44
45
|
|
|
45
46
|
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
46
47
|
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
@@ -50,7 +51,9 @@ def adagrad_(
|
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
class Adagrad(Transform):
|
|
53
|
-
"""Adagrad, divides by sum of past squares of gradients
|
|
54
|
+
"""Adagrad, divides by sum of past squares of gradients.
|
|
55
|
+
|
|
56
|
+
This implementation is identical to :code:`torch.optim.Adagrad`.
|
|
54
57
|
|
|
55
58
|
Args:
|
|
56
59
|
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
@@ -69,29 +72,30 @@ class Adagrad(Transform):
|
|
|
69
72
|
alpha: float = 1,
|
|
70
73
|
pow: float = 2,
|
|
71
74
|
use_sqrt: bool = True,
|
|
75
|
+
divide: bool=False,
|
|
72
76
|
inner: Chainable | None = None,
|
|
73
77
|
):
|
|
74
78
|
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
75
|
-
eps = eps, pow=pow, use_sqrt = use_sqrt)
|
|
79
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide)
|
|
76
80
|
super().__init__(defaults=defaults, uses_grad=False)
|
|
77
81
|
|
|
78
82
|
if inner is not None:
|
|
79
83
|
self.set_child('inner', inner)
|
|
80
84
|
|
|
81
85
|
@torch.no_grad
|
|
82
|
-
def
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
83
87
|
tensors = TensorList(tensors)
|
|
84
88
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
85
89
|
|
|
86
|
-
lr_decay,alpha,eps =
|
|
90
|
+
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
87
91
|
|
|
88
|
-
pow, use_sqrt = itemgetter('pow', 'use_sqrt')(
|
|
92
|
+
pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
|
|
89
93
|
|
|
90
|
-
sq_sum =
|
|
94
|
+
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
91
95
|
|
|
92
96
|
# initialize accumulator on 1st step
|
|
93
97
|
if step == 1:
|
|
94
|
-
sq_sum.set_(tensors.full_like(
|
|
98
|
+
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
95
99
|
|
|
96
100
|
return adagrad_(
|
|
97
101
|
tensors,
|
|
@@ -102,45 +106,60 @@ class Adagrad(Transform):
|
|
|
102
106
|
step=self.global_state["step"],
|
|
103
107
|
pow=pow,
|
|
104
108
|
use_sqrt=use_sqrt,
|
|
109
|
+
divide=divide,
|
|
105
110
|
|
|
106
111
|
# inner args
|
|
107
112
|
inner=self.children.get("inner", None),
|
|
108
113
|
params=params,
|
|
109
114
|
grads=grads,
|
|
110
|
-
vars=vars,
|
|
111
115
|
)
|
|
112
116
|
|
|
113
117
|
|
|
114
118
|
|
|
115
|
-
class FullMatrixAdagrad(
|
|
116
|
-
def __init__(self, beta: float | None = None, decay: float | None = None, concat_params=
|
|
117
|
-
defaults = dict(beta=beta, decay=decay)
|
|
118
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
119
|
+
class FullMatrixAdagrad(TensorwiseTransform):
|
|
120
|
+
def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=True, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', divide: bool=False, inner: Chainable | None = None):
|
|
121
|
+
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init, divide=divide)
|
|
122
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner,)
|
|
119
123
|
|
|
120
124
|
@torch.no_grad
|
|
121
|
-
def update_tensor(self, tensor, param, grad, state,
|
|
125
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
122
126
|
G = tensor.ravel()
|
|
123
127
|
GG = torch.outer(G, G)
|
|
124
|
-
decay =
|
|
125
|
-
beta =
|
|
126
|
-
|
|
127
|
-
|
|
128
|
+
decay = setting['decay']
|
|
129
|
+
beta = setting['beta']
|
|
130
|
+
init = setting['init']
|
|
131
|
+
|
|
132
|
+
if 'GG' not in state:
|
|
133
|
+
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
134
|
+
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
135
|
+
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
136
|
+
elif init == 'GGT': state['GG'] = GG.clone()
|
|
137
|
+
else: raise ValueError(init)
|
|
128
138
|
if decay is not None: state['GG'].mul_(decay)
|
|
129
139
|
|
|
130
140
|
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
131
141
|
else: state['GG'].add_(GG)
|
|
142
|
+
state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
|
|
132
143
|
|
|
133
144
|
@torch.no_grad
|
|
134
|
-
def apply_tensor(self, tensor, param, grad, state,
|
|
145
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
135
146
|
GG = state['GG']
|
|
147
|
+
sqrt = setting['sqrt']
|
|
148
|
+
divide = setting['divide']
|
|
149
|
+
if divide: GG = GG/state.get('i', 1)
|
|
136
150
|
|
|
137
151
|
if tensor.numel() == 1:
|
|
138
|
-
|
|
152
|
+
GG = GG.squeeze()
|
|
153
|
+
if sqrt: return tensor / GG.sqrt()
|
|
154
|
+
return tensor / GG
|
|
139
155
|
|
|
140
156
|
try:
|
|
141
|
-
B = matrix_power_eigh(GG, -1/2)
|
|
157
|
+
if sqrt: B = matrix_power_eigh(GG, -1/2)
|
|
158
|
+
else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
|
|
159
|
+
|
|
142
160
|
except torch.linalg.LinAlgError:
|
|
143
|
-
|
|
161
|
+
scale = 1 / tensor.abs().max()
|
|
162
|
+
return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
|
|
144
163
|
|
|
145
164
|
return (B @ tensor.ravel()).view_as(tensor)
|
|
146
165
|
|