torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/reduce.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any, cast
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module, Target,
|
|
8
|
+
from ...core import Chainable, Module, Target, Var, maybe_chain
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class ReduceOperation(Module, ABC):
|
|
@@ -26,25 +26,25 @@ class ReduceOperation(Module, ABC):
|
|
|
26
26
|
raise ValueError('At least one operand must be a module')
|
|
27
27
|
|
|
28
28
|
@abstractmethod
|
|
29
|
-
def transform(self,
|
|
29
|
+
def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
30
30
|
"""applies the operation to operands"""
|
|
31
31
|
raise NotImplementedError
|
|
32
32
|
|
|
33
33
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
34
|
+
def step(self, var: Var) -> Var:
|
|
35
35
|
# pass cloned update to all module operands
|
|
36
36
|
processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
37
|
|
|
38
38
|
for i, v in enumerate(self.operands):
|
|
39
39
|
if f'operand_{i}' in self.children:
|
|
40
40
|
v: Module
|
|
41
|
-
|
|
42
|
-
processed_operands[i] =
|
|
43
|
-
|
|
41
|
+
updated_var = v.step(var.clone(clone_update=True))
|
|
42
|
+
processed_operands[i] = updated_var.get_update()
|
|
43
|
+
var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
|
|
44
44
|
|
|
45
|
-
transformed = self.transform(
|
|
46
|
-
|
|
47
|
-
return
|
|
45
|
+
transformed = self.transform(var, *processed_operands)
|
|
46
|
+
var.update = transformed
|
|
47
|
+
return var
|
|
48
48
|
|
|
49
49
|
class Sum(ReduceOperation):
|
|
50
50
|
USE_MEAN = False
|
|
@@ -52,7 +52,7 @@ class Sum(ReduceOperation):
|
|
|
52
52
|
super().__init__({}, *inputs)
|
|
53
53
|
|
|
54
54
|
@torch.no_grad
|
|
55
|
-
def transform(self,
|
|
55
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
56
56
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
57
57
|
sum = cast(list, sorted_inputs[0])
|
|
58
58
|
if len(sorted_inputs) > 1:
|
|
@@ -76,9 +76,9 @@ class WeightedSum(ReduceOperation):
|
|
|
76
76
|
super().__init__(defaults=defaults, *inputs)
|
|
77
77
|
|
|
78
78
|
@torch.no_grad
|
|
79
|
-
def transform(self,
|
|
79
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
80
80
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
81
|
-
weights = self.settings[
|
|
81
|
+
weights = self.settings[var.params[0]]['weights']
|
|
82
82
|
sum = cast(list, sorted_inputs[0])
|
|
83
83
|
torch._foreach_mul_(sum, weights[0])
|
|
84
84
|
if len(sorted_inputs) > 1:
|
|
@@ -98,7 +98,7 @@ class Median(ReduceOperation):
|
|
|
98
98
|
super().__init__({}, *inputs)
|
|
99
99
|
|
|
100
100
|
@torch.no_grad
|
|
101
|
-
def transform(self,
|
|
101
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
102
102
|
res = []
|
|
103
103
|
lists = [i for i in inputs if isinstance(i, list)]
|
|
104
104
|
floats = [i for i in inputs if isinstance(i, (int,float))]
|
|
@@ -111,7 +111,7 @@ class Prod(ReduceOperation):
|
|
|
111
111
|
super().__init__({}, *inputs)
|
|
112
112
|
|
|
113
113
|
@torch.no_grad
|
|
114
|
-
def transform(self,
|
|
114
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
115
115
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
116
116
|
prod = cast(list, sorted_inputs[0])
|
|
117
117
|
if len(sorted_inputs) > 1:
|
|
@@ -125,7 +125,7 @@ class MaximumModules(ReduceOperation):
|
|
|
125
125
|
super().__init__({}, *inputs)
|
|
126
126
|
|
|
127
127
|
@torch.no_grad
|
|
128
|
-
def transform(self,
|
|
128
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
129
129
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
130
130
|
maximum = cast(list, sorted_inputs[0])
|
|
131
131
|
if len(sorted_inputs) > 1:
|
|
@@ -139,7 +139,7 @@ class MinimumModules(ReduceOperation):
|
|
|
139
139
|
super().__init__({}, *inputs)
|
|
140
140
|
|
|
141
141
|
@torch.no_grad
|
|
142
|
-
def transform(self,
|
|
142
|
+
def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
143
143
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
144
144
|
minimum = cast(list, sorted_inputs[0])
|
|
145
145
|
if len(sorted_inputs) > 1:
|
torchzero/modules/ops/split.py
CHANGED
|
@@ -3,46 +3,46 @@ from typing import cast
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module,
|
|
6
|
+
from ...core import Chainable, Module, Var
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def _split(
|
|
10
10
|
module: Module,
|
|
11
11
|
idxs,
|
|
12
12
|
params,
|
|
13
|
-
|
|
13
|
+
var: Var,
|
|
14
14
|
):
|
|
15
15
|
split_params = [p for i,p in enumerate(params) if i in idxs]
|
|
16
16
|
|
|
17
17
|
split_grad = None
|
|
18
|
-
if
|
|
19
|
-
split_grad = [g for i,g in enumerate(
|
|
18
|
+
if var.grad is not None:
|
|
19
|
+
split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
|
|
20
20
|
|
|
21
21
|
split_update = None
|
|
22
|
-
if
|
|
23
|
-
split_update = [u for i,u in enumerate(
|
|
22
|
+
if var.update is not None:
|
|
23
|
+
split_update = [u for i,u in enumerate(var.update) if i in idxs]
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
split_var = var.clone(clone_update=False)
|
|
26
|
+
split_var.params = split_params
|
|
27
|
+
split_var.grad = split_grad
|
|
28
|
+
split_var.update = split_update
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
split_var = module.step(split_var)
|
|
31
31
|
|
|
32
|
-
if (
|
|
33
|
-
|
|
32
|
+
if (var.grad is None) and (split_var.grad is not None):
|
|
33
|
+
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
34
34
|
|
|
35
|
-
if
|
|
35
|
+
if split_var.update is not None:
|
|
36
36
|
|
|
37
|
-
if
|
|
38
|
-
if
|
|
39
|
-
else:
|
|
37
|
+
if var.update is None:
|
|
38
|
+
if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
|
|
39
|
+
else: var.update = [g.clone() for g in var.grad]
|
|
40
40
|
|
|
41
|
-
for idx, u in zip(idxs,
|
|
42
|
-
|
|
41
|
+
for idx, u in zip(idxs, split_var.update):
|
|
42
|
+
var.update[idx] = u
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
return
|
|
44
|
+
var.update_attrs_from_clone_(split_var)
|
|
45
|
+
return var
|
|
46
46
|
|
|
47
47
|
class Split(Module):
|
|
48
48
|
"""Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
|
|
@@ -53,9 +53,9 @@ class Split(Module):
|
|
|
53
53
|
if true is not None: self.set_child('true', true)
|
|
54
54
|
if false is not None: self.set_child('false', false)
|
|
55
55
|
|
|
56
|
-
def step(self,
|
|
56
|
+
def step(self, var):
|
|
57
57
|
|
|
58
|
-
params =
|
|
58
|
+
params = var.params
|
|
59
59
|
filter = self.settings[params[0]]['filter']
|
|
60
60
|
|
|
61
61
|
true_idxs = []
|
|
@@ -66,10 +66,10 @@ class Split(Module):
|
|
|
66
66
|
|
|
67
67
|
if 'true' in self.children:
|
|
68
68
|
true = self.children['true']
|
|
69
|
-
|
|
69
|
+
var = _split(true, idxs=true_idxs, params=params, var=var)
|
|
70
70
|
|
|
71
71
|
if 'false' in self.children:
|
|
72
72
|
false = self.children['false']
|
|
73
|
-
|
|
73
|
+
var = _split(false, idxs=false_idxs, params=params, var=var)
|
|
74
74
|
|
|
75
|
-
return
|
|
75
|
+
return var
|
torchzero/modules/ops/switch.py
CHANGED
|
@@ -23,16 +23,16 @@ class Alternate(Module):
|
|
|
23
23
|
self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
|
|
24
24
|
|
|
25
25
|
@torch.no_grad
|
|
26
|
-
def step(self,
|
|
26
|
+
def step(self, var):
|
|
27
27
|
# get current module
|
|
28
28
|
current_module_idx = self.global_state.setdefault('current_module_idx', 0)
|
|
29
29
|
module = self.children[f'module_{current_module_idx}']
|
|
30
30
|
|
|
31
31
|
# step
|
|
32
|
-
|
|
32
|
+
var = module.step(var.clone(clone_update=False))
|
|
33
33
|
|
|
34
34
|
# number of steps until next module
|
|
35
|
-
steps = self.settings[
|
|
35
|
+
steps = self.settings[var.params[0]]['steps']
|
|
36
36
|
if isinstance(steps, int): steps = [steps]*len(self.children)
|
|
37
37
|
|
|
38
38
|
if 'steps_to_next' not in self.global_state:
|
|
@@ -51,7 +51,7 @@ class Alternate(Module):
|
|
|
51
51
|
|
|
52
52
|
self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
|
|
53
53
|
|
|
54
|
-
return
|
|
54
|
+
return var
|
|
55
55
|
|
|
56
56
|
class Switch(Alternate):
|
|
57
57
|
"""switch to next module after some steps"""
|
torchzero/modules/ops/unary.py
CHANGED
|
@@ -3,7 +3,7 @@ from collections import deque
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import TensorwiseTransform, Target, Transform
|
|
6
|
-
from ...utils import TensorList
|
|
6
|
+
from ...utils import TensorList, unpack_dicts,unpack_states
|
|
7
7
|
|
|
8
8
|
class UnaryLambda(Transform):
|
|
9
9
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
@@ -11,8 +11,8 @@ class UnaryLambda(Transform):
|
|
|
11
11
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
12
12
|
|
|
13
13
|
@torch.no_grad
|
|
14
|
-
def
|
|
15
|
-
return
|
|
14
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
15
|
+
return settings[0]['fn'](tensors)
|
|
16
16
|
|
|
17
17
|
class UnaryParameterwiseLambda(TensorwiseTransform):
|
|
18
18
|
def __init__(self, fn, target: "Target" = 'update'):
|
|
@@ -20,8 +20,8 @@ class UnaryParameterwiseLambda(TensorwiseTransform):
|
|
|
20
20
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
21
21
|
|
|
22
22
|
@torch.no_grad
|
|
23
|
-
def
|
|
24
|
-
return
|
|
23
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
24
|
+
return settings['fn'](tensor)
|
|
25
25
|
|
|
26
26
|
class CustomUnaryOperation(Transform):
|
|
27
27
|
def __init__(self, name: str, target: "Target" = 'update'):
|
|
@@ -29,35 +29,35 @@ class CustomUnaryOperation(Transform):
|
|
|
29
29
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
30
30
|
|
|
31
31
|
@torch.no_grad
|
|
32
|
-
def
|
|
33
|
-
return getattr(tensors,
|
|
32
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
33
|
+
return getattr(tensors, settings[0]['name'])()
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class Abs(Transform):
|
|
37
37
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
38
38
|
@torch.no_grad
|
|
39
|
-
def
|
|
39
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
40
40
|
torch._foreach_abs_(tensors)
|
|
41
41
|
return tensors
|
|
42
42
|
|
|
43
43
|
class Sign(Transform):
|
|
44
44
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
45
45
|
@torch.no_grad
|
|
46
|
-
def
|
|
46
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
47
47
|
torch._foreach_sign_(tensors)
|
|
48
48
|
return tensors
|
|
49
49
|
|
|
50
50
|
class Exp(Transform):
|
|
51
51
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
52
52
|
@torch.no_grad
|
|
53
|
-
def
|
|
53
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
54
54
|
torch._foreach_exp_(tensors)
|
|
55
55
|
return tensors
|
|
56
56
|
|
|
57
57
|
class Sqrt(Transform):
|
|
58
58
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
59
59
|
@torch.no_grad
|
|
60
|
-
def
|
|
60
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
61
61
|
torch._foreach_sqrt_(tensors)
|
|
62
62
|
return tensors
|
|
63
63
|
|
|
@@ -66,8 +66,8 @@ class Reciprocal(Transform):
|
|
|
66
66
|
defaults = dict(eps = eps)
|
|
67
67
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
68
68
|
@torch.no_grad
|
|
69
|
-
def
|
|
70
|
-
eps =
|
|
69
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
70
|
+
eps = [s['eps'] for s in settings]
|
|
71
71
|
if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
|
|
72
72
|
torch._foreach_reciprocal_(tensors)
|
|
73
73
|
return tensors
|
|
@@ -75,7 +75,7 @@ class Reciprocal(Transform):
|
|
|
75
75
|
class Negate(Transform):
|
|
76
76
|
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
77
77
|
@torch.no_grad
|
|
78
|
-
def
|
|
78
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
79
79
|
torch._foreach_neg_(tensors)
|
|
80
80
|
return tensors
|
|
81
81
|
|
|
@@ -97,8 +97,8 @@ class NanToNum(Transform):
|
|
|
97
97
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
98
98
|
|
|
99
99
|
@torch.no_grad
|
|
100
|
-
def
|
|
101
|
-
nan, posinf, neginf =
|
|
100
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
101
|
+
nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
|
|
102
102
|
return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
|
|
103
103
|
|
|
104
104
|
class Rescale(Transform):
|
|
@@ -108,8 +108,8 @@ class Rescale(Transform):
|
|
|
108
108
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
109
109
|
|
|
110
110
|
@torch.no_grad
|
|
111
|
-
def
|
|
112
|
-
min,max =
|
|
113
|
-
tensorwise =
|
|
111
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
112
|
+
min, max = unpack_dicts(settings, 'min','max')
|
|
113
|
+
tensorwise = settings[0]['tensorwise']
|
|
114
114
|
dim = None if tensorwise else 'global'
|
|
115
|
-
return TensorList(tensors).rescale(min=min, max=max, eps=
|
|
115
|
+
return TensorList(tensors).rescale(min=min, max=max, eps=settings[0]['eps'], dim=dim)
|
torchzero/modules/ops/utility.py
CHANGED
|
@@ -9,47 +9,47 @@ from ...utils.tensorlist import Distributions, TensorList
|
|
|
9
9
|
class Clone(Transform):
|
|
10
10
|
def __init__(self): super().__init__({}, uses_grad=False)
|
|
11
11
|
@torch.no_grad
|
|
12
|
-
def
|
|
12
|
+
def apply(self, tensors, params, grads, loss, states, settings): return [t.clone() for t in tensors]
|
|
13
13
|
|
|
14
14
|
class Grad(Module):
|
|
15
15
|
def __init__(self):
|
|
16
16
|
super().__init__({})
|
|
17
17
|
@torch.no_grad
|
|
18
|
-
def step(self,
|
|
19
|
-
|
|
20
|
-
return
|
|
18
|
+
def step(self, var):
|
|
19
|
+
var.update = [g.clone() for g in var.get_grad()]
|
|
20
|
+
return var
|
|
21
21
|
|
|
22
22
|
class Params(Module):
|
|
23
23
|
def __init__(self):
|
|
24
24
|
super().__init__({})
|
|
25
25
|
@torch.no_grad
|
|
26
|
-
def step(self,
|
|
27
|
-
|
|
28
|
-
return
|
|
26
|
+
def step(self, var):
|
|
27
|
+
var.update = [p.clone() for p in var.params]
|
|
28
|
+
return var
|
|
29
29
|
|
|
30
30
|
class Update(Module):
|
|
31
31
|
def __init__(self):
|
|
32
32
|
super().__init__({})
|
|
33
33
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
35
|
-
|
|
36
|
-
return
|
|
34
|
+
def step(self, var):
|
|
35
|
+
var.update = [u.clone() for u in var.get_update()]
|
|
36
|
+
return var
|
|
37
37
|
|
|
38
38
|
class Zeros(Module):
|
|
39
39
|
def __init__(self):
|
|
40
40
|
super().__init__({})
|
|
41
41
|
@torch.no_grad
|
|
42
|
-
def step(self,
|
|
43
|
-
|
|
44
|
-
return
|
|
42
|
+
def step(self, var):
|
|
43
|
+
var.update = [torch.zeros_like(p) for p in var.params]
|
|
44
|
+
return var
|
|
45
45
|
|
|
46
46
|
class Ones(Module):
|
|
47
47
|
def __init__(self):
|
|
48
48
|
super().__init__({})
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def step(self,
|
|
51
|
-
|
|
52
|
-
return
|
|
50
|
+
def step(self, var):
|
|
51
|
+
var.update = [torch.ones_like(p) for p in var.params]
|
|
52
|
+
return var
|
|
53
53
|
|
|
54
54
|
class Fill(Module):
|
|
55
55
|
def __init__(self, value: float):
|
|
@@ -57,9 +57,9 @@ class Fill(Module):
|
|
|
57
57
|
super().__init__(defaults)
|
|
58
58
|
|
|
59
59
|
@torch.no_grad
|
|
60
|
-
def step(self,
|
|
61
|
-
|
|
62
|
-
return
|
|
60
|
+
def step(self, var):
|
|
61
|
+
var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
|
|
62
|
+
return var
|
|
63
63
|
|
|
64
64
|
class RandomSample(Module):
|
|
65
65
|
def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
|
|
@@ -67,20 +67,20 @@ class RandomSample(Module):
|
|
|
67
67
|
super().__init__(defaults)
|
|
68
68
|
|
|
69
69
|
@torch.no_grad
|
|
70
|
-
def step(self,
|
|
71
|
-
|
|
72
|
-
eps=self.
|
|
70
|
+
def step(self, var):
|
|
71
|
+
var.update = TensorList(var.params).sample_like(
|
|
72
|
+
eps=[self.settings[p]['eps'] for p in var.params], distribution=self.settings[var.params[0]]['distribution']
|
|
73
73
|
)
|
|
74
|
-
return
|
|
74
|
+
return var
|
|
75
75
|
|
|
76
76
|
class Randn(Module):
|
|
77
77
|
def __init__(self):
|
|
78
78
|
super().__init__({})
|
|
79
79
|
|
|
80
80
|
@torch.no_grad
|
|
81
|
-
def step(self,
|
|
82
|
-
|
|
83
|
-
return
|
|
81
|
+
def step(self, var):
|
|
82
|
+
var.update = [torch.randn_like(p) for p in var.params]
|
|
83
|
+
return var
|
|
84
84
|
|
|
85
85
|
class Uniform(Module):
|
|
86
86
|
def __init__(self, low: float, high: float):
|
|
@@ -88,25 +88,25 @@ class Uniform(Module):
|
|
|
88
88
|
super().__init__(defaults)
|
|
89
89
|
|
|
90
90
|
@torch.no_grad
|
|
91
|
-
def step(self,
|
|
92
|
-
low,high = self.get_settings('low','high'
|
|
93
|
-
|
|
94
|
-
return
|
|
91
|
+
def step(self, var):
|
|
92
|
+
low,high = self.get_settings(var.params, 'low','high')
|
|
93
|
+
var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
|
|
94
|
+
return var
|
|
95
95
|
|
|
96
96
|
class GradToNone(Module):
|
|
97
97
|
def __init__(self): super().__init__()
|
|
98
|
-
def step(self,
|
|
99
|
-
|
|
100
|
-
return
|
|
98
|
+
def step(self, var):
|
|
99
|
+
var.grad = None
|
|
100
|
+
return var
|
|
101
101
|
|
|
102
102
|
class UpdateToNone(Module):
|
|
103
103
|
def __init__(self): super().__init__()
|
|
104
|
-
def step(self,
|
|
105
|
-
|
|
106
|
-
return
|
|
104
|
+
def step(self, var):
|
|
105
|
+
var.update = None
|
|
106
|
+
return var
|
|
107
107
|
|
|
108
108
|
class Identity(Module):
|
|
109
109
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
110
|
-
def step(self,
|
|
110
|
+
def step(self, var): return var
|
|
111
111
|
|
|
112
112
|
NoOp = Identity
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
|
-
|
|
5
5
|
from ...core import (
|
|
6
6
|
Chainable,
|
|
7
7
|
Module,
|
|
8
|
-
Preconditioner,
|
|
9
8
|
Target,
|
|
10
|
-
|
|
9
|
+
TensorwiseTransform,
|
|
11
10
|
Transform,
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
Var,
|
|
12
|
+
apply_transform,
|
|
14
13
|
)
|
|
15
|
-
from ...utils import NumberList, TensorList
|
|
14
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
16
15
|
from ...utils.linalg import matrix_power_eigh
|
|
17
16
|
from ..functional import add_power_, lerp_power_, root
|
|
18
17
|
|
|
@@ -31,7 +30,6 @@ def adagrad_(
|
|
|
31
30
|
inner: Module | None = None,
|
|
32
31
|
params: list[torch.Tensor] | None = None,
|
|
33
32
|
grads: list[torch.Tensor] | None = None,
|
|
34
|
-
vars: Vars | None = None,
|
|
35
33
|
):
|
|
36
34
|
"""returns `tensors_`"""
|
|
37
35
|
clr = alpha / (1 + step * lr_decay)
|
|
@@ -40,7 +38,7 @@ def adagrad_(
|
|
|
40
38
|
|
|
41
39
|
if inner is not None:
|
|
42
40
|
assert params is not None
|
|
43
|
-
tensors_ = TensorList(
|
|
41
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
44
42
|
|
|
45
43
|
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
46
44
|
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
@@ -79,19 +77,19 @@ class Adagrad(Transform):
|
|
|
79
77
|
self.set_child('inner', inner)
|
|
80
78
|
|
|
81
79
|
@torch.no_grad
|
|
82
|
-
def
|
|
80
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
83
81
|
tensors = TensorList(tensors)
|
|
84
82
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
85
83
|
|
|
86
|
-
lr_decay,alpha,eps =
|
|
84
|
+
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
87
85
|
|
|
88
|
-
pow, use_sqrt = itemgetter('pow', 'use_sqrt')(
|
|
86
|
+
pow, use_sqrt = itemgetter('pow', 'use_sqrt')(settings[0])
|
|
89
87
|
|
|
90
|
-
sq_sum =
|
|
88
|
+
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
91
89
|
|
|
92
90
|
# initialize accumulator on 1st step
|
|
93
91
|
if step == 1:
|
|
94
|
-
sq_sum.set_(tensors.full_like(
|
|
92
|
+
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
95
93
|
|
|
96
94
|
return adagrad_(
|
|
97
95
|
tensors,
|
|
@@ -107,40 +105,51 @@ class Adagrad(Transform):
|
|
|
107
105
|
inner=self.children.get("inner", None),
|
|
108
106
|
params=params,
|
|
109
107
|
grads=grads,
|
|
110
|
-
vars=vars,
|
|
111
108
|
)
|
|
112
109
|
|
|
113
110
|
|
|
114
111
|
|
|
115
|
-
class FullMatrixAdagrad(
|
|
116
|
-
def __init__(self, beta: float | None = None, decay: float | None = None, concat_params=False, update_freq=1, inner: Chainable | None = None):
|
|
117
|
-
defaults = dict(beta=beta, decay=decay)
|
|
112
|
+
class FullMatrixAdagrad(TensorwiseTransform):
|
|
113
|
+
def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=False, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', inner: Chainable | None = None):
|
|
114
|
+
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init)
|
|
118
115
|
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
119
116
|
|
|
120
117
|
@torch.no_grad
|
|
121
|
-
def update_tensor(self, tensor, param, grad, state, settings):
|
|
118
|
+
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
122
119
|
G = tensor.ravel()
|
|
123
120
|
GG = torch.outer(G, G)
|
|
124
121
|
decay = settings['decay']
|
|
125
122
|
beta = settings['beta']
|
|
126
|
-
|
|
127
|
-
|
|
123
|
+
init = settings['init']
|
|
124
|
+
|
|
125
|
+
if 'GG' not in state:
|
|
126
|
+
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
127
|
+
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
128
|
+
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
129
|
+
elif init == 'GGT': state['GG'] = GG.clone()
|
|
130
|
+
else: raise ValueError(init)
|
|
128
131
|
if decay is not None: state['GG'].mul_(decay)
|
|
129
132
|
|
|
130
133
|
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
131
134
|
else: state['GG'].add_(GG)
|
|
132
135
|
|
|
133
136
|
@torch.no_grad
|
|
134
|
-
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
137
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
135
138
|
GG = state['GG']
|
|
139
|
+
sqrt = settings['sqrt']
|
|
136
140
|
|
|
137
141
|
if tensor.numel() == 1:
|
|
138
|
-
|
|
142
|
+
GG = GG.squeeze()
|
|
143
|
+
if sqrt: return tensor / GG.sqrt()
|
|
144
|
+
return tensor / GG
|
|
139
145
|
|
|
140
146
|
try:
|
|
141
|
-
B = matrix_power_eigh(GG, -1/2)
|
|
147
|
+
if sqrt: B = matrix_power_eigh(GG, -1/2)
|
|
148
|
+
else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
|
|
149
|
+
|
|
142
150
|
except torch.linalg.LinAlgError:
|
|
143
|
-
|
|
151
|
+
scale = 1 / tensor.abs().max()
|
|
152
|
+
return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
|
|
144
153
|
|
|
145
154
|
return (B @ tensor.ravel()).view_as(tensor)
|
|
146
155
|
|