torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable, Mapping, MutableSequence, Sequence, MutableMapping
|
|
2
|
+
from typing import Any, Literal, TypeVar, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .tensorlist import TensorList
|
|
7
|
+
from .numberlist import NumberList
|
|
8
|
+
from .torch_tools import tofloat, totensor
|
|
9
|
+
|
|
10
|
+
ListLike = TypeVar('ListLike', bound=MutableSequence)
|
|
11
|
+
|
|
12
|
+
ParamFilter = Literal["has_grad", "requires_grad", "all"] | Callable[[torch.Tensor], bool]
|
|
13
|
+
def _param_filter(param: torch.Tensor, mode: ParamFilter):
|
|
14
|
+
if callable(mode): return mode(param)
|
|
15
|
+
if mode == 'has_grad': return param.grad is not None
|
|
16
|
+
if mode == 'requires_grad': return param.requires_grad
|
|
17
|
+
if mode == 'all': return True
|
|
18
|
+
raise ValueError(f"Unknown mode {mode}")
|
|
19
|
+
|
|
20
|
+
def get_params(
|
|
21
|
+
param_groups: Iterable[Mapping[str, Any]],
|
|
22
|
+
mode: ParamFilter = 'requires_grad',
|
|
23
|
+
cls: type[ListLike] = TensorList,
|
|
24
|
+
) -> ListLike:
|
|
25
|
+
return cls(p for g in param_groups for p in g['params'] if _param_filter(p, mode)) # type:ignore[reportCallIssue]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@overload
|
|
29
|
+
def get_group_vals(param_groups: Iterable[Mapping[str, Any]], key: str, *,
|
|
30
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> ListLike: ...
|
|
31
|
+
@overload
|
|
32
|
+
def get_group_vals(param_groups: Iterable[Mapping[str, Any]], key: list[str] | tuple[str,...], *,
|
|
33
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
34
|
+
@overload
|
|
35
|
+
def get_group_vals(param_groups: Iterable[Mapping[str, Any]], key: str, key2: str, *keys: str,
|
|
36
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
37
|
+
|
|
38
|
+
def get_group_vals(param_groups: Iterable[Mapping[str, Any]],
|
|
39
|
+
key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
40
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
41
|
+
|
|
42
|
+
# single key, return single cls
|
|
43
|
+
if isinstance(key, str) and key2 is None:
|
|
44
|
+
values = cls()
|
|
45
|
+
for group in param_groups:
|
|
46
|
+
num_params = len([p for p in group['params'] if _param_filter(p, mode)])
|
|
47
|
+
if num_params > 0:
|
|
48
|
+
group_value = group[key]
|
|
49
|
+
values.extend(group_value for _ in range(num_params))
|
|
50
|
+
return values
|
|
51
|
+
|
|
52
|
+
# multiple keys
|
|
53
|
+
k1 = (key,) if isinstance(key, str) else tuple(key)
|
|
54
|
+
k2 = () if key2 is None else (key2,)
|
|
55
|
+
keys = k1 + k2 + keys
|
|
56
|
+
|
|
57
|
+
values = [cls() for _ in keys]
|
|
58
|
+
for group in param_groups:
|
|
59
|
+
num_params = len([p for p in group['params'] if _param_filter(p, mode)])
|
|
60
|
+
if num_params > 0:
|
|
61
|
+
for i,key in enumerate(keys):
|
|
62
|
+
group_value = group[key]
|
|
63
|
+
values[i].extend(group_value for _ in range(num_params))
|
|
64
|
+
return values
|
|
65
|
+
|
|
66
|
+
_InitLiterals = Literal['param', 'grad']
|
|
67
|
+
Init = _InitLiterals | Any | list[_InitLiterals | Any] | tuple[_InitLiterals | Any]
|
|
68
|
+
|
|
69
|
+
def _make_initial_state_value(param: torch.Tensor, init: Init, i: int | None):
|
|
70
|
+
if callable(init): return init(param)
|
|
71
|
+
if isinstance(init, torch.Tensor): return init.detach().clone()
|
|
72
|
+
|
|
73
|
+
if isinstance(init, str):
|
|
74
|
+
if init in ('param','params'): return param.detach().clone()
|
|
75
|
+
if init in ('grad', 'grads'):
|
|
76
|
+
if param.grad is None: raise RuntimeError('init is set to "grad, but param.grad is None"')
|
|
77
|
+
return param.grad.detach().clone()
|
|
78
|
+
|
|
79
|
+
if isinstance(init, (list,tuple)):
|
|
80
|
+
if i is None: raise RuntimeError(f'init is per-parameter ({type(init)}) but parameter index i is None')
|
|
81
|
+
return _make_initial_state_value(param, init[i], None)
|
|
82
|
+
|
|
83
|
+
return init
|
|
84
|
+
|
|
85
|
+
@overload
|
|
86
|
+
def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
|
|
87
|
+
key: str, *,
|
|
88
|
+
must_exist: bool = False, init: Init = torch.zeros_like,
|
|
89
|
+
cls: type[ListLike] = list) -> ListLike: ...
|
|
90
|
+
@overload
|
|
91
|
+
def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
|
|
92
|
+
key: list[str] | tuple[str,...], *,
|
|
93
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
94
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
95
|
+
@overload
|
|
96
|
+
def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
|
|
97
|
+
key: str, key2: str, *keys: str,
|
|
98
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
99
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
100
|
+
|
|
101
|
+
def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
|
|
102
|
+
key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
103
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
104
|
+
cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
105
|
+
|
|
106
|
+
# single key, return single cls
|
|
107
|
+
if isinstance(key, str) and key2 is None:
|
|
108
|
+
values = cls()
|
|
109
|
+
for i, param in enumerate(params):
|
|
110
|
+
s = state[param]
|
|
111
|
+
if key not in s:
|
|
112
|
+
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
113
|
+
s[key] = _make_initial_state_value(param, init, i)
|
|
114
|
+
values.append(s[key])
|
|
115
|
+
return values
|
|
116
|
+
|
|
117
|
+
# multiple keys
|
|
118
|
+
k1 = (key,) if isinstance(key, str) else tuple(key)
|
|
119
|
+
k2 = () if key2 is None else (key2,)
|
|
120
|
+
keys = k1 + k2 + keys
|
|
121
|
+
|
|
122
|
+
values = [cls() for _ in keys]
|
|
123
|
+
for i, param in enumerate(params):
|
|
124
|
+
s = state[param]
|
|
125
|
+
for k_i, key in enumerate(keys):
|
|
126
|
+
if key not in s:
|
|
127
|
+
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
128
|
+
k_init = init[k_i] if isinstance(init, (list,tuple)) else init
|
|
129
|
+
s[key] = _make_initial_state_value(param, k_init, i)
|
|
130
|
+
values[k_i].append(s[key])
|
|
131
|
+
|
|
132
|
+
return values
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def loss_at_params(closure, params: Iterable[torch.Tensor],
|
|
137
|
+
new_params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
|
|
138
|
+
params = TensorList(params)
|
|
139
|
+
|
|
140
|
+
old_params = params.clone() if restore else None
|
|
141
|
+
|
|
142
|
+
if isinstance(new_params, Sequence) and isinstance(new_params[0], torch.Tensor):
|
|
143
|
+
# when not restoring, copy new_params to params to avoid unexpected bugs due to shared storage
|
|
144
|
+
# when restoring params will be set back to old_params so its fine
|
|
145
|
+
if restore: params.set_(new_params)
|
|
146
|
+
else: params.copy_(new_params) # type:ignore
|
|
147
|
+
|
|
148
|
+
else:
|
|
149
|
+
new_params = totensor(new_params)
|
|
150
|
+
params.from_vec_(new_params)
|
|
151
|
+
|
|
152
|
+
if backward: loss = closure()
|
|
153
|
+
else: loss = closure(False)
|
|
154
|
+
|
|
155
|
+
if restore:
|
|
156
|
+
assert old_params is not None
|
|
157
|
+
params.set_(old_params)
|
|
158
|
+
|
|
159
|
+
return tofloat(loss)
|
|
160
|
+
|
|
161
|
+
def loss_grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
|
|
162
|
+
params = TensorList(params)
|
|
163
|
+
old_params = params.clone() if restore else None
|
|
164
|
+
loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
|
|
165
|
+
grad = params.ensure_grad_().grad
|
|
166
|
+
|
|
167
|
+
if restore:
|
|
168
|
+
assert old_params is not None
|
|
169
|
+
params.set_(old_params)
|
|
170
|
+
|
|
171
|
+
return loss, grad
|
|
172
|
+
|
|
173
|
+
def grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
|
|
174
|
+
return loss_grad_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
|
|
175
|
+
|
|
176
|
+
def loss_grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
|
|
177
|
+
params = TensorList(params)
|
|
178
|
+
old_params = params.clone() if restore else None
|
|
179
|
+
loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
|
|
180
|
+
grad = params.ensure_grad_().grad.to_vec()
|
|
181
|
+
|
|
182
|
+
if restore:
|
|
183
|
+
assert old_params is not None
|
|
184
|
+
params.set_(old_params)
|
|
185
|
+
|
|
186
|
+
return loss, grad
|
|
187
|
+
|
|
188
|
+
def grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
|
|
189
|
+
return loss_grad_vec_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class Optimizer(torch.optim.Optimizer):
|
|
194
|
+
"""subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
params (iterable): an iterable of :class:`torch.Tensor` s or
|
|
198
|
+
:class:`dict` s. Specifies what Tensors should be optimized.
|
|
199
|
+
defaults (dict | None): a dict containing default values of optimization
|
|
200
|
+
options (used when a parameter group doesn't specify them).
|
|
201
|
+
"""
|
|
202
|
+
def __init__(self, params, defaults: dict[str, Any] | None = None, **_defaults):
|
|
203
|
+
if defaults is None: defaults = {}
|
|
204
|
+
defaults.update(_defaults)
|
|
205
|
+
|
|
206
|
+
super().__init__(params, defaults)
|
|
207
|
+
self.global_state = self.state[self.param_groups[0]['params'][0]]
|
|
208
|
+
"""state of 1st parameter, can be used as global state which is how L-BFGS uses it in pytorch, and there is some kind of good reason to do it like that"""
|
|
209
|
+
|
|
210
|
+
def get_params(self, mode: ParamFilter = 'requires_grad', cls: type[ListLike] = TensorList) -> ListLike:
|
|
211
|
+
return get_params(self.param_groups, mode, cls)
|
|
212
|
+
|
|
213
|
+
@overload
|
|
214
|
+
def group_vals(self, key: str, *,
|
|
215
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike: ...
|
|
216
|
+
@overload
|
|
217
|
+
def group_vals(self, key: list[str] | tuple[str,...], *,
|
|
218
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
|
|
219
|
+
@overload
|
|
220
|
+
def group_vals(self, key: str, key2: str, *keys: str,
|
|
221
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
|
|
222
|
+
|
|
223
|
+
def group_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
224
|
+
mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike | list[ListLike]:
|
|
225
|
+
return get_group_vals(self.param_groups, key, key2, *keys, mode = mode, cls = cls) # pyright:ignore[reportArgumentType]
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@overload
|
|
229
|
+
def state_vals(self, key: str, *,
|
|
230
|
+
init: Init = torch.zeros_like,
|
|
231
|
+
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
232
|
+
cls: type[ListLike] = TensorList) -> ListLike: ...
|
|
233
|
+
@overload
|
|
234
|
+
def state_vals(self, key: list[str] | tuple[str,...], *,
|
|
235
|
+
init: Init | Sequence[Init] = torch.zeros_like,
|
|
236
|
+
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
237
|
+
cls: type[ListLike] = TensorList) -> list[ListLike]: ...
|
|
238
|
+
@overload
|
|
239
|
+
def state_vals(self, key: str, key2: str, *keys: str,
|
|
240
|
+
init: Init | Sequence[Init] = torch.zeros_like,
|
|
241
|
+
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
242
|
+
cls: type[ListLike] = TensorList) -> list[ListLike]: ...
|
|
243
|
+
|
|
244
|
+
def state_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
245
|
+
init: Init | Sequence[Init] = torch.zeros_like,
|
|
246
|
+
mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
|
|
247
|
+
cls: type[ListLike] = TensorList) -> ListLike | list[ListLike]:
|
|
248
|
+
|
|
249
|
+
if isinstance(mode, (list,tuple)): params = mode
|
|
250
|
+
else: params = self.get_params(mode)
|
|
251
|
+
|
|
252
|
+
return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
|
|
253
|
+
|
|
254
|
+
def loss_at_params(self, closure, params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
|
|
255
|
+
return loss_at_params(closure=closure,params=self.get_params(),new_params=params,backward=backward,restore=restore)
|
|
256
|
+
|
|
257
|
+
def loss_grad_at_params(self, closure, params: Sequence[torch.Tensor] | Any, restore=False):
|
|
258
|
+
return loss_grad_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
|
|
259
|
+
|
|
260
|
+
def grad_at_params(self, closure, new_params: Sequence[torch.Tensor], restore=False):
|
|
261
|
+
return self.loss_grad_at_params(closure=closure,params=new_params,restore=restore)[1]
|
|
262
|
+
|
|
263
|
+
def loss_grad_vec_at_params(self, closure, params: Any, restore=False):
|
|
264
|
+
return loss_grad_vec_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
|
|
265
|
+
|
|
266
|
+
def grad_vec_at_params(self, closure, params: Any, restore=False):
|
|
267
|
+
return self.loss_grad_vec_at_params(closure=closure,params=params,restore=restore)[1]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
|
|
271
|
+
if set_to_none:
|
|
272
|
+
for p in params:
|
|
273
|
+
p.grad = None
|
|
274
|
+
|
|
275
|
+
else:
|
|
276
|
+
grads = [p.grad for p in params if p.grad is not None]
|
|
277
|
+
for grad in grads:
|
|
278
|
+
# taken from torch.optim.Optimizer.zero_grad
|
|
279
|
+
if grad.grad_fn is not None:
|
|
280
|
+
grad.detach_()
|
|
281
|
+
else:
|
|
282
|
+
grad.requires_grad_(False)
|
|
283
|
+
|
|
284
|
+
torch._foreach_zero_(grads)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import optuna
|
|
2
|
+
|
|
3
|
+
from ..core import Chain, Module
|
|
4
|
+
|
|
5
|
+
from ..modules import (
|
|
6
|
+
EMA,
|
|
7
|
+
NAG,
|
|
8
|
+
Cautious,
|
|
9
|
+
ClipNorm,
|
|
10
|
+
ClipNormGrowth,
|
|
11
|
+
ClipValue,
|
|
12
|
+
ClipValueGrowth,
|
|
13
|
+
Debias,
|
|
14
|
+
Normalize,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_momentum(trial: optuna.Trial, prefix: str, conditional: bool=True) -> list[Module]:
|
|
19
|
+
cond = trial.suggest_categorical(f'{prefix}_use_momentum', [True,False]) if conditional else True
|
|
20
|
+
if cond:
|
|
21
|
+
beta = trial.suggest_float(f'{prefix}_beta', -1, 2)
|
|
22
|
+
dampening = trial.suggest_float(f'{prefix}_dampening', -1, 2)
|
|
23
|
+
lerp = trial.suggest_categorical(f'{prefix}_use_lerp', [True, False])
|
|
24
|
+
nag = trial.suggest_categorical(f'{prefix}_use_NAG', [True, False])
|
|
25
|
+
debiased = trial.suggest_categorical(f'{prefix}_debiased', [True, False])
|
|
26
|
+
if nag:
|
|
27
|
+
m = NAG(beta, dampening, lerp)
|
|
28
|
+
if debiased: m = Chain(m, Debias(beta1=beta))
|
|
29
|
+
else:
|
|
30
|
+
m = EMA(beta, dampening, debiased=debiased, lerp=lerp)
|
|
31
|
+
return [m]
|
|
32
|
+
return []
|
|
33
|
+
|
|
34
|
+
def get_clip_value(trial: optuna.Trial, prefix: str, conditional: bool=True) -> list[Module]:
|
|
35
|
+
cond = trial.suggest_categorical(f'{prefix}_use_clip_value', [True,False]) if conditional else True
|
|
36
|
+
if cond:
|
|
37
|
+
return [ClipValue(value = trial.suggest_float(f'{prefix}_clip_value', 0, 10))]
|
|
38
|
+
return []
|
|
39
|
+
|
|
40
|
+
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from collections.abc import Sequence, Iterable, Mapping
|
|
3
|
+
import warnings
|
|
4
|
+
import torch, numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
Params = Iterable[torch.Tensor | tuple[str, torch.Tensor] | Mapping[str, Any]]
|
|
9
|
+
|
|
10
|
+
def _validate_params_are_unique_(params: Sequence[torch.Tensor]):
|
|
11
|
+
# this is from pytorch add_param_group
|
|
12
|
+
if len(params) != len(set(params)):
|
|
13
|
+
warnings.warn(
|
|
14
|
+
"optimizer contains a parameter group with duplicate parameters; "
|
|
15
|
+
"in future, this will cause an error; "
|
|
16
|
+
"see github.com/pytorch/pytorch/issues/40967 for more information",
|
|
17
|
+
stacklevel=3,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def _validate_param_is_differentiable_(tensor: torch.Tensor | Any):
|
|
21
|
+
"""Checks that param is torch.Tensor and isn't a leaf parameter unless differentiable is True, otherwise this raises, this is taken from torch.optim.Optimizer."""
|
|
22
|
+
if not (tensor.is_leaf or tensor.retains_grad):
|
|
23
|
+
raise ValueError("can't optimize a non-leaf Tensor")
|
|
24
|
+
|
|
25
|
+
def _validate_at_least_one_param_requires_grad_(params: Iterable[torch.Tensor]):
|
|
26
|
+
params = list(params)
|
|
27
|
+
if not any(p.requires_grad for p in params):
|
|
28
|
+
warnings.warn(
|
|
29
|
+
"Parameter group contains no parameters which require gradients. "
|
|
30
|
+
"Note for gradient-free optimizers, they still only optimize parameters with requires_grad=True, "
|
|
31
|
+
"so if needed, use `with torch.no_grad():` context instead.", stacklevel=3)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _copy_param_groups(param_groups: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
"""copies param_groups, doesn't copy the tensors."""
|
|
37
|
+
new_param_group = []
|
|
38
|
+
|
|
39
|
+
for g in param_groups:
|
|
40
|
+
assert isinstance(g, dict)
|
|
41
|
+
g_copy = g.copy()
|
|
42
|
+
|
|
43
|
+
for k in ('params', 'updates', 'grads'):
|
|
44
|
+
if k in g_copy:
|
|
45
|
+
assert isinstance(g_copy[k], list)
|
|
46
|
+
g_copy[k] = g_copy[k].copy()
|
|
47
|
+
|
|
48
|
+
new_param_group.append(g_copy)
|
|
49
|
+
|
|
50
|
+
return new_param_group
|
|
51
|
+
|
|
52
|
+
def _process_param_group_(param_group: dict[str, Any]) -> dict[str, Any]:
|
|
53
|
+
"""makes sure `param_group["params"]` is a list of tensors, and sets `param_group["param_names"]` if params are named."""
|
|
54
|
+
if 'params' not in param_group: raise KeyError("Param group doesn't have a `params` key.")
|
|
55
|
+
|
|
56
|
+
if isinstance(param_group['params'], torch.Tensor): param_group['params'] = [param_group['params']]
|
|
57
|
+
|
|
58
|
+
tensors: list[torch.Tensor] = []
|
|
59
|
+
names: list[str] | None = []
|
|
60
|
+
|
|
61
|
+
for p in param_group['params']:
|
|
62
|
+
if isinstance(p, torch.Tensor):
|
|
63
|
+
tensors.append(p)
|
|
64
|
+
|
|
65
|
+
elif isinstance(p, tuple):
|
|
66
|
+
if len(p) != 2:
|
|
67
|
+
raise ValueError(f'named_parameters must be a tuple of (name, tensor), got length {len(p)} tuple')
|
|
68
|
+
if (not isinstance(p[0], str)) or (not isinstance(p[1], torch.Tensor)):
|
|
69
|
+
raise ValueError(f'named_parameters must be a tuple of (name, tensor), got {[type(a) for a in p]}')
|
|
70
|
+
names.append(p[0])
|
|
71
|
+
tensors.append(p[1])
|
|
72
|
+
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError(f'Parameters must be tensors or tuples (name, tensor), got parameter of type {type(p)}')
|
|
75
|
+
|
|
76
|
+
if len(tensors) == 0: warnings.warn('got an empty parameter group')
|
|
77
|
+
|
|
78
|
+
param_group['params'] = tensors
|
|
79
|
+
|
|
80
|
+
if len(names) != 0:
|
|
81
|
+
if len(names) != len(tensors):
|
|
82
|
+
raise ValueError(f"Number of parameters {len(tensors)} doesn't match number of names {len(names)}")
|
|
83
|
+
param_group['param_names'] = names
|
|
84
|
+
|
|
85
|
+
return param_group
|
|
86
|
+
|
|
87
|
+
def _make_param_groups(params: Params, differentiable: bool) -> list[dict[str, Any]]:
|
|
88
|
+
params = list(params)
|
|
89
|
+
|
|
90
|
+
param_groups: list[dict[str, Any]] = [dict(p) for p in params if isinstance(p, Mapping)]
|
|
91
|
+
tensors = [p for p in params if isinstance(p, torch.Tensor)]
|
|
92
|
+
named_tensors = [p for p in params if isinstance(p, tuple)]
|
|
93
|
+
|
|
94
|
+
if len(tensors) != 0: param_groups.append({"params": tensors})
|
|
95
|
+
if len(named_tensors) != 0: param_groups.append({"params": named_tensors})
|
|
96
|
+
|
|
97
|
+
# process param_groups
|
|
98
|
+
for g in param_groups:
|
|
99
|
+
_process_param_group_(g)
|
|
100
|
+
|
|
101
|
+
# validate
|
|
102
|
+
all_params = [p for g in param_groups for p in g['params']]
|
|
103
|
+
_validate_params_are_unique_(all_params)
|
|
104
|
+
_validate_at_least_one_param_requires_grad_(all_params)
|
|
105
|
+
if not differentiable:
|
|
106
|
+
for p in all_params: _validate_param_is_differentiable_(p)
|
|
107
|
+
|
|
108
|
+
return param_groups
|
|
109
|
+
|
|
110
|
+
def _add_defaults_to_param_groups_(param_groups: list[dict[str, Any]], defaults: dict[str, Any]) -> list[dict[str, Any]]:
|
|
111
|
+
for group in param_groups:
|
|
112
|
+
for k, v in defaults.items():
|
|
113
|
+
if k not in group:
|
|
114
|
+
group[k] = v
|
|
115
|
+
return param_groups
|
|
116
|
+
|
|
117
|
+
def _add_updates_grads_to_param_groups_(param_groups: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
118
|
+
for group in param_groups:
|
|
119
|
+
if 'updates' in group: raise ValueError('updates in group')
|
|
120
|
+
group['updates'] = [None for _ in group['params']]
|
|
121
|
+
|
|
122
|
+
if 'grads' in group: raise ValueError('grads in group')
|
|
123
|
+
group['grads'] = [None for _ in group['grads']]
|
|
124
|
+
|
|
125
|
+
return param_groups
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _set_update_and_grad_(
|
|
129
|
+
param_groups: list[dict[str, Any]],
|
|
130
|
+
updates: list[torch.Tensor] | None,
|
|
131
|
+
grads: list[torch.Tensor] | None,
|
|
132
|
+
) -> list[dict[str, Any]]:
|
|
133
|
+
if updates is None and grads is None: return param_groups
|
|
134
|
+
|
|
135
|
+
updates_iter = iter(updates) if updates is not None else None
|
|
136
|
+
grads_iter = iter(grads) if grads is not None else None
|
|
137
|
+
|
|
138
|
+
for group in param_groups:
|
|
139
|
+
group_params = group['params']
|
|
140
|
+
group_updates = group['updates']
|
|
141
|
+
group_grads = group['grads']
|
|
142
|
+
|
|
143
|
+
for i, param in enumerate(group_params):
|
|
144
|
+
if not param.requires_grad: continue
|
|
145
|
+
if updates_iter is not None: group_updates[i] = next(updates_iter)
|
|
146
|
+
if grads_iter is not None: group_grads[i] = next(grads_iter)
|
|
147
|
+
|
|
148
|
+
return param_groups
|
|
149
|
+
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -1,25 +1,40 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
import operator
|
|
3
|
-
from typing import Any, TypeVar
|
|
4
|
-
from collections.abc import Iterable
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def _flatten_no_check(iterable: Iterable) -> list[Any]:
|
|
9
|
-
"""Flatten an iterable of iterables, returns a flattened list. Note that if `iterable` is not Iterable, this will return `[iterable]`."""
|
|
10
|
-
if isinstance(iterable, Iterable):
|
|
11
|
-
return [a for i in iterable for a in _flatten_no_check(i)]
|
|
12
|
-
return [iterable]
|
|
13
|
-
|
|
14
|
-
def flatten(iterable: Iterable) -> list[Any]:
|
|
15
|
-
"""Flatten an iterable of iterables, returns a flattened list. If `iterable` is not iterable, raises a TypeError."""
|
|
16
|
-
if isinstance(iterable, Iterable): return [a for i in iterable for a in _flatten_no_check(i)]
|
|
17
|
-
raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
|
|
18
|
-
|
|
19
|
-
X = TypeVar("X")
|
|
20
|
-
# def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
|
|
21
|
-
def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
|
|
22
|
-
"""Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
|
|
23
|
-
return functools.reduce(operator.iconcat, x, [])
|
|
24
|
-
|
|
25
|
-
|
|
1
|
+
import functools
|
|
2
|
+
import operator
|
|
3
|
+
from typing import Any, TypeVar
|
|
4
|
+
from collections.abc import Iterable, Callable
|
|
5
|
+
from collections import UserDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _flatten_no_check(iterable: Iterable) -> list[Any]:
|
|
9
|
+
"""Flatten an iterable of iterables, returns a flattened list. Note that if `iterable` is not Iterable, this will return `[iterable]`."""
|
|
10
|
+
if isinstance(iterable, Iterable) and not isinstance(iterable, str):
|
|
11
|
+
return [a for i in iterable for a in _flatten_no_check(i)]
|
|
12
|
+
return [iterable]
|
|
13
|
+
|
|
14
|
+
def flatten(iterable: Iterable) -> list[Any]:
|
|
15
|
+
"""Flatten an iterable of iterables, returns a flattened list. If `iterable` is not iterable, raises a TypeError."""
|
|
16
|
+
if isinstance(iterable, Iterable): return [a for i in iterable for a in _flatten_no_check(i)]
|
|
17
|
+
raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
|
|
18
|
+
|
|
19
|
+
X = TypeVar("X")
|
|
20
|
+
# def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
|
|
21
|
+
def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
|
|
22
|
+
"""Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
|
|
23
|
+
return functools.reduce(operator.iconcat, x, [])
|
|
24
|
+
|
|
25
|
+
def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
|
|
26
|
+
"""generic equals function that supports scalars and lists of numbers"""
|
|
27
|
+
if isinstance(x, (int,float)):
|
|
28
|
+
if isinstance(y, (int,float)): return x==y
|
|
29
|
+
return all(i==x for i in y)
|
|
30
|
+
if isinstance(y, (int,float)):
|
|
31
|
+
return all(i==y for i in x)
|
|
32
|
+
return all(i==j for i,j in zip(x,y))
|
|
33
|
+
|
|
34
|
+
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
35
|
+
"""If `other` is list/tuple, applies `fn` to self zipped with `other`.
|
|
36
|
+
Otherwise applies `fn` to this sequence and `other`.
|
|
37
|
+
Returns a new sequence with return values of the callable."""
|
|
38
|
+
if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
|
|
39
|
+
return self.__class__(fn(i, other, *args, **kwargs) for i in self)
|
|
40
|
+
|