pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +35 -21
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_analyzer.py +2 -3
  20. tests/test_backends.py +3 -9
  21. tests/test_density.py +0 -1
  22. tests/test_extensions.py +0 -1
  23. tests/test_matching_utils.py +10 -60
  24. tests/test_rotations.py +1 -1
  25. tme/__version__.py +1 -1
  26. tme/analyzer/_utils.py +4 -4
  27. tme/analyzer/aggregation.py +35 -15
  28. tme/analyzer/peaks.py +11 -10
  29. tme/backends/_jax_utils.py +26 -13
  30. tme/backends/_numpyfftw_utils.py +270 -0
  31. tme/backends/cupy_backend.py +16 -55
  32. tme/backends/jax_backend.py +76 -37
  33. tme/backends/matching_backend.py +17 -51
  34. tme/backends/mlx_backend.py +1 -27
  35. tme/backends/npfftw_backend.py +71 -65
  36. tme/backends/pytorch_backend.py +1 -26
  37. tme/density.py +2 -6
  38. tme/extensions.cpython-311-darwin.so +0 -0
  39. tme/filters/ctf.py +22 -21
  40. tme/filters/wedge.py +10 -7
  41. tme/mask.py +341 -0
  42. tme/matching_data.py +31 -19
  43. tme/matching_exhaustive.py +37 -47
  44. tme/matching_optimization.py +2 -1
  45. tme/matching_scores.py +229 -411
  46. tme/matching_utils.py +73 -422
  47. tme/memory.py +1 -1
  48. tme/orientations.py +13 -8
  49. tme/rotations.py +1 -1
  50. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  51. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
  52. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
  55. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
  56. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Henry Gomersall
4
+ # heng@kedevelopments.co.uk
5
+ #
6
+ # All rights reserved.
7
+ #
8
+ # Redistribution and use in source and binary forms, with or without
9
+ # modification, are permitted provided that the following conditions are met:
10
+ #
11
+ # * Redistributions of source code must retain the above copyright notice, this
12
+ # list of conditions and the following disclaimer.
13
+ #
14
+ # * Redistributions in binary form must reproduce the above copyright notice,
15
+ # this list of conditions and the following disclaimer in the documentation
16
+ # and/or other materials provided with the distribution.
17
+ #
18
+ # * Neither the name of the copyright holder nor the names of its contributors
19
+ # may be used to endorse or promote products derived from this software without
20
+ # specific prior written permission.
21
+ #
22
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
26
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
29
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
30
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
31
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32
+ # POSSIBILITY OF SUCH DAMAGE.
33
+ #
34
+
35
+ # This code has been adapted to add support for the out argument in rfftn, irfftn
36
+ # to allow for reusing existing array buffers
37
+ # Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
38
+
39
+ import threading
40
+
41
+ import pyfftw
42
+ import numpy as np
43
+ import pyfftw.builders as builders
44
+ from pyfftw.interfaces import cache
45
+ from pyfftw.builders._utils import _norm_args, _default_effort, _default_threads
46
+
47
+
48
+ def _Xfftn(
49
+ a,
50
+ s,
51
+ axes,
52
+ overwrite_input,
53
+ planner_effort,
54
+ threads,
55
+ auto_align_input,
56
+ auto_contiguous,
57
+ calling_func,
58
+ normalise_idft=True,
59
+ ortho=False,
60
+ real_direction_flag=None,
61
+ output_array=None,
62
+ ):
63
+
64
+ work_with_copy = False
65
+
66
+ a = np.asanyarray(a)
67
+
68
+ try:
69
+ s = tuple(s)
70
+ except TypeError:
71
+ pass
72
+
73
+ try:
74
+ axes = tuple(axes)
75
+ except TypeError:
76
+ pass
77
+
78
+ if calling_func in ("dct", "dst"):
79
+ # real-to-real transforms require passing an additional flag argument
80
+ avoid_copy = False
81
+ args = (
82
+ overwrite_input,
83
+ planner_effort,
84
+ threads,
85
+ auto_align_input,
86
+ auto_contiguous,
87
+ avoid_copy,
88
+ real_direction_flag,
89
+ )
90
+ elif calling_func in ("irfft2", "irfftn"):
91
+ # overwrite_input is not an argument to irfft2 or irfftn
92
+ args = (planner_effort, threads, auto_align_input, auto_contiguous)
93
+
94
+ if not overwrite_input:
95
+ # Only irfft2 and irfftn have overwriting the input
96
+ # as the default (and so require the input array to
97
+ # be reloaded).
98
+ work_with_copy = True
99
+ else:
100
+ args = (
101
+ overwrite_input,
102
+ planner_effort,
103
+ threads,
104
+ auto_align_input,
105
+ auto_contiguous,
106
+ )
107
+
108
+ if not a.flags.writeable:
109
+ # Special case of a locked array - always work with a
110
+ # copy. See issue #92.
111
+ work_with_copy = True
112
+
113
+ if overwrite_input:
114
+ raise ValueError(
115
+ "overwrite_input cannot be True when the "
116
+ + "input array flags.writeable is False"
117
+ )
118
+
119
+ if work_with_copy:
120
+ # We make the copy before registering the key so that the
121
+ # copy's stride information will be cached since this will be
122
+ # used for planning. Make sure the copy is byte aligned to
123
+ # prevent further copying
124
+ a_original = a
125
+ a = pyfftw.empty_aligned(shape=a.shape, dtype=a.dtype)
126
+ a[...] = a_original
127
+
128
+ if cache.is_enabled():
129
+ alignment = a.ctypes.data % pyfftw.simd_alignment
130
+
131
+ key = (
132
+ calling_func,
133
+ a.shape,
134
+ a.strides,
135
+ a.dtype,
136
+ s.__hash__(),
137
+ axes.__hash__(),
138
+ alignment,
139
+ args,
140
+ threading.get_ident(),
141
+ )
142
+
143
+ try:
144
+ if key in cache._fftw_cache:
145
+ FFTW_object = cache._fftw_cache.lookup(key)
146
+ else:
147
+ FFTW_object = None
148
+
149
+ except KeyError:
150
+ # This occurs if the object has fallen out of the cache between
151
+ # the check and the lookup
152
+ FFTW_object = None
153
+
154
+ if not cache.is_enabled() or FFTW_object is None:
155
+
156
+ # If we're going to create a new FFTW object and are not
157
+ # working with a copy, then we need to copy the input array to
158
+ # preserve it, otherwise we can't actually take the transform
159
+ # of the input array! (in general, we have to assume that the
160
+ # input array will be destroyed during planning).
161
+ if not work_with_copy:
162
+ a_copy = a.copy()
163
+
164
+ planner_args = (a, s, axes) + args
165
+
166
+ FFTW_object = getattr(builders, calling_func)(*planner_args)
167
+
168
+ # Only copy if the input array is what was actually used
169
+ # (otherwise it shouldn't be overwritten)
170
+ if not work_with_copy and FFTW_object.input_array is a:
171
+ a[:] = a_copy
172
+
173
+ if cache.is_enabled():
174
+ cache._fftw_cache.insert(FFTW_object, key)
175
+
176
+ output_array = FFTW_object(normalise_idft=normalise_idft, ortho=ortho)
177
+
178
+ else:
179
+ orig_output_array = FFTW_object.output_array
180
+ output_shape = orig_output_array.shape
181
+ output_dtype = orig_output_array.dtype
182
+ output_alignment = FFTW_object.output_alignment
183
+
184
+ if output_array is None:
185
+ output_array = pyfftw.empty_aligned(
186
+ output_shape, output_dtype, n=output_alignment
187
+ )
188
+
189
+ FFTW_object(
190
+ input_array=a,
191
+ output_array=output_array,
192
+ normalise_idft=normalise_idft,
193
+ ortho=ortho,
194
+ )
195
+
196
+ return output_array
197
+
198
+
199
+ def rfftn(
200
+ a,
201
+ s=None,
202
+ axes=None,
203
+ norm=None,
204
+ overwrite_input=False,
205
+ planner_effort=None,
206
+ threads=None,
207
+ auto_align_input=True,
208
+ auto_contiguous=True,
209
+ out=None,
210
+ ):
211
+ """Perform an n-D real FFT.
212
+
213
+ The first four arguments are as per :func:`numpy.fft.rfftn`;
214
+ the rest of the arguments are documented
215
+ in the :ref:`additional arguments docs<interfaces_additional_args>`.
216
+ """
217
+ calling_func = "rfftn"
218
+ planner_effort = _default_effort(planner_effort)
219
+ threads = _default_threads(threads)
220
+
221
+ return _Xfftn(
222
+ a,
223
+ s,
224
+ axes,
225
+ overwrite_input,
226
+ planner_effort,
227
+ threads,
228
+ auto_align_input,
229
+ auto_contiguous,
230
+ calling_func,
231
+ **_norm_args(norm),
232
+ output_array=out,
233
+ )
234
+
235
+
236
+ def irfftn(
237
+ a,
238
+ s=None,
239
+ axes=None,
240
+ norm=None,
241
+ overwrite_input=False,
242
+ planner_effort=None,
243
+ threads=None,
244
+ auto_align_input=True,
245
+ auto_contiguous=True,
246
+ out=None,
247
+ ):
248
+ """Perform an n-D real inverse FFT.
249
+
250
+ The first four arguments are as per :func:`numpy.fft.rfftn`;
251
+ the rest of the arguments are documented
252
+ in the :ref:`additional arguments docs<interfaces_additional_args>`.
253
+ """
254
+ calling_func = "irfftn"
255
+ planner_effort = _default_effort(planner_effort)
256
+ threads = _default_threads(threads)
257
+
258
+ return _Xfftn(
259
+ a,
260
+ s,
261
+ axes,
262
+ overwrite_input,
263
+ planner_effort,
264
+ threads,
265
+ auto_align_input,
266
+ auto_contiguous,
267
+ calling_func,
268
+ **_norm_args(norm),
269
+ output_array=out,
270
+ )
@@ -6,9 +6,9 @@ Copyright (c) 2023 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
+ from typing import Tuple, List
9
10
  from importlib.util import find_spec
10
11
  from contextlib import contextmanager
11
- from typing import Tuple, Callable, List
12
12
 
13
13
  from .npfftw_backend import NumpyFFTWBackend
14
14
  from ..types import CupyArray, NDArray, shm_type
@@ -81,6 +81,17 @@ class CupyBackend(NumpyFFTWBackend):
81
81
  """,
82
82
  "norm_scores",
83
83
  )
84
+
85
+ # Sum of square computation similar to the demeaned variance in pytom
86
+ self.ssum = cp.ReductionKernel(
87
+ f"{ftype} arr",
88
+ f"{ftype} ret",
89
+ "arr * arr",
90
+ "a + b",
91
+ "ret = a",
92
+ "0",
93
+ f"ssum_{ftype}",
94
+ )
84
95
  self.texture_available = find_spec("voltools") is not None
85
96
 
86
97
  def to_backend_array(self, arr: NDArray) -> CupyArray:
@@ -111,53 +122,14 @@ class CupyBackend(NumpyFFTWBackend):
111
122
  def unravel_index(self, indices, shape):
112
123
  return self._array_backend.unravel_index(indices=indices, dims=shape)
113
124
 
114
- def build_fft(
115
- self,
116
- fwd_shape: Tuple[int],
117
- inv_shape: Tuple[int],
118
- inv_output_shape: Tuple[int] = None,
119
- fwd_axes: Tuple[int] = None,
120
- inv_axes: Tuple[int] = None,
121
- **kwargs,
122
- ) -> Tuple[Callable, Callable]:
123
- cache = self._array_backend.fft.config.get_plan_cache()
124
- current_device = self._array_backend.cuda.device.get_device_id()
125
-
126
- previous_transform = [fwd_shape, inv_shape]
127
- if current_device in PLAN_CACHE:
128
- previous_transform = PLAN_CACHE[current_device]
129
-
130
- real_diff, cmplx_diff = True, True
131
- if len(fwd_shape) == len(previous_transform[0]):
132
- real_diff = fwd_shape == previous_transform[0]
133
- if len(inv_shape) == len(previous_transform[1]):
134
- cmplx_diff = inv_shape == previous_transform[1]
135
-
136
- if real_diff or cmplx_diff:
137
- cache.clear()
138
-
139
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
140
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
141
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
142
-
143
- def rfftn(
144
- arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
145
- ) -> CupyArray:
146
- return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
147
-
148
- def irfftn(
149
- arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
150
- ) -> CupyArray:
151
- return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
152
-
153
- PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
154
- return rfftn, irfftn
125
+ def free_cache(self):
126
+ self._array_backend.fft.config.get_plan_cache().clear()
155
127
 
156
128
  def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
157
129
  return self._cufft.rfftn(arr, **kwargs)
158
130
 
159
131
  def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
160
- return self._cufft.irfftn(arr, **kwargs)
132
+ return self._cufft.irfftn(arr, **kwargs).astype(self._float_dtype)
161
133
 
162
134
  def compute_convolution_shapes(
163
135
  self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
@@ -178,17 +150,6 @@ class CupyBackend(NumpyFFTWBackend):
178
150
  peaks = self._array_backend.array(self._array_backend.nonzero(max_filter)).T
179
151
  return peaks
180
152
 
181
- # The default methods in Cupy were oddly slow
182
- def var(self, a, *args, **kwargs):
183
- out = a - self._array_backend.mean(a, *args, **kwargs)
184
- self._array_backend.square(out, out)
185
- out = self._array_backend.mean(out, *args, **kwargs)
186
- return out
187
-
188
- def std(self, a, *args, **kwargs):
189
- out = self.var(a, *args, **kwargs)
190
- return self._array_backend.sqrt(out)
191
-
192
153
  def _get_texture(self, arr: CupyArray, order: int = 3, prefilter: bool = False):
193
154
  key = id(arr)
194
155
  if key in TEXTURE_CACHE:
@@ -235,7 +196,7 @@ class CupyBackend(NumpyFFTWBackend):
235
196
  )
236
197
  return None
237
198
 
238
- if data.ndim == 3 and cache and self.texture_available:
199
+ if data.ndim == 3 and cache and self.texture_available and not batched:
239
200
  # Device memory pool (should) come to rescue performance
240
201
  temp = self.zeros(data.shape, data.dtype)
241
202
  texture = self._get_texture(data, order=order, prefilter=prefilter)
@@ -7,7 +7,9 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
9
  from functools import wraps
10
- from typing import Tuple, List, Callable, Dict
10
+ from typing import Tuple, List, Dict, Any
11
+
12
+ import numpy as np
11
13
 
12
14
  from ..types import BackendArray
13
15
  from .npfftw_backend import NumpyFFTWBackend, shm_type
@@ -64,6 +66,10 @@ class JaxBackend(NumpyFFTWBackend):
64
66
  arr = arr.at[idx].set(value)
65
67
  return arr
66
68
 
69
+ def addat(self, arr, indices, values):
70
+ arr = arr.at[indices].add(values)
71
+ return arr
72
+
67
73
  def topleft_pad(
68
74
  self, arr: BackendArray, shape: Tuple[int], padval: int = 0
69
75
  ) -> BackendArray:
@@ -88,6 +94,7 @@ class JaxBackend(NumpyFFTWBackend):
88
94
  "sqrt",
89
95
  "maximum",
90
96
  "exp",
97
+ "mod",
91
98
  ]
92
99
  for ufunc in ufuncs:
93
100
  backend_method = emulate_out(getattr(self._array_backend, ufunc))
@@ -103,27 +110,6 @@ class JaxBackend(NumpyFFTWBackend):
103
110
  shape=arr.shape, dtype=arr.dtype, fill_value=value
104
111
  )
105
112
 
106
- def build_fft(
107
- self,
108
- fwd_shape: Tuple[int],
109
- inv_shape: Tuple[int] = None,
110
- inv_output_shape: Tuple[int] = None,
111
- fwd_axes: Tuple[int] = None,
112
- inv_axes: Tuple[int] = None,
113
- **kwargs,
114
- ) -> Tuple[Callable, Callable]:
115
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
116
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
117
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
118
-
119
- def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
120
- return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
121
-
122
- def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
123
- return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
124
-
125
- return rfftn, irfftn
126
-
127
113
  def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
128
114
  return self._array_backend.fft.rfftn(arr, **kwargs)
129
115
 
@@ -194,12 +180,37 @@ class JaxBackend(NumpyFFTWBackend):
194
180
 
195
181
  return convolution_shape, fast_shape, fast_ft_shape
196
182
 
183
+ def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]:
184
+ if isinstance(obj, np.ndarray):
185
+ return ("array", (tuple(obj.flatten().tolist()), obj.shape))
186
+ return ("other", obj)
187
+
188
+ def _from_hashable(self, type_info: str, data: Any) -> Any:
189
+ if type_info == "array":
190
+ data, shape = data
191
+ return self.array(data).reshape(shape)
192
+ return data
193
+
194
+ def _dict_to_tuple(self, data: Dict) -> Tuple:
195
+ return tuple((k, self._to_hashable(v)) for k, v in data.items())
196
+
197
+ def _tuple_to_dict(self, data: Tuple) -> Dict:
198
+ return {x[0]: self._from_hashable(*x[1]) for x in data}
199
+
200
+ def _unbatch(self, data, target_ndim, index):
201
+ if not isinstance(data, type(self.zeros(1))):
202
+ return data
203
+ elif data.ndim <= target_ndim:
204
+ return data
205
+ return data[index]
206
+
197
207
  def scan(
198
208
  self,
199
209
  matching_data: type,
200
210
  splits: Tuple[Tuple[slice, slice]],
201
211
  n_jobs: int,
202
- callback_class,
212
+ callback_class: object,
213
+ callback_class_args: Dict,
203
214
  rotate_mask: bool = False,
204
215
  **kwargs,
205
216
  ) -> List:
@@ -208,11 +219,13 @@ class JaxBackend(NumpyFFTWBackend):
208
219
  :py:class:`tme.analyzer.MaxScoreOverRotations`.
209
220
  """
210
221
  from ._jax_utils import scan as scan_inner
222
+ from ..analyzer import MaxScoreOverRotations
211
223
 
212
224
  pad_target = True if len(splits) > 1 else False
213
225
  convolution_mode = "valid" if pad_target else "same"
214
226
  target_pad = matching_data.target_padding(pad_target=pad_target)
215
227
 
228
+ score_mask = self.full((1,), fill_value=1, dtype=bool)
216
229
  target_shape = tuple(
217
230
  (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
218
231
  )
@@ -220,16 +233,20 @@ class JaxBackend(NumpyFFTWBackend):
220
233
  target_shape=self.to_numpy_array(target_shape),
221
234
  template_shape=self.to_numpy_array(matching_data._template.shape),
222
235
  batch_mask=self.to_numpy_array(matching_data._batch_mask),
223
- pad_target=pad_target,
224
236
  )
225
237
  analyzer_args = {
226
- "convolution_mode": convolution_mode,
238
+ "shape": fast_shape,
227
239
  "fourier_shift": shift,
240
+ "fast_shape": fast_shape,
228
241
  "targetshape": target_shape,
229
242
  "templateshape": matching_data.template.shape,
230
243
  "convolution_shape": conv_shape,
244
+ "convolution_mode": convolution_mode,
245
+ "thread_safe": False,
246
+ "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
247
+ "n_rotations": matching_data.rotations.shape[0],
248
+ "jax_mode": True,
231
249
  }
232
-
233
250
  create_target_filter = matching_data.target_filter is not None
234
251
  create_template_filter = matching_data.template_filter is not None
235
252
  create_filter = create_target_filter or create_template_filter
@@ -245,6 +262,9 @@ class JaxBackend(NumpyFFTWBackend):
245
262
  for i in range(matching_data.rotations.shape[0])
246
263
  }
247
264
  for split_start in range(0, len(splits), n_jobs):
265
+
266
+ analyzer_kwargs = []
267
+
248
268
  split_subset = splits[split_start : (split_start + n_jobs)]
249
269
  if not len(split_subset):
250
270
  continue
@@ -256,8 +276,18 @@ class JaxBackend(NumpyFFTWBackend):
256
276
  target_pad=target_pad,
257
277
  template_slice=template_split,
258
278
  )
279
+ cur_args = analyzer_args.copy()
280
+ cur_args["offset"] = translation_offset
281
+ cur_args.update(callback_class_args)
282
+
283
+ analyzer_kwargs.append(self._dict_to_tuple(cur_args))
284
+
285
+ if pad_target:
286
+ score_mask = base._score_mask(fast_shape, shift)
287
+
288
+ _target = self.astype(base._target, self._float_dtype)
259
289
  translation_offsets.append(translation_offset)
260
- targets.append(self.topleft_pad(base._target, fast_shape))
290
+ targets.append(self.topleft_pad(_target, fast_shape))
261
291
 
262
292
  if create_filter:
263
293
  filter_args = {
@@ -279,24 +309,33 @@ class JaxBackend(NumpyFFTWBackend):
279
309
 
280
310
  create_filter, create_template_filter, create_target_filter = (False,) * 3
281
311
  base, targets = None, self._array_backend.stack(targets)
282
- scores, rotations = scan_inner(
312
+
313
+ analyzer_kwargs = tuple(analyzer_kwargs)
314
+ states = scan_inner(
283
315
  self.astype(targets, self._float_dtype),
284
- matching_data.template,
285
- matching_data.template_mask,
316
+ self.astype(matching_data.template, self._float_dtype),
317
+ self.astype(matching_data.template_mask, self._float_dtype),
286
318
  matching_data.rotations,
287
319
  template_filter,
288
320
  target_filter,
321
+ score_mask,
289
322
  fast_shape,
290
323
  rotate_mask,
324
+ callback_class,
325
+ analyzer_kwargs,
291
326
  )
292
327
 
293
- for index in range(scores.shape[0]):
294
- temp = callback_class(
295
- shape=scores[index].shape,
296
- offset=translation_offsets[index],
297
- )
298
- state = (scores[index], rotations[index], rotation_mapping)
299
- ret.append(temp.result(state, **analyzer_args))
328
+ ndim = targets.ndim - 1
329
+ for index in range(targets.shape[0]):
330
+ kwargs = self._tuple_to_dict(analyzer_kwargs[index])
331
+ analyzer = callback_class(**kwargs)
332
+
333
+ state = [self._unbatch(x, ndim, index) for x in states]
334
+
335
+ if isinstance(analyzer, MaxScoreOverRotations):
336
+ state[2] = rotation_mapping
337
+
338
+ ret.append(analyzer.result(state, **kwargs))
300
339
  return ret
301
340
 
302
341
  def get_available_memory(self) -> int:
@@ -8,7 +8,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
 
9
9
  from abc import ABC, abstractmethod
10
10
  from multiprocessing import shared_memory
11
- from typing import Tuple, Callable, List, Any, Union, Optional, Generator, Dict
11
+ from typing import Tuple, Callable, List, Any, Union, Optional, Generator
12
12
 
13
13
  from ..types import BackendArray, NDArray, Scalar, shm_type
14
14
 
@@ -863,6 +863,17 @@ class MatchingBackend(ABC):
863
863
  Indices of ``k`` largest elements in ``arr``.
864
864
  """
865
865
 
866
+ @abstractmethod
867
+ def ssum(self, arr, *args, **kwargs) -> BackendArray:
868
+ """
869
+ Compute the sum of squares of ``arr``.
870
+
871
+ Returns
872
+ -------
873
+ BackendArray
874
+ Sum of squares with shape ().
875
+ """
876
+
866
877
  def indices(self, *args, **kwargs) -> BackendArray:
867
878
  """
868
879
  Creates an array representing the index grid of an input.
@@ -1087,57 +1098,12 @@ class MatchingBackend(ABC):
1087
1098
  """
1088
1099
 
1089
1100
  @abstractmethod
1090
- def build_fft(
1091
- self,
1092
- fwd_shape: Tuple[int],
1093
- inv_shape: Tuple[int],
1094
- real_dtype: type,
1095
- cmpl_dtype: type,
1096
- inv_output_shape: Tuple[int] = None,
1097
- temp_fwd: NDArray = None,
1098
- temp_inv: NDArray = None,
1099
- fwd_axes: Tuple[int] = None,
1100
- inv_axes: Tuple[int] = None,
1101
- fftargs: Dict = {},
1102
- ) -> Tuple[Callable, Callable]:
1103
- """
1104
- Build forward and inverse real fourier transform functions. The returned
1105
- callables have two parameters ``arr`` and ``out`` which correspond to the
1106
- input and output of the Fourier transform. The methods return the output
1107
- of the respective function call, regardless of ``out`` being provided or not,
1108
- analogous to most numpy functions.
1101
+ def rfftn(self, **kwargs):
1102
+ """Perform an n-D real FFT."""
1109
1103
 
1110
- Parameters
1111
- ----------
1112
- fwd_shape : tuple
1113
- Input shape for the forward Fourier transform.
1114
- (see `compute_convolution_shapes`).
1115
- inv_shape : tuple
1116
- Input shape for the inverse Fourier transform.
1117
- real_dtype : dtype
1118
- Data type of the forward Fourier transform.
1119
- complex_dtype : dtype
1120
- Data type of the inverse Fourier transform.
1121
- inv_output_shape : tuple, optional
1122
- Output shape of the inverse Fourier transform. By default fast_shape.
1123
- fftargs : dict, optional
1124
- Dictionary passed to pyFFTW builders.
1125
- temp_fwd : NDArray, optional
1126
- Temporary array to build the forward transform. Superseeds shape defined by
1127
- fwd_shape if provided.
1128
- temp_inv : NDArray, optional
1129
- Temporary array to build the inverse transform. Superseeds shape defined by
1130
- inv_shape if provided.
1131
- fwd_axes : tuple of int
1132
- Axes to perform the forward Fourier transform over.
1133
- inv_axes : tuple of int
1134
- Axes to perform the inverse Fourier transform over.
1135
-
1136
- Returns
1137
- -------
1138
- tuple
1139
- Tuple of callables for forward and inverse real Fourier transform.
1140
- """
1104
+ @abstractmethod
1105
+ def irfftn(self, **kwargs):
1106
+ """Perform an n-D real inverse FFT."""
1141
1107
 
1142
1108
  def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
1143
1109
  """