pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,9 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
9
  from functools import wraps
10
- from typing import Tuple, List, Callable
10
+ from typing import Tuple, List, Dict, Any
11
+
12
+ import numpy as np
11
13
 
12
14
  from ..types import BackendArray
13
15
  from .npfftw_backend import NumpyFFTWBackend, shm_type
@@ -51,12 +53,6 @@ class JaxBackend(NumpyFFTWBackend):
51
53
  )
52
54
  self.scipy = jsp
53
55
  self._create_ufuncs()
54
- try:
55
- from ._jax_utils import scan as _
56
-
57
- self.scan = self._scan
58
- except Exception:
59
- pass
60
56
 
61
57
  def from_sharedarr(self, arr: BackendArray) -> BackendArray:
62
58
  return arr
@@ -70,6 +66,10 @@ class JaxBackend(NumpyFFTWBackend):
70
66
  arr = arr.at[idx].set(value)
71
67
  return arr
72
68
 
69
+ def addat(self, arr, indices, values):
70
+ arr = arr.at[indices].add(values)
71
+ return arr
72
+
73
73
  def topleft_pad(
74
74
  self, arr: BackendArray, shape: Tuple[int], padval: int = 0
75
75
  ) -> BackendArray:
@@ -94,6 +94,7 @@ class JaxBackend(NumpyFFTWBackend):
94
94
  "sqrt",
95
95
  "maximum",
96
96
  "exp",
97
+ "mod",
97
98
  ]
98
99
  for ufunc in ufuncs:
99
100
  backend_method = emulate_out(getattr(self._array_backend, ufunc))
@@ -109,27 +110,6 @@ class JaxBackend(NumpyFFTWBackend):
109
110
  shape=arr.shape, dtype=arr.dtype, fill_value=value
110
111
  )
111
112
 
112
- def build_fft(
113
- self,
114
- fwd_shape: Tuple[int],
115
- inv_shape: Tuple[int] = None,
116
- inv_output_shape: Tuple[int] = None,
117
- fwd_axes: Tuple[int] = None,
118
- inv_axes: Tuple[int] = None,
119
- **kwargs,
120
- ) -> Tuple[Callable, Callable]:
121
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
122
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
123
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
124
-
125
- def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
126
- return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
127
-
128
- def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
129
- return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
130
-
131
- return rfftn, irfftn
132
-
133
113
  def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
134
114
  return self._array_backend.fft.rfftn(arr, **kwargs)
135
115
 
@@ -189,12 +169,41 @@ class JaxBackend(NumpyFFTWBackend):
189
169
  rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
190
170
  return max_scores, rotations
191
171
 
192
- def _scan(
172
+ def compute_convolution_shapes(
173
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
174
+ ) -> Tuple[List[int], List[int], List[int]]:
175
+ from scipy.fft import next_fast_len
176
+
177
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
178
+ fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
179
+ fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
180
+
181
+ return convolution_shape, fast_shape, fast_ft_shape
182
+
183
+ def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]:
184
+ if isinstance(obj, np.ndarray):
185
+ return ("array", (tuple(obj.flatten().tolist()), obj.shape))
186
+ return ("other", obj)
187
+
188
+ def _from_hashable(self, type_info: str, data: Any) -> Any:
189
+ if type_info == "array":
190
+ data, shape = data
191
+ return self.array(data).reshape(shape)
192
+ return data
193
+
194
+ def _dict_to_tuple(self, data: Dict) -> Tuple:
195
+ return tuple((k, self._to_hashable(v)) for k, v in data.items())
196
+
197
+ def _tuple_to_dict(self, data: Tuple) -> Dict:
198
+ return {x[0]: self._from_hashable(*x[1]) for x in data}
199
+
200
+ def scan(
193
201
  self,
194
202
  matching_data: type,
195
203
  splits: Tuple[Tuple[slice, slice]],
196
204
  n_jobs: int,
197
- callback_class,
205
+ callback_class: object,
206
+ callback_class_args: Dict,
198
207
  rotate_mask: bool = False,
199
208
  **kwargs,
200
209
  ) -> List:
@@ -214,17 +223,21 @@ class JaxBackend(NumpyFFTWBackend):
214
223
  conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
215
224
  target_shape=self.to_numpy_array(target_shape),
216
225
  template_shape=self.to_numpy_array(matching_data._template.shape),
217
- pad_fourier=False,
226
+ batch_mask=self.to_numpy_array(matching_data._batch_mask),
218
227
  )
219
-
220
228
  analyzer_args = {
221
- "convolution_mode": convolution_mode,
229
+ "shape": fast_shape,
222
230
  "fourier_shift": shift,
231
+ "fast_shape": fast_shape,
223
232
  "targetshape": target_shape,
224
233
  "templateshape": matching_data.template.shape,
225
234
  "convolution_shape": conv_shape,
235
+ "convolution_mode": convolution_mode,
236
+ "thread_safe": False,
237
+ "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
238
+ "n_rotations": matching_data.rotations.shape[0],
239
+ "jax_mode": True,
226
240
  }
227
-
228
241
  create_target_filter = matching_data.target_filter is not None
229
242
  create_template_filter = matching_data.template_filter is not None
230
243
  create_filter = create_target_filter or create_template_filter
@@ -240,25 +253,34 @@ class JaxBackend(NumpyFFTWBackend):
240
253
  for i in range(matching_data.rotations.shape[0])
241
254
  }
242
255
  for split_start in range(0, len(splits), n_jobs):
256
+
257
+ analyzer_kwargs = []
258
+
243
259
  split_subset = splits[split_start : (split_start + n_jobs)]
244
260
  if not len(split_subset):
245
261
  continue
246
262
 
247
263
  targets, translation_offsets = [], []
248
264
  for target_split, template_split in split_subset:
249
- base = matching_data.subset_by_slice(
265
+ base, translation_offset = matching_data.subset_by_slice(
250
266
  target_slice=target_split,
251
267
  target_pad=target_pad,
252
268
  template_slice=template_split,
253
269
  )
254
- translation_offsets.append(base._translation_offset)
255
- targets.append(self.topleft_pad(base._target, fast_shape))
270
+ cur_args = analyzer_args.copy()
271
+ cur_args["offset"] = translation_offset
272
+ cur_args.update(callback_class_args)
273
+
274
+ analyzer_kwargs.append(self._dict_to_tuple(cur_args))
275
+
276
+ _target = self.astype(base._target, self._float_dtype)
277
+ translation_offsets.append(translation_offset)
278
+ targets.append(self.topleft_pad(_target, fast_shape))
256
279
 
257
280
  if create_filter:
258
281
  filter_args = {
259
282
  "data_rfft": self.fft.rfftn(targets[0]),
260
283
  "return_real_fourier": True,
261
- "shape_is_real_fourier": False,
262
284
  }
263
285
 
264
286
  if create_template_filter:
@@ -275,25 +297,27 @@ class JaxBackend(NumpyFFTWBackend):
275
297
 
276
298
  create_filter, create_template_filter, create_target_filter = (False,) * 3
277
299
  base, targets = None, self._array_backend.stack(targets)
278
- scores, rotations = scan_inner(
300
+
301
+ analyzer_kwargs = tuple(analyzer_kwargs)
302
+ states = scan_inner(
279
303
  self.astype(targets, self._float_dtype),
280
- matching_data.template,
281
- matching_data.template_mask,
304
+ self.astype(matching_data.template, self._float_dtype),
305
+ self.astype(matching_data.template_mask, self._float_dtype),
282
306
  matching_data.rotations,
283
307
  template_filter,
284
308
  target_filter,
285
309
  fast_shape,
286
310
  rotate_mask,
311
+ callback_class,
312
+ analyzer_kwargs,
287
313
  )
288
314
 
289
- for index in range(scores.shape[0]):
290
- temp = callback_class(
291
- shape=scores.shape,
292
- offset=translation_offsets[index],
293
- )
294
- state = (scores, rotations, rotation_mapping)
295
- ret.append(temp.result(state, **analyzer_args))
315
+ for index in range(targets.shape[0]):
316
+ kwargs = self._tuple_to_dict(analyzer_kwargs[index])
317
+ analyzer = callback_class(**kwargs)
296
318
 
319
+ state = (states[0][index], states[1][index], rotation_mapping)
320
+ ret.append(analyzer.result(state, **kwargs))
297
321
  return ret
298
322
 
299
323
  def get_available_memory(self) -> int:
@@ -8,7 +8,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
 
9
9
  from abc import ABC, abstractmethod
10
10
  from multiprocessing import shared_memory
11
- from typing import Tuple, Callable, List, Any, Union, Optional, Generator, Dict
11
+ from typing import Tuple, Callable, List, Any, Union, Optional, Generator
12
12
 
13
13
  from ..types import BackendArray, NDArray, Scalar, shm_type
14
14
 
@@ -1087,57 +1087,12 @@ class MatchingBackend(ABC):
1087
1087
  """
1088
1088
 
1089
1089
  @abstractmethod
1090
- def build_fft(
1091
- self,
1092
- fwd_shape: Tuple[int],
1093
- inv_shape: Tuple[int],
1094
- real_dtype: type,
1095
- cmpl_dtype: type,
1096
- inv_output_shape: Tuple[int] = None,
1097
- temp_fwd: NDArray = None,
1098
- temp_inv: NDArray = None,
1099
- fwd_axes: Tuple[int] = None,
1100
- inv_axes: Tuple[int] = None,
1101
- fftargs: Dict = {},
1102
- ) -> Tuple[Callable, Callable]:
1103
- """
1104
- Build forward and inverse real fourier transform functions. The returned
1105
- callables have two parameters ``arr`` and ``out`` which correspond to the
1106
- input and output of the Fourier transform. The methods return the output
1107
- of the respective function call, regardless of ``out`` being provided or not,
1108
- analogous to most numpy functions.
1090
+ def rfftn(self, **kwargs):
1091
+ """Perform an n-D real FFT."""
1109
1092
 
1110
- Parameters
1111
- ----------
1112
- fwd_shape : tuple
1113
- Input shape for the forward Fourier transform.
1114
- (see `compute_convolution_shapes`).
1115
- inv_shape : tuple
1116
- Input shape for the inverse Fourier transform.
1117
- real_dtype : dtype
1118
- Data type of the forward Fourier transform.
1119
- complex_dtype : dtype
1120
- Data type of the inverse Fourier transform.
1121
- inv_output_shape : tuple, optional
1122
- Output shape of the inverse Fourier transform. By default fast_shape.
1123
- fftargs : dict, optional
1124
- Dictionary passed to pyFFTW builders.
1125
- temp_fwd : NDArray, optional
1126
- Temporary array to build the forward transform. Superseeds shape defined by
1127
- fwd_shape if provided.
1128
- temp_inv : NDArray, optional
1129
- Temporary array to build the inverse transform. Superseeds shape defined by
1130
- inv_shape if provided.
1131
- fwd_axes : tuple of int
1132
- Axes to perform the forward Fourier transform over.
1133
- inv_axes : tuple of int
1134
- Axes to perform the inverse Fourier transform over.
1135
-
1136
- Returns
1137
- -------
1138
- tuple
1139
- Tuple of callables for forward and inverse real Fourier transform.
1140
- """
1093
+ @abstractmethod
1094
+ def irfftn(self, **kwargs):
1095
+ """Perform an n-D real inverse FFT."""
1141
1096
 
1142
1097
  def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
1143
1098
  """
@@ -6,7 +6,7 @@ Copyright (c) 2024 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- from typing import Tuple, List, Callable
9
+ from typing import Tuple, List
10
10
 
11
11
  import numpy as np
12
12
 
@@ -144,32 +144,6 @@ class MLXBackend(NumpyFFTWBackend):
144
144
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
145
145
  return arr[box]
146
146
 
147
- def build_fft(
148
- self,
149
- fwd_shape: Tuple[int],
150
- inv_shape: Tuple[int] = None,
151
- inv_output_shape: Tuple[int] = None,
152
- fwd_axes: Tuple[int] = None,
153
- inv_axes: Tuple[int] = None,
154
- **kwargs,
155
- ) -> Tuple[Callable, Callable]:
156
- # Runs on mlx.core.cpu until Metal support is available
157
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
158
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
159
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
160
-
161
- def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
162
- out[:] = self._array_backend.fft.rfftn(
163
- arr, s=s, axes=axes, stream=self._array_backend.cpu
164
- )
165
-
166
- def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
167
- out[:] = self._array_backend.fft.irfftn(
168
- arr, s=s, axes=axes, stream=self._array_backend.cpu
169
- )
170
-
171
- return rfftn, irfftn
172
-
173
147
  def rfftn(self, arr, *args, **kwargs):
174
148
  return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
175
149
 
@@ -9,17 +9,23 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
9
9
  import os
10
10
  from psutil import virtual_memory
11
11
  from contextlib import contextmanager
12
- from typing import Tuple, Dict, List, Type
12
+ from typing import Tuple, List, Type
13
13
 
14
14
  import scipy
15
15
  import numpy as np
16
16
  from scipy.ndimage import maximum_filter, affine_transform
17
- from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
18
- from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len, interfaces
17
+ from pyfftw import (
18
+ zeros_aligned,
19
+ simd_alignment,
20
+ next_fast_len,
21
+ interfaces,
22
+ config,
23
+ )
19
24
 
20
25
  from ..types import NDArray, BackendArray, shm_type
21
26
  from .matching_backend import MatchingBackend, _create_metafunction
22
27
 
28
+
23
29
  os.environ["MKL_NUM_THREADS"] = "1"
24
30
  os.environ["OMP_NUM_THREADS"] = "1"
25
31
  os.environ["PYFFTW_NUM_THREADS"] = "1"
@@ -103,6 +109,20 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
103
109
  self.solve_triangular = self._solve_triangular
104
110
  self.linalg.solve_triangular = scipy.linalg.solve_triangular
105
111
 
112
+ try:
113
+ from ._numpyfftw_utils import rfftn as rfftn_cache
114
+ from ._numpyfftw_utils import irfftn as irfftn_cache
115
+
116
+ self._rfftn = rfftn_cache
117
+ self._irfftn = irfftn_cache
118
+ except Exception as e:
119
+ print(e)
120
+
121
+ config.NUM_THREADS = 1
122
+ config.PLANNER_EFFORT = "FFTW_MEASURE"
123
+ interfaces.cache.enable()
124
+ interfaces.cache.set_keepalive_time(360)
125
+
106
126
  def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
107
127
  # Upper argument is not supported until numpy 2.0
108
128
  ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
@@ -138,7 +158,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
138
158
  return float
139
159
 
140
160
  def free_cache(self):
141
- pass
161
+ interfaces.cache.disable()
142
162
 
143
163
  def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
144
164
  return self._array_backend.transpose(arr, *args, **kwargs)
@@ -240,70 +260,53 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
240
260
  b[tuple(bind)] = arr[tuple(aind)]
241
261
  return b
242
262
 
243
- def build_fft(
244
- self,
245
- fwd_shape: Tuple[int],
246
- inv_shape: Tuple[int],
247
- real_dtype: type,
248
- cmpl_dtype: type,
249
- fftargs: Dict = {},
250
- inv_output_shape: Tuple[int] = None,
251
- temp_fwd: NDArray = None,
252
- temp_inv: NDArray = None,
253
- fwd_axes: Tuple[int] = None,
254
- inv_axes: Tuple[int] = None,
255
- ) -> Tuple[FFTW, FFTW]:
256
- if temp_fwd is None:
257
- temp_fwd = (
258
- self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
259
- )
260
- if temp_inv is None:
261
- temp_inv = (
262
- self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
263
- )
264
-
265
- default_values = {
266
- "planner_effort": "FFTW_MEASURE",
267
- "auto_align_input": False,
268
- "auto_contiguous": False,
269
- "avoid_copy": True,
270
- "overwrite_input": True,
271
- "threads": 1,
272
- }
273
- for key in default_values:
274
- if key in fftargs:
275
- continue
276
- fftargs[key] = default_values[key]
277
-
278
- rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
279
- _rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
280
- overwrite_input = fftargs.pop("overwrite_input", None)
281
-
282
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
283
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
284
- _irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
285
-
286
- def _rfftn_wrapper(arr, out, *args, **kwargs):
287
- return _rfftn(arr, out)
288
-
289
- def _irfftn_wrapper(arr, out, *args, **kwargs):
290
- return _irfftn(arr, out)
291
-
292
- fftargs["overwrite_input"] = overwrite_input
293
- return _rfftn_wrapper, _irfftn_wrapper
263
+ def _rfftn(self, arr, out=None, **kwargs):
264
+ ret = interfaces.numpy_fft.rfftn(arr, **kwargs)
265
+ if out is not None:
266
+ out[:] = ret
267
+ return out
268
+ return ret
294
269
 
295
- @staticmethod
296
- def _format_fft_shape(shape: Tuple[int], axes: Tuple[int] = None):
297
- if axes is None:
298
- return shape
299
- axes = tuple(sorted(range(len(shape))[i] for i in axes))
300
- return tuple(shape[i] for i in axes)
270
+ def _irfftn(self, arr, out=None, **kwargs):
271
+ ret = interfaces.numpy_fft.irfftn(arr, **kwargs)
272
+ if out is not None:
273
+ out[:] = ret
274
+ return out
275
+ return ret
301
276
 
302
- def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
303
- return interfaces.numpy_fft.rfftn(arr, **kwargs)
277
+ def rfftn(
278
+ self,
279
+ arr: NDArray,
280
+ out=None,
281
+ auto_align_input: bool = False,
282
+ auto_contiguous: bool = False,
283
+ overwrite_input: bool = True,
284
+ **kwargs,
285
+ ) -> NDArray:
286
+ return self._rfftn(
287
+ arr,
288
+ auto_align_input=auto_align_input,
289
+ auto_contiguous=auto_contiguous,
290
+ overwrite_input=overwrite_input,
291
+ **kwargs,
292
+ )
304
293
 
305
- def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
306
- return interfaces.numpy_fft.irfftn(arr, **kwargs)
294
+ def irfftn(
295
+ self,
296
+ arr: NDArray,
297
+ out=None,
298
+ auto_align_input: bool = False,
299
+ auto_contiguous: bool = False,
300
+ overwrite_input: bool = True,
301
+ **kwargs,
302
+ ) -> NDArray:
303
+ return self._irfftn(
304
+ arr,
305
+ auto_align_input=auto_align_input,
306
+ auto_contiguous=auto_contiguous,
307
+ overwrite_input=overwrite_input,
308
+ **kwargs,
309
+ )
307
310
 
308
311
  def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
309
312
  new_shape = self.to_backend_array(newshape)
@@ -398,33 +401,33 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
398
401
  out_mask: NDArray = None,
399
402
  order: int = 3,
400
403
  cache: bool = False,
404
+ batched: bool = False,
401
405
  ) -> Tuple[NDArray, NDArray]:
402
- out = self.zeros_like(arr) if out is None else out
403
- batched = arr.ndim != rotation_matrix.shape[0]
404
-
405
- center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
406
- if not use_geometric_center:
407
- center = self.center_of_mass(arr, cutoff=0)
408
-
409
- offset = int(arr.ndim - rotation_matrix.shape[0])
410
- center = center[offset:]
411
- translation = self.zeros(center.size) if translation is None else translation
412
- matrix = self._rigid_transform_matrix(
413
- rotation_matrix=rotation_matrix,
414
- translation=translation,
415
- center=center,
416
- )
417
-
418
- subset = tuple(slice(None) for _ in range(arr.ndim))
419
- if offset > 1:
420
- subset = tuple(
421
- 0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
406
+ if out is None:
407
+ out = self.zeros_like(arr)
408
+
409
+ # Check whether rotation_matrix is already a rigid transform matrix
410
+ matrix = rotation_matrix
411
+ if matrix.shape[-1] == (arr.ndim - int(batched)):
412
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
413
+ if not use_geometric_center:
414
+ center = self.center_of_mass(arr, cutoff=0)
415
+
416
+ offset = int(arr.ndim - rotation_matrix.shape[0])
417
+ center = center[offset:]
418
+ translation = (
419
+ self.zeros(center.size) if translation is None else translation
420
+ )
421
+ matrix = self._rigid_transform_matrix(
422
+ rotation_matrix=rotation_matrix,
423
+ translation=translation,
424
+ center=center,
422
425
  )
423
426
 
424
427
  self._rigid_transform(
425
- data=arr[subset],
428
+ data=arr,
426
429
  matrix=matrix,
427
- output=out[subset],
430
+ output=out,
428
431
  order=order,
429
432
  prefilter=True,
430
433
  cache=cache,
@@ -433,11 +436,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
433
436
 
434
437
  # Applying the prefilter leads to artifacts in the mask.
435
438
  if arr_mask is not None:
436
- out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
439
+ if out_mask is None:
440
+ out_mask = self.zeros_like(arr_mask)
441
+
437
442
  self._rigid_transform(
438
- data=arr_mask[subset],
443
+ data=arr_mask,
439
444
  matrix=matrix,
440
- output=out_mask[subset],
445
+ output=out_mask,
441
446
  order=order,
442
447
  prefilter=False,
443
448
  cache=cache,
@@ -7,7 +7,7 @@ Copyright (c) 2023 European Molecular Biology Laboratory
7
7
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
  """
9
9
 
10
- from typing import Tuple, Callable
10
+ from typing import Tuple
11
11
  from contextlib import contextmanager
12
12
  from multiprocessing import shared_memory
13
13
  from multiprocessing.managers import SharedMemoryManager
@@ -273,31 +273,6 @@ class PytorchBackend(NumpyFFTWBackend):
273
273
  kwargs["device"] = self.device
274
274
  return self._array_backend.eye(*args, **kwargs)
275
275
 
276
- def build_fft(
277
- self,
278
- fwd_shape: Tuple[int],
279
- inv_shape: Tuple[int],
280
- inv_output_shape: Tuple[int] = None,
281
- fwd_axes: Tuple[int] = None,
282
- inv_axes: Tuple[int] = None,
283
- **kwargs,
284
- ) -> Tuple[Callable, Callable]:
285
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
286
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
287
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
288
-
289
- def rfftn(
290
- arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
291
- ) -> TorchTensor:
292
- return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
293
-
294
- def irfftn(
295
- arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
296
- ) -> TorchTensor:
297
- return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
298
-
299
- return rfftn, irfftn
300
-
301
276
  def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
302
277
  kwargs["dim"] = kwargs.pop("axes", None)
303
278
  return self._array_backend.fft.rfftn(arr, **kwargs)
@@ -306,6 +281,9 @@ class PytorchBackend(NumpyFFTWBackend):
306
281
  kwargs["dim"] = kwargs.pop("axes", None)
307
282
  return self._array_backend.fft.irfftn(arr, **kwargs)
308
283
 
284
+ def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
285
+ return rotation_matrix
286
+
309
287
  def rigid_transform(
310
288
  self,
311
289
  arr: TorchTensor,
@@ -317,6 +295,7 @@ class PytorchBackend(NumpyFFTWBackend):
317
295
  out_mask: TorchTensor = None,
318
296
  order: int = 1,
319
297
  cache: bool = False,
298
+ **kwargs,
320
299
  ):
321
300
  _mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
322
301
  mode = _mode_mapping.get(order, None)
tme/density.py CHANGED
@@ -36,6 +36,7 @@ from .matching_utils import (
36
36
  array_to_memmap,
37
37
  memmap_to_array,
38
38
  minimum_enclosing_box,
39
+ is_gzipped,
39
40
  )
40
41
 
41
42
  __all__ = ["Density"]
@@ -331,6 +332,7 @@ class Density:
331
332
  if non_standard_crs:
332
333
  data = np.transpose(data, crs_index)
333
334
  origin = np.take(origin, crs_index)
335
+ sampling_rate = np.take(sampling_rate, crs_index)
334
336
 
335
337
  return data.T, origin[::-1], sampling_rate[::-1], metadata
336
338
 
@@ -1763,12 +1765,13 @@ class Density:
1763
1765
  axis=axis,
1764
1766
  )
1765
1767
 
1766
- arr_ft = np.fft.fftn(self.data)
1768
+ mask, mask_ret = np.where(mask), np.where(mask_ret)
1769
+
1770
+ arr_ft = np.fft.fftn(self.data)[mask]
1767
1771
  arr_ft *= np.prod(ret_shape) / np.prod(self.shape)
1768
1772
  ret_ft = np.zeros(ret_shape, dtype=arr_ft.dtype)
1769
- ret_ft[mask_ret] = arr_ft[mask]
1770
- ret.data = np.real(np.fft.ifftn(ret_ft))
1771
-
1773
+ np.add.at(ret_ft, mask_ret, arr_ft)
1774
+ ret.data = np.real(np.fft.ifftn(ret_ft)).astype(self.data.dtype)
1772
1775
  ret.sampling_rate = new_sampling_rate
1773
1776
  return ret
1774
1777
 
@@ -2256,9 +2259,3 @@ class Density:
2256
2259
  coordinates = np.array(np.where(data > 0))
2257
2260
  weights = self.data[tuple(coordinates)]
2258
2261
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2259
-
2260
-
2261
- def is_gzipped(filename: str) -> bool:
2262
- """Check if a file is a gzip file by reading its magic number."""
2263
- with open(filename, "rb") as f:
2264
- return f.read(2) == b"\x1f\x8b"
Binary file