pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +23 -10
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +13 -3
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +15 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +5 -44
- tme/backends/jax_backend.py +58 -37
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +68 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +7 -19
- tme/matching_exhaustive.py +34 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +206 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +4 -6
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/backends/cupy_backend.py
CHANGED
@@ -6,9 +6,9 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
+
from typing import Tuple, List
|
9
10
|
from importlib.util import find_spec
|
10
11
|
from contextlib import contextmanager
|
11
|
-
from typing import Tuple, Callable, List
|
12
12
|
|
13
13
|
from .npfftw_backend import NumpyFFTWBackend
|
14
14
|
from ..types import CupyArray, NDArray, shm_type
|
@@ -111,53 +111,14 @@ class CupyBackend(NumpyFFTWBackend):
|
|
111
111
|
def unravel_index(self, indices, shape):
|
112
112
|
return self._array_backend.unravel_index(indices=indices, dims=shape)
|
113
113
|
|
114
|
-
def
|
115
|
-
self
|
116
|
-
fwd_shape: Tuple[int],
|
117
|
-
inv_shape: Tuple[int],
|
118
|
-
inv_output_shape: Tuple[int] = None,
|
119
|
-
fwd_axes: Tuple[int] = None,
|
120
|
-
inv_axes: Tuple[int] = None,
|
121
|
-
**kwargs,
|
122
|
-
) -> Tuple[Callable, Callable]:
|
123
|
-
cache = self._array_backend.fft.config.get_plan_cache()
|
124
|
-
current_device = self._array_backend.cuda.device.get_device_id()
|
125
|
-
|
126
|
-
previous_transform = [fwd_shape, inv_shape]
|
127
|
-
if current_device in PLAN_CACHE:
|
128
|
-
previous_transform = PLAN_CACHE[current_device]
|
129
|
-
|
130
|
-
real_diff, cmplx_diff = True, True
|
131
|
-
if len(fwd_shape) == len(previous_transform[0]):
|
132
|
-
real_diff = fwd_shape == previous_transform[0]
|
133
|
-
if len(inv_shape) == len(previous_transform[1]):
|
134
|
-
cmplx_diff = inv_shape == previous_transform[1]
|
135
|
-
|
136
|
-
if real_diff or cmplx_diff:
|
137
|
-
cache.clear()
|
138
|
-
|
139
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
140
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
141
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
142
|
-
|
143
|
-
def rfftn(
|
144
|
-
arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
|
145
|
-
) -> CupyArray:
|
146
|
-
return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
|
147
|
-
|
148
|
-
def irfftn(
|
149
|
-
arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
|
150
|
-
) -> CupyArray:
|
151
|
-
return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
|
152
|
-
|
153
|
-
PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
|
154
|
-
return rfftn, irfftn
|
114
|
+
def free_cache(self):
|
115
|
+
self._array_backend.fft.config.get_plan_cache().clear()
|
155
116
|
|
156
117
|
def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
157
118
|
return self._cufft.rfftn(arr, **kwargs)
|
158
119
|
|
159
120
|
def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
160
|
-
return self._cufft.irfftn(arr, **kwargs)
|
121
|
+
return self._cufft.irfftn(arr, **kwargs).astype(self._float_dtype)
|
161
122
|
|
162
123
|
def compute_convolution_shapes(
|
163
124
|
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
@@ -235,7 +196,7 @@ class CupyBackend(NumpyFFTWBackend):
|
|
235
196
|
)
|
236
197
|
return None
|
237
198
|
|
238
|
-
if data.ndim == 3 and cache and self.texture_available:
|
199
|
+
if data.ndim == 3 and cache and self.texture_available and not batched:
|
239
200
|
# Device memory pool (should) come to rescue performance
|
240
201
|
temp = self.zeros(data.shape, data.dtype)
|
241
202
|
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
tme/backends/jax_backend.py
CHANGED
@@ -7,7 +7,9 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from functools import wraps
|
10
|
-
from typing import Tuple, List,
|
10
|
+
from typing import Tuple, List, Dict, Any
|
11
|
+
|
12
|
+
import numpy as np
|
11
13
|
|
12
14
|
from ..types import BackendArray
|
13
15
|
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
@@ -64,6 +66,10 @@ class JaxBackend(NumpyFFTWBackend):
|
|
64
66
|
arr = arr.at[idx].set(value)
|
65
67
|
return arr
|
66
68
|
|
69
|
+
def addat(self, arr, indices, values):
|
70
|
+
arr = arr.at[indices].add(values)
|
71
|
+
return arr
|
72
|
+
|
67
73
|
def topleft_pad(
|
68
74
|
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
|
69
75
|
) -> BackendArray:
|
@@ -88,6 +94,7 @@ class JaxBackend(NumpyFFTWBackend):
|
|
88
94
|
"sqrt",
|
89
95
|
"maximum",
|
90
96
|
"exp",
|
97
|
+
"mod",
|
91
98
|
]
|
92
99
|
for ufunc in ufuncs:
|
93
100
|
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
@@ -103,27 +110,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
103
110
|
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
104
111
|
)
|
105
112
|
|
106
|
-
def build_fft(
|
107
|
-
self,
|
108
|
-
fwd_shape: Tuple[int],
|
109
|
-
inv_shape: Tuple[int] = None,
|
110
|
-
inv_output_shape: Tuple[int] = None,
|
111
|
-
fwd_axes: Tuple[int] = None,
|
112
|
-
inv_axes: Tuple[int] = None,
|
113
|
-
**kwargs,
|
114
|
-
) -> Tuple[Callable, Callable]:
|
115
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
116
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
117
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
118
|
-
|
119
|
-
def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
|
120
|
-
return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
|
121
|
-
|
122
|
-
def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
|
123
|
-
return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
|
124
|
-
|
125
|
-
return rfftn, irfftn
|
126
|
-
|
127
113
|
def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
|
128
114
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
129
115
|
|
@@ -194,12 +180,30 @@ class JaxBackend(NumpyFFTWBackend):
|
|
194
180
|
|
195
181
|
return convolution_shape, fast_shape, fast_ft_shape
|
196
182
|
|
183
|
+
def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]:
|
184
|
+
if isinstance(obj, np.ndarray):
|
185
|
+
return ("array", (tuple(obj.flatten().tolist()), obj.shape))
|
186
|
+
return ("other", obj)
|
187
|
+
|
188
|
+
def _from_hashable(self, type_info: str, data: Any) -> Any:
|
189
|
+
if type_info == "array":
|
190
|
+
data, shape = data
|
191
|
+
return self.array(data).reshape(shape)
|
192
|
+
return data
|
193
|
+
|
194
|
+
def _dict_to_tuple(self, data: Dict) -> Tuple:
|
195
|
+
return tuple((k, self._to_hashable(v)) for k, v in data.items())
|
196
|
+
|
197
|
+
def _tuple_to_dict(self, data: Tuple) -> Dict:
|
198
|
+
return {x[0]: self._from_hashable(*x[1]) for x in data}
|
199
|
+
|
197
200
|
def scan(
|
198
201
|
self,
|
199
202
|
matching_data: type,
|
200
203
|
splits: Tuple[Tuple[slice, slice]],
|
201
204
|
n_jobs: int,
|
202
|
-
callback_class,
|
205
|
+
callback_class: object,
|
206
|
+
callback_class_args: Dict,
|
203
207
|
rotate_mask: bool = False,
|
204
208
|
**kwargs,
|
205
209
|
) -> List:
|
@@ -220,16 +224,20 @@ class JaxBackend(NumpyFFTWBackend):
|
|
220
224
|
target_shape=self.to_numpy_array(target_shape),
|
221
225
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
222
226
|
batch_mask=self.to_numpy_array(matching_data._batch_mask),
|
223
|
-
pad_target=pad_target,
|
224
227
|
)
|
225
228
|
analyzer_args = {
|
226
|
-
"
|
229
|
+
"shape": fast_shape,
|
227
230
|
"fourier_shift": shift,
|
231
|
+
"fast_shape": fast_shape,
|
228
232
|
"targetshape": target_shape,
|
229
233
|
"templateshape": matching_data.template.shape,
|
230
234
|
"convolution_shape": conv_shape,
|
235
|
+
"convolution_mode": convolution_mode,
|
236
|
+
"thread_safe": False,
|
237
|
+
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
238
|
+
"n_rotations": matching_data.rotations.shape[0],
|
239
|
+
"jax_mode": True,
|
231
240
|
}
|
232
|
-
|
233
241
|
create_target_filter = matching_data.target_filter is not None
|
234
242
|
create_template_filter = matching_data.template_filter is not None
|
235
243
|
create_filter = create_target_filter or create_template_filter
|
@@ -245,6 +253,9 @@ class JaxBackend(NumpyFFTWBackend):
|
|
245
253
|
for i in range(matching_data.rotations.shape[0])
|
246
254
|
}
|
247
255
|
for split_start in range(0, len(splits), n_jobs):
|
256
|
+
|
257
|
+
analyzer_kwargs = []
|
258
|
+
|
248
259
|
split_subset = splits[split_start : (split_start + n_jobs)]
|
249
260
|
if not len(split_subset):
|
250
261
|
continue
|
@@ -256,8 +267,15 @@ class JaxBackend(NumpyFFTWBackend):
|
|
256
267
|
target_pad=target_pad,
|
257
268
|
template_slice=template_split,
|
258
269
|
)
|
270
|
+
cur_args = analyzer_args.copy()
|
271
|
+
cur_args["offset"] = translation_offset
|
272
|
+
cur_args.update(callback_class_args)
|
273
|
+
|
274
|
+
analyzer_kwargs.append(self._dict_to_tuple(cur_args))
|
275
|
+
|
276
|
+
_target = self.astype(base._target, self._float_dtype)
|
259
277
|
translation_offsets.append(translation_offset)
|
260
|
-
targets.append(self.topleft_pad(
|
278
|
+
targets.append(self.topleft_pad(_target, fast_shape))
|
261
279
|
|
262
280
|
if create_filter:
|
263
281
|
filter_args = {
|
@@ -279,24 +297,27 @@ class JaxBackend(NumpyFFTWBackend):
|
|
279
297
|
|
280
298
|
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
281
299
|
base, targets = None, self._array_backend.stack(targets)
|
282
|
-
|
300
|
+
|
301
|
+
analyzer_kwargs = tuple(analyzer_kwargs)
|
302
|
+
states = scan_inner(
|
283
303
|
self.astype(targets, self._float_dtype),
|
284
|
-
matching_data.template,
|
285
|
-
matching_data.template_mask,
|
304
|
+
self.astype(matching_data.template, self._float_dtype),
|
305
|
+
self.astype(matching_data.template_mask, self._float_dtype),
|
286
306
|
matching_data.rotations,
|
287
307
|
template_filter,
|
288
308
|
target_filter,
|
289
309
|
fast_shape,
|
290
310
|
rotate_mask,
|
311
|
+
callback_class,
|
312
|
+
analyzer_kwargs,
|
291
313
|
)
|
292
314
|
|
293
|
-
for index in range(
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
)
|
298
|
-
state
|
299
|
-
ret.append(temp.result(state, **analyzer_args))
|
315
|
+
for index in range(targets.shape[0]):
|
316
|
+
kwargs = self._tuple_to_dict(analyzer_kwargs[index])
|
317
|
+
analyzer = callback_class(**kwargs)
|
318
|
+
|
319
|
+
state = (states[0][index], states[1][index], rotation_mapping)
|
320
|
+
ret.append(analyzer.result(state, **kwargs))
|
300
321
|
return ret
|
301
322
|
|
302
323
|
def get_available_memory(self) -> int:
|
tme/backends/matching_backend.py
CHANGED
@@ -8,7 +8,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
8
8
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from multiprocessing import shared_memory
|
11
|
-
from typing import Tuple, Callable, List, Any, Union, Optional, Generator
|
11
|
+
from typing import Tuple, Callable, List, Any, Union, Optional, Generator
|
12
12
|
|
13
13
|
from ..types import BackendArray, NDArray, Scalar, shm_type
|
14
14
|
|
@@ -1087,57 +1087,12 @@ class MatchingBackend(ABC):
|
|
1087
1087
|
"""
|
1088
1088
|
|
1089
1089
|
@abstractmethod
|
1090
|
-
def
|
1091
|
-
|
1092
|
-
fwd_shape: Tuple[int],
|
1093
|
-
inv_shape: Tuple[int],
|
1094
|
-
real_dtype: type,
|
1095
|
-
cmpl_dtype: type,
|
1096
|
-
inv_output_shape: Tuple[int] = None,
|
1097
|
-
temp_fwd: NDArray = None,
|
1098
|
-
temp_inv: NDArray = None,
|
1099
|
-
fwd_axes: Tuple[int] = None,
|
1100
|
-
inv_axes: Tuple[int] = None,
|
1101
|
-
fftargs: Dict = {},
|
1102
|
-
) -> Tuple[Callable, Callable]:
|
1103
|
-
"""
|
1104
|
-
Build forward and inverse real fourier transform functions. The returned
|
1105
|
-
callables have two parameters ``arr`` and ``out`` which correspond to the
|
1106
|
-
input and output of the Fourier transform. The methods return the output
|
1107
|
-
of the respective function call, regardless of ``out`` being provided or not,
|
1108
|
-
analogous to most numpy functions.
|
1090
|
+
def rfftn(self, **kwargs):
|
1091
|
+
"""Perform an n-D real FFT."""
|
1109
1092
|
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
Input shape for the forward Fourier transform.
|
1114
|
-
(see `compute_convolution_shapes`).
|
1115
|
-
inv_shape : tuple
|
1116
|
-
Input shape for the inverse Fourier transform.
|
1117
|
-
real_dtype : dtype
|
1118
|
-
Data type of the forward Fourier transform.
|
1119
|
-
complex_dtype : dtype
|
1120
|
-
Data type of the inverse Fourier transform.
|
1121
|
-
inv_output_shape : tuple, optional
|
1122
|
-
Output shape of the inverse Fourier transform. By default fast_shape.
|
1123
|
-
fftargs : dict, optional
|
1124
|
-
Dictionary passed to pyFFTW builders.
|
1125
|
-
temp_fwd : NDArray, optional
|
1126
|
-
Temporary array to build the forward transform. Superseeds shape defined by
|
1127
|
-
fwd_shape if provided.
|
1128
|
-
temp_inv : NDArray, optional
|
1129
|
-
Temporary array to build the inverse transform. Superseeds shape defined by
|
1130
|
-
inv_shape if provided.
|
1131
|
-
fwd_axes : tuple of int
|
1132
|
-
Axes to perform the forward Fourier transform over.
|
1133
|
-
inv_axes : tuple of int
|
1134
|
-
Axes to perform the inverse Fourier transform over.
|
1135
|
-
|
1136
|
-
Returns
|
1137
|
-
-------
|
1138
|
-
tuple
|
1139
|
-
Tuple of callables for forward and inverse real Fourier transform.
|
1140
|
-
"""
|
1093
|
+
@abstractmethod
|
1094
|
+
def irfftn(self, **kwargs):
|
1095
|
+
"""Perform an n-D real inverse FFT."""
|
1141
1096
|
|
1142
1097
|
def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
|
1143
1098
|
"""
|
tme/backends/mlx_backend.py
CHANGED
@@ -6,7 +6,7 @@ Copyright (c) 2024 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
from typing import Tuple, List
|
9
|
+
from typing import Tuple, List
|
10
10
|
|
11
11
|
import numpy as np
|
12
12
|
|
@@ -144,32 +144,6 @@ class MLXBackend(NumpyFFTWBackend):
|
|
144
144
|
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
145
145
|
return arr[box]
|
146
146
|
|
147
|
-
def build_fft(
|
148
|
-
self,
|
149
|
-
fwd_shape: Tuple[int],
|
150
|
-
inv_shape: Tuple[int] = None,
|
151
|
-
inv_output_shape: Tuple[int] = None,
|
152
|
-
fwd_axes: Tuple[int] = None,
|
153
|
-
inv_axes: Tuple[int] = None,
|
154
|
-
**kwargs,
|
155
|
-
) -> Tuple[Callable, Callable]:
|
156
|
-
# Runs on mlx.core.cpu until Metal support is available
|
157
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
158
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
159
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
160
|
-
|
161
|
-
def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
|
162
|
-
out[:] = self._array_backend.fft.rfftn(
|
163
|
-
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
164
|
-
)
|
165
|
-
|
166
|
-
def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
|
167
|
-
out[:] = self._array_backend.fft.irfftn(
|
168
|
-
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
169
|
-
)
|
170
|
-
|
171
|
-
return rfftn, irfftn
|
172
|
-
|
173
147
|
def rfftn(self, arr, *args, **kwargs):
|
174
148
|
return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
|
175
149
|
|
tme/backends/npfftw_backend.py
CHANGED
@@ -9,17 +9,23 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
9
9
|
import os
|
10
10
|
from psutil import virtual_memory
|
11
11
|
from contextlib import contextmanager
|
12
|
-
from typing import Tuple,
|
12
|
+
from typing import Tuple, List, Type
|
13
13
|
|
14
14
|
import scipy
|
15
15
|
import numpy as np
|
16
16
|
from scipy.ndimage import maximum_filter, affine_transform
|
17
|
-
from pyfftw
|
18
|
-
|
17
|
+
from pyfftw import (
|
18
|
+
zeros_aligned,
|
19
|
+
simd_alignment,
|
20
|
+
next_fast_len,
|
21
|
+
interfaces,
|
22
|
+
config,
|
23
|
+
)
|
19
24
|
|
20
25
|
from ..types import NDArray, BackendArray, shm_type
|
21
26
|
from .matching_backend import MatchingBackend, _create_metafunction
|
22
27
|
|
28
|
+
|
23
29
|
os.environ["MKL_NUM_THREADS"] = "1"
|
24
30
|
os.environ["OMP_NUM_THREADS"] = "1"
|
25
31
|
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
@@ -103,6 +109,20 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
103
109
|
self.solve_triangular = self._solve_triangular
|
104
110
|
self.linalg.solve_triangular = scipy.linalg.solve_triangular
|
105
111
|
|
112
|
+
try:
|
113
|
+
from ._numpyfftw_utils import rfftn as rfftn_cache
|
114
|
+
from ._numpyfftw_utils import irfftn as irfftn_cache
|
115
|
+
|
116
|
+
self._rfftn = rfftn_cache
|
117
|
+
self._irfftn = irfftn_cache
|
118
|
+
except Exception as e:
|
119
|
+
print(e)
|
120
|
+
|
121
|
+
config.NUM_THREADS = 1
|
122
|
+
config.PLANNER_EFFORT = "FFTW_MEASURE"
|
123
|
+
interfaces.cache.enable()
|
124
|
+
interfaces.cache.set_keepalive_time(360)
|
125
|
+
|
106
126
|
def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
|
107
127
|
# Upper argument is not supported until numpy 2.0
|
108
128
|
ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
|
@@ -138,7 +158,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
138
158
|
return float
|
139
159
|
|
140
160
|
def free_cache(self):
|
141
|
-
|
161
|
+
interfaces.cache.disable()
|
142
162
|
|
143
163
|
def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
144
164
|
return self._array_backend.transpose(arr, *args, **kwargs)
|
@@ -240,70 +260,53 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
240
260
|
b[tuple(bind)] = arr[tuple(aind)]
|
241
261
|
return b
|
242
262
|
|
243
|
-
def
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
fftargs: Dict = {},
|
250
|
-
inv_output_shape: Tuple[int] = None,
|
251
|
-
temp_fwd: NDArray = None,
|
252
|
-
temp_inv: NDArray = None,
|
253
|
-
fwd_axes: Tuple[int] = None,
|
254
|
-
inv_axes: Tuple[int] = None,
|
255
|
-
) -> Tuple[FFTW, FFTW]:
|
256
|
-
if temp_fwd is None:
|
257
|
-
temp_fwd = (
|
258
|
-
self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
|
259
|
-
)
|
260
|
-
if temp_inv is None:
|
261
|
-
temp_inv = (
|
262
|
-
self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
|
263
|
-
)
|
264
|
-
|
265
|
-
default_values = {
|
266
|
-
"planner_effort": "FFTW_MEASURE",
|
267
|
-
"auto_align_input": False,
|
268
|
-
"auto_contiguous": False,
|
269
|
-
"avoid_copy": True,
|
270
|
-
"overwrite_input": True,
|
271
|
-
"threads": 1,
|
272
|
-
}
|
273
|
-
for key in default_values:
|
274
|
-
if key in fftargs:
|
275
|
-
continue
|
276
|
-
fftargs[key] = default_values[key]
|
277
|
-
|
278
|
-
rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
|
279
|
-
_rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
|
280
|
-
overwrite_input = fftargs.pop("overwrite_input", None)
|
281
|
-
|
282
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
283
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
284
|
-
_irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
|
285
|
-
|
286
|
-
def _rfftn_wrapper(arr, out, *args, **kwargs):
|
287
|
-
return _rfftn(arr, out)
|
288
|
-
|
289
|
-
def _irfftn_wrapper(arr, out, *args, **kwargs):
|
290
|
-
return _irfftn(arr, out)
|
291
|
-
|
292
|
-
fftargs["overwrite_input"] = overwrite_input
|
293
|
-
return _rfftn_wrapper, _irfftn_wrapper
|
263
|
+
def _rfftn(self, arr, out=None, **kwargs):
|
264
|
+
ret = interfaces.numpy_fft.rfftn(arr, **kwargs)
|
265
|
+
if out is not None:
|
266
|
+
out[:] = ret
|
267
|
+
return out
|
268
|
+
return ret
|
294
269
|
|
295
|
-
|
296
|
-
|
297
|
-
if
|
298
|
-
|
299
|
-
|
300
|
-
return
|
270
|
+
def _irfftn(self, arr, out=None, **kwargs):
|
271
|
+
ret = interfaces.numpy_fft.irfftn(arr, **kwargs)
|
272
|
+
if out is not None:
|
273
|
+
out[:] = ret
|
274
|
+
return out
|
275
|
+
return ret
|
301
276
|
|
302
|
-
def rfftn(
|
303
|
-
|
277
|
+
def rfftn(
|
278
|
+
self,
|
279
|
+
arr: NDArray,
|
280
|
+
out=None,
|
281
|
+
auto_align_input: bool = False,
|
282
|
+
auto_contiguous: bool = False,
|
283
|
+
overwrite_input: bool = True,
|
284
|
+
**kwargs,
|
285
|
+
) -> NDArray:
|
286
|
+
return self._rfftn(
|
287
|
+
arr,
|
288
|
+
auto_align_input=auto_align_input,
|
289
|
+
auto_contiguous=auto_contiguous,
|
290
|
+
overwrite_input=overwrite_input,
|
291
|
+
**kwargs,
|
292
|
+
)
|
304
293
|
|
305
|
-
def irfftn(
|
306
|
-
|
294
|
+
def irfftn(
|
295
|
+
self,
|
296
|
+
arr: NDArray,
|
297
|
+
out=None,
|
298
|
+
auto_align_input: bool = False,
|
299
|
+
auto_contiguous: bool = False,
|
300
|
+
overwrite_input: bool = True,
|
301
|
+
**kwargs,
|
302
|
+
) -> NDArray:
|
303
|
+
return self._irfftn(
|
304
|
+
arr,
|
305
|
+
auto_align_input=auto_align_input,
|
306
|
+
auto_contiguous=auto_contiguous,
|
307
|
+
overwrite_input=overwrite_input,
|
308
|
+
**kwargs,
|
309
|
+
)
|
307
310
|
|
308
311
|
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
309
312
|
new_shape = self.to_backend_array(newshape)
|
tme/backends/pytorch_backend.py
CHANGED
@@ -7,7 +7,7 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
7
7
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Tuple
|
10
|
+
from typing import Tuple
|
11
11
|
from contextlib import contextmanager
|
12
12
|
from multiprocessing import shared_memory
|
13
13
|
from multiprocessing.managers import SharedMemoryManager
|
@@ -273,31 +273,6 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
273
273
|
kwargs["device"] = self.device
|
274
274
|
return self._array_backend.eye(*args, **kwargs)
|
275
275
|
|
276
|
-
def build_fft(
|
277
|
-
self,
|
278
|
-
fwd_shape: Tuple[int],
|
279
|
-
inv_shape: Tuple[int],
|
280
|
-
inv_output_shape: Tuple[int] = None,
|
281
|
-
fwd_axes: Tuple[int] = None,
|
282
|
-
inv_axes: Tuple[int] = None,
|
283
|
-
**kwargs,
|
284
|
-
) -> Tuple[Callable, Callable]:
|
285
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
286
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
287
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
288
|
-
|
289
|
-
def rfftn(
|
290
|
-
arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
|
291
|
-
) -> TorchTensor:
|
292
|
-
return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
|
293
|
-
|
294
|
-
def irfftn(
|
295
|
-
arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
|
296
|
-
) -> TorchTensor:
|
297
|
-
return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
|
298
|
-
|
299
|
-
return rfftn, irfftn
|
300
|
-
|
301
276
|
def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
302
277
|
kwargs["dim"] = kwargs.pop("axes", None)
|
303
278
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
tme/density.py
CHANGED
@@ -36,6 +36,7 @@ from .matching_utils import (
|
|
36
36
|
array_to_memmap,
|
37
37
|
memmap_to_array,
|
38
38
|
minimum_enclosing_box,
|
39
|
+
is_gzipped,
|
39
40
|
)
|
40
41
|
|
41
42
|
__all__ = ["Density"]
|
@@ -331,6 +332,7 @@ class Density:
|
|
331
332
|
if non_standard_crs:
|
332
333
|
data = np.transpose(data, crs_index)
|
333
334
|
origin = np.take(origin, crs_index)
|
335
|
+
sampling_rate = np.take(sampling_rate, crs_index)
|
334
336
|
|
335
337
|
return data.T, origin[::-1], sampling_rate[::-1], metadata
|
336
338
|
|
@@ -2257,9 +2259,3 @@ class Density:
|
|
2257
2259
|
coordinates = np.array(np.where(data > 0))
|
2258
2260
|
weights = self.data[tuple(coordinates)]
|
2259
2261
|
return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
|
2260
|
-
|
2261
|
-
|
2262
|
-
def is_gzipped(filename: str) -> bool:
|
2263
|
-
"""Check if a file is a gzip file by reading its magic number."""
|
2264
|
-
with open(filename, "rb") as f:
|
2265
|
-
return f.read(2) == b"\x1f\x8b"
|
Binary file
|