pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/analyzer.py CHANGED
@@ -4,59 +4,58 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- from time import sleep
8
- from typing import Tuple, List, Dict
9
- from abc import ABC, abstractmethod
10
7
  from contextlib import nullcontext
11
- from multiprocessing import RawValue, Manager, Lock
8
+ from abc import ABC, abstractmethod
9
+ from multiprocessing import Manager, Lock
10
+ from typing import Tuple, List, Dict, Generator
12
11
 
13
12
  import numpy as np
14
- from numpy.typing import NDArray
15
- from scipy.stats import entropy
16
13
  from sklearn.cluster import DBSCAN
17
14
  from skimage.feature import peak_local_max
18
15
  from skimage.registration._phase_cross_correlation import _upsampled_dft
19
- from .extensions import max_index_by_label, online_statistics, find_candidate_indices
16
+
17
+ from .backends import backend as be
18
+ from .types import BackendArray, NDArray
19
+ from .extensions import max_index_by_label, find_candidate_indices
20
20
  from .matching_utils import (
21
- split_numpy_array_slices,
21
+ split_shape,
22
22
  array_to_memmap,
23
23
  generate_tempfile_name,
24
24
  euler_to_rotationmatrix,
25
25
  apply_convolution_mode,
26
26
  )
27
- from .backends import backend
27
+
28
+ PeakType = Tuple[BackendArray, BackendArray]
28
29
 
29
30
 
30
- def filter_points_indices_bucket(
31
- coordinates: NDArray, min_distance: Tuple[int]
32
- ) -> NDArray:
33
- coordinates = backend.subtract(coordinates, backend.min(coordinates, axis=0))
34
- bucket_indices = backend.astype(backend.divide(coordinates, min_distance), int)
35
- multiplier = backend.power(
36
- backend.max(bucket_indices, axis=0) + 1, backend.arange(bucket_indices.shape[1])
31
+ def _filter_bucket(coordinates: BackendArray, min_distance: Tuple[int]) -> BackendArray:
32
+ coordinates = be.subtract(coordinates, be.min(coordinates, axis=0))
33
+ bucket_indices = be.astype(be.divide(coordinates, min_distance), int)
34
+ multiplier = be.power(
35
+ be.max(bucket_indices, axis=0) + 1, be.arange(bucket_indices.shape[1])
37
36
  )
38
- backend.multiply(bucket_indices, multiplier, out=bucket_indices)
39
- flattened_indices = backend.sum(bucket_indices, axis=1)
40
- _, unique_indices = backend.unique(flattened_indices, return_index=True)
41
- unique_indices = unique_indices[backend.argsort(unique_indices)]
37
+ be.multiply(bucket_indices, multiplier, out=bucket_indices)
38
+ flattened_indices = be.sum(bucket_indices, axis=1)
39
+ _, unique_indices = be.unique(flattened_indices, return_index=True)
40
+ unique_indices = unique_indices[be.argsort(unique_indices)]
42
41
  return unique_indices
43
42
 
44
43
 
45
44
  def filter_points_indices(
46
- coordinates: NDArray,
45
+ coordinates: BackendArray,
47
46
  min_distance: float,
48
47
  bucket_cutoff: int = 1e4,
49
48
  batch_dims: Tuple[int] = None,
50
- ) -> NDArray:
49
+ ) -> BackendArray:
51
50
  if min_distance <= 0:
52
- return backend.arange(coordinates.shape[0])
51
+ return be.arange(coordinates.shape[0])
53
52
  if coordinates.shape[0] == 0:
54
53
  return ()
55
54
 
56
55
  if batch_dims is not None:
57
- coordinates_new = backend.zeros(coordinates.shape, coordinates.dtype)
56
+ coordinates_new = be.zeros(coordinates.shape, coordinates.dtype)
58
57
  coordinates_new[:] = coordinates
59
- coordinates_new[..., batch_dims] = backend.astype(
58
+ coordinates_new[..., batch_dims] = be.astype(
60
59
  coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
61
60
  )
62
61
  coordinates = coordinates_new
@@ -66,7 +65,7 @@ def filter_points_indices(
66
65
  elif coordinates.shape[0] > bucket_cutoff or not isinstance(
67
66
  coordinates, np.ndarray
68
67
  ):
69
- return filter_points_indices_bucket(coordinates, min_distance)
68
+ return _filter_bucket(coordinates, min_distance)
70
69
  distances = np.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
71
70
  distances = np.tril(distances)
72
71
  keep = np.sum(distances > min_distance, axis=1)
@@ -76,7 +75,7 @@ def filter_points_indices(
76
75
 
77
76
  def filter_points(
78
77
  coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
79
- ) -> NDArray:
78
+ ) -> BackendArray:
80
79
  unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
81
80
  coordinates = coordinates[unique_indices]
82
81
  return coordinates
@@ -96,8 +95,13 @@ class PeakCaller(ABC):
96
95
  Minimum distance to array boundaries.
97
96
  batch_dims : int, optional
98
97
  Peak calling batch dimensions.
98
+ minimum_score : float
99
+ Minimum score from which to consider peaks. If provided, superseeds limits
100
+ presented by :py:attr:`PeakCaller.number_of_peaks`.
101
+ maximum_score : float
102
+ Maximum score upon which to consider peaks,
99
103
  **kwargs
100
- Additional keyword arguments.
104
+ Optional keyword arguments.
101
105
 
102
106
  Raises
103
107
  ------
@@ -112,12 +116,10 @@ class PeakCaller(ABC):
112
116
  min_distance: int = 1,
113
117
  min_boundary_distance: int = 0,
114
118
  batch_dims: Tuple[int] = None,
119
+ minimum_score: float = None,
120
+ maximum_score: float = None,
115
121
  **kwargs,
116
122
  ):
117
- number_of_peaks = int(number_of_peaks)
118
- min_distance, min_boundary_distance = int(min_distance), int(
119
- min_boundary_distance
120
- )
121
123
  if number_of_peaks <= 0:
122
124
  raise ValueError(
123
125
  f"number_of_peaks has to be larger than 0, got {number_of_peaks}"
@@ -130,26 +132,28 @@ class PeakCaller(ABC):
130
132
  )
131
133
 
132
134
  self.peak_list = []
133
- self.min_distance = min_distance
134
- self.min_boundary_distance = min_boundary_distance
135
- self.number_of_peaks = number_of_peaks
135
+ self.min_distance = int(min_distance)
136
+ self.number_of_peaks = int(number_of_peaks)
137
+ self.min_boundary_distance = int(min_boundary_distance)
136
138
 
137
139
  self.batch_dims = batch_dims
138
140
  if batch_dims is not None:
139
141
  self.batch_dims = tuple(int(x) for x in self.batch_dims)
140
142
 
143
+ self.minimum_score, self.maximum_score = minimum_score, maximum_score
144
+
141
145
  # Postprocesing arguments
142
146
  self.fourier_shift = kwargs.get("fourier_shift", None)
143
147
  self.convolution_mode = kwargs.get("convolution_mode", None)
144
148
  self.targetshape = kwargs.get("targetshape", None)
145
149
  self.templateshape = kwargs.get("templateshape", None)
146
150
 
147
- def __iter__(self):
151
+ def __iter__(self) -> Generator:
148
152
  """
149
153
  Returns a generator to list objects containing translation,
150
154
  rotation, score and details of a given candidate.
151
155
  """
152
- self.peak_list = [backend.to_cpu_array(arr) for arr in self.peak_list]
156
+ self.peak_list = [be.to_cpu_array(arr) for arr in self.peak_list]
153
157
  yield from self.peak_list
154
158
 
155
159
  @staticmethod
@@ -181,34 +185,24 @@ class PeakCaller(ABC):
181
185
 
182
186
  yield from _generate_slices_recursive(0, ())
183
187
 
184
- def __call__(
185
- self,
186
- score_space: NDArray,
187
- rotation_matrix: NDArray,
188
- minimum_score: float = None,
189
- maximum_score: float = None,
190
- **kwargs,
191
- ) -> None:
188
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
192
189
  """
193
190
  Update the internal parameter store based on input array.
194
191
 
195
192
  Parameters
196
193
  ----------
197
- score_space : NDArray
198
- Array containing the score space.
199
- rotation_matrix : NDArray
194
+ scores : BackendArray
195
+ Score space data.
196
+ rotation_matrix : BackendArray
200
197
  Rotation matrix used to obtain the score array.
201
- minimum_score : float
202
- Minimum score from which to consider peaks. If provided, superseeds limits
203
- presented by :py:attr:`PeakCaller.number_of_peaks`.
204
- maximum_score : float
205
- Maximum score upon which to consider peaks,
206
198
  **kwargs
207
- Optional keyword arguments passed to :py:meth:`PeakCaller.call_peak`.
199
+ Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
208
200
  """
209
- for subset, offset in self._batchify(score_space.shape, self.batch_dims):
201
+ minimum_score, maximum_score = self.minimum_score, self.maximum_score
202
+ for subset, batch_offset in self._batchify(scores.shape, self.batch_dims):
203
+ batch_offset = be.to_backend_array(batch_offset)
210
204
  peak_positions, peak_details = self.call_peaks(
211
- score_space=score_space[subset],
205
+ scores=scores[subset],
212
206
  rotation_matrix=rotation_matrix,
213
207
  minimum_score=minimum_score,
214
208
  maximum_score=maximum_score,
@@ -221,34 +215,30 @@ class PeakCaller(ABC):
221
215
  continue
222
216
 
223
217
  if peak_details is None:
224
- peak_details = backend.full((peak_positions.shape[0],), fill_value=-1)
218
+ peak_details = be.full((peak_positions.shape[0],), fill_value=-1)
225
219
 
226
- backend.add(peak_positions, offset, out=peak_positions)
227
- peak_positions = backend.astype(peak_positions, int)
220
+ peak_positions = be.to_backend_array(peak_positions)
221
+ peak_positions = be.add(peak_positions, batch_offset, out=peak_positions)
222
+ peak_positions = be.astype(peak_positions, int)
228
223
  if self.min_boundary_distance > 0:
229
- upper_limit = backend.subtract(
230
- score_space.shape, self.min_boundary_distance
224
+ upper_limit = be.subtract(
225
+ be.to_backend_array(scores.shape), self.min_boundary_distance
231
226
  )
232
- valid_peaks = backend.multiply(
227
+ valid_peaks = be.multiply(
233
228
  peak_positions < upper_limit,
234
229
  peak_positions >= self.min_boundary_distance,
235
230
  )
236
231
  if self.batch_dims is not None:
237
232
  valid_peaks[..., self.batch_dims] = True
238
233
 
239
- valid_peaks = (
240
- backend.sum(valid_peaks, axis=1) == peak_positions.shape[1]
241
- )
234
+ valid_peaks = be.sum(valid_peaks, axis=1) == peak_positions.shape[1]
242
235
 
243
- if backend.sum(valid_peaks) == 0:
236
+ if be.sum(valid_peaks) == 0:
244
237
  continue
238
+ peak_positions = peak_positions[valid_peaks]
239
+ peak_details = peak_details[valid_peaks]
245
240
 
246
- peak_positions, peak_details = (
247
- peak_positions[valid_peaks],
248
- peak_details[valid_peaks],
249
- )
250
-
251
- peak_scores = score_space[tuple(peak_positions.T)]
241
+ peak_scores = scores[tuple(peak_positions.T)]
252
242
  if minimum_score is not None:
253
243
  valid_peaks = peak_scores >= minimum_score
254
244
  peak_positions, peak_details, peak_scores = (
@@ -267,7 +257,7 @@ class PeakCaller(ABC):
267
257
  if peak_positions.shape[0] == 0:
268
258
  continue
269
259
 
270
- rotations = backend.repeat(
260
+ rotations = be.repeat(
271
261
  rotation_matrix.reshape(1, *rotation_matrix.shape),
272
262
  peak_positions.shape[0],
273
263
  axis=0,
@@ -278,83 +268,71 @@ class PeakCaller(ABC):
278
268
  peak_details=peak_details,
279
269
  peak_scores=peak_scores,
280
270
  rotations=rotations,
281
- batch_offset=offset,
282
- **kwargs,
283
271
  )
284
272
 
285
273
  return None
286
274
 
287
275
  @abstractmethod
288
- def call_peaks(
289
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
290
- ) -> Tuple[NDArray, NDArray]:
276
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
291
277
  """
292
278
  Call peaks in the score space.
293
279
 
294
- This function is not intended to be called directly, but should rather be
295
- defined by classes inheriting from :py:class:`PeakCaller` to execute a given
296
- peak calling algorithm.
297
-
298
280
  Parameters
299
281
  ----------
300
- score_space : NDArray
301
- Data array of scores.
302
- **kwargs : Dict, optional
303
- Keyword arguments passed to __call__.
282
+ scores : BackendArray
283
+ Score array.
284
+ **kwargs : dict
285
+ Optional keyword arguments passed to underlying implementations.
304
286
 
305
287
  Returns
306
288
  -------
307
- Tuple[NDArray, NDArray]
289
+ Tuple[BackendArray, BackendArray]
308
290
  Array of peak coordinates and peak details.
309
291
  """
310
292
 
311
293
  @classmethod
312
- def merge(cls, candidates=List[List], **kwargs) -> NDArray:
294
+ def merge(cls, candidates=List[List], **kwargs) -> Tuple:
313
295
  """
314
296
  Merge multiple instances of :py:class:`PeakCaller`.
315
297
 
316
298
  Parameters
317
299
  ----------
318
- candidate_fits : list of lists
300
+ candidates : list of lists
319
301
  Obtained by invoking list on the generator returned by __iter__.
320
- param_stores : list of tuples, optional
321
- List of parameter stores. Each tuple contains candidate data and number
322
- of candidates.
323
302
  **kwargs
324
- Additional keyword arguments.
303
+ Optional keyword arguments.
325
304
 
326
305
  Returns
327
306
  -------
328
- NDArray
329
- NDArray of candidates.
307
+ Tuple
308
+ Tuple of translation, rotation, score and details of candidates.
330
309
  """
331
310
  base = cls(**kwargs)
332
311
  for candidate in candidates:
333
312
  if len(candidate) == 0:
334
313
  continue
335
314
  peak_positions, rotations, peak_scores, peak_details = candidate
336
- kwargs["translation_offset"] = backend.zeros(peak_positions.shape[1])
337
315
  base._update(
338
- peak_positions=backend.to_backend_array(peak_positions),
339
- peak_details=backend.to_backend_array(peak_details),
340
- peak_scores=backend.to_backend_array(peak_scores),
341
- rotations=backend.to_backend_array(rotations),
342
- **kwargs,
316
+ peak_positions=be.to_backend_array(peak_positions),
317
+ peak_details=be.to_backend_array(peak_details),
318
+ peak_scores=be.to_backend_array(peak_scores),
319
+ rotations=be.to_backend_array(rotations),
320
+ offset=kwargs.get("offset", None),
343
321
  )
344
322
  return tuple(base)
345
323
 
346
324
  @staticmethod
347
325
  def oversample_peaks(
348
- score_space: NDArray, peak_positions: NDArray, oversampling_factor: int = 8
326
+ scores: BackendArray, peak_positions: BackendArray, oversampling_factor: int = 8
349
327
  ):
350
328
  """
351
329
  Refines peaks positions in the corresponding score space.
352
330
 
353
331
  Parameters
354
332
  ----------
355
- score_space : NDArray
333
+ scores : BackendArray
356
334
  The d-dimensional array representing the score space.
357
- peak_positions : NDArray
335
+ peak_positions : BackendArray
358
336
  An array of shape (n, d) containing the peak coordinates
359
337
  to be refined, where n is the number of peaks and d is the
360
338
  dimensionality of the score space.
@@ -363,14 +341,14 @@ class PeakCaller(ABC):
363
341
 
364
342
  Returns
365
343
  -------
366
- NDArray
344
+ BackendArray
367
345
  An array of shape (n, d) containing the refined subpixel
368
346
  coordinates of the peaks.
369
347
 
370
348
  Notes
371
349
  -----
372
350
  Floating point peak positions are determined by oversampling the
373
- score_space around peak_positions. The accuracy
351
+ scores around peak_positions. The accuracy
374
352
  of refinement scales with 1 / oversampling_factor.
375
353
 
376
354
  References
@@ -382,8 +360,8 @@ class PeakCaller(ABC):
382
360
  DOI:10.1364/OL.33.000156
383
361
 
384
362
  """
385
- score_space = backend.to_numpy_array(score_space)
386
- peak_positions = backend.to_numpy_array(peak_positions)
363
+ scores = be.to_numpy_array(scores)
364
+ peak_positions = be.to_numpy_array(peak_positions)
387
365
 
388
366
  peak_positions = np.round(
389
367
  np.divide(
@@ -396,10 +374,10 @@ class PeakCaller(ABC):
396
374
  dftshift, np.multiply(peak_positions, oversampling_factor)
397
375
  )
398
376
 
399
- score_space_ft = np.fft.fftn(score_space).conj()
377
+ scores_ft = np.fft.fftn(scores).conj()
400
378
  for index in range(sample_region_offset.shape[0]):
401
379
  cross_correlation_upsampled = _upsampled_dft(
402
- data=score_space_ft,
380
+ data=scores_ft,
403
381
  upsampled_region_size=upsampled_region_size,
404
382
  upsample_factor=oversampling_factor,
405
383
  axis_offsets=sample_region_offset[index],
@@ -412,73 +390,66 @@ class PeakCaller(ABC):
412
390
  maxima = np.divide(np.subtract(maxima, dftshift), oversampling_factor)
413
391
  peak_positions[index] = np.add(peak_positions[index], maxima)
414
392
 
415
- peak_positions = backend.to_backend_array(peak_positions)
393
+ peak_positions = be.to_backend_array(peak_positions)
416
394
 
417
395
  return peak_positions
418
396
 
419
397
  def _update(
420
398
  self,
421
- peak_positions: NDArray,
422
- peak_details: NDArray,
423
- peak_scores: NDArray,
424
- rotations: NDArray,
425
- **kwargs,
426
- ) -> None:
399
+ peak_positions: BackendArray,
400
+ peak_details: BackendArray,
401
+ peak_scores: BackendArray,
402
+ rotations: BackendArray,
403
+ offset: BackendArray = None,
404
+ ):
427
405
  """
428
406
  Update internal parameter store.
429
407
 
430
408
  Parameters
431
409
  ----------
432
- peak_positions : NDArray
433
- Position of peaks with shape n x d where n is the number of
434
- peaks and d the dimension.
435
- peak_scores : NDArray
436
- Corresponding score obtained at each peak.
437
- translation_offset : NDArray, optional
438
- Offset of the score_space, occurs e.g. when template matching
439
- to parts of a tomogram.
440
- rotations: NDArray
441
- Rotations used to obtain the score space from which
442
- the candidate stem.
410
+ peak_positions : BackendArray
411
+ Position of peaks (n, d).
412
+ peak_details : BackendArray
413
+ Details of each peak (n, ).
414
+ rotations: BackendArray
415
+ Rotation at each peak (n, d, d).
416
+ rotations: BackendArray
417
+ Rotation at each peak (n, d, d).
418
+ offset : BackendArray, optional
419
+ Translation offset, e.g. from splitting, (n, ).
443
420
  """
444
- translation_offset = kwargs.get(
445
- "translation_offset", backend.zeros(peak_positions.shape[1])
446
- )
447
- translation_offset = backend.astype(translation_offset, peak_positions.dtype)
421
+ if offset is not None:
422
+ offset = be.astype(offset, peak_positions.dtype)
423
+ peak_positions = be.add(peak_positions, offset, out=peak_positions)
448
424
 
449
- backend.add(peak_positions, translation_offset, out=peak_positions)
450
425
  if not len(self.peak_list):
451
426
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
452
-
453
- peak_positions = backend.concatenate((self.peak_list[0], peak_positions))
454
- rotations = backend.concatenate((self.peak_list[1], rotations))
455
- peak_scores = backend.concatenate((self.peak_list[2], peak_scores))
456
- peak_details = backend.concatenate((self.peak_list[3], peak_details))
427
+ else:
428
+ peak_positions = be.concatenate((self.peak_list[0], peak_positions))
429
+ rotations = be.concatenate((self.peak_list[1], rotations))
430
+ peak_scores = be.concatenate((self.peak_list[2], peak_scores))
431
+ peak_details = be.concatenate((self.peak_list[3], peak_details))
457
432
 
458
433
  if self.batch_dims is None:
459
- top_n = min(backend.size(peak_scores), self.number_of_peaks)
460
- top_scores, *_ = backend.topk_indices(peak_scores, top_n)
434
+ top_n = min(be.size(peak_scores), self.number_of_peaks)
435
+ top_scores, *_ = be.topk_indices(peak_scores, top_n)
461
436
  else:
462
437
  # Not very performant but fairly robust
463
438
  batch_indices = peak_positions[..., self.batch_dims]
464
- backend.subtract(
465
- batch_indices, backend.min(batch_indices, axis=0), out=batch_indices
466
- )
467
- multiplier = backend.power(
468
- backend.max(batch_indices, axis=0) + 1,
469
- backend.arange(batch_indices.shape[1]),
439
+ batch_indices = be.subtract(batch_indices, be.min(batch_indices, axis=0))
440
+ multiplier = be.power(
441
+ be.max(batch_indices, axis=0) + 1,
442
+ be.arange(batch_indices.shape[1]),
470
443
  )
471
- backend.multiply(batch_indices, multiplier, out=batch_indices)
472
- batch_indices = backend.sum(batch_indices, axis=1)
473
- unique_indices, batch_counts = backend.unique(
474
- batch_indices, return_counts=True
475
- )
476
- total_indices = backend.arange(peak_scores.shape[0])
444
+ batch_indices = be.multiply(batch_indices, multiplier, out=batch_indices)
445
+ batch_indices = be.sum(batch_indices, axis=1)
446
+ unique_indices, batch_counts = be.unique(batch_indices, return_counts=True)
447
+ total_indices = be.arange(peak_scores.shape[0])
477
448
  batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
478
- top_scores = backend.concatenate(
449
+ top_scores = be.concatenate(
479
450
  [
480
451
  total_indices[indices][
481
- backend.topk_indices(
452
+ be.topk_indices(
482
453
  peak_scores[indices], min(y, self.number_of_peaks)
483
454
  )
484
455
  ]
@@ -500,7 +471,14 @@ class PeakCaller(ABC):
500
471
  self.peak_list[3] = peak_details[final_order]
501
472
 
502
473
  def _postprocess(
503
- self, fourier_shift, convolution_mode, targetshape, templateshape, **kwargs
474
+ self,
475
+ fast_shape: Tuple[int],
476
+ targetshape: Tuple[int],
477
+ templateshape: Tuple[int],
478
+ fourier_shift: Tuple[int] = None,
479
+ convolution_mode: str = None,
480
+ shared_memory_handler=None,
481
+ **kwargs,
504
482
  ):
505
483
  if not len(self.peak_list):
506
484
  return self
@@ -509,52 +487,43 @@ class PeakCaller(ABC):
509
487
  if not len(peak_positions):
510
488
  return self
511
489
 
512
- if targetshape is None or templateshape is None:
513
- return self
514
-
515
- # Remove padding to next fast fourier length
516
- score_space_shape = backend.add(targetshape, templateshape) - 1
517
-
490
+ # Wrap peaks around score space
491
+ fast_shape = be.to_backend_array(fast_shape)
518
492
  if fourier_shift is not None:
519
- peak_positions = backend.add(peak_positions, fourier_shift)
520
-
521
- backend.subtract(
493
+ fourier_shift = be.to_backend_array(fourier_shift)
494
+ peak_positions = be.add(peak_positions, fourier_shift)
495
+ peak_positions = be.subtract(
522
496
  peak_positions,
523
- backend.multiply(
524
- backend.astype(
525
- backend.divide(peak_positions, score_space_shape), int
526
- ),
527
- score_space_shape,
497
+ be.multiply(
498
+ be.astype(be.divide(peak_positions, fast_shape), int),
499
+ fast_shape,
528
500
  ),
529
- out=peak_positions,
530
501
  )
531
502
 
532
- if convolution_mode is None:
533
- return None
534
-
535
- if convolution_mode == "full":
536
- output_shape = score_space_shape
537
- elif convolution_mode == "same":
503
+ # Remove padding to fast Fourier (and potential full convolution) shape
504
+ targetshape = be.to_backend_array(targetshape)
505
+ templateshape = be.to_backend_array(templateshape)
506
+ fast_shape = be.minimum(be.add(targetshape, templateshape) - 1, fast_shape)
507
+ output_shape = fast_shape
508
+ if convolution_mode == "same":
538
509
  output_shape = targetshape
539
510
  elif convolution_mode == "valid":
540
- output_shape = backend.add(
541
- backend.subtract(targetshape, templateshape),
542
- backend.mod(templateshape, 2),
511
+ output_shape = be.add(
512
+ be.subtract(targetshape, templateshape),
513
+ be.mod(templateshape, 2),
543
514
  )
544
515
 
545
- output_shape = backend.to_backend_array(output_shape)
546
- starts = backend.divide(backend.subtract(score_space_shape, output_shape), 2)
547
- starts = backend.astype(starts, int)
548
- stops = backend.add(starts, output_shape)
549
-
550
- valid_peaks = (
551
- backend.sum(
552
- backend.multiply(peak_positions > starts, peak_positions <= stops),
553
- axis=1,
554
- )
555
- == peak_positions.shape[1]
516
+ output_shape = be.to_backend_array(output_shape)
517
+ starts = be.astype(
518
+ be.divide(be.subtract(fast_shape, output_shape), 2),
519
+ be._int_dtype,
556
520
  )
557
- self.peak_list[0] = backend.subtract(peak_positions, starts)
521
+ stops = be.add(starts, output_shape)
522
+
523
+ valid_peaks = be.multiply(peak_positions > starts, peak_positions <= stops)
524
+ valid_peaks = be.sum(valid_peaks, axis=1) == peak_positions.shape[1]
525
+
526
+ self.peak_list[0] = be.subtract(peak_positions, starts)
558
527
  self.peak_list = [x[valid_peaks] for x in self.peak_list]
559
528
  return self
560
529
 
@@ -565,27 +534,14 @@ class PeakCallerSort(PeakCaller):
565
534
  highest scores.
566
535
  """
567
536
 
568
- def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
569
- """
570
- Call peaks in the score space.
571
-
572
- Parameters
573
- ----------
574
- score_space : NDArray
575
- Data array of scores.
576
-
577
- Returns
578
- -------
579
- Tuple[NDArray, NDArray]
580
- Array of peak coordinates and peak details.
581
- """
582
- flat_score_space = score_space.reshape(-1)
583
- k = min(self.number_of_peaks, backend.size(flat_score_space))
537
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
538
+ flat_scores = scores.reshape(-1)
539
+ k = min(self.number_of_peaks, be.size(flat_scores))
584
540
 
585
- top_k_indices, *_ = backend.topk_indices(flat_score_space, k)
541
+ top_k_indices, *_ = be.topk_indices(flat_scores, k)
586
542
 
587
- coordinates = backend.unravel_index(top_k_indices, score_space.shape)
588
- coordinates = backend.transpose(backend.stack(coordinates))
543
+ coordinates = be.unravel_index(top_k_indices, scores.shape)
544
+ coordinates = be.transpose(be.stack(coordinates))
589
545
 
590
546
  return coordinates, None
591
547
 
@@ -594,28 +550,11 @@ class PeakCallerMaximumFilter(PeakCaller):
594
550
  """
595
551
  Find local maxima by applying a maximum filter and enforcing a distance
596
552
  constraint subsequently. This is similar to the strategy implemented in
597
- skimage.feature.peak_local_max.
553
+ :obj:`skimage.feature.peak_local_max`.
598
554
  """
599
555
 
600
- def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
601
- """
602
- Call peaks in the score space.
603
-
604
- Parameters
605
- ----------
606
- score_space : NDArray
607
- Data array of scores.
608
- kwargs: Dict, optional
609
- Optional keyword arguments.
610
-
611
- Returns
612
- -------
613
- Tuple[NDArray, NDArray]
614
- Array of peak coordinates and peak details.
615
- """
616
- peaks = backend.max_filter_coordinates(score_space, self.min_distance)
617
-
618
- return peaks, None
556
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
557
+ return be.max_filter_coordinates(scores, self.min_distance), None
619
558
 
620
559
 
621
560
  class PeakCallerFast(PeakCaller):
@@ -627,56 +566,35 @@ class PeakCallerFast(PeakCaller):
627
566
 
628
567
  """
629
568
 
630
- def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
631
- """
632
- Call peaks in the score space.
633
-
634
- Parameters
635
- ----------
636
- score_space : NDArray
637
- Data array of scores.
569
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
570
+ splits = {i: x // self.min_distance for i, x in enumerate(scores.shape)}
571
+ slices = split_shape(scores.shape, splits)
638
572
 
639
- Returns
640
- -------
641
- Tuple[NDArray, NDArray]
642
- Array of peak coordinates and peak details.
643
- """
644
- splits = {
645
- axis: score_space.shape[axis] // self.min_distance
646
- for axis in range(score_space.ndim)
647
- }
648
- slices = split_numpy_array_slices(score_space.shape, splits)
649
-
650
- coordinates = backend.to_backend_array(
573
+ coordinates = be.to_backend_array(
651
574
  [
652
- backend.unravel_index(
653
- backend.argmax(score_space[subvol]), score_space[subvol].shape
654
- )
575
+ be.unravel_index(be.argmax(scores[subvol]), scores[subvol].shape)
655
576
  for subvol in slices
656
577
  ]
657
578
  )
658
- offset = backend.to_backend_array(
579
+ offset = be.to_backend_array(
659
580
  [tuple(x.start for x in subvol) for subvol in slices]
660
581
  )
661
- backend.add(coordinates, offset, out=coordinates)
662
- coordinates = coordinates[
663
- backend.flip(backend.argsort(score_space[tuple(coordinates.T)]), (0,))
664
- ]
582
+ be.add(coordinates, offset, out=coordinates)
583
+ coordinates = coordinates[be.argsort(-scores[tuple(coordinates.T)])]
665
584
 
666
585
  if coordinates.shape[0] == 0:
667
586
  return None
668
587
 
669
- starts = backend.maximum(coordinates - self.min_distance, 0)
670
- stops = backend.minimum(coordinates + self.min_distance, score_space.shape)
588
+ starts = be.maximum(coordinates - self.min_distance, 0)
589
+ stops = be.minimum(coordinates + self.min_distance, scores.shape)
671
590
  slices_list = [
672
591
  tuple(slice(*coord) for coord in zip(start_row, stop_row))
673
592
  for start_row, stop_row in zip(starts, stops)
674
593
  ]
675
594
 
676
- scores = score_space[tuple(coordinates.T)]
677
595
  keep = [
678
- score >= backend.max(score_space[subvol])
679
- for subvol, score in zip(slices_list, scores)
596
+ score_subvol >= be.max(scores[subvol])
597
+ for subvol, score_subvol in zip(slices_list, scores[tuple(coordinates.T)])
680
598
  ]
681
599
  coordinates = coordinates[keep,]
682
600
 
@@ -694,26 +612,26 @@ class PeakCallerRecursiveMasking(PeakCaller):
694
612
 
695
613
  def call_peaks(
696
614
  self,
697
- score_space: NDArray,
698
- rotation_matrix: NDArray,
699
- mask: NDArray = None,
615
+ scores: BackendArray,
616
+ rotation_matrix: BackendArray,
617
+ mask: BackendArray = None,
700
618
  minimum_score: float = None,
701
- rotation_space: NDArray = None,
619
+ rotation_space: BackendArray = None,
702
620
  rotation_mapping: Dict = None,
703
621
  **kwargs,
704
- ) -> Tuple[NDArray, NDArray]:
622
+ ) -> PeakType:
705
623
  """
706
624
  Call peaks in the score space.
707
625
 
708
626
  Parameters
709
627
  ----------
710
- score_space : NDArray
628
+ scores : BackendArray
711
629
  Data array of scores.
712
- rotation_matrix : NDArray
630
+ rotation_matrix : BackendArray
713
631
  Rotation matrix.
714
- mask : NDArray, optional
632
+ mask : BackendArray, optional
715
633
  Mask array, by default None.
716
- rotation_space : NDArray, optional
634
+ rotation_space : BackendArray, optional
717
635
  Rotation space array, by default None.
718
636
  rotation_mapping : Dict optional
719
637
  Dictionary mapping values in rotation_space to Euler angles.
@@ -724,7 +642,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
724
642
 
725
643
  Returns
726
644
  -------
727
- Tuple[NDArray, NDArray]
645
+ Tuple[BackendArray, BackendArray]
728
646
  Array of peak coordinates and peak details.
729
647
 
730
648
  Notes
@@ -738,26 +656,26 @@ class PeakCallerRecursiveMasking(PeakCaller):
738
656
 
739
657
  if mask is None:
740
658
  masking_function = self._mask_scores_box
741
- shape = tuple(self.min_distance for _ in range(score_space.ndim))
742
- mask = backend.zeros(shape, dtype=backend._float_dtype)
659
+ shape = tuple(self.min_distance for _ in range(scores.ndim))
660
+ mask = be.zeros(shape, dtype=be._float_dtype)
743
661
 
744
- rotated_template = backend.zeros(mask.shape, dtype=mask.dtype)
662
+ rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
745
663
 
746
664
  peak_limit = self.number_of_peaks
747
665
  if minimum_score is not None:
748
- peak_limit = backend.size(score_space)
666
+ peak_limit = be.size(scores)
749
667
  else:
750
- minimum_score = backend.min(score_space) - 1
668
+ minimum_score = be.min(scores) - 1
751
669
 
752
- scores = backend.zeros(score_space.shape, dtype=score_space.dtype)
753
- scores[:] = score_space
670
+ scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
671
+ scores_copy[:] = scores
754
672
 
755
673
  while True:
756
- backend.argmax(scores)
757
- peak = backend.unravel_index(
758
- indices=backend.argmax(scores), shape=scores.shape
674
+ be.argmax(scores_copy)
675
+ peak = be.unravel_index(
676
+ indices=be.argmax(scores_copy), shape=scores_copy.shape
759
677
  )
760
- if scores[tuple(peak)] < minimum_score:
678
+ if scores_copy[tuple(peak)] < minimum_score:
761
679
  break
762
680
 
763
681
  coordinates.append(peak)
@@ -770,7 +688,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
770
688
  )
771
689
 
772
690
  masking_function(
773
- score_space=scores,
691
+ scores=scores_copy,
774
692
  rotation_matrix=current_rotation_matrix,
775
693
  peak=peak,
776
694
  mask=mask,
@@ -780,33 +698,33 @@ class PeakCallerRecursiveMasking(PeakCaller):
780
698
  if len(coordinates) >= peak_limit:
781
699
  break
782
700
 
783
- peaks = backend.to_backend_array(coordinates)
701
+ peaks = be.to_backend_array(coordinates)
784
702
  return peaks, None
785
703
 
786
704
  @staticmethod
787
705
  def _get_rotation_matrix(
788
- peak: NDArray,
789
- rotation_space: NDArray,
790
- rotation_mapping: NDArray,
791
- rotation_matrix: NDArray,
792
- ) -> NDArray:
706
+ peak: BackendArray,
707
+ rotation_space: BackendArray,
708
+ rotation_mapping: BackendArray,
709
+ rotation_matrix: BackendArray,
710
+ ) -> BackendArray:
793
711
  """
794
712
  Get rotation matrix based on peak and rotation data.
795
713
 
796
714
  Parameters
797
715
  ----------
798
- peak : NDArray
716
+ peak : BackendArray
799
717
  Peak coordinates.
800
- rotation_space : NDArray
718
+ rotation_space : BackendArray
801
719
  Rotation space array.
802
720
  rotation_mapping : Dict
803
721
  Dictionary mapping values in rotation_space to Euler angles.
804
- rotation_matrix : NDArray
722
+ rotation_matrix : BackendArray
805
723
  Current rotation matrix.
806
724
 
807
725
  Returns
808
726
  -------
809
- NDArray
727
+ BackendArray
810
728
  Rotation matrix.
811
729
  """
812
730
  if rotation_space is None or rotation_mapping is None:
@@ -814,135 +732,117 @@ class PeakCallerRecursiveMasking(PeakCaller):
814
732
 
815
733
  rotation = rotation_mapping[rotation_space[tuple(peak)]]
816
734
 
817
- rotation_matrix = backend.to_backend_array(
818
- euler_to_rotationmatrix(backend.to_numpy_array(rotation))
819
- )
820
- return rotation_matrix
735
+ # TODO: Newer versions of rotation mapping contain rotation matrices not angles
736
+ if len(rotation) == 3:
737
+ rotation = be.to_backend_array(
738
+ euler_to_rotationmatrix(be.to_numpy_array(rotation))
739
+ )
740
+ return rotation
821
741
 
822
742
  @staticmethod
823
743
  def _mask_scores_box(
824
- score_space: NDArray, peak: NDArray, mask: NDArray, **kwargs: Dict
744
+ scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
825
745
  ) -> None:
826
746
  """
827
747
  Mask scores in a box around a peak.
828
748
 
829
749
  Parameters
830
750
  ----------
831
- score_space : NDArray
751
+ scores : BackendArray
832
752
  Data array of scores.
833
- peak : NDArray
753
+ peak : BackendArray
834
754
  Peak coordinates.
835
- mask : NDArray
755
+ mask : BackendArray
836
756
  Mask array.
837
757
  """
838
- start = backend.maximum(backend.subtract(peak, mask.shape), 0)
839
- stop = backend.minimum(backend.add(peak, mask.shape), score_space.shape)
840
- start, stop = backend.astype(start, int), backend.astype(stop, int)
758
+ start = be.maximum(be.subtract(peak, mask.shape), 0)
759
+ stop = be.minimum(be.add(peak, mask.shape), scores.shape)
760
+ start, stop = be.astype(start, int), be.astype(stop, int)
841
761
  coords = tuple(slice(*pos) for pos in zip(start, stop))
842
- score_space[coords] = 0
762
+ scores[coords] = 0
843
763
  return None
844
764
 
845
765
  @staticmethod
846
766
  def _mask_scores_rotate(
847
- score_space: NDArray,
848
- peak: NDArray,
849
- mask: NDArray,
850
- rotated_template: NDArray,
851
- rotation_matrix: NDArray,
767
+ scores: BackendArray,
768
+ peak: BackendArray,
769
+ mask: BackendArray,
770
+ rotated_template: BackendArray,
771
+ rotation_matrix: BackendArray,
852
772
  **kwargs: Dict,
853
773
  ) -> None:
854
774
  """
855
- Mask score_space using mask rotation around a peak.
775
+ Mask scores using mask rotation around a peak.
856
776
 
857
777
  Parameters
858
778
  ----------
859
- score_space : NDArray
779
+ scores : BackendArray
860
780
  Data array of scores.
861
- peak : NDArray
781
+ peak : BackendArray
862
782
  Peak coordinates.
863
- mask : NDArray
783
+ mask : BackendArray
864
784
  Mask array.
865
- rotated_template : NDArray
785
+ rotated_template : BackendArray
866
786
  Empty array to write mask rotations to.
867
- rotation_matrix : NDArray
787
+ rotation_matrix : BackendArray
868
788
  Rotation matrix.
869
789
  """
870
- left_pad = backend.divide(mask.shape, 2).astype(int)
871
- right_pad = backend.add(left_pad, backend.mod(mask.shape, 2).astype(int))
790
+ left_pad = be.divide(mask.shape, 2).astype(int)
791
+ right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
872
792
 
873
- score_start = backend.subtract(peak, left_pad)
874
- score_stop = backend.add(peak, right_pad)
793
+ score_start = be.subtract(peak, left_pad)
794
+ score_stop = be.add(peak, right_pad)
875
795
 
876
- template_start = backend.subtract(backend.maximum(score_start, 0), score_start)
877
- template_stop = backend.subtract(
878
- score_stop, backend.minimum(score_stop, score_space.shape)
879
- )
880
- template_stop = backend.subtract(mask.shape, template_stop)
796
+ template_start = be.subtract(be.maximum(score_start, 0), score_start)
797
+ template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
798
+ template_stop = be.subtract(mask.shape, template_stop)
881
799
 
882
- score_start = backend.maximum(score_start, 0)
883
- score_stop = backend.minimum(score_stop, score_space.shape)
884
- score_start = backend.astype(score_start, int)
885
- score_stop = backend.astype(score_stop, int)
800
+ score_start = be.maximum(score_start, 0)
801
+ score_stop = be.minimum(score_stop, scores.shape)
802
+ score_start = be.astype(score_start, int)
803
+ score_stop = be.astype(score_stop, int)
886
804
 
887
- template_start = backend.astype(template_start, int)
888
- template_stop = backend.astype(template_stop, int)
805
+ template_start = be.astype(template_start, int)
806
+ template_stop = be.astype(template_stop, int)
889
807
  coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
890
808
  coords_template = tuple(
891
809
  slice(*pos) for pos in zip(template_start, template_stop)
892
810
  )
893
811
 
894
812
  rotated_template.fill(0)
895
- backend.rotate_array(
813
+ be.rigid_transform(
896
814
  arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
897
815
  )
898
816
 
899
- score_space[coords_score] = backend.multiply(
900
- score_space[coords_score], (rotated_template[coords_template] <= 0.1)
817
+ scores[coords_score] = be.multiply(
818
+ scores[coords_score], (rotated_template[coords_template] <= 0.1)
901
819
  )
902
820
  return None
903
821
 
904
822
 
905
823
  class PeakCallerScipy(PeakCaller):
906
824
  """
907
- Peak calling using skimage.feature.peak_local_max to compute local maxima.
825
+ Peak calling using :obj:`skimage.feature.peak_local_max` to compute local maxima.
908
826
  """
909
827
 
910
828
  def call_peaks(
911
- self, score_space: NDArray, minimum_score: float = None, **kwargs
912
- ) -> Tuple[NDArray, NDArray]:
913
- """
914
- Call peaks in the score space.
915
-
916
- Parameters
917
- ----------
918
- score_space : NDArray
919
- Data array of scores.
920
- minimum_score : float
921
- Minimum score value to consider. If provided, superseeds limit given
922
- by :py:attr:`PeakCaller.number_of_peaks`.
923
-
924
- Returns
925
- -------
926
- Tuple[NDArray, NDArray]
927
- Array of peak coordinates and peak details.
928
- """
929
- score_space = backend.to_numpy_array(score_space)
829
+ self, scores: BackendArray, minimum_score: float = None, **kwargs
830
+ ) -> PeakType:
831
+ scores = be.to_numpy_array(scores)
930
832
  num_peaks = self.number_of_peaks
931
833
  if minimum_score is not None:
932
834
  num_peaks = np.inf
933
835
 
934
- non_squeezable_dims = tuple(
935
- i for i, x in enumerate(score_space.shape) if x != 1
936
- )
836
+ non_squeezable_dims = tuple(i for i, x in enumerate(scores.shape) if x != 1)
937
837
  peaks = peak_local_max(
938
- np.squeeze(score_space),
838
+ np.squeeze(scores),
939
839
  num_peaks=num_peaks,
940
840
  min_distance=self.min_distance,
941
841
  threshold_abs=minimum_score,
942
842
  )
943
- peaks_full = np.zeros((peaks.shape[0], score_space.ndim), peaks.dtype)
843
+ peaks_full = np.zeros((peaks.shape[0], scores.ndim), peaks.dtype)
944
844
  peaks_full[..., non_squeezable_dims] = peaks[:]
945
- peaks = backend.to_backend_array(peaks_full)
845
+ peaks = be.to_backend_array(peaks_full)
946
846
  return peaks, None
947
847
 
948
848
 
@@ -967,7 +867,7 @@ class PeakClustering(PeakCallerSort):
967
867
  Parameters
968
868
  ----------
969
869
  **kwargs
970
- Additional keyword arguments passed to :py:meth:`PeakCaller.merge`.
870
+ Optional keyword arguments passed to :py:meth:`PeakCaller.merge`.
971
871
 
972
872
  Returns
973
873
  -------
@@ -999,263 +899,33 @@ class PeakClustering(PeakCallerSort):
999
899
  return peaks, rotations, scores, details
1000
900
 
1001
901
 
1002
- class ScoreStatistics(PeakCallerFast):
1003
- """
1004
- Compute basic statistics on score spaces with respect to a reference
1005
- score or value.
1006
-
1007
- This class is used to evaluate a blurring or scoring method when the correct fit
1008
- is known. It is thread-safe and is designed to be shared among multiple processes
1009
- with write permissions to the internal parameters.
1010
-
1011
- After instantiation, the class's functionality can be accessed through the
1012
- `__call__` method.
1013
-
1014
- Parameters
1015
- ----------
1016
- reference_position : int, optional
1017
- Index of the correct fit in the array passed to call. Defaults to None.
1018
- min_distance : float, optional
1019
- Minimum distance for local maxima. Defaults to None.
1020
- reference_fit : float, optional
1021
- Score of the correct fit. If set, `reference_position` will be ignored.
1022
- Defaults to None.
1023
- number_of_peaks : int, optional
1024
- Number of candidate fits to consider. Defaults to 1.
1025
- """
1026
-
1027
- def __init__(
1028
- self,
1029
- reference_position: Tuple[int] = None,
1030
- min_distance: float = 10,
1031
- reference_fit: float = None,
1032
- number_of_peaks: int = 1,
1033
- ):
1034
- super().__init__(number_of_peaks=number_of_peaks, min_distance=min_distance)
1035
- self.lock = Lock()
1036
-
1037
- self.n = RawValue("Q", 0)
1038
- self.rmean = RawValue("d", 0)
1039
- self.ssqd = RawValue("d", 0)
1040
- self.nbetter_or_equal = RawValue("Q", 0)
1041
- self.maximum_value = RawValue("f", 0)
1042
- self.minimum_value = RawValue("f", 2**32)
1043
- self.shannon_entropy = Manager().list()
1044
- self.candidate_fits = Manager().list()
1045
- self.rotation_names = Manager().list()
1046
- self.reference_fit = RawValue("f", 0)
1047
- self.has_reference = RawValue("i", 0)
1048
-
1049
- self.reference_position = reference_position
1050
- if reference_fit is not None:
1051
- self.reference_fit.value = reference_fit
1052
- self.has_reference.value = 1
1053
-
1054
- def __call__(
1055
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1056
- ) -> None:
1057
- """
1058
- Processes the input array and rotation matrix.
1059
-
1060
- Parameters
1061
- ----------
1062
- arr : NDArray
1063
- Input data array.
1064
- rotation_matrix : NDArray
1065
- Rotation matrix for processing.
1066
- """
1067
- self.set_reference(score_space, rotation_matrix)
1068
-
1069
- while not self.has_reference.value:
1070
- print("Stalling processes until reference_fit has been set.")
1071
- sleep(0.5)
1072
-
1073
- name = "_".join([str(value) for value in rotation_matrix.ravel()])
1074
- n, rmean, ssqd, nbetter_or_equal, max_value, min_value = online_statistics(
1075
- score_space, 0, 0.0, 0.0, self.reference_fit.value
1076
- )
1077
-
1078
- freq, _ = np.histogram(score_space, bins=100)
1079
- shannon_entropy = entropy(freq / score_space.size)
1080
-
1081
- peaks, _ = super().call_peaks(
1082
- score_space=score_space, rotation_matrix=rotation_matrix, **kwargs
1083
- )
1084
- scores = score_space[tuple(peaks.T)]
1085
- rotations = np.repeat(
1086
- rotation_matrix.reshape(1, *rotation_matrix.shape),
1087
- peaks.shape[0],
1088
- axis=0,
1089
- )
1090
- distances = np.linalg.norm(peaks - self.reference_position[None, :], axis=1)
1091
-
1092
- self._update(
1093
- peak_positions=peaks,
1094
- rotations=rotations,
1095
- peak_scores=scores,
1096
- peak_details=distances,
1097
- n=n,
1098
- rmean=rmean,
1099
- ssqd=ssqd,
1100
- nbetter_or_equal=nbetter_or_equal,
1101
- max_value=max_value,
1102
- min_value=min_value,
1103
- entropy=shannon_entropy,
1104
- name=name,
1105
- )
1106
-
1107
- def __iter__(self):
1108
- param_store = (
1109
- self.peak_list[0],
1110
- self.peak_list[1],
1111
- self.peak_list[2],
1112
- self.peak_list[3],
1113
- self.n.value,
1114
- self.rmean.value,
1115
- self.ssqd.value,
1116
- self.nbetter_or_equal.value,
1117
- self.maximum_value.value,
1118
- self.minimum_value.value,
1119
- list(self.shannon_entropy),
1120
- list(self.rotation_names),
1121
- self.reference_fit.value,
1122
- )
1123
- yield from param_store
1124
-
1125
- def _update(
1126
- self,
1127
- n: int,
1128
- rmean: float,
1129
- ssqd: float,
1130
- nbetter_or_equal: int,
1131
- max_value: float,
1132
- min_value: float,
1133
- entropy: float,
1134
- name: str,
1135
- **kwargs,
1136
- ) -> None:
1137
- """
1138
- Updates the internal statistics of the analyzer.
1139
-
1140
- Parameters
1141
- ----------
1142
- n : int
1143
- Sample size.
1144
- rmean : float
1145
- Running mean.
1146
- ssqd : float
1147
- Sum of squared differences.
1148
- nbetter_or_equal : int
1149
- Number of values better or equal to reference.
1150
- max_value : float
1151
- Maximum value.
1152
- min_value : float
1153
- Minimum value.
1154
- entropy : float
1155
- Shannon entropy.
1156
- candidates : list
1157
- List of candidate fits.
1158
- name : str
1159
- Name or label for the data.
1160
- kwargs : dict
1161
- Keyword arguments passed to PeakCaller._update.
1162
- """
1163
- with self.lock:
1164
- super()._update(**kwargs)
1165
-
1166
- n_total = self.n.value + n
1167
- delta = rmean - self.rmean.value
1168
- delta2 = delta * delta
1169
- self.rmean.value += delta * n / n_total
1170
- self.ssqd.value += ssqd + delta2 * (n * self.n.value) / n_total
1171
- self.n.value = n_total
1172
- self.nbetter_or_equal.value += nbetter_or_equal
1173
- self.minimum_value.value = min(self.minimum_value.value, min_value)
1174
- self.maximum_value.value = max(self.maximum_value.value, max_value)
1175
- self.shannon_entropy.append(entropy)
1176
- self.rotation_names.append(name)
1177
-
1178
- @classmethod
1179
- def merge(cls, param_stores: List[Tuple]) -> Tuple:
1180
- """
1181
- Merges multiple instances of :py:class`ScoreStatistics`.
1182
-
1183
- Parameters
1184
- ----------
1185
- param_stores : list of tuple
1186
- Internal parameter store. Obtained by running `tuple(instance)`.
1187
- Defaults to a list with two empty tuples.
1188
-
1189
- Returns
1190
- -------
1191
- tuple
1192
- Contains the reference fit, the z-transform of the reference fit,
1193
- number of scores, and various other statistics.
1194
- """
1195
- base = cls(reference_position=np.zeros(3, int))
1196
- for param_store in param_stores:
1197
- base._update(
1198
- peak_positions=param_store[0],
1199
- rotations=param_store[1],
1200
- peak_scores=param_store[2],
1201
- peak_details=param_store[3],
1202
- n=param_store[4],
1203
- rmean=param_store[5],
1204
- ssqd=param_store[6],
1205
- nbetter_or_equal=param_store[7],
1206
- max_value=param_store[8],
1207
- min_value=param_store[9],
1208
- entropy=param_store[10],
1209
- name=param_store[11],
1210
- )
1211
- base.reference_fit.value = param_store[12]
1212
- return tuple(base)
1213
-
1214
- def set_reference(self, score_space: NDArray, rotation_matrix: NDArray) -> None:
1215
- """
1216
- Sets the reference for the analyzer based on the input array
1217
- and rotation matrix.
1218
-
1219
- Parameters
1220
- ----------
1221
- score_space : NDArray
1222
- Input data array.
1223
- rotation_matrix : NDArray
1224
- Rotation matrix for setting reference.
1225
- """
1226
- is_ref = np.allclose(
1227
- rotation_matrix,
1228
- np.eye(rotation_matrix.shape[0], dtype=rotation_matrix.dtype),
1229
- )
1230
- if not is_ref:
1231
- return None
1232
-
1233
- reference_position = self.reference_position
1234
- if reference_position is None:
1235
- reference_position = np.divide(score_space.shape, 2).astype(int)
1236
- self.reference_position = reference_position
1237
- self.reference_fit.value = score_space[tuple(reference_position)]
1238
- self.has_reference.value = 1
1239
-
1240
-
1241
902
  class MaxScoreOverRotations:
1242
903
  """
1243
- Obtain the maximum translation score over various rotations.
904
+ Determine the rotation maximizing the score of all given translations.
1244
905
 
1245
906
  Attributes
1246
907
  ----------
1247
- score_space : NDArray
1248
- The score space for the observed rotations.
1249
- rotations : NDArray
1250
- The rotation identifiers for each score.
1251
- translation_offset : NDArray, optional
1252
- The offset applied during translation.
1253
- observed_rotations : int
1254
- Count of observed rotations.
908
+ shape : tuple of ints.
909
+ Shape of ``scores`` and rotations.
910
+ scores : BackendArray
911
+ Array mapping translations to scores.
912
+ rotations : BackendArray
913
+ Array mapping translations to rotation indices.
914
+ rotation_mapping : Dict
915
+ Mapping of rotation matrix bytestrings to rotation indices.
916
+ offset : BackendArray, optional
917
+ Coordinate origin considered during merging, zero by default
1255
918
  use_memmap : bool, optional
1256
- Whether to offload internal data arrays to disk
919
+ Memmap scores and rotations arrays, False by default.
1257
920
  thread_safe: bool, optional
1258
- Whether access to internal data arrays should be thread safe
921
+ Allow class to be modified by multiple processes, True by default.
922
+ only_unique_rotations : bool, optional
923
+ Whether each rotation will be shown only once, False by default.
924
+
925
+ Raises
926
+ ------
927
+ ValueError
928
+ If the data shape cannot be determined from the parameters.
1259
929
 
1260
930
  Examples
1261
931
  --------
@@ -1263,11 +933,7 @@ class MaxScoreOverRotations:
1263
933
  instance
1264
934
 
1265
935
  >>> from tme.analyzer import MaxScoreOverRotations
1266
- >>> analyzer = MaxScoreOverRotations(
1267
- >>> score_space_shape = (50, 50),
1268
- >>> score_space_dtype = np.float32,
1269
- >>> rotation_space_dtype = np.int32,
1270
- >>> )
936
+ >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
1271
937
 
1272
938
  The following simulates a template matching run by creating random data for a range
1273
939
  of rotations and sending it to ``analyzer`` via its __call__ method
@@ -1275,9 +941,9 @@ class MaxScoreOverRotations:
1275
941
  >>> for rotation_number in range(10):
1276
942
  >>> scores = np.random.rand(50,50)
1277
943
  >>> rotation = np.random.rand(scores.ndim, scores.ndim)
1278
- >>> analyzer(score_space = scores, rotation_matrix = rotation)
944
+ >>> analyzer(scores = scores, rotation_matrix = rotation)
1279
945
 
1280
- The aggregated scores can be exctracted by invoking the __iter__ method of
946
+ The aggregated scores can be extracted by invoking the __iter__ method of
1281
947
  ``analyzer``
1282
948
 
1283
949
  >>> results = tuple(analyzer)
@@ -1288,7 +954,7 @@ class MaxScoreOverRotations:
1288
954
  score for a given translation, (4) a dictionary mapping rotation matrices to the
1289
955
  indices used in (2).
1290
956
 
1291
- We can extract the ``optimal_score`, ``optimal_translation`` and ``optimal_rotation``
957
+ We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
1292
958
  as follows
1293
959
 
1294
960
  >>> optimal_score = results[0].max()
@@ -1307,201 +973,175 @@ class MaxScoreOverRotations:
1307
973
 
1308
974
  def __init__(
1309
975
  self,
1310
- score_space_shape: Tuple[int],
1311
- score_space_dtype: type,
1312
- translation_offset: NDArray = None,
976
+ shape: Tuple[int] = None,
977
+ scores: BackendArray = None,
978
+ rotations: BackendArray = None,
979
+ offset: BackendArray = None,
1313
980
  score_threshold: float = 0,
1314
981
  shared_memory_handler: object = None,
1315
- rotation_space_dtype: type = int,
1316
982
  use_memmap: bool = False,
1317
983
  thread_safe: bool = True,
984
+ only_unique_rotations: bool = False,
1318
985
  **kwargs,
1319
986
  ):
1320
- score_space_shape = tuple(int(x) for x in score_space_shape)
1321
- self.score_space = backend.arr_to_sharedarr(
1322
- backend.full(
1323
- shape=score_space_shape,
1324
- dtype=score_space_dtype,
987
+ if shape is None and scores is None:
988
+ raise ValueError("Either scores_shape or scores need to be specified.")
989
+
990
+ if scores is None:
991
+ shape = tuple(int(x) for x in shape)
992
+ scores = be.full(
993
+ shape=shape,
994
+ dtype=be._float_dtype,
1325
995
  fill_value=score_threshold,
1326
- ),
1327
- shared_memory_handler=shared_memory_handler,
1328
- )
1329
- self.rotations = backend.arr_to_sharedarr(
1330
- backend.full(score_space_shape, dtype=rotation_space_dtype, fill_value=-1),
1331
- shared_memory_handler,
1332
- )
1333
- if translation_offset is None:
1334
- translation_offset = backend.zeros(len(score_space_shape))
996
+ )
997
+ self.scores, self.shape = scores, scores.shape
998
+
999
+ if rotations is None:
1000
+ rotations = be.full(shape, dtype=be._int_dtype, fill_value=-1)
1001
+ self.rotations = rotations
1002
+
1003
+ self.scores_dtype = self.scores.dtype
1004
+ self.rotations_dtype = self.rotations.dtype
1005
+ self.scores = be.to_sharedarr(self.scores, shared_memory_handler)
1006
+ self.rotations = be.to_sharedarr(self.rotations, shared_memory_handler)
1335
1007
 
1336
- self.translation_offset = backend.astype(translation_offset, int)
1337
- self.score_space_shape = score_space_shape
1338
- self.rotation_space_dtype = rotation_space_dtype
1339
- self.score_space_dtype = score_space_dtype
1008
+ if offset is None:
1009
+ offset = be.zeros(len(self.shape), be._int_dtype)
1010
+ self.offset = be.astype(offset, int)
1340
1011
 
1341
1012
  self.use_memmap = use_memmap
1342
1013
  self.lock = Manager().Lock() if thread_safe else nullcontext()
1343
- self.lock_is_nullcontext = isinstance(
1344
- self.score_space, type(backend.zeros((1)))
1345
- )
1346
- self.observed_rotations = Manager().dict() if thread_safe else {}
1014
+ self.lock_is_nullcontext = isinstance(self.scores, type(be.zeros((1))))
1015
+ self.rotation_mapping = Manager().dict() if thread_safe else {}
1016
+ self._inversion_mapping = self.lock_is_nullcontext and only_unique_rotations
1347
1017
 
1348
1018
  def _postprocess(
1349
1019
  self,
1350
- fourier_shift,
1351
- convolution_mode,
1352
- targetshape,
1353
- templateshape,
1020
+ targetshape: Tuple[int],
1021
+ templateshape: Tuple[int],
1022
+ fourier_shift: Tuple[int] = None,
1023
+ convolution_mode: str = None,
1354
1024
  shared_memory_handler=None,
1025
+ fast_shape: Tuple[int] = None,
1355
1026
  **kwargs,
1356
- ):
1357
- internal_scores = backend.sharedarr_to_arr(
1358
- shape=self.score_space_shape,
1359
- dtype=self.score_space_dtype,
1360
- shm=self.score_space,
1361
- )
1362
- internal_rotations = backend.sharedarr_to_arr(
1363
- shape=self.score_space_shape,
1364
- dtype=self.rotation_space_dtype,
1365
- shm=self.rotations,
1366
- )
1367
-
1027
+ ) -> "MaxScoreOverRotations":
1028
+ """
1029
+ Correct padding to Fourier (and if requested convolution) shape.
1030
+ """
1031
+ scores = be.from_sharedarr(self.scores)
1032
+ rotations = be.from_sharedarr(self.rotations)
1368
1033
  if fourier_shift is not None:
1369
1034
  axis = tuple(i for i in range(len(fourier_shift)))
1370
- internal_scores = backend.roll(
1371
- internal_scores, shift=fourier_shift, axis=axis
1372
- )
1373
- internal_rotations = backend.roll(
1374
- internal_rotations, shift=fourier_shift, axis=axis
1375
- )
1035
+ scores = be.roll(scores, shift=fourier_shift, axis=axis)
1036
+ rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
1376
1037
 
1038
+ convargs = {
1039
+ "s1": targetshape,
1040
+ "s2": templateshape,
1041
+ "convolution_mode": convolution_mode,
1042
+ }
1377
1043
  if convolution_mode is not None:
1378
- internal_scores = apply_convolution_mode(
1379
- internal_scores,
1380
- convolution_mode=convolution_mode,
1381
- s1=targetshape,
1382
- s2=templateshape,
1383
- )
1384
- internal_rotations = apply_convolution_mode(
1385
- internal_rotations,
1386
- convolution_mode=convolution_mode,
1387
- s1=targetshape,
1388
- s2=templateshape,
1389
- )
1044
+ scores = apply_convolution_mode(scores, **convargs)
1045
+ rotations = apply_convolution_mode(rotations, **convargs)
1390
1046
 
1391
- self.score_space_shape = internal_scores.shape
1392
- self.score_space = backend.arr_to_sharedarr(
1393
- internal_scores, shared_memory_handler
1394
- )
1395
- self.rotations = backend.arr_to_sharedarr(
1396
- internal_rotations, shared_memory_handler
1397
- )
1047
+ self.shape = scores.shape
1048
+ self.scores = be.to_sharedarr(scores, shared_memory_handler)
1049
+ self.rotations = be.to_sharedarr(rotations, shared_memory_handler)
1398
1050
  return self
1399
1051
 
1400
- def __iter__(self):
1401
- internal_scores = backend.sharedarr_to_arr(
1402
- shape=self.score_space_shape,
1403
- dtype=self.score_space_dtype,
1404
- shm=self.score_space,
1405
- )
1406
- internal_rotations = backend.sharedarr_to_arr(
1407
- shape=self.score_space_shape,
1408
- dtype=self.rotation_space_dtype,
1409
- shm=self.rotations,
1410
- )
1052
+ def __iter__(self) -> Generator:
1053
+ scores = be.from_sharedarr(self.scores)
1054
+ rotations = be.from_sharedarr(self.rotations)
1411
1055
 
1412
- internal_scores = backend.to_numpy_array(internal_scores)
1413
- internal_rotations = backend.to_numpy_array(internal_rotations)
1056
+ scores = be.to_numpy_array(scores)
1057
+ rotations = be.to_numpy_array(rotations)
1414
1058
  if self.use_memmap:
1415
- internal_scores_filename = array_to_memmap(internal_scores)
1416
- internal_rotations_filename = array_to_memmap(internal_rotations)
1417
- internal_scores = np.memmap(
1418
- internal_scores_filename,
1059
+ scores = np.memmap(
1060
+ array_to_memmap(scores),
1419
1061
  mode="r",
1420
- dtype=internal_scores.dtype,
1421
- shape=internal_scores.shape,
1062
+ dtype=scores.dtype,
1063
+ shape=scores.shape,
1422
1064
  )
1423
- internal_rotations = np.memmap(
1424
- internal_rotations_filename,
1065
+ rotations = np.memmap(
1066
+ array_to_memmap(rotations),
1425
1067
  mode="r",
1426
- dtype=internal_rotations.dtype,
1427
- shape=internal_rotations.shape,
1068
+ dtype=rotations.dtype,
1069
+ shape=rotations.shape,
1428
1070
  )
1429
1071
  else:
1430
- # Avoid invalidation by shared memory handler with copy
1431
- internal_scores = internal_scores.copy()
1432
- internal_rotations = internal_rotations.copy()
1072
+ # Copy to avoid invalidation by shared memory handler
1073
+ scores, rotations = scores.copy(), rotations.copy()
1074
+
1075
+ if self._inversion_mapping:
1076
+ self.rotation_mapping = {
1077
+ be.tobytes(v): k for k, v in self.rotation_mapping.items()
1078
+ }
1433
1079
 
1434
1080
  param_store = (
1435
- internal_scores,
1436
- backend.to_numpy_array(self.translation_offset),
1437
- internal_rotations,
1438
- dict(self.observed_rotations),
1081
+ scores,
1082
+ be.to_numpy_array(self.offset),
1083
+ rotations,
1084
+ dict(self.rotation_mapping),
1439
1085
  )
1440
1086
  yield from param_store
1441
1087
 
1442
- def __call__(
1443
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1444
- ) -> None:
1088
+ def __call__(self, scores: BackendArray, rotation_matrix: BackendArray):
1445
1089
  """
1446
- Update internal parameter store based on `score_space`.
1090
+ Update internal parameter store based on `scores`.
1447
1091
 
1448
1092
  Parameters
1449
1093
  ----------
1450
- score_space : ndarray
1451
- Numpy array containing the score space.
1452
- rotation_matrix : ndarray
1094
+ scores : BackendArray
1095
+ Array containing the score space.
1096
+ rotation_matrix : BackendArray
1453
1097
  Square matrix describing the current rotation.
1454
- **kwargs
1455
- Arbitrary keyword arguments.
1456
1098
  """
1457
- rotation = backend.tobytes(rotation_matrix)
1458
-
1099
+ # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
1100
+ # If the analyzer is not shared and each rotation is unique, we can
1101
+ # use index to rotation mapping and invert prior to merging.
1459
1102
  if self.lock_is_nullcontext:
1460
- rotation_index = self.observed_rotations.setdefault(
1461
- rotation, len(self.observed_rotations)
1462
- )
1463
- backend.max_score_over_rotations(
1464
- score_space=score_space,
1465
- internal_scores=self.score_space,
1466
- internal_rotations=self.rotations,
1103
+ rotation_index = len(self.rotation_mapping)
1104
+ if self._inversion_mapping:
1105
+ self.rotation_mapping[rotation_index] = rotation_matrix
1106
+ else:
1107
+ rotation = be.tobytes(rotation_matrix)
1108
+ rotation_index = self.rotation_mapping.setdefault(
1109
+ rotation, rotation_index
1110
+ )
1111
+ self.scores, self.rotations = be.max_score_over_rotations(
1112
+ scores=scores,
1113
+ max_scores=self.scores,
1114
+ rotations=self.rotations,
1467
1115
  rotation_index=rotation_index,
1468
1116
  )
1469
1117
  return None
1470
1118
 
1119
+ rotation = be.tobytes(rotation_matrix)
1471
1120
  with self.lock:
1472
- rotation_index = self.observed_rotations.setdefault(
1473
- rotation, len(self.observed_rotations)
1474
- )
1475
- internal_scores = backend.sharedarr_to_arr(
1476
- shape=self.score_space_shape,
1477
- dtype=self.score_space_dtype,
1478
- shm=self.score_space,
1121
+ rotation_index = self.rotation_mapping.setdefault(
1122
+ rotation, len(self.rotation_mapping)
1479
1123
  )
1480
- internal_rotations = backend.sharedarr_to_arr(
1481
- shape=self.score_space_shape,
1482
- dtype=self.rotation_space_dtype,
1483
- shm=self.rotations,
1484
- )
1485
-
1486
- backend.max_score_over_rotations(
1487
- score_space=score_space,
1488
- internal_scores=internal_scores,
1489
- internal_rotations=internal_rotations,
1124
+ internal_scores = be.from_sharedarr(self.scores)
1125
+ internal_rotations = be.from_sharedarr(self.rotations)
1126
+ internal_sores, internal_rotations = be.max_score_over_rotations(
1127
+ scores=scores,
1128
+ max_scores=internal_scores,
1129
+ rotations=internal_rotations,
1490
1130
  rotation_index=rotation_index,
1491
1131
  )
1492
1132
  return None
1493
1133
 
1494
1134
  @classmethod
1495
- def merge(cls, param_stores=List[Tuple], **kwargs) -> Tuple[NDArray]:
1135
+ def merge(cls, param_stores: List[Tuple], **kwargs) -> Tuple[NDArray]:
1496
1136
  """
1497
1137
  Merges multiple instances of :py:class:`MaxScoreOverRotations`.
1498
1138
 
1499
1139
  Parameters
1500
1140
  ----------
1501
- param_stores : list of tuples, optional
1141
+ param_stores : list of tuples
1502
1142
  Internal parameter store. Obtained by running `tuple(instance)`.
1503
1143
  **kwargs
1504
- Arbitrary keyword arguments.
1144
+ Optional keyword arguments.
1505
1145
 
1506
1146
  Returns
1507
1147
  -------
@@ -1513,51 +1153,51 @@ class MaxScoreOverRotations:
1513
1153
  if len(param_stores) == 1:
1514
1154
  return param_stores[0]
1515
1155
 
1516
- new_rotation_mapping, base_max = {}, None
1517
- scores_out_dtype, rotations_out_dtype = None, None
1156
+ # Determine output array shape and create consistent rotation map
1157
+ new_rotation_mapping, out_shape = {}, None
1518
1158
  for i in range(len(param_stores)):
1519
1159
  if param_stores[i] is None:
1520
1160
  continue
1521
- score_space, offset, rotations, rotation_mapping = param_stores[i]
1522
- if base_max is None:
1523
- base_max = np.zeros(score_space.ndim, int)
1524
- scores_out_dtype = score_space.dtype
1525
- rotations_out_dtype = rotations.dtype
1526
- np.maximum(base_max, np.add(offset, score_space.shape), out=base_max)
1161
+
1162
+ scores, offset, rotations, rotation_mapping = param_stores[i]
1163
+ if out_shape is None:
1164
+ out_shape = np.zeros(scores.ndim, int)
1165
+ scores_dtype, rotations_dtype = scores.dtype, rotations.dtype
1166
+ out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
1527
1167
 
1528
1168
  for key, value in rotation_mapping.items():
1529
1169
  if key not in new_rotation_mapping:
1530
1170
  new_rotation_mapping[key] = len(new_rotation_mapping)
1531
1171
 
1532
- if base_max is None:
1172
+ if out_shape is None:
1533
1173
  return None
1534
1174
 
1535
- base_max = tuple(int(x) for x in base_max)
1175
+ out_shape = tuple(int(x) for x in out_shape)
1536
1176
  use_memmap = kwargs.get("use_memmap", False)
1537
1177
  if use_memmap:
1538
1178
  scores_out_filename = generate_tempfile_name()
1539
1179
  rotations_out_filename = generate_tempfile_name()
1540
1180
 
1541
1181
  scores_out = np.memmap(
1542
- scores_out_filename, mode="w+", shape=base_max, dtype=scores_out_dtype
1182
+ scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
1543
1183
  )
1544
1184
  scores_out.fill(kwargs.get("score_threshold", 0))
1545
1185
  scores_out.flush()
1546
1186
  rotations_out = np.memmap(
1547
1187
  rotations_out_filename,
1548
1188
  mode="w+",
1549
- shape=base_max,
1550
- dtype=rotations_out_dtype,
1189
+ shape=out_shape,
1190
+ dtype=rotations_dtype,
1551
1191
  )
1552
1192
  rotations_out.fill(-1)
1553
1193
  rotations_out.flush()
1554
1194
  else:
1555
1195
  scores_out = np.full(
1556
- base_max,
1196
+ out_shape,
1557
1197
  fill_value=kwargs.get("score_threshold", 0),
1558
- dtype=scores_out_dtype,
1198
+ dtype=scores_dtype,
1559
1199
  )
1560
- rotations_out = np.full(base_max, fill_value=-1, dtype=rotations_out_dtype)
1200
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
1561
1201
 
1562
1202
  for i in range(len(param_stores)):
1563
1203
  if param_stores[i] is None:
@@ -1567,21 +1207,21 @@ class MaxScoreOverRotations:
1567
1207
  scores_out = np.memmap(
1568
1208
  scores_out_filename,
1569
1209
  mode="r+",
1570
- shape=base_max,
1571
- dtype=scores_out_dtype,
1210
+ shape=out_shape,
1211
+ dtype=scores_dtype,
1572
1212
  )
1573
1213
  rotations_out = np.memmap(
1574
1214
  rotations_out_filename,
1575
1215
  mode="r+",
1576
- shape=base_max,
1577
- dtype=rotations_out_dtype,
1216
+ shape=out_shape,
1217
+ dtype=rotations_dtype,
1578
1218
  )
1579
- score_space, offset, rotations, rotation_mapping = param_stores[i]
1580
- stops = np.add(offset, score_space.shape).astype(int)
1219
+ scores, offset, rotations, rotation_mapping = param_stores[i]
1220
+ stops = np.add(offset, scores.shape).astype(int)
1581
1221
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
1582
1222
 
1583
- indices_update = score_space > scores_out[indices]
1584
- scores_out[indices][indices_update] = score_space[indices_update]
1223
+ indices_update = scores > scores_out[indices]
1224
+ scores_out[indices][indices_update] = scores[indices_update]
1585
1225
 
1586
1226
  lookup_table = np.arange(
1587
1227
  len(rotation_mapping) + 1, dtype=rotations_out.dtype
@@ -1594,24 +1234,24 @@ class MaxScoreOverRotations:
1594
1234
  rotations_out[indices][indices_update] = lookup_table[updated_rotations]
1595
1235
 
1596
1236
  if use_memmap:
1597
- score_space._mmap.close()
1237
+ scores._mmap.close()
1598
1238
  rotations._mmap.close()
1599
1239
  scores_out.flush()
1600
1240
  rotations_out.flush()
1601
1241
  scores_out, rotations_out = None, None
1602
1242
 
1603
1243
  param_stores[i] = None
1604
- score_space, rotations = None, None
1244
+ scores, rotations = None, None
1605
1245
 
1606
1246
  if use_memmap:
1607
1247
  scores_out = np.memmap(
1608
- scores_out_filename, mode="r", shape=base_max, dtype=scores_out_dtype
1248
+ scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
1609
1249
  )
1610
1250
  rotations_out = np.memmap(
1611
1251
  rotations_out_filename,
1612
1252
  mode="r",
1613
- shape=base_max,
1614
- dtype=rotations_out_dtype,
1253
+ shape=out_shape,
1254
+ dtype=rotations_dtype,
1615
1255
  )
1616
1256
  return (
1617
1257
  scores_out,
@@ -1620,6 +1260,10 @@ class MaxScoreOverRotations:
1620
1260
  new_rotation_mapping,
1621
1261
  )
1622
1262
 
1263
+ @property
1264
+ def shared(self):
1265
+ return True
1266
+
1623
1267
 
1624
1268
  class _MaxScoreOverTranslations(MaxScoreOverRotations):
1625
1269
  """
@@ -1627,11 +1271,11 @@ class _MaxScoreOverTranslations(MaxScoreOverRotations):
1627
1271
 
1628
1272
  Attributes
1629
1273
  ----------
1630
- score_space : NDArray
1274
+ scores : BackendArray
1631
1275
  The score space for the observed rotations.
1632
- rotations : NDArray
1276
+ rotations : BackendArray
1633
1277
  The rotation identifiers for each score.
1634
- translation_offset : NDArray, optional
1278
+ translation_offset : BackendArray, optional
1635
1279
  The offset applied during translation.
1636
1280
  observed_rotations : int
1637
1281
  Count of observed rotations.
@@ -1642,36 +1286,36 @@ class _MaxScoreOverTranslations(MaxScoreOverRotations):
1642
1286
  """
1643
1287
 
1644
1288
  def __call__(
1645
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1289
+ self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs
1646
1290
  ) -> None:
1647
1291
  """
1648
- Update internal parameter store based on `score_space`.
1292
+ Update internal parameter store based on `scores`.
1649
1293
 
1650
1294
  Parameters
1651
1295
  ----------
1652
- score_space : ndarray
1296
+ scores : BackendArray
1653
1297
  Numpy array containing the score space.
1654
- rotation_matrix : ndarray
1298
+ rotation_matrix : BackendArray
1655
1299
  Square matrix describing the current rotation.
1656
1300
  **kwargs
1657
- Arbitrary keyword arguments.
1301
+ Optional keyword arguments.
1658
1302
  """
1659
1303
  from tme.matching_utils import centered_mask
1660
1304
 
1661
1305
  with self.lock:
1662
- rotation = backend.tobytes(rotation_matrix)
1306
+ rotation = be.tobytes(rotation_matrix)
1663
1307
  if rotation not in self.observed_rotations:
1664
1308
  self.observed_rotations[rotation] = len(self.observed_rotations)
1665
- score_space = centered_mask(score_space, kwargs["template_shape"])
1309
+ scores = centered_mask(scores, kwargs["template_shape"])
1666
1310
  rotation_index = self.observed_rotations[rotation]
1667
- internal_scores = backend.sharedarr_to_arr(
1668
- shape=self.score_space_shape,
1669
- dtype=self.score_space_dtype,
1670
- shm=self.score_space,
1311
+ internal_scores = be.from_sharedarr(
1312
+ shape=self.shape,
1313
+ dtype=self.scores_dtype,
1314
+ shm=self.scores,
1671
1315
  )
1672
- max_score = score_space.max(axis=(1, 2, 3))
1673
- mean_score = score_space.mean(axis=(1, 2, 3))
1674
- std_score = score_space.std(axis=(1, 2, 3))
1316
+ max_score = scores.max(axis=(1, 2, 3))
1317
+ mean_score = scores.mean(axis=(1, 2, 3))
1318
+ std_score = scores.std(axis=(1, 2, 3))
1675
1319
  z_score = (max_score - mean_score) / std_score
1676
1320
  internal_scores[rotation_index] = z_score
1677
1321
 
@@ -1695,7 +1339,7 @@ class MemmapHandler:
1695
1339
  indices : tuple of slice, optional
1696
1340
  Slices specifying which parts of the memmap array will be updated by `__call__`.
1697
1341
  **kwargs
1698
- Arbitrary keyword arguments.
1342
+ Optional keyword arguments.
1699
1343
  """
1700
1344
 
1701
1345
  def __init__(
@@ -1718,15 +1362,13 @@ class MemmapHandler:
1718
1362
  self.dtype = dtype
1719
1363
  self._indices = indices
1720
1364
 
1721
- def __call__(
1722
- self, score_space: NDArray, rotation_matrix: NDArray, **kwargs
1723
- ) -> None:
1365
+ def __call__(self, scores: NDArray, rotation_matrix: NDArray) -> None:
1724
1366
  """
1725
- Write `score_space` to memmap object on disk.
1367
+ Write `scores` to memmap object on disk.
1726
1368
 
1727
1369
  Parameters
1728
1370
  ----------
1729
- score_space : ndarray
1371
+ scores : ndarray
1730
1372
  Numpy array containing the score space.
1731
1373
  rotation_matrix : ndarray
1732
1374
  Square matrix describing the current rotation.
@@ -1738,7 +1380,7 @@ class MemmapHandler:
1738
1380
  array = np.memmap(current_object, mode="r+", shape=self.shape, dtype=self.dtype)
1739
1381
  # Does not really need a lock because processes operate on different rotations
1740
1382
  with self.lock:
1741
- array[self._indices] += score_space
1383
+ array[self._indices] += scores
1742
1384
  array.flush()
1743
1385
 
1744
1386
  def __iter__(self):