pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +23 -10
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_backends.py +3 -9
  20. tests/test_density.py +0 -1
  21. tests/test_matching_utils.py +10 -60
  22. tests/test_rotations.py +1 -1
  23. tme/__version__.py +1 -1
  24. tme/analyzer/_utils.py +4 -4
  25. tme/analyzer/aggregation.py +13 -3
  26. tme/analyzer/peaks.py +11 -10
  27. tme/backends/_jax_utils.py +15 -13
  28. tme/backends/_numpyfftw_utils.py +270 -0
  29. tme/backends/cupy_backend.py +5 -44
  30. tme/backends/jax_backend.py +58 -37
  31. tme/backends/matching_backend.py +6 -51
  32. tme/backends/mlx_backend.py +1 -27
  33. tme/backends/npfftw_backend.py +68 -65
  34. tme/backends/pytorch_backend.py +1 -26
  35. tme/density.py +2 -6
  36. tme/extensions.cpython-311-darwin.so +0 -0
  37. tme/filters/ctf.py +22 -21
  38. tme/filters/wedge.py +10 -7
  39. tme/mask.py +341 -0
  40. tme/matching_data.py +7 -19
  41. tme/matching_exhaustive.py +34 -47
  42. tme/matching_optimization.py +2 -1
  43. tme/matching_scores.py +206 -411
  44. tme/matching_utils.py +73 -422
  45. tme/memory.py +1 -1
  46. tme/orientations.py +4 -6
  47. tme/rotations.py +1 -1
  48. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  49. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
  50. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
  51. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  52. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,9 @@ Copyright (c) 2023 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
+ from typing import Tuple, List
9
10
  from importlib.util import find_spec
10
11
  from contextlib import contextmanager
11
- from typing import Tuple, Callable, List
12
12
 
13
13
  from .npfftw_backend import NumpyFFTWBackend
14
14
  from ..types import CupyArray, NDArray, shm_type
@@ -111,53 +111,14 @@ class CupyBackend(NumpyFFTWBackend):
111
111
  def unravel_index(self, indices, shape):
112
112
  return self._array_backend.unravel_index(indices=indices, dims=shape)
113
113
 
114
- def build_fft(
115
- self,
116
- fwd_shape: Tuple[int],
117
- inv_shape: Tuple[int],
118
- inv_output_shape: Tuple[int] = None,
119
- fwd_axes: Tuple[int] = None,
120
- inv_axes: Tuple[int] = None,
121
- **kwargs,
122
- ) -> Tuple[Callable, Callable]:
123
- cache = self._array_backend.fft.config.get_plan_cache()
124
- current_device = self._array_backend.cuda.device.get_device_id()
125
-
126
- previous_transform = [fwd_shape, inv_shape]
127
- if current_device in PLAN_CACHE:
128
- previous_transform = PLAN_CACHE[current_device]
129
-
130
- real_diff, cmplx_diff = True, True
131
- if len(fwd_shape) == len(previous_transform[0]):
132
- real_diff = fwd_shape == previous_transform[0]
133
- if len(inv_shape) == len(previous_transform[1]):
134
- cmplx_diff = inv_shape == previous_transform[1]
135
-
136
- if real_diff or cmplx_diff:
137
- cache.clear()
138
-
139
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
140
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
141
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
142
-
143
- def rfftn(
144
- arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
145
- ) -> CupyArray:
146
- return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
147
-
148
- def irfftn(
149
- arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
150
- ) -> CupyArray:
151
- return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
152
-
153
- PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
154
- return rfftn, irfftn
114
+ def free_cache(self):
115
+ self._array_backend.fft.config.get_plan_cache().clear()
155
116
 
156
117
  def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
157
118
  return self._cufft.rfftn(arr, **kwargs)
158
119
 
159
120
  def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
160
- return self._cufft.irfftn(arr, **kwargs)
121
+ return self._cufft.irfftn(arr, **kwargs).astype(self._float_dtype)
161
122
 
162
123
  def compute_convolution_shapes(
163
124
  self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
@@ -235,7 +196,7 @@ class CupyBackend(NumpyFFTWBackend):
235
196
  )
236
197
  return None
237
198
 
238
- if data.ndim == 3 and cache and self.texture_available:
199
+ if data.ndim == 3 and cache and self.texture_available and not batched:
239
200
  # Device memory pool (should) come to rescue performance
240
201
  temp = self.zeros(data.shape, data.dtype)
241
202
  texture = self._get_texture(data, order=order, prefilter=prefilter)
@@ -7,7 +7,9 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
9
  from functools import wraps
10
- from typing import Tuple, List, Callable, Dict
10
+ from typing import Tuple, List, Dict, Any
11
+
12
+ import numpy as np
11
13
 
12
14
  from ..types import BackendArray
13
15
  from .npfftw_backend import NumpyFFTWBackend, shm_type
@@ -64,6 +66,10 @@ class JaxBackend(NumpyFFTWBackend):
64
66
  arr = arr.at[idx].set(value)
65
67
  return arr
66
68
 
69
+ def addat(self, arr, indices, values):
70
+ arr = arr.at[indices].add(values)
71
+ return arr
72
+
67
73
  def topleft_pad(
68
74
  self, arr: BackendArray, shape: Tuple[int], padval: int = 0
69
75
  ) -> BackendArray:
@@ -88,6 +94,7 @@ class JaxBackend(NumpyFFTWBackend):
88
94
  "sqrt",
89
95
  "maximum",
90
96
  "exp",
97
+ "mod",
91
98
  ]
92
99
  for ufunc in ufuncs:
93
100
  backend_method = emulate_out(getattr(self._array_backend, ufunc))
@@ -103,27 +110,6 @@ class JaxBackend(NumpyFFTWBackend):
103
110
  shape=arr.shape, dtype=arr.dtype, fill_value=value
104
111
  )
105
112
 
106
- def build_fft(
107
- self,
108
- fwd_shape: Tuple[int],
109
- inv_shape: Tuple[int] = None,
110
- inv_output_shape: Tuple[int] = None,
111
- fwd_axes: Tuple[int] = None,
112
- inv_axes: Tuple[int] = None,
113
- **kwargs,
114
- ) -> Tuple[Callable, Callable]:
115
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
116
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
117
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
118
-
119
- def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
120
- return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
121
-
122
- def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
123
- return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
124
-
125
- return rfftn, irfftn
126
-
127
113
  def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
128
114
  return self._array_backend.fft.rfftn(arr, **kwargs)
129
115
 
@@ -194,12 +180,30 @@ class JaxBackend(NumpyFFTWBackend):
194
180
 
195
181
  return convolution_shape, fast_shape, fast_ft_shape
196
182
 
183
+ def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]:
184
+ if isinstance(obj, np.ndarray):
185
+ return ("array", (tuple(obj.flatten().tolist()), obj.shape))
186
+ return ("other", obj)
187
+
188
+ def _from_hashable(self, type_info: str, data: Any) -> Any:
189
+ if type_info == "array":
190
+ data, shape = data
191
+ return self.array(data).reshape(shape)
192
+ return data
193
+
194
+ def _dict_to_tuple(self, data: Dict) -> Tuple:
195
+ return tuple((k, self._to_hashable(v)) for k, v in data.items())
196
+
197
+ def _tuple_to_dict(self, data: Tuple) -> Dict:
198
+ return {x[0]: self._from_hashable(*x[1]) for x in data}
199
+
197
200
  def scan(
198
201
  self,
199
202
  matching_data: type,
200
203
  splits: Tuple[Tuple[slice, slice]],
201
204
  n_jobs: int,
202
- callback_class,
205
+ callback_class: object,
206
+ callback_class_args: Dict,
203
207
  rotate_mask: bool = False,
204
208
  **kwargs,
205
209
  ) -> List:
@@ -220,16 +224,20 @@ class JaxBackend(NumpyFFTWBackend):
220
224
  target_shape=self.to_numpy_array(target_shape),
221
225
  template_shape=self.to_numpy_array(matching_data._template.shape),
222
226
  batch_mask=self.to_numpy_array(matching_data._batch_mask),
223
- pad_target=pad_target,
224
227
  )
225
228
  analyzer_args = {
226
- "convolution_mode": convolution_mode,
229
+ "shape": fast_shape,
227
230
  "fourier_shift": shift,
231
+ "fast_shape": fast_shape,
228
232
  "targetshape": target_shape,
229
233
  "templateshape": matching_data.template.shape,
230
234
  "convolution_shape": conv_shape,
235
+ "convolution_mode": convolution_mode,
236
+ "thread_safe": False,
237
+ "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
238
+ "n_rotations": matching_data.rotations.shape[0],
239
+ "jax_mode": True,
231
240
  }
232
-
233
241
  create_target_filter = matching_data.target_filter is not None
234
242
  create_template_filter = matching_data.template_filter is not None
235
243
  create_filter = create_target_filter or create_template_filter
@@ -245,6 +253,9 @@ class JaxBackend(NumpyFFTWBackend):
245
253
  for i in range(matching_data.rotations.shape[0])
246
254
  }
247
255
  for split_start in range(0, len(splits), n_jobs):
256
+
257
+ analyzer_kwargs = []
258
+
248
259
  split_subset = splits[split_start : (split_start + n_jobs)]
249
260
  if not len(split_subset):
250
261
  continue
@@ -256,8 +267,15 @@ class JaxBackend(NumpyFFTWBackend):
256
267
  target_pad=target_pad,
257
268
  template_slice=template_split,
258
269
  )
270
+ cur_args = analyzer_args.copy()
271
+ cur_args["offset"] = translation_offset
272
+ cur_args.update(callback_class_args)
273
+
274
+ analyzer_kwargs.append(self._dict_to_tuple(cur_args))
275
+
276
+ _target = self.astype(base._target, self._float_dtype)
259
277
  translation_offsets.append(translation_offset)
260
- targets.append(self.topleft_pad(base._target, fast_shape))
278
+ targets.append(self.topleft_pad(_target, fast_shape))
261
279
 
262
280
  if create_filter:
263
281
  filter_args = {
@@ -279,24 +297,27 @@ class JaxBackend(NumpyFFTWBackend):
279
297
 
280
298
  create_filter, create_template_filter, create_target_filter = (False,) * 3
281
299
  base, targets = None, self._array_backend.stack(targets)
282
- scores, rotations = scan_inner(
300
+
301
+ analyzer_kwargs = tuple(analyzer_kwargs)
302
+ states = scan_inner(
283
303
  self.astype(targets, self._float_dtype),
284
- matching_data.template,
285
- matching_data.template_mask,
304
+ self.astype(matching_data.template, self._float_dtype),
305
+ self.astype(matching_data.template_mask, self._float_dtype),
286
306
  matching_data.rotations,
287
307
  template_filter,
288
308
  target_filter,
289
309
  fast_shape,
290
310
  rotate_mask,
311
+ callback_class,
312
+ analyzer_kwargs,
291
313
  )
292
314
 
293
- for index in range(scores.shape[0]):
294
- temp = callback_class(
295
- shape=scores[index].shape,
296
- offset=translation_offsets[index],
297
- )
298
- state = (scores[index], rotations[index], rotation_mapping)
299
- ret.append(temp.result(state, **analyzer_args))
315
+ for index in range(targets.shape[0]):
316
+ kwargs = self._tuple_to_dict(analyzer_kwargs[index])
317
+ analyzer = callback_class(**kwargs)
318
+
319
+ state = (states[0][index], states[1][index], rotation_mapping)
320
+ ret.append(analyzer.result(state, **kwargs))
300
321
  return ret
301
322
 
302
323
  def get_available_memory(self) -> int:
@@ -8,7 +8,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
 
9
9
  from abc import ABC, abstractmethod
10
10
  from multiprocessing import shared_memory
11
- from typing import Tuple, Callable, List, Any, Union, Optional, Generator, Dict
11
+ from typing import Tuple, Callable, List, Any, Union, Optional, Generator
12
12
 
13
13
  from ..types import BackendArray, NDArray, Scalar, shm_type
14
14
 
@@ -1087,57 +1087,12 @@ class MatchingBackend(ABC):
1087
1087
  """
1088
1088
 
1089
1089
  @abstractmethod
1090
- def build_fft(
1091
- self,
1092
- fwd_shape: Tuple[int],
1093
- inv_shape: Tuple[int],
1094
- real_dtype: type,
1095
- cmpl_dtype: type,
1096
- inv_output_shape: Tuple[int] = None,
1097
- temp_fwd: NDArray = None,
1098
- temp_inv: NDArray = None,
1099
- fwd_axes: Tuple[int] = None,
1100
- inv_axes: Tuple[int] = None,
1101
- fftargs: Dict = {},
1102
- ) -> Tuple[Callable, Callable]:
1103
- """
1104
- Build forward and inverse real fourier transform functions. The returned
1105
- callables have two parameters ``arr`` and ``out`` which correspond to the
1106
- input and output of the Fourier transform. The methods return the output
1107
- of the respective function call, regardless of ``out`` being provided or not,
1108
- analogous to most numpy functions.
1090
+ def rfftn(self, **kwargs):
1091
+ """Perform an n-D real FFT."""
1109
1092
 
1110
- Parameters
1111
- ----------
1112
- fwd_shape : tuple
1113
- Input shape for the forward Fourier transform.
1114
- (see `compute_convolution_shapes`).
1115
- inv_shape : tuple
1116
- Input shape for the inverse Fourier transform.
1117
- real_dtype : dtype
1118
- Data type of the forward Fourier transform.
1119
- complex_dtype : dtype
1120
- Data type of the inverse Fourier transform.
1121
- inv_output_shape : tuple, optional
1122
- Output shape of the inverse Fourier transform. By default fast_shape.
1123
- fftargs : dict, optional
1124
- Dictionary passed to pyFFTW builders.
1125
- temp_fwd : NDArray, optional
1126
- Temporary array to build the forward transform. Superseeds shape defined by
1127
- fwd_shape if provided.
1128
- temp_inv : NDArray, optional
1129
- Temporary array to build the inverse transform. Superseeds shape defined by
1130
- inv_shape if provided.
1131
- fwd_axes : tuple of int
1132
- Axes to perform the forward Fourier transform over.
1133
- inv_axes : tuple of int
1134
- Axes to perform the inverse Fourier transform over.
1135
-
1136
- Returns
1137
- -------
1138
- tuple
1139
- Tuple of callables for forward and inverse real Fourier transform.
1140
- """
1093
+ @abstractmethod
1094
+ def irfftn(self, **kwargs):
1095
+ """Perform an n-D real inverse FFT."""
1141
1096
 
1142
1097
  def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
1143
1098
  """
@@ -6,7 +6,7 @@ Copyright (c) 2024 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- from typing import Tuple, List, Callable
9
+ from typing import Tuple, List
10
10
 
11
11
  import numpy as np
12
12
 
@@ -144,32 +144,6 @@ class MLXBackend(NumpyFFTWBackend):
144
144
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
145
145
  return arr[box]
146
146
 
147
- def build_fft(
148
- self,
149
- fwd_shape: Tuple[int],
150
- inv_shape: Tuple[int] = None,
151
- inv_output_shape: Tuple[int] = None,
152
- fwd_axes: Tuple[int] = None,
153
- inv_axes: Tuple[int] = None,
154
- **kwargs,
155
- ) -> Tuple[Callable, Callable]:
156
- # Runs on mlx.core.cpu until Metal support is available
157
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
158
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
159
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
160
-
161
- def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
162
- out[:] = self._array_backend.fft.rfftn(
163
- arr, s=s, axes=axes, stream=self._array_backend.cpu
164
- )
165
-
166
- def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
167
- out[:] = self._array_backend.fft.irfftn(
168
- arr, s=s, axes=axes, stream=self._array_backend.cpu
169
- )
170
-
171
- return rfftn, irfftn
172
-
173
147
  def rfftn(self, arr, *args, **kwargs):
174
148
  return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
175
149
 
@@ -9,17 +9,23 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
9
9
  import os
10
10
  from psutil import virtual_memory
11
11
  from contextlib import contextmanager
12
- from typing import Tuple, Dict, List, Type
12
+ from typing import Tuple, List, Type
13
13
 
14
14
  import scipy
15
15
  import numpy as np
16
16
  from scipy.ndimage import maximum_filter, affine_transform
17
- from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
18
- from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len, interfaces
17
+ from pyfftw import (
18
+ zeros_aligned,
19
+ simd_alignment,
20
+ next_fast_len,
21
+ interfaces,
22
+ config,
23
+ )
19
24
 
20
25
  from ..types import NDArray, BackendArray, shm_type
21
26
  from .matching_backend import MatchingBackend, _create_metafunction
22
27
 
28
+
23
29
  os.environ["MKL_NUM_THREADS"] = "1"
24
30
  os.environ["OMP_NUM_THREADS"] = "1"
25
31
  os.environ["PYFFTW_NUM_THREADS"] = "1"
@@ -103,6 +109,20 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
103
109
  self.solve_triangular = self._solve_triangular
104
110
  self.linalg.solve_triangular = scipy.linalg.solve_triangular
105
111
 
112
+ try:
113
+ from ._numpyfftw_utils import rfftn as rfftn_cache
114
+ from ._numpyfftw_utils import irfftn as irfftn_cache
115
+
116
+ self._rfftn = rfftn_cache
117
+ self._irfftn = irfftn_cache
118
+ except Exception as e:
119
+ print(e)
120
+
121
+ config.NUM_THREADS = 1
122
+ config.PLANNER_EFFORT = "FFTW_MEASURE"
123
+ interfaces.cache.enable()
124
+ interfaces.cache.set_keepalive_time(360)
125
+
106
126
  def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
107
127
  # Upper argument is not supported until numpy 2.0
108
128
  ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
@@ -138,7 +158,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
138
158
  return float
139
159
 
140
160
  def free_cache(self):
141
- pass
161
+ interfaces.cache.disable()
142
162
 
143
163
  def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
144
164
  return self._array_backend.transpose(arr, *args, **kwargs)
@@ -240,70 +260,53 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
240
260
  b[tuple(bind)] = arr[tuple(aind)]
241
261
  return b
242
262
 
243
- def build_fft(
244
- self,
245
- fwd_shape: Tuple[int],
246
- inv_shape: Tuple[int],
247
- real_dtype: type,
248
- cmpl_dtype: type,
249
- fftargs: Dict = {},
250
- inv_output_shape: Tuple[int] = None,
251
- temp_fwd: NDArray = None,
252
- temp_inv: NDArray = None,
253
- fwd_axes: Tuple[int] = None,
254
- inv_axes: Tuple[int] = None,
255
- ) -> Tuple[FFTW, FFTW]:
256
- if temp_fwd is None:
257
- temp_fwd = (
258
- self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
259
- )
260
- if temp_inv is None:
261
- temp_inv = (
262
- self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
263
- )
264
-
265
- default_values = {
266
- "planner_effort": "FFTW_MEASURE",
267
- "auto_align_input": False,
268
- "auto_contiguous": False,
269
- "avoid_copy": True,
270
- "overwrite_input": True,
271
- "threads": 1,
272
- }
273
- for key in default_values:
274
- if key in fftargs:
275
- continue
276
- fftargs[key] = default_values[key]
277
-
278
- rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
279
- _rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
280
- overwrite_input = fftargs.pop("overwrite_input", None)
281
-
282
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
283
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
284
- _irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
285
-
286
- def _rfftn_wrapper(arr, out, *args, **kwargs):
287
- return _rfftn(arr, out)
288
-
289
- def _irfftn_wrapper(arr, out, *args, **kwargs):
290
- return _irfftn(arr, out)
291
-
292
- fftargs["overwrite_input"] = overwrite_input
293
- return _rfftn_wrapper, _irfftn_wrapper
263
+ def _rfftn(self, arr, out=None, **kwargs):
264
+ ret = interfaces.numpy_fft.rfftn(arr, **kwargs)
265
+ if out is not None:
266
+ out[:] = ret
267
+ return out
268
+ return ret
294
269
 
295
- @staticmethod
296
- def _format_fft_shape(shape: Tuple[int], axes: Tuple[int] = None):
297
- if axes is None:
298
- return shape
299
- axes = tuple(sorted(range(len(shape))[i] for i in axes))
300
- return tuple(shape[i] for i in axes)
270
+ def _irfftn(self, arr, out=None, **kwargs):
271
+ ret = interfaces.numpy_fft.irfftn(arr, **kwargs)
272
+ if out is not None:
273
+ out[:] = ret
274
+ return out
275
+ return ret
301
276
 
302
- def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
303
- return interfaces.numpy_fft.rfftn(arr, **kwargs)
277
+ def rfftn(
278
+ self,
279
+ arr: NDArray,
280
+ out=None,
281
+ auto_align_input: bool = False,
282
+ auto_contiguous: bool = False,
283
+ overwrite_input: bool = True,
284
+ **kwargs,
285
+ ) -> NDArray:
286
+ return self._rfftn(
287
+ arr,
288
+ auto_align_input=auto_align_input,
289
+ auto_contiguous=auto_contiguous,
290
+ overwrite_input=overwrite_input,
291
+ **kwargs,
292
+ )
304
293
 
305
- def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
306
- return interfaces.numpy_fft.irfftn(arr, **kwargs)
294
+ def irfftn(
295
+ self,
296
+ arr: NDArray,
297
+ out=None,
298
+ auto_align_input: bool = False,
299
+ auto_contiguous: bool = False,
300
+ overwrite_input: bool = True,
301
+ **kwargs,
302
+ ) -> NDArray:
303
+ return self._irfftn(
304
+ arr,
305
+ auto_align_input=auto_align_input,
306
+ auto_contiguous=auto_contiguous,
307
+ overwrite_input=overwrite_input,
308
+ **kwargs,
309
+ )
307
310
 
308
311
  def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
309
312
  new_shape = self.to_backend_array(newshape)
@@ -7,7 +7,7 @@ Copyright (c) 2023 European Molecular Biology Laboratory
7
7
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
  """
9
9
 
10
- from typing import Tuple, Callable
10
+ from typing import Tuple
11
11
  from contextlib import contextmanager
12
12
  from multiprocessing import shared_memory
13
13
  from multiprocessing.managers import SharedMemoryManager
@@ -273,31 +273,6 @@ class PytorchBackend(NumpyFFTWBackend):
273
273
  kwargs["device"] = self.device
274
274
  return self._array_backend.eye(*args, **kwargs)
275
275
 
276
- def build_fft(
277
- self,
278
- fwd_shape: Tuple[int],
279
- inv_shape: Tuple[int],
280
- inv_output_shape: Tuple[int] = None,
281
- fwd_axes: Tuple[int] = None,
282
- inv_axes: Tuple[int] = None,
283
- **kwargs,
284
- ) -> Tuple[Callable, Callable]:
285
- rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
286
- irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
287
- irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
288
-
289
- def rfftn(
290
- arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
291
- ) -> TorchTensor:
292
- return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
293
-
294
- def irfftn(
295
- arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
296
- ) -> TorchTensor:
297
- return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
298
-
299
- return rfftn, irfftn
300
-
301
276
  def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
302
277
  kwargs["dim"] = kwargs.pop("axes", None)
303
278
  return self._array_backend.fft.rfftn(arr, **kwargs)
tme/density.py CHANGED
@@ -36,6 +36,7 @@ from .matching_utils import (
36
36
  array_to_memmap,
37
37
  memmap_to_array,
38
38
  minimum_enclosing_box,
39
+ is_gzipped,
39
40
  )
40
41
 
41
42
  __all__ = ["Density"]
@@ -331,6 +332,7 @@ class Density:
331
332
  if non_standard_crs:
332
333
  data = np.transpose(data, crs_index)
333
334
  origin = np.take(origin, crs_index)
335
+ sampling_rate = np.take(sampling_rate, crs_index)
334
336
 
335
337
  return data.T, origin[::-1], sampling_rate[::-1], metadata
336
338
 
@@ -2257,9 +2259,3 @@ class Density:
2257
2259
  coordinates = np.array(np.where(data > 0))
2258
2260
  weights = self.data[tuple(coordinates)]
2259
2261
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2260
-
2261
-
2262
- def is_gzipped(filename: str) -> bool:
2263
- """Check if a file is a gzip file by reading its magic number."""
2264
- with open(filename, "rb") as f:
2265
- return f.read(2) == b"\x1f\x8b"
Binary file