pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
- pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
- scripts/estimate_ram_usage.py +97 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +30 -41
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +96 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +158 -390
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_orientations.py +0 -12
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +64 -18
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +79 -40
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +58 -5
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +24 -13
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
#!python
|
2
|
+
""" Estimate RAM requirements for template matching jobs.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
import numpy as np
|
9
|
+
import argparse
|
10
|
+
from tme import Density
|
11
|
+
from tme.matching_utils import estimate_ram_usage
|
12
|
+
from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
13
|
+
|
14
|
+
|
15
|
+
def parse_args():
|
16
|
+
parser = argparse.ArgumentParser(
|
17
|
+
description="Estimate RAM usage for template matching."
|
18
|
+
)
|
19
|
+
parser.add_argument(
|
20
|
+
"-m",
|
21
|
+
"--target",
|
22
|
+
dest="target",
|
23
|
+
type=str,
|
24
|
+
required=True,
|
25
|
+
help="Path to a target in CCP4/MRC format.",
|
26
|
+
)
|
27
|
+
parser.add_argument(
|
28
|
+
"-i",
|
29
|
+
"--template",
|
30
|
+
dest="template",
|
31
|
+
type=str,
|
32
|
+
required=True,
|
33
|
+
help="Path to a template in PDB/MMCIF or CCP4/MRC format.",
|
34
|
+
)
|
35
|
+
parser.add_argument(
|
36
|
+
"--matching_method",
|
37
|
+
required=False,
|
38
|
+
default=None,
|
39
|
+
help="Analyzer method to use.",
|
40
|
+
)
|
41
|
+
parser.add_argument(
|
42
|
+
"-s",
|
43
|
+
dest="score",
|
44
|
+
type=str,
|
45
|
+
default="FLCSphericalMask",
|
46
|
+
help="Template matching scoring function.",
|
47
|
+
choices=MATCHING_EXHAUSTIVE_REGISTER.keys(),
|
48
|
+
)
|
49
|
+
parser.add_argument(
|
50
|
+
"--ncores", type=int, help="Number of cores for parallelization.", required=True
|
51
|
+
)
|
52
|
+
parser.add_argument(
|
53
|
+
"--no_edge_padding",
|
54
|
+
dest="no_edge_padding",
|
55
|
+
action="store_true",
|
56
|
+
default=False,
|
57
|
+
help="Whether to pad the edges of the target. This is useful, if the target"
|
58
|
+
" has a well defined bounding box, e.g. a density map.",
|
59
|
+
)
|
60
|
+
parser.add_argument(
|
61
|
+
"--no_fourier_padding",
|
62
|
+
dest="no_fourier_padding",
|
63
|
+
action="store_true",
|
64
|
+
default=False,
|
65
|
+
help="Whether input arrays should be zero-padded to the full convolution shape"
|
66
|
+
" for numerical stability. When working with very large targets such as"
|
67
|
+
" tomograms it is safe to use this flag and benefit from the performance gain.",
|
68
|
+
)
|
69
|
+
args = parser.parse_args()
|
70
|
+
return args
|
71
|
+
|
72
|
+
|
73
|
+
def main():
|
74
|
+
args = parse_args()
|
75
|
+
target = Density.from_file(args.target)
|
76
|
+
template = Density.from_file(args.template)
|
77
|
+
|
78
|
+
target_box = target.shape
|
79
|
+
if not args.no_edge_padding:
|
80
|
+
target_box = np.add(target_box, template.shape)
|
81
|
+
|
82
|
+
template_box = template.shape
|
83
|
+
if args.no_fourier_padding:
|
84
|
+
template_box = np.ones(len(template_box), dtype=int)
|
85
|
+
|
86
|
+
result = estimate_ram_usage(
|
87
|
+
shape1=target_box,
|
88
|
+
shape2=template_box,
|
89
|
+
matching_method=args.score,
|
90
|
+
ncores=args.ncores,
|
91
|
+
analyzer_method="MaxScoreOverRotations",
|
92
|
+
)
|
93
|
+
print(result)
|
94
|
+
|
95
|
+
|
96
|
+
if __name__ == "__main__":
|
97
|
+
main()
|
@@ -12,8 +12,8 @@ from sys import exit
|
|
12
12
|
from time import time
|
13
13
|
from typing import Tuple
|
14
14
|
from copy import deepcopy
|
15
|
-
from os.path import exists
|
16
15
|
from tempfile import gettempdir
|
16
|
+
from os.path import exists, abspath
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
|
@@ -359,8 +359,8 @@ def parse_args():
|
|
359
359
|
"--invert-target-contrast",
|
360
360
|
action="store_true",
|
361
361
|
default=False,
|
362
|
-
help="Invert the target
|
363
|
-
"
|
362
|
+
help="Invert the target contrast. Useful for matching on tomograms if the "
|
363
|
+
"template has not been inverted.",
|
364
364
|
)
|
365
365
|
io_group.add_argument(
|
366
366
|
"--scramble-phases",
|
@@ -576,7 +576,7 @@ def parse_args():
|
|
576
576
|
"'angles', or a single column file without header. Exposure will be taken from "
|
577
577
|
"the input file , if you are using a tab-separated file, the column names "
|
578
578
|
"'angles' and 'weights' need to be present. It is also possible to specify a "
|
579
|
-
"continuous wedge mask using e.g.,
|
579
|
+
"continuous wedge mask using e.g., 50,45.",
|
580
580
|
)
|
581
581
|
filter_group.add_argument(
|
582
582
|
"--tilt-weighting",
|
@@ -685,13 +685,6 @@ def parse_args():
|
|
685
685
|
help="Useful if the target does not have a well-defined bounding box. Will be "
|
686
686
|
"activated automatically if splitting is required to avoid boundary artifacts.",
|
687
687
|
)
|
688
|
-
performance_group.add_argument(
|
689
|
-
"--pad-filter",
|
690
|
-
action="store_true",
|
691
|
-
default=False,
|
692
|
-
help="Pad the template filter to the shape of the target. Useful for fast "
|
693
|
-
"oscilating filters to avoid aliasing effects.",
|
694
|
-
)
|
695
688
|
performance_group.add_argument(
|
696
689
|
"--interpolation-order",
|
697
690
|
required=False,
|
@@ -700,13 +693,6 @@ def parse_args():
|
|
700
693
|
help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
|
701
694
|
"and pytorch backends.",
|
702
695
|
)
|
703
|
-
performance_group.add_argument(
|
704
|
-
"--use-mixed-precision",
|
705
|
-
action="store_true",
|
706
|
-
default=False,
|
707
|
-
help="Use float16 for real values operations where possible. Not supported "
|
708
|
-
"for jax backend.",
|
709
|
-
)
|
710
696
|
performance_group.add_argument(
|
711
697
|
"--use-memmap",
|
712
698
|
action="store_true",
|
@@ -742,6 +728,7 @@ def parse_args():
|
|
742
728
|
args.interpolation_order = 3
|
743
729
|
if args.backend in ("jax", "pytorch"):
|
744
730
|
args.interpolation_order = 1
|
731
|
+
args.reconstruction_interpolation_order = 1
|
745
732
|
|
746
733
|
if args.interpolation_order < 0:
|
747
734
|
args.interpolation_order = None
|
@@ -779,6 +766,14 @@ def parse_args():
|
|
779
766
|
)
|
780
767
|
args.orientations = orientations
|
781
768
|
|
769
|
+
args.target = abspath(args.target)
|
770
|
+
if args.target_mask is not None:
|
771
|
+
args.target_mask = abspath(args.target_mask)
|
772
|
+
|
773
|
+
args.template = abspath(args.template)
|
774
|
+
if args.template_mask is not None:
|
775
|
+
args.template_mask = abspath(args.template_mask)
|
776
|
+
|
782
777
|
return args
|
783
778
|
|
784
779
|
|
@@ -796,15 +791,23 @@ def main():
|
|
796
791
|
sampling_rate=target.sampling_rate,
|
797
792
|
)
|
798
793
|
|
794
|
+
if np.allclose(target.sampling_rate, 1):
|
795
|
+
warnings.warn(
|
796
|
+
"Target sampling rate is 1.0, which may indicate missing or incorrect "
|
797
|
+
"metadata. Verify that your target file contains proper sampling rate "
|
798
|
+
"information, as filters (CTF, BandPass) require accurate sampling rates "
|
799
|
+
"to function correctly."
|
800
|
+
)
|
801
|
+
|
799
802
|
if target.sampling_rate.size == template.sampling_rate.size:
|
800
803
|
if not np.allclose(
|
801
804
|
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
802
805
|
):
|
803
|
-
|
804
|
-
f"
|
805
|
-
"
|
806
|
+
warnings.warn(
|
807
|
+
f"Sampling rate mismatch detected: target={target.sampling_rate} "
|
808
|
+
f"template={template.sampling_rate}. Proceeding with user-provided "
|
809
|
+
f"values. Make sure this is intentional. "
|
806
810
|
)
|
807
|
-
template = template.resample(target.sampling_rate, order=3)
|
808
811
|
|
809
812
|
template_mask = load_and_validate_mask(
|
810
813
|
mask_target=template, mask_path=args.template_mask
|
@@ -881,16 +884,13 @@ def main():
|
|
881
884
|
print("\n" + "-" * 80)
|
882
885
|
|
883
886
|
if args.scramble_phases:
|
884
|
-
template.data = scramble_phases(
|
885
|
-
template.data, noise_proportion=1.0, normalize_power=False
|
886
|
-
)
|
887
|
+
template.data = scramble_phases(template.data, noise_proportion=1.0)
|
887
888
|
|
888
889
|
callback_class = MaxScoreOverRotations
|
889
|
-
if args.peak_calling:
|
890
|
-
callback_class = PeakCallerMaximumFilter
|
891
|
-
|
892
890
|
if args.orientations is not None:
|
893
891
|
callback_class = MaxScoreOverRotationsConstrained
|
892
|
+
elif args.peak_calling:
|
893
|
+
callback_class = PeakCallerMaximumFilter
|
894
894
|
|
895
895
|
# Determine suitable backend for the selected operation
|
896
896
|
available_backends = be.available_backends()
|
@@ -939,16 +939,6 @@ def main():
|
|
939
939
|
args.use_gpu = False
|
940
940
|
be.change_backend("pytorch", device=device)
|
941
941
|
|
942
|
-
# TODO: Make the inverse casting from complex64 -> float 16 stable
|
943
|
-
# if args.use_mixed_precision:
|
944
|
-
# be.change_backend(
|
945
|
-
# backend_name=args.backend,
|
946
|
-
# float_dtype=be._array_backend.float16,
|
947
|
-
# complex_dtype=be._array_backend.complex64,
|
948
|
-
# int_dtype=be._array_backend.int16,
|
949
|
-
# device=device,
|
950
|
-
# )
|
951
|
-
|
952
942
|
available_memory = be.get_available_memory() * be.device_count()
|
953
943
|
if args.memory is None:
|
954
944
|
args.memory = int(args.memory_scaling * available_memory)
|
@@ -976,6 +966,8 @@ def main():
|
|
976
966
|
target_dim=target.metadata.get("batch_dimension", None),
|
977
967
|
template_dim=template.metadata.get("batch_dimension", None),
|
978
968
|
)
|
969
|
+
args.batch_dims = tuple(int(x) for x in np.where(matching_data._batch_mask)[0])
|
970
|
+
|
979
971
|
splits, schedule = compute_schedule(args, matching_data, callback_class)
|
980
972
|
|
981
973
|
n_splits = np.prod(list(splits.values()))
|
@@ -1004,7 +996,6 @@ def main():
|
|
1004
996
|
compute_options = {
|
1005
997
|
"Backend": be._BACKEND_REGISTRY[be._backend_name],
|
1006
998
|
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1007
|
-
"Use Mixed Precision": args.use_mixed_precision,
|
1008
999
|
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1009
1000
|
"Temporary Directory": args.temp_directory,
|
1010
1001
|
"Target Splits": f"{target_split} [N={n_splits}]",
|
@@ -1025,7 +1016,6 @@ def main():
|
|
1025
1016
|
"Tilt Angles": args.tilt_angles,
|
1026
1017
|
"Tilt Weighting": args.tilt_weighting,
|
1027
1018
|
"Reconstruction Filter": args.reconstruction_filter,
|
1028
|
-
"Extend Filter Grid": args.pad_filter,
|
1029
1019
|
}
|
1030
1020
|
if args.ctf_file is not None or args.defocus is not None:
|
1031
1021
|
filter_args["CTF File"] = args.ctf_file
|
@@ -1081,7 +1071,6 @@ def main():
|
|
1081
1071
|
callback_class_args=analyzer_args,
|
1082
1072
|
target_splits=splits,
|
1083
1073
|
pad_target_edges=args.pad_edges,
|
1084
|
-
pad_template_filter=args.pad_filter,
|
1085
1074
|
interpolation_order=args.interpolation_order,
|
1086
1075
|
)
|
1087
1076
|
|
@@ -17,7 +17,7 @@ from scipy.special import erfcinv
|
|
17
17
|
|
18
18
|
from tme import Density, Structure, Orientations
|
19
19
|
from tme.cli import sanitize_name, print_block, print_entry
|
20
|
-
from tme.matching_utils import load_pickle,
|
20
|
+
from tme.matching_utils import load_pickle, center_slice, write_pickle
|
21
21
|
from tme.matching_optimization import create_score_object, optimize_match
|
22
22
|
from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
|
23
23
|
from tme.analyzer import (
|
@@ -87,6 +87,11 @@ def parse_args():
|
|
87
87
|
help="Output prefix. Defaults to basename of first input. Extension is "
|
88
88
|
"added with respect to chosen output format.",
|
89
89
|
)
|
90
|
+
output_group.add_argument(
|
91
|
+
"--angles-clockwise",
|
92
|
+
action="store_true",
|
93
|
+
help="Report Euler angles in clockwise format expected by RELION.",
|
94
|
+
)
|
90
95
|
output_group.add_argument(
|
91
96
|
"--output-format",
|
92
97
|
choices=[
|
@@ -112,7 +117,7 @@ def parse_args():
|
|
112
117
|
peak_group.add_argument(
|
113
118
|
"--peak-caller",
|
114
119
|
choices=list(PEAK_CALLERS.keys()),
|
115
|
-
default="
|
120
|
+
default="PeakCallerMaximumFilter",
|
116
121
|
help="Peak caller for local maxima identification.",
|
117
122
|
)
|
118
123
|
peak_group.add_argument(
|
@@ -183,7 +188,7 @@ def parse_args():
|
|
183
188
|
)
|
184
189
|
additional_group.add_argument(
|
185
190
|
"--n-false-positives",
|
186
|
-
type=
|
191
|
+
type=float,
|
187
192
|
default=None,
|
188
193
|
required=False,
|
189
194
|
help="Number of accepted false-positives picks to determine minimum score.",
|
@@ -313,11 +318,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
|
|
313
318
|
data = load_matching_output(foreground)
|
314
319
|
scores, _, rotations, rotation_mapping, *_ = data
|
315
320
|
|
316
|
-
# We could normalize to unit sdev, but that might lead to unexpected
|
317
|
-
# results for flat background distributions
|
318
|
-
scores -= scores.mean()
|
319
321
|
indices = tuple(slice(0, x) for x in scores.shape)
|
320
|
-
|
321
322
|
indices_update = scores > scores_out[indices]
|
322
323
|
scores_out[indices][indices_update] = scores[indices_update]
|
323
324
|
|
@@ -364,9 +365,8 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
|
|
364
365
|
scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
|
365
366
|
for background in backgrounds:
|
366
367
|
data_norm = load_matching_output(background)
|
368
|
+
scores, _, rotations, rotation_mapping, *_ = data_norm
|
367
369
|
|
368
|
-
scores = data_norm[0]
|
369
|
-
scores -= scores.mean()
|
370
370
|
indices = tuple(slice(0, x) for x in scores.shape)
|
371
371
|
indices_update = scores > scores_norm[indices]
|
372
372
|
scores_norm[indices][indices_update] = scores[indices_update]
|
@@ -375,8 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
|
|
375
375
|
update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
|
376
376
|
scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
|
377
377
|
scores_out[update] = data[0][update] - scores_norm[update]
|
378
|
-
scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
|
379
378
|
scores_out = np.fmax(scores_out, 0, out=scores_out)
|
379
|
+
scores_out[update] += scores_norm[update].mean()
|
380
|
+
|
381
|
+
# scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
|
380
382
|
data[0] = scores_out
|
381
383
|
|
382
384
|
fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
|
@@ -478,15 +480,21 @@ def main():
|
|
478
480
|
if orientations is None:
|
479
481
|
translations, rotations, scores, details = [], [], [], []
|
480
482
|
|
481
|
-
|
482
|
-
|
483
|
+
var = None
|
484
|
+
# Data processed by normalize_input is guaranteed to have this shape)
|
485
|
+
scores, _, rotation_array, rotation_mapping, *_ = data
|
486
|
+
if len(data) == 6:
|
487
|
+
scores, _, rotation_array, rotation_mapping, var, *_ = data
|
483
488
|
|
484
489
|
cropped_shape = np.subtract(
|
485
490
|
scores.shape, np.multiply(args.min_boundary_distance, 2)
|
486
491
|
).astype(int)
|
487
492
|
|
488
493
|
if args.min_boundary_distance > 0:
|
489
|
-
|
494
|
+
_scores = np.zeros_like(scores)
|
495
|
+
subset = center_slice(scores.shape, cropped_shape)
|
496
|
+
_scores[subset] = scores[subset]
|
497
|
+
scores = _scores
|
490
498
|
|
491
499
|
if args.n_false_positives is not None:
|
492
500
|
# Rickgauer et al. 2017
|
@@ -499,17 +507,20 @@ def main():
|
|
499
507
|
)
|
500
508
|
args.n_false_positives = max(args.n_false_positives, 1)
|
501
509
|
n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
|
510
|
+
std = np.std(scores[cropped_slice])
|
511
|
+
if var is not None:
|
512
|
+
std = np.asarray(np.sqrt(var)).reshape(())
|
513
|
+
|
502
514
|
minimum_score = np.multiply(
|
503
515
|
erfcinv(2 * args.n_false_positives / n_correlations),
|
504
|
-
np.sqrt(2) *
|
516
|
+
np.sqrt(2) * std,
|
505
517
|
)
|
506
|
-
print(f"Determined
|
507
|
-
|
508
|
-
args.min_score = minimum_score
|
518
|
+
print(f"Determined cutoff --min-score {minimum_score}.")
|
519
|
+
args.min_score = max(minimum_score, 0)
|
509
520
|
|
510
521
|
args.batch_dims = None
|
511
|
-
if hasattr(cli_args, "
|
512
|
-
args.batch_dims = cli_args.
|
522
|
+
if hasattr(cli_args, "batch_dims"):
|
523
|
+
args.batch_dims = cli_args.batch_dims
|
513
524
|
|
514
525
|
peak_caller_kwargs = {
|
515
526
|
"shape": scores.shape,
|
@@ -517,8 +528,8 @@ def main():
|
|
517
528
|
"min_distance": args.min_distance,
|
518
529
|
"min_boundary_distance": args.min_boundary_distance,
|
519
530
|
"batch_dims": args.batch_dims,
|
520
|
-
"
|
521
|
-
"
|
531
|
+
"min_score": args.min_score,
|
532
|
+
"max_score": args.max_score,
|
522
533
|
}
|
523
534
|
|
524
535
|
peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
|
@@ -611,6 +622,9 @@ def main():
|
|
611
622
|
orientations.rotations[index] = angles
|
612
623
|
orientations.scores[index] = score * -1
|
613
624
|
|
625
|
+
if args.angles_clockwise:
|
626
|
+
orientations.rotations *= -1
|
627
|
+
|
614
628
|
if args.output_format in ("orientations", "relion4", "relion5"):
|
615
629
|
file_format, extension = "text", "tsv"
|
616
630
|
|
@@ -19,13 +19,13 @@ from magicgui import widgets
|
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
from napari.layers import Image
|
21
21
|
from scipy.fft import next_fast_len
|
22
|
-
from qtpy.QtWidgets import QFileDialog
|
23
22
|
from napari.utils.events import EventedList
|
23
|
+
from qtpy.QtWidgets import QFileDialog, QMessageBox
|
24
24
|
|
25
|
-
from tme.backends import backend
|
25
|
+
from tme.backends import backend as be
|
26
26
|
from tme.rotations import align_vectors
|
27
|
-
from tme import Preprocessor, Density, Orientations
|
28
27
|
from tme.matching_utils import create_mask, load_pickle
|
28
|
+
from tme import Preprocessor, Density, Orientations
|
29
29
|
from tme.filters import BandPassReconstructed, CTFReconstructed
|
30
30
|
|
31
31
|
preprocessor = Preprocessor()
|
@@ -69,7 +69,7 @@ def ctf_filter(
|
|
69
69
|
flip_phase: bool = False,
|
70
70
|
) -> NDArray:
|
71
71
|
fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
|
72
|
-
template_pad =
|
72
|
+
template_pad = be.topleft_pad(template, fast_shape)
|
73
73
|
template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
|
74
74
|
ctf = CTFReconstructed(
|
75
75
|
shape=fast_shape,
|
@@ -85,7 +85,7 @@ def ctf_filter(
|
|
85
85
|
)
|
86
86
|
np.multiply(template_ft, ctf()["data"], out=template_ft)
|
87
87
|
template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
|
88
|
-
template =
|
88
|
+
template = be.topleft_pad(template_pad, template.shape)
|
89
89
|
return template
|
90
90
|
|
91
91
|
|
@@ -392,7 +392,13 @@ class FilterWidget(widgets.Container):
|
|
392
392
|
metadata = selected_layer_metadata.copy()
|
393
393
|
if "filter_parameters" not in metadata:
|
394
394
|
metadata["filter_parameters"] = []
|
395
|
-
|
395
|
+
|
396
|
+
payload = {filter_name: kwargs.copy()}
|
397
|
+
if isinstance(metadata["filter_parameters"], dict):
|
398
|
+
metadata["filter_parameters"].update(payload)
|
399
|
+
else:
|
400
|
+
metadata["filter_parameters"].append(payload)
|
401
|
+
|
396
402
|
metadata["used_filter"] = filter_name
|
397
403
|
new_layer.metadata = metadata
|
398
404
|
|
@@ -450,7 +456,30 @@ def box_mask(
|
|
450
456
|
mask_type="box",
|
451
457
|
shape=template.shape,
|
452
458
|
center=(center_x, center_y, center_z),
|
453
|
-
|
459
|
+
size=(height_x, height_y, height_z),
|
460
|
+
sigma_decay=sigma_decay,
|
461
|
+
)
|
462
|
+
|
463
|
+
|
464
|
+
def membrane_mask(
|
465
|
+
template: NDArray,
|
466
|
+
symmetry_axis: int,
|
467
|
+
center_x: float,
|
468
|
+
center_y: float,
|
469
|
+
center_z: float,
|
470
|
+
radius: float,
|
471
|
+
thickness: float = 1,
|
472
|
+
separation: float = 3,
|
473
|
+
sigma_decay: float = 0,
|
474
|
+
**kwargs,
|
475
|
+
) -> NDArray:
|
476
|
+
return create_mask(
|
477
|
+
center=(center_x, center_y, center_z),
|
478
|
+
mask_type="membrane",
|
479
|
+
shape=template.shape,
|
480
|
+
radius=radius,
|
481
|
+
thickness=thickness,
|
482
|
+
separation=separation,
|
454
483
|
sigma_decay=sigma_decay,
|
455
484
|
)
|
456
485
|
|
@@ -471,7 +500,7 @@ def tube_mask(
|
|
471
500
|
mask_type="tube",
|
472
501
|
shape=template.shape,
|
473
502
|
symmetry_axis=symmetry_axis,
|
474
|
-
|
503
|
+
center=(center_x, center_y, center_z),
|
475
504
|
inner_radius=inner_radius,
|
476
505
|
outer_radius=outer_radius,
|
477
506
|
height=height,
|
@@ -584,6 +613,7 @@ class MaskWidget(widgets.Container):
|
|
584
613
|
"Ellipsoid": ellipsod_mask,
|
585
614
|
"Tube": tube_mask,
|
586
615
|
"Box": box_mask,
|
616
|
+
"Membrane": membrane_mask,
|
587
617
|
"Wedge": wedge_mask,
|
588
618
|
"Threshold": threshold_mask,
|
589
619
|
"Lowpass": lowpass_mask,
|
@@ -817,7 +847,7 @@ class AlignmentWidget(widgets.Container):
|
|
817
847
|
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
818
848
|
|
819
849
|
rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
|
820
|
-
rotated_data, _ =
|
850
|
+
rotated_data, _ = be.rigid_transform(
|
821
851
|
arr=active_layer.data,
|
822
852
|
rotation_matrix=rotation_matrix,
|
823
853
|
use_geometric_center=False,
|
@@ -953,7 +983,6 @@ class PointCloudWidget(widgets.Container):
|
|
953
983
|
if not isinstance(layer, napari.layers.Points):
|
954
984
|
continue
|
955
985
|
|
956
|
-
layer.face_color = "white"
|
957
986
|
if event == "Label":
|
958
987
|
if len(layer.properties.get("detail", ())) == 0:
|
959
988
|
continue
|
@@ -970,9 +999,7 @@ class PointCloudWidget(widgets.Container):
|
|
970
999
|
layer.face_color = "score_scaled"
|
971
1000
|
layer.face_colormap = "turbo"
|
972
1001
|
layer.face_color_mode = "colormap"
|
973
|
-
|
974
1002
|
layer.refresh_colors()
|
975
|
-
|
976
1003
|
return None
|
977
1004
|
|
978
1005
|
def _set_positive(self, event):
|
@@ -1140,9 +1167,18 @@ class MatchingWidget(widgets.Container):
|
|
1140
1167
|
self.viewer = viewer
|
1141
1168
|
self.dataframes = {}
|
1142
1169
|
|
1170
|
+
option_container = widgets.Container(layout="horizontal")
|
1171
|
+
self.load_target_checkbox = widgets.CheckBox(text="Load Target", value=False)
|
1172
|
+
self.load_rotations_checkbox = widgets.CheckBox(
|
1173
|
+
text="Load Rotations", value=False
|
1174
|
+
)
|
1175
|
+
option_container.append(self.load_target_checkbox)
|
1176
|
+
option_container.append(self.load_rotations_checkbox)
|
1177
|
+
|
1143
1178
|
self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
|
1144
1179
|
self.import_button.clicked.connect(self._get_load_path)
|
1145
1180
|
|
1181
|
+
self.append(option_container)
|
1146
1182
|
self.append(self.import_button)
|
1147
1183
|
|
1148
1184
|
def _get_load_path(self, event):
|
@@ -1150,7 +1186,7 @@ class MatchingWidget(widgets.Container):
|
|
1150
1186
|
self.native,
|
1151
1187
|
"Open Pickle File...",
|
1152
1188
|
"",
|
1153
|
-
"Pickle Files (*.pickle);;All Files (*)",
|
1189
|
+
"Pickle Files (*.pickle *pickle.gz);;All Files (*)",
|
1154
1190
|
)
|
1155
1191
|
if filename:
|
1156
1192
|
self._load_data(filename)
|
@@ -1159,14 +1195,35 @@ class MatchingWidget(widgets.Container):
|
|
1159
1195
|
data = load_pickle(filename)
|
1160
1196
|
|
1161
1197
|
fname = basename(filename).replace(".pickle", "")
|
1198
|
+
|
1199
|
+
if self.load_target_checkbox.value:
|
1200
|
+
try:
|
1201
|
+
target = Density.from_file(data[-1][-1].target)
|
1202
|
+
_ = self.viewer.add_image(
|
1203
|
+
data=target.data,
|
1204
|
+
name=f"{fname}_target",
|
1205
|
+
metadata={
|
1206
|
+
"origin": target.origin,
|
1207
|
+
"sampling_rate": target.sampling_rate,
|
1208
|
+
},
|
1209
|
+
)
|
1210
|
+
except Exception as e:
|
1211
|
+
msg = QMessageBox(self.native)
|
1212
|
+
msg.setIcon(QMessageBox.Warning)
|
1213
|
+
msg.setWindowTitle("Loading Error")
|
1214
|
+
msg.setText(str(e))
|
1215
|
+
msg.setStandardButtons(QMessageBox.Ok)
|
1216
|
+
msg.exec_()
|
1217
|
+
|
1162
1218
|
if data[0].ndim == data[2].ndim:
|
1163
1219
|
metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1220
|
+
if self.load_rotations_checkbox.value:
|
1221
|
+
_ = self.viewer.add_image(
|
1222
|
+
data=data[2],
|
1223
|
+
name=f"{fname}_rotations",
|
1224
|
+
colormap="orange",
|
1225
|
+
metadata=metadata,
|
1226
|
+
)
|
1170
1227
|
_ = self.viewer.add_image(
|
1171
1228
|
data=data[0],
|
1172
1229
|
name=f"{fname}_scores",
|
@@ -1174,11 +1231,8 @@ class MatchingWidget(widgets.Container):
|
|
1174
1231
|
metadata=metadata,
|
1175
1232
|
)
|
1176
1233
|
return None
|
1177
|
-
detail = np.zeros_like(data[2])
|
1178
|
-
else:
|
1179
|
-
detail = data[3]
|
1180
1234
|
|
1181
|
-
point_properties = {"score": data[2], "detail":
|
1235
|
+
point_properties = {"score": data[2], "detail": data[3]}
|
1182
1236
|
point_properties["score_scaled"] = np.log1p(
|
1183
1237
|
point_properties["score"] - point_properties["score"].min()
|
1184
1238
|
)
|
@@ -1191,8 +1245,26 @@ class MatchingWidget(widgets.Container):
|
|
1191
1245
|
)
|
1192
1246
|
|
1193
1247
|
|
1248
|
+
class CustomNapariViewer(napari.Viewer):
|
1249
|
+
"""
|
1250
|
+
Custom viewer to ensure 3D image layers are by default shown as xy projection.
|
1251
|
+
"""
|
1252
|
+
|
1253
|
+
def add_image(self, data, **kwargs):
|
1254
|
+
viewer_ndim = len(self.dims.order)
|
1255
|
+
layer = super().add_image(data, **kwargs)
|
1256
|
+
|
1257
|
+
try:
|
1258
|
+
# Set to xy view the first time data is opened
|
1259
|
+
if viewer_ndim != 3 and data.ndim == 3:
|
1260
|
+
self.dims.order = (2, 0, 1)
|
1261
|
+
except Exception:
|
1262
|
+
pass
|
1263
|
+
return layer
|
1264
|
+
|
1265
|
+
|
1194
1266
|
def main():
|
1195
|
-
viewer =
|
1267
|
+
viewer = CustomNapariViewer()
|
1196
1268
|
|
1197
1269
|
filter_widget = FilterWidget(preprocessor, viewer)
|
1198
1270
|
mask_widget = MaskWidget(viewer)
|