torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -11,17 +11,19 @@ in an optimizer when you have to create one from parameters on each step. The so
11
11
  it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
12
12
  """
13
13
  import builtins
14
- from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
15
14
  import math
16
15
  import operator
16
+ from abc import ABC, abstractmethod
17
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
17
18
  from typing import Any, Literal, TypedDict, overload
18
- from typing_extensions import Self, TypeAlias, Unpack
19
19
 
20
20
  import torch
21
- from .ops import where_
22
- from .python_tools import generic_eq, zipmap
23
- from .numberlist import NumberList, as_numberlist, maybe_numberlist
21
+ from typing_extensions import Self, TypeAlias, Unpack
24
22
 
23
+ from .metrics import Metrics, evaluate_metric, calculate_metric_list
24
+ from .numberlist import NumberList, as_numberlist, maybe_numberlist
25
+ from .ops import where_
26
+ from .python_tools import generic_ne, zipmap
25
27
 
26
28
  _Scalar = int | float | bool | complex
27
29
  _TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
@@ -33,6 +35,7 @@ _STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
33
35
  _Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
34
36
 
35
37
  Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
38
+
36
39
  class _NewTensorKwargs(TypedDict, total = False):
37
40
  memory_format: Any
38
41
  dtype: Any
@@ -217,6 +220,12 @@ class TensorList(list[torch.Tensor | Any]):
217
220
  """Returns a TensorList with all elements for which `fn` returned True."""
218
221
  return self.__class__(i for i in self if fn(i, *args, **kwargs))
219
222
 
223
+ def filter_by_list(self, s: Sequence[bool]):
224
+ """returns a new TensorList with all elements where corresponding elements in :code:`s` are True."""
225
+ if len(self) != len(s):
226
+ raise ValueError(f"{len(self) = }, {len(s) = }")
227
+ return self.__class__(i for i, boolean in zip(self, s) if boolean)
228
+
220
229
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
221
230
  """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
222
231
  Otherwise applies `fn` to this TensorList and `other`.
@@ -319,8 +328,20 @@ class TensorList(list[torch.Tensor | Any]):
319
328
  def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
320
329
  def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
321
330
  def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
331
+
322
332
  def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
323
- return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
333
+ # return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
334
+ if ord == 1: return self.global_sum()
335
+ if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
336
+ if ord == torch.inf: return self.abs().global_max()
337
+ if ord == -torch.inf: return self.abs().global_min()
338
+ if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
339
+
340
+ return self.abs().pow_(ord).global_sum().pow(1/ord)
341
+
342
+ def global_metric(self, metric: Metrics) -> torch.Tensor:
343
+ return evaluate_metric(self, metric)
344
+
324
345
  def global_any(self): return builtins.any(self.any())
325
346
  def global_all(self): return builtins.all(self.all())
326
347
  def global_numel(self) -> int: return builtins.sum(self.numel())
@@ -351,31 +372,54 @@ class TensorList(list[torch.Tensor | Any]):
351
372
 
352
373
  def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
353
374
  return self.zipmap_args(torch.randint_like, low, high, **kwargs)
375
+
354
376
  def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
355
377
  res = self.empty_like(**kwargs)
356
378
  res.uniform_(low, high, generator=generator)
357
379
  return res
380
+
358
381
  def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
359
382
  r = self.randn_like(generator=generator, **kwargs)
360
- return (r * radius) / r.global_vector_norm()
383
+ return r.mul_(maybe_numberlist(radius) / r.global_vector_norm())
384
+
361
385
  def bernoulli(self, generator = None):
362
386
  return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
387
+
363
388
  def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
364
389
  """p is probability of a 1, other values will be 0."""
365
390
  return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
391
+
366
392
  def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
367
393
  """p is probability of a 1, other values will be -1."""
368
394
  return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
369
395
 
370
- def sample_like(self, eps: "_Scalar | _ScalarSeq" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
396
+ def sample_like(self, distribution: Distributions = 'normal', variance: "_Scalar | _ScalarSeq | Sequence | None" = None, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
371
397
  """Sample around 0."""
372
- if distribution in ('normal', 'gaussian'): return self.randn_like(generator=generator, **kwargs) * eps
398
+ if isinstance(variance, Sequence):
399
+ if all(v is None for v in variance): variance = None
400
+ else: variance = [v if v is not None else 1 for v in variance]
401
+
402
+ if distribution in ('normal', 'gaussian'):
403
+ ret = self.randn_like(generator=generator, **kwargs)
404
+ if variance is not None: ret *= variance
405
+ return ret
406
+
373
407
  if distribution == 'uniform':
374
- if isinstance(eps, (list,tuple)):
375
- return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs)
376
- return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
377
- if distribution == 'sphere': return self.sphere_like(eps, generator=generator, **kwargs)
378
- if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
408
+ b = 1
409
+ if variance is not None:
410
+ b = ((12 * maybe_numberlist(variance)) ** 0.5) / 2
411
+ return self.uniform_like(-b, b, generator=generator, **kwargs)
412
+
413
+ if distribution == 'sphere':
414
+ if variance is None: radius = 1
415
+ else: radius = maybe_numberlist(variance) * math.sqrt(self.global_numel())
416
+ return self.sphere_like(radius, generator=generator, **kwargs)
417
+
418
+ if distribution == 'rademacher':
419
+ ret = self.rademacher_like(generator=generator, **kwargs)
420
+ if variance is not None: ret *= variance
421
+ return ret
422
+
379
423
  raise ValueError(f'Unknow distribution {distribution}')
380
424
 
381
425
  def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
@@ -425,11 +469,11 @@ class TensorList(list[torch.Tensor | Any]):
425
469
  return self
426
470
 
427
471
  def lazy_add(self, other: int | float | list[int | float] | tuple[int | float]):
428
- if generic_eq(other, 0): return self
429
- return self.add(other)
472
+ if generic_ne(other, 0): return self.add(other)
473
+ return self
430
474
  def lazy_add_(self, other: int | float | list[int | float] | tuple[int | float]):
431
- if generic_eq(other, 0): return self
432
- return self.add_(other)
475
+ if generic_ne(other, 0): return self.add_(other)
476
+ return self
433
477
 
434
478
  @overload
435
479
  def sub(self, other: _TensorSeq, alpha: _Scalar = 1): ...
@@ -449,11 +493,11 @@ class TensorList(list[torch.Tensor | Any]):
449
493
  return self
450
494
 
451
495
  def lazy_sub(self, other: int | float | list[int | float] | tuple[int | float]):
452
- if generic_eq(other, 0): return self
453
- return self.sub(other)
496
+ if generic_ne(other, 0): return self.sub(other)
497
+ return self
454
498
  def lazy_sub_(self, other: int | float | list[int | float] | tuple[int | float]):
455
- if generic_eq(other, 0): return self
456
- return self.sub_(other)
499
+ if generic_ne(other, 0): return self.sub_(other)
500
+ return self
457
501
 
458
502
  def neg(self): return self.__class__(torch._foreach_neg(self))
459
503
  def neg_(self):
@@ -467,13 +511,13 @@ class TensorList(list[torch.Tensor | Any]):
467
511
 
468
512
  # TODO: benchmark
469
513
  def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
470
- if generic_eq(other, 1):
471
- if clone: return self.clone()
472
- return self
473
- return self * other
514
+ if generic_ne(other, 1):
515
+ return self * other
516
+ if clone: return self.clone()
517
+ return self
474
518
  def lazy_mul_(self, other: int | float | list[int | float] | tuple[int | float]):
475
- if generic_eq(other, 1): return self
476
- return self.mul_(other)
519
+ if generic_ne(other, 1): return self.mul_(other)
520
+ return self
477
521
 
478
522
  def div(self, other: _STOrSTSeq) -> Self: return self.__class__(torch._foreach_div(self, other))
479
523
  def div_(self, other: _STOrSTSeq):
@@ -481,11 +525,11 @@ class TensorList(list[torch.Tensor | Any]):
481
525
  return self
482
526
 
483
527
  def lazy_div(self, other: int | float | list[int | float] | tuple[int | float]):
484
- if generic_eq(other, 1): return self
485
- return self / other
528
+ if generic_ne(other, 1): return self / other
529
+ return self
486
530
  def lazy_div_(self, other: int | float | list[int | float] | tuple[int | float]):
487
- if generic_eq(other, 1): return self
488
- return self.div_(other)
531
+ if generic_ne(other, 1): return self.div_(other)
532
+ return self
489
533
 
490
534
  def pow(self, exponent: "_Scalar | _STSeq"): return self.__class__(torch._foreach_pow(self, exponent))
491
535
  def pow_(self, exponent: "_Scalar | _STSeq"):
@@ -497,6 +541,11 @@ class TensorList(list[torch.Tensor | Any]):
497
541
  torch._foreach_pow_(input, self)
498
542
  return self
499
543
 
544
+ def square(self): return self.__class__(torch._foreach_pow(self, 2))
545
+ def square_(self):
546
+ torch._foreach_pow_(self, 2)
547
+ return self
548
+
500
549
  def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
501
550
  def sqrt_(self):
502
551
  torch._foreach_sqrt_(self)
@@ -627,9 +676,12 @@ class TensorList(list[torch.Tensor | Any]):
627
676
  if dim is None: dim = ()
628
677
  return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
629
678
 
630
- def norm(self, ord: _Scalar, dtype=None):
679
+ def norm(self, ord: float, dtype=None):
631
680
  return self.__class__(torch._foreach_norm(self, ord, dtype))
632
681
 
682
+ def metric(self, metric: Metrics) -> "TensorList":
683
+ return calculate_metric_list(self, metric)
684
+
633
685
  def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
634
686
  if dim == 'global': return self._global_fn(keepdim, self.global_mean)
635
687
  return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
@@ -782,29 +834,29 @@ class TensorList(list[torch.Tensor | Any]):
782
834
  for t, o in zip(self, other): t.copysign_(o)
783
835
  return self
784
836
 
785
- def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
837
+ def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
786
838
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
787
839
  if tensorwise:
788
- norm_self = self.norm(ord)
789
- norm_other = magnitude.norm(ord)
840
+ norm_self = self.metric(ord)
841
+ norm_other = magnitude.metric(ord)
790
842
  else:
791
- norm_self = self.global_vector_norm(ord)
792
- norm_other = magnitude.global_vector_norm(ord)
843
+ norm_self = self.global_metric(ord)
844
+ norm_other = magnitude.global_metric(ord)
793
845
 
794
- if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
846
+ if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
795
847
 
796
848
  return self * (norm_other / norm_self.clip_(min=eps))
797
849
 
798
- def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
850
+ def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
799
851
  if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
800
852
  if tensorwise:
801
- norm_self = self.norm(ord)
802
- norm_other = magnitude.norm(ord)
853
+ norm_self = self.metric(ord)
854
+ norm_other = magnitude.metric(ord)
803
855
  else:
804
- norm_self = self.global_vector_norm(ord)
805
- norm_other = magnitude.global_vector_norm(ord)
856
+ norm_self = self.global_metric(ord)
857
+ norm_other = magnitude.global_metric(ord)
806
858
 
807
- if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
859
+ if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
808
860
 
809
861
  return self.mul_(norm_other / norm_self.clip_(min=eps))
810
862
 
@@ -897,14 +949,14 @@ class TensorList(list[torch.Tensor | Any]):
897
949
  if eps!=0: std.add_(eps)
898
950
  return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
899
951
 
900
- def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
952
+ def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
901
953
  """calculate multipler to clip self norm to min and max"""
902
954
  if tensorwise:
903
- self_norm = self.norm(ord)
955
+ self_norm = self.metric(ord)
904
956
  self_norm.masked_fill_(self_norm == 0, 1)
905
957
 
906
958
  else:
907
- self_norm = self.global_vector_norm(ord)
959
+ self_norm = self.global_metric(ord)
908
960
  if self_norm == 0: return 1
909
961
 
910
962
  mul = 1
@@ -918,12 +970,12 @@ class TensorList(list[torch.Tensor | Any]):
918
970
 
919
971
  return mul
920
972
 
921
- def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
973
+ def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
922
974
  """clips norm of each tensor to (min, max) range"""
923
975
  if min is None and max is None: return self
924
976
  return self * self._clip_multiplier(min, max, tensorwise, ord)
925
977
 
926
- def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
978
+ def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
927
979
  """clips norm of each tensor to (min, max) range"""
928
980
  if min is None and max is None: return self
929
981
  return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
@@ -982,6 +1034,15 @@ class TensorList(list[torch.Tensor | Any]):
982
1034
  # """sets index in flattened view"""
983
1035
  # return self.clone().flatset_(idx, value)
984
1036
 
1037
+ def flat_get(self, idx: int):
1038
+ cur = 0
1039
+ for tensor in self:
1040
+ numel = tensor.numel()
1041
+ if idx < cur + numel:
1042
+ return tensor.view(-1)[cur-idx]
1043
+ cur += numel
1044
+ raise IndexError(idx)
1045
+
985
1046
  def flat_set_(self, idx: int, value: Any):
986
1047
  """sets index in flattened view"""
987
1048
  cur = 0
@@ -1057,6 +1118,19 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
1057
1118
  if isinstance(x, torch.Tensor): return x.numel()
1058
1119
  return x.global_numel()
1059
1120
 
1121
+
1122
+ def generic_finfo(x: torch.Tensor | TensorList) -> torch.finfo:
1123
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype)
1124
+ return torch.finfo(x[0].dtype)
1125
+
1126
+ def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
1127
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
1128
+ return torch.finfo(x[0].dtype).eps
1129
+
1130
+ def generic_finfo_tiny(x: torch.Tensor | TensorList) -> float:
1131
+ if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).tiny
1132
+ return torch.finfo(x[0].dtype).tiny
1133
+
1060
1134
  @overload
1061
1135
  def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
1062
1136
  @overload
@@ -1069,7 +1143,8 @@ def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
1069
1143
  if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
1070
1144
  return x.global_vector_norm(ord)
1071
1145
 
1072
-
1146
+ def generic_metric(x: torch.Tensor | TensorList, metric: Metrics) -> torch.Tensor:
1147
+ return evaluate_metric(x, metric)
1073
1148
 
1074
1149
  @overload
1075
1150
  def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
@@ -1079,3 +1154,11 @@ def generic_randn_like(x: torch.Tensor | TensorList):
1079
1154
  if isinstance(x, torch.Tensor): return torch.randn_like(x)
1080
1155
  return x.randn_like()
1081
1156
 
1157
+
1158
+ def generic_sum(x: torch.Tensor | TensorList) -> torch.Tensor:
1159
+ if isinstance(x, torch.Tensor): return x.sum()
1160
+ return x.global_sum()
1161
+
1162
+ def generic_max(x: torch.Tensor | TensorList) -> torch.Tensor:
1163
+ if isinstance(x, torch.Tensor): return x.max()
1164
+ return x.global_max()
@@ -7,10 +7,15 @@ import numpy as np
7
7
  import torch
8
8
 
9
9
 
10
- def totensor(x):
11
- if isinstance(x, torch.Tensor): return x
12
- if isinstance(x, np.ndarray): return torch.from_numpy(x)
13
- return torch.from_numpy(np.asarray(x))
10
+ def totensor(x, device=None, dtype=None):
11
+ if device is None and dtype is None:
12
+ if isinstance(x, torch.Tensor): return x
13
+ if isinstance(x, np.ndarray): return torch.from_numpy(x)
14
+ return torch.from_numpy(np.asarray(x))
15
+
16
+ if isinstance(x, torch.Tensor): return x.to(device=device, dtype=dtype)
17
+ if isinstance(x, np.ndarray): return torch.as_tensor(x, device=device, dtype=dtype)
18
+ return torch.as_tensor(np.asarray(x), device=device, dtype=dtype)
14
19
 
15
20
  def tonumpy(x):
16
21
  if isinstance(x, np.ndarray): return x
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchzero
3
+ Version: 0.3.13
4
+ Summary: Modular optimization library for PyTorch.
5
+ Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/inikishev/torchzero
7
+ Project-URL: Repository, https://github.com/inikishev/torchzero
8
+ Project-URL: Issues, https://github.com/inikishev/torchzero/isses
9
+ Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: torch
13
+ Requires-Dist: numpy
14
+ Requires-Dist: typing_extensions
@@ -0,0 +1,166 @@
1
+ tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
2
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
3
+ tests/test_opts.py,sha256=aT6-RbyUhWlIhMPi-ihqZtfiiYk0eT9vEIxMyxvwVOk,44059
4
+ tests/test_tensorlist.py,sha256=pWXQE-vEq08EGJSKWgsTgo-7QjjkavOJ5BlWUm241qI,72434
5
+ tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
6
+ tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
7
+ torchzero/__init__.py,sha256=aIH-cCTXnDr90cKUPhM8bv-uE69Hzjlf0jlYspYf0ZM,120
8
+ torchzero/core/__init__.py,sha256=aYyQt-CHzWT6hGUt5KVjRZZr2lsX5I1XvbWpzaAv3VE,151
9
+ torchzero/core/module.py,sha256=o4MZnJbMpk-F1qm-XvbguE_W_9a0sO2Mb8iDUZQ42B4,40511
10
+ torchzero/core/reformulation.py,sha256=jppgzXBtqdsc7ot6_Gr38vJbbrhG1Gs4vC32y7iB4BA,2387
11
+ torchzero/core/transform.py,sha256=xRDpsZj0H1QcFdO-t2mNMNOYoqqnRHiI3K1YluWwCVk,17097
12
+ torchzero/modules/__init__.py,sha256=3lGta9P0N3cWdVcruCBJ7uqu4DfLTPCKI_mlOZT6Z_o,615
13
+ torchzero/modules/functional.py,sha256=E_d6hLT2_xdE-3AhQ4AthDYK5uZULbF10iHI09Z3_yk,7921
14
+ torchzero/modules/adaptive/__init__.py,sha256=5L2dlEJV6HKBnYhgd7wo2yGi0WPd9qmpw9XS5wOQOq8,944
15
+ torchzero/modules/adaptive/adagrad.py,sha256=0qXC5F4PuOsgLjRXQUWBoiq0AUixsvOP1uDbEeRIcNs,12531
16
+ torchzero/modules/adaptive/adahessian.py,sha256=rWxgDiBMd6MK64mRjZwiudDF07is5AFj_MCEvBD7h8U,8670
17
+ torchzero/modules/adaptive/adam.py,sha256=4lWSe__tdyRv0rfkUda1qa_NH36DIW3sd8td98bK6XI,3829
18
+ torchzero/modules/adaptive/adan.py,sha256=Dt_gibyrGtWDIUCaSF6RFIxu2xwiF9fCfrpkoD-CaUM,2825
19
+ torchzero/modules/adaptive/adaptive_heavyball.py,sha256=xQQw1Vx-NgQd_ouK14J1p5ijd5lEn3sAN0hVJAL0j8U,2024
20
+ torchzero/modules/adaptive/aegd.py,sha256=_4ASgDX8__DPnnBE_RncnMqM4rItM7Eji4EZzSGGq5I,1876
21
+ torchzero/modules/adaptive/esgd.py,sha256=DtrN3hZhGK4LgMxmcjQCBtEO-5hLZrAdnvps6f8WQ2A,6416
22
+ torchzero/modules/adaptive/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
23
+ torchzero/modules/adaptive/lmadagrad.py,sha256=rMs7vrgiwOJgWo-OZXkGu32X561edEhhaxwgbY2NTnk,7176
24
+ torchzero/modules/adaptive/mars.py,sha256=iOkyY3r52btp7Cry7WN0AB4arpn4N9b_Hg6S55XC6Q8,2255
25
+ torchzero/modules/adaptive/matrix_momentum.py,sha256=ZMSdGNSHgyUgtJwjKzK7PZDwMrPw5wZJVu0K4Xa-SpI,6693
26
+ torchzero/modules/adaptive/msam.py,sha256=locqM2jiC3AbGCDCo6T40GF3iaVV2svrkzoi0hD2cJI,6663
27
+ torchzero/modules/adaptive/muon.py,sha256=5Asgj03s6JXXrO-p5Qgn3D8bVwbDEqCq7hxNy4joQDE,10335
28
+ torchzero/modules/adaptive/natural_gradient.py,sha256=5qRehh-iAeZ4hjfOR-gfsObbMsJZnipzCkG-yptkrH0,6349
29
+ torchzero/modules/adaptive/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
30
+ torchzero/modules/adaptive/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
31
+ torchzero/modules/adaptive/rprop.py,sha256=VDnLPKxw8ECihyUeNVE8cyDll_Ut3k3_NqoLgpgxxLA,11818
32
+ torchzero/modules/adaptive/sam.py,sha256=LnOPNZnIUsis0402RHnA-fTPkNM8baUR9HR50pF_BtM,5696
33
+ torchzero/modules/adaptive/shampoo.py,sha256=r7V4I5_Ve1YVOS3HhO2k5cZvJT1lPHTVApV3iVJVceA,9711
34
+ torchzero/modules/adaptive/soap.py,sha256=roQLBthNNNmTYgeJPi_LxZY-r4m6REeUo0_DZknYU50,10662
35
+ torchzero/modules/adaptive/sophia_h.py,sha256=lSK8uVdOxBAhU2jE6fyIx1YgqEQyZG-Fv9o2TniAZzk,7179
36
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
37
+ torchzero/modules/clipping/clipping.py,sha256=t98M3QKZKqXJ3_tzXXIiG4EOYMaHqLYrMZ-6zmRuy-k,14331
38
+ torchzero/modules/clipping/ema_clipping.py,sha256=Ki0LPNUwPoE825A5rSE7SxGQMiI3nO3iwnjKQ486iaI,6611
39
+ torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
40
+ torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
41
+ torchzero/modules/conjugate_gradient/cg.py,sha256=iAIiIyfM5hWeFH6-gxx8y-5olY0ED4DpnbLzXa9dke4,14492
42
+ torchzero/modules/experimental/__init__.py,sha256=blI-OhpQAC6-Ho1uxUq-t7Mm9CAMnNMXkBDmXul8tbc,729
43
+ torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
44
+ torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
45
+ torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
46
+ torchzero/modules/experimental/gradmin.py,sha256=hKTOG7tk6FnG8t-7OmTAhGTGSDdONzP1JvCRPRqaKt0,3740
47
+ torchzero/modules/experimental/l_infinity.py,sha256=nhYusM3YYbc0ptaSf9zlrsqY8EgxlHm9OejJ6VV0qtM,4750
48
+ torchzero/modules/experimental/momentum.py,sha256=VqZc14EGVO_KUABPLRIBlvHdgg-64o-4heMQH0vW5vY,5233
49
+ torchzero/modules/experimental/newton_solver.py,sha256=0HnDBlrBLvUgS4hWmkJqyw0M7UPFp3kU3SFq2xZVYhQ,5454
50
+ torchzero/modules/experimental/newtonnewton.py,sha256=a0XXvlVe37z2MMcQ4TeGbbWX9OuYp_5b-21jq3o1z3E,3823
51
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
52
+ torchzero/modules/experimental/scipy_newton_cg.py,sha256=8nKBaHHmqdU9F1kVPn2QVFUTx2_I8Jsfqxix1v-qoL0,4073
53
+ torchzero/modules/experimental/structural_projections.py,sha256=rxJFG5F23dOiK_8KqKyvoSMLWqAOXtVGHSwfRqH22Wg,4185
54
+ torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
55
+ torchzero/modules/grad_approximation/fdm.py,sha256=zx70GZDQmhe43bZP5Mbbl31xsMOsGO43kznoQDbqxJo,4372
56
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=v7v5maMC_Vak7N5w-LjIH6FIrkQUt7MvbR0PLprsmTI,4338
57
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=x8vlweBrfJ6SmhMHvI_C8UZGzlS3AnmlulvqnSzm6iY,4437
58
+ torchzero/modules/grad_approximation/rfdm.py,sha256=PA0FlGVaNHxstUqPqG1wMuIF4i9NpCqPpp6-1vxCboA,23635
59
+ torchzero/modules/higher_order/__init__.py,sha256=iaoIrmR9DJE9QHt9PeZNCWqIYDe-86h1IjkaumR4qF0,51
60
+ torchzero/modules/higher_order/higher_order_newton.py,sha256=2r1wuhdi57pbo8akQE88O8R-Y79BtiwD1WQIShh1rjQ,12967
61
+ torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
62
+ torchzero/modules/least_squares/gn.py,sha256=23AB6AWAl5IuBj4Vd3boQ6ndk0pO3ovaF9EiY1a1XWs,5094
63
+ torchzero/modules/line_search/__init__.py,sha256=mFWgkcgfMkL2NKj3CLbuwee3e8WHBOaXs-wtx3oTW58,216
64
+ torchzero/modules/line_search/_polyinterp.py,sha256=qIhcLjOlpB6NHU0oiUGMncwQxWNfy8757orsbzjkp6s,10882
65
+ torchzero/modules/line_search/adaptive.py,sha256=8Ip5F5PpsDLgg6TwB_E7zIZheycd78coRg4u7cpO3Cg,3795
66
+ torchzero/modules/line_search/backtracking.py,sha256=Mhx8_UT_Mr1gASYHUorBJ38E4YlcM9LpW9YrJHYfLXU,9049
67
+ torchzero/modules/line_search/line_search.py,sha256=lmtjr9Zpz9RYJXoYaJnpXkBSIdcN6DdwGKKXTCmcJNU,13294
68
+ torchzero/modules/line_search/scipy.py,sha256=gQMi6IYrnDvYsZWIO_cELhg_VZutIQJAELHHlLyu2fg,2286
69
+ torchzero/modules/line_search/strong_wolfe.py,sha256=lXcfOzg4kU0RGTe7GnVWuDPs2YWAjHH9vwRMDEGz4Mg,15054
70
+ torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
71
+ torchzero/modules/misc/debug.py,sha256=6pFAGYANjCPGIZH_4ghpUYYTEsT5jr7PMB9KLuPP4p8,1532
72
+ torchzero/modules/misc/escape.py,sha256=qfEdKLD5rejqrmvyHrI5BRQq8js9UF2-Axs_C0KFyWA,1866
73
+ torchzero/modules/misc/gradient_accumulation.py,sha256=mBWa5CBCZwp4TrtOyjWI3VnHag4gum4WBM2WFhvHqW4,4891
74
+ torchzero/modules/misc/homotopy.py,sha256=hihLETE4dNZ27zatqPR_qT3kGX-AXbC7oBWRDbFQo58,1939
75
+ torchzero/modules/misc/misc.py,sha256=feI-IQlxhIoAbsSRTjE4SbGez1c2Uu9-WA_nkK7iiqQ,15411
76
+ torchzero/modules/misc/multistep.py,sha256=RtDFIeTHu4RcERvlKEP4_10-lpRZOgbnBeSah92dQ7A,6323
77
+ torchzero/modules/misc/regularization.py,sha256=SkQ0_Ybtv9IEGI9QGdvNZaja5bAyc1x-j_1gvYIVepI,6105
78
+ torchzero/modules/misc/split.py,sha256=JcXVB4xk3h55YT2OAdepVsRoE1PD7bqX6NmJ2IxBgAI,4013
79
+ torchzero/modules/misc/switch.py,sha256=p758heAnv-PkoslpafL35Yp7mlvPmDVSe1mWiuuD8Mk,3711
80
+ torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
81
+ torchzero/modules/momentum/averaging.py,sha256=vDW8tgGsEuBXF_BTUYHB_j--TIVam9j0nZdp_x8TkxY,3229
82
+ torchzero/modules/momentum/cautious.py,sha256=x506a3lUETRpxPWqXLFJVFBH1gmLqIfqL5J-hFdEvOo,8051
83
+ torchzero/modules/momentum/momentum.py,sha256=q3n0BvQURuSBzA9vn1ZrH-n7Nsr0AS-38VJuwraQPY0,4495
84
+ torchzero/modules/ops/__init__.py,sha256=9UHaXs9aaKc0ewAhicTlDmj42bSC_vddMOD0eYuUj_8,1226
85
+ torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
86
+ torchzero/modules/ops/binary.py,sha256=2hV2oruaq5Asu4Ts8X8yiZQ-07fU0RGpRy3-vifXqXY,12151
87
+ torchzero/modules/ops/higher_level.py,sha256=E76zgSHlhVpHLrXhnVwelIQFm1IKn0IFcVq7DOJw0es,9037
88
+ torchzero/modules/ops/multi.py,sha256=YC3rBTmPRwF5aEPDNsyTK4J_JEAbmE7oBmF7W-VOV3A,8588
89
+ torchzero/modules/ops/reduce.py,sha256=kALG7X8q02sWpo1skpXjS0r875gwq6mrhLZbFfYaZoA,6324
90
+ torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
91
+ torchzero/modules/ops/utility.py,sha256=_k9S59i6IYOzzfIQlToQ9mlDseTTAS_49wujUxMGXZo,4105
92
+ torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
93
+ torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
94
+ torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
95
+ torchzero/modules/projections/projection.py,sha256=itkkb2UmMqbdtWKjUUg6gbFJfCEIZAskC0HCvom-6sc,14084
96
+ torchzero/modules/quasi_newton/__init__.py,sha256=HxXENs3O6nFRfCvUJhWPK9f8_A6iMwB6UF1Zold12UQ,515
97
+ torchzero/modules/quasi_newton/damping.py,sha256=K1DVqqKiAs6-F3JQh5jlKNb79oJdObqnKWwHHRl6boQ,2813
98
+ torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=Zx-tlFRa89GhoSP7RFJdLQJPiqPCL7rWaV7WJoQ1YCs,6930
99
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=fzCjV5YsLo_uJTVG3vosPHsvDc97mLKueK6fxOHLb8I,11195
100
+ torchzero/modules/quasi_newton/lsr1.py,sha256=D3_yV5xtgklMlU4fAL1-sH82-1tNl3K2F12ZBZyLQGM,8512
101
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=-xUGPld8Y0MHwN6qsmDihLbUbulU0T1z8jf2mZhNpcE,44529
102
+ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
103
+ torchzero/modules/restarts/restars.py,sha256=A3fpTIbfpZCEUq9csPckdcsXQtaL0Le5UY3ZfKzxVSs,8971
104
+ torchzero/modules/second_order/__init__.py,sha256=lTGccDNVwPuMevMeKi5O0a9cl24Rn9tk7VkC6jvlGYc,233
105
+ torchzero/modules/second_order/multipoint.py,sha256=Ilzo0Ddd3iApegceu7cHSMGim9ZH5QS4-2uBtrKXC6k,8581
106
+ torchzero/modules/second_order/newton.py,sha256=PAPbJzssx0Ji328BFOEzeJZPd3IubJTPHs6ZhqS_nW8,15663
107
+ torchzero/modules/second_order/newton_cg.py,sha256=zavattL2z-IjWRT_AdwV5h7BGtQnrBzMTtTyt9xjZ-I,17363
108
+ torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
109
+ torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
110
+ torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
111
+ torchzero/modules/smoothing/sampling.py,sha256=zI5bATytQmCqm_UgAQbfA9tNRgrZaKLfUb0B-kzKRHU,12867
112
+ torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
113
+ torchzero/modules/step_size/adaptive.py,sha256=HvffW3m1NnpMTpps0QjJTvbblSODxxWMBBFTbNwp0vM,14482
114
+ torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
115
+ torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
116
+ torchzero/modules/termination/termination.py,sha256=BXU3R04caBc8rFJ4v_yJjgGi1X4iA11eYwlbiJfxexI,6637
117
+ torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
118
+ torchzero/modules/trust_region/cubic_regularization.py,sha256=gbKOR5zBo3t9i-sW23DCtTQwZrBubuFy_VuafrLaeUw,6718
119
+ torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
120
+ torchzero/modules/trust_region/levenberg_marquardt.py,sha256=Ibyf3ivcGR9sPkD5COXi7dRk3PSOfyTlI1W8ISAHNa8,5039
121
+ torchzero/modules/trust_region/trust_cg.py,sha256=UdQxNx7jf_CxyioRtJ92z35QU5HDbI22xpgd-4pW7V8,4297
122
+ torchzero/modules/trust_region/trust_region.py,sha256=eimCFViJSzoubrRmDluCon6mfcyT7PQA0yRPu4FlO2Q,12872
123
+ torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
124
+ torchzero/modules/variance_reduction/svrg.py,sha256=9pBjPY4EMkGyfj68gXqPi1GJIolUVl5zyNtlZInCKKo,8635
125
+ torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
126
+ torchzero/modules/weight_decay/weight_decay.py,sha256=Y7kE_j0GRF8ceJ9SS6qykQ8a23X2OTDCjJ9VklOQSEw,5415
127
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
128
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
129
+ torchzero/modules/zeroth_order/__init__.py,sha256=1_6wNrytru7tEHXzRXmL4AnK39ILPgf8FMVtF_YmAYU,30
130
+ torchzero/modules/zeroth_order/cd.py,sha256=6NL_xe56w1RbPPgxcggnQnD9eWNq7ZrhZjv4bZwq2Ms,14951
131
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
132
+ torchzero/optim/root.py,sha256=gGtAJ9qBoSNV58EKzUGZ8J3lyKGUF8BEw34Zfprppdo,2273
133
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
134
+ torchzero/optim/utility/split.py,sha256=kraPCLAewX2uLbD_9R2dIrcF-kpUuT9IcxPeVrAARvA,1672
135
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
136
+ torchzero/optim/wrappers/directsearch.py,sha256=rimJIB2RrVzLpRPQKhzkrMQ4bTAEU3NEOT4pJQNIAHE,11309
137
+ torchzero/optim/wrappers/fcmaes.py,sha256=jKmmBKEwguYiJdvTRmAp5JSilxcUhtpRoKlzmp-lyWE,4251
138
+ torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
139
+ torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
140
+ torchzero/optim/wrappers/nlopt.py,sha256=RuLKretljjAjTZ9tXY3FyEDuB7mAboeGOQBupWfzPc4,8105
141
+ torchzero/optim/wrappers/optuna.py,sha256=pIXkC5NVmEnUQ4jsGaz6Gv9uYOZM9rnxME4UGkeolsE,2393
142
+ torchzero/optim/wrappers/scipy.py,sha256=A4yeQRdB9f65UrJ2g80NfqqMc6zUyr9js40TUESCHPg,21535
143
+ torchzero/utils/__init__.py,sha256=7S4VRTkfS-0uI8HOR0EFIjiEcKrmYK7LEhTocIgki6c,1112
144
+ torchzero/utils/compile.py,sha256=Dozox91tcShUJ3L320TTbJrcuA-l4WVegLAQujRqy94,5132
145
+ torchzero/utils/derivatives.py,sha256=zJ0xyedvlIwgAYMa1F5BBfyrkvgjXy7v7evvl6QAlT0,17195
146
+ torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
147
+ torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
148
+ torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
149
+ torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
150
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
151
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
152
+ torchzero/utils/python_tools.py,sha256=kdiGk-I0Q-GpIVu3pCROkWvUHiDgzsagLgEsTzZplQw,3427
153
+ torchzero/utils/tensorlist.py,sha256=nIWBME3fUQPsr4buvtV3LaJgSXPEG_Xb58KAzfjwK-k,56064
154
+ torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
155
+ torchzero/utils/linalg/__init__.py,sha256=cNoTJOPeqbNn9l7_HAAen2rlehGS3DyY5SveInG3Stc,328
156
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
157
+ torchzero/utils/linalg/linear_operator.py,sha256=uJUxvOVHpG3U3GNx61JGa_uM8GqzsNZmA_z7P0RwZ5E,12747
158
+ torchzero/utils/linalg/matrix_funcs.py,sha256=BKQK_oIG35R6yGxU80eBG0VkyY2EgxywqbhvU7JhWm4,3109
159
+ torchzero/utils/linalg/orthogonalize.py,sha256=BpuDiAPrsJMUpTNBMCntBNA8-O2nozLxY5ZbCfRlEFY,444
160
+ torchzero/utils/linalg/qr.py,sha256=5tbPEV9I6X69r5ACWF9XeqjZTUtUql2145uoGjlJNDs,2517
161
+ torchzero/utils/linalg/solve.py,sha256=R5lPTzHn2sgvRy4MRp-Ngl0sypSGLRLHJjf1oKKAJD0,14395
162
+ torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
163
+ torchzero-0.3.13.dist-info/METADATA,sha256=onWv9DCY_mvI2vm-1MYrkRfTfJvWDLKDpuNGZO1ill0,565
164
+ torchzero-0.3.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
165
+ torchzero-0.3.13.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
166
+ torchzero-0.3.13.dist-info/RECORD,,
@@ -1,3 +1,2 @@
1
- docs
2
1
  tests
3
2
  torchzero
docs/source/conf.py DELETED
@@ -1,57 +0,0 @@
1
- # Configuration file for the Sphinx documentation builder.
2
- #
3
- # For the full list of built-in configuration values, see the documentation:
4
- # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
-
6
- # -- Project information -----------------------------------------------------
7
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
- import sys, os
9
- #sys.path.insert(0, os.path.abspath('.../src'))
10
-
11
- project = 'torchzero'
12
- copyright = '2024, Ivan Nikishev'
13
- author = 'Ivan Nikishev'
14
-
15
- # -- General configuration ---------------------------------------------------
16
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
17
-
18
- # https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
19
- extensions = [
20
- 'sphinx.ext.autodoc',
21
- 'sphinx.ext.autosummary',
22
- 'sphinx.ext.viewcode',
23
- 'sphinx.ext.autosectionlabel',
24
- 'sphinx.ext.githubpages',
25
- 'sphinx.ext.napoleon',
26
- 'autoapi.extension',
27
- # 'sphinx_rtd_theme',
28
- ]
29
- autosummary_generate = True
30
- autoapi_dirs = ['../../src']
31
- autoapi_type = "python"
32
- # autoapi_ignore = ["*/tensorlist.py"]
33
-
34
- # https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
35
- autoapi_options = [
36
- "members",
37
- "undoc-members",
38
- "show-inheritance",
39
- "show-module-summary",
40
- "imported-members",
41
- ]
42
-
43
-
44
- templates_path = ['_templates']
45
- exclude_patterns = []
46
-
47
- # -- Options for HTML output -------------------------------------------------
48
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49
-
50
- #html_theme = 'alabaster'
51
- html_theme = 'furo'
52
- html_static_path = ['_static']
53
-
54
-
55
- # OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
56
- source_suffix = ['.rst', '.md']
57
- master_doc = 'index'