pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/analyzer.py CHANGED
@@ -43,14 +43,29 @@ def filter_points_indices_bucket(
43
43
 
44
44
 
45
45
  def filter_points_indices(
46
- coordinates: NDArray, min_distance: float, bucket_cutoff: int = 1e4
46
+ coordinates: NDArray,
47
+ min_distance: float,
48
+ bucket_cutoff: int = 1e4,
49
+ batch_dims: Tuple[int] = None,
47
50
  ) -> NDArray:
48
51
  if min_distance <= 0:
49
52
  return backend.arange(coordinates.shape[0])
53
+ if coordinates.shape[0] == 0:
54
+ return ()
55
+
56
+ if batch_dims is not None:
57
+ coordinates_new = backend.zeros(coordinates.shape, coordinates.dtype)
58
+ coordinates_new[:] = coordinates
59
+ coordinates_new[..., batch_dims] = backend.astype(
60
+ coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
61
+ )
62
+ coordinates = coordinates_new
50
63
 
51
64
  if isinstance(coordinates, np.ndarray):
52
65
  return find_candidate_indices(coordinates, min_distance)
53
- elif coordinates.shape[0] > bucket_cutoff:
66
+ elif coordinates.shape[0] > bucket_cutoff or not isinstance(
67
+ coordinates, np.ndarray
68
+ ):
54
69
  return filter_points_indices_bucket(coordinates, min_distance)
55
70
  distances = np.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
56
71
  distances = np.tril(distances)
@@ -59,8 +74,10 @@ def filter_points_indices(
59
74
  return indices[keep == indices]
60
75
 
61
76
 
62
- def filter_points(coordinates: NDArray, min_distance: Tuple[int]) -> NDArray:
63
- unique_indices = filter_points_indices(coordinates, min_distance)
77
+ def filter_points(
78
+ coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
79
+ ) -> NDArray:
80
+ unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
64
81
  coordinates = coordinates[unique_indices]
65
82
  return coordinates
66
83
 
@@ -77,6 +94,8 @@ class PeakCaller(ABC):
77
94
  Minimum distance between peaks.
78
95
  min_boundary_distance : int, optional
79
96
  Minimum distance to array boundaries.
97
+ batch_dims : int, optional
98
+ Peak calling batch dimensions.
80
99
  **kwargs
81
100
  Additional keyword arguments.
82
101
 
@@ -92,6 +111,7 @@ class PeakCaller(ABC):
92
111
  number_of_peaks: int = 1000,
93
112
  min_distance: int = 1,
94
113
  min_boundary_distance: int = 0,
114
+ batch_dims: Tuple[int] = None,
95
115
  **kwargs,
96
116
  ):
97
117
  number_of_peaks = int(number_of_peaks)
@@ -114,6 +134,10 @@ class PeakCaller(ABC):
114
134
  self.min_boundary_distance = min_boundary_distance
115
135
  self.number_of_peaks = number_of_peaks
116
136
 
137
+ self.batch_dims = batch_dims
138
+ if batch_dims is not None:
139
+ self.batch_dims = tuple(int(x) for x in self.batch_dims)
140
+
117
141
  # Postprocesing arguments
118
142
  self.fourier_shift = kwargs.get("fourier_shift", None)
119
143
  self.convolution_mode = kwargs.get("convolution_mode", None)
@@ -128,10 +152,41 @@ class PeakCaller(ABC):
128
152
  self.peak_list = [backend.to_cpu_array(arr) for arr in self.peak_list]
129
153
  yield from self.peak_list
130
154
 
155
+ @staticmethod
156
+ def _batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
157
+ if batch_dims is None:
158
+ yield (tuple(slice(None) for _ in shape), tuple(0 for _ in shape))
159
+ return None
160
+
161
+ batch_ranges = [range(shape[dim]) for dim in batch_dims]
162
+
163
+ def _generate_slices_recursive(current_dim, current_indices):
164
+ if current_dim == len(batch_dims):
165
+ slice_list, offset_list, batch_index = [], [], 0
166
+ for i in range(len(shape)):
167
+ if i in batch_dims:
168
+ index = current_indices[batch_index]
169
+ slice_list.append(slice(index, index + 1))
170
+ offset_list.append(index)
171
+ batch_index += 1
172
+ else:
173
+ slice_list.append(slice(None))
174
+ offset_list.append(0)
175
+ yield (tuple(slice_list), tuple(offset_list))
176
+ else:
177
+ for index in batch_ranges[current_dim]:
178
+ yield from _generate_slices_recursive(
179
+ current_dim + 1, current_indices + (index,)
180
+ )
181
+
182
+ yield from _generate_slices_recursive(0, ())
183
+
131
184
  def __call__(
132
185
  self,
133
186
  score_space: NDArray,
134
187
  rotation_matrix: NDArray,
188
+ minimum_score: float = None,
189
+ maximum_score: float = None,
135
190
  **kwargs,
136
191
  ) -> None:
137
192
  """
@@ -143,59 +198,91 @@ class PeakCaller(ABC):
143
198
  Array containing the score space.
144
199
  rotation_matrix : NDArray
145
200
  Rotation matrix used to obtain the score array.
201
+ minimum_score : float
202
+ Minimum score from which to consider peaks. If provided, superseeds limits
203
+ presented by :py:attr:`PeakCaller.number_of_peaks`.
204
+ maximum_score : float
205
+ Maximum score upon which to consider peaks,
146
206
  **kwargs
147
207
  Optional keyword arguments passed to :py:meth:`PeakCaller.call_peak`.
148
208
  """
149
- peak_positions, peak_details = self.call_peaks(
150
- score_space=score_space, rotation_matrix=rotation_matrix, **kwargs
151
- )
209
+ for subset, offset in self._batchify(score_space.shape, self.batch_dims):
210
+ peak_positions, peak_details = self.call_peaks(
211
+ score_space=score_space[subset],
212
+ rotation_matrix=rotation_matrix,
213
+ minimum_score=minimum_score,
214
+ maximum_score=maximum_score,
215
+ **kwargs,
216
+ )
152
217
 
153
- if peak_positions is None:
154
- return None
218
+ if peak_positions is None:
219
+ continue
220
+ if peak_positions.shape[0] == 0:
221
+ continue
155
222
 
156
- peak_positions = backend.astype(peak_positions, int)
157
- if peak_positions.shape[0] == 0:
158
- return None
223
+ if peak_details is None:
224
+ peak_details = backend.full((peak_positions.shape[0],), fill_value=-1)
225
+
226
+ backend.add(peak_positions, offset, out=peak_positions)
227
+ peak_positions = backend.astype(peak_positions, int)
228
+ if self.min_boundary_distance > 0:
229
+ upper_limit = backend.subtract(
230
+ score_space.shape, self.min_boundary_distance
231
+ )
232
+ valid_peaks = backend.multiply(
233
+ peak_positions < upper_limit,
234
+ peak_positions >= self.min_boundary_distance,
235
+ )
236
+ if self.batch_dims is not None:
237
+ valid_peaks[..., self.batch_dims] = True
159
238
 
160
- if peak_details is None:
161
- peak_details = backend.to_backend_array([-1] * peak_positions.shape[0])
239
+ valid_peaks = (
240
+ backend.sum(valid_peaks, axis=1) == peak_positions.shape[1]
241
+ )
162
242
 
163
- if self.min_boundary_distance > 0:
164
- upper_limit = backend.subtract(
165
- score_space.shape, self.min_boundary_distance
166
- )
167
- valid_peaks = (
168
- backend.sum(
169
- backend.multiply(
170
- peak_positions < upper_limit,
171
- peak_positions >= self.min_boundary_distance,
172
- ),
173
- axis=1,
243
+ if backend.sum(valid_peaks) == 0:
244
+ continue
245
+
246
+ peak_positions, peak_details = (
247
+ peak_positions[valid_peaks],
248
+ peak_details[valid_peaks],
174
249
  )
175
- == peak_positions.shape[1]
176
- )
177
- if backend.sum(valid_peaks) == 0:
178
- return None
179
250
 
180
- peak_positions, peak_details = (
181
- peak_positions[valid_peaks],
182
- peak_details[valid_peaks],
251
+ peak_scores = score_space[tuple(peak_positions.T)]
252
+ if minimum_score is not None:
253
+ valid_peaks = peak_scores >= minimum_score
254
+ peak_positions, peak_details, peak_scores = (
255
+ peak_positions[valid_peaks],
256
+ peak_details[valid_peaks],
257
+ peak_scores[valid_peaks],
258
+ )
259
+ if maximum_score is not None:
260
+ valid_peaks = peak_scores <= maximum_score
261
+ peak_positions, peak_details, peak_scores = (
262
+ peak_positions[valid_peaks],
263
+ peak_details[valid_peaks],
264
+ peak_scores[valid_peaks],
265
+ )
266
+
267
+ if peak_positions.shape[0] == 0:
268
+ continue
269
+
270
+ rotations = backend.repeat(
271
+ rotation_matrix.reshape(1, *rotation_matrix.shape),
272
+ peak_positions.shape[0],
273
+ axis=0,
183
274
  )
184
275
 
185
- rotations = backend.repeat(
186
- rotation_matrix.reshape(1, *rotation_matrix.shape),
187
- peak_positions.shape[0],
188
- axis=0,
189
- )
190
- peak_scores = score_space[tuple(peak_positions.T)]
276
+ self._update(
277
+ peak_positions=peak_positions,
278
+ peak_details=peak_details,
279
+ peak_scores=peak_scores,
280
+ rotations=rotations,
281
+ batch_offset=offset,
282
+ **kwargs,
283
+ )
191
284
 
192
- self._update(
193
- peak_positions=peak_positions,
194
- peak_details=peak_details,
195
- peak_scores=peak_scores,
196
- rotations=rotations,
197
- **kwargs,
198
- )
285
+ return None
199
286
 
200
287
  @abstractmethod
201
288
  def call_peaks(
@@ -212,10 +299,8 @@ class PeakCaller(ABC):
212
299
  ----------
213
300
  score_space : NDArray
214
301
  Data array of scores.
215
- minimum_score : float
216
- Minimum score value to consider.
217
- min_distance : float
218
- Minimum distance between maxima.
302
+ **kwargs : Dict, optional
303
+ Keyword arguments passed to __call__.
219
304
 
220
305
  Returns
221
306
  -------
@@ -364,25 +449,52 @@ class PeakCaller(ABC):
364
449
  backend.add(peak_positions, translation_offset, out=peak_positions)
365
450
  if not len(self.peak_list):
366
451
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
367
- dim = peak_positions.shape[1]
368
- peak_scores = backend.zeros((0,), peak_scores.dtype)
369
- peak_details = backend.zeros((0,), peak_details.dtype)
370
- rotations = backend.zeros((0, dim, dim), rotations.dtype)
371
- peak_positions = backend.zeros((0, dim), peak_positions.dtype)
372
452
 
373
- peaks = backend.concatenate((self.peak_list[0], peak_positions))
453
+ peak_positions = backend.concatenate((self.peak_list[0], peak_positions))
374
454
  rotations = backend.concatenate((self.peak_list[1], rotations))
375
455
  peak_scores = backend.concatenate((self.peak_list[2], peak_scores))
376
456
  peak_details = backend.concatenate((self.peak_list[3], peak_details))
377
457
 
378
- top_n = min(backend.size(peak_scores), self.number_of_peaks)
379
- top_scores, *_ = backend.topk_indices(peak_scores, top_n)
458
+ if self.batch_dims is None:
459
+ top_n = min(backend.size(peak_scores), self.number_of_peaks)
460
+ top_scores, *_ = backend.topk_indices(peak_scores, top_n)
461
+ else:
462
+ # Not very performant but fairly robust
463
+ batch_indices = peak_positions[..., self.batch_dims]
464
+ backend.subtract(
465
+ batch_indices, backend.min(batch_indices, axis=0), out=batch_indices
466
+ )
467
+ multiplier = backend.power(
468
+ backend.max(batch_indices, axis=0) + 1,
469
+ backend.arange(batch_indices.shape[1]),
470
+ )
471
+ backend.multiply(batch_indices, multiplier, out=batch_indices)
472
+ batch_indices = backend.sum(batch_indices, axis=1)
473
+ unique_indices, batch_counts = backend.unique(
474
+ batch_indices, return_counts=True
475
+ )
476
+ total_indices = backend.arange(peak_scores.shape[0])
477
+ batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
478
+ top_scores = backend.concatenate(
479
+ [
480
+ total_indices[indices][
481
+ backend.topk_indices(
482
+ peak_scores[indices], min(y, self.number_of_peaks)
483
+ )
484
+ ]
485
+ for indices, y in zip(batch_indices, batch_counts)
486
+ ]
487
+ )
380
488
 
381
489
  final_order = top_scores[
382
- filter_points_indices(peaks[top_scores], self.min_distance)
490
+ filter_points_indices(
491
+ coordinates=peak_positions[top_scores],
492
+ min_distance=self.min_distance,
493
+ batch_dims=self.batch_dims,
494
+ )
383
495
  ]
384
496
 
385
- self.peak_list[0] = peaks[final_order,]
497
+ self.peak_list[0] = peak_positions[final_order,]
386
498
  self.peak_list[1] = rotations[final_order,]
387
499
  self.peak_list[2] = peak_scores[final_order]
388
500
  self.peak_list[3] = peak_details[final_order]
@@ -390,6 +502,9 @@ class PeakCaller(ABC):
390
502
  def _postprocess(
391
503
  self, fourier_shift, convolution_mode, targetshape, templateshape, **kwargs
392
504
  ):
505
+ if not len(self.peak_list):
506
+ return self
507
+
393
508
  peak_positions = self.peak_list[0]
394
509
  if not len(peak_positions):
395
510
  return self
@@ -402,15 +517,16 @@ class PeakCaller(ABC):
402
517
 
403
518
  if fourier_shift is not None:
404
519
  peak_positions = backend.add(peak_positions, fourier_shift)
405
- backend.divide(peak_positions, score_space_shape).astype(int)
406
520
 
407
521
  backend.subtract(
408
522
  peak_positions,
409
523
  backend.multiply(
410
- backend.divide(peak_positions, score_space_shape).astype(int),
411
- score_space_shape
524
+ backend.astype(
525
+ backend.divide(peak_positions, score_space_shape), int
526
+ ),
527
+ score_space_shape,
412
528
  ),
413
- out = peak_positions
529
+ out=peak_positions,
414
530
  )
415
531
 
416
532
  if convolution_mode is None:
@@ -423,23 +539,17 @@ class PeakCaller(ABC):
423
539
  elif convolution_mode == "valid":
424
540
  output_shape = backend.add(
425
541
  backend.subtract(targetshape, templateshape),
426
- backend.mod(templateshape, 2)
542
+ backend.mod(templateshape, 2),
427
543
  )
428
544
 
429
545
  output_shape = backend.to_backend_array(output_shape)
430
- starts = backend.divide(
431
- backend.subtract(score_space_shape, output_shape),
432
- 2
433
- )
546
+ starts = backend.divide(backend.subtract(score_space_shape, output_shape), 2)
434
547
  starts = backend.astype(starts, int)
435
548
  stops = backend.add(starts, output_shape)
436
549
 
437
550
  valid_peaks = (
438
551
  backend.sum(
439
- backend.multiply(
440
- peak_positions > starts,
441
- peak_positions <= stops
442
- ),
552
+ backend.multiply(peak_positions > starts, peak_positions <= stops),
443
553
  axis=1,
444
554
  )
445
555
  == peak_positions.shape[1]
@@ -448,17 +558,14 @@ class PeakCaller(ABC):
448
558
  self.peak_list = [x[valid_peaks] for x in self.peak_list]
449
559
  return self
450
560
 
561
+
451
562
  class PeakCallerSort(PeakCaller):
452
563
  """
453
564
  A :py:class:`PeakCaller` subclass that first selects ``number_of_peaks``
454
- highest scores and subsequently filters local maxima to suffice a distance
455
- from one another of ``min_distance``.
456
-
565
+ highest scores.
457
566
  """
458
567
 
459
- def call_peaks(
460
- self, score_space: NDArray, minimum_score: float = None, **kwargs
461
- ) -> Tuple[NDArray, NDArray]:
568
+ def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
462
569
  """
463
570
  Call peaks in the score space.
464
571
 
@@ -466,9 +573,6 @@ class PeakCallerSort(PeakCaller):
466
573
  ----------
467
574
  score_space : NDArray
468
575
  Data array of scores.
469
- minimum_score : float
470
- Minimum score value to consider. If provided, superseeds limit given
471
- by :py:attr:`PeakCaller.number_of_peaks`.
472
576
 
473
577
  Returns
474
578
  -------
@@ -478,16 +582,12 @@ class PeakCallerSort(PeakCaller):
478
582
  flat_score_space = score_space.reshape(-1)
479
583
  k = min(self.number_of_peaks, backend.size(flat_score_space))
480
584
 
481
- if minimum_score is not None:
482
- k = backend.sum(score_space >= minimum_score)
483
-
484
585
  top_k_indices, *_ = backend.topk_indices(flat_score_space, k)
485
586
 
486
587
  coordinates = backend.unravel_index(top_k_indices, score_space.shape)
487
588
  coordinates = backend.transpose(backend.stack(coordinates))
488
589
 
489
- peaks = filter_points(coordinates, self.min_distance)
490
- return peaks, None
590
+ return coordinates, None
491
591
 
492
592
 
493
593
  class PeakCallerMaximumFilter(PeakCaller):
@@ -497,9 +597,7 @@ class PeakCallerMaximumFilter(PeakCaller):
497
597
  skimage.feature.peak_local_max.
498
598
  """
499
599
 
500
- def call_peaks(
501
- self, score_space: NDArray, minimum_score: float = None, **kwargs
502
- ) -> Tuple[NDArray, NDArray]:
600
+ def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
503
601
  """
504
602
  Call peaks in the score space.
505
603
 
@@ -507,9 +605,8 @@ class PeakCallerMaximumFilter(PeakCaller):
507
605
  ----------
508
606
  score_space : NDArray
509
607
  Data array of scores.
510
- minimum_score : float
511
- Minimum score value to consider. If provided, superseeds limit given
512
- by :py:attr:`PeakCaller.number_of_peaks`.
608
+ kwargs: Dict, optional
609
+ Optional keyword arguments.
513
610
 
514
611
  Returns
515
612
  -------
@@ -518,17 +615,6 @@ class PeakCallerMaximumFilter(PeakCaller):
518
615
  """
519
616
  peaks = backend.max_filter_coordinates(score_space, self.min_distance)
520
617
 
521
- scores = score_space[tuple(peaks.T)]
522
-
523
- input_candidates = min(
524
- self.number_of_peaks, peaks.shape[0] - 1, backend.size(score_space) - 1
525
- )
526
- if minimum_score is not None:
527
- input_candidates = backend.sum(scores >= minimum_score)
528
-
529
- top_indices = backend.topk_indices(scores, input_candidates)
530
- peaks = peaks[top_indices]
531
-
532
618
  return peaks, None
533
619
 
534
620
 
@@ -541,9 +627,7 @@ class PeakCallerFast(PeakCaller):
541
627
 
542
628
  """
543
629
 
544
- def call_peaks(
545
- self, score_space: NDArray, minimum_score: float = None, **kwargs
546
- ) -> Tuple[NDArray, NDArray]:
630
+ def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
547
631
  """
548
632
  Call peaks in the score space.
549
633
 
@@ -551,9 +635,6 @@ class PeakCallerFast(PeakCaller):
551
635
  ----------
552
636
  score_space : NDArray
553
637
  Data array of scores.
554
- minimum_score : float
555
- Minimum score value to consider. If provided, superseeds limit given
556
- by :py:attr:`PeakCaller.number_of_peaks`.
557
638
 
558
639
  Returns
559
640
  -------
@@ -585,26 +666,24 @@ class PeakCallerFast(PeakCaller):
585
666
  if coordinates.shape[0] == 0:
586
667
  return None
587
668
 
588
- peaks = filter_points(coordinates, self.min_distance)
589
-
590
- starts = backend.maximum(peaks - self.min_distance, 0)
591
- stops = backend.minimum(peaks + self.min_distance, score_space.shape)
669
+ starts = backend.maximum(coordinates - self.min_distance, 0)
670
+ stops = backend.minimum(coordinates + self.min_distance, score_space.shape)
592
671
  slices_list = [
593
672
  tuple(slice(*coord) for coord in zip(start_row, stop_row))
594
673
  for start_row, stop_row in zip(starts, stops)
595
674
  ]
596
675
 
597
- scores = score_space[tuple(peaks.T)]
676
+ scores = score_space[tuple(coordinates.T)]
598
677
  keep = [
599
678
  score >= backend.max(score_space[subvol])
600
679
  for subvol, score in zip(slices_list, scores)
601
680
  ]
602
- peaks = peaks[keep,]
681
+ coordinates = coordinates[keep,]
603
682
 
604
- if len(peaks) == 0:
605
- return peaks, None
683
+ if len(coordinates) == 0:
684
+ return coordinates, None
606
685
 
607
- return peaks, None
686
+ return coordinates, None
608
687
 
609
688
 
610
689
  class PeakCallerRecursiveMasking(PeakCaller):
@@ -660,7 +739,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
660
739
  if mask is None:
661
740
  masking_function = self._mask_scores_box
662
741
  shape = tuple(self.min_distance for _ in range(score_space.ndim))
663
- mask = backend.zeros(shape, dtype=backend._default_dtype)
742
+ mask = backend.zeros(shape, dtype=backend._float_dtype)
664
743
 
665
744
  rotated_template = backend.zeros(mask.shape, dtype=mask.dtype)
666
745
 
@@ -670,12 +749,15 @@ class PeakCallerRecursiveMasking(PeakCaller):
670
749
  else:
671
750
  minimum_score = backend.min(score_space) - 1
672
751
 
752
+ scores = backend.zeros(score_space.shape, dtype=score_space.dtype)
753
+ scores[:] = score_space
754
+
673
755
  while True:
674
- backend.argmax(score_space)
756
+ backend.argmax(scores)
675
757
  peak = backend.unravel_index(
676
- indices=backend.argmax(score_space), shape=score_space.shape
758
+ indices=backend.argmax(scores), shape=scores.shape
677
759
  )
678
- if score_space[tuple(peak)] < minimum_score:
760
+ if scores[tuple(peak)] < minimum_score:
679
761
  break
680
762
 
681
763
  coordinates.append(peak)
@@ -688,7 +770,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
688
770
  )
689
771
 
690
772
  masking_function(
691
- score_space=score_space,
773
+ score_space=scores,
692
774
  rotation_matrix=current_rotation_matrix,
693
775
  peak=peak,
694
776
  mask=mask,
@@ -844,18 +926,23 @@ class PeakCallerScipy(PeakCaller):
844
926
  Tuple[NDArray, NDArray]
845
927
  Array of peak coordinates and peak details.
846
928
  """
847
-
848
929
  score_space = backend.to_numpy_array(score_space)
849
930
  num_peaks = self.number_of_peaks
850
931
  if minimum_score is not None:
851
932
  num_peaks = np.inf
852
933
 
934
+ non_squeezable_dims = tuple(
935
+ i for i, x in enumerate(score_space.shape) if x != 1
936
+ )
853
937
  peaks = peak_local_max(
854
- score_space,
938
+ np.squeeze(score_space),
855
939
  num_peaks=num_peaks,
856
940
  min_distance=self.min_distance,
857
941
  threshold_abs=minimum_score,
858
942
  )
943
+ peaks_full = np.zeros((peaks.shape[0], score_space.ndim), peaks.dtype)
944
+ peaks_full[..., non_squeezable_dims] = peaks[:]
945
+ peaks = backend.to_backend_array(peaks_full)
859
946
  return peaks, None
860
947
 
861
948
 
@@ -1169,6 +1256,53 @@ class MaxScoreOverRotations:
1169
1256
  Whether to offload internal data arrays to disk
1170
1257
  thread_safe: bool, optional
1171
1258
  Whether access to internal data arrays should be thread safe
1259
+
1260
+ Examples
1261
+ --------
1262
+ The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
1263
+ instance
1264
+
1265
+ >>> from tme.analyzer import MaxScoreOverRotations
1266
+ >>> analyzer = MaxScoreOverRotations(
1267
+ >>> score_space_shape = (50, 50),
1268
+ >>> score_space_dtype = np.float32,
1269
+ >>> rotation_space_dtype = np.int32,
1270
+ >>> )
1271
+
1272
+ The following simulates a template matching run by creating random data for a range
1273
+ of rotations and sending it to ``analyzer`` via its __call__ method
1274
+
1275
+ >>> for rotation_number in range(10):
1276
+ >>> scores = np.random.rand(50,50)
1277
+ >>> rotation = np.random.rand(scores.ndim, scores.ndim)
1278
+ >>> analyzer(score_space = scores, rotation_matrix = rotation)
1279
+
1280
+ The aggregated scores can be exctracted by invoking the __iter__ method of
1281
+ ``analyzer``
1282
+
1283
+ >>> results = tuple(analyzer)
1284
+
1285
+ The ``results`` tuple contains (1) the maximum scores for each translation,
1286
+ (2) an offset which is relevant when merging results from split template matching
1287
+ using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
1288
+ score for a given translation, (4) a dictionary mapping rotation matrices to the
1289
+ indices used in (2).
1290
+
1291
+ We can extract the ``optimal_score`, ``optimal_translation`` and ``optimal_rotation``
1292
+ as follows
1293
+
1294
+ >>> optimal_score = results[0].max()
1295
+ >>> optimal_translation = np.where(results[0] == results[0].max())
1296
+ >>> optimal_rotation_index = results[2][optimal_translation]
1297
+ >>> for key, value in results[3].items():
1298
+ >>> if value != optimal_rotation_index:
1299
+ >>> continue
1300
+ >>> optimal_rotation = np.frombuffer(key, rotation.dtype)
1301
+ >>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
1302
+
1303
+ The outlined procedure is a trivial method to identify high scoring peaks.
1304
+ Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
1305
+ that can be used.
1172
1306
  """
1173
1307
 
1174
1308
  def __init__(
@@ -1206,18 +1340,20 @@ class MaxScoreOverRotations:
1206
1340
 
1207
1341
  self.use_memmap = use_memmap
1208
1342
  self.lock = Manager().Lock() if thread_safe else nullcontext()
1209
- self.lock_is_nullcontext = isinstance(self.score_space, type(backend.zeros((1))))
1343
+ self.lock_is_nullcontext = isinstance(
1344
+ self.score_space, type(backend.zeros((1)))
1345
+ )
1210
1346
  self.observed_rotations = Manager().dict() if thread_safe else {}
1211
1347
 
1212
-
1213
- def _postprocess(self,
1348
+ def _postprocess(
1349
+ self,
1214
1350
  fourier_shift,
1215
1351
  convolution_mode,
1216
1352
  targetshape,
1217
1353
  templateshape,
1218
1354
  shared_memory_handler=None,
1219
- **kwargs
1220
- ):
1355
+ **kwargs,
1356
+ ):
1221
1357
  internal_scores = backend.sharedarr_to_arr(
1222
1358
  shape=self.score_space_shape,
1223
1359
  dtype=self.score_space_dtype,
@@ -1232,14 +1368,10 @@ class MaxScoreOverRotations:
1232
1368
  if fourier_shift is not None:
1233
1369
  axis = tuple(i for i in range(len(fourier_shift)))
1234
1370
  internal_scores = backend.roll(
1235
- internal_scores,
1236
- shift=fourier_shift,
1237
- axis=axis
1371
+ internal_scores, shift=fourier_shift, axis=axis
1238
1372
  )
1239
1373
  internal_rotations = backend.roll(
1240
- internal_rotations,
1241
- shift=fourier_shift,
1242
- axis=axis
1374
+ internal_rotations, shift=fourier_shift, axis=axis
1243
1375
  )
1244
1376
 
1245
1377
  if convolution_mode is not None:
@@ -1247,27 +1379,24 @@ class MaxScoreOverRotations:
1247
1379
  internal_scores,
1248
1380
  convolution_mode=convolution_mode,
1249
1381
  s1=targetshape,
1250
- s2=templateshape
1382
+ s2=templateshape,
1251
1383
  )
1252
1384
  internal_rotations = apply_convolution_mode(
1253
1385
  internal_rotations,
1254
1386
  convolution_mode=convolution_mode,
1255
1387
  s1=targetshape,
1256
- s2=templateshape
1388
+ s2=templateshape,
1257
1389
  )
1258
1390
 
1259
1391
  self.score_space_shape = internal_scores.shape
1260
1392
  self.score_space = backend.arr_to_sharedarr(
1261
- internal_scores,
1262
- shared_memory_handler
1393
+ internal_scores, shared_memory_handler
1263
1394
  )
1264
1395
  self.rotations = backend.arr_to_sharedarr(
1265
- internal_rotations,
1266
- shared_memory_handler
1396
+ internal_rotations, shared_memory_handler
1267
1397
  )
1268
1398
  return self
1269
1399
 
1270
-
1271
1400
  def __iter__(self):
1272
1401
  internal_scores = backend.sharedarr_to_arr(
1273
1402
  shape=self.score_space_shape,
@@ -1326,11 +1455,11 @@ class MaxScoreOverRotations:
1326
1455
  Arbitrary keyword arguments.
1327
1456
  """
1328
1457
  rotation = backend.tobytes(rotation_matrix)
1329
- rotation_index = self.observed_rotations.setdefault(
1330
- rotation, len(self.observed_rotations)
1331
- )
1332
1458
 
1333
1459
  if self.lock_is_nullcontext:
1460
+ rotation_index = self.observed_rotations.setdefault(
1461
+ rotation, len(self.observed_rotations)
1462
+ )
1334
1463
  backend.max_score_over_rotations(
1335
1464
  score_space=score_space,
1336
1465
  internal_scores=self.score_space,
@@ -1340,6 +1469,9 @@ class MaxScoreOverRotations:
1340
1469
  return None
1341
1470
 
1342
1471
  with self.lock:
1472
+ rotation_index = self.observed_rotations.setdefault(
1473
+ rotation, len(self.observed_rotations)
1474
+ )
1343
1475
  internal_scores = backend.sharedarr_to_arr(
1344
1476
  shape=self.score_space_shape,
1345
1477
  dtype=self.score_space_dtype,
@@ -1647,5 +1779,3 @@ class MemmapHandler:
1647
1779
  """
1648
1780
  rotation_string = "_".join(rotation_matrix.ravel().astype(str))
1649
1781
  return self._path_translation[rotation_string]
1650
-
1651
-