torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
torchzero/tensorlist.py DELETED
@@ -1,826 +0,0 @@
1
- r"""
2
- TensorList is a data type that can be used to manipulate a sequence of tensors such as model parameters,
3
- with the same methods that normal tensors have, plus some additional convenience features.
4
- Whenever possible, I used _foreach methods and other tricks to speed up computation.
5
-
6
- TensorList is similar to TensorDict (https://github.com/pytorch/tensordict).
7
- If you want to get the most performance out of a collection of tensors, use TensorDict and lock it.
8
- However I found that *creating* a TensorDict is quite slow. In fact it negates the benefits of using it
9
- in an optimizer when you have to create one from parameters on each step. The solution could be to create
10
- it once beforehand, but then you won't be able to easily support parameter groups and per-parameter states.
11
- """
12
- import builtins
13
- from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
14
- import math
15
- import operator
16
- from typing import Any, Literal, TypedDict
17
- from typing_extensions import Self, TypeAlias, Unpack
18
-
19
- import torch
20
-
21
- _Scalar = int | float | bool | complex
22
- _AnyTensor = torch.Tensor | torch.nn.Parameter
23
- _TensorSequence = list[_AnyTensor] | tuple[_AnyTensor, ...]
24
- _ScalarSequence: TypeAlias = "list[_Scalar] | tuple[_Scalar] | TensorList"
25
- _STSequence: TypeAlias = "_TensorSequence | _ScalarSequence"
26
- _STOrSTSequence: TypeAlias = "_Scalar | torch.Tensor | torch.nn.Parameter | _STSequence"
27
-
28
- Distributions = Literal['normal', 'uniform', 'sphere', 'rademacher']
29
- class _NewTensorKwargs(TypedDict, total = False):
30
- memory_format: Any
31
- dtype: Any
32
- layout: Any
33
- device: Any
34
- pin_memory: bool
35
- requires_grad: bool
36
-
37
- # _foreach_methods = {attr.replace('_foreach_', ''):getattr(torch, attr) for attr in dir(torch) if attr.startswith('_foreach_')}
38
- class _MethodCallerWithArgs:
39
- """Return a callable object that calls the given method on its operand.
40
-
41
- This is similar to operator.methodcaller but args and kwargs are specificed in __call__.
42
-
43
- Args:
44
- method (str): name of method to call.
45
- """
46
- __slots__ = ('_name',)
47
- def __init__(self, name: str):
48
- self._name = name
49
-
50
- def __call__(self, obj, *args, **kwargs):
51
- return getattr(obj, self._name)(*args, **kwargs)
52
-
53
- def __repr__(self):
54
- return f'{self.__class__.__module__}.{self.__class__.__name__}({repr(self._name)})'
55
-
56
- def __reduce__(self):
57
- return self.__class__, self._name
58
-
59
-
60
- def maximum_(input:torch.Tensor, other: torch.Tensor):
61
- """in-place maximum"""
62
- return torch.maximum(input, other, out = input)
63
-
64
- def where_(input: torch.Tensor, condition: torch.Tensor, other: torch.Tensor):
65
- """in-place where"""
66
- return torch.where(condition, input, other, out = input)
67
-
68
- # tensorlist must subclass list
69
- # UserList doesn't work with _foreach_xxx
70
- class TensorList(list[torch.Tensor | Any]):
71
- @classmethod
72
- def complex(cls, real: _TensorSequence, imag: _TensorSequence):
73
- """Create a complex TensorList from real and imaginary tensor sequences."""
74
- return cls(torch.complex(r, i) for r, i in zip(real, imag))
75
-
76
- @property
77
- def device(self): return [i.device for i in self]
78
- @property
79
- def dtype(self): return [i.dtype for i in self]
80
- @property
81
- def requires_grad(self): return [i.requires_grad for i in self]
82
- @property
83
- def shape(self): return [i.shape for i in self]
84
- def size(self, dim: int | None = None): return [i.size(dim) for i in self]
85
- @property
86
- def ndim(self): return [i.ndim for i in self]
87
- def ndimension(self): return [i.ndimension() for i in self]
88
- def numel(self): return [i.numel() for i in self]
89
-
90
- @property
91
- def grad(self): return self.__class__(i.grad for i in self)
92
- @property
93
- def real(self): return self.__class__(i.real for i in self)
94
- @property
95
- def imag(self): return self.__class__(i.imag for i in self)
96
-
97
- def view_as_real(self): return self.__class__(torch.view_as_real(i) for i in self)
98
- def view_as_complex(self): return self.__class__(torch.view_as_complex(i) for i in self)
99
-
100
- def type_as(self, other: torch.Tensor | _TensorSequence):
101
- return self.zipmap(_MethodCallerWithArgs('type_as'), other)
102
-
103
- def to_real_views(self):
104
- """Turns all complex tensors into real views, ignoring non-complex tensors, and sets an attribute `_tl_is_complex` to True or False,
105
- which `from_real_views` method can use to convert real views back into complex tensors"""
106
- tl = TensorList()
107
- for p in self:
108
- if torch.is_complex(p):
109
- p._tl_is_complex = True # type:ignore
110
- tl.append(torch.view_as_real(p))
111
- else:
112
- p._tl_is_complex = False # type:ignore
113
- tl.append(p)
114
- return tl
115
-
116
- def from_real_views(self):
117
- """undoes `to_real_views`."""
118
- return self.__class__(torch.view_as_complex(p) if p._tl_is_complex else p for p in self) # type:ignore
119
-
120
- def get_existing_grads(self):
121
- """Returns all gradients that are not None."""
122
- return self.__class__(i.grad for i in self if i is not None)
123
-
124
- def with_requires_grad(self, requires_grad = True):
125
- """Returns all tensors with requires_grad set to the given value."""
126
- return self.__class__(i for i in self if i.requires_grad == requires_grad)
127
-
128
- def with_grad(self):
129
- """returns all tensors whose .grad is not None"""
130
- return self.__class__(i for i in self if i.grad is not None)
131
-
132
- def ensure_grad_(self):
133
- """For each element, if grad is None and it requires grad, sets grad to zeroes."""
134
- for i in self:
135
- if i.requires_grad and i.grad is None: i.grad = torch.zeros_like(i)
136
- return self
137
-
138
- def accumulate_grad_(self, grads: _TensorSequence):
139
- """Creates grad if it is None, otherwise adds to existing grad."""
140
- for i, g in zip(self, grads):
141
- if i.grad is None: i.grad = g
142
- else: i.grad.add_(g)
143
- return self
144
-
145
- def set_grad_(self, grads: _TensorSequence):
146
- """Sets grad to the given sequence, overwrites grad that already exists."""
147
- for i, g in zip(self, grads): i.grad = g
148
- return self
149
-
150
- def zero_grad_(self, set_to_none = True):
151
- """Set all grads to None or zeroes."""
152
- if set_to_none:
153
- for p in self: p.grad = None
154
- else:
155
- self.get_existing_grads().zero_()
156
- return self
157
-
158
- def __add__(self, other: _STOrSTSequence) -> Self: return self.add(other) # type:ignore
159
- def __radd__(self, other: _STOrSTSequence) -> Self: return self.add(other)
160
- def __iadd__(self, other: _STOrSTSequence) -> Self: return self.add_(other) # type:ignore
161
-
162
- def __sub__(self, other: "_Scalar | _STSequence") -> Self: return self.sub(other)
163
- def __rsub__(self, other: "_Scalar | _STSequence") -> Self: return self.sub(other).neg_()
164
- def __isub__(self, other: "_Scalar | _STSequence") -> Self: return self.sub_(other)
165
-
166
- def __mul__(self, other: _STOrSTSequence) -> Self: return self.mul(other) # type:ignore
167
- def __rmul__(self, other: _STOrSTSequence) -> Self: return self.mul(other) # type:ignore
168
- def __imul__(self, other: _STOrSTSequence) -> Self: return self.mul_(other) # type:ignore
169
-
170
- def __truediv__(self, other: "_Scalar | _STSequence") -> Self: return self.div(other)
171
- def __rtruediv__(self, other: "_Scalar | _STSequence") -> Self: return other * self.reciprocal() # type:ignore
172
- def __itruediv__(self, other: "_Scalar | _STSequence") -> Self: return self.div_(other)
173
-
174
- def __floordiv__(self, other: _STOrSTSequence): return self.floor_divide(other)
175
- #def __rfloordiv__(self, other: "TensorList"): return other.floor_divide(self)
176
- def __ifloordiv__(self, other: _STOrSTSequence): return self.floor_divide_(other)
177
-
178
- def __mod__(self, other: _STOrSTSequence): return self.remainder(other)
179
- #def __rmod__(self, other: STOrSTSequence): return self.remainder(other)
180
- def __imod__(self, other: _STOrSTSequence):return self.remainder_(other)
181
-
182
- def __pow__(self, other: "_Scalar | _STSequence"): return self.pow(other)
183
- def __rpow__(self, other: "_Scalar | _TensorSequence"): return self.rpow(other)
184
- def __ipow__(self, other: "_Scalar | _STSequence"): return self.pow_(other)
185
-
186
- def __neg__(self): return self.neg()
187
-
188
- def __eq__(self, other: _STOrSTSequence): return self.eq(other) # type:ignore
189
- def __ne__(self, other: _STOrSTSequence): return self.ne(other) # type:ignore
190
- def __lt__(self, other: _STOrSTSequence): return self.lt(other) # type:ignore
191
- def __le__(self, other: _STOrSTSequence): return self.le(other) # type:ignore
192
- def __gt__(self, other: _STOrSTSequence): return self.gt(other) # type:ignore
193
- def __ge__(self, other: _STOrSTSequence): return self.ge(other) # type:ignore
194
-
195
- def __invert__(self): return self.logical_not()
196
-
197
- def __and__(self, other: torch.Tensor | _TensorSequence): return self.logical_and(other)
198
- def __iand__(self, other: torch.Tensor | _TensorSequence): return self.logical_and_(other)
199
- def __or__(self, other: torch.Tensor | _TensorSequence): return self.logical_or(other)
200
- def __ior__(self, other: torch.Tensor | _TensorSequence): return self.logical_or_(other)
201
- def __xor__(self, other: torch.Tensor | _TensorSequence): return self.logical_xor(other)
202
- def __ixor__(self, other: torch.Tensor | _TensorSequence): return self.logical_xor_(other)
203
-
204
- def __bool__(self):
205
- raise RuntimeError(f'Boolean value of {self.__class__.__name__} is ambiguous')
206
-
207
- def map(self, fn: Callable[..., torch.Tensor], *args, **kwargs):
208
- """Applies `fn` to all elements of this TensorList
209
- and returns a new TensorList with return values of the callable."""
210
- return self.__class__(fn(i, *args, **kwargs) for i in self)
211
- def map_inplace_(self, fn: Callable[..., Any], *args, **kwargs):
212
- """Applies an in-place `fn` to all elements of this TensorList."""
213
- for i in self: fn(i, *args, **kwargs)
214
- return self
215
-
216
- def filter(self, fn: Callable[..., bool], *args, **kwargs):
217
- """Returns a TensorList with all elements for which `fn` returned True."""
218
- return self.__class__(i for i in self if fn(i, *args, **kwargs))
219
-
220
- def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
221
- """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
222
- Otherwise applies `fn` to this TensorList and `other`.
223
- Returns a new TensorList with return values of the callable."""
224
- if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
225
- return self.__class__(fn(i, other, *args, **kwargs) for i in self)
226
-
227
- def zipmap_inplace_(self, fn: Callable[..., Any], other: Any | list | tuple, *args, **kwargs):
228
- """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
229
- Otherwise applies `fn` to this TensorList and `other`.
230
- The callable must modify elements in-place."""
231
- if isinstance(other, (list, tuple)):
232
- for i, j in zip(self, other): fn(i, j, *args, **kwargs)
233
- else:
234
- for i in self: fn(i, other, *args, **kwargs)
235
- return self
236
-
237
- def zipmap_args(self, fn: Callable[..., Any], *others, **kwargs):
238
- """If `args` is list/tuple, applies `fn` to this TensorList zipped with `others`.
239
- Otherwise applies `fn` to this TensorList and `other`."""
240
- others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
241
- return self.__class__(fn(*z, **kwargs) for z in zip(self, *others))
242
-
243
- def zipmap_args_inplace_(self, fn: Callable[..., Any], *others, **kwargs):
244
- """If `args` is list/tuple, applies `fn` to this TensorList zipped with `other`.
245
- Otherwise applies `fn` to this TensorList and `other`.
246
- The callable must modify elements in-place."""
247
- others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
248
- for z in zip(self, *others): fn(*z, **kwargs)
249
- return self
250
-
251
- def _foreach_apply(self, fn: Callable[[list[torch.Tensor]], list[torch.Tensor]], *args, **kwargs):
252
- """Applies a torch._foreach_xxx function to self and converts returned list back to TensorList or subclass."""
253
- return self.__class__(fn(self), *args, **kwargs)
254
-
255
- # def __getattr__(self, name: str) -> Callable:
256
- # if name == '__torch_function__' or name == '_ipython_canary_method_should_not_exist_': raise AttributeError('who ???')
257
- # if name in _foreach_methods:
258
- # method = partial(self._foreach_apply, _foreach_methods[name])
259
- # else:
260
- # method = partial(self.map, MethodCallerWithArgs(name))
261
- # setattr(self, name, method)
262
- # return method
263
-
264
- def to(self, *args, **kwargs): return self.__class__(i.to(*args, **kwargs) for i in self)
265
- def cuda(self): return self.__class__(i.cuda() for i in self)
266
- def cpu(self): return self.__class__(i.cpu() for i in self)
267
- def long(self): return self.__class__(i.long() for i in self)
268
- def short(self): return self.__class__(i.short() for i in self)
269
- def clone(self): return self.__class__(i.clone() for i in self)
270
- def detach(self): return self.__class__(i.detach() for i in self)
271
- def detach_(self): return self.__class__(i.detach_() for i in self)
272
-
273
- # apparently I can't use float for typing if I call a method "float"
274
- def as_float(self): return self.__class__(i.float() for i in self)
275
- def as_bool(self): return self.__class__(i.bool() for i in self)
276
- def as_int(self): return self.__class__(i.int() for i in self)
277
-
278
- def copy_(self, src: _TensorSequence, non_blocking = False):
279
- """Copies the elements from src tensors into self tensors."""
280
- torch._foreach_copy_(self, src, non_blocking=non_blocking)
281
- def set_(self, storage: Iterable[torch.Tensor | torch.types.Storage]):
282
- """Sets elements of this TensorList to the values of a list of tensors."""
283
- for i, j in zip(self, storage): i.set_(j) # type:ignore
284
- return self
285
-
286
-
287
- def requires_grad_(self, mode: bool = True):
288
- for e in self: e.requires_grad_(mode)
289
- return self
290
-
291
- def to_vec(self): return torch.cat(self.ravel())
292
- def from_vec_(self, vec:torch.Tensor):
293
- """Sets elements of this TensorList to the values of a 1D tensor.
294
- The length of the tensor must be equal to the total number of elements in this TensorList."""
295
- cur = 0
296
- for el in self:
297
- numel = el.numel()
298
- el.set_(vec[cur:cur + numel].view_as(el)) # type:ignore
299
- cur += numel
300
- return self
301
-
302
- def from_vec(self, vec:torch.Tensor):
303
- """Creates a new TensorList from this TensorList but with values from a 1D tensor.
304
- The length of the tensor must be equal to the total number of elements in this TensorList."""
305
- res = []
306
- cur = 0
307
- for el in self:
308
- numel = el.numel()
309
- res.append(vec[cur:cur + numel].view_as(el)) # type:ignore
310
- cur += numel
311
- return TensorList(res)
312
-
313
- def total_min(self) -> torch.Tensor:
314
- return torch.min(self.to_vec())
315
- def total_max(self) -> torch.Tensor:
316
- return torch.max(self.to_vec())
317
- def total_mean(self) -> torch.Tensor:
318
- return torch.mean(self.to_vec())
319
- def total_sum(self) -> torch.Tensor:
320
- return torch.sum(self.to_vec())
321
- def total_vector_norm(self, ord:float = 2) -> torch.Tensor:
322
- return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
323
- def total_any(self):
324
- return self.to_vec().any()
325
- def total_all(self):
326
- return self.to_vec().all()
327
- def total_numel(self):
328
- return builtins.sum(self.numel())
329
-
330
- def empty_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.empty_like(i, **kwargs) for i in self)
331
- def zeros_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.zeros_like(i, **kwargs) for i in self)
332
- def ones_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.ones_like(i, **kwargs) for i in self)
333
- def full_like(self, fill_value: "_Scalar | _ScalarSequence", **kwargs: Unpack[_NewTensorKwargs]):
334
- #return self.__class__(torch.full_like(i, fill_value=fill_value, **kwargs) for i in self)
335
- return self.zipmap(torch.full_like, other=fill_value, **kwargs)
336
-
337
- def rand_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.rand_like(i, **kwargs) for i in self)
338
- def randn_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.randn_like(i, **kwargs) for i in self)
339
-
340
- def randint_like(self, low: "_Scalar | _ScalarSequence", high: "_Scalar | _ScalarSequence", **kwargs: Unpack[_NewTensorKwargs]):
341
- return self.zipmap_args(torch.randint_like, low, high, **kwargs)
342
- def uniform_like(self, low: "_Scalar | _ScalarSequence" = 0, high: "_Scalar | _ScalarSequence" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
343
- res = self.empty_like(**kwargs)
344
- res.uniform_(low, high, generator=generator)
345
- return res
346
- def sphere_like(self, radius: "_Scalar | _ScalarSequence", **kwargs: Unpack[_NewTensorKwargs]) -> Self:
347
- r = self.randn_like(**kwargs)
348
- return (r * radius) / r.total_vector_norm() # type:ignore
349
- def bernoulli(self, generator = None):
350
- return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
351
- def bernoulli_like(self, p: "_Scalar | _ScalarSequence" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
352
- """p is probability of a 1, other values will be 0."""
353
- return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
354
- def rademacher_like(self, p: "_Scalar | _ScalarSequence" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
355
- """p is probability of a 1, other values will be -1."""
356
- return self.bernoulli_like(p, generator=generator, **kwargs) * 2 - 1
357
-
358
- def sample_like(self, eps: "_Scalar | _ScalarSequence" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
359
- """Sample around 0."""
360
- if distribution == 'normal': return self.randn_like(**kwargs) * eps # TODO: generator
361
- if distribution == 'uniform':
362
- if isinstance(eps, (list,tuple)):
363
- return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs) # type:ignore
364
- return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
365
- if distribution == 'sphere': return self.sphere_like(eps, **kwargs)
366
- if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
367
- raise ValueError(f'Unknow distribution {distribution}')
368
-
369
- def eq(self, other: _STOrSTSequence): return self.zipmap(torch.eq, other)
370
- def eq_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('eq_'), other)
371
- def ne(self, other: _STOrSTSequence): return self.zipmap(torch.ne, other)
372
- def ne_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('ne_'), other)
373
- def lt(self, other: _STOrSTSequence): return self.zipmap(torch.lt, other)
374
- def lt_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('lt_'), other)
375
- def le(self, other: _STOrSTSequence): return self.zipmap(torch.le, other)
376
- def le_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('le_'), other)
377
- def gt(self, other: _STOrSTSequence): return self.zipmap(torch.gt, other)
378
- def gt_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('gt_'), other)
379
- def ge(self, other: _STOrSTSequence): return self.zipmap(torch.ge, other)
380
- def ge_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('ge_'), other)
381
-
382
- def logical_and(self, other: torch.Tensor | _TensorSequence): return self.zipmap(torch.logical_and, other)
383
- def logical_and_(self, other: torch.Tensor | _TensorSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_and_'), other)
384
- def logical_or(self, other: torch.Tensor | _TensorSequence): return self.zipmap(torch.logical_or, other)
385
- def logical_or_(self, other: torch.Tensor | _TensorSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_or_'), other)
386
- def logical_xor(self, other: torch.Tensor | _TensorSequence): return self.zipmap(torch.logical_xor, other)
387
- def logical_xor_(self, other: torch.Tensor | _TensorSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('logical_xor_'), other)
388
-
389
- def logical_not(self): return self.__class__(torch.logical_not(i) for i in self)
390
- def logical_not_(self):
391
- for i in self: i.logical_not_()
392
- return self
393
-
394
- def equal(self, other: torch.Tensor | _TensorSequence):
395
- """returns TensorList of boolean values, True if two tensors have the same size and elements, False otherwise."""
396
- return self.zipmap(torch.equal, other)
397
-
398
- def add(self, other: _STOrSTSequence, alpha: _Scalar = 1):
399
- if alpha == 1: return self.__class__(torch._foreach_add(self, other))
400
- return self.__class__(torch._foreach_add(self, other, alpha = alpha)) # type:ignore
401
- def add_(self, other: _STOrSTSequence, alpha: _Scalar = 1):
402
- if alpha == 1: torch._foreach_add_(self, other)
403
- else: torch._foreach_add_(self, other, alpha = alpha) # type:ignore
404
- return self
405
-
406
-
407
- def sub(self, other: "_Scalar | _STSequence", alpha: _Scalar = 1):
408
- if alpha == 1: return self.__class__(torch._foreach_sub(self, other))
409
- return self.__class__(torch._foreach_sub(self, other, alpha = alpha)) # type:ignore
410
- def sub_(self, other: "_Scalar | _STSequence", alpha: _Scalar = 1):
411
- if alpha == 1: torch._foreach_sub_(self, other)
412
- else: torch._foreach_sub_(self, other, alpha = alpha) # type:ignore
413
- return self
414
-
415
- def neg(self): return self.__class__(torch._foreach_neg(self))
416
- def neg_(self):
417
- torch._foreach_neg_(self)
418
- return self
419
-
420
- def mul(self, other: _STOrSTSequence): return self.__class__(torch._foreach_mul(self, other))
421
- def mul_(self, other: _STOrSTSequence):
422
- torch._foreach_mul_(self, other)
423
- return self
424
-
425
- def div(self, other: _STOrSTSequence) -> Self: return self.__class__(torch._foreach_div(self, other))
426
- def div_(self, other: _STOrSTSequence):
427
- torch._foreach_div_(self, other)
428
- return self
429
-
430
- def pow(self, exponent: "_Scalar | _STSequence"): return self.__class__(torch._foreach_pow(self, exponent))
431
- def pow_(self, exponent: "_Scalar | _STSequence"):
432
- torch._foreach_pow_(self, exponent)
433
- return self
434
-
435
- def rpow(self, input: _Scalar | _TensorSequence): return self.__class__(torch._foreach_pow(input, self))
436
- def rpow_(self, input: _TensorSequence):
437
- torch._foreach_pow_(input, self)
438
- return self
439
-
440
- def sqrt(self): return self.__class__(torch._foreach_sqrt(self))
441
- def sqrt_(self):
442
- torch._foreach_sqrt_(self)
443
- return self
444
-
445
- def remainder(self, other: _STOrSTSequence): return self.zipmap(torch.remainder, other)
446
- def remainder_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('remainder_'), other)
447
-
448
- def floor_divide(self, other: _STOrSTSequence): return self.zipmap(torch.floor_divide, other)
449
- def floor_divide_(self, other: _STOrSTSequence): return self.zipmap_inplace_(_MethodCallerWithArgs('floor_divide_'), other)
450
-
451
- def reciprocal(self): return self.__class__(torch._foreach_reciprocal(self))
452
- def reciprocal_(self):
453
- torch._foreach_reciprocal_(self)
454
- return self
455
-
456
- def abs(self): return self.__class__(torch._foreach_abs(self))
457
- def abs_(self):
458
- torch._foreach_abs_(self)
459
- return self
460
-
461
- def sign(self): return self.__class__(torch._foreach_sign(self))
462
- def sign_(self):
463
- torch._foreach_sign_(self)
464
- return self
465
-
466
- def signbit(self): return self.__class__(torch.signbit(i) for i in self)
467
-
468
- def sin(self): return self.__class__(torch._foreach_sin(self))
469
- def sin_(self):
470
- torch._foreach_sin_(self)
471
- return self
472
-
473
- def cos(self): return self.__class__(torch._foreach_cos(self))
474
- def cos_(self):
475
- torch._foreach_cos_(self)
476
- return self
477
-
478
- def tan(self): return self.__class__(torch._foreach_tan(self))
479
- def tan_(self):
480
- torch._foreach_tan_(self)
481
- return self
482
-
483
- def asin(self): return self.__class__(torch._foreach_asin(self))
484
- def asin_(self):
485
- torch._foreach_asin_(self)
486
- return self
487
-
488
- def acos(self): return self.__class__(torch._foreach_acos(self))
489
- def acos_(self):
490
- torch._foreach_acos_(self)
491
- return self
492
-
493
- def atan(self): return self.__class__(torch._foreach_atan(self))
494
- def atan_(self):
495
- torch._foreach_atan_(self)
496
- return self
497
-
498
- def sinh(self): return self.__class__(torch._foreach_sinh(self))
499
- def sinh_(self):
500
- torch._foreach_sinh_(self)
501
- return self
502
-
503
- def cosh(self): return self.__class__(torch._foreach_cosh(self))
504
- def cosh_(self):
505
- torch._foreach_cosh_(self)
506
- return self
507
-
508
- def tanh(self): return self.__class__(torch._foreach_tanh(self))
509
- def tanh_(self):
510
- torch._foreach_tanh_(self)
511
- return self
512
-
513
- def log(self): return self.__class__(torch._foreach_log(self))
514
- def log_(self):
515
- torch._foreach_log_(self)
516
- return self
517
-
518
- def log10(self): return self.__class__(torch._foreach_log10(self))
519
- def log10_(self):
520
- torch._foreach_log10_(self)
521
- return self
522
-
523
- def log2(self): return self.__class__(torch._foreach_log2(self))
524
- def log2_(self):
525
- torch._foreach_log2_(self)
526
- return self
527
-
528
- def log1p(self): return self.__class__(torch._foreach_log1p(self))
529
- def log1p_(self):
530
- torch._foreach_log1p_(self)
531
- return self
532
-
533
- def erf(self): return self.__class__(torch._foreach_erf(self))
534
- def erf_(self):
535
- torch._foreach_erf_(self)
536
- return self
537
-
538
- def erfc(self): return self.__class__(torch._foreach_erfc(self))
539
- def erfc_(self):
540
- torch._foreach_erfc_(self)
541
- return self
542
-
543
- def max(self, dim = None, keepdim = False):
544
- if dim is None and not keepdim: return self.__class__(torch._foreach_max(self))
545
- return self.__class__(i.max(dim=dim, keepdim=keepdim) for i in self)
546
-
547
- def min(self, dim = None, keepdim = False):
548
- if dim is None and not keepdim: return self.__class__(torch._foreach_max(self.neg())).neg()
549
- return self.__class__(i.min(dim=dim, keepdim=keepdim) for i in self)
550
-
551
- def norm(self, ord: _Scalar, dtype=None):
552
- return self.__class__(torch._foreach_norm(self, ord, dtype))
553
-
554
- def mean(self, dim = None, keepdim = False): return self.__class__(i.mean(dim=dim, keepdim=keepdim) for i in self)
555
- def sum(self, dim = None, keepdim = False): return self.__class__(i.sum(dim=dim, keepdim=keepdim) for i in self)
556
- def prod(self, dim = None, keepdim = False): return self.__class__(i.prod(dim=dim, keepdim=keepdim) for i in self)
557
- def std(self, dim = None, keepdim = False): return self.__class__(i.std(dim=dim, keepdim=keepdim) for i in self)
558
-
559
- def clamp_min(self, other: "_Scalar | _STSequence"): return self.__class__(torch._foreach_clamp_min(self, other))
560
- def clamp_min_(self, other: "_Scalar | _STSequence"):
561
- torch._foreach_clamp_min_(self, other)
562
- return self
563
- def clamp_max(self, other: "_Scalar | _STSequence"): return self.__class__(torch._foreach_clamp_max(self, other))
564
- def clamp_max_(self, other: "_Scalar | _STSequence"):
565
- torch._foreach_clamp_max_(self, other)
566
- return self
567
-
568
- def clamp(self, min: "_Scalar | _STSequence | None" = None, max: "_Scalar | _STSequence | None" = None):
569
- l = self
570
- if min is not None: l = l.clamp_min(min)
571
- if max is not None: l = l.clamp_max(max)
572
- return l
573
- def clamp_(self, min: "_Scalar | _STSequence | None" = None, max: "_Scalar | _STSequence | None" = None):
574
- if min is not None: self.clamp_min_(min)
575
- if max is not None: self.clamp_max_(max)
576
- return self
577
-
578
- def clip(self, min: "_Scalar | _STSequence | None" = None, max: "_Scalar | _STSequence | None" = None): return self.clamp(min,max)
579
- def clip_(self, min: "_Scalar | _STSequence | None" = None, max: "_Scalar | _STSequence | None" = None): return self.clamp_(min,max)
580
-
581
- def clamp_magnitude(self, min: "_Scalar | _STSequence | None" = None, max: "_Scalar | _STSequence | None" = None):
582
- return self.abs().clamp_(min, max) * self.sign().add_(0.5).sign_() # this prevents zeros
583
- def clamp_magnitude_(self, min: "_Scalar | _STSequence | None" = None, max: "_Scalar | _STSequence | None" = None):
584
- sign = self.sign().add_(0.5).sign_()
585
- return self.abs_().clamp_(min, max).mul_(sign)
586
-
587
-
588
- def floor(self): return self.__class__(torch._foreach_floor(self))
589
- def floor_(self):
590
- torch._foreach_floor_(self)
591
- return self
592
- def ceil(self): return self.__class__(torch._foreach_ceil(self))
593
- def ceil_(self):
594
- torch._foreach_ceil_(self)
595
- return self
596
- def round(self): return self.__class__(torch._foreach_round(self))
597
- def round_(self):
598
- torch._foreach_round(self)
599
- return self
600
-
601
- def zero_(self):
602
- torch._foreach_zero_(self)
603
- return self
604
-
605
- def lerp(self, tensors1: _TensorSequence, weight: "_Scalar | _TensorSequence"):
606
- """linear interpolation of between self and tensors1. `out = self + weight * (tensors1 - self)`."""
607
- return self.__class__(torch._foreach_lerp(self, tensors1, weight))
608
- def lerp_(self, tensors1: _TensorSequence, weight: "_Scalar | _TensorSequence"):
609
- """linear interpolation of between self and tensors1. `out = self + weight * (tensors1 - self)`."""
610
- torch._foreach_lerp_(self, tensors1, weight)
611
- return self
612
-
613
- def lerp_compat(self, tensors1: _TensorSequence, weight: "_STOrSTSequence"):
614
- """`lerp` but supports python number sequence as weight and implemented through other operations
615
-
616
- `out = self + weight * (tensors1 - self)`."""
617
- return self + weight * (TensorList(tensors1) - self)
618
- def lerp_compat_(self, tensors1: _TensorSequence, weight: "_STOrSTSequence"):
619
- """`lerp_` but supports python number sequence as weight and implemented through other operations
620
-
621
- `out = self + weight * (tensors1 - self)`."""
622
- return self.add_(TensorList(tensors1).sub(self).mul_(weight))
623
-
624
- def addcmul(self, tensors1: _TensorSequence, tensor2: _TensorSequence, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
625
- return self.__class__(torch._foreach_addcmul(self, tensors1, tensor2, value))
626
- def addcmul_(self, tensors1: _TensorSequence, tensor2: _TensorSequence, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
627
- torch._foreach_addcmul_(self, tensors1, tensor2, value)
628
- return self
629
- def addcdiv(self, tensors1: _TensorSequence, tensor2: _TensorSequence, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
630
- return self.__class__(torch._foreach_addcdiv(self, tensors1, tensor2, value))
631
- def addcdiv_(self, tensors1: _TensorSequence, tensor2: _TensorSequence, value: "_Scalar | Sequence[_Scalar] | torch.Tensor" = 1):
632
- torch._foreach_addcdiv_(self, tensors1, tensor2, value)
633
- return self
634
-
635
- def uniform_(self, low: "_Scalar | _ScalarSequence" = 0, high: "_Scalar | _ScalarSequence" = 1, generator = None):
636
- return self.zipmap_args_inplace_(_MethodCallerWithArgs('uniform_'), low, high, generator = generator)
637
-
638
- def maximum(self, other: torch.Tensor | _TensorSequence): return self.zipmap(torch.maximum, other = other)
639
- def maximum_(self, other: torch.Tensor | _TensorSequence): return self.zipmap_inplace_(maximum_, other = other)
640
-
641
- def squeeze(self, dim = None):
642
- return self.__class__(i.squeeze(dim) for i in self)
643
-
644
- def squeeze_(self, dim = None):
645
- for i in self: i.squeeze_(dim)
646
- return self
647
-
648
- def conj(self): return self.__class__(i.conj() for i in self)
649
-
650
- def nan_to_num_(self,nan: float | None = None,posinf: float | None = None,neginf: float | None = None):
651
- for i in self: torch.nan_to_num_(i, nan = nan, posinf = posinf, neginf = neginf)
652
- return self
653
-
654
- def ravel(self): return self.__class__(i.ravel() for i in self)
655
-
656
- def any(self): return self.__class__(i.any() for i in self)
657
- def all(self): return self.__class__(i.all() for i in self)
658
- def isfinite(self): return self.__class__(i.isfinite() for i in self)
659
-
660
- def fill(self, value: _STOrSTSequence): return self.zipmap(torch.fill, other = value)
661
- def fill_(self, value: _STOrSTSequence): return self.zipmap_inplace_(torch.fill_, other = value)
662
-
663
- def copysign(self, other):
664
- return self.__class__(t.copysign(o) for t, o in zip(self, other))
665
- def copysign_(self, other):
666
- for t, o in zip(self, other): t.copysign_(o)
667
- return self
668
-
669
- def graft(self, other: "_TensorSequence", tensorwise=False, eps = 1e-6):
670
- if not isinstance(other, TensorList): other = TensorList(other)
671
- if tensorwise:
672
- norm_self = self.norm(2)
673
- norm_other = other.norm(2)
674
- else:
675
- norm_self = self.total_vector_norm(2)
676
- norm_other = other.total_vector_norm(2)
677
-
678
- return self * (norm_other / norm_self.clip_(min=eps)) # type:ignore
679
-
680
- def graft_(self, other: "_TensorSequence", tensorwise=False, eps = 1e-6):
681
- if not isinstance(other, TensorList): other = TensorList(other)
682
- if tensorwise:
683
- norm_self = self.norm(2)
684
- norm_other = other.norm(2)
685
- else:
686
- norm_self = self.total_vector_norm(2)
687
- norm_other = other.total_vector_norm(2)
688
-
689
- return self.mul_(norm_other / norm_self.clip_(min=eps)) # type:ignore
690
-
691
-
692
- def where(self, condition: "torch.Tensor | _TensorSequence", other: _STOrSTSequence):
693
- """self where condition is true other otherwise"""
694
- return self.zipmap_args(_MethodCallerWithArgs('where'), condition, other)
695
- def where_(self, condition: "torch.Tensor | _TensorSequence", other: "torch.Tensor | _TensorSequence"):
696
- """self where condition is true other otherwise"""
697
- return self.zipmap_args_inplace_(where_, condition, other)
698
-
699
- def masked_fill(self, mask: "torch.Tensor | _TensorSequence", fill_value: "_Scalar | _ScalarSequence"):
700
- """Same as tensor[mask] = value (not in-place), where value must be scalar/scalars"""
701
- return self.zipmap_args(torch.masked_fill, mask, fill_value)
702
- def masked_fill_(self, mask: "torch.Tensor | _TensorSequence", fill_value: "_Scalar | _ScalarSequence"):
703
- """Same as tensor[mask] = value, where value must be scalar/scalars"""
704
- return self.zipmap_args_inplace_(_MethodCallerWithArgs('masked_fill_'), mask, fill_value)
705
-
706
- def select_set_(self, mask: _TensorSequence, value: _STOrSTSequence):
707
- """Same as tensor[mask] = value"""
708
- if not isinstance(value, (list,tuple)): value = [value]*len(self) # type:ignore
709
- for tensor, m, v in zip(self, mask, value): # type:ignore
710
- print(tensor, m, v)
711
- tensor[m] = v
712
-
713
- def masked_set_(self, mask: _TensorSequence, value: _TensorSequence):
714
- """Same as tensor[mask] = value[mask]"""
715
- for tensor, m, v in zip(self, mask, value):
716
- tensor[m] = v[m]
717
-
718
- def select(self, idx: Any):
719
- """same as tensor[idx]"""
720
- if not isinstance(idx, (list,tuple)): return self.__class__(t[idx] for t in self)
721
- return self.__class__(t[i] for t,i in zip(self, idx))
722
-
723
- def swap_tensors(self, other: _TensorSequence):
724
- for s, o in zip(self, other):
725
- torch.utils.swap_tensors(s, o)
726
-
727
- def unbind_channels(self, dim=0):
728
- """returns a new tensorlist where tensors with 2 or more dimensions are split into slices along 1st dimension"""
729
- return self.__class__(ch for t in self for ch in (t.unbind(dim) if t.ndim >= 2 else (t,)) )
730
-
731
- def flatiter(self) -> Generator[torch.Tensor]:
732
- for tensor in self:
733
- yield from tensor.view(-1)
734
-
735
- def __repr__(self):
736
- return f"{self.__class__.__name__}({super().__repr__()})"
737
-
738
-
739
- def _alpha_add(x, other, alpha):
740
- return x + other * alpha
741
-
742
- class NumberList(TensorList):
743
- """TensorList subclass for python numbers.
744
- Note that this only supports basic arithmetic operations that are overloaded.
745
-
746
- Can't use a numpy array because _foreach methods do not work with it."""
747
- # remove torch.Tensor from return values
748
- def __getitem__(self, i) -> Any:
749
- return super().__getitem__(i)
750
-
751
- def __iter__(self) -> Iterator[Any]:
752
- return super().__iter__()
753
-
754
- def _set_to_method_result_(self, method: str, *args, **kwargs):
755
- """Sets each element of the tensorlist to the result of calling the specified method on the corresponding element.
756
- This is used to support/mimic in-place operations."""
757
- res = getattr(self, method)(*args, **kwargs)
758
- for i,v in enumerate(res): self[i] = v
759
- return self
760
-
761
- def add(self, other: _STOrSTSequence, alpha: _Scalar = 1):
762
- if alpha == 1: return self.zipmap(operator.add, other=other)
763
- return self.zipmap(_alpha_add, other=other, alpha = alpha)
764
- def add_(self, other: _STOrSTSequence, alpha: _Scalar = 1):
765
- raise ValueError('dont use in-place operations on NumberList')
766
- # return self._set_to_method_result_('add', other, alpha = alpha)
767
-
768
- def sub(self, other: "_Scalar | _STSequence", alpha: _Scalar = 1):
769
- if alpha == 1: return self.zipmap(operator.sub, other=other)
770
- return self.zipmap(_alpha_add, other=other, alpha = -alpha)
771
-
772
- def __rsub__(self, other: "_Scalar | _STSequence") -> Self:
773
- # avoids in-place neg
774
- return self.sub(other).neg()
775
-
776
- def sub_(self, other: "_Scalar | _STSequence", alpha: _Scalar = 1):
777
- raise ValueError('dont use in-place operations on NumberList')
778
- # return self._set_to_method_result_('sub', other, alpha = alpha)
779
-
780
- def neg(self): return self.__class__(-i for i in self)
781
- def neg_(self):
782
- raise ValueError('dont use in-place operations on NumberList')
783
- # return self._set_to_method_result_('neg')
784
-
785
- def mul(self, other: _STOrSTSequence): return self.zipmap(operator.mul, other=other)
786
- def mul_(self, other: _STOrSTSequence):
787
- raise ValueError('dont use in-place operations on NumberList')
788
- # return self._set_to_method_result_('mul', other)
789
-
790
- def div(self, other: _STOrSTSequence) -> Self: return self.zipmap(operator.truediv, other=other)
791
- def div_(self, other: _STOrSTSequence):
792
- raise ValueError('dont use in-place operations on NumberList')
793
- # return self._set_to_method_result_('div', other)
794
-
795
- def pow(self, exponent: "_Scalar | _STSequence"): return self.zipmap(math.pow, other=exponent)
796
- def pow_(self, exponent: "_Scalar | _STSequence"):
797
- raise ValueError('dont use in-place operations on NumberList')
798
- # return self._set_to_method_result_('pow_', exponent)
799
-
800
- def __rtruediv__(self, other: "_Scalar | _STSequence"):
801
- # overriding because TensorList implements this through reciprocal
802
- if isinstance(other, (tuple,list)): return self.__class__(o / i for o, i in zip(self, other))
803
- return self.__class__(other / i for i in self)
804
-
805
- def map(self, fn: Callable[..., Any], *args, **kwargs):
806
- """Applies `fn` to all elements of this NumberList
807
- and returns a new TensorList with return values of the callable."""
808
- return super().map(fn, *args, **kwargs)
809
-
810
- def stack(tensorlists: Iterable[TensorList], dim = 0):
811
- """Returns a tensorlist with the same elements as the input tensorlists, but stacked along the specified dimension."""
812
- return TensorList(torch.stack(i, dim = dim) for i in zip(*tensorlists))
813
-
814
- def mean(tensorlists: Iterable[TensorList]):
815
- """Returns a tensorlist which is the mean of given tensorlists."""
816
- return stack(tensorlists).mean(0)
817
-
818
- def sum(tensorlists: Iterable[TensorList]):
819
- """Returns a tensorlist which is the sum of given tensorlists."""
820
- return stack(tensorlists).sum(0)
821
-
822
-
823
- def where(condition: TensorList, input: _STOrSTSequence, other: _STOrSTSequence):
824
- """Where but for a tensorlist."""
825
- args = [i if isinstance(i, (list, tuple)) else [i]*len(condition) for i in (input, other)]
826
- return condition.__class__(torch.where(*z) for z in zip(condition, *args))