pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/match_template.py +163 -201
  3. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +48 -39
  4. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +3 -4
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +14 -14
  8. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/RECORD +54 -50
  9. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/match_template.py +163 -201
  14. scripts/match_template_filters.py +1200 -0
  15. scripts/postprocess.py +48 -39
  16. scripts/preprocess.py +10 -23
  17. scripts/preprocessor_gui.py +3 -4
  18. scripts/pytme_runner.py +769 -0
  19. scripts/refine_matches.py +0 -1
  20. tests/preprocessing/test_frequency_filters.py +19 -10
  21. tests/test_analyzer.py +122 -122
  22. tests/test_backends.py +1 -0
  23. tests/test_matching_cli.py +30 -30
  24. tests/test_matching_data.py +5 -5
  25. tests/test_matching_utils.py +1 -1
  26. tme/__version__.py +1 -1
  27. tme/analyzer/__init__.py +1 -1
  28. tme/analyzer/_utils.py +1 -4
  29. tme/analyzer/aggregation.py +15 -6
  30. tme/analyzer/base.py +25 -36
  31. tme/analyzer/peaks.py +39 -113
  32. tme/analyzer/proxy.py +1 -0
  33. tme/backends/_jax_utils.py +16 -15
  34. tme/backends/cupy_backend.py +9 -13
  35. tme/backends/jax_backend.py +19 -16
  36. tme/backends/npfftw_backend.py +27 -25
  37. tme/backends/pytorch_backend.py +4 -0
  38. tme/density.py +5 -4
  39. tme/filters/__init__.py +2 -2
  40. tme/filters/_utils.py +32 -7
  41. tme/filters/bandpass.py +225 -186
  42. tme/filters/ctf.py +117 -67
  43. tme/filters/reconstruction.py +38 -9
  44. tme/filters/wedge.py +88 -105
  45. tme/filters/whitening.py +1 -6
  46. tme/matching_data.py +24 -36
  47. tme/matching_exhaustive.py +14 -11
  48. tme/matching_scores.py +21 -12
  49. tme/matching_utils.py +13 -6
  50. tme/orientations.py +13 -3
  51. tme/parser.py +109 -29
  52. tme/preprocessor.py +2 -2
  53. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  54. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  55. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/analyzer/base.py CHANGED
@@ -38,16 +38,16 @@ class AbstractAnalyzer(ABC):
38
38
 
39
39
  Returns
40
40
  -------
41
- state
42
- Initial state tuple containing the analyzer's internal data
43
- structures. The exact structure depends on the specific
44
- implementation.
41
+ state : tuple
42
+ Initial state tuple of the analyzer instance. The exact structure
43
+ depends on the specific implementation.
45
44
 
46
45
  Notes
47
46
  -----
48
47
  This method creates the initial state that will be passed to
49
- subsequent calls to __call__. The state should contain all
50
- necessary data structures for accumulating analysis results.
48
+ :py:meth:`AbstractAnalyzer.__call__` and finally to
49
+ :py:meth:`AbstractAnalyzer.result`. The state should contain all necessary
50
+ data structures for accumulating analysis results.
51
51
  """
52
52
 
53
53
  @abstractmethod
@@ -57,49 +57,39 @@ class AbstractAnalyzer(ABC):
57
57
 
58
58
  Parameters
59
59
  ----------
60
- state : object
61
- Current analyzer state as returned by init_state() or
62
- previous calls to __call__.
60
+ state : tuple
61
+ Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
62
+ or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
63
63
  scores : BackendArray
64
- Array of scores computed for the current rotation.
64
+ Array of new scores with dimensionality d.
65
65
  rotation_matrix : BackendArray
66
- Rotation matrix used to generate the scores.
66
+ Rotation matrix used to generate scores with shape (d,d).
67
67
  **kwargs : dict
68
- Additional keyword arguments specific to the analyzer
69
- implementation.
68
+ Keyword arguments used by specific implementations.
70
69
 
71
70
  Returns
72
71
  -------
73
- state
74
- Updated analyzer state with the new scoring data incorporated.
75
-
76
- Notes
77
- -----
78
- This method should be pure functional - it should not modify
79
- the input state but return a new state with the updates applied.
80
- The exact signature may vary between implementations.
72
+ tuple
73
+ Updated analyzer state incorporating the new data.
81
74
  """
82
- pass
83
75
 
84
76
  @abstractmethod
85
77
  def result(self, state: Tuple, **kwargs) -> Tuple:
86
78
  """
87
- Finalize the analysis and produce the final result.
79
+ Finalize the analysis by performing potential post processing.
88
80
 
89
81
  Parameters
90
82
  ----------
91
83
  state : tuple
92
- Final analyzer state containing all accumulated data.
84
+ Analyzer state containing accumulated data.
93
85
  **kwargs : dict
94
- Additional keyword arguments for result processing,
95
- such as postprocessing parameters.
86
+ Keyword arguments used by specific implementations.
96
87
 
97
88
  Returns
98
89
  -------
99
90
  result
100
- Final analysis result. The exact format depends on the
101
- analyzer implementation but typically includes processed
102
- scores, rotation information, and metadata.
91
+ Final analysis result. The exact struccture depends on the
92
+ analyzer implementation.
103
93
 
104
94
  Notes
105
95
  -----
@@ -108,25 +98,24 @@ class AbstractAnalyzer(ABC):
108
98
  It may apply postprocessing operations like convolution mode
109
99
  correction or coordinate transformations.
110
100
  """
111
- pass
112
101
 
113
102
  @classmethod
114
103
  @abstractmethod
115
104
  def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
116
105
  """
117
- Merge results from multiple analyzer instances.
106
+ Merge multiple analyzer results.
118
107
 
119
108
  Parameters
120
109
  ----------
121
- results : list
122
- List of result objects as returned by the result() method
123
- from multiple analyzer instances.
110
+ results : list of tuple
111
+ List of tuple objects returned by :py:meth:`AbstractAnalyzer.result`
112
+ from different instances of the same analyzer class.
124
113
  **kwargs : dict
125
- Additional keyword arguments for merge configuration.
114
+ Keyword arguments used by specific implementations.
126
115
 
127
116
  Returns
128
117
  -------
129
- merged_result
118
+ tuple
130
119
  Single result object combining all input results.
131
120
 
132
121
  Notes
tme/analyzer/peaks.py CHANGED
@@ -17,9 +17,9 @@ from skimage.registration._phase_cross_correlation import _upsampled_dft
17
17
  from .base import AbstractAnalyzer
18
18
  from ._utils import score_to_cart
19
19
  from ..backends import backend as be
20
- from ..matching_utils import split_shape
21
20
  from ..types import BackendArray, NDArray
22
21
  from ..rotations import euler_to_rotationmatrix
22
+ from ..matching_utils import split_shape, compute_extraction_box
23
23
 
24
24
  __all__ = [
25
25
  "PeakCaller",
@@ -765,14 +765,14 @@ class PeakCallerRecursiveMasking(PeakCaller):
765
765
  values. If rotations and rotation_mapping is provided, the respective
766
766
  rotation will be applied to the mask, otherwise rotation_matrix is used.
767
767
  """
768
- coordinates, masking_function = [], self._mask_scores_rotate
768
+ peaks = []
769
+ box = tuple(self.min_distance for _ in range(scores.ndim))
769
770
 
770
- if mask is None:
771
- masking_function = self._mask_scores_box
772
- shape = tuple(self.min_distance for _ in range(scores.ndim))
773
- mask = be.zeros(shape, dtype=be._float_dtype)
774
-
775
- rotated_template = be.zeros(mask.shape, dtype=mask.dtype)
771
+ scores = be.to_backend_array(scores)
772
+ if mask is not None:
773
+ box = mask.shape
774
+ mask = be.to_backend_array(mask)
775
+ mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
776
776
 
777
777
  peak_limit = self.num_peaks
778
778
  if min_score is not None:
@@ -780,39 +780,45 @@ class PeakCallerRecursiveMasking(PeakCaller):
780
780
  else:
781
781
  min_score = be.min(scores) - 1
782
782
 
783
- scores_copy = be.zeros(scores.shape, dtype=scores.dtype)
784
- scores_copy[:] = scores
785
-
783
+ _scores = be.zeros(scores.shape, dtype=scores.dtype)
784
+ _scores[:] = scores[:]
786
785
  while True:
787
- be.argmax(scores_copy)
788
- peak = be.unravel_index(
789
- indices=be.argmax(scores_copy), shape=scores_copy.shape
790
- )
791
- if scores_copy[tuple(peak)] < min_score:
786
+ peak = be.unravel_index(indices=be.argmax(_scores), shape=_scores.shape)
787
+ if _scores[tuple(peak)] < min_score:
792
788
  break
789
+ peaks.append(peak)
793
790
 
794
- coordinates.append(peak)
795
-
796
- current_rotation_matrix = self._get_rotation_matrix(
797
- peak=peak,
798
- rotation_space=rotations,
799
- rotation_mapping=rotation_mapping,
800
- rotation_matrix=rotation_matrix,
791
+ score_beg, score_end, tmpl_beg, tmpl_end, _ = compute_extraction_box(
792
+ centers=be.to_backend_array(peak)[None],
793
+ extraction_shape=box,
794
+ original_shape=scores.shape,
801
795
  )
802
-
803
- masking_function(
804
- scores=scores_copy,
805
- rotation_matrix=current_rotation_matrix,
806
- peak=peak,
807
- mask=mask,
808
- rotated_template=rotated_template,
796
+ score_slice = tuple(
797
+ slice(int(x), int(y)) for x, y in zip(score_beg[0], score_end[0])
798
+ )
799
+ tmpl_slice = tuple(
800
+ slice(int(x), int(y)) for x, y in zip(tmpl_beg[0], tmpl_end[0])
809
801
  )
810
802
 
811
- if len(coordinates) >= peak_limit:
803
+ score_mask = 0
804
+ if mask is not None:
805
+ mask_buffer.fill(0)
806
+ rmat = self._get_rotation_matrix(
807
+ peak=peak,
808
+ rotation_space=rotations,
809
+ rotation_mapping=rotation_mapping,
810
+ rotation_matrix=rotation_matrix,
811
+ )
812
+ be.rigid_transform(
813
+ arr=mask, rotation_matrix=rmat, order=1, out=mask_buffer
814
+ )
815
+ score_mask = mask_buffer[tmpl_slice] <= 0.1
816
+
817
+ _scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
818
+ if len(peaks) >= peak_limit:
812
819
  break
813
820
 
814
- peaks = be.to_backend_array(coordinates)
815
- return peaks, None
821
+ return be.to_backend_array(peaks), None
816
822
 
817
823
  @staticmethod
818
824
  def _get_rotation_matrix(
@@ -852,86 +858,6 @@ class PeakCallerRecursiveMasking(PeakCaller):
852
858
  )
853
859
  return rotation
854
860
 
855
- @staticmethod
856
- def _mask_scores_box(
857
- scores: BackendArray, peak: BackendArray, mask: BackendArray, **kwargs: Dict
858
- ) -> None:
859
- """
860
- Mask scores in a box around a peak.
861
-
862
- Parameters
863
- ----------
864
- scores : BackendArray
865
- Data array of scores.
866
- peak : BackendArray
867
- Peak coordinates.
868
- mask : BackendArray
869
- Mask array.
870
- """
871
- start = be.maximum(be.subtract(peak, mask.shape), 0)
872
- stop = be.minimum(be.add(peak, mask.shape), scores.shape)
873
- start, stop = be.astype(start, int), be.astype(stop, int)
874
- coords = tuple(slice(*pos) for pos in zip(start, stop))
875
- scores[coords] = 0
876
- return None
877
-
878
- @staticmethod
879
- def _mask_scores_rotate(
880
- scores: BackendArray,
881
- peak: BackendArray,
882
- mask: BackendArray,
883
- rotated_template: BackendArray,
884
- rotation_matrix: BackendArray,
885
- **kwargs: Dict,
886
- ) -> None:
887
- """
888
- Mask scores using mask rotation around a peak.
889
-
890
- Parameters
891
- ----------
892
- scores : BackendArray
893
- Data array of scores.
894
- peak : BackendArray
895
- Peak coordinates.
896
- mask : BackendArray
897
- Mask array.
898
- rotated_template : BackendArray
899
- Empty array to write mask rotations to.
900
- rotation_matrix : BackendArray
901
- Rotation matrix.
902
- """
903
- left_pad = be.divide(mask.shape, 2).astype(int)
904
- right_pad = be.add(left_pad, be.mod(mask.shape, 2).astype(int))
905
-
906
- score_start = be.subtract(peak, left_pad)
907
- score_stop = be.add(peak, right_pad)
908
-
909
- template_start = be.subtract(be.maximum(score_start, 0), score_start)
910
- template_stop = be.subtract(score_stop, be.minimum(score_stop, scores.shape))
911
- template_stop = be.subtract(mask.shape, template_stop)
912
-
913
- score_start = be.maximum(score_start, 0)
914
- score_stop = be.minimum(score_stop, scores.shape)
915
- score_start = be.astype(score_start, int)
916
- score_stop = be.astype(score_stop, int)
917
-
918
- template_start = be.astype(template_start, int)
919
- template_stop = be.astype(template_stop, int)
920
- coords_score = tuple(slice(*pos) for pos in zip(score_start, score_stop))
921
- coords_template = tuple(
922
- slice(*pos) for pos in zip(template_start, template_stop)
923
- )
924
-
925
- rotated_template.fill(0)
926
- be.rigid_transform(
927
- arr=mask, rotation_matrix=rotation_matrix, order=1, out=rotated_template
928
- )
929
-
930
- scores[coords_score] = be.multiply(
931
- scores[coords_score], (rotated_template[coords_template] <= 0.1)
932
- )
933
- return None
934
-
935
861
 
936
862
  class PeakCallerScipy(PeakCaller):
937
863
  """
tme/analyzer/proxy.py CHANGED
@@ -18,6 +18,7 @@ from ..backends import backend as be
18
18
 
19
19
  __all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
20
20
 
21
+
21
22
  class StatelessSharedAnalyzerProxy:
22
23
  """
23
24
  Proxy that wraps functional analyzers for concurrent access via shared memory.
@@ -10,16 +10,19 @@ from typing import Tuple
10
10
  from functools import partial
11
11
 
12
12
  import jax.numpy as jnp
13
- from jax import pmap, lax
13
+ from jax import pmap, lax, vmap
14
14
 
15
15
  from ..types import BackendArray
16
16
  from ..backends import backend as be
17
17
  from ..matching_utils import normalize_template as _normalize_template
18
18
 
19
19
 
20
+ __all__ = ["scan"]
21
+
22
+
20
23
  def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
21
24
  """
22
- Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
25
+ Computes :py:meth:`tme.matching_scores.cc_setup`.
23
26
  """
24
27
  template_ft = jnp.fft.rfftn(template, s=template.shape)
25
28
  template_ft = template_ft.at[:].multiply(ft_target)
@@ -28,18 +31,17 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
28
31
 
29
32
 
30
33
  def _flc_scoring(
31
- template: BackendArray,
32
- template_mask: BackendArray,
33
34
  ft_target: BackendArray,
34
35
  ft_target2: BackendArray,
36
+ template: BackendArray,
37
+ template_mask: BackendArray,
35
38
  n_observations: BackendArray,
36
39
  eps: float,
37
40
  **kwargs,
38
41
  ) -> BackendArray:
39
42
  """
40
- Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
43
+ Computes :py:meth:`tme.matching_scores.flc_scoring`.
41
44
  """
42
- correlation = _correlate(template=template, ft_target=ft_target)
43
45
  inv_denominator = _reciprocal_target_std(
44
46
  ft_target=ft_target,
45
47
  ft_target2=ft_target2,
@@ -47,18 +49,17 @@ def _flc_scoring(
47
49
  eps=eps,
48
50
  n_observations=n_observations,
49
51
  )
50
- correlation = correlation.at[:].multiply(inv_denominator)
51
- return correlation
52
+ return _flcSphere_scoring(ft_target, template, inv_denominator)
52
53
 
53
54
 
54
55
  def _flcSphere_scoring(
55
- template: BackendArray,
56
56
  ft_target: BackendArray,
57
+ template: BackendArray,
57
58
  inv_denominator: BackendArray,
58
59
  **kwargs,
59
60
  ) -> BackendArray:
60
61
  """
61
- Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
62
+ Computes :py:meth:`tme.matching_scores.corr_scoring`.
62
63
  """
63
64
  correlation = _correlate(template=template, ft_target=ft_target)
64
65
  correlation = correlation.at[:].multiply(inv_denominator)
@@ -77,7 +78,7 @@ def _reciprocal_target_std(
77
78
 
78
79
  See Also
79
80
  --------
80
- :py:meth:`tme.matching_exhaustive.flc_scoring`.
81
+ :py:meth:`tme.matching_scores.flc_scoring`.
81
82
  """
82
83
  ft_shape = template_mask.shape
83
84
  ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
@@ -163,12 +164,12 @@ def scan(
163
164
  template_rot = _normalize_template(
164
165
  template_rot, template_mask_rot, n_observations
165
166
  )
166
- template_rot = be.topleft_pad(template_rot, fast_shape)
167
- template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
167
+ rot_pad = be.topleft_pad(template_rot, fast_shape)
168
+ mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
168
169
 
169
170
  scores = scoring_func(
170
- template=template_rot,
171
- template_mask=template_mask_rot,
171
+ template=rot_pad,
172
+ template_mask=mask_rot_pad,
172
173
  ft_target=ft_target,
173
174
  ft_target2=ft_target2,
174
175
  inv_denominator=inv_denominator,
@@ -6,13 +6,10 @@ Copyright (c) 2023 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- import warnings
10
9
  from importlib.util import find_spec
11
10
  from contextlib import contextmanager
12
11
  from typing import Tuple, Callable, List
13
12
 
14
- import numpy as np
15
-
16
13
  from .npfftw_backend import NumpyFFTWBackend
17
14
  from ..types import CupyArray, NDArray, shm_type
18
15
 
@@ -146,15 +143,14 @@ class CupyBackend(NumpyFFTWBackend):
146
143
  def rfftn(
147
144
  arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
148
145
  ) -> CupyArray:
149
- return self.rfftn(arr, s=s, axes=fwd_axes)
146
+ return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
150
147
 
151
148
  def irfftn(
152
149
  arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
153
150
  ) -> CupyArray:
154
- return self.irfftn(arr, s=s, axes=inv_axes)
151
+ return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
155
152
 
156
153
  PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
157
-
158
154
  return rfftn, irfftn
159
155
 
160
156
  def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
@@ -239,13 +235,13 @@ class CupyBackend(NumpyFFTWBackend):
239
235
  )
240
236
  return None
241
237
 
242
- # if data.ndim == 3 and cache and self.texture_available:
243
- # # Device memory pool (should) come to rescue performance
244
- # temp = self.zeros(data.shape, data.dtype)
245
- # texture = self._get_texture(data, order=order, prefilter=prefilter)
246
- # texture.affine(transform_m=matrix, profile=False, output=temp)
247
- # output[out_slice] = temp
248
- # return None
238
+ if data.ndim == 3 and cache and self.texture_available:
239
+ # Device memory pool (should) come to rescue performance
240
+ temp = self.zeros(data.shape, data.dtype)
241
+ texture = self._get_texture(data, order=order, prefilter=prefilter)
242
+ texture.affine(transform_m=matrix, profile=False, output=temp)
243
+ output[out_slice] = temp
244
+ return None
249
245
 
250
246
  self.affine_transform(
251
247
  input=data,
@@ -7,7 +7,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
9
  from functools import wraps
10
- from typing import Tuple, List, Callable
10
+ from typing import Tuple, List, Callable, Dict
11
11
 
12
12
  from ..types import BackendArray
13
13
  from .npfftw_backend import NumpyFFTWBackend, shm_type
@@ -51,12 +51,6 @@ class JaxBackend(NumpyFFTWBackend):
51
51
  )
52
52
  self.scipy = jsp
53
53
  self._create_ufuncs()
54
- try:
55
- from ._jax_utils import scan as _
56
-
57
- self.scan = self._scan
58
- except Exception:
59
- pass
60
54
 
61
55
  def from_sharedarr(self, arr: BackendArray) -> BackendArray:
62
56
  return arr
@@ -189,7 +183,18 @@ class JaxBackend(NumpyFFTWBackend):
189
183
  rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
190
184
  return max_scores, rotations
191
185
 
192
- def _scan(
186
+ def compute_convolution_shapes(
187
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
188
+ ) -> Tuple[List[int], List[int], List[int]]:
189
+ from scipy.fft import next_fast_len
190
+
191
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
192
+ fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
193
+ fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
194
+
195
+ return convolution_shape, fast_shape, fast_ft_shape
196
+
197
+ def scan(
193
198
  self,
194
199
  matching_data: type,
195
200
  splits: Tuple[Tuple[slice, slice]],
@@ -214,9 +219,9 @@ class JaxBackend(NumpyFFTWBackend):
214
219
  conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
215
220
  target_shape=self.to_numpy_array(target_shape),
216
221
  template_shape=self.to_numpy_array(matching_data._template.shape),
217
- pad_fourier=False,
222
+ batch_mask=self.to_numpy_array(matching_data._batch_mask),
223
+ pad_target=pad_target,
218
224
  )
219
-
220
225
  analyzer_args = {
221
226
  "convolution_mode": convolution_mode,
222
227
  "fourier_shift": shift,
@@ -246,19 +251,18 @@ class JaxBackend(NumpyFFTWBackend):
246
251
 
247
252
  targets, translation_offsets = [], []
248
253
  for target_split, template_split in split_subset:
249
- base = matching_data.subset_by_slice(
254
+ base, translation_offset = matching_data.subset_by_slice(
250
255
  target_slice=target_split,
251
256
  target_pad=target_pad,
252
257
  template_slice=template_split,
253
258
  )
254
- translation_offsets.append(base._translation_offset)
259
+ translation_offsets.append(translation_offset)
255
260
  targets.append(self.topleft_pad(base._target, fast_shape))
256
261
 
257
262
  if create_filter:
258
263
  filter_args = {
259
264
  "data_rfft": self.fft.rfftn(targets[0]),
260
265
  "return_real_fourier": True,
261
- "shape_is_real_fourier": False,
262
266
  }
263
267
 
264
268
  if create_template_filter:
@@ -288,12 +292,11 @@ class JaxBackend(NumpyFFTWBackend):
288
292
 
289
293
  for index in range(scores.shape[0]):
290
294
  temp = callback_class(
291
- shape=scores.shape,
295
+ shape=scores[index].shape,
292
296
  offset=translation_offsets[index],
293
297
  )
294
- state = (scores, rotations, rotation_mapping)
298
+ state = (scores[index], rotations[index], rotation_mapping)
295
299
  ret.append(temp.result(state, **analyzer_args))
296
-
297
300
  return ret
298
301
 
299
302
  def get_available_memory(self) -> int:
@@ -398,33 +398,33 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
398
398
  out_mask: NDArray = None,
399
399
  order: int = 3,
400
400
  cache: bool = False,
401
+ batched: bool = False,
401
402
  ) -> Tuple[NDArray, NDArray]:
402
- out = self.zeros_like(arr) if out is None else out
403
- batched = arr.ndim != rotation_matrix.shape[0]
404
-
405
- center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
406
- if not use_geometric_center:
407
- center = self.center_of_mass(arr, cutoff=0)
408
-
409
- offset = int(arr.ndim - rotation_matrix.shape[0])
410
- center = center[offset:]
411
- translation = self.zeros(center.size) if translation is None else translation
412
- matrix = self._rigid_transform_matrix(
413
- rotation_matrix=rotation_matrix,
414
- translation=translation,
415
- center=center,
416
- )
417
-
418
- subset = tuple(slice(None) for _ in range(arr.ndim))
419
- if offset > 1:
420
- subset = tuple(
421
- 0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
403
+ if out is None:
404
+ out = self.zeros_like(arr)
405
+
406
+ # Check whether rotation_matrix is already a rigid transform matrix
407
+ matrix = rotation_matrix
408
+ if matrix.shape[-1] == (arr.ndim - int(batched)):
409
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
410
+ if not use_geometric_center:
411
+ center = self.center_of_mass(arr, cutoff=0)
412
+
413
+ offset = int(arr.ndim - rotation_matrix.shape[0])
414
+ center = center[offset:]
415
+ translation = (
416
+ self.zeros(center.size) if translation is None else translation
417
+ )
418
+ matrix = self._rigid_transform_matrix(
419
+ rotation_matrix=rotation_matrix,
420
+ translation=translation,
421
+ center=center,
422
422
  )
423
423
 
424
424
  self._rigid_transform(
425
- data=arr[subset],
425
+ data=arr,
426
426
  matrix=matrix,
427
- output=out[subset],
427
+ output=out,
428
428
  order=order,
429
429
  prefilter=True,
430
430
  cache=cache,
@@ -433,11 +433,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
433
433
 
434
434
  # Applying the prefilter leads to artifacts in the mask.
435
435
  if arr_mask is not None:
436
- out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
436
+ if out_mask is None:
437
+ out_mask = self.zeros_like(arr_mask)
438
+
437
439
  self._rigid_transform(
438
- data=arr_mask[subset],
440
+ data=arr_mask,
439
441
  matrix=matrix,
440
- output=out_mask[subset],
442
+ output=out_mask,
441
443
  order=order,
442
444
  prefilter=False,
443
445
  cache=cache,
@@ -306,6 +306,9 @@ class PytorchBackend(NumpyFFTWBackend):
306
306
  kwargs["dim"] = kwargs.pop("axes", None)
307
307
  return self._array_backend.fft.irfftn(arr, **kwargs)
308
308
 
309
+ def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
310
+ return rotation_matrix
311
+
309
312
  def rigid_transform(
310
313
  self,
311
314
  arr: TorchTensor,
@@ -317,6 +320,7 @@ class PytorchBackend(NumpyFFTWBackend):
317
320
  out_mask: TorchTensor = None,
318
321
  order: int = 1,
319
322
  cache: bool = False,
323
+ **kwargs,
320
324
  ):
321
325
  _mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
322
326
  mode = _mode_mapping.get(order, None)
tme/density.py CHANGED
@@ -1763,12 +1763,13 @@ class Density:
1763
1763
  axis=axis,
1764
1764
  )
1765
1765
 
1766
- arr_ft = np.fft.fftn(self.data)
1766
+ mask, mask_ret = np.where(mask), np.where(mask_ret)
1767
+
1768
+ arr_ft = np.fft.fftn(self.data)[mask]
1767
1769
  arr_ft *= np.prod(ret_shape) / np.prod(self.shape)
1768
1770
  ret_ft = np.zeros(ret_shape, dtype=arr_ft.dtype)
1769
- ret_ft[mask_ret] = arr_ft[mask]
1770
- ret.data = np.real(np.fft.ifftn(ret_ft))
1771
-
1771
+ np.add.at(ret_ft, mask_ret, arr_ft)
1772
+ ret.data = np.real(np.fft.ifftn(ret_ft)).astype(self.data.dtype)
1772
1773
  ret.sampling_rate = new_sampling_rate
1773
1774
  return ret
1774
1775
 
tme/filters/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from .ctf import CTF, CTFReconstructed
2
2
  from .compose import Compose, ComposableFilter
3
- from .bandpass import BandPassFilter
3
+ from .bandpass import BandPass, BandPassReconstructed
4
4
  from .whitening import LinearWhiteningFilter
5
5
  from .wedge import Wedge, WedgeReconstructed
6
- from .reconstruction import ReconstructFromTilt
6
+ from .reconstruction import ReconstructFromTilt, ShiftFourier