pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
Binary file
@@ -0,0 +1,332 @@
1
+ /* Pybind extensions for template matching score space analyzers.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ */
7
+
8
+ #include <vector>
9
+ #include <iostream>
10
+ #include <limits>
11
+
12
+ #include <pybind11/stl.h>
13
+ #include <pybind11/numpy.h>
14
+ #include <pybind11/pybind11.h>
15
+
16
+ namespace py = pybind11;
17
+
18
+ template <typename T>
19
+ void absolute_minimum_deviation(
20
+ py::array_t<T, py::array::c_style> coordinates,
21
+ py::array_t<T, py::array::c_style> output) {
22
+ auto coordinates_data = coordinates.data();
23
+ auto output_data = output.mutable_data();
24
+ int n = coordinates.shape(0);
25
+ int k = coordinates.shape(1);
26
+ int ik, jk, in, jn;
27
+
28
+ for (int i = 0; i < n; ++i) {
29
+ ik = i * k;
30
+ in = i * n;
31
+ for (int j = i + 1; j < n; ++j) {
32
+ jk = j * k;
33
+ jn = j * n;
34
+ T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
35
+ for (int p = 1; p < k; ++p) {
36
+ min_distance = std::min(min_distance,
37
+ std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
38
+ }
39
+ output_data[in + j] = min_distance;
40
+ output_data[jn + i] = min_distance;
41
+ }
42
+ output_data[in + i] = 0;
43
+ }
44
+ }
45
+
46
+ template <typename T>
47
+ std::pair<double, std::pair<int, int>> max_euclidean_distance(
48
+ py::array_t<T, py::array::c_style> coordinates) {
49
+ auto coordinates_data = coordinates.data();
50
+ int n = coordinates.shape(0);
51
+ int k = coordinates.shape(1);
52
+
53
+ double distance = 0.0;
54
+ double difference = 0.0;
55
+ double max_distance = -1;
56
+ double squared_distances = 0.0;
57
+
58
+ int ik, jk;
59
+ int max_i = -1, max_j = -1;
60
+
61
+ for (int i = 0; i < n; ++i) {
62
+ ik = i * k;
63
+ for (int j = i + 1; j < n; ++j) {
64
+ jk = j * k;
65
+ squared_distances = 0.0;
66
+ for (int p = 0; p < k; ++p) {
67
+ difference = static_cast<double>(
68
+ coordinates_data[ik + p] - coordinates_data[jk + p]
69
+ );
70
+ squared_distances += (difference * difference);
71
+ }
72
+ distance = std::sqrt(squared_distances);
73
+ if (distance > max_distance) {
74
+ max_distance = distance;
75
+ max_i = i;
76
+ max_j = j;
77
+ }
78
+ }
79
+ }
80
+
81
+ return std::make_pair(max_distance, std::make_pair(max_i, max_j));
82
+ }
83
+
84
+
85
+ template <typename T>
86
+ inline py::array_t<int, py::array::c_style> find_candidate_indices(
87
+ py::array_t<T, py::array::c_style> coordinates,
88
+ T min_distance) {
89
+ auto coordinates_data = coordinates.data();
90
+ int n = coordinates.shape(0);
91
+ int k = coordinates.shape(1);
92
+ int ik, jk;
93
+
94
+ std::vector<int> candidate_indices;
95
+ candidate_indices.reserve(n / 2);
96
+ candidate_indices.push_back(0);
97
+
98
+ for (int i = 1; i < n; ++i) {
99
+ bool is_candidate = true;
100
+ ik = i * k;
101
+ for (int candidate_index : candidate_indices) {
102
+ jk = candidate_index * k;
103
+ T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
104
+ for (int p = 1; p < k; ++p) {
105
+ distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
106
+ }
107
+ distance = std::sqrt(distance);
108
+ if (distance <= min_distance) {
109
+ is_candidate = false;
110
+ break;
111
+ }
112
+ }
113
+ if (is_candidate) {
114
+ candidate_indices.push_back(i);
115
+ }
116
+ }
117
+
118
+ py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
119
+ auto output_data = output.mutable_data();
120
+
121
+ for (size_t i = 0; i < candidate_indices.size(); ++i) {
122
+ output_data[i] = candidate_indices[i];
123
+ }
124
+
125
+ return output;
126
+ }
127
+
128
+ template <typename T>
129
+ py::array_t<T, py::array::c_style> find_candidate_coordinates(
130
+ py::array_t<T, py::array::c_style> coordinates,
131
+ T min_distance) {
132
+
133
+ py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
134
+ coordinates, min_distance);
135
+ auto candidate_indices_data = candidate_indices_array.data();
136
+ int num_candidates = candidate_indices_array.shape(0);
137
+ int k = coordinates.shape(1);
138
+ auto coordinates_data = coordinates.data();
139
+
140
+ py::array_t<T, py::array::c_style> output({num_candidates, k});
141
+ auto output_data = output.mutable_data();
142
+
143
+ for (int i = 0; i < num_candidates; ++i) {
144
+ int candidate_index = candidate_indices_data[i] * k;
145
+ std::copy(
146
+ coordinates_data + candidate_index,
147
+ coordinates_data + candidate_index + k,
148
+ output_data + i * k
149
+ );
150
+ }
151
+
152
+ return output;
153
+ }
154
+
155
+ template <typename U, typename T>
156
+ py::dict max_index_by_label(
157
+ py::array_t<U, py::array::c_style> labels,
158
+ py::array_t<T, py::array::c_style> scores
159
+ ) {
160
+
161
+ const U* labels_ptr = labels.data();
162
+ const T* scores_ptr = scores.data();
163
+
164
+ std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
165
+
166
+ U label;
167
+ T score;
168
+ for (ssize_t i = 0; i < labels.size(); ++i) {
169
+ label = labels_ptr[i];
170
+ score = scores_ptr[i];
171
+
172
+ auto it = max_scores.insert({label, {score, i}});
173
+
174
+ if (score > it.first->second.first) {
175
+ it.first->second = {score, i};
176
+ }
177
+ }
178
+
179
+ py::dict ret;
180
+ for (auto& item: max_scores) {
181
+ ret[py::cast(item.first)] = py::cast(item.second.second);
182
+ }
183
+
184
+ return ret;
185
+ }
186
+
187
+
188
+ template <typename T>
189
+ py::tuple online_statistics(
190
+ py::array_t<T, py::array::c_style> arr,
191
+ unsigned long long int n = 0,
192
+ double rmean = 0,
193
+ double ssqd = 0,
194
+ T reference = 0) {
195
+
196
+ auto in = arr.data();
197
+ int size = arr.size();
198
+
199
+ T max_value = std::numeric_limits<T>::lowest();
200
+ T min_value = std::numeric_limits<T>::max();
201
+ double delta, delta_prime;
202
+
203
+ unsigned long long int nbetter_or_equal = 0;
204
+
205
+ for(int i = 0; i < size; i++){
206
+ n++;
207
+ delta = in[i] - rmean;
208
+ rmean += delta / n;
209
+ delta_prime = in[i] - rmean;
210
+ ssqd += delta * delta_prime;
211
+
212
+ max_value = std::max(in[i], max_value);
213
+ min_value = std::min(in[i], min_value);
214
+ if (in[i] >= reference)
215
+ nbetter_or_equal++;
216
+ }
217
+
218
+ return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
219
+ }
220
+
221
+ PYBIND11_MODULE(extensions, m) {
222
+
223
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
224
+ "Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
225
+ py::arg("coordinates"), py::arg("output"));
226
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
227
+ "Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
228
+ py::arg("coordinates"), py::arg("output"));
229
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
230
+ "Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
231
+ py::arg("coordinates"), py::arg("output"));
232
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
233
+ "Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
234
+ py::arg("coordinates"), py::arg("output"));
235
+
236
+
237
+ m.def("max_euclidean_distance", max_euclidean_distance<double>,
238
+ "Identify pair of points with maximal euclidean distance (float64).",
239
+ py::arg("coordinates"));
240
+ m.def("max_euclidean_distance", max_euclidean_distance<float>,
241
+ "Identify pair of points with maximal euclidean distance (float32).",
242
+ py::arg("coordinates"));
243
+ m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
244
+ "Identify pair of points with maximal euclidean distance (int64).",
245
+ py::arg("coordinates"));
246
+ m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
247
+ "Identify pair of points with maximal euclidean distance (int32).",
248
+ py::arg("coordinates"));
249
+
250
+
251
+ m.def("find_candidate_indices", &find_candidate_indices<double>,
252
+ "Finds candidate indices with minimum distance (float64).",
253
+ py::arg("coordinates"), py::arg("min_distance"));
254
+ m.def("find_candidate_indices", &find_candidate_indices<float>,
255
+ "Finds candidate indices with minimum distance (float32).",
256
+ py::arg("coordinates"), py::arg("min_distance"));
257
+ m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
258
+ "Finds candidate indices with minimum distance (int64).",
259
+ py::arg("coordinates"), py::arg("min_distance"));
260
+ m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
261
+ "Finds candidate indices with minimum distance (int32).",
262
+ py::arg("coordinates"), py::arg("min_distance"));
263
+
264
+
265
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
266
+ "Finds candidate coordinates with minimum distance (float64).",
267
+ py::arg("coordinates"), py::arg("min_distance"));
268
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
269
+ "Finds candidate coordinates with minimum distance (float32).",
270
+ py::arg("coordinates"), py::arg("min_distance"));
271
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
272
+ "Finds candidate coordinates with minimum distance (int64).",
273
+ py::arg("coordinates"), py::arg("min_distance"));
274
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
275
+ "Finds candidate coordinates with minimum distance (int32).",
276
+ py::arg("coordinates"), py::arg("min_distance"));
277
+
278
+
279
+ m.def("max_index_by_label", &max_index_by_label<double, double>,
280
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
281
+ m.def("max_index_by_label", &max_index_by_label<double, float>,
282
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
283
+ m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
284
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
285
+ m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
286
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
287
+
288
+ m.def("max_index_by_label", &max_index_by_label<float, double>,
289
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
290
+ m.def("max_index_by_label", &max_index_by_label<float, float>,
291
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
292
+ m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
293
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
294
+ m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
295
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
296
+
297
+ m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
298
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
299
+ m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
300
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
301
+ m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
302
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
303
+ m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
304
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
305
+
306
+ m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
307
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
308
+ m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
309
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
310
+ m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
311
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
312
+ m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
313
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
314
+
315
+
316
+ m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
317
+ py::arg("n") = 0, py::arg("rmean") = 0,
318
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
319
+ "Compute running online statistics on a numpy array.");
320
+ m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
321
+ py::arg("n") = 0, py::arg("rmean") = 0,
322
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
323
+ "Compute running online statistics on a numpy array.");
324
+ m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
325
+ py::arg("n") = 0, py::arg("rmean") = 0,
326
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
327
+ "Compute running online statistics on a numpy array.");
328
+ m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
329
+ py::arg("n") = 0, py::arg("rmean") = 0,
330
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
331
+ "Compute running online statistics on a numpy array.");
332
+ }
@@ -0,0 +1,6 @@
1
+ from .ctf import CTF
2
+ from .compose import Compose, ComposableFilter
3
+ from .bandpass import BandPassFilter
4
+ from .whitening import LinearWhiteningFilter
5
+ from .wedge import Wedge, WedgeReconstructed
6
+ from .reconstruction import ReconstructFromTilt
tme/filters/_utils.py ADDED
@@ -0,0 +1,311 @@
1
+ """ Utilities for the generation of frequency grids.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple, List, Dict
9
+
10
+ import numpy as np
11
+
12
+ from ..backends import backend as be
13
+ from ..backends import NumpyFFTWBackend
14
+ from ..types import BackendArray, NDArray
15
+ from ..rotations import euler_to_rotationmatrix
16
+
17
+
18
+ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
19
+ """
20
+ Given an opening_axis, computes the shape of the remaining dimensions.
21
+
22
+ Parameters
23
+ ----------
24
+ shape : Tuple[int]
25
+ The shape of the input array.
26
+ opening_axis : int
27
+ The axis along which the array will be tilted.
28
+ reduce_dim : bool, optional (default=False)
29
+ Whether to reduce the dimensionality after tilting.
30
+
31
+ Returns
32
+ -------
33
+ Tuple[int]
34
+ The shape of the array after tilting.
35
+ """
36
+ tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
37
+ if reduce_dim:
38
+ tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
39
+
40
+ return tilt_shape
41
+
42
+
43
+ def centered_grid(shape: Tuple[int]) -> NDArray:
44
+ """
45
+ Generate an integer valued grid centered around size // 2
46
+
47
+ Parameters
48
+ ----------
49
+ shape : Tuple[int]
50
+ The shape of the grid.
51
+
52
+ Returns
53
+ -------
54
+ NDArray
55
+ The centered grid.
56
+ """
57
+ index_grid = np.array(
58
+ np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
59
+ )
60
+ return index_grid
61
+
62
+
63
+ def frequency_grid_at_angle(
64
+ shape: Tuple[int],
65
+ angle: float,
66
+ sampling_rate: Tuple[float],
67
+ opening_axis: int = None,
68
+ tilt_axis: int = None,
69
+ ) -> NDArray:
70
+ """
71
+ Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
72
+
73
+ Parameters
74
+ ----------
75
+ shape : Tuple[int]
76
+ The shape of the grid.
77
+ angle : float
78
+ The angle at which to generate the grid.
79
+ sampling_rate : Tuple[float]
80
+ The sampling rate for each dimension.
81
+ opening_axis : int, optional
82
+ The axis to be opened, defaults to None.
83
+ tilt_axis : int, optional
84
+ The axis along which the grid is tilted, defaults to None.
85
+
86
+ Returns:
87
+ --------
88
+ NDArray
89
+ The frequency grid.
90
+ """
91
+ sampling_rate = np.array(sampling_rate)
92
+ sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
93
+
94
+ tilt_shape = compute_tilt_shape(
95
+ shape=shape, opening_axis=opening_axis, reduce_dim=False
96
+ )
97
+
98
+ if angle == 0:
99
+ sampling_rate = compute_tilt_shape(
100
+ shape=sampling_rate, opening_axis=opening_axis, reduce_dim=True
101
+ )
102
+ index_grid = fftfreqn(
103
+ tuple(x for x in tilt_shape if x != 1),
104
+ sampling_rate=sampling_rate,
105
+ compute_euclidean_norm=True,
106
+ )
107
+
108
+ if angle != 0:
109
+ aspect_ratio = shape[opening_axis] / shape[tilt_axis]
110
+ angle = np.degrees(np.arctan(np.tan(np.radians(angle)) * aspect_ratio))
111
+
112
+ angles = np.zeros(len(shape))
113
+ angles[tilt_axis] = angle
114
+ rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
115
+
116
+ index_grid = fftfreqn(tilt_shape, sampling_rate=None)
117
+ index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
118
+ norm = np.multiply(sampling_rate, shape).astype(int)
119
+
120
+ index_grid = np.divide(index_grid.T, norm).T
121
+ index_grid = np.squeeze(index_grid)
122
+ index_grid = np.linalg.norm(index_grid, axis=(0))
123
+
124
+ return index_grid
125
+
126
+
127
+ def fftfreqn(
128
+ shape: Tuple[int],
129
+ sampling_rate: Tuple[float],
130
+ compute_euclidean_norm: bool = False,
131
+ shape_is_real_fourier: bool = False,
132
+ return_sparse_grid: bool = False,
133
+ ) -> NDArray:
134
+ """
135
+ Generate the n-dimensional discrete Fourier transform sample frequencies.
136
+
137
+ Parameters
138
+ ----------
139
+ shape : Tuple[int]
140
+ The shape of the data.
141
+ sampling_rate : float or Tuple[float]
142
+ The sampling rate.
143
+ compute_euclidean_norm : bool, optional
144
+ Whether to compute the Euclidean norm, defaults to False.
145
+ shape_is_real_fourier : bool, optional
146
+ Whether the shape corresponds to a real Fourier transform, defaults to False.
147
+
148
+ Returns:
149
+ --------
150
+ NDArray
151
+ The sample frequencies.
152
+ """
153
+ # There is no real need to have these operations on GPU right now
154
+ np_be = NumpyFFTWBackend()
155
+ norm = np_be.full(len(shape), fill_value=1, dtype=np_be._float_dtype)
156
+ center = np_be.astype(np_be.divide(shape, 2), np_be._int_dtype)
157
+ if sampling_rate is not None:
158
+ norm = np_be.astype(np_be.multiply(shape, sampling_rate), int)
159
+
160
+ if shape_is_real_fourier:
161
+ center[-1], norm[-1] = 0, 1
162
+ if sampling_rate is not None:
163
+ norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
164
+
165
+ grids = []
166
+ for i, x in enumerate(shape):
167
+ baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
168
+ grid = (np_be.arange(x, dtype=np_be._int_dtype) - center[i]) / norm[i]
169
+ grid = np_be.astype(grid, np_be._float_dtype)
170
+ grids.append(np_be.reshape(grid, baseline_dims))
171
+
172
+ if compute_euclidean_norm:
173
+ grids = sum(np_be.square(x) for x in grids)
174
+ grids = np_be.sqrt(grids, out=grids)
175
+ return grids
176
+
177
+ if return_sparse_grid:
178
+ return grids
179
+
180
+ grid_flesh = np_be.full(shape, fill_value=1, dtype=np_be._float_dtype)
181
+ grids = np_be.stack(tuple(grid * grid_flesh for grid in grids))
182
+
183
+ return grids
184
+
185
+
186
+ def crop_real_fourier(data: BackendArray) -> BackendArray:
187
+ """
188
+ Crop the real part of a Fourier transform.
189
+
190
+ Parameters
191
+ ----------
192
+ data : BackendArray
193
+ The Fourier transformed data.
194
+
195
+ Returns:
196
+ --------
197
+ BackendArray
198
+ The cropped data.
199
+ """
200
+ stop = 1 + (data.shape[-1] // 2)
201
+ return data[..., :stop]
202
+
203
+
204
+ def compute_fourier_shape(
205
+ shape: Tuple[int], shape_is_real_fourier: bool = False
206
+ ) -> List[int]:
207
+ if shape_is_real_fourier:
208
+ return shape
209
+ shape = [int(x) for x in shape]
210
+ shape[-1] = 1 + shape[-1] // 2
211
+ return shape
212
+
213
+
214
+ def shift_fourier(
215
+ data: BackendArray, shape_is_real_fourier: bool = False
216
+ ) -> BackendArray:
217
+ comp = be
218
+ if isinstance(data, np.ndarray):
219
+ comp = NumpyFFTWBackend()
220
+ shape = comp.to_backend_array(data.shape)
221
+ shift = comp.add(comp.divide(shape, 2), comp.mod(shape, 2))
222
+ shift = [int(x) for x in shift]
223
+ if shape_is_real_fourier:
224
+ shift[-1] = 0
225
+
226
+ data = comp.roll(data, shift, tuple(i for i in range(len(shift))))
227
+ return data
228
+
229
+
230
+ def create_reconstruction_filter(
231
+ filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
232
+ ):
233
+ """Create a reconstruction filter of given filter_type.
234
+
235
+ Parameters
236
+ ----------
237
+ filter_shape : tuple of int
238
+ Shape of the returned filter.
239
+ filter_type: str
240
+ The type of created filter, available options are:
241
+
242
+ +---------------+----------------------------------------------------+
243
+ | ram-lak | Returns |w| |
244
+ +---------------+----------------------------------------------------+
245
+ | ramp-cont | Principles of Computerized Tomographic Imaging Avin|
246
+ | | ash C. Kak and Malcolm Slaney Chap 3 Eq. 61 [1]_ |
247
+ +---------------+----------------------------------------------------+
248
+ | ramp | Like ramp-cont but considering tilt angles |
249
+ +---------------+----------------------------------------------------+
250
+ | shepp-logan | |w| * sinc(|w| / 2) [2]_ |
251
+ +---------------+----------------------------------------------------+
252
+ | cosine | |w| * cos(|w| * pi / 2) [2]_ |
253
+ +---------------+----------------------------------------------------+
254
+ | hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
255
+ +---------------+----------------------------------------------------+
256
+ kwargs: Dict
257
+ Keyword arguments for particular filter_types.
258
+
259
+ Returns
260
+ -------
261
+ NDArray
262
+ Reconstruction filter
263
+
264
+ References
265
+ ----------
266
+ .. [1] Principles of Computerized Tomographic Imaging Avinash C. Kak and Malcolm Slaney Chap 3 Eq. 61
267
+ .. [2] https://odlgroup.github.io/odl/index.html
268
+ """
269
+ filter_type = str(filter_type).lower()
270
+ freq = fftfreqn(filter_shape, sampling_rate=0.5, compute_euclidean_norm=True)
271
+
272
+ if filter_type == "ram-lak":
273
+ ret = np.copy(freq)
274
+ elif filter_type == "ramp-cont":
275
+ ret, ndim = None, len(filter_shape)
276
+ for dim, size in enumerate(filter_shape):
277
+ n = np.concatenate(
278
+ (
279
+ np.arange(1, size / 2 + 1, 2, dtype=int),
280
+ np.arange(size / 2 - 1, 0, -2, dtype=int),
281
+ )
282
+ )
283
+ ret1d = np.zeros(size)
284
+ ret1d[0] = 0.25
285
+ ret1d[1::2] = -1 / (np.pi * n) ** 2
286
+ ret1d_shape = tuple(size if i == dim else 1 for i in range(ndim))
287
+ ret1d = ret1d.reshape(ret1d_shape)
288
+ if ret is None:
289
+ ret = ret1d
290
+ else:
291
+ ret = ret * ret1d
292
+ ret = 2 * np.fft.fftshift(np.real(np.fft.fftn(ret)))
293
+ elif filter_type == "ramp":
294
+ tilt_angles = kwargs.get("tilt_angles", False)
295
+ if tilt_angles is False:
296
+ raise ValueError("'ramp' filter requires specifying tilt angles.")
297
+ size = filter_shape[0]
298
+ ret = fftfreqn((size,), sampling_rate=1, compute_euclidean_norm=True)
299
+ min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
300
+ ret *= min_increment * size
301
+ np.fmin(ret, 1, out=ret)
302
+ elif filter_type == "shepp-logan":
303
+ ret = freq * np.sinc(freq / 2)
304
+ elif filter_type == "cosine":
305
+ ret = freq * np.cos(freq * np.pi / 2)
306
+ elif filter_type == "hamming":
307
+ ret = freq * (0.54 + 0.46 * np.cos(freq * np.pi))
308
+ else:
309
+ raise ValueError("Unsupported filter type")
310
+
311
+ return ret