torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
torchzero/optim/modular.py
DELETED
|
@@ -1,148 +0,0 @@
|
|
|
1
|
-
from collections import abc
|
|
2
|
-
import warnings
|
|
3
|
-
from inspect import cleandoc
|
|
4
|
-
import torch
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
from ..core import OptimizerModule, TensorListOptimizer, OptimizationVars, _Chain, _Chainable
|
|
8
|
-
from ..utils.python_tools import flatten
|
|
9
|
-
|
|
10
|
-
def _unroll_modules(flat_modules: list[OptimizerModule], nested) -> list[OptimizerModule]:
|
|
11
|
-
"""returns a list of all modules, including all nested ones"""
|
|
12
|
-
unrolled = []
|
|
13
|
-
for m in flat_modules:
|
|
14
|
-
unrolled.append(m)
|
|
15
|
-
if len(m.children) > 0:
|
|
16
|
-
unrolled.extend(_unroll_modules(list(m.children.values()), nested=True))
|
|
17
|
-
if nested:
|
|
18
|
-
if m.next_module is not None:
|
|
19
|
-
unrolled.extend(_unroll_modules([m.next_module], nested=True))
|
|
20
|
-
return unrolled
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class Modular(TensorListOptimizer):
|
|
24
|
-
"""Creates a modular optimizer by chaining together a sequence of optimizer modules.
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
28
|
-
*modules (Iterable[OptimizerModule] | OptimizerModule):
|
|
29
|
-
A sequence of optimizer modules to chain together. This argument will be flattened."""
|
|
30
|
-
def __init__(self, params, *modules: _Chainable):
|
|
31
|
-
flat_modules = flatten(modules)
|
|
32
|
-
self.modules: list[OptimizerModule] = flat_modules
|
|
33
|
-
self.chain = _Chain(flat_modules)
|
|
34
|
-
|
|
35
|
-
# save unrolled modules and make sure there is only 1 LR module.
|
|
36
|
-
self.unrolled_modules = _unroll_modules(flat_modules, nested=False)
|
|
37
|
-
num_lr_modules = len([m for m in self.unrolled_modules if m.IS_LR_MODULE])
|
|
38
|
-
if num_lr_modules > 1:
|
|
39
|
-
warnings.warn(cleandoc(
|
|
40
|
-
f"""More then 1 lr modules have been added.
|
|
41
|
-
This may lead to incorrect behaviour with learning rate scheduling and per-parameter learning rates.
|
|
42
|
-
Make sure there is a single `LR` module, use `Alpha` module instead of it where needed.
|
|
43
|
-
\nList of modules: {self.unrolled_modules}; \nlist of lr modules: {[m for m in self.unrolled_modules if m.IS_LR_MODULE]}"""
|
|
44
|
-
))
|
|
45
|
-
|
|
46
|
-
if isinstance(params, torch.nn.Module):
|
|
47
|
-
self.model = params
|
|
48
|
-
params = list(params.parameters())
|
|
49
|
-
else:
|
|
50
|
-
self.model = None
|
|
51
|
-
params = list(params)
|
|
52
|
-
|
|
53
|
-
# if there is an `lr` setting, make sure there is an LR module that can use it
|
|
54
|
-
for p in params:
|
|
55
|
-
if isinstance(p, dict):
|
|
56
|
-
if 'lr' in p:
|
|
57
|
-
if num_lr_modules == 0:
|
|
58
|
-
warnings.warn(cleandoc(
|
|
59
|
-
"""Passed "lr" setting in a parameter group, but there is no LR module that can use that setting.
|
|
60
|
-
Add an `LR` module to make per-layer "lr" setting work."""
|
|
61
|
-
))
|
|
62
|
-
|
|
63
|
-
super().__init__(params, {})
|
|
64
|
-
self.chain._initialize_(params, set_passed_params=True)
|
|
65
|
-
|
|
66
|
-
# run post-init hooks
|
|
67
|
-
for module in self.unrolled_modules:
|
|
68
|
-
for hook in module.post_init_hooks:
|
|
69
|
-
hook(self, module)
|
|
70
|
-
|
|
71
|
-
def state_dict(self):
|
|
72
|
-
state_dict = {}
|
|
73
|
-
state_dict['__self__'] = super().state_dict()
|
|
74
|
-
for i,v in enumerate(self.unrolled_modules):
|
|
75
|
-
state_dict[str(i)] = v.state_dict()
|
|
76
|
-
return state_dict
|
|
77
|
-
|
|
78
|
-
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
79
|
-
super().load_state_dict(state_dict['__self__'])
|
|
80
|
-
for i,v in enumerate(self.unrolled_modules):
|
|
81
|
-
if str(i) in state_dict:
|
|
82
|
-
v.load_state_dict(state_dict[str(i)])
|
|
83
|
-
else:
|
|
84
|
-
warnings.warn(f"Tried to load state dict for {i}th module: {v.__class__.__name__}, but it is not present in state_dict with {list(state_dict.keys()) = }")
|
|
85
|
-
|
|
86
|
-
def get_lr_module(self, last=True) -> OptimizerModule:
|
|
87
|
-
"""
|
|
88
|
-
Retrieves the module in the chain that controls the learning rate.
|
|
89
|
-
|
|
90
|
-
This method is useful for setting up a learning rate scheduler. By default, it retrieves the last module in the chain
|
|
91
|
-
that has an `lr` group parameter.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
last (bool, optional):
|
|
95
|
-
If multiple modules have an `lr` parameter, this argument controls which one is returned.
|
|
96
|
-
- If `True` (default), the last module is returned.
|
|
97
|
-
- If `False`, the first module is returned.
|
|
98
|
-
|
|
99
|
-
Returns:
|
|
100
|
-
OptimizerModule: The module that controls the learning rate.
|
|
101
|
-
|
|
102
|
-
Raises:
|
|
103
|
-
ValueError: If no modules in the chain have an `lr` parameter. To fix this, add an `LR` module.
|
|
104
|
-
|
|
105
|
-
Example:
|
|
106
|
-
|
|
107
|
-
.. code:: py
|
|
108
|
-
from torch.optim.lr_scheduler import OneCycleLR
|
|
109
|
-
import torchzero as tz
|
|
110
|
-
|
|
111
|
-
opt = tz.Modular(model.parameters(), [tz.m.RMSProp(), tz.m.LR(1e-2), tz.m.DirectionalNewton()])
|
|
112
|
-
lr_scheduler = OneCycleLR(opt.get_lr_module(), max_lr = 1e-1, total_steps = 1000, cycle_momentum=False)
|
|
113
|
-
|
|
114
|
-
"""
|
|
115
|
-
modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
|
|
116
|
-
for m in modules:
|
|
117
|
-
if 'lr' in m.param_groups[0]: return m
|
|
118
|
-
|
|
119
|
-
raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} support and `lr` parameter. The easiest way to fix is is to add an `LR(1)` module at the end.')
|
|
120
|
-
|
|
121
|
-
def get_module_by_name(self, name: str | type, last=True) -> OptimizerModule:
|
|
122
|
-
"""Returns the first or last module in the chain that matches the provided name or type.
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
name (str | type): the name (as a string) or the type of the module to search for.
|
|
126
|
-
last (bool, optional):
|
|
127
|
-
If multiple modules match, this argument controls which one is returned.
|
|
128
|
-
- If `True` (default), the last matching module is returned.
|
|
129
|
-
- If `False`, the first matching module is returned.
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
OptimizerModule: The matching optimizer module.
|
|
133
|
-
|
|
134
|
-
Raises:
|
|
135
|
-
ValueError: If no modules in the chain match the provided name or type.
|
|
136
|
-
"""
|
|
137
|
-
modules = list(reversed(self.unrolled_modules)) if last else self.unrolled_modules
|
|
138
|
-
for m in modules:
|
|
139
|
-
if isinstance(name, str) and m.__class__.__name__ == name: return m
|
|
140
|
-
if isinstance(name, type) and isinstance(m, name): return m
|
|
141
|
-
|
|
142
|
-
raise ValueError(f'No modules out of {", ".join(m.__class__.__name__ for m in modules)} match "{name}".')
|
|
143
|
-
|
|
144
|
-
def step(self, closure=None): # type:ignore
|
|
145
|
-
vars = OptimizationVars(closure, self.model)
|
|
146
|
-
res = self.chain.step(vars)
|
|
147
|
-
for hook in vars.post_step_hooks: hook(self, vars)
|
|
148
|
-
return res
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .directional_newton import DirectionalNewton
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
from ...modules import (
|
|
2
|
-
SGD,
|
|
3
|
-
)
|
|
4
|
-
from ...modules import DirectionalNewton as _DirectionalNewton, LR
|
|
5
|
-
from ..modular import Modular
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class DirectionalNewton(Modular):
|
|
9
|
-
"""Minimizes a parabola in the direction of the gradient (or update if momentum or weight decay is enabled)
|
|
10
|
-
via one additional forward pass, and uses another forward pass to make sure it didn't overstep.
|
|
11
|
-
So in total this performs three forward passes and one backward.
|
|
12
|
-
|
|
13
|
-
First forward and backward pass is used to calculate the value and gradient at initial parameters.
|
|
14
|
-
Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated
|
|
15
|
-
with new parameters. A quadratic is fitted to two points and gradient,
|
|
16
|
-
if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased
|
|
17
|
-
with an additional forward pass.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
21
|
-
lr (float, optional):
|
|
22
|
-
learning rate. Since you shouldn't put this module after LR(), you have to specify
|
|
23
|
-
the learning rate in this argument. Defaults to 1e-2.
|
|
24
|
-
max_dist (float | None, optional):
|
|
25
|
-
maximum distance to step when minimizing quadratic.
|
|
26
|
-
If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
|
|
27
|
-
validate_step (bool, optional):
|
|
28
|
-
uses an additional forward pass to check
|
|
29
|
-
if step towards the minimum actually decreased the loss. Defaults to True.
|
|
30
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
31
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
32
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
33
|
-
nesterov (bool, optional):
|
|
34
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
35
|
-
|
|
36
|
-
Note:
|
|
37
|
-
While lr scheduling is supported, this uses lr of the first parameter for all parameters.
|
|
38
|
-
"""
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
params,
|
|
42
|
-
lr: float = 1e-4,
|
|
43
|
-
max_dist: float | None = 1e5,
|
|
44
|
-
validate_step: bool = True,
|
|
45
|
-
momentum: float = 0,
|
|
46
|
-
dampening: float = 0,
|
|
47
|
-
weight_decay: float = 0,
|
|
48
|
-
nesterov: bool = False,
|
|
49
|
-
|
|
50
|
-
):
|
|
51
|
-
|
|
52
|
-
modules = [
|
|
53
|
-
SGD(momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov),
|
|
54
|
-
LR(lr),
|
|
55
|
-
_DirectionalNewton(max_dist, validate_step)
|
|
56
|
-
]
|
|
57
|
-
super().__init__(params, modules)
|
|
58
|
-
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .newton import ExactNewton
|
|
@@ -1,94 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...modules import (
|
|
6
|
-
LR,
|
|
7
|
-
ClipNorm,
|
|
8
|
-
FallbackLinearSystemSolvers,
|
|
9
|
-
LinearSystemSolvers,
|
|
10
|
-
LineSearches,
|
|
11
|
-
get_line_search,
|
|
12
|
-
)
|
|
13
|
-
from ...modules import ExactNewton as _ExactNewton
|
|
14
|
-
from ..modular import Modular
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ExactNewton(Modular):
|
|
18
|
-
"""Peforms an exact Newton step using batched autograd. Note that torch.func would be way more efficient
|
|
19
|
-
but much more restrictive to what operations are allowed (I will add it at some point).
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
23
|
-
lr (float, optional): learning rate. Defaults to 1.
|
|
24
|
-
tikhonov (float, optional):
|
|
25
|
-
tikhonov regularization (constant value added to the diagonal of the hessian). Defaults to 0.
|
|
26
|
-
solver (LinearSystemSolvers, optional):
|
|
27
|
-
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
28
|
-
fallback (FallbackLinearSystemSolvers, optional):
|
|
29
|
-
what to do if solver fails. Defaults to "safe_diag"
|
|
30
|
-
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
31
|
-
max_norm (float, optional):
|
|
32
|
-
clips the newton step to L2 norm to avoid instability by giant steps.
|
|
33
|
-
A mauch better way is to use trust region methods. I haven't implemented any
|
|
34
|
-
but you can use `tz.optim.wrappers.scipy.ScipyMinimize` with one of the trust region methods.
|
|
35
|
-
Defaults to None.
|
|
36
|
-
validate (bool, optional):
|
|
37
|
-
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
38
|
-
If not, undo the step and perform a gradient descent step.
|
|
39
|
-
tol (float, optional):
|
|
40
|
-
only has effect if `validate` is enabled.
|
|
41
|
-
If loss increased by `loss * tol`, perform gradient descent step.
|
|
42
|
-
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
43
|
-
gd_lr (float, optional):
|
|
44
|
-
only has effect if `validate` is enabled.
|
|
45
|
-
Gradient descent step learning rate. Defaults to 1e-2.
|
|
46
|
-
line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to None.
|
|
47
|
-
batched_hessian (bool, optional):
|
|
48
|
-
whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
|
|
49
|
-
should be faster, but this feature being experimental, there may be performance cliffs.
|
|
50
|
-
Defaults to True.
|
|
51
|
-
diag (False, optional):
|
|
52
|
-
only use the diagonal of the hessian. This will still calculate the full hessian!
|
|
53
|
-
This is mainly useful for benchmarking.
|
|
54
|
-
"""
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
params,
|
|
58
|
-
lr: float = 1,
|
|
59
|
-
tikhonov: float | Literal['eig'] = 0.0,
|
|
60
|
-
solver: LinearSystemSolvers = "cholesky_lu",
|
|
61
|
-
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
62
|
-
max_norm: float | None = None,
|
|
63
|
-
validate=False,
|
|
64
|
-
tol: float = 1,
|
|
65
|
-
gd_lr = 1e-2,
|
|
66
|
-
line_search: LineSearches | None = None,
|
|
67
|
-
batched_hessian = True,
|
|
68
|
-
|
|
69
|
-
diag: bool = False,
|
|
70
|
-
):
|
|
71
|
-
modules: list[Any] = [
|
|
72
|
-
_ExactNewton(
|
|
73
|
-
tikhonov=tikhonov,
|
|
74
|
-
batched_hessian=batched_hessian,
|
|
75
|
-
solver=solver,
|
|
76
|
-
fallback=fallback,
|
|
77
|
-
validate=validate,
|
|
78
|
-
tol = tol,
|
|
79
|
-
gd_lr=gd_lr,
|
|
80
|
-
diag = diag,
|
|
81
|
-
),
|
|
82
|
-
]
|
|
83
|
-
|
|
84
|
-
if max_norm is not None:
|
|
85
|
-
modules.append(ClipNorm(max_norm))
|
|
86
|
-
|
|
87
|
-
modules.append(LR(lr))
|
|
88
|
-
|
|
89
|
-
if line_search is not None:
|
|
90
|
-
modules.append(get_line_search(line_search))
|
|
91
|
-
|
|
92
|
-
super().__init__(params, modules)
|
|
93
|
-
|
|
94
|
-
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...modules import FDM as _FDM, WrapClosure, SGD, WeightDecay, LR
|
|
6
|
-
from ...modules.gradient_approximation._fd_formulas import _FD_Formulas
|
|
7
|
-
from ..modular import Modular
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class FDM(Modular):
|
|
11
|
-
"""Gradient approximation via finite difference.
|
|
12
|
-
|
|
13
|
-
This performs `n + 1` evaluations per step with `forward` and `backward` formulas,
|
|
14
|
-
and `2 * n` with `central` formula, where n is the number of parameters.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
18
|
-
lr (float, optional): learning rate. Defaults to 1e-3.
|
|
19
|
-
eps (float, optional): finite difference epsilon. Defaults to 1e-3.
|
|
20
|
-
formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
|
|
21
|
-
n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
|
|
22
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
23
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
24
|
-
nesterov (bool, optional):
|
|
25
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
26
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
27
|
-
decoupled (bool, optional):
|
|
28
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
29
|
-
"""
|
|
30
|
-
def __init__(
|
|
31
|
-
self,
|
|
32
|
-
params,
|
|
33
|
-
lr: float = 1e-3,
|
|
34
|
-
eps: float = 1e-3,
|
|
35
|
-
formula: _FD_Formulas = "forward",
|
|
36
|
-
n_points: Literal[2, 3] = 2,
|
|
37
|
-
momentum: float = 0,
|
|
38
|
-
dampening: float = 0,
|
|
39
|
-
nesterov: bool = False,
|
|
40
|
-
weight_decay: float = 0,
|
|
41
|
-
decoupled=False,
|
|
42
|
-
|
|
43
|
-
):
|
|
44
|
-
modules: list = [
|
|
45
|
-
_FDM(eps = eps, formula=formula, n_points=n_points),
|
|
46
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
47
|
-
LR(lr),
|
|
48
|
-
|
|
49
|
-
]
|
|
50
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
51
|
-
super().__init__(params, modules)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class FDMWrapper(Modular):
|
|
55
|
-
"""Gradient approximation via finite difference. This wraps any other optimizer.
|
|
56
|
-
This also supports optimizers that perform multiple gradient evaluations per step, like LBFGS.
|
|
57
|
-
|
|
58
|
-
Exaple:
|
|
59
|
-
```
|
|
60
|
-
lbfgs = torch.optim.LBFGS(params, lr = 1)
|
|
61
|
-
fdm = FDMWrapper(optimizer = lbfgs)
|
|
62
|
-
```
|
|
63
|
-
|
|
64
|
-
This performs n+1 evaluations per step with `forward` and `backward` formulas,
|
|
65
|
-
and 2*n with `central` formula.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
69
|
-
optimizer (torch.optim.Optimizer): optimizer that will perform optimization using FDM-approximated gradients.
|
|
70
|
-
eps (float, optional): finite difference epsilon. Defaults to 1e-3.
|
|
71
|
-
formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
|
|
72
|
-
n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
|
|
73
|
-
"""
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
optimizer: torch.optim.Optimizer,
|
|
77
|
-
eps: float = 1e-3,
|
|
78
|
-
formula: _FD_Formulas = "forward",
|
|
79
|
-
n_points: Literal[2, 3] = 2,
|
|
80
|
-
):
|
|
81
|
-
modules = [
|
|
82
|
-
_FDM(eps = eps, formula=formula, n_points=n_points, target = 'closure'),
|
|
83
|
-
WrapClosure(optimizer)
|
|
84
|
-
]
|
|
85
|
-
# some optimizers have `eps` setting in param groups too.
|
|
86
|
-
# it should not be passed to FDM
|
|
87
|
-
super().__init__([p for g in optimizer.param_groups.copy() for p in g['params']], modules)
|
|
@@ -1,146 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...modules import (LR, FallbackLinearSystemSolvers,
|
|
5
|
-
LinearSystemSolvers, LineSearches, ClipNorm)
|
|
6
|
-
from ...modules import NewtonFDM as _NewtonFDM, get_line_search
|
|
7
|
-
from ...modules.experimental.subspace import Proj2Masks, ProjRandom, Subspace
|
|
8
|
-
from ..modular import Modular
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class NewtonFDM(Modular):
|
|
12
|
-
"""Newton method with gradient and hessian approximated via finite difference.
|
|
13
|
-
|
|
14
|
-
This performs approximately `4 * n^2 + 1` evaluations per step;
|
|
15
|
-
if `diag` is True, performs `n * 2 + 1` evaluations per step.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
19
|
-
lr (float, optional): learning rate.
|
|
20
|
-
eps (float, optional): epsilon for finite difference.
|
|
21
|
-
Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
|
|
22
|
-
diag (bool, optional): whether to only approximate diagonal elements of the hessian.
|
|
23
|
-
This also ignores `solver` if True. Defaults to False.
|
|
24
|
-
solver (LinearSystemSolvers, optional):
|
|
25
|
-
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
26
|
-
fallback (FallbackLinearSystemSolvers, optional):
|
|
27
|
-
what to do if solver fails. Defaults to "safe_diag"
|
|
28
|
-
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
29
|
-
validate (bool, optional):
|
|
30
|
-
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
31
|
-
If not, undo the step and perform a gradient descent step.
|
|
32
|
-
tol (float, optional):
|
|
33
|
-
only has effect if `validate` is enabled.
|
|
34
|
-
If loss increased by `loss * tol`, perform gradient descent step.
|
|
35
|
-
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
36
|
-
gd_lr (float, optional):
|
|
37
|
-
only has effect if `validate` is enabled.
|
|
38
|
-
Gradient descent step learning rate. Defaults to 1e-2.
|
|
39
|
-
line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to 'brent'.
|
|
40
|
-
"""
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
params,
|
|
44
|
-
lr: float = 1,
|
|
45
|
-
eps: float = 1e-2,
|
|
46
|
-
diag=False,
|
|
47
|
-
solver: LinearSystemSolvers = "cholesky_lu",
|
|
48
|
-
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
49
|
-
max_norm: float | None = None,
|
|
50
|
-
validate=False,
|
|
51
|
-
tol: float = 2,
|
|
52
|
-
gd_lr = 1e-2,
|
|
53
|
-
line_search: LineSearches | None = 'brent',
|
|
54
|
-
):
|
|
55
|
-
modules: list[Any] = [
|
|
56
|
-
_NewtonFDM(eps = eps, diag = diag, solver=solver, fallback=fallback, validate=validate, tol=tol, gd_lr=gd_lr),
|
|
57
|
-
]
|
|
58
|
-
|
|
59
|
-
if max_norm is not None:
|
|
60
|
-
modules.append(ClipNorm(max_norm))
|
|
61
|
-
|
|
62
|
-
modules.append(LR(lr))
|
|
63
|
-
|
|
64
|
-
if line_search is not None:
|
|
65
|
-
modules.append(get_line_search(line_search))
|
|
66
|
-
|
|
67
|
-
super().__init__(params, modules)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class RandomSubspaceNewtonFDM(Modular):
|
|
71
|
-
"""This projects the parameters into a smaller dimensional subspace,
|
|
72
|
-
making approximating the hessian via finite difference feasible.
|
|
73
|
-
|
|
74
|
-
This performs approximately `4 * subspace_ndim^2 + 1` evaluations per step;
|
|
75
|
-
if `diag` is True, performs `subspace_ndim * 2 + 1` evaluations per step.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
79
|
-
subspace_ndim (float, optional): number of random subspace dimensions.
|
|
80
|
-
lr (float, optional): learning rate.
|
|
81
|
-
eps (float, optional): epsilon for finite difference.
|
|
82
|
-
Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
|
|
83
|
-
diag (bool, optional): whether to only approximate diagonal elements of the hessian.
|
|
84
|
-
solver (LinearSystemSolvers, optional):
|
|
85
|
-
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
86
|
-
fallback (FallbackLinearSystemSolvers, optional):
|
|
87
|
-
what to do if solver fails. Defaults to "safe_diag"
|
|
88
|
-
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
89
|
-
validate (bool, optional):
|
|
90
|
-
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
91
|
-
If not, undo the step and perform a gradient descent step.
|
|
92
|
-
tol (float, optional):
|
|
93
|
-
only has effect if `validate` is enabled.
|
|
94
|
-
If loss increased by `loss * tol`, perform gradient descent step.
|
|
95
|
-
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
96
|
-
gd_lr (float, optional):
|
|
97
|
-
only has effect if `validate` is enabled.
|
|
98
|
-
Gradient descent step learning rate. Defaults to 1e-2.
|
|
99
|
-
line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to BacktrackingLS().
|
|
100
|
-
randomize_every (float, optional): generates new random projections every n steps. Defaults to 1.
|
|
101
|
-
"""
|
|
102
|
-
def __init__(
|
|
103
|
-
self,
|
|
104
|
-
params,
|
|
105
|
-
subspace_ndim: int = 3,
|
|
106
|
-
lr: float = 1,
|
|
107
|
-
eps: float = 1e-2,
|
|
108
|
-
diag=False,
|
|
109
|
-
solver: LinearSystemSolvers = "cholesky_lu",
|
|
110
|
-
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
111
|
-
max_norm: float | None = None,
|
|
112
|
-
validate=False,
|
|
113
|
-
tol: float = 2,
|
|
114
|
-
gd_lr = 1e-2,
|
|
115
|
-
line_search: LineSearches | None = 'brent',
|
|
116
|
-
randomize_every: int = 1,
|
|
117
|
-
):
|
|
118
|
-
if subspace_ndim == 1: projections = [ProjRandom(1)]
|
|
119
|
-
else:
|
|
120
|
-
projections: list[Any] = [Proj2Masks(subspace_ndim//2)]
|
|
121
|
-
if subspace_ndim % 2 == 1: projections.append(ProjRandom(1))
|
|
122
|
-
|
|
123
|
-
modules: list[Any] = [
|
|
124
|
-
Subspace(
|
|
125
|
-
modules = _NewtonFDM(
|
|
126
|
-
eps = eps,
|
|
127
|
-
diag = diag,
|
|
128
|
-
solver=solver,
|
|
129
|
-
fallback=fallback,
|
|
130
|
-
validate=validate,
|
|
131
|
-
tol=tol,
|
|
132
|
-
gd_lr=gd_lr
|
|
133
|
-
),
|
|
134
|
-
projections = projections,
|
|
135
|
-
update_every=randomize_every),
|
|
136
|
-
]
|
|
137
|
-
if max_norm is not None:
|
|
138
|
-
modules.append(ClipNorm(max_norm))
|
|
139
|
-
|
|
140
|
-
modules.append(LR(lr))
|
|
141
|
-
|
|
142
|
-
if line_search is not None:
|
|
143
|
-
modules.append(get_line_search(line_search))
|
|
144
|
-
|
|
145
|
-
super().__init__(params, modules)
|
|
146
|
-
|