torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,1081 @@
|
|
|
1
|
+
# pyright: reportIncompatibleMethodOverride=false
|
|
2
|
+
r"""
|
|
3
|
+
TensorList is a data type that can be used to manipulate a sequence of tensors such as model parameters,
|
|
4
|
+
with the same methods that normal tensors have, plus some additional convenience features.
|
|
5
|
+
Whenever possible, I used _foreach methods and other tricks to speed up computation.
|
|
6
|
+
|
|
7
|
+
TensorList is similar to TensorDict (https://github.com/pytorch/tensordict).
|
|
8
|
+
If you want to get the most performance out of a collection of tensors, use TensorDict and lock it.
|
|
9
|
+
However I found that *creating* a TensorDict is quite slow. In fact it negates the benefits of using it
|
|
10
|
+
in an optimizer when you have to create one from parameters on each step. The solution could be to create
|
|
11
|
+
it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
|
|
12
|
+
"""
|
|
13
|
+
import builtins
|
|
14
|
+
from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
|
|
15
|
+
import math
|
|
16
|
+
import operator
|
|
17
|
+
from typing import Any, Literal, TypedDict, overload
|
|
18
|
+
from typing_extensions import Self, TypeAlias, Unpack
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from .ops import where_
|
|
22
|
+
from .python_tools import generic_eq, zipmap
|
|
23
|
+
from .numberlist import NumberList, as_numberlist, maybe_numberlist
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_Scalar = int | float | bool | complex
|
|
27
|
+
_TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
|
|
28
|
+
_ScalarSeq = list[int] | list[float] | list[bool] | list[complex] | tuple[int] | tuple[float] | tuple[bool] | tuple[complex]
|
|
29
|
+
_ScalarSequence = Sequence[_Scalar] # i only check (list,tuple), its faster and safer
|
|
30
|
+
_STSeq = _TensorSeq | _ScalarSeq
|
|
31
|
+
_STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
|
|
32
|
+
|
|
33
|
+
_Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
|
|
34
|
+
|
|
35
|
+
Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
|
|
36
|
+
class _NewTensorKwargs(TypedDict, total = False):
|
|
37
|
+
memory_format: Any
|
|
38
|
+
dtype: Any
|
|
39
|
+
layout: Any
|
|
40
|
+
device: Any
|
|
41
|
+
pin_memory: bool
|
|
42
|
+
requires_grad: bool
|
|
43
|
+
|
|
44
|
+
# _foreach_methods = {attr.replace('_foreach_', ''):getattr(torch, attr) for attr in dir(torch) if attr.startswith('_foreach_')}
|
|
45
|
+
class _MethodCallerWithArgs:
|
|
46
|
+
"""Return a callable object that calls the given method on its operand.
|
|
47
|
+
|
|
48
|
+
This is similar to operator.methodcaller but args and kwargs are specificed in __call__.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
method (str): name of method to call.
|
|
52
|
+
"""
|
|
53
|
+
__slots__ = ('_name',)
|
|
54
|
+
def __init__(self, name: str):
|
|
55
|
+
self._name = name
|
|
56
|
+
|
|
57
|
+
def __call__(self, obj, *args, **kwargs):
|
|
58
|
+
return getattr(obj, self._name)(*args, **kwargs)
|
|
59
|
+
|
|
60
|
+
def __repr__(self):
|
|
61
|
+
return f'{self.__class__.__module__}.{self.__class__.__name__}({repr(self._name)})'
|
|
62
|
+
|
|
63
|
+
def __reduce__(self):
|
|
64
|
+
return self.__class__, self._name
|
|
65
|
+
|
|
66
|
+
def as_tensorlist(x):
|
|
67
|
+
if isinstance(x, TensorList): return x
|
|
68
|
+
return TensorList(x)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# tensorlist must subclass list
|
|
72
|
+
# UserList doesn't work with _foreach_xxx
|
|
73
|
+
class TensorList(list[torch.Tensor | Any]):
|
|
74
|
+
@classmethod
|
|
75
|
+
def complex(cls, real: _TensorSeq, imag: _TensorSeq):
|
|
76
|
+
"""Create a complex TensorList from real and imaginary tensor sequences."""
|
|
77
|
+
return cls(torch.complex(r, i) for r, i in zip(real, imag))
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def device(self): return [i.device for i in self]
|
|
81
|
+
@property
|
|
82
|
+
def dtype(self): return [i.dtype for i in self]
|
|
83
|
+
@property
|
|
84
|
+
def requires_grad(self): return [i.requires_grad for i in self]
|
|
85
|
+
@property
|
|
86
|
+
def shape(self): return [i.shape for i in self]
|
|
87
|
+
def size(self, dim: int | None = None): return [i.size(dim) for i in self]
|
|
88
|
+
@property
|
|
89
|
+
def ndim(self): return [i.ndim for i in self]
|
|
90
|
+
def ndimension(self): return [i.ndimension() for i in self]
|
|
91
|
+
def numel(self): return [i.numel() for i in self]
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def grad(self): return self.__class__(i.grad for i in self)
|
|
95
|
+
@property
|
|
96
|
+
def real(self): return self.__class__(i.real for i in self)
|
|
97
|
+
@property
|
|
98
|
+
def imag(self): return self.__class__(i.imag for i in self)
|
|
99
|
+
|
|
100
|
+
def view_as_real(self): return self.__class__(torch.view_as_real(i) for i in self)
|
|
101
|
+
def view_as_complex(self): return self.__class__(torch.view_as_complex(i) for i in self)
|
|
102
|
+
|
|
103
|
+
def type_as(self, other: torch.Tensor | _TensorSeq):
|
|
104
|
+
return self.zipmap(_MethodCallerWithArgs('type_as'), other)
|
|
105
|
+
|
|
106
|
+
def view_as(self, other: torch.Tensor | Sequence[torch.Tensor]):
|
|
107
|
+
if isinstance(other, Sequence): return self.__class__(s.view_as(o) for s, o in zip(self, other))
|
|
108
|
+
return self.__class__(s.view_as(other) for s in self)
|
|
109
|
+
|
|
110
|
+
def fill_none(self, reference: Iterable[torch.Tensor]):
|
|
111
|
+
"""all None values are replaced with zeros of the same shape as corresponding `reference` tensor."""
|
|
112
|
+
return self.__class__(t if t is not None else torch.zeros_like(r) for t,r in zip(self, reference))
|
|
113
|
+
|
|
114
|
+
def fill_none_(self, reference: Iterable[torch.Tensor]):
|
|
115
|
+
"""all None values are replaced with zeros of the same shape as corresponding `reference` tensor."""
|
|
116
|
+
for i, (t,r) in enumerate(zip(self, reference)):
|
|
117
|
+
if t is None: self[i] = torch.zeros_like(r)
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
def get_grad(self):
|
|
121
|
+
"""Returns all gradients that are not None."""
|
|
122
|
+
return self.__class__(i.grad for i in self if i.grad is not None)
|
|
123
|
+
|
|
124
|
+
def with_requires_grad(self, requires_grad = True):
|
|
125
|
+
"""Returns all tensors with requires_grad set to the given value."""
|
|
126
|
+
return self.__class__(i for i in self if i.requires_grad == requires_grad)
|
|
127
|
+
|
|
128
|
+
def with_grad(self):
|
|
129
|
+
"""returns all tensors whose .grad is not None"""
|
|
130
|
+
return self.__class__(i for i in self if i.grad is not None)
|
|
131
|
+
|
|
132
|
+
def ensure_grad_(self):
|
|
133
|
+
"""For each element, if grad is None and it requires grad, sets grad to zeroes."""
|
|
134
|
+
for i in self:
|
|
135
|
+
if i.requires_grad and i.grad is None: i.grad = torch.zeros_like(i)
|
|
136
|
+
return self
|
|
137
|
+
|
|
138
|
+
def accumulate_grad_(self, grads: _TensorSeq):
|
|
139
|
+
"""Creates grad if it is None, otherwise adds to existing grad."""
|
|
140
|
+
for i, g in zip(self, grads):
|
|
141
|
+
if i.grad is None: i.grad = g
|
|
142
|
+
else: i.grad.add_(g)
|
|
143
|
+
return self
|
|
144
|
+
|
|
145
|
+
def set_grad_(self, grads: _TensorSeq):
|
|
146
|
+
"""Assings grad attributes to the given sequence, replaces grad that already exists."""
|
|
147
|
+
for i, g in zip(self, grads): i.grad = g
|
|
148
|
+
return self
|
|
149
|
+
|
|
150
|
+
def zero_grad_(self, set_to_none = True):
|
|
151
|
+
"""Set all grads to None or zeroes."""
|
|
152
|
+
if set_to_none:
|
|
153
|
+
for p in self: p.grad = None
|
|
154
|
+
else:
|
|
155
|
+
self.get_grad().zero_()
|
|
156
|
+
return self
|
|
157
|
+
|
|
158
|
+
def __add__(self, other: _STOrSTSeq) -> Self: return self.add(other) # pyright: ignore[reportCallIssue,reportArgumentType]
|
|
159
|
+
def __radd__(self, other: _STOrSTSeq) -> Self: return self.add(other) # pyright: ignore[reportCallIssue,reportArgumentType]
|
|
160
|
+
def __iadd__(self, other: _STOrSTSeq) -> Self: return self.add_(other) # pyright: ignore[reportCallIssue,reportArgumentType]
|
|
161
|
+
|
|
162
|
+
def __sub__(self, other: "_Scalar | _STSeq") -> Self: return self.sub(other) # pyright: ignore[reportCallIssue,reportArgumentType]
|
|
163
|
+
def __rsub__(self, other: "_Scalar | _STSeq") -> Self: return self.sub(other).neg_() # pyright: ignore[reportCallIssue,reportArgumentType]
|
|
164
|
+
def __isub__(self, other: "_Scalar | _STSeq") -> Self: return self.sub_(other) # pyright: ignore[reportCallIssue,reportArgumentType]
|
|
165
|
+
|
|
166
|
+
def __mul__(self, other: _STOrSTSeq) -> Self: return self.mul(other)
|
|
167
|
+
def __rmul__(self, other: _STOrSTSeq) -> Self: return self.mul(other)
|
|
168
|
+
def __imul__(self, other: _STOrSTSeq) -> Self: return self.mul_(other)
|
|
169
|
+
|
|
170
|
+
def __truediv__(self, other: "_STOrSTSeq") -> Self: return self.div(other)
|
|
171
|
+
def __rtruediv__(self, other: "_STOrSTSeq") -> Self: return other * self.reciprocal()
|
|
172
|
+
def __itruediv__(self, other: "_STOrSTSeq") -> Self: return self.div_(other)
|
|
173
|
+
|
|
174
|
+
def __floordiv__(self, other: _STOrSTSeq): return self.floor_divide(other)
|
|
175
|
+
#def __rfloordiv__(self, other: "TensorList"): return other.floor_divide(self)
|
|
176
|
+
def __ifloordiv__(self, other: _STOrSTSeq): return self.floor_divide_(other)
|
|
177
|
+
|
|
178
|
+
def __mod__(self, other: _STOrSTSeq): return self.remainder(other)
|
|
179
|
+
#def __rmod__(self, other: STOrSTSequence): return self.remainder(other)
|
|
180
|
+
def __imod__(self, other: _STOrSTSeq):return self.remainder_(other)
|
|
181
|
+
|
|
182
|
+
def __pow__(self, other: "_Scalar | _STSeq"): return self.pow(other)
|
|
183
|
+
def __rpow__(self, other: "_Scalar | _TensorSeq"): return self.rpow(other)
|
|
184
|
+
def __ipow__(self, other: "_Scalar | _STSeq"): return self.pow_(other)
|
|
185
|
+
|
|
186
|
+
def __neg__(self): return self.neg()
|
|
187
|
+
|
|
188
|
+
def __eq__(self, other: _STOrSTSeq): return self.eq(other)
|
|
189
|
+
def __ne__(self, other: _STOrSTSeq): return self.ne(other)
|
|
190
|
+
def __lt__(self, other: _STOrSTSeq): return self.lt(other)
|
|
191
|
+
def __le__(self, other: _STOrSTSeq): return self.le(other)
|
|
192
|
+
def __gt__(self, other: _STOrSTSeq): return self.gt(other)
|
|
193
|
+
def __ge__(self, other: _STOrSTSeq): return self.ge(other)
|
|
194
|
+
|
|
195
|
+
def __invert__(self): return self.logical_not()
|
|
196
|
+
|
|
197
|
+
def __and__(self, other: torch.Tensor | _TensorSeq): return self.logical_and(other)
|
|
198
|
+
def __iand__(self, other: torch.Tensor | _TensorSeq): return self.logical_and_(other)
|
|
199
|
+
def __or__(self, other: torch.Tensor | _TensorSeq): return self.logical_or(other)
|
|
200
|
+
def __ior__(self, other: torch.Tensor | _TensorSeq): return self.logical_or_(other)
|
|
201
|
+
def __xor__(self, other: torch.Tensor | _TensorSeq): return self.logical_xor(other)
|
|
202
|
+
def __ixor__(self, other: torch.Tensor | _TensorSeq): return self.logical_xor_(other)
|
|
203
|
+
|
|
204
|
+
def __bool__(self):
|
|
205
|
+
raise RuntimeError(f'Boolean value of {self.__class__.__name__} is ambiguous')
|
|
206
|
+
|
|
207
|
+
def map(self, fn: Callable[..., torch.Tensor], *args, **kwargs):
|
|
208
|
+
"""Applies `fn` to all elements of this TensorList
|
|
209
|
+
and returns a new TensorList with return values of the callable."""
|
|
210
|
+
return self.__class__(fn(i, *args, **kwargs) for i in self)
|
|
211
|
+
def map_inplace_(self, fn: Callable[..., Any], *args, **kwargs):
|
|
212
|
+
"""Applies an in-place `fn` to all elements of this TensorList."""
|
|
213
|
+
for i in self: fn(i, *args, **kwargs)
|
|
214
|
+
return self
|
|
215
|
+
|
|
216
|
+
def filter(self, fn: Callable[..., bool], *args, **kwargs):
|
|
217
|
+
"""Returns a TensorList with all elements for which `fn` returned True."""
|
|
218
|
+
return self.__class__(i for i in self if fn(i, *args, **kwargs))
|
|
219
|
+
|
|
220
|
+
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
221
|
+
"""If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
|
|
222
|
+
Otherwise applies `fn` to this TensorList and `other`.
|
|
223
|
+
Returns a new TensorList with return values of the callable."""
|
|
224
|
+
return zipmap(self, fn, other, *args, **kwargs)
|
|
225
|
+
|
|
226
|
+
def zipmap_inplace_(self, fn: Callable[..., Any], other: Any | list | tuple, *args, **kwargs):
|
|
227
|
+
"""If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
|
|
228
|
+
Otherwise applies `fn` to this TensorList and `other`.
|
|
229
|
+
The callable must modify elements in-place."""
|
|
230
|
+
if isinstance(other, (list, tuple)):
|
|
231
|
+
for i, j in zip(self, other): fn(i, j, *args, **kwargs)
|
|
232
|
+
else:
|
|
233
|
+
for i in self: fn(i, other, *args, **kwargs)
|
|
234
|
+
return self
|
|
235
|
+
|
|
236
|
+
def zipmap_args(self, fn: Callable[..., Any], *others, **kwargs):
|
|
237
|
+
"""If `args` is list/tuple, applies `fn` to this TensorList zipped with `others`.
|
|
238
|
+
Otherwise applies `fn` to this TensorList and `other`."""
|
|
239
|
+
others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
|
|
240
|
+
return self.__class__(fn(*z, **kwargs) for z in zip(self, *others))
|
|
241
|
+
|
|
242
|
+
def zipmap_args_inplace_(self, fn: Callable[..., Any], *others, **kwargs):
|
|
243
|
+
"""If `args` is list/tuple, applies `fn` to this TensorList zipped with `other`.
|
|
244
|
+
Otherwise applies `fn` to this TensorList and `other`.
|
|
245
|
+
The callable must modify elements in-place."""
|
|
246
|
+
others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
|
|
247
|
+
for z in zip(self, *others): fn(*z, **kwargs)
|
|
248
|
+
return self
|
|
249
|
+
|
|
250
|
+
def _foreach_apply(self, fn: Callable[[list[torch.Tensor]], list[torch.Tensor]], *args, **kwargs):
|
|
251
|
+
"""Applies a torch._foreach_xxx function to self and converts returned list back to TensorList or subclass."""
|
|
252
|
+
return self.__class__(fn(self), *args, **kwargs)
|
|
253
|
+
|
|
254
|
+
# def __getattr__(self, name: str) -> Callable:
|
|
255
|
+
# if name == '__torch_function__' or name == '_ipython_canary_method_should_not_exist_': raise AttributeError('who ???')
|
|
256
|
+
# if name in _foreach_methods:
|
|
257
|
+
# method = partial(self._foreach_apply, _foreach_methods[name])
|
|
258
|
+
# else:
|
|
259
|
+
# method = partial(self.map, MethodCallerWithArgs(name))
|
|
260
|
+
# setattr(self, name, method)
|
|
261
|
+
# return method
|
|
262
|
+
|
|
263
|
+
def to(self, *args, **kwargs): return self.__class__(i.to(*args, **kwargs) for i in self)
|
|
264
|
+
def cuda(self): return self.__class__(i.cuda() for i in self)
|
|
265
|
+
def cpu(self): return self.__class__(i.cpu() for i in self)
|
|
266
|
+
def long(self): return self.__class__(i.long() for i in self)
|
|
267
|
+
def short(self): return self.__class__(i.short() for i in self)
|
|
268
|
+
def clone(self): return self.__class__(i.clone() for i in self)
|
|
269
|
+
def detach(self): return self.__class__(i.detach() for i in self)
|
|
270
|
+
def detach_(self):
|
|
271
|
+
for i in self: i.detach_()
|
|
272
|
+
return self
|
|
273
|
+
def contiguous(self): return self.__class__(i.contiguous() for i in self)
|
|
274
|
+
|
|
275
|
+
# apparently I can't use float for typing if I call a method "float"
|
|
276
|
+
def as_float(self): return self.__class__(i.float() for i in self)
|
|
277
|
+
def as_bool(self): return self.__class__(i.bool() for i in self)
|
|
278
|
+
def as_int(self): return self.__class__(i.int() for i in self)
|
|
279
|
+
|
|
280
|
+
def copy_(self, src: _TensorSeq, non_blocking = False):
|
|
281
|
+
"""Copies the elements from src tensors into self tensors."""
|
|
282
|
+
torch._foreach_copy_(self, src, non_blocking=non_blocking)
|
|
283
|
+
def set_(self, storage: Iterable[torch.Tensor | torch.types.Storage]):
|
|
284
|
+
"""Sets elements of this TensorList to the values of a list of tensors."""
|
|
285
|
+
for i, j in zip(self, storage): i.set_(j) # pyright:ignore[reportArgumentType]
|
|
286
|
+
return self
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def requires_grad_(self, mode: bool = True):
|
|
290
|
+
for e in self: e.requires_grad_(mode)
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
def to_vec(self): return torch.cat(self.ravel())
|
|
294
|
+
def from_vec_(self, vec:torch.Tensor):
|
|
295
|
+
"""Sets elements of this TensorList to the values of a 1D tensor.
|
|
296
|
+
The length of the tensor must be equal to the total number of elements in this TensorList."""
|
|
297
|
+
cur = 0
|
|
298
|
+
for el in self:
|
|
299
|
+
numel = el.numel()
|
|
300
|
+
el.set_(vec[cur:cur + numel].type_as(el).view_as(el)) # pyright:ignore[reportArgumentType]
|
|
301
|
+
cur += numel
|
|
302
|
+
return self
|
|
303
|
+
|
|
304
|
+
def from_vec(self, vec:torch.Tensor):
|
|
305
|
+
"""Creates a new TensorList from this TensorList but with values from a 1D tensor.
|
|
306
|
+
The length of the tensor must be equal to the total number of elements in this TensorList."""
|
|
307
|
+
res = []
|
|
308
|
+
cur = 0
|
|
309
|
+
for el in self:
|
|
310
|
+
numel = el.numel()
|
|
311
|
+
res.append(vec[cur:cur + numel].type_as(el).view_as(el))
|
|
312
|
+
cur += numel
|
|
313
|
+
return TensorList(res)
|
|
314
|
+
|
|
315
|
+
# using single operation on a vec, e.g. torch.sum(self.to_vec()) can be faster but its less memory efficient
|
|
316
|
+
def global_min(self) -> torch.Tensor: return builtins.min(self.min()) # pyright:ignore[reportArgumentType]
|
|
317
|
+
def global_max(self) -> torch.Tensor: return builtins.max(self.max()) # pyright:ignore[reportArgumentType]
|
|
318
|
+
def global_mean(self) -> torch.Tensor: return self.global_sum()/self.global_numel()
|
|
319
|
+
def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
|
|
320
|
+
def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
|
|
321
|
+
def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
|
|
322
|
+
def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
|
|
323
|
+
return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
324
|
+
def global_any(self): return builtins.any(self.any())
|
|
325
|
+
def global_all(self): return builtins.all(self.all())
|
|
326
|
+
def global_numel(self) -> int: return builtins.sum(self.numel())
|
|
327
|
+
|
|
328
|
+
def empty_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.empty_like(i, **kwargs) for i in self)
|
|
329
|
+
def zeros_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.zeros_like(i, **kwargs) for i in self)
|
|
330
|
+
def ones_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.ones_like(i, **kwargs) for i in self)
|
|
331
|
+
def full_like(self, fill_value: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
|
|
332
|
+
#return self.__class__(torch.full_like(i, fill_value=fill_value, **kwargs) for i in self)
|
|
333
|
+
return self.zipmap(torch.full_like, other=fill_value, **kwargs)
|
|
334
|
+
|
|
335
|
+
def rand_like(self, generator=None, dtype: Any=None, device: Any=None, **kwargs):
|
|
336
|
+
if generator is not None:
|
|
337
|
+
return self.__class__(torch.rand(t.shape, generator=generator,
|
|
338
|
+
dtype=t.dtype if dtype is None else dtype,
|
|
339
|
+
device=t.device if device is None else device, **kwargs) for t in self)
|
|
340
|
+
|
|
341
|
+
return self.__class__(torch.rand_like(i, dtype=dtype, device=device, **kwargs) for i in self)
|
|
342
|
+
|
|
343
|
+
def randn_like(self, generator=None, dtype: Any=None, device: Any=None, **kwargs):
|
|
344
|
+
|
|
345
|
+
if generator is not None:
|
|
346
|
+
return self.__class__(torch.randn(t.shape, generator=generator,
|
|
347
|
+
dtype=t.dtype if dtype is None else dtype,
|
|
348
|
+
device=t.device if device is None else device, **kwargs) for t in self)
|
|
349
|
+
|
|
350
|
+
return self.__class__(torch.randn_like(i, dtype=dtype, device=device, **kwargs) for i in self)
|
|
351
|
+
|
|
352
|
+
def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
|
|
353
|
+
return self.zipmap_args(torch.randint_like, low, high, **kwargs)
|
|
354
|
+
def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
355
|
+
res = self.empty_like(**kwargs)
|
|
356
|
+
res.uniform_(low, high, generator=generator)
|
|
357
|
+
return res
|
|
358
|
+
def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
|
|
359
|
+
r = self.randn_like(generator=generator, **kwargs)
|
|
360
|
+
return (r * radius) / r.global_vector_norm()
|
|
361
|
+
def bernoulli(self, generator = None):
|
|
362
|
+
return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
|
|
363
|
+
def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
364
|
+
"""p is probability of a 1, other values will be 0."""
|
|
365
|
+
return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
|
|
366
|
+
def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
367
|
+
"""p is probability of a 1, other values will be -1."""
|
|
368
|
+
return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
|
|
369
|
+
|
|
370
|
+
def sample_like(self, eps: "_Scalar | _ScalarSeq" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
|
|
371
|
+
"""Sample around 0."""
|
|
372
|
+
if distribution in ('normal', 'gaussian'): return self.randn_like(generator=generator, **kwargs) * eps
|
|
373
|
+
if distribution == 'uniform':
|
|
374
|
+
if isinstance(eps, (list,tuple)):
|
|
375
|
+
return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs)
|
|
376
|
+
return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
|
|
377
|
+
if distribution == 'sphere': return self.sphere_like(eps, generator=generator, **kwargs)
|
|
378
|
+
if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
|
|
379
|
+
raise ValueError(f'Unknow distribution {distribution}')
|
|
380
|
+
|
|
381
|
+
def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
|
|
382
|
+
def eq_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('eq_'), other)
|
|
383
|
+
def ne(self, other: _STOrSTSeq): return self.zipmap(torch.ne, other)
|
|
384
|
+
def ne_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('ne_'), other)
|
|
385
|
+
def lt(self, other: _STOrSTSeq): return self.zipmap(torch.lt, other)
|
|
386
|
+
def lt_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('lt_'), other)
|
|
387
|
+
def le(self, other: _STOrSTSeq): return self.zipmap(torch.le, other)
|
|
388
|
+
def le_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('le_'), other)
|
|
389
|
+
def gt(self, other: _STOrSTSeq): return self.zipmap(torch.gt, other)
|
|
390
|
+
def gt_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('gt_'), other)
|
|
391
|
+
def ge(self, other: _STOrSTSeq): return self.zipmap(torch.ge, other)
|
|
392
|
+
def ge_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('ge_'), other)
|
|
393
|
+
|
|
394
|
+
def logical_and(self, other: torch.Tensor | _TensorSeq): return self.zipmap(torch.logical_and, other)
|
|
395
|
+
def logical_and_(self, other: torch.Tensor | _TensorSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_and_'), other)
|
|
396
|
+
def logical_or(self, other: torch.Tensor | _TensorSeq): return self.zipmap(torch.logical_or, other)
|
|
397
|
+
def logical_or_(self, other: torch.Tensor | _TensorSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_or_'), other)
|
|
398
|
+
def logical_xor(self, other: torch.Tensor | _TensorSeq): return self.zipmap(torch.logical_xor, other)
|
|
399
|
+
def logical_xor_(self, other: torch.Tensor | _TensorSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_xor_'), other)
|
|
400
|
+
|
|
401
|
+
def logical_not(self): return self.__class__(torch.logical_not(i) for i in self)
|
|
402
|
+
def logical_not_(self):
|
|
403
|
+
for i in self: i.logical_not_()
|
|
404
|
+
return self
|
|
405
|
+
|
|
406
|
+
def equal(self, other: torch.Tensor | _TensorSeq):
|
|
407
|
+
"""returns TensorList of boolean values, True if two tensors have the same size and elements, False otherwise."""
|
|
408
|
+
return self.zipmap(torch.equal, other)
|
|
409
|
+
|
|
410
|
+
@overload
|
|
411
|
+
def add(self, other: torch.Tensor | _TensorSeq, alpha: _Scalar = 1): ...
|
|
412
|
+
@overload
|
|
413
|
+
def add(self, other: _Scalar | _ScalarSeq): ...
|
|
414
|
+
def add(self, other: _STOrSTSeq, alpha: _Scalar = 1):
|
|
415
|
+
if alpha == 1: return self.__class__(torch._foreach_add(self, other))
|
|
416
|
+
return self.__class__(torch._foreach_add(self, other, alpha = alpha)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
417
|
+
|
|
418
|
+
@overload
|
|
419
|
+
def add_(self, other: torch.Tensor | _TensorSeq, alpha: _Scalar = 1): ...
|
|
420
|
+
@overload
|
|
421
|
+
def add_(self, other: _Scalar | _ScalarSeq): ...
|
|
422
|
+
def add_(self, other: _STOrSTSeq, alpha: _Scalar = 1):
|
|
423
|
+
if alpha == 1: torch._foreach_add_(self, other)
|
|
424
|
+
else: torch._foreach_add_(self, other, alpha = alpha) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
425
|
+
return self
|
|
426
|
+
|
|
427
|
+
def lazy_add(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
428
|
+
if generic_eq(other, 0): return self
|
|
429
|
+
return self.add(other)
|
|
430
|
+
def lazy_add_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
431
|
+
if generic_eq(other, 0): return self
|
|
432
|
+
return self.add_(other)
|
|
433
|
+
|
|
434
|
+
@overload
|
|
435
|
+
def sub(self, other: _TensorSeq, alpha: _Scalar = 1): ...
|
|
436
|
+
@overload
|
|
437
|
+
def sub(self, other: _Scalar | _ScalarSeq): ...
|
|
438
|
+
def sub(self, other: "_Scalar | _STSeq", alpha: _Scalar = 1):
|
|
439
|
+
if alpha == 1: return self.__class__(torch._foreach_sub(self, other))
|
|
440
|
+
return self.__class__(torch._foreach_sub(self, other, alpha = alpha)) # pyright:ignore[reportArgumentType]
|
|
441
|
+
|
|
442
|
+
@overload
|
|
443
|
+
def sub_(self, other: _TensorSeq, alpha: _Scalar = 1): ...
|
|
444
|
+
@overload
|
|
445
|
+
def sub_(self, other: _Scalar | _ScalarSeq): ...
|
|
446
|
+
def sub_(self, other: "_Scalar | _STSeq", alpha: _Scalar = 1):
|
|
447
|
+
if alpha == 1: torch._foreach_sub_(self, other)
|
|
448
|
+
else: torch._foreach_sub_(self, other, alpha = alpha) # pyright:ignore[reportArgumentType]
|
|
449
|
+
return self
|
|
450
|
+
|
|
451
|
+
def lazy_sub(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
452
|
+
if generic_eq(other, 0): return self
|
|
453
|
+
return self.sub(other)
|
|
454
|
+
def lazy_sub_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
455
|
+
if generic_eq(other, 0): return self
|
|
456
|
+
return self.sub_(other)
|
|
457
|
+
|
|
458
|
+
def neg(self): return self.__class__(torch._foreach_neg(self))
|
|
459
|
+
def neg_(self):
|
|
460
|
+
torch._foreach_neg_(self)
|
|
461
|
+
return self
|
|
462
|
+
|
|
463
|
+
def mul(self, other: _STOrSTSeq): return self.__class__(torch._foreach_mul(self, other))
|
|
464
|
+
def mul_(self, other: _STOrSTSeq):
|
|
465
|
+
torch._foreach_mul_(self, other)
|
|
466
|
+
return self
|
|
467
|
+
|
|
468
|
+
# TODO: benchmark
|
|
469
|
+
def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
|
|
470
|
+
if generic_eq(other, 1):
|
|
471
|
+
if clone: return self.clone()
|
|
472
|
+
return self
|
|
473
|
+
return self * other
|
|
474
|
+
def lazy_mul_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
475
|
+
if generic_eq(other, 1): return self
|
|
476
|
+
return self.mul_(other)
|
|
477
|
+
|
|
478
|
+
def div(self, other: _STOrSTSeq) -> Self: return self.__class__(torch._foreach_div(self, other))
|
|
479
|
+
def div_(self, other: _STOrSTSeq):
|
|
480
|
+
torch._foreach_div_(self, other)
|
|
481
|
+
return self
|
|
482
|
+
|
|
483
|
+
def lazy_div(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
484
|
+
if generic_eq(other, 1): return self
|
|
485
|
+
return self / other
|
|
486
|
+
def lazy_div_(self, other: int | float | list[int | float] | tuple[int | float]):
|
|
487
|
+
if generic_eq(other, 1): return self
|
|
488
|
+
return self.div_(other)
|
|
489
|
+
|
|
490
|
+
def pow(self, exponent: "_Scalar | _STSeq"): return self.__class__(torch._foreach_pow(self, exponent))
|
|
491
|
+
def pow_(self, exponent: "_Scalar | _STSeq"):
|
|
492
|
+
torch._foreach_pow_(self, exponent)
|
|
493
|
+
return self
|
|
494
|
+
|
|
495
|
+
def rpow(self, input: _Scalar | _TensorSeq): return self.__class__(torch._foreach_pow(input, self))
|
|
496
|
+
def rpow_(self, input: _TensorSeq):
|
|
497
|
+
torch._foreach_pow_(input, self)
|
|
498
|
+
return self
|
|
499
|
+
|
|
500
|
+
def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
|
|
501
|
+
def sqrt_(self):
|
|
502
|
+
torch._foreach_sqrt_(self)
|
|
503
|
+
return self
|
|
504
|
+
|
|
505
|
+
def remainder(self, other: _STOrSTSeq): return self.zipmap(torch.remainder, other)
|
|
506
|
+
def remainder_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('remainder_'), other)
|
|
507
|
+
|
|
508
|
+
def floor_divide(self, other: _STOrSTSeq): return self.zipmap(torch.floor_divide, other)
|
|
509
|
+
def floor_divide_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('floor_divide_'), other)
|
|
510
|
+
|
|
511
|
+
def reciprocal(self): return self.__class__(torch._foreach_reciprocal(self))
|
|
512
|
+
def reciprocal_(self):
|
|
513
|
+
torch._foreach_reciprocal_(self)
|
|
514
|
+
return self
|
|
515
|
+
|
|
516
|
+
def abs(self): return self.__class__(torch._foreach_abs(self))
|
|
517
|
+
def abs_(self):
|
|
518
|
+
torch._foreach_abs_(self)
|
|
519
|
+
return self
|
|
520
|
+
|
|
521
|
+
def sign(self): return self.__class__(torch._foreach_sign(self))
|
|
522
|
+
def sign_(self):
|
|
523
|
+
torch._foreach_sign_(self)
|
|
524
|
+
return self
|
|
525
|
+
|
|
526
|
+
def exp(self): return self.__class__(torch._foreach_exp(self))
|
|
527
|
+
def exp_(self):
|
|
528
|
+
torch._foreach_exp_(self)
|
|
529
|
+
return self
|
|
530
|
+
|
|
531
|
+
def signbit(self): return self.__class__(torch.signbit(i) for i in self)
|
|
532
|
+
|
|
533
|
+
def sin(self): return self.__class__(torch._foreach_sin(self))
|
|
534
|
+
def sin_(self):
|
|
535
|
+
torch._foreach_sin_(self)
|
|
536
|
+
return self
|
|
537
|
+
|
|
538
|
+
def cos(self): return self.__class__(torch._foreach_cos(self))
|
|
539
|
+
def cos_(self):
|
|
540
|
+
torch._foreach_cos_(self)
|
|
541
|
+
return self
|
|
542
|
+
|
|
543
|
+
def tan(self): return self.__class__(torch._foreach_tan(self))
|
|
544
|
+
def tan_(self):
|
|
545
|
+
torch._foreach_tan_(self)
|
|
546
|
+
return self
|
|
547
|
+
|
|
548
|
+
def asin(self): return self.__class__(torch._foreach_asin(self))
|
|
549
|
+
def asin_(self):
|
|
550
|
+
torch._foreach_asin_(self)
|
|
551
|
+
return self
|
|
552
|
+
|
|
553
|
+
def acos(self): return self.__class__(torch._foreach_acos(self))
|
|
554
|
+
def acos_(self):
|
|
555
|
+
torch._foreach_acos_(self)
|
|
556
|
+
return self
|
|
557
|
+
|
|
558
|
+
def atan(self): return self.__class__(torch._foreach_atan(self))
|
|
559
|
+
def atan_(self):
|
|
560
|
+
torch._foreach_atan_(self)
|
|
561
|
+
return self
|
|
562
|
+
|
|
563
|
+
def sinh(self): return self.__class__(torch._foreach_sinh(self))
|
|
564
|
+
def sinh_(self):
|
|
565
|
+
torch._foreach_sinh_(self)
|
|
566
|
+
return self
|
|
567
|
+
|
|
568
|
+
def cosh(self): return self.__class__(torch._foreach_cosh(self))
|
|
569
|
+
def cosh_(self):
|
|
570
|
+
torch._foreach_cosh_(self)
|
|
571
|
+
return self
|
|
572
|
+
|
|
573
|
+
def tanh(self): return self.__class__(torch._foreach_tanh(self))
|
|
574
|
+
def tanh_(self):
|
|
575
|
+
torch._foreach_tanh_(self)
|
|
576
|
+
return self
|
|
577
|
+
|
|
578
|
+
def log(self): return self.__class__(torch._foreach_log(self))
|
|
579
|
+
def log_(self):
|
|
580
|
+
torch._foreach_log_(self)
|
|
581
|
+
return self
|
|
582
|
+
|
|
583
|
+
def log10(self): return self.__class__(torch._foreach_log10(self))
|
|
584
|
+
def log10_(self):
|
|
585
|
+
torch._foreach_log10_(self)
|
|
586
|
+
return self
|
|
587
|
+
|
|
588
|
+
def log2(self): return self.__class__(torch._foreach_log2(self))
|
|
589
|
+
def log2_(self):
|
|
590
|
+
torch._foreach_log2_(self)
|
|
591
|
+
return self
|
|
592
|
+
|
|
593
|
+
def log1p(self): return self.__class__(torch._foreach_log1p(self))
|
|
594
|
+
def log1p_(self):
|
|
595
|
+
torch._foreach_log1p_(self)
|
|
596
|
+
return self
|
|
597
|
+
|
|
598
|
+
def erf(self): return self.__class__(torch._foreach_erf(self))
|
|
599
|
+
def erf_(self):
|
|
600
|
+
torch._foreach_erf_(self)
|
|
601
|
+
return self
|
|
602
|
+
|
|
603
|
+
def erfc(self): return self.__class__(torch._foreach_erfc(self))
|
|
604
|
+
def erfc_(self):
|
|
605
|
+
torch._foreach_erfc_(self)
|
|
606
|
+
return self
|
|
607
|
+
|
|
608
|
+
def sigmoid(self): return self.__class__(torch._foreach_sigmoid(self))
|
|
609
|
+
def sigmoid_(self):
|
|
610
|
+
torch._foreach_sigmoid_(self)
|
|
611
|
+
return self
|
|
612
|
+
|
|
613
|
+
def _global_fn(self, keepdim, fn, *args, **kwargs):
|
|
614
|
+
"""checks that keepdim is False and returns fn(*args, **kwargs)"""
|
|
615
|
+
#if keepdim: raise ValueError('dim = global and keepdim = True')
|
|
616
|
+
return fn(*args, **kwargs)
|
|
617
|
+
|
|
618
|
+
def max(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
619
|
+
if dim is None and not keepdim: return self.__class__(torch._foreach_max(self))
|
|
620
|
+
if dim == 'global': return self._global_fn(keepdim, self.global_max)
|
|
621
|
+
if dim is None: dim = ()
|
|
622
|
+
return self.__class__(i.amax(dim=dim, keepdim=keepdim) for i in self)
|
|
623
|
+
|
|
624
|
+
def min(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
625
|
+
if dim is None and not keepdim: return self.__class__(torch._foreach_max(self.neg())).neg_()
|
|
626
|
+
if dim == 'global': return self._global_fn(keepdim, self.global_min)
|
|
627
|
+
if dim is None: dim = ()
|
|
628
|
+
return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
|
|
629
|
+
|
|
630
|
+
def norm(self, ord: _Scalar, dtype=None):
|
|
631
|
+
return self.__class__(torch._foreach_norm(self, ord, dtype))
|
|
632
|
+
|
|
633
|
+
def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
634
|
+
if dim == 'global': return self._global_fn(keepdim, self.global_mean)
|
|
635
|
+
return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
|
|
636
|
+
|
|
637
|
+
def sum(self, dim: _Dim = None, keepdim = False) -> Self | Any:
|
|
638
|
+
if dim == 'global': return self._global_fn(keepdim, self.global_sum)
|
|
639
|
+
return self.__class__(i.sum(dim=dim, keepdim=keepdim) for i in self)
|
|
640
|
+
|
|
641
|
+
def prod(self, dim = None, keepdim = False): return self.__class__(i.prod(dim=dim, keepdim=keepdim) for i in self)
|
|
642
|
+
|
|
643
|
+
def std(self, dim: _Dim = None, unbiased: bool = True, keepdim = False) -> Self | Any:
|
|
644
|
+
if dim == 'global': return self._global_fn(keepdim, self.global_std)
|
|
645
|
+
return self.__class__(i.std(dim=dim, unbiased=unbiased, keepdim=keepdim) for i in self)
|
|
646
|
+
|
|
647
|
+
def var(self, dim: _Dim = None, unbiased: bool = True, keepdim = False) -> Self | Any:
|
|
648
|
+
if dim == 'global': return self._global_fn(keepdim, self.global_var)
|
|
649
|
+
return self.__class__(i.var(dim=dim, unbiased=unbiased, keepdim=keepdim) for i in self)
|
|
650
|
+
|
|
651
|
+
def median(self, dim=None, keepdim=False):
|
|
652
|
+
"""note this doesn't return indices"""
|
|
653
|
+
# median returns tensor or namedtuple (values, indices)
|
|
654
|
+
if dim is None: return self.__class__(i.median() for i in self)
|
|
655
|
+
return self.__class__(i.median(dim=dim, keepdim=keepdim)[0] for i in self)
|
|
656
|
+
|
|
657
|
+
def quantile(self, q, dim=None, keepdim=False, *, interpolation='linear',):
|
|
658
|
+
return self.__class__(i.quantile(q=q, dim=dim, keepdim=keepdim, interpolation=interpolation) for i in self)
|
|
659
|
+
|
|
660
|
+
def clamp_min(self, other: "_Scalar | _STSeq"): return self.__class__(torch._foreach_clamp_min(self, other))
|
|
661
|
+
def clamp_min_(self, other: "_Scalar | _STSeq"):
|
|
662
|
+
torch._foreach_clamp_min_(self, other)
|
|
663
|
+
return self
|
|
664
|
+
def clamp_max(self, other: "_Scalar | _STSeq"): return self.__class__(torch._foreach_clamp_max(self, other))
|
|
665
|
+
def clamp_max_(self, other: "_Scalar | _STSeq"):
|
|
666
|
+
torch._foreach_clamp_max_(self, other)
|
|
667
|
+
return self
|
|
668
|
+
|
|
669
|
+
def clamp(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
|
|
670
|
+
l = self
|
|
671
|
+
if min is not None: l = l.clamp_min(min)
|
|
672
|
+
if max is not None: l = l.clamp_max(max)
|
|
673
|
+
return l
|
|
674
|
+
def clamp_(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
|
|
675
|
+
if min is not None: self.clamp_min_(min)
|
|
676
|
+
if max is not None: self.clamp_max_(max)
|
|
677
|
+
return self
|
|
678
|
+
|
|
679
|
+
def clip(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None): return self.clamp(min,max)
|
|
680
|
+
def clip_(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None): return self.clamp_(min,max)
|
|
681
|
+
|
|
682
|
+
def clamp_magnitude(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
|
|
683
|
+
return self.abs().clamp_(min, max) * self.sign().add_(0.5).sign_() # this prevents zeros
|
|
684
|
+
def clamp_magnitude_(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
|
|
685
|
+
sign = self.sign().add_(0.5).sign_()
|
|
686
|
+
return self.abs_().clamp_(min, max).mul_(sign)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def floor(self): return self.__class__(torch._foreach_floor(self))
|
|
690
|
+
def floor_(self):
|
|
691
|
+
torch._foreach_floor_(self)
|
|
692
|
+
return self
|
|
693
|
+
def ceil(self): return self.__class__(torch._foreach_ceil(self))
|
|
694
|
+
def ceil_(self):
|
|
695
|
+
torch._foreach_ceil_(self)
|
|
696
|
+
return self
|
|
697
|
+
def round(self): return self.__class__(torch._foreach_round(self))
|
|
698
|
+
def round_(self):
|
|
699
|
+
torch._foreach_round_(self)
|
|
700
|
+
return self
|
|
701
|
+
|
|
702
|
+
def zero_(self):
|
|
703
|
+
torch._foreach_zero_(self)
|
|
704
|
+
return self
|
|
705
|
+
|
|
706
|
+
def lerp(self, tensors1: _TensorSeq, weight: "_Scalar | _ScalarSeq | _TensorSeq"):
|
|
707
|
+
"""linear interpolation of between self and tensors1. `out = self + weight * (tensors1 - self)`."""
|
|
708
|
+
return self.__class__(torch._foreach_lerp(self, tensors1, weight))
|
|
709
|
+
def lerp_(self, tensors1: _TensorSeq, weight: "_Scalar | _ScalarSeq | _TensorSeq"):
|
|
710
|
+
"""linear interpolation of between self and tensors1. `out = self + weight * (tensors1 - self)`."""
|
|
711
|
+
torch._foreach_lerp_(self, tensors1, weight)
|
|
712
|
+
return self
|
|
713
|
+
|
|
714
|
+
def lerp_compat(self, tensors1: _TensorSeq, weight: "_STOrSTSeq"):
|
|
715
|
+
"""`lerp` but support scalar sequence weight on pytorch versions before 2.6
|
|
716
|
+
|
|
717
|
+
`out = self + weight * (tensors1 - self)`."""
|
|
718
|
+
return self + weight * (TensorList(tensors1) - self)
|
|
719
|
+
def lerp_compat_(self, tensors1: _TensorSeq, weight: "_STOrSTSeq"):
|
|
720
|
+
"""`lerp_` but support scalar sequence weight on previous pytorch versions before 2.6
|
|
721
|
+
|
|
722
|
+
`out = self + weight * (tensors1 - self)`."""
|
|
723
|
+
return self.add_(TensorList(tensors1).sub(self).mul_(weight))
|
|
724
|
+
|
|
725
|
+
def addcmul(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
|
|
726
|
+
return self.__class__(torch._foreach_addcmul(self, tensors1, tensor2, value))
|
|
727
|
+
def addcmul_(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
|
|
728
|
+
torch._foreach_addcmul_(self, tensors1, tensor2, value)
|
|
729
|
+
return self
|
|
730
|
+
def addcdiv(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
|
|
731
|
+
return self.__class__(torch._foreach_addcdiv(self, tensors1, tensor2, value))
|
|
732
|
+
def addcdiv_(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
|
|
733
|
+
torch._foreach_addcdiv_(self, tensors1, tensor2, value)
|
|
734
|
+
return self
|
|
735
|
+
|
|
736
|
+
def uniform_(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator = None):
|
|
737
|
+
return self.zipmap_args_inplace_(_MethodCallerWithArgs('uniform_'), low, high, generator = generator)
|
|
738
|
+
|
|
739
|
+
def maximum(self, other: "_Scalar | _ScalarSeq | _TensorSeq"):
|
|
740
|
+
return self.__class__(torch._foreach_maximum(self, other))
|
|
741
|
+
def maximum_(self, other: "_Scalar | _ScalarSeq | _TensorSeq"): # ruff: noqa F811
|
|
742
|
+
torch._foreach_maximum_(self, other)
|
|
743
|
+
return self
|
|
744
|
+
|
|
745
|
+
def minimum(self, other: "_Scalar | _ScalarSeq | _TensorSeq"):
|
|
746
|
+
return self.__class__(torch._foreach_minimum(self, other))
|
|
747
|
+
def minimum_(self, other: "_Scalar | _ScalarSeq | _TensorSeq"):
|
|
748
|
+
torch._foreach_minimum_(self, other)
|
|
749
|
+
return self
|
|
750
|
+
|
|
751
|
+
def squeeze(self, dim = None):
|
|
752
|
+
if dim is None: return self.__class__(i.squeeze() for i in self)
|
|
753
|
+
return self.__class__(i.squeeze(dim) for i in self)
|
|
754
|
+
|
|
755
|
+
def squeeze_(self, dim = None):
|
|
756
|
+
if dim is None:
|
|
757
|
+
for i in self: i.squeeze_()
|
|
758
|
+
else:
|
|
759
|
+
for i in self: i.squeeze_(dim)
|
|
760
|
+
return self
|
|
761
|
+
|
|
762
|
+
def conj(self): return self.__class__(i.conj() for i in self)
|
|
763
|
+
|
|
764
|
+
def nan_to_num(self, nan: "float | _ScalarSeq | None" = None, posinf: "float | _ScalarSeq | None" = None, neginf: "float | _ScalarSeq | None" = None):
|
|
765
|
+
return self.zipmap_args(torch.nan_to_num, nan, posinf, neginf)
|
|
766
|
+
def nan_to_num_(self, nan: "float | _ScalarSeq | None" = None, posinf: "float | _ScalarSeq | None" = None, neginf: "float | _ScalarSeq | None" = None):
|
|
767
|
+
return self.zipmap_args_inplace_(torch.nan_to_num_, nan, posinf, neginf)
|
|
768
|
+
|
|
769
|
+
def ravel(self): return self.__class__(i.ravel() for i in self)
|
|
770
|
+
def view_flat(self): return self.__class__(i.view(-1) for i in self)
|
|
771
|
+
|
|
772
|
+
def any(self): return self.__class__(i.any() for i in self)
|
|
773
|
+
def all(self): return self.__class__(i.all() for i in self)
|
|
774
|
+
def isfinite(self): return self.__class__(i.isfinite() for i in self)
|
|
775
|
+
|
|
776
|
+
def fill(self, value: _STOrSTSeq): return self.zipmap(torch.fill, other = value)
|
|
777
|
+
def fill_(self, value: _STOrSTSeq): return self.zipmap_inplace_(torch.fill_, other = value)
|
|
778
|
+
|
|
779
|
+
def copysign(self, other):
|
|
780
|
+
return self.__class__(t.copysign(o) for t, o in zip(self, other))
|
|
781
|
+
def copysign_(self, other):
|
|
782
|
+
for t, o in zip(self, other): t.copysign_(o)
|
|
783
|
+
return self
|
|
784
|
+
|
|
785
|
+
def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
786
|
+
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
787
|
+
if tensorwise:
|
|
788
|
+
norm_self = self.norm(ord)
|
|
789
|
+
norm_other = magnitude.norm(ord)
|
|
790
|
+
else:
|
|
791
|
+
norm_self = self.global_vector_norm(ord)
|
|
792
|
+
norm_other = magnitude.global_vector_norm(ord)
|
|
793
|
+
|
|
794
|
+
if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
795
|
+
|
|
796
|
+
return self * (norm_other / norm_self.clip_(min=eps))
|
|
797
|
+
|
|
798
|
+
def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
|
|
799
|
+
if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
|
|
800
|
+
if tensorwise:
|
|
801
|
+
norm_self = self.norm(ord)
|
|
802
|
+
norm_other = magnitude.norm(ord)
|
|
803
|
+
else:
|
|
804
|
+
norm_self = self.global_vector_norm(ord)
|
|
805
|
+
norm_other = magnitude.global_vector_norm(ord)
|
|
806
|
+
|
|
807
|
+
if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
808
|
+
|
|
809
|
+
return self.mul_(norm_other / norm_self.clip_(min=eps))
|
|
810
|
+
|
|
811
|
+
def _get_rescale_coeffs(self, min:"_Scalar | _ScalarSeq", max:"_Scalar | _ScalarSeq", dim: _Dim, eps):
|
|
812
|
+
self_min = self.min(dim=dim, keepdim=True)
|
|
813
|
+
self_max = self.max(dim=dim, keepdim=True)
|
|
814
|
+
|
|
815
|
+
# target range difference (diff)
|
|
816
|
+
min = maybe_numberlist(min)
|
|
817
|
+
max = maybe_numberlist(max)
|
|
818
|
+
diff = max - min
|
|
819
|
+
target_min = min
|
|
820
|
+
source_range = (self_max - self_min).add_(eps)
|
|
821
|
+
a = diff / source_range
|
|
822
|
+
b = target_min - (a * self_min)
|
|
823
|
+
|
|
824
|
+
return a, b
|
|
825
|
+
|
|
826
|
+
def rescale(self, min: "_Scalar | _ScalarSeq | None", max: "_Scalar | _ScalarSeq | None", dim: _Dim = None, eps=0.):
|
|
827
|
+
"""rescales each tensor to (min, max) range"""
|
|
828
|
+
if min is None and max is None: return self
|
|
829
|
+
if max is None:
|
|
830
|
+
assert min is not None
|
|
831
|
+
return self - (self.min(dim=dim, keepdim=True).sub_(min))
|
|
832
|
+
if min is None: return self - (self.max(dim=dim, keepdim=True).sub_(max))
|
|
833
|
+
|
|
834
|
+
a,b = self._get_rescale_coeffs(min=min, max=max, dim=dim, eps=eps)
|
|
835
|
+
return (self*a).add_(b)
|
|
836
|
+
|
|
837
|
+
def rescale_(self, min: "_Scalar | _ScalarSeq | None", max: "_Scalar | _ScalarSeq | None", dim: _Dim = None, eps=0.):
|
|
838
|
+
"""rescales each tensor to (min, max) range"""
|
|
839
|
+
if min is None and max is None: return self
|
|
840
|
+
if max is None:
|
|
841
|
+
assert min is not None
|
|
842
|
+
return self.sub_(self.min(dim=dim, keepdim=True).sub_(min))
|
|
843
|
+
if min is None: return self.sub_(self.max(dim=dim, keepdim=True).sub_(max))
|
|
844
|
+
|
|
845
|
+
a,b = self._get_rescale_coeffs(min=min, max=max, dim=dim, eps=eps)
|
|
846
|
+
return (self.mul_(a)).add_(b)
|
|
847
|
+
|
|
848
|
+
def rescale_to_01(self, dim: _Dim = None, eps: float = 0):
|
|
849
|
+
"""faster method to rescale to (0, 1) range"""
|
|
850
|
+
res = self - self.min(dim = dim, keepdim=True)
|
|
851
|
+
max = res.max(dim = dim, keepdim=True)
|
|
852
|
+
if eps != 0: max.add_(eps)
|
|
853
|
+
return res.div_(max)
|
|
854
|
+
|
|
855
|
+
def rescale_to_01_(self, dim: _Dim = None, eps: float = 0):
|
|
856
|
+
"""faster method to rescale to (0, 1) range"""
|
|
857
|
+
self.sub_(self.min(dim = dim, keepdim=True))
|
|
858
|
+
max = self.max(dim = dim, keepdim=True)
|
|
859
|
+
if eps != 0: max.add_(eps)
|
|
860
|
+
return self.div_(max)
|
|
861
|
+
|
|
862
|
+
def normalize(self, mean: "_Scalar | _ScalarSeq | None", var: "_Scalar | _ScalarSeq | None", dim: _Dim = None): # pylint:disable=redefined-outer-name
|
|
863
|
+
"""normalizes to mean and variance"""
|
|
864
|
+
if mean is None and var is None: return self
|
|
865
|
+
if mean is None: return self / self.std(dim = dim, keepdim = True)
|
|
866
|
+
if var is None: return self - self.mean(dim = dim, keepdim = True)
|
|
867
|
+
self_mean = self.mean(dim = dim, keepdim = True)
|
|
868
|
+
self_std = self.std(dim = dim, keepdim = True)
|
|
869
|
+
|
|
870
|
+
if isinstance(var, Sequence): var_sqrt = [i**0.5 for i in var]
|
|
871
|
+
else: var_sqrt = var ** 0.5
|
|
872
|
+
|
|
873
|
+
return (self - self_mean).div_(self_std).mul_(var_sqrt).add_(mean)
|
|
874
|
+
|
|
875
|
+
def normalize_(self, mean: "_Scalar | _ScalarSeq | None", var: "_Scalar | _ScalarSeq | None", dim: _Dim = None): # pylint:disable=redefined-outer-name
|
|
876
|
+
"""normalizes to mean and variance"""
|
|
877
|
+
if mean is None and var is None: return self
|
|
878
|
+
if mean is None: return self / self.std(dim = dim, keepdim = True)
|
|
879
|
+
if var is None: return self - self.mean(dim = dim, keepdim = True)
|
|
880
|
+
self_mean = self.mean(dim = dim, keepdim = True)
|
|
881
|
+
self_std = self.std(dim = dim, keepdim = True)
|
|
882
|
+
|
|
883
|
+
if isinstance(var, Sequence): var_sqrt = [i**0.5 for i in var]
|
|
884
|
+
else: var_sqrt = var ** 0.5
|
|
885
|
+
|
|
886
|
+
return self.sub_(self_mean).div_(self_std).mul_(var_sqrt).add_(mean)
|
|
887
|
+
|
|
888
|
+
def znormalize(self, dim: _Dim = None, eps:float = 0):
|
|
889
|
+
"""faster method to normalize to 0 mean and 1 variance"""
|
|
890
|
+
std = self.std(dim = dim, keepdim = True)
|
|
891
|
+
if eps!=0: std.add_(eps)
|
|
892
|
+
return (self - self.mean(dim = dim, keepdim=True)).div_(std)
|
|
893
|
+
|
|
894
|
+
def znormalize_(self, dim: _Dim = None, eps:float = 0):
|
|
895
|
+
"""faster method to normalize to 0 mean and 1 variance"""
|
|
896
|
+
std = self.std(dim = dim, keepdim = True)
|
|
897
|
+
if eps!=0: std.add_(eps)
|
|
898
|
+
return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
|
|
899
|
+
|
|
900
|
+
def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
|
|
901
|
+
"""calculate multipler to clip self norm to min and max"""
|
|
902
|
+
if tensorwise:
|
|
903
|
+
self_norm = self.norm(ord)
|
|
904
|
+
self_norm.masked_fill_(self_norm == 0, 1)
|
|
905
|
+
|
|
906
|
+
else:
|
|
907
|
+
self_norm = self.global_vector_norm(ord)
|
|
908
|
+
if self_norm == 0: return 1
|
|
909
|
+
|
|
910
|
+
mul = 1
|
|
911
|
+
if min is not None:
|
|
912
|
+
mul_to_min = generic_clamp(maybe_numberlist(min) / self_norm, min=1)
|
|
913
|
+
mul *= mul_to_min
|
|
914
|
+
|
|
915
|
+
if max is not None:
|
|
916
|
+
mul_to_max = generic_clamp(maybe_numberlist(max) / self_norm, max=1)
|
|
917
|
+
mul *= mul_to_max
|
|
918
|
+
|
|
919
|
+
return mul
|
|
920
|
+
|
|
921
|
+
def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
|
|
922
|
+
"""clips norm of each tensor to (min, max) range"""
|
|
923
|
+
if min is None and max is None: return self
|
|
924
|
+
return self * self._clip_multiplier(min, max, tensorwise, ord)
|
|
925
|
+
|
|
926
|
+
def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
|
|
927
|
+
"""clips norm of each tensor to (min, max) range"""
|
|
928
|
+
if min is None and max is None: return self
|
|
929
|
+
return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def where(self, condition: "torch.Tensor | _TensorSeq", other: _STOrSTSeq):
|
|
933
|
+
"""self where condition is true other otherwise"""
|
|
934
|
+
return self.zipmap_args(_MethodCallerWithArgs('where'), condition, other)
|
|
935
|
+
def where_(self, condition: "torch.Tensor | _TensorSeq", other: "torch.Tensor | _TensorSeq"):
|
|
936
|
+
"""self where condition is true other otherwise"""
|
|
937
|
+
return self.zipmap_args_inplace_(where_, condition, other)
|
|
938
|
+
|
|
939
|
+
def masked_fill(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
|
|
940
|
+
"""Same as tensor[mask] = value (not in-place), where value must be scalar/scalars"""
|
|
941
|
+
return self.zipmap_args(torch.masked_fill, mask, fill_value)
|
|
942
|
+
def masked_fill_(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
|
|
943
|
+
"""Same as tensor[mask] = value, where value must be scalar/scalars"""
|
|
944
|
+
return self.zipmap_args_inplace_(_MethodCallerWithArgs('masked_fill_'), mask, fill_value)
|
|
945
|
+
|
|
946
|
+
def select_set_(self, mask: _TensorSeq, value: _STOrSTSeq):
|
|
947
|
+
"""Same as tensor[mask] = value"""
|
|
948
|
+
list_value = value if isinstance(value, (list,tuple)) else [value]*len(self)
|
|
949
|
+
for tensor, m, v in zip(self, mask, list_value):
|
|
950
|
+
tensor[m] = v # pyright: ignore[reportArgumentType]
|
|
951
|
+
|
|
952
|
+
def masked_set_(self, mask: _TensorSeq, value: _TensorSeq):
|
|
953
|
+
"""Same as tensor[mask] = value[mask]"""
|
|
954
|
+
for tensor, m, v in zip(self, mask, value):
|
|
955
|
+
tensor[m] = v[m]
|
|
956
|
+
|
|
957
|
+
def select(self, idx: Any):
|
|
958
|
+
"""same as tensor[idx]"""
|
|
959
|
+
if not isinstance(idx, (list,tuple)): return self.__class__(t[idx] for t in self)
|
|
960
|
+
return self.__class__(t[i] for t,i in zip(self, idx))
|
|
961
|
+
|
|
962
|
+
def dot(self, other: _TensorSeq):
|
|
963
|
+
return (self * other).global_sum()
|
|
964
|
+
|
|
965
|
+
def tensorwise_dot(self, other: _TensorSeq):
|
|
966
|
+
return (self * other).sum()
|
|
967
|
+
|
|
968
|
+
def swap_tensors(self, other: _TensorSeq):
|
|
969
|
+
for s, o in zip(self, other):
|
|
970
|
+
torch.utils.swap_tensors(s, o)
|
|
971
|
+
|
|
972
|
+
def unbind_channels(self, dim=0):
|
|
973
|
+
"""returns a new tensorlist where tensors with 2 or more dimensions are split into slices along 1st dimension"""
|
|
974
|
+
return self.__class__(ch for t in self for ch in (t.unbind(dim) if t.ndim >= 2 else (t,)) )
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def flatiter(self) -> Generator[torch.Tensor]:
|
|
978
|
+
for tensor in self:
|
|
979
|
+
yield from tensor.view(-1)
|
|
980
|
+
|
|
981
|
+
# def flatset(self, idx: int, value: Any):
|
|
982
|
+
# """sets index in flattened view"""
|
|
983
|
+
# return self.clone().flatset_(idx, value)
|
|
984
|
+
|
|
985
|
+
def flat_set_(self, idx: int, value: Any):
|
|
986
|
+
"""sets index in flattened view"""
|
|
987
|
+
cur = 0
|
|
988
|
+
for tensor in self:
|
|
989
|
+
numel = tensor.numel()
|
|
990
|
+
if idx < cur + numel:
|
|
991
|
+
tensor.view(-1)[cur-idx] = value
|
|
992
|
+
return self
|
|
993
|
+
cur += numel
|
|
994
|
+
raise IndexError(idx)
|
|
995
|
+
|
|
996
|
+
def flat_set_lambda_(self, idx, fn):
|
|
997
|
+
"""sets index in flattened view to return of fn(current_value)"""
|
|
998
|
+
cur = 0
|
|
999
|
+
for tensor in self:
|
|
1000
|
+
numel = tensor.numel()
|
|
1001
|
+
if idx < cur + numel:
|
|
1002
|
+
flat_view = tensor.view(-1)
|
|
1003
|
+
flat_view[cur-idx] = fn(flat_view[cur-idx])
|
|
1004
|
+
return self
|
|
1005
|
+
cur += numel
|
|
1006
|
+
raise IndexError(idx)
|
|
1007
|
+
|
|
1008
|
+
def __repr__(self):
|
|
1009
|
+
return f"{self.__class__.__name__}({super().__repr__()})"
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
def stack(tensorlists: Iterable[TensorList], dim = 0):
|
|
1013
|
+
"""Returns a tensorlist with the same elements as the input tensorlists, but stacked along the specified dimension."""
|
|
1014
|
+
return TensorList(torch.stack(i, dim = dim) for i in zip(*tensorlists))
|
|
1015
|
+
|
|
1016
|
+
def mean(tensorlists: Iterable[TensorList]):
|
|
1017
|
+
"""Returns a tensorlist which is the mean of given tensorlists."""
|
|
1018
|
+
res = TensorList()
|
|
1019
|
+
for tensors in zip(*tensorlists):
|
|
1020
|
+
res.append(torch.stack(tensors).mean(0))
|
|
1021
|
+
return res
|
|
1022
|
+
|
|
1023
|
+
def median(tensorlists: Iterable[TensorList]):
|
|
1024
|
+
"""Returns a tensorlist which is the median of given tensorlists."""
|
|
1025
|
+
res = TensorList()
|
|
1026
|
+
for tensors in zip(*tensorlists):
|
|
1027
|
+
res.append(torch.stack(tensors).median(0)[0])
|
|
1028
|
+
return res
|
|
1029
|
+
|
|
1030
|
+
def quantile(tensorlists: Iterable[TensorList], q, interpolation = 'linear'):
|
|
1031
|
+
"""Returns a tensorlist which is the median of given tensorlists."""
|
|
1032
|
+
res = TensorList()
|
|
1033
|
+
for tensors in zip(*tensorlists):
|
|
1034
|
+
res.append(torch.stack(tensors).quantile(q=q, dim=0, interpolation=interpolation))
|
|
1035
|
+
return res
|
|
1036
|
+
|
|
1037
|
+
def sum(tensorlists: Iterable[TensorList]):
|
|
1038
|
+
"""Returns a tensorlist which is the sum of given tensorlists."""
|
|
1039
|
+
res = TensorList()
|
|
1040
|
+
for tensors in zip(*tensorlists):
|
|
1041
|
+
res.append(torch.stack(tensors).sum(0))
|
|
1042
|
+
return res
|
|
1043
|
+
|
|
1044
|
+
def where(condition: TensorList, input: _STOrSTSeq, other: _STOrSTSeq):
|
|
1045
|
+
"""Where but for a tensorlist."""
|
|
1046
|
+
args = [i if isinstance(i, (list, tuple)) else [i]*len(condition) for i in (input, other)]
|
|
1047
|
+
return condition.__class__(torch.where(*z) for z in zip(condition, *args))
|
|
1048
|
+
|
|
1049
|
+
def generic_clamp(x: Any, min=None,max=None) -> Any:
|
|
1050
|
+
if isinstance(x, (torch.Tensor, TensorList)): return x.clamp(min,max)
|
|
1051
|
+
if isinstance(x, (list, tuple)): return x.__class__(generic_clamp(i,min,max) for i in x)
|
|
1052
|
+
if x < min: return min
|
|
1053
|
+
if x > max: return max
|
|
1054
|
+
return x
|
|
1055
|
+
|
|
1056
|
+
def generic_numel(x: torch.Tensor | TensorList) -> int:
|
|
1057
|
+
if isinstance(x, torch.Tensor): return x.numel()
|
|
1058
|
+
return x.global_numel()
|
|
1059
|
+
|
|
1060
|
+
@overload
|
|
1061
|
+
def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
1062
|
+
@overload
|
|
1063
|
+
def generic_zeros_like(x: TensorList) -> TensorList: ...
|
|
1064
|
+
def generic_zeros_like(x: torch.Tensor | TensorList):
|
|
1065
|
+
if isinstance(x, torch.Tensor): return torch.zeros_like(x)
|
|
1066
|
+
return x.zeros_like()
|
|
1067
|
+
|
|
1068
|
+
def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
|
|
1069
|
+
if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
|
|
1070
|
+
return x.global_vector_norm(ord)
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
@overload
|
|
1075
|
+
def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
|
|
1076
|
+
@overload
|
|
1077
|
+
def generic_randn_like(x: TensorList) -> TensorList: ...
|
|
1078
|
+
def generic_randn_like(x: torch.Tensor | TensorList):
|
|
1079
|
+
if isinstance(x, torch.Tensor): return torch.randn_like(x)
|
|
1080
|
+
return x.randn_like()
|
|
1081
|
+
|