torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,1081 @@
1
+ # pyright: reportIncompatibleMethodOverride=false
2
+ r"""
3
+ TensorList is a data type that can be used to manipulate a sequence of tensors such as model parameters,
4
+ with the same methods that normal tensors have, plus some additional convenience features.
5
+ Whenever possible, I used _foreach methods and other tricks to speed up computation.
6
+
7
+ TensorList is similar to TensorDict (https://github.com/pytorch/tensordict).
8
+ If you want to get the most performance out of a collection of tensors, use TensorDict and lock it.
9
+ However I found that *creating* a TensorDict is quite slow. In fact it negates the benefits of using it
10
+ in an optimizer when you have to create one from parameters on each step. The solution could be to create
11
+ it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
12
+ """
13
+ import builtins
14
+ from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
15
+ import math
16
+ import operator
17
+ from typing import Any, Literal, TypedDict, overload
18
+ from typing_extensions import Self, TypeAlias, Unpack
19
+
20
+ import torch
21
+ from .ops import where_
22
+ from .python_tools import generic_eq, zipmap
23
+ from .numberlist import NumberList, as_numberlist, maybe_numberlist
24
+
25
+
26
+ _Scalar = int | float | bool | complex
27
+ _TensorSeq = list[torch.Tensor] | tuple[torch.Tensor, ...]
28
+ _ScalarSeq = list[int] | list[float] | list[bool] | list[complex] | tuple[int] | tuple[float] | tuple[bool] | tuple[complex]
29
+ _ScalarSequence = Sequence[_Scalar] # i only check (list,tuple), its faster and safer
30
+ _STSeq = _TensorSeq | _ScalarSeq
31
+ _STOrSTSeq = _Scalar | torch.Tensor | _ScalarSeq | _TensorSeq
32
+
33
+ _Dim = int | list[int] | tuple[int,...] | Literal['global'] | None
34
+
35
+ Distributions = Literal['normal', 'gaussian', 'uniform', 'sphere', 'rademacher']
36
+ class _NewTensorKwargs(TypedDict, total = False):
37
+ memory_format: Any
38
+ dtype: Any
39
+ layout: Any
40
+ device: Any
41
+ pin_memory: bool
42
+ requires_grad: bool
43
+
44
+ # _foreach_methods = {attr.replace('_foreach_', ''):getattr(torch, attr) for attr in dir(torch) if attr.startswith('_foreach_')}
45
+ class _MethodCallerWithArgs:
46
+ """Return a callable object that calls the given method on its operand.
47
+
48
+ This is similar to operator.methodcaller but args and kwargs are specificed in __call__.
49
+
50
+ Args:
51
+ method (str): name of method to call.
52
+ """
53
+ __slots__ = ('_name',)
54
+ def __init__(self, name: str):
55
+ self._name = name
56
+
57
+ def __call__(self, obj, *args, **kwargs):
58
+ return getattr(obj, self._name)(*args, **kwargs)
59
+
60
+ def __repr__(self):
61
+ return f'{self.__class__.__module__}.{self.__class__.__name__}({repr(self._name)})'
62
+
63
+ def __reduce__(self):
64
+ return self.__class__, self._name
65
+
66
+ def as_tensorlist(x):
67
+ if isinstance(x, TensorList): return x
68
+ return TensorList(x)
69
+
70
+
71
+ # tensorlist must subclass list
72
+ # UserList doesn't work with _foreach_xxx
73
+ class TensorList(list[torch.Tensor | Any]):
74
+ @classmethod
75
+ def complex(cls, real: _TensorSeq, imag: _TensorSeq):
76
+ """Create a complex TensorList from real and imaginary tensor sequences."""
77
+ return cls(torch.complex(r, i) for r, i in zip(real, imag))
78
+
79
+ @property
80
+ def device(self): return [i.device for i in self]
81
+ @property
82
+ def dtype(self): return [i.dtype for i in self]
83
+ @property
84
+ def requires_grad(self): return [i.requires_grad for i in self]
85
+ @property
86
+ def shape(self): return [i.shape for i in self]
87
+ def size(self, dim: int | None = None): return [i.size(dim) for i in self]
88
+ @property
89
+ def ndim(self): return [i.ndim for i in self]
90
+ def ndimension(self): return [i.ndimension() for i in self]
91
+ def numel(self): return [i.numel() for i in self]
92
+
93
+ @property
94
+ def grad(self): return self.__class__(i.grad for i in self)
95
+ @property
96
+ def real(self): return self.__class__(i.real for i in self)
97
+ @property
98
+ def imag(self): return self.__class__(i.imag for i in self)
99
+
100
+ def view_as_real(self): return self.__class__(torch.view_as_real(i) for i in self)
101
+ def view_as_complex(self): return self.__class__(torch.view_as_complex(i) for i in self)
102
+
103
+ def type_as(self, other: torch.Tensor | _TensorSeq):
104
+ return self.zipmap(_MethodCallerWithArgs('type_as'), other)
105
+
106
+ def view_as(self, other: torch.Tensor | Sequence[torch.Tensor]):
107
+ if isinstance(other, Sequence): return self.__class__(s.view_as(o) for s, o in zip(self, other))
108
+ return self.__class__(s.view_as(other) for s in self)
109
+
110
+ def fill_none(self, reference: Iterable[torch.Tensor]):
111
+ """all None values are replaced with zeros of the same shape as corresponding `reference` tensor."""
112
+ return self.__class__(t if t is not None else torch.zeros_like(r) for t,r in zip(self, reference))
113
+
114
+ def fill_none_(self, reference: Iterable[torch.Tensor]):
115
+ """all None values are replaced with zeros of the same shape as corresponding `reference` tensor."""
116
+ for i, (t,r) in enumerate(zip(self, reference)):
117
+ if t is None: self[i] = torch.zeros_like(r)
118
+ return self
119
+
120
+ def get_grad(self):
121
+ """Returns all gradients that are not None."""
122
+ return self.__class__(i.grad for i in self if i.grad is not None)
123
+
124
+ def with_requires_grad(self, requires_grad = True):
125
+ """Returns all tensors with requires_grad set to the given value."""
126
+ return self.__class__(i for i in self if i.requires_grad == requires_grad)
127
+
128
+ def with_grad(self):
129
+ """returns all tensors whose .grad is not None"""
130
+ return self.__class__(i for i in self if i.grad is not None)
131
+
132
+ def ensure_grad_(self):
133
+ """For each element, if grad is None and it requires grad, sets grad to zeroes."""
134
+ for i in self:
135
+ if i.requires_grad and i.grad is None: i.grad = torch.zeros_like(i)
136
+ return self
137
+
138
+ def accumulate_grad_(self, grads: _TensorSeq):
139
+ """Creates grad if it is None, otherwise adds to existing grad."""
140
+ for i, g in zip(self, grads):
141
+ if i.grad is None: i.grad = g
142
+ else: i.grad.add_(g)
143
+ return self
144
+
145
+ def set_grad_(self, grads: _TensorSeq):
146
+ """Assings grad attributes to the given sequence, replaces grad that already exists."""
147
+ for i, g in zip(self, grads): i.grad = g
148
+ return self
149
+
150
+ def zero_grad_(self, set_to_none = True):
151
+ """Set all grads to None or zeroes."""
152
+ if set_to_none:
153
+ for p in self: p.grad = None
154
+ else:
155
+ self.get_grad().zero_()
156
+ return self
157
+
158
+ def __add__(self, other: _STOrSTSeq) -> Self: return self.add(other) # pyright: ignore[reportCallIssue,reportArgumentType]
159
+ def __radd__(self, other: _STOrSTSeq) -> Self: return self.add(other) # pyright: ignore[reportCallIssue,reportArgumentType]
160
+ def __iadd__(self, other: _STOrSTSeq) -> Self: return self.add_(other) # pyright: ignore[reportCallIssue,reportArgumentType]
161
+
162
+ def __sub__(self, other: "_Scalar | _STSeq") -> Self: return self.sub(other) # pyright: ignore[reportCallIssue,reportArgumentType]
163
+ def __rsub__(self, other: "_Scalar | _STSeq") -> Self: return self.sub(other).neg_() # pyright: ignore[reportCallIssue,reportArgumentType]
164
+ def __isub__(self, other: "_Scalar | _STSeq") -> Self: return self.sub_(other) # pyright: ignore[reportCallIssue,reportArgumentType]
165
+
166
+ def __mul__(self, other: _STOrSTSeq) -> Self: return self.mul(other)
167
+ def __rmul__(self, other: _STOrSTSeq) -> Self: return self.mul(other)
168
+ def __imul__(self, other: _STOrSTSeq) -> Self: return self.mul_(other)
169
+
170
+ def __truediv__(self, other: "_STOrSTSeq") -> Self: return self.div(other)
171
+ def __rtruediv__(self, other: "_STOrSTSeq") -> Self: return other * self.reciprocal()
172
+ def __itruediv__(self, other: "_STOrSTSeq") -> Self: return self.div_(other)
173
+
174
+ def __floordiv__(self, other: _STOrSTSeq): return self.floor_divide(other)
175
+ #def __rfloordiv__(self, other: "TensorList"): return other.floor_divide(self)
176
+ def __ifloordiv__(self, other: _STOrSTSeq): return self.floor_divide_(other)
177
+
178
+ def __mod__(self, other: _STOrSTSeq): return self.remainder(other)
179
+ #def __rmod__(self, other: STOrSTSequence): return self.remainder(other)
180
+ def __imod__(self, other: _STOrSTSeq):return self.remainder_(other)
181
+
182
+ def __pow__(self, other: "_Scalar | _STSeq"): return self.pow(other)
183
+ def __rpow__(self, other: "_Scalar | _TensorSeq"): return self.rpow(other)
184
+ def __ipow__(self, other: "_Scalar | _STSeq"): return self.pow_(other)
185
+
186
+ def __neg__(self): return self.neg()
187
+
188
+ def __eq__(self, other: _STOrSTSeq): return self.eq(other)
189
+ def __ne__(self, other: _STOrSTSeq): return self.ne(other)
190
+ def __lt__(self, other: _STOrSTSeq): return self.lt(other)
191
+ def __le__(self, other: _STOrSTSeq): return self.le(other)
192
+ def __gt__(self, other: _STOrSTSeq): return self.gt(other)
193
+ def __ge__(self, other: _STOrSTSeq): return self.ge(other)
194
+
195
+ def __invert__(self): return self.logical_not()
196
+
197
+ def __and__(self, other: torch.Tensor | _TensorSeq): return self.logical_and(other)
198
+ def __iand__(self, other: torch.Tensor | _TensorSeq): return self.logical_and_(other)
199
+ def __or__(self, other: torch.Tensor | _TensorSeq): return self.logical_or(other)
200
+ def __ior__(self, other: torch.Tensor | _TensorSeq): return self.logical_or_(other)
201
+ def __xor__(self, other: torch.Tensor | _TensorSeq): return self.logical_xor(other)
202
+ def __ixor__(self, other: torch.Tensor | _TensorSeq): return self.logical_xor_(other)
203
+
204
+ def __bool__(self):
205
+ raise RuntimeError(f'Boolean value of {self.__class__.__name__} is ambiguous')
206
+
207
+ def map(self, fn: Callable[..., torch.Tensor], *args, **kwargs):
208
+ """Applies `fn` to all elements of this TensorList
209
+ and returns a new TensorList with return values of the callable."""
210
+ return self.__class__(fn(i, *args, **kwargs) for i in self)
211
+ def map_inplace_(self, fn: Callable[..., Any], *args, **kwargs):
212
+ """Applies an in-place `fn` to all elements of this TensorList."""
213
+ for i in self: fn(i, *args, **kwargs)
214
+ return self
215
+
216
+ def filter(self, fn: Callable[..., bool], *args, **kwargs):
217
+ """Returns a TensorList with all elements for which `fn` returned True."""
218
+ return self.__class__(i for i in self if fn(i, *args, **kwargs))
219
+
220
+ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
221
+ """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
222
+ Otherwise applies `fn` to this TensorList and `other`.
223
+ Returns a new TensorList with return values of the callable."""
224
+ return zipmap(self, fn, other, *args, **kwargs)
225
+
226
+ def zipmap_inplace_(self, fn: Callable[..., Any], other: Any | list | tuple, *args, **kwargs):
227
+ """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
228
+ Otherwise applies `fn` to this TensorList and `other`.
229
+ The callable must modify elements in-place."""
230
+ if isinstance(other, (list, tuple)):
231
+ for i, j in zip(self, other): fn(i, j, *args, **kwargs)
232
+ else:
233
+ for i in self: fn(i, other, *args, **kwargs)
234
+ return self
235
+
236
+ def zipmap_args(self, fn: Callable[..., Any], *others, **kwargs):
237
+ """If `args` is list/tuple, applies `fn` to this TensorList zipped with `others`.
238
+ Otherwise applies `fn` to this TensorList and `other`."""
239
+ others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
240
+ return self.__class__(fn(*z, **kwargs) for z in zip(self, *others))
241
+
242
+ def zipmap_args_inplace_(self, fn: Callable[..., Any], *others, **kwargs):
243
+ """If `args` is list/tuple, applies `fn` to this TensorList zipped with `other`.
244
+ Otherwise applies `fn` to this TensorList and `other`.
245
+ The callable must modify elements in-place."""
246
+ others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
247
+ for z in zip(self, *others): fn(*z, **kwargs)
248
+ return self
249
+
250
+ def _foreach_apply(self, fn: Callable[[list[torch.Tensor]], list[torch.Tensor]], *args, **kwargs):
251
+ """Applies a torch._foreach_xxx function to self and converts returned list back to TensorList or subclass."""
252
+ return self.__class__(fn(self), *args, **kwargs)
253
+
254
+ # def __getattr__(self, name: str) -> Callable:
255
+ # if name == '__torch_function__' or name == '_ipython_canary_method_should_not_exist_': raise AttributeError('who ???')
256
+ # if name in _foreach_methods:
257
+ # method = partial(self._foreach_apply, _foreach_methods[name])
258
+ # else:
259
+ # method = partial(self.map, MethodCallerWithArgs(name))
260
+ # setattr(self, name, method)
261
+ # return method
262
+
263
+ def to(self, *args, **kwargs): return self.__class__(i.to(*args, **kwargs) for i in self)
264
+ def cuda(self): return self.__class__(i.cuda() for i in self)
265
+ def cpu(self): return self.__class__(i.cpu() for i in self)
266
+ def long(self): return self.__class__(i.long() for i in self)
267
+ def short(self): return self.__class__(i.short() for i in self)
268
+ def clone(self): return self.__class__(i.clone() for i in self)
269
+ def detach(self): return self.__class__(i.detach() for i in self)
270
+ def detach_(self):
271
+ for i in self: i.detach_()
272
+ return self
273
+ def contiguous(self): return self.__class__(i.contiguous() for i in self)
274
+
275
+ # apparently I can't use float for typing if I call a method "float"
276
+ def as_float(self): return self.__class__(i.float() for i in self)
277
+ def as_bool(self): return self.__class__(i.bool() for i in self)
278
+ def as_int(self): return self.__class__(i.int() for i in self)
279
+
280
+ def copy_(self, src: _TensorSeq, non_blocking = False):
281
+ """Copies the elements from src tensors into self tensors."""
282
+ torch._foreach_copy_(self, src, non_blocking=non_blocking)
283
+ def set_(self, storage: Iterable[torch.Tensor | torch.types.Storage]):
284
+ """Sets elements of this TensorList to the values of a list of tensors."""
285
+ for i, j in zip(self, storage): i.set_(j) # pyright:ignore[reportArgumentType]
286
+ return self
287
+
288
+
289
+ def requires_grad_(self, mode: bool = True):
290
+ for e in self: e.requires_grad_(mode)
291
+ return self
292
+
293
+ def to_vec(self): return torch.cat(self.ravel())
294
+ def from_vec_(self, vec:torch.Tensor):
295
+ """Sets elements of this TensorList to the values of a 1D tensor.
296
+ The length of the tensor must be equal to the total number of elements in this TensorList."""
297
+ cur = 0
298
+ for el in self:
299
+ numel = el.numel()
300
+ el.set_(vec[cur:cur + numel].type_as(el).view_as(el)) # pyright:ignore[reportArgumentType]
301
+ cur += numel
302
+ return self
303
+
304
+ def from_vec(self, vec:torch.Tensor):
305
+ """Creates a new TensorList from this TensorList but with values from a 1D tensor.
306
+ The length of the tensor must be equal to the total number of elements in this TensorList."""
307
+ res = []
308
+ cur = 0
309
+ for el in self:
310
+ numel = el.numel()
311
+ res.append(vec[cur:cur + numel].type_as(el).view_as(el))
312
+ cur += numel
313
+ return TensorList(res)
314
+
315
+ # using single operation on a vec, e.g. torch.sum(self.to_vec()) can be faster but its less memory efficient
316
+ def global_min(self) -> torch.Tensor: return builtins.min(self.min()) # pyright:ignore[reportArgumentType]
317
+ def global_max(self) -> torch.Tensor: return builtins.max(self.max()) # pyright:ignore[reportArgumentType]
318
+ def global_mean(self) -> torch.Tensor: return self.global_sum()/self.global_numel()
319
+ def global_sum(self) -> torch.Tensor: return builtins.sum(self.sum()) # pyright:ignore[reportArgumentType,reportReturnType]
320
+ def global_std(self) -> torch.Tensor: return torch.std(self.to_vec())
321
+ def global_var(self) -> torch.Tensor: return torch.var(self.to_vec())
322
+ def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
323
+ return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
324
+ def global_any(self): return builtins.any(self.any())
325
+ def global_all(self): return builtins.all(self.all())
326
+ def global_numel(self) -> int: return builtins.sum(self.numel())
327
+
328
+ def empty_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.empty_like(i, **kwargs) for i in self)
329
+ def zeros_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.zeros_like(i, **kwargs) for i in self)
330
+ def ones_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.ones_like(i, **kwargs) for i in self)
331
+ def full_like(self, fill_value: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
332
+ #return self.__class__(torch.full_like(i, fill_value=fill_value, **kwargs) for i in self)
333
+ return self.zipmap(torch.full_like, other=fill_value, **kwargs)
334
+
335
+ def rand_like(self, generator=None, dtype: Any=None, device: Any=None, **kwargs):
336
+ if generator is not None:
337
+ return self.__class__(torch.rand(t.shape, generator=generator,
338
+ dtype=t.dtype if dtype is None else dtype,
339
+ device=t.device if device is None else device, **kwargs) for t in self)
340
+
341
+ return self.__class__(torch.rand_like(i, dtype=dtype, device=device, **kwargs) for i in self)
342
+
343
+ def randn_like(self, generator=None, dtype: Any=None, device: Any=None, **kwargs):
344
+
345
+ if generator is not None:
346
+ return self.__class__(torch.randn(t.shape, generator=generator,
347
+ dtype=t.dtype if dtype is None else dtype,
348
+ device=t.device if device is None else device, **kwargs) for t in self)
349
+
350
+ return self.__class__(torch.randn_like(i, dtype=dtype, device=device, **kwargs) for i in self)
351
+
352
+ def randint_like(self, low: "_Scalar | _ScalarSeq", high: "_Scalar | _ScalarSeq", **kwargs: Unpack[_NewTensorKwargs]):
353
+ return self.zipmap_args(torch.randint_like, low, high, **kwargs)
354
+ def uniform_like(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
355
+ res = self.empty_like(**kwargs)
356
+ res.uniform_(low, high, generator=generator)
357
+ return res
358
+ def sphere_like(self, radius: "_Scalar | _ScalarSeq", generator=None, **kwargs: Unpack[_NewTensorKwargs]) -> Self:
359
+ r = self.randn_like(generator=generator, **kwargs)
360
+ return (r * radius) / r.global_vector_norm()
361
+ def bernoulli(self, generator = None):
362
+ return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
363
+ def bernoulli_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
364
+ """p is probability of a 1, other values will be 0."""
365
+ return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
366
+ def rademacher_like(self, p: "_Scalar | _ScalarSeq" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
367
+ """p is probability of a 1, other values will be -1."""
368
+ return self.bernoulli_like(p, generator=generator, **kwargs).mul_(2).sub_(1)
369
+
370
+ def sample_like(self, eps: "_Scalar | _ScalarSeq" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
371
+ """Sample around 0."""
372
+ if distribution in ('normal', 'gaussian'): return self.randn_like(generator=generator, **kwargs) * eps
373
+ if distribution == 'uniform':
374
+ if isinstance(eps, (list,tuple)):
375
+ return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs)
376
+ return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
377
+ if distribution == 'sphere': return self.sphere_like(eps, generator=generator, **kwargs)
378
+ if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
379
+ raise ValueError(f'Unknow distribution {distribution}')
380
+
381
+ def eq(self, other: _STOrSTSeq): return self.zipmap(torch.eq, other)
382
+ def eq_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('eq_'), other)
383
+ def ne(self, other: _STOrSTSeq): return self.zipmap(torch.ne, other)
384
+ def ne_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('ne_'), other)
385
+ def lt(self, other: _STOrSTSeq): return self.zipmap(torch.lt, other)
386
+ def lt_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('lt_'), other)
387
+ def le(self, other: _STOrSTSeq): return self.zipmap(torch.le, other)
388
+ def le_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('le_'), other)
389
+ def gt(self, other: _STOrSTSeq): return self.zipmap(torch.gt, other)
390
+ def gt_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('gt_'), other)
391
+ def ge(self, other: _STOrSTSeq): return self.zipmap(torch.ge, other)
392
+ def ge_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('ge_'), other)
393
+
394
+ def logical_and(self, other: torch.Tensor | _TensorSeq): return self.zipmap(torch.logical_and, other)
395
+ def logical_and_(self, other: torch.Tensor | _TensorSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_and_'), other)
396
+ def logical_or(self, other: torch.Tensor | _TensorSeq): return self.zipmap(torch.logical_or, other)
397
+ def logical_or_(self, other: torch.Tensor | _TensorSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_or_'), other)
398
+ def logical_xor(self, other: torch.Tensor | _TensorSeq): return self.zipmap(torch.logical_xor, other)
399
+ def logical_xor_(self, other: torch.Tensor | _TensorSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_xor_'), other)
400
+
401
+ def logical_not(self): return self.__class__(torch.logical_not(i) for i in self)
402
+ def logical_not_(self):
403
+ for i in self: i.logical_not_()
404
+ return self
405
+
406
+ def equal(self, other: torch.Tensor | _TensorSeq):
407
+ """returns TensorList of boolean values, True if two tensors have the same size and elements, False otherwise."""
408
+ return self.zipmap(torch.equal, other)
409
+
410
+ @overload
411
+ def add(self, other: torch.Tensor | _TensorSeq, alpha: _Scalar = 1): ...
412
+ @overload
413
+ def add(self, other: _Scalar | _ScalarSeq): ...
414
+ def add(self, other: _STOrSTSeq, alpha: _Scalar = 1):
415
+ if alpha == 1: return self.__class__(torch._foreach_add(self, other))
416
+ return self.__class__(torch._foreach_add(self, other, alpha = alpha)) # pyright:ignore[reportCallIssue,reportArgumentType]
417
+
418
+ @overload
419
+ def add_(self, other: torch.Tensor | _TensorSeq, alpha: _Scalar = 1): ...
420
+ @overload
421
+ def add_(self, other: _Scalar | _ScalarSeq): ...
422
+ def add_(self, other: _STOrSTSeq, alpha: _Scalar = 1):
423
+ if alpha == 1: torch._foreach_add_(self, other)
424
+ else: torch._foreach_add_(self, other, alpha = alpha) # pyright:ignore[reportCallIssue,reportArgumentType]
425
+ return self
426
+
427
+ def lazy_add(self, other: int | float | list[int | float] | tuple[int | float]):
428
+ if generic_eq(other, 0): return self
429
+ return self.add(other)
430
+ def lazy_add_(self, other: int | float | list[int | float] | tuple[int | float]):
431
+ if generic_eq(other, 0): return self
432
+ return self.add_(other)
433
+
434
+ @overload
435
+ def sub(self, other: _TensorSeq, alpha: _Scalar = 1): ...
436
+ @overload
437
+ def sub(self, other: _Scalar | _ScalarSeq): ...
438
+ def sub(self, other: "_Scalar | _STSeq", alpha: _Scalar = 1):
439
+ if alpha == 1: return self.__class__(torch._foreach_sub(self, other))
440
+ return self.__class__(torch._foreach_sub(self, other, alpha = alpha)) # pyright:ignore[reportArgumentType]
441
+
442
+ @overload
443
+ def sub_(self, other: _TensorSeq, alpha: _Scalar = 1): ...
444
+ @overload
445
+ def sub_(self, other: _Scalar | _ScalarSeq): ...
446
+ def sub_(self, other: "_Scalar | _STSeq", alpha: _Scalar = 1):
447
+ if alpha == 1: torch._foreach_sub_(self, other)
448
+ else: torch._foreach_sub_(self, other, alpha = alpha) # pyright:ignore[reportArgumentType]
449
+ return self
450
+
451
+ def lazy_sub(self, other: int | float | list[int | float] | tuple[int | float]):
452
+ if generic_eq(other, 0): return self
453
+ return self.sub(other)
454
+ def lazy_sub_(self, other: int | float | list[int | float] | tuple[int | float]):
455
+ if generic_eq(other, 0): return self
456
+ return self.sub_(other)
457
+
458
+ def neg(self): return self.__class__(torch._foreach_neg(self))
459
+ def neg_(self):
460
+ torch._foreach_neg_(self)
461
+ return self
462
+
463
+ def mul(self, other: _STOrSTSeq): return self.__class__(torch._foreach_mul(self, other))
464
+ def mul_(self, other: _STOrSTSeq):
465
+ torch._foreach_mul_(self, other)
466
+ return self
467
+
468
+ # TODO: benchmark
469
+ def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
470
+ if generic_eq(other, 1):
471
+ if clone: return self.clone()
472
+ return self
473
+ return self * other
474
+ def lazy_mul_(self, other: int | float | list[int | float] | tuple[int | float]):
475
+ if generic_eq(other, 1): return self
476
+ return self.mul_(other)
477
+
478
+ def div(self, other: _STOrSTSeq) -> Self: return self.__class__(torch._foreach_div(self, other))
479
+ def div_(self, other: _STOrSTSeq):
480
+ torch._foreach_div_(self, other)
481
+ return self
482
+
483
+ def lazy_div(self, other: int | float | list[int | float] | tuple[int | float]):
484
+ if generic_eq(other, 1): return self
485
+ return self / other
486
+ def lazy_div_(self, other: int | float | list[int | float] | tuple[int | float]):
487
+ if generic_eq(other, 1): return self
488
+ return self.div_(other)
489
+
490
+ def pow(self, exponent: "_Scalar | _STSeq"): return self.__class__(torch._foreach_pow(self, exponent))
491
+ def pow_(self, exponent: "_Scalar | _STSeq"):
492
+ torch._foreach_pow_(self, exponent)
493
+ return self
494
+
495
+ def rpow(self, input: _Scalar | _TensorSeq): return self.__class__(torch._foreach_pow(input, self))
496
+ def rpow_(self, input: _TensorSeq):
497
+ torch._foreach_pow_(input, self)
498
+ return self
499
+
500
+ def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
501
+ def sqrt_(self):
502
+ torch._foreach_sqrt_(self)
503
+ return self
504
+
505
+ def remainder(self, other: _STOrSTSeq): return self.zipmap(torch.remainder, other)
506
+ def remainder_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('remainder_'), other)
507
+
508
+ def floor_divide(self, other: _STOrSTSeq): return self.zipmap(torch.floor_divide, other)
509
+ def floor_divide_(self, other: _STOrSTSeq): return self.zipmap_inplace_(_MethodCallerWithArgs('floor_divide_'), other)
510
+
511
+ def reciprocal(self): return self.__class__(torch._foreach_reciprocal(self))
512
+ def reciprocal_(self):
513
+ torch._foreach_reciprocal_(self)
514
+ return self
515
+
516
+ def abs(self): return self.__class__(torch._foreach_abs(self))
517
+ def abs_(self):
518
+ torch._foreach_abs_(self)
519
+ return self
520
+
521
+ def sign(self): return self.__class__(torch._foreach_sign(self))
522
+ def sign_(self):
523
+ torch._foreach_sign_(self)
524
+ return self
525
+
526
+ def exp(self): return self.__class__(torch._foreach_exp(self))
527
+ def exp_(self):
528
+ torch._foreach_exp_(self)
529
+ return self
530
+
531
+ def signbit(self): return self.__class__(torch.signbit(i) for i in self)
532
+
533
+ def sin(self): return self.__class__(torch._foreach_sin(self))
534
+ def sin_(self):
535
+ torch._foreach_sin_(self)
536
+ return self
537
+
538
+ def cos(self): return self.__class__(torch._foreach_cos(self))
539
+ def cos_(self):
540
+ torch._foreach_cos_(self)
541
+ return self
542
+
543
+ def tan(self): return self.__class__(torch._foreach_tan(self))
544
+ def tan_(self):
545
+ torch._foreach_tan_(self)
546
+ return self
547
+
548
+ def asin(self): return self.__class__(torch._foreach_asin(self))
549
+ def asin_(self):
550
+ torch._foreach_asin_(self)
551
+ return self
552
+
553
+ def acos(self): return self.__class__(torch._foreach_acos(self))
554
+ def acos_(self):
555
+ torch._foreach_acos_(self)
556
+ return self
557
+
558
+ def atan(self): return self.__class__(torch._foreach_atan(self))
559
+ def atan_(self):
560
+ torch._foreach_atan_(self)
561
+ return self
562
+
563
+ def sinh(self): return self.__class__(torch._foreach_sinh(self))
564
+ def sinh_(self):
565
+ torch._foreach_sinh_(self)
566
+ return self
567
+
568
+ def cosh(self): return self.__class__(torch._foreach_cosh(self))
569
+ def cosh_(self):
570
+ torch._foreach_cosh_(self)
571
+ return self
572
+
573
+ def tanh(self): return self.__class__(torch._foreach_tanh(self))
574
+ def tanh_(self):
575
+ torch._foreach_tanh_(self)
576
+ return self
577
+
578
+ def log(self): return self.__class__(torch._foreach_log(self))
579
+ def log_(self):
580
+ torch._foreach_log_(self)
581
+ return self
582
+
583
+ def log10(self): return self.__class__(torch._foreach_log10(self))
584
+ def log10_(self):
585
+ torch._foreach_log10_(self)
586
+ return self
587
+
588
+ def log2(self): return self.__class__(torch._foreach_log2(self))
589
+ def log2_(self):
590
+ torch._foreach_log2_(self)
591
+ return self
592
+
593
+ def log1p(self): return self.__class__(torch._foreach_log1p(self))
594
+ def log1p_(self):
595
+ torch._foreach_log1p_(self)
596
+ return self
597
+
598
+ def erf(self): return self.__class__(torch._foreach_erf(self))
599
+ def erf_(self):
600
+ torch._foreach_erf_(self)
601
+ return self
602
+
603
+ def erfc(self): return self.__class__(torch._foreach_erfc(self))
604
+ def erfc_(self):
605
+ torch._foreach_erfc_(self)
606
+ return self
607
+
608
+ def sigmoid(self): return self.__class__(torch._foreach_sigmoid(self))
609
+ def sigmoid_(self):
610
+ torch._foreach_sigmoid_(self)
611
+ return self
612
+
613
+ def _global_fn(self, keepdim, fn, *args, **kwargs):
614
+ """checks that keepdim is False and returns fn(*args, **kwargs)"""
615
+ #if keepdim: raise ValueError('dim = global and keepdim = True')
616
+ return fn(*args, **kwargs)
617
+
618
+ def max(self, dim: _Dim = None, keepdim = False) -> Self | Any:
619
+ if dim is None and not keepdim: return self.__class__(torch._foreach_max(self))
620
+ if dim == 'global': return self._global_fn(keepdim, self.global_max)
621
+ if dim is None: dim = ()
622
+ return self.__class__(i.amax(dim=dim, keepdim=keepdim) for i in self)
623
+
624
+ def min(self, dim: _Dim = None, keepdim = False) -> Self | Any:
625
+ if dim is None and not keepdim: return self.__class__(torch._foreach_max(self.neg())).neg_()
626
+ if dim == 'global': return self._global_fn(keepdim, self.global_min)
627
+ if dim is None: dim = ()
628
+ return self.__class__(i.amin(dim=dim, keepdim=keepdim) for i in self)
629
+
630
+ def norm(self, ord: _Scalar, dtype=None):
631
+ return self.__class__(torch._foreach_norm(self, ord, dtype))
632
+
633
+ def mean(self, dim: _Dim = None, keepdim = False) -> Self | Any:
634
+ if dim == 'global': return self._global_fn(keepdim, self.global_mean)
635
+ return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
636
+
637
+ def sum(self, dim: _Dim = None, keepdim = False) -> Self | Any:
638
+ if dim == 'global': return self._global_fn(keepdim, self.global_sum)
639
+ return self.__class__(i.sum(dim=dim, keepdim=keepdim) for i in self)
640
+
641
+ def prod(self, dim = None, keepdim = False): return self.__class__(i.prod(dim=dim, keepdim=keepdim) for i in self)
642
+
643
+ def std(self, dim: _Dim = None, unbiased: bool = True, keepdim = False) -> Self | Any:
644
+ if dim == 'global': return self._global_fn(keepdim, self.global_std)
645
+ return self.__class__(i.std(dim=dim, unbiased=unbiased, keepdim=keepdim) for i in self)
646
+
647
+ def var(self, dim: _Dim = None, unbiased: bool = True, keepdim = False) -> Self | Any:
648
+ if dim == 'global': return self._global_fn(keepdim, self.global_var)
649
+ return self.__class__(i.var(dim=dim, unbiased=unbiased, keepdim=keepdim) for i in self)
650
+
651
+ def median(self, dim=None, keepdim=False):
652
+ """note this doesn't return indices"""
653
+ # median returns tensor or namedtuple (values, indices)
654
+ if dim is None: return self.__class__(i.median() for i in self)
655
+ return self.__class__(i.median(dim=dim, keepdim=keepdim)[0] for i in self)
656
+
657
+ def quantile(self, q, dim=None, keepdim=False, *, interpolation='linear',):
658
+ return self.__class__(i.quantile(q=q, dim=dim, keepdim=keepdim, interpolation=interpolation) for i in self)
659
+
660
+ def clamp_min(self, other: "_Scalar | _STSeq"): return self.__class__(torch._foreach_clamp_min(self, other))
661
+ def clamp_min_(self, other: "_Scalar | _STSeq"):
662
+ torch._foreach_clamp_min_(self, other)
663
+ return self
664
+ def clamp_max(self, other: "_Scalar | _STSeq"): return self.__class__(torch._foreach_clamp_max(self, other))
665
+ def clamp_max_(self, other: "_Scalar | _STSeq"):
666
+ torch._foreach_clamp_max_(self, other)
667
+ return self
668
+
669
+ def clamp(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
670
+ l = self
671
+ if min is not None: l = l.clamp_min(min)
672
+ if max is not None: l = l.clamp_max(max)
673
+ return l
674
+ def clamp_(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
675
+ if min is not None: self.clamp_min_(min)
676
+ if max is not None: self.clamp_max_(max)
677
+ return self
678
+
679
+ def clip(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None): return self.clamp(min,max)
680
+ def clip_(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None): return self.clamp_(min,max)
681
+
682
+ def clamp_magnitude(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
683
+ return self.abs().clamp_(min, max) * self.sign().add_(0.5).sign_() # this prevents zeros
684
+ def clamp_magnitude_(self, min: "_Scalar | _STSeq | None" = None, max: "_Scalar | _STSeq | None" = None):
685
+ sign = self.sign().add_(0.5).sign_()
686
+ return self.abs_().clamp_(min, max).mul_(sign)
687
+
688
+
689
+ def floor(self): return self.__class__(torch._foreach_floor(self))
690
+ def floor_(self):
691
+ torch._foreach_floor_(self)
692
+ return self
693
+ def ceil(self): return self.__class__(torch._foreach_ceil(self))
694
+ def ceil_(self):
695
+ torch._foreach_ceil_(self)
696
+ return self
697
+ def round(self): return self.__class__(torch._foreach_round(self))
698
+ def round_(self):
699
+ torch._foreach_round_(self)
700
+ return self
701
+
702
+ def zero_(self):
703
+ torch._foreach_zero_(self)
704
+ return self
705
+
706
+ def lerp(self, tensors1: _TensorSeq, weight: "_Scalar | _ScalarSeq | _TensorSeq"):
707
+ """linear interpolation of between self and tensors1. `out = self + weight * (tensors1 - self)`."""
708
+ return self.__class__(torch._foreach_lerp(self, tensors1, weight))
709
+ def lerp_(self, tensors1: _TensorSeq, weight: "_Scalar | _ScalarSeq | _TensorSeq"):
710
+ """linear interpolation of between self and tensors1. `out = self + weight * (tensors1 - self)`."""
711
+ torch._foreach_lerp_(self, tensors1, weight)
712
+ return self
713
+
714
+ def lerp_compat(self, tensors1: _TensorSeq, weight: "_STOrSTSeq"):
715
+ """`lerp` but support scalar sequence weight on pytorch versions before 2.6
716
+
717
+ `out = self + weight * (tensors1 - self)`."""
718
+ return self + weight * (TensorList(tensors1) - self)
719
+ def lerp_compat_(self, tensors1: _TensorSeq, weight: "_STOrSTSeq"):
720
+ """`lerp_` but support scalar sequence weight on previous pytorch versions before 2.6
721
+
722
+ `out = self + weight * (tensors1 - self)`."""
723
+ return self.add_(TensorList(tensors1).sub(self).mul_(weight))
724
+
725
+ def addcmul(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
726
+ return self.__class__(torch._foreach_addcmul(self, tensors1, tensor2, value))
727
+ def addcmul_(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
728
+ torch._foreach_addcmul_(self, tensors1, tensor2, value)
729
+ return self
730
+ def addcdiv(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
731
+ return self.__class__(torch._foreach_addcdiv(self, tensors1, tensor2, value))
732
+ def addcdiv_(self, tensors1: _TensorSeq, tensor2: _TensorSeq, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
733
+ torch._foreach_addcdiv_(self, tensors1, tensor2, value)
734
+ return self
735
+
736
+ def uniform_(self, low: "_Scalar | _ScalarSeq" = 0, high: "_Scalar | _ScalarSeq" = 1, generator = None):
737
+ return self.zipmap_args_inplace_(_MethodCallerWithArgs('uniform_'), low, high, generator = generator)
738
+
739
+ def maximum(self, other: "_Scalar | _ScalarSeq | _TensorSeq"):
740
+ return self.__class__(torch._foreach_maximum(self, other))
741
+ def maximum_(self, other: "_Scalar | _ScalarSeq | _TensorSeq"): # ruff: noqa F811
742
+ torch._foreach_maximum_(self, other)
743
+ return self
744
+
745
+ def minimum(self, other: "_Scalar | _ScalarSeq | _TensorSeq"):
746
+ return self.__class__(torch._foreach_minimum(self, other))
747
+ def minimum_(self, other: "_Scalar | _ScalarSeq | _TensorSeq"):
748
+ torch._foreach_minimum_(self, other)
749
+ return self
750
+
751
+ def squeeze(self, dim = None):
752
+ if dim is None: return self.__class__(i.squeeze() for i in self)
753
+ return self.__class__(i.squeeze(dim) for i in self)
754
+
755
+ def squeeze_(self, dim = None):
756
+ if dim is None:
757
+ for i in self: i.squeeze_()
758
+ else:
759
+ for i in self: i.squeeze_(dim)
760
+ return self
761
+
762
+ def conj(self): return self.__class__(i.conj() for i in self)
763
+
764
+ def nan_to_num(self, nan: "float | _ScalarSeq | None" = None, posinf: "float | _ScalarSeq | None" = None, neginf: "float | _ScalarSeq | None" = None):
765
+ return self.zipmap_args(torch.nan_to_num, nan, posinf, neginf)
766
+ def nan_to_num_(self, nan: "float | _ScalarSeq | None" = None, posinf: "float | _ScalarSeq | None" = None, neginf: "float | _ScalarSeq | None" = None):
767
+ return self.zipmap_args_inplace_(torch.nan_to_num_, nan, posinf, neginf)
768
+
769
+ def ravel(self): return self.__class__(i.ravel() for i in self)
770
+ def view_flat(self): return self.__class__(i.view(-1) for i in self)
771
+
772
+ def any(self): return self.__class__(i.any() for i in self)
773
+ def all(self): return self.__class__(i.all() for i in self)
774
+ def isfinite(self): return self.__class__(i.isfinite() for i in self)
775
+
776
+ def fill(self, value: _STOrSTSeq): return self.zipmap(torch.fill, other = value)
777
+ def fill_(self, value: _STOrSTSeq): return self.zipmap_inplace_(torch.fill_, other = value)
778
+
779
+ def copysign(self, other):
780
+ return self.__class__(t.copysign(o) for t, o in zip(self, other))
781
+ def copysign_(self, other):
782
+ for t, o in zip(self, other): t.copysign_(o)
783
+ return self
784
+
785
+ def graft(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
786
+ if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
787
+ if tensorwise:
788
+ norm_self = self.norm(ord)
789
+ norm_other = magnitude.norm(ord)
790
+ else:
791
+ norm_self = self.global_vector_norm(ord)
792
+ norm_other = magnitude.global_vector_norm(ord)
793
+
794
+ if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
795
+
796
+ return self * (norm_other / norm_self.clip_(min=eps))
797
+
798
+ def graft_(self, magnitude: "_TensorSeq", tensorwise=False, ord: float = 2, eps = 1e-6, strength: float | _ScalarSeq = 1):
799
+ if not isinstance(magnitude, TensorList): magnitude = TensorList(magnitude)
800
+ if tensorwise:
801
+ norm_self = self.norm(ord)
802
+ norm_other = magnitude.norm(ord)
803
+ else:
804
+ norm_self = self.global_vector_norm(ord)
805
+ norm_other = magnitude.global_vector_norm(ord)
806
+
807
+ if not generic_eq(strength, 1): norm_other.lerp_(norm_self, 1-maybe_numberlist(strength)) # pyright:ignore[reportCallIssue,reportArgumentType]
808
+
809
+ return self.mul_(norm_other / norm_self.clip_(min=eps))
810
+
811
+ def _get_rescale_coeffs(self, min:"_Scalar | _ScalarSeq", max:"_Scalar | _ScalarSeq", dim: _Dim, eps):
812
+ self_min = self.min(dim=dim, keepdim=True)
813
+ self_max = self.max(dim=dim, keepdim=True)
814
+
815
+ # target range difference (diff)
816
+ min = maybe_numberlist(min)
817
+ max = maybe_numberlist(max)
818
+ diff = max - min
819
+ target_min = min
820
+ source_range = (self_max - self_min).add_(eps)
821
+ a = diff / source_range
822
+ b = target_min - (a * self_min)
823
+
824
+ return a, b
825
+
826
+ def rescale(self, min: "_Scalar | _ScalarSeq | None", max: "_Scalar | _ScalarSeq | None", dim: _Dim = None, eps=0.):
827
+ """rescales each tensor to (min, max) range"""
828
+ if min is None and max is None: return self
829
+ if max is None:
830
+ assert min is not None
831
+ return self - (self.min(dim=dim, keepdim=True).sub_(min))
832
+ if min is None: return self - (self.max(dim=dim, keepdim=True).sub_(max))
833
+
834
+ a,b = self._get_rescale_coeffs(min=min, max=max, dim=dim, eps=eps)
835
+ return (self*a).add_(b)
836
+
837
+ def rescale_(self, min: "_Scalar | _ScalarSeq | None", max: "_Scalar | _ScalarSeq | None", dim: _Dim = None, eps=0.):
838
+ """rescales each tensor to (min, max) range"""
839
+ if min is None and max is None: return self
840
+ if max is None:
841
+ assert min is not None
842
+ return self.sub_(self.min(dim=dim, keepdim=True).sub_(min))
843
+ if min is None: return self.sub_(self.max(dim=dim, keepdim=True).sub_(max))
844
+
845
+ a,b = self._get_rescale_coeffs(min=min, max=max, dim=dim, eps=eps)
846
+ return (self.mul_(a)).add_(b)
847
+
848
+ def rescale_to_01(self, dim: _Dim = None, eps: float = 0):
849
+ """faster method to rescale to (0, 1) range"""
850
+ res = self - self.min(dim = dim, keepdim=True)
851
+ max = res.max(dim = dim, keepdim=True)
852
+ if eps != 0: max.add_(eps)
853
+ return res.div_(max)
854
+
855
+ def rescale_to_01_(self, dim: _Dim = None, eps: float = 0):
856
+ """faster method to rescale to (0, 1) range"""
857
+ self.sub_(self.min(dim = dim, keepdim=True))
858
+ max = self.max(dim = dim, keepdim=True)
859
+ if eps != 0: max.add_(eps)
860
+ return self.div_(max)
861
+
862
+ def normalize(self, mean: "_Scalar | _ScalarSeq | None", var: "_Scalar | _ScalarSeq | None", dim: _Dim = None): # pylint:disable=redefined-outer-name
863
+ """normalizes to mean and variance"""
864
+ if mean is None and var is None: return self
865
+ if mean is None: return self / self.std(dim = dim, keepdim = True)
866
+ if var is None: return self - self.mean(dim = dim, keepdim = True)
867
+ self_mean = self.mean(dim = dim, keepdim = True)
868
+ self_std = self.std(dim = dim, keepdim = True)
869
+
870
+ if isinstance(var, Sequence): var_sqrt = [i**0.5 for i in var]
871
+ else: var_sqrt = var ** 0.5
872
+
873
+ return (self - self_mean).div_(self_std).mul_(var_sqrt).add_(mean)
874
+
875
+ def normalize_(self, mean: "_Scalar | _ScalarSeq | None", var: "_Scalar | _ScalarSeq | None", dim: _Dim = None): # pylint:disable=redefined-outer-name
876
+ """normalizes to mean and variance"""
877
+ if mean is None and var is None: return self
878
+ if mean is None: return self / self.std(dim = dim, keepdim = True)
879
+ if var is None: return self - self.mean(dim = dim, keepdim = True)
880
+ self_mean = self.mean(dim = dim, keepdim = True)
881
+ self_std = self.std(dim = dim, keepdim = True)
882
+
883
+ if isinstance(var, Sequence): var_sqrt = [i**0.5 for i in var]
884
+ else: var_sqrt = var ** 0.5
885
+
886
+ return self.sub_(self_mean).div_(self_std).mul_(var_sqrt).add_(mean)
887
+
888
+ def znormalize(self, dim: _Dim = None, eps:float = 0):
889
+ """faster method to normalize to 0 mean and 1 variance"""
890
+ std = self.std(dim = dim, keepdim = True)
891
+ if eps!=0: std.add_(eps)
892
+ return (self - self.mean(dim = dim, keepdim=True)).div_(std)
893
+
894
+ def znormalize_(self, dim: _Dim = None, eps:float = 0):
895
+ """faster method to normalize to 0 mean and 1 variance"""
896
+ std = self.std(dim = dim, keepdim = True)
897
+ if eps!=0: std.add_(eps)
898
+ return self.sub_(self.mean(dim = dim, keepdim=True)).div_(std)
899
+
900
+ def _clip_multiplier(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
901
+ """calculate multipler to clip self norm to min and max"""
902
+ if tensorwise:
903
+ self_norm = self.norm(ord)
904
+ self_norm.masked_fill_(self_norm == 0, 1)
905
+
906
+ else:
907
+ self_norm = self.global_vector_norm(ord)
908
+ if self_norm == 0: return 1
909
+
910
+ mul = 1
911
+ if min is not None:
912
+ mul_to_min = generic_clamp(maybe_numberlist(min) / self_norm, min=1)
913
+ mul *= mul_to_min
914
+
915
+ if max is not None:
916
+ mul_to_max = generic_clamp(maybe_numberlist(max) / self_norm, max=1)
917
+ mul *= mul_to_max
918
+
919
+ return mul
920
+
921
+ def clip_norm(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
922
+ """clips norm of each tensor to (min, max) range"""
923
+ if min is None and max is None: return self
924
+ return self * self._clip_multiplier(min, max, tensorwise, ord)
925
+
926
+ def clip_norm_(self, min: "_Scalar | _ScalarSeq | None"= None, max: "_Scalar | _ScalarSeq | None" = None, tensorwise: bool = True, ord:float = 2):
927
+ """clips norm of each tensor to (min, max) range"""
928
+ if min is None and max is None: return self
929
+ return self.mul_(self._clip_multiplier(min, max, tensorwise, ord))
930
+
931
+
932
+ def where(self, condition: "torch.Tensor | _TensorSeq", other: _STOrSTSeq):
933
+ """self where condition is true other otherwise"""
934
+ return self.zipmap_args(_MethodCallerWithArgs('where'), condition, other)
935
+ def where_(self, condition: "torch.Tensor | _TensorSeq", other: "torch.Tensor | _TensorSeq"):
936
+ """self where condition is true other otherwise"""
937
+ return self.zipmap_args_inplace_(where_, condition, other)
938
+
939
+ def masked_fill(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
940
+ """Same as tensor[mask] = value (not in-place), where value must be scalar/scalars"""
941
+ return self.zipmap_args(torch.masked_fill, mask, fill_value)
942
+ def masked_fill_(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
943
+ """Same as tensor[mask] = value, where value must be scalar/scalars"""
944
+ return self.zipmap_args_inplace_(_MethodCallerWithArgs('masked_fill_'), mask, fill_value)
945
+
946
+ def select_set_(self, mask: _TensorSeq, value: _STOrSTSeq):
947
+ """Same as tensor[mask] = value"""
948
+ list_value = value if isinstance(value, (list,tuple)) else [value]*len(self)
949
+ for tensor, m, v in zip(self, mask, list_value):
950
+ tensor[m] = v # pyright: ignore[reportArgumentType]
951
+
952
+ def masked_set_(self, mask: _TensorSeq, value: _TensorSeq):
953
+ """Same as tensor[mask] = value[mask]"""
954
+ for tensor, m, v in zip(self, mask, value):
955
+ tensor[m] = v[m]
956
+
957
+ def select(self, idx: Any):
958
+ """same as tensor[idx]"""
959
+ if not isinstance(idx, (list,tuple)): return self.__class__(t[idx] for t in self)
960
+ return self.__class__(t[i] for t,i in zip(self, idx))
961
+
962
+ def dot(self, other: _TensorSeq):
963
+ return (self * other).global_sum()
964
+
965
+ def tensorwise_dot(self, other: _TensorSeq):
966
+ return (self * other).sum()
967
+
968
+ def swap_tensors(self, other: _TensorSeq):
969
+ for s, o in zip(self, other):
970
+ torch.utils.swap_tensors(s, o)
971
+
972
+ def unbind_channels(self, dim=0):
973
+ """returns a new tensorlist where tensors with 2 or more dimensions are split into slices along 1st dimension"""
974
+ return self.__class__(ch for t in self for ch in (t.unbind(dim) if t.ndim >= 2 else (t,)) )
975
+
976
+
977
+ def flatiter(self) -> Generator[torch.Tensor]:
978
+ for tensor in self:
979
+ yield from tensor.view(-1)
980
+
981
+ # def flatset(self, idx: int, value: Any):
982
+ # """sets index in flattened view"""
983
+ # return self.clone().flatset_(idx, value)
984
+
985
+ def flat_set_(self, idx: int, value: Any):
986
+ """sets index in flattened view"""
987
+ cur = 0
988
+ for tensor in self:
989
+ numel = tensor.numel()
990
+ if idx < cur + numel:
991
+ tensor.view(-1)[cur-idx] = value
992
+ return self
993
+ cur += numel
994
+ raise IndexError(idx)
995
+
996
+ def flat_set_lambda_(self, idx, fn):
997
+ """sets index in flattened view to return of fn(current_value)"""
998
+ cur = 0
999
+ for tensor in self:
1000
+ numel = tensor.numel()
1001
+ if idx < cur + numel:
1002
+ flat_view = tensor.view(-1)
1003
+ flat_view[cur-idx] = fn(flat_view[cur-idx])
1004
+ return self
1005
+ cur += numel
1006
+ raise IndexError(idx)
1007
+
1008
+ def __repr__(self):
1009
+ return f"{self.__class__.__name__}({super().__repr__()})"
1010
+
1011
+
1012
+ def stack(tensorlists: Iterable[TensorList], dim = 0):
1013
+ """Returns a tensorlist with the same elements as the input tensorlists, but stacked along the specified dimension."""
1014
+ return TensorList(torch.stack(i, dim = dim) for i in zip(*tensorlists))
1015
+
1016
+ def mean(tensorlists: Iterable[TensorList]):
1017
+ """Returns a tensorlist which is the mean of given tensorlists."""
1018
+ res = TensorList()
1019
+ for tensors in zip(*tensorlists):
1020
+ res.append(torch.stack(tensors).mean(0))
1021
+ return res
1022
+
1023
+ def median(tensorlists: Iterable[TensorList]):
1024
+ """Returns a tensorlist which is the median of given tensorlists."""
1025
+ res = TensorList()
1026
+ for tensors in zip(*tensorlists):
1027
+ res.append(torch.stack(tensors).median(0)[0])
1028
+ return res
1029
+
1030
+ def quantile(tensorlists: Iterable[TensorList], q, interpolation = 'linear'):
1031
+ """Returns a tensorlist which is the median of given tensorlists."""
1032
+ res = TensorList()
1033
+ for tensors in zip(*tensorlists):
1034
+ res.append(torch.stack(tensors).quantile(q=q, dim=0, interpolation=interpolation))
1035
+ return res
1036
+
1037
+ def sum(tensorlists: Iterable[TensorList]):
1038
+ """Returns a tensorlist which is the sum of given tensorlists."""
1039
+ res = TensorList()
1040
+ for tensors in zip(*tensorlists):
1041
+ res.append(torch.stack(tensors).sum(0))
1042
+ return res
1043
+
1044
+ def where(condition: TensorList, input: _STOrSTSeq, other: _STOrSTSeq):
1045
+ """Where but for a tensorlist."""
1046
+ args = [i if isinstance(i, (list, tuple)) else [i]*len(condition) for i in (input, other)]
1047
+ return condition.__class__(torch.where(*z) for z in zip(condition, *args))
1048
+
1049
+ def generic_clamp(x: Any, min=None,max=None) -> Any:
1050
+ if isinstance(x, (torch.Tensor, TensorList)): return x.clamp(min,max)
1051
+ if isinstance(x, (list, tuple)): return x.__class__(generic_clamp(i,min,max) for i in x)
1052
+ if x < min: return min
1053
+ if x > max: return max
1054
+ return x
1055
+
1056
+ def generic_numel(x: torch.Tensor | TensorList) -> int:
1057
+ if isinstance(x, torch.Tensor): return x.numel()
1058
+ return x.global_numel()
1059
+
1060
+ @overload
1061
+ def generic_zeros_like(x: torch.Tensor) -> torch.Tensor: ...
1062
+ @overload
1063
+ def generic_zeros_like(x: TensorList) -> TensorList: ...
1064
+ def generic_zeros_like(x: torch.Tensor | TensorList):
1065
+ if isinstance(x, torch.Tensor): return torch.zeros_like(x)
1066
+ return x.zeros_like()
1067
+
1068
+ def generic_vector_norm(x: torch.Tensor | TensorList, ord=2) -> torch.Tensor:
1069
+ if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=ord) # pylint:disable=not-callable
1070
+ return x.global_vector_norm(ord)
1071
+
1072
+
1073
+
1074
+ @overload
1075
+ def generic_randn_like(x: torch.Tensor) -> torch.Tensor: ...
1076
+ @overload
1077
+ def generic_randn_like(x: TensorList) -> TensorList: ...
1078
+ def generic_randn_like(x: torch.Tensor | TensorList):
1079
+ if isinstance(x, torch.Tensor): return torch.randn_like(x)
1080
+ return x.randn_like()
1081
+