torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/utils/tensorlist.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing_extensions import Self, TypeAlias, Unpack
|
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from .ops import where_
|
|
22
|
-
from .python_tools import
|
|
22
|
+
from .python_tools import zipmap, generic_ne
|
|
23
23
|
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
24
24
|
|
|
25
25
|
|
|
@@ -217,6 +217,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
217
217
|
"""Returns a TensorList with all elements for which `fn` returned True."""
|
|
218
218
|
return self.__class__(i for i in self if fn(i, *args, **kwargs))
|
|
219
219
|
|
|
220
|
+
def filter_by_list(self, s: Sequence[bool]):
|
|
221
|
+
"""returns a new TensorList with all elements where corresponding elements in :code:`s` are True."""
|
|
222
|
+
if len(self) != len(s):
|
|
223
|
+
raise ValueError(f"{len(self) = }, {len(s) = }")
|
|
224
|
+
return self.__class__(i for i, boolean in zip(self, s) if boolean)
|
|
225
|
+
|
|
220
226
|
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
221
227
|
"""If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
|
|
222
228
|
Otherwise applies `fn` to this TensorList and `other`.
|
|
@@ -319,7 +325,8 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
319
325
|
def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
|
|
320
326
|
def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
|
|
321
327
|
def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
|
|
322
|
-
def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
|
|
328
|
+
def global_vector_norm(self, ord:float | Literal['mean_abs'] = 2) -> torch.Tensor:
|
|
329
|
+
if ord == 'mean_abs': return self.abs().global_mean()
|
|
323
330
|
return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
324
331
|
def global_any(self): return builtins.any(self.any())
|
|
325
332
|
def global_all(self): return builtins.all(self.all())
|
|
@@ -425,11 +432,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
425
432
|
return self
|
|
426
433
|
|
|
427
434
|
def lazy_add(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
428
|
-
if
|
|
429
|
-
return self
|
|
435
|
+
if generic_ne(other, 0): return self.add(other)
|
|
436
|
+
return self
|
|
430
437
|
def lazy_add_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
431
|
-
if
|
|
432
|
-
return self
|
|
438
|
+
if generic_ne(other, 0): return self.add_(other)
|
|
439
|
+
return self
|
|
433
440
|
|
|
434
441
|
@overload
|
|
435
442
|
def sub(self, other: _TensorSeq, alpha: _Scalar = 1): ...
|
|
@@ -449,11 +456,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
449
456
|
return self
|
|
450
457
|
|
|
451
458
|
def lazy_sub(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
452
|
-
if
|
|
453
|
-
return self
|
|
459
|
+
if generic_ne(other, 0): return self.sub(other)
|
|
460
|
+
return self
|
|
454
461
|
def lazy_sub_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
455
|
-
if
|
|
456
|
-
return self
|
|
462
|
+
if generic_ne(other, 0): return self.sub_(other)
|
|
463
|
+
return self
|
|
457
464
|
|
|
458
465
|
def neg(self): return self.__class__(torch._foreach_neg(self))
|
|
459
466
|
def neg_(self):
|
|
@@ -467,13 +474,13 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
467
474
|
|
|
468
475
|
# TODO: benchmark
|
|
469
476
|
def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
|
|
470
|
-
if
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
return self
|
|
477
|
+
if generic_ne(other, 1):
|
|
478
|
+
return self * other
|
|
479
|
+
if clone: return self.clone()
|
|
480
|
+
return self
|
|
474
481
|
def lazy_mul_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
475
|
-
if
|
|
476
|
-
return self
|
|
482
|
+
if generic_ne(other, 1): return self.mul_(other)
|
|
483
|
+
return self
|
|
477
484
|
|
|
478
485
|
def div(self, other: _STOrSTSeq) -> Self: return self.__class__(torch._foreach_div(self, other))
|
|
479
486
|
def div_(self, other: _STOrSTSeq):
|
|
@@ -481,11 +488,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
481
488
|
return self
|
|
482
489
|
|
|
483
490
|
def lazy_div(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
484
|
-
if
|
|
485
|
-
return self
|
|
491
|
+
if generic_ne(other, 1): return self / other
|
|
492
|
+
return self
|
|
486
493
|
def lazy_div_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
487
|
-
if
|
|
488
|
-
return self
|
|
494
|
+
if generic_ne(other, 1): return self.div_(other)
|
|
495
|
+
return self
|
|
489
496
|
|
|
490
497
|
def pow(self, exponent: "_Scalar | _STSeq"): return self.__class__(torch._foreach_pow(self, exponent))
|
|
491
498
|
def pow_(self, exponent: "_Scalar | _STSeq"):
|
|
@@ -627,7 +634,8 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
627
634
|
if dim is None: dim = ()
|
|
628
635
|
return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
|
|
629
636
|
|
|
630
|
-
def norm(self, ord: _Scalar, dtype=None):
|
|
637
|
+
def norm(self, ord: _Scalar|Literal["mean_abs"], dtype=None):
|
|
638
|
+
if isinstance(ord, str): return self.abs().mean()
|
|
631
639
|
return self.__class__(torch._foreach_norm(self, ord, dtype))
|
|
632
640
|
|
|
633
641
|
def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
@@ -782,7 +790,7 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
782
790
|
for t, o in zip(self, other): t.copysign_(o)
|
|
783
791
|
return self
|
|
784
792
|
|
|
785
|
-
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
793
|
+
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
786
794
|
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
787
795
|
if tensorwise:
|
|
788
796
|
norm_self = self.norm(ord)
|
|
@@ -791,11 +799,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
791
799
|
norm_self = self.global_vector_norm(ord)
|
|
792
800
|
norm_other = magnitude.global_vector_norm(ord)
|
|
793
801
|
|
|
794
|
-
if
|
|
802
|
+
if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
795
803
|
|
|
796
804
|
return self * (norm_other / norm_self.clip_(min=eps))
|
|
797
805
|
|
|
798
|
-
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
806
|
+
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
799
807
|
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
800
808
|
if tensorwise:
|
|
801
809
|
norm_self = self.norm(ord)
|
|
@@ -804,7 +812,7 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
804
812
|
norm_self = self.global_vector_norm(ord)
|
|
805
813
|
norm_other = magnitude.global_vector_norm(ord)
|
|
806
814
|
|
|
807
|
-
if
|
|
815
|
+
if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
808
816
|
|
|
809
817
|
return self.mul_(norm_other / norm_self.clip_(min=eps))
|
|
810
818
|
|
|
@@ -897,7 +905,7 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
897
905
|
if eps!=0: std.add_(eps)
|
|
898
906
|
return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
|
|
899
907
|
|
|
900
|
-
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
|
|
908
|
+
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
|
|
901
909
|
"""calculate multipler to clip self norm to min and max"""
|
|
902
910
|
if tensorwise:
|
|
903
911
|
self_norm = self.norm(ord)
|
|
@@ -918,12 +926,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
918
926
|
|
|
919
927
|
return mul
|
|
920
928
|
|
|
921
|
-
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
|
|
929
|
+
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
|
|
922
930
|
"""clips norm of each tensor to (min, max) range"""
|
|
923
931
|
if min is None and max is None: return self
|
|
924
932
|
return self * self._clip_multiplier(min, max, tensorwise, ord)
|
|
925
933
|
|
|
926
|
-
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
|
|
934
|
+
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
|
|
927
935
|
"""clips norm of each tensor to (min, max) range"""
|
|
928
936
|
if min is None and max is None: return self
|
|
929
937
|
return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
|
|
@@ -1057,6 +1065,10 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
|
|
|
1057
1065
|
if isinstance(x, torch.Tensor): return x.numel()
|
|
1058
1066
|
return x.global_numel()
|
|
1059
1067
|
|
|
1068
|
+
def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
|
|
1069
|
+
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
|
|
1070
|
+
return torch.finfo(x[0].dtype).eps
|
|
1071
|
+
|
|
1060
1072
|
@overload
|
|
1061
1073
|
def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
1062
1074
|
@overload
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.11
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -45,8 +45,6 @@ Dynamic: license-file
|
|
|
45
45
|
|
|
46
46
|
`torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
|
|
47
47
|
|
|
48
|
-
NOTE: torchzero is in active development, currently docs are in a state of flux.
|
|
49
|
-
|
|
50
48
|
## Installation
|
|
51
49
|
|
|
52
50
|
```bash
|
|
@@ -113,31 +111,21 @@ for epoch in range(100):
|
|
|
113
111
|
`torchzero` provides a huge number of various modules:
|
|
114
112
|
|
|
115
113
|
* **Optimizers**: Optimization algorithms.
|
|
116
|
-
* `Adam`.
|
|
117
|
-
* `Shampoo`.
|
|
118
|
-
* `SOAP` (my current recommendation).
|
|
119
|
-
* `Muon`.
|
|
120
|
-
* `SophiaH`.
|
|
121
|
-
* `Adagrad` and `FullMatrixAdagrad`.
|
|
122
|
-
* `Lion`.
|
|
123
|
-
* `RMSprop`.
|
|
124
|
-
* `OrthoGrad`.
|
|
125
|
-
* `Rprop`.
|
|
114
|
+
* `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
|
|
126
115
|
|
|
127
116
|
Additionally many other optimizers can be easily defined via modules:
|
|
128
117
|
* Grams: `[tz.m.Adam(), tz.m.GradSign()]`
|
|
129
118
|
* LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
|
|
130
119
|
* Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
|
|
131
|
-
*
|
|
120
|
+
* Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
|
|
132
121
|
* Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
|
|
133
122
|
|
|
134
123
|
* **Momentum**:
|
|
135
|
-
* `NAG`: Nesterov Accelerated Gradient.
|
|
136
124
|
* `HeavyBall`: Classic momentum (Polyak's momentum).
|
|
125
|
+
* `NAG`: Nesterov Accelerated Gradient.
|
|
137
126
|
* `EMA`: Exponential moving average.
|
|
138
|
-
* `Averaging` (`
|
|
127
|
+
* `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
|
|
139
128
|
* `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
|
|
140
|
-
* `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
|
|
141
129
|
|
|
142
130
|
* **Stabilization**: Gradient stabilization techniques.
|
|
143
131
|
* `ClipNorm`: Clips gradient L2 norm.
|
|
@@ -154,31 +142,42 @@ for epoch in range(100):
|
|
|
154
142
|
|
|
155
143
|
* **Second order**: Second order methods.
|
|
156
144
|
* `Newton`: Classic Newton's method.
|
|
157
|
-
* `
|
|
145
|
+
* `InverseFreeNewton`: Inverse-free version of Newton's method.
|
|
146
|
+
* `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
|
|
147
|
+
* `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
|
|
158
148
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning
|
|
149
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning.
|
|
150
|
+
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
160
151
|
|
|
161
152
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
153
|
* `LBFGS`: Limited-memory BFGS.
|
|
163
154
|
* `LSR1`: Limited-memory SR1.
|
|
164
155
|
* `OnlineLBFGS`: Online LBFGS.
|
|
165
|
-
* `BFGS`, `SR1`, `
|
|
166
|
-
* `
|
|
156
|
+
* `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
|
|
157
|
+
* `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
|
|
158
|
+
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
159
|
+
|
|
160
|
+
* **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
|
|
161
|
+
* `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
|
|
162
|
+
* `CubicRegularization`: Cubic regularization, works better with exact hessian.
|
|
167
163
|
|
|
168
164
|
* **Line Search**:
|
|
169
165
|
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
170
166
|
* `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
|
|
171
167
|
* `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
|
|
172
|
-
* `TrustRegion`: First order trust region method.
|
|
173
168
|
|
|
174
169
|
* **Learning Rate**:
|
|
175
170
|
* `LR`: Controls learning rate and adds support for LR schedulers.
|
|
176
|
-
* `PolyakStepSize`: Polyak's method.
|
|
177
|
-
* `
|
|
171
|
+
* `PolyakStepSize`: Polyak's subgradient method.
|
|
172
|
+
* `BarzilaiBorwein`: Barzilai-Borwein step-size.
|
|
173
|
+
* `Warmup`, `WarmupNormCLip`: Learning rate warmup.
|
|
178
174
|
|
|
179
175
|
* **Projections**: This can implement things like GaLore but I haven't done that yet.
|
|
180
|
-
* `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
181
|
-
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
|
|
176
|
+
<!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
177
|
+
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
|
|
178
|
+
This is WIP
|
|
179
|
+
* `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
|
|
180
|
+
* `ViewAsReal`: put if you have complex paramters.
|
|
182
181
|
|
|
183
182
|
* **Smoothing**: Smoothing-based optimization methods.
|
|
184
183
|
* `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
|
|
@@ -194,6 +193,8 @@ for epoch in range(100):
|
|
|
194
193
|
|
|
195
194
|
* **Experimental**: various horrible atrocities
|
|
196
195
|
|
|
196
|
+
A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
|
|
197
|
+
|
|
197
198
|
## Advanced Usage
|
|
198
199
|
|
|
199
200
|
### Closure
|
|
@@ -312,20 +313,21 @@ not in the module itself. Also both per-parameter settings and state are stored
|
|
|
312
313
|
|
|
313
314
|
```python
|
|
314
315
|
import torch
|
|
315
|
-
from torchzero.core import Module,
|
|
316
|
+
from torchzero.core import Module, Var
|
|
316
317
|
|
|
317
318
|
class HeavyBall(Module):
|
|
318
319
|
def __init__(self, momentum: float = 0.9, dampening: float = 0):
|
|
319
320
|
defaults = dict(momentum=momentum, dampening=dampening)
|
|
320
321
|
super().__init__(defaults)
|
|
321
322
|
|
|
322
|
-
def step(self,
|
|
323
|
-
#
|
|
324
|
-
#
|
|
323
|
+
def step(self, var: Var):
|
|
324
|
+
# Var object holds all attributes used for optimization - parameters, gradient, update, etc.
|
|
325
|
+
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
326
|
+
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
325
327
|
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
326
328
|
|
|
327
|
-
params =
|
|
328
|
-
update =
|
|
329
|
+
params = var.params
|
|
330
|
+
update = var.get_update() # list of tensors
|
|
329
331
|
|
|
330
332
|
exp_avg_list = []
|
|
331
333
|
for p, u in zip(params, update):
|
|
@@ -346,34 +348,57 @@ class HeavyBall(Module):
|
|
|
346
348
|
# and it is part of self.state
|
|
347
349
|
exp_avg_list.append(buf.clone())
|
|
348
350
|
|
|
349
|
-
# set new update to
|
|
350
|
-
|
|
351
|
-
return
|
|
351
|
+
# set new update to var
|
|
352
|
+
var.update = exp_avg_list
|
|
353
|
+
return var
|
|
352
354
|
```
|
|
353
355
|
|
|
354
|
-
|
|
356
|
+
More in-depth guide will be available in the documentation in the future.
|
|
357
|
+
|
|
358
|
+
## Other stuff
|
|
355
359
|
|
|
356
|
-
|
|
357
|
-
* `LineSearch` for line searches
|
|
358
|
-
* `Preconditioner` for preconditioners
|
|
359
|
-
* `Projection` for projections like GaLore or into fourier domain.
|
|
360
|
-
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
361
|
-
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
360
|
+
There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
362
361
|
|
|
363
|
-
|
|
362
|
+
---
|
|
364
363
|
|
|
365
|
-
|
|
364
|
+
### Scipy
|
|
366
365
|
|
|
367
|
-
|
|
366
|
+
#### torchzero.optim.wrappers.scipy.ScipyMinimize
|
|
368
367
|
|
|
369
|
-
|
|
368
|
+
A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
|
|
370
369
|
|
|
371
|
-
|
|
370
|
+
#### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
|
|
372
371
|
|
|
373
|
-
|
|
372
|
+
Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
|
|
373
|
+
|
|
374
|
+
---
|
|
375
|
+
|
|
376
|
+
### NLOpt
|
|
377
|
+
|
|
378
|
+
#### torchzero.optim.wrappers.nlopt.NLOptWrapper
|
|
374
379
|
|
|
375
|
-
|
|
380
|
+
A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
|
|
381
|
+
|
|
382
|
+
---
|
|
383
|
+
|
|
384
|
+
### Nevergrad
|
|
385
|
+
|
|
386
|
+
#### torchzero.optim.wrappers.nevergrad.NevergradWrapper
|
|
387
|
+
|
|
388
|
+
A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
|
|
389
|
+
|
|
390
|
+
---
|
|
391
|
+
|
|
392
|
+
### fast-cma-es
|
|
393
|
+
|
|
394
|
+
#### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
|
|
395
|
+
|
|
396
|
+
A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
|
|
397
|
+
|
|
398
|
+
# License
|
|
399
|
+
|
|
400
|
+
This project is licensed under the MIT License
|
|
376
401
|
|
|
377
|
-
|
|
402
|
+
# Project Links
|
|
378
403
|
|
|
379
|
-
|
|
404
|
+
The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
docs/source/conf.py,sha256=Kd0Uyu6WnhSHEyTbOEjxoaUg4sAu0AxN19raSARtltE,1883
|
|
2
|
+
docs/source/docstring template.py,sha256=lIf4Jdkxd-Vr0vOuL9IOTCMOxw5ENsmZDLXKv1eO9ns,1585
|
|
3
|
+
tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
|
|
4
|
+
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
5
|
+
tests/test_opts.py,sha256=pAeyDIT0Q4SXBZqR9W_IUjwAEBcMnYr3zE0N4R0xn8w,42509
|
|
6
|
+
tests/test_tensorlist.py,sha256=SwzLKLrs2ppMtm_7UrfTDTlD-ObZd7JQ_FNHbp059tc,72460
|
|
7
|
+
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
8
|
+
tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
|
|
9
|
+
torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
|
|
10
|
+
torchzero/core/__init__.py,sha256=Zib_4is13LFAabp_7VU8QXZpQEEZGzsH94vgRI0HxAg,150
|
|
11
|
+
torchzero/core/module.py,sha256=BfU4YMjwLrwcz24XAfL-cZx05cESIimViKUStJKBEHM,32872
|
|
12
|
+
torchzero/core/transform.py,sha256=sBgEyQVm141v99lnosusNIMWaReuWKuMyzkJha_WwKg,16440
|
|
13
|
+
torchzero/modules/__init__.py,sha256=0Gk6XK32FKxtiW9rh-0Plql2dghHn3Ms1F-Ymn4oVzw,386
|
|
14
|
+
torchzero/modules/functional.py,sha256=hmJaxB7U9X9nsT1Z5aPSqsw5HsQfL2ns1YS8AWdul6c,6948
|
|
15
|
+
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
16
|
+
torchzero/modules/clipping/clipping.py,sha256=6d-LPCI4zqlcV9fXK8rtRLiReyt8lMeQhmt1gsqNljs,14897
|
|
17
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=PNUTvixvc0wdjtWzja6pEzXbNpyXtGxj_H15umWx4zc,6608
|
|
18
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
|
|
19
|
+
torchzero/modules/experimental/__init__.py,sha256=qV-VaBnRsLFtv6T6R9Imkd1G81QR4O-9_kDbCAwJXeY,1464
|
|
20
|
+
torchzero/modules/experimental/absoap.py,sha256=U3nLAV_vxl6HjJhqi8FlK8K6AMLoiZ-deykEshhnCC0,9916
|
|
21
|
+
torchzero/modules/experimental/adadam.py,sha256=PARjM2kRmJ7ifYsI83tADKCuvSZYAoT2vR4Gj2aZ-SA,4103
|
|
22
|
+
torchzero/modules/experimental/adamY.py,sha256=Rr9vXjFPWTfIHnnhGQAfVAQnfANNgcrFm_R8vJsU1to,4043
|
|
23
|
+
torchzero/modules/experimental/adam_lambertw.py,sha256=FXZiTJKVRbXSu9-_boZGYoCqBlh2035mwsagq75qyeA,5323
|
|
24
|
+
torchzero/modules/experimental/adaptive_step_size.py,sha256=OJseQX9sd9F58pMC5JbVNm7PtovMXL4sMwQg3jooVtg,3494
|
|
25
|
+
torchzero/modules/experimental/adasoap.py,sha256=vcgWEgDdqmgimt5bGgvznCnxkkathGO0engd1xo7M4s,7491
|
|
26
|
+
torchzero/modules/experimental/cosine.py,sha256=0Cc42Wd1sMrjm-YxmpcwCCsGpLv3H83rL-XAtrgZhb4,9155
|
|
27
|
+
torchzero/modules/experimental/cubic_adam.py,sha256=wHJKm9bO24Xvtwunz_1Kz7mGi_C-syupixiDaBnYx2Q,2787
|
|
28
|
+
torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
|
|
29
|
+
torchzero/modules/experimental/dct.py,sha256=Iv8ZxGhTOIm3NHS4zxoFG9K9BEwtrJqsKApctiIjnxg,2463
|
|
30
|
+
torchzero/modules/experimental/eigendescent.py,sha256=Pdz7QUbM3pD3DTsTC0nZ0AfOe2pj-WVPPkbnw8lDZ3c,4725
|
|
31
|
+
torchzero/modules/experimental/etf.py,sha256=ul167I1qAbYeTmTPG_WFLLlE1MEsNXxVsTWd9s2YC9g,6125
|
|
32
|
+
torchzero/modules/experimental/exp_adam.py,sha256=yhR5-NGflbEJrSAe0ps4xgAM-eFI-gAdS6cgZIJDgaI,4100
|
|
33
|
+
torchzero/modules/experimental/expanded_lbfgs.py,sha256=M58cCaeLZXGqZwyaeGhi-UAyCsnnJvLAYIZ64r0tQNE,5649
|
|
34
|
+
torchzero/modules/experimental/fft.py,sha256=YEUKdAXNX8BCZYXKV5uWWU8aTlGjpFTUSpIEwIG-_fM,3050
|
|
35
|
+
torchzero/modules/experimental/gradmin.py,sha256=UixSLdca4ekYHOipEivdXfBAV-uEL9TZm5nCFXVaNco,3684
|
|
36
|
+
torchzero/modules/experimental/hnewton.py,sha256=_Gv4O2x0qYBxGtkCuYuzL21VuI5wTn1sTEegk17d6X4,3036
|
|
37
|
+
torchzero/modules/experimental/modular_lbfgs.py,sha256=d40yRi6NN2Au7-UQ1akMkET0PWhEFAhGKAYoQBDmqFQ,10671
|
|
38
|
+
torchzero/modules/experimental/newton_solver.py,sha256=3dZ7FG-2vGxJKkFF9P2LCs-LI_epcvZbyNtJOtw47pg,3055
|
|
39
|
+
torchzero/modules/experimental/newtonnewton.py,sha256=cRL4dKsDAN8tHPyHQkLbTGxkHfemCU6re-n4odV3Ik4,3324
|
|
40
|
+
torchzero/modules/experimental/parabolic_search.py,sha256=2GgE4cq5QkJYZprADIplQfbPWRJRGFmToYTScJkR0tg,6328
|
|
41
|
+
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
|
|
42
|
+
torchzero/modules/experimental/structural_projections.py,sha256=lrySQZOq7VhL_VqU7dIJRsypxA16cUliQYkj5-N2B2I,4187
|
|
43
|
+
torchzero/modules/experimental/subspace_preconditioners.py,sha256=RdG-RoPF6AiFVphrVlb6egNyYI0e_eHoENUWqKJ4icQ,5170
|
|
44
|
+
torchzero/modules/experimental/tensor_adagrad.py,sha256=y29i6BGXwv9lwrTRDzq2YRSngQmfZnreRIeH1NGzpBo,1572
|
|
45
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
46
|
+
torchzero/modules/grad_approximation/fdm.py,sha256=K_D0fKwspg21Opo2xTG4I34gLDmcaYBp5NUzlaQnjxQ,4490
|
|
47
|
+
torchzero/modules/grad_approximation/forward_gradient.py,sha256=AoezoYxXii2gKpIGO7BOZkLb2weYwxrWAKpHL7hrW9Y,4313
|
|
48
|
+
torchzero/modules/grad_approximation/grad_approximator.py,sha256=HO-XaNRF3ZwMduBP02V0oabmSRgqmDGPlKkWfDVDPW8,4740
|
|
49
|
+
torchzero/modules/grad_approximation/rfdm.py,sha256=omarcZyMgJomJwxQ_b7ulE6eK6aW3JP_Sh-jcX5DhR4,23434
|
|
50
|
+
torchzero/modules/higher_order/__init__.py,sha256=W94CY8K1NFxs9TPi415UssKVKz5MV_bH9adax1uZsYM,50
|
|
51
|
+
torchzero/modules/higher_order/higher_order_newton.py,sha256=_v5v0WY07CvZn9QPIS89FxEZ2tNfd8Bkamt1o12_mLQ,12255
|
|
52
|
+
torchzero/modules/line_search/__init__.py,sha256=9ja1Dspfuzu9UxGbU5-t0bFeBcdwoX9Fl_aSMR-AXnQ,219
|
|
53
|
+
torchzero/modules/line_search/adaptive.py,sha256=Uj7lAIzpgy89ddlwA4VcEEIfcNJSbGA5HH3ncuzHrTU,2926
|
|
54
|
+
torchzero/modules/line_search/backtracking.py,sha256=dyXgfrIJ_IO7W4p8GqJNPc4r_igU4X4ljLCLNKyY2Tw,8246
|
|
55
|
+
torchzero/modules/line_search/line_search.py,sha256=_u59XYFkRsIKuT1H4Bz7qAHr3Ldzxbup71OeqDGxMfs,9724
|
|
56
|
+
torchzero/modules/line_search/polynomial.py,sha256=KlK0d9qaphxS0s8B5rlt-yIUYNuV-5O24STcx4vN2Ic,9056
|
|
57
|
+
torchzero/modules/line_search/scipy.py,sha256=eGplW1L8kQKdRbt9PPpvZ6MMekDq5KsjurhSpN9QCnY,2301
|
|
58
|
+
torchzero/modules/line_search/strong_wolfe.py,sha256=F5962HTHdPWgvWHwnUofCqFxfKsCu5p8Ic-aRbn7wVg,8458
|
|
59
|
+
torchzero/modules/misc/__init__.py,sha256=cZpMkZQubuzquhFZV-yELrDMznqhhCibmr0CBOR0ZpU,693
|
|
60
|
+
torchzero/modules/misc/debug.py,sha256=iuWg5egoMnG6y3Cyd423xS7BRVYiwZq9575d7A7U3Dg,1652
|
|
61
|
+
torchzero/modules/misc/escape.py,sha256=1XgNmT4pOptaXHSWEONkUPpcYnIujm5gdK6n_-zmw20,1821
|
|
62
|
+
torchzero/modules/misc/gradient_accumulation.py,sha256=6yXRUxD_f3Zfx83UyCvPJ-56XN4GJjEQcNIDlvFtuuY,2590
|
|
63
|
+
torchzero/modules/misc/misc.py,sha256=VTQZAcfQBo2yudy1u1lyHhmaAmQlxzVcZTHcXXnUeTM,13470
|
|
64
|
+
torchzero/modules/misc/multistep.py,sha256=rAPCALSHXjVNxR8d1CA3RFP_xnN6j5KksjB6yl8vtng,5585
|
|
65
|
+
torchzero/modules/misc/regularization.py,sha256=R8ya7HEF2MLtcAr7GS9IjXwJ4xh0lJWMdWMIRfwL42s,6279
|
|
66
|
+
torchzero/modules/misc/split.py,sha256=ebc95OZjC-Vs73JeTkL--eZrtKijg7lPN0hmD0Whfxc,3195
|
|
67
|
+
torchzero/modules/misc/switch.py,sha256=72mfY_uIVyTllwuR21_K7QC8IQFP7JMKzH4K2nAx0Wc,3726
|
|
68
|
+
torchzero/modules/momentum/__init__.py,sha256=tI2I5zSQB7aTwEn371wvUTy2O2n_-KVCafjBv-OMsYE,545
|
|
69
|
+
torchzero/modules/momentum/averaging.py,sha256=gZRjHb443HuFF03p3Oh2rfgh2Qu8sJBxc_8NR-ircaA,3241
|
|
70
|
+
torchzero/modules/momentum/cautious.py,sha256=QP3Sqc8nMb3xTDDDfGwFn5AWvN4EI5U-CCcZb-F5oX0,8266
|
|
71
|
+
torchzero/modules/momentum/ema.py,sha256=9OdMF20RYnEkwe9Xu2dCAAiI0qY2MQvhS87bKP7ptTI,10755
|
|
72
|
+
torchzero/modules/momentum/experimental.py,sha256=WnM9FUKPxyFNiKU6Ip7wqqYxHbXuaMKOcLjjomfENb4,6916
|
|
73
|
+
torchzero/modules/momentum/matrix_momentum.py,sha256=gZeTJZbhgixCOkE9Jyowtva58hl5vsH9iTqGC54FWFs,8047
|
|
74
|
+
torchzero/modules/momentum/momentum.py,sha256=Yx35jtbLb1syVFcTiNSoZPoUPmdsUy3QpoNWcN4sC9w,2664
|
|
75
|
+
torchzero/modules/ops/__init__.py,sha256=1q9CBo6OXWXDgyjvKKTlG0EdP4ASIvkWFXtd6LOuU88,1083
|
|
76
|
+
torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
|
|
77
|
+
torchzero/modules/ops/binary.py,sha256=mIeaa3v5Bk7mwzSTC0jGMLhKf-Ujg6aFbSia2yo-3JQ,12199
|
|
78
|
+
torchzero/modules/ops/multi.py,sha256=DpabTYj0sic5dmosnmj7lgIX3dbmcgl0h9XfzKpbaus,8918
|
|
79
|
+
torchzero/modules/ops/reduce.py,sha256=uLCq493hFy_Ib22GjIKtMHTTObK3RDmubGHTVqgFgg8,6339
|
|
80
|
+
torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
|
|
81
|
+
torchzero/modules/ops/utility.py,sha256=9Skxkt4RO79OBdw95wOKhqKN2RMdZg9emO7R9q2d5oU,3767
|
|
82
|
+
torchzero/modules/optimizers/__init__.py,sha256=IJaLoZ39rbB4GSW9rLKrfSCh5FsAkFy2ww5MhJ6MYnE,817
|
|
83
|
+
torchzero/modules/optimizers/adagrad.py,sha256=p-DWbhGuuogldiFPNxxQfJ8AA5Tsd4UwGOIyX7GT0WE,5892
|
|
84
|
+
torchzero/modules/optimizers/adahessian.py,sha256=vOJfwGi7ypfi7vifCMJfGew-McdGJKQM3TmkT-OUgI0,8682
|
|
85
|
+
torchzero/modules/optimizers/adam.py,sha256=SkJ7UJ1BOAgfregmzYDFo_3cgPNke_RK9B58hOal_Zg,3954
|
|
86
|
+
torchzero/modules/optimizers/adan.py,sha256=aOG6KGLU4oHYeQn3JB-A4NQ-279QpHA7firY3kkhFR4,3311
|
|
87
|
+
torchzero/modules/optimizers/adaptive_heavyball.py,sha256=DnkWHA0GBLIKCq8nWh76fZA6PnJ3eKsJDBXWKnZ_uIs,2127
|
|
88
|
+
torchzero/modules/optimizers/esgd.py,sha256=WXwYPA-qTA_QW9h4NDwNaly9gbi1uvMQ-5fSuLqnPkQ,6413
|
|
89
|
+
torchzero/modules/optimizers/ladagrad.py,sha256=HQb7LuZnG8SvS8JWqu7JJz_owlkyT-fnqeICrJBQxbc,7314
|
|
90
|
+
torchzero/modules/optimizers/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
|
|
91
|
+
torchzero/modules/optimizers/mars.py,sha256=7tr32x2eQNu8ZVQAPnLIkM2kkYp7S57uiDywTdqy1uY,2710
|
|
92
|
+
torchzero/modules/optimizers/msam.py,sha256=nvoo6smewR3hiCCymZQiB3DlCvLBGxfxlovJF2bwwsc,6588
|
|
93
|
+
torchzero/modules/optimizers/muon.py,sha256=AZKpmkVUjukXtI7Pb9PKDEeycreLF6qYlIMSbV_9IuA,10463
|
|
94
|
+
torchzero/modules/optimizers/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
|
|
95
|
+
torchzero/modules/optimizers/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
|
|
96
|
+
torchzero/modules/optimizers/rprop.py,sha256=nFpnqcXevGkUcPWERDX9gsiBCGgOi4pyPFloL68zwPY,11984
|
|
97
|
+
torchzero/modules/optimizers/sam.py,sha256=yEhXAS3v62nhAvs63RZ80VfZ93MaQ0cyMQziFdy6e2U,5711
|
|
98
|
+
torchzero/modules/optimizers/shampoo.py,sha256=m_XOvo2Eb1HP8QqYFPsT0rgczJ8HqKjh67QmtaY9dVg,9544
|
|
99
|
+
torchzero/modules/optimizers/soap.py,sha256=MXQ8fdBzLyFtgW34fnmY3hQqv3q4QwEthho9kK-72VE,11305
|
|
100
|
+
torchzero/modules/optimizers/sophia_h.py,sha256=dgQwjij5R4zdESYoKhc4BMhb6dKkDuEvjlL4bDdeQtw,7213
|
|
101
|
+
torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
|
|
102
|
+
torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
|
|
103
|
+
torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
|
|
104
|
+
torchzero/modules/projections/projection.py,sha256=PU2e9LNfVMnNrXnBDt-hdr5pVtl0TpgiB4b92WUguSs,14005
|
|
105
|
+
torchzero/modules/quasi_newton/__init__.py,sha256=guTCpbAffZyupnThdPxAsLULAmPP3vdPaNfPCe9KB9Y,854
|
|
106
|
+
torchzero/modules/quasi_newton/cg.py,sha256=HCfza5UInco7_hYT8s3duNRTmBdjbw5jscWLKNUiS8w,14453
|
|
107
|
+
torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=bMvIcWifYlJX83UtXFESMw7OdA4AO7tJwlHZwkc5wx0,6555
|
|
108
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=BmE5sOFLFoJDlpoSphM5VowMgt7wtEFihbLkdylDXhM,10638
|
|
109
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=a19a9ABqMiTVJmXe6Woc0sJ1kkhQa3Y6QDouaUNnPt0,7873
|
|
110
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=hKJ9Irmh2pKNfB7Wen4MrDfMrbvzp00FTcPlpFvJLDU,48582
|
|
111
|
+
torchzero/modules/quasi_newton/trust_region.py,sha256=cxOEDeZ8ZhG_w7QXGYnTsF-t5g5zZ39q9Uxb2IXWgAY,15213
|
|
112
|
+
torchzero/modules/second_order/__init__.py,sha256=Trje1qM65yp8WWzuRm-tMTRqfKi4wpI7f8yyZWjhPCw,152
|
|
113
|
+
torchzero/modules/second_order/newton.py,sha256=94LGrQo5Q8aC5DI9S6RSXF0stVcgWzq3JnE9l_BsVUw,12875
|
|
114
|
+
torchzero/modules/second_order/newton_cg.py,sha256=l8FX9vQSVCSkpk5a-M2wEBBjQoODF-T07GFW_tjJxkM,14890
|
|
115
|
+
torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
|
|
116
|
+
torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
|
|
117
|
+
torchzero/modules/smoothing/gaussian.py,sha256=iTsWlMNHuDLoxPRIsm2pAb5cS8OqdRJwCsw-vUTVmpE,7887
|
|
118
|
+
torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
|
|
119
|
+
torchzero/modules/step_size/__init__.py,sha256=Z8NpB9RYIXhcNx11NWixa7mORPiT4nI1mKQGA7JfC6g,122
|
|
120
|
+
torchzero/modules/step_size/adaptive.py,sha256=3qQr1aaPYEJlkiDSQbuVQ_OVkOq-W4LL7PkHFFgwP2c,4845
|
|
121
|
+
torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
|
|
122
|
+
torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
|
|
123
|
+
torchzero/modules/weight_decay/weight_decay.py,sha256=2MhWRyryplDtB61QyKN7KqBa3mEkhtqXhij8LGR-mYA,5464
|
|
124
|
+
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
125
|
+
torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
|
|
126
|
+
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
127
|
+
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
128
|
+
torchzero/optim/utility/split.py,sha256=ZbazNuMTYunm75V_5ard0A_LletGaYAg-Pm2rANJKrE,1610
|
|
129
|
+
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
130
|
+
torchzero/optim/wrappers/directsearch.py,sha256=GQ2nzy9ADqbV_QUMN3IaYecZ0Pzx_3mAasSB4fryTBE,11362
|
|
131
|
+
torchzero/optim/wrappers/fcmaes.py,sha256=o_FchMtDsrEj9XRonHHeyVHPAXTHaU244SzlldgEzLg,4250
|
|
132
|
+
torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
|
|
133
|
+
torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
|
|
134
|
+
torchzero/optim/wrappers/nlopt.py,sha256=AaVEKfjbrt5DFION44_-g-jQAoVi4lCvBBPU5UDGO9Q,8151
|
|
135
|
+
torchzero/optim/wrappers/optuna.py,sha256=ZZ66aXEypSJMVomphbzHNJnmIOyXS9tqE89YZBPpIuo,2331
|
|
136
|
+
torchzero/optim/wrappers/scipy.py,sha256=Td1AvpLDEPqPVW6IpHbkVW4CpNiUU9r_eyc3qJVHZAY,19352
|
|
137
|
+
torchzero/utils/__init__.py,sha256=4JMKzF3qICE9PSfgXAwb3cPswM5f1JUutWwviev2-0k,875
|
|
138
|
+
torchzero/utils/compile.py,sha256=N8AWLv_7oBUHYornmvvx_L4uynjiD-x5Hj1tBwei3-w,5127
|
|
139
|
+
torchzero/utils/derivatives.py,sha256=IIn4stpMMJxYmGKh1JCH4Gha_a4w8Z5G04uVz2BwMP4,16995
|
|
140
|
+
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
141
|
+
torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
|
|
142
|
+
torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
|
|
143
|
+
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
144
|
+
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
145
|
+
torchzero/utils/python_tools.py,sha256=NEyDVJfLBbdwh5m49qiOdIr0NffZRqKhaJ-cktviD1o,3243
|
|
146
|
+
torchzero/utils/tensorlist.py,sha256=WvjhPzGbgRySAsUBFQ7b-39V9rm7jbR1VOeYZQXiiKw,53925
|
|
147
|
+
torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
|
|
148
|
+
torchzero/utils/linalg/__init__.py,sha256=tsUt20_rbA_3pV6NK7yCkGoX1l0D9ayMKwZeySsYxHw,291
|
|
149
|
+
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
150
|
+
torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
|
|
151
|
+
torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
|
|
152
|
+
torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
|
|
153
|
+
torchzero/utils/linalg/solve.py,sha256=JF0i_eJTBRKCs7CONUOV7coPjE46NC5nMaz2JotrvSE,11232
|
|
154
|
+
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
155
|
+
torchzero-0.3.11.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
|
|
156
|
+
torchzero-0.3.11.dist-info/METADATA,sha256=Czo-sKnlVxQ75MhY3D61oD8lusASV0ez_l697dyJBNc,15797
|
|
157
|
+
torchzero-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
158
|
+
torchzero-0.3.11.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
|
|
159
|
+
torchzero-0.3.11.dist-info/RECORD,,
|