pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
- {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
- {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
- pytme-0.3b0.post1.dist-info/RECORD +126 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
- pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +341 -378
- pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +318 -189
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +12 -12
- scripts/pytme_runner.py +769 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -54
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +395 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -204
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +54 -9
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.post1.dist-info/RECORD +0 -119
- pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
|
-
#!
|
2
|
-
""" CLI for basic pyTME template matching functions.
|
1
|
+
#!python3
|
2
|
+
""" CLI interface for basic pyTME template matching functions.
|
3
3
|
|
4
4
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
5
|
|
@@ -12,34 +12,28 @@ from sys import exit
|
|
12
12
|
from time import time
|
13
13
|
from typing import Tuple
|
14
14
|
from copy import deepcopy
|
15
|
-
from os.path import exists
|
16
|
-
from tempfile import gettempdir
|
15
|
+
from os.path import abspath, exists
|
17
16
|
|
18
17
|
import numpy as np
|
19
18
|
|
20
|
-
from tme.backends import backend as be
|
21
19
|
from tme import Density, __version__
|
22
|
-
from tme.matching_utils import
|
23
|
-
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
24
|
-
from tme.rotations import (
|
25
|
-
get_cone_rotations,
|
20
|
+
from tme.matching_utils import (
|
26
21
|
get_rotation_matrices,
|
22
|
+
get_rotations_around_vector,
|
23
|
+
compute_parallelization_schedule,
|
24
|
+
scramble_phases,
|
25
|
+
generate_tempfile_name,
|
26
|
+
write_pickle,
|
27
27
|
)
|
28
|
+
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
28
29
|
from tme.matching_data import MatchingData
|
29
30
|
from tme.analyzer import (
|
30
31
|
MaxScoreOverRotations,
|
31
32
|
PeakCallerMaximumFilter,
|
32
33
|
)
|
33
|
-
from tme.
|
34
|
-
|
35
|
-
|
36
|
-
Compose,
|
37
|
-
BandPassFilter,
|
38
|
-
WedgeReconstructed,
|
39
|
-
ReconstructFromTilt,
|
40
|
-
LinearWhiteningFilter,
|
41
|
-
)
|
42
|
-
|
34
|
+
from tme.backends import backend as be
|
35
|
+
from tme.preprocessing import Compose
|
36
|
+
from tme.scoring import flc_scoring2
|
43
37
|
|
44
38
|
def get_func_fullname(func) -> str:
|
45
39
|
"""Returns the full name of the given function, including its module."""
|
@@ -50,8 +44,6 @@ def print_block(name: str, data: dict, label_width=20) -> None:
|
|
50
44
|
"""Prints a formatted block of information."""
|
51
45
|
print(f"\n> {name}")
|
52
46
|
for key, value in data.items():
|
53
|
-
if isinstance(value, np.ndarray):
|
54
|
-
value = value.shape
|
55
47
|
formatted_value = str(value)
|
56
48
|
print(f" - {key + ':':<{label_width}} {formatted_value}")
|
57
49
|
|
@@ -107,9 +99,7 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
107
99
|
f"Expected shape of {mask_path} was {mask_target.shape},"
|
108
100
|
f" got f{mask.shape}"
|
109
101
|
)
|
110
|
-
if not np.allclose(
|
111
|
-
np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
|
112
|
-
):
|
102
|
+
if not np.allclose(mask.sampling_rate, mask_target.sampling_rate):
|
113
103
|
raise ValueError(
|
114
104
|
f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
|
115
105
|
f", got f{mask.sampling_rate}"
|
@@ -117,6 +107,50 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
117
107
|
return mask
|
118
108
|
|
119
109
|
|
110
|
+
def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
|
111
|
+
"""
|
112
|
+
Crop the provided data and mask to a smaller box based on a cutoff value.
|
113
|
+
|
114
|
+
Parameters
|
115
|
+
----------
|
116
|
+
data : Density
|
117
|
+
The data that should be cropped.
|
118
|
+
cutoff : float
|
119
|
+
The threshold value to determine which parts of the data should be kept.
|
120
|
+
data_mask : Density, optional
|
121
|
+
A mask for the data that should be cropped.
|
122
|
+
|
123
|
+
Returns
|
124
|
+
-------
|
125
|
+
bool
|
126
|
+
Returns True if the data was adjusted (cropped), otherwise returns False.
|
127
|
+
|
128
|
+
Notes
|
129
|
+
-----
|
130
|
+
Cropping is performed in place.
|
131
|
+
"""
|
132
|
+
if cutoff is None:
|
133
|
+
return False
|
134
|
+
|
135
|
+
box = data.trim_box(cutoff=cutoff)
|
136
|
+
box_mask = box
|
137
|
+
if data_mask is not None:
|
138
|
+
box_mask = data_mask.trim_box(cutoff=cutoff)
|
139
|
+
box = tuple(
|
140
|
+
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
141
|
+
for arr, mask in zip(box, box_mask)
|
142
|
+
)
|
143
|
+
if box == tuple(slice(0, x) for x in data.shape):
|
144
|
+
return False
|
145
|
+
|
146
|
+
data.adjust_box(box)
|
147
|
+
|
148
|
+
if data_mask:
|
149
|
+
data_mask.adjust_box(box)
|
150
|
+
|
151
|
+
return True
|
152
|
+
|
153
|
+
|
120
154
|
def parse_rotation_logic(args, ndim):
|
121
155
|
if args.angular_sampling is not None:
|
122
156
|
rotations = get_rotation_matrices(
|
@@ -131,66 +165,34 @@ def parse_rotation_logic(args, ndim):
|
|
131
165
|
if args.axis_sampling is None:
|
132
166
|
args.axis_sampling = args.cone_sampling
|
133
167
|
|
134
|
-
rotations =
|
168
|
+
rotations = get_rotations_around_vector(
|
135
169
|
cone_angle=args.cone_angle,
|
136
170
|
cone_sampling=args.cone_sampling,
|
137
171
|
axis_angle=args.axis_angle,
|
138
172
|
axis_sampling=args.axis_sampling,
|
139
173
|
n_symmetry=args.axis_symmetry,
|
140
|
-
axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
|
141
|
-
reference=[0, 0, -1],
|
142
174
|
)
|
143
175
|
return rotations
|
144
176
|
|
145
177
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
pad_edges = True
|
155
|
-
|
156
|
-
splits, schedule = matching_data.computation_schedule(
|
157
|
-
matching_method=args.score,
|
158
|
-
analyzer_method=callback_class.__name__,
|
159
|
-
use_gpu=args.use_gpu,
|
160
|
-
pad_fourier=False,
|
161
|
-
pad_target_edges=pad_edges,
|
162
|
-
available_memory=args.memory,
|
163
|
-
max_cores=args.cores,
|
178
|
+
# TODO: Think about whether wedge mask should also be added to target
|
179
|
+
# For now leave it at the cost of incorrect upper bound on the scores
|
180
|
+
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
181
|
+
from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
|
182
|
+
from tme.preprocessing.tilt_series import (
|
183
|
+
Wedge,
|
184
|
+
WedgeReconstructed,
|
185
|
+
ReconstructFromTilt,
|
164
186
|
)
|
165
187
|
|
166
|
-
if splits is None:
|
167
|
-
print(
|
168
|
-
"Found no suitable parallelization schedule. Consider increasing"
|
169
|
-
" available RAM or decreasing number of cores."
|
170
|
-
)
|
171
|
-
exit(-1)
|
172
|
-
|
173
|
-
n_splits = np.prod(list(splits.values()))
|
174
|
-
if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
|
175
|
-
args.pad_edges = True
|
176
|
-
return compute_schedule(args, matching_data, callback_class, True)
|
177
|
-
return splits, schedule
|
178
|
-
|
179
|
-
|
180
|
-
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
181
|
-
needs_reconstruction = False
|
182
188
|
template_filter, target_filter = [], []
|
183
189
|
if args.tilt_angles is not None:
|
184
|
-
needs_reconstruction = args.tilt_weighting is not None
|
185
190
|
try:
|
186
191
|
wedge = Wedge.from_file(args.tilt_angles)
|
187
192
|
wedge.weight_type = args.tilt_weighting
|
188
193
|
if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
|
189
194
|
wedge = WedgeReconstructed(
|
190
|
-
angles=wedge.angles,
|
191
|
-
weight_wedge=args.tilt_weighting == "angle",
|
192
|
-
opening_axis=args.wedge_axes[0],
|
193
|
-
tilt_axis=args.wedge_axes[1],
|
195
|
+
angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
|
194
196
|
)
|
195
197
|
except FileNotFoundError:
|
196
198
|
tilt_step, create_continuous_wedge = None, True
|
@@ -229,29 +231,23 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
229
231
|
angles=tilt_angles,
|
230
232
|
weight_wedge=args.tilt_weighting == "angle",
|
231
233
|
create_continuous_wedge=create_continuous_wedge,
|
232
|
-
reconstruction_filter=args.reconstruction_filter,
|
233
|
-
opening_axis=args.wedge_axes[0],
|
234
|
-
tilt_axis=args.wedge_axes[1],
|
235
|
-
)
|
236
|
-
wedge_target = WedgeReconstructed(
|
237
|
-
angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
|
238
|
-
weight_wedge=False,
|
239
|
-
create_continuous_wedge=True,
|
240
|
-
opening_axis=args.wedge_axes[0],
|
241
|
-
tilt_axis=args.wedge_axes[1],
|
242
234
|
)
|
243
|
-
target_filter.append(wedge_target)
|
244
235
|
|
236
|
+
wedge.opening_axis = args.wedge_axes[0]
|
237
|
+
wedge.tilt_axis = args.wedge_axes[1]
|
245
238
|
wedge.sampling_rate = template.sampling_rate
|
246
239
|
template_filter.append(wedge)
|
247
240
|
if not isinstance(wedge, WedgeReconstructed):
|
248
|
-
|
249
|
-
|
250
|
-
|
241
|
+
template_filter.append(
|
242
|
+
ReconstructFromTilt(
|
243
|
+
reconstruction_filter=args.reconstruction_filter,
|
244
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
245
|
+
)
|
251
246
|
)
|
252
|
-
template_filter.append(reconstruction_filter)
|
253
247
|
|
254
248
|
if args.ctf_file is not None or args.defocus is not None:
|
249
|
+
from tme.preprocessing.tilt_series import CTF
|
250
|
+
|
255
251
|
needs_reconstruction = True
|
256
252
|
if args.ctf_file is not None:
|
257
253
|
ctf = CTF.from_file(args.ctf_file)
|
@@ -263,7 +259,6 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
263
259
|
"per micrograph."
|
264
260
|
)
|
265
261
|
ctf.angles = wedge.angles
|
266
|
-
ctf.no_reconstruction = False
|
267
262
|
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
268
263
|
else:
|
269
264
|
needs_reconstruction = False
|
@@ -322,13 +317,18 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
322
317
|
sampling_rate=template.sampling_rate,
|
323
318
|
)
|
324
319
|
template_filter.append(bandpass)
|
325
|
-
|
320
|
+
|
321
|
+
if not args.no_filter_target:
|
322
|
+
target_filter.append(bandpass)
|
326
323
|
|
327
324
|
if args.whiten_spectrum:
|
328
325
|
whitening_filter = LinearWhiteningFilter()
|
329
326
|
template_filter.append(whitening_filter)
|
330
327
|
target_filter.append(whitening_filter)
|
331
328
|
|
329
|
+
needs_reconstruction = any(
|
330
|
+
[isinstance(t, ReconstructFromTilt) for t in template_filter]
|
331
|
+
)
|
332
332
|
if needs_reconstruction and args.reconstruction_filter is None:
|
333
333
|
warnings.warn(
|
334
334
|
"Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
|
@@ -336,8 +336,6 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
336
336
|
|
337
337
|
template_filter = Compose(template_filter) if len(template_filter) else None
|
338
338
|
target_filter = Compose(target_filter) if len(target_filter) else None
|
339
|
-
if args.no_filter_target:
|
340
|
-
target_filter = None
|
341
339
|
|
342
340
|
return template_filter, target_filter
|
343
341
|
|
@@ -345,7 +343,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
345
343
|
def parse_args():
|
346
344
|
parser = argparse.ArgumentParser(
|
347
345
|
description="Perform template matching.",
|
348
|
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
346
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
349
347
|
)
|
350
348
|
|
351
349
|
io_group = parser.add_argument_group("Input / Output")
|
@@ -395,7 +393,8 @@ def parse_args():
|
|
395
393
|
dest="invert_target_contrast",
|
396
394
|
action="store_true",
|
397
395
|
default=False,
|
398
|
-
help="Invert the target's contrast
|
396
|
+
help="Invert the target's contrast and rescale linearly between zero and one. "
|
397
|
+
"This option is intended for targets where templates to-be-matched have "
|
399
398
|
"negative values, e.g. tomograms.",
|
400
399
|
)
|
401
400
|
io_group.add_argument(
|
@@ -435,19 +434,6 @@ def parse_args():
|
|
435
434
|
help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
|
436
435
|
"narrow interval around a known orientation, e.g. for surface oversampling.",
|
437
436
|
)
|
438
|
-
angular_group.add_argument(
|
439
|
-
"--cone_axis",
|
440
|
-
dest="cone_axis",
|
441
|
-
type=check_positive,
|
442
|
-
default=2,
|
443
|
-
help="Principal axis to build cone around.",
|
444
|
-
)
|
445
|
-
angular_group.add_argument(
|
446
|
-
"--invert_cone",
|
447
|
-
dest="invert_cone",
|
448
|
-
action="store_true",
|
449
|
-
help="Invert cone handedness.",
|
450
|
-
)
|
451
437
|
angular_group.add_argument(
|
452
438
|
"--cone_sampling",
|
453
439
|
dest="cone_sampling",
|
@@ -529,7 +515,7 @@ def parse_args():
|
|
529
515
|
required=False,
|
530
516
|
type=float,
|
531
517
|
default=0.85,
|
532
|
-
help="Fraction of available memory to be used. Ignored if --ram is set."
|
518
|
+
help="Fraction of available memory to be used. Ignored if --ram is set."
|
533
519
|
)
|
534
520
|
computation_group.add_argument(
|
535
521
|
"--temp_directory",
|
@@ -544,6 +530,7 @@ def parse_args():
|
|
544
530
|
choices=be.available_backends(),
|
545
531
|
help="[Expert] Overwrite default computation backend.",
|
546
532
|
)
|
533
|
+
|
547
534
|
filter_group = parser.add_argument_group("Filters")
|
548
535
|
filter_group.add_argument(
|
549
536
|
"--lowpass",
|
@@ -590,8 +577,8 @@ def parse_args():
|
|
590
577
|
type=str,
|
591
578
|
required=False,
|
592
579
|
default=None,
|
593
|
-
help="Indices of wedge opening and tilt axis, e.g.
|
594
|
-
"in z and tilted over the x
|
580
|
+
help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
|
581
|
+
"in z-direction and tilted over the x axis.",
|
595
582
|
)
|
596
583
|
filter_group.add_argument(
|
597
584
|
"--tilt_angles",
|
@@ -698,7 +685,7 @@ def parse_args():
|
|
698
685
|
dest="no_flip_phase",
|
699
686
|
action="store_false",
|
700
687
|
required=False,
|
701
|
-
help="
|
688
|
+
help="Perform phase-flipping CTF correction.",
|
702
689
|
)
|
703
690
|
ctf_group.add_argument(
|
704
691
|
"--correct_defocus_gradient",
|
@@ -710,6 +697,22 @@ def parse_args():
|
|
710
697
|
)
|
711
698
|
|
712
699
|
performance_group = parser.add_argument_group("Performance")
|
700
|
+
performance_group.add_argument(
|
701
|
+
"--cutoff_target",
|
702
|
+
dest="cutoff_target",
|
703
|
+
type=float,
|
704
|
+
required=False,
|
705
|
+
default=None,
|
706
|
+
help="Target contour level (used for cropping).",
|
707
|
+
)
|
708
|
+
performance_group.add_argument(
|
709
|
+
"--cutoff_template",
|
710
|
+
dest="cutoff_template",
|
711
|
+
type=float,
|
712
|
+
required=False,
|
713
|
+
default=None,
|
714
|
+
help="Template contour level (used for cropping).",
|
715
|
+
)
|
713
716
|
performance_group.add_argument(
|
714
717
|
"--no_centering",
|
715
718
|
dest="no_centering",
|
@@ -717,20 +720,30 @@ def parse_args():
|
|
717
720
|
help="Assumes the template is already centered and omits centering.",
|
718
721
|
)
|
719
722
|
performance_group.add_argument(
|
720
|
-
"--
|
721
|
-
dest="
|
723
|
+
"--no_edge_padding",
|
724
|
+
dest="no_edge_padding",
|
725
|
+
action="store_true",
|
726
|
+
default=False,
|
727
|
+
help="Whether to not pad the edges of the target. Can be set if the target"
|
728
|
+
" has a well defined bounding box, e.g. a masked reconstruction.",
|
729
|
+
)
|
730
|
+
performance_group.add_argument(
|
731
|
+
"--no_fourier_padding",
|
732
|
+
dest="no_fourier_padding",
|
722
733
|
action="store_true",
|
723
734
|
default=False,
|
724
|
-
help="Whether
|
725
|
-
"
|
735
|
+
help="Whether input arrays should not be zero-padded to full convolution shape "
|
736
|
+
"for numerical stability. When working with very large targets, e.g. tomograms, "
|
737
|
+
"it is safe to use this flag and benefit from the performance gain.",
|
726
738
|
)
|
727
739
|
performance_group.add_argument(
|
728
|
-
"--
|
729
|
-
dest="
|
740
|
+
"--no_filter_padding",
|
741
|
+
dest="no_filter_padding",
|
730
742
|
action="store_true",
|
731
743
|
default=False,
|
732
|
-
help="
|
733
|
-
"
|
744
|
+
help="Omits padding of optional template filters. Particularly effective when "
|
745
|
+
"the target is much larger than the template. However, for fast osciliating "
|
746
|
+
"filters setting this flag can introduce aliasing effects.",
|
734
747
|
)
|
735
748
|
performance_group.add_argument(
|
736
749
|
"--interpolation_order",
|
@@ -777,7 +790,7 @@ def parse_args():
|
|
777
790
|
dest="number_of_peaks",
|
778
791
|
action="store_true",
|
779
792
|
default=1000,
|
780
|
-
help="Number of peaks to call, 1000 by default
|
793
|
+
help="Number of peaks to call, 1000 by default..",
|
781
794
|
)
|
782
795
|
args = parser.parse_args()
|
783
796
|
args.version = __version__
|
@@ -786,9 +799,16 @@ def parse_args():
|
|
786
799
|
args.interpolation_order = None
|
787
800
|
|
788
801
|
if args.temp_directory is None:
|
789
|
-
|
802
|
+
default = abspath(".")
|
803
|
+
if os.environ.get("TMPDIR", None) is not None:
|
804
|
+
default = os.environ.get("TMPDIR")
|
805
|
+
args.temp_directory = default
|
790
806
|
|
791
807
|
os.environ["TMPDIR"] = args.temp_directory
|
808
|
+
|
809
|
+
args.pad_target_edges = not args.no_edge_padding
|
810
|
+
args.pad_fourier = not args.no_fourier_padding
|
811
|
+
|
792
812
|
if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
|
793
813
|
raise ValueError(
|
794
814
|
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
@@ -837,15 +857,15 @@ def main():
|
|
837
857
|
try:
|
838
858
|
template = Density.from_file(args.template)
|
839
859
|
except Exception:
|
860
|
+
drop = target.metadata.get("batch_dimension", ())
|
861
|
+
keep = [i not in drop for i in range(target.data.ndim)]
|
840
862
|
template = Density.from_structure(
|
841
863
|
filename_or_structure=args.template,
|
842
|
-
sampling_rate=target.sampling_rate,
|
864
|
+
sampling_rate=target.sampling_rate[keep],
|
843
865
|
)
|
844
866
|
|
845
867
|
if target.sampling_rate.size == template.sampling_rate.size:
|
846
|
-
if not np.allclose(
|
847
|
-
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
848
|
-
):
|
868
|
+
if not np.allclose(target.sampling_rate, template.sampling_rate):
|
849
869
|
print(
|
850
870
|
f"Resampling template to {target.sampling_rate}. "
|
851
871
|
"Consider providing a template with the same sampling rate as the target."
|
@@ -860,6 +880,9 @@ def main():
|
|
860
880
|
)
|
861
881
|
|
862
882
|
initial_shape = target.shape
|
883
|
+
is_cropped = crop_data(
|
884
|
+
data=target, data_mask=target_mask, cutoff=args.cutoff_target
|
885
|
+
)
|
863
886
|
print_block(
|
864
887
|
name="Target",
|
865
888
|
data={
|
@@ -868,6 +891,13 @@ def main():
|
|
868
891
|
"Final Shape": target.shape,
|
869
892
|
},
|
870
893
|
)
|
894
|
+
if is_cropped:
|
895
|
+
args.target = generate_tempfile_name(suffix=".mrc")
|
896
|
+
target.to_file(args.target)
|
897
|
+
|
898
|
+
if target_mask:
|
899
|
+
args.target_mask = generate_tempfile_name(suffix=".mrc")
|
900
|
+
target_mask.to_file(args.target_mask)
|
871
901
|
|
872
902
|
if target_mask:
|
873
903
|
print_block(
|
@@ -880,6 +910,8 @@ def main():
|
|
880
910
|
)
|
881
911
|
|
882
912
|
initial_shape = template.shape
|
913
|
+
_ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
|
914
|
+
|
883
915
|
translation = np.zeros(len(template.shape), dtype=np.float32)
|
884
916
|
if not args.no_centering:
|
885
917
|
template, translation = template.centered(0)
|
@@ -927,7 +959,7 @@ def main():
|
|
927
959
|
|
928
960
|
if args.scramble_phases:
|
929
961
|
template.data = scramble_phases(
|
930
|
-
template.data, noise_proportion=1.0, normalize_power=
|
962
|
+
template.data, noise_proportion=1.0, normalize_power=True
|
931
963
|
)
|
932
964
|
|
933
965
|
# Determine suitable backend for the selected operation
|
@@ -935,8 +967,10 @@ def main():
|
|
935
967
|
if args.backend is not None:
|
936
968
|
req_backend = args.backend
|
937
969
|
if req_backend not in available_backends:
|
938
|
-
raise ValueError(
|
939
|
-
|
970
|
+
raise ValueError(
|
971
|
+
"Requested backend is not available."
|
972
|
+
)
|
973
|
+
available_backends = [req_backend,]
|
940
974
|
|
941
975
|
be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
|
942
976
|
if args.use_gpu:
|
@@ -951,21 +985,23 @@ def main():
|
|
951
985
|
available_backends.remove("jax")
|
952
986
|
if args.use_gpu and "pytorch" in available_backends:
|
953
987
|
available_backends = ("pytorch",)
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
988
|
+
if args.interpolation_order == 3:
|
989
|
+
raise NotImplementedError(
|
990
|
+
"Pytorch does not support --interpolation_order 3, 1 is supported."
|
991
|
+
)
|
992
|
+
ndim = len(template.shape)
|
993
|
+
if len(target.shape) == ndim and ndim <= 3 and args.use_gpu:
|
994
|
+
available_backends = ["jax", ]
|
959
995
|
|
960
996
|
backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
|
961
997
|
if args.use_gpu:
|
962
|
-
backend_preference = ("cupy", "
|
998
|
+
backend_preference = ("cupy", "jax", "pytorch")
|
963
999
|
for pref in backend_preference:
|
964
1000
|
if pref not in available_backends:
|
965
1001
|
continue
|
966
1002
|
be.change_backend(pref)
|
967
1003
|
if pref == "pytorch":
|
968
|
-
be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
|
1004
|
+
be.change_backend(pref, device = "cuda" if args.use_gpu else "cpu")
|
969
1005
|
|
970
1006
|
if args.use_mixed_precision:
|
971
1007
|
be.change_backend(
|
@@ -976,12 +1012,6 @@ def main():
|
|
976
1012
|
)
|
977
1013
|
break
|
978
1014
|
|
979
|
-
if pref == "pytorch" and args.interpolation_order == 3:
|
980
|
-
warnings.warn(
|
981
|
-
"Pytorch does not support --interpolation_order 3, setting it to 1."
|
982
|
-
)
|
983
|
-
args.interpolation_order = 1
|
984
|
-
|
985
1015
|
available_memory = be.get_available_memory() * be.device_count()
|
986
1016
|
if args.memory is None:
|
987
1017
|
args.memory = int(args.memory_scaling * available_memory)
|
@@ -999,17 +1029,49 @@ def main():
|
|
999
1029
|
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
1000
1030
|
)
|
1001
1031
|
|
1002
|
-
|
1003
|
-
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
1032
|
+
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
1004
1033
|
args, template, target
|
1005
1034
|
)
|
1006
1035
|
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1036
|
+
target_dims = target.metadata.get("batch_dimension", None)
|
1037
|
+
matching_data._set_matching_dimension(target_dims=target_dims, template_dims=None)
|
1038
|
+
args.score = "FLC" if target_dims is not None else args.score
|
1039
|
+
args.target_batch, args.template_batch = target_dims, None
|
1040
|
+
|
1041
|
+
template_box = matching_data._output_template_shape
|
1042
|
+
if not args.pad_fourier:
|
1043
|
+
template_box = tuple(0 for _ in range(len(template_box)))
|
1044
|
+
|
1045
|
+
target_padding = tuple(0 for _ in range(len(template_box)))
|
1046
|
+
if args.pad_target_edges:
|
1047
|
+
target_padding = matching_data._output_template_shape
|
1048
|
+
|
1049
|
+
splits, schedule = compute_parallelization_schedule(
|
1050
|
+
shape1=target.shape,
|
1051
|
+
shape2=tuple(int(x) for x in template_box),
|
1052
|
+
shape1_padding=tuple(int(x) for x in target_padding),
|
1053
|
+
max_cores=args.cores,
|
1054
|
+
max_ram=args.memory,
|
1055
|
+
split_only_outer=args.use_gpu,
|
1056
|
+
matching_method=args.score,
|
1057
|
+
analyzer_method=callback_class.__name__,
|
1058
|
+
backend=be._backend_name,
|
1059
|
+
float_nbytes=be.datatype_bytes(be._float_dtype),
|
1060
|
+
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
1061
|
+
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
1062
|
+
split_axes=target_dims,
|
1010
1063
|
)
|
1011
|
-
splits, schedule = compute_schedule(args, matching_data, callback_class)
|
1012
1064
|
|
1065
|
+
if splits is None:
|
1066
|
+
print(
|
1067
|
+
"Found no suitable parallelization schedule. Consider increasing"
|
1068
|
+
" available RAM or decreasing number of cores."
|
1069
|
+
)
|
1070
|
+
exit(-1)
|
1071
|
+
|
1072
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
1073
|
+
if target_dims is not None:
|
1074
|
+
matching_score = flc_scoring2
|
1013
1075
|
n_splits = np.prod(list(splits.values()))
|
1014
1076
|
target_split = ", ".join(
|
1015
1077
|
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
@@ -1021,7 +1083,8 @@ def main():
|
|
1021
1083
|
"Center Template": not args.no_centering,
|
1022
1084
|
"Scramble Template": args.scramble_phases,
|
1023
1085
|
"Invert Contrast": args.invert_target_contrast,
|
1024
|
-
"Extend
|
1086
|
+
"Extend Fourier Grid": not args.no_fourier_padding,
|
1087
|
+
"Extend Target Edges": not args.no_edge_padding,
|
1025
1088
|
"Interpolation Order": args.interpolation_order,
|
1026
1089
|
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
1027
1090
|
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
@@ -1034,8 +1097,8 @@ def main():
|
|
1034
1097
|
)
|
1035
1098
|
|
1036
1099
|
compute_options = {
|
1037
|
-
"Backend":
|
1038
|
-
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1100
|
+
"Backend" :be._BACKEND_REGISTRY[be._backend_name],
|
1101
|
+
"Compute Devices" : f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1039
1102
|
"Use Mixed Precision": args.use_mixed_precision,
|
1040
1103
|
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1041
1104
|
"Temporary Directory": args.temp_directory,
|
@@ -1057,7 +1120,6 @@ def main():
|
|
1057
1120
|
"Tilt Angles": args.tilt_angles,
|
1058
1121
|
"Tilt Weighting": args.tilt_weighting,
|
1059
1122
|
"Reconstruction Filter": args.reconstruction_filter,
|
1060
|
-
"Extend Filter Grid": args.pad_filter,
|
1061
1123
|
}
|
1062
1124
|
if args.ctf_file is not None or args.defocus is not None:
|
1063
1125
|
filter_args["CTF File"] = args.ctf_file
|
@@ -1080,7 +1142,8 @@ def main():
|
|
1080
1142
|
analyzer_args = {
|
1081
1143
|
"score_threshold": args.score_threshold,
|
1082
1144
|
"number_of_peaks": args.number_of_peaks,
|
1083
|
-
"min_distance": max(template.shape) //
|
1145
|
+
"min_distance" : max(template.shape) // 2,
|
1146
|
+
"min_boundary_distance" : max(template.shape) // 2,
|
1084
1147
|
"use_memmap": args.use_memmap,
|
1085
1148
|
}
|
1086
1149
|
print_block(
|
@@ -1105,13 +1168,14 @@ def main():
|
|
1105
1168
|
callback_class=callback_class,
|
1106
1169
|
callback_class_args=analyzer_args,
|
1107
1170
|
target_splits=splits,
|
1108
|
-
pad_target_edges=args.
|
1109
|
-
|
1171
|
+
pad_target_edges=args.pad_target_edges,
|
1172
|
+
pad_fourier=args.pad_fourier,
|
1173
|
+
pad_template_filter=not args.no_filter_padding,
|
1110
1174
|
interpolation_order=args.interpolation_order,
|
1111
1175
|
)
|
1112
1176
|
|
1113
1177
|
candidates = list(candidates) if candidates is not None else []
|
1114
|
-
if
|
1178
|
+
if callback_class == MaxScoreOverRotations:
|
1115
1179
|
if target_mask is not None and args.score != "MCC":
|
1116
1180
|
candidates[0] *= target_mask.data
|
1117
1181
|
with warnings.catch_warnings():
|
@@ -1123,6 +1187,7 @@ def main():
|
|
1123
1187
|
x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1124
1188
|
for i, x in candidates[3].items()
|
1125
1189
|
}
|
1190
|
+
print(np.where(candidates[0] == candidates[0].max()), candidates[0].max())
|
1126
1191
|
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1127
1192
|
write_pickle(data=candidates, filename=args.output)
|
1128
1193
|
|