torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ from typing_extensions import Self, TypeAlias, Unpack
19
19
 
20
20
  import torch
21
21
  from .ops import where_
22
- from .python_tools import generic_eq, zipmap
22
+ from .python_tools import zipmap, generic_ne
23
23
  from .numberlist import NumberList, as_numberlist, maybe_numberlist
24
24
 
25
25
 
@@ -217,6 +217,12 @@ class TensorList(list[torch.Tensor | Any]):
217
217
  """Returns a TensorList with all elements for which `fn` returned True."""
218
218
  return self.__class__(i for i in self if fn(i, *args, **kwargs))
219
219
 
220
+ def filter_by_list(self, s: Sequence[bool]):
221
+ """returns a new TensorList with all elements where corresponding elements in :code:`s` are True."""
222
+ if len(self) != len(s):
223
+ raise ValueError(f"{len(self) = }, {len(s) = }")
224
+ return self.__class__(i for i, boolean in zip(self, s) if boolean)
225
+
220
226
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
221
227
  """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
222
228
  Otherwise applies `fn` to this TensorList and `other`.
@@ -319,7 +325,8 @@ class TensorList(list[torch.Tensor | Any]):
319
325
  def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
320
326
  def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
321
327
  def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
322
- def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
328
+ def global_vector_norm(self, ord:float | Literal['mean_abs'] = 2) -> torch.Tensor:
329
+ if ord == 'mean_abs': return self.abs().global_mean()
323
330
  return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
324
331
  def global_any(self): return builtins.any(self.any())
325
332
  def global_all(self): return builtins.all(self.all())
@@ -425,11 +432,11 @@ class TensorList(list[torch.Tensor | Any]):
425
432
  return self
426
433
 
427
434
  def lazy_add(self, other: int | float | list[int | float] | tuple[int | float]):
428
- if generic_eq(other, 0): return self
429
- return self.add(other)
435
+ if generic_ne(other, 0): return self.add(other)
436
+ return self
430
437
  def lazy_add_(self, other: int | float | list[int | float] | tuple[int | float]):
431
- if generic_eq(other, 0): return self
432
- return self.add_(other)
438
+ if generic_ne(other, 0): return self.add_(other)
439
+ return self
433
440
 
434
441
  @overload
435
442
  def sub(self, other: _TensorSeq, alpha: _Scalar = 1): ...
@@ -449,11 +456,11 @@ class TensorList(list[torch.Tensor | Any]):
449
456
  return self
450
457
 
451
458
  def lazy_sub(self, other: int | float | list[int | float] | tuple[int | float]):
452
- if generic_eq(other, 0): return self
453
- return self.sub(other)
459
+ if generic_ne(other, 0): return self.sub(other)
460
+ return self
454
461
  def lazy_sub_(self, other: int | float | list[int | float] | tuple[int | float]):
455
- if generic_eq(other, 0): return self
456
- return self.sub_(other)
462
+ if generic_ne(other, 0): return self.sub_(other)
463
+ return self
457
464
 
458
465
  def neg(self): return self.__class__(torch._foreach_neg(self))
459
466
  def neg_(self):
@@ -467,13 +474,13 @@ class TensorList(list[torch.Tensor | Any]):
467
474
 
468
475
  # TODO: benchmark
469
476
  def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
470
- if generic_eq(other, 1):
471
- if clone: return self.clone()
472
- return self
473
- return self * other
477
+ if generic_ne(other, 1):
478
+ return self * other
479
+ if clone: return self.clone()
480
+ return self
474
481
  def lazy_mul_(self, other: int | float | list[int | float] | tuple[int | float]):
475
- if generic_eq(other, 1): return self
476
- return self.mul_(other)
482
+ if generic_ne(other, 1): return self.mul_(other)
483
+ return self
477
484
 
478
485
  def div(self, other: _STOrSTSeq) -> Self: return self.__class__(torch._foreach_div(self, other))
479
486
  def div_(self, other: _STOrSTSeq):
@@ -481,11 +488,11 @@ class TensorList(list[torch.Tensor | Any]):
481
488
  return self
482
489
 
483
490
  def lazy_div(self, other: int | float | list[int | float] | tuple[int | float]):
484
- if generic_eq(other, 1): return self
485
- return self / other
491
+ if generic_ne(other, 1): return self / other
492
+ return self
486
493
  def lazy_div_(self, other: int | float | list[int | float] | tuple[int | float]):
487
- if generic_eq(other, 1): return self
488
- return self.div_(other)
494
+ if generic_ne(other, 1): return self.div_(other)
495
+ return self
489
496
 
490
497
  def pow(self, exponent: "_Scalar | _STSeq"): return self.__class__(torch._foreach_pow(self, exponent))
491
498
  def pow_(self, exponent: "_Scalar | _STSeq"):
@@ -627,7 +634,8 @@ class TensorList(list[torch.Tensor | Any]):
627
634
  if dim is None: dim = ()
628
635
  return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
629
636
 
630
- def norm(self, ord: _Scalar, dtype=None):
637
+ def norm(self, ord: _Scalar|Literal["mean_abs"], dtype=None):
638
+ if isinstance(ord, str): return self.abs().mean()
631
639
  return self.__class__(torch._foreach_norm(self, ord, dtype))
632
640
 
633
641
  def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
@@ -782,7 +790,7 @@ class TensorList(list[torch.Tensor | Any]):
782
790
  for t, o in zip(self, other): t.copysign_(o)
783
791
  return self
784
792
 
785
- def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
793
+ def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
786
794
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
787
795
  if tensorwise:
788
796
  norm_self = self.norm(ord)
@@ -791,11 +799,11 @@ class TensorList(list[torch.Tensor | Any]):
791
799
  norm_self = self.global_vector_norm(ord)
792
800
  norm_other = magnitude.global_vector_norm(ord)
793
801
 
794
- if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
802
+ if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
795
803
 
796
804
  return self * (norm_other / norm_self.clip_(min=eps))
797
805
 
798
- def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
806
+ def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
799
807
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
800
808
  if tensorwise:
801
809
  norm_self = self.norm(ord)
@@ -804,7 +812,7 @@ class TensorList(list[torch.Tensor | Any]):
804
812
  norm_self = self.global_vector_norm(ord)
805
813
  norm_other = magnitude.global_vector_norm(ord)
806
814
 
807
- if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
815
+ if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
808
816
 
809
817
  return self.mul_(norm_other / norm_self.clip_(min=eps))
810
818
 
@@ -897,7 +905,7 @@ class TensorList(list[torch.Tensor | Any]):
897
905
  if eps!=0: std.add_(eps)
898
906
  return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
899
907
 
900
- def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
908
+ def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
901
909
  """calculate multipler to clip self norm to min and max"""
902
910
  if tensorwise:
903
911
  self_norm = self.norm(ord)
@@ -918,12 +926,12 @@ class TensorList(list[torch.Tensor | Any]):
918
926
 
919
927
  return mul
920
928
 
921
- def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
929
+ def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
922
930
  """clips norm of each tensor to (min, max) range"""
923
931
  if min is None and max is None: return self
924
932
  return self * self._clip_multiplier(min, max, tensorwise, ord)
925
933
 
926
- def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
934
+ def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
927
935
  """clips norm of each tensor to (min, max) range"""
928
936
  if min is None and max is None: return self
929
937
  return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
@@ -1057,6 +1065,10 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
1057
1065
  if isinstance(x, torch.Tensor): return x.numel()
1058
1066
  return x.global_numel()
1059
1067
 
1068
+ def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
1069
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
1070
+ return torch.finfo(x[0].dtype).eps
1071
+
1060
1072
  @overload
1061
1073
  def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
1062
1074
  @overload
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -45,8 +45,6 @@ Dynamic: license-file
45
45
 
46
46
  `torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
47
47
 
48
- NOTE: torchzero is in active development, currently docs are in a state of flux.
49
-
50
48
  ## Installation
51
49
 
52
50
  ```bash
@@ -113,31 +111,21 @@ for epoch in range(100):
113
111
  `torchzero` provides a huge number of various modules:
114
112
 
115
113
  * **Optimizers**: Optimization algorithms.
116
- * `Adam`.
117
- * `Shampoo`.
118
- * `SOAP` (my current recommendation).
119
- * `Muon`.
120
- * `SophiaH`.
121
- * `Adagrad` and `FullMatrixAdagrad`.
122
- * `Lion`.
123
- * `RMSprop`.
124
- * `OrthoGrad`.
125
- * `Rprop`.
114
+ * `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
126
115
 
127
116
  Additionally many other optimizers can be easily defined via modules:
128
117
  * Grams: `[tz.m.Adam(), tz.m.GradSign()]`
129
118
  * LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
130
119
  * Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
131
- * Full matrix version of any diagonal optimizer, like Adam: `tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9))`
120
+ * Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
132
121
  * Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
133
122
 
134
123
  * **Momentum**:
135
- * `NAG`: Nesterov Accelerated Gradient.
136
124
  * `HeavyBall`: Classic momentum (Polyak's momentum).
125
+ * `NAG`: Nesterov Accelerated Gradient.
137
126
  * `EMA`: Exponential moving average.
138
- * `Averaging` (`Medianveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
127
+ * `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
139
128
  * `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
140
- * `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
141
129
 
142
130
  * **Stabilization**: Gradient stabilization techniques.
143
131
  * `ClipNorm`: Clips gradient L2 norm.
@@ -154,31 +142,42 @@ for epoch in range(100):
154
142
 
155
143
  * **Second order**: Second order methods.
156
144
  * `Newton`: Classic Newton's method.
157
- * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
145
+ * `InverseFreeNewton`: Inverse-free version of Newton's method.
146
+ * `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
147
+ * `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
158
148
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
- * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
149
+ * `NystromPCG`: NewtonCG with Nyström preconditioning.
150
+ * `HigherOrderNewton`: Higher order Newton's method with trust region.
160
151
 
161
152
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
153
  * `LBFGS`: Limited-memory BFGS.
163
154
  * `LSR1`: Limited-memory SR1.
164
155
  * `OnlineLBFGS`: Online LBFGS.
165
- * `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
166
- * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
156
+ * `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
157
+ * `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
158
+ * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
159
+
160
+ * **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
161
+ * `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
162
+ * `CubicRegularization`: Cubic regularization, works better with exact hessian.
167
163
 
168
164
  * **Line Search**:
169
165
  * `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
170
166
  * `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
171
167
  * `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
172
- * `TrustRegion`: First order trust region method.
173
168
 
174
169
  * **Learning Rate**:
175
170
  * `LR`: Controls learning rate and adds support for LR schedulers.
176
- * `PolyakStepSize`: Polyak's method.
177
- * `Warmup`: Learning rate warmup.
171
+ * `PolyakStepSize`: Polyak's subgradient method.
172
+ * `BarzilaiBorwein`: Barzilai-Borwein step-size.
173
+ * `Warmup`, `WarmupNormCLip`: Learning rate warmup.
178
174
 
179
175
  * **Projections**: This can implement things like GaLore but I haven't done that yet.
180
- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
181
- * `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
176
+ <!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
177
+ * `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
178
+ This is WIP
179
+ * `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
180
+ * `ViewAsReal`: put if you have complex paramters.
182
181
 
183
182
  * **Smoothing**: Smoothing-based optimization methods.
184
183
  * `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
@@ -194,6 +193,8 @@ for epoch in range(100):
194
193
 
195
194
  * **Experimental**: various horrible atrocities
196
195
 
196
+ A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
197
+
197
198
  ## Advanced Usage
198
199
 
199
200
  ### Closure
@@ -312,20 +313,21 @@ not in the module itself. Also both per-parameter settings and state are stored
312
313
 
313
314
  ```python
314
315
  import torch
315
- from torchzero.core import Module, Vars
316
+ from torchzero.core import Module, Var
316
317
 
317
318
  class HeavyBall(Module):
318
319
  def __init__(self, momentum: float = 0.9, dampening: float = 0):
319
320
  defaults = dict(momentum=momentum, dampening=dampening)
320
321
  super().__init__(defaults)
321
322
 
322
- def step(self, vars: Vars):
323
- # a module takes a Vars object, modifies it or creates a new one, and returns it
324
- # Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
323
+ def step(self, var: Var):
324
+ # Var object holds all attributes used for optimization - parameters, gradient, update, etc.
325
+ # a module takes a Var object, modifies it or creates a new one, and returns it
326
+ # Var has a bunch of attributes, including parameters, gradients, update, closure, loss
325
327
  # for now we are only interested in update, and we will apply the heavyball rule to it.
326
328
 
327
- params = vars.params
328
- update = vars.get_update() # list of tensors
329
+ params = var.params
330
+ update = var.get_update() # list of tensors
329
331
 
330
332
  exp_avg_list = []
331
333
  for p, u in zip(params, update):
@@ -346,34 +348,57 @@ class HeavyBall(Module):
346
348
  # and it is part of self.state
347
349
  exp_avg_list.append(buf.clone())
348
350
 
349
- # set new update to vars
350
- vars.update = exp_avg_list
351
- return vars
351
+ # set new update to var
352
+ var.update = exp_avg_list
353
+ return var
352
354
  ```
353
355
 
354
- There are a some specialized base modules that make it much easier to implement some specific things.
356
+ More in-depth guide will be available in the documentation in the future.
357
+
358
+ ## Other stuff
355
359
 
356
- * `GradApproximator` for gradient approximations
357
- * `LineSearch` for line searches
358
- * `Preconditioner` for preconditioners
359
- * `Projection` for projections like GaLore or into fourier domain.
360
- * `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
361
- * `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
360
+ There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
362
361
 
363
- The documentation on how to actually use them is to write itself in the near future.
362
+ ---
364
363
 
365
- ## License
364
+ ### Scipy
366
365
 
367
- This project is licensed under the MIT License
366
+ #### torchzero.optim.wrappers.scipy.ScipyMinimize
368
367
 
369
- ## Project Links
368
+ A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
370
369
 
371
- TODO (there are docs but from very old version)
370
+ #### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
372
371
 
373
- ## Other stuff
372
+ Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
373
+
374
+ ---
375
+
376
+ ### NLOpt
377
+
378
+ #### torchzero.optim.wrappers.nlopt.NLOptWrapper
374
379
 
375
- There are also wrappers providing `torch.optim.Optimizer` interface for for `scipy.optimize`, NLOpt and Nevergrad.
380
+ A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
381
+
382
+ ---
383
+
384
+ ### Nevergrad
385
+
386
+ #### torchzero.optim.wrappers.nevergrad.NevergradWrapper
387
+
388
+ A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
389
+
390
+ ---
391
+
392
+ ### fast-cma-es
393
+
394
+ #### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
395
+
396
+ A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
397
+
398
+ # License
399
+
400
+ This project is licensed under the MIT License
376
401
 
377
- They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
402
+ # Project Links
378
403
 
379
- Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
404
+ The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
@@ -0,0 +1,159 @@
1
+ docs/source/conf.py,sha256=Kd0Uyu6WnhSHEyTbOEjxoaUg4sAu0AxN19raSARtltE,1883
2
+ docs/source/docstring template.py,sha256=lIf4Jdkxd-Vr0vOuL9IOTCMOxw5ENsmZDLXKv1eO9ns,1585
3
+ tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
4
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
5
+ tests/test_opts.py,sha256=pAeyDIT0Q4SXBZqR9W_IUjwAEBcMnYr3zE0N4R0xn8w,42509
6
+ tests/test_tensorlist.py,sha256=SwzLKLrs2ppMtm_7UrfTDTlD-ObZd7JQ_FNHbp059tc,72460
7
+ tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
8
+ tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
9
+ torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
10
+ torchzero/core/__init__.py,sha256=Zib_4is13LFAabp_7VU8QXZpQEEZGzsH94vgRI0HxAg,150
11
+ torchzero/core/module.py,sha256=BfU4YMjwLrwcz24XAfL-cZx05cESIimViKUStJKBEHM,32872
12
+ torchzero/core/transform.py,sha256=sBgEyQVm141v99lnosusNIMWaReuWKuMyzkJha_WwKg,16440
13
+ torchzero/modules/__init__.py,sha256=0Gk6XK32FKxtiW9rh-0Plql2dghHn3Ms1F-Ymn4oVzw,386
14
+ torchzero/modules/functional.py,sha256=hmJaxB7U9X9nsT1Z5aPSqsw5HsQfL2ns1YS8AWdul6c,6948
15
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
16
+ torchzero/modules/clipping/clipping.py,sha256=6d-LPCI4zqlcV9fXK8rtRLiReyt8lMeQhmt1gsqNljs,14897
17
+ torchzero/modules/clipping/ema_clipping.py,sha256=PNUTvixvc0wdjtWzja6pEzXbNpyXtGxj_H15umWx4zc,6608
18
+ torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
19
+ torchzero/modules/experimental/__init__.py,sha256=qV-VaBnRsLFtv6T6R9Imkd1G81QR4O-9_kDbCAwJXeY,1464
20
+ torchzero/modules/experimental/absoap.py,sha256=U3nLAV_vxl6HjJhqi8FlK8K6AMLoiZ-deykEshhnCC0,9916
21
+ torchzero/modules/experimental/adadam.py,sha256=PARjM2kRmJ7ifYsI83tADKCuvSZYAoT2vR4Gj2aZ-SA,4103
22
+ torchzero/modules/experimental/adamY.py,sha256=Rr9vXjFPWTfIHnnhGQAfVAQnfANNgcrFm_R8vJsU1to,4043
23
+ torchzero/modules/experimental/adam_lambertw.py,sha256=FXZiTJKVRbXSu9-_boZGYoCqBlh2035mwsagq75qyeA,5323
24
+ torchzero/modules/experimental/adaptive_step_size.py,sha256=OJseQX9sd9F58pMC5JbVNm7PtovMXL4sMwQg3jooVtg,3494
25
+ torchzero/modules/experimental/adasoap.py,sha256=vcgWEgDdqmgimt5bGgvznCnxkkathGO0engd1xo7M4s,7491
26
+ torchzero/modules/experimental/cosine.py,sha256=0Cc42Wd1sMrjm-YxmpcwCCsGpLv3H83rL-XAtrgZhb4,9155
27
+ torchzero/modules/experimental/cubic_adam.py,sha256=wHJKm9bO24Xvtwunz_1Kz7mGi_C-syupixiDaBnYx2Q,2787
28
+ torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
29
+ torchzero/modules/experimental/dct.py,sha256=Iv8ZxGhTOIm3NHS4zxoFG9K9BEwtrJqsKApctiIjnxg,2463
30
+ torchzero/modules/experimental/eigendescent.py,sha256=Pdz7QUbM3pD3DTsTC0nZ0AfOe2pj-WVPPkbnw8lDZ3c,4725
31
+ torchzero/modules/experimental/etf.py,sha256=ul167I1qAbYeTmTPG_WFLLlE1MEsNXxVsTWd9s2YC9g,6125
32
+ torchzero/modules/experimental/exp_adam.py,sha256=yhR5-NGflbEJrSAe0ps4xgAM-eFI-gAdS6cgZIJDgaI,4100
33
+ torchzero/modules/experimental/expanded_lbfgs.py,sha256=M58cCaeLZXGqZwyaeGhi-UAyCsnnJvLAYIZ64r0tQNE,5649
34
+ torchzero/modules/experimental/fft.py,sha256=YEUKdAXNX8BCZYXKV5uWWU8aTlGjpFTUSpIEwIG-_fM,3050
35
+ torchzero/modules/experimental/gradmin.py,sha256=UixSLdca4ekYHOipEivdXfBAV-uEL9TZm5nCFXVaNco,3684
36
+ torchzero/modules/experimental/hnewton.py,sha256=_Gv4O2x0qYBxGtkCuYuzL21VuI5wTn1sTEegk17d6X4,3036
37
+ torchzero/modules/experimental/modular_lbfgs.py,sha256=d40yRi6NN2Au7-UQ1akMkET0PWhEFAhGKAYoQBDmqFQ,10671
38
+ torchzero/modules/experimental/newton_solver.py,sha256=3dZ7FG-2vGxJKkFF9P2LCs-LI_epcvZbyNtJOtw47pg,3055
39
+ torchzero/modules/experimental/newtonnewton.py,sha256=cRL4dKsDAN8tHPyHQkLbTGxkHfemCU6re-n4odV3Ik4,3324
40
+ torchzero/modules/experimental/parabolic_search.py,sha256=2GgE4cq5QkJYZprADIplQfbPWRJRGFmToYTScJkR0tg,6328
41
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
42
+ torchzero/modules/experimental/structural_projections.py,sha256=lrySQZOq7VhL_VqU7dIJRsypxA16cUliQYkj5-N2B2I,4187
43
+ torchzero/modules/experimental/subspace_preconditioners.py,sha256=RdG-RoPF6AiFVphrVlb6egNyYI0e_eHoENUWqKJ4icQ,5170
44
+ torchzero/modules/experimental/tensor_adagrad.py,sha256=y29i6BGXwv9lwrTRDzq2YRSngQmfZnreRIeH1NGzpBo,1572
45
+ torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
46
+ torchzero/modules/grad_approximation/fdm.py,sha256=K_D0fKwspg21Opo2xTG4I34gLDmcaYBp5NUzlaQnjxQ,4490
47
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=AoezoYxXii2gKpIGO7BOZkLb2weYwxrWAKpHL7hrW9Y,4313
48
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=HO-XaNRF3ZwMduBP02V0oabmSRgqmDGPlKkWfDVDPW8,4740
49
+ torchzero/modules/grad_approximation/rfdm.py,sha256=omarcZyMgJomJwxQ_b7ulE6eK6aW3JP_Sh-jcX5DhR4,23434
50
+ torchzero/modules/higher_order/__init__.py,sha256=W94CY8K1NFxs9TPi415UssKVKz5MV_bH9adax1uZsYM,50
51
+ torchzero/modules/higher_order/higher_order_newton.py,sha256=_v5v0WY07CvZn9QPIS89FxEZ2tNfd8Bkamt1o12_mLQ,12255
52
+ torchzero/modules/line_search/__init__.py,sha256=9ja1Dspfuzu9UxGbU5-t0bFeBcdwoX9Fl_aSMR-AXnQ,219
53
+ torchzero/modules/line_search/adaptive.py,sha256=Uj7lAIzpgy89ddlwA4VcEEIfcNJSbGA5HH3ncuzHrTU,2926
54
+ torchzero/modules/line_search/backtracking.py,sha256=dyXgfrIJ_IO7W4p8GqJNPc4r_igU4X4ljLCLNKyY2Tw,8246
55
+ torchzero/modules/line_search/line_search.py,sha256=_u59XYFkRsIKuT1H4Bz7qAHr3Ldzxbup71OeqDGxMfs,9724
56
+ torchzero/modules/line_search/polynomial.py,sha256=KlK0d9qaphxS0s8B5rlt-yIUYNuV-5O24STcx4vN2Ic,9056
57
+ torchzero/modules/line_search/scipy.py,sha256=eGplW1L8kQKdRbt9PPpvZ6MMekDq5KsjurhSpN9QCnY,2301
58
+ torchzero/modules/line_search/strong_wolfe.py,sha256=F5962HTHdPWgvWHwnUofCqFxfKsCu5p8Ic-aRbn7wVg,8458
59
+ torchzero/modules/misc/__init__.py,sha256=cZpMkZQubuzquhFZV-yELrDMznqhhCibmr0CBOR0ZpU,693
60
+ torchzero/modules/misc/debug.py,sha256=iuWg5egoMnG6y3Cyd423xS7BRVYiwZq9575d7A7U3Dg,1652
61
+ torchzero/modules/misc/escape.py,sha256=1XgNmT4pOptaXHSWEONkUPpcYnIujm5gdK6n_-zmw20,1821
62
+ torchzero/modules/misc/gradient_accumulation.py,sha256=6yXRUxD_f3Zfx83UyCvPJ-56XN4GJjEQcNIDlvFtuuY,2590
63
+ torchzero/modules/misc/misc.py,sha256=VTQZAcfQBo2yudy1u1lyHhmaAmQlxzVcZTHcXXnUeTM,13470
64
+ torchzero/modules/misc/multistep.py,sha256=rAPCALSHXjVNxR8d1CA3RFP_xnN6j5KksjB6yl8vtng,5585
65
+ torchzero/modules/misc/regularization.py,sha256=R8ya7HEF2MLtcAr7GS9IjXwJ4xh0lJWMdWMIRfwL42s,6279
66
+ torchzero/modules/misc/split.py,sha256=ebc95OZjC-Vs73JeTkL--eZrtKijg7lPN0hmD0Whfxc,3195
67
+ torchzero/modules/misc/switch.py,sha256=72mfY_uIVyTllwuR21_K7QC8IQFP7JMKzH4K2nAx0Wc,3726
68
+ torchzero/modules/momentum/__init__.py,sha256=tI2I5zSQB7aTwEn371wvUTy2O2n_-KVCafjBv-OMsYE,545
69
+ torchzero/modules/momentum/averaging.py,sha256=gZRjHb443HuFF03p3Oh2rfgh2Qu8sJBxc_8NR-ircaA,3241
70
+ torchzero/modules/momentum/cautious.py,sha256=QP3Sqc8nMb3xTDDDfGwFn5AWvN4EI5U-CCcZb-F5oX0,8266
71
+ torchzero/modules/momentum/ema.py,sha256=9OdMF20RYnEkwe9Xu2dCAAiI0qY2MQvhS87bKP7ptTI,10755
72
+ torchzero/modules/momentum/experimental.py,sha256=WnM9FUKPxyFNiKU6Ip7wqqYxHbXuaMKOcLjjomfENb4,6916
73
+ torchzero/modules/momentum/matrix_momentum.py,sha256=gZeTJZbhgixCOkE9Jyowtva58hl5vsH9iTqGC54FWFs,8047
74
+ torchzero/modules/momentum/momentum.py,sha256=Yx35jtbLb1syVFcTiNSoZPoUPmdsUy3QpoNWcN4sC9w,2664
75
+ torchzero/modules/ops/__init__.py,sha256=1q9CBo6OXWXDgyjvKKTlG0EdP4ASIvkWFXtd6LOuU88,1083
76
+ torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
77
+ torchzero/modules/ops/binary.py,sha256=mIeaa3v5Bk7mwzSTC0jGMLhKf-Ujg6aFbSia2yo-3JQ,12199
78
+ torchzero/modules/ops/multi.py,sha256=DpabTYj0sic5dmosnmj7lgIX3dbmcgl0h9XfzKpbaus,8918
79
+ torchzero/modules/ops/reduce.py,sha256=uLCq493hFy_Ib22GjIKtMHTTObK3RDmubGHTVqgFgg8,6339
80
+ torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
81
+ torchzero/modules/ops/utility.py,sha256=9Skxkt4RO79OBdw95wOKhqKN2RMdZg9emO7R9q2d5oU,3767
82
+ torchzero/modules/optimizers/__init__.py,sha256=IJaLoZ39rbB4GSW9rLKrfSCh5FsAkFy2ww5MhJ6MYnE,817
83
+ torchzero/modules/optimizers/adagrad.py,sha256=p-DWbhGuuogldiFPNxxQfJ8AA5Tsd4UwGOIyX7GT0WE,5892
84
+ torchzero/modules/optimizers/adahessian.py,sha256=vOJfwGi7ypfi7vifCMJfGew-McdGJKQM3TmkT-OUgI0,8682
85
+ torchzero/modules/optimizers/adam.py,sha256=SkJ7UJ1BOAgfregmzYDFo_3cgPNke_RK9B58hOal_Zg,3954
86
+ torchzero/modules/optimizers/adan.py,sha256=aOG6KGLU4oHYeQn3JB-A4NQ-279QpHA7firY3kkhFR4,3311
87
+ torchzero/modules/optimizers/adaptive_heavyball.py,sha256=DnkWHA0GBLIKCq8nWh76fZA6PnJ3eKsJDBXWKnZ_uIs,2127
88
+ torchzero/modules/optimizers/esgd.py,sha256=WXwYPA-qTA_QW9h4NDwNaly9gbi1uvMQ-5fSuLqnPkQ,6413
89
+ torchzero/modules/optimizers/ladagrad.py,sha256=HQb7LuZnG8SvS8JWqu7JJz_owlkyT-fnqeICrJBQxbc,7314
90
+ torchzero/modules/optimizers/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
91
+ torchzero/modules/optimizers/mars.py,sha256=7tr32x2eQNu8ZVQAPnLIkM2kkYp7S57uiDywTdqy1uY,2710
92
+ torchzero/modules/optimizers/msam.py,sha256=nvoo6smewR3hiCCymZQiB3DlCvLBGxfxlovJF2bwwsc,6588
93
+ torchzero/modules/optimizers/muon.py,sha256=AZKpmkVUjukXtI7Pb9PKDEeycreLF6qYlIMSbV_9IuA,10463
94
+ torchzero/modules/optimizers/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
95
+ torchzero/modules/optimizers/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
96
+ torchzero/modules/optimizers/rprop.py,sha256=nFpnqcXevGkUcPWERDX9gsiBCGgOi4pyPFloL68zwPY,11984
97
+ torchzero/modules/optimizers/sam.py,sha256=yEhXAS3v62nhAvs63RZ80VfZ93MaQ0cyMQziFdy6e2U,5711
98
+ torchzero/modules/optimizers/shampoo.py,sha256=m_XOvo2Eb1HP8QqYFPsT0rgczJ8HqKjh67QmtaY9dVg,9544
99
+ torchzero/modules/optimizers/soap.py,sha256=MXQ8fdBzLyFtgW34fnmY3hQqv3q4QwEthho9kK-72VE,11305
100
+ torchzero/modules/optimizers/sophia_h.py,sha256=dgQwjij5R4zdESYoKhc4BMhb6dKkDuEvjlL4bDdeQtw,7213
101
+ torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
102
+ torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
103
+ torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
104
+ torchzero/modules/projections/projection.py,sha256=PU2e9LNfVMnNrXnBDt-hdr5pVtl0TpgiB4b92WUguSs,14005
105
+ torchzero/modules/quasi_newton/__init__.py,sha256=guTCpbAffZyupnThdPxAsLULAmPP3vdPaNfPCe9KB9Y,854
106
+ torchzero/modules/quasi_newton/cg.py,sha256=HCfza5UInco7_hYT8s3duNRTmBdjbw5jscWLKNUiS8w,14453
107
+ torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=bMvIcWifYlJX83UtXFESMw7OdA4AO7tJwlHZwkc5wx0,6555
108
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=BmE5sOFLFoJDlpoSphM5VowMgt7wtEFihbLkdylDXhM,10638
109
+ torchzero/modules/quasi_newton/lsr1.py,sha256=a19a9ABqMiTVJmXe6Woc0sJ1kkhQa3Y6QDouaUNnPt0,7873
110
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=hKJ9Irmh2pKNfB7Wen4MrDfMrbvzp00FTcPlpFvJLDU,48582
111
+ torchzero/modules/quasi_newton/trust_region.py,sha256=cxOEDeZ8ZhG_w7QXGYnTsF-t5g5zZ39q9Uxb2IXWgAY,15213
112
+ torchzero/modules/second_order/__init__.py,sha256=Trje1qM65yp8WWzuRm-tMTRqfKi4wpI7f8yyZWjhPCw,152
113
+ torchzero/modules/second_order/newton.py,sha256=94LGrQo5Q8aC5DI9S6RSXF0stVcgWzq3JnE9l_BsVUw,12875
114
+ torchzero/modules/second_order/newton_cg.py,sha256=l8FX9vQSVCSkpk5a-M2wEBBjQoODF-T07GFW_tjJxkM,14890
115
+ torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
116
+ torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
117
+ torchzero/modules/smoothing/gaussian.py,sha256=iTsWlMNHuDLoxPRIsm2pAb5cS8OqdRJwCsw-vUTVmpE,7887
118
+ torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
119
+ torchzero/modules/step_size/__init__.py,sha256=Z8NpB9RYIXhcNx11NWixa7mORPiT4nI1mKQGA7JfC6g,122
120
+ torchzero/modules/step_size/adaptive.py,sha256=3qQr1aaPYEJlkiDSQbuVQ_OVkOq-W4LL7PkHFFgwP2c,4845
121
+ torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
122
+ torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
123
+ torchzero/modules/weight_decay/weight_decay.py,sha256=2MhWRyryplDtB61QyKN7KqBa3mEkhtqXhij8LGR-mYA,5464
124
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
125
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
126
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
127
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
128
+ torchzero/optim/utility/split.py,sha256=ZbazNuMTYunm75V_5ard0A_LletGaYAg-Pm2rANJKrE,1610
129
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
130
+ torchzero/optim/wrappers/directsearch.py,sha256=GQ2nzy9ADqbV_QUMN3IaYecZ0Pzx_3mAasSB4fryTBE,11362
131
+ torchzero/optim/wrappers/fcmaes.py,sha256=o_FchMtDsrEj9XRonHHeyVHPAXTHaU244SzlldgEzLg,4250
132
+ torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
133
+ torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
134
+ torchzero/optim/wrappers/nlopt.py,sha256=AaVEKfjbrt5DFION44_-g-jQAoVi4lCvBBPU5UDGO9Q,8151
135
+ torchzero/optim/wrappers/optuna.py,sha256=ZZ66aXEypSJMVomphbzHNJnmIOyXS9tqE89YZBPpIuo,2331
136
+ torchzero/optim/wrappers/scipy.py,sha256=Td1AvpLDEPqPVW6IpHbkVW4CpNiUU9r_eyc3qJVHZAY,19352
137
+ torchzero/utils/__init__.py,sha256=4JMKzF3qICE9PSfgXAwb3cPswM5f1JUutWwviev2-0k,875
138
+ torchzero/utils/compile.py,sha256=N8AWLv_7oBUHYornmvvx_L4uynjiD-x5Hj1tBwei3-w,5127
139
+ torchzero/utils/derivatives.py,sha256=IIn4stpMMJxYmGKh1JCH4Gha_a4w8Z5G04uVz2BwMP4,16995
140
+ torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
141
+ torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
142
+ torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
143
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
144
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
145
+ torchzero/utils/python_tools.py,sha256=NEyDVJfLBbdwh5m49qiOdIr0NffZRqKhaJ-cktviD1o,3243
146
+ torchzero/utils/tensorlist.py,sha256=WvjhPzGbgRySAsUBFQ7b-39V9rm7jbR1VOeYZQXiiKw,53925
147
+ torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
148
+ torchzero/utils/linalg/__init__.py,sha256=tsUt20_rbA_3pV6NK7yCkGoX1l0D9ayMKwZeySsYxHw,291
149
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
150
+ torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
151
+ torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
152
+ torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
153
+ torchzero/utils/linalg/solve.py,sha256=JF0i_eJTBRKCs7CONUOV7coPjE46NC5nMaz2JotrvSE,11232
154
+ torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
155
+ torchzero-0.3.11.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
156
+ torchzero-0.3.11.dist-info/METADATA,sha256=Czo-sKnlVxQ75MhY3D61oD8lusASV0ez_l697dyJBNc,15797
157
+ torchzero-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
158
+ torchzero-0.3.11.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
159
+ torchzero-0.3.11.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5