pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +473 -140
- scripts/match_template_filters.py +458 -169
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +278 -148
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +22 -12
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +85 -64
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +86 -60
- tme/matching_exhaustive.py +245 -166
- tme/matching_optimization.py +137 -69
- tme/matching_utils.py +1 -1
- tme/orientations.py +175 -55
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +188 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +51 -0
- tme/preprocessing/frequency_filters.py +378 -0
- tme/preprocessing/tilt_series.py +1017 -0
- tme/preprocessor.py +17 -7
- tme/structure.py +4 -1
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/orientations.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
#!python3
|
2
|
-
""" Handle template matching
|
2
|
+
""" Handle template matching orientations and conversion between formats.
|
3
3
|
|
4
4
|
Copyright (c) 2024 European Molecular Biology Laboratory
|
5
5
|
|
@@ -8,6 +8,7 @@
|
|
8
8
|
import re
|
9
9
|
from collections import deque
|
10
10
|
from dataclasses import dataclass
|
11
|
+
from string import ascii_lowercase
|
11
12
|
from typing import List, Tuple, Dict
|
12
13
|
|
13
14
|
import numpy as np
|
@@ -17,7 +18,48 @@ from scipy.spatial.transform import Rotation
|
|
17
18
|
@dataclass
|
18
19
|
class Orientations:
|
19
20
|
"""
|
20
|
-
Handle template matching
|
21
|
+
Handle template matching orientations and conversion between formats.
|
22
|
+
|
23
|
+
Examples
|
24
|
+
--------
|
25
|
+
The following achieves the minimal definition of an :py:class:`Orientations` instance
|
26
|
+
|
27
|
+
>>> import numpy as np
|
28
|
+
>>> from tme import Orientations
|
29
|
+
>>> translations = np.random.randint(low = 0, high = 100, size = (100,3))
|
30
|
+
>>> rotations = np.random.rand(100, 3)
|
31
|
+
>>> scores = np.random.rand(100)
|
32
|
+
>>> details = np.full((100,), fill_value = -1)
|
33
|
+
>>> orientations = Orientations(
|
34
|
+
>>> translations = translations,
|
35
|
+
>>> rotations = rotations,
|
36
|
+
>>> scores = scores,
|
37
|
+
>>> details = details,
|
38
|
+
>>> )
|
39
|
+
|
40
|
+
The created ``orientations`` object can be written to disk in a range of formats.
|
41
|
+
See :py:meth:`Orientations.to_file` for available formats. The following creates
|
42
|
+
a STAR file
|
43
|
+
|
44
|
+
>>> orientations.to_file("test.star")
|
45
|
+
|
46
|
+
:py:meth:`Orientations.from_file` can create :py:class:`Orientations` instances
|
47
|
+
from a range of formats, to enable conversion between formats
|
48
|
+
|
49
|
+
>>> orientations_star = Orientations.from_file("test.star")
|
50
|
+
>>> np.all(orientations.translations == orientations_star.translations)
|
51
|
+
True
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
translations: np.ndarray
|
56
|
+
Array with translations of each orientations (n, d).
|
57
|
+
rotations: np.ndarray
|
58
|
+
Array with euler angles of each orientation in zxy convention (n, d).
|
59
|
+
scores: np.ndarray
|
60
|
+
Array with the score of each orientation (n, ).
|
61
|
+
details: np.ndarray
|
62
|
+
Array with additional orientation details (n, ).
|
21
63
|
"""
|
22
64
|
|
23
65
|
#: Return a numpy array with translations of each orientation (n x d).
|
@@ -32,6 +74,29 @@ class Orientations:
|
|
32
74
|
#: Return a numpy array with additional orientation details (n, ).
|
33
75
|
details: np.ndarray
|
34
76
|
|
77
|
+
def __post_init__(self):
|
78
|
+
self.translations = np.array(self.translations).astype(np.float32)
|
79
|
+
self.rotations = np.array(self.rotations).astype(np.float32)
|
80
|
+
self.scores = np.array(self.scores).astype(np.float32)
|
81
|
+
self.details = np.array(self.details).astype(np.float32)
|
82
|
+
n_orientations = set(
|
83
|
+
[
|
84
|
+
self.translations.shape[0],
|
85
|
+
self.rotations.shape[0],
|
86
|
+
self.scores.shape[0],
|
87
|
+
self.details.shape[0],
|
88
|
+
]
|
89
|
+
)
|
90
|
+
if len(n_orientations) != 1:
|
91
|
+
raise ValueError(
|
92
|
+
"The first dimension of all parameters needs to be of equal length."
|
93
|
+
)
|
94
|
+
if self.translations.ndim != 2:
|
95
|
+
raise ValueError("Expected two dimensional translations parameter.")
|
96
|
+
|
97
|
+
if self.rotations.ndim != 2:
|
98
|
+
raise ValueError("Expected two dimensional rotations parameter.")
|
99
|
+
|
35
100
|
def __iter__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
36
101
|
"""
|
37
102
|
Iterate over the current class instance. Each iteration returns a orientation
|
@@ -77,7 +142,17 @@ class Orientations:
|
|
77
142
|
filename : str
|
78
143
|
The name of the file where the orientations will be saved.
|
79
144
|
file_format : type, optional
|
80
|
-
The format in which to save the orientations.
|
145
|
+
The format in which to save the orientations. Defaults to None and infers
|
146
|
+
the file_format from the typical extension. Supported formats are
|
147
|
+
|
148
|
+
+---------------+----------------------------------------------------+
|
149
|
+
| text | pyTME's standard tab-separated orientations file |
|
150
|
+
+---------------+----------------------------------------------------+
|
151
|
+
| relion | Creates a STAR file of orientations |
|
152
|
+
+---------------+----------------------------------------------------+
|
153
|
+
| dynamo | Creates a dynamo table |
|
154
|
+
+---------------+----------------------------------------------------+
|
155
|
+
|
81
156
|
**kwargs : dict
|
82
157
|
Additional keyword arguments specific to the file format.
|
83
158
|
|
@@ -120,8 +195,14 @@ class Orientations:
|
|
120
195
|
The file is saved with a header specifying each column: z, y, x, euler_z,
|
121
196
|
euler_y, euler_x, score, detail. Each row in the file corresponds to an orientation.
|
122
197
|
"""
|
198
|
+
naming = ascii_lowercase[::-1]
|
123
199
|
header = "\t".join(
|
124
|
-
[
|
200
|
+
[
|
201
|
+
*list(naming[: self.translations.shape[1]]),
|
202
|
+
*[f"euler_{x}" for x in naming[: self.rotations.shape[1]]],
|
203
|
+
"score",
|
204
|
+
"detail",
|
205
|
+
]
|
125
206
|
)
|
126
207
|
with open(filename, mode="w", encoding="utf-8") as ofile:
|
127
208
|
_ = ofile.write(f"{header}\n")
|
@@ -158,9 +239,6 @@ class Orientations:
|
|
158
239
|
References
|
159
240
|
----------
|
160
241
|
.. [1] https://wiki.dynamo.biozentrum.unibas.ch/w/index.php/Table
|
161
|
-
|
162
|
-
The file is saved with a standard header used in Dynamo STAR files.
|
163
|
-
Each row in the file corresponds to an orientation.
|
164
242
|
"""
|
165
243
|
with open(filename, mode="w", encoding="utf-8") as ofile:
|
166
244
|
for index, (translation, rotation, score, detail) in enumerate(self):
|
@@ -316,8 +394,18 @@ class Orientations:
|
|
316
394
|
filename : str
|
317
395
|
The name of the file from which to read the orientations.
|
318
396
|
file_format : type, optional
|
319
|
-
The format of the file.
|
320
|
-
|
397
|
+
The format of the file. Defaults to None and infers
|
398
|
+
the file_format from the typical extension. Supported formats are
|
399
|
+
|
400
|
+
+---------------+----------------------------------------------------+
|
401
|
+
| text | pyTME's standard tab-separated orientations file |
|
402
|
+
+---------------+----------------------------------------------------+
|
403
|
+
| relion | Creates a STAR file of orientations |
|
404
|
+
+---------------+----------------------------------------------------+
|
405
|
+
| dynamo | Creates a dynamo table |
|
406
|
+
+---------------+----------------------------------------------------+
|
407
|
+
|
408
|
+
**kwargs
|
321
409
|
Additional keyword arguments specific to the file format.
|
322
410
|
|
323
411
|
Returns
|
@@ -330,11 +418,18 @@ class Orientations:
|
|
330
418
|
ValueError
|
331
419
|
If an unsupported file format is specified.
|
332
420
|
"""
|
333
|
-
mapping = {
|
421
|
+
mapping = {
|
422
|
+
"text": cls._from_text,
|
423
|
+
"relion": cls._from_relion_star,
|
424
|
+
"tbl": cls._from_tbl,
|
425
|
+
}
|
334
426
|
if file_format is None:
|
335
427
|
file_format = "text"
|
428
|
+
|
336
429
|
if filename.lower().endswith(".star"):
|
337
430
|
file_format = "relion"
|
431
|
+
elif filename.lower().endswith(".tbl"):
|
432
|
+
file_format = "tbl"
|
338
433
|
|
339
434
|
func = mapping.get(file_format, None)
|
340
435
|
if func is None:
|
@@ -375,25 +470,28 @@ class Orientations:
|
|
375
470
|
"""
|
376
471
|
with open(filename, mode="r", encoding="utf-8") as infile:
|
377
472
|
data = [x.strip().split("\t") for x in infile.read().split("\n")]
|
378
|
-
_ = data.pop(0)
|
379
473
|
|
474
|
+
header = data.pop(0)
|
380
475
|
translation, rotation, score, detail = [], [], [], []
|
381
476
|
for candidate in data:
|
382
477
|
if len(candidate) <= 1:
|
383
478
|
continue
|
384
|
-
if len(candidate) != 8:
|
385
|
-
candidate.append(-1)
|
386
479
|
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
480
|
+
translation.append(
|
481
|
+
tuple(
|
482
|
+
candidate[i] for i, x in enumerate(header) if x in ascii_lowercase
|
483
|
+
)
|
484
|
+
)
|
485
|
+
rotation.append(
|
486
|
+
tuple(candidate[i] for i, x in enumerate(header) if "euler" in x)
|
487
|
+
)
|
488
|
+
score.append(candidate[-2])
|
489
|
+
detail.append(candidate[-1])
|
392
490
|
|
393
|
-
translation = np.vstack(translation)
|
394
|
-
rotation = np.vstack(rotation)
|
395
|
-
score = np.array(score)
|
396
|
-
detail = np.array(detail)
|
491
|
+
translation = np.vstack(translation)
|
492
|
+
rotation = np.vstack(rotation)
|
493
|
+
score = np.array(score)
|
494
|
+
detail = np.array(detail)
|
397
495
|
|
398
496
|
return translation, rotation, score, detail
|
399
497
|
|
@@ -448,20 +546,15 @@ class Orientations:
|
|
448
546
|
ret = cls._parse_star(filename=filename, delimiter=delimiter)
|
449
547
|
ret = ret["data_particles"]
|
450
548
|
|
451
|
-
translation = (
|
452
|
-
|
453
|
-
(ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
|
454
|
-
)
|
455
|
-
.astype(np.float32)
|
456
|
-
.astype(int)
|
457
|
-
.T
|
549
|
+
translation = np.vstack(
|
550
|
+
(ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
|
458
551
|
)
|
552
|
+
translation = translation.astype(np.float32).T
|
459
553
|
|
460
|
-
rotation = (
|
461
|
-
|
462
|
-
.astype(np.float32)
|
463
|
-
.T
|
554
|
+
rotation = np.vstack(
|
555
|
+
(ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"])
|
464
556
|
)
|
557
|
+
rotation = rotation.astype(np.float32).T
|
465
558
|
|
466
559
|
rotation = Rotation.from_euler("xyx", rotation, degrees=True)
|
467
560
|
rotation = rotation.as_euler(seq="zyx", degrees=True)
|
@@ -470,6 +563,33 @@ class Orientations:
|
|
470
563
|
|
471
564
|
return translation, rotation, score, detail
|
472
565
|
|
566
|
+
@staticmethod
|
567
|
+
def _from_tbl(
|
568
|
+
filename: str, **kwargs
|
569
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
570
|
+
with open(filename, mode="r", encoding="utf-8") as infile:
|
571
|
+
data = infile.read().split("\n")
|
572
|
+
data = [x.strip().split(" ") for x in data if len(x.strip())]
|
573
|
+
|
574
|
+
if len(data[0]) != 38:
|
575
|
+
raise ValueError(
|
576
|
+
"Expected tbl file to have 38 columns generated by _to_tbl."
|
577
|
+
)
|
578
|
+
|
579
|
+
translations, rotations, scores, details = [], [], [], []
|
580
|
+
for peak in data:
|
581
|
+
rotation = Rotation.from_euler(
|
582
|
+
"xyx", (peak[6], peak[7], peak[8]), degrees=True
|
583
|
+
)
|
584
|
+
rotations.append(rotation.as_euler(seq="zyx", degrees=True))
|
585
|
+
scores.append(peak[9])
|
586
|
+
details.append(-1)
|
587
|
+
translations.append((peak[25], peak[24], peak[23]))
|
588
|
+
|
589
|
+
translations, rotations = np.array(translations), np.array(rotations)
|
590
|
+
scores, details = np.array(scores), np.array(details)
|
591
|
+
return translations, rotations, scores, details
|
592
|
+
|
473
593
|
def get_extraction_slices(
|
474
594
|
self,
|
475
595
|
target_shape: Tuple[int],
|
@@ -504,55 +624,55 @@ class Orientations:
|
|
504
624
|
SystemExit
|
505
625
|
If no peak remains after filtering, indicating an error.
|
506
626
|
"""
|
507
|
-
|
508
|
-
|
627
|
+
right_pad = np.divide(extraction_shape, 2).astype(int)
|
628
|
+
left_pad = np.add(right_pad, np.mod(extraction_shape, 2)).astype(int)
|
629
|
+
|
630
|
+
peaks = self.translations.astype(int)
|
631
|
+
obs_beg = np.subtract(peaks, left_pad)
|
632
|
+
obs_end = np.add(peaks, right_pad)
|
509
633
|
|
510
|
-
|
511
|
-
|
634
|
+
obs_beg = np.maximum(obs_beg, 0)
|
635
|
+
obs_end = np.minimum(obs_end, target_shape)
|
512
636
|
|
513
|
-
|
514
|
-
|
515
|
-
cand_stop = np.subtract(extraction_shape, cand_stop)
|
516
|
-
obs_start = np.maximum(obs_start, 0)
|
517
|
-
obs_stop = np.minimum(obs_stop, target_shape)
|
637
|
+
cand_beg = left_pad - np.subtract(peaks, obs_beg)
|
638
|
+
cand_end = left_pad + np.subtract(obs_end, peaks)
|
518
639
|
|
519
640
|
subset = self
|
520
641
|
if drop_out_of_box:
|
521
|
-
stops = np.subtract(
|
642
|
+
stops = np.subtract(cand_end, extraction_shape)
|
522
643
|
keep_peaks = (
|
523
644
|
np.sum(
|
524
|
-
np.multiply(
|
645
|
+
np.multiply(cand_beg == 0, stops == 0),
|
525
646
|
axis=1,
|
526
647
|
)
|
527
|
-
==
|
648
|
+
== peaks.shape[1]
|
528
649
|
)
|
529
650
|
n_remaining = keep_peaks.sum()
|
530
651
|
if n_remaining == 0:
|
531
652
|
print(
|
532
653
|
"No peak remaining after filtering. Started with"
|
533
|
-
f" {
|
654
|
+
f" {peaks.shape[0]} filtered to {n_remaining}."
|
534
655
|
" Consider reducing min_distance, increase num_peaks or use"
|
535
656
|
" a different peak caller."
|
536
657
|
)
|
537
|
-
exit(-1)
|
538
658
|
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
659
|
+
cand_beg = cand_beg[keep_peaks,]
|
660
|
+
cand_end = cand_end[keep_peaks,]
|
661
|
+
obs_beg = obs_beg[keep_peaks,]
|
662
|
+
obs_end = obs_end[keep_peaks,]
|
543
663
|
subset = self[keep_peaks]
|
544
664
|
|
545
|
-
|
546
|
-
|
665
|
+
cand_beg, cand_end = cand_beg.astype(int), cand_end.astype(int)
|
666
|
+
obs_beg, obs_end = obs_beg.astype(int), obs_end.astype(int)
|
547
667
|
|
548
668
|
candidate_slices = [
|
549
669
|
tuple(slice(s, e) for s, e in zip(start_row, stop_row))
|
550
|
-
for start_row, stop_row in zip(
|
670
|
+
for start_row, stop_row in zip(cand_beg, cand_end)
|
551
671
|
]
|
552
672
|
|
553
673
|
observation_slices = [
|
554
674
|
tuple(slice(s, e) for s, e in zip(start_row, stop_row))
|
555
|
-
for start_row, stop_row in zip(
|
675
|
+
for start_row, stop_row in zip(obs_beg, obs_end)
|
556
676
|
]
|
557
677
|
|
558
678
|
if return_orientations:
|
@@ -0,0 +1,188 @@
|
|
1
|
+
""" Utilities for the generation of frequency grids.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from numpy.typing import NDArray
|
12
|
+
|
13
|
+
from ..backends import backend
|
14
|
+
from ..matching_utils import euler_to_rotationmatrix
|
15
|
+
|
16
|
+
|
17
|
+
def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
|
18
|
+
"""
|
19
|
+
Given an opening_axis, computes the shape of the remaining dimensions.
|
20
|
+
|
21
|
+
Parameters:
|
22
|
+
-----------
|
23
|
+
shape : Tuple[int]
|
24
|
+
The shape of the input array.
|
25
|
+
opening_axis : int
|
26
|
+
The axis along which the array will be tilted.
|
27
|
+
reduce_dim : bool, optional (default=False)
|
28
|
+
Whether to reduce the dimensionality after tilting.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
--------
|
32
|
+
Tuple[int]
|
33
|
+
The shape of the array after tilting.
|
34
|
+
"""
|
35
|
+
tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
|
36
|
+
if reduce_dim:
|
37
|
+
tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
|
38
|
+
|
39
|
+
return tilt_shape
|
40
|
+
|
41
|
+
|
42
|
+
def centered_grid(shape: Tuple[int]) -> NDArray:
|
43
|
+
"""
|
44
|
+
Generate an integer valued grid centered around size // 2
|
45
|
+
|
46
|
+
Parameters:
|
47
|
+
-----------
|
48
|
+
shape : Tuple[int]
|
49
|
+
The shape of the grid.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
--------
|
53
|
+
NDArray
|
54
|
+
The centered grid.
|
55
|
+
"""
|
56
|
+
index_grid = np.array(
|
57
|
+
np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
|
58
|
+
)
|
59
|
+
return index_grid
|
60
|
+
|
61
|
+
|
62
|
+
def frequency_grid_at_angle(
|
63
|
+
shape: Tuple[int],
|
64
|
+
angle: float,
|
65
|
+
sampling_rate: Tuple[float],
|
66
|
+
opening_axis: int = None,
|
67
|
+
tilt_axis: int = None,
|
68
|
+
) -> NDArray:
|
69
|
+
"""
|
70
|
+
Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
|
71
|
+
|
72
|
+
Parameters:
|
73
|
+
-----------
|
74
|
+
shape : Tuple[int]
|
75
|
+
The shape of the grid.
|
76
|
+
angle : float
|
77
|
+
The angle at which to generate the grid.
|
78
|
+
sampling_rate : Tuple[float]
|
79
|
+
The sampling rate for each dimension.
|
80
|
+
opening_axis : int, optional
|
81
|
+
The axis to be opened, defaults to None.
|
82
|
+
tilt_axis : int, optional
|
83
|
+
The axis along which the grid is tilted, defaults to None.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
--------
|
87
|
+
NDArray
|
88
|
+
The frequency grid.
|
89
|
+
"""
|
90
|
+
sampling_rate = np.array(sampling_rate)
|
91
|
+
sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
|
92
|
+
|
93
|
+
tilt_shape = compute_tilt_shape(
|
94
|
+
shape=shape, opening_axis=opening_axis, reduce_dim=False
|
95
|
+
)
|
96
|
+
index_grid = centered_grid(shape=tilt_shape)
|
97
|
+
if angle != 0:
|
98
|
+
angles = np.zeros(len(shape))
|
99
|
+
angles[tilt_axis] = angle
|
100
|
+
rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
|
101
|
+
index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
|
102
|
+
|
103
|
+
norm = np.divide(1, 2 * sampling_rate * np.divide(shape, 2).astype(int))
|
104
|
+
|
105
|
+
index_grid = np.multiply(index_grid.T, norm).T
|
106
|
+
index_grid = np.squeeze(index_grid)
|
107
|
+
index_grid = np.linalg.norm(index_grid, axis=(0))
|
108
|
+
return index_grid
|
109
|
+
|
110
|
+
|
111
|
+
def fftfreqn(
|
112
|
+
shape: Tuple[int],
|
113
|
+
sampling_rate: Tuple[float],
|
114
|
+
compute_euclidean_norm: bool = False,
|
115
|
+
shape_is_real_fourier: bool = False,
|
116
|
+
) -> NDArray:
|
117
|
+
"""
|
118
|
+
Generate the n-dimensional discrete Fourier Transform sample frequencies.
|
119
|
+
|
120
|
+
Parameters:
|
121
|
+
-----------
|
122
|
+
shape : Tuple[int]
|
123
|
+
The shape of the data.
|
124
|
+
sampling_rate : float or Tuple[float]
|
125
|
+
The sampling rate.
|
126
|
+
compute_euclidean_norm : bool, optional
|
127
|
+
Whether to compute the Euclidean norm, defaults to False.
|
128
|
+
shape_is_real_fourier : bool, optional
|
129
|
+
Whether the shape corresponds to a real Fourier transform, defaults to False.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
--------
|
133
|
+
NDArray
|
134
|
+
The sample frequencies.
|
135
|
+
"""
|
136
|
+
center = backend.astype(backend.divide(shape, 2), backend._int_dtype)
|
137
|
+
|
138
|
+
norm = np.ones(len(shape))
|
139
|
+
if sampling_rate is not None:
|
140
|
+
norm = backend.astype(backend.multiply(shape, sampling_rate), int)
|
141
|
+
|
142
|
+
if shape_is_real_fourier:
|
143
|
+
center[-1] = 0
|
144
|
+
norm[-1] = 1
|
145
|
+
if sampling_rate is not None:
|
146
|
+
norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
|
147
|
+
|
148
|
+
indices = backend.transpose(backend.indices(shape))
|
149
|
+
indices -= center
|
150
|
+
indices = backend.divide(indices, norm)
|
151
|
+
indices = backend.transpose(indices)
|
152
|
+
|
153
|
+
if compute_euclidean_norm:
|
154
|
+
indices = backend.square(indices)
|
155
|
+
indices = backend.sum(indices, axis=0)
|
156
|
+
backend.sqrt(indices, out=indices)
|
157
|
+
|
158
|
+
return indices
|
159
|
+
|
160
|
+
|
161
|
+
def crop_real_fourier(data: NDArray) -> NDArray:
|
162
|
+
"""
|
163
|
+
Crop the real part of a Fourier transform.
|
164
|
+
|
165
|
+
Parameters:
|
166
|
+
-----------
|
167
|
+
data : NDArray
|
168
|
+
The Fourier transformed data.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
--------
|
172
|
+
NDArray
|
173
|
+
The cropped data.
|
174
|
+
"""
|
175
|
+
stop = 1 + (data.shape[-1] // 2)
|
176
|
+
return data[..., :stop]
|
177
|
+
|
178
|
+
|
179
|
+
def shift_fourier(data: NDArray, shape_is_real_fourier: bool = False):
|
180
|
+
shift = backend.add(
|
181
|
+
backend.astype(backend.divide(data.shape, 2), int),
|
182
|
+
backend.mod(data.shape, 2),
|
183
|
+
)
|
184
|
+
if shape_is_real_fourier:
|
185
|
+
shift[-1] = 0
|
186
|
+
|
187
|
+
data = backend.roll(data, shift, tuple(i for i in range(len(shift))))
|
188
|
+
return data
|
@@ -0,0 +1,31 @@
|
|
1
|
+
""" Defines a specification for filters that can be used with
|
2
|
+
:py:class:`tme.preprocessing.compose.Compose`.
|
3
|
+
|
4
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
from typing import Dict
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
|
11
|
+
|
12
|
+
class ComposableFilter(ABC):
|
13
|
+
"""
|
14
|
+
Strategy class for composable filters.
|
15
|
+
"""
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
def __call__(self, *args, **kwargs) -> Dict:
|
19
|
+
"""
|
20
|
+
Parameters:
|
21
|
+
-----------
|
22
|
+
*args : tuple
|
23
|
+
Variable length argument list.
|
24
|
+
**kwargs : dict
|
25
|
+
Arbitrary keyword arguments.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
--------
|
29
|
+
Dict
|
30
|
+
A dictionary representing the result of the filtering operation.
|
31
|
+
"""
|
@@ -0,0 +1,51 @@
|
|
1
|
+
""" Combine filters using an interface analogous to pytorch's Compose.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple, Dict
|
9
|
+
|
10
|
+
from tme.backends import backend
|
11
|
+
|
12
|
+
|
13
|
+
class Compose:
|
14
|
+
"""
|
15
|
+
Compose a series of transformations.
|
16
|
+
|
17
|
+
This class allows composing multiple transformations together. Each transformation
|
18
|
+
is expected to be a callable that accepts keyword arguments and returns metadata.
|
19
|
+
|
20
|
+
Parameters:
|
21
|
+
-----------
|
22
|
+
transforms : Tuple[object]
|
23
|
+
A tuple containing transformation objects.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
--------
|
27
|
+
Dict
|
28
|
+
Metadata resulting from the composed transformations.
|
29
|
+
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, transforms: Tuple[object]):
|
33
|
+
self.transforms = transforms
|
34
|
+
|
35
|
+
def __call__(self, **kwargs: Dict) -> Dict:
|
36
|
+
meta = {}
|
37
|
+
if not len(self.transforms):
|
38
|
+
return meta
|
39
|
+
|
40
|
+
meta = self.transforms[0](**kwargs)
|
41
|
+
for transform in self.transforms[1:]:
|
42
|
+
kwargs.update(meta)
|
43
|
+
ret = transform(**kwargs)
|
44
|
+
|
45
|
+
if ret.get("is_multiplicative_filter", False):
|
46
|
+
backend.multiply(ret["data"], meta["data"], out=ret["data"])
|
47
|
+
ret["merge"] = None
|
48
|
+
|
49
|
+
meta = ret
|
50
|
+
|
51
|
+
return meta
|