pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1017 @@
1
+ """ Defines filters on tomographic tilt series.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ import re
8
+ from typing import Tuple, Dict
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+ from numpy.typing import NDArray
13
+
14
+ from .. import Preprocessor
15
+ from ..backends import backend
16
+ from ..matching_utils import euler_to_rotationmatrix
17
+
18
+ from ._utils import (
19
+ frequency_grid_at_angle,
20
+ compute_tilt_shape,
21
+ crop_real_fourier,
22
+ centered_grid,
23
+ fftfreqn,
24
+ shift_fourier,
25
+ )
26
+
27
+
28
+ def create_reconstruction_filter(
29
+ filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
30
+ ):
31
+ """Create a reconstruction filter of given filter_type.
32
+
33
+ Parameters
34
+ ----------
35
+ filter_shape : tuple of int
36
+ Shape of the returned filter
37
+ filter_type: str
38
+ The type of created filter, available options are:
39
+
40
+ +---------------+----------------------------------------------------+
41
+ | ram-lak | Returns |w| |
42
+ +---------------+----------------------------------------------------+
43
+ | ramp-cont | Principles of Computerized Tomographic Imaging Avin|
44
+ | | ash C. Kak and Malcolm Slaney Chap 3 Eq. 61 [1]_ |
45
+ +---------------+----------------------------------------------------+
46
+ | ramp | Like ramp-cont but considering tilt angles |
47
+ +---------------+----------------------------------------------------+
48
+ | shepp-logan | |w| * sinc(|w| / 2) [2]_ |
49
+ +---------------+----------------------------------------------------+
50
+ | cosine | |w| * cos(|w| * pi / 2) [2]_ |
51
+ +---------------+----------------------------------------------------+
52
+ | hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
53
+ +---------------+----------------------------------------------------+
54
+
55
+ kwargs: Dict
56
+ Keyword arguments for particular filter_types.
57
+
58
+ Returns
59
+ -------
60
+ NDArray
61
+ Reconstruction filter
62
+
63
+ References
64
+ ----------
65
+ .. [1] Principles of Computerized Tomographic Imaging Avinash C. Kak and Malcolm Slaney Chap 3 Eq. 61
66
+ .. [2] https://odlgroup.github.io/odl/index.html
67
+ """
68
+ filter_type = str(filter_type).lower()
69
+ freq = fftfreqn(filter_shape, sampling_rate=0.5, compute_euclidean_norm=True)
70
+
71
+ if filter_type == "ram-lak":
72
+ ret = np.copy(freq)
73
+ elif filter_type == "ramp-cont":
74
+ ret, ndim = None, len(filter_shape)
75
+ for dim, size in enumerate(filter_shape):
76
+ n = np.concatenate(
77
+ (
78
+ np.arange(1, size / 2 + 1, 2, dtype=int),
79
+ np.arange(size / 2 - 1, 0, -2, dtype=int),
80
+ )
81
+ )
82
+ ret1d = np.zeros(size)
83
+ ret1d[0] = 0.25
84
+ ret1d[1::2] = -1 / (np.pi * n) ** 2
85
+ ret1d_shape = tuple(size if i == dim else 1 for i in range(ndim))
86
+ ret1d = ret1d.reshape(ret1d_shape)
87
+ if ret is None:
88
+ ret = ret1d
89
+ else:
90
+ ret = ret * ret1d
91
+ ret = 2 * np.real(np.fft.fftn(ret))
92
+ elif filter_type == "ramp":
93
+ tilt_angles = kwargs.get("tilt_angles", False)
94
+ if tilt_angles is False:
95
+ raise ValueError("'ramp' filter requires specifying tilt angles.")
96
+ size, odd = filter_shape[0], filter_shape[0] % 2
97
+ ret = np.arange(-size // 2 + odd, size // 2 + odd, 1, dtype=np.float32)
98
+ ret /= size // 2
99
+ ret *= 0.5
100
+ np.abs(ret, out=ret)
101
+
102
+ min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
103
+ ret *= min_increment * size
104
+ np.fmin(ret, 1, out=ret)
105
+
106
+ ret = np.tile(ret[:, np.newaxis], (1, filter_shape[1]))
107
+
108
+ elif filter_type == "shepp-logan":
109
+ ret = freq * np.sinc(freq / 2)
110
+ elif filter_type == "cosine":
111
+ ret = freq * np.cos(freq * np.pi / 2)
112
+ elif filter_type == "hamming":
113
+ ret = freq * (0.54 + 0.46 * np.cos(freq * np.pi))
114
+ else:
115
+ raise ValueError("Unsupported filter type")
116
+
117
+ return ret
118
+
119
+
120
+ @dataclass
121
+ class ReconstructFromTilt:
122
+ """Reconstruct a volume from a tilt series."""
123
+
124
+ #: Shape of the reconstruction.
125
+ shape: Tuple[int] = None
126
+ #: Angle of each individual tilt.
127
+ angles: Tuple[float] = None
128
+ #: The axis around which the volume is opened.
129
+ opening_axis: int = 0
130
+ #: Axis the plane is tilted over.
131
+ tilt_axis: int = 2
132
+ #: Whether to return a share compliant with rfftn.
133
+ return_real_fourier: bool = True
134
+ #: Interpolation order used for rotation
135
+ interpolation_order: int = 1
136
+ #: Filter window applied during reconstruction.
137
+ reconstruction_filter: str = None
138
+
139
+ def __call__(self, **kwargs):
140
+ func_args = vars(self).copy()
141
+ func_args.update(kwargs)
142
+
143
+ ret = self.reconstruct(**func_args)
144
+
145
+ return {
146
+ "data": ret,
147
+ "shape": ret.shape,
148
+ "shape_is_real_fourier": func_args["return_real_fourier"],
149
+ "angles": func_args["angles"],
150
+ "tilt_axis": func_args["tilt_axis"],
151
+ "opening_axis": func_args["opening_axis"],
152
+ "is_multiplicative_filter": False,
153
+ }
154
+
155
+ @staticmethod
156
+ def reconstruct(
157
+ data: NDArray,
158
+ shape: Tuple[int],
159
+ angles: Tuple[float],
160
+ opening_axis: int,
161
+ tilt_axis: int,
162
+ interpolation_order: int = 1,
163
+ return_real_fourier: bool = True,
164
+ reconstruction_filter: str = None,
165
+ **kwargs,
166
+ ):
167
+ """
168
+ Reconstruct a volume from a tilt series.
169
+
170
+ Parameters:
171
+ -----------
172
+ data : NDArray
173
+ The tilt series data.
174
+ shape : tuple of int
175
+ Shape of the reconstruction.
176
+ angles : tuple of float
177
+ Angle of each individual tilt.
178
+ opening_axis : int
179
+ The axis around which the volume is opened.
180
+ tilt_axis : int
181
+ Axis the plane is tilted over.
182
+ interpolation_order : int, optional
183
+ Interpolation order used for rotation, defautls to 1.
184
+ return_real_fourier : bool, optional
185
+ Whether to return a shape compliant with rfftn, defaults to True.
186
+ reconstruction_filter : bool, optional
187
+ Filter window applied during reconstruction.
188
+ See :py:meth:`create_reconstruction_filter` for available options.
189
+
190
+ Returns:
191
+ --------
192
+ NDArray
193
+ The reconstructed volume.
194
+ """
195
+ if data.shape == shape:
196
+ return data
197
+
198
+ data = backend.to_backend_array(data)
199
+ volume_temp = backend.zeros(shape, dtype=backend._float_dtype)
200
+ volume_temp_rotated = backend.zeros(shape, dtype=backend._float_dtype)
201
+ volume = backend.zeros(shape, dtype=backend._float_dtype)
202
+
203
+ slices = tuple(
204
+ slice(a, a + 1) for a in backend.astype(backend.divide(shape, 2), int)
205
+ )
206
+ subset = tuple(
207
+ slice(None) if i != opening_axis else slices[opening_axis]
208
+ for i in range(len(shape))
209
+ )
210
+ angles_loop = backend.zeros(len(shape))
211
+ wedge_dim = [x for x in data.shape]
212
+ wedge_dim.insert(1 + opening_axis, 1)
213
+ wedges = backend.reshape(data, wedge_dim)
214
+
215
+ rec_filter = 1
216
+ if reconstruction_filter is not None:
217
+ rec_filter = create_reconstruction_filter(
218
+ filter_type=reconstruction_filter,
219
+ filter_shape=tuple(x for x in wedges[0].shape if x != 1),
220
+ tilt_angles=angles,
221
+ )
222
+ if tilt_axis > 0:
223
+ rec_filter = rec_filter.T
224
+
225
+ # This is most likely an upstream bug
226
+ if tilt_axis == 1 and opening_axis == 0:
227
+ rec_filter = rec_filter.T
228
+
229
+ rec_filter = backend.to_backend_array(rec_filter)
230
+ rec_filter = backend.reshape(rec_filter, wedges[0].shape)
231
+
232
+ for index in range(len(angles)):
233
+ backend.fill(angles_loop, 0)
234
+ backend.fill(volume_temp, 0)
235
+ backend.fill(volume_temp_rotated, 0)
236
+
237
+ volume_temp[subset] = wedges[index] * rec_filter
238
+
239
+ angles_loop[tilt_axis] = angles[index]
240
+ angles_loop = backend.roll(angles_loop, (opening_axis - 1,), axis=0)
241
+ rotation_matrix = euler_to_rotationmatrix(
242
+ backend.to_numpy_array(angles_loop)
243
+ )
244
+ rotation_matrix = backend.to_backend_array(rotation_matrix)
245
+
246
+ backend.rotate_array(
247
+ arr=volume_temp,
248
+ rotation_matrix=rotation_matrix,
249
+ out=volume_temp_rotated,
250
+ use_geometric_center=True,
251
+ order=interpolation_order,
252
+ )
253
+ backend.add(volume, volume_temp_rotated, out=volume)
254
+
255
+ volume = shift_fourier(data=volume, shape_is_real_fourier=False)
256
+
257
+ if return_real_fourier:
258
+ volume = crop_real_fourier(volume)
259
+
260
+ return volume
261
+
262
+
263
+ class Wedge:
264
+ """
265
+ Generate wedge mask for tomographic data.
266
+
267
+ Parameters:
268
+ -----------
269
+ shape : tuple of int
270
+ The shape of the reconstruction volume.
271
+ tilt_axis : int
272
+ Axis the plane is tilted over.
273
+ opening_axis : int
274
+ The axis around which the volume is opened.
275
+ angles : tuple of float
276
+ The tilt angles.
277
+ weights : tuple of float
278
+ The weights corresponding to each tilt angle.
279
+ weight_type : str, optional
280
+ The type of weighting to apply, defaults to None.
281
+ frequency_cutoff : float, optional
282
+ The frequency cutoff value, defaults to 0.5.
283
+
284
+ Returns:
285
+ --------
286
+ Dict
287
+ A dictionary containing weighted wedges and related information.
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ shape: Tuple[int],
293
+ tilt_axis: int,
294
+ opening_axis: int,
295
+ angles: Tuple[float],
296
+ weights: Tuple[float],
297
+ weight_type: str = None,
298
+ frequency_cutoff: float = 0.5,
299
+ ):
300
+ self.shape = shape
301
+ self.tilt_axis = tilt_axis
302
+ self.opening_axis = opening_axis
303
+ self.angles = angles
304
+ self.weights = weights
305
+ self.frequency_cutoff = frequency_cutoff
306
+
307
+ @classmethod
308
+ def from_file(cls, filename: str) -> "Wedge":
309
+ """
310
+ Generate a :py:class:`Wedge` instance by reading tilt angles and weights
311
+ from a tab-separated text file.
312
+
313
+ Parameters:
314
+ -----------
315
+ filename : str
316
+ The path to the file containing tilt angles and weights.
317
+
318
+ Returns:
319
+ --------
320
+ :py:class:`Wedge`
321
+ Class instance instance initialized with angles and weights from the file.
322
+ """
323
+ data = cls._from_text(filename)
324
+
325
+ angles, weights = data.get("angles", None), data.get("weights", None)
326
+ if angles is None:
327
+ raise ValueError(f"Could not find colum angles in {filename}")
328
+
329
+ if weights is None:
330
+ weights = [1] * len(angles)
331
+
332
+ if len(weights) != len(angles):
333
+ raise ValueError("Length of weights and angles differ.")
334
+
335
+ return cls(
336
+ shape=None,
337
+ tilt_axis=0,
338
+ opening_axis=2,
339
+ angles=np.array(angles, dtype=np.float32),
340
+ weights=np.array(weights, dtype=np.float32),
341
+ )
342
+
343
+ @staticmethod
344
+ def _from_text(filename: str, delimiter="\t") -> Dict:
345
+ """
346
+ Read column data from a text file.
347
+
348
+ Parameters:
349
+ -----------
350
+ filename : str
351
+ The path to the text file.
352
+ delimiter : str, optional
353
+ The delimiter used in the file, defaults to '\t'.
354
+
355
+ Returns:
356
+ --------
357
+ Dict
358
+ A dictionary with one key for each column.
359
+ """
360
+ with open(filename, mode="r", encoding="utf-8") as infile:
361
+ data = [x.strip() for x in infile.read().split("\n")]
362
+ data = [x.split("\t") for x in data if len(x)]
363
+
364
+ headers = data.pop(0)
365
+ ret = {header: list(column) for header, column in zip(headers, zip(*data))}
366
+
367
+ return ret
368
+
369
+ def __call__(self, **kwargs: Dict) -> NDArray:
370
+ func_args = vars(self).copy()
371
+ func_args.update(kwargs)
372
+
373
+ weight_types = {
374
+ None: self.weight_angle,
375
+ "angle": self.weight_angle,
376
+ "relion": self.weight_relion,
377
+ "grigorieff": self.weight_grigorieff,
378
+ }
379
+
380
+ weight_type = func_args.get("weight_type", None)
381
+ if weight_type not in weight_types:
382
+ raise ValueError(
383
+ f"Supported weight_types are {','.join(list(weight_types.keys()))}"
384
+ )
385
+
386
+ if weight_type == "angle":
387
+ func_args["weights"] = np.cos(np.radians(self.angles))
388
+
389
+ ret = weight_types[weight_type](**func_args)
390
+ ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
391
+
392
+ return {
393
+ "data": ret,
394
+ "angles": func_args["angles"],
395
+ "tilt_axis": func_args["tilt_axis"],
396
+ "opening_axis": func_args["opening_axis"],
397
+ "sampling_rate": func_args.get("sampling_rate", 1),
398
+ "is_multiplicative_filter": True,
399
+ }
400
+
401
+ @staticmethod
402
+ def weight_angle(
403
+ shape: Tuple[int],
404
+ weights: Tuple[float],
405
+ angles: Tuple[float],
406
+ opening_axis: int,
407
+ tilt_axis: int,
408
+ **kwargs,
409
+ ) -> NDArray:
410
+ """
411
+ Generate weighted wedges based on the cosine of the current angle.
412
+ """
413
+ tilt_shape = compute_tilt_shape(
414
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
415
+ )
416
+ wedge, wedges = np.ones(tilt_shape), np.zeros((len(angles), *tilt_shape))
417
+ for index, angle in enumerate(angles):
418
+ wedge.fill(weights[index])
419
+ wedges[index] = wedge
420
+
421
+ return wedges
422
+
423
+ def weight_relion(self, **kwargs) -> NDArray:
424
+ """
425
+ Generate weighted wedges based on the RELION 1.4 formalism, weighting each
426
+ angle using the cosine of the angle and a Gaussian lowpass filter computed
427
+ with respect to the exposure per angstrom.
428
+
429
+ Returns:
430
+ --------
431
+ NDArray
432
+ Weighted wedges.
433
+ """
434
+ tilt_shape = compute_tilt_shape(
435
+ shape=self.shape, opening_axis=self.opening_axis, reduce_dim=True
436
+ )
437
+
438
+ wedges = np.zeros((len(self.angles), *tilt_shape))
439
+ for index, angle in enumerate(self.angles):
440
+ frequency_grid = frequency_grid_at_angle(
441
+ shape=self.shape,
442
+ opening_axis=self.opening_axis,
443
+ tilt_axis=self.tilt_axis,
444
+ angle=angle,
445
+ sampling_rate=1,
446
+ )
447
+ # frequency_mask = frequency_grid <= self.frequency_cutoff
448
+
449
+ sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
450
+ sigma = -2 * np.pi**2 * sigma**2
451
+ np.square(frequency_grid, out=frequency_grid)
452
+ np.multiply(sigma, frequency_grid, out=frequency_grid)
453
+ np.exp(frequency_grid, out=frequency_grid)
454
+ np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
455
+ # np.multiply(frequency_grid, frequency_mask, out=frequency_grid)
456
+
457
+ wedges[index] = frequency_grid
458
+
459
+ return wedges
460
+
461
+ def weight_grigorieff(
462
+ self,
463
+ amplitude: float = 0.245,
464
+ power: float = -1.665,
465
+ offset: float = 2.81,
466
+ **kwargs,
467
+ ) -> NDArray:
468
+ """
469
+ Generate weighted wedges based on the formalism introduced in [1]_.
470
+
471
+ Returns:
472
+ --------
473
+ NDArray
474
+ Weighted wedges.
475
+
476
+ References
477
+ ----------
478
+ .. [1] Timothy GrantNikolaus Grigorieff (2015), eLife 4:e06980.
479
+ """
480
+ tilt_shape = compute_tilt_shape(
481
+ shape=self.shape,
482
+ opening_axis=self.opening_axis,
483
+ reduce_dim=True,
484
+ )
485
+
486
+ wedges = np.zeros((len(self.angles), *tilt_shape), dtype=backend._float_dtype)
487
+ for index, angle in enumerate(self.angles):
488
+ frequency_grid = frequency_grid_at_angle(
489
+ shape=self.shape,
490
+ opening_axis=self.opening_axis,
491
+ tilt_axis=self.tilt_axis,
492
+ angle=angle,
493
+ sampling_rate=1,
494
+ )
495
+ # frequency_mask = frequency_grid <= self.frequency_cutoff
496
+
497
+ with np.errstate(divide="ignore"):
498
+ np.power(frequency_grid, power, out=frequency_grid)
499
+ np.multiply(amplitude, frequency_grid, out=frequency_grid)
500
+ np.add(frequency_grid, offset, out=frequency_grid)
501
+ np.multiply(-2, frequency_grid, out=frequency_grid)
502
+ np.divide(
503
+ self.weights[index],
504
+ frequency_grid,
505
+ out=frequency_grid,
506
+ )
507
+
508
+ np.exp(frequency_grid, out=frequency_grid)
509
+ # np.multiply(frequency_grid, frequency_mask, out=frequency_grid)
510
+
511
+ wedges[index] = frequency_grid
512
+
513
+ return wedges
514
+
515
+
516
+ class WedgeReconstructed:
517
+ """
518
+ Initialize :py:class:`WedgeReconstructed`.
519
+
520
+ Parameters:
521
+ -----------
522
+ angles : Tuple[float], optional
523
+ The tilt angles, defaults to None.
524
+ opening_axis : int, optional
525
+ The axis around which the wedge is opened, defaults to 0.
526
+ tilt_axis : int, optional
527
+ The axis along which the tilt is applied, defaults to 2.
528
+ **kwargs : Dict
529
+ Additional keyword arguments.
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ angles: Tuple[float] = None,
535
+ start_tilt: float = None,
536
+ stop_tilt: float = None,
537
+ opening_axis: int = 0,
538
+ tilt_axis: int = 2,
539
+ weight_wedge: bool = False,
540
+ create_continuous_wedge: bool = False,
541
+ **kwargs: Dict,
542
+ ):
543
+ self.angles = angles
544
+ self.opening_axis = opening_axis
545
+ self.tilt_axis = tilt_axis
546
+ self.weight_wedge = weight_wedge
547
+ self.create_continuous_wedge = create_continuous_wedge
548
+
549
+ def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
550
+ """
551
+ Generate the reconstructed wedge.
552
+
553
+ Parameters:
554
+ -----------
555
+ shape : tuple of int
556
+ The shape of the reconstruction volume.
557
+ **kwargs : Dict
558
+ Additional keyword arguments.
559
+
560
+ Returns:
561
+ --------
562
+ Dict
563
+ A dictionary containing the reconstructed wedge and related information.
564
+ """
565
+ func_args = vars(self).copy()
566
+ func_args.update(kwargs)
567
+
568
+ if kwargs.get("is_fourier_shape", False):
569
+ print("Cannot create continuous wedge mask basde on real fourier shape.")
570
+
571
+ func = self.step_wedge
572
+ if func_args.get("create_continuous_wedge", False):
573
+ func = self.continuous_wedge
574
+
575
+ ret = func(shape=shape, **func_args)
576
+ ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
577
+
578
+ return {
579
+ "data": ret,
580
+ "shape_is_real_fourier": func_args["return_real_fourier"],
581
+ "shape": ret.shape,
582
+ "tilt_axis": func_args["tilt_axis"],
583
+ "opening_axis": func_args["opening_axis"],
584
+ "is_multiplicative_filter": True,
585
+ "angles": func_args["angles"],
586
+ }
587
+
588
+ @staticmethod
589
+ def continuous_wedge(
590
+ shape: Tuple[int],
591
+ angles: Tuple[float],
592
+ opening_axis: int,
593
+ tilt_axis: int,
594
+ return_real_fourier: bool,
595
+ **kwargs: Dict,
596
+ ) -> NDArray:
597
+ """
598
+ Generate a continuous reconstructed wedge.
599
+
600
+ Parameters:
601
+ -----------
602
+ shape : tuple of int
603
+ The shape of the reconstruction volume.
604
+ angles : tuple of float
605
+ Start and stop tilt angle.
606
+ opening_axis : int
607
+ The axis around which the wedge is opened.
608
+ tilt_axis : int
609
+ The axis along which the tilt is applied.
610
+ return_real_fourier : bool
611
+ Whether to return the real part of the Fourier transform.
612
+
613
+ Returns:
614
+ --------
615
+ NDArray
616
+ The reconstructed wedge.
617
+ """
618
+ preprocessor = Preprocessor()
619
+ start_tilt, stop_tilt = angles
620
+ ret = preprocessor.continuous_wedge_mask(
621
+ start_tilt=start_tilt,
622
+ stop_tilt=stop_tilt,
623
+ shape=shape,
624
+ opening_axis=opening_axis,
625
+ tilt_axis=tilt_axis,
626
+ omit_negative_frequencies=return_real_fourier,
627
+ infinite_plane=False,
628
+ )
629
+
630
+ return ret
631
+
632
+ @staticmethod
633
+ def step_wedge(
634
+ shape: Tuple[int],
635
+ angles: Tuple[float],
636
+ opening_axis: int,
637
+ tilt_axis: int,
638
+ return_real_fourier: bool,
639
+ weight_wedge: bool = False,
640
+ **kwargs: Dict,
641
+ ) -> NDArray:
642
+ """
643
+ Generate a step-wise reconstructed wedge.
644
+
645
+ Parameters:
646
+ -----------
647
+ shape : tuple of int
648
+ The shape of the reconstruction volume.
649
+ angles : tuple of float
650
+ The tilt angles.
651
+ opening_axis : int
652
+ The axis around which the wedge is opened.
653
+ tilt_axis : int
654
+ The axis along which the tilt is applied.
655
+ weight_wedge : bool, optional
656
+ Whether to weight the wedge by the cosine of the angle.
657
+ return_real_fourier : bool
658
+ Whether to return the real part of the Fourier transform.
659
+
660
+ Returns:
661
+ --------
662
+ NDArray
663
+ The reconstructed wedge.
664
+ """
665
+ preprocessor = Preprocessor()
666
+
667
+ angles = np.asarray(backend.to_numpy_array(angles))
668
+ weights = np.ones(angles.size)
669
+ if weight_wedge:
670
+ weights = np.cos(np.radians(angles))
671
+ ret = preprocessor.step_wedge_mask(
672
+ tilt_angles=angles,
673
+ weights=weights,
674
+ start_tilt=None,
675
+ stop_tilt=None,
676
+ tilt_step=None,
677
+ shape=shape,
678
+ opening_axis=opening_axis,
679
+ tilt_axis=tilt_axis,
680
+ omit_negative_frequencies=return_real_fourier,
681
+ )
682
+
683
+ return ret
684
+
685
+
686
+ @dataclass
687
+ class CTF:
688
+ """
689
+ Representation of a contrast transfer function (CTF) [1]_.
690
+
691
+ References
692
+ ----------
693
+ .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
694
+ Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
695
+ """
696
+
697
+ #: The shape of the to-be reconstructed volume.
698
+ shape: Tuple[int]
699
+ #: The defocus value in x direction.
700
+ defocus_x: float
701
+ #: The tilt angles.
702
+ angles: Tuple[float] = None
703
+ #: The axis around which the wedge is opened, defaults to None.
704
+ opening_axis: int = None
705
+ #: The axis along which the tilt is applied, defaults to None.
706
+ tilt_axis: int = None
707
+ #: Whether to correct defocus gradient, defaults to False.
708
+ correct_defocus_gradient: bool = False
709
+ #: The sampling rate, defaults to 1.
710
+ sampling_rate: Tuple[float] = 1
711
+ #: The acceleration voltage in Volts, defaults to 300e3.
712
+ acceleration_voltage: float = 300e3
713
+ #: The spherical aberration coefficient, defaults to 2.7e7.
714
+ spherical_aberration: float = 2.7e7
715
+ #: The amplitude contrast, defaults to 0.07.
716
+ amplitude_contrast: float = 0.07
717
+ #: The phase shift, defaults to 0.
718
+ phase_shift: float = 0
719
+ #: The defocus angle, defaults to 0.
720
+ defocus_angle: float = 0
721
+ #: The defocus value in y direction, defaults to None.
722
+ defocus_y: float = None
723
+ #: Whether the returned CTF should be phase-flipped.
724
+ flip_phase: bool = True
725
+ #: Whether to return a format compliant with rfft. Only relevant for single angles.
726
+ return_real_fourier: bool = False
727
+
728
+ @classmethod
729
+ def from_file(cls, filename: str) -> "CTF":
730
+ """
731
+ Initialize :py:class:`CTF` from file.
732
+
733
+ Parameters:
734
+ -----------
735
+ filename : str
736
+ The path to a file with ctf parameters. Supports the following formats:
737
+ - CTFFIND4
738
+ """
739
+ data = cls._from_ctffind(filename=filename)
740
+
741
+ return cls(
742
+ shape=None,
743
+ angles=None,
744
+ defocus_x=data["defocus_1"],
745
+ sampling_rate=data["pixel_size"],
746
+ acceleration_voltage=data["acceleration_voltage"],
747
+ spherical_aberration=data["spherical_aberration"],
748
+ amplitude_contrast=data["amplitude_contrast"],
749
+ phase_shift=data["additional_phase_shift"],
750
+ defocus_angle=np.degrees(data["azimuth_astigmatism"]),
751
+ defocus_y=data["defocus_2"],
752
+ )
753
+
754
+ @staticmethod
755
+ def _from_ctffind(filename: str):
756
+ parameter_regex = {
757
+ "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
758
+ "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
759
+ "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
760
+ "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
761
+ }
762
+
763
+ with open(filename, mode="r", encoding="utf-8") as infile:
764
+ lines = [x.strip() for x in infile.read().split("\n")]
765
+ lines = [x for x in lines if len(x)]
766
+
767
+ def _screen_params(line, params, output):
768
+ for parameter, regex_pattern in parameter_regex.items():
769
+ match = re.search(regex_pattern, line)
770
+ if match:
771
+ output[parameter] = float(match.group(1))
772
+
773
+ columns = {
774
+ "micrograph_number": 0,
775
+ "defocus_1": 1,
776
+ "defocus_2": 2,
777
+ "azimuth_astigmatism": 3,
778
+ "additional_phase_shift": 4,
779
+ "cross_correlation": 5,
780
+ "spacing": 6,
781
+ }
782
+ output = {k: [] for k in columns.keys()}
783
+ for line in lines:
784
+ if line.startswith("#"):
785
+ _screen_params(line, params=parameter_regex, output=output)
786
+ continue
787
+
788
+ values = line.split()
789
+ for key, value in columns.items():
790
+ output[key].append(float(values[value]))
791
+
792
+ for key in columns:
793
+ output[key] = np.array(output[key])
794
+
795
+ return output
796
+
797
+ def __post_init__(self):
798
+ self.defocus_angle = np.radians(self.defocus_angle)
799
+
800
+ kwargs = {
801
+ "defocus_x": self.defocus_x,
802
+ "defocus_y": self.defocus_y,
803
+ "spherical_aberration": self.spherical_aberration,
804
+ }
805
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
806
+ self._update_parameters(
807
+ electron_wavelength=self._compute_electron_wavelength(), **kwargs
808
+ )
809
+
810
+ def _compute_electron_wavelength(self, acceleration_voltage: int = None):
811
+ """Computes the wavelength of an electron in angstrom."""
812
+
813
+ if acceleration_voltage is None:
814
+ acceleration_voltage = self.acceleration_voltage
815
+
816
+ # Physical constants expressed in SI units
817
+ planck_constant = 6.62606896e-34
818
+ electron_charge = 1.60217646e-19
819
+ electron_mass = 9.10938215e-31
820
+ light_velocity = 299792458
821
+
822
+ energy = electron_charge * acceleration_voltage
823
+ denominator = energy**2
824
+ denominator += 2 * energy * electron_mass * light_velocity**2
825
+ electron_wavelength = np.divide(
826
+ planck_constant * light_velocity, np.sqrt(denominator)
827
+ )
828
+ electron_wavelength *= 1e10
829
+ return electron_wavelength
830
+
831
+ def _update_parameters(self, **kwargs):
832
+ """Update multiple parameters of the CTF instance."""
833
+ voxel_based = [
834
+ "electron_wavelength",
835
+ "spherical_aberration",
836
+ "defocus_x",
837
+ "defocus_y",
838
+ ]
839
+ if "sampling_rate" in kwargs:
840
+ self.sampling_rate = kwargs["sampling_rate"]
841
+
842
+ if "acceleration_voltage" in kwargs:
843
+ kwargs["electron_wavelength"] = self._compute_electron_wavelength()
844
+
845
+ for key, value in kwargs.items():
846
+ if key in voxel_based and value is not None:
847
+ value = np.divide(value, np.max(self.sampling_rate))
848
+ setattr(self, key, value)
849
+
850
+ def __call__(self, **kwargs) -> NDArray:
851
+ func_args = vars(self).copy()
852
+ func_args.update(kwargs)
853
+
854
+ if len(func_args["angles"]) != len(func_args["defocus_x"]):
855
+ func_args["angles"] = self.angles
856
+ func_args["return_real_fourier"] = False
857
+ func_args["tilt_axis"] = None
858
+ func_args["opening_axis"] = None
859
+
860
+ ret = self.weight(**func_args)
861
+ ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
862
+ return {
863
+ "data": ret,
864
+ "angles": func_args["angles"],
865
+ "tilt_axis": func_args["tilt_axis"],
866
+ "opening_axis": func_args["opening_axis"],
867
+ "is_multiplicative_filter": True,
868
+ }
869
+
870
+ def weight(
871
+ self,
872
+ shape: Tuple[int],
873
+ defocus_x: Tuple[float],
874
+ angles: Tuple[float],
875
+ electron_wavelength: float = None,
876
+ opening_axis: int = None,
877
+ tilt_axis: int = None,
878
+ amplitude_contrast: float = 0.07,
879
+ phase_shift: Tuple[float] = 0,
880
+ defocus_angle: Tuple[float] = 0,
881
+ defocus_y: Tuple[float] = None,
882
+ correct_defocus_gradient: bool = False,
883
+ sampling_rate: Tuple[float] = 1,
884
+ acceleration_voltage: float = 300e3,
885
+ spherical_aberration: float = 2.7e3,
886
+ flip_phase: bool = True,
887
+ return_real_fourier: bool = False,
888
+ **kwargs: Dict,
889
+ ) -> NDArray:
890
+ """
891
+ Compute the CTF weight tilt stack.
892
+
893
+ Parameters:
894
+ -----------
895
+ shape : tuple of int
896
+ The shape of the CTF.
897
+ defocus_x : tuple of float
898
+ The defocus value in x direction.
899
+ angles : tuple of float
900
+ The tilt angles.
901
+ electron_wavelength : float, optional
902
+ The electron wavelength, defaults to None.
903
+ opening_axis : int, optional
904
+ The axis around which the wedge is opened, defaults to None.
905
+ tilt_axis : int, optional
906
+ The axis along which the tilt is applied, defaults to None.
907
+ amplitude_contrast : float, optional
908
+ The amplitude contrast, defaults to 0.07.
909
+ phase_shift : tuple of float, optional
910
+ The phase shift, defaults to 0.
911
+ defocus_angle : tuple of float, optional
912
+ The defocus angle, defaults to 0.
913
+ defocus_y : tuple of float, optional
914
+ The defocus value in y direction, defaults to None.
915
+ correct_defocus_gradient : bool, optional
916
+ Whether to correct defocus gradient, defaults to False.
917
+ sampling_rate : tuple of float, optional
918
+ The sampling rate, defaults to 1.
919
+ acceleration_voltage : float, optional
920
+ The acceleration voltage in electron microscopy, defaults to 300e3.
921
+ spherical_aberration : float, optional
922
+ The spherical aberration coefficient, defaults to 2.7e3.
923
+ flip_phase : bool, optional
924
+ Whether the returned CTF should be phase-flipped.
925
+ **kwargs : Dict
926
+ Additional keyword arguments.
927
+
928
+ Returns:
929
+ --------
930
+ NDArray
931
+ A stack containing the CTF weight.
932
+ """
933
+ defoci_x = np.atleast_1d(defocus_x)
934
+ defoci_y = np.atleast_1d(defocus_y)
935
+ phase_shift = np.atleast_1d(phase_shift)
936
+ angles = np.atleast_1d(angles)
937
+ defocus_angle = np.atleast_1d(defocus_angle)
938
+
939
+ sampling_rate = np.max(sampling_rate)
940
+ tilt_shape = compute_tilt_shape(
941
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
942
+ )
943
+ stack = np.zeros((len(angles), *tilt_shape))
944
+ electron_wavelength = self._compute_electron_wavelength() / sampling_rate
945
+
946
+ correct_defocus_gradient &= len(shape) == 3
947
+ correct_defocus_gradient &= tilt_axis is not None
948
+ correct_defocus_gradient &= opening_axis is not None
949
+
950
+ for index, angle in enumerate(angles):
951
+ grid = backend.to_numpy_array(centered_grid(shape=tilt_shape))
952
+ grid = np.divide(grid.T, sampling_rate).T
953
+
954
+ defocus_x, defocus_y = defoci_x[index], defoci_y[index]
955
+
956
+ # This should be done after defocus_x computation
957
+ if correct_defocus_gradient:
958
+ angle_rad = np.radians(angle)
959
+
960
+ defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
961
+
962
+ remaining_axis = tuple(
963
+ i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
964
+ )[0]
965
+
966
+ if tilt_axis > remaining_axis:
967
+ defocus_x = np.add(defocus_x, defocus_gradient)
968
+ elif tilt_axis < remaining_axis and defocus_y is not None:
969
+ defocus_y = np.add(defocus_y, defocus_gradient.T)
970
+
971
+ if defocus_y is not None:
972
+ defocus_sum = np.add(defocus_x, defocus_y)
973
+ defocus_difference = np.subtract(defocus_x, defocus_y)
974
+ angular_grid = np.arctan2(grid[0], grid[1])
975
+ defocus_difference *= np.cos(2 * (angular_grid - defocus_angle[index]))
976
+ defocus_x = np.add(defocus_sum, defocus_difference)
977
+ defocus_x *= 0.5
978
+
979
+ frequency_grid = frequency_grid_at_angle(
980
+ shape=shape,
981
+ opening_axis=opening_axis,
982
+ tilt_axis=tilt_axis,
983
+ angle=angle,
984
+ sampling_rate=1,
985
+ )
986
+ frequency_grid *= frequency_grid <= 0.5
987
+ np.square(frequency_grid, out=frequency_grid)
988
+
989
+ electron_aberration = spherical_aberration * electron_wavelength**2
990
+ chi = defocus_x - 0.5 * electron_aberration * frequency_grid
991
+ np.multiply(chi, np.pi * electron_wavelength, out=chi)
992
+ np.multiply(chi, frequency_grid, out=chi)
993
+
994
+ chi += phase_shift[index]
995
+ chi += np.arctan(
996
+ np.divide(
997
+ amplitude_contrast,
998
+ np.sqrt(1 - np.square(amplitude_contrast)),
999
+ )
1000
+ )
1001
+ np.sin(-chi, out=chi)
1002
+ stack[index] = chi
1003
+
1004
+ if flip_phase:
1005
+ np.abs(stack, out=stack)
1006
+
1007
+ np.negative(stack, out=stack)
1008
+ stack = np.squeeze(stack)
1009
+
1010
+ stack = backend.to_backend_array(stack)
1011
+
1012
+ if len(angles) == 1:
1013
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
1014
+ if return_real_fourier:
1015
+ stack = crop_real_fourier(stack)
1016
+
1017
+ return stack