pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/analyzer.py CHANGED
@@ -4,59 +4,58 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- from time import sleep
8
- from typing import Tuple, List, Dict
9
- from abc import ABC, abstractmethod
10
7
  from contextlib import nullcontext
11
- from multiprocessing import RawValue, Manager, Lock
8
+ from abc import ABC, abstractmethod
9
+ from multiprocessing import Manager, Lock
10
+ from typing import Tuple, List, Dict, Generator
12
11
 
13
12
  import numpy as np
14
- from numpy.typing import NDArray
15
- from scipy.stats import entropy
16
13
  from sklearn.cluster import DBSCAN
17
14
  from skimage.feature import peak_local_max
18
15
  from skimage.registration._phase_cross_correlation import _upsampled_dft
19
- from .extensions import max_index_by_label, online_statistics, find_candidate_indices
16
+
17
+ from .backends import backend as be
18
+ from .types import BackendArray, NDArray
19
+ from .extensions import max_index_by_label, find_candidate_indices
20
20
  from .matching_utils import (
21
- split_numpy_array_slices,
21
+ split_shape,
22
22
  array_to_memmap,
23
23
  generate_tempfile_name,
24
24
  euler_to_rotationmatrix,
25
25
  apply_convolution_mode,
26
26
  )
27
- from .backends import backend
27
+
28
+ PeakType = Tuple[BackendArray, BackendArray]
28
29
 
29
30
 
30
- def filter_points_indices_bucket(
31
- coordinates: NDArray, min_distance: Tuple[int]
32
- ) -> NDArray:
33
- coordinates = backend.subtract(coordinates, backend.min(coordinates, axis=0))
34
- bucket_indices = backend.astype(backend.divide(coordinates, min_distance), int)
35
- multiplier = backend.power(
36
- backend.max(bucket_indices, axis=0) + 1, backend.arange(bucket_indices.shape[1])
31
+ def _filter_bucket(coordinates: BackendArray, min_distance: Tuple[int]) -> BackendArray:
32
+ coordinates = be.subtract(coordinates, be.min(coordinates, axis=0))
33
+ bucket_indices = be.astype(be.divide(coordinates, min_distance), int)
34
+ multiplier = be.power(
35
+ be.max(bucket_indices, axis=0) + 1, be.arange(bucket_indices.shape[1])
37
36
  )
38
- backend.multiply(bucket_indices, multiplier, out=bucket_indices)
39
- flattened_indices = backend.sum(bucket_indices, axis=1)
40
- _, unique_indices = backend.unique(flattened_indices, return_index=True)
41
- unique_indices = unique_indices[backend.argsort(unique_indices)]
37
+ be.multiply(bucket_indices, multiplier, out=bucket_indices)
38
+ flattened_indices = be.sum(bucket_indices, axis=1)
39
+ _, unique_indices = be.unique(flattened_indices, return_index=True)
40
+ unique_indices = unique_indices[be.argsort(unique_indices)]
42
41
  return unique_indices
43
42
 
44
43
 
45
44
  def filter_points_indices(
46
- coordinates: NDArray,
45
+ coordinates: BackendArray,
47
46
  min_distance: float,
48
47
  bucket_cutoff: int = 1e4,
49
48
  batch_dims: Tuple[int] = None,
50
- ) -> NDArray:
49
+ ) -> BackendArray:
51
50
  if min_distance <= 0:
52
- return backend.arange(coordinates.shape[0])
51
+ return be.arange(coordinates.shape[0])
53
52
  if coordinates.shape[0] == 0:
54
53
  return ()
55
54
 
56
55
  if batch_dims is not None:
57
- coordinates_new = backend.zeros(coordinates.shape, coordinates.dtype)
56
+ coordinates_new = be.zeros(coordinates.shape, coordinates.dtype)
58
57
  coordinates_new[:] = coordinates
59
- coordinates_new[..., batch_dims] = backend.astype(
58
+ coordinates_new[..., batch_dims] = be.astype(
60
59
  coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
61
60
  )
62
61
  coordinates = coordinates_new
@@ -66,7 +65,7 @@ def filter_points_indices(
66
65
  elif coordinates.shape[0] > bucket_cutoff or not isinstance(
67
66
  coordinates, np.ndarray
68
67
  ):
69
- return filter_points_indices_bucket(coordinates, min_distance)
68
+ return _filter_bucket(coordinates, min_distance)
70
69
  distances = np.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
71
70
  distances = np.tril(distances)
72
71
  keep = np.sum(distances > min_distance, axis=1)
@@ -76,7 +75,7 @@ def filter_points_indices(
76
75
 
77
76
  def filter_points(
78
77
  coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
79
- ) -> NDArray:
78
+ ) -> BackendArray:
80
79
  unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
81
80
  coordinates = coordinates[unique_indices]
82
81
  return coordinates
@@ -96,8 +95,13 @@ class PeakCaller(ABC):
96
95
  Minimum distance to array boundaries.
97
96
  batch_dims : int, optional
98
97
  Peak calling batch dimensions.
98
+ minimum_score : float
99
+ Minimum score from which to consider peaks. If provided, superseeds limits
100
+ presented by :py:attr:`PeakCaller.number_of_peaks`.
101
+ maximum_score : float
102
+ Maximum score upon which to consider peaks,
99
103
  **kwargs
100
- Additional keyword arguments.
104
+ Optional keyword arguments.
101
105
 
102
106
  Raises
103
107
  ------
@@ -112,12 +116,10 @@ class PeakCaller(ABC):
112
116
  min_distance: int = 1,
113
117
  min_boundary_distance: int = 0,
114
118
  batch_dims: Tuple[int] = None,
119
+ minimum_score: float = None,
120
+ maximum_score: float = None,
115
121
  **kwargs,
116
122
  ):
117
- number_of_peaks = int(number_of_peaks)
118
- min_distance, min_boundary_distance = int(min_distance), int(
119
- min_boundary_distance
120
- )
121
123
  if number_of_peaks <= 0:
122
124
  raise ValueError(
123
125
  f"number_of_peaks has to be larger than 0, got {number_of_peaks}"
@@ -130,26 +132,28 @@ class PeakCaller(ABC):
130
132
  )
131
133
 
132
134
  self.peak_list = []
133
- self.min_distance = min_distance
134
- self.min_boundary_distance = min_boundary_distance
135
- self.number_of_peaks = number_of_peaks
135
+ self.min_distance = int(min_distance)
136
+ self.number_of_peaks = int(number_of_peaks)
137
+ self.min_boundary_distance = int(min_boundary_distance)
136
138
 
137
139
  self.batch_dims = batch_dims
138
140
  if batch_dims is not None:
139
141
  self.batch_dims = tuple(int(x) for x in self.batch_dims)
140
142
 
143
+ self.minimum_score, self.maximum_score = minimum_score, maximum_score
144
+
141
145
  # Postprocesing arguments
142
146
  self.fourier_shift = kwargs.get("fourier_shift", None)
143
147
  self.convolution_mode = kwargs.get("convolution_mode", None)
144
148
  self.targetshape = kwargs.get("targetshape", None)
145
149
  self.templateshape = kwargs.get("templateshape", None)
146
150
 
147
- def __iter__(self):
151
+ def __iter__(self) -> Generator:
148
152
  """
149
153
  Returns a generator to list objects containing translation,
150
154
  rotation, score and details of a given candidate.
151
155
  """
152
- self.peak_list = [backend.to_cpu_array(arr) for arr in self.peak_list]
156
+ self.peak_list = [be.to_cpu_array(arr) for arr in self.peak_list]
153
157
  yield from self.peak_list
154
158
 
155
159
  @staticmethod
@@ -181,34 +185,24 @@ class PeakCaller(ABC):
181
185
 
182
186
  yield from _generate_slices_recursive(0, ())
183
187
 
184
- def __call__(
185
- self,
186
- score_space: NDArray,
187
- rotation_matrix: NDArray,
188
- minimum_score: float = None,
189
- maximum_score: float = None,
190
- **kwargs,
191
- ) -> None:
188
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
192
189
  """
193
190
  Update the internal parameter store based on input array.
194
191
 
195
192
  Parameters
196
193
  ----------
197
- score_space : NDArray
198
- Array containing the score space.
199
- rotation_matrix : NDArray
194
+ scores : BackendArray
195
+ Score space data.
196
+ rotation_matrix : BackendArray
200
197
  Rotation matrix used to obtain the score array.
201
- minimum_score : float
202
- Minimum score from which to consider peaks. If provided, superseeds limits
203
- presented by :py:attr:`PeakCaller.number_of_peaks`.
204
- maximum_score : float
205
- Maximum score upon which to consider peaks,
206
198
  **kwargs
207
- Optional keyword arguments passed to :py:meth:`PeakCaller.call_peak`.
199
+ Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
208
200
  """
209
- for subset, offset in self._batchify(score_space.shape, self.batch_dims):
201
+ minimum_score, maximum_score = self.minimum_score, self.maximum_score
202
+ for subset, batch_offset in self._batchify(scores.shape, self.batch_dims):
203
+ batch_offset = be.to_backend_array(batch_offset)
210
204
  peak_positions, peak_details = self.call_peaks(
211
- score_space=score_space[subset],
205
+ scores=scores[subset],
212
206
  rotation_matrix=rotation_matrix,
213
207
  minimum_score=minimum_score,
214
208
  maximum_score=maximum_score,
@@ -221,34 +215,30 @@ class PeakCaller(ABC):
221
215
  continue
222
216
 
223
217
  if peak_details is None:
224
- peak_details = backend.full((peak_positions.shape[0],), fill_value=-1)
218
+ peak_details = be.full((peak_positions.shape[0],), fill_value=-1)
225
219
 
226
- backend.add(peak_positions, offset, out=peak_positions)
227
- peak_positions = backend.astype(peak_positions, int)
220
+ peak_positions = be.to_backend_array(peak_positions)
221
+ peak_positions = be.add(peak_positions, batch_offset, out=peak_positions)
222
+ peak_positions = be.astype(peak_positions, int)
228
223
  if self.min_boundary_distance > 0:
229
- upper_limit = backend.subtract(
230
- score_space.shape, self.min_boundary_distance
224
+ upper_limit = be.subtract(
225
+ be.to_backend_array(scores.shape), self.min_boundary_distance
231
226
  )
232
- valid_peaks = backend.multiply(
227
+ valid_peaks = be.multiply(
233
228
  peak_positions < upper_limit,
234
229
  peak_positions >= self.min_boundary_distance,
235
230
  )
236
231
  if self.batch_dims is not None:
237
232
  valid_peaks[..., self.batch_dims] = True
238
233
 
239
- valid_peaks = (
240
- backend.sum(valid_peaks, axis=1) == peak_positions.shape[1]
241
- )
234
+ valid_peaks = be.sum(valid_peaks, axis=1) == peak_positions.shape[1]
242
235
 
243
- if backend.sum(valid_peaks) == 0:
236
+ if be.sum(valid_peaks) == 0:
244
237
  continue
238
+ peak_positions = peak_positions[valid_peaks]
239
+ peak_details = peak_details[valid_peaks]
245
240
 
246
- peak_positions, peak_details = (
247
- peak_positions[valid_peaks],
248
- peak_details[valid_peaks],
249
- )
250
-
251
- peak_scores = score_space[tuple(peak_positions.T)]
241
+ peak_scores = scores[tuple(peak_positions.T)]
252
242
  if minimum_score is not None:
253
243
  valid_peaks = peak_scores >= minimum_score
254
244
  peak_positions, peak_details, peak_scores = (
@@ -267,7 +257,7 @@ class PeakCaller(ABC):
267
257
  if peak_positions.shape[0] == 0:
268
258
  continue
269
259
 
270
- rotations = backend.repeat(
260
+ rotations = be.repeat(
271
261
  rotation_matrix.reshape(1, *rotation_matrix.shape),
272
262
  peak_positions.shape[0],
273
263
  axis=0,
@@ -278,83 +268,71 @@ class PeakCaller(ABC):
278
268
  peak_details=peak_details,
279
269
  peak_scores=peak_scores,
280
270
  rotations=rotations,
281
- batch_offset=offset,
282
- **kwargs,
283
271
  )
284
272
 
285
273
  return None
286
274
 
287
275
  @abstractmethod
288
- def call_peaks(
289
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
290
- ) -> Tuple[NDArray, NDArray]:
276
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
291
277
  """
292
278
  Call peaks in the score space.
293
279
 
294
- This function is not intended to be called directly, but should rather be
295
- defined by classes inheriting from :py:class:`PeakCaller` to execute a given
296
- peak calling algorithm.
297
-
298
280
  Parameters
299
281
  ----------
300
- score_space : NDArray
301
- Data array of scores.
302
- **kwargs : Dict, optional
303
- Keyword arguments passed to __call__.
282
+ scores : BackendArray
283
+ Score array.
284
+ **kwargs : dict
285
+ Optional keyword arguments passed to underlying implementations.
304
286
 
305
287
  Returns
306
288
  -------
307
- Tuple[NDArray, NDArray]
289
+ Tuple[BackendArray, BackendArray]
308
290
  Array of peak coordinates and peak details.
309
291
  """
310
292
 
311
293
  @classmethod
312
- def merge(cls, candidates=List[List], **kwargs) -> NDArray:
294
+ def merge(cls, candidates=List[List], **kwargs) -> Tuple:
313
295
  """
314
296
  Merge multiple instances of :py:class:`PeakCaller`.
315
297
 
316
298
  Parameters
317
299
  ----------
318
- candidate_fits : list of lists
300
+ candidates : list of lists
319
301
  Obtained by invoking list on the generator returned by __iter__.
320
- param_stores : list of tuples, optional
321
- List of parameter stores. Each tuple contains candidate data and number
322
- of candidates.
323
302
  **kwargs
324
- Additional keyword arguments.
303
+ Optional keyword arguments.
325
304
 
326
305
  Returns
327
306
  -------
328
- NDArray
329
- NDArray of candidates.
307
+ Tuple
308
+ Tuple of translation, rotation, score and details of candidates.
330
309
  """
331
310
  base = cls(**kwargs)
332
311
  for candidate in candidates:
333
312
  if len(candidate) == 0:
334
313
  continue
335
314
  peak_positions, rotations, peak_scores, peak_details = candidate
336
- kwargs["translation_offset"] = backend.zeros(peak_positions.shape[1])
337
315
  base._update(
338
- peak_positions=backend.to_backend_array(peak_positions),
339
- peak_details=backend.to_backend_array(peak_details),
340
- peak_scores=backend.to_backend_array(peak_scores),
341
- rotations=backend.to_backend_array(rotations),
342
- **kwargs,
316
+ peak_positions=be.to_backend_array(peak_positions),
317
+ peak_details=be.to_backend_array(peak_details),
318
+ peak_scores=be.to_backend_array(peak_scores),
319
+ rotations=be.to_backend_array(rotations),
320
+ offset=kwargs.get("offset", None),
343
321
  )
344
322
  return tuple(base)
345
323
 
346
324
  @staticmethod
347
325
  def oversample_peaks(
348
- score_space: NDArray, peak_positions: NDArray, oversampling_factor: int = 8
326
+ scores: BackendArray, peak_positions: BackendArray, oversampling_factor: int = 8
349
327
  ):
350
328
  """
351
329
  Refines peaks positions in the corresponding score space.
352
330
 
353
331
  Parameters
354
332
  ----------
355
- score_space : NDArray
333
+ scores : BackendArray
356
334
  The d-dimensional array representing the score space.
357
- peak_positions : NDArray
335
+ peak_positions : BackendArray
358
336
  An array of shape (n, d) containing the peak coordinates
359
337
  to be refined, where n is the number of peaks and d is the
360
338
  dimensionality of the score space.
@@ -363,14 +341,14 @@ class PeakCaller(ABC):
363
341
 
364
342
  Returns
365
343
  -------
366
- NDArray
344
+ BackendArray
367
345
  An array of shape (n, d) containing the refined subpixel
368
346
  coordinates of the peaks.
369
347
 
370
348
  Notes
371
349
  -----
372
350
  Floating point peak positions are determined by oversampling the
373
- score_space around peak_positions. The accuracy
351
+ scores around peak_positions. The accuracy
374
352
  of refinement scales with 1 / oversampling_factor.
375
353
 
376
354
  References
@@ -382,8 +360,8 @@ class PeakCaller(ABC):
382
360
  DOI:10.1364/OL.33.000156
383
361
 
384
362
  """
385
- score_space = backend.to_numpy_array(score_space)
386
- peak_positions = backend.to_numpy_array(peak_positions)
363
+ scores = be.to_numpy_array(scores)
364
+ peak_positions = be.to_numpy_array(peak_positions)
387
365
 
388
366
  peak_positions = np.round(
389
367
  np.divide(
@@ -396,10 +374,10 @@ class PeakCaller(ABC):
396
374
  dftshift, np.multiply(peak_positions, oversampling_factor)
397
375
  )
398
376
 
399
- score_space_ft = np.fft.fftn(score_space).conj()
377
+ scores_ft = np.fft.fftn(scores).conj()
400
378
  for index in range(sample_region_offset.shape[0]):
401
379
  cross_correlation_upsampled = _upsampled_dft(
402
- data=score_space_ft,
380
+ data=scores_ft,
403
381
  upsampled_region_size=upsampled_region_size,
404
382
  upsample_factor=oversampling_factor,
405
383
  axis_offsets=sample_region_offset[index],
@@ -412,73 +390,66 @@ class PeakCaller(ABC):
412
390
  maxima = np.divide(np.subtract(maxima, dftshift), oversampling_factor)
413
391
  peak_positions[index] = np.add(peak_positions[index], maxima)
414
392
 
415
- peak_positions = backend.to_backend_array(peak_positions)
393
+ peak_positions = be.to_backend_array(peak_positions)
416
394
 
417
395
  return peak_positions
418
396
 
419
397
  def _update(
420
398
  self,
421
- peak_positions: NDArray,
422
- peak_details: NDArray,
423
- peak_scores: NDArray,
424
- rotations: NDArray,
425
- **kwargs,
426
- ) -> None:
399
+ peak_positions: BackendArray,
400
+ peak_details: BackendArray,
401
+ peak_scores: BackendArray,
402
+ rotations: BackendArray,
403
+ offset: BackendArray = None,
404
+ ):
427
405
  """
428
406
  Update internal parameter store.
429
407
 
430
408
  Parameters
431
409
  ----------
432
- peak_positions : NDArray
433
- Position of peaks with shape n x d where n is the number of
434
- peaks and d the dimension.
435
- peak_scores : NDArray
436
- Corresponding score obtained at each peak.
437
- translation_offset : NDArray, optional
438
- Offset of the score_space, occurs e.g. when template matching
439
- to parts of a tomogram.
440
- rotations: NDArray
441
- Rotations used to obtain the score space from which
442
- the candidate stem.
410
+ peak_positions : BackendArray
411
+ Position of peaks (n, d).
412
+ peak_details : BackendArray
413
+ Details of each peak (n, ).
414
+ rotations: BackendArray
415
+ Rotation at each peak (n, d, d).
416
+ rotations: BackendArray
417
+ Rotation at each peak (n, d, d).
418
+ offset : BackendArray, optional
419
+ Translation offset, e.g. from splitting, (n, ).
443
420
  """
444
- translation_offset = kwargs.get(
445
- "translation_offset", backend.zeros(peak_positions.shape[1])
446
- )
447
- translation_offset = backend.astype(translation_offset, peak_positions.dtype)
421
+ if offset is not None:
422
+ offset = be.astype(offset, peak_positions.dtype)
423
+ peak_positions = be.add(peak_positions, offset, out=peak_positions)
448
424
 
449
- backend.add(peak_positions, translation_offset, out=peak_positions)
450
425
  if not len(self.peak_list):
451
426
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
452
-
453
- peak_positions = backend.concatenate((self.peak_list[0], peak_positions))
454
- rotations = backend.concatenate((self.peak_list[1], rotations))
455
- peak_scores = backend.concatenate((self.peak_list[2], peak_scores))
456
- peak_details = backend.concatenate((self.peak_list[3], peak_details))
427
+ else:
428
+ peak_positions = be.concatenate((self.peak_list[0], peak_positions))
429
+ rotations = be.concatenate((self.peak_list[1], rotations))
430
+ peak_scores = be.concatenate((self.peak_list[2], peak_scores))
431
+ peak_details = be.concatenate((self.peak_list[3], peak_details))
457
432
 
458
433
  if self.batch_dims is None:
459
- top_n = min(backend.size(peak_scores), self.number_of_peaks)
460
- top_scores, *_ = backend.topk_indices(peak_scores, top_n)
434
+ top_n = min(be.size(peak_scores), self.number_of_peaks)
435
+ top_scores, *_ = be.topk_indices(peak_scores, top_n)
461
436
  else:
462
437
  # Not very performant but fairly robust
463
438
  batch_indices = peak_positions[..., self.batch_dims]
464
- backend.subtract(
465
- batch_indices, backend.min(batch_indices, axis=0), out=batch_indices
466
- )
467
- multiplier = backend.power(
468
- backend.max(batch_indices, axis=0) + 1,
469
- backend.arange(batch_indices.shape[1]),
470
- )
471
- backend.multiply(batch_indices, multiplier, out=batch_indices)
472
- batch_indices = backend.sum(batch_indices, axis=1)
473
- unique_indices, batch_counts = backend.unique(
474
- batch_indices, return_counts=True
439
+ batch_indices = be.subtract(batch_indices, be.min(batch_indices, axis=0))
440
+ multiplier = be.power(
441
+ be.max(batch_indices, axis=0) + 1,
442
+ be.arange(batch_indices.shape[1]),
475
443
  )
476
- total_indices = backend.arange(peak_scores.shape[0])
444
+ batch_indices = be.multiply(batch_indices, multiplier, out=batch_indices)
445
+ batch_indices = be.sum(batch_indices, axis=1)
446
+ unique_indices, batch_counts = be.unique(batch_indices, return_counts=True)
447
+ total_indices = be.arange(peak_scores.shape[0])
477
448
  batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
478
- top_scores = backend.concatenate(
449
+ top_scores = be.concatenate(
479
450
  [
480
451
  total_indices[indices][
481
- backend.topk_indices(
452
+ be.topk_indices(
482
453
  peak_scores[indices], min(y, self.number_of_peaks)
483
454
  )
484
455
  ]
@@ -488,19 +459,27 @@ class PeakCaller(ABC):
488
459
 
489
460
  final_order = top_scores[
490
461
  filter_points_indices(
491
- coordinates=peak_positions[top_scores],
462
+ coordinates=peak_positions[top_scores, :],
492
463
  min_distance=self.min_distance,
493
464
  batch_dims=self.batch_dims,
494
465
  )
495
466
  ]
496
467
 
497
- self.peak_list[0] = peak_positions[final_order,]
498
- self.peak_list[1] = rotations[final_order,]
468
+ self.peak_list[0] = peak_positions[final_order, :]
469
+ self.peak_list[1] = rotations[final_order, :]
499
470
  self.peak_list[2] = peak_scores[final_order]
500
471
  self.peak_list[3] = peak_details[final_order]
501
472
 
502
473
  def _postprocess(
503
- self, fourier_shift, convolution_mode, targetshape, templateshape, **kwargs
474
+ self,
475
+ fast_shape: Tuple[int],
476
+ targetshape: Tuple[int],
477
+ templateshape: Tuple[int],
478
+ convolution_shape: Tuple[int] = None,
479
+ fourier_shift: Tuple[int] = None,
480
+ convolution_mode: str = None,
481
+ shared_memory_handler=None,
482
+ **kwargs,
504
483
  ):
505
484
  if not len(self.peak_list):
506
485
  return self
@@ -509,52 +488,43 @@ class PeakCaller(ABC):
509
488
  if not len(peak_positions):
510
489
  return self
511
490
 
512
- if targetshape is None or templateshape is None:
513
- return self
514
-
515
- # Remove padding to next fast fourier length
516
- score_space_shape = backend.add(targetshape, templateshape) - 1
517
-
491
+ # Wrap peaks around score space
492
+ convolution_shape = be.to_backend_array(convolution_shape)
493
+ fast_shape = be.to_backend_array(fast_shape)
518
494
  if fourier_shift is not None:
519
- peak_positions = backend.add(peak_positions, fourier_shift)
520
-
521
- backend.subtract(
495
+ fourier_shift = be.to_backend_array(fourier_shift)
496
+ peak_positions = be.add(peak_positions, fourier_shift)
497
+ peak_positions = be.subtract(
522
498
  peak_positions,
523
- backend.multiply(
524
- backend.astype(
525
- backend.divide(peak_positions, score_space_shape), int
526
- ),
527
- score_space_shape,
499
+ be.multiply(
500
+ be.astype(be.divide(peak_positions, fast_shape), int),
501
+ fast_shape,
528
502
  ),
529
- out=peak_positions,
530
503
  )
531
504
 
532
- if convolution_mode is None:
533
- return None
534
-
535
- if convolution_mode == "full":
536
- output_shape = score_space_shape
537
- elif convolution_mode == "same":
505
+ # Remove padding to fast Fourier (and potential full convolution) shape
506
+ output_shape = convolution_shape
507
+ targetshape = be.to_backend_array(targetshape)
508
+ templateshape = be.to_backend_array(templateshape)
509
+ if convolution_mode == "same":
538
510
  output_shape = targetshape
539
511
  elif convolution_mode == "valid":
540
- output_shape = backend.add(
541
- backend.subtract(targetshape, templateshape),
542
- backend.mod(templateshape, 2),
512
+ output_shape = be.add(
513
+ be.subtract(targetshape, templateshape),
514
+ be.mod(templateshape, 2),
543
515
  )
544
516
 
545
- output_shape = backend.to_backend_array(output_shape)
546
- starts = backend.divide(backend.subtract(score_space_shape, output_shape), 2)
547
- starts = backend.astype(starts, int)
548
- stops = backend.add(starts, output_shape)
549
-
550
- valid_peaks = (
551
- backend.sum(
552
- backend.multiply(peak_positions > starts, peak_positions <= stops),
553
- axis=1,
554
- )
555
- == peak_positions.shape[1]
517
+ output_shape = be.to_backend_array(output_shape)
518
+ starts = be.astype(
519
+ be.divide(be.subtract(convolution_shape, output_shape), 2),
520
+ be._int_dtype,
556
521
  )
557
- self.peak_list[0] = backend.subtract(peak_positions, starts)
522
+ stops = be.add(starts, output_shape)
523
+
524
+ valid_peaks = be.multiply(peak_positions > starts, peak_positions <= stops)
525
+ valid_peaks = be.sum(valid_peaks, axis=1) == peak_positions.shape[1]
526
+
527
+ self.peak_list[0] = be.subtract(peak_positions, starts)
558
528
  self.peak_list = [x[valid_peaks] for x in self.peak_list]
559
529
  return self
560
530
 
@@ -565,27 +535,14 @@ class PeakCallerSort(PeakCaller):
565
535
  highest scores.
566
536
  """
567
537
 
568
- def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
569
- """
570
- Call peaks in the score space.
538
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
539
+ flat_scores = scores.reshape(-1)
540
+ k = min(self.number_of_peaks, be.size(flat_scores))
571
541
 
572
- Parameters
573
- ----------
574
- score_space : NDArray
575
- Data array of scores.
542
+ top_k_indices, *_ = be.topk_indices(flat_scores, k)
576
543
 
577
- Returns
578
- -------
579
- Tuple[NDArray, NDArray]
580
- Array of peak coordinates and peak details.
581
- """
582
- flat_score_space = score_space.reshape(-1)
583
- k = min(self.number_of_peaks, backend.size(flat_score_space))
584
-
585
- top_k_indices, *_ = backend.topk_indices(flat_score_space, k)
586
-
587
- coordinates = backend.unravel_index(top_k_indices, score_space.shape)
588
- coordinates = backend.transpose(backend.stack(coordinates))
544
+ coordinates = be.unravel_index(top_k_indices, scores.shape)
545
+ coordinates = be.transpose(be.stack(coordinates))
589
546
 
590
547
  return coordinates, None
591
548
 
@@ -594,28 +551,11 @@ class PeakCallerMaximumFilter(PeakCaller):
594
551
  """
595
552
  Find local maxima by applying a maximum filter and enforcing a distance
596
553
  constraint subsequently. This is similar to the strategy implemented in
597
- skimage.feature.peak_local_max.
554
+ :obj:`skimage.feature.peak_local_max`.
598
555
  """
599
556
 
600
- def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
601
- """
602
- Call peaks in the score space.
603
-
604
- Parameters
605
- ----------
606
- score_space : NDArray
607
- Data array of scores.
608
- kwargs: Dict, optional
609
- Optional keyword arguments.
610
-
611
- Returns
612
- -------
613
- Tuple[NDArray, NDArray]
614
- Array of peak coordinates and peak details.
615
- """
616
- peaks = backend.max_filter_coordinates(score_space, self.min_distance)
617
-
618
- return peaks, None
557
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
558
+ return be.max_filter_coordinates(scores, self.min_distance), None
619
559
 
620
560
 
621
561
  class PeakCallerFast(PeakCaller):
@@ -627,56 +567,35 @@ class PeakCallerFast(PeakCaller):
627
567
 
628
568
  """
629
569
 
630
- def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
631
- """
632
- Call peaks in the score space.
633
-
634
- Parameters
635
- ----------
636
- score_space : NDArray
637
- Data array of scores.
570
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
571
+ splits = {i: x // self.min_distance for i, x in enumerate(scores.shape)}
572
+ slices = split_shape(scores.shape, splits)
638
573
 
639
- Returns
640
- -------
641
- Tuple[NDArray, NDArray]
642
- Array of peak coordinates and peak details.
643
- """
644
- splits = {
645
- axis: score_space.shape[axis] // self.min_distance
646
- for axis in range(score_space.ndim)
647
- }
648
- slices = split_numpy_array_slices(score_space.shape, splits)
649
-
650
- coordinates = backend.to_backend_array(
574
+ coordinates = be.to_backend_array(
651
575
  [
652
- backend.unravel_index(
653
- backend.argmax(score_space[subvol]), score_space[subvol].shape
654
- )
576
+ be.unravel_index(be.argmax(scores[subvol]), scores[subvol].shape)
655
577
  for subvol in slices
656
578
  ]
657
579
  )
658
- offset = backend.to_backend_array(
580
+ offset = be.to_backend_array(
659
581
  [tuple(x.start for x in subvol) for subvol in slices]
660
582
  )
661
- backend.add(coordinates, offset, out=coordinates)
662
- coordinates = coordinates[
663
- backend.flip(backend.argsort(score_space[tuple(coordinates.T)]), (0,))
664
- ]
583
+ be.add(coordinates, offset, out=coordinates)
584
+ coordinates = coordinates[be.argsort(-scores[tuple(coordinates.T)])]
665
585
 
666
586
  if coordinates.shape[0] == 0:
667
587
  return None
668
588
 
669
- starts = backend.maximum(coordinates - self.min_distance, 0)
670
- stops = backend.minimum(coordinates + self.min_distance, score_space.shape)
589
+ starts = be.maximum(coordinates - self.min_distance, 0)
590
+ stops = be.minimum(coordinates + self.min_distance, scores.shape)
671
591
  slices_list = [
672
592
  tuple(slice(*coord) for coord in zip(start_row, stop_row))
673
593
  for start_row, stop_row in zip(starts, stops)
674
594
  ]
675
595
 
676
- scores = score_space[tuple(coordinates.T)]
677
596
  keep = [
678
- score >= backend.max(score_space[subvol])
679
- for subvol, score in zip(slices_list, scores)
597
+ score_subvol >= be.max(scores[subvol])
598
+ for subvol, score_subvol in zip(slices_list, scores[tuple(coordinates.T)])
680
599
  ]
681
600
  coordinates = coordinates[keep,]
682
601
 
@@ -694,26 +613,26 @@ class PeakCallerRecursiveMasking(PeakCaller):
694
613
 
695
614
  def call_peaks(
696
615
  self,
697
- score_space: NDArray,
698
- rotation_matrix: NDArray,
699
- mask: NDArray = None,
616
+ scores: BackendArray,
617
+ rotation_matrix: BackendArray,
618
+ mask: BackendArray = None,
700
619
  minimum_score: float = None,
701
- rotation_space: NDArray = None,
620
+ rotation_space: BackendArray = None,
702
621
  rotation_mapping: Dict = None,
703
622
  **kwargs,
704
- ) -> Tuple[NDArray, NDArray]:
623
+ ) -> PeakType:
705
624
  """
706
625
  Call peaks in the score space.
707
626
 
708
627
  Parameters
709
628
  ----------
710
- score_space : NDArray
629
+ scores : BackendArray
711
630
  Data array of scores.
712
- rotation_matrix : NDArray
631
+ rotation_matrix : BackendArray
713
632
  Rotation matrix.
714
- mask : NDArray, optional
633
+ mask : BackendArray, optional
715
634
  Mask array, by default None.
716
- rotation_space : NDArray, optional
635
+ rotation_space : BackendArray, optional
717
636
  Rotation space array, by default None.
718
637
  rotation_mapping : Dict optional
719
638
  Dictionary mapping values in rotation_space to Euler angles.
@@ -724,7 +643,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
724
643
 
725
644
  Returns
726
645
  -------
727
- Tuple[NDArray, NDArray]
646
+ Tuple[BackendArray, BackendArray]
728
647
  Array of peak coordinates and peak details.
729
648
 
730
649
  Notes
@@ -738,26 +657,26 @@ class PeakCallerRecursiveMasking(PeakCaller):
738
657
 
739
658
  if mask is None:
740
659
  masking_function = self._mask_scores_box
741
- shape = tuple(self.min_distance for _ in range(score_space.ndim))
742
- mask = backend.zeros(shape, dtype=backend._float_dtype)
660
+ shape = tuple(self.min_distance for _ in range(scores.ndim))
661
+ mask = be.zeros(shape, dtype=be._float_dtype)
743
662
 
744
- rotated_template = backend.zeros(mask.shape, dtype=mask.dtype)
663
+ rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
745
664
 
746
665
  peak_limit = self.number_of_peaks
747
666
  if minimum_score is not None:
748
- peak_limit = backend.size(score_space)
667
+ peak_limit = be.size(scores)
749
668
  else:
750
- minimum_score = backend.min(score_space) - 1
669
+ minimum_score = be.min(scores) - 1
751
670
 
752
- scores = backend.zeros(score_space.shape, dtype=score_space.dtype)
753
- scores[:] = score_space
671
+ scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
672
+ scores_copy[:] = scores
754
673
 
755
674
  while True:
756
- backend.argmax(scores)
757
- peak = backend.unravel_index(
758
- indices=backend.argmax(scores), shape=scores.shape
675
+ be.argmax(scores_copy)
676
+ peak = be.unravel_index(
677
+ indices=be.argmax(scores_copy), shape=scores_copy.shape
759
678
  )
760
- if scores[tuple(peak)] < minimum_score:
679
+ if scores_copy[tuple(peak)] < minimum_score:
761
680
  break
762
681
 
763
682
  coordinates.append(peak)
@@ -770,7 +689,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
770
689
  )
771
690
 
772
691
  masking_function(
773
- score_space=scores,
692
+ scores=scores_copy,
774
693
  rotation_matrix=current_rotation_matrix,
775
694
  peak=peak,
776
695
  mask=mask,
@@ -780,33 +699,33 @@ class PeakCallerRecursiveMasking(PeakCaller):
780
699
  if len(coordinates) >= peak_limit:
781
700
  break
782
701
 
783
- peaks = backend.to_backend_array(coordinates)
702
+ peaks = be.to_backend_array(coordinates)
784
703
  return peaks, None
785
704
 
786
705
  @staticmethod
787
706
  def _get_rotation_matrix(
788
- peak: NDArray,
789
- rotation_space: NDArray,
790
- rotation_mapping: NDArray,
791
- rotation_matrix: NDArray,
792
- ) -> NDArray:
707
+ peak: BackendArray,
708
+ rotation_space: BackendArray,
709
+ rotation_mapping: BackendArray,
710
+ rotation_matrix: BackendArray,
711
+ ) -> BackendArray:
793
712
  """
794
713
  Get rotation matrix based on peak and rotation data.
795
714
 
796
715
  Parameters
797
716
  ----------
798
- peak : NDArray
717
+ peak : BackendArray
799
718
  Peak coordinates.
800
- rotation_space : NDArray
719
+ rotation_space : BackendArray
801
720
  Rotation space array.
802
721
  rotation_mapping : Dict
803
722
  Dictionary mapping values in rotation_space to Euler angles.
804
- rotation_matrix : NDArray
723
+ rotation_matrix : BackendArray
805
724
  Current rotation matrix.
806
725
 
807
726
  Returns
808
727
  -------
809
- NDArray
728
+ BackendArray
810
729
  Rotation matrix.
811
730
  """
812
731
  if rotation_space is None or rotation_mapping is None:
@@ -814,135 +733,117 @@ class PeakCallerRecursiveMasking(PeakCaller):
814
733
 
815
734
  rotation = rotation_mapping[rotation_space[tuple(peak)]]
816
735
 
817
- rotation_matrix = backend.to_backend_array(
818
- euler_to_rotationmatrix(backend.to_numpy_array(rotation))
819
- )
820
- return rotation_matrix
736
+ # TODO: Newer versions of rotation mapping contain rotation matrices not angles
737
+ if len(rotation) == 3:
738
+ rotation = be.to_backend_array(
739
+ euler_to_rotationmatrix(be.to_numpy_array(rotation))
740
+ )
741
+ return rotation
821
742
 
822
743
  @staticmethod
823
744
  def _mask_scores_box(
824
- score_space: NDArray, peak: NDArray, mask: NDArray, **kwargs: Dict
745
+ scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
825
746
  ) -> None:
826
747
  """
827
748
  Mask scores in a box around a peak.
828
749
 
829
750
  Parameters
830
751
  ----------
831
- score_space : NDArray
752
+ scores : BackendArray
832
753
  Data array of scores.
833
- peak : NDArray
754
+ peak : BackendArray
834
755
  Peak coordinates.
835
- mask : NDArray
756
+ mask : BackendArray
836
757
  Mask array.
837
758
  """
838
- start = backend.maximum(backend.subtract(peak, mask.shape), 0)
839
- stop = backend.minimum(backend.add(peak, mask.shape), score_space.shape)
840
- start, stop = backend.astype(start, int), backend.astype(stop, int)
759
+ start = be.maximum(be.subtract(peak, mask.shape), 0)
760
+ stop = be.minimum(be.add(peak, mask.shape), scores.shape)
761
+ start, stop = be.astype(start, int), be.astype(stop, int)
841
762
  coords = tuple(slice(*pos) for pos in zip(start, stop))
842
- score_space[coords] = 0
763
+ scores[coords] = 0
843
764
  return None
844
765
 
845
766
  @staticmethod
846
767
  def _mask_scores_rotate(
847
- score_space: NDArray,
848
- peak: NDArray,
849
- mask: NDArray,
850
- rotated_template: NDArray,
851
- rotation_matrix: NDArray,
768
+ scores: BackendArray,
769
+ peak: BackendArray,
770
+ mask: BackendArray,
771
+ rotated_template: BackendArray,
772
+ rotation_matrix: BackendArray,
852
773
  **kwargs: Dict,
853
774
  ) -> None:
854
775
  """
855
- Mask score_space using mask rotation around a peak.
776
+ Mask scores using mask rotation around a peak.
856
777
 
857
778
  Parameters
858
779
  ----------
859
- score_space : NDArray
780
+ scores : BackendArray
860
781
  Data array of scores.
861
- peak : NDArray
782
+ peak : BackendArray
862
783
  Peak coordinates.
863
- mask : NDArray
784
+ mask : BackendArray
864
785
  Mask array.
865
- rotated_template : NDArray
786
+ rotated_template : BackendArray
866
787
  Empty array to write mask rotations to.
867
- rotation_matrix : NDArray
788
+ rotation_matrix : BackendArray
868
789
  Rotation matrix.
869
790
  """
870
- left_pad = backend.divide(mask.shape, 2).astype(int)
871
- right_pad = backend.add(left_pad, backend.mod(mask.shape, 2).astype(int))
791
+ left_pad = be.divide(mask.shape, 2).astype(int)
792
+ right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
872
793
 
873
- score_start = backend.subtract(peak, left_pad)
874
- score_stop = backend.add(peak, right_pad)
794
+ score_start = be.subtract(peak, left_pad)
795
+ score_stop = be.add(peak, right_pad)
875
796
 
876
- template_start = backend.subtract(backend.maximum(score_start, 0), score_start)
877
- template_stop = backend.subtract(
878
- score_stop, backend.minimum(score_stop, score_space.shape)
879
- )
880
- template_stop = backend.subtract(mask.shape, template_stop)
797
+ template_start = be.subtract(be.maximum(score_start, 0), score_start)
798
+ template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
799
+ template_stop = be.subtract(mask.shape, template_stop)
881
800
 
882
- score_start = backend.maximum(score_start, 0)
883
- score_stop = backend.minimum(score_stop, score_space.shape)
884
- score_start = backend.astype(score_start, int)
885
- score_stop = backend.astype(score_stop, int)
801
+ score_start = be.maximum(score_start, 0)
802
+ score_stop = be.minimum(score_stop, scores.shape)
803
+ score_start = be.astype(score_start, int)
804
+ score_stop = be.astype(score_stop, int)
886
805
 
887
- template_start = backend.astype(template_start, int)
888
- template_stop = backend.astype(template_stop, int)
806
+ template_start = be.astype(template_start, int)
807
+ template_stop = be.astype(template_stop, int)
889
808
  coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
890
809
  coords_template = tuple(
891
810
  slice(*pos) for pos in zip(template_start, template_stop)
892
811
  )
893
812
 
894
813
  rotated_template.fill(0)
895
- backend.rotate_array(
814
+ be.rigid_transform(
896
815
  arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
897
816
  )
898
817
 
899
- score_space[coords_score] = backend.multiply(
900
- score_space[coords_score], (rotated_template[coords_template] <= 0.1)
818
+ scores[coords_score] = be.multiply(
819
+ scores[coords_score], (rotated_template[coords_template] <= 0.1)
901
820
  )
902
821
  return None
903
822
 
904
823
 
905
824
  class PeakCallerScipy(PeakCaller):
906
825
  """
907
- Peak calling using skimage.feature.peak_local_max to compute local maxima.
826
+ Peak calling using :obj:`skimage.feature.peak_local_max` to compute local maxima.
908
827
  """
909
828
 
910
829
  def call_peaks(
911
- self, score_space: NDArray, minimum_score: float = None, **kwargs
912
- ) -> Tuple[NDArray, NDArray]:
913
- """
914
- Call peaks in the score space.
915
-
916
- Parameters
917
- ----------
918
- score_space : NDArray
919
- Data array of scores.
920
- minimum_score : float
921
- Minimum score value to consider. If provided, superseeds limit given
922
- by :py:attr:`PeakCaller.number_of_peaks`.
923
-
924
- Returns
925
- -------
926
- Tuple[NDArray, NDArray]
927
- Array of peak coordinates and peak details.
928
- """
929
- score_space = backend.to_numpy_array(score_space)
830
+ self, scores: BackendArray, minimum_score: float = None, **kwargs
831
+ ) -> PeakType:
832
+ scores = be.to_numpy_array(scores)
930
833
  num_peaks = self.number_of_peaks
931
834
  if minimum_score is not None:
932
835
  num_peaks = np.inf
933
836
 
934
- non_squeezable_dims = tuple(
935
- i for i, x in enumerate(score_space.shape) if x != 1
936
- )
837
+ non_squeezable_dims = tuple(i for i, x in enumerate(scores.shape) if x != 1)
937
838
  peaks = peak_local_max(
938
- np.squeeze(score_space),
839
+ np.squeeze(scores),
939
840
  num_peaks=num_peaks,
940
841
  min_distance=self.min_distance,
941
842
  threshold_abs=minimum_score,
942
843
  )
943
- peaks_full = np.zeros((peaks.shape[0], score_space.ndim), peaks.dtype)
844
+ peaks_full = np.zeros((peaks.shape[0], scores.ndim), peaks.dtype)
944
845
  peaks_full[..., non_squeezable_dims] = peaks[:]
945
- peaks = backend.to_backend_array(peaks_full)
846
+ peaks = be.to_backend_array(peaks_full)
946
847
  return peaks, None
947
848
 
948
849
 
@@ -967,7 +868,7 @@ class PeakClustering(PeakCallerSort):
967
868
  Parameters
968
869
  ----------
969
870
  **kwargs
970
- Additional keyword arguments passed to :py:meth:`PeakCaller.merge`.
871
+ Optional keyword arguments passed to :py:meth:`PeakCaller.merge`.
971
872
 
972
873
  Returns
973
874
  -------
@@ -999,263 +900,33 @@ class PeakClustering(PeakCallerSort):
999
900
  return peaks, rotations, scores, details
1000
901
 
1001
902
 
1002
- class ScoreStatistics(PeakCallerFast):
1003
- """
1004
- Compute basic statistics on score spaces with respect to a reference
1005
- score or value.
1006
-
1007
- This class is used to evaluate a blurring or scoring method when the correct fit
1008
- is known. It is thread-safe and is designed to be shared among multiple processes
1009
- with write permissions to the internal parameters.
1010
-
1011
- After instantiation, the class's functionality can be accessed through the
1012
- `__call__` method.
1013
-
1014
- Parameters
1015
- ----------
1016
- reference_position : int, optional
1017
- Index of the correct fit in the array passed to call. Defaults to None.
1018
- min_distance : float, optional
1019
- Minimum distance for local maxima. Defaults to None.
1020
- reference_fit : float, optional
1021
- Score of the correct fit. If set, `reference_position` will be ignored.
1022
- Defaults to None.
1023
- number_of_peaks : int, optional
1024
- Number of candidate fits to consider. Defaults to 1.
1025
- """
1026
-
1027
- def __init__(
1028
- self,
1029
- reference_position: Tuple[int] = None,
1030
- min_distance: float = 10,
1031
- reference_fit: float = None,
1032
- number_of_peaks: int = 1,
1033
- ):
1034
- super().__init__(number_of_peaks=number_of_peaks, min_distance=min_distance)
1035
- self.lock = Lock()
1036
-
1037
- self.n = RawValue("Q", 0)
1038
- self.rmean = RawValue("d", 0)
1039
- self.ssqd = RawValue("d", 0)
1040
- self.nbetter_or_equal = RawValue("Q", 0)
1041
- self.maximum_value = RawValue("f", 0)
1042
- self.minimum_value = RawValue("f", 2**32)
1043
- self.shannon_entropy = Manager().list()
1044
- self.candidate_fits = Manager().list()
1045
- self.rotation_names = Manager().list()
1046
- self.reference_fit = RawValue("f", 0)
1047
- self.has_reference = RawValue("i", 0)
1048
-
1049
- self.reference_position = reference_position
1050
- if reference_fit is not None:
1051
- self.reference_fit.value = reference_fit
1052
- self.has_reference.value = 1
1053
-
1054
- def __call__(
1055
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1056
- ) -> None:
1057
- """
1058
- Processes the input array and rotation matrix.
1059
-
1060
- Parameters
1061
- ----------
1062
- arr : NDArray
1063
- Input data array.
1064
- rotation_matrix : NDArray
1065
- Rotation matrix for processing.
1066
- """
1067
- self.set_reference(score_space, rotation_matrix)
1068
-
1069
- while not self.has_reference.value:
1070
- print("Stalling processes until reference_fit has been set.")
1071
- sleep(0.5)
1072
-
1073
- name = "_".join([str(value) for value in rotation_matrix.ravel()])
1074
- n, rmean, ssqd, nbetter_or_equal, max_value, min_value = online_statistics(
1075
- score_space, 0, 0.0, 0.0, self.reference_fit.value
1076
- )
1077
-
1078
- freq, _ = np.histogram(score_space, bins=100)
1079
- shannon_entropy = entropy(freq / score_space.size)
1080
-
1081
- peaks, _ = super().call_peaks(
1082
- score_space=score_space, rotation_matrix=rotation_matrix, **kwargs
1083
- )
1084
- scores = score_space[tuple(peaks.T)]
1085
- rotations = np.repeat(
1086
- rotation_matrix.reshape(1, *rotation_matrix.shape),
1087
- peaks.shape[0],
1088
- axis=0,
1089
- )
1090
- distances = np.linalg.norm(peaks - self.reference_position[None, :], axis=1)
1091
-
1092
- self._update(
1093
- peak_positions=peaks,
1094
- rotations=rotations,
1095
- peak_scores=scores,
1096
- peak_details=distances,
1097
- n=n,
1098
- rmean=rmean,
1099
- ssqd=ssqd,
1100
- nbetter_or_equal=nbetter_or_equal,
1101
- max_value=max_value,
1102
- min_value=min_value,
1103
- entropy=shannon_entropy,
1104
- name=name,
1105
- )
1106
-
1107
- def __iter__(self):
1108
- param_store = (
1109
- self.peak_list[0],
1110
- self.peak_list[1],
1111
- self.peak_list[2],
1112
- self.peak_list[3],
1113
- self.n.value,
1114
- self.rmean.value,
1115
- self.ssqd.value,
1116
- self.nbetter_or_equal.value,
1117
- self.maximum_value.value,
1118
- self.minimum_value.value,
1119
- list(self.shannon_entropy),
1120
- list(self.rotation_names),
1121
- self.reference_fit.value,
1122
- )
1123
- yield from param_store
1124
-
1125
- def _update(
1126
- self,
1127
- n: int,
1128
- rmean: float,
1129
- ssqd: float,
1130
- nbetter_or_equal: int,
1131
- max_value: float,
1132
- min_value: float,
1133
- entropy: float,
1134
- name: str,
1135
- **kwargs,
1136
- ) -> None:
1137
- """
1138
- Updates the internal statistics of the analyzer.
1139
-
1140
- Parameters
1141
- ----------
1142
- n : int
1143
- Sample size.
1144
- rmean : float
1145
- Running mean.
1146
- ssqd : float
1147
- Sum of squared differences.
1148
- nbetter_or_equal : int
1149
- Number of values better or equal to reference.
1150
- max_value : float
1151
- Maximum value.
1152
- min_value : float
1153
- Minimum value.
1154
- entropy : float
1155
- Shannon entropy.
1156
- candidates : list
1157
- List of candidate fits.
1158
- name : str
1159
- Name or label for the data.
1160
- kwargs : dict
1161
- Keyword arguments passed to PeakCaller._update.
1162
- """
1163
- with self.lock:
1164
- super()._update(**kwargs)
1165
-
1166
- n_total = self.n.value + n
1167
- delta = rmean - self.rmean.value
1168
- delta2 = delta * delta
1169
- self.rmean.value += delta * n / n_total
1170
- self.ssqd.value += ssqd + delta2 * (n * self.n.value) / n_total
1171
- self.n.value = n_total
1172
- self.nbetter_or_equal.value += nbetter_or_equal
1173
- self.minimum_value.value = min(self.minimum_value.value, min_value)
1174
- self.maximum_value.value = max(self.maximum_value.value, max_value)
1175
- self.shannon_entropy.append(entropy)
1176
- self.rotation_names.append(name)
1177
-
1178
- @classmethod
1179
- def merge(cls, param_stores: List[Tuple]) -> Tuple:
1180
- """
1181
- Merges multiple instances of :py:class`ScoreStatistics`.
1182
-
1183
- Parameters
1184
- ----------
1185
- param_stores : list of tuple
1186
- Internal parameter store. Obtained by running `tuple(instance)`.
1187
- Defaults to a list with two empty tuples.
1188
-
1189
- Returns
1190
- -------
1191
- tuple
1192
- Contains the reference fit, the z-transform of the reference fit,
1193
- number of scores, and various other statistics.
1194
- """
1195
- base = cls(reference_position=np.zeros(3, int))
1196
- for param_store in param_stores:
1197
- base._update(
1198
- peak_positions=param_store[0],
1199
- rotations=param_store[1],
1200
- peak_scores=param_store[2],
1201
- peak_details=param_store[3],
1202
- n=param_store[4],
1203
- rmean=param_store[5],
1204
- ssqd=param_store[6],
1205
- nbetter_or_equal=param_store[7],
1206
- max_value=param_store[8],
1207
- min_value=param_store[9],
1208
- entropy=param_store[10],
1209
- name=param_store[11],
1210
- )
1211
- base.reference_fit.value = param_store[12]
1212
- return tuple(base)
1213
-
1214
- def set_reference(self, score_space: NDArray, rotation_matrix: NDArray) -> None:
1215
- """
1216
- Sets the reference for the analyzer based on the input array
1217
- and rotation matrix.
1218
-
1219
- Parameters
1220
- ----------
1221
- score_space : NDArray
1222
- Input data array.
1223
- rotation_matrix : NDArray
1224
- Rotation matrix for setting reference.
1225
- """
1226
- is_ref = np.allclose(
1227
- rotation_matrix,
1228
- np.eye(rotation_matrix.shape[0], dtype=rotation_matrix.dtype),
1229
- )
1230
- if not is_ref:
1231
- return None
1232
-
1233
- reference_position = self.reference_position
1234
- if reference_position is None:
1235
- reference_position = np.divide(score_space.shape, 2).astype(int)
1236
- self.reference_position = reference_position
1237
- self.reference_fit.value = score_space[tuple(reference_position)]
1238
- self.has_reference.value = 1
1239
-
1240
-
1241
903
  class MaxScoreOverRotations:
1242
904
  """
1243
- Obtain the maximum translation score over various rotations.
905
+ Determine the rotation maximizing the score of all given translations.
1244
906
 
1245
907
  Attributes
1246
908
  ----------
1247
- score_space : NDArray
1248
- The score space for the observed rotations.
1249
- rotations : NDArray
1250
- The rotation identifiers for each score.
1251
- translation_offset : NDArray, optional
1252
- The offset applied during translation.
1253
- observed_rotations : int
1254
- Count of observed rotations.
909
+ shape : tuple of ints.
910
+ Shape of ``scores`` and rotations.
911
+ scores : BackendArray
912
+ Array mapping translations to scores.
913
+ rotations : BackendArray
914
+ Array mapping translations to rotation indices.
915
+ rotation_mapping : Dict
916
+ Mapping of rotation matrix bytestrings to rotation indices.
917
+ offset : BackendArray, optional
918
+ Coordinate origin considered during merging, zero by default
1255
919
  use_memmap : bool, optional
1256
- Whether to offload internal data arrays to disk
920
+ Memmap scores and rotations arrays, False by default.
1257
921
  thread_safe: bool, optional
1258
- Whether access to internal data arrays should be thread safe
922
+ Allow class to be modified by multiple processes, True by default.
923
+ only_unique_rotations : bool, optional
924
+ Whether each rotation will be shown only once, False by default.
925
+
926
+ Raises
927
+ ------
928
+ ValueError
929
+ If the data shape cannot be determined from the parameters.
1259
930
 
1260
931
  Examples
1261
932
  --------
@@ -1263,11 +934,7 @@ class MaxScoreOverRotations:
1263
934
  instance
1264
935
 
1265
936
  >>> from tme.analyzer import MaxScoreOverRotations
1266
- >>> analyzer = MaxScoreOverRotations(
1267
- >>> score_space_shape = (50, 50),
1268
- >>> score_space_dtype = np.float32,
1269
- >>> rotation_space_dtype = np.int32,
1270
- >>> )
937
+ >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
1271
938
 
1272
939
  The following simulates a template matching run by creating random data for a range
1273
940
  of rotations and sending it to ``analyzer`` via its __call__ method
@@ -1275,9 +942,9 @@ class MaxScoreOverRotations:
1275
942
  >>> for rotation_number in range(10):
1276
943
  >>> scores = np.random.rand(50,50)
1277
944
  >>> rotation = np.random.rand(scores.ndim, scores.ndim)
1278
- >>> analyzer(score_space = scores, rotation_matrix = rotation)
945
+ >>> analyzer(scores = scores, rotation_matrix = rotation)
1279
946
 
1280
- The aggregated scores can be exctracted by invoking the __iter__ method of
947
+ The aggregated scores can be extracted by invoking the __iter__ method of
1281
948
  ``analyzer``
1282
949
 
1283
950
  >>> results = tuple(analyzer)
@@ -1288,7 +955,7 @@ class MaxScoreOverRotations:
1288
955
  score for a given translation, (4) a dictionary mapping rotation matrices to the
1289
956
  indices used in (2).
1290
957
 
1291
- We can extract the ``optimal_score`, ``optimal_translation`` and ``optimal_rotation``
958
+ We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
1292
959
  as follows
1293
960
 
1294
961
  >>> optimal_score = results[0].max()
@@ -1307,201 +974,177 @@ class MaxScoreOverRotations:
1307
974
 
1308
975
  def __init__(
1309
976
  self,
1310
- score_space_shape: Tuple[int],
1311
- score_space_dtype: type,
1312
- translation_offset: NDArray = None,
977
+ shape: Tuple[int] = None,
978
+ scores: BackendArray = None,
979
+ rotations: BackendArray = None,
980
+ offset: BackendArray = None,
1313
981
  score_threshold: float = 0,
1314
982
  shared_memory_handler: object = None,
1315
- rotation_space_dtype: type = int,
1316
983
  use_memmap: bool = False,
1317
984
  thread_safe: bool = True,
985
+ only_unique_rotations: bool = False,
1318
986
  **kwargs,
1319
987
  ):
1320
- score_space_shape = tuple(int(x) for x in score_space_shape)
1321
- self.score_space = backend.arr_to_sharedarr(
1322
- backend.full(
1323
- shape=score_space_shape,
1324
- dtype=score_space_dtype,
988
+ if shape is None and scores is None:
989
+ raise ValueError("Either scores_shape or scores need to be specified.")
990
+
991
+ if scores is None:
992
+ shape = tuple(int(x) for x in shape)
993
+ scores = be.full(
994
+ shape=shape,
995
+ dtype=be._float_dtype,
1325
996
  fill_value=score_threshold,
1326
- ),
1327
- shared_memory_handler=shared_memory_handler,
1328
- )
1329
- self.rotations = backend.arr_to_sharedarr(
1330
- backend.full(score_space_shape, dtype=rotation_space_dtype, fill_value=-1),
1331
- shared_memory_handler,
1332
- )
1333
- if translation_offset is None:
1334
- translation_offset = backend.zeros(len(score_space_shape))
997
+ )
998
+ self.scores, self.shape = scores, scores.shape
999
+
1000
+ if rotations is None:
1001
+ rotations = be.full(shape, dtype=be._int_dtype, fill_value=-1)
1002
+ self.rotations = rotations
1335
1003
 
1336
- self.translation_offset = backend.astype(translation_offset, int)
1337
- self.score_space_shape = score_space_shape
1338
- self.rotation_space_dtype = rotation_space_dtype
1339
- self.score_space_dtype = score_space_dtype
1004
+ self.scores_dtype = self.scores.dtype
1005
+ self.rotations_dtype = self.rotations.dtype
1006
+ self.scores = be.to_sharedarr(self.scores, shared_memory_handler)
1007
+ self.rotations = be.to_sharedarr(self.rotations, shared_memory_handler)
1008
+
1009
+ if offset is None:
1010
+ offset = be.zeros(len(self.shape), be._int_dtype)
1011
+ self.offset = be.astype(offset, int)
1340
1012
 
1341
1013
  self.use_memmap = use_memmap
1342
1014
  self.lock = Manager().Lock() if thread_safe else nullcontext()
1343
- self.lock_is_nullcontext = isinstance(
1344
- self.score_space, type(backend.zeros((1)))
1345
- )
1346
- self.observed_rotations = Manager().dict() if thread_safe else {}
1015
+ self.lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
1016
+ self.rotation_mapping = Manager().dict() if thread_safe else {}
1017
+ self._inversion_mapping = self.lock_is_nullcontext and only_unique_rotations
1347
1018
 
1348
1019
  def _postprocess(
1349
1020
  self,
1350
- fourier_shift,
1351
- convolution_mode,
1352
- targetshape,
1353
- templateshape,
1021
+ targetshape: Tuple[int],
1022
+ templateshape: Tuple[int],
1023
+ convolution_shape: Tuple[int],
1024
+ fourier_shift: Tuple[int] = None,
1025
+ convolution_mode: str = None,
1354
1026
  shared_memory_handler=None,
1027
+ fast_shape: Tuple[int] = None,
1355
1028
  **kwargs,
1356
- ):
1357
- internal_scores = backend.sharedarr_to_arr(
1358
- shape=self.score_space_shape,
1359
- dtype=self.score_space_dtype,
1360
- shm=self.score_space,
1361
- )
1362
- internal_rotations = backend.sharedarr_to_arr(
1363
- shape=self.score_space_shape,
1364
- dtype=self.rotation_space_dtype,
1365
- shm=self.rotations,
1366
- )
1367
-
1029
+ ) -> "MaxScoreOverRotations":
1030
+ """
1031
+ Correct padding to Fourier (and if requested convolution) shape.
1032
+ """
1033
+ scores = be.from_sharedarr(self.scores)
1034
+ rotations = be.from_sharedarr(self.rotations)
1368
1035
  if fourier_shift is not None:
1369
1036
  axis = tuple(i for i in range(len(fourier_shift)))
1370
- internal_scores = backend.roll(
1371
- internal_scores, shift=fourier_shift, axis=axis
1372
- )
1373
- internal_rotations = backend.roll(
1374
- internal_rotations, shift=fourier_shift, axis=axis
1375
- )
1376
-
1037
+ scores = be.roll(scores, shift=fourier_shift, axis=axis)
1038
+ rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
1039
+
1040
+ convargs = {
1041
+ "s1": targetshape,
1042
+ "s2": templateshape,
1043
+ "convolution_mode": convolution_mode,
1044
+ "convolution_shape": convolution_shape,
1045
+ }
1377
1046
  if convolution_mode is not None:
1378
- internal_scores = apply_convolution_mode(
1379
- internal_scores,
1380
- convolution_mode=convolution_mode,
1381
- s1=targetshape,
1382
- s2=templateshape,
1383
- )
1384
- internal_rotations = apply_convolution_mode(
1385
- internal_rotations,
1386
- convolution_mode=convolution_mode,
1387
- s1=targetshape,
1388
- s2=templateshape,
1389
- )
1047
+ scores = apply_convolution_mode(scores, **convargs)
1048
+ rotations = apply_convolution_mode(rotations, **convargs)
1390
1049
 
1391
- self.score_space_shape = internal_scores.shape
1392
- self.score_space = backend.arr_to_sharedarr(
1393
- internal_scores, shared_memory_handler
1394
- )
1395
- self.rotations = backend.arr_to_sharedarr(
1396
- internal_rotations, shared_memory_handler
1397
- )
1050
+ self.shape = scores.shape
1051
+ self.scores = be.to_sharedarr(scores, shared_memory_handler)
1052
+ self.rotations = be.to_sharedarr(rotations, shared_memory_handler)
1398
1053
  return self
1399
1054
 
1400
- def __iter__(self):
1401
- internal_scores = backend.sharedarr_to_arr(
1402
- shape=self.score_space_shape,
1403
- dtype=self.score_space_dtype,
1404
- shm=self.score_space,
1405
- )
1406
- internal_rotations = backend.sharedarr_to_arr(
1407
- shape=self.score_space_shape,
1408
- dtype=self.rotation_space_dtype,
1409
- shm=self.rotations,
1410
- )
1055
+ def __iter__(self) -> Generator:
1056
+ scores = be.from_sharedarr(self.scores)
1057
+ rotations = be.from_sharedarr(self.rotations)
1411
1058
 
1412
- internal_scores = backend.to_numpy_array(internal_scores)
1413
- internal_rotations = backend.to_numpy_array(internal_rotations)
1059
+ scores = be.to_numpy_array(scores)
1060
+ rotations = be.to_numpy_array(rotations)
1414
1061
  if self.use_memmap:
1415
- internal_scores_filename = array_to_memmap(internal_scores)
1416
- internal_rotations_filename = array_to_memmap(internal_rotations)
1417
- internal_scores = np.memmap(
1418
- internal_scores_filename,
1062
+ scores = np.memmap(
1063
+ array_to_memmap(scores),
1419
1064
  mode="r",
1420
- dtype=internal_scores.dtype,
1421
- shape=internal_scores.shape,
1065
+ dtype=scores.dtype,
1066
+ shape=scores.shape,
1422
1067
  )
1423
- internal_rotations = np.memmap(
1424
- internal_rotations_filename,
1068
+ rotations = np.memmap(
1069
+ array_to_memmap(rotations),
1425
1070
  mode="r",
1426
- dtype=internal_rotations.dtype,
1427
- shape=internal_rotations.shape,
1071
+ dtype=rotations.dtype,
1072
+ shape=rotations.shape,
1428
1073
  )
1429
1074
  else:
1430
- # Avoid invalidation by shared memory handler with copy
1431
- internal_scores = internal_scores.copy()
1432
- internal_rotations = internal_rotations.copy()
1075
+ # Copy to avoid invalidation by shared memory handler
1076
+ scores, rotations = scores.copy(), rotations.copy()
1077
+
1078
+ if self._inversion_mapping:
1079
+ self.rotation_mapping = {
1080
+ be.tobytes(v): k for k, v in self.rotation_mapping.items()
1081
+ }
1433
1082
 
1434
1083
  param_store = (
1435
- internal_scores,
1436
- backend.to_numpy_array(self.translation_offset),
1437
- internal_rotations,
1438
- dict(self.observed_rotations),
1084
+ scores,
1085
+ be.to_numpy_array(self.offset),
1086
+ rotations,
1087
+ dict(self.rotation_mapping),
1439
1088
  )
1440
1089
  yield from param_store
1441
1090
 
1442
- def __call__(
1443
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1444
- ) -> None:
1091
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
1445
1092
  """
1446
- Update internal parameter store based on `score_space`.
1093
+ Update internal parameter store based on `scores`.
1447
1094
 
1448
1095
  Parameters
1449
1096
  ----------
1450
- score_space : ndarray
1451
- Numpy array containing the score space.
1452
- rotation_matrix : ndarray
1097
+ scores : BackendArray
1098
+ Array containing the score space.
1099
+ rotation_matrix : BackendArray
1453
1100
  Square matrix describing the current rotation.
1454
- **kwargs
1455
- Arbitrary keyword arguments.
1456
1101
  """
1457
- rotation = backend.tobytes(rotation_matrix)
1458
-
1102
+ # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
1103
+ # If the analyzer is not shared and each rotation is unique, we can
1104
+ # use index to rotation mapping and invert prior to merging.
1459
1105
  if self.lock_is_nullcontext:
1460
- rotation_index = self.observed_rotations.setdefault(
1461
- rotation, len(self.observed_rotations)
1462
- )
1463
- backend.max_score_over_rotations(
1464
- score_space=score_space,
1465
- internal_scores=self.score_space,
1466
- internal_rotations=self.rotations,
1106
+ rotation_index = len(self.rotation_mapping)
1107
+ if self._inversion_mapping:
1108
+ self.rotation_mapping[rotation_index] = rotation_matrix
1109
+ else:
1110
+ rotation = be.tobytes(rotation_matrix)
1111
+ rotation_index = self.rotation_mapping.setdefault(
1112
+ rotation, rotation_index
1113
+ )
1114
+ self.scores, self.rotations = be.max_score_over_rotations(
1115
+ scores=scores,
1116
+ max_scores=self.scores,
1117
+ rotations=self.rotations,
1467
1118
  rotation_index=rotation_index,
1468
1119
  )
1469
1120
  return None
1470
1121
 
1122
+ rotation = be.tobytes(rotation_matrix)
1471
1123
  with self.lock:
1472
- rotation_index = self.observed_rotations.setdefault(
1473
- rotation, len(self.observed_rotations)
1474
- )
1475
- internal_scores = backend.sharedarr_to_arr(
1476
- shape=self.score_space_shape,
1477
- dtype=self.score_space_dtype,
1478
- shm=self.score_space,
1124
+ rotation_index = self.rotation_mapping.setdefault(
1125
+ rotation, len(self.rotation_mapping)
1479
1126
  )
1480
- internal_rotations = backend.sharedarr_to_arr(
1481
- shape=self.score_space_shape,
1482
- dtype=self.rotation_space_dtype,
1483
- shm=self.rotations,
1484
- )
1485
-
1486
- backend.max_score_over_rotations(
1487
- score_space=score_space,
1488
- internal_scores=internal_scores,
1489
- internal_rotations=internal_rotations,
1127
+ internal_scores = be.from_sharedarr(self.scores)
1128
+ internal_rotations = be.from_sharedarr(self.rotations)
1129
+ internal_sores, internal_rotations = be.max_score_over_rotations(
1130
+ scores=scores,
1131
+ max_scores=internal_scores,
1132
+ rotations=internal_rotations,
1490
1133
  rotation_index=rotation_index,
1491
1134
  )
1492
1135
  return None
1493
1136
 
1494
1137
  @classmethod
1495
- def merge(cls, param_stores=List[Tuple], **kwargs) -> Tuple[NDArray]:
1138
+ def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple[NDArray]:
1496
1139
  """
1497
1140
  Merges multiple instances of :py:class:`MaxScoreOverRotations`.
1498
1141
 
1499
1142
  Parameters
1500
1143
  ----------
1501
- param_stores : list of tuples, optional
1144
+ param_stores : list of tuples
1502
1145
  Internal parameter store. Obtained by running `tuple(instance)`.
1503
1146
  **kwargs
1504
- Arbitrary keyword arguments.
1147
+ Optional keyword arguments.
1505
1148
 
1506
1149
  Returns
1507
1150
  -------
@@ -1513,51 +1156,51 @@ class MaxScoreOverRotations:
1513
1156
  if len(param_stores) == 1:
1514
1157
  return param_stores[0]
1515
1158
 
1516
- new_rotation_mapping, base_max = {}, None
1517
- scores_out_dtype, rotations_out_dtype = None, None
1159
+ # Determine output array shape and create consistent rotation map
1160
+ new_rotation_mapping, out_shape = {}, None
1518
1161
  for i in range(len(param_stores)):
1519
1162
  if param_stores[i] is None:
1520
1163
  continue
1521
- score_space, offset, rotations, rotation_mapping = param_stores[i]
1522
- if base_max is None:
1523
- base_max = np.zeros(score_space.ndim, int)
1524
- scores_out_dtype = score_space.dtype
1525
- rotations_out_dtype = rotations.dtype
1526
- np.maximum(base_max, np.add(offset, score_space.shape), out=base_max)
1164
+
1165
+ scores, offset, rotations, rotation_mapping = param_stores[i]
1166
+ if out_shape is None:
1167
+ out_shape = np.zeros(scores.ndim, int)
1168
+ scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
1169
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
1527
1170
 
1528
1171
  for key, value in rotation_mapping.items():
1529
1172
  if key not in new_rotation_mapping:
1530
1173
  new_rotation_mapping[key] = len(new_rotation_mapping)
1531
1174
 
1532
- if base_max is None:
1175
+ if out_shape is None:
1533
1176
  return None
1534
1177
 
1535
- base_max = tuple(int(x) for x in base_max)
1178
+ out_shape = tuple(int(x) for x in out_shape)
1536
1179
  use_memmap = kwargs.get("use_memmap", False)
1537
1180
  if use_memmap:
1538
1181
  scores_out_filename = generate_tempfile_name()
1539
1182
  rotations_out_filename = generate_tempfile_name()
1540
1183
 
1541
1184
  scores_out = np.memmap(
1542
- scores_out_filename, mode="w+", shape=base_max, dtype=scores_out_dtype
1185
+ scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
1543
1186
  )
1544
1187
  scores_out.fill(kwargs.get("score_threshold", 0))
1545
1188
  scores_out.flush()
1546
1189
  rotations_out = np.memmap(
1547
1190
  rotations_out_filename,
1548
1191
  mode="w+",
1549
- shape=base_max,
1550
- dtype=rotations_out_dtype,
1192
+ shape=out_shape,
1193
+ dtype=rotations_dtype,
1551
1194
  )
1552
1195
  rotations_out.fill(-1)
1553
1196
  rotations_out.flush()
1554
1197
  else:
1555
1198
  scores_out = np.full(
1556
- base_max,
1199
+ out_shape,
1557
1200
  fill_value=kwargs.get("score_threshold", 0),
1558
- dtype=scores_out_dtype,
1201
+ dtype=scores_dtype,
1559
1202
  )
1560
- rotations_out = np.full(base_max, fill_value=-1, dtype=rotations_out_dtype)
1203
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
1561
1204
 
1562
1205
  for i in range(len(param_stores)):
1563
1206
  if param_stores[i] is None:
@@ -1567,21 +1210,21 @@ class MaxScoreOverRotations:
1567
1210
  scores_out = np.memmap(
1568
1211
  scores_out_filename,
1569
1212
  mode="r+",
1570
- shape=base_max,
1571
- dtype=scores_out_dtype,
1213
+ shape=out_shape,
1214
+ dtype=scores_dtype,
1572
1215
  )
1573
1216
  rotations_out = np.memmap(
1574
1217
  rotations_out_filename,
1575
1218
  mode="r+",
1576
- shape=base_max,
1577
- dtype=rotations_out_dtype,
1219
+ shape=out_shape,
1220
+ dtype=rotations_dtype,
1578
1221
  )
1579
- score_space, offset, rotations, rotation_mapping = param_stores[i]
1580
- stops = np.add(offset, score_space.shape).astype(int)
1222
+ scores, offset, rotations, rotation_mapping = param_stores[i]
1223
+ stops = np.add(offset, scores.shape).astype(int)
1581
1224
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
1582
1225
 
1583
- indices_update = score_space > scores_out[indices]
1584
- scores_out[indices][indices_update] = score_space[indices_update]
1226
+ indices_update = scores > scores_out[indices]
1227
+ scores_out[indices][indices_update] = scores[indices_update]
1585
1228
 
1586
1229
  lookup_table = np.arange(
1587
1230
  len(rotation_mapping) + 1, dtype=rotations_out.dtype
@@ -1594,24 +1237,24 @@ class MaxScoreOverRotations:
1594
1237
  rotations_out[indices][indices_update] = lookup_table[updated_rotations]
1595
1238
 
1596
1239
  if use_memmap:
1597
- score_space._mmap.close()
1240
+ scores._mmap.close()
1598
1241
  rotations._mmap.close()
1599
1242
  scores_out.flush()
1600
1243
  rotations_out.flush()
1601
1244
  scores_out, rotations_out = None, None
1602
1245
 
1603
1246
  param_stores[i] = None
1604
- score_space, rotations = None, None
1247
+ scores, rotations = None, None
1605
1248
 
1606
1249
  if use_memmap:
1607
1250
  scores_out = np.memmap(
1608
- scores_out_filename, mode="r", shape=base_max, dtype=scores_out_dtype
1251
+ scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
1609
1252
  )
1610
1253
  rotations_out = np.memmap(
1611
1254
  rotations_out_filename,
1612
1255
  mode="r",
1613
- shape=base_max,
1614
- dtype=rotations_out_dtype,
1256
+ shape=out_shape,
1257
+ dtype=rotations_dtype,
1615
1258
  )
1616
1259
  return (
1617
1260
  scores_out,
@@ -1620,6 +1263,10 @@ class MaxScoreOverRotations:
1620
1263
  new_rotation_mapping,
1621
1264
  )
1622
1265
 
1266
+ @property
1267
+ def shared(self):
1268
+ return True
1269
+
1623
1270
 
1624
1271
  class _MaxScoreOverTranslations(MaxScoreOverRotations):
1625
1272
  """
@@ -1627,11 +1274,11 @@ class _MaxScoreOverTranslations(MaxScoreOverRotations):
1627
1274
 
1628
1275
  Attributes
1629
1276
  ----------
1630
- score_space : NDArray
1277
+ scores : BackendArray
1631
1278
  The score space for the observed rotations.
1632
- rotations : NDArray
1279
+ rotations : BackendArray
1633
1280
  The rotation identifiers for each score.
1634
- translation_offset : NDArray, optional
1281
+ translation_offset : BackendArray, optional
1635
1282
  The offset applied during translation.
1636
1283
  observed_rotations : int
1637
1284
  Count of observed rotations.
@@ -1642,36 +1289,36 @@ class _MaxScoreOverTranslations(MaxScoreOverRotations):
1642
1289
  """
1643
1290
 
1644
1291
  def __call__(
1645
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1292
+ self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs
1646
1293
  ) -> None:
1647
1294
  """
1648
- Update internal parameter store based on `score_space`.
1295
+ Update internal parameter store based on `scores`.
1649
1296
 
1650
1297
  Parameters
1651
1298
  ----------
1652
- score_space : ndarray
1299
+ scores : BackendArray
1653
1300
  Numpy array containing the score space.
1654
- rotation_matrix : ndarray
1301
+ rotation_matrix : BackendArray
1655
1302
  Square matrix describing the current rotation.
1656
1303
  **kwargs
1657
- Arbitrary keyword arguments.
1304
+ Optional keyword arguments.
1658
1305
  """
1659
1306
  from tme.matching_utils import centered_mask
1660
1307
 
1661
1308
  with self.lock:
1662
- rotation = backend.tobytes(rotation_matrix)
1309
+ rotation = be.tobytes(rotation_matrix)
1663
1310
  if rotation not in self.observed_rotations:
1664
1311
  self.observed_rotations[rotation] = len(self.observed_rotations)
1665
- score_space = centered_mask(score_space, kwargs["template_shape"])
1312
+ scores = centered_mask(scores, kwargs["template_shape"])
1666
1313
  rotation_index = self.observed_rotations[rotation]
1667
- internal_scores = backend.sharedarr_to_arr(
1668
- shape=self.score_space_shape,
1669
- dtype=self.score_space_dtype,
1670
- shm=self.score_space,
1314
+ internal_scores = be.from_sharedarr(
1315
+ shape=self.shape,
1316
+ dtype=self.scores_dtype,
1317
+ shm=self.scores,
1671
1318
  )
1672
- max_score = score_space.max(axis=(1, 2, 3))
1673
- mean_score = score_space.mean(axis=(1, 2, 3))
1674
- std_score = score_space.std(axis=(1, 2, 3))
1319
+ max_score = scores.max(axis=(1, 2, 3))
1320
+ mean_score = scores.mean(axis=(1, 2, 3))
1321
+ std_score = scores.std(axis=(1, 2, 3))
1675
1322
  z_score = (max_score - mean_score) / std_score
1676
1323
  internal_scores[rotation_index] = z_score
1677
1324
 
@@ -1695,7 +1342,7 @@ class MemmapHandler:
1695
1342
  indices : tuple of slice, optional
1696
1343
  Slices specifying which parts of the memmap array will be updated by `__call__`.
1697
1344
  **kwargs
1698
- Arbitrary keyword arguments.
1345
+ Optional keyword arguments.
1699
1346
  """
1700
1347
 
1701
1348
  def __init__(
@@ -1718,15 +1365,13 @@ class MemmapHandler:
1718
1365
  self.dtype = dtype
1719
1366
  self._indices = indices
1720
1367
 
1721
- def __call__(
1722
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1723
- ) -> None:
1368
+ def __call__(self, scores: NDArray, rotation_matrix: NDArray) -> None:
1724
1369
  """
1725
- Write `score_space` to memmap object on disk.
1370
+ Write `scores` to memmap object on disk.
1726
1371
 
1727
1372
  Parameters
1728
1373
  ----------
1729
- score_space : ndarray
1374
+ scores : ndarray
1730
1375
  Numpy array containing the score space.
1731
1376
  rotation_matrix : ndarray
1732
1377
  Square matrix describing the current rotation.
@@ -1738,7 +1383,7 @@ class MemmapHandler:
1738
1383
  array = np.memmap(current_object, mode="r+", shape=self.shape, dtype=self.dtype)
1739
1384
  # Does not really need a lock because processes operate on different rotations
1740
1385
  with self.lock:
1741
- array[self._indices] += score_space
1386
+ array[self._indices] += scores
1742
1387
  array.flush()
1743
1388
 
1744
1389
  def __iter__(self):