pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
Binary file
|
@@ -0,0 +1,332 @@
|
|
1
|
+
/* Pybind extensions for template matching score space analyzers.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <vector>
|
9
|
+
#include <iostream>
|
10
|
+
#include <limits>
|
11
|
+
|
12
|
+
#include <pybind11/stl.h>
|
13
|
+
#include <pybind11/numpy.h>
|
14
|
+
#include <pybind11/pybind11.h>
|
15
|
+
|
16
|
+
namespace py = pybind11;
|
17
|
+
|
18
|
+
template <typename T>
|
19
|
+
void absolute_minimum_deviation(
|
20
|
+
py::array_t<T, py::array::c_style> coordinates,
|
21
|
+
py::array_t<T, py::array::c_style> output) {
|
22
|
+
auto coordinates_data = coordinates.data();
|
23
|
+
auto output_data = output.mutable_data();
|
24
|
+
int n = coordinates.shape(0);
|
25
|
+
int k = coordinates.shape(1);
|
26
|
+
int ik, jk, in, jn;
|
27
|
+
|
28
|
+
for (int i = 0; i < n; ++i) {
|
29
|
+
ik = i * k;
|
30
|
+
in = i * n;
|
31
|
+
for (int j = i + 1; j < n; ++j) {
|
32
|
+
jk = j * k;
|
33
|
+
jn = j * n;
|
34
|
+
T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
|
35
|
+
for (int p = 1; p < k; ++p) {
|
36
|
+
min_distance = std::min(min_distance,
|
37
|
+
std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
|
38
|
+
}
|
39
|
+
output_data[in + j] = min_distance;
|
40
|
+
output_data[jn + i] = min_distance;
|
41
|
+
}
|
42
|
+
output_data[in + i] = 0;
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
template <typename T>
|
47
|
+
std::pair<double, std::pair<int, int>> max_euclidean_distance(
|
48
|
+
py::array_t<T, py::array::c_style> coordinates) {
|
49
|
+
auto coordinates_data = coordinates.data();
|
50
|
+
int n = coordinates.shape(0);
|
51
|
+
int k = coordinates.shape(1);
|
52
|
+
|
53
|
+
double distance = 0.0;
|
54
|
+
double difference = 0.0;
|
55
|
+
double max_distance = -1;
|
56
|
+
double squared_distances = 0.0;
|
57
|
+
|
58
|
+
int ik, jk;
|
59
|
+
int max_i = -1, max_j = -1;
|
60
|
+
|
61
|
+
for (int i = 0; i < n; ++i) {
|
62
|
+
ik = i * k;
|
63
|
+
for (int j = i + 1; j < n; ++j) {
|
64
|
+
jk = j * k;
|
65
|
+
squared_distances = 0.0;
|
66
|
+
for (int p = 0; p < k; ++p) {
|
67
|
+
difference = static_cast<double>(
|
68
|
+
coordinates_data[ik + p] - coordinates_data[jk + p]
|
69
|
+
);
|
70
|
+
squared_distances += (difference * difference);
|
71
|
+
}
|
72
|
+
distance = std::sqrt(squared_distances);
|
73
|
+
if (distance > max_distance) {
|
74
|
+
max_distance = distance;
|
75
|
+
max_i = i;
|
76
|
+
max_j = j;
|
77
|
+
}
|
78
|
+
}
|
79
|
+
}
|
80
|
+
|
81
|
+
return std::make_pair(max_distance, std::make_pair(max_i, max_j));
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
template <typename T>
|
86
|
+
inline py::array_t<int, py::array::c_style> find_candidate_indices(
|
87
|
+
py::array_t<T, py::array::c_style> coordinates,
|
88
|
+
T min_distance) {
|
89
|
+
auto coordinates_data = coordinates.data();
|
90
|
+
int n = coordinates.shape(0);
|
91
|
+
int k = coordinates.shape(1);
|
92
|
+
int ik, jk;
|
93
|
+
|
94
|
+
std::vector<int> candidate_indices;
|
95
|
+
candidate_indices.reserve(n / 2);
|
96
|
+
candidate_indices.push_back(0);
|
97
|
+
|
98
|
+
for (int i = 1; i < n; ++i) {
|
99
|
+
bool is_candidate = true;
|
100
|
+
ik = i * k;
|
101
|
+
for (int candidate_index : candidate_indices) {
|
102
|
+
jk = candidate_index * k;
|
103
|
+
T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
|
104
|
+
for (int p = 1; p < k; ++p) {
|
105
|
+
distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
|
106
|
+
}
|
107
|
+
distance = std::sqrt(distance);
|
108
|
+
if (distance <= min_distance) {
|
109
|
+
is_candidate = false;
|
110
|
+
break;
|
111
|
+
}
|
112
|
+
}
|
113
|
+
if (is_candidate) {
|
114
|
+
candidate_indices.push_back(i);
|
115
|
+
}
|
116
|
+
}
|
117
|
+
|
118
|
+
py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
|
119
|
+
auto output_data = output.mutable_data();
|
120
|
+
|
121
|
+
for (size_t i = 0; i < candidate_indices.size(); ++i) {
|
122
|
+
output_data[i] = candidate_indices[i];
|
123
|
+
}
|
124
|
+
|
125
|
+
return output;
|
126
|
+
}
|
127
|
+
|
128
|
+
template <typename T>
|
129
|
+
py::array_t<T, py::array::c_style> find_candidate_coordinates(
|
130
|
+
py::array_t<T, py::array::c_style> coordinates,
|
131
|
+
T min_distance) {
|
132
|
+
|
133
|
+
py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
|
134
|
+
coordinates, min_distance);
|
135
|
+
auto candidate_indices_data = candidate_indices_array.data();
|
136
|
+
int num_candidates = candidate_indices_array.shape(0);
|
137
|
+
int k = coordinates.shape(1);
|
138
|
+
auto coordinates_data = coordinates.data();
|
139
|
+
|
140
|
+
py::array_t<T, py::array::c_style> output({num_candidates, k});
|
141
|
+
auto output_data = output.mutable_data();
|
142
|
+
|
143
|
+
for (int i = 0; i < num_candidates; ++i) {
|
144
|
+
int candidate_index = candidate_indices_data[i] * k;
|
145
|
+
std::copy(
|
146
|
+
coordinates_data + candidate_index,
|
147
|
+
coordinates_data + candidate_index + k,
|
148
|
+
output_data + i * k
|
149
|
+
);
|
150
|
+
}
|
151
|
+
|
152
|
+
return output;
|
153
|
+
}
|
154
|
+
|
155
|
+
template <typename U, typename T>
|
156
|
+
py::dict max_index_by_label(
|
157
|
+
py::array_t<U, py::array::c_style> labels,
|
158
|
+
py::array_t<T, py::array::c_style> scores
|
159
|
+
) {
|
160
|
+
|
161
|
+
const U* labels_ptr = labels.data();
|
162
|
+
const T* scores_ptr = scores.data();
|
163
|
+
|
164
|
+
std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
|
165
|
+
|
166
|
+
U label;
|
167
|
+
T score;
|
168
|
+
for (ssize_t i = 0; i < labels.size(); ++i) {
|
169
|
+
label = labels_ptr[i];
|
170
|
+
score = scores_ptr[i];
|
171
|
+
|
172
|
+
auto it = max_scores.insert({label, {score, i}});
|
173
|
+
|
174
|
+
if (score > it.first->second.first) {
|
175
|
+
it.first->second = {score, i};
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
py::dict ret;
|
180
|
+
for (auto& item: max_scores) {
|
181
|
+
ret[py::cast(item.first)] = py::cast(item.second.second);
|
182
|
+
}
|
183
|
+
|
184
|
+
return ret;
|
185
|
+
}
|
186
|
+
|
187
|
+
|
188
|
+
template <typename T>
|
189
|
+
py::tuple online_statistics(
|
190
|
+
py::array_t<T, py::array::c_style> arr,
|
191
|
+
unsigned long long int n = 0,
|
192
|
+
double rmean = 0,
|
193
|
+
double ssqd = 0,
|
194
|
+
T reference = 0) {
|
195
|
+
|
196
|
+
auto in = arr.data();
|
197
|
+
int size = arr.size();
|
198
|
+
|
199
|
+
T max_value = std::numeric_limits<T>::lowest();
|
200
|
+
T min_value = std::numeric_limits<T>::max();
|
201
|
+
double delta, delta_prime;
|
202
|
+
|
203
|
+
unsigned long long int nbetter_or_equal = 0;
|
204
|
+
|
205
|
+
for(int i = 0; i < size; i++){
|
206
|
+
n++;
|
207
|
+
delta = in[i] - rmean;
|
208
|
+
rmean += delta / n;
|
209
|
+
delta_prime = in[i] - rmean;
|
210
|
+
ssqd += delta * delta_prime;
|
211
|
+
|
212
|
+
max_value = std::max(in[i], max_value);
|
213
|
+
min_value = std::min(in[i], min_value);
|
214
|
+
if (in[i] >= reference)
|
215
|
+
nbetter_or_equal++;
|
216
|
+
}
|
217
|
+
|
218
|
+
return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
|
219
|
+
}
|
220
|
+
|
221
|
+
PYBIND11_MODULE(extensions, m) {
|
222
|
+
|
223
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
|
224
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
|
225
|
+
py::arg("coordinates"), py::arg("output"));
|
226
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
|
227
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
|
228
|
+
py::arg("coordinates"), py::arg("output"));
|
229
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
|
230
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
|
231
|
+
py::arg("coordinates"), py::arg("output"));
|
232
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
|
233
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
|
234
|
+
py::arg("coordinates"), py::arg("output"));
|
235
|
+
|
236
|
+
|
237
|
+
m.def("max_euclidean_distance", max_euclidean_distance<double>,
|
238
|
+
"Identify pair of points with maximal euclidean distance (float64).",
|
239
|
+
py::arg("coordinates"));
|
240
|
+
m.def("max_euclidean_distance", max_euclidean_distance<float>,
|
241
|
+
"Identify pair of points with maximal euclidean distance (float32).",
|
242
|
+
py::arg("coordinates"));
|
243
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
|
244
|
+
"Identify pair of points with maximal euclidean distance (int64).",
|
245
|
+
py::arg("coordinates"));
|
246
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
|
247
|
+
"Identify pair of points with maximal euclidean distance (int32).",
|
248
|
+
py::arg("coordinates"));
|
249
|
+
|
250
|
+
|
251
|
+
m.def("find_candidate_indices", &find_candidate_indices<double>,
|
252
|
+
"Finds candidate indices with minimum distance (float64).",
|
253
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
254
|
+
m.def("find_candidate_indices", &find_candidate_indices<float>,
|
255
|
+
"Finds candidate indices with minimum distance (float32).",
|
256
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
257
|
+
m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
|
258
|
+
"Finds candidate indices with minimum distance (int64).",
|
259
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
260
|
+
m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
|
261
|
+
"Finds candidate indices with minimum distance (int32).",
|
262
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
263
|
+
|
264
|
+
|
265
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
|
266
|
+
"Finds candidate coordinates with minimum distance (float64).",
|
267
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
268
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
|
269
|
+
"Finds candidate coordinates with minimum distance (float32).",
|
270
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
271
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
|
272
|
+
"Finds candidate coordinates with minimum distance (int64).",
|
273
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
274
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
|
275
|
+
"Finds candidate coordinates with minimum distance (int32).",
|
276
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
277
|
+
|
278
|
+
|
279
|
+
m.def("max_index_by_label", &max_index_by_label<double, double>,
|
280
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
281
|
+
m.def("max_index_by_label", &max_index_by_label<double, float>,
|
282
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
283
|
+
m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
|
284
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
285
|
+
m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
|
286
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
287
|
+
|
288
|
+
m.def("max_index_by_label", &max_index_by_label<float, double>,
|
289
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
290
|
+
m.def("max_index_by_label", &max_index_by_label<float, float>,
|
291
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
292
|
+
m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
|
293
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
294
|
+
m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
|
295
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
296
|
+
|
297
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
|
298
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
299
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
|
300
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
301
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
|
302
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
303
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
|
304
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
305
|
+
|
306
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
|
307
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
308
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
|
309
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
310
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
|
311
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
312
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
|
313
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
314
|
+
|
315
|
+
|
316
|
+
m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
|
317
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
318
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
319
|
+
"Compute running online statistics on a numpy array.");
|
320
|
+
m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
|
321
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
322
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
323
|
+
"Compute running online statistics on a numpy array.");
|
324
|
+
m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
|
325
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
326
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
327
|
+
"Compute running online statistics on a numpy array.");
|
328
|
+
m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
|
329
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
330
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
331
|
+
"Compute running online statistics on a numpy array.");
|
332
|
+
}
|
tme/filters/__init__.py
ADDED
tme/filters/_utils.py
ADDED
@@ -0,0 +1,311 @@
|
|
1
|
+
""" Utilities for the generation of frequency grids.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple, List, Dict
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from ..backends import backend as be
|
13
|
+
from ..backends import NumpyFFTWBackend
|
14
|
+
from ..types import BackendArray, NDArray
|
15
|
+
from ..rotations import euler_to_rotationmatrix
|
16
|
+
|
17
|
+
|
18
|
+
def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
|
19
|
+
"""
|
20
|
+
Given an opening_axis, computes the shape of the remaining dimensions.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
shape : Tuple[int]
|
25
|
+
The shape of the input array.
|
26
|
+
opening_axis : int
|
27
|
+
The axis along which the array will be tilted.
|
28
|
+
reduce_dim : bool, optional (default=False)
|
29
|
+
Whether to reduce the dimensionality after tilting.
|
30
|
+
|
31
|
+
Returns
|
32
|
+
-------
|
33
|
+
Tuple[int]
|
34
|
+
The shape of the array after tilting.
|
35
|
+
"""
|
36
|
+
tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
|
37
|
+
if reduce_dim:
|
38
|
+
tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
|
39
|
+
|
40
|
+
return tilt_shape
|
41
|
+
|
42
|
+
|
43
|
+
def centered_grid(shape: Tuple[int]) -> NDArray:
|
44
|
+
"""
|
45
|
+
Generate an integer valued grid centered around size // 2
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
shape : Tuple[int]
|
50
|
+
The shape of the grid.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
NDArray
|
55
|
+
The centered grid.
|
56
|
+
"""
|
57
|
+
index_grid = np.array(
|
58
|
+
np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
|
59
|
+
)
|
60
|
+
return index_grid
|
61
|
+
|
62
|
+
|
63
|
+
def frequency_grid_at_angle(
|
64
|
+
shape: Tuple[int],
|
65
|
+
angle: float,
|
66
|
+
sampling_rate: Tuple[float],
|
67
|
+
opening_axis: int = None,
|
68
|
+
tilt_axis: int = None,
|
69
|
+
) -> NDArray:
|
70
|
+
"""
|
71
|
+
Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
|
72
|
+
|
73
|
+
Parameters
|
74
|
+
----------
|
75
|
+
shape : Tuple[int]
|
76
|
+
The shape of the grid.
|
77
|
+
angle : float
|
78
|
+
The angle at which to generate the grid.
|
79
|
+
sampling_rate : Tuple[float]
|
80
|
+
The sampling rate for each dimension.
|
81
|
+
opening_axis : int, optional
|
82
|
+
The axis to be opened, defaults to None.
|
83
|
+
tilt_axis : int, optional
|
84
|
+
The axis along which the grid is tilted, defaults to None.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
--------
|
88
|
+
NDArray
|
89
|
+
The frequency grid.
|
90
|
+
"""
|
91
|
+
sampling_rate = np.array(sampling_rate)
|
92
|
+
sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
|
93
|
+
|
94
|
+
tilt_shape = compute_tilt_shape(
|
95
|
+
shape=shape, opening_axis=opening_axis, reduce_dim=False
|
96
|
+
)
|
97
|
+
|
98
|
+
if angle == 0:
|
99
|
+
sampling_rate = compute_tilt_shape(
|
100
|
+
shape=sampling_rate, opening_axis=opening_axis, reduce_dim=True
|
101
|
+
)
|
102
|
+
index_grid = fftfreqn(
|
103
|
+
tuple(x for x in tilt_shape if x != 1),
|
104
|
+
sampling_rate=sampling_rate,
|
105
|
+
compute_euclidean_norm=True,
|
106
|
+
)
|
107
|
+
|
108
|
+
if angle != 0:
|
109
|
+
aspect_ratio = shape[opening_axis] / shape[tilt_axis]
|
110
|
+
angle = np.degrees(np.arctan(np.tan(np.radians(angle)) * aspect_ratio))
|
111
|
+
|
112
|
+
angles = np.zeros(len(shape))
|
113
|
+
angles[tilt_axis] = angle
|
114
|
+
rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
|
115
|
+
|
116
|
+
index_grid = fftfreqn(tilt_shape, sampling_rate=None)
|
117
|
+
index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
|
118
|
+
norm = np.multiply(sampling_rate, shape).astype(int)
|
119
|
+
|
120
|
+
index_grid = np.divide(index_grid.T, norm).T
|
121
|
+
index_grid = np.squeeze(index_grid)
|
122
|
+
index_grid = np.linalg.norm(index_grid, axis=(0))
|
123
|
+
|
124
|
+
return index_grid
|
125
|
+
|
126
|
+
|
127
|
+
def fftfreqn(
|
128
|
+
shape: Tuple[int],
|
129
|
+
sampling_rate: Tuple[float],
|
130
|
+
compute_euclidean_norm: bool = False,
|
131
|
+
shape_is_real_fourier: bool = False,
|
132
|
+
return_sparse_grid: bool = False,
|
133
|
+
) -> NDArray:
|
134
|
+
"""
|
135
|
+
Generate the n-dimensional discrete Fourier transform sample frequencies.
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
----------
|
139
|
+
shape : Tuple[int]
|
140
|
+
The shape of the data.
|
141
|
+
sampling_rate : float or Tuple[float]
|
142
|
+
The sampling rate.
|
143
|
+
compute_euclidean_norm : bool, optional
|
144
|
+
Whether to compute the Euclidean norm, defaults to False.
|
145
|
+
shape_is_real_fourier : bool, optional
|
146
|
+
Whether the shape corresponds to a real Fourier transform, defaults to False.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
--------
|
150
|
+
NDArray
|
151
|
+
The sample frequencies.
|
152
|
+
"""
|
153
|
+
# There is no real need to have these operations on GPU right now
|
154
|
+
np_be = NumpyFFTWBackend()
|
155
|
+
norm = np_be.full(len(shape), fill_value=1, dtype=np_be._float_dtype)
|
156
|
+
center = np_be.astype(np_be.divide(shape, 2), np_be._int_dtype)
|
157
|
+
if sampling_rate is not None:
|
158
|
+
norm = np_be.astype(np_be.multiply(shape, sampling_rate), int)
|
159
|
+
|
160
|
+
if shape_is_real_fourier:
|
161
|
+
center[-1], norm[-1] = 0, 1
|
162
|
+
if sampling_rate is not None:
|
163
|
+
norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
|
164
|
+
|
165
|
+
grids = []
|
166
|
+
for i, x in enumerate(shape):
|
167
|
+
baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
|
168
|
+
grid = (np_be.arange(x, dtype=np_be._int_dtype) - center[i]) / norm[i]
|
169
|
+
grid = np_be.astype(grid, np_be._float_dtype)
|
170
|
+
grids.append(np_be.reshape(grid, baseline_dims))
|
171
|
+
|
172
|
+
if compute_euclidean_norm:
|
173
|
+
grids = sum(np_be.square(x) for x in grids)
|
174
|
+
grids = np_be.sqrt(grids, out=grids)
|
175
|
+
return grids
|
176
|
+
|
177
|
+
if return_sparse_grid:
|
178
|
+
return grids
|
179
|
+
|
180
|
+
grid_flesh = np_be.full(shape, fill_value=1, dtype=np_be._float_dtype)
|
181
|
+
grids = np_be.stack(tuple(grid * grid_flesh for grid in grids))
|
182
|
+
|
183
|
+
return grids
|
184
|
+
|
185
|
+
|
186
|
+
def crop_real_fourier(data: BackendArray) -> BackendArray:
|
187
|
+
"""
|
188
|
+
Crop the real part of a Fourier transform.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
data : BackendArray
|
193
|
+
The Fourier transformed data.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
--------
|
197
|
+
BackendArray
|
198
|
+
The cropped data.
|
199
|
+
"""
|
200
|
+
stop = 1 + (data.shape[-1] // 2)
|
201
|
+
return data[..., :stop]
|
202
|
+
|
203
|
+
|
204
|
+
def compute_fourier_shape(
|
205
|
+
shape: Tuple[int], shape_is_real_fourier: bool = False
|
206
|
+
) -> List[int]:
|
207
|
+
if shape_is_real_fourier:
|
208
|
+
return shape
|
209
|
+
shape = [int(x) for x in shape]
|
210
|
+
shape[-1] = 1 + shape[-1] // 2
|
211
|
+
return shape
|
212
|
+
|
213
|
+
|
214
|
+
def shift_fourier(
|
215
|
+
data: BackendArray, shape_is_real_fourier: bool = False
|
216
|
+
) -> BackendArray:
|
217
|
+
comp = be
|
218
|
+
if isinstance(data, np.ndarray):
|
219
|
+
comp = NumpyFFTWBackend()
|
220
|
+
shape = comp.to_backend_array(data.shape)
|
221
|
+
shift = comp.add(comp.divide(shape, 2), comp.mod(shape, 2))
|
222
|
+
shift = [int(x) for x in shift]
|
223
|
+
if shape_is_real_fourier:
|
224
|
+
shift[-1] = 0
|
225
|
+
|
226
|
+
data = comp.roll(data, shift, tuple(i for i in range(len(shift))))
|
227
|
+
return data
|
228
|
+
|
229
|
+
|
230
|
+
def create_reconstruction_filter(
|
231
|
+
filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
|
232
|
+
):
|
233
|
+
"""Create a reconstruction filter of given filter_type.
|
234
|
+
|
235
|
+
Parameters
|
236
|
+
----------
|
237
|
+
filter_shape : tuple of int
|
238
|
+
Shape of the returned filter.
|
239
|
+
filter_type: str
|
240
|
+
The type of created filter, available options are:
|
241
|
+
|
242
|
+
+---------------+----------------------------------------------------+
|
243
|
+
| ram-lak | Returns |w| |
|
244
|
+
+---------------+----------------------------------------------------+
|
245
|
+
| ramp-cont | Principles of Computerized Tomographic Imaging Avin|
|
246
|
+
| | ash C. Kak and Malcolm Slaney Chap 3 Eq. 61 [1]_ |
|
247
|
+
+---------------+----------------------------------------------------+
|
248
|
+
| ramp | Like ramp-cont but considering tilt angles |
|
249
|
+
+---------------+----------------------------------------------------+
|
250
|
+
| shepp-logan | |w| * sinc(|w| / 2) [2]_ |
|
251
|
+
+---------------+----------------------------------------------------+
|
252
|
+
| cosine | |w| * cos(|w| * pi / 2) [2]_ |
|
253
|
+
+---------------+----------------------------------------------------+
|
254
|
+
| hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
|
255
|
+
+---------------+----------------------------------------------------+
|
256
|
+
kwargs: Dict
|
257
|
+
Keyword arguments for particular filter_types.
|
258
|
+
|
259
|
+
Returns
|
260
|
+
-------
|
261
|
+
NDArray
|
262
|
+
Reconstruction filter
|
263
|
+
|
264
|
+
References
|
265
|
+
----------
|
266
|
+
.. [1] Principles of Computerized Tomographic Imaging Avinash C. Kak and Malcolm Slaney Chap 3 Eq. 61
|
267
|
+
.. [2] https://odlgroup.github.io/odl/index.html
|
268
|
+
"""
|
269
|
+
filter_type = str(filter_type).lower()
|
270
|
+
freq = fftfreqn(filter_shape, sampling_rate=0.5, compute_euclidean_norm=True)
|
271
|
+
|
272
|
+
if filter_type == "ram-lak":
|
273
|
+
ret = np.copy(freq)
|
274
|
+
elif filter_type == "ramp-cont":
|
275
|
+
ret, ndim = None, len(filter_shape)
|
276
|
+
for dim, size in enumerate(filter_shape):
|
277
|
+
n = np.concatenate(
|
278
|
+
(
|
279
|
+
np.arange(1, size / 2 + 1, 2, dtype=int),
|
280
|
+
np.arange(size / 2 - 1, 0, -2, dtype=int),
|
281
|
+
)
|
282
|
+
)
|
283
|
+
ret1d = np.zeros(size)
|
284
|
+
ret1d[0] = 0.25
|
285
|
+
ret1d[1::2] = -1 / (np.pi * n) ** 2
|
286
|
+
ret1d_shape = tuple(size if i == dim else 1 for i in range(ndim))
|
287
|
+
ret1d = ret1d.reshape(ret1d_shape)
|
288
|
+
if ret is None:
|
289
|
+
ret = ret1d
|
290
|
+
else:
|
291
|
+
ret = ret * ret1d
|
292
|
+
ret = 2 * np.fft.fftshift(np.real(np.fft.fftn(ret)))
|
293
|
+
elif filter_type == "ramp":
|
294
|
+
tilt_angles = kwargs.get("tilt_angles", False)
|
295
|
+
if tilt_angles is False:
|
296
|
+
raise ValueError("'ramp' filter requires specifying tilt angles.")
|
297
|
+
size = filter_shape[0]
|
298
|
+
ret = fftfreqn((size,), sampling_rate=1, compute_euclidean_norm=True)
|
299
|
+
min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
|
300
|
+
ret *= min_increment * size
|
301
|
+
np.fmin(ret, 1, out=ret)
|
302
|
+
elif filter_type == "shepp-logan":
|
303
|
+
ret = freq * np.sinc(freq / 2)
|
304
|
+
elif filter_type == "cosine":
|
305
|
+
ret = freq * np.cos(freq * np.pi / 2)
|
306
|
+
elif filter_type == "hamming":
|
307
|
+
ret = freq * (0.54 + 0.46 * np.cos(freq * np.pi))
|
308
|
+
else:
|
309
|
+
raise ValueError("Unsupported filter type")
|
310
|
+
|
311
|
+
return ret
|