pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3.0.data/scripts/match_template.py +1106 -0
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
- pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
- pytme-0.3.0.dist-info/RECORD +126 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
- pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +349 -378
- pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +320 -190
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +85 -19
- scripts/pytme_runner.py +771 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -53
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +396 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -201
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +158 -28
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.dist-info/RECORD +0 -119
- pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
scripts/match_template.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
#!python3
|
2
|
-
"""
|
2
|
+
"""CLI for basic pyTME template matching functions.
|
3
3
|
|
4
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
5
|
|
6
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
import os
|
9
9
|
import argparse
|
@@ -12,67 +12,44 @@ from sys import exit
|
|
12
12
|
from time import time
|
13
13
|
from typing import Tuple
|
14
14
|
from copy import deepcopy
|
15
|
-
from os.path import exists
|
16
15
|
from tempfile import gettempdir
|
16
|
+
from os.path import exists, abspath
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from tme.backends import backend as be
|
21
|
-
from tme import Density, __version__
|
21
|
+
from tme import Density, __version__, Orientations
|
22
22
|
from tme.matching_utils import scramble_phases, write_pickle
|
23
23
|
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
24
24
|
from tme.rotations import (
|
25
25
|
get_cone_rotations,
|
26
26
|
get_rotation_matrices,
|
27
|
+
euler_to_rotationmatrix,
|
27
28
|
)
|
28
29
|
from tme.matching_data import MatchingData
|
29
30
|
from tme.analyzer import (
|
30
31
|
MaxScoreOverRotations,
|
31
32
|
PeakCallerMaximumFilter,
|
33
|
+
MaxScoreOverRotationsConstrained,
|
32
34
|
)
|
33
35
|
from tme.filters import (
|
34
36
|
CTF,
|
35
37
|
Wedge,
|
36
38
|
Compose,
|
37
|
-
|
39
|
+
BandPass,
|
40
|
+
CTFReconstructed,
|
38
41
|
WedgeReconstructed,
|
39
42
|
ReconstructFromTilt,
|
40
43
|
LinearWhiteningFilter,
|
44
|
+
BandPassReconstructed,
|
45
|
+
)
|
46
|
+
from tme.cli import (
|
47
|
+
get_func_fullname,
|
48
|
+
print_block,
|
49
|
+
print_entry,
|
50
|
+
check_positive,
|
51
|
+
sanitize_name,
|
41
52
|
)
|
42
|
-
|
43
|
-
|
44
|
-
def get_func_fullname(func) -> str:
|
45
|
-
"""Returns the full name of the given function, including its module."""
|
46
|
-
return f"<function '{func.__module__}.{func.__name__}'>"
|
47
|
-
|
48
|
-
|
49
|
-
def print_block(name: str, data: dict, label_width=20) -> None:
|
50
|
-
"""Prints a formatted block of information."""
|
51
|
-
print(f"\n> {name}")
|
52
|
-
for key, value in data.items():
|
53
|
-
if isinstance(value, np.ndarray):
|
54
|
-
value = value.shape
|
55
|
-
formatted_value = str(value)
|
56
|
-
print(f" - {key + ':':<{label_width}} {formatted_value}")
|
57
|
-
|
58
|
-
|
59
|
-
def print_entry() -> None:
|
60
|
-
width = 80
|
61
|
-
text = f" pytme v{__version__} "
|
62
|
-
padding_total = width - len(text) - 2
|
63
|
-
padding_left = padding_total // 2
|
64
|
-
padding_right = padding_total - padding_left
|
65
|
-
|
66
|
-
print("*" * width)
|
67
|
-
print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
|
68
|
-
print("*" * width)
|
69
|
-
|
70
|
-
|
71
|
-
def check_positive(value):
|
72
|
-
ivalue = float(value)
|
73
|
-
if ivalue <= 0:
|
74
|
-
raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
|
75
|
-
return ivalue
|
76
53
|
|
77
54
|
|
78
55
|
def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
@@ -118,6 +95,14 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
118
95
|
|
119
96
|
|
120
97
|
def parse_rotation_logic(args, ndim):
|
98
|
+
if args.particle_diameter is not None:
|
99
|
+
resolution = Density.from_file(args.target, use_memmap=True)
|
100
|
+
resolution = 360 * np.maximum(
|
101
|
+
np.max(2 * resolution.sampling_rate),
|
102
|
+
args.lowpass if args.lowpass is not None else 0,
|
103
|
+
)
|
104
|
+
args.angular_sampling = resolution / (3.14159265358979 * args.particle_diameter)
|
105
|
+
|
121
106
|
if args.angular_sampling is not None:
|
122
107
|
rotations = get_rotation_matrices(
|
123
108
|
angular_sampling=args.angular_sampling,
|
@@ -138,7 +123,7 @@ def parse_rotation_logic(args, ndim):
|
|
138
123
|
axis_sampling=args.axis_sampling,
|
139
124
|
n_symmetry=args.axis_symmetry,
|
140
125
|
axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
|
141
|
-
reference=[0, 0, -1],
|
126
|
+
reference=[0, 0, -1 if args.invert_cone else 1],
|
142
127
|
)
|
143
128
|
return rotations
|
144
129
|
|
@@ -178,123 +163,82 @@ def compute_schedule(
|
|
178
163
|
|
179
164
|
|
180
165
|
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
181
|
-
needs_reconstruction = False
|
182
166
|
template_filter, target_filter = [], []
|
167
|
+
|
168
|
+
if args.tilt_angles is None:
|
169
|
+
args.tilt_angles = args.ctf_file
|
170
|
+
|
171
|
+
wedge = None
|
183
172
|
if args.tilt_angles is not None:
|
184
|
-
needs_reconstruction = args.tilt_weighting is not None
|
185
173
|
try:
|
186
174
|
wedge = Wedge.from_file(args.tilt_angles)
|
187
175
|
wedge.weight_type = args.tilt_weighting
|
188
|
-
if args.tilt_weighting in ("angle", None)
|
176
|
+
if args.tilt_weighting in ("angle", None):
|
189
177
|
wedge = WedgeReconstructed(
|
190
178
|
angles=wedge.angles,
|
191
179
|
weight_wedge=args.tilt_weighting == "angle",
|
192
|
-
opening_axis=args.wedge_axes[0],
|
193
|
-
tilt_axis=args.wedge_axes[1],
|
194
180
|
)
|
195
|
-
except FileNotFoundError:
|
196
|
-
tilt_step, create_continuous_wedge = None, True
|
181
|
+
except (FileNotFoundError, AttributeError):
|
197
182
|
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
tilt_step = float(tilt_step)
|
205
|
-
tilt_angles = np.arange(
|
206
|
-
-tilt_start, tilt_stop + tilt_step, tilt_step
|
207
|
-
).tolist()
|
208
|
-
|
209
|
-
if args.tilt_weighting is not None and tilt_step is None:
|
210
|
-
raise ValueError(
|
211
|
-
"Tilt weighting is not supported for continuous wedges."
|
212
|
-
)
|
213
|
-
if args.tilt_weighting not in ("angle", None):
|
214
|
-
raise ValueError(
|
215
|
-
"Tilt weighting schemes other than 'angle' or 'None' require "
|
216
|
-
"a specification of electron doses via --tilt_angles."
|
217
|
-
)
|
218
|
-
|
219
|
-
wedge = Wedge(
|
220
|
-
angles=tilt_angles,
|
221
|
-
opening_axis=args.wedge_axes[0],
|
222
|
-
tilt_axis=args.wedge_axes[1],
|
223
|
-
shape=None,
|
224
|
-
weight_type=None,
|
225
|
-
weights=np.ones_like(tilt_angles),
|
183
|
+
tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
|
184
|
+
wedge = WedgeReconstructed(
|
185
|
+
angles=(tilt_start, tilt_stop),
|
186
|
+
create_continuous_wedge=True,
|
187
|
+
weight_wedge=False,
|
188
|
+
reconstruction_filter=args.reconstruction_filter,
|
226
189
|
)
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
wedge_target = WedgeReconstructed(
|
237
|
-
angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
|
238
|
-
weight_wedge=False,
|
239
|
-
create_continuous_wedge=True,
|
240
|
-
opening_axis=args.wedge_axes[0],
|
241
|
-
tilt_axis=args.wedge_axes[1],
|
242
|
-
)
|
243
|
-
target_filter.append(wedge_target)
|
190
|
+
wedge.opening_axis, wedge.tilt_axis = args.wedge_axes
|
191
|
+
|
192
|
+
wedge_target = WedgeReconstructed(
|
193
|
+
angles=wedge.angles,
|
194
|
+
weight_wedge=False,
|
195
|
+
create_continuous_wedge=True,
|
196
|
+
opening_axis=wedge.opening_axis,
|
197
|
+
tilt_axis=wedge.tilt_axis,
|
198
|
+
)
|
244
199
|
|
245
200
|
wedge.sampling_rate = template.sampling_rate
|
201
|
+
wedge_target.sampling_rate = template.sampling_rate
|
202
|
+
|
203
|
+
target_filter.append(wedge_target)
|
246
204
|
template_filter.append(wedge)
|
247
|
-
if not isinstance(wedge, WedgeReconstructed):
|
248
|
-
reconstruction_filter = ReconstructFromTilt(
|
249
|
-
reconstruction_filter=args.reconstruction_filter,
|
250
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
251
|
-
)
|
252
|
-
template_filter.append(reconstruction_filter)
|
253
205
|
|
254
206
|
if args.ctf_file is not None or args.defocus is not None:
|
255
|
-
|
256
|
-
|
257
|
-
|
207
|
+
try:
|
208
|
+
ctf = CTF.from_file(
|
209
|
+
args.ctf_file,
|
210
|
+
spherical_aberration=args.spherical_aberration,
|
211
|
+
amplitude_contrast=args.amplitude_contrast,
|
212
|
+
acceleration_voltage=args.acceleration_voltage * 1e3,
|
213
|
+
)
|
214
|
+
if (len(ctf.angles) == 0) and wedge is None:
|
215
|
+
raise ValueError(
|
216
|
+
"You requested to specify the CTF per tilt, but did not specify "
|
217
|
+
"tilt angles via --tilt-angles or --ctf-file. "
|
218
|
+
)
|
219
|
+
if len(ctf.angles) == 0:
|
220
|
+
ctf.angles = wedge.angles
|
221
|
+
|
258
222
|
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
259
|
-
if n_tilts_ctfs != n_tils_angles:
|
223
|
+
if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
|
260
224
|
raise ValueError(
|
261
|
-
f"CTF file contains {n_tilts_ctfs}
|
262
|
-
f"
|
263
|
-
"per micrograph."
|
225
|
+
f"CTF file contains {n_tilts_ctfs} tilt, but recieved "
|
226
|
+
f"{n_tils_angles} tilt angles. Expected one angle per tilt"
|
264
227
|
)
|
265
|
-
|
266
|
-
|
267
|
-
ctf
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
defocus_y=None,
|
274
|
-
angles=[0],
|
275
|
-
shape=None,
|
276
|
-
return_real_fourier=True,
|
228
|
+
|
229
|
+
except (FileNotFoundError, AttributeError):
|
230
|
+
ctf = CTFReconstructed(
|
231
|
+
defocus_x=args.defocus,
|
232
|
+
phase_shift=args.phase_shift,
|
233
|
+
amplitude_contrast=args.amplitude_contrast,
|
234
|
+
spherical_aberration=args.spherical_aberration,
|
235
|
+
acceleration_voltage=args.acceleration_voltage * 1e3,
|
277
236
|
)
|
278
|
-
ctf.sampling_rate = template.sampling_rate
|
279
237
|
ctf.flip_phase = args.no_flip_phase
|
280
|
-
ctf.
|
281
|
-
ctf.
|
282
|
-
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
238
|
+
ctf.sampling_rate = template.sampling_rate
|
239
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
283
240
|
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
284
|
-
|
285
|
-
if not needs_reconstruction:
|
286
|
-
template_filter.append(ctf)
|
287
|
-
elif isinstance(template_filter[-1], ReconstructFromTilt):
|
288
|
-
template_filter.insert(-1, ctf)
|
289
|
-
else:
|
290
|
-
template_filter.insert(0, ctf)
|
291
|
-
template_filter.insert(
|
292
|
-
1,
|
293
|
-
ReconstructFromTilt(
|
294
|
-
reconstruction_filter=args.reconstruction_filter,
|
295
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
296
|
-
),
|
297
|
-
)
|
241
|
+
template_filter.append(ctf)
|
298
242
|
|
299
243
|
if args.lowpass or args.highpass is not None:
|
300
244
|
lowpass, highpass = args.lowpass, args.highpass
|
@@ -315,7 +259,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
315
259
|
except Exception:
|
316
260
|
pass
|
317
261
|
|
318
|
-
bandpass =
|
262
|
+
bandpass = BandPassReconstructed(
|
319
263
|
use_gaussian=args.no_pass_smooth,
|
320
264
|
lowpass=lowpass,
|
321
265
|
highpass=highpass,
|
@@ -329,10 +273,30 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
329
273
|
template_filter.append(whitening_filter)
|
330
274
|
target_filter.append(whitening_filter)
|
331
275
|
|
332
|
-
|
276
|
+
rec_filt = (Wedge, CTF)
|
277
|
+
needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
|
278
|
+
if needs_reconstruction > 0 and args.reconstruction_filter is None:
|
333
279
|
warnings.warn(
|
334
|
-
"Consider using a --reconstruction_filter such as '
|
280
|
+
"Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
|
281
|
+
"to avoid artifacts from reconstruction using weighted backprojection."
|
282
|
+
)
|
283
|
+
|
284
|
+
template_filter = sorted(
|
285
|
+
template_filter, key=lambda x: type(x) in rec_filt, reverse=True
|
286
|
+
)
|
287
|
+
if needs_reconstruction > 0:
|
288
|
+
relevant_filters = [x for x in template_filter if type(x) in rec_filt]
|
289
|
+
if len(relevant_filters) == 0:
|
290
|
+
raise ValueError("Filters require ")
|
291
|
+
|
292
|
+
reconstruction_filter = ReconstructFromTilt(
|
293
|
+
reconstruction_filter=args.reconstruction_filter,
|
294
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
295
|
+
angles=relevant_filters[0].angles,
|
296
|
+
opening_axis=args.wedge_axes[0],
|
297
|
+
tilt_axis=args.wedge_axes[1],
|
335
298
|
)
|
299
|
+
template_filter.insert(needs_reconstruction, reconstruction_filter)
|
336
300
|
|
337
301
|
template_filter = Compose(template_filter) if len(template_filter) else None
|
338
302
|
target_filter = Compose(target_filter) if len(target_filter) else None
|
@@ -342,6 +306,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
342
306
|
return template_filter, target_filter
|
343
307
|
|
344
308
|
|
309
|
+
def _format_sampling(arr, decimals: int = 2):
|
310
|
+
return tuple(round(float(x), decimals) for x in arr)
|
311
|
+
|
312
|
+
|
345
313
|
def parse_args():
|
346
314
|
parser = argparse.ArgumentParser(
|
347
315
|
description="Perform template matching.",
|
@@ -352,7 +320,6 @@ def parse_args():
|
|
352
320
|
io_group.add_argument(
|
353
321
|
"-m",
|
354
322
|
"--target",
|
355
|
-
dest="target",
|
356
323
|
type=str,
|
357
324
|
required=True,
|
358
325
|
help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
|
@@ -360,8 +327,8 @@ def parse_args():
|
|
360
327
|
"https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
|
361
328
|
)
|
362
329
|
io_group.add_argument(
|
363
|
-
"
|
364
|
-
|
330
|
+
"-M",
|
331
|
+
"--target-mask",
|
365
332
|
type=str,
|
366
333
|
required=False,
|
367
334
|
help="Path to a mask for the target in a supported format (see target).",
|
@@ -369,14 +336,13 @@ def parse_args():
|
|
369
336
|
io_group.add_argument(
|
370
337
|
"-i",
|
371
338
|
"--template",
|
372
|
-
dest="template",
|
373
339
|
type=str,
|
374
340
|
required=True,
|
375
341
|
help="Path to a template in PDB/MMCIF or other supported formats (see target).",
|
376
342
|
)
|
377
343
|
io_group.add_argument(
|
378
|
-
"
|
379
|
-
|
344
|
+
"-I",
|
345
|
+
"--template-mask",
|
380
346
|
type=str,
|
381
347
|
required=False,
|
382
348
|
help="Path to a mask for the template in a supported format (see target).",
|
@@ -384,32 +350,62 @@ def parse_args():
|
|
384
350
|
io_group.add_argument(
|
385
351
|
"-o",
|
386
352
|
"--output",
|
387
|
-
dest="output",
|
388
353
|
type=str,
|
389
354
|
required=False,
|
390
355
|
default="output.pickle",
|
391
356
|
help="Path to the output pickle file.",
|
392
357
|
)
|
393
358
|
io_group.add_argument(
|
394
|
-
"--
|
395
|
-
dest="invert_target_contrast",
|
359
|
+
"--invert-target-contrast",
|
396
360
|
action="store_true",
|
397
361
|
default=False,
|
398
362
|
help="Invert the target's contrast for cases where templates to-be-matched have "
|
399
363
|
"negative values, e.g. tomograms.",
|
400
364
|
)
|
401
365
|
io_group.add_argument(
|
402
|
-
"--
|
403
|
-
dest="scramble_phases",
|
366
|
+
"--scramble-phases",
|
404
367
|
action="store_true",
|
405
368
|
default=False,
|
406
369
|
help="Phase scramble the template to generate a noise score background.",
|
407
370
|
)
|
408
371
|
|
372
|
+
sampling_group = parser.add_argument_group("Sampling")
|
373
|
+
sampling_group.add_argument(
|
374
|
+
"--orientations",
|
375
|
+
default=None,
|
376
|
+
required=False,
|
377
|
+
help="Path to a file readable by Orientations.from_file containing "
|
378
|
+
"translations and rotations of candidate peaks to refine.",
|
379
|
+
)
|
380
|
+
sampling_group.add_argument(
|
381
|
+
"--orientations-scaling",
|
382
|
+
required=False,
|
383
|
+
type=float,
|
384
|
+
default=1.0,
|
385
|
+
help="Scaling factor to map candidate translations onto the target. "
|
386
|
+
"Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
|
387
|
+
"the corresponding --orientations-scaling would be 3.",
|
388
|
+
)
|
389
|
+
sampling_group.add_argument(
|
390
|
+
"--orientations-cone",
|
391
|
+
required=False,
|
392
|
+
type=float,
|
393
|
+
default=20.0,
|
394
|
+
help="Accept orientations within specified cone angle of each orientation.",
|
395
|
+
)
|
396
|
+
sampling_group.add_argument(
|
397
|
+
"--orientations-uncertainty",
|
398
|
+
required=False,
|
399
|
+
type=str,
|
400
|
+
default="10",
|
401
|
+
help="Accept translations within the specified radius of each orientation. "
|
402
|
+
"Can be a single value or comma-separated string for per-axis uncertainty.",
|
403
|
+
)
|
404
|
+
|
409
405
|
scoring_group = parser.add_argument_group("Scoring")
|
410
406
|
scoring_group.add_argument(
|
411
407
|
"-s",
|
412
|
-
|
408
|
+
"--score",
|
413
409
|
type=str,
|
414
410
|
default="FLCSphericalMask",
|
415
411
|
choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
|
@@ -421,67 +417,64 @@ def parse_args():
|
|
421
417
|
|
422
418
|
angular_exclusive.add_argument(
|
423
419
|
"-a",
|
424
|
-
|
420
|
+
"--angular-sampling",
|
425
421
|
type=check_positive,
|
426
422
|
default=None,
|
427
|
-
help="Angular sampling rate
|
428
|
-
"A lower number yields more rotations. Values >= 180 sample only the identity.",
|
423
|
+
help="Angular sampling rate. Lower values = more rotations, higher precision.",
|
429
424
|
)
|
430
425
|
angular_exclusive.add_argument(
|
431
|
-
"--
|
432
|
-
dest="cone_angle",
|
426
|
+
"--cone-angle",
|
433
427
|
type=check_positive,
|
434
428
|
default=None,
|
435
429
|
help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
|
436
430
|
"narrow interval around a known orientation, e.g. for surface oversampling.",
|
437
431
|
)
|
432
|
+
angular_exclusive.add_argument(
|
433
|
+
"--particle-diameter",
|
434
|
+
type=check_positive,
|
435
|
+
default=None,
|
436
|
+
help="Particle diameter in units of sampling rate.",
|
437
|
+
)
|
438
438
|
angular_group.add_argument(
|
439
|
-
"--
|
440
|
-
dest="cone_axis",
|
439
|
+
"--cone-axis",
|
441
440
|
type=check_positive,
|
442
441
|
default=2,
|
443
442
|
help="Principal axis to build cone around.",
|
444
443
|
)
|
445
444
|
angular_group.add_argument(
|
446
|
-
"--
|
447
|
-
dest="invert_cone",
|
445
|
+
"--invert-cone",
|
448
446
|
action="store_true",
|
449
447
|
help="Invert cone handedness.",
|
450
448
|
)
|
451
449
|
angular_group.add_argument(
|
452
|
-
"--
|
453
|
-
dest="cone_sampling",
|
450
|
+
"--cone-sampling",
|
454
451
|
type=check_positive,
|
455
452
|
default=None,
|
456
453
|
help="Sampling rate of the cone in degrees.",
|
457
454
|
)
|
458
455
|
angular_group.add_argument(
|
459
|
-
"--
|
460
|
-
dest="axis_angle",
|
456
|
+
"--axis-angle",
|
461
457
|
type=check_positive,
|
462
458
|
default=360.0,
|
463
459
|
required=False,
|
464
460
|
help="Sampling angle along the z-axis of the cone.",
|
465
461
|
)
|
466
462
|
angular_group.add_argument(
|
467
|
-
"--
|
468
|
-
dest="axis_sampling",
|
463
|
+
"--axis-sampling",
|
469
464
|
type=check_positive,
|
470
465
|
default=None,
|
471
466
|
required=False,
|
472
|
-
help="Sampling rate along the z-axis of the cone. Defaults to --
|
467
|
+
help="Sampling rate along the z-axis of the cone. Defaults to --cone-sampling.",
|
473
468
|
)
|
474
469
|
angular_group.add_argument(
|
475
|
-
"--
|
476
|
-
dest="axis_symmetry",
|
470
|
+
"--axis-symmetry",
|
477
471
|
type=check_positive,
|
478
472
|
default=1,
|
479
473
|
required=False,
|
480
474
|
help="N-fold symmetry around z-axis of the cone.",
|
481
475
|
)
|
482
476
|
angular_group.add_argument(
|
483
|
-
"--
|
484
|
-
dest="no_use_optimized_set",
|
477
|
+
"--no-use-optimized-set",
|
485
478
|
action="store_true",
|
486
479
|
default=False,
|
487
480
|
required=False,
|
@@ -498,56 +491,40 @@ def parse_args():
|
|
498
491
|
help="Number of cores used for template matching.",
|
499
492
|
)
|
500
493
|
computation_group.add_argument(
|
501
|
-
"--
|
502
|
-
dest="use_gpu",
|
503
|
-
action="store_true",
|
504
|
-
default=False,
|
505
|
-
help="Whether to perform computations on the GPU.",
|
506
|
-
)
|
507
|
-
computation_group.add_argument(
|
508
|
-
"--gpu_indices",
|
509
|
-
dest="gpu_indices",
|
494
|
+
"--gpu-indices",
|
510
495
|
type=str,
|
511
496
|
default=None,
|
512
|
-
help="Comma-separated
|
513
|
-
"
|
514
|
-
" If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
|
515
|
-
" be respected.",
|
497
|
+
help="Comma-separated GPU indices (e.g., '0,1,2' for first 3 GPUs). Otherwise "
|
498
|
+
"CUDA_VISIBLE_DEVICES will be used.",
|
516
499
|
)
|
517
500
|
computation_group.add_argument(
|
518
|
-
"
|
519
|
-
"--ram",
|
520
|
-
dest="memory",
|
501
|
+
"--memory",
|
521
502
|
required=False,
|
522
503
|
type=int,
|
523
504
|
default=None,
|
524
505
|
help="Amount of memory that can be used in bytes.",
|
525
506
|
)
|
526
507
|
computation_group.add_argument(
|
527
|
-
"--
|
528
|
-
dest="memory_scaling",
|
508
|
+
"--memory-scaling",
|
529
509
|
required=False,
|
530
510
|
type=float,
|
531
511
|
default=0.85,
|
532
|
-
help="Fraction of available memory to be used. Ignored if --
|
512
|
+
help="Fraction of available memory to be used. Ignored if --memory is set.",
|
533
513
|
)
|
534
514
|
computation_group.add_argument(
|
535
|
-
"--
|
536
|
-
|
537
|
-
|
538
|
-
help="Directory for temporary objects. Faster I/O improves runtime.",
|
515
|
+
"--temp-directory",
|
516
|
+
default=gettempdir(),
|
517
|
+
help="Temporary directory for memmaps. Better I/O improves runtime.",
|
539
518
|
)
|
540
519
|
computation_group.add_argument(
|
541
520
|
"--backend",
|
542
|
-
|
543
|
-
default=None,
|
521
|
+
default=be._backend_name,
|
544
522
|
choices=be.available_backends(),
|
545
|
-
help="
|
523
|
+
help="Set computation backend.",
|
546
524
|
)
|
547
525
|
filter_group = parser.add_argument_group("Filters")
|
548
526
|
filter_group.add_argument(
|
549
527
|
"--lowpass",
|
550
|
-
dest="lowpass",
|
551
528
|
type=float,
|
552
529
|
required=False,
|
553
530
|
help="Resolution to lowpass filter template and target to in the same unit "
|
@@ -555,58 +532,54 @@ def parse_args():
|
|
555
532
|
)
|
556
533
|
filter_group.add_argument(
|
557
534
|
"--highpass",
|
558
|
-
dest="highpass",
|
559
535
|
type=float,
|
560
536
|
required=False,
|
561
537
|
help="Resolution to highpass filter template and target to in the same unit "
|
562
538
|
"as the sampling rate of template and target (typically Ångstrom).",
|
563
539
|
)
|
564
540
|
filter_group.add_argument(
|
565
|
-
"--
|
566
|
-
dest="no_pass_smooth",
|
541
|
+
"--no-pass-smooth",
|
567
542
|
action="store_false",
|
568
543
|
default=True,
|
569
544
|
help="Whether a hard edge filter should be used for --lowpass and --highpass.",
|
570
545
|
)
|
571
546
|
filter_group.add_argument(
|
572
|
-
"--
|
573
|
-
dest="pass_format",
|
547
|
+
"--pass-format",
|
574
548
|
type=str,
|
575
549
|
required=False,
|
576
550
|
default="sampling_rate",
|
577
551
|
choices=["sampling_rate", "voxel", "frequency"],
|
578
|
-
help="How values passed to --lowpass and --highpass should be interpreted. "
|
552
|
+
help="How values passed to --lowpass and --highpass should be interpreted. "
|
553
|
+
"Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
|
579
554
|
)
|
580
555
|
filter_group.add_argument(
|
581
|
-
"--
|
582
|
-
dest="whiten_spectrum",
|
556
|
+
"--whiten-spectrum",
|
583
557
|
action="store_true",
|
584
558
|
default=None,
|
585
559
|
help="Apply spectral whitening to template and target based on target spectrum.",
|
586
560
|
)
|
587
561
|
filter_group.add_argument(
|
588
|
-
"--
|
589
|
-
dest="wedge_axes",
|
562
|
+
"--wedge-axes",
|
590
563
|
type=str,
|
591
564
|
required=False,
|
592
|
-
default=
|
593
|
-
help="Indices of wedge opening and tilt axis, e.g
|
594
|
-
"
|
565
|
+
default="2,0",
|
566
|
+
help="Indices of projection (wedge opening) and tilt axis, e.g., '2,0' "
|
567
|
+
"for the typical projection over z and tilting over the x-axis.",
|
595
568
|
)
|
596
569
|
filter_group.add_argument(
|
597
|
-
"--
|
598
|
-
dest="tilt_angles",
|
570
|
+
"--tilt-angles",
|
599
571
|
type=str,
|
600
572
|
required=False,
|
601
573
|
default=None,
|
602
|
-
help="Path to a
|
603
|
-
"
|
604
|
-
"
|
605
|
-
"
|
574
|
+
help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
|
575
|
+
"a tomostar STAR file, an MMOD file, a tab-separated file with column name "
|
576
|
+
"'angles', or a single column file without header. Exposure will be taken from "
|
577
|
+
"the input file , if you are using a tab-separated file, the column names "
|
578
|
+
"'angles' and 'weights' need to be present. It is also possible to specify a "
|
579
|
+
"continuous wedge mask using e.g., -50,45.",
|
606
580
|
)
|
607
581
|
filter_group.add_argument(
|
608
|
-
"--
|
609
|
-
dest="tilt_weighting",
|
582
|
+
"--tilt-weighting",
|
610
583
|
type=str,
|
611
584
|
required=False,
|
612
585
|
choices=["angle", "relion", "grigorieff"],
|
@@ -615,28 +588,25 @@ def parse_args():
|
|
615
588
|
"angle (cosine based weighting), "
|
616
589
|
"relion (relion formalism for wedge weighting) requires,"
|
617
590
|
"grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
|
618
|
-
"relion and grigorieff require electron doses in --
|
591
|
+
"relion and grigorieff require electron doses in --tilt-angles weights column.",
|
619
592
|
)
|
620
593
|
filter_group.add_argument(
|
621
|
-
"--
|
622
|
-
dest="reconstruction_filter",
|
594
|
+
"--reconstruction-filter",
|
623
595
|
type=str,
|
624
596
|
required=False,
|
625
597
|
choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
|
626
|
-
default=
|
598
|
+
default="ramp",
|
627
599
|
help="Filter applied when reconstructing (N+1)-D from N-D filters.",
|
628
600
|
)
|
629
601
|
filter_group.add_argument(
|
630
|
-
"--
|
631
|
-
dest="reconstruction_interpolation_order",
|
602
|
+
"--reconstruction-interpolation-order",
|
632
603
|
type=int,
|
633
604
|
default=1,
|
634
605
|
required=False,
|
635
|
-
help="Analogous to --
|
606
|
+
help="Analogous to --interpolation-order but for reconstruction.",
|
636
607
|
)
|
637
608
|
filter_group.add_argument(
|
638
|
-
"--
|
639
|
-
dest="no_filter_target",
|
609
|
+
"--no-filter-target",
|
640
610
|
action="store_true",
|
641
611
|
default=False,
|
642
612
|
help="Whether to not apply potential filters to the target.",
|
@@ -644,65 +614,58 @@ def parse_args():
|
|
644
614
|
|
645
615
|
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
646
616
|
ctf_group.add_argument(
|
647
|
-
"--
|
648
|
-
dest="ctf_file",
|
617
|
+
"--ctf-file",
|
649
618
|
type=str,
|
650
619
|
required=False,
|
651
620
|
default=None,
|
652
|
-
help="Path to a file with CTF parameters
|
653
|
-
"
|
621
|
+
help="Path to a file with CTF parameters. This can be a Warp/M XML file "
|
622
|
+
"a GCTF/Relion STAR file, an MDOC file, or the output of CTFFIND4. If the file "
|
623
|
+
" does not specify tilt angles, --tilt-angles are used.",
|
654
624
|
)
|
655
625
|
ctf_group.add_argument(
|
656
626
|
"--defocus",
|
657
|
-
dest="defocus",
|
658
627
|
type=float,
|
659
628
|
required=False,
|
660
629
|
default=None,
|
661
|
-
help="Defocus in units of sampling rate (typically Ångstrom). "
|
662
|
-
"Superseded by --
|
630
|
+
help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
|
631
|
+
"for a defocus of 3 micrometer. Superseded by --ctf-file.",
|
663
632
|
)
|
664
633
|
ctf_group.add_argument(
|
665
|
-
"--
|
666
|
-
dest="phase_shift",
|
634
|
+
"--phase-shift",
|
667
635
|
type=float,
|
668
636
|
required=False,
|
669
637
|
default=0,
|
670
|
-
help="Phase shift in degrees. Superseded by --
|
638
|
+
help="Phase shift in degrees. Superseded by --ctf-file.",
|
671
639
|
)
|
672
640
|
ctf_group.add_argument(
|
673
|
-
"--
|
674
|
-
dest="acceleration_voltage",
|
641
|
+
"--acceleration-voltage",
|
675
642
|
type=float,
|
676
643
|
required=False,
|
677
644
|
default=300,
|
678
645
|
help="Acceleration voltage in kV.",
|
679
646
|
)
|
680
647
|
ctf_group.add_argument(
|
681
|
-
"--
|
682
|
-
dest="spherical_aberration",
|
648
|
+
"--spherical-aberration",
|
683
649
|
type=float,
|
684
650
|
required=False,
|
685
651
|
default=2.7e7,
|
686
652
|
help="Spherical aberration in units of sampling rate (typically Ångstrom).",
|
687
653
|
)
|
688
654
|
ctf_group.add_argument(
|
689
|
-
"--
|
690
|
-
dest="amplitude_contrast",
|
655
|
+
"--amplitude-contrast",
|
691
656
|
type=float,
|
692
657
|
required=False,
|
693
658
|
default=0.07,
|
694
659
|
help="Amplitude contrast.",
|
695
660
|
)
|
696
661
|
ctf_group.add_argument(
|
697
|
-
"--
|
698
|
-
dest="no_flip_phase",
|
662
|
+
"--no-flip-phase",
|
699
663
|
action="store_false",
|
700
664
|
required=False,
|
701
665
|
help="Do not perform phase-flipping CTF correction.",
|
702
666
|
)
|
703
667
|
ctf_group.add_argument(
|
704
|
-
"--
|
705
|
-
dest="correct_defocus_gradient",
|
668
|
+
"--correct-defocus-gradient",
|
706
669
|
action="store_true",
|
707
670
|
required=False,
|
708
671
|
help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
|
@@ -711,55 +674,49 @@ def parse_args():
|
|
711
674
|
|
712
675
|
performance_group = parser.add_argument_group("Performance")
|
713
676
|
performance_group.add_argument(
|
714
|
-
"--
|
715
|
-
dest="no_centering",
|
677
|
+
"--centering",
|
716
678
|
action="store_true",
|
717
|
-
help="
|
679
|
+
help="Center the template in the box if it has not been done already.",
|
718
680
|
)
|
719
681
|
performance_group.add_argument(
|
720
|
-
"--
|
721
|
-
dest="pad_edges",
|
682
|
+
"--pad-edges",
|
722
683
|
action="store_true",
|
723
684
|
default=False,
|
724
|
-
help="
|
725
|
-
"
|
685
|
+
help="Useful if the target does not have a well-defined bounding box. Will be "
|
686
|
+
"activated automatically if splitting is required to avoid boundary artifacts.",
|
726
687
|
)
|
727
688
|
performance_group.add_argument(
|
728
|
-
"--
|
729
|
-
dest="pad_filter",
|
689
|
+
"--pad-filter",
|
730
690
|
action="store_true",
|
731
691
|
default=False,
|
732
|
-
help="
|
692
|
+
help="Pad the template filter to the shape of the target. Useful for fast "
|
733
693
|
"oscilating filters to avoid aliasing effects.",
|
734
694
|
)
|
735
695
|
performance_group.add_argument(
|
736
|
-
"--
|
737
|
-
dest="interpolation_order",
|
696
|
+
"--interpolation-order",
|
738
697
|
required=False,
|
739
698
|
type=int,
|
740
|
-
default=
|
741
|
-
help="Spline interpolation used for rotations."
|
699
|
+
default=None,
|
700
|
+
help="Spline interpolation used for rotations. Defaults to 3, and 1 for jax "
|
701
|
+
"and pytorch backends.",
|
742
702
|
)
|
743
703
|
performance_group.add_argument(
|
744
|
-
"--
|
745
|
-
dest="use_mixed_precision",
|
704
|
+
"--use-mixed-precision",
|
746
705
|
action="store_true",
|
747
706
|
default=False,
|
748
|
-
help="Use float16 for real values operations where possible."
|
707
|
+
help="Use float16 for real values operations where possible. Not supported "
|
708
|
+
"for jax backend.",
|
749
709
|
)
|
750
710
|
performance_group.add_argument(
|
751
|
-
"--
|
752
|
-
dest="use_memmap",
|
711
|
+
"--use-memmap",
|
753
712
|
action="store_true",
|
754
713
|
default=False,
|
755
|
-
help="
|
756
|
-
"Particularly useful for large inputs in combination with --use_gpu.",
|
714
|
+
help="Memmap large data to disk, e.g., matching on unbinned tomograms.",
|
757
715
|
)
|
758
716
|
|
759
|
-
analyzer_group = parser.add_argument_group("
|
717
|
+
analyzer_group = parser.add_argument_group("Output / Analysis")
|
760
718
|
analyzer_group.add_argument(
|
761
|
-
"--
|
762
|
-
dest="score_threshold",
|
719
|
+
"--score-threshold",
|
763
720
|
required=False,
|
764
721
|
type=float,
|
765
722
|
default=0,
|
@@ -767,21 +724,25 @@ def parse_args():
|
|
767
724
|
)
|
768
725
|
analyzer_group.add_argument(
|
769
726
|
"-p",
|
770
|
-
|
727
|
+
"--peak-calling",
|
771
728
|
action="store_true",
|
772
729
|
default=False,
|
773
730
|
help="Perform peak calling instead of score aggregation.",
|
774
731
|
)
|
775
732
|
analyzer_group.add_argument(
|
776
|
-
"--
|
777
|
-
|
778
|
-
action="store_true",
|
733
|
+
"--num-peaks",
|
734
|
+
type=int,
|
779
735
|
default=1000,
|
780
736
|
help="Number of peaks to call, 1000 by default.",
|
781
737
|
)
|
782
738
|
args = parser.parse_args()
|
783
739
|
args.version = __version__
|
784
740
|
|
741
|
+
if args.interpolation_order is None:
|
742
|
+
args.interpolation_order = 3
|
743
|
+
if args.backend in ("jax", "pytorch"):
|
744
|
+
args.interpolation_order = 1
|
745
|
+
|
785
746
|
if args.interpolation_order < 0:
|
786
747
|
args.interpolation_order = None
|
787
748
|
|
@@ -789,41 +750,42 @@ def parse_args():
|
|
789
750
|
args.temp_directory = gettempdir()
|
790
751
|
|
791
752
|
os.environ["TMPDIR"] = args.temp_directory
|
792
|
-
if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
|
793
|
-
raise ValueError(
|
794
|
-
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
795
|
-
)
|
796
|
-
|
797
|
-
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
798
753
|
if args.gpu_indices is not None:
|
799
754
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
|
800
755
|
|
801
|
-
if args.
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
756
|
+
if args.tilt_angles is not None and not exists(args.tilt_angles):
|
757
|
+
try:
|
758
|
+
float(args.tilt_angles.split(",")[0])
|
759
|
+
except Exception:
|
760
|
+
raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
|
761
|
+
|
762
|
+
if args.ctf_file is not None and args.tilt_angles is None:
|
763
|
+
# Check if tilt angles can be extracted from CTF specification
|
764
|
+
try:
|
765
|
+
ctf = CTF.from_file(args.ctf_file)
|
766
|
+
if ctf.angles is None:
|
767
|
+
raise ValueError
|
768
|
+
args.tilt_angles = args.ctf_file
|
769
|
+
except Exception:
|
770
|
+
raise ValueError(
|
771
|
+
"Need to specify --tilt-angles when not provided in --ctf-file."
|
807
772
|
)
|
808
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
809
|
-
args.gpu_indices = [
|
810
|
-
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
811
|
-
]
|
812
773
|
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
|
774
|
+
args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
|
775
|
+
if args.orientations is not None:
|
776
|
+
orientations = Orientations.from_file(args.orientations)
|
777
|
+
orientations.translations = np.divide(
|
778
|
+
orientations.translations, args.orientations_scaling
|
779
|
+
)
|
780
|
+
args.orientations = orientations
|
821
781
|
|
822
|
-
|
823
|
-
|
782
|
+
args.target = abspath(args.target)
|
783
|
+
if args.target_mask is not None:
|
784
|
+
args.target_mask = abspath(args.target_mask)
|
824
785
|
|
825
|
-
|
826
|
-
|
786
|
+
args.template = abspath(args.template)
|
787
|
+
if args.template_mask is not None:
|
788
|
+
args.template_mask = abspath(args.template_mask)
|
827
789
|
|
828
790
|
return args
|
829
791
|
|
@@ -864,7 +826,7 @@ def main():
|
|
864
826
|
name="Target",
|
865
827
|
data={
|
866
828
|
"Initial Shape": initial_shape,
|
867
|
-
"Sampling Rate":
|
829
|
+
"Sampling Rate": _format_sampling(target.sampling_rate),
|
868
830
|
"Final Shape": target.shape,
|
869
831
|
},
|
870
832
|
)
|
@@ -874,27 +836,28 @@ def main():
|
|
874
836
|
name="Target Mask",
|
875
837
|
data={
|
876
838
|
"Initial Shape": initial_shape,
|
877
|
-
"Sampling Rate":
|
839
|
+
"Sampling Rate": _format_sampling(target_mask.sampling_rate),
|
878
840
|
"Final Shape": target_mask.shape,
|
879
841
|
},
|
880
842
|
)
|
881
843
|
|
882
844
|
initial_shape = template.shape
|
883
845
|
translation = np.zeros(len(template.shape), dtype=np.float32)
|
884
|
-
if
|
846
|
+
if args.centering:
|
885
847
|
template, translation = template.centered(0)
|
848
|
+
|
886
849
|
print_block(
|
887
850
|
name="Template",
|
888
851
|
data={
|
889
852
|
"Initial Shape": initial_shape,
|
890
|
-
"Sampling Rate":
|
853
|
+
"Sampling Rate": _format_sampling(template.sampling_rate),
|
891
854
|
"Final Shape": template.shape,
|
892
855
|
},
|
893
856
|
)
|
894
857
|
|
895
858
|
if template_mask is None:
|
896
859
|
template_mask = template.empty
|
897
|
-
if not args.
|
860
|
+
if not args.centering:
|
898
861
|
enclosing_box = template.minimum_enclosing_box(
|
899
862
|
0, use_geometric_center=False
|
900
863
|
)
|
@@ -919,7 +882,7 @@ def main():
|
|
919
882
|
name="Template Mask",
|
920
883
|
data={
|
921
884
|
"Inital Shape": initial_shape,
|
922
|
-
"Sampling Rate":
|
885
|
+
"Sampling Rate": _format_sampling(template_mask.sampling_rate),
|
923
886
|
"Final Shape": template_mask.shape,
|
924
887
|
},
|
925
888
|
)
|
@@ -930,65 +893,78 @@ def main():
|
|
930
893
|
template.data, noise_proportion=1.0, normalize_power=False
|
931
894
|
)
|
932
895
|
|
896
|
+
callback_class = MaxScoreOverRotations
|
897
|
+
if args.peak_calling:
|
898
|
+
callback_class = PeakCallerMaximumFilter
|
899
|
+
|
900
|
+
if args.orientations is not None:
|
901
|
+
callback_class = MaxScoreOverRotationsConstrained
|
902
|
+
|
933
903
|
# Determine suitable backend for the selected operation
|
934
904
|
available_backends = be.available_backends()
|
935
|
-
if args.backend
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
|
942
|
-
if args.use_gpu:
|
943
|
-
args.cores = len(args.gpu_indices)
|
944
|
-
be_selection = ("pytorch", "cupy", "jax")
|
945
|
-
if args.use_mixed_precision:
|
946
|
-
be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
|
947
|
-
|
948
|
-
available_backends = [x for x in available_backends if x in be_selection]
|
949
|
-
if args.peak_calling:
|
950
|
-
if "jax" in available_backends:
|
951
|
-
available_backends.remove("jax")
|
952
|
-
if args.use_gpu and "pytorch" in available_backends:
|
953
|
-
available_backends = ("pytorch",)
|
954
|
-
|
955
|
-
# dim_match = len(template.shape) == len(target.shape) <= 3
|
956
|
-
# if dim_match and args.use_gpu and "jax" in available_backends:
|
957
|
-
# args.interpolation_order = 1
|
958
|
-
# available_backends = ["jax"]
|
959
|
-
|
960
|
-
backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
|
961
|
-
if args.use_gpu:
|
962
|
-
backend_preference = ("cupy", "pytorch", "jax")
|
963
|
-
for pref in backend_preference:
|
964
|
-
if pref not in available_backends:
|
965
|
-
continue
|
966
|
-
be.change_backend(pref)
|
967
|
-
if pref == "pytorch":
|
968
|
-
be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
|
969
|
-
|
970
|
-
if args.use_mixed_precision:
|
971
|
-
be.change_backend(
|
972
|
-
backend_name=pref,
|
973
|
-
default_dtype=be._array_backend.float16,
|
974
|
-
complex_dtype=be._array_backend.complex64,
|
975
|
-
default_dtype_int=be._array_backend.int16,
|
976
|
-
)
|
977
|
-
break
|
905
|
+
if args.backend not in available_backends:
|
906
|
+
raise ValueError("Requested backend is not available.")
|
907
|
+
if args.backend == "jax" and callback_class != MaxScoreOverRotations:
|
908
|
+
raise ValueError(
|
909
|
+
"Jax backend only supports the MaxScoreOverRotations analyzer."
|
910
|
+
)
|
978
911
|
|
979
|
-
if
|
912
|
+
if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
|
980
913
|
warnings.warn(
|
981
|
-
"
|
914
|
+
"Jax and pytorch do not support interpolation order 3, setting it to 1."
|
982
915
|
)
|
983
916
|
args.interpolation_order = 1
|
984
917
|
|
918
|
+
if args.backend in ("pytorch", "cupy", "jax"):
|
919
|
+
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
920
|
+
if gpu_devices is None:
|
921
|
+
warnings.warn(
|
922
|
+
"No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
|
923
|
+
"Assuming device 0.",
|
924
|
+
)
|
925
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
926
|
+
|
927
|
+
args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
|
928
|
+
args.gpu_indices = [
|
929
|
+
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
930
|
+
]
|
931
|
+
|
932
|
+
# Finally set the desired backend
|
933
|
+
device = "cuda"
|
934
|
+
args.use_gpu = False
|
935
|
+
be.change_backend(args.backend)
|
936
|
+
if args.backend in ("jax", "pytorch", "cupy"):
|
937
|
+
args.use_gpu = True
|
938
|
+
|
939
|
+
if args.backend == "pytorch":
|
940
|
+
try:
|
941
|
+
be.change_backend("pytorch", device=device)
|
942
|
+
# Trigger exception if not compiled with device
|
943
|
+
be.get_available_memory()
|
944
|
+
except Exception as e:
|
945
|
+
print(e)
|
946
|
+
device = "cpu"
|
947
|
+
args.use_gpu = False
|
948
|
+
be.change_backend("pytorch", device=device)
|
949
|
+
|
950
|
+
# TODO: Make the inverse casting from complex64 -> float 16 stable
|
951
|
+
# if args.use_mixed_precision:
|
952
|
+
# be.change_backend(
|
953
|
+
# backend_name=args.backend,
|
954
|
+
# float_dtype=be._array_backend.float16,
|
955
|
+
# complex_dtype=be._array_backend.complex64,
|
956
|
+
# int_dtype=be._array_backend.int16,
|
957
|
+
# device=device,
|
958
|
+
# )
|
959
|
+
|
985
960
|
available_memory = be.get_available_memory() * be.device_count()
|
986
961
|
if args.memory is None:
|
987
962
|
args.memory = int(args.memory_scaling * available_memory)
|
988
963
|
|
989
|
-
|
990
|
-
|
991
|
-
|
964
|
+
if args.orientations_uncertainty is not None:
|
965
|
+
args.orientations_uncertainty = tuple(
|
966
|
+
int(x) for x in args.orientations_uncertainty.split(",")
|
967
|
+
)
|
992
968
|
|
993
969
|
matching_data = MatchingData(
|
994
970
|
target=target,
|
@@ -1018,7 +994,7 @@ def main():
|
|
1018
994
|
options = {
|
1019
995
|
"Angular Sampling": f"{args.angular_sampling}"
|
1020
996
|
f" [{matching_data.rotations.shape[0]} rotations]",
|
1021
|
-
"Center Template":
|
997
|
+
"Center Template": args.centering,
|
1022
998
|
"Scramble Template": args.scramble_phases,
|
1023
999
|
"Invert Contrast": args.invert_target_contrast,
|
1024
1000
|
"Extend Target Edges": args.pad_edges,
|
@@ -1061,12 +1037,7 @@ def main():
|
|
1061
1037
|
}
|
1062
1038
|
if args.ctf_file is not None or args.defocus is not None:
|
1063
1039
|
filter_args["CTF File"] = args.ctf_file
|
1064
|
-
filter_args["Defocus"] = args.defocus
|
1065
|
-
filter_args["Phase Shift"] = args.phase_shift
|
1066
1040
|
filter_args["Flip Phase"] = args.no_flip_phase
|
1067
|
-
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1068
|
-
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1069
|
-
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
1070
1041
|
filter_args["Correct Defocus"] = args.correct_defocus_gradient
|
1071
1042
|
|
1072
1043
|
filter_args = {k: v for k, v in filter_args.items() if v is not None}
|
@@ -1079,13 +1050,25 @@ def main():
|
|
1079
1050
|
|
1080
1051
|
analyzer_args = {
|
1081
1052
|
"score_threshold": args.score_threshold,
|
1082
|
-
"
|
1053
|
+
"num_peaks": args.num_peaks,
|
1083
1054
|
"min_distance": max(template.shape) // 3,
|
1084
1055
|
"use_memmap": args.use_memmap,
|
1085
1056
|
}
|
1057
|
+
if args.orientations is not None:
|
1058
|
+
analyzer_args["reference"] = (0, 0, 1)
|
1059
|
+
analyzer_args["cone_angle"] = args.orientations_cone
|
1060
|
+
analyzer_args["acceptance_radius"] = args.orientations_uncertainty
|
1061
|
+
analyzer_args["positions"] = args.orientations.translations
|
1062
|
+
analyzer_args["rotations"] = euler_to_rotationmatrix(
|
1063
|
+
args.orientations.rotations
|
1064
|
+
)
|
1065
|
+
|
1086
1066
|
print_block(
|
1087
1067
|
name="Analyzer",
|
1088
|
-
data={
|
1068
|
+
data={
|
1069
|
+
"Analyzer": callback_class,
|
1070
|
+
**{sanitize_name(k): v for k, v in analyzer_args.items()},
|
1071
|
+
},
|
1089
1072
|
label_width=max(len(key) for key in options.keys()) + 3,
|
1090
1073
|
)
|
1091
1074
|
print("\n" + "-" * 80)
|
@@ -1111,18 +1094,6 @@ def main():
|
|
1111
1094
|
)
|
1112
1095
|
|
1113
1096
|
candidates = list(candidates) if candidates is not None else []
|
1114
|
-
if issubclass(callback_class, MaxScoreOverRotations):
|
1115
|
-
if target_mask is not None and args.score != "MCC":
|
1116
|
-
candidates[0] *= target_mask.data
|
1117
|
-
with warnings.catch_warnings():
|
1118
|
-
warnings.simplefilter("ignore", category=UserWarning)
|
1119
|
-
nbytes = be.datatype_bytes(be._float_dtype)
|
1120
|
-
dtype = np.float32 if nbytes == 4 else np.float16
|
1121
|
-
rot_dim = matching_data.rotations.shape[1]
|
1122
|
-
candidates[3] = {
|
1123
|
-
x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1124
|
-
for i, x in candidates[3].items()
|
1125
|
-
}
|
1126
1097
|
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1127
1098
|
write_pickle(data=candidates, filename=args.output)
|
1128
1099
|
|