torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -11,17 +11,19 @@ in an optimizer when you have to create one from parameters on each step. The so
11
11
  it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
12
12
  """
13
13
  import builtins
14
- from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
15
14
  import math
16
15
  import operator
16
+ from abc import ABC, abstractmethod
17
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
17
18
  from typing import Any, Literal, TypedDict, overload
18
- from typing_extensions import Self, TypeAlias, Unpack
19
19
 
20
20
  import torch
21
- from .ops import where_
22
- from .python_tools import zipmap, generic_ne
23
- from .numberlist import NumberList, as_numberlist, maybe_numberlist
21
+ from typing_extensions import Self, TypeAlias, Unpack
24
22
 
23
+ from .metrics import Metrics, evaluate_metric, calculate_metric_list
24
+ from .numberlist import NumberList, as_numberlist, maybe_numberlist
25
+ from .ops import where_
26
+ from .python_tools import generic_ne, zipmap
25
27
 
26
28
  _Scalar = int | float | bool | complex
27
29
  _TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
@@ -33,6 +35,7 @@ _STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
33
35
  _Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
34
36
 
35
37
  Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
38
+
36
39
  class _NewTensorKwargs(TypedDict, total = False):
37
40
  memory_format: Any
38
41
  dtype: Any
@@ -325,9 +328,20 @@ class TensorList(list[torch.Tensor | Any]):
325
328
  def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
326
329
  def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
327
330
  def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
328
- def global_vector_norm(self, ord:float | Literal['mean_abs'] = 2) -> torch.Tensor:
329
- if ord == 'mean_abs': return self.abs().global_mean()
330
- return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
331
+
332
+ def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
333
+ # return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
334
+ if ord == 1: return self.global_sum()
335
+ if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
336
+ if ord == torch.inf: return self.abs().global_max()
337
+ if ord == -torch.inf: return self.abs().global_min()
338
+ if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
339
+
340
+ return self.abs().pow_(ord).global_sum().pow(1/ord)
341
+
342
+ def global_metric(self, metric: Metrics) -> torch.Tensor:
343
+ return evaluate_metric(self, metric)
344
+
331
345
  def global_any(self): return builtins.any(self.any())
332
346
  def global_all(self): return builtins.all(self.all())
333
347
  def global_numel(self) -> int: return builtins.sum(self.numel())
@@ -358,31 +372,54 @@ class TensorList(list[torch.Tensor | Any]):
358
372
 
359
373
  def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
360
374
  return self.zipmap_args(torch.randint_like, low, high, **kwargs)
375
+
361
376
  def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
362
377
  res = self.empty_like(**kwargs)
363
378
  res.uniform_(low, high, generator=generator)
364
379
  return res
380
+
365
381
  def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
366
382
  r = self.randn_like(generator=generator, **kwargs)
367
- return (r * radius) / r.global_vector_norm()
383
+ return r.mul_(maybe_numberlist(radius) / r.global_vector_norm())
384
+
368
385
  def bernoulli(self, generator = None):
369
386
  return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
387
+
370
388
  def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
371
389
  """p is probability of a 1, other values will be 0."""
372
390
  return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
391
+
373
392
  def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
374
393
  """p is probability of a 1, other values will be -1."""
375
394
  return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
376
395
 
377
- def sample_like(self, eps: "_Scalar | _ScalarSeq" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
396
+ def sample_like(self, distribution: Distributions = 'normal', variance: "_Scalar | _ScalarSeq | Sequence | None" = None, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
378
397
  """Sample around 0."""
379
- if distribution in ('normal', 'gaussian'): return self.randn_like(generator=generator, **kwargs) * eps
398
+ if isinstance(variance, Sequence):
399
+ if all(v is None for v in variance): variance = None
400
+ else: variance = [v if v is not None else 1 for v in variance]
401
+
402
+ if distribution in ('normal', 'gaussian'):
403
+ ret = self.randn_like(generator=generator, **kwargs)
404
+ if variance is not None: ret *= variance
405
+ return ret
406
+
380
407
  if distribution == 'uniform':
381
- if isinstance(eps, (list,tuple)):
382
- return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs)
383
- return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
384
- if distribution == 'sphere': return self.sphere_like(eps, generator=generator, **kwargs)
385
- if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
408
+ b = 1
409
+ if variance is not None:
410
+ b = ((12 * maybe_numberlist(variance)) ** 0.5) / 2
411
+ return self.uniform_like(-b, b, generator=generator, **kwargs)
412
+
413
+ if distribution == 'sphere':
414
+ if variance is None: radius = 1
415
+ else: radius = maybe_numberlist(variance) * math.sqrt(self.global_numel())
416
+ return self.sphere_like(radius, generator=generator, **kwargs)
417
+
418
+ if distribution == 'rademacher':
419
+ ret = self.rademacher_like(generator=generator, **kwargs)
420
+ if variance is not None: ret *= variance
421
+ return ret
422
+
386
423
  raise ValueError(f'Unknow distribution {distribution}')
387
424
 
388
425
  def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
@@ -504,6 +541,11 @@ class TensorList(list[torch.Tensor | Any]):
504
541
  torch._foreach_pow_(input, self)
505
542
  return self
506
543
 
544
+ def square(self): return self.__class__(torch._foreach_pow(self, 2))
545
+ def square_(self):
546
+ torch._foreach_pow_(self, 2)
547
+ return self
548
+
507
549
  def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
508
550
  def sqrt_(self):
509
551
  torch._foreach_sqrt_(self)
@@ -634,10 +676,12 @@ class TensorList(list[torch.Tensor | Any]):
634
676
  if dim is None: dim = ()
635
677
  return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
636
678
 
637
- def norm(self, ord: _Scalar|Literal["mean_abs"], dtype=None):
638
- if isinstance(ord, str): return self.abs().mean()
679
+ def norm(self, ord: float, dtype=None):
639
680
  return self.__class__(torch._foreach_norm(self, ord, dtype))
640
681
 
682
+ def metric(self, metric: Metrics) -> "TensorList":
683
+ return calculate_metric_list(self, metric)
684
+
641
685
  def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
642
686
  if dim == 'global': return self._global_fn(keepdim, self.global_mean)
643
687
  return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
@@ -790,27 +834,27 @@ class TensorList(list[torch.Tensor | Any]):
790
834
  for t, o in zip(self, other): t.copysign_(o)
791
835
  return self
792
836
 
793
- def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
837
+ def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
794
838
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
795
839
  if tensorwise:
796
- norm_self = self.norm(ord)
797
- norm_other = magnitude.norm(ord)
840
+ norm_self = self.metric(ord)
841
+ norm_other = magnitude.metric(ord)
798
842
  else:
799
- norm_self = self.global_vector_norm(ord)
800
- norm_other = magnitude.global_vector_norm(ord)
843
+ norm_self = self.global_metric(ord)
844
+ norm_other = magnitude.global_metric(ord)
801
845
 
802
846
  if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
803
847
 
804
848
  return self * (norm_other / norm_self.clip_(min=eps))
805
849
 
806
- def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float | Literal['mean_abs'] = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
850
+ def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
807
851
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
808
852
  if tensorwise:
809
- norm_self = self.norm(ord)
810
- norm_other = magnitude.norm(ord)
853
+ norm_self = self.metric(ord)
854
+ norm_other = magnitude.metric(ord)
811
855
  else:
812
- norm_self = self.global_vector_norm(ord)
813
- norm_other = magnitude.global_vector_norm(ord)
856
+ norm_self = self.global_metric(ord)
857
+ norm_other = magnitude.global_metric(ord)
814
858
 
815
859
  if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
816
860
 
@@ -905,14 +949,14 @@ class TensorList(list[torch.Tensor | Any]):
905
949
  if eps!=0: std.add_(eps)
906
950
  return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
907
951
 
908
- def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
952
+ def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
909
953
  """calculate multipler to clip self norm to min and max"""
910
954
  if tensorwise:
911
- self_norm = self.norm(ord)
955
+ self_norm = self.metric(ord)
912
956
  self_norm.masked_fill_(self_norm == 0, 1)
913
957
 
914
958
  else:
915
- self_norm = self.global_vector_norm(ord)
959
+ self_norm = self.global_metric(ord)
916
960
  if self_norm == 0: return 1
917
961
 
918
962
  mul = 1
@@ -926,12 +970,12 @@ class TensorList(list[torch.Tensor | Any]):
926
970
 
927
971
  return mul
928
972
 
929
- def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
973
+ def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
930
974
  """clips norm of each tensor to (min, max) range"""
931
975
  if min is None and max is None: return self
932
976
  return self * self._clip_multiplier(min, max, tensorwise, ord)
933
977
 
934
- def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float|Literal["mean_abs"] = 2):
978
+ def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
935
979
  """clips norm of each tensor to (min, max) range"""
936
980
  if min is None and max is None: return self
937
981
  return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
@@ -990,6 +1034,15 @@ class TensorList(list[torch.Tensor | Any]):
990
1034
  # """sets index in flattened view"""
991
1035
  # return self.clone().flatset_(idx, value)
992
1036
 
1037
+ def flat_get(self, idx: int):
1038
+ cur = 0
1039
+ for tensor in self:
1040
+ numel = tensor.numel()
1041
+ if idx < cur + numel:
1042
+ return tensor.view(-1)[cur-idx]
1043
+ cur += numel
1044
+ raise IndexError(idx)
1045
+
993
1046
  def flat_set_(self, idx: int, value: Any):
994
1047
  """sets index in flattened view"""
995
1048
  cur = 0
@@ -1065,10 +1118,19 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
1065
1118
  if isinstance(x, torch.Tensor): return x.numel()
1066
1119
  return x.global_numel()
1067
1120
 
1121
+
1122
+ def generic_finfo(x: torch.Tensor | TensorList) -> torch.finfo:
1123
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype)
1124
+ return torch.finfo(x[0].dtype)
1125
+
1068
1126
  def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
1069
1127
  if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
1070
1128
  return torch.finfo(x[0].dtype).eps
1071
1129
 
1130
+ def generic_finfo_tiny(x: torch.Tensor | TensorList) -> float:
1131
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).tiny
1132
+ return torch.finfo(x[0].dtype).tiny
1133
+
1072
1134
  @overload
1073
1135
  def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
1074
1136
  @overload
@@ -1081,7 +1143,8 @@ def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
1081
1143
  if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
1082
1144
  return x.global_vector_norm(ord)
1083
1145
 
1084
-
1146
+ def generic_metric(x: torch.Tensor | TensorList, metric: Metrics) -> torch.Tensor:
1147
+ return evaluate_metric(x, metric)
1085
1148
 
1086
1149
  @overload
1087
1150
  def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
@@ -1091,3 +1154,11 @@ def generic_randn_like(x: torch.Tensor | TensorList):
1091
1154
  if isinstance(x, torch.Tensor): return torch.randn_like(x)
1092
1155
  return x.randn_like()
1093
1156
 
1157
+
1158
+ def generic_sum(x: torch.Tensor | TensorList) -> torch.Tensor:
1159
+ if isinstance(x, torch.Tensor): return x.sum()
1160
+ return x.global_sum()
1161
+
1162
+ def generic_max(x: torch.Tensor | TensorList) -> torch.Tensor:
1163
+ if isinstance(x, torch.Tensor): return x.max()
1164
+ return x.global_max()
@@ -7,10 +7,15 @@ import numpy as np
7
7
  import torch
8
8
 
9
9
 
10
- def totensor(x):
11
- if isinstance(x, torch.Tensor): return x
12
- if isinstance(x, np.ndarray): return torch.from_numpy(x)
13
- return torch.from_numpy(np.asarray(x))
10
+ def totensor(x, device=None, dtype=None):
11
+ if device is None and dtype is None:
12
+ if isinstance(x, torch.Tensor): return x
13
+ if isinstance(x, np.ndarray): return torch.from_numpy(x)
14
+ return torch.from_numpy(np.asarray(x))
15
+
16
+ if isinstance(x, torch.Tensor): return x.to(device=device, dtype=dtype)
17
+ if isinstance(x, np.ndarray): return torch.as_tensor(x, device=device, dtype=dtype)
18
+ return torch.as_tensor(np.asarray(x), device=device, dtype=dtype)
14
19
 
15
20
  def tonumpy(x):
16
21
  if isinstance(x, np.ndarray): return x
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchzero
3
+ Version: 0.3.14
4
+ Summary: Modular optimization library for PyTorch.
5
+ Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/inikishev/torchzero
7
+ Project-URL: Repository, https://github.com/inikishev/torchzero
8
+ Project-URL: Issues, https://github.com/inikishev/torchzero/isses
9
+ Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: torch
13
+ Requires-Dist: numpy
14
+ Requires-Dist: typing_extensions
@@ -0,0 +1,167 @@
1
+ tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
2
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
3
+ tests/test_opts.py,sha256=XsOTxiKJQJOUr47d1nzyQ1WFclBrhXYKQMdEVN66bWs,43687
4
+ tests/test_tensorlist.py,sha256=pWXQE-vEq08EGJSKWgsTgo-7QjjkavOJ5BlWUm241qI,72434
5
+ tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
6
+ tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
7
+ torchzero/__init__.py,sha256=aIH-cCTXnDr90cKUPhM8bv-uE69Hzjlf0jlYspYf0ZM,120
8
+ torchzero/core/__init__.py,sha256=aYyQt-CHzWT6hGUt5KVjRZZr2lsX5I1XvbWpzaAv3VE,151
9
+ torchzero/core/module.py,sha256=chn9NZWdgYekzVPDppvp2REMp4WugiMhOjQU3tys6ZU,40651
10
+ torchzero/core/reformulation.py,sha256=jppgzXBtqdsc7ot6_Gr38vJbbrhG1Gs4vC32y7iB4BA,2387
11
+ torchzero/core/transform.py,sha256=xRDpsZj0H1QcFdO-t2mNMNOYoqqnRHiI3K1YluWwCVk,17097
12
+ torchzero/modules/__init__.py,sha256=3lGta9P0N3cWdVcruCBJ7uqu4DfLTPCKI_mlOZT6Z_o,615
13
+ torchzero/modules/functional.py,sha256=E_d6hLT2_xdE-3AhQ4AthDYK5uZULbF10iHI09Z3_yk,7921
14
+ torchzero/modules/adaptive/__init__.py,sha256=5L2dlEJV6HKBnYhgd7wo2yGi0WPd9qmpw9XS5wOQOq8,944
15
+ torchzero/modules/adaptive/adagrad.py,sha256=0qXC5F4PuOsgLjRXQUWBoiq0AUixsvOP1uDbEeRIcNs,12531
16
+ torchzero/modules/adaptive/adahessian.py,sha256=rWxgDiBMd6MK64mRjZwiudDF07is5AFj_MCEvBD7h8U,8670
17
+ torchzero/modules/adaptive/adam.py,sha256=4lWSe__tdyRv0rfkUda1qa_NH36DIW3sd8td98bK6XI,3829
18
+ torchzero/modules/adaptive/adan.py,sha256=Dt_gibyrGtWDIUCaSF6RFIxu2xwiF9fCfrpkoD-CaUM,2825
19
+ torchzero/modules/adaptive/adaptive_heavyball.py,sha256=xQQw1Vx-NgQd_ouK14J1p5ijd5lEn3sAN0hVJAL0j8U,2024
20
+ torchzero/modules/adaptive/aegd.py,sha256=_4ASgDX8__DPnnBE_RncnMqM4rItM7Eji4EZzSGGq5I,1876
21
+ torchzero/modules/adaptive/esgd.py,sha256=DtrN3hZhGK4LgMxmcjQCBtEO-5hLZrAdnvps6f8WQ2A,6416
22
+ torchzero/modules/adaptive/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
23
+ torchzero/modules/adaptive/lmadagrad.py,sha256=rMs7vrgiwOJgWo-OZXkGu32X561edEhhaxwgbY2NTnk,7176
24
+ torchzero/modules/adaptive/mars.py,sha256=iOkyY3r52btp7Cry7WN0AB4arpn4N9b_Hg6S55XC6Q8,2255
25
+ torchzero/modules/adaptive/matrix_momentum.py,sha256=ZMSdGNSHgyUgtJwjKzK7PZDwMrPw5wZJVu0K4Xa-SpI,6693
26
+ torchzero/modules/adaptive/msam.py,sha256=locqM2jiC3AbGCDCo6T40GF3iaVV2svrkzoi0hD2cJI,6663
27
+ torchzero/modules/adaptive/muon.py,sha256=5Asgj03s6JXXrO-p5Qgn3D8bVwbDEqCq7hxNy4joQDE,10335
28
+ torchzero/modules/adaptive/natural_gradient.py,sha256=5qRehh-iAeZ4hjfOR-gfsObbMsJZnipzCkG-yptkrH0,6349
29
+ torchzero/modules/adaptive/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
30
+ torchzero/modules/adaptive/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
31
+ torchzero/modules/adaptive/rprop.py,sha256=VDnLPKxw8ECihyUeNVE8cyDll_Ut3k3_NqoLgpgxxLA,11818
32
+ torchzero/modules/adaptive/sam.py,sha256=LnOPNZnIUsis0402RHnA-fTPkNM8baUR9HR50pF_BtM,5696
33
+ torchzero/modules/adaptive/shampoo.py,sha256=r7V4I5_Ve1YVOS3HhO2k5cZvJT1lPHTVApV3iVJVceA,9711
34
+ torchzero/modules/adaptive/soap.py,sha256=roQLBthNNNmTYgeJPi_LxZY-r4m6REeUo0_DZknYU50,10662
35
+ torchzero/modules/adaptive/sophia_h.py,sha256=lSK8uVdOxBAhU2jE6fyIx1YgqEQyZG-Fv9o2TniAZzk,7179
36
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
37
+ torchzero/modules/clipping/clipping.py,sha256=t98M3QKZKqXJ3_tzXXIiG4EOYMaHqLYrMZ-6zmRuy-k,14331
38
+ torchzero/modules/clipping/ema_clipping.py,sha256=Ki0LPNUwPoE825A5rSE7SxGQMiI3nO3iwnjKQ486iaI,6611
39
+ torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
40
+ torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
41
+ torchzero/modules/conjugate_gradient/cg.py,sha256=L49p_wBeyQ3pZmaPM1vLx7uWmZBKtIhMf7Uv2ggUbI4,14534
42
+ torchzero/modules/experimental/__init__.py,sha256=blI-OhpQAC6-Ho1uxUq-t7Mm9CAMnNMXkBDmXul8tbc,729
43
+ torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
44
+ torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
45
+ torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
46
+ torchzero/modules/experimental/gradmin.py,sha256=hKTOG7tk6FnG8t-7OmTAhGTGSDdONzP1JvCRPRqaKt0,3740
47
+ torchzero/modules/experimental/l_infinity.py,sha256=nhYusM3YYbc0ptaSf9zlrsqY8EgxlHm9OejJ6VV0qtM,4750
48
+ torchzero/modules/experimental/momentum.py,sha256=VqZc14EGVO_KUABPLRIBlvHdgg-64o-4heMQH0vW5vY,5233
49
+ torchzero/modules/experimental/newton_solver.py,sha256=0HnDBlrBLvUgS4hWmkJqyw0M7UPFp3kU3SFq2xZVYhQ,5454
50
+ torchzero/modules/experimental/newtonnewton.py,sha256=a0XXvlVe37z2MMcQ4TeGbbWX9OuYp_5b-21jq3o1z3E,3823
51
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
52
+ torchzero/modules/experimental/scipy_newton_cg.py,sha256=8nKBaHHmqdU9F1kVPn2QVFUTx2_I8Jsfqxix1v-qoL0,4073
53
+ torchzero/modules/experimental/spsa1.py,sha256=tytlOaKSyhvoElErI-MzhbL0fIm3K4d_M1Kpbpb9jbw,3622
54
+ torchzero/modules/experimental/structural_projections.py,sha256=rxJFG5F23dOiK_8KqKyvoSMLWqAOXtVGHSwfRqH22Wg,4185
55
+ torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
56
+ torchzero/modules/grad_approximation/fdm.py,sha256=zx70GZDQmhe43bZP5Mbbl31xsMOsGO43kznoQDbqxJo,4372
57
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=pCOsvt4ZZtzsIlAeXLmpS8vNXYewA_Gh7uyz1_1yROs,4011
58
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=x8vlweBrfJ6SmhMHvI_C8UZGzlS3AnmlulvqnSzm6iY,4437
59
+ torchzero/modules/grad_approximation/rfdm.py,sha256=zgzZYJDAq6HsOMeHePUWlnnJWYRpqey77nHMaHh8140,19217
60
+ torchzero/modules/higher_order/__init__.py,sha256=iaoIrmR9DJE9QHt9PeZNCWqIYDe-86h1IjkaumR4qF0,51
61
+ torchzero/modules/higher_order/higher_order_newton.py,sha256=2r1wuhdi57pbo8akQE88O8R-Y79BtiwD1WQIShh1rjQ,12967
62
+ torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
63
+ torchzero/modules/least_squares/gn.py,sha256=23AB6AWAl5IuBj4Vd3boQ6ndk0pO3ovaF9EiY1a1XWs,5094
64
+ torchzero/modules/line_search/__init__.py,sha256=mFWgkcgfMkL2NKj3CLbuwee3e8WHBOaXs-wtx3oTW58,216
65
+ torchzero/modules/line_search/_polyinterp.py,sha256=qIhcLjOlpB6NHU0oiUGMncwQxWNfy8757orsbzjkp6s,10882
66
+ torchzero/modules/line_search/adaptive.py,sha256=8Ip5F5PpsDLgg6TwB_E7zIZheycd78coRg4u7cpO3Cg,3795
67
+ torchzero/modules/line_search/backtracking.py,sha256=Mhx8_UT_Mr1gASYHUorBJ38E4YlcM9LpW9YrJHYfLXU,9049
68
+ torchzero/modules/line_search/line_search.py,sha256=lmtjr9Zpz9RYJXoYaJnpXkBSIdcN6DdwGKKXTCmcJNU,13294
69
+ torchzero/modules/line_search/scipy.py,sha256=xQ80h9cSyF4Iorq_1NoJglu_Bx4_KeojulBIxvwU6gQ,2836
70
+ torchzero/modules/line_search/strong_wolfe.py,sha256=21IMtenhBlrw_edtwizzLZ01PF4-rU4M2oGOCUA2udc,14936
71
+ torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
72
+ torchzero/modules/misc/debug.py,sha256=6pFAGYANjCPGIZH_4ghpUYYTEsT5jr7PMB9KLuPP4p8,1532
73
+ torchzero/modules/misc/escape.py,sha256=qfEdKLD5rejqrmvyHrI5BRQq8js9UF2-Axs_C0KFyWA,1866
74
+ torchzero/modules/misc/gradient_accumulation.py,sha256=mBWa5CBCZwp4TrtOyjWI3VnHag4gum4WBM2WFhvHqW4,4891
75
+ torchzero/modules/misc/homotopy.py,sha256=hihLETE4dNZ27zatqPR_qT3kGX-AXbC7oBWRDbFQo58,1939
76
+ torchzero/modules/misc/misc.py,sha256=feI-IQlxhIoAbsSRTjE4SbGez1c2Uu9-WA_nkK7iiqQ,15411
77
+ torchzero/modules/misc/multistep.py,sha256=RtDFIeTHu4RcERvlKEP4_10-lpRZOgbnBeSah92dQ7A,6323
78
+ torchzero/modules/misc/regularization.py,sha256=SkQ0_Ybtv9IEGI9QGdvNZaja5bAyc1x-j_1gvYIVepI,6105
79
+ torchzero/modules/misc/split.py,sha256=JcXVB4xk3h55YT2OAdepVsRoE1PD7bqX6NmJ2IxBgAI,4013
80
+ torchzero/modules/misc/switch.py,sha256=p758heAnv-PkoslpafL35Yp7mlvPmDVSe1mWiuuD8Mk,3711
81
+ torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
82
+ torchzero/modules/momentum/averaging.py,sha256=vDW8tgGsEuBXF_BTUYHB_j--TIVam9j0nZdp_x8TkxY,3229
83
+ torchzero/modules/momentum/cautious.py,sha256=x506a3lUETRpxPWqXLFJVFBH1gmLqIfqL5J-hFdEvOo,8051
84
+ torchzero/modules/momentum/momentum.py,sha256=q3n0BvQURuSBzA9vn1ZrH-n7Nsr0AS-38VJuwraQPY0,4495
85
+ torchzero/modules/ops/__init__.py,sha256=9UHaXs9aaKc0ewAhicTlDmj42bSC_vddMOD0eYuUj_8,1226
86
+ torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
87
+ torchzero/modules/ops/binary.py,sha256=2hV2oruaq5Asu4Ts8X8yiZQ-07fU0RGpRy3-vifXqXY,12151
88
+ torchzero/modules/ops/higher_level.py,sha256=E76zgSHlhVpHLrXhnVwelIQFm1IKn0IFcVq7DOJw0es,9037
89
+ torchzero/modules/ops/multi.py,sha256=YC3rBTmPRwF5aEPDNsyTK4J_JEAbmE7oBmF7W-VOV3A,8588
90
+ torchzero/modules/ops/reduce.py,sha256=kALG7X8q02sWpo1skpXjS0r875gwq6mrhLZbFfYaZoA,6324
91
+ torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
92
+ torchzero/modules/ops/utility.py,sha256=_k9S59i6IYOzzfIQlToQ9mlDseTTAS_49wujUxMGXZo,4105
93
+ torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
94
+ torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
95
+ torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
96
+ torchzero/modules/projections/projection.py,sha256=itkkb2UmMqbdtWKjUUg6gbFJfCEIZAskC0HCvom-6sc,14084
97
+ torchzero/modules/quasi_newton/__init__.py,sha256=HxXENs3O6nFRfCvUJhWPK9f8_A6iMwB6UF1Zold12UQ,515
98
+ torchzero/modules/quasi_newton/damping.py,sha256=K1DVqqKiAs6-F3JQh5jlKNb79oJdObqnKWwHHRl6boQ,2813
99
+ torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=Zx-tlFRa89GhoSP7RFJdLQJPiqPCL7rWaV7WJoQ1YCs,6930
100
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=fzCjV5YsLo_uJTVG3vosPHsvDc97mLKueK6fxOHLb8I,11195
101
+ torchzero/modules/quasi_newton/lsr1.py,sha256=D3_yV5xtgklMlU4fAL1-sH82-1tNl3K2F12ZBZyLQGM,8512
102
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=-xUGPld8Y0MHwN6qsmDihLbUbulU0T1z8jf2mZhNpcE,44529
103
+ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
104
+ torchzero/modules/restarts/restars.py,sha256=ZN8kzufkOKredI54HJ9cSANAYY2yN4gAggqAVtEBCyA,9084
105
+ torchzero/modules/second_order/__init__.py,sha256=lTGccDNVwPuMevMeKi5O0a9cl24Rn9tk7VkC6jvlGYc,233
106
+ torchzero/modules/second_order/multipoint.py,sha256=Ilzo0Ddd3iApegceu7cHSMGim9ZH5QS4-2uBtrKXC6k,8581
107
+ torchzero/modules/second_order/newton.py,sha256=PAPbJzssx0Ji328BFOEzeJZPd3IubJTPHs6ZhqS_nW8,15663
108
+ torchzero/modules/second_order/newton_cg.py,sha256=6KmYr-U5gmuTGoVZA8Gt2-rfLwnjNakjKkBKEuBffTk,16738
109
+ torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
110
+ torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
111
+ torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
112
+ torchzero/modules/smoothing/sampling.py,sha256=zI5bATytQmCqm_UgAQbfA9tNRgrZaKLfUb0B-kzKRHU,12867
113
+ torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
114
+ torchzero/modules/step_size/adaptive.py,sha256=HvffW3m1NnpMTpps0QjJTvbblSODxxWMBBFTbNwp0vM,14482
115
+ torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
116
+ torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
117
+ torchzero/modules/termination/termination.py,sha256=BXU3R04caBc8rFJ4v_yJjgGi1X4iA11eYwlbiJfxexI,6637
118
+ torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
119
+ torchzero/modules/trust_region/cubic_regularization.py,sha256=gbKOR5zBo3t9i-sW23DCtTQwZrBubuFy_VuafrLaeUw,6718
120
+ torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
121
+ torchzero/modules/trust_region/levenberg_marquardt.py,sha256=s4XBXK8LwHuwyanOJFtgmOLkxEiBMhBVzch0J2_dIFk,5055
122
+ torchzero/modules/trust_region/trust_cg.py,sha256=F6G0hTXv6Ry0swO_4fx3ecxYWxMr72nrmwJPbFDpqH4,4459
123
+ torchzero/modules/trust_region/trust_region.py,sha256=eimCFViJSzoubrRmDluCon6mfcyT7PQA0yRPu4FlO2Q,12872
124
+ torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
125
+ torchzero/modules/variance_reduction/svrg.py,sha256=9pBjPY4EMkGyfj68gXqPi1GJIolUVl5zyNtlZInCKKo,8635
126
+ torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
127
+ torchzero/modules/weight_decay/weight_decay.py,sha256=Y7kE_j0GRF8ceJ9SS6qykQ8a23X2OTDCjJ9VklOQSEw,5415
128
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
129
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
130
+ torchzero/modules/zeroth_order/__init__.py,sha256=1ADUiOHVHzvIP4TpH7_ILmeW2heidfikbf6d5g_1RzY,18
131
+ torchzero/modules/zeroth_order/cd.py,sha256=SwjwoAqX86-JnVHIwKAE7g_tqm0EvEUUNuLM4T5mKXE,4876
132
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
133
+ torchzero/optim/root.py,sha256=gGtAJ9qBoSNV58EKzUGZ8J3lyKGUF8BEw34Zfprppdo,2273
134
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
135
+ torchzero/optim/utility/split.py,sha256=kraPCLAewX2uLbD_9R2dIrcF-kpUuT9IcxPeVrAARvA,1672
136
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
137
+ torchzero/optim/wrappers/directsearch.py,sha256=rimJIB2RrVzLpRPQKhzkrMQ4bTAEU3NEOT4pJQNIAHE,11309
138
+ torchzero/optim/wrappers/fcmaes.py,sha256=jKmmBKEwguYiJdvTRmAp5JSilxcUhtpRoKlzmp-lyWE,4251
139
+ torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
140
+ torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
141
+ torchzero/optim/wrappers/nlopt.py,sha256=RuLKretljjAjTZ9tXY3FyEDuB7mAboeGOQBupWfzPc4,8105
142
+ torchzero/optim/wrappers/optuna.py,sha256=pIXkC5NVmEnUQ4jsGaz6Gv9uYOZM9rnxME4UGkeolsE,2393
143
+ torchzero/optim/wrappers/scipy.py,sha256=A4yeQRdB9f65UrJ2g80NfqqMc6zUyr9js40TUESCHPg,21535
144
+ torchzero/utils/__init__.py,sha256=7S4VRTkfS-0uI8HOR0EFIjiEcKrmYK7LEhTocIgki6c,1112
145
+ torchzero/utils/compile.py,sha256=Dozox91tcShUJ3L320TTbJrcuA-l4WVegLAQujRqy94,5132
146
+ torchzero/utils/derivatives.py,sha256=zJ0xyedvlIwgAYMa1F5BBfyrkvgjXy7v7evvl6QAlT0,17195
147
+ torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
148
+ torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
149
+ torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
150
+ torchzero/utils/optimizer.py,sha256=pzGIddJtsR_oH3mb3GOsWHi02Sb62ms4NCaWUOk0Clo,12470
151
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
152
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
153
+ torchzero/utils/python_tools.py,sha256=QxLZ2PKgp4R2zI_C1qNOX_u4eIcuT0wNBpBM5YEIYuU,3428
154
+ torchzero/utils/tensorlist.py,sha256=nIWBME3fUQPsr4buvtV3LaJgSXPEG_Xb58KAzfjwK-k,56064
155
+ torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
156
+ torchzero/utils/linalg/__init__.py,sha256=cNoTJOPeqbNn9l7_HAAen2rlehGS3DyY5SveInG3Stc,328
157
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
158
+ torchzero/utils/linalg/linear_operator.py,sha256=uJUxvOVHpG3U3GNx61JGa_uM8GqzsNZmA_z7P0RwZ5E,12747
159
+ torchzero/utils/linalg/matrix_funcs.py,sha256=BKQK_oIG35R6yGxU80eBG0VkyY2EgxywqbhvU7JhWm4,3109
160
+ torchzero/utils/linalg/orthogonalize.py,sha256=BpuDiAPrsJMUpTNBMCntBNA8-O2nozLxY5ZbCfRlEFY,444
161
+ torchzero/utils/linalg/qr.py,sha256=5tbPEV9I6X69r5ACWF9XeqjZTUtUql2145uoGjlJNDs,2517
162
+ torchzero/utils/linalg/solve.py,sha256=R5lPTzHn2sgvRy4MRp-Ngl0sypSGLRLHJjf1oKKAJD0,14395
163
+ torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
164
+ torchzero-0.3.14.dist-info/METADATA,sha256=Pd9XeJLSPuNQtb-dVJvnIiDOX5Mdr9b9ihpfw6rxBpQ,565
165
+ torchzero-0.3.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
166
+ torchzero-0.3.14.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
167
+ torchzero-0.3.14.dist-info/RECORD,,
@@ -1,3 +1,2 @@
1
- docs
2
1
  tests
3
2
  torchzero
docs/source/conf.py DELETED
@@ -1,59 +0,0 @@
1
- # Configuration file for the Sphinx documentation builder.
2
- #
3
- # For the full list of built-in configuration values, see the documentation:
4
- # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
-
6
- # -- Project information -----------------------------------------------------
7
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
- import sys, os
9
- #sys.path.insert(0, os.path.abspath('.../src'))
10
-
11
- project = 'torchzero'
12
- copyright = '2025, Ivan Nikishev'
13
- author = 'Ivan Nikishev'
14
-
15
- # -- General configuration ---------------------------------------------------
16
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
17
-
18
- # https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
19
- extensions = [
20
- 'sphinx.ext.autodoc',
21
- 'sphinx.ext.autosummary',
22
- 'sphinx.ext.viewcode',
23
- 'sphinx.ext.autosectionlabel',
24
- 'sphinx.ext.githubpages',
25
- 'sphinx.ext.napoleon',
26
- 'autoapi.extension',
27
- "myst_nb",
28
-
29
- # 'sphinx_rtd_theme',
30
- ]
31
- autosummary_generate = True
32
- autoapi_dirs = ['../../torchzero']
33
- autoapi_type = "python"
34
- # autoapi_ignore = ["*/tensorlist.py"]
35
-
36
- # https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
37
- autoapi_options = [
38
- "members",
39
- "undoc-members",
40
- "show-inheritance",
41
- "show-module-summary",
42
- "imported-members",
43
- ]
44
-
45
-
46
- templates_path = ['_templates']
47
- exclude_patterns = []
48
-
49
- # -- Options for HTML output -------------------------------------------------
50
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
51
-
52
- #html_theme = 'alabaster'
53
- html_theme = 'sphinx_rtd_theme'
54
- html_static_path = ['_static']
55
-
56
-
57
- # OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
58
- source_suffix = ['.rst', '.md']
59
- master_doc = 'index'
@@ -1,46 +0,0 @@
1
- class MyModule:
2
- """[One-line summary of the class].
3
-
4
- [A more detailed description of the class, explaining its purpose, how it
5
- works, and its typical use cases. You can use multiple paragraphs.]
6
-
7
- .. note::
8
- [Optional: Add important notes, warnings, or usage guidelines here.
9
- For example, you could mention if a closure is required, discuss
10
- stability, or highlight performance characteristics. Use the `.. note::`
11
- directive to make it stand out in the documentation.]
12
-
13
- Args:
14
- param1 (type, optional):
15
- [Description of the first parameter. Use :code:`backticks` for
16
- inline code like variable names or specific values like ``"autograd"``.
17
- Explain what the parameter does.] Defaults to [value].
18
- param2 (type):
19
- [Description of a mandatory parameter (no "optional" or "Defaults to").]
20
- **kwargs:
21
- [If you accept keyword arguments, describe what they are used for.]
22
-
23
- Examples:
24
- [A title or short sentence describing the first example]:
25
-
26
- .. code-block:: python
27
-
28
- opt = tz.Modular(
29
- model.parameters(),
30
- ...
31
- )
32
-
33
- [A title or short sentence for a second, different example]:
34
-
35
- .. code-block:: python
36
-
37
- opt = tz.Modular(
38
- model.parameters(),
39
- ...
40
- )
41
-
42
- References:
43
- - [Optional: A citation for a relevant paper, book, or algorithm.]
44
- - [Optional: A link to a blog post or website with more information.]
45
-
46
- """