pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/orientations.py CHANGED
@@ -1,5 +1,5 @@
1
1
  #!python3
2
- """ Handle template matching peaks and convert between formats.
2
+ """ Handle template matching orientations and conversion between formats.
3
3
 
4
4
  Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
@@ -8,6 +8,7 @@
8
8
  import re
9
9
  from collections import deque
10
10
  from dataclasses import dataclass
11
+ from string import ascii_lowercase
11
12
  from typing import List, Tuple, Dict
12
13
 
13
14
  import numpy as np
@@ -17,7 +18,48 @@ from scipy.spatial.transform import Rotation
17
18
  @dataclass
18
19
  class Orientations:
19
20
  """
20
- Handle template matching peaks and convert between formats.
21
+ Handle template matching orientations and conversion between formats.
22
+
23
+ Examples
24
+ --------
25
+ The following achieves the minimal definition of an :py:class:`Orientations` instance
26
+
27
+ >>> import numpy as np
28
+ >>> from tme import Orientations
29
+ >>> translations = np.random.randint(low = 0, high = 100, size = (100,3))
30
+ >>> rotations = np.random.rand(100, 3)
31
+ >>> scores = np.random.rand(100)
32
+ >>> details = np.full((100,), fill_value = -1)
33
+ >>> orientations = Orientations(
34
+ >>> translations = translations,
35
+ >>> rotations = rotations,
36
+ >>> scores = scores,
37
+ >>> details = details,
38
+ >>> )
39
+
40
+ The created ``orientations`` object can be written to disk in a range of formats.
41
+ See :py:meth:`Orientations.to_file` for available formats. The following creates
42
+ a STAR file
43
+
44
+ >>> orientations.to_file("test.star")
45
+
46
+ :py:meth:`Orientations.from_file` can create :py:class:`Orientations` instances
47
+ from a range of formats, to enable conversion between formats
48
+
49
+ >>> orientations_star = Orientations.from_file("test.star")
50
+ >>> np.all(orientations.translations == orientations_star.translations)
51
+ True
52
+
53
+ Parameters
54
+ ----------
55
+ translations: np.ndarray
56
+ Array with translations of each orientations (n, d).
57
+ rotations: np.ndarray
58
+ Array with euler angles of each orientation in zxy convention (n, d).
59
+ scores: np.ndarray
60
+ Array with the score of each orientation (n, ).
61
+ details: np.ndarray
62
+ Array with additional orientation details (n, ).
21
63
  """
22
64
 
23
65
  #: Return a numpy array with translations of each orientation (n x d).
@@ -32,6 +74,29 @@ class Orientations:
32
74
  #: Return a numpy array with additional orientation details (n, ).
33
75
  details: np.ndarray
34
76
 
77
+ def __post_init__(self):
78
+ self.translations = np.array(self.translations).astype(np.float32)
79
+ self.rotations = np.array(self.rotations).astype(np.float32)
80
+ self.scores = np.array(self.scores).astype(np.float32)
81
+ self.details = np.array(self.details).astype(np.float32)
82
+ n_orientations = set(
83
+ [
84
+ self.translations.shape[0],
85
+ self.rotations.shape[0],
86
+ self.scores.shape[0],
87
+ self.details.shape[0],
88
+ ]
89
+ )
90
+ if len(n_orientations) != 1:
91
+ raise ValueError(
92
+ "The first dimension of all parameters needs to be of equal length."
93
+ )
94
+ if self.translations.ndim != 2:
95
+ raise ValueError("Expected two dimensional translations parameter.")
96
+
97
+ if self.rotations.ndim != 2:
98
+ raise ValueError("Expected two dimensional rotations parameter.")
99
+
35
100
  def __iter__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
36
101
  """
37
102
  Iterate over the current class instance. Each iteration returns a orientation
@@ -77,7 +142,17 @@ class Orientations:
77
142
  filename : str
78
143
  The name of the file where the orientations will be saved.
79
144
  file_format : type, optional
80
- The format in which to save the orientations. Supported formats are 'text' and 'relion'.
145
+ The format in which to save the orientations. Defaults to None and infers
146
+ the file_format from the typical extension. Supported formats are
147
+
148
+ +---------------+----------------------------------------------------+
149
+ | text | pyTME's standard tab-separated orientations file |
150
+ +---------------+----------------------------------------------------+
151
+ | relion | Creates a STAR file of orientations |
152
+ +---------------+----------------------------------------------------+
153
+ | dynamo | Creates a dynamo table |
154
+ +---------------+----------------------------------------------------+
155
+
81
156
  **kwargs : dict
82
157
  Additional keyword arguments specific to the file format.
83
158
 
@@ -120,8 +195,14 @@ class Orientations:
120
195
  The file is saved with a header specifying each column: z, y, x, euler_z,
121
196
  euler_y, euler_x, score, detail. Each row in the file corresponds to an orientation.
122
197
  """
198
+ naming = ascii_lowercase[::-1]
123
199
  header = "\t".join(
124
- ["z", "y", "x", "euler_z", "euler_y", "euler_x", "score", "detail"]
200
+ [
201
+ *list(naming[: self.translations.shape[1]]),
202
+ *[f"euler_{x}" for x in naming[: self.rotations.shape[1]]],
203
+ "score",
204
+ "detail",
205
+ ]
125
206
  )
126
207
  with open(filename, mode="w", encoding="utf-8") as ofile:
127
208
  _ = ofile.write(f"{header}\n")
@@ -158,9 +239,6 @@ class Orientations:
158
239
  References
159
240
  ----------
160
241
  .. [1] https://wiki.dynamo.biozentrum.unibas.ch/w/index.php/Table
161
-
162
- The file is saved with a standard header used in Dynamo STAR files.
163
- Each row in the file corresponds to an orientation.
164
242
  """
165
243
  with open(filename, mode="w", encoding="utf-8") as ofile:
166
244
  for index, (translation, rotation, score, detail) in enumerate(self):
@@ -316,8 +394,18 @@ class Orientations:
316
394
  filename : str
317
395
  The name of the file from which to read the orientations.
318
396
  file_format : type, optional
319
- The format of the file. Currently, only 'text' format is supported.
320
- **kwargs : dict
397
+ The format of the file. Defaults to None and infers
398
+ the file_format from the typical extension. Supported formats are
399
+
400
+ +---------------+----------------------------------------------------+
401
+ | text | pyTME's standard tab-separated orientations file |
402
+ +---------------+----------------------------------------------------+
403
+ | relion | Creates a STAR file of orientations |
404
+ +---------------+----------------------------------------------------+
405
+ | dynamo | Creates a dynamo table |
406
+ +---------------+----------------------------------------------------+
407
+
408
+ **kwargs
321
409
  Additional keyword arguments specific to the file format.
322
410
 
323
411
  Returns
@@ -330,11 +418,18 @@ class Orientations:
330
418
  ValueError
331
419
  If an unsupported file format is specified.
332
420
  """
333
- mapping = {"text": cls._from_text, "relion": cls._from_relion_star}
421
+ mapping = {
422
+ "text": cls._from_text,
423
+ "relion": cls._from_relion_star,
424
+ "tbl": cls._from_tbl,
425
+ }
334
426
  if file_format is None:
335
427
  file_format = "text"
428
+
336
429
  if filename.lower().endswith(".star"):
337
430
  file_format = "relion"
431
+ elif filename.lower().endswith(".tbl"):
432
+ file_format = "tbl"
338
433
 
339
434
  func = mapping.get(file_format, None)
340
435
  if func is None:
@@ -375,25 +470,28 @@ class Orientations:
375
470
  """
376
471
  with open(filename, mode="r", encoding="utf-8") as infile:
377
472
  data = [x.strip().split("\t") for x in infile.read().split("\n")]
378
- _ = data.pop(0)
379
473
 
474
+ header = data.pop(0)
380
475
  translation, rotation, score, detail = [], [], [], []
381
476
  for candidate in data:
382
477
  if len(candidate) <= 1:
383
478
  continue
384
- if len(candidate) != 8:
385
- candidate.append(-1)
386
479
 
387
- candidate = [float(x) for x in candidate]
388
- translation.append((candidate[0], candidate[1], candidate[2]))
389
- rotation.append((candidate[3], candidate[4], candidate[5]))
390
- score.append(candidate[6])
391
- detail.append(candidate[7])
480
+ translation.append(
481
+ tuple(
482
+ candidate[i] for i, x in enumerate(header) if x in ascii_lowercase
483
+ )
484
+ )
485
+ rotation.append(
486
+ tuple(candidate[i] for i, x in enumerate(header) if "euler" in x)
487
+ )
488
+ score.append(candidate[-2])
489
+ detail.append(candidate[-1])
392
490
 
393
- translation = np.vstack(translation).astype(int)
394
- rotation = np.vstack(rotation).astype(float)
395
- score = np.array(score).astype(float)
396
- detail = np.array(detail).astype(float)
491
+ translation = np.vstack(translation)
492
+ rotation = np.vstack(rotation)
493
+ score = np.array(score)
494
+ detail = np.array(detail)
397
495
 
398
496
  return translation, rotation, score, detail
399
497
 
@@ -448,20 +546,15 @@ class Orientations:
448
546
  ret = cls._parse_star(filename=filename, delimiter=delimiter)
449
547
  ret = ret["data_particles"]
450
548
 
451
- translation = (
452
- np.vstack(
453
- (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
454
- )
455
- .astype(np.float32)
456
- .astype(int)
457
- .T
549
+ translation = np.vstack(
550
+ (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
458
551
  )
552
+ translation = translation.astype(np.float32).T
459
553
 
460
- rotation = (
461
- np.vstack((ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"]))
462
- .astype(np.float32)
463
- .T
554
+ rotation = np.vstack(
555
+ (ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"])
464
556
  )
557
+ rotation = rotation.astype(np.float32).T
465
558
 
466
559
  rotation = Rotation.from_euler("xyx", rotation, degrees=True)
467
560
  rotation = rotation.as_euler(seq="zyx", degrees=True)
@@ -470,6 +563,33 @@ class Orientations:
470
563
 
471
564
  return translation, rotation, score, detail
472
565
 
566
+ @staticmethod
567
+ def _from_tbl(
568
+ filename: str, **kwargs
569
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
570
+ with open(filename, mode="r", encoding="utf-8") as infile:
571
+ data = infile.read().split("\n")
572
+ data = [x.strip().split(" ") for x in data if len(x.strip())]
573
+
574
+ if len(data[0]) != 38:
575
+ raise ValueError(
576
+ "Expected tbl file to have 38 columns generated by _to_tbl."
577
+ )
578
+
579
+ translations, rotations, scores, details = [], [], [], []
580
+ for peak in data:
581
+ rotation = Rotation.from_euler(
582
+ "xyx", (peak[6], peak[7], peak[8]), degrees=True
583
+ )
584
+ rotations.append(rotation.as_euler(seq="zyx", degrees=True))
585
+ scores.append(peak[9])
586
+ details.append(-1)
587
+ translations.append((peak[25], peak[24], peak[23]))
588
+
589
+ translations, rotations = np.array(translations), np.array(rotations)
590
+ scores, details = np.array(scores), np.array(details)
591
+ return translations, rotations, scores, details
592
+
473
593
  def get_extraction_slices(
474
594
  self,
475
595
  target_shape: Tuple[int],
@@ -504,55 +624,55 @@ class Orientations:
504
624
  SystemExit
505
625
  If no peak remains after filtering, indicating an error.
506
626
  """
507
- left_pad = np.divide(extraction_shape, 2).astype(int)
508
- right_pad = np.add(left_pad, np.mod(extraction_shape, 2)).astype(int)
627
+ right_pad = np.divide(extraction_shape, 2).astype(int)
628
+ left_pad = np.add(right_pad, np.mod(extraction_shape, 2)).astype(int)
629
+
630
+ peaks = self.translations.astype(int)
631
+ obs_beg = np.subtract(peaks, left_pad)
632
+ obs_end = np.add(peaks, right_pad)
509
633
 
510
- obs_start = np.subtract(self.translations, left_pad)
511
- obs_stop = np.add(self.translations, right_pad)
634
+ obs_beg = np.maximum(obs_beg, 0)
635
+ obs_end = np.minimum(obs_end, target_shape)
512
636
 
513
- cand_start = np.subtract(np.maximum(obs_start, 0), obs_start)
514
- cand_stop = np.subtract(obs_stop, np.minimum(obs_stop, target_shape))
515
- cand_stop = np.subtract(extraction_shape, cand_stop)
516
- obs_start = np.maximum(obs_start, 0)
517
- obs_stop = np.minimum(obs_stop, target_shape)
637
+ cand_beg = left_pad - np.subtract(peaks, obs_beg)
638
+ cand_end = left_pad + np.subtract(obs_end, peaks)
518
639
 
519
640
  subset = self
520
641
  if drop_out_of_box:
521
- stops = np.subtract(cand_stop, extraction_shape)
642
+ stops = np.subtract(cand_end, extraction_shape)
522
643
  keep_peaks = (
523
644
  np.sum(
524
- np.multiply(cand_start == 0, stops == 0),
645
+ np.multiply(cand_beg == 0, stops == 0),
525
646
  axis=1,
526
647
  )
527
- == self.translations.shape[1]
648
+ == peaks.shape[1]
528
649
  )
529
650
  n_remaining = keep_peaks.sum()
530
651
  if n_remaining == 0:
531
652
  print(
532
653
  "No peak remaining after filtering. Started with"
533
- f" {self.translations.shape[0]} filtered to {n_remaining}."
654
+ f" {peaks.shape[0]} filtered to {n_remaining}."
534
655
  " Consider reducing min_distance, increase num_peaks or use"
535
656
  " a different peak caller."
536
657
  )
537
- exit(-1)
538
658
 
539
- cand_start = cand_start[keep_peaks,]
540
- cand_stop = cand_stop[keep_peaks,]
541
- obs_start = obs_start[keep_peaks,]
542
- obs_stop = obs_stop[keep_peaks,]
659
+ cand_beg = cand_beg[keep_peaks,]
660
+ cand_end = cand_end[keep_peaks,]
661
+ obs_beg = obs_beg[keep_peaks,]
662
+ obs_end = obs_end[keep_peaks,]
543
663
  subset = self[keep_peaks]
544
664
 
545
- cand_start, cand_stop = cand_start.astype(int), cand_stop.astype(int)
546
- obs_start, obs_stop = obs_start.astype(int), obs_stop.astype(int)
665
+ cand_beg, cand_end = cand_beg.astype(int), cand_end.astype(int)
666
+ obs_beg, obs_end = obs_beg.astype(int), obs_end.astype(int)
547
667
 
548
668
  candidate_slices = [
549
669
  tuple(slice(s, e) for s, e in zip(start_row, stop_row))
550
- for start_row, stop_row in zip(cand_start, cand_stop)
670
+ for start_row, stop_row in zip(cand_beg, cand_end)
551
671
  ]
552
672
 
553
673
  observation_slices = [
554
674
  tuple(slice(s, e) for s, e in zip(start_row, stop_row))
555
- for start_row, stop_row in zip(obs_start, obs_stop)
675
+ for start_row, stop_row in zip(obs_beg, obs_end)
556
676
  ]
557
677
 
558
678
  if return_orientations:
@@ -0,0 +1,2 @@
1
+ from .compose import Compose
2
+ from .frequency_filters import BandPassFilter, LinearWhiteningFilter
@@ -0,0 +1,188 @@
1
+ """ Utilities for the generation of frequency grids.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ from ..backends import backend
14
+ from ..matching_utils import euler_to_rotationmatrix
15
+
16
+
17
+ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
18
+ """
19
+ Given an opening_axis, computes the shape of the remaining dimensions.
20
+
21
+ Parameters:
22
+ -----------
23
+ shape : Tuple[int]
24
+ The shape of the input array.
25
+ opening_axis : int
26
+ The axis along which the array will be tilted.
27
+ reduce_dim : bool, optional (default=False)
28
+ Whether to reduce the dimensionality after tilting.
29
+
30
+ Returns:
31
+ --------
32
+ Tuple[int]
33
+ The shape of the array after tilting.
34
+ """
35
+ tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
36
+ if reduce_dim:
37
+ tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
38
+
39
+ return tilt_shape
40
+
41
+
42
+ def centered_grid(shape: Tuple[int]) -> NDArray:
43
+ """
44
+ Generate an integer valued grid centered around size // 2
45
+
46
+ Parameters:
47
+ -----------
48
+ shape : Tuple[int]
49
+ The shape of the grid.
50
+
51
+ Returns:
52
+ --------
53
+ NDArray
54
+ The centered grid.
55
+ """
56
+ index_grid = np.array(
57
+ np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
58
+ )
59
+ return index_grid
60
+
61
+
62
+ def frequency_grid_at_angle(
63
+ shape: Tuple[int],
64
+ angle: float,
65
+ sampling_rate: Tuple[float],
66
+ opening_axis: int = None,
67
+ tilt_axis: int = None,
68
+ ) -> NDArray:
69
+ """
70
+ Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
71
+
72
+ Parameters:
73
+ -----------
74
+ shape : Tuple[int]
75
+ The shape of the grid.
76
+ angle : float
77
+ The angle at which to generate the grid.
78
+ sampling_rate : Tuple[float]
79
+ The sampling rate for each dimension.
80
+ opening_axis : int, optional
81
+ The axis to be opened, defaults to None.
82
+ tilt_axis : int, optional
83
+ The axis along which the grid is tilted, defaults to None.
84
+
85
+ Returns:
86
+ --------
87
+ NDArray
88
+ The frequency grid.
89
+ """
90
+ sampling_rate = np.array(sampling_rate)
91
+ sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
92
+
93
+ tilt_shape = compute_tilt_shape(
94
+ shape=shape, opening_axis=opening_axis, reduce_dim=False
95
+ )
96
+ index_grid = centered_grid(shape=tilt_shape)
97
+ if angle != 0:
98
+ angles = np.zeros(len(shape))
99
+ angles[tilt_axis] = angle
100
+ rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
101
+ index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
102
+
103
+ norm = np.divide(1, 2 * sampling_rate * np.divide(shape, 2).astype(int))
104
+
105
+ index_grid = np.multiply(index_grid.T, norm).T
106
+ index_grid = np.squeeze(index_grid)
107
+ index_grid = np.linalg.norm(index_grid, axis=(0))
108
+ return index_grid
109
+
110
+
111
+ def fftfreqn(
112
+ shape: Tuple[int],
113
+ sampling_rate: Tuple[float],
114
+ compute_euclidean_norm: bool = False,
115
+ shape_is_real_fourier: bool = False,
116
+ ) -> NDArray:
117
+ """
118
+ Generate the n-dimensional discrete Fourier Transform sample frequencies.
119
+
120
+ Parameters:
121
+ -----------
122
+ shape : Tuple[int]
123
+ The shape of the data.
124
+ sampling_rate : float or Tuple[float]
125
+ The sampling rate.
126
+ compute_euclidean_norm : bool, optional
127
+ Whether to compute the Euclidean norm, defaults to False.
128
+ shape_is_real_fourier : bool, optional
129
+ Whether the shape corresponds to a real Fourier transform, defaults to False.
130
+
131
+ Returns:
132
+ --------
133
+ NDArray
134
+ The sample frequencies.
135
+ """
136
+ center = backend.astype(backend.divide(shape, 2), backend._int_dtype)
137
+
138
+ norm = np.ones(len(shape))
139
+ if sampling_rate is not None:
140
+ norm = backend.astype(backend.multiply(shape, sampling_rate), int)
141
+
142
+ if shape_is_real_fourier:
143
+ center[-1] = 0
144
+ norm[-1] = 1
145
+ if sampling_rate is not None:
146
+ norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
147
+
148
+ indices = backend.transpose(backend.indices(shape))
149
+ indices -= center
150
+ indices = backend.divide(indices, norm)
151
+ indices = backend.transpose(indices)
152
+
153
+ if compute_euclidean_norm:
154
+ indices = backend.square(indices)
155
+ indices = backend.sum(indices, axis=0)
156
+ backend.sqrt(indices, out=indices)
157
+
158
+ return indices
159
+
160
+
161
+ def crop_real_fourier(data: NDArray) -> NDArray:
162
+ """
163
+ Crop the real part of a Fourier transform.
164
+
165
+ Parameters:
166
+ -----------
167
+ data : NDArray
168
+ The Fourier transformed data.
169
+
170
+ Returns:
171
+ --------
172
+ NDArray
173
+ The cropped data.
174
+ """
175
+ stop = 1 + (data.shape[-1] // 2)
176
+ return data[..., :stop]
177
+
178
+
179
+ def shift_fourier(data: NDArray, shape_is_real_fourier: bool = False):
180
+ shift = backend.add(
181
+ backend.astype(backend.divide(data.shape, 2), int),
182
+ backend.mod(data.shape, 2),
183
+ )
184
+ if shape_is_real_fourier:
185
+ shift[-1] = 0
186
+
187
+ data = backend.roll(data, shift, tuple(i for i in range(len(shift))))
188
+ return data
@@ -0,0 +1,31 @@
1
+ """ Defines a specification for filters that can be used with
2
+ :py:class:`tme.preprocessing.compose.Compose`.
3
+
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ from typing import Dict
9
+ from abc import ABC, abstractmethod
10
+
11
+
12
+ class ComposableFilter(ABC):
13
+ """
14
+ Strategy class for composable filters.
15
+ """
16
+
17
+ @abstractmethod
18
+ def __call__(self, *args, **kwargs) -> Dict:
19
+ """
20
+ Parameters:
21
+ -----------
22
+ *args : tuple
23
+ Variable length argument list.
24
+ **kwargs : dict
25
+ Arbitrary keyword arguments.
26
+
27
+ Returns:
28
+ --------
29
+ Dict
30
+ A dictionary representing the result of the filtering operation.
31
+ """
@@ -0,0 +1,51 @@
1
+ """ Combine filters using an interface analogous to pytorch's Compose.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple, Dict
9
+
10
+ from tme.backends import backend
11
+
12
+
13
+ class Compose:
14
+ """
15
+ Compose a series of transformations.
16
+
17
+ This class allows composing multiple transformations together. Each transformation
18
+ is expected to be a callable that accepts keyword arguments and returns metadata.
19
+
20
+ Parameters:
21
+ -----------
22
+ transforms : Tuple[object]
23
+ A tuple containing transformation objects.
24
+
25
+ Returns:
26
+ --------
27
+ Dict
28
+ Metadata resulting from the composed transformations.
29
+
30
+ """
31
+
32
+ def __init__(self, transforms: Tuple[object]):
33
+ self.transforms = transforms
34
+
35
+ def __call__(self, **kwargs: Dict) -> Dict:
36
+ meta = {}
37
+ if not len(self.transforms):
38
+ return meta
39
+
40
+ meta = self.transforms[0](**kwargs)
41
+ for transform in self.transforms[1:]:
42
+ kwargs.update(meta)
43
+ ret = transform(**kwargs)
44
+
45
+ if ret.get("is_multiplicative_filter", False):
46
+ backend.multiply(ret["data"], meta["data"], out=ret["data"])
47
+ ret["merge"] = None
48
+
49
+ meta = ret
50
+
51
+ return meta