torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
""""""
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Iterable,Sequence
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Target, Vars, maybe_chain
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ReduceOperation(Module, ABC):
|
|
12
|
+
"""Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
|
|
13
|
+
def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
|
|
14
|
+
super().__init__(defaults=defaults)
|
|
15
|
+
|
|
16
|
+
self.operands = []
|
|
17
|
+
for i, v in enumerate(operands):
|
|
18
|
+
|
|
19
|
+
if isinstance(v, (Module, Sequence)):
|
|
20
|
+
self.set_child(f'operand_{i}', v)
|
|
21
|
+
self.operands.append(self.children[f'operand_{i}'])
|
|
22
|
+
else:
|
|
23
|
+
self.operands.append(v)
|
|
24
|
+
|
|
25
|
+
if not self.children:
|
|
26
|
+
raise ValueError('At least one operand must be a module')
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def transform(self, vars: Vars, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
30
|
+
"""applies the operation to operands"""
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
@torch.no_grad
|
|
34
|
+
def step(self, vars: Vars) -> Vars:
|
|
35
|
+
# pass cloned update to all module operands
|
|
36
|
+
processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
|
+
|
|
38
|
+
for i, v in enumerate(self.operands):
|
|
39
|
+
if f'operand_{i}' in self.children:
|
|
40
|
+
v: Module
|
|
41
|
+
updated_vars = v.step(vars.clone(clone_update=True))
|
|
42
|
+
processed_operands[i] = updated_vars.get_update()
|
|
43
|
+
vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
|
|
44
|
+
|
|
45
|
+
transformed = self.transform(vars, *processed_operands)
|
|
46
|
+
vars.update = transformed
|
|
47
|
+
return vars
|
|
48
|
+
|
|
49
|
+
class Sum(ReduceOperation):
|
|
50
|
+
USE_MEAN = False
|
|
51
|
+
def __init__(self, *inputs: Chainable | float):
|
|
52
|
+
super().__init__({}, *inputs)
|
|
53
|
+
|
|
54
|
+
@torch.no_grad
|
|
55
|
+
def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
56
|
+
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
57
|
+
sum = cast(list, sorted_inputs[0])
|
|
58
|
+
if len(sorted_inputs) > 1:
|
|
59
|
+
for v in sorted_inputs[1:]:
|
|
60
|
+
torch._foreach_add_(sum, v)
|
|
61
|
+
|
|
62
|
+
if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
|
|
63
|
+
return sum
|
|
64
|
+
|
|
65
|
+
class Mean(Sum):
|
|
66
|
+
USE_MEAN = True
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class WeightedSum(ReduceOperation):
|
|
70
|
+
USE_MEAN = False
|
|
71
|
+
def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
|
|
72
|
+
weights = list(weights)
|
|
73
|
+
if len(inputs) != len(weights):
|
|
74
|
+
raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
|
|
75
|
+
defaults = dict(weights=weights)
|
|
76
|
+
super().__init__(defaults=defaults, *inputs)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
80
|
+
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
81
|
+
weights = self.settings[vars.params[0]]['weights']
|
|
82
|
+
sum = cast(list, sorted_inputs[0])
|
|
83
|
+
torch._foreach_mul_(sum, weights[0])
|
|
84
|
+
if len(sorted_inputs) > 1:
|
|
85
|
+
for v, w in zip(sorted_inputs[1:], weights[1:]):
|
|
86
|
+
if isinstance(v, (int, float)): torch._foreach_add_(sum, v*w)
|
|
87
|
+
else: torch._foreach_add_(sum, v, alpha=w)
|
|
88
|
+
|
|
89
|
+
if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
|
|
90
|
+
return sum
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class WeightedMean(WeightedSum):
|
|
94
|
+
USE_MEAN = True
|
|
95
|
+
|
|
96
|
+
class Median(ReduceOperation):
|
|
97
|
+
def __init__(self, *inputs: Chainable | float):
|
|
98
|
+
super().__init__({}, *inputs)
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
102
|
+
res = []
|
|
103
|
+
lists = [i for i in inputs if isinstance(i, list)]
|
|
104
|
+
floats = [i for i in inputs if isinstance(i, (int,float))]
|
|
105
|
+
for tensors in zip(*lists):
|
|
106
|
+
res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
|
|
107
|
+
return res
|
|
108
|
+
|
|
109
|
+
class Prod(ReduceOperation):
|
|
110
|
+
def __init__(self, *inputs: Chainable | float):
|
|
111
|
+
super().__init__({}, *inputs)
|
|
112
|
+
|
|
113
|
+
@torch.no_grad
|
|
114
|
+
def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
115
|
+
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
116
|
+
prod = cast(list, sorted_inputs[0])
|
|
117
|
+
if len(sorted_inputs) > 1:
|
|
118
|
+
for v in sorted_inputs[1:]:
|
|
119
|
+
torch._foreach_mul_(prod, v)
|
|
120
|
+
|
|
121
|
+
return prod
|
|
122
|
+
|
|
123
|
+
class MaximumModules(ReduceOperation):
|
|
124
|
+
def __init__(self, *inputs: Chainable | float):
|
|
125
|
+
super().__init__({}, *inputs)
|
|
126
|
+
|
|
127
|
+
@torch.no_grad
|
|
128
|
+
def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
129
|
+
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
130
|
+
maximum = cast(list, sorted_inputs[0])
|
|
131
|
+
if len(sorted_inputs) > 1:
|
|
132
|
+
for v in sorted_inputs[1:]:
|
|
133
|
+
torch._foreach_maximum_(maximum, v)
|
|
134
|
+
|
|
135
|
+
return maximum
|
|
136
|
+
|
|
137
|
+
class MinimumModules(ReduceOperation):
|
|
138
|
+
def __init__(self, *inputs: Chainable | float):
|
|
139
|
+
super().__init__({}, *inputs)
|
|
140
|
+
|
|
141
|
+
@torch.no_grad
|
|
142
|
+
def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
143
|
+
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
144
|
+
minimum = cast(list, sorted_inputs[0])
|
|
145
|
+
if len(sorted_inputs) > 1:
|
|
146
|
+
for v in sorted_inputs[1:]:
|
|
147
|
+
torch._foreach_minimum_(minimum, v)
|
|
148
|
+
|
|
149
|
+
return minimum
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, Vars
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _split(
|
|
10
|
+
module: Module,
|
|
11
|
+
idxs,
|
|
12
|
+
params,
|
|
13
|
+
vars: Vars,
|
|
14
|
+
):
|
|
15
|
+
split_params = [p for i,p in enumerate(params) if i in idxs]
|
|
16
|
+
|
|
17
|
+
split_grad = None
|
|
18
|
+
if vars.grad is not None:
|
|
19
|
+
split_grad = [g for i,g in enumerate(vars.grad) if i in idxs]
|
|
20
|
+
|
|
21
|
+
split_update = None
|
|
22
|
+
if vars.update is not None:
|
|
23
|
+
split_update = [u for i,u in enumerate(vars.update) if i in idxs]
|
|
24
|
+
|
|
25
|
+
split_vars = vars.clone(clone_update=False)
|
|
26
|
+
split_vars.params = split_params
|
|
27
|
+
split_vars.grad = split_grad
|
|
28
|
+
split_vars.update = split_update
|
|
29
|
+
|
|
30
|
+
split_vars = module.step(split_vars)
|
|
31
|
+
|
|
32
|
+
if (vars.grad is None) and (split_vars.grad is not None):
|
|
33
|
+
vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
34
|
+
|
|
35
|
+
if split_vars.update is not None:
|
|
36
|
+
|
|
37
|
+
if vars.update is None:
|
|
38
|
+
if vars.grad is None: vars.update = [cast(torch.Tensor, None) for _ in vars.params]
|
|
39
|
+
else: vars.update = [g.clone() for g in vars.grad]
|
|
40
|
+
|
|
41
|
+
for idx, u in zip(idxs, split_vars.update):
|
|
42
|
+
vars.update[idx] = u
|
|
43
|
+
|
|
44
|
+
vars.update_attrs_from_clone_(split_vars)
|
|
45
|
+
return vars
|
|
46
|
+
|
|
47
|
+
class Split(Module):
|
|
48
|
+
"""Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
|
|
49
|
+
def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
|
|
50
|
+
defaults = dict(filter=filter)
|
|
51
|
+
super().__init__(defaults)
|
|
52
|
+
|
|
53
|
+
if true is not None: self.set_child('true', true)
|
|
54
|
+
if false is not None: self.set_child('false', false)
|
|
55
|
+
|
|
56
|
+
def step(self, vars):
|
|
57
|
+
|
|
58
|
+
params = vars.params
|
|
59
|
+
filter = self.settings[params[0]]['filter']
|
|
60
|
+
|
|
61
|
+
true_idxs = []
|
|
62
|
+
false_idxs = []
|
|
63
|
+
for i,p in enumerate(params):
|
|
64
|
+
if filter(p): true_idxs.append(i)
|
|
65
|
+
else: false_idxs.append(i)
|
|
66
|
+
|
|
67
|
+
if 'true' in self.children:
|
|
68
|
+
true = self.children['true']
|
|
69
|
+
vars = _split(true, idxs=true_idxs, params=params, vars=vars)
|
|
70
|
+
|
|
71
|
+
if 'false' in self.children:
|
|
72
|
+
false = self.children['false']
|
|
73
|
+
vars = _split(false, idxs=false_idxs, params=params, vars=vars)
|
|
74
|
+
|
|
75
|
+
return vars
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Alternate(Module):
|
|
10
|
+
"""alternate between stepping with `modules`"""
|
|
11
|
+
LOOP = True
|
|
12
|
+
def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
|
|
13
|
+
if isinstance(steps, Iterable):
|
|
14
|
+
steps = list(steps)
|
|
15
|
+
if len(steps) != len(modules):
|
|
16
|
+
raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
|
|
17
|
+
|
|
18
|
+
defaults = dict(steps=steps)
|
|
19
|
+
super().__init__(defaults)
|
|
20
|
+
|
|
21
|
+
self.set_children_sequence(modules)
|
|
22
|
+
self.global_state['current_module_idx'] = 0
|
|
23
|
+
self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
|
|
24
|
+
|
|
25
|
+
@torch.no_grad
|
|
26
|
+
def step(self, vars):
|
|
27
|
+
# get current module
|
|
28
|
+
current_module_idx = self.global_state.setdefault('current_module_idx', 0)
|
|
29
|
+
module = self.children[f'module_{current_module_idx}']
|
|
30
|
+
|
|
31
|
+
# step
|
|
32
|
+
vars = module.step(vars.clone(clone_update=False))
|
|
33
|
+
|
|
34
|
+
# number of steps until next module
|
|
35
|
+
steps = self.settings[vars.params[0]]['steps']
|
|
36
|
+
if isinstance(steps, int): steps = [steps]*len(self.children)
|
|
37
|
+
|
|
38
|
+
if 'steps_to_next' not in self.global_state:
|
|
39
|
+
self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
|
|
40
|
+
|
|
41
|
+
self.global_state['steps_to_next'] -= 1
|
|
42
|
+
|
|
43
|
+
# switch to next module
|
|
44
|
+
if self.global_state['steps_to_next'] == 0:
|
|
45
|
+
self.global_state['current_module_idx'] += 1
|
|
46
|
+
|
|
47
|
+
# loop to first module (or keep using last module on Switch)
|
|
48
|
+
if self.global_state['current_module_idx'] > len(self.children) - 1:
|
|
49
|
+
if self.LOOP: self.global_state['current_module_idx'] = 0
|
|
50
|
+
else: self.global_state['current_module_idx'] = len(self.children) - 1
|
|
51
|
+
|
|
52
|
+
self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
|
|
53
|
+
|
|
54
|
+
return vars
|
|
55
|
+
|
|
56
|
+
class Switch(Alternate):
|
|
57
|
+
"""switch to next module after some steps"""
|
|
58
|
+
LOOP = False
|
|
59
|
+
def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
|
|
60
|
+
|
|
61
|
+
if isinstance(steps, Iterable):
|
|
62
|
+
steps = list(steps)
|
|
63
|
+
if len(steps) != len(modules) - 1:
|
|
64
|
+
raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
|
|
65
|
+
|
|
66
|
+
steps.append(1)
|
|
67
|
+
|
|
68
|
+
super().__init__(*modules, steps=steps)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import TensorwiseTransform, Target, Transform
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
|
|
8
|
+
class UnaryLambda(Transform):
|
|
9
|
+
def __init__(self, fn, target: "Target" = 'update'):
|
|
10
|
+
defaults = dict(fn=fn)
|
|
11
|
+
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
12
|
+
|
|
13
|
+
@torch.no_grad
|
|
14
|
+
def transform(self, tensors, params, grads, vars):
|
|
15
|
+
return self.settings[params[0]]['fn'](tensors)
|
|
16
|
+
|
|
17
|
+
class UnaryParameterwiseLambda(TensorwiseTransform):
|
|
18
|
+
def __init__(self, fn, target: "Target" = 'update'):
|
|
19
|
+
defaults = dict(fn=fn)
|
|
20
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
21
|
+
|
|
22
|
+
@torch.no_grad
|
|
23
|
+
def transform(self, tensor, param, grad, vars):
|
|
24
|
+
return self.settings[param]['fn'](tensor)
|
|
25
|
+
|
|
26
|
+
class CustomUnaryOperation(Transform):
|
|
27
|
+
def __init__(self, name: str, target: "Target" = 'update'):
|
|
28
|
+
defaults = dict(name=name)
|
|
29
|
+
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
30
|
+
|
|
31
|
+
@torch.no_grad
|
|
32
|
+
def transform(self, tensors, params, grads, vars):
|
|
33
|
+
return getattr(tensors, self.settings[params[0]]['name'])()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Abs(Transform):
|
|
37
|
+
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
38
|
+
@torch.no_grad
|
|
39
|
+
def transform(self, tensors, params, grads, vars):
|
|
40
|
+
torch._foreach_abs_(tensors)
|
|
41
|
+
return tensors
|
|
42
|
+
|
|
43
|
+
class Sign(Transform):
|
|
44
|
+
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
45
|
+
@torch.no_grad
|
|
46
|
+
def transform(self, tensors, params, grads, vars):
|
|
47
|
+
torch._foreach_sign_(tensors)
|
|
48
|
+
return tensors
|
|
49
|
+
|
|
50
|
+
class Exp(Transform):
|
|
51
|
+
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
52
|
+
@torch.no_grad
|
|
53
|
+
def transform(self, tensors, params, grads, vars):
|
|
54
|
+
torch._foreach_exp_(tensors)
|
|
55
|
+
return tensors
|
|
56
|
+
|
|
57
|
+
class Sqrt(Transform):
|
|
58
|
+
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def transform(self, tensors, params, grads, vars):
|
|
61
|
+
torch._foreach_sqrt_(tensors)
|
|
62
|
+
return tensors
|
|
63
|
+
|
|
64
|
+
class Reciprocal(Transform):
|
|
65
|
+
def __init__(self, eps = 0, target: "Target" = 'update'):
|
|
66
|
+
defaults = dict(eps = eps)
|
|
67
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
68
|
+
@torch.no_grad
|
|
69
|
+
def transform(self, tensors, params, grads, vars):
|
|
70
|
+
eps = self.get_settings('eps', params=params)
|
|
71
|
+
if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
|
|
72
|
+
torch._foreach_reciprocal_(tensors)
|
|
73
|
+
return tensors
|
|
74
|
+
|
|
75
|
+
class Negate(Transform):
|
|
76
|
+
def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def transform(self, tensors, params, grads, vars):
|
|
79
|
+
torch._foreach_neg_(tensors)
|
|
80
|
+
return tensors
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class NanToNum(Transform):
|
|
84
|
+
"""Convert `nan`, `inf` and `-inf` to numbers.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
nan (optional): the value to replace NaNs with. Default is zero.
|
|
88
|
+
posinf (optional): if a Number, the value to replace positive infinity values with.
|
|
89
|
+
If None, positive infinity values are replaced with the greatest finite value
|
|
90
|
+
representable by input's dtype. Default is None.
|
|
91
|
+
neginf (optional): if a Number, the value to replace negative infinity values with.
|
|
92
|
+
If None, negative infinity values are replaced with the lowest finite value
|
|
93
|
+
representable by input's dtype. Default is None.
|
|
94
|
+
"""
|
|
95
|
+
def __init__(self, nan=None, posinf=None, neginf=None, target: "Target" = 'update'):
|
|
96
|
+
defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
|
|
97
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
98
|
+
|
|
99
|
+
@torch.no_grad
|
|
100
|
+
def transform(self, tensors, params, grads, vars):
|
|
101
|
+
nan, posinf, neginf = self.get_settings('nan', 'posinf', 'neginf', params=params)
|
|
102
|
+
return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
|
|
103
|
+
|
|
104
|
+
class Rescale(Transform):
|
|
105
|
+
"""rescale update to (min, max) range"""
|
|
106
|
+
def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
|
|
107
|
+
defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
|
|
108
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
109
|
+
|
|
110
|
+
@torch.no_grad
|
|
111
|
+
def transform(self, tensors, params, grads, vars):
|
|
112
|
+
min,max = self.get_settings('min','max', params=params)
|
|
113
|
+
tensorwise = self.settings[params[0]]['tensorwise']
|
|
114
|
+
dim = None if tensorwise else 'global'
|
|
115
|
+
return TensorList(tensors).rescale(min=min, max=max, eps=self.settings[params[0]]['eps'], dim=dim)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, Target, Transform
|
|
6
|
+
from ...utils.tensorlist import Distributions, TensorList
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Clone(Transform):
|
|
10
|
+
def __init__(self): super().__init__({}, uses_grad=False)
|
|
11
|
+
@torch.no_grad
|
|
12
|
+
def transform(self, tensors, params, grads, vars): return [t.clone() for t in tensors]
|
|
13
|
+
|
|
14
|
+
class Grad(Module):
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__({})
|
|
17
|
+
@torch.no_grad
|
|
18
|
+
def step(self, vars):
|
|
19
|
+
vars.update = [g.clone() for g in vars.get_grad()]
|
|
20
|
+
return vars
|
|
21
|
+
|
|
22
|
+
class Params(Module):
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__({})
|
|
25
|
+
@torch.no_grad
|
|
26
|
+
def step(self, vars):
|
|
27
|
+
vars.update = [p.clone() for p in vars.params]
|
|
28
|
+
return vars
|
|
29
|
+
|
|
30
|
+
class Update(Module):
|
|
31
|
+
def __init__(self):
|
|
32
|
+
super().__init__({})
|
|
33
|
+
@torch.no_grad
|
|
34
|
+
def step(self, vars):
|
|
35
|
+
vars.update = [u.clone() for u in vars.get_update()]
|
|
36
|
+
return vars
|
|
37
|
+
|
|
38
|
+
class Zeros(Module):
|
|
39
|
+
def __init__(self):
|
|
40
|
+
super().__init__({})
|
|
41
|
+
@torch.no_grad
|
|
42
|
+
def step(self, vars):
|
|
43
|
+
vars.update = [torch.zeros_like(p) for p in vars.params]
|
|
44
|
+
return vars
|
|
45
|
+
|
|
46
|
+
class Ones(Module):
|
|
47
|
+
def __init__(self):
|
|
48
|
+
super().__init__({})
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def step(self, vars):
|
|
51
|
+
vars.update = [torch.ones_like(p) for p in vars.params]
|
|
52
|
+
return vars
|
|
53
|
+
|
|
54
|
+
class Fill(Module):
|
|
55
|
+
def __init__(self, value: float):
|
|
56
|
+
defaults = dict(value=value)
|
|
57
|
+
super().__init__(defaults)
|
|
58
|
+
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def step(self, vars):
|
|
61
|
+
vars.update = [torch.full_like(p, self.settings[p]['value']) for p in vars.params]
|
|
62
|
+
return vars
|
|
63
|
+
|
|
64
|
+
class RandomSample(Module):
|
|
65
|
+
def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
|
|
66
|
+
defaults = dict(eps=eps, distribution=distribution)
|
|
67
|
+
super().__init__(defaults)
|
|
68
|
+
|
|
69
|
+
@torch.no_grad
|
|
70
|
+
def step(self, vars):
|
|
71
|
+
vars.update = TensorList(vars.params).sample_like(
|
|
72
|
+
eps=self.get_settings('eps',params=vars.params), distribution=self.settings[vars.params[0]]['distribution']
|
|
73
|
+
)
|
|
74
|
+
return vars
|
|
75
|
+
|
|
76
|
+
class Randn(Module):
|
|
77
|
+
def __init__(self):
|
|
78
|
+
super().__init__({})
|
|
79
|
+
|
|
80
|
+
@torch.no_grad
|
|
81
|
+
def step(self, vars):
|
|
82
|
+
vars.update = [torch.randn_like(p) for p in vars.params]
|
|
83
|
+
return vars
|
|
84
|
+
|
|
85
|
+
class Uniform(Module):
|
|
86
|
+
def __init__(self, low: float, high: float):
|
|
87
|
+
defaults = dict(low=low, high=high)
|
|
88
|
+
super().__init__(defaults)
|
|
89
|
+
|
|
90
|
+
@torch.no_grad
|
|
91
|
+
def step(self, vars):
|
|
92
|
+
low,high = self.get_settings('low','high', params=vars.params)
|
|
93
|
+
vars.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(vars.params, low, high)]
|
|
94
|
+
return vars
|
|
95
|
+
|
|
96
|
+
class GradToNone(Module):
|
|
97
|
+
def __init__(self): super().__init__()
|
|
98
|
+
def step(self, vars):
|
|
99
|
+
vars.grad = None
|
|
100
|
+
return vars
|
|
101
|
+
|
|
102
|
+
class UpdateToNone(Module):
|
|
103
|
+
def __init__(self): super().__init__()
|
|
104
|
+
def step(self, vars):
|
|
105
|
+
vars.update = None
|
|
106
|
+
return vars
|
|
107
|
+
|
|
108
|
+
class Identity(Module):
|
|
109
|
+
def __init__(self, *args, **kwargs): super().__init__()
|
|
110
|
+
def step(self, vars): return vars
|
|
111
|
+
|
|
112
|
+
NoOp = Identity
|
|
@@ -1,10 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from .
|
|
6
|
-
from .rprop import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
1
|
+
from .adagrad import Adagrad, FullMatrixAdagrad
|
|
2
|
+
from .adam import Adam
|
|
3
|
+
from .lion import Lion
|
|
4
|
+
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
5
|
+
from .rmsprop import RMSprop
|
|
6
|
+
from .rprop import (
|
|
7
|
+
BacktrackOnSignChange,
|
|
8
|
+
Rprop,
|
|
9
|
+
ScaleLRBySignChange,
|
|
10
|
+
SignConsistencyLRs,
|
|
11
|
+
SignConsistencyMask,
|
|
12
|
+
)
|
|
13
|
+
from .shampoo import Shampoo
|
|
14
|
+
from .soap import SOAP
|
|
15
|
+
from .orthograd import OrthoGrad, orthograd_
|
|
16
|
+
from .sophia_h import SophiaH
|
|
17
|
+
# from .curveball import CurveBall
|
|
18
|
+
# from .spectral import SpectralPreconditioner
|