torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -11,17 +11,19 @@ in an optimizer when you have to create one from parameters on each step. The so
11
11
  it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
12
12
  """
13
13
  import builtins
14
- from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
15
14
  import math
16
15
  import operator
16
+ from abc import ABC, abstractmethod
17
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
17
18
  from typing import Any, Literal, TypedDict, overload
18
- from typing_extensions import Self, TypeAlias, Unpack
19
19
 
20
20
  import torch
21
- from .ops import where_
22
- from .python_tools import zipmap, generic_ne
23
- from .numberlist import NumberList, as_numberlist, maybe_numberlist
21
+ from typing_extensions import Self, TypeAlias, Unpack
24
22
 
23
+ from .metrics import Metrics, evaluate_metric, calculate_metric_list
24
+ from .numberlist import NumberList, as_numberlist, maybe_numberlist
25
+ from .ops import where_
26
+ from .python_tools import generic_ne, zipmap
25
27
 
26
28
  _Scalar = int | float | bool | complex
27
29
  _TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
@@ -33,6 +35,7 @@ _STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
33
35
  _Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
34
36
 
35
37
  Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
38
+
36
39
  class _NewTensorKwargs(TypedDict, total = False):
37
40
  memory_format: Any
38
41
  dtype: Any
@@ -325,9 +328,20 @@ class TensorList(list[torch.Tensor | Any]):
325
328
  def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
326
329
  def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
327
330
  def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
328
- def global_vector_norm(self, ord:float | Literal['mean_abs'] = 2) -> torch.Tensor:
329
- if ord == 'mean_abs': return self.abs().global_mean()
330
- return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
331
+
332
+ def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
333
+ # return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
334
+ if ord == 1: return self.global_sum()
335
+ if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
336
+ if ord == torch.inf: return self.abs().global_max()
337
+ if ord == -torch.inf: return self.abs().global_min()
338
+ if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
339
+
340
+ return self.abs().pow_(ord).global_sum().pow(1/ord)
341
+
342
+ def global_metric(self, metric: Metrics) -> torch.Tensor:
343
+ return evaluate_metric(self, metric)
344
+
331
345
  def global_any(self): return builtins.any(self.any())
332
346
  def global_all(self): return builtins.all(self.all())
333
347
  def global_numel(self) -> int: return builtins.sum(self.numel())
@@ -358,31 +372,54 @@ class TensorList(list[torch.Tensor | Any]):
358
372
 
359
373
  def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
360
374
  return self.zipmap_args(torch.randint_like, low, high, **kwargs)
375
+
361
376
  def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
362
377
  res = self.empty_like(**kwargs)
363
378
  res.uniform_(low, high, generator=generator)
364
379
  return res
380
+
365
381
  def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
366
382
  r = self.randn_like(generator=generator, **kwargs)
367
- return (r * radius) / r.global_vector_norm()
383
+ return r.mul_(maybe_numberlist(radius) / r.global_vector_norm())
384
+
368
385
  def bernoulli(self, generator = None):
369
386
  return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
387
+
370
388
  def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
371
389
  """p is probability of a 1, other values will be 0."""
372
390
  return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
391
+
373
392
  def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
374
393
  """p is probability of a 1, other values will be -1."""
375
394
  return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
376
395
 
377
- def sample_like(self, eps: "_Scalar | _ScalarSeq" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
396
+ def sample_like(self, distribution: Distributions = 'normal', variance: "_Scalar | _ScalarSeq | Sequence | None" = None, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
378
397
  """Sample around 0."""
379
- if distribution in ('normal', 'gaussian'): return self.randn_like(generator=generator, **kwargs) * eps
398
+ if isinstance(variance, Sequence):
399
+ if all(v is None for v in variance): variance = None
400
+ else: variance = [v if v is not None else 1 for v in variance]
401
+
402
+ if distribution in ('normal', 'gaussian'):
403
+ ret = self.randn_like(generator=generator, **kwargs)
404
+ if variance is not None: ret *= variance
405
+ return ret
406
+
380
407
  if distribution == 'uniform':
381
- if isinstance(eps, (list,tuple)):
382
- return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs)
383
- return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
384
- if distribution == 'sphere': return self.sphere_like(eps, generator=generator, **kwargs)
385
- if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
408
+ b = 1
409
+ if variance is not None:
410
+ b = ((12 * maybe_numberlist(variance)) ** 0.5) / 2
411
+ return self.uniform_like(-b, b, generator=generator, **kwargs)
412
+
413
+ if distribution == 'sphere':
414
+ if variance is None: radius = 1
415
+ else: radius = maybe_numberlist(variance) * math.sqrt(self.global_numel())
416
+ return self.sphere_like(radius, generator=generator, **kwargs)
417
+
418
+ if distribution == 'rademacher':
419
+ ret = self.rademacher_like(generator=generator, **kwargs)
420
+ if variance is not None: ret *= variance
421
+ return ret
422
+
386
423
  raise ValueError(f'Unknow distribution {distribution}')
387
424
 
388
425
  def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
@@ -504,6 +541,11 @@ class TensorList(list[torch.Tensor | Any]):
504
541
  torch._foreach_pow_(input, self)
505
542
  return self
506
543
 
544
+ def square(self): return self.__class__(torch._foreach_pow(self, 2))
545
+ def square_(self):
546
+ torch._foreach_pow_(self, 2)
547
+ return self
548
+
507
549
  def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
508
550
  def sqrt_(self):
509
551
  torch._foreach_sqrt_(self)
@@ -634,10 +676,12 @@ class TensorList(list[torch.Tensor | Any]):
634
676
  if dim is None: dim = ()
635
677
  return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
636
678
 
637
- def norm(self, ord: _Scalar|Literal["mean_abs"], dtype=None):
638
- if isinstance(ord, str): return self.abs().mean()
679
+ def norm(self, ord: float, dtype=None):
639
680
  return self.__class__(torch._foreach_norm(self, ord, dtype))
640
681
 
682
+ def metric(self, metric: Metrics) -> "TensorList":
683
+ return calculate_metric_list(self, metric)
684
+
641
685
  def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
642
686
  if dim == 'global': return self._global_fn(keepdim, self.global_mean)
643
687
  return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
@@ -790,27 +834,27 @@ class TensorList(list[torch.Tensor | Any]):
790
834
  for t, o in zip(self, other): t.copysign_(o)
791
835
  return self
792
836
 
793
- def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
837
+ def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
794
838
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
795
839
  if tensorwise:
796
- norm_self = self.norm(ord)
797
- norm_other = magnitude.norm(ord)
840
+ norm_self = self.metric(ord)
841
+ norm_other = magnitude.metric(ord)
798
842
  else:
799
- norm_self = self.global_vector_norm(ord)
800
- norm_other = magnitude.global_vector_norm(ord)
843
+ norm_self = self.global_metric(ord)
844
+ norm_other = magnitude.global_metric(ord)
801
845
 
802
846
  if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
803
847
 
804
848
  return self * (norm_other / norm_self.clip_(min=eps))
805
849
 
806
- def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
850
+ def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
807
851
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
808
852
  if tensorwise:
809
- norm_self = self.norm(ord)
810
- norm_other = magnitude.norm(ord)
853
+ norm_self = self.metric(ord)
854
+ norm_other = magnitude.metric(ord)
811
855
  else:
812
- norm_self = self.global_vector_norm(ord)
813
- norm_other = magnitude.global_vector_norm(ord)
856
+ norm_self = self.global_metric(ord)
857
+ norm_other = magnitude.global_metric(ord)
814
858
 
815
859
  if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
816
860
 
@@ -905,14 +949,14 @@ class TensorList(list[torch.Tensor | Any]):
905
949
  if eps!=0: std.add_(eps)
906
950
  return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
907
951
 
908
- def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
952
+ def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
909
953
  """calculate multipler to clip self norm to min and max"""
910
954
  if tensorwise:
911
- self_norm = self.norm(ord)
955
+ self_norm = self.metric(ord)
912
956
  self_norm.masked_fill_(self_norm == 0, 1)
913
957
 
914
958
  else:
915
- self_norm = self.global_vector_norm(ord)
959
+ self_norm = self.global_metric(ord)
916
960
  if self_norm == 0: return 1
917
961
 
918
962
  mul = 1
@@ -926,12 +970,12 @@ class TensorList(list[torch.Tensor | Any]):
926
970
 
927
971
  return mul
928
972
 
929
- def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
973
+ def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
930
974
  """clips norm of each tensor to (min, max) range"""
931
975
  if min is None and max is None: return self
932
976
  return self * self._clip_multiplier(min, max, tensorwise, ord)
933
977
 
934
- def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
978
+ def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
935
979
  """clips norm of each tensor to (min, max) range"""
936
980
  if min is None and max is None: return self
937
981
  return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
@@ -990,6 +1034,15 @@ class TensorList(list[torch.Tensor | Any]):
990
1034
  # """sets index in flattened view"""
991
1035
  # return self.clone().flatset_(idx, value)
992
1036
 
1037
+ def flat_get(self, idx: int):
1038
+ cur = 0
1039
+ for tensor in self:
1040
+ numel = tensor.numel()
1041
+ if idx < cur + numel:
1042
+ return tensor.view(-1)[cur-idx]
1043
+ cur += numel
1044
+ raise IndexError(idx)
1045
+
993
1046
  def flat_set_(self, idx: int, value: Any):
994
1047
  """sets index in flattened view"""
995
1048
  cur = 0
@@ -1065,10 +1118,19 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
1065
1118
  if isinstance(x, torch.Tensor): return x.numel()
1066
1119
  return x.global_numel()
1067
1120
 
1121
+
1122
+ def generic_finfo(x: torch.Tensor | TensorList) -> torch.finfo:
1123
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype)
1124
+ return torch.finfo(x[0].dtype)
1125
+
1068
1126
  def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
1069
1127
  if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
1070
1128
  return torch.finfo(x[0].dtype).eps
1071
1129
 
1130
+ def generic_finfo_tiny(x: torch.Tensor | TensorList) -> float:
1131
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).tiny
1132
+ return torch.finfo(x[0].dtype).tiny
1133
+
1072
1134
  @overload
1073
1135
  def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
1074
1136
  @overload
@@ -1081,7 +1143,8 @@ def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
1081
1143
  if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
1082
1144
  return x.global_vector_norm(ord)
1083
1145
 
1084
-
1146
+ def generic_metric(x: torch.Tensor | TensorList, metric: Metrics) -> torch.Tensor:
1147
+ return evaluate_metric(x, metric)
1085
1148
 
1086
1149
  @overload
1087
1150
  def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
@@ -1091,3 +1154,11 @@ def generic_randn_like(x: torch.Tensor | TensorList):
1091
1154
  if isinstance(x, torch.Tensor): return torch.randn_like(x)
1092
1155
  return x.randn_like()
1093
1156
 
1157
+
1158
+ def generic_sum(x: torch.Tensor | TensorList) -> torch.Tensor:
1159
+ if isinstance(x, torch.Tensor): return x.sum()
1160
+ return x.global_sum()
1161
+
1162
+ def generic_max(x: torch.Tensor | TensorList) -> torch.Tensor:
1163
+ if isinstance(x, torch.Tensor): return x.max()
1164
+ return x.global_max()
@@ -7,10 +7,15 @@ import numpy as np
7
7
  import torch
8
8
 
9
9
 
10
- def totensor(x):
11
- if isinstance(x, torch.Tensor): return x
12
- if isinstance(x, np.ndarray): return torch.from_numpy(x)
13
- return torch.from_numpy(np.asarray(x))
10
+ def totensor(x, device=None, dtype=None):
11
+ if device is None and dtype is None:
12
+ if isinstance(x, torch.Tensor): return x
13
+ if isinstance(x, np.ndarray): return torch.from_numpy(x)
14
+ return torch.from_numpy(np.asarray(x))
15
+
16
+ if isinstance(x, torch.Tensor): return x.to(device=device, dtype=dtype)
17
+ if isinstance(x, np.ndarray): return torch.as_tensor(x, device=device, dtype=dtype)
18
+ return torch.as_tensor(np.asarray(x), device=device, dtype=dtype)
14
19
 
15
20
  def tonumpy(x):
16
21
  if isinstance(x, np.ndarray): return x
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchzero
3
+ Version: 0.3.13
4
+ Summary: Modular optimization library for PyTorch.
5
+ Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/inikishev/torchzero
7
+ Project-URL: Repository, https://github.com/inikishev/torchzero
8
+ Project-URL: Issues, https://github.com/inikishev/torchzero/isses
9
+ Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: torch
13
+ Requires-Dist: numpy
14
+ Requires-Dist: typing_extensions
@@ -0,0 +1,166 @@
1
+ tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
2
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
3
+ tests/test_opts.py,sha256=aT6-RbyUhWlIhMPi-ihqZtfiiYk0eT9vEIxMyxvwVOk,44059
4
+ tests/test_tensorlist.py,sha256=pWXQE-vEq08EGJSKWgsTgo-7QjjkavOJ5BlWUm241qI,72434
5
+ tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
6
+ tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
7
+ torchzero/__init__.py,sha256=aIH-cCTXnDr90cKUPhM8bv-uE69Hzjlf0jlYspYf0ZM,120
8
+ torchzero/core/__init__.py,sha256=aYyQt-CHzWT6hGUt5KVjRZZr2lsX5I1XvbWpzaAv3VE,151
9
+ torchzero/core/module.py,sha256=o4MZnJbMpk-F1qm-XvbguE_W_9a0sO2Mb8iDUZQ42B4,40511
10
+ torchzero/core/reformulation.py,sha256=jppgzXBtqdsc7ot6_Gr38vJbbrhG1Gs4vC32y7iB4BA,2387
11
+ torchzero/core/transform.py,sha256=xRDpsZj0H1QcFdO-t2mNMNOYoqqnRHiI3K1YluWwCVk,17097
12
+ torchzero/modules/__init__.py,sha256=3lGta9P0N3cWdVcruCBJ7uqu4DfLTPCKI_mlOZT6Z_o,615
13
+ torchzero/modules/functional.py,sha256=E_d6hLT2_xdE-3AhQ4AthDYK5uZULbF10iHI09Z3_yk,7921
14
+ torchzero/modules/adaptive/__init__.py,sha256=5L2dlEJV6HKBnYhgd7wo2yGi0WPd9qmpw9XS5wOQOq8,944
15
+ torchzero/modules/adaptive/adagrad.py,sha256=0qXC5F4PuOsgLjRXQUWBoiq0AUixsvOP1uDbEeRIcNs,12531
16
+ torchzero/modules/adaptive/adahessian.py,sha256=rWxgDiBMd6MK64mRjZwiudDF07is5AFj_MCEvBD7h8U,8670
17
+ torchzero/modules/adaptive/adam.py,sha256=4lWSe__tdyRv0rfkUda1qa_NH36DIW3sd8td98bK6XI,3829
18
+ torchzero/modules/adaptive/adan.py,sha256=Dt_gibyrGtWDIUCaSF6RFIxu2xwiF9fCfrpkoD-CaUM,2825
19
+ torchzero/modules/adaptive/adaptive_heavyball.py,sha256=xQQw1Vx-NgQd_ouK14J1p5ijd5lEn3sAN0hVJAL0j8U,2024
20
+ torchzero/modules/adaptive/aegd.py,sha256=_4ASgDX8__DPnnBE_RncnMqM4rItM7Eji4EZzSGGq5I,1876
21
+ torchzero/modules/adaptive/esgd.py,sha256=DtrN3hZhGK4LgMxmcjQCBtEO-5hLZrAdnvps6f8WQ2A,6416
22
+ torchzero/modules/adaptive/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
23
+ torchzero/modules/adaptive/lmadagrad.py,sha256=rMs7vrgiwOJgWo-OZXkGu32X561edEhhaxwgbY2NTnk,7176
24
+ torchzero/modules/adaptive/mars.py,sha256=iOkyY3r52btp7Cry7WN0AB4arpn4N9b_Hg6S55XC6Q8,2255
25
+ torchzero/modules/adaptive/matrix_momentum.py,sha256=ZMSdGNSHgyUgtJwjKzK7PZDwMrPw5wZJVu0K4Xa-SpI,6693
26
+ torchzero/modules/adaptive/msam.py,sha256=locqM2jiC3AbGCDCo6T40GF3iaVV2svrkzoi0hD2cJI,6663
27
+ torchzero/modules/adaptive/muon.py,sha256=5Asgj03s6JXXrO-p5Qgn3D8bVwbDEqCq7hxNy4joQDE,10335
28
+ torchzero/modules/adaptive/natural_gradient.py,sha256=5qRehh-iAeZ4hjfOR-gfsObbMsJZnipzCkG-yptkrH0,6349
29
+ torchzero/modules/adaptive/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
30
+ torchzero/modules/adaptive/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
31
+ torchzero/modules/adaptive/rprop.py,sha256=VDnLPKxw8ECihyUeNVE8cyDll_Ut3k3_NqoLgpgxxLA,11818
32
+ torchzero/modules/adaptive/sam.py,sha256=LnOPNZnIUsis0402RHnA-fTPkNM8baUR9HR50pF_BtM,5696
33
+ torchzero/modules/adaptive/shampoo.py,sha256=r7V4I5_Ve1YVOS3HhO2k5cZvJT1lPHTVApV3iVJVceA,9711
34
+ torchzero/modules/adaptive/soap.py,sha256=roQLBthNNNmTYgeJPi_LxZY-r4m6REeUo0_DZknYU50,10662
35
+ torchzero/modules/adaptive/sophia_h.py,sha256=lSK8uVdOxBAhU2jE6fyIx1YgqEQyZG-Fv9o2TniAZzk,7179
36
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
37
+ torchzero/modules/clipping/clipping.py,sha256=t98M3QKZKqXJ3_tzXXIiG4EOYMaHqLYrMZ-6zmRuy-k,14331
38
+ torchzero/modules/clipping/ema_clipping.py,sha256=Ki0LPNUwPoE825A5rSE7SxGQMiI3nO3iwnjKQ486iaI,6611
39
+ torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
40
+ torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
41
+ torchzero/modules/conjugate_gradient/cg.py,sha256=iAIiIyfM5hWeFH6-gxx8y-5olY0ED4DpnbLzXa9dke4,14492
42
+ torchzero/modules/experimental/__init__.py,sha256=blI-OhpQAC6-Ho1uxUq-t7Mm9CAMnNMXkBDmXul8tbc,729
43
+ torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
44
+ torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
45
+ torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
46
+ torchzero/modules/experimental/gradmin.py,sha256=hKTOG7tk6FnG8t-7OmTAhGTGSDdONzP1JvCRPRqaKt0,3740
47
+ torchzero/modules/experimental/l_infinity.py,sha256=nhYusM3YYbc0ptaSf9zlrsqY8EgxlHm9OejJ6VV0qtM,4750
48
+ torchzero/modules/experimental/momentum.py,sha256=VqZc14EGVO_KUABPLRIBlvHdgg-64o-4heMQH0vW5vY,5233
49
+ torchzero/modules/experimental/newton_solver.py,sha256=0HnDBlrBLvUgS4hWmkJqyw0M7UPFp3kU3SFq2xZVYhQ,5454
50
+ torchzero/modules/experimental/newtonnewton.py,sha256=a0XXvlVe37z2MMcQ4TeGbbWX9OuYp_5b-21jq3o1z3E,3823
51
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
52
+ torchzero/modules/experimental/scipy_newton_cg.py,sha256=8nKBaHHmqdU9F1kVPn2QVFUTx2_I8Jsfqxix1v-qoL0,4073
53
+ torchzero/modules/experimental/structural_projections.py,sha256=rxJFG5F23dOiK_8KqKyvoSMLWqAOXtVGHSwfRqH22Wg,4185
54
+ torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
55
+ torchzero/modules/grad_approximation/fdm.py,sha256=zx70GZDQmhe43bZP5Mbbl31xsMOsGO43kznoQDbqxJo,4372
56
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=v7v5maMC_Vak7N5w-LjIH6FIrkQUt7MvbR0PLprsmTI,4338
57
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=x8vlweBrfJ6SmhMHvI_C8UZGzlS3AnmlulvqnSzm6iY,4437
58
+ torchzero/modules/grad_approximation/rfdm.py,sha256=PA0FlGVaNHxstUqPqG1wMuIF4i9NpCqPpp6-1vxCboA,23635
59
+ torchzero/modules/higher_order/__init__.py,sha256=iaoIrmR9DJE9QHt9PeZNCWqIYDe-86h1IjkaumR4qF0,51
60
+ torchzero/modules/higher_order/higher_order_newton.py,sha256=2r1wuhdi57pbo8akQE88O8R-Y79BtiwD1WQIShh1rjQ,12967
61
+ torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
62
+ torchzero/modules/least_squares/gn.py,sha256=23AB6AWAl5IuBj4Vd3boQ6ndk0pO3ovaF9EiY1a1XWs,5094
63
+ torchzero/modules/line_search/__init__.py,sha256=mFWgkcgfMkL2NKj3CLbuwee3e8WHBOaXs-wtx3oTW58,216
64
+ torchzero/modules/line_search/_polyinterp.py,sha256=qIhcLjOlpB6NHU0oiUGMncwQxWNfy8757orsbzjkp6s,10882
65
+ torchzero/modules/line_search/adaptive.py,sha256=8Ip5F5PpsDLgg6TwB_E7zIZheycd78coRg4u7cpO3Cg,3795
66
+ torchzero/modules/line_search/backtracking.py,sha256=Mhx8_UT_Mr1gASYHUorBJ38E4YlcM9LpW9YrJHYfLXU,9049
67
+ torchzero/modules/line_search/line_search.py,sha256=lmtjr9Zpz9RYJXoYaJnpXkBSIdcN6DdwGKKXTCmcJNU,13294
68
+ torchzero/modules/line_search/scipy.py,sha256=gQMi6IYrnDvYsZWIO_cELhg_VZutIQJAELHHlLyu2fg,2286
69
+ torchzero/modules/line_search/strong_wolfe.py,sha256=lXcfOzg4kU0RGTe7GnVWuDPs2YWAjHH9vwRMDEGz4Mg,15054
70
+ torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
71
+ torchzero/modules/misc/debug.py,sha256=6pFAGYANjCPGIZH_4ghpUYYTEsT5jr7PMB9KLuPP4p8,1532
72
+ torchzero/modules/misc/escape.py,sha256=qfEdKLD5rejqrmvyHrI5BRQq8js9UF2-Axs_C0KFyWA,1866
73
+ torchzero/modules/misc/gradient_accumulation.py,sha256=mBWa5CBCZwp4TrtOyjWI3VnHag4gum4WBM2WFhvHqW4,4891
74
+ torchzero/modules/misc/homotopy.py,sha256=hihLETE4dNZ27zatqPR_qT3kGX-AXbC7oBWRDbFQo58,1939
75
+ torchzero/modules/misc/misc.py,sha256=feI-IQlxhIoAbsSRTjE4SbGez1c2Uu9-WA_nkK7iiqQ,15411
76
+ torchzero/modules/misc/multistep.py,sha256=RtDFIeTHu4RcERvlKEP4_10-lpRZOgbnBeSah92dQ7A,6323
77
+ torchzero/modules/misc/regularization.py,sha256=SkQ0_Ybtv9IEGI9QGdvNZaja5bAyc1x-j_1gvYIVepI,6105
78
+ torchzero/modules/misc/split.py,sha256=JcXVB4xk3h55YT2OAdepVsRoE1PD7bqX6NmJ2IxBgAI,4013
79
+ torchzero/modules/misc/switch.py,sha256=p758heAnv-PkoslpafL35Yp7mlvPmDVSe1mWiuuD8Mk,3711
80
+ torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
81
+ torchzero/modules/momentum/averaging.py,sha256=vDW8tgGsEuBXF_BTUYHB_j--TIVam9j0nZdp_x8TkxY,3229
82
+ torchzero/modules/momentum/cautious.py,sha256=x506a3lUETRpxPWqXLFJVFBH1gmLqIfqL5J-hFdEvOo,8051
83
+ torchzero/modules/momentum/momentum.py,sha256=q3n0BvQURuSBzA9vn1ZrH-n7Nsr0AS-38VJuwraQPY0,4495
84
+ torchzero/modules/ops/__init__.py,sha256=9UHaXs9aaKc0ewAhicTlDmj42bSC_vddMOD0eYuUj_8,1226
85
+ torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
86
+ torchzero/modules/ops/binary.py,sha256=2hV2oruaq5Asu4Ts8X8yiZQ-07fU0RGpRy3-vifXqXY,12151
87
+ torchzero/modules/ops/higher_level.py,sha256=E76zgSHlhVpHLrXhnVwelIQFm1IKn0IFcVq7DOJw0es,9037
88
+ torchzero/modules/ops/multi.py,sha256=YC3rBTmPRwF5aEPDNsyTK4J_JEAbmE7oBmF7W-VOV3A,8588
89
+ torchzero/modules/ops/reduce.py,sha256=kALG7X8q02sWpo1skpXjS0r875gwq6mrhLZbFfYaZoA,6324
90
+ torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
91
+ torchzero/modules/ops/utility.py,sha256=_k9S59i6IYOzzfIQlToQ9mlDseTTAS_49wujUxMGXZo,4105
92
+ torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
93
+ torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
94
+ torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
95
+ torchzero/modules/projections/projection.py,sha256=itkkb2UmMqbdtWKjUUg6gbFJfCEIZAskC0HCvom-6sc,14084
96
+ torchzero/modules/quasi_newton/__init__.py,sha256=HxXENs3O6nFRfCvUJhWPK9f8_A6iMwB6UF1Zold12UQ,515
97
+ torchzero/modules/quasi_newton/damping.py,sha256=K1DVqqKiAs6-F3JQh5jlKNb79oJdObqnKWwHHRl6boQ,2813
98
+ torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=Zx-tlFRa89GhoSP7RFJdLQJPiqPCL7rWaV7WJoQ1YCs,6930
99
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=fzCjV5YsLo_uJTVG3vosPHsvDc97mLKueK6fxOHLb8I,11195
100
+ torchzero/modules/quasi_newton/lsr1.py,sha256=D3_yV5xtgklMlU4fAL1-sH82-1tNl3K2F12ZBZyLQGM,8512
101
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=-xUGPld8Y0MHwN6qsmDihLbUbulU0T1z8jf2mZhNpcE,44529
102
+ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
103
+ torchzero/modules/restarts/restars.py,sha256=A3fpTIbfpZCEUq9csPckdcsXQtaL0Le5UY3ZfKzxVSs,8971
104
+ torchzero/modules/second_order/__init__.py,sha256=lTGccDNVwPuMevMeKi5O0a9cl24Rn9tk7VkC6jvlGYc,233
105
+ torchzero/modules/second_order/multipoint.py,sha256=Ilzo0Ddd3iApegceu7cHSMGim9ZH5QS4-2uBtrKXC6k,8581
106
+ torchzero/modules/second_order/newton.py,sha256=PAPbJzssx0Ji328BFOEzeJZPd3IubJTPHs6ZhqS_nW8,15663
107
+ torchzero/modules/second_order/newton_cg.py,sha256=zavattL2z-IjWRT_AdwV5h7BGtQnrBzMTtTyt9xjZ-I,17363
108
+ torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
109
+ torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
110
+ torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
111
+ torchzero/modules/smoothing/sampling.py,sha256=zI5bATytQmCqm_UgAQbfA9tNRgrZaKLfUb0B-kzKRHU,12867
112
+ torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
113
+ torchzero/modules/step_size/adaptive.py,sha256=HvffW3m1NnpMTpps0QjJTvbblSODxxWMBBFTbNwp0vM,14482
114
+ torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
115
+ torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
116
+ torchzero/modules/termination/termination.py,sha256=BXU3R04caBc8rFJ4v_yJjgGi1X4iA11eYwlbiJfxexI,6637
117
+ torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
118
+ torchzero/modules/trust_region/cubic_regularization.py,sha256=gbKOR5zBo3t9i-sW23DCtTQwZrBubuFy_VuafrLaeUw,6718
119
+ torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
120
+ torchzero/modules/trust_region/levenberg_marquardt.py,sha256=Ibyf3ivcGR9sPkD5COXi7dRk3PSOfyTlI1W8ISAHNa8,5039
121
+ torchzero/modules/trust_region/trust_cg.py,sha256=UdQxNx7jf_CxyioRtJ92z35QU5HDbI22xpgd-4pW7V8,4297
122
+ torchzero/modules/trust_region/trust_region.py,sha256=eimCFViJSzoubrRmDluCon6mfcyT7PQA0yRPu4FlO2Q,12872
123
+ torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
124
+ torchzero/modules/variance_reduction/svrg.py,sha256=9pBjPY4EMkGyfj68gXqPi1GJIolUVl5zyNtlZInCKKo,8635
125
+ torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
126
+ torchzero/modules/weight_decay/weight_decay.py,sha256=Y7kE_j0GRF8ceJ9SS6qykQ8a23X2OTDCjJ9VklOQSEw,5415
127
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
128
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
129
+ torchzero/modules/zeroth_order/__init__.py,sha256=1_6wNrytru7tEHXzRXmL4AnK39ILPgf8FMVtF_YmAYU,30
130
+ torchzero/modules/zeroth_order/cd.py,sha256=6NL_xe56w1RbPPgxcggnQnD9eWNq7ZrhZjv4bZwq2Ms,14951
131
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
132
+ torchzero/optim/root.py,sha256=gGtAJ9qBoSNV58EKzUGZ8J3lyKGUF8BEw34Zfprppdo,2273
133
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
134
+ torchzero/optim/utility/split.py,sha256=kraPCLAewX2uLbD_9R2dIrcF-kpUuT9IcxPeVrAARvA,1672
135
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
136
+ torchzero/optim/wrappers/directsearch.py,sha256=rimJIB2RrVzLpRPQKhzkrMQ4bTAEU3NEOT4pJQNIAHE,11309
137
+ torchzero/optim/wrappers/fcmaes.py,sha256=jKmmBKEwguYiJdvTRmAp5JSilxcUhtpRoKlzmp-lyWE,4251
138
+ torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
139
+ torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
140
+ torchzero/optim/wrappers/nlopt.py,sha256=RuLKretljjAjTZ9tXY3FyEDuB7mAboeGOQBupWfzPc4,8105
141
+ torchzero/optim/wrappers/optuna.py,sha256=pIXkC5NVmEnUQ4jsGaz6Gv9uYOZM9rnxME4UGkeolsE,2393
142
+ torchzero/optim/wrappers/scipy.py,sha256=A4yeQRdB9f65UrJ2g80NfqqMc6zUyr9js40TUESCHPg,21535
143
+ torchzero/utils/__init__.py,sha256=7S4VRTkfS-0uI8HOR0EFIjiEcKrmYK7LEhTocIgki6c,1112
144
+ torchzero/utils/compile.py,sha256=Dozox91tcShUJ3L320TTbJrcuA-l4WVegLAQujRqy94,5132
145
+ torchzero/utils/derivatives.py,sha256=zJ0xyedvlIwgAYMa1F5BBfyrkvgjXy7v7evvl6QAlT0,17195
146
+ torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
147
+ torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
148
+ torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
149
+ torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
150
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
151
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
152
+ torchzero/utils/python_tools.py,sha256=kdiGk-I0Q-GpIVu3pCROkWvUHiDgzsagLgEsTzZplQw,3427
153
+ torchzero/utils/tensorlist.py,sha256=nIWBME3fUQPsr4buvtV3LaJgSXPEG_Xb58KAzfjwK-k,56064
154
+ torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
155
+ torchzero/utils/linalg/__init__.py,sha256=cNoTJOPeqbNn9l7_HAAen2rlehGS3DyY5SveInG3Stc,328
156
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
157
+ torchzero/utils/linalg/linear_operator.py,sha256=uJUxvOVHpG3U3GNx61JGa_uM8GqzsNZmA_z7P0RwZ5E,12747
158
+ torchzero/utils/linalg/matrix_funcs.py,sha256=BKQK_oIG35R6yGxU80eBG0VkyY2EgxywqbhvU7JhWm4,3109
159
+ torchzero/utils/linalg/orthogonalize.py,sha256=BpuDiAPrsJMUpTNBMCntBNA8-O2nozLxY5ZbCfRlEFY,444
160
+ torchzero/utils/linalg/qr.py,sha256=5tbPEV9I6X69r5ACWF9XeqjZTUtUql2145uoGjlJNDs,2517
161
+ torchzero/utils/linalg/solve.py,sha256=R5lPTzHn2sgvRy4MRp-Ngl0sypSGLRLHJjf1oKKAJD0,14395
162
+ torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
163
+ torchzero-0.3.13.dist-info/METADATA,sha256=onWv9DCY_mvI2vm-1MYrkRfTfJvWDLKDpuNGZO1ill0,565
164
+ torchzero-0.3.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
165
+ torchzero-0.3.13.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
166
+ torchzero-0.3.13.dist-info/RECORD,,
@@ -1,3 +1,2 @@
1
- docs
2
1
  tests
3
2
  torchzero
docs/source/conf.py DELETED
@@ -1,59 +0,0 @@
1
- # Configuration file for the Sphinx documentation builder.
2
- #
3
- # For the full list of built-in configuration values, see the documentation:
4
- # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
-
6
- # -- Project information -----------------------------------------------------
7
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
- import sys, os
9
- #sys.path.insert(0, os.path.abspath('.../src'))
10
-
11
- project = 'torchzero'
12
- copyright = '2025, Ivan Nikishev'
13
- author = 'Ivan Nikishev'
14
-
15
- # -- General configuration ---------------------------------------------------
16
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
17
-
18
- # https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
19
- extensions = [
20
- 'sphinx.ext.autodoc',
21
- 'sphinx.ext.autosummary',
22
- 'sphinx.ext.viewcode',
23
- 'sphinx.ext.autosectionlabel',
24
- 'sphinx.ext.githubpages',
25
- 'sphinx.ext.napoleon',
26
- 'autoapi.extension',
27
- "myst_nb",
28
-
29
- # 'sphinx_rtd_theme',
30
- ]
31
- autosummary_generate = True
32
- autoapi_dirs = ['../../torchzero']
33
- autoapi_type = "python"
34
- # autoapi_ignore = ["*/tensorlist.py"]
35
-
36
- # https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
37
- autoapi_options = [
38
- "members",
39
- "undoc-members",
40
- "show-inheritance",
41
- "show-module-summary",
42
- "imported-members",
43
- ]
44
-
45
-
46
- templates_path = ['_templates']
47
- exclude_patterns = []
48
-
49
- # -- Options for HTML output -------------------------------------------------
50
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
51
-
52
- #html_theme = 'alabaster'
53
- html_theme = 'sphinx_rtd_theme'
54
- html_static_path = ['_static']
55
-
56
-
57
- # OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
58
- source_suffix = ['.rst', '.md']
59
- master_doc = 'index'
@@ -1,46 +0,0 @@
1
- class MyModule:
2
- """[One-line summary of the class].
3
-
4
- [A more detailed description of the class, explaining its purpose, how it
5
- works, and its typical use cases. You can use multiple paragraphs.]
6
-
7
- .. note::
8
- [Optional: Add important notes, warnings, or usage guidelines here.
9
- For example, you could mention if a closure is required, discuss
10
- stability, or highlight performance characteristics. Use the `.. note::`
11
- directive to make it stand out in the documentation.]
12
-
13
- Args:
14
- param1 (type, optional):
15
- [Description of the first parameter. Use :code:`backticks` for
16
- inline code like variable names or specific values like ``"autograd"``.
17
- Explain what the parameter does.] Defaults to [value].
18
- param2 (type):
19
- [Description of a mandatory parameter (no "optional" or "Defaults to").]
20
- **kwargs:
21
- [If you accept keyword arguments, describe what they are used for.]
22
-
23
- Examples:
24
- [A title or short sentence describing the first example]:
25
-
26
- .. code-block:: python
27
-
28
- opt = tz.Modular(
29
- model.parameters(),
30
- ...
31
- )
32
-
33
- [A title or short sentence for a second, different example]:
34
-
35
- .. code-block:: python
36
-
37
- opt = tz.Modular(
38
- model.parameters(),
39
- ...
40
- )
41
-
42
- References:
43
- - [Optional: A citation for a relevant paper, book, or algorithm.]
44
- - [Optional: A link to a blog post or website with more information.]
45
-
46
- """