pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1011 @@
1
+ """ Defines filters on tomographic tilt series.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ import re
8
+ from typing import Tuple, Dict
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+
13
+ from .. import Preprocessor
14
+ from ..types import NDArray
15
+ from ..backends import backend as be
16
+ from ..matching_utils import euler_to_rotationmatrix
17
+ from ._utils import (
18
+ frequency_grid_at_angle,
19
+ compute_tilt_shape,
20
+ crop_real_fourier,
21
+ fftfreqn,
22
+ shift_fourier,
23
+ )
24
+
25
+
26
+ def create_reconstruction_filter(
27
+ filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
28
+ ):
29
+ """Create a reconstruction filter of given filter_type.
30
+
31
+ Parameters
32
+ ----------
33
+ filter_shape : tuple of int
34
+ Shape of the returned filter
35
+ filter_type: str
36
+ The type of created filter, available options are:
37
+
38
+ +---------------+----------------------------------------------------+
39
+ | ram-lak | Returns |w| |
40
+ +---------------+----------------------------------------------------+
41
+ | ramp-cont | Principles of Computerized Tomographic Imaging Avin|
42
+ | | ash C. Kak and Malcolm Slaney Chap 3 Eq. 61 [1]_ |
43
+ +---------------+----------------------------------------------------+
44
+ | ramp | Like ramp-cont but considering tilt angles |
45
+ +---------------+----------------------------------------------------+
46
+ | shepp-logan | |w| * sinc(|w| / 2) [2]_ |
47
+ +---------------+----------------------------------------------------+
48
+ | cosine | |w| * cos(|w| * pi / 2) [2]_ |
49
+ +---------------+----------------------------------------------------+
50
+ | hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
51
+ +---------------+----------------------------------------------------+
52
+ kwargs: Dict
53
+ Keyword arguments for particular filter_types.
54
+
55
+ Returns
56
+ -------
57
+ NDArray
58
+ Reconstruction filter
59
+
60
+ References
61
+ ----------
62
+ .. [1] Principles of Computerized Tomographic Imaging Avinash C. Kak and Malcolm Slaney Chap 3 Eq. 61
63
+ .. [2] https://odlgroup.github.io/odl/index.html
64
+ """
65
+ filter_type = str(filter_type).lower()
66
+ freq = fftfreqn(filter_shape, sampling_rate=0.5, compute_euclidean_norm=True)
67
+
68
+ if filter_type == "ram-lak":
69
+ ret = np.copy(freq)
70
+ elif filter_type == "ramp-cont":
71
+ ret, ndim = None, len(filter_shape)
72
+ for dim, size in enumerate(filter_shape):
73
+ n = np.concatenate(
74
+ (
75
+ np.arange(1, size / 2 + 1, 2, dtype=int),
76
+ np.arange(size / 2 - 1, 0, -2, dtype=int),
77
+ )
78
+ )
79
+ ret1d = np.zeros(size)
80
+ ret1d[0] = 0.25
81
+ ret1d[1::2] = -1 / (np.pi * n) ** 2
82
+ ret1d_shape = tuple(size if i == dim else 1 for i in range(ndim))
83
+ ret1d = ret1d.reshape(ret1d_shape)
84
+ if ret is None:
85
+ ret = ret1d
86
+ else:
87
+ ret = ret * ret1d
88
+ ret = 2 * np.real(np.fft.fftn(ret))
89
+ elif filter_type == "ramp":
90
+ tilt_angles = kwargs.get("tilt_angles", False)
91
+ if tilt_angles is False:
92
+ raise ValueError("'ramp' filter requires specifying tilt angles.")
93
+ size, odd = filter_shape[0], filter_shape[0] % 2
94
+ ret = np.arange(-size // 2 + odd, size // 2 + odd, 1, dtype=np.float32)
95
+ ret /= size // 2
96
+ ret *= 0.5
97
+ np.abs(ret, out=ret)
98
+
99
+ min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
100
+ ret *= min_increment * size
101
+ np.fmin(ret, 1, out=ret)
102
+
103
+ ret = np.tile(ret[:, np.newaxis], (1, filter_shape[1]))
104
+
105
+ elif filter_type == "shepp-logan":
106
+ ret = freq * np.sinc(freq / 2)
107
+ elif filter_type == "cosine":
108
+ ret = freq * np.cos(freq * np.pi / 2)
109
+ elif filter_type == "hamming":
110
+ ret = freq * (0.54 + 0.46 * np.cos(freq * np.pi))
111
+ else:
112
+ raise ValueError("Unsupported filter type")
113
+
114
+ return ret
115
+
116
+
117
+ @dataclass
118
+ class ReconstructFromTilt:
119
+ """Reconstruct a volume from a tilt series."""
120
+
121
+ #: Shape of the reconstruction.
122
+ shape: Tuple[int] = None
123
+ #: Angle of each individual tilt.
124
+ angles: Tuple[float] = None
125
+ #: The axis around which the volume is opened.
126
+ opening_axis: int = 0
127
+ #: Axis the plane is tilted over.
128
+ tilt_axis: int = 2
129
+ #: Whether to return a share compliant with rfftn.
130
+ return_real_fourier: bool = True
131
+ #: Interpolation order used for rotation
132
+ interpolation_order: int = 1
133
+ #: Filter window applied during reconstruction.
134
+ reconstruction_filter: str = None
135
+
136
+ def __call__(self, **kwargs):
137
+ func_args = vars(self).copy()
138
+ func_args.update(kwargs)
139
+
140
+ ret = self.reconstruct(**func_args)
141
+
142
+ return {
143
+ "data": ret,
144
+ "shape": ret.shape,
145
+ "shape_is_real_fourier": func_args["return_real_fourier"],
146
+ "angles": func_args["angles"],
147
+ "tilt_axis": func_args["tilt_axis"],
148
+ "opening_axis": func_args["opening_axis"],
149
+ "is_multiplicative_filter": False,
150
+ }
151
+
152
+ @staticmethod
153
+ def reconstruct(
154
+ data: NDArray,
155
+ shape: Tuple[int],
156
+ angles: Tuple[float],
157
+ opening_axis: int,
158
+ tilt_axis: int,
159
+ interpolation_order: int = 1,
160
+ return_real_fourier: bool = True,
161
+ reconstruction_filter: str = None,
162
+ **kwargs,
163
+ ):
164
+ """
165
+ Reconstruct a volume from a tilt series.
166
+
167
+ Parameters:
168
+ -----------
169
+ data : NDArray
170
+ The tilt series data.
171
+ shape : tuple of int
172
+ Shape of the reconstruction.
173
+ angles : tuple of float
174
+ Angle of each individual tilt.
175
+ opening_axis : int
176
+ The axis around which the volume is opened.
177
+ tilt_axis : int
178
+ Axis the plane is tilted over.
179
+ interpolation_order : int, optional
180
+ Interpolation order used for rotation, defautls to 1.
181
+ return_real_fourier : bool, optional
182
+ Whether to return a shape compliant with rfftn, defaults to True.
183
+ reconstruction_filter : bool, optional
184
+ Filter window applied during reconstruction.
185
+ See :py:meth:`create_reconstruction_filter` for available options.
186
+
187
+ Returns:
188
+ --------
189
+ NDArray
190
+ The reconstructed volume.
191
+ """
192
+ if data.shape == shape:
193
+ return data
194
+
195
+ data = be.to_backend_array(data)
196
+ volume_temp = be.zeros(shape, dtype=be._float_dtype)
197
+ volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
198
+ volume = be.zeros(shape, dtype=be._float_dtype)
199
+
200
+ slices = tuple(slice(a, a + 1) for a in be.astype(be.divide(shape, 2), int))
201
+ subset = tuple(
202
+ slice(None) if i != opening_axis else slices[opening_axis]
203
+ for i in range(len(shape))
204
+ )
205
+ angles_loop = be.zeros(len(shape))
206
+ wedge_dim = [x for x in data.shape]
207
+ wedge_dim.insert(1 + opening_axis, 1)
208
+ wedges = be.reshape(data, wedge_dim)
209
+
210
+ rec_filter = 1
211
+ if reconstruction_filter is not None:
212
+ rec_filter = create_reconstruction_filter(
213
+ filter_type=reconstruction_filter,
214
+ filter_shape=tuple(x for x in wedges[0].shape if x != 1),
215
+ tilt_angles=angles,
216
+ )
217
+ if tilt_axis > 0:
218
+ rec_filter = rec_filter.T
219
+
220
+ # This is most likely an upstream bug
221
+ if tilt_axis == 1 and opening_axis == 0:
222
+ rec_filter = rec_filter.T
223
+
224
+ rec_filter = be.to_backend_array(rec_filter)
225
+ rec_filter = be.reshape(rec_filter, wedges[0].shape)
226
+
227
+ for index in range(len(angles)):
228
+ be.fill(angles_loop, 0)
229
+ be.fill(volume_temp, 0)
230
+ be.fill(volume_temp_rotated, 0)
231
+
232
+ volume_temp[subset] = wedges[index] * rec_filter
233
+
234
+ angles_loop[tilt_axis] = angles[index]
235
+ angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0)
236
+ rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop))
237
+ rotation_matrix = be.to_backend_array(rotation_matrix)
238
+
239
+ be.rigid_transform(
240
+ arr=volume_temp,
241
+ rotation_matrix=rotation_matrix,
242
+ out=volume_temp_rotated,
243
+ use_geometric_center=True,
244
+ order=interpolation_order,
245
+ )
246
+ be.add(volume, volume_temp_rotated, out=volume)
247
+
248
+ volume = shift_fourier(data=volume, shape_is_real_fourier=False)
249
+
250
+ if return_real_fourier:
251
+ volume = crop_real_fourier(volume)
252
+
253
+ return volume
254
+
255
+
256
+ class Wedge:
257
+ """
258
+ Generate wedge mask for tomographic data.
259
+
260
+ Parameters:
261
+ -----------
262
+ shape : tuple of int
263
+ The shape of the reconstruction volume.
264
+ tilt_axis : int
265
+ Axis the plane is tilted over.
266
+ opening_axis : int
267
+ The axis around which the volume is opened.
268
+ angles : tuple of float
269
+ The tilt angles.
270
+ weights : tuple of float
271
+ The weights corresponding to each tilt angle.
272
+ weight_type : str, optional
273
+ The type of weighting to apply, defaults to None.
274
+ frequency_cutoff : float, optional
275
+ The frequency cutoff value, defaults to 0.5.
276
+
277
+ Returns:
278
+ --------
279
+ Dict
280
+ A dictionary containing weighted wedges and related information.
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ shape: Tuple[int],
286
+ tilt_axis: int,
287
+ opening_axis: int,
288
+ angles: Tuple[float],
289
+ weights: Tuple[float],
290
+ weight_type: str = None,
291
+ frequency_cutoff: float = 0.5,
292
+ ):
293
+ self.shape = shape
294
+ self.tilt_axis = tilt_axis
295
+ self.opening_axis = opening_axis
296
+ self.angles = angles
297
+ self.weights = weights
298
+ self.frequency_cutoff = frequency_cutoff
299
+
300
+ @classmethod
301
+ def from_file(cls, filename: str) -> "Wedge":
302
+ """
303
+ Generate a :py:class:`Wedge` instance by reading tilt angles and weights
304
+ from a tab-separated text file.
305
+
306
+ Parameters:
307
+ -----------
308
+ filename : str
309
+ The path to the file containing tilt angles and weights.
310
+
311
+ Returns:
312
+ --------
313
+ :py:class:`Wedge`
314
+ Class instance instance initialized with angles and weights from the file.
315
+ """
316
+ data = cls._from_text(filename)
317
+
318
+ angles, weights = data.get("angles", None), data.get("weights", None)
319
+ if angles is None:
320
+ raise ValueError(f"Could not find colum angles in {filename}")
321
+
322
+ if weights is None:
323
+ weights = [1] * len(angles)
324
+
325
+ if len(weights) != len(angles):
326
+ raise ValueError("Length of weights and angles differ.")
327
+
328
+ return cls(
329
+ shape=None,
330
+ tilt_axis=0,
331
+ opening_axis=2,
332
+ angles=np.array(angles, dtype=np.float32),
333
+ weights=np.array(weights, dtype=np.float32),
334
+ )
335
+
336
+ @staticmethod
337
+ def _from_text(filename: str, delimiter="\t") -> Dict:
338
+ """
339
+ Read column data from a text file.
340
+
341
+ Parameters:
342
+ -----------
343
+ filename : str
344
+ The path to the text file.
345
+ delimiter : str, optional
346
+ The delimiter used in the file, defaults to '\t'.
347
+
348
+ Returns:
349
+ --------
350
+ Dict
351
+ A dictionary with one key for each column.
352
+ """
353
+ with open(filename, mode="r", encoding="utf-8") as infile:
354
+ data = [x.strip() for x in infile.read().split("\n")]
355
+ data = [x.split("\t") for x in data if len(x)]
356
+
357
+ headers = data.pop(0)
358
+ ret = {header: list(column) for header, column in zip(headers, zip(*data))}
359
+
360
+ return ret
361
+
362
+ def __call__(self, **kwargs: Dict) -> NDArray:
363
+ func_args = vars(self).copy()
364
+ func_args.update(kwargs)
365
+
366
+ weight_types = {
367
+ None: self.weight_angle,
368
+ "angle": self.weight_angle,
369
+ "relion": self.weight_relion,
370
+ "grigorieff": self.weight_grigorieff,
371
+ }
372
+
373
+ weight_type = func_args.get("weight_type", None)
374
+ if weight_type not in weight_types:
375
+ raise ValueError(
376
+ f"Supported weight_types are {','.join(list(weight_types.keys()))}"
377
+ )
378
+
379
+ if weight_type == "angle":
380
+ func_args["weights"] = np.cos(np.radians(self.angles))
381
+
382
+ ret = weight_types[weight_type](**func_args)
383
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
384
+
385
+ return {
386
+ "data": ret,
387
+ "angles": func_args["angles"],
388
+ "tilt_axis": func_args["tilt_axis"],
389
+ "opening_axis": func_args["opening_axis"],
390
+ "sampling_rate": func_args.get("sampling_rate", 1),
391
+ "is_multiplicative_filter": True,
392
+ }
393
+
394
+ @staticmethod
395
+ def weight_angle(
396
+ shape: Tuple[int],
397
+ weights: Tuple[float],
398
+ angles: Tuple[float],
399
+ opening_axis: int,
400
+ tilt_axis: int,
401
+ **kwargs,
402
+ ) -> NDArray:
403
+ """
404
+ Generate weighted wedges based on the cosine of the current angle.
405
+ """
406
+ tilt_shape = compute_tilt_shape(
407
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
408
+ )
409
+ wedge, wedges = np.ones(tilt_shape), np.zeros((len(angles), *tilt_shape))
410
+ for index, angle in enumerate(angles):
411
+ wedge.fill(weights[index])
412
+ wedges[index] = wedge
413
+
414
+ return wedges
415
+
416
+ def weight_relion(self, **kwargs) -> NDArray:
417
+ """
418
+ Generate weighted wedges based on the RELION 1.4 formalism, weighting each
419
+ angle using the cosine of the angle and a Gaussian lowpass filter computed
420
+ with respect to the exposure per angstrom.
421
+
422
+ Returns:
423
+ --------
424
+ NDArray
425
+ Weighted wedges.
426
+ """
427
+ tilt_shape = compute_tilt_shape(
428
+ shape=self.shape, opening_axis=self.opening_axis, reduce_dim=True
429
+ )
430
+
431
+ wedges = np.zeros((len(self.angles), *tilt_shape))
432
+ for index, angle in enumerate(self.angles):
433
+ frequency_grid = frequency_grid_at_angle(
434
+ shape=self.shape,
435
+ opening_axis=self.opening_axis,
436
+ tilt_axis=self.tilt_axis,
437
+ angle=angle,
438
+ sampling_rate=1,
439
+ )
440
+ # frequency_mask = frequency_grid <= self.frequency_cutoff
441
+
442
+ sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
443
+ sigma = -2 * np.pi**2 * sigma**2
444
+ np.square(frequency_grid, out=frequency_grid)
445
+ np.multiply(sigma, frequency_grid, out=frequency_grid)
446
+ np.exp(frequency_grid, out=frequency_grid)
447
+ np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
448
+ # np.multiply(frequency_grid, frequency_mask, out=frequency_grid)
449
+
450
+ wedges[index] = frequency_grid
451
+
452
+ return wedges
453
+
454
+ def weight_grigorieff(
455
+ self,
456
+ amplitude: float = 0.245,
457
+ power: float = -1.665,
458
+ offset: float = 2.81,
459
+ **kwargs,
460
+ ) -> NDArray:
461
+ """
462
+ Generate weighted wedges based on the formalism introduced in [1]_.
463
+
464
+ Returns:
465
+ --------
466
+ NDArray
467
+ Weighted wedges.
468
+
469
+ References
470
+ ----------
471
+ .. [1] Timothy GrantNikolaus Grigorieff (2015), eLife 4:e06980.
472
+ """
473
+ tilt_shape = compute_tilt_shape(
474
+ shape=self.shape,
475
+ opening_axis=self.opening_axis,
476
+ reduce_dim=True,
477
+ )
478
+
479
+ wedges = np.zeros((len(self.angles), *tilt_shape), dtype=be._float_dtype)
480
+ for index, angle in enumerate(self.angles):
481
+ frequency_grid = frequency_grid_at_angle(
482
+ shape=self.shape,
483
+ opening_axis=self.opening_axis,
484
+ tilt_axis=self.tilt_axis,
485
+ angle=angle,
486
+ sampling_rate=1,
487
+ )
488
+ # frequency_mask = frequency_grid <= self.frequency_cutoff
489
+
490
+ with np.errstate(divide="ignore"):
491
+ np.power(frequency_grid, power, out=frequency_grid)
492
+ np.multiply(amplitude, frequency_grid, out=frequency_grid)
493
+ np.add(frequency_grid, offset, out=frequency_grid)
494
+ np.multiply(-2, frequency_grid, out=frequency_grid)
495
+ np.divide(
496
+ self.weights[index],
497
+ frequency_grid,
498
+ out=frequency_grid,
499
+ )
500
+
501
+ np.exp(frequency_grid, out=frequency_grid)
502
+ # np.multiply(frequency_grid, frequency_mask, out=frequency_grid)
503
+
504
+ wedges[index] = frequency_grid
505
+
506
+ return wedges
507
+
508
+
509
+ class WedgeReconstructed:
510
+ """
511
+ Initialize :py:class:`WedgeReconstructed`.
512
+
513
+ Parameters:
514
+ -----------
515
+ angles : Tuple[float], optional
516
+ The tilt angles, defaults to None.
517
+ opening_axis : int, optional
518
+ The axis around which the wedge is opened, defaults to 0.
519
+ tilt_axis : int, optional
520
+ The axis along which the tilt is applied, defaults to 2.
521
+ **kwargs : Dict
522
+ Additional keyword arguments.
523
+ """
524
+
525
+ def __init__(
526
+ self,
527
+ angles: Tuple[float] = None,
528
+ start_tilt: float = None,
529
+ stop_tilt: float = None,
530
+ opening_axis: int = 0,
531
+ tilt_axis: int = 2,
532
+ weight_wedge: bool = False,
533
+ create_continuous_wedge: bool = False,
534
+ **kwargs: Dict,
535
+ ):
536
+ self.angles = angles
537
+ self.opening_axis = opening_axis
538
+ self.tilt_axis = tilt_axis
539
+ self.weight_wedge = weight_wedge
540
+ self.create_continuous_wedge = create_continuous_wedge
541
+
542
+ def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
543
+ """
544
+ Generate the reconstructed wedge.
545
+
546
+ Parameters:
547
+ -----------
548
+ shape : tuple of int
549
+ The shape of the reconstruction volume.
550
+ **kwargs : Dict
551
+ Additional keyword arguments.
552
+
553
+ Returns:
554
+ --------
555
+ Dict
556
+ A dictionary containing the reconstructed wedge and related information.
557
+ """
558
+ func_args = vars(self).copy()
559
+ func_args.update(kwargs)
560
+
561
+ if kwargs.get("is_fourier_shape", False):
562
+ print("Cannot create continuous wedge mask basde on real fourier shape.")
563
+
564
+ func = self.step_wedge
565
+ if func_args.get("create_continuous_wedge", False):
566
+ func = self.continuous_wedge
567
+
568
+ ret = func(shape=shape, **func_args)
569
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
570
+
571
+ return {
572
+ "data": ret,
573
+ "shape_is_real_fourier": func_args["return_real_fourier"],
574
+ "shape": ret.shape,
575
+ "tilt_axis": func_args["tilt_axis"],
576
+ "opening_axis": func_args["opening_axis"],
577
+ "is_multiplicative_filter": True,
578
+ "angles": func_args["angles"],
579
+ }
580
+
581
+ @staticmethod
582
+ def continuous_wedge(
583
+ shape: Tuple[int],
584
+ angles: Tuple[float],
585
+ opening_axis: int,
586
+ tilt_axis: int,
587
+ return_real_fourier: bool,
588
+ **kwargs: Dict,
589
+ ) -> NDArray:
590
+ """
591
+ Generate a continuous reconstructed wedge.
592
+
593
+ Parameters:
594
+ -----------
595
+ shape : tuple of int
596
+ The shape of the reconstruction volume.
597
+ angles : tuple of float
598
+ Start and stop tilt angle.
599
+ opening_axis : int
600
+ The axis around which the wedge is opened.
601
+ tilt_axis : int
602
+ The axis along which the tilt is applied.
603
+ return_real_fourier : bool
604
+ Whether to return the real part of the Fourier transform.
605
+
606
+ Returns:
607
+ --------
608
+ NDArray
609
+ The reconstructed wedge.
610
+ """
611
+ preprocessor = Preprocessor()
612
+ start_tilt, stop_tilt = angles
613
+ ret = preprocessor.continuous_wedge_mask(
614
+ start_tilt=start_tilt,
615
+ stop_tilt=stop_tilt,
616
+ shape=shape,
617
+ opening_axis=opening_axis,
618
+ tilt_axis=tilt_axis,
619
+ omit_negative_frequencies=return_real_fourier,
620
+ infinite_plane=False,
621
+ )
622
+
623
+ return ret
624
+
625
+ @staticmethod
626
+ def step_wedge(
627
+ shape: Tuple[int],
628
+ angles: Tuple[float],
629
+ opening_axis: int,
630
+ tilt_axis: int,
631
+ return_real_fourier: bool,
632
+ weight_wedge: bool = False,
633
+ **kwargs: Dict,
634
+ ) -> NDArray:
635
+ """
636
+ Generate a step-wise reconstructed wedge.
637
+
638
+ Parameters:
639
+ -----------
640
+ shape : tuple of int
641
+ The shape of the reconstruction volume.
642
+ angles : tuple of float
643
+ The tilt angles.
644
+ opening_axis : int
645
+ The axis around which the wedge is opened.
646
+ tilt_axis : int
647
+ The axis along which the tilt is applied.
648
+ weight_wedge : bool, optional
649
+ Whether to weight the wedge by the cosine of the angle.
650
+ return_real_fourier : bool
651
+ Whether to return the real part of the Fourier transform.
652
+
653
+ Returns:
654
+ --------
655
+ NDArray
656
+ The reconstructed wedge.
657
+ """
658
+ preprocessor = Preprocessor()
659
+
660
+ angles = np.asarray(be.to_numpy_array(angles))
661
+ weights = np.ones(angles.size)
662
+ if weight_wedge:
663
+ weights = np.cos(np.radians(angles))
664
+ ret = preprocessor.step_wedge_mask(
665
+ tilt_angles=angles,
666
+ weights=weights,
667
+ start_tilt=None,
668
+ stop_tilt=None,
669
+ tilt_step=None,
670
+ shape=shape,
671
+ opening_axis=opening_axis,
672
+ tilt_axis=tilt_axis,
673
+ omit_negative_frequencies=return_real_fourier,
674
+ )
675
+
676
+ return ret
677
+
678
+
679
+ @dataclass
680
+ class CTF:
681
+ """
682
+ Representation of a contrast transfer function (CTF) [1]_.
683
+
684
+ References
685
+ ----------
686
+ .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs.
687
+ Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015.
688
+ """
689
+
690
+ #: The shape of the to-be reconstructed volume.
691
+ shape: Tuple[int]
692
+ #: The defocus value in x direction.
693
+ defocus_x: float
694
+ #: The tilt angles.
695
+ angles: Tuple[float] = None
696
+ #: The axis around which the wedge is opened, defaults to None.
697
+ opening_axis: int = None
698
+ #: The axis along which the tilt is applied, defaults to None.
699
+ tilt_axis: int = None
700
+ #: Whether to correct defocus gradient, defaults to False.
701
+ correct_defocus_gradient: bool = False
702
+ #: The sampling rate, defaults to 1.
703
+ sampling_rate: Tuple[float] = 1
704
+ #: The acceleration voltage in Volts, defaults to 300e3.
705
+ acceleration_voltage: float = 300e3
706
+ #: The spherical aberration coefficient, defaults to 2.7e7.
707
+ spherical_aberration: float = 2.7e7
708
+ #: The amplitude contrast, defaults to 0.07.
709
+ amplitude_contrast: float = 0.07
710
+ #: The phase shift, defaults to 0.
711
+ phase_shift: float = 0
712
+ #: The defocus angle, defaults to 0.
713
+ defocus_angle: float = 0
714
+ #: The defocus value in y direction, defaults to None.
715
+ defocus_y: float = None
716
+ #: Whether the returned CTF should be phase-flipped.
717
+ flip_phase: bool = True
718
+ #: Whether to return a format compliant with rfft. Only relevant for single angles.
719
+ return_real_fourier: bool = False
720
+
721
+ @classmethod
722
+ def from_file(cls, filename: str) -> "CTF":
723
+ """
724
+ Initialize :py:class:`CTF` from file.
725
+
726
+ Parameters:
727
+ -----------
728
+ filename : str
729
+ The path to a file with ctf parameters. Supports the following formats:
730
+ - CTFFIND4
731
+ """
732
+ data = cls._from_ctffind(filename=filename)
733
+
734
+ return cls(
735
+ shape=None,
736
+ angles=None,
737
+ defocus_x=data["defocus_1"],
738
+ sampling_rate=data["pixel_size"],
739
+ acceleration_voltage=data["acceleration_voltage"],
740
+ spherical_aberration=data["spherical_aberration"],
741
+ amplitude_contrast=data["amplitude_contrast"],
742
+ phase_shift=data["additional_phase_shift"],
743
+ defocus_angle=np.degrees(data["azimuth_astigmatism"]),
744
+ defocus_y=data["defocus_2"],
745
+ )
746
+
747
+ @staticmethod
748
+ def _from_ctffind(filename: str):
749
+ parameter_regex = {
750
+ "pixel_size": r"Pixel size: ([0-9.]+) Angstroms",
751
+ "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV",
752
+ "spherical_aberration": r"spherical aberration: ([0-9.]+) mm",
753
+ "amplitude_contrast": r"amplitude contrast: ([0-9.]+)",
754
+ }
755
+
756
+ with open(filename, mode="r", encoding="utf-8") as infile:
757
+ lines = [x.strip() for x in infile.read().split("\n")]
758
+ lines = [x for x in lines if len(x)]
759
+
760
+ def _screen_params(line, params, output):
761
+ for parameter, regex_pattern in parameter_regex.items():
762
+ match = re.search(regex_pattern, line)
763
+ if match:
764
+ output[parameter] = float(match.group(1))
765
+
766
+ columns = {
767
+ "micrograph_number": 0,
768
+ "defocus_1": 1,
769
+ "defocus_2": 2,
770
+ "azimuth_astigmatism": 3,
771
+ "additional_phase_shift": 4,
772
+ "cross_correlation": 5,
773
+ "spacing": 6,
774
+ }
775
+ output = {k: [] for k in columns.keys()}
776
+ for line in lines:
777
+ if line.startswith("#"):
778
+ _screen_params(line, params=parameter_regex, output=output)
779
+ continue
780
+
781
+ values = line.split()
782
+ for key, value in columns.items():
783
+ output[key].append(float(values[value]))
784
+
785
+ for key in columns:
786
+ output[key] = np.array(output[key])
787
+
788
+ return output
789
+
790
+ def __post_init__(self):
791
+ self.defocus_angle = np.radians(self.defocus_angle)
792
+
793
+ kwargs = {
794
+ "defocus_x": self.defocus_x,
795
+ "defocus_y": self.defocus_y,
796
+ "spherical_aberration": self.spherical_aberration,
797
+ }
798
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
799
+ self._update_parameters(
800
+ electron_wavelength=self._compute_electron_wavelength(), **kwargs
801
+ )
802
+
803
+ def _compute_electron_wavelength(self, acceleration_voltage: int = None):
804
+ """Computes the wavelength of an electron in angstrom."""
805
+
806
+ if acceleration_voltage is None:
807
+ acceleration_voltage = self.acceleration_voltage
808
+
809
+ # Physical constants expressed in SI units
810
+ planck_constant = 6.62606896e-34
811
+ electron_charge = 1.60217646e-19
812
+ electron_mass = 9.10938215e-31
813
+ light_velocity = 299792458
814
+
815
+ energy = electron_charge * acceleration_voltage
816
+ denominator = energy**2
817
+ denominator += 2 * energy * electron_mass * light_velocity**2
818
+ electron_wavelength = np.divide(
819
+ planck_constant * light_velocity, np.sqrt(denominator)
820
+ )
821
+ electron_wavelength *= 1e10
822
+ return electron_wavelength
823
+
824
+ def _update_parameters(self, **kwargs):
825
+ """Update multiple parameters of the CTF instance."""
826
+ voxel_based = [
827
+ "electron_wavelength",
828
+ "spherical_aberration",
829
+ "defocus_x",
830
+ "defocus_y",
831
+ ]
832
+ if "sampling_rate" in kwargs:
833
+ self.sampling_rate = kwargs["sampling_rate"]
834
+
835
+ if "acceleration_voltage" in kwargs:
836
+ kwargs["electron_wavelength"] = self._compute_electron_wavelength()
837
+
838
+ for key, value in kwargs.items():
839
+ if key in voxel_based and value is not None:
840
+ value = np.divide(value, np.max(self.sampling_rate))
841
+ setattr(self, key, value)
842
+
843
+ def __call__(self, **kwargs) -> NDArray:
844
+ func_args = vars(self).copy()
845
+ func_args.update(kwargs)
846
+
847
+ if len(func_args["angles"]) != len(func_args["defocus_x"]):
848
+ func_args["angles"] = self.angles
849
+ func_args["return_real_fourier"] = False
850
+ func_args["tilt_axis"] = None
851
+ func_args["opening_axis"] = None
852
+
853
+ ret = self.weight(**func_args)
854
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
855
+ return {
856
+ "data": ret,
857
+ "angles": func_args["angles"],
858
+ "tilt_axis": func_args["tilt_axis"],
859
+ "opening_axis": func_args["opening_axis"],
860
+ "is_multiplicative_filter": True,
861
+ }
862
+
863
+ def weight(
864
+ self,
865
+ shape: Tuple[int],
866
+ defocus_x: Tuple[float],
867
+ angles: Tuple[float],
868
+ electron_wavelength: float = None,
869
+ opening_axis: int = None,
870
+ tilt_axis: int = None,
871
+ amplitude_contrast: float = 0.07,
872
+ phase_shift: Tuple[float] = 0,
873
+ defocus_angle: Tuple[float] = 0,
874
+ defocus_y: Tuple[float] = None,
875
+ correct_defocus_gradient: bool = False,
876
+ sampling_rate: Tuple[float] = 1,
877
+ acceleration_voltage: float = 300e3,
878
+ spherical_aberration: float = 2.7e3,
879
+ flip_phase: bool = True,
880
+ return_real_fourier: bool = False,
881
+ **kwargs: Dict,
882
+ ) -> NDArray:
883
+ """
884
+ Compute the CTF weight tilt stack.
885
+
886
+ Parameters:
887
+ -----------
888
+ shape : tuple of int
889
+ The shape of the CTF.
890
+ defocus_x : tuple of float
891
+ The defocus value in x direction.
892
+ angles : tuple of float
893
+ The tilt angles.
894
+ electron_wavelength : float, optional
895
+ The electron wavelength, defaults to None.
896
+ opening_axis : int, optional
897
+ The axis around which the wedge is opened, defaults to None.
898
+ tilt_axis : int, optional
899
+ The axis along which the tilt is applied, defaults to None.
900
+ amplitude_contrast : float, optional
901
+ The amplitude contrast, defaults to 0.07.
902
+ phase_shift : tuple of float, optional
903
+ The phase shift, defaults to 0.
904
+ defocus_angle : tuple of float, optional
905
+ The defocus angle, defaults to 0.
906
+ defocus_y : tuple of float, optional
907
+ The defocus value in y direction, defaults to None.
908
+ correct_defocus_gradient : bool, optional
909
+ Whether to correct defocus gradient, defaults to False.
910
+ sampling_rate : tuple of float, optional
911
+ The sampling rate, defaults to 1.
912
+ acceleration_voltage : float, optional
913
+ The acceleration voltage in electron microscopy, defaults to 300e3.
914
+ spherical_aberration : float, optional
915
+ The spherical aberration coefficient, defaults to 2.7e3.
916
+ flip_phase : bool, optional
917
+ Whether the returned CTF should be phase-flipped.
918
+ **kwargs : Dict
919
+ Additional keyword arguments.
920
+
921
+ Returns:
922
+ --------
923
+ NDArray
924
+ A stack containing the CTF weight.
925
+ """
926
+ defoci_x = np.atleast_1d(defocus_x)
927
+ defoci_y = np.atleast_1d(defocus_y)
928
+ phase_shift = np.atleast_1d(phase_shift)
929
+ angles = np.atleast_1d(angles)
930
+ defocus_angle = np.atleast_1d(defocus_angle)
931
+
932
+ sampling_rate = np.max(sampling_rate)
933
+ tilt_shape = compute_tilt_shape(
934
+ shape=shape, opening_axis=opening_axis, reduce_dim=True
935
+ )
936
+ stack = np.zeros((len(angles), *tilt_shape))
937
+
938
+ correct_defocus_gradient &= len(shape) == 3
939
+ correct_defocus_gradient &= tilt_axis is not None
940
+ correct_defocus_gradient &= opening_axis is not None
941
+
942
+ for index, angle in enumerate(angles):
943
+ defocus_x, defocus_y = defoci_x[index], defoci_y[index]
944
+
945
+ if correct_defocus_gradient or defocus_y is not None:
946
+ grid = fftfreqn(
947
+ shape=shape,
948
+ sampling_rate=be.divide(sampling_rate, shape),
949
+ return_sparse_grid=True,
950
+ )
951
+
952
+ # This should be done after defocus_x computation
953
+ if correct_defocus_gradient:
954
+ angle_rad = np.radians(angle)
955
+
956
+ defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
957
+ remaining_axis = tuple(
958
+ i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
959
+ )[0]
960
+
961
+ if tilt_axis > remaining_axis:
962
+ defocus_x = np.add(defocus_x, defocus_gradient)
963
+ elif tilt_axis < remaining_axis and defocus_y is not None:
964
+ defocus_y = np.add(defocus_y, defocus_gradient.T)
965
+
966
+ if defocus_y is not None:
967
+ defocus_sum = np.add(defocus_x, defocus_y)
968
+ defocus_difference = np.subtract(defocus_x, defocus_y)
969
+ angular_grid = np.arctan2(grid[0], grid[1])
970
+ defocus_difference *= np.cos(2 * (angular_grid - defocus_angle[index]))
971
+ defocus_x = np.add(defocus_sum, defocus_difference)
972
+ defocus_x *= 0.5
973
+
974
+ frequency_grid = frequency_grid_at_angle(
975
+ shape=shape,
976
+ opening_axis=opening_axis,
977
+ tilt_axis=tilt_axis,
978
+ angle=angle,
979
+ sampling_rate=1,
980
+ )
981
+ np.square(frequency_grid, out=frequency_grid)
982
+
983
+ electron_aberration = spherical_aberration * electron_wavelength**2
984
+ chi = defocus_x - 0.5 * electron_aberration * frequency_grid
985
+ np.multiply(chi, np.pi * electron_wavelength, out=chi)
986
+ np.multiply(chi, frequency_grid, out=chi)
987
+
988
+ chi += phase_shift[index]
989
+ chi += np.arctan(
990
+ np.divide(
991
+ amplitude_contrast,
992
+ np.sqrt(1 - np.square(amplitude_contrast)),
993
+ )
994
+ )
995
+ np.sin(-chi, out=chi)
996
+ stack[index] = chi
997
+
998
+ # Avoid contrast inversion
999
+ np.negative(stack, out=stack)
1000
+ if flip_phase:
1001
+ np.abs(stack, out=stack)
1002
+
1003
+ stack = np.squeeze(stack)
1004
+ stack = be.to_backend_array(stack)
1005
+
1006
+ if len(angles) == 1:
1007
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
1008
+ if return_real_fourier:
1009
+ stack = crop_real_fourier(stack)
1010
+
1011
+ return stack