pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
- pytme-0.3.1.dist-info/RECORD +133 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
- pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +177 -226
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +69 -47
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +98 -28
- scripts/pytme_runner.py +1223 -0
- scripts/refine_matches.py +156 -387
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +4 -9
- tests/test_density.py +0 -1
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +11 -61
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +5 -8
- tme/analyzer/aggregation.py +28 -9
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +49 -122
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +31 -28
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +11 -54
- tme/backends/jax_backend.py +72 -48
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +95 -90
- tme/backends/pytorch_backend.py +5 -26
- tme/density.py +7 -10
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +138 -87
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +98 -112
- tme/filters/whitening.py +1 -6
- tme/mask.py +341 -0
- tme/matching_data.py +20 -44
- tme/matching_exhaustive.py +46 -56
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +216 -412
- tme/matching_utils.py +82 -424
- tme/memory.py +1 -1
- tme/orientations.py +16 -8
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- tme/rotations.py +1 -1
- pytme-0.3b0.dist-info/RECORD +0 -122
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/backends/jax_backend.py
CHANGED
@@ -7,7 +7,9 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from functools import wraps
|
10
|
-
from typing import Tuple, List,
|
10
|
+
from typing import Tuple, List, Dict, Any
|
11
|
+
|
12
|
+
import numpy as np
|
11
13
|
|
12
14
|
from ..types import BackendArray
|
13
15
|
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
@@ -51,12 +53,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
51
53
|
)
|
52
54
|
self.scipy = jsp
|
53
55
|
self._create_ufuncs()
|
54
|
-
try:
|
55
|
-
from ._jax_utils import scan as _
|
56
|
-
|
57
|
-
self.scan = self._scan
|
58
|
-
except Exception:
|
59
|
-
pass
|
60
56
|
|
61
57
|
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
|
62
58
|
return arr
|
@@ -70,6 +66,10 @@ class JaxBackend(NumpyFFTWBackend):
|
|
70
66
|
arr = arr.at[idx].set(value)
|
71
67
|
return arr
|
72
68
|
|
69
|
+
def addat(self, arr, indices, values):
|
70
|
+
arr = arr.at[indices].add(values)
|
71
|
+
return arr
|
72
|
+
|
73
73
|
def topleft_pad(
|
74
74
|
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
|
75
75
|
) -> BackendArray:
|
@@ -94,6 +94,7 @@ class JaxBackend(NumpyFFTWBackend):
|
|
94
94
|
"sqrt",
|
95
95
|
"maximum",
|
96
96
|
"exp",
|
97
|
+
"mod",
|
97
98
|
]
|
98
99
|
for ufunc in ufuncs:
|
99
100
|
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
@@ -109,27 +110,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
109
110
|
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
110
111
|
)
|
111
112
|
|
112
|
-
def build_fft(
|
113
|
-
self,
|
114
|
-
fwd_shape: Tuple[int],
|
115
|
-
inv_shape: Tuple[int] = None,
|
116
|
-
inv_output_shape: Tuple[int] = None,
|
117
|
-
fwd_axes: Tuple[int] = None,
|
118
|
-
inv_axes: Tuple[int] = None,
|
119
|
-
**kwargs,
|
120
|
-
) -> Tuple[Callable, Callable]:
|
121
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
122
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
123
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
124
|
-
|
125
|
-
def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
|
126
|
-
return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
|
127
|
-
|
128
|
-
def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
|
129
|
-
return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
|
130
|
-
|
131
|
-
return rfftn, irfftn
|
132
|
-
|
133
113
|
def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
|
134
114
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
135
115
|
|
@@ -189,12 +169,41 @@ class JaxBackend(NumpyFFTWBackend):
|
|
189
169
|
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
190
170
|
return max_scores, rotations
|
191
171
|
|
192
|
-
def
|
172
|
+
def compute_convolution_shapes(
|
173
|
+
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
174
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
175
|
+
from scipy.fft import next_fast_len
|
176
|
+
|
177
|
+
convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
|
178
|
+
fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
|
179
|
+
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
180
|
+
|
181
|
+
return convolution_shape, fast_shape, fast_ft_shape
|
182
|
+
|
183
|
+
def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]:
|
184
|
+
if isinstance(obj, np.ndarray):
|
185
|
+
return ("array", (tuple(obj.flatten().tolist()), obj.shape))
|
186
|
+
return ("other", obj)
|
187
|
+
|
188
|
+
def _from_hashable(self, type_info: str, data: Any) -> Any:
|
189
|
+
if type_info == "array":
|
190
|
+
data, shape = data
|
191
|
+
return self.array(data).reshape(shape)
|
192
|
+
return data
|
193
|
+
|
194
|
+
def _dict_to_tuple(self, data: Dict) -> Tuple:
|
195
|
+
return tuple((k, self._to_hashable(v)) for k, v in data.items())
|
196
|
+
|
197
|
+
def _tuple_to_dict(self, data: Tuple) -> Dict:
|
198
|
+
return {x[0]: self._from_hashable(*x[1]) for x in data}
|
199
|
+
|
200
|
+
def scan(
|
193
201
|
self,
|
194
202
|
matching_data: type,
|
195
203
|
splits: Tuple[Tuple[slice, slice]],
|
196
204
|
n_jobs: int,
|
197
|
-
callback_class,
|
205
|
+
callback_class: object,
|
206
|
+
callback_class_args: Dict,
|
198
207
|
rotate_mask: bool = False,
|
199
208
|
**kwargs,
|
200
209
|
) -> List:
|
@@ -214,17 +223,21 @@ class JaxBackend(NumpyFFTWBackend):
|
|
214
223
|
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
215
224
|
target_shape=self.to_numpy_array(target_shape),
|
216
225
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
217
|
-
|
226
|
+
batch_mask=self.to_numpy_array(matching_data._batch_mask),
|
218
227
|
)
|
219
|
-
|
220
228
|
analyzer_args = {
|
221
|
-
"
|
229
|
+
"shape": fast_shape,
|
222
230
|
"fourier_shift": shift,
|
231
|
+
"fast_shape": fast_shape,
|
223
232
|
"targetshape": target_shape,
|
224
233
|
"templateshape": matching_data.template.shape,
|
225
234
|
"convolution_shape": conv_shape,
|
235
|
+
"convolution_mode": convolution_mode,
|
236
|
+
"thread_safe": False,
|
237
|
+
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
238
|
+
"n_rotations": matching_data.rotations.shape[0],
|
239
|
+
"jax_mode": True,
|
226
240
|
}
|
227
|
-
|
228
241
|
create_target_filter = matching_data.target_filter is not None
|
229
242
|
create_template_filter = matching_data.template_filter is not None
|
230
243
|
create_filter = create_target_filter or create_template_filter
|
@@ -240,25 +253,34 @@ class JaxBackend(NumpyFFTWBackend):
|
|
240
253
|
for i in range(matching_data.rotations.shape[0])
|
241
254
|
}
|
242
255
|
for split_start in range(0, len(splits), n_jobs):
|
256
|
+
|
257
|
+
analyzer_kwargs = []
|
258
|
+
|
243
259
|
split_subset = splits[split_start : (split_start + n_jobs)]
|
244
260
|
if not len(split_subset):
|
245
261
|
continue
|
246
262
|
|
247
263
|
targets, translation_offsets = [], []
|
248
264
|
for target_split, template_split in split_subset:
|
249
|
-
base = matching_data.subset_by_slice(
|
265
|
+
base, translation_offset = matching_data.subset_by_slice(
|
250
266
|
target_slice=target_split,
|
251
267
|
target_pad=target_pad,
|
252
268
|
template_slice=template_split,
|
253
269
|
)
|
254
|
-
|
255
|
-
|
270
|
+
cur_args = analyzer_args.copy()
|
271
|
+
cur_args["offset"] = translation_offset
|
272
|
+
cur_args.update(callback_class_args)
|
273
|
+
|
274
|
+
analyzer_kwargs.append(self._dict_to_tuple(cur_args))
|
275
|
+
|
276
|
+
_target = self.astype(base._target, self._float_dtype)
|
277
|
+
translation_offsets.append(translation_offset)
|
278
|
+
targets.append(self.topleft_pad(_target, fast_shape))
|
256
279
|
|
257
280
|
if create_filter:
|
258
281
|
filter_args = {
|
259
282
|
"data_rfft": self.fft.rfftn(targets[0]),
|
260
283
|
"return_real_fourier": True,
|
261
|
-
"shape_is_real_fourier": False,
|
262
284
|
}
|
263
285
|
|
264
286
|
if create_template_filter:
|
@@ -275,25 +297,27 @@ class JaxBackend(NumpyFFTWBackend):
|
|
275
297
|
|
276
298
|
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
277
299
|
base, targets = None, self._array_backend.stack(targets)
|
278
|
-
|
300
|
+
|
301
|
+
analyzer_kwargs = tuple(analyzer_kwargs)
|
302
|
+
states = scan_inner(
|
279
303
|
self.astype(targets, self._float_dtype),
|
280
|
-
matching_data.template,
|
281
|
-
matching_data.template_mask,
|
304
|
+
self.astype(matching_data.template, self._float_dtype),
|
305
|
+
self.astype(matching_data.template_mask, self._float_dtype),
|
282
306
|
matching_data.rotations,
|
283
307
|
template_filter,
|
284
308
|
target_filter,
|
285
309
|
fast_shape,
|
286
310
|
rotate_mask,
|
311
|
+
callback_class,
|
312
|
+
analyzer_kwargs,
|
287
313
|
)
|
288
314
|
|
289
|
-
for index in range(
|
290
|
-
|
291
|
-
|
292
|
-
offset=translation_offsets[index],
|
293
|
-
)
|
294
|
-
state = (scores, rotations, rotation_mapping)
|
295
|
-
ret.append(temp.result(state, **analyzer_args))
|
315
|
+
for index in range(targets.shape[0]):
|
316
|
+
kwargs = self._tuple_to_dict(analyzer_kwargs[index])
|
317
|
+
analyzer = callback_class(**kwargs)
|
296
318
|
|
319
|
+
state = (states[0][index], states[1][index], rotation_mapping)
|
320
|
+
ret.append(analyzer.result(state, **kwargs))
|
297
321
|
return ret
|
298
322
|
|
299
323
|
def get_available_memory(self) -> int:
|
tme/backends/matching_backend.py
CHANGED
@@ -8,7 +8,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
8
8
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from multiprocessing import shared_memory
|
11
|
-
from typing import Tuple, Callable, List, Any, Union, Optional, Generator
|
11
|
+
from typing import Tuple, Callable, List, Any, Union, Optional, Generator
|
12
12
|
|
13
13
|
from ..types import BackendArray, NDArray, Scalar, shm_type
|
14
14
|
|
@@ -1087,57 +1087,12 @@ class MatchingBackend(ABC):
|
|
1087
1087
|
"""
|
1088
1088
|
|
1089
1089
|
@abstractmethod
|
1090
|
-
def
|
1091
|
-
|
1092
|
-
fwd_shape: Tuple[int],
|
1093
|
-
inv_shape: Tuple[int],
|
1094
|
-
real_dtype: type,
|
1095
|
-
cmpl_dtype: type,
|
1096
|
-
inv_output_shape: Tuple[int] = None,
|
1097
|
-
temp_fwd: NDArray = None,
|
1098
|
-
temp_inv: NDArray = None,
|
1099
|
-
fwd_axes: Tuple[int] = None,
|
1100
|
-
inv_axes: Tuple[int] = None,
|
1101
|
-
fftargs: Dict = {},
|
1102
|
-
) -> Tuple[Callable, Callable]:
|
1103
|
-
"""
|
1104
|
-
Build forward and inverse real fourier transform functions. The returned
|
1105
|
-
callables have two parameters ``arr`` and ``out`` which correspond to the
|
1106
|
-
input and output of the Fourier transform. The methods return the output
|
1107
|
-
of the respective function call, regardless of ``out`` being provided or not,
|
1108
|
-
analogous to most numpy functions.
|
1090
|
+
def rfftn(self, **kwargs):
|
1091
|
+
"""Perform an n-D real FFT."""
|
1109
1092
|
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
Input shape for the forward Fourier transform.
|
1114
|
-
(see `compute_convolution_shapes`).
|
1115
|
-
inv_shape : tuple
|
1116
|
-
Input shape for the inverse Fourier transform.
|
1117
|
-
real_dtype : dtype
|
1118
|
-
Data type of the forward Fourier transform.
|
1119
|
-
complex_dtype : dtype
|
1120
|
-
Data type of the inverse Fourier transform.
|
1121
|
-
inv_output_shape : tuple, optional
|
1122
|
-
Output shape of the inverse Fourier transform. By default fast_shape.
|
1123
|
-
fftargs : dict, optional
|
1124
|
-
Dictionary passed to pyFFTW builders.
|
1125
|
-
temp_fwd : NDArray, optional
|
1126
|
-
Temporary array to build the forward transform. Superseeds shape defined by
|
1127
|
-
fwd_shape if provided.
|
1128
|
-
temp_inv : NDArray, optional
|
1129
|
-
Temporary array to build the inverse transform. Superseeds shape defined by
|
1130
|
-
inv_shape if provided.
|
1131
|
-
fwd_axes : tuple of int
|
1132
|
-
Axes to perform the forward Fourier transform over.
|
1133
|
-
inv_axes : tuple of int
|
1134
|
-
Axes to perform the inverse Fourier transform over.
|
1135
|
-
|
1136
|
-
Returns
|
1137
|
-
-------
|
1138
|
-
tuple
|
1139
|
-
Tuple of callables for forward and inverse real Fourier transform.
|
1140
|
-
"""
|
1093
|
+
@abstractmethod
|
1094
|
+
def irfftn(self, **kwargs):
|
1095
|
+
"""Perform an n-D real inverse FFT."""
|
1141
1096
|
|
1142
1097
|
def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
|
1143
1098
|
"""
|
tme/backends/mlx_backend.py
CHANGED
@@ -6,7 +6,7 @@ Copyright (c) 2024 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
from typing import Tuple, List
|
9
|
+
from typing import Tuple, List
|
10
10
|
|
11
11
|
import numpy as np
|
12
12
|
|
@@ -144,32 +144,6 @@ class MLXBackend(NumpyFFTWBackend):
|
|
144
144
|
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
145
145
|
return arr[box]
|
146
146
|
|
147
|
-
def build_fft(
|
148
|
-
self,
|
149
|
-
fwd_shape: Tuple[int],
|
150
|
-
inv_shape: Tuple[int] = None,
|
151
|
-
inv_output_shape: Tuple[int] = None,
|
152
|
-
fwd_axes: Tuple[int] = None,
|
153
|
-
inv_axes: Tuple[int] = None,
|
154
|
-
**kwargs,
|
155
|
-
) -> Tuple[Callable, Callable]:
|
156
|
-
# Runs on mlx.core.cpu until Metal support is available
|
157
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
158
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
159
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
160
|
-
|
161
|
-
def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
|
162
|
-
out[:] = self._array_backend.fft.rfftn(
|
163
|
-
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
164
|
-
)
|
165
|
-
|
166
|
-
def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
|
167
|
-
out[:] = self._array_backend.fft.irfftn(
|
168
|
-
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
169
|
-
)
|
170
|
-
|
171
|
-
return rfftn, irfftn
|
172
|
-
|
173
147
|
def rfftn(self, arr, *args, **kwargs):
|
174
148
|
return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
|
175
149
|
|
tme/backends/npfftw_backend.py
CHANGED
@@ -9,17 +9,23 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
9
9
|
import os
|
10
10
|
from psutil import virtual_memory
|
11
11
|
from contextlib import contextmanager
|
12
|
-
from typing import Tuple,
|
12
|
+
from typing import Tuple, List, Type
|
13
13
|
|
14
14
|
import scipy
|
15
15
|
import numpy as np
|
16
16
|
from scipy.ndimage import maximum_filter, affine_transform
|
17
|
-
from pyfftw
|
18
|
-
|
17
|
+
from pyfftw import (
|
18
|
+
zeros_aligned,
|
19
|
+
simd_alignment,
|
20
|
+
next_fast_len,
|
21
|
+
interfaces,
|
22
|
+
config,
|
23
|
+
)
|
19
24
|
|
20
25
|
from ..types import NDArray, BackendArray, shm_type
|
21
26
|
from .matching_backend import MatchingBackend, _create_metafunction
|
22
27
|
|
28
|
+
|
23
29
|
os.environ["MKL_NUM_THREADS"] = "1"
|
24
30
|
os.environ["OMP_NUM_THREADS"] = "1"
|
25
31
|
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
@@ -103,6 +109,20 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
103
109
|
self.solve_triangular = self._solve_triangular
|
104
110
|
self.linalg.solve_triangular = scipy.linalg.solve_triangular
|
105
111
|
|
112
|
+
try:
|
113
|
+
from ._numpyfftw_utils import rfftn as rfftn_cache
|
114
|
+
from ._numpyfftw_utils import irfftn as irfftn_cache
|
115
|
+
|
116
|
+
self._rfftn = rfftn_cache
|
117
|
+
self._irfftn = irfftn_cache
|
118
|
+
except Exception as e:
|
119
|
+
print(e)
|
120
|
+
|
121
|
+
config.NUM_THREADS = 1
|
122
|
+
config.PLANNER_EFFORT = "FFTW_MEASURE"
|
123
|
+
interfaces.cache.enable()
|
124
|
+
interfaces.cache.set_keepalive_time(360)
|
125
|
+
|
106
126
|
def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
|
107
127
|
# Upper argument is not supported until numpy 2.0
|
108
128
|
ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
|
@@ -138,7 +158,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
138
158
|
return float
|
139
159
|
|
140
160
|
def free_cache(self):
|
141
|
-
|
161
|
+
interfaces.cache.disable()
|
142
162
|
|
143
163
|
def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
144
164
|
return self._array_backend.transpose(arr, *args, **kwargs)
|
@@ -240,70 +260,53 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
240
260
|
b[tuple(bind)] = arr[tuple(aind)]
|
241
261
|
return b
|
242
262
|
|
243
|
-
def
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
fftargs: Dict = {},
|
250
|
-
inv_output_shape: Tuple[int] = None,
|
251
|
-
temp_fwd: NDArray = None,
|
252
|
-
temp_inv: NDArray = None,
|
253
|
-
fwd_axes: Tuple[int] = None,
|
254
|
-
inv_axes: Tuple[int] = None,
|
255
|
-
) -> Tuple[FFTW, FFTW]:
|
256
|
-
if temp_fwd is None:
|
257
|
-
temp_fwd = (
|
258
|
-
self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
|
259
|
-
)
|
260
|
-
if temp_inv is None:
|
261
|
-
temp_inv = (
|
262
|
-
self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
|
263
|
-
)
|
264
|
-
|
265
|
-
default_values = {
|
266
|
-
"planner_effort": "FFTW_MEASURE",
|
267
|
-
"auto_align_input": False,
|
268
|
-
"auto_contiguous": False,
|
269
|
-
"avoid_copy": True,
|
270
|
-
"overwrite_input": True,
|
271
|
-
"threads": 1,
|
272
|
-
}
|
273
|
-
for key in default_values:
|
274
|
-
if key in fftargs:
|
275
|
-
continue
|
276
|
-
fftargs[key] = default_values[key]
|
277
|
-
|
278
|
-
rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
|
279
|
-
_rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
|
280
|
-
overwrite_input = fftargs.pop("overwrite_input", None)
|
281
|
-
|
282
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
283
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
284
|
-
_irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
|
285
|
-
|
286
|
-
def _rfftn_wrapper(arr, out, *args, **kwargs):
|
287
|
-
return _rfftn(arr, out)
|
288
|
-
|
289
|
-
def _irfftn_wrapper(arr, out, *args, **kwargs):
|
290
|
-
return _irfftn(arr, out)
|
291
|
-
|
292
|
-
fftargs["overwrite_input"] = overwrite_input
|
293
|
-
return _rfftn_wrapper, _irfftn_wrapper
|
263
|
+
def _rfftn(self, arr, out=None, **kwargs):
|
264
|
+
ret = interfaces.numpy_fft.rfftn(arr, **kwargs)
|
265
|
+
if out is not None:
|
266
|
+
out[:] = ret
|
267
|
+
return out
|
268
|
+
return ret
|
294
269
|
|
295
|
-
|
296
|
-
|
297
|
-
if
|
298
|
-
|
299
|
-
|
300
|
-
return
|
270
|
+
def _irfftn(self, arr, out=None, **kwargs):
|
271
|
+
ret = interfaces.numpy_fft.irfftn(arr, **kwargs)
|
272
|
+
if out is not None:
|
273
|
+
out[:] = ret
|
274
|
+
return out
|
275
|
+
return ret
|
301
276
|
|
302
|
-
def rfftn(
|
303
|
-
|
277
|
+
def rfftn(
|
278
|
+
self,
|
279
|
+
arr: NDArray,
|
280
|
+
out=None,
|
281
|
+
auto_align_input: bool = False,
|
282
|
+
auto_contiguous: bool = False,
|
283
|
+
overwrite_input: bool = True,
|
284
|
+
**kwargs,
|
285
|
+
) -> NDArray:
|
286
|
+
return self._rfftn(
|
287
|
+
arr,
|
288
|
+
auto_align_input=auto_align_input,
|
289
|
+
auto_contiguous=auto_contiguous,
|
290
|
+
overwrite_input=overwrite_input,
|
291
|
+
**kwargs,
|
292
|
+
)
|
304
293
|
|
305
|
-
def irfftn(
|
306
|
-
|
294
|
+
def irfftn(
|
295
|
+
self,
|
296
|
+
arr: NDArray,
|
297
|
+
out=None,
|
298
|
+
auto_align_input: bool = False,
|
299
|
+
auto_contiguous: bool = False,
|
300
|
+
overwrite_input: bool = True,
|
301
|
+
**kwargs,
|
302
|
+
) -> NDArray:
|
303
|
+
return self._irfftn(
|
304
|
+
arr,
|
305
|
+
auto_align_input=auto_align_input,
|
306
|
+
auto_contiguous=auto_contiguous,
|
307
|
+
overwrite_input=overwrite_input,
|
308
|
+
**kwargs,
|
309
|
+
)
|
307
310
|
|
308
311
|
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
309
312
|
new_shape = self.to_backend_array(newshape)
|
@@ -398,33 +401,33 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
398
401
|
out_mask: NDArray = None,
|
399
402
|
order: int = 3,
|
400
403
|
cache: bool = False,
|
404
|
+
batched: bool = False,
|
401
405
|
) -> Tuple[NDArray, NDArray]:
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
translation=
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
|
406
|
+
if out is None:
|
407
|
+
out = self.zeros_like(arr)
|
408
|
+
|
409
|
+
# Check whether rotation_matrix is already a rigid transform matrix
|
410
|
+
matrix = rotation_matrix
|
411
|
+
if matrix.shape[-1] == (arr.ndim - int(batched)):
|
412
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
413
|
+
if not use_geometric_center:
|
414
|
+
center = self.center_of_mass(arr, cutoff=0)
|
415
|
+
|
416
|
+
offset = int(arr.ndim - rotation_matrix.shape[0])
|
417
|
+
center = center[offset:]
|
418
|
+
translation = (
|
419
|
+
self.zeros(center.size) if translation is None else translation
|
420
|
+
)
|
421
|
+
matrix = self._rigid_transform_matrix(
|
422
|
+
rotation_matrix=rotation_matrix,
|
423
|
+
translation=translation,
|
424
|
+
center=center,
|
422
425
|
)
|
423
426
|
|
424
427
|
self._rigid_transform(
|
425
|
-
data=arr
|
428
|
+
data=arr,
|
426
429
|
matrix=matrix,
|
427
|
-
output=out
|
430
|
+
output=out,
|
428
431
|
order=order,
|
429
432
|
prefilter=True,
|
430
433
|
cache=cache,
|
@@ -433,11 +436,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
433
436
|
|
434
437
|
# Applying the prefilter leads to artifacts in the mask.
|
435
438
|
if arr_mask is not None:
|
436
|
-
|
439
|
+
if out_mask is None:
|
440
|
+
out_mask = self.zeros_like(arr_mask)
|
441
|
+
|
437
442
|
self._rigid_transform(
|
438
|
-
data=arr_mask
|
443
|
+
data=arr_mask,
|
439
444
|
matrix=matrix,
|
440
|
-
output=out_mask
|
445
|
+
output=out_mask,
|
441
446
|
order=order,
|
442
447
|
prefilter=False,
|
443
448
|
cache=cache,
|
tme/backends/pytorch_backend.py
CHANGED
@@ -7,7 +7,7 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
7
7
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Tuple
|
10
|
+
from typing import Tuple
|
11
11
|
from contextlib import contextmanager
|
12
12
|
from multiprocessing import shared_memory
|
13
13
|
from multiprocessing.managers import SharedMemoryManager
|
@@ -273,31 +273,6 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
273
273
|
kwargs["device"] = self.device
|
274
274
|
return self._array_backend.eye(*args, **kwargs)
|
275
275
|
|
276
|
-
def build_fft(
|
277
|
-
self,
|
278
|
-
fwd_shape: Tuple[int],
|
279
|
-
inv_shape: Tuple[int],
|
280
|
-
inv_output_shape: Tuple[int] = None,
|
281
|
-
fwd_axes: Tuple[int] = None,
|
282
|
-
inv_axes: Tuple[int] = None,
|
283
|
-
**kwargs,
|
284
|
-
) -> Tuple[Callable, Callable]:
|
285
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
286
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
287
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
288
|
-
|
289
|
-
def rfftn(
|
290
|
-
arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
|
291
|
-
) -> TorchTensor:
|
292
|
-
return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
|
293
|
-
|
294
|
-
def irfftn(
|
295
|
-
arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
|
296
|
-
) -> TorchTensor:
|
297
|
-
return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
|
298
|
-
|
299
|
-
return rfftn, irfftn
|
300
|
-
|
301
276
|
def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
302
277
|
kwargs["dim"] = kwargs.pop("axes", None)
|
303
278
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
@@ -306,6 +281,9 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
306
281
|
kwargs["dim"] = kwargs.pop("axes", None)
|
307
282
|
return self._array_backend.fft.irfftn(arr, **kwargs)
|
308
283
|
|
284
|
+
def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
|
285
|
+
return rotation_matrix
|
286
|
+
|
309
287
|
def rigid_transform(
|
310
288
|
self,
|
311
289
|
arr: TorchTensor,
|
@@ -317,6 +295,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
317
295
|
out_mask: TorchTensor = None,
|
318
296
|
order: int = 1,
|
319
297
|
cache: bool = False,
|
298
|
+
**kwargs,
|
320
299
|
):
|
321
300
|
_mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
|
322
301
|
mode = _mode_mapping.get(order, None)
|
tme/density.py
CHANGED
@@ -36,6 +36,7 @@ from .matching_utils import (
|
|
36
36
|
array_to_memmap,
|
37
37
|
memmap_to_array,
|
38
38
|
minimum_enclosing_box,
|
39
|
+
is_gzipped,
|
39
40
|
)
|
40
41
|
|
41
42
|
__all__ = ["Density"]
|
@@ -331,6 +332,7 @@ class Density:
|
|
331
332
|
if non_standard_crs:
|
332
333
|
data = np.transpose(data, crs_index)
|
333
334
|
origin = np.take(origin, crs_index)
|
335
|
+
sampling_rate = np.take(sampling_rate, crs_index)
|
334
336
|
|
335
337
|
return data.T, origin[::-1], sampling_rate[::-1], metadata
|
336
338
|
|
@@ -1763,12 +1765,13 @@ class Density:
|
|
1763
1765
|
axis=axis,
|
1764
1766
|
)
|
1765
1767
|
|
1766
|
-
|
1768
|
+
mask, mask_ret = np.where(mask), np.where(mask_ret)
|
1769
|
+
|
1770
|
+
arr_ft = np.fft.fftn(self.data)[mask]
|
1767
1771
|
arr_ft *= np.prod(ret_shape) / np.prod(self.shape)
|
1768
1772
|
ret_ft = np.zeros(ret_shape, dtype=arr_ft.dtype)
|
1769
|
-
ret_ft
|
1770
|
-
ret.data = np.real(np.fft.ifftn(ret_ft))
|
1771
|
-
|
1773
|
+
np.add.at(ret_ft, mask_ret, arr_ft)
|
1774
|
+
ret.data = np.real(np.fft.ifftn(ret_ft)).astype(self.data.dtype)
|
1772
1775
|
ret.sampling_rate = new_sampling_rate
|
1773
1776
|
return ret
|
1774
1777
|
|
@@ -2256,9 +2259,3 @@ class Density:
|
|
2256
2259
|
coordinates = np.array(np.where(data > 0))
|
2257
2260
|
weights = self.data[tuple(coordinates)]
|
2258
2261
|
return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
|
2259
|
-
|
2260
|
-
|
2261
|
-
def is_gzipped(filename: str) -> bool:
|
2262
|
-
"""Check if a file is a gzip file by reading its magic number."""
|
2263
|
-
with open(filename, "rb") as f:
|
2264
|
-
return f.read(2) == b"\x1f\x8b"
|
Binary file
|