pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/METADATA +1 -1
  5. pytme-0.2.4.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +276 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +8 -7
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +11 -4
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +11 -9
  61. tme/matching_exhaustive.py +10 -8
  62. tme/matching_utils.py +1 -0
  63. tme/preprocessing/_utils.py +14 -14
  64. tme/preprocessing/composable_filter.py +0 -2
  65. tme/preprocessing/compose.py +4 -4
  66. tme/preprocessing/frequency_filters.py +32 -35
  67. tme/preprocessing/tilt_series.py +202 -118
  68. tme/preprocessor.py +24 -246
  69. tme/structure.py +14 -14
  70. pytme-0.2.3.dist-info/RECORD +0 -75
  71. tme/matching_memory.py +0 -383
  72. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  73. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/postprocess.py +0 -0
  74. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,332 @@
1
+ /* Pybind extensions for template matching score space analyzers.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ */
7
+
8
+ #include <vector>
9
+ #include <iostream>
10
+ #include <limits>
11
+
12
+ #include <pybind11/stl.h>
13
+ #include <pybind11/numpy.h>
14
+ #include <pybind11/pybind11.h>
15
+
16
+ namespace py = pybind11;
17
+
18
+ template <typename T>
19
+ void absolute_minimum_deviation(
20
+ py::array_t<T, py::array::c_style> coordinates,
21
+ py::array_t<T, py::array::c_style> output) {
22
+ auto coordinates_data = coordinates.data();
23
+ auto output_data = output.mutable_data();
24
+ int n = coordinates.shape(0);
25
+ int k = coordinates.shape(1);
26
+ int ik, jk, in, jn;
27
+
28
+ for (int i = 0; i < n; ++i) {
29
+ ik = i * k;
30
+ in = i * n;
31
+ for (int j = i + 1; j < n; ++j) {
32
+ jk = j * k;
33
+ jn = j * n;
34
+ T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
35
+ for (int p = 1; p < k; ++p) {
36
+ min_distance = std::min(min_distance,
37
+ std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
38
+ }
39
+ output_data[in + j] = min_distance;
40
+ output_data[jn + i] = min_distance;
41
+ }
42
+ output_data[in + i] = 0;
43
+ }
44
+ }
45
+
46
+ template <typename T>
47
+ std::pair<double, std::pair<int, int>> max_euclidean_distance(
48
+ py::array_t<T, py::array::c_style> coordinates) {
49
+ auto coordinates_data = coordinates.data();
50
+ int n = coordinates.shape(0);
51
+ int k = coordinates.shape(1);
52
+
53
+ double distance = 0.0;
54
+ double difference = 0.0;
55
+ double max_distance = -1;
56
+ double squared_distances = 0.0;
57
+
58
+ int ik, jk;
59
+ int max_i = -1, max_j = -1;
60
+
61
+ for (int i = 0; i < n; ++i) {
62
+ ik = i * k;
63
+ for (int j = i + 1; j < n; ++j) {
64
+ jk = j * k;
65
+ squared_distances = 0.0;
66
+ for (int p = 0; p < k; ++p) {
67
+ difference = static_cast<double>(
68
+ coordinates_data[ik + p] - coordinates_data[jk + p]
69
+ );
70
+ squared_distances += (difference * difference);
71
+ }
72
+ distance = std::sqrt(squared_distances);
73
+ if (distance > max_distance) {
74
+ max_distance = distance;
75
+ max_i = i;
76
+ max_j = j;
77
+ }
78
+ }
79
+ }
80
+
81
+ return std::make_pair(max_distance, std::make_pair(max_i, max_j));
82
+ }
83
+
84
+
85
+ template <typename T>
86
+ inline py::array_t<int, py::array::c_style> find_candidate_indices(
87
+ py::array_t<T, py::array::c_style> coordinates,
88
+ T min_distance) {
89
+ auto coordinates_data = coordinates.data();
90
+ int n = coordinates.shape(0);
91
+ int k = coordinates.shape(1);
92
+ int ik, jk;
93
+
94
+ std::vector<int> candidate_indices;
95
+ candidate_indices.reserve(n / 2);
96
+ candidate_indices.push_back(0);
97
+
98
+ for (int i = 1; i < n; ++i) {
99
+ bool is_candidate = true;
100
+ ik = i * k;
101
+ for (int candidate_index : candidate_indices) {
102
+ jk = candidate_index * k;
103
+ T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
104
+ for (int p = 1; p < k; ++p) {
105
+ distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
106
+ }
107
+ distance = std::sqrt(distance);
108
+ if (distance <= min_distance) {
109
+ is_candidate = false;
110
+ break;
111
+ }
112
+ }
113
+ if (is_candidate) {
114
+ candidate_indices.push_back(i);
115
+ }
116
+ }
117
+
118
+ py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
119
+ auto output_data = output.mutable_data();
120
+
121
+ for (size_t i = 0; i < candidate_indices.size(); ++i) {
122
+ output_data[i] = candidate_indices[i];
123
+ }
124
+
125
+ return output;
126
+ }
127
+
128
+ template <typename T>
129
+ py::array_t<T, py::array::c_style> find_candidate_coordinates(
130
+ py::array_t<T, py::array::c_style> coordinates,
131
+ T min_distance) {
132
+
133
+ py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
134
+ coordinates, min_distance);
135
+ auto candidate_indices_data = candidate_indices_array.data();
136
+ int num_candidates = candidate_indices_array.shape(0);
137
+ int k = coordinates.shape(1);
138
+ auto coordinates_data = coordinates.data();
139
+
140
+ py::array_t<T, py::array::c_style> output({num_candidates, k});
141
+ auto output_data = output.mutable_data();
142
+
143
+ for (int i = 0; i < num_candidates; ++i) {
144
+ int candidate_index = candidate_indices_data[i] * k;
145
+ std::copy(
146
+ coordinates_data + candidate_index,
147
+ coordinates_data + candidate_index + k,
148
+ output_data + i * k
149
+ );
150
+ }
151
+
152
+ return output;
153
+ }
154
+
155
+ template <typename U, typename T>
156
+ py::dict max_index_by_label(
157
+ py::array_t<U, py::array::c_style> labels,
158
+ py::array_t<T, py::array::c_style> scores
159
+ ) {
160
+
161
+ const U* labels_ptr = labels.data();
162
+ const T* scores_ptr = scores.data();
163
+
164
+ std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
165
+
166
+ U label;
167
+ T score;
168
+ for (ssize_t i = 0; i < labels.size(); ++i) {
169
+ label = labels_ptr[i];
170
+ score = scores_ptr[i];
171
+
172
+ auto it = max_scores.insert({label, {score, i}});
173
+
174
+ if (score > it.first->second.first) {
175
+ it.first->second = {score, i};
176
+ }
177
+ }
178
+
179
+ py::dict ret;
180
+ for (auto& item: max_scores) {
181
+ ret[py::cast(item.first)] = py::cast(item.second.second);
182
+ }
183
+
184
+ return ret;
185
+ }
186
+
187
+
188
+ template <typename T>
189
+ py::tuple online_statistics(
190
+ py::array_t<T, py::array::c_style> arr,
191
+ unsigned long long int n = 0,
192
+ double rmean = 0,
193
+ double ssqd = 0,
194
+ T reference = 0) {
195
+
196
+ auto in = arr.data();
197
+ int size = arr.size();
198
+
199
+ T max_value = std::numeric_limits<T>::lowest();
200
+ T min_value = std::numeric_limits<T>::max();
201
+ double delta, delta_prime;
202
+
203
+ unsigned long long int nbetter_or_equal = 0;
204
+
205
+ for(int i = 0; i < size; i++){
206
+ n++;
207
+ delta = in[i] - rmean;
208
+ rmean += delta / n;
209
+ delta_prime = in[i] - rmean;
210
+ ssqd += delta * delta_prime;
211
+
212
+ max_value = std::max(in[i], max_value);
213
+ min_value = std::min(in[i], min_value);
214
+ if (in[i] >= reference)
215
+ nbetter_or_equal++;
216
+ }
217
+
218
+ return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
219
+ }
220
+
221
+ PYBIND11_MODULE(extensions, m) {
222
+
223
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
224
+ "Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
225
+ py::arg("coordinates"), py::arg("output"));
226
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
227
+ "Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
228
+ py::arg("coordinates"), py::arg("output"));
229
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
230
+ "Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
231
+ py::arg("coordinates"), py::arg("output"));
232
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
233
+ "Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
234
+ py::arg("coordinates"), py::arg("output"));
235
+
236
+
237
+ m.def("max_euclidean_distance", max_euclidean_distance<double>,
238
+ "Identify pair of points with maximal euclidean distance (float64).",
239
+ py::arg("coordinates"));
240
+ m.def("max_euclidean_distance", max_euclidean_distance<float>,
241
+ "Identify pair of points with maximal euclidean distance (float32).",
242
+ py::arg("coordinates"));
243
+ m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
244
+ "Identify pair of points with maximal euclidean distance (int64).",
245
+ py::arg("coordinates"));
246
+ m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
247
+ "Identify pair of points with maximal euclidean distance (int32).",
248
+ py::arg("coordinates"));
249
+
250
+
251
+ m.def("find_candidate_indices", &find_candidate_indices<double>,
252
+ "Finds candidate indices with minimum distance (float64).",
253
+ py::arg("coordinates"), py::arg("min_distance"));
254
+ m.def("find_candidate_indices", &find_candidate_indices<float>,
255
+ "Finds candidate indices with minimum distance (float32).",
256
+ py::arg("coordinates"), py::arg("min_distance"));
257
+ m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
258
+ "Finds candidate indices with minimum distance (int64).",
259
+ py::arg("coordinates"), py::arg("min_distance"));
260
+ m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
261
+ "Finds candidate indices with minimum distance (int32).",
262
+ py::arg("coordinates"), py::arg("min_distance"));
263
+
264
+
265
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
266
+ "Finds candidate coordinates with minimum distance (float64).",
267
+ py::arg("coordinates"), py::arg("min_distance"));
268
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
269
+ "Finds candidate coordinates with minimum distance (float32).",
270
+ py::arg("coordinates"), py::arg("min_distance"));
271
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
272
+ "Finds candidate coordinates with minimum distance (int64).",
273
+ py::arg("coordinates"), py::arg("min_distance"));
274
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
275
+ "Finds candidate coordinates with minimum distance (int32).",
276
+ py::arg("coordinates"), py::arg("min_distance"));
277
+
278
+
279
+ m.def("max_index_by_label", &max_index_by_label<double, double>,
280
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
281
+ m.def("max_index_by_label", &max_index_by_label<double, float>,
282
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
283
+ m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
284
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
285
+ m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
286
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
287
+
288
+ m.def("max_index_by_label", &max_index_by_label<float, double>,
289
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
290
+ m.def("max_index_by_label", &max_index_by_label<float, float>,
291
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
292
+ m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
293
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
294
+ m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
295
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
296
+
297
+ m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
298
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
299
+ m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
300
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
301
+ m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
302
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
303
+ m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
304
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
305
+
306
+ m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
307
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
308
+ m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
309
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
310
+ m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
311
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
312
+ m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
313
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
314
+
315
+
316
+ m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
317
+ py::arg("n") = 0, py::arg("rmean") = 0,
318
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
319
+ "Compute running online statistics on a numpy array.");
320
+ m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
321
+ py::arg("n") = 0, py::arg("rmean") = 0,
322
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
323
+ "Compute running online statistics on a numpy array.");
324
+ m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
325
+ py::arg("n") = 0, py::arg("rmean") = 0,
326
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
327
+ "Compute running online statistics on a numpy array.");
328
+ m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
329
+ py::arg("n") = 0, py::arg("rmean") = 0,
330
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
331
+ "Compute running online statistics on a numpy array.");
332
+ }
tme/matching_data.py CHANGED
@@ -450,7 +450,7 @@ class MatchingData:
450
450
  template_shape: NDArray,
451
451
  batch_mask: NDArray = None,
452
452
  pad_fourier: bool = False,
453
- ) -> Tuple[Tuple, Tuple, Tuple]:
453
+ ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
454
454
  """
455
455
  Determines an efficient shape for Fourier transforms considering zero-padding.
456
456
  """
@@ -479,11 +479,6 @@ class MatchingData:
479
479
  np.subtract(target_shape, template_shape), 1 - batch_mask
480
480
  )
481
481
  if np.sum(shape_diff < 0):
482
- warnings.warn(
483
- "Template is larger than target and padding is turned off. Consider "
484
- "swapping them or activate padding. Correcting the shift for now."
485
- )
486
-
487
482
  shape_shift = np.divide(shape_diff, 2)
488
483
  offset = np.mod(shape_diff, 2)
489
484
  if pad_fourier:
@@ -491,6 +486,11 @@ class MatchingData:
491
486
  offset,
492
487
  np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
493
488
  )
489
+ else:
490
+ warnings.warn(
491
+ "Template is larger than target and padding is turned off. Consider "
492
+ "swapping them or activate padding. Correcting the shift for now."
493
+ )
494
494
 
495
495
  shape_shift = np.add(shape_shift, offset)
496
496
  fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
@@ -498,7 +498,9 @@ class MatchingData:
498
498
  fourier_shift = tuple(fourier_shift.astype(int))
499
499
  return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
500
500
 
501
- def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
501
+ def fourier_padding(
502
+ self, pad_fourier: bool = False
503
+ ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
502
504
  """
503
505
  Computes efficient shape four Fourier transforms and potential associated shifts.
504
506
 
@@ -510,8 +512,8 @@ class MatchingData:
510
512
 
511
513
  Returns
512
514
  -------
513
- Tuple[tuple of int, tuple of int, tuple of int]
514
- Tuple with real and complex Fourier transform shape, and corresponding shift.
515
+ Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
516
+ Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
515
517
  """
516
518
  return self._fourier_padding(
517
519
  target_shape=be.to_numpy_array(self._output_target_shape),
@@ -82,13 +82,14 @@ def _setup_template_filter_apply_target_filter(
82
82
  fastt_shape, fastt_ft_shape = fast_shape, filter_shape
83
83
  if filter_template and not pad_template_filter:
84
84
  # FFT shape acrobatics for faster filter application
85
- _, fastt_shape, _, _ = matching_data._fourier_padding(
86
- target_shape=be.to_numpy_array(matching_data._template.shape),
87
- template_shape=be.to_numpy_array(
88
- [1 for _ in matching_data._template.shape]
89
- ),
90
- pad_fourier=False,
91
- )
85
+ # _, fastt_shape, _, _ = matching_data._fourier_padding(
86
+ # target_shape=be.to_numpy_array(matching_data._template.shape),
87
+ # template_shape=be.to_numpy_array(
88
+ # [1 for _ in matching_data._template.shape]
89
+ # ),
90
+ # pad_fourier=False,
91
+ # )
92
+ fastt_shape = matching_data._template.shape
92
93
  matching_data.template = be.reverse(
93
94
  be.topleft_pad(matching_data.template, fastt_shape)
94
95
  )
@@ -399,7 +400,8 @@ def scan_subsets(
399
400
  The template matching procedure is determined by ``matching_setup`` and
400
401
  ``matching_score``, which are unique to each score. In the following,
401
402
  we will be using the `FLCSphericalMask` score, which is composed of
402
- :py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
403
+ :py:meth:`tme.matching_scores.flcSphericalMask_setup` and
404
+ :py:meth:`tme.matching_scores.corr_scoring`
403
405
 
404
406
  >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
405
407
  >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
tme/matching_utils.py CHANGED
@@ -645,6 +645,7 @@ def get_rotation_matrices(
645
645
  dets = np.linalg.det(ret)
646
646
  neg_dets = dets < 0
647
647
  ret[neg_dets, :, -1] *= -1
648
+ ret[0] = np.eye(dim, dtype = ret.dtype)
648
649
  return ret
649
650
 
650
651
 
@@ -19,8 +19,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
19
19
  """
20
20
  Given an opening_axis, computes the shape of the remaining dimensions.
21
21
 
22
- Parameters:
23
- -----------
22
+ Parameters
23
+ ----------
24
24
  shape : Tuple[int]
25
25
  The shape of the input array.
26
26
  opening_axis : int
@@ -28,8 +28,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
28
28
  reduce_dim : bool, optional (default=False)
29
29
  Whether to reduce the dimensionality after tilting.
30
30
 
31
- Returns:
32
- --------
31
+ Returns
32
+ -------
33
33
  Tuple[int]
34
34
  The shape of the array after tilting.
35
35
  """
@@ -44,13 +44,13 @@ def centered_grid(shape: Tuple[int]) -> NDArray:
44
44
  """
45
45
  Generate an integer valued grid centered around size // 2
46
46
 
47
- Parameters:
48
- -----------
47
+ Parameters
48
+ ----------
49
49
  shape : Tuple[int]
50
50
  The shape of the grid.
51
51
 
52
- Returns:
53
- --------
52
+ Returns
53
+ -------
54
54
  NDArray
55
55
  The centered grid.
56
56
  """
@@ -70,8 +70,8 @@ def frequency_grid_at_angle(
70
70
  """
71
71
  Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
72
72
 
73
- Parameters:
74
- -----------
73
+ Parameters
74
+ ----------
75
75
  shape : Tuple[int]
76
76
  The shape of the grid.
77
77
  angle : float
@@ -128,8 +128,8 @@ def fftfreqn(
128
128
  """
129
129
  Generate the n-dimensional discrete Fourier transform sample frequencies.
130
130
 
131
- Parameters:
132
- -----------
131
+ Parameters
132
+ ----------
133
133
  shape : Tuple[int]
134
134
  The shape of the data.
135
135
  sampling_rate : float or Tuple[float]
@@ -180,8 +180,8 @@ def crop_real_fourier(data: BackendArray) -> BackendArray:
180
180
  """
181
181
  Crop the real part of a Fourier transform.
182
182
 
183
- Parameters:
184
- -----------
183
+ Parameters
184
+ ----------
185
185
  data : BackendArray
186
186
  The Fourier transformed data.
187
187
 
@@ -20,7 +20,6 @@ class ComposableFilter(ABC):
20
20
 
21
21
  Parameters
22
22
  ----------
23
-
24
23
  *args : tuple
25
24
  Variable length argument list.
26
25
  **kwargs : dict
@@ -28,7 +27,6 @@ class ComposableFilter(ABC):
28
27
 
29
28
  Returns
30
29
  -------
31
-
32
30
  Dict
33
31
  A dictionary representing the result of the filtering operation.
34
32
  """
@@ -17,13 +17,13 @@ class Compose:
17
17
  This class allows composing multiple transformations together. Each transformation
18
18
  is expected to be a callable that accepts keyword arguments and returns metadata.
19
19
 
20
- Parameters:
21
- -----------
20
+ Parameters
21
+ ----------
22
22
  transforms : Tuple[object]
23
23
  A tuple containing transformation objects.
24
24
 
25
- Returns:
26
- --------
25
+ Returns
26
+ -------
27
27
  Dict
28
28
  Metadata resulting from the composed transformations.
29
29
 
@@ -18,12 +18,10 @@ from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_
18
18
 
19
19
  class BandPassFilter:
20
20
  """
21
- This class provides methods to generate bandpass filters in Fourier space,
22
- either by directly specifying the frequency cutoffs (discrete_bandpass) or
23
- by using Gaussian functions (gaussian_bandpass).
21
+ Generate bandpass filters in Fourier space.
24
22
 
25
- Parameters:
26
- -----------
23
+ Parameters
24
+ ----------
27
25
  lowpass : float, optional
28
26
  The lowpass cutoff, defaults to None.
29
27
  highpass : float, optional
@@ -67,8 +65,8 @@ class BandPassFilter:
67
65
  """
68
66
  Generate a bandpass filter using discrete frequency cutoffs.
69
67
 
70
- Parameters:
71
- -----------
68
+ Parameters
69
+ ----------
72
70
  shape : tuple of int
73
71
  The shape of the bandpass filter.
74
72
  lowpass : float
@@ -84,8 +82,8 @@ class BandPassFilter:
84
82
  **kwargs : dict
85
83
  Additional keyword arguments.
86
84
 
87
- Returns:
88
- --------
85
+ Returns
86
+ -------
89
87
  BackendArray
90
88
  The bandpass filter in Fourier space.
91
89
  """
@@ -98,17 +96,18 @@ class BandPassFilter:
98
96
  shape_is_real_fourier=shape_is_real_fourier,
99
97
  compute_euclidean_norm=True,
100
98
  )
101
-
102
- lowpass = 0 if lowpass is None else lowpass
103
- highpass = 1e10 if highpass is None else highpass
99
+ grid = be.to_backend_array(grid)
100
+ sampling_rate = be.to_backend_array(sampling_rate)
104
101
 
105
102
  highcut = grid.max()
106
- if lowpass > 0:
107
- highcut = np.max(2 * sampling_rate / lowpass)
108
- lowcut = np.max(2 * sampling_rate / highpass)
103
+ if lowpass is not None:
104
+ highcut = be.max(2 * sampling_rate / lowpass)
109
105
 
110
- bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
106
+ lowcut = 0
107
+ if highpass is not None:
108
+ lowcut = be.max(2 * sampling_rate / highpass)
111
109
 
110
+ bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
112
111
  bandpass_filter = shift_fourier(
113
112
  data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
114
113
  )
@@ -129,10 +128,10 @@ class BandPassFilter:
129
128
  **kwargs,
130
129
  ) -> BackendArray:
131
130
  """
132
- Generate a bandpass filter using Gaussian functions.
131
+ Generate a bandpass filter using Gaussians.
133
132
 
134
- Parameters:
135
- -----------
133
+ Parameters
134
+ ----------
136
135
  shape : tuple of int
137
136
  The shape of the bandpass filter.
138
137
  lowpass : float
@@ -148,8 +147,8 @@ class BandPassFilter:
148
147
  **kwargs : dict
149
148
  Additional keyword arguments.
150
149
 
151
- Returns:
152
- --------
150
+ Returns
151
+ -------
153
152
  BackendArray
154
153
  The bandpass filter in Fourier space.
155
154
  """
@@ -216,15 +215,13 @@ class BandPassFilter:
216
215
 
217
216
  class LinearWhiteningFilter:
218
217
  """
219
- This class provides methods to compute the spectrum of the input data and
220
- apply linear whitening to the Fourier coefficients.
218
+ Compute Fourier power spectrums and perform whitening.
221
219
 
222
- Parameters:
223
- -----------
220
+ Parameters
221
+ ----------
224
222
  **kwargs : Dict, optional
225
223
  Additional keyword arguments.
226
224
 
227
-
228
225
  References
229
226
  ----------
230
227
  .. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
@@ -243,10 +240,10 @@ class LinearWhiteningFilter:
243
240
  data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
244
241
  ) -> Tuple[BackendArray, BackendArray]:
245
242
  """
246
- Compute the spectrum of the input data.
243
+ Compute the power spectrum of the input data.
247
244
 
248
- Parameters:
249
- -----------
245
+ Parameters
246
+ ----------
250
247
  data_rfft : BackendArray
251
248
  The Fourier transform of the input data.
252
249
  n_bins : int, optional
@@ -254,8 +251,8 @@ class LinearWhiteningFilter:
254
251
  batch_dimension : int, optional
255
252
  Batch dimension to average over.
256
253
 
257
- Returns:
258
- --------
254
+ Returns
255
+ -------
259
256
  bins : BackendArray
260
257
  Array containing the bin indices for the spectrum.
261
258
  radial_averages : BackendArray
@@ -330,8 +327,8 @@ class LinearWhiteningFilter:
330
327
  """
331
328
  Apply linear whitening to the data and return the result.
332
329
 
333
- Parameters:
334
- -----------
330
+ Parameters
331
+ ----------
335
332
  data : BackendArray, optional
336
333
  The input data, defaults to None.
337
334
  data_rfft : BackendArray, optional
@@ -345,8 +342,8 @@ class LinearWhiteningFilter:
345
342
  **kwargs : Dict
346
343
  Additional keyword arguments.
347
344
 
348
- Returns:
349
- --------
345
+ Returns
346
+ -------
350
347
  Dict
351
348
  Filter data and associated parameters.
352
349
  """