pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
tme/analyzer/proxy.py ADDED
@@ -0,0 +1,123 @@
1
+ """
2
+ Implements SharedAnalyzerProxy to managed shared memory of Analyzer instances
3
+ across different tasks.
4
+
5
+ This is primarily useful for CPU template matching, where parallelization can
6
+ be performed over rotations, rather than subsections of a large input volume.
7
+
8
+ Copyright (c) 2025 European Molecular Biology Laboratory
9
+
10
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
11
+ """
12
+
13
+ from typing import Tuple
14
+ from multiprocessing import Manager
15
+ from multiprocessing.shared_memory import SharedMemory
16
+
17
+ from ..backends import backend as be
18
+
19
+ __all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
20
+
21
+
22
+ class StatelessSharedAnalyzerProxy:
23
+ """
24
+ Proxy that wraps functional analyzers for concurrent access via shared memory.
25
+
26
+ Enables multiple processes/threads to safely update the same analyzer
27
+ while preserving the functional interface of the underlying analyzer.
28
+ """
29
+
30
+ def __init__(self, analyzer_class: type, analyzer_params: dict):
31
+ self._shared = False
32
+ self._process = self._direct_call
33
+
34
+ self._analyzer = analyzer_class(**analyzer_params)
35
+
36
+ def __call__(self, state, *args, **kwargs):
37
+ return self._process(state, *args, **kwargs)
38
+
39
+ def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
40
+ state = self._analyzer.init_state()
41
+ if shm_handler is not None:
42
+ self._shared = True
43
+ state = self._to_shared(state, shm_handler)
44
+
45
+ self._lock = Manager().Lock()
46
+ self._process = self._thread_safe_call
47
+ return state
48
+
49
+ def _to_shared(self, state: Tuple, shm_handler):
50
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
51
+
52
+ ret = []
53
+ for v in state:
54
+ if isinstance(v, backend_arr):
55
+ v = be.to_sharedarr(v, shm_handler)
56
+ elif isinstance(v, dict):
57
+ v = Manager().dict(**v)
58
+ ret.append(v)
59
+ return tuple(ret)
60
+
61
+ def _shared_to_object(self, shared: type):
62
+ if not self._shared:
63
+ return shared
64
+ if isinstance(shared, tuple) and len(shared):
65
+ if isinstance(shared[0], SharedMemory):
66
+ return be.from_sharedarr(shared)
67
+ return shared
68
+
69
+ def _thread_safe_call(self, state, *args, **kwargs):
70
+ """Thread-safe call to analyzer"""
71
+ with self._lock:
72
+ state = tuple(self._shared_to_object(x) for x in state)
73
+ return self._direct_call(state, *args, **kwargs)
74
+
75
+ def _direct_call(self, state, *args, **kwargs):
76
+ """Direct call to analyzer without locking"""
77
+ return self._analyzer(state, *args, **kwargs)
78
+
79
+ def result(self, state, **kwargs):
80
+ """Extract final result"""
81
+ final_state = state
82
+ if self._shared:
83
+ # Convert shared arrays back to regular arrays and copy to
84
+ # avoid array invalidation by shared memory handler
85
+ final_state = tuple(self._shared_to_object(x) for x in final_state)
86
+ return self._analyzer.result(final_state, **kwargs)
87
+
88
+ def merge(self, *args, **kwargs):
89
+ return self._analyzer.merge(*args, **kwargs)
90
+
91
+
92
+ class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
93
+ """
94
+ Child of :py:class:`StatelessSharedAnalyzerProxy` that is aware
95
+ of the current analyzer state to emulate the previous analyzer interface.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ analyzer_class: type,
101
+ analyzer_params: dict,
102
+ shm_handler: type = None,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ analyzer_class=analyzer_class,
107
+ analyzer_params=analyzer_params,
108
+ )
109
+ if not self._analyzer.shareable:
110
+ shm_handler = None
111
+ self.init_state(shm_handler)
112
+
113
+ def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
114
+ self._state = super().init_state(shm_handler, *args, **kwargs)
115
+
116
+ def __call__(self, *args, **kwargs):
117
+ state = super().__call__(self._state, *args, **kwargs)
118
+ if not self._shared:
119
+ self._state = state
120
+
121
+ def result(self, **kwargs):
122
+ """Extract final result"""
123
+ return super().result(self._state, **kwargs)
tme/backends/__init__.py CHANGED
@@ -1,8 +1,9 @@
1
- """ pyTME backend manager.
1
+ """
2
+ pyTME backend manager.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Dict, List
@@ -1,34 +1,35 @@
1
- """ Utility functions for cupy backend.
1
+ """
2
+ Utility functions for cupy backend.
2
3
 
3
- The functions spline_filter, _prepad_for_spline_filter, _filter_input,
4
- _get_coord_affine_batched and affine_transform are largely copied from
5
- cupyx.scipy.ndimage which operates under the following license
4
+ The functions spline_filter, _prepad_for_spline_filter, _filter_input,
5
+ _get_coord_affine_batched and affine_transform are largely copied from
6
+ cupyx.scipy.ndimage which operates under the following license
6
7
 
7
- Copyright (c) 2015 Preferred Infrastructure, Inc.
8
- Copyright (c) 2015 Preferred Networks, Inc.
8
+ Copyright (c) 2015 Preferred Infrastructure, Inc.
9
+ Copyright (c) 2015 Preferred Networks, Inc.
9
10
 
10
- Permission is hereby granted, free of charge, to any person obtaining a copy
11
- of this software and associated documentation files (the "Software"), to deal
12
- in the Software without restriction, including without limitation the rights
13
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
- copies of the Software, and to permit persons to whom the Software is
15
- furnished to do so, subject to the following conditions:
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
16
17
 
17
- The above copyright notice and this permission notice shall be included in
18
- all copies or substantial portions of the Software.
18
+ The above copyright notice and this permission notice shall be included in
19
+ all copies or substantial portions of the Software.
19
20
 
20
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26
- THE SOFTWARE.
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27
+ THE SOFTWARE.
27
28
 
28
- I have since extended the functionality of the cupyx.scipy.ndimage functions
29
- in question to support batched inputs.
29
+ I have since extended the functionality of the cupyx.scipy.ndimage functions
30
+ in question to support batched inputs.
30
31
 
31
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
32
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
32
33
  """
33
34
 
34
35
  import numpy
@@ -1,24 +1,28 @@
1
- """ Utility functions for jax backend.
1
+ """
2
+ Utility functions for jax backend.
2
3
 
3
- Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple
9
10
  from functools import partial
10
11
 
11
12
  import jax.numpy as jnp
12
- from jax import pmap, lax
13
+ from jax import pmap, lax, vmap
13
14
 
14
15
  from ..types import BackendArray
15
16
  from ..backends import backend as be
16
17
  from ..matching_utils import normalize_template as _normalize_template
17
18
 
18
19
 
20
+ __all__ = ["scan"]
21
+
22
+
19
23
  def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
20
24
  """
21
- Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
25
+ Computes :py:meth:`tme.matching_scores.cc_setup`.
22
26
  """
23
27
  template_ft = jnp.fft.rfftn(template, s=template.shape)
24
28
  template_ft = template_ft.at[:].multiply(ft_target)
@@ -27,18 +31,17 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
27
31
 
28
32
 
29
33
  def _flc_scoring(
30
- template: BackendArray,
31
- template_mask: BackendArray,
32
34
  ft_target: BackendArray,
33
35
  ft_target2: BackendArray,
36
+ template: BackendArray,
37
+ template_mask: BackendArray,
34
38
  n_observations: BackendArray,
35
39
  eps: float,
36
40
  **kwargs,
37
41
  ) -> BackendArray:
38
42
  """
39
- Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
43
+ Computes :py:meth:`tme.matching_scores.flc_scoring`.
40
44
  """
41
- correlation = _correlate(template=template, ft_target=ft_target)
42
45
  inv_denominator = _reciprocal_target_std(
43
46
  ft_target=ft_target,
44
47
  ft_target2=ft_target2,
@@ -46,18 +49,17 @@ def _flc_scoring(
46
49
  eps=eps,
47
50
  n_observations=n_observations,
48
51
  )
49
- correlation = correlation.at[:].multiply(inv_denominator)
50
- return correlation
52
+ return _flcSphere_scoring(ft_target, template, inv_denominator)
51
53
 
52
54
 
53
55
  def _flcSphere_scoring(
54
- template: BackendArray,
55
56
  ft_target: BackendArray,
57
+ template: BackendArray,
56
58
  inv_denominator: BackendArray,
57
59
  **kwargs,
58
60
  ) -> BackendArray:
59
61
  """
60
- Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
62
+ Computes :py:meth:`tme.matching_scores.corr_scoring`.
61
63
  """
62
64
  correlation = _correlate(template=template, ft_target=ft_target)
63
65
  correlation = correlation.at[:].multiply(inv_denominator)
@@ -76,7 +78,7 @@ def _reciprocal_target_std(
76
78
 
77
79
  See Also
78
80
  --------
79
- :py:meth:`tme.matching_exhaustive.flc_scoring`.
81
+ :py:meth:`tme.matching_scores.flc_scoring`.
80
82
  """
81
83
  ft_shape = template_mask.shape
82
84
  ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
@@ -162,12 +164,12 @@ def scan(
162
164
  template_rot = _normalize_template(
163
165
  template_rot, template_mask_rot, n_observations
164
166
  )
165
- template_rot = be.topleft_pad(template_rot, fast_shape)
166
- template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
167
+ rot_pad = be.topleft_pad(template_rot, fast_shape)
168
+ mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
167
169
 
168
170
  scores = scoring_func(
169
- template=template_rot,
170
- template_mask=template_mask_rot,
171
+ template=rot_pad,
172
+ template_mask=mask_rot_pad,
171
173
  ft_target=ft_target,
172
174
  ft_target2=ft_target2,
173
175
  inv_denominator=inv_denominator,
@@ -1,17 +1,15 @@
1
- """ Backend using cupy for template matching.
1
+ """
2
+ Backend using cupy for template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
- import warnings
9
9
  from importlib.util import find_spec
10
10
  from contextlib import contextmanager
11
11
  from typing import Tuple, Callable, List
12
12
 
13
- import numpy as np
14
-
15
13
  from .npfftw_backend import NumpyFFTWBackend
16
14
  from ..types import CupyArray, NDArray, shm_type
17
15
 
@@ -113,16 +111,6 @@ class CupyBackend(NumpyFFTWBackend):
113
111
  def unravel_index(self, indices, shape):
114
112
  return self._array_backend.unravel_index(indices=indices, dims=shape)
115
113
 
116
- def unique(self, ar, axis=None, *args, **kwargs):
117
- if axis is None:
118
- return self._array_backend.unique(ar=ar, axis=axis, *args, **kwargs)
119
-
120
- warnings.warn("Axis argument not yet supported in CupY, falling back to NumPy.")
121
- ret = np.unique(ar=self.to_numpy_array(ar), axis=axis, *args, **kwargs)
122
- if not isinstance(ret, tuple):
123
- return self.to_backend_array(ret)
124
- return tuple(self.to_backend_array(k) for k in ret)
125
-
126
114
  def build_fft(
127
115
  self,
128
116
  fwd_shape: Tuple[int],
@@ -155,15 +143,14 @@ class CupyBackend(NumpyFFTWBackend):
155
143
  def rfftn(
156
144
  arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
157
145
  ) -> CupyArray:
158
- return self.rfftn(arr, s=s, axes=fwd_axes)
146
+ return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
159
147
 
160
148
  def irfftn(
161
149
  arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
162
150
  ) -> CupyArray:
163
- return self.irfftn(arr, s=s, axes=inv_axes)
151
+ return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
164
152
 
165
153
  PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
166
-
167
154
  return rfftn, irfftn
168
155
 
169
156
  def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
@@ -248,13 +235,13 @@ class CupyBackend(NumpyFFTWBackend):
248
235
  )
249
236
  return None
250
237
 
251
- # if data.ndim == 3 and cache and self.texture_available:
252
- # # Device memory pool (should) come to rescue performance
253
- # temp = self.zeros(data.shape, data.dtype)
254
- # texture = self._get_texture(data, order=order, prefilter=prefilter)
255
- # texture.affine(transform_m=matrix, profile=False, output=temp)
256
- # output[out_slice] = temp
257
- # return None
238
+ if data.ndim == 3 and cache and self.texture_available:
239
+ # Device memory pool (should) come to rescue performance
240
+ temp = self.zeros(data.shape, data.dtype)
241
+ texture = self._get_texture(data, order=order, prefilter=prefilter)
242
+ texture.affine(transform_m=matrix, profile=False, output=temp)
243
+ output[out_slice] = temp
244
+ return None
258
245
 
259
246
  self.affine_transform(
260
247
  input=data,
@@ -1,12 +1,13 @@
1
- """ Backend using jax for template matching.
1
+ """
2
+ Backend using jax for template matching.
2
3
 
3
- Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from functools import wraps
9
- from typing import Tuple, List, Callable
10
+ from typing import Tuple, List, Callable, Dict
10
11
 
11
12
  from ..types import BackendArray
12
13
  from .npfftw_backend import NumpyFFTWBackend, shm_type
@@ -50,12 +51,6 @@ class JaxBackend(NumpyFFTWBackend):
50
51
  )
51
52
  self.scipy = jsp
52
53
  self._create_ufuncs()
53
- try:
54
- from ._jax_utils import scan as _
55
-
56
- self.scan = self._scan
57
- except Exception:
58
- pass
59
54
 
60
55
  def from_sharedarr(self, arr: BackendArray) -> BackendArray:
61
56
  return arr
@@ -188,7 +183,18 @@ class JaxBackend(NumpyFFTWBackend):
188
183
  rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
189
184
  return max_scores, rotations
190
185
 
191
- def _scan(
186
+ def compute_convolution_shapes(
187
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
188
+ ) -> Tuple[List[int], List[int], List[int]]:
189
+ from scipy.fft import next_fast_len
190
+
191
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
192
+ fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
193
+ fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
194
+
195
+ return convolution_shape, fast_shape, fast_ft_shape
196
+
197
+ def scan(
192
198
  self,
193
199
  matching_data: type,
194
200
  splits: Tuple[Tuple[slice, slice]],
@@ -213,9 +219,9 @@ class JaxBackend(NumpyFFTWBackend):
213
219
  conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
214
220
  target_shape=self.to_numpy_array(target_shape),
215
221
  template_shape=self.to_numpy_array(matching_data._template.shape),
216
- pad_fourier=False,
222
+ batch_mask=self.to_numpy_array(matching_data._batch_mask),
223
+ pad_target=pad_target,
217
224
  )
218
-
219
225
  analyzer_args = {
220
226
  "convolution_mode": convolution_mode,
221
227
  "fourier_shift": shift,
@@ -245,19 +251,18 @@ class JaxBackend(NumpyFFTWBackend):
245
251
 
246
252
  targets, translation_offsets = [], []
247
253
  for target_split, template_split in split_subset:
248
- base = matching_data.subset_by_slice(
254
+ base, translation_offset = matching_data.subset_by_slice(
249
255
  target_slice=target_split,
250
256
  target_pad=target_pad,
251
257
  template_slice=template_split,
252
258
  )
253
- translation_offsets.append(base._translation_offset)
259
+ translation_offsets.append(translation_offset)
254
260
  targets.append(self.topleft_pad(base._target, fast_shape))
255
261
 
256
262
  if create_filter:
257
263
  filter_args = {
258
264
  "data_rfft": self.fft.rfftn(targets[0]),
259
265
  "return_real_fourier": True,
260
- "shape_is_real_fourier": False,
261
266
  }
262
267
 
263
268
  if create_template_filter:
@@ -287,15 +292,11 @@ class JaxBackend(NumpyFFTWBackend):
287
292
 
288
293
  for index in range(scores.shape[0]):
289
294
  temp = callback_class(
290
- shape=scores.shape,
291
- scores=scores[index],
292
- rotations=rotations[index],
293
- thread_safe=False,
295
+ shape=scores[index].shape,
294
296
  offset=translation_offsets[index],
295
297
  )
296
- temp.rotation_mapping = rotation_mapping
297
- ret.append(tuple(temp._postprocess(**analyzer_args)))
298
-
298
+ state = (scores[index], rotations[index], rotation_mapping)
299
+ ret.append(temp.result(state, **analyzer_args))
299
300
  return ret
300
301
 
301
302
  def get_available_memory(self) -> int:
@@ -1,8 +1,9 @@
1
- """ Strategy pattern to allow for flexible array / FFT backends.
1
+ """
2
+ Strategy pattern to allow for flexible array / FFT backends.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from abc import ABC, abstractmethod
@@ -1,8 +1,9 @@
1
- """ Backend using Apple's MLX library for template matching.
1
+ """
2
+ Backend using Apple's MLX library for template matching.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple, List, Callable
@@ -1,8 +1,9 @@
1
- """ Backend using numpy and pyFFTW for template matching.
1
+ """
2
+ Backend using numpy and pyFFTW for template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import os
@@ -159,8 +160,9 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
159
160
  temp = self._array_backend.zeros(1, dtype=dtype)
160
161
  return temp.nbytes
161
162
 
162
- @staticmethod
163
- def astype(arr, dtype: Type) -> NDArray:
163
+ def astype(self, arr, dtype: Type) -> NDArray:
164
+ if self._array_backend.iscomplexobj(arr):
165
+ arr = arr.real
164
166
  return arr.astype(dtype)
165
167
 
166
168
  @staticmethod
@@ -396,33 +398,33 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
396
398
  out_mask: NDArray = None,
397
399
  order: int = 3,
398
400
  cache: bool = False,
401
+ batched: bool = False,
399
402
  ) -> Tuple[NDArray, NDArray]:
400
- out = self.zeros_like(arr) if out is None else out
401
- batched = arr.ndim != rotation_matrix.shape[0]
402
-
403
- center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
404
- if not use_geometric_center:
405
- center = self.center_of_mass(arr, cutoff=0)
406
-
407
- offset = int(arr.ndim - rotation_matrix.shape[0])
408
- center = center[offset:]
409
- translation = self.zeros(center.size) if translation is None else translation
410
- matrix = self._rigid_transform_matrix(
411
- rotation_matrix=rotation_matrix,
412
- translation=translation,
413
- center=center,
414
- )
415
-
416
- subset = tuple(slice(None) for _ in range(arr.ndim))
417
- if offset > 1:
418
- subset = tuple(
419
- 0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
403
+ if out is None:
404
+ out = self.zeros_like(arr)
405
+
406
+ # Check whether rotation_matrix is already a rigid transform matrix
407
+ matrix = rotation_matrix
408
+ if matrix.shape[-1] == (arr.ndim - int(batched)):
409
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
410
+ if not use_geometric_center:
411
+ center = self.center_of_mass(arr, cutoff=0)
412
+
413
+ offset = int(arr.ndim - rotation_matrix.shape[0])
414
+ center = center[offset:]
415
+ translation = (
416
+ self.zeros(center.size) if translation is None else translation
417
+ )
418
+ matrix = self._rigid_transform_matrix(
419
+ rotation_matrix=rotation_matrix,
420
+ translation=translation,
421
+ center=center,
420
422
  )
421
423
 
422
424
  self._rigid_transform(
423
- data=arr[subset],
425
+ data=arr,
424
426
  matrix=matrix,
425
- output=out[subset],
427
+ output=out,
426
428
  order=order,
427
429
  prefilter=True,
428
430
  cache=cache,
@@ -431,11 +433,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
431
433
 
432
434
  # Applying the prefilter leads to artifacts in the mask.
433
435
  if arr_mask is not None:
434
- out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
436
+ if out_mask is None:
437
+ out_mask = self.zeros_like(arr_mask)
438
+
435
439
  self._rigid_transform(
436
- data=arr_mask[subset],
440
+ data=arr_mask,
437
441
  matrix=matrix,
438
- output=out_mask[subset],
442
+ output=out_mask,
439
443
  order=order,
440
444
  prefilter=False,
441
445
  cache=cache,
@@ -1,9 +1,10 @@
1
- """ Backend using pytorch and optionally GPU acceleration for
2
- template matching.
1
+ """
2
+ Backend using pytorch and optionally GPU acceleration for
3
+ template matching.
3
4
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
5
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
6
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
8
  """
8
9
 
9
10
  from typing import Tuple, Callable
@@ -134,6 +135,15 @@ class PytorchBackend(NumpyFFTWBackend):
134
135
  def astype(self, arr: TorchTensor, dtype: type) -> TorchTensor:
135
136
  return arr.to(dtype)
136
137
 
138
+ @staticmethod
139
+ def at(arr, idx, value) -> NDArray:
140
+ arr[idx] = value
141
+ return arr
142
+
143
+ @staticmethod
144
+ def addat(arr, indices, *args, **kwargs) -> NDArray:
145
+ return arr.index_put_(indices, *args, accumulate=True, **kwargs)
146
+
137
147
  def flip(self, a, axis, **kwargs):
138
148
  return self._array_backend.flip(input=a, dims=axis, **kwargs)
139
149
 
@@ -296,6 +306,9 @@ class PytorchBackend(NumpyFFTWBackend):
296
306
  kwargs["dim"] = kwargs.pop("axes", None)
297
307
  return self._array_backend.fft.irfftn(arr, **kwargs)
298
308
 
309
+ def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
310
+ return rotation_matrix
311
+
299
312
  def rigid_transform(
300
313
  self,
301
314
  arr: TorchTensor,
@@ -307,6 +320,7 @@ class PytorchBackend(NumpyFFTWBackend):
307
320
  out_mask: TorchTensor = None,
308
321
  order: int = 1,
309
322
  cache: bool = False,
323
+ **kwargs,
310
324
  ):
311
325
  _mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
312
326
  mode = _mode_mapping.get(order, None)