pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/METADATA +1 -1
  5. pytme-0.2.5.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +226 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +3 -9
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +14 -10
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +14 -12
  61. tme/matching_exhaustive.py +17 -15
  62. tme/matching_optimization.py +215 -208
  63. tme/matching_utils.py +1 -0
  64. tme/preprocessing/_utils.py +14 -14
  65. tme/preprocessing/composable_filter.py +0 -2
  66. tme/preprocessing/compose.py +4 -4
  67. tme/preprocessing/frequency_filters.py +32 -35
  68. tme/preprocessing/tilt_series.py +198 -117
  69. tme/preprocessor.py +24 -246
  70. tme/structure.py +22 -22
  71. pytme-0.2.3.dist-info/RECORD +0 -75
  72. tme/matching_memory.py +0 -383
  73. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/estimate_ram_usage.py +0 -0
  74. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/postprocess.py +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/LICENSE +0 -0
  76. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/entry_points.txt +0 -0
File without changes
@@ -0,0 +1,76 @@
1
+ import pytest
2
+
3
+ from tme.preprocessing import Compose
4
+ from tme.backends import backend as be
5
+
6
+
7
+ def mock_transform1(**kwargs):
8
+ return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
9
+
10
+
11
+ def mock_transform2(**kwargs):
12
+ return {"data": be.ones((10, 10)) * 2, "is_multiplicative_filter": True}
13
+
14
+
15
+ def mock_transform3(**kwargs):
16
+ return {"extra_info": "test"}
17
+
18
+
19
+ class TestCompose:
20
+ @pytest.fixture
21
+ def compose_instance(self):
22
+ return Compose((mock_transform1, mock_transform2, mock_transform3))
23
+
24
+ def test_init(self):
25
+ transforms = (mock_transform1, mock_transform2)
26
+ compose = Compose(transforms)
27
+ assert compose.transforms == transforms
28
+
29
+ def test_call_empty_transforms(self):
30
+ compose = Compose(())
31
+ result = compose()
32
+ assert result == {}
33
+
34
+ def test_call_single_transform(self):
35
+ compose = Compose((mock_transform1,))
36
+ result = compose()
37
+ assert "data" in result
38
+ assert result.get("is_multiplicative_filter", False)
39
+ assert be.allclose(result["data"], be.ones((10, 10)))
40
+
41
+ def test_call_multiple_transforms(self, compose_instance):
42
+ result = compose_instance()
43
+ assert "data" in result
44
+ assert "extra_info" not in result
45
+ assert be.allclose(result["data"], be.ones((10, 10)) * 2)
46
+
47
+ def test_multiplicative_filter_composition(self):
48
+ compose = Compose((mock_transform1, mock_transform2))
49
+ result = compose()
50
+ assert "data" in result
51
+ assert be.allclose(result["data"], be.ones((10, 10)) * 2)
52
+
53
+ @pytest.mark.parametrize(
54
+ "kwargs", [{}, {"extra_param": "test"}, {"data": be.zeros((5, 5))}]
55
+ )
56
+ def test_call_with_kwargs(self, compose_instance, kwargs):
57
+ result = compose_instance(**kwargs)
58
+ assert "data" in result
59
+ assert "extra_info" not in result
60
+
61
+ def test_non_multiplicative_filter(self):
62
+ def non_mult_transform(**kwargs):
63
+ return {"data": be.ones((10, 10)) * 3, "is_multiplicative_filter": False}
64
+
65
+ compose = Compose((mock_transform1, non_mult_transform))
66
+ result = compose()
67
+ assert "data" in result
68
+ assert be.allclose(result["data"], be.ones((10, 10)) * 3)
69
+
70
+ def test_error_handling(self):
71
+ def error_transform(**kwargs):
72
+ raise ValueError("Test error")
73
+
74
+ compose = Compose((mock_transform1, error_transform))
75
+ with pytest.raises(ValueError, match="Test error"):
76
+ compose()
@@ -0,0 +1,178 @@
1
+ import pytest
2
+ import numpy as np
3
+ from typing import Tuple
4
+
5
+ from tme.backends import backend as be
6
+ from tme.preprocessing._utils import compute_fourier_shape
7
+ from tme.preprocessing.frequency_filters import BandPassFilter, LinearWhiteningFilter
8
+
9
+
10
+ class TestBandPassFilter:
11
+ @pytest.fixture
12
+ def band_pass_filter(self):
13
+ return BandPassFilter()
14
+
15
+ @pytest.mark.parametrize(
16
+ "shape, lowpass, highpass, sampling_rate",
17
+ [
18
+ ((10, 10), 0.2, 0.8, 1),
19
+ ((20, 20, 20), 0.1, 0.9, 2),
20
+ ((30, 30), None, 0.5, 1),
21
+ ((40, 40), 0.3, None, 0.5),
22
+ ],
23
+ )
24
+ def test_discrete_bandpass(
25
+ self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
26
+ ):
27
+ result = BandPassFilter.discrete_bandpass(
28
+ shape, lowpass, highpass, sampling_rate
29
+ )
30
+ assert isinstance(result, type(be.ones((1,))))
31
+ assert result.shape == shape
32
+ assert np.all((result >= 0) & (result <= 1))
33
+
34
+ @pytest.mark.parametrize(
35
+ "shape, lowpass, highpass, sampling_rate",
36
+ [
37
+ ((10, 10), 0.2, 0.8, 1),
38
+ ((20, 20, 20), 0.1, 0.9, 2),
39
+ ((30, 30), None, 0.5, 1),
40
+ ((40, 40), 0.3, None, 0.5),
41
+ ],
42
+ )
43
+ def test_gaussian_bandpass(
44
+ self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
45
+ ):
46
+ result = BandPassFilter.gaussian_bandpass(
47
+ shape, lowpass, highpass, sampling_rate
48
+ )
49
+ assert isinstance(result, type(be.ones((1,))))
50
+ assert result.shape == shape
51
+ assert np.all((result >= 0) & (result <= 1))
52
+
53
+ @pytest.mark.parametrize("use_gaussian", [True, False])
54
+ @pytest.mark.parametrize("return_real_fourier", [True, False])
55
+ @pytest.mark.parametrize("shape_is_real_fourier", [True, False])
56
+ def test_call_method(
57
+ self,
58
+ band_pass_filter: BandPassFilter,
59
+ use_gaussian: bool,
60
+ return_real_fourier: bool,
61
+ shape_is_real_fourier: bool,
62
+ ):
63
+ band_pass_filter.use_gaussian = use_gaussian
64
+ band_pass_filter.return_real_fourier = return_real_fourier
65
+ band_pass_filter.shape_is_real_fourier = shape_is_real_fourier
66
+
67
+ result = band_pass_filter(shape=(10, 10), lowpass=0.2, highpass=0.8)
68
+
69
+ assert isinstance(result, dict)
70
+ assert "data" in result
71
+ assert "sampling_rate" in result
72
+ assert "is_multiplicative_filter" in result
73
+ assert isinstance(result["data"], type(be.ones((1,))))
74
+ assert result["is_multiplicative_filter"] is True
75
+
76
+ def test_default_values(self, band_pass_filter: BandPassFilter):
77
+ assert band_pass_filter.lowpass is None
78
+ assert band_pass_filter.highpass is None
79
+ assert band_pass_filter.sampling_rate == 1
80
+ assert band_pass_filter.use_gaussian is True
81
+ assert band_pass_filter.return_real_fourier is False
82
+ assert band_pass_filter.shape_is_real_fourier is False
83
+
84
+ @pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
85
+ def test_return_real_fourier(self, shape: Tuple[int]):
86
+ bpf = BandPassFilter(return_real_fourier=True)
87
+ result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
88
+ expected_shape = tuple(compute_fourier_shape(shape, False))
89
+ assert result["data"].shape == expected_shape
90
+
91
+
92
+ class TestLinearWhiteningFilter:
93
+ @pytest.mark.parametrize(
94
+ "shape, n_bins, batch_dimension",
95
+ [
96
+ ((10, 10), None, None),
97
+ ((20, 20, 20), 15, 0),
98
+ ((30, 30, 30), 20, 1),
99
+ ((40, 40, 40, 40), 25, 2),
100
+ ],
101
+ )
102
+ def test_compute_spectrum(
103
+ self, shape: Tuple[int], n_bins: int, batch_dimension: int
104
+ ):
105
+ data_rfft = be.fft.rfftn(be.random.random(shape))
106
+ bins, radial_averages = LinearWhiteningFilter._compute_spectrum(
107
+ data_rfft, n_bins, batch_dimension
108
+ )
109
+ data_shape = tuple(
110
+ int(x) for i, x in enumerate(data_rfft.shape) if i != batch_dimension
111
+ )
112
+
113
+ assert isinstance(bins, np.ndarray)
114
+ assert isinstance(radial_averages, np.ndarray)
115
+ assert bins.shape == data_shape
116
+ assert radial_averages.ndim == 1
117
+ assert np.all(radial_averages >= 0) and np.all(radial_averages <= 1)
118
+
119
+ @pytest.mark.parametrize("shape", ((10, 10), (21, 20, 31)))
120
+ @pytest.mark.parametrize("shape_is_real_fourier", (False, True))
121
+ @pytest.mark.parametrize("order", (1, 3))
122
+ def test_interpolate_spectrum(
123
+ self, shape: Tuple[int], shape_is_real_fourier: bool, order: int
124
+ ):
125
+ spectrum = be.random.random(100)
126
+ result = LinearWhiteningFilter()._interpolate_spectrum(
127
+ spectrum, shape, shape_is_real_fourier, order
128
+ )
129
+ assert result.shape == tuple(shape)
130
+ assert isinstance(result, np.ndarray)
131
+
132
+ @pytest.mark.parametrize(
133
+ "shape, n_bins, batch_dimension, order",
134
+ [
135
+ ((10, 10), None, None, 1),
136
+ ((20, 20, 20), 15, 0, 2),
137
+ ((30, 30, 30), 20, 1, None),
138
+ ],
139
+ )
140
+ def test_call_method(
141
+ self,
142
+ shape: Tuple[int],
143
+ n_bins: int,
144
+ batch_dimension: int,
145
+ order: int,
146
+ ):
147
+ data = be.random.random(shape)
148
+ result = LinearWhiteningFilter()(
149
+ data=data, n_bins=n_bins, batch_dimension=batch_dimension, order=order
150
+ )
151
+
152
+ assert isinstance(result, dict)
153
+ assert result.get("data", False) is not False
154
+ assert result.get("is_multiplicative_filter", False)
155
+ assert isinstance(result["data"], type(be.ones((1,))))
156
+ data_shape = tuple(
157
+ int(x) for i, x in enumerate(data.shape) if i != batch_dimension
158
+ )
159
+ assert result["data"].shape == tuple(compute_fourier_shape(data_shape, False))
160
+
161
+ def test_call_method_with_data_rfft(self):
162
+ shape = (30, 30, 30)
163
+ data_rfft = be.fft.rfftn(be.random.random(shape))
164
+ result = LinearWhiteningFilter()(data_rfft=data_rfft)
165
+
166
+ assert isinstance(result, dict)
167
+ assert result.get("data", False) is not False
168
+ assert result.get("is_multiplicative_filter", False)
169
+ assert isinstance(result["data"], type(be.ones((1,))))
170
+ assert result["data"].shape == data_rfft.shape
171
+
172
+ @pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
173
+ def test_filter_mask_range(self, shape: Tuple[int]):
174
+ data = be.random.random(shape)
175
+ result = LinearWhiteningFilter()(data=data)
176
+
177
+ filter_mask = result["data"]
178
+ assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
@@ -0,0 +1,136 @@
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from tme import Density, Structure, Preprocessor
5
+
6
+
7
+ class TestPreprocessor:
8
+ def setup_method(self):
9
+ self.density = Density.from_file(filename="tests/data/Raw/em_map.map")
10
+ self.structure = Structure.from_file("tests/data/Structures/5khe.cif")
11
+ self.structure_density = Density.from_structure(
12
+ filename_or_structure="tests/data/Structures/5khe.cif",
13
+ origin=self.density.origin,
14
+ shape=self.density.shape,
15
+ sampling_rate=self.density.sampling_rate,
16
+ )
17
+ self.preprocessor = Preprocessor()
18
+
19
+ def teardown_method(self):
20
+ self.density = None
21
+ self.structure_density = None
22
+
23
+ def test_initialization(self):
24
+ _ = Preprocessor()
25
+
26
+ def test_apply_method_error(self):
27
+ with pytest.raises(TypeError):
28
+ self.preprocessor.apply_method(method=None, parameters={})
29
+
30
+ with pytest.raises(NotImplementedError):
31
+ self.preprocessor.apply_method(method="None", parameters={})
32
+
33
+ def test_method_to_id_error(self):
34
+ with pytest.raises(TypeError):
35
+ self.preprocessor.method_to_id(method=None, parameters={})
36
+
37
+ with pytest.raises(NotImplementedError):
38
+ self.preprocessor.method_to_id(method="None", parameters={})
39
+
40
+ def test_method_to_id(self):
41
+ ret = self.preprocessor.method_to_id(method="gaussian_filter", parameters={})
42
+ assert isinstance(ret, str)
43
+
44
+ @pytest.mark.parametrize("low_sigma,high_sigma", [(0, 1), (3, 5)])
45
+ def test_difference_of_gaussian_filter(self, low_sigma, high_sigma):
46
+ _ = self.preprocessor.difference_of_gaussian_filter(
47
+ template=self.structure_density.data,
48
+ low_sigma=low_sigma,
49
+ high_sigma=high_sigma,
50
+ )
51
+
52
+ @pytest.mark.parametrize("smallest_size,largest_size", [(1, 10), (2, 20)])
53
+ def test_bandpass_filter(self, smallest_size, largest_size):
54
+ _ = self.preprocessor.bandpass_filter(
55
+ template=self.structure_density.data,
56
+ lowpass=smallest_size,
57
+ highpass=largest_size,
58
+ sampling_rate=1,
59
+ )
60
+
61
+ @pytest.mark.parametrize("lbd,sigma_range", [(1, (2, 4)), (20, (1, 6))])
62
+ def test_local_gaussian_alignment_filter(self, lbd, sigma_range):
63
+ _ = self.preprocessor.local_gaussian_alignment_filter(
64
+ template=self.structure_density.data,
65
+ target=self.density.data,
66
+ lbd=lbd,
67
+ sigma_range=sigma_range,
68
+ )
69
+
70
+ @pytest.mark.parametrize(
71
+ "lbd,sigma_range,gaussian_sigma", [(1, (2, 4), 1), (20, (1, 6), 3)]
72
+ )
73
+ def test_local_gaussian_filter(self, lbd, sigma_range, gaussian_sigma):
74
+ _ = self.preprocessor.local_gaussian_filter(
75
+ template=self.structure_density.data,
76
+ lbd=lbd,
77
+ sigma_range=sigma_range,
78
+ gaussian_sigma=gaussian_sigma,
79
+ )
80
+
81
+ @pytest.mark.parametrize(
82
+ "edge_algorithm",
83
+ ["sobel", "prewitt", "laplace", "gaussian", "gaussian_laplace"],
84
+ )
85
+ @pytest.mark.parametrize("reverse", [(True), (False)])
86
+ def test_edge_gaussian_filter(self, edge_algorithm, reverse):
87
+ _ = self.preprocessor.edge_gaussian_filter(
88
+ template=self.structure_density.data,
89
+ edge_algorithm=edge_algorithm,
90
+ reverse=reverse,
91
+ sigma=3,
92
+ )
93
+
94
+ @pytest.mark.parametrize("width", range(1, 9, 3))
95
+ def test_mean_filter(self, width):
96
+ _ = self.preprocessor.mean_filter(
97
+ template=self.structure_density.data,
98
+ width=width,
99
+ )
100
+
101
+ @pytest.mark.parametrize("width", range(1, 9, 3))
102
+ def test_kaiserb_filter(self, width):
103
+ _ = self.preprocessor.kaiserb_filter(
104
+ template=self.structure_density.data,
105
+ width=width,
106
+ )
107
+
108
+ @pytest.mark.parametrize("width", range(1, 9, 3))
109
+ def test_blob_filter(self, width):
110
+ _ = self.preprocessor.blob_filter(
111
+ template=self.structure_density.data,
112
+ width=width,
113
+ )
114
+
115
+ @pytest.mark.parametrize("width", range(1, 9, 3))
116
+ def test_hamming_filter(self, width):
117
+ _ = self.preprocessor.hamming_filter(
118
+ template=self.structure_density.data,
119
+ width=width,
120
+ )
121
+
122
+ @pytest.mark.parametrize("rank", range(1, 9, 3))
123
+ def test_rank_filter(self, rank):
124
+ _ = self.preprocessor.rank_filter(
125
+ template=self.structure_density.data,
126
+ rank=rank,
127
+ )
128
+
129
+ @pytest.mark.parametrize("infinite_plane", [False, True])
130
+ def test_continuous_wedge_mask(self, infinite_plane):
131
+ _ = self.preprocessor.continuous_wedge_mask(
132
+ start_tilt=50,
133
+ stop_tilt=-40,
134
+ shape=(50, 50, 50),
135
+ infinite_plane=infinite_plane,
136
+ )
@@ -0,0 +1,79 @@
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from tme.preprocessing._utils import (
5
+ fftfreqn,
6
+ centered_grid,
7
+ shift_fourier,
8
+ compute_fourier_shape,
9
+ crop_real_fourier,
10
+ compute_tilt_shape,
11
+ frequency_grid_at_angle,
12
+ )
13
+
14
+
15
+ class TestPreprocessUtils:
16
+ @pytest.mark.parametrize("reduce_dim", (False, True))
17
+ @pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
18
+ def test_compute_tilt_shape(self, shape, reduce_dim):
19
+ tilt_shape = compute_tilt_shape(
20
+ shape=shape, opening_axis=0, reduce_dim=reduce_dim
21
+ )
22
+ if reduce_dim:
23
+ assert len(tilt_shape) == len(shape) - 1
24
+ else:
25
+ assert len(tilt_shape) == len(shape)
26
+ assert tilt_shape[0] == 1
27
+
28
+ @pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
29
+ def test_centered_grid(self, shape):
30
+ grid = centered_grid(shape=shape)
31
+ assert grid.shape[0] == len(shape)
32
+ center = tuple(int(x) // 2 for x in shape)
33
+ for i in range(grid.shape[0]):
34
+ assert grid[i][center] == 0
35
+ assert np.max(grid[i]) <= center[i]
36
+
37
+ @pytest.mark.parametrize("shape", ((10, 15, 30),))
38
+ @pytest.mark.parametrize("sampling_rate", (0.5, 1, 2))
39
+ @pytest.mark.parametrize("angle", (-5, 0, 5))
40
+ @pytest.mark.parametrize("wedge", ((0, 1), (1, 0)))
41
+ def test_frequency_grid_at_angle(self, shape, sampling_rate, angle, wedge):
42
+ opening, tilt = wedge
43
+ fgrid = frequency_grid_at_angle(
44
+ shape=shape,
45
+ angle=angle,
46
+ sampling_rate=sampling_rate,
47
+ opening_axis=opening,
48
+ tilt_axis=tilt,
49
+ )
50
+ tilt_shape = compute_tilt_shape(shape, opening_axis=opening, reduce_dim=True)
51
+ assert fgrid.shape == tuple(tilt_shape)
52
+ assert fgrid.max() <= np.sqrt(1 / sampling_rate * len(shape))
53
+
54
+ @pytest.mark.parametrize("n", [10, 100, 1000])
55
+ @pytest.mark.parametrize("sampling_rate", range(1, 4))
56
+ def test_fftfreqn(self, n, sampling_rate):
57
+ assert np.allclose(
58
+ fftfreqn(
59
+ shape=(n,), sampling_rate=sampling_rate, compute_euclidean_norm=True
60
+ ),
61
+ np.abs(np.fft.ifftshift(np.fft.fftfreq(n=n, d=sampling_rate))),
62
+ )
63
+
64
+ @pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
65
+ def test_crop_real_fourier(self, shape):
66
+ data = np.random.rand(*shape)
67
+ data_crop = crop_real_fourier(data)
68
+ assert data_crop.shape == tuple(compute_fourier_shape(data.shape, False))
69
+
70
+ @pytest.mark.parametrize("real", (False, True))
71
+ @pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
72
+ def test_compute_fourier_shape(self, shape, real: bool):
73
+ data = np.random.rand(*shape)
74
+ func = np.fft.rfftn if real else np.fft.fftn
75
+ assert func(data).shape == tuple(compute_fourier_shape(data.shape, not real))
76
+
77
+ def test_shift_fourier(self):
78
+ data = np.random.rand(10)
79
+ assert np.allclose(shift_fourier(data, False), np.fft.ifftshift(data))