torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,1787 @@
1
+ # pylint: disable = redefined-outer-name, bad-indentation
2
+ """note that a lot of this is written by Gemini for a better coverage."""
3
+ import math
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from torchzero.utils.tensorlist import (
9
+ TensorList,
10
+ _MethodCallerWithArgs,
11
+ as_tensorlist,
12
+ generic_clamp,
13
+ mean,
14
+ median,
15
+ quantile,
16
+ stack,
17
+ )
18
+ from torchzero.utils.tensorlist import sum as tl_sum
19
+ from torchzero.utils.tensorlist import where as tl_where
20
+
21
+
22
+ def randmask_like(x,device=None,dtype=None):
23
+ return torch.rand_like(x.float(), device=device,dtype=dtype) > 0.5
24
+
25
+ # Helper function for comparing TensorLists element-wise
26
+ def assert_tl_equal(tl1: TensorList, tl2: TensorList):
27
+ assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
28
+ for t1, t2 in zip(tl1, tl2):
29
+ if t1 is None and t2 is None:
30
+ continue
31
+ assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
32
+ assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
33
+ assert torch.equal(t1, t2), f"Tensors are not equal:\n{t1}\nvs\n{t2}"
34
+
35
+ def assert_tl_allclose(tl1: TensorList, tl2: TensorList, **kwargs):
36
+ assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
37
+ for t1, t2 in zip(tl1, tl2):
38
+ if t1 is None and t2 is None:
39
+ continue
40
+ assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
41
+ assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
42
+ assert torch.allclose(t1, t2, equal_nan=True, **kwargs), f"Tensors are not close:\n{t1}\nvs\n{t2}"
43
+
44
+ # --- Fixtures ---
45
+
46
+ @pytest.fixture
47
+ def simple_tensors() -> list[torch.Tensor]:
48
+ return [torch.randn(2, 3), torch.randn(5), torch.randn(1, 4, 2)]
49
+
50
+
51
+ @pytest.fixture
52
+ def big_tensors() -> list[torch.Tensor]:
53
+ return [torch.randn(20, 50), torch.randn(3,3,3,3,3,3), torch.randn(1000)]
54
+
55
+
56
+ @pytest.fixture
57
+ def simple_tl(simple_tensors) -> TensorList:
58
+ return TensorList(simple_tensors)
59
+
60
+ @pytest.fixture
61
+ def simple_tl_clone(simple_tl) -> TensorList:
62
+ return simple_tl.clone()
63
+
64
+ @pytest.fixture
65
+ def big_tl(big_tensors) -> TensorList:
66
+ return TensorList(big_tensors)
67
+
68
+ @pytest.fixture
69
+ def grad_tensors() -> list[torch.Tensor]:
70
+ return [
71
+ torch.randn(2, 2, requires_grad=True),
72
+ torch.randn(3, requires_grad=False),
73
+ torch.randn(1, 5, requires_grad=True)
74
+ ]
75
+
76
+ @pytest.fixture
77
+ def grad_tl(grad_tensors) -> TensorList:
78
+ return TensorList(grad_tensors)
79
+
80
+ @pytest.fixture
81
+ def int_tensors() -> list[torch.Tensor]:
82
+ return [torch.randint(0, 10, (2, 3)), torch.randint(0, 10, (5,))]
83
+
84
+ @pytest.fixture
85
+ def int_tl(int_tensors) -> TensorList:
86
+ return TensorList(int_tensors)
87
+
88
+ @pytest.fixture
89
+ def bool_tensors() -> list[torch.Tensor]:
90
+ return [torch.rand(2, 3) > 0.5, torch.rand(5) > 0.5]
91
+
92
+ @pytest.fixture
93
+ def bool_tl(bool_tensors) -> TensorList:
94
+ return TensorList(bool_tensors)
95
+
96
+ @pytest.fixture
97
+ def complex_tensors() -> list[torch.Tensor]:
98
+ return [torch.randn(2, 3, dtype=torch.complex64), torch.randn(5, dtype=torch.complex64)]
99
+
100
+ @pytest.fixture
101
+ def complex_tl(complex_tensors) -> TensorList:
102
+ return TensorList(complex_tensors)
103
+
104
+ # --- Test Cases ---
105
+
106
+ def test_initialization(simple_tensors):
107
+ tl = TensorList(simple_tensors)
108
+ assert isinstance(tl, TensorList)
109
+ assert isinstance(tl, list)
110
+ assert len(tl) == len(simple_tensors)
111
+ for i in range(len(tl)):
112
+ assert torch.equal(tl[i], simple_tensors[i])
113
+
114
+ def test_empty_initialization():
115
+ tl = TensorList()
116
+ assert isinstance(tl, TensorList)
117
+ assert len(tl) == 0
118
+
119
+ def test_as_tensorlist(simple_tensors, simple_tl: TensorList):
120
+ tl = as_tensorlist(simple_tensors)
121
+ assert isinstance(tl, TensorList)
122
+ assert_tl_equal(tl, simple_tl)
123
+
124
+ tl2 = as_tensorlist(simple_tl)
125
+ assert tl2 is simple_tl # Should return the same object if already TensorList
126
+
127
+ def test_complex_classmethod(simple_tensors):
128
+ real_tl = TensorList([t.float() for t in simple_tensors])
129
+ imag_tl = TensorList([torch.randn_like(t) for t in simple_tensors])
130
+ complex_tl = TensorList.complex(real_tl, imag_tl)
131
+
132
+ assert isinstance(complex_tl, TensorList)
133
+ assert len(complex_tl) == len(real_tl)
134
+ for i in range(len(complex_tl)):
135
+ assert complex_tl[i].dtype == torch.complex64 or complex_tl[i].dtype == torch.complex128
136
+ assert torch.equal(complex_tl[i].real, real_tl[i])
137
+ assert torch.equal(complex_tl[i].imag, imag_tl[i])
138
+
139
+ # --- Properties ---
140
+ def test_properties(simple_tl: TensorList, simple_tensors):
141
+ assert simple_tl.device == [t.device for t in simple_tensors]
142
+ assert simple_tl.dtype == [t.dtype for t in simple_tensors]
143
+ assert simple_tl.requires_grad == [t.requires_grad for t in simple_tensors]
144
+ assert simple_tl.shape == [t.shape for t in simple_tensors]
145
+ assert simple_tl.size() == [t.size() for t in simple_tensors]
146
+ assert simple_tl.size(0) == [t.size(0) for t in simple_tensors]
147
+ assert simple_tl.ndim == [t.ndim for t in simple_tensors]
148
+ assert simple_tl.ndimension() == [t.ndimension() for t in simple_tensors]
149
+ assert simple_tl.numel() == [t.numel() for t in simple_tensors]
150
+
151
+ def test_grad_property(grad_tl: TensorList, grad_tensors):
152
+ # Initially grads are None
153
+ assert all(g is None for g in grad_tl.grad)
154
+
155
+ # Set some grads
156
+ for i, t in enumerate(grad_tensors):
157
+ if t.requires_grad:
158
+ t.grad = torch.ones_like(t) * (i + 1)
159
+
160
+ grads = grad_tl.grad
161
+ assert isinstance(grads, TensorList)
162
+ assert len(grads) == len(grad_tl)
163
+ for i in range(len(grad_tl)):
164
+ if grad_tensors[i].requires_grad:
165
+ assert torch.equal(grads[i], torch.ones_like(grad_tensors[i]) * (i + 1))
166
+ else:
167
+ assert grads[i] is None # Accessing .grad on non-req-grad tensor returns None
168
+
169
+ def test_real_imag_properties(complex_tl: TensorList, complex_tensors):
170
+ real_part = complex_tl.real
171
+ imag_part = complex_tl.imag
172
+ assert isinstance(real_part, TensorList)
173
+ assert isinstance(imag_part, TensorList)
174
+ assert len(real_part) == len(complex_tl)
175
+ assert len(imag_part) == len(complex_tl)
176
+ for i in range(len(complex_tl)):
177
+ assert torch.equal(real_part[i], complex_tensors[i].real)
178
+ assert torch.equal(imag_part[i], complex_tensors[i].imag)
179
+
180
+ def test_view_as_real_complex(complex_tl: TensorList, complex_tensors):
181
+ real_view = complex_tl.view_as_real()
182
+ assert isinstance(real_view, TensorList)
183
+ assert len(real_view) == len(complex_tl)
184
+ for i in range(len(complex_tl)):
185
+ assert torch.equal(real_view[i], torch.view_as_real(complex_tensors[i]))
186
+
187
+ # Convert back
188
+ complex_view_again = real_view.view_as_complex()
189
+ assert_tl_equal(complex_view_again, complex_tl)
190
+
191
+ # --- Utility Methods ---
192
+
193
+ def test_type_as(simple_tl: TensorList, int_tl: TensorList):
194
+ int_casted_tl = simple_tl.type_as(int_tl[0]) # Cast like first int tensor
195
+ assert isinstance(int_casted_tl, TensorList)
196
+ assert all(t.dtype == int_tl[0].dtype for t in int_casted_tl)
197
+
198
+ float_casted_tl = int_tl.type_as(simple_tl) # Cast like corresponding float tensors
199
+ assert isinstance(float_casted_tl, TensorList)
200
+ assert all(t.dtype == s.dtype for t, s in zip(float_casted_tl, simple_tl))
201
+
202
+
203
+ def test_fill_none(simple_tl: TensorList):
204
+ tl_with_none = TensorList([simple_tl[0], None, simple_tl[2]])
205
+ reference_tl = simple_tl.clone() # Use original shapes as reference
206
+
207
+ filled_tl = tl_with_none.fill_none(reference_tl)
208
+ assert isinstance(filled_tl, TensorList)
209
+ assert filled_tl[0] is tl_with_none[0] # Should keep existing tensor
210
+ assert torch.equal(filled_tl[1], torch.zeros_like(reference_tl[1]))
211
+ assert filled_tl[2] is tl_with_none[2]
212
+ # Check original is not modified
213
+ assert tl_with_none[1] is None
214
+
215
+ filled_tl_inplace = tl_with_none.fill_none_(reference_tl)
216
+ assert filled_tl_inplace is tl_with_none # Should return self
217
+ assert filled_tl_inplace[0] is simple_tl[0]
218
+ assert torch.equal(filled_tl_inplace[1], torch.zeros_like(reference_tl[1]))
219
+ assert filled_tl_inplace[2] is simple_tl[2]
220
+
221
+
222
+ def test_get_grad(grad_tl: TensorList, grad_tensors):
223
+ # No grads initially
224
+ assert len(grad_tl.get_grad()) == 0
225
+
226
+ # Set grads only for tensors requiring grad
227
+ expected_grads = []
228
+ for i, t in enumerate(grad_tensors):
229
+ if t.requires_grad:
230
+ g = torch.rand_like(t)
231
+ t.grad = g
232
+ expected_grads.append(g)
233
+
234
+ retrieved_grads = grad_tl.get_grad()
235
+ assert isinstance(retrieved_grads, TensorList)
236
+ assert len(retrieved_grads) == len(expected_grads)
237
+ for rg, eg in zip(retrieved_grads, expected_grads):
238
+ assert torch.equal(rg, eg)
239
+
240
+
241
+ def test_with_requires_grad(grad_tl: TensorList, grad_tensors):
242
+ req_grad_true = grad_tl.with_requires_grad(True)
243
+ expected_true = [t for t in grad_tensors if t.requires_grad]
244
+ assert len(req_grad_true) == len(expected_true)
245
+ for rt, et in zip(req_grad_true, expected_true):
246
+ assert rt is et
247
+
248
+ req_grad_false = grad_tl.with_requires_grad(False)
249
+ expected_false = [t for t in grad_tensors if not t.requires_grad]
250
+ assert len(req_grad_false) == len(expected_false)
251
+ for rt, et in zip(req_grad_false, expected_false):
252
+ assert rt is et
253
+
254
+
255
+ def test_with_grad(grad_tl: TensorList, grad_tensors):
256
+ assert len(grad_tl.with_grad()) == 0 # No grads set yet
257
+
258
+ # Set grads for tensors requiring grad
259
+ expected_with_grad = []
260
+ for i, t in enumerate(grad_tensors):
261
+ if t.requires_grad:
262
+ t.grad = torch.ones_like(t) * i
263
+ expected_with_grad.append(t)
264
+
265
+ has_grad_tl = grad_tl.with_grad()
266
+ assert isinstance(has_grad_tl, TensorList)
267
+ assert len(has_grad_tl) == len(expected_with_grad)
268
+ for hg, eg in zip(has_grad_tl, expected_with_grad):
269
+ assert hg is eg
270
+
271
+
272
+ def test_ensure_grad_(grad_tl: TensorList, grad_tensors):
273
+ # Call ensure_grad_
274
+ grad_tl.ensure_grad_()
275
+
276
+ for t in grad_tl:
277
+ if t.requires_grad:
278
+ assert t.grad is not None
279
+ assert torch.equal(t.grad, torch.zeros_like(t))
280
+ else:
281
+ assert t.grad is None
282
+
283
+ # Call again, should not change existing zero grads
284
+ grad_tl.ensure_grad_()
285
+ for t in grad_tl:
286
+ if t.requires_grad:
287
+ assert t.grad is not None, 'this is a fixture'
288
+ assert torch.equal(t.grad, torch.zeros_like(t))
289
+
290
+
291
+ def test_accumulate_grad_(grad_tl: TensorList, grad_tensors):
292
+ new_grads = TensorList([torch.rand_like(t) for t in grad_tensors])
293
+ new_grads_copy = new_grads.clone()
294
+
295
+ # First accumulation (grads are None or zero if ensure_grad_ was called)
296
+ grad_tl.accumulate_grad_(new_grads)
297
+ for t, ng in zip(grad_tl, new_grads_copy):
298
+ # if t.requires_grad:
299
+ assert t.grad is not None
300
+ assert torch.equal(t.grad, ng)
301
+ # else:
302
+ # assert t.grad is None # Should not create grad if requires_grad is False
303
+
304
+ # Second accumulation
305
+ new_grads_2 = TensorList([torch.rand_like(t) for t in grad_tensors])
306
+ expected_grads = TensorList([g + ng2 for t, g, ng2 in zip(grad_tensors, grad_tl.grad, new_grads_2)])
307
+
308
+ grad_tl.accumulate_grad_(new_grads_2)
309
+ for t, eg in zip(grad_tl, expected_grads):
310
+ assert t.grad is not None
311
+ assert torch.allclose(t.grad, eg)
312
+ # else:
313
+ # assert t.grad is None
314
+
315
+
316
+ def test_set_grad_(grad_tl: TensorList, grad_tensors):
317
+ # Set initial grads
318
+ initial_grads = TensorList([torch.ones_like(t) if t.requires_grad else None for t in grad_tensors])
319
+ grad_tl.set_grad_(initial_grads)
320
+ for t, ig in zip(grad_tl, initial_grads):
321
+ assert t.grad is ig
322
+
323
+ # Set new grads
324
+ new_grads = TensorList([torch.rand_like(t) * 2 if t.requires_grad else None for t in grad_tensors])
325
+ grad_tl.set_grad_(new_grads)
326
+ for t, ng in zip(grad_tl, new_grads):
327
+ assert t.grad is ng # Checks object identity for None, value for Tensors
328
+
329
+
330
+ def test_zero_grad_(grad_tl: TensorList, grad_tensors):
331
+ # Set some grads
332
+ for t in grad_tl:
333
+ if t.requires_grad:
334
+ t.grad = torch.ones_like(t)
335
+
336
+ # Zero grads (set to None)
337
+ grad_tl.zero_grad_(set_to_none=True)
338
+ for t in grad_tl:
339
+ assert t.grad is None
340
+
341
+ # Set grads again
342
+ for t in grad_tl:
343
+ if t.requires_grad:
344
+ t.grad = torch.ones_like(t)
345
+
346
+ # Zero grads (set to zero)
347
+ grad_tl.zero_grad_(set_to_none=False)
348
+ for t in grad_tl:
349
+ if t.requires_grad:
350
+ assert t.grad is not None
351
+ assert torch.equal(t.grad, torch.zeros_like(t))
352
+ else:
353
+ assert t.grad is None # Should remain None if requires_grad is False
354
+
355
+
356
+ # --- Arithmetic Operators ---
357
+
358
+ @pytest.mark.parametrize("other_type", ["scalar", "list_scalar", "tensorlist", "list_tensor"])
359
+ @pytest.mark.parametrize("op, op_inplace, torch_op, foreach_op, foreach_op_inplace", [
360
+ ('__add__', '__iadd__', torch.add, torch._foreach_add, torch._foreach_add_),
361
+ ('__sub__', '__isub__', torch.sub, torch._foreach_sub, torch._foreach_sub_),
362
+ ('__mul__', '__imul__', torch.mul, torch._foreach_mul, torch._foreach_mul_),
363
+ ('__truediv__', '__itruediv__', torch.div, torch._foreach_div, torch._foreach_div_),
364
+ ])
365
+ def test_arithmetic_ops(simple_tl: TensorList, simple_tl_clone: TensorList, other_type, op, op_inplace, torch_op, foreach_op, foreach_op_inplace):
366
+ if other_type == "scalar":
367
+ other = 2.5
368
+ other_list = [other] * len(simple_tl)
369
+ elif other_type == "list_scalar":
370
+ other = [1.0, 2.0, 3.0]
371
+ other_list = other
372
+ elif other_type == "tensorlist":
373
+ other = simple_tl_clone.clone().mul_(0.5) # Create a compatible TensorList
374
+ other_list = other
375
+ elif other_type == "list_tensor":
376
+ other = [t * 0.5 for t in simple_tl_clone] # Create a compatible list of tensors
377
+ other_list = other
378
+ else:
379
+ pytest.fail("Unknown other_type")
380
+
381
+ # --- Test out-of-place ---
382
+ op_func = getattr(simple_tl, op)
383
+ result_tl = op_func(other)
384
+ expected_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_list)])
385
+
386
+ assert isinstance(result_tl, TensorList)
387
+ assert_tl_allclose(result_tl, expected_tl)
388
+ # Ensure original is unchanged
389
+ assert_tl_equal(simple_tl, simple_tl_clone)
390
+
391
+ # Test foreach version directly for comparison (if applicable)
392
+ if op != '__sub__' or other_type != 'scalar': # _foreach_sub doesn't support scalar 'other' directly
393
+ if hasattr(torch, foreach_op.__name__):
394
+ expected_foreach = TensorList(foreach_op(simple_tl, other_list))
395
+ assert_tl_allclose(result_tl, expected_foreach)
396
+
397
+ # --- Test in-place ---
398
+ tl_copy = simple_tl.clone()
399
+ op_inplace_func = getattr(tl_copy, op_inplace)
400
+ result_inplace = op_inplace_func(other)
401
+
402
+ assert result_inplace is tl_copy # Should return self
403
+ assert_tl_allclose(tl_copy, expected_tl)
404
+
405
+ # Test foreach_ inplace version directly
406
+ tl_copy_foreach = simple_tl.clone()
407
+ if op != '__sub__' or other_type != 'scalar': # _foreach_sub_ doesn't support scalar 'other' directly
408
+ if hasattr(torch, foreach_op_inplace.__name__):
409
+ foreach_op_inplace(tl_copy_foreach, other_list)
410
+ assert_tl_allclose(tl_copy_foreach, expected_tl)
411
+
412
+ # --- Test r-ops (if applicable) ---
413
+ if op in ['__add__', '__mul__']: # Commutative
414
+ rop_func = getattr(simple_tl, op.replace('__', '__r', 1))
415
+ result_rtl = rop_func(other)
416
+ assert_tl_allclose(result_rtl, expected_tl)
417
+ elif op == '__sub__': # Test rsub: other - self = -(self - other)
418
+ rop_func = getattr(simple_tl, '__rsub__')
419
+ result_rtl = rop_func(other)
420
+ expected_rtl = expected_tl.neg() # Note: self.sub(other).neg_() == other - self
421
+ assert_tl_allclose(result_rtl, expected_rtl)
422
+ elif op == '__truediv__': # Test rtruediv: other / self
423
+ if other_type in ["scalar", "list_scalar"]: # scalar / tensor or list<scalar> / list<tensor>
424
+ rop_func = getattr(simple_tl, '__rtruediv__')
425
+ result_rtl = rop_func(other)
426
+ expected_rtl = TensorList([o / t for t, o in zip(simple_tl, other_list)])
427
+ assert_tl_allclose(result_rtl, expected_rtl)
428
+ # rtruediv for tensorlist/list_tensor is not implemented directly
429
+
430
+
431
+ @pytest.mark.parametrize("op, torch_op", [
432
+ ('__pow__', torch.pow),
433
+ ('__floordiv__', torch.floor_divide),
434
+ ('__mod__', torch.remainder),
435
+ ])
436
+ @pytest.mark.parametrize("other_type", ["scalar", "list_scalar", "tensorlist", "list_tensor"])
437
+ def test_other_arithmetic_ops(simple_tl: TensorList, simple_tl_clone: TensorList, op, torch_op, other_type):
438
+ is_pow = op == '__pow__'
439
+ if other_type == "scalar":
440
+ other = 2 if is_pow else 2.5
441
+ other_list = [other] * len(simple_tl)
442
+ elif other_type == "list_scalar":
443
+ other = [2, 1, 3] if is_pow else [1.5, 2.5, 3.5]
444
+ other_list = other
445
+ elif other_type == "tensorlist":
446
+ other = simple_tl_clone.clone().abs_().add_(1).clamp_(max=3) if is_pow else simple_tl_clone.clone().mul_(0.5).add_(1)
447
+ other_list = other
448
+ elif other_type == "list_tensor":
449
+ other = [(t.abs() + 1).clamp(max=3) if is_pow else (t*0.5 + 1) for t in simple_tl_clone]
450
+ other_list = other
451
+ else:
452
+ pytest.fail("Unknown other_type")
453
+
454
+ # Test out-of-place
455
+ op_func = getattr(simple_tl, op)
456
+ result_tl = op_func(other)
457
+ expected_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_list)])
458
+
459
+ assert isinstance(result_tl, TensorList)
460
+ assert_tl_allclose(result_tl, expected_tl)
461
+ assert_tl_equal(simple_tl, simple_tl_clone) # Ensure original unchanged
462
+
463
+ # Test in-place (if exists)
464
+ op_inplace = op.replace('__', '__i', 1) + '_' # Standard naming convention adopted
465
+ if hasattr(simple_tl, op_inplace):
466
+ tl_copy = simple_tl.clone()
467
+ op_inplace_func = getattr(tl_copy, op_inplace)
468
+ result_inplace = op_inplace_func(other)
469
+ assert result_inplace is tl_copy
470
+ assert_tl_allclose(tl_copy, expected_tl)
471
+
472
+ # Test rpow
473
+ if op == '__pow__' and other_type in ['scalar']:#, 'list_scalar']: # _foreach_pow doesn't support list of scalars as base
474
+ rop_func = getattr(simple_tl, '__rpow__')
475
+ result_rtl = rop_func(other)
476
+ expected_rtl = TensorList([torch_op(o, t) for t, o in zip(simple_tl, other_list)])
477
+ assert_tl_allclose(result_rtl, expected_rtl)
478
+
479
+
480
+ def test_negation(simple_tl: TensorList, simple_tl_clone):
481
+ neg_tl = -simple_tl
482
+ expected_tl = TensorList([-t for t in simple_tl])
483
+ assert_tl_allclose(neg_tl, expected_tl)
484
+ assert_tl_equal(simple_tl, simple_tl_clone) # Ensure original unchanged
485
+
486
+ neg_tl_inplace = simple_tl.neg_()
487
+ assert neg_tl_inplace is simple_tl
488
+ assert_tl_allclose(simple_tl, expected_tl)
489
+
490
+ # --- Comparison Operators ---
491
+
492
+ @pytest.mark.parametrize("op, torch_op", [
493
+ ('__eq__', torch.eq),
494
+ ('__ne__', torch.ne),
495
+ ('__lt__', torch.lt),
496
+ ('__le__', torch.le),
497
+ ('__gt__', torch.gt),
498
+ ('__ge__', torch.ge),
499
+ ])
500
+ @pytest.mark.parametrize("other_type", ["scalar", "list_scalar", "tensorlist", "list_tensor"])
501
+ def test_comparison_ops(simple_tl: TensorList, op, torch_op, other_type):
502
+ if other_type == "scalar":
503
+ other = 0.0
504
+ other_list = [other] * len(simple_tl)
505
+ elif other_type == "list_scalar":
506
+ other = [-0.5, 0.0, 0.5]
507
+ other_list = other
508
+ elif other_type == "tensorlist":
509
+ other = simple_tl.clone().mul_(0.9)
510
+ other_list = other
511
+ elif other_type == "list_tensor":
512
+ other = [t * 0.9 for t in simple_tl]
513
+ other_list = other
514
+ else:
515
+ pytest.fail("Unknown other_type")
516
+
517
+ op_func = getattr(simple_tl, op)
518
+ result_tl = op_func(other)
519
+ expected_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_list)])
520
+
521
+ assert isinstance(result_tl, TensorList)
522
+ assert all(t.dtype == torch.bool for t in result_tl)
523
+ assert_tl_equal(result_tl, expected_tl)
524
+
525
+
526
+ # --- Logical Operators ---
527
+
528
+ @pytest.mark.parametrize("op, op_inplace, torch_op", [
529
+ ('__and__', '__iand__', torch.logical_and),
530
+ ('__or__', '__ior__', torch.logical_or),
531
+ ('__xor__', '__ixor__', torch.logical_xor),
532
+ ])
533
+ def test_logical_binary_ops(bool_tl: TensorList, op, op_inplace, torch_op):
534
+ other_tl = TensorList([randmask_like(t) for t in bool_tl])
535
+ other_list = list(other_tl) # Use list version for comparison
536
+
537
+ # Out-of-place
538
+ op_func = getattr(bool_tl, op)
539
+ result_tl = op_func(other_tl)
540
+ expected_tl = TensorList([torch_op(t, o) for t, o in zip(bool_tl, other_list)])
541
+
542
+ assert isinstance(result_tl, TensorList)
543
+ assert all(t.dtype == torch.bool for t in result_tl)
544
+ assert_tl_equal(result_tl, expected_tl)
545
+
546
+ # In-place
547
+ tl_copy = bool_tl.clone()
548
+ op_inplace_func = getattr(tl_copy, op_inplace) # Naming convention with _
549
+ result_inplace = op_inplace_func(other_tl)
550
+ assert result_inplace is tl_copy
551
+ assert_tl_equal(tl_copy, expected_tl)
552
+
553
+
554
+ def test_logical_not(bool_tl: TensorList):
555
+ # Out-of-place (~ operator maps to logical_not)
556
+ not_tl = ~bool_tl
557
+ expected_tl = TensorList([torch.logical_not(t) for t in bool_tl])
558
+ assert isinstance(not_tl, TensorList)
559
+ assert all(t.dtype == torch.bool for t in not_tl)
560
+ assert_tl_equal(not_tl, expected_tl)
561
+
562
+ # In-place
563
+ tl_copy = bool_tl.clone()
564
+ result_inplace = tl_copy.logical_not_()
565
+ assert result_inplace is tl_copy
566
+ assert_tl_equal(tl_copy, expected_tl)
567
+
568
+
569
+ def test_bool_raises(simple_tl: TensorList):
570
+ with pytest.raises(RuntimeError, match="Boolean value of TensorList is ambiguous"):
571
+ bool(simple_tl)
572
+ # Test with empty list
573
+ with pytest.raises(RuntimeError, match="Boolean value of TensorList is ambiguous"):
574
+ bool(TensorList())
575
+
576
+ # --- Map / Zipmap / Filter ---
577
+
578
+ def test_map(simple_tl: TensorList):
579
+ mapped_tl = simple_tl.map(torch.abs)
580
+ expected_tl = TensorList([torch.abs(t) for t in simple_tl])
581
+ assert_tl_allclose(mapped_tl, expected_tl)
582
+
583
+ def test_map_inplace_(simple_tl: TensorList):
584
+ tl_copy = simple_tl.clone()
585
+ result = tl_copy.map_inplace_(torch.abs_)
586
+ expected_tl = TensorList([torch.abs(t) for t in simple_tl]) # Calculate expected from original
587
+ assert result is tl_copy
588
+ assert_tl_allclose(tl_copy, expected_tl)
589
+
590
+ def test_filter(simple_tl: TensorList):
591
+ # Filter tensors with more than 5 elements
592
+ filtered_tl = simple_tl.filter(lambda t: t.numel() > 5)
593
+ expected_tl = TensorList([t for t in simple_tl if t.numel() > 5])
594
+ assert len(filtered_tl) == len(expected_tl)
595
+ for ft, et in zip(filtered_tl, expected_tl):
596
+ assert ft is et # Should contain the original tensor objects
597
+
598
+ def test_zipmap(simple_tl: TensorList):
599
+ # Zipmap with another TensorList
600
+ other_tl = simple_tl.clone().mul_(0.5)
601
+ result_tl = simple_tl.zipmap(torch.add, other_tl)
602
+ expected_tl = TensorList([torch.add(t, o) for t, o in zip(simple_tl, other_tl)])
603
+ assert_tl_allclose(result_tl, expected_tl)
604
+
605
+ # Zipmap with a list of tensors
606
+ other_list = [t * 0.5 for t in simple_tl]
607
+ result_tl_list = simple_tl.zipmap(torch.add, other_list)
608
+ assert_tl_allclose(result_tl_list, expected_tl)
609
+
610
+ # Zipmap with a scalar
611
+ result_tl_scalar = simple_tl.zipmap(torch.add, 2.0)
612
+ expected_tl_scalar = TensorList([torch.add(t, 2.0) for t in simple_tl])
613
+ assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
614
+
615
+ # Zipmap with a list of scalars
616
+ other_scalars = [1.0, 2.0, 3.0]
617
+ result_tl_scalars = simple_tl.zipmap(torch.add, other_scalars)
618
+ expected_tl_scalars = TensorList([torch.add(t, s) for t, s in zip(simple_tl, other_scalars)])
619
+ assert_tl_allclose(result_tl_scalars, expected_tl_scalars)
620
+
621
+
622
+ def test_zipmap_inplace_(simple_tl: TensorList):
623
+ # Zipmap inplace with another TensorList
624
+ tl_copy = simple_tl.clone()
625
+ other_tl = simple_tl.clone().mul_(0.5)
626
+ result = tl_copy.zipmap_inplace_(_MethodCallerWithArgs('add_'), other_tl)
627
+ expected_tl = TensorList([torch.add(t, o) for t, o in zip(simple_tl, other_tl)])
628
+ assert result is tl_copy
629
+ assert_tl_allclose(tl_copy, expected_tl)
630
+
631
+ # Zipmap inplace with a scalar
632
+ tl_copy_scalar = simple_tl.clone()
633
+ result_scalar = tl_copy_scalar.zipmap_inplace_(_MethodCallerWithArgs('add_'), 2.0)
634
+ expected_tl_scalar = TensorList([torch.add(t, 2.0) for t in simple_tl])
635
+ assert result_scalar is tl_copy_scalar
636
+ assert_tl_allclose(tl_copy_scalar, expected_tl_scalar)
637
+
638
+ # Zipmap inplace with list of scalars
639
+ tl_copy_scalars = simple_tl.clone()
640
+ other_scalars = [1.0, 2.0, 3.0]
641
+ result_scalars = tl_copy_scalars.zipmap_inplace_(_MethodCallerWithArgs('add_'), other_scalars)
642
+ expected_tl_scalars = TensorList([torch.add(t, s) for t, s in zip(simple_tl, other_scalars)])
643
+ assert result_scalars is tl_copy_scalars
644
+ assert_tl_allclose(tl_copy_scalars, expected_tl_scalars)
645
+
646
+
647
+ def test_zipmap_args(simple_tl: TensorList):
648
+ other1 = simple_tl.clone().mul(0.5)
649
+ other2 = 2.0
650
+ other3 = [1, 2, 3]
651
+ # Test torch.lerp(input, end, weight) -> input + weight * (end - input)
652
+ # self = input, other1 = end, other2 = weight (scalar)
653
+ result_tl = simple_tl.zipmap_args(torch.lerp, other1, other2)
654
+ expected_tl = TensorList([torch.lerp(t, o1, other2) for t, o1 in zip(simple_tl, other1)])
655
+ assert_tl_allclose(result_tl, expected_tl)
656
+
657
+ # self = input, other1 = end, other3 = weight (list scalar)
658
+ result_tl_list = simple_tl.zipmap_args(torch.lerp, other1, other3)
659
+ expected_tl_list = TensorList([torch.lerp(t, o1, o3) for t, o1, o3 in zip(simple_tl, other1, other3)])
660
+ assert_tl_allclose(result_tl_list, expected_tl_list)
661
+
662
+ def test_zipmap_args_inplace_(simple_tl: TensorList):
663
+ tl_copy = simple_tl.clone()
664
+ other1 = simple_tl.clone().mul(0.5)
665
+ other2 = 0.5
666
+ # Test torch.addcmul_(tensor1, tensor2, value=1) -> self + value * tensor1 * tensor2
667
+ # self = self, other1 = tensor1, other1 (again) = tensor2, other2 = value
668
+ result_tl = tl_copy.zipmap_args_inplace_(_MethodCallerWithArgs('addcmul_'), other1, other1, value=other2)
669
+ expected_tl = TensorList([t.addcmul(o1, o1, value=other2) for t, o1 in zip(simple_tl.clone(), other1)]) # Need clone for calculation
670
+ assert result_tl is tl_copy
671
+ assert_tl_allclose(tl_copy, expected_tl)
672
+
673
+ # --- Tensor Method Wrappers ---
674
+
675
+ @pytest.mark.parametrize("method_name, args", [
676
+ ('clone', ()),
677
+ ('detach', ()),
678
+ ('contiguous', ()),
679
+ ('cpu', ()),
680
+ ('long', ()),
681
+ ('short', ()),
682
+ ('as_float', ()),
683
+ ('as_int', ()),
684
+ ('as_bool', ()),
685
+ ('sqrt', ()),
686
+ ('exp', ()),
687
+ ('log', ()),
688
+ ('sin', ()),
689
+ ('cos', ()),
690
+ ('abs', ()),
691
+ ('neg', ()),
692
+ ('reciprocal', ()),
693
+ ('sign', ()),
694
+ ('round', ()),
695
+ ('floor', ()),
696
+ ('ceil', ()),
697
+ ('logical_not', ()), # Assuming input is boolean for this test
698
+ ('ravel', ()),
699
+ ('view_flat', ()),
700
+ ('conj', ()), # Assuming input is complex for this test
701
+ ('squeeze', ()),
702
+ ('squeeze', (0,)), # Example with args
703
+ # Add more simple unary methods here...
704
+ ])
705
+ def test_simple_unary_methods(simple_tl: TensorList, method_name, args):
706
+ tl_to_test = simple_tl
707
+ if method_name == 'logical_not':
708
+ tl_to_test = simple_tl.gt(0) # Create a boolean TL
709
+ elif method_name == 'conj':
710
+ tl_to_test = TensorList.complex(simple_tl, simple_tl) # Create complex
711
+
712
+ method = getattr(tl_to_test, method_name)
713
+ result_tl = method(*args)
714
+
715
+ method_names_map = {"as_float": "float", "as_int": "int", "as_bool": "bool", "view_flat": "ravel"}
716
+ tensor_method_name = method_name
717
+ if tensor_method_name in method_names_map: tensor_method_name = method_names_map[tensor_method_name]
718
+ expected_tl = TensorList([getattr(t, tensor_method_name)(*args) for t in tl_to_test])
719
+
720
+ assert isinstance(result_tl, TensorList)
721
+ # Need allclose for float results, equal for others
722
+ if any(t.is_floating_point() for t in expected_tl):
723
+ assert_tl_allclose(result_tl, expected_tl)
724
+ else:
725
+ assert_tl_equal(result_tl, expected_tl)
726
+
727
+ # Test inplace if available
728
+ method_inplace_name = method_name + '_'
729
+ if hasattr(tl_to_test, method_inplace_name) and hasattr(torch.Tensor, method_inplace_name):
730
+ tl_copy = tl_to_test.clone()
731
+ method_inplace = getattr(tl_copy, method_inplace_name)
732
+ result_inplace = method_inplace(*args)
733
+ assert result_inplace is tl_copy
734
+ if any(t.is_floating_point() for t in expected_tl):
735
+ assert_tl_allclose(tl_copy, expected_tl)
736
+ else:
737
+ assert_tl_equal(tl_copy, expected_tl)
738
+
739
+ def test_to(simple_tl: TensorList):
740
+ # Test changing dtype
741
+ float_tl = simple_tl.to(dtype=torch.float64)
742
+ assert all(t.dtype == torch.float64 for t in float_tl)
743
+
744
+ # Test changing device (if multiple devices available)
745
+ if torch.cuda.is_available():
746
+ cuda_tl = simple_tl.to(device='cuda')
747
+ assert all(t.device.type == 'cuda' for t in cuda_tl)
748
+ cpu_tl = cuda_tl.to('cpu')
749
+ assert all(t.device.type == 'cpu' for t in cpu_tl)
750
+
751
+ def test_copy_(simple_tl: TensorList):
752
+ src_tl = TensorList([torch.randn_like(t) for t in simple_tl])
753
+ tl_copy = simple_tl.clone()
754
+ tl_copy.copy_(src_tl)
755
+ assert_tl_equal(tl_copy, src_tl)
756
+ # Ensure src is unchanged
757
+ assert not torch.equal(src_tl[0], simple_tl[0]) # Verify src was different
758
+
759
+ def test_set_(simple_tl: TensorList):
760
+ src_tl = TensorList([torch.randn_like(t) for t in simple_tl])
761
+ tl_copy = simple_tl.clone()
762
+ tl_copy.set_(src_tl) # src_tl provides the storage/tensors
763
+ assert_tl_equal(tl_copy, src_tl)
764
+ # Note: set_ might have side effects on src_tl depending on PyTorch version/tensor types
765
+
766
+ def test_requires_grad_(grad_tl: TensorList):
767
+ grad_tl.requires_grad_(False)
768
+ assert grad_tl.requires_grad == [False] * len(grad_tl)
769
+ grad_tl.requires_grad_(True)
770
+ # This sets requires_grad=True for ALL tensors, unlike the initial fixture
771
+ assert grad_tl.requires_grad == [True] * len(grad_tl)
772
+
773
+
774
+ # --- Vectorization ---
775
+
776
+ def test_to_vec(simple_tl: TensorList):
777
+ vec = simple_tl.to_vec()
778
+ expected_vec = torch.cat([t.view(-1) for t in simple_tl])
779
+ assert torch.equal(vec, expected_vec)
780
+
781
+ def test_from_vec_(simple_tl: TensorList):
782
+ tl_copy = simple_tl.clone()
783
+ numel = simple_tl.global_numel()
784
+ new_vec = torch.arange(numel, dtype=simple_tl[0].dtype).float() # Use float for generality
785
+
786
+ result = tl_copy.from_vec_(new_vec)
787
+ assert result is tl_copy
788
+
789
+ current_pos = 0
790
+ for t_orig, t_modified in zip(simple_tl, tl_copy):
791
+ n = t_orig.numel()
792
+ expected_tensor = new_vec[current_pos : current_pos + n].view_as(t_orig)
793
+ assert torch.equal(t_modified, expected_tensor)
794
+ current_pos += n
795
+
796
+ def test_from_vec(simple_tl: TensorList):
797
+ tl_clone = simple_tl.clone() # Keep original safe
798
+ numel = simple_tl.global_numel()
799
+ new_vec = torch.arange(numel, dtype=simple_tl[0].dtype).float()
800
+
801
+ new_tl = simple_tl.from_vec(new_vec)
802
+ assert isinstance(new_tl, TensorList)
803
+ assert_tl_equal(simple_tl, tl_clone) # Original unchanged
804
+
805
+ current_pos = 0
806
+ for t_orig, t_new in zip(simple_tl, new_tl):
807
+ n = t_orig.numel()
808
+ expected_tensor = new_vec[current_pos : current_pos + n].view_as(t_orig)
809
+ assert torch.equal(t_new, expected_tensor)
810
+ current_pos += n
811
+
812
+
813
+ # --- Global Reductions ---
814
+
815
+ @pytest.mark.parametrize("global_method, vec_equiv_method", [
816
+ ('global_min', 'min'),
817
+ ('global_max', 'max'),
818
+ ('global_sum', 'sum'),
819
+ ('global_mean', 'mean'),
820
+ ('global_std', 'std'),
821
+ ('global_var', 'var'),
822
+ ('global_any', 'any'),
823
+ ('global_all', 'all'),
824
+ ])
825
+ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_method):
826
+ tl_to_test = simple_tl
827
+ if 'any' in global_method or 'all' in global_method:
828
+ tl_to_test = simple_tl.gt(0) # Need boolean input
829
+
830
+ global_method_func = getattr(tl_to_test, global_method)
831
+ result = global_method_func()
832
+
833
+ vec = tl_to_test.to_vec()
834
+ vec_equiv_func = getattr(vec, vec_equiv_method)
835
+ expected = vec_equiv_func()
836
+
837
+ if isinstance(result, bool): assert result == expected
838
+ else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
839
+
840
+
841
+ def test_global_vector_norm(simple_tl: TensorList):
842
+ ord = 1.5
843
+ result = simple_tl.global_vector_norm(ord=ord)
844
+ vec = simple_tl.to_vec()
845
+ expected = torch.linalg.vector_norm(vec, ord=ord) # pylint:disable=not-callable
846
+ assert torch.allclose(result, expected)
847
+
848
+ def test_global_numel(simple_tl: TensorList):
849
+ result = simple_tl.global_numel()
850
+ expected = sum(t.numel() for t in simple_tl)
851
+ assert result == expected
852
+
853
+ # --- Like Creation Methods ---
854
+
855
+ @pytest.mark.parametrize("like_method, torch_equiv", [
856
+ ('empty_like', torch.empty_like),
857
+ ('zeros_like', torch.zeros_like),
858
+ ('ones_like', torch.ones_like),
859
+ ('rand_like', torch.rand_like),
860
+ ('randn_like', torch.randn_like),
861
+ ])
862
+ def test_simple_like_methods(simple_tl: TensorList, like_method, torch_equiv):
863
+ like_method_func = getattr(simple_tl, like_method)
864
+ result_tl = like_method_func()
865
+
866
+ assert isinstance(result_tl, TensorList)
867
+ assert len(result_tl) == len(simple_tl)
868
+ for res_t, orig_t in zip(result_tl, simple_tl):
869
+ assert res_t.shape == orig_t.shape
870
+ assert res_t.dtype == orig_t.dtype
871
+ assert res_t.device == orig_t.device
872
+ # Cannot easily check values for rand/randn/empty
873
+
874
+ # Test with kwargs (e.g., changing dtype)
875
+ result_tl_kw = like_method_func(dtype=torch.float64)
876
+ assert all(t.dtype == torch.float64 for t in result_tl_kw)
877
+
878
+
879
+ def test_full_like(simple_tl: TensorList):
880
+ # Scalar fill_value
881
+ fill_value_scalar = 5.0
882
+ result_tl_scalar = simple_tl.full_like(fill_value_scalar)
883
+ expected_tl_scalar = TensorList([torch.full_like(t, fill_value_scalar) for t in simple_tl])
884
+ assert_tl_equal(result_tl_scalar, expected_tl_scalar)
885
+
886
+ # List fill_value
887
+ fill_value_list = [1.0, 2.0, 3.0]
888
+ result_tl_list = simple_tl.full_like(fill_value_list)
889
+ expected_tl_list = TensorList([torch.full_like(t, fv) for t, fv in zip(simple_tl, fill_value_list)])
890
+ assert_tl_equal(result_tl_list, expected_tl_list)
891
+
892
+ # Test with kwargs
893
+ result_tl_kw = simple_tl.full_like(fill_value_scalar, dtype=torch.int)
894
+ assert all(t.dtype == torch.int for t in result_tl_kw)
895
+ assert all(torch.all(t == int(fill_value_scalar)) for t in result_tl_kw)
896
+
897
+
898
+ def test_randint_like(simple_tl: TensorList):
899
+ low = 0
900
+ high = 10
901
+ # Scalar low/high
902
+ result_tl_scalar = simple_tl.randint_like(low, high)
903
+ assert isinstance(result_tl_scalar, TensorList)
904
+ assert all(t.dtype == simple_tl[0].dtype for t in result_tl_scalar) # Default dtype
905
+ assert all(torch.all((t >= low) & (t < high)) for t in result_tl_scalar)
906
+ assert result_tl_scalar.shape == simple_tl.shape
907
+
908
+ # List low/high
909
+ low_list = [0, 5, 2]
910
+ high_list = [5, 15, 7]
911
+ result_tl_list = simple_tl.randint_like(low_list, high_list)
912
+ assert isinstance(result_tl_list, TensorList)
913
+ assert all(t.dtype == simple_tl[0].dtype for t in result_tl_list)
914
+ assert all(torch.all((t >= l) & (t < h)) for t, l, h in zip(result_tl_list, low_list, high_list))
915
+ assert result_tl_list.shape == simple_tl.shape
916
+
917
+
918
+ def test_uniform_like(simple_tl: TensorList):
919
+ # Default range (0, 1)
920
+ result_tl_default = simple_tl.uniform_like()
921
+ assert isinstance(result_tl_default, TensorList)
922
+ assert result_tl_default.shape == simple_tl.shape
923
+ assert all(t.dtype == simple_tl[i].dtype for i, t in enumerate(result_tl_default))
924
+ assert all(torch.all((t >= 0) & (t <= 1)) for t in result_tl_default) # Check range roughly
925
+
926
+ # Scalar low/high
927
+ low, high = -1.0, 1.0
928
+ result_tl_scalar = simple_tl.uniform_like(low, high)
929
+ assert all(torch.all((t >= low) & (t <= high)) for t in result_tl_scalar)
930
+
931
+ # List low/high
932
+ low_list = [-1, 0, -2]
933
+ high_list = [0, 1, -1]
934
+ result_tl_list = simple_tl.uniform_like(low_list, high_list)
935
+ assert all(torch.all((t >= l) & (t <= h)) for t, l, h in zip(result_tl_list, low_list, high_list))
936
+
937
+
938
+ def test_sphere_like(simple_tl: TensorList):
939
+ radius = 5.0
940
+ result_tl_scalar = simple_tl.sphere_like(radius)
941
+ assert isinstance(result_tl_scalar, TensorList)
942
+ assert result_tl_scalar.shape == simple_tl.shape
943
+ assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(radius))
944
+
945
+ radius_list = [1.0, 10.0, 2.0]
946
+ result_tl_list = simple_tl.sphere_like(radius_list)
947
+ # Cannot easily check norm with list radius, just check type/shape
948
+ assert isinstance(result_tl_list, TensorList)
949
+ assert result_tl_list.shape == simple_tl.shape
950
+
951
+
952
+ def test_bernoulli_like(big_tl: TensorList):
953
+ p_scalar = 0.7
954
+ result_tl_scalar = big_tl.bernoulli_like(p_scalar)
955
+ assert isinstance(result_tl_scalar, TensorList)
956
+ assert result_tl_scalar.shape == big_tl.shape
957
+ assert all(t.dtype == big_tl[i].dtype for i, t in enumerate(result_tl_scalar)) # Should preserve dtype
958
+ assert all(torch.all((t == 0) | (t == 1)) for t in result_tl_scalar)
959
+ # Check mean is approximately p
960
+ assert abs(result_tl_scalar.to_vec().float().mean().item() - p_scalar) < 0.1 # Loose check
961
+
962
+ p_list = [0.2, 0.5, 0.8]
963
+ result_tl_list = big_tl.bernoulli_like(p_list)
964
+ assert isinstance(result_tl_list, TensorList)
965
+ assert result_tl_list.shape == big_tl.shape
966
+
967
+
968
+ def test_rademacher_like(big_tl: TensorList):
969
+ result_tl = big_tl.rademacher_like() # p=0.5 default
970
+ assert isinstance(result_tl, TensorList)
971
+ assert result_tl.shape == big_tl.shape
972
+ assert all(torch.all((t == -1) | (t == 1)) for t in result_tl)
973
+
974
+ # Check mean is approx 0
975
+ assert abs(result_tl.to_vec().float().mean().item()) < 0.1 # Loose check
976
+
977
+
978
+ @pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
979
+ def test_sample_like(simple_tl: TensorList, dist):
980
+ eps_scalar = 2.0
981
+ result_tl_scalar = simple_tl.sample_like(eps_scalar, distribution=dist)
982
+ assert isinstance(result_tl_scalar, TensorList)
983
+ assert result_tl_scalar.shape == simple_tl.shape
984
+
985
+ eps_list = [0.5, 1.0, 1.5]
986
+ result_tl_list = simple_tl.sample_like(eps_list, distribution=dist)
987
+ assert isinstance(result_tl_list, TensorList)
988
+ assert result_tl_list.shape == simple_tl.shape
989
+
990
+ # Basic checks based on distribution
991
+ if dist == 'uniform':
992
+ assert all(torch.all((t >= -eps_scalar/2) & (t <= eps_scalar/2)) for t in result_tl_scalar)
993
+ assert all(torch.all((t >= -e/2) & (t <= e/2)) for t, e in zip(result_tl_list, eps_list))
994
+ elif dist == 'sphere':
995
+ assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
996
+ # Cannot check list version easily
997
+ elif dist == 'rademacher':
998
+ assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
999
+ assert all(torch.all((t == -e) | (t == e)) for t, e in zip(result_tl_list, eps_list))
1000
+
1001
+
1002
+ # --- Advanced Math Ops ---
1003
+
1004
+ def test_clamp(simple_tl: TensorList):
1005
+ min_val, max_val = -0.5, 0.5
1006
+ # Both min and max
1007
+ clamped_tl = simple_tl.clamp(min_val, max_val)
1008
+ expected_tl = TensorList([t.clamp(min_val, max_val) for t in simple_tl])
1009
+ assert_tl_allclose(clamped_tl, expected_tl)
1010
+
1011
+ # Only min
1012
+ clamped_min_tl = simple_tl.clamp(min=min_val)
1013
+ expected_min_tl = TensorList([t.clamp(min=min_val) for t in simple_tl])
1014
+ assert_tl_allclose(clamped_min_tl, expected_min_tl)
1015
+
1016
+ # Only max
1017
+ clamped_max_tl = simple_tl.clamp(max=max_val)
1018
+ expected_max_tl = TensorList([t.clamp(max=max_val) for t in simple_tl])
1019
+ assert_tl_allclose(clamped_max_tl, expected_max_tl)
1020
+
1021
+ # List min/max
1022
+ min_list = [-1, -0.5, 0]
1023
+ max_list = [1, 0.5, 0.2]
1024
+ clamped_list_tl = simple_tl.clamp(min_list, max_list)
1025
+ expected_list_tl = TensorList([t.clamp(mn, mx) for t, mn, mx in zip(simple_tl, min_list, max_list)])
1026
+ assert_tl_allclose(clamped_list_tl, expected_list_tl)
1027
+
1028
+ # Inplace
1029
+ tl_copy = simple_tl.clone()
1030
+ result = tl_copy.clamp_(min_val, max_val)
1031
+ assert result is tl_copy
1032
+ assert_tl_allclose(tl_copy, expected_tl)
1033
+
1034
+
1035
+ def test_clamp_magnitude(simple_tl: TensorList):
1036
+ min_val, max_val = 0.2, 1.0
1037
+ tl_copy = simple_tl.clone()
1038
+ # Test non-zero case
1039
+ tl_copy[0][0,0] = 0.01 # ensure some small values
1040
+ tl_copy[1][0] = 10.0 # ensure some large values
1041
+ tl_copy[2][0,0,0] = 0.0 # test zero
1042
+
1043
+ clamped_tl = tl_copy.clamp_magnitude(min_val, max_val)
1044
+ # Check magnitudes are clipped
1045
+ for t in clamped_tl:
1046
+ abs_t = t.abs()
1047
+ # Allow small tolerance for floating point issues near zero
1048
+ assert torch.all(abs_t >= min_val - 1e-6)
1049
+ assert torch.all(abs_t <= max_val + 1e-6)
1050
+ # Check signs are preserved (or zero remains zero)
1051
+ original_sign = tl_copy.sign()
1052
+ clamped_sign = clamped_tl.sign()
1053
+ # Zeros might become non-zero min magnitude, so compare non-zeros
1054
+ non_zero_mask = tl_copy.ne(0)
1055
+ for os, cs, nz in zip(original_sign, clamped_sign, non_zero_mask):
1056
+ assert torch.all(os[nz] == cs[nz])
1057
+
1058
+ # Inplace
1059
+ tl_copy_inplace = tl_copy.clone()
1060
+ result = tl_copy_inplace.clamp_magnitude_(min_val, max_val)
1061
+ assert result is tl_copy_inplace
1062
+ assert_tl_allclose(tl_copy_inplace, clamped_tl)
1063
+
1064
+
1065
+ def test_lerp(simple_tl: TensorList):
1066
+ tensors1 = simple_tl.clone().mul_(2)
1067
+ weight_scalar = 0.5
1068
+ result_tl_scalar = simple_tl.lerp(tensors1, weight_scalar)
1069
+ expected_tl_scalar = TensorList([torch.lerp(t, t1, weight_scalar) for t, t1 in zip(simple_tl, tensors1)])
1070
+ assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
1071
+
1072
+ weight_list = [0.1, 0.5, 0.9]
1073
+ result_tl_list = simple_tl.lerp(tensors1, weight_list)
1074
+ expected_tl_list = TensorList([torch.lerp(t, t1, w) for t, t1, w in zip(simple_tl, tensors1, weight_list)])
1075
+ assert_tl_allclose(result_tl_list, expected_tl_list)
1076
+
1077
+ # Inplace
1078
+ tl_copy = simple_tl.clone()
1079
+ result_inplace = tl_copy.lerp_(tensors1, weight_scalar)
1080
+ assert result_inplace is tl_copy
1081
+ assert_tl_allclose(tl_copy, expected_tl_scalar)
1082
+
1083
+
1084
+ def test_lerp_compat(simple_tl: TensorList):
1085
+ # Test specifically the scalar sequence case for compatibility fallback
1086
+ tensors1 = simple_tl.clone().mul_(2)
1087
+ weight_list = [0.1, 0.5, 0.9]
1088
+ result_tl_list = simple_tl.lerp_compat(tensors1, weight_list)
1089
+ expected_tl_list = TensorList([t + w * (t1 - t) for t, t1, w in zip(simple_tl, tensors1, weight_list)])
1090
+ assert_tl_allclose(result_tl_list, expected_tl_list)
1091
+
1092
+ # Inplace
1093
+ tl_copy = simple_tl.clone()
1094
+ result_inplace = tl_copy.lerp_compat_(tensors1, weight_list)
1095
+ assert result_inplace is tl_copy
1096
+ assert_tl_allclose(tl_copy, expected_tl_list)
1097
+
1098
+
1099
+ @pytest.mark.parametrize("op_name, torch_op", [
1100
+ ('addcmul', torch.addcmul),
1101
+ ('addcdiv', torch.addcdiv),
1102
+ ])
1103
+ def test_addcops(simple_tl: TensorList, op_name, torch_op):
1104
+ tensors1 = simple_tl.clone().add(0.1)
1105
+ tensors2 = simple_tl.clone().mul(0.5)
1106
+ value_scalar = 2.0
1107
+ value_list = [1.0, 2.0, 3.0]
1108
+
1109
+ op_func = getattr(simple_tl, op_name)
1110
+ op_inplace_func = getattr(simple_tl, op_name + '_')
1111
+
1112
+ # Scalar value
1113
+ result_tl_scalar = op_func(tensors1, tensors2, value=value_scalar)
1114
+ expected_tl_scalar = TensorList([torch_op(t, t1, t2, value=value_scalar)
1115
+ for t, t1, t2 in zip(simple_tl, tensors1, tensors2)])
1116
+ assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
1117
+
1118
+ # List value
1119
+ result_tl_list = op_func(tensors1, tensors2, value=value_list)
1120
+ expected_tl_list = TensorList([torch_op(t, t1, t2, value=v)
1121
+ for t, t1, t2, v in zip(simple_tl, tensors1, tensors2, value_list)])
1122
+ assert_tl_allclose(result_tl_list, expected_tl_list)
1123
+
1124
+
1125
+ # Inplace (scalar value)
1126
+ tl_copy_scalar = simple_tl.clone()
1127
+ op_inplace_func = getattr(tl_copy_scalar, op_name + '_')
1128
+ result_inplace_scalar = op_inplace_func(tensors1, tensors2, value=value_scalar)
1129
+ assert result_inplace_scalar is tl_copy_scalar
1130
+ assert_tl_allclose(tl_copy_scalar, expected_tl_scalar)
1131
+
1132
+ # Inplace (list value)
1133
+ tl_copy_list = simple_tl.clone()
1134
+ op_inplace_func = getattr(tl_copy_list, op_name + '_')
1135
+ result_inplace_list = op_inplace_func(tensors1, tensors2, value=value_list)
1136
+ assert result_inplace_list is tl_copy_list
1137
+ assert_tl_allclose(tl_copy_list, expected_tl_list)
1138
+
1139
+
1140
+ @pytest.mark.parametrize("op_name, torch_op", [
1141
+ ('maximum', torch.maximum),
1142
+ ('minimum', torch.minimum),
1143
+ ])
1144
+ def test_maximin(simple_tl: TensorList, op_name, torch_op):
1145
+ other_scalar = 0.0
1146
+ other_list_scalar = [-1.0, 0.0, 1.0]
1147
+ other_tl = simple_tl.clone().mul_(-1)
1148
+
1149
+ op_func = getattr(simple_tl, op_name)
1150
+ op_inplace_func = getattr(simple_tl, op_name + '_')
1151
+
1152
+ # Scalar other
1153
+ result_tl_scalar = op_func(other_scalar)
1154
+ expected_tl_scalar = TensorList([torch_op(t, torch.tensor(other_scalar, dtype=t.dtype, device=t.device)) for t in simple_tl])
1155
+ assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
1156
+
1157
+ # List scalar other
1158
+ result_tl_list_scalar = op_func(other_list_scalar)
1159
+ expected_tl_list_scalar = TensorList([torch_op(t, torch.tensor(o, dtype=t.dtype, device=t.device)) for t, o in zip(simple_tl, other_list_scalar)])
1160
+ assert_tl_allclose(result_tl_list_scalar, expected_tl_list_scalar)
1161
+
1162
+ # TensorList other
1163
+ result_tl_tl = op_func(other_tl)
1164
+ expected_tl_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_tl)])
1165
+ assert_tl_allclose(result_tl_tl, expected_tl_tl)
1166
+
1167
+ # Inplace (TensorList other)
1168
+ tl_copy = simple_tl.clone()
1169
+ op_inplace_func = getattr(tl_copy, op_name + '_')
1170
+ result_inplace = op_inplace_func(other_tl)
1171
+ assert result_inplace is tl_copy
1172
+ assert_tl_allclose(tl_copy, expected_tl_tl)
1173
+
1174
+
1175
+ def test_nan_to_num(simple_tl: TensorList):
1176
+ tl_with_nan = simple_tl.clone()
1177
+ tl_with_nan[0][0, 0] = float('nan')
1178
+ tl_with_nan[1][0] = float('inf')
1179
+ tl_with_nan[2][0, 0, 0] = float('-inf')
1180
+
1181
+ # Default conversion
1182
+ result_default = tl_with_nan.nan_to_num()
1183
+ expected_default = TensorList([torch.nan_to_num(t) for t in tl_with_nan])
1184
+ assert_tl_equal(result_default, expected_default)
1185
+
1186
+ # Custom values (scalar)
1187
+ nan, posinf, neginf = 0.0, 1e6, -1e6
1188
+ result_scalar = tl_with_nan.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
1189
+ expected_scalar = TensorList([torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf) for t in tl_with_nan])
1190
+ assert_tl_equal(result_scalar, expected_scalar)
1191
+
1192
+ # Custom values (list)
1193
+ nan_list = [0.0, 1.0, 2.0]
1194
+ posinf_list = [1e5, 1e6, 1e7]
1195
+ neginf_list = [-1e5, -1e6, -1e7]
1196
+ result_list = tl_with_nan.nan_to_num(nan=nan_list, posinf=posinf_list, neginf=neginf_list)
1197
+ expected_list = TensorList([torch.nan_to_num(t, nan=n, posinf=p, neginf=ni)
1198
+ for t, n, p, ni in zip(tl_with_nan, nan_list, posinf_list, neginf_list)])
1199
+ assert_tl_equal(result_list, expected_list)
1200
+
1201
+ # Inplace
1202
+ tl_copy = tl_with_nan.clone()
1203
+ result_inplace = tl_copy.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)
1204
+ assert result_inplace is tl_copy
1205
+ assert_tl_equal(tl_copy, expected_scalar)
1206
+
1207
+ # --- Reduction Ops ---
1208
+
1209
+ @pytest.mark.parametrize("reduction_method", ['mean', 'sum', 'min', 'max'])#, 'var', 'std', 'median', 'quantile'])
1210
+ @pytest.mark.parametrize("dim", [None, 0, 'global'])
1211
+ @pytest.mark.parametrize("keepdim", [False, True])
1212
+ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
1213
+ if dim == 'global' and keepdim:
1214
+ # with pytest.raises(ValueError, match='dim = global and keepdim = True'):
1215
+ # getattr(simple_tl, reduction_method)(dim=dim, keepdim=keepdim)
1216
+ return
1217
+ # Quantile needs q
1218
+ q = 0.75
1219
+ if reduction_method == 'quantile':
1220
+ args = {'q': q, 'dim': dim, 'keepdim': keepdim}
1221
+ torch_args = {'q': q, 'dim': dim, 'keepdim': keepdim}
1222
+ if dim is None: # torch.quantile doesn't accept dim=None, needs integer dim
1223
+ torch_args['dim'] = 0 if simple_tl[0].ndim > 0 else None # Use dim 0 if possible
1224
+ if torch_args['dim'] is None: # Cannot test dim=None on 0-d tensor easily here
1225
+ pytest.skip("Cannot test quantile with dim=None on 0-d tensors easily")
1226
+ elif reduction_method == 'median':
1227
+ args = {'dim': dim, 'keepdim': keepdim}
1228
+ torch_args = {'dim': dim, 'keepdim': keepdim}
1229
+ if dim is None: # torch.median requires dim if tensor is not 1D
1230
+ # Skip complex multi-dim median check for None dim
1231
+ pytest.skip("Skipping median test with dim=None for simplicity")
1232
+ else:
1233
+ args = {'dim': dim, 'keepdim': keepdim}
1234
+ torch_args = {'dim': dim, 'keepdim': keepdim}
1235
+
1236
+ reduction_func = getattr(simple_tl, reduction_method)
1237
+
1238
+ # Skip if dim is invalid for a tensor
1239
+ if isinstance(dim, int):
1240
+ if any(dim >= t.ndim for t in simple_tl):
1241
+ pytest.skip(f"Dimension {dim} out of range for at least one tensor")
1242
+
1243
+ try:
1244
+ result = reduction_func(**args)
1245
+ except RuntimeError as e:
1246
+ # median/quantile might fail on certain dtypes, skip if so
1247
+ if "median" in reduction_method or "quantile" in reduction_method:
1248
+ pytest.skip(f"Skipping {reduction_method} due to dtype incompatibility: {e}")
1249
+ else: raise e
1250
+
1251
+
1252
+ if dim == 'global':
1253
+ vec = simple_tl.to_vec()
1254
+ if reduction_method == 'min': expected = vec.min()
1255
+ elif reduction_method == 'max': expected = vec.max()
1256
+ elif reduction_method == 'mean': expected = vec.mean()
1257
+ elif reduction_method == 'sum': expected = vec.sum()
1258
+ elif reduction_method == 'std': expected = vec.std()
1259
+ elif reduction_method == 'var': expected = vec.var()
1260
+ elif reduction_method == 'median': expected = vec.median()#.values # scalar tensor
1261
+ elif reduction_method == 'quantile': expected = vec.quantile(q)
1262
+ else:
1263
+ pytest.fail("Unknown global reduction")
1264
+ assert False, 'sus'
1265
+ assert torch.allclose(result, expected)
1266
+ else:
1267
+ expected_list = []
1268
+ for t in simple_tl:
1269
+ if reduction_method == 'min': torch_func = getattr(t, 'amin')
1270
+ elif reduction_method == 'max': torch_func = getattr(t, 'amax')
1271
+ else: torch_func = getattr(t, reduction_method)
1272
+ try:
1273
+ if reduction_method == 'median':
1274
+ # Median returns (values, indices), we only want values
1275
+ expected_val = torch_func(**torch_args)[0]
1276
+ elif reduction_method == 'quantile':
1277
+ expected_val = torch_func(**torch_args)
1278
+ # quantile might return scalar tensor if dim is None and keepdim=False
1279
+ # if dim is None and not keepdim: expected_val = expected_val.unsqueeze(0) if expected_val.ndim == 0 else expected_val
1280
+
1281
+ else:
1282
+ torch_args_copy = torch_args.copy()
1283
+ if reduction_method in ('min', 'max'):
1284
+ if 'dim' in torch_args_copy and torch_args_copy['dim'] is None: torch_args_copy['dim'] = ()
1285
+ expected_val = torch_func(**torch_args_copy)
1286
+
1287
+ # Handle cases where reduction reduces to scalar but we expect TL
1288
+ if not isinstance(expected_val, torch.Tensor): # e.g. min/max on scalar tensor
1289
+ expected_val = torch.tensor(expected_val, device=t.device, dtype=t.dtype)
1290
+ # if dim is None and not keepdim and expected_val.ndim==0:
1291
+ # expected_val = expected_val.unsqueeze(0) # Make it 1D for consistency in TL
1292
+
1293
+
1294
+ expected_list.append(expected_val)
1295
+ except RuntimeError as e:
1296
+ # Skip individual tensor if op not supported (e.g. std on int)
1297
+ if "std" in str(e) or "var" in str(e) or "mean" in str(e):
1298
+ pytest.skip(f"Skipping {reduction_method} on tensor due to dtype: {e}")
1299
+ else: raise e
1300
+
1301
+ expected_tl = TensorList(expected_list)
1302
+ assert isinstance(result, TensorList)
1303
+ assert len(result) == len(expected_tl)
1304
+ assert_tl_allclose(result, expected_tl, atol=1e-6) # Use allclose due to potential float variations
1305
+
1306
+ # --- Grafting, Rescaling, Normalizing, Clipping ---
1307
+
1308
+ def test_graft(simple_tl: TensorList):
1309
+ magnitude_tl = simple_tl.clone().mul_(2.0) # Double the magnitude
1310
+
1311
+ # Tensorwise graft
1312
+ grafted_tensorwise = simple_tl.graft(magnitude_tl, tensorwise=True, ord=2)
1313
+ original_norms = simple_tl.norm(ord=2)
1314
+ magnitude_norms = magnitude_tl.norm(ord=2)
1315
+ grafted_norms = grafted_tensorwise.norm(ord=2)
1316
+ # Check norms match the magnitude norms
1317
+ assert_tl_allclose(grafted_norms, magnitude_norms)
1318
+ # Check directions are preserved (allow for scaling factor)
1319
+ for g, o, onorm, mnorm in zip(grafted_tensorwise, simple_tl, original_norms, magnitude_norms):
1320
+ # Handle zero norm case
1321
+ if onorm > 1e-7 and mnorm > 1e-7:
1322
+ expected_g = o * (mnorm / onorm)
1323
+ assert torch.allclose(g, expected_g)
1324
+ elif mnorm <= 1e-7: # If magnitude is zero, graft should be zero
1325
+ assert torch.allclose(g, torch.zeros_like(g))
1326
+ # If original norm is zero but magnitude is non-zero, result is undefined/arbitrary direction?
1327
+ # Current implementation results in zero due to mul by zero tensor.
1328
+
1329
+ # Global graft
1330
+ grafted_global = simple_tl.graft(magnitude_tl, tensorwise=False, ord=2)
1331
+ original_global_norm = simple_tl.global_vector_norm(ord=2)
1332
+ magnitude_global_norm = magnitude_tl.global_vector_norm(ord=2)
1333
+ grafted_global_norm = grafted_global.global_vector_norm(ord=2)
1334
+ # Check global norm matches
1335
+ assert torch.allclose(grafted_global_norm, magnitude_global_norm)
1336
+ # Check direction (overall vector) is preserved
1337
+ if original_global_norm > 1e-7 and magnitude_global_norm > 1e-7:
1338
+ expected_global_scale = magnitude_global_norm / original_global_norm
1339
+ expected_global_tl = simple_tl * expected_global_scale
1340
+ assert_tl_allclose(grafted_global, expected_global_tl)
1341
+ elif magnitude_global_norm <= 1e-7:
1342
+ assert torch.allclose(grafted_global.to_vec(), torch.zeros(simple_tl.global_numel()))
1343
+
1344
+
1345
+ # Test inplace
1346
+ tl_copy_t = simple_tl.clone()
1347
+ tl_copy_t.graft_(magnitude_tl, tensorwise=True, ord=2)
1348
+ assert_tl_allclose(tl_copy_t, grafted_tensorwise)
1349
+
1350
+ tl_copy_g = simple_tl.clone()
1351
+ tl_copy_g.graft_(magnitude_tl, tensorwise=False, ord=2)
1352
+ assert_tl_allclose(tl_copy_g, grafted_global)
1353
+
1354
+
1355
+ @pytest.mark.parametrize("dim", [None, 0, 'global'])
1356
+ def test_rescale(simple_tl: TensorList, dim):
1357
+ min_val, max_val = 0.0, 1.0
1358
+ min_list = [0.0, -1.0, 0.5]
1359
+ max_list = [1.0, 0.0, 1.5]
1360
+ eps = 1e-6
1361
+
1362
+ # if dim is 0 make sure it isn't len 1 dim
1363
+ if dim == 0:
1364
+ tensors = TensorList()
1365
+ for t in simple_tl:
1366
+ while t.shape[0] == 1: t = t[0]
1367
+ if t.ndim != 0: tensors.append(t)
1368
+ simple_tl = tensors
1369
+
1370
+ # Skip if dim is invalid for a tensor
1371
+ if isinstance(dim, int):
1372
+ if any(dim >= t.ndim for t in simple_tl):
1373
+ pytest.skip(f"Dimension {dim} out of range for at least one tensor")
1374
+
1375
+ # Rescale scalar
1376
+ rescaled_scalar = simple_tl.rescale(min_val, max_val, dim=dim, eps=eps)
1377
+ rescaled_scalar_min = rescaled_scalar.min(dim=dim if dim != 'global' else None)
1378
+ rescaled_scalar_max = rescaled_scalar.max(dim=dim if dim != 'global' else None)
1379
+
1380
+ if dim == 'global':
1381
+ assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
1382
+ assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
1383
+ else:
1384
+ assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-4)
1385
+ assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-4)
1386
+
1387
+
1388
+ # Rescale list
1389
+ rescaled_list = simple_tl.rescale(min_list, max_list, dim=dim, eps=eps)
1390
+ rescaled_list_min = rescaled_list.min(dim=dim if dim != 'global' else None)
1391
+ rescaled_list_max = rescaled_list.max(dim=dim if dim != 'global' else None)
1392
+
1393
+ if dim == 'global':
1394
+ # Global rescale with list min/max is tricky, check range contains target roughly
1395
+ global_min_rescaled = rescaled_list.global_min()
1396
+ global_max_rescaled = rescaled_list.global_max()
1397
+ # Cannot guarantee exact match due to single scaling factor 'a' and 'b'
1398
+ # Check if the range is approximately correct based on average target range?
1399
+ avg_min = sum(min_list)/len(min_list)
1400
+ avg_max = sum(max_list)/len(max_list)
1401
+ assert global_min_rescaled > avg_min - 1.0 # Loose check
1402
+ assert global_max_rescaled < avg_max + 1.0 # Loose check
1403
+
1404
+ else:
1405
+ assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-4)
1406
+ assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-4)
1407
+
1408
+ # Rescale to 01 helper
1409
+ rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
1410
+ rescaled_01_min = rescaled_01.min(dim=dim if dim != 'global' else None)
1411
+ rescaled_01_max = rescaled_01.max(dim=dim if dim != 'global' else None)
1412
+ if dim == 'global':
1413
+ assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
1414
+ assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
1415
+ else:
1416
+ assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-4)
1417
+ assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-4)
1418
+
1419
+
1420
+ # Test inplace
1421
+ tl_copy = simple_tl.clone()
1422
+ tl_copy.rescale_(min_val, max_val, dim=dim, eps=eps)
1423
+ assert_tl_allclose(tl_copy, rescaled_scalar)
1424
+
1425
+ tl_copy_01 = simple_tl.clone()
1426
+ tl_copy_01.rescale_to_01_(dim=dim, eps=eps)
1427
+ assert_tl_allclose(tl_copy_01, rescaled_01)
1428
+
1429
+
1430
+ @pytest.mark.parametrize("dim", [None, 0, 'global'])
1431
+ def test_normalize(big_tl: TensorList, dim):
1432
+ simple_tl = big_tl # can't be bothered t renamed
1433
+
1434
+ mean_val, var_val = 0.0, 1.0
1435
+ mean_list = [0.0, 1.0, -0.5]
1436
+ var_list = [1.0, 0.5, 2.0] # Variance > 0
1437
+
1438
+ # if dim is 0 make sure it isn't len 1 dim
1439
+ if dim == 0:
1440
+ tensors = TensorList()
1441
+ for t in simple_tl:
1442
+ while t.shape[0] == 1: t = t[0]
1443
+ if t.ndim != 0: tensors.append(t)
1444
+ simple_tl = tensors
1445
+
1446
+ # Skip if dim is invalid for a tensor
1447
+ if isinstance(dim, int):
1448
+ if any(dim >= t.ndim for t in simple_tl):
1449
+ pytest.skip(f"Dimension {dim} out of range for at least one tensor")
1450
+
1451
+ # Normalize scalar mean/var (z-normalize essentially)
1452
+ normalized_scalar = simple_tl.normalize(mean_val, var_val, dim=dim)
1453
+ normalized_scalar_mean = normalized_scalar.mean(dim=dim if dim != 'global' else None)
1454
+ normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
1455
+
1456
+ if dim == 'global':
1457
+ assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-4)
1458
+ assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-4)
1459
+ else:
1460
+ assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-4)
1461
+ assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-4)
1462
+
1463
+ # Normalize list mean/var
1464
+ normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
1465
+ normalized_list_mean = normalized_list.mean(dim=dim if dim != 'global' else None)
1466
+ normalized_list_var = normalized_list.var(dim=dim if dim != 'global' else None)
1467
+
1468
+ if dim == 'global':
1469
+ global_mean_rescaled = normalized_list.global_mean()
1470
+ global_var_rescaled = normalized_list.global_var()
1471
+ avg_mean = sum(mean_list)/len(mean_list)
1472
+ avg_var = sum(var_list)/len(var_list)
1473
+ # Cannot guarantee exact match due to single scaling factor 'a' and 'b'
1474
+ assert global_mean_rescaled - 0.6 < torch.tensor(avg_mean) < global_mean_rescaled + 0.6
1475
+ assert global_var_rescaled - 0.6 < torch.tensor(avg_var) < global_var_rescaled + 0.6
1476
+ # assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
1477
+ # assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
1478
+ else:
1479
+ assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-4)
1480
+ assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-4)
1481
+
1482
+ # Z-normalize helper
1483
+ znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
1484
+ znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
1485
+ znorm_var = znorm.var(dim=dim if dim != 'global' else None)
1486
+ if dim == 'global':
1487
+ assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-4)
1488
+ assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-4)
1489
+ else:
1490
+ assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-4)
1491
+ assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-4)
1492
+
1493
+
1494
+ # Test inplace
1495
+ tl_copy = simple_tl.clone()
1496
+ tl_copy.normalize_(mean_val, var_val, dim=dim)
1497
+ assert_tl_allclose(tl_copy, normalized_scalar)
1498
+
1499
+ tl_copy_z = simple_tl.clone()
1500
+ tl_copy_z.znormalize_(dim=dim, eps=1e-10)
1501
+ assert_tl_allclose(tl_copy_z, znorm)
1502
+
1503
+
1504
+ @pytest.mark.parametrize("tensorwise", [True, False])
1505
+ def test_clip_norm(simple_tl: TensorList, tensorwise):
1506
+ min_val, max_val = 0.5, 1.5
1507
+ min_list = [0.2, 0.7, 1.0]
1508
+ max_list = [1.0, 1.2, 2.0]
1509
+ ord = 2
1510
+
1511
+ # Clip scalar min/max
1512
+ clipped_scalar = simple_tl.clip_norm(min_val, max_val, tensorwise=tensorwise, ord=ord)
1513
+ if tensorwise:
1514
+ clipped_scalar_norms = clipped_scalar.norm(ord=ord)
1515
+ assert all(torch.all((n >= min_val - 1e-6) & (n <= max_val + 1e-6)) for n in clipped_scalar_norms)
1516
+ else:
1517
+ clipped_scalar_global_norm = clipped_scalar.global_vector_norm(ord=ord)
1518
+ assert min_val - 1e-6 <= clipped_scalar_global_norm <= max_val + 1e-6
1519
+
1520
+ # Clip list min/max
1521
+ clipped_list = simple_tl.clip_norm(min_list, max_list, tensorwise=tensorwise, ord=ord)
1522
+ if tensorwise:
1523
+ clipped_list_norms = clipped_list.norm(ord=ord)
1524
+ assert all(torch.all((n >= mn - 1e-6) & (n <= mx + 1e-6)) for n, mn, mx in zip(clipped_list_norms, min_list, max_list))
1525
+ else:
1526
+ # Global clip with list min/max is tricky, multiplier is complex
1527
+ # Just check type and shape
1528
+ assert isinstance(clipped_list, TensorList)
1529
+ assert clipped_list.shape == simple_tl.shape
1530
+
1531
+
1532
+ # Test inplace
1533
+ tl_copy = simple_tl.clone()
1534
+ tl_copy.clip_norm_(min_val, max_val, tensorwise=tensorwise, ord=ord)
1535
+ assert_tl_allclose(tl_copy, clipped_scalar)
1536
+
1537
+
1538
+ # --- Indexing and Masking ---
1539
+
1540
+ def test_where(simple_tl: TensorList):
1541
+ condition_tl = simple_tl.gt(0)
1542
+ other_scalar = -1.0
1543
+ other_list_scalar = [-1.0, -2.0, -3.0]
1544
+ other_tl = simple_tl.clone().mul_(-1)
1545
+
1546
+ # Scalar other
1547
+ result_scalar = simple_tl.where(condition_tl, other_scalar)
1548
+ expected_scalar = TensorList([torch.where(c, t, torch.tensor(other_scalar, dtype=t.dtype, device=t.device))
1549
+ for t, c in zip(simple_tl, condition_tl)])
1550
+ assert_tl_allclose(result_scalar, expected_scalar)
1551
+
1552
+ # List scalar other
1553
+ result_list_scalar = simple_tl.where(condition_tl, other_list_scalar)
1554
+ expected_list_scalar = TensorList([torch.where(c, t, torch.tensor(o, dtype=t.dtype, device=t.device))
1555
+ for t, c, o in zip(simple_tl, condition_tl, other_list_scalar)])
1556
+ assert_tl_allclose(result_list_scalar, expected_list_scalar)
1557
+
1558
+
1559
+ # TensorList other
1560
+ result_tl = simple_tl.where(condition_tl, other_tl)
1561
+ expected_tl = TensorList([torch.where(c, t, o) for t, c, o in zip(simple_tl, condition_tl, other_tl)])
1562
+ assert_tl_allclose(result_tl, expected_tl)
1563
+
1564
+ # Test module-level where function
1565
+ result_module = tl_where(condition_tl, simple_tl, other_tl)
1566
+ assert_tl_allclose(result_module, expected_tl)
1567
+
1568
+
1569
+ # Test inplace where_ (needs TensorList other)
1570
+ tl_copy = simple_tl.clone()
1571
+ result_inplace = tl_copy.where_(condition_tl, other_tl)
1572
+ assert result_inplace is tl_copy
1573
+ assert_tl_allclose(tl_copy, expected_tl)
1574
+
1575
+
1576
+ def test_masked_fill(simple_tl: TensorList):
1577
+ mask_tl = simple_tl.lt(0)
1578
+ fill_value_scalar = 99.0
1579
+ fill_value_list = [11.0, 22.0, 33.0]
1580
+
1581
+ # Scalar fill
1582
+ result_scalar = simple_tl.masked_fill(mask_tl, fill_value_scalar)
1583
+ expected_scalar = TensorList([t.masked_fill(m, fill_value_scalar) for t, m in zip(simple_tl, mask_tl)])
1584
+ assert_tl_allclose(result_scalar, expected_scalar)
1585
+
1586
+ # List fill
1587
+ result_list = simple_tl.masked_fill(mask_tl, fill_value_list)
1588
+ expected_list = TensorList([t.masked_fill(m, fv) for t, m, fv in zip(simple_tl, mask_tl, fill_value_list)])
1589
+ assert_tl_allclose(result_list, expected_list)
1590
+
1591
+ # Test inplace
1592
+ tl_copy = simple_tl.clone()
1593
+ result_inplace = tl_copy.masked_fill_(mask_tl, fill_value_scalar)
1594
+ assert result_inplace is tl_copy
1595
+ assert_tl_allclose(tl_copy, expected_scalar)
1596
+
1597
+
1598
+ def test_select_set_(simple_tl: TensorList):
1599
+ mask_tl = simple_tl.gt(0.5)
1600
+ value_scalar = -1.0
1601
+ value_list_scalar = [-1.0, -2.0, -3.0]
1602
+ value_tl = simple_tl.clone().mul_(0.1)
1603
+
1604
+ # Set with scalar value
1605
+ tl_copy_scalar = simple_tl.clone()
1606
+ tl_copy_scalar.select_set_(mask_tl, value_scalar)
1607
+ expected_scalar = simple_tl.clone()
1608
+ for t, m in zip(expected_scalar, mask_tl): t[m] = value_scalar
1609
+ assert_tl_allclose(tl_copy_scalar, expected_scalar)
1610
+
1611
+ # Set with list of scalar values
1612
+ tl_copy_list_scalar = simple_tl.clone()
1613
+ tl_copy_list_scalar.select_set_(mask_tl, value_list_scalar)
1614
+ expected_list_scalar = simple_tl.clone()
1615
+ for t, m, v in zip(expected_list_scalar, mask_tl, value_list_scalar): t[m] = v
1616
+ assert_tl_allclose(tl_copy_list_scalar, expected_list_scalar)
1617
+
1618
+ # Set with TensorList value
1619
+ # no thats masked_set_
1620
+ # tl_copy_tl = simple_tl.clone()
1621
+ # tl_copy_tl.select_set_(mask_tl, value_tl)
1622
+ # expected_tl = simple_tl.clone()
1623
+ # for t, m, v in zip(expected_tl, mask_tl, value_tl): t[m] = v[m] # Select from value tensor too
1624
+ # assert_tl_allclose(tl_copy_tl, expected_tl)
1625
+
1626
+
1627
+ def test_masked_set_(simple_tl: TensorList):
1628
+ mask_tl = simple_tl.gt(0.5)
1629
+ value_tl = simple_tl.clone().mul_(0.1)
1630
+
1631
+ tl_copy = simple_tl.clone()
1632
+ tl_copy.masked_set_(mask_tl, value_tl)
1633
+ expected = simple_tl.clone()
1634
+ for t, m, v in zip(expected, mask_tl, value_tl): t[m] = v[m] # masked_set_ semantics
1635
+ assert_tl_allclose(tl_copy, expected)
1636
+
1637
+
1638
+ def test_select(simple_tl: TensorList):
1639
+ # Select with integer
1640
+ idx_int = 0
1641
+ result_int = simple_tl.select(idx_int)
1642
+ expected_int = TensorList([t[idx_int] for t in simple_tl])
1643
+ assert_tl_equal(result_int, expected_int)
1644
+
1645
+ # Select with slice
1646
+ idx_slice = slice(0, 1)
1647
+ result_slice = simple_tl.select(idx_slice)
1648
+ expected_slice = TensorList([t[idx_slice] for t in simple_tl])
1649
+ assert_tl_equal(result_slice, expected_slice)
1650
+
1651
+ # Select with list of indices (per tensor)
1652
+ idx_list = [0, slice(1, 3), (0, slice(None), 0)] # Different index for each tensor
1653
+ result_list = simple_tl.select(idx_list)
1654
+ expected_list = TensorList([t[i] for t, i in zip(simple_tl, idx_list)])
1655
+ assert_tl_equal(result_list, expected_list)
1656
+
1657
+ # --- Miscellaneous ---
1658
+
1659
+ def test_dot(simple_tl: TensorList):
1660
+ other_tl = simple_tl.clone().mul_(0.5)
1661
+ result = simple_tl.dot(other_tl)
1662
+ expected = (simple_tl * other_tl).global_sum()
1663
+ assert torch.allclose(result, expected)
1664
+
1665
+ def test_swap_tensors(simple_tl: TensorList):
1666
+ tl1 = simple_tl.clone()
1667
+ tl2 = simple_tl.clone().mul_(2)
1668
+ tl1_orig_copy = tl1.clone()
1669
+ tl2_orig_copy = tl2.clone()
1670
+
1671
+ tl1.swap_tensors(tl2)
1672
+
1673
+ # Check tl1 now has tl2's original data and vice versa
1674
+ assert_tl_equal(tl1, tl2_orig_copy)
1675
+ assert_tl_equal(tl2, tl1_orig_copy)
1676
+
1677
+
1678
+ def test_unbind_channels(simple_tl: TensorList):
1679
+ # Make sure at least one tensor has >1 dim 0 size
1680
+ simple_tl[0] = torch.randn(3, 4, 5)
1681
+ simple_tl[1] = torch.randn(2, 6)
1682
+ simple_tl[2] = torch.randn(1) # Keep a 1D tensor
1683
+
1684
+ unbound_tl = simple_tl.unbind_channels(dim=0)
1685
+
1686
+ expected_list = []
1687
+ for t in simple_tl:
1688
+ if t.ndim >= 2:
1689
+ expected_list.extend(list(t.unbind(dim=0)))
1690
+ else:
1691
+ expected_list.append(t)
1692
+ expected_tl = TensorList(expected_list)
1693
+
1694
+ assert_tl_equal(unbound_tl, expected_tl)
1695
+
1696
+
1697
+ def test_flatiter(simple_tl: TensorList):
1698
+ iterator = simple_tl.flatiter()
1699
+ all_elements = list(iterator)
1700
+ expected_elements = list(simple_tl.to_vec())
1701
+
1702
+ assert len(all_elements) == len(expected_elements)
1703
+ for el, exp_el in zip(all_elements, expected_elements):
1704
+ # flatiter yields scalar tensors
1705
+ assert isinstance(el, torch.Tensor)
1706
+ assert el.ndim == 0
1707
+ assert torch.equal(el, exp_el)
1708
+
1709
+
1710
+ def test_repr(simple_tl: TensorList):
1711
+ representation = repr(simple_tl)
1712
+ assert representation.startswith("TensorList([")
1713
+ assert representation.endswith("])")
1714
+ # Check if tensor representations are inside
1715
+ assert "tensor(" in representation
1716
+
1717
+
1718
+ # --- Module Level Functions ---
1719
+
1720
+ def test_stack(simple_tl: TensorList):
1721
+ tl1 = simple_tl.clone()
1722
+ tl2 = simple_tl.clone() * 2
1723
+ tl3 = simple_tl.clone() * 3
1724
+
1725
+ stacked_tl = stack([tl1, tl2, tl3], dim=0)
1726
+ expected_tl = TensorList([torch.stack([t1, t2, t3], dim=0)
1727
+ for t1, t2, t3 in zip(tl1, tl2, tl3)])
1728
+ assert_tl_equal(stacked_tl, expected_tl)
1729
+
1730
+ stacked_tl_dim1 = stack([tl1, tl2, tl3], dim=1)
1731
+ expected_tl_dim1 = TensorList([torch.stack([t1, t2, t3], dim=1)
1732
+ for t1, t2, t3 in zip(tl1, tl2, tl3)])
1733
+ assert_tl_equal(stacked_tl_dim1, expected_tl_dim1)
1734
+
1735
+
1736
+ def test_mean_median_sum_quantile_module(simple_tl: TensorList):
1737
+ tl1 = simple_tl.clone()
1738
+ tl2 = simple_tl.clone() * 2.5
1739
+ tl3 = simple_tl.clone() * -1.0
1740
+ tensors = [tl1, tl2, tl3]
1741
+
1742
+ # Mean
1743
+ mean_res = mean(tensors)
1744
+ expected_mean = stack(tensors, dim=0).mean(dim=0)
1745
+ assert_tl_allclose(mean_res, expected_mean)
1746
+
1747
+ # Sum
1748
+ sum_res = tl_sum(tensors)
1749
+ expected_sum = stack(tensors, dim=0).sum(dim=0)
1750
+ assert_tl_allclose(sum_res, expected_sum)
1751
+
1752
+ # Median
1753
+ median_res = median(tensors)
1754
+ # Stack and get median values (result is named tuple)
1755
+ expected_median_vals = stack(tensors, dim=0).median(dim=0)
1756
+ expected_median = TensorList(expected_median_vals)
1757
+ assert_tl_allclose(median_res, expected_median)
1758
+
1759
+ # Quantile
1760
+ q = 0.25
1761
+ quantile_res = quantile(tensors, q=q)
1762
+ expected_quantile_vals = stack(tensors, dim=0).quantile(q=q, dim=0)
1763
+ expected_quantile = TensorList(list(expected_quantile_vals))
1764
+ assert_tl_allclose(quantile_res, expected_quantile)
1765
+
1766
+
1767
+ # --- Test _MethodCallerWithArgs ---
1768
+ def test_method_caller_with_args():
1769
+ caller = _MethodCallerWithArgs('add')
1770
+ t = torch.tensor([1, 2])
1771
+ result = caller(t, 5) # t.add(5)
1772
+ assert torch.equal(result, torch.tensor([6, 7]))
1773
+
1774
+ result_kw = caller(t, other=10, alpha=2) # t.add(other=10, alpha=2)
1775
+ assert torch.equal(result_kw, torch.tensor([21, 22])) # 1 + 2*10, 2 + 2*10
1776
+
1777
+ # --- Test generic_clamp ---
1778
+ def test_generic_clamp():
1779
+ assert generic_clamp(5, min=0, max=10) == 5
1780
+ assert generic_clamp(-5, min=0, max=10) == 0
1781
+ assert generic_clamp(15, min=0, max=10) == 10
1782
+ assert generic_clamp(torch.tensor([-5, 5, 15]), min=0, max=10).equal(torch.tensor([0, 5, 10]))
1783
+
1784
+ tl = TensorList([torch.tensor([-5, 5, 15]), torch.tensor([1, 12])])
1785
+ clamped_tl = generic_clamp(tl, min=0, max=10)
1786
+ expected_tl = TensorList([torch.tensor([0, 5, 10]), torch.tensor([1, 10])])
1787
+ assert_tl_equal(clamped_tl, expected_tl)