pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1339 @@
|
|
1
|
+
#!python3
|
2
|
+
"""CLI for basic pyTME template matching functions.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
import os
|
9
|
+
import argparse
|
10
|
+
import warnings
|
11
|
+
from sys import exit
|
12
|
+
from time import time
|
13
|
+
from typing import Tuple
|
14
|
+
from copy import deepcopy
|
15
|
+
from os.path import exists
|
16
|
+
from tempfile import gettempdir
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
|
20
|
+
from tme.backends import backend as be
|
21
|
+
from tme import Density, __version__, Orientations
|
22
|
+
from tme.matching_utils import scramble_phases, write_pickle, generate_tempfile_name
|
23
|
+
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
24
|
+
from tme.rotations import (
|
25
|
+
get_cone_rotations,
|
26
|
+
get_rotation_matrices,
|
27
|
+
euler_to_rotationmatrix,
|
28
|
+
)
|
29
|
+
from tme.matching_data import MatchingData
|
30
|
+
from tme.analyzer import (
|
31
|
+
MaxScoreOverRotations,
|
32
|
+
PeakCallerMaximumFilter,
|
33
|
+
MaxScoreOverRotationsConstrained,
|
34
|
+
)
|
35
|
+
from tme.filters import (
|
36
|
+
CTF,
|
37
|
+
Wedge,
|
38
|
+
Compose,
|
39
|
+
BandPass,
|
40
|
+
ShiftFourier,
|
41
|
+
CTFReconstructed,
|
42
|
+
WedgeReconstructed,
|
43
|
+
ReconstructFromTilt,
|
44
|
+
LinearWhiteningFilter,
|
45
|
+
BandPassReconstructed
|
46
|
+
)
|
47
|
+
from tme.cli import get_func_fullname, print_block, print_entry, check_positive
|
48
|
+
|
49
|
+
|
50
|
+
def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
51
|
+
"""
|
52
|
+
Loadsa mask in CCP4/MRC format and assess whether the sampling_rate
|
53
|
+
and shape matches its target.
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
mask_target : Density
|
58
|
+
Object the mask should be applied to
|
59
|
+
mask_path : str
|
60
|
+
Path to the mask in CCP4/MRC format.
|
61
|
+
kwargs : dict, optional
|
62
|
+
Keyword arguments passed to :py:meth:`tme.density.Density.from_file`.
|
63
|
+
|
64
|
+
Raise
|
65
|
+
-----
|
66
|
+
ValueError
|
67
|
+
If shape or sampling rate do not match between mask_target and mask
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
Density
|
72
|
+
A density instance if the mask was validated and loaded otherwise None
|
73
|
+
"""
|
74
|
+
mask = mask_path
|
75
|
+
if mask is not None:
|
76
|
+
mask = Density.from_file(mask, **kwargs)
|
77
|
+
mask.origin = deepcopy(mask_target.origin)
|
78
|
+
if not np.allclose(mask.shape, mask_target.shape):
|
79
|
+
raise ValueError(
|
80
|
+
f"Expected shape of {mask_path} was {mask_target.shape},"
|
81
|
+
f" got f{mask.shape}"
|
82
|
+
)
|
83
|
+
if not np.allclose(
|
84
|
+
np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
|
85
|
+
):
|
86
|
+
raise ValueError(
|
87
|
+
f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
|
88
|
+
f", got f{mask.sampling_rate}"
|
89
|
+
)
|
90
|
+
return mask
|
91
|
+
|
92
|
+
|
93
|
+
def parse_rotation_logic(args, ndim):
|
94
|
+
if args.particle_diameter is not None:
|
95
|
+
resolution = Density.from_file(args.target, use_memmap=True)
|
96
|
+
resolution = 360 * np.maximum(
|
97
|
+
np.max(2 * resolution.sampling_rate),
|
98
|
+
args.lowpass if args.lowpass is not None else 0,
|
99
|
+
)
|
100
|
+
args.angular_sampling = resolution / (3.14159265358979 * args.particle_diameter)
|
101
|
+
|
102
|
+
if args.angular_sampling is not None:
|
103
|
+
rotations = get_rotation_matrices(
|
104
|
+
angular_sampling=args.angular_sampling,
|
105
|
+
dim=ndim,
|
106
|
+
use_optimized_set=not args.no_use_optimized_set,
|
107
|
+
)
|
108
|
+
if args.angular_sampling >= 180:
|
109
|
+
rotations = np.eye(ndim).reshape(1, ndim, ndim)
|
110
|
+
return rotations
|
111
|
+
|
112
|
+
if args.axis_sampling is None:
|
113
|
+
args.axis_sampling = args.cone_sampling
|
114
|
+
|
115
|
+
rotations = get_cone_rotations(
|
116
|
+
cone_angle=args.cone_angle,
|
117
|
+
cone_sampling=args.cone_sampling,
|
118
|
+
axis_angle=args.axis_angle,
|
119
|
+
axis_sampling=args.axis_sampling,
|
120
|
+
n_symmetry=args.axis_symmetry,
|
121
|
+
axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
|
122
|
+
reference=[0, 0, -1],
|
123
|
+
)
|
124
|
+
return rotations
|
125
|
+
|
126
|
+
|
127
|
+
def compute_schedule(
|
128
|
+
args,
|
129
|
+
matching_data: MatchingData,
|
130
|
+
callback_class,
|
131
|
+
pad_edges: bool = False,
|
132
|
+
):
|
133
|
+
# User requested target padding
|
134
|
+
if args.pad_edges is True:
|
135
|
+
pad_edges = True
|
136
|
+
|
137
|
+
splits, schedule = matching_data.computation_schedule(
|
138
|
+
matching_method=args.score,
|
139
|
+
analyzer_method=callback_class.__name__,
|
140
|
+
use_gpu=args.use_gpu,
|
141
|
+
pad_fourier=False,
|
142
|
+
pad_target_edges=pad_edges,
|
143
|
+
available_memory=args.memory,
|
144
|
+
max_cores=args.cores,
|
145
|
+
)
|
146
|
+
|
147
|
+
if splits is None:
|
148
|
+
print(
|
149
|
+
"Found no suitable parallelization schedule. Consider increasing"
|
150
|
+
" available RAM or decreasing number of cores."
|
151
|
+
)
|
152
|
+
exit(-1)
|
153
|
+
|
154
|
+
n_splits = np.prod(list(splits.values()))
|
155
|
+
if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
|
156
|
+
args.pad_edges = True
|
157
|
+
return compute_schedule(args, matching_data, callback_class, True)
|
158
|
+
return splits, schedule
|
159
|
+
|
160
|
+
|
161
|
+
def extract_tilts(args, target):
|
162
|
+
from tme.projection import Projector
|
163
|
+
try:
|
164
|
+
wedge = Wedge.from_file(args.tilt_angles)
|
165
|
+
wedge.weight_type = args.tilt_weighting
|
166
|
+
if args.tilt_weighting in ("angle", None):
|
167
|
+
wedge = WedgeReconstructed(
|
168
|
+
angles=wedge.angles,
|
169
|
+
weight_wedge=args.tilt_weighting == "angle",
|
170
|
+
)
|
171
|
+
except (FileNotFoundError, AttributeError):
|
172
|
+
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
173
|
+
tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
|
174
|
+
wedge = WedgeReconstructed(
|
175
|
+
angles=(tilt_start, tilt_stop),
|
176
|
+
create_continuous_wedge=True,
|
177
|
+
weight_wedge=False,
|
178
|
+
reconstruction_filter=args.reconstruction_filter,
|
179
|
+
)
|
180
|
+
projector = Projector(target.data)
|
181
|
+
tilts = projector.extract_tilts(
|
182
|
+
tilt_angles=wedge.angles,
|
183
|
+
)
|
184
|
+
target = Density(tilts, sampling_rate=(1, *target.sampling_rate[1:]))
|
185
|
+
temp_path = generate_tempfile_name("h5")
|
186
|
+
target.to_file(temp_path)
|
187
|
+
return Density.from_file(temp_path, use_memmap = True)
|
188
|
+
|
189
|
+
|
190
|
+
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
191
|
+
template_filter, target_filter = [], []
|
192
|
+
|
193
|
+
wedge = None
|
194
|
+
if args.tilt_angles is not None:
|
195
|
+
try:
|
196
|
+
wedge = Wedge.from_file(args.tilt_angles)
|
197
|
+
wedge.weight_type = args.tilt_weighting
|
198
|
+
except (FileNotFoundError, AttributeError):
|
199
|
+
raise ValueError(
|
200
|
+
"Projection matching angles need to be specified via angles file."
|
201
|
+
)
|
202
|
+
|
203
|
+
wedge_target = WedgeReconstructed(
|
204
|
+
angles=wedge.angles,
|
205
|
+
weight_wedge=False,
|
206
|
+
create_continuous_wedge=True,
|
207
|
+
opening_axis=args.wedge_axes[0],
|
208
|
+
tilt_axis=args.wedge_axes[1],
|
209
|
+
)
|
210
|
+
wedge.opening_axis = args.wedge_axes[0]
|
211
|
+
wedge.tilt_axis = args.wedge_axes[1]
|
212
|
+
|
213
|
+
target_filter.append(wedge_target)
|
214
|
+
template_filter.append(wedge)
|
215
|
+
|
216
|
+
args.ctf_file is not None
|
217
|
+
if args.ctf_file is not None or args.defocus is not None:
|
218
|
+
try:
|
219
|
+
ctf = CTF.from_file(args.ctf_file)
|
220
|
+
if (len(ctf.angles) == 0) and wedge is None:
|
221
|
+
raise ValueError(
|
222
|
+
"You requested to specify the CTF per tilt, but did not specify "
|
223
|
+
"tilt angles via --tilt_angles or --ctf_file (Warp/M XML format). "
|
224
|
+
)
|
225
|
+
if len(ctf.angles) == 0:
|
226
|
+
ctf.angles = wedge.angles
|
227
|
+
|
228
|
+
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
229
|
+
if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
|
230
|
+
raise ValueError(
|
231
|
+
f"CTF file contains {n_tilts_ctfs} tilt, but match_template "
|
232
|
+
f"recieved {n_tils_angles} tilt angles. Expected one angle "
|
233
|
+
"per tilt."
|
234
|
+
)
|
235
|
+
|
236
|
+
except (FileNotFoundError, AttributeError):
|
237
|
+
ctf = CTFReconstructed(defocus_x=args.defocus, phase_shift=args.phase_shift)
|
238
|
+
|
239
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
240
|
+
ctf.sampling_rate = template.sampling_rate
|
241
|
+
ctf.flip_phase = args.no_flip_phase
|
242
|
+
ctf.amplitude_contrast = args.amplitude_contrast
|
243
|
+
ctf.spherical_aberration = args.spherical_aberration
|
244
|
+
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
245
|
+
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
246
|
+
template_filter.append(ctf)
|
247
|
+
|
248
|
+
if args.lowpass or args.highpass is not None:
|
249
|
+
lowpass, highpass = args.lowpass, args.highpass
|
250
|
+
if args.pass_format == "voxel":
|
251
|
+
if lowpass is not None:
|
252
|
+
lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
|
253
|
+
if highpass is not None:
|
254
|
+
highpass = np.max(np.multiply(highpass, template.sampling_rate))
|
255
|
+
elif args.pass_format == "frequency":
|
256
|
+
if lowpass is not None:
|
257
|
+
lowpass = np.max(np.divide(template.sampling_rate, lowpass))
|
258
|
+
if highpass is not None:
|
259
|
+
highpass = np.max(np.divide(template.sampling_rate, highpass))
|
260
|
+
|
261
|
+
try:
|
262
|
+
if args.lowpass >= args.highpass:
|
263
|
+
warnings.warn("--lowpass should be smaller than --highpass.")
|
264
|
+
except Exception:
|
265
|
+
pass
|
266
|
+
|
267
|
+
bandpass = BandPassReconstructed(
|
268
|
+
use_gaussian=args.no_pass_smooth,
|
269
|
+
lowpass=lowpass,
|
270
|
+
highpass=highpass,
|
271
|
+
sampling_rate=template.sampling_rate,
|
272
|
+
)
|
273
|
+
template_filter.append(bandpass)
|
274
|
+
target_filter.append(bandpass)
|
275
|
+
|
276
|
+
if args.whiten_spectrum:
|
277
|
+
whitening_filter = LinearWhiteningFilter()
|
278
|
+
template_filter.append(whitening_filter)
|
279
|
+
target_filter.append(whitening_filter)
|
280
|
+
|
281
|
+
rec_filt = (Wedge, CTF)
|
282
|
+
needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
|
283
|
+
if needs_reconstruction > 0 and args.reconstruction_filter is None:
|
284
|
+
warnings.warn(
|
285
|
+
"Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
|
286
|
+
"to avoid artifacts from reconstruction using weighted backprojection."
|
287
|
+
)
|
288
|
+
|
289
|
+
template_filter = sorted(
|
290
|
+
template_filter, key=lambda x: type(x) in rec_filt, reverse=True
|
291
|
+
)
|
292
|
+
if needs_reconstruction > 0:
|
293
|
+
relevant_filters = [x for x in template_filter if type(x) in rec_filt]
|
294
|
+
if len(relevant_filters) == 0:
|
295
|
+
raise ValueError("Filters require ")
|
296
|
+
|
297
|
+
reconstruction_filter = ReconstructFromTilt(
|
298
|
+
reconstruction_filter=args.reconstruction_filter,
|
299
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
300
|
+
angles=relevant_filters[0].angles,
|
301
|
+
opening_axis=args.wedge_axes[0],
|
302
|
+
tilt_axis=args.wedge_axes[1],
|
303
|
+
)
|
304
|
+
template_filter.insert(needs_reconstruction, reconstruction_filter)
|
305
|
+
|
306
|
+
template_filter = Compose(template_filter) if len(template_filter) else None
|
307
|
+
target_filter = Compose(target_filter) if len(target_filter) else None
|
308
|
+
if args.no_filter_target:
|
309
|
+
target_filter = None
|
310
|
+
|
311
|
+
return template_filter, target_filter
|
312
|
+
|
313
|
+
|
314
|
+
def setup_projection_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
315
|
+
template_filter, target_filter = [], []
|
316
|
+
|
317
|
+
wedge = None
|
318
|
+
if args.tilt_angles is not None:
|
319
|
+
try:
|
320
|
+
wedge = Wedge.from_file(args.tilt_angles)
|
321
|
+
wedge.weight_type = args.tilt_weighting
|
322
|
+
except (FileNotFoundError, AttributeError):
|
323
|
+
raise ValueError(
|
324
|
+
"Projection matching angles need to be specified via angles file."
|
325
|
+
)
|
326
|
+
wedge.opening_axis = args.wedge_axes[0]
|
327
|
+
wedge.tilt_axis = args.wedge_axes[1]
|
328
|
+
template_filter.append(wedge)
|
329
|
+
|
330
|
+
args.ctf_file is not None
|
331
|
+
if args.ctf_file is not None or args.defocus is not None:
|
332
|
+
try:
|
333
|
+
ctf = CTF.from_file(args.ctf_file)
|
334
|
+
if (len(ctf.angles) == 0) and wedge is None:
|
335
|
+
raise ValueError(
|
336
|
+
"You requested to specify the CTF per tilt, but did not specify "
|
337
|
+
"tilt angles via --tilt_angles or --ctf_file (Warp/M XML format). "
|
338
|
+
)
|
339
|
+
if len(ctf.angles) == 0:
|
340
|
+
ctf.angles = wedge.angles
|
341
|
+
|
342
|
+
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
343
|
+
if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
|
344
|
+
raise ValueError(
|
345
|
+
f"CTF file contains {n_tilts_ctfs} tilt, but match_template "
|
346
|
+
f"recieved {n_tils_angles} tilt angles. Expected one angle "
|
347
|
+
"per tilt."
|
348
|
+
)
|
349
|
+
|
350
|
+
except (FileNotFoundError, AttributeError):
|
351
|
+
ctf = CTF(
|
352
|
+
defocus_x=args.defocus,
|
353
|
+
phase_shift=args.phase_shift,
|
354
|
+
angles=wedge.angles
|
355
|
+
)
|
356
|
+
|
357
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
358
|
+
ctf.sampling_rate = template.sampling_rate
|
359
|
+
ctf.flip_phase = args.no_flip_phase
|
360
|
+
ctf.amplitude_contrast = args.amplitude_contrast
|
361
|
+
ctf.spherical_aberration = args.spherical_aberration
|
362
|
+
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
363
|
+
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
364
|
+
template_filter.append(ctf)
|
365
|
+
|
366
|
+
if args.lowpass or args.highpass is not None:
|
367
|
+
lowpass, highpass = args.lowpass, args.highpass
|
368
|
+
if args.pass_format == "voxel":
|
369
|
+
if lowpass is not None:
|
370
|
+
lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
|
371
|
+
if highpass is not None:
|
372
|
+
highpass = np.max(np.multiply(highpass, template.sampling_rate))
|
373
|
+
elif args.pass_format == "frequency":
|
374
|
+
if lowpass is not None:
|
375
|
+
lowpass = np.max(np.divide(template.sampling_rate, lowpass))
|
376
|
+
if highpass is not None:
|
377
|
+
highpass = np.max(np.divide(template.sampling_rate, highpass))
|
378
|
+
|
379
|
+
try:
|
380
|
+
if args.lowpass >= args.highpass:
|
381
|
+
warnings.warn("--lowpass should be smaller than --highpass.")
|
382
|
+
except Exception:
|
383
|
+
pass
|
384
|
+
|
385
|
+
bandpass = BandPass(
|
386
|
+
angles=wedge.angles,
|
387
|
+
use_gaussian=args.no_pass_smooth,
|
388
|
+
lowpass=lowpass,
|
389
|
+
highpass=highpass,
|
390
|
+
sampling_rate=template.sampling_rate,
|
391
|
+
)
|
392
|
+
template_filter.append(bandpass)
|
393
|
+
target_filter.append(bandpass)
|
394
|
+
|
395
|
+
if args.whiten_spectrum:
|
396
|
+
whitening_filter = LinearWhiteningFilter()
|
397
|
+
template_filter.append(whitening_filter)
|
398
|
+
target_filter.append(whitening_filter)
|
399
|
+
|
400
|
+
template_filter.append(ShiftFourier())
|
401
|
+
if len(target_filter):
|
402
|
+
target_filter.append(ShiftFourier())
|
403
|
+
template_filter = Compose(template_filter) if len(template_filter) else None
|
404
|
+
target_filter = Compose(target_filter) if len(target_filter) else None
|
405
|
+
if args.no_filter_target:
|
406
|
+
target_filter = None
|
407
|
+
|
408
|
+
return template_filter, target_filter
|
409
|
+
|
410
|
+
|
411
|
+
def _format_sampling(arr, decimals: int = 2):
|
412
|
+
return tuple(round(float(x), decimals) for x in arr)
|
413
|
+
|
414
|
+
def parse_args():
|
415
|
+
parser = argparse.ArgumentParser(
|
416
|
+
description="Perform template matching.",
|
417
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
418
|
+
)
|
419
|
+
|
420
|
+
io_group = parser.add_argument_group("Input / Output")
|
421
|
+
io_group.add_argument(
|
422
|
+
"-m",
|
423
|
+
"--target",
|
424
|
+
dest="target",
|
425
|
+
type=str,
|
426
|
+
required=True,
|
427
|
+
help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
|
428
|
+
"tme.density.Density.from_file "
|
429
|
+
"https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
|
430
|
+
)
|
431
|
+
io_group.add_argument(
|
432
|
+
"--target_mask",
|
433
|
+
"--target-mask",
|
434
|
+
dest="target_mask",
|
435
|
+
type=str,
|
436
|
+
required=False,
|
437
|
+
help="Path to a mask for the target in a supported format (see target).",
|
438
|
+
)
|
439
|
+
io_group.add_argument(
|
440
|
+
"-i",
|
441
|
+
"--template",
|
442
|
+
dest="template",
|
443
|
+
type=str,
|
444
|
+
required=True,
|
445
|
+
help="Path to a template in PDB/MMCIF or other supported formats (see target).",
|
446
|
+
)
|
447
|
+
io_group.add_argument(
|
448
|
+
"--template_mask",
|
449
|
+
"--template-mask",
|
450
|
+
dest="template_mask",
|
451
|
+
type=str,
|
452
|
+
required=False,
|
453
|
+
help="Path to a mask for the template in a supported format (see target).",
|
454
|
+
)
|
455
|
+
io_group.add_argument(
|
456
|
+
"-o",
|
457
|
+
"--output",
|
458
|
+
dest="output",
|
459
|
+
type=str,
|
460
|
+
required=False,
|
461
|
+
default="output.pickle",
|
462
|
+
help="Path to the output pickle file.",
|
463
|
+
)
|
464
|
+
io_group.add_argument(
|
465
|
+
"--invert_target_contrast",
|
466
|
+
"--invert-target-contrast",
|
467
|
+
dest="invert_target_contrast",
|
468
|
+
action="store_true",
|
469
|
+
default=False,
|
470
|
+
help="Invert the target's contrast for cases where templates to-be-matched have "
|
471
|
+
"negative values, e.g. tomograms.",
|
472
|
+
)
|
473
|
+
io_group.add_argument(
|
474
|
+
"--scramble_phases",
|
475
|
+
"--scramble-phases",
|
476
|
+
dest="scramble_phases",
|
477
|
+
action="store_true",
|
478
|
+
default=False,
|
479
|
+
help="Phase scramble the template to generate a noise score background.",
|
480
|
+
)
|
481
|
+
|
482
|
+
sampling_group = parser.add_argument_group("Sampling")
|
483
|
+
sampling_group.add_argument(
|
484
|
+
"--orientations",
|
485
|
+
dest="orientations",
|
486
|
+
default=None,
|
487
|
+
required=False,
|
488
|
+
help="Path to a file readable via Orientations.from_file containing "
|
489
|
+
"translations and rotations of candidate peaks to refine.",
|
490
|
+
)
|
491
|
+
sampling_group.add_argument(
|
492
|
+
"--orientations_scaling",
|
493
|
+
"--orientations-scaling",
|
494
|
+
required=False,
|
495
|
+
type=float,
|
496
|
+
default=1.0,
|
497
|
+
help="Scaling factor to map candidate translations onto the target. "
|
498
|
+
"Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
|
499
|
+
"the corresponding orientations_scaling would be 3.",
|
500
|
+
)
|
501
|
+
sampling_group.add_argument(
|
502
|
+
"--orientations_cone",
|
503
|
+
"--orientations-cone",
|
504
|
+
required=False,
|
505
|
+
type=float,
|
506
|
+
default=20.0,
|
507
|
+
help="Accept orientations within specified cone angle of each orientation.",
|
508
|
+
)
|
509
|
+
sampling_group.add_argument(
|
510
|
+
"--orientations_uncertainty",
|
511
|
+
"--orientations-uncertainty",
|
512
|
+
required=False,
|
513
|
+
type=str,
|
514
|
+
default="10",
|
515
|
+
help="Accept translations within the specified radius of each orientation. "
|
516
|
+
"Can be a single value or comma-separated string for per-axis uncertainty.",
|
517
|
+
)
|
518
|
+
|
519
|
+
scoring_group = parser.add_argument_group("Scoring")
|
520
|
+
scoring_group.add_argument(
|
521
|
+
"-s",
|
522
|
+
dest="score",
|
523
|
+
type=str,
|
524
|
+
default="FLCSphericalMask",
|
525
|
+
choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
|
526
|
+
help="Template matching scoring function.",
|
527
|
+
)
|
528
|
+
|
529
|
+
angular_group = parser.add_argument_group("Angular Sampling")
|
530
|
+
angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
|
531
|
+
|
532
|
+
angular_exclusive.add_argument(
|
533
|
+
"-a",
|
534
|
+
"--angular_sampling",
|
535
|
+
"--angular-sampling",
|
536
|
+
dest="angular_sampling",
|
537
|
+
type=check_positive,
|
538
|
+
default=None,
|
539
|
+
help="Angular sampling rate using optimized rotational sets."
|
540
|
+
"A lower number yields more rotations. Values >= 180 sample only the identity.",
|
541
|
+
)
|
542
|
+
angular_exclusive.add_argument(
|
543
|
+
"--cone_angle",
|
544
|
+
"--cone-angle",
|
545
|
+
dest="cone_angle",
|
546
|
+
type=check_positive,
|
547
|
+
default=None,
|
548
|
+
help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
|
549
|
+
"narrow interval around a known orientation, e.g. for surface oversampling.",
|
550
|
+
)
|
551
|
+
angular_exclusive.add_argument(
|
552
|
+
"--particle_diameter",
|
553
|
+
"--particle-diameter",
|
554
|
+
dest="particle_diameter",
|
555
|
+
type=check_positive,
|
556
|
+
default=None,
|
557
|
+
help="Particle diameter in units of sampling rate.",
|
558
|
+
)
|
559
|
+
angular_group.add_argument(
|
560
|
+
"--cone_axis",
|
561
|
+
"--cone-axis",
|
562
|
+
dest="cone_axis",
|
563
|
+
type=check_positive,
|
564
|
+
default=2,
|
565
|
+
help="Principal axis to build cone around.",
|
566
|
+
)
|
567
|
+
angular_group.add_argument(
|
568
|
+
"--invert_cone",
|
569
|
+
"--invert-cone",
|
570
|
+
dest="invert_cone",
|
571
|
+
action="store_true",
|
572
|
+
help="Invert cone handedness direction from up to down.",
|
573
|
+
)
|
574
|
+
angular_group.add_argument(
|
575
|
+
"--cone_sampling",
|
576
|
+
"--cone-sampling",
|
577
|
+
dest="cone_sampling",
|
578
|
+
type=check_positive,
|
579
|
+
default=None,
|
580
|
+
help="Sampling rate of the cone in degrees.",
|
581
|
+
)
|
582
|
+
angular_group.add_argument(
|
583
|
+
"--axis_angle",
|
584
|
+
"--axis-angle",
|
585
|
+
dest="axis_angle",
|
586
|
+
type=check_positive,
|
587
|
+
default=360.0,
|
588
|
+
required=False,
|
589
|
+
help="Sampling angle along the principal axis of the cone.",
|
590
|
+
)
|
591
|
+
angular_group.add_argument(
|
592
|
+
"--axis_sampling",
|
593
|
+
"--axis-sampling",
|
594
|
+
dest="axis_sampling",
|
595
|
+
type=check_positive,
|
596
|
+
default=None,
|
597
|
+
required=False,
|
598
|
+
help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
|
599
|
+
)
|
600
|
+
angular_group.add_argument(
|
601
|
+
"--axis_symmetry",
|
602
|
+
"--axis-symmetry",
|
603
|
+
dest="axis_symmetry",
|
604
|
+
type=check_positive,
|
605
|
+
default=1,
|
606
|
+
required=False,
|
607
|
+
help="N-fold symmetry around z-axis of the cone.",
|
608
|
+
)
|
609
|
+
angular_group.add_argument(
|
610
|
+
"--no_use_optimized_set",
|
611
|
+
"--no-use-optimized-set",
|
612
|
+
dest="no_use_optimized_set",
|
613
|
+
action="store_true",
|
614
|
+
default=False,
|
615
|
+
required=False,
|
616
|
+
help="Whether to use random uniform instead of optimized rotation sets.",
|
617
|
+
)
|
618
|
+
|
619
|
+
computation_group = parser.add_argument_group("Computation")
|
620
|
+
computation_group.add_argument(
|
621
|
+
"-n",
|
622
|
+
dest="cores",
|
623
|
+
required=False,
|
624
|
+
type=int,
|
625
|
+
default=4,
|
626
|
+
help="Number of cores used for template matching.",
|
627
|
+
)
|
628
|
+
computation_group.add_argument(
|
629
|
+
"--use_gpu",
|
630
|
+
"--use-gpu",
|
631
|
+
dest="use_gpu",
|
632
|
+
action="store_true",
|
633
|
+
default=False,
|
634
|
+
help="Whether to perform computations on the GPU.",
|
635
|
+
)
|
636
|
+
computation_group.add_argument(
|
637
|
+
"--gpu_indices",
|
638
|
+
"--gpu-indices",
|
639
|
+
dest="gpu_indices",
|
640
|
+
type=str,
|
641
|
+
default=None,
|
642
|
+
help="Comma-separated list of GPU indices to use. For example,"
|
643
|
+
" 0,1 for the first and second GPU. Only used if --use_gpu is set."
|
644
|
+
" If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
|
645
|
+
" be respected.",
|
646
|
+
)
|
647
|
+
computation_group.add_argument(
|
648
|
+
"--memory",
|
649
|
+
dest="memory",
|
650
|
+
required=False,
|
651
|
+
type=int,
|
652
|
+
default=None,
|
653
|
+
help="Amount of memory that can be used in bytes.",
|
654
|
+
)
|
655
|
+
computation_group.add_argument(
|
656
|
+
"--memory_scaling",
|
657
|
+
"--memory-scaling",
|
658
|
+
dest="memory_scaling",
|
659
|
+
required=False,
|
660
|
+
type=float,
|
661
|
+
default=0.85,
|
662
|
+
help="Fraction of available memory to be used. Ignored if --memory is set.",
|
663
|
+
)
|
664
|
+
computation_group.add_argument(
|
665
|
+
"--temp_directory",
|
666
|
+
"--temp-directory",
|
667
|
+
dest="temp_directory",
|
668
|
+
default=None,
|
669
|
+
help="Directory for temporary objects. Faster I/O improves runtime.",
|
670
|
+
)
|
671
|
+
computation_group.add_argument(
|
672
|
+
"--backend",
|
673
|
+
dest="backend",
|
674
|
+
default=be._backend_name,
|
675
|
+
choices=be.available_backends(),
|
676
|
+
help="[Expert] Overwrite default computation backend.",
|
677
|
+
)
|
678
|
+
filter_group = parser.add_argument_group("Filters")
|
679
|
+
filter_group.add_argument(
|
680
|
+
"--lowpass",
|
681
|
+
dest="lowpass",
|
682
|
+
type=float,
|
683
|
+
required=False,
|
684
|
+
help="Resolution to lowpass filter template and target to in the same unit "
|
685
|
+
"as the sampling rate of template and target (typically Ångstrom).",
|
686
|
+
)
|
687
|
+
filter_group.add_argument(
|
688
|
+
"--highpass",
|
689
|
+
dest="highpass",
|
690
|
+
type=float,
|
691
|
+
required=False,
|
692
|
+
help="Resolution to highpass filter template and target to in the same unit "
|
693
|
+
"as the sampling rate of template and target (typically Ångstrom).",
|
694
|
+
)
|
695
|
+
filter_group.add_argument(
|
696
|
+
"--no_pass_smooth",
|
697
|
+
"--no-pass-smooth",
|
698
|
+
dest="no_pass_smooth",
|
699
|
+
action="store_false",
|
700
|
+
default=True,
|
701
|
+
help="Whether a hard edge filter should be used for --lowpass and --highpass.",
|
702
|
+
)
|
703
|
+
filter_group.add_argument(
|
704
|
+
"--pass_format",
|
705
|
+
"--pass-format",
|
706
|
+
dest="pass_format",
|
707
|
+
type=str,
|
708
|
+
required=False,
|
709
|
+
default="sampling_rate",
|
710
|
+
choices=["sampling_rate", "voxel", "frequency"],
|
711
|
+
help="How values passed to --lowpass and --highpass should be interpreted. "
|
712
|
+
"Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
|
713
|
+
)
|
714
|
+
filter_group.add_argument(
|
715
|
+
"--whiten_spectrum",
|
716
|
+
"--whiten-spectrum",
|
717
|
+
dest="whiten_spectrum",
|
718
|
+
action="store_true",
|
719
|
+
default=None,
|
720
|
+
help="Apply spectral whitening to template and target based on target spectrum.",
|
721
|
+
)
|
722
|
+
filter_group.add_argument(
|
723
|
+
"--wedge_axes",
|
724
|
+
"--wedge-axes",
|
725
|
+
dest="wedge_axes",
|
726
|
+
type=str,
|
727
|
+
required=False,
|
728
|
+
default="2,0",
|
729
|
+
help="Indices of projection (wedge opening) and tilt axis, e.g., '2,0' "
|
730
|
+
"for the typical projection over z and tilting over the x-axis.",
|
731
|
+
)
|
732
|
+
filter_group.add_argument(
|
733
|
+
"--tilt_angles",
|
734
|
+
"--tilt-angles",
|
735
|
+
dest="tilt_angles",
|
736
|
+
type=str,
|
737
|
+
required=False,
|
738
|
+
default=None,
|
739
|
+
help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
|
740
|
+
"a tomostar STAR file, a tab-separated file with column name 'angles', or a "
|
741
|
+
"single column file without header. Exposure will be taken from the input file "
|
742
|
+
", if you are using a tab-separated file, the column names 'angles' and "
|
743
|
+
"'weights' need to be present. It is also possible to specify a continuous "
|
744
|
+
"wedge mask using e.g., -50,45.",
|
745
|
+
)
|
746
|
+
filter_group.add_argument(
|
747
|
+
"--tilt_weighting",
|
748
|
+
"--tilt-weighting",
|
749
|
+
dest="tilt_weighting",
|
750
|
+
type=str,
|
751
|
+
required=False,
|
752
|
+
choices=["angle", "relion", "grigorieff"],
|
753
|
+
default=None,
|
754
|
+
help="Weighting scheme used to reweight individual tilts. Available options: "
|
755
|
+
"angle (cosine based weighting), "
|
756
|
+
"relion (relion formalism for wedge weighting) requires,"
|
757
|
+
"grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
|
758
|
+
"relion and grigorieff require electron doses in --tilt_angles weights column.",
|
759
|
+
)
|
760
|
+
filter_group.add_argument(
|
761
|
+
"--reconstruction_filter",
|
762
|
+
"--reconstruction-filter",
|
763
|
+
dest="reconstruction_filter",
|
764
|
+
type=str,
|
765
|
+
required=False,
|
766
|
+
choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
|
767
|
+
default=None,
|
768
|
+
help="Filter applied when reconstructing (N+1)-D from N-D filters.",
|
769
|
+
)
|
770
|
+
filter_group.add_argument(
|
771
|
+
"--reconstruction_interpolation_order",
|
772
|
+
"--reconstruction-interpolation-order",
|
773
|
+
dest="reconstruction_interpolation_order",
|
774
|
+
type=int,
|
775
|
+
default=1,
|
776
|
+
required=False,
|
777
|
+
help="Analogous to --interpolation_order but for reconstruction.",
|
778
|
+
)
|
779
|
+
filter_group.add_argument(
|
780
|
+
"--no_filter_target",
|
781
|
+
"--no-filter-target",
|
782
|
+
dest="no_filter_target",
|
783
|
+
action="store_true",
|
784
|
+
default=False,
|
785
|
+
help="Whether to not apply potential filters to the target.",
|
786
|
+
)
|
787
|
+
|
788
|
+
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
789
|
+
ctf_group.add_argument(
|
790
|
+
"--ctf_file",
|
791
|
+
"--ctf-file",
|
792
|
+
dest="ctf_file",
|
793
|
+
type=str,
|
794
|
+
required=False,
|
795
|
+
default=None,
|
796
|
+
help="Path to a file with CTF parameters. This can be a Warp/M XML file "
|
797
|
+
"a GCTF/Relion STAR file, or the output of CTFFIND4. If the file does not "
|
798
|
+
"specify tilt angles, the angles specified with --tilt_angles are used.",
|
799
|
+
)
|
800
|
+
ctf_group.add_argument(
|
801
|
+
"--defocus",
|
802
|
+
dest="defocus",
|
803
|
+
type=float,
|
804
|
+
required=False,
|
805
|
+
default=None,
|
806
|
+
help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
|
807
|
+
"for a defocus of 3 micrometer. Superseded by --ctf_file.",
|
808
|
+
)
|
809
|
+
ctf_group.add_argument(
|
810
|
+
"--phase_shift",
|
811
|
+
"--phase-shift",
|
812
|
+
dest="phase_shift",
|
813
|
+
type=float,
|
814
|
+
required=False,
|
815
|
+
default=0,
|
816
|
+
help="Phase shift in degrees. Superseded by --ctf_file.",
|
817
|
+
)
|
818
|
+
ctf_group.add_argument(
|
819
|
+
"--acceleration_voltage",
|
820
|
+
"--acceleration-voltage",
|
821
|
+
dest="acceleration_voltage",
|
822
|
+
type=float,
|
823
|
+
required=False,
|
824
|
+
default=300,
|
825
|
+
help="Acceleration voltage in kV.",
|
826
|
+
)
|
827
|
+
ctf_group.add_argument(
|
828
|
+
"--spherical_aberration",
|
829
|
+
"--spherical-aberration",
|
830
|
+
dest="spherical_aberration",
|
831
|
+
type=float,
|
832
|
+
required=False,
|
833
|
+
default=2.7e7,
|
834
|
+
help="Spherical aberration in units of sampling rate (typically Ångstrom).",
|
835
|
+
)
|
836
|
+
ctf_group.add_argument(
|
837
|
+
"--amplitude_contrast",
|
838
|
+
"--amplitude-contrast",
|
839
|
+
dest="amplitude_contrast",
|
840
|
+
type=float,
|
841
|
+
required=False,
|
842
|
+
default=0.07,
|
843
|
+
help="Amplitude contrast.",
|
844
|
+
)
|
845
|
+
ctf_group.add_argument(
|
846
|
+
"--no_flip_phase",
|
847
|
+
"--no-flip-phase",
|
848
|
+
dest="no_flip_phase",
|
849
|
+
action="store_false",
|
850
|
+
required=False,
|
851
|
+
help="Do not perform phase-flipping CTF correction.",
|
852
|
+
)
|
853
|
+
ctf_group.add_argument(
|
854
|
+
"--correct_defocus_gradient",
|
855
|
+
"--correct-defocus-gradient",
|
856
|
+
dest="correct_defocus_gradient",
|
857
|
+
action="store_true",
|
858
|
+
required=False,
|
859
|
+
help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
|
860
|
+
"defocus gradients.",
|
861
|
+
)
|
862
|
+
|
863
|
+
performance_group = parser.add_argument_group("Performance")
|
864
|
+
performance_group.add_argument(
|
865
|
+
"--no_centering",
|
866
|
+
"--no-centering",
|
867
|
+
dest="no_centering",
|
868
|
+
action="store_true",
|
869
|
+
help="Assumes the template is already centered and omits centering.",
|
870
|
+
)
|
871
|
+
performance_group.add_argument(
|
872
|
+
"--pad_edges",
|
873
|
+
"--pad-edges",
|
874
|
+
dest="pad_edges",
|
875
|
+
action="store_true",
|
876
|
+
default=False,
|
877
|
+
help="Whether to pad the edges of the target. Useful if the target does not "
|
878
|
+
"a well-defined bounding box. Defaults to True if splitting is required.",
|
879
|
+
)
|
880
|
+
performance_group.add_argument(
|
881
|
+
"--pad_filter",
|
882
|
+
"--pad-filter",
|
883
|
+
dest="pad_filter",
|
884
|
+
action="store_true",
|
885
|
+
default=False,
|
886
|
+
help="Pads the filter to the shape of the target. Particularly useful for fast "
|
887
|
+
"oscilating filters to avoid aliasing effects.",
|
888
|
+
)
|
889
|
+
performance_group.add_argument(
|
890
|
+
"--interpolation_order",
|
891
|
+
"--interpolation-order",
|
892
|
+
dest="interpolation_order",
|
893
|
+
required=False,
|
894
|
+
type=int,
|
895
|
+
default=None,
|
896
|
+
help="Spline interpolation used for rotations.",
|
897
|
+
)
|
898
|
+
performance_group.add_argument(
|
899
|
+
"--use_mixed_precision",
|
900
|
+
"--use-mixed-precision",
|
901
|
+
dest="use_mixed_precision",
|
902
|
+
action="store_true",
|
903
|
+
default=False,
|
904
|
+
help="Use float16 for real values operations where possible. Not supported "
|
905
|
+
"for jax backend.",
|
906
|
+
)
|
907
|
+
performance_group.add_argument(
|
908
|
+
"--use_memmap",
|
909
|
+
"--use-memmap",
|
910
|
+
dest="use_memmap",
|
911
|
+
action="store_true",
|
912
|
+
default=False,
|
913
|
+
help="Use memmaps to offload large data objects to disk. "
|
914
|
+
"Particularly useful for large inputs in combination with --use_gpu.",
|
915
|
+
)
|
916
|
+
|
917
|
+
analyzer_group = parser.add_argument_group("Analyzer")
|
918
|
+
analyzer_group.add_argument(
|
919
|
+
"--score_threshold",
|
920
|
+
"--score-threshold",
|
921
|
+
dest="score_threshold",
|
922
|
+
required=False,
|
923
|
+
type=float,
|
924
|
+
default=0,
|
925
|
+
help="Minimum template matching scores to consider for analysis.",
|
926
|
+
)
|
927
|
+
analyzer_group.add_argument(
|
928
|
+
"-p",
|
929
|
+
"--peak-calling",
|
930
|
+
dest="peak_calling",
|
931
|
+
action="store_true",
|
932
|
+
default=False,
|
933
|
+
help="Perform peak calling instead of score aggregation.",
|
934
|
+
)
|
935
|
+
analyzer_group.add_argument(
|
936
|
+
"--num_peaks",
|
937
|
+
"--num-peaks",
|
938
|
+
dest="num_peaks",
|
939
|
+
default=1000,
|
940
|
+
help="Number of peaks to call, 1000 by default.",
|
941
|
+
)
|
942
|
+
|
943
|
+
projection_group = parser.add_argument_group("Projection")
|
944
|
+
projection_group.add_argument(
|
945
|
+
"--projection_matching",
|
946
|
+
"--projection-matching",
|
947
|
+
dest="projection_matching",
|
948
|
+
action="store_true",
|
949
|
+
help="Perform projection matching instead of nD-nD matching.",
|
950
|
+
)
|
951
|
+
projection_group.add_argument(
|
952
|
+
"--extract_tilts",
|
953
|
+
"--extract-tilts",
|
954
|
+
dest="extract_tilts",
|
955
|
+
action="store_true",
|
956
|
+
help="Assume target is a reconstruction we have to extract tilts from. If the "
|
957
|
+
"target is a tilt series already, this flag can be omitted.",
|
958
|
+
)
|
959
|
+
|
960
|
+
|
961
|
+
args = parser.parse_args()
|
962
|
+
args.version = __version__
|
963
|
+
|
964
|
+
if args.interpolation_order is None:
|
965
|
+
args.interpolation_order = 3
|
966
|
+
if args.backend in ("jax", "pytorch"):
|
967
|
+
args.interpolation_order = 1
|
968
|
+
|
969
|
+
if args.interpolation_order < 0:
|
970
|
+
args.interpolation_order = None
|
971
|
+
|
972
|
+
if args.temp_directory is None:
|
973
|
+
args.temp_directory = gettempdir()
|
974
|
+
|
975
|
+
os.environ["TMPDIR"] = args.temp_directory
|
976
|
+
if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
|
977
|
+
raise ValueError(
|
978
|
+
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
979
|
+
)
|
980
|
+
|
981
|
+
if args.gpu_indices is not None:
|
982
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
|
983
|
+
|
984
|
+
if args.use_gpu:
|
985
|
+
warnings.warn(
|
986
|
+
"The use_gpu flag is no longer required and automatically "
|
987
|
+
"determined based on the selected backend."
|
988
|
+
)
|
989
|
+
|
990
|
+
if args.tilt_angles is not None:
|
991
|
+
if args.wedge_axes is None:
|
992
|
+
raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
|
993
|
+
if not exists(args.tilt_angles):
|
994
|
+
try:
|
995
|
+
float(args.tilt_angles.split(",")[0])
|
996
|
+
except ValueError:
|
997
|
+
raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
|
998
|
+
|
999
|
+
if args.extract_tilts:
|
1000
|
+
args.projection_matching = True
|
1001
|
+
|
1002
|
+
if args.extract_tilts and args.tilt_angles is None:
|
1003
|
+
raise ValueError("Need to specify --tilt_angles when --extract_tilts is set.")
|
1004
|
+
|
1005
|
+
if args.projection_matching and args.backend != "jax":
|
1006
|
+
raise ValueError("Projection matching is only supported for --backend jax.")
|
1007
|
+
|
1008
|
+
if args.ctf_file is not None and args.tilt_angles is None:
|
1009
|
+
raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
|
1010
|
+
|
1011
|
+
if args.wedge_axes is not None:
|
1012
|
+
args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
|
1013
|
+
|
1014
|
+
if args.orientations is not None:
|
1015
|
+
orientations = Orientations.from_file(args.orientations)
|
1016
|
+
orientations.translations = np.divide(
|
1017
|
+
orientations.translations, args.orientations_scaling
|
1018
|
+
)
|
1019
|
+
args.orientations = orientations
|
1020
|
+
|
1021
|
+
return args
|
1022
|
+
|
1023
|
+
|
1024
|
+
def main():
|
1025
|
+
args = parse_args()
|
1026
|
+
print_entry()
|
1027
|
+
|
1028
|
+
target = Density.from_file(args.target, use_memmap=True)
|
1029
|
+
if args.extract_tilts:
|
1030
|
+
target = extract_tilts(args=args, target=target)
|
1031
|
+
|
1032
|
+
try:
|
1033
|
+
template = Density.from_file(args.template)
|
1034
|
+
except Exception:
|
1035
|
+
template = Density.from_structure(
|
1036
|
+
filename_or_structure=args.template,
|
1037
|
+
sampling_rate=target.sampling_rate,
|
1038
|
+
)
|
1039
|
+
|
1040
|
+
if target.sampling_rate.size == template.sampling_rate.size:
|
1041
|
+
if not np.allclose(
|
1042
|
+
np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
|
1043
|
+
) and not args.projection_matching:
|
1044
|
+
print(
|
1045
|
+
"Target and template sampling rate do not match. "
|
1046
|
+
"Make sure this is intended."
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
template_mask = load_and_validate_mask(
|
1050
|
+
mask_target=template, mask_path=args.template_mask
|
1051
|
+
)
|
1052
|
+
target_mask = load_and_validate_mask(
|
1053
|
+
mask_target=target, mask_path=args.target_mask, use_memmap=True
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
initial_shape = target.shape
|
1057
|
+
print_block(
|
1058
|
+
name="Target",
|
1059
|
+
data={
|
1060
|
+
"Initial Shape": initial_shape,
|
1061
|
+
"Sampling Rate": _format_sampling(target.sampling_rate),
|
1062
|
+
"Final Shape": target.shape,
|
1063
|
+
},
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
if target_mask:
|
1067
|
+
print_block(
|
1068
|
+
name="Target Mask",
|
1069
|
+
data={
|
1070
|
+
"Initial Shape": initial_shape,
|
1071
|
+
"Sampling Rate": _format_sampling(target_mask.sampling_rate),
|
1072
|
+
"Final Shape": target_mask.shape,
|
1073
|
+
},
|
1074
|
+
)
|
1075
|
+
|
1076
|
+
initial_shape = template.shape
|
1077
|
+
translation = np.zeros(len(template.shape), dtype=np.float32)
|
1078
|
+
if not args.no_centering:
|
1079
|
+
template, translation = template.centered(0)
|
1080
|
+
print_block(
|
1081
|
+
name="Template",
|
1082
|
+
data={
|
1083
|
+
"Initial Shape": initial_shape,
|
1084
|
+
"Sampling Rate": _format_sampling(template.sampling_rate),
|
1085
|
+
"Final Shape": template.shape,
|
1086
|
+
},
|
1087
|
+
)
|
1088
|
+
|
1089
|
+
if template_mask is None:
|
1090
|
+
template_mask = template.empty
|
1091
|
+
if not args.no_centering:
|
1092
|
+
enclosing_box = template.minimum_enclosing_box(
|
1093
|
+
0, use_geometric_center=False
|
1094
|
+
)
|
1095
|
+
template_mask.adjust_box(enclosing_box)
|
1096
|
+
|
1097
|
+
template_mask.data[:] = 1
|
1098
|
+
translation = np.zeros_like(translation)
|
1099
|
+
|
1100
|
+
template_mask.pad(template.shape, center=False)
|
1101
|
+
origin_translation = np.divide(
|
1102
|
+
np.subtract(template.origin, template_mask.origin), template.sampling_rate
|
1103
|
+
)
|
1104
|
+
translation = np.add(translation, origin_translation)
|
1105
|
+
|
1106
|
+
template_mask = template_mask.rigid_transform(
|
1107
|
+
rotation_matrix=np.eye(template_mask.data.ndim),
|
1108
|
+
translation=-translation,
|
1109
|
+
order=1,
|
1110
|
+
)
|
1111
|
+
template_mask.origin = template.origin.copy()
|
1112
|
+
print_block(
|
1113
|
+
name="Template Mask",
|
1114
|
+
data={
|
1115
|
+
"Inital Shape": initial_shape,
|
1116
|
+
"Sampling Rate": _format_sampling(template_mask.sampling_rate),
|
1117
|
+
"Final Shape": template_mask.shape,
|
1118
|
+
},
|
1119
|
+
)
|
1120
|
+
print("\n" + "-" * 80)
|
1121
|
+
|
1122
|
+
if args.scramble_phases:
|
1123
|
+
template.data = scramble_phases(
|
1124
|
+
template.data, noise_proportion=1.0, normalize_power=False
|
1125
|
+
)
|
1126
|
+
|
1127
|
+
callback_class = MaxScoreOverRotations
|
1128
|
+
if args.peak_calling:
|
1129
|
+
callback_class = PeakCallerMaximumFilter
|
1130
|
+
|
1131
|
+
if args.orientations is not None:
|
1132
|
+
callback_class = MaxScoreOverRotationsConstrained
|
1133
|
+
|
1134
|
+
# Determine suitable backend for the selected operation
|
1135
|
+
available_backends = be.available_backends()
|
1136
|
+
if args.backend not in available_backends:
|
1137
|
+
raise ValueError("Requested backend is not available.")
|
1138
|
+
if args.backend == "jax" and callback_class != MaxScoreOverRotations:
|
1139
|
+
raise ValueError(
|
1140
|
+
"Jax backend only supports the MaxScoreOverRotations analyzer."
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
|
1144
|
+
warnings.warn(
|
1145
|
+
"Jax and pytorch do not support interpolation order 3, setting it to 1."
|
1146
|
+
)
|
1147
|
+
args.interpolation_order = 1
|
1148
|
+
|
1149
|
+
if args.backend in ("pytorch", "cupy", "jax"):
|
1150
|
+
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
1151
|
+
if gpu_devices is None:
|
1152
|
+
warnings.warn(
|
1153
|
+
"No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
|
1154
|
+
"Assuming device 0.",
|
1155
|
+
)
|
1156
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
1157
|
+
|
1158
|
+
args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
|
1159
|
+
args.gpu_indices = [
|
1160
|
+
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
1161
|
+
]
|
1162
|
+
|
1163
|
+
# Finally set the desired backend
|
1164
|
+
device = "cuda"
|
1165
|
+
be.change_backend(args.backend)
|
1166
|
+
if args.backend in ("jax", "pytorch", "cupy"):
|
1167
|
+
args.use_gpu = True
|
1168
|
+
|
1169
|
+
if args.backend == "pytorch":
|
1170
|
+
try:
|
1171
|
+
be.change_backend("pytorch", device=device)
|
1172
|
+
# Trigger exception if not compiled with device
|
1173
|
+
be.get_available_memory()
|
1174
|
+
except Exception as e:
|
1175
|
+
print(e)
|
1176
|
+
device = "cpu"
|
1177
|
+
args.use_gpu = True
|
1178
|
+
be.change_backend("pytorch", device=device)
|
1179
|
+
|
1180
|
+
if args.use_mixed_precision:
|
1181
|
+
be.change_backend(
|
1182
|
+
backend_name=args.backend,
|
1183
|
+
float_dtype=be._array_backend.float16,
|
1184
|
+
complex_dtype=be._array_backend.complex64,
|
1185
|
+
int_dtype=be._array_backend.int16,
|
1186
|
+
device=device,
|
1187
|
+
)
|
1188
|
+
|
1189
|
+
available_memory = be.get_available_memory() * be.device_count()
|
1190
|
+
if args.memory is None:
|
1191
|
+
args.memory = int(args.memory_scaling * available_memory)
|
1192
|
+
|
1193
|
+
if args.orientations_uncertainty is not None:
|
1194
|
+
args.orientations_uncertainty = tuple(
|
1195
|
+
int(x) for x in args.orientations_uncertainty.split(",")
|
1196
|
+
)
|
1197
|
+
|
1198
|
+
matching_data = MatchingData(
|
1199
|
+
target=target,
|
1200
|
+
template=template.data,
|
1201
|
+
target_mask=target_mask,
|
1202
|
+
template_mask=template_mask,
|
1203
|
+
invert_target=args.invert_target_contrast,
|
1204
|
+
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
1205
|
+
)
|
1206
|
+
|
1207
|
+
setup_filt = setup_filter
|
1208
|
+
if args.projection_matching:
|
1209
|
+
setup_filt = setup_projection_filter
|
1210
|
+
matching_data.set_matching_dimension(target_dim=0)
|
1211
|
+
|
1212
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
1213
|
+
matching_data.template_filter, matching_data.target_filter = setup_filt(
|
1214
|
+
args, template, target
|
1215
|
+
)
|
1216
|
+
|
1217
|
+
splits, schedule = compute_schedule(args, matching_data, callback_class)
|
1218
|
+
|
1219
|
+
n_splits = np.prod(list(splits.values()))
|
1220
|
+
target_split = ", ".join(
|
1221
|
+
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
1222
|
+
)
|
1223
|
+
gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
|
1224
|
+
options = {
|
1225
|
+
"Angular Sampling": f"{args.angular_sampling}"
|
1226
|
+
f" [{matching_data.rotations.shape[0]} rotations]",
|
1227
|
+
"Center Template": not args.no_centering,
|
1228
|
+
"Scramble Template": args.scramble_phases,
|
1229
|
+
"Invert Contrast": args.invert_target_contrast,
|
1230
|
+
"Extend Target Edges": args.pad_edges,
|
1231
|
+
"Interpolation Order": args.interpolation_order,
|
1232
|
+
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
1233
|
+
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
1234
|
+
}
|
1235
|
+
|
1236
|
+
print_block(
|
1237
|
+
name="Template Matching",
|
1238
|
+
data=options,
|
1239
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1240
|
+
)
|
1241
|
+
|
1242
|
+
compute_options = {
|
1243
|
+
"Backend": be._BACKEND_REGISTRY[be._backend_name],
|
1244
|
+
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1245
|
+
"Use Mixed Precision": args.use_mixed_precision,
|
1246
|
+
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1247
|
+
"Temporary Directory": args.temp_directory,
|
1248
|
+
"Target Splits": f"{target_split} [N={n_splits}]",
|
1249
|
+
}
|
1250
|
+
print_block(
|
1251
|
+
name="Computation",
|
1252
|
+
data=compute_options,
|
1253
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1254
|
+
)
|
1255
|
+
|
1256
|
+
filter_args = {
|
1257
|
+
"Lowpass": args.lowpass,
|
1258
|
+
"Highpass": args.highpass,
|
1259
|
+
"Smooth Pass": args.no_pass_smooth,
|
1260
|
+
"Pass Format": args.pass_format,
|
1261
|
+
"Spectral Whitening": args.whiten_spectrum,
|
1262
|
+
"Wedge Axes": args.wedge_axes,
|
1263
|
+
"Tilt Angles": args.tilt_angles,
|
1264
|
+
"Tilt Weighting": args.tilt_weighting,
|
1265
|
+
"Reconstruction Filter": args.reconstruction_filter,
|
1266
|
+
"Extend Filter Grid": args.pad_filter,
|
1267
|
+
}
|
1268
|
+
if args.ctf_file is not None or args.defocus is not None:
|
1269
|
+
filter_args["CTF File"] = args.ctf_file
|
1270
|
+
filter_args["Defocus"] = args.defocus
|
1271
|
+
filter_args["Phase Shift"] = args.phase_shift
|
1272
|
+
filter_args["Flip Phase"] = args.no_flip_phase
|
1273
|
+
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1274
|
+
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1275
|
+
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
1276
|
+
filter_args["Correct Defocus"] = args.correct_defocus_gradient
|
1277
|
+
|
1278
|
+
filter_args = {k: v for k, v in filter_args.items() if v is not None}
|
1279
|
+
if len(filter_args):
|
1280
|
+
print_block(
|
1281
|
+
name="Filters",
|
1282
|
+
data=filter_args,
|
1283
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1284
|
+
)
|
1285
|
+
|
1286
|
+
analyzer_args = {
|
1287
|
+
"score_threshold": args.score_threshold,
|
1288
|
+
"num_peaks": args.num_peaks,
|
1289
|
+
"min_distance": max(template.shape) // 3,
|
1290
|
+
"use_memmap": args.use_memmap,
|
1291
|
+
}
|
1292
|
+
if args.orientations is not None:
|
1293
|
+
analyzer_args["reference"] = (0, 0, 1)
|
1294
|
+
analyzer_args["cone_angle"] = args.orientations_cone
|
1295
|
+
analyzer_args["acceptance_radius"] = args.orientations_uncertainty
|
1296
|
+
analyzer_args["positions"] = args.orientations.translations
|
1297
|
+
analyzer_args["rotations"] = euler_to_rotationmatrix(
|
1298
|
+
args.orientations.rotations
|
1299
|
+
)
|
1300
|
+
|
1301
|
+
print_block(
|
1302
|
+
name="Analyzer",
|
1303
|
+
data={"Analyzer": callback_class, **analyzer_args},
|
1304
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1305
|
+
)
|
1306
|
+
print("\n" + "-" * 80)
|
1307
|
+
|
1308
|
+
outer_jobs = f"{schedule[0]} job{'s' if schedule[0] > 1 else ''}"
|
1309
|
+
inner_jobs = f"{schedule[1]} core{'s' if schedule[1] > 1 else ''}"
|
1310
|
+
n_splits = f"{n_splits} split{'s' if n_splits > 1 else ''}"
|
1311
|
+
print(f"\nDistributing {n_splits} on {outer_jobs} each using {inner_jobs}.")
|
1312
|
+
|
1313
|
+
start = time()
|
1314
|
+
print("Running Template Matching. This might take a while ...")
|
1315
|
+
candidates = scan_subsets(
|
1316
|
+
matching_data=matching_data,
|
1317
|
+
job_schedule=schedule,
|
1318
|
+
matching_score=matching_score,
|
1319
|
+
matching_setup=matching_setup,
|
1320
|
+
callback_class=callback_class,
|
1321
|
+
callback_class_args=analyzer_args,
|
1322
|
+
target_splits=splits,
|
1323
|
+
pad_target_edges=args.pad_edges,
|
1324
|
+
pad_template_filter=args.pad_filter,
|
1325
|
+
interpolation_order=args.interpolation_order,
|
1326
|
+
match_projection=args.projection_matching
|
1327
|
+
)
|
1328
|
+
|
1329
|
+
candidates = list(candidates) if candidates is not None else []
|
1330
|
+
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1331
|
+
write_pickle(data=candidates, filename=args.output)
|
1332
|
+
|
1333
|
+
runtime = time() - start
|
1334
|
+
print("\n" + "-" * 80)
|
1335
|
+
print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
|
1336
|
+
|
1337
|
+
|
1338
|
+
if __name__ == "__main__":
|
1339
|
+
main()
|