pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tme/analyzer/peaks.py ADDED
@@ -0,0 +1,953 @@
1
+ """ Implements classes to analyze outputs from exhaustive template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from functools import wraps
9
+ from abc import ABC, abstractmethod
10
+ from typing import Tuple, List, Dict, Generator
11
+
12
+ import numpy as np
13
+ from skimage.feature import peak_local_max
14
+ from skimage.registration._phase_cross_correlation import _upsampled_dft
15
+
16
+ from ._utils import score_to_cart
17
+ from ..backends import backend as be
18
+ from ..matching_utils import split_shape
19
+ from ..types import BackendArray, NDArray
20
+ from ..rotations import euler_to_rotationmatrix
21
+
22
+ __all__ = [
23
+ "PeakCaller",
24
+ "PeakCallerSort",
25
+ "PeakCallerMaximumFilter",
26
+ "PeakCallerFast",
27
+ "PeakCallerRecursiveMasking",
28
+ "PeakCallerScipy",
29
+ "PeakClustering",
30
+ "filter_points",
31
+ "filter_points_indices",
32
+ ]
33
+
34
+ PeakType = Tuple[BackendArray, BackendArray]
35
+
36
+
37
+ def _filter_bucket(
38
+ coordinates: BackendArray, min_distance: Tuple[float], scores: BackendArray = None
39
+ ) -> BackendArray:
40
+ coordinates = be.subtract(coordinates, be.min(coordinates, axis=0))
41
+ bucket_indices = be.astype(be.divide(coordinates, min_distance), int)
42
+ multiplier = be.power(
43
+ be.max(bucket_indices, axis=0) + 1, be.arange(bucket_indices.shape[1])
44
+ )
45
+ bucket_indices = be.multiply(bucket_indices, multiplier, out=bucket_indices)
46
+ flattened_indices = be.sum(bucket_indices, axis=1)
47
+
48
+ if scores is not None:
49
+ _, inverse_indices = be.unique(flattened_indices, return_inverse=True)
50
+
51
+ # Avoid bucket index overlap
52
+ scores = be.subtract(scores, be.min(scores))
53
+ scores = be.divide(scores, be.max(scores) + 0.1, out=scores)
54
+ scores = be.subtract(inverse_indices, scores)
55
+
56
+ indices = be.argsort(scores)
57
+ sorted_buckets = inverse_indices[indices]
58
+ mask = sorted_buckets[1:] != sorted_buckets[:-1]
59
+ mask = be.concatenate((be.full((1,), fill_value=1, dtype=mask.dtype), mask))
60
+ return indices[mask]
61
+
62
+ _, unique_indices = be.unique(flattened_indices, return_index=True)
63
+ return unique_indices[be.argsort(unique_indices)]
64
+
65
+
66
+ def filter_points_indices(
67
+ coordinates: BackendArray,
68
+ min_distance: float,
69
+ bucket_cutoff: int = 1e5,
70
+ batch_dims: Tuple[int] = None,
71
+ scores: BackendArray = None,
72
+ ) -> BackendArray:
73
+ from ..extensions import find_candidate_indices
74
+
75
+ if min_distance <= 0:
76
+ return be.arange(coordinates.shape[0])
77
+
78
+ n_coords = coordinates.shape[0]
79
+ if n_coords == 0:
80
+ return ()
81
+
82
+ if batch_dims is not None:
83
+ coordinates_new = be.zeros(coordinates.shape, coordinates.dtype)
84
+ coordinates_new[:] = coordinates
85
+ coordinates_new[..., batch_dims] = be.astype(
86
+ coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
87
+ )
88
+ coordinates = coordinates_new
89
+
90
+ if isinstance(coordinates, np.ndarray) and n_coords < bucket_cutoff:
91
+ if scores is not None:
92
+ sorted_indices = np.argsort(-scores)
93
+ coordinates = coordinates[sorted_indices]
94
+ indices = find_candidate_indices(coordinates, min_distance)
95
+ if scores is not None:
96
+ return sorted_indices[indices]
97
+ elif n_coords > bucket_cutoff or not isinstance(coordinates, np.ndarray):
98
+ return _filter_bucket(coordinates, min_distance, scores)
99
+
100
+ distances = be.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
101
+ distances = be.tril(distances)
102
+ keep = be.sum(distances > min_distance, axis=1)
103
+ indices = be.arange(coordinates.shape[0])
104
+ return indices[keep == indices]
105
+
106
+
107
+ def filter_points(
108
+ coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
109
+ ) -> BackendArray:
110
+ unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
111
+ coordinates = coordinates[unique_indices]
112
+ return coordinates
113
+
114
+
115
+ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
116
+ if batch_dims is None:
117
+ yield (tuple(slice(None) for _ in shape), tuple(0 for _ in shape))
118
+ return None
119
+
120
+ batch_ranges = [range(shape[dim]) for dim in batch_dims]
121
+
122
+ def _generate_slices_recursive(current_dim, current_indices):
123
+ if current_dim == len(batch_dims):
124
+ slice_list, offset_list, batch_index = [], [], 0
125
+ for i in range(len(shape)):
126
+ if i in batch_dims:
127
+ index = current_indices[batch_index]
128
+ slice_list.append(slice(index, index + 1))
129
+ offset_list.append(index)
130
+ batch_index += 1
131
+ else:
132
+ slice_list.append(slice(None))
133
+ offset_list.append(0)
134
+ yield (tuple(slice_list), tuple(offset_list))
135
+ else:
136
+ for index in batch_ranges[current_dim]:
137
+ yield from _generate_slices_recursive(
138
+ current_dim + 1, current_indices + (index,)
139
+ )
140
+
141
+ yield from _generate_slices_recursive(0, ())
142
+
143
+
144
+ class PeakCaller(ABC):
145
+ """
146
+ Base class for peak calling algorithms.
147
+
148
+ Parameters
149
+ ----------
150
+ shape : tuple of int
151
+ Score space shape. Used to determine dimension of peak calling problem.
152
+ num_peaks : int, optional
153
+ Number of candidate peaks to consider.
154
+ min_distance : int, optional
155
+ Minimum distance between peaks, 1 by default
156
+ min_boundary_distance : int, optional
157
+ Minimum distance to array boundaries, 0 by default.
158
+ min_score : float, optional
159
+ Minimum score from which to consider peaks.
160
+ max_score : float, optional
161
+ Maximum score upon which to consider peaks.
162
+ batch_dims : int, optional
163
+ Peak calling batch dimensions.
164
+ **kwargs
165
+ Optional keyword arguments.
166
+
167
+ Raises
168
+ ------
169
+ ValueError
170
+ If num_peaks is less than or equal to zero.
171
+ If min_distances is less than zero.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ shape: int,
177
+ num_peaks: int = 1000,
178
+ min_distance: int = 1,
179
+ min_boundary_distance: int = 0,
180
+ min_score: float = None,
181
+ max_score: float = None,
182
+ batch_dims: Tuple[int] = None,
183
+ shm_handler: object = None,
184
+ **kwargs,
185
+ ):
186
+ if num_peaks <= 0:
187
+ raise ValueError("num_peaks has to be larger than 0.")
188
+ if min_distance < 0:
189
+ raise ValueError("min_distance has to be non-negative.")
190
+ if min_boundary_distance < 0:
191
+ raise ValueError("min_boundary_distance has to be non-negative.")
192
+
193
+ ndim = len(shape)
194
+ self.translations = be.full(
195
+ (num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
196
+ )
197
+ self.rotations = be.full(
198
+ (num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
199
+ )
200
+ self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
201
+ self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
202
+
203
+ self.num_peaks = int(num_peaks)
204
+ self.min_distance = int(min_distance)
205
+ self.min_boundary_distance = int(min_boundary_distance)
206
+
207
+ self.batch_dims = batch_dims
208
+ if batch_dims is not None:
209
+ self.batch_dims = tuple(int(x) for x in self.batch_dims)
210
+
211
+ self.min_score, self.max_score = min_score, max_score
212
+
213
+ # Postprocessing arguments
214
+ self.fourier_shift = kwargs.get("fourier_shift", None)
215
+ self.convolution_mode = kwargs.get("convolution_mode", None)
216
+ self.targetshape = kwargs.get("targetshape", None)
217
+ self.templateshape = kwargs.get("templateshape", None)
218
+
219
+ def __iter__(self) -> Generator:
220
+ """
221
+ Returns a generator to list objects containing translation,
222
+ rotation, score and details of a given candidate.
223
+ """
224
+ self.peak_list = [
225
+ be.to_cpu_array(self.translations),
226
+ be.to_cpu_array(self.rotations),
227
+ be.to_cpu_array(self.scores),
228
+ be.to_cpu_array(self.details),
229
+ ]
230
+ yield from self.peak_list
231
+
232
+ def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
233
+ if not len(peaks):
234
+ return None
235
+
236
+ valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
237
+ if self.min_boundary_distance > 0:
238
+ upper_limit = be.subtract(
239
+ be.to_backend_array(scores.shape), self.min_boundary_distance
240
+ )
241
+ valid_peaks = be.multiply(
242
+ peaks < upper_limit,
243
+ peaks >= self.min_boundary_distance,
244
+ )
245
+ if self.batch_dims is not None:
246
+ valid_peaks[..., self.batch_dims] = True
247
+
248
+ valid_peaks = be.sum(valid_peaks, axis=1) == peaks.shape[1]
249
+
250
+ # Score thresholds and nan removal
251
+ peak_scores = scores[tuple(peaks.T)]
252
+ valid_peaks = be.multiply(peak_scores == peak_scores, valid_peaks)
253
+ if self.min_score is not None:
254
+ valid_peaks = be.multiply(peak_scores >= self.min_score, valid_peaks)
255
+
256
+ if self.max_score is not None:
257
+ valid_peaks = be.multiply(peak_scores <= self.max_score, valid_peaks)
258
+
259
+ if be.sum(valid_peaks) == 0:
260
+ return None
261
+
262
+ # Ensure consistent upper limit of input peaks for _update step
263
+ if (
264
+ be.sum(valid_peaks) > self.num_peaks
265
+ or peak_scores.shape[0] > self.num_peaks
266
+ ):
267
+ peak_indices = self._top_peaks(
268
+ peaks, scores=peak_scores * valid_peaks, num_peaks=2 * self.num_peaks
269
+ )
270
+ valid_peaks = be.full(peak_scores.shape, 0, bool)
271
+ valid_peaks[peak_indices] = True
272
+
273
+ if self.min_score is not None:
274
+ valid_peaks = be.multiply(peak_scores >= self.min_score, valid_peaks)
275
+
276
+ if self.max_score is not None:
277
+ valid_peaks = be.multiply(peak_scores <= self.max_score, valid_peaks)
278
+
279
+ if valid_peaks.shape[0] != peaks.shape[0]:
280
+ return None
281
+ return valid_peaks
282
+
283
+ def _apply_over_batch(func):
284
+ @wraps(func)
285
+ def wrapper(self, scores, rotation_matrix, **kwargs):
286
+ for subset, batch_offset in batchify(scores.shape, self.batch_dims):
287
+ yield func(
288
+ self,
289
+ scores=scores[subset],
290
+ rotation_matrix=rotation_matrix,
291
+ batch_offset=batch_offset,
292
+ **kwargs,
293
+ )
294
+
295
+ return wrapper
296
+
297
+ @_apply_over_batch
298
+ def _call_peaks(self, scores, rotation_matrix, batch_offset=None, **kwargs):
299
+ peak_positions, peak_details = self.call_peaks(
300
+ scores=scores,
301
+ rotation_matrix=rotation_matrix,
302
+ min_score=self.min_score,
303
+ max_score=self.max_score,
304
+ batch_offset=batch_offset,
305
+ **kwargs,
306
+ )
307
+ if peak_positions is None:
308
+ return None, None
309
+
310
+ peak_positions = be.to_backend_array(peak_positions)
311
+ if batch_offset is not None:
312
+ batch_offset = be.to_backend_array(batch_offset)
313
+ peak_positions = be.add(peak_positions, batch_offset, out=peak_positions)
314
+
315
+ peak_positions = be.astype(peak_positions, int)
316
+ return peak_positions, peak_details
317
+
318
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
319
+ """
320
+ Update the internal parameter store based on input array.
321
+
322
+ Parameters
323
+ ----------
324
+ scores : BackendArray
325
+ Score space data.
326
+ rotation_matrix : BackendArray
327
+ Rotation matrix used to obtain the score array.
328
+ **kwargs
329
+ Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
330
+ """
331
+ for ret in self._call_peaks(scores=scores, rotation_matrix=rotation_matrix):
332
+ peak_positions, peak_details = ret
333
+ if peak_positions is None:
334
+ continue
335
+
336
+ valid_peaks = self._get_peak_mask(peaks=peak_positions, scores=scores)
337
+ if valid_peaks is None:
338
+ continue
339
+
340
+ peak_positions = peak_positions[valid_peaks]
341
+ peak_scores = scores[tuple(peak_positions.T)]
342
+ if peak_details is not None:
343
+ peak_details = peak_details[valid_peaks]
344
+ # peak_details, peak_scores = peak_scores, -peak_details
345
+ else:
346
+ peak_details = be.full(peak_scores.shape, fill_value=-1)
347
+
348
+ rotations = be.repeat(
349
+ rotation_matrix.reshape(1, *rotation_matrix.shape),
350
+ peak_positions.shape[0],
351
+ axis=0,
352
+ )
353
+
354
+ self._update(
355
+ peak_positions=peak_positions,
356
+ peak_details=peak_details,
357
+ peak_scores=peak_scores,
358
+ rotations=rotations,
359
+ )
360
+
361
+ return None
362
+
363
+ @abstractmethod
364
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
365
+ """
366
+ Call peaks in the score space.
367
+
368
+ Parameters
369
+ ----------
370
+ scores : BackendArray
371
+ Score array.
372
+ **kwargs : dict
373
+ Optional keyword arguments passed to underlying implementations.
374
+
375
+ Returns
376
+ -------
377
+ Tuple[BackendArray, BackendArray]
378
+ Array of peak coordinates and peak details.
379
+ """
380
+
381
+ @classmethod
382
+ def merge(cls, candidates=List[List], **kwargs) -> Tuple:
383
+ """
384
+ Merge multiple instances of :py:class:`PeakCaller`.
385
+
386
+ Parameters
387
+ ----------
388
+ candidates : list of lists
389
+ Obtained by invoking list on the generator returned by __iter__.
390
+ **kwargs
391
+ Optional keyword arguments.
392
+
393
+ Returns
394
+ -------
395
+ Tuple
396
+ Tuple of translation, rotation, score and details of candidates.
397
+ """
398
+ if "shape" not in kwargs:
399
+ kwargs["shape"] = tuple(1 for _ in range(candidates[0][0].shape[1]))
400
+
401
+ base = cls(**kwargs)
402
+ for candidate in candidates:
403
+ if len(candidate) == 0:
404
+ continue
405
+ peak_positions, rotations, peak_scores, peak_details = candidate
406
+ base._update(
407
+ peak_positions=be.to_backend_array(peak_positions),
408
+ peak_details=be.to_backend_array(peak_details),
409
+ peak_scores=be.to_backend_array(peak_scores),
410
+ rotations=be.to_backend_array(rotations),
411
+ offset=kwargs.get("offset", None),
412
+ )
413
+ return tuple(base)
414
+
415
+ @staticmethod
416
+ def oversample_peaks(
417
+ scores: BackendArray, peak_positions: BackendArray, oversampling_factor: int = 8
418
+ ):
419
+ """
420
+ Refines peaks positions in the corresponding score space.
421
+
422
+ Parameters
423
+ ----------
424
+ scores : BackendArray
425
+ The d-dimensional array representing the score space.
426
+ peak_positions : BackendArray
427
+ An array of shape (n, d) containing the peak coordinates
428
+ to be refined, where n is the number of peaks and d is the
429
+ dimensionality of the score space.
430
+ oversampling_factor : int, optional
431
+ The oversampling factor for Fourier transforms. Defaults to 8.
432
+
433
+ Returns
434
+ -------
435
+ BackendArray
436
+ An array of shape (n, d) containing the refined subpixel
437
+ coordinates of the peaks.
438
+
439
+ Notes
440
+ -----
441
+ Floating point peak positions are determined by oversampling the
442
+ scores around peak_positions. The accuracy
443
+ of refinement scales with 1 / oversampling_factor.
444
+
445
+ References
446
+ ----------
447
+ .. [1] https://scikit-image.org/docs/stable/api/skimage.registration.html
448
+ .. [2] Manuel Guizar-Sicairos, Samuel T. Thurman, and
449
+ James R. Fienup, “Efficient subpixel image registration
450
+ algorithms,” Optics Letters 33, 156-158 (2008).
451
+ DOI:10.1364/OL.33.000156
452
+
453
+ """
454
+ scores = be.to_numpy_array(scores)
455
+ peak_positions = be.to_numpy_array(peak_positions)
456
+
457
+ peak_positions = np.round(
458
+ np.divide(
459
+ np.multiply(peak_positions, oversampling_factor), oversampling_factor
460
+ )
461
+ )
462
+ upsampled_region_size = np.ceil(np.multiply(oversampling_factor, 1.5))
463
+ dftshift = np.round(np.divide(upsampled_region_size, 2.0))
464
+ sample_region_offset = np.subtract(
465
+ dftshift, np.multiply(peak_positions, oversampling_factor)
466
+ )
467
+
468
+ scores_ft = np.fft.fftn(scores).conj()
469
+ for index in range(sample_region_offset.shape[0]):
470
+ cross_correlation_upsampled = _upsampled_dft(
471
+ data=scores_ft,
472
+ upsampled_region_size=upsampled_region_size,
473
+ upsample_factor=oversampling_factor,
474
+ axis_offsets=sample_region_offset[index],
475
+ ).conj()
476
+
477
+ maxima = np.unravel_index(
478
+ np.argmax(np.abs(cross_correlation_upsampled)),
479
+ cross_correlation_upsampled.shape,
480
+ )
481
+ maxima = np.divide(np.subtract(maxima, dftshift), oversampling_factor)
482
+ peak_positions[index] = np.add(peak_positions[index], maxima)
483
+
484
+ peak_positions = be.to_backend_array(peak_positions)
485
+
486
+ return peak_positions
487
+
488
+ def _top_peaks(self, positions, scores, num_peaks: int = None):
489
+ num_peaks = be.size(scores) if not num_peaks else num_peaks
490
+
491
+ if self.batch_dims is None:
492
+ top_n = min(be.size(scores), num_peaks)
493
+ top_scores, *_ = be.topk_indices(scores, top_n)
494
+ return top_scores
495
+
496
+ # Not very performant but fairly robust
497
+ batch_indices = positions[..., self.batch_dims]
498
+ batch_indices = be.subtract(batch_indices, be.min(batch_indices, axis=0))
499
+ multiplier = be.power(
500
+ be.max(batch_indices, axis=0) + 1,
501
+ be.arange(batch_indices.shape[1]),
502
+ )
503
+ batch_indices = be.multiply(batch_indices, multiplier, out=batch_indices)
504
+ batch_indices = be.sum(batch_indices, axis=1)
505
+ unique_indices, batch_counts = be.unique(batch_indices, return_counts=True)
506
+ total_indices = be.arange(scores.shape[0])
507
+ batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
508
+ top_scores = be.concatenate(
509
+ [
510
+ total_indices[indices][
511
+ be.topk_indices(scores[indices], min(y, num_peaks))
512
+ ]
513
+ for indices, y in zip(batch_indices, batch_counts)
514
+ ]
515
+ )
516
+ return top_scores
517
+
518
+ def _update(
519
+ self,
520
+ peak_positions: BackendArray,
521
+ peak_details: BackendArray,
522
+ peak_scores: BackendArray,
523
+ rotations: BackendArray,
524
+ offset: BackendArray = None,
525
+ ):
526
+ """
527
+ Update internal parameter store.
528
+
529
+ Parameters
530
+ ----------
531
+ peak_positions : BackendArray
532
+ Position of peaks (n, d).
533
+ peak_details : BackendArray
534
+ Details of each peak (n, ).
535
+ peak_scores: BackendArray
536
+ Score at each peak (n,).
537
+ rotations: BackendArray
538
+ Rotation at each peak (n, d, d).
539
+ offset : BackendArray, optional
540
+ Translation offset, e.g. from splitting, (d, ).
541
+ """
542
+ if offset is not None:
543
+ offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
544
+ peak_positions = be.add(peak_positions, offset, out=peak_positions)
545
+
546
+ positions = be.concatenate((self.translations, peak_positions))
547
+ rotations = be.concatenate((self.rotations, rotations))
548
+ scores = be.concatenate((self.scores, peak_scores))
549
+ details = be.concatenate((self.details, peak_details))
550
+
551
+ # topk filtering after distances yields more distributed peak calls
552
+ distance_order = filter_points_indices(
553
+ coordinates=positions,
554
+ min_distance=self.min_distance,
555
+ batch_dims=self.batch_dims,
556
+ scores=scores,
557
+ )
558
+
559
+ top_scores = self._top_peaks(
560
+ positions[distance_order, :], scores[distance_order], self.num_peaks
561
+ )
562
+ final_order = distance_order[top_scores]
563
+
564
+ self.translations = positions[final_order, :]
565
+ self.rotations = rotations[final_order, :]
566
+ self.scores = scores[final_order]
567
+ self.details = details[final_order]
568
+
569
+ def _postprocess(self, **kwargs):
570
+ if not len(self.translations):
571
+ return self
572
+
573
+ positions, valid_peaks = score_to_cart(self.translations, **kwargs)
574
+
575
+ self.translations = positions[valid_peaks]
576
+ self.rotations = self.rotations[valid_peaks]
577
+ self.scores = self.scores[valid_peaks]
578
+ self.details = self.details[valid_peaks]
579
+ return self
580
+
581
+
582
+ class PeakCallerSort(PeakCaller):
583
+ """
584
+ A :py:class:`PeakCaller` subclass that first selects ``num_peaks``
585
+ highest scores.
586
+ """
587
+
588
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
589
+ flat_scores = scores.reshape(-1)
590
+ k = min(self.num_peaks, be.size(flat_scores))
591
+
592
+ top_k_indices, *_ = be.topk_indices(flat_scores, k)
593
+
594
+ coordinates = be.unravel_index(top_k_indices, scores.shape)
595
+ coordinates = be.transpose(be.stack(coordinates))
596
+
597
+ return coordinates, None
598
+
599
+
600
+ class PeakCallerMaximumFilter(PeakCaller):
601
+ """
602
+ Find local maxima by applying a maximum filter and enforcing a distance
603
+ constraint subsequently. This is similar to the strategy implemented in
604
+ :obj:`skimage.feature.peak_local_max`.
605
+ """
606
+
607
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
608
+ return be.max_filter_coordinates(scores, self.min_distance), None
609
+
610
+
611
+ class PeakCallerFast(PeakCaller):
612
+ """
613
+ Subdivides the score space into squares with edge length ``min_distance``
614
+ and determiens maximum value for each. In a second pass, all local maxima
615
+ that are not the local maxima in a ``min_distance`` square centered around them
616
+ are removed.
617
+
618
+ """
619
+
620
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
621
+ splits = {i: x // self.min_distance for i, x in enumerate(scores.shape)}
622
+ slices = split_shape(scores.shape, splits)
623
+
624
+ coordinates = be.to_backend_array(
625
+ [
626
+ be.unravel_index(be.argmax(scores[subvol]), scores[subvol].shape)
627
+ for subvol in slices
628
+ ]
629
+ )
630
+ offset = be.to_backend_array(
631
+ [tuple(x.start for x in subvol) for subvol in slices]
632
+ )
633
+ be.add(coordinates, offset, out=coordinates)
634
+ coordinates = coordinates[be.argsort(-scores[tuple(coordinates.T)])]
635
+
636
+ if coordinates.shape[0] == 0:
637
+ return None
638
+
639
+ starts = be.maximum(coordinates - self.min_distance, 0)
640
+ stops = be.minimum(coordinates + self.min_distance, scores.shape)
641
+ slices_list = [
642
+ tuple(slice(*coord) for coord in zip(start_row, stop_row))
643
+ for start_row, stop_row in zip(starts, stops)
644
+ ]
645
+
646
+ keep = [
647
+ score_subvol >= be.max(scores[subvol])
648
+ for subvol, score_subvol in zip(slices_list, scores[tuple(coordinates.T)])
649
+ ]
650
+ coordinates = coordinates[keep,]
651
+
652
+ if len(coordinates) == 0:
653
+ return coordinates, None
654
+
655
+ return coordinates, None
656
+
657
+
658
+ class PeakCallerRecursiveMasking(PeakCaller):
659
+ """
660
+ Identifies peaks iteratively by selecting the top score and masking
661
+ a region around it.
662
+ """
663
+
664
+ def call_peaks(
665
+ self,
666
+ scores: BackendArray,
667
+ rotation_matrix: BackendArray,
668
+ mask: BackendArray = None,
669
+ min_score: float = None,
670
+ rotations: BackendArray = None,
671
+ rotation_mapping: Dict = None,
672
+ **kwargs,
673
+ ) -> PeakType:
674
+ """
675
+ Call peaks in the score space.
676
+
677
+ Parameters
678
+ ----------
679
+ scores : BackendArray
680
+ Data array of scores.
681
+ rotation_matrix : BackendArray
682
+ Rotation matrix.
683
+ mask : BackendArray, optional
684
+ Mask array, by default None.
685
+ rotations : BackendArray, optional
686
+ Rotation space array, by default None.
687
+ rotation_mapping : Dict optional
688
+ Dictionary mapping values in rotations to Euler angles.
689
+ By default None
690
+ min_score : float
691
+ Minimum score value to consider. If provided, superseeds limit given
692
+ by :py:attr:`PeakCaller.num_peaks`.
693
+
694
+ Returns
695
+ -------
696
+ Tuple[BackendArray, BackendArray]
697
+ Array of peak coordinates and peak details.
698
+
699
+ Notes
700
+ -----
701
+ By default, scores are masked using a box with edge length self.min_distance.
702
+ If mask is provided, elements around each peak will be multiplied by the mask
703
+ values. If rotations and rotation_mapping is provided, the respective
704
+ rotation will be applied to the mask, otherwise rotation_matrix is used.
705
+ """
706
+ coordinates, masking_function = [], self._mask_scores_rotate
707
+
708
+ if mask is None:
709
+ masking_function = self._mask_scores_box
710
+ shape = tuple(self.min_distance for _ in range(scores.ndim))
711
+ mask = be.zeros(shape, dtype=be._float_dtype)
712
+
713
+ rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
714
+
715
+ peak_limit = self.num_peaks
716
+ if min_score is not None:
717
+ peak_limit = be.size(scores)
718
+ else:
719
+ min_score = be.min(scores) - 1
720
+
721
+ scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
722
+ scores_copy[:] = scores
723
+
724
+ while True:
725
+ be.argmax(scores_copy)
726
+ peak = be.unravel_index(
727
+ indices=be.argmax(scores_copy), shape=scores_copy.shape
728
+ )
729
+ if scores_copy[tuple(peak)] < min_score:
730
+ break
731
+
732
+ coordinates.append(peak)
733
+
734
+ current_rotation_matrix = self._get_rotation_matrix(
735
+ peak=peak,
736
+ rotation_space=rotations,
737
+ rotation_mapping=rotation_mapping,
738
+ rotation_matrix=rotation_matrix,
739
+ )
740
+
741
+ masking_function(
742
+ scores=scores_copy,
743
+ rotation_matrix=current_rotation_matrix,
744
+ peak=peak,
745
+ mask=mask,
746
+ rotated_template=rotated_template,
747
+ )
748
+
749
+ if len(coordinates) >= peak_limit:
750
+ break
751
+
752
+ peaks = be.to_backend_array(coordinates)
753
+ return peaks, None
754
+
755
+ @staticmethod
756
+ def _get_rotation_matrix(
757
+ peak: BackendArray,
758
+ rotation_space: BackendArray,
759
+ rotation_mapping: BackendArray,
760
+ rotation_matrix: BackendArray,
761
+ ) -> BackendArray:
762
+ """
763
+ Get rotation matrix based on peak and rotation data.
764
+
765
+ Parameters
766
+ ----------
767
+ peak : BackendArray
768
+ Peak coordinates.
769
+ rotation_space : BackendArray
770
+ Rotation space array.
771
+ rotation_mapping : Dict
772
+ Dictionary mapping values in rotation_space to Euler angles.
773
+ rotation_matrix : BackendArray
774
+ Current rotation matrix.
775
+
776
+ Returns
777
+ -------
778
+ BackendArray
779
+ Rotation matrix.
780
+ """
781
+ if rotation_space is None or rotation_mapping is None:
782
+ return rotation_matrix
783
+
784
+ rotation = rotation_mapping[rotation_space[tuple(peak)]]
785
+
786
+ # TODO: Newer versions of rotation mapping contain rotation matrices not angles
787
+ if rotation.ndim != 2:
788
+ rotation = be.to_backend_array(
789
+ euler_to_rotationmatrix(be.to_numpy_array(rotation))
790
+ )
791
+ return rotation
792
+
793
+ @staticmethod
794
+ def _mask_scores_box(
795
+ scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
796
+ ) -> None:
797
+ """
798
+ Mask scores in a box around a peak.
799
+
800
+ Parameters
801
+ ----------
802
+ scores : BackendArray
803
+ Data array of scores.
804
+ peak : BackendArray
805
+ Peak coordinates.
806
+ mask : BackendArray
807
+ Mask array.
808
+ """
809
+ start = be.maximum(be.subtract(peak, mask.shape), 0)
810
+ stop = be.minimum(be.add(peak, mask.shape), scores.shape)
811
+ start, stop = be.astype(start, int), be.astype(stop, int)
812
+ coords = tuple(slice(*pos) for pos in zip(start, stop))
813
+ scores[coords] = 0
814
+ return None
815
+
816
+ @staticmethod
817
+ def _mask_scores_rotate(
818
+ scores: BackendArray,
819
+ peak: BackendArray,
820
+ mask: BackendArray,
821
+ rotated_template: BackendArray,
822
+ rotation_matrix: BackendArray,
823
+ **kwargs: Dict,
824
+ ) -> None:
825
+ """
826
+ Mask scores using mask rotation around a peak.
827
+
828
+ Parameters
829
+ ----------
830
+ scores : BackendArray
831
+ Data array of scores.
832
+ peak : BackendArray
833
+ Peak coordinates.
834
+ mask : BackendArray
835
+ Mask array.
836
+ rotated_template : BackendArray
837
+ Empty array to write mask rotations to.
838
+ rotation_matrix : BackendArray
839
+ Rotation matrix.
840
+ """
841
+ left_pad = be.divide(mask.shape, 2).astype(int)
842
+ right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
843
+
844
+ score_start = be.subtract(peak, left_pad)
845
+ score_stop = be.add(peak, right_pad)
846
+
847
+ template_start = be.subtract(be.maximum(score_start, 0), score_start)
848
+ template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
849
+ template_stop = be.subtract(mask.shape, template_stop)
850
+
851
+ score_start = be.maximum(score_start, 0)
852
+ score_stop = be.minimum(score_stop, scores.shape)
853
+ score_start = be.astype(score_start, int)
854
+ score_stop = be.astype(score_stop, int)
855
+
856
+ template_start = be.astype(template_start, int)
857
+ template_stop = be.astype(template_stop, int)
858
+ coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
859
+ coords_template = tuple(
860
+ slice(*pos) for pos in zip(template_start, template_stop)
861
+ )
862
+
863
+ rotated_template.fill(0)
864
+ be.rigid_transform(
865
+ arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
866
+ )
867
+
868
+ scores[coords_score] = be.multiply(
869
+ scores[coords_score], (rotated_template[coords_template] <= 0.1)
870
+ )
871
+ return None
872
+
873
+
874
+ class PeakCallerScipy(PeakCaller):
875
+ """
876
+ Peak calling using :obj:`skimage.feature.peak_local_max` to compute local maxima.
877
+ """
878
+
879
+ def call_peaks(
880
+ self, scores: BackendArray, min_score: float = None, **kwargs
881
+ ) -> PeakType:
882
+ scores = be.to_numpy_array(scores)
883
+ num_peaks = self.num_peaks
884
+ if min_score is not None:
885
+ num_peaks = np.inf
886
+
887
+ non_squeezable_dims = tuple(i for i, x in enumerate(scores.shape) if x != 1)
888
+ peaks = peak_local_max(
889
+ np.squeeze(scores),
890
+ num_peaks=num_peaks,
891
+ min_distance=self.min_distance,
892
+ threshold_abs=min_score,
893
+ )
894
+ peaks_full = np.zeros((peaks.shape[0], scores.ndim), peaks.dtype)
895
+ peaks_full[..., non_squeezable_dims] = peaks[:]
896
+ peaks = be.to_backend_array(peaks_full)
897
+ return peaks, None
898
+
899
+
900
+ class PeakClustering(PeakCallerSort):
901
+ """
902
+ Use DBScan clustering to identify more reliable peaks.
903
+ """
904
+
905
+ def __init__(
906
+ self,
907
+ num_peaks: int = 1000,
908
+ **kwargs,
909
+ ):
910
+ kwargs["min_distance"] = 0
911
+ super().__init__(num_peaks=num_peaks, **kwargs)
912
+
913
+ @classmethod
914
+ def merge(cls, **kwargs) -> NDArray:
915
+ """
916
+ Merge multiple instances of Analyzer.
917
+
918
+ Parameters
919
+ ----------
920
+ **kwargs
921
+ Optional keyword arguments passed to :py:meth:`PeakCaller.merge`.
922
+
923
+ Returns
924
+ -------
925
+ NDArray
926
+ NDArray of candidates.
927
+ """
928
+ from sklearn.cluster import DBSCAN
929
+ from ..extensions import max_index_by_label
930
+
931
+ peaks, rotations, scores, details = super().merge(**kwargs)
932
+
933
+ scores = np.array([candidate[2] for candidate in peaks])
934
+ clusters = DBSCAN(eps=np.finfo(float).eps, min_samples=8).fit(peaks)
935
+ labels = clusters.labels_.astype(int)
936
+
937
+ label_max = max_index_by_label(labels=labels, scores=scores)
938
+ if -1 in label_max:
939
+ _ = label_max.pop(-1)
940
+ representatives = set(label_max.values())
941
+
942
+ keep = np.array(
943
+ [
944
+ True if index in representatives else False
945
+ for index in range(peaks.shape[0])
946
+ ]
947
+ )
948
+ peaks = peaks[keep,]
949
+ rotations = rotations[keep,]
950
+ scores = scores[keep]
951
+ details = details[keep]
952
+
953
+ return peaks, rotations, scores, details