pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/match_template.py +91 -142
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/postprocess.py +20 -29
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +6 -9
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/METADATA +11 -10
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/RECORD +33 -32
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +91 -142
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +6 -9
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +34 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/density.py +15 -8
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +24 -17
- tme/matching_exhaustive.py +36 -19
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +7 -2
- tme/orientations.py +26 -9
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/tilt_series.py +10 -32
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/WHEEL +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
#!python
|
2
|
-
""" CLI
|
2
|
+
""" CLI for basic pyTME template matching functions.
|
3
3
|
|
4
4
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
5
|
|
@@ -22,7 +22,6 @@ from tme.matching_utils import (
|
|
22
22
|
get_rotations_around_vector,
|
23
23
|
compute_parallelization_schedule,
|
24
24
|
scramble_phases,
|
25
|
-
generate_tempfile_name,
|
26
25
|
write_pickle,
|
27
26
|
)
|
28
27
|
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
@@ -99,7 +98,9 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
99
98
|
f"Expected shape of {mask_path} was {mask_target.shape},"
|
100
99
|
f" got f{mask.shape}"
|
101
100
|
)
|
102
|
-
if not np.allclose(
|
101
|
+
if not np.allclose(
|
102
|
+
np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
|
103
|
+
):
|
103
104
|
raise ValueError(
|
104
105
|
f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
|
105
106
|
f", got f{mask.sampling_rate}"
|
@@ -107,50 +108,6 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
107
108
|
return mask
|
108
109
|
|
109
110
|
|
110
|
-
def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
|
111
|
-
"""
|
112
|
-
Crop the provided data and mask to a smaller box based on a cutoff value.
|
113
|
-
|
114
|
-
Parameters
|
115
|
-
----------
|
116
|
-
data : Density
|
117
|
-
The data that should be cropped.
|
118
|
-
cutoff : float
|
119
|
-
The threshold value to determine which parts of the data should be kept.
|
120
|
-
data_mask : Density, optional
|
121
|
-
A mask for the data that should be cropped.
|
122
|
-
|
123
|
-
Returns
|
124
|
-
-------
|
125
|
-
bool
|
126
|
-
Returns True if the data was adjusted (cropped), otherwise returns False.
|
127
|
-
|
128
|
-
Notes
|
129
|
-
-----
|
130
|
-
Cropping is performed in place.
|
131
|
-
"""
|
132
|
-
if cutoff is None:
|
133
|
-
return False
|
134
|
-
|
135
|
-
box = data.trim_box(cutoff=cutoff)
|
136
|
-
box_mask = box
|
137
|
-
if data_mask is not None:
|
138
|
-
box_mask = data_mask.trim_box(cutoff=cutoff)
|
139
|
-
box = tuple(
|
140
|
-
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
141
|
-
for arr, mask in zip(box, box_mask)
|
142
|
-
)
|
143
|
-
if box == tuple(slice(0, x) for x in data.shape):
|
144
|
-
return False
|
145
|
-
|
146
|
-
data.adjust_box(box)
|
147
|
-
|
148
|
-
if data_mask:
|
149
|
-
data_mask.adjust_box(box)
|
150
|
-
|
151
|
-
return True
|
152
|
-
|
153
|
-
|
154
111
|
def parse_rotation_logic(args, ndim):
|
155
112
|
if args.angular_sampling is not None:
|
156
113
|
rotations = get_rotation_matrices(
|
@@ -175,8 +132,52 @@ def parse_rotation_logic(args, ndim):
|
|
175
132
|
return rotations
|
176
133
|
|
177
134
|
|
178
|
-
|
179
|
-
|
135
|
+
def compute_schedule(
|
136
|
+
args,
|
137
|
+
target: Density,
|
138
|
+
matching_data: MatchingData,
|
139
|
+
callback_class,
|
140
|
+
pad_edges: bool = False,
|
141
|
+
):
|
142
|
+
# User requested target padding
|
143
|
+
if args.pad_edges is True:
|
144
|
+
pad_edges = True
|
145
|
+
template_box = matching_data._output_template_shape
|
146
|
+
if not args.pad_fourier:
|
147
|
+
template_box = tuple(0 for _ in range(len(template_box)))
|
148
|
+
|
149
|
+
target_padding = tuple(0 for _ in range(len(template_box)))
|
150
|
+
if pad_edges:
|
151
|
+
target_padding = matching_data._output_template_shape
|
152
|
+
|
153
|
+
splits, schedule = compute_parallelization_schedule(
|
154
|
+
shape1=target.shape,
|
155
|
+
shape2=tuple(int(x) for x in template_box),
|
156
|
+
shape1_padding=tuple(int(x) for x in target_padding),
|
157
|
+
max_cores=args.cores,
|
158
|
+
max_ram=args.memory,
|
159
|
+
split_only_outer=args.use_gpu,
|
160
|
+
matching_method=args.score,
|
161
|
+
analyzer_method=callback_class.__name__,
|
162
|
+
backend=be._backend_name,
|
163
|
+
float_nbytes=be.datatype_bytes(be._float_dtype),
|
164
|
+
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
165
|
+
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
166
|
+
)
|
167
|
+
|
168
|
+
if splits is None:
|
169
|
+
print(
|
170
|
+
"Found no suitable parallelization schedule. Consider increasing"
|
171
|
+
" available RAM or decreasing number of cores."
|
172
|
+
)
|
173
|
+
exit(-1)
|
174
|
+
n_splits = np.prod(list(splits.values()))
|
175
|
+
if pad_edges is False and n_splits > 1:
|
176
|
+
args.pad_edges = True
|
177
|
+
return compute_schedule(args, target, matching_data, callback_class, True)
|
178
|
+
return splits, schedule
|
179
|
+
|
180
|
+
|
180
181
|
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
181
182
|
from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
|
182
183
|
from tme.preprocessing.tilt_series import (
|
@@ -232,18 +233,23 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
232
233
|
weight_wedge=args.tilt_weighting == "angle",
|
233
234
|
create_continuous_wedge=create_continuous_wedge,
|
234
235
|
)
|
236
|
+
wedge_target = WedgeReconstructed(
|
237
|
+
angles=tilt_angles,
|
238
|
+
weight_wedge=False,
|
239
|
+
create_continuous_wedge=create_continuous_wedge,
|
240
|
+
)
|
241
|
+
target_filter.append(wedge_target)
|
235
242
|
|
236
243
|
wedge.opening_axis = args.wedge_axes[0]
|
237
244
|
wedge.tilt_axis = args.wedge_axes[1]
|
238
245
|
wedge.sampling_rate = template.sampling_rate
|
239
246
|
template_filter.append(wedge)
|
240
247
|
if not isinstance(wedge, WedgeReconstructed):
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
245
|
-
)
|
248
|
+
reconstruction_filter = ReconstructFromTilt(
|
249
|
+
reconstruction_filter=args.reconstruction_filter,
|
250
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
246
251
|
)
|
252
|
+
template_filter.append(reconstruction_filter)
|
247
253
|
|
248
254
|
if args.ctf_file is not None or args.defocus is not None:
|
249
255
|
from tme.preprocessing.tilt_series import CTF
|
@@ -393,8 +399,7 @@ def parse_args():
|
|
393
399
|
dest="invert_target_contrast",
|
394
400
|
action="store_true",
|
395
401
|
default=False,
|
396
|
-
help="Invert the target's contrast
|
397
|
-
"This option is intended for targets where templates to-be-matched have "
|
402
|
+
help="Invert the target's contrast for cases where templates to-be-matched have "
|
398
403
|
"negative values, e.g. tomograms.",
|
399
404
|
)
|
400
405
|
io_group.add_argument(
|
@@ -696,22 +701,6 @@ def parse_args():
|
|
696
701
|
)
|
697
702
|
|
698
703
|
performance_group = parser.add_argument_group("Performance")
|
699
|
-
performance_group.add_argument(
|
700
|
-
"--cutoff_target",
|
701
|
-
dest="cutoff_target",
|
702
|
-
type=float,
|
703
|
-
required=False,
|
704
|
-
default=None,
|
705
|
-
help="Target contour level (used for cropping).",
|
706
|
-
)
|
707
|
-
performance_group.add_argument(
|
708
|
-
"--cutoff_template",
|
709
|
-
dest="cutoff_template",
|
710
|
-
type=float,
|
711
|
-
required=False,
|
712
|
-
default=None,
|
713
|
-
help="Template contour level (used for cropping).",
|
714
|
-
)
|
715
704
|
performance_group.add_argument(
|
716
705
|
"--no_centering",
|
717
706
|
dest="no_centering",
|
@@ -719,30 +708,28 @@ def parse_args():
|
|
719
708
|
help="Assumes the template is already centered and omits centering.",
|
720
709
|
)
|
721
710
|
performance_group.add_argument(
|
722
|
-
"--
|
723
|
-
dest="
|
711
|
+
"--pad_edges",
|
712
|
+
dest="pad_edges",
|
724
713
|
action="store_true",
|
725
714
|
default=False,
|
726
|
-
help="Whether to
|
727
|
-
"
|
715
|
+
help="Whether to pad the edges of the target. Useful if the target does not "
|
716
|
+
"a well-defined bounding box. Defaults to True if splitting is required.",
|
728
717
|
)
|
729
718
|
performance_group.add_argument(
|
730
|
-
"--
|
731
|
-
dest="
|
719
|
+
"--pad_fourier",
|
720
|
+
dest="pad_fourier",
|
732
721
|
action="store_true",
|
733
722
|
default=False,
|
734
723
|
help="Whether input arrays should not be zero-padded to full convolution shape "
|
735
|
-
"for numerical stability.
|
736
|
-
"it is safe to use this flag and benefit from the performance gain.",
|
724
|
+
"for numerical stability. Typically only useful when working with small data.",
|
737
725
|
)
|
738
726
|
performance_group.add_argument(
|
739
|
-
"--
|
740
|
-
dest="
|
727
|
+
"--pad_filter",
|
728
|
+
dest="pad_filter",
|
741
729
|
action="store_true",
|
742
730
|
default=False,
|
743
|
-
help="
|
744
|
-
"
|
745
|
-
"filters setting this flag can introduce aliasing effects.",
|
731
|
+
help="Pads the filter to the shape of the target. Particularly useful for fast "
|
732
|
+
"oscilating filters to avoid aliasing effects.",
|
746
733
|
)
|
747
734
|
performance_group.add_argument(
|
748
735
|
"--interpolation_order",
|
@@ -805,9 +792,6 @@ def parse_args():
|
|
805
792
|
|
806
793
|
os.environ["TMPDIR"] = args.temp_directory
|
807
794
|
|
808
|
-
args.pad_target_edges = not args.no_edge_padding
|
809
|
-
args.pad_fourier = not args.no_fourier_padding
|
810
|
-
|
811
795
|
if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
|
812
796
|
raise ValueError(
|
813
797
|
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
@@ -862,7 +846,9 @@ def main():
|
|
862
846
|
)
|
863
847
|
|
864
848
|
if target.sampling_rate.size == template.sampling_rate.size:
|
865
|
-
if not np.allclose(
|
849
|
+
if not np.allclose(
|
850
|
+
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
851
|
+
):
|
866
852
|
print(
|
867
853
|
f"Resampling template to {target.sampling_rate}. "
|
868
854
|
"Consider providing a template with the same sampling rate as the target."
|
@@ -877,9 +863,6 @@ def main():
|
|
877
863
|
)
|
878
864
|
|
879
865
|
initial_shape = target.shape
|
880
|
-
is_cropped = crop_data(
|
881
|
-
data=target, data_mask=target_mask, cutoff=args.cutoff_target
|
882
|
-
)
|
883
866
|
print_block(
|
884
867
|
name="Target",
|
885
868
|
data={
|
@@ -888,13 +871,6 @@ def main():
|
|
888
871
|
"Final Shape": target.shape,
|
889
872
|
},
|
890
873
|
)
|
891
|
-
if is_cropped:
|
892
|
-
args.target = generate_tempfile_name(suffix=".mrc")
|
893
|
-
target.to_file(args.target)
|
894
|
-
|
895
|
-
if target_mask:
|
896
|
-
args.target_mask = generate_tempfile_name(suffix=".mrc")
|
897
|
-
target_mask.to_file(args.target_mask)
|
898
874
|
|
899
875
|
if target_mask:
|
900
876
|
print_block(
|
@@ -907,8 +883,6 @@ def main():
|
|
907
883
|
)
|
908
884
|
|
909
885
|
initial_shape = template.shape
|
910
|
-
_ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
|
911
|
-
|
912
886
|
translation = np.zeros(len(template.shape), dtype=np.float32)
|
913
887
|
if not args.no_centering:
|
914
888
|
template, translation = template.centered(0)
|
@@ -956,7 +930,7 @@ def main():
|
|
956
930
|
|
957
931
|
if args.scramble_phases:
|
958
932
|
template.data = scramble_phases(
|
959
|
-
template.data, noise_proportion=1.0, normalize_power=
|
933
|
+
template.data, noise_proportion=1.0, normalize_power=False
|
960
934
|
)
|
961
935
|
|
962
936
|
# Determine suitable backend for the selected operation
|
@@ -980,10 +954,7 @@ def main():
|
|
980
954
|
available_backends.remove("jax")
|
981
955
|
if args.use_gpu and "pytorch" in available_backends:
|
982
956
|
available_backends = ("pytorch",)
|
983
|
-
|
984
|
-
raise NotImplementedError(
|
985
|
-
"Pytorch does not support --interpolation_order 3, 1 is supported."
|
986
|
-
)
|
957
|
+
|
987
958
|
# dim_match = len(template.shape) == len(target.shape) <= 3
|
988
959
|
# if dim_match and args.use_gpu and "jax" in available_backends:
|
989
960
|
# args.interpolation_order = 1
|
@@ -1008,6 +979,12 @@ def main():
|
|
1008
979
|
)
|
1009
980
|
break
|
1010
981
|
|
982
|
+
if pref == "pytorch" and args.interpolation_order == 3:
|
983
|
+
warnings.warn(
|
984
|
+
"Pytorch does not support --interpolation_order 3, setting it to 1."
|
985
|
+
)
|
986
|
+
args.interpolation_order = 1
|
987
|
+
|
1011
988
|
available_memory = be.get_available_memory() * be.device_count()
|
1012
989
|
if args.memory is None:
|
1013
990
|
args.memory = int(args.memory_scaling * available_memory)
|
@@ -1025,41 +1002,13 @@ def main():
|
|
1025
1002
|
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
1026
1003
|
)
|
1027
1004
|
|
1005
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
1028
1006
|
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
1029
1007
|
args, template, target
|
1030
1008
|
)
|
1031
1009
|
|
1032
|
-
|
1033
|
-
if not args.pad_fourier:
|
1034
|
-
template_box = tuple(0 for _ in range(len(template_box)))
|
1035
|
-
|
1036
|
-
target_padding = tuple(0 for _ in range(len(template_box)))
|
1037
|
-
if args.pad_target_edges:
|
1038
|
-
target_padding = matching_data._output_template_shape
|
1039
|
-
|
1040
|
-
splits, schedule = compute_parallelization_schedule(
|
1041
|
-
shape1=target.shape,
|
1042
|
-
shape2=tuple(int(x) for x in template_box),
|
1043
|
-
shape1_padding=tuple(int(x) for x in target_padding),
|
1044
|
-
max_cores=args.cores,
|
1045
|
-
max_ram=args.memory,
|
1046
|
-
split_only_outer=args.use_gpu,
|
1047
|
-
matching_method=args.score,
|
1048
|
-
analyzer_method=callback_class.__name__,
|
1049
|
-
backend=be._backend_name,
|
1050
|
-
float_nbytes=be.datatype_bytes(be._float_dtype),
|
1051
|
-
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
1052
|
-
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
1053
|
-
)
|
1054
|
-
|
1055
|
-
if splits is None:
|
1056
|
-
print(
|
1057
|
-
"Found no suitable parallelization schedule. Consider increasing"
|
1058
|
-
" available RAM or decreasing number of cores."
|
1059
|
-
)
|
1060
|
-
exit(-1)
|
1010
|
+
splits, schedule = compute_schedule(args, target, matching_data, callback_class)
|
1061
1011
|
|
1062
|
-
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
1063
1012
|
n_splits = np.prod(list(splits.values()))
|
1064
1013
|
target_split = ", ".join(
|
1065
1014
|
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
@@ -1071,8 +1020,8 @@ def main():
|
|
1071
1020
|
"Center Template": not args.no_centering,
|
1072
1021
|
"Scramble Template": args.scramble_phases,
|
1073
1022
|
"Invert Contrast": args.invert_target_contrast,
|
1074
|
-
"Extend Fourier Grid":
|
1075
|
-
"Extend Target Edges":
|
1023
|
+
"Extend Fourier Grid": args.pad_fourier,
|
1024
|
+
"Extend Target Edges": args.pad_edges,
|
1076
1025
|
"Interpolation Order": args.interpolation_order,
|
1077
1026
|
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
1078
1027
|
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
@@ -1108,6 +1057,7 @@ def main():
|
|
1108
1057
|
"Tilt Angles": args.tilt_angles,
|
1109
1058
|
"Tilt Weighting": args.tilt_weighting,
|
1110
1059
|
"Reconstruction Filter": args.reconstruction_filter,
|
1060
|
+
"Extend Filter Grid": args.pad_filter,
|
1111
1061
|
}
|
1112
1062
|
if args.ctf_file is not None or args.defocus is not None:
|
1113
1063
|
filter_args["CTF File"] = args.ctf_file
|
@@ -1130,8 +1080,7 @@ def main():
|
|
1130
1080
|
analyzer_args = {
|
1131
1081
|
"score_threshold": args.score_threshold,
|
1132
1082
|
"number_of_peaks": args.number_of_peaks,
|
1133
|
-
"min_distance": max(template.shape) //
|
1134
|
-
"min_boundary_distance": max(template.shape) // 2,
|
1083
|
+
"min_distance": max(template.shape) // 3,
|
1135
1084
|
"use_memmap": args.use_memmap,
|
1136
1085
|
}
|
1137
1086
|
print_block(
|
@@ -1156,9 +1105,9 @@ def main():
|
|
1156
1105
|
callback_class=callback_class,
|
1157
1106
|
callback_class_args=analyzer_args,
|
1158
1107
|
target_splits=splits,
|
1159
|
-
pad_target_edges=args.
|
1108
|
+
pad_target_edges=args.pad_edges,
|
1160
1109
|
pad_fourier=args.pad_fourier,
|
1161
|
-
pad_template_filter=
|
1110
|
+
pad_template_filter=args.pad_filter,
|
1162
1111
|
interpolation_order=args.interpolation_order,
|
1163
1112
|
)
|
1164
1113
|
|
@@ -509,19 +509,7 @@ def main():
|
|
509
509
|
|
510
510
|
target = Density.from_file(cli_args.target)
|
511
511
|
if args.invert_target_contrast:
|
512
|
-
|
513
|
-
target.data = target.data * -1
|
514
|
-
target.data = np.divide(
|
515
|
-
np.subtract(target.data, target.data.mean()), target.data.std()
|
516
|
-
)
|
517
|
-
else:
|
518
|
-
target.data = (
|
519
|
-
-np.divide(
|
520
|
-
np.subtract(target.data, target.data.min()),
|
521
|
-
np.subtract(target.data.max(), target.data.min()),
|
522
|
-
)
|
523
|
-
+ 1
|
524
|
-
)
|
512
|
+
target.data = target.data * -1
|
525
513
|
|
526
514
|
if args.output_format in ("extraction", "relion"):
|
527
515
|
if not np.all(np.divide(target.shape, template.shape) > 2):
|
@@ -546,10 +534,14 @@ def main():
|
|
546
534
|
|
547
535
|
working_directory = getcwd()
|
548
536
|
if args.output_format == "relion":
|
537
|
+
name = [
|
538
|
+
join(working_directory, f"{args.output_prefix}_{index}.mrc")
|
539
|
+
for index in range(len(cand_slices))
|
540
|
+
]
|
549
541
|
orientations.to_file(
|
550
542
|
filename=f"{args.output_prefix}.star",
|
551
543
|
file_format="relion",
|
552
|
-
|
544
|
+
name=name,
|
553
545
|
ctf_image=args.wedge_mask,
|
554
546
|
sampling_rate=target.sampling_rate.max(),
|
555
547
|
subtomogram_size=extraction_shape[0],
|
@@ -606,24 +598,22 @@ def main():
|
|
606
598
|
if args.output_format == "average":
|
607
599
|
orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
|
608
600
|
target_shape=target.shape,
|
609
|
-
extraction_shape=
|
601
|
+
extraction_shape=template.shape,
|
610
602
|
drop_out_of_box=True,
|
611
603
|
return_orientations=True,
|
612
604
|
)
|
613
605
|
out = np.zeros_like(template.data)
|
614
|
-
# out = np.zeros(np.multiply(template.shape, 2).astype(int))
|
615
606
|
for index in range(len(cand_slices)):
|
616
|
-
from scipy.spatial.transform import Rotation
|
617
|
-
|
618
|
-
rotation = Rotation.from_euler(
|
619
|
-
angles=orientations.rotations[index], seq="zyx", degrees=True
|
620
|
-
)
|
621
|
-
rotation_matrix = rotation.inv().as_matrix()
|
622
|
-
|
623
607
|
subset = Density(target.data[obs_slices[index]])
|
624
|
-
|
608
|
+
rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
|
625
609
|
|
610
|
+
subset = subset.rigid_transform(
|
611
|
+
rotation_matrix=np.linalg.inv(rotation_matrix),
|
612
|
+
order=1,
|
613
|
+
use_geometric_center=True,
|
614
|
+
)
|
626
615
|
np.add(out, subset.data, out=out)
|
616
|
+
|
627
617
|
out /= len(cand_slices)
|
628
618
|
ret = Density(out, sampling_rate=template.sampling_rate, origin=0)
|
629
619
|
ret.pad(template.shape, center=True)
|
@@ -637,17 +627,18 @@ def main():
|
|
637
627
|
target_shape=target.shape,
|
638
628
|
)
|
639
629
|
|
630
|
+
# Template is larger than target
|
640
631
|
for index, (translation, angles, *_) in enumerate(orientations):
|
641
632
|
rotation_matrix = euler_to_rotationmatrix(angles)
|
642
633
|
if template_is_density:
|
643
|
-
translation = np.subtract(translation, center)
|
644
634
|
transformed_template = template.rigid_transform(
|
645
|
-
rotation_matrix=rotation_matrix
|
646
|
-
)
|
647
|
-
transformed_template.origin = np.add(
|
648
|
-
target_origin, np.multiply(translation, sampling_rate)
|
635
|
+
rotation_matrix=rotation_matrix, use_geometric_center=True
|
649
636
|
)
|
650
637
|
|
638
|
+
# Just adapting the coordinate system not the in-box position
|
639
|
+
shift = np.multiply(np.subtract(translation, center), sampling_rate)
|
640
|
+
transformed_template.origin = np.add(target_origin, shift)
|
641
|
+
|
651
642
|
else:
|
652
643
|
template = Structure.from_file(cli_args.template)
|
653
644
|
new_center_of_mass = np.add(
|
@@ -0,0 +1,132 @@
|
|
1
|
+
#!python
|
2
|
+
""" Preprocessing routines for template matching.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
import warnings
|
9
|
+
import argparse
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from tme import Density, Structure
|
13
|
+
from tme.backends import backend as be
|
14
|
+
from tme.preprocessing.frequency_filters import BandPassFilter
|
15
|
+
|
16
|
+
|
17
|
+
def parse_args():
|
18
|
+
parser = argparse.ArgumentParser(
|
19
|
+
description="Perform template matching preprocessing.",
|
20
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
21
|
+
)
|
22
|
+
|
23
|
+
io_group = parser.add_argument_group("Input / Output")
|
24
|
+
io_group.add_argument(
|
25
|
+
"-m",
|
26
|
+
"--data",
|
27
|
+
dest="data",
|
28
|
+
type=str,
|
29
|
+
required=True,
|
30
|
+
help="Path to a file in PDB/MMCIF, CCP4/MRC, EM, H5 or a format supported by "
|
31
|
+
"tme.density.Density.from_file "
|
32
|
+
"https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
|
33
|
+
)
|
34
|
+
io_group.add_argument(
|
35
|
+
"-o",
|
36
|
+
"--output",
|
37
|
+
dest="output",
|
38
|
+
type=str,
|
39
|
+
required=True,
|
40
|
+
help="Path the output should be written to.",
|
41
|
+
)
|
42
|
+
|
43
|
+
box_group = parser.add_argument_group("Box")
|
44
|
+
box_group.add_argument(
|
45
|
+
"--box_size",
|
46
|
+
dest="box_size",
|
47
|
+
type=int,
|
48
|
+
required=True,
|
49
|
+
help="Box size of the output",
|
50
|
+
)
|
51
|
+
box_group.add_argument(
|
52
|
+
"--sampling_rate",
|
53
|
+
dest="sampling_rate",
|
54
|
+
type=float,
|
55
|
+
required=True,
|
56
|
+
help="Sampling rate of the output file.",
|
57
|
+
)
|
58
|
+
|
59
|
+
modulation_group = parser.add_argument_group("Modulation")
|
60
|
+
modulation_group.add_argument(
|
61
|
+
"--invert_contrast",
|
62
|
+
dest="invert_contrast",
|
63
|
+
action="store_true",
|
64
|
+
required=False,
|
65
|
+
help="Inverts the template contrast.",
|
66
|
+
)
|
67
|
+
modulation_group.add_argument(
|
68
|
+
"--lowpass",
|
69
|
+
dest="lowpass",
|
70
|
+
type=float,
|
71
|
+
required=False,
|
72
|
+
default=None,
|
73
|
+
help="Lowpass filter the template to the given resolution. Nyquist by default. "
|
74
|
+
"A value of 0 disables the filter.",
|
75
|
+
)
|
76
|
+
modulation_group.add_argument(
|
77
|
+
"--no_centering",
|
78
|
+
dest="no_centering",
|
79
|
+
action="store_true",
|
80
|
+
help="Assumes the template is already centered and omits centering.",
|
81
|
+
)
|
82
|
+
args = parser.parse_args()
|
83
|
+
return args
|
84
|
+
|
85
|
+
|
86
|
+
def main():
|
87
|
+
args = parse_args()
|
88
|
+
|
89
|
+
try:
|
90
|
+
data = Structure.from_file(args.data)
|
91
|
+
data = Density.from_structure(data, sampling_rate=args.sampling_rate)
|
92
|
+
except NotImplementedError:
|
93
|
+
data = Density.from_file(args.data)
|
94
|
+
|
95
|
+
if not args.no_centering:
|
96
|
+
data, _ = data.centered(0)
|
97
|
+
|
98
|
+
recommended_box = be.compute_convolution_shapes([args.box_size], [1])[1][0]
|
99
|
+
if recommended_box != args.box_size:
|
100
|
+
warnings.warn(
|
101
|
+
f"Consider using --box_size {recommended_box} instead of {args.box_size}."
|
102
|
+
)
|
103
|
+
|
104
|
+
data.pad(
|
105
|
+
np.multiply(args.box_size, np.divide(args.sampling_rate, data.sampling_rate)),
|
106
|
+
center=True,
|
107
|
+
)
|
108
|
+
|
109
|
+
bpf_mask = 1
|
110
|
+
lowpass = 2 * args.sampling_rate if args.lowpass is None else args.lowpass
|
111
|
+
if args.lowpass != 0:
|
112
|
+
bpf_mask = BandPassFilter(
|
113
|
+
lowpass=lowpass,
|
114
|
+
highpass=None,
|
115
|
+
use_gaussian=True,
|
116
|
+
return_real_fourier=True,
|
117
|
+
shape_is_real_fourier=False,
|
118
|
+
)(shape=data.shape)["data"]
|
119
|
+
|
120
|
+
data_ft = np.fft.rfftn(data.data, s=data.shape)
|
121
|
+
data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
|
122
|
+
data.data = np.fft.irfftn(data_ft, s=data.shape).real
|
123
|
+
|
124
|
+
data = data.resample(args.sampling_rate, method="spline", order=3)
|
125
|
+
|
126
|
+
if args.invert_contrast:
|
127
|
+
data.data = data.data * -1
|
128
|
+
|
129
|
+
data.to_file(args.output)
|
130
|
+
|
131
|
+
if __name__ == "__main__":
|
132
|
+
main()
|
@@ -132,14 +132,6 @@ def local_gaussian_filter(
|
|
132
132
|
)
|
133
133
|
|
134
134
|
|
135
|
-
def ntree(
|
136
|
-
template: NDArray,
|
137
|
-
sigma_range: Tuple[float, float],
|
138
|
-
**kwargs: dict,
|
139
|
-
) -> NDArray:
|
140
|
-
return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
|
141
|
-
|
142
|
-
|
143
135
|
def mean(
|
144
136
|
template: NDArray,
|
145
137
|
width: int,
|
@@ -197,6 +189,10 @@ def compute_power_spectrum(template: NDArray) -> NDArray:
|
|
197
189
|
return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
|
198
190
|
|
199
191
|
|
192
|
+
def invert_contrast(template: NDArray) -> NDArray:
|
193
|
+
return template * -1
|
194
|
+
|
195
|
+
|
200
196
|
def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
|
201
197
|
"""
|
202
198
|
Creates list of magicui widgets by inspecting function typing ann
|
@@ -252,13 +248,13 @@ WRAPPED_FUNCTIONS = {
|
|
252
248
|
"gaussian_filter": gaussian_filter,
|
253
249
|
"bandpass_filter": bandpass_filter,
|
254
250
|
"edge_gaussian_filter": edge_gaussian_filter,
|
255
|
-
"ntree_filter": ntree,
|
256
251
|
"local_gaussian_filter": local_gaussian_filter,
|
257
252
|
"difference_of_gaussian_filter": difference_of_gaussian_filter,
|
258
253
|
"mean_filter": mean,
|
259
254
|
"wedge_filter": wedge,
|
260
255
|
"power_spectrum": compute_power_spectrum,
|
261
256
|
"ctf": ctf_filter,
|
257
|
+
"invert_contrast": invert_contrast,
|
262
258
|
}
|
263
259
|
|
264
260
|
EXCLUDED_FUNCTIONS = [
|
@@ -634,6 +630,7 @@ class MaskWidget(widgets.Container):
|
|
634
630
|
|
635
631
|
data = active_layer.data.copy()
|
636
632
|
cutoff = np.quantile(data, self.percentile_range_edit.value / 100)
|
633
|
+
cutoff = max(cutoff, np.finfo(np.float32).resolution)
|
637
634
|
data[data < cutoff] = 0
|
638
635
|
|
639
636
|
center_of_mass = Density.center_of_mass(np.abs(data), 0)
|