torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/utils/tensorlist.py
CHANGED
|
@@ -11,17 +11,19 @@ in an optimizer when you have to create one from parameters on each step. The so
|
|
|
11
11
|
it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
|
|
12
12
|
"""
|
|
13
13
|
import builtins
|
|
14
|
-
from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
|
|
15
14
|
import math
|
|
16
15
|
import operator
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
17
18
|
from typing import Any, Literal, TypedDict, overload
|
|
18
|
-
from typing_extensions import Self, TypeAlias, Unpack
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
|
-
from
|
|
22
|
-
from .python_tools import generic_eq, zipmap
|
|
23
|
-
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
21
|
+
from typing_extensions import Self, TypeAlias, Unpack
|
|
24
22
|
|
|
23
|
+
from .metrics import Metrics, evaluate_metric, calculate_metric_list
|
|
24
|
+
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
25
|
+
from .ops import where_
|
|
26
|
+
from .python_tools import generic_ne, zipmap
|
|
25
27
|
|
|
26
28
|
_Scalar = int | float | bool | complex
|
|
27
29
|
_TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
|
|
@@ -33,6 +35,7 @@ _STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
|
|
|
33
35
|
_Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
|
|
34
36
|
|
|
35
37
|
Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
|
|
38
|
+
|
|
36
39
|
class _NewTensorKwargs(TypedDict, total = False):
|
|
37
40
|
memory_format: Any
|
|
38
41
|
dtype: Any
|
|
@@ -217,6 +220,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
217
220
|
"""Returns a TensorList with all elements for which `fn` returned True."""
|
|
218
221
|
return self.__class__(i for i in self if fn(i, *args, **kwargs))
|
|
219
222
|
|
|
223
|
+
def filter_by_list(self, s: Sequence[bool]):
|
|
224
|
+
"""returns a new TensorList with all elements where corresponding elements in :code:`s` are True."""
|
|
225
|
+
if len(self) != len(s):
|
|
226
|
+
raise ValueError(f"{len(self) = }, {len(s) = }")
|
|
227
|
+
return self.__class__(i for i, boolean in zip(self, s) if boolean)
|
|
228
|
+
|
|
220
229
|
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
221
230
|
"""If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
|
|
222
231
|
Otherwise applies `fn` to this TensorList and `other`.
|
|
@@ -319,8 +328,20 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
319
328
|
def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
|
|
320
329
|
def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
|
|
321
330
|
def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
|
|
331
|
+
|
|
322
332
|
def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
|
|
323
|
-
return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
333
|
+
# return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
334
|
+
if ord == 1: return self.global_sum()
|
|
335
|
+
if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
|
|
336
|
+
if ord == torch.inf: return self.abs().global_max()
|
|
337
|
+
if ord == -torch.inf: return self.abs().global_min()
|
|
338
|
+
if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
|
|
339
|
+
|
|
340
|
+
return self.abs().pow_(ord).global_sum().pow(1/ord)
|
|
341
|
+
|
|
342
|
+
def global_metric(self, metric: Metrics) -> torch.Tensor:
|
|
343
|
+
return evaluate_metric(self, metric)
|
|
344
|
+
|
|
324
345
|
def global_any(self): return builtins.any(self.any())
|
|
325
346
|
def global_all(self): return builtins.all(self.all())
|
|
326
347
|
def global_numel(self) -> int: return builtins.sum(self.numel())
|
|
@@ -351,31 +372,54 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
351
372
|
|
|
352
373
|
def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
|
|
353
374
|
return self.zipmap_args(torch.randint_like, low, high, **kwargs)
|
|
375
|
+
|
|
354
376
|
def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
355
377
|
res = self.empty_like(**kwargs)
|
|
356
378
|
res.uniform_(low, high, generator=generator)
|
|
357
379
|
return res
|
|
380
|
+
|
|
358
381
|
def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
|
|
359
382
|
r = self.randn_like(generator=generator, **kwargs)
|
|
360
|
-
return (
|
|
383
|
+
return r.mul_(maybe_numberlist(radius) / r.global_vector_norm())
|
|
384
|
+
|
|
361
385
|
def bernoulli(self, generator = None):
|
|
362
386
|
return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
|
|
387
|
+
|
|
363
388
|
def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
364
389
|
"""p is probability of a 1, other values will be 0."""
|
|
365
390
|
return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
|
|
391
|
+
|
|
366
392
|
def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
367
393
|
"""p is probability of a 1, other values will be -1."""
|
|
368
394
|
return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
|
|
369
395
|
|
|
370
|
-
def sample_like(self,
|
|
396
|
+
def sample_like(self, distribution: Distributions = 'normal', variance: "_Scalar | _ScalarSeq | Sequence | None" = None, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
371
397
|
"""Sample around 0."""
|
|
372
|
-
if
|
|
398
|
+
if isinstance(variance, Sequence):
|
|
399
|
+
if all(v is None for v in variance): variance = None
|
|
400
|
+
else: variance = [v if v is not None else 1 for v in variance]
|
|
401
|
+
|
|
402
|
+
if distribution in ('normal', 'gaussian'):
|
|
403
|
+
ret = self.randn_like(generator=generator, **kwargs)
|
|
404
|
+
if variance is not None: ret *= variance
|
|
405
|
+
return ret
|
|
406
|
+
|
|
373
407
|
if distribution == 'uniform':
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
408
|
+
b = 1
|
|
409
|
+
if variance is not None:
|
|
410
|
+
b = ((12 * maybe_numberlist(variance)) ** 0.5) / 2
|
|
411
|
+
return self.uniform_like(-b, b, generator=generator, **kwargs)
|
|
412
|
+
|
|
413
|
+
if distribution == 'sphere':
|
|
414
|
+
if variance is None: radius = 1
|
|
415
|
+
else: radius = maybe_numberlist(variance) * math.sqrt(self.global_numel())
|
|
416
|
+
return self.sphere_like(radius, generator=generator, **kwargs)
|
|
417
|
+
|
|
418
|
+
if distribution == 'rademacher':
|
|
419
|
+
ret = self.rademacher_like(generator=generator, **kwargs)
|
|
420
|
+
if variance is not None: ret *= variance
|
|
421
|
+
return ret
|
|
422
|
+
|
|
379
423
|
raise ValueError(f'Unknow distribution {distribution}')
|
|
380
424
|
|
|
381
425
|
def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
|
|
@@ -425,11 +469,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
425
469
|
return self
|
|
426
470
|
|
|
427
471
|
def lazy_add(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
428
|
-
if
|
|
429
|
-
return self
|
|
472
|
+
if generic_ne(other, 0): return self.add(other)
|
|
473
|
+
return self
|
|
430
474
|
def lazy_add_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
431
|
-
if
|
|
432
|
-
return self
|
|
475
|
+
if generic_ne(other, 0): return self.add_(other)
|
|
476
|
+
return self
|
|
433
477
|
|
|
434
478
|
@overload
|
|
435
479
|
def sub(self, other: _TensorSeq, alpha: _Scalar = 1): ...
|
|
@@ -449,11 +493,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
449
493
|
return self
|
|
450
494
|
|
|
451
495
|
def lazy_sub(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
452
|
-
if
|
|
453
|
-
return self
|
|
496
|
+
if generic_ne(other, 0): return self.sub(other)
|
|
497
|
+
return self
|
|
454
498
|
def lazy_sub_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
455
|
-
if
|
|
456
|
-
return self
|
|
499
|
+
if generic_ne(other, 0): return self.sub_(other)
|
|
500
|
+
return self
|
|
457
501
|
|
|
458
502
|
def neg(self): return self.__class__(torch._foreach_neg(self))
|
|
459
503
|
def neg_(self):
|
|
@@ -467,13 +511,13 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
467
511
|
|
|
468
512
|
# TODO: benchmark
|
|
469
513
|
def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
|
|
470
|
-
if
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
return self
|
|
514
|
+
if generic_ne(other, 1):
|
|
515
|
+
return self * other
|
|
516
|
+
if clone: return self.clone()
|
|
517
|
+
return self
|
|
474
518
|
def lazy_mul_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
475
|
-
if
|
|
476
|
-
return self
|
|
519
|
+
if generic_ne(other, 1): return self.mul_(other)
|
|
520
|
+
return self
|
|
477
521
|
|
|
478
522
|
def div(self, other: _STOrSTSeq) -> Self: return self.__class__(torch._foreach_div(self, other))
|
|
479
523
|
def div_(self, other: _STOrSTSeq):
|
|
@@ -481,11 +525,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
481
525
|
return self
|
|
482
526
|
|
|
483
527
|
def lazy_div(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
484
|
-
if
|
|
485
|
-
return self
|
|
528
|
+
if generic_ne(other, 1): return self / other
|
|
529
|
+
return self
|
|
486
530
|
def lazy_div_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
487
|
-
if
|
|
488
|
-
return self
|
|
531
|
+
if generic_ne(other, 1): return self.div_(other)
|
|
532
|
+
return self
|
|
489
533
|
|
|
490
534
|
def pow(self, exponent: "_Scalar | _STSeq"): return self.__class__(torch._foreach_pow(self, exponent))
|
|
491
535
|
def pow_(self, exponent: "_Scalar | _STSeq"):
|
|
@@ -497,6 +541,11 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
497
541
|
torch._foreach_pow_(input, self)
|
|
498
542
|
return self
|
|
499
543
|
|
|
544
|
+
def square(self): return self.__class__(torch._foreach_pow(self, 2))
|
|
545
|
+
def square_(self):
|
|
546
|
+
torch._foreach_pow_(self, 2)
|
|
547
|
+
return self
|
|
548
|
+
|
|
500
549
|
def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
|
|
501
550
|
def sqrt_(self):
|
|
502
551
|
torch._foreach_sqrt_(self)
|
|
@@ -627,9 +676,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
627
676
|
if dim is None: dim = ()
|
|
628
677
|
return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
|
|
629
678
|
|
|
630
|
-
def norm(self, ord:
|
|
679
|
+
def norm(self, ord: float, dtype=None):
|
|
631
680
|
return self.__class__(torch._foreach_norm(self, ord, dtype))
|
|
632
681
|
|
|
682
|
+
def metric(self, metric: Metrics) -> "TensorList":
|
|
683
|
+
return calculate_metric_list(self, metric)
|
|
684
|
+
|
|
633
685
|
def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
634
686
|
if dim == 'global': return self._global_fn(keepdim, self.global_mean)
|
|
635
687
|
return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
|
|
@@ -782,29 +834,29 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
782
834
|
for t, o in zip(self, other): t.copysign_(o)
|
|
783
835
|
return self
|
|
784
836
|
|
|
785
|
-
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord:
|
|
837
|
+
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
786
838
|
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
787
839
|
if tensorwise:
|
|
788
|
-
norm_self = self.
|
|
789
|
-
norm_other = magnitude.
|
|
840
|
+
norm_self = self.metric(ord)
|
|
841
|
+
norm_other = magnitude.metric(ord)
|
|
790
842
|
else:
|
|
791
|
-
norm_self = self.
|
|
792
|
-
norm_other = magnitude.
|
|
843
|
+
norm_self = self.global_metric(ord)
|
|
844
|
+
norm_other = magnitude.global_metric(ord)
|
|
793
845
|
|
|
794
|
-
if
|
|
846
|
+
if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
795
847
|
|
|
796
848
|
return self * (norm_other / norm_self.clip_(min=eps))
|
|
797
849
|
|
|
798
|
-
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord:
|
|
850
|
+
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: Metrics = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
799
851
|
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
800
852
|
if tensorwise:
|
|
801
|
-
norm_self = self.
|
|
802
|
-
norm_other = magnitude.
|
|
853
|
+
norm_self = self.metric(ord)
|
|
854
|
+
norm_other = magnitude.metric(ord)
|
|
803
855
|
else:
|
|
804
|
-
norm_self = self.
|
|
805
|
-
norm_other = magnitude.
|
|
856
|
+
norm_self = self.global_metric(ord)
|
|
857
|
+
norm_other = magnitude.global_metric(ord)
|
|
806
858
|
|
|
807
|
-
if
|
|
859
|
+
if generic_ne(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
808
860
|
|
|
809
861
|
return self.mul_(norm_other / norm_self.clip_(min=eps))
|
|
810
862
|
|
|
@@ -897,14 +949,14 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
897
949
|
if eps!=0: std.add_(eps)
|
|
898
950
|
return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
|
|
899
951
|
|
|
900
|
-
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:
|
|
952
|
+
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
|
|
901
953
|
"""calculate multipler to clip self norm to min and max"""
|
|
902
954
|
if tensorwise:
|
|
903
|
-
self_norm = self.
|
|
955
|
+
self_norm = self.metric(ord)
|
|
904
956
|
self_norm.masked_fill_(self_norm == 0, 1)
|
|
905
957
|
|
|
906
958
|
else:
|
|
907
|
-
self_norm = self.
|
|
959
|
+
self_norm = self.global_metric(ord)
|
|
908
960
|
if self_norm == 0: return 1
|
|
909
961
|
|
|
910
962
|
mul = 1
|
|
@@ -918,12 +970,12 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
918
970
|
|
|
919
971
|
return mul
|
|
920
972
|
|
|
921
|
-
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:
|
|
973
|
+
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
|
|
922
974
|
"""clips norm of each tensor to (min, max) range"""
|
|
923
975
|
if min is None and max is None: return self
|
|
924
976
|
return self * self._clip_multiplier(min, max, tensorwise, ord)
|
|
925
977
|
|
|
926
|
-
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:
|
|
978
|
+
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:Metrics = 2):
|
|
927
979
|
"""clips norm of each tensor to (min, max) range"""
|
|
928
980
|
if min is None and max is None: return self
|
|
929
981
|
return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
|
|
@@ -982,6 +1034,15 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
982
1034
|
# """sets index in flattened view"""
|
|
983
1035
|
# return self.clone().flatset_(idx, value)
|
|
984
1036
|
|
|
1037
|
+
def flat_get(self, idx: int):
|
|
1038
|
+
cur = 0
|
|
1039
|
+
for tensor in self:
|
|
1040
|
+
numel = tensor.numel()
|
|
1041
|
+
if idx < cur + numel:
|
|
1042
|
+
return tensor.view(-1)[cur-idx]
|
|
1043
|
+
cur += numel
|
|
1044
|
+
raise IndexError(idx)
|
|
1045
|
+
|
|
985
1046
|
def flat_set_(self, idx: int, value: Any):
|
|
986
1047
|
"""sets index in flattened view"""
|
|
987
1048
|
cur = 0
|
|
@@ -1057,6 +1118,19 @@ def generic_numel(x: torch.Tensor | TensorList) -> int:
|
|
|
1057
1118
|
if isinstance(x, torch.Tensor): return x.numel()
|
|
1058
1119
|
return x.global_numel()
|
|
1059
1120
|
|
|
1121
|
+
|
|
1122
|
+
def generic_finfo(x: torch.Tensor | TensorList) -> torch.finfo:
|
|
1123
|
+
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype)
|
|
1124
|
+
return torch.finfo(x[0].dtype)
|
|
1125
|
+
|
|
1126
|
+
def generic_finfo_eps(x: torch.Tensor | TensorList) -> float:
|
|
1127
|
+
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).eps
|
|
1128
|
+
return torch.finfo(x[0].dtype).eps
|
|
1129
|
+
|
|
1130
|
+
def generic_finfo_tiny(x: torch.Tensor | TensorList) -> float:
|
|
1131
|
+
if isinstance(x, torch.Tensor): return torch.finfo(x.dtype).tiny
|
|
1132
|
+
return torch.finfo(x[0].dtype).tiny
|
|
1133
|
+
|
|
1060
1134
|
@overload
|
|
1061
1135
|
def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
1062
1136
|
@overload
|
|
@@ -1069,7 +1143,8 @@ def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
|
|
|
1069
1143
|
if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
|
|
1070
1144
|
return x.global_vector_norm(ord)
|
|
1071
1145
|
|
|
1072
|
-
|
|
1146
|
+
def generic_metric(x: torch.Tensor | TensorList, metric: Metrics) -> torch.Tensor:
|
|
1147
|
+
return evaluate_metric(x, metric)
|
|
1073
1148
|
|
|
1074
1149
|
@overload
|
|
1075
1150
|
def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
@@ -1079,3 +1154,11 @@ def generic_randn_like(x: torch.Tensor | TensorList):
|
|
|
1079
1154
|
if isinstance(x, torch.Tensor): return torch.randn_like(x)
|
|
1080
1155
|
return x.randn_like()
|
|
1081
1156
|
|
|
1157
|
+
|
|
1158
|
+
def generic_sum(x: torch.Tensor | TensorList) -> torch.Tensor:
|
|
1159
|
+
if isinstance(x, torch.Tensor): return x.sum()
|
|
1160
|
+
return x.global_sum()
|
|
1161
|
+
|
|
1162
|
+
def generic_max(x: torch.Tensor | TensorList) -> torch.Tensor:
|
|
1163
|
+
if isinstance(x, torch.Tensor): return x.max()
|
|
1164
|
+
return x.global_max()
|
torchzero/utils/torch_tools.py
CHANGED
|
@@ -7,10 +7,15 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def totensor(x):
|
|
11
|
-
if
|
|
12
|
-
|
|
13
|
-
|
|
10
|
+
def totensor(x, device=None, dtype=None):
|
|
11
|
+
if device is None and dtype is None:
|
|
12
|
+
if isinstance(x, torch.Tensor): return x
|
|
13
|
+
if isinstance(x, np.ndarray): return torch.from_numpy(x)
|
|
14
|
+
return torch.from_numpy(np.asarray(x))
|
|
15
|
+
|
|
16
|
+
if isinstance(x, torch.Tensor): return x.to(device=device, dtype=dtype)
|
|
17
|
+
if isinstance(x, np.ndarray): return torch.as_tensor(x, device=device, dtype=dtype)
|
|
18
|
+
return torch.as_tensor(np.asarray(x), device=device, dtype=dtype)
|
|
14
19
|
|
|
15
20
|
def tonumpy(x):
|
|
16
21
|
if isinstance(x, np.ndarray): return x
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchzero
|
|
3
|
+
Version: 0.3.13
|
|
4
|
+
Summary: Modular optimization library for PyTorch.
|
|
5
|
+
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/inikishev/torchzero
|
|
7
|
+
Project-URL: Repository, https://github.com/inikishev/torchzero
|
|
8
|
+
Project-URL: Issues, https://github.com/inikishev/torchzero/isses
|
|
9
|
+
Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
|
|
10
|
+
Requires-Python: >=3.10
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
Requires-Dist: torch
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: typing_extensions
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
|
|
2
|
+
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
3
|
+
tests/test_opts.py,sha256=aT6-RbyUhWlIhMPi-ihqZtfiiYk0eT9vEIxMyxvwVOk,44059
|
|
4
|
+
tests/test_tensorlist.py,sha256=pWXQE-vEq08EGJSKWgsTgo-7QjjkavOJ5BlWUm241qI,72434
|
|
5
|
+
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
6
|
+
tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
|
|
7
|
+
torchzero/__init__.py,sha256=aIH-cCTXnDr90cKUPhM8bv-uE69Hzjlf0jlYspYf0ZM,120
|
|
8
|
+
torchzero/core/__init__.py,sha256=aYyQt-CHzWT6hGUt5KVjRZZr2lsX5I1XvbWpzaAv3VE,151
|
|
9
|
+
torchzero/core/module.py,sha256=o4MZnJbMpk-F1qm-XvbguE_W_9a0sO2Mb8iDUZQ42B4,40511
|
|
10
|
+
torchzero/core/reformulation.py,sha256=jppgzXBtqdsc7ot6_Gr38vJbbrhG1Gs4vC32y7iB4BA,2387
|
|
11
|
+
torchzero/core/transform.py,sha256=xRDpsZj0H1QcFdO-t2mNMNOYoqqnRHiI3K1YluWwCVk,17097
|
|
12
|
+
torchzero/modules/__init__.py,sha256=3lGta9P0N3cWdVcruCBJ7uqu4DfLTPCKI_mlOZT6Z_o,615
|
|
13
|
+
torchzero/modules/functional.py,sha256=E_d6hLT2_xdE-3AhQ4AthDYK5uZULbF10iHI09Z3_yk,7921
|
|
14
|
+
torchzero/modules/adaptive/__init__.py,sha256=5L2dlEJV6HKBnYhgd7wo2yGi0WPd9qmpw9XS5wOQOq8,944
|
|
15
|
+
torchzero/modules/adaptive/adagrad.py,sha256=0qXC5F4PuOsgLjRXQUWBoiq0AUixsvOP1uDbEeRIcNs,12531
|
|
16
|
+
torchzero/modules/adaptive/adahessian.py,sha256=rWxgDiBMd6MK64mRjZwiudDF07is5AFj_MCEvBD7h8U,8670
|
|
17
|
+
torchzero/modules/adaptive/adam.py,sha256=4lWSe__tdyRv0rfkUda1qa_NH36DIW3sd8td98bK6XI,3829
|
|
18
|
+
torchzero/modules/adaptive/adan.py,sha256=Dt_gibyrGtWDIUCaSF6RFIxu2xwiF9fCfrpkoD-CaUM,2825
|
|
19
|
+
torchzero/modules/adaptive/adaptive_heavyball.py,sha256=xQQw1Vx-NgQd_ouK14J1p5ijd5lEn3sAN0hVJAL0j8U,2024
|
|
20
|
+
torchzero/modules/adaptive/aegd.py,sha256=_4ASgDX8__DPnnBE_RncnMqM4rItM7Eji4EZzSGGq5I,1876
|
|
21
|
+
torchzero/modules/adaptive/esgd.py,sha256=DtrN3hZhGK4LgMxmcjQCBtEO-5hLZrAdnvps6f8WQ2A,6416
|
|
22
|
+
torchzero/modules/adaptive/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
|
|
23
|
+
torchzero/modules/adaptive/lmadagrad.py,sha256=rMs7vrgiwOJgWo-OZXkGu32X561edEhhaxwgbY2NTnk,7176
|
|
24
|
+
torchzero/modules/adaptive/mars.py,sha256=iOkyY3r52btp7Cry7WN0AB4arpn4N9b_Hg6S55XC6Q8,2255
|
|
25
|
+
torchzero/modules/adaptive/matrix_momentum.py,sha256=ZMSdGNSHgyUgtJwjKzK7PZDwMrPw5wZJVu0K4Xa-SpI,6693
|
|
26
|
+
torchzero/modules/adaptive/msam.py,sha256=locqM2jiC3AbGCDCo6T40GF3iaVV2svrkzoi0hD2cJI,6663
|
|
27
|
+
torchzero/modules/adaptive/muon.py,sha256=5Asgj03s6JXXrO-p5Qgn3D8bVwbDEqCq7hxNy4joQDE,10335
|
|
28
|
+
torchzero/modules/adaptive/natural_gradient.py,sha256=5qRehh-iAeZ4hjfOR-gfsObbMsJZnipzCkG-yptkrH0,6349
|
|
29
|
+
torchzero/modules/adaptive/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
|
|
30
|
+
torchzero/modules/adaptive/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
|
|
31
|
+
torchzero/modules/adaptive/rprop.py,sha256=VDnLPKxw8ECihyUeNVE8cyDll_Ut3k3_NqoLgpgxxLA,11818
|
|
32
|
+
torchzero/modules/adaptive/sam.py,sha256=LnOPNZnIUsis0402RHnA-fTPkNM8baUR9HR50pF_BtM,5696
|
|
33
|
+
torchzero/modules/adaptive/shampoo.py,sha256=r7V4I5_Ve1YVOS3HhO2k5cZvJT1lPHTVApV3iVJVceA,9711
|
|
34
|
+
torchzero/modules/adaptive/soap.py,sha256=roQLBthNNNmTYgeJPi_LxZY-r4m6REeUo0_DZknYU50,10662
|
|
35
|
+
torchzero/modules/adaptive/sophia_h.py,sha256=lSK8uVdOxBAhU2jE6fyIx1YgqEQyZG-Fv9o2TniAZzk,7179
|
|
36
|
+
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
37
|
+
torchzero/modules/clipping/clipping.py,sha256=t98M3QKZKqXJ3_tzXXIiG4EOYMaHqLYrMZ-6zmRuy-k,14331
|
|
38
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=Ki0LPNUwPoE825A5rSE7SxGQMiI3nO3iwnjKQ486iaI,6611
|
|
39
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
|
|
40
|
+
torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
|
|
41
|
+
torchzero/modules/conjugate_gradient/cg.py,sha256=iAIiIyfM5hWeFH6-gxx8y-5olY0ED4DpnbLzXa9dke4,14492
|
|
42
|
+
torchzero/modules/experimental/__init__.py,sha256=blI-OhpQAC6-Ho1uxUq-t7Mm9CAMnNMXkBDmXul8tbc,729
|
|
43
|
+
torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
|
|
44
|
+
torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
|
|
45
|
+
torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
|
|
46
|
+
torchzero/modules/experimental/gradmin.py,sha256=hKTOG7tk6FnG8t-7OmTAhGTGSDdONzP1JvCRPRqaKt0,3740
|
|
47
|
+
torchzero/modules/experimental/l_infinity.py,sha256=nhYusM3YYbc0ptaSf9zlrsqY8EgxlHm9OejJ6VV0qtM,4750
|
|
48
|
+
torchzero/modules/experimental/momentum.py,sha256=VqZc14EGVO_KUABPLRIBlvHdgg-64o-4heMQH0vW5vY,5233
|
|
49
|
+
torchzero/modules/experimental/newton_solver.py,sha256=0HnDBlrBLvUgS4hWmkJqyw0M7UPFp3kU3SFq2xZVYhQ,5454
|
|
50
|
+
torchzero/modules/experimental/newtonnewton.py,sha256=a0XXvlVe37z2MMcQ4TeGbbWX9OuYp_5b-21jq3o1z3E,3823
|
|
51
|
+
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
|
|
52
|
+
torchzero/modules/experimental/scipy_newton_cg.py,sha256=8nKBaHHmqdU9F1kVPn2QVFUTx2_I8Jsfqxix1v-qoL0,4073
|
|
53
|
+
torchzero/modules/experimental/structural_projections.py,sha256=rxJFG5F23dOiK_8KqKyvoSMLWqAOXtVGHSwfRqH22Wg,4185
|
|
54
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
55
|
+
torchzero/modules/grad_approximation/fdm.py,sha256=zx70GZDQmhe43bZP5Mbbl31xsMOsGO43kznoQDbqxJo,4372
|
|
56
|
+
torchzero/modules/grad_approximation/forward_gradient.py,sha256=v7v5maMC_Vak7N5w-LjIH6FIrkQUt7MvbR0PLprsmTI,4338
|
|
57
|
+
torchzero/modules/grad_approximation/grad_approximator.py,sha256=x8vlweBrfJ6SmhMHvI_C8UZGzlS3AnmlulvqnSzm6iY,4437
|
|
58
|
+
torchzero/modules/grad_approximation/rfdm.py,sha256=PA0FlGVaNHxstUqPqG1wMuIF4i9NpCqPpp6-1vxCboA,23635
|
|
59
|
+
torchzero/modules/higher_order/__init__.py,sha256=iaoIrmR9DJE9QHt9PeZNCWqIYDe-86h1IjkaumR4qF0,51
|
|
60
|
+
torchzero/modules/higher_order/higher_order_newton.py,sha256=2r1wuhdi57pbo8akQE88O8R-Y79BtiwD1WQIShh1rjQ,12967
|
|
61
|
+
torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
|
|
62
|
+
torchzero/modules/least_squares/gn.py,sha256=23AB6AWAl5IuBj4Vd3boQ6ndk0pO3ovaF9EiY1a1XWs,5094
|
|
63
|
+
torchzero/modules/line_search/__init__.py,sha256=mFWgkcgfMkL2NKj3CLbuwee3e8WHBOaXs-wtx3oTW58,216
|
|
64
|
+
torchzero/modules/line_search/_polyinterp.py,sha256=qIhcLjOlpB6NHU0oiUGMncwQxWNfy8757orsbzjkp6s,10882
|
|
65
|
+
torchzero/modules/line_search/adaptive.py,sha256=8Ip5F5PpsDLgg6TwB_E7zIZheycd78coRg4u7cpO3Cg,3795
|
|
66
|
+
torchzero/modules/line_search/backtracking.py,sha256=Mhx8_UT_Mr1gASYHUorBJ38E4YlcM9LpW9YrJHYfLXU,9049
|
|
67
|
+
torchzero/modules/line_search/line_search.py,sha256=lmtjr9Zpz9RYJXoYaJnpXkBSIdcN6DdwGKKXTCmcJNU,13294
|
|
68
|
+
torchzero/modules/line_search/scipy.py,sha256=gQMi6IYrnDvYsZWIO_cELhg_VZutIQJAELHHlLyu2fg,2286
|
|
69
|
+
torchzero/modules/line_search/strong_wolfe.py,sha256=lXcfOzg4kU0RGTe7GnVWuDPs2YWAjHH9vwRMDEGz4Mg,15054
|
|
70
|
+
torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
|
|
71
|
+
torchzero/modules/misc/debug.py,sha256=6pFAGYANjCPGIZH_4ghpUYYTEsT5jr7PMB9KLuPP4p8,1532
|
|
72
|
+
torchzero/modules/misc/escape.py,sha256=qfEdKLD5rejqrmvyHrI5BRQq8js9UF2-Axs_C0KFyWA,1866
|
|
73
|
+
torchzero/modules/misc/gradient_accumulation.py,sha256=mBWa5CBCZwp4TrtOyjWI3VnHag4gum4WBM2WFhvHqW4,4891
|
|
74
|
+
torchzero/modules/misc/homotopy.py,sha256=hihLETE4dNZ27zatqPR_qT3kGX-AXbC7oBWRDbFQo58,1939
|
|
75
|
+
torchzero/modules/misc/misc.py,sha256=feI-IQlxhIoAbsSRTjE4SbGez1c2Uu9-WA_nkK7iiqQ,15411
|
|
76
|
+
torchzero/modules/misc/multistep.py,sha256=RtDFIeTHu4RcERvlKEP4_10-lpRZOgbnBeSah92dQ7A,6323
|
|
77
|
+
torchzero/modules/misc/regularization.py,sha256=SkQ0_Ybtv9IEGI9QGdvNZaja5bAyc1x-j_1gvYIVepI,6105
|
|
78
|
+
torchzero/modules/misc/split.py,sha256=JcXVB4xk3h55YT2OAdepVsRoE1PD7bqX6NmJ2IxBgAI,4013
|
|
79
|
+
torchzero/modules/misc/switch.py,sha256=p758heAnv-PkoslpafL35Yp7mlvPmDVSe1mWiuuD8Mk,3711
|
|
80
|
+
torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
|
|
81
|
+
torchzero/modules/momentum/averaging.py,sha256=vDW8tgGsEuBXF_BTUYHB_j--TIVam9j0nZdp_x8TkxY,3229
|
|
82
|
+
torchzero/modules/momentum/cautious.py,sha256=x506a3lUETRpxPWqXLFJVFBH1gmLqIfqL5J-hFdEvOo,8051
|
|
83
|
+
torchzero/modules/momentum/momentum.py,sha256=q3n0BvQURuSBzA9vn1ZrH-n7Nsr0AS-38VJuwraQPY0,4495
|
|
84
|
+
torchzero/modules/ops/__init__.py,sha256=9UHaXs9aaKc0ewAhicTlDmj42bSC_vddMOD0eYuUj_8,1226
|
|
85
|
+
torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
|
|
86
|
+
torchzero/modules/ops/binary.py,sha256=2hV2oruaq5Asu4Ts8X8yiZQ-07fU0RGpRy3-vifXqXY,12151
|
|
87
|
+
torchzero/modules/ops/higher_level.py,sha256=E76zgSHlhVpHLrXhnVwelIQFm1IKn0IFcVq7DOJw0es,9037
|
|
88
|
+
torchzero/modules/ops/multi.py,sha256=YC3rBTmPRwF5aEPDNsyTK4J_JEAbmE7oBmF7W-VOV3A,8588
|
|
89
|
+
torchzero/modules/ops/reduce.py,sha256=kALG7X8q02sWpo1skpXjS0r875gwq6mrhLZbFfYaZoA,6324
|
|
90
|
+
torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
|
|
91
|
+
torchzero/modules/ops/utility.py,sha256=_k9S59i6IYOzzfIQlToQ9mlDseTTAS_49wujUxMGXZo,4105
|
|
92
|
+
torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
|
|
93
|
+
torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
|
|
94
|
+
torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
|
|
95
|
+
torchzero/modules/projections/projection.py,sha256=itkkb2UmMqbdtWKjUUg6gbFJfCEIZAskC0HCvom-6sc,14084
|
|
96
|
+
torchzero/modules/quasi_newton/__init__.py,sha256=HxXENs3O6nFRfCvUJhWPK9f8_A6iMwB6UF1Zold12UQ,515
|
|
97
|
+
torchzero/modules/quasi_newton/damping.py,sha256=K1DVqqKiAs6-F3JQh5jlKNb79oJdObqnKWwHHRl6boQ,2813
|
|
98
|
+
torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=Zx-tlFRa89GhoSP7RFJdLQJPiqPCL7rWaV7WJoQ1YCs,6930
|
|
99
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=fzCjV5YsLo_uJTVG3vosPHsvDc97mLKueK6fxOHLb8I,11195
|
|
100
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=D3_yV5xtgklMlU4fAL1-sH82-1tNl3K2F12ZBZyLQGM,8512
|
|
101
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=-xUGPld8Y0MHwN6qsmDihLbUbulU0T1z8jf2mZhNpcE,44529
|
|
102
|
+
torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
|
|
103
|
+
torchzero/modules/restarts/restars.py,sha256=A3fpTIbfpZCEUq9csPckdcsXQtaL0Le5UY3ZfKzxVSs,8971
|
|
104
|
+
torchzero/modules/second_order/__init__.py,sha256=lTGccDNVwPuMevMeKi5O0a9cl24Rn9tk7VkC6jvlGYc,233
|
|
105
|
+
torchzero/modules/second_order/multipoint.py,sha256=Ilzo0Ddd3iApegceu7cHSMGim9ZH5QS4-2uBtrKXC6k,8581
|
|
106
|
+
torchzero/modules/second_order/newton.py,sha256=PAPbJzssx0Ji328BFOEzeJZPd3IubJTPHs6ZhqS_nW8,15663
|
|
107
|
+
torchzero/modules/second_order/newton_cg.py,sha256=zavattL2z-IjWRT_AdwV5h7BGtQnrBzMTtTyt9xjZ-I,17363
|
|
108
|
+
torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
|
|
109
|
+
torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
|
|
110
|
+
torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
|
|
111
|
+
torchzero/modules/smoothing/sampling.py,sha256=zI5bATytQmCqm_UgAQbfA9tNRgrZaKLfUb0B-kzKRHU,12867
|
|
112
|
+
torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
|
|
113
|
+
torchzero/modules/step_size/adaptive.py,sha256=HvffW3m1NnpMTpps0QjJTvbblSODxxWMBBFTbNwp0vM,14482
|
|
114
|
+
torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
|
|
115
|
+
torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
|
|
116
|
+
torchzero/modules/termination/termination.py,sha256=BXU3R04caBc8rFJ4v_yJjgGi1X4iA11eYwlbiJfxexI,6637
|
|
117
|
+
torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
|
|
118
|
+
torchzero/modules/trust_region/cubic_regularization.py,sha256=gbKOR5zBo3t9i-sW23DCtTQwZrBubuFy_VuafrLaeUw,6718
|
|
119
|
+
torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
|
|
120
|
+
torchzero/modules/trust_region/levenberg_marquardt.py,sha256=Ibyf3ivcGR9sPkD5COXi7dRk3PSOfyTlI1W8ISAHNa8,5039
|
|
121
|
+
torchzero/modules/trust_region/trust_cg.py,sha256=UdQxNx7jf_CxyioRtJ92z35QU5HDbI22xpgd-4pW7V8,4297
|
|
122
|
+
torchzero/modules/trust_region/trust_region.py,sha256=eimCFViJSzoubrRmDluCon6mfcyT7PQA0yRPu4FlO2Q,12872
|
|
123
|
+
torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
|
|
124
|
+
torchzero/modules/variance_reduction/svrg.py,sha256=9pBjPY4EMkGyfj68gXqPi1GJIolUVl5zyNtlZInCKKo,8635
|
|
125
|
+
torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
|
|
126
|
+
torchzero/modules/weight_decay/weight_decay.py,sha256=Y7kE_j0GRF8ceJ9SS6qykQ8a23X2OTDCjJ9VklOQSEw,5415
|
|
127
|
+
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
128
|
+
torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
|
|
129
|
+
torchzero/modules/zeroth_order/__init__.py,sha256=1_6wNrytru7tEHXzRXmL4AnK39ILPgf8FMVtF_YmAYU,30
|
|
130
|
+
torchzero/modules/zeroth_order/cd.py,sha256=6NL_xe56w1RbPPgxcggnQnD9eWNq7ZrhZjv4bZwq2Ms,14951
|
|
131
|
+
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
132
|
+
torchzero/optim/root.py,sha256=gGtAJ9qBoSNV58EKzUGZ8J3lyKGUF8BEw34Zfprppdo,2273
|
|
133
|
+
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
134
|
+
torchzero/optim/utility/split.py,sha256=kraPCLAewX2uLbD_9R2dIrcF-kpUuT9IcxPeVrAARvA,1672
|
|
135
|
+
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
136
|
+
torchzero/optim/wrappers/directsearch.py,sha256=rimJIB2RrVzLpRPQKhzkrMQ4bTAEU3NEOT4pJQNIAHE,11309
|
|
137
|
+
torchzero/optim/wrappers/fcmaes.py,sha256=jKmmBKEwguYiJdvTRmAp5JSilxcUhtpRoKlzmp-lyWE,4251
|
|
138
|
+
torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
|
|
139
|
+
torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
|
|
140
|
+
torchzero/optim/wrappers/nlopt.py,sha256=RuLKretljjAjTZ9tXY3FyEDuB7mAboeGOQBupWfzPc4,8105
|
|
141
|
+
torchzero/optim/wrappers/optuna.py,sha256=pIXkC5NVmEnUQ4jsGaz6Gv9uYOZM9rnxME4UGkeolsE,2393
|
|
142
|
+
torchzero/optim/wrappers/scipy.py,sha256=A4yeQRdB9f65UrJ2g80NfqqMc6zUyr9js40TUESCHPg,21535
|
|
143
|
+
torchzero/utils/__init__.py,sha256=7S4VRTkfS-0uI8HOR0EFIjiEcKrmYK7LEhTocIgki6c,1112
|
|
144
|
+
torchzero/utils/compile.py,sha256=Dozox91tcShUJ3L320TTbJrcuA-l4WVegLAQujRqy94,5132
|
|
145
|
+
torchzero/utils/derivatives.py,sha256=zJ0xyedvlIwgAYMa1F5BBfyrkvgjXy7v7evvl6QAlT0,17195
|
|
146
|
+
torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
|
|
147
|
+
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
148
|
+
torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
|
|
149
|
+
torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
|
|
150
|
+
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
151
|
+
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
152
|
+
torchzero/utils/python_tools.py,sha256=kdiGk-I0Q-GpIVu3pCROkWvUHiDgzsagLgEsTzZplQw,3427
|
|
153
|
+
torchzero/utils/tensorlist.py,sha256=nIWBME3fUQPsr4buvtV3LaJgSXPEG_Xb58KAzfjwK-k,56064
|
|
154
|
+
torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
|
|
155
|
+
torchzero/utils/linalg/__init__.py,sha256=cNoTJOPeqbNn9l7_HAAen2rlehGS3DyY5SveInG3Stc,328
|
|
156
|
+
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
157
|
+
torchzero/utils/linalg/linear_operator.py,sha256=uJUxvOVHpG3U3GNx61JGa_uM8GqzsNZmA_z7P0RwZ5E,12747
|
|
158
|
+
torchzero/utils/linalg/matrix_funcs.py,sha256=BKQK_oIG35R6yGxU80eBG0VkyY2EgxywqbhvU7JhWm4,3109
|
|
159
|
+
torchzero/utils/linalg/orthogonalize.py,sha256=BpuDiAPrsJMUpTNBMCntBNA8-O2nozLxY5ZbCfRlEFY,444
|
|
160
|
+
torchzero/utils/linalg/qr.py,sha256=5tbPEV9I6X69r5ACWF9XeqjZTUtUql2145uoGjlJNDs,2517
|
|
161
|
+
torchzero/utils/linalg/solve.py,sha256=R5lPTzHn2sgvRy4MRp-Ngl0sypSGLRLHJjf1oKKAJD0,14395
|
|
162
|
+
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
163
|
+
torchzero-0.3.13.dist-info/METADATA,sha256=onWv9DCY_mvI2vm-1MYrkRfTfJvWDLKDpuNGZO1ill0,565
|
|
164
|
+
torchzero-0.3.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
165
|
+
torchzero-0.3.13.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
|
|
166
|
+
torchzero-0.3.13.dist-info/RECORD,,
|
docs/source/conf.py
DELETED
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
# Configuration file for the Sphinx documentation builder.
|
|
2
|
-
#
|
|
3
|
-
# For the full list of built-in configuration values, see the documentation:
|
|
4
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
|
5
|
-
|
|
6
|
-
# -- Project information -----------------------------------------------------
|
|
7
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
8
|
-
import sys, os
|
|
9
|
-
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
10
|
-
|
|
11
|
-
project = 'torchzero'
|
|
12
|
-
copyright = '2024, Ivan Nikishev'
|
|
13
|
-
author = 'Ivan Nikishev'
|
|
14
|
-
|
|
15
|
-
# -- General configuration ---------------------------------------------------
|
|
16
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
|
17
|
-
|
|
18
|
-
# https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
|
|
19
|
-
extensions = [
|
|
20
|
-
'sphinx.ext.autodoc',
|
|
21
|
-
'sphinx.ext.autosummary',
|
|
22
|
-
'sphinx.ext.viewcode',
|
|
23
|
-
'sphinx.ext.autosectionlabel',
|
|
24
|
-
'sphinx.ext.githubpages',
|
|
25
|
-
'sphinx.ext.napoleon',
|
|
26
|
-
'autoapi.extension',
|
|
27
|
-
# 'sphinx_rtd_theme',
|
|
28
|
-
]
|
|
29
|
-
autosummary_generate = True
|
|
30
|
-
autoapi_dirs = ['../../src']
|
|
31
|
-
autoapi_type = "python"
|
|
32
|
-
# autoapi_ignore = ["*/tensorlist.py"]
|
|
33
|
-
|
|
34
|
-
# https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
|
|
35
|
-
autoapi_options = [
|
|
36
|
-
"members",
|
|
37
|
-
"undoc-members",
|
|
38
|
-
"show-inheritance",
|
|
39
|
-
"show-module-summary",
|
|
40
|
-
"imported-members",
|
|
41
|
-
]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
templates_path = ['_templates']
|
|
45
|
-
exclude_patterns = []
|
|
46
|
-
|
|
47
|
-
# -- Options for HTML output -------------------------------------------------
|
|
48
|
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
|
49
|
-
|
|
50
|
-
#html_theme = 'alabaster'
|
|
51
|
-
html_theme = 'furo'
|
|
52
|
-
html_static_path = ['_static']
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
|
|
56
|
-
source_suffix = ['.rst', '.md']
|
|
57
|
-
master_doc = 'index'
|