pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +183 -69
- scripts/match_template_filters.py +193 -71
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +259 -117
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +20 -8
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +79 -60
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +85 -61
- tme/matching_exhaustive.py +222 -129
- tme/matching_optimization.py +117 -76
- tme/orientations.py +175 -55
- tme/preprocessing/_utils.py +17 -5
- tme/preprocessing/composable_filter.py +2 -1
- tme/preprocessing/compose.py +1 -2
- tme/preprocessing/frequency_filters.py +97 -41
- tme/preprocessing/tilt_series.py +137 -87
- tme/preprocessor.py +3 -0
- tme/structure.py +4 -1
- pytme-0.2.0.dist-info/RECORD +0 -72
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -217,7 +217,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
217
217
|
if args.tilt_weighting not in ("angle", None):
|
218
218
|
raise ValueError(
|
219
219
|
"Tilt weighting schemes other than 'angle' or 'None' require "
|
220
|
-
"a specification of electron doses."
|
220
|
+
"a specification of electron doses via --tilt_angles."
|
221
221
|
)
|
222
222
|
|
223
223
|
wedge = Wedge(
|
@@ -240,31 +240,58 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
240
240
|
wedge.sampling_rate = template.sampling_rate
|
241
241
|
template_filter.append(wedge)
|
242
242
|
if not isinstance(wedge, WedgeReconstructed):
|
243
|
-
template_filter.append(
|
244
|
-
|
245
|
-
|
243
|
+
template_filter.append(
|
244
|
+
ReconstructFromTilt(
|
245
|
+
reconstruction_filter=args.reconstruction_filter,
|
246
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
247
|
+
)
|
248
|
+
)
|
246
249
|
|
247
|
-
if args.ctf_file is not None:
|
250
|
+
if args.ctf_file is not None or args.defocus is not None:
|
248
251
|
from tme.preprocessing.tilt_series import CTF
|
249
252
|
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
253
|
+
needs_reconstruction = True
|
254
|
+
if args.ctf_file is not None:
|
255
|
+
ctf = CTF.from_file(args.ctf_file)
|
256
|
+
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
257
|
+
if n_tilts_ctfs != n_tils_angles:
|
258
|
+
raise ValueError(
|
259
|
+
f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
|
260
|
+
f"recieved {n_tils_angles} tilt angles. Expected one angle "
|
261
|
+
"per micrograph."
|
262
|
+
)
|
263
|
+
ctf.angles = wedge.angles
|
264
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
265
|
+
else:
|
266
|
+
needs_reconstruction = False
|
267
|
+
ctf = CTF(
|
268
|
+
defocus_x=[args.defocus],
|
269
|
+
phase_shift=[args.phase_shift],
|
270
|
+
defocus_y=None,
|
271
|
+
angles=[0],
|
272
|
+
shape=None,
|
273
|
+
return_real_fourier=True,
|
257
274
|
)
|
258
|
-
ctf.
|
259
|
-
ctf.
|
260
|
-
|
261
|
-
|
275
|
+
ctf.sampling_rate = template.sampling_rate
|
276
|
+
ctf.flip_phase = not args.no_flip_phase
|
277
|
+
ctf.amplitude_contrast = args.amplitude_contrast
|
278
|
+
ctf.spherical_aberration = args.spherical_aberration
|
279
|
+
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
280
|
+
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
281
|
+
|
282
|
+
if not needs_reconstruction:
|
283
|
+
template_filter.append(ctf)
|
284
|
+
elif isinstance(template_filter[-1], ReconstructFromTilt):
|
262
285
|
template_filter.insert(-1, ctf)
|
263
286
|
else:
|
264
287
|
template_filter.insert(0, ctf)
|
265
|
-
template_filter.insert(
|
266
|
-
|
267
|
-
|
288
|
+
template_filter.insert(
|
289
|
+
1,
|
290
|
+
ReconstructFromTilt(
|
291
|
+
reconstruction_filter=args.reconstruction_filter,
|
292
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
293
|
+
),
|
294
|
+
)
|
268
295
|
|
269
296
|
if args.lowpass or args.highpass is not None:
|
270
297
|
lowpass, highpass = args.lowpass, args.highpass
|
@@ -293,6 +320,14 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
293
320
|
template_filter.append(whitening_filter)
|
294
321
|
target_filter.append(whitening_filter)
|
295
322
|
|
323
|
+
needs_reconstruction = any(
|
324
|
+
[isinstance(t, ReconstructFromTilt) for t in template_filter]
|
325
|
+
)
|
326
|
+
if needs_reconstruction and args.reconstruction_filter is None:
|
327
|
+
warnings.warn(
|
328
|
+
"Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
|
329
|
+
)
|
330
|
+
|
296
331
|
template_filter = Compose(template_filter) if len(template_filter) else None
|
297
332
|
target_filter = Compose(target_filter) if len(target_filter) else None
|
298
333
|
|
@@ -510,7 +545,7 @@ def parse_args():
|
|
510
545
|
dest="no_pass_smooth",
|
511
546
|
action="store_false",
|
512
547
|
default=True,
|
513
|
-
help="Whether a hard edge filter should be used for --lowpass and --highpass."
|
548
|
+
help="Whether a hard edge filter should be used for --lowpass and --highpass.",
|
514
549
|
)
|
515
550
|
filter_group.add_argument(
|
516
551
|
"--pass_format",
|
@@ -519,7 +554,7 @@ def parse_args():
|
|
519
554
|
required=False,
|
520
555
|
choices=["sampling_rate", "voxel", "frequency"],
|
521
556
|
help="How values passed to --lowpass and --highpass should be interpreted. "
|
522
|
-
"By default, they are assumed to be in units of sampling rate, e.g. Ångstrom."
|
557
|
+
"By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
|
523
558
|
)
|
524
559
|
filter_group.add_argument(
|
525
560
|
"--whiten_spectrum",
|
@@ -561,23 +596,90 @@ def parse_args():
|
|
561
596
|
"grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
|
562
597
|
"relion and grigorieff require electron doses in --tilt_angles weights column.",
|
563
598
|
)
|
564
|
-
# filter_group.add_argument(
|
565
|
-
# "--ctf_file",
|
566
|
-
# dest="ctf_file",
|
567
|
-
# type=str,
|
568
|
-
# required=False,
|
569
|
-
# default=None,
|
570
|
-
# help="Path to a file with CTF parameters from CTFFIND4.",
|
571
|
-
# )
|
572
599
|
filter_group.add_argument(
|
573
600
|
"--reconstruction_filter",
|
574
601
|
dest="reconstruction_filter",
|
575
602
|
type=str,
|
576
603
|
required=False,
|
577
|
-
choices
|
604
|
+
choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
|
578
605
|
default=None,
|
579
606
|
help="Filter applied when reconstructing (N+1)-D from N-D filters.",
|
580
607
|
)
|
608
|
+
filter_group.add_argument(
|
609
|
+
"--reconstruction_interpolation_order",
|
610
|
+
dest="reconstruction_interpolation_order",
|
611
|
+
type=int,
|
612
|
+
default=1,
|
613
|
+
required=False,
|
614
|
+
help="Analogous to --interpolation_order but for reconstruction.",
|
615
|
+
)
|
616
|
+
|
617
|
+
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
618
|
+
ctf_group.add_argument(
|
619
|
+
"--ctf_file",
|
620
|
+
dest="ctf_file",
|
621
|
+
type=str,
|
622
|
+
required=False,
|
623
|
+
default=None,
|
624
|
+
help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
|
625
|
+
"interpreted as tilt obtained at the angle specified in --tilt_angles. ",
|
626
|
+
)
|
627
|
+
ctf_group.add_argument(
|
628
|
+
"--defocus",
|
629
|
+
dest="defocus",
|
630
|
+
type=float,
|
631
|
+
required=False,
|
632
|
+
default=None,
|
633
|
+
help="Defocus in units of sampling rate (typically Ångstrom). "
|
634
|
+
"Superseded by --ctf_file.",
|
635
|
+
)
|
636
|
+
ctf_group.add_argument(
|
637
|
+
"--phase_shift",
|
638
|
+
dest="phase_shift",
|
639
|
+
type=float,
|
640
|
+
required=False,
|
641
|
+
default=0,
|
642
|
+
help="Phase shift in degrees. Superseded by --ctf_file.",
|
643
|
+
)
|
644
|
+
ctf_group.add_argument(
|
645
|
+
"--acceleration_voltage",
|
646
|
+
dest="acceleration_voltage",
|
647
|
+
type=float,
|
648
|
+
required=False,
|
649
|
+
default=300,
|
650
|
+
help="Acceleration voltage in kV, defaults to 300.",
|
651
|
+
)
|
652
|
+
ctf_group.add_argument(
|
653
|
+
"--spherical_aberration",
|
654
|
+
dest="spherical_aberration",
|
655
|
+
type=float,
|
656
|
+
required=False,
|
657
|
+
default=2.7e7,
|
658
|
+
help="Spherical aberration in units of sampling rate (typically Ångstrom).",
|
659
|
+
)
|
660
|
+
ctf_group.add_argument(
|
661
|
+
"--amplitude_contrast",
|
662
|
+
dest="amplitude_contrast",
|
663
|
+
type=float,
|
664
|
+
required=False,
|
665
|
+
default=0.07,
|
666
|
+
help="Amplitude contrast, defaults to 0.07.",
|
667
|
+
)
|
668
|
+
ctf_group.add_argument(
|
669
|
+
"--no_flip_phase",
|
670
|
+
dest="no_flip_phase",
|
671
|
+
action="store_false",
|
672
|
+
required=False,
|
673
|
+
help="Whether the phase of the computed CTF should not be flipped.",
|
674
|
+
)
|
675
|
+
ctf_group.add_argument(
|
676
|
+
"--correct_defocus_gradient",
|
677
|
+
dest="correct_defocus_gradient",
|
678
|
+
action="store_true",
|
679
|
+
required=False,
|
680
|
+
help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
|
681
|
+
"defocus gradients.",
|
682
|
+
)
|
581
683
|
|
582
684
|
performance_group = parser.add_argument_group("Performance")
|
583
685
|
performance_group.add_argument(
|
@@ -655,12 +757,11 @@ def parse_args():
|
|
655
757
|
)
|
656
758
|
|
657
759
|
args = parser.parse_args()
|
760
|
+
args.version = __version__
|
658
761
|
|
659
762
|
if args.interpolation_order < 0:
|
660
763
|
args.interpolation_order = None
|
661
764
|
|
662
|
-
args.ctf_file = None
|
663
|
-
|
664
765
|
if args.temp_directory is None:
|
665
766
|
default = abspath(".")
|
666
767
|
if os.environ.get("TMPDIR", None) is not None:
|
@@ -725,12 +826,13 @@ def main():
|
|
725
826
|
sampling_rate=target.sampling_rate,
|
726
827
|
)
|
727
828
|
|
728
|
-
if
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
829
|
+
if target.sampling_rate.size == template.sampling_rate.size:
|
830
|
+
if not np.allclose(target.sampling_rate, template.sampling_rate):
|
831
|
+
print(
|
832
|
+
f"Resampling template to {target.sampling_rate}. "
|
833
|
+
"Consider providing a template with the same sampling rate as the target."
|
834
|
+
)
|
835
|
+
template = template.resample(target.sampling_rate, order=3)
|
734
836
|
|
735
837
|
template_mask = load_and_validate_mask(
|
736
838
|
mask_target=template, mask_path=args.template_mask
|
@@ -863,31 +965,46 @@ def main():
|
|
863
965
|
if args.memory is None:
|
864
966
|
args.memory = int(args.memory_scaling * available_memory)
|
865
967
|
|
866
|
-
|
867
|
-
if args.
|
868
|
-
|
968
|
+
callback_class = MaxScoreOverRotations
|
969
|
+
if args.peak_calling:
|
970
|
+
callback_class = PeakCallerMaximumFilter
|
971
|
+
|
972
|
+
matching_data = MatchingData(
|
973
|
+
target=target,
|
974
|
+
template=template.data,
|
975
|
+
target_mask=target_mask,
|
976
|
+
template_mask=template_mask,
|
977
|
+
invert_target=args.invert_target_contrast,
|
978
|
+
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
979
|
+
)
|
980
|
+
|
981
|
+
template_filter, target_filter = setup_filter(args, template, target)
|
982
|
+
matching_data.template_filter = template_filter
|
983
|
+
matching_data.target_filter = target_filter
|
869
984
|
|
870
|
-
template_box =
|
985
|
+
template_box = matching_data._output_template_shape
|
871
986
|
if not args.pad_fourier:
|
872
987
|
template_box = np.ones(len(template_box), dtype=int)
|
873
988
|
|
874
|
-
|
875
|
-
|
876
|
-
|
989
|
+
target_padding = np.zeros(
|
990
|
+
(backend.size(matching_data._output_template_shape)), dtype=int
|
991
|
+
)
|
992
|
+
if args.pad_target_edges:
|
993
|
+
target_padding = matching_data._output_template_shape
|
877
994
|
|
878
995
|
splits, schedule = compute_parallelization_schedule(
|
879
996
|
shape1=target.shape,
|
880
|
-
shape2=template_box,
|
881
|
-
shape1_padding=target_padding,
|
997
|
+
shape2=tuple(int(x) for x in template_box),
|
998
|
+
shape1_padding=tuple(int(x) for x in target_padding),
|
882
999
|
max_cores=args.cores,
|
883
1000
|
max_ram=args.memory,
|
884
1001
|
split_only_outer=args.use_gpu,
|
885
1002
|
matching_method=args.score,
|
886
1003
|
analyzer_method=callback_class.__name__,
|
887
1004
|
backend=backend._backend_name,
|
888
|
-
float_nbytes=backend.datatype_bytes(backend.
|
1005
|
+
float_nbytes=backend.datatype_bytes(backend._float_dtype),
|
889
1006
|
complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
|
890
|
-
integer_nbytes=backend.datatype_bytes(backend.
|
1007
|
+
integer_nbytes=backend.datatype_bytes(backend._int_dtype),
|
891
1008
|
)
|
892
1009
|
|
893
1010
|
if splits is None:
|
@@ -898,20 +1015,6 @@ def main():
|
|
898
1015
|
exit(-1)
|
899
1016
|
|
900
1017
|
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
901
|
-
matching_data = MatchingData(target=target, template=template.data)
|
902
|
-
matching_data.rotations = parse_rotation_logic(args=args, ndim=target.data.ndim)
|
903
|
-
|
904
|
-
template_filter, target_filter = setup_filter(args, template, target)
|
905
|
-
matching_data.template_filter = template_filter
|
906
|
-
matching_data.target_filter = target_filter
|
907
|
-
|
908
|
-
matching_data.template_filter = template_filter
|
909
|
-
matching_data._invert_target = args.invert_target_contrast
|
910
|
-
if target_mask is not None:
|
911
|
-
matching_data.target_mask = target_mask
|
912
|
-
if template_mask is not None:
|
913
|
-
matching_data.template_mask = template_mask.data
|
914
|
-
|
915
1018
|
n_splits = np.prod(list(splits.values()))
|
916
1019
|
target_split = ", ".join(
|
917
1020
|
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
@@ -945,13 +1048,23 @@ def main():
|
|
945
1048
|
"Lowpass": args.lowpass,
|
946
1049
|
"Highpass": args.highpass,
|
947
1050
|
"Smooth Pass": args.no_pass_smooth,
|
948
|
-
"Pass Format"
|
1051
|
+
"Pass Format": args.pass_format,
|
949
1052
|
"Spectral Whitening": args.whiten_spectrum,
|
950
1053
|
"Wedge Axes": args.wedge_axes,
|
951
1054
|
"Tilt Angles": args.tilt_angles,
|
952
1055
|
"Tilt Weighting": args.tilt_weighting,
|
953
|
-
"
|
1056
|
+
"Reconstruction Filter": args.reconstruction_filter,
|
954
1057
|
}
|
1058
|
+
if args.ctf_file is not None or args.defocus is not None:
|
1059
|
+
filter_args["CTF File"] = args.ctf_file
|
1060
|
+
filter_args["Defocus"] = args.defocus
|
1061
|
+
filter_args["Phase Shift"] = args.phase_shift
|
1062
|
+
filter_args["No Flip Phase"] = args.no_flip_phase
|
1063
|
+
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1064
|
+
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1065
|
+
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
1066
|
+
filter_args["Correct Defocus"] = args.correct_defocus_gradient
|
1067
|
+
|
955
1068
|
filter_args = {k: v for k, v in filter_args.items() if v is not None}
|
956
1069
|
if len(filter_args):
|
957
1070
|
print_block(
|
@@ -1000,15 +1113,16 @@ def main():
|
|
1000
1113
|
candidates[0] *= target_mask.data
|
1001
1114
|
with warnings.catch_warnings():
|
1002
1115
|
warnings.simplefilter("ignore", category=UserWarning)
|
1116
|
+
nbytes = backend.datatype_bytes(backend._float_dtype)
|
1117
|
+
dtype = np.float32 if nbytes == 4 else np.float16
|
1118
|
+
rot_dim = matching_data.rotations.shape[1]
|
1003
1119
|
candidates[3] = {
|
1004
1120
|
x: euler_from_rotationmatrix(
|
1005
|
-
np.frombuffer(i, dtype=
|
1006
|
-
candidates[0].ndim, candidates[0].ndim
|
1007
|
-
)
|
1121
|
+
np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1008
1122
|
)
|
1009
1123
|
for i, x in candidates[3].items()
|
1010
1124
|
}
|
1011
|
-
candidates.append((target.origin, template.origin,
|
1125
|
+
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1012
1126
|
write_pickle(data=candidates, filename=args.output)
|
1013
1127
|
|
1014
1128
|
runtime = time() - start
|
@@ -9,7 +9,7 @@ import argparse
|
|
9
9
|
from sys import exit
|
10
10
|
from os import getcwd
|
11
11
|
from os.path import join, abspath
|
12
|
-
from typing import List
|
12
|
+
from typing import List, Tuple
|
13
13
|
from os.path import splitext
|
14
14
|
|
15
15
|
import numpy as np
|
@@ -29,6 +29,7 @@ from tme.matching_utils import (
|
|
29
29
|
euler_to_rotationmatrix,
|
30
30
|
euler_from_rotationmatrix,
|
31
31
|
)
|
32
|
+
from tme.matching_optimization import create_score_object, optimize_match
|
32
33
|
|
33
34
|
PEAK_CALLERS = {
|
34
35
|
"PeakCallerSort": PeakCallerSort,
|
@@ -181,6 +182,13 @@ def parse_args():
|
|
181
182
|
required=False,
|
182
183
|
help="Number of accepted false-positives picks to determine minimum score.",
|
183
184
|
)
|
185
|
+
additional_group.add_argument(
|
186
|
+
"--local_optimization",
|
187
|
+
action="store_true",
|
188
|
+
required=False,
|
189
|
+
help="[Experimental] Perform local optimization of candidates. Useful when the "
|
190
|
+
"number of identified candidats is small (< 10).",
|
191
|
+
)
|
184
192
|
|
185
193
|
args = parser.parse_args()
|
186
194
|
|
@@ -195,28 +203,34 @@ def parse_args():
|
|
195
203
|
|
196
204
|
if args.minimum_score is not None or args.n_false_positives is not None:
|
197
205
|
args.number_of_peaks = np.iinfo(np.int64).max
|
198
|
-
|
206
|
+
elif args.number_of_peaks is None:
|
199
207
|
args.number_of_peaks = 1000
|
200
208
|
|
201
209
|
return args
|
202
210
|
|
203
211
|
|
204
|
-
def load_template(
|
212
|
+
def load_template(
|
213
|
+
filepath: str,
|
214
|
+
sampling_rate: NDArray,
|
215
|
+
centering: bool = True,
|
216
|
+
target_shape: Tuple[int] = None,
|
217
|
+
):
|
205
218
|
try:
|
206
219
|
template = Density.from_file(filepath)
|
207
|
-
|
220
|
+
center = np.divide(np.subtract(template.shape, 1), 2)
|
208
221
|
template_is_density = True
|
209
|
-
except
|
222
|
+
except Exception:
|
210
223
|
template = Structure.from_file(filepath)
|
211
|
-
|
224
|
+
center = template.center_of_mass()[::-1]
|
212
225
|
template = Density.from_structure(template, sampling_rate=sampling_rate)
|
213
226
|
template_is_density = False
|
214
227
|
|
215
|
-
translation = np.zeros_like(
|
216
|
-
if
|
228
|
+
translation = np.zeros_like(center)
|
229
|
+
if centering and template_is_density:
|
217
230
|
template, translation = template.centered(0)
|
231
|
+
center = np.divide(np.subtract(template.shape, 1), 2)
|
218
232
|
|
219
|
-
return template,
|
233
|
+
return template, center, translation, template_is_density
|
220
234
|
|
221
235
|
|
222
236
|
def merge_outputs(data, filepaths: List[str], args):
|
@@ -226,7 +240,7 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
226
240
|
if data[0].ndim != data[2].ndim:
|
227
241
|
return data, 1
|
228
242
|
|
229
|
-
from tme.matching_exhaustive import
|
243
|
+
from tme.matching_exhaustive import normalize_under_mask
|
230
244
|
|
231
245
|
def _norm_scores(data, args):
|
232
246
|
target_origin, _, sampling_rate, cli_args = data[-1]
|
@@ -235,7 +249,7 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
235
249
|
ret = load_template(
|
236
250
|
filepath=cli_args.template,
|
237
251
|
sampling_rate=sampling_rate,
|
238
|
-
|
252
|
+
centering=not cli_args.no_centering,
|
239
253
|
)
|
240
254
|
template, center_of_mass, translation, template_is_density = ret
|
241
255
|
|
@@ -256,7 +270,7 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
256
270
|
mask.shape, np.multiply(args.min_boundary_distance, 2)
|
257
271
|
).astype(int)
|
258
272
|
mask[cropped_shape] = 0
|
259
|
-
|
273
|
+
normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
|
260
274
|
return data[0]
|
261
275
|
|
262
276
|
entities = np.zeros_like(data[0])
|
@@ -280,7 +294,7 @@ def main():
|
|
280
294
|
ret = load_template(
|
281
295
|
filepath=cli_args.template,
|
282
296
|
sampling_rate=sampling_rate,
|
283
|
-
|
297
|
+
centering=not cli_args.no_centering,
|
284
298
|
)
|
285
299
|
template, center_of_mass, translation, template_is_density = ret
|
286
300
|
|
@@ -310,7 +324,9 @@ def main():
|
|
310
324
|
max_shape = np.max(template.shape)
|
311
325
|
args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
|
312
326
|
|
313
|
-
|
327
|
+
entities = None
|
328
|
+
if len(args.input_file) > 1:
|
329
|
+
data, entities = merge_outputs(data=data, filepaths=args.input_file, args=args)
|
314
330
|
|
315
331
|
orientations = args.orientations
|
316
332
|
if orientations is None:
|
@@ -346,31 +362,33 @@ def main():
|
|
346
362
|
minimum_score = max(minimum_score, 0)
|
347
363
|
args.minimum_score = minimum_score
|
348
364
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
args.
|
365
|
+
args.batch_dims = None
|
366
|
+
if hasattr(cli_args, "target_batch"):
|
367
|
+
args.batch_dims = cli_args.target_batch
|
368
|
+
|
369
|
+
peak_caller_kwargs = {
|
370
|
+
"number_of_peaks": args.number_of_peaks,
|
371
|
+
"min_distance": args.min_distance,
|
372
|
+
"min_boundary_distance": args.min_boundary_distance,
|
373
|
+
"batch_dims": args.batch_dims,
|
374
|
+
}
|
356
375
|
|
376
|
+
peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
|
357
377
|
peak_caller(
|
358
378
|
scores,
|
359
|
-
rotation_matrix=np.eye(
|
379
|
+
rotation_matrix=np.eye(template.data.ndim),
|
360
380
|
mask=template.data,
|
361
381
|
rotation_mapping=rotation_mapping,
|
362
382
|
rotation_array=rotation_array,
|
363
383
|
minimum_score=args.minimum_score,
|
364
384
|
)
|
365
385
|
candidates = peak_caller.merge(
|
366
|
-
candidates=[tuple(peak_caller)],
|
367
|
-
number_of_peaks=args.number_of_peaks,
|
368
|
-
min_distance=args.min_distance,
|
369
|
-
min_boundary_distance=args.min_boundary_distance,
|
386
|
+
candidates=[tuple(peak_caller)], **peak_caller_kwargs
|
370
387
|
)
|
371
388
|
if len(candidates) == 0:
|
372
|
-
|
373
|
-
|
389
|
+
candidates = [[], [], [], []]
|
390
|
+
print("Found no peaks, consider changing peak calling parameters.")
|
391
|
+
exit(0)
|
374
392
|
|
375
393
|
for translation, _, score, detail in zip(*candidates):
|
376
394
|
rotations.append(rotation_mapping[rotation_array[tuple(translation)]])
|
@@ -381,8 +399,13 @@ def main():
|
|
381
399
|
for i in range(translation.shape[0]):
|
382
400
|
rotations.append(euler_from_rotationmatrix(rotation[i]))
|
383
401
|
|
384
|
-
|
402
|
+
if len(rotations):
|
403
|
+
rotations = np.vstack(rotations).astype(float)
|
385
404
|
translations, scores, details = candidates[0], candidates[2], candidates[3]
|
405
|
+
|
406
|
+
if entities is not None:
|
407
|
+
details = entities[tuple(translations.T)]
|
408
|
+
|
386
409
|
orientations = Orientations(
|
387
410
|
translations=translations,
|
388
411
|
rotations=rotations,
|
@@ -390,14 +413,55 @@ def main():
|
|
390
413
|
details=details,
|
391
414
|
)
|
392
415
|
|
393
|
-
if args.minimum_score is not None:
|
416
|
+
if args.minimum_score is not None and len(orientations.scores):
|
394
417
|
keep = orientations.scores >= args.minimum_score
|
395
418
|
orientations = orientations[keep]
|
396
419
|
|
397
|
-
if args.maximum_score is not None:
|
420
|
+
if args.maximum_score is not None and len(orientations.scores):
|
398
421
|
keep = orientations.scores <= args.maximum_score
|
399
422
|
orientations = orientations[keep]
|
400
423
|
|
424
|
+
if args.peak_oversampling > 1:
|
425
|
+
peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller]()
|
426
|
+
if data[0].ndim != data[2].ndim:
|
427
|
+
print(
|
428
|
+
"Input pickle does not contain template matching scores."
|
429
|
+
" Cannot oversample peaks."
|
430
|
+
)
|
431
|
+
exit(-1)
|
432
|
+
orientations.translations = peak_caller.oversample_peaks(
|
433
|
+
score_space=data[0],
|
434
|
+
peak_positions=orientations.translations,
|
435
|
+
oversampling_factor=args.peak_oversampling,
|
436
|
+
)
|
437
|
+
|
438
|
+
if args.local_optimization:
|
439
|
+
target = Density.from_file(cli_args.target)
|
440
|
+
orientations.translations = orientations.translations.astype(np.float32)
|
441
|
+
orientations.rotations = orientations.rotations.astype(np.float32)
|
442
|
+
for index, (translation, angles, *_) in enumerate(orientations):
|
443
|
+
score_object = create_score_object(
|
444
|
+
score="FLC",
|
445
|
+
target=target.data.copy(),
|
446
|
+
template=template.data.copy(),
|
447
|
+
template_mask=template_mask.data.copy(),
|
448
|
+
)
|
449
|
+
|
450
|
+
center = np.divide(template.shape, 2)
|
451
|
+
init_translation = np.subtract(translation, center)
|
452
|
+
bounds_translation = tuple((x - 5, x + 5) for x in init_translation)
|
453
|
+
|
454
|
+
translation, rotation_matrix, score = optimize_match(
|
455
|
+
score_object=score_object,
|
456
|
+
optimization_method="basinhopping",
|
457
|
+
bounds_translation=bounds_translation,
|
458
|
+
maxiter=3,
|
459
|
+
x0=[*init_translation, *angles],
|
460
|
+
)
|
461
|
+
orientations.translations[index] = np.add(translation, center)
|
462
|
+
orientations.rotations[index] = angles
|
463
|
+
orientations.scores[index] = score * -1
|
464
|
+
|
401
465
|
if args.output_format == "orientations":
|
402
466
|
orientations.to_file(filename=f"{args.output_prefix}.tsv", file_format="text")
|
403
467
|
exit(0)
|
@@ -515,7 +579,6 @@ def main():
|
|
515
579
|
)
|
516
580
|
rotation_matrix = rotation.inv().as_matrix()
|
517
581
|
|
518
|
-
# rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
|
519
582
|
subset = Density(target.data[obs_slices[index]])
|
520
583
|
subset = subset.rigid_transform(rotation_matrix=rotation_matrix, order=1)
|
521
584
|
|
@@ -526,35 +589,30 @@ def main():
|
|
526
589
|
ret.to_file(f"{args.output_prefix}_average.mrc")
|
527
590
|
exit(0)
|
528
591
|
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
)
|
536
|
-
exit(-1)
|
537
|
-
orientations.translations = peak_caller.oversample_peaks(
|
538
|
-
score_space=data[0],
|
539
|
-
translations=orientations.translations,
|
540
|
-
oversampling_factor=args.oversampling_factor,
|
541
|
-
)
|
592
|
+
template, center, *_ = load_template(
|
593
|
+
filepath=cli_args.template,
|
594
|
+
sampling_rate=sampling_rate,
|
595
|
+
centering=not cli_args.no_centering,
|
596
|
+
target_shape=target.shape,
|
597
|
+
)
|
542
598
|
|
543
599
|
for index, (translation, angles, *_) in enumerate(orientations):
|
544
600
|
rotation_matrix = euler_to_rotationmatrix(angles)
|
545
601
|
if template_is_density:
|
546
|
-
translation = np.subtract(translation,
|
602
|
+
translation = np.subtract(translation, center)
|
547
603
|
transformed_template = template.rigid_transform(
|
548
604
|
rotation_matrix=rotation_matrix
|
549
605
|
)
|
550
|
-
|
551
|
-
|
606
|
+
transformed_template.origin = np.add(
|
607
|
+
target_origin, np.multiply(translation, sampling_rate)
|
608
|
+
)
|
609
|
+
|
552
610
|
else:
|
553
611
|
template = Structure.from_file(cli_args.template)
|
554
612
|
new_center_of_mass = np.add(
|
555
613
|
np.multiply(translation, sampling_rate), target_origin
|
556
614
|
)
|
557
|
-
translation = np.subtract(new_center_of_mass,
|
615
|
+
translation = np.subtract(new_center_of_mass, center)
|
558
616
|
transformed_template = template.rigid_transform(
|
559
617
|
translation=translation[::-1],
|
560
618
|
rotation_matrix=rotation_matrix[::-1, ::-1],
|
@@ -789,7 +789,10 @@ class AlignmentWidget(widgets.Container):
|
|
789
789
|
active_layer = self.viewer.layers.selection.active
|
790
790
|
if active_layer is None:
|
791
791
|
return ()
|
792
|
-
|
792
|
+
try:
|
793
|
+
return [i for i in range(active_layer.data.ndim)]
|
794
|
+
except Exception:
|
795
|
+
return ()
|
793
796
|
|
794
797
|
def _update_align_axis(self, *args):
|
795
798
|
self.align_axis.choices = self._get_active_layer_dims()
|