pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/METADATA +21 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +193 -27
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  70. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py ADDED
@@ -0,0 +1,127 @@
1
+ """
2
+ Implements abstract base class for template matching analyzers.
3
+
4
+ Copyright (c) 2025 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+
9
+ from typing import Tuple, List
10
+ from abc import ABC, abstractmethod
11
+
12
+ __all__ = ["AbstractAnalyzer"]
13
+
14
+
15
+ class AbstractAnalyzer(ABC):
16
+ """
17
+ Abstract base class for template matching analyzers.
18
+ """
19
+
20
+ @property
21
+ def shareable(self):
22
+ """
23
+ Indicate whether the analyzer can be shared across processes.
24
+
25
+ Returns
26
+ -------
27
+ bool
28
+ True if the analyzer supports shared memory operations
29
+ and can be safely used across multiple processes, False
30
+ if it should only be used within a single process.
31
+ """
32
+ return False
33
+
34
+ @abstractmethod
35
+ def init_state(self, *args, **kwargs) -> Tuple:
36
+ """
37
+ Initialize the analyzer state.
38
+
39
+ Returns
40
+ -------
41
+ state : tuple
42
+ Initial state tuple of the analyzer instance. The exact structure
43
+ depends on the specific implementation.
44
+
45
+ Notes
46
+ -----
47
+ This method creates the initial state that will be passed to
48
+ :py:meth:`AbstractAnalyzer.__call__` and finally to
49
+ :py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
50
+ data structures for accumulating analysis results.
51
+ """
52
+
53
+ @abstractmethod
54
+ def __call__(self, state, scores, rotation_matrix, **kwargs) -> Tuple:
55
+ """
56
+ Update the analyzer state with new scoring data.
57
+
58
+ Parameters
59
+ ----------
60
+ state : tuple
61
+ Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
62
+ or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
63
+ scores : BackendArray
64
+ Array of new scores with dimensionality d.
65
+ rotation_matrix : BackendArray
66
+ Rotation matrix used to generate scores with shape (d,d).
67
+ **kwargs : dict
68
+ Keyword arguments used by specific implementations.
69
+
70
+ Returns
71
+ -------
72
+ tuple
73
+ Updated analyzer state incorporating the new data.
74
+ """
75
+
76
+ @abstractmethod
77
+ def result(self, state: Tuple, **kwargs) -> Tuple:
78
+ """
79
+ Finalize the analysis by performing potential post processing.
80
+
81
+ Parameters
82
+ ----------
83
+ state : tuple
84
+ Analyzer state containing accumulated data.
85
+ **kwargs : dict
86
+ Keyword arguments used by specific implementations.
87
+
88
+ Returns
89
+ -------
90
+ result
91
+ Final analysis result. The exact struccture depends on the
92
+ analyzer implementation.
93
+
94
+ Notes
95
+ -----
96
+ This method converts the internal analyzer state into the
97
+ final output format expected by the template matching pipeline.
98
+ It may apply postprocessing operations like convolution mode
99
+ correction or coordinate transformations.
100
+ """
101
+
102
+ @classmethod
103
+ @abstractmethod
104
+ def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
105
+ """
106
+ Merge multiple analyzer results.
107
+
108
+ Parameters
109
+ ----------
110
+ results : list of tuple
111
+ List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
112
+ from different instances of the same analyzer class.
113
+ **kwargs : dict
114
+ Keyword arguments used by specific implementations.
115
+
116
+ Returns
117
+ -------
118
+ tuple
119
+ Single result object combining all input results.
120
+
121
+ Notes
122
+ -----
123
+ This method enables parallel processing by allowing results
124
+ from different processes or splits to be combined into a
125
+ unified result. The merge operation should handle overlapping
126
+ data appropriately and maintain consistency.
127
+ """
tme/analyzer/peaks.py CHANGED
@@ -1,23 +1,25 @@
1
- """ Implements classes to analyze outputs from exhaustive template matching.
1
+ """
2
+ Implements classes to analyze outputs from exhaustive template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from functools import wraps
9
- from abc import ABC, abstractmethod
10
- from typing import Tuple, List, Dict, Generator
10
+ from abc import abstractmethod
11
+ from typing import Tuple, List, Dict
11
12
 
12
13
  import numpy as np
13
14
  from skimage.feature import peak_local_max
14
15
  from skimage.registration._phase_cross_correlation import _upsampled_dft
15
16
 
17
+ from .base import AbstractAnalyzer
16
18
  from ._utils import score_to_cart
17
19
  from ..backends import backend as be
18
- from ..matching_utils import split_shape
19
20
  from ..types import BackendArray, NDArray
20
21
  from ..rotations import euler_to_rotationmatrix
22
+ from ..matching_utils import split_shape, compute_extraction_box
21
23
 
22
24
  __all__ = [
23
25
  "PeakCaller",
@@ -141,7 +143,7 @@ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
141
143
  yield from _generate_slices_recursive(0, ())
142
144
 
143
145
 
144
- class PeakCaller(ABC):
146
+ class PeakCaller(AbstractAnalyzer):
145
147
  """
146
148
  Base class for peak calling algorithms.
147
149
 
@@ -190,19 +192,7 @@ class PeakCaller(ABC):
190
192
  if min_boundary_distance < 0:
191
193
  raise ValueError("min_boundary_distance has to be non-negative.")
192
194
 
193
- ndim = len(shape)
194
- self.translations = be.full(
195
- (num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
196
- )
197
- self.rotations = be.full(
198
- (num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
199
- )
200
- for i in range(ndim):
201
- self.rotations[:, i, i] = 1.0
202
-
203
- self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
204
- self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
205
-
195
+ self.shape = shape
206
196
  self.num_peaks = int(num_peaks)
207
197
  self.min_distance = int(min_distance)
208
198
  self.min_boundary_distance = int(min_boundary_distance)
@@ -213,31 +203,47 @@ class PeakCaller(ABC):
213
203
 
214
204
  self.min_score, self.max_score = min_score, max_score
215
205
 
216
- # Postprocessing arguments
217
- self.fourier_shift = kwargs.get("fourier_shift", None)
218
- self.convolution_mode = kwargs.get("convolution_mode", None)
219
- self.targetshape = kwargs.get("targetshape", None)
220
- self.templateshape = kwargs.get("templateshape", None)
221
-
222
- def __iter__(self) -> Generator:
206
+ @abstractmethod
207
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
223
208
  """
224
- Returns a generator to list objects containing translation,
225
- rotation, score and details of a given candidate.
209
+ Call peaks in the score space.
210
+
211
+ Parameters
212
+ ----------
213
+ scores : BackendArray
214
+ Score array to update analyzer with.
215
+ **kwargs : dict
216
+ Optional keyword arguments passed to underlying implementations.
217
+
218
+ Returns
219
+ -------
220
+ BackendArray
221
+ Peak positions (n, d).
222
+ BackendArray
223
+ Peak details (n, d).
226
224
  """
227
- self.peak_list = [
228
- be.to_cpu_array(self.translations),
229
- be.to_cpu_array(self.rotations),
230
- be.to_cpu_array(self.scores),
231
- be.to_cpu_array(self.details),
232
- ]
233
- yield from self.peak_list
225
+
226
+ def init_state(self):
227
+ ndim = len(self.shape)
228
+ translations = be.full(
229
+ (self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
230
+ )
231
+ rotations = be.full(
232
+ (self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
233
+ )
234
+ for i in range(ndim):
235
+ rotations[:, i, i] = 1.0
236
+
237
+ scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
238
+ details = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
239
+ return translations, rotations, scores, details
234
240
 
235
241
  def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
236
242
  if not len(peaks):
237
243
  return None
238
244
 
239
245
  valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
240
- if self.min_boundary_distance > 0:
246
+ if self.min_boundary_distance >= 0:
241
247
  upper_limit = be.subtract(
242
248
  be.to_backend_array(scores.shape), self.min_boundary_distance
243
249
  )
@@ -318,20 +324,34 @@ class PeakCaller(ABC):
318
324
  peak_positions = be.astype(peak_positions, int)
319
325
  return peak_positions, peak_details
320
326
 
321
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
327
+ def __call__(
328
+ self,
329
+ state: Tuple,
330
+ scores: BackendArray,
331
+ rotation_matrix: BackendArray,
332
+ **kwargs,
333
+ ) -> Tuple:
322
334
  """
323
335
  Update the internal parameter store based on input array.
324
336
 
325
337
  Parameters
326
338
  ----------
339
+ state : tuple
340
+ Current state tuple where:
341
+ - positions : BackendArray, (n, d) of peak positions
342
+ - rotations : BackendArray, (n, d, d) of correponding rotations
343
+ - scores : BackendArray, (n, ) of peak scores
344
+ - details : BackendArray, (n, ) of peak details
327
345
  scores : BackendArray
328
- Score space data.
346
+ Array of new scores to update analyzer with.
329
347
  rotation_matrix : BackendArray
330
348
  Rotation matrix used to obtain the score array.
331
349
  **kwargs
332
350
  Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
333
351
  """
334
- for ret in self._call_peaks(scores=scores, rotation_matrix=rotation_matrix):
352
+ for ret in self._call_peaks(
353
+ scores=scores, rotation_matrix=rotation_matrix, **kwargs
354
+ ):
335
355
  peak_positions, peak_details = ret
336
356
  if peak_positions is None:
337
357
  continue
@@ -344,7 +364,6 @@ class PeakCaller(ABC):
344
364
  peak_scores = scores[tuple(peak_positions.T)]
345
365
  if peak_details is not None:
346
366
  peak_details = peak_details[valid_peaks]
347
- # peak_details, peak_scores = peak_scores, -peak_details
348
367
  else:
349
368
  peak_details = be.full(peak_scores.shape, fill_value=-1)
350
369
 
@@ -354,66 +373,57 @@ class PeakCaller(ABC):
354
373
  axis=0,
355
374
  )
356
375
 
357
- self._update(
376
+ state = self._update(
377
+ state,
358
378
  peak_positions=peak_positions,
359
379
  peak_details=peak_details,
360
380
  peak_scores=peak_scores,
361
- rotations=rotations,
381
+ peak_rotations=rotations,
362
382
  )
363
383
 
364
- return None
365
-
366
- @abstractmethod
367
- def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
368
- """
369
- Call peaks in the score space.
370
-
371
- Parameters
372
- ----------
373
- scores : BackendArray
374
- Score array.
375
- **kwargs : dict
376
- Optional keyword arguments passed to underlying implementations.
377
-
378
- Returns
379
- -------
380
- Tuple[BackendArray, BackendArray]
381
- Array of peak coordinates and peak details.
382
- """
384
+ return state
383
385
 
384
386
  @classmethod
385
- def merge(cls, candidates=List[List], **kwargs) -> Tuple:
387
+ def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
386
388
  """
387
389
  Merge multiple instances of :py:class:`PeakCaller`.
388
390
 
389
391
  Parameters
390
392
  ----------
391
- candidates : list of lists
392
- Obtained by invoking list on the generator returned by __iter__.
393
+ results : list of tuple
394
+ List of instance results created by applying `result`.
393
395
  **kwargs
394
396
  Optional keyword arguments.
395
397
 
396
398
  Returns
397
399
  -------
398
- Tuple
399
- Tuple of translation, rotation, score and details of candidates.
400
+ NDArray
401
+ Peak positions (n, d).
402
+ NDArray
403
+ Peak rotation matrices (n, d, d).
404
+ NDArray
405
+ Peak scores (n, ).
406
+ NDArray
407
+ Peak details (n,).
400
408
  """
401
409
  if "shape" not in kwargs:
402
- kwargs["shape"] = tuple(1 for _ in range(candidates[0][0].shape[1]))
410
+ kwargs["shape"] = tuple(1 for _ in range(results[0][0].shape[1]))
403
411
 
404
412
  base = cls(**kwargs)
405
- for candidate in candidates:
406
- if len(candidate) == 0:
413
+ base_state = base.init_state()
414
+ for result in results:
415
+ if len(result) == 0:
407
416
  continue
408
- peak_positions, rotations, peak_scores, peak_details = candidate
409
- base._update(
417
+ peak_positions, rotations, peak_scores, peak_details = result
418
+ base_state = base._update(
419
+ base_state,
410
420
  peak_positions=be.to_backend_array(peak_positions),
411
421
  peak_details=be.to_backend_array(peak_details),
412
422
  peak_scores=be.to_backend_array(peak_scores),
413
- rotations=be.to_backend_array(rotations),
423
+ peak_rotations=be.to_backend_array(rotations),
414
424
  offset=kwargs.get("offset", None),
415
425
  )
416
- return tuple(base)
426
+ return base_state
417
427
 
418
428
  @staticmethod
419
429
  def oversample_peaks(
@@ -520,10 +530,11 @@ class PeakCaller(ABC):
520
530
 
521
531
  def _update(
522
532
  self,
533
+ state,
523
534
  peak_positions: BackendArray,
524
535
  peak_details: BackendArray,
525
536
  peak_scores: BackendArray,
526
- rotations: BackendArray,
537
+ peak_rotations: BackendArray,
527
538
  offset: BackendArray = None,
528
539
  ):
529
540
  """
@@ -542,14 +553,15 @@ class PeakCaller(ABC):
542
553
  offset : BackendArray, optional
543
554
  Translation offset, e.g. from splitting, (d, ).
544
555
  """
556
+ translations, rotations, scores, details = state
545
557
  if offset is not None:
546
558
  offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
547
559
  peak_positions = be.add(peak_positions, offset, out=peak_positions)
548
560
 
549
- positions = be.concatenate((self.translations, peak_positions))
550
- rotations = be.concatenate((self.rotations, rotations))
551
- scores = be.concatenate((self.scores, peak_scores))
552
- details = be.concatenate((self.details, peak_details))
561
+ positions = be.concatenate((translations, peak_positions))
562
+ rotations = be.concatenate((rotations, peak_rotations))
563
+ scores = be.concatenate((scores, peak_scores))
564
+ details = be.concatenate((details, peak_details))
553
565
 
554
566
  # topk filtering after distances yields more distributed peak calls
555
567
  distance_order = filter_points_indices(
@@ -564,22 +576,69 @@ class PeakCaller(ABC):
564
576
  )
565
577
  final_order = distance_order[top_scores]
566
578
 
567
- self.translations = positions[final_order, :]
568
- self.rotations = rotations[final_order, :]
569
- self.scores = scores[final_order]
570
- self.details = details[final_order]
579
+ translations = positions[final_order, :]
580
+ rotations = rotations[final_order, :]
581
+ scores = scores[final_order]
582
+ details = details[final_order]
583
+ return translations, rotations, scores, details
571
584
 
572
- def _postprocess(self, **kwargs):
573
- if not len(self.translations):
574
- return self
585
+ def result(
586
+ self,
587
+ state,
588
+ fast_shape: Tuple[int] = None,
589
+ targetshape: Tuple[int] = None,
590
+ templateshape: Tuple[int] = None,
591
+ convolution_shape: Tuple[int] = None,
592
+ fourier_shift: Tuple[int] = None,
593
+ convolution_mode: str = None,
594
+ **kwargs,
595
+ ):
596
+ """
597
+ Finalize the analysis result with optional postprocessing.
575
598
 
576
- positions, valid_peaks = score_to_cart(self.translations, **kwargs)
599
+ Parameters
600
+ ----------
601
+ state : tuple
602
+ Current state tuple where:
603
+ - positions : BackendArray, (n, d) of peak positions
604
+ - rotations : BackendArray, (n, d, d) of correponding rotations
605
+ - scores : BackendArray, (n, ) of peak scores
606
+ - details : BackendArray, (n, ) of peak details
607
+ targetshape : Tuple[int], optional
608
+ Shape of the target for convolution mode correction.
609
+ templateshape : Tuple[int], optional
610
+ Shape of the template for convolution mode correction.
611
+ convolution_shape : Tuple[int], optional
612
+ Shape used for convolution.
613
+ fourier_shift : Tuple[int], optional.
614
+ Shift to apply for Fourier correction.
615
+ convolution_mode : str, optional
616
+ Convolution mode for padding correction.
617
+ **kwargs
618
+ Additional keyword arguments.
577
619
 
578
- self.translations = positions[valid_peaks]
579
- self.rotations = self.rotations[valid_peaks]
580
- self.scores = self.scores[valid_peaks]
581
- self.details = self.details[valid_peaks]
582
- return self
620
+ Returns
621
+ -------
622
+ tuple
623
+ Final result tuple (positions, rotations, scores, details).
624
+ """
625
+ translations, rotations, scores, details = state
626
+
627
+ positions, valid_peaks = score_to_cart(
628
+ positions=translations,
629
+ fast_shape=fast_shape,
630
+ targetshape=targetshape,
631
+ templateshape=templateshape,
632
+ convolution_shape=convolution_shape,
633
+ fourier_shift=fourier_shift,
634
+ convolution_mode=convolution_mode,
635
+ **kwargs,
636
+ )
637
+ translations = be.to_cpu_array(positions[valid_peaks])
638
+ rotations = be.to_cpu_array(rotations[valid_peaks])
639
+ scores = be.to_cpu_array(scores[valid_peaks])
640
+ details = be.to_cpu_array(details[valid_peaks])
641
+ return translations, rotations, scores, details
583
642
 
584
643
 
585
644
  class PeakCallerSort(PeakCaller):
@@ -706,14 +765,14 @@ class PeakCallerRecursiveMasking(PeakCaller):
706
765
  values. If rotations and rotation_mapping is provided, the respective
707
766
  rotation will be applied to the mask, otherwise rotation_matrix is used.
708
767
  """
709
- coordinates, masking_function = [], self._mask_scores_rotate
710
-
711
- if mask is None:
712
- masking_function = self._mask_scores_box
713
- shape = tuple(self.min_distance for _ in range(scores.ndim))
714
- mask = be.zeros(shape, dtype=be._float_dtype)
768
+ peaks = []
769
+ box = tuple(self.min_distance for _ in range(scores.ndim))
715
770
 
716
- rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
771
+ scores = be.to_backend_array(scores)
772
+ if mask is not None:
773
+ box = mask.shape
774
+ mask = be.to_backend_array(mask)
775
+ mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
717
776
 
718
777
  peak_limit = self.num_peaks
719
778
  if min_score is not None:
@@ -721,39 +780,45 @@ class PeakCallerRecursiveMasking(PeakCaller):
721
780
  else:
722
781
  min_score = be.min(scores) - 1
723
782
 
724
- scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
725
- scores_copy[:] = scores
726
-
783
+ _scores = be.zeros(scores.shape, dtype=scores.dtype)
784
+ _scores[:] = scores[:]
727
785
  while True:
728
- be.argmax(scores_copy)
729
- peak = be.unravel_index(
730
- indices=be.argmax(scores_copy), shape=scores_copy.shape
731
- )
732
- if scores_copy[tuple(peak)] < min_score:
786
+ peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
787
+ if _scores[tuple(peak)] < min_score:
733
788
  break
789
+ peaks.append(peak)
734
790
 
735
- coordinates.append(peak)
736
-
737
- current_rotation_matrix = self._get_rotation_matrix(
738
- peak=peak,
739
- rotation_space=rotations,
740
- rotation_mapping=rotation_mapping,
741
- rotation_matrix=rotation_matrix,
791
+ score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
792
+ centers=be.to_backend_array(peak)[None],
793
+ extraction_shape=box,
794
+ original_shape=scores.shape,
742
795
  )
743
-
744
- masking_function(
745
- scores=scores_copy,
746
- rotation_matrix=current_rotation_matrix,
747
- peak=peak,
748
- mask=mask,
749
- rotated_template=rotated_template,
796
+ score_slice = tuple(
797
+ slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
798
+ )
799
+ tmpl_slice = tuple(
800
+ slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
750
801
  )
751
802
 
752
- if len(coordinates) >= peak_limit:
803
+ score_mask = 0
804
+ if mask is not None:
805
+ mask_buffer.fill(0)
806
+ rmat = self._get_rotation_matrix(
807
+ peak=peak,
808
+ rotation_space=rotations,
809
+ rotation_mapping=rotation_mapping,
810
+ rotation_matrix=rotation_matrix,
811
+ )
812
+ be.rigid_transform(
813
+ arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
814
+ )
815
+ score_mask = mask_buffer[tmpl_slice] <= 0.1
816
+
817
+ _scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
818
+ if len(peaks) >= peak_limit:
753
819
  break
754
820
 
755
- peaks = be.to_backend_array(coordinates)
756
- return peaks, None
821
+ return be.to_backend_array(peaks), None
757
822
 
758
823
  @staticmethod
759
824
  def _get_rotation_matrix(
@@ -793,86 +858,6 @@ class PeakCallerRecursiveMasking(PeakCaller):
793
858
  )
794
859
  return rotation
795
860
 
796
- @staticmethod
797
- def _mask_scores_box(
798
- scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
799
- ) -> None:
800
- """
801
- Mask scores in a box around a peak.
802
-
803
- Parameters
804
- ----------
805
- scores : BackendArray
806
- Data array of scores.
807
- peak : BackendArray
808
- Peak coordinates.
809
- mask : BackendArray
810
- Mask array.
811
- """
812
- start = be.maximum(be.subtract(peak, mask.shape), 0)
813
- stop = be.minimum(be.add(peak, mask.shape), scores.shape)
814
- start, stop = be.astype(start, int), be.astype(stop, int)
815
- coords = tuple(slice(*pos) for pos in zip(start, stop))
816
- scores[coords] = 0
817
- return None
818
-
819
- @staticmethod
820
- def _mask_scores_rotate(
821
- scores: BackendArray,
822
- peak: BackendArray,
823
- mask: BackendArray,
824
- rotated_template: BackendArray,
825
- rotation_matrix: BackendArray,
826
- **kwargs: Dict,
827
- ) -> None:
828
- """
829
- Mask scores using mask rotation around a peak.
830
-
831
- Parameters
832
- ----------
833
- scores : BackendArray
834
- Data array of scores.
835
- peak : BackendArray
836
- Peak coordinates.
837
- mask : BackendArray
838
- Mask array.
839
- rotated_template : BackendArray
840
- Empty array to write mask rotations to.
841
- rotation_matrix : BackendArray
842
- Rotation matrix.
843
- """
844
- left_pad = be.divide(mask.shape, 2).astype(int)
845
- right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
846
-
847
- score_start = be.subtract(peak, left_pad)
848
- score_stop = be.add(peak, right_pad)
849
-
850
- template_start = be.subtract(be.maximum(score_start, 0), score_start)
851
- template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
852
- template_stop = be.subtract(mask.shape, template_stop)
853
-
854
- score_start = be.maximum(score_start, 0)
855
- score_stop = be.minimum(score_stop, scores.shape)
856
- score_start = be.astype(score_start, int)
857
- score_stop = be.astype(score_stop, int)
858
-
859
- template_start = be.astype(template_start, int)
860
- template_stop = be.astype(template_stop, int)
861
- coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
862
- coords_template = tuple(
863
- slice(*pos) for pos in zip(template_start, template_stop)
864
- )
865
-
866
- rotated_template.fill(0)
867
- be.rigid_transform(
868
- arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
869
- )
870
-
871
- scores[coords_score] = be.multiply(
872
- scores[coords_score], (rotated_template[coords_template] <= 0.1)
873
- )
874
- return None
875
-
876
861
 
877
862
  class PeakCallerScipy(PeakCaller):
878
863
  """