pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/orientations.py CHANGED
@@ -1,5 +1,5 @@
1
1
  #!python3
2
- """ Handle template matching peaks and convert between formats.
2
+ """ Handle template matching orientations and conversion between formats.
3
3
 
4
4
  Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
@@ -8,6 +8,7 @@
8
8
  import re
9
9
  from collections import deque
10
10
  from dataclasses import dataclass
11
+ from string import ascii_lowercase
11
12
  from typing import List, Tuple, Dict
12
13
 
13
14
  import numpy as np
@@ -17,21 +18,85 @@ from scipy.spatial.transform import Rotation
17
18
  @dataclass
18
19
  class Orientations:
19
20
  """
20
- Handle template matching peaks and convert between formats.
21
+ Handle template matching orientations and conversion between formats.
22
+
23
+ Examples
24
+ --------
25
+ The following achieves the minimal definition of an :py:class:`Orientations` instance
26
+
27
+ >>> import numpy as np
28
+ >>> from tme import Orientations
29
+ >>> translations = np.random.randint(low = 0, high = 100, size = (100,3))
30
+ >>> rotations = np.random.rand(100, 3)
31
+ >>> scores = np.random.rand(100)
32
+ >>> details = np.full((100,), fill_value = -1)
33
+ >>> orientations = Orientations(
34
+ >>> translations = translations,
35
+ >>> rotations = rotations,
36
+ >>> scores = scores,
37
+ >>> details = details,
38
+ >>> )
39
+
40
+ The created ``orientations`` object can be written to disk in a range of formats.
41
+ See :py:meth:`Orientations.to_file` for available formats. The following creates
42
+ a STAR file
43
+
44
+ >>> orientations.to_file("test.star")
45
+
46
+ :py:meth:`Orientations.from_file` can create :py:class:`Orientations` instances
47
+ from a range of formats, to enable conversion between formats
48
+
49
+ >>> orientations_star = Orientations.from_file("test.star")
50
+ >>> np.all(orientations.translations == orientations_star.translations)
51
+ True
52
+
53
+ Parameters
54
+ ----------
55
+ translations: np.ndarray
56
+ Array with translations of each orientations (n, d).
57
+ rotations: np.ndarray
58
+ Array with euler angles of each orientation in zxy convention (n, d).
59
+ scores: np.ndarray
60
+ Array with the score of each orientation (n, ).
61
+ details: np.ndarray
62
+ Array with additional orientation details (n, ).
21
63
  """
22
64
 
23
- #: Return a numpy array with translations of each orientation (n x d).
65
+ #: Array with translations of each orientation (n, d).
24
66
  translations: np.ndarray
25
67
 
26
- #: Return a numpy array with euler angles of each orientation in zxy format (n x d).
68
+ #: Array with zyx euler angles of each orientation (n, d).
27
69
  rotations: np.ndarray
28
70
 
29
- #: Return a numpy array with the score of each orientation (n, ).
71
+ #: Array with scores of each orientation (n, ).
30
72
  scores: np.ndarray
31
73
 
32
- #: Return a numpy array with additional orientation details (n, ).
74
+ #: Array with additional details of each orientation(n, ).
33
75
  details: np.ndarray
34
76
 
77
+ def __post_init__(self):
78
+ self.translations = np.array(self.translations).astype(np.float32)
79
+ self.rotations = np.array(self.rotations).astype(np.float32)
80
+ self.scores = np.array(self.scores).astype(np.float32)
81
+ self.details = np.array(self.details).astype(np.float32)
82
+ n_orientations = set(
83
+ [
84
+ self.translations.shape[0],
85
+ self.rotations.shape[0],
86
+ self.scores.shape[0],
87
+ self.details.shape[0],
88
+ ]
89
+ )
90
+ if len(n_orientations) != 1:
91
+ raise ValueError(
92
+ "The first dimension of all parameters needs to be of equal length."
93
+ )
94
+ if self.translations.ndim != 2:
95
+ raise ValueError("Expected two dimensional translations parameter.")
96
+
97
+ if self.rotations.ndim != 2:
98
+ raise ValueError("Expected two dimensional rotations parameter.")
99
+
35
100
  def __iter__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
36
101
  """
37
102
  Iterate over the current class instance. Each iteration returns a orientation
@@ -65,9 +130,21 @@ class Orientations:
65
130
  "scores",
66
131
  "details",
67
132
  )
68
- kwargs = {attr: getattr(self, attr)[indices] for attr in attributes}
133
+ kwargs = {attr: getattr(self, attr)[indices].copy() for attr in attributes}
69
134
  return self.__class__(**kwargs)
70
135
 
136
+ def copy(self) -> "Orientations":
137
+ """
138
+ Create a copy of the current class instance.
139
+
140
+ Returns
141
+ -------
142
+ :py:class:`Orientations`
143
+ Copy of the class instance.
144
+ """
145
+ indices = np.arange(self.scores.size)
146
+ return self[indices]
147
+
71
148
  def to_file(self, filename: str, file_format: type = None, **kwargs) -> None:
72
149
  """
73
150
  Save the current class instance to a file in the specified format.
@@ -77,7 +154,17 @@ class Orientations:
77
154
  filename : str
78
155
  The name of the file where the orientations will be saved.
79
156
  file_format : type, optional
80
- The format in which to save the orientations. Supported formats are 'text' and 'relion'.
157
+ The format in which to save the orientations. Defaults to None and infers
158
+ the file_format from the typical extension. Supported formats are
159
+
160
+ +---------------+----------------------------------------------------+
161
+ | text | pytme's standard tab-separated orientations file |
162
+ +---------------+----------------------------------------------------+
163
+ | relion | Creates a STAR file of orientations |
164
+ +---------------+----------------------------------------------------+
165
+ | dynamo | Creates a dynamo table |
166
+ +---------------+----------------------------------------------------+
167
+
81
168
  **kwargs : dict
82
169
  Additional keyword arguments specific to the file format.
83
170
 
@@ -120,17 +207,23 @@ class Orientations:
120
207
  The file is saved with a header specifying each column: z, y, x, euler_z,
121
208
  euler_y, euler_x, score, detail. Each row in the file corresponds to an orientation.
122
209
  """
210
+ naming = ascii_lowercase[::-1]
123
211
  header = "\t".join(
124
- ["z", "y", "x", "euler_z", "euler_y", "euler_x", "score", "detail"]
212
+ [
213
+ *list(naming[: self.translations.shape[1]]),
214
+ *[f"euler_{x}" for x in naming[: self.rotations.shape[1]]],
215
+ "score",
216
+ "detail",
217
+ ]
125
218
  )
126
219
  with open(filename, mode="w", encoding="utf-8") as ofile:
127
220
  _ = ofile.write(f"{header}\n")
128
221
  for translation, angles, score, detail in self:
129
- translation_string = "\t".join([str(x) for x in translation])
130
- angle_string = "\t".join([str(x) for x in angles])
131
- _ = ofile.write(
132
- f"{translation_string}\t{angle_string}\t{score}\t{detail}\n"
222
+ out_string = (
223
+ "\t".join([str(x) for x in (*translation, *angles, score, detail)])
224
+ + "\n"
133
225
  )
226
+ _ = ofile.write(out_string)
134
227
  return None
135
228
 
136
229
  def _to_dynamo_tbl(
@@ -158,9 +251,6 @@ class Orientations:
158
251
  References
159
252
  ----------
160
253
  .. [1] https://wiki.dynamo.biozentrum.unibas.ch/w/index.php/Table
161
-
162
- The file is saved with a standard header used in Dynamo STAR files.
163
- Each row in the file corresponds to an orientation.
164
254
  """
165
255
  with open(filename, mode="w", encoding="utf-8") as ofile:
166
256
  for index, (translation, rotation, score, detail) in enumerate(self):
@@ -316,8 +406,18 @@ class Orientations:
316
406
  filename : str
317
407
  The name of the file from which to read the orientations.
318
408
  file_format : type, optional
319
- The format of the file. Currently, only 'text' format is supported.
320
- **kwargs : dict
409
+ The format of the file. Defaults to None and infers
410
+ the file_format from the typical extension. Supported formats are
411
+
412
+ +---------------+----------------------------------------------------+
413
+ | text | pyTME's standard tab-separated orientations file |
414
+ +---------------+----------------------------------------------------+
415
+ | relion | Creates a STAR file of orientations |
416
+ +---------------+----------------------------------------------------+
417
+ | dynamo | Creates a dynamo table |
418
+ +---------------+----------------------------------------------------+
419
+
420
+ **kwargs
321
421
  Additional keyword arguments specific to the file format.
322
422
 
323
423
  Returns
@@ -330,11 +430,18 @@ class Orientations:
330
430
  ValueError
331
431
  If an unsupported file format is specified.
332
432
  """
333
- mapping = {"text": cls._from_text, "relion": cls._from_relion_star}
433
+ mapping = {
434
+ "text": cls._from_text,
435
+ "relion": cls._from_relion_star,
436
+ "tbl": cls._from_tbl,
437
+ }
334
438
  if file_format is None:
335
439
  file_format = "text"
440
+
336
441
  if filename.lower().endswith(".star"):
337
442
  file_format = "relion"
443
+ elif filename.lower().endswith(".tbl"):
444
+ file_format = "tbl"
338
445
 
339
446
  func = mapping.get(file_format, None)
340
447
  if func is None:
@@ -370,30 +477,61 @@ class Orientations:
370
477
 
371
478
  Notes
372
479
  -----
373
- The text file is expected to have a header and data in columns corresponding to
374
- z, y, x, euler_z, euler_y, euler_x, score, detail.
480
+ The text file is expected to have a header and data in columns. Colums containing
481
+ the name euler are considered to specify rotations. The second last and last
482
+ column correspond to score and detail. Its possible to only specify translations,
483
+ in this case the remaining columns will be filled with trivial values.
375
484
  """
376
485
  with open(filename, mode="r", encoding="utf-8") as infile:
377
486
  data = [x.strip().split("\t") for x in infile.read().split("\n")]
378
- _ = data.pop(0)
379
487
 
488
+ header = data.pop(0)
380
489
  translation, rotation, score, detail = [], [], [], []
381
490
  for candidate in data:
382
491
  if len(candidate) <= 1:
383
492
  continue
384
- if len(candidate) != 8:
385
- candidate.append(-1)
386
493
 
387
- candidate = [float(x) for x in candidate]
388
- translation.append((candidate[0], candidate[1], candidate[2]))
389
- rotation.append((candidate[3], candidate[4], candidate[5]))
390
- score.append(candidate[6])
391
- detail.append(candidate[7])
494
+ translation.append(
495
+ tuple(
496
+ candidate[i] for i, x in enumerate(header) if x in ascii_lowercase
497
+ )
498
+ )
499
+ rotation.append(
500
+ tuple(candidate[i] for i, x in enumerate(header) if "euler" in x)
501
+ )
502
+ score.append(candidate[-2])
503
+ detail.append(candidate[-1])
504
+
505
+ translation = np.vstack(translation)
506
+ rotation = np.vstack(rotation)
507
+ score = np.array(score)
508
+ detail = np.array(detail)
509
+
510
+ if translation.shape[1] == len(header):
511
+ rotation = np.zeros(translation.shape, dtype=np.float32)
512
+ score = np.zeros(translation.shape[0], dtype=np.float32)
513
+ detail = np.zeros(translation.shape[0], dtype=np.float32) - 1
514
+
515
+ if rotation.size == 0 and translation.shape[0] != 0:
516
+ rotation = np.zeros(translation.shape, dtype=np.float32)
517
+
518
+ header_order = tuple(x for x in header if x in ascii_lowercase)
519
+ header_order = zip(header_order, range(len(header_order)))
520
+ sort_order = tuple(
521
+ x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=True)
522
+ )
523
+ translation = translation[..., sort_order]
392
524
 
393
- translation = np.vstack(translation).astype(int)
394
- rotation = np.vstack(rotation).astype(float)
395
- score = np.array(score).astype(float)
396
- detail = np.array(detail).astype(float)
525
+ header_order = tuple(
526
+ x
527
+ for x in header
528
+ if "euler" in x and x.replace("euler_", "") in ascii_lowercase
529
+ )
530
+ header_order = zip(header_order, range(len(header_order)))
531
+ sort_order = tuple(
532
+ x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=True)
533
+ )
534
+ rotation = rotation[..., sort_order]
397
535
 
398
536
  return translation, rotation, score, detail
399
537
 
@@ -448,20 +586,15 @@ class Orientations:
448
586
  ret = cls._parse_star(filename=filename, delimiter=delimiter)
449
587
  ret = ret["data_particles"]
450
588
 
451
- translation = (
452
- np.vstack(
453
- (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
454
- )
455
- .astype(np.float32)
456
- .astype(int)
457
- .T
589
+ translation = np.vstack(
590
+ (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
458
591
  )
592
+ translation = translation.astype(np.float32).T
459
593
 
460
- rotation = (
461
- np.vstack((ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"]))
462
- .astype(np.float32)
463
- .T
594
+ rotation = np.vstack(
595
+ (ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"])
464
596
  )
597
+ rotation = rotation.astype(np.float32).T
465
598
 
466
599
  rotation = Rotation.from_euler("xyx", rotation, degrees=True)
467
600
  rotation = rotation.as_euler(seq="zyx", degrees=True)
@@ -470,6 +603,33 @@ class Orientations:
470
603
 
471
604
  return translation, rotation, score, detail
472
605
 
606
+ @staticmethod
607
+ def _from_tbl(
608
+ filename: str, **kwargs
609
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
610
+ with open(filename, mode="r", encoding="utf-8") as infile:
611
+ data = infile.read().split("\n")
612
+ data = [x.strip().split(" ") for x in data if len(x.strip())]
613
+
614
+ if len(data[0]) != 38:
615
+ raise ValueError(
616
+ "Expected tbl file to have 38 columns generated by _to_tbl."
617
+ )
618
+
619
+ translations, rotations, scores, details = [], [], [], []
620
+ for peak in data:
621
+ rotation = Rotation.from_euler(
622
+ "xyx", (peak[6], peak[7], peak[8]), degrees=True
623
+ )
624
+ rotations.append(rotation.as_euler(seq="zyx", degrees=True))
625
+ scores.append(peak[9])
626
+ details.append(-1)
627
+ translations.append((peak[25], peak[24], peak[23]))
628
+
629
+ translations, rotations = np.array(translations), np.array(rotations)
630
+ scores, details = np.array(scores), np.array(details)
631
+ return translations, rotations, scores, details
632
+
473
633
  def get_extraction_slices(
474
634
  self,
475
635
  target_shape: Tuple[int],
@@ -504,55 +664,55 @@ class Orientations:
504
664
  SystemExit
505
665
  If no peak remains after filtering, indicating an error.
506
666
  """
507
- left_pad = np.divide(extraction_shape, 2).astype(int)
508
- right_pad = np.add(left_pad, np.mod(extraction_shape, 2)).astype(int)
667
+ right_pad = np.divide(extraction_shape, 2).astype(int)
668
+ left_pad = np.add(right_pad, np.mod(extraction_shape, 2)).astype(int)
669
+
670
+ peaks = self.translations.astype(int)
671
+ obs_beg = np.subtract(peaks, left_pad)
672
+ obs_end = np.add(peaks, right_pad)
509
673
 
510
- obs_start = np.subtract(self.translations, left_pad)
511
- obs_stop = np.add(self.translations, right_pad)
674
+ obs_beg = np.maximum(obs_beg, 0)
675
+ obs_end = np.minimum(obs_end, target_shape)
512
676
 
513
- cand_start = np.subtract(np.maximum(obs_start, 0), obs_start)
514
- cand_stop = np.subtract(obs_stop, np.minimum(obs_stop, target_shape))
515
- cand_stop = np.subtract(extraction_shape, cand_stop)
516
- obs_start = np.maximum(obs_start, 0)
517
- obs_stop = np.minimum(obs_stop, target_shape)
677
+ cand_beg = left_pad - np.subtract(peaks, obs_beg)
678
+ cand_end = left_pad + np.subtract(obs_end, peaks)
518
679
 
519
680
  subset = self
520
681
  if drop_out_of_box:
521
- stops = np.subtract(cand_stop, extraction_shape)
682
+ stops = np.subtract(cand_end, extraction_shape)
522
683
  keep_peaks = (
523
684
  np.sum(
524
- np.multiply(cand_start == 0, stops == 0),
685
+ np.multiply(cand_beg == 0, stops == 0),
525
686
  axis=1,
526
687
  )
527
- == self.translations.shape[1]
688
+ == peaks.shape[1]
528
689
  )
529
690
  n_remaining = keep_peaks.sum()
530
691
  if n_remaining == 0:
531
692
  print(
532
693
  "No peak remaining after filtering. Started with"
533
- f" {self.translations.shape[0]} filtered to {n_remaining}."
694
+ f" {peaks.shape[0]} filtered to {n_remaining}."
534
695
  " Consider reducing min_distance, increase num_peaks or use"
535
696
  " a different peak caller."
536
697
  )
537
- exit(-1)
538
698
 
539
- cand_start = cand_start[keep_peaks,]
540
- cand_stop = cand_stop[keep_peaks,]
541
- obs_start = obs_start[keep_peaks,]
542
- obs_stop = obs_stop[keep_peaks,]
699
+ cand_beg = cand_beg[keep_peaks,]
700
+ cand_end = cand_end[keep_peaks,]
701
+ obs_beg = obs_beg[keep_peaks,]
702
+ obs_end = obs_end[keep_peaks,]
543
703
  subset = self[keep_peaks]
544
704
 
545
- cand_start, cand_stop = cand_start.astype(int), cand_stop.astype(int)
546
- obs_start, obs_stop = obs_start.astype(int), obs_stop.astype(int)
705
+ cand_beg, cand_end = cand_beg.astype(int), cand_end.astype(int)
706
+ obs_beg, obs_end = obs_beg.astype(int), obs_end.astype(int)
547
707
 
548
708
  candidate_slices = [
549
709
  tuple(slice(s, e) for s, e in zip(start_row, stop_row))
550
- for start_row, stop_row in zip(cand_start, cand_stop)
710
+ for start_row, stop_row in zip(cand_beg, cand_end)
551
711
  ]
552
712
 
553
713
  observation_slices = [
554
714
  tuple(slice(s, e) for s, e in zip(start_row, stop_row))
555
- for start_row, stop_row in zip(obs_start, obs_stop)
715
+ for start_row, stop_row in zip(obs_beg, obs_end)
556
716
  ]
557
717
 
558
718
  if return_orientations:
tme/parser.py CHANGED
@@ -137,8 +137,7 @@ class Parser(ABC):
137
137
 
138
138
  class PDBParser(Parser):
139
139
  """
140
- A Parser subclass for converting PDB file data into a dictionary representation.
141
- This class is specifically designed to work with PDB file format.
140
+ Convert PDB file data into a dictionary representation [1]_.
142
141
 
143
142
  References
144
143
  ----------
@@ -228,8 +227,8 @@ class PDBParser(Parser):
228
227
 
229
228
  class MMCIFParser(Parser):
230
229
  """
231
- A Parser subclass for converting MMCIF file data into a dictionary representation.
232
- This implementation heavily relies on the atomium library [1]_.
230
+ Convert MMCIF file data into a dictionary representation. This implementation
231
+ heavily relies on the atomium library [1]_.
233
232
 
234
233
  References
235
234
  ----------
@@ -0,0 +1,2 @@
1
+ from .compose import Compose
2
+ from .frequency_filters import BandPassFilter, LinearWhiteningFilter
@@ -0,0 +1,217 @@
1
+ """ Utilities for the generation of frequency grids.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple, List
9
+
10
+ import numpy as np
11
+
12
+ from ..backends import backend as be
13
+ from ..backends import NumpyFFTWBackend
14
+ from ..types import BackendArray, NDArray
15
+ from ..matching_utils import euler_to_rotationmatrix
16
+
17
+
18
+ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
19
+ """
20
+ Given an opening_axis, computes the shape of the remaining dimensions.
21
+
22
+ Parameters:
23
+ -----------
24
+ shape : Tuple[int]
25
+ The shape of the input array.
26
+ opening_axis : int
27
+ The axis along which the array will be tilted.
28
+ reduce_dim : bool, optional (default=False)
29
+ Whether to reduce the dimensionality after tilting.
30
+
31
+ Returns:
32
+ --------
33
+ Tuple[int]
34
+ The shape of the array after tilting.
35
+ """
36
+ tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
37
+ if reduce_dim:
38
+ tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
39
+
40
+ return tilt_shape
41
+
42
+
43
+ def centered_grid(shape: Tuple[int]) -> NDArray:
44
+ """
45
+ Generate an integer valued grid centered around size // 2
46
+
47
+ Parameters:
48
+ -----------
49
+ shape : Tuple[int]
50
+ The shape of the grid.
51
+
52
+ Returns:
53
+ --------
54
+ NDArray
55
+ The centered grid.
56
+ """
57
+ index_grid = np.array(
58
+ np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
59
+ )
60
+ return index_grid
61
+
62
+
63
+ def frequency_grid_at_angle(
64
+ shape: Tuple[int],
65
+ angle: float,
66
+ sampling_rate: Tuple[float],
67
+ opening_axis: int = None,
68
+ tilt_axis: int = None,
69
+ ) -> NDArray:
70
+ """
71
+ Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
72
+
73
+ Parameters:
74
+ -----------
75
+ shape : Tuple[int]
76
+ The shape of the grid.
77
+ angle : float
78
+ The angle at which to generate the grid.
79
+ sampling_rate : Tuple[float]
80
+ The sampling rate for each dimension.
81
+ opening_axis : int, optional
82
+ The axis to be opened, defaults to None.
83
+ tilt_axis : int, optional
84
+ The axis along which the grid is tilted, defaults to None.
85
+
86
+ Returns:
87
+ --------
88
+ NDArray
89
+ The frequency grid.
90
+ """
91
+ sampling_rate = np.array(sampling_rate)
92
+ sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
93
+
94
+ tilt_shape = compute_tilt_shape(
95
+ shape=shape, opening_axis=opening_axis, reduce_dim=False
96
+ )
97
+
98
+ if angle == 0:
99
+ index_grid = fftfreqn(
100
+ tuple(x for x in tilt_shape if x != 1),
101
+ sampling_rate=1,
102
+ compute_euclidean_norm=True,
103
+ )
104
+
105
+ if angle != 0:
106
+ angles = np.zeros(len(shape))
107
+ angles[tilt_axis] = angle
108
+ rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
109
+
110
+ index_grid = fftfreqn(tilt_shape, sampling_rate=None)
111
+ index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
112
+ norm = np.multiply(sampling_rate, shape).astype(int)
113
+
114
+ index_grid = np.divide(index_grid.T, norm).T
115
+ index_grid = np.squeeze(index_grid)
116
+ index_grid = np.linalg.norm(index_grid, axis=(0))
117
+
118
+ return index_grid
119
+
120
+
121
+ def fftfreqn(
122
+ shape: Tuple[int],
123
+ sampling_rate: Tuple[float],
124
+ compute_euclidean_norm: bool = False,
125
+ shape_is_real_fourier: bool = False,
126
+ return_sparse_grid: bool = False,
127
+ ) -> NDArray:
128
+ """
129
+ Generate the n-dimensional discrete Fourier transform sample frequencies.
130
+
131
+ Parameters:
132
+ -----------
133
+ shape : Tuple[int]
134
+ The shape of the data.
135
+ sampling_rate : float or Tuple[float]
136
+ The sampling rate.
137
+ compute_euclidean_norm : bool, optional
138
+ Whether to compute the Euclidean norm, defaults to False.
139
+ shape_is_real_fourier : bool, optional
140
+ Whether the shape corresponds to a real Fourier transform, defaults to False.
141
+
142
+ Returns:
143
+ --------
144
+ NDArray
145
+ The sample frequencies.
146
+ """
147
+ # There is no real need to have these operations on GPU right now
148
+ temp_backend = NumpyFFTWBackend()
149
+ norm = temp_backend.full(len(shape), fill_value=1)
150
+ center = temp_backend.astype(temp_backend.divide(shape, 2), temp_backend._int_dtype)
151
+ if sampling_rate is not None:
152
+ norm = temp_backend.astype(temp_backend.multiply(shape, sampling_rate), int)
153
+
154
+ if shape_is_real_fourier:
155
+ center[-1], norm[-1] = 0, 1
156
+ if sampling_rate is not None:
157
+ norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
158
+
159
+ grids = []
160
+ for i, x in enumerate(shape):
161
+ baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
162
+ grid = (temp_backend.arange(x) - center[i]) / norm[i]
163
+ grids.append(temp_backend.reshape(grid, baseline_dims))
164
+
165
+ if compute_euclidean_norm:
166
+ grids = sum(temp_backend.square(x) for x in grids)
167
+ grids = temp_backend.sqrt(grids, out=grids)
168
+ return grids
169
+
170
+ if return_sparse_grid:
171
+ return grids
172
+
173
+ grid_flesh = temp_backend.full(shape, fill_value=1)
174
+ grids = temp_backend.stack(tuple(grid * grid_flesh for grid in grids))
175
+
176
+ return grids
177
+
178
+
179
+ def crop_real_fourier(data: BackendArray) -> BackendArray:
180
+ """
181
+ Crop the real part of a Fourier transform.
182
+
183
+ Parameters:
184
+ -----------
185
+ data : BackendArray
186
+ The Fourier transformed data.
187
+
188
+ Returns:
189
+ --------
190
+ BackendArray
191
+ The cropped data.
192
+ """
193
+ stop = 1 + (data.shape[-1] // 2)
194
+ return data[..., :stop]
195
+
196
+
197
+ def compute_fourier_shape(
198
+ shape: Tuple[int], shape_is_real_fourier: bool = False
199
+ ) -> List[int]:
200
+ if shape_is_real_fourier:
201
+ return shape
202
+ shape = [int(x) for x in shape]
203
+ shape[-1] = 1 + shape[-1] // 2
204
+ return shape
205
+
206
+
207
+ def shift_fourier(
208
+ data: BackendArray, shape_is_real_fourier: bool = False
209
+ ) -> BackendArray:
210
+ shape = be.to_backend_array(data.shape)
211
+ shift = be.add(be.divide(shape, 2), be.mod(shape, 2))
212
+ shift = [int(x) for x in shift]
213
+ if shape_is_real_fourier:
214
+ shift[-1] = 0
215
+
216
+ data = be.roll(data, shift, tuple(i for i in range(len(shift))))
217
+ return data