pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. pytme-0.2.0.data/scripts/match_template.py +1019 -0
  2. pytme-0.2.0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0.dist-info/RECORD +72 -0
  6. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +459 -218
  9. pytme-0.1.9.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +533 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +35 -6
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +173 -78
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +76 -33
  23. tme/matching_exhaustive.py +354 -225
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +753 -649
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +176 -0
  30. tme/preprocessing/composable_filter.py +30 -0
  31. tme/preprocessing/compose.py +52 -0
  32. tme/preprocessing/frequency_filters.py +322 -0
  33. tme/preprocessing/tilt_series.py +967 -0
  34. tme/preprocessor.py +35 -25
  35. tme/structure.py +2 -37
  36. pytme-0.1.9.data/scripts/postprocess.py +0 -625
  37. pytme-0.1.9.dist-info/RECORD +0 -61
  38. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
  40. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
  41. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,967 @@
1
+ """ Defines filters on tomographic tilt series.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ import re
8
+ from typing import Tuple, Dict
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+ from numpy.typing import NDArray
13
+
14
+ from .. import Preprocessor
15
+ from ..backends import backend
16
+ from ..matching_utils import euler_to_rotationmatrix
17
+
18
+ from ._utils import (
19
+ frequency_grid_at_angle,
20
+ compute_tilt_shape,
21
+ crop_real_fourier,
22
+ centered_grid,
23
+ fftfreqn
24
+ )
25
+
26
+
27
+ def create_reconstruction_filter(filter_shape : Tuple[int], filter_type : str):
28
+ """Create a reconstruction filter of given filter_type.
29
+
30
+ Parameters
31
+ ----------
32
+ filter_shape : tuple of int
33
+ Shape of the returned filter
34
+ filter_type: str
35
+ The type of created filter, available options include:
36
+
37
+ +---------------+----------------------------------------------------+
38
+ | 'ram-lak' | Returns |w| |
39
+ +---------------+----------------------------------------------------+
40
+ | 'ramp' | Principles of Computerized Tomographic Imaging Avin|
41
+ | | ash C. Kak and Malcolm Slaney Chap 3 Eq. 61 [1]_ |
42
+ +---------------+----------------------------------------------------+
43
+ | 'shepp-logan' | |w| * sinc(|w| / 2) [2]_ |
44
+ +---------------+----------------------------------------------------+
45
+ | 'cosine' | |w| * cos(|w| * pi / 2) [2]_ |
46
+ +---------------+----------------------------------------------------+
47
+ | 'hamming' | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
48
+ +---------------+----------------------------------------------------+
49
+
50
+ Returns
51
+ -------
52
+ NDArray
53
+ Reconstruction filter
54
+
55
+ References
56
+ ----------
57
+ .. [1] Principles of Computerized Tomographic Imaging Avinash C. Kak and Malcolm Slaney Chap 3 Eq. 61
58
+ .. [2] https://odlgroup.github.io/odl/index.html
59
+ """
60
+ filter_type = str(filter_type).lower()
61
+ freq = fftfreqn(filter_shape, sampling_rate = .5, compute_euclidean_norm = True)
62
+
63
+ if filter_type == 'ram-lak':
64
+ ret = np.copy(freq)
65
+ elif filter_type == "ramp":
66
+ ret, ndim = None, len(filter_shape)
67
+ for dim, size in enumerate(filter_shape):
68
+ n = np.concatenate(
69
+ (
70
+ np.arange(1, size / 2 + 1, 2, dtype=int),
71
+ np.arange(size / 2 - 1, 0, -2, dtype=int),
72
+ )
73
+ )
74
+ ret1d = np.zeros(size)
75
+ ret1d[0] = 0.25
76
+ ret1d[1::2] = -1 / (np.pi * n) ** 2
77
+ ret1d_shape = tuple(size if i == dim else 1 for i in range(ndim))
78
+ ret1d = ret1d.reshape(ret1d_shape)
79
+ if ret is None:
80
+ ret = ret1d
81
+ else:
82
+ ret = ret * ret1d
83
+ ret = 2 * np.real(np.fft.fftn(ret))
84
+ elif filter_type == 'shepp-logan':
85
+ ret = freq * np.sinc(freq / 2)
86
+ elif filter_type == 'cosine':
87
+ ret = freq * np.cos(freq * np.pi / 2)
88
+ elif filter_type == 'hamming':
89
+ ret = freq * (0.54 + 0.46 * np.cos(freq * np.pi))
90
+ else:
91
+ raise ValueError("Unsupported filter type")
92
+
93
+ return ret
94
+
95
+ @dataclass
96
+ class ReconstructFromTilt:
97
+ """Reconstruct a volume from a tilt series."""
98
+ #: Shape of the reconstruction.
99
+ shape: Tuple[int] = None
100
+ #: Angle of each individual tilt.
101
+ angles: Tuple[float] = None
102
+ #: The axis around which the volume is opened.
103
+ opening_axis: int = 0
104
+ #: Axis the plane is tilted over.
105
+ tilt_axis: int = 2
106
+ #: Whether to return a share compliant with rfftn.
107
+ return_real_fourier: bool = True
108
+ #: Interpolation order used for rotation
109
+ interpolation_order: int = 1
110
+ #: Filter window applied during reconstruction.
111
+ reconstruction_filter: str = None
112
+
113
+ def __call__(self, **kwargs):
114
+ func_args = vars(self)
115
+ func_args.update(kwargs)
116
+
117
+ ret = self.reconstruct(**func_args)
118
+
119
+ return {
120
+ "data": ret,
121
+ "shape": ret.shape,
122
+ "shape_is_real_fourier": func_args["return_real_fourier"],
123
+ "angles": func_args["angles"],
124
+ "tilt_axis": func_args["tilt_axis"],
125
+ "opening_axis": func_args["opening_axis"],
126
+ "is_multiplicative_filter": False,
127
+ }
128
+
129
+ @staticmethod
130
+ def reconstruct(
131
+ data: NDArray,
132
+ shape: Tuple[int],
133
+ angles: Tuple[float],
134
+ opening_axis: int,
135
+ tilt_axis: int,
136
+ interpolation_order: int = 1,
137
+ return_real_fourier: bool = True,
138
+ reconstruction_filter:str = None,
139
+ **kwargs,
140
+ ):
141
+ """
142
+ Reconstruct the volume from tilt series.
143
+
144
+ Parameters:
145
+ -----------
146
+ data : NDArray
147
+ The tilt series data.
148
+ shape : tuple of int
149
+ Shape of the reconstruction.
150
+ angles : tuple of float
151
+ Angle of each individual tilt.
152
+ opening_axis : int
153
+ The axis around which the volume is opened.
154
+ tilt_axis : int
155
+ Axis the plane is tilted over.
156
+ interpolation_order : int, optional
157
+ Interpolation order used for rotation, defautls to 1.
158
+ return_real_fourier : bool, optional
159
+ Whether to return a share compliant with rfftn, defaults to True.
160
+ reconstruction_filter : bool, optional
161
+ Filter window applied during reconstruction.
162
+ See :py:meth:`create_reconstruction_filter` for available options.
163
+
164
+ Returns:
165
+ --------
166
+ NDArray
167
+ The reconstructed volume.
168
+ """
169
+ if data.shape == shape:
170
+ return data
171
+
172
+ data = backend.to_backend_array(data)
173
+ volume_temp = backend.zeros(shape, dtype=backend._default_dtype)
174
+ volume_temp_rotated = backend.zeros(shape, dtype=backend._default_dtype)
175
+ volume = backend.zeros(shape, dtype=backend._default_dtype)
176
+
177
+ slices = tuple(slice(a, a + 1) for a in backend.divide(shape, 2).astype(int))
178
+ subset = tuple(
179
+ slice(None) if i != opening_axis else slices[opening_axis]
180
+ for i in range(len(shape))
181
+ )
182
+ angles_loop = backend.zeros(len(shape))
183
+ wedge_dim = [x for x in data.shape]
184
+ wedge_dim.insert(1 + opening_axis, 1)
185
+ wedges = backend.reshape(data, wedge_dim)
186
+
187
+ rec_filter = 1
188
+ if reconstruction_filter is not None:
189
+ rec_filter = create_reconstruction_filter(
190
+ filter_type = reconstruction_filter,
191
+ filter_shape = tuple(x for x in wedges[0].shape if x != 1)
192
+ )
193
+
194
+ for index in range(len(angles)):
195
+ backend.fill(angles_loop, 0)
196
+ backend.fill(volume_temp, 0)
197
+ backend.fill(volume_temp_rotated, 0)
198
+
199
+ volume_temp[subset] = wedges[index] * rec_filter
200
+
201
+ angles_loop[tilt_axis] = angles[index]
202
+ angles_loop = backend.roll(angles_loop, opening_axis - 1, axis=0)
203
+ rotation_matrix = euler_to_rotationmatrix(
204
+ backend.to_numpy_array(angles_loop)
205
+ )
206
+ rotation_matrix = backend.to_backend_array(rotation_matrix)
207
+
208
+ backend.rotate_array(
209
+ arr=volume_temp,
210
+ rotation_matrix=rotation_matrix,
211
+ out=volume_temp_rotated,
212
+ use_geometric_center=True,
213
+ order=interpolation_order,
214
+ )
215
+ backend.add(volume, volume_temp_rotated, out=volume)
216
+
217
+ shift = backend.add(
218
+ backend.astype(backend.divide(volume.shape, 2), int),
219
+ backend.mod(volume.shape, 2),
220
+ )
221
+ volume = backend.roll(volume, shift, tuple(i for i in range(len(shift))))
222
+
223
+ if return_real_fourier:
224
+ volume = crop_real_fourier(volume)
225
+
226
+ return volume
227
+
228
+
229
+ class Wedge:
230
+ """
231
+ Generate wedge mask for tomographic data.
232
+
233
+ Parameters:
234
+ -----------
235
+ shape : tuple of int
236
+ The shape of the reconstruction volume.
237
+ tilt_axis : int
238
+ Axis the plane is tilted over.
239
+ opening_axis : int
240
+ The axis around which the volume is opened.
241
+ angles : tuple of float
242
+ The tilt angles.
243
+ weights : tuple of float
244
+ The weights corresponding to each tilt angle.
245
+ weight_type : str, optional
246
+ The type of weighting to apply, defaults to None.
247
+ frequency_cutoff : float, optional
248
+ The frequency cutoff value, defaults to 0.5.
249
+
250
+ Returns:
251
+ --------
252
+ Dict
253
+ A dictionary containing weighted wedges and related information.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ shape: Tuple[int],
259
+ tilt_axis: int,
260
+ opening_axis: int,
261
+ angles: Tuple[float],
262
+ weights: Tuple[float],
263
+ weight_type: str = None,
264
+ frequency_cutoff: float = 0.5,
265
+ ):
266
+ self.shape = shape
267
+ self.tilt_axis = tilt_axis
268
+ self.opening_axis = opening_axis
269
+ self.angles = angles
270
+ self.weights = weights
271
+ self.frequency_cutoff = frequency_cutoff
272
+
273
+ @classmethod
274
+ def from_file(cls, filename: str) -> "Wedge":
275
+ """
276
+ Generate a :py:class:`Wedge` instance by reading tilt angles and weights
277
+ from a tab-separated text file.
278
+
279
+ Parameters:
280
+ -----------
281
+ filename : str
282
+ The path to the file containing tilt angles and weights.
283
+
284
+ Returns:
285
+ --------
286
+ :py:class:`Wedge`
287
+ Class instance instance initialized with angles and weights from the file.
288
+ """
289
+ data = cls._from_text(filename)
290
+
291
+ angles, weights = data.get("angles", None), data.get("weights", None)
292
+ if angles is None:
293
+ raise ValueError(f"Could not find colum angles in {filename}")
294
+
295
+ if weights is None:
296
+ weights = [1] * len(angles)
297
+
298
+ if len(weights) != len(angles):
299
+ raise ValueError("Length of weights and angles differ.")
300
+
301
+ return cls(
302
+ shape=None,
303
+ tilt_axis=0,
304
+ opening_axis=2,
305
+ angles=np.array(angles, dtype=np.float32),
306
+ weights=np.array(weights, dtype=np.float32),
307
+ )
308
+
309
+ @staticmethod
310
+ def _from_text(filename: str, delimiter="\t") -> Dict:
311
+ """
312
+ Read column data from a text file.
313
+
314
+ Parameters:
315
+ -----------
316
+ filename : str
317
+ The path to the text file.
318
+ delimiter : str, optional
319
+ The delimiter used in the file, defaults to '\t'.
320
+
321
+ Returns:
322
+ --------
323
+ Dict
324
+ A dictionary with one key for each column.
325
+ """
326
+ with open(filename, mode="r", encoding="utf-8") as infile:
327
+ data = [x.strip() for x in infile.read().split("\n")]
328
+ data = [x.split("\t") for x in data if len(x)]
329
+
330
+ headers = data.pop(0)
331
+ ret = {header: list(column) for header, column in zip(headers, zip(*data))}
332
+
333
+ return ret
334
+
335
+ def __call__(self, **kwargs: Dict) -> NDArray:
336
+ func_args = vars(self)
337
+ func_args.update(kwargs)
338
+
339
+ weight_types = {
340
+ None: self.weight_angle,
341
+ "angle": self.weight_angle,
342
+ "relion": self.weight_relion,
343
+ "grigorieff": self.weight_grigorieff,
344
+ }
345
+
346
+ weight_type = func_args.get("weight_type", None)
347
+ if weight_type not in weight_types:
348
+ raise ValueError(
349
+ f"Supported weight_types are {','.join(list(weight_types.keys()))}"
350
+ )
351
+
352
+ if weight_type == "angle":
353
+ func_args["weights"] = np.cos(np.radians(self.angles))
354
+
355
+ ret = weight_types[weight_type](**func_args)
356
+ ret = backend.astype(
357
+ backend.to_backend_array(ret),
358
+ backend._default_dtype
359
+ )
360
+
361
+ return {
362
+ "data": ret,
363
+ "angles": func_args["angles"],
364
+ "tilt_axis": func_args["tilt_axis"],
365
+ "opening_axis": func_args["opening_axis"],
366
+ "sampling_rate": func_args.get("sampling_rate", 1),
367
+ "is_multiplicative_filter": True,
368
+ }
369
+
370
+ @staticmethod
371
+ def weight_angle(
372
+ shape: Tuple[int],
373
+ weights: Tuple[float],
374
+ angles: Tuple[float],
375
+ opening_axis: int,
376
+ tilt_axis: int,
377
+ frequency_cutoff: float = 0.5,
378
+ **kwargs,
379
+ ) -> NDArray:
380
+ """
381
+ Generate weighted wedges based on the cosine of the current angle.
382
+ """
383
+ tilt_shape = compute_tilt_shape(
384
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
385
+ )
386
+ wedge, wedges = np.ones(tilt_shape), np.zeros((len(angles), *tilt_shape))
387
+ for index, angle in enumerate(angles):
388
+ wedge.fill(weights[index])
389
+ frequency_grid = frequency_grid_at_angle(
390
+ shape=shape,
391
+ opening_axis=opening_axis,
392
+ tilt_axis=tilt_axis,
393
+ angle=angle,
394
+ sampling_rate=1,
395
+ )
396
+ np.multiply(wedge, frequency_grid <= frequency_cutoff, out=wedge)
397
+ wedges[index] = wedge
398
+
399
+ return wedges
400
+
401
+ def weight_relion(self, **kwargs) -> NDArray:
402
+ """
403
+ Generate weighted wedges based on the RELION 1.4 formalism, weighting each
404
+ angle using the cosine of the angle and a Gaussian lowpass filter computed
405
+ with respect to the exposure per angstrom.
406
+
407
+ Returns:
408
+ --------
409
+ NDArray
410
+ Weighted wedges.
411
+ """
412
+ tilt_shape = compute_tilt_shape(
413
+ shape=self.shape, opening_axis=self.opening_axis, reduce_dim=True
414
+ )
415
+
416
+ wedges = np.zeros((len(self.angles), *tilt_shape))
417
+ for index, angle in enumerate(self.angles):
418
+ frequency_grid = frequency_grid_at_angle(
419
+ shape=self.shape,
420
+ opening_axis=self.opening_axis,
421
+ tilt_axis=self.tilt_axis,
422
+ angle=angle,
423
+ sampling_rate=1,
424
+ )
425
+ frequency_mask = frequency_grid <= self.frequency_cutoff
426
+
427
+ sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
428
+ sigma = -2 * np.pi**2 * sigma**2
429
+ np.square(frequency_grid, out=frequency_grid)
430
+ np.multiply(sigma, frequency_grid, out=frequency_grid)
431
+ np.exp(frequency_grid, out=frequency_grid)
432
+ np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
433
+ np.multiply(frequency_grid, frequency_mask, out=frequency_grid)
434
+
435
+ wedges[index] = frequency_grid
436
+
437
+ return wedges
438
+
439
+ def weight_grigorieff(
440
+ self,
441
+ amplitude: float = 0.245,
442
+ power: float = -1.665,
443
+ offset: float = 2.81,
444
+ **kwargs,
445
+ ) -> NDArray:
446
+ """
447
+ Generate weighted wedges based on the formalism introduced in [1]_.
448
+
449
+ Returns:
450
+ --------
451
+ NDArray
452
+ Weighted wedges.
453
+
454
+ References
455
+ ----------
456
+ .. [1] Timothy GrantNikolaus Grigorieff (2015), eLife 4:e06980.
457
+ """
458
+ tilt_shape = compute_tilt_shape(
459
+ shape=self.shape, opening_axis=self.opening_axis, reduce_dim=True,
460
+ )
461
+
462
+ wedges = np.zeros((len(self.angles), *tilt_shape), dtype=backend._default_dtype)
463
+ for index, angle in enumerate(self.angles):
464
+ frequency_grid = frequency_grid_at_angle(
465
+ shape=self.shape,
466
+ opening_axis=self.opening_axis,
467
+ tilt_axis=self.tilt_axis,
468
+ angle=angle,
469
+ sampling_rate=1,
470
+ )
471
+ frequency_mask = frequency_grid <= self.frequency_cutoff
472
+
473
+ with np.errstate(divide="ignore"):
474
+ np.power(frequency_grid, power, out = frequency_grid)
475
+ np.multiply(amplitude, frequency_grid, out = frequency_grid)
476
+ np.add(frequency_grid, offset, out = frequency_grid)
477
+ np.multiply(-2, frequency_grid, out = frequency_grid)
478
+ np.divide(
479
+ self.weights[index],
480
+ frequency_grid,
481
+ out=frequency_grid,
482
+ )
483
+
484
+ np.exp(frequency_grid, out=frequency_grid)
485
+ np.multiply(frequency_grid, frequency_mask, out=frequency_grid)
486
+
487
+ wedges[index] = frequency_grid
488
+
489
+ return wedges
490
+
491
+
492
+ class WedgeReconstructed:
493
+ """
494
+ Initialize :py:class:`WedgeReconstructed`.
495
+
496
+ Parameters:
497
+ -----------
498
+ angles : Tuple[float], optional
499
+ The tilt angles, defaults to None.
500
+ opening_axis : int, optional
501
+ The axis around which the wedge is opened, defaults to 0.
502
+ tilt_axis : int, optional
503
+ The axis along which the tilt is applied, defaults to 2.
504
+ **kwargs : Dict
505
+ Additional keyword arguments.
506
+ """
507
+
508
+ def __init__(
509
+ self,
510
+ angles: Tuple[float] = None,
511
+ start_tilt: float = None,
512
+ stop_tilt: float = None,
513
+ opening_axis: int = 0,
514
+ tilt_axis: int = 2,
515
+ weight_wedge : bool = False,
516
+ create_continuous_wedge : bool = False,
517
+ **kwargs: Dict,
518
+ ):
519
+ self.angles = angles
520
+ self.opening_axis = opening_axis
521
+ self.tilt_axis = tilt_axis
522
+ self.weight_wedge = weight_wedge
523
+ self.create_continuous_wedge = create_continuous_wedge
524
+
525
+ def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
526
+ """
527
+ Generate the reconstructed wedge.
528
+
529
+ Parameters:
530
+ -----------
531
+ shape : tuple of int
532
+ The shape of the reconstruction volume.
533
+ **kwargs : Dict
534
+ Additional keyword arguments.
535
+
536
+ Returns:
537
+ --------
538
+ Dict
539
+ A dictionary containing the reconstructed wedge and related information.
540
+ """
541
+ func_args = vars(self)
542
+ func_args.update(kwargs)
543
+
544
+ if kwargs.get("is_fourier_shape", False):
545
+ print("Cannot create continuous wedge mask basde on real fourier shape.")
546
+
547
+ func = self.step_wedge
548
+ if func_args.get("create_continuous_wedge", False):
549
+ func = self.continuous_wedge
550
+
551
+ ret = func(shape=shape, **func_args)
552
+ ret = backend.astype(
553
+ backend.to_backend_array(ret),
554
+ backend._default_dtype
555
+ )
556
+
557
+ return {
558
+ "data": ret,
559
+ "shape_is_real_fourier": func_args["return_real_fourier"],
560
+ "shape": ret.shape,
561
+ "tilt_axis": func_args["tilt_axis"],
562
+ "opening_axis": func_args["opening_axis"],
563
+ "is_multiplicative_filter": True,
564
+ "angles" : func_args["angles"],
565
+ }
566
+
567
+ @staticmethod
568
+ def continuous_wedge(
569
+ shape: Tuple[int],
570
+ angles: Tuple[float],
571
+ opening_axis: int,
572
+ tilt_axis: int,
573
+ return_real_fourier: bool,
574
+ **kwargs: Dict,
575
+ ) -> NDArray:
576
+ """
577
+ Generate a continuous reconstructed wedge.
578
+
579
+ Parameters:
580
+ -----------
581
+ shape : tuple of int
582
+ The shape of the reconstruction volume.
583
+ angles : tuple of float
584
+ Start and stop tilt angle.
585
+ opening_axis : int
586
+ The axis around which the wedge is opened.
587
+ tilt_axis : int
588
+ The axis along which the tilt is applied.
589
+ return_real_fourier : bool
590
+ Whether to return the real part of the Fourier transform.
591
+
592
+ Returns:
593
+ --------
594
+ NDArray
595
+ The reconstructed wedge.
596
+ """
597
+ preprocessor = Preprocessor()
598
+ start_tilt, stop_tilt = angles
599
+ ret = preprocessor.continuous_wedge_mask(
600
+ start_tilt=start_tilt,
601
+ stop_tilt=stop_tilt,
602
+ shape=shape,
603
+ opening_axis=opening_axis,
604
+ tilt_axis=tilt_axis,
605
+ omit_negative_frequencies=return_real_fourier,
606
+ infinite_plane=False,
607
+ )
608
+
609
+ return ret
610
+
611
+ @staticmethod
612
+ def step_wedge(
613
+ shape: Tuple[int],
614
+ angles: Tuple[float],
615
+ opening_axis: int,
616
+ tilt_axis: int,
617
+ return_real_fourier: bool,
618
+ weight_wedge: bool = False,
619
+ **kwargs: Dict,
620
+ ) -> NDArray:
621
+ """
622
+ Generate a step-wise reconstructed wedge.
623
+
624
+ Parameters:
625
+ -----------
626
+ shape : tuple of int
627
+ The shape of the reconstruction volume.
628
+ angles : tuple of float
629
+ The tilt angles.
630
+ opening_axis : int
631
+ The axis around which the wedge is opened.
632
+ tilt_axis : int
633
+ The axis along which the tilt is applied.
634
+ weight_wedge : bool, optional
635
+ Whether to weight the wedge by the cosine of the angle.
636
+ return_real_fourier : bool
637
+ Whether to return the real part of the Fourier transform.
638
+
639
+ Returns:
640
+ --------
641
+ NDArray
642
+ The reconstructed wedge.
643
+ """
644
+ preprocessor = Preprocessor()
645
+
646
+ angles = np.asarray(backend.to_numpy_array(angles))
647
+ weights = np.ones(angles.size)
648
+ if weight_wedge:
649
+ weights = np.cos(np.radians(angles))
650
+ ret = preprocessor.step_wedge_mask(
651
+ tilt_angles=angles,
652
+ weights=weights,
653
+ start_tilt=None,
654
+ stop_tilt=None,
655
+ tilt_step=None,
656
+ shape=shape,
657
+ opening_axis=opening_axis,
658
+ tilt_axis=tilt_axis,
659
+ omit_negative_frequencies=return_real_fourier,
660
+ )
661
+
662
+ return ret
663
+
664
+
665
+ @dataclass
666
+ class CTF:
667
+ """
668
+ Representation of a contrast transfer function (CTF) [1]_.
669
+
670
+ References
671
+ ----------
672
+ .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
673
+ Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
674
+ """
675
+
676
+ #: The shape of the to-be reconstructed volume.
677
+ shape: Tuple[int]
678
+ #: The defocus value in x direction.
679
+ defocus_x: float
680
+ #: The tilt angles.
681
+ angles: Tuple[float]
682
+ #: The axis around which the wedge is opened, defaults to None.
683
+ opening_axis: int = None
684
+ #: The axis along which the tilt is applied, defaults to None.
685
+ tilt_axis: int = None
686
+ #: Whether to correct defocus gradient, defaults to False.
687
+ correct_defocus_gradient: bool = False
688
+ #: The sampling rate, defaults to 1.
689
+ sampling_rate: Tuple[float] = 1
690
+ #: The acceleration voltage in Volts, defaults to 300e3.
691
+ acceleration_voltage: float = 300e3
692
+ #: The spherical aberration coefficient, defaults to 2.7e3.
693
+ spherical_aberration: float = 2.7e3
694
+ #: The amplitude contrast, defaults to 0.07.
695
+ amplitude_contrast: float = 0.07
696
+ #: The phase shift, defaults to 0.
697
+ phase_shift: float = 0
698
+ #: The defocus angle, defaults to 0.
699
+ defocus_angle: float = 0
700
+ #: The defocus value in y direction, defaults to None.
701
+ defocus_y: float = None
702
+
703
+ @classmethod
704
+ def from_file(cls, filename: str) -> "CTF":
705
+ """
706
+ Initialize :py:class:`CTF` from file.
707
+
708
+ Parameters:
709
+ -----------
710
+ filename : str
711
+ The path to a file with ctf parameters. Supports the following formats:
712
+ - CTFFIND4
713
+ """
714
+ data = cls._from_ctffind(filename=filename)
715
+
716
+ return cls(
717
+ shape=None,
718
+ angles=None,
719
+ defocus_x=data["defocus_1"],
720
+ sampling_rate=data["pixel_size"],
721
+ acceleration_voltage=data["acceleration_voltage"],
722
+ spherical_aberration=data["spherical_aberration"],
723
+ amplitude_contrast=data["amplitude_contrast"],
724
+ phase_shift=data["additional_phase_shift"],
725
+ defocus_angle=np.degrees(data["azimuth_astigmatism"]),
726
+ defocus_y=data["defocus_2"],
727
+ )
728
+
729
+ @staticmethod
730
+ def _from_ctffind(filename: str):
731
+ parameter_regex = {
732
+ "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
733
+ "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
734
+ "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
735
+ "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
736
+ }
737
+
738
+ with open(filename, mode="r", encoding="utf-8") as infile:
739
+ lines = [x.strip() for x in infile.read().split("\n")]
740
+ lines = [x for x in lines if len(x)]
741
+
742
+ def _screen_params(line, params, output):
743
+ for parameter, regex_pattern in parameter_regex.items():
744
+ match = re.search(regex_pattern, line)
745
+ if match:
746
+ output[parameter] = float(match.group(1))
747
+
748
+ columns = {
749
+ "micrograph_number": 0,
750
+ "defocus_1": 1,
751
+ "defocus_2": 2,
752
+ "azimuth_astigmatism": 3,
753
+ "additional_phase_shift": 4,
754
+ "cross_correlation": 5,
755
+ "spacing": 6,
756
+ }
757
+ output = {k: [] for k in columns.keys()}
758
+ for line in lines:
759
+ if line.startswith("#"):
760
+ _screen_params(line, params=parameter_regex, output=output)
761
+ continue
762
+
763
+ values = line.split()
764
+ for key, value in columns.items():
765
+ output[key].append(float(values[value]))
766
+
767
+ for key in columns:
768
+ output[key] = np.array(output[key])
769
+
770
+ return output
771
+
772
+ def __post_init__(self):
773
+ self.defocus_angle = np.radians(self.defocus_angle)
774
+
775
+ kwargs = {
776
+ "defocus_x": self.defocus_x,
777
+ "defocus_y": self.defocus_y,
778
+ "spherical_aberration": self.spherical_aberration,
779
+ }
780
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
781
+ self._update_parameters(
782
+ electron_wavelength=self._compute_electron_wavelength(), **kwargs
783
+ )
784
+
785
+ def _compute_electron_wavelength(self, acceleration_voltage: int = None):
786
+ """Computes the wavelength of an electron in angstrom."""
787
+
788
+ if acceleration_voltage is None:
789
+ acceleration_voltage = self.acceleration_voltage
790
+
791
+ # Physical constants expressed in SI units
792
+ planck_constant = 6.62606896e-34
793
+ electron_charge = 1.60217646e-19
794
+ electron_mass = 9.10938215e-31
795
+ light_velocity = 299792458
796
+
797
+ energy = electron_charge * acceleration_voltage
798
+ denominator = energy**2
799
+ denominator += 2 * energy * electron_mass * light_velocity**2
800
+ electron_wavelength = np.divide(
801
+ planck_constant * light_velocity, np.sqrt(denominator)
802
+ )
803
+ electron_wavelength *= 1e10
804
+ return electron_wavelength
805
+
806
+ def _update_parameters(self, **kwargs):
807
+ """Update multiple parameters of the CTF instance."""
808
+ voxel_based = [
809
+ "electron_wavelength",
810
+ "spherical_aberration",
811
+ "defocus_x",
812
+ "defocus_y",
813
+ ]
814
+
815
+ if "sampling_rate" in kwargs:
816
+ self.sampling_rate = kwargs["sampling_rate"]
817
+
818
+ if "acceleration_voltage" in kwargs:
819
+ kwargs["electron_wavelength"] = self._compute_electron_wavelength()
820
+
821
+ for key, value in kwargs.items():
822
+ if key in voxel_based:
823
+ value = value / self.sampling_rate
824
+ setattr(self, key, value)
825
+
826
+ def __call__(self, **kwargs) -> NDArray:
827
+ func_args = vars(self)
828
+ func_args.update(kwargs)
829
+
830
+ ret = self.weight(**func_args)
831
+ ret = backend.astype(
832
+ backend.to_backend_array(ret),
833
+ backend._default_dtype
834
+ )
835
+ return {
836
+ "data": ret,
837
+ "angles": func_args["angles"],
838
+ "tilt_axis": func_args["tilt_axis"],
839
+ "opening_axis": func_args["opening_axis"],
840
+ "is_multiplicative_filter": True,
841
+ }
842
+
843
+ def weight(
844
+ self,
845
+ shape: Tuple[int],
846
+ defocus_x: Tuple[float],
847
+ angles: Tuple[float],
848
+ electron_wavelength: float = None,
849
+ opening_axis: int = None,
850
+ tilt_axis: int = None,
851
+ amplitude_contrast: float = 0.07,
852
+ phase_shift: Tuple[float] = 0,
853
+ defocus_angle: Tuple[float] = 0,
854
+ defocus_y: Tuple[float] = None,
855
+ correct_defocus_gradient: bool = False,
856
+ sampling_rate: Tuple[float] = 1,
857
+ acceleration_voltage: float = 300e3,
858
+ spherical_aberration: float = 2.7e3,
859
+ **kwargs: Dict,
860
+ ) -> NDArray:
861
+ """
862
+ Compute the CTF weight tilt stack.
863
+
864
+ Parameters:
865
+ -----------
866
+ shape : tuple of int
867
+ The shape of the CTF.
868
+ defocus_x : tuple of float
869
+ The defocus value in x direction.
870
+ angles : tuple of float
871
+ The tilt angles.
872
+ electron_wavelength : float, optional
873
+ The electron wavelength, defaults to None.
874
+ opening_axis : int, optional
875
+ The axis around which the wedge is opened, defaults to None.
876
+ tilt_axis : int, optional
877
+ The axis along which the tilt is applied, defaults to None.
878
+ amplitude_contrast : float, optional
879
+ The amplitude contrast, defaults to 0.07.
880
+ phase_shift : tuple of float, optional
881
+ The phase shift, defaults to 0.
882
+ defocus_angle : tuple of float, optional
883
+ The defocus angle, defaults to 0.
884
+ defocus_y : tuple of float, optional
885
+ The defocus value in y direction, defaults to None.
886
+ correct_defocus_gradient : bool, optional
887
+ Whether to correct defocus gradient, defaults to False.
888
+ sampling_rate : tuple of float, optional
889
+ The sampling rate, defaults to 1.
890
+ acceleration_voltage : float, optional
891
+ The acceleration voltage in electron microscopy, defaults to 300e3.
892
+ spherical_aberration : float, optional
893
+ The spherical aberration coefficient, defaults to 2.7e3.
894
+ **kwargs : Dict
895
+ Additional keyword arguments.
896
+
897
+ Returns:
898
+ --------
899
+ NDArray
900
+ A stack containing the CTF weight.
901
+ """
902
+ tilt_shape = compute_tilt_shape(
903
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
904
+ )
905
+ stack = np.zeros((len(angles), *tilt_shape))
906
+ electron_wavelength = self._compute_electron_wavelength()
907
+ defoci_x = defocus_x
908
+ defoci_y = defocus_y
909
+ for index, angle in enumerate(angles):
910
+ grid = centered_grid(shape=tilt_shape)
911
+ grid = np.divide(grid.T, (sampling_rate, sampling_rate)).T
912
+
913
+ defocus_x, defocus_y = defoci_x[index], defoci_y[index]
914
+
915
+ # This should be done after defocus_x computation
916
+ if correct_defocus_gradient and len(shape) == 3:
917
+ angle_rad = np.radians(angle)
918
+
919
+ defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
920
+
921
+ remaining_axis = tuple(
922
+ i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
923
+ )[0]
924
+
925
+ if tilt_axis > remaining_axis:
926
+ defocus_x = np.add(
927
+ defocus_x, np.multiply(defocus_gradient, np.sin(angle_rad))
928
+ )
929
+ elif tilt_axis < remaining_axis and defocus_y is not None:
930
+ defocus_y = np.add(
931
+ defocus_y, np.multiply(defocus_gradient.T, np.sin(angle_rad))
932
+ )
933
+
934
+ if defocus_y is not None:
935
+ defocus_sum = np.add(defocus_x, defocus_y)
936
+ defocus_difference = np.subtract(defocus_x, defocus_y)
937
+ angular_grid = np.arctan2(grid[0], grid[1])
938
+ defocus_difference *= np.cos(2 * (angular_grid - defocus_angle[index]))
939
+ defocus_x = np.add(defocus_sum, defocus_difference)
940
+ defocus_x *= 0.5
941
+
942
+ frequency_grid = frequency_grid_at_angle(
943
+ shape=shape,
944
+ opening_axis=opening_axis,
945
+ tilt_axis=tilt_axis,
946
+ angle=angle,
947
+ sampling_rate=1,
948
+ )
949
+
950
+ np.square(frequency_grid, out=frequency_grid)
951
+ electron_aberration = spherical_aberration * electron_wavelength**2
952
+
953
+ chi = defocus_x - 0.5 * electron_aberration * frequency_grid
954
+ np.multiply(chi, np.pi * electron_wavelength, out=chi)
955
+ np.multiply(chi, frequency_grid, out=chi)
956
+
957
+ chi += phase_shift[index]
958
+ chi += np.arctan(
959
+ np.divide(
960
+ amplitude_contrast,
961
+ np.sqrt(1 - np.square(amplitude_contrast)),
962
+ )
963
+ )
964
+ np.sin(-chi, out=chi)
965
+ stack[index] = chi
966
+
967
+ return stack