pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/analyzer.py CHANGED
@@ -4,54 +4,68 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- from time import sleep
8
- from typing import Tuple, List, Dict
9
- from abc import ABC, abstractmethod
10
7
  from contextlib import nullcontext
11
- from multiprocessing import RawValue, Manager, Lock
8
+ from abc import ABC, abstractmethod
9
+ from multiprocessing import Manager, Lock
10
+ from typing import Tuple, List, Dict, Generator
12
11
 
13
12
  import numpy as np
14
- from numpy.typing import NDArray
15
- from scipy.stats import entropy
16
13
  from sklearn.cluster import DBSCAN
17
14
  from skimage.feature import peak_local_max
18
15
  from skimage.registration._phase_cross_correlation import _upsampled_dft
19
- from .extensions import max_index_by_label, online_statistics, find_candidate_indices
16
+
17
+ from .backends import backend as be
18
+ from .types import BackendArray, NDArray
19
+ from .extensions import max_index_by_label, find_candidate_indices
20
20
  from .matching_utils import (
21
- split_numpy_array_slices,
21
+ split_shape,
22
22
  array_to_memmap,
23
23
  generate_tempfile_name,
24
24
  euler_to_rotationmatrix,
25
25
  apply_convolution_mode,
26
26
  )
27
- from .backends import backend
28
27
 
28
+ PeakType = Tuple[BackendArray, BackendArray]
29
29
 
30
- def filter_points_indices_bucket(
31
- coordinates: NDArray, min_distance: Tuple[int]
32
- ) -> NDArray:
33
- coordinates = backend.subtract(coordinates, backend.min(coordinates, axis=0))
34
- bucket_indices = backend.astype(backend.divide(coordinates, min_distance), int)
35
- multiplier = backend.power(
36
- backend.max(bucket_indices, axis=0) + 1, backend.arange(bucket_indices.shape[1])
30
+
31
+ def _filter_bucket(coordinates: BackendArray, min_distance: Tuple[int]) -> BackendArray:
32
+ coordinates = be.subtract(coordinates, be.min(coordinates, axis=0))
33
+ bucket_indices = be.astype(be.divide(coordinates, min_distance), int)
34
+ multiplier = be.power(
35
+ be.max(bucket_indices, axis=0) + 1, be.arange(bucket_indices.shape[1])
37
36
  )
38
- backend.multiply(bucket_indices, multiplier, out=bucket_indices)
39
- flattened_indices = backend.sum(bucket_indices, axis=1)
40
- _, unique_indices = backend.unique(flattened_indices, return_index=True)
41
- unique_indices = unique_indices[backend.argsort(unique_indices)]
37
+ be.multiply(bucket_indices, multiplier, out=bucket_indices)
38
+ flattened_indices = be.sum(bucket_indices, axis=1)
39
+ _, unique_indices = be.unique(flattened_indices, return_index=True)
40
+ unique_indices = unique_indices[be.argsort(unique_indices)]
42
41
  return unique_indices
43
42
 
44
43
 
45
44
  def filter_points_indices(
46
- coordinates: NDArray, min_distance: float, bucket_cutoff: int = 1e4
47
- ) -> NDArray:
45
+ coordinates: BackendArray,
46
+ min_distance: float,
47
+ bucket_cutoff: int = 1e4,
48
+ batch_dims: Tuple[int] = None,
49
+ ) -> BackendArray:
48
50
  if min_distance <= 0:
49
- return backend.arange(coordinates.shape[0])
51
+ return be.arange(coordinates.shape[0])
52
+ if coordinates.shape[0] == 0:
53
+ return ()
54
+
55
+ if batch_dims is not None:
56
+ coordinates_new = be.zeros(coordinates.shape, coordinates.dtype)
57
+ coordinates_new[:] = coordinates
58
+ coordinates_new[..., batch_dims] = be.astype(
59
+ coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
60
+ )
61
+ coordinates = coordinates_new
50
62
 
51
63
  if isinstance(coordinates, np.ndarray):
52
64
  return find_candidate_indices(coordinates, min_distance)
53
- elif coordinates.shape[0] > bucket_cutoff:
54
- return filter_points_indices_bucket(coordinates, min_distance)
65
+ elif coordinates.shape[0] > bucket_cutoff or not isinstance(
66
+ coordinates, np.ndarray
67
+ ):
68
+ return _filter_bucket(coordinates, min_distance)
55
69
  distances = np.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
56
70
  distances = np.tril(distances)
57
71
  keep = np.sum(distances > min_distance, axis=1)
@@ -59,8 +73,10 @@ def filter_points_indices(
59
73
  return indices[keep == indices]
60
74
 
61
75
 
62
- def filter_points(coordinates: NDArray, min_distance: Tuple[int]) -> NDArray:
63
- unique_indices = filter_points_indices(coordinates, min_distance)
76
+ def filter_points(
77
+ coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
78
+ ) -> BackendArray:
79
+ unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
64
80
  coordinates = coordinates[unique_indices]
65
81
  return coordinates
66
82
 
@@ -77,8 +93,15 @@ class PeakCaller(ABC):
77
93
  Minimum distance between peaks.
78
94
  min_boundary_distance : int, optional
79
95
  Minimum distance to array boundaries.
96
+ batch_dims : int, optional
97
+ Peak calling batch dimensions.
98
+ minimum_score : float
99
+ Minimum score from which to consider peaks. If provided, superseeds limits
100
+ presented by :py:attr:`PeakCaller.number_of_peaks`.
101
+ maximum_score : float
102
+ Maximum score upon which to consider peaks,
80
103
  **kwargs
81
- Additional keyword arguments.
104
+ Optional keyword arguments.
82
105
 
83
106
  Raises
84
107
  ------
@@ -92,12 +115,11 @@ class PeakCaller(ABC):
92
115
  number_of_peaks: int = 1000,
93
116
  min_distance: int = 1,
94
117
  min_boundary_distance: int = 0,
118
+ batch_dims: Tuple[int] = None,
119
+ minimum_score: float = None,
120
+ maximum_score: float = None,
95
121
  **kwargs,
96
122
  ):
97
- number_of_peaks = int(number_of_peaks)
98
- min_distance, min_boundary_distance = int(min_distance), int(
99
- min_boundary_distance
100
- )
101
123
  if number_of_peaks <= 0:
102
124
  raise ValueError(
103
125
  f"number_of_peaks has to be larger than 0, got {number_of_peaks}"
@@ -110,9 +132,15 @@ class PeakCaller(ABC):
110
132
  )
111
133
 
112
134
  self.peak_list = []
113
- self.min_distance = min_distance
114
- self.min_boundary_distance = min_boundary_distance
115
- self.number_of_peaks = number_of_peaks
135
+ self.min_distance = int(min_distance)
136
+ self.number_of_peaks = int(number_of_peaks)
137
+ self.min_boundary_distance = int(min_boundary_distance)
138
+
139
+ self.batch_dims = batch_dims
140
+ if batch_dims is not None:
141
+ self.batch_dims = tuple(int(x) for x in self.batch_dims)
142
+
143
+ self.minimum_score, self.maximum_score = minimum_score, maximum_score
116
144
 
117
145
  # Postprocesing arguments
118
146
  self.fourier_shift = kwargs.get("fourier_shift", None)
@@ -120,156 +148,191 @@ class PeakCaller(ABC):
120
148
  self.targetshape = kwargs.get("targetshape", None)
121
149
  self.templateshape = kwargs.get("templateshape", None)
122
150
 
123
- def __iter__(self):
151
+ def __iter__(self) -> Generator:
124
152
  """
125
153
  Returns a generator to list objects containing translation,
126
154
  rotation, score and details of a given candidate.
127
155
  """
128
- self.peak_list = [backend.to_cpu_array(arr) for arr in self.peak_list]
156
+ self.peak_list = [be.to_cpu_array(arr) for arr in self.peak_list]
129
157
  yield from self.peak_list
130
158
 
131
- def __call__(
132
- self,
133
- score_space: NDArray,
134
- rotation_matrix: NDArray,
135
- **kwargs,
136
- ) -> None:
159
+ @staticmethod
160
+ def _batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
161
+ if batch_dims is None:
162
+ yield (tuple(slice(None) for _ in shape), tuple(0 for _ in shape))
163
+ return None
164
+
165
+ batch_ranges = [range(shape[dim]) for dim in batch_dims]
166
+
167
+ def _generate_slices_recursive(current_dim, current_indices):
168
+ if current_dim == len(batch_dims):
169
+ slice_list, offset_list, batch_index = [], [], 0
170
+ for i in range(len(shape)):
171
+ if i in batch_dims:
172
+ index = current_indices[batch_index]
173
+ slice_list.append(slice(index, index + 1))
174
+ offset_list.append(index)
175
+ batch_index += 1
176
+ else:
177
+ slice_list.append(slice(None))
178
+ offset_list.append(0)
179
+ yield (tuple(slice_list), tuple(offset_list))
180
+ else:
181
+ for index in batch_ranges[current_dim]:
182
+ yield from _generate_slices_recursive(
183
+ current_dim + 1, current_indices + (index,)
184
+ )
185
+
186
+ yield from _generate_slices_recursive(0, ())
187
+
188
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
137
189
  """
138
190
  Update the internal parameter store based on input array.
139
191
 
140
192
  Parameters
141
193
  ----------
142
- score_space : NDArray
143
- Array containing the score space.
144
- rotation_matrix : NDArray
194
+ scores : BackendArray
195
+ Score space data.
196
+ rotation_matrix : BackendArray
145
197
  Rotation matrix used to obtain the score array.
146
198
  **kwargs
147
- Optional keyword arguments passed to :py:meth:`PeakCaller.call_peak`.
199
+ Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
148
200
  """
149
- peak_positions, peak_details = self.call_peaks(
150
- score_space=score_space, rotation_matrix=rotation_matrix, **kwargs
151
- )
152
-
153
- if peak_positions is None:
154
- return None
201
+ minimum_score, maximum_score = self.minimum_score, self.maximum_score
202
+ for subset, batch_offset in self._batchify(scores.shape, self.batch_dims):
203
+ batch_offset = be.to_backend_array(batch_offset)
204
+ peak_positions, peak_details = self.call_peaks(
205
+ scores=scores[subset],
206
+ rotation_matrix=rotation_matrix,
207
+ minimum_score=minimum_score,
208
+ maximum_score=maximum_score,
209
+ **kwargs,
210
+ )
155
211
 
156
- peak_positions = backend.astype(peak_positions, int)
157
- if peak_positions.shape[0] == 0:
158
- return None
212
+ if peak_positions is None:
213
+ continue
214
+ if peak_positions.shape[0] == 0:
215
+ continue
159
216
 
160
- if peak_details is None:
161
- peak_details = backend.to_backend_array([-1] * peak_positions.shape[0])
217
+ if peak_details is None:
218
+ peak_details = be.full((peak_positions.shape[0],), fill_value=-1)
162
219
 
163
- if self.min_boundary_distance > 0:
164
- upper_limit = backend.subtract(
165
- score_space.shape, self.min_boundary_distance
166
- )
167
- valid_peaks = (
168
- backend.sum(
169
- backend.multiply(
170
- peak_positions < upper_limit,
171
- peak_positions >= self.min_boundary_distance,
172
- ),
173
- axis=1,
220
+ peak_positions = be.to_backend_array(peak_positions)
221
+ peak_positions = be.add(peak_positions, batch_offset, out=peak_positions)
222
+ peak_positions = be.astype(peak_positions, int)
223
+ if self.min_boundary_distance > 0:
224
+ upper_limit = be.subtract(
225
+ be.to_backend_array(scores.shape), self.min_boundary_distance
226
+ )
227
+ valid_peaks = be.multiply(
228
+ peak_positions < upper_limit,
229
+ peak_positions >= self.min_boundary_distance,
230
+ )
231
+ if self.batch_dims is not None:
232
+ valid_peaks[..., self.batch_dims] = True
233
+
234
+ valid_peaks = be.sum(valid_peaks, axis=1) == peak_positions.shape[1]
235
+
236
+ if be.sum(valid_peaks) == 0:
237
+ continue
238
+ peak_positions = peak_positions[valid_peaks]
239
+ peak_details = peak_details[valid_peaks]
240
+
241
+ peak_scores = scores[tuple(peak_positions.T)]
242
+ if minimum_score is not None:
243
+ valid_peaks = peak_scores >= minimum_score
244
+ peak_positions, peak_details, peak_scores = (
245
+ peak_positions[valid_peaks],
246
+ peak_details[valid_peaks],
247
+ peak_scores[valid_peaks],
174
248
  )
175
- == peak_positions.shape[1]
249
+ if maximum_score is not None:
250
+ valid_peaks = peak_scores <= maximum_score
251
+ peak_positions, peak_details, peak_scores = (
252
+ peak_positions[valid_peaks],
253
+ peak_details[valid_peaks],
254
+ peak_scores[valid_peaks],
255
+ )
256
+
257
+ if peak_positions.shape[0] == 0:
258
+ continue
259
+
260
+ rotations = be.repeat(
261
+ rotation_matrix.reshape(1, *rotation_matrix.shape),
262
+ peak_positions.shape[0],
263
+ axis=0,
176
264
  )
177
- if backend.sum(valid_peaks) == 0:
178
- return None
179
265
 
180
- peak_positions, peak_details = (
181
- peak_positions[valid_peaks],
182
- peak_details[valid_peaks],
266
+ self._update(
267
+ peak_positions=peak_positions,
268
+ peak_details=peak_details,
269
+ peak_scores=peak_scores,
270
+ rotations=rotations,
183
271
  )
184
272
 
185
- rotations = backend.repeat(
186
- rotation_matrix.reshape(1, *rotation_matrix.shape),
187
- peak_positions.shape[0],
188
- axis=0,
189
- )
190
- peak_scores = score_space[tuple(peak_positions.T)]
191
-
192
- self._update(
193
- peak_positions=peak_positions,
194
- peak_details=peak_details,
195
- peak_scores=peak_scores,
196
- rotations=rotations,
197
- **kwargs,
198
- )
273
+ return None
199
274
 
200
275
  @abstractmethod
201
- def call_peaks(
202
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
203
- ) -> Tuple[NDArray, NDArray]:
276
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
204
277
  """
205
278
  Call peaks in the score space.
206
279
 
207
- This function is not intended to be called directly, but should rather be
208
- defined by classes inheriting from :py:class:`PeakCaller` to execute a given
209
- peak calling algorithm.
210
-
211
280
  Parameters
212
281
  ----------
213
- score_space : NDArray
214
- Data array of scores.
215
- minimum_score : float
216
- Minimum score value to consider.
217
- min_distance : float
218
- Minimum distance between maxima.
282
+ scores : BackendArray
283
+ Score array.
284
+ **kwargs : dict
285
+ Optional keyword arguments passed to underlying implementations.
219
286
 
220
287
  Returns
221
288
  -------
222
- Tuple[NDArray, NDArray]
289
+ Tuple[BackendArray, BackendArray]
223
290
  Array of peak coordinates and peak details.
224
291
  """
225
292
 
226
293
  @classmethod
227
- def merge(cls, candidates=List[List], **kwargs) -> NDArray:
294
+ def merge(cls, candidates=List[List], **kwargs) -> Tuple:
228
295
  """
229
296
  Merge multiple instances of :py:class:`PeakCaller`.
230
297
 
231
298
  Parameters
232
299
  ----------
233
- candidate_fits : list of lists
300
+ candidates : list of lists
234
301
  Obtained by invoking list on the generator returned by __iter__.
235
- param_stores : list of tuples, optional
236
- List of parameter stores. Each tuple contains candidate data and number
237
- of candidates.
238
302
  **kwargs
239
- Additional keyword arguments.
303
+ Optional keyword arguments.
240
304
 
241
305
  Returns
242
306
  -------
243
- NDArray
244
- NDArray of candidates.
307
+ Tuple
308
+ Tuple of translation, rotation, score and details of candidates.
245
309
  """
246
310
  base = cls(**kwargs)
247
311
  for candidate in candidates:
248
312
  if len(candidate) == 0:
249
313
  continue
250
314
  peak_positions, rotations, peak_scores, peak_details = candidate
251
- kwargs["translation_offset"] = backend.zeros(peak_positions.shape[1])
252
315
  base._update(
253
- peak_positions=backend.to_backend_array(peak_positions),
254
- peak_details=backend.to_backend_array(peak_details),
255
- peak_scores=backend.to_backend_array(peak_scores),
256
- rotations=backend.to_backend_array(rotations),
257
- **kwargs,
316
+ peak_positions=be.to_backend_array(peak_positions),
317
+ peak_details=be.to_backend_array(peak_details),
318
+ peak_scores=be.to_backend_array(peak_scores),
319
+ rotations=be.to_backend_array(rotations),
320
+ offset=kwargs.get("offset", None),
258
321
  )
259
322
  return tuple(base)
260
323
 
261
324
  @staticmethod
262
325
  def oversample_peaks(
263
- score_space: NDArray, peak_positions: NDArray, oversampling_factor: int = 8
326
+ scores: BackendArray, peak_positions: BackendArray, oversampling_factor: int = 8
264
327
  ):
265
328
  """
266
329
  Refines peaks positions in the corresponding score space.
267
330
 
268
331
  Parameters
269
332
  ----------
270
- score_space : NDArray
333
+ scores : BackendArray
271
334
  The d-dimensional array representing the score space.
272
- peak_positions : NDArray
335
+ peak_positions : BackendArray
273
336
  An array of shape (n, d) containing the peak coordinates
274
337
  to be refined, where n is the number of peaks and d is the
275
338
  dimensionality of the score space.
@@ -278,14 +341,14 @@ class PeakCaller(ABC):
278
341
 
279
342
  Returns
280
343
  -------
281
- NDArray
344
+ BackendArray
282
345
  An array of shape (n, d) containing the refined subpixel
283
346
  coordinates of the peaks.
284
347
 
285
348
  Notes
286
349
  -----
287
350
  Floating point peak positions are determined by oversampling the
288
- score_space around peak_positions. The accuracy
351
+ scores around peak_positions. The accuracy
289
352
  of refinement scales with 1 / oversampling_factor.
290
353
 
291
354
  References
@@ -297,8 +360,8 @@ class PeakCaller(ABC):
297
360
  DOI:10.1364/OL.33.000156
298
361
 
299
362
  """
300
- score_space = backend.to_numpy_array(score_space)
301
- peak_positions = backend.to_numpy_array(peak_positions)
363
+ scores = be.to_numpy_array(scores)
364
+ peak_positions = be.to_numpy_array(peak_positions)
302
365
 
303
366
  peak_positions = np.round(
304
367
  np.divide(
@@ -311,10 +374,10 @@ class PeakCaller(ABC):
311
374
  dftshift, np.multiply(peak_positions, oversampling_factor)
312
375
  )
313
376
 
314
- score_space_ft = np.fft.fftn(score_space).conj()
377
+ scores_ft = np.fft.fftn(scores).conj()
315
378
  for index in range(sample_region_offset.shape[0]):
316
379
  cross_correlation_upsampled = _upsampled_dft(
317
- data=score_space_ft,
380
+ data=scores_ft,
318
381
  upsampled_region_size=upsampled_region_size,
319
382
  upsample_factor=oversampling_factor,
320
383
  axis_offsets=sample_region_offset[index],
@@ -327,209 +390,171 @@ class PeakCaller(ABC):
327
390
  maxima = np.divide(np.subtract(maxima, dftshift), oversampling_factor)
328
391
  peak_positions[index] = np.add(peak_positions[index], maxima)
329
392
 
330
- peak_positions = backend.to_backend_array(peak_positions)
393
+ peak_positions = be.to_backend_array(peak_positions)
331
394
 
332
395
  return peak_positions
333
396
 
334
397
  def _update(
335
398
  self,
336
- peak_positions: NDArray,
337
- peak_details: NDArray,
338
- peak_scores: NDArray,
339
- rotations: NDArray,
340
- **kwargs,
341
- ) -> None:
399
+ peak_positions: BackendArray,
400
+ peak_details: BackendArray,
401
+ peak_scores: BackendArray,
402
+ rotations: BackendArray,
403
+ offset: BackendArray = None,
404
+ ):
342
405
  """
343
406
  Update internal parameter store.
344
407
 
345
408
  Parameters
346
409
  ----------
347
- peak_positions : NDArray
348
- Position of peaks with shape n x d where n is the number of
349
- peaks and d the dimension.
350
- peak_scores : NDArray
351
- Corresponding score obtained at each peak.
352
- translation_offset : NDArray, optional
353
- Offset of the score_space, occurs e.g. when template matching
354
- to parts of a tomogram.
355
- rotations: NDArray
356
- Rotations used to obtain the score space from which
357
- the candidate stem.
410
+ peak_positions : BackendArray
411
+ Position of peaks (n, d).
412
+ peak_details : BackendArray
413
+ Details of each peak (n, ).
414
+ rotations: BackendArray
415
+ Rotation at each peak (n, d, d).
416
+ rotations: BackendArray
417
+ Rotation at each peak (n, d, d).
418
+ offset : BackendArray, optional
419
+ Translation offset, e.g. from splitting, (n, ).
358
420
  """
359
- translation_offset = kwargs.get(
360
- "translation_offset", backend.zeros(peak_positions.shape[1])
361
- )
362
- translation_offset = backend.astype(translation_offset, peak_positions.dtype)
421
+ if offset is not None:
422
+ offset = be.astype(offset, peak_positions.dtype)
423
+ peak_positions = be.add(peak_positions, offset, out=peak_positions)
363
424
 
364
- backend.add(peak_positions, translation_offset, out=peak_positions)
365
425
  if not len(self.peak_list):
366
426
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
367
- dim = peak_positions.shape[1]
368
- peak_scores = backend.zeros((0,), peak_scores.dtype)
369
- peak_details = backend.zeros((0,), peak_details.dtype)
370
- rotations = backend.zeros((0, dim, dim), rotations.dtype)
371
- peak_positions = backend.zeros((0, dim), peak_positions.dtype)
372
-
373
- peaks = backend.concatenate((self.peak_list[0], peak_positions))
374
- rotations = backend.concatenate((self.peak_list[1], rotations))
375
- peak_scores = backend.concatenate((self.peak_list[2], peak_scores))
376
- peak_details = backend.concatenate((self.peak_list[3], peak_details))
377
-
378
- top_n = min(backend.size(peak_scores), self.number_of_peaks)
379
- top_scores, *_ = backend.topk_indices(peak_scores, top_n)
427
+ else:
428
+ peak_positions = be.concatenate((self.peak_list[0], peak_positions))
429
+ rotations = be.concatenate((self.peak_list[1], rotations))
430
+ peak_scores = be.concatenate((self.peak_list[2], peak_scores))
431
+ peak_details = be.concatenate((self.peak_list[3], peak_details))
432
+
433
+ if self.batch_dims is None:
434
+ top_n = min(be.size(peak_scores), self.number_of_peaks)
435
+ top_scores, *_ = be.topk_indices(peak_scores, top_n)
436
+ else:
437
+ # Not very performant but fairly robust
438
+ batch_indices = peak_positions[..., self.batch_dims]
439
+ batch_indices = be.subtract(batch_indices, be.min(batch_indices, axis=0))
440
+ multiplier = be.power(
441
+ be.max(batch_indices, axis=0) + 1,
442
+ be.arange(batch_indices.shape[1]),
443
+ )
444
+ batch_indices = be.multiply(batch_indices, multiplier, out=batch_indices)
445
+ batch_indices = be.sum(batch_indices, axis=1)
446
+ unique_indices, batch_counts = be.unique(batch_indices, return_counts=True)
447
+ total_indices = be.arange(peak_scores.shape[0])
448
+ batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
449
+ top_scores = be.concatenate(
450
+ [
451
+ total_indices[indices][
452
+ be.topk_indices(
453
+ peak_scores[indices], min(y, self.number_of_peaks)
454
+ )
455
+ ]
456
+ for indices, y in zip(batch_indices, batch_counts)
457
+ ]
458
+ )
380
459
 
381
460
  final_order = top_scores[
382
- filter_points_indices(peaks[top_scores], self.min_distance)
461
+ filter_points_indices(
462
+ coordinates=peak_positions[top_scores],
463
+ min_distance=self.min_distance,
464
+ batch_dims=self.batch_dims,
465
+ )
383
466
  ]
384
467
 
385
- self.peak_list[0] = peaks[final_order,]
468
+ self.peak_list[0] = peak_positions[final_order,]
386
469
  self.peak_list[1] = rotations[final_order,]
387
470
  self.peak_list[2] = peak_scores[final_order]
388
471
  self.peak_list[3] = peak_details[final_order]
389
472
 
390
473
  def _postprocess(
391
- self, fourier_shift, convolution_mode, targetshape, templateshape, **kwargs
474
+ self,
475
+ fast_shape: Tuple[int],
476
+ targetshape: Tuple[int],
477
+ templateshape: Tuple[int],
478
+ fourier_shift: Tuple[int] = None,
479
+ convolution_mode: str = None,
480
+ shared_memory_handler=None,
481
+ **kwargs,
392
482
  ):
393
- peak_positions = self.peak_list[0]
394
- if not len(peak_positions):
483
+ if not len(self.peak_list):
395
484
  return self
396
485
 
397
- if targetshape is None or templateshape is None:
486
+ peak_positions = self.peak_list[0]
487
+ if not len(peak_positions):
398
488
  return self
399
489
 
400
- # Remove padding to next fast fourier length
401
- score_space_shape = backend.add(targetshape, templateshape) - 1
402
-
490
+ # Wrap peaks around score space
491
+ fast_shape = be.to_backend_array(fast_shape)
403
492
  if fourier_shift is not None:
404
- peak_positions = backend.add(peak_positions, fourier_shift)
405
- backend.divide(peak_positions, score_space_shape).astype(int)
406
-
407
- backend.subtract(
493
+ fourier_shift = be.to_backend_array(fourier_shift)
494
+ peak_positions = be.add(peak_positions, fourier_shift)
495
+ peak_positions = be.subtract(
408
496
  peak_positions,
409
- backend.multiply(
410
- backend.divide(peak_positions, score_space_shape).astype(int),
411
- score_space_shape
497
+ be.multiply(
498
+ be.astype(be.divide(peak_positions, fast_shape), int),
499
+ fast_shape,
412
500
  ),
413
- out = peak_positions
414
501
  )
415
502
 
416
- if convolution_mode is None:
417
- return None
418
-
419
- if convolution_mode == "full":
420
- output_shape = score_space_shape
421
- elif convolution_mode == "same":
503
+ # Remove padding to fast Fourier (and potential full convolution) shape
504
+ targetshape = be.to_backend_array(targetshape)
505
+ templateshape = be.to_backend_array(templateshape)
506
+ fast_shape = be.minimum(be.add(targetshape, templateshape) - 1, fast_shape)
507
+ output_shape = fast_shape
508
+ if convolution_mode == "same":
422
509
  output_shape = targetshape
423
510
  elif convolution_mode == "valid":
424
- output_shape = backend.add(
425
- backend.subtract(targetshape, templateshape),
426
- backend.mod(templateshape, 2)
511
+ output_shape = be.add(
512
+ be.subtract(targetshape, templateshape),
513
+ be.mod(templateshape, 2),
427
514
  )
428
515
 
429
- output_shape = backend.to_backend_array(output_shape)
430
- starts = backend.divide(
431
- backend.subtract(score_space_shape, output_shape),
432
- 2
516
+ output_shape = be.to_backend_array(output_shape)
517
+ starts = be.astype(
518
+ be.divide(be.subtract(fast_shape, output_shape), 2),
519
+ be._int_dtype,
433
520
  )
434
- starts = backend.astype(starts, int)
435
- stops = backend.add(starts, output_shape)
436
-
437
- valid_peaks = (
438
- backend.sum(
439
- backend.multiply(
440
- peak_positions > starts,
441
- peak_positions <= stops
442
- ),
443
- axis=1,
444
- )
445
- == peak_positions.shape[1]
446
- )
447
- self.peak_list[0] = backend.subtract(peak_positions, starts)
521
+ stops = be.add(starts, output_shape)
522
+
523
+ valid_peaks = be.multiply(peak_positions > starts, peak_positions <= stops)
524
+ valid_peaks = be.sum(valid_peaks, axis=1) == peak_positions.shape[1]
525
+
526
+ self.peak_list[0] = be.subtract(peak_positions, starts)
448
527
  self.peak_list = [x[valid_peaks] for x in self.peak_list]
449
528
  return self
450
529
 
530
+
451
531
  class PeakCallerSort(PeakCaller):
452
532
  """
453
533
  A :py:class:`PeakCaller` subclass that first selects ``number_of_peaks``
454
- highest scores and subsequently filters local maxima to suffice a distance
455
- from one another of ``min_distance``.
456
-
534
+ highest scores.
457
535
  """
458
536
 
459
- def call_peaks(
460
- self, score_space: NDArray, minimum_score: float = None, **kwargs
461
- ) -> Tuple[NDArray, NDArray]:
462
- """
463
- Call peaks in the score space.
464
-
465
- Parameters
466
- ----------
467
- score_space : NDArray
468
- Data array of scores.
469
- minimum_score : float
470
- Minimum score value to consider. If provided, superseeds limit given
471
- by :py:attr:`PeakCaller.number_of_peaks`.
472
-
473
- Returns
474
- -------
475
- Tuple[NDArray, NDArray]
476
- Array of peak coordinates and peak details.
477
- """
478
- flat_score_space = score_space.reshape(-1)
479
- k = min(self.number_of_peaks, backend.size(flat_score_space))
480
-
481
- if minimum_score is not None:
482
- k = backend.sum(score_space >= minimum_score)
537
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
538
+ flat_scores = scores.reshape(-1)
539
+ k = min(self.number_of_peaks, be.size(flat_scores))
483
540
 
484
- top_k_indices, *_ = backend.topk_indices(flat_score_space, k)
541
+ top_k_indices, *_ = be.topk_indices(flat_scores, k)
485
542
 
486
- coordinates = backend.unravel_index(top_k_indices, score_space.shape)
487
- coordinates = backend.transpose(backend.stack(coordinates))
543
+ coordinates = be.unravel_index(top_k_indices, scores.shape)
544
+ coordinates = be.transpose(be.stack(coordinates))
488
545
 
489
- peaks = filter_points(coordinates, self.min_distance)
490
- return peaks, None
546
+ return coordinates, None
491
547
 
492
548
 
493
549
  class PeakCallerMaximumFilter(PeakCaller):
494
550
  """
495
551
  Find local maxima by applying a maximum filter and enforcing a distance
496
552
  constraint subsequently. This is similar to the strategy implemented in
497
- skimage.feature.peak_local_max.
553
+ :obj:`skimage.feature.peak_local_max`.
498
554
  """
499
555
 
500
- def call_peaks(
501
- self, score_space: NDArray, minimum_score: float = None, **kwargs
502
- ) -> Tuple[NDArray, NDArray]:
503
- """
504
- Call peaks in the score space.
505
-
506
- Parameters
507
- ----------
508
- score_space : NDArray
509
- Data array of scores.
510
- minimum_score : float
511
- Minimum score value to consider. If provided, superseeds limit given
512
- by :py:attr:`PeakCaller.number_of_peaks`.
513
-
514
- Returns
515
- -------
516
- Tuple[NDArray, NDArray]
517
- Array of peak coordinates and peak details.
518
- """
519
- peaks = backend.max_filter_coordinates(score_space, self.min_distance)
520
-
521
- scores = score_space[tuple(peaks.T)]
522
-
523
- input_candidates = min(
524
- self.number_of_peaks, peaks.shape[0] - 1, backend.size(score_space) - 1
525
- )
526
- if minimum_score is not None:
527
- input_candidates = backend.sum(scores >= minimum_score)
528
-
529
- top_indices = backend.topk_indices(scores, input_candidates)
530
- peaks = peaks[top_indices]
531
-
532
- return peaks, None
556
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
557
+ return be.max_filter_coordinates(scores, self.min_distance), None
533
558
 
534
559
 
535
560
  class PeakCallerFast(PeakCaller):
@@ -541,70 +566,42 @@ class PeakCallerFast(PeakCaller):
541
566
 
542
567
  """
543
568
 
544
- def call_peaks(
545
- self, score_space: NDArray, minimum_score: float = None, **kwargs
546
- ) -> Tuple[NDArray, NDArray]:
547
- """
548
- Call peaks in the score space.
549
-
550
- Parameters
551
- ----------
552
- score_space : NDArray
553
- Data array of scores.
554
- minimum_score : float
555
- Minimum score value to consider. If provided, superseeds limit given
556
- by :py:attr:`PeakCaller.number_of_peaks`.
557
-
558
- Returns
559
- -------
560
- Tuple[NDArray, NDArray]
561
- Array of peak coordinates and peak details.
562
- """
563
- splits = {
564
- axis: score_space.shape[axis] // self.min_distance
565
- for axis in range(score_space.ndim)
566
- }
567
- slices = split_numpy_array_slices(score_space.shape, splits)
569
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
570
+ splits = {i: x // self.min_distance for i, x in enumerate(scores.shape)}
571
+ slices = split_shape(scores.shape, splits)
568
572
 
569
- coordinates = backend.to_backend_array(
573
+ coordinates = be.to_backend_array(
570
574
  [
571
- backend.unravel_index(
572
- backend.argmax(score_space[subvol]), score_space[subvol].shape
573
- )
575
+ be.unravel_index(be.argmax(scores[subvol]), scores[subvol].shape)
574
576
  for subvol in slices
575
577
  ]
576
578
  )
577
- offset = backend.to_backend_array(
579
+ offset = be.to_backend_array(
578
580
  [tuple(x.start for x in subvol) for subvol in slices]
579
581
  )
580
- backend.add(coordinates, offset, out=coordinates)
581
- coordinates = coordinates[
582
- backend.flip(backend.argsort(score_space[tuple(coordinates.T)]), (0,))
583
- ]
582
+ be.add(coordinates, offset, out=coordinates)
583
+ coordinates = coordinates[be.argsort(-scores[tuple(coordinates.T)])]
584
584
 
585
585
  if coordinates.shape[0] == 0:
586
586
  return None
587
587
 
588
- peaks = filter_points(coordinates, self.min_distance)
589
-
590
- starts = backend.maximum(peaks - self.min_distance, 0)
591
- stops = backend.minimum(peaks + self.min_distance, score_space.shape)
588
+ starts = be.maximum(coordinates - self.min_distance, 0)
589
+ stops = be.minimum(coordinates + self.min_distance, scores.shape)
592
590
  slices_list = [
593
591
  tuple(slice(*coord) for coord in zip(start_row, stop_row))
594
592
  for start_row, stop_row in zip(starts, stops)
595
593
  ]
596
594
 
597
- scores = score_space[tuple(peaks.T)]
598
595
  keep = [
599
- score >= backend.max(score_space[subvol])
600
- for subvol, score in zip(slices_list, scores)
596
+ score_subvol >= be.max(scores[subvol])
597
+ for subvol, score_subvol in zip(slices_list, scores[tuple(coordinates.T)])
601
598
  ]
602
- peaks = peaks[keep,]
599
+ coordinates = coordinates[keep,]
603
600
 
604
- if len(peaks) == 0:
605
- return peaks, None
601
+ if len(coordinates) == 0:
602
+ return coordinates, None
606
603
 
607
- return peaks, None
604
+ return coordinates, None
608
605
 
609
606
 
610
607
  class PeakCallerRecursiveMasking(PeakCaller):
@@ -615,26 +612,26 @@ class PeakCallerRecursiveMasking(PeakCaller):
615
612
 
616
613
  def call_peaks(
617
614
  self,
618
- score_space: NDArray,
619
- rotation_matrix: NDArray,
620
- mask: NDArray = None,
615
+ scores: BackendArray,
616
+ rotation_matrix: BackendArray,
617
+ mask: BackendArray = None,
621
618
  minimum_score: float = None,
622
- rotation_space: NDArray = None,
619
+ rotation_space: BackendArray = None,
623
620
  rotation_mapping: Dict = None,
624
621
  **kwargs,
625
- ) -> Tuple[NDArray, NDArray]:
622
+ ) -> PeakType:
626
623
  """
627
624
  Call peaks in the score space.
628
625
 
629
626
  Parameters
630
627
  ----------
631
- score_space : NDArray
628
+ scores : BackendArray
632
629
  Data array of scores.
633
- rotation_matrix : NDArray
630
+ rotation_matrix : BackendArray
634
631
  Rotation matrix.
635
- mask : NDArray, optional
632
+ mask : BackendArray, optional
636
633
  Mask array, by default None.
637
- rotation_space : NDArray, optional
634
+ rotation_space : BackendArray, optional
638
635
  Rotation space array, by default None.
639
636
  rotation_mapping : Dict optional
640
637
  Dictionary mapping values in rotation_space to Euler angles.
@@ -645,7 +642,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
645
642
 
646
643
  Returns
647
644
  -------
648
- Tuple[NDArray, NDArray]
645
+ Tuple[BackendArray, BackendArray]
649
646
  Array of peak coordinates and peak details.
650
647
 
651
648
  Notes
@@ -659,23 +656,26 @@ class PeakCallerRecursiveMasking(PeakCaller):
659
656
 
660
657
  if mask is None:
661
658
  masking_function = self._mask_scores_box
662
- shape = tuple(self.min_distance for _ in range(score_space.ndim))
663
- mask = backend.zeros(shape, dtype=backend._default_dtype)
659
+ shape = tuple(self.min_distance for _ in range(scores.ndim))
660
+ mask = be.zeros(shape, dtype=be._float_dtype)
664
661
 
665
- rotated_template = backend.zeros(mask.shape, dtype=mask.dtype)
662
+ rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
666
663
 
667
664
  peak_limit = self.number_of_peaks
668
665
  if minimum_score is not None:
669
- peak_limit = backend.size(score_space)
666
+ peak_limit = be.size(scores)
670
667
  else:
671
- minimum_score = backend.min(score_space) - 1
668
+ minimum_score = be.min(scores) - 1
669
+
670
+ scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
671
+ scores_copy[:] = scores
672
672
 
673
673
  while True:
674
- backend.argmax(score_space)
675
- peak = backend.unravel_index(
676
- indices=backend.argmax(score_space), shape=score_space.shape
674
+ be.argmax(scores_copy)
675
+ peak = be.unravel_index(
676
+ indices=be.argmax(scores_copy), shape=scores_copy.shape
677
677
  )
678
- if score_space[tuple(peak)] < minimum_score:
678
+ if scores_copy[tuple(peak)] < minimum_score:
679
679
  break
680
680
 
681
681
  coordinates.append(peak)
@@ -688,7 +688,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
688
688
  )
689
689
 
690
690
  masking_function(
691
- score_space=score_space,
691
+ scores=scores_copy,
692
692
  rotation_matrix=current_rotation_matrix,
693
693
  peak=peak,
694
694
  mask=mask,
@@ -698,33 +698,33 @@ class PeakCallerRecursiveMasking(PeakCaller):
698
698
  if len(coordinates) >= peak_limit:
699
699
  break
700
700
 
701
- peaks = backend.to_backend_array(coordinates)
701
+ peaks = be.to_backend_array(coordinates)
702
702
  return peaks, None
703
703
 
704
704
  @staticmethod
705
705
  def _get_rotation_matrix(
706
- peak: NDArray,
707
- rotation_space: NDArray,
708
- rotation_mapping: NDArray,
709
- rotation_matrix: NDArray,
710
- ) -> NDArray:
706
+ peak: BackendArray,
707
+ rotation_space: BackendArray,
708
+ rotation_mapping: BackendArray,
709
+ rotation_matrix: BackendArray,
710
+ ) -> BackendArray:
711
711
  """
712
712
  Get rotation matrix based on peak and rotation data.
713
713
 
714
714
  Parameters
715
715
  ----------
716
- peak : NDArray
716
+ peak : BackendArray
717
717
  Peak coordinates.
718
- rotation_space : NDArray
718
+ rotation_space : BackendArray
719
719
  Rotation space array.
720
720
  rotation_mapping : Dict
721
721
  Dictionary mapping values in rotation_space to Euler angles.
722
- rotation_matrix : NDArray
722
+ rotation_matrix : BackendArray
723
723
  Current rotation matrix.
724
724
 
725
725
  Returns
726
726
  -------
727
- NDArray
727
+ BackendArray
728
728
  Rotation matrix.
729
729
  """
730
730
  if rotation_space is None or rotation_mapping is None:
@@ -732,130 +732,117 @@ class PeakCallerRecursiveMasking(PeakCaller):
732
732
 
733
733
  rotation = rotation_mapping[rotation_space[tuple(peak)]]
734
734
 
735
- rotation_matrix = backend.to_backend_array(
736
- euler_to_rotationmatrix(backend.to_numpy_array(rotation))
737
- )
738
- return rotation_matrix
735
+ # TODO: Newer versions of rotation mapping contain rotation matrices not angles
736
+ if len(rotation) == 3:
737
+ rotation = be.to_backend_array(
738
+ euler_to_rotationmatrix(be.to_numpy_array(rotation))
739
+ )
740
+ return rotation
739
741
 
740
742
  @staticmethod
741
743
  def _mask_scores_box(
742
- score_space: NDArray, peak: NDArray, mask: NDArray, **kwargs: Dict
744
+ scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
743
745
  ) -> None:
744
746
  """
745
747
  Mask scores in a box around a peak.
746
748
 
747
749
  Parameters
748
750
  ----------
749
- score_space : NDArray
751
+ scores : BackendArray
750
752
  Data array of scores.
751
- peak : NDArray
753
+ peak : BackendArray
752
754
  Peak coordinates.
753
- mask : NDArray
755
+ mask : BackendArray
754
756
  Mask array.
755
757
  """
756
- start = backend.maximum(backend.subtract(peak, mask.shape), 0)
757
- stop = backend.minimum(backend.add(peak, mask.shape), score_space.shape)
758
- start, stop = backend.astype(start, int), backend.astype(stop, int)
758
+ start = be.maximum(be.subtract(peak, mask.shape), 0)
759
+ stop = be.minimum(be.add(peak, mask.shape), scores.shape)
760
+ start, stop = be.astype(start, int), be.astype(stop, int)
759
761
  coords = tuple(slice(*pos) for pos in zip(start, stop))
760
- score_space[coords] = 0
762
+ scores[coords] = 0
761
763
  return None
762
764
 
763
765
  @staticmethod
764
766
  def _mask_scores_rotate(
765
- score_space: NDArray,
766
- peak: NDArray,
767
- mask: NDArray,
768
- rotated_template: NDArray,
769
- rotation_matrix: NDArray,
767
+ scores: BackendArray,
768
+ peak: BackendArray,
769
+ mask: BackendArray,
770
+ rotated_template: BackendArray,
771
+ rotation_matrix: BackendArray,
770
772
  **kwargs: Dict,
771
773
  ) -> None:
772
774
  """
773
- Mask score_space using mask rotation around a peak.
775
+ Mask scores using mask rotation around a peak.
774
776
 
775
777
  Parameters
776
778
  ----------
777
- score_space : NDArray
779
+ scores : BackendArray
778
780
  Data array of scores.
779
- peak : NDArray
781
+ peak : BackendArray
780
782
  Peak coordinates.
781
- mask : NDArray
783
+ mask : BackendArray
782
784
  Mask array.
783
- rotated_template : NDArray
785
+ rotated_template : BackendArray
784
786
  Empty array to write mask rotations to.
785
- rotation_matrix : NDArray
787
+ rotation_matrix : BackendArray
786
788
  Rotation matrix.
787
789
  """
788
- left_pad = backend.divide(mask.shape, 2).astype(int)
789
- right_pad = backend.add(left_pad, backend.mod(mask.shape, 2).astype(int))
790
+ left_pad = be.divide(mask.shape, 2).astype(int)
791
+ right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
790
792
 
791
- score_start = backend.subtract(peak, left_pad)
792
- score_stop = backend.add(peak, right_pad)
793
+ score_start = be.subtract(peak, left_pad)
794
+ score_stop = be.add(peak, right_pad)
793
795
 
794
- template_start = backend.subtract(backend.maximum(score_start, 0), score_start)
795
- template_stop = backend.subtract(
796
- score_stop, backend.minimum(score_stop, score_space.shape)
797
- )
798
- template_stop = backend.subtract(mask.shape, template_stop)
796
+ template_start = be.subtract(be.maximum(score_start, 0), score_start)
797
+ template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
798
+ template_stop = be.subtract(mask.shape, template_stop)
799
799
 
800
- score_start = backend.maximum(score_start, 0)
801
- score_stop = backend.minimum(score_stop, score_space.shape)
802
- score_start = backend.astype(score_start, int)
803
- score_stop = backend.astype(score_stop, int)
800
+ score_start = be.maximum(score_start, 0)
801
+ score_stop = be.minimum(score_stop, scores.shape)
802
+ score_start = be.astype(score_start, int)
803
+ score_stop = be.astype(score_stop, int)
804
804
 
805
- template_start = backend.astype(template_start, int)
806
- template_stop = backend.astype(template_stop, int)
805
+ template_start = be.astype(template_start, int)
806
+ template_stop = be.astype(template_stop, int)
807
807
  coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
808
808
  coords_template = tuple(
809
809
  slice(*pos) for pos in zip(template_start, template_stop)
810
810
  )
811
811
 
812
812
  rotated_template.fill(0)
813
- backend.rotate_array(
813
+ be.rigid_transform(
814
814
  arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
815
815
  )
816
816
 
817
- score_space[coords_score] = backend.multiply(
818
- score_space[coords_score], (rotated_template[coords_template] <= 0.1)
817
+ scores[coords_score] = be.multiply(
818
+ scores[coords_score], (rotated_template[coords_template] <= 0.1)
819
819
  )
820
820
  return None
821
821
 
822
822
 
823
823
  class PeakCallerScipy(PeakCaller):
824
824
  """
825
- Peak calling using skimage.feature.peak_local_max to compute local maxima.
825
+ Peak calling using :obj:`skimage.feature.peak_local_max` to compute local maxima.
826
826
  """
827
827
 
828
828
  def call_peaks(
829
- self, score_space: NDArray, minimum_score: float = None, **kwargs
830
- ) -> Tuple[NDArray, NDArray]:
831
- """
832
- Call peaks in the score space.
833
-
834
- Parameters
835
- ----------
836
- score_space : NDArray
837
- Data array of scores.
838
- minimum_score : float
839
- Minimum score value to consider. If provided, superseeds limit given
840
- by :py:attr:`PeakCaller.number_of_peaks`.
841
-
842
- Returns
843
- -------
844
- Tuple[NDArray, NDArray]
845
- Array of peak coordinates and peak details.
846
- """
847
-
848
- score_space = backend.to_numpy_array(score_space)
829
+ self, scores: BackendArray, minimum_score: float = None, **kwargs
830
+ ) -> PeakType:
831
+ scores = be.to_numpy_array(scores)
849
832
  num_peaks = self.number_of_peaks
850
833
  if minimum_score is not None:
851
834
  num_peaks = np.inf
852
835
 
836
+ non_squeezable_dims = tuple(i for i, x in enumerate(scores.shape) if x != 1)
853
837
  peaks = peak_local_max(
854
- score_space,
838
+ np.squeeze(scores),
855
839
  num_peaks=num_peaks,
856
840
  min_distance=self.min_distance,
857
841
  threshold_abs=minimum_score,
858
842
  )
843
+ peaks_full = np.zeros((peaks.shape[0], scores.ndim), peaks.dtype)
844
+ peaks_full[..., non_squeezable_dims] = peaks[:]
845
+ peaks = be.to_backend_array(peaks_full)
859
846
  return peaks, None
860
847
 
861
848
 
@@ -880,7 +867,7 @@ class PeakClustering(PeakCallerSort):
880
867
  Parameters
881
868
  ----------
882
869
  **kwargs
883
- Additional keyword arguments passed to :py:meth:`PeakCaller.merge`.
870
+ Optional keyword arguments passed to :py:meth:`PeakCaller.merge`.
884
871
 
885
872
  Returns
886
873
  -------
@@ -912,464 +899,249 @@ class PeakClustering(PeakCallerSort):
912
899
  return peaks, rotations, scores, details
913
900
 
914
901
 
915
- class ScoreStatistics(PeakCallerFast):
916
- """
917
- Compute basic statistics on score spaces with respect to a reference
918
- score or value.
919
-
920
- This class is used to evaluate a blurring or scoring method when the correct fit
921
- is known. It is thread-safe and is designed to be shared among multiple processes
922
- with write permissions to the internal parameters.
923
-
924
- After instantiation, the class's functionality can be accessed through the
925
- `__call__` method.
926
-
927
- Parameters
928
- ----------
929
- reference_position : int, optional
930
- Index of the correct fit in the array passed to call. Defaults to None.
931
- min_distance : float, optional
932
- Minimum distance for local maxima. Defaults to None.
933
- reference_fit : float, optional
934
- Score of the correct fit. If set, `reference_position` will be ignored.
935
- Defaults to None.
936
- number_of_peaks : int, optional
937
- Number of candidate fits to consider. Defaults to 1.
938
- """
939
-
940
- def __init__(
941
- self,
942
- reference_position: Tuple[int] = None,
943
- min_distance: float = 10,
944
- reference_fit: float = None,
945
- number_of_peaks: int = 1,
946
- ):
947
- super().__init__(number_of_peaks=number_of_peaks, min_distance=min_distance)
948
- self.lock = Lock()
949
-
950
- self.n = RawValue("Q", 0)
951
- self.rmean = RawValue("d", 0)
952
- self.ssqd = RawValue("d", 0)
953
- self.nbetter_or_equal = RawValue("Q", 0)
954
- self.maximum_value = RawValue("f", 0)
955
- self.minimum_value = RawValue("f", 2**32)
956
- self.shannon_entropy = Manager().list()
957
- self.candidate_fits = Manager().list()
958
- self.rotation_names = Manager().list()
959
- self.reference_fit = RawValue("f", 0)
960
- self.has_reference = RawValue("i", 0)
961
-
962
- self.reference_position = reference_position
963
- if reference_fit is not None:
964
- self.reference_fit.value = reference_fit
965
- self.has_reference.value = 1
966
-
967
- def __call__(
968
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
969
- ) -> None:
970
- """
971
- Processes the input array and rotation matrix.
972
-
973
- Parameters
974
- ----------
975
- arr : NDArray
976
- Input data array.
977
- rotation_matrix : NDArray
978
- Rotation matrix for processing.
979
- """
980
- self.set_reference(score_space, rotation_matrix)
981
-
982
- while not self.has_reference.value:
983
- print("Stalling processes until reference_fit has been set.")
984
- sleep(0.5)
985
-
986
- name = "_".join([str(value) for value in rotation_matrix.ravel()])
987
- n, rmean, ssqd, nbetter_or_equal, max_value, min_value = online_statistics(
988
- score_space, 0, 0.0, 0.0, self.reference_fit.value
989
- )
990
-
991
- freq, _ = np.histogram(score_space, bins=100)
992
- shannon_entropy = entropy(freq / score_space.size)
993
-
994
- peaks, _ = super().call_peaks(
995
- score_space=score_space, rotation_matrix=rotation_matrix, **kwargs
996
- )
997
- scores = score_space[tuple(peaks.T)]
998
- rotations = np.repeat(
999
- rotation_matrix.reshape(1, *rotation_matrix.shape),
1000
- peaks.shape[0],
1001
- axis=0,
1002
- )
1003
- distances = np.linalg.norm(peaks - self.reference_position[None, :], axis=1)
1004
-
1005
- self._update(
1006
- peak_positions=peaks,
1007
- rotations=rotations,
1008
- peak_scores=scores,
1009
- peak_details=distances,
1010
- n=n,
1011
- rmean=rmean,
1012
- ssqd=ssqd,
1013
- nbetter_or_equal=nbetter_or_equal,
1014
- max_value=max_value,
1015
- min_value=min_value,
1016
- entropy=shannon_entropy,
1017
- name=name,
1018
- )
1019
-
1020
- def __iter__(self):
1021
- param_store = (
1022
- self.peak_list[0],
1023
- self.peak_list[1],
1024
- self.peak_list[2],
1025
- self.peak_list[3],
1026
- self.n.value,
1027
- self.rmean.value,
1028
- self.ssqd.value,
1029
- self.nbetter_or_equal.value,
1030
- self.maximum_value.value,
1031
- self.minimum_value.value,
1032
- list(self.shannon_entropy),
1033
- list(self.rotation_names),
1034
- self.reference_fit.value,
1035
- )
1036
- yield from param_store
1037
-
1038
- def _update(
1039
- self,
1040
- n: int,
1041
- rmean: float,
1042
- ssqd: float,
1043
- nbetter_or_equal: int,
1044
- max_value: float,
1045
- min_value: float,
1046
- entropy: float,
1047
- name: str,
1048
- **kwargs,
1049
- ) -> None:
1050
- """
1051
- Updates the internal statistics of the analyzer.
1052
-
1053
- Parameters
1054
- ----------
1055
- n : int
1056
- Sample size.
1057
- rmean : float
1058
- Running mean.
1059
- ssqd : float
1060
- Sum of squared differences.
1061
- nbetter_or_equal : int
1062
- Number of values better or equal to reference.
1063
- max_value : float
1064
- Maximum value.
1065
- min_value : float
1066
- Minimum value.
1067
- entropy : float
1068
- Shannon entropy.
1069
- candidates : list
1070
- List of candidate fits.
1071
- name : str
1072
- Name or label for the data.
1073
- kwargs : dict
1074
- Keyword arguments passed to PeakCaller._update.
1075
- """
1076
- with self.lock:
1077
- super()._update(**kwargs)
1078
-
1079
- n_total = self.n.value + n
1080
- delta = rmean - self.rmean.value
1081
- delta2 = delta * delta
1082
- self.rmean.value += delta * n / n_total
1083
- self.ssqd.value += ssqd + delta2 * (n * self.n.value) / n_total
1084
- self.n.value = n_total
1085
- self.nbetter_or_equal.value += nbetter_or_equal
1086
- self.minimum_value.value = min(self.minimum_value.value, min_value)
1087
- self.maximum_value.value = max(self.maximum_value.value, max_value)
1088
- self.shannon_entropy.append(entropy)
1089
- self.rotation_names.append(name)
1090
-
1091
- @classmethod
1092
- def merge(cls, param_stores: List[Tuple]) -> Tuple:
1093
- """
1094
- Merges multiple instances of :py:class`ScoreStatistics`.
1095
-
1096
- Parameters
1097
- ----------
1098
- param_stores : list of tuple
1099
- Internal parameter store. Obtained by running `tuple(instance)`.
1100
- Defaults to a list with two empty tuples.
1101
-
1102
- Returns
1103
- -------
1104
- tuple
1105
- Contains the reference fit, the z-transform of the reference fit,
1106
- number of scores, and various other statistics.
1107
- """
1108
- base = cls(reference_position=np.zeros(3, int))
1109
- for param_store in param_stores:
1110
- base._update(
1111
- peak_positions=param_store[0],
1112
- rotations=param_store[1],
1113
- peak_scores=param_store[2],
1114
- peak_details=param_store[3],
1115
- n=param_store[4],
1116
- rmean=param_store[5],
1117
- ssqd=param_store[6],
1118
- nbetter_or_equal=param_store[7],
1119
- max_value=param_store[8],
1120
- min_value=param_store[9],
1121
- entropy=param_store[10],
1122
- name=param_store[11],
1123
- )
1124
- base.reference_fit.value = param_store[12]
1125
- return tuple(base)
1126
-
1127
- def set_reference(self, score_space: NDArray, rotation_matrix: NDArray) -> None:
1128
- """
1129
- Sets the reference for the analyzer based on the input array
1130
- and rotation matrix.
1131
-
1132
- Parameters
1133
- ----------
1134
- score_space : NDArray
1135
- Input data array.
1136
- rotation_matrix : NDArray
1137
- Rotation matrix for setting reference.
1138
- """
1139
- is_ref = np.allclose(
1140
- rotation_matrix,
1141
- np.eye(rotation_matrix.shape[0], dtype=rotation_matrix.dtype),
1142
- )
1143
- if not is_ref:
1144
- return None
1145
-
1146
- reference_position = self.reference_position
1147
- if reference_position is None:
1148
- reference_position = np.divide(score_space.shape, 2).astype(int)
1149
- self.reference_position = reference_position
1150
- self.reference_fit.value = score_space[tuple(reference_position)]
1151
- self.has_reference.value = 1
1152
-
1153
-
1154
902
  class MaxScoreOverRotations:
1155
903
  """
1156
- Obtain the maximum translation score over various rotations.
904
+ Determine the rotation maximizing the score of all given translations.
1157
905
 
1158
906
  Attributes
1159
907
  ----------
1160
- score_space : NDArray
1161
- The score space for the observed rotations.
1162
- rotations : NDArray
1163
- The rotation identifiers for each score.
1164
- translation_offset : NDArray, optional
1165
- The offset applied during translation.
1166
- observed_rotations : int
1167
- Count of observed rotations.
908
+ shape : tuple of ints.
909
+ Shape of ``scores`` and rotations.
910
+ scores : BackendArray
911
+ Array mapping translations to scores.
912
+ rotations : BackendArray
913
+ Array mapping translations to rotation indices.
914
+ rotation_mapping : Dict
915
+ Mapping of rotation matrix bytestrings to rotation indices.
916
+ offset : BackendArray, optional
917
+ Coordinate origin considered during merging, zero by default
1168
918
  use_memmap : bool, optional
1169
- Whether to offload internal data arrays to disk
919
+ Memmap scores and rotations arrays, False by default.
1170
920
  thread_safe: bool, optional
1171
- Whether access to internal data arrays should be thread safe
921
+ Allow class to be modified by multiple processes, True by default.
922
+ only_unique_rotations : bool, optional
923
+ Whether each rotation will be shown only once, False by default.
924
+
925
+ Raises
926
+ ------
927
+ ValueError
928
+ If the data shape cannot be determined from the parameters.
929
+
930
+ Examples
931
+ --------
932
+ The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
933
+ instance
934
+
935
+ >>> from tme.analyzer import MaxScoreOverRotations
936
+ >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
937
+
938
+ The following simulates a template matching run by creating random data for a range
939
+ of rotations and sending it to ``analyzer`` via its __call__ method
940
+
941
+ >>> for rotation_number in range(10):
942
+ >>> scores = np.random.rand(50,50)
943
+ >>> rotation = np.random.rand(scores.ndim, scores.ndim)
944
+ >>> analyzer(scores = scores, rotation_matrix = rotation)
945
+
946
+ The aggregated scores can be extracted by invoking the __iter__ method of
947
+ ``analyzer``
948
+
949
+ >>> results = tuple(analyzer)
950
+
951
+ The ``results`` tuple contains (1) the maximum scores for each translation,
952
+ (2) an offset which is relevant when merging results from split template matching
953
+ using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
954
+ score for a given translation, (4) a dictionary mapping rotation matrices to the
955
+ indices used in (2).
956
+
957
+ We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
958
+ as follows
959
+
960
+ >>> optimal_score = results[0].max()
961
+ >>> optimal_translation = np.where(results[0] == results[0].max())
962
+ >>> optimal_rotation_index = results[2][optimal_translation]
963
+ >>> for key, value in results[3].items():
964
+ >>> if value != optimal_rotation_index:
965
+ >>> continue
966
+ >>> optimal_rotation = np.frombuffer(key, rotation.dtype)
967
+ >>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
968
+
969
+ The outlined procedure is a trivial method to identify high scoring peaks.
970
+ Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
971
+ that can be used.
1172
972
  """
1173
973
 
1174
974
  def __init__(
1175
975
  self,
1176
- score_space_shape: Tuple[int],
1177
- score_space_dtype: type,
1178
- translation_offset: NDArray = None,
976
+ shape: Tuple[int] = None,
977
+ scores: BackendArray = None,
978
+ rotations: BackendArray = None,
979
+ offset: BackendArray = None,
1179
980
  score_threshold: float = 0,
1180
981
  shared_memory_handler: object = None,
1181
- rotation_space_dtype: type = int,
1182
982
  use_memmap: bool = False,
1183
983
  thread_safe: bool = True,
984
+ only_unique_rotations: bool = False,
1184
985
  **kwargs,
1185
986
  ):
1186
- score_space_shape = tuple(int(x) for x in score_space_shape)
1187
- self.score_space = backend.arr_to_sharedarr(
1188
- backend.full(
1189
- shape=score_space_shape,
1190
- dtype=score_space_dtype,
987
+ if shape is None and scores is None:
988
+ raise ValueError("Either scores_shape or scores need to be specified.")
989
+
990
+ if scores is None:
991
+ shape = tuple(int(x) for x in shape)
992
+ scores = be.full(
993
+ shape=shape,
994
+ dtype=be._float_dtype,
1191
995
  fill_value=score_threshold,
1192
- ),
1193
- shared_memory_handler=shared_memory_handler,
1194
- )
1195
- self.rotations = backend.arr_to_sharedarr(
1196
- backend.full(score_space_shape, dtype=rotation_space_dtype, fill_value=-1),
1197
- shared_memory_handler,
1198
- )
1199
- if translation_offset is None:
1200
- translation_offset = backend.zeros(len(score_space_shape))
996
+ )
997
+ self.scores, self.shape = scores, scores.shape
1201
998
 
1202
- self.translation_offset = backend.astype(translation_offset, int)
1203
- self.score_space_shape = score_space_shape
1204
- self.rotation_space_dtype = rotation_space_dtype
1205
- self.score_space_dtype = score_space_dtype
999
+ if rotations is None:
1000
+ rotations = be.full(shape, dtype=be._int_dtype, fill_value=-1)
1001
+ self.rotations = rotations
1002
+
1003
+ self.scores_dtype = self.scores.dtype
1004
+ self.rotations_dtype = self.rotations.dtype
1005
+ self.scores = be.to_sharedarr(self.scores, shared_memory_handler)
1006
+ self.rotations = be.to_sharedarr(self.rotations, shared_memory_handler)
1007
+
1008
+ if offset is None:
1009
+ offset = be.zeros(len(self.shape), be._int_dtype)
1010
+ self.offset = be.astype(offset, int)
1206
1011
 
1207
1012
  self.use_memmap = use_memmap
1208
1013
  self.lock = Manager().Lock() if thread_safe else nullcontext()
1209
- self.lock_is_nullcontext = isinstance(self.score_space, type(backend.zeros((1))))
1210
- self.observed_rotations = Manager().dict() if thread_safe else {}
1211
-
1014
+ self.lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
1015
+ self.rotation_mapping = Manager().dict() if thread_safe else {}
1016
+ self._inversion_mapping = self.lock_is_nullcontext and only_unique_rotations
1212
1017
 
1213
- def _postprocess(self,
1214
- fourier_shift,
1215
- convolution_mode,
1216
- targetshape,
1217
- templateshape,
1018
+ def _postprocess(
1019
+ self,
1020
+ targetshape: Tuple[int],
1021
+ templateshape: Tuple[int],
1022
+ fourier_shift: Tuple[int] = None,
1023
+ convolution_mode: str = None,
1218
1024
  shared_memory_handler=None,
1219
- **kwargs
1220
- ):
1221
- internal_scores = backend.sharedarr_to_arr(
1222
- shape=self.score_space_shape,
1223
- dtype=self.score_space_dtype,
1224
- shm=self.score_space,
1225
- )
1226
- internal_rotations = backend.sharedarr_to_arr(
1227
- shape=self.score_space_shape,
1228
- dtype=self.rotation_space_dtype,
1229
- shm=self.rotations,
1230
- )
1231
-
1025
+ fast_shape: Tuple[int] = None,
1026
+ **kwargs,
1027
+ ) -> "MaxScoreOverRotations":
1028
+ """
1029
+ Correct padding to Fourier (and if requested convolution) shape.
1030
+ """
1031
+ scores = be.from_sharedarr(self.scores)
1032
+ rotations = be.from_sharedarr(self.rotations)
1232
1033
  if fourier_shift is not None:
1233
1034
  axis = tuple(i for i in range(len(fourier_shift)))
1234
- internal_scores = backend.roll(
1235
- internal_scores,
1236
- shift=fourier_shift,
1237
- axis=axis
1238
- )
1239
- internal_rotations = backend.roll(
1240
- internal_rotations,
1241
- shift=fourier_shift,
1242
- axis=axis
1243
- )
1035
+ scores = be.roll(scores, shift=fourier_shift, axis=axis)
1036
+ rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
1244
1037
 
1038
+ convargs = {
1039
+ "s1": targetshape,
1040
+ "s2": templateshape,
1041
+ "convolution_mode": convolution_mode,
1042
+ }
1245
1043
  if convolution_mode is not None:
1246
- internal_scores = apply_convolution_mode(
1247
- internal_scores,
1248
- convolution_mode=convolution_mode,
1249
- s1=targetshape,
1250
- s2=templateshape
1251
- )
1252
- internal_rotations = apply_convolution_mode(
1253
- internal_rotations,
1254
- convolution_mode=convolution_mode,
1255
- s1=targetshape,
1256
- s2=templateshape
1257
- )
1044
+ scores = apply_convolution_mode(scores, **convargs)
1045
+ rotations = apply_convolution_mode(rotations, **convargs)
1258
1046
 
1259
- self.score_space_shape = internal_scores.shape
1260
- self.score_space = backend.arr_to_sharedarr(
1261
- internal_scores,
1262
- shared_memory_handler
1263
- )
1264
- self.rotations = backend.arr_to_sharedarr(
1265
- internal_rotations,
1266
- shared_memory_handler
1267
- )
1047
+ self.shape = scores.shape
1048
+ self.scores = be.to_sharedarr(scores, shared_memory_handler)
1049
+ self.rotations = be.to_sharedarr(rotations, shared_memory_handler)
1268
1050
  return self
1269
1051
 
1052
+ def __iter__(self) -> Generator:
1053
+ scores = be.from_sharedarr(self.scores)
1054
+ rotations = be.from_sharedarr(self.rotations)
1270
1055
 
1271
- def __iter__(self):
1272
- internal_scores = backend.sharedarr_to_arr(
1273
- shape=self.score_space_shape,
1274
- dtype=self.score_space_dtype,
1275
- shm=self.score_space,
1276
- )
1277
- internal_rotations = backend.sharedarr_to_arr(
1278
- shape=self.score_space_shape,
1279
- dtype=self.rotation_space_dtype,
1280
- shm=self.rotations,
1281
- )
1282
-
1283
- internal_scores = backend.to_numpy_array(internal_scores)
1284
- internal_rotations = backend.to_numpy_array(internal_rotations)
1056
+ scores = be.to_numpy_array(scores)
1057
+ rotations = be.to_numpy_array(rotations)
1285
1058
  if self.use_memmap:
1286
- internal_scores_filename = array_to_memmap(internal_scores)
1287
- internal_rotations_filename = array_to_memmap(internal_rotations)
1288
- internal_scores = np.memmap(
1289
- internal_scores_filename,
1059
+ scores = np.memmap(
1060
+ array_to_memmap(scores),
1290
1061
  mode="r",
1291
- dtype=internal_scores.dtype,
1292
- shape=internal_scores.shape,
1062
+ dtype=scores.dtype,
1063
+ shape=scores.shape,
1293
1064
  )
1294
- internal_rotations = np.memmap(
1295
- internal_rotations_filename,
1065
+ rotations = np.memmap(
1066
+ array_to_memmap(rotations),
1296
1067
  mode="r",
1297
- dtype=internal_rotations.dtype,
1298
- shape=internal_rotations.shape,
1068
+ dtype=rotations.dtype,
1069
+ shape=rotations.shape,
1299
1070
  )
1300
1071
  else:
1301
- # Avoid invalidation by shared memory handler with copy
1302
- internal_scores = internal_scores.copy()
1303
- internal_rotations = internal_rotations.copy()
1072
+ # Copy to avoid invalidation by shared memory handler
1073
+ scores, rotations = scores.copy(), rotations.copy()
1074
+
1075
+ if self._inversion_mapping:
1076
+ self.rotation_mapping = {
1077
+ be.tobytes(v): k for k, v in self.rotation_mapping.items()
1078
+ }
1304
1079
 
1305
1080
  param_store = (
1306
- internal_scores,
1307
- backend.to_numpy_array(self.translation_offset),
1308
- internal_rotations,
1309
- dict(self.observed_rotations),
1081
+ scores,
1082
+ be.to_numpy_array(self.offset),
1083
+ rotations,
1084
+ dict(self.rotation_mapping),
1310
1085
  )
1311
1086
  yield from param_store
1312
1087
 
1313
- def __call__(
1314
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1315
- ) -> None:
1088
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
1316
1089
  """
1317
- Update internal parameter store based on `score_space`.
1090
+ Update internal parameter store based on `scores`.
1318
1091
 
1319
1092
  Parameters
1320
1093
  ----------
1321
- score_space : ndarray
1322
- Numpy array containing the score space.
1323
- rotation_matrix : ndarray
1094
+ scores : BackendArray
1095
+ Array containing the score space.
1096
+ rotation_matrix : BackendArray
1324
1097
  Square matrix describing the current rotation.
1325
- **kwargs
1326
- Arbitrary keyword arguments.
1327
1098
  """
1328
- rotation = backend.tobytes(rotation_matrix)
1329
- rotation_index = self.observed_rotations.setdefault(
1330
- rotation, len(self.observed_rotations)
1331
- )
1332
-
1099
+ # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
1100
+ # If the analyzer is not shared and each rotation is unique, we can
1101
+ # use index to rotation mapping and invert prior to merging.
1333
1102
  if self.lock_is_nullcontext:
1334
- backend.max_score_over_rotations(
1335
- score_space=score_space,
1336
- internal_scores=self.score_space,
1337
- internal_rotations=self.rotations,
1103
+ rotation_index = len(self.rotation_mapping)
1104
+ if self._inversion_mapping:
1105
+ self.rotation_mapping[rotation_index] = rotation_matrix
1106
+ else:
1107
+ rotation = be.tobytes(rotation_matrix)
1108
+ rotation_index = self.rotation_mapping.setdefault(
1109
+ rotation, rotation_index
1110
+ )
1111
+ self.scores, self.rotations = be.max_score_over_rotations(
1112
+ scores=scores,
1113
+ max_scores=self.scores,
1114
+ rotations=self.rotations,
1338
1115
  rotation_index=rotation_index,
1339
1116
  )
1340
1117
  return None
1341
1118
 
1119
+ rotation = be.tobytes(rotation_matrix)
1342
1120
  with self.lock:
1343
- internal_scores = backend.sharedarr_to_arr(
1344
- shape=self.score_space_shape,
1345
- dtype=self.score_space_dtype,
1346
- shm=self.score_space,
1121
+ rotation_index = self.rotation_mapping.setdefault(
1122
+ rotation, len(self.rotation_mapping)
1347
1123
  )
1348
- internal_rotations = backend.sharedarr_to_arr(
1349
- shape=self.score_space_shape,
1350
- dtype=self.rotation_space_dtype,
1351
- shm=self.rotations,
1352
- )
1353
-
1354
- backend.max_score_over_rotations(
1355
- score_space=score_space,
1356
- internal_scores=internal_scores,
1357
- internal_rotations=internal_rotations,
1124
+ internal_scores = be.from_sharedarr(self.scores)
1125
+ internal_rotations = be.from_sharedarr(self.rotations)
1126
+ internal_sores, internal_rotations = be.max_score_over_rotations(
1127
+ scores=scores,
1128
+ max_scores=internal_scores,
1129
+ rotations=internal_rotations,
1358
1130
  rotation_index=rotation_index,
1359
1131
  )
1360
1132
  return None
1361
1133
 
1362
1134
  @classmethod
1363
- def merge(cls, param_stores=List[Tuple], **kwargs) -> Tuple[NDArray]:
1135
+ def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple[NDArray]:
1364
1136
  """
1365
1137
  Merges multiple instances of :py:class:`MaxScoreOverRotations`.
1366
1138
 
1367
1139
  Parameters
1368
1140
  ----------
1369
- param_stores : list of tuples, optional
1141
+ param_stores : list of tuples
1370
1142
  Internal parameter store. Obtained by running `tuple(instance)`.
1371
1143
  **kwargs
1372
- Arbitrary keyword arguments.
1144
+ Optional keyword arguments.
1373
1145
 
1374
1146
  Returns
1375
1147
  -------
@@ -1381,51 +1153,51 @@ class MaxScoreOverRotations:
1381
1153
  if len(param_stores) == 1:
1382
1154
  return param_stores[0]
1383
1155
 
1384
- new_rotation_mapping, base_max = {}, None
1385
- scores_out_dtype, rotations_out_dtype = None, None
1156
+ # Determine output array shape and create consistent rotation map
1157
+ new_rotation_mapping, out_shape = {}, None
1386
1158
  for i in range(len(param_stores)):
1387
1159
  if param_stores[i] is None:
1388
1160
  continue
1389
- score_space, offset, rotations, rotation_mapping = param_stores[i]
1390
- if base_max is None:
1391
- base_max = np.zeros(score_space.ndim, int)
1392
- scores_out_dtype = score_space.dtype
1393
- rotations_out_dtype = rotations.dtype
1394
- np.maximum(base_max, np.add(offset, score_space.shape), out=base_max)
1161
+
1162
+ scores, offset, rotations, rotation_mapping = param_stores[i]
1163
+ if out_shape is None:
1164
+ out_shape = np.zeros(scores.ndim, int)
1165
+ scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
1166
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
1395
1167
 
1396
1168
  for key, value in rotation_mapping.items():
1397
1169
  if key not in new_rotation_mapping:
1398
1170
  new_rotation_mapping[key] = len(new_rotation_mapping)
1399
1171
 
1400
- if base_max is None:
1172
+ if out_shape is None:
1401
1173
  return None
1402
1174
 
1403
- base_max = tuple(int(x) for x in base_max)
1175
+ out_shape = tuple(int(x) for x in out_shape)
1404
1176
  use_memmap = kwargs.get("use_memmap", False)
1405
1177
  if use_memmap:
1406
1178
  scores_out_filename = generate_tempfile_name()
1407
1179
  rotations_out_filename = generate_tempfile_name()
1408
1180
 
1409
1181
  scores_out = np.memmap(
1410
- scores_out_filename, mode="w+", shape=base_max, dtype=scores_out_dtype
1182
+ scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
1411
1183
  )
1412
1184
  scores_out.fill(kwargs.get("score_threshold", 0))
1413
1185
  scores_out.flush()
1414
1186
  rotations_out = np.memmap(
1415
1187
  rotations_out_filename,
1416
1188
  mode="w+",
1417
- shape=base_max,
1418
- dtype=rotations_out_dtype,
1189
+ shape=out_shape,
1190
+ dtype=rotations_dtype,
1419
1191
  )
1420
1192
  rotations_out.fill(-1)
1421
1193
  rotations_out.flush()
1422
1194
  else:
1423
1195
  scores_out = np.full(
1424
- base_max,
1196
+ out_shape,
1425
1197
  fill_value=kwargs.get("score_threshold", 0),
1426
- dtype=scores_out_dtype,
1198
+ dtype=scores_dtype,
1427
1199
  )
1428
- rotations_out = np.full(base_max, fill_value=-1, dtype=rotations_out_dtype)
1200
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
1429
1201
 
1430
1202
  for i in range(len(param_stores)):
1431
1203
  if param_stores[i] is None:
@@ -1435,21 +1207,21 @@ class MaxScoreOverRotations:
1435
1207
  scores_out = np.memmap(
1436
1208
  scores_out_filename,
1437
1209
  mode="r+",
1438
- shape=base_max,
1439
- dtype=scores_out_dtype,
1210
+ shape=out_shape,
1211
+ dtype=scores_dtype,
1440
1212
  )
1441
1213
  rotations_out = np.memmap(
1442
1214
  rotations_out_filename,
1443
1215
  mode="r+",
1444
- shape=base_max,
1445
- dtype=rotations_out_dtype,
1216
+ shape=out_shape,
1217
+ dtype=rotations_dtype,
1446
1218
  )
1447
- score_space, offset, rotations, rotation_mapping = param_stores[i]
1448
- stops = np.add(offset, score_space.shape).astype(int)
1219
+ scores, offset, rotations, rotation_mapping = param_stores[i]
1220
+ stops = np.add(offset, scores.shape).astype(int)
1449
1221
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
1450
1222
 
1451
- indices_update = score_space > scores_out[indices]
1452
- scores_out[indices][indices_update] = score_space[indices_update]
1223
+ indices_update = scores > scores_out[indices]
1224
+ scores_out[indices][indices_update] = scores[indices_update]
1453
1225
 
1454
1226
  lookup_table = np.arange(
1455
1227
  len(rotation_mapping) + 1, dtype=rotations_out.dtype
@@ -1462,24 +1234,24 @@ class MaxScoreOverRotations:
1462
1234
  rotations_out[indices][indices_update] = lookup_table[updated_rotations]
1463
1235
 
1464
1236
  if use_memmap:
1465
- score_space._mmap.close()
1237
+ scores._mmap.close()
1466
1238
  rotations._mmap.close()
1467
1239
  scores_out.flush()
1468
1240
  rotations_out.flush()
1469
1241
  scores_out, rotations_out = None, None
1470
1242
 
1471
1243
  param_stores[i] = None
1472
- score_space, rotations = None, None
1244
+ scores, rotations = None, None
1473
1245
 
1474
1246
  if use_memmap:
1475
1247
  scores_out = np.memmap(
1476
- scores_out_filename, mode="r", shape=base_max, dtype=scores_out_dtype
1248
+ scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
1477
1249
  )
1478
1250
  rotations_out = np.memmap(
1479
1251
  rotations_out_filename,
1480
1252
  mode="r",
1481
- shape=base_max,
1482
- dtype=rotations_out_dtype,
1253
+ shape=out_shape,
1254
+ dtype=rotations_dtype,
1483
1255
  )
1484
1256
  return (
1485
1257
  scores_out,
@@ -1488,6 +1260,10 @@ class MaxScoreOverRotations:
1488
1260
  new_rotation_mapping,
1489
1261
  )
1490
1262
 
1263
+ @property
1264
+ def shared(self):
1265
+ return True
1266
+
1491
1267
 
1492
1268
  class _MaxScoreOverTranslations(MaxScoreOverRotations):
1493
1269
  """
@@ -1495,11 +1271,11 @@ class _MaxScoreOverTranslations(MaxScoreOverRotations):
1495
1271
 
1496
1272
  Attributes
1497
1273
  ----------
1498
- score_space : NDArray
1274
+ scores : BackendArray
1499
1275
  The score space for the observed rotations.
1500
- rotations : NDArray
1276
+ rotations : BackendArray
1501
1277
  The rotation identifiers for each score.
1502
- translation_offset : NDArray, optional
1278
+ translation_offset : BackendArray, optional
1503
1279
  The offset applied during translation.
1504
1280
  observed_rotations : int
1505
1281
  Count of observed rotations.
@@ -1510,36 +1286,36 @@ class _MaxScoreOverTranslations(MaxScoreOverRotations):
1510
1286
  """
1511
1287
 
1512
1288
  def __call__(
1513
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1289
+ self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs
1514
1290
  ) -> None:
1515
1291
  """
1516
- Update internal parameter store based on `score_space`.
1292
+ Update internal parameter store based on `scores`.
1517
1293
 
1518
1294
  Parameters
1519
1295
  ----------
1520
- score_space : ndarray
1296
+ scores : BackendArray
1521
1297
  Numpy array containing the score space.
1522
- rotation_matrix : ndarray
1298
+ rotation_matrix : BackendArray
1523
1299
  Square matrix describing the current rotation.
1524
1300
  **kwargs
1525
- Arbitrary keyword arguments.
1301
+ Optional keyword arguments.
1526
1302
  """
1527
1303
  from tme.matching_utils import centered_mask
1528
1304
 
1529
1305
  with self.lock:
1530
- rotation = backend.tobytes(rotation_matrix)
1306
+ rotation = be.tobytes(rotation_matrix)
1531
1307
  if rotation not in self.observed_rotations:
1532
1308
  self.observed_rotations[rotation] = len(self.observed_rotations)
1533
- score_space = centered_mask(score_space, kwargs["template_shape"])
1309
+ scores = centered_mask(scores, kwargs["template_shape"])
1534
1310
  rotation_index = self.observed_rotations[rotation]
1535
- internal_scores = backend.sharedarr_to_arr(
1536
- shape=self.score_space_shape,
1537
- dtype=self.score_space_dtype,
1538
- shm=self.score_space,
1311
+ internal_scores = be.from_sharedarr(
1312
+ shape=self.shape,
1313
+ dtype=self.scores_dtype,
1314
+ shm=self.scores,
1539
1315
  )
1540
- max_score = score_space.max(axis=(1, 2, 3))
1541
- mean_score = score_space.mean(axis=(1, 2, 3))
1542
- std_score = score_space.std(axis=(1, 2, 3))
1316
+ max_score = scores.max(axis=(1, 2, 3))
1317
+ mean_score = scores.mean(axis=(1, 2, 3))
1318
+ std_score = scores.std(axis=(1, 2, 3))
1543
1319
  z_score = (max_score - mean_score) / std_score
1544
1320
  internal_scores[rotation_index] = z_score
1545
1321
 
@@ -1563,7 +1339,7 @@ class MemmapHandler:
1563
1339
  indices : tuple of slice, optional
1564
1340
  Slices specifying which parts of the memmap array will be updated by `__call__`.
1565
1341
  **kwargs
1566
- Arbitrary keyword arguments.
1342
+ Optional keyword arguments.
1567
1343
  """
1568
1344
 
1569
1345
  def __init__(
@@ -1586,15 +1362,13 @@ class MemmapHandler:
1586
1362
  self.dtype = dtype
1587
1363
  self._indices = indices
1588
1364
 
1589
- def __call__(
1590
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1591
- ) -> None:
1365
+ def __call__(self, scores: NDArray, rotation_matrix: NDArray) -> None:
1592
1366
  """
1593
- Write `score_space` to memmap object on disk.
1367
+ Write `scores` to memmap object on disk.
1594
1368
 
1595
1369
  Parameters
1596
1370
  ----------
1597
- score_space : ndarray
1371
+ scores : ndarray
1598
1372
  Numpy array containing the score space.
1599
1373
  rotation_matrix : ndarray
1600
1374
  Square matrix describing the current rotation.
@@ -1606,7 +1380,7 @@ class MemmapHandler:
1606
1380
  array = np.memmap(current_object, mode="r+", shape=self.shape, dtype=self.dtype)
1607
1381
  # Does not really need a lock because processes operate on different rotations
1608
1382
  with self.lock:
1609
- array[self._indices] += score_space
1383
+ array[self._indices] += scores
1610
1384
  array.flush()
1611
1385
 
1612
1386
  def __iter__(self):
@@ -1647,5 +1421,3 @@ class MemmapHandler:
1647
1421
  """
1648
1422
  rotation_string = "_".join(rotation_matrix.ravel().astype(str))
1649
1423
  return self._path_translation[rotation_string]
1650
-
1651
-