torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
torchzero/utils/tensorlist.py
CHANGED
|
@@ -11,17 +11,19 @@ in an optimizer when you have to create one from parameters on each step. The so
|
|
|
11
11
|
it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
|
|
12
12
|
"""
|
|
13
13
|
import builtins
|
|
14
|
-
from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
|
|
15
14
|
import math
|
|
16
15
|
import operator
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
17
18
|
from typing import Any, Literal, TypedDict, overload
|
|
18
|
-
from typing_extensions import Self, TypeAlias, Unpack
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
|
-
from
|
|
22
|
-
from .python_tools import zipmap, generic_ne
|
|
23
|
-
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
21
|
+
from typing_extensions import Self, TypeAlias, Unpack
|
|
24
22
|
|
|
23
|
+
from .metrics import Metrics, evaluate_metric, calculate_metric_list
|
|
24
|
+
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
25
|
+
from .ops import where_
|
|
26
|
+
from .python_tools import generic_ne, zipmap
|
|
25
27
|
|
|
26
28
|
_Scalar = int | float | bool | complex
|
|
27
29
|
_TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
|
|
@@ -33,6 +35,7 @@ _STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
|
|
|
33
35
|
_Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
|
|
34
36
|
|
|
35
37
|
Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
|
|
38
|
+
|
|
36
39
|
class _NewTensorKwargs(TypedDict, total = False):
|
|
37
40
|
memory_format: Any
|
|
38
41
|
dtype: Any
|
|
@@ -325,9 +328,20 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
325
328
|
def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
|
|
326
329
|
def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
|
|
327
330
|
def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
331
|
+
|
|
332
|
+
def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
|
|
333
|
+
# return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
334
|
+
if ord == 1: return self.global_sum()
|
|
335
|
+
if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
|
|
336
|
+
if ord == torch.inf: return self.abs().global_max()
|
|
337
|
+
if ord == -torch.inf: return self.abs().global_min()
|
|
338
|
+
if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
|
|
339
|
+
|
|
340
|
+
return self.abs().pow_(ord).global_sum().pow(1/ord)
|
|
341
|
+
|
|
342
|
+
def global_metric(self, metric: Metrics) -> torch.Tensor:
|
|
343
|
+
return evaluate_metric(self, metric)
|
|
344
|
+
|
|
331
345
|
def global_any(self): return builtins.any(self.any())
|
|
332
346
|
def global_all(self): return builtins.all(self.all())
|
|
333
347
|
def global_numel(self) -> int: return builtins.sum(self.numel())
|
|
@@ -358,31 +372,54 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
358
372
|
|
|
359
373
|
def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
|
|
360
374
|
return self.zipmap_args(torch.randint_like, low, high, **kwargs)
|
|
375
|
+
|
|
361
376
|
def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
362
377
|
res = self.empty_like(**kwargs)
|
|
363
378
|
res.uniform_(low, high, generator=generator)
|
|
364
379
|
return res
|
|
380
|
+
|
|
365
381
|
def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
|
|
366
382
|
r = self.randn_like(generator=generator, **kwargs)
|
|
367
|
-
return (
|
|
383
|
+
return r.mul_(maybe_numberlist(radius) / r.global_vector_norm())
|
|
384
|
+
|
|
368
385
|
def bernoulli(self, generator = None):
|
|
369
386
|
return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
|
|
387
|
+
|
|
370
388
|
def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
371
389
|
"""p is probability of a 1, other values will be 0."""
|
|
372
390
|
return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
|
|
391
|
+
|
|
373
392
|
def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
374
393
|
"""p is probability of a 1, other values will be -1."""
|
|
375
394
|
return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
|
|
376
395
|
|
|
377
|
-
def sample_like(self,
|
|
396
|
+
def sample_like(self, distribution: Distributions = 'normal', variance: "_Scalar | _ScalarSeq | Sequence | None" = None, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
378
397
|
"""Sample around 0."""
|
|
379
|
-
if
|
|
398
|
+
if isinstance(variance, Sequence):
|
|
399
|
+
if all(v is None for v in variance): variance = None
|
|
400
|
+
else: variance = [v if v is not None else 1 for v in variance]
|
|
401
|
+
|
|
402
|
+
if distribution in ('normal', 'gaussian'):
|
|
403
|
+
ret = self.randn_like(generator=generator, **kwargs)
|
|
404
|
+
if variance is not None: ret *= variance
|
|
405
|
+
return ret
|
|
406
|
+
|
|
380
407
|
if distribution == 'uniform':
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
408
|
+
b = 1
|
|
409
|
+
if variance is not None:
|
|
410
|
+
b = ((12 * maybe_numberlist(variance)) ** 0.5) / 2
|
|
411
|
+
return self.uniform_like(-b, b, generator=generator, **kwargs)
|
|
412
|
+
|
|
413
|
+
if distribution == 'sphere':
|
|
414
|
+
if variance is None: radius = 1
|
|
415
|
+
else: radius = maybe_numberlist(variance) * math.sqrt(self.global_numel())
|
|
416
|
+
return self.sphere_like(radius, generator=generator, **kwargs)
|
|
417
|
+
|
|
418
|
+
if distribution == 'rademacher':
|
|
419
|
+
ret = self.rademacher_like(generator=generator, **kwargs)
|
|
420
|
+
if variance is not None: ret *= variance
|
|
421
|
+
return ret
|
|
422
|
+
|
|
386
423
|
raise ValueError(f'Unknow distribution {distribution}')
|
|
387
424
|
|
|
388
425
|
def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
|
|
@@ -504,6 +541,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
504
541
|
torch._foreach_pow_(input, self)
|
|
505
542
|
return self
|
|
506
543
|
|
|
544
|
+
def square(self): return self.__class__(torch._foreach_pow(self, 2))
|
|
545
|
+
def square_(self):
|
|
546
|
+
torch._foreach_pow_(self, 2)
|
|
547
|
+
return self
|
|
548
|
+
|
|
507
549
|
def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
|
|
508
550
|
def sqrt_(self):
|
|
509
551
|
torch._foreach_sqrt_(self)
|
|
@@ -634,10 +676,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
634
676
|
if dim is None: dim = ()
|
|
635
677
|
return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
|
|
636
678
|
|
|
637
|
-
def norm(self, ord:
|
|
638
|
-
if isinstance(ord, str): return self.abs().mean()
|
|
679
|
+
def norm(self, ord: float, dtype=None):
|
|
639
680
|
return self.__class__(torch._foreach_norm(self, ord, dtype))
|
|
640
681
|
|
|
682
|
+
def metric(self, metric: Metrics) -> "TensorList":
|
|
683
|
+
return calculate_metric_list(self, metric)
|
|
684
|
+
|
|
641
685
|
def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
642
686
|
if dim == 'global': return self._global_fn(keepdim, self.global_mean)
|
|
643
687
|
return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
|
|
@@ -790,27 +834,27 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
790
834
|
for t, o in zip(self, other): t.copysign_(o)
|
|
791
835
|
return self
|
|
792
836
|
|
|
793
|
-
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord:
|
|
837
|
+
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
794
838
|
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
795
839
|
if tensorwise:
|
|
796
|
-
norm_self = self.
|
|
797
|
-
norm_other = magnitude.
|
|
840
|
+
norm_self = self.metric(ord)
|
|
841
|
+
norm_other = magnitude.metric(ord)
|
|
798
842
|
else:
|
|
799
|
-
norm_self = self.
|
|
800
|
-
norm_other = magnitude.
|
|
843
|
+
norm_self = self.global_metric(ord)
|
|
844
|
+
norm_other = magnitude.global_metric(ord)
|
|
801
845
|
|
|
802
846
|
if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
803
847
|
|
|
804
848
|
return self * (norm_other / norm_self.clip_(min=eps))
|
|
805
849
|
|
|
806
|
-
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord:
|
|
850
|
+
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
807
851
|
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
808
852
|
if tensorwise:
|
|
809
|
-
norm_self = self.
|
|
810
|
-
norm_other = magnitude.
|
|
853
|
+
norm_self = self.metric(ord)
|
|
854
|
+
norm_other = magnitude.metric(ord)
|
|
811
855
|
else:
|
|
812
|
-
norm_self = self.
|
|
813
|
-
norm_other = magnitude.
|
|
856
|
+
norm_self = self.global_metric(ord)
|
|
857
|
+
norm_other = magnitude.global_metric(ord)
|
|
814
858
|
|
|
815
859
|
if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
816
860
|
|
|
@@ -905,14 +949,14 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
905
949
|
if eps!=0: std.add_(eps)
|
|
906
950
|
return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
|
|
907
951
|
|
|
908
|
-
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:
|
|
952
|
+
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
|
|
909
953
|
"""calculate multipler to clip self norm to min and max"""
|
|
910
954
|
if tensorwise:
|
|
911
|
-
self_norm = self.
|
|
955
|
+
self_norm = self.metric(ord)
|
|
912
956
|
self_norm.masked_fill_(self_norm == 0, 1)
|
|
913
957
|
|
|
914
958
|
else:
|
|
915
|
-
self_norm = self.
|
|
959
|
+
self_norm = self.global_metric(ord)
|
|
916
960
|
if self_norm == 0: return 1
|
|
917
961
|
|
|
918
962
|
mul = 1
|
|
@@ -926,12 +970,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
926
970
|
|
|
927
971
|
return mul
|
|
928
972
|
|
|
929
|
-
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:
|
|
973
|
+
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
|
|
930
974
|
"""clips norm of each tensor to (min, max) range"""
|
|
931
975
|
if min is None and max is None: return self
|
|
932
976
|
return self * self._clip_multiplier(min, max, tensorwise, ord)
|
|
933
977
|
|
|
934
|
-
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:
|
|
978
|
+
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
|
|
935
979
|
"""clips norm of each tensor to (min, max) range"""
|
|
936
980
|
if min is None and max is None: return self
|
|
937
981
|
return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
|
|
@@ -990,6 +1034,15 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
990
1034
|
# """sets index in flattened view"""
|
|
991
1035
|
# return self.clone().flatset_(idx, value)
|
|
992
1036
|
|
|
1037
|
+
def flat_get(self, idx: int):
|
|
1038
|
+
cur = 0
|
|
1039
|
+
for tensor in self:
|
|
1040
|
+
numel = tensor.numel()
|
|
1041
|
+
if idx < cur + numel:
|
|
1042
|
+
return tensor.view(-1)[cur-idx]
|
|
1043
|
+
cur += numel
|
|
1044
|
+
raise IndexError(idx)
|
|
1045
|
+
|
|
993
1046
|
def flat_set_(self, idx: int, value: Any):
|
|
994
1047
|
"""sets index in flattened view"""
|
|
995
1048
|
cur = 0
|
|
@@ -1065,10 +1118,19 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
|
|
|
1065
1118
|
if isinstance(x, torch.Tensor): return x.numel()
|
|
1066
1119
|
return x.global_numel()
|
|
1067
1120
|
|
|
1121
|
+
|
|
1122
|
+
def generic_finfo(x: torch.Tensor | TensorList) -> torch.finfo:
|
|
1123
|
+
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype)
|
|
1124
|
+
return torch.finfo(x[0].dtype)
|
|
1125
|
+
|
|
1068
1126
|
def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
|
|
1069
1127
|
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
|
|
1070
1128
|
return torch.finfo(x[0].dtype).eps
|
|
1071
1129
|
|
|
1130
|
+
def generic_finfo_tiny(x: torch.Tensor | TensorList) -> float:
|
|
1131
|
+
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).tiny
|
|
1132
|
+
return torch.finfo(x[0].dtype).tiny
|
|
1133
|
+
|
|
1072
1134
|
@overload
|
|
1073
1135
|
def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
1074
1136
|
@overload
|
|
@@ -1081,7 +1143,8 @@ def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
|
|
|
1081
1143
|
if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
|
|
1082
1144
|
return x.global_vector_norm(ord)
|
|
1083
1145
|
|
|
1084
|
-
|
|
1146
|
+
def generic_metric(x: torch.Tensor | TensorList, metric: Metrics) -> torch.Tensor:
|
|
1147
|
+
return evaluate_metric(x, metric)
|
|
1085
1148
|
|
|
1086
1149
|
@overload
|
|
1087
1150
|
def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
@@ -1091,3 +1154,11 @@ def generic_randn_like(x: torch.Tensor | TensorList):
|
|
|
1091
1154
|
if isinstance(x, torch.Tensor): return torch.randn_like(x)
|
|
1092
1155
|
return x.randn_like()
|
|
1093
1156
|
|
|
1157
|
+
|
|
1158
|
+
def generic_sum(x: torch.Tensor | TensorList) -> torch.Tensor:
|
|
1159
|
+
if isinstance(x, torch.Tensor): return x.sum()
|
|
1160
|
+
return x.global_sum()
|
|
1161
|
+
|
|
1162
|
+
def generic_max(x: torch.Tensor | TensorList) -> torch.Tensor:
|
|
1163
|
+
if isinstance(x, torch.Tensor): return x.max()
|
|
1164
|
+
return x.global_max()
|
torchzero/utils/torch_tools.py
CHANGED
|
@@ -7,10 +7,15 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def totensor(x):
|
|
11
|
-
if
|
|
12
|
-
|
|
13
|
-
|
|
10
|
+
def totensor(x, device=None, dtype=None):
|
|
11
|
+
if device is None and dtype is None:
|
|
12
|
+
if isinstance(x, torch.Tensor): return x
|
|
13
|
+
if isinstance(x, np.ndarray): return torch.from_numpy(x)
|
|
14
|
+
return torch.from_numpy(np.asarray(x))
|
|
15
|
+
|
|
16
|
+
if isinstance(x, torch.Tensor): return x.to(device=device, dtype=dtype)
|
|
17
|
+
if isinstance(x, np.ndarray): return torch.as_tensor(x, device=device, dtype=dtype)
|
|
18
|
+
return torch.as_tensor(np.asarray(x), device=device, dtype=dtype)
|
|
14
19
|
|
|
15
20
|
def tonumpy(x):
|
|
16
21
|
if isinstance(x, np.ndarray): return x
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchzero
|
|
3
|
+
Version: 0.3.14
|
|
4
|
+
Summary: Modular optimization library for PyTorch.
|
|
5
|
+
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/inikishev/torchzero
|
|
7
|
+
Project-URL: Repository, https://github.com/inikishev/torchzero
|
|
8
|
+
Project-URL: Issues, https://github.com/inikishev/torchzero/isses
|
|
9
|
+
Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
|
|
10
|
+
Requires-Python: >=3.10
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
Requires-Dist: torch
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: typing_extensions
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
|
|
2
|
+
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
3
|
+
tests/test_opts.py,sha256=XsOTxiKJQJOUr47d1nzyQ1WFclBrhXYKQMdEVN66bWs,43687
|
|
4
|
+
tests/test_tensorlist.py,sha256=pWXQE-vEq08EGJSKWgsTgo-7QjjkavOJ5BlWUm241qI,72434
|
|
5
|
+
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
6
|
+
tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
|
|
7
|
+
torchzero/__init__.py,sha256=aIH-cCTXnDr90cKUPhM8bv-uE69Hzjlf0jlYspYf0ZM,120
|
|
8
|
+
torchzero/core/__init__.py,sha256=aYyQt-CHzWT6hGUt5KVjRZZr2lsX5I1XvbWpzaAv3VE,151
|
|
9
|
+
torchzero/core/module.py,sha256=chn9NZWdgYekzVPDppvp2REMp4WugiMhOjQU3tys6ZU,40651
|
|
10
|
+
torchzero/core/reformulation.py,sha256=jppgzXBtqdsc7ot6_Gr38vJbbrhG1Gs4vC32y7iB4BA,2387
|
|
11
|
+
torchzero/core/transform.py,sha256=xRDpsZj0H1QcFdO-t2mNMNOYoqqnRHiI3K1YluWwCVk,17097
|
|
12
|
+
torchzero/modules/__init__.py,sha256=3lGta9P0N3cWdVcruCBJ7uqu4DfLTPCKI_mlOZT6Z_o,615
|
|
13
|
+
torchzero/modules/functional.py,sha256=E_d6hLT2_xdE-3AhQ4AthDYK5uZULbF10iHI09Z3_yk,7921
|
|
14
|
+
torchzero/modules/adaptive/__init__.py,sha256=5L2dlEJV6HKBnYhgd7wo2yGi0WPd9qmpw9XS5wOQOq8,944
|
|
15
|
+
torchzero/modules/adaptive/adagrad.py,sha256=0qXC5F4PuOsgLjRXQUWBoiq0AUixsvOP1uDbEeRIcNs,12531
|
|
16
|
+
torchzero/modules/adaptive/adahessian.py,sha256=rWxgDiBMd6MK64mRjZwiudDF07is5AFj_MCEvBD7h8U,8670
|
|
17
|
+
torchzero/modules/adaptive/adam.py,sha256=4lWSe__tdyRv0rfkUda1qa_NH36DIW3sd8td98bK6XI,3829
|
|
18
|
+
torchzero/modules/adaptive/adan.py,sha256=Dt_gibyrGtWDIUCaSF6RFIxu2xwiF9fCfrpkoD-CaUM,2825
|
|
19
|
+
torchzero/modules/adaptive/adaptive_heavyball.py,sha256=xQQw1Vx-NgQd_ouK14J1p5ijd5lEn3sAN0hVJAL0j8U,2024
|
|
20
|
+
torchzero/modules/adaptive/aegd.py,sha256=_4ASgDX8__DPnnBE_RncnMqM4rItM7Eji4EZzSGGq5I,1876
|
|
21
|
+
torchzero/modules/adaptive/esgd.py,sha256=DtrN3hZhGK4LgMxmcjQCBtEO-5hLZrAdnvps6f8WQ2A,6416
|
|
22
|
+
torchzero/modules/adaptive/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
|
|
23
|
+
torchzero/modules/adaptive/lmadagrad.py,sha256=rMs7vrgiwOJgWo-OZXkGu32X561edEhhaxwgbY2NTnk,7176
|
|
24
|
+
torchzero/modules/adaptive/mars.py,sha256=iOkyY3r52btp7Cry7WN0AB4arpn4N9b_Hg6S55XC6Q8,2255
|
|
25
|
+
torchzero/modules/adaptive/matrix_momentum.py,sha256=ZMSdGNSHgyUgtJwjKzK7PZDwMrPw5wZJVu0K4Xa-SpI,6693
|
|
26
|
+
torchzero/modules/adaptive/msam.py,sha256=locqM2jiC3AbGCDCo6T40GF3iaVV2svrkzoi0hD2cJI,6663
|
|
27
|
+
torchzero/modules/adaptive/muon.py,sha256=5Asgj03s6JXXrO-p5Qgn3D8bVwbDEqCq7hxNy4joQDE,10335
|
|
28
|
+
torchzero/modules/adaptive/natural_gradient.py,sha256=5qRehh-iAeZ4hjfOR-gfsObbMsJZnipzCkG-yptkrH0,6349
|
|
29
|
+
torchzero/modules/adaptive/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
|
|
30
|
+
torchzero/modules/adaptive/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
|
|
31
|
+
torchzero/modules/adaptive/rprop.py,sha256=VDnLPKxw8ECihyUeNVE8cyDll_Ut3k3_NqoLgpgxxLA,11818
|
|
32
|
+
torchzero/modules/adaptive/sam.py,sha256=LnOPNZnIUsis0402RHnA-fTPkNM8baUR9HR50pF_BtM,5696
|
|
33
|
+
torchzero/modules/adaptive/shampoo.py,sha256=r7V4I5_Ve1YVOS3HhO2k5cZvJT1lPHTVApV3iVJVceA,9711
|
|
34
|
+
torchzero/modules/adaptive/soap.py,sha256=roQLBthNNNmTYgeJPi_LxZY-r4m6REeUo0_DZknYU50,10662
|
|
35
|
+
torchzero/modules/adaptive/sophia_h.py,sha256=lSK8uVdOxBAhU2jE6fyIx1YgqEQyZG-Fv9o2TniAZzk,7179
|
|
36
|
+
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
37
|
+
torchzero/modules/clipping/clipping.py,sha256=t98M3QKZKqXJ3_tzXXIiG4EOYMaHqLYrMZ-6zmRuy-k,14331
|
|
38
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=Ki0LPNUwPoE825A5rSE7SxGQMiI3nO3iwnjKQ486iaI,6611
|
|
39
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
|
|
40
|
+
torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
|
|
41
|
+
torchzero/modules/conjugate_gradient/cg.py,sha256=L49p_wBeyQ3pZmaPM1vLx7uWmZBKtIhMf7Uv2ggUbI4,14534
|
|
42
|
+
torchzero/modules/experimental/__init__.py,sha256=blI-OhpQAC6-Ho1uxUq-t7Mm9CAMnNMXkBDmXul8tbc,729
|
|
43
|
+
torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
|
|
44
|
+
torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
|
|
45
|
+
torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
|
|
46
|
+
torchzero/modules/experimental/gradmin.py,sha256=hKTOG7tk6FnG8t-7OmTAhGTGSDdONzP1JvCRPRqaKt0,3740
|
|
47
|
+
torchzero/modules/experimental/l_infinity.py,sha256=nhYusM3YYbc0ptaSf9zlrsqY8EgxlHm9OejJ6VV0qtM,4750
|
|
48
|
+
torchzero/modules/experimental/momentum.py,sha256=VqZc14EGVO_KUABPLRIBlvHdgg-64o-4heMQH0vW5vY,5233
|
|
49
|
+
torchzero/modules/experimental/newton_solver.py,sha256=0HnDBlrBLvUgS4hWmkJqyw0M7UPFp3kU3SFq2xZVYhQ,5454
|
|
50
|
+
torchzero/modules/experimental/newtonnewton.py,sha256=a0XXvlVe37z2MMcQ4TeGbbWX9OuYp_5b-21jq3o1z3E,3823
|
|
51
|
+
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
|
|
52
|
+
torchzero/modules/experimental/scipy_newton_cg.py,sha256=8nKBaHHmqdU9F1kVPn2QVFUTx2_I8Jsfqxix1v-qoL0,4073
|
|
53
|
+
torchzero/modules/experimental/spsa1.py,sha256=tytlOaKSyhvoElErI-MzhbL0fIm3K4d_M1Kpbpb9jbw,3622
|
|
54
|
+
torchzero/modules/experimental/structural_projections.py,sha256=rxJFG5F23dOiK_8KqKyvoSMLWqAOXtVGHSwfRqH22Wg,4185
|
|
55
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
|
|
56
|
+
torchzero/modules/grad_approximation/fdm.py,sha256=zx70GZDQmhe43bZP5Mbbl31xsMOsGO43kznoQDbqxJo,4372
|
|
57
|
+
torchzero/modules/grad_approximation/forward_gradient.py,sha256=pCOsvt4ZZtzsIlAeXLmpS8vNXYewA_Gh7uyz1_1yROs,4011
|
|
58
|
+
torchzero/modules/grad_approximation/grad_approximator.py,sha256=x8vlweBrfJ6SmhMHvI_C8UZGzlS3AnmlulvqnSzm6iY,4437
|
|
59
|
+
torchzero/modules/grad_approximation/rfdm.py,sha256=zgzZYJDAq6HsOMeHePUWlnnJWYRpqey77nHMaHh8140,19217
|
|
60
|
+
torchzero/modules/higher_order/__init__.py,sha256=iaoIrmR9DJE9QHt9PeZNCWqIYDe-86h1IjkaumR4qF0,51
|
|
61
|
+
torchzero/modules/higher_order/higher_order_newton.py,sha256=2r1wuhdi57pbo8akQE88O8R-Y79BtiwD1WQIShh1rjQ,12967
|
|
62
|
+
torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
|
|
63
|
+
torchzero/modules/least_squares/gn.py,sha256=23AB6AWAl5IuBj4Vd3boQ6ndk0pO3ovaF9EiY1a1XWs,5094
|
|
64
|
+
torchzero/modules/line_search/__init__.py,sha256=mFWgkcgfMkL2NKj3CLbuwee3e8WHBOaXs-wtx3oTW58,216
|
|
65
|
+
torchzero/modules/line_search/_polyinterp.py,sha256=qIhcLjOlpB6NHU0oiUGMncwQxWNfy8757orsbzjkp6s,10882
|
|
66
|
+
torchzero/modules/line_search/adaptive.py,sha256=8Ip5F5PpsDLgg6TwB_E7zIZheycd78coRg4u7cpO3Cg,3795
|
|
67
|
+
torchzero/modules/line_search/backtracking.py,sha256=Mhx8_UT_Mr1gASYHUorBJ38E4YlcM9LpW9YrJHYfLXU,9049
|
|
68
|
+
torchzero/modules/line_search/line_search.py,sha256=lmtjr9Zpz9RYJXoYaJnpXkBSIdcN6DdwGKKXTCmcJNU,13294
|
|
69
|
+
torchzero/modules/line_search/scipy.py,sha256=xQ80h9cSyF4Iorq_1NoJglu_Bx4_KeojulBIxvwU6gQ,2836
|
|
70
|
+
torchzero/modules/line_search/strong_wolfe.py,sha256=21IMtenhBlrw_edtwizzLZ01PF4-rU4M2oGOCUA2udc,14936
|
|
71
|
+
torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
|
|
72
|
+
torchzero/modules/misc/debug.py,sha256=6pFAGYANjCPGIZH_4ghpUYYTEsT5jr7PMB9KLuPP4p8,1532
|
|
73
|
+
torchzero/modules/misc/escape.py,sha256=qfEdKLD5rejqrmvyHrI5BRQq8js9UF2-Axs_C0KFyWA,1866
|
|
74
|
+
torchzero/modules/misc/gradient_accumulation.py,sha256=mBWa5CBCZwp4TrtOyjWI3VnHag4gum4WBM2WFhvHqW4,4891
|
|
75
|
+
torchzero/modules/misc/homotopy.py,sha256=hihLETE4dNZ27zatqPR_qT3kGX-AXbC7oBWRDbFQo58,1939
|
|
76
|
+
torchzero/modules/misc/misc.py,sha256=feI-IQlxhIoAbsSRTjE4SbGez1c2Uu9-WA_nkK7iiqQ,15411
|
|
77
|
+
torchzero/modules/misc/multistep.py,sha256=RtDFIeTHu4RcERvlKEP4_10-lpRZOgbnBeSah92dQ7A,6323
|
|
78
|
+
torchzero/modules/misc/regularization.py,sha256=SkQ0_Ybtv9IEGI9QGdvNZaja5bAyc1x-j_1gvYIVepI,6105
|
|
79
|
+
torchzero/modules/misc/split.py,sha256=JcXVB4xk3h55YT2OAdepVsRoE1PD7bqX6NmJ2IxBgAI,4013
|
|
80
|
+
torchzero/modules/misc/switch.py,sha256=p758heAnv-PkoslpafL35Yp7mlvPmDVSe1mWiuuD8Mk,3711
|
|
81
|
+
torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
|
|
82
|
+
torchzero/modules/momentum/averaging.py,sha256=vDW8tgGsEuBXF_BTUYHB_j--TIVam9j0nZdp_x8TkxY,3229
|
|
83
|
+
torchzero/modules/momentum/cautious.py,sha256=x506a3lUETRpxPWqXLFJVFBH1gmLqIfqL5J-hFdEvOo,8051
|
|
84
|
+
torchzero/modules/momentum/momentum.py,sha256=q3n0BvQURuSBzA9vn1ZrH-n7Nsr0AS-38VJuwraQPY0,4495
|
|
85
|
+
torchzero/modules/ops/__init__.py,sha256=9UHaXs9aaKc0ewAhicTlDmj42bSC_vddMOD0eYuUj_8,1226
|
|
86
|
+
torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
|
|
87
|
+
torchzero/modules/ops/binary.py,sha256=2hV2oruaq5Asu4Ts8X8yiZQ-07fU0RGpRy3-vifXqXY,12151
|
|
88
|
+
torchzero/modules/ops/higher_level.py,sha256=E76zgSHlhVpHLrXhnVwelIQFm1IKn0IFcVq7DOJw0es,9037
|
|
89
|
+
torchzero/modules/ops/multi.py,sha256=YC3rBTmPRwF5aEPDNsyTK4J_JEAbmE7oBmF7W-VOV3A,8588
|
|
90
|
+
torchzero/modules/ops/reduce.py,sha256=kALG7X8q02sWpo1skpXjS0r875gwq6mrhLZbFfYaZoA,6324
|
|
91
|
+
torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
|
|
92
|
+
torchzero/modules/ops/utility.py,sha256=_k9S59i6IYOzzfIQlToQ9mlDseTTAS_49wujUxMGXZo,4105
|
|
93
|
+
torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
|
|
94
|
+
torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
|
|
95
|
+
torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
|
|
96
|
+
torchzero/modules/projections/projection.py,sha256=itkkb2UmMqbdtWKjUUg6gbFJfCEIZAskC0HCvom-6sc,14084
|
|
97
|
+
torchzero/modules/quasi_newton/__init__.py,sha256=HxXENs3O6nFRfCvUJhWPK9f8_A6iMwB6UF1Zold12UQ,515
|
|
98
|
+
torchzero/modules/quasi_newton/damping.py,sha256=K1DVqqKiAs6-F3JQh5jlKNb79oJdObqnKWwHHRl6boQ,2813
|
|
99
|
+
torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=Zx-tlFRa89GhoSP7RFJdLQJPiqPCL7rWaV7WJoQ1YCs,6930
|
|
100
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=fzCjV5YsLo_uJTVG3vosPHsvDc97mLKueK6fxOHLb8I,11195
|
|
101
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=D3_yV5xtgklMlU4fAL1-sH82-1tNl3K2F12ZBZyLQGM,8512
|
|
102
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=-xUGPld8Y0MHwN6qsmDihLbUbulU0T1z8jf2mZhNpcE,44529
|
|
103
|
+
torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
|
|
104
|
+
torchzero/modules/restarts/restars.py,sha256=ZN8kzufkOKredI54HJ9cSANAYY2yN4gAggqAVtEBCyA,9084
|
|
105
|
+
torchzero/modules/second_order/__init__.py,sha256=lTGccDNVwPuMevMeKi5O0a9cl24Rn9tk7VkC6jvlGYc,233
|
|
106
|
+
torchzero/modules/second_order/multipoint.py,sha256=Ilzo0Ddd3iApegceu7cHSMGim9ZH5QS4-2uBtrKXC6k,8581
|
|
107
|
+
torchzero/modules/second_order/newton.py,sha256=PAPbJzssx0Ji328BFOEzeJZPd3IubJTPHs6ZhqS_nW8,15663
|
|
108
|
+
torchzero/modules/second_order/newton_cg.py,sha256=6KmYr-U5gmuTGoVZA8Gt2-rfLwnjNakjKkBKEuBffTk,16738
|
|
109
|
+
torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
|
|
110
|
+
torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
|
|
111
|
+
torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
|
|
112
|
+
torchzero/modules/smoothing/sampling.py,sha256=zI5bATytQmCqm_UgAQbfA9tNRgrZaKLfUb0B-kzKRHU,12867
|
|
113
|
+
torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
|
|
114
|
+
torchzero/modules/step_size/adaptive.py,sha256=HvffW3m1NnpMTpps0QjJTvbblSODxxWMBBFTbNwp0vM,14482
|
|
115
|
+
torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
|
|
116
|
+
torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
|
|
117
|
+
torchzero/modules/termination/termination.py,sha256=BXU3R04caBc8rFJ4v_yJjgGi1X4iA11eYwlbiJfxexI,6637
|
|
118
|
+
torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
|
|
119
|
+
torchzero/modules/trust_region/cubic_regularization.py,sha256=gbKOR5zBo3t9i-sW23DCtTQwZrBubuFy_VuafrLaeUw,6718
|
|
120
|
+
torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
|
|
121
|
+
torchzero/modules/trust_region/levenberg_marquardt.py,sha256=s4XBXK8LwHuwyanOJFtgmOLkxEiBMhBVzch0J2_dIFk,5055
|
|
122
|
+
torchzero/modules/trust_region/trust_cg.py,sha256=F6G0hTXv6Ry0swO_4fx3ecxYWxMr72nrmwJPbFDpqH4,4459
|
|
123
|
+
torchzero/modules/trust_region/trust_region.py,sha256=eimCFViJSzoubrRmDluCon6mfcyT7PQA0yRPu4FlO2Q,12872
|
|
124
|
+
torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
|
|
125
|
+
torchzero/modules/variance_reduction/svrg.py,sha256=9pBjPY4EMkGyfj68gXqPi1GJIolUVl5zyNtlZInCKKo,8635
|
|
126
|
+
torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
|
|
127
|
+
torchzero/modules/weight_decay/weight_decay.py,sha256=Y7kE_j0GRF8ceJ9SS6qykQ8a23X2OTDCjJ9VklOQSEw,5415
|
|
128
|
+
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
129
|
+
torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
|
|
130
|
+
torchzero/modules/zeroth_order/__init__.py,sha256=1ADUiOHVHzvIP4TpH7_ILmeW2heidfikbf6d5g_1RzY,18
|
|
131
|
+
torchzero/modules/zeroth_order/cd.py,sha256=SwjwoAqX86-JnVHIwKAE7g_tqm0EvEUUNuLM4T5mKXE,4876
|
|
132
|
+
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
133
|
+
torchzero/optim/root.py,sha256=gGtAJ9qBoSNV58EKzUGZ8J3lyKGUF8BEw34Zfprppdo,2273
|
|
134
|
+
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
135
|
+
torchzero/optim/utility/split.py,sha256=kraPCLAewX2uLbD_9R2dIrcF-kpUuT9IcxPeVrAARvA,1672
|
|
136
|
+
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
137
|
+
torchzero/optim/wrappers/directsearch.py,sha256=rimJIB2RrVzLpRPQKhzkrMQ4bTAEU3NEOT4pJQNIAHE,11309
|
|
138
|
+
torchzero/optim/wrappers/fcmaes.py,sha256=jKmmBKEwguYiJdvTRmAp5JSilxcUhtpRoKlzmp-lyWE,4251
|
|
139
|
+
torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
|
|
140
|
+
torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
|
|
141
|
+
torchzero/optim/wrappers/nlopt.py,sha256=RuLKretljjAjTZ9tXY3FyEDuB7mAboeGOQBupWfzPc4,8105
|
|
142
|
+
torchzero/optim/wrappers/optuna.py,sha256=pIXkC5NVmEnUQ4jsGaz6Gv9uYOZM9rnxME4UGkeolsE,2393
|
|
143
|
+
torchzero/optim/wrappers/scipy.py,sha256=A4yeQRdB9f65UrJ2g80NfqqMc6zUyr9js40TUESCHPg,21535
|
|
144
|
+
torchzero/utils/__init__.py,sha256=7S4VRTkfS-0uI8HOR0EFIjiEcKrmYK7LEhTocIgki6c,1112
|
|
145
|
+
torchzero/utils/compile.py,sha256=Dozox91tcShUJ3L320TTbJrcuA-l4WVegLAQujRqy94,5132
|
|
146
|
+
torchzero/utils/derivatives.py,sha256=zJ0xyedvlIwgAYMa1F5BBfyrkvgjXy7v7evvl6QAlT0,17195
|
|
147
|
+
torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
|
|
148
|
+
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
149
|
+
torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
|
|
150
|
+
torchzero/utils/optimizer.py,sha256=pzGIddJtsR_oH3mb3GOsWHi02Sb62ms4NCaWUOk0Clo,12470
|
|
151
|
+
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
152
|
+
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
153
|
+
torchzero/utils/python_tools.py,sha256=QxLZ2PKgp4R2zI_C1qNOX_u4eIcuT0wNBpBM5YEIYuU,3428
|
|
154
|
+
torchzero/utils/tensorlist.py,sha256=nIWBME3fUQPsr4buvtV3LaJgSXPEG_Xb58KAzfjwK-k,56064
|
|
155
|
+
torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
|
|
156
|
+
torchzero/utils/linalg/__init__.py,sha256=cNoTJOPeqbNn9l7_HAAen2rlehGS3DyY5SveInG3Stc,328
|
|
157
|
+
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
158
|
+
torchzero/utils/linalg/linear_operator.py,sha256=uJUxvOVHpG3U3GNx61JGa_uM8GqzsNZmA_z7P0RwZ5E,12747
|
|
159
|
+
torchzero/utils/linalg/matrix_funcs.py,sha256=BKQK_oIG35R6yGxU80eBG0VkyY2EgxywqbhvU7JhWm4,3109
|
|
160
|
+
torchzero/utils/linalg/orthogonalize.py,sha256=BpuDiAPrsJMUpTNBMCntBNA8-O2nozLxY5ZbCfRlEFY,444
|
|
161
|
+
torchzero/utils/linalg/qr.py,sha256=5tbPEV9I6X69r5ACWF9XeqjZTUtUql2145uoGjlJNDs,2517
|
|
162
|
+
torchzero/utils/linalg/solve.py,sha256=R5lPTzHn2sgvRy4MRp-Ngl0sypSGLRLHJjf1oKKAJD0,14395
|
|
163
|
+
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
164
|
+
torchzero-0.3.14.dist-info/METADATA,sha256=Pd9XeJLSPuNQtb-dVJvnIiDOX5Mdr9b9ihpfw6rxBpQ,565
|
|
165
|
+
torchzero-0.3.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
166
|
+
torchzero-0.3.14.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
|
|
167
|
+
torchzero-0.3.14.dist-info/RECORD,,
|
docs/source/conf.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
|
1
|
-
# Configuration file for the Sphinx documentation builder.
|
|
2
|
-
#
|
|
3
|
-
# For the full list of built-in configuration values, see the documentation:
|
|
4
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
|
5
|
-
|
|
6
|
-
# -- Project information -----------------------------------------------------
|
|
7
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
8
|
-
import sys, os
|
|
9
|
-
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
10
|
-
|
|
11
|
-
project = 'torchzero'
|
|
12
|
-
copyright = '2025, Ivan Nikishev'
|
|
13
|
-
author = 'Ivan Nikishev'
|
|
14
|
-
|
|
15
|
-
# -- General configuration ---------------------------------------------------
|
|
16
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
|
17
|
-
|
|
18
|
-
# https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
|
|
19
|
-
extensions = [
|
|
20
|
-
'sphinx.ext.autodoc',
|
|
21
|
-
'sphinx.ext.autosummary',
|
|
22
|
-
'sphinx.ext.viewcode',
|
|
23
|
-
'sphinx.ext.autosectionlabel',
|
|
24
|
-
'sphinx.ext.githubpages',
|
|
25
|
-
'sphinx.ext.napoleon',
|
|
26
|
-
'autoapi.extension',
|
|
27
|
-
"myst_nb",
|
|
28
|
-
|
|
29
|
-
# 'sphinx_rtd_theme',
|
|
30
|
-
]
|
|
31
|
-
autosummary_generate = True
|
|
32
|
-
autoapi_dirs = ['../../torchzero']
|
|
33
|
-
autoapi_type = "python"
|
|
34
|
-
# autoapi_ignore = ["*/tensorlist.py"]
|
|
35
|
-
|
|
36
|
-
# https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
|
|
37
|
-
autoapi_options = [
|
|
38
|
-
"members",
|
|
39
|
-
"undoc-members",
|
|
40
|
-
"show-inheritance",
|
|
41
|
-
"show-module-summary",
|
|
42
|
-
"imported-members",
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
templates_path = ['_templates']
|
|
47
|
-
exclude_patterns = []
|
|
48
|
-
|
|
49
|
-
# -- Options for HTML output -------------------------------------------------
|
|
50
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
|
51
|
-
|
|
52
|
-
#html_theme = 'alabaster'
|
|
53
|
-
html_theme = 'sphinx_rtd_theme'
|
|
54
|
-
html_static_path = ['_static']
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
|
|
58
|
-
source_suffix = ['.rst', '.md']
|
|
59
|
-
master_doc = 'index'
|
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
class MyModule:
|
|
2
|
-
"""[One-line summary of the class].
|
|
3
|
-
|
|
4
|
-
[A more detailed description of the class, explaining its purpose, how it
|
|
5
|
-
works, and its typical use cases. You can use multiple paragraphs.]
|
|
6
|
-
|
|
7
|
-
.. note::
|
|
8
|
-
[Optional: Add important notes, warnings, or usage guidelines here.
|
|
9
|
-
For example, you could mention if a closure is required, discuss
|
|
10
|
-
stability, or highlight performance characteristics. Use the `.. note::`
|
|
11
|
-
directive to make it stand out in the documentation.]
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
param1 (type, optional):
|
|
15
|
-
[Description of the first parameter. Use :code:`backticks` for
|
|
16
|
-
inline code like variable names or specific values like ``"autograd"``.
|
|
17
|
-
Explain what the parameter does.] Defaults to [value].
|
|
18
|
-
param2 (type):
|
|
19
|
-
[Description of a mandatory parameter (no "optional" or "Defaults to").]
|
|
20
|
-
**kwargs:
|
|
21
|
-
[If you accept keyword arguments, describe what they are used for.]
|
|
22
|
-
|
|
23
|
-
Examples:
|
|
24
|
-
[A title or short sentence describing the first example]:
|
|
25
|
-
|
|
26
|
-
.. code-block:: python
|
|
27
|
-
|
|
28
|
-
opt = tz.Modular(
|
|
29
|
-
model.parameters(),
|
|
30
|
-
...
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
[A title or short sentence for a second, different example]:
|
|
34
|
-
|
|
35
|
-
.. code-block:: python
|
|
36
|
-
|
|
37
|
-
opt = tz.Modular(
|
|
38
|
-
model.parameters(),
|
|
39
|
-
...
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
References:
|
|
43
|
-
- [Optional: A citation for a relevant paper, book, or algorithm.]
|
|
44
|
-
- [Optional: A link to a blog post or website with more information.]
|
|
45
|
-
|
|
46
|
-
"""
|