pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
  8. pytme-0.3b0.post1.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +341 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +318 -189
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +12 -12
  19. scripts/pytme_runner.py +769 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/filters/__init__.py +3 -3
  49. tme/filters/_utils.py +36 -10
  50. tme/filters/bandpass.py +229 -188
  51. tme/filters/compose.py +5 -4
  52. tme/filters/ctf.py +516 -254
  53. tme/filters/reconstruction.py +91 -32
  54. tme/filters/wedge.py +196 -135
  55. tme/filters/whitening.py +37 -42
  56. tme/matching_data.py +28 -39
  57. tme/matching_exhaustive.py +31 -27
  58. tme/matching_optimization.py +5 -4
  59. tme/matching_scores.py +25 -15
  60. tme/matching_utils.py +54 -9
  61. tme/memory.py +4 -3
  62. tme/orientations.py +22 -9
  63. tme/parser.py +114 -33
  64. tme/preprocessor.py +6 -5
  65. tme/rotations.py +10 -7
  66. tme/structure.py +4 -3
  67. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  68. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  69. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  70. scripts/estimate_ram_usage.py +0 -97
  71. tests/data/Maps/.DS_Store +0 -0
  72. tests/data/Structures/.DS_Store +0 -0
  73. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py CHANGED
@@ -1,233 +1,149 @@
1
- """ Implements class CTF to create Fourier filter representations.
1
+ """
2
+ Implements class CTF and CTFReconstruced.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import re
9
- import warnings
10
10
  from typing import Tuple, Dict
11
11
  from dataclasses import dataclass
12
12
 
13
13
  import numpy as np
14
14
 
15
15
  from ..types import NDArray
16
- from ..parser import StarParser
17
16
  from ..backends import backend as be
18
17
  from .compose import ComposableFilter
18
+ from ..parser import StarParser, XMLParser, MDOCParser
19
19
  from ._utils import (
20
20
  frequency_grid_at_angle,
21
21
  compute_tilt_shape,
22
22
  crop_real_fourier,
23
23
  fftfreqn,
24
24
  shift_fourier,
25
+ pad_to_length,
25
26
  )
26
27
 
27
- __all__ = ["CTF"]
28
+ __all__ = ["CTF", "CTFReconstructed", "create_ctf"]
28
29
 
29
30
 
30
31
  @dataclass
31
32
  class CTF(ComposableFilter):
32
33
  """
33
- Generate a contrast transfer function mask.
34
-
35
- References
36
- ----------
37
- .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
38
- Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
34
+ Generate a per-tilt contrast transfer function mask.
39
35
  """
40
36
 
41
- #: The shape of the to-be reconstructed volume.
37
+ #: The shape of the to-be created mask.
42
38
  shape: Tuple[int] = None
43
- #: The defocus value in x direction.
44
- defocus_x: float = None
39
+ #: The defocus value in x direction (in units of sampling rate).
40
+ defocus_x: Tuple[float] = None
45
41
  #: The tilt angles.
46
42
  angles: Tuple[float] = None
47
- #: The axis around which the wedge is opened, defaults to None.
48
- opening_axis: int = None
49
- #: The axis along which the tilt is applied, defaults to None.
50
- tilt_axis: int = None
51
- #: Whether to correct defocus gradient, defaults to False.
43
+ #: The microscope projection axis, defaults to 2 (z).
44
+ opening_axis: int = 2
45
+ #: The axis along which the tilt is applied, defaults to 0 (x).
46
+ tilt_axis: int = 0
47
+ #: Whether to correct defocus gradient, defaults False.
52
48
  correct_defocus_gradient: bool = False
53
- #: The sampling rate, defaults to 1 Angstrom / Voxel.
49
+ #: The sampling rate, defaults to 1 Ångstrom / voxel.
54
50
  sampling_rate: Tuple[float] = 1
55
51
  #: The acceleration voltage in Volts, defaults to 300e3.
56
- acceleration_voltage: float = 300e3
57
- #: The spherical aberration coefficient, defaults to 2.7e7.
58
- spherical_aberration: float = 2.7e7
52
+ acceleration_voltage: Tuple[float] = 300e3
53
+ #: The spherical aberration, defaults to 2.7e7 (in units of sampling rate).
54
+ spherical_aberration: Tuple[float] = 2.7e7
59
55
  #: The amplitude contrast, defaults to 0.07.
60
- amplitude_contrast: float = 0.07
61
- #: The phase shift, defaults to 0.
62
- phase_shift: float = 0
63
- #: The defocus angle, defaults to 0.
64
- defocus_angle: float = 0
65
- #: The defocus value in y direction, defaults to None.
66
- defocus_y: float = None
67
- #: Whether the returned CTF should be phase-flipped.
56
+ amplitude_contrast: Tuple[float] = 0.07
57
+ #: The phase shift in degrees, defaults to 0.
58
+ phase_shift: Tuple[float] = 0
59
+ #: The defocus angle in degrees, defaults to 0.
60
+ defocus_angle: Tuple[float] = 0
61
+ #: The defocus value in y direction, defaults to None (in units of sampling rate).
62
+ defocus_y: Tuple[float] = None
63
+ #: Whether the returned CTF should be phase-flipped, defaults to True.
68
64
  flip_phase: bool = True
69
- #: Whether to return a format compliant with rfft. Only relevant for single angles.
65
+ #: Whether to return a ctf mask for rfft (for :py:class:`CTFReconstructed`).
70
66
  return_real_fourier: bool = False
71
- #: Whether the output should not be used for n+1 dimensional reconstruction
72
- no_reconstruction: bool = True
73
67
 
74
68
  @classmethod
75
- def from_file(cls, filename: str) -> "CTF":
69
+ def from_file(cls, filename: str, **kwargs) -> "CTF":
76
70
  """
77
71
  Initialize :py:class:`CTF` from file.
78
72
 
79
73
  Parameters
80
74
  ----------
81
75
  filename : str
82
- The path to a file with ctf parameters. Supports the following formats:
83
- - CTFFIND4
76
+ The path to a file with ctf parameters. Supports extensions are:
77
+
78
+ +-------+---------------------------------------------------------+
79
+ | .star | GCTF file |
80
+ +-------+---------------------------------------------------------+
81
+ | .xml | WARP/M XML file |
82
+ +-------+---------------------------------------------------------+
83
+ | .mdoc | SerialEM file |
84
+ +-------+---------------------------------------------------------+
85
+ | .* | CTFFIND4 file |
86
+ +-------+---------------------------------------------------------+
87
+ **kwargs : optional
88
+ Overwrite fields that cannot be extracted from input file.
84
89
  """
90
+ func = _from_ctffind
85
91
  if filename.lower().endswith("star"):
86
- data = cls._from_gctf(filename=filename)
87
- else:
88
- data = cls._from_ctffind(filename=filename)
89
-
90
- return cls(
91
- shape=None,
92
- angles=None,
93
- defocus_x=data["defocus_1"],
94
- sampling_rate=data["pixel_size"],
95
- acceleration_voltage=data["acceleration_voltage"],
96
- spherical_aberration=data["spherical_aberration"],
97
- amplitude_contrast=data["amplitude_contrast"],
98
- phase_shift=data["additional_phase_shift"],
99
- defocus_angle=np.degrees(data["azimuth_astigmatism"]),
100
- defocus_y=data["defocus_2"],
101
- )
102
-
103
- @staticmethod
104
- def _from_ctffind(filename: str):
105
- parameter_regex = {
106
- "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
107
- "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
108
- "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
109
- "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
92
+ func = _from_star
93
+ elif filename.lower().endswith("xml"):
94
+ func = _from_xml
95
+ elif filename.lower().endswith("mdoc"):
96
+ func = _from_mdoc
97
+
98
+ data = func(filename=filename)
99
+
100
+ # Pixel size needs to be overwritten by pixel size the ctf is generated for
101
+ init_kwargs = {
102
+ "shape": None,
103
+ "angles": data.get("angles", None),
104
+ "defocus_x": data["defocus_1"],
105
+ "sampling_rate": data["pixel_size"],
106
+ "acceleration_voltage": np.multiply(data["acceleration_voltage"], 1e3),
107
+ "spherical_aberration": data.get("spherical_aberration"),
108
+ "amplitude_contrast": data.get("amplitude_contrast"),
109
+ "phase_shift": data.get("additional_phase_shift"),
110
+ "defocus_angle": data.get("azimuth_astigmatism"),
111
+ "defocus_y": data["defocus_2"],
110
112
  }
111
-
112
- with open(filename, mode="r", encoding="utf-8") as infile:
113
- lines = [x.strip() for x in infile.read().split("\n")]
114
- lines = [x for x in lines if len(x)]
115
-
116
- def _screen_params(line, params, output):
117
- for parameter, regex_pattern in parameter_regex.items():
118
- match = re.search(regex_pattern, line)
119
- if match:
120
- output[parameter] = float(match.group(1))
121
-
122
- columns = {
123
- "micrograph_number": 0,
124
- "defocus_1": 1,
125
- "defocus_2": 2,
126
- "azimuth_astigmatism": 3,
127
- "additional_phase_shift": 4,
128
- "cross_correlation": 5,
129
- "spacing": 6,
130
- }
131
- output = {k: [] for k in columns.keys()}
132
- for line in lines:
133
- if line.startswith("#"):
134
- _screen_params(line, params=parameter_regex, output=output)
135
- continue
136
-
137
- values = line.split()
138
- for key, value in columns.items():
139
- output[key].append(float(values[value]))
140
-
141
- for key in columns:
142
- output[key] = np.array(output[key])
143
-
144
- return output
145
-
146
- @staticmethod
147
- def _from_gctf(filename: str):
148
- parser = StarParser(filename)
149
- ctf_data = parser["data_"]
150
-
151
- mapping = {
152
- "defocus_1": ("_rlnDefocusU", float),
153
- "defocus_2": ("_rlnDefocusV", float),
154
- "pixel_size": ("_rlnDetectorPixelSize", float),
155
- "acceleration_voltage": ("_rlnVoltage", float),
156
- "spherical_aberration": ("_rlnSphericalAberration", float),
157
- "amplitude_contrast": ("_rlnAmplitudeContrast", float),
158
- "additional_phase_shift": (None, float),
159
- "azimuth_astigmatism": ("_rlnDefocusAngle", float),
160
- }
161
- output = {}
162
- for out_key, (key, key_dtype) in mapping.items():
163
- if key not in ctf_data and key is not None:
164
- warnings.warn(f"ctf_data is missing key {key}.")
165
-
166
- key_value = ctf_data.get(key, [0])
167
- output[out_key] = [key_dtype(x) for x in key_value]
168
-
169
- longest_key = max(map(len, output.values()))
170
- output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()}
171
- return output
113
+ for k, v in kwargs.items():
114
+ if k in init_kwargs and init_kwargs.get(k) is None:
115
+ init_kwargs[k] = v
116
+ init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
117
+ return cls(**init_kwargs)
172
118
 
173
119
  def __post_init__(self):
174
120
  self.defocus_angle = np.radians(self.defocus_angle)
175
-
176
- def _compute_electron_wavelength(self, acceleration_voltage: int = None):
177
- """Computes the wavelength of an electron in angstrom."""
178
-
179
- if acceleration_voltage is None:
180
- acceleration_voltage = self.acceleration_voltage
181
-
182
- # Physical constants expressed in SI units
183
- planck_constant = 6.62606896e-34
184
- electron_charge = 1.60217646e-19
185
- electron_mass = 9.10938215e-31
186
- light_velocity = 299792458
187
-
188
- energy = electron_charge * acceleration_voltage
189
- denominator = energy**2
190
- denominator += 2 * energy * electron_mass * light_velocity**2
191
- electron_wavelength = np.divide(
192
- planck_constant * light_velocity, np.sqrt(denominator)
193
- )
194
- # Convert to Ångstrom
195
- electron_wavelength *= 1e10
196
- return electron_wavelength
121
+ self.phase_shift = np.radians(self.phase_shift)
197
122
 
198
123
  def __call__(self, **kwargs) -> NDArray:
124
+ """
125
+ Returns a CTF stack of chosen parameters with DC component in the center.
126
+ """
199
127
  func_args = vars(self).copy()
200
128
  func_args.update(kwargs)
201
129
 
202
- if len(func_args["angles"]) != len(func_args["defocus_x"]):
203
- func_args["angles"] = self.angles
204
- func_args["return_real_fourier"] = False
205
- func_args["tilt_axis"] = None
206
- func_args["opening_axis"] = None
207
-
208
130
  ret = self.weight(**func_args)
209
131
  ret = be.astype(be.to_backend_array(ret), be._float_dtype)
210
132
  return {
211
133
  "data": ret,
212
- "angles": func_args["angles"],
213
- "tilt_axis": func_args["tilt_axis"],
214
- "opening_axis": func_args["opening_axis"],
134
+ "shape": func_args["shape"],
135
+ "return_real_fourier": func_args.get("return_real_fourier"),
215
136
  "is_multiplicative_filter": True,
216
137
  }
217
138
 
218
- @staticmethod
219
- def _pad_to_length(arr, length: int):
220
- ret = np.atleast_1d(arr)
221
- return np.repeat(ret, length // ret.size)
222
-
223
139
  def weight(
224
140
  self,
225
141
  shape: Tuple[int],
226
142
  defocus_x: Tuple[float],
227
143
  angles: Tuple[float],
228
- opening_axis: int = None,
229
- tilt_axis: int = None,
230
- amplitude_contrast: float = 0.07,
144
+ opening_axis: int = 2,
145
+ tilt_axis: int = 0,
146
+ amplitude_contrast: Tuple[float] = 0.07,
231
147
  phase_shift: Tuple[float] = 0,
232
148
  defocus_angle: Tuple[float] = 0,
233
149
  defocus_y: Tuple[float] = None,
@@ -237,7 +153,6 @@ class CTF(ComposableFilter):
237
153
  spherical_aberration: float = 2.7e3,
238
154
  flip_phase: bool = True,
239
155
  return_real_fourier: bool = False,
240
- no_reconstruction: bool = True,
241
156
  cutoff_frequency: float = 0.5,
242
157
  **kwargs: Dict,
243
158
  ) -> NDArray:
@@ -253,15 +168,15 @@ class CTF(ComposableFilter):
253
168
  angles : tuple of float
254
169
  The tilt angles.
255
170
  opening_axis : int, optional
256
- The axis around which the wedge is opened, defaults to None.
171
+ The axis around which the wedge is opened, defaults to 2.
257
172
  tilt_axis : int, optional
258
- The axis along which the tilt is applied, defaults to None.
259
- amplitude_contrast : float, optional
173
+ The axis along which the tilt is applied, defaults to 0.
174
+ amplitude_contrast : tuple of float, optional
260
175
  The amplitude contrast, defaults to 0.07.
261
176
  phase_shift : tuple of float, optional
262
- The phase shift, defaults to 0.
177
+ The phase shift in radians, defaults to 0.
263
178
  defocus_angle : tuple of float, optional
264
- The defocus angle, defaults to 0.
179
+ The defocus angle in radians, defaults to 0.
265
180
  defocus_y : tuple of float, optional
266
181
  The defocus value in y direction, defaults to None.
267
182
  correct_defocus_gradient : bool, optional
@@ -273,7 +188,7 @@ class CTF(ComposableFilter):
273
188
  spherical_aberration : float, optional
274
189
  The spherical aberration coefficient, defaults to 2.7e3.
275
190
  flip_phase : bool, optional
276
- Whether the returned CTF should be phase-flipped.
191
+ Whether the returned CTF should be phase-flipped, defaults to True.
277
192
  **kwargs : Dict
278
193
  Additional keyword arguments.
279
194
 
@@ -283,101 +198,47 @@ class CTF(ComposableFilter):
283
198
  A stack containing the CTF weight.
284
199
  """
285
200
  angles = np.atleast_1d(angles)
286
- defoci_x = self._pad_to_length(defocus_x, angles.size)
287
- defoci_y = self._pad_to_length(defocus_y, angles.size)
288
- phase_shift = self._pad_to_length(phase_shift, angles.size)
289
- defocus_angle = self._pad_to_length(defocus_angle, angles.size)
290
- spherical_aberration = self._pad_to_length(spherical_aberration, angles.size)
291
- amplitude_contrast = self._pad_to_length(amplitude_contrast, angles.size)
201
+ defoci_x = pad_to_length(defocus_x, angles.size)
202
+ defoci_y = pad_to_length(defocus_y, angles.size)
203
+ phase_shift = pad_to_length(phase_shift, angles.size)
204
+ defocus_angle = pad_to_length(defocus_angle, angles.size)
205
+ spherical_aberration = pad_to_length(spherical_aberration, angles.size)
206
+ amplitude_contrast = pad_to_length(amplitude_contrast, angles.size)
207
+ acceleration_voltage = pad_to_length(acceleration_voltage, angles.size)
292
208
 
293
209
  sampling_rate = np.max(sampling_rate)
294
- tilt_shape = compute_tilt_shape(
210
+ ctf_shape = compute_tilt_shape(
295
211
  shape=shape, opening_axis=opening_axis, reduce_dim=True
296
212
  )
297
- stack = np.zeros((len(angles), *tilt_shape))
213
+ stack = np.zeros((len(angles), *ctf_shape))
298
214
 
299
- correct_defocus_gradient &= len(shape) == 3
300
- correct_defocus_gradient &= tilt_axis is not None
301
- correct_defocus_gradient &= opening_axis is not None
302
-
303
- spherical_aberration /= sampling_rate
304
- electron_wavelength = self._compute_electron_wavelength() / sampling_rate
305
- electron_aberration = spherical_aberration * electron_wavelength**2
215
+ # Shift tilt axis forward
216
+ corrected_tilt_axis = tilt_axis
217
+ if opening_axis and tilt_axis is not None:
218
+ if opening_axis < tilt_axis:
219
+ corrected_tilt_axis -= 1
306
220
 
307
221
  for index, angle in enumerate(angles):
308
222
  defocus_x, defocus_y = defoci_x[index], defoci_y[index]
309
223
 
310
- defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
311
- defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
312
-
313
- if correct_defocus_gradient or defocus_y is not None:
314
- grid_shape = shape
315
- sampling = be.divide(sampling_rate, be.to_backend_array(shape))
316
- sampling = tuple(float(x) for x in sampling)
317
- if not no_reconstruction:
318
- grid_shape = tilt_shape
319
- sampling = tuple(
320
- x for i, x in enumerate(sampling) if i != opening_axis
321
- )
322
-
323
- grid = fftfreqn(
324
- shape=grid_shape,
325
- sampling_rate=sampling,
326
- return_sparse_grid=True,
327
- )
328
-
329
- # This should be done after defocus_x computation
330
- if correct_defocus_gradient:
331
- angle_rad = np.radians(angle)
332
- defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
333
- remaining_axis = tuple(
334
- i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
335
- )[0]
336
-
337
- if tilt_axis > remaining_axis:
338
- defocus_x = np.add(defocus_x, defocus_gradient)
339
- elif tilt_axis < remaining_axis and defocus_y is not None:
340
- defocus_y = np.add(defocus_y, defocus_gradient.T)
341
-
342
- # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
343
- if defocus_y is not None:
344
- defocus_sum = np.add(defocus_x, defocus_y)
345
- defocus_difference = np.subtract(defocus_x, defocus_y)
346
-
347
- angular_grid = np.arctan2(grid[1], grid[0])
348
- defocus_difference = np.multiply(
349
- defocus_difference,
350
- np.cos(2 * (angular_grid - defocus_angle[index])),
351
- )
352
- defocus_x = np.add(defocus_sum, defocus_difference)
353
- defocus_x *= 0.5
354
-
355
- frequency_grid = frequency_grid_at_angle(
356
- shape=shape,
357
- opening_axis=opening_axis,
358
- tilt_axis=tilt_axis,
224
+ correction = correct_defocus_gradient and angle is not None
225
+ chi = create_ctf(
359
226
  angle=angle,
360
- sampling_rate=1,
361
- )
362
- frequency_mask = frequency_grid < cutoff_frequency
363
-
364
- # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
365
- np.square(frequency_grid, out=frequency_grid)
366
- chi = defocus_x - 0.5 * electron_aberration[index] * frequency_grid
367
- np.multiply(chi, np.pi * electron_wavelength, out=chi)
368
- np.multiply(chi, frequency_grid, out=chi)
369
- chi += phase_shift[index]
370
- chi += np.arctan(
371
- np.divide(
372
- amplitude_contrast[index],
373
- np.sqrt(1 - np.square(amplitude_contrast[index])),
374
- )
227
+ shape=ctf_shape,
228
+ defocus_x=defocus_x,
229
+ defocus_y=defocus_y,
230
+ sampling_rate=sampling_rate,
231
+ acceleration_voltage=acceleration_voltage[index],
232
+ correct_defocus_gradient=correction,
233
+ spherical_aberration=spherical_aberration[index],
234
+ cutoff_frequency=cutoff_frequency,
235
+ phase_shift=phase_shift[index],
236
+ defocus_angle=defocus_angle[index],
237
+ amplitude_contrast=amplitude_contrast[index],
238
+ tilt_axis=corrected_tilt_axis,
239
+ opening_axis=opening_axis,
240
+ full_shape=shape,
375
241
  )
376
- np.sin(-chi, out=chi)
377
- np.multiply(chi, frequency_mask, out=chi)
378
-
379
- if no_reconstruction:
380
- chi = shift_fourier(data=chi, shape_is_real_fourier=False)
381
242
 
382
243
  stack[index] = chi
383
244
 
@@ -387,7 +248,408 @@ class CTF(ComposableFilter):
387
248
  np.abs(stack, out=stack)
388
249
 
389
250
  stack = be.to_backend_array(np.squeeze(stack))
390
- if no_reconstruction and return_real_fourier:
251
+ return stack
252
+
253
+
254
+ class CTFReconstructed(CTF):
255
+ """
256
+ Create a simple contrast transfer function mask without the ability to specify
257
+ per-tilt parameters like in :py:class:`CTF`.
258
+ """
259
+
260
+ def weight(
261
+ self,
262
+ shape: Tuple[int],
263
+ defocus_x: Tuple[float],
264
+ amplitude_contrast: float = 0.07,
265
+ phase_shift: Tuple[float] = 0,
266
+ defocus_angle: Tuple[float] = 0,
267
+ defocus_y: Tuple[float] = None,
268
+ sampling_rate: Tuple[float] = 1,
269
+ acceleration_voltage: float = 300e3,
270
+ spherical_aberration: float = 2.7e3,
271
+ flip_phase: bool = True,
272
+ return_real_fourier: bool = False,
273
+ cutoff_frequency: float = 0.5,
274
+ **kwargs: Dict,
275
+ ) -> NDArray:
276
+ """
277
+ Compute the CTF weight tilt stack.
278
+
279
+ Parameters
280
+ ----------
281
+ shape : tuple of int
282
+ The shape of the CTF.
283
+ defocus_x : tuple of float
284
+ The defocus value in x direction.
285
+ opening_axis : int, optional
286
+ The axis around which the wedge is opened, defaults to 2.
287
+ amplitude_contrast : float, optional
288
+ The amplitude contrast, defaults to 0.07.
289
+ phase_shift : tuple of float, optional
290
+ The phase shift in radians, defaults to 0.
291
+ defocus_angle : tuple of float, optional
292
+ The defocus angle in radians, defaults to 0.
293
+ defocus_y : tuple of float, optional
294
+ The defocus value in y direction, defaults to None.
295
+ sampling_rate : tuple of float, optional
296
+ The sampling rate, defaults to 1.
297
+ acceleration_voltage : float, optional
298
+ The acceleration voltage in electron microscopy, defaults to 300e3.
299
+ spherical_aberration : float, optional
300
+ The spherical aberration coefficient, defaults to 2.7e3.
301
+ flip_phase : bool, optional
302
+ Whether the returned CTF should be phase-flipped.
303
+ **kwargs : Dict
304
+ Additional keyword arguments.
305
+
306
+ Returns
307
+ -------
308
+ NDArray
309
+ A stack containing the CTF weight.
310
+ """
311
+ stack = create_ctf(
312
+ shape=shape,
313
+ defocus_x=defocus_x,
314
+ defocus_y=defocus_y,
315
+ sampling_rate=np.max(sampling_rate),
316
+ acceleration_voltage=self.acceleration_voltage,
317
+ correct_defocus_gradient=False,
318
+ spherical_aberration=spherical_aberration,
319
+ cutoff_frequency=cutoff_frequency,
320
+ phase_shift=phase_shift,
321
+ defocus_angle=defocus_angle,
322
+ amplitude_contrast=amplitude_contrast,
323
+ )
324
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
325
+
326
+ # Avoid contrast inversion
327
+ np.negative(stack, out=stack)
328
+ if flip_phase:
329
+ np.abs(stack, out=stack)
330
+
331
+ stack = be.to_backend_array(np.squeeze(stack))
332
+ if return_real_fourier:
391
333
  stack = crop_real_fourier(stack)
392
334
 
393
335
  return stack
336
+
337
+
338
+ def _from_xml(filename: str) -> Dict:
339
+ data = XMLParser(filename)
340
+
341
+ params = {
342
+ "PhaseShift": None,
343
+ "Amplitude": None,
344
+ "Defocus": None,
345
+ "Voltage": None,
346
+ "Cs": None,
347
+ "DefocusAngle": None,
348
+ "PixelSize": None,
349
+ "Angles": data["Angles"],
350
+ }
351
+
352
+ ctf_options = data["CTF"]["Param"]
353
+ for option in ctf_options:
354
+ option = option["@attributes"]
355
+ name = option["Name"]
356
+ if name in params:
357
+ params[name] = option["Value"]
358
+
359
+ if "GridCTF" in data:
360
+ ctf = data["GridCTF"]["Node"]
361
+ params["Defocus"] = [ctf[i]["@attributes"]["Value"] for i in range(len(ctf))]
362
+ ctf_phase = data["GridCTFPhase"]["Node"]
363
+ params["PhaseShift"] = [
364
+ ctf_phase[i]["@attributes"]["Value"] for i in range(len(ctf_phase))
365
+ ]
366
+ params["PhaseShift"] = np.degrees(params["PhaseShift"])
367
+ ctf_ast = data["GridCTFDefocusAngle"]["Node"]
368
+ params["DefocusAngle"] = [
369
+ ctf_ast[i]["@attributes"]["Value"] for i in range(len(ctf_ast))
370
+ ]
371
+
372
+ missing = [k for k, v in params.items() if v is None]
373
+ if len(missing):
374
+ raise ValueError(f"Could not find {missing} in {filename}.")
375
+
376
+ params = {
377
+ k: np.array(v) if hasattr(v, "__len__") else float(v) for k, v in params.items()
378
+ }
379
+
380
+ # Convert units to sampling rate (we assume it is Angstrom)
381
+ params["Cs"] = float(params["Cs"] * 1e7)
382
+ params["Defocus"] = params["Defocus"] * 1e4
383
+
384
+ mapping = {
385
+ "angles": "Angles",
386
+ "defocus_1": "Defocus",
387
+ "defocus_2": "Defocus",
388
+ "azimuth_astigmatism": "DefocusAngle",
389
+ "additional_phase_shift": "PhaseShift",
390
+ "acceleration_voltage": "Voltage",
391
+ "spherical_aberration": "Cs",
392
+ "amplitude_contrast": "Amplitude",
393
+ "pixel_size": "PixelSize",
394
+ }
395
+ return {k: params[v] for k, v in mapping.items()}
396
+
397
+
398
+ def _from_ctffind(filename: str) -> Dict:
399
+ parameter_regex = {
400
+ "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
401
+ "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
402
+ "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
403
+ "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
404
+ }
405
+
406
+ with open(filename, mode="r", encoding="utf-8") as infile:
407
+ lines = [x.strip() for x in infile.read().split("\n")]
408
+ lines = [x for x in lines if len(x)]
409
+
410
+ def _screen_params(line, params, output):
411
+ for parameter, regex_pattern in parameter_regex.items():
412
+ match = re.search(regex_pattern, line)
413
+ if match:
414
+ output[parameter] = float(match.group(1))
415
+
416
+ columns = {
417
+ "micrograph_number": 0,
418
+ "defocus_1": 1,
419
+ "defocus_2": 2,
420
+ "azimuth_astigmatism": 3,
421
+ "additional_phase_shift": 4,
422
+ "cross_correlation": 5,
423
+ "spacing": 6,
424
+ }
425
+ output = {k: [] for k in columns.keys()}
426
+ for line in lines:
427
+ if line.startswith("#"):
428
+ _screen_params(line, params=parameter_regex, output=output)
429
+ continue
430
+
431
+ values = line.split()
432
+ for key, value in columns.items():
433
+ output[key].append(float(values[value]))
434
+
435
+ for key in columns:
436
+ output[key] = np.array(output[key])
437
+
438
+ output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
439
+ return output
440
+
441
+
442
+ def _from_star(filename: str) -> Dict:
443
+ parser = StarParser(filename)
444
+
445
+ if "data_stopgap_wedgelist" in parser:
446
+ key = "data_stopgap_wedgelist"
447
+ mapping = {
448
+ "angles": ("_tilt_angle", float, 1),
449
+ "defocus_1": ("_defocus", float, 1e4),
450
+ "defocus_2": (None, float, 1e4),
451
+ "pixel_size": ("_pixelsize", float, 1),
452
+ "acceleration_voltage": ("_voltage", float, 1),
453
+ "spherical_aberration": ("_cs", float, 1e7),
454
+ "amplitude_contrast": ("_amp_contrast", float, 1),
455
+ "additional_phase_shift": (None, float, 1),
456
+ "azimuth_astigmatism": (None, float, 1),
457
+ }
458
+ else:
459
+ key = "data_"
460
+ mapping = {
461
+ "defocus_1": ("_rlnDefocusU", float, 1),
462
+ "defocus_2": ("_rlnDefocusV", float, 1),
463
+ "pixel_size": ("_rlnDetectorPixelSize", float, 1),
464
+ "acceleration_voltage": ("_rlnVoltage", float, 1),
465
+ "spherical_aberration": ("_rlnSphericalAberration", float, 1),
466
+ "amplitude_contrast": ("_rlnAmplitudeContrast", float, 1),
467
+ "additional_phase_shift": (None, float, 1),
468
+ "azimuth_astigmatism": ("_rlnDefocusAngle", float, 1),
469
+ }
470
+
471
+ output = {}
472
+ ctf_data = parser[key]
473
+ for out_key, (key, key_dtype, scale) in mapping.items():
474
+ key_value = ctf_data.get(key)
475
+ if key_value is not None:
476
+ try:
477
+ key_value = [key_dtype(x) * scale for x in key_value]
478
+ except Exception:
479
+ pass
480
+ output[out_key] = key_value
481
+ return output
482
+
483
+
484
+ def _from_mdoc(filename: str) -> Dict:
485
+ parser = MDOCParser(filename)
486
+
487
+ mapping = {
488
+ "angles": ("TiltAngle", float),
489
+ "defocus_1": ("Defocus", float),
490
+ "acceleration_voltage": ("Voltage", float),
491
+ # These will be None, but on purpose
492
+ "pixel_size": ("_rlnDetectorPixelSize", float),
493
+ "defocus_2": ("Defocus2", float),
494
+ "spherical_aberration": ("_rlnSphericalAberration", float),
495
+ "amplitude_contrast": ("_rlnAmplitudeContrast", float),
496
+ "additional_phase_shift": (None, float),
497
+ "azimuth_astigmatism": ("_rlnDefocusAngle", float),
498
+ }
499
+ output = {}
500
+ for out_key, (key, key_dtype) in mapping.items():
501
+ output[out_key] = parser.get(key, None)
502
+
503
+ # Adjust convention and convert to Angstrom
504
+ output["defocus_1"] = np.multiply(output["defocus_1"], -1e4)
505
+ return output
506
+
507
+
508
+ def _compute_electron_wavelength(acceleration_voltage: int = None):
509
+ """Computes the wavelength of an electron in angstrom."""
510
+
511
+ # Physical constants expressed in SI units
512
+ planck_constant = 6.62606896e-34
513
+ electron_charge = 1.60217646e-19
514
+ electron_mass = 9.10938215e-31
515
+ light_velocity = 299792458
516
+
517
+ energy = electron_charge * acceleration_voltage
518
+ denominator = energy**2
519
+ denominator += 2 * energy * electron_mass * light_velocity**2
520
+ electron_wavelength = np.divide(
521
+ planck_constant * light_velocity, np.sqrt(denominator)
522
+ )
523
+ # Convert to Ångstrom
524
+ electron_wavelength *= 1e10
525
+ return electron_wavelength
526
+
527
+
528
+ def create_ctf(
529
+ shape: Tuple[int],
530
+ defocus_x: float,
531
+ acceleration_voltage: float = 300,
532
+ defocus_angle: float = 0,
533
+ phase_shift: float = 0,
534
+ defocus_y: float = None,
535
+ sampling_rate: float = 1,
536
+ spherical_aberration: float = 2.7e7,
537
+ amplitude_contrast: float = 0.07,
538
+ correct_defocus_gradient: bool = False,
539
+ cutoff_frequency: float = 0.5,
540
+ angle: float = None,
541
+ tilt_axis: int = 0,
542
+ opening_axis: int = None,
543
+ full_shape: Tuple[int] = None,
544
+ ) -> NDArray:
545
+ """
546
+ Create CTF representation using the definition from [1]_.
547
+
548
+ Parameters
549
+ ----------
550
+ shape : Tuple[int]
551
+ Shape of the returned CTF mask.
552
+ defocus_x : float
553
+ Defocus in x in units of sampling rate, e.g. 30000 Angstrom.
554
+ acceleration_voltage : float, optional
555
+ Acceleration voltage in keV, defaults to 300.
556
+ defocus_angle : float, optional
557
+ Astigmatism in radians, defaults to 0.
558
+ phase_shift : float, optional
559
+ Phase shift from phase plate in radians, defaults to 0.
560
+ defocus_y : float, optional
561
+ Defocus in y in units of sampling rate.
562
+ tilt_axis : int, optional
563
+ Axes the specimen was tilted over, defaults to 0 (x-axis).
564
+ sampling_rate : float or tuple of floats
565
+ Sampling rate throughout shape, e.g., 4 Angstrom per voxel.
566
+ amplitude_contrast : float, optional
567
+ Amplitude contrast of microscope, defaults to 0.07.
568
+ spherical_aberration : float, optional
569
+ Spherical aberration of microscope in Angstrom.
570
+ angle : float, optional
571
+ Assume the created CTF is a projection over opening_axis observed at angle.
572
+ opening_axis : int, optional
573
+ Projection axis, only relevant if angle is given.
574
+ full_shape : tuple of ints
575
+ Shape of the entire volume we are observing a projection of. This is required
576
+ to compute aspect ratios for correct scaling. For instance, the 2D CTF slice
577
+ could be (50,50), while the final 3D CTF volume is (50,50,25) with the
578
+ opening_axis being 2, i.e., the z-axis.
579
+
580
+ Returns
581
+ -------
582
+ NDArray
583
+ CTF mask.
584
+
585
+ References
586
+ ----------
587
+ .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
588
+ Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
589
+ """
590
+ electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
591
+ electron_wavelength /= sampling_rate
592
+ aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
593
+ if correct_defocus_gradient or defocus_y is not None:
594
+ if len(shape) < 2:
595
+ raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
596
+
597
+ sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
598
+ grid = fftfreqn(
599
+ shape=shape,
600
+ sampling_rate=sampling,
601
+ return_sparse_grid=True,
602
+ )
603
+
604
+ # This should be done after defocus_x computation
605
+ if correct_defocus_gradient:
606
+ if angle is None:
607
+ raise ValueError("Cannot correct for defocus gradient without angle.")
608
+
609
+ angle_rad = np.radians(angle)
610
+ defocus_gradient = np.multiply(grid[tilt_axis], np.sin(angle_rad))
611
+
612
+ if tilt_axis == 0:
613
+ defocus_x = np.add(defocus_x, defocus_gradient)
614
+ elif tilt_axis == 1 and defocus_y is not None:
615
+ defocus_y = np.add(defocus_y, defocus_gradient)
616
+
617
+ # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
618
+ if defocus_y is not None:
619
+ defocus_sum = np.add(defocus_x, defocus_y)
620
+ defocus_difference = np.subtract(defocus_x, defocus_y)
621
+
622
+ angular_grid = np.arctan2(grid[1], grid[0])
623
+ defocus_difference = np.multiply(
624
+ defocus_difference,
625
+ np.cos(2 * (angular_grid - defocus_angle)),
626
+ )
627
+ defocus_x = np.add(defocus_sum, defocus_difference)
628
+ defocus_x *= 0.5
629
+
630
+ frequency_grid = fftfreqn(shape, sampling_rate=True, compute_euclidean_norm=True)
631
+ if angle is not None and opening_axis is not None and full_shape is not None:
632
+ frequency_grid = frequency_grid_at_angle(
633
+ shape=full_shape,
634
+ tilt_axis=tilt_axis,
635
+ opening_axis=opening_axis,
636
+ angle=angle,
637
+ sampling_rate=1,
638
+ )
639
+ frequency_mask = frequency_grid < cutoff_frequency
640
+
641
+ # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
642
+ frequency_grid = np.square(frequency_grid, out=frequency_grid)
643
+ chi = defocus_x - 0.5 * aberration * frequency_grid
644
+ chi = np.multiply(chi, np.pi * electron_wavelength, out=chi)
645
+ chi = np.multiply(chi, frequency_grid, out=chi)
646
+ chi += phase_shift
647
+ chi += np.arctan(
648
+ np.divide(
649
+ amplitude_contrast,
650
+ np.sqrt(1 - np.square(amplitude_contrast)),
651
+ )
652
+ )
653
+ chi = np.sin(-chi, out=chi)
654
+ chi = np.multiply(chi, frequency_mask, out=chi)
655
+ return chi