pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
- pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
- scripts/estimate_ram_usage.py +97 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +30 -41
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +96 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +158 -390
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_orientations.py +0 -12
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +64 -18
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +79 -40
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +58 -5
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +24 -13
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
tme/backends/mlx_backend.py
CHANGED
@@ -6,7 +6,7 @@ Copyright (c) 2024 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
from typing import Tuple, List
|
9
|
+
from typing import Tuple, List
|
10
10
|
|
11
11
|
import numpy as np
|
12
12
|
|
@@ -144,32 +144,6 @@ class MLXBackend(NumpyFFTWBackend):
|
|
144
144
|
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
145
145
|
return arr[box]
|
146
146
|
|
147
|
-
def build_fft(
|
148
|
-
self,
|
149
|
-
fwd_shape: Tuple[int],
|
150
|
-
inv_shape: Tuple[int] = None,
|
151
|
-
inv_output_shape: Tuple[int] = None,
|
152
|
-
fwd_axes: Tuple[int] = None,
|
153
|
-
inv_axes: Tuple[int] = None,
|
154
|
-
**kwargs,
|
155
|
-
) -> Tuple[Callable, Callable]:
|
156
|
-
# Runs on mlx.core.cpu until Metal support is available
|
157
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
158
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
159
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
160
|
-
|
161
|
-
def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
|
162
|
-
out[:] = self._array_backend.fft.rfftn(
|
163
|
-
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
164
|
-
)
|
165
|
-
|
166
|
-
def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
|
167
|
-
out[:] = self._array_backend.fft.irfftn(
|
168
|
-
arr, s=s, axes=axes, stream=self._array_backend.cpu
|
169
|
-
)
|
170
|
-
|
171
|
-
return rfftn, irfftn
|
172
|
-
|
173
147
|
def rfftn(self, arr, *args, **kwargs):
|
174
148
|
return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
|
175
149
|
|
tme/backends/npfftw_backend.py
CHANGED
@@ -9,17 +9,23 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
9
9
|
import os
|
10
10
|
from psutil import virtual_memory
|
11
11
|
from contextlib import contextmanager
|
12
|
-
from typing import Tuple,
|
12
|
+
from typing import Tuple, List, Type
|
13
13
|
|
14
14
|
import scipy
|
15
15
|
import numpy as np
|
16
16
|
from scipy.ndimage import maximum_filter, affine_transform
|
17
|
-
from pyfftw
|
18
|
-
|
17
|
+
from pyfftw import (
|
18
|
+
zeros_aligned,
|
19
|
+
simd_alignment,
|
20
|
+
next_fast_len,
|
21
|
+
interfaces,
|
22
|
+
config,
|
23
|
+
)
|
19
24
|
|
20
25
|
from ..types import NDArray, BackendArray, shm_type
|
21
26
|
from .matching_backend import MatchingBackend, _create_metafunction
|
22
27
|
|
28
|
+
|
23
29
|
os.environ["MKL_NUM_THREADS"] = "1"
|
24
30
|
os.environ["OMP_NUM_THREADS"] = "1"
|
25
31
|
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
@@ -103,6 +109,20 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
103
109
|
self.solve_triangular = self._solve_triangular
|
104
110
|
self.linalg.solve_triangular = scipy.linalg.solve_triangular
|
105
111
|
|
112
|
+
try:
|
113
|
+
from ._numpyfftw_utils import rfftn as rfftn_cache
|
114
|
+
from ._numpyfftw_utils import irfftn as irfftn_cache
|
115
|
+
|
116
|
+
self._rfftn = rfftn_cache
|
117
|
+
self._irfftn = irfftn_cache
|
118
|
+
except Exception as e:
|
119
|
+
print(e)
|
120
|
+
|
121
|
+
config.NUM_THREADS = 1
|
122
|
+
config.PLANNER_EFFORT = "FFTW_MEASURE"
|
123
|
+
interfaces.cache.enable()
|
124
|
+
interfaces.cache.set_keepalive_time(360)
|
125
|
+
|
106
126
|
def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
|
107
127
|
# Upper argument is not supported until numpy 2.0
|
108
128
|
ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
|
@@ -138,7 +158,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
138
158
|
return float
|
139
159
|
|
140
160
|
def free_cache(self):
|
141
|
-
|
161
|
+
interfaces.cache.disable()
|
142
162
|
|
143
163
|
def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
144
164
|
return self._array_backend.transpose(arr, *args, **kwargs)
|
@@ -181,6 +201,9 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
181
201
|
sorted_indices = self.unravel_index(indices=sorted_indices, shape=arr.shape)
|
182
202
|
return sorted_indices
|
183
203
|
|
204
|
+
def ssum(self, arr, *args, **kwargs):
|
205
|
+
return self.sum(self.square(arr), *args, **kwargs)
|
206
|
+
|
184
207
|
def indices(self, *args, **kwargs) -> NDArray:
|
185
208
|
return self._array_backend.indices(*args, **kwargs)
|
186
209
|
|
@@ -240,70 +263,53 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
240
263
|
b[tuple(bind)] = arr[tuple(aind)]
|
241
264
|
return b
|
242
265
|
|
243
|
-
def
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
fftargs: Dict = {},
|
250
|
-
inv_output_shape: Tuple[int] = None,
|
251
|
-
temp_fwd: NDArray = None,
|
252
|
-
temp_inv: NDArray = None,
|
253
|
-
fwd_axes: Tuple[int] = None,
|
254
|
-
inv_axes: Tuple[int] = None,
|
255
|
-
) -> Tuple[FFTW, FFTW]:
|
256
|
-
if temp_fwd is None:
|
257
|
-
temp_fwd = (
|
258
|
-
self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
|
259
|
-
)
|
260
|
-
if temp_inv is None:
|
261
|
-
temp_inv = (
|
262
|
-
self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
|
263
|
-
)
|
264
|
-
|
265
|
-
default_values = {
|
266
|
-
"planner_effort": "FFTW_MEASURE",
|
267
|
-
"auto_align_input": False,
|
268
|
-
"auto_contiguous": False,
|
269
|
-
"avoid_copy": True,
|
270
|
-
"overwrite_input": True,
|
271
|
-
"threads": 1,
|
272
|
-
}
|
273
|
-
for key in default_values:
|
274
|
-
if key in fftargs:
|
275
|
-
continue
|
276
|
-
fftargs[key] = default_values[key]
|
277
|
-
|
278
|
-
rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
|
279
|
-
_rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
|
280
|
-
overwrite_input = fftargs.pop("overwrite_input", None)
|
281
|
-
|
282
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
283
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
284
|
-
_irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
|
285
|
-
|
286
|
-
def _rfftn_wrapper(arr, out, *args, **kwargs):
|
287
|
-
return _rfftn(arr, out)
|
288
|
-
|
289
|
-
def _irfftn_wrapper(arr, out, *args, **kwargs):
|
290
|
-
return _irfftn(arr, out)
|
291
|
-
|
292
|
-
fftargs["overwrite_input"] = overwrite_input
|
293
|
-
return _rfftn_wrapper, _irfftn_wrapper
|
266
|
+
def _rfftn(self, arr, out=None, **kwargs):
|
267
|
+
ret = interfaces.numpy_fft.rfftn(arr, **kwargs)
|
268
|
+
if out is not None:
|
269
|
+
out[:] = ret
|
270
|
+
return out
|
271
|
+
return ret
|
294
272
|
|
295
|
-
|
296
|
-
|
297
|
-
if
|
298
|
-
|
299
|
-
|
300
|
-
return
|
273
|
+
def _irfftn(self, arr, out=None, **kwargs):
|
274
|
+
ret = interfaces.numpy_fft.irfftn(arr, **kwargs)
|
275
|
+
if out is not None:
|
276
|
+
out[:] = ret
|
277
|
+
return out
|
278
|
+
return ret
|
301
279
|
|
302
|
-
def rfftn(
|
303
|
-
|
280
|
+
def rfftn(
|
281
|
+
self,
|
282
|
+
arr: NDArray,
|
283
|
+
out=None,
|
284
|
+
auto_align_input: bool = False,
|
285
|
+
auto_contiguous: bool = False,
|
286
|
+
overwrite_input: bool = True,
|
287
|
+
**kwargs,
|
288
|
+
) -> NDArray:
|
289
|
+
return self._rfftn(
|
290
|
+
arr,
|
291
|
+
auto_align_input=auto_align_input,
|
292
|
+
auto_contiguous=auto_contiguous,
|
293
|
+
overwrite_input=overwrite_input,
|
294
|
+
**kwargs,
|
295
|
+
)
|
304
296
|
|
305
|
-
def irfftn(
|
306
|
-
|
297
|
+
def irfftn(
|
298
|
+
self,
|
299
|
+
arr: NDArray,
|
300
|
+
out=None,
|
301
|
+
auto_align_input: bool = False,
|
302
|
+
auto_contiguous: bool = False,
|
303
|
+
overwrite_input: bool = True,
|
304
|
+
**kwargs,
|
305
|
+
) -> NDArray:
|
306
|
+
return self._irfftn(
|
307
|
+
arr,
|
308
|
+
auto_align_input=auto_align_input,
|
309
|
+
auto_contiguous=auto_contiguous,
|
310
|
+
overwrite_input=overwrite_input,
|
311
|
+
**kwargs,
|
312
|
+
)
|
307
313
|
|
308
314
|
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
309
315
|
new_shape = self.to_backend_array(newshape)
|
tme/backends/pytorch_backend.py
CHANGED
@@ -7,7 +7,7 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
7
7
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Tuple
|
10
|
+
from typing import Tuple
|
11
11
|
from contextlib import contextmanager
|
12
12
|
from multiprocessing import shared_memory
|
13
13
|
from multiprocessing.managers import SharedMemoryManager
|
@@ -273,31 +273,6 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
273
273
|
kwargs["device"] = self.device
|
274
274
|
return self._array_backend.eye(*args, **kwargs)
|
275
275
|
|
276
|
-
def build_fft(
|
277
|
-
self,
|
278
|
-
fwd_shape: Tuple[int],
|
279
|
-
inv_shape: Tuple[int],
|
280
|
-
inv_output_shape: Tuple[int] = None,
|
281
|
-
fwd_axes: Tuple[int] = None,
|
282
|
-
inv_axes: Tuple[int] = None,
|
283
|
-
**kwargs,
|
284
|
-
) -> Tuple[Callable, Callable]:
|
285
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
286
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
287
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
288
|
-
|
289
|
-
def rfftn(
|
290
|
-
arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
|
291
|
-
) -> TorchTensor:
|
292
|
-
return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
|
293
|
-
|
294
|
-
def irfftn(
|
295
|
-
arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
|
296
|
-
) -> TorchTensor:
|
297
|
-
return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
|
298
|
-
|
299
|
-
return rfftn, irfftn
|
300
|
-
|
301
276
|
def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
302
277
|
kwargs["dim"] = kwargs.pop("axes", None)
|
303
278
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
tme/density.py
CHANGED
@@ -36,6 +36,7 @@ from .matching_utils import (
|
|
36
36
|
array_to_memmap,
|
37
37
|
memmap_to_array,
|
38
38
|
minimum_enclosing_box,
|
39
|
+
is_gzipped,
|
39
40
|
)
|
40
41
|
|
41
42
|
__all__ = ["Density"]
|
@@ -331,6 +332,7 @@ class Density:
|
|
331
332
|
if non_standard_crs:
|
332
333
|
data = np.transpose(data, crs_index)
|
333
334
|
origin = np.take(origin, crs_index)
|
335
|
+
sampling_rate = np.take(sampling_rate, crs_index)
|
334
336
|
|
335
337
|
return data.T, origin[::-1], sampling_rate[::-1], metadata
|
336
338
|
|
@@ -2194,7 +2196,7 @@ class Density:
|
|
2194
2196
|
|
2195
2197
|
Parameters
|
2196
2198
|
----------
|
2197
|
-
target : Density
|
2199
|
+
target : :py:class:`Density`
|
2198
2200
|
The target map for template matching.
|
2199
2201
|
template : Structure
|
2200
2202
|
The template that should be aligned to the target.
|
@@ -2258,8 +2260,59 @@ class Density:
|
|
2258
2260
|
weights = self.data[tuple(coordinates)]
|
2259
2261
|
return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
|
2260
2262
|
|
2263
|
+
@staticmethod
|
2264
|
+
def fourier_shell_correlation(density1: "Density", density2: "Density") -> NDArray:
|
2265
|
+
"""
|
2266
|
+
Computes the Fourier Shell Correlation (FSC) between two instances of `Density`.
|
2267
|
+
|
2268
|
+
The Fourier transforms of the input maps are divided into shells
|
2269
|
+
based on their spatial frequency. The correlation between corresponding shells
|
2270
|
+
in the two maps is computed to give the FSC.
|
2271
|
+
|
2272
|
+
Parameters
|
2273
|
+
----------
|
2274
|
+
density1 : :py:class:`Density`
|
2275
|
+
Reference for comparison.
|
2276
|
+
density2 : :py:class:`Density`
|
2277
|
+
Target for comparison.
|
2278
|
+
|
2279
|
+
Returns
|
2280
|
+
-------
|
2281
|
+
NDArray
|
2282
|
+
An array of shape (N, 2), where N is the number of shells.
|
2283
|
+
The first column represents the spatial frequency for each shell
|
2284
|
+
and the second column represents the corresponding FSC.
|
2285
|
+
|
2286
|
+
References
|
2287
|
+
----------
|
2288
|
+
.. [1] https://github.com/tdgrant1/denss/blob/master/saxstats/saxstats.py
|
2289
|
+
"""
|
2290
|
+
side = density1.data.shape[0]
|
2291
|
+
df = 1.0 / side
|
2292
|
+
|
2293
|
+
qx_ = np.fft.fftfreq(side) * side * df
|
2294
|
+
qx, qy, qz = np.meshgrid(qx_, qx_, qx_, indexing="ij")
|
2295
|
+
qr = np.sqrt(qx**2 + qy**2 + qz**2)
|
2296
|
+
|
2297
|
+
qmax = np.max(qr)
|
2298
|
+
qstep = np.min(qr[qr > 0])
|
2299
|
+
nbins = int(qmax / qstep)
|
2300
|
+
qbins = np.linspace(0, nbins * qstep, nbins + 1)
|
2301
|
+
qbin_labels = np.searchsorted(qbins, qr, "right") - 1
|
2302
|
+
|
2303
|
+
F1 = np.fft.fftn(density1.data)
|
2304
|
+
F2 = np.fft.fftn(density2.data)
|
2305
|
+
|
2306
|
+
qbin_labels = qbin_labels.reshape(-1)
|
2307
|
+
numerator = np.bincount(
|
2308
|
+
qbin_labels, weights=np.real(F1 * np.conj(F2)).reshape(-1)
|
2309
|
+
)
|
2310
|
+
term1 = np.bincount(qbin_labels, weights=np.abs(F1).reshape(-1) ** 2)
|
2311
|
+
term2 = np.bincount(qbin_labels, weights=np.abs(F2).reshape(-1) ** 2)
|
2312
|
+
np.multiply(term1, term2, out=term1)
|
2313
|
+
denominator = np.sqrt(term1)
|
2314
|
+
FSC = np.divide(numerator, denominator)
|
2315
|
+
|
2316
|
+
qidx = np.where(qbins < qx.max())
|
2261
2317
|
|
2262
|
-
|
2263
|
-
"""Check if a file is a gzip file by reading its magic number."""
|
2264
|
-
with open(filename, "rb") as f:
|
2265
|
-
return f.read(2) == b"\x1f\x8b"
|
2318
|
+
return np.vstack((qbins[qidx], FSC[qidx])).T
|
Binary file
|
tme/filters/ctf.py
CHANGED
@@ -36,7 +36,7 @@ class CTF(ComposableFilter):
|
|
36
36
|
|
37
37
|
#: The shape of the to-be created mask.
|
38
38
|
shape: Tuple[int] = None
|
39
|
-
#: The defocus
|
39
|
+
#: The defocus in x direction (in units of sampling rate).
|
40
40
|
defocus_x: Tuple[float] = None
|
41
41
|
#: The tilt angles.
|
42
42
|
angles: Tuple[float] = None
|
@@ -164,7 +164,7 @@ class CTF(ComposableFilter):
|
|
164
164
|
shape : tuple of int
|
165
165
|
The shape of the CTF.
|
166
166
|
defocus_x : tuple of float
|
167
|
-
The defocus
|
167
|
+
The defocus in x direction (in units of sampling rate).
|
168
168
|
angles : tuple of float
|
169
169
|
The tilt angles.
|
170
170
|
opening_axis : int, optional
|
@@ -178,7 +178,7 @@ class CTF(ComposableFilter):
|
|
178
178
|
defocus_angle : tuple of float, optional
|
179
179
|
The defocus angle in radians, defaults to 0.
|
180
180
|
defocus_y : tuple of float, optional
|
181
|
-
The defocus
|
181
|
+
The defocus in x direction (in units of sampling rate).
|
182
182
|
correct_defocus_gradient : bool, optional
|
183
183
|
Whether to correct defocus gradient, defaults to False.
|
184
184
|
sampling_rate : tuple of float, optional
|
@@ -219,14 +219,12 @@ class CTF(ComposableFilter):
|
|
219
219
|
corrected_tilt_axis -= 1
|
220
220
|
|
221
221
|
for index, angle in enumerate(angles):
|
222
|
-
defocus_x, defocus_y = defoci_x[index], defoci_y[index]
|
223
|
-
|
224
222
|
correction = correct_defocus_gradient and angle is not None
|
225
223
|
chi = create_ctf(
|
226
224
|
angle=angle,
|
227
225
|
shape=ctf_shape,
|
228
|
-
defocus_x=
|
229
|
-
defocus_y=
|
226
|
+
defocus_x=defoci_x[index],
|
227
|
+
defocus_y=defoci_y[index],
|
230
228
|
sampling_rate=sampling_rate,
|
231
229
|
acceleration_voltage=acceleration_voltage[index],
|
232
230
|
correct_defocus_gradient=correction,
|
@@ -243,12 +241,10 @@ class CTF(ComposableFilter):
|
|
243
241
|
stack[index] = chi
|
244
242
|
|
245
243
|
# Avoid contrast inversion
|
246
|
-
np.negative(stack, out=stack)
|
244
|
+
stack = np.negative(stack, out=stack)
|
247
245
|
if flip_phase:
|
248
|
-
np.abs(stack, out=stack)
|
249
|
-
|
250
|
-
stack = be.to_backend_array(np.squeeze(stack))
|
251
|
-
return stack
|
246
|
+
stack = np.abs(stack, out=stack)
|
247
|
+
return be.to_backend_array(np.squeeze(stack))
|
252
248
|
|
253
249
|
|
254
250
|
class CTFReconstructed(CTF):
|
@@ -281,7 +277,7 @@ class CTFReconstructed(CTF):
|
|
281
277
|
shape : tuple of int
|
282
278
|
The shape of the CTF.
|
283
279
|
defocus_x : tuple of float
|
284
|
-
The defocus
|
280
|
+
The defocus in x direction in units of sampling rate.
|
285
281
|
opening_axis : int, optional
|
286
282
|
The axis around which the wedge is opened, defaults to 2.
|
287
283
|
amplitude_contrast : float, optional
|
@@ -291,7 +287,7 @@ class CTFReconstructed(CTF):
|
|
291
287
|
defocus_angle : tuple of float, optional
|
292
288
|
The defocus angle in radians, defaults to 0.
|
293
289
|
defocus_y : tuple of float, optional
|
294
|
-
The defocus
|
290
|
+
The defocus in y direction in units of sampling rate.
|
295
291
|
sampling_rate : tuple of float, optional
|
296
292
|
The sampling rate, defaults to 1.
|
297
293
|
acceleration_voltage : float, optional
|
@@ -321,18 +317,15 @@ class CTFReconstructed(CTF):
|
|
321
317
|
defocus_angle=defocus_angle,
|
322
318
|
amplitude_contrast=amplitude_contrast,
|
323
319
|
)
|
324
|
-
stack = shift_fourier(data=stack, shape_is_real_fourier=False)
|
325
|
-
|
326
320
|
# Avoid contrast inversion
|
327
321
|
np.negative(stack, out=stack)
|
328
322
|
if flip_phase:
|
329
323
|
np.abs(stack, out=stack)
|
330
324
|
|
331
|
-
stack =
|
325
|
+
stack = shift_fourier(data=stack, shape_is_real_fourier=False)
|
332
326
|
if return_real_fourier:
|
333
327
|
stack = crop_real_fourier(stack)
|
334
|
-
|
335
|
-
return stack
|
328
|
+
return be.to_backend_array(np.squeeze(stack))
|
336
329
|
|
337
330
|
|
338
331
|
def _from_xml(filename: str) -> Dict:
|
@@ -436,6 +429,9 @@ def _from_ctffind(filename: str) -> Dict:
|
|
436
429
|
output[key] = np.array(output[key])
|
437
430
|
|
438
431
|
output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
|
432
|
+
cs = output.get("spherical_aberration")
|
433
|
+
if cs is not None:
|
434
|
+
output["spherical_aberration"] = float(cs) * 1e7
|
439
435
|
return output
|
440
436
|
|
441
437
|
|
@@ -566,7 +562,7 @@ def create_ctf(
|
|
566
562
|
amplitude_contrast : float, optional
|
567
563
|
Amplitude contrast of microscope, defaults to 0.07.
|
568
564
|
spherical_aberration : float, optional
|
569
|
-
Spherical aberration of microscope in
|
565
|
+
Spherical aberration of microscope in units of sampling rate.
|
570
566
|
angle : float, optional
|
571
567
|
Assume the created CTF is a projection over opening_axis observed at angle.
|
572
568
|
opening_axis : int, optional
|
@@ -590,10 +586,14 @@ def create_ctf(
|
|
590
586
|
electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
|
591
587
|
electron_wavelength /= sampling_rate
|
592
588
|
aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
|
589
|
+
|
590
|
+
defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
|
591
|
+
defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
|
593
592
|
if correct_defocus_gradient or defocus_y is not None:
|
594
593
|
if len(shape) < 2:
|
595
594
|
raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
|
596
595
|
|
596
|
+
# Axial distance from grid center in multiples of sampling rate
|
597
597
|
sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
|
598
598
|
grid = fftfreqn(
|
599
599
|
shape=shape,
|
@@ -619,6 +619,7 @@ def create_ctf(
|
|
619
619
|
defocus_sum = np.add(defocus_x, defocus_y)
|
620
620
|
defocus_difference = np.subtract(defocus_x, defocus_y)
|
621
621
|
|
622
|
+
# Reusing grid, but in principle pure frequencies would suffice
|
622
623
|
angular_grid = np.arctan2(grid[1], grid[0])
|
623
624
|
defocus_difference = np.multiply(
|
624
625
|
defocus_difference,
|
@@ -627,7 +628,7 @@ def create_ctf(
|
|
627
628
|
defocus_x = np.add(defocus_sum, defocus_difference)
|
628
629
|
defocus_x *= 0.5
|
629
630
|
|
630
|
-
frequency_grid = fftfreqn(shape, sampling_rate=
|
631
|
+
frequency_grid = fftfreqn(shape, sampling_rate=1, compute_euclidean_norm=True)
|
631
632
|
if angle is not None and opening_axis is not None and full_shape is not None:
|
632
633
|
frequency_grid = frequency_grid_at_angle(
|
633
634
|
shape=full_shape,
|
tme/filters/wedge.py
CHANGED
@@ -15,7 +15,7 @@ import numpy as np
|
|
15
15
|
from ..types import NDArray
|
16
16
|
from ..backends import backend as be
|
17
17
|
from .compose import ComposableFilter
|
18
|
-
from ..matching_utils import
|
18
|
+
from ..matching_utils import center_slice
|
19
19
|
from ..rotations import euler_to_rotationmatrix
|
20
20
|
from ..parser import XMLParser, StarParser, MDOCParser
|
21
21
|
from ._utils import (
|
@@ -207,11 +207,10 @@ class Wedge(ComposableFilter):
|
|
207
207
|
)
|
208
208
|
sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
|
209
209
|
sigma = -2 * np.pi**2 * sigma**2
|
210
|
-
np.square(frequency_grid, out=frequency_grid)
|
211
|
-
np.multiply(sigma, frequency_grid, out=frequency_grid)
|
212
|
-
np.exp(frequency_grid, out=frequency_grid)
|
213
|
-
np.multiply(frequency_grid, np.cos(np.radians(angle))
|
214
|
-
wedges[index] = frequency_grid
|
210
|
+
frequency_grid = np.square(frequency_grid, out=frequency_grid)
|
211
|
+
frequency_grid = np.multiply(sigma, frequency_grid, out=frequency_grid)
|
212
|
+
frequency_grid = np.exp(frequency_grid, out=frequency_grid)
|
213
|
+
wedges[index] = np.multiply(frequency_grid, np.cos(np.radians(angle)))
|
215
214
|
|
216
215
|
return wedges
|
217
216
|
|
@@ -490,7 +489,11 @@ class WedgeReconstructed:
|
|
490
489
|
)
|
491
490
|
wedge_volume += plane_rotated * weights[index]
|
492
491
|
|
493
|
-
|
492
|
+
subset = center_slice(
|
493
|
+
wedge_volume.shape, (shape[opening_axis], shape[tilt_axis])
|
494
|
+
)
|
495
|
+
wedge_volume = wedge_volume[subset]
|
496
|
+
|
494
497
|
np.fmin(wedge_volume, np.max(weights), wedge_volume)
|
495
498
|
|
496
499
|
if opening_axis > tilt_axis:
|