pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +147 -93
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +67 -26
- scripts/preprocessor_gui.py +175 -85
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +451 -809
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +111 -223
- tme/backends/jax_backend.py +214 -150
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +239 -507
- tme/backends/pytorch_backend.py +21 -145
- tme/density.py +233 -363
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +322 -285
- tme/matching_exhaustive.py +172 -1493
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +280 -386
- tme/memory.py +377 -0
- tme/orientations.py +52 -12
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +34 -40
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@
|
|
8
8
|
import os
|
9
9
|
import argparse
|
10
10
|
import warnings
|
11
|
-
import importlib.util
|
12
11
|
from sys import exit
|
13
12
|
from time import time
|
14
13
|
from typing import Tuple
|
@@ -22,7 +21,6 @@ from tme.matching_utils import (
|
|
22
21
|
get_rotation_matrices,
|
23
22
|
get_rotations_around_vector,
|
24
23
|
compute_parallelization_schedule,
|
25
|
-
euler_from_rotationmatrix,
|
26
24
|
scramble_phases,
|
27
25
|
generate_tempfile_name,
|
28
26
|
write_pickle,
|
@@ -33,7 +31,7 @@ from tme.analyzer import (
|
|
33
31
|
MaxScoreOverRotations,
|
34
32
|
PeakCallerMaximumFilter,
|
35
33
|
)
|
36
|
-
from tme.backends import backend
|
34
|
+
from tme.backends import backend as be
|
37
35
|
from tme.preprocessing import Compose
|
38
36
|
|
39
37
|
|
@@ -52,7 +50,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
|
|
52
50
|
|
53
51
|
def print_entry() -> None:
|
54
52
|
width = 80
|
55
|
-
text = f"
|
53
|
+
text = f" pytme v{__version__} "
|
56
54
|
padding_total = width - len(text) - 2
|
57
55
|
padding_left = padding_total // 2
|
58
56
|
padding_right = padding_total - padding_left
|
@@ -273,7 +271,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
273
271
|
return_real_fourier=True,
|
274
272
|
)
|
275
273
|
ctf.sampling_rate = template.sampling_rate
|
276
|
-
ctf.flip_phase =
|
274
|
+
ctf.flip_phase = args.no_flip_phase
|
277
275
|
ctf.amplitude_contrast = args.amplitude_contrast
|
278
276
|
ctf.spherical_aberration = args.spherical_aberration
|
279
277
|
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
@@ -306,6 +304,12 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
306
304
|
if highpass is not None:
|
307
305
|
highpass = np.max(np.divide(template.sampling_rate, highpass))
|
308
306
|
|
307
|
+
try:
|
308
|
+
if args.lowpass >= args.highpass:
|
309
|
+
warnings.warn("--lowpass should be smaller than --highpass.")
|
310
|
+
except Exception:
|
311
|
+
pass
|
312
|
+
|
309
313
|
bandpass = BandPassFilter(
|
310
314
|
use_gaussian=args.no_pass_smooth,
|
311
315
|
lowpass=lowpass,
|
@@ -313,7 +317,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
313
317
|
sampling_rate=template.sampling_rate,
|
314
318
|
)
|
315
319
|
template_filter.append(bandpass)
|
316
|
-
|
320
|
+
|
321
|
+
if not args.no_filter_target:
|
322
|
+
target_filter.append(bandpass)
|
317
323
|
|
318
324
|
if args.whiten_spectrum:
|
319
325
|
whitening_filter = LinearWhiteningFilter()
|
@@ -335,7 +341,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
335
341
|
|
336
342
|
|
337
343
|
def parse_args():
|
338
|
-
parser = argparse.ArgumentParser(
|
344
|
+
parser = argparse.ArgumentParser(
|
345
|
+
description="Perform template matching.",
|
346
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
347
|
+
)
|
339
348
|
|
340
349
|
io_group = parser.add_argument_group("Input / Output")
|
341
350
|
io_group.add_argument(
|
@@ -405,13 +414,6 @@ def parse_args():
|
|
405
414
|
choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
|
406
415
|
help="Template matching scoring function.",
|
407
416
|
)
|
408
|
-
scoring_group.add_argument(
|
409
|
-
"-p",
|
410
|
-
dest="peak_calling",
|
411
|
-
action="store_true",
|
412
|
-
default=False,
|
413
|
-
help="Perform peak calling instead of score aggregation.",
|
414
|
-
)
|
415
417
|
|
416
418
|
angular_group = parser.add_argument_group("Angular Sampling")
|
417
419
|
angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
|
@@ -445,7 +447,7 @@ def parse_args():
|
|
445
447
|
type=check_positive,
|
446
448
|
default=360.0,
|
447
449
|
required=False,
|
448
|
-
help="Sampling angle along the z-axis of the cone.
|
450
|
+
help="Sampling angle along the z-axis of the cone.",
|
449
451
|
)
|
450
452
|
angular_group.add_argument(
|
451
453
|
"--axis_sampling",
|
@@ -513,8 +515,7 @@ def parse_args():
|
|
513
515
|
required=False,
|
514
516
|
type=float,
|
515
517
|
default=0.85,
|
516
|
-
help="Fraction of available memory
|
517
|
-
"ignored if --ram is set",
|
518
|
+
help="Fraction of available memory to be used. Ignored if --ram is set.",
|
518
519
|
)
|
519
520
|
computation_group.add_argument(
|
520
521
|
"--temp_directory",
|
@@ -522,7 +523,13 @@ def parse_args():
|
|
522
523
|
default=None,
|
523
524
|
help="Directory for temporary objects. Faster I/O improves runtime.",
|
524
525
|
)
|
525
|
-
|
526
|
+
computation_group.add_argument(
|
527
|
+
"--backend",
|
528
|
+
dest="backend",
|
529
|
+
default=None,
|
530
|
+
choices=be.available_backends(),
|
531
|
+
help="[Expert] Overwrite default computation backend.",
|
532
|
+
)
|
526
533
|
filter_group = parser.add_argument_group("Filters")
|
527
534
|
filter_group.add_argument(
|
528
535
|
"--lowpass",
|
@@ -552,9 +559,9 @@ def parse_args():
|
|
552
559
|
dest="pass_format",
|
553
560
|
type=str,
|
554
561
|
required=False,
|
562
|
+
default="sampling_rate",
|
555
563
|
choices=["sampling_rate", "voxel", "frequency"],
|
556
|
-
help="How values passed to --lowpass and --highpass should be interpreted. "
|
557
|
-
"By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
|
564
|
+
help="How values passed to --lowpass and --highpass should be interpreted. ",
|
558
565
|
)
|
559
566
|
filter_group.add_argument(
|
560
567
|
"--whiten_spectrum",
|
@@ -613,6 +620,13 @@ def parse_args():
|
|
613
620
|
required=False,
|
614
621
|
help="Analogous to --interpolation_order but for reconstruction.",
|
615
622
|
)
|
623
|
+
filter_group.add_argument(
|
624
|
+
"--no_filter_target",
|
625
|
+
dest="no_filter_target",
|
626
|
+
action="store_true",
|
627
|
+
default=False,
|
628
|
+
help="Whether to not apply potential filters to the target.",
|
629
|
+
)
|
616
630
|
|
617
631
|
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
618
632
|
ctf_group.add_argument(
|
@@ -647,7 +661,7 @@ def parse_args():
|
|
647
661
|
type=float,
|
648
662
|
required=False,
|
649
663
|
default=300,
|
650
|
-
help="Acceleration voltage in kV
|
664
|
+
help="Acceleration voltage in kV.",
|
651
665
|
)
|
652
666
|
ctf_group.add_argument(
|
653
667
|
"--spherical_aberration",
|
@@ -663,14 +677,14 @@ def parse_args():
|
|
663
677
|
type=float,
|
664
678
|
required=False,
|
665
679
|
default=0.07,
|
666
|
-
help="Amplitude contrast
|
680
|
+
help="Amplitude contrast.",
|
667
681
|
)
|
668
682
|
ctf_group.add_argument(
|
669
683
|
"--no_flip_phase",
|
670
684
|
dest="no_flip_phase",
|
671
685
|
action="store_false",
|
672
686
|
required=False,
|
673
|
-
help="
|
687
|
+
help="Do not perform phase-flipping CTF correction.",
|
674
688
|
)
|
675
689
|
ctf_group.add_argument(
|
676
690
|
"--correct_defocus_gradient",
|
@@ -721,14 +735,22 @@ def parse_args():
|
|
721
735
|
"for numerical stability. When working with very large targets, e.g. tomograms, "
|
722
736
|
"it is safe to use this flag and benefit from the performance gain.",
|
723
737
|
)
|
738
|
+
performance_group.add_argument(
|
739
|
+
"--no_filter_padding",
|
740
|
+
dest="no_filter_padding",
|
741
|
+
action="store_true",
|
742
|
+
default=False,
|
743
|
+
help="Omits padding of optional template filters. Particularly effective when "
|
744
|
+
"the target is much larger than the template. However, for fast osciliating "
|
745
|
+
"filters setting this flag can introduce aliasing effects.",
|
746
|
+
)
|
724
747
|
performance_group.add_argument(
|
725
748
|
"--interpolation_order",
|
726
749
|
dest="interpolation_order",
|
727
750
|
required=False,
|
728
751
|
type=int,
|
729
752
|
default=3,
|
730
|
-
help="Spline interpolation used for
|
731
|
-
"no interpolation is performed.",
|
753
|
+
help="Spline interpolation used for rotations.",
|
732
754
|
)
|
733
755
|
performance_group.add_argument(
|
734
756
|
"--use_mixed_precision",
|
@@ -755,7 +777,20 @@ def parse_args():
|
|
755
777
|
default=0,
|
756
778
|
help="Minimum template matching scores to consider for analysis.",
|
757
779
|
)
|
758
|
-
|
780
|
+
analyzer_group.add_argument(
|
781
|
+
"-p",
|
782
|
+
dest="peak_calling",
|
783
|
+
action="store_true",
|
784
|
+
default=False,
|
785
|
+
help="Perform peak calling instead of score aggregation.",
|
786
|
+
)
|
787
|
+
analyzer_group.add_argument(
|
788
|
+
"--number_of_peaks",
|
789
|
+
dest="number_of_peaks",
|
790
|
+
action="store_true",
|
791
|
+
default=1000,
|
792
|
+
help="Number of peaks to call, 1000 by default.",
|
793
|
+
)
|
759
794
|
args = parser.parse_args()
|
760
795
|
args.version = __version__
|
761
796
|
|
@@ -924,44 +959,56 @@ def main():
|
|
924
959
|
template.data, noise_proportion=1.0, normalize_power=True
|
925
960
|
)
|
926
961
|
|
927
|
-
|
962
|
+
# Determine suitable backend for the selected operation
|
963
|
+
available_backends = be.available_backends()
|
964
|
+
if args.backend is not None:
|
965
|
+
req_backend = args.backend
|
966
|
+
if req_backend not in available_backends:
|
967
|
+
raise ValueError("Requested backend is not available.")
|
968
|
+
available_backends = [req_backend]
|
969
|
+
|
970
|
+
be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
|
928
971
|
if args.use_gpu:
|
929
972
|
args.cores = len(args.gpu_indices)
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
if not has_torch and not has_cupy:
|
934
|
-
raise ValueError(
|
935
|
-
"Found neither CuPy nor PyTorch installation. You need to install"
|
936
|
-
" either to enable GPU support."
|
937
|
-
)
|
973
|
+
be_selection = ("pytorch", "cupy", "jax")
|
974
|
+
if args.use_mixed_precision:
|
975
|
+
be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
|
938
976
|
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
if not has_cupy:
|
947
|
-
preferred_backend = "pytorch"
|
948
|
-
backend.change_backend(backend_name=preferred_backend, device="cuda")
|
949
|
-
if args.use_mixed_precision and preferred_backend == "pytorch":
|
977
|
+
available_backends = [x for x in available_backends if x in be_selection]
|
978
|
+
if args.peak_calling:
|
979
|
+
if "jax" in available_backends:
|
980
|
+
available_backends.remove("jax")
|
981
|
+
if args.use_gpu and "pytorch" in available_backends:
|
982
|
+
available_backends = ("pytorch",)
|
983
|
+
if args.interpolation_order == 3:
|
950
984
|
raise NotImplementedError(
|
951
|
-
"
|
952
|
-
" Consider installing CuPy to enable this feature."
|
985
|
+
"Pytorch does not support --interpolation_order 3, 1 is supported."
|
953
986
|
)
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
complex_dtype=backend._array_backend.complex64,
|
959
|
-
default_dtype_int=backend._array_backend.int16,
|
960
|
-
)
|
961
|
-
available_memory = backend.get_available_memory() * args.cores
|
962
|
-
if preferred_backend == "pytorch" and args.interpolation_order == 3:
|
963
|
-
args.interpolation_order = 1
|
987
|
+
# dim_match = len(template.shape) == len(target.shape) <= 3
|
988
|
+
# if dim_match and args.use_gpu and "jax" in available_backends:
|
989
|
+
# args.interpolation_order = 1
|
990
|
+
# available_backends = ["jax"]
|
964
991
|
|
992
|
+
backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
|
993
|
+
if args.use_gpu:
|
994
|
+
backend_preference = ("cupy", "pytorch", "jax")
|
995
|
+
for pref in backend_preference:
|
996
|
+
if pref not in available_backends:
|
997
|
+
continue
|
998
|
+
be.change_backend(pref)
|
999
|
+
if pref == "pytorch":
|
1000
|
+
be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
|
1001
|
+
|
1002
|
+
if args.use_mixed_precision:
|
1003
|
+
be.change_backend(
|
1004
|
+
backend_name=pref,
|
1005
|
+
default_dtype=be._array_backend.float16,
|
1006
|
+
complex_dtype=be._array_backend.complex64,
|
1007
|
+
default_dtype_int=be._array_backend.int16,
|
1008
|
+
)
|
1009
|
+
break
|
1010
|
+
|
1011
|
+
available_memory = be.get_available_memory() * be.device_count()
|
965
1012
|
if args.memory is None:
|
966
1013
|
args.memory = int(args.memory_scaling * available_memory)
|
967
1014
|
|
@@ -978,17 +1025,15 @@ def main():
|
|
978
1025
|
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
979
1026
|
)
|
980
1027
|
|
981
|
-
template_filter, target_filter = setup_filter(
|
982
|
-
|
983
|
-
|
1028
|
+
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
1029
|
+
args, template, target
|
1030
|
+
)
|
984
1031
|
|
985
1032
|
template_box = matching_data._output_template_shape
|
986
1033
|
if not args.pad_fourier:
|
987
|
-
template_box =
|
1034
|
+
template_box = tuple(0 for _ in range(len(template_box)))
|
988
1035
|
|
989
|
-
target_padding =
|
990
|
-
(backend.size(matching_data._output_template_shape)), dtype=int
|
991
|
-
)
|
1036
|
+
target_padding = tuple(0 for _ in range(len(template_box)))
|
992
1037
|
if args.pad_target_edges:
|
993
1038
|
target_padding = matching_data._output_template_shape
|
994
1039
|
|
@@ -1001,10 +1046,10 @@ def main():
|
|
1001
1046
|
split_only_outer=args.use_gpu,
|
1002
1047
|
matching_method=args.score,
|
1003
1048
|
analyzer_method=callback_class.__name__,
|
1004
|
-
backend=
|
1005
|
-
float_nbytes=
|
1006
|
-
complex_nbytes=
|
1007
|
-
integer_nbytes=
|
1049
|
+
backend=be._backend_name,
|
1050
|
+
float_nbytes=be.datatype_bytes(be._float_dtype),
|
1051
|
+
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
1052
|
+
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
1008
1053
|
)
|
1009
1054
|
|
1010
1055
|
if splits is None:
|
@@ -1021,27 +1066,36 @@ def main():
|
|
1021
1066
|
)
|
1022
1067
|
gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
|
1023
1068
|
options = {
|
1024
|
-
"
|
1025
|
-
"
|
1026
|
-
"
|
1027
|
-
"
|
1028
|
-
"
|
1069
|
+
"Angular Sampling": f"{args.angular_sampling}"
|
1070
|
+
f" [{matching_data.rotations.shape[0]} rotations]",
|
1071
|
+
"Center Template": not args.no_centering,
|
1072
|
+
"Scramble Template": args.scramble_phases,
|
1073
|
+
"Invert Contrast": args.invert_target_contrast,
|
1029
1074
|
"Extend Fourier Grid": not args.no_fourier_padding,
|
1030
1075
|
"Extend Target Edges": not args.no_edge_padding,
|
1031
1076
|
"Interpolation Order": args.interpolation_order,
|
1032
|
-
"Score": f"{args.score}",
|
1033
1077
|
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
1034
1078
|
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
1035
|
-
"Angular Sampling": f"{args.angular_sampling}"
|
1036
|
-
f" [{matching_data.rotations.shape[0]} rotations]",
|
1037
|
-
"Scramble Template": args.scramble_phases,
|
1038
|
-
"Target Splits": f"{target_split} [N={n_splits}]",
|
1039
1079
|
}
|
1040
1080
|
|
1041
1081
|
print_block(
|
1042
|
-
name="Template Matching
|
1082
|
+
name="Template Matching",
|
1043
1083
|
data=options,
|
1044
|
-
label_width=max(len(key) for key in options.keys()) +
|
1084
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
compute_options = {
|
1088
|
+
"Backend": be._BACKEND_REGISTRY[be._backend_name],
|
1089
|
+
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1090
|
+
"Use Mixed Precision": args.use_mixed_precision,
|
1091
|
+
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1092
|
+
"Temporary Directory": args.temp_directory,
|
1093
|
+
"Target Splits": f"{target_split} [N={n_splits}]",
|
1094
|
+
}
|
1095
|
+
print_block(
|
1096
|
+
name="Computation",
|
1097
|
+
data=compute_options,
|
1098
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1045
1099
|
)
|
1046
1100
|
|
1047
1101
|
filter_args = {
|
@@ -1059,7 +1113,7 @@ def main():
|
|
1059
1113
|
filter_args["CTF File"] = args.ctf_file
|
1060
1114
|
filter_args["Defocus"] = args.defocus
|
1061
1115
|
filter_args["Phase Shift"] = args.phase_shift
|
1062
|
-
filter_args["
|
1116
|
+
filter_args["Flip Phase"] = args.no_flip_phase
|
1063
1117
|
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1064
1118
|
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1065
1119
|
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
@@ -1070,20 +1124,20 @@ def main():
|
|
1070
1124
|
print_block(
|
1071
1125
|
name="Filters",
|
1072
1126
|
data=filter_args,
|
1073
|
-
label_width=max(len(key) for key in options.keys()) +
|
1127
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1074
1128
|
)
|
1075
1129
|
|
1076
1130
|
analyzer_args = {
|
1077
1131
|
"score_threshold": args.score_threshold,
|
1078
|
-
"number_of_peaks":
|
1079
|
-
"
|
1132
|
+
"number_of_peaks": args.number_of_peaks,
|
1133
|
+
"min_distance": max(template.shape) // 2,
|
1134
|
+
"min_boundary_distance": max(template.shape) // 2,
|
1080
1135
|
"use_memmap": args.use_memmap,
|
1081
1136
|
}
|
1082
|
-
analyzer_args = {"Analyzer": callback_class, **analyzer_args}
|
1083
1137
|
print_block(
|
1084
|
-
name="
|
1085
|
-
data=analyzer_args,
|
1086
|
-
label_width=max(len(key) for key in options.keys()) +
|
1138
|
+
name="Analyzer",
|
1139
|
+
data={"Analyzer": callback_class, **analyzer_args},
|
1140
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1087
1141
|
)
|
1088
1142
|
print("\n" + "-" * 80)
|
1089
1143
|
|
@@ -1104,6 +1158,7 @@ def main():
|
|
1104
1158
|
target_splits=splits,
|
1105
1159
|
pad_target_edges=args.pad_target_edges,
|
1106
1160
|
pad_fourier=args.pad_fourier,
|
1161
|
+
pad_template_filter=not args.no_filter_padding,
|
1107
1162
|
interpolation_order=args.interpolation_order,
|
1108
1163
|
)
|
1109
1164
|
|
@@ -1113,19 +1168,18 @@ def main():
|
|
1113
1168
|
candidates[0] *= target_mask.data
|
1114
1169
|
with warnings.catch_warnings():
|
1115
1170
|
warnings.simplefilter("ignore", category=UserWarning)
|
1116
|
-
nbytes =
|
1171
|
+
nbytes = be.datatype_bytes(be._float_dtype)
|
1117
1172
|
dtype = np.float32 if nbytes == 4 else np.float16
|
1118
1173
|
rot_dim = matching_data.rotations.shape[1]
|
1119
1174
|
candidates[3] = {
|
1120
|
-
x:
|
1121
|
-
np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1122
|
-
)
|
1175
|
+
x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1123
1176
|
for i, x in candidates[3].items()
|
1124
1177
|
}
|
1125
1178
|
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1126
1179
|
write_pickle(data=candidates, filename=args.output)
|
1127
1180
|
|
1128
1181
|
runtime = time() - start
|
1182
|
+
print("\n" + "-" * 80)
|
1129
1183
|
print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
|
1130
1184
|
|
1131
1185
|
|
@@ -8,9 +8,8 @@
|
|
8
8
|
import argparse
|
9
9
|
from sys import exit
|
10
10
|
from os import getcwd
|
11
|
-
from os.path import join, abspath
|
12
11
|
from typing import List, Tuple
|
13
|
-
from os.path import splitext
|
12
|
+
from os.path import join, abspath, splitext
|
14
13
|
|
15
14
|
import numpy as np
|
16
15
|
from numpy.typing import NDArray
|
@@ -26,6 +25,7 @@ from tme.analyzer import (
|
|
26
25
|
)
|
27
26
|
from tme.matching_utils import (
|
28
27
|
load_pickle,
|
28
|
+
centered_mask,
|
29
29
|
euler_to_rotationmatrix,
|
30
30
|
euler_from_rotationmatrix,
|
31
31
|
)
|
@@ -41,9 +41,7 @@ PEAK_CALLERS = {
|
|
41
41
|
|
42
42
|
|
43
43
|
def parse_args():
|
44
|
-
parser = argparse.ArgumentParser(
|
45
|
-
description="Peak Calling for Template Matching Outputs"
|
46
|
-
)
|
44
|
+
parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
|
47
45
|
|
48
46
|
input_group = parser.add_argument_group("Input")
|
49
47
|
output_group = parser.add_argument_group("Output")
|
@@ -56,6 +54,13 @@ def parse_args():
|
|
56
54
|
nargs="+",
|
57
55
|
help="Path to the output of match_template.py.",
|
58
56
|
)
|
57
|
+
input_group.add_argument(
|
58
|
+
"--background_file",
|
59
|
+
required=False,
|
60
|
+
nargs="+",
|
61
|
+
help="Path to an output of match_template.py used for normalization. "
|
62
|
+
"For instance from --scramble_phases or a different template.",
|
63
|
+
)
|
59
64
|
input_group.add_argument(
|
60
65
|
"--target_mask",
|
61
66
|
required=False,
|
@@ -87,7 +92,7 @@ def parse_args():
|
|
87
92
|
"average",
|
88
93
|
],
|
89
94
|
default="orientations",
|
90
|
-
help="Available output formats:"
|
95
|
+
help="Available output formats: "
|
91
96
|
"orientations (translation, rotation, and score), "
|
92
97
|
"alignment (aligned template to target based on orientations), "
|
93
98
|
"extraction (extract regions around peaks from targets, i.e. subtomograms), "
|
@@ -206,6 +211,15 @@ def parse_args():
|
|
206
211
|
elif args.number_of_peaks is None:
|
207
212
|
args.number_of_peaks = 1000
|
208
213
|
|
214
|
+
if args.background_file is None:
|
215
|
+
args.background_file = [None]
|
216
|
+
if len(args.background_file) == 1:
|
217
|
+
args.background_file = args.background_file * len(args.input_file)
|
218
|
+
elif len(args.background_file) not in (0, len(args.input_file)):
|
219
|
+
raise ValueError(
|
220
|
+
"--background_file needs to be specified once or for each --input_file."
|
221
|
+
)
|
222
|
+
|
209
223
|
return args
|
210
224
|
|
211
225
|
|
@@ -233,8 +247,8 @@ def load_template(
|
|
233
247
|
return template, center, translation, template_is_density
|
234
248
|
|
235
249
|
|
236
|
-
def merge_outputs(data,
|
237
|
-
if len(
|
250
|
+
def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
|
251
|
+
if len(foreground_paths) == 0:
|
238
252
|
return data, 1
|
239
253
|
|
240
254
|
if data[0].ndim != data[2].ndim:
|
@@ -275,8 +289,11 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
275
289
|
|
276
290
|
entities = np.zeros_like(data[0])
|
277
291
|
data[0] = _norm_scores(data=data, args=args)
|
278
|
-
for index, filepath in enumerate(
|
279
|
-
new_scores = _norm_scores(
|
292
|
+
for index, filepath in enumerate(foreground_paths):
|
293
|
+
new_scores = _norm_scores(
|
294
|
+
data=load_match_template_output(filepath, background_paths[index]),
|
295
|
+
args=args,
|
296
|
+
)
|
280
297
|
indices = new_scores > data[0]
|
281
298
|
entities[indices] = index + 1
|
282
299
|
data[0][indices] = new_scores[indices]
|
@@ -284,9 +301,18 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
284
301
|
return data, entities
|
285
302
|
|
286
303
|
|
304
|
+
def load_match_template_output(foreground_path, background_path):
|
305
|
+
data = load_pickle(foreground_path)
|
306
|
+
if background_path is not None:
|
307
|
+
data_background = load_pickle(background_path)
|
308
|
+
data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
|
309
|
+
np.fmax(data[0], 0, out=data[0])
|
310
|
+
return data
|
311
|
+
|
312
|
+
|
287
313
|
def main():
|
288
314
|
args = parse_args()
|
289
|
-
data =
|
315
|
+
data = load_match_template_output(args.input_file[0], args.background_file[0])
|
290
316
|
|
291
317
|
target_origin, _, sampling_rate, cli_args = data[-1]
|
292
318
|
|
@@ -326,7 +352,12 @@ def main():
|
|
326
352
|
|
327
353
|
entities = None
|
328
354
|
if len(args.input_file) > 1:
|
329
|
-
data, entities = merge_outputs(
|
355
|
+
data, entities = merge_outputs(
|
356
|
+
data=data,
|
357
|
+
foreground_paths=args.input_file,
|
358
|
+
background_paths=args.background_file,
|
359
|
+
args=args,
|
360
|
+
)
|
330
361
|
|
331
362
|
orientations = args.orientations
|
332
363
|
if orientations is None:
|
@@ -339,24 +370,27 @@ def main():
|
|
339
370
|
target_mask = Density.from_file(args.target_mask)
|
340
371
|
scores = scores * target_mask.data
|
341
372
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
scores.shape, np.multiply(args.min_boundary_distance, 2)
|
346
|
-
).astype(int)
|
373
|
+
cropped_shape = np.subtract(
|
374
|
+
scores.shape, np.multiply(args.min_boundary_distance, 2)
|
375
|
+
).astype(int)
|
347
376
|
|
348
|
-
|
377
|
+
if args.min_boundary_distance > 0:
|
378
|
+
scores = centered_mask(scores, new_shape=cropped_shape)
|
379
|
+
|
380
|
+
if args.n_false_positives is not None:
|
381
|
+
# Rickgauer et al. 2017
|
382
|
+
cropped_slice = tuple(
|
349
383
|
slice(
|
350
384
|
int(args.min_boundary_distance),
|
351
385
|
int(x - args.min_boundary_distance),
|
352
386
|
)
|
353
387
|
for x in scores.shape
|
354
388
|
)
|
355
|
-
|
356
|
-
n_correlations = np.size(scores[
|
389
|
+
args.n_false_positives = max(args.n_false_positives, 1)
|
390
|
+
n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
|
357
391
|
minimum_score = np.multiply(
|
358
392
|
erfcinv(2 * args.n_false_positives / n_correlations),
|
359
|
-
np.sqrt(2) * np.std(scores[
|
393
|
+
np.sqrt(2) * np.std(scores[cropped_slice]),
|
360
394
|
)
|
361
395
|
print(f"Determined minimum score cutoff: {minimum_score}.")
|
362
396
|
minimum_score = max(minimum_score, 0)
|
@@ -371,6 +405,8 @@ def main():
|
|
371
405
|
"min_distance": args.min_distance,
|
372
406
|
"min_boundary_distance": args.min_boundary_distance,
|
373
407
|
"batch_dims": args.batch_dims,
|
408
|
+
"minimum_score": args.minimum_score,
|
409
|
+
"maximum_score": args.maximum_score,
|
374
410
|
}
|
375
411
|
|
376
412
|
peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
|
@@ -380,7 +416,6 @@ def main():
|
|
380
416
|
mask=template.data,
|
381
417
|
rotation_mapping=rotation_mapping,
|
382
418
|
rotation_array=rotation_array,
|
383
|
-
minimum_score=args.minimum_score,
|
384
419
|
)
|
385
420
|
candidates = peak_caller.merge(
|
386
421
|
candidates=[tuple(peak_caller)], **peak_caller_kwargs
|
@@ -388,10 +423,16 @@ def main():
|
|
388
423
|
if len(candidates) == 0:
|
389
424
|
candidates = [[], [], [], []]
|
390
425
|
print("Found no peaks, consider changing peak calling parameters.")
|
391
|
-
exit(
|
426
|
+
exit(-1)
|
392
427
|
|
393
428
|
for translation, _, score, detail in zip(*candidates):
|
394
|
-
|
429
|
+
rotation_index = rotation_array[tuple(translation)]
|
430
|
+
rotation = rotation_mapping.get(
|
431
|
+
rotation_index, np.zeros(template.data.ndim, int)
|
432
|
+
)
|
433
|
+
if rotation.ndim == 2:
|
434
|
+
rotation = euler_from_rotationmatrix(rotation)
|
435
|
+
rotations.append(rotation)
|
395
436
|
|
396
437
|
else:
|
397
438
|
candidates = data
|
@@ -430,7 +471,7 @@ def main():
|
|
430
471
|
)
|
431
472
|
exit(-1)
|
432
473
|
orientations.translations = peak_caller.oversample_peaks(
|
433
|
-
|
474
|
+
scores=data[0],
|
434
475
|
peak_positions=orientations.translations,
|
435
476
|
oversampling_factor=args.peak_oversampling,
|
436
477
|
)
|
@@ -570,7 +611,7 @@ def main():
|
|
570
611
|
return_orientations=True,
|
571
612
|
)
|
572
613
|
out = np.zeros_like(template.data)
|
573
|
-
out = np.zeros(np.multiply(template.shape, 2).astype(int))
|
614
|
+
# out = np.zeros(np.multiply(template.shape, 2).astype(int))
|
574
615
|
for index in range(len(cand_slices)):
|
575
616
|
from scipy.spatial.transform import Rotation
|
576
617
|
|