pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py CHANGED
@@ -7,7 +7,6 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
9
  import re
10
- import warnings
11
10
  from typing import Tuple, Dict
12
11
  from dataclasses import dataclass
13
12
 
@@ -16,13 +15,14 @@ import numpy as np
16
15
  from ..types import NDArray
17
16
  from ..backends import backend as be
18
17
  from .compose import ComposableFilter
19
- from ..parser import StarParser, XMLParser
18
+ from ..parser import StarParser, XMLParser, MDOCParser
20
19
  from ._utils import (
21
20
  frequency_grid_at_angle,
22
21
  compute_tilt_shape,
23
22
  crop_real_fourier,
24
23
  fftfreqn,
25
24
  shift_fourier,
25
+ pad_to_length,
26
26
  )
27
27
 
28
28
  __all__ = ["CTF", "CTFReconstructed", "create_ctf"]
@@ -34,39 +34,39 @@ class CTF(ComposableFilter):
34
34
  Generate a per-tilt contrast transfer function mask.
35
35
  """
36
36
 
37
- #: The shape of the to-be reconstructed volume.
37
+ #: The shape of the to-be created mask.
38
38
  shape: Tuple[int] = None
39
- #: The defocus value in x direction.
40
- defocus_x: float = None
39
+ #: The defocus in x direction (in units of sampling rate).
40
+ defocus_x: Tuple[float] = None
41
41
  #: The tilt angles.
42
42
  angles: Tuple[float] = None
43
- #: The microscope projection axis, defaults to None.
43
+ #: The microscope projection axis, defaults to 2 (z).
44
44
  opening_axis: int = 2
45
- #: The axis along which the tilt is applied, defaults to 2 (z).
45
+ #: The axis along which the tilt is applied, defaults to 0 (x).
46
46
  tilt_axis: int = 0
47
- #: Whether to correct defocus gradient, defaults to 0 (x).
47
+ #: Whether to correct defocus gradient, defaults False.
48
48
  correct_defocus_gradient: bool = False
49
- #: The sampling rate, defaults to 1 Angstrom / Voxel.
49
+ #: The sampling rate, defaults to 1 Ångstrom / voxel.
50
50
  sampling_rate: Tuple[float] = 1
51
51
  #: The acceleration voltage in Volts, defaults to 300e3.
52
- acceleration_voltage: float = 300e3
53
- #: The spherical aberration coefficient, defaults to 2.7e7.
54
- spherical_aberration: float = 2.7e7
52
+ acceleration_voltage: Tuple[float] = 300e3
53
+ #: The spherical aberration, defaults to 2.7e7 (in units of sampling rate).
54
+ spherical_aberration: Tuple[float] = 2.7e7
55
55
  #: The amplitude contrast, defaults to 0.07.
56
- amplitude_contrast: float = 0.07
56
+ amplitude_contrast: Tuple[float] = 0.07
57
57
  #: The phase shift in degrees, defaults to 0.
58
- phase_shift: float = 0
58
+ phase_shift: Tuple[float] = 0
59
59
  #: The defocus angle in degrees, defaults to 0.
60
- defocus_angle: float = 0
61
- #: The defocus value in y direction, defaults to None.
62
- defocus_y: float = None
63
- #: Whether the returned CTF should be phase-flipped.
60
+ defocus_angle: Tuple[float] = 0
61
+ #: The defocus value in y direction, defaults to None (in units of sampling rate).
62
+ defocus_y: Tuple[float] = None
63
+ #: Whether the returned CTF should be phase-flipped, defaults to True.
64
64
  flip_phase: bool = True
65
- #: Whether to return a format compliant with rfft. Only relevant for single angles.
65
+ #: Whether to return a ctf mask for rfft (for :py:class:`CTFReconstructed`).
66
66
  return_real_fourier: bool = False
67
67
 
68
68
  @classmethod
69
- def from_file(cls, filename: str) -> "CTF":
69
+ def from_file(cls, filename: str, **kwargs) -> "CTF":
70
70
  """
71
71
  Initialize :py:class:`CTF` from file.
72
72
 
@@ -80,36 +80,50 @@ class CTF(ComposableFilter):
80
80
  +-------+---------------------------------------------------------+
81
81
  | .xml | WARP/M XML file |
82
82
  +-------+---------------------------------------------------------+
83
+ | .mdoc | SerialEM file |
84
+ +-------+---------------------------------------------------------+
83
85
  | .* | CTFFIND4 file |
84
86
  +-------+---------------------------------------------------------+
87
+ **kwargs : optional
88
+ Overwrite fields that cannot be extracted from input file.
85
89
  """
86
90
  func = _from_ctffind
87
91
  if filename.lower().endswith("star"):
88
- func = _from_gctf
92
+ func = _from_star
89
93
  elif filename.lower().endswith("xml"):
90
94
  func = _from_xml
95
+ elif filename.lower().endswith("mdoc"):
96
+ func = _from_mdoc
91
97
 
92
98
  data = func(filename=filename)
93
99
 
94
100
  # Pixel size needs to be overwritten by pixel size the ctf is generated for
95
- return cls(
96
- shape=None,
97
- angles=data.get("angles", None),
98
- defocus_x=data["defocus_1"],
99
- sampling_rate=data["pixel_size"],
100
- acceleration_voltage=data["acceleration_voltage"],
101
- spherical_aberration=data["spherical_aberration"],
102
- amplitude_contrast=data["amplitude_contrast"],
103
- phase_shift=data["additional_phase_shift"],
104
- defocus_angle=data["azimuth_astigmatism"],
105
- defocus_y=data["defocus_2"],
106
- )
101
+ init_kwargs = {
102
+ "shape": None,
103
+ "angles": data.get("angles", None),
104
+ "defocus_x": data["defocus_1"],
105
+ "sampling_rate": data["pixel_size"],
106
+ "acceleration_voltage": np.multiply(data["acceleration_voltage"], 1e3),
107
+ "spherical_aberration": data.get("spherical_aberration"),
108
+ "amplitude_contrast": data.get("amplitude_contrast"),
109
+ "phase_shift": data.get("additional_phase_shift"),
110
+ "defocus_angle": data.get("azimuth_astigmatism"),
111
+ "defocus_y": data["defocus_2"],
112
+ }
113
+ for k, v in kwargs.items():
114
+ if k in init_kwargs and init_kwargs.get(k) is None:
115
+ init_kwargs[k] = v
116
+ init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
117
+ return cls(**init_kwargs)
107
118
 
108
119
  def __post_init__(self):
109
120
  self.defocus_angle = np.radians(self.defocus_angle)
110
121
  self.phase_shift = np.radians(self.phase_shift)
111
122
 
112
123
  def __call__(self, **kwargs) -> NDArray:
124
+ """
125
+ Returns a CTF stack of chosen parameters with DC component in the center.
126
+ """
113
127
  func_args = vars(self).copy()
114
128
  func_args.update(kwargs)
115
129
 
@@ -122,11 +136,6 @@ class CTF(ComposableFilter):
122
136
  "is_multiplicative_filter": True,
123
137
  }
124
138
 
125
- @staticmethod
126
- def _pad_to_length(arr, length: int):
127
- ret = np.atleast_1d(arr)
128
- return np.repeat(ret, length // ret.size)
129
-
130
139
  def weight(
131
140
  self,
132
141
  shape: Tuple[int],
@@ -134,7 +143,7 @@ class CTF(ComposableFilter):
134
143
  angles: Tuple[float],
135
144
  opening_axis: int = 2,
136
145
  tilt_axis: int = 0,
137
- amplitude_contrast: float = 0.07,
146
+ amplitude_contrast: Tuple[float] = 0.07,
138
147
  phase_shift: Tuple[float] = 0,
139
148
  defocus_angle: Tuple[float] = 0,
140
149
  defocus_y: Tuple[float] = None,
@@ -155,21 +164,21 @@ class CTF(ComposableFilter):
155
164
  shape : tuple of int
156
165
  The shape of the CTF.
157
166
  defocus_x : tuple of float
158
- The defocus value in x direction.
167
+ The defocus in x direction (in units of sampling rate).
159
168
  angles : tuple of float
160
169
  The tilt angles.
161
170
  opening_axis : int, optional
162
171
  The axis around which the wedge is opened, defaults to 2.
163
172
  tilt_axis : int, optional
164
173
  The axis along which the tilt is applied, defaults to 0.
165
- amplitude_contrast : float, optional
174
+ amplitude_contrast : tuple of float, optional
166
175
  The amplitude contrast, defaults to 0.07.
167
176
  phase_shift : tuple of float, optional
168
177
  The phase shift in radians, defaults to 0.
169
178
  defocus_angle : tuple of float, optional
170
179
  The defocus angle in radians, defaults to 0.
171
180
  defocus_y : tuple of float, optional
172
- The defocus value in y direction, defaults to None.
181
+ The defocus in x direction (in units of sampling rate).
173
182
  correct_defocus_gradient : bool, optional
174
183
  Whether to correct defocus gradient, defaults to False.
175
184
  sampling_rate : tuple of float, optional
@@ -179,7 +188,7 @@ class CTF(ComposableFilter):
179
188
  spherical_aberration : float, optional
180
189
  The spherical aberration coefficient, defaults to 2.7e3.
181
190
  flip_phase : bool, optional
182
- Whether the returned CTF should be phase-flipped.
191
+ Whether the returned CTF should be phase-flipped, defaults to True.
183
192
  **kwargs : Dict
184
193
  Additional keyword arguments.
185
194
 
@@ -189,12 +198,13 @@ class CTF(ComposableFilter):
189
198
  A stack containing the CTF weight.
190
199
  """
191
200
  angles = np.atleast_1d(angles)
192
- defoci_x = self._pad_to_length(defocus_x, angles.size)
193
- defoci_y = self._pad_to_length(defocus_y, angles.size)
194
- phase_shift = self._pad_to_length(phase_shift, angles.size)
195
- defocus_angle = self._pad_to_length(defocus_angle, angles.size)
196
- spherical_aberration = self._pad_to_length(spherical_aberration, angles.size)
197
- amplitude_contrast = self._pad_to_length(amplitude_contrast, angles.size)
201
+ defoci_x = pad_to_length(defocus_x, angles.size)
202
+ defoci_y = pad_to_length(defocus_y, angles.size)
203
+ phase_shift = pad_to_length(phase_shift, angles.size)
204
+ defocus_angle = pad_to_length(defocus_angle, angles.size)
205
+ spherical_aberration = pad_to_length(spherical_aberration, angles.size)
206
+ amplitude_contrast = pad_to_length(amplitude_contrast, angles.size)
207
+ acceleration_voltage = pad_to_length(acceleration_voltage, angles.size)
198
208
 
199
209
  sampling_rate = np.max(sampling_rate)
200
210
  ctf_shape = compute_tilt_shape(
@@ -209,16 +219,14 @@ class CTF(ComposableFilter):
209
219
  corrected_tilt_axis -= 1
210
220
 
211
221
  for index, angle in enumerate(angles):
212
- defocus_x, defocus_y = defoci_x[index], defoci_y[index]
213
-
214
222
  correction = correct_defocus_gradient and angle is not None
215
223
  chi = create_ctf(
216
224
  angle=angle,
217
225
  shape=ctf_shape,
218
- defocus_x=defocus_x,
219
- defocus_y=defocus_y,
226
+ defocus_x=defoci_x[index],
227
+ defocus_y=defoci_y[index],
220
228
  sampling_rate=sampling_rate,
221
- acceleration_voltage=self.acceleration_voltage,
229
+ acceleration_voltage=acceleration_voltage[index],
222
230
  correct_defocus_gradient=correction,
223
231
  spherical_aberration=spherical_aberration[index],
224
232
  cutoff_frequency=cutoff_frequency,
@@ -233,12 +241,10 @@ class CTF(ComposableFilter):
233
241
  stack[index] = chi
234
242
 
235
243
  # Avoid contrast inversion
236
- np.negative(stack, out=stack)
244
+ stack = np.negative(stack, out=stack)
237
245
  if flip_phase:
238
- np.abs(stack, out=stack)
239
-
240
- stack = be.to_backend_array(np.squeeze(stack))
241
- return stack
246
+ stack = np.abs(stack, out=stack)
247
+ return be.to_backend_array(np.squeeze(stack))
242
248
 
243
249
 
244
250
  class CTFReconstructed(CTF):
@@ -271,7 +277,7 @@ class CTFReconstructed(CTF):
271
277
  shape : tuple of int
272
278
  The shape of the CTF.
273
279
  defocus_x : tuple of float
274
- The defocus value in x direction.
280
+ The defocus in x direction in units of sampling rate.
275
281
  opening_axis : int, optional
276
282
  The axis around which the wedge is opened, defaults to 2.
277
283
  amplitude_contrast : float, optional
@@ -281,7 +287,7 @@ class CTFReconstructed(CTF):
281
287
  defocus_angle : tuple of float, optional
282
288
  The defocus angle in radians, defaults to 0.
283
289
  defocus_y : tuple of float, optional
284
- The defocus value in y direction, defaults to None.
290
+ The defocus in y direction in units of sampling rate.
285
291
  sampling_rate : tuple of float, optional
286
292
  The sampling rate, defaults to 1.
287
293
  acceleration_voltage : float, optional
@@ -311,21 +317,18 @@ class CTFReconstructed(CTF):
311
317
  defocus_angle=defocus_angle,
312
318
  amplitude_contrast=amplitude_contrast,
313
319
  )
314
- stack = shift_fourier(data=stack, shape_is_real_fourier=False)
315
-
316
320
  # Avoid contrast inversion
317
321
  np.negative(stack, out=stack)
318
322
  if flip_phase:
319
323
  np.abs(stack, out=stack)
320
324
 
321
- stack = be.to_backend_array(np.squeeze(stack))
325
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
322
326
  if return_real_fourier:
323
327
  stack = crop_real_fourier(stack)
328
+ return be.to_backend_array(np.squeeze(stack))
324
329
 
325
- return stack
326
330
 
327
-
328
- def _from_xml(filename: str):
331
+ def _from_xml(filename: str) -> Dict:
329
332
  data = XMLParser(filename)
330
333
 
331
334
  params = {
@@ -353,6 +356,7 @@ def _from_xml(filename: str):
353
356
  params["PhaseShift"] = [
354
357
  ctf_phase[i]["@attributes"]["Value"] for i in range(len(ctf_phase))
355
358
  ]
359
+ params["PhaseShift"] = np.degrees(params["PhaseShift"])
356
360
  ctf_ast = data["GridCTFDefocusAngle"]["Node"]
357
361
  params["DefocusAngle"] = [
358
362
  ctf_ast[i]["@attributes"]["Value"] for i in range(len(ctf_ast))
@@ -384,7 +388,7 @@ def _from_xml(filename: str):
384
388
  return {k: params[v] for k, v in mapping.items()}
385
389
 
386
390
 
387
- def _from_ctffind(filename: str):
391
+ def _from_ctffind(filename: str) -> Dict:
388
392
  parameter_regex = {
389
393
  "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
390
394
  "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
@@ -425,18 +429,64 @@ def _from_ctffind(filename: str):
425
429
  output[key] = np.array(output[key])
426
430
 
427
431
  output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
432
+ cs = output.get("spherical_aberration")
433
+ if cs is not None:
434
+ output["spherical_aberration"] = float(cs) * 1e7
428
435
  return output
429
436
 
430
437
 
431
- def _from_gctf(filename: str):
438
+ def _from_star(filename: str) -> Dict:
432
439
  parser = StarParser(filename)
433
- ctf_data = parser["data_"]
440
+
441
+ if "data_stopgap_wedgelist" in parser:
442
+ key = "data_stopgap_wedgelist"
443
+ mapping = {
444
+ "angles": ("_tilt_angle", float, 1),
445
+ "defocus_1": ("_defocus", float, 1e4),
446
+ "defocus_2": (None, float, 1e4),
447
+ "pixel_size": ("_pixelsize", float, 1),
448
+ "acceleration_voltage": ("_voltage", float, 1),
449
+ "spherical_aberration": ("_cs", float, 1e7),
450
+ "amplitude_contrast": ("_amp_contrast", float, 1),
451
+ "additional_phase_shift": (None, float, 1),
452
+ "azimuth_astigmatism": (None, float, 1),
453
+ }
454
+ else:
455
+ key = "data_"
456
+ mapping = {
457
+ "defocus_1": ("_rlnDefocusU", float, 1),
458
+ "defocus_2": ("_rlnDefocusV", float, 1),
459
+ "pixel_size": ("_rlnDetectorPixelSize", float, 1),
460
+ "acceleration_voltage": ("_rlnVoltage", float, 1),
461
+ "spherical_aberration": ("_rlnSphericalAberration", float, 1),
462
+ "amplitude_contrast": ("_rlnAmplitudeContrast", float, 1),
463
+ "additional_phase_shift": (None, float, 1),
464
+ "azimuth_astigmatism": ("_rlnDefocusAngle", float, 1),
465
+ }
466
+
467
+ output = {}
468
+ ctf_data = parser[key]
469
+ for out_key, (key, key_dtype, scale) in mapping.items():
470
+ key_value = ctf_data.get(key)
471
+ if key_value is not None:
472
+ try:
473
+ key_value = [key_dtype(x) * scale for x in key_value]
474
+ except Exception:
475
+ pass
476
+ output[out_key] = key_value
477
+ return output
478
+
479
+
480
+ def _from_mdoc(filename: str) -> Dict:
481
+ parser = MDOCParser(filename)
434
482
 
435
483
  mapping = {
436
- "defocus_1": ("_rlnDefocusU", float),
437
- "defocus_2": ("_rlnDefocusV", float),
484
+ "angles": ("TiltAngle", float),
485
+ "defocus_1": ("Defocus", float),
486
+ "acceleration_voltage": ("Voltage", float),
487
+ # These will be None, but on purpose
438
488
  "pixel_size": ("_rlnDetectorPixelSize", float),
439
- "acceleration_voltage": ("_rlnVoltage", float),
489
+ "defocus_2": ("Defocus2", float),
440
490
  "spherical_aberration": ("_rlnSphericalAberration", float),
441
491
  "amplitude_contrast": ("_rlnAmplitudeContrast", float),
442
492
  "additional_phase_shift": (None, float),
@@ -444,14 +494,10 @@ def _from_gctf(filename: str):
444
494
  }
445
495
  output = {}
446
496
  for out_key, (key, key_dtype) in mapping.items():
447
- if key not in ctf_data and key is not None:
448
- warnings.warn(f"ctf_data is missing key {key}.")
449
-
450
- key_value = ctf_data.get(key, [0])
451
- output[out_key] = [key_dtype(x) for x in key_value]
497
+ output[out_key] = parser.get(key, None)
452
498
 
453
- longest_key = max(map(len, output.values()))
454
- output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()}
499
+ # Adjust convention and convert to Angstrom
500
+ output["defocus_1"] = np.multiply(output["defocus_1"], -1e4)
455
501
  return output
456
502
 
457
503
 
@@ -516,7 +562,7 @@ def create_ctf(
516
562
  amplitude_contrast : float, optional
517
563
  Amplitude contrast of microscope, defaults to 0.07.
518
564
  spherical_aberration : float, optional
519
- Spherical aberration of microscope in Angstrom.
565
+ Spherical aberration of microscope in units of sampling rate.
520
566
  angle : float, optional
521
567
  Assume the created CTF is a projection over opening_axis observed at angle.
522
568
  opening_axis : int, optional
@@ -540,10 +586,14 @@ def create_ctf(
540
586
  electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
541
587
  electron_wavelength /= sampling_rate
542
588
  aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
589
+
590
+ defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
591
+ defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
543
592
  if correct_defocus_gradient or defocus_y is not None:
544
593
  if len(shape) < 2:
545
594
  raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
546
595
 
596
+ # Axial distance from grid center in multiples of sampling rate
547
597
  sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
548
598
  grid = fftfreqn(
549
599
  shape=shape,
@@ -569,6 +619,7 @@ def create_ctf(
569
619
  defocus_sum = np.add(defocus_x, defocus_y)
570
620
  defocus_difference = np.subtract(defocus_x, defocus_y)
571
621
 
622
+ # Reusing grid, but in principle pure frequencies would suffice
572
623
  angular_grid = np.arctan2(grid[1], grid[0])
573
624
  defocus_difference = np.multiply(
574
625
  defocus_difference,
@@ -577,7 +628,7 @@ def create_ctf(
577
628
  defocus_x = np.add(defocus_sum, defocus_difference)
578
629
  defocus_x *= 0.5
579
630
 
580
- frequency_grid = fftfreqn(shape, sampling_rate=True, compute_euclidean_norm=True)
631
+ frequency_grid = fftfreqn(shape, sampling_rate=1, compute_euclidean_norm=True)
581
632
  if angle is not None and opening_axis is not None and full_shape is not None:
582
633
  frequency_grid = frequency_grid_at_angle(
583
634
  shape=full_shape,
@@ -589,10 +640,10 @@ def create_ctf(
589
640
  frequency_mask = frequency_grid < cutoff_frequency
590
641
 
591
642
  # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term)
592
- np.square(frequency_grid, out=frequency_grid)
643
+ frequency_grid = np.square(frequency_grid, out=frequency_grid)
593
644
  chi = defocus_x - 0.5 * aberration * frequency_grid
594
- np.multiply(chi, np.pi * electron_wavelength, out=chi)
595
- np.multiply(chi, frequency_grid, out=chi)
645
+ chi = np.multiply(chi, np.pi * electron_wavelength, out=chi)
646
+ chi = np.multiply(chi, frequency_grid, out=chi)
596
647
  chi += phase_shift
597
648
  chi += np.arctan(
598
649
  np.divide(
@@ -600,6 +651,6 @@ def create_ctf(
600
651
  np.sqrt(1 - np.square(amplitude_contrast)),
601
652
  )
602
653
  )
603
- np.sin(-chi, out=chi)
604
- np.multiply(chi, frequency_mask, out=chi)
654
+ chi = np.sin(-chi, out=chi)
655
+ chi = np.multiply(chi, frequency_mask, out=chi)
605
656
  return chi
@@ -11,14 +11,14 @@ from dataclasses import dataclass
11
11
 
12
12
  import numpy as np
13
13
 
14
- from ..types import NDArray
15
14
  from ..backends import backend as be
15
+ from ..types import NDArray, BackendArray
16
16
 
17
17
  from .compose import ComposableFilter
18
18
  from ..rotations import euler_to_rotationmatrix
19
19
  from ._utils import crop_real_fourier, shift_fourier, create_reconstruction_filter
20
20
 
21
- __all__ = ["ReconstructFromTilt"]
21
+ __all__ = ["ReconstructFromTilt", "ShiftFourier"]
22
22
 
23
23
 
24
24
  @dataclass
@@ -52,7 +52,9 @@ class ReconstructFromTilt(ComposableFilter):
52
52
  shape : tuple of int
53
53
  The shape of the reconstruction volume.
54
54
  data : BackendArray
55
- D-dimensional image stack with shape (n, ...)
55
+ D-dimensional image stack with shape (n, ...). The data is assumed to be
56
+ a Fourier transform of the stack you are trying to reconstruct with
57
+ DC component in the center.
56
58
  angles : tuple of float
57
59
  Angle of each individual tilt.
58
60
  return_real_fourier : bool, optional
@@ -87,13 +89,14 @@ class ReconstructFromTilt(ComposableFilter):
87
89
 
88
90
  ret = self.reconstruct(**func_args)
89
91
 
92
+ ret = shift_fourier(data=ret, shape_is_real_fourier=False)
90
93
  if return_real_fourier:
91
94
  ret = crop_real_fourier(ret)
92
95
 
93
96
  return {
94
97
  "data": ret,
95
98
  "shape": func_args["shape"],
96
- "shape_is_real_fourier": return_real_fourier,
99
+ "return_real_fourier": return_real_fourier,
97
100
  "is_multiplicative_filter": False,
98
101
  }
99
102
 
@@ -114,7 +117,7 @@ class ReconstructFromTilt(ComposableFilter):
114
117
  Parameters
115
118
  ----------
116
119
  data : NDArray
117
- The tilt series data.
120
+ The Fourier transform of tilt series data.
118
121
  shape : tuple of int
119
122
  Shape of the reconstruction.
120
123
  angles : tuple of float
@@ -138,9 +141,9 @@ class ReconstructFromTilt(ComposableFilter):
138
141
  return data
139
142
 
140
143
  data = be.to_backend_array(data)
141
- volume_temp = be.zeros(shape, dtype=be._float_dtype)
142
- volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
143
- volume = be.zeros(shape, dtype=be._float_dtype)
144
+ volume_temp = be.zeros(shape, dtype=data.dtype)
145
+ volume_temp_rotated = be.zeros(shape, dtype=data.dtype)
146
+ volume = be.zeros(shape, dtype=data.dtype)
144
147
 
145
148
  slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape)
146
149
  subset = tuple(
@@ -187,4 +190,30 @@ class ReconstructFromTilt(ComposableFilter):
187
190
  )
188
191
  volume = be.add(volume, volume_temp_rotated, out=volume)
189
192
 
190
- return shift_fourier(data=volume, shape_is_real_fourier=False)
193
+ return volume
194
+
195
+
196
+ class ShiftFourier(ComposableFilter):
197
+ def __call__(
198
+ self,
199
+ data: BackendArray,
200
+ shape_is_real_fourier: bool = False,
201
+ return_real_fourier: bool = True,
202
+ **kwargs,
203
+ ):
204
+ ret = []
205
+ for index in range(data.shape[0]):
206
+ mask = be.to_numpy_array(data[index])
207
+
208
+ mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier)
209
+ if return_real_fourier:
210
+ mask = crop_real_fourier(mask)
211
+ ret.append(mask[None])
212
+ ret = np.concatenate(ret, axis=0)
213
+
214
+ return {
215
+ "data": ret,
216
+ "shape": kwargs.get("shape"),
217
+ "return_real_fourier": return_real_fourier,
218
+ "is_multiplicative_filter": False,
219
+ }