pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/analyzer.py CHANGED
@@ -43,14 +43,29 @@ def filter_points_indices_bucket(
43
43
 
44
44
 
45
45
  def filter_points_indices(
46
- coordinates: NDArray, min_distance: float, bucket_cutoff: int = 1e4
46
+ coordinates: NDArray,
47
+ min_distance: float,
48
+ bucket_cutoff: int = 1e4,
49
+ batch_dims: Tuple[int] = None,
47
50
  ) -> NDArray:
48
51
  if min_distance <= 0:
49
52
  return backend.arange(coordinates.shape[0])
53
+ if coordinates.shape[0] == 0:
54
+ return ()
55
+
56
+ if batch_dims is not None:
57
+ coordinates_new = backend.zeros(coordinates.shape, coordinates.dtype)
58
+ coordinates_new[:] = coordinates
59
+ coordinates_new[..., batch_dims] = backend.astype(
60
+ coordinates[..., batch_dims] * (2 * min_distance), coordinates_new.dtype
61
+ )
62
+ coordinates = coordinates_new
50
63
 
51
64
  if isinstance(coordinates, np.ndarray):
52
65
  return find_candidate_indices(coordinates, min_distance)
53
- elif coordinates.shape[0] > bucket_cutoff:
66
+ elif coordinates.shape[0] > bucket_cutoff or not isinstance(
67
+ coordinates, np.ndarray
68
+ ):
54
69
  return filter_points_indices_bucket(coordinates, min_distance)
55
70
  distances = np.linalg.norm(coordinates[:, None] - coordinates, axis=-1)
56
71
  distances = np.tril(distances)
@@ -59,8 +74,10 @@ def filter_points_indices(
59
74
  return indices[keep == indices]
60
75
 
61
76
 
62
- def filter_points(coordinates: NDArray, min_distance: Tuple[int]) -> NDArray:
63
- unique_indices = filter_points_indices(coordinates, min_distance)
77
+ def filter_points(
78
+ coordinates: NDArray, min_distance: Tuple[int], batch_dims: Tuple[int] = None
79
+ ) -> NDArray:
80
+ unique_indices = filter_points_indices(coordinates, min_distance, batch_dims)
64
81
  coordinates = coordinates[unique_indices]
65
82
  return coordinates
66
83
 
@@ -77,6 +94,8 @@ class PeakCaller(ABC):
77
94
  Minimum distance between peaks.
78
95
  min_boundary_distance : int, optional
79
96
  Minimum distance to array boundaries.
97
+ batch_dims : int, optional
98
+ Peak calling batch dimensions.
80
99
  **kwargs
81
100
  Additional keyword arguments.
82
101
 
@@ -92,6 +111,7 @@ class PeakCaller(ABC):
92
111
  number_of_peaks: int = 1000,
93
112
  min_distance: int = 1,
94
113
  min_boundary_distance: int = 0,
114
+ batch_dims: Tuple[int] = None,
95
115
  **kwargs,
96
116
  ):
97
117
  number_of_peaks = int(number_of_peaks)
@@ -114,6 +134,10 @@ class PeakCaller(ABC):
114
134
  self.min_boundary_distance = min_boundary_distance
115
135
  self.number_of_peaks = number_of_peaks
116
136
 
137
+ self.batch_dims = batch_dims
138
+ if batch_dims is not None:
139
+ self.batch_dims = tuple(int(x) for x in self.batch_dims)
140
+
117
141
  # Postprocesing arguments
118
142
  self.fourier_shift = kwargs.get("fourier_shift", None)
119
143
  self.convolution_mode = kwargs.get("convolution_mode", None)
@@ -128,10 +152,41 @@ class PeakCaller(ABC):
128
152
  self.peak_list = [backend.to_cpu_array(arr) for arr in self.peak_list]
129
153
  yield from self.peak_list
130
154
 
155
+ @staticmethod
156
+ def _batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
157
+ if batch_dims is None:
158
+ yield (tuple(slice(None) for _ in shape), tuple(0 for _ in shape))
159
+ return None
160
+
161
+ batch_ranges = [range(shape[dim]) for dim in batch_dims]
162
+
163
+ def _generate_slices_recursive(current_dim, current_indices):
164
+ if current_dim == len(batch_dims):
165
+ slice_list, offset_list, batch_index = [], [], 0
166
+ for i in range(len(shape)):
167
+ if i in batch_dims:
168
+ index = current_indices[batch_index]
169
+ slice_list.append(slice(index, index + 1))
170
+ offset_list.append(index)
171
+ batch_index += 1
172
+ else:
173
+ slice_list.append(slice(None))
174
+ offset_list.append(0)
175
+ yield (tuple(slice_list), tuple(offset_list))
176
+ else:
177
+ for index in batch_ranges[current_dim]:
178
+ yield from _generate_slices_recursive(
179
+ current_dim + 1, current_indices + (index,)
180
+ )
181
+
182
+ yield from _generate_slices_recursive(0, ())
183
+
131
184
  def __call__(
132
185
  self,
133
186
  score_space: NDArray,
134
187
  rotation_matrix: NDArray,
188
+ minimum_score: float = None,
189
+ maximum_score: float = None,
135
190
  **kwargs,
136
191
  ) -> None:
137
192
  """
@@ -143,59 +198,91 @@ class PeakCaller(ABC):
143
198
  Array containing the score space.
144
199
  rotation_matrix : NDArray
145
200
  Rotation matrix used to obtain the score array.
201
+ minimum_score : float
202
+ Minimum score from which to consider peaks. If provided, superseeds limits
203
+ presented by :py:attr:`PeakCaller.number_of_peaks`.
204
+ maximum_score : float
205
+ Maximum score upon which to consider peaks,
146
206
  **kwargs
147
207
  Optional keyword arguments passed to :py:meth:`PeakCaller.call_peak`.
148
208
  """
149
- peak_positions, peak_details = self.call_peaks(
150
- score_space=score_space, rotation_matrix=rotation_matrix, **kwargs
151
- )
209
+ for subset, offset in self._batchify(score_space.shape, self.batch_dims):
210
+ peak_positions, peak_details = self.call_peaks(
211
+ score_space=score_space[subset],
212
+ rotation_matrix=rotation_matrix,
213
+ minimum_score=minimum_score,
214
+ maximum_score=maximum_score,
215
+ **kwargs,
216
+ )
152
217
 
153
- if peak_positions is None:
154
- return None
218
+ if peak_positions is None:
219
+ continue
220
+ if peak_positions.shape[0] == 0:
221
+ continue
155
222
 
156
- peak_positions = backend.astype(peak_positions, int)
157
- if peak_positions.shape[0] == 0:
158
- return None
223
+ if peak_details is None:
224
+ peak_details = backend.full((peak_positions.shape[0],), fill_value=-1)
159
225
 
160
- if peak_details is None:
161
- peak_details = backend.to_backend_array([-1] * peak_positions.shape[0])
226
+ backend.add(peak_positions, offset, out=peak_positions)
227
+ peak_positions = backend.astype(peak_positions, int)
228
+ if self.min_boundary_distance > 0:
229
+ upper_limit = backend.subtract(
230
+ score_space.shape, self.min_boundary_distance
231
+ )
232
+ valid_peaks = backend.multiply(
233
+ peak_positions < upper_limit,
234
+ peak_positions >= self.min_boundary_distance,
235
+ )
236
+ if self.batch_dims is not None:
237
+ valid_peaks[..., self.batch_dims] = True
162
238
 
163
- if self.min_boundary_distance > 0:
164
- upper_limit = backend.subtract(
165
- score_space.shape, self.min_boundary_distance
166
- )
167
- valid_peaks = (
168
- backend.sum(
169
- backend.multiply(
170
- peak_positions < upper_limit,
171
- peak_positions >= self.min_boundary_distance,
172
- ),
173
- axis=1,
239
+ valid_peaks = (
240
+ backend.sum(valid_peaks, axis=1) == peak_positions.shape[1]
174
241
  )
175
- == peak_positions.shape[1]
176
- )
177
- if backend.sum(valid_peaks) == 0:
178
- return None
179
242
 
180
- peak_positions, peak_details = (
181
- peak_positions[valid_peaks],
182
- peak_details[valid_peaks],
243
+ if backend.sum(valid_peaks) == 0:
244
+ continue
245
+
246
+ peak_positions, peak_details = (
247
+ peak_positions[valid_peaks],
248
+ peak_details[valid_peaks],
249
+ )
250
+
251
+ peak_scores = score_space[tuple(peak_positions.T)]
252
+ if minimum_score is not None:
253
+ valid_peaks = peak_scores >= minimum_score
254
+ peak_positions, peak_details, peak_scores = (
255
+ peak_positions[valid_peaks],
256
+ peak_details[valid_peaks],
257
+ peak_scores[valid_peaks],
258
+ )
259
+ if maximum_score is not None:
260
+ valid_peaks = peak_scores <= maximum_score
261
+ peak_positions, peak_details, peak_scores = (
262
+ peak_positions[valid_peaks],
263
+ peak_details[valid_peaks],
264
+ peak_scores[valid_peaks],
265
+ )
266
+
267
+ if peak_positions.shape[0] == 0:
268
+ continue
269
+
270
+ rotations = backend.repeat(
271
+ rotation_matrix.reshape(1, *rotation_matrix.shape),
272
+ peak_positions.shape[0],
273
+ axis=0,
183
274
  )
184
275
 
185
- rotations = backend.repeat(
186
- rotation_matrix.reshape(1, *rotation_matrix.shape),
187
- peak_positions.shape[0],
188
- axis=0,
189
- )
190
- peak_scores = score_space[tuple(peak_positions.T)]
276
+ self._update(
277
+ peak_positions=peak_positions,
278
+ peak_details=peak_details,
279
+ peak_scores=peak_scores,
280
+ rotations=rotations,
281
+ batch_offset=offset,
282
+ **kwargs,
283
+ )
191
284
 
192
- self._update(
193
- peak_positions=peak_positions,
194
- peak_details=peak_details,
195
- peak_scores=peak_scores,
196
- rotations=rotations,
197
- **kwargs,
198
- )
285
+ return None
199
286
 
200
287
  @abstractmethod
201
288
  def call_peaks(
@@ -212,10 +299,8 @@ class PeakCaller(ABC):
212
299
  ----------
213
300
  score_space : NDArray
214
301
  Data array of scores.
215
- minimum_score : float
216
- Minimum score value to consider.
217
- min_distance : float
218
- Minimum distance between maxima.
302
+ **kwargs : Dict, optional
303
+ Keyword arguments passed to __call__.
219
304
 
220
305
  Returns
221
306
  -------
@@ -364,25 +449,52 @@ class PeakCaller(ABC):
364
449
  backend.add(peak_positions, translation_offset, out=peak_positions)
365
450
  if not len(self.peak_list):
366
451
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
367
- dim = peak_positions.shape[1]
368
- peak_scores = backend.zeros((0,), peak_scores.dtype)
369
- peak_details = backend.zeros((0,), peak_details.dtype)
370
- rotations = backend.zeros((0, dim, dim), rotations.dtype)
371
- peak_positions = backend.zeros((0, dim), peak_positions.dtype)
372
452
 
373
- peaks = backend.concatenate((self.peak_list[0], peak_positions))
453
+ peak_positions = backend.concatenate((self.peak_list[0], peak_positions))
374
454
  rotations = backend.concatenate((self.peak_list[1], rotations))
375
455
  peak_scores = backend.concatenate((self.peak_list[2], peak_scores))
376
456
  peak_details = backend.concatenate((self.peak_list[3], peak_details))
377
457
 
378
- top_n = min(backend.size(peak_scores), self.number_of_peaks)
379
- top_scores, *_ = backend.topk_indices(peak_scores, top_n)
458
+ if self.batch_dims is None:
459
+ top_n = min(backend.size(peak_scores), self.number_of_peaks)
460
+ top_scores, *_ = backend.topk_indices(peak_scores, top_n)
461
+ else:
462
+ # Not very performant but fairly robust
463
+ batch_indices = peak_positions[..., self.batch_dims]
464
+ backend.subtract(
465
+ batch_indices, backend.min(batch_indices, axis=0), out=batch_indices
466
+ )
467
+ multiplier = backend.power(
468
+ backend.max(batch_indices, axis=0) + 1,
469
+ backend.arange(batch_indices.shape[1]),
470
+ )
471
+ backend.multiply(batch_indices, multiplier, out=batch_indices)
472
+ batch_indices = backend.sum(batch_indices, axis=1)
473
+ unique_indices, batch_counts = backend.unique(
474
+ batch_indices, return_counts=True
475
+ )
476
+ total_indices = backend.arange(peak_scores.shape[0])
477
+ batch_indices = [total_indices[batch_indices == x] for x in unique_indices]
478
+ top_scores = backend.concatenate(
479
+ [
480
+ total_indices[indices][
481
+ backend.topk_indices(
482
+ peak_scores[indices], min(y, self.number_of_peaks)
483
+ )
484
+ ]
485
+ for indices, y in zip(batch_indices, batch_counts)
486
+ ]
487
+ )
380
488
 
381
489
  final_order = top_scores[
382
- filter_points_indices(peaks[top_scores], self.min_distance)
490
+ filter_points_indices(
491
+ coordinates=peak_positions[top_scores],
492
+ min_distance=self.min_distance,
493
+ batch_dims=self.batch_dims,
494
+ )
383
495
  ]
384
496
 
385
- self.peak_list[0] = peaks[final_order,]
497
+ self.peak_list[0] = peak_positions[final_order,]
386
498
  self.peak_list[1] = rotations[final_order,]
387
499
  self.peak_list[2] = peak_scores[final_order]
388
500
  self.peak_list[3] = peak_details[final_order]
@@ -390,6 +502,9 @@ class PeakCaller(ABC):
390
502
  def _postprocess(
391
503
  self, fourier_shift, convolution_mode, targetshape, templateshape, **kwargs
392
504
  ):
505
+ if not len(self.peak_list):
506
+ return self
507
+
393
508
  peak_positions = self.peak_list[0]
394
509
  if not len(peak_positions):
395
510
  return self
@@ -402,12 +517,13 @@ class PeakCaller(ABC):
402
517
 
403
518
  if fourier_shift is not None:
404
519
  peak_positions = backend.add(peak_positions, fourier_shift)
405
- backend.divide(peak_positions, score_space_shape).astype(int)
406
520
 
407
521
  backend.subtract(
408
522
  peak_positions,
409
523
  backend.multiply(
410
- backend.divide(peak_positions, score_space_shape).astype(int),
524
+ backend.astype(
525
+ backend.divide(peak_positions, score_space_shape), int
526
+ ),
411
527
  score_space_shape,
412
528
  ),
413
529
  out=peak_positions,
@@ -446,14 +562,10 @@ class PeakCaller(ABC):
446
562
  class PeakCallerSort(PeakCaller):
447
563
  """
448
564
  A :py:class:`PeakCaller` subclass that first selects ``number_of_peaks``
449
- highest scores and subsequently filters local maxima to suffice a distance
450
- from one another of ``min_distance``.
451
-
565
+ highest scores.
452
566
  """
453
567
 
454
- def call_peaks(
455
- self, score_space: NDArray, minimum_score: float = None, **kwargs
456
- ) -> Tuple[NDArray, NDArray]:
568
+ def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
457
569
  """
458
570
  Call peaks in the score space.
459
571
 
@@ -461,9 +573,6 @@ class PeakCallerSort(PeakCaller):
461
573
  ----------
462
574
  score_space : NDArray
463
575
  Data array of scores.
464
- minimum_score : float
465
- Minimum score value to consider. If provided, superseeds limit given
466
- by :py:attr:`PeakCaller.number_of_peaks`.
467
576
 
468
577
  Returns
469
578
  -------
@@ -473,16 +582,12 @@ class PeakCallerSort(PeakCaller):
473
582
  flat_score_space = score_space.reshape(-1)
474
583
  k = min(self.number_of_peaks, backend.size(flat_score_space))
475
584
 
476
- if minimum_score is not None:
477
- k = backend.sum(score_space >= minimum_score)
478
-
479
585
  top_k_indices, *_ = backend.topk_indices(flat_score_space, k)
480
586
 
481
587
  coordinates = backend.unravel_index(top_k_indices, score_space.shape)
482
588
  coordinates = backend.transpose(backend.stack(coordinates))
483
589
 
484
- peaks = filter_points(coordinates, self.min_distance)
485
- return peaks, None
590
+ return coordinates, None
486
591
 
487
592
 
488
593
  class PeakCallerMaximumFilter(PeakCaller):
@@ -492,9 +597,7 @@ class PeakCallerMaximumFilter(PeakCaller):
492
597
  skimage.feature.peak_local_max.
493
598
  """
494
599
 
495
- def call_peaks(
496
- self, score_space: NDArray, minimum_score: float = None, **kwargs
497
- ) -> Tuple[NDArray, NDArray]:
600
+ def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
498
601
  """
499
602
  Call peaks in the score space.
500
603
 
@@ -502,9 +605,8 @@ class PeakCallerMaximumFilter(PeakCaller):
502
605
  ----------
503
606
  score_space : NDArray
504
607
  Data array of scores.
505
- minimum_score : float
506
- Minimum score value to consider. If provided, superseeds limit given
507
- by :py:attr:`PeakCaller.number_of_peaks`.
608
+ kwargs: Dict, optional
609
+ Optional keyword arguments.
508
610
 
509
611
  Returns
510
612
  -------
@@ -513,17 +615,6 @@ class PeakCallerMaximumFilter(PeakCaller):
513
615
  """
514
616
  peaks = backend.max_filter_coordinates(score_space, self.min_distance)
515
617
 
516
- scores = score_space[tuple(peaks.T)]
517
-
518
- input_candidates = min(
519
- self.number_of_peaks, peaks.shape[0] - 1, backend.size(score_space) - 1
520
- )
521
- if minimum_score is not None:
522
- input_candidates = backend.sum(scores >= minimum_score)
523
-
524
- top_indices = backend.topk_indices(scores, input_candidates)
525
- peaks = peaks[top_indices]
526
-
527
618
  return peaks, None
528
619
 
529
620
 
@@ -536,9 +627,7 @@ class PeakCallerFast(PeakCaller):
536
627
 
537
628
  """
538
629
 
539
- def call_peaks(
540
- self, score_space: NDArray, minimum_score: float = None, **kwargs
541
- ) -> Tuple[NDArray, NDArray]:
630
+ def call_peaks(self, score_space: NDArray, **kwargs) -> Tuple[NDArray, NDArray]:
542
631
  """
543
632
  Call peaks in the score space.
544
633
 
@@ -546,9 +635,6 @@ class PeakCallerFast(PeakCaller):
546
635
  ----------
547
636
  score_space : NDArray
548
637
  Data array of scores.
549
- minimum_score : float
550
- Minimum score value to consider. If provided, superseeds limit given
551
- by :py:attr:`PeakCaller.number_of_peaks`.
552
638
 
553
639
  Returns
554
640
  -------
@@ -580,26 +666,24 @@ class PeakCallerFast(PeakCaller):
580
666
  if coordinates.shape[0] == 0:
581
667
  return None
582
668
 
583
- peaks = filter_points(coordinates, self.min_distance)
584
-
585
- starts = backend.maximum(peaks - self.min_distance, 0)
586
- stops = backend.minimum(peaks + self.min_distance, score_space.shape)
669
+ starts = backend.maximum(coordinates - self.min_distance, 0)
670
+ stops = backend.minimum(coordinates + self.min_distance, score_space.shape)
587
671
  slices_list = [
588
672
  tuple(slice(*coord) for coord in zip(start_row, stop_row))
589
673
  for start_row, stop_row in zip(starts, stops)
590
674
  ]
591
675
 
592
- scores = score_space[tuple(peaks.T)]
676
+ scores = score_space[tuple(coordinates.T)]
593
677
  keep = [
594
678
  score >= backend.max(score_space[subvol])
595
679
  for subvol, score in zip(slices_list, scores)
596
680
  ]
597
- peaks = peaks[keep,]
681
+ coordinates = coordinates[keep,]
598
682
 
599
- if len(peaks) == 0:
600
- return peaks, None
683
+ if len(coordinates) == 0:
684
+ return coordinates, None
601
685
 
602
- return peaks, None
686
+ return coordinates, None
603
687
 
604
688
 
605
689
  class PeakCallerRecursiveMasking(PeakCaller):
@@ -655,7 +739,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
655
739
  if mask is None:
656
740
  masking_function = self._mask_scores_box
657
741
  shape = tuple(self.min_distance for _ in range(score_space.ndim))
658
- mask = backend.zeros(shape, dtype=backend._default_dtype)
742
+ mask = backend.zeros(shape, dtype=backend._float_dtype)
659
743
 
660
744
  rotated_template = backend.zeros(mask.shape, dtype=mask.dtype)
661
745
 
@@ -665,12 +749,15 @@ class PeakCallerRecursiveMasking(PeakCaller):
665
749
  else:
666
750
  minimum_score = backend.min(score_space) - 1
667
751
 
752
+ scores = backend.zeros(score_space.shape, dtype=score_space.dtype)
753
+ scores[:] = score_space
754
+
668
755
  while True:
669
- backend.argmax(score_space)
756
+ backend.argmax(scores)
670
757
  peak = backend.unravel_index(
671
- indices=backend.argmax(score_space), shape=score_space.shape
758
+ indices=backend.argmax(scores), shape=scores.shape
672
759
  )
673
- if score_space[tuple(peak)] < minimum_score:
760
+ if scores[tuple(peak)] < minimum_score:
674
761
  break
675
762
 
676
763
  coordinates.append(peak)
@@ -683,7 +770,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
683
770
  )
684
771
 
685
772
  masking_function(
686
- score_space=score_space,
773
+ score_space=scores,
687
774
  rotation_matrix=current_rotation_matrix,
688
775
  peak=peak,
689
776
  mask=mask,
@@ -839,18 +926,23 @@ class PeakCallerScipy(PeakCaller):
839
926
  Tuple[NDArray, NDArray]
840
927
  Array of peak coordinates and peak details.
841
928
  """
842
-
843
929
  score_space = backend.to_numpy_array(score_space)
844
930
  num_peaks = self.number_of_peaks
845
931
  if minimum_score is not None:
846
932
  num_peaks = np.inf
847
933
 
934
+ non_squeezable_dims = tuple(
935
+ i for i, x in enumerate(score_space.shape) if x != 1
936
+ )
848
937
  peaks = peak_local_max(
849
- score_space,
938
+ np.squeeze(score_space),
850
939
  num_peaks=num_peaks,
851
940
  min_distance=self.min_distance,
852
941
  threshold_abs=minimum_score,
853
942
  )
943
+ peaks_full = np.zeros((peaks.shape[0], score_space.ndim), peaks.dtype)
944
+ peaks_full[..., non_squeezable_dims] = peaks[:]
945
+ peaks = backend.to_backend_array(peaks_full)
854
946
  return peaks, None
855
947
 
856
948
 
@@ -1164,6 +1256,53 @@ class MaxScoreOverRotations:
1164
1256
  Whether to offload internal data arrays to disk
1165
1257
  thread_safe: bool, optional
1166
1258
  Whether access to internal data arrays should be thread safe
1259
+
1260
+ Examples
1261
+ --------
1262
+ The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
1263
+ instance
1264
+
1265
+ >>> from tme.analyzer import MaxScoreOverRotations
1266
+ >>> analyzer = MaxScoreOverRotations(
1267
+ >>> score_space_shape = (50, 50),
1268
+ >>> score_space_dtype = np.float32,
1269
+ >>> rotation_space_dtype = np.int32,
1270
+ >>> )
1271
+
1272
+ The following simulates a template matching run by creating random data for a range
1273
+ of rotations and sending it to ``analyzer`` via its __call__ method
1274
+
1275
+ >>> for rotation_number in range(10):
1276
+ >>> scores = np.random.rand(50,50)
1277
+ >>> rotation = np.random.rand(scores.ndim, scores.ndim)
1278
+ >>> analyzer(score_space = scores, rotation_matrix = rotation)
1279
+
1280
+ The aggregated scores can be exctracted by invoking the __iter__ method of
1281
+ ``analyzer``
1282
+
1283
+ >>> results = tuple(analyzer)
1284
+
1285
+ The ``results`` tuple contains (1) the maximum scores for each translation,
1286
+ (2) an offset which is relevant when merging results from split template matching
1287
+ using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
1288
+ score for a given translation, (4) a dictionary mapping rotation matrices to the
1289
+ indices used in (2).
1290
+
1291
+ We can extract the ``optimal_score`, ``optimal_translation`` and ``optimal_rotation``
1292
+ as follows
1293
+
1294
+ >>> optimal_score = results[0].max()
1295
+ >>> optimal_translation = np.where(results[0] == results[0].max())
1296
+ >>> optimal_rotation_index = results[2][optimal_translation]
1297
+ >>> for key, value in results[3].items():
1298
+ >>> if value != optimal_rotation_index:
1299
+ >>> continue
1300
+ >>> optimal_rotation = np.frombuffer(key, rotation.dtype)
1301
+ >>> optimal_rotation = optimal_rotation.reshape(scores.ndim, scores.ndim)
1302
+
1303
+ The outlined procedure is a trivial method to identify high scoring peaks.
1304
+ Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
1305
+ that can be used.
1167
1306
  """
1168
1307
 
1169
1308
  def __init__(
@@ -1316,11 +1455,11 @@ class MaxScoreOverRotations:
1316
1455
  Arbitrary keyword arguments.
1317
1456
  """
1318
1457
  rotation = backend.tobytes(rotation_matrix)
1319
- rotation_index = self.observed_rotations.setdefault(
1320
- rotation, len(self.observed_rotations)
1321
- )
1322
1458
 
1323
1459
  if self.lock_is_nullcontext:
1460
+ rotation_index = self.observed_rotations.setdefault(
1461
+ rotation, len(self.observed_rotations)
1462
+ )
1324
1463
  backend.max_score_over_rotations(
1325
1464
  score_space=score_space,
1326
1465
  internal_scores=self.score_space,
@@ -1330,6 +1469,9 @@ class MaxScoreOverRotations:
1330
1469
  return None
1331
1470
 
1332
1471
  with self.lock:
1472
+ rotation_index = self.observed_rotations.setdefault(
1473
+ rotation, len(self.observed_rotations)
1474
+ )
1333
1475
  internal_scores = backend.sharedarr_to_arr(
1334
1476
  shape=self.score_space_shape,
1335
1477
  dtype=self.score_space_dtype,
tme/backends/__init__.py CHANGED
@@ -57,6 +57,7 @@ class BackendManager:
57
57
  "cpu_backend": NumpyFFTWBackend,
58
58
  "pytorch": PytorchBackend,
59
59
  "cupy": CupyBackend,
60
+ "mlx": MLXBackend,
60
61
  }
61
62
  self._backend = NumpyFFTWBackend()
62
63
  self._backend_name = "cpu_backend"