pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
  7. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -29
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +385 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -88
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +6 -4
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Implements class CTF to create Fourier filter representations.
1
+ """
2
+ Implements class CTF and CTFReconstruced.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import re
@@ -13,9 +14,9 @@ from dataclasses import dataclass
13
14
  import numpy as np
14
15
 
15
16
  from ..types import NDArray
16
- from ..parser import StarParser
17
17
  from ..backends import backend as be
18
18
  from .compose import ComposableFilter
19
+ from ..parser import StarParser, XMLParser
19
20
  from ._utils import (
20
21
  frequency_grid_at_angle,
21
22
  compute_tilt_shape,
@@ -24,18 +25,13 @@ from ._utils import (
24
25
  shift_fourier,
25
26
  )
26
27
 
27
- __all__ = ["CTF"]
28
+ __all__ = ["CTF", "CTFReconstructed", "create_ctf"]
28
29
 
29
30
 
30
31
  @dataclass
31
32
  class CTF(ComposableFilter):
32
33
  """
33
- Generate a contrast transfer function mask.
34
-
35
- References
36
- ----------
37
- .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
38
- Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
34
+ Generate a per-tilt contrast transfer function mask.
39
35
  """
40
36
 
41
37
  #: The shape of the to-be reconstructed volume.
@@ -44,11 +40,11 @@ class CTF(ComposableFilter):
44
40
  defocus_x: float = None
45
41
  #: The tilt angles.
46
42
  angles: Tuple[float] = None
47
- #: The axis around which the wedge is opened, defaults to None.
48
- opening_axis: int = None
49
- #: The axis along which the tilt is applied, defaults to None.
50
- tilt_axis: int = None
51
- #: Whether to correct defocus gradient, defaults to False.
43
+ #: The microscope projection axis, defaults to None.
44
+ opening_axis: int = 2
45
+ #: The axis along which the tilt is applied, defaults to 2 (z).
46
+ tilt_axis: int = 0
47
+ #: Whether to correct defocus gradient, defaults to 0 (x).
52
48
  correct_defocus_gradient: bool = False
53
49
  #: The sampling rate, defaults to 1 Angstrom / Voxel.
54
50
  sampling_rate: Tuple[float] = 1
@@ -58,9 +54,9 @@ class CTF(ComposableFilter):
58
54
  spherical_aberration: float = 2.7e7
59
55
  #: The amplitude contrast, defaults to 0.07.
60
56
  amplitude_contrast: float = 0.07
61
- #: The phase shift, defaults to 0.
57
+ #: The phase shift in degrees, defaults to 0.
62
58
  phase_shift: float = 0
63
- #: The defocus angle, defaults to 0.
59
+ #: The defocus angle in degrees, defaults to 0.
64
60
  defocus_angle: float = 0
65
61
  #: The defocus value in y direction, defaults to None.
66
62
  defocus_y: float = None
@@ -68,8 +64,6 @@ class CTF(ComposableFilter):
68
64
  flip_phase: bool = True
69
65
  #: Whether to return a format compliant with rfft. Only relevant for single angles.
70
66
  return_real_fourier: bool = False
71
- #: Whether the output should not be used for n+1 dimensional reconstruction
72
- no_reconstruction: bool = True
73
67
 
74
68
  @classmethod
75
69
  def from_file(cls, filename: str) -> "CTF":
@@ -79,139 +73,52 @@ class CTF(ComposableFilter):
79
73
  Parameters
80
74
  ----------
81
75
  filename : str
82
- The path to a file with ctf parameters. Supports the following formats:
83
- - CTFFIND4
76
+ The path to a file with ctf parameters. Supports extensions are:
77
+
78
+ +-------+---------------------------------------------------------+
79
+ | .star | GCTF file |
80
+ +-------+---------------------------------------------------------+
81
+ | .xml | WARP/M XML file |
82
+ +-------+---------------------------------------------------------+
83
+ | .* | CTFFIND4 file |
84
+ +-------+---------------------------------------------------------+
84
85
  """
86
+ func = _from_ctffind
85
87
  if filename.lower().endswith("star"):
86
- data = cls._from_gctf(filename=filename)
87
- else:
88
- data = cls._from_ctffind(filename=filename)
88
+ func = _from_gctf
89
+ elif filename.lower().endswith("xml"):
90
+ func = _from_xml
89
91
 
92
+ data = func(filename=filename)
93
+
94
+ # Pixel size needs to be overwritten by pixel size the ctf is generated for
90
95
  return cls(
91
96
  shape=None,
92
- angles=None,
97
+ angles=data.get("angles", None),
93
98
  defocus_x=data["defocus_1"],
94
99
  sampling_rate=data["pixel_size"],
95
100
  acceleration_voltage=data["acceleration_voltage"],
96
101
  spherical_aberration=data["spherical_aberration"],
97
102
  amplitude_contrast=data["amplitude_contrast"],
98
103
  phase_shift=data["additional_phase_shift"],
99
- defocus_angle=np.degrees(data["azimuth_astigmatism"]),
104
+ defocus_angle=data["azimuth_astigmatism"],
100
105
  defocus_y=data["defocus_2"],
101
106
  )
102
107
 
103
- @staticmethod
104
- def _from_ctffind(filename: str):
105
- parameter_regex = {
106
- "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
107
- "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
108
- "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
109
- "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
110
- }
111
-
112
- with open(filename, mode="r", encoding="utf-8") as infile:
113
- lines = [x.strip() for x in infile.read().split("\n")]
114
- lines = [x for x in lines if len(x)]
115
-
116
- def _screen_params(line, params, output):
117
- for parameter, regex_pattern in parameter_regex.items():
118
- match = re.search(regex_pattern, line)
119
- if match:
120
- output[parameter] = float(match.group(1))
121
-
122
- columns = {
123
- "micrograph_number": 0,
124
- "defocus_1": 1,
125
- "defocus_2": 2,
126
- "azimuth_astigmatism": 3,
127
- "additional_phase_shift": 4,
128
- "cross_correlation": 5,
129
- "spacing": 6,
130
- }
131
- output = {k: [] for k in columns.keys()}
132
- for line in lines:
133
- if line.startswith("#"):
134
- _screen_params(line, params=parameter_regex, output=output)
135
- continue
136
-
137
- values = line.split()
138
- for key, value in columns.items():
139
- output[key].append(float(values[value]))
140
-
141
- for key in columns:
142
- output[key] = np.array(output[key])
143
-
144
- return output
145
-
146
- @staticmethod
147
- def _from_gctf(filename: str):
148
- parser = StarParser(filename)
149
- ctf_data = parser["data_"]
150
-
151
- mapping = {
152
- "defocus_1": ("_rlnDefocusU", float),
153
- "defocus_2": ("_rlnDefocusV", float),
154
- "pixel_size": ("_rlnDetectorPixelSize", float),
155
- "acceleration_voltage": ("_rlnVoltage", float),
156
- "spherical_aberration": ("_rlnSphericalAberration", float),
157
- "amplitude_contrast": ("_rlnAmplitudeContrast", float),
158
- "additional_phase_shift": (None, float),
159
- "azimuth_astigmatism": ("_rlnDefocusAngle", float),
160
- }
161
- output = {}
162
- for out_key, (key, key_dtype) in mapping.items():
163
- if key not in ctf_data and key is not None:
164
- warnings.warn(f"ctf_data is missing key {key}.")
165
-
166
- key_value = ctf_data.get(key, [0])
167
- output[out_key] = [key_dtype(x) for x in key_value]
168
-
169
- longest_key = max(map(len, output.values()))
170
- output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()}
171
- return output
172
-
173
108
  def __post_init__(self):
174
109
  self.defocus_angle = np.radians(self.defocus_angle)
175
-
176
- def _compute_electron_wavelength(self, acceleration_voltage: int = None):
177
- """Computes the wavelength of an electron in angstrom."""
178
-
179
- if acceleration_voltage is None:
180
- acceleration_voltage = self.acceleration_voltage
181
-
182
- # Physical constants expressed in SI units
183
- planck_constant = 6.62606896e-34
184
- electron_charge = 1.60217646e-19
185
- electron_mass = 9.10938215e-31
186
- light_velocity = 299792458
187
-
188
- energy = electron_charge * acceleration_voltage
189
- denominator = energy**2
190
- denominator += 2 * energy * electron_mass * light_velocity**2
191
- electron_wavelength = np.divide(
192
- planck_constant * light_velocity, np.sqrt(denominator)
193
- )
194
- # Convert to Ångstrom
195
- electron_wavelength *= 1e10
196
- return electron_wavelength
110
+ self.phase_shift = np.radians(self.phase_shift)
197
111
 
198
112
  def __call__(self, **kwargs) -> NDArray:
199
113
  func_args = vars(self).copy()
200
114
  func_args.update(kwargs)
201
115
 
202
- if len(func_args["angles"]) != len(func_args["defocus_x"]):
203
- func_args["angles"] = self.angles
204
- func_args["return_real_fourier"] = False
205
- func_args["tilt_axis"] = None
206
- func_args["opening_axis"] = None
207
-
208
116
  ret = self.weight(**func_args)
209
117
  ret = be.astype(be.to_backend_array(ret), be._float_dtype)
210
118
  return {
211
119
  "data": ret,
212
- "angles": func_args["angles"],
213
- "tilt_axis": func_args["tilt_axis"],
214
- "opening_axis": func_args["opening_axis"],
120
+ "shape": func_args["shape"],
121
+ "return_real_fourier": func_args.get("return_real_fourier"),
215
122
  "is_multiplicative_filter": True,
216
123
  }
217
124
 
@@ -225,8 +132,8 @@ class CTF(ComposableFilter):
225
132
  shape: Tuple[int],
226
133
  defocus_x: Tuple[float],
227
134
  angles: Tuple[float],
228
- opening_axis: int = None,
229
- tilt_axis: int = None,
135
+ opening_axis: int = 2,
136
+ tilt_axis: int = 0,
230
137
  amplitude_contrast: float = 0.07,
231
138
  phase_shift: Tuple[float] = 0,
232
139
  defocus_angle: Tuple[float] = 0,
@@ -237,7 +144,6 @@ class CTF(ComposableFilter):
237
144
  spherical_aberration: float = 2.7e3,
238
145
  flip_phase: bool = True,
239
146
  return_real_fourier: bool = False,
240
- no_reconstruction: bool = True,
241
147
  cutoff_frequency: float = 0.5,
242
148
  **kwargs: Dict,
243
149
  ) -> NDArray:
@@ -253,15 +159,15 @@ class CTF(ComposableFilter):
253
159
  angles : tuple of float
254
160
  The tilt angles.
255
161
  opening_axis : int, optional
256
- The axis around which the wedge is opened, defaults to None.
162
+ The axis around which the wedge is opened, defaults to 2.
257
163
  tilt_axis : int, optional
258
- The axis along which the tilt is applied, defaults to None.
164
+ The axis along which the tilt is applied, defaults to 0.
259
165
  amplitude_contrast : float, optional
260
166
  The amplitude contrast, defaults to 0.07.
261
167
  phase_shift : tuple of float, optional
262
- The phase shift, defaults to 0.
168
+ The phase shift in radians, defaults to 0.
263
169
  defocus_angle : tuple of float, optional
264
- The defocus angle, defaults to 0.
170
+ The defocus angle in radians, defaults to 0.
265
171
  defocus_y : tuple of float, optional
266
172
  The defocus value in y direction, defaults to None.
267
173
  correct_defocus_gradient : bool, optional
@@ -291,93 +197,38 @@ class CTF(ComposableFilter):
291
197
  amplitude_contrast = self._pad_to_length(amplitude_contrast, angles.size)
292
198
 
293
199
  sampling_rate = np.max(sampling_rate)
294
- tilt_shape = compute_tilt_shape(
200
+ ctf_shape = compute_tilt_shape(
295
201
  shape=shape, opening_axis=opening_axis, reduce_dim=True
296
202
  )
297
- stack = np.zeros((len(angles), *tilt_shape))
203
+ stack = np.zeros((len(angles), *ctf_shape))
298
204
 
299
- correct_defocus_gradient &= len(shape) == 3
300
- correct_defocus_gradient &= tilt_axis is not None
301
- correct_defocus_gradient &= opening_axis is not None
302
-
303
- spherical_aberration /= sampling_rate
304
- electron_wavelength = self._compute_electron_wavelength() / sampling_rate
305
- electron_aberration = spherical_aberration * electron_wavelength**2
205
+ # Shift tilt axis forward
206
+ corrected_tilt_axis = tilt_axis
207
+ if opening_axis and tilt_axis is not None:
208
+ if opening_axis < tilt_axis:
209
+ corrected_tilt_axis -= 1
306
210
 
307
211
  for index, angle in enumerate(angles):
308
212
  defocus_x, defocus_y = defoci_x[index], defoci_y[index]
309
213
 
310
- defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
311
- defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
312
-
313
- if correct_defocus_gradient or defocus_y is not None:
314
- grid_shape = shape
315
- sampling = be.divide(sampling_rate, be.to_backend_array(shape))
316
- sampling = tuple(float(x) for x in sampling)
317
- if not no_reconstruction:
318
- grid_shape = tilt_shape
319
- sampling = tuple(
320
- x for i, x in enumerate(sampling) if i != opening_axis
321
- )
322
-
323
- grid = fftfreqn(
324
- shape=grid_shape,
325
- sampling_rate=sampling,
326
- return_sparse_grid=True,
327
- )
328
-
329
- # This should be done after defocus_x computation
330
- if correct_defocus_gradient:
331
- angle_rad = np.radians(angle)
332
- defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
333
- remaining_axis = tuple(
334
- i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
335
- )[0]
336
-
337
- if tilt_axis > remaining_axis:
338
- defocus_x = np.add(defocus_x, defocus_gradient)
339
- elif tilt_axis < remaining_axis and defocus_y is not None:
340
- defocus_y = np.add(defocus_y, defocus_gradient.T)
341
-
342
- # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
343
- if defocus_y is not None:
344
- defocus_sum = np.add(defocus_x, defocus_y)
345
- defocus_difference = np.subtract(defocus_x, defocus_y)
346
-
347
- angular_grid = np.arctan2(grid[1], grid[0])
348
- defocus_difference = np.multiply(
349
- defocus_difference,
350
- np.cos(2 * (angular_grid - defocus_angle[index])),
351
- )
352
- defocus_x = np.add(defocus_sum, defocus_difference)
353
- defocus_x *= 0.5
354
-
355
- frequency_grid = frequency_grid_at_angle(
356
- shape=shape,
357
- opening_axis=opening_axis,
358
- tilt_axis=tilt_axis,
214
+ correction = correct_defocus_gradient and angle is not None
215
+ chi = create_ctf(
359
216
  angle=angle,
360
- sampling_rate=1,
361
- )
362
- frequency_mask = frequency_grid < cutoff_frequency
363
-
364
- # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
365
- np.square(frequency_grid, out=frequency_grid)
366
- chi = defocus_x - 0.5 * electron_aberration[index] * frequency_grid
367
- np.multiply(chi, np.pi * electron_wavelength, out=chi)
368
- np.multiply(chi, frequency_grid, out=chi)
369
- chi += phase_shift[index]
370
- chi += np.arctan(
371
- np.divide(
372
- amplitude_contrast[index],
373
- np.sqrt(1 - np.square(amplitude_contrast[index])),
374
- )
217
+ shape=ctf_shape,
218
+ defocus_x=defocus_x,
219
+ defocus_y=defocus_y,
220
+ sampling_rate=sampling_rate,
221
+ acceleration_voltage=self.acceleration_voltage,
222
+ correct_defocus_gradient=correction,
223
+ spherical_aberration=spherical_aberration[index],
224
+ cutoff_frequency=cutoff_frequency,
225
+ phase_shift=phase_shift[index],
226
+ defocus_angle=defocus_angle[index],
227
+ amplitude_contrast=amplitude_contrast[index],
228
+ tilt_axis=corrected_tilt_axis,
229
+ opening_axis=opening_axis,
230
+ full_shape=shape,
375
231
  )
376
- np.sin(-chi, out=chi)
377
- np.multiply(chi, frequency_mask, out=chi)
378
-
379
- if no_reconstruction:
380
- chi = shift_fourier(data=chi, shape_is_real_fourier=False)
381
232
 
382
233
  stack[index] = chi
383
234
 
@@ -387,7 +238,368 @@ class CTF(ComposableFilter):
387
238
  np.abs(stack, out=stack)
388
239
 
389
240
  stack = be.to_backend_array(np.squeeze(stack))
390
- if no_reconstruction and return_real_fourier:
241
+ return stack
242
+
243
+
244
+ class CTFReconstructed(CTF):
245
+ """
246
+ Create a simple contrast transfer function mask without the ability to specify
247
+ per-tilt parameters like in :py:class:`CTF`.
248
+ """
249
+
250
+ def weight(
251
+ self,
252
+ shape: Tuple[int],
253
+ defocus_x: Tuple[float],
254
+ amplitude_contrast: float = 0.07,
255
+ phase_shift: Tuple[float] = 0,
256
+ defocus_angle: Tuple[float] = 0,
257
+ defocus_y: Tuple[float] = None,
258
+ sampling_rate: Tuple[float] = 1,
259
+ acceleration_voltage: float = 300e3,
260
+ spherical_aberration: float = 2.7e3,
261
+ flip_phase: bool = True,
262
+ return_real_fourier: bool = False,
263
+ cutoff_frequency: float = 0.5,
264
+ **kwargs: Dict,
265
+ ) -> NDArray:
266
+ """
267
+ Compute the CTF weight tilt stack.
268
+
269
+ Parameters
270
+ ----------
271
+ shape : tuple of int
272
+ The shape of the CTF.
273
+ defocus_x : tuple of float
274
+ The defocus value in x direction.
275
+ opening_axis : int, optional
276
+ The axis around which the wedge is opened, defaults to 2.
277
+ amplitude_contrast : float, optional
278
+ The amplitude contrast, defaults to 0.07.
279
+ phase_shift : tuple of float, optional
280
+ The phase shift in radians, defaults to 0.
281
+ defocus_angle : tuple of float, optional
282
+ The defocus angle in radians, defaults to 0.
283
+ defocus_y : tuple of float, optional
284
+ The defocus value in y direction, defaults to None.
285
+ sampling_rate : tuple of float, optional
286
+ The sampling rate, defaults to 1.
287
+ acceleration_voltage : float, optional
288
+ The acceleration voltage in electron microscopy, defaults to 300e3.
289
+ spherical_aberration : float, optional
290
+ The spherical aberration coefficient, defaults to 2.7e3.
291
+ flip_phase : bool, optional
292
+ Whether the returned CTF should be phase-flipped.
293
+ **kwargs : Dict
294
+ Additional keyword arguments.
295
+
296
+ Returns
297
+ -------
298
+ NDArray
299
+ A stack containing the CTF weight.
300
+ """
301
+ stack = create_ctf(
302
+ shape=shape,
303
+ defocus_x=defocus_x,
304
+ defocus_y=defocus_y,
305
+ sampling_rate=np.max(sampling_rate),
306
+ acceleration_voltage=self.acceleration_voltage,
307
+ correct_defocus_gradient=False,
308
+ spherical_aberration=spherical_aberration,
309
+ cutoff_frequency=cutoff_frequency,
310
+ phase_shift=phase_shift,
311
+ defocus_angle=defocus_angle,
312
+ amplitude_contrast=amplitude_contrast,
313
+ )
314
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
315
+
316
+ # Avoid contrast inversion
317
+ np.negative(stack, out=stack)
318
+ if flip_phase:
319
+ np.abs(stack, out=stack)
320
+
321
+ stack = be.to_backend_array(np.squeeze(stack))
322
+ if return_real_fourier:
391
323
  stack = crop_real_fourier(stack)
392
324
 
393
325
  return stack
326
+
327
+
328
+ def _from_xml(filename: str):
329
+ data = XMLParser(filename)
330
+
331
+ params = {
332
+ "PhaseShift": None,
333
+ "Amplitude": None,
334
+ "Defocus": None,
335
+ "Voltage": None,
336
+ "Cs": None,
337
+ "DefocusAngle": None,
338
+ "PixelSize": None,
339
+ "Angles": data["Angles"],
340
+ }
341
+
342
+ ctf_options = data["CTF"]["Param"]
343
+ for option in ctf_options:
344
+ option = option["@attributes"]
345
+ name = option["Name"]
346
+ if name in params:
347
+ params[name] = option["Value"]
348
+
349
+ if "GridCTF" in data:
350
+ ctf = data["GridCTF"]["Node"]
351
+ params["Defocus"] = [ctf[i]["@attributes"]["Value"] for i in range(len(ctf))]
352
+ ctf_phase = data["GridCTFPhase"]["Node"]
353
+ params["PhaseShift"] = [
354
+ ctf_phase[i]["@attributes"]["Value"] for i in range(len(ctf_phase))
355
+ ]
356
+ ctf_ast = data["GridCTFDefocusAngle"]["Node"]
357
+ params["DefocusAngle"] = [
358
+ ctf_ast[i]["@attributes"]["Value"] for i in range(len(ctf_ast))
359
+ ]
360
+
361
+ missing = [k for k, v in params.items() if v is None]
362
+ if len(missing):
363
+ raise ValueError(f"Could not find {missing} in {filename}.")
364
+
365
+ params = {
366
+ k: np.array(v) if hasattr(v, "__len__") else float(v) for k, v in params.items()
367
+ }
368
+
369
+ # Convert units to sampling rate (we assume it is Angstrom)
370
+ params["Cs"] = float(params["Cs"] * 1e7)
371
+ params["Defocus"] = params["Defocus"] * 1e4
372
+
373
+ mapping = {
374
+ "angles": "Angles",
375
+ "defocus_1": "Defocus",
376
+ "defocus_2": "Defocus",
377
+ "azimuth_astigmatism": "DefocusAngle",
378
+ "additional_phase_shift": "PhaseShift",
379
+ "acceleration_voltage": "Voltage",
380
+ "spherical_aberration": "Cs",
381
+ "amplitude_contrast": "Amplitude",
382
+ "pixel_size": "PixelSize",
383
+ }
384
+ return {k: params[v] for k, v in mapping.items()}
385
+
386
+
387
+ def _from_ctffind(filename: str):
388
+ parameter_regex = {
389
+ "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
390
+ "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
391
+ "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
392
+ "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
393
+ }
394
+
395
+ with open(filename, mode="r", encoding="utf-8") as infile:
396
+ lines = [x.strip() for x in infile.read().split("\n")]
397
+ lines = [x for x in lines if len(x)]
398
+
399
+ def _screen_params(line, params, output):
400
+ for parameter, regex_pattern in parameter_regex.items():
401
+ match = re.search(regex_pattern, line)
402
+ if match:
403
+ output[parameter] = float(match.group(1))
404
+
405
+ columns = {
406
+ "micrograph_number": 0,
407
+ "defocus_1": 1,
408
+ "defocus_2": 2,
409
+ "azimuth_astigmatism": 3,
410
+ "additional_phase_shift": 4,
411
+ "cross_correlation": 5,
412
+ "spacing": 6,
413
+ }
414
+ output = {k: [] for k in columns.keys()}
415
+ for line in lines:
416
+ if line.startswith("#"):
417
+ _screen_params(line, params=parameter_regex, output=output)
418
+ continue
419
+
420
+ values = line.split()
421
+ for key, value in columns.items():
422
+ output[key].append(float(values[value]))
423
+
424
+ for key in columns:
425
+ output[key] = np.array(output[key])
426
+
427
+ output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
428
+ return output
429
+
430
+
431
+ def _from_gctf(filename: str):
432
+ parser = StarParser(filename)
433
+ ctf_data = parser["data_"]
434
+
435
+ mapping = {
436
+ "defocus_1": ("_rlnDefocusU", float),
437
+ "defocus_2": ("_rlnDefocusV", float),
438
+ "pixel_size": ("_rlnDetectorPixelSize", float),
439
+ "acceleration_voltage": ("_rlnVoltage", float),
440
+ "spherical_aberration": ("_rlnSphericalAberration", float),
441
+ "amplitude_contrast": ("_rlnAmplitudeContrast", float),
442
+ "additional_phase_shift": (None, float),
443
+ "azimuth_astigmatism": ("_rlnDefocusAngle", float),
444
+ }
445
+ output = {}
446
+ for out_key, (key, key_dtype) in mapping.items():
447
+ if key not in ctf_data and key is not None:
448
+ warnings.warn(f"ctf_data is missing key {key}.")
449
+
450
+ key_value = ctf_data.get(key, [0])
451
+ output[out_key] = [key_dtype(x) for x in key_value]
452
+
453
+ longest_key = max(map(len, output.values()))
454
+ output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()}
455
+ return output
456
+
457
+
458
+ def _compute_electron_wavelength(acceleration_voltage: int = None):
459
+ """Computes the wavelength of an electron in angstrom."""
460
+
461
+ # Physical constants expressed in SI units
462
+ planck_constant = 6.62606896e-34
463
+ electron_charge = 1.60217646e-19
464
+ electron_mass = 9.10938215e-31
465
+ light_velocity = 299792458
466
+
467
+ energy = electron_charge * acceleration_voltage
468
+ denominator = energy**2
469
+ denominator += 2 * energy * electron_mass * light_velocity**2
470
+ electron_wavelength = np.divide(
471
+ planck_constant * light_velocity, np.sqrt(denominator)
472
+ )
473
+ # Convert to Ångstrom
474
+ electron_wavelength *= 1e10
475
+ return electron_wavelength
476
+
477
+
478
+ def create_ctf(
479
+ shape: Tuple[int],
480
+ defocus_x: float,
481
+ acceleration_voltage: float = 300,
482
+ defocus_angle: float = 0,
483
+ phase_shift: float = 0,
484
+ defocus_y: float = None,
485
+ sampling_rate: float = 1,
486
+ spherical_aberration: float = 2.7e7,
487
+ amplitude_contrast: float = 0.07,
488
+ correct_defocus_gradient: bool = False,
489
+ cutoff_frequency: float = 0.5,
490
+ angle: float = None,
491
+ tilt_axis: int = 0,
492
+ opening_axis: int = None,
493
+ full_shape: Tuple[int] = None,
494
+ ) -> NDArray:
495
+ """
496
+ Create CTF representation using the definition from [1]_.
497
+
498
+ Parameters
499
+ ----------
500
+ shape : Tuple[int]
501
+ Shape of the returned CTF mask.
502
+ defocus_x : float
503
+ Defocus in x in units of sampling rate, e.g. 30000 Angstrom.
504
+ acceleration_voltage : float, optional
505
+ Acceleration voltage in keV, defaults to 300.
506
+ defocus_angle : float, optional
507
+ Astigmatism in radians, defaults to 0.
508
+ phase_shift : float, optional
509
+ Phase shift from phase plate in radians, defaults to 0.
510
+ defocus_y : float, optional
511
+ Defocus in y in units of sampling rate.
512
+ tilt_axis : int, optional
513
+ Axes the specimen was tilted over, defaults to 0 (x-axis).
514
+ sampling_rate : float or tuple of floats
515
+ Sampling rate throughout shape, e.g., 4 Angstrom per voxel.
516
+ amplitude_contrast : float, optional
517
+ Amplitude contrast of microscope, defaults to 0.07.
518
+ spherical_aberration : float, optional
519
+ Spherical aberration of microscope in Angstrom.
520
+ angle : float, optional
521
+ Assume the created CTF is a projection over opening_axis observed at angle.
522
+ opening_axis : int, optional
523
+ Projection axis, only relevant if angle is given.
524
+ full_shape : tuple of ints
525
+ Shape of the entire volume we are observing a projection of. This is required
526
+ to compute aspect ratios for correct scaling. For instance, the 2D CTF slice
527
+ could be (50,50), while the final 3D CTF volume is (50,50,25) with the
528
+ opening_axis being 2, i.e., the z-axis.
529
+
530
+ Returns
531
+ -------
532
+ NDArray
533
+ CTF mask.
534
+
535
+ References
536
+ ----------
537
+ .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
538
+ Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
539
+ """
540
+ electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
541
+ electron_wavelength /= sampling_rate
542
+ aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
543
+ if correct_defocus_gradient or defocus_y is not None:
544
+ if len(shape) < 2:
545
+ raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
546
+
547
+ sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
548
+ grid = fftfreqn(
549
+ shape=shape,
550
+ sampling_rate=sampling,
551
+ return_sparse_grid=True,
552
+ )
553
+
554
+ # This should be done after defocus_x computation
555
+ if correct_defocus_gradient:
556
+ if angle is None:
557
+ raise ValueError("Cannot correct for defocus gradient without angle.")
558
+
559
+ angle_rad = np.radians(angle)
560
+ defocus_gradient = np.multiply(grid[tilt_axis], np.sin(angle_rad))
561
+
562
+ if tilt_axis == 0:
563
+ defocus_x = np.add(defocus_x, defocus_gradient)
564
+ elif tilt_axis == 1 and defocus_y is not None:
565
+ defocus_y = np.add(defocus_y, defocus_gradient)
566
+
567
+ # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy))
568
+ if defocus_y is not None:
569
+ defocus_sum = np.add(defocus_x, defocus_y)
570
+ defocus_difference = np.subtract(defocus_x, defocus_y)
571
+
572
+ angular_grid = np.arctan2(grid[1], grid[0])
573
+ defocus_difference = np.multiply(
574
+ defocus_difference,
575
+ np.cos(2 * (angular_grid - defocus_angle)),
576
+ )
577
+ defocus_x = np.add(defocus_sum, defocus_difference)
578
+ defocus_x *= 0.5
579
+
580
+ frequency_grid = fftfreqn(shape, sampling_rate=True, compute_euclidean_norm=True)
581
+ if angle is not None and opening_axis is not None and full_shape is not None:
582
+ frequency_grid = frequency_grid_at_angle(
583
+ shape=full_shape,
584
+ tilt_axis=tilt_axis,
585
+ opening_axis=opening_axis,
586
+ angle=angle,
587
+ sampling_rate=1,
588
+ )
589
+ frequency_mask = frequency_grid < cutoff_frequency
590
+
591
+ # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
592
+ np.square(frequency_grid, out=frequency_grid)
593
+ chi = defocus_x - 0.5 * aberration * frequency_grid
594
+ np.multiply(chi, np.pi * electron_wavelength, out=chi)
595
+ np.multiply(chi, frequency_grid, out=chi)
596
+ chi += phase_shift
597
+ chi += np.arctan(
598
+ np.divide(
599
+ amplitude_contrast,
600
+ np.sqrt(1 - np.square(amplitude_contrast)),
601
+ )
602
+ )
603
+ np.sin(-chi, out=chi)
604
+ np.multiply(chi, frequency_mask, out=chi)
605
+ return chi