pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
tme/filters/bandpass.py
ADDED
@@ -0,0 +1,230 @@
|
|
1
|
+
""" Implements class BandPassFilter to create Fourier filter representations.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple
|
9
|
+
from math import log, sqrt
|
10
|
+
|
11
|
+
from ..types import BackendArray
|
12
|
+
from ..backends import backend as be
|
13
|
+
from .compose import ComposableFilter
|
14
|
+
from ._utils import fftfreqn, crop_real_fourier, shift_fourier
|
15
|
+
|
16
|
+
__all__ = ["BandPassFilter"]
|
17
|
+
|
18
|
+
|
19
|
+
class BandPassFilter(ComposableFilter):
|
20
|
+
"""
|
21
|
+
Generate bandpass filters in Fourier space.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
lowpass : float, optional
|
26
|
+
The lowpass cutoff, defaults to None.
|
27
|
+
highpass : float, optional
|
28
|
+
The highpass cutoff, defaults to None.
|
29
|
+
sampling_rate : Tuple[float], optional
|
30
|
+
The sampling r_position_to_molmapate in Fourier space, defaults to 1.
|
31
|
+
use_gaussian : bool, optional
|
32
|
+
Whether to use Gaussian bandpass filter, defaults to True.
|
33
|
+
return_real_fourier : bool, optional
|
34
|
+
Whether to return only the real Fourier space, defaults to False.
|
35
|
+
shape_is_real_fourier : bool, optional
|
36
|
+
Whether the shape represents the real Fourier space, defaults to False.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
lowpass: float = None,
|
42
|
+
highpass: float = None,
|
43
|
+
sampling_rate: Tuple[float] = 1,
|
44
|
+
use_gaussian: bool = True,
|
45
|
+
return_real_fourier: bool = False,
|
46
|
+
shape_is_real_fourier: bool = False,
|
47
|
+
):
|
48
|
+
self.lowpass = lowpass
|
49
|
+
self.highpass = highpass
|
50
|
+
self.use_gaussian = use_gaussian
|
51
|
+
self.return_real_fourier = return_real_fourier
|
52
|
+
self.shape_is_real_fourier = shape_is_real_fourier
|
53
|
+
self.sampling_rate = sampling_rate
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def discrete_bandpass(
|
57
|
+
shape: Tuple[int],
|
58
|
+
lowpass: float,
|
59
|
+
highpass: float,
|
60
|
+
sampling_rate: Tuple[float],
|
61
|
+
return_real_fourier: bool = False,
|
62
|
+
shape_is_real_fourier: bool = False,
|
63
|
+
**kwargs,
|
64
|
+
) -> BackendArray:
|
65
|
+
"""
|
66
|
+
Generate a bandpass filter using discrete frequency cutoffs.
|
67
|
+
|
68
|
+
Parameters
|
69
|
+
----------
|
70
|
+
shape : tuple of int
|
71
|
+
The shape of the bandpass filter.
|
72
|
+
lowpass : float
|
73
|
+
The lowpass cutoff in units of sampling rate.
|
74
|
+
highpass : float
|
75
|
+
The highpass cutoff in units of sampling rate.
|
76
|
+
return_real_fourier : bool, optional
|
77
|
+
Whether to return only the real Fourier space, defaults to False.
|
78
|
+
sampling_rate : float
|
79
|
+
The sampling rate in Fourier space.
|
80
|
+
shape_is_real_fourier : bool, optional
|
81
|
+
Whether the shape represents the real Fourier space, defaults to False.
|
82
|
+
**kwargs : dict
|
83
|
+
Additional keyword arguments.
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
BackendArray
|
88
|
+
The bandpass filter in Fourier space.
|
89
|
+
"""
|
90
|
+
if shape_is_real_fourier:
|
91
|
+
return_real_fourier = False
|
92
|
+
|
93
|
+
grid = fftfreqn(
|
94
|
+
shape=shape,
|
95
|
+
sampling_rate=0.5,
|
96
|
+
shape_is_real_fourier=shape_is_real_fourier,
|
97
|
+
compute_euclidean_norm=True,
|
98
|
+
)
|
99
|
+
grid = be.astype(be.to_backend_array(grid), be._float_dtype)
|
100
|
+
sampling_rate = be.to_backend_array(sampling_rate)
|
101
|
+
|
102
|
+
highcut = grid.max()
|
103
|
+
if lowpass is not None:
|
104
|
+
highcut = be.max(2 * sampling_rate / lowpass)
|
105
|
+
|
106
|
+
lowcut = 0
|
107
|
+
if highpass is not None:
|
108
|
+
lowcut = be.max(2 * sampling_rate / highpass)
|
109
|
+
|
110
|
+
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
|
111
|
+
bandpass_filter = shift_fourier(
|
112
|
+
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
113
|
+
)
|
114
|
+
|
115
|
+
if return_real_fourier:
|
116
|
+
bandpass_filter = crop_real_fourier(bandpass_filter)
|
117
|
+
|
118
|
+
return bandpass_filter
|
119
|
+
|
120
|
+
@staticmethod
|
121
|
+
def gaussian_bandpass(
|
122
|
+
shape: Tuple[int],
|
123
|
+
lowpass: float,
|
124
|
+
highpass: float,
|
125
|
+
sampling_rate: float,
|
126
|
+
return_real_fourier: bool = False,
|
127
|
+
shape_is_real_fourier: bool = False,
|
128
|
+
**kwargs,
|
129
|
+
) -> BackendArray:
|
130
|
+
"""
|
131
|
+
Generate a bandpass filter using Gaussians.
|
132
|
+
|
133
|
+
Parameters
|
134
|
+
----------
|
135
|
+
shape : tuple of int
|
136
|
+
The shape of the bandpass filter.
|
137
|
+
lowpass : float
|
138
|
+
The lowpass cutoff in units of sampling rate.
|
139
|
+
highpass : float
|
140
|
+
The highpass cutoff in units of sampling rate.
|
141
|
+
sampling_rate : float
|
142
|
+
The sampling rate in Fourier space.
|
143
|
+
return_real_fourier : bool, optional
|
144
|
+
Whether to return only the real Fourier space, defaults to False.
|
145
|
+
shape_is_real_fourier : bool, optional
|
146
|
+
Whether the shape represents the real Fourier space, defaults to False.
|
147
|
+
**kwargs : dict
|
148
|
+
Additional keyword arguments.
|
149
|
+
|
150
|
+
Returns
|
151
|
+
-------
|
152
|
+
BackendArray
|
153
|
+
The bandpass filter in Fourier space.
|
154
|
+
"""
|
155
|
+
if shape_is_real_fourier:
|
156
|
+
return_real_fourier = False
|
157
|
+
|
158
|
+
grid = fftfreqn(
|
159
|
+
shape=shape,
|
160
|
+
sampling_rate=0.5,
|
161
|
+
shape_is_real_fourier=shape_is_real_fourier,
|
162
|
+
compute_euclidean_norm=True,
|
163
|
+
)
|
164
|
+
grid = be.astype(be.to_backend_array(grid), be._float_dtype)
|
165
|
+
grid = -be.square(grid, out=grid)
|
166
|
+
|
167
|
+
has_lowpass, has_highpass = False, False
|
168
|
+
norm = float(sqrt(2 * log(2)))
|
169
|
+
upper_sampling = float(
|
170
|
+
be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
|
171
|
+
)
|
172
|
+
|
173
|
+
if lowpass is not None:
|
174
|
+
lowpass, has_lowpass = float(lowpass), True
|
175
|
+
lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
|
176
|
+
if highpass is not None:
|
177
|
+
highpass, has_highpass = float(highpass), True
|
178
|
+
highpass = be.maximum(highpass, be.eps(be._float_dtype))
|
179
|
+
|
180
|
+
if has_lowpass:
|
181
|
+
lowpass = upper_sampling / (lowpass * norm)
|
182
|
+
lowpass = be.multiply(2, be.square(lowpass))
|
183
|
+
if not has_highpass:
|
184
|
+
lowpass_filter = be.divide(grid, lowpass, out=grid)
|
185
|
+
else:
|
186
|
+
lowpass_filter = be.divide(grid, lowpass)
|
187
|
+
lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter)
|
188
|
+
|
189
|
+
if has_highpass:
|
190
|
+
highpass = upper_sampling / (highpass * norm)
|
191
|
+
highpass = be.multiply(2, be.square(highpass))
|
192
|
+
highpass_filter = be.divide(grid, highpass, out=grid)
|
193
|
+
highpass_filter = be.exp(highpass_filter, out=highpass_filter)
|
194
|
+
highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter)
|
195
|
+
|
196
|
+
if has_lowpass and not has_highpass:
|
197
|
+
bandpass_filter = lowpass_filter
|
198
|
+
elif not has_lowpass and has_highpass:
|
199
|
+
bandpass_filter = highpass_filter
|
200
|
+
elif has_lowpass and has_highpass:
|
201
|
+
bandpass_filter = be.multiply(
|
202
|
+
lowpass_filter, highpass_filter, out=lowpass_filter
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
bandpass_filter = be.full(shape, fill_value=1, dtype=be._float_dtype)
|
206
|
+
|
207
|
+
bandpass_filter = shift_fourier(
|
208
|
+
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
209
|
+
)
|
210
|
+
|
211
|
+
if return_real_fourier:
|
212
|
+
bandpass_filter = crop_real_fourier(bandpass_filter)
|
213
|
+
|
214
|
+
return bandpass_filter
|
215
|
+
|
216
|
+
def __call__(self, **kwargs):
|
217
|
+
func_args = vars(self)
|
218
|
+
func_args.update(kwargs)
|
219
|
+
|
220
|
+
func = self.discrete_bandpass
|
221
|
+
if func_args.get("use_gaussian"):
|
222
|
+
func = self.gaussian_bandpass
|
223
|
+
|
224
|
+
mask = func(**func_args)
|
225
|
+
|
226
|
+
return {
|
227
|
+
"data": be.to_backend_array(mask),
|
228
|
+
"sampling_rate": func_args.get("sampling_rate", 1),
|
229
|
+
"is_multiplicative_filter": True,
|
230
|
+
}
|
tme/filters/compose.py
ADDED
@@ -0,0 +1,81 @@
|
|
1
|
+
""" Combine filters using an interface analogous to pytorch's Compose.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple, Dict
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
|
11
|
+
from tme.backends import backend as be
|
12
|
+
|
13
|
+
__all__ = ["Compose", "ComposableFilter"]
|
14
|
+
|
15
|
+
|
16
|
+
class Compose:
|
17
|
+
"""
|
18
|
+
Compose a series of transformations.
|
19
|
+
|
20
|
+
This class allows composing multiple transformations together. Each transformation
|
21
|
+
is expected to be a callable that accepts keyword arguments and returns metadata.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
transforms : Tuple[object]
|
26
|
+
A tuple containing transformation objects.
|
27
|
+
|
28
|
+
Returns
|
29
|
+
-------
|
30
|
+
Dict
|
31
|
+
Metadata resulting from the composed transformations.
|
32
|
+
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, transforms: Tuple[object]):
|
36
|
+
self.transforms = transforms
|
37
|
+
|
38
|
+
def __call__(self, **kwargs: Dict) -> Dict:
|
39
|
+
meta = {}
|
40
|
+
if not len(self.transforms):
|
41
|
+
return meta
|
42
|
+
|
43
|
+
meta = self.transforms[0](**kwargs)
|
44
|
+
for transform in self.transforms[1:]:
|
45
|
+
kwargs.update(meta)
|
46
|
+
ret = transform(**kwargs)
|
47
|
+
|
48
|
+
if "data" not in ret:
|
49
|
+
continue
|
50
|
+
|
51
|
+
if ret.get("is_multiplicative_filter", False):
|
52
|
+
prev_data = meta.pop("data")
|
53
|
+
ret["data"] = be.multiply(ret["data"], prev_data)
|
54
|
+
ret["merge"], prev_data = None, None
|
55
|
+
|
56
|
+
meta = ret
|
57
|
+
|
58
|
+
return meta
|
59
|
+
|
60
|
+
|
61
|
+
class ComposableFilter(ABC):
|
62
|
+
"""
|
63
|
+
Strategy class for composable filters.
|
64
|
+
"""
|
65
|
+
|
66
|
+
@abstractmethod
|
67
|
+
def __call__(self, *args, **kwargs) -> Dict:
|
68
|
+
"""
|
69
|
+
|
70
|
+
Parameters
|
71
|
+
----------
|
72
|
+
*args : tuple
|
73
|
+
Variable length argument list.
|
74
|
+
**kwargs : dict
|
75
|
+
Arbitrary keyword arguments.
|
76
|
+
|
77
|
+
Returns
|
78
|
+
-------
|
79
|
+
Dict
|
80
|
+
A dictionary representing the result of the filtering operation.
|
81
|
+
"""
|
tme/filters/ctf.py
ADDED
@@ -0,0 +1,393 @@
|
|
1
|
+
""" Implements class CTF to create Fourier filter representations.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import re
|
9
|
+
import warnings
|
10
|
+
from typing import Tuple, Dict
|
11
|
+
from dataclasses import dataclass
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
|
15
|
+
from ..types import NDArray
|
16
|
+
from ..parser import StarParser
|
17
|
+
from ..backends import backend as be
|
18
|
+
from .compose import ComposableFilter
|
19
|
+
from ._utils import (
|
20
|
+
frequency_grid_at_angle,
|
21
|
+
compute_tilt_shape,
|
22
|
+
crop_real_fourier,
|
23
|
+
fftfreqn,
|
24
|
+
shift_fourier,
|
25
|
+
)
|
26
|
+
|
27
|
+
__all__ = ["CTF"]
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class CTF(ComposableFilter):
|
32
|
+
"""
|
33
|
+
Generate a contrast transfer function mask.
|
34
|
+
|
35
|
+
References
|
36
|
+
----------
|
37
|
+
.. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
|
38
|
+
Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
|
39
|
+
"""
|
40
|
+
|
41
|
+
#: The shape of the to-be reconstructed volume.
|
42
|
+
shape: Tuple[int] = None
|
43
|
+
#: The defocus value in x direction.
|
44
|
+
defocus_x: float = None
|
45
|
+
#: The tilt angles.
|
46
|
+
angles: Tuple[float] = None
|
47
|
+
#: The axis around which the wedge is opened, defaults to None.
|
48
|
+
opening_axis: int = None
|
49
|
+
#: The axis along which the tilt is applied, defaults to None.
|
50
|
+
tilt_axis: int = None
|
51
|
+
#: Whether to correct defocus gradient, defaults to False.
|
52
|
+
correct_defocus_gradient: bool = False
|
53
|
+
#: The sampling rate, defaults to 1 Angstrom / Voxel.
|
54
|
+
sampling_rate: Tuple[float] = 1
|
55
|
+
#: The acceleration voltage in Volts, defaults to 300e3.
|
56
|
+
acceleration_voltage: float = 300e3
|
57
|
+
#: The spherical aberration coefficient, defaults to 2.7e7.
|
58
|
+
spherical_aberration: float = 2.7e7
|
59
|
+
#: The amplitude contrast, defaults to 0.07.
|
60
|
+
amplitude_contrast: float = 0.07
|
61
|
+
#: The phase shift, defaults to 0.
|
62
|
+
phase_shift: float = 0
|
63
|
+
#: The defocus angle, defaults to 0.
|
64
|
+
defocus_angle: float = 0
|
65
|
+
#: The defocus value in y direction, defaults to None.
|
66
|
+
defocus_y: float = None
|
67
|
+
#: Whether the returned CTF should be phase-flipped.
|
68
|
+
flip_phase: bool = True
|
69
|
+
#: Whether to return a format compliant with rfft. Only relevant for single angles.
|
70
|
+
return_real_fourier: bool = False
|
71
|
+
#: Whether the output should not be used for n+1 dimensional reconstruction
|
72
|
+
no_reconstruction: bool = True
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def from_file(cls, filename: str) -> "CTF":
|
76
|
+
"""
|
77
|
+
Initialize :py:class:`CTF` from file.
|
78
|
+
|
79
|
+
Parameters
|
80
|
+
----------
|
81
|
+
filename : str
|
82
|
+
The path to a file with ctf parameters. Supports the following formats:
|
83
|
+
- CTFFIND4
|
84
|
+
"""
|
85
|
+
if filename.lower().endswith("star"):
|
86
|
+
data = cls._from_gctf(filename=filename)
|
87
|
+
else:
|
88
|
+
data = cls._from_ctffind(filename=filename)
|
89
|
+
|
90
|
+
return cls(
|
91
|
+
shape=None,
|
92
|
+
angles=None,
|
93
|
+
defocus_x=data["defocus_1"],
|
94
|
+
sampling_rate=data["pixel_size"],
|
95
|
+
acceleration_voltage=data["acceleration_voltage"],
|
96
|
+
spherical_aberration=data["spherical_aberration"],
|
97
|
+
amplitude_contrast=data["amplitude_contrast"],
|
98
|
+
phase_shift=data["additional_phase_shift"],
|
99
|
+
defocus_angle=np.degrees(data["azimuth_astigmatism"]),
|
100
|
+
defocus_y=data["defocus_2"],
|
101
|
+
)
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def _from_ctffind(filename: str):
|
105
|
+
parameter_regex = {
|
106
|
+
"pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
|
107
|
+
"acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
|
108
|
+
"spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
|
109
|
+
"amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
|
110
|
+
}
|
111
|
+
|
112
|
+
with open(filename, mode="r", encoding="utf-8") as infile:
|
113
|
+
lines = [x.strip() for x in infile.read().split("\n")]
|
114
|
+
lines = [x for x in lines if len(x)]
|
115
|
+
|
116
|
+
def _screen_params(line, params, output):
|
117
|
+
for parameter, regex_pattern in parameter_regex.items():
|
118
|
+
match = re.search(regex_pattern, line)
|
119
|
+
if match:
|
120
|
+
output[parameter] = float(match.group(1))
|
121
|
+
|
122
|
+
columns = {
|
123
|
+
"micrograph_number": 0,
|
124
|
+
"defocus_1": 1,
|
125
|
+
"defocus_2": 2,
|
126
|
+
"azimuth_astigmatism": 3,
|
127
|
+
"additional_phase_shift": 4,
|
128
|
+
"cross_correlation": 5,
|
129
|
+
"spacing": 6,
|
130
|
+
}
|
131
|
+
output = {k: [] for k in columns.keys()}
|
132
|
+
for line in lines:
|
133
|
+
if line.startswith("#"):
|
134
|
+
_screen_params(line, params=parameter_regex, output=output)
|
135
|
+
continue
|
136
|
+
|
137
|
+
values = line.split()
|
138
|
+
for key, value in columns.items():
|
139
|
+
output[key].append(float(values[value]))
|
140
|
+
|
141
|
+
for key in columns:
|
142
|
+
output[key] = np.array(output[key])
|
143
|
+
|
144
|
+
return output
|
145
|
+
|
146
|
+
@staticmethod
|
147
|
+
def _from_gctf(filename: str):
|
148
|
+
parser = StarParser(filename)
|
149
|
+
ctf_data = parser["data_"]
|
150
|
+
|
151
|
+
mapping = {
|
152
|
+
"defocus_1": ("_rlnDefocusU", float),
|
153
|
+
"defocus_2": ("_rlnDefocusV", float),
|
154
|
+
"pixel_size": ("_rlnDetectorPixelSize", float),
|
155
|
+
"acceleration_voltage": ("_rlnVoltage", float),
|
156
|
+
"spherical_aberration": ("_rlnSphericalAberration", float),
|
157
|
+
"amplitude_contrast": ("_rlnAmplitudeContrast", float),
|
158
|
+
"additional_phase_shift": (None, float),
|
159
|
+
"azimuth_astigmatism": ("_rlnDefocusAngle", float),
|
160
|
+
}
|
161
|
+
output = {}
|
162
|
+
for out_key, (key, key_dtype) in mapping.items():
|
163
|
+
if key not in ctf_data and key is not None:
|
164
|
+
warnings.warn(f"ctf_data is missing key {key}.")
|
165
|
+
|
166
|
+
key_value = ctf_data.get(key, [0])
|
167
|
+
output[out_key] = [key_dtype(x) for x in key_value]
|
168
|
+
|
169
|
+
longest_key = max(map(len, output.values()))
|
170
|
+
output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()}
|
171
|
+
return output
|
172
|
+
|
173
|
+
def __post_init__(self):
|
174
|
+
self.defocus_angle = np.radians(self.defocus_angle)
|
175
|
+
|
176
|
+
def _compute_electron_wavelength(self, acceleration_voltage: int = None):
|
177
|
+
"""Computes the wavelength of an electron in angstrom."""
|
178
|
+
|
179
|
+
if acceleration_voltage is None:
|
180
|
+
acceleration_voltage = self.acceleration_voltage
|
181
|
+
|
182
|
+
# Physical constants expressed in SI units
|
183
|
+
planck_constant = 6.62606896e-34
|
184
|
+
electron_charge = 1.60217646e-19
|
185
|
+
electron_mass = 9.10938215e-31
|
186
|
+
light_velocity = 299792458
|
187
|
+
|
188
|
+
energy = electron_charge * acceleration_voltage
|
189
|
+
denominator = energy**2
|
190
|
+
denominator += 2 * energy * electron_mass * light_velocity**2
|
191
|
+
electron_wavelength = np.divide(
|
192
|
+
planck_constant * light_velocity, np.sqrt(denominator)
|
193
|
+
)
|
194
|
+
# Convert to Ångstrom
|
195
|
+
electron_wavelength *= 1e10
|
196
|
+
return electron_wavelength
|
197
|
+
|
198
|
+
def __call__(self, **kwargs) -> NDArray:
|
199
|
+
func_args = vars(self).copy()
|
200
|
+
func_args.update(kwargs)
|
201
|
+
|
202
|
+
if len(func_args["angles"]) != len(func_args["defocus_x"]):
|
203
|
+
func_args["angles"] = self.angles
|
204
|
+
func_args["return_real_fourier"] = False
|
205
|
+
func_args["tilt_axis"] = None
|
206
|
+
func_args["opening_axis"] = None
|
207
|
+
|
208
|
+
ret = self.weight(**func_args)
|
209
|
+
ret = be.astype(be.to_backend_array(ret), be._float_dtype)
|
210
|
+
return {
|
211
|
+
"data": ret,
|
212
|
+
"angles": func_args["angles"],
|
213
|
+
"tilt_axis": func_args["tilt_axis"],
|
214
|
+
"opening_axis": func_args["opening_axis"],
|
215
|
+
"is_multiplicative_filter": True,
|
216
|
+
}
|
217
|
+
|
218
|
+
@staticmethod
|
219
|
+
def _pad_to_length(arr, length: int):
|
220
|
+
ret = np.atleast_1d(arr)
|
221
|
+
return np.repeat(ret, length // ret.size)
|
222
|
+
|
223
|
+
def weight(
|
224
|
+
self,
|
225
|
+
shape: Tuple[int],
|
226
|
+
defocus_x: Tuple[float],
|
227
|
+
angles: Tuple[float],
|
228
|
+
opening_axis: int = None,
|
229
|
+
tilt_axis: int = None,
|
230
|
+
amplitude_contrast: float = 0.07,
|
231
|
+
phase_shift: Tuple[float] = 0,
|
232
|
+
defocus_angle: Tuple[float] = 0,
|
233
|
+
defocus_y: Tuple[float] = None,
|
234
|
+
correct_defocus_gradient: bool = False,
|
235
|
+
sampling_rate: Tuple[float] = 1,
|
236
|
+
acceleration_voltage: float = 300e3,
|
237
|
+
spherical_aberration: float = 2.7e3,
|
238
|
+
flip_phase: bool = True,
|
239
|
+
return_real_fourier: bool = False,
|
240
|
+
no_reconstruction: bool = True,
|
241
|
+
cutoff_frequency: float = 0.5,
|
242
|
+
**kwargs: Dict,
|
243
|
+
) -> NDArray:
|
244
|
+
"""
|
245
|
+
Compute the CTF weight tilt stack.
|
246
|
+
|
247
|
+
Parameters
|
248
|
+
----------
|
249
|
+
shape : tuple of int
|
250
|
+
The shape of the CTF.
|
251
|
+
defocus_x : tuple of float
|
252
|
+
The defocus value in x direction.
|
253
|
+
angles : tuple of float
|
254
|
+
The tilt angles.
|
255
|
+
opening_axis : int, optional
|
256
|
+
The axis around which the wedge is opened, defaults to None.
|
257
|
+
tilt_axis : int, optional
|
258
|
+
The axis along which the tilt is applied, defaults to None.
|
259
|
+
amplitude_contrast : float, optional
|
260
|
+
The amplitude contrast, defaults to 0.07.
|
261
|
+
phase_shift : tuple of float, optional
|
262
|
+
The phase shift, defaults to 0.
|
263
|
+
defocus_angle : tuple of float, optional
|
264
|
+
The defocus angle, defaults to 0.
|
265
|
+
defocus_y : tuple of float, optional
|
266
|
+
The defocus value in y direction, defaults to None.
|
267
|
+
correct_defocus_gradient : bool, optional
|
268
|
+
Whether to correct defocus gradient, defaults to False.
|
269
|
+
sampling_rate : tuple of float, optional
|
270
|
+
The sampling rate, defaults to 1.
|
271
|
+
acceleration_voltage : float, optional
|
272
|
+
The acceleration voltage in electron microscopy, defaults to 300e3.
|
273
|
+
spherical_aberration : float, optional
|
274
|
+
The spherical aberration coefficient, defaults to 2.7e3.
|
275
|
+
flip_phase : bool, optional
|
276
|
+
Whether the returned CTF should be phase-flipped.
|
277
|
+
**kwargs : Dict
|
278
|
+
Additional keyword arguments.
|
279
|
+
|
280
|
+
Returns
|
281
|
+
-------
|
282
|
+
NDArray
|
283
|
+
A stack containing the CTF weight.
|
284
|
+
"""
|
285
|
+
angles = np.atleast_1d(angles)
|
286
|
+
defoci_x = self._pad_to_length(defocus_x, angles.size)
|
287
|
+
defoci_y = self._pad_to_length(defocus_y, angles.size)
|
288
|
+
phase_shift = self._pad_to_length(phase_shift, angles.size)
|
289
|
+
defocus_angle = self._pad_to_length(defocus_angle, angles.size)
|
290
|
+
spherical_aberration = self._pad_to_length(spherical_aberration, angles.size)
|
291
|
+
amplitude_contrast = self._pad_to_length(amplitude_contrast, angles.size)
|
292
|
+
|
293
|
+
sampling_rate = np.max(sampling_rate)
|
294
|
+
tilt_shape = compute_tilt_shape(
|
295
|
+
shape=shape, opening_axis=opening_axis, reduce_dim=True
|
296
|
+
)
|
297
|
+
stack = np.zeros((len(angles), *tilt_shape))
|
298
|
+
|
299
|
+
correct_defocus_gradient &= len(shape) == 3
|
300
|
+
correct_defocus_gradient &= tilt_axis is not None
|
301
|
+
correct_defocus_gradient &= opening_axis is not None
|
302
|
+
|
303
|
+
spherical_aberration /= sampling_rate
|
304
|
+
electron_wavelength = self._compute_electron_wavelength() / sampling_rate
|
305
|
+
electron_aberration = spherical_aberration * electron_wavelength**2
|
306
|
+
|
307
|
+
for index, angle in enumerate(angles):
|
308
|
+
defocus_x, defocus_y = defoci_x[index], defoci_y[index]
|
309
|
+
|
310
|
+
defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
|
311
|
+
defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
|
312
|
+
|
313
|
+
if correct_defocus_gradient or defocus_y is not None:
|
314
|
+
grid_shape = shape
|
315
|
+
sampling = be.divide(sampling_rate, be.to_backend_array(shape))
|
316
|
+
sampling = tuple(float(x) for x in sampling)
|
317
|
+
if not no_reconstruction:
|
318
|
+
grid_shape = tilt_shape
|
319
|
+
sampling = tuple(
|
320
|
+
x for i, x in enumerate(sampling) if i != opening_axis
|
321
|
+
)
|
322
|
+
|
323
|
+
grid = fftfreqn(
|
324
|
+
shape=grid_shape,
|
325
|
+
sampling_rate=sampling,
|
326
|
+
return_sparse_grid=True,
|
327
|
+
)
|
328
|
+
|
329
|
+
# This should be done after defocus_x computation
|
330
|
+
if correct_defocus_gradient:
|
331
|
+
angle_rad = np.radians(angle)
|
332
|
+
defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
|
333
|
+
remaining_axis = tuple(
|
334
|
+
i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
|
335
|
+
)[0]
|
336
|
+
|
337
|
+
if tilt_axis > remaining_axis:
|
338
|
+
defocus_x = np.add(defocus_x, defocus_gradient)
|
339
|
+
elif tilt_axis < remaining_axis and defocus_y is not None:
|
340
|
+
defocus_y = np.add(defocus_y, defocus_gradient.T)
|
341
|
+
|
342
|
+
# 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
|
343
|
+
if defocus_y is not None:
|
344
|
+
defocus_sum = np.add(defocus_x, defocus_y)
|
345
|
+
defocus_difference = np.subtract(defocus_x, defocus_y)
|
346
|
+
|
347
|
+
angular_grid = np.arctan2(grid[1], grid[0])
|
348
|
+
defocus_difference = np.multiply(
|
349
|
+
defocus_difference,
|
350
|
+
np.cos(2 * (angular_grid - defocus_angle[index])),
|
351
|
+
)
|
352
|
+
defocus_x = np.add(defocus_sum, defocus_difference)
|
353
|
+
defocus_x *= 0.5
|
354
|
+
|
355
|
+
frequency_grid = frequency_grid_at_angle(
|
356
|
+
shape=shape,
|
357
|
+
opening_axis=opening_axis,
|
358
|
+
tilt_axis=tilt_axis,
|
359
|
+
angle=angle,
|
360
|
+
sampling_rate=1,
|
361
|
+
)
|
362
|
+
frequency_mask = frequency_grid < cutoff_frequency
|
363
|
+
|
364
|
+
# k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
|
365
|
+
np.square(frequency_grid, out=frequency_grid)
|
366
|
+
chi = defocus_x - 0.5 * electron_aberration[index] * frequency_grid
|
367
|
+
np.multiply(chi, np.pi * electron_wavelength, out=chi)
|
368
|
+
np.multiply(chi, frequency_grid, out=chi)
|
369
|
+
chi += phase_shift[index]
|
370
|
+
chi += np.arctan(
|
371
|
+
np.divide(
|
372
|
+
amplitude_contrast[index],
|
373
|
+
np.sqrt(1 - np.square(amplitude_contrast[index])),
|
374
|
+
)
|
375
|
+
)
|
376
|
+
np.sin(-chi, out=chi)
|
377
|
+
np.multiply(chi, frequency_mask, out=chi)
|
378
|
+
|
379
|
+
if no_reconstruction:
|
380
|
+
chi = shift_fourier(data=chi, shape_is_real_fourier=False)
|
381
|
+
|
382
|
+
stack[index] = chi
|
383
|
+
|
384
|
+
# Avoid contrast inversion
|
385
|
+
np.negative(stack, out=stack)
|
386
|
+
if flip_phase:
|
387
|
+
np.abs(stack, out=stack)
|
388
|
+
|
389
|
+
stack = be.to_backend_array(np.squeeze(stack))
|
390
|
+
if no_reconstruction and return_real_fourier:
|
391
|
+
stack = crop_real_fourier(stack)
|
392
|
+
|
393
|
+
return stack
|