torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +55 -22
- tests/test_tensorlist.py +3 -3
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +20 -130
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +76 -26
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +15 -15
- torchzero/modules/quasi_newton/lsr1.py +18 -17
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- torchzero/modules/second_order/newton.py +38 -21
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +19 -19
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8.dist-info/RECORD +0 -130
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
torchzero/utils/optimizer.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
1
2
|
from collections.abc import Callable, Iterable, Mapping, MutableSequence, Sequence, MutableMapping
|
|
2
3
|
from typing import Any, Literal, TypeVar, overload
|
|
3
4
|
|
|
@@ -132,65 +133,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
132
133
|
return values
|
|
133
134
|
|
|
134
135
|
|
|
135
|
-
|
|
136
|
-
def loss_at_params(closure, params: Iterable[torch.Tensor],
|
|
137
|
-
new_params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
|
|
138
|
-
params = TensorList(params)
|
|
139
|
-
|
|
140
|
-
old_params = params.clone() if restore else None
|
|
141
|
-
|
|
142
|
-
if isinstance(new_params, Sequence) and isinstance(new_params[0], torch.Tensor):
|
|
143
|
-
# when not restoring, copy new_params to params to avoid unexpected bugs due to shared storage
|
|
144
|
-
# when restoring params will be set back to old_params so its fine
|
|
145
|
-
if restore: params.set_(new_params)
|
|
146
|
-
else: params.copy_(new_params) # type:ignore
|
|
147
|
-
|
|
148
|
-
else:
|
|
149
|
-
new_params = totensor(new_params)
|
|
150
|
-
params.from_vec_(new_params)
|
|
151
|
-
|
|
152
|
-
if backward: loss = closure()
|
|
153
|
-
else: loss = closure(False)
|
|
154
|
-
|
|
155
|
-
if restore:
|
|
156
|
-
assert old_params is not None
|
|
157
|
-
params.set_(old_params)
|
|
158
|
-
|
|
159
|
-
return tofloat(loss)
|
|
160
|
-
|
|
161
|
-
def loss_grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
|
|
162
|
-
params = TensorList(params)
|
|
163
|
-
old_params = params.clone() if restore else None
|
|
164
|
-
loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
|
|
165
|
-
grad = params.ensure_grad_().grad
|
|
166
|
-
|
|
167
|
-
if restore:
|
|
168
|
-
assert old_params is not None
|
|
169
|
-
params.set_(old_params)
|
|
170
|
-
|
|
171
|
-
return loss, grad
|
|
172
|
-
|
|
173
|
-
def grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
|
|
174
|
-
return loss_grad_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
|
|
175
|
-
|
|
176
|
-
def loss_grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
|
|
177
|
-
params = TensorList(params)
|
|
178
|
-
old_params = params.clone() if restore else None
|
|
179
|
-
loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
|
|
180
|
-
grad = params.ensure_grad_().grad.to_vec()
|
|
181
|
-
|
|
182
|
-
if restore:
|
|
183
|
-
assert old_params is not None
|
|
184
|
-
params.set_(old_params)
|
|
185
|
-
|
|
186
|
-
return loss, grad
|
|
187
|
-
|
|
188
|
-
def grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
|
|
189
|
-
return loss_grad_vec_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
class Optimizer(torch.optim.Optimizer):
|
|
136
|
+
class Optimizer(torch.optim.Optimizer, ABC):
|
|
194
137
|
"""subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
|
|
195
138
|
|
|
196
139
|
Args:
|
|
@@ -251,21 +194,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
251
194
|
|
|
252
195
|
return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
|
|
253
196
|
|
|
254
|
-
def loss_at_params(self, closure, params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
|
|
255
|
-
return loss_at_params(closure=closure,params=self.get_params(),new_params=params,backward=backward,restore=restore)
|
|
256
|
-
|
|
257
|
-
def loss_grad_at_params(self, closure, params: Sequence[torch.Tensor] | Any, restore=False):
|
|
258
|
-
return loss_grad_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
|
|
259
|
-
|
|
260
|
-
def grad_at_params(self, closure, new_params: Sequence[torch.Tensor], restore=False):
|
|
261
|
-
return self.loss_grad_at_params(closure=closure,params=new_params,restore=restore)[1]
|
|
262
|
-
|
|
263
|
-
def loss_grad_vec_at_params(self, closure, params: Any, restore=False):
|
|
264
|
-
return loss_grad_vec_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
|
|
265
|
-
|
|
266
|
-
def grad_vec_at_params(self, closure, params: Any, restore=False):
|
|
267
|
-
return self.loss_grad_vec_at_params(closure=closure,params=params,restore=restore)[1]
|
|
268
197
|
|
|
198
|
+
# shut up pylance
|
|
199
|
+
@abstractmethod
|
|
200
|
+
def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
269
201
|
|
|
270
202
|
def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
|
|
271
203
|
if set_to_none:
|
|
@@ -281,4 +213,53 @@ def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
|
|
|
281
213
|
else:
|
|
282
214
|
grad.requires_grad_(False)
|
|
283
215
|
|
|
284
|
-
torch._foreach_zero_(grads)
|
|
216
|
+
torch._foreach_zero_(grads)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@overload
|
|
220
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
221
|
+
key: str, *,
|
|
222
|
+
must_exist: bool = False, init: Init = torch.zeros_like,
|
|
223
|
+
cls: type[ListLike] = list) -> ListLike: ...
|
|
224
|
+
@overload
|
|
225
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
226
|
+
key: list[str] | tuple[str,...], *,
|
|
227
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
228
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
229
|
+
@overload
|
|
230
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
231
|
+
key: str, key2: str, *keys: str,
|
|
232
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
233
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
234
|
+
|
|
235
|
+
def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
|
|
236
|
+
key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
237
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
238
|
+
cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
239
|
+
|
|
240
|
+
# single key, return single cls
|
|
241
|
+
if isinstance(key, str) and key2 is None:
|
|
242
|
+
values = cls()
|
|
243
|
+
for i,s in enumerate(states):
|
|
244
|
+
if key not in s:
|
|
245
|
+
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
246
|
+
s[key] = _make_initial_state_value(tensors[i], init, i)
|
|
247
|
+
values.append(s[key])
|
|
248
|
+
return values
|
|
249
|
+
|
|
250
|
+
# multiple keys
|
|
251
|
+
k1 = (key,) if isinstance(key, str) else tuple(key)
|
|
252
|
+
k2 = () if key2 is None else (key2,)
|
|
253
|
+
keys = k1 + k2 + keys
|
|
254
|
+
|
|
255
|
+
values = [cls() for _ in keys]
|
|
256
|
+
for i,s in enumerate(states):
|
|
257
|
+
for k_i, key in enumerate(keys):
|
|
258
|
+
if key not in s:
|
|
259
|
+
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
260
|
+
k_init = init[k_i] if isinstance(init, (list,tuple)) else init
|
|
261
|
+
s[key] = _make_initial_state_value(tensors[i], k_init, i)
|
|
262
|
+
values[k_i].append(s[key])
|
|
263
|
+
|
|
264
|
+
return values
|
|
265
|
+
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import operator
|
|
3
|
-
from typing import Any, TypeVar
|
|
4
|
-
from collections.abc import Iterable, Callable
|
|
3
|
+
from typing import Any, TypeVar, overload
|
|
4
|
+
from collections.abc import Iterable, Callable, Mapping, MutableSequence
|
|
5
5
|
from collections import UserDict
|
|
6
6
|
|
|
7
7
|
|
|
@@ -17,8 +17,8 @@ def flatten(iterable: Iterable) -> list[Any]:
|
|
|
17
17
|
raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
|
|
18
18
|
|
|
19
19
|
X = TypeVar("X")
|
|
20
|
-
# def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]:
|
|
21
|
-
def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]:
|
|
20
|
+
# def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]:
|
|
21
|
+
def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]:
|
|
22
22
|
"""Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
|
|
23
23
|
return functools.reduce(operator.iconcat, x, [])
|
|
24
24
|
|
|
@@ -38,3 +38,16 @@ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
|
38
38
|
if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
|
|
39
39
|
return self.__class__(fn(i, other, *args, **kwargs) for i in self)
|
|
40
40
|
|
|
41
|
+
ListLike = TypeVar('ListLike', bound=MutableSequence)
|
|
42
|
+
@overload
|
|
43
|
+
def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, *, cls:type[ListLike]=list) -> ListLike: ...
|
|
44
|
+
@overload
|
|
45
|
+
def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str, *keys:str, cls:type[ListLike]=list) -> list[ListLike]: ...
|
|
46
|
+
def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None = None, *keys:str, cls:type[ListLike]=list) -> ListLike | list[ListLike]:
|
|
47
|
+
k1 = (key,) if isinstance(key, str) else tuple(key)
|
|
48
|
+
k2 = () if key2 is None else (key2,)
|
|
49
|
+
keys = k1 + k2 + keys
|
|
50
|
+
|
|
51
|
+
values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
|
|
52
|
+
if len(values) == 1: return values[0]
|
|
53
|
+
return values
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.10
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -157,13 +157,14 @@ for epoch in range(100):
|
|
|
157
157
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
158
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
159
|
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
160
|
+
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
160
161
|
|
|
161
162
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
163
|
* `LBFGS`: Limited-memory BFGS.
|
|
163
164
|
* `LSR1`: Limited-memory SR1.
|
|
164
165
|
* `OnlineLBFGS`: Online LBFGS.
|
|
165
|
-
* `BFGS`, `SR1`, `
|
|
166
|
-
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
|
|
166
|
+
* `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `ColumnUpdatingMethod`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`: Classic full-matrix quasi-newton methods.
|
|
167
|
+
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
167
168
|
|
|
168
169
|
* **Line Search**:
|
|
169
170
|
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
@@ -312,20 +313,20 @@ not in the module itself. Also both per-parameter settings and state are stored
|
|
|
312
313
|
|
|
313
314
|
```python
|
|
314
315
|
import torch
|
|
315
|
-
from torchzero.core import Module,
|
|
316
|
+
from torchzero.core import Module, Var
|
|
316
317
|
|
|
317
318
|
class HeavyBall(Module):
|
|
318
319
|
def __init__(self, momentum: float = 0.9, dampening: float = 0):
|
|
319
320
|
defaults = dict(momentum=momentum, dampening=dampening)
|
|
320
321
|
super().__init__(defaults)
|
|
321
322
|
|
|
322
|
-
def step(self,
|
|
323
|
-
# a module takes a
|
|
324
|
-
#
|
|
323
|
+
def step(self, var: Var):
|
|
324
|
+
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
325
|
+
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
325
326
|
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
326
327
|
|
|
327
|
-
params =
|
|
328
|
-
update =
|
|
328
|
+
params = var.params
|
|
329
|
+
update = var.get_update() # list of tensors
|
|
329
330
|
|
|
330
331
|
exp_avg_list = []
|
|
331
332
|
for p, u in zip(params, update):
|
|
@@ -346,16 +347,15 @@ class HeavyBall(Module):
|
|
|
346
347
|
# and it is part of self.state
|
|
347
348
|
exp_avg_list.append(buf.clone())
|
|
348
349
|
|
|
349
|
-
# set new update to
|
|
350
|
-
|
|
351
|
-
return
|
|
350
|
+
# set new update to var
|
|
351
|
+
var.update = exp_avg_list
|
|
352
|
+
return var
|
|
352
353
|
```
|
|
353
354
|
|
|
354
355
|
There are a some specialized base modules that make it much easier to implement some specific things.
|
|
355
356
|
|
|
356
357
|
* `GradApproximator` for gradient approximations
|
|
357
358
|
* `LineSearch` for line searches
|
|
358
|
-
* `Preconditioner` for preconditioners
|
|
359
359
|
* `Projection` for projections like GaLore or into fourier domain.
|
|
360
360
|
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
361
361
|
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
@@ -376,4 +376,4 @@ There are also wrappers providing `torch.optim.Optimizer` interface for for `sci
|
|
|
376
376
|
|
|
377
377
|
They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
378
378
|
|
|
379
|
-
Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
|
|
379
|
+
Apparently <https://github.com/avaneev/biteopt> is diabolical so I will add a wrapper for it too very soon.
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
|
|
2
|
+
tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
|
|
3
|
+
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
4
|
+
tests/test_opts.py,sha256=VSko5fUuACo_y6iab_akke0gMhCUEEUJ9ahpBqWBoM4,41715
|
|
5
|
+
tests/test_tensorlist.py,sha256=SwzLKLrs2ppMtm_7UrfTDTlD-ObZd7JQ_FNHbp059tc,72460
|
|
6
|
+
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
7
|
+
tests/test_vars.py,sha256=MqCJXrbj-C75APm1heykzcEWewinihlSjekkYDx-TFk,6726
|
|
8
|
+
torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
|
|
9
|
+
torchzero/core/__init__.py,sha256=Zib_4is13LFAabp_7VU8QXZpQEEZGzsH94vgRI0HxAg,150
|
|
10
|
+
torchzero/core/module.py,sha256=Yfzn48dDbxYZJLpWnLYFIbqBb4sB3GekSZ7QGIplYAg,27525
|
|
11
|
+
torchzero/core/transform.py,sha256=yK1wYgp03THzRN9y_f9-5q2nonEZMa0CfDFAdOxnqEU,11778
|
|
12
|
+
torchzero/modules/__init__.py,sha256=8C73_dFzfWUWhii1UF86FUy8x75RPiAVLAm4sLTikBg,359
|
|
13
|
+
torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
|
|
14
|
+
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
15
|
+
torchzero/modules/clipping/clipping.py,sha256=XKFKvzNgsvuYUmvHyulE6PkZv_aeLQjp0CgtFj0013s,12516
|
|
16
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=MGouZEN0BorliHAZhue0afhC3AhZJ6wrnwBRzDTHjX4,5978
|
|
17
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=50c1YOUPVL8eWzH6zJINnNP68oiZkDcq7rR6HnWjVFc,6674
|
|
18
|
+
torchzero/modules/experimental/__init__.py,sha256=zxxNKPZHnkVnx1ZjKNX_nkV4Wc_EdODM6qJGn7Pgb3w,766
|
|
19
|
+
torchzero/modules/experimental/absoap.py,sha256=-KwQXmI12hvHbMGPHM0APAxDQztlFhlSOG55KK6PvpI,9901
|
|
20
|
+
torchzero/modules/experimental/adadam.py,sha256=o0KPLaF4J7L_Ty71RNgsysk6IEuC4DRE5nGQkGIP_dA,4078
|
|
21
|
+
torchzero/modules/experimental/adamY.py,sha256=LZabWX_vccDaG6_UVZl9ALJ-3nCZu-NEygJQ_Bwzel8,4018
|
|
22
|
+
torchzero/modules/experimental/adasoap.py,sha256=XtxEvBWYdcqfWnQqOFa_-SrOwd_nXHzLftiw-YXDACQ,7408
|
|
23
|
+
torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
|
|
24
|
+
torchzero/modules/experimental/diagonal_higher_order_newton.py,sha256=u4-a5qJ_97XiZUDlClE2cASkBsx_NTJNPk6BWWybiqE,7158
|
|
25
|
+
torchzero/modules/experimental/eigendescent.py,sha256=0cM1p4rYbrpwBNXgBEMblVyX0xBWTzojSC1EsUnXH6k,4707
|
|
26
|
+
torchzero/modules/experimental/etf.py,sha256=FsLOCmQf24PPoRf5wsRUjVqk32uW9uTzaf1ERjFxRK8,5744
|
|
27
|
+
torchzero/modules/experimental/gradmin.py,sha256=UixSLdca4ekYHOipEivdXfBAV-uEL9TZm5nCFXVaNco,3684
|
|
28
|
+
torchzero/modules/experimental/newton_solver.py,sha256=3dZ7FG-2vGxJKkFF9P2LCs-LI_epcvZbyNtJOtw47pg,3055
|
|
29
|
+
torchzero/modules/experimental/newtonnewton.py,sha256=QCGnY_CFo0i_NUB7D-6ezeNpG6wLkTD5lHBiakFIqbM,3033
|
|
30
|
+
torchzero/modules/experimental/reduce_outward_lr.py,sha256=VFjcTpmLwpfhUR8u_5rbzPgHVR6K3fvti7jVy1DnsYU,1300
|
|
31
|
+
torchzero/modules/experimental/soapy.py,sha256=7qsh9Y9U9oeQDwuDSVqnz71AD0nUYY3q0XN2XoMFWaw,6721
|
|
32
|
+
torchzero/modules/experimental/spectral.py,sha256=SN7tToIpmna0IZ1NgObvqEbO48NnVbwqRwKi8ROsb7s,7374
|
|
33
|
+
torchzero/modules/experimental/structured_newton.py,sha256=CWfVJ2LPZUuz1bMnlgOM6tlYPd2etjgLDIcyAfAG_y8,3464
|
|
34
|
+
torchzero/modules/experimental/subspace_preconditioners.py,sha256=9Tl1PCN9crFUvVn6343GHoI3kv6CVnUWP1dfhwUvAFU,5130
|
|
35
|
+
torchzero/modules/experimental/tada.py,sha256=84YcLhG34CbWq84L-AUj-A4uxpzdIVayaARHRm2f9b8,1564
|
|
36
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
37
|
+
torchzero/modules/grad_approximation/fdm.py,sha256=cUgy98Bz0Br4q6ViNxn6EVOZX2jE0nDXVZLUGhxpDcA,3589
|
|
38
|
+
torchzero/modules/grad_approximation/forward_gradient.py,sha256=cNgx8kc8r0fWj8xdU2b85W3fenNDQZKuIsJLM3UzSig,3867
|
|
39
|
+
torchzero/modules/grad_approximation/grad_approximator.py,sha256=TODFUwBgTmjfbnO6Sc833fnvYzYaqqYTEba_13s-qOI,2906
|
|
40
|
+
torchzero/modules/grad_approximation/rfdm.py,sha256=VsRlf95JnG6HdlIsJANcfJjMk7c_B9a5-fH9dSTBA10,11328
|
|
41
|
+
torchzero/modules/higher_order/__init__.py,sha256=W94CY8K1NFxs9TPi415UssKVKz5MV_bH9adax1uZsYM,50
|
|
42
|
+
torchzero/modules/higher_order/higher_order_newton.py,sha256=BwiSlcGobam04SgWFcB1p_-TSuzu2rWgGVnmvP6Si9k,9751
|
|
43
|
+
torchzero/modules/line_search/__init__.py,sha256=nkOUPLe88wE91ICEhprl2pJsvaKtbI3KzYOdT83AGsg,253
|
|
44
|
+
torchzero/modules/line_search/backtracking.py,sha256=ZgeLAYqrw-6BeEGp8wmOgFoLtUKROF7w7LpAREe0xZU,7704
|
|
45
|
+
torchzero/modules/line_search/line_search.py,sha256=CfOENZgAPSdyv1wvSbhw6gdpfbQnXGdOnLsq29wjvzU,7229
|
|
46
|
+
torchzero/modules/line_search/scipy.py,sha256=SvDCZ1DPOLZcSeOFvf3tXAf1ty-9qRVfGFMWVF5q708,2293
|
|
47
|
+
torchzero/modules/line_search/strong_wolfe.py,sha256=xOU4XFekh4TIepm9ztJTYpcGucEMPwAeb_cDK4Rp0ho,7620
|
|
48
|
+
torchzero/modules/line_search/trust_region.py,sha256=xUZApOTW4uXFBk_Uq_YBktiXcoSAKdDc6O5vjTwquGw,3101
|
|
49
|
+
torchzero/modules/lr/__init__.py,sha256=kh2k_tma-oTOALR6AlD5XHdTPSMgU4A04Oa0hAqrEpI,89
|
|
50
|
+
torchzero/modules/lr/adaptive.py,sha256=6s06Gvu1UmoT89hrMkXvJWHkEOMNcy5mMiyxy3V9lQs,3904
|
|
51
|
+
torchzero/modules/lr/lr.py,sha256=1gU2QzMA5PV2KkzOkxxrZZKGcz-Kbjyp7WNurOM36ys,2655
|
|
52
|
+
torchzero/modules/momentum/__init__.py,sha256=pSD7vxu8PySrYOSHQMi3C9heYdcQr8y6WC_rwMybZm0,544
|
|
53
|
+
torchzero/modules/momentum/averaging.py,sha256=NmRodxsSekEDGIuFGDYOvJL-WkdMN3YF-naBdtfjxx8,3247
|
|
54
|
+
torchzero/modules/momentum/cautious.py,sha256=JuaFYfyf9S3nTcqeZz5ylXKepqi0eqglOAQ0uNG0eT8,7373
|
|
55
|
+
torchzero/modules/momentum/ema.py,sha256=qJV__nIbcD9e8qvwbvsATnYkQrdnmMiA91ju52IqSxw,10699
|
|
56
|
+
torchzero/modules/momentum/experimental.py,sha256=eYnP6NmBDegwX9XC_dYMJP3vquBpM1LyQc03v3vW6-8,6900
|
|
57
|
+
torchzero/modules/momentum/matrix_momentum.py,sha256=LR12UugXM8ocwTB8zBYpt03oZeZU0cb0UoFR6qO34V8,6818
|
|
58
|
+
torchzero/modules/momentum/momentum.py,sha256=4Pgk-3HM7Av_ILT6oXtvnM1CB1yit8AkFnYWLvnUAqs,2655
|
|
59
|
+
torchzero/modules/ops/__init__.py,sha256=hxMZFSXX7xvitXkuBiYykVGX3p03Xprm_QA2CMg4eW8,1601
|
|
60
|
+
torchzero/modules/ops/accumulate.py,sha256=yKNgw8ZsaVRPjuzPzLJOvALkjik0aWx30Eu91FefRoA,3741
|
|
61
|
+
torchzero/modules/ops/binary.py,sha256=98jyjkJ8BPuSH-mb4g2BnFi6UzvRZRf__Pt-jnD3pNU,9690
|
|
62
|
+
torchzero/modules/ops/debug.py,sha256=zueWyNVvaJmxRD8QG8m_ys9jc7zRfSr8GAuxqz5dDTI,851
|
|
63
|
+
torchzero/modules/ops/misc.py,sha256=GmnKDjMXaTUjPcC5e7Jftk6k2NQ0Ivv4ceUApxMckIQ,15978
|
|
64
|
+
torchzero/modules/ops/multi.py,sha256=T1aVaRr6bLWvjoj1cyxaDncROypT6rmmmji8mvBHczo,5713
|
|
65
|
+
torchzero/modules/ops/reduce.py,sha256=reGvusJyCzM8VdHbWyJRYFePPBXfVP0jZeXIEKGIJGc,5668
|
|
66
|
+
torchzero/modules/ops/split.py,sha256=eM4Qsz6pMNF22bk3NF2rtvyxSOt9U-EyYxMAyjvTrMQ,2265
|
|
67
|
+
torchzero/modules/ops/switch.py,sha256=ddsxq4bsH86iWW6mMdcQw3c0mU1s2FA-PRZpVOia7PY,2506
|
|
68
|
+
torchzero/modules/ops/unary.py,sha256=3ysDHXs6snsQNBj3c288BT8G6T30Nvo0QM3PcdfQ2ww,4888
|
|
69
|
+
torchzero/modules/ops/utility.py,sha256=8XFjQO4ghCmGD2H-lYTgaBzik_9pB3Uxt7xCxQrv5Ig,3080
|
|
70
|
+
torchzero/modules/optimizers/__init__.py,sha256=BbT2nhIt4p74t1cO8ziQgzqZHaLvyuleXQbccugd06M,554
|
|
71
|
+
torchzero/modules/optimizers/adagrad.py,sha256=NHpWcnIRM2LyPnNtDVTdluX4n1qqqWs9IHpFD8uYeLo,5500
|
|
72
|
+
torchzero/modules/optimizers/adam.py,sha256=u6ieXHn_lHZozmGiKhSA73pApI83eeTNIyOrxBTFL1o,4009
|
|
73
|
+
torchzero/modules/optimizers/lion.py,sha256=4yy6d0SLpGXndu8NCuYhdsNshMEYhONu_FPYXdupA_s,1119
|
|
74
|
+
torchzero/modules/optimizers/muon.py,sha256=exbp7wVpIryiOxmbf9RAfZ9a6XXuOWTUqdjn-i57Fq4,9628
|
|
75
|
+
torchzero/modules/optimizers/orthograd.py,sha256=cN5g7OusfeUlh0jn0kjkvpcVjqR01eGoi9WK1sSPnug,2021
|
|
76
|
+
torchzero/modules/optimizers/rmsprop.py,sha256=jM5ohfABYUljy2RrtG_bY9PMHNzSkROYjqFPxnlXE6o,4309
|
|
77
|
+
torchzero/modules/optimizers/rprop.py,sha256=d0R8UR-f9Pb64VrsJegrCyteLYa5TAmgObjgirqLaBo,11030
|
|
78
|
+
torchzero/modules/optimizers/shampoo.py,sha256=hmfgPghzmjmba3PH1vLzaz0lOvLiIX9rCKrT71YZb40,8420
|
|
79
|
+
torchzero/modules/optimizers/soap.py,sha256=7adybqncrkt31rNveQwXp8eeZKWf0LDhC5wt7GbmDcM,11052
|
|
80
|
+
torchzero/modules/optimizers/sophia_h.py,sha256=He9YrHeaQhiz4CJm-3H_d_M07MGTsP663v8wx4BlaZI,4273
|
|
81
|
+
torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
|
|
82
|
+
torchzero/modules/projections/dct.py,sha256=0tswjgta3mE5D5Yjw9mJWqPBDga0OIe3lKlwd1AXASc,2369
|
|
83
|
+
torchzero/modules/projections/fft.py,sha256=wNDZP5-3b2-bND3qH2yvX3SqFaljbLkPTQ1gUnlH5fU,2955
|
|
84
|
+
torchzero/modules/projections/galore.py,sha256=etaG2gxazxuDEu-e2r7lKIKMTPEGGS5Vi7HXccmD3kY,241
|
|
85
|
+
torchzero/modules/projections/projection.py,sha256=QUV_Gi6QlPiWEmcc7rwucr2yuYwYFGvSRUAT4uucqMY,10049
|
|
86
|
+
torchzero/modules/projections/structural.py,sha256=f8-72zViXJ6S2gxDagkrrul9IaOPsYXZmX8sFLYkxCc,5683
|
|
87
|
+
torchzero/modules/quasi_newton/__init__.py,sha256=Yc-NV__cJCiYLr2BZG4VsYa3VVq4gCxBMcirQEXSNIo,630
|
|
88
|
+
torchzero/modules/quasi_newton/cg.py,sha256=lvmwJNTR7AEcpDIvpcLnMrZrOLwNld8GFAC19CcTKoY,11661
|
|
89
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=BDiv3f7qN8-Nhs8LqtWwk7Wwv88NtXXYle5WwKeekm4,9198
|
|
90
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=A0Pstikb6JrQbwM5RZjLw9WJEHiMRy3PsPF1_iLkrK4,6053
|
|
91
|
+
torchzero/modules/quasi_newton/olbfgs.py,sha256=Tz2eubiN7OXGN1mbXT4VKPd9kynpXzcLas7mrvBax-k,8333
|
|
92
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=4hRII9GFE5MzNtXkHH_T1hEJ1T8T4-Q4A4MXlhf64mc,25142
|
|
93
|
+
torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
|
|
94
|
+
torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=oLbJ96sl-2XBwLbJrrTZiLJIhKhTPOD6-wny7hbSno4,10767
|
|
95
|
+
torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
|
|
96
|
+
torchzero/modules/second_order/newton.py,sha256=ZYIcLpifcOHL_KRC6YoNs-MJQKM39urXUQzReWnWeXE,6583
|
|
97
|
+
torchzero/modules/second_order/newton_cg.py,sha256=YAEAD_8YU_H8Y-o6JI0Ywgk-kpAQOFBQm2Bjzaz9Bjs,2865
|
|
98
|
+
torchzero/modules/second_order/nystrom.py,sha256=aM6dlDv7znGYNXZgKN6B6AhZ1Tpp01JMs83B1hcXE3w,6061
|
|
99
|
+
torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
|
|
100
|
+
torchzero/modules/smoothing/gaussian.py,sha256=KbCgRXGntdPbt4-ojalrHkniYgYXk2294b-2C4MIFi8,6109
|
|
101
|
+
torchzero/modules/smoothing/laplacian.py,sha256=Vp2EnCQhyfGc3CbyOLc6_ZiVx_jvnOISf9vlHkIH4Jo,4998
|
|
102
|
+
torchzero/modules/weight_decay/__init__.py,sha256=j2Vq3DDxLYIPJmXWgAJ6dL-rXzcDEZxxvhJqRT3H0-U,95
|
|
103
|
+
torchzero/modules/weight_decay/weight_decay.py,sha256=UFL9W5w5nzTZGWvCwyGLe9UWBKN8FTClme1Klt7XZPw,3034
|
|
104
|
+
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
105
|
+
torchzero/modules/wrappers/optim_wrapper.py,sha256=-wNI-fN8eaMSkvPIcPa34yxH0St5aLn7jaaLeh2DUsM,3569
|
|
106
|
+
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
107
|
+
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
108
|
+
torchzero/optim/utility/split.py,sha256=ZbazNuMTYunm75V_5ard0A_LletGaYAg-Pm2rANJKrE,1610
|
|
109
|
+
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
110
|
+
torchzero/optim/wrappers/directsearch.py,sha256=Y2-7Sy4mYRPXPh0FTlsY_XOk5pCGjZsnbrlWCPZNp6A,10141
|
|
111
|
+
torchzero/optim/wrappers/fcmaes.py,sha256=TQvIktXV8ldy6smBX-S7ZcQEbSmSZyj567TuYShbvJg,3513
|
|
112
|
+
torchzero/optim/wrappers/mads.py,sha256=lC7edtrFS37PgmX7z9-eoqw6prl0k5BDB4NVBVQXJWE,2945
|
|
113
|
+
torchzero/optim/wrappers/nevergrad.py,sha256=qslMb-4_kfjU3Dd0UbbzE2SdLViil3Qjo2v0FtPE3Fg,4000
|
|
114
|
+
torchzero/optim/wrappers/nlopt.py,sha256=AaVEKfjbrt5DFION44_-g-jQAoVi4lCvBBPU5UDGO9Q,8151
|
|
115
|
+
torchzero/optim/wrappers/optuna.py,sha256=YN1I3rzsi20A9963pWNWd7W75FkxalVb5z5fCRQeWA0,2280
|
|
116
|
+
torchzero/optim/wrappers/scipy.py,sha256=pR26v8v0a-o2u0sbsKXpZ9JUKqXMaaI8gGLI8xYx3-s,19239
|
|
117
|
+
torchzero/utils/__init__.py,sha256=7beAjXvnmBQoy5hwYHY_PBUtrrbYb9Z7-KrYgfcFkPE,844
|
|
118
|
+
torchzero/utils/compile.py,sha256=N8AWLv_7oBUHYornmvvx_L4uynjiD-x5Hj1tBwei3-w,5127
|
|
119
|
+
torchzero/utils/derivatives.py,sha256=sAVd0Q1xmIPpo_AxRuoow66Hy_3goX_9o3lQK_1TyW0,16909
|
|
120
|
+
torchzero/utils/numberlist.py,sha256=cbG0UsSb9WCRxVhw8sd7Yf0bDy_gSqtghiJtkUxIO6U,6139
|
|
121
|
+
torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
|
|
122
|
+
torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
|
|
123
|
+
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
124
|
+
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
125
|
+
torchzero/utils/python_tools.py,sha256=T5W7MpR7pNXiWSVw7gj-UuE9Ch0p9LRWuUZfg9Vtb-I,2794
|
|
126
|
+
torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
|
|
127
|
+
torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
|
|
128
|
+
torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
|
|
129
|
+
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
130
|
+
torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
|
|
131
|
+
torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
|
|
132
|
+
torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
|
|
133
|
+
torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
|
|
134
|
+
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
135
|
+
torchzero-0.3.10.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
|
|
136
|
+
torchzero-0.3.10.dist-info/METADATA,sha256=_J7AbrIa-nD6UWbuydCwxAnSpKcC9O1Vp_rM896ZkYQ,14081
|
|
137
|
+
torchzero-0.3.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
138
|
+
torchzero-0.3.10.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
|
|
139
|
+
torchzero-0.3.10.dist-info/RECORD,,
|
torchzero/core/preconditioner.py
DELETED
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from collections import ChainMap, defaultdict
|
|
3
|
-
from collections.abc import Mapping, Sequence
|
|
4
|
-
from typing import Any, overload, final
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from .module import Module, Chainable, Vars
|
|
9
|
-
from .transform import apply, Transform, Target
|
|
10
|
-
from ..utils import TensorList, vec_to_tensors
|
|
11
|
-
|
|
12
|
-
class Preconditioner(Transform):
|
|
13
|
-
"""Abstract class for a preconditioner."""
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
defaults: dict | None,
|
|
17
|
-
uses_grad: bool,
|
|
18
|
-
concat_params: bool = False,
|
|
19
|
-
update_freq: int = 1,
|
|
20
|
-
scale_first: bool = False,
|
|
21
|
-
inner: Chainable | None = None,
|
|
22
|
-
target: Target = "update",
|
|
23
|
-
):
|
|
24
|
-
if defaults is None: defaults = {}
|
|
25
|
-
defaults.update(dict(__update_freq=update_freq, __concat_params=concat_params, __scale_first=scale_first))
|
|
26
|
-
super().__init__(defaults, uses_grad=uses_grad, target=target)
|
|
27
|
-
|
|
28
|
-
if inner is not None:
|
|
29
|
-
self.set_child('inner', inner)
|
|
30
|
-
|
|
31
|
-
@abstractmethod
|
|
32
|
-
def update(self, tensors: list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
|
|
33
|
-
"""updates the preconditioner with `tensors`, any internal state should be stored using `keys`"""
|
|
34
|
-
|
|
35
|
-
@abstractmethod
|
|
36
|
-
def apply(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> list[torch.Tensor]:
|
|
37
|
-
"""applies preconditioner to `tensors`, any internal state should be stored using `keys`"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
-
step = self.global_state.get('__step', 0)
|
|
42
|
-
states = [self.state[p] for p in params]
|
|
43
|
-
settings = [self.settings[p] for p in params]
|
|
44
|
-
global_settings = settings[0]
|
|
45
|
-
update_freq = global_settings['__update_freq']
|
|
46
|
-
|
|
47
|
-
scale_first = global_settings['__scale_first']
|
|
48
|
-
scale_factor = 0
|
|
49
|
-
if scale_first and step == 0:
|
|
50
|
-
# initial step size guess from pytorch LBFGS was too unstable
|
|
51
|
-
# I switched to norm
|
|
52
|
-
tensors = TensorList(tensors)
|
|
53
|
-
scale_factor = tensors.abs().global_mean().clip(min=1)
|
|
54
|
-
|
|
55
|
-
# update preconditioner
|
|
56
|
-
if step % update_freq == 0:
|
|
57
|
-
self.update(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
|
|
58
|
-
|
|
59
|
-
# step with inner
|
|
60
|
-
if 'inner' in self.children:
|
|
61
|
-
tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
|
|
62
|
-
|
|
63
|
-
# apply preconditioner
|
|
64
|
-
tensors = self.apply(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
|
|
65
|
-
|
|
66
|
-
# scale initial step, when preconditioner might not have been applied
|
|
67
|
-
if scale_first and step == 0:
|
|
68
|
-
torch._foreach_div_(tensors, scale_factor)
|
|
69
|
-
|
|
70
|
-
self.global_state['__step'] = step + 1
|
|
71
|
-
return tensors
|
|
72
|
-
|
|
73
|
-
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
74
|
-
step = self.global_state.get('__step', 0)
|
|
75
|
-
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
76
|
-
params_vec = torch.cat([p.ravel() for p in params])
|
|
77
|
-
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
78
|
-
|
|
79
|
-
states = [self.state[params[0]]]
|
|
80
|
-
settings = [self.settings[params[0]]]
|
|
81
|
-
global_settings = settings[0]
|
|
82
|
-
update_freq = global_settings['__update_freq']
|
|
83
|
-
|
|
84
|
-
scale_first = global_settings['__scale_first']
|
|
85
|
-
scale_factor = 0
|
|
86
|
-
if scale_first and step == 0:
|
|
87
|
-
# initial step size guess from pytorch LBFGS was too unstable
|
|
88
|
-
scale_factor = tensors_vec.abs().mean().clip(min=1)
|
|
89
|
-
|
|
90
|
-
# update preconditioner
|
|
91
|
-
if step % update_freq == 0:
|
|
92
|
-
self.update(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)
|
|
93
|
-
|
|
94
|
-
# step with inner
|
|
95
|
-
if 'inner' in self.children:
|
|
96
|
-
tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
|
|
97
|
-
tensors_vec = torch.cat([t.ravel() for t in tensors]) # have to recat
|
|
98
|
-
|
|
99
|
-
# apply preconditioner
|
|
100
|
-
tensors_vec = self.apply(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)[0]
|
|
101
|
-
|
|
102
|
-
# scale initial step, when preconditioner might not have been applied
|
|
103
|
-
if scale_first and step == 0:
|
|
104
|
-
tensors_vec /= scale_factor
|
|
105
|
-
|
|
106
|
-
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
107
|
-
self.global_state['__step'] = step + 1
|
|
108
|
-
return tensors
|
|
109
|
-
|
|
110
|
-
@torch.no_grad
|
|
111
|
-
def transform(self, tensors, params, grads, vars):
|
|
112
|
-
concat_params = self.settings[params[0]]['__concat_params']
|
|
113
|
-
if concat_params: return self._concat_transform(tensors, params, grads, vars)
|
|
114
|
-
return self._tensor_wise_transform(tensors, params, grads, vars)
|
|
115
|
-
|
|
116
|
-
class TensorwisePreconditioner(Preconditioner, ABC):
|
|
117
|
-
@abstractmethod
|
|
118
|
-
def update_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]):
|
|
119
|
-
"""update preconditioner with `tensor`"""
|
|
120
|
-
|
|
121
|
-
@abstractmethod
|
|
122
|
-
def apply_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
|
|
123
|
-
"""apply preconditioner to `tensor`"""
|
|
124
|
-
|
|
125
|
-
@final
|
|
126
|
-
def update(self, tensors, params, grads, states, settings):
|
|
127
|
-
if grads is None: grads = [None]*len(tensors)
|
|
128
|
-
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
129
|
-
self.update_tensor(t, p, g, state, setting)
|
|
130
|
-
|
|
131
|
-
@final
|
|
132
|
-
def apply(self, tensors, params, grads, states, settings):
|
|
133
|
-
preconditioned = []
|
|
134
|
-
if grads is None: grads = [None]*len(tensors)
|
|
135
|
-
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
136
|
-
preconditioned.append(self.apply_tensor(t, p, g, state, setting))
|
|
137
|
-
return preconditioned
|
|
138
|
-
|