pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
- pytme-0.2.3.dist-info/METADATA +92 -0
- pytme-0.2.3.dist-info/RECORD +75 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
- pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +219 -216
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +86 -54
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +181 -94
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +458 -813
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +187 -0
- tme/backends/cupy_backend.py +109 -226
- tme/backends/jax_backend.py +230 -152
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +240 -507
- tme/backends/pytorch_backend.py +30 -151
- tme/density.py +248 -371
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +328 -284
- tme/matching_exhaustive.py +195 -1499
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +887 -0
- tme/matching_utils.py +287 -388
- tme/memory.py +377 -0
- tme/orientations.py +78 -21
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +44 -72
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
scripts/match_template.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
#!python3
|
2
|
-
""" CLI
|
2
|
+
""" CLI for basic pyTME template matching functions.
|
3
3
|
|
4
4
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
5
|
|
@@ -8,7 +8,6 @@
|
|
8
8
|
import os
|
9
9
|
import argparse
|
10
10
|
import warnings
|
11
|
-
import importlib.util
|
12
11
|
from sys import exit
|
13
12
|
from time import time
|
14
13
|
from typing import Tuple
|
@@ -22,9 +21,7 @@ from tme.matching_utils import (
|
|
22
21
|
get_rotation_matrices,
|
23
22
|
get_rotations_around_vector,
|
24
23
|
compute_parallelization_schedule,
|
25
|
-
euler_from_rotationmatrix,
|
26
24
|
scramble_phases,
|
27
|
-
generate_tempfile_name,
|
28
25
|
write_pickle,
|
29
26
|
)
|
30
27
|
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
@@ -33,7 +30,7 @@ from tme.analyzer import (
|
|
33
30
|
MaxScoreOverRotations,
|
34
31
|
PeakCallerMaximumFilter,
|
35
32
|
)
|
36
|
-
from tme.backends import backend
|
33
|
+
from tme.backends import backend as be
|
37
34
|
from tme.preprocessing import Compose
|
38
35
|
|
39
36
|
|
@@ -52,7 +49,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
|
|
52
49
|
|
53
50
|
def print_entry() -> None:
|
54
51
|
width = 80
|
55
|
-
text = f"
|
52
|
+
text = f" pytme v{__version__} "
|
56
53
|
padding_total = width - len(text) - 2
|
57
54
|
padding_left = padding_total // 2
|
58
55
|
padding_right = padding_total - padding_left
|
@@ -101,7 +98,9 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
101
98
|
f"Expected shape of {mask_path} was {mask_target.shape},"
|
102
99
|
f" got f{mask.shape}"
|
103
100
|
)
|
104
|
-
if not np.allclose(
|
101
|
+
if not np.allclose(
|
102
|
+
np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
|
103
|
+
):
|
105
104
|
raise ValueError(
|
106
105
|
f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
|
107
106
|
f", got f{mask.sampling_rate}"
|
@@ -109,50 +108,6 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
109
108
|
return mask
|
110
109
|
|
111
110
|
|
112
|
-
def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
|
113
|
-
"""
|
114
|
-
Crop the provided data and mask to a smaller box based on a cutoff value.
|
115
|
-
|
116
|
-
Parameters
|
117
|
-
----------
|
118
|
-
data : Density
|
119
|
-
The data that should be cropped.
|
120
|
-
cutoff : float
|
121
|
-
The threshold value to determine which parts of the data should be kept.
|
122
|
-
data_mask : Density, optional
|
123
|
-
A mask for the data that should be cropped.
|
124
|
-
|
125
|
-
Returns
|
126
|
-
-------
|
127
|
-
bool
|
128
|
-
Returns True if the data was adjusted (cropped), otherwise returns False.
|
129
|
-
|
130
|
-
Notes
|
131
|
-
-----
|
132
|
-
Cropping is performed in place.
|
133
|
-
"""
|
134
|
-
if cutoff is None:
|
135
|
-
return False
|
136
|
-
|
137
|
-
box = data.trim_box(cutoff=cutoff)
|
138
|
-
box_mask = box
|
139
|
-
if data_mask is not None:
|
140
|
-
box_mask = data_mask.trim_box(cutoff=cutoff)
|
141
|
-
box = tuple(
|
142
|
-
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
143
|
-
for arr, mask in zip(box, box_mask)
|
144
|
-
)
|
145
|
-
if box == tuple(slice(0, x) for x in data.shape):
|
146
|
-
return False
|
147
|
-
|
148
|
-
data.adjust_box(box)
|
149
|
-
|
150
|
-
if data_mask:
|
151
|
-
data_mask.adjust_box(box)
|
152
|
-
|
153
|
-
return True
|
154
|
-
|
155
|
-
|
156
111
|
def parse_rotation_logic(args, ndim):
|
157
112
|
if args.angular_sampling is not None:
|
158
113
|
rotations = get_rotation_matrices(
|
@@ -177,8 +132,52 @@ def parse_rotation_logic(args, ndim):
|
|
177
132
|
return rotations
|
178
133
|
|
179
134
|
|
180
|
-
|
181
|
-
|
135
|
+
def compute_schedule(
|
136
|
+
args,
|
137
|
+
target: Density,
|
138
|
+
matching_data: MatchingData,
|
139
|
+
callback_class,
|
140
|
+
pad_edges: bool = False,
|
141
|
+
):
|
142
|
+
# User requested target padding
|
143
|
+
if args.pad_edges is True:
|
144
|
+
pad_edges = True
|
145
|
+
template_box = matching_data._output_template_shape
|
146
|
+
if not args.pad_fourier:
|
147
|
+
template_box = tuple(0 for _ in range(len(template_box)))
|
148
|
+
|
149
|
+
target_padding = tuple(0 for _ in range(len(template_box)))
|
150
|
+
if pad_edges:
|
151
|
+
target_padding = matching_data._output_template_shape
|
152
|
+
|
153
|
+
splits, schedule = compute_parallelization_schedule(
|
154
|
+
shape1=target.shape,
|
155
|
+
shape2=tuple(int(x) for x in template_box),
|
156
|
+
shape1_padding=tuple(int(x) for x in target_padding),
|
157
|
+
max_cores=args.cores,
|
158
|
+
max_ram=args.memory,
|
159
|
+
split_only_outer=args.use_gpu,
|
160
|
+
matching_method=args.score,
|
161
|
+
analyzer_method=callback_class.__name__,
|
162
|
+
backend=be._backend_name,
|
163
|
+
float_nbytes=be.datatype_bytes(be._float_dtype),
|
164
|
+
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
165
|
+
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
166
|
+
)
|
167
|
+
|
168
|
+
if splits is None:
|
169
|
+
print(
|
170
|
+
"Found no suitable parallelization schedule. Consider increasing"
|
171
|
+
" available RAM or decreasing number of cores."
|
172
|
+
)
|
173
|
+
exit(-1)
|
174
|
+
n_splits = np.prod(list(splits.values()))
|
175
|
+
if pad_edges is False and n_splits > 1:
|
176
|
+
args.pad_edges = True
|
177
|
+
return compute_schedule(args, target, matching_data, callback_class, True)
|
178
|
+
return splits, schedule
|
179
|
+
|
180
|
+
|
182
181
|
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
183
182
|
from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
|
184
183
|
from tme.preprocessing.tilt_series import (
|
@@ -234,18 +233,23 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
234
233
|
weight_wedge=args.tilt_weighting == "angle",
|
235
234
|
create_continuous_wedge=create_continuous_wedge,
|
236
235
|
)
|
236
|
+
wedge_target = WedgeReconstructed(
|
237
|
+
angles=tilt_angles,
|
238
|
+
weight_wedge=False,
|
239
|
+
create_continuous_wedge=create_continuous_wedge,
|
240
|
+
)
|
241
|
+
target_filter.append(wedge_target)
|
237
242
|
|
238
243
|
wedge.opening_axis = args.wedge_axes[0]
|
239
244
|
wedge.tilt_axis = args.wedge_axes[1]
|
240
245
|
wedge.sampling_rate = template.sampling_rate
|
241
246
|
template_filter.append(wedge)
|
242
247
|
if not isinstance(wedge, WedgeReconstructed):
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
247
|
-
)
|
248
|
+
reconstruction_filter = ReconstructFromTilt(
|
249
|
+
reconstruction_filter=args.reconstruction_filter,
|
250
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
248
251
|
)
|
252
|
+
template_filter.append(reconstruction_filter)
|
249
253
|
|
250
254
|
if args.ctf_file is not None or args.defocus is not None:
|
251
255
|
from tme.preprocessing.tilt_series import CTF
|
@@ -273,7 +277,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
273
277
|
return_real_fourier=True,
|
274
278
|
)
|
275
279
|
ctf.sampling_rate = template.sampling_rate
|
276
|
-
ctf.flip_phase =
|
280
|
+
ctf.flip_phase = args.no_flip_phase
|
277
281
|
ctf.amplitude_contrast = args.amplitude_contrast
|
278
282
|
ctf.spherical_aberration = args.spherical_aberration
|
279
283
|
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
@@ -306,6 +310,12 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
306
310
|
if highpass is not None:
|
307
311
|
highpass = np.max(np.divide(template.sampling_rate, highpass))
|
308
312
|
|
313
|
+
try:
|
314
|
+
if args.lowpass >= args.highpass:
|
315
|
+
warnings.warn("--lowpass should be smaller than --highpass.")
|
316
|
+
except Exception:
|
317
|
+
pass
|
318
|
+
|
309
319
|
bandpass = BandPassFilter(
|
310
320
|
use_gaussian=args.no_pass_smooth,
|
311
321
|
lowpass=lowpass,
|
@@ -313,7 +323,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
313
323
|
sampling_rate=template.sampling_rate,
|
314
324
|
)
|
315
325
|
template_filter.append(bandpass)
|
316
|
-
|
326
|
+
|
327
|
+
if not args.no_filter_target:
|
328
|
+
target_filter.append(bandpass)
|
317
329
|
|
318
330
|
if args.whiten_spectrum:
|
319
331
|
whitening_filter = LinearWhiteningFilter()
|
@@ -335,7 +347,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
335
347
|
|
336
348
|
|
337
349
|
def parse_args():
|
338
|
-
parser = argparse.ArgumentParser(
|
350
|
+
parser = argparse.ArgumentParser(
|
351
|
+
description="Perform template matching.",
|
352
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
353
|
+
)
|
339
354
|
|
340
355
|
io_group = parser.add_argument_group("Input / Output")
|
341
356
|
io_group.add_argument(
|
@@ -384,8 +399,7 @@ def parse_args():
|
|
384
399
|
dest="invert_target_contrast",
|
385
400
|
action="store_true",
|
386
401
|
default=False,
|
387
|
-
help="Invert the target's contrast
|
388
|
-
"This option is intended for targets where templates to-be-matched have "
|
402
|
+
help="Invert the target's contrast for cases where templates to-be-matched have "
|
389
403
|
"negative values, e.g. tomograms.",
|
390
404
|
)
|
391
405
|
io_group.add_argument(
|
@@ -405,13 +419,6 @@ def parse_args():
|
|
405
419
|
choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
|
406
420
|
help="Template matching scoring function.",
|
407
421
|
)
|
408
|
-
scoring_group.add_argument(
|
409
|
-
"-p",
|
410
|
-
dest="peak_calling",
|
411
|
-
action="store_true",
|
412
|
-
default=False,
|
413
|
-
help="Perform peak calling instead of score aggregation.",
|
414
|
-
)
|
415
422
|
|
416
423
|
angular_group = parser.add_argument_group("Angular Sampling")
|
417
424
|
angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
|
@@ -445,7 +452,7 @@ def parse_args():
|
|
445
452
|
type=check_positive,
|
446
453
|
default=360.0,
|
447
454
|
required=False,
|
448
|
-
help="Sampling angle along the z-axis of the cone.
|
455
|
+
help="Sampling angle along the z-axis of the cone.",
|
449
456
|
)
|
450
457
|
angular_group.add_argument(
|
451
458
|
"--axis_sampling",
|
@@ -513,8 +520,7 @@ def parse_args():
|
|
513
520
|
required=False,
|
514
521
|
type=float,
|
515
522
|
default=0.85,
|
516
|
-
help="Fraction of available memory
|
517
|
-
"ignored if --ram is set",
|
523
|
+
help="Fraction of available memory to be used. Ignored if --ram is set.",
|
518
524
|
)
|
519
525
|
computation_group.add_argument(
|
520
526
|
"--temp_directory",
|
@@ -522,7 +528,13 @@ def parse_args():
|
|
522
528
|
default=None,
|
523
529
|
help="Directory for temporary objects. Faster I/O improves runtime.",
|
524
530
|
)
|
525
|
-
|
531
|
+
computation_group.add_argument(
|
532
|
+
"--backend",
|
533
|
+
dest="backend",
|
534
|
+
default=None,
|
535
|
+
choices=be.available_backends(),
|
536
|
+
help="[Expert] Overwrite default computation backend.",
|
537
|
+
)
|
526
538
|
filter_group = parser.add_argument_group("Filters")
|
527
539
|
filter_group.add_argument(
|
528
540
|
"--lowpass",
|
@@ -552,9 +564,9 @@ def parse_args():
|
|
552
564
|
dest="pass_format",
|
553
565
|
type=str,
|
554
566
|
required=False,
|
567
|
+
default="sampling_rate",
|
555
568
|
choices=["sampling_rate", "voxel", "frequency"],
|
556
|
-
help="How values passed to --lowpass and --highpass should be interpreted. "
|
557
|
-
"By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
|
569
|
+
help="How values passed to --lowpass and --highpass should be interpreted. ",
|
558
570
|
)
|
559
571
|
filter_group.add_argument(
|
560
572
|
"--whiten_spectrum",
|
@@ -613,6 +625,13 @@ def parse_args():
|
|
613
625
|
required=False,
|
614
626
|
help="Analogous to --interpolation_order but for reconstruction.",
|
615
627
|
)
|
628
|
+
filter_group.add_argument(
|
629
|
+
"--no_filter_target",
|
630
|
+
dest="no_filter_target",
|
631
|
+
action="store_true",
|
632
|
+
default=False,
|
633
|
+
help="Whether to not apply potential filters to the target.",
|
634
|
+
)
|
616
635
|
|
617
636
|
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
618
637
|
ctf_group.add_argument(
|
@@ -647,7 +666,7 @@ def parse_args():
|
|
647
666
|
type=float,
|
648
667
|
required=False,
|
649
668
|
default=300,
|
650
|
-
help="Acceleration voltage in kV
|
669
|
+
help="Acceleration voltage in kV.",
|
651
670
|
)
|
652
671
|
ctf_group.add_argument(
|
653
672
|
"--spherical_aberration",
|
@@ -663,14 +682,14 @@ def parse_args():
|
|
663
682
|
type=float,
|
664
683
|
required=False,
|
665
684
|
default=0.07,
|
666
|
-
help="Amplitude contrast
|
685
|
+
help="Amplitude contrast.",
|
667
686
|
)
|
668
687
|
ctf_group.add_argument(
|
669
688
|
"--no_flip_phase",
|
670
689
|
dest="no_flip_phase",
|
671
690
|
action="store_false",
|
672
691
|
required=False,
|
673
|
-
help="
|
692
|
+
help="Do not perform phase-flipping CTF correction.",
|
674
693
|
)
|
675
694
|
ctf_group.add_argument(
|
676
695
|
"--correct_defocus_gradient",
|
@@ -682,22 +701,6 @@ def parse_args():
|
|
682
701
|
)
|
683
702
|
|
684
703
|
performance_group = parser.add_argument_group("Performance")
|
685
|
-
performance_group.add_argument(
|
686
|
-
"--cutoff_target",
|
687
|
-
dest="cutoff_target",
|
688
|
-
type=float,
|
689
|
-
required=False,
|
690
|
-
default=None,
|
691
|
-
help="Target contour level (used for cropping).",
|
692
|
-
)
|
693
|
-
performance_group.add_argument(
|
694
|
-
"--cutoff_template",
|
695
|
-
dest="cutoff_template",
|
696
|
-
type=float,
|
697
|
-
required=False,
|
698
|
-
default=None,
|
699
|
-
help="Template contour level (used for cropping).",
|
700
|
-
)
|
701
704
|
performance_group.add_argument(
|
702
705
|
"--no_centering",
|
703
706
|
dest="no_centering",
|
@@ -705,21 +708,28 @@ def parse_args():
|
|
705
708
|
help="Assumes the template is already centered and omits centering.",
|
706
709
|
)
|
707
710
|
performance_group.add_argument(
|
708
|
-
"--
|
709
|
-
dest="
|
711
|
+
"--pad_edges",
|
712
|
+
dest="pad_edges",
|
710
713
|
action="store_true",
|
711
714
|
default=False,
|
712
|
-
help="Whether to
|
713
|
-
"
|
715
|
+
help="Whether to pad the edges of the target. Useful if the target does not "
|
716
|
+
"a well-defined bounding box. Defaults to True if splitting is required.",
|
714
717
|
)
|
715
718
|
performance_group.add_argument(
|
716
|
-
"--
|
717
|
-
dest="
|
719
|
+
"--pad_fourier",
|
720
|
+
dest="pad_fourier",
|
718
721
|
action="store_true",
|
719
722
|
default=False,
|
720
723
|
help="Whether input arrays should not be zero-padded to full convolution shape "
|
721
|
-
"for numerical stability.
|
722
|
-
|
724
|
+
"for numerical stability. Typically only useful when working with small data.",
|
725
|
+
)
|
726
|
+
performance_group.add_argument(
|
727
|
+
"--pad_filter",
|
728
|
+
dest="pad_filter",
|
729
|
+
action="store_true",
|
730
|
+
default=False,
|
731
|
+
help="Pads the filter to the shape of the target. Particularly useful for fast "
|
732
|
+
"oscilating filters to avoid aliasing effects.",
|
723
733
|
)
|
724
734
|
performance_group.add_argument(
|
725
735
|
"--interpolation_order",
|
@@ -727,8 +737,7 @@ def parse_args():
|
|
727
737
|
required=False,
|
728
738
|
type=int,
|
729
739
|
default=3,
|
730
|
-
help="Spline interpolation used for
|
731
|
-
"no interpolation is performed.",
|
740
|
+
help="Spline interpolation used for rotations.",
|
732
741
|
)
|
733
742
|
performance_group.add_argument(
|
734
743
|
"--use_mixed_precision",
|
@@ -755,7 +764,20 @@ def parse_args():
|
|
755
764
|
default=0,
|
756
765
|
help="Minimum template matching scores to consider for analysis.",
|
757
766
|
)
|
758
|
-
|
767
|
+
analyzer_group.add_argument(
|
768
|
+
"-p",
|
769
|
+
dest="peak_calling",
|
770
|
+
action="store_true",
|
771
|
+
default=False,
|
772
|
+
help="Perform peak calling instead of score aggregation.",
|
773
|
+
)
|
774
|
+
analyzer_group.add_argument(
|
775
|
+
"--number_of_peaks",
|
776
|
+
dest="number_of_peaks",
|
777
|
+
action="store_true",
|
778
|
+
default=1000,
|
779
|
+
help="Number of peaks to call, 1000 by default.",
|
780
|
+
)
|
759
781
|
args = parser.parse_args()
|
760
782
|
args.version = __version__
|
761
783
|
|
@@ -770,9 +792,6 @@ def parse_args():
|
|
770
792
|
|
771
793
|
os.environ["TMPDIR"] = args.temp_directory
|
772
794
|
|
773
|
-
args.pad_target_edges = not args.no_edge_padding
|
774
|
-
args.pad_fourier = not args.no_fourier_padding
|
775
|
-
|
776
795
|
if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
|
777
796
|
raise ValueError(
|
778
797
|
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
@@ -827,7 +846,9 @@ def main():
|
|
827
846
|
)
|
828
847
|
|
829
848
|
if target.sampling_rate.size == template.sampling_rate.size:
|
830
|
-
if not np.allclose(
|
849
|
+
if not np.allclose(
|
850
|
+
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
851
|
+
):
|
831
852
|
print(
|
832
853
|
f"Resampling template to {target.sampling_rate}. "
|
833
854
|
"Consider providing a template with the same sampling rate as the target."
|
@@ -842,9 +863,6 @@ def main():
|
|
842
863
|
)
|
843
864
|
|
844
865
|
initial_shape = target.shape
|
845
|
-
is_cropped = crop_data(
|
846
|
-
data=target, data_mask=target_mask, cutoff=args.cutoff_target
|
847
|
-
)
|
848
866
|
print_block(
|
849
867
|
name="Target",
|
850
868
|
data={
|
@@ -853,13 +871,6 @@ def main():
|
|
853
871
|
"Final Shape": target.shape,
|
854
872
|
},
|
855
873
|
)
|
856
|
-
if is_cropped:
|
857
|
-
args.target = generate_tempfile_name(suffix=".mrc")
|
858
|
-
target.to_file(args.target)
|
859
|
-
|
860
|
-
if target_mask:
|
861
|
-
args.target_mask = generate_tempfile_name(suffix=".mrc")
|
862
|
-
target_mask.to_file(args.target_mask)
|
863
874
|
|
864
875
|
if target_mask:
|
865
876
|
print_block(
|
@@ -872,8 +883,6 @@ def main():
|
|
872
883
|
)
|
873
884
|
|
874
885
|
initial_shape = template.shape
|
875
|
-
_ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
|
876
|
-
|
877
886
|
translation = np.zeros(len(template.shape), dtype=np.float32)
|
878
887
|
if not args.no_centering:
|
879
888
|
template, translation = template.centered(0)
|
@@ -921,47 +930,62 @@ def main():
|
|
921
930
|
|
922
931
|
if args.scramble_phases:
|
923
932
|
template.data = scramble_phases(
|
924
|
-
template.data, noise_proportion=1.0, normalize_power=
|
933
|
+
template.data, noise_proportion=1.0, normalize_power=False
|
925
934
|
)
|
926
935
|
|
927
|
-
|
936
|
+
# Determine suitable backend for the selected operation
|
937
|
+
available_backends = be.available_backends()
|
938
|
+
if args.backend is not None:
|
939
|
+
req_backend = args.backend
|
940
|
+
if req_backend not in available_backends:
|
941
|
+
raise ValueError("Requested backend is not available.")
|
942
|
+
available_backends = [req_backend]
|
943
|
+
|
944
|
+
be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
|
928
945
|
if args.use_gpu:
|
929
946
|
args.cores = len(args.gpu_indices)
|
930
|
-
|
931
|
-
|
947
|
+
be_selection = ("pytorch", "cupy", "jax")
|
948
|
+
if args.use_mixed_precision:
|
949
|
+
be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
|
932
950
|
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
951
|
+
available_backends = [x for x in available_backends if x in be_selection]
|
952
|
+
if args.peak_calling:
|
953
|
+
if "jax" in available_backends:
|
954
|
+
available_backends.remove("jax")
|
955
|
+
if args.use_gpu and "pytorch" in available_backends:
|
956
|
+
available_backends = ("pytorch",)
|
957
|
+
|
958
|
+
# dim_match = len(template.shape) == len(target.shape) <= 3
|
959
|
+
# if dim_match and args.use_gpu and "jax" in available_backends:
|
960
|
+
# args.interpolation_order = 1
|
961
|
+
# available_backends = ["jax"]
|
962
|
+
|
963
|
+
backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
|
964
|
+
if args.use_gpu:
|
965
|
+
backend_preference = ("cupy", "pytorch", "jax")
|
966
|
+
for pref in backend_preference:
|
967
|
+
if pref not in available_backends:
|
968
|
+
continue
|
969
|
+
be.change_backend(pref)
|
970
|
+
if pref == "pytorch":
|
971
|
+
be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
|
972
|
+
|
973
|
+
if args.use_mixed_precision:
|
974
|
+
be.change_backend(
|
975
|
+
backend_name=pref,
|
976
|
+
default_dtype=be._array_backend.float16,
|
977
|
+
complex_dtype=be._array_backend.complex64,
|
978
|
+
default_dtype_int=be._array_backend.int16,
|
937
979
|
)
|
980
|
+
break
|
938
981
|
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
else:
|
945
|
-
preferred_backend = "cupy"
|
946
|
-
if not has_cupy:
|
947
|
-
preferred_backend = "pytorch"
|
948
|
-
backend.change_backend(backend_name=preferred_backend, device="cuda")
|
949
|
-
if args.use_mixed_precision and preferred_backend == "pytorch":
|
950
|
-
raise NotImplementedError(
|
951
|
-
"pytorch backend does not yet support mixed precision."
|
952
|
-
" Consider installing CuPy to enable this feature."
|
953
|
-
)
|
954
|
-
elif args.use_mixed_precision:
|
955
|
-
backend.change_backend(
|
956
|
-
backend_name="cupy",
|
957
|
-
default_dtype=backend._array_backend.float16,
|
958
|
-
complex_dtype=backend._array_backend.complex64,
|
959
|
-
default_dtype_int=backend._array_backend.int16,
|
960
|
-
)
|
961
|
-
available_memory = backend.get_available_memory() * args.cores
|
962
|
-
if preferred_backend == "pytorch" and args.interpolation_order == 3:
|
963
|
-
args.interpolation_order = 1
|
982
|
+
if pref == "pytorch" and args.interpolation_order == 3:
|
983
|
+
warnings.warn(
|
984
|
+
"Pytorch does not support --interpolation_order 3, setting it to 1."
|
985
|
+
)
|
986
|
+
args.interpolation_order = 1
|
964
987
|
|
988
|
+
available_memory = be.get_available_memory() * be.device_count()
|
965
989
|
if args.memory is None:
|
966
990
|
args.memory = int(args.memory_scaling * available_memory)
|
967
991
|
|
@@ -978,70 +1002,49 @@ def main():
|
|
978
1002
|
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
979
1003
|
)
|
980
1004
|
|
981
|
-
|
982
|
-
matching_data.template_filter =
|
983
|
-
|
984
|
-
|
985
|
-
template_box = matching_data._output_template_shape
|
986
|
-
if not args.pad_fourier:
|
987
|
-
template_box = np.ones(len(template_box), dtype=int)
|
988
|
-
|
989
|
-
target_padding = np.zeros(
|
990
|
-
(backend.size(matching_data._output_template_shape)), dtype=int
|
991
|
-
)
|
992
|
-
if args.pad_target_edges:
|
993
|
-
target_padding = matching_data._output_template_shape
|
994
|
-
|
995
|
-
splits, schedule = compute_parallelization_schedule(
|
996
|
-
shape1=target.shape,
|
997
|
-
shape2=tuple(int(x) for x in template_box),
|
998
|
-
shape1_padding=tuple(int(x) for x in target_padding),
|
999
|
-
max_cores=args.cores,
|
1000
|
-
max_ram=args.memory,
|
1001
|
-
split_only_outer=args.use_gpu,
|
1002
|
-
matching_method=args.score,
|
1003
|
-
analyzer_method=callback_class.__name__,
|
1004
|
-
backend=backend._backend_name,
|
1005
|
-
float_nbytes=backend.datatype_bytes(backend._float_dtype),
|
1006
|
-
complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
|
1007
|
-
integer_nbytes=backend.datatype_bytes(backend._int_dtype),
|
1005
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
1006
|
+
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
1007
|
+
args, template, target
|
1008
1008
|
)
|
1009
1009
|
|
1010
|
-
|
1011
|
-
print(
|
1012
|
-
"Found no suitable parallelization schedule. Consider increasing"
|
1013
|
-
" available RAM or decreasing number of cores."
|
1014
|
-
)
|
1015
|
-
exit(-1)
|
1010
|
+
splits, schedule = compute_schedule(args, target, matching_data, callback_class)
|
1016
1011
|
|
1017
|
-
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
1018
1012
|
n_splits = np.prod(list(splits.values()))
|
1019
1013
|
target_split = ", ".join(
|
1020
1014
|
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
1021
1015
|
)
|
1022
1016
|
gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
|
1023
1017
|
options = {
|
1024
|
-
"CPU Cores": args.cores,
|
1025
|
-
"Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
|
1026
|
-
"Use Mixed Precision": args.use_mixed_precision,
|
1027
|
-
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1028
|
-
"Temporary Directory": args.temp_directory,
|
1029
|
-
"Extend Fourier Grid": not args.no_fourier_padding,
|
1030
|
-
"Extend Target Edges": not args.no_edge_padding,
|
1031
|
-
"Interpolation Order": args.interpolation_order,
|
1032
|
-
"Score": f"{args.score}",
|
1033
|
-
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
1034
|
-
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
1035
1018
|
"Angular Sampling": f"{args.angular_sampling}"
|
1036
1019
|
f" [{matching_data.rotations.shape[0]} rotations]",
|
1020
|
+
"Center Template": not args.no_centering,
|
1037
1021
|
"Scramble Template": args.scramble_phases,
|
1038
|
-
"
|
1022
|
+
"Invert Contrast": args.invert_target_contrast,
|
1023
|
+
"Extend Fourier Grid": args.pad_fourier,
|
1024
|
+
"Extend Target Edges": args.pad_edges,
|
1025
|
+
"Interpolation Order": args.interpolation_order,
|
1026
|
+
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
1027
|
+
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
1039
1028
|
}
|
1040
1029
|
|
1041
1030
|
print_block(
|
1042
|
-
name="Template Matching
|
1031
|
+
name="Template Matching",
|
1043
1032
|
data=options,
|
1044
|
-
label_width=max(len(key) for key in options.keys()) +
|
1033
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1034
|
+
)
|
1035
|
+
|
1036
|
+
compute_options = {
|
1037
|
+
"Backend": be._BACKEND_REGISTRY[be._backend_name],
|
1038
|
+
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1039
|
+
"Use Mixed Precision": args.use_mixed_precision,
|
1040
|
+
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1041
|
+
"Temporary Directory": args.temp_directory,
|
1042
|
+
"Target Splits": f"{target_split} [N={n_splits}]",
|
1043
|
+
}
|
1044
|
+
print_block(
|
1045
|
+
name="Computation",
|
1046
|
+
data=compute_options,
|
1047
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1045
1048
|
)
|
1046
1049
|
|
1047
1050
|
filter_args = {
|
@@ -1054,12 +1057,13 @@ def main():
|
|
1054
1057
|
"Tilt Angles": args.tilt_angles,
|
1055
1058
|
"Tilt Weighting": args.tilt_weighting,
|
1056
1059
|
"Reconstruction Filter": args.reconstruction_filter,
|
1060
|
+
"Extend Filter Grid": args.pad_filter,
|
1057
1061
|
}
|
1058
1062
|
if args.ctf_file is not None or args.defocus is not None:
|
1059
1063
|
filter_args["CTF File"] = args.ctf_file
|
1060
1064
|
filter_args["Defocus"] = args.defocus
|
1061
1065
|
filter_args["Phase Shift"] = args.phase_shift
|
1062
|
-
filter_args["
|
1066
|
+
filter_args["Flip Phase"] = args.no_flip_phase
|
1063
1067
|
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1064
1068
|
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1065
1069
|
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
@@ -1070,20 +1074,19 @@ def main():
|
|
1070
1074
|
print_block(
|
1071
1075
|
name="Filters",
|
1072
1076
|
data=filter_args,
|
1073
|
-
label_width=max(len(key) for key in options.keys()) +
|
1077
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1074
1078
|
)
|
1075
1079
|
|
1076
1080
|
analyzer_args = {
|
1077
1081
|
"score_threshold": args.score_threshold,
|
1078
|
-
"number_of_peaks":
|
1079
|
-
"
|
1082
|
+
"number_of_peaks": args.number_of_peaks,
|
1083
|
+
"min_distance": max(template.shape) // 3,
|
1080
1084
|
"use_memmap": args.use_memmap,
|
1081
1085
|
}
|
1082
|
-
analyzer_args = {"Analyzer": callback_class, **analyzer_args}
|
1083
1086
|
print_block(
|
1084
|
-
name="
|
1085
|
-
data=analyzer_args,
|
1086
|
-
label_width=max(len(key) for key in options.keys()) +
|
1087
|
+
name="Analyzer",
|
1088
|
+
data={"Analyzer": callback_class, **analyzer_args},
|
1089
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1087
1090
|
)
|
1088
1091
|
print("\n" + "-" * 80)
|
1089
1092
|
|
@@ -1102,8 +1105,9 @@ def main():
|
|
1102
1105
|
callback_class=callback_class,
|
1103
1106
|
callback_class_args=analyzer_args,
|
1104
1107
|
target_splits=splits,
|
1105
|
-
pad_target_edges=args.
|
1108
|
+
pad_target_edges=args.pad_edges,
|
1106
1109
|
pad_fourier=args.pad_fourier,
|
1110
|
+
pad_template_filter=args.pad_filter,
|
1107
1111
|
interpolation_order=args.interpolation_order,
|
1108
1112
|
)
|
1109
1113
|
|
@@ -1113,19 +1117,18 @@ def main():
|
|
1113
1117
|
candidates[0] *= target_mask.data
|
1114
1118
|
with warnings.catch_warnings():
|
1115
1119
|
warnings.simplefilter("ignore", category=UserWarning)
|
1116
|
-
nbytes =
|
1120
|
+
nbytes = be.datatype_bytes(be._float_dtype)
|
1117
1121
|
dtype = np.float32 if nbytes == 4 else np.float16
|
1118
1122
|
rot_dim = matching_data.rotations.shape[1]
|
1119
1123
|
candidates[3] = {
|
1120
|
-
x:
|
1121
|
-
np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1122
|
-
)
|
1124
|
+
x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1123
1125
|
for i, x in candidates[3].items()
|
1124
1126
|
}
|
1125
1127
|
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1126
1128
|
write_pickle(data=candidates, filename=args.output)
|
1127
1129
|
|
1128
1130
|
runtime = time() - start
|
1131
|
+
print("\n" + "-" * 80)
|
1129
1132
|
print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
|
1130
1133
|
|
1131
1134
|
|