torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
tests/test_tensorlist.py
ADDED
|
@@ -0,0 +1,1787 @@
|
|
|
1
|
+
# pylint: disable = redefined-outer-name, bad-indentation
|
|
2
|
+
"""note that a lot of this is written by Gemini for a better coverage."""
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torchzero.utils.tensorlist import (
|
|
9
|
+
TensorList,
|
|
10
|
+
_MethodCallerWithArgs,
|
|
11
|
+
as_tensorlist,
|
|
12
|
+
generic_clamp,
|
|
13
|
+
mean,
|
|
14
|
+
median,
|
|
15
|
+
quantile,
|
|
16
|
+
stack,
|
|
17
|
+
)
|
|
18
|
+
from torchzero.utils.tensorlist import sum as tl_sum
|
|
19
|
+
from torchzero.utils.tensorlist import where as tl_where
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def randmask_like(x,device=None,dtype=None):
|
|
23
|
+
return torch.rand_like(x.float(), device=device,dtype=dtype) > 0.5
|
|
24
|
+
|
|
25
|
+
# Helper function for comparing TensorLists element-wise
|
|
26
|
+
def assert_tl_equal(tl1: TensorList, tl2: TensorList):
|
|
27
|
+
assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
|
|
28
|
+
for t1, t2 in zip(tl1, tl2):
|
|
29
|
+
if t1 is None and t2 is None:
|
|
30
|
+
continue
|
|
31
|
+
assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
|
|
32
|
+
assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
|
|
33
|
+
assert torch.equal(t1, t2), f"Tensors are not equal:\n{t1}\nvs\n{t2}"
|
|
34
|
+
|
|
35
|
+
def assert_tl_allclose(tl1: TensorList, tl2: TensorList, **kwargs):
|
|
36
|
+
assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
|
|
37
|
+
for t1, t2 in zip(tl1, tl2):
|
|
38
|
+
if t1 is None and t2 is None:
|
|
39
|
+
continue
|
|
40
|
+
assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
|
|
41
|
+
assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
|
|
42
|
+
assert torch.allclose(t1, t2, equal_nan=True, **kwargs), f"Tensors are not close:\n{t1}\nvs\n{t2}"
|
|
43
|
+
|
|
44
|
+
# --- Fixtures ---
|
|
45
|
+
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def simple_tensors() -> list[torch.Tensor]:
|
|
48
|
+
return [torch.randn(2, 3), torch.randn(5), torch.randn(1, 4, 2)]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.fixture
|
|
52
|
+
def big_tensors() -> list[torch.Tensor]:
|
|
53
|
+
return [torch.randn(20, 50), torch.randn(3,3,3,3,3,3), torch.randn(1000)]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.fixture
|
|
57
|
+
def simple_tl(simple_tensors) -> TensorList:
|
|
58
|
+
return TensorList(simple_tensors)
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def simple_tl_clone(simple_tl) -> TensorList:
|
|
62
|
+
return simple_tl.clone()
|
|
63
|
+
|
|
64
|
+
@pytest.fixture
|
|
65
|
+
def big_tl(big_tensors) -> TensorList:
|
|
66
|
+
return TensorList(big_tensors)
|
|
67
|
+
|
|
68
|
+
@pytest.fixture
|
|
69
|
+
def grad_tensors() -> list[torch.Tensor]:
|
|
70
|
+
return [
|
|
71
|
+
torch.randn(2, 2, requires_grad=True),
|
|
72
|
+
torch.randn(3, requires_grad=False),
|
|
73
|
+
torch.randn(1, 5, requires_grad=True)
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
@pytest.fixture
|
|
77
|
+
def grad_tl(grad_tensors) -> TensorList:
|
|
78
|
+
return TensorList(grad_tensors)
|
|
79
|
+
|
|
80
|
+
@pytest.fixture
|
|
81
|
+
def int_tensors() -> list[torch.Tensor]:
|
|
82
|
+
return [torch.randint(0, 10, (2, 3)), torch.randint(0, 10, (5,))]
|
|
83
|
+
|
|
84
|
+
@pytest.fixture
|
|
85
|
+
def int_tl(int_tensors) -> TensorList:
|
|
86
|
+
return TensorList(int_tensors)
|
|
87
|
+
|
|
88
|
+
@pytest.fixture
|
|
89
|
+
def bool_tensors() -> list[torch.Tensor]:
|
|
90
|
+
return [torch.rand(2, 3) > 0.5, torch.rand(5) > 0.5]
|
|
91
|
+
|
|
92
|
+
@pytest.fixture
|
|
93
|
+
def bool_tl(bool_tensors) -> TensorList:
|
|
94
|
+
return TensorList(bool_tensors)
|
|
95
|
+
|
|
96
|
+
@pytest.fixture
|
|
97
|
+
def complex_tensors() -> list[torch.Tensor]:
|
|
98
|
+
return [torch.randn(2, 3, dtype=torch.complex64), torch.randn(5, dtype=torch.complex64)]
|
|
99
|
+
|
|
100
|
+
@pytest.fixture
|
|
101
|
+
def complex_tl(complex_tensors) -> TensorList:
|
|
102
|
+
return TensorList(complex_tensors)
|
|
103
|
+
|
|
104
|
+
# --- Test Cases ---
|
|
105
|
+
|
|
106
|
+
def test_initialization(simple_tensors):
|
|
107
|
+
tl = TensorList(simple_tensors)
|
|
108
|
+
assert isinstance(tl, TensorList)
|
|
109
|
+
assert isinstance(tl, list)
|
|
110
|
+
assert len(tl) == len(simple_tensors)
|
|
111
|
+
for i in range(len(tl)):
|
|
112
|
+
assert torch.equal(tl[i], simple_tensors[i])
|
|
113
|
+
|
|
114
|
+
def test_empty_initialization():
|
|
115
|
+
tl = TensorList()
|
|
116
|
+
assert isinstance(tl, TensorList)
|
|
117
|
+
assert len(tl) == 0
|
|
118
|
+
|
|
119
|
+
def test_as_tensorlist(simple_tensors, simple_tl: TensorList):
|
|
120
|
+
tl = as_tensorlist(simple_tensors)
|
|
121
|
+
assert isinstance(tl, TensorList)
|
|
122
|
+
assert_tl_equal(tl, simple_tl)
|
|
123
|
+
|
|
124
|
+
tl2 = as_tensorlist(simple_tl)
|
|
125
|
+
assert tl2 is simple_tl # Should return the same object if already TensorList
|
|
126
|
+
|
|
127
|
+
def test_complex_classmethod(simple_tensors):
|
|
128
|
+
real_tl = TensorList([t.float() for t in simple_tensors])
|
|
129
|
+
imag_tl = TensorList([torch.randn_like(t) for t in simple_tensors])
|
|
130
|
+
complex_tl = TensorList.complex(real_tl, imag_tl)
|
|
131
|
+
|
|
132
|
+
assert isinstance(complex_tl, TensorList)
|
|
133
|
+
assert len(complex_tl) == len(real_tl)
|
|
134
|
+
for i in range(len(complex_tl)):
|
|
135
|
+
assert complex_tl[i].dtype == torch.complex64 or complex_tl[i].dtype == torch.complex128
|
|
136
|
+
assert torch.equal(complex_tl[i].real, real_tl[i])
|
|
137
|
+
assert torch.equal(complex_tl[i].imag, imag_tl[i])
|
|
138
|
+
|
|
139
|
+
# --- Properties ---
|
|
140
|
+
def test_properties(simple_tl: TensorList, simple_tensors):
|
|
141
|
+
assert simple_tl.device == [t.device for t in simple_tensors]
|
|
142
|
+
assert simple_tl.dtype == [t.dtype for t in simple_tensors]
|
|
143
|
+
assert simple_tl.requires_grad == [t.requires_grad for t in simple_tensors]
|
|
144
|
+
assert simple_tl.shape == [t.shape for t in simple_tensors]
|
|
145
|
+
assert simple_tl.size() == [t.size() for t in simple_tensors]
|
|
146
|
+
assert simple_tl.size(0) == [t.size(0) for t in simple_tensors]
|
|
147
|
+
assert simple_tl.ndim == [t.ndim for t in simple_tensors]
|
|
148
|
+
assert simple_tl.ndimension() == [t.ndimension() for t in simple_tensors]
|
|
149
|
+
assert simple_tl.numel() == [t.numel() for t in simple_tensors]
|
|
150
|
+
|
|
151
|
+
def test_grad_property(grad_tl: TensorList, grad_tensors):
|
|
152
|
+
# Initially grads are None
|
|
153
|
+
assert all(g is None for g in grad_tl.grad)
|
|
154
|
+
|
|
155
|
+
# Set some grads
|
|
156
|
+
for i, t in enumerate(grad_tensors):
|
|
157
|
+
if t.requires_grad:
|
|
158
|
+
t.grad = torch.ones_like(t) * (i + 1)
|
|
159
|
+
|
|
160
|
+
grads = grad_tl.grad
|
|
161
|
+
assert isinstance(grads, TensorList)
|
|
162
|
+
assert len(grads) == len(grad_tl)
|
|
163
|
+
for i in range(len(grad_tl)):
|
|
164
|
+
if grad_tensors[i].requires_grad:
|
|
165
|
+
assert torch.equal(grads[i], torch.ones_like(grad_tensors[i]) * (i + 1))
|
|
166
|
+
else:
|
|
167
|
+
assert grads[i] is None # Accessing .grad on non-req-grad tensor returns None
|
|
168
|
+
|
|
169
|
+
def test_real_imag_properties(complex_tl: TensorList, complex_tensors):
|
|
170
|
+
real_part = complex_tl.real
|
|
171
|
+
imag_part = complex_tl.imag
|
|
172
|
+
assert isinstance(real_part, TensorList)
|
|
173
|
+
assert isinstance(imag_part, TensorList)
|
|
174
|
+
assert len(real_part) == len(complex_tl)
|
|
175
|
+
assert len(imag_part) == len(complex_tl)
|
|
176
|
+
for i in range(len(complex_tl)):
|
|
177
|
+
assert torch.equal(real_part[i], complex_tensors[i].real)
|
|
178
|
+
assert torch.equal(imag_part[i], complex_tensors[i].imag)
|
|
179
|
+
|
|
180
|
+
def test_view_as_real_complex(complex_tl: TensorList, complex_tensors):
|
|
181
|
+
real_view = complex_tl.view_as_real()
|
|
182
|
+
assert isinstance(real_view, TensorList)
|
|
183
|
+
assert len(real_view) == len(complex_tl)
|
|
184
|
+
for i in range(len(complex_tl)):
|
|
185
|
+
assert torch.equal(real_view[i], torch.view_as_real(complex_tensors[i]))
|
|
186
|
+
|
|
187
|
+
# Convert back
|
|
188
|
+
complex_view_again = real_view.view_as_complex()
|
|
189
|
+
assert_tl_equal(complex_view_again, complex_tl)
|
|
190
|
+
|
|
191
|
+
# --- Utility Methods ---
|
|
192
|
+
|
|
193
|
+
def test_type_as(simple_tl: TensorList, int_tl: TensorList):
|
|
194
|
+
int_casted_tl = simple_tl.type_as(int_tl[0]) # Cast like first int tensor
|
|
195
|
+
assert isinstance(int_casted_tl, TensorList)
|
|
196
|
+
assert all(t.dtype == int_tl[0].dtype for t in int_casted_tl)
|
|
197
|
+
|
|
198
|
+
float_casted_tl = int_tl.type_as(simple_tl) # Cast like corresponding float tensors
|
|
199
|
+
assert isinstance(float_casted_tl, TensorList)
|
|
200
|
+
assert all(t.dtype == s.dtype for t, s in zip(float_casted_tl, simple_tl))
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_fill_none(simple_tl: TensorList):
|
|
204
|
+
tl_with_none = TensorList([simple_tl[0], None, simple_tl[2]])
|
|
205
|
+
reference_tl = simple_tl.clone() # Use original shapes as reference
|
|
206
|
+
|
|
207
|
+
filled_tl = tl_with_none.fill_none(reference_tl)
|
|
208
|
+
assert isinstance(filled_tl, TensorList)
|
|
209
|
+
assert filled_tl[0] is tl_with_none[0] # Should keep existing tensor
|
|
210
|
+
assert torch.equal(filled_tl[1], torch.zeros_like(reference_tl[1]))
|
|
211
|
+
assert filled_tl[2] is tl_with_none[2]
|
|
212
|
+
# Check original is not modified
|
|
213
|
+
assert tl_with_none[1] is None
|
|
214
|
+
|
|
215
|
+
filled_tl_inplace = tl_with_none.fill_none_(reference_tl)
|
|
216
|
+
assert filled_tl_inplace is tl_with_none # Should return self
|
|
217
|
+
assert filled_tl_inplace[0] is simple_tl[0]
|
|
218
|
+
assert torch.equal(filled_tl_inplace[1], torch.zeros_like(reference_tl[1]))
|
|
219
|
+
assert filled_tl_inplace[2] is simple_tl[2]
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_get_grad(grad_tl: TensorList, grad_tensors):
|
|
223
|
+
# No grads initially
|
|
224
|
+
assert len(grad_tl.get_grad()) == 0
|
|
225
|
+
|
|
226
|
+
# Set grads only for tensors requiring grad
|
|
227
|
+
expected_grads = []
|
|
228
|
+
for i, t in enumerate(grad_tensors):
|
|
229
|
+
if t.requires_grad:
|
|
230
|
+
g = torch.rand_like(t)
|
|
231
|
+
t.grad = g
|
|
232
|
+
expected_grads.append(g)
|
|
233
|
+
|
|
234
|
+
retrieved_grads = grad_tl.get_grad()
|
|
235
|
+
assert isinstance(retrieved_grads, TensorList)
|
|
236
|
+
assert len(retrieved_grads) == len(expected_grads)
|
|
237
|
+
for rg, eg in zip(retrieved_grads, expected_grads):
|
|
238
|
+
assert torch.equal(rg, eg)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def test_with_requires_grad(grad_tl: TensorList, grad_tensors):
|
|
242
|
+
req_grad_true = grad_tl.with_requires_grad(True)
|
|
243
|
+
expected_true = [t for t in grad_tensors if t.requires_grad]
|
|
244
|
+
assert len(req_grad_true) == len(expected_true)
|
|
245
|
+
for rt, et in zip(req_grad_true, expected_true):
|
|
246
|
+
assert rt is et
|
|
247
|
+
|
|
248
|
+
req_grad_false = grad_tl.with_requires_grad(False)
|
|
249
|
+
expected_false = [t for t in grad_tensors if not t.requires_grad]
|
|
250
|
+
assert len(req_grad_false) == len(expected_false)
|
|
251
|
+
for rt, et in zip(req_grad_false, expected_false):
|
|
252
|
+
assert rt is et
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_with_grad(grad_tl: TensorList, grad_tensors):
|
|
256
|
+
assert len(grad_tl.with_grad()) == 0 # No grads set yet
|
|
257
|
+
|
|
258
|
+
# Set grads for tensors requiring grad
|
|
259
|
+
expected_with_grad = []
|
|
260
|
+
for i, t in enumerate(grad_tensors):
|
|
261
|
+
if t.requires_grad:
|
|
262
|
+
t.grad = torch.ones_like(t) * i
|
|
263
|
+
expected_with_grad.append(t)
|
|
264
|
+
|
|
265
|
+
has_grad_tl = grad_tl.with_grad()
|
|
266
|
+
assert isinstance(has_grad_tl, TensorList)
|
|
267
|
+
assert len(has_grad_tl) == len(expected_with_grad)
|
|
268
|
+
for hg, eg in zip(has_grad_tl, expected_with_grad):
|
|
269
|
+
assert hg is eg
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def test_ensure_grad_(grad_tl: TensorList, grad_tensors):
|
|
273
|
+
# Call ensure_grad_
|
|
274
|
+
grad_tl.ensure_grad_()
|
|
275
|
+
|
|
276
|
+
for t in grad_tl:
|
|
277
|
+
if t.requires_grad:
|
|
278
|
+
assert t.grad is not None
|
|
279
|
+
assert torch.equal(t.grad, torch.zeros_like(t))
|
|
280
|
+
else:
|
|
281
|
+
assert t.grad is None
|
|
282
|
+
|
|
283
|
+
# Call again, should not change existing zero grads
|
|
284
|
+
grad_tl.ensure_grad_()
|
|
285
|
+
for t in grad_tl:
|
|
286
|
+
if t.requires_grad:
|
|
287
|
+
assert t.grad is not None, 'this is a fixture'
|
|
288
|
+
assert torch.equal(t.grad, torch.zeros_like(t))
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def test_accumulate_grad_(grad_tl: TensorList, grad_tensors):
|
|
292
|
+
new_grads = TensorList([torch.rand_like(t) for t in grad_tensors])
|
|
293
|
+
new_grads_copy = new_grads.clone()
|
|
294
|
+
|
|
295
|
+
# First accumulation (grads are None or zero if ensure_grad_ was called)
|
|
296
|
+
grad_tl.accumulate_grad_(new_grads)
|
|
297
|
+
for t, ng in zip(grad_tl, new_grads_copy):
|
|
298
|
+
# if t.requires_grad:
|
|
299
|
+
assert t.grad is not None
|
|
300
|
+
assert torch.equal(t.grad, ng)
|
|
301
|
+
# else:
|
|
302
|
+
# assert t.grad is None # Should not create grad if requires_grad is False
|
|
303
|
+
|
|
304
|
+
# Second accumulation
|
|
305
|
+
new_grads_2 = TensorList([torch.rand_like(t) for t in grad_tensors])
|
|
306
|
+
expected_grads = TensorList([g + ng2 for t, g, ng2 in zip(grad_tensors, grad_tl.grad, new_grads_2)])
|
|
307
|
+
|
|
308
|
+
grad_tl.accumulate_grad_(new_grads_2)
|
|
309
|
+
for t, eg in zip(grad_tl, expected_grads):
|
|
310
|
+
assert t.grad is not None
|
|
311
|
+
assert torch.allclose(t.grad, eg)
|
|
312
|
+
# else:
|
|
313
|
+
# assert t.grad is None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def test_set_grad_(grad_tl: TensorList, grad_tensors):
|
|
317
|
+
# Set initial grads
|
|
318
|
+
initial_grads = TensorList([torch.ones_like(t) if t.requires_grad else None for t in grad_tensors])
|
|
319
|
+
grad_tl.set_grad_(initial_grads)
|
|
320
|
+
for t, ig in zip(grad_tl, initial_grads):
|
|
321
|
+
assert t.grad is ig
|
|
322
|
+
|
|
323
|
+
# Set new grads
|
|
324
|
+
new_grads = TensorList([torch.rand_like(t) * 2 if t.requires_grad else None for t in grad_tensors])
|
|
325
|
+
grad_tl.set_grad_(new_grads)
|
|
326
|
+
for t, ng in zip(grad_tl, new_grads):
|
|
327
|
+
assert t.grad is ng # Checks object identity for None, value for Tensors
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def test_zero_grad_(grad_tl: TensorList, grad_tensors):
|
|
331
|
+
# Set some grads
|
|
332
|
+
for t in grad_tl:
|
|
333
|
+
if t.requires_grad:
|
|
334
|
+
t.grad = torch.ones_like(t)
|
|
335
|
+
|
|
336
|
+
# Zero grads (set to None)
|
|
337
|
+
grad_tl.zero_grad_(set_to_none=True)
|
|
338
|
+
for t in grad_tl:
|
|
339
|
+
assert t.grad is None
|
|
340
|
+
|
|
341
|
+
# Set grads again
|
|
342
|
+
for t in grad_tl:
|
|
343
|
+
if t.requires_grad:
|
|
344
|
+
t.grad = torch.ones_like(t)
|
|
345
|
+
|
|
346
|
+
# Zero grads (set to zero)
|
|
347
|
+
grad_tl.zero_grad_(set_to_none=False)
|
|
348
|
+
for t in grad_tl:
|
|
349
|
+
if t.requires_grad:
|
|
350
|
+
assert t.grad is not None
|
|
351
|
+
assert torch.equal(t.grad, torch.zeros_like(t))
|
|
352
|
+
else:
|
|
353
|
+
assert t.grad is None # Should remain None if requires_grad is False
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
# --- Arithmetic Operators ---
|
|
357
|
+
|
|
358
|
+
@pytest.mark.parametrize("other_type", ["scalar", "list_scalar", "tensorlist", "list_tensor"])
|
|
359
|
+
@pytest.mark.parametrize("op, op_inplace, torch_op, foreach_op, foreach_op_inplace", [
|
|
360
|
+
('__add__', '__iadd__', torch.add, torch._foreach_add, torch._foreach_add_),
|
|
361
|
+
('__sub__', '__isub__', torch.sub, torch._foreach_sub, torch._foreach_sub_),
|
|
362
|
+
('__mul__', '__imul__', torch.mul, torch._foreach_mul, torch._foreach_mul_),
|
|
363
|
+
('__truediv__', '__itruediv__', torch.div, torch._foreach_div, torch._foreach_div_),
|
|
364
|
+
])
|
|
365
|
+
def test_arithmetic_ops(simple_tl: TensorList, simple_tl_clone: TensorList, other_type, op, op_inplace, torch_op, foreach_op, foreach_op_inplace):
|
|
366
|
+
if other_type == "scalar":
|
|
367
|
+
other = 2.5
|
|
368
|
+
other_list = [other] * len(simple_tl)
|
|
369
|
+
elif other_type == "list_scalar":
|
|
370
|
+
other = [1.0, 2.0, 3.0]
|
|
371
|
+
other_list = other
|
|
372
|
+
elif other_type == "tensorlist":
|
|
373
|
+
other = simple_tl_clone.clone().mul_(0.5) # Create a compatible TensorList
|
|
374
|
+
other_list = other
|
|
375
|
+
elif other_type == "list_tensor":
|
|
376
|
+
other = [t * 0.5 for t in simple_tl_clone] # Create a compatible list of tensors
|
|
377
|
+
other_list = other
|
|
378
|
+
else:
|
|
379
|
+
pytest.fail("Unknown other_type")
|
|
380
|
+
|
|
381
|
+
# --- Test out-of-place ---
|
|
382
|
+
op_func = getattr(simple_tl, op)
|
|
383
|
+
result_tl = op_func(other)
|
|
384
|
+
expected_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_list)])
|
|
385
|
+
|
|
386
|
+
assert isinstance(result_tl, TensorList)
|
|
387
|
+
assert_tl_allclose(result_tl, expected_tl)
|
|
388
|
+
# Ensure original is unchanged
|
|
389
|
+
assert_tl_equal(simple_tl, simple_tl_clone)
|
|
390
|
+
|
|
391
|
+
# Test foreach version directly for comparison (if applicable)
|
|
392
|
+
if op != '__sub__' or other_type != 'scalar': # _foreach_sub doesn't support scalar 'other' directly
|
|
393
|
+
if hasattr(torch, foreach_op.__name__):
|
|
394
|
+
expected_foreach = TensorList(foreach_op(simple_tl, other_list))
|
|
395
|
+
assert_tl_allclose(result_tl, expected_foreach)
|
|
396
|
+
|
|
397
|
+
# --- Test in-place ---
|
|
398
|
+
tl_copy = simple_tl.clone()
|
|
399
|
+
op_inplace_func = getattr(tl_copy, op_inplace)
|
|
400
|
+
result_inplace = op_inplace_func(other)
|
|
401
|
+
|
|
402
|
+
assert result_inplace is tl_copy # Should return self
|
|
403
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
404
|
+
|
|
405
|
+
# Test foreach_ inplace version directly
|
|
406
|
+
tl_copy_foreach = simple_tl.clone()
|
|
407
|
+
if op != '__sub__' or other_type != 'scalar': # _foreach_sub_ doesn't support scalar 'other' directly
|
|
408
|
+
if hasattr(torch, foreach_op_inplace.__name__):
|
|
409
|
+
foreach_op_inplace(tl_copy_foreach, other_list)
|
|
410
|
+
assert_tl_allclose(tl_copy_foreach, expected_tl)
|
|
411
|
+
|
|
412
|
+
# --- Test r-ops (if applicable) ---
|
|
413
|
+
if op in ['__add__', '__mul__']: # Commutative
|
|
414
|
+
rop_func = getattr(simple_tl, op.replace('__', '__r', 1))
|
|
415
|
+
result_rtl = rop_func(other)
|
|
416
|
+
assert_tl_allclose(result_rtl, expected_tl)
|
|
417
|
+
elif op == '__sub__': # Test rsub: other - self = -(self - other)
|
|
418
|
+
rop_func = getattr(simple_tl, '__rsub__')
|
|
419
|
+
result_rtl = rop_func(other)
|
|
420
|
+
expected_rtl = expected_tl.neg() # Note: self.sub(other).neg_() == other - self
|
|
421
|
+
assert_tl_allclose(result_rtl, expected_rtl)
|
|
422
|
+
elif op == '__truediv__': # Test rtruediv: other / self
|
|
423
|
+
if other_type in ["scalar", "list_scalar"]: # scalar / tensor or list<scalar> / list<tensor>
|
|
424
|
+
rop_func = getattr(simple_tl, '__rtruediv__')
|
|
425
|
+
result_rtl = rop_func(other)
|
|
426
|
+
expected_rtl = TensorList([o / t for t, o in zip(simple_tl, other_list)])
|
|
427
|
+
assert_tl_allclose(result_rtl, expected_rtl)
|
|
428
|
+
# rtruediv for tensorlist/list_tensor is not implemented directly
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
@pytest.mark.parametrize("op, torch_op", [
|
|
432
|
+
('__pow__', torch.pow),
|
|
433
|
+
('__floordiv__', torch.floor_divide),
|
|
434
|
+
('__mod__', torch.remainder),
|
|
435
|
+
])
|
|
436
|
+
@pytest.mark.parametrize("other_type", ["scalar", "list_scalar", "tensorlist", "list_tensor"])
|
|
437
|
+
def test_other_arithmetic_ops(simple_tl: TensorList, simple_tl_clone: TensorList, op, torch_op, other_type):
|
|
438
|
+
is_pow = op == '__pow__'
|
|
439
|
+
if other_type == "scalar":
|
|
440
|
+
other = 2 if is_pow else 2.5
|
|
441
|
+
other_list = [other] * len(simple_tl)
|
|
442
|
+
elif other_type == "list_scalar":
|
|
443
|
+
other = [2, 1, 3] if is_pow else [1.5, 2.5, 3.5]
|
|
444
|
+
other_list = other
|
|
445
|
+
elif other_type == "tensorlist":
|
|
446
|
+
other = simple_tl_clone.clone().abs_().add_(1).clamp_(max=3) if is_pow else simple_tl_clone.clone().mul_(0.5).add_(1)
|
|
447
|
+
other_list = other
|
|
448
|
+
elif other_type == "list_tensor":
|
|
449
|
+
other = [(t.abs() + 1).clamp(max=3) if is_pow else (t*0.5 + 1) for t in simple_tl_clone]
|
|
450
|
+
other_list = other
|
|
451
|
+
else:
|
|
452
|
+
pytest.fail("Unknown other_type")
|
|
453
|
+
|
|
454
|
+
# Test out-of-place
|
|
455
|
+
op_func = getattr(simple_tl, op)
|
|
456
|
+
result_tl = op_func(other)
|
|
457
|
+
expected_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_list)])
|
|
458
|
+
|
|
459
|
+
assert isinstance(result_tl, TensorList)
|
|
460
|
+
assert_tl_allclose(result_tl, expected_tl)
|
|
461
|
+
assert_tl_equal(simple_tl, simple_tl_clone) # Ensure original unchanged
|
|
462
|
+
|
|
463
|
+
# Test in-place (if exists)
|
|
464
|
+
op_inplace = op.replace('__', '__i', 1) + '_' # Standard naming convention adopted
|
|
465
|
+
if hasattr(simple_tl, op_inplace):
|
|
466
|
+
tl_copy = simple_tl.clone()
|
|
467
|
+
op_inplace_func = getattr(tl_copy, op_inplace)
|
|
468
|
+
result_inplace = op_inplace_func(other)
|
|
469
|
+
assert result_inplace is tl_copy
|
|
470
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
471
|
+
|
|
472
|
+
# Test rpow
|
|
473
|
+
if op == '__pow__' and other_type in ['scalar']:#, 'list_scalar']: # _foreach_pow doesn't support list of scalars as base
|
|
474
|
+
rop_func = getattr(simple_tl, '__rpow__')
|
|
475
|
+
result_rtl = rop_func(other)
|
|
476
|
+
expected_rtl = TensorList([torch_op(o, t) for t, o in zip(simple_tl, other_list)])
|
|
477
|
+
assert_tl_allclose(result_rtl, expected_rtl)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def test_negation(simple_tl: TensorList, simple_tl_clone):
|
|
481
|
+
neg_tl = -simple_tl
|
|
482
|
+
expected_tl = TensorList([-t for t in simple_tl])
|
|
483
|
+
assert_tl_allclose(neg_tl, expected_tl)
|
|
484
|
+
assert_tl_equal(simple_tl, simple_tl_clone) # Ensure original unchanged
|
|
485
|
+
|
|
486
|
+
neg_tl_inplace = simple_tl.neg_()
|
|
487
|
+
assert neg_tl_inplace is simple_tl
|
|
488
|
+
assert_tl_allclose(simple_tl, expected_tl)
|
|
489
|
+
|
|
490
|
+
# --- Comparison Operators ---
|
|
491
|
+
|
|
492
|
+
@pytest.mark.parametrize("op, torch_op", [
|
|
493
|
+
('__eq__', torch.eq),
|
|
494
|
+
('__ne__', torch.ne),
|
|
495
|
+
('__lt__', torch.lt),
|
|
496
|
+
('__le__', torch.le),
|
|
497
|
+
('__gt__', torch.gt),
|
|
498
|
+
('__ge__', torch.ge),
|
|
499
|
+
])
|
|
500
|
+
@pytest.mark.parametrize("other_type", ["scalar", "list_scalar", "tensorlist", "list_tensor"])
|
|
501
|
+
def test_comparison_ops(simple_tl: TensorList, op, torch_op, other_type):
|
|
502
|
+
if other_type == "scalar":
|
|
503
|
+
other = 0.0
|
|
504
|
+
other_list = [other] * len(simple_tl)
|
|
505
|
+
elif other_type == "list_scalar":
|
|
506
|
+
other = [-0.5, 0.0, 0.5]
|
|
507
|
+
other_list = other
|
|
508
|
+
elif other_type == "tensorlist":
|
|
509
|
+
other = simple_tl.clone().mul_(0.9)
|
|
510
|
+
other_list = other
|
|
511
|
+
elif other_type == "list_tensor":
|
|
512
|
+
other = [t * 0.9 for t in simple_tl]
|
|
513
|
+
other_list = other
|
|
514
|
+
else:
|
|
515
|
+
pytest.fail("Unknown other_type")
|
|
516
|
+
|
|
517
|
+
op_func = getattr(simple_tl, op)
|
|
518
|
+
result_tl = op_func(other)
|
|
519
|
+
expected_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_list)])
|
|
520
|
+
|
|
521
|
+
assert isinstance(result_tl, TensorList)
|
|
522
|
+
assert all(t.dtype == torch.bool for t in result_tl)
|
|
523
|
+
assert_tl_equal(result_tl, expected_tl)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
# --- Logical Operators ---
|
|
527
|
+
|
|
528
|
+
@pytest.mark.parametrize("op, op_inplace, torch_op", [
|
|
529
|
+
('__and__', '__iand__', torch.logical_and),
|
|
530
|
+
('__or__', '__ior__', torch.logical_or),
|
|
531
|
+
('__xor__', '__ixor__', torch.logical_xor),
|
|
532
|
+
])
|
|
533
|
+
def test_logical_binary_ops(bool_tl: TensorList, op, op_inplace, torch_op):
|
|
534
|
+
other_tl = TensorList([randmask_like(t) for t in bool_tl])
|
|
535
|
+
other_list = list(other_tl) # Use list version for comparison
|
|
536
|
+
|
|
537
|
+
# Out-of-place
|
|
538
|
+
op_func = getattr(bool_tl, op)
|
|
539
|
+
result_tl = op_func(other_tl)
|
|
540
|
+
expected_tl = TensorList([torch_op(t, o) for t, o in zip(bool_tl, other_list)])
|
|
541
|
+
|
|
542
|
+
assert isinstance(result_tl, TensorList)
|
|
543
|
+
assert all(t.dtype == torch.bool for t in result_tl)
|
|
544
|
+
assert_tl_equal(result_tl, expected_tl)
|
|
545
|
+
|
|
546
|
+
# In-place
|
|
547
|
+
tl_copy = bool_tl.clone()
|
|
548
|
+
op_inplace_func = getattr(tl_copy, op_inplace) # Naming convention with _
|
|
549
|
+
result_inplace = op_inplace_func(other_tl)
|
|
550
|
+
assert result_inplace is tl_copy
|
|
551
|
+
assert_tl_equal(tl_copy, expected_tl)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def test_logical_not(bool_tl: TensorList):
|
|
555
|
+
# Out-of-place (~ operator maps to logical_not)
|
|
556
|
+
not_tl = ~bool_tl
|
|
557
|
+
expected_tl = TensorList([torch.logical_not(t) for t in bool_tl])
|
|
558
|
+
assert isinstance(not_tl, TensorList)
|
|
559
|
+
assert all(t.dtype == torch.bool for t in not_tl)
|
|
560
|
+
assert_tl_equal(not_tl, expected_tl)
|
|
561
|
+
|
|
562
|
+
# In-place
|
|
563
|
+
tl_copy = bool_tl.clone()
|
|
564
|
+
result_inplace = tl_copy.logical_not_()
|
|
565
|
+
assert result_inplace is tl_copy
|
|
566
|
+
assert_tl_equal(tl_copy, expected_tl)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def test_bool_raises(simple_tl: TensorList):
|
|
570
|
+
with pytest.raises(RuntimeError, match="Boolean value of TensorList is ambiguous"):
|
|
571
|
+
bool(simple_tl)
|
|
572
|
+
# Test with empty list
|
|
573
|
+
with pytest.raises(RuntimeError, match="Boolean value of TensorList is ambiguous"):
|
|
574
|
+
bool(TensorList())
|
|
575
|
+
|
|
576
|
+
# --- Map / Zipmap / Filter ---
|
|
577
|
+
|
|
578
|
+
def test_map(simple_tl: TensorList):
|
|
579
|
+
mapped_tl = simple_tl.map(torch.abs)
|
|
580
|
+
expected_tl = TensorList([torch.abs(t) for t in simple_tl])
|
|
581
|
+
assert_tl_allclose(mapped_tl, expected_tl)
|
|
582
|
+
|
|
583
|
+
def test_map_inplace_(simple_tl: TensorList):
|
|
584
|
+
tl_copy = simple_tl.clone()
|
|
585
|
+
result = tl_copy.map_inplace_(torch.abs_)
|
|
586
|
+
expected_tl = TensorList([torch.abs(t) for t in simple_tl]) # Calculate expected from original
|
|
587
|
+
assert result is tl_copy
|
|
588
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
589
|
+
|
|
590
|
+
def test_filter(simple_tl: TensorList):
|
|
591
|
+
# Filter tensors with more than 5 elements
|
|
592
|
+
filtered_tl = simple_tl.filter(lambda t: t.numel() > 5)
|
|
593
|
+
expected_tl = TensorList([t for t in simple_tl if t.numel() > 5])
|
|
594
|
+
assert len(filtered_tl) == len(expected_tl)
|
|
595
|
+
for ft, et in zip(filtered_tl, expected_tl):
|
|
596
|
+
assert ft is et # Should contain the original tensor objects
|
|
597
|
+
|
|
598
|
+
def test_zipmap(simple_tl: TensorList):
|
|
599
|
+
# Zipmap with another TensorList
|
|
600
|
+
other_tl = simple_tl.clone().mul_(0.5)
|
|
601
|
+
result_tl = simple_tl.zipmap(torch.add, other_tl)
|
|
602
|
+
expected_tl = TensorList([torch.add(t, o) for t, o in zip(simple_tl, other_tl)])
|
|
603
|
+
assert_tl_allclose(result_tl, expected_tl)
|
|
604
|
+
|
|
605
|
+
# Zipmap with a list of tensors
|
|
606
|
+
other_list = [t * 0.5 for t in simple_tl]
|
|
607
|
+
result_tl_list = simple_tl.zipmap(torch.add, other_list)
|
|
608
|
+
assert_tl_allclose(result_tl_list, expected_tl)
|
|
609
|
+
|
|
610
|
+
# Zipmap with a scalar
|
|
611
|
+
result_tl_scalar = simple_tl.zipmap(torch.add, 2.0)
|
|
612
|
+
expected_tl_scalar = TensorList([torch.add(t, 2.0) for t in simple_tl])
|
|
613
|
+
assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
|
|
614
|
+
|
|
615
|
+
# Zipmap with a list of scalars
|
|
616
|
+
other_scalars = [1.0, 2.0, 3.0]
|
|
617
|
+
result_tl_scalars = simple_tl.zipmap(torch.add, other_scalars)
|
|
618
|
+
expected_tl_scalars = TensorList([torch.add(t, s) for t, s in zip(simple_tl, other_scalars)])
|
|
619
|
+
assert_tl_allclose(result_tl_scalars, expected_tl_scalars)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def test_zipmap_inplace_(simple_tl: TensorList):
|
|
623
|
+
# Zipmap inplace with another TensorList
|
|
624
|
+
tl_copy = simple_tl.clone()
|
|
625
|
+
other_tl = simple_tl.clone().mul_(0.5)
|
|
626
|
+
result = tl_copy.zipmap_inplace_(_MethodCallerWithArgs('add_'), other_tl)
|
|
627
|
+
expected_tl = TensorList([torch.add(t, o) for t, o in zip(simple_tl, other_tl)])
|
|
628
|
+
assert result is tl_copy
|
|
629
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
630
|
+
|
|
631
|
+
# Zipmap inplace with a scalar
|
|
632
|
+
tl_copy_scalar = simple_tl.clone()
|
|
633
|
+
result_scalar = tl_copy_scalar.zipmap_inplace_(_MethodCallerWithArgs('add_'), 2.0)
|
|
634
|
+
expected_tl_scalar = TensorList([torch.add(t, 2.0) for t in simple_tl])
|
|
635
|
+
assert result_scalar is tl_copy_scalar
|
|
636
|
+
assert_tl_allclose(tl_copy_scalar, expected_tl_scalar)
|
|
637
|
+
|
|
638
|
+
# Zipmap inplace with list of scalars
|
|
639
|
+
tl_copy_scalars = simple_tl.clone()
|
|
640
|
+
other_scalars = [1.0, 2.0, 3.0]
|
|
641
|
+
result_scalars = tl_copy_scalars.zipmap_inplace_(_MethodCallerWithArgs('add_'), other_scalars)
|
|
642
|
+
expected_tl_scalars = TensorList([torch.add(t, s) for t, s in zip(simple_tl, other_scalars)])
|
|
643
|
+
assert result_scalars is tl_copy_scalars
|
|
644
|
+
assert_tl_allclose(tl_copy_scalars, expected_tl_scalars)
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def test_zipmap_args(simple_tl: TensorList):
|
|
648
|
+
other1 = simple_tl.clone().mul(0.5)
|
|
649
|
+
other2 = 2.0
|
|
650
|
+
other3 = [1, 2, 3]
|
|
651
|
+
# Test torch.lerp(input, end, weight) -> input + weight * (end - input)
|
|
652
|
+
# self = input, other1 = end, other2 = weight (scalar)
|
|
653
|
+
result_tl = simple_tl.zipmap_args(torch.lerp, other1, other2)
|
|
654
|
+
expected_tl = TensorList([torch.lerp(t, o1, other2) for t, o1 in zip(simple_tl, other1)])
|
|
655
|
+
assert_tl_allclose(result_tl, expected_tl)
|
|
656
|
+
|
|
657
|
+
# self = input, other1 = end, other3 = weight (list scalar)
|
|
658
|
+
result_tl_list = simple_tl.zipmap_args(torch.lerp, other1, other3)
|
|
659
|
+
expected_tl_list = TensorList([torch.lerp(t, o1, o3) for t, o1, o3 in zip(simple_tl, other1, other3)])
|
|
660
|
+
assert_tl_allclose(result_tl_list, expected_tl_list)
|
|
661
|
+
|
|
662
|
+
def test_zipmap_args_inplace_(simple_tl: TensorList):
|
|
663
|
+
tl_copy = simple_tl.clone()
|
|
664
|
+
other1 = simple_tl.clone().mul(0.5)
|
|
665
|
+
other2 = 0.5
|
|
666
|
+
# Test torch.addcmul_(tensor1, tensor2, value=1) -> self + value * tensor1 * tensor2
|
|
667
|
+
# self = self, other1 = tensor1, other1 (again) = tensor2, other2 = value
|
|
668
|
+
result_tl = tl_copy.zipmap_args_inplace_(_MethodCallerWithArgs('addcmul_'), other1, other1, value=other2)
|
|
669
|
+
expected_tl = TensorList([t.addcmul(o1, o1, value=other2) for t, o1 in zip(simple_tl.clone(), other1)]) # Need clone for calculation
|
|
670
|
+
assert result_tl is tl_copy
|
|
671
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
672
|
+
|
|
673
|
+
# --- Tensor Method Wrappers ---
|
|
674
|
+
|
|
675
|
+
@pytest.mark.parametrize("method_name, args", [
|
|
676
|
+
('clone', ()),
|
|
677
|
+
('detach', ()),
|
|
678
|
+
('contiguous', ()),
|
|
679
|
+
('cpu', ()),
|
|
680
|
+
('long', ()),
|
|
681
|
+
('short', ()),
|
|
682
|
+
('as_float', ()),
|
|
683
|
+
('as_int', ()),
|
|
684
|
+
('as_bool', ()),
|
|
685
|
+
('sqrt', ()),
|
|
686
|
+
('exp', ()),
|
|
687
|
+
('log', ()),
|
|
688
|
+
('sin', ()),
|
|
689
|
+
('cos', ()),
|
|
690
|
+
('abs', ()),
|
|
691
|
+
('neg', ()),
|
|
692
|
+
('reciprocal', ()),
|
|
693
|
+
('sign', ()),
|
|
694
|
+
('round', ()),
|
|
695
|
+
('floor', ()),
|
|
696
|
+
('ceil', ()),
|
|
697
|
+
('logical_not', ()), # Assuming input is boolean for this test
|
|
698
|
+
('ravel', ()),
|
|
699
|
+
('view_flat', ()),
|
|
700
|
+
('conj', ()), # Assuming input is complex for this test
|
|
701
|
+
('squeeze', ()),
|
|
702
|
+
('squeeze', (0,)), # Example with args
|
|
703
|
+
# Add more simple unary methods here...
|
|
704
|
+
])
|
|
705
|
+
def test_simple_unary_methods(simple_tl: TensorList, method_name, args):
|
|
706
|
+
tl_to_test = simple_tl
|
|
707
|
+
if method_name == 'logical_not':
|
|
708
|
+
tl_to_test = simple_tl.gt(0) # Create a boolean TL
|
|
709
|
+
elif method_name == 'conj':
|
|
710
|
+
tl_to_test = TensorList.complex(simple_tl, simple_tl) # Create complex
|
|
711
|
+
|
|
712
|
+
method = getattr(tl_to_test, method_name)
|
|
713
|
+
result_tl = method(*args)
|
|
714
|
+
|
|
715
|
+
method_names_map = {"as_float": "float", "as_int": "int", "as_bool": "bool", "view_flat": "ravel"}
|
|
716
|
+
tensor_method_name = method_name
|
|
717
|
+
if tensor_method_name in method_names_map: tensor_method_name = method_names_map[tensor_method_name]
|
|
718
|
+
expected_tl = TensorList([getattr(t, tensor_method_name)(*args) for t in tl_to_test])
|
|
719
|
+
|
|
720
|
+
assert isinstance(result_tl, TensorList)
|
|
721
|
+
# Need allclose for float results, equal for others
|
|
722
|
+
if any(t.is_floating_point() for t in expected_tl):
|
|
723
|
+
assert_tl_allclose(result_tl, expected_tl)
|
|
724
|
+
else:
|
|
725
|
+
assert_tl_equal(result_tl, expected_tl)
|
|
726
|
+
|
|
727
|
+
# Test inplace if available
|
|
728
|
+
method_inplace_name = method_name + '_'
|
|
729
|
+
if hasattr(tl_to_test, method_inplace_name) and hasattr(torch.Tensor, method_inplace_name):
|
|
730
|
+
tl_copy = tl_to_test.clone()
|
|
731
|
+
method_inplace = getattr(tl_copy, method_inplace_name)
|
|
732
|
+
result_inplace = method_inplace(*args)
|
|
733
|
+
assert result_inplace is tl_copy
|
|
734
|
+
if any(t.is_floating_point() for t in expected_tl):
|
|
735
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
736
|
+
else:
|
|
737
|
+
assert_tl_equal(tl_copy, expected_tl)
|
|
738
|
+
|
|
739
|
+
def test_to(simple_tl: TensorList):
|
|
740
|
+
# Test changing dtype
|
|
741
|
+
float_tl = simple_tl.to(dtype=torch.float64)
|
|
742
|
+
assert all(t.dtype == torch.float64 for t in float_tl)
|
|
743
|
+
|
|
744
|
+
# Test changing device (if multiple devices available)
|
|
745
|
+
if torch.cuda.is_available():
|
|
746
|
+
cuda_tl = simple_tl.to(device='cuda')
|
|
747
|
+
assert all(t.device.type == 'cuda' for t in cuda_tl)
|
|
748
|
+
cpu_tl = cuda_tl.to('cpu')
|
|
749
|
+
assert all(t.device.type == 'cpu' for t in cpu_tl)
|
|
750
|
+
|
|
751
|
+
def test_copy_(simple_tl: TensorList):
|
|
752
|
+
src_tl = TensorList([torch.randn_like(t) for t in simple_tl])
|
|
753
|
+
tl_copy = simple_tl.clone()
|
|
754
|
+
tl_copy.copy_(src_tl)
|
|
755
|
+
assert_tl_equal(tl_copy, src_tl)
|
|
756
|
+
# Ensure src is unchanged
|
|
757
|
+
assert not torch.equal(src_tl[0], simple_tl[0]) # Verify src was different
|
|
758
|
+
|
|
759
|
+
def test_set_(simple_tl: TensorList):
|
|
760
|
+
src_tl = TensorList([torch.randn_like(t) for t in simple_tl])
|
|
761
|
+
tl_copy = simple_tl.clone()
|
|
762
|
+
tl_copy.set_(src_tl) # src_tl provides the storage/tensors
|
|
763
|
+
assert_tl_equal(tl_copy, src_tl)
|
|
764
|
+
# Note: set_ might have side effects on src_tl depending on PyTorch version/tensor types
|
|
765
|
+
|
|
766
|
+
def test_requires_grad_(grad_tl: TensorList):
|
|
767
|
+
grad_tl.requires_grad_(False)
|
|
768
|
+
assert grad_tl.requires_grad == [False] * len(grad_tl)
|
|
769
|
+
grad_tl.requires_grad_(True)
|
|
770
|
+
# This sets requires_grad=True for ALL tensors, unlike the initial fixture
|
|
771
|
+
assert grad_tl.requires_grad == [True] * len(grad_tl)
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
# --- Vectorization ---
|
|
775
|
+
|
|
776
|
+
def test_to_vec(simple_tl: TensorList):
|
|
777
|
+
vec = simple_tl.to_vec()
|
|
778
|
+
expected_vec = torch.cat([t.view(-1) for t in simple_tl])
|
|
779
|
+
assert torch.equal(vec, expected_vec)
|
|
780
|
+
|
|
781
|
+
def test_from_vec_(simple_tl: TensorList):
|
|
782
|
+
tl_copy = simple_tl.clone()
|
|
783
|
+
numel = simple_tl.global_numel()
|
|
784
|
+
new_vec = torch.arange(numel, dtype=simple_tl[0].dtype).float() # Use float for generality
|
|
785
|
+
|
|
786
|
+
result = tl_copy.from_vec_(new_vec)
|
|
787
|
+
assert result is tl_copy
|
|
788
|
+
|
|
789
|
+
current_pos = 0
|
|
790
|
+
for t_orig, t_modified in zip(simple_tl, tl_copy):
|
|
791
|
+
n = t_orig.numel()
|
|
792
|
+
expected_tensor = new_vec[current_pos : current_pos + n].view_as(t_orig)
|
|
793
|
+
assert torch.equal(t_modified, expected_tensor)
|
|
794
|
+
current_pos += n
|
|
795
|
+
|
|
796
|
+
def test_from_vec(simple_tl: TensorList):
|
|
797
|
+
tl_clone = simple_tl.clone() # Keep original safe
|
|
798
|
+
numel = simple_tl.global_numel()
|
|
799
|
+
new_vec = torch.arange(numel, dtype=simple_tl[0].dtype).float()
|
|
800
|
+
|
|
801
|
+
new_tl = simple_tl.from_vec(new_vec)
|
|
802
|
+
assert isinstance(new_tl, TensorList)
|
|
803
|
+
assert_tl_equal(simple_tl, tl_clone) # Original unchanged
|
|
804
|
+
|
|
805
|
+
current_pos = 0
|
|
806
|
+
for t_orig, t_new in zip(simple_tl, new_tl):
|
|
807
|
+
n = t_orig.numel()
|
|
808
|
+
expected_tensor = new_vec[current_pos : current_pos + n].view_as(t_orig)
|
|
809
|
+
assert torch.equal(t_new, expected_tensor)
|
|
810
|
+
current_pos += n
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
# --- Global Reductions ---
|
|
814
|
+
|
|
815
|
+
@pytest.mark.parametrize("global_method, vec_equiv_method", [
|
|
816
|
+
('global_min', 'min'),
|
|
817
|
+
('global_max', 'max'),
|
|
818
|
+
('global_sum', 'sum'),
|
|
819
|
+
('global_mean', 'mean'),
|
|
820
|
+
('global_std', 'std'),
|
|
821
|
+
('global_var', 'var'),
|
|
822
|
+
('global_any', 'any'),
|
|
823
|
+
('global_all', 'all'),
|
|
824
|
+
])
|
|
825
|
+
def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_method):
|
|
826
|
+
tl_to_test = simple_tl
|
|
827
|
+
if 'any' in global_method or 'all' in global_method:
|
|
828
|
+
tl_to_test = simple_tl.gt(0) # Need boolean input
|
|
829
|
+
|
|
830
|
+
global_method_func = getattr(tl_to_test, global_method)
|
|
831
|
+
result = global_method_func()
|
|
832
|
+
|
|
833
|
+
vec = tl_to_test.to_vec()
|
|
834
|
+
vec_equiv_func = getattr(vec, vec_equiv_method)
|
|
835
|
+
expected = vec_equiv_func()
|
|
836
|
+
|
|
837
|
+
if isinstance(result, bool): assert result == expected
|
|
838
|
+
else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def test_global_vector_norm(simple_tl: TensorList):
|
|
842
|
+
ord = 1.5
|
|
843
|
+
result = simple_tl.global_vector_norm(ord=ord)
|
|
844
|
+
vec = simple_tl.to_vec()
|
|
845
|
+
expected = torch.linalg.vector_norm(vec, ord=ord) # pylint:disable=not-callable
|
|
846
|
+
assert torch.allclose(result, expected)
|
|
847
|
+
|
|
848
|
+
def test_global_numel(simple_tl: TensorList):
|
|
849
|
+
result = simple_tl.global_numel()
|
|
850
|
+
expected = sum(t.numel() for t in simple_tl)
|
|
851
|
+
assert result == expected
|
|
852
|
+
|
|
853
|
+
# --- Like Creation Methods ---
|
|
854
|
+
|
|
855
|
+
@pytest.mark.parametrize("like_method, torch_equiv", [
|
|
856
|
+
('empty_like', torch.empty_like),
|
|
857
|
+
('zeros_like', torch.zeros_like),
|
|
858
|
+
('ones_like', torch.ones_like),
|
|
859
|
+
('rand_like', torch.rand_like),
|
|
860
|
+
('randn_like', torch.randn_like),
|
|
861
|
+
])
|
|
862
|
+
def test_simple_like_methods(simple_tl: TensorList, like_method, torch_equiv):
|
|
863
|
+
like_method_func = getattr(simple_tl, like_method)
|
|
864
|
+
result_tl = like_method_func()
|
|
865
|
+
|
|
866
|
+
assert isinstance(result_tl, TensorList)
|
|
867
|
+
assert len(result_tl) == len(simple_tl)
|
|
868
|
+
for res_t, orig_t in zip(result_tl, simple_tl):
|
|
869
|
+
assert res_t.shape == orig_t.shape
|
|
870
|
+
assert res_t.dtype == orig_t.dtype
|
|
871
|
+
assert res_t.device == orig_t.device
|
|
872
|
+
# Cannot easily check values for rand/randn/empty
|
|
873
|
+
|
|
874
|
+
# Test with kwargs (e.g., changing dtype)
|
|
875
|
+
result_tl_kw = like_method_func(dtype=torch.float64)
|
|
876
|
+
assert all(t.dtype == torch.float64 for t in result_tl_kw)
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def test_full_like(simple_tl: TensorList):
|
|
880
|
+
# Scalar fill_value
|
|
881
|
+
fill_value_scalar = 5.0
|
|
882
|
+
result_tl_scalar = simple_tl.full_like(fill_value_scalar)
|
|
883
|
+
expected_tl_scalar = TensorList([torch.full_like(t, fill_value_scalar) for t in simple_tl])
|
|
884
|
+
assert_tl_equal(result_tl_scalar, expected_tl_scalar)
|
|
885
|
+
|
|
886
|
+
# List fill_value
|
|
887
|
+
fill_value_list = [1.0, 2.0, 3.0]
|
|
888
|
+
result_tl_list = simple_tl.full_like(fill_value_list)
|
|
889
|
+
expected_tl_list = TensorList([torch.full_like(t, fv) for t, fv in zip(simple_tl, fill_value_list)])
|
|
890
|
+
assert_tl_equal(result_tl_list, expected_tl_list)
|
|
891
|
+
|
|
892
|
+
# Test with kwargs
|
|
893
|
+
result_tl_kw = simple_tl.full_like(fill_value_scalar, dtype=torch.int)
|
|
894
|
+
assert all(t.dtype == torch.int for t in result_tl_kw)
|
|
895
|
+
assert all(torch.all(t == int(fill_value_scalar)) for t in result_tl_kw)
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
def test_randint_like(simple_tl: TensorList):
|
|
899
|
+
low = 0
|
|
900
|
+
high = 10
|
|
901
|
+
# Scalar low/high
|
|
902
|
+
result_tl_scalar = simple_tl.randint_like(low, high)
|
|
903
|
+
assert isinstance(result_tl_scalar, TensorList)
|
|
904
|
+
assert all(t.dtype == simple_tl[0].dtype for t in result_tl_scalar) # Default dtype
|
|
905
|
+
assert all(torch.all((t >= low) & (t < high)) for t in result_tl_scalar)
|
|
906
|
+
assert result_tl_scalar.shape == simple_tl.shape
|
|
907
|
+
|
|
908
|
+
# List low/high
|
|
909
|
+
low_list = [0, 5, 2]
|
|
910
|
+
high_list = [5, 15, 7]
|
|
911
|
+
result_tl_list = simple_tl.randint_like(low_list, high_list)
|
|
912
|
+
assert isinstance(result_tl_list, TensorList)
|
|
913
|
+
assert all(t.dtype == simple_tl[0].dtype for t in result_tl_list)
|
|
914
|
+
assert all(torch.all((t >= l) & (t < h)) for t, l, h in zip(result_tl_list, low_list, high_list))
|
|
915
|
+
assert result_tl_list.shape == simple_tl.shape
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def test_uniform_like(simple_tl: TensorList):
|
|
919
|
+
# Default range (0, 1)
|
|
920
|
+
result_tl_default = simple_tl.uniform_like()
|
|
921
|
+
assert isinstance(result_tl_default, TensorList)
|
|
922
|
+
assert result_tl_default.shape == simple_tl.shape
|
|
923
|
+
assert all(t.dtype == simple_tl[i].dtype for i, t in enumerate(result_tl_default))
|
|
924
|
+
assert all(torch.all((t >= 0) & (t <= 1)) for t in result_tl_default) # Check range roughly
|
|
925
|
+
|
|
926
|
+
# Scalar low/high
|
|
927
|
+
low, high = -1.0, 1.0
|
|
928
|
+
result_tl_scalar = simple_tl.uniform_like(low, high)
|
|
929
|
+
assert all(torch.all((t >= low) & (t <= high)) for t in result_tl_scalar)
|
|
930
|
+
|
|
931
|
+
# List low/high
|
|
932
|
+
low_list = [-1, 0, -2]
|
|
933
|
+
high_list = [0, 1, -1]
|
|
934
|
+
result_tl_list = simple_tl.uniform_like(low_list, high_list)
|
|
935
|
+
assert all(torch.all((t >= l) & (t <= h)) for t, l, h in zip(result_tl_list, low_list, high_list))
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def test_sphere_like(simple_tl: TensorList):
|
|
939
|
+
radius = 5.0
|
|
940
|
+
result_tl_scalar = simple_tl.sphere_like(radius)
|
|
941
|
+
assert isinstance(result_tl_scalar, TensorList)
|
|
942
|
+
assert result_tl_scalar.shape == simple_tl.shape
|
|
943
|
+
assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(radius))
|
|
944
|
+
|
|
945
|
+
radius_list = [1.0, 10.0, 2.0]
|
|
946
|
+
result_tl_list = simple_tl.sphere_like(radius_list)
|
|
947
|
+
# Cannot easily check norm with list radius, just check type/shape
|
|
948
|
+
assert isinstance(result_tl_list, TensorList)
|
|
949
|
+
assert result_tl_list.shape == simple_tl.shape
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def test_bernoulli_like(big_tl: TensorList):
|
|
953
|
+
p_scalar = 0.7
|
|
954
|
+
result_tl_scalar = big_tl.bernoulli_like(p_scalar)
|
|
955
|
+
assert isinstance(result_tl_scalar, TensorList)
|
|
956
|
+
assert result_tl_scalar.shape == big_tl.shape
|
|
957
|
+
assert all(t.dtype == big_tl[i].dtype for i, t in enumerate(result_tl_scalar)) # Should preserve dtype
|
|
958
|
+
assert all(torch.all((t == 0) | (t == 1)) for t in result_tl_scalar)
|
|
959
|
+
# Check mean is approximately p
|
|
960
|
+
assert abs(result_tl_scalar.to_vec().float().mean().item() - p_scalar) < 0.1 # Loose check
|
|
961
|
+
|
|
962
|
+
p_list = [0.2, 0.5, 0.8]
|
|
963
|
+
result_tl_list = big_tl.bernoulli_like(p_list)
|
|
964
|
+
assert isinstance(result_tl_list, TensorList)
|
|
965
|
+
assert result_tl_list.shape == big_tl.shape
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
def test_rademacher_like(big_tl: TensorList):
|
|
969
|
+
result_tl = big_tl.rademacher_like() # p=0.5 default
|
|
970
|
+
assert isinstance(result_tl, TensorList)
|
|
971
|
+
assert result_tl.shape == big_tl.shape
|
|
972
|
+
assert all(torch.all((t == -1) | (t == 1)) for t in result_tl)
|
|
973
|
+
|
|
974
|
+
# Check mean is approx 0
|
|
975
|
+
assert abs(result_tl.to_vec().float().mean().item()) < 0.1 # Loose check
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
@pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
|
|
979
|
+
def test_sample_like(simple_tl: TensorList, dist):
|
|
980
|
+
eps_scalar = 2.0
|
|
981
|
+
result_tl_scalar = simple_tl.sample_like(eps_scalar, distribution=dist)
|
|
982
|
+
assert isinstance(result_tl_scalar, TensorList)
|
|
983
|
+
assert result_tl_scalar.shape == simple_tl.shape
|
|
984
|
+
|
|
985
|
+
eps_list = [0.5, 1.0, 1.5]
|
|
986
|
+
result_tl_list = simple_tl.sample_like(eps_list, distribution=dist)
|
|
987
|
+
assert isinstance(result_tl_list, TensorList)
|
|
988
|
+
assert result_tl_list.shape == simple_tl.shape
|
|
989
|
+
|
|
990
|
+
# Basic checks based on distribution
|
|
991
|
+
if dist == 'uniform':
|
|
992
|
+
assert all(torch.all((t >= -eps_scalar/2) & (t <= eps_scalar/2)) for t in result_tl_scalar)
|
|
993
|
+
assert all(torch.all((t >= -e/2) & (t <= e/2)) for t, e in zip(result_tl_list, eps_list))
|
|
994
|
+
elif dist == 'sphere':
|
|
995
|
+
assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
|
|
996
|
+
# Cannot check list version easily
|
|
997
|
+
elif dist == 'rademacher':
|
|
998
|
+
assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
|
|
999
|
+
assert all(torch.all((t == -e) | (t == e)) for t, e in zip(result_tl_list, eps_list))
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
# --- Advanced Math Ops ---
|
|
1003
|
+
|
|
1004
|
+
def test_clamp(simple_tl: TensorList):
|
|
1005
|
+
min_val, max_val = -0.5, 0.5
|
|
1006
|
+
# Both min and max
|
|
1007
|
+
clamped_tl = simple_tl.clamp(min_val, max_val)
|
|
1008
|
+
expected_tl = TensorList([t.clamp(min_val, max_val) for t in simple_tl])
|
|
1009
|
+
assert_tl_allclose(clamped_tl, expected_tl)
|
|
1010
|
+
|
|
1011
|
+
# Only min
|
|
1012
|
+
clamped_min_tl = simple_tl.clamp(min=min_val)
|
|
1013
|
+
expected_min_tl = TensorList([t.clamp(min=min_val) for t in simple_tl])
|
|
1014
|
+
assert_tl_allclose(clamped_min_tl, expected_min_tl)
|
|
1015
|
+
|
|
1016
|
+
# Only max
|
|
1017
|
+
clamped_max_tl = simple_tl.clamp(max=max_val)
|
|
1018
|
+
expected_max_tl = TensorList([t.clamp(max=max_val) for t in simple_tl])
|
|
1019
|
+
assert_tl_allclose(clamped_max_tl, expected_max_tl)
|
|
1020
|
+
|
|
1021
|
+
# List min/max
|
|
1022
|
+
min_list = [-1, -0.5, 0]
|
|
1023
|
+
max_list = [1, 0.5, 0.2]
|
|
1024
|
+
clamped_list_tl = simple_tl.clamp(min_list, max_list)
|
|
1025
|
+
expected_list_tl = TensorList([t.clamp(mn, mx) for t, mn, mx in zip(simple_tl, min_list, max_list)])
|
|
1026
|
+
assert_tl_allclose(clamped_list_tl, expected_list_tl)
|
|
1027
|
+
|
|
1028
|
+
# Inplace
|
|
1029
|
+
tl_copy = simple_tl.clone()
|
|
1030
|
+
result = tl_copy.clamp_(min_val, max_val)
|
|
1031
|
+
assert result is tl_copy
|
|
1032
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def test_clamp_magnitude(simple_tl: TensorList):
|
|
1036
|
+
min_val, max_val = 0.2, 1.0
|
|
1037
|
+
tl_copy = simple_tl.clone()
|
|
1038
|
+
# Test non-zero case
|
|
1039
|
+
tl_copy[0][0,0] = 0.01 # ensure some small values
|
|
1040
|
+
tl_copy[1][0] = 10.0 # ensure some large values
|
|
1041
|
+
tl_copy[2][0,0,0] = 0.0 # test zero
|
|
1042
|
+
|
|
1043
|
+
clamped_tl = tl_copy.clamp_magnitude(min_val, max_val)
|
|
1044
|
+
# Check magnitudes are clipped
|
|
1045
|
+
for t in clamped_tl:
|
|
1046
|
+
abs_t = t.abs()
|
|
1047
|
+
# Allow small tolerance for floating point issues near zero
|
|
1048
|
+
assert torch.all(abs_t >= min_val - 1e-6)
|
|
1049
|
+
assert torch.all(abs_t <= max_val + 1e-6)
|
|
1050
|
+
# Check signs are preserved (or zero remains zero)
|
|
1051
|
+
original_sign = tl_copy.sign()
|
|
1052
|
+
clamped_sign = clamped_tl.sign()
|
|
1053
|
+
# Zeros might become non-zero min magnitude, so compare non-zeros
|
|
1054
|
+
non_zero_mask = tl_copy.ne(0)
|
|
1055
|
+
for os, cs, nz in zip(original_sign, clamped_sign, non_zero_mask):
|
|
1056
|
+
assert torch.all(os[nz] == cs[nz])
|
|
1057
|
+
|
|
1058
|
+
# Inplace
|
|
1059
|
+
tl_copy_inplace = tl_copy.clone()
|
|
1060
|
+
result = tl_copy_inplace.clamp_magnitude_(min_val, max_val)
|
|
1061
|
+
assert result is tl_copy_inplace
|
|
1062
|
+
assert_tl_allclose(tl_copy_inplace, clamped_tl)
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
def test_lerp(simple_tl: TensorList):
|
|
1066
|
+
tensors1 = simple_tl.clone().mul_(2)
|
|
1067
|
+
weight_scalar = 0.5
|
|
1068
|
+
result_tl_scalar = simple_tl.lerp(tensors1, weight_scalar)
|
|
1069
|
+
expected_tl_scalar = TensorList([torch.lerp(t, t1, weight_scalar) for t, t1 in zip(simple_tl, tensors1)])
|
|
1070
|
+
assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
|
|
1071
|
+
|
|
1072
|
+
weight_list = [0.1, 0.5, 0.9]
|
|
1073
|
+
result_tl_list = simple_tl.lerp(tensors1, weight_list)
|
|
1074
|
+
expected_tl_list = TensorList([torch.lerp(t, t1, w) for t, t1, w in zip(simple_tl, tensors1, weight_list)])
|
|
1075
|
+
assert_tl_allclose(result_tl_list, expected_tl_list)
|
|
1076
|
+
|
|
1077
|
+
# Inplace
|
|
1078
|
+
tl_copy = simple_tl.clone()
|
|
1079
|
+
result_inplace = tl_copy.lerp_(tensors1, weight_scalar)
|
|
1080
|
+
assert result_inplace is tl_copy
|
|
1081
|
+
assert_tl_allclose(tl_copy, expected_tl_scalar)
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
def test_lerp_compat(simple_tl: TensorList):
|
|
1085
|
+
# Test specifically the scalar sequence case for compatibility fallback
|
|
1086
|
+
tensors1 = simple_tl.clone().mul_(2)
|
|
1087
|
+
weight_list = [0.1, 0.5, 0.9]
|
|
1088
|
+
result_tl_list = simple_tl.lerp_compat(tensors1, weight_list)
|
|
1089
|
+
expected_tl_list = TensorList([t + w * (t1 - t) for t, t1, w in zip(simple_tl, tensors1, weight_list)])
|
|
1090
|
+
assert_tl_allclose(result_tl_list, expected_tl_list)
|
|
1091
|
+
|
|
1092
|
+
# Inplace
|
|
1093
|
+
tl_copy = simple_tl.clone()
|
|
1094
|
+
result_inplace = tl_copy.lerp_compat_(tensors1, weight_list)
|
|
1095
|
+
assert result_inplace is tl_copy
|
|
1096
|
+
assert_tl_allclose(tl_copy, expected_tl_list)
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
@pytest.mark.parametrize("op_name, torch_op", [
|
|
1100
|
+
('addcmul', torch.addcmul),
|
|
1101
|
+
('addcdiv', torch.addcdiv),
|
|
1102
|
+
])
|
|
1103
|
+
def test_addcops(simple_tl: TensorList, op_name, torch_op):
|
|
1104
|
+
tensors1 = simple_tl.clone().add(0.1)
|
|
1105
|
+
tensors2 = simple_tl.clone().mul(0.5)
|
|
1106
|
+
value_scalar = 2.0
|
|
1107
|
+
value_list = [1.0, 2.0, 3.0]
|
|
1108
|
+
|
|
1109
|
+
op_func = getattr(simple_tl, op_name)
|
|
1110
|
+
op_inplace_func = getattr(simple_tl, op_name + '_')
|
|
1111
|
+
|
|
1112
|
+
# Scalar value
|
|
1113
|
+
result_tl_scalar = op_func(tensors1, tensors2, value=value_scalar)
|
|
1114
|
+
expected_tl_scalar = TensorList([torch_op(t, t1, t2, value=value_scalar)
|
|
1115
|
+
for t, t1, t2 in zip(simple_tl, tensors1, tensors2)])
|
|
1116
|
+
assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
|
|
1117
|
+
|
|
1118
|
+
# List value
|
|
1119
|
+
result_tl_list = op_func(tensors1, tensors2, value=value_list)
|
|
1120
|
+
expected_tl_list = TensorList([torch_op(t, t1, t2, value=v)
|
|
1121
|
+
for t, t1, t2, v in zip(simple_tl, tensors1, tensors2, value_list)])
|
|
1122
|
+
assert_tl_allclose(result_tl_list, expected_tl_list)
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
# Inplace (scalar value)
|
|
1126
|
+
tl_copy_scalar = simple_tl.clone()
|
|
1127
|
+
op_inplace_func = getattr(tl_copy_scalar, op_name + '_')
|
|
1128
|
+
result_inplace_scalar = op_inplace_func(tensors1, tensors2, value=value_scalar)
|
|
1129
|
+
assert result_inplace_scalar is tl_copy_scalar
|
|
1130
|
+
assert_tl_allclose(tl_copy_scalar, expected_tl_scalar)
|
|
1131
|
+
|
|
1132
|
+
# Inplace (list value)
|
|
1133
|
+
tl_copy_list = simple_tl.clone()
|
|
1134
|
+
op_inplace_func = getattr(tl_copy_list, op_name + '_')
|
|
1135
|
+
result_inplace_list = op_inplace_func(tensors1, tensors2, value=value_list)
|
|
1136
|
+
assert result_inplace_list is tl_copy_list
|
|
1137
|
+
assert_tl_allclose(tl_copy_list, expected_tl_list)
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
@pytest.mark.parametrize("op_name, torch_op", [
|
|
1141
|
+
('maximum', torch.maximum),
|
|
1142
|
+
('minimum', torch.minimum),
|
|
1143
|
+
])
|
|
1144
|
+
def test_maximin(simple_tl: TensorList, op_name, torch_op):
|
|
1145
|
+
other_scalar = 0.0
|
|
1146
|
+
other_list_scalar = [-1.0, 0.0, 1.0]
|
|
1147
|
+
other_tl = simple_tl.clone().mul_(-1)
|
|
1148
|
+
|
|
1149
|
+
op_func = getattr(simple_tl, op_name)
|
|
1150
|
+
op_inplace_func = getattr(simple_tl, op_name + '_')
|
|
1151
|
+
|
|
1152
|
+
# Scalar other
|
|
1153
|
+
result_tl_scalar = op_func(other_scalar)
|
|
1154
|
+
expected_tl_scalar = TensorList([torch_op(t, torch.tensor(other_scalar, dtype=t.dtype, device=t.device)) for t in simple_tl])
|
|
1155
|
+
assert_tl_allclose(result_tl_scalar, expected_tl_scalar)
|
|
1156
|
+
|
|
1157
|
+
# List scalar other
|
|
1158
|
+
result_tl_list_scalar = op_func(other_list_scalar)
|
|
1159
|
+
expected_tl_list_scalar = TensorList([torch_op(t, torch.tensor(o, dtype=t.dtype, device=t.device)) for t, o in zip(simple_tl, other_list_scalar)])
|
|
1160
|
+
assert_tl_allclose(result_tl_list_scalar, expected_tl_list_scalar)
|
|
1161
|
+
|
|
1162
|
+
# TensorList other
|
|
1163
|
+
result_tl_tl = op_func(other_tl)
|
|
1164
|
+
expected_tl_tl = TensorList([torch_op(t, o) for t, o in zip(simple_tl, other_tl)])
|
|
1165
|
+
assert_tl_allclose(result_tl_tl, expected_tl_tl)
|
|
1166
|
+
|
|
1167
|
+
# Inplace (TensorList other)
|
|
1168
|
+
tl_copy = simple_tl.clone()
|
|
1169
|
+
op_inplace_func = getattr(tl_copy, op_name + '_')
|
|
1170
|
+
result_inplace = op_inplace_func(other_tl)
|
|
1171
|
+
assert result_inplace is tl_copy
|
|
1172
|
+
assert_tl_allclose(tl_copy, expected_tl_tl)
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
def test_nan_to_num(simple_tl: TensorList):
|
|
1176
|
+
tl_with_nan = simple_tl.clone()
|
|
1177
|
+
tl_with_nan[0][0, 0] = float('nan')
|
|
1178
|
+
tl_with_nan[1][0] = float('inf')
|
|
1179
|
+
tl_with_nan[2][0, 0, 0] = float('-inf')
|
|
1180
|
+
|
|
1181
|
+
# Default conversion
|
|
1182
|
+
result_default = tl_with_nan.nan_to_num()
|
|
1183
|
+
expected_default = TensorList([torch.nan_to_num(t) for t in tl_with_nan])
|
|
1184
|
+
assert_tl_equal(result_default, expected_default)
|
|
1185
|
+
|
|
1186
|
+
# Custom values (scalar)
|
|
1187
|
+
nan, posinf, neginf = 0.0, 1e6, -1e6
|
|
1188
|
+
result_scalar = tl_with_nan.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
|
|
1189
|
+
expected_scalar = TensorList([torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf) for t in tl_with_nan])
|
|
1190
|
+
assert_tl_equal(result_scalar, expected_scalar)
|
|
1191
|
+
|
|
1192
|
+
# Custom values (list)
|
|
1193
|
+
nan_list = [0.0, 1.0, 2.0]
|
|
1194
|
+
posinf_list = [1e5, 1e6, 1e7]
|
|
1195
|
+
neginf_list = [-1e5, -1e6, -1e7]
|
|
1196
|
+
result_list = tl_with_nan.nan_to_num(nan=nan_list, posinf=posinf_list, neginf=neginf_list)
|
|
1197
|
+
expected_list = TensorList([torch.nan_to_num(t, nan=n, posinf=p, neginf=ni)
|
|
1198
|
+
for t, n, p, ni in zip(tl_with_nan, nan_list, posinf_list, neginf_list)])
|
|
1199
|
+
assert_tl_equal(result_list, expected_list)
|
|
1200
|
+
|
|
1201
|
+
# Inplace
|
|
1202
|
+
tl_copy = tl_with_nan.clone()
|
|
1203
|
+
result_inplace = tl_copy.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)
|
|
1204
|
+
assert result_inplace is tl_copy
|
|
1205
|
+
assert_tl_equal(tl_copy, expected_scalar)
|
|
1206
|
+
|
|
1207
|
+
# --- Reduction Ops ---
|
|
1208
|
+
|
|
1209
|
+
@pytest.mark.parametrize("reduction_method", ['mean', 'sum', 'min', 'max'])#, 'var', 'std', 'median', 'quantile'])
|
|
1210
|
+
@pytest.mark.parametrize("dim", [None, 0, 'global'])
|
|
1211
|
+
@pytest.mark.parametrize("keepdim", [False, True])
|
|
1212
|
+
def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
|
|
1213
|
+
if dim == 'global' and keepdim:
|
|
1214
|
+
# with pytest.raises(ValueError, match='dim = global and keepdim = True'):
|
|
1215
|
+
# getattr(simple_tl, reduction_method)(dim=dim, keepdim=keepdim)
|
|
1216
|
+
return
|
|
1217
|
+
# Quantile needs q
|
|
1218
|
+
q = 0.75
|
|
1219
|
+
if reduction_method == 'quantile':
|
|
1220
|
+
args = {'q': q, 'dim': dim, 'keepdim': keepdim}
|
|
1221
|
+
torch_args = {'q': q, 'dim': dim, 'keepdim': keepdim}
|
|
1222
|
+
if dim is None: # torch.quantile doesn't accept dim=None, needs integer dim
|
|
1223
|
+
torch_args['dim'] = 0 if simple_tl[0].ndim > 0 else None # Use dim 0 if possible
|
|
1224
|
+
if torch_args['dim'] is None: # Cannot test dim=None on 0-d tensor easily here
|
|
1225
|
+
pytest.skip("Cannot test quantile with dim=None on 0-d tensors easily")
|
|
1226
|
+
elif reduction_method == 'median':
|
|
1227
|
+
args = {'dim': dim, 'keepdim': keepdim}
|
|
1228
|
+
torch_args = {'dim': dim, 'keepdim': keepdim}
|
|
1229
|
+
if dim is None: # torch.median requires dim if tensor is not 1D
|
|
1230
|
+
# Skip complex multi-dim median check for None dim
|
|
1231
|
+
pytest.skip("Skipping median test with dim=None for simplicity")
|
|
1232
|
+
else:
|
|
1233
|
+
args = {'dim': dim, 'keepdim': keepdim}
|
|
1234
|
+
torch_args = {'dim': dim, 'keepdim': keepdim}
|
|
1235
|
+
|
|
1236
|
+
reduction_func = getattr(simple_tl, reduction_method)
|
|
1237
|
+
|
|
1238
|
+
# Skip if dim is invalid for a tensor
|
|
1239
|
+
if isinstance(dim, int):
|
|
1240
|
+
if any(dim >= t.ndim for t in simple_tl):
|
|
1241
|
+
pytest.skip(f"Dimension {dim} out of range for at least one tensor")
|
|
1242
|
+
|
|
1243
|
+
try:
|
|
1244
|
+
result = reduction_func(**args)
|
|
1245
|
+
except RuntimeError as e:
|
|
1246
|
+
# median/quantile might fail on certain dtypes, skip if so
|
|
1247
|
+
if "median" in reduction_method or "quantile" in reduction_method:
|
|
1248
|
+
pytest.skip(f"Skipping {reduction_method} due to dtype incompatibility: {e}")
|
|
1249
|
+
else: raise e
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
if dim == 'global':
|
|
1253
|
+
vec = simple_tl.to_vec()
|
|
1254
|
+
if reduction_method == 'min': expected = vec.min()
|
|
1255
|
+
elif reduction_method == 'max': expected = vec.max()
|
|
1256
|
+
elif reduction_method == 'mean': expected = vec.mean()
|
|
1257
|
+
elif reduction_method == 'sum': expected = vec.sum()
|
|
1258
|
+
elif reduction_method == 'std': expected = vec.std()
|
|
1259
|
+
elif reduction_method == 'var': expected = vec.var()
|
|
1260
|
+
elif reduction_method == 'median': expected = vec.median()#.values # scalar tensor
|
|
1261
|
+
elif reduction_method == 'quantile': expected = vec.quantile(q)
|
|
1262
|
+
else:
|
|
1263
|
+
pytest.fail("Unknown global reduction")
|
|
1264
|
+
assert False, 'sus'
|
|
1265
|
+
assert torch.allclose(result, expected)
|
|
1266
|
+
else:
|
|
1267
|
+
expected_list = []
|
|
1268
|
+
for t in simple_tl:
|
|
1269
|
+
if reduction_method == 'min': torch_func = getattr(t, 'amin')
|
|
1270
|
+
elif reduction_method == 'max': torch_func = getattr(t, 'amax')
|
|
1271
|
+
else: torch_func = getattr(t, reduction_method)
|
|
1272
|
+
try:
|
|
1273
|
+
if reduction_method == 'median':
|
|
1274
|
+
# Median returns (values, indices), we only want values
|
|
1275
|
+
expected_val = torch_func(**torch_args)[0]
|
|
1276
|
+
elif reduction_method == 'quantile':
|
|
1277
|
+
expected_val = torch_func(**torch_args)
|
|
1278
|
+
# quantile might return scalar tensor if dim is None and keepdim=False
|
|
1279
|
+
# if dim is None and not keepdim: expected_val = expected_val.unsqueeze(0) if expected_val.ndim == 0 else expected_val
|
|
1280
|
+
|
|
1281
|
+
else:
|
|
1282
|
+
torch_args_copy = torch_args.copy()
|
|
1283
|
+
if reduction_method in ('min', 'max'):
|
|
1284
|
+
if 'dim' in torch_args_copy and torch_args_copy['dim'] is None: torch_args_copy['dim'] = ()
|
|
1285
|
+
expected_val = torch_func(**torch_args_copy)
|
|
1286
|
+
|
|
1287
|
+
# Handle cases where reduction reduces to scalar but we expect TL
|
|
1288
|
+
if not isinstance(expected_val, torch.Tensor): # e.g. min/max on scalar tensor
|
|
1289
|
+
expected_val = torch.tensor(expected_val, device=t.device, dtype=t.dtype)
|
|
1290
|
+
# if dim is None and not keepdim and expected_val.ndim==0:
|
|
1291
|
+
# expected_val = expected_val.unsqueeze(0) # Make it 1D for consistency in TL
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
expected_list.append(expected_val)
|
|
1295
|
+
except RuntimeError as e:
|
|
1296
|
+
# Skip individual tensor if op not supported (e.g. std on int)
|
|
1297
|
+
if "std" in str(e) or "var" in str(e) or "mean" in str(e):
|
|
1298
|
+
pytest.skip(f"Skipping {reduction_method} on tensor due to dtype: {e}")
|
|
1299
|
+
else: raise e
|
|
1300
|
+
|
|
1301
|
+
expected_tl = TensorList(expected_list)
|
|
1302
|
+
assert isinstance(result, TensorList)
|
|
1303
|
+
assert len(result) == len(expected_tl)
|
|
1304
|
+
assert_tl_allclose(result, expected_tl, atol=1e-6) # Use allclose due to potential float variations
|
|
1305
|
+
|
|
1306
|
+
# --- Grafting, Rescaling, Normalizing, Clipping ---
|
|
1307
|
+
|
|
1308
|
+
def test_graft(simple_tl: TensorList):
|
|
1309
|
+
magnitude_tl = simple_tl.clone().mul_(2.0) # Double the magnitude
|
|
1310
|
+
|
|
1311
|
+
# Tensorwise graft
|
|
1312
|
+
grafted_tensorwise = simple_tl.graft(magnitude_tl, tensorwise=True, ord=2)
|
|
1313
|
+
original_norms = simple_tl.norm(ord=2)
|
|
1314
|
+
magnitude_norms = magnitude_tl.norm(ord=2)
|
|
1315
|
+
grafted_norms = grafted_tensorwise.norm(ord=2)
|
|
1316
|
+
# Check norms match the magnitude norms
|
|
1317
|
+
assert_tl_allclose(grafted_norms, magnitude_norms)
|
|
1318
|
+
# Check directions are preserved (allow for scaling factor)
|
|
1319
|
+
for g, o, onorm, mnorm in zip(grafted_tensorwise, simple_tl, original_norms, magnitude_norms):
|
|
1320
|
+
# Handle zero norm case
|
|
1321
|
+
if onorm > 1e-7 and mnorm > 1e-7:
|
|
1322
|
+
expected_g = o * (mnorm / onorm)
|
|
1323
|
+
assert torch.allclose(g, expected_g)
|
|
1324
|
+
elif mnorm <= 1e-7: # If magnitude is zero, graft should be zero
|
|
1325
|
+
assert torch.allclose(g, torch.zeros_like(g))
|
|
1326
|
+
# If original norm is zero but magnitude is non-zero, result is undefined/arbitrary direction?
|
|
1327
|
+
# Current implementation results in zero due to mul by zero tensor.
|
|
1328
|
+
|
|
1329
|
+
# Global graft
|
|
1330
|
+
grafted_global = simple_tl.graft(magnitude_tl, tensorwise=False, ord=2)
|
|
1331
|
+
original_global_norm = simple_tl.global_vector_norm(ord=2)
|
|
1332
|
+
magnitude_global_norm = magnitude_tl.global_vector_norm(ord=2)
|
|
1333
|
+
grafted_global_norm = grafted_global.global_vector_norm(ord=2)
|
|
1334
|
+
# Check global norm matches
|
|
1335
|
+
assert torch.allclose(grafted_global_norm, magnitude_global_norm)
|
|
1336
|
+
# Check direction (overall vector) is preserved
|
|
1337
|
+
if original_global_norm > 1e-7 and magnitude_global_norm > 1e-7:
|
|
1338
|
+
expected_global_scale = magnitude_global_norm / original_global_norm
|
|
1339
|
+
expected_global_tl = simple_tl * expected_global_scale
|
|
1340
|
+
assert_tl_allclose(grafted_global, expected_global_tl)
|
|
1341
|
+
elif magnitude_global_norm <= 1e-7:
|
|
1342
|
+
assert torch.allclose(grafted_global.to_vec(), torch.zeros(simple_tl.global_numel()))
|
|
1343
|
+
|
|
1344
|
+
|
|
1345
|
+
# Test inplace
|
|
1346
|
+
tl_copy_t = simple_tl.clone()
|
|
1347
|
+
tl_copy_t.graft_(magnitude_tl, tensorwise=True, ord=2)
|
|
1348
|
+
assert_tl_allclose(tl_copy_t, grafted_tensorwise)
|
|
1349
|
+
|
|
1350
|
+
tl_copy_g = simple_tl.clone()
|
|
1351
|
+
tl_copy_g.graft_(magnitude_tl, tensorwise=False, ord=2)
|
|
1352
|
+
assert_tl_allclose(tl_copy_g, grafted_global)
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
@pytest.mark.parametrize("dim", [None, 0, 'global'])
|
|
1356
|
+
def test_rescale(simple_tl: TensorList, dim):
|
|
1357
|
+
min_val, max_val = 0.0, 1.0
|
|
1358
|
+
min_list = [0.0, -1.0, 0.5]
|
|
1359
|
+
max_list = [1.0, 0.0, 1.5]
|
|
1360
|
+
eps = 1e-6
|
|
1361
|
+
|
|
1362
|
+
# if dim is 0 make sure it isn't len 1 dim
|
|
1363
|
+
if dim == 0:
|
|
1364
|
+
tensors = TensorList()
|
|
1365
|
+
for t in simple_tl:
|
|
1366
|
+
while t.shape[0] == 1: t = t[0]
|
|
1367
|
+
if t.ndim != 0: tensors.append(t)
|
|
1368
|
+
simple_tl = tensors
|
|
1369
|
+
|
|
1370
|
+
# Skip if dim is invalid for a tensor
|
|
1371
|
+
if isinstance(dim, int):
|
|
1372
|
+
if any(dim >= t.ndim for t in simple_tl):
|
|
1373
|
+
pytest.skip(f"Dimension {dim} out of range for at least one tensor")
|
|
1374
|
+
|
|
1375
|
+
# Rescale scalar
|
|
1376
|
+
rescaled_scalar = simple_tl.rescale(min_val, max_val, dim=dim, eps=eps)
|
|
1377
|
+
rescaled_scalar_min = rescaled_scalar.min(dim=dim if dim != 'global' else None)
|
|
1378
|
+
rescaled_scalar_max = rescaled_scalar.max(dim=dim if dim != 'global' else None)
|
|
1379
|
+
|
|
1380
|
+
if dim == 'global':
|
|
1381
|
+
assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
|
|
1382
|
+
assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
|
|
1383
|
+
else:
|
|
1384
|
+
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-4)
|
|
1385
|
+
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-4)
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
# Rescale list
|
|
1389
|
+
rescaled_list = simple_tl.rescale(min_list, max_list, dim=dim, eps=eps)
|
|
1390
|
+
rescaled_list_min = rescaled_list.min(dim=dim if dim != 'global' else None)
|
|
1391
|
+
rescaled_list_max = rescaled_list.max(dim=dim if dim != 'global' else None)
|
|
1392
|
+
|
|
1393
|
+
if dim == 'global':
|
|
1394
|
+
# Global rescale with list min/max is tricky, check range contains target roughly
|
|
1395
|
+
global_min_rescaled = rescaled_list.global_min()
|
|
1396
|
+
global_max_rescaled = rescaled_list.global_max()
|
|
1397
|
+
# Cannot guarantee exact match due to single scaling factor 'a' and 'b'
|
|
1398
|
+
# Check if the range is approximately correct based on average target range?
|
|
1399
|
+
avg_min = sum(min_list)/len(min_list)
|
|
1400
|
+
avg_max = sum(max_list)/len(max_list)
|
|
1401
|
+
assert global_min_rescaled > avg_min - 1.0 # Loose check
|
|
1402
|
+
assert global_max_rescaled < avg_max + 1.0 # Loose check
|
|
1403
|
+
|
|
1404
|
+
else:
|
|
1405
|
+
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-4)
|
|
1406
|
+
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-4)
|
|
1407
|
+
|
|
1408
|
+
# Rescale to 01 helper
|
|
1409
|
+
rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
|
|
1410
|
+
rescaled_01_min = rescaled_01.min(dim=dim if dim != 'global' else None)
|
|
1411
|
+
rescaled_01_max = rescaled_01.max(dim=dim if dim != 'global' else None)
|
|
1412
|
+
if dim == 'global':
|
|
1413
|
+
assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
|
|
1414
|
+
assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
|
|
1415
|
+
else:
|
|
1416
|
+
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-4)
|
|
1417
|
+
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-4)
|
|
1418
|
+
|
|
1419
|
+
|
|
1420
|
+
# Test inplace
|
|
1421
|
+
tl_copy = simple_tl.clone()
|
|
1422
|
+
tl_copy.rescale_(min_val, max_val, dim=dim, eps=eps)
|
|
1423
|
+
assert_tl_allclose(tl_copy, rescaled_scalar)
|
|
1424
|
+
|
|
1425
|
+
tl_copy_01 = simple_tl.clone()
|
|
1426
|
+
tl_copy_01.rescale_to_01_(dim=dim, eps=eps)
|
|
1427
|
+
assert_tl_allclose(tl_copy_01, rescaled_01)
|
|
1428
|
+
|
|
1429
|
+
|
|
1430
|
+
@pytest.mark.parametrize("dim", [None, 0, 'global'])
|
|
1431
|
+
def test_normalize(big_tl: TensorList, dim):
|
|
1432
|
+
simple_tl = big_tl # can't be bothered t renamed
|
|
1433
|
+
|
|
1434
|
+
mean_val, var_val = 0.0, 1.0
|
|
1435
|
+
mean_list = [0.0, 1.0, -0.5]
|
|
1436
|
+
var_list = [1.0, 0.5, 2.0] # Variance > 0
|
|
1437
|
+
|
|
1438
|
+
# if dim is 0 make sure it isn't len 1 dim
|
|
1439
|
+
if dim == 0:
|
|
1440
|
+
tensors = TensorList()
|
|
1441
|
+
for t in simple_tl:
|
|
1442
|
+
while t.shape[0] == 1: t = t[0]
|
|
1443
|
+
if t.ndim != 0: tensors.append(t)
|
|
1444
|
+
simple_tl = tensors
|
|
1445
|
+
|
|
1446
|
+
# Skip if dim is invalid for a tensor
|
|
1447
|
+
if isinstance(dim, int):
|
|
1448
|
+
if any(dim >= t.ndim for t in simple_tl):
|
|
1449
|
+
pytest.skip(f"Dimension {dim} out of range for at least one tensor")
|
|
1450
|
+
|
|
1451
|
+
# Normalize scalar mean/var (z-normalize essentially)
|
|
1452
|
+
normalized_scalar = simple_tl.normalize(mean_val, var_val, dim=dim)
|
|
1453
|
+
normalized_scalar_mean = normalized_scalar.mean(dim=dim if dim != 'global' else None)
|
|
1454
|
+
normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
|
|
1455
|
+
|
|
1456
|
+
if dim == 'global':
|
|
1457
|
+
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-4)
|
|
1458
|
+
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-4)
|
|
1459
|
+
else:
|
|
1460
|
+
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-4)
|
|
1461
|
+
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-4)
|
|
1462
|
+
|
|
1463
|
+
# Normalize list mean/var
|
|
1464
|
+
normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
|
|
1465
|
+
normalized_list_mean = normalized_list.mean(dim=dim if dim != 'global' else None)
|
|
1466
|
+
normalized_list_var = normalized_list.var(dim=dim if dim != 'global' else None)
|
|
1467
|
+
|
|
1468
|
+
if dim == 'global':
|
|
1469
|
+
global_mean_rescaled = normalized_list.global_mean()
|
|
1470
|
+
global_var_rescaled = normalized_list.global_var()
|
|
1471
|
+
avg_mean = sum(mean_list)/len(mean_list)
|
|
1472
|
+
avg_var = sum(var_list)/len(var_list)
|
|
1473
|
+
# Cannot guarantee exact match due to single scaling factor 'a' and 'b'
|
|
1474
|
+
assert global_mean_rescaled - 0.6 < torch.tensor(avg_mean) < global_mean_rescaled + 0.6
|
|
1475
|
+
assert global_var_rescaled - 0.6 < torch.tensor(avg_var) < global_var_rescaled + 0.6
|
|
1476
|
+
# assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
|
|
1477
|
+
# assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
|
|
1478
|
+
else:
|
|
1479
|
+
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-4)
|
|
1480
|
+
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-4)
|
|
1481
|
+
|
|
1482
|
+
# Z-normalize helper
|
|
1483
|
+
znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
|
|
1484
|
+
znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
|
|
1485
|
+
znorm_var = znorm.var(dim=dim if dim != 'global' else None)
|
|
1486
|
+
if dim == 'global':
|
|
1487
|
+
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-4)
|
|
1488
|
+
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-4)
|
|
1489
|
+
else:
|
|
1490
|
+
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-4)
|
|
1491
|
+
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-4)
|
|
1492
|
+
|
|
1493
|
+
|
|
1494
|
+
# Test inplace
|
|
1495
|
+
tl_copy = simple_tl.clone()
|
|
1496
|
+
tl_copy.normalize_(mean_val, var_val, dim=dim)
|
|
1497
|
+
assert_tl_allclose(tl_copy, normalized_scalar)
|
|
1498
|
+
|
|
1499
|
+
tl_copy_z = simple_tl.clone()
|
|
1500
|
+
tl_copy_z.znormalize_(dim=dim, eps=1e-10)
|
|
1501
|
+
assert_tl_allclose(tl_copy_z, znorm)
|
|
1502
|
+
|
|
1503
|
+
|
|
1504
|
+
@pytest.mark.parametrize("tensorwise", [True, False])
|
|
1505
|
+
def test_clip_norm(simple_tl: TensorList, tensorwise):
|
|
1506
|
+
min_val, max_val = 0.5, 1.5
|
|
1507
|
+
min_list = [0.2, 0.7, 1.0]
|
|
1508
|
+
max_list = [1.0, 1.2, 2.0]
|
|
1509
|
+
ord = 2
|
|
1510
|
+
|
|
1511
|
+
# Clip scalar min/max
|
|
1512
|
+
clipped_scalar = simple_tl.clip_norm(min_val, max_val, tensorwise=tensorwise, ord=ord)
|
|
1513
|
+
if tensorwise:
|
|
1514
|
+
clipped_scalar_norms = clipped_scalar.norm(ord=ord)
|
|
1515
|
+
assert all(torch.all((n >= min_val - 1e-6) & (n <= max_val + 1e-6)) for n in clipped_scalar_norms)
|
|
1516
|
+
else:
|
|
1517
|
+
clipped_scalar_global_norm = clipped_scalar.global_vector_norm(ord=ord)
|
|
1518
|
+
assert min_val - 1e-6 <= clipped_scalar_global_norm <= max_val + 1e-6
|
|
1519
|
+
|
|
1520
|
+
# Clip list min/max
|
|
1521
|
+
clipped_list = simple_tl.clip_norm(min_list, max_list, tensorwise=tensorwise, ord=ord)
|
|
1522
|
+
if tensorwise:
|
|
1523
|
+
clipped_list_norms = clipped_list.norm(ord=ord)
|
|
1524
|
+
assert all(torch.all((n >= mn - 1e-6) & (n <= mx + 1e-6)) for n, mn, mx in zip(clipped_list_norms, min_list, max_list))
|
|
1525
|
+
else:
|
|
1526
|
+
# Global clip with list min/max is tricky, multiplier is complex
|
|
1527
|
+
# Just check type and shape
|
|
1528
|
+
assert isinstance(clipped_list, TensorList)
|
|
1529
|
+
assert clipped_list.shape == simple_tl.shape
|
|
1530
|
+
|
|
1531
|
+
|
|
1532
|
+
# Test inplace
|
|
1533
|
+
tl_copy = simple_tl.clone()
|
|
1534
|
+
tl_copy.clip_norm_(min_val, max_val, tensorwise=tensorwise, ord=ord)
|
|
1535
|
+
assert_tl_allclose(tl_copy, clipped_scalar)
|
|
1536
|
+
|
|
1537
|
+
|
|
1538
|
+
# --- Indexing and Masking ---
|
|
1539
|
+
|
|
1540
|
+
def test_where(simple_tl: TensorList):
|
|
1541
|
+
condition_tl = simple_tl.gt(0)
|
|
1542
|
+
other_scalar = -1.0
|
|
1543
|
+
other_list_scalar = [-1.0, -2.0, -3.0]
|
|
1544
|
+
other_tl = simple_tl.clone().mul_(-1)
|
|
1545
|
+
|
|
1546
|
+
# Scalar other
|
|
1547
|
+
result_scalar = simple_tl.where(condition_tl, other_scalar)
|
|
1548
|
+
expected_scalar = TensorList([torch.where(c, t, torch.tensor(other_scalar, dtype=t.dtype, device=t.device))
|
|
1549
|
+
for t, c in zip(simple_tl, condition_tl)])
|
|
1550
|
+
assert_tl_allclose(result_scalar, expected_scalar)
|
|
1551
|
+
|
|
1552
|
+
# List scalar other
|
|
1553
|
+
result_list_scalar = simple_tl.where(condition_tl, other_list_scalar)
|
|
1554
|
+
expected_list_scalar = TensorList([torch.where(c, t, torch.tensor(o, dtype=t.dtype, device=t.device))
|
|
1555
|
+
for t, c, o in zip(simple_tl, condition_tl, other_list_scalar)])
|
|
1556
|
+
assert_tl_allclose(result_list_scalar, expected_list_scalar)
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
# TensorList other
|
|
1560
|
+
result_tl = simple_tl.where(condition_tl, other_tl)
|
|
1561
|
+
expected_tl = TensorList([torch.where(c, t, o) for t, c, o in zip(simple_tl, condition_tl, other_tl)])
|
|
1562
|
+
assert_tl_allclose(result_tl, expected_tl)
|
|
1563
|
+
|
|
1564
|
+
# Test module-level where function
|
|
1565
|
+
result_module = tl_where(condition_tl, simple_tl, other_tl)
|
|
1566
|
+
assert_tl_allclose(result_module, expected_tl)
|
|
1567
|
+
|
|
1568
|
+
|
|
1569
|
+
# Test inplace where_ (needs TensorList other)
|
|
1570
|
+
tl_copy = simple_tl.clone()
|
|
1571
|
+
result_inplace = tl_copy.where_(condition_tl, other_tl)
|
|
1572
|
+
assert result_inplace is tl_copy
|
|
1573
|
+
assert_tl_allclose(tl_copy, expected_tl)
|
|
1574
|
+
|
|
1575
|
+
|
|
1576
|
+
def test_masked_fill(simple_tl: TensorList):
|
|
1577
|
+
mask_tl = simple_tl.lt(0)
|
|
1578
|
+
fill_value_scalar = 99.0
|
|
1579
|
+
fill_value_list = [11.0, 22.0, 33.0]
|
|
1580
|
+
|
|
1581
|
+
# Scalar fill
|
|
1582
|
+
result_scalar = simple_tl.masked_fill(mask_tl, fill_value_scalar)
|
|
1583
|
+
expected_scalar = TensorList([t.masked_fill(m, fill_value_scalar) for t, m in zip(simple_tl, mask_tl)])
|
|
1584
|
+
assert_tl_allclose(result_scalar, expected_scalar)
|
|
1585
|
+
|
|
1586
|
+
# List fill
|
|
1587
|
+
result_list = simple_tl.masked_fill(mask_tl, fill_value_list)
|
|
1588
|
+
expected_list = TensorList([t.masked_fill(m, fv) for t, m, fv in zip(simple_tl, mask_tl, fill_value_list)])
|
|
1589
|
+
assert_tl_allclose(result_list, expected_list)
|
|
1590
|
+
|
|
1591
|
+
# Test inplace
|
|
1592
|
+
tl_copy = simple_tl.clone()
|
|
1593
|
+
result_inplace = tl_copy.masked_fill_(mask_tl, fill_value_scalar)
|
|
1594
|
+
assert result_inplace is tl_copy
|
|
1595
|
+
assert_tl_allclose(tl_copy, expected_scalar)
|
|
1596
|
+
|
|
1597
|
+
|
|
1598
|
+
def test_select_set_(simple_tl: TensorList):
|
|
1599
|
+
mask_tl = simple_tl.gt(0.5)
|
|
1600
|
+
value_scalar = -1.0
|
|
1601
|
+
value_list_scalar = [-1.0, -2.0, -3.0]
|
|
1602
|
+
value_tl = simple_tl.clone().mul_(0.1)
|
|
1603
|
+
|
|
1604
|
+
# Set with scalar value
|
|
1605
|
+
tl_copy_scalar = simple_tl.clone()
|
|
1606
|
+
tl_copy_scalar.select_set_(mask_tl, value_scalar)
|
|
1607
|
+
expected_scalar = simple_tl.clone()
|
|
1608
|
+
for t, m in zip(expected_scalar, mask_tl): t[m] = value_scalar
|
|
1609
|
+
assert_tl_allclose(tl_copy_scalar, expected_scalar)
|
|
1610
|
+
|
|
1611
|
+
# Set with list of scalar values
|
|
1612
|
+
tl_copy_list_scalar = simple_tl.clone()
|
|
1613
|
+
tl_copy_list_scalar.select_set_(mask_tl, value_list_scalar)
|
|
1614
|
+
expected_list_scalar = simple_tl.clone()
|
|
1615
|
+
for t, m, v in zip(expected_list_scalar, mask_tl, value_list_scalar): t[m] = v
|
|
1616
|
+
assert_tl_allclose(tl_copy_list_scalar, expected_list_scalar)
|
|
1617
|
+
|
|
1618
|
+
# Set with TensorList value
|
|
1619
|
+
# no thats masked_set_
|
|
1620
|
+
# tl_copy_tl = simple_tl.clone()
|
|
1621
|
+
# tl_copy_tl.select_set_(mask_tl, value_tl)
|
|
1622
|
+
# expected_tl = simple_tl.clone()
|
|
1623
|
+
# for t, m, v in zip(expected_tl, mask_tl, value_tl): t[m] = v[m] # Select from value tensor too
|
|
1624
|
+
# assert_tl_allclose(tl_copy_tl, expected_tl)
|
|
1625
|
+
|
|
1626
|
+
|
|
1627
|
+
def test_masked_set_(simple_tl: TensorList):
|
|
1628
|
+
mask_tl = simple_tl.gt(0.5)
|
|
1629
|
+
value_tl = simple_tl.clone().mul_(0.1)
|
|
1630
|
+
|
|
1631
|
+
tl_copy = simple_tl.clone()
|
|
1632
|
+
tl_copy.masked_set_(mask_tl, value_tl)
|
|
1633
|
+
expected = simple_tl.clone()
|
|
1634
|
+
for t, m, v in zip(expected, mask_tl, value_tl): t[m] = v[m] # masked_set_ semantics
|
|
1635
|
+
assert_tl_allclose(tl_copy, expected)
|
|
1636
|
+
|
|
1637
|
+
|
|
1638
|
+
def test_select(simple_tl: TensorList):
|
|
1639
|
+
# Select with integer
|
|
1640
|
+
idx_int = 0
|
|
1641
|
+
result_int = simple_tl.select(idx_int)
|
|
1642
|
+
expected_int = TensorList([t[idx_int] for t in simple_tl])
|
|
1643
|
+
assert_tl_equal(result_int, expected_int)
|
|
1644
|
+
|
|
1645
|
+
# Select with slice
|
|
1646
|
+
idx_slice = slice(0, 1)
|
|
1647
|
+
result_slice = simple_tl.select(idx_slice)
|
|
1648
|
+
expected_slice = TensorList([t[idx_slice] for t in simple_tl])
|
|
1649
|
+
assert_tl_equal(result_slice, expected_slice)
|
|
1650
|
+
|
|
1651
|
+
# Select with list of indices (per tensor)
|
|
1652
|
+
idx_list = [0, slice(1, 3), (0, slice(None), 0)] # Different index for each tensor
|
|
1653
|
+
result_list = simple_tl.select(idx_list)
|
|
1654
|
+
expected_list = TensorList([t[i] for t, i in zip(simple_tl, idx_list)])
|
|
1655
|
+
assert_tl_equal(result_list, expected_list)
|
|
1656
|
+
|
|
1657
|
+
# --- Miscellaneous ---
|
|
1658
|
+
|
|
1659
|
+
def test_dot(simple_tl: TensorList):
|
|
1660
|
+
other_tl = simple_tl.clone().mul_(0.5)
|
|
1661
|
+
result = simple_tl.dot(other_tl)
|
|
1662
|
+
expected = (simple_tl * other_tl).global_sum()
|
|
1663
|
+
assert torch.allclose(result, expected)
|
|
1664
|
+
|
|
1665
|
+
def test_swap_tensors(simple_tl: TensorList):
|
|
1666
|
+
tl1 = simple_tl.clone()
|
|
1667
|
+
tl2 = simple_tl.clone().mul_(2)
|
|
1668
|
+
tl1_orig_copy = tl1.clone()
|
|
1669
|
+
tl2_orig_copy = tl2.clone()
|
|
1670
|
+
|
|
1671
|
+
tl1.swap_tensors(tl2)
|
|
1672
|
+
|
|
1673
|
+
# Check tl1 now has tl2's original data and vice versa
|
|
1674
|
+
assert_tl_equal(tl1, tl2_orig_copy)
|
|
1675
|
+
assert_tl_equal(tl2, tl1_orig_copy)
|
|
1676
|
+
|
|
1677
|
+
|
|
1678
|
+
def test_unbind_channels(simple_tl: TensorList):
|
|
1679
|
+
# Make sure at least one tensor has >1 dim 0 size
|
|
1680
|
+
simple_tl[0] = torch.randn(3, 4, 5)
|
|
1681
|
+
simple_tl[1] = torch.randn(2, 6)
|
|
1682
|
+
simple_tl[2] = torch.randn(1) # Keep a 1D tensor
|
|
1683
|
+
|
|
1684
|
+
unbound_tl = simple_tl.unbind_channels(dim=0)
|
|
1685
|
+
|
|
1686
|
+
expected_list = []
|
|
1687
|
+
for t in simple_tl:
|
|
1688
|
+
if t.ndim >= 2:
|
|
1689
|
+
expected_list.extend(list(t.unbind(dim=0)))
|
|
1690
|
+
else:
|
|
1691
|
+
expected_list.append(t)
|
|
1692
|
+
expected_tl = TensorList(expected_list)
|
|
1693
|
+
|
|
1694
|
+
assert_tl_equal(unbound_tl, expected_tl)
|
|
1695
|
+
|
|
1696
|
+
|
|
1697
|
+
def test_flatiter(simple_tl: TensorList):
|
|
1698
|
+
iterator = simple_tl.flatiter()
|
|
1699
|
+
all_elements = list(iterator)
|
|
1700
|
+
expected_elements = list(simple_tl.to_vec())
|
|
1701
|
+
|
|
1702
|
+
assert len(all_elements) == len(expected_elements)
|
|
1703
|
+
for el, exp_el in zip(all_elements, expected_elements):
|
|
1704
|
+
# flatiter yields scalar tensors
|
|
1705
|
+
assert isinstance(el, torch.Tensor)
|
|
1706
|
+
assert el.ndim == 0
|
|
1707
|
+
assert torch.equal(el, exp_el)
|
|
1708
|
+
|
|
1709
|
+
|
|
1710
|
+
def test_repr(simple_tl: TensorList):
|
|
1711
|
+
representation = repr(simple_tl)
|
|
1712
|
+
assert representation.startswith("TensorList([")
|
|
1713
|
+
assert representation.endswith("])")
|
|
1714
|
+
# Check if tensor representations are inside
|
|
1715
|
+
assert "tensor(" in representation
|
|
1716
|
+
|
|
1717
|
+
|
|
1718
|
+
# --- Module Level Functions ---
|
|
1719
|
+
|
|
1720
|
+
def test_stack(simple_tl: TensorList):
|
|
1721
|
+
tl1 = simple_tl.clone()
|
|
1722
|
+
tl2 = simple_tl.clone() * 2
|
|
1723
|
+
tl3 = simple_tl.clone() * 3
|
|
1724
|
+
|
|
1725
|
+
stacked_tl = stack([tl1, tl2, tl3], dim=0)
|
|
1726
|
+
expected_tl = TensorList([torch.stack([t1, t2, t3], dim=0)
|
|
1727
|
+
for t1, t2, t3 in zip(tl1, tl2, tl3)])
|
|
1728
|
+
assert_tl_equal(stacked_tl, expected_tl)
|
|
1729
|
+
|
|
1730
|
+
stacked_tl_dim1 = stack([tl1, tl2, tl3], dim=1)
|
|
1731
|
+
expected_tl_dim1 = TensorList([torch.stack([t1, t2, t3], dim=1)
|
|
1732
|
+
for t1, t2, t3 in zip(tl1, tl2, tl3)])
|
|
1733
|
+
assert_tl_equal(stacked_tl_dim1, expected_tl_dim1)
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
def test_mean_median_sum_quantile_module(simple_tl: TensorList):
|
|
1737
|
+
tl1 = simple_tl.clone()
|
|
1738
|
+
tl2 = simple_tl.clone() * 2.5
|
|
1739
|
+
tl3 = simple_tl.clone() * -1.0
|
|
1740
|
+
tensors = [tl1, tl2, tl3]
|
|
1741
|
+
|
|
1742
|
+
# Mean
|
|
1743
|
+
mean_res = mean(tensors)
|
|
1744
|
+
expected_mean = stack(tensors, dim=0).mean(dim=0)
|
|
1745
|
+
assert_tl_allclose(mean_res, expected_mean)
|
|
1746
|
+
|
|
1747
|
+
# Sum
|
|
1748
|
+
sum_res = tl_sum(tensors)
|
|
1749
|
+
expected_sum = stack(tensors, dim=0).sum(dim=0)
|
|
1750
|
+
assert_tl_allclose(sum_res, expected_sum)
|
|
1751
|
+
|
|
1752
|
+
# Median
|
|
1753
|
+
median_res = median(tensors)
|
|
1754
|
+
# Stack and get median values (result is named tuple)
|
|
1755
|
+
expected_median_vals = stack(tensors, dim=0).median(dim=0)
|
|
1756
|
+
expected_median = TensorList(expected_median_vals)
|
|
1757
|
+
assert_tl_allclose(median_res, expected_median)
|
|
1758
|
+
|
|
1759
|
+
# Quantile
|
|
1760
|
+
q = 0.25
|
|
1761
|
+
quantile_res = quantile(tensors, q=q)
|
|
1762
|
+
expected_quantile_vals = stack(tensors, dim=0).quantile(q=q, dim=0)
|
|
1763
|
+
expected_quantile = TensorList(list(expected_quantile_vals))
|
|
1764
|
+
assert_tl_allclose(quantile_res, expected_quantile)
|
|
1765
|
+
|
|
1766
|
+
|
|
1767
|
+
# --- Test _MethodCallerWithArgs ---
|
|
1768
|
+
def test_method_caller_with_args():
|
|
1769
|
+
caller = _MethodCallerWithArgs('add')
|
|
1770
|
+
t = torch.tensor([1, 2])
|
|
1771
|
+
result = caller(t, 5) # t.add(5)
|
|
1772
|
+
assert torch.equal(result, torch.tensor([6, 7]))
|
|
1773
|
+
|
|
1774
|
+
result_kw = caller(t, other=10, alpha=2) # t.add(other=10, alpha=2)
|
|
1775
|
+
assert torch.equal(result_kw, torch.tensor([21, 22])) # 1 + 2*10, 2 + 2*10
|
|
1776
|
+
|
|
1777
|
+
# --- Test generic_clamp ---
|
|
1778
|
+
def test_generic_clamp():
|
|
1779
|
+
assert generic_clamp(5, min=0, max=10) == 5
|
|
1780
|
+
assert generic_clamp(-5, min=0, max=10) == 0
|
|
1781
|
+
assert generic_clamp(15, min=0, max=10) == 10
|
|
1782
|
+
assert generic_clamp(torch.tensor([-5, 5, 15]), min=0, max=10).equal(torch.tensor([0, 5, 10]))
|
|
1783
|
+
|
|
1784
|
+
tl = TensorList([torch.tensor([-5, 5, 15]), torch.tensor([1, 12])])
|
|
1785
|
+
clamped_tl = generic_clamp(tl, min=0, max=10)
|
|
1786
|
+
expected_tl = TensorList([torch.tensor([0, 5, 10]), torch.tensor([1, 10])])
|
|
1787
|
+
assert_tl_equal(clamped_tl, expected_tl)
|