pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
- pytme-0.2.4.data/scripts/preprocess.py +148 -0
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
- pytme-0.2.4.dist-info/RECORD +119 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +97 -148
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +116 -61
- scripts/preprocessor_gui.py +15 -23
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +276 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +35 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +26 -12
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +33 -24
- tme/matching_exhaustive.py +39 -20
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +8 -2
- tme/orientations.py +26 -9
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +5 -4
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +210 -148
- tme/preprocessor.py +24 -246
- tme/structure.py +14 -14
- pytme-0.2.2.dist-info/RECORD +0 -74
- tme/matching_memory.py +0 -383
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
File without changes
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from tme.preprocessing import Compose
|
4
|
+
from tme.backends import backend as be
|
5
|
+
|
6
|
+
|
7
|
+
def mock_transform1(**kwargs):
|
8
|
+
return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
|
9
|
+
|
10
|
+
|
11
|
+
def mock_transform2(**kwargs):
|
12
|
+
return {"data": be.ones((10, 10)) * 2, "is_multiplicative_filter": True}
|
13
|
+
|
14
|
+
|
15
|
+
def mock_transform3(**kwargs):
|
16
|
+
return {"extra_info": "test"}
|
17
|
+
|
18
|
+
|
19
|
+
class TestCompose:
|
20
|
+
@pytest.fixture
|
21
|
+
def compose_instance(self):
|
22
|
+
return Compose((mock_transform1, mock_transform2, mock_transform3))
|
23
|
+
|
24
|
+
def test_init(self):
|
25
|
+
transforms = (mock_transform1, mock_transform2)
|
26
|
+
compose = Compose(transforms)
|
27
|
+
assert compose.transforms == transforms
|
28
|
+
|
29
|
+
def test_call_empty_transforms(self):
|
30
|
+
compose = Compose(())
|
31
|
+
result = compose()
|
32
|
+
assert result == {}
|
33
|
+
|
34
|
+
def test_call_single_transform(self):
|
35
|
+
compose = Compose((mock_transform1,))
|
36
|
+
result = compose()
|
37
|
+
assert "data" in result
|
38
|
+
assert result.get("is_multiplicative_filter", False)
|
39
|
+
assert be.allclose(result["data"], be.ones((10, 10)))
|
40
|
+
|
41
|
+
def test_call_multiple_transforms(self, compose_instance):
|
42
|
+
result = compose_instance()
|
43
|
+
assert "data" in result
|
44
|
+
assert "extra_info" not in result
|
45
|
+
assert be.allclose(result["data"], be.ones((10, 10)) * 2)
|
46
|
+
|
47
|
+
def test_multiplicative_filter_composition(self):
|
48
|
+
compose = Compose((mock_transform1, mock_transform2))
|
49
|
+
result = compose()
|
50
|
+
assert "data" in result
|
51
|
+
assert be.allclose(result["data"], be.ones((10, 10)) * 2)
|
52
|
+
|
53
|
+
@pytest.mark.parametrize(
|
54
|
+
"kwargs", [{}, {"extra_param": "test"}, {"data": be.zeros((5, 5))}]
|
55
|
+
)
|
56
|
+
def test_call_with_kwargs(self, compose_instance, kwargs):
|
57
|
+
result = compose_instance(**kwargs)
|
58
|
+
assert "data" in result
|
59
|
+
assert "extra_info" not in result
|
60
|
+
|
61
|
+
def test_non_multiplicative_filter(self):
|
62
|
+
def non_mult_transform(**kwargs):
|
63
|
+
return {"data": be.ones((10, 10)) * 3, "is_multiplicative_filter": False}
|
64
|
+
|
65
|
+
compose = Compose((mock_transform1, non_mult_transform))
|
66
|
+
result = compose()
|
67
|
+
assert "data" in result
|
68
|
+
assert be.allclose(result["data"], be.ones((10, 10)) * 3)
|
69
|
+
|
70
|
+
def test_error_handling(self):
|
71
|
+
def error_transform(**kwargs):
|
72
|
+
raise ValueError("Test error")
|
73
|
+
|
74
|
+
compose = Compose((mock_transform1, error_transform))
|
75
|
+
with pytest.raises(ValueError, match="Test error"):
|
76
|
+
compose()
|
@@ -0,0 +1,178 @@
|
|
1
|
+
import pytest
|
2
|
+
import numpy as np
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
from tme.backends import backend as be
|
6
|
+
from tme.preprocessing._utils import compute_fourier_shape
|
7
|
+
from tme.preprocessing.frequency_filters import BandPassFilter, LinearWhiteningFilter
|
8
|
+
|
9
|
+
|
10
|
+
class TestBandPassFilter:
|
11
|
+
@pytest.fixture
|
12
|
+
def band_pass_filter(self):
|
13
|
+
return BandPassFilter()
|
14
|
+
|
15
|
+
@pytest.mark.parametrize(
|
16
|
+
"shape, lowpass, highpass, sampling_rate",
|
17
|
+
[
|
18
|
+
((10, 10), 0.2, 0.8, 1),
|
19
|
+
((20, 20, 20), 0.1, 0.9, 2),
|
20
|
+
((30, 30), None, 0.5, 1),
|
21
|
+
((40, 40), 0.3, None, 0.5),
|
22
|
+
],
|
23
|
+
)
|
24
|
+
def test_discrete_bandpass(
|
25
|
+
self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
|
26
|
+
):
|
27
|
+
result = BandPassFilter.discrete_bandpass(
|
28
|
+
shape, lowpass, highpass, sampling_rate
|
29
|
+
)
|
30
|
+
assert isinstance(result, type(be.ones((1,))))
|
31
|
+
assert result.shape == shape
|
32
|
+
assert np.all((result >= 0) & (result <= 1))
|
33
|
+
|
34
|
+
@pytest.mark.parametrize(
|
35
|
+
"shape, lowpass, highpass, sampling_rate",
|
36
|
+
[
|
37
|
+
((10, 10), 0.2, 0.8, 1),
|
38
|
+
((20, 20, 20), 0.1, 0.9, 2),
|
39
|
+
((30, 30), None, 0.5, 1),
|
40
|
+
((40, 40), 0.3, None, 0.5),
|
41
|
+
],
|
42
|
+
)
|
43
|
+
def test_gaussian_bandpass(
|
44
|
+
self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
|
45
|
+
):
|
46
|
+
result = BandPassFilter.gaussian_bandpass(
|
47
|
+
shape, lowpass, highpass, sampling_rate
|
48
|
+
)
|
49
|
+
assert isinstance(result, type(be.ones((1,))))
|
50
|
+
assert result.shape == shape
|
51
|
+
assert np.all((result >= 0) & (result <= 1))
|
52
|
+
|
53
|
+
@pytest.mark.parametrize("use_gaussian", [True, False])
|
54
|
+
@pytest.mark.parametrize("return_real_fourier", [True, False])
|
55
|
+
@pytest.mark.parametrize("shape_is_real_fourier", [True, False])
|
56
|
+
def test_call_method(
|
57
|
+
self,
|
58
|
+
band_pass_filter: BandPassFilter,
|
59
|
+
use_gaussian: bool,
|
60
|
+
return_real_fourier: bool,
|
61
|
+
shape_is_real_fourier: bool,
|
62
|
+
):
|
63
|
+
band_pass_filter.use_gaussian = use_gaussian
|
64
|
+
band_pass_filter.return_real_fourier = return_real_fourier
|
65
|
+
band_pass_filter.shape_is_real_fourier = shape_is_real_fourier
|
66
|
+
|
67
|
+
result = band_pass_filter(shape=(10, 10), lowpass=0.2, highpass=0.8)
|
68
|
+
|
69
|
+
assert isinstance(result, dict)
|
70
|
+
assert "data" in result
|
71
|
+
assert "sampling_rate" in result
|
72
|
+
assert "is_multiplicative_filter" in result
|
73
|
+
assert isinstance(result["data"], type(be.ones((1,))))
|
74
|
+
assert result["is_multiplicative_filter"] is True
|
75
|
+
|
76
|
+
def test_default_values(self, band_pass_filter: BandPassFilter):
|
77
|
+
assert band_pass_filter.lowpass is None
|
78
|
+
assert band_pass_filter.highpass is None
|
79
|
+
assert band_pass_filter.sampling_rate == 1
|
80
|
+
assert band_pass_filter.use_gaussian is True
|
81
|
+
assert band_pass_filter.return_real_fourier is False
|
82
|
+
assert band_pass_filter.shape_is_real_fourier is False
|
83
|
+
|
84
|
+
@pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
|
85
|
+
def test_return_real_fourier(self, shape: Tuple[int]):
|
86
|
+
bpf = BandPassFilter(return_real_fourier=True)
|
87
|
+
result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
|
88
|
+
expected_shape = tuple(compute_fourier_shape(shape, False))
|
89
|
+
assert result["data"].shape == expected_shape
|
90
|
+
|
91
|
+
|
92
|
+
class TestLinearWhiteningFilter:
|
93
|
+
@pytest.mark.parametrize(
|
94
|
+
"shape, n_bins, batch_dimension",
|
95
|
+
[
|
96
|
+
((10, 10), None, None),
|
97
|
+
((20, 20, 20), 15, 0),
|
98
|
+
((30, 30, 30), 20, 1),
|
99
|
+
((40, 40, 40, 40), 25, 2),
|
100
|
+
],
|
101
|
+
)
|
102
|
+
def test_compute_spectrum(
|
103
|
+
self, shape: Tuple[int], n_bins: int, batch_dimension: int
|
104
|
+
):
|
105
|
+
data_rfft = be.fft.rfftn(be.random.random(shape))
|
106
|
+
bins, radial_averages = LinearWhiteningFilter._compute_spectrum(
|
107
|
+
data_rfft, n_bins, batch_dimension
|
108
|
+
)
|
109
|
+
data_shape = tuple(
|
110
|
+
int(x) for i, x in enumerate(data_rfft.shape) if i != batch_dimension
|
111
|
+
)
|
112
|
+
|
113
|
+
assert isinstance(bins, np.ndarray)
|
114
|
+
assert isinstance(radial_averages, np.ndarray)
|
115
|
+
assert bins.shape == data_shape
|
116
|
+
assert radial_averages.ndim == 1
|
117
|
+
assert np.all(radial_averages >= 0) and np.all(radial_averages <= 1)
|
118
|
+
|
119
|
+
@pytest.mark.parametrize("shape", ((10, 10), (21, 20, 31)))
|
120
|
+
@pytest.mark.parametrize("shape_is_real_fourier", (False, True))
|
121
|
+
@pytest.mark.parametrize("order", (1, 3))
|
122
|
+
def test_interpolate_spectrum(
|
123
|
+
self, shape: Tuple[int], shape_is_real_fourier: bool, order: int
|
124
|
+
):
|
125
|
+
spectrum = be.random.random(100)
|
126
|
+
result = LinearWhiteningFilter()._interpolate_spectrum(
|
127
|
+
spectrum, shape, shape_is_real_fourier, order
|
128
|
+
)
|
129
|
+
assert result.shape == tuple(shape)
|
130
|
+
assert isinstance(result, np.ndarray)
|
131
|
+
|
132
|
+
@pytest.mark.parametrize(
|
133
|
+
"shape, n_bins, batch_dimension, order",
|
134
|
+
[
|
135
|
+
((10, 10), None, None, 1),
|
136
|
+
((20, 20, 20), 15, 0, 2),
|
137
|
+
((30, 30, 30), 20, 1, None),
|
138
|
+
],
|
139
|
+
)
|
140
|
+
def test_call_method(
|
141
|
+
self,
|
142
|
+
shape: Tuple[int],
|
143
|
+
n_bins: int,
|
144
|
+
batch_dimension: int,
|
145
|
+
order: int,
|
146
|
+
):
|
147
|
+
data = be.random.random(shape)
|
148
|
+
result = LinearWhiteningFilter()(
|
149
|
+
data=data, n_bins=n_bins, batch_dimension=batch_dimension, order=order
|
150
|
+
)
|
151
|
+
|
152
|
+
assert isinstance(result, dict)
|
153
|
+
assert result.get("data", False) is not False
|
154
|
+
assert result.get("is_multiplicative_filter", False)
|
155
|
+
assert isinstance(result["data"], type(be.ones((1,))))
|
156
|
+
data_shape = tuple(
|
157
|
+
int(x) for i, x in enumerate(data.shape) if i != batch_dimension
|
158
|
+
)
|
159
|
+
assert result["data"].shape == tuple(compute_fourier_shape(data_shape, False))
|
160
|
+
|
161
|
+
def test_call_method_with_data_rfft(self):
|
162
|
+
shape = (30, 30, 30)
|
163
|
+
data_rfft = be.fft.rfftn(be.random.random(shape))
|
164
|
+
result = LinearWhiteningFilter()(data_rfft=data_rfft)
|
165
|
+
|
166
|
+
assert isinstance(result, dict)
|
167
|
+
assert result.get("data", False) is not False
|
168
|
+
assert result.get("is_multiplicative_filter", False)
|
169
|
+
assert isinstance(result["data"], type(be.ones((1,))))
|
170
|
+
assert result["data"].shape == data_rfft.shape
|
171
|
+
|
172
|
+
@pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
|
173
|
+
def test_filter_mask_range(self, shape: Tuple[int]):
|
174
|
+
data = be.random.random(shape)
|
175
|
+
result = LinearWhiteningFilter()(data=data)
|
176
|
+
|
177
|
+
filter_mask = result["data"]
|
178
|
+
assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
|
@@ -0,0 +1,136 @@
|
|
1
|
+
import pytest
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
from tme import Density, Structure, Preprocessor
|
5
|
+
|
6
|
+
|
7
|
+
class TestPreprocessor:
|
8
|
+
def setup_method(self):
|
9
|
+
self.density = Density.from_file(filename="tests/data/Raw/em_map.map")
|
10
|
+
self.structure = Structure.from_file("tests/data/Structures/5khe.cif")
|
11
|
+
self.structure_density = Density.from_structure(
|
12
|
+
filename_or_structure="tests/data/Structures/5khe.cif",
|
13
|
+
origin=self.density.origin,
|
14
|
+
shape=self.density.shape,
|
15
|
+
sampling_rate=self.density.sampling_rate,
|
16
|
+
)
|
17
|
+
self.preprocessor = Preprocessor()
|
18
|
+
|
19
|
+
def teardown_method(self):
|
20
|
+
self.density = None
|
21
|
+
self.structure_density = None
|
22
|
+
|
23
|
+
def test_initialization(self):
|
24
|
+
_ = Preprocessor()
|
25
|
+
|
26
|
+
def test_apply_method_error(self):
|
27
|
+
with pytest.raises(TypeError):
|
28
|
+
self.preprocessor.apply_method(method=None, parameters={})
|
29
|
+
|
30
|
+
with pytest.raises(NotImplementedError):
|
31
|
+
self.preprocessor.apply_method(method="None", parameters={})
|
32
|
+
|
33
|
+
def test_method_to_id_error(self):
|
34
|
+
with pytest.raises(TypeError):
|
35
|
+
self.preprocessor.method_to_id(method=None, parameters={})
|
36
|
+
|
37
|
+
with pytest.raises(NotImplementedError):
|
38
|
+
self.preprocessor.method_to_id(method="None", parameters={})
|
39
|
+
|
40
|
+
def test_method_to_id(self):
|
41
|
+
ret = self.preprocessor.method_to_id(method="gaussian_filter", parameters={})
|
42
|
+
assert isinstance(ret, str)
|
43
|
+
|
44
|
+
@pytest.mark.parametrize("low_sigma,high_sigma", [(0, 1), (3, 5)])
|
45
|
+
def test_difference_of_gaussian_filter(self, low_sigma, high_sigma):
|
46
|
+
_ = self.preprocessor.difference_of_gaussian_filter(
|
47
|
+
template=self.structure_density.data,
|
48
|
+
low_sigma=low_sigma,
|
49
|
+
high_sigma=high_sigma,
|
50
|
+
)
|
51
|
+
|
52
|
+
@pytest.mark.parametrize("smallest_size,largest_size", [(1, 10), (2, 20)])
|
53
|
+
def test_bandpass_filter(self, smallest_size, largest_size):
|
54
|
+
_ = self.preprocessor.bandpass_filter(
|
55
|
+
template=self.structure_density.data,
|
56
|
+
lowpass=smallest_size,
|
57
|
+
highpass=largest_size,
|
58
|
+
sampling_rate=1,
|
59
|
+
)
|
60
|
+
|
61
|
+
@pytest.mark.parametrize("lbd,sigma_range", [(1, (2, 4)), (20, (1, 6))])
|
62
|
+
def test_local_gaussian_alignment_filter(self, lbd, sigma_range):
|
63
|
+
_ = self.preprocessor.local_gaussian_alignment_filter(
|
64
|
+
template=self.structure_density.data,
|
65
|
+
target=self.density.data,
|
66
|
+
lbd=lbd,
|
67
|
+
sigma_range=sigma_range,
|
68
|
+
)
|
69
|
+
|
70
|
+
@pytest.mark.parametrize(
|
71
|
+
"lbd,sigma_range,gaussian_sigma", [(1, (2, 4), 1), (20, (1, 6), 3)]
|
72
|
+
)
|
73
|
+
def test_local_gaussian_filter(self, lbd, sigma_range, gaussian_sigma):
|
74
|
+
_ = self.preprocessor.local_gaussian_filter(
|
75
|
+
template=self.structure_density.data,
|
76
|
+
lbd=lbd,
|
77
|
+
sigma_range=sigma_range,
|
78
|
+
gaussian_sigma=gaussian_sigma,
|
79
|
+
)
|
80
|
+
|
81
|
+
@pytest.mark.parametrize(
|
82
|
+
"edge_algorithm",
|
83
|
+
["sobel", "prewitt", "laplace", "gaussian", "gaussian_laplace"],
|
84
|
+
)
|
85
|
+
@pytest.mark.parametrize("reverse", [(True), (False)])
|
86
|
+
def test_edge_gaussian_filter(self, edge_algorithm, reverse):
|
87
|
+
_ = self.preprocessor.edge_gaussian_filter(
|
88
|
+
template=self.structure_density.data,
|
89
|
+
edge_algorithm=edge_algorithm,
|
90
|
+
reverse=reverse,
|
91
|
+
sigma=3,
|
92
|
+
)
|
93
|
+
|
94
|
+
@pytest.mark.parametrize("width", range(1, 9, 3))
|
95
|
+
def test_mean_filter(self, width):
|
96
|
+
_ = self.preprocessor.mean_filter(
|
97
|
+
template=self.structure_density.data,
|
98
|
+
width=width,
|
99
|
+
)
|
100
|
+
|
101
|
+
@pytest.mark.parametrize("width", range(1, 9, 3))
|
102
|
+
def test_kaiserb_filter(self, width):
|
103
|
+
_ = self.preprocessor.kaiserb_filter(
|
104
|
+
template=self.structure_density.data,
|
105
|
+
width=width,
|
106
|
+
)
|
107
|
+
|
108
|
+
@pytest.mark.parametrize("width", range(1, 9, 3))
|
109
|
+
def test_blob_filter(self, width):
|
110
|
+
_ = self.preprocessor.blob_filter(
|
111
|
+
template=self.structure_density.data,
|
112
|
+
width=width,
|
113
|
+
)
|
114
|
+
|
115
|
+
@pytest.mark.parametrize("width", range(1, 9, 3))
|
116
|
+
def test_hamming_filter(self, width):
|
117
|
+
_ = self.preprocessor.hamming_filter(
|
118
|
+
template=self.structure_density.data,
|
119
|
+
width=width,
|
120
|
+
)
|
121
|
+
|
122
|
+
@pytest.mark.parametrize("rank", range(1, 9, 3))
|
123
|
+
def test_rank_filter(self, rank):
|
124
|
+
_ = self.preprocessor.rank_filter(
|
125
|
+
template=self.structure_density.data,
|
126
|
+
rank=rank,
|
127
|
+
)
|
128
|
+
|
129
|
+
@pytest.mark.parametrize("infinite_plane", [False, True])
|
130
|
+
def test_continuous_wedge_mask(self, infinite_plane):
|
131
|
+
_ = self.preprocessor.continuous_wedge_mask(
|
132
|
+
start_tilt=50,
|
133
|
+
stop_tilt=-40,
|
134
|
+
shape=(50, 50, 50),
|
135
|
+
infinite_plane=infinite_plane,
|
136
|
+
)
|
@@ -0,0 +1,79 @@
|
|
1
|
+
import pytest
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
from tme.preprocessing._utils import (
|
5
|
+
fftfreqn,
|
6
|
+
centered_grid,
|
7
|
+
shift_fourier,
|
8
|
+
compute_fourier_shape,
|
9
|
+
crop_real_fourier,
|
10
|
+
compute_tilt_shape,
|
11
|
+
frequency_grid_at_angle,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class TestPreprocessUtils:
|
16
|
+
@pytest.mark.parametrize("reduce_dim", (False, True))
|
17
|
+
@pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
|
18
|
+
def test_compute_tilt_shape(self, shape, reduce_dim):
|
19
|
+
tilt_shape = compute_tilt_shape(
|
20
|
+
shape=shape, opening_axis=0, reduce_dim=reduce_dim
|
21
|
+
)
|
22
|
+
if reduce_dim:
|
23
|
+
assert len(tilt_shape) == len(shape) - 1
|
24
|
+
else:
|
25
|
+
assert len(tilt_shape) == len(shape)
|
26
|
+
assert tilt_shape[0] == 1
|
27
|
+
|
28
|
+
@pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
|
29
|
+
def test_centered_grid(self, shape):
|
30
|
+
grid = centered_grid(shape=shape)
|
31
|
+
assert grid.shape[0] == len(shape)
|
32
|
+
center = tuple(int(x) // 2 for x in shape)
|
33
|
+
for i in range(grid.shape[0]):
|
34
|
+
assert grid[i][center] == 0
|
35
|
+
assert np.max(grid[i]) <= center[i]
|
36
|
+
|
37
|
+
@pytest.mark.parametrize("shape", ((10, 15, 30),))
|
38
|
+
@pytest.mark.parametrize("sampling_rate", (0.5, 1, 2))
|
39
|
+
@pytest.mark.parametrize("angle", (-5, 0, 5))
|
40
|
+
@pytest.mark.parametrize("wedge", ((0, 1), (1, 0)))
|
41
|
+
def test_frequency_grid_at_angle(self, shape, sampling_rate, angle, wedge):
|
42
|
+
opening, tilt = wedge
|
43
|
+
fgrid = frequency_grid_at_angle(
|
44
|
+
shape=shape,
|
45
|
+
angle=angle,
|
46
|
+
sampling_rate=sampling_rate,
|
47
|
+
opening_axis=opening,
|
48
|
+
tilt_axis=tilt,
|
49
|
+
)
|
50
|
+
tilt_shape = compute_tilt_shape(shape, opening_axis=opening, reduce_dim=True)
|
51
|
+
assert fgrid.shape == tuple(tilt_shape)
|
52
|
+
assert fgrid.max() <= np.sqrt(1 / sampling_rate * len(shape))
|
53
|
+
|
54
|
+
@pytest.mark.parametrize("n", [10, 100, 1000])
|
55
|
+
@pytest.mark.parametrize("sampling_rate", range(1, 4))
|
56
|
+
def test_fftfreqn(self, n, sampling_rate):
|
57
|
+
assert np.allclose(
|
58
|
+
fftfreqn(
|
59
|
+
shape=(n,), sampling_rate=sampling_rate, compute_euclidean_norm=True
|
60
|
+
),
|
61
|
+
np.abs(np.fft.ifftshift(np.fft.fftfreq(n=n, d=sampling_rate))),
|
62
|
+
)
|
63
|
+
|
64
|
+
@pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
|
65
|
+
def test_crop_real_fourier(self, shape):
|
66
|
+
data = np.random.rand(*shape)
|
67
|
+
data_crop = crop_real_fourier(data)
|
68
|
+
assert data_crop.shape == tuple(compute_fourier_shape(data.shape, False))
|
69
|
+
|
70
|
+
@pytest.mark.parametrize("real", (False, True))
|
71
|
+
@pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
|
72
|
+
def test_compute_fourier_shape(self, shape, real: bool):
|
73
|
+
data = np.random.rand(*shape)
|
74
|
+
func = np.fft.rfftn if real else np.fft.fftn
|
75
|
+
assert func(data).shape == tuple(compute_fourier_shape(data.shape, not real))
|
76
|
+
|
77
|
+
def test_shift_fourier(self):
|
78
|
+
data = np.random.rand(10)
|
79
|
+
assert np.allclose(shift_fourier(data, False), np.fft.ifftshift(data))
|