pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/match_template.py +8 -8
- {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocess.py +22 -6
- {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +9 -14
- {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/METADATA +1 -1
- pytme-0.2.4.dist-info/RECORD +119 -0
- {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
- {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
- scripts/match_template.py +8 -8
- scripts/preprocess.py +22 -6
- scripts/preprocessor_gui.py +9 -14
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +276 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/backends/jax_backend.py +8 -7
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +11 -4
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +11 -9
- tme/matching_exhaustive.py +10 -8
- tme/matching_utils.py +1 -0
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +0 -2
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +202 -118
- tme/preprocessor.py +24 -246
- tme/structure.py +14 -14
- pytme-0.2.3.dist-info/RECORD +0 -75
- tme/matching_memory.py +0 -383
- {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/postprocess.py +0 -0
- {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
- {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,332 @@
|
|
1
|
+
/* Pybind extensions for template matching score space analyzers.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <vector>
|
9
|
+
#include <iostream>
|
10
|
+
#include <limits>
|
11
|
+
|
12
|
+
#include <pybind11/stl.h>
|
13
|
+
#include <pybind11/numpy.h>
|
14
|
+
#include <pybind11/pybind11.h>
|
15
|
+
|
16
|
+
namespace py = pybind11;
|
17
|
+
|
18
|
+
template <typename T>
|
19
|
+
void absolute_minimum_deviation(
|
20
|
+
py::array_t<T, py::array::c_style> coordinates,
|
21
|
+
py::array_t<T, py::array::c_style> output) {
|
22
|
+
auto coordinates_data = coordinates.data();
|
23
|
+
auto output_data = output.mutable_data();
|
24
|
+
int n = coordinates.shape(0);
|
25
|
+
int k = coordinates.shape(1);
|
26
|
+
int ik, jk, in, jn;
|
27
|
+
|
28
|
+
for (int i = 0; i < n; ++i) {
|
29
|
+
ik = i * k;
|
30
|
+
in = i * n;
|
31
|
+
for (int j = i + 1; j < n; ++j) {
|
32
|
+
jk = j * k;
|
33
|
+
jn = j * n;
|
34
|
+
T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
|
35
|
+
for (int p = 1; p < k; ++p) {
|
36
|
+
min_distance = std::min(min_distance,
|
37
|
+
std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
|
38
|
+
}
|
39
|
+
output_data[in + j] = min_distance;
|
40
|
+
output_data[jn + i] = min_distance;
|
41
|
+
}
|
42
|
+
output_data[in + i] = 0;
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
template <typename T>
|
47
|
+
std::pair<double, std::pair<int, int>> max_euclidean_distance(
|
48
|
+
py::array_t<T, py::array::c_style> coordinates) {
|
49
|
+
auto coordinates_data = coordinates.data();
|
50
|
+
int n = coordinates.shape(0);
|
51
|
+
int k = coordinates.shape(1);
|
52
|
+
|
53
|
+
double distance = 0.0;
|
54
|
+
double difference = 0.0;
|
55
|
+
double max_distance = -1;
|
56
|
+
double squared_distances = 0.0;
|
57
|
+
|
58
|
+
int ik, jk;
|
59
|
+
int max_i = -1, max_j = -1;
|
60
|
+
|
61
|
+
for (int i = 0; i < n; ++i) {
|
62
|
+
ik = i * k;
|
63
|
+
for (int j = i + 1; j < n; ++j) {
|
64
|
+
jk = j * k;
|
65
|
+
squared_distances = 0.0;
|
66
|
+
for (int p = 0; p < k; ++p) {
|
67
|
+
difference = static_cast<double>(
|
68
|
+
coordinates_data[ik + p] - coordinates_data[jk + p]
|
69
|
+
);
|
70
|
+
squared_distances += (difference * difference);
|
71
|
+
}
|
72
|
+
distance = std::sqrt(squared_distances);
|
73
|
+
if (distance > max_distance) {
|
74
|
+
max_distance = distance;
|
75
|
+
max_i = i;
|
76
|
+
max_j = j;
|
77
|
+
}
|
78
|
+
}
|
79
|
+
}
|
80
|
+
|
81
|
+
return std::make_pair(max_distance, std::make_pair(max_i, max_j));
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
template <typename T>
|
86
|
+
inline py::array_t<int, py::array::c_style> find_candidate_indices(
|
87
|
+
py::array_t<T, py::array::c_style> coordinates,
|
88
|
+
T min_distance) {
|
89
|
+
auto coordinates_data = coordinates.data();
|
90
|
+
int n = coordinates.shape(0);
|
91
|
+
int k = coordinates.shape(1);
|
92
|
+
int ik, jk;
|
93
|
+
|
94
|
+
std::vector<int> candidate_indices;
|
95
|
+
candidate_indices.reserve(n / 2);
|
96
|
+
candidate_indices.push_back(0);
|
97
|
+
|
98
|
+
for (int i = 1; i < n; ++i) {
|
99
|
+
bool is_candidate = true;
|
100
|
+
ik = i * k;
|
101
|
+
for (int candidate_index : candidate_indices) {
|
102
|
+
jk = candidate_index * k;
|
103
|
+
T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
|
104
|
+
for (int p = 1; p < k; ++p) {
|
105
|
+
distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
|
106
|
+
}
|
107
|
+
distance = std::sqrt(distance);
|
108
|
+
if (distance <= min_distance) {
|
109
|
+
is_candidate = false;
|
110
|
+
break;
|
111
|
+
}
|
112
|
+
}
|
113
|
+
if (is_candidate) {
|
114
|
+
candidate_indices.push_back(i);
|
115
|
+
}
|
116
|
+
}
|
117
|
+
|
118
|
+
py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
|
119
|
+
auto output_data = output.mutable_data();
|
120
|
+
|
121
|
+
for (size_t i = 0; i < candidate_indices.size(); ++i) {
|
122
|
+
output_data[i] = candidate_indices[i];
|
123
|
+
}
|
124
|
+
|
125
|
+
return output;
|
126
|
+
}
|
127
|
+
|
128
|
+
template <typename T>
|
129
|
+
py::array_t<T, py::array::c_style> find_candidate_coordinates(
|
130
|
+
py::array_t<T, py::array::c_style> coordinates,
|
131
|
+
T min_distance) {
|
132
|
+
|
133
|
+
py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
|
134
|
+
coordinates, min_distance);
|
135
|
+
auto candidate_indices_data = candidate_indices_array.data();
|
136
|
+
int num_candidates = candidate_indices_array.shape(0);
|
137
|
+
int k = coordinates.shape(1);
|
138
|
+
auto coordinates_data = coordinates.data();
|
139
|
+
|
140
|
+
py::array_t<T, py::array::c_style> output({num_candidates, k});
|
141
|
+
auto output_data = output.mutable_data();
|
142
|
+
|
143
|
+
for (int i = 0; i < num_candidates; ++i) {
|
144
|
+
int candidate_index = candidate_indices_data[i] * k;
|
145
|
+
std::copy(
|
146
|
+
coordinates_data + candidate_index,
|
147
|
+
coordinates_data + candidate_index + k,
|
148
|
+
output_data + i * k
|
149
|
+
);
|
150
|
+
}
|
151
|
+
|
152
|
+
return output;
|
153
|
+
}
|
154
|
+
|
155
|
+
template <typename U, typename T>
|
156
|
+
py::dict max_index_by_label(
|
157
|
+
py::array_t<U, py::array::c_style> labels,
|
158
|
+
py::array_t<T, py::array::c_style> scores
|
159
|
+
) {
|
160
|
+
|
161
|
+
const U* labels_ptr = labels.data();
|
162
|
+
const T* scores_ptr = scores.data();
|
163
|
+
|
164
|
+
std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
|
165
|
+
|
166
|
+
U label;
|
167
|
+
T score;
|
168
|
+
for (ssize_t i = 0; i < labels.size(); ++i) {
|
169
|
+
label = labels_ptr[i];
|
170
|
+
score = scores_ptr[i];
|
171
|
+
|
172
|
+
auto it = max_scores.insert({label, {score, i}});
|
173
|
+
|
174
|
+
if (score > it.first->second.first) {
|
175
|
+
it.first->second = {score, i};
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
py::dict ret;
|
180
|
+
for (auto& item: max_scores) {
|
181
|
+
ret[py::cast(item.first)] = py::cast(item.second.second);
|
182
|
+
}
|
183
|
+
|
184
|
+
return ret;
|
185
|
+
}
|
186
|
+
|
187
|
+
|
188
|
+
template <typename T>
|
189
|
+
py::tuple online_statistics(
|
190
|
+
py::array_t<T, py::array::c_style> arr,
|
191
|
+
unsigned long long int n = 0,
|
192
|
+
double rmean = 0,
|
193
|
+
double ssqd = 0,
|
194
|
+
T reference = 0) {
|
195
|
+
|
196
|
+
auto in = arr.data();
|
197
|
+
int size = arr.size();
|
198
|
+
|
199
|
+
T max_value = std::numeric_limits<T>::lowest();
|
200
|
+
T min_value = std::numeric_limits<T>::max();
|
201
|
+
double delta, delta_prime;
|
202
|
+
|
203
|
+
unsigned long long int nbetter_or_equal = 0;
|
204
|
+
|
205
|
+
for(int i = 0; i < size; i++){
|
206
|
+
n++;
|
207
|
+
delta = in[i] - rmean;
|
208
|
+
rmean += delta / n;
|
209
|
+
delta_prime = in[i] - rmean;
|
210
|
+
ssqd += delta * delta_prime;
|
211
|
+
|
212
|
+
max_value = std::max(in[i], max_value);
|
213
|
+
min_value = std::min(in[i], min_value);
|
214
|
+
if (in[i] >= reference)
|
215
|
+
nbetter_or_equal++;
|
216
|
+
}
|
217
|
+
|
218
|
+
return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
|
219
|
+
}
|
220
|
+
|
221
|
+
PYBIND11_MODULE(extensions, m) {
|
222
|
+
|
223
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
|
224
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
|
225
|
+
py::arg("coordinates"), py::arg("output"));
|
226
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
|
227
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
|
228
|
+
py::arg("coordinates"), py::arg("output"));
|
229
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
|
230
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
|
231
|
+
py::arg("coordinates"), py::arg("output"));
|
232
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
|
233
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
|
234
|
+
py::arg("coordinates"), py::arg("output"));
|
235
|
+
|
236
|
+
|
237
|
+
m.def("max_euclidean_distance", max_euclidean_distance<double>,
|
238
|
+
"Identify pair of points with maximal euclidean distance (float64).",
|
239
|
+
py::arg("coordinates"));
|
240
|
+
m.def("max_euclidean_distance", max_euclidean_distance<float>,
|
241
|
+
"Identify pair of points with maximal euclidean distance (float32).",
|
242
|
+
py::arg("coordinates"));
|
243
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
|
244
|
+
"Identify pair of points with maximal euclidean distance (int64).",
|
245
|
+
py::arg("coordinates"));
|
246
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
|
247
|
+
"Identify pair of points with maximal euclidean distance (int32).",
|
248
|
+
py::arg("coordinates"));
|
249
|
+
|
250
|
+
|
251
|
+
m.def("find_candidate_indices", &find_candidate_indices<double>,
|
252
|
+
"Finds candidate indices with minimum distance (float64).",
|
253
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
254
|
+
m.def("find_candidate_indices", &find_candidate_indices<float>,
|
255
|
+
"Finds candidate indices with minimum distance (float32).",
|
256
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
257
|
+
m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
|
258
|
+
"Finds candidate indices with minimum distance (int64).",
|
259
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
260
|
+
m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
|
261
|
+
"Finds candidate indices with minimum distance (int32).",
|
262
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
263
|
+
|
264
|
+
|
265
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
|
266
|
+
"Finds candidate coordinates with minimum distance (float64).",
|
267
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
268
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
|
269
|
+
"Finds candidate coordinates with minimum distance (float32).",
|
270
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
271
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
|
272
|
+
"Finds candidate coordinates with minimum distance (int64).",
|
273
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
274
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
|
275
|
+
"Finds candidate coordinates with minimum distance (int32).",
|
276
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
277
|
+
|
278
|
+
|
279
|
+
m.def("max_index_by_label", &max_index_by_label<double, double>,
|
280
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
281
|
+
m.def("max_index_by_label", &max_index_by_label<double, float>,
|
282
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
283
|
+
m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
|
284
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
285
|
+
m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
|
286
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
287
|
+
|
288
|
+
m.def("max_index_by_label", &max_index_by_label<float, double>,
|
289
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
290
|
+
m.def("max_index_by_label", &max_index_by_label<float, float>,
|
291
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
292
|
+
m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
|
293
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
294
|
+
m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
|
295
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
296
|
+
|
297
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
|
298
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
299
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
|
300
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
301
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
|
302
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
303
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
|
304
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
305
|
+
|
306
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
|
307
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
308
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
|
309
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
310
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
|
311
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
312
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
|
313
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
314
|
+
|
315
|
+
|
316
|
+
m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
|
317
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
318
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
319
|
+
"Compute running online statistics on a numpy array.");
|
320
|
+
m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
|
321
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
322
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
323
|
+
"Compute running online statistics on a numpy array.");
|
324
|
+
m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
|
325
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
326
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
327
|
+
"Compute running online statistics on a numpy array.");
|
328
|
+
m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
|
329
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
330
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
331
|
+
"Compute running online statistics on a numpy array.");
|
332
|
+
}
|
tme/matching_data.py
CHANGED
@@ -450,7 +450,7 @@ class MatchingData:
|
|
450
450
|
template_shape: NDArray,
|
451
451
|
batch_mask: NDArray = None,
|
452
452
|
pad_fourier: bool = False,
|
453
|
-
) -> Tuple[Tuple, Tuple, Tuple]:
|
453
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
454
454
|
"""
|
455
455
|
Determines an efficient shape for Fourier transforms considering zero-padding.
|
456
456
|
"""
|
@@ -479,11 +479,6 @@ class MatchingData:
|
|
479
479
|
np.subtract(target_shape, template_shape), 1 - batch_mask
|
480
480
|
)
|
481
481
|
if np.sum(shape_diff < 0):
|
482
|
-
warnings.warn(
|
483
|
-
"Template is larger than target and padding is turned off. Consider "
|
484
|
-
"swapping them or activate padding. Correcting the shift for now."
|
485
|
-
)
|
486
|
-
|
487
482
|
shape_shift = np.divide(shape_diff, 2)
|
488
483
|
offset = np.mod(shape_diff, 2)
|
489
484
|
if pad_fourier:
|
@@ -491,6 +486,11 @@ class MatchingData:
|
|
491
486
|
offset,
|
492
487
|
np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
|
493
488
|
)
|
489
|
+
else:
|
490
|
+
warnings.warn(
|
491
|
+
"Template is larger than target and padding is turned off. Consider "
|
492
|
+
"swapping them or activate padding. Correcting the shift for now."
|
493
|
+
)
|
494
494
|
|
495
495
|
shape_shift = np.add(shape_shift, offset)
|
496
496
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
@@ -498,7 +498,9 @@ class MatchingData:
|
|
498
498
|
fourier_shift = tuple(fourier_shift.astype(int))
|
499
499
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
500
500
|
|
501
|
-
def fourier_padding(
|
501
|
+
def fourier_padding(
|
502
|
+
self, pad_fourier: bool = False
|
503
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
502
504
|
"""
|
503
505
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
504
506
|
|
@@ -510,8 +512,8 @@ class MatchingData:
|
|
510
512
|
|
511
513
|
Returns
|
512
514
|
-------
|
513
|
-
Tuple[tuple of int, tuple of int, tuple of int]
|
514
|
-
Tuple with
|
515
|
+
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
516
|
+
Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
|
515
517
|
"""
|
516
518
|
return self._fourier_padding(
|
517
519
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
tme/matching_exhaustive.py
CHANGED
@@ -82,13 +82,14 @@ def _setup_template_filter_apply_target_filter(
|
|
82
82
|
fastt_shape, fastt_ft_shape = fast_shape, filter_shape
|
83
83
|
if filter_template and not pad_template_filter:
|
84
84
|
# FFT shape acrobatics for faster filter application
|
85
|
-
_, fastt_shape, _, _ = matching_data._fourier_padding(
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
)
|
85
|
+
# _, fastt_shape, _, _ = matching_data._fourier_padding(
|
86
|
+
# target_shape=be.to_numpy_array(matching_data._template.shape),
|
87
|
+
# template_shape=be.to_numpy_array(
|
88
|
+
# [1 for _ in matching_data._template.shape]
|
89
|
+
# ),
|
90
|
+
# pad_fourier=False,
|
91
|
+
# )
|
92
|
+
fastt_shape = matching_data._template.shape
|
92
93
|
matching_data.template = be.reverse(
|
93
94
|
be.topleft_pad(matching_data.template, fastt_shape)
|
94
95
|
)
|
@@ -399,7 +400,8 @@ def scan_subsets(
|
|
399
400
|
The template matching procedure is determined by ``matching_setup`` and
|
400
401
|
``matching_score``, which are unique to each score. In the following,
|
401
402
|
we will be using the `FLCSphericalMask` score, which is composed of
|
402
|
-
:py:meth:`flcSphericalMask_setup` and
|
403
|
+
:py:meth:`tme.matching_scores.flcSphericalMask_setup` and
|
404
|
+
:py:meth:`tme.matching_scores.corr_scoring`
|
403
405
|
|
404
406
|
>>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
405
407
|
>>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
|
tme/matching_utils.py
CHANGED
tme/preprocessing/_utils.py
CHANGED
@@ -19,8 +19,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
|
|
19
19
|
"""
|
20
20
|
Given an opening_axis, computes the shape of the remaining dimensions.
|
21
21
|
|
22
|
-
Parameters
|
23
|
-
|
22
|
+
Parameters
|
23
|
+
----------
|
24
24
|
shape : Tuple[int]
|
25
25
|
The shape of the input array.
|
26
26
|
opening_axis : int
|
@@ -28,8 +28,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
|
|
28
28
|
reduce_dim : bool, optional (default=False)
|
29
29
|
Whether to reduce the dimensionality after tilting.
|
30
30
|
|
31
|
-
Returns
|
32
|
-
|
31
|
+
Returns
|
32
|
+
-------
|
33
33
|
Tuple[int]
|
34
34
|
The shape of the array after tilting.
|
35
35
|
"""
|
@@ -44,13 +44,13 @@ def centered_grid(shape: Tuple[int]) -> NDArray:
|
|
44
44
|
"""
|
45
45
|
Generate an integer valued grid centered around size // 2
|
46
46
|
|
47
|
-
Parameters
|
48
|
-
|
47
|
+
Parameters
|
48
|
+
----------
|
49
49
|
shape : Tuple[int]
|
50
50
|
The shape of the grid.
|
51
51
|
|
52
|
-
Returns
|
53
|
-
|
52
|
+
Returns
|
53
|
+
-------
|
54
54
|
NDArray
|
55
55
|
The centered grid.
|
56
56
|
"""
|
@@ -70,8 +70,8 @@ def frequency_grid_at_angle(
|
|
70
70
|
"""
|
71
71
|
Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
|
72
72
|
|
73
|
-
Parameters
|
74
|
-
|
73
|
+
Parameters
|
74
|
+
----------
|
75
75
|
shape : Tuple[int]
|
76
76
|
The shape of the grid.
|
77
77
|
angle : float
|
@@ -128,8 +128,8 @@ def fftfreqn(
|
|
128
128
|
"""
|
129
129
|
Generate the n-dimensional discrete Fourier transform sample frequencies.
|
130
130
|
|
131
|
-
Parameters
|
132
|
-
|
131
|
+
Parameters
|
132
|
+
----------
|
133
133
|
shape : Tuple[int]
|
134
134
|
The shape of the data.
|
135
135
|
sampling_rate : float or Tuple[float]
|
@@ -180,8 +180,8 @@ def crop_real_fourier(data: BackendArray) -> BackendArray:
|
|
180
180
|
"""
|
181
181
|
Crop the real part of a Fourier transform.
|
182
182
|
|
183
|
-
Parameters
|
184
|
-
|
183
|
+
Parameters
|
184
|
+
----------
|
185
185
|
data : BackendArray
|
186
186
|
The Fourier transformed data.
|
187
187
|
|
@@ -20,7 +20,6 @@ class ComposableFilter(ABC):
|
|
20
20
|
|
21
21
|
Parameters
|
22
22
|
----------
|
23
|
-
|
24
23
|
*args : tuple
|
25
24
|
Variable length argument list.
|
26
25
|
**kwargs : dict
|
@@ -28,7 +27,6 @@ class ComposableFilter(ABC):
|
|
28
27
|
|
29
28
|
Returns
|
30
29
|
-------
|
31
|
-
|
32
30
|
Dict
|
33
31
|
A dictionary representing the result of the filtering operation.
|
34
32
|
"""
|
tme/preprocessing/compose.py
CHANGED
@@ -17,13 +17,13 @@ class Compose:
|
|
17
17
|
This class allows composing multiple transformations together. Each transformation
|
18
18
|
is expected to be a callable that accepts keyword arguments and returns metadata.
|
19
19
|
|
20
|
-
Parameters
|
21
|
-
|
20
|
+
Parameters
|
21
|
+
----------
|
22
22
|
transforms : Tuple[object]
|
23
23
|
A tuple containing transformation objects.
|
24
24
|
|
25
|
-
Returns
|
26
|
-
|
25
|
+
Returns
|
26
|
+
-------
|
27
27
|
Dict
|
28
28
|
Metadata resulting from the composed transformations.
|
29
29
|
|
@@ -18,12 +18,10 @@ from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_
|
|
18
18
|
|
19
19
|
class BandPassFilter:
|
20
20
|
"""
|
21
|
-
|
22
|
-
either by directly specifying the frequency cutoffs (discrete_bandpass) or
|
23
|
-
by using Gaussian functions (gaussian_bandpass).
|
21
|
+
Generate bandpass filters in Fourier space.
|
24
22
|
|
25
|
-
Parameters
|
26
|
-
|
23
|
+
Parameters
|
24
|
+
----------
|
27
25
|
lowpass : float, optional
|
28
26
|
The lowpass cutoff, defaults to None.
|
29
27
|
highpass : float, optional
|
@@ -67,8 +65,8 @@ class BandPassFilter:
|
|
67
65
|
"""
|
68
66
|
Generate a bandpass filter using discrete frequency cutoffs.
|
69
67
|
|
70
|
-
Parameters
|
71
|
-
|
68
|
+
Parameters
|
69
|
+
----------
|
72
70
|
shape : tuple of int
|
73
71
|
The shape of the bandpass filter.
|
74
72
|
lowpass : float
|
@@ -84,8 +82,8 @@ class BandPassFilter:
|
|
84
82
|
**kwargs : dict
|
85
83
|
Additional keyword arguments.
|
86
84
|
|
87
|
-
Returns
|
88
|
-
|
85
|
+
Returns
|
86
|
+
-------
|
89
87
|
BackendArray
|
90
88
|
The bandpass filter in Fourier space.
|
91
89
|
"""
|
@@ -98,17 +96,18 @@ class BandPassFilter:
|
|
98
96
|
shape_is_real_fourier=shape_is_real_fourier,
|
99
97
|
compute_euclidean_norm=True,
|
100
98
|
)
|
101
|
-
|
102
|
-
|
103
|
-
highpass = 1e10 if highpass is None else highpass
|
99
|
+
grid = be.to_backend_array(grid)
|
100
|
+
sampling_rate = be.to_backend_array(sampling_rate)
|
104
101
|
|
105
102
|
highcut = grid.max()
|
106
|
-
if lowpass
|
107
|
-
highcut =
|
108
|
-
lowcut = np.max(2 * sampling_rate / highpass)
|
103
|
+
if lowpass is not None:
|
104
|
+
highcut = be.max(2 * sampling_rate / lowpass)
|
109
105
|
|
110
|
-
|
106
|
+
lowcut = 0
|
107
|
+
if highpass is not None:
|
108
|
+
lowcut = be.max(2 * sampling_rate / highpass)
|
111
109
|
|
110
|
+
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
|
112
111
|
bandpass_filter = shift_fourier(
|
113
112
|
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
114
113
|
)
|
@@ -129,10 +128,10 @@ class BandPassFilter:
|
|
129
128
|
**kwargs,
|
130
129
|
) -> BackendArray:
|
131
130
|
"""
|
132
|
-
Generate a bandpass filter using
|
131
|
+
Generate a bandpass filter using Gaussians.
|
133
132
|
|
134
|
-
Parameters
|
135
|
-
|
133
|
+
Parameters
|
134
|
+
----------
|
136
135
|
shape : tuple of int
|
137
136
|
The shape of the bandpass filter.
|
138
137
|
lowpass : float
|
@@ -148,8 +147,8 @@ class BandPassFilter:
|
|
148
147
|
**kwargs : dict
|
149
148
|
Additional keyword arguments.
|
150
149
|
|
151
|
-
Returns
|
152
|
-
|
150
|
+
Returns
|
151
|
+
-------
|
153
152
|
BackendArray
|
154
153
|
The bandpass filter in Fourier space.
|
155
154
|
"""
|
@@ -216,15 +215,13 @@ class BandPassFilter:
|
|
216
215
|
|
217
216
|
class LinearWhiteningFilter:
|
218
217
|
"""
|
219
|
-
|
220
|
-
apply linear whitening to the Fourier coefficients.
|
218
|
+
Compute Fourier power spectrums and perform whitening.
|
221
219
|
|
222
|
-
Parameters
|
223
|
-
|
220
|
+
Parameters
|
221
|
+
----------
|
224
222
|
**kwargs : Dict, optional
|
225
223
|
Additional keyword arguments.
|
226
224
|
|
227
|
-
|
228
225
|
References
|
229
226
|
----------
|
230
227
|
.. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
|
@@ -243,10 +240,10 @@ class LinearWhiteningFilter:
|
|
243
240
|
data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
|
244
241
|
) -> Tuple[BackendArray, BackendArray]:
|
245
242
|
"""
|
246
|
-
Compute the spectrum of the input data.
|
243
|
+
Compute the power spectrum of the input data.
|
247
244
|
|
248
|
-
Parameters
|
249
|
-
|
245
|
+
Parameters
|
246
|
+
----------
|
250
247
|
data_rfft : BackendArray
|
251
248
|
The Fourier transform of the input data.
|
252
249
|
n_bins : int, optional
|
@@ -254,8 +251,8 @@ class LinearWhiteningFilter:
|
|
254
251
|
batch_dimension : int, optional
|
255
252
|
Batch dimension to average over.
|
256
253
|
|
257
|
-
Returns
|
258
|
-
|
254
|
+
Returns
|
255
|
+
-------
|
259
256
|
bins : BackendArray
|
260
257
|
Array containing the bin indices for the spectrum.
|
261
258
|
radial_averages : BackendArray
|
@@ -330,8 +327,8 @@ class LinearWhiteningFilter:
|
|
330
327
|
"""
|
331
328
|
Apply linear whitening to the data and return the result.
|
332
329
|
|
333
|
-
Parameters
|
334
|
-
|
330
|
+
Parameters
|
331
|
+
----------
|
335
332
|
data : BackendArray, optional
|
336
333
|
The input data, defaults to None.
|
337
334
|
data_rfft : BackendArray, optional
|
@@ -345,8 +342,8 @@ class LinearWhiteningFilter:
|
|
345
342
|
**kwargs : Dict
|
346
343
|
Additional keyword arguments.
|
347
344
|
|
348
|
-
Returns
|
349
|
-
|
345
|
+
Returns
|
346
|
+
-------
|
350
347
|
Dict
|
351
348
|
Filter data and associated parameters.
|
352
349
|
"""
|