pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py ADDED
@@ -0,0 +1,127 @@
1
+ """
2
+ Implements abstract base class for template matching analyzers.
3
+
4
+ Copyright (c) 2025 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+
9
+ from typing import Tuple, List
10
+ from abc import ABC, abstractmethod
11
+
12
+ __all__ = ["AbstractAnalyzer"]
13
+
14
+
15
+ class AbstractAnalyzer(ABC):
16
+ """
17
+ Abstract base class for template matching analyzers.
18
+ """
19
+
20
+ @property
21
+ def shareable(self):
22
+ """
23
+ Indicate whether the analyzer can be shared across processes.
24
+
25
+ Returns
26
+ -------
27
+ bool
28
+ True if the analyzer supports shared memory operations
29
+ and can be safely used across multiple processes, False
30
+ if it should only be used within a single process.
31
+ """
32
+ return False
33
+
34
+ @abstractmethod
35
+ def init_state(self, *args, **kwargs) -> Tuple:
36
+ """
37
+ Initialize the analyzer state.
38
+
39
+ Returns
40
+ -------
41
+ state : tuple
42
+ Initial state tuple of the analyzer instance. The exact structure
43
+ depends on the specific implementation.
44
+
45
+ Notes
46
+ -----
47
+ This method creates the initial state that will be passed to
48
+ :py:meth:`AbstractAnalyzer.__call__` and finally to
49
+ :py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
50
+ data structures for accumulating analysis results.
51
+ """
52
+
53
+ @abstractmethod
54
+ def __call__(self, state, scores, rotation_matrix, **kwargs) -> Tuple:
55
+ """
56
+ Update the analyzer state with new scoring data.
57
+
58
+ Parameters
59
+ ----------
60
+ state : tuple
61
+ Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
62
+ or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
63
+ scores : BackendArray
64
+ Array of new scores with dimensionality d.
65
+ rotation_matrix : BackendArray
66
+ Rotation matrix used to generate scores with shape (d,d).
67
+ **kwargs : dict
68
+ Keyword arguments used by specific implementations.
69
+
70
+ Returns
71
+ -------
72
+ tuple
73
+ Updated analyzer state incorporating the new data.
74
+ """
75
+
76
+ @abstractmethod
77
+ def result(self, state: Tuple, **kwargs) -> Tuple:
78
+ """
79
+ Finalize the analysis by performing potential post processing.
80
+
81
+ Parameters
82
+ ----------
83
+ state : tuple
84
+ Analyzer state containing accumulated data.
85
+ **kwargs : dict
86
+ Keyword arguments used by specific implementations.
87
+
88
+ Returns
89
+ -------
90
+ result
91
+ Final analysis result. The exact struccture depends on the
92
+ analyzer implementation.
93
+
94
+ Notes
95
+ -----
96
+ This method converts the internal analyzer state into the
97
+ final output format expected by the template matching pipeline.
98
+ It may apply postprocessing operations like convolution mode
99
+ correction or coordinate transformations.
100
+ """
101
+
102
+ @classmethod
103
+ @abstractmethod
104
+ def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
105
+ """
106
+ Merge multiple analyzer results.
107
+
108
+ Parameters
109
+ ----------
110
+ results : list of tuple
111
+ List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
112
+ from different instances of the same analyzer class.
113
+ **kwargs : dict
114
+ Keyword arguments used by specific implementations.
115
+
116
+ Returns
117
+ -------
118
+ tuple
119
+ Single result object combining all input results.
120
+
121
+ Notes
122
+ -----
123
+ This method enables parallel processing by allowing results
124
+ from different processes or splits to be combined into a
125
+ unified result. The merge operation should handle overlapping
126
+ data appropriately and maintain consistency.
127
+ """
tme/analyzer/peaks.py CHANGED
@@ -1,23 +1,25 @@
1
- """ Implements classes to analyze outputs from exhaustive template matching.
1
+ """
2
+ Implements classes to analyze outputs from exhaustive template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from functools import wraps
9
- from abc import ABC, abstractmethod
10
- from typing import Tuple, List, Dict, Generator
10
+ from abc import abstractmethod
11
+ from typing import Tuple, List, Dict
11
12
 
12
13
  import numpy as np
13
14
  from skimage.feature import peak_local_max
14
15
  from skimage.registration._phase_cross_correlation import _upsampled_dft
15
16
 
17
+ from .base import AbstractAnalyzer
16
18
  from ._utils import score_to_cart
17
19
  from ..backends import backend as be
18
- from ..matching_utils import split_shape
19
20
  from ..types import BackendArray, NDArray
20
21
  from ..rotations import euler_to_rotationmatrix
22
+ from ..matching_utils import split_shape, compute_extraction_box
21
23
 
22
24
  __all__ = [
23
25
  "PeakCaller",
@@ -141,7 +143,7 @@ def batchify(shape: Tuple[int], batch_dims: Tuple[int] = None) -> List:
141
143
  yield from _generate_slices_recursive(0, ())
142
144
 
143
145
 
144
- class PeakCaller(ABC):
146
+ class PeakCaller(AbstractAnalyzer):
145
147
  """
146
148
  Base class for peak calling algorithms.
147
149
 
@@ -190,16 +192,7 @@ class PeakCaller(ABC):
190
192
  if min_boundary_distance < 0:
191
193
  raise ValueError("min_boundary_distance has to be non-negative.")
192
194
 
193
- ndim = len(shape)
194
- self.translations = be.full(
195
- (num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
196
- )
197
- self.rotations = be.full(
198
- (num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
199
- )
200
- self.scores = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
201
- self.details = be.full((num_peaks,), fill_value=0, dtype=be._float_dtype)
202
-
195
+ self.shape = shape
203
196
  self.num_peaks = int(num_peaks)
204
197
  self.min_distance = int(min_distance)
205
198
  self.min_boundary_distance = int(min_boundary_distance)
@@ -210,31 +203,47 @@ class PeakCaller(ABC):
210
203
 
211
204
  self.min_score, self.max_score = min_score, max_score
212
205
 
213
- # Postprocessing arguments
214
- self.fourier_shift = kwargs.get("fourier_shift", None)
215
- self.convolution_mode = kwargs.get("convolution_mode", None)
216
- self.targetshape = kwargs.get("targetshape", None)
217
- self.templateshape = kwargs.get("templateshape", None)
218
-
219
- def __iter__(self) -> Generator:
206
+ @abstractmethod
207
+ def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
220
208
  """
221
- Returns a generator to list objects containing translation,
222
- rotation, score and details of a given candidate.
209
+ Call peaks in the score space.
210
+
211
+ Parameters
212
+ ----------
213
+ scores : BackendArray
214
+ Score array to update analyzer with.
215
+ **kwargs : dict
216
+ Optional keyword arguments passed to underlying implementations.
217
+
218
+ Returns
219
+ -------
220
+ BackendArray
221
+ Peak positions (n, d).
222
+ BackendArray
223
+ Peak details (n, d).
223
224
  """
224
- self.peak_list = [
225
- be.to_cpu_array(self.translations),
226
- be.to_cpu_array(self.rotations),
227
- be.to_cpu_array(self.scores),
228
- be.to_cpu_array(self.details),
229
- ]
230
- yield from self.peak_list
225
+
226
+ def init_state(self):
227
+ ndim = len(self.shape)
228
+ translations = be.full(
229
+ (self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
230
+ )
231
+ rotations = be.full(
232
+ (self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
233
+ )
234
+ for i in range(ndim):
235
+ rotations[:, i, i] = 1.0
236
+
237
+ scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
238
+ details = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
239
+ return translations, rotations, scores, details
231
240
 
232
241
  def _get_peak_mask(self, peaks: BackendArray, scores: BackendArray) -> BackendArray:
233
242
  if not len(peaks):
234
243
  return None
235
244
 
236
245
  valid_peaks = be.full((peaks.shape[0],), fill_value=1) == 1
237
- if self.min_boundary_distance > 0:
246
+ if self.min_boundary_distance >= 0:
238
247
  upper_limit = be.subtract(
239
248
  be.to_backend_array(scores.shape), self.min_boundary_distance
240
249
  )
@@ -315,20 +324,34 @@ class PeakCaller(ABC):
315
324
  peak_positions = be.astype(peak_positions, int)
316
325
  return peak_positions, peak_details
317
326
 
318
- def __call__(self, scores: BackendArray, rotation_matrix: BackendArray, **kwargs):
327
+ def __call__(
328
+ self,
329
+ state: Tuple,
330
+ scores: BackendArray,
331
+ rotation_matrix: BackendArray,
332
+ **kwargs,
333
+ ) -> Tuple:
319
334
  """
320
335
  Update the internal parameter store based on input array.
321
336
 
322
337
  Parameters
323
338
  ----------
339
+ state : tuple
340
+ Current state tuple where:
341
+ - positions : BackendArray, (n, d) of peak positions
342
+ - rotations : BackendArray, (n, d, d) of correponding rotations
343
+ - scores : BackendArray, (n, ) of peak scores
344
+ - details : BackendArray, (n, ) of peak details
324
345
  scores : BackendArray
325
- Score space data.
346
+ Array of new scores to update analyzer with.
326
347
  rotation_matrix : BackendArray
327
348
  Rotation matrix used to obtain the score array.
328
349
  **kwargs
329
350
  Optional keyword aguments passed to :py:meth:`PeakCaller.call_peaks`.
330
351
  """
331
- for ret in self._call_peaks(scores=scores, rotation_matrix=rotation_matrix):
352
+ for ret in self._call_peaks(
353
+ scores=scores, rotation_matrix=rotation_matrix, **kwargs
354
+ ):
332
355
  peak_positions, peak_details = ret
333
356
  if peak_positions is None:
334
357
  continue
@@ -341,7 +364,6 @@ class PeakCaller(ABC):
341
364
  peak_scores = scores[tuple(peak_positions.T)]
342
365
  if peak_details is not None:
343
366
  peak_details = peak_details[valid_peaks]
344
- # peak_details, peak_scores = peak_scores, -peak_details
345
367
  else:
346
368
  peak_details = be.full(peak_scores.shape, fill_value=-1)
347
369
 
@@ -351,66 +373,57 @@ class PeakCaller(ABC):
351
373
  axis=0,
352
374
  )
353
375
 
354
- self._update(
376
+ state = self._update(
377
+ state,
355
378
  peak_positions=peak_positions,
356
379
  peak_details=peak_details,
357
380
  peak_scores=peak_scores,
358
- rotations=rotations,
381
+ peak_rotations=rotations,
359
382
  )
360
383
 
361
- return None
362
-
363
- @abstractmethod
364
- def call_peaks(self, scores: BackendArray, **kwargs) -> PeakType:
365
- """
366
- Call peaks in the score space.
367
-
368
- Parameters
369
- ----------
370
- scores : BackendArray
371
- Score array.
372
- **kwargs : dict
373
- Optional keyword arguments passed to underlying implementations.
374
-
375
- Returns
376
- -------
377
- Tuple[BackendArray, BackendArray]
378
- Array of peak coordinates and peak details.
379
- """
384
+ return state
380
385
 
381
386
  @classmethod
382
- def merge(cls, candidates=List[List], **kwargs) -> Tuple:
387
+ def merge(cls, results=List[Tuple], **kwargs) -> Tuple:
383
388
  """
384
389
  Merge multiple instances of :py:class:`PeakCaller`.
385
390
 
386
391
  Parameters
387
392
  ----------
388
- candidates : list of lists
389
- Obtained by invoking list on the generator returned by __iter__.
393
+ results : list of tuple
394
+ List of instance results created by applying `result`.
390
395
  **kwargs
391
396
  Optional keyword arguments.
392
397
 
393
398
  Returns
394
399
  -------
395
- Tuple
396
- Tuple of translation, rotation, score and details of candidates.
400
+ NDArray
401
+ Peak positions (n, d).
402
+ NDArray
403
+ Peak rotation matrices (n, d, d).
404
+ NDArray
405
+ Peak scores (n, ).
406
+ NDArray
407
+ Peak details (n,).
397
408
  """
398
409
  if "shape" not in kwargs:
399
- kwargs["shape"] = tuple(1 for _ in range(candidates[0][0].shape[1]))
410
+ kwargs["shape"] = tuple(1 for _ in range(results[0][0].shape[1]))
400
411
 
401
412
  base = cls(**kwargs)
402
- for candidate in candidates:
403
- if len(candidate) == 0:
413
+ base_state = base.init_state()
414
+ for result in results:
415
+ if len(result) == 0:
404
416
  continue
405
- peak_positions, rotations, peak_scores, peak_details = candidate
406
- base._update(
417
+ peak_positions, rotations, peak_scores, peak_details = result
418
+ base_state = base._update(
419
+ base_state,
407
420
  peak_positions=be.to_backend_array(peak_positions),
408
421
  peak_details=be.to_backend_array(peak_details),
409
422
  peak_scores=be.to_backend_array(peak_scores),
410
- rotations=be.to_backend_array(rotations),
423
+ peak_rotations=be.to_backend_array(rotations),
411
424
  offset=kwargs.get("offset", None),
412
425
  )
413
- return tuple(base)
426
+ return base_state
414
427
 
415
428
  @staticmethod
416
429
  def oversample_peaks(
@@ -517,10 +530,11 @@ class PeakCaller(ABC):
517
530
 
518
531
  def _update(
519
532
  self,
533
+ state,
520
534
  peak_positions: BackendArray,
521
535
  peak_details: BackendArray,
522
536
  peak_scores: BackendArray,
523
- rotations: BackendArray,
537
+ peak_rotations: BackendArray,
524
538
  offset: BackendArray = None,
525
539
  ):
526
540
  """
@@ -539,14 +553,15 @@ class PeakCaller(ABC):
539
553
  offset : BackendArray, optional
540
554
  Translation offset, e.g. from splitting, (d, ).
541
555
  """
556
+ translations, rotations, scores, details = state
542
557
  if offset is not None:
543
558
  offset = be.astype(be.to_backend_array(offset), peak_positions.dtype)
544
559
  peak_positions = be.add(peak_positions, offset, out=peak_positions)
545
560
 
546
- positions = be.concatenate((self.translations, peak_positions))
547
- rotations = be.concatenate((self.rotations, rotations))
548
- scores = be.concatenate((self.scores, peak_scores))
549
- details = be.concatenate((self.details, peak_details))
561
+ positions = be.concatenate((translations, peak_positions))
562
+ rotations = be.concatenate((rotations, peak_rotations))
563
+ scores = be.concatenate((scores, peak_scores))
564
+ details = be.concatenate((details, peak_details))
550
565
 
551
566
  # topk filtering after distances yields more distributed peak calls
552
567
  distance_order = filter_points_indices(
@@ -561,22 +576,69 @@ class PeakCaller(ABC):
561
576
  )
562
577
  final_order = distance_order[top_scores]
563
578
 
564
- self.translations = positions[final_order, :]
565
- self.rotations = rotations[final_order, :]
566
- self.scores = scores[final_order]
567
- self.details = details[final_order]
579
+ translations = positions[final_order, :]
580
+ rotations = rotations[final_order, :]
581
+ scores = scores[final_order]
582
+ details = details[final_order]
583
+ return translations, rotations, scores, details
568
584
 
569
- def _postprocess(self, **kwargs):
570
- if not len(self.translations):
571
- return self
585
+ def result(
586
+ self,
587
+ state,
588
+ fast_shape: Tuple[int] = None,
589
+ targetshape: Tuple[int] = None,
590
+ templateshape: Tuple[int] = None,
591
+ convolution_shape: Tuple[int] = None,
592
+ fourier_shift: Tuple[int] = None,
593
+ convolution_mode: str = None,
594
+ **kwargs,
595
+ ):
596
+ """
597
+ Finalize the analysis result with optional postprocessing.
572
598
 
573
- positions, valid_peaks = score_to_cart(self.translations, **kwargs)
599
+ Parameters
600
+ ----------
601
+ state : tuple
602
+ Current state tuple where:
603
+ - positions : BackendArray, (n, d) of peak positions
604
+ - rotations : BackendArray, (n, d, d) of correponding rotations
605
+ - scores : BackendArray, (n, ) of peak scores
606
+ - details : BackendArray, (n, ) of peak details
607
+ targetshape : Tuple[int], optional
608
+ Shape of the target for convolution mode correction.
609
+ templateshape : Tuple[int], optional
610
+ Shape of the template for convolution mode correction.
611
+ convolution_shape : Tuple[int], optional
612
+ Shape used for convolution.
613
+ fourier_shift : Tuple[int], optional.
614
+ Shift to apply for Fourier correction.
615
+ convolution_mode : str, optional
616
+ Convolution mode for padding correction.
617
+ **kwargs
618
+ Additional keyword arguments.
574
619
 
575
- self.translations = positions[valid_peaks]
576
- self.rotations = self.rotations[valid_peaks]
577
- self.scores = self.scores[valid_peaks]
578
- self.details = self.details[valid_peaks]
579
- return self
620
+ Returns
621
+ -------
622
+ tuple
623
+ Final result tuple (positions, rotations, scores, details).
624
+ """
625
+ translations, rotations, scores, details = state
626
+
627
+ positions, valid_peaks = score_to_cart(
628
+ positions=translations,
629
+ fast_shape=fast_shape,
630
+ targetshape=targetshape,
631
+ templateshape=templateshape,
632
+ convolution_shape=convolution_shape,
633
+ fourier_shift=fourier_shift,
634
+ convolution_mode=convolution_mode,
635
+ **kwargs,
636
+ )
637
+ translations = be.to_cpu_array(positions[valid_peaks])
638
+ rotations = be.to_cpu_array(rotations[valid_peaks])
639
+ scores = be.to_cpu_array(scores[valid_peaks])
640
+ details = be.to_cpu_array(details[valid_peaks])
641
+ return translations, rotations, scores, details
580
642
 
581
643
 
582
644
  class PeakCallerSort(PeakCaller):
@@ -703,14 +765,14 @@ class PeakCallerRecursiveMasking(PeakCaller):
703
765
  values. If rotations and rotation_mapping is provided, the respective
704
766
  rotation will be applied to the mask, otherwise rotation_matrix is used.
705
767
  """
706
- coordinates, masking_function = [], self._mask_scores_rotate
768
+ peaks = []
769
+ box = tuple(self.min_distance for _ in range(scores.ndim))
707
770
 
708
- if mask is None:
709
- masking_function = self._mask_scores_box
710
- shape = tuple(self.min_distance for _ in range(scores.ndim))
711
- mask = be.zeros(shape, dtype=be._float_dtype)
712
-
713
- rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
771
+ scores = be.to_backend_array(scores)
772
+ if mask is not None:
773
+ box = mask.shape
774
+ mask = be.to_backend_array(mask)
775
+ mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
714
776
 
715
777
  peak_limit = self.num_peaks
716
778
  if min_score is not None:
@@ -718,39 +780,45 @@ class PeakCallerRecursiveMasking(PeakCaller):
718
780
  else:
719
781
  min_score = be.min(scores) - 1
720
782
 
721
- scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
722
- scores_copy[:] = scores
723
-
783
+ _scores = be.zeros(scores.shape, dtype=scores.dtype)
784
+ _scores[:] = scores[:]
724
785
  while True:
725
- be.argmax(scores_copy)
726
- peak = be.unravel_index(
727
- indices=be.argmax(scores_copy), shape=scores_copy.shape
728
- )
729
- if scores_copy[tuple(peak)] < min_score:
786
+ peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
787
+ if _scores[tuple(peak)] < min_score:
730
788
  break
789
+ peaks.append(peak)
731
790
 
732
- coordinates.append(peak)
733
-
734
- current_rotation_matrix = self._get_rotation_matrix(
735
- peak=peak,
736
- rotation_space=rotations,
737
- rotation_mapping=rotation_mapping,
738
- rotation_matrix=rotation_matrix,
791
+ score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
792
+ centers=be.to_backend_array(peak)[None],
793
+ extraction_shape=box,
794
+ original_shape=scores.shape,
739
795
  )
740
-
741
- masking_function(
742
- scores=scores_copy,
743
- rotation_matrix=current_rotation_matrix,
744
- peak=peak,
745
- mask=mask,
746
- rotated_template=rotated_template,
796
+ score_slice = tuple(
797
+ slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
798
+ )
799
+ tmpl_slice = tuple(
800
+ slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
747
801
  )
748
802
 
749
- if len(coordinates) >= peak_limit:
803
+ score_mask = 0
804
+ if mask is not None:
805
+ mask_buffer.fill(0)
806
+ rmat = self._get_rotation_matrix(
807
+ peak=peak,
808
+ rotation_space=rotations,
809
+ rotation_mapping=rotation_mapping,
810
+ rotation_matrix=rotation_matrix,
811
+ )
812
+ be.rigid_transform(
813
+ arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
814
+ )
815
+ score_mask = mask_buffer[tmpl_slice] <= 0.1
816
+
817
+ _scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
818
+ if len(peaks) >= peak_limit:
750
819
  break
751
820
 
752
- peaks = be.to_backend_array(coordinates)
753
- return peaks, None
821
+ return be.to_backend_array(peaks), None
754
822
 
755
823
  @staticmethod
756
824
  def _get_rotation_matrix(
@@ -790,86 +858,6 @@ class PeakCallerRecursiveMasking(PeakCaller):
790
858
  )
791
859
  return rotation
792
860
 
793
- @staticmethod
794
- def _mask_scores_box(
795
- scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
796
- ) -> None:
797
- """
798
- Mask scores in a box around a peak.
799
-
800
- Parameters
801
- ----------
802
- scores : BackendArray
803
- Data array of scores.
804
- peak : BackendArray
805
- Peak coordinates.
806
- mask : BackendArray
807
- Mask array.
808
- """
809
- start = be.maximum(be.subtract(peak, mask.shape), 0)
810
- stop = be.minimum(be.add(peak, mask.shape), scores.shape)
811
- start, stop = be.astype(start, int), be.astype(stop, int)
812
- coords = tuple(slice(*pos) for pos in zip(start, stop))
813
- scores[coords] = 0
814
- return None
815
-
816
- @staticmethod
817
- def _mask_scores_rotate(
818
- scores: BackendArray,
819
- peak: BackendArray,
820
- mask: BackendArray,
821
- rotated_template: BackendArray,
822
- rotation_matrix: BackendArray,
823
- **kwargs: Dict,
824
- ) -> None:
825
- """
826
- Mask scores using mask rotation around a peak.
827
-
828
- Parameters
829
- ----------
830
- scores : BackendArray
831
- Data array of scores.
832
- peak : BackendArray
833
- Peak coordinates.
834
- mask : BackendArray
835
- Mask array.
836
- rotated_template : BackendArray
837
- Empty array to write mask rotations to.
838
- rotation_matrix : BackendArray
839
- Rotation matrix.
840
- """
841
- left_pad = be.divide(mask.shape, 2).astype(int)
842
- right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
843
-
844
- score_start = be.subtract(peak, left_pad)
845
- score_stop = be.add(peak, right_pad)
846
-
847
- template_start = be.subtract(be.maximum(score_start, 0), score_start)
848
- template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
849
- template_stop = be.subtract(mask.shape, template_stop)
850
-
851
- score_start = be.maximum(score_start, 0)
852
- score_stop = be.minimum(score_stop, scores.shape)
853
- score_start = be.astype(score_start, int)
854
- score_stop = be.astype(score_stop, int)
855
-
856
- template_start = be.astype(template_start, int)
857
- template_stop = be.astype(template_stop, int)
858
- coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
859
- coords_template = tuple(
860
- slice(*pos) for pos in zip(template_start, template_stop)
861
- )
862
-
863
- rotated_template.fill(0)
864
- be.rigid_transform(
865
- arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
866
- )
867
-
868
- scores[coords_score] = be.multiply(
869
- scores[coords_score], (rotated_template[coords_template] <= 0.1)
870
- )
871
- return None
872
-
873
861
 
874
862
  class PeakCallerScipy(PeakCaller):
875
863
  """