pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -10,9 +10,6 @@ from tme.backends import backend as be
10
10
  from tme.memory import MATCHING_MEMORY_REGISTRY
11
11
  from tme.matching_utils import (
12
12
  compute_parallelization_schedule,
13
- elliptical_mask,
14
- box_mask,
15
- tube_mask,
16
13
  create_mask,
17
14
  scramble_phases,
18
15
  apply_convolution_mode,
@@ -50,73 +47,26 @@ class TestMatchingUtils:
50
47
  max_splits=256,
51
48
  )
52
49
 
53
- def test_create_mask(self):
50
+ @pytest.mark.parametrize("mask_type", ["ellipse", "box", "tube", "membrane"])
51
+ def test_create_mask(self, mask_type: str):
54
52
  create_mask(
55
- mask_type="ellipse",
53
+ mask_type=mask_type,
56
54
  shape=self.density.shape,
57
55
  radius=5,
58
56
  center=np.divide(self.density.shape, 2),
57
+ height=np.max(self.density.shape) // 2,
58
+ size=np.divide(self.density.shape, 2).astype(int),
59
+ thickness=2,
60
+ separation=2,
61
+ symmetry_axis=1,
62
+ inner_radius=5,
63
+ outer_radius=10,
59
64
  )
60
65
 
61
66
  def test_create_mask_error(self):
62
67
  with pytest.raises(ValueError):
63
68
  create_mask(mask_type=None)
64
69
 
65
- def test_elliptical_mask(self):
66
- elliptical_mask(
67
- shape=self.density.shape,
68
- radius=5,
69
- center=np.divide(self.density.shape, 2),
70
- )
71
-
72
- def test_box_mask(self):
73
- box_mask(
74
- shape=self.density.shape,
75
- height=[5, 10, 20],
76
- center=np.divide(self.density.shape, 2),
77
- )
78
-
79
- def test_tube_mask(self):
80
- tube_mask(
81
- shape=self.density.shape,
82
- outer_radius=10,
83
- inner_radius=5,
84
- height=5,
85
- base_center=np.divide(self.density.shape, 2),
86
- symmetry_axis=1,
87
- )
88
-
89
- def test_tube_mask_error(self):
90
- with pytest.raises(ValueError):
91
- tube_mask(
92
- shape=self.density.shape,
93
- outer_radius=5,
94
- inner_radius=10,
95
- height=5,
96
- base_center=np.divide(self.density.shape, 2),
97
- symmetry_axis=1,
98
- )
99
-
100
- with pytest.raises(ValueError):
101
- tube_mask(
102
- shape=self.density.shape,
103
- outer_radius=5,
104
- inner_radius=10,
105
- height=10 * np.max(self.density.shape),
106
- base_center=np.divide(self.density.shape, 2),
107
- symmetry_axis=1,
108
- )
109
-
110
- with pytest.raises(ValueError):
111
- tube_mask(
112
- shape=self.density.shape,
113
- outer_radius=5,
114
- inner_radius=10,
115
- height=10 * np.max(self.density.shape),
116
- base_center=np.divide(self.density.shape, 2),
117
- symmetry_axis=len(self.density.shape) + 1,
118
- )
119
-
120
70
  def test_scramble_phases(self):
121
71
  scramble_phases(arr=self.density.data, noise_proportion=0.5)
122
72
 
@@ -139,7 +89,7 @@ class TestMatchingUtils:
139
89
  expected_size = np.subtract(
140
90
  self.density.shape, self.structure_density.shape
141
91
  )
142
- expected_size += np.mod(self.structure_density.shape, 2)
92
+ expected_size += 1
143
93
  assert np.allclose(ret.shape, expected_size)
144
94
 
145
95
  def test_apply_convolution_mode_error(self):
tests/test_rotations.py CHANGED
@@ -8,6 +8,7 @@ from tme import Density
8
8
  from scipy.spatial.transform import Rotation
9
9
  from scipy.signal import correlate
10
10
 
11
+ from tme.mask import elliptical_mask
11
12
  from tme.rotations import (
12
13
  euler_from_rotationmatrix,
13
14
  euler_to_rotationmatrix,
@@ -16,7 +17,6 @@ from tme.rotations import (
16
17
  get_rotation_matrices,
17
18
  )
18
19
  from tme.matching_utils import (
19
- elliptical_mask,
20
20
  split_shape,
21
21
  compute_full_convolution_index,
22
22
  )
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.b0"
1
+ __version__ = "0.3.1"
tme/analyzer/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .peaks import *
2
2
  from .aggregation import *
3
3
  from .proxy import *
4
- from .base import *
4
+ from .base import *
tme/analyzer/_utils.py CHANGED
@@ -47,10 +47,7 @@ def _convmode_to_shape(
47
47
  if convolution_mode == "same":
48
48
  output_shape = targetshape
49
49
  elif convolution_mode == "valid":
50
- output_shape = be.add(
51
- be.subtract(targetshape, templateshape),
52
- be.mod(templateshape, 2),
53
- )
50
+ output_shape = be.subtract(targetshape, templateshape) + 1
54
51
  return be.to_backend_array(output_shape)
55
52
 
56
53
 
@@ -96,22 +93,22 @@ def cart_to_score(
96
93
  templateshape = be.to_backend_array(templateshape)
97
94
  convolution_shape = be.to_backend_array(convolution_shape)
98
95
 
99
- # Compute removed padding
100
96
  output_shape = _convmode_to_shape(
101
97
  convolution_mode=convolution_mode,
102
98
  targetshape=targetshape,
103
99
  templateshape=templateshape,
104
100
  convolution_shape=convolution_shape,
105
101
  )
106
- valid_positions = be.multiply(positions >= 0, positions < output_shape)
107
- valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
108
102
 
103
+ # Offset from padding the target
109
104
  starts = be.astype(
110
105
  be.divide(be.subtract(convolution_shape, output_shape), 2),
111
106
  be._int_dtype,
112
107
  )
113
-
114
108
  positions = be.add(positions, starts)
109
+
110
+ valid_positions = be.multiply(positions >= 0, positions < fast_shape)
111
+ valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
115
112
  if fourier_shift is not None:
116
113
  fourier_shift = be.to_backend_array(fourier_shift)
117
114
  positions = be.subtract(positions, fourier_shift)
@@ -57,22 +57,23 @@ class MaxScoreOverRotations(AbstractAnalyzer):
57
57
  The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
58
58
  instance
59
59
 
60
+ >>> import numpy as np
60
61
  >>> from tme.analyzer import MaxScoreOverRotations
61
- >>> analyzer = MaxScoreOverRotations(shape = (50, 50))
62
+ >>> analyzer = MaxScoreOverRotations(shape=(50, 50))
62
63
 
63
64
  The following simulates a template matching run by creating random data for a range
64
65
  of rotations and sending it to ``analyzer`` via its __call__ method
65
66
 
66
- >> state = analyzer.init_state()
67
+ >>> state = analyzer.init_state()
67
68
  >>> for rotation_number in range(10):
68
69
  >>> scores = np.random.rand(50,50)
69
70
  >>> rotation = np.random.rand(scores.ndim, scores.ndim)
70
- >>> state, analyzer(state, scores = scores, rotation_matrix = rotation)
71
+ >>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
71
72
 
72
73
  The aggregated scores can be extracted by invoking the result method of
73
74
  ``analyzer``
74
75
 
75
- >>> results = analyzer.result()
76
+ >>> results = analyzer.result(state)
76
77
 
77
78
  The ``results`` tuple contains (1) the maximum scores for each translation,
78
79
  (2) an offset which is relevant when merging results from split template matching
@@ -100,6 +101,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
100
101
  shm_handler: object = None,
101
102
  use_memmap: bool = False,
102
103
  inversion_mapping: bool = False,
104
+ jax_mode: bool = False,
103
105
  **kwargs,
104
106
  ):
105
107
  self._use_memmap = use_memmap
@@ -107,6 +109,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
107
109
  self._shape = tuple(int(x) for x in shape)
108
110
  self._inversion_mapping = inversion_mapping
109
111
 
112
+ self._jax_mode = jax_mode
113
+ if self._jax_mode:
114
+ self._inversion_mapping = False
115
+
110
116
  if offset is None:
111
117
  offset = be.zeros(len(self._shape), be._int_dtype)
112
118
  self._offset = be.astype(be.to_backend_array(offset), int)
@@ -138,6 +144,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
138
144
  state: Tuple,
139
145
  scores: BackendArray,
140
146
  rotation_matrix: BackendArray,
147
+ **kwargs,
141
148
  ) -> Tuple:
142
149
  """
143
150
  Update the parameter store.
@@ -167,6 +174,8 @@ class MaxScoreOverRotations(AbstractAnalyzer):
167
174
  rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
168
175
  if self._inversion_mapping:
169
176
  rotation_mapping[rotation_index] = rotation_matrix
177
+ elif self._jax_mode:
178
+ rotation_index = kwargs.get("rotation_index", 0)
170
179
  else:
171
180
  rotation = be.tobytes(rotation_matrix)
172
181
  rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
@@ -536,13 +545,19 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
536
545
  )
537
546
 
538
547
  def __call__(
539
- self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
548
+ self,
549
+ state: Tuple,
550
+ scores: BackendArray,
551
+ rotation_matrix: BackendArray,
552
+ **kwargs,
540
553
  ) -> Tuple:
541
554
  mask = self._get_constraint(rotation_matrix)
542
555
  mask = self._get_score_mask(mask=mask, scores=scores)
543
556
 
544
557
  scores = be.multiply(scores, mask, out=scores)
545
- return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
558
+ return super().__call__(
559
+ state, scores=scores, rotation_matrix=rotation_matrix, **kwargs
560
+ )
546
561
 
547
562
  def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
548
563
  """
@@ -627,7 +642,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
627
642
  return scores, rotations, {}
628
643
 
629
644
  def __call__(
630
- self, state, scores: BackendArray, rotation_matrix: BackendArray
645
+ self,
646
+ state,
647
+ scores: BackendArray,
648
+ rotation_matrix: BackendArray,
649
+ **kwargs,
631
650
  ) -> Tuple:
632
651
  prev_scores, rotations, rotation_mapping = state
633
652
 
@@ -738,5 +757,5 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
738
757
  cls._invert_rmap(master_rotation_mapping),
739
758
  )
740
759
 
741
- def _postprocess(self, **kwargs):
742
- return self
760
+ def result(self, state: Tuple, **kwargs) -> Tuple:
761
+ return state
tme/analyzer/base.py CHANGED
@@ -38,16 +38,16 @@ class AbstractAnalyzer(ABC):
38
38
 
39
39
  Returns
40
40
  -------
41
- state
42
- Initial state tuple containing the analyzer's internal data
43
- structures. The exact structure depends on the specific
44
- implementation.
41
+ state : tuple
42
+ Initial state tuple of the analyzer instance. The exact structure
43
+ depends on the specific implementation.
45
44
 
46
45
  Notes
47
46
  -----
48
47
  This method creates the initial state that will be passed to
49
- subsequent calls to __call__. The state should contain all
50
- necessary data structures for accumulating analysis results.
48
+ :py:meth:`AbstractAnalyzer.__call__` and finally to
49
+ :py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
50
+ data structures for accumulating analysis results.
51
51
  """
52
52
 
53
53
  @abstractmethod
@@ -57,49 +57,39 @@ class AbstractAnalyzer(ABC):
57
57
 
58
58
  Parameters
59
59
  ----------
60
- state : object
61
- Current analyzer state as returned by init_state() or
62
- previous calls to __call__.
60
+ state : tuple
61
+ Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
62
+ or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
63
63
  scores : BackendArray
64
- Array of scores computed for the current rotation.
64
+ Array of new scores with dimensionality d.
65
65
  rotation_matrix : BackendArray
66
- Rotation matrix used to generate the scores.
66
+ Rotation matrix used to generate scores with shape (d,d).
67
67
  **kwargs : dict
68
- Additional keyword arguments specific to the analyzer
69
- implementation.
68
+ Keyword arguments used by specific implementations.
70
69
 
71
70
  Returns
72
71
  -------
73
- state
74
- Updated analyzer state with the new scoring data incorporated.
75
-
76
- Notes
77
- -----
78
- This method should be pure functional - it should not modify
79
- the input state but return a new state with the updates applied.
80
- The exact signature may vary between implementations.
72
+ tuple
73
+ Updated analyzer state incorporating the new data.
81
74
  """
82
- pass
83
75
 
84
76
  @abstractmethod
85
77
  def result(self, state: Tuple, **kwargs) -> Tuple:
86
78
  """
87
- Finalize the analysis and produce the final result.
79
+ Finalize the analysis by performing potential post processing.
88
80
 
89
81
  Parameters
90
82
  ----------
91
83
  state : tuple
92
- Final analyzer state containing all accumulated data.
84
+ Analyzer state containing accumulated data.
93
85
  **kwargs : dict
94
- Additional keyword arguments for result processing,
95
- such as postprocessing parameters.
86
+ Keyword arguments used by specific implementations.
96
87
 
97
88
  Returns
98
89
  -------
99
90
  result
100
- Final analysis result. The exact format depends on the
101
- analyzer implementation but typically includes processed
102
- scores, rotation information, and metadata.
91
+ Final analysis result. The exact struccture depends on the
92
+ analyzer implementation.
103
93
 
104
94
  Notes
105
95
  -----
@@ -108,25 +98,24 @@ class AbstractAnalyzer(ABC):
108
98
  It may apply postprocessing operations like convolution mode
109
99
  correction or coordinate transformations.
110
100
  """
111
- pass
112
101
 
113
102
  @classmethod
114
103
  @abstractmethod
115
104
  def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
116
105
  """
117
- Merge results from multiple analyzer instances.
106
+ Merge multiple analyzer results.
118
107
 
119
108
  Parameters
120
109
  ----------
121
- results : list
122
- List of result objects as returned by the result() method
123
- from multiple analyzer instances.
110
+ results : list of tuple
111
+ List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
112
+ from different instances of the same analyzer class.
124
113
  **kwargs : dict
125
- Additional keyword arguments for merge configuration.
114
+ Keyword arguments used by specific implementations.
126
115
 
127
116
  Returns
128
117
  -------
129
- merged_result
118
+ tuple
130
119
  Single result object combining all input results.
131
120
 
132
121
  Notes
tme/analyzer/peaks.py CHANGED
@@ -17,9 +17,9 @@ from skimage.registration._phase_cross_correlation import _upsampled_dft
17
17
  from .base import AbstractAnalyzer
18
18
  from ._utils import score_to_cart
19
19
  from ..backends import backend as be
20
- from ..matching_utils import split_shape
21
20
  from ..types import BackendArray, NDArray
22
21
  from ..rotations import euler_to_rotationmatrix
22
+ from ..matching_utils import split_shape, compute_extraction_box
23
23
 
24
24
  __all__ = [
25
25
  "PeakCaller",
@@ -228,10 +228,15 @@ class PeakCaller(AbstractAnalyzer):
228
228
  translations = be.full(
229
229
  (self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
230
230
  )
231
+
232
+ rdim = len(self.shape)
233
+ if self.batch_dims:
234
+ rdim -= len(self.batch_dims)
235
+
231
236
  rotations = be.full(
232
- (self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
237
+ (self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
233
238
  )
234
- for i in range(ndim):
239
+ for i in range(rdim):
235
240
  rotations[:, i, i] = 1.0
236
241
 
237
242
  scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
@@ -750,8 +755,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
750
755
  Dictionary mapping values in rotations to Euler angles.
751
756
  By default None
752
757
  min_score : float
753
- Minimum score value to consider. If provided, superseeds limit given
754
- by :py:attr:`PeakCaller.num_peaks`.
758
+ Minimum score value to consider.
755
759
 
756
760
  Returns
757
761
  -------
@@ -765,54 +769,57 @@ class PeakCallerRecursiveMasking(PeakCaller):
765
769
  values. If rotations and rotation_mapping is provided, the respective
766
770
  rotation will be applied to the mask, otherwise rotation_matrix is used.
767
771
  """
768
- coordinates, masking_function = [], self._mask_scores_rotate
772
+ peaks = []
773
+ box = tuple(self.min_distance for _ in range(scores.ndim))
769
774
 
770
- if mask is None:
771
- masking_function = self._mask_scores_box
772
- shape = tuple(self.min_distance for _ in range(scores.ndim))
773
- mask = be.zeros(shape, dtype=be._float_dtype)
775
+ scores = be.to_backend_array(scores)
776
+ if mask is not None:
777
+ box = mask.shape
778
+ mask = be.to_backend_array(mask)
779
+ mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
774
780
 
775
- rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
776
-
777
- peak_limit = self.num_peaks
778
- if min_score is not None:
779
- peak_limit = be.size(scores)
780
- else:
781
+ if min_score is None:
781
782
  min_score = be.min(scores) - 1
782
783
 
783
- scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
784
- scores_copy[:] = scores
785
-
784
+ _scores = be.zeros(scores.shape, dtype=scores.dtype)
785
+ _scores[:] = scores[:]
786
786
  while True:
787
- be.argmax(scores_copy)
788
- peak = be.unravel_index(
789
- indices=be.argmax(scores_copy), shape=scores_copy.shape
790
- )
791
- if scores_copy[tuple(peak)] < min_score:
787
+ peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
788
+ if _scores[tuple(peak)] < min_score:
792
789
  break
790
+ peaks.append(peak)
793
791
 
794
- coordinates.append(peak)
795
-
796
- current_rotation_matrix = self._get_rotation_matrix(
797
- peak=peak,
798
- rotation_space=rotations,
799
- rotation_mapping=rotation_mapping,
800
- rotation_matrix=rotation_matrix,
792
+ score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
793
+ centers=be.to_backend_array(peak)[None],
794
+ extraction_shape=box,
795
+ original_shape=scores.shape,
801
796
  )
802
-
803
- masking_function(
804
- scores=scores_copy,
805
- rotation_matrix=current_rotation_matrix,
806
- peak=peak,
807
- mask=mask,
808
- rotated_template=rotated_template,
797
+ score_slice = tuple(
798
+ slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
799
+ )
800
+ tmpl_slice = tuple(
801
+ slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
809
802
  )
810
803
 
811
- if len(coordinates) >= peak_limit:
804
+ score_mask = 0
805
+ if mask is not None:
806
+ mask_buffer.fill(0)
807
+ rmat = self._get_rotation_matrix(
808
+ peak=peak,
809
+ rotation_space=rotations,
810
+ rotation_mapping=rotation_mapping,
811
+ rotation_matrix=rotation_matrix,
812
+ )
813
+ be.rigid_transform(
814
+ arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
815
+ )
816
+ score_mask = mask_buffer[tmpl_slice] <= 0.1
817
+
818
+ _scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
819
+ if len(peaks) >= self.num_peaks:
812
820
  break
813
821
 
814
- peaks = be.to_backend_array(coordinates)
815
- return peaks, None
822
+ return be.to_backend_array(peaks), None
816
823
 
817
824
  @staticmethod
818
825
  def _get_rotation_matrix(
@@ -845,93 +852,13 @@ class PeakCallerRecursiveMasking(PeakCaller):
845
852
 
846
853
  rotation = rotation_mapping[rotation_space[tuple(peak)]]
847
854
 
848
- # TODO: Newer versions of rotation mapping contain rotation matrices not angles
855
+ # Old versions of rotation mapping contained Euler angles
849
856
  if rotation.ndim != 2:
850
857
  rotation = be.to_backend_array(
851
858
  euler_to_rotationmatrix(be.to_numpy_array(rotation))
852
859
  )
853
860
  return rotation
854
861
 
855
- @staticmethod
856
- def _mask_scores_box(
857
- scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
858
- ) -> None:
859
- """
860
- Mask scores in a box around a peak.
861
-
862
- Parameters
863
- ----------
864
- scores : BackendArray
865
- Data array of scores.
866
- peak : BackendArray
867
- Peak coordinates.
868
- mask : BackendArray
869
- Mask array.
870
- """
871
- start = be.maximum(be.subtract(peak, mask.shape), 0)
872
- stop = be.minimum(be.add(peak, mask.shape), scores.shape)
873
- start, stop = be.astype(start, int), be.astype(stop, int)
874
- coords = tuple(slice(*pos) for pos in zip(start, stop))
875
- scores[coords] = 0
876
- return None
877
-
878
- @staticmethod
879
- def _mask_scores_rotate(
880
- scores: BackendArray,
881
- peak: BackendArray,
882
- mask: BackendArray,
883
- rotated_template: BackendArray,
884
- rotation_matrix: BackendArray,
885
- **kwargs: Dict,
886
- ) -> None:
887
- """
888
- Mask scores using mask rotation around a peak.
889
-
890
- Parameters
891
- ----------
892
- scores : BackendArray
893
- Data array of scores.
894
- peak : BackendArray
895
- Peak coordinates.
896
- mask : BackendArray
897
- Mask array.
898
- rotated_template : BackendArray
899
- Empty array to write mask rotations to.
900
- rotation_matrix : BackendArray
901
- Rotation matrix.
902
- """
903
- left_pad = be.divide(mask.shape, 2).astype(int)
904
- right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
905
-
906
- score_start = be.subtract(peak, left_pad)
907
- score_stop = be.add(peak, right_pad)
908
-
909
- template_start = be.subtract(be.maximum(score_start, 0), score_start)
910
- template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
911
- template_stop = be.subtract(mask.shape, template_stop)
912
-
913
- score_start = be.maximum(score_start, 0)
914
- score_stop = be.minimum(score_stop, scores.shape)
915
- score_start = be.astype(score_start, int)
916
- score_stop = be.astype(score_stop, int)
917
-
918
- template_start = be.astype(template_start, int)
919
- template_stop = be.astype(template_stop, int)
920
- coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
921
- coords_template = tuple(
922
- slice(*pos) for pos in zip(template_start, template_stop)
923
- )
924
-
925
- rotated_template.fill(0)
926
- be.rigid_transform(
927
- arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
928
- )
929
-
930
- scores[coords_score] = be.multiply(
931
- scores[coords_score], (rotated_template[coords_template] <= 0.1)
932
- )
933
- return None
934
-
935
862
 
936
863
  class PeakCallerScipy(PeakCaller):
937
864
  """
tme/analyzer/proxy.py CHANGED
@@ -18,6 +18,7 @@ from ..backends import backend as be
18
18
 
19
19
  __all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
20
20
 
21
+
21
22
  class StatelessSharedAnalyzerProxy:
22
23
  """
23
24
  Proxy that wraps functional analyzers for concurrent access via shared memory.