torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Dropout(Transform):
|
|
13
|
+
"""Applies dropout to the update.
|
|
14
|
+
|
|
15
|
+
For each weight the update to that weight has :code:`p` probability to be set to 0.
|
|
16
|
+
This can be used to implement gradient dropout or update dropout depending on placement.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
|
|
20
|
+
graft (bool, optional):
|
|
21
|
+
if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
|
|
22
|
+
target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Examples:
|
|
26
|
+
Gradient dropout.
|
|
27
|
+
|
|
28
|
+
.. code-block:: python
|
|
29
|
+
|
|
30
|
+
opt = tz.Modular(
|
|
31
|
+
model.parameters(),
|
|
32
|
+
tz.m.Dropout(0.5),
|
|
33
|
+
tz.m.Adam(),
|
|
34
|
+
tz.m.LR(1e-3)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
Update dropout.
|
|
38
|
+
|
|
39
|
+
.. code-block:: python
|
|
40
|
+
|
|
41
|
+
opt = tz.Modular(
|
|
42
|
+
model.parameters(),
|
|
43
|
+
tz.m.Adam(),
|
|
44
|
+
tz.m.Dropout(0.5),
|
|
45
|
+
tz.m.LR(1e-3)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
|
|
50
|
+
defaults = dict(p=p, graft=graft)
|
|
51
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
52
|
+
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
55
|
+
tensors = TensorList(tensors)
|
|
56
|
+
p = NumberList(s['p'] for s in settings)
|
|
57
|
+
graft = settings[0]['graft']
|
|
58
|
+
|
|
59
|
+
if graft:
|
|
60
|
+
target_norm = tensors.global_vector_norm()
|
|
61
|
+
tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
62
|
+
return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
|
|
63
|
+
|
|
64
|
+
return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
65
|
+
|
|
66
|
+
def _bernoulli_like(tensor, p = 0.5, generator = None):
|
|
67
|
+
"""p is probability of a 1, other values will be 0."""
|
|
68
|
+
return torch.bernoulli(torch.full_like(tensor, p), generator = generator)
|
|
69
|
+
|
|
70
|
+
class WeightDropout(Module):
|
|
71
|
+
"""
|
|
72
|
+
Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
|
|
73
|
+
|
|
74
|
+
Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
|
|
78
|
+
graft (bool, optional):
|
|
79
|
+
if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
|
|
80
|
+
"""
|
|
81
|
+
def __init__(self, p: float = 0.5, graft: bool = True):
|
|
82
|
+
defaults = dict(p=p, graft=graft, use_dropout=True)
|
|
83
|
+
super().__init__(defaults)
|
|
84
|
+
|
|
85
|
+
@torch.no_grad
|
|
86
|
+
def step(self, var):
|
|
87
|
+
closure = var.closure
|
|
88
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
89
|
+
params = TensorList(var.params)
|
|
90
|
+
p = NumberList(self.settings[p]['p'] for p in params)
|
|
91
|
+
|
|
92
|
+
# create masks
|
|
93
|
+
mask = []
|
|
94
|
+
for p, m in zip(params, mask):
|
|
95
|
+
prob = self.settings[p]['p']
|
|
96
|
+
use_dropout = self.settings[p]['use_dropout']
|
|
97
|
+
if use_dropout: mask.append(_bernoulli_like(p, prob))
|
|
98
|
+
else: mask.append(torch.ones_like(p))
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def dropout_closure(backward=True):
|
|
102
|
+
orig_params = params.clone()
|
|
103
|
+
params.mul_(mask)
|
|
104
|
+
if backward:
|
|
105
|
+
with torch.enable_grad(): loss = closure()
|
|
106
|
+
else:
|
|
107
|
+
loss = closure(False)
|
|
108
|
+
params.copy_(orig_params)
|
|
109
|
+
return loss
|
|
110
|
+
|
|
111
|
+
var.closure = dropout_closure
|
|
112
|
+
return var
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class PerturbWeights(Module):
|
|
116
|
+
"""
|
|
117
|
+
Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
|
|
118
|
+
|
|
119
|
+
Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
|
|
123
|
+
relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
|
|
124
|
+
graft (bool, optional):
|
|
125
|
+
if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
|
|
126
|
+
"""
|
|
127
|
+
def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
|
|
128
|
+
defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
|
|
129
|
+
super().__init__(defaults)
|
|
130
|
+
|
|
131
|
+
@torch.no_grad
|
|
132
|
+
def step(self, var):
|
|
133
|
+
closure = var.closure
|
|
134
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
135
|
+
params = TensorList(var.params)
|
|
136
|
+
|
|
137
|
+
# create perturbations
|
|
138
|
+
perts = []
|
|
139
|
+
for p in params:
|
|
140
|
+
settings = self.settings[p]
|
|
141
|
+
if not settings['perturb']:
|
|
142
|
+
perts.append(torch.zeros_like(p))
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
alpha = settings['alpha']
|
|
146
|
+
if settings['relative']:
|
|
147
|
+
alpha *= p.abs().mean()
|
|
148
|
+
|
|
149
|
+
distribution = self.settings[p]['distribution'].lower()
|
|
150
|
+
if distribution in ('normal', 'gaussian'):
|
|
151
|
+
perts.append(torch.randn_like(p).mul_(alpha))
|
|
152
|
+
elif distribution == 'uniform':
|
|
153
|
+
perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
|
|
154
|
+
elif distribution == 'sphere':
|
|
155
|
+
r = torch.randn_like(p)
|
|
156
|
+
perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError(distribution)
|
|
159
|
+
|
|
160
|
+
@torch.no_grad
|
|
161
|
+
def perturbed_closure(backward=True):
|
|
162
|
+
params.add_(perts)
|
|
163
|
+
if backward:
|
|
164
|
+
with torch.enable_grad(): loss = closure()
|
|
165
|
+
else:
|
|
166
|
+
loss = closure(False)
|
|
167
|
+
params.sub_(perts)
|
|
168
|
+
return loss
|
|
169
|
+
|
|
170
|
+
var.closure = perturbed_closure
|
|
171
|
+
return var
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, Var
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _split(
|
|
10
|
+
module: Module,
|
|
11
|
+
idxs,
|
|
12
|
+
params,
|
|
13
|
+
var: Var,
|
|
14
|
+
):
|
|
15
|
+
split_params = [p for i,p in enumerate(params) if i in idxs]
|
|
16
|
+
|
|
17
|
+
split_grad = None
|
|
18
|
+
if var.grad is not None:
|
|
19
|
+
split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
|
|
20
|
+
|
|
21
|
+
split_update = None
|
|
22
|
+
if var.update is not None:
|
|
23
|
+
split_update = [u for i,u in enumerate(var.update) if i in idxs]
|
|
24
|
+
|
|
25
|
+
split_var = var.clone(clone_update=False)
|
|
26
|
+
split_var.params = split_params
|
|
27
|
+
split_var.grad = split_grad
|
|
28
|
+
split_var.update = split_update
|
|
29
|
+
|
|
30
|
+
split_var = module.step(split_var)
|
|
31
|
+
|
|
32
|
+
if (var.grad is None) and (split_var.grad is not None):
|
|
33
|
+
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
34
|
+
|
|
35
|
+
if split_var.update is not None:
|
|
36
|
+
|
|
37
|
+
if var.update is None:
|
|
38
|
+
if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
|
|
39
|
+
else: var.update = [g.clone() for g in var.grad]
|
|
40
|
+
|
|
41
|
+
for idx, u in zip(idxs, split_var.update):
|
|
42
|
+
var.update[idx] = u
|
|
43
|
+
|
|
44
|
+
var.update_attrs_from_clone_(split_var)
|
|
45
|
+
return var
|
|
46
|
+
|
|
47
|
+
class Split(Module):
|
|
48
|
+
"""Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
filter (Callable[[torch.Tensor], bool]): a function that takes in a parameter tensor and returns a boolean value.
|
|
52
|
+
true (Chainable | None): modules that are applied to tensors where :code:`filter` returned True.
|
|
53
|
+
false (Chainable | None): modules that are applied to tensors where :code:`filter` returned False.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
standard Muon with Adam fallback
|
|
57
|
+
|
|
58
|
+
.. code-block:: python
|
|
59
|
+
|
|
60
|
+
opt = tz.Modular(
|
|
61
|
+
model.head.parameters(),
|
|
62
|
+
tz.m.Split(
|
|
63
|
+
# apply muon only to 2D+ parameters
|
|
64
|
+
filter = lambda t: t.ndim >= 2,
|
|
65
|
+
true = [
|
|
66
|
+
tz.m.HeavyBall(),
|
|
67
|
+
tz.m.Orthogonalize(),
|
|
68
|
+
tz.m.LR(1e-2),
|
|
69
|
+
],
|
|
70
|
+
false = tz.m.Adam()
|
|
71
|
+
),
|
|
72
|
+
tz.m.LR(1e-2)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
|
|
78
|
+
defaults = dict(filter=filter)
|
|
79
|
+
super().__init__(defaults)
|
|
80
|
+
|
|
81
|
+
if true is not None: self.set_child('true', true)
|
|
82
|
+
if false is not None: self.set_child('false', false)
|
|
83
|
+
|
|
84
|
+
def step(self, var):
|
|
85
|
+
|
|
86
|
+
params = var.params
|
|
87
|
+
filter = self.settings[params[0]]['filter']
|
|
88
|
+
|
|
89
|
+
true_idxs = []
|
|
90
|
+
false_idxs = []
|
|
91
|
+
for i,p in enumerate(params):
|
|
92
|
+
if filter(p): true_idxs.append(i)
|
|
93
|
+
else: false_idxs.append(i)
|
|
94
|
+
|
|
95
|
+
if 'true' in self.children:
|
|
96
|
+
true = self.children['true']
|
|
97
|
+
var = _split(true, idxs=true_idxs, params=params, var=var)
|
|
98
|
+
|
|
99
|
+
if 'false' in self.children:
|
|
100
|
+
false = self.children['false']
|
|
101
|
+
var = _split(false, idxs=false_idxs, params=params, var=var)
|
|
102
|
+
|
|
103
|
+
return var
|
|
@@ -7,7 +7,28 @@ from ...core import Chainable, Module
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Alternate(Module):
|
|
10
|
-
"""
|
|
10
|
+
"""Alternates between stepping with :code:`modules`.
|
|
11
|
+
|
|
12
|
+
That is, first step is performed with 1st module, second step with second module, etc.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
Alternate between Adam, SignSGD and RMSprop
|
|
19
|
+
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
|
|
22
|
+
opt = tz.Modular(
|
|
23
|
+
model.parameters(),
|
|
24
|
+
tz.m.Alternate(
|
|
25
|
+
tz.m.Adam(),
|
|
26
|
+
[tz.m.SignSGD(), tz.m.Mul(0.5)],
|
|
27
|
+
tz.m.RMSprop(),
|
|
28
|
+
),
|
|
29
|
+
tz.m.LR(1e-3),
|
|
30
|
+
)
|
|
31
|
+
"""
|
|
11
32
|
LOOP = True
|
|
12
33
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
|
|
13
34
|
if isinstance(steps, Iterable):
|
|
@@ -23,16 +44,16 @@ class Alternate(Module):
|
|
|
23
44
|
self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
|
|
24
45
|
|
|
25
46
|
@torch.no_grad
|
|
26
|
-
def step(self,
|
|
47
|
+
def step(self, var):
|
|
27
48
|
# get current module
|
|
28
49
|
current_module_idx = self.global_state.setdefault('current_module_idx', 0)
|
|
29
50
|
module = self.children[f'module_{current_module_idx}']
|
|
30
51
|
|
|
31
52
|
# step
|
|
32
|
-
|
|
53
|
+
var = module.step(var.clone(clone_update=False))
|
|
33
54
|
|
|
34
55
|
# number of steps until next module
|
|
35
|
-
steps = self.settings[
|
|
56
|
+
steps = self.settings[var.params[0]]['steps']
|
|
36
57
|
if isinstance(steps, int): steps = [steps]*len(self.children)
|
|
37
58
|
|
|
38
59
|
if 'steps_to_next' not in self.global_state:
|
|
@@ -51,17 +72,37 @@ class Alternate(Module):
|
|
|
51
72
|
|
|
52
73
|
self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
|
|
53
74
|
|
|
54
|
-
return
|
|
75
|
+
return var
|
|
55
76
|
|
|
56
77
|
class Switch(Alternate):
|
|
57
|
-
"""
|
|
78
|
+
"""After :code:`steps` steps switches to the next module.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
steps (int | Iterable[int]): Number of steps to perform with each module.
|
|
82
|
+
|
|
83
|
+
Examples:
|
|
84
|
+
Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
|
|
85
|
+
|
|
86
|
+
.. code-block:: python
|
|
87
|
+
|
|
88
|
+
opt = tz.Modular(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.Switch(
|
|
91
|
+
[tz.m.Adam(), tz.m.LR(1e-3)],
|
|
92
|
+
[tz.m.LBFGS(), tz.m.Backtracking()],
|
|
93
|
+
[tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
|
|
94
|
+
steps = (1000, 2000)
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
"""
|
|
98
|
+
|
|
58
99
|
LOOP = False
|
|
59
100
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
|
|
60
101
|
|
|
61
102
|
if isinstance(steps, Iterable):
|
|
62
103
|
steps = list(steps)
|
|
63
104
|
if len(steps) != len(modules) - 1:
|
|
64
|
-
raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
|
|
105
|
+
raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")
|
|
65
106
|
|
|
66
107
|
steps.append(1)
|
|
67
108
|
|
|
@@ -11,4 +11,4 @@ from .experimental import CoordinateMomentum
|
|
|
11
11
|
# from .matrix_momentum import MatrixMomentum
|
|
12
12
|
|
|
13
13
|
from .momentum import NAG, HeavyBall
|
|
14
|
-
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
14
|
+
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""Modules that perform averaging over a history of past updates."""
|
|
1
2
|
from collections import deque
|
|
2
3
|
from collections.abc import Sequence
|
|
3
4
|
from typing import Any, Literal, cast
|
|
@@ -9,14 +10,19 @@ from ...utils import tolist
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class Averaging(TensorwiseTransform):
|
|
13
|
+
"""Average of past :code:`history_size` updates.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
history_size (int): Number of past updates to average
|
|
17
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
18
|
+
"""
|
|
12
19
|
def __init__(self, history_size: int, target: Target = 'update'):
|
|
13
20
|
defaults = dict(history_size=history_size)
|
|
14
21
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
15
22
|
|
|
16
23
|
@torch.no_grad
|
|
17
|
-
def
|
|
18
|
-
history_size =
|
|
19
|
-
state = self.state[param]
|
|
24
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
25
|
+
history_size = setting['history_size']
|
|
20
26
|
if 'history' not in state:
|
|
21
27
|
state['history'] = deque(maxlen=history_size)
|
|
22
28
|
state['average'] = torch.zeros_like(tensor)
|
|
@@ -29,15 +35,19 @@ class Averaging(TensorwiseTransform):
|
|
|
29
35
|
return average / len(history)
|
|
30
36
|
|
|
31
37
|
class WeightedAveraging(TensorwiseTransform):
|
|
32
|
-
"""
|
|
38
|
+
"""Weighted average of past :code:`len(weights)` updates.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
weights (Sequence[float]): a sequence of weights from oldest to newest.
|
|
42
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
43
|
+
"""
|
|
33
44
|
def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
|
|
34
45
|
defaults = dict(weights = tolist(weights))
|
|
35
46
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
36
47
|
|
|
37
48
|
@torch.no_grad
|
|
38
|
-
def
|
|
39
|
-
weights =
|
|
40
|
-
state = self.state[param]
|
|
49
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
50
|
+
weights = setting['weights']
|
|
41
51
|
|
|
42
52
|
if 'history' not in state:
|
|
43
53
|
state['history'] = deque(maxlen=len(weights))
|
|
@@ -59,14 +69,19 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
59
69
|
|
|
60
70
|
|
|
61
71
|
class MedianAveraging(TensorwiseTransform):
|
|
72
|
+
"""Median of past :code:`history_size` updates.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
history_size (int): Number of past updates to average
|
|
76
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
77
|
+
"""
|
|
62
78
|
def __init__(self, history_size: int, target: Target = 'update'):
|
|
63
79
|
defaults = dict(history_size = history_size)
|
|
64
80
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
65
81
|
|
|
66
82
|
@torch.no_grad
|
|
67
|
-
def
|
|
68
|
-
history_size =
|
|
69
|
-
state = self.state[param]
|
|
83
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
84
|
+
history_size = setting['history_size']
|
|
70
85
|
|
|
71
86
|
if 'history' not in state:
|
|
72
87
|
state['history'] = deque(maxlen=history_size)
|