pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,230 @@
1
+ """ Implements class BandPassFilter to create Fourier filter representations.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple
9
+ from math import log, sqrt
10
+
11
+ from ..types import BackendArray
12
+ from ..backends import backend as be
13
+ from .compose import ComposableFilter
14
+ from ._utils import fftfreqn, crop_real_fourier, shift_fourier
15
+
16
+ __all__ = ["BandPassFilter"]
17
+
18
+
19
+ class BandPassFilter(ComposableFilter):
20
+ """
21
+ Generate bandpass filters in Fourier space.
22
+
23
+ Parameters
24
+ ----------
25
+ lowpass : float, optional
26
+ The lowpass cutoff, defaults to None.
27
+ highpass : float, optional
28
+ The highpass cutoff, defaults to None.
29
+ sampling_rate : Tuple[float], optional
30
+ The sampling r_position_to_molmapate in Fourier space, defaults to 1.
31
+ use_gaussian : bool, optional
32
+ Whether to use Gaussian bandpass filter, defaults to True.
33
+ return_real_fourier : bool, optional
34
+ Whether to return only the real Fourier space, defaults to False.
35
+ shape_is_real_fourier : bool, optional
36
+ Whether the shape represents the real Fourier space, defaults to False.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ lowpass: float = None,
42
+ highpass: float = None,
43
+ sampling_rate: Tuple[float] = 1,
44
+ use_gaussian: bool = True,
45
+ return_real_fourier: bool = False,
46
+ shape_is_real_fourier: bool = False,
47
+ ):
48
+ self.lowpass = lowpass
49
+ self.highpass = highpass
50
+ self.use_gaussian = use_gaussian
51
+ self.return_real_fourier = return_real_fourier
52
+ self.shape_is_real_fourier = shape_is_real_fourier
53
+ self.sampling_rate = sampling_rate
54
+
55
+ @staticmethod
56
+ def discrete_bandpass(
57
+ shape: Tuple[int],
58
+ lowpass: float,
59
+ highpass: float,
60
+ sampling_rate: Tuple[float],
61
+ return_real_fourier: bool = False,
62
+ shape_is_real_fourier: bool = False,
63
+ **kwargs,
64
+ ) -> BackendArray:
65
+ """
66
+ Generate a bandpass filter using discrete frequency cutoffs.
67
+
68
+ Parameters
69
+ ----------
70
+ shape : tuple of int
71
+ The shape of the bandpass filter.
72
+ lowpass : float
73
+ The lowpass cutoff in units of sampling rate.
74
+ highpass : float
75
+ The highpass cutoff in units of sampling rate.
76
+ return_real_fourier : bool, optional
77
+ Whether to return only the real Fourier space, defaults to False.
78
+ sampling_rate : float
79
+ The sampling rate in Fourier space.
80
+ shape_is_real_fourier : bool, optional
81
+ Whether the shape represents the real Fourier space, defaults to False.
82
+ **kwargs : dict
83
+ Additional keyword arguments.
84
+
85
+ Returns
86
+ -------
87
+ BackendArray
88
+ The bandpass filter in Fourier space.
89
+ """
90
+ if shape_is_real_fourier:
91
+ return_real_fourier = False
92
+
93
+ grid = fftfreqn(
94
+ shape=shape,
95
+ sampling_rate=0.5,
96
+ shape_is_real_fourier=shape_is_real_fourier,
97
+ compute_euclidean_norm=True,
98
+ )
99
+ grid = be.astype(be.to_backend_array(grid), be._float_dtype)
100
+ sampling_rate = be.to_backend_array(sampling_rate)
101
+
102
+ highcut = grid.max()
103
+ if lowpass is not None:
104
+ highcut = be.max(2 * sampling_rate / lowpass)
105
+
106
+ lowcut = 0
107
+ if highpass is not None:
108
+ lowcut = be.max(2 * sampling_rate / highpass)
109
+
110
+ bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
111
+ bandpass_filter = shift_fourier(
112
+ data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
113
+ )
114
+
115
+ if return_real_fourier:
116
+ bandpass_filter = crop_real_fourier(bandpass_filter)
117
+
118
+ return bandpass_filter
119
+
120
+ @staticmethod
121
+ def gaussian_bandpass(
122
+ shape: Tuple[int],
123
+ lowpass: float,
124
+ highpass: float,
125
+ sampling_rate: float,
126
+ return_real_fourier: bool = False,
127
+ shape_is_real_fourier: bool = False,
128
+ **kwargs,
129
+ ) -> BackendArray:
130
+ """
131
+ Generate a bandpass filter using Gaussians.
132
+
133
+ Parameters
134
+ ----------
135
+ shape : tuple of int
136
+ The shape of the bandpass filter.
137
+ lowpass : float
138
+ The lowpass cutoff in units of sampling rate.
139
+ highpass : float
140
+ The highpass cutoff in units of sampling rate.
141
+ sampling_rate : float
142
+ The sampling rate in Fourier space.
143
+ return_real_fourier : bool, optional
144
+ Whether to return only the real Fourier space, defaults to False.
145
+ shape_is_real_fourier : bool, optional
146
+ Whether the shape represents the real Fourier space, defaults to False.
147
+ **kwargs : dict
148
+ Additional keyword arguments.
149
+
150
+ Returns
151
+ -------
152
+ BackendArray
153
+ The bandpass filter in Fourier space.
154
+ """
155
+ if shape_is_real_fourier:
156
+ return_real_fourier = False
157
+
158
+ grid = fftfreqn(
159
+ shape=shape,
160
+ sampling_rate=0.5,
161
+ shape_is_real_fourier=shape_is_real_fourier,
162
+ compute_euclidean_norm=True,
163
+ )
164
+ grid = be.astype(be.to_backend_array(grid), be._float_dtype)
165
+ grid = -be.square(grid, out=grid)
166
+
167
+ has_lowpass, has_highpass = False, False
168
+ norm = float(sqrt(2 * log(2)))
169
+ upper_sampling = float(
170
+ be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
171
+ )
172
+
173
+ if lowpass is not None:
174
+ lowpass, has_lowpass = float(lowpass), True
175
+ lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
176
+ if highpass is not None:
177
+ highpass, has_highpass = float(highpass), True
178
+ highpass = be.maximum(highpass, be.eps(be._float_dtype))
179
+
180
+ if has_lowpass:
181
+ lowpass = upper_sampling / (lowpass * norm)
182
+ lowpass = be.multiply(2, be.square(lowpass))
183
+ if not has_highpass:
184
+ lowpass_filter = be.divide(grid, lowpass, out=grid)
185
+ else:
186
+ lowpass_filter = be.divide(grid, lowpass)
187
+ lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter)
188
+
189
+ if has_highpass:
190
+ highpass = upper_sampling / (highpass * norm)
191
+ highpass = be.multiply(2, be.square(highpass))
192
+ highpass_filter = be.divide(grid, highpass, out=grid)
193
+ highpass_filter = be.exp(highpass_filter, out=highpass_filter)
194
+ highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter)
195
+
196
+ if has_lowpass and not has_highpass:
197
+ bandpass_filter = lowpass_filter
198
+ elif not has_lowpass and has_highpass:
199
+ bandpass_filter = highpass_filter
200
+ elif has_lowpass and has_highpass:
201
+ bandpass_filter = be.multiply(
202
+ lowpass_filter, highpass_filter, out=lowpass_filter
203
+ )
204
+ else:
205
+ bandpass_filter = be.full(shape, fill_value=1, dtype=be._float_dtype)
206
+
207
+ bandpass_filter = shift_fourier(
208
+ data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
209
+ )
210
+
211
+ if return_real_fourier:
212
+ bandpass_filter = crop_real_fourier(bandpass_filter)
213
+
214
+ return bandpass_filter
215
+
216
+ def __call__(self, **kwargs):
217
+ func_args = vars(self)
218
+ func_args.update(kwargs)
219
+
220
+ func = self.discrete_bandpass
221
+ if func_args.get("use_gaussian"):
222
+ func = self.gaussian_bandpass
223
+
224
+ mask = func(**func_args)
225
+
226
+ return {
227
+ "data": be.to_backend_array(mask),
228
+ "sampling_rate": func_args.get("sampling_rate", 1),
229
+ "is_multiplicative_filter": True,
230
+ }
tme/filters/compose.py ADDED
@@ -0,0 +1,81 @@
1
+ """ Combine filters using an interface analogous to pytorch's Compose.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple, Dict
9
+ from abc import ABC, abstractmethod
10
+
11
+ from tme.backends import backend as be
12
+
13
+ __all__ = ["Compose", "ComposableFilter"]
14
+
15
+
16
+ class Compose:
17
+ """
18
+ Compose a series of transformations.
19
+
20
+ This class allows composing multiple transformations together. Each transformation
21
+ is expected to be a callable that accepts keyword arguments and returns metadata.
22
+
23
+ Parameters
24
+ ----------
25
+ transforms : Tuple[object]
26
+ A tuple containing transformation objects.
27
+
28
+ Returns
29
+ -------
30
+ Dict
31
+ Metadata resulting from the composed transformations.
32
+
33
+ """
34
+
35
+ def __init__(self, transforms: Tuple[object]):
36
+ self.transforms = transforms
37
+
38
+ def __call__(self, **kwargs: Dict) -> Dict:
39
+ meta = {}
40
+ if not len(self.transforms):
41
+ return meta
42
+
43
+ meta = self.transforms[0](**kwargs)
44
+ for transform in self.transforms[1:]:
45
+ kwargs.update(meta)
46
+ ret = transform(**kwargs)
47
+
48
+ if "data" not in ret:
49
+ continue
50
+
51
+ if ret.get("is_multiplicative_filter", False):
52
+ prev_data = meta.pop("data")
53
+ ret["data"] = be.multiply(ret["data"], prev_data)
54
+ ret["merge"], prev_data = None, None
55
+
56
+ meta = ret
57
+
58
+ return meta
59
+
60
+
61
+ class ComposableFilter(ABC):
62
+ """
63
+ Strategy class for composable filters.
64
+ """
65
+
66
+ @abstractmethod
67
+ def __call__(self, *args, **kwargs) -> Dict:
68
+ """
69
+
70
+ Parameters
71
+ ----------
72
+ *args : tuple
73
+ Variable length argument list.
74
+ **kwargs : dict
75
+ Arbitrary keyword arguments.
76
+
77
+ Returns
78
+ -------
79
+ Dict
80
+ A dictionary representing the result of the filtering operation.
81
+ """
tme/filters/ctf.py ADDED
@@ -0,0 +1,393 @@
1
+ """ Implements class CTF to create Fourier filter representations.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import re
9
+ import warnings
10
+ from typing import Tuple, Dict
11
+ from dataclasses import dataclass
12
+
13
+ import numpy as np
14
+
15
+ from ..types import NDArray
16
+ from ..parser import StarParser
17
+ from ..backends import backend as be
18
+ from .compose import ComposableFilter
19
+ from ._utils import (
20
+ frequency_grid_at_angle,
21
+ compute_tilt_shape,
22
+ crop_real_fourier,
23
+ fftfreqn,
24
+ shift_fourier,
25
+ )
26
+
27
+ __all__ = ["CTF"]
28
+
29
+
30
+ @dataclass
31
+ class CTF(ComposableFilter):
32
+ """
33
+ Generate a contrast transfer function mask.
34
+
35
+ References
36
+ ----------
37
+ .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
38
+ Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
39
+ """
40
+
41
+ #: The shape of the to-be reconstructed volume.
42
+ shape: Tuple[int] = None
43
+ #: The defocus value in x direction.
44
+ defocus_x: float = None
45
+ #: The tilt angles.
46
+ angles: Tuple[float] = None
47
+ #: The axis around which the wedge is opened, defaults to None.
48
+ opening_axis: int = None
49
+ #: The axis along which the tilt is applied, defaults to None.
50
+ tilt_axis: int = None
51
+ #: Whether to correct defocus gradient, defaults to False.
52
+ correct_defocus_gradient: bool = False
53
+ #: The sampling rate, defaults to 1 Angstrom / Voxel.
54
+ sampling_rate: Tuple[float] = 1
55
+ #: The acceleration voltage in Volts, defaults to 300e3.
56
+ acceleration_voltage: float = 300e3
57
+ #: The spherical aberration coefficient, defaults to 2.7e7.
58
+ spherical_aberration: float = 2.7e7
59
+ #: The amplitude contrast, defaults to 0.07.
60
+ amplitude_contrast: float = 0.07
61
+ #: The phase shift, defaults to 0.
62
+ phase_shift: float = 0
63
+ #: The defocus angle, defaults to 0.
64
+ defocus_angle: float = 0
65
+ #: The defocus value in y direction, defaults to None.
66
+ defocus_y: float = None
67
+ #: Whether the returned CTF should be phase-flipped.
68
+ flip_phase: bool = True
69
+ #: Whether to return a format compliant with rfft. Only relevant for single angles.
70
+ return_real_fourier: bool = False
71
+ #: Whether the output should not be used for n+1 dimensional reconstruction
72
+ no_reconstruction: bool = True
73
+
74
+ @classmethod
75
+ def from_file(cls, filename: str) -> "CTF":
76
+ """
77
+ Initialize :py:class:`CTF` from file.
78
+
79
+ Parameters
80
+ ----------
81
+ filename : str
82
+ The path to a file with ctf parameters. Supports the following formats:
83
+ - CTFFIND4
84
+ """
85
+ if filename.lower().endswith("star"):
86
+ data = cls._from_gctf(filename=filename)
87
+ else:
88
+ data = cls._from_ctffind(filename=filename)
89
+
90
+ return cls(
91
+ shape=None,
92
+ angles=None,
93
+ defocus_x=data["defocus_1"],
94
+ sampling_rate=data["pixel_size"],
95
+ acceleration_voltage=data["acceleration_voltage"],
96
+ spherical_aberration=data["spherical_aberration"],
97
+ amplitude_contrast=data["amplitude_contrast"],
98
+ phase_shift=data["additional_phase_shift"],
99
+ defocus_angle=np.degrees(data["azimuth_astigmatism"]),
100
+ defocus_y=data["defocus_2"],
101
+ )
102
+
103
+ @staticmethod
104
+ def _from_ctffind(filename: str):
105
+ parameter_regex = {
106
+ "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
107
+ "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
108
+ "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
109
+ "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
110
+ }
111
+
112
+ with open(filename, mode="r", encoding="utf-8") as infile:
113
+ lines = [x.strip() for x in infile.read().split("\n")]
114
+ lines = [x for x in lines if len(x)]
115
+
116
+ def _screen_params(line, params, output):
117
+ for parameter, regex_pattern in parameter_regex.items():
118
+ match = re.search(regex_pattern, line)
119
+ if match:
120
+ output[parameter] = float(match.group(1))
121
+
122
+ columns = {
123
+ "micrograph_number": 0,
124
+ "defocus_1": 1,
125
+ "defocus_2": 2,
126
+ "azimuth_astigmatism": 3,
127
+ "additional_phase_shift": 4,
128
+ "cross_correlation": 5,
129
+ "spacing": 6,
130
+ }
131
+ output = {k: [] for k in columns.keys()}
132
+ for line in lines:
133
+ if line.startswith("#"):
134
+ _screen_params(line, params=parameter_regex, output=output)
135
+ continue
136
+
137
+ values = line.split()
138
+ for key, value in columns.items():
139
+ output[key].append(float(values[value]))
140
+
141
+ for key in columns:
142
+ output[key] = np.array(output[key])
143
+
144
+ return output
145
+
146
+ @staticmethod
147
+ def _from_gctf(filename: str):
148
+ parser = StarParser(filename)
149
+ ctf_data = parser["data_"]
150
+
151
+ mapping = {
152
+ "defocus_1": ("_rlnDefocusU", float),
153
+ "defocus_2": ("_rlnDefocusV", float),
154
+ "pixel_size": ("_rlnDetectorPixelSize", float),
155
+ "acceleration_voltage": ("_rlnVoltage", float),
156
+ "spherical_aberration": ("_rlnSphericalAberration", float),
157
+ "amplitude_contrast": ("_rlnAmplitudeContrast", float),
158
+ "additional_phase_shift": (None, float),
159
+ "azimuth_astigmatism": ("_rlnDefocusAngle", float),
160
+ }
161
+ output = {}
162
+ for out_key, (key, key_dtype) in mapping.items():
163
+ if key not in ctf_data and key is not None:
164
+ warnings.warn(f"ctf_data is missing key {key}.")
165
+
166
+ key_value = ctf_data.get(key, [0])
167
+ output[out_key] = [key_dtype(x) for x in key_value]
168
+
169
+ longest_key = max(map(len, output.values()))
170
+ output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()}
171
+ return output
172
+
173
+ def __post_init__(self):
174
+ self.defocus_angle = np.radians(self.defocus_angle)
175
+
176
+ def _compute_electron_wavelength(self, acceleration_voltage: int = None):
177
+ """Computes the wavelength of an electron in angstrom."""
178
+
179
+ if acceleration_voltage is None:
180
+ acceleration_voltage = self.acceleration_voltage
181
+
182
+ # Physical constants expressed in SI units
183
+ planck_constant = 6.62606896e-34
184
+ electron_charge = 1.60217646e-19
185
+ electron_mass = 9.10938215e-31
186
+ light_velocity = 299792458
187
+
188
+ energy = electron_charge * acceleration_voltage
189
+ denominator = energy**2
190
+ denominator += 2 * energy * electron_mass * light_velocity**2
191
+ electron_wavelength = np.divide(
192
+ planck_constant * light_velocity, np.sqrt(denominator)
193
+ )
194
+ # Convert to Ångstrom
195
+ electron_wavelength *= 1e10
196
+ return electron_wavelength
197
+
198
+ def __call__(self, **kwargs) -> NDArray:
199
+ func_args = vars(self).copy()
200
+ func_args.update(kwargs)
201
+
202
+ if len(func_args["angles"]) != len(func_args["defocus_x"]):
203
+ func_args["angles"] = self.angles
204
+ func_args["return_real_fourier"] = False
205
+ func_args["tilt_axis"] = None
206
+ func_args["opening_axis"] = None
207
+
208
+ ret = self.weight(**func_args)
209
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
210
+ return {
211
+ "data": ret,
212
+ "angles": func_args["angles"],
213
+ "tilt_axis": func_args["tilt_axis"],
214
+ "opening_axis": func_args["opening_axis"],
215
+ "is_multiplicative_filter": True,
216
+ }
217
+
218
+ @staticmethod
219
+ def _pad_to_length(arr, length: int):
220
+ ret = np.atleast_1d(arr)
221
+ return np.repeat(ret, length // ret.size)
222
+
223
+ def weight(
224
+ self,
225
+ shape: Tuple[int],
226
+ defocus_x: Tuple[float],
227
+ angles: Tuple[float],
228
+ opening_axis: int = None,
229
+ tilt_axis: int = None,
230
+ amplitude_contrast: float = 0.07,
231
+ phase_shift: Tuple[float] = 0,
232
+ defocus_angle: Tuple[float] = 0,
233
+ defocus_y: Tuple[float] = None,
234
+ correct_defocus_gradient: bool = False,
235
+ sampling_rate: Tuple[float] = 1,
236
+ acceleration_voltage: float = 300e3,
237
+ spherical_aberration: float = 2.7e3,
238
+ flip_phase: bool = True,
239
+ return_real_fourier: bool = False,
240
+ no_reconstruction: bool = True,
241
+ cutoff_frequency: float = 0.5,
242
+ **kwargs: Dict,
243
+ ) -> NDArray:
244
+ """
245
+ Compute the CTF weight tilt stack.
246
+
247
+ Parameters
248
+ ----------
249
+ shape : tuple of int
250
+ The shape of the CTF.
251
+ defocus_x : tuple of float
252
+ The defocus value in x direction.
253
+ angles : tuple of float
254
+ The tilt angles.
255
+ opening_axis : int, optional
256
+ The axis around which the wedge is opened, defaults to None.
257
+ tilt_axis : int, optional
258
+ The axis along which the tilt is applied, defaults to None.
259
+ amplitude_contrast : float, optional
260
+ The amplitude contrast, defaults to 0.07.
261
+ phase_shift : tuple of float, optional
262
+ The phase shift, defaults to 0.
263
+ defocus_angle : tuple of float, optional
264
+ The defocus angle, defaults to 0.
265
+ defocus_y : tuple of float, optional
266
+ The defocus value in y direction, defaults to None.
267
+ correct_defocus_gradient : bool, optional
268
+ Whether to correct defocus gradient, defaults to False.
269
+ sampling_rate : tuple of float, optional
270
+ The sampling rate, defaults to 1.
271
+ acceleration_voltage : float, optional
272
+ The acceleration voltage in electron microscopy, defaults to 300e3.
273
+ spherical_aberration : float, optional
274
+ The spherical aberration coefficient, defaults to 2.7e3.
275
+ flip_phase : bool, optional
276
+ Whether the returned CTF should be phase-flipped.
277
+ **kwargs : Dict
278
+ Additional keyword arguments.
279
+
280
+ Returns
281
+ -------
282
+ NDArray
283
+ A stack containing the CTF weight.
284
+ """
285
+ angles = np.atleast_1d(angles)
286
+ defoci_x = self._pad_to_length(defocus_x, angles.size)
287
+ defoci_y = self._pad_to_length(defocus_y, angles.size)
288
+ phase_shift = self._pad_to_length(phase_shift, angles.size)
289
+ defocus_angle = self._pad_to_length(defocus_angle, angles.size)
290
+ spherical_aberration = self._pad_to_length(spherical_aberration, angles.size)
291
+ amplitude_contrast = self._pad_to_length(amplitude_contrast, angles.size)
292
+
293
+ sampling_rate = np.max(sampling_rate)
294
+ tilt_shape = compute_tilt_shape(
295
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
296
+ )
297
+ stack = np.zeros((len(angles), *tilt_shape))
298
+
299
+ correct_defocus_gradient &= len(shape) == 3
300
+ correct_defocus_gradient &= tilt_axis is not None
301
+ correct_defocus_gradient &= opening_axis is not None
302
+
303
+ spherical_aberration /= sampling_rate
304
+ electron_wavelength = self._compute_electron_wavelength() / sampling_rate
305
+ electron_aberration = spherical_aberration * electron_wavelength**2
306
+
307
+ for index, angle in enumerate(angles):
308
+ defocus_x, defocus_y = defoci_x[index], defoci_y[index]
309
+
310
+ defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
311
+ defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
312
+
313
+ if correct_defocus_gradient or defocus_y is not None:
314
+ grid_shape = shape
315
+ sampling = be.divide(sampling_rate, be.to_backend_array(shape))
316
+ sampling = tuple(float(x) for x in sampling)
317
+ if not no_reconstruction:
318
+ grid_shape = tilt_shape
319
+ sampling = tuple(
320
+ x for i, x in enumerate(sampling) if i != opening_axis
321
+ )
322
+
323
+ grid = fftfreqn(
324
+ shape=grid_shape,
325
+ sampling_rate=sampling,
326
+ return_sparse_grid=True,
327
+ )
328
+
329
+ # This should be done after defocus_x computation
330
+ if correct_defocus_gradient:
331
+ angle_rad = np.radians(angle)
332
+ defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
333
+ remaining_axis = tuple(
334
+ i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
335
+ )[0]
336
+
337
+ if tilt_axis > remaining_axis:
338
+ defocus_x = np.add(defocus_x, defocus_gradient)
339
+ elif tilt_axis < remaining_axis and defocus_y is not None:
340
+ defocus_y = np.add(defocus_y, defocus_gradient.T)
341
+
342
+ # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
343
+ if defocus_y is not None:
344
+ defocus_sum = np.add(defocus_x, defocus_y)
345
+ defocus_difference = np.subtract(defocus_x, defocus_y)
346
+
347
+ angular_grid = np.arctan2(grid[1], grid[0])
348
+ defocus_difference = np.multiply(
349
+ defocus_difference,
350
+ np.cos(2 * (angular_grid - defocus_angle[index])),
351
+ )
352
+ defocus_x = np.add(defocus_sum, defocus_difference)
353
+ defocus_x *= 0.5
354
+
355
+ frequency_grid = frequency_grid_at_angle(
356
+ shape=shape,
357
+ opening_axis=opening_axis,
358
+ tilt_axis=tilt_axis,
359
+ angle=angle,
360
+ sampling_rate=1,
361
+ )
362
+ frequency_mask = frequency_grid < cutoff_frequency
363
+
364
+ # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
365
+ np.square(frequency_grid, out=frequency_grid)
366
+ chi = defocus_x - 0.5 * electron_aberration[index] * frequency_grid
367
+ np.multiply(chi, np.pi * electron_wavelength, out=chi)
368
+ np.multiply(chi, frequency_grid, out=chi)
369
+ chi += phase_shift[index]
370
+ chi += np.arctan(
371
+ np.divide(
372
+ amplitude_contrast[index],
373
+ np.sqrt(1 - np.square(amplitude_contrast[index])),
374
+ )
375
+ )
376
+ np.sin(-chi, out=chi)
377
+ np.multiply(chi, frequency_mask, out=chi)
378
+
379
+ if no_reconstruction:
380
+ chi = shift_fourier(data=chi, shape_is_real_fourier=False)
381
+
382
+ stack[index] = chi
383
+
384
+ # Avoid contrast inversion
385
+ np.negative(stack, out=stack)
386
+ if flip_phase:
387
+ np.abs(stack, out=stack)
388
+
389
+ stack = be.to_backend_array(np.squeeze(stack))
390
+ if no_reconstruction and return_real_fourier:
391
+ stack = crop_real_fourier(stack)
392
+
393
+ return stack