pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.5__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/match_template.py +8 -8
- {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocess.py +22 -6
- {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocessor_gui.py +9 -14
- {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/METADATA +1 -1
- pytme-0.2.5.dist-info/RECORD +119 -0
- {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/WHEEL +1 -1
- {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/top_level.txt +1 -0
- scripts/match_template.py +8 -8
- scripts/preprocess.py +22 -6
- scripts/preprocessor_gui.py +9 -14
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/backends/jax_backend.py +3 -9
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +14 -10
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +14 -12
- tme/matching_exhaustive.py +17 -15
- tme/matching_optimization.py +215 -208
- tme/matching_utils.py +1 -0
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +0 -2
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +198 -117
- tme/preprocessor.py +24 -246
- tme/structure.py +22 -22
- pytme-0.2.3.dist-info/RECORD +0 -75
- tme/matching_memory.py +0 -383
- {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/postprocess.py +0 -0
- {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/LICENSE +0 -0
- {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,332 @@
|
|
1
|
+
/* Pybind extensions for template matching score space analyzers.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <vector>
|
9
|
+
#include <iostream>
|
10
|
+
#include <limits>
|
11
|
+
|
12
|
+
#include <pybind11/stl.h>
|
13
|
+
#include <pybind11/numpy.h>
|
14
|
+
#include <pybind11/pybind11.h>
|
15
|
+
|
16
|
+
namespace py = pybind11;
|
17
|
+
|
18
|
+
template <typename T>
|
19
|
+
void absolute_minimum_deviation(
|
20
|
+
py::array_t<T, py::array::c_style> coordinates,
|
21
|
+
py::array_t<T, py::array::c_style> output) {
|
22
|
+
auto coordinates_data = coordinates.data();
|
23
|
+
auto output_data = output.mutable_data();
|
24
|
+
int n = coordinates.shape(0);
|
25
|
+
int k = coordinates.shape(1);
|
26
|
+
int ik, jk, in, jn;
|
27
|
+
|
28
|
+
for (int i = 0; i < n; ++i) {
|
29
|
+
ik = i * k;
|
30
|
+
in = i * n;
|
31
|
+
for (int j = i + 1; j < n; ++j) {
|
32
|
+
jk = j * k;
|
33
|
+
jn = j * n;
|
34
|
+
T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
|
35
|
+
for (int p = 1; p < k; ++p) {
|
36
|
+
min_distance = std::min(min_distance,
|
37
|
+
std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
|
38
|
+
}
|
39
|
+
output_data[in + j] = min_distance;
|
40
|
+
output_data[jn + i] = min_distance;
|
41
|
+
}
|
42
|
+
output_data[in + i] = 0;
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
template <typename T>
|
47
|
+
std::pair<double, std::pair<int, int>> max_euclidean_distance(
|
48
|
+
py::array_t<T, py::array::c_style> coordinates) {
|
49
|
+
auto coordinates_data = coordinates.data();
|
50
|
+
int n = coordinates.shape(0);
|
51
|
+
int k = coordinates.shape(1);
|
52
|
+
|
53
|
+
double distance = 0.0;
|
54
|
+
double difference = 0.0;
|
55
|
+
double max_distance = -1;
|
56
|
+
double squared_distances = 0.0;
|
57
|
+
|
58
|
+
int ik, jk;
|
59
|
+
int max_i = -1, max_j = -1;
|
60
|
+
|
61
|
+
for (int i = 0; i < n; ++i) {
|
62
|
+
ik = i * k;
|
63
|
+
for (int j = i + 1; j < n; ++j) {
|
64
|
+
jk = j * k;
|
65
|
+
squared_distances = 0.0;
|
66
|
+
for (int p = 0; p < k; ++p) {
|
67
|
+
difference = static_cast<double>(
|
68
|
+
coordinates_data[ik + p] - coordinates_data[jk + p]
|
69
|
+
);
|
70
|
+
squared_distances += (difference * difference);
|
71
|
+
}
|
72
|
+
distance = std::sqrt(squared_distances);
|
73
|
+
if (distance > max_distance) {
|
74
|
+
max_distance = distance;
|
75
|
+
max_i = i;
|
76
|
+
max_j = j;
|
77
|
+
}
|
78
|
+
}
|
79
|
+
}
|
80
|
+
|
81
|
+
return std::make_pair(max_distance, std::make_pair(max_i, max_j));
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
template <typename T>
|
86
|
+
inline py::array_t<int, py::array::c_style> find_candidate_indices(
|
87
|
+
py::array_t<T, py::array::c_style> coordinates,
|
88
|
+
T min_distance) {
|
89
|
+
auto coordinates_data = coordinates.data();
|
90
|
+
int n = coordinates.shape(0);
|
91
|
+
int k = coordinates.shape(1);
|
92
|
+
int ik, jk;
|
93
|
+
|
94
|
+
std::vector<int> candidate_indices;
|
95
|
+
candidate_indices.reserve(n / 2);
|
96
|
+
candidate_indices.push_back(0);
|
97
|
+
|
98
|
+
for (int i = 1; i < n; ++i) {
|
99
|
+
bool is_candidate = true;
|
100
|
+
ik = i * k;
|
101
|
+
for (int candidate_index : candidate_indices) {
|
102
|
+
jk = candidate_index * k;
|
103
|
+
T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
|
104
|
+
for (int p = 1; p < k; ++p) {
|
105
|
+
distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
|
106
|
+
}
|
107
|
+
distance = std::sqrt(distance);
|
108
|
+
if (distance <= min_distance) {
|
109
|
+
is_candidate = false;
|
110
|
+
break;
|
111
|
+
}
|
112
|
+
}
|
113
|
+
if (is_candidate) {
|
114
|
+
candidate_indices.push_back(i);
|
115
|
+
}
|
116
|
+
}
|
117
|
+
|
118
|
+
py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
|
119
|
+
auto output_data = output.mutable_data();
|
120
|
+
|
121
|
+
for (size_t i = 0; i < candidate_indices.size(); ++i) {
|
122
|
+
output_data[i] = candidate_indices[i];
|
123
|
+
}
|
124
|
+
|
125
|
+
return output;
|
126
|
+
}
|
127
|
+
|
128
|
+
template <typename T>
|
129
|
+
py::array_t<T, py::array::c_style> find_candidate_coordinates(
|
130
|
+
py::array_t<T, py::array::c_style> coordinates,
|
131
|
+
T min_distance) {
|
132
|
+
|
133
|
+
py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
|
134
|
+
coordinates, min_distance);
|
135
|
+
auto candidate_indices_data = candidate_indices_array.data();
|
136
|
+
int num_candidates = candidate_indices_array.shape(0);
|
137
|
+
int k = coordinates.shape(1);
|
138
|
+
auto coordinates_data = coordinates.data();
|
139
|
+
|
140
|
+
py::array_t<T, py::array::c_style> output({num_candidates, k});
|
141
|
+
auto output_data = output.mutable_data();
|
142
|
+
|
143
|
+
for (int i = 0; i < num_candidates; ++i) {
|
144
|
+
int candidate_index = candidate_indices_data[i] * k;
|
145
|
+
std::copy(
|
146
|
+
coordinates_data + candidate_index,
|
147
|
+
coordinates_data + candidate_index + k,
|
148
|
+
output_data + i * k
|
149
|
+
);
|
150
|
+
}
|
151
|
+
|
152
|
+
return output;
|
153
|
+
}
|
154
|
+
|
155
|
+
template <typename U, typename T>
|
156
|
+
py::dict max_index_by_label(
|
157
|
+
py::array_t<U, py::array::c_style> labels,
|
158
|
+
py::array_t<T, py::array::c_style> scores
|
159
|
+
) {
|
160
|
+
|
161
|
+
const U* labels_ptr = labels.data();
|
162
|
+
const T* scores_ptr = scores.data();
|
163
|
+
|
164
|
+
std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
|
165
|
+
|
166
|
+
U label;
|
167
|
+
T score;
|
168
|
+
for (ssize_t i = 0; i < labels.size(); ++i) {
|
169
|
+
label = labels_ptr[i];
|
170
|
+
score = scores_ptr[i];
|
171
|
+
|
172
|
+
auto it = max_scores.insert({label, {score, i}});
|
173
|
+
|
174
|
+
if (score > it.first->second.first) {
|
175
|
+
it.first->second = {score, i};
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
py::dict ret;
|
180
|
+
for (auto& item: max_scores) {
|
181
|
+
ret[py::cast(item.first)] = py::cast(item.second.second);
|
182
|
+
}
|
183
|
+
|
184
|
+
return ret;
|
185
|
+
}
|
186
|
+
|
187
|
+
|
188
|
+
template <typename T>
|
189
|
+
py::tuple online_statistics(
|
190
|
+
py::array_t<T, py::array::c_style> arr,
|
191
|
+
unsigned long long int n = 0,
|
192
|
+
double rmean = 0,
|
193
|
+
double ssqd = 0,
|
194
|
+
T reference = 0) {
|
195
|
+
|
196
|
+
auto in = arr.data();
|
197
|
+
int size = arr.size();
|
198
|
+
|
199
|
+
T max_value = std::numeric_limits<T>::lowest();
|
200
|
+
T min_value = std::numeric_limits<T>::max();
|
201
|
+
double delta, delta_prime;
|
202
|
+
|
203
|
+
unsigned long long int nbetter_or_equal = 0;
|
204
|
+
|
205
|
+
for(int i = 0; i < size; i++){
|
206
|
+
n++;
|
207
|
+
delta = in[i] - rmean;
|
208
|
+
rmean += delta / n;
|
209
|
+
delta_prime = in[i] - rmean;
|
210
|
+
ssqd += delta * delta_prime;
|
211
|
+
|
212
|
+
max_value = std::max(in[i], max_value);
|
213
|
+
min_value = std::min(in[i], min_value);
|
214
|
+
if (in[i] >= reference)
|
215
|
+
nbetter_or_equal++;
|
216
|
+
}
|
217
|
+
|
218
|
+
return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
|
219
|
+
}
|
220
|
+
|
221
|
+
PYBIND11_MODULE(extensions, m) {
|
222
|
+
|
223
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
|
224
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
|
225
|
+
py::arg("coordinates"), py::arg("output"));
|
226
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
|
227
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
|
228
|
+
py::arg("coordinates"), py::arg("output"));
|
229
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
|
230
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
|
231
|
+
py::arg("coordinates"), py::arg("output"));
|
232
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
|
233
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
|
234
|
+
py::arg("coordinates"), py::arg("output"));
|
235
|
+
|
236
|
+
|
237
|
+
m.def("max_euclidean_distance", max_euclidean_distance<double>,
|
238
|
+
"Identify pair of points with maximal euclidean distance (float64).",
|
239
|
+
py::arg("coordinates"));
|
240
|
+
m.def("max_euclidean_distance", max_euclidean_distance<float>,
|
241
|
+
"Identify pair of points with maximal euclidean distance (float32).",
|
242
|
+
py::arg("coordinates"));
|
243
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
|
244
|
+
"Identify pair of points with maximal euclidean distance (int64).",
|
245
|
+
py::arg("coordinates"));
|
246
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
|
247
|
+
"Identify pair of points with maximal euclidean distance (int32).",
|
248
|
+
py::arg("coordinates"));
|
249
|
+
|
250
|
+
|
251
|
+
m.def("find_candidate_indices", &find_candidate_indices<double>,
|
252
|
+
"Finds candidate indices with minimum distance (float64).",
|
253
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
254
|
+
m.def("find_candidate_indices", &find_candidate_indices<float>,
|
255
|
+
"Finds candidate indices with minimum distance (float32).",
|
256
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
257
|
+
m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
|
258
|
+
"Finds candidate indices with minimum distance (int64).",
|
259
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
260
|
+
m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
|
261
|
+
"Finds candidate indices with minimum distance (int32).",
|
262
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
263
|
+
|
264
|
+
|
265
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
|
266
|
+
"Finds candidate coordinates with minimum distance (float64).",
|
267
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
268
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
|
269
|
+
"Finds candidate coordinates with minimum distance (float32).",
|
270
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
271
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
|
272
|
+
"Finds candidate coordinates with minimum distance (int64).",
|
273
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
274
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
|
275
|
+
"Finds candidate coordinates with minimum distance (int32).",
|
276
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
277
|
+
|
278
|
+
|
279
|
+
m.def("max_index_by_label", &max_index_by_label<double, double>,
|
280
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
281
|
+
m.def("max_index_by_label", &max_index_by_label<double, float>,
|
282
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
283
|
+
m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
|
284
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
285
|
+
m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
|
286
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
287
|
+
|
288
|
+
m.def("max_index_by_label", &max_index_by_label<float, double>,
|
289
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
290
|
+
m.def("max_index_by_label", &max_index_by_label<float, float>,
|
291
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
292
|
+
m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
|
293
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
294
|
+
m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
|
295
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
296
|
+
|
297
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
|
298
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
299
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
|
300
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
301
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
|
302
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
303
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
|
304
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
305
|
+
|
306
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
|
307
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
308
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
|
309
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
310
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
|
311
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
312
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
|
313
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
314
|
+
|
315
|
+
|
316
|
+
m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
|
317
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
318
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
319
|
+
"Compute running online statistics on a numpy array.");
|
320
|
+
m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
|
321
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
322
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
323
|
+
"Compute running online statistics on a numpy array.");
|
324
|
+
m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
|
325
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
326
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
327
|
+
"Compute running online statistics on a numpy array.");
|
328
|
+
m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
|
329
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
330
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
331
|
+
"Compute running online statistics on a numpy array.");
|
332
|
+
}
|
tme/matching_data.py
CHANGED
@@ -450,7 +450,7 @@ class MatchingData:
|
|
450
450
|
template_shape: NDArray,
|
451
451
|
batch_mask: NDArray = None,
|
452
452
|
pad_fourier: bool = False,
|
453
|
-
) -> Tuple[Tuple, Tuple, Tuple]:
|
453
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
454
454
|
"""
|
455
455
|
Determines an efficient shape for Fourier transforms considering zero-padding.
|
456
456
|
"""
|
@@ -478,12 +478,8 @@ class MatchingData:
|
|
478
478
|
shape_diff = np.multiply(
|
479
479
|
np.subtract(target_shape, template_shape), 1 - batch_mask
|
480
480
|
)
|
481
|
-
|
482
|
-
|
483
|
-
"Template is larger than target and padding is turned off. Consider "
|
484
|
-
"swapping them or activate padding. Correcting the shift for now."
|
485
|
-
)
|
486
|
-
|
481
|
+
shape_mask = shape_diff < 0
|
482
|
+
if np.sum(shape_mask):
|
487
483
|
shape_shift = np.divide(shape_diff, 2)
|
488
484
|
offset = np.mod(shape_diff, 2)
|
489
485
|
if pad_fourier:
|
@@ -491,14 +487,20 @@ class MatchingData:
|
|
491
487
|
offset,
|
492
488
|
np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
|
493
489
|
)
|
494
|
-
|
495
|
-
|
490
|
+
else:
|
491
|
+
warnings.warn(
|
492
|
+
"Template is larger than target and padding is turned off. Consider "
|
493
|
+
"swapping them or activate padding. Correcting the shift for now."
|
494
|
+
)
|
495
|
+
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
496
496
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
497
497
|
|
498
498
|
fourier_shift = tuple(fourier_shift.astype(int))
|
499
499
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
500
500
|
|
501
|
-
def fourier_padding(
|
501
|
+
def fourier_padding(
|
502
|
+
self, pad_fourier: bool = False
|
503
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
502
504
|
"""
|
503
505
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
504
506
|
|
@@ -510,8 +512,8 @@ class MatchingData:
|
|
510
512
|
|
511
513
|
Returns
|
512
514
|
-------
|
513
|
-
Tuple[tuple of int, tuple of int, tuple of int]
|
514
|
-
Tuple with
|
515
|
+
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
516
|
+
Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
|
515
517
|
"""
|
516
518
|
return self._fourier_padding(
|
517
519
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
tme/matching_exhaustive.py
CHANGED
@@ -82,19 +82,20 @@ def _setup_template_filter_apply_target_filter(
|
|
82
82
|
fastt_shape, fastt_ft_shape = fast_shape, filter_shape
|
83
83
|
if filter_template and not pad_template_filter:
|
84
84
|
# FFT shape acrobatics for faster filter application
|
85
|
-
_, fastt_shape, _, _ = matching_data._fourier_padding(
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
)
|
92
|
-
|
93
|
-
|
94
|
-
)
|
95
|
-
|
96
|
-
|
97
|
-
)
|
85
|
+
# _, fastt_shape, _, _ = matching_data._fourier_padding(
|
86
|
+
# target_shape=be.to_numpy_array(matching_data._template.shape),
|
87
|
+
# template_shape=be.to_numpy_array(
|
88
|
+
# [1 for _ in matching_data._template.shape]
|
89
|
+
# ),
|
90
|
+
# pad_fourier=False,
|
91
|
+
# )
|
92
|
+
fastt_shape = matching_data._template.shape
|
93
|
+
# matching_data.template = be.reverse(
|
94
|
+
# be.topleft_pad(matching_data.template, fastt_shape)
|
95
|
+
# )
|
96
|
+
# matching_data.template_mask = be.reverse(
|
97
|
+
# be.topleft_pad(matching_data.template_mask, fastt_shape)
|
98
|
+
# )
|
98
99
|
matching_data._set_matching_dimension(
|
99
100
|
target_dims=matching_data._target_dims,
|
100
101
|
template_dims=matching_data._template_dims,
|
@@ -207,7 +208,7 @@ def scan(
|
|
207
208
|
|
208
209
|
Examples
|
209
210
|
--------
|
210
|
-
Schematically,
|
211
|
+
Schematically, :py:meth:`scan` is identical to :py:meth:`scan_subsets`,
|
211
212
|
with the distinction that the objects contained in ``matching_data`` are not
|
212
213
|
split and the search is only parallelized over angles.
|
213
214
|
Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
|
@@ -399,7 +400,8 @@ def scan_subsets(
|
|
399
400
|
The template matching procedure is determined by ``matching_setup`` and
|
400
401
|
``matching_score``, which are unique to each score. In the following,
|
401
402
|
we will be using the `FLCSphericalMask` score, which is composed of
|
402
|
-
:py:meth:`flcSphericalMask_setup` and
|
403
|
+
:py:meth:`tme.matching_scores.flcSphericalMask_setup` and
|
404
|
+
:py:meth:`tme.matching_scores.corr_scoring`
|
403
405
|
|
404
406
|
>>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
405
407
|
>>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
|